SaashaJoshi/piQture

Add `desired_labels` argument to `collate_fn` in `data_loader`

Closed this issue · 0 comments

We require users to specify what labels they want to retrieve from the loaded data. To realize this, collate_fn needs to support a desired_label argument. For example,

def collate_fn(batch, desired_labels):
    new_batch = []
    for img, label in batch:
        if label in desired_labels:
            filtered_batch.append((img, label))
    return torch.utils.data.default_collate(new_batch)