uber-research/sbnet

Is there any way to reduce computation time of sbnet_module.reduce_mask?

Closed this issue · 2 comments

I have built the yolov2 object detector using sbnet, but it takes too long to compute the result of sbnet_module.reduce_mask.
I need to compute sbnet_module.reduce_mask every frame, because the mask change every frame.

# yolov2 with dense convnet
Forwarding 1 inputs ...
Forwarding time = 0.0349409580231 sec

# yolov2 with sbnet (sparsity = 0.92)
Forwarding 1 inputs ...
Fowarding time = 0.0198512077332 sec
 + time(sbnet_module.reduce_mask) = 0.0325801372528 sec

When I applied sbnet model on yolov2(darknet) model,
forwarding time was about 1.7 times faster. However, it took longer than expected to compute the reduce_mask results which are needed to perform sparse_gather and sparse_scatter.

Below is my code to compute reduce_mask for conv1s ~ conv5.
Currently, it takes 0.03 sec to execute this code, but it's too slow considering the forwarding time of the detector (detector forwarding time is almost 0.03 sec).
Is there any fastest way to compute sbnet_module.reduce_mask?

# compute block_params for all different size in conv1s ~ conv5
block_params_p0_k3 = calc_block_params([1, 416, 864, None],
                                       [1, 34, 34, 1],
                                       [3, 3, 1, 1],
                                       [1, 1, 1, 1],
                                       padding='VALID')
block_params_p1_k3 = calc_block_params([1, 208, 432, None],
                                       [1, 18, 18, 1],
                                       [3, 3, 1, 1],
                                       [1, 1, 1, 1],
                                       padding='VALID')
block_params_p2_k3 = calc_block_params([1, 104, 216, None],
                                       [1, 10, 10, 1],
                                       [3, 3, 1, 1],
                                       [1, 1, 1, 1],
                                       padding='VALID')
block_params_p2_k1 = calc_block_params([1, 104, 216, None],
                                       [1, 8, 8, 1],
                                       [1, 1, 1, 1],
                                       [1, 1, 1, 1],
                                       padding='VALID')
block_params_p3_k3 = calc_block_params([1, 52, 108, None],
                                       [1, 6, 6, 1],
                                       [3, 3, 1, 1],
                                       [1, 1, 1, 1],
                                       padding='VALID')
block_params_p3_k1 = calc_block_params([1, 52, 108, None],
                                       [1, 4, 4, 1],
                                       [1, 1, 1, 1],
                                       [1, 1, 1, 1],
                                       padding='VALID')
block_params_p4_k3 = calc_block_params([1, 26, 54, None],
                                       [1, 4, 4, 1],
                                       [3, 3, 1, 1],
                                       [1, 1, 1, 1],
                                       padding='VALID')
block_params_p4_k1 = calc_block_params([1, 26, 54, None],
                                       [1, 2, 2, 1],
                                       [1, 1, 1, 1],
                                       [1, 1, 1, 1],
                                       padding='VALID')
block_params_p5_k3 = calc_block_params([1, 13, 27, None],
                                       [1, 3, 3, 1],
                                       [3, 3, 1, 1],
                                       [1, 1, 1, 1],
                                       padding='VALID')
block_params_p5_k1 = calc_block_params([1, 13, 27, None],
                                       [1, 1, 1, 1],
                                       [1, 1, 1, 1],
                                       [1, 1, 1, 1],
                                       padding='VALID')

# compute random binaray mask depending on bndbox for conv1s ~ conv5
mask_p0 = np.zeros([1, 416, 864], dtype=np.float32)
if bndbox:
    for bbox in bndbox:
        xmin = max(0, int(round(bbox[0] * 416)))
        ymin = max(0, int(round(bbox[1] * 864)))
        xmax = min(int(round(bbox[2] * 416)), 415)
        ymax = min(int(round(bbox[3] * 864)), 863)
        mask_p0[:, xmin:xmax, ymin:ymax] = 1.0

mask_p5 = block_reduce(mask_p0, (1, 32, 32), np.max)
mask_p4 = mask_p5.repeat(2, axis=1).repeat(2, axis=2)
mask_p3 = mask_p4.repeat(2, axis=1).repeat(2, axis=2)
mask_p2 = mask_p3.repeat(2, axis=1).repeat(2, axis=2)
mask_p1 = mask_p2.repeat(2, axis=1).repeat(2, axis=2)
mask_p0 = mask_p1.repeat(2, axis=1).repeat(2, axis=2)

with tf.Graph().as_default():
    mask_p0_tf = tf.constant(mask_p0, dtype=tf.float32)
    mask_p1_tf = tf.constant(mask_p1, dtype=tf.float32)
    mask_p2_tf = tf.constant(mask_p2, dtype=tf.float32)
    mask_p3_tf = tf.constant(mask_p3, dtype=tf.float32)
    mask_p4_tf = tf.constant(mask_p4, dtype=tf.float32)
    mask_p5_tf = tf.constant(mask_p5, dtype=tf.float32)

    # compute sbnet for conv1s ~ conv5
    sbnet_p0_k3 = sbnet_module.reduce_mask(mask_p0_tf,
        tf.constant(block_params_p0_k3.bcount, dtype=tf.int32),
         bsize=block_params_p0_k3.bsize,
         boffset=block_params_p0_k3.boffset,
         bstride=block_params_p0_k3.bstrides,
         tol=0.0,
         avgpool=True)
    sbnet_p1_k3 = sbnet_module.reduce_mask(mask_p1_tf,
         tf.constant(block_params_p1_k3.bcount, dtype=tf.int32),
         bsize=block_params_p1_k3.bsize,
         boffset=block_params_p1_k3.boffset,
         bstride=block_params_p1_k3.bstrides,
         tol=0.0,
         avgpool=True)
    sbnet_p2_k3 = sbnet_module.reduce_mask(mask_p2_tf,
         tf.constant(block_params_p2_k3.bcount, dtype=tf.int32),
         bsize=block_params_p2_k3.bsize,
         boffset=block_params_p2_k3.boffset,
         bstride=block_params_p2_k3.bstrides,
         tol=0.0,
         avgpool=True)
    sbnet_p2_k1 = sbnet_module.reduce_mask(mask_p2_tf,
         tf.constant(block_params_p2_k1.bcount, dtype=tf.int32),
         bsize=block_params_p2_k1.bsize,
         boffset=block_params_p2_k1.boffset,
         bstride=block_params_p2_k1.bstrides,
         tol=0.0,
         avgpool=True)
    sbnet_p3_k3 = sbnet_module.reduce_mask(mask_p3_tf,
         tf.constant(block_params_p3_k3.bcount, dtype=tf.int32),
         bsize=block_params_p3_k3.bsize,
         boffset=block_params_p3_k3.boffset,
         bstride=block_params_p3_k3.bstrides,
         tol=0.0,
         avgpool=True)
    sbnet_p3_k1 = sbnet_module.reduce_mask(mask_p3_tf,
         tf.constant(block_params_p3_k1.bcount, dtype=tf.int32),
         bsize=block_params_p3_k1.bsize,
         boffset=block_params_p3_k1.boffset,
         bstride=block_params_p3_k1.bstrides,
         tol=0.0,
         avgpool=True)
    sbnet_p4_k3 = sbnet_module.reduce_mask(mask_p4_tf,
         tf.constant(block_params_p4_k3.bcount, dtype=tf.int32),
         bsize=block_params_p4_k3.bsize,
         boffset=block_params_p4_k3.boffset,
         bstride=block_params_p4_k3.bstrides,
         tol=0.0,
         avgpool=True)
    sbnet_p4_k1 = sbnet_module.reduce_mask(mask_p4_tf,
         tf.constant(block_params_p4_k1.bcount, dtype=tf.int32),
         bsize=block_params_p4_k1.bsize,
         boffset=block_params_p4_k1.boffset,
         bstride=block_params_p4_k1.bstrides,
         tol=0.0,
         avgpool=True)
    sbnet_p5_k3 = sbnet_module.reduce_mask(mask_p5_tf,
         tf.constant(block_params_p5_k3.bcount, dtype=tf.int32),
         bsize=block_params_p5_k3.bsize,
         boffset=block_params_p5_k3.boffset,
         bstride=block_params_p5_k3.bstrides,
         tol=0.0,
         avgpool=True)
    sbnet_p5_k1 = sbnet_module.reduce_mask(mask_p5_tf,
         tf.constant(block_params_p5_k1.bcount, dtype=tf.int32),
         bsize=block_params_p5_k1.bsize,
         boffset=block_params_p5_k1.boffset,
         bstride=block_params_p5_k1.bstrides,
         tol=0.0,
         avgpool=True)

    with tf.Session() as sess:
        ind_val_p0_k3, ind_val_p1_k3, \
        ind_val_p2_k3, ind_val_p2_k1, \
        ind_val_p3_k3, ind_val_p3_k1, \
        ind_val_p4_k3, ind_val_p4_k1, \
        ind_val_p5_k3, ind_val_p5_k1, \
        bin_val_p0_k3, bin_val_p1_k3, \
        bin_val_p2_k3, bin_val_p2_k1, \
        bin_val_p3_k3, bin_val_p3_k1, \
        bin_val_p4_k3, bin_val_p4_k1, \
        bin_val_p5_k3, bin_val_p5_k1 = \
            sess.run([sbnet_p0_k3.active_block_indices,
                      sbnet_p1_k3.active_block_indices,
                      sbnet_p2_k3.active_block_indices,
                      sbnet_p2_k1.active_block_indices,
                      sbnet_p3_k3.active_block_indices,
                      sbnet_p3_k1.active_block_indices,
                      sbnet_p4_k3.active_block_indices,
                      sbnet_p4_k1.active_block_indices,
                      sbnet_p5_k3.active_block_indices,
                      sbnet_p5_k1.active_block_indices,
                      sbnet_p0_k3.bin_counts,
                      sbnet_p1_k3.bin_counts,
                      sbnet_p2_k3.bin_counts,
                      sbnet_p2_k1.bin_counts,
                      sbnet_p3_k3.bin_counts,
                      sbnet_p3_k1.bin_counts,
                      sbnet_p4_k3.bin_counts,
                      sbnet_p4_k1.bin_counts,
                      sbnet_p5_k3.bin_counts,
                      sbnet_p5_k1.bin_counts])

# After that, these values go into feed_dict.

Generally reduce_mask is a very inexpensive op and should be applied to a small tensor of NHW1 dimensions. In our experience it takes around 20 microseconds per call IIRC. Running a separate session to compute the mask will add a lot of overhead though so I recommend keeping reduce_mask nodes inside the same graph/session run as the main network. It's not entirely clear from your example if you already do that but it seems you pass the outputs to another session using feed_dict which would incur a lot of session overhead.

Another question is how you measure the timings. You may want to check out my post here:
https://stackoverflow.com/questions/34293714/can-i-measure-the-execution-time-of-individual-operations-with-tensorflow

One possible optimization is to support batched reduce mask but in use cases we've seen that would be targeting savings on the order of 10s of microseconds.

I also recommend using NVIDIA's nvprof/nvvp tools for profiling.

The time spent in sbnet_module.reduce_mask dropped a lot when I merge the session following your advice and pass only random binary mask as feed_dict. Thank you :)