Pandas side-by-side stacked bar plot

11,169

The resulting bars will not neighbour each other as in your first figure, but outside of that, pandas lets you do what you want as follows:

df_g = df.groupby(['Pclass', 'Sex'])['Survived'].agg([np.mean, lambda x: 1-np.mean(x)])
df_g.columns = ['Survived', 'Died']
df_g.plot.bar(stacked=True)

enter image description here

Here, the horizontal grouping of patches is complicated by the requirement of stacking. If, for instance, we only cared about the value of "Survived", pandas could take care of it out-of-the-box.

df.groupby(['Pclass', 'Sex'])['Survived'].mean().unstack().plot.bar()

enter image description here

If an ad hoc solution suffices for post-processing the plot, doing so is also not terribly complicated:

import numpy as np
from matplotlib import ticker

df_g = df.groupby(['Pclass', 'Sex'])['Survived'].agg([np.mean, lambda x: 1-np.mean(x)])
df_g.columns = ['Survived', 'Died']
ax = df_g.plot.bar(stacked=True)

# Move back every second patch
for i in range(6):
    new_x = ax.patches[i].get_x() - (i%2)/2
    ax.patches[i].set_x(new_x)
    ax.patches[i+6].set_x(new_x)

# Update tick locations correspondingly
minor_tick_locs = [x.get_x()+1/4 for x in ax.patches[:6]]
major_tick_locs = np.array([x.get_x()+1/4 for x in ax.patches[:6]]).reshape(3, 2).mean(axis=1)
ax.set_xticks(minor_tick_locs, minor=True)
ax.set_xticks(major_tick_locs)

# Use indices from dataframe as tick labels
minor_tick_labels = df_g.index.levels[1][df_g.index.labels[1]].values
major_tick_labels = df_g.index.levels[0].values
ax.xaxis.set_ticklabels(minor_tick_labels, minor=True)
ax.xaxis.set_ticklabels(major_tick_labels)

# Remove ticks and organize tick labels to avoid overlap
ax.tick_params(axis='x', which='both', bottom='off')
ax.tick_params(axis='x', which='minor', rotation=45)
ax.tick_params(axis='x', which='major', pad=35, rotation=0)

enter image description here

Share:
11,169
PyRsquared
Author by

PyRsquared

Updated on June 11, 2022

Comments

  • PyRsquared
    PyRsquared almost 2 years

    I want to create a stacked bar plot of the titanic dataset. The plot needs to group by "Pclass", "Sex" and "Survived". I have managed to do this with a lot of tedious numpy manipulation to produce the normalized plot below (where "M" is male and "F" is female)enter image description here

    Is there a way to do this using pandas inbuilt plotting functionality?

    I have tried this:

    import pandas as pd
    import matplotlib.pyplot as plt
    df = pd.read_csv('train.csv')
    df_grouped = df.groupby(['Survived','Sex','Pclass'])['Survived'].count()
    df_grouped.unstack().plot(kind='bar',stacked=True,  colormap='Blues', grid=True, figsize=(13,5));
    

    enter image description here

    Which is not what I want. Is there anyway to produce the first plot using pandas plotting? Thanks in advance