POSTECH-CVLab/PyTorch-StudioGAN

RuntimeError: Error(s) in loading state_dict for Generator: Missing key(s) in state_dict:

goongzi-leean opened this issue · 0 comments

When I load AGGAN-Mod, I get this error:

RuntimeError: Error(s) in loading state_dict for Generator:
Missing key(s) in state_dict: "blocks.0.0.bn1.gain.weight", "blocks.0.0.bn1.bias.weight", "blocks.0.0.bn2.gain.weight", "blocks.0.0.bn2.bias.weight", "blocks.1.0.bn1.gain.weight", "blocks.1.0.bn1.bias.weight", "blocks.1.0.bn2.gain.weight", "blocks.1.0.bn2.bias.weight", "blocks.2.0.bn1.gain.weight", "blocks.2.0.bn1.bias.weight", "blocks.2.0.bn2.gain.weight", "blocks.2.0.bn2.bias.weight".
Unexpected key(s) in state_dict: "blocks.0.0.bn1.embed0.weight", "blocks.0.0.bn1.embed1.weight", "blocks.0.0.bn2.embed0.weight", "blocks.0.0.bn2.embed1.weight", "blocks.1.0.bn1.embed0.weight", "blocks.1.0.bn1.embed1.weight", "blocks.1.0.bn2.embed0.weight", "blocks.1.0.bn2.embed1.weight", "blocks.2.0.bn1.embed0.weight", "blocks.2.0.bn1.embed1.weight", "blocks.2.0.bn2.embed0.weight", "blocks.2.0.bn2.embed1.weight".

So I went to find out why.
The network structure in which my generator was found looks like this:

Generator(
(linear0): Linear(in_features=128, out_features=4096, bias=True)
(blocks): ModuleList(
(0): ModuleList(
(0): GenBlock(
(bn1): ConditionalBatchNorm2d(
(bn): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True)
(gain): Linear(in_features=10, out_features=256, bias=False)
(bias): Linear(in_features=10, out_features=256, bias=False)

)
(bn2): ConditionalBatchNorm2d(
(bn): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True)
(gain): Linear(in_features=10, out_features=256, bias=False)
(bias): Linear(in_features=10, out_features=256, bias=False)
)
(activation): ReLU(inplace=True)
(conv2d0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(conv2d1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2d2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)

And the log that the author trained looks like this:

Generator(
(linear0): Linear(in_features=128, out_features=4096, bias=True)
(blocks): ModuleList(
(0): ModuleList(
(0): GenBlock(
(bn1): ConditionalBatchNorm2d(
(bn): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True)
(embed0): Embedding(10, 256)
(embed1): Embedding(10, 256)

)
(bn2): ConditionalBatchNorm2d(
(bn): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True)
(embed0): Embedding(10, 256)
(embed1): Embedding(10, 256)
)
(activation): ReLU(inplace=True)
(conv2d0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(conv2d1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2d2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)

I found ConditionalBatchNorm2d (in ops.py) in the latest code and found:

self.gain = MODULES.g_linear(in_features=in_features, out_features=out_features, bias=False)
self.bias = MODULES.g_linear(in_features=in_features, out_features=out_features, bias=False)

but g_linear= ops.linear(in config.py)

This is where the above error comes in.

ConditionalBatchNorm2d will need to be modified if a load author pre-trained generator is required. Or you can choose to retrain. This is true for all conditions GAN.

Of course, I hope the author can pay attention to this problem.

Best!

Leean