Make legend correspond to colors of scatter points in matplotlib

12,214

I was able to make the legend correspond to the color. The key was using multiple scatterplots for each category in the data as mentioned by Rutger Kassies.

Here is the code:

import numpy as np
import matplotlib.pyplot as plt

# Setting various plot properties
plt.clf()
fig = plt.figure()
ax = fig.add_subplot(111, axisbg='w', frame_on=True)
fig.set_size_inches(18.5, 10.5)

# Creating a discrete colorbar
colors = plt.cm.rainbow(np.linspace(0, 1, 20))

current_plot_range = 0
previous_plot_range = 0

for i,c in enumerate(colors):
  previous_plot_range += current_plot_range
  current_plot_range = labels[labels==i].size
  m.scatter(
       [geom.x for geom in map_points[      
             previous_plot_range:previous_plot_range+current_plot_range]],
       [geom.y for geom in map_points[
             previous_plot_range:previous_plot_range+current_plot_range]],
       20, lw=.25, marker='o',color = c, label=i, alpha =0.9, antialiased=True, 
       zorder=3)

plt.legend()
m.fillcontinents(color='#555555')

The result looks something like this: enter image description here

Share:
12,214

Related videos on Youtube

Nitin
Author by

Nitin

Updated on June 21, 2022

Comments

  • Nitin
    Nitin almost 2 years

    I have a plot that I am generating through KMeans algorithm in scikit-learn. The clusters correspond to different colors. Here is the plot, enter image description here

    I need a legend for this plot which corresponds to the cluster number in the plot. Ideally, the legend should display the color of the cluster and the label should be the cluster number. Thanks.

    EDIT: I think I should put some code since people are downvoting this

    from sklearn.cluster import KMeans
    km = KMeans(n_clusters=20, init='random')   
    km.fit(df)  #df is the dataframe which contains points as coordinates
    labels = km.labels_
    plt.clf()
    fig = plt.figure()
    ax = fig.add_subplot(111, axisbg='w', frame_on=True)
    fig.set_size_inches(18.5, 10.5)
    
    # Plot the clusters on the map
    # m is a basemap object
    m.scatter(
             [geom.x for geom in map_points],
             [geom.y for geom in map_points],
             20, marker='o', lw=.25,
             c = labels.astype(float),
             alpha =0.9, antialiased=True,
             zorder=3)
    m.fillcontinents(color='#555555')
    plt.show()
    
    • cd98
      cd98 over 10 years
      @Nitin, if you came up with a solution, please post it back here
  • SeF
    SeF over 8 years
    Thanks for this answer! You can consider to use enumerate instead of zip.