Question about DDPM and DDIM sampling.
e4s2022 opened this issue · 1 comments
Hi, thanks for sharing your excellent work!
I just walked through the code base and noticed that during sampling you used timestamp t from 0 to 999 (see here. I think in the reversed pass, we should start from 999 till 0. I'm a little confused about this.
Another question is, what does the denoise
option mean for the last sampling step? please check here.
These two questions can be raised either for the DDPM or DDIM sampler. Really appreciate your explanation.
您好,我的看法是这样的。
作者使用的是基于朗之万动力学NCSN扩散模型。原论文在设置参数的时候,开始参数大于结束参数。betas按列表顺序由大到小。在降噪过程中,betas应该是由大到小,列表索引应该是0-999。X_T是真实图像。
mcvd-pytorch/configs/kth64_big.yml
Lines 81 to 82 in 451da2e
mcvd-pytorch/models/__init__.py
Lines 24 to 26 in 226a3fd
mcvd-pytorch/models/better/ncsnpp_more.py
Lines 736 to 739 in 226a3fd
mcvd-pytorch/models/__init__.py
Line 267 in 451da2e
作为对比:
在DDPM原论文中,开始参数小于结束参数。betas按列表顺序由小到大。降噪过程时间步从列表索引999到0。X_0是真实图像。
def train(
exp_name, tpu_name, bucket_name_prefix, model_name='unet2d16b2c112244', dataset='celebahq256',
optimizer='adam', total_bs=64, grad_clip=1., lr=0.00002, warmup=5000,
num_diffusion_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule='linear', loss_type='noisepred',
dropout=0.0, randflip=1, block_size=1,
tfds_data_dir='tensorflow_datasets', log_dir='logs'
elif beta_schedule == 'linear':
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
i_0 = tf.constant(self.num_timesteps - 1, dtype=tf.int32)
img_0 = noise_fn(shape=shape, dtype=tf.float32)
_, img_final = tf.while_loop(
cond=lambda i_, _: tf.greater_equal(i_, 0),
body=lambda i_, img_: [
i_ - 1,
self.p_sample(
denoise_fn=denoise_fn, x=img_, t=tf.fill([shape[0]], i_), noise_fn=noise_fn, return_pred_xstart=False)
],
loop_vars=[i_0, img_0],
shape_invariants=[i_0.shape, img_0.shape],
back_prop=False
)
LDM latent diffusion类似。
LDM DDPM开始参数小于结束参数。betas按列表顺序由小到大。降噪过程时间步从列表索引999到0。X_0是真实图像。
linear_start: 0.0015
linear_end: 0.0195
if schedule == "linear":
betas = (
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
clip_denoised=self.clip_denoised)
LDM DDIM
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
img, pred_x0 = outs
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)