lucidrains/ema-pytorch

question about usage

fire717 opened this issue · 4 comments

Hi!
I want to use ema like this: save ema model weights, and load the saved weights as normal weights, is this the right usage?

ema_model = EMA(model)

#train
output = model(data)

#val 
output = ema_model.ema_model(data)
ema_model.update()

#save 
torch.save(ema_model.ema_model.state_dict(), save_name)

#load 
model.load_state_dict(torch.load(save_name), strict=True)

yup! would it help if I added a method for setting the online model weights and buffer with the EMAs?

nit: you should call ema_model.update() before the validation

the validation can also be invoked as ema_model(data)

Thx! I think it's good enough now!

@fire717 no problem

added it anyways in case you have a special need for it