NVlabs/RADIO

Use RADIOV2 as VLM's vision encoder.

echo840 opened this issue · 16 comments

Hello, thank you for your great work!
We are currently exploring the utilization of radio as a vision encoder for vision language models. In our specific setup, we employ SigClip and RADIOV2 as the vision encoder, while Phi2 serves as the language model. The obtained results are as follows:
image

They use the same data and configuration, the only difference is the vision encoder. Is it normal to observe worse performance when using a RADIOv2 compared to using SigClip?

# Feature extract
class RadioVisionTower(nn.Module):
    def __init__(self, vision_tower, args, delay_load=False):
        super().__init__()

        self.is_loaded = False

        self.vision_tower_name = vision_tower

        if not delay_load:
            self.load_model()
        else:
            self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name,trust_remote_code=True)

    def load_model(self):
        self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
        self.image_processor.do_resize = True
        self.image_processor.crop_size = self.image_processor.size
        self.vision_tower = AutoModel.from_pretrained(self.vision_tower_name, trust_remote_code=True)
        self.vision_tower.requires_grad_(False)

        self.is_loaded = True

    @torch.no_grad()
    def forward(self, images):
        if type(images) is list:
            image_features = []
            for image in images:
                _ , image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)).to(self.dtype)
                image_features.append(image_feature)
        else:
            _ , image_features = self.vision_tower(images.to(device=self.device, dtype=self.dtype)).to(self.dtype)
        return image_features
#image process. we resize the image to 432, which is RADIO's preferred_resolution.
 if self.data_args.image_aspect_ratio == 'pad':
      def expand2square(pil_img, background_color):
          width, height = pil_img.size
          if width == height:
              return pil_img
          elif width > height:
              result = Image.new(pil_img.mode, (width, width), background_color)
              result.paste(pil_img, (0, (width - height) // 2))
              return result
          else:
              result = Image.new(pil_img.mode, (height, height), background_color)
              result.paste(pil_img, ((height - width) // 2, 0))
              return result
      image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
      image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
  elif:
      width, height = image.size
      max_size = max(width,height)
      image = image.resize((max_size,max_size))
      image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 

Could you give me some suggestions?

Hello, in the experiments we published in our paper, we used an image pre-processor that resizes the longest edge to 432, adjusting the shortest edge to keep the original image aspect ratio, followed by a crop along the shortest edge to the nearest multiple of the patch size. This should be mostly equivalent to expand2square followed by a resize to 432x432, only without the padding along the shortest dimension. This requires support for variable-size, non-square images.

Are you using image_aspect_ratio == 'pad' as I suspect otherwise we might end up cropping actual pixels on the edges along the longest edge?

Thank you for your response! Yes, during finetuning, we used image_aspect_ratio == 'pad'. I'm now trying the experiment according to your instructions. Thank you very much!

Hello, RADIOV2 is still lower than SigClip. I would like to know if I have missed any operations in the feature extraction code below. Do I need to extract features from the second-to-last layer from vision tower like LLAVA? Or if I have overlooked the normalization operation? Or do I need to add the summary token?

# Feature extract
class RadioVisionTower(nn.Module):
    def __init__(self, vision_tower, args, delay_load=False):
        super().__init__()

        self.is_loaded = False

        self.vision_tower_name = vision_tower

        if not delay_load:
            self.load_model()
        else:
            self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name,trust_remote_code=True)

    def load_model(self):
        self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
        self.image_processor.do_resize = True
        self.image_processor.crop_size = self.image_processor.size
        self.vision_tower = AutoModel.from_pretrained(self.vision_tower_name, trust_remote_code=True)
        self.vision_tower.requires_grad_(False)

        self.is_loaded = True

    @torch.no_grad()
    def forward(self, images):
        if type(images) is list:
            image_features = []
            for image in images:
                _ , image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)).to(self.dtype)
                image_features.append(image_feature)
        else:
            _ , image_features = self.vision_tower(images.to(device=self.device, dtype=self.dtype)).to(self.dtype)
        return image_features

Hello, I have not worked with the HuggingFace model in LLaVA however equivalently you should be able to use the TorchHub model. In my LLaVA integration I used standard normalization instead of the built-in input conditioner (i.e. I make a call to vision_tower.make_preprocessor_external()).

This is my code (pardon the lack of untidiness):

from argparse import Namespace
import os
import torch
import torch.nn as nn
from typing import Any, Dict
import warnings

from transformers import CLIPVisionConfig
from transformers import CLIPImageProcessor, SamImageProcessor
from PIL import Image
import numpy as np


class RADIOVisionTower(nn.Module):
    """
    Vision Tower for the RADIO model.

    Args:
        vision_tower (str): Vision tower name. This is passed on
            the command line with the `--vision_tower` argument.
            The string is expected in the pattern of:
            `radio:<image_size>:<checkpoint_or_version>:<extra_config>`.
            Where <extra_config> is a comma-separated list of key=value pairs.
            <image_size> is the image resolution.
            <checkpoint> is a TorchHub version or path to a checkpoint.
        args (Namespace): Arguments.
        delay_load (bool): Delay loading the model.
    """
    def __init__(self, vision_tower, args, delay_load=False):
        """Initialization Routine."""

        super().__init__()

        self.vision_tower_name = vision_tower[len("radio:"):]
        config_items = self.vision_tower_name.split(":")
        self.image_sizes = [int(x) for x in config_items[0].split(",")]
        if len(self.image_sizes) == 0:
            raise ValueError("Expected more than zero images sizes!")
        self.image_size = self.image_sizes[0]
        self.do_center_crop = args.mm_im_crop

        self.vision_tower_checkpoint = config_items[1]

        extra_config = {}
        if len(config_items) > 2:
            # Parse extra config items. These are provided as a comma-separated list
            # of key=value pairs.
            extra_config_items = config_items[2].split(",")

            for item in extra_config_items:
                key, value = item.split("=")
                extra_config[key] = value

        self.adaptor_name = extra_config.get("adaptor", "backbone")
        self.fuse_adaptor_with_backbone = eval(extra_config.get("fuse_adaptor_with_backbone", "False"))
        self.skip_layer_norm = eval(extra_config.get("skip_layer_norm", "False"))

        self.is_loaded = False

        if not delay_load:
            self.load_model()
        else:
            # FIXME: This is a hack to avoid having to load the config from the checkpoint.
            hidden_size = self.get_hidden_size()
            patch_size = 16

            self.cfg_only = CLIPVisionConfig(
                **{

                    "hidden_size": hidden_size,
                    "image_size": self.image_size,
                    "model_type": "radio_vision_model",
                    "num_attention_heads": None,
                    "num_channels": 3,
                    "num_hidden_layers": None,
                    "patch_size": patch_size,
                }
            )

    def get_hidden_size(self):
        if self.adaptor_name == "openai_clip":
            hidden_size = 1024
        elif self.adaptor_name == "clip":
            hidden_size = 1280
        elif self.adaptor_name == "rtx-translate":
            hidden_size = 2048
        elif self.adaptor_name == "backbone":
            hidden_size = 1280
        else:
            raise ValueError(f"Unknown adaptor name: {self.adaptor_name}")

        if self.fuse_adaptor_with_backbone:
            hidden_size += 1280

        return hidden_size

    @property
    def hidden_size(self):
        return self.get_hidden_size()

    def load_model(self):

        crop_size={'height': self.image_size, 'width': self.image_size}

        if self.do_center_crop:
            self.image_processor = CLIPImageProcessor(
                size={"shortest_edge": self.image_size},
                crop_size=crop_size,
                do_center_crop=self.do_center_crop,
                do_normalize=True,
            )
        else:
            self.image_processor = SamImageProcessor(
                    size={"longest_edge": self.image_size},
                    pad_size={'height': self.image_size, 'width': self.image_size},
                    do_pad=False,
                    do_normalize=True,
            )
            # Add a crop_size attribute to the image processor, since the
            # train.py script needs this to generate fake images of zeros
            # with the right size, when the sample does not have an
            # associated image.
            self.image_processor.crop_size = crop_size

        # For compatibility with CLIP Image Processor: the data loader uses width/height to
        # create dummy blank images for samples that don't have an image.
        self.image_processor.crop_size = {"width": self.image_size, "height": self.image_size}

        checkpoint_path_or_version = self.vision_tower_checkpoint

        # NOTE: do a lazy import of Timm to avoid issues with
        # DeepSpeed's ZeRO-3.
        from timm.models.vision_transformer import VisionTransformer

        self.vision_tower = torch.hub.load('NVlabs/RADIO',
                                           'radio_model',
                                           version=checkpoint_path_or_version,
                                           progress=True,
                                           adaptor_names=self.adaptor_name if self.adaptor_name != "backbone" else None)

        if isinstance(self.vision_tower.model, VisionTransformer):
            hidden_size = self.vision_tower.model.embed_dim
        else:
            raise ValueError(f"Unknown model type: {self.vision_tower}")

        # Override hidden size for OpenAI CLIP.
        hidden_size = self.get_hidden_size()

        if hasattr(self.vision_tower.model, "patch_generator"):
            patch_gen = self.vision_tower.model.patch_generator
            # Cropped Positional Embedding (CPE) case.
            patch_size = patch_gen.patch_size
        else:
            # Standard ViT case.
            patch_size = self.vision_tower.model.patch_embed.patch_size[0]

        self.vision_tower.config = CLIPVisionConfig(
                **{
                    "hidden_size": hidden_size,
                    "image_size": self.image_size,
                    "model_type": "radio_vision_model",
                    "num_attention_heads": None,
                    "num_channels": 3,
                    "num_hidden_layers": None,
                    "patch_size": patch_size,
                }
            )

        self.vision_tower.make_preprocessor_external()
        self.vision_tower.eval()
        self.vision_tower.requires_grad_(False)

        self.is_loaded = True
        self._to_dtype = None

        if self.skip_layer_norm:
            self.vision_tower.model.norm = torch.nn.Identity()


    def to(self, *args, **kwargs):
        # Prevent casting the RADIO model's weights
        kwargs = dict(kwargs)
        self._to_dtype = kwargs.pop('dtype', None)
        super().to(*args, **kwargs)
        pass

    def train(self, mode=True):
        """Intercept call."""
        # Drop a warning if mode is True.
        if mode:
            warnings.warn("RADIOEncoder is always in eval mode.")
        pass

    @torch.no_grad()
    def get_features(self, x: torch.Tensor):
        output = self.vision_tower(x)
        if isinstance(output, dict):
            _, features = output[self.adaptor_name]
            if self.fuse_adaptor_with_backbone:
                _, backbone_features = output["backbone"]
                features = torch.cat([features, backbone_features], dim=2)
        else:
            _, features = output
        return features

    @torch.no_grad()
    def forward(self, images: torch.Tensor):
        """Main forward pass."""
        input_shape = images.shape

        x = images

        # Add a batch dimension if necessary.
        if len(input_shape) == 3:
            x = x.unsqueeze(0)

        # Convert the input to the model's dtype (we assume
        # that the model only has one dtype for all parameters).
        param0 = next(self.vision_tower.parameters())
        x = x.to(dtype=param0.dtype, device=param0.device)

        patch_size = self.vision_tower.config.patch_size

        if self.do_center_crop:
            # Crop the input to a multiple of patch size.
            _, _, H, W = x.shape

            H = H - (H % patch_size)
            W = W - (W % patch_size)

            x = x[:, :, :H, :W]
        else:
            # Pad to nearest multiple of patch size
            _, _, H, W = x.shape
            H = H + (patch_size - (H % patch_size)) % patch_size
            W = W + (patch_size - (W % patch_size)) % patch_size
            x = nn.functional.pad(x, (0, W - x.shape[3], 0, H - x.shape[2]), mode="constant", value=0)

        features = self.get_features(x) # B, T, C

        B, _, H, W = x.shape
        _, _, C = features.shape

        # Remove the batch dimension if we added it.
        if len(input_shape) == 3:
            features = features.squeeze(0)

        # Cast back to the input's dtype.
        features = features.to(images.dtype)

        assert features.shape[-1] == self.get_hidden_size()

        return features

Thank you! I‘m also curious about the setting of "extra_config" and "config_items ". Is the setting for the following parameters is true or false?

self.adaptor_name = extra_config.get("adaptor", "backbone")
self.fuse_adaptor_with_backbone = eval(extra_config.get("fuse_adaptor_with_backbone", "False"))
self.skip_layer_norm = eval(extra_config.get("skip_layer_norm", "False"))

Hi, in my standard configuration the adaptor is backbone, and fuse_adaptor_with_backbone and skip_layer_norm are both False.

Hi, in my standard configuration the adaptor is backbone, and fuse_adaptor_with_backbone and skip_layer_norm are both False.

Thank you for your prompt response and your great work!

Hello, have you been able to get RADIO to perform well in your VLM setup?

I'm sorry, to be honest, I can't achieve better results than Sigclip under the same settings. Sigclip has a resolution of 384, while Radio's resolution is dynamic (with a maximum size set to 1280). To save time, we use qwen2 0.5b as LLM. And we also add some OCR data such as docvqa and textvqa. However, the experiments are at the same setting.

image

Hello, have you been able to get RADIO to perform well in your VLM setup?