jku-vds-lab/paradime

Adapt losses to automatically use correct device

Closed this issue · 2 comments

Device-related errors are fixed in most places, but they remain a problem with the current loss implementation. The losses should know whether the model of a ParametricDR lives on CPU or GPU and move the .sub output accordingly.

This is mostly done. There remains a RuntimeError when a model parameter is used inside a transform, but I couldn't figure out yet if this issue is connected with the loss modifications.

Issue is fixed, the RuntimeError is caused by an in-place operation in the StudentTTransform.