Taking subsets of a pytorch dataset

49,138

Solution 1

You can define a custom sampler for the dataset loader avoiding recreating the dataset (just creating a new loader for each different sampling).

class YourSampler(Sampler):
    def __init__(self, mask):
        self.mask = mask

    def __iter__(self):
        return (self.indices[i] for i in torch.nonzero(self.mask))

    def __len__(self):
        return len(self.mask)

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

sampler1 = YourSampler(your_mask)
sampler2 = YourSampler(your_other_mask)
trainloader_sampler1 = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          sampler = sampler1, shuffle=False, num_workers=2)
trainloader_sampler2 = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          sampler = sampler2, shuffle=False, num_workers=2)

PS: You can find more info here: http://pytorch.org/docs/master/_modules/torch/utils/data/sampler.html#Sampler

Solution 2

torch.utils.data.Subset is easier, supports shuffle, and doesn't require writing your own sampler:

import torchvision
import torch

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=None)

evens = list(range(0, len(trainset), 2))
odds = list(range(1, len(trainset), 2))
trainset_1 = torch.utils.data.Subset(trainset, evens)
trainset_2 = torch.utils.data.Subset(trainset, odds)

trainloader_1 = torch.utils.data.DataLoader(trainset_1, batch_size=4,
                                            shuffle=True, num_workers=2)
trainloader_2 = torch.utils.data.DataLoader(trainset_2, batch_size=4,
                                            shuffle=True, num_workers=2)
Share:
49,138

Related videos on Youtube

Miriam Farber
Author by

Miriam Farber

Applied Math PhD Student at MIT

Updated on August 03, 2022

Comments

  • Miriam Farber
    Miriam Farber over 1 year

    I have a network which I want to train on some dataset (as an example, say CIFAR10). I can create data loader object via

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                              shuffle=True, num_workers=2)
    

    My question is as follows: Suppose I want to make several different training iterations. Let's say I want at first to train the network on all images in odd positions, then on all images in even positions and so on. In order to do that, I need to be able to access to those images. Unfortunately, it seems that trainset does not allow such access. That is, trying to do trainset[:1000] or more generally trainset[mask] will throw an error.

    I could do instead

    trainset.train_data=trainset.train_data[mask]
    trainset.train_labels=trainset.train_labels[mask]
    

    and then

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                                  shuffle=True, num_workers=2)
    

    However, that will force me to create a new copy of the full dataset in each iteration (as I already changed trainset.train_data so I will need to redefine trainset). Is there some way to avoid it?

    Ideally, I would like to have something "equivalent" to

    trainloader = torch.utils.data.DataLoader(trainset[mask], batch_size=4,
                                                  shuffle=True, num_workers=2)
    
  • Miriam Farber
    Miriam Farber over 6 years
    Thanks! One small remark: apparently sampler is not compatible with shuffle, so in order to achieve the same result one can do: torch.utils.data.DataLoader(trainset, batch_size=4, sampler=SubsetRandomSampler(np.where(mask)[0]),shuffle=False‌​, num_workers=2)
  • jodag
    jodag over 4 years
    Keep in mind that a list of indices is a valid argument for sampler since it implements __len__ and __iter__. This kind of circumvents the need for a custom sampler class.
  • user650654
    user650654 almost 4 years
    Converting evens and odds to a list is not necessary--at least in torch 1.5.0, Subset accepts generators: ts1 = Subset(trainset, range(0, len(trainset), 2))
  • noamgot
    noamgot over 3 years
    It doesn't allow filtering by class, just by the dataset original order, does it?
  • Antoine
    Antoine almost 3 years
    @user650654 Slighlty off-topic, but range is not a generator.
  • LudvigH
    LudvigH almost 3 years
    the index set must be a python Sequence. i.e. list, tuple or range
  • Isaac Zhao
    Isaac Zhao almost 2 years
    from torch.utils.data.sampler import SubsetRandomSampler