About the loss of image quality after adding diffusion
Spring-lovely opened this issue · 8 comments
Hi, Dear author
Thank you so much for your open source work. I have the following questions when running the code, I hope you can take some time to answer them. I'm 6 frames predict 6 frames.
The first line of the following picture is the input frame, the second line is the label, the third line is the final prediction, the fourth line is backbone output, and now the diff is added, but the image quality is significantly decreased (see line 3), I do not know whether there is a problem with the function of the denoising part, please help to see.
Only add the following functions to diffcast.py
` def predict(self, frames_in, compute_loss=False, frames_gt=None, **kwargs):
T_out = default(kwargs.get('T_out'), 6)
if compute_loss:
B, T_in, c, h, w = frames_in.shape
device = self.device
backbone_output, backbone_loss = self.backbone_net.predict(frames_in, frames_gt=frames_gt,
compute_loss=compute_loss, **kwargs)
residual = frames_gt - backbone_output
global_ctx, local_ctx = self.ctx_net.scan_ctx(torch.cat((frames_in, backbone_output), dim=1))
pre_frag = frames_in
pre_mu = None
pred_ress = []
diff_loss = 0.
t = torch.randint(0, self.num_timesteps, (B,), device=self.device).long()
for frag_idx in range(T_out // T_in):
mu = backbone_output[:, frag_idx * T_in : (frag_idx + 1) * T_in]
res = residual[:, frag_idx * T_in : (frag_idx + 1) * T_in]
cond = pre_frag - pre_mu if pre_mu is not None else torch.zeros_like(pre_frag)
res_pred, noise_loss = self.p_losses(res, t, cond=cond, ctx=global_ctx if frag_idx > 0 else local_ctx,
idx=torch.full((B,), frag_idx, device=device, dtype=torch.long))
diff_loss += noise_loss
frag_pred = res_pred + mu
pre_frag = frag_pred
pre_mu = mu
alpha = torch.tensor(0.5)
loss = (1 - alpha) * backbone_loss + alpha * diff_loss / 3.
#backbone_output = self.unnormalize(backbone_output)
return backbone_output, loss
else:
pred, mu, y = self.sample(frames_in=frames_in, T_out=T_out)
loss = None
backbone_loss = None
diff_loss = None
# return pred, mu, y, loss
return pred, mu
def p_losses(self, x_start, t, noise=None, offset_noise_strength=None, ctx=None, idx=None, cond=None):
b, _, c, h, w = x_start.shape
noise = default(noise, lambda: torch.randn_like(x_start))
# offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise
offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)
if offset_noise_strength > 0.:
offset_noise = torch.randn(x_start.shape[:2], device = self.device)
noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')
# noise sample
x = self.predict_v(x_start=x_start, t=t, noise=noise)
model_out = self.model(x, t, cond=cond, ctx=ctx, idx=idx)
if self.objective == 'pred_noise':
target = noise
elif self.objective == 'pred_x0':
target = x_start
elif self.objective == 'pred_v':
v = self.predict_v(x_start, t, noise)
target = v
else:
raise ValueError(f'unknown objective {self.objective}')
loss = F.mse_loss(model_out, target, reduction = 'none')
loss = reduce(loss, 'b ... -> b', 'mean')
loss = loss * extract(self.loss_weight, t, loss.shape)
return model_out, loss.mean()
`
Hi, thanks for following our work.
From your code, I notice that you utilize the model_out
as the residual_pred
, which is an error for diffusion training strategy.
You can find a demo achievement of p_losses
func in Line 767 at this.
Typically, diffusion models do not generate target during training.
Finally, you can determine if the diffusion model is being properly trained by observing the trend of the diff_loss during the training process.
Hi, thanks for following our work. From your code, I notice that you utilize the
model_out
as theresidual_pred
, which is an error for diffusion training strategy. You can find a demo achievement ofp_losses
func in Line 767 at this. Typically, diffusion models do not generate target during training. Finally, you can determine if the diffusion model is being properly trained by observing the trend of the diff_loss during the training process.
@DeminYu98 , But from the code @Spring-lovely post, it seems the diffusion model's output is residual_pred, since the p_loss's target is residual itself, it looks the same from the link you post, if the model_out is not residual_pred, then what it is, if you can give me a more specific case or point out which line should be modified to make things right, i will be very grateful. Thanks for your time.
@Yager-42 Well, thanks for pointing out your confusion.
I think it would be better to considerate this question from the basic theory of DDPM.
- Firstly, forget the residual, fragment or whatever, we divide the backbone and diffusion model into two independent components. The diffusion part in Diffcast is simply applied to generate a seqence.
- Secondly, imagine we let diffusion model generate overall residual sequence given nothing at once, it is just an unconditional generation task and it is feasible, right?
- Then we can think about that which distribution is the learning target of diffusion model , what is the input / output of unet during training and what is the input/output of unet during inference. It is just what this repo do.
- In this way we just use
noise_loss = self.p_loss(residual, t)
incompute loss
for training diffusion, just as shown in DDPM repo - Finally, as for
the p_loss's target is residual itself
, I still cannot get it. In 'p_loss' func, the loss target is actually theF.mse_loss(model_out, noise)
, which is just the leanring trarget of original DDPM.
@DeminYu98 Thanks for your reply, i do understand now, best wishes to you
Hello, I apologize for disturbing you:
After your last guidance, during the training phase of calculating the loss function, cond = frame_in - gt, UNet predicts x0, and I finally get the prediction result as shown in the third row. I find that my final result is more similar to the input (first row) rather than the ground truth (second row). Is there something wrong with my code? Although it indeed looks more refined than the backbone, I found that after adding diffusion, metrics such as CSI and HSS have decreased compared to the backbone. Where might the problem be? Thank you. Looking forward to your reply.
@Spring-lovely I apologize for the delayed response.
From a visual standpoint, your prediction results appear normal. To make a more comprehensive assessment, it would be helpful to examine the training loss and residual visualizations.
Given that I'm not fully aware of your specific training process, I'm unable to determine if there are any issues definitively.
Regarding the metrics you mentioned, I noticed that you're only predicting 6 frames. For such short-term predictions, deterministic backbones indeed tend to perform better. Additionally, I'm curious whether your 6-frame prediction is based on an autoregressive approach or as a single segment prediction? The segment size can significantly impact DiffCast's performance, as we've noted in our paper's appendix.
If you have any further questions or need clarification on any point, please don't hesitate to ask.
Thanks, @Spring-lovely for providing the loss function. However, I have a few suggestions:
- We need to set the
auto_normalize
variable inGaussianDiffusion
class toFalse
for the loss function to work correctly. x
inp_loss
on the code above should be updated tox = self.q_sample(x_start=x_start, t=t, noise=noise)
as suggested in previous discussions.- As the task of @Spring-lovely seems to be a 1-in-1-out task whereas the paper describes an autoregressive 1-in-4-out task. This means that the loss function provided above does not fit with autoregressive task, potentially leading to the discontinuity between segments. Thus, I recommended using:
cond = [:, (frag_idx-1) * T_in : (frag_idx) * T_in] if pre_mu is not None else torch.zeros_like(pre_frag)
.
I've attached the final loss function below. I would appreciate any feedback or suggestions from anyone.
def compute_loss(self, frames_in, frames_gt):
compute_loss = True
B, T_in, c, h, w = frames_in.shape
T_out = frames_gt.shape[1]
device = frames_in.device
backbone_output, backbone_loss = self.backbone_net.predict(frames_in, frames_gt=frames_gt, compute_loss=compute_loss)
frames_in = self.normalize(frames_in)
backbone_output = self.normalize(backbone_output)
frames_gt = self.normalize(frames_gt)
residual = frames_gt - backbone_output
global_ctx, local_ctx = self.ctx_net.scan_ctx(torch.cat((frames_in, backbone_output), dim=1))
pre_frag = frames_in
pre_mu = None
pred_ress = []
diff_loss = 0.
t = torch.randint(0, self.num_timesteps, (B,), device=self.device).long()
for frag_idx in range(T_out // T_in):
mu = backbone_output[:, frag_idx * T_in : (frag_idx + 1) * T_in]
res = residual[:, frag_idx * T_in : (frag_idx + 1) * T_in]
cond = pre_frag - pre_mu if pre_mu is not None else torch.zeros_like(pre_frag)
res_pred, noise_loss = self.p_losses(res, t, cond=cond, ctx=global_ctx if frag_idx > 0 else local_ctx,
idx=torch.full((B,), frag_idx, device=device, dtype=torch.long))
diff_loss += noise_loss
pre_frag = frames_gt[:, frag_idx * T_in : (frag_idx + 1) * T_in]
pre_mu = mu
diff_loss /= (T_out // T_in)
alpha = torch.tensor(0.5)
loss = (1 - alpha) * backbone_loss + alpha * diff_loss
return loss
# Reference: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L763
@autocast(enabled = False)
def q_sample(self, x_start, t, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def p_losses(self, x_start, t, noise=None, offset_noise_strength=None, ctx=None, idx=None, cond=None):
b, _, c, h, w = x_start.shape
noise = default(noise, lambda: torch.randn_like(x_start))
# offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise
offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)
if offset_noise_strength > 0.:
offset_noise = torch.randn(x_start.shape[:2], device = self.device)
noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')
# noise sample
x = self.q_sample(x_start=x_start, t=t, noise=noise) # Use q_sample here for updating
model_out = self.model(x, t, cond=cond, ctx=ctx, idx=idx)
if self.objective == 'pred_noise':
target = noise
elif self.objective == 'pred_x0':
target = x_start
elif self.objective == 'pred_v':
v = self.predict_v(x_start, t, noise)
target = v
else:
raise ValueError(f'unknown objective {self.objective}')
loss = F.mse_loss(model_out, target, reduction = 'none')
loss = reduce(loss, 'b ... -> b', 'mean')
loss = loss * extract(self.loss_weight, t, loss.shape)
return model_out, loss.mean()