Add `desired_labels` argument to `collate_fn` in `data_loader`
Closed this issue · 0 comments
SaashaJoshi commented
We require users to specify what labels they want to retrieve from the loaded data. To realize this, collate_fn
needs to support a desired_label
argument. For example,
def collate_fn(batch, desired_labels):
new_batch = []
for img, label in batch:
if label in desired_labels:
filtered_batch.append((img, label))
return torch.utils.data.default_collate(new_batch)