jeonggg119/DL_paper

[CV_3D] MVSNet: Depth Inference for Unstructured Multi-view Stereo

jeonggg119 opened this issue · 0 comments

MVSNet: Depth Inference for Unstructured Multi-view Stereo

Paper Review

Abstract

  • MVSNet : E2E DL model for depth map inference from multi-view imgs
    • (1) Extract deep visual img features
    • (2) Build 3D Cost Volume upon reference camera frustum via differentiable homography warping
    • (3) Apply 3D conv to regularize and regress initial Depth Map → Refine with reference img
  • Multiple features를 One cost feature로 mapping 하는 Variance-based metric 이용해서 N-view inputs 처리 가능
  • Experiments
    • DTU dataset 대해 outperform SOTA & faster in runtime → benchmarking
    • T&T dataset 대해 rank first without fine-tuning → strong generalization

Introduction

  • Multi-View Stereo (MVS) : estimating dense representation from overlapping imgs
  • Traditional methods
    • How : using hand-crafted similarity metrics & engineered regularizations
    • Limitation : dense matching intractable for global semantic information (ex. low-textured, specular, reflective region) → incomplete reconstruction
  • Learnable CNN-based methods for 2-view stereo matching
    • Global semantic information 문제 해결
    • How : 2-view에서는 camera params 없이도 image pairs 미리 보정해서 horizontal pixel-wise disparity estimation 가능
    • Limitation : MVS에서는 input img가 arbitrary camera geometry 일 수도 있기에 learning method 사용 어려움
  • Learnable CNN-based methods for MVS recon
    • 위의 Limitation 인해 MVS와 CNN의 fit 안맞아서 거의 시도되지 않았음
    • Ex. SurfaceNet using CVC (Color Voxel Cubes), LSM (Learned Stereo Machine)
    • Limitation : volumetric representation of regular grids 사용하기에 huge memory consumption of 3D volumes 인해 network scale up 어려움 (long time required OR only for synthetic objects in low volume resolution)
  • MVSNet
    • How : computing one depth map at each time (not whole 3D scene at once)
    • Input : one reference img and several source imgs → to infer depth map for reference img
    • Key insight : Differentiable homography warping operation
      • to encode camera geometries implicitly to build 3D Cost Volumes from 2D img features
    • Next step : Multi-scale 3D conv
      • to regularize and regress initial Depth Map → Refine with reference img
    • Major differences
      • 3D Cost Volume is built upon camera frustum instead of regular Euclidean space
      • Decoupled MVS recon to smaller problems of per-view depth map estimation → large-scale recon possible!

Related work

MVS Reconstruction

(분류 기준 : Output representation)

  • Direct Point Cloud recon : 3D point에서 직접 수행 → sequential propagation 인해 hard to be fully parallelized, long time
  • Volumetric recon : 3D space를 regular grid로 나눈 후, each voxel이 surface에 붙어있는지 추정 → space discretization error, high memory consumption
  • Depth map recon : only one reference img와 a few source imgs에만 집중하는 small problems of per-view estimation로 분리 + PC 또는 Volumetric recon에 쉽게 fuse 가능

Learned Stereo

Traditional Stereo 방법 대신 DL model 사용하기 시작!

  • Pair-wise patch matching
    • DL network to match two img patches
    • Learned features for stereo matching and semi-global matching(SGM) for post-processing
  • Cost regularization
    • SGMNet, CNN-CRF, GCNet
    • GCNet (SOTA) : 3D CNN으로 cost volume을 regularize 하고 disparity를 regress하는 E2E model

Learned MVS

Fewer attempts ...

  • Multi-patch similarity (new metric for MVS)
    • SurfaceNet : sophisticated voxel-wise view 선택해서 cost volume 계산 → 3D CNN으로 정규화하고 surface voxel 추론
    • LSM : camera parameters are encoded as projection for cost volume → 3D CNN으로 voxel이 surface에 속하는지 분류
    • But, 두 방법 다 volumetric representation 한계로 인해 small-scale recon만 가능

MVSNet

image

(1) Image Feature Extraction

  • Goal : To extract deep features $F$ of N개 input imgs $I$
  • 2D Network : 8-layer 2D CNN
    • layer = Conv + BN + ReLU except for last layer
    • layer 1,2 & 4,5 : extract higher-level representation
    • layer 3 & layer 6 : s=2 → divide feature towers into 3 scales (original input size, 1/2, 1/4)
  • Output : N개 32-channel feature maps downsized by 4 in each dim
    • original neighboring information of each remaining pixel은 32-channel pixel descriptor에 의해 이미 encoding 되어 있음 → dense matching 할 때 useful context information 잃어버릴 걱정 X
  • Ablation study : original img 대해 dense matching 했을 때 보다 extracted feature maps 대해 했을 때 recon quality 훨씬 굿
    image
class UniNetDS2(Network):
    """Simple UniNet, as described in the paper."""

    def setup(self):
        print ('2D with 32 filters')
        base_filter = 8
        (self.feed('data')
        .conv_bn(3, base_filter, 1, center=True, scale=True, name='conv0_0')
        .conv_bn(3, base_filter, 1, center=True, scale=True, name='conv0_1')
        .conv_bn(5, base_filter * 2, 2, center=True, scale=True, name='conv1_0')
        .conv_bn(3, base_filter * 2, 1, center=True, scale=True, name='conv1_1')
        .conv_bn(3, base_filter * 2, 1, center=True, scale=True, name='conv1_2')
        .conv_bn(5, base_filter * 4, 2, center=True, scale=True, name='conv2_0')
        .conv_bn(3, base_filter * 4, 1, center=True, scale=True, name='conv2_1')
        .conv(3, base_filter * 4, 1, biased=False, relu=False, name='conv2_2'))

### model.py -> def inference
    # image feature extraction    
    if is_master_gpu:
        ref_tower = UNetDS2GN({'data': ref_image}, is_training=True, reuse=False)
    else:
        ref_tower = UNetDS2GN({'data': ref_image}, is_training=True, reuse=True)
    view_towers = []
    for view in range(1, FLAGS.view_num):
        view_image = tf.squeeze(tf.slice(images, [0, view, 0, 0, 0], [-1, 1, -1, -1, -1]), axis=1)
        view_tower = UNetDS2GN({'data': view_image}, is_training=True, reuse=True)
        view_towers.append(view_tower)

(2) Cost Volume

  • Goal : To build 3D Cost Volume from extracted feature maps and input cameras
  • How : regular grid로 space를 나누지 않고, reference camera frustum 위에 cost volume 구축
  • Notations
    • $I_1$ : reference img → $F_1$ : reference feature map
    • $I_i$ (i=2~N) : source imgs → $F_i$ : feature map
    • ${K_i, R_i, t_i}$ (i=1~N) : camera intrinsics, rotations, translations
    • $n_1$ : principle axis of reference camera

Differentiable Homography

  • Warping all feature maps $F$ → N개의 feature volume $V$ (By different fronto-parallel planes of reference camera)
  • Coordinate mapping from warped $V_i(d)$ to $F_i$ at $d$ By planar transformation $x'$ ~ $H_i(d)*x$
    • ~ : projective equality
    • $H_i(d)$ : 3x3 Homography matrix bw i-th feature map $F_i$ and reference feature map $F_1$ at depth $d$
      image
  • ⇔ Classical plane sweeping stereo + Differentiable bilinear interporlation to sample pixels from feature map (imgs X)
  • Differentiable Warping operation : 2D feature extraction과 3D regularization network 연결 → E2E depth map inference !

Cost Metric : Variance-based Metric $M$

  • Notations
    • $W$(img width), $H$(img height), $D$(depth sample #), $F$(feature map channel #)
    • Feature volume size : $V$ = $W$/4 * $H$/4 * $D$ * $F$
    • $\overline{V_i}$ : Average volume of all feature volumes
  • Mapping : N개의 feature volume $V_i$ → 1개의 cost volume $C$
    image
  • Matching cost
    • Traditional MVS methods : pairwise costs bw refer img and all src imgs in heuristic way
    • MVSNet : all views contribute equally to matching cost & no preference to refer img
  • Mean vs Variance
    • Prior research using Mean operation : infer multi-patch similarity with additional pre- and post- CNN layers
    • MVSNet using Variance operation : measure multi-view feature difference explicitly
### model.py -> def inference
    # build cost volume by differentiable homography
    with tf.name_scope('cost_volume_homography'):
        depth_costs = []
        for d in range(depth_num):
            # compute cost (variation metric)
            ave_feature = ref_tower.get_output()
            ave_feature2 = tf.square(ref_tower.get_output())
            for view in range(0, FLAGS.view_num - 1):
                homography = tf.slice(view_homographies[view], begin=[0, d, 0, 0], size=[-1, 1, 3, 3])
                homography = tf.squeeze(homography, axis=1)
                warped_view_feature = tf_transform_homography(view_towers[view].get_output(), homography)
                ave_feature = ave_feature + warped_view_feature
                ave_feature2 = ave_feature2 + tf.square(warped_view_feature)
            ave_feature = ave_feature / FLAGS.view_num
            ave_feature2 = ave_feature2 / FLAGS.view_num
            cost = ave_feature2 - tf.square(ave_feature)
            depth_costs.append(cost)
        cost_volume = tf.stack(depth_costs, axis=1)

Cost Volume Regularization

  • What : raw Cost volume $C$ → regulated Probability volume $P$
  • Why : $C$는 img features에서 계산되었기에 noise-contaminated 위험 존재 → smoothness constraints와 통합 필요
  • How : Multi-scale 3D CNN (4-scale network)
    • ≒ 3D Unet encoder-decoder structure (aggregating neighboring information from large receptive field)
    • +) Computation 줄이기 위해 channel수(32→8) 줄이고, conv layers수(3→2) 줄임
  • Output : 1-channel volume → softmax operation along depth direction for probability normalization
  • Usages : per-pixel depth estimation, measuring estimation confidence
    => determining recon quality by probability distribution, outlier filtering
class RegNetUS0(Network):
    """network for regularizing 3D cost volume in a encoder-decoder style. Keeping original size."""

    def setup(self):
        print ('Shallow 3D UNet with 8 channel input')
        base_filter = 8
        (self.feed('data')
        .conv_bn(3, base_filter * 2, 2, center=True, scale=True, name='3dconv1_0')
        .conv_bn(3, base_filter * 4, 2, center=True, scale=True, name='3dconv2_0')
        .conv_bn(3, base_filter * 8, 2, center=True, scale=True, name='3dconv3_0'))

        (self.feed('data')
        .conv_bn(3, base_filter, 1, center=True, scale=True, name='3dconv0_1'))

        (self.feed('3dconv1_0')
        .conv_bn(3, base_filter * 2, 1, center=True, scale=True, name='3dconv1_1'))

        (self.feed('3dconv2_0')
        .conv_bn(3, base_filter * 4, 1, center=True, scale=True, name='3dconv2_1'))

        (self.feed('3dconv3_0')
        .conv_bn(3, base_filter * 8, 1, center=True, scale=True, name='3dconv3_1')
        .deconv_bn(3, base_filter * 4, 2, center=True, scale=True, name='3dconv4_0'))

        (self.feed('3dconv4_0', '3dconv2_1')
        .add(name='3dconv4_1')
        .deconv_bn(3, base_filter * 2, 2, center=True, scale=True, name='3dconv5_0'))

        (self.feed('3dconv5_0', '3dconv1_1')
        .add(name='3dconv5_1')
        .deconv_bn(3, base_filter, 2, center=True, scale=True, name='3dconv6_0'))

        (self.feed('3dconv6_0', '3dconv0_1')
        .add(name='3dconv6_1')
        .conv(3, 1, 1, biased=False, relu=False, name='3dconv6_2'))

(3) Depth Map

image
Initial Estimation

  • What : regulated Probability volume $P$ → inferred Depth map $D$
  • How : Expectation value along depth direction = Probability weighted sum over all depth hypothesis
    = Soft argmin → fully differentiable operation & armax effect
    image
    • $P(d)$ : probability estimation for all pixels at depth $d$
    • $d$ : depth hypothesis uniformly sampled within [ $d_{min}$ , $d_{max}$ ]
  • Output : depth map (same size to 2D img feature maps = 1/4 size of input img)

Probability Map

  • Why(Observation) : Multi-scale 3D CNN은 probability를 single model로 정규화하는 기능을 가졌지만, falsely matched pixels의 경우 scattered distribution을 띄기에 one peak에 집중 불가
  • Definition : The quality of depth estimation $\hat{d}$ = GT depth가 estimation 근처의 작은 범위 내에 있을 확률
  • How : Probability sum over 4 nearest depth hypothesis to measure estimation quality
  • Effect : better depth map filtering, outlier filtering
def get_propability_map(cv, depth_map, depth_start, depth_interval):
    """ get probability map from cost volume """

    def _repeat_(x, num_repeats):
        """ repeat each element num_repeats times """
        x = tf.reshape(x, [-1])
        ones = tf.ones((1, num_repeats), dtype='int32')
        x = tf.reshape(x, shape=(-1,1))
        x = tf.matmul(x, ones)
        return tf.reshape(x, [-1])

    shape = tf.shape(depth_map)
    batch_size = shape[0]
    height = shape[1]
    width = shape[2]
    depth = tf.shape(cv)[1]

    # byx coordinate, batched & flattened
    b_coordinates = tf.range(batch_size)
    y_coordinates = tf.range(height)
    x_coordinates = tf.range(width)
    b_coordinates, y_coordinates, x_coordinates = tf.meshgrid(b_coordinates, y_coordinates, x_coordinates)
    b_coordinates = _repeat_(b_coordinates, batch_size)
    y_coordinates = _repeat_(y_coordinates, batch_size)
    x_coordinates = _repeat_(x_coordinates, batch_size)

    # d coordinate (floored and ceiled), batched & flattened
    d_coordinates = tf.reshape((depth_map - depth_start) / depth_interval, [-1])
    d_coordinates_left0 = tf.clip_by_value(tf.cast(tf.floor(d_coordinates), 'int32'), 0, depth - 1)
    d_coordinates_left1 = tf.clip_by_value(d_coordinates_left0 - 1, 0, depth - 1)
    d_coordinates1_right0 = tf.clip_by_value(tf.cast(tf.ceil(d_coordinates), 'int32'), 0, depth - 1)
    d_coordinates1_right1 = tf.clip_by_value(d_coordinates1_right0 + 1, 0, depth - 1)

    # voxel coordinates
    voxel_coordinates_left0 = tf.stack(
        [b_coordinates, d_coordinates_left0, y_coordinates, x_coordinates], axis=1)
    voxel_coordinates_left1 = tf.stack(
        [b_coordinates, d_coordinates_left1, y_coordinates, x_coordinates], axis=1)
    voxel_coordinates_right0 = tf.stack(
        [b_coordinates, d_coordinates1_right0, y_coordinates, x_coordinates], axis=1)
    voxel_coordinates_right1 = tf.stack(
        [b_coordinates, d_coordinates1_right1, y_coordinates, x_coordinates], axis=1)

    # get probability image by gathering and interpolation
    prob_map_left0 = tf.gather_nd(cv, voxel_coordinates_left0)
    prob_map_left1 = tf.gather_nd(cv, voxel_coordinates_left1)
    prob_map_right0 = tf.gather_nd(cv, voxel_coordinates_right0)
    prob_map_right1 = tf.gather_nd(cv, voxel_coordinates_right1)
    prob_map = prob_map_left0 + prob_map_left1 + prob_map_right0 + prob_map_right1
    prob_map = tf.reshape(prob_map, [batch_size, height, width, 1])

    return prob_map


### model.py -> def inference
    # depth map by softArgmin
    with tf.name_scope('soft_arg_min'):
        # probability volume by soft max
        probability_volume = tf.nn.softmax(
            tf.scalar_mul(-1, filtered_cost_volume), axis=1, name='prob_volume')
        # depth image by soft argmin
        volume_shape = tf.shape(probability_volume)
        soft_2d = []
        for i in range(FLAGS.batch_size):
            soft_1d = tf.linspace(depth_start[i], depth_end[i], tf.cast(depth_num, tf.int32))
            soft_2d.append(soft_1d)
        soft_2d = tf.reshape(tf.stack(soft_2d, axis=0), [volume_shape[0], volume_shape[1], 1, 1])
        soft_4d = tf.tile(soft_2d, [1, 1, volume_shape[2], volume_shape[3]])
        estimated_depth_map = tf.reduce_sum(soft_4d * probability_volume, axis=1)
        estimated_depth_map = tf.expand_dims(estimated_depth_map, axis=3)

    # probability map
    prob_map = get_propability_map(probability_volume, estimated_depth_map, depth_start, depth_interval)

    return estimated_depth_map, prob_map # filtered_depth_map, probability_volume

Depth Map Refinement

  • Why : Large receptive field 인해 reconstruction boundary의 oversmoothing 문제
  • How : reference img에는 boundary 정보가 있으므로 refine 위한 guidance로 사용
    • MVSNet + Depth residual learning network
      • Pre-scaling of inital depth magnitude to [0, 1] → Refinement 후 back : (biased at certain depth scale 방지)
      • Input : Initial depth map & resized reference img를 4-channel input으로 concat
      • → 32-channel 2D conv 3개와 1-channel conv 1개를 거쳐 Depth residual 학습
      • Last layer : No BN layer and ReLU as to learn negative residual
class RefineNet(Network):
    """network for depth map refinement using original image."""

    def setup(self):

        (self.feed('color_image', 'depth_image')
        .concat(axis=3, name='concat_image'))

        (self.feed('concat_image')
        .conv_bn(3, 32, 1, name='refine_conv0')
        .conv_bn(3, 32, 1, name='refine_conv1')
        .conv_bn(3, 32, 1, name='refine_conv2')
        .conv(3, 1, 1, relu=False, name='refine_conv3'))

        (self.feed('refine_conv3', 'depth_image')
        .add(name='refined_depth_image'))

## model.py
def depth_refine(init_depth_map, image, depth_num, depth_start, depth_interval, is_master_gpu=True):
    """ refine depth image with the image """

    # normalization parameters
    depth_shape = tf.shape(init_depth_map)
    depth_end = depth_start + (tf.cast(depth_num, tf.float32) - 1) * depth_interval
    depth_start_mat = tf.tile(tf.reshape(
        depth_start, [depth_shape[0], 1, 1, 1]), [1, depth_shape[1], depth_shape[2], 1])
    depth_end_mat = tf.tile(tf.reshape(
        depth_end, [depth_shape[0], 1, 1, 1]), [1, depth_shape[1], depth_shape[2], 1])
    depth_scale_mat = depth_end_mat - depth_start_mat

    # normalize depth map (to 0~1)
    init_norm_depth_map = tf.div(init_depth_map - depth_start_mat, depth_scale_mat)

    # resize normalized image to the same size of depth image
    resized_image = tf.image.resize_bilinear(image, [depth_shape[1], depth_shape[2]])

    # refinement network
    if is_master_gpu:
        norm_depth_tower = RefineNet({'color_image': resized_image, 'depth_image': init_norm_depth_map},
                                        is_training=True, reuse=False)
    else:
        norm_depth_tower = RefineNet({'color_image': resized_image, 'depth_image': init_norm_depth_map},
                                        is_training=True, reuse=True)
    norm_depth_map = norm_depth_tower.get_output()

    # denormalize depth map
    refined_depth_map = tf.multiply(norm_depth_map, depth_scale_mat) + depth_start_mat

    return refined_depth_map

Loss Function

image

  • Loss for both estimated (Initial & Refined) depth map are considered
  • Mean absolute difference bw GT and Estimated depth map
  • Considering only pixels with valid GT depth map labels (Not whole img)
  • Notations
    • $p_{valide}$ : set of valid GT pixels
    • $d(p)$ : GT depth value of pixel $p$
    • $\hat{d_i}(p)$ : Initial depth estimation
    • $\hat{d_r}(p)$ : Refined depth map estimation
    • $λ$ = 1.0
def non_zero_mean_absolute_diff(y_true, y_pred, interval):
    """ non zero mean absolute loss for one batch """
    with tf.name_scope('MAE'):
        shape = tf.shape(y_pred)
        interval = tf.reshape(interval, [shape[0]])
        mask_true = tf.cast(tf.not_equal(y_true, 0.0), dtype='float32')
        denom = tf.reduce_sum(mask_true, axis=[1, 2, 3]) + 1e-7
        masked_abs_error = tf.abs(mask_true * (y_true - y_pred))            # 4D
        masked_mae = tf.reduce_sum(masked_abs_error, axis=[1, 2, 3])        # 1D
        masked_mae = tf.reduce_sum((masked_mae / interval) / denom)         # 1
    return masked_mae

def mvsnet_regression_loss(estimated_depth_image, depth_image, depth_interval):
    """ compute loss and accuracy """
    # non zero mean absulote loss
    masked_mae = non_zero_mean_absolute_diff(depth_image, estimated_depth_image, depth_interval)
    # less one accuracy
    less_one_accuracy = less_one_percentage(depth_image, estimated_depth_image, depth_interval)
    # less three accuracy
    less_three_accuracy = less_three_percentage(depth_image, estimated_depth_image, depth_interval)

    return masked_mae, less_one_accuracy, less_three_accuracy

Implementations

Training

Data Preparation

  • DTU dataset (GT pc with normal information)+ generated GT Depth maps
    • DTU dataset : large-scale MVS dataset containing 100↑ scenes with different lighting conditions
    • Point cloud with normal information → Mesh by SPSR → Depth maps by rendering mesh to each viewpoint
      • SPSR(screened Poisson surface reconstruction) : depth-of-tree = 11 (to acquire high quality mesh result)
      • Mesh trimming-factor = 9.5 (to alleviate mesh artifacts)
  • 49 imgs with 7 different lighting conditions for each scan => Total # of training samples : 27097

View Selection

  • Training img : Reference img + 2 Source imgs
  • Downsize imgs in feature extraction → Downsize img resolution 1600x1200 to 800x600 in 3D regularization → Crop img patch with W=640, H=512 from center => img resolution 바뀌었으니 이에 따라 input camera parameters도 바꿔주었음
  • Depth hypotheses are uniformly sampled from [425mm ~ 935mm] with 2mm resolution
  • Environment : TensorFlow, Tesla P100
  • 100,000 iterations

Post-processing

image

Depth Map Filter

  • Goal : To filter out outliers at background and occluded areas before converting depth value to dense point clouds
  • Criteria : Photometric consistency & Geometric consistency
    • Photometric consistency : measuring matching quality
      • (Experiment) Pixels with probability lower than 0.8 = Outliers
    • Geometric consistency : measuring depth consistency among multiple view
      • reference pixel과 another view의 pixel 끼리 각각의 depth 대해 project, reproject 해서 특정 조건식 만족시키도록 함
      • (Experiment) All depths should be at least 3-view consistent

Depth Map Fusion

  • Goal : To integrate depth maps from different views to a unified pc representation
  • Visibility-based fusion → minimize depth occlusions, violations
  • Filtering step에서 visible views for each pixel을 선택하고, all reprojected depths 대해 평균 → suppress recon noises
  • 3D Point cloud 생성하기위해 fused depth maps을 space에 reproject 시킴

Experiments

Benchmarking on DTU dataset

  • MVSNet outperforms all methods in both the completeness & overall quality with a significant margin
    image

Generalization on T&T dataset

  • Using MVSNet trained on DTU without any fine-tuning
    image

Ablations

  • View Number
  • Image Features
  • Cost Metric
  • Depth Refinement

Conclusion

  • MVSNet : unstructed imgs를 input으로 받아서 reference img 대해 depth map 추정 E2E DL Network
  • Core contribution of MVSNet : To encode camera parameters as differentiable homography to build cost volume upon camera frustum → 2D feature extraction과 3D cost regularization 연결
  • Results : DTU 대해 outperform & efficient in speed / T&T 대해 SOTA without fine-tuning → generalization ability

Code Review

## model.py
def get_propability_map(cv, depth_map, depth_start, depth_interval):
    """ get probability map from cost volume """

    def _repeat_(x, num_repeats):
        """ repeat each element num_repeats times """
        x = tf.reshape(x, [-1])
        ones = tf.ones((1, num_repeats), dtype='int32')
        x = tf.reshape(x, shape=(-1,1))
        x = tf.matmul(x, ones)
        return tf.reshape(x, [-1])

    shape = tf.shape(depth_map)
    batch_size = shape[0]
    height = shape[1]
    width = shape[2]
    depth = tf.shape(cv)[1]

    # byx coordinate, batched & flattened
    b_coordinates = tf.range(batch_size)
    y_coordinates = tf.range(height)
    x_coordinates = tf.range(width)
    b_coordinates, y_coordinates, x_coordinates = tf.meshgrid(b_coordinates, y_coordinates, x_coordinates)
    b_coordinates = _repeat_(b_coordinates, batch_size)
    y_coordinates = _repeat_(y_coordinates, batch_size)
    x_coordinates = _repeat_(x_coordinates, batch_size)

    # d coordinate (floored and ceiled), batched & flattened
    d_coordinates = tf.reshape((depth_map - depth_start) / depth_interval, [-1])
    d_coordinates_left0 = tf.clip_by_value(tf.cast(tf.floor(d_coordinates), 'int32'), 0, depth - 1)
    d_coordinates_left1 = tf.clip_by_value(d_coordinates_left0 - 1, 0, depth - 1)
    d_coordinates1_right0 = tf.clip_by_value(tf.cast(tf.ceil(d_coordinates), 'int32'), 0, depth - 1)
    d_coordinates1_right1 = tf.clip_by_value(d_coordinates1_right0 + 1, 0, depth - 1)

    # voxel coordinates
    voxel_coordinates_left0 = tf.stack(
        [b_coordinates, d_coordinates_left0, y_coordinates, x_coordinates], axis=1)
    voxel_coordinates_left1 = tf.stack(
        [b_coordinates, d_coordinates_left1, y_coordinates, x_coordinates], axis=1)
    voxel_coordinates_right0 = tf.stack(
        [b_coordinates, d_coordinates1_right0, y_coordinates, x_coordinates], axis=1)
    voxel_coordinates_right1 = tf.stack(
        [b_coordinates, d_coordinates1_right1, y_coordinates, x_coordinates], axis=1)

    # get probability image by gathering and interpolation
    prob_map_left0 = tf.gather_nd(cv, voxel_coordinates_left0)
    prob_map_left1 = tf.gather_nd(cv, voxel_coordinates_left1)
    prob_map_right0 = tf.gather_nd(cv, voxel_coordinates_right0)
    prob_map_right1 = tf.gather_nd(cv, voxel_coordinates_right1)
    prob_map = prob_map_left0 + prob_map_left1 + prob_map_right0 + prob_map_right1
    prob_map = tf.reshape(prob_map, [batch_size, height, width, 1])

    return prob_map

def inference(images, cams, depth_num, depth_start, depth_interval, is_master_gpu=True):
    """ infer depth image from multi-view images and cameras """

    # dynamic gpu params
    depth_end = depth_start + (tf.cast(depth_num, tf.float32) - 1) * depth_interval

    # reference image
    ref_image = tf.squeeze(tf.slice(images, [0, 0, 0, 0, 0], [-1, 1, -1, -1, 3]), axis=1)
    ref_cam = tf.squeeze(tf.slice(cams, [0, 0, 0, 0, 0], [-1, 1, 2, 4, 4]), axis=1)

    # image feature extraction    
    if is_master_gpu:
        ref_tower = UNetDS2GN({'data': ref_image}, is_training=True, reuse=False)
    else:
        ref_tower = UNetDS2GN({'data': ref_image}, is_training=True, reuse=True)
    view_towers = []
    for view in range(1, FLAGS.view_num):
        view_image = tf.squeeze(tf.slice(images, [0, view, 0, 0, 0], [-1, 1, -1, -1, -1]), axis=1)
        view_tower = UNetDS2GN({'data': view_image}, is_training=True, reuse=True)
        view_towers.append(view_tower)

    # get all homographies
    view_homographies = []
    for view in range(1, FLAGS.view_num):
        view_cam = tf.squeeze(tf.slice(cams, [0, view, 0, 0, 0], [-1, 1, 2, 4, 4]), axis=1)
        homographies = get_homographies(ref_cam, view_cam, depth_num=depth_num,
                                        depth_start=depth_start, depth_interval=depth_interval)
        view_homographies.append(homographies)

    # build cost volume by differentialble homography
    with tf.name_scope('cost_volume_homography'):
        depth_costs = []
        for d in range(depth_num):
            # compute cost (variation metric)
            ave_feature = ref_tower.get_output()
            ave_feature2 = tf.square(ref_tower.get_output())
            for view in range(0, FLAGS.view_num - 1):
                homography = tf.slice(view_homographies[view], begin=[0, d, 0, 0], size=[-1, 1, 3, 3])
                homography = tf.squeeze(homography, axis=1)
				# warped_view_feature = homography_warping(view_towers[view].get_output(), homography)
                warped_view_feature = tf_transform_homography(view_towers[view].get_output(), homography)
                ave_feature = ave_feature + warped_view_feature
                ave_feature2 = ave_feature2 + tf.square(warped_view_feature)
            ave_feature = ave_feature / FLAGS.view_num
            ave_feature2 = ave_feature2 / FLAGS.view_num
            cost = ave_feature2 - tf.square(ave_feature)
            depth_costs.append(cost)
        cost_volume = tf.stack(depth_costs, axis=1)

    # filtered cost volume, size of (B, D, H, W, 1)
    if is_master_gpu:
        filtered_cost_volume_tower = RegNetUS0({'data': cost_volume}, is_training=True, reuse=False)
    else:
        filtered_cost_volume_tower = RegNetUS0({'data': cost_volume}, is_training=True, reuse=True)
    filtered_cost_volume = tf.squeeze(filtered_cost_volume_tower.get_output(), axis=-1)

    # depth map by softArgmin
    with tf.name_scope('soft_arg_min'):
        # probability volume by soft max
        probability_volume = tf.nn.softmax(
            tf.scalar_mul(-1, filtered_cost_volume), axis=1, name='prob_volume')
        # depth image by soft argmin
        volume_shape = tf.shape(probability_volume)
        soft_2d = []
        for i in range(FLAGS.batch_size):
            soft_1d = tf.linspace(depth_start[i], depth_end[i], tf.cast(depth_num, tf.int32))
            soft_2d.append(soft_1d)
        soft_2d = tf.reshape(tf.stack(soft_2d, axis=0), [volume_shape[0], volume_shape[1], 1, 1])
        soft_4d = tf.tile(soft_2d, [1, 1, volume_shape[2], volume_shape[3]])
        estimated_depth_map = tf.reduce_sum(soft_4d * probability_volume, axis=1)
        estimated_depth_map = tf.expand_dims(estimated_depth_map, axis=3)

    # probability map
    prob_map = get_propability_map(probability_volume, estimated_depth_map, depth_start, depth_interval)

    return estimated_depth_map, prob_map#, filtered_depth_map, probability_volume

def inference_mem(images, cams, depth_num, depth_start, depth_interval, is_master_gpu=True):
    """ infer depth image from multi-view images and cameras """

    # dynamic gpu params
    depth_end = depth_start + (tf.cast(depth_num, tf.float32) - 1) * depth_interval
    feature_c = 32
    feature_h = FLAGS.max_h / 4
    feature_w = FLAGS.max_w / 4

    # reference image
    ref_image = tf.squeeze(tf.slice(images, [0, 0, 0, 0, 0], [-1, 1, -1, -1, 3]), axis=1)
    ref_cam = tf.squeeze(tf.slice(cams, [0, 0, 0, 0, 0], [-1, 1, 2, 4, 4]), axis=1)

    # image feature extraction    
    if is_master_gpu:
        ref_tower = UNetDS2GN({'data': ref_image}, is_training=True, reuse=False)
    else:
        ref_tower = UNetDS2GN({'data': ref_image}, is_training=True, reuse=True)
    ref_feature = ref_tower.get_output()
    ref_feature2 = tf.square(ref_feature)

    view_features = []
    for view in range(1, FLAGS.view_num):
        view_image = tf.squeeze(tf.slice(images, [0, view, 0, 0, 0], [-1, 1, -1, -1, -1]), axis=1)
        view_tower = UNetDS2GN({'data': view_image}, is_training=True, reuse=True)
        view_features.append(view_tower.get_output())
    view_features = tf.stack(view_features, axis=0)

    # get all homographies
    view_homographies = []
    for view in range(1, FLAGS.view_num):
        view_cam = tf.squeeze(tf.slice(cams, [0, view, 0, 0, 0], [-1, 1, 2, 4, 4]), axis=1)
        homographies = get_homographies(ref_cam, view_cam, depth_num=depth_num,
                                        depth_start=depth_start, depth_interval=depth_interval)
        view_homographies.append(homographies)
    view_homographies = tf.stack(view_homographies, axis=0)

    # build cost volume by differentialble homography
    with tf.name_scope('cost_volume_homography'):
        depth_costs = []

        for d in range(depth_num):
            # compute cost (standard deviation feature)
            ave_feature = tf.Variable(tf.zeros(
                [FLAGS.batch_size, feature_h, feature_w, feature_c]),
                name='ave', trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
            ave_feature2 = tf.Variable(tf.zeros(
                [FLAGS.batch_size, feature_h, feature_w, feature_c]),
                name='ave2', trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
            ave_feature = tf.assign(ave_feature, ref_feature)
            ave_feature2 = tf.assign(ave_feature2, ref_feature2)

            def body(view, ave_feature, ave_feature2):
                """Loop body."""
                homography = tf.slice(view_homographies[view], begin=[0, d, 0, 0], size=[-1, 1, 3, 3])
                homography = tf.squeeze(homography, axis=1)
                # warped_view_feature = homography_warping(view_features[view], homography)
                warped_view_feature = tf_transform_homography(view_features[view], homography)
                ave_feature = tf.assign_add(ave_feature, warped_view_feature)
                ave_feature2 = tf.assign_add(ave_feature2, tf.square(warped_view_feature))
                view = tf.add(view, 1)
                return view, ave_feature, ave_feature2

            view = tf.constant(0)
            cond = lambda view, *_: tf.less(view, FLAGS.view_num - 1)
            _, ave_feature, ave_feature2 = tf.while_loop(
                cond, body, [view, ave_feature, ave_feature2], back_prop=False, parallel_iterations=1)

            ave_feature = tf.assign(ave_feature, tf.square(ave_feature) / (FLAGS.view_num * FLAGS.view_num))
            ave_feature2 = tf.assign(ave_feature2, ave_feature2 / FLAGS.view_num - ave_feature)
            depth_costs.append(ave_feature2)
        cost_volume = tf.stack(depth_costs, axis=1)

    # filtered cost volume, size of (B, D, H, W, 1)
    if is_master_gpu:
        filtered_cost_volume_tower = RegNetUS0({'data': cost_volume}, is_training=True, reuse=False)
    else:
        filtered_cost_volume_tower = RegNetUS0({'data': cost_volume}, is_training=True, reuse=True)
    filtered_cost_volume = tf.squeeze(filtered_cost_volume_tower.get_output(), axis=-1)

    # depth map by softArgmin
    with tf.name_scope('soft_arg_min'):
        # probability volume by soft max
        probability_volume = tf.nn.softmax(tf.scalar_mul(-1, filtered_cost_volume),
                                           axis=1, name='prob_volume')

        # depth image by soft argmin
        volume_shape = tf.shape(probability_volume)
        soft_2d = []
        for i in range(FLAGS.batch_size):
            soft_1d = tf.linspace(depth_start[i], depth_end[i], tf.cast(depth_num, tf.int32))
            soft_2d.append(soft_1d)
        soft_2d = tf.reshape(tf.stack(soft_2d, axis=0), [volume_shape[0], volume_shape[1], 1, 1])
        soft_4d = tf.tile(soft_2d, [1, 1, volume_shape[2], volume_shape[3]])
        estimated_depth_map = tf.reduce_sum(soft_4d * probability_volume, axis=1)
        estimated_depth_map = tf.expand_dims(estimated_depth_map, axis=3)

    # probability map
    prob_map = get_propability_map(probability_volume, estimated_depth_map, depth_start, depth_interval)

    # return filtered_depth_map, 
    return estimated_depth_map, prob_map


def inference_prob_recurrent(images, cams, depth_num, depth_start, depth_interval, is_master_gpu=True):
    """ infer disparity image from stereo images and cameras """

    # dynamic gpu params
    depth_end = depth_start + (tf.cast(depth_num, tf.float32) - 1) * depth_interval

    # reference image
    ref_image = tf.squeeze(tf.slice(images, [0, 0, 0, 0, 0], [-1, 1, -1, -1, 3]), axis=1)
    ref_cam = tf.squeeze(tf.slice(cams, [0, 0, 0, 0, 0], [-1, 1, 2, 4, 4]), axis=1)

    # image feature extraction    
    if is_master_gpu:
        ref_tower = UNetDS2GN({'data': ref_image}, is_training=True, reuse=False)
    else:
        ref_tower = UNetDS2GN({'data': ref_image}, is_training=True, reuse=True)
    view_towers = []
    for view in range(1, FLAGS.view_num):
        view_image = tf.squeeze(tf.slice(images, [0, view, 0, 0, 0], [-1, 1, -1, -1, -1]), axis=1)
        view_tower = UNetDS2GN({'data': view_image}, is_training=True, reuse=True)
        view_towers.append(view_tower)

    # get all homographies
    view_homographies = []
    for view in range(1, FLAGS.view_num):
        view_cam = tf.squeeze(tf.slice(cams, [0, view, 0, 0, 0], [-1, 1, 2, 4, 4]), axis=1)
        homographies = get_homographies(ref_cam, view_cam, depth_num=depth_num,
                                        depth_start=depth_start, depth_interval=depth_interval)
        view_homographies.append(homographies)

    gru1_filters = 16
    gru2_filters = 4
    gru3_filters = 2
    feature_shape = [FLAGS.batch_size, FLAGS.max_h/4, FLAGS.max_w/4, 32]
    gru_input_shape = [feature_shape[1], feature_shape[2]]
    state1 = tf.zeros([FLAGS.batch_size, feature_shape[1], feature_shape[2], gru1_filters])
    state2 = tf.zeros([FLAGS.batch_size, feature_shape[1], feature_shape[2], gru2_filters])
    state3 = tf.zeros([FLAGS.batch_size, feature_shape[1], feature_shape[2], gru3_filters])
    conv_gru1 = ConvGRUCell(shape=gru_input_shape, kernel=[3, 3], filters=gru1_filters)
    conv_gru2 = ConvGRUCell(shape=gru_input_shape, kernel=[3, 3], filters=gru2_filters)
    conv_gru3 = ConvGRUCell(shape=gru_input_shape, kernel=[3, 3], filters=gru3_filters)

    exp_div = tf.zeros([FLAGS.batch_size, feature_shape[1], feature_shape[2], 1])
    soft_depth_map = tf.zeros([FLAGS.batch_size, feature_shape[1], feature_shape[2], 1])

    with tf.name_scope('cost_volume_homography'):

        # forward cost volume
        depth_costs = []
        for d in range(depth_num):

            # compute cost (variation metric)
            ave_feature = ref_tower.get_output()
            ave_feature2 = tf.square(ref_tower.get_output())

            for view in range(0, FLAGS.view_num - 1):
                homography = tf.slice(
                    view_homographies[view], begin=[0, d, 0, 0], size=[-1, 1, 3, 3])
                homography = tf.squeeze(homography, axis=1)
                # warped_view_feature = homography_warping(view_towers[view].get_output(), homography)
                warped_view_feature = tf_transform_homography(view_towers[view].get_output(), homography)
                ave_feature = ave_feature + warped_view_feature
                ave_feature2 = ave_feature2 + tf.square(warped_view_feature)
            ave_feature = ave_feature / FLAGS.view_num
            ave_feature2 = ave_feature2 / FLAGS.view_num 
            cost = ave_feature2 - tf.square(ave_feature) 
            
            # gru
            reg_cost1, state1 = conv_gru1(-cost, state1, scope='conv_gru1')
            reg_cost2, state2 = conv_gru2(reg_cost1, state2, scope='conv_gru2')
            reg_cost3, state3 = conv_gru3(reg_cost2, state3, scope='conv_gru3')
            reg_cost = tf.layers.conv2d(
                reg_cost3, 1, 3, padding='same', reuse=tf.AUTO_REUSE, name='prob_conv')
            depth_costs.append(reg_cost)
            
        prob_volume = tf.stack(depth_costs, axis=1)
        prob_volume = tf.nn.softmax(prob_volume, axis=1, name='prob_volume')

    return prob_volume

def inference_winner_take_all(images, cams, depth_num, depth_start, depth_end, 
                              is_master_gpu=True, reg_type='GRU', inverse_depth=False):
    """ infer disparity image from stereo images and cameras """

    if not inverse_depth:
        depth_interval = (depth_end - depth_start) / (tf.cast(depth_num, tf.float32) - 1)

    # reference image
    ref_image = tf.squeeze(tf.slice(images, [0, 0, 0, 0, 0], [-1, 1, -1, -1, 3]), axis=1)
    ref_cam = tf.squeeze(tf.slice(cams, [0, 0, 0, 0, 0], [-1, 1, 2, 4, 4]), axis=1)

    # image feature extraction    
    if is_master_gpu:
        ref_tower = UNetDS2GN({'data': ref_image}, is_training=True, reuse=False)
    else:
        ref_tower = UNetDS2GN({'data': ref_image}, is_training=True, reuse=True)
    view_towers = []
    for view in range(1, FLAGS.view_num):
        view_image = tf.squeeze(tf.slice(images, [0, view, 0, 0, 0], [-1, 1, -1, -1, -1]), axis=1)
        view_tower = UNetDS2GN({'data': view_image}, is_training=True, reuse=True)
        view_towers.append(view_tower)

    # get all homographies
    view_homographies = []
    for view in range(1, FLAGS.view_num):
        view_cam = tf.squeeze(tf.slice(cams, [0, view, 0, 0, 0], [-1, 1, 2, 4, 4]), axis=1)
        if inverse_depth:
            homographies = get_homographies_inv_depth(ref_cam, view_cam, depth_num=depth_num,
                                depth_start=depth_start, depth_end=depth_end)
        else:
            homographies = get_homographies(ref_cam, view_cam, depth_num=depth_num,
                                            depth_start=depth_start, depth_interval=depth_interval)
        view_homographies.append(homographies)

    # gru unit
    gru1_filters = 16
    gru2_filters = 4
    gru3_filters = 2
    feature_shape = [FLAGS.batch_size, FLAGS.max_h/4, FLAGS.max_w/4, 32]
    gru_input_shape = [feature_shape[1], feature_shape[2]]
    state1 = tf.zeros([FLAGS.batch_size, feature_shape[1], feature_shape[2], gru1_filters])
    state2 = tf.zeros([FLAGS.batch_size, feature_shape[1], feature_shape[2], gru2_filters])
    state3 = tf.zeros([FLAGS.batch_size, feature_shape[1], feature_shape[2], gru3_filters])
    conv_gru1 = ConvGRUCell(shape=gru_input_shape, kernel=[3, 3], filters=gru1_filters)
    conv_gru2 = ConvGRUCell(shape=gru_input_shape, kernel=[3, 3], filters=gru2_filters)
    conv_gru3 = ConvGRUCell(shape=gru_input_shape, kernel=[3, 3], filters=gru3_filters)

    # initialize variables
    exp_sum = tf.Variable(tf.zeros(
        [FLAGS.batch_size, feature_shape[1], feature_shape[2], 1]),
        name='exp_sum', trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
    depth_image = tf.Variable(tf.zeros(
        [FLAGS.batch_size, feature_shape[1], feature_shape[2], 1]),
        name='depth_image', trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
    max_prob_image = tf.Variable(tf.zeros(
        [FLAGS.batch_size, feature_shape[1], feature_shape[2], 1]),
        name='max_prob_image', trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
    init_map = tf.zeros([FLAGS.batch_size, feature_shape[1], feature_shape[2], 1])

    # define winner take all loop
    def body(depth_index, state1, state2, state3, depth_image, max_prob_image, exp_sum, incre):
        """Loop body."""

        # calculate cost 
        ave_feature = ref_tower.get_output()
        ave_feature2 = tf.square(ref_tower.get_output())
        for view in range(0, FLAGS.view_num - 1):
            homographies = view_homographies[view]
            homographies = tf.transpose(homographies, perm=[1, 0, 2, 3])
            homography = homographies[depth_index]
            # warped_view_feature = homography_warping(view_towers[view].get_output(), homography)
            warped_view_feature = tf_transform_homography(view_towers[view].get_output(), homography)
            ave_feature = ave_feature + warped_view_feature
            ave_feature2 = ave_feature2 + tf.square(warped_view_feature)
        ave_feature = ave_feature / FLAGS.view_num
        ave_feature2 = ave_feature2 / FLAGS.view_num
        cost = ave_feature2 - tf.square(ave_feature)
        cost.set_shape([FLAGS.batch_size, feature_shape[1], feature_shape[2], 32])

        # gru
        reg_cost1, state1 = conv_gru1(-cost, state1, scope='conv_gru1')
        reg_cost2, state2 = conv_gru2(reg_cost1, state2, scope='conv_gru2')
        reg_cost3, state3 = conv_gru3(reg_cost2, state3, scope='conv_gru3')
        reg_cost = tf.layers.conv2d(
            reg_cost3, 1, 3, padding='same', reuse=tf.AUTO_REUSE, name='prob_conv')
        prob = tf.exp(reg_cost)

        # index
        d_idx = tf.cast(depth_index, tf.float32) 
        if inverse_depth:
            inv_depth_start = tf.div(1.0, depth_start)
            inv_depth_end = tf.div(1.0, depth_end)
            inv_interval = (inv_depth_start - inv_depth_end) / (tf.cast(depth_num, 'float32') - 1)
            inv_depth = inv_depth_start - d_idx * inv_interval
            depth = tf.div(1.0, inv_depth)
        else:
            depth = depth_start + d_idx * depth_interval
        temp_depth_image = tf.reshape(depth, [FLAGS.batch_size, 1, 1, 1])
        temp_depth_image = tf.tile(
            temp_depth_image, [1, feature_shape[1], feature_shape[2], 1])

        # update the best
        update_flag_image = tf.cast(tf.less(max_prob_image, prob), dtype='float32')
        new_max_prob_image = update_flag_image * prob + (1 - update_flag_image) * max_prob_image
        new_depth_image = update_flag_image * temp_depth_image + (1 - update_flag_image) * depth_image
        max_prob_image = tf.assign(max_prob_image, new_max_prob_image)
        depth_image = tf.assign(depth_image, new_depth_image)

        # update counter
        exp_sum = tf.assign_add(exp_sum, prob)
        depth_index = tf.add(depth_index, incre)

        return depth_index, state1, state2, state3, depth_image, max_prob_image, exp_sum, incre
    
    # run forward loop
    exp_sum = tf.assign(exp_sum, init_map)
    depth_image = tf.assign(depth_image, init_map)
    max_prob_image = tf.assign(max_prob_image, init_map)
    depth_index = tf.constant(0)
    incre = tf.constant(1)
    cond = lambda depth_index, *_: tf.less(depth_index, depth_num)
    _, state1, state2, state3, depth_image, max_prob_image, exp_sum, incre = tf.while_loop(
        cond, body
        , [depth_index, state1, state2, state3, depth_image, max_prob_image, exp_sum, incre]
        , back_prop=False, parallel_iterations=1)

    # get output
    forward_exp_sum = exp_sum + 1e-7
    forward_depth_map = depth_image
    return forward_depth_map, max_prob_image / forward_exp_sum

def depth_refine(init_depth_map, image, depth_num, depth_start, depth_interval, is_master_gpu=True):
    """ refine depth image with the image """

    # normalization parameters
    depth_shape = tf.shape(init_depth_map)
    depth_end = depth_start + (tf.cast(depth_num, tf.float32) - 1) * depth_interval
    depth_start_mat = tf.tile(tf.reshape(
        depth_start, [depth_shape[0], 1, 1, 1]), [1, depth_shape[1], depth_shape[2], 1])
    depth_end_mat = tf.tile(tf.reshape(
        depth_end, [depth_shape[0], 1, 1, 1]), [1, depth_shape[1], depth_shape[2], 1])
    depth_scale_mat = depth_end_mat - depth_start_mat

    # normalize depth map (to 0~1)
    init_norm_depth_map = tf.div(init_depth_map - depth_start_mat, depth_scale_mat)

    # resize normalized image to the same size of depth image
    resized_image = tf.image.resize_bilinear(image, [depth_shape[1], depth_shape[2]])

    # refinement network
    if is_master_gpu:
        norm_depth_tower = RefineNet({'color_image': resized_image, 'depth_image': init_norm_depth_map},
                                        is_training=True, reuse=False)
    else:
        norm_depth_tower = RefineNet({'color_image': resized_image, 'depth_image': init_norm_depth_map},
                                        is_training=True, reuse=True)
    norm_depth_map = norm_depth_tower.get_output()

    # denormalize depth map
    refined_depth_map = tf.multiply(norm_depth_map, depth_scale_mat) + depth_start_mat

    return refined_depth_map

Reference