alsdudrla10/DG

Dimension mismatch error

chrundle opened this issue · 2 comments

Hello,

I hope you are doing well. Firstly, I want to thank you for making your code available.

I cloned your repo today and downloaded all of the files outlined in your README. Following step 2 in the README, I was able to use the vanilla model to generate CIFAR-10 samples via the following command:

$ python3 generate.py --network checkpoints/pretrained_score/edm-cifar10-32x32-cond-vp.pkl --outdir=samples/cifar_cond_vanilla --dg_weight_1st_order=0

I then tried to use your pretrained model to generate discriminator-guided CIFAR-10 samples via the following command (from your readme):

$ python3 generate.py --network checkpoints/pretrained_score/edm-cifar10-32x32-cond-vp.pkl --outdir=samples/cifar_cond --dg_weight_1st_order=1 --cond=1 --discriminator_ckpt=/checkpoints/discriminator/cifar_cond/discriminator_250.pt --boosting=1

However, this returned the following error:

Loading network from "checkpoints/pretrained_score/edm-cifar10-32x32-cond-vp.pkl"...
<function get_discriminator.<locals>.evaluate at 0x2001ac552820>
Generating 50000 images to "samples/cifar_cond"...
  0%|                                                                                                                | 0/501 [00:01<?, ?it/s]
Traceback (most recent call last):
  File "generate.py", line 206, in <module>
    main()
  File "/usr/workspace/zeroml/LC_utilities/lassen/conda_env/envs/mad/lib/python3.8/site-packages/click/core.py", line 1128, in __call__
    return self.main(*args, **kwargs)
  File "/usr/workspace/zeroml/LC_utilities/lassen/conda_env/envs/mad/lib/python3.8/site-packages/click/core.py", line 1053, in main
    rv = self.invoke(ctx)
  File "/usr/workspace/zeroml/LC_utilities/lassen/conda_env/envs/mad/lib/python3.8/site-packages/click/core.py", line 1395, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/usr/workspace/zeroml/LC_utilities/lassen/conda_env/envs/mad/lib/python3.8/site-packages/click/core.py", line 754, in invoke
    return __callback(*args, **kwargs)
  File "generate.py", line 176, in main
    images = edm_sampler(boosting, time_min, time_max, vpsde, dg_weight_1st_order, dg_weight_2nd_order, discriminator, net, latents, class_labels, randn_like=torch.randn_like, **sampler_kwargs)
  File "generate.py", line 75, in edm_sampler
    discriminator_guidance, log_ratio = classifier_lib.get_grad_log_ratio(discriminator, vpsde, x_hat, t_hat, net.img_resolution, time_min, time_max, class_labels, log=True)
  File "/p/gpfs1/chrundle/zoo-er/dg-discriminator-guidance/classifier_lib.py", line 69, in get_grad_log_ratio
    input = mean_vp_tau * unnormalized_input
RuntimeError: The size of tensor a (100) must match the size of tensor b (32) at non-singleton dimension 3

I suspect the issue is minor but I wanted to reach out to you to see if this was something you have seen before and/or knew how to resolve. Thank you for your time.

Hello, I am very happy that you are interated in our work.

I noticed that the same error occured in my environment.
I just modified the get_grad_log_ratio() in classifier_lib.py, and now it works well.

Thank you very much :)

Thank you very much! I have adopted the changes and it is working for me as well.