How does the axis parameter from NumPy work?

18,225

Solution 1

Clearly,

e.shape == (3, 2, 2)

Sum over an axis is a reduction operation so the specified axis disappears. Hence,

e.sum(axis=0).shape == (2, 2)
e.sum(axis=1).shape == (3, 2)
e.sum(axis=2).shape == (3, 2)

Intuitively, we are "squashing" the array along the chosen axis, and summing the numbers that get squashed together.

Solution 2

To understand the axis intuitively, refer the picture below (source: Physics Dept, Cornell Uni)

enter image description here

The shape of the (boolean) array in the above figure is shape=(8, 3). ndarray.shape will return a tuple where the entries correspond to the length of the particular dimension. In our example, 8 corresponds to length of axis 0 whereas 3 corresponds to length of axis 1.

Solution 3

If someone need this visual description:

numpy axis

Solution 4

There are good answers for visualization however it might help to think purely from analytical perspective.

You can create array of arbitrary dimension with numpy. For example, here's a 5-dimension array:

>>> a = np.random.rand(2, 3, 4, 5, 6)
>>> a.shape
(2, 3, 4, 5, 6)

You can access any element of this array by specifying indices. For example, here's the first element of this array:

>>> a[0, 0, 0, 0, 0]
0.0038908603263844155

Now if you take out one of the dimensions, you get number of elements in that dimension:

>>> a[0, 0, :, 0, 0]
array([0.00389086, 0.27394775, 0.26565889, 0.62125279])

When you apply a function like sum with axis parameter, that dimension gets eliminated and array of dimension less than original gets created. For each cell in new array, the operator will get list of elements and apply the reduction function to get a scaler.

>>> np.sum(a, axis=2).shape
(2, 3, 5, 6)

Now you can check that the first element of this array is sum of above elements:

>>> np.sum(a, axis=2)[0, 0, 0, 0]
1.1647502999560164

>>> a[0, 0, :, 0, 0].sum()
1.1647502999560164

The axis=None has special meaning to flatten out the array and apply function on all numbers.

Now you can think about more complex cases where axis is not just number but a tuple:

>>> np.sum(a, axis=(2,3)).shape
(2, 3, 6)

Note that we use same technique to figure out how this reduction was done:

>>> np.sum(a, axis=(2,3))[0,0,0]
7.889432081931909

>>> a[0, 0, :, :, 0].sum()
7.88943208193191

You can also use same reasoning for adding dimension in array instead of reducing dimension:

>>> x = np.random.rand(3, 4)
>>> y = np.random.rand(3, 4)

# New dimension is created on specified axis
>>> np.stack([x, y], axis=2).shape
(3, 4, 2)
>>> np.stack([x, y], axis=0).shape
(2, 3, 4)

# To retrieve item i in stack set i in that axis 

Hope this gives you generic and full understanding of this important parameter.

Solution 5

Some answers are too specific or do not address the main source of confusion. This answer attempts to provide a more general but simple explanation of the concept, with a simple example.

The main source of confusion is related to expressions such as "Axis along which the means are computed", which is the documentation of the argument axis of the numpy.mean function. What the heck does "along which" even mean here? "Along which" essentially means that you will sum the rows (and divide by the number of rows, given that we are computing the mean), if the axis is 0, and the columns, if the axis is 1. In the case of axis is 0 (or 1), the rows can be scalars or vectors or even other multi-dimensional arrays.

In [1]: import numpy as np

In [2]: a=np.array([[1, 2], [3, 4]])

In [3]: a
Out[3]: 
array([[1, 2],
       [3, 4]])

In [4]: np.mean(a, axis=0)
Out[4]: array([2., 3.])

In [5]: np.mean(a, axis=1)
Out[5]: array([1.5, 3.5])

So, in the example above, np.mean(a, axis=0) returns array([2., 3.]) because (1 + 3)/2 = 2 and (2 + 4)/2 = 3. It returns an array of two numbers because it returns the mean of the rows for each column (and there are two columns).

Share:
18,225
CodyBugstein
Author by

CodyBugstein

Aspiring computer nerd.

Updated on June 06, 2022

Comments

  • CodyBugstein
    CodyBugstein almost 2 years

    Can someone explain exactly what the axis parameter in NumPy does?

    I am terribly confused.

    I'm trying to use the function myArray.sum(axis=num)

    At first I thought if the array is itself 3 dimensions, axis=0 will return three elements, consisting of the sum of all nested items in that same position. If each dimension contained five dimensions, I expected axis=1 to return a result of five items, and so on.

    However this is not the case, and the documentation does not do a good job helping me out (they use a 3x3x3 array so it's hard to tell what's happening)

    Here's what I did:

    >>> e
    array([[[1, 0],
            [0, 0]],
    
           [[1, 1],
            [1, 0]],
    
           [[1, 0],
            [0, 1]]])
    >>> e.sum(axis = 0)
    array([[3, 1],
           [1, 1]])
    >>> e.sum(axis=1)
    array([[1, 0],
           [2, 1],
           [1, 1]])
    >>> e.sum(axis=2)
    array([[1, 0],
           [2, 1],
           [1, 1]])
    >>>
    

    Clearly the result is not intuitive.

  • CodyBugstein
    CodyBugstein about 10 years
    I don't understand - why does the specified axis disappear?
  • Martin
    Martin about 10 years
    If you have a vector and sum over its elements you get a number. From 1 dimensional array you get a scalar. This is the way how it disappears.
  • Benjamin Crouzier
    Benjamin Crouzier over 6 years
    @CodyBugstein The axis disappears because sum is an aggregation operation, meaning that when you sum n elements, you end up with just 1. So after the sum, all elements along that axis are collapsed to just one element. Imagine that your are crumpling a rectangular sheet of paper into an horizontal line (or kind of cigar). If you drew a grid of number on your sheet, all the vertical elements end up together; only the horizontal dimension remains.
  • debaonline4u
    debaonline4u over 5 years
    Here you can follow too: stackoverflow.com/questions/22149584/…
  • blitu12345
    blitu12345 about 5 years
    Hey Martin! Thanks man for the answer!! I was really messed up how axis parameter works.But this precise answer in just 5 lines makes everything crystal clear.Now I have got a dimension ti disappear!!
  • off99555
    off99555 almost 4 years
    This is a better understanding than just making an axis disappear. An axis disappearing is a result, not the cause. If you expect the axis disappear you will be annoyed when you try to use operations like sort.