A hierarchical clustering and dendrogram example using SciPy and pandas

Hierarchical clustering using scipy and pandas

Hierarchical clustering with SciPy

What do you do when you've got too many variables and don't know which are useful?

Often people will want a clustering solution, and hierarchical clustering is one of the most popular, especially for gene expression data.

Hierarchical clustering groups similar objects or parameters into clusters. Unlike K-means, the number of clusters is not predetermined, but we can ad-hoc determine the number of clusters by deciding where to cut the dendrogram. This is one of the techniques we'll focus on.

Scikit-learn also has a good hierarchical clustering solution, but we'll focus on SciPy's implementation for now.

SciPy was built to work with NumPy arrays, so keeping the row and column names concordant with their pandas DataFrame counterparts is key.

First, let's import all the modules we will need.

In [128]:
from collections import defaultdict
from scipy.spatial.distance import pdist, squareform
from scipy.cluster.hierarchy import linkage, dendrogram
from matplotlib.colors import rgb2hex, colorConverter
from scipy.cluster.hierarchy import set_link_color_palette
import pandas as pd
import scipy.cluster.hierarchy as sch
%pylab inline
Populating the interactive namespace from numpy and matplotlib

The next two cells are technically optional. I just think it makes the dendrogram images look nicer. The next cell sets the figure size for figures generated inline by pylab. Seaborn's "whitegrid" color palette is a good general-purpose tool.

In [129]:
from pylab import rcParams
rcParams['figure.figsize'] = 12, 9
In [130]:
import seaborn as sns
sns.set_style("whitegrid")

We import our dataset as a pandas DataFrame. This comes from fivethirtyeight's github on alcohol consumption by country, which originally comes from the WHO. The data gives the total alcohol consumption and also describes consumption by beverage type. We'll use this to see which countries tend to be similar in alcohol consumption.

In [131]:
clustdf=pd.read_csv("https://raw.githubusercontent.com/fivethirtyeight/data/master/alcohol-consumption/drinks.csv", index_col=0)
In [132]:
clustdf.head(1)
Out[132]:
beer_servings spirit_servings wine_servings total_litres_of_pure_alcohol
country
Afghanistan 0 0 0 0.0
In [133]:
clustdf.shape
Out[133]:
(193, 4)

We also need to decide whether to cluster on rows or columns. If you'd like to cluster based on columns, you can leave the DataFrame as-is. If you'd like to cluster the rows, you have to transpose the DataFrame.

In [134]:
clustdf_t=clustdf.transpose()

Then we compute the distance matrix and the linkage matrix using SciPy libraries. The hyperparameters are NOT trivial. I strongly encourage everyone to check out the SciPy docs for pdist and linkage for details and try different hyperparameters to see what you get!

In this case, we have used the default setting (Euclidean distance) for the p-dist function. This is equivalent to the $L^2$-norm as the distance metric between the points. Since the goal of the clustering is to minimize distance between points in the same cluster, the purpose of the linkage algorithm here is to compute the distance between clusters. You could choose, for example, the distance between the centroids of the clusters, or the distance between the closest two points in the clusters.

I selected "complete" which uses the Voor Hees algorithm: $$ d(u,v) = \max (\mbox{dist}(u[i],v[j])) $$ Which is just a fancy way of saying we compute all the distances between points in cluster $u$ and cluster $v$, then find the max distance (in other words, a point in cluster $u$, $u[i]$ that achieves that max distance from a point in cluster $v$, $v[j]$).

In [135]:
c_dist = pdist(clustdf_t) # computing the distance
c_link = linkage(clustdf_t,  metric='correlation', method='complete')# computing the linkage

With the linkage matrix, we can make a dendrogram. Take note that I am using the columns of the original pandas DataFrame as labels. This dendrogram will look different depending on the hyperparameters chosen for distance and linkage.

In [136]:
B=dendrogram(c_link,labels=list(clustdf.columns))

This is nice, but I'd like a legend that tells me what these colors mean. Plus, clustering by column is not that interesting. We will cluster by the index (countries) and those will be easier to interpret. Since there are 193 countries in this data, it will be difficult to label all these countries legibly.

We define a Clusters class which gives an HTML representation of the cluster assignments, then a function to get the cluster assignments to get a clean legend that tells us which countries are in which cluster, and what the colors assigned to each cluster mean. This code is originally from http://www.nxn.se/valent/extract-cluster-elements-by-color-in-python

In [137]:
class Clusters(dict):
    def _repr_html_(self):
        html = '<table style="border: 0;">'
        for c in self:
            hx = rgb2hex(colorConverter.to_rgb(c))
            html += '<tr style="border: 0;">' \
            '<td style="background-color: {0}; ' \
                       'border: 0;">' \
            '<code style="background-color: {0};">'.format(hx)
            html += c + '</code></td>'
            html += '<td style="border: 0"><code>' 
            html += repr(self[c]) + '</code>'
            html += '</td></tr>'

        html += '</table>'

        return html
def get_cluster_classes(den, label='ivl'):
    cluster_idxs = defaultdict(list)
    for c, pi in zip(den['color_list'], den['icoord']):
        for leg in pi[1:3]:
            i = (leg - 5.0) / 10.0
            if abs(i - int(i)) < 1e-5:
                cluster_idxs[c].append(int(i))
    
    cluster_classes = Clusters()
    for c, l in cluster_idxs.items():
        i_l = [den[label][i] for i in l]
        cluster_classes[c] = i_l
    
    return cluster_classes

I define a function to roll the transpose, distance calculation, linkage, and graphing the dendrogram into one step. Most of the lines are setting matplotlib parameters. The function accepts the number of desired clusters and a pandas dataframe as input and provides a dendrogram as output.

Four other arguments after df and numclust: transpose, which simply makes it so that we use the index of the dataframe when set to True and columns when set to False, and dataname, which accepts a string that describes the data and goes in the title of the rendered dendrogram. Data_dist and data_link use the same distance and linkage methods as given above. Save indicates whether you'd like to save the figure to a PNG. Xticksize sets the size of the x ticks, which is important for legibility.

In [157]:
def get_clust_graph(df, numclust, transpose=False, dataname=None, save=False, xticksize=8):
    if transpose==True:
        aml=df.transpose()
        xl="x-axis"
    else:
        aml=df
        xl="y-axis"
    data_dist = pdist(aml.transpose()) # computing the distance
    data_link = linkage(data_dist,  metric='correlation', method='complete')#method="complete") # computing the linkage
    B=dendrogram(data_link,labels=list(aml.columns),p=numclust, truncate_mode="lastp",get_leaves=True, count_sort='ascending', show_contracted=True)
    #myInd = [i for i, c in zip(B['ivl'], B['color_list']) if c=='g']
    get_cluster_classes(B)
    ax=plt.gca()
    ax.tick_params(axis='x', which='major', labelsize=xticksize)
    ax.tick_params(axis='y', which='major', labelsize=15)
    plt.xlabel(xl)
    #plt.set_size_inches(18.5, 10.5)
    plt.ylabel('Distance')
    plt.suptitle(xl+" clustering for "+dataname, fontweight='bold', fontsize=16);
    if save:
        plt.savefig(str(df.index.name)+str(numclust)+"tr_"+str(transpose)+"dn_"+str(dataname)+save+'.png')
    else:
        print("Not saving")
    return get_cluster_classes(B)

The following is a helper function to show cluster members when the number of clusters is less than the total number of countries.

In [145]:
def give_cluster_assigns(df, numclust, tranpose=True):
    if transpose==True:
        data_dist = pdist(df.transpose())
        data_link = linkage(data_dist,  metric='correlation', method='complete')
        cluster_assigns=pd.Series(sch.fcluster(data_link, numclust, criterion='maxclust', monocrit=None), index=df.columns)
    else:
        data_dist = pdist(df)
        data_link = linkage(data_dist,  metric='correlation', method='complete')
        cluster_assigns=pd.Series(sch.fcluster(data_link, numclust, criterion='maxclust', monocrit=None), index=df.index)
    for i in arange(1,numclust+1):
        print("Cluster ",str(i),": ( N =",len(cluster_assigns[cluster_assigns==i].index),")", ", ".join(list(cluster_assigns[cluster_assigns==i].index)))

Let's try doing this with all the countries available from the dataset.

In [146]:
get_clust_graph(clustdf, 193, transpose=True,dataname="Alcohol")
Not saving
Out[146]:
g['Andorra', 'Luxembourg', 'Croatia', 'Denmark', 'Switzerland', 'Portugal', 'France', 'Australia', 'Austria', 'Netherlands', 'Belgium', 'Slovenia', 'Argentina', 'Sweden', 'Italy', 'Uruguay', 'Equatorial Guinea', 'Malta', 'Norway', 'Chile', 'Greece', 'New Zealand', 'United Kingdom', 'Chile']
r['Palau', 'Venezuela', 'Gabon', 'Namibia', 'Lithuania', 'Poland', 'Latvia', 'Ireland', 'Romania', 'Germany', 'Czech Republic', 'Hungary', 'Serbia', 'Spain', 'Canada', 'Finland', 'USA', 'Paraguay', 'Cyprus', 'Belize', 'Panama', 'Brazil', 'Iceland', 'South Africa', 'Angola', 'Mexico']
c['Belarus', 'Grenada', 'Estonia', 'Ukraine', 'Bulgaria', 'St. Lucia', 'Slovakia', 'Russian Federation']
m['Dominica', 'Guyana', 'Haiti', 'Cook Islands', 'Albania', 'Antigua & Barbuda', 'Honduras', 'Jamaica', 'Nicaragua', 'Cuba', 'Liberia', 'United Arab Emirates', 'Armenia', 'Kyrgyzstan', 'Uzbekistan', 'India', 'Sri Lanka', 'Niue', 'St. Kitts & Nevis', 'Dominican Republic', 'Trinidad & Tobago', 'Peru', 'Moldova', 'St. Vincent & the Grenadines', 'Kazakhstan', 'Thailand', 'Bahamas', 'Barbados', 'Suriname', 'China', 'Mongolia', 'Philippines', 'Bosnia-Herzegovina', 'Japan']
y['Georgia', 'Montenegro', 'Laos', 'Sao Tome & Principe', 'Macedonia', 'Fiji', 'Lesotho', 'Mauritius', 'Samoa', 'Burundi', 'Swaziland', 'Congo', 'Vietnam', 'Colombia', 'Ecuador', 'Costa Rica', 'Cameroon', 'South Korea', 'Botswana', 'Seychelles', 'Bolivia', 'Cabo Verde', 'Botswana', 'Lebanon', 'Turkmenistan', 'El Salvador', 'Guatemala', 'Cambodia', 'Israel', 'Bahrain', 'Micronesia', 'Nauru', 'Tunisia', 'Nigeria', 'Rwanda', 'Uganda', 'Mozambique', 'Turkey', 'Kenya', 'Zimbabwe', 'Singapore', 'Solomon Islands', 'Kenya', 'Qatar', 'Tuvalu', 'Syria', 'Azerbaijan', 'Djibouti', 'Kiribati', 'Tonga', 'Zambia', 'Madagascar', 'Oman', 'Vanuatu', 'Guinea-Bissau', 'Papua New Guinea', 'Burkina Faso', 'Sierra Leone', 'Brunei', 'DR Congo', 'Tanzania', 'Algeria', 'Togo', 'Benin', 'Ghana', "Cote d'Ivoire", 'Morocco', 'Senegal', 'Bhutan', 'Ethiopia', 'Central African Republic', 'Chad', 'Eritrea', 'Malaysia', 'Malawi', 'Sudan', 'Jordan', 'Tajikistan', 'Egypt', 'Nepal', 'Gambia', 'Guinea', 'Iraq', 'Indonesia', 'Myanmar', 'Mali', 'Yemen', 'Comoros', 'Niger', 'Timor-Leste', 'Afghanistan', 'Bangladesh', 'North Korea', 'Iran', 'Kuwait', 'Libya', 'Maldives', 'Marshall Islands', 'Mauritania', 'Monaco', 'Pakistan', 'San Marino', 'Somalia', 'Saudi Arabia']

To make the graphic bigger and make it easier to see the countries on the x-axis, you can change the rcParams again:

In [150]:
rcParams['figure.figsize'] = 50, 9
rcParams['axes.labelsize'] = "large"
rcParams['font.size']= 20
In [162]:
get_clust_graph(clustdf, 193, transpose=True,dataname="Alcohol", save="ww2", xticksize=9)
Out[162]:
g['Andorra', 'Luxembourg', 'Croatia', 'Denmark', 'Switzerland', 'Portugal', 'France', 'Australia', 'Austria', 'Netherlands', 'Belgium', 'Slovenia', 'Argentina', 'Sweden', 'Italy', 'Uruguay', 'Equatorial Guinea', 'Malta', 'Norway', 'Chile', 'Greece', 'New Zealand', 'United Kingdom', 'Chile']
r['Palau', 'Venezuela', 'Gabon', 'Namibia', 'Lithuania', 'Poland', 'Latvia', 'Ireland', 'Romania', 'Germany', 'Czech Republic', 'Hungary', 'Serbia', 'Spain', 'Canada', 'Finland', 'USA', 'Paraguay', 'Cyprus', 'Belize', 'Panama', 'Brazil', 'Iceland', 'South Africa', 'Angola', 'Mexico']
c['Belarus', 'Grenada', 'Estonia', 'Ukraine', 'Bulgaria', 'St. Lucia', 'Slovakia', 'Russian Federation']
m['Dominica', 'Guyana', 'Haiti', 'Cook Islands', 'Albania', 'Antigua & Barbuda', 'Honduras', 'Jamaica', 'Nicaragua', 'Cuba', 'Liberia', 'United Arab Emirates', 'Armenia', 'Kyrgyzstan', 'Uzbekistan', 'India', 'Sri Lanka', 'Niue', 'St. Kitts & Nevis', 'Dominican Republic', 'Trinidad & Tobago', 'Peru', 'Moldova', 'St. Vincent & the Grenadines', 'Kazakhstan', 'Thailand', 'Bahamas', 'Barbados', 'Suriname', 'China', 'Mongolia', 'Philippines', 'Bosnia-Herzegovina', 'Japan']
y['Georgia', 'Montenegro', 'Laos', 'Sao Tome & Principe', 'Macedonia', 'Fiji', 'Lesotho', 'Mauritius', 'Samoa', 'Burundi', 'Swaziland', 'Congo', 'Vietnam', 'Colombia', 'Ecuador', 'Costa Rica', 'Cameroon', 'South Korea', 'Botswana', 'Seychelles', 'Bolivia', 'Cabo Verde', 'Botswana', 'Lebanon', 'Turkmenistan', 'El Salvador', 'Guatemala', 'Cambodia', 'Israel', 'Bahrain', 'Micronesia', 'Nauru', 'Tunisia', 'Nigeria', 'Rwanda', 'Uganda', 'Mozambique', 'Turkey', 'Kenya', 'Zimbabwe', 'Singapore', 'Solomon Islands', 'Kenya', 'Qatar', 'Tuvalu', 'Syria', 'Azerbaijan', 'Djibouti', 'Kiribati', 'Tonga', 'Zambia', 'Madagascar', 'Oman', 'Vanuatu', 'Guinea-Bissau', 'Papua New Guinea', 'Burkina Faso', 'Sierra Leone', 'Brunei', 'DR Congo', 'Tanzania', 'Algeria', 'Togo', 'Benin', 'Ghana', "Cote d'Ivoire", 'Morocco', 'Senegal', 'Bhutan', 'Ethiopia', 'Central African Republic', 'Chad', 'Eritrea', 'Malaysia', 'Malawi', 'Sudan', 'Jordan', 'Tajikistan', 'Egypt', 'Nepal', 'Gambia', 'Guinea', 'Iraq', 'Indonesia', 'Myanmar', 'Mali', 'Yemen', 'Comoros', 'Niger', 'Timor-Leste', 'Afghanistan', 'Bangladesh', 'North Korea', 'Iran', 'Kuwait', 'Libya', 'Maldives', 'Marshall Islands', 'Mauritania', 'Monaco', 'Pakistan', 'San Marino', 'Somalia', 'Saudi Arabia']
In [155]:
rcParams['figure.figsize'] = 24, 10

If you're interested in cutting off the dendrogram to achieve a certain number of clusters, you can change the defined numclust variable as follows:

In [158]:
get_clust_graph(clustdf, 10, transpose=True,dataname="Alcohol", xticksize=14)
Not saving
Out[158]:
g['(7)', '(16)']
r['(8)', '(14)', '(4)']
c['(2)', '(6)']
m['(4)', '(30)']
b['(102)']

And we can use the give_cluster_assigns function to see exactly which countries are in these clusters.

In [86]:
give_cluster_assigns(clustdf,10,tranpose=True)
Cluster  1 : ( N = 23 ) Andorra, Argentina, Australia, Austria, Belgium, Chile, Croatia, Denmark, Equatorial Guinea, France, Greece, Italy, Luxembourg, Malta, Netherlands, New Zealand, Norway, Portugal, Slovenia, Sweden, Switzerland, United Kingdom, Uruguay
Cluster  2 : ( N = 26 ) Angola, Belize, Brazil, Canada, Cyprus, Czech Republic, Finland, Gabon, Germany, Hungary, Iceland, Ireland, Latvia, Lithuania, Mexico, Namibia, Palau, Panama, Paraguay, Poland, Romania, Serbia, South Africa, Spain, USA, Venezuela
Cluster  3 : ( N = 8 ) Belarus, Bulgaria, Estonia, Grenada, Russian Federation, St. Lucia, Slovakia, Ukraine
Cluster  4 : ( N = 102 ) Afghanistan, Algeria, Azerbaijan, Bahrain, Bangladesh, Benin, Bhutan, Bolivia, Botswana, Brunei, Burkina Faso, Burundi, Cote d'Ivoire, Cabo Verde, Cambodia, Cameroon, Central African Republic, Chad, Colombia, Comoros, Congo, Costa Rica, North Korea, DR Congo, Djibouti, Ecuador, Egypt, El Salvador, Eritrea, Ethiopia, Fiji, Gambia, Georgia, Ghana, Guatemala, Guinea, Guinea-Bissau, Indonesia, Iran, Iraq, Israel, Jordan, Kenya, Kiribati, Kuwait, Laos, Lebanon, Lesotho, Libya, Madagascar, Malawi, Malaysia, Maldives, Mali, Marshall Islands, Mauritania, Mauritius, Micronesia, Monaco, Montenegro, Morocco, Mozambique, Myanmar, Nauru, Nepal, Niger, Nigeria, Oman, Pakistan, Papua New Guinea, Qatar, South Korea, Rwanda, Samoa, San Marino, Sao Tome & Principe, Saudi Arabia, Senegal, Seychelles, Sierra Leone, Singapore, Solomon Islands, Somalia, Sudan, Swaziland, Syria, Tajikistan, Macedonia, Timor-Leste, Togo, Tonga, Tunisia, Turkey, Turkmenistan, Tuvalu, Uganda, Tanzania, Vanuatu, Vietnam, Yemen, Zambia, Zimbabwe
Cluster  5 : ( N = 34 ) Albania, Antigua & Barbuda, Armenia, Bahamas, Barbados, Bosnia-Herzegovina, China, Cook Islands, Cuba, Dominica, Dominican Republic, Guyana, Haiti, Honduras, India, Jamaica, Japan, Kazakhstan, Kyrgyzstan, Liberia, Mongolia, Nicaragua, Niue, Peru, Philippines, Moldova, St. Kitts & Nevis, St. Vincent & the Grenadines, Sri Lanka, Suriname, Thailand, Trinidad & Tobago, United Arab Emirates, Uzbekistan
In [ ]:
 

Leave a comment

Your email address will not be published. Required fields are marked *