A pytorch dataset sampler
for always sampling balanced batches.
Be sure to use a batch_size
that is an integer multiple of the number of classes.
For example, if your train_dataset
has 10 classes and you use a batch_size=30
with the BalancedBatchSampler
train_loader = torch.utils.data.DataLoader(train_dataset, sampler=BalancedBatchSampler(train_dataset), batch_size=30)
You will obtain a train_loader
in which each element has 3 samples for each of the 10 classes