Can't use run_segment with apex.amp
Opened this issue · 4 comments
I use code like this
run_segment = optimal_grad_checkpointing(model, inp)
run_segment, optimizer = apex.amp.initialize(run_segment, optimizer, opt_level="02", verbosity=0)
...
output = run_segment(images)
and get the error
output = run_segment(images)
File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
result = self.forward(*input, **kwargs)
File "/opt/conda/lib/python3.6/site-packages/apex/amp/_initialize.py", line 197, in new_fwd
**applier(kwargs, input_caster))
File "/working_dir/OptimalGradCheckpointing/graph.py", line 911, in forward
return graph_forward(x, **self.info_dict)
File "/working_dir/OptimalGradCheckpointing/graph.py", line 838, in graph_forward
output = checkpoint(segment_checkpoint_forward(op), input)
File "/opt/conda/lib/python3.6/site-packages/torch/utils/checkpoint.py", line 155, in checkpoint
return CheckpointFunction.apply(function, preserve, *args)
File "/opt/conda/lib/python3.6/site-packages/torch/utils/checkpoint.py", line 74, in forward
outputs = run_function(*args)
File "/working_dir/OptimalGradCheckpointing/graph.py", line 807, in custom_forward
outputs = segment(*inputs)
File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
result = self.forward(*input, **kwargs)
File "/working_dir/OptimalGradCheckpointing/graph.py", line 911, in forward
return graph_forward(x, **self.info_dict)
File "/working_dir/OptimalGradCheckpointing/graph.py", line 840, in graph_forward
output = op(input)
File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
result = self.forward(*input, **kwargs)
File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 349, in forward
return self._conv_forward(input, self.weight)
File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 346, in _conv_forward
self.padding, self.dilation, self.groups)
RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same
It would be effective to combine Optimal Gradient Checkpointing with apex.amp or torch.cuda.amp
Hi,
It could be that the pytorch checkpointing function is not supporting apex. Did you try torch.cuda.amp?
I would like to try torch.cuda.amp, but torch.cuda.amp.autocast appears only in PyTorch 1.6 and OptimalGradCheckpointing works only with PyTorch 1.5
Our implementation of auto parsing graph is depending on torch.jit and quite volatile with pytorch version. If you have manual parse_graph function it can definitely work with 1.6.
For auto parse, I haven't tested on 1.6 but I think it is likely working because I don't expect too many changes from pytorch 1.5 to 1.6.
Let me know if you are able to use it under pytorch 1.6. I will also test the compatibility of different versions when I get time.
Thanks
Yes, it works with torch.cuda.amp with PyTorch 1.10 after I fixed the line #3 (comment)
Thanks!