sinahmr/DIaM

about the train code, train_base.py of BAM is not used to this repo's src/model/pspnet.py

Closed this issue · 7 comments

about the train code, train_base.py of BAM is not used to this repo's src/model/pspnet.py

Dear Sina,
Hello, I am very interested in your work, excellent work in the field of GFSS. But I found some problems while running the code I used train_base.py to train the model on my dataset, but found that the trained model can't be loaded in this repo Discovered that BAM's PSPNet and DIaM's PSPNet models are not written in the same way, which resulted in the model not being able to be loaded properly Can you please provide me with the training code that you were using?

Hi,
Thanks for your interest in our work!

I used the same train_base.py code that you are using, but you're right, the module names in BAM differ from ours. To resolve this, what I did was to run the following script once on the BAM's trained model to align it with our naming:

from collections import OrderedDict
state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
    k = k.replace('module.', '')  # If BAM is run on multiple GPUs, 'module.' will precede keys, otherwise, we should add it ourselves
    if k.startswith('cls.4.'):
        k = k.replace('cls.4', 'classifier')
    elif k.startswith('cls.'):
        k = k.replace('cls', 'bottleneck')
    elif k.startswith('encoder.'):
        continue
    k = 'module.' + k
    state_dict[k] = v

filename = os.path.join(root, f'model.pth')
torch.save({'epoch': checkpoint['epoch'], 'state_dict': state_dict, 'optimizer': checkpoint['optimizer']}, filename)

You can paste it somewhere after the BAM model is loaded, for example after this line, run the code until the end of this snippet (maybe add an exit() after it). Then delete this snippet and run the code once again, but with model.pth as the checkpoint (setting ckpt_used in the config file to model).

Let me know if this doesn't resolve the issue.

Thanks for the answer! It solved perfectly the problem I was having. In principle, the code you provided loads the original pspnet model parameters into the structure corresponding to the current repo's pspnet. Very ingenious. Thank you again!

No worries, I'm happy it helped!

@Nevaeh7 are you available for a quick talk please

Hello, I would like to ask if I can use a network other than PSPNet, such as SegFormer from the MMSEG framework, since you mentioned in the readme that any model trained on the base class can be used. If so, how can I modify the code in your repository?

Hi,
Sure, you should be able to use other networks. Our contribution is what is implemented in the src/classifier.py file. Functions in the Classifier class take as input base classes' prototypes, alongside support and query features, and perform the inference. You can use other networks to generate features for the support and query images, and also use its final classifier (Linear) weights as base prototypes. Our proposed approach happens from this line to this line in src/test.py, so you can change the rest.