How does Pytorch Dataloader handle variable size data?
Solution 1
So how do you handle the fact that your samples are of different length? torch.utils.data.DataLoader
has a collate_fn
parameter which is used to transform a list of samples into a batch. By default it does this to lists. You can write your own collate_fn
, which for instance 0
-pads the input, truncates it to some predefined length or applies any other operation of your choice.
Solution 2
This is the way I do it:
def collate_fn_padd(batch):
'''
Padds batch of variable length
note: it converts things ToTensor manually here since the ToTensor transform
assume it takes in images rather than arbitrary tensors.
'''
## get sequence lengths
lengths = torch.tensor([ t.shape[0] for t in batch ]).to(device)
## padd
batch = [ torch.Tensor(t).to(device) for t in batch ]
batch = torch.nn.utils.rnn.pad_sequence(batch)
## compute mask
mask = (batch != 0).to(device)
return batch, lengths, mask
then I pass that to the dataloader class as a collate_fn
.
There seems to be a giant list of different posts in the pytorch forum. Let me link to all of them. They all have answers of their own and discussions. It doesn't seem to me that there is one "standard way to do it" but if there is from an authoritative reference please share.
It would be nice that the ideal answer mentions
- efficiency, e.g. if to do the processing in GPU with torch in the collate function vs numpy
things of that sort.
List:
- https://discuss.pytorch.org/t/how-to-create-batches-of-a-list-of-varying-dimension-tensors/50773
- https://discuss.pytorch.org/t/how-to-create-a-dataloader-with-variable-size-input/8278
- https://discuss.pytorch.org/t/using-variable-sized-input-is-padding-required/18131
- https://discuss.pytorch.org/t/dataloader-for-various-length-of-data/6418
- https://discuss.pytorch.org/t/how-to-do-padding-based-on-lengths/24442
bucketing: - https://discuss.pytorch.org/t/tensorflow-esque-bucket-by-sequence-length/41284
Solution 3
As @Jatentaki suggested, I wrote my custom collate function and it worked fine.
def get_max_length(x):
return len(max(x, key=len))
def pad_sequence(seq):
def _pad(_it, _max_len):
return [0] * (_max_len - len(_it)) + _it
return [_pad(it, get_max_length(seq)) for it in seq]
def custom_collate(batch):
transposed = zip(*batch)
lst = []
for samples in transposed:
if isinstance(samples[0], int):
lst.append(torch.LongTensor(samples))
elif isinstance(samples[0], float):
lst.append(torch.DoubleTensor(samples))
elif isinstance(samples[0], collections.Sequence):
lst.append(torch.LongTensor(pad_sequence(samples)))
return lst
stream_dataset = StreamDataset(data_path)
stream_data_loader = torch.utils.data.dataloader.DataLoader(dataset=stream_dataset,
batch_size=batch_size,
collate_fn=custom_collate,
shuffle=False)
Related videos on Youtube
Trung Le
Updated on June 25, 2022Comments
-
Trung Le almost 2 years
I have a dataset that looks like below. That is the first item is the user id followed by the set of items which is clicked by the user.
0 24104 27359 6684 0 24104 27359 1 16742 31529 31485 1 16742 31529 2 6579 19316 13091 7181 6579 19316 13091 2 6579 19316 13091 7181 6579 19316 2 6579 19316 13091 7181 6579 19316 13091 6579 2 6579 19316 13091 7181 6579 4 19577 21608 4 19577 21608 4 19577 21608 18373 5 3541 9529 5 3541 9529 6 6832 19218 14144 6 6832 19218 7 9751 23424 25067 12606 26245 23083 12606
I define a custom dataset to handle my click log data.
import torch.utils.data as data class ClickLogDataset(data.Dataset): def __init__(self, data_path): self.data_path = data_path self.uids = [] self.streams = [] with open(self.data_path, 'r') as fdata: for row in fdata: row = row.strip('\n').split('\t') self.uids.append(int(row[0])) self.streams.append(list(map(int, row[1:]))) def __len__(self): return len(self.uids) def __getitem__(self, idx): uid, stream = self.uids[idx], self.streams[idx] return uid, stream
Then I use a DataLoader to retrieve mini batches from the data for training.
from torch.utils.data.dataloader import DataLoader clicklog_dataset = ClickLogDataset(data_path) clicklog_data_loader = DataLoader(dataset=clicklog_dataset, batch_size=16) for uid_batch, stream_batch in stream_data_loader: print(uid_batch) print(stream_batch)
The code above returns differently from what I expected, I want
stream_batch
to be a 2D tensor of type integer of length16
. However, what I get is a list of 1D tensor of length 16, and the list has only one element, like below. Why is that ?#stream_batch [tensor([24104, 24104, 16742, 16742, 6579, 6579, 6579, 6579, 19577, 19577, 19577, 3541, 3541, 6832, 6832, 9751])]
-
Charlie Parker almost 5 yearscross posted: quora.com/unanswered/…
-
-
Black Jack 21 about 4 yearsWhat if I do not desire to pad extra numbers? I mean what if I have a fully convolutional neural network and I do not need same sized input and in particular I do not want to change input by padding it as well (I am doing an explainable AI experiment)?
-
Jatentaki about 4 years@RedFloyd it's all fine, except you will need to make some adaptations and will lose some performance. In PyTorch (and roughly every other framework) CNN operations such as
Conv2d
are executed in a "vectorized" fashion over the 1st dimension (usually called batch dimension). In your case, you will just have to have this dimension equal to 1 and call your network as many times as you have images instead of just stacking them into one big tensor and executing your network once on all of them. This will probably cost you performance but nothing more. -
Black Jack 21 about 4 yearsThanks for replying. Just to clarify, doing this is essentially SGD, which would be noisy and troublesome to train (ie, may not converge wel) ?l
-
Tahlor almost 4 yearsIs it customary to put tensors on the GPU in collate? I was under the impression this means you can't use multiple workers in your dataloader if you do this. I'd be interested in knowing which approach typically has better performance.
-
financial_physician about 3 years@Pinocchio why do you compute the sequence lengths and mask? If I understand correctly, once the batch gets passed into the network the network doesn't have a way to use masks or to trim the input, right?
-
financial_physician about 3 yearsIn case anyone stumbles across this, I think the answer provided by David Ng is the best way to do this stackoverflow.com/questions/51030782/…