K-means clustering using sklearn.cluster

10,484

Practically, It's impossible to visualize 750 dimension data directly.

But there are other way going around, for example, doing dimention reduction first using PCA to a farily low dimention, like 4. Scikit-learn also provides a function for this.

Then you can draw a matrix of plot, with each plot only have two features. Using Pandas package, you can draw these plot very easily with scatter_matrix function.

Note that, in your case you only using PCA for visualization, you should still doing K-means clustering on original data, after getting the centroids, doing the PCA for the centroids using the PCA model you create before.

Here is an example plot created by scatter_matrix function: enter image description here

Share:
10,484
jacky_learns_to_code
Author by

jacky_learns_to_code

I aspire to be the best Data Scientist I can ever become, for each day that passes, I am one step closer to my dream!

Updated on June 05, 2022

Comments

  • jacky_learns_to_code
    jacky_learns_to_code almost 2 years

    I came across this tutorial on K-means clustering on Unsupervised Machine Learning: Flat Clustering, and below is the code:

    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib import style
    style.use("ggplot")
    
    from sklearn.cluster import KMeans
    
    X = np.array([[1,2],[5,8],[1.5,1.8],[1,0.6],[9,11]])
    
    kmeans = KMeans(n_clusters=3)
    kmeans.fit(X)
    
    centroid = kmeans.cluster_centers_
    labels = kmeans.labels_
    
    print (centroid)
    print(labels)
    
    colors = ["g.","r.","c."]
    
    for i in range(len(X)):
       print ("coordinate:" , X[i], "label:", labels[i])
       plt.plot(X[i][0],X[i][1],colors[labels[i]],markersize=10)
    
    plt.scatter(centroid[:,0],centroid[:,1], marker = "x", s=150, linewidths = 5, zorder =10)
    
    plt.show()
    

    In this example, the array has only 2 features [1,2],[5,8],[1.5,1.8] etc.

    I have tried to replace the X with 10 x 750 matrix (750 features) stored in an np.array(). The graph it created just does not make any sense.

    How could I alter the above code to solve my problem?