about the training speed
winnechan opened this issue · 1 comments
winnechan commented
Hi, thanks for your project.
The EMA model is updated on CPU by iterating over all the parameters of the online model, which makes the GPU utility is low. Does this mean that there is no way to speed up the training?
Thanks
HiDolen commented
I noticed that when instantiating KarrasEMA
or EMA
, you can pass the parameter allow_different_devices
(which defaults to False). When allow_different_devices
is set to True, the parameters of the EMA model will be moved to the same device as the parameters of the trained model; otherwise, they are kept on the CPU. Although it might be too late, I hope it helps.