polarizationpruning/PolarizationPruning

A question about SparseGate before the ConvBNReLU block in MobileNetV2

cys4 opened this issue · 8 comments

cys4 commented

Thank you for sharing this great work!

I have a question about pruning the MobileNetV2.
After following 3-step instructions, I got a checkpoint that has a 'SparseGate (input_gate)' between 'select' and 'conv' within every 'InvertedResidual' block like:

(17): InvertedResidual(
  (select): ChannelSelect(channel_num=160)
  (input_gate): SparseGate(channel_num=160)
  (conv): Sequential(
    (0): ConvBNReLU(
      (0): Conv2d(160, 774, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(774, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): ConvBNReLU(
      (0): Conv2d(774, 774, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=774, bias=False)
      (1): BatchNorm2d(774, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (2): Conv2d(774, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (3): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): Identity()
    (5): ChannelExpand(channel_num=320)
  )
)

But, I cannot see the SparseGate (input_gate) in the pruned checkpoint you provided.

I wonder you performed post-processing such as fusing weights of SparseGate and the following convolution to get these checkpoints or I missed something needed to get the final checkpoint.

Thank you.

Thank you for your interest in our work.

As you mentioned, our pruning pipeline is a 3-stage pipeline: training, pruning, and fine-tuning. The provided checkpoint is a fine-tuned checkpoint. We try not to introduce extra parameters to the fine-tuned model for the fair comparison. Therefore, there is no SparseGate in the model during the fine-tuning stage.

While pruning, the SpareGate is replaced by scaling factors and bias in bn layers:

Note: if the sparse_layer is SparseGate, the gate will be replaced by BatchNorm
scaling factor. The value of the gate will be set to all ones.

The related code is here:

# prune the gate
if isinstance(sparse_layer_out, SparseGate):
sparse_layer_out.prune(idx_out)
# multiply the bn weight and SparseGate weight
sparse_weight_out: torch.Tensor = sparse_layer_out.weight.view(-1)
bn_layer.weight.data = (bn_layer.weight.data * sparse_weight_out).clone()
bn_layer.bias.data = (bn_layer.bias.data * sparse_weight_out).clone()
# the function of the SparseGate is now replaced by bn layers
# the SparseGate should be disabled
sparse_layer_out.set_ones()

cys4 commented

Thank you for the answer.
Does the code affect the 'SparseGate' right after 'ChannelSelect' at the beginning of 'InvertedResidual'?
It only looks like setting 'SparseGate' right after 'BatchNorm2d' to one.
Of course, I guess I can fuse the weights of 'SparseGate' at the beginning into the following 'Conv2d' in a similar way.

Thank you for the answer.
Does the code affect the 'SparseGate' right after 'ChannelSelect' at the beginning of 'InvertedResidual'?
It only looks like setting 'SparseGate' right after 'BatchNorm2d' to one.
Of course, I guess I can fuse the weights of 'SparseGate' at the beginning into the following 'Conv2d' in a similar way.

You are right. The code I mentioned above only setting SparseGate right after BatchNorm2d to one. We did not remove the SparseGate at the beginning of InvertedResidual since there is no BatchNorm2d layer.

That's a good idea to fuse the weights of SparseGate at the beginning into the following Conv2d in a similar way. I think it's worth trying. Personally, I suppose removing SparseGate or not will make little difference.

cys4 commented

Thank you for the clarification!

@cys4 Sorry for missing one thing! Our released checkpoints did not enable the input_gate option in the sparsity training, so there are no SparseGate layers at the beginning of the ConvBNReLU. If you train the sparsity stage by enabling the input_gate, there will be a remaining SparseGate in the final fine-tuned checkpoint.

The current implementation keeps the input gate at the fine-tuning stage. As you mentioned, the SparseGate layer in the fine-tuned model can be fused with the first conv layer in the block.

I am pretty sorry to make confusion. Hope my clarification could help you!

cys4 commented

No problem!
'input_gate' that you mentioned does not seem an explicit command line option, but is hardcoded in the current code.
I think I should remove the relevant code, for example the following two lines here, to disable 'input_gate'.

                if use_gate or input_mask:
                    self.input_gate = SparseGate(conv_in) 

Am I right?
As you said, fusing with the following conv in the fine-tuning stage would be another option.

No problem!
'input_gate' that you mentioned does not seem an explicit command line option, but is hardcoded in the current code.
I think I should remove the relevant code, for example the following two lines here, to disable 'input_gate'.

                if use_gate or input_mask:
                    self.input_gate = SparseGate(conv_in) 

Am I right?
As you said, fusing with the following conv in the fine-tuning stage would be another option.

To disable input gates

I suppose that will work:

  1. Change this part:

    if self.pw:
    if use_gate or input_mask:
    self.input_gate = SparseGate(conv_in)

    to

                if self.pw:
                    if use_gate and input_mask:
                        # only enable the input_gate when input_mask option is True
                        self.input_gate = SparseGate(conv_in)
  2. Add a command line option here:

    parser.add_argument('--target-flops', type=float, default=None,
    help='Stop when pruned model archive the target FLOPs')

    parser.add_argument('--input-gate-mbv2', action='store_true',
                        help="Use an extra input gate at the beginning of each blocks of MobileNet v2.")

    and use it at

    model = mobilenet_v2(inverted_residual_setting=refine_checkpoint['cfg'],
    width_mult=args.width_multiplier,
    use_gate=args.gate, input_mask=args.gate)

                model = mobilenet_v2(inverted_residual_setting=refine_checkpoint['cfg'],
                                    width_mult=args.width_multiplier,
                                    use_gate=args.gate, input_mask=args.input_gate_mbv2)

After that, you can use --input-gate-mbv2 as a command line option.

How input gates work

The input gates at the beginning of each block give the ability to pruning the input channel of the first convolution layer in each block. It will further reduce the FLOPs of the model. The input channel of the first convolution layer in each block will be kept if there is no input gates. The relevant code could be seen here:

if in_channel_mask is not None:
# prune the input channel according to the in_channel_mask
# convert mask to channel indexes
idx_in = np.squeeze(np.argwhere(np.asarray(in_channel_mask)))
if len(idx_in.shape) == 0:
# expand the single scalar to array
idx_in = np.expand_dims(idx_in, 0)
elif len(idx_in.shape) == 1 and idx_in.shape[0] == 0:
# nothing left, prune the whole block
out_channel_mask = np.full(conv_layer.out_channels, False)
return in_channel_mask, out_channel_mask

The in_channel_mask will be None if there is no input gate, as shown here:

in_channel_mask, input_gate_mask = prune_conv_layer(conv_layer=pw_layer[0],
bn_layer=pw_layer[1],
sparse_layer_in=self.input_gate if self.has_input_mask else None,
sparse_layer_out=pw_layer.sparse_layer,
in_channel_mask=None if self.has_input_mask else in_channel_mask,
pruner=pruner,
prune_output_mode="prune",
prune_mode='default')

We did not enable the input gates when we train the released checkpoint. So you can see the input channel of the first conv layer in each block is unpruned. I think that can be improved by enabling the input gates.

Hope that helps! Let me know if there is any question.

cys4 commented

It definitely helps.
Things are clearer now.

Thank you very much.