desy-ml/cheetah

quadrupole default dtype

jp-ga opened this issue · 1 comments

Quadrupole (and probably other elements) dtype default is set to torch.float32, which leads to errors when using beams with default tensors (which default to torch.double). See https://github.com/desy-ml/cheetah/blob/master/cheetah/accelerator/quadrupole.py#L37

This can result annoying when trying to use default kwargs

This is actually a big item I still want to address. It's kind of related to #113.

In the end, I think everything should kind of work in a similar way to say nn.Linear in PyTorch, where you can do either one of

quad = quad.double()
quad = quad.gpu()

This requires registering stuff in PyTorch and we hadn't quite figured out yet how to do this while keeping assignments like

quad.k1 = torch.tensor([4.2])

functional.

As a workaround for now, you basically have to remember to pass the dtype you want to both the beam and all elements on initialisation and then in should work.