A question about SparseGate before the ConvBNReLU block in MobileNetV2
cys4 opened this issue · 8 comments
Thank you for sharing this great work!
I have a question about pruning the MobileNetV2.
After following 3-step instructions, I got a checkpoint that has a 'SparseGate (input_gate)' between 'select' and 'conv' within every 'InvertedResidual' block like:
(17): InvertedResidual(
(select): ChannelSelect(channel_num=160)
(input_gate): SparseGate(channel_num=160)
(conv): Sequential(
(0): ConvBNReLU(
(0): Conv2d(160, 774, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(774, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(1): ConvBNReLU(
(0): Conv2d(774, 774, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=774, bias=False)
(1): BatchNorm2d(774, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(2): Conv2d(774, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
(3): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): Identity()
(5): ChannelExpand(channel_num=320)
)
)
But, I cannot see the SparseGate (input_gate) in the pruned checkpoint you provided.
I wonder you performed post-processing such as fusing weights of SparseGate and the following convolution to get these checkpoints or I missed something needed to get the final checkpoint.
Thank you.
Thank you for your interest in our work.
As you mentioned, our pruning pipeline is a 3-stage pipeline: training, pruning, and fine-tuning. The provided checkpoint is a fine-tuned checkpoint. We try not to introduce extra parameters to the fine-tuned model for the fair comparison. Therefore, there is no SparseGate
in the model during the fine-tuning stage.
While pruning, the SpareGate
is replaced by scaling factors and bias in bn layers:
PolarizationPruning/imagenet/models/common.py
Lines 127 to 128 in f19f3fd
The related code is here:
PolarizationPruning/imagenet/models/common.py
Lines 246 to 255 in f19f3fd
Thank you for the answer.
Does the code affect the 'SparseGate' right after 'ChannelSelect' at the beginning of 'InvertedResidual'?
It only looks like setting 'SparseGate' right after 'BatchNorm2d' to one.
Of course, I guess I can fuse the weights of 'SparseGate' at the beginning into the following 'Conv2d' in a similar way.
Thank you for the answer.
Does the code affect the 'SparseGate' right after 'ChannelSelect' at the beginning of 'InvertedResidual'?
It only looks like setting 'SparseGate' right after 'BatchNorm2d' to one.
Of course, I guess I can fuse the weights of 'SparseGate' at the beginning into the following 'Conv2d' in a similar way.
You are right. The code I mentioned above only setting SparseGate
right after BatchNorm2d
to one. We did not remove the SparseGate
at the beginning of InvertedResidual
since there is no BatchNorm2d
layer.
That's a good idea to fuse the weights of SparseGate
at the beginning into the following Conv2d
in a similar way. I think it's worth trying. Personally, I suppose removing SparseGate
or not will make little difference.
Thank you for the clarification!
@cys4 Sorry for missing one thing! Our released checkpoints did not enable the input_gate
option in the sparsity training, so there are no SparseGate
layers at the beginning of the ConvBNReLU
. If you train the sparsity stage by enabling the input_gate
, there will be a remaining SparseGate
in the final fine-tuned checkpoint.
The current implementation keeps the input gate at the fine-tuning stage. As you mentioned, the SparseGate
layer in the fine-tuned model can be fused with the first conv layer in the block.
I am pretty sorry to make confusion. Hope my clarification could help you!
No problem!
'input_gate' that you mentioned does not seem an explicit command line option, but is hardcoded in the current code.
I think I should remove the relevant code, for example the following two lines here, to disable 'input_gate'.
if use_gate or input_mask:
self.input_gate = SparseGate(conv_in)
Am I right?
As you said, fusing with the following conv in the fine-tuning stage would be another option.
No problem!
'input_gate' that you mentioned does not seem an explicit command line option, but is hardcoded in the current code.
I think I should remove the relevant code, for example the following two lines here, to disable 'input_gate'.if use_gate or input_mask: self.input_gate = SparseGate(conv_in)
Am I right?
As you said, fusing with the following conv in the fine-tuning stage would be another option.
To disable input gates
I suppose that will work:
-
Change this part:
PolarizationPruning/imagenet/models/mobilenet.py
Lines 219 to 221 in f19f3fd
to
if self.pw: if use_gate and input_mask: # only enable the input_gate when input_mask option is True self.input_gate = SparseGate(conv_in)
-
Add a command line option here:
PolarizationPruning/imagenet/main.py
Lines 194 to 195 in f19f3fd
parser.add_argument('--input-gate-mbv2', action='store_true', help="Use an extra input gate at the beginning of each blocks of MobileNet v2.")
and use it at
PolarizationPruning/imagenet/main.py
Lines 426 to 428 in f19f3fd
model = mobilenet_v2(inverted_residual_setting=refine_checkpoint['cfg'], width_mult=args.width_multiplier, use_gate=args.gate, input_mask=args.input_gate_mbv2)
After that, you can use --input-gate-mbv2
as a command line option.
How input gates work
The input gates at the beginning of each block give the ability to pruning the input channel of the first convolution layer in each block. It will further reduce the FLOPs of the model. The input channel of the first convolution layer in each block will be kept if there is no input gates. The relevant code could be seen here:
PolarizationPruning/imagenet/models/common.py
Lines 177 to 187 in f19f3fd
The in_channel_mask
will be None
if there is no input gate, as shown here:
PolarizationPruning/imagenet/models/mobilenet.py
Lines 282 to 289 in f19f3fd
We did not enable the input gates when we train the released checkpoint. So you can see the input channel of the first conv layer in each block is unpruned. I think that can be improved by enabling the input gates.
Hope that helps! Let me know if there is any question.
It definitely helps.
Things are clearer now.
Thank you very much.