p-lambda/wilds

Error with example fMOW command: incorrect value of "unlabeled_n_groups_per_batch"

joshuafan opened this issue · 0 comments

Hello,
If I directly run this command suggested in the README:
python examples/run_expt.py --dataset fmow --algorithm DANN --unlabeled_split test_unlabeled --root_dir data

I get the following exeption:

Traceback (most recent call last):
  File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/examples/run_expt.py", line 491, in <module>
    main()
  File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/examples/run_expt.py", line 454, in main
    train(
  File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/examples/train.py", line 114, in train
    run_epoch(algorithm, datasets['train'], general_logger, epoch, config, train=True, unlabeled_dataset=unlabeled_dataset)
  File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/examples/train.py", line 38, in run_epoch
    unlabeled_data_iterator = InfiniteDataIterator(unlabeled_dataset['loader'])
  File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/examples/utils.py", line 393, in __init__
    self.iter = iter(self.data_loader)
  File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 442, in __iter__
    return self._get_iterator()
  File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 388, in _get_iterator
    return _MultiProcessingDataLoaderIter(self)
  File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1085, in __init__
    self._reset(loader, first_iter=True)
  File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1118, in _reset
    self._try_put_index()
  File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1352, in _try_put_index
    index = self._next_index()
  File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 624, in _next_index
    return next(self._sampler_iter)  # may raise StopIteration
  File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/wilds/common/data_loaders.py", line 131, in __iter__
    groups_for_batch = np.random.choice(
  File "mtrand.pyx", line 984, in numpy.random.mtrand.RandomState.choice
ValueError: Cannot take a larger sample than population when 'replace=False'

I think this occurs because there are only 2 unique years in the test_unlabeled split, but unlabeled_n_groups_per_batch is set to 8, so it tries to sample 8 years without replacement.

I was able to fix this by changing the argument unlabeled_n_groups_per_batch to 2, here: https://github.com/p-lambda/wilds/blob/main/examples/configs/datasets.py#L220

It would be great if this can be fixed. Thank you so much for releasing these wonderful datasets and baseline algorithms!