mhamilton723/STEGO

Integrating DINOv2 as Backbone for STEGO

benearnthof opened this issue · 2 comments

Hello & thank you for this great repository. I've been experimenting a lot with STEGO in the past couple of weeks and would like to integrate DINOv2 as a new backbone for STEGO. I've simply added a custom DINOv2 Featurizer like this:

class DinoV2Featurizer(nn.Module):
    def __init__(self, dim, cfg, freeze_backbone=True):
        # SAME AS DINOv1 but loading correct model of course
            for p in self.model.parameters():
                p.requires_grad = False

        self.dropout = torch.nn.Dropout2d(p=.1)
        if "s" in arch: # small model
            self.n_feats = 384
        elif "b" in arch: # big model
            self.n_feats = 768
        elif "l" in arch: # large model
            self.n_feats = 1024
        else: # giant model
            self.n_feats = 1536

        self.cluster1 = self.make_clusterer(self.n_feats)
        self.proj_type = cfg.projection_type
        if self.proj_type == "nonlinear":
            self.cluster2 = self.make_nonlinear_clusterer(self.n_feats)

    def make_clusterer(self, in_channels):
        return torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, self.dim, (1, 1)))  # ,num_classes used to be self.dim

    def make_nonlinear_clusterer(self, in_channels):
        return torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, in_channels, (1, 1)),
            torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels, self.dim, (1, 1))) # , used to be self.dim

    def forward(self, img, n=1):
        self.model.eval()
        with torch.no_grad():
            assert (img.shape[2] % self.patch_size == 0)
            assert (img.shape[3] % self.patch_size == 0)
            image_feat = self.model.get_intermediate_layers(x=img, n=1, reshape=True)[0]

        if self.proj_type is not None:
            code = self.cluster1(self.dropout(image_feat))
            if self.proj_type == "nonlinear":
                code += self.cluster2(self.dropout(image_feat))
        else:
            code = image_feat

        return image_feat, code

But the performance of this is horrendous. I'm not even getting above 65% validation accuracy on Potsdam3 where STEGO with DINOv1 saturates at around 81% after 5000 training steps. What I've got from the Paper and other supplementary material linked in other issues is that STEGO seems to be pretty sensitive to setting the training hyperparameters. Could you provide resources on how to best tune them for a new backbone, or point me in the right direction on how to obtain a set of usable hyperparameters. Best regards & thank you in advance.

Update: On the Cityscapes & Coco Datasets this seems to work out of the box and yields very good results, only Potsdam3 seems very susceptible to hyperparameters.

Closing this as I've found that performing greedy grid search for the pos_inter_shift and pos_intra_shift hyperparameters is enough to obtain models that outperform STEGO with DINOv1 embeddings.