ankanbhunia/PIDM

can you teach me how the"frozen_out" work? thanks!

gouchaonijiao opened this issue · 2 comments

frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
terms["vb"] = self._vb_terms_bpd(
model=lambda *args, r=frozen_out: r,
x_start=x_start,
x_t=x_t,
t=t,
clip_denoised=False,
)["output"]

The self._vb_terms_bpd function takes model as one of its input (note that it is a function). Now, we already have the output of the function model(), thus, we save it as frozen_out and pass it as a dummy lambda function in line 997, so that we don't re-calculate the function again inside the self._vb_terms_bpd function. Hope it helps.

Thank you for your prompt reply! But I still can't understand very well. Frozen_out is not a function,so how these code
" out = self.p_mean_variance(
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
)" in line 931 and
" if x_cond is None:
model_output = model(x, self._scale_timesteps(t), **model_kwargs)" in line 326 run?