Question regarding `ema` in `training_loop`
Michael-H777 opened this issue · 5 comments
Hello,
First of all, thank you for this amazing repo. the framework is quite powerful.
I've been reading your code, and have a questions. In the training_loop
function in training/training_loop.py
, what is the function of the variable ema
? you can find it in line 107 when its initialized, as well as line 187 when you perform parameter update using the net's weight.
I dont see where its being used, other than being saved during checkpointing.
Michael
Hi,
Thank you for your interests. The ema
stands for the exponential moving average of the network during training. The ema
trick typically produces better samples, and is widely adopted in the field of generative models. You are correct that the ema
network is not used during training. However, it's used in evaluation. In the generate.py
file, we load the ema
network for evaluation (generating images).
Yilun
I see, thanks for the explanation! i mostly work on segmentation, and is new to generative models.
Hi, sorry for re-opening this issue.
I am quite successful in refactoring your code for my purpose, thanks for this repo btw!
I have a question regarding EMA updates. when I monitor the training progress, it takes more time to update the EMA parameters than it does for the actual training (32x32 cifar-10). have you investigate the difference between updating EMA after every step vs updating EMA at every epoch? like, do they behave differently in terms of image quality?
Michael
Hi Michael,
It seems that the EMA update simply requires a linear interpolation between net
and ema
, which is significantly faster than the forward/backward computation in each training step (I will check the speed w/ and w/o EMA later). I haven't experimented with adjusting the frequency of EMA updates. I borrow the EMA parameters from the EDM paper https://arxiv.org/abs/2206.00364 .
Yilun
ok, thanks.