voletiv/mcvd-pytorch

Question about DDPM and DDIM sampling.

e4s2022 opened this issue · 1 comments

Hi, thanks for sharing your excellent work!

I just walked through the code base and noticed that during sampling you used timestamp t from 0 to 999 (see here. I think in the reversed pass, we should start from 999 till 0. I'm a little confused about this.

Another question is, what does the denoise option mean for the last sampling step? please check here.

These two questions can be raised either for the DDPM or DDIM sampler. Really appreciate your explanation.

您好,我的看法是这样的。

作者使用的是基于朗之万动力学NCSN扩散模型。原论文在设置参数的时候,开始参数大于结束参数。betas按列表顺序由大到小。在降噪过程中,betas应该是由大到小,列表索引应该是0-999。X_T是真实图像。

image

image

sigma_begin: 0.02
sigma_end: 0.0001

elif config.model.sigma_dist == 'linear':
return torch.linspace(config.model.sigma_begin, config.model.sigma_end,
T).to(config.device)

if self.schedule == 'linear':
self.register_buffer('betas', get_sigmas(config))
self.register_buffer('alphas', torch.cumprod(1 - self.betas.flip(0), 0).flip(0))
self.register_buffer('alphas_prev', torch.cat([self.alphas[1:], torch.tensor([1.0]).to(self.alphas)]))

for i, step in enumerate(steps):

作为对比:
在DDPM原论文中,开始参数小于结束参数。betas按列表顺序由小到大。降噪过程时间步从列表索引999到0。X_0是真实图像。

image

image

https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/scripts/run_celebahq.py#L132-L137

def train(
    exp_name, tpu_name, bucket_name_prefix, model_name='unet2d16b2c112244', dataset='celebahq256',
    optimizer='adam', total_bs=64, grad_clip=1., lr=0.00002, warmup=5000,
    num_diffusion_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule='linear', loss_type='noisepred',
    dropout=0.0, randflip=1, block_size=1,
    tfds_data_dir='tensorflow_datasets', log_dir='logs'

https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L26-L27

  elif beta_schedule == 'linear':
    betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)

https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L205-L217

    i_0 = tf.constant(self.num_timesteps - 1, dtype=tf.int32)
    img_0 = noise_fn(shape=shape, dtype=tf.float32)
    _, img_final = tf.while_loop(
      cond=lambda i_, _: tf.greater_equal(i_, 0),
      body=lambda i_, img_: [
        i_ - 1,
        self.p_sample(
          denoise_fn=denoise_fn, x=img_, t=tf.fill([shape[0]], i_), noise_fn=noise_fn, return_pred_xstart=False)
      ],
      loop_vars=[i_0, img_0],
      shape_invariants=[i_0.shape, img_0.shape],
      back_prop=False
    )

LDM latent diffusion类似。
LDM DDPM开始参数小于结束参数。betas按列表顺序由小到大。降噪过程时间步从列表索引999到0。X_0是真实图像。

https://github.com/CompVis/latent-diffusion/blob/a506df5756472e2ebaf9078affdde2c4f1502cd4/configs/latent-diffusion/cin256-v2.yaml#L5-L6

    linear_start: 0.0015
    linear_end: 0.0195

https://github.com/CompVis/latent-diffusion/blob/171cf29fb54afe048b03ec73da8abb9d102d0614/ldm/modules/diffusionmodules/util.py#L22-L25

    if schedule == "linear":
        betas = (
                torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
        )

https://github.com/CompVis/latent-diffusion/blob/a506df5756472e2ebaf9078affdde2c4f1502cd4/ldm/models/diffusion/ddpm.py#L258-L260

        for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
                                clip_denoised=self.clip_denoised)

LDM DDIM

https://github.com/CompVis/latent-diffusion/blob/a506df5756472e2ebaf9078affdde2c4f1502cd4/ldm/models/diffusion/ddim.py#L133-L160

        time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
        total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
        print(f"Running DDIM Sampling with {total_steps} timesteps")


        iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)


        for i, step in enumerate(iterator):
            index = total_steps - i - 1
            ts = torch.full((b,), step, device=device, dtype=torch.long)


            if mask is not None:
                assert x0 is not None
                img_orig = self.model.q_sample(x0, ts)  # TODO: deterministic forward pass?
                img = img_orig * mask + (1. - mask) * img


            outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
                                      quantize_denoised=quantize_denoised, temperature=temperature,
                                      noise_dropout=noise_dropout, score_corrector=score_corrector,
                                      corrector_kwargs=corrector_kwargs,
                                      unconditional_guidance_scale=unconditional_guidance_scale,
                                      unconditional_conditioning=unconditional_conditioning)
            img, pred_x0 = outs
            if callback: callback(i)
            if img_callback: img_callback(pred_x0, i)


            if index % log_every_t == 0 or index == total_steps - 1:
                intermediates['x_inter'].append(img)
                intermediates['pred_x0'].append(pred_x0)