About using ddim50 on face dataset
Suimingzhe opened this issue · 15 comments
I found an issue when I use ddim50 sampling aftrer training ddpm-ip on my own face dataset. The sampled images have much noise (however using ddpm50 is ok). I tried to use pre-trained celeba ckpt you provide and found the same problem.
mpiexec 4 python scripts/image_sample.py
--image_size 32 --timestep_respacing ddim50 --use_ddim True
--model_path DDPM_IP_celeba64.pt
--num_channels 192 --num_head_channels 64 --num_res_blocks 3 --attention_resolutions 32,16,8
--resblock_updown True --use_new_attention_order True --learn_sigma True --dropout 0.1
--diffusion_steps 1000 --noise_schedule cosine --use_scale_shift_norm True --batch_size 256 --num_samples 50000
@Suimingzhe Hi, for the DDIM sampling, I use the DDIM official code to do training and sampling. I remember that I also had some issues when using ADM code to do DDIM sampling.
Could you please check if the ADM-IP celeba checkpoint can do normal DDPM sampling?
@forever208 Thanks for your reply. I checked the sampled images again and here are the results.
For cifa10, using pre-trained ADM-IP,ddpm50 is normal, ddim50 seems normal.
For celeba, using pre-trained ADM-IP, ddpm50 is normal, ddim50 is unnormal(much noise).
For my own face dataset(similar to celeba), after training ADM-IP, ddpm50 is normal, ddim50 is unnormal(much noise).
The hyperparameters on my own dataset are the same as training celeba only except for '--img_size=256*256'.
@Suimingzhe hi, thanks for your info, I will check the DDIM sampling results today using pre-trained ADM-IP and let you know.
Currently, my only suggestion is that, if you wanna test the performance of DDIM-IP, you'd better use DDIM official code for training and inference. (but one major problem is that training on DDIM code is very slow)
@forever208 Thanks for your suggestion.
Hi @Suimingzhe, I confirmed your noisy DDIM samples using ADM-IP on the ADM code. So I suggest you train DDIM-IP on CelebA.
Hi, can you give me advice which line do I modify to apply your idea?
@john09282922
refer to the section "Simple to implement Input Perturbation in diffusion models" of the readme file
@john09282922 refer to the section "Simple to implement Input Perturbation in diffusion models" of the readme file
Hi, thanks for replying my question, I saw that, but I want to fix original DDIM code with your idea. can you tell me which line did you modify on original ddim code? not ADM.
thanks
@john09282922 you should make the modification on their script losses.py
more specifically, line 10
@john09282922 you should make the modification on their script losses.py
more specifically, line 10
thanks for giving me detail info. I love your code than original ddim code.
e in ddim is same as noise in your code. right? but, I am confusing that you are using x_start after the idea equation. but, ddim is not the same. if you have the loss function code when using paper works, could you send me? my email jungminhwang0919@gmail.com
thanks,
@john09282922 e in DDIM is epsilon in DDPM. I am not clear about 'you are using x_start after the idea equation', what does it mean?
@john09282922 e in DDIM is epsilon in DDPM. I am not clear about 'you are using x_start after the idea equation', what does it mean?
this is original ddim code.
x = x0 * a.sqrt() + e * (1.0 - a).sqrt()
output = model(x, t.float())
if keepdim:
return (e - output).square().sum(dim=(1, 2, 3))
else:
return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0)
I am not sure how to change which part. x or x0?
Is it same dimension?
thanks
@john09282922 replace x = x0 * a.sqrt() + e * (1.0 - a).sqrt()
with x = x0 * a.sqrt() + (e+gamma*w) * (1.0 - a).sqrt()
where w~N(0, I)
keep everything else unchanged
@john09282922 replace
x = x0 * a.sqrt() + e * (1.0 - a).sqrt()
withx = x0 * a.sqrt() + (e+gamma*w) * (1.0 - a).sqrt()
where w~N(0, I)keep everything else unchanged
thank you very much! but, I am sorry what is w? is it original noise ? In original ddim code, e ? or w is equal to th.rand_like(e), how to update w for training?
thanks,
@john09282922 w is equal to th.rand_like(e)