Is there any pytorch function can combine the specific continuous dimensions of tensor into one?

10,793

Solution 1

I am not sure what you have in mind with "a more elegant way", but Tensor.view() has the advantage not to re-allocate data for the view (original tensor and view share the same data), making this operation quite light-weight.

As mentioned by @UmangGupta, it is however rather straight-forward to wrap this function to achieve what you want, e.g.:

import torch

def magic_combine(x, dim_begin, dim_end):
    combined_shape = list(x.shape[:dim_begin]) + [-1] + list(x.shape[dim_end:])
    return x.view(combined_shape)

a = torch.zeros(1, 2, 3, 4, 5, 6)
b = magic_combine(a, 2, 5) # combine dimension 2, 3, 4
print(b.size())
# torch.Size([1, 2, 60, 6])

Solution 2

a = torch.zeros(1, 2, 3, 4, 5, 6)
b = a.view(*a.shape[:2], -1, *a.shape[5:])

Seems to me a bit simpler than the current accepted answer and doesn't go through a list constructor (3 times).

Solution 3

Also possible with torch einops.

Github.

> pip install einops
from einops import rearrange

a = torch.zeros(1, 2, 3, 4, 5, 6)
b = rearrange(a, 'd0 d1 d2 d3 d4 d5 -> d0 d1 (d2 d3 d4) d5')
Share:
10,793
gasoon
Author by

gasoon

Updated on June 05, 2022

Comments

  • gasoon
    gasoon almost 2 years

    Let's call the function I'm looking for "magic_combine", which can combine the continuous dimensions of tensor I give to it. For more specific, I want it to do the following thing:

    a = torch.zeros(1, 2, 3, 4, 5, 6)  
    b = a.magic_combine(2, 5) # combine dimension 2, 3, 4 
    print(b.size()) # should be (1, 2, 60, 6)
    

    I know that torch.view() can do the similar thing. But I'm just wondering if there is any more elegant way to achieve the goal?