Integrating DINOv2 as Backbone for STEGO
benearnthof opened this issue · 2 comments
Hello & thank you for this great repository. I've been experimenting a lot with STEGO in the past couple of weeks and would like to integrate DINOv2 as a new backbone for STEGO. I've simply added a custom DINOv2 Featurizer like this:
class DinoV2Featurizer(nn.Module):
def __init__(self, dim, cfg, freeze_backbone=True):
# SAME AS DINOv1 but loading correct model of course
for p in self.model.parameters():
p.requires_grad = False
self.dropout = torch.nn.Dropout2d(p=.1)
if "s" in arch: # small model
self.n_feats = 384
elif "b" in arch: # big model
self.n_feats = 768
elif "l" in arch: # large model
self.n_feats = 1024
else: # giant model
self.n_feats = 1536
self.cluster1 = self.make_clusterer(self.n_feats)
self.proj_type = cfg.projection_type
if self.proj_type == "nonlinear":
self.cluster2 = self.make_nonlinear_clusterer(self.n_feats)
def make_clusterer(self, in_channels):
return torch.nn.Sequential(
torch.nn.Conv2d(in_channels, self.dim, (1, 1))) # ,num_classes used to be self.dim
def make_nonlinear_clusterer(self, in_channels):
return torch.nn.Sequential(
torch.nn.Conv2d(in_channels, in_channels, (1, 1)),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels, self.dim, (1, 1))) # , used to be self.dim
def forward(self, img, n=1):
self.model.eval()
with torch.no_grad():
assert (img.shape[2] % self.patch_size == 0)
assert (img.shape[3] % self.patch_size == 0)
image_feat = self.model.get_intermediate_layers(x=img, n=1, reshape=True)[0]
if self.proj_type is not None:
code = self.cluster1(self.dropout(image_feat))
if self.proj_type == "nonlinear":
code += self.cluster2(self.dropout(image_feat))
else:
code = image_feat
return image_feat, code
But the performance of this is horrendous. I'm not even getting above 65% validation accuracy on Potsdam3 where STEGO with DINOv1 saturates at around 81% after 5000 training steps. What I've got from the Paper and other supplementary material linked in other issues is that STEGO seems to be pretty sensitive to setting the training hyperparameters. Could you provide resources on how to best tune them for a new backbone, or point me in the right direction on how to obtain a set of usable hyperparameters. Best regards & thank you in advance.
Update: On the Cityscapes & Coco Datasets this seems to work out of the box and yields very good results, only Potsdam3 seems very susceptible to hyperparameters.
Closing this as I've found that performing greedy grid search for the pos_inter_shift
and pos_intra_shift
hyperparameters is enough to obtain models that outperform STEGO with DINOv1 embeddings.