RuntimeError: Error(s) in loading state_dict for Generator: Missing key(s) in state_dict:
goongzi-leean opened this issue · 0 comments
When I load AGGAN-Mod, I get this error:
RuntimeError: Error(s) in loading state_dict for Generator:
Missing key(s) in state_dict: "blocks.0.0.bn1.gain.weight", "blocks.0.0.bn1.bias.weight", "blocks.0.0.bn2.gain.weight", "blocks.0.0.bn2.bias.weight", "blocks.1.0.bn1.gain.weight", "blocks.1.0.bn1.bias.weight", "blocks.1.0.bn2.gain.weight", "blocks.1.0.bn2.bias.weight", "blocks.2.0.bn1.gain.weight", "blocks.2.0.bn1.bias.weight", "blocks.2.0.bn2.gain.weight", "blocks.2.0.bn2.bias.weight".
Unexpected key(s) in state_dict: "blocks.0.0.bn1.embed0.weight", "blocks.0.0.bn1.embed1.weight", "blocks.0.0.bn2.embed0.weight", "blocks.0.0.bn2.embed1.weight", "blocks.1.0.bn1.embed0.weight", "blocks.1.0.bn1.embed1.weight", "blocks.1.0.bn2.embed0.weight", "blocks.1.0.bn2.embed1.weight", "blocks.2.0.bn1.embed0.weight", "blocks.2.0.bn1.embed1.weight", "blocks.2.0.bn2.embed0.weight", "blocks.2.0.bn2.embed1.weight".
So I went to find out why.
The network structure in which my generator was found looks like this:
Generator(
(linear0): Linear(in_features=128, out_features=4096, bias=True)
(blocks): ModuleList(
(0): ModuleList(
(0): GenBlock(
(bn1): ConditionalBatchNorm2d(
(bn): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True)
(gain): Linear(in_features=10, out_features=256, bias=False)
(bias): Linear(in_features=10, out_features=256, bias=False)
)
(bn2): ConditionalBatchNorm2d(
(bn): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True)
(gain): Linear(in_features=10, out_features=256, bias=False)
(bias): Linear(in_features=10, out_features=256, bias=False)
)
(activation): ReLU(inplace=True)
(conv2d0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(conv2d1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2d2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
And the log that the author trained looks like this:
Generator(
(linear0): Linear(in_features=128, out_features=4096, bias=True)
(blocks): ModuleList(
(0): ModuleList(
(0): GenBlock(
(bn1): ConditionalBatchNorm2d(
(bn): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True)
(embed0): Embedding(10, 256)
(embed1): Embedding(10, 256)
)
(bn2): ConditionalBatchNorm2d(
(bn): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True)
(embed0): Embedding(10, 256)
(embed1): Embedding(10, 256)
)
(activation): ReLU(inplace=True)
(conv2d0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(conv2d1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2d2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
I found ConditionalBatchNorm2d (in ops.py) in the latest code and found:
self.gain = MODULES.g_linear(in_features=in_features, out_features=out_features, bias=False)
self.bias = MODULES.g_linear(in_features=in_features, out_features=out_features, bias=False)
but g_linear= ops.linear(in config.py)
This is where the above error comes in.
ConditionalBatchNorm2d will need to be modified if a load author pre-trained generator is required. Or you can choose to retrain. This is true for all conditions GAN.
Of course, I hope the author can pay attention to this problem.
Best!
Leean