How to infer on batch
rzamarefat opened this issue · 7 comments
Hi thank you for this awesome implementation. Is it possible to infer on batch of images at the same time?
I think it is easy to run on batches with small modifications, if you are familiar with pytorch.
I think it is easy to run on batches with small modifications, if you are familiar with pytorch.
But the model requires cropped images. Any suggestions on batch aligned images (but not cropped) inference?
I did a batched version for a single box.
This is intended to be used in videos where there is a single person whose face remains inside the box the whole time.
Most of the funcs can be rewritten to work with batches and in pytorch.
I did a batched version for a single box. This is intended to be used in videos where there is a single person whose face remains inside the box the whole time. Most of the funcs can be rewritten to work with batches and in pytorch.
Sounds cool! Can you share these funcs with us?
"""
Source code from cleardusk
Optimized by Juan F. Montesinos to work with batches in GPU
"""
import os.path as osp
import torch
from torch import nn
from torchvision.transforms import Compose
import models
from bfm import BFMModel
from utils.io import _load
from utils.functions import (
crop_video, reshape_fortran, parse_roi_box_from_bbox,
)
from utils.tddfa_util import (
load_model, _batched_parse_param, batched_similar_transform,
ToTensorGjz, NormalizeGjz
)
make_abs_path = lambda fn: osp.join(osp.dirname(osp.realpath(__file__)), fn)
class TDDFA(nn.Module):
"""TDDFA: named Three-D Dense Face Alignment (TDDFA)"""
def __init__(self, **kvs):
self.size = kvs.get('size', 120)
# load BFM
self.bfm = BFMModel(
bfm_fp=kvs.get('bfm_fp', make_abs_path('configs/bfm_noneck_v3.pkl')),
shape_dim=kvs.get('shape_dim', 40),
exp_dim=kvs.get('exp_dim', 10)
)
self.tri = self.bfm.tri
param_mean_std_fp = kvs.get(
'param_mean_std_fp', make_abs_path(f'configs/param_mean_std_62d_{self.size}x{self.size}.pkl')
)
# load model, default output is dimension with length 62 = 12(pose) + 40(shape) +10(expression)
model = getattr(models, kvs.get('arch'))(
num_classes=kvs.get('num_params', 62),
widen_factor=kvs.get('widen_factor', 1),
size=self.size,
mode=kvs.get('mode', 'small')
)
model = load_model(model, kvs.get('checkpoint_fp'))
self.model = model
# data normalization
self.transform_normalize = NormalizeGjz(mean=127.5, std=128)
transform_to_tensor = ToTensorGjz()
transform = Compose([transform_to_tensor, self.transform_normalize])
self.transform = transform
# params normalization config
r = _load(param_mean_std_fp)
self.param_mean = torch.from_numpy(r.get('mean'))
self.param_std = torch.from_numpy(r.get('std'))
self.param_mean = self.param_mean
self.param_std = self.param_std
def batched_inference(self, video_ori, bbox, **kvs):
"""The main call of TDDFA, given image and box / landmark, return 3DMM params and roi_box
:param img_ori: the input image
:param objs: left, top, right, bottom = bbox (think in lines like y=25, not points)
:param kvs: options
:return: param list and roi_box list
"""
roi_box = parse_roi_box_from_bbox(bbox)
video = crop_video(video_ori, roi_box)
img = torch.nn.functional.interpolate(video, size=(self.size, self.size), mode='bilinear', align_corners=False)
inp = self.transform_normalize(img)
param = self.model(inp)
param = param * self.param_std + self.param_mean # re-scale
return param, roi_box
def batched_recon_vers(self, param, roi_box, **kvs):
dense_flag = kvs.get('dense_flag', False)
size = self.size
R, offset, alpha_shp, alpha_exp = _batched_parse_param(param)
if dense_flag:
tensor = self.bfm.u + self.bfm.w_shp @ alpha_shp + self.bfm.w_exp @ alpha_exp
else:
tensor = self.bfm.u_base + self.bfm.w_shp_base @ alpha_shp + self.bfm.w_exp_base @ alpha_exp
pts3d = R @ reshape_fortran(tensor, (param.shape[0], 3, -1)) + offset
pts3d = batched_similar_transform(pts3d, roi_box, size)
return pts3d
So it's basically a matter of getting rid of the framework the author proposes.
I find it's not very well designed as pytorch automatically allows u to work on cpu or gpu just by writting 2 words.
Besides, some funcs are needed, like reshape in fortran order which is not implemented in pytorch:
def reshape_fortran(x, shape):
if len(x.shape) > 0:
x = x.permute(*reversed(range(len(x.shape))))
return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape))))
def crop_video(video, roi_box):
bs, c, h, w = video.shape
sx, sy, ex, ey = [int(round(_)) for _ in roi_box]
dh, dw = ey - sy, ex - sx
res = torch.zeros(bs, c, dh, dw, dtype=video.dtype, device=video.device)
if sx < 0:
sx, dsx = 0, -sx
else:
dsx = 0
if ex > w:
ex, dex = w, dw - (ex - w)
else:
dex = dw
if sy < 0:
sy, dsy = 0, -sy
else:
dsy = 0
if ey > h:
ey, dey = h, dh - (ey - h)
else:
dey = dh
res[..., dsy:dey, dsx:dex] = video[..., sy:ey, sx:ex]
return res
def batched_similar_transform(pts3d, roi_box, size):
pts3d[:, 0, :] -= 1 # for Python compatibility
pts3d[:, 2, :] -= 1
pts3d[:, 1, :] = size - pts3d[:, 1, :]
sx, sy, ex, ey = roi_box
scale_x = (ex - sx) / size
scale_y = (ey - sy) / size
pts3d[:, 0, :] = pts3d[:, 0, :] * scale_x + sx
pts3d[:, 1, :] = pts3d[:, 1, :] * scale_y + sy
s = (scale_x + scale_y) / 2
pts3d[:, 2, :] *= s
pts3d[:, 2, :] -= torch.min(pts3d[:, 2, :], dim=-1)[0].unsqueeze(-1)
return pts3d.contiguous()
def _batched_parse_param(param):
"""matrix pose form
param: shape=(trans_dim+shape_dim+exp_dim,), i.e., 62 = 12 + 40 + 10
"""
assert param.ndim == 2
bs, n = param.shape
if n == 62:
trans_dim, shape_dim, exp_dim = 12, 40, 10
elif n == 72:
trans_dim, shape_dim, exp_dim = 12, 40, 20
elif n == 141:
trans_dim, shape_dim, exp_dim = 12, 100, 29
else:
raise Exception(f'Undefined templated param parsing rule')
R_ = param[:, :trans_dim].reshape(bs, 3, -1)
R = R_[..., :3]
offset = R_[..., -1].reshape(bs, 3, 1)
alpha_shp = param[:, trans_dim:trans_dim + shape_dim].reshape(bs, -1, 1)
alpha_exp = param[:, trans_dim + shape_dim:].reshape(bs, -1, 1)
return R, offset, alpha_shp, alpha_exp
Lastly, rewritting some items as nn.Modules so the auto-allocation works
def _batched_parse_param(param):
"""matrix pose form
param: shape=(trans_dim+shape_dim+exp_dim,), i.e., 62 = 12 + 40 + 10
"""
assert param.ndim == 2
bs, n = param.shape
if n == 62:
trans_dim, shape_dim, exp_dim = 12, 40, 10
elif n == 72:
trans_dim, shape_dim, exp_dim = 12, 40, 20
elif n == 141:
trans_dim, shape_dim, exp_dim = 12, 100, 29
else:
raise Exception(f'Undefined templated param parsing rule')
R_ = param[:, :trans_dim].reshape(bs, 3, -1)
R = R_[..., :3]
offset = R_[..., -1].reshape(bs, 3, 1)
alpha_shp = param[:, trans_dim:trans_dim + shape_dim].reshape(bs, -1, 1)
alpha_exp = param[:, trans_dim + shape_dim:].reshape(bs, -1, 1)
return R, offset, alpha_shp, alpha_exp
And that's all more or less
I'd eventually upload the code ready to use but have other prios atm.