Is there any pytorch function can combine the specific continuous dimensions of tensor into one?
Solution 1
I am not sure what you have in mind with "a more elegant way", but Tensor.view()
has the advantage not to re-allocate data for the view (original tensor and view share the same data), making this operation quite light-weight.
As mentioned by @UmangGupta, it is however rather straight-forward to wrap this function to achieve what you want, e.g.:
import torch
def magic_combine(x, dim_begin, dim_end):
combined_shape = list(x.shape[:dim_begin]) + [-1] + list(x.shape[dim_end:])
return x.view(combined_shape)
a = torch.zeros(1, 2, 3, 4, 5, 6)
b = magic_combine(a, 2, 5) # combine dimension 2, 3, 4
print(b.size())
# torch.Size([1, 2, 60, 6])
Solution 2
a = torch.zeros(1, 2, 3, 4, 5, 6)
b = a.view(*a.shape[:2], -1, *a.shape[5:])
Seems to me a bit simpler than the current accepted answer and doesn't go through a list
constructor (3 times).
Solution 3
Also possible with torch einops.
> pip install einops
from einops import rearrange
a = torch.zeros(1, 2, 3, 4, 5, 6)
b = rearrange(a, 'd0 d1 d2 d3 d4 d5 -> d0 d1 (d2 d3 d4) d5')
gasoon
Updated on June 05, 2022Comments
-
gasoon almost 2 years
Let's call the function I'm looking for "
magic_combine
", which can combine the continuous dimensions of tensor I give to it. For more specific, I want it to do the following thing:a = torch.zeros(1, 2, 3, 4, 5, 6) b = a.magic_combine(2, 5) # combine dimension 2, 3, 4 print(b.size()) # should be (1, 2, 60, 6)
I know that
torch.view()
can do the similar thing. But I'm just wondering if there is any more elegant way to achieve the goal?