NVlabs/RADIO

radio-v2 update

PinxueGuo opened this issue · 3 comments

Hi, Thanks for the excellent work. I note that the REDME update says that v2's teachers includes SAM but v1 not. Is this a typo? Or does it mean that the radio I loaded previously via torch.hub was not trained by SAM? (It was January, v2 hadn't been released yet, so torch.hub didn't have to specify v1/v2)

Hi Pinxue. Thanks for reaching out again. I've actually been working pretty hard since you reached out to us in January to get SAM support working properly in the released code. When you first reached out, this was our response:

The initial model+code release won't include the adapter heads that you'd need to replace the vision encoder in SAM, however I will add it to the roadmap to include a version of RADIO that includes adapter heads and SAM compatibility.

So what's happening with RADIOv2 is actually that we're getting quite a bit closer to making the SAM drop-in replacement possible. For RADIOv1, we were having issues with the SAM adaptor weights, and we weren't able to resolve it without re-training the model. So unfortunately, RADIOv1 won't get SAM support.

However, I just pushed a code update that allows you to load the SAM head by passing adaptor_names='sam' to the hub.load(version='radio_v2', adaptor_names='sam', ...) call. What I don't have wired together is the full blown replacement of SAM's encoder because we think different users may have different requirements for the adaptor framework, and we're working with internal API design teams on how best to handle this. Here's a sketch of what should work to serve as a drop-in SAM replacement:

from segment_anything.build_sam import sam_model_registry
from einops import rearrange

class RADIOtoSAM(nn.Module):
    def __init__(self, radio: nn.Module, sam_ve_neck: nn.Module):
        super().__init__()
        self.radio = radio
        self.neck = sam_ve_neck

    def forward(self, x: torch.Tensor):
        assert x.shape[-3:] == (3, 1024, 1024), 'Invalid input shape'

        # `outputs` will be a dictionary with two entries: { 'backbone': ..., 'sam': ... }
        # We want the sam adaptor
        outputs = self.radio(x)
        sam_feats = outputs['sam'].features  # `.summary` is also available, but we don't use it for SAM

        features = rearrange(features, 'b (r c) d -> b d r c', r=64, c=64)

        features = self.neck(features.float())
        return features

# This will allow us to replace the vision encoder, up until the neck
radio_v2 = torch.hub.load(version='radio_v2', adaptor_heads='sam', vitdet_window_size=16)

sam = sam_model_registry['vit_h'](checkpoint=<path_to_checkpoint>, ...)
sam.image_encoder = RADIOtoSAM(radio_v2, sam.image_encoder.neck)

# Now you can use SAM like usual, but with RADIOv2 as the image encoder.

Hi Mike, that's so cool!!! Thank you again for your excellent research and hard work! This is so cool for RADIOv2.

By the way, after using it for the past month or so, I'd say that RADIOv1 already has a lot of good capacities for downstream tasks. Kudos to you!

Hi @PinxueGuo, thanks for the feedback!

I'd say that RADIOv1 already has a lot of good capacities for downstream tasks

If that's OK with you, do you mind sharing some details on this? We're always happy to know when our work is useful to others.