torch.where API in MNIST and CIFAR10, ImageNet configuration files
NanyangYe opened this issue · 1 comments
Hi,
When we tried to run the codes for MNIST and CIFAR10. It throws the error like that:
index = torch.where(output.max(1)[1] == y)[0]
TypeError: where() missing 2 required positional argument: "input", "other"
We have checked the API docs for Pytorch 1.3, Pytorch 1.0, Pytorch 0.4.1 . It seems that the usuage is not standard. We also tried to run the experiment in ImageNet folder, but the configuration files used in the code are not there in the Github. Do you know how to fix this? Thank you very much.
Hi @lincolnBush!
I would double check your PyTorch version, because v1.3 should have worked. The torch.where(conditional)
functionality, as far as I can tell, has been in Pytorch 1.2+ (I've been personally using the latest stable release, v1.4), and you can see it at the API docs here after the description of the torch.where(condition,x,y)
: https://pytorch.org/docs/stable/torch.html#torch.where
You may have missed the ImageNet config files in the ImageNet/config/
folder here. The run_fast_Npx.sh
shell scripts (e.g. this one) show how to use them.
Let us know if you have any other questions!
~Eric