qubvel-org/segmentation_models.pytorch

Guidance required to add CBAM Module

jawi289o opened this issue · 1 comments

Can you please guide how can I add the CBAM module in UNet and Deeplabv3p?

cbam

Hi @jawi289o there is a similar attention module SCSE

class SCSEModule(nn.Module):
def __init__(self, in_channels, reduction=16):
super().__init__()
self.cSE = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, in_channels // reduction, 1),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels // reduction, in_channels, 1),
nn.Sigmoid(),
)
self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid())
def forward(self, x):
return x * self.cSE(x) + x * self.sSE(x)

class Attention(nn.Module):
def __init__(self, name, **params):
super().__init__()
if name is None:
self.attention = nn.Identity(**params)
elif name == "scse":
self.attention = SCSEModule(**params)
else:
raise ValueError("Attention {} is not implemented".format(name))
def forward(self, x):
return self.attention(x)

So, you can just specify decoder_attention_type="scse", or if there is any difference you can add it to the file