yuval-alaluf/restyle-encoder

Error when loading converted ada-pytorch model

TheodoreGalanos opened this issue ยท 13 comments

Hello and thank you for sharing your fascinating work!

I'm trying to use restyle with a pretrained stylegan-ada-pytorch model. I followed the conversion script (thanks btw!) and have my .pt model file ready. Unfortunately, when I'm trying to run training using the following command

python scripts/train_restyle_psp.py --dataset_type=buildings --encoder_type=BackboneEncoder --exp_dir=experiment/restyle_psp_ffhq_encode --workers=8 --batch_size=8 --test_batch_size=8 --test_workers=8 --val_interval=5000 --save_interval=10000 --start_from_latent_avg --lpips_lambda=0.8 --l2_lambda=1 --w_norm_lambda=0 --id_lambda=0.1 --input_nc=6 --n_iters_per_batch=5 --output_size=512 --stylegan_weights=F:\Experimentation\Generative_models\GANs\StyleGAN2\pretrained_models\rosalinity\buildings_5kimg_upsampled.pt

I get the following error when loading the pretrained model

  File "scripts/train_restyle_psp.py", line 30, in <module>
    main()
  File "scripts/train_restyle_psp.py", line 25, in main
    coach = Coach(opts)
  File "F:\Experimentation\Generative_models\GANs\StyleGAN2\GAN_editing\restyle-encoder\training\coach_restyle_psp.py", line 31, in __init__
    self.net = pSp(self.opts).to(self.device)
  File "F:\Experimentation\Generative_models\GANs\StyleGAN2\GAN_editing\restyle-encoder\models\psp.py", line 25, in __init__
    self.load_weights()
  File "F:\Experimentation\Generative_models\GANs\StyleGAN2\GAN_editing\restyle-encoder\models\psp.py", line 52, in load_weights
    self.decoder.load_state_dict(ckpt['g_ema'], strict=True)
  File "C:\Users\user\miniconda3\envs\archelites\lib\site-packages\torch\nn\modules\module.py", line 1052, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Generator:
        Missing key(s) in state_dict: "style.3.weight", "style.3.bias", "style.4.weight", "style.4.bias", "style.5.weight", "style.5.bias", "style.6.weight", "style.6.bias", "style.7.weight", "style.7.bias", "style.8.weight", "style.8.bias".```

Any idea why I get a mismatch here? Thanks!

Hi @TheodoreGalanos ,
I actually haven't tried using the scripts myself, but thought it would be helpful to reference them in the repo.
I noticed that there is a similar issue in dvschultz's repo (dvschultz/stylegan2-ada-pytorch#6) when trying to convert a model at 256x256 resolution so it seems like other people are also facing the same issue.
I will try playing with the script myself within the next couple of days to see if anything interesting pops up, but posting a comment in that thread could help as well.

Regardless, I noticed your running on a building StyleGAN if I'm not mistaken. But I see that you're using the BackboneEncoder and the ID loss that are specialized for the human faces domain. You should take a look at the other encoder and the MOCO loss we have in the repo. There are more details here: https://github.com/yuval-alaluf/restyle-encoder#additional-notes

Thanks @yuval-alaluf, let me know what you find once you find time to look into it! I hadn't seen the thread, will follow that one as well!

And thanks for the feedback on the options, I started with a more or less copy/paste :) Will adjust it with your recommendations.

Any chance you can check how many mapping layers you used when training your generator?
I see in the stylegan-ada-pytorch repo the following configs:

    cfg_specs = {
        'auto':      dict(ref_gpus=-1, kimg=25000,  mb=-1, mbstd=-1, fmaps=-1,  lrate=-1,     gamma=-1,   ema=-1,  ramp=0.05, map=2), # Populated dynamically based on resolution and GPU count.
        'stylegan2': dict(ref_gpus=8,  kimg=25000,  mb=32, mbstd=4,  fmaps=1,   lrate=0.002,  gamma=10,   ema=10,  ramp=None, map=8), # Uses mixed-precision, unlike the original StyleGAN2.
        'paper256':  dict(ref_gpus=8,  kimg=25000,  mb=64, mbstd=8,  fmaps=0.5, lrate=0.0025, gamma=1,    ema=20,  ramp=None, map=8),
        'paper512':  dict(ref_gpus=8,  kimg=25000,  mb=64, mbstd=8,  fmaps=1,   lrate=0.0025, gamma=0.5,  ema=20,  ramp=None, map=8),
        'paper1024': dict(ref_gpus=8,  kimg=25000,  mb=32, mbstd=4,  fmaps=1,   lrate=0.002,  gamma=2,    ema=10,  ramp=None, map=8),
        'cifar':     dict(ref_gpus=2,  kimg=100000, mb=64, mbstd=32, fmaps=1,   lrate=0.0025, gamma=0.01, ema=500, ramp=0.05, map=2),
    }

If you used auto, this may explain why you're missing the keys related to styles 3 - 8.
If this is the case, the easiest fix I see is changing the definition of the generator in this repo (using the generator you converted to rosinality's format)

self.decoder = Generator(self.opts.output_size, 512, 8, channel_multiplier=2)

Try changing the above lines to:

self.decoder = Generator(self.opts.output_size, 512, 2, channel_multiplier=2) 

Interested to see if this is indeed the difference we're seeing.

Hi @yuval-alaluf I just realized this today from a discussion in a discord server I'm in. I think I trained my model with 'auto' so I have 2 vs 8 mapping networks! Good catch! Will go ahead and try that in a bit (after the morning coffee)

So I tried this and it obviously by passed the issue, however it apparently fails to find the latent_avg

    self.avg_image = self.net(self.net.latent_avg.unsqueeze(0),
AttributeError: 'NoneType' object has no attribute 'unsqueeze'```

So I tried this and it obviously by passed the issue, however it apparently fails to find the latent_avg

Ok so it seems like the stylegan-ada-pytorch conversion script is missing some features that rosinality implemented when converting from the tensorflow versions.

If we take a look at rosinality's conversion script, he does the following:

with open(args.path, "rb") as f:
    generator, discriminator, g_ema = pickle.load(f)
...
latent_avg = torch.from_numpy(g_ema.vars["dlatent_avg"].value().eval())
...
ckpt = {"g_ema": state_dict, "latent_avg": latent_avg}
# some more conversions

You can see that he saves latent_avg in the checkpoint and this is what we try loading in this repo.
It seems like the pytorch conversion script is missing this step. They simply have:

state_dict = {"g_ema": state_ros}
torch.save(state_dict, output_file)

Taking a look at the pytorch code, it seems like the latent average is stored in

g_ema.mapping.w_avg    # seems like `g_ema` is stored as `G_nvidia` is the export script

So you can try adding

state_dict = {"g_ema": state_ros, "latent_avg": latent_avg}
torch.save(state_dict, output_file)

and see if this solves the problem with the latent average.

Let me know if this works. I didn't test it and just looked at the code so I may have missed something.
Once we're able to solve all these edge cases, I think it would be really helpful to add a PR somewhere. It seems like more people will come across these issues.

Thank you will try this soon! Apologies have been swamped the last 2 days and forgot this.

Changing export_weights.py with these lines seems to have worked for me:

    latent_avg = state_nv['mapping.w_avg']
    state_dict = {"g_ema": state_ros, "latent_avg": latent_avg}
    torch.save(state_dict, output_file)

Changing export_weights.py with these lines seems to have worked for me:

    latent_avg = state_nv['mapping.w_avg']
    state_dict = {"g_ema": state_ros, "latent_avg": latent_avg}
    torch.save(state_dict, output_file)

Thanks, this actually solved the problem with latent_avg

Thanks @AleksiKnuutila for the snippet! ๐Ÿ˜„
I provided a link to this issue in the README in case other people come across similar issues with the conversion.
Since it seems like this solves the problem, I'm closing this issue for now.

Thanks @yuval-alaluf for your amazing work.
Since your coaches require to train from lantent_avg, why not just estimate the initial lantent_avg from a dense sampling if the lantent_avg does not available?
It may look like this:

# Initialize network		
self.net = pSp(self.opts).to(self.device)
if self.net.latent_avg is None:
    self.net.latent_avg = self.net.decoder.mean_latent(int(1e5))[0]  # set a very large number for estimation

Thanks all. Iโ€™ve added this to the export_weights.py script so you shouldnโ€™t need to edit your own copy anymore.

Thanks for the update @dvschultz!