quadrupole default dtype
jp-ga opened this issue · 1 comments
Quadrupole (and probably other elements) dtype
default is set to torch.float32
, which leads to errors when using beams with default tensors (which default to torch.double
). See https://github.com/desy-ml/cheetah/blob/master/cheetah/accelerator/quadrupole.py#L37
This can result annoying when trying to use default kwargs
This is actually a big item I still want to address. It's kind of related to #113.
In the end, I think everything should kind of work in a similar way to say nn.Linear
in PyTorch, where you can do either one of
quad = quad.double()
quad = quad.gpu()
This requires registering stuff in PyTorch and we hadn't quite figured out yet how to do this while keeping assignments like
quad.k1 = torch.tensor([4.2])
functional.
As a workaround for now, you basically have to remember to pass the dtype
you want to both the beam and all elements on initialisation and then in should work.