Deepcopy in BPM prevents gradient calculation
Hespe opened this issue · 1 comments
Hespe commented
In the BPM, there is a call to deepcopy
before passing on the beam down the line:
cheetah/cheetah/accelerator/bpm.py
Line 51 in b38a654
Unfortunetly, this copy currently inhibits taking gradients of the BPM reading. The reduced example
import torch
import cheetah
beam = cheetah.ParameterBeam.from_parameters(mu_x=torch.tensor([0.0], requires_grad=True))
bpm = cheetah.BPM()
bpm.track(beam)
fails with
RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment. If you were attempting to deepcopy a module, this may be because of a torch.nn.utils.weight_norm usage, see pytorch/pytorch#103001
Hespe commented
Maybe the deepcopy
can simply be replaced by a call to torch.clone
?