question about usage
fire717 opened this issue · 4 comments
fire717 commented
Hi!
I want to use ema like this: save ema model weights, and load the saved weights as normal weights, is this the right usage?
ema_model = EMA(model)
#train
output = model(data)
#val
output = ema_model.ema_model(data)
ema_model.update()
#save
torch.save(ema_model.ema_model.state_dict(), save_name)
#load
model.load_state_dict(torch.load(save_name), strict=True)
lucidrains commented
yup! would it help if I added a method for setting the online model weights and buffer with the EMAs?
lucidrains commented
nit: you should call ema_model.update()
before the validation
the validation can also be invoked as ema_model(data)
fire717 commented
Thx! I think it's good enough now!
lucidrains commented