Error when loading converted ada-pytorch model
TheodoreGalanos opened this issue ยท 13 comments
Hello and thank you for sharing your fascinating work!
I'm trying to use restyle with a pretrained stylegan-ada-pytorch model. I followed the conversion script (thanks btw!) and have my .pt model file ready. Unfortunately, when I'm trying to run training using the following command
python scripts/train_restyle_psp.py --dataset_type=buildings --encoder_type=BackboneEncoder --exp_dir=experiment/restyle_psp_ffhq_encode --workers=8 --batch_size=8 --test_batch_size=8 --test_workers=8 --val_interval=5000 --save_interval=10000 --start_from_latent_avg --lpips_lambda=0.8 --l2_lambda=1 --w_norm_lambda=0 --id_lambda=0.1 --input_nc=6 --n_iters_per_batch=5 --output_size=512 --stylegan_weights=F:\Experimentation\Generative_models\GANs\StyleGAN2\pretrained_models\rosalinity\buildings_5kimg_upsampled.pt
I get the following error when loading the pretrained model
File "scripts/train_restyle_psp.py", line 30, in <module>
main()
File "scripts/train_restyle_psp.py", line 25, in main
coach = Coach(opts)
File "F:\Experimentation\Generative_models\GANs\StyleGAN2\GAN_editing\restyle-encoder\training\coach_restyle_psp.py", line 31, in __init__
self.net = pSp(self.opts).to(self.device)
File "F:\Experimentation\Generative_models\GANs\StyleGAN2\GAN_editing\restyle-encoder\models\psp.py", line 25, in __init__
self.load_weights()
File "F:\Experimentation\Generative_models\GANs\StyleGAN2\GAN_editing\restyle-encoder\models\psp.py", line 52, in load_weights
self.decoder.load_state_dict(ckpt['g_ema'], strict=True)
File "C:\Users\user\miniconda3\envs\archelites\lib\site-packages\torch\nn\modules\module.py", line 1052, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Generator:
Missing key(s) in state_dict: "style.3.weight", "style.3.bias", "style.4.weight", "style.4.bias", "style.5.weight", "style.5.bias", "style.6.weight", "style.6.bias", "style.7.weight", "style.7.bias", "style.8.weight", "style.8.bias".```
Any idea why I get a mismatch here? Thanks!
Hi @TheodoreGalanos ,
I actually haven't tried using the scripts myself, but thought it would be helpful to reference them in the repo.
I noticed that there is a similar issue in dvschultz's repo (dvschultz/stylegan2-ada-pytorch#6) when trying to convert a model at 256x256 resolution so it seems like other people are also facing the same issue.
I will try playing with the script myself within the next couple of days to see if anything interesting pops up, but posting a comment in that thread could help as well.
Regardless, I noticed your running on a building StyleGAN if I'm not mistaken. But I see that you're using the BackboneEncoder
and the ID loss that are specialized for the human faces domain. You should take a look at the other encoder and the MOCO loss we have in the repo. There are more details here: https://github.com/yuval-alaluf/restyle-encoder#additional-notes
Thanks @yuval-alaluf, let me know what you find once you find time to look into it! I hadn't seen the thread, will follow that one as well!
And thanks for the feedback on the options, I started with a more or less copy/paste :) Will adjust it with your recommendations.
Any chance you can check how many mapping layers you used when training your generator?
I see in the stylegan-ada-pytorch repo the following configs:
cfg_specs = {
'auto': dict(ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=0.05, map=2), # Populated dynamically based on resolution and GPU count.
'stylegan2': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # Uses mixed-precision, unlike the original StyleGAN2.
'paper256': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=0.5, lrate=0.0025, gamma=1, ema=20, ramp=None, map=8),
'paper512': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=1, lrate=0.0025, gamma=0.5, ema=20, ramp=None, map=8),
'paper1024': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=2, ema=10, ramp=None, map=8),
'cifar': dict(ref_gpus=2, kimg=100000, mb=64, mbstd=32, fmaps=1, lrate=0.0025, gamma=0.01, ema=500, ramp=0.05, map=2),
}
If you used auto
, this may explain why you're missing the keys related to styles 3 - 8.
If this is the case, the easiest fix I see is changing the definition of the generator in this repo (using the generator you converted to rosinality's format)
Line 22 in ecc797c
Try changing the above lines to:
self.decoder = Generator(self.opts.output_size, 512, 2, channel_multiplier=2)
Interested to see if this is indeed the difference we're seeing.
Hi @yuval-alaluf I just realized this today from a discussion in a discord server I'm in. I think I trained my model with 'auto' so I have 2 vs 8 mapping networks! Good catch! Will go ahead and try that in a bit (after the morning coffee)
So I tried this and it obviously by passed the issue, however it apparently fails to find the latent_avg
self.avg_image = self.net(self.net.latent_avg.unsqueeze(0),
AttributeError: 'NoneType' object has no attribute 'unsqueeze'```
So I tried this and it obviously by passed the issue, however it apparently fails to find the latent_avg
Ok so it seems like the stylegan-ada-pytorch conversion script is missing some features that rosinality implemented when converting from the tensorflow versions.
If we take a look at rosinality's conversion script, he does the following:
with open(args.path, "rb") as f:
generator, discriminator, g_ema = pickle.load(f)
...
latent_avg = torch.from_numpy(g_ema.vars["dlatent_avg"].value().eval())
...
ckpt = {"g_ema": state_dict, "latent_avg": latent_avg}
# some more conversions
You can see that he saves latent_avg
in the checkpoint and this is what we try loading in this repo.
It seems like the pytorch conversion script is missing this step. They simply have:
state_dict = {"g_ema": state_ros}
torch.save(state_dict, output_file)
Taking a look at the pytorch code, it seems like the latent average is stored in
g_ema.mapping.w_avg # seems like `g_ema` is stored as `G_nvidia` is the export script
So you can try adding
state_dict = {"g_ema": state_ros, "latent_avg": latent_avg}
torch.save(state_dict, output_file)
and see if this solves the problem with the latent average.
Let me know if this works. I didn't test it and just looked at the code so I may have missed something.
Once we're able to solve all these edge cases, I think it would be really helpful to add a PR somewhere. It seems like more people will come across these issues.
Thank you will try this soon! Apologies have been swamped the last 2 days and forgot this.
Changing export_weights.py with these lines seems to have worked for me:
latent_avg = state_nv['mapping.w_avg']
state_dict = {"g_ema": state_ros, "latent_avg": latent_avg}
torch.save(state_dict, output_file)
Changing export_weights.py with these lines seems to have worked for me:
latent_avg = state_nv['mapping.w_avg'] state_dict = {"g_ema": state_ros, "latent_avg": latent_avg} torch.save(state_dict, output_file)
Thanks, this actually solved the problem with latent_avg
Thanks @AleksiKnuutila for the snippet! ๐
I provided a link to this issue in the README in case other people come across similar issues with the conversion.
Since it seems like this solves the problem, I'm closing this issue for now.
Thanks @yuval-alaluf for your amazing work.
Since your coaches require to train from lantent_avg
, why not just estimate the initial lantent_avg
from a dense sampling if the lantent_avg
does not available?
It may look like this:
# Initialize network
self.net = pSp(self.opts).to(self.device)
if self.net.latent_avg is None:
self.net.latent_avg = self.net.decoder.mean_latent(int(1e5))[0] # set a very large number for estimation
Thanks all. Iโve added this to the export_weights.py
script so you shouldnโt need to edit your own copy anymore.
Thanks for the update @dvschultz!