The picture has to be pre-trained
Aaron198513 opened this issue · 1 comments
Aaron198513 commented
The picture has to be pre-trained, how do I change the parameters of the 2D convolution into the parameters of the 3D convolution
Andy1621 commented
Please check
UniFormer/video_classification/slowfast/models/uniformer.py
Lines 387 to 421 in f92e423
def inflate_weight(self, weight_2d, time_dim, center=False): | |
if center: | |
weight_3d = torch.zeros(*weight_2d.shape) | |
weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) | |
middle_idx = time_dim // 2 | |
weight_3d[:, :, middle_idx, :, :] = weight_2d | |
else: | |
weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) | |
weight_3d = weight_3d / time_dim | |
return weight_3d | |
def get_pretrained_model(self, cfg): | |
if cfg.UNIFORMER.PRETRAIN_NAME: | |
checkpoint = torch.load(model_path[cfg.UNIFORMER.PRETRAIN_NAME], map_location='cpu') | |
if 'model' in checkpoint: | |
checkpoint = checkpoint['model'] | |
elif 'model_state' in checkpoint: | |
checkpoint = checkpoint['model_state'] | |
state_dict_3d = self.state_dict() | |
for k in checkpoint.keys(): | |
if checkpoint[k].shape != state_dict_3d[k].shape: | |
if len(state_dict_3d[k].shape) <= 2: | |
logger.info(f'Ignore: {k}') | |
continue | |
logger.info(f'Inflate: {k}, {checkpoint[k].shape} => {state_dict_3d[k].shape}') | |
time_dim = state_dict_3d[k].shape[2] | |
checkpoint[k] = self.inflate_weight(checkpoint[k], time_dim) | |
if self.num_classes != checkpoint['head.weight'].shape[0]: | |
del checkpoint['head.weight'] | |
del checkpoint['head.bias'] | |
return checkpoint | |
else: | |
return None |