Question for your code: The claimed consistent noise seems unused in demo code?
Yuxuan-W opened this issue · 2 comments
def ddcm_sampler(scheduler, x_s, x_t, timestep, e_s, e_t, x_0, noise, eta, to_next=True):
if scheduler.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if scheduler.step_index is None:
scheduler._init_step_index(timestep)
prev_step_index = scheduler.step_index + 1
if prev_step_index < len(scheduler.timesteps):
prev_timestep = scheduler.timesteps[prev_step_index]
else:
prev_timestep = timestep
alpha_prod_t = scheduler.alphas_cumprod[timestep]
alpha_prod_t_prev = (
scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod
)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
variance = beta_prod_t_prev
std_dev_t = eta * variance
noise = std_dev_t ** (0.5) * noise
e_c = (x_s - alpha_prod_t ** (0.5) * x_0) / (1 - alpha_prod_t) ** (0.5)
pred_x0 = x_0 + ((x_t - x_s) - beta_prod_t ** (0.5) * (e_t - e_s)) / alpha_prod_t ** (0.5) # + mv_offset
eps = (e_t - e_s) + e_c
dir_xt = (beta_prod_t_prev - std_dev_t) ** (0.5) * eps
# Noise is not used for one-step sampling.
if len(scheduler.timesteps) > 1:
prev_xt = alpha_prod_t_prev ** (0.5) * pred_x0 + dir_xt + noise
prev_xs = alpha_prod_t_prev ** (0.5) * x_0 + dir_xt + noise
else:
prev_xt = pred_x0
prev_xs = x_0
if to_next:
scheduler._step_index += 1
return prev_xs, prev_xt, pred_x0
Here the eta is set to 1 in your code, but this will lead dir_xt to be always 0.
Besides, I'm a bit of confused of the computation of pred_x0, it seems add a target branch latent to original image's latent and then subtract the source branch latent.
Would be appreciate for your reply!
Hi Yuxuan,
Thanks for your question, and in fact we use the consistent noise in the following code:
pred_x0 = x_0 + ((x_t - x_s) - beta_prod_t ** (0.5) * (e_t - e_s)) / alpha_prod_t ** (0.5)
The reason why we’re not using the consistent noise here is we want to avoid floating error caused by it so we make a simplification to not use the consistent noise. Here’s how it work:
And I’m not sure whether consistency model scheduler can use eta as well.
Understand. Thanks for this clear clarifying!