Linwei-Chen/FreqFusion

Freqfusion如何在Segformer解码器使用?

Opened this issue · 5 comments

好牛好牛!能否分享一下Freqfusion如何在Segformer的解码器中实现方式?对于4倍、8倍的上采样方法,如何实现呢?不胜感激!

感谢您的关注!您可以参考SegNeXt的实现,基本是相同的:
https://github.com/Linwei-Chen/FreqFusion/blob/main/SegNeXt/mmseg/models/decode_heads/ham_head.py

之前的工作以及我们都注意到SegFormer的结果相对波动,您可能需要多次调试

我尝试在SegFormer的head中加入FreqFusion,使用3个FreqFusion将上采样替换,结果是它需要比较大的显存,出现CUDA out of memory,我怀疑是我加入的形式不对,请问能否给出在在论文中的segformer头的实现方式参考一下?谢谢

显存占用确实会增加,建议使用mmcv的carafe降低显存占用,可以使用use_checkpoint=True减少一些

@HEADS.register_module()
class SegformerHeadFreqFusion(SegformerHead):
    """
    not test
    """

    def __init__(self, 
                compress_ratio=8, 
                lowpass_kernel=5, 
                highpass_kernel=3, 
                lowpass_pad=0,
                highpass_pad=0,
                padding_mode='replicate',
                hamming_window=False,
                feature_align=False,
                feature_align_group=4,
                comp_feat_upsample=True,
                compressed_channel=None,
                semi_conv=True,
                hf_att=False,
                use_global_context=False,
                use_channel_att=False,
                use_dyedgeconv=False,
                **kwargs):
        super().__init__(**kwargs)
        # channels = kwargs.get('channels', 256)
        # in_channels = kwargs.get('in_channels', None)
        self.freqfusions = nn.ModuleList()
        pre_c = self.channels
        self.feature_align_group = feature_align_group
        for idx in range(len(self.in_channels) - 1, 0, -1):
            freqfusion = FreqFusion2(hr_channels=self.channels, lr_channels=pre_c, 
                    scale_factor=1, 
                    lowpass_kernel=lowpass_kernel, 
                    highpass_kernel=highpass_kernel, 
                    lowpass_pad = lowpass_pad,
                    highpass_pad = highpass_pad,
                    padding_mode = padding_mode,
                    hamming_window = hamming_window,
                    comp_feat_upsample = comp_feat_upsample,
                    feature_align = feature_align,
                    # feature_align_group = feature_align_group * (len(self.freqfusions) + 1),
                    feature_align_group = feature_align_group,
                    use_channel_att=use_channel_att,
                    up_group=1, 
                    upsample_mode='nearest', 
                    align_corners=False,
                    hr_residual=True, 
                    compressed_channels= (pre_c + self.channels) // compress_ratio if compressed_channel is None else compressed_channel,
                    use_high_pass=True, use_low_pass=True, semi_conv=semi_conv)                
            pre_c += self.channels
            self.freqfusions.append(freqfusion)

    def _forward_feature(self, inputs):
        # Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32
        inputs = self._transform_inputs(inputs)
        outs = []
        for idx in range(len(inputs)):
            x = inputs[idx]
            conv = self.convs[idx]
            outs.append(conv(x.contiguous()))

        # from low to high resolution
        outs = outs[::-1]
        lowres_feat = outs[0]
        for hires_feat, freqfusion in zip(outs[1:], self.freqfusions):
            _, hires_feat, lowres_feat = freqfusion(hr_feat=hires_feat, lr_feat=lowres_feat, use_checkpoint=False)
         
            lowres_feat = torch.cat([lowres_feat, hires_feat], dim=1)
        out = lowres_feat
        out = self.fusion_conv(out)
        return out
        
    def forward(self, inputs):
        out = self._forward_feature(inputs)
        out = self.cls_seg(out)
        return out

此次提供一个早期的实现,参数名字可能有变化

好的,感谢你的回答