How does Pytorch's "Fold" and "Unfold" work?
Solution 1
unfold
and fold
are used to facilitate "sliding window" operation (like convolutions).
Suppose you want to apply a function foo
to every 5x5 window in a feature map/image:
from torch.nn import functional as f
windows = f.unfold(x, kernel_size=5)
Now windows
has size
of batch-(5*5*x.size(1)
)-num_windows, you can apply foo
on windows
:
processed = foo(windows)
Now you need to "fold" processed
back to the original size of x
:
out = f.fold(processed, x.shape[-2:], kernel_size=5)
You need to take care of padding
, and kernel_size
that may affect your ability to "fold" back processed
to the size of x
.
Moreover, fold
sums over overlapping elements, so you might want to divide the output of fold
by patch size.
Solution 2
unfold
imagines a tensor as a longer tensor with repeated columns/rows of values 'folded' on top of each other, which is then "unfolded":
-
size
determines how large the folds are -
step
determines how often it is folded
E.g. for a 2x5 tensor, unfolding it with step=1
, and patch size=2
across dim=1
:
x = torch.tensor([[1,2,3,4,5],
[6,7,8,9,10]])
>>> x.unfold(1,2,1)
tensor([[[ 1, 2], [ 2, 3], [ 3, 4], [ 4, 5]],
[[ 6, 7], [ 7, 8], [ 8, 9], [ 9, 10]]])
fold
is roughly the opposite of this operation, but "overlapping" values are summed in the output.
Solution 3
One dimensional unfolding is easy:
x = torch.arange(1, 9).float()
print(x)
# dimension, size, step
print(x.unfold(0, 2, 1))
print(x.unfold(0, 3, 2))
Out:
tensor([1., 2., 3., 4., 5., 6., 7., 8.])
tensor([[1., 2.],
[2., 3.],
[3., 4.],
[4., 5.],
[5., 6.],
[6., 7.],
[7., 8.]])
tensor([[1., 2., 3.],
[3., 4., 5.],
[5., 6., 7.]])
Two dimensional unfolding (also called patching)
import torch
patch=(3,3)
x=torch.arange(16).float()
print(x, x.shape)
x2d = x.reshape(1,1,4,4)
print(x2d, x2d.shape)
h,w = patch
c=x2d.size(1)
print(c) # channels
# unfold(dimension, size, step)
r = x2d.unfold(2,h,1).unfold(3,w,1).transpose(1,3).reshape(-1, c, h, w)
print(r.shape)
print(r) # result
tensor([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13.,
14., 15.]) torch.Size([16])
tensor([[[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.]]]]) torch.Size([1, 1, 4, 4])
1
torch.Size([4, 1, 3, 3])
tensor([[[[ 0., 1., 2.],
[ 4., 5., 6.],
[ 8., 9., 10.]]],
[[[ 4., 5., 6.],
[ 8., 9., 10.],
[12., 13., 14.]]],
[[[ 1., 2., 3.],
[ 5., 6., 7.],
[ 9., 10., 11.]]],
[[[ 5., 6., 7.],
[ 9., 10., 11.],
[13., 14., 15.]]]])
shoshi
Updated on July 24, 2022Comments
-
shoshi almost 2 years
I've gone through the official doc. I'm having a hard time understanding what this function is used for and how it works. Can someone explain this in Layman terms?
I get an error for the example they provide, although the Pytorch version I'm using matches the documentation. Perhaps fixing the error, which I did, is supposed to teach me something? The snippet given in the documentation is:
fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 2)) input = torch.randn(1, 3 * 2 * 2, 1) output = fold(input) output.size()
and the fixed snippet is:
fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 2)) input = torch.randn(1, 3 * 2 * 2, 3 * 2 * 2) output = fold(input) output.size()
Thanks!