DeminYu98/DiffCast

About the loss of image quality after adding diffusion

Spring-lovely opened this issue · 8 comments

Hi, Dear author
Thank you so much for your open source work. I have the following questions when running the code, I hope you can take some time to answer them. I'm 6 frames predict 6 frames.

1715245073046

The first line of the following picture is the input frame, the second line is the label, the third line is the final prediction, the fourth line is backbone output, and now the diff is added, but the image quality is significantly decreased (see line 3), I do not know whether there is a problem with the function of the denoising part, please help to see.

Only add the following functions to diffcast.py
` def predict(self, frames_in, compute_loss=False, frames_gt=None, **kwargs):
T_out = default(kwargs.get('T_out'), 6)

    if compute_loss:
        B, T_in, c, h, w = frames_in.shape
        device = self.device

        backbone_output, backbone_loss = self.backbone_net.predict(frames_in, frames_gt=frames_gt,
                                                                   compute_loss=compute_loss, **kwargs)

        residual = frames_gt - backbone_output
        global_ctx, local_ctx = self.ctx_net.scan_ctx(torch.cat((frames_in, backbone_output), dim=1))

        pre_frag = frames_in
        pre_mu = None
        pred_ress = []
        diff_loss = 0.
        t = torch.randint(0, self.num_timesteps, (B,), device=self.device).long()
        for frag_idx in range(T_out // T_in):
            mu = backbone_output[:, frag_idx * T_in : (frag_idx + 1) * T_in]
            res = residual[:, frag_idx * T_in : (frag_idx + 1) * T_in]

            cond = pre_frag - pre_mu if pre_mu is not None else torch.zeros_like(pre_frag)
            res_pred, noise_loss = self.p_losses(res, t, cond=cond, ctx=global_ctx if frag_idx > 0 else local_ctx,
                                                 idx=torch.full((B,), frag_idx, device=device, dtype=torch.long))
            diff_loss += noise_loss

            frag_pred = res_pred + mu
            pre_frag = frag_pred
            pre_mu = mu

        alpha = torch.tensor(0.5)
        loss = (1 - alpha) * backbone_loss + alpha * diff_loss / 3.

        #backbone_output = self.unnormalize(backbone_output)
        return backbone_output, loss
    else:
        pred, mu, y = self.sample(frames_in=frames_in, T_out=T_out)
        loss = None
        backbone_loss = None
        diff_loss = None

        # return pred, mu, y, loss
        return pred, mu

def p_losses(self, x_start, t, noise=None, offset_noise_strength=None, ctx=None, idx=None, cond=None):
    b, _, c, h, w = x_start.shape

    noise = default(noise, lambda: torch.randn_like(x_start))

    # offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise
    offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)

    if offset_noise_strength > 0.:
        offset_noise = torch.randn(x_start.shape[:2], device = self.device)
        noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')

    # noise sample
    x = self.predict_v(x_start=x_start, t=t, noise=noise)
    model_out = self.model(x, t, cond=cond, ctx=ctx, idx=idx)

    if self.objective == 'pred_noise':
        target = noise
    elif self.objective == 'pred_x0':
        target = x_start
    elif self.objective == 'pred_v':
        v = self.predict_v(x_start, t, noise)
        target = v
    else:
        raise ValueError(f'unknown objective {self.objective}')

    loss = F.mse_loss(model_out, target, reduction = 'none')
    loss = reduce(loss, 'b ... -> b', 'mean')

    loss = loss * extract(self.loss_weight, t, loss.shape)
    return model_out, loss.mean()

`

Hi, thanks for following our work.
From your code, I notice that you utilize the model_out as the residual_pred, which is an error for diffusion training strategy.
You can find a demo achievement of p_losses func in Line 767 at this.
Typically, diffusion models do not generate target during training.
Finally, you can determine if the diffusion model is being properly trained by observing the trend of the diff_loss during the training process.

Hi, thanks for following our work. From your code, I notice that you utilize the model_out as the residual_pred, which is an error for diffusion training strategy. You can find a demo achievement of p_losses func in Line 767 at this. Typically, diffusion models do not generate target during training. Finally, you can determine if the diffusion model is being properly trained by observing the trend of the diff_loss during the training process.

@DeminYu98 , But from the code @Spring-lovely post, it seems the diffusion model's output is residual_pred, since the p_loss's target is residual itself, it looks the same from the link you post, if the model_out is not residual_pred, then what it is, if you can give me a more specific case or point out which line should be modified to make things right, i will be very grateful. Thanks for your time.

@Yager-42 Well, thanks for pointing out your confusion.
I think it would be better to considerate this question from the basic theory of DDPM.

  • Firstly, forget the residual, fragment or whatever, we divide the backbone and diffusion model into two independent components. The diffusion part in Diffcast is simply applied to generate a seqence.
  • Secondly, imagine we let diffusion model generate overall residual sequence given nothing at once, it is just an unconditional generation task and it is feasible, right?
  • Then we can think about that which distribution is the learning target of diffusion model , what is the input / output of unet during training and what is the input/output of unet during inference. It is just what this repo do.
  • In this way we just use noise_loss = self.p_loss(residual, t) in compute loss for training diffusion, just as shown in DDPM repo
  • Finally, as for the p_loss's target is residual itself, I still cannot get it. In 'p_loss' func, the loss target is actually the F.mse_loss(model_out, noise), which is just the leanring trarget of original DDPM.

@DeminYu98 Thanks for your reply, i do understand now, best wishes to you

Hello, I apologize for disturbing you:
After your last guidance, during the training phase of calculating the loss function, cond = frame_in - gt, UNet predicts x0, and I finally get the prediction result as shown in the third row. I find that my final result is more similar to the input (first row) rather than the ground truth (second row). Is there something wrong with my code? Although it indeed looks more refined than the backbone, I found that after adding diffusion, metrics such as CSI and HSS have decreased compared to the backbone. Where might the problem be? Thank you. Looking forward to your reply.
1720511012549
image

@Spring-lovely I apologize for the delayed response.
From a visual standpoint, your prediction results appear normal. To make a more comprehensive assessment, it would be helpful to examine the training loss and residual visualizations.
Given that I'm not fully aware of your specific training process, I'm unable to determine if there are any issues definitively.
Regarding the metrics you mentioned, I noticed that you're only predicting 6 frames. For such short-term predictions, deterministic backbones indeed tend to perform better. Additionally, I'm curious whether your 6-frame prediction is based on an autoregressive approach or as a single segment prediction? The segment size can significantly impact DiffCast's performance, as we've noted in our paper's appendix.
If you have any further questions or need clarification on any point, please don't hesitate to ask.

Thanks, @Spring-lovely for providing the loss function. However, I have a few suggestions:

  1. We need to set the auto_normalize variable in GaussianDiffusion class to False for the loss function to work correctly.
  2. x in p_loss on the code above should be updated to x = self.q_sample(x_start=x_start, t=t, noise=noise) as suggested in previous discussions.
  3. As the task of @Spring-lovely seems to be a 1-in-1-out task whereas the paper describes an autoregressive 1-in-4-out task. This means that the loss function provided above does not fit with autoregressive task, potentially leading to the discontinuity between segments. Thus, I recommended using:
    cond = [:, (frag_idx-1) * T_in : (frag_idx) * T_in] if pre_mu is not None else torch.zeros_like(pre_frag).

I've attached the final loss function below. I would appreciate any feedback or suggestions from anyone.

def compute_loss(self, frames_in, frames_gt):
    compute_loss = True
    B, T_in, c, h, w = frames_in.shape
    T_out = frames_gt.shape[1]
    device = frames_in.device

    backbone_output, backbone_loss = self.backbone_net.predict(frames_in, frames_gt=frames_gt, compute_loss=compute_loss)

    frames_in = self.normalize(frames_in)
    backbone_output = self.normalize(backbone_output)
    frames_gt = self.normalize(frames_gt)

    residual = frames_gt - backbone_output
    global_ctx, local_ctx = self.ctx_net.scan_ctx(torch.cat((frames_in, backbone_output), dim=1))

    pre_frag = frames_in
    pre_mu = None
    pred_ress = []
    diff_loss = 0.
    t = torch.randint(0, self.num_timesteps, (B,), device=self.device).long()
    for frag_idx in range(T_out // T_in):
        mu = backbone_output[:, frag_idx * T_in : (frag_idx + 1) * T_in]
        res = residual[:, frag_idx * T_in : (frag_idx + 1) * T_in]

        cond = pre_frag - pre_mu if pre_mu is not None else torch.zeros_like(pre_frag)
        res_pred, noise_loss = self.p_losses(res, t, cond=cond, ctx=global_ctx if frag_idx > 0 else local_ctx,
                                                idx=torch.full((B,), frag_idx, device=device, dtype=torch.long))
        diff_loss += noise_loss

        pre_frag = frames_gt[:, frag_idx * T_in : (frag_idx + 1) * T_in]
        pre_mu = mu
    diff_loss /= (T_out // T_in)

    alpha = torch.tensor(0.5)
    loss = (1 - alpha) * backbone_loss + alpha * diff_loss
    return loss

# Reference: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L763
@autocast(enabled = False)
def q_sample(self, x_start, t, noise = None):
    noise = default(noise, lambda: torch.randn_like(x_start))

    return (
        extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
        extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
    )

def p_losses(self, x_start, t, noise=None, offset_noise_strength=None, ctx=None, idx=None, cond=None):
    b, _, c, h, w = x_start.shape

    noise = default(noise, lambda: torch.randn_like(x_start))

    # offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise
    offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)

    if offset_noise_strength > 0.:
        offset_noise = torch.randn(x_start.shape[:2], device = self.device)
        noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')

    # noise sample
    x = self.q_sample(x_start=x_start, t=t, noise=noise) # Use q_sample here for updating

    model_out = self.model(x, t, cond=cond, ctx=ctx, idx=idx)

    if self.objective == 'pred_noise':
        target = noise
    elif self.objective == 'pred_x0':
        target = x_start
    elif self.objective == 'pred_v':
        v = self.predict_v(x_start, t, noise)
        target = v
    else:
        raise ValueError(f'unknown objective {self.objective}')

    loss = F.mse_loss(model_out, target, reduction = 'none')
    loss = reduce(loss, 'b ... -> b', 'mean')

    loss = loss * extract(self.loss_weight, t, loss.shape)
    return model_out, loss.mean()

您好,很抱歉打扰了您:经过您的最后一次指导,在计算损失函数的训练阶段,cond = frame_in - gt,UNet 预测 x0,我终于得到了第三行所示的预测结果。我发现我的最终结果更类似于输入(第一行),而不是真实值(第二行)。我的代码有问题吗?虽然它看起来确实比 backbone 更精致,但我发现,添加 diffusion 后,CSI 和 HSS 等指标相比 backbone 有所下降。问题可能出在哪里?谢谢。期待您的回复。 1720511012549 图像

您好,请问方便提供一下您最终的训练损失代码吗,万分感谢