alsdudrla10/DG

size mismatch

bottle0228 opened this issue · 5 comments

Hello,
thank you very much for sharing the code!

I have downloaded your code and everything went smoothly before, but I encountered some issues while proceeding to step five. I tried running this code:
python3 train.py --savedir=/checkpoints/discriminator/cifar_cond --gendir=/samples/cifar_cond_vanilla --datadir=/data/true_data_label.npz --cond=1
But there seems to be some mismatch issues:

Traceback (most recent call last):
  File "/home/lrp/DG/train.py", line 179, in <module>
    main()
  File "/home/anaconda/envs/dg/lib/python3.9/site-packages/click/core.py", line 1130, in __call__
    return self.main(*args, **kwargs)
  File "/home/anaconda/envs/dg/lib/python3.9/site-packages/click/core.py", line 1055, in main
    rv = self.invoke(ctx)
  File "/home/anaconda/envs/dg/lib/python3.9/site-packages/click/core.py", line 1404, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/home/anaconda/envs/dg/lib/python3.9/site-packages/click/core.py", line 760, in invoke
    return __callback(*args, **kwargs)
  File "/home/lrp/DG/train.py", line 124, in main
    pretrained_classifier = classifier_lib.load_classifier(opts.pretrained_classifier_ckpt, opts.img_resolution, opts.device, eval=False)
  File "/home/lrp/DG/classifier_lib.py", line 33, in load_classifier
    classifier.load_state_dict(classifier_state)
  File "/home/anaconda/envs/dg/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1604, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for EncoderUNetModel:
        Missing key(s) in state_dict: "input_blocks.3.1.norm.weight", "input_blocks.3.1.norm.bias", "input_blocks.3.1.qkv.weight", "input_b                                                                   locks.3.1.qkv.bias", "input_blocks.3.1.proj_out.weight", "input_blocks.3.1.proj_out.bias", "input_blocks.6.0.skip_connection.weight", "inpu                                                                   t_blocks.6.0.skip_connection.bias", "input_blocks.6.1.norm.weight", "input_blocks.6.1.norm.bias", "input_blocks.6.1.qkv.weight", "input_blo                                                                   cks.6.1.qkv.bias", "input_blocks.6.1.proj_out.weight", "input_blocks.6.1.proj_out.bias", "input_blocks.9.0.in_layers.0.weight", "input_bloc                                                                   ks.9.0.in_layers.0.bias", "input_blocks.9.0.in_layers.2.weight", "input_blocks.9.0.in_layers.2.bias", "input_blocks.9.0.emb_layers.1.weight                                                                   ", "input_blocks.9.0.emb_layers.1.bias", "input_blocks.9.0.out_layers.0.weight", "input_blocks.9.0.out_layers.0.bias", "input_blocks.9.0.ou                                                                   t_layers.3.weight", "input_blocks.9.0.out_layers.3.bias", "input_blocks.9.1.norm.weight", "input_blocks.9.1.norm.bias", "input_blocks.9.1.q                                                                   kv.weight", "input_blocks.9.1.qkv.bias", "input_blocks.9.1.proj_out.weight", "input_blocks.9.1.proj_out.bias", "input_blocks.10.0.in_layers                                                                   .0.weight", "input_blocks.10.0.in_layers.0.bias", "input_blocks.10.0.in_layers.2.weight", "input_blocks.10.0.in_layers.2.bias", "input_bloc                                                                   ks.10.0.emb_layers.1.weight", "input_blocks.10.0.emb_layers.1.bias", "input_blocks.10.0.out_layers.0.weight", "input_blocks.10.0.out_layers                                                                   .0.bias", "input_blocks.10.0.out_layers.3.weight", "input_blocks.10.0.out_layers.3.bias", "input_blocks.11.0.in_layers.0.weight", "input_bl                                                                   ocks.11.0.in_layers.0.bias", "input_blocks.11.0.in_layers.2.weight", "input_blocks.11.0.in_layers.2.bias", "input_blocks.11.0.emb_layers.1.                                                                   weight", "input_blocks.11.0.emb_layers.1.bias", "input_blocks.11.0.out_layers.0.weight", "input_blocks.11.0.out_layers.0.bias", "input_bloc                                                                   ks.11.0.out_layers.3.weight", "input_blocks.11.0.out_layers.3.bias", "input_blocks.11.0.skip_connection.weight", "input_blocks.11.0.skip_co                                                                   nnection.bias", "input_blocks.11.1.norm.weight", "input_blocks.11.1.norm.bias", "input_blocks.11.1.qkv.weight", "input_blocks.11.1.qkv.bias                                                                   ", "input_blocks.11.1.proj_out.weight", "input_blocks.11.1.proj_out.bias", "input_blocks.12.0.in_layers.0.weight", "input_blocks.12.0.in_la                                                                   yers.0.bias", "input_blocks.12.0.in_layers.2.weight", "input_blocks.12.0.in_layers.2.bias", "input_blocks.12.0.emb_layers.1.weight", "input                                                                   _blocks.12.0.emb_layers.1.bias", "input_blocks.12.0.out_layers.0.weight", "input_blocks.12.0.out_layers.0.bias", "input_blocks.12.0.out_lay                                                                   ers.3.weight", "input_blocks.12.0.out_layers.3.bias", "input_blocks.12.1.norm.weight", "input_blocks.12.1.norm.bias", "input_blocks.12.1.qk                                                                   v.weight", "input_blocks.12.1.qkv.bias", "input_blocks.12.1.proj_out.weight", "input_blocks.12.1.proj_out.bias", "input_blocks.13.0.in_laye                                                                   rs.0.weight", "input_blocks.13.0.in_layers.0.bias", "input_blocks.13.0.in_layers.2.weight", "input_blocks.13.0.in_layers.2.bias", "input_bl                                                                   ocks.13.0.emb_layers.1.weight", "input_blocks.13.0.emb_layers.1.bias", "input_blocks.13.0.out_layers.0.weight", "input_blocks.13.0.out_laye                                                                   rs.0.bias", "input_blocks.13.0.out_layers.3.weight", "input_blocks.13.0.out_layers.3.bias", "input_blocks.13.1.norm.weight", "input_blocks.                                                                   13.1.norm.bias", "input_blocks.13.1.qkv.weight", "input_blocks.13.1.qkv.bias", "input_blocks.13.1.proj_out.weight", "input_blocks.13.1.proj                                                                   _out.bias", "input_blocks.14.0.in_layers.0.weight", "input_blocks.14.0.in_layers.0.bias", "input_blocks.14.0.in_layers.2.weight", "input_bl                                                                   ocks.14.0.in_layers.2.bias", "input_blocks.14.0.emb_layers.1.weight", "input_blocks.14.0.emb_layers.1.bias", "input_blocks.14.0.out_layers.                                                                   0.weight", "input_blocks.14.0.out_layers.0.bias", "input_blocks.14.0.out_layers.3.weight", "input_blocks.14.0.out_layers.3.bias", "input_bl                                                                   ocks.14.1.norm.weight", "input_blocks.14.1.norm.bias", "input_blocks.14.1.qkv.weight", "input_blocks.14.1.qkv.bias", "input_blocks.14.1.pro                                                                   j_out.weight", "input_blocks.14.1.proj_out.bias".
        Unexpected key(s) in state_dict: "input_blocks.4.0.skip_connection.weight", "input_blocks.4.0.skip_connection.bias", "input_blocks.                                                                   5.1.norm.weight", "input_blocks.5.1.norm.bias", "input_blocks.5.1.qkv.weight", "input_blocks.5.1.qkv.bias", "input_blocks.5.1.proj_out.weig                                                                   ht", "input_blocks.5.1.proj_out.bias", "input_blocks.7.0.skip_connection.weight", "input_blocks.7.0.skip_connection.bias".
        size mismatch for input_blocks.0.0.weight: copying a param with shape torch.Size([128, 4, 3, 3]) from checkpoint, the shape in curr                                                                   ent model is torch.Size([128, 3, 3, 3]).
        size mismatch for input_blocks.4.0.in_layers.2.weight: copying a param with shape torch.Size([256, 128, 3, 3]) from checkpoint, the                                                                    shape in current model is torch.Size([128, 128, 3, 3]).
        size mismatch for input_blocks.4.0.in_layers.2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in cur                                                                   rent model is torch.Size([128]).

I hope to receive your help!

Thank you very much!

Hello, I am very happy that you are interated in our work.

I have checked our code agian, but I couldn't find any errors on your part.
I think you received "DG/checkpoints/ADM_classifier/32x32_classifier.pt" incorrectly.

Maybe, you downloaded from "DG_imagenet/pretrained_models/ADM_classifier/32x32_classifier.pt"

Check it please!
Thank you very much :)

Hello,
Thank you very much for your reply. I have checked and it is indeed the case.
Thank you very much !

Hello,
I'm sorry to bother you again.
When I reached step six, an error occurred again. When I'm trying to run
python3 generate.py --network checkpoints/pretrained_ score/edm-cifar10-32x32-cond-vp.pkl --outdir=samples/cifar_ cond --dg_ weight_ 1st_ order=1 --cond=1 --discriminator_ ckpt=/checkpoints/discriminator/cifar_ cond/discriminator_ 250.pt --boosting=1
An error occurred:

2023-05-07 11:24:49.969188: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-05-07 11:24:50.010555: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-05-07 11:24:50.615057: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Loading network from "checkpoints/pretrained_score/edm-cifar10-32x32-cond-vp.pkl"...
<function get_discriminator.<locals>.evaluate at 0x7fea2b0f01f0>
Generating 50000 images to "samples/cifar_cond"...
  0%|                                                                                                                                                               | 0/501 [00:02<?, ?it/s]
Traceback (most recent call last):
  File "/home/lrp/DG/generate.py", line 206, in <module>
    main()
  File "/home/anaconda/envs/dg/lib/python3.9/site-packages/click/core.py", line 1130, in __call__
    return self.main(*args, **kwargs)
  File "/home/anaconda/envs/dg/lib/python3.9/site-packages/click/core.py", line 1055, in main
    rv = self.invoke(ctx)
  File "/home/anaconda/envs/dg/lib/python3.9/site-packages/click/core.py", line 1404, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/home/anaconda/envs/dg/lib/python3.9/site-packages/click/core.py", line 760, in invoke
    return __callback(*args, **kwargs)
  File "/home/lrp/DG/generate.py", line 176, in main
    images = edm_sampler(boosting, time_min, time_max, vpsde, dg_weight_1st_order, dg_weight_2nd_order, discriminator, net, latents, class_labels, randn_like=torch.randn_like, **sampler_kwargs)
  File "/home/lrp/DG/generate.py", line 75, in edm_sampler
    discriminator_guidance, log_ratio = classifier_lib.get_grad_log_ratio(discriminator, vpsde, x_hat, t_hat, net.img_resolution, time_min, time_max, class_labels, log=True)
  File "/home/lrp/DG/classifier_lib.py", line 69, in get_grad_log_ratio
    input = mean_vp_tau * unnormalized_input
RuntimeError: The size of tensor a (100) must match the size of tensor b (32) at non-singleton dimension 3

May I ask how I can solve it?
Thank you very much!

See #3

We fix some code in classifier_lib.py
Good Luck :)

Thank you for your answer!