It is recommended that you write like this to support more backbones
Opened this issue · 1 comments
Jacky-Android commented
It is recommended that you write the UNetV2 class like this to support more backbones, timm==0.9.12
import timm
class UNetV2(nn.Module):
"""
use SpatialAtt + ChannelAtt
"""
def __init__(self, channel=32, n_classes=1, deep_supervision=True, backbone ='pvt_v2_b2',pretrained=False):
super().__init__()
self.deep_supervision = deep_supervision
self.encoder = timm.create_model(backbone,pretrained=pretrained,features_only=True,out_indices=(0,1,2,3))
channel1,channel2,channel3,channel4 = self.encoder.feature_info.channels()
self.ca_1 = ChannelAttention(channel1)
self.sa_1 = SpatialAttention()
self.ca_2 = ChannelAttention(channel2)
self.sa_2 = SpatialAttention()
self.ca_3 = ChannelAttention(channel3)
self.sa_3 = SpatialAttention()
self.ca_4 = ChannelAttention(channel4)
self.sa_4 = SpatialAttention()
self.Translayer_1 = BasicConv2d(channel1, channel, 1)
self.Translayer_2 = BasicConv2d(channel2, channel, 1)
self.Translayer_3 = BasicConv2d(channel3, channel, 1)
self.Translayer_4 = BasicConv2d(channel4, channel, 1)
self.sdi_1 = SDI(channel)
self.sdi_2 = SDI(channel)
self.sdi_3 = SDI(channel)
self.sdi_4 = SDI(channel)
self.seg_outs = nn.ModuleList([
nn.Conv2d(channel, n_classes, 1, 1)] * 4)
self.deconv2 = nn.ConvTranspose2d(channel, channel, kernel_size=4, stride=2, padding=1,
bias=False)
self.deconv3 = nn.ConvTranspose2d(channel, channel, kernel_size=4, stride=2,
padding=1, bias=False)
self.deconv4 = nn.ConvTranspose2d(channel, channel, kernel_size=4, stride=2,
padding=1, bias=False)
self.deconv5 = nn.ConvTranspose2d(channel, channel, kernel_size=4, stride=2,
padding=1, bias=False)
yaoppeng commented
Thanks for your valuable recommendation.
I will definitely modify it later and make it more general, especially for 3D volumes.