lucidrains/ema-pytorch

Support for non-deepcopyable models

stefan-baumann opened this issue · 5 comments

Hi there,

first of all thanks for the really useful library! There is one thing I have run into which I think could be a good addition to the library: as you're aware, some models are not copyable via copy.deepcopy. Instead of just raising an Exception, you could also offer users the ability to input a separate model instance, which is then used for the EMA. Then, the parameters can be synced via module.parameters(), and later updated as before.
It would be great if you could offer this (or some other alternative) as an option.

Thanks and all the best!

Any news on this?
I was searching for exactly this issue since Muse uses this library for loading the ema models and the maskgit training is supposed to use the ema model instead of the non-ema model for the VAE, after trying a few things I was not able to get it working so I did a bit of searching and end up here, seems like if we want to train the Maskgit for Muse with the ema VAE we might need to find a way to add support for non-deepcopyable models.

hey Alejandro

what error are you seeing? I thought all the modules from muse should be deepcopyable?

Hey, sorry I didn't answer before, I was a bit busy until recently. I think the issue I'm getting is most likely because of how I'm loading the VAE, I'm not 100% sure how I should be loading the EMA VAE, I did several tests and in some of them I was getting this issue where it said that the model was not deepcopyable, I think that one was where I was loading the VAE first with the vae.load() function from Muse itself and then I used ema = EMA(vae) to init the EMA, that's probably not the way to do it tho, I think is more related to Muse than this repo tho but the issue is mainly that I have no idea how to properly load the EMA VAE from a checkpoint to resume training or to start the MaskGit training with it, the only time I was able to make it kinda load the model was broken and not training properly.

@ZeroCool940711 oh got it

well, it turns out i already addressed this issue but did not close the issue

you just have to pass in your instantiated EMA model through the ema_model kwarg

oh, nice, will give it a try and see if I can make it work on Muse :)