forever208/DDPM-IP

About using ddim50 on face dataset

Suimingzhe opened this issue · 15 comments

I found an issue when I use ddim50 sampling aftrer training ddpm-ip on my own face dataset. The sampled images have much noise (however using ddpm50 is ok). I tried to use pre-trained celeba ckpt you provide and found the same problem.

mpiexec 4 python scripts/image_sample.py
--image_size 32 --timestep_respacing ddim50 --use_ddim True
--model_path DDPM_IP_celeba64.pt
--num_channels 192 --num_head_channels 64 --num_res_blocks 3 --attention_resolutions 32,16,8
--resblock_updown True --use_new_attention_order True --learn_sigma True --dropout 0.1
--diffusion_steps 1000 --noise_schedule cosine --use_scale_shift_norm True --batch_size 256 --num_samples 50000

@Suimingzhe Hi, for the DDIM sampling, I use the DDIM official code to do training and sampling. I remember that I also had some issues when using ADM code to do DDIM sampling.

Could you please check if the ADM-IP celeba checkpoint can do normal DDPM sampling?

@forever208 Thanks for your reply. I checked the sampled images again and here are the results.

For cifa10, using pre-trained ADM-IP,ddpm50 is normal, ddim50 seems normal.
For celeba, using pre-trained ADM-IP, ddpm50 is normal, ddim50 is unnormal(much noise).
For my own face dataset(similar to celeba), after training ADM-IP, ddpm50 is normal, ddim50 is unnormal(much noise).

The hyperparameters on my own dataset are the same as training celeba only except for '--img_size=256*256'.

@Suimingzhe hi, thanks for your info, I will check the DDIM sampling results today using pre-trained ADM-IP and let you know.
Currently, my only suggestion is that, if you wanna test the performance of DDIM-IP, you'd better use DDIM official code for training and inference. (but one major problem is that training on DDIM code is very slow)

@forever208 Thanks for your suggestion.

Hi @Suimingzhe, I confirmed your noisy DDIM samples using ADM-IP on the ADM code. So I suggest you train DDIM-IP on CelebA.

Hi, can you give me advice which line do I modify to apply your idea?

@john09282922
refer to the section "Simple to implement Input Perturbation in diffusion models" of the readme file

@john09282922 refer to the section "Simple to implement Input Perturbation in diffusion models" of the readme file

Hi, thanks for replying my question, I saw that, but I want to fix original DDIM code with your idea. can you tell me which line did you modify on original ddim code? not ADM.
thanks

@john09282922 you should make the modification on their script losses.py

more specifically, line 10

@john09282922 you should make the modification on their script losses.py

more specifically, line 10

thanks for giving me detail info. I love your code than original ddim code.
e in ddim is same as noise in your code. right? but, I am confusing that you are using x_start after the idea equation. but, ddim is not the same. if you have the loss function code when using paper works, could you send me? my email jungminhwang0919@gmail.com

thanks,

@john09282922 e in DDIM is epsilon in DDPM. I am not clear about 'you are using x_start after the idea equation', what does it mean?

@john09282922 e in DDIM is epsilon in DDPM. I am not clear about 'you are using x_start after the idea equation', what does it mean?
this is original ddim code.
x = x0 * a.sqrt() + e * (1.0 - a).sqrt()
output = model(x, t.float())
if keepdim:
return (e - output).square().sum(dim=(1, 2, 3))
else:
return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0)

I am not sure how to change which part. x or x0?
Is it same dimension?

thanks

@john09282922 replace x = x0 * a.sqrt() + e * (1.0 - a).sqrt() with x = x0 * a.sqrt() + (e+gamma*w) * (1.0 - a).sqrt() where w~N(0, I)

keep everything else unchanged

@john09282922 replace x = x0 * a.sqrt() + e * (1.0 - a).sqrt() with x = x0 * a.sqrt() + (e+gamma*w) * (1.0 - a).sqrt() where w~N(0, I)

keep everything else unchanged

thank you very much! but, I am sorry what is w? is it original noise ? In original ddim code, e ? or w is equal to th.rand_like(e), how to update w for training?

thanks,

@john09282922 w is equal to th.rand_like(e)