How can we do transfer learning ?
rengotj opened this issue · 1 comments
rengotj commented
Hello,
Thank you very much for this really interesting project.
I am wondering about transfer learning. Is it possible to specify in the model configuration from which previously trained model the new training should start ?
Thank you for your help.
Best regards.
LeDat98 commented
Add a configuration parameter for pre-trained weights in file config/config.yaml
pretrained_model_path: "path/to/pretrained/model.h5"
change _init_params function in file train.py to
def _init_params(self):
self.criterionG, criterionD = get_loss(self.config['model'])
self.netG, netD = get_nets(self.config['model'])
self.netG.cuda() # Move generator to GPU
# Check if using DataParallel (recommended if multiple GPUs are available)
if torch.cuda.device_count() > 1:
self.netG = torch.nn.DataParallel(self.netG)
print(f"Using {torch.cuda.device_count()} GPUs with DataParallel")
pretrained_path = self.config.get('pretrained_model_path', None)
if pretrained_path and os.path.isfile(pretrained_path):
checkpoint = torch.load(pretrained_path)
state_dict = checkpoint['model']
# If model was trained using DataParallel, its keys will have 'module.' prefix
if not isinstance(self.netG, torch.nn.DataParallel) and any(k.startswith('module.') for k in state_dict.keys()):
# Create new state dict without 'module.' prefix
new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
self.netG.load_state_dict(new_state_dict)
else:
self.netG.load_state_dict(state_dict)
print(f"Loaded pre-trained model from {pretrained_path}")
self.adv_trainer = self._get_adversarial_trainer(self.config['model']['d_name'], netD, criterionD)
self.model = get_model(self.config['model'])
self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters()))
self.optimizer_D = self._get_optim(self.adv_trainer.get_params())
self.scheduler_G = self._get_scheduler(self.optimizer_G)
self.scheduler_D = self._get_scheduler(self.optimizer_D)