lucidrains/ema-pytorch

about the training speed

winnechan opened this issue · 1 comments

Hi, thanks for your project.

The EMA model is updated on CPU by iterating over all the parameters of the online model, which makes the GPU utility is low. Does this mean that there is no way to speed up the training?

Thanks

I noticed that when instantiating KarrasEMA or EMA, you can pass the parameter allow_different_devices (which defaults to False). When allow_different_devices is set to True, the parameters of the EMA model will be moved to the same device as the parameters of the trained model; otherwise, they are kept on the CPU. Although it might be too late, I hope it helps.