desy-ml/cheetah

Deepcopy in BPM prevents gradient calculation

Hespe opened this issue · 1 comments

In the BPM, there is a call to deepcopy before passing on the beam down the line:

return deepcopy(incoming)

Unfortunetly, this copy currently inhibits taking gradients of the BPM reading. The reduced example

import torch
import cheetah

beam = cheetah.ParameterBeam.from_parameters(mu_x=torch.tensor([0.0], requires_grad=True))

bpm = cheetah.BPM()
bpm.track(beam)

fails with

RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment. If you were attempting to deepcopy a module, this may be because of a torch.nn.utils.weight_norm usage, see pytorch/pytorch#103001

Maybe the deepcopy can simply be replaced by a call to torch.clone?