retinaface net=mvn0.25,在onnx导出加入boxes解码和landmark解码,直接在onnx输出,但是框可以输出,landmark分支在onnx导出报错?
EricHuiK opened this issue · 0 comments
import torch
import torch.nn as nn
import torchvision.models.detection.backbone_utils as backbone_utils
import torchvision.models._utils as _utils
import torch.nn.functional as F
from collections import OrderedDict
from layers.functions.prior_box import PriorBox
from models.net import MobileNetV1 as MobileNetV1
from models.net import FPN as FPN
from models.net import SSH as SSH
import numpy as np
def decode_fixed(loc, priors):
# print("decode_fixed,,,,,,,,,,,,,,,,,,,,,,,,,,")
# print("loc.shape: ",loc.shape)
# print("priors.shape: ",priors.shape)
# variances=[0.1,0.2]
"""Decode locations from predictions using priors to undo
the encoding we did for offset regression at train time.
Args:
loc (tensor): location predictions for loc layers,
Shape: [num_priors,4]
priors (tensor): Prior boxes in center-offset form.
Shape: [num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
decoded bounding box predictions
"""
# print(loc.shape,"locccc")
boxes = torch.cat((
priors[:, :, 0:1] + loc[:, :, 0:1] * 0.1 * priors[:, :, 2:3],
priors[:, :, 1:2] + loc[:, :, 1:2] * 0.1 * priors[:, :, 3:4],
priors[:, :, 2:3] * torch.exp(loc[:, :, 2:3] * 0.2),
priors[:, :, 3:4] * torch.exp(loc[:, :, 3:4] * 0.2),
), 2)
return boxes
# boxes=boxes.numpy()
# center_x = boxes[:, 0]
# center_y = boxes[:, 1]
#
# w = boxes[:, 2]
# h = boxes[:, 3]
#
# xmin = center_x - (w / 2)
# ymin = center_y - (h / 2)
#
# xmax = center_x + (w / 2)
# ymax = center_y + (h / 2)
#
# return torch.from_numpy(np.column_stack([xmin, ymin, xmax, ymax]))
# boxes = torch.cat((
# priors[:, :2] + loc[:, :2] * 0.1 * priors[:, 2:],
# priors[:, 2:] * torch.exp(loc[:, 2:] * 0.2)), 1)
# boxes[:, :2] -= boxes[:, 2:] / 2
# boxes[:, 2:] += boxes[:, :2]
# return boxes
def decode_landm(pre, priors, variances):
# print("decode landmarks: ", ">>>>>>>>>>>>>>>>>>.")
# print("decode pre: ", pre)
# print("decode priors: ", priors)
# print("decode variances: ", variances)
"""Decode landm from predictions using priors to undo
the encoding we did for offset regression at train time.
Args:
pre (tensor): landm predictions for loc layers,
Shape: [num_priors,10]
priors (tensor): Prior boxes in center-offset form.
Shape: [num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
decoded landm predictions
decode pre: torch.Size([4200, 10])
decode priors: torch.Size([4200, 4])
decode variances: 2
"""
# landms = torch.cat((priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
# priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
# priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
# priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
# priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
# ), dim=1)
# landms = torch.cat((priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
# priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
# priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
# priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:]
# # ,
# # priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
# ), dim=1)
# return landms
landms = torch.cat((priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
), dim=1)
return landms
class ClassHead(nn.Module):
def init(self,inchannels=512,num_anchors=3):
super(ClassHead,self).init()
self.num_anchors = num_anchors
self.conv1x1 = nn.Conv2d(inchannels,self.num_anchors*2,kernel_size=(1,1),stride=1,padding=0)
def forward(self,x):
out = self.conv1x1(x)
out = out.permute(0,2,3,1).contiguous()
return out.view(out.shape[0], -1, 2)
class BboxHead(nn.Module):
def init(self,inchannels=512,num_anchors=3):
super(BboxHead,self).init()
self.conv1x1 = nn.Conv2d(inchannels,num_anchors*4,kernel_size=(1,1),stride=1,padding=0)
def forward(self,x):
out = self.conv1x1(x)
out = out.permute(0,2,3,1).contiguous()
return out.view(out.shape[0], -1, 4)
class LandmarkHead(nn.Module):
def init(self,inchannels=512,num_anchors=3):
super(LandmarkHead,self).init()
self.conv1x1 = nn.Conv2d(inchannels,num_anchors*10,kernel_size=(1,1),stride=1,padding=0)
def forward(self,x):
out = self.conv1x1(x)
out = out.permute(0,2,3,1).contiguous()
return out.view(out.shape[0], -1, 10)
class RetinaFace(nn.Module):
def init(self, cfg = None, phase = 'train'):
"""
:param cfg: Network related settings.
:param phase: train or test.
"""
super(RetinaFace,self).init()
self.phase = phase
backbone = None
if cfg['name'] == 'mobilenet0.25':
backbone = MobileNetV1()
if cfg['pretrain']:
checkpoint = torch.load("./weights/mobilenetV1X0.25_pretrain.tar", map_location=torch.device('cpu'))
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
name = k[7:] # remove module.
new_state_dict[name] = v
# load params
backbone.load_state_dict(new_state_dict)
elif cfg['name'] == 'Resnet50':
import torchvision.models as models
backbone = models.resnet50(pretrained=cfg['pretrain'])
self.body = _utils.IntermediateLayerGetter(backbone, cfg['return_layers'])
in_channels_stage2 = cfg['in_channel']
in_channels_list = [
in_channels_stage2 * 2,
in_channels_stage2 * 4,
in_channels_stage2 * 8,
]
out_channels = cfg['out_channel']
self.fpn = FPN(in_channels_list,out_channels)
self.ssh1 = SSH(out_channels, out_channels)
self.ssh2 = SSH(out_channels, out_channels)
self.ssh3 = SSH(out_channels, out_channels)
self.ClassHead = self._make_class_head(fpn_num=3, inchannels=cfg['out_channel'])
self.BboxHead = self._make_bbox_head(fpn_num=3, inchannels=cfg['out_channel'])
self.LandmarkHead = self._make_landmark_head(fpn_num=3, inchannels=cfg['out_channel'])
self.priorbox = PriorBox(cfg, image_size=(320, 320), phase=self.phase)
self.priors = self.priorbox.forward()
def _make_class_head(self,fpn_num=3,inchannels=64,anchor_num=2):
classhead = nn.ModuleList()
for i in range(fpn_num):
classhead.append(ClassHead(inchannels,anchor_num))
return classhead
def _make_bbox_head(self,fpn_num=3,inchannels=64,anchor_num=2):
bboxhead = nn.ModuleList()
for i in range(fpn_num):
bboxhead.append(BboxHead(inchannels,anchor_num))
return bboxhead
def _make_landmark_head(self,fpn_num=3,inchannels=64,anchor_num=2):
landmarkhead = nn.ModuleList()
for i in range(fpn_num):
landmarkhead.append(LandmarkHead(inchannels,anchor_num))
return landmarkhead
def forward(self,inputs):
out = self.body(inputs)
from math import ceil
# FPN
fpn = self.fpn(out)
# SSH
feature1 = self.ssh1(fpn[0])
feature2 = self.ssh2(fpn[1])
feature3 = self.ssh3(fpn[2])
features = [feature1, feature2, feature3]
loc = list()
conf = list()
landm = list()
for (x, l, c,lam) in zip(features, self.BboxHead, self.ClassHead,self.LandmarkHead):
# loc.append(l(x).permute(0, 2,3, 1).contiguous())
# conf.append(c(x).permute(0, 2,3,1).contiguous())
# landm.append(lam(x).permute(0,2,3, 1).contiguous())
loc.append(l(x))
conf.append(c(x))
landm.append(lam(x))
feature_anchors = [tmp_ratio * ceil(320 / i) * ceil(320 / i) for
tmp_ratio, i in zip([ 2, 2,2], [8, 16, 32])]
#
# for each_feature in feature_anchors:
# print('each_feature.shape: ',each_feature)
#
# for each_loc in loc:
# print('each_loc: ',each_loc.shape)
#
# for each_landm in landm:
# print('each_landm: ',each_landm.shape)
bbox_regressions = torch.cat(
[o.view(-1, tmp_feature_anchors, 4) for tmp_feature_anchors, o in zip(feature_anchors, loc)], 1)
classifications = torch.cat(
[o.view(-1, tmp_feature_anchors, 2) for tmp_feature_anchors, o in zip(feature_anchors, conf)], 1)
# ldm_regressions = torch.cat(
# [o.view(-1, tmp_feature_anchors, 10) for tmp_feature_anchors, o in zip(feature_anchors, conf)], 1)
# torch.cat([o.view(o.size(0), -1, 10) for o in landm], 1)
# ldm_regressions = torch.cat([self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1)
# ldm_regressions = torch.cat(
# [o.view(-1, tmp_feature_anchors, 10) for tmp_feature_anchors, o in zip(feature_anchors, landm)], 1)
ldm_regressions = torch.cat([o.view(o.size(0), -1, 10) for o in landm], 1)
# ldm_regressions = torch.cat(landm, dim=1)
# ldm_regressions = torch.cat(
# [o.view(-1, tmp_feature_anchors, 10) for tmp_feature_anchors, o in zip(feature_anchors, conf)], 1)
anchor_num = int(sum(feature_anchors))
conf = F.softmax(classifications, dim=-1)
# landm_mark = decode_landm(ldm_regressions.data.squeeze(0), self.priors.data.squeeze(0), [0.1, 0.2])
self.priors = self.priors.reshape([-1, anchor_num, 4])
boxes = decode_fixed(bbox_regressions, self.priors)
landm_mark = decode_landm(ldm_regressions.data.squeeze(0), bbox_regressions.data.squeeze(0), [0.1, 0.2])
output = (boxes, conf, landm_mark)
return output
# bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1)
# classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)],dim=1)
# ldm_regressions = torch.cat([self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1)
#
# if self.phase == 'train':
# output = (bbox_regressions, classifications, ldm_regressions)
# else:
# output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions)
# return output