torch.where bug
cguerner opened this issue · 4 comments
Hi,
Reporting a bug with the latest package version concept-erasure 0.2.1, python 10.4 and torch 1.13.0+cu116.
The second argument to torch.Tensor.where() has to be a tensor, not a float. Fixed it with the following
Lzeros = torch.zeros(L.shape) W = V * L.rsqrt().where(mask, Lzeros) @ V.mT W_inv = V * L.sqrt().where(mask, Lzeros) @ V.mT
Hmm this is pretty bizarre because
- The PyTorch docs for
Tensor.where
do say that a scalar should work, and this is true both for PyTorch 2.0 and the 1.13 version you're using
- It's working for me when I try it on PyTorch 2.0.
![Captura de pantalla 2023-09-04 a la(s) 12 29 53 a m](https://private-user-images.githubusercontent.com/39116809/265387017-8fb38320-f3c7-4f7f-8f6a-19066eca2845.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MTY3MzI0MDcsIm5iZiI6MTcxNjczMjEwNywicGF0aCI6Ii8zOTExNjgwOS8yNjUzODcwMTctOGZiMzgzMjAtZjNjNy00ZjdmLThmNmEtMTkwNjZlY2EyODQ1LnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDA1MjYlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwNTI2VDE0MDE0N1omWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWU5ZDk1MGEwMGZjMWMyYThkYmZkZjQ1ZTg3YTk2YjU4NjEzNzQ5OTI0MTU5YmQ0YmQ4OTcxYWNlMDMzODc3OWEmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.ZxG6v2ZD_ZSjmSFx8ckMH032uvzUp5G8jwyZloPVXZ4)
If you could make a clean repro that would be helpful. I'd also recommend you try PyTorch 2.0, since it's possible that the torch.where
behavior did not match the docs in 1.13 but this was fixed in 2.0. If that's the case we might want to bump the minimum required PyTorch version to 2.0.
Came here to say the same same thing -- observed this issue on a Mac M2 with 1.13
I can reproduce with PyTorch 1.13.
The docs going back to PyTorch 1.7.0 say that scalars are supposed to be allowed in the function torch.where
, and Tensor.where
is supposed to behave the same way, but it looks like there was actually a discrepancy. The function actually does work on PyTorch 1.13:
So I think the solution is to use the function instead of the Tensor
method.