PyTorch torch.max over multiple dimensions

11,243

Solution 1

Now, you can do this. The PR was merged (Aug 28) and it is now available in the nightly release.

Simply use torch.amax():

import torch

x = torch.tensor([
    [[-0.3000, -0.2926],[-0.2705, -0.2632]],
    [[-0.1821, -0.1747],[-0.1526, -0.1453]],
    [[-0.0642, -0.0568],[-0.0347, -0.0274]]
])

print(torch.amax(x, dim=(1, 2)))

# Output:
# >>> tensor([-0.2632, -0.1453, -0.0274])

Original Answer

As of today (April 11, 2020), there is no way to do .min() or .max() over multiple dimensions in PyTorch. There is an open issue about it that you can follow and see if it ever gets implemented. A workaround in your case would be:

import torch

x = torch.tensor([
    [[-0.3000, -0.2926],[-0.2705, -0.2632]],
    [[-0.1821, -0.1747],[-0.1526, -0.1453]],
    [[-0.0642, -0.0568],[-0.0347, -0.0274]]
])

print(x.view(x.size(0), -1).max(dim=-1))

# output:
# >>> values=tensor([-0.2632, -0.1453, -0.0274]),
# >>> indices=tensor([3, 3, 3]))

So, if you need only the values: x.view(x.size(0), -1).max(dim=-1).values.

If x is not a contiguous tensor, then .view() will fail. In this case, you should use .reshape() instead.


Update August 26, 2020

This feature is being implemented in PR#43092 and the functions will be called amin and amax. They will return only the values. This is probably being merged soon, so you might be able to access these functions on the nightly build by the time you're reading this :) Have fun.

Solution 2

Although the solution of Berriel solves this specific question, I thought adding some explanation might help everyone to shed some light on the trick that's employed here, so that it can be adapted for (m)any other dimensions.

Let's start by inspecting the shape of the input tensor x:

In [58]: x.shape   
Out[58]: torch.Size([3, 2, 2])

So, we have a 3D tensor of shape (3, 2, 2). Now, as per OP's question, we need to compute maximum of the values in the tensor along both 1st and 2nd dimensions. As of this writing, the torch.max()'s dim argument supports only int. So, we can't use a tuple. Hence, we will use the following trick, which I will call as,

The Flatten & Max Trick: since we want to compute max over both 1st and 2nd dimensions, we will flatten both of these dimensions to a single dimension and leave the 0th dimension untouched. This is exactly what is happening by doing:

In [61]: x.flatten().reshape(x.shape[0], -1).shape   
Out[61]: torch.Size([3, 4])   # 2*2 = 4

So, now we have shrinked the 3D tensor to a 2D tensor (i.e. matrix).

In [62]: x.flatten().reshape(x.shape[0], -1) 
Out[62]:
tensor([[-0.3000, -0.2926, -0.2705, -0.2632],
        [-0.1821, -0.1747, -0.1526, -0.1453],
        [-0.0642, -0.0568, -0.0347, -0.0274]])

Now, we can simply apply max over the 1st dimension (i.e. in this case, first dimension is also the last dimension), since the flattened dimensions resides in that dimension.

In [65]: x.flatten().reshape(x.shape[0], -1).max(dim=1)    # or: `dim = -1`
Out[65]: 
torch.return_types.max(
values=tensor([-0.2632, -0.1453, -0.0274]),
indices=tensor([3, 3, 3]))

We got 3 values in the resultant tensor since we had 3 rows in the matrix.


Now, on the other hand if you want to compute max over 0th and 1st dimensions, you'd do:

In [80]: x.flatten().reshape(-1, x.shape[-1]).shape 
Out[80]: torch.Size([6, 2])    # 3*2 = 6

In [79]: x.flatten().reshape(-1, x.shape[-1]) 
Out[79]: 
tensor([[-0.3000, -0.2926],
        [-0.2705, -0.2632],
        [-0.1821, -0.1747],
        [-0.1526, -0.1453],
        [-0.0642, -0.0568],
        [-0.0347, -0.0274]])

Now, we can simply apply max over the 0th dimension since that is the result of our flattening. ((also, from our original shape of (3, 2, 2), after taking max over first 2 dimensions, we should get two values as result.)

In [82]: x.flatten().reshape(-1, x.shape[-1]).max(dim=0) 
Out[82]: 
torch.return_types.max(
values=tensor([-0.0347, -0.0274]),
indices=tensor([5, 5]))

In a similar vein, you can adapt this approach to multiple dimensions and other reduction functions such as min.


Note: I'm following the terminology of 0-based dimensions (0, 1, 2, 3, ...) just to be consistent with PyTorch usage and the code.

Solution 3

If you only want to use the torch.max() function to get the indices of the max entry in a 2D tensor, you can do:

max_i_vals, max_i_indices = torch.max(x, 0)
print('max_i_vals, max_i_indices: ', max_i_vals, max_i_indices)
max_j_index = torch.max(max_i_vals, 0)[1]
print('max_j_index: ', max_j_index)
max_index = [max_i_indices[max_j_index], max_j_index]
print('max_index: ', max_index)

In testing, the above printed out for me:

max_i_vals: tensor([0.7930, 0.7144, 0.6985, 0.7349, 0.9162, 0.5584, 1.4777, 0.8047, 0.9008, 1.0169, 0.6705, 0.9034, 1.1159, 0.8852, 1.0353], grad_fn=\<MaxBackward0>)   
max_i_indices: tensor([ 5,  8, 10,  6, 13, 14,  5,  6,  6,  6, 13,  4, 13, 13, 11])  
max_j_index:  tensor(6)  
max_index:  [tensor(5), tensor(6)]

This approach can be extended for 3 dimensions. While not as visually pleasing as other answers in this post, this answer shows that the problem can be solved using only the torch.max() function (though I do agree built-in support for torch.max() over multiple dimensions would be a boon).

FOLLOW UP
I stumbled upon a similar question in the PyTorch forums and the poster ptrblck offered this line of code as a solution for getting the indices of the maximal entry in the tensor x:

x = (x==torch.max(x)).nonzero()

Not only does this one-liner work with N-dimensional tensors without needing adjustments to the code, but it is also much faster than the approach I wrote of above (at least 2:1 ratio) and faster than the accepted answer (about 3:2 ratio) according to my benchmarks.

Share:
11,243
iGero
Author by

iGero

Updated on June 05, 2022

Comments

  • iGero
    iGero almost 2 years

    Have tensor like :x.shape = [3, 2, 2].

    import torch
    
    x = torch.tensor([
        [[-0.3000, -0.2926],[-0.2705, -0.2632]],
        [[-0.1821, -0.1747],[-0.1526, -0.1453]],
        [[-0.0642, -0.0568],[-0.0347, -0.0274]]
    ])
    

    I need to take .max() over the 2nd and 3rd dimensions. I expect some like this [-0.2632, -0.1453, -0.0274] as output. I tried to use: x.max(dim=(1,2)), but this causes an error.

  • iGero
    iGero about 4 years
    thanks. it works, but need use reshape insted view for avoid error in my case
  • Berriel
    Berriel about 4 years
    @iGero ok, I'll add this note on the answer just in case :) glad it helped
  • San Askaruly
    San Askaruly over 3 years
    oh, getting clear a bit. could you please specify what is "result of flattening"? i would appreciate, thanks!
  • kmario23
    kmario23 over 3 years
    Flattening always returns an 1D tensor of size resulting from the multiplication of the individual dimensions in the original shape (i.e., 3*2*2 here with tensor x)
  • zwep
    zwep over 3 years
    I tried this with pytorch version 1.5.0 and 1.6.0, but there was no method torch.amax. Can you validate that? Or am I doing something wrong?
  • Berriel
    Berriel over 3 years
    @zwep as I said in the answer, this function is currently available in the nightly release. Therefore, you have to upgrade to it if you want to have access to amax, or wait until the next stable release, i.e., 1.7.0.
  • zwep
    zwep over 3 years
    @Berriel ah sorry, I did not know which version was related to the nightly release. Although I dont know if you can talk of a version in such a case
  • Berriel
    Berriel over 3 years
    @zwep don't worry. Well, it is called nightly build/release :) you can find instructions on how to install in the official PyTorch website. It is straightforward. Let me know if you have any problem.