hunto/DiffKD

About the loss

TungyuYoung opened this issue · 5 comments

Dear Hunto,
I tried your method on my task. I found that while the original loss (student's output - ground truth) descended, the other losses, like autoencoder loss, diffusion loss seemed do not change. What do you think is the potential reason?

Best!

Hi, you may check whether you have added the DiffKD module to your optimizer.

Hello, I checked the code and found that it seems that I do have not add the DiffKD module to the optimizer. However, I checked the train.py you provided and it seems that there is no such operation, either. Have I missed something? Can you provide a solution idea? A million thanks! The diff code I created is below:

`
feature_loss_func = FeatureLoss

class DiffKD(nn.Module):
def init(self,
student_channels,
teacher_channels,
kernel_size=3,
inference_steps=5,
num_train_time_steps=1000,
use_ae=False, # use autoencoder
ae_channels=None):
super().init()
self.use_ae = use_ae
self.diffusion_inference_steps = inference_steps

    # AutoEncoder for compress teacher feature
    if use_ae:
        if ae_channels is None:
            ae_channels = teacher_channels // 2  # 16 * 2 * 2 = 64
        self.ae = AutoEncoder(teacher_channels, ae_channels)
        teacher_channels = ae_channels
    else:
        teacher_channels = teacher_channels
    # transform student feature to the same dimension as teacher
    self.trans = nn.Conv2d(student_channels, teacher_channels, 1)

    # diffusion model - predict noise
    self.scheduler = DDIMScheduler(num_train_time_steps=num_train_time_steps, clip_sample=False,
                                   beta_schedule="linear")
    self.noise_adapter = NoiseAdapter(teacher_channels, kernel_size)

    # pipeline for denoising student feature
    self.model = DiffusionModel(in_channels=teacher_channels, kernel_size=kernel_size)
    self.pipeline = DDIMPipeline(self.model, self.scheduler, self.noise_adapter)
    self.proj = nn.Sequential(nn.Conv2d(teacher_channels, teacher_channels, 1), nn.BatchNorm2d(teacher_channels))

def forward(self, student_feat, teacher_feat):
    # student_feat: [B, 16, T, F_c]
    student_feat = self.trans(student_feat)  # -> student_feat: [B, 32, T, F_c]

    if self.use_ae:
        hidden_teacher_feat, rec_teacher_feat = self.ae(teacher_feat)
        # rec_loss = F.mse_loss(teacher_feat, rec_teacher_feat)
        rec_loss = feature_loss_func(rec_teacher_feat, teacher_feat)
        teacher_feat = hidden_teacher_feat.detach()
    else:
        rec_loss = None

    # denoise student feature
    refined_feature = self.pipeline(
        batch_size=student_feat.shape[0],
        device=student_feat.device,
        dtype=student_feat.dtype,
        shape=student_feat.shape[1:],
        feat=student_feat,
        num_inference_steps=self.diffusion_inference_steps,
        proj=self.proj
    )
    refined_feature = self.proj(refined_feature)

    # train diffusion model
    ddim_loss = self.ddim_loss(teacher_feat)

    # Return: denoised student feature, teacher feature, diffusion loss, AutoEncoder loss
    return refined_feature, teacher_feat, ddim_loss, rec_loss

def ddim_loss(self, gt_feat):  # diffusion loss
    noise = torch.randn(gt_feat.shape, device=gt_feat.device)
    bs = gt_feat.shape[0]  # batch size

    # Sample a random timestep for each feature
    time_step = torch.randint(0, self.scheduler.num_train_time_steps, (bs,), device=gt_feat.device).long()

    # Add noise to the clean feature according to the noise magnitude at each timestep
    noisy_feature = self.scheduler.add_noise(gt_feat, noise, time_step)
    noisy_pred = self.model(noisy_feature, time_step)
    loss = F.mse_loss(noisy_pred, noise)
    return loss

class DiffPro(nn.Module):
def init(self, student, teacher):
super().init()
self.student = student.train()
self.teacher = teacher.eval()

    self.student_features = None
    self.teacher_features = None

    ae_channels = 32
    use_ae = True

    self.DiffusionProcess = nn.Sequential(
        DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
        DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
        DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
        DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
        DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
        DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
        DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
        DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
        DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
        DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
        DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
        DiffKD(2, 2, kernel_size=1, use_ae=True, ae_channels=8),
    ).cuda()

def forward(self, noisy_specs):
    with torch.no_grad():
        teacher_features = self._feature_extractor(self.teacher, noisy_specs)

    student_features = self._feature_extractor(self.student, noisy_specs)

    refined_features = []
    teacher_out_features = []
    diff_loss = 0.
    ae_loss = 0.

    for i in range(len(self.DiffusionProcess)):
        refined_feature_, teacher_out_feature_, diff_loss_, ae_loss_ = self.DiffusionProcess[i](student_features[i],
                                                                                                teacher_features[i])
        refined_features.append(refined_feature_)
        teacher_out_features.append(teacher_out_feature_)
        diff_loss += diff_loss_
        ae_loss += ae_loss_

    return refined_features, teacher_out_features, diff_loss, ae_loss

def _feature_extractor(self, model, noisy_specs):
    feature_extractor = layer_feature_extraction.GTCRN_fe(model)
    features = []

    features_map = feature_extractor.extract_feature_maps(noisy_specs)

    encoder_f, decoder_f, dpgrnn1_f, dpgrnn2_f = (features_map["encoder"],
                                                  features_map["decoder"],
                                                  features_map["dpgrnn1"][0][0],
                                                  features_map["dpgrnn2"][0][0])

    for i in range(len(encoder_f)):
        features.append(encoder_f[i])
    features.append(dpgrnn1_f)
    features.append(dpgrnn2_f)
    for j in range(len(decoder_f)):
        features.append(decoder_f[j])

    feature_extractor.remove_hook()

    return features

`

I added the module into the student via the following code:
https://github.com/hunto/image_classification_sota/blob/6cb144105fc5c2f778e51cc66e35314938f96fae/lib/models/losses/kd_loss.py#L94

Sorry about this, I'll consider a better way to aachiev it.

Understood. I will try to fix it. Thanks again!

Dear Hunto,
When I tried to save the trained student model, it seemed that the DIFF module would be saved together. Is there anyway to avoid it? Or do I need to utilize the DIFF module as well while inferencing?