AndreaCodegoni/Tiny_model_4_CD

Update for multichannel input

Opened this issue · 5 comments

Not an issue, but in case anyone else is interested in training on > 3 bands (e.g. sentinel 2) here are the changes:

class ChangeClassifier(Module):
    def __init__(
        self,
        input_channels: int = 3, # new arg
        bkbn_name: str = "efficientnet_b4",
        pretrained: bool = True,
        output_layer_bkbn: str = "3",
        freeze_backbone: bool = False,
    ):
        super().__init__()

        # Load the pretrained backbone according to parameters:
        self._backbone = _get_backbone(
            bkbn_name, pretrained, output_layer_bkbn, freeze_backbone, input_channels
        )

        # Initialize mixing blocks passing input_channels
        self._first_mix = MixingMaskAttentionBlock(
            input_channels * 2, input_channels, [input_channels, 10, 5], [10, 5, 1]
        )

...

# pass input_channels to _get_backbone
def _get_backbone(
    bkbn_name, pretrained, output_layer_bkbn, freeze_backbone, input_channels
) -> ModuleList:
    # The whole model:
    entire_model = getattr(torchvision.models, bkbn_name)(
        weights=EfficientNet_B4_Weights.IMAGENET1K_V1 if pretrained else None
    ).features

    # Modify the first conv layer input_channels
    first_conv = entire_model[0][0]
    first_conv.in_channels = input_channels

    new_weight = torch.randn(
        first_conv.out_channels, input_channels, *first_conv.kernel_size
    )
    first_conv.weight.data = new_weight

    # Slicing the model
    derived_model = ModuleList([])
    for name, layer in enumerate(entire_model):
        derived_model.append(layer)
        if str(name) == output_layer_bkbn:
            break

    # Freezing the backbone weights
    if freeze_backbone:
        for param in derived_model.parameters():
            param.requires_grad = False

    return derived_model

Cheers

Hi Robin,

thanks for sharing this extension. This could be a very interesting application. Do you have any example to show how the model performs on this type of images? I'm curious about that :)

@AndreaCodegoni I've seen a modest bump in metrics on the Onera (OSCD) Sentinel 2 change detection dataset. If you are curious the datamodule etc I took from https://github.com/Dibz15/OpenMineChangeDetection

Wow, very interesting repo, thank you! I'll have a look at it and I will add a link to the README!

@robmarkcole thank you! It is nice to see that TinyCD is used by other researchers :)