Axes from plt.subplots() is a "numpy.ndarray" object and has no attribute "plot"

121,254

Solution 1

The problem here is with how matplotlib handles subplots. Just do the following:

fig, axes = plt.subplots(nrows=1, ncols=2)
for axis in axes:
    print(type(axis))

you will get a matplotlib object which is actually a 1D array which can be traversed using single index i.e. axis[0], axis[1]...and so on. But if you do

fig, axes = plt.subplots(nrows=2, ncols=2)
for axis in axes:
    print(type(axis))

you will get a numpy ndarray object which is actually a 2D array which can be traversed only using 2 indices i.e. axis[0, 0], axis[1, 0]...and so on. So be mindful how you incorporate your for loop to traverse through axes object.

Solution 2

In case if you use N by 1 graphs, for example if you do like fig, ax = plt.subplots(3, 1) then please do likeax[plot_count].plot(...)

Solution 3

The axes are in 2-d, not 1-d so you can't iterate through using one loop. You need one more loop:

 fig,axes=plt.subplots(nrows=2,ncols=2)
    plt.tight_layout()
    for ho in axes:
        for i in ho:
            i.plot(a,a**2)

This gives no problem but if I try:

for i in axes:
      i.plot(a,a**2)

the error occurs.

Share:
121,254
Lucubrator
Author by

Lucubrator

Interested in Python.

Updated on July 09, 2022

Comments

  • Lucubrator
    Lucubrator almost 2 years

    The information below may be superfluous if you are trying to understand the error message. Please start off by reading the answer by @user707650.

    Using MatPlotLib, I wanted a generalizable script that creates the following from my data.

    A window containing a subplots arranged so that there are b subplots per column. I want to be able to change the values of a and b.

    If I have data for 2a subplots, I want 2 windows, each with the previously described "a subplots arranged according to b subplots per column".

    The x and y data I am plotting are floats stored in np.arrays and are structured as follows:

    • The x data is always the same for all plots and is of length 5.

       'x_vector': [0.000, 0.005, 0.010, 0.020, 0.030, 0.040]
      
    • The y data of all plots are stored in y_vector where the data for the first plot is stored at indexes 0 through 5. The data for the second plot is stored at indexes 6 through 11. The third plot gets 12-18, the fourth 19-24, and so on.

    In total, for this dataset, I have 91 plots (i.e. 91*6 = 546 values stored in y_vector).

    Attempt:

    import matplotlib.pyplot as plt
    
    # Options:
    plots_tot = 14 # Total number of plots. In reality there is going to be 7*13 = 91 plots.
    location_of_ydata = 6 # The values for the n:th plot can be found in the y_vector at index 'n*6' through 'n*6 + 6'.
    plots_window = 7 # Total number of plots per window.
    rows = 2 # Number of rows, i.e. number of subplots per column.
    
    # Calculating number of columns:
    prim_cols = plots_window / rows
    extra_cols = 0
    if plots_window % rows > 0:
        extra_cols = 1
    cols = prim_cols + extra_cols
    
    print 'cols:', cols
    print 'rows:', rows
    
    # Plotting:
    n=0
    x=0
    fig, ax = plt.subplots(rows, cols)
    while x <= plots_tot:
        ax[x].plot(x_vector, y_vector[n:(n+location_of_ydata)], 'ro')
        if x % plots_window == plots_window - 1:
            plt.show() # New window for every 7 plots.
        n = n+location_of_ydata
        x = x+1
    

    I get the following error:

    cols: 4
    rows: 2
    Traceback (most recent call last):
      File "Script.py", line 222, in <module>
        ax[x].plot(x_vector, y_vector[n:(n+location_of_ydata)], 'ro')
    AttributeError: 'numpy.ndarray' object has no attribute 'plot'