EleutherAI/concept-erasure

torch.where bug

cguerner opened this issue · 4 comments

Hi,

Reporting a bug with the latest package version concept-erasure 0.2.1, python 10.4 and torch 1.13.0+cu116.

Screen Shot 2023-08-29 at 15 52 50

The second argument to torch.Tensor.where() has to be a tensor, not a float. Fixed it with the following

Lzeros = torch.zeros(L.shape) W = V * L.rsqrt().where(mask, Lzeros) @ V.mT W_inv = V * L.sqrt().where(mask, Lzeros) @ V.mT

Hmm this is pretty bizarre because

  1. The PyTorch docs for Tensor.where do say that a scalar should work, and this is true both for PyTorch 2.0 and the 1.13 version you're using
    Captura de pantalla 2023-09-04 a la(s) 12 28 40 a m
  2. It's working for me when I try it on PyTorch 2.0.
Captura de pantalla 2023-09-04 a la(s) 12 29 53 a m

If you could make a clean repro that would be helpful. I'd also recommend you try PyTorch 2.0, since it's possible that the torch.where behavior did not match the docs in 1.13 but this was fixed in 2.0. If that's the case we might want to bump the minimum required PyTorch version to 2.0.

Came here to say the same same thing -- observed this issue on a Mac M2 with 1.13

I can reproduce with PyTorch 1.13.
Captura de pantalla 2023-09-04 a la(s) 9 16 22 a m

The docs going back to PyTorch 1.7.0 say that scalars are supposed to be allowed in the function torch.where, and Tensor.where is supposed to behave the same way, but it looks like there was actually a discrepancy. The function actually does work on PyTorch 1.13:
Captura de pantalla 2023-09-04 a la(s) 9 23 34 a m

So I think the solution is to use the function instead of the Tensor method.

Can @cguerner and/or @cemoody confirm that PR #7 fixes the problem in your environments?