customized legend for matplotlib
Here the main this is to show to to create a customized legend. In the example, the last 5 lines are the ones used for the legend. c1, c2 and c3 are used to define each entry of the legend. We used linestyle='None' to remove the crossing line, you can remove it and see how it would looks like.
import matplotlib.pyplot as plt import pandas as pd from matplotlib.lines import Line2D carat = [5, 10, 20, 30, 5, 10, 20, 30, 5, 10, 20, 30] price = [100, 100, 200, 200, 300, 300, 400, 400, 500, 500, 600, 600] color =['D', 'D', 'D', 'E', 'E', 'E', 'F', 'F', 'F', 'G', 'G', 'G',] df = pd.DataFrame(dict(carat=carat, price=price, color=color)) fig, ax = plt.subplots() colors = {'D':'red', 'E':'blue', 'F':'green', 'G':'black'} ax.scatter(df['carat'], df['price'], c=df['color'].apply(lambda x: colors[x])) c1 = Line2D([0], [0], color="red", marker="o", linestyle='None') c2 = Line2D([0], [0], color="blue", marker="s", linestyle='None') c3 = Line2D([0], [0], color="green", marker="d", linestyle='None') ax.legend([c1, c2, c3], ['c111', 'c2', 'cIII'], numpoints=1) plt.show()
Or a better solution is to use plot with the third parameter as the marker (positional parameter). see the below example
colors_hex = zip(*colors)[1] for idx, column in enumerate(df): plt.plot(df[column], df[column], "o", c=colors_hex[idx], alpha=0.5, label=column) plt.legend() plt.show()
References:
This example is not mine, but I edited it to match our purpose. I lost the reference and couldn't find it, if you do, please let me know so I can add it in the references.
1. matplotlib official website
2. http://stackoverflow.com/questions/21285885/remove-line-through-marker-in-matplotlib-legend
Comments
Post a Comment