Error with example fMOW command: incorrect value of "unlabeled_n_groups_per_batch"
joshuafan opened this issue · 0 comments
joshuafan commented
Hello,
If I directly run this command suggested in the README:
python examples/run_expt.py --dataset fmow --algorithm DANN --unlabeled_split test_unlabeled --root_dir data
I get the following exeption:
Traceback (most recent call last):
File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/examples/run_expt.py", line 491, in <module>
main()
File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/examples/run_expt.py", line 454, in main
train(
File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/examples/train.py", line 114, in train
run_epoch(algorithm, datasets['train'], general_logger, epoch, config, train=True, unlabeled_dataset=unlabeled_dataset)
File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/examples/train.py", line 38, in run_epoch
unlabeled_data_iterator = InfiniteDataIterator(unlabeled_dataset['loader'])
File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/examples/utils.py", line 393, in __init__
self.iter = iter(self.data_loader)
File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 442, in __iter__
return self._get_iterator()
File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 388, in _get_iterator
return _MultiProcessingDataLoaderIter(self)
File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1085, in __init__
self._reset(loader, first_iter=True)
File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1118, in _reset
self._try_put_index()
File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1352, in _try_put_index
index = self._next_index()
File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 624, in _next_index
return next(self._sampler_iter) # may raise StopIteration
File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/wilds/common/data_loaders.py", line 131, in __iter__
groups_for_batch = np.random.choice(
File "mtrand.pyx", line 984, in numpy.random.mtrand.RandomState.choice
ValueError: Cannot take a larger sample than population when 'replace=False'
I think this occurs because there are only 2 unique years in the test_unlabeled
split, but unlabeled_n_groups_per_batch
is set to 8, so it tries to sample 8 years without replacement.
I was able to fix this by changing the argument unlabeled_n_groups_per_batch
to 2, here: https://github.com/p-lambda/wilds/blob/main/examples/configs/datasets.py#L220
It would be great if this can be fixed. Thank you so much for releasing these wonderful datasets and baseline algorithms!