Taking subsets of a pytorch dataset
Solution 1
You can define a custom sampler for the dataset loader avoiding recreating the dataset (just creating a new loader for each different sampling).
class YourSampler(Sampler):
def __init__(self, mask):
self.mask = mask
def __iter__(self):
return (self.indices[i] for i in torch.nonzero(self.mask))
def __len__(self):
return len(self.mask)
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
sampler1 = YourSampler(your_mask)
sampler2 = YourSampler(your_other_mask)
trainloader_sampler1 = torch.utils.data.DataLoader(trainset, batch_size=4,
sampler = sampler1, shuffle=False, num_workers=2)
trainloader_sampler2 = torch.utils.data.DataLoader(trainset, batch_size=4,
sampler = sampler2, shuffle=False, num_workers=2)
PS: You can find more info here: http://pytorch.org/docs/master/_modules/torch/utils/data/sampler.html#Sampler
Solution 2
torch.utils.data.Subset
is easier, supports shuffle
, and doesn't require writing your own sampler:
import torchvision
import torch
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=None)
evens = list(range(0, len(trainset), 2))
odds = list(range(1, len(trainset), 2))
trainset_1 = torch.utils.data.Subset(trainset, evens)
trainset_2 = torch.utils.data.Subset(trainset, odds)
trainloader_1 = torch.utils.data.DataLoader(trainset_1, batch_size=4,
shuffle=True, num_workers=2)
trainloader_2 = torch.utils.data.DataLoader(trainset_2, batch_size=4,
shuffle=True, num_workers=2)
Related videos on Youtube
Comments
-
Miriam Farber over 1 year
I have a network which I want to train on some dataset (as an example, say
CIFAR10
). I can create data loader object viatrainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
My question is as follows: Suppose I want to make several different training iterations. Let's say I want at first to train the network on all images in odd positions, then on all images in even positions and so on. In order to do that, I need to be able to access to those images. Unfortunately, it seems that
trainset
does not allow such access. That is, trying to dotrainset[:1000]
or more generallytrainset[mask]
will throw an error.I could do instead
trainset.train_data=trainset.train_data[mask] trainset.train_labels=trainset.train_labels[mask]
and then
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
However, that will force me to create a new copy of the full dataset in each iteration (as I already changed
trainset.train_data
so I will need to redefinetrainset
). Is there some way to avoid it?Ideally, I would like to have something "equivalent" to
trainloader = torch.utils.data.DataLoader(trainset[mask], batch_size=4, shuffle=True, num_workers=2)
-
Miriam Farber over 6 yearsThanks! One small remark: apparently sampler is not compatible with shuffle, so in order to achieve the same result one can do: torch.utils.data.DataLoader(trainset, batch_size=4, sampler=SubsetRandomSampler(np.where(mask)[0]),shuffle=False, num_workers=2)
-
jodag over 4 yearsKeep in mind that a
list
of indices is a valid argument forsampler
since it implements__len__
and__iter__
. This kind of circumvents the need for a custom sampler class. -
user650654 almost 4 yearsConverting
evens
andodds
to a list is not necessary--at least in torch 1.5.0,Subset
accepts generators:ts1 = Subset(trainset, range(0, len(trainset), 2))
-
noamgot over 3 yearsIt doesn't allow filtering by class, just by the dataset original order, does it?
-
Antoine almost 3 years@user650654 Slighlty off-topic, but
range
is not a generator. -
LudvigH almost 3 yearsthe index set must be a python
Sequence
. i.e.list
,tuple
orrange
-
Isaac Zhao almost 2 yearsfrom torch.utils.data.sampler import SubsetRandomSampler