Update for multichannel input
Opened this issue · 5 comments
Not an issue, but in case anyone else is interested in training on > 3 bands (e.g. sentinel 2) here are the changes:
class ChangeClassifier(Module):
def __init__(
self,
input_channels: int = 3, # new arg
bkbn_name: str = "efficientnet_b4",
pretrained: bool = True,
output_layer_bkbn: str = "3",
freeze_backbone: bool = False,
):
super().__init__()
# Load the pretrained backbone according to parameters:
self._backbone = _get_backbone(
bkbn_name, pretrained, output_layer_bkbn, freeze_backbone, input_channels
)
# Initialize mixing blocks passing input_channels
self._first_mix = MixingMaskAttentionBlock(
input_channels * 2, input_channels, [input_channels, 10, 5], [10, 5, 1]
)
...
# pass input_channels to _get_backbone
def _get_backbone(
bkbn_name, pretrained, output_layer_bkbn, freeze_backbone, input_channels
) -> ModuleList:
# The whole model:
entire_model = getattr(torchvision.models, bkbn_name)(
weights=EfficientNet_B4_Weights.IMAGENET1K_V1 if pretrained else None
).features
# Modify the first conv layer input_channels
first_conv = entire_model[0][0]
first_conv.in_channels = input_channels
new_weight = torch.randn(
first_conv.out_channels, input_channels, *first_conv.kernel_size
)
first_conv.weight.data = new_weight
# Slicing the model
derived_model = ModuleList([])
for name, layer in enumerate(entire_model):
derived_model.append(layer)
if str(name) == output_layer_bkbn:
break
# Freezing the backbone weights
if freeze_backbone:
for param in derived_model.parameters():
param.requires_grad = False
return derived_model
Cheers
Hi Robin,
thanks for sharing this extension. This could be a very interesting application. Do you have any example to show how the model performs on this type of images? I'm curious about that :)
@AndreaCodegoni I've seen a modest bump in metrics on the Onera (OSCD) Sentinel 2 change detection dataset. If you are curious the datamodule etc I took from https://github.com/Dibz15/OpenMineChangeDetection
Wow, very interesting repo, thank you! I'll have a look at it and I will add a link to the README!
@AndreaCodegoni another ref for you https://github.com/developmentseed/chabud2023/tree/main
@robmarkcole thank you! It is nice to see that TinyCD is used by other researchers :)