HomebrewNLP/revlib

revlib can not work in torch amp.

JAYatBUAA opened this issue · 8 comments

dear authors,
when using revlib in torch amp, it reports error as follow:

Traceback (most recent call last):
File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/revlib/core.py", line 130, in backward
mod_out = take_0th_tensor(new_mod.wrapped_module(y0, *ctx.args, **ctx.kwargs))
File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 613, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 597, in _conv_forward
return F.conv3d(
RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same

dear authors, how to solve it, thanks in advance

Unfortunately, I don't know. I don't use torch.amp myself, and I last touched RevLib a while ago. Please let me know if you have a minimal script to reproduce the error or manage to fix it. So that you know, all pull requests are welcome.
Just so we're on the same page: torch.amp doesn't save memory but instead is only helpful for speed improvements by downcasting matrix multiplications to fp16. Is this what you're after? If you want the memory improvements, please give RevLib's intermediate casting a try.\

Unfortunately, I don't know. I don't use torch.amp myself, and I last touched RevLib a while ago. Please let me know if you have a minimal script to reproduce the error or manage to fix it. So that you know, all pull requests are welcome. Just so we're on the same page: torch.amp doesn't save memory but instead is only helpful for speed improvements by downcasting matrix multiplications to fp16. Is this what you're after? If you want the memory improvements, please give RevLib's intermediate casting a try.\

dear author, I guess the error happens in the RevResNet backward pass where feature map dtype (float16) is not match to the conv weights dtype (float32) and this will not happen in forward pass, because the forward pass is warpped in torch.cuda.amp.autocast() context, where the conv weights dtype will automatically convert to the half.

Do you have a minimal example to reproduce the error?

Do you have a minimal example to reproduce the error?

Due to an important deadline recently, I'll try to give you a reply as soon as possible. Thanks for your help.

when loss.backward() warpped in torch.cuda.amp.autocast() context, this error is not reported.

Please share a minimal script to reproduce this error. I'll be able to take it from there.

The best next action would be a PR with a unit test for torch amp.
Alternatively, RevLib is open to contributions. You're welcome to submit a PR for the fix :)

Really thanks for your RevLib. I hope to contribute mysellf once I have enough time.