customized legend for matplotlib


Here the main this is to show to to create a customized legend. In the example, the last 5 lines are the ones used for the legend. c1, c2 and c3 are used to define each entry of the legend. We used linestyle='None' to remove the crossing line, you can remove it and see how it would looks like. 



import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.lines import Line2D

carat = [5, 10, 20, 30, 5, 10, 20, 30, 5, 10, 20, 30]
price = [100, 100, 200, 200, 300, 300, 400, 400, 500, 500, 600, 600]
color =['D', 'D', 'D', 'E', 'E', 'E', 'F', 'F', 'F', 'G', 'G', 'G',]

df = pd.DataFrame(dict(carat=carat, price=price, color=color))

fig, ax = plt.subplots()

colors = {'D':'red', 'E':'blue', 'F':'green', 'G':'black'}

ax.scatter(df['carat'], df['price'], c=df['color'].apply(lambda x: colors[x]))

c1 = Line2D([0], [0], color="red", marker="o", linestyle='None')
c2 = Line2D([0], [0], color="blue", marker="s", linestyle='None')
c3 = Line2D([0], [0], color="green", marker="d", linestyle='None')

ax.legend([c1, c2, c3], ['c111', 'c2', 'cIII'], numpoints=1)
plt.show()

Or a better solution is to use plot with the third parameter as the marker (positional parameter). see the below example

colors_hex = zip(*colors)[1]
for idx, column in enumerate(df):
    plt.plot(df[column], df[column], "o", c=colors_hex[idx], alpha=0.5, label=column)
plt.legend()
plt.show()


References:
This example is not mine, but I edited it to match our purpose. I lost the reference and couldn't find it, if you do, please let me know so I can add it in the references.
1. matplotlib official website
2. http://stackoverflow.com/questions/21285885/remove-line-through-marker-in-matplotlib-legend

Comments

Popular Posts