Freqfusion如何在Segformer解码器使用?
Opened this issue · 5 comments
StephenZhou98 commented
好牛好牛!能否分享一下Freqfusion如何在Segformer的解码器中实现方式?对于4倍、8倍的上采样方法,如何实现呢?不胜感激!
Linwei-Chen commented
感谢您的关注!您可以参考SegNeXt的实现,基本是相同的:
https://github.com/Linwei-Chen/FreqFusion/blob/main/SegNeXt/mmseg/models/decode_heads/ham_head.py
之前的工作以及我们都注意到SegFormer的结果相对波动,您可能需要多次调试
SwordShiSan commented
我尝试在SegFormer的head中加入FreqFusion,使用3个FreqFusion将上采样替换,结果是它需要比较大的显存,出现CUDA out of memory,我怀疑是我加入的形式不对,请问能否给出在在论文中的segformer头的实现方式参考一下?谢谢
Linwei-Chen commented
显存占用确实会增加,建议使用mmcv的carafe降低显存占用,可以使用use_checkpoint=True减少一些
Linwei-Chen commented
@HEADS.register_module()
class SegformerHeadFreqFusion(SegformerHead):
"""
not test
"""
def __init__(self,
compress_ratio=8,
lowpass_kernel=5,
highpass_kernel=3,
lowpass_pad=0,
highpass_pad=0,
padding_mode='replicate',
hamming_window=False,
feature_align=False,
feature_align_group=4,
comp_feat_upsample=True,
compressed_channel=None,
semi_conv=True,
hf_att=False,
use_global_context=False,
use_channel_att=False,
use_dyedgeconv=False,
**kwargs):
super().__init__(**kwargs)
# channels = kwargs.get('channels', 256)
# in_channels = kwargs.get('in_channels', None)
self.freqfusions = nn.ModuleList()
pre_c = self.channels
self.feature_align_group = feature_align_group
for idx in range(len(self.in_channels) - 1, 0, -1):
freqfusion = FreqFusion2(hr_channels=self.channels, lr_channels=pre_c,
scale_factor=1,
lowpass_kernel=lowpass_kernel,
highpass_kernel=highpass_kernel,
lowpass_pad = lowpass_pad,
highpass_pad = highpass_pad,
padding_mode = padding_mode,
hamming_window = hamming_window,
comp_feat_upsample = comp_feat_upsample,
feature_align = feature_align,
# feature_align_group = feature_align_group * (len(self.freqfusions) + 1),
feature_align_group = feature_align_group,
use_channel_att=use_channel_att,
up_group=1,
upsample_mode='nearest',
align_corners=False,
hr_residual=True,
compressed_channels= (pre_c + self.channels) // compress_ratio if compressed_channel is None else compressed_channel,
use_high_pass=True, use_low_pass=True, semi_conv=semi_conv)
pre_c += self.channels
self.freqfusions.append(freqfusion)
def _forward_feature(self, inputs):
# Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32
inputs = self._transform_inputs(inputs)
outs = []
for idx in range(len(inputs)):
x = inputs[idx]
conv = self.convs[idx]
outs.append(conv(x.contiguous()))
# from low to high resolution
outs = outs[::-1]
lowres_feat = outs[0]
for hires_feat, freqfusion in zip(outs[1:], self.freqfusions):
_, hires_feat, lowres_feat = freqfusion(hr_feat=hires_feat, lr_feat=lowres_feat, use_checkpoint=False)
lowres_feat = torch.cat([lowres_feat, hires_feat], dim=1)
out = lowres_feat
out = self.fusion_conv(out)
return out
def forward(self, inputs):
out = self._forward_feature(inputs)
out = self.cls_seg(out)
return out
此次提供一个早期的实现,参数名字可能有变化
SwordShiSan commented
好的,感谢你的回答