Sense-X/UniFormer

The picture has to be pre-trained

Aaron198513 opened this issue · 1 comments

The picture has to be pre-trained, how do I change the parameters of the 2D convolution into the parameters of the 3D convolution

Please check

def inflate_weight(self, weight_2d, time_dim, center=False):
if center:
weight_3d = torch.zeros(*weight_2d.shape)
weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
middle_idx = time_dim // 2
weight_3d[:, :, middle_idx, :, :] = weight_2d
else:
weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
weight_3d = weight_3d / time_dim
return weight_3d
def get_pretrained_model(self, cfg):
if cfg.UNIFORMER.PRETRAIN_NAME:
checkpoint = torch.load(model_path[cfg.UNIFORMER.PRETRAIN_NAME], map_location='cpu')
if 'model' in checkpoint:
checkpoint = checkpoint['model']
elif 'model_state' in checkpoint:
checkpoint = checkpoint['model_state']
state_dict_3d = self.state_dict()
for k in checkpoint.keys():
if checkpoint[k].shape != state_dict_3d[k].shape:
if len(state_dict_3d[k].shape) <= 2:
logger.info(f'Ignore: {k}')
continue
logger.info(f'Inflate: {k}, {checkpoint[k].shape} => {state_dict_3d[k].shape}')
time_dim = state_dict_3d[k].shape[2]
checkpoint[k] = self.inflate_weight(checkpoint[k], time_dim)
if self.num_classes != checkpoint['head.weight'].shape[0]:
del checkpoint['head.weight']
del checkpoint['head.bias']
return checkpoint
else:
return None