cleardusk/3DDFA_V2

How to infer on batch

rzamarefat opened this issue · 7 comments

Hi thank you for this awesome implementation. Is it possible to infer on batch of images at the same time?

I think it is easy to run on batches with small modifications, if you are familiar with pytorch.

I think it is easy to run on batches with small modifications, if you are familiar with pytorch.

But the model requires cropped images. Any suggestions on batch aligned images (but not cropped) inference?

I did a batched version for a single box.
This is intended to be used in videos where there is a single person whose face remains inside the box the whole time.
Most of the funcs can be rewritten to work with batches and in pytorch.

I did a batched version for a single box. This is intended to be used in videos where there is a single person whose face remains inside the box the whole time. Most of the funcs can be rewritten to work with batches and in pytorch.

Sounds cool! Can you share these funcs with us?

"""
Source code from cleardusk
Optimized by Juan F. Montesinos to work with batches in GPU
"""
import os.path as osp
import torch
from torch import nn
from torchvision.transforms import Compose

import models
from bfm import BFMModel
from utils.io import _load
from utils.functions import (
    crop_video, reshape_fortran, parse_roi_box_from_bbox,
)
from utils.tddfa_util import (
    load_model, _batched_parse_param, batched_similar_transform,
    ToTensorGjz, NormalizeGjz
)

make_abs_path = lambda fn: osp.join(osp.dirname(osp.realpath(__file__)), fn)


class TDDFA(nn.Module):
    """TDDFA: named Three-D Dense Face Alignment (TDDFA)"""

    def __init__(self, **kvs):
        self.size = kvs.get('size', 120)

        # load BFM
        self.bfm = BFMModel(
            bfm_fp=kvs.get('bfm_fp', make_abs_path('configs/bfm_noneck_v3.pkl')),
            shape_dim=kvs.get('shape_dim', 40),
            exp_dim=kvs.get('exp_dim', 10)
        )
        self.tri = self.bfm.tri

        param_mean_std_fp = kvs.get(
            'param_mean_std_fp', make_abs_path(f'configs/param_mean_std_62d_{self.size}x{self.size}.pkl')
        )

        # load model, default output is dimension with length 62 = 12(pose) + 40(shape) +10(expression)
        model = getattr(models, kvs.get('arch'))(
            num_classes=kvs.get('num_params', 62),
            widen_factor=kvs.get('widen_factor', 1),
            size=self.size,
            mode=kvs.get('mode', 'small')
        )
        model = load_model(model, kvs.get('checkpoint_fp'))


        self.model = model

        # data normalization
        self.transform_normalize = NormalizeGjz(mean=127.5, std=128)
        transform_to_tensor = ToTensorGjz()
        transform = Compose([transform_to_tensor, self.transform_normalize])
        self.transform = transform

        # params normalization config
        r = _load(param_mean_std_fp)
        self.param_mean = torch.from_numpy(r.get('mean'))
        self.param_std = torch.from_numpy(r.get('std'))
        self.param_mean = self.param_mean
        self.param_std = self.param_std



    def batched_inference(self, video_ori, bbox, **kvs):
        """The main call of TDDFA, given image and box / landmark, return 3DMM params and roi_box
        :param img_ori: the input image
        :param objs: left, top, right, bottom = bbox (think in lines like y=25, not points)
        :param kvs: options
        :return: param list and roi_box list
        """
        roi_box = parse_roi_box_from_bbox(bbox)
        video = crop_video(video_ori, roi_box)
        img = torch.nn.functional.interpolate(video, size=(self.size, self.size), mode='bilinear', align_corners=False)

        inp = self.transform_normalize(img)
        param = self.model(inp)

        param = param * self.param_std + self.param_mean  # re-scale

        return param, roi_box

    def batched_recon_vers(self, param, roi_box, **kvs):
        dense_flag = kvs.get('dense_flag', False)
        size = self.size
        R, offset, alpha_shp, alpha_exp = _batched_parse_param(param)
        if dense_flag:
            tensor = self.bfm.u + self.bfm.w_shp @ alpha_shp + self.bfm.w_exp @ alpha_exp
        else:
            tensor = self.bfm.u_base + self.bfm.w_shp_base @ alpha_shp + self.bfm.w_exp_base @ alpha_exp
        pts3d = R @ reshape_fortran(tensor, (param.shape[0], 3, -1)) + offset
        pts3d = batched_similar_transform(pts3d, roi_box, size)

        return pts3d

So it's basically a matter of getting rid of the framework the author proposes.
I find it's not very well designed as pytorch automatically allows u to work on cpu or gpu just by writting 2 words.

Besides, some funcs are needed, like reshape in fortran order which is not implemented in pytorch:

def reshape_fortran(x, shape):
    if len(x.shape) > 0:
        x = x.permute(*reversed(range(len(x.shape))))
    return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape))))
def crop_video(video, roi_box):
    bs, c, h, w = video.shape

    sx, sy, ex, ey = [int(round(_)) for _ in roi_box]
    dh, dw = ey - sy, ex - sx
    res = torch.zeros(bs, c, dh, dw, dtype=video.dtype, device=video.device)

    if sx < 0:
        sx, dsx = 0, -sx
    else:
        dsx = 0

    if ex > w:
        ex, dex = w, dw - (ex - w)
    else:
        dex = dw

    if sy < 0:
        sy, dsy = 0, -sy
    else:
        dsy = 0

    if ey > h:
        ey, dey = h, dh - (ey - h)
    else:
        dey = dh

    res[..., dsy:dey, dsx:dex] = video[..., sy:ey, sx:ex]
    return res
def batched_similar_transform(pts3d, roi_box, size):
    pts3d[:, 0, :] -= 1  # for Python compatibility
    pts3d[:, 2, :] -= 1
    pts3d[:, 1, :] = size - pts3d[:, 1, :]

    sx, sy, ex, ey = roi_box
    scale_x = (ex - sx) / size
    scale_y = (ey - sy) / size
    pts3d[:, 0, :] = pts3d[:, 0, :] * scale_x + sx
    pts3d[:, 1, :] = pts3d[:, 1, :] * scale_y + sy
    s = (scale_x + scale_y) / 2
    pts3d[:, 2, :] *= s
    pts3d[:, 2, :] -= torch.min(pts3d[:, 2, :], dim=-1)[0].unsqueeze(-1)
    return pts3d.contiguous()
def _batched_parse_param(param):
    """matrix pose form
    param: shape=(trans_dim+shape_dim+exp_dim,), i.e., 62 = 12 + 40 + 10
    """

    assert param.ndim == 2
    bs, n = param.shape
    if n == 62:
        trans_dim, shape_dim, exp_dim = 12, 40, 10
    elif n == 72:
        trans_dim, shape_dim, exp_dim = 12, 40, 20
    elif n == 141:
        trans_dim, shape_dim, exp_dim = 12, 100, 29
    else:
        raise Exception(f'Undefined templated param parsing rule')

    R_ = param[:, :trans_dim].reshape(bs, 3, -1)
    R = R_[..., :3]
    offset = R_[..., -1].reshape(bs, 3, 1)
    alpha_shp = param[:, trans_dim:trans_dim + shape_dim].reshape(bs, -1, 1)
    alpha_exp = param[:, trans_dim + shape_dim:].reshape(bs, -1, 1)

    return R, offset, alpha_shp, alpha_exp

Lastly, rewritting some items as nn.Modules so the auto-allocation works

def _batched_parse_param(param):
    """matrix pose form
    param: shape=(trans_dim+shape_dim+exp_dim,), i.e., 62 = 12 + 40 + 10
    """

    assert param.ndim == 2
    bs, n = param.shape
    if n == 62:
        trans_dim, shape_dim, exp_dim = 12, 40, 10
    elif n == 72:
        trans_dim, shape_dim, exp_dim = 12, 40, 20
    elif n == 141:
        trans_dim, shape_dim, exp_dim = 12, 100, 29
    else:
        raise Exception(f'Undefined templated param parsing rule')

    R_ = param[:, :trans_dim].reshape(bs, 3, -1)
    R = R_[..., :3]
    offset = R_[..., -1].reshape(bs, 3, 1)
    alpha_shp = param[:, trans_dim:trans_dim + shape_dim].reshape(bs, -1, 1)
    alpha_exp = param[:, trans_dim + shape_dim:].reshape(bs, -1, 1)

    return R, offset, alpha_shp, alpha_exp

And that's all more or less

I'd eventually upload the code ready to use but have other prios atm.