About the loss
TungyuYoung opened this issue · 5 comments
Dear Hunto,
I tried your method on my task. I found that while the original loss (student's output - ground truth) descended, the other losses, like autoencoder loss, diffusion loss seemed do not change. What do you think is the potential reason?
Best!
Hi, you may check whether you have added the DiffKD module to your optimizer.
Hello, I checked the code and found that it seems that I do have not add the DiffKD module to the optimizer. However, I checked the train.py you provided and it seems that there is no such operation, either. Have I missed something? Can you provide a solution idea? A million thanks! The diff code I created is below:
`
feature_loss_func = FeatureLoss
class DiffKD(nn.Module):
def init(self,
student_channels,
teacher_channels,
kernel_size=3,
inference_steps=5,
num_train_time_steps=1000,
use_ae=False, # use autoencoder
ae_channels=None):
super().init()
self.use_ae = use_ae
self.diffusion_inference_steps = inference_steps
# AutoEncoder for compress teacher feature
if use_ae:
if ae_channels is None:
ae_channels = teacher_channels // 2 # 16 * 2 * 2 = 64
self.ae = AutoEncoder(teacher_channels, ae_channels)
teacher_channels = ae_channels
else:
teacher_channels = teacher_channels
# transform student feature to the same dimension as teacher
self.trans = nn.Conv2d(student_channels, teacher_channels, 1)
# diffusion model - predict noise
self.scheduler = DDIMScheduler(num_train_time_steps=num_train_time_steps, clip_sample=False,
beta_schedule="linear")
self.noise_adapter = NoiseAdapter(teacher_channels, kernel_size)
# pipeline for denoising student feature
self.model = DiffusionModel(in_channels=teacher_channels, kernel_size=kernel_size)
self.pipeline = DDIMPipeline(self.model, self.scheduler, self.noise_adapter)
self.proj = nn.Sequential(nn.Conv2d(teacher_channels, teacher_channels, 1), nn.BatchNorm2d(teacher_channels))
def forward(self, student_feat, teacher_feat):
# student_feat: [B, 16, T, F_c]
student_feat = self.trans(student_feat) # -> student_feat: [B, 32, T, F_c]
if self.use_ae:
hidden_teacher_feat, rec_teacher_feat = self.ae(teacher_feat)
# rec_loss = F.mse_loss(teacher_feat, rec_teacher_feat)
rec_loss = feature_loss_func(rec_teacher_feat, teacher_feat)
teacher_feat = hidden_teacher_feat.detach()
else:
rec_loss = None
# denoise student feature
refined_feature = self.pipeline(
batch_size=student_feat.shape[0],
device=student_feat.device,
dtype=student_feat.dtype,
shape=student_feat.shape[1:],
feat=student_feat,
num_inference_steps=self.diffusion_inference_steps,
proj=self.proj
)
refined_feature = self.proj(refined_feature)
# train diffusion model
ddim_loss = self.ddim_loss(teacher_feat)
# Return: denoised student feature, teacher feature, diffusion loss, AutoEncoder loss
return refined_feature, teacher_feat, ddim_loss, rec_loss
def ddim_loss(self, gt_feat): # diffusion loss
noise = torch.randn(gt_feat.shape, device=gt_feat.device)
bs = gt_feat.shape[0] # batch size
# Sample a random timestep for each feature
time_step = torch.randint(0, self.scheduler.num_train_time_steps, (bs,), device=gt_feat.device).long()
# Add noise to the clean feature according to the noise magnitude at each timestep
noisy_feature = self.scheduler.add_noise(gt_feat, noise, time_step)
noisy_pred = self.model(noisy_feature, time_step)
loss = F.mse_loss(noisy_pred, noise)
return loss
class DiffPro(nn.Module):
def init(self, student, teacher):
super().init()
self.student = student.train()
self.teacher = teacher.eval()
self.student_features = None
self.teacher_features = None
ae_channels = 32
use_ae = True
self.DiffusionProcess = nn.Sequential(
DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
DiffKD(2, 2, kernel_size=1, use_ae=True, ae_channels=8),
).cuda()
def forward(self, noisy_specs):
with torch.no_grad():
teacher_features = self._feature_extractor(self.teacher, noisy_specs)
student_features = self._feature_extractor(self.student, noisy_specs)
refined_features = []
teacher_out_features = []
diff_loss = 0.
ae_loss = 0.
for i in range(len(self.DiffusionProcess)):
refined_feature_, teacher_out_feature_, diff_loss_, ae_loss_ = self.DiffusionProcess[i](student_features[i],
teacher_features[i])
refined_features.append(refined_feature_)
teacher_out_features.append(teacher_out_feature_)
diff_loss += diff_loss_
ae_loss += ae_loss_
return refined_features, teacher_out_features, diff_loss, ae_loss
def _feature_extractor(self, model, noisy_specs):
feature_extractor = layer_feature_extraction.GTCRN_fe(model)
features = []
features_map = feature_extractor.extract_feature_maps(noisy_specs)
encoder_f, decoder_f, dpgrnn1_f, dpgrnn2_f = (features_map["encoder"],
features_map["decoder"],
features_map["dpgrnn1"][0][0],
features_map["dpgrnn2"][0][0])
for i in range(len(encoder_f)):
features.append(encoder_f[i])
features.append(dpgrnn1_f)
features.append(dpgrnn2_f)
for j in range(len(decoder_f)):
features.append(decoder_f[j])
feature_extractor.remove_hook()
return features
`
I added the module into the student via the following code:
https://github.com/hunto/image_classification_sota/blob/6cb144105fc5c2f778e51cc66e35314938f96fae/lib/models/losses/kd_loss.py#L94
Sorry about this, I'll consider a better way to aachiev it.
Understood. I will try to fix it. Thanks again!
Dear Hunto,
When I tried to save the trained student model, it seemed that the DIFF module would be saved together. Is there anyway to avoid it? Or do I need to utilize the DIFF module as well while inferencing?