google-research/big_vision

Question about SigLIP

maximzubkov opened this issue · 7 comments

Hello, Google Research team!

Thanks a lot for your work! I came across your paper SigLIP and was curious to reproduce the results myself on another dataset. I checked the README and it says that the SigLIT code is in TODO status. However in the codebase both sigmoid_loss and chunked_sigmoid_loss are both implemented and integrated into the training script as well as config is defined in this .ipynb. So my question is the following: is there is still something missing in SigLIP, or I can already try to run it using the command like this:

big_vision.trainers.proj.image_text.contrastive \
    --config ... \
    --workdir ...

I also have another question about the paper itself, or it's rather an ask for a recommendation. You pre-trained some models with an image encoder frozen, and in 2 days you achieved very competitive scores. However, 20k batch size and 107k total steps which means that the model saw a total of 2.14B image-text examples with text model initialized from scratch and huge ViT-g/14. What do you think about an inverse experiment, how long it is gonna take to train a ViT from scratch having a nice pre-trained text representation? The reason why I'm asking is that in my research I'm dealing with pretty non-conventional images, but regular texts

Looking forward to more research, thank you!

Hi Maxim,

Question 1: good catch, the latest contrastive trainer was actually opensourced, but not yet well tested in the OSS setting (I have focused my effort on the main trainer so far). Maybe just try it and let us know how it goes.

Question 2: I guess it ultimately depends on the data. For our data and models, locking text representation was really bad, see Figure 3 in https://arxiv.org/abs/2111.07991.

Thank you for the fast reply, @akolesnikoff! Will try it

Hello, @akolesnikoff!

I made some progress since our last conversation. As we agreed, I tried to run contrastive pre-training on a server with a single NVIDIA A100 (40G) using the coco_captions dataset. The first thing I did was wrapping everything in docker. Unfortunately, a classic installation of big_vision described in the README didn't work for me, so I'll explain here the journey to build a working docker image in case other people run into the same problems:

  • I installed clu with --no-deps flag since otherwise, pip was attempting to install every possible version of jaxlib
  • As a base image I first used the official tensorflow/tensorflow:2.14.0-gpu docker image and the latest version of jax, however, I faced a segmentation fault happening during the execution of 293 line of contrastive.py
  • Then I ran SigLIP demo in colab, and it worked on GPU with jax==0.4.16+cu11 and tensorflow==2.13.0, so I basically reproduced the same environment in the Dockerfile
  • Unfortunately the official image tensorflow/tensorflow:2.13.0-gpu is built with python-3.8, while jax==0.4.16 requires python>=3.9, so I had to use nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 as a base image, install python-3.11, and on top of it install tensorflow[and-cuda]==2.13.1 and jax[cuda11_pip]==0.4.16
  • Then I had to downgrade tensorflow-addons and tensorflow-text since otherwise they were incompatible with tensorflow[and-cuda]==2.13.1

So the final Dockerfile that I have now and that works is:

FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04

RUN apt-get update -y && apt-get install -y --no-install-recommends build-essential \
                       ca-certificates \
                       wget \
                       curl \
                       unzip \
                       ssh \
                       git \
                       vim \
                       jq

ENV DEBIAN_FRONTEND="noninteractive" TZ="Europe/London"
ENV LANG C.UTF-8
ENV LC_ALL C.UTF-8

RUN apt-get install -y software-properties-common
RUN add-apt-repository ppa:deadsnakes/ppa
RUN apt-get install -y python3.11-full
RUN apt-get install -y python3.11-dev
RUN apt-get clean
RUN ln -s /usr/bin/python3.11 /usr/bin/python
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.11
RUN python -m pip install --upgrade pip

WORKDIR /code

RUN pip install tensorflow[and-cuda]==2.13.1
RUN pip install "jax[cuda11_pip]==0.4.16" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
RUN git clone https://github.com/google-research/big_vision
WORKDIR /code/big_vision

RUN pip install absl-py
RUN pip install einops
RUN pip install optax
RUN pip install tensorstore
RUN pip install flax
RUN pip install git+https://github.com/google/flaxformer
RUN pip install git+https://github.com/akolesnikoff/panopticapi.git@mute
RUN pip install overrides

RUN pip install clu --no-deps

RUN pip install tfds-nightly
RUN pip install tensorflow-addons==0.21.0
RUN pip install tensorflow-text==2.13.0rc
RUN pip install tensorflow-gan

RUN pip install ml_collections

RUN pip cache purge

Next, I tried to run a contrastive experiment with a B/16 model and a resolution of 224, so I used the following config inspired by both LiT and SigLIP demo:

import ml_collections

import big_vision.pp.builder as pp_builder
import big_vision.pp.ops_general
import big_vision.pp.ops_image
import big_vision.pp.ops_text
import PIL

import big_vision.configs.common as bvcc
from big_vision.configs.proj.image_text import common
from ml_collections import ConfigDict

VARIANT, RES = 'B/16', 224

MODEL2PARAMS = {
    ('B/16', 224): ('webli_en_b16_224_63724782.npz', 'B', 768, 64, 32_000),
    ('B/16', 256): ('webli_en_b16_256_60500360.npz', 'B', 768, 64, 32_000),
    ('B/16', 384): ('webli_en_b16_384_68578854.npz', 'B', 768, 64, 32_000),
    ('B/16', 512): ('webli_en_b16_512_68580893.npz', 'B', 768, 64, 32_000),
    ('L/16', 256): ('webli_en_l16_256_60552751.npz', 'L', 1024, 64, 32_000),
    ('L/16', 384): ('webli_en_l16_384_63634585.npz', 'L', 1024, 64, 32_000),
    ('So400m/14', 224): ('webli_en_so400m_224_57633886.npz', 'So400m', 1152, 16, 32_000),
    ('So400m/14', 384): ('webli_en_so400m_384_58765454.npz', 'So400m', 1152, 64, 32_000),
    ('B/16-i18n', 256): ('webli_i18n_b16_256_66117334.npz', 'B', 768, 64, 250_000),
}

VOCAB2TOKENIZER = {32_000: 'c4_en', 250_000: 'mc4'}

def params_from_model(variant: str, res: int):
    variant_ = variant[:-len('-i18n')] if variant.endswith('-i18n') else variant
    return (*MODEL2PARAMS[(variant, res)], variant_)


def get_config(arg=None):
  """The base configuration."""
  arg = bvcc.parse_arg(
      arg, 
      variant=VARIANT,
      res=RES, 
  )
  ckpt, txt_variant, emb_dim, seq_len, vocab, img_variant = params_from_model(variant=arg.variant, res=arg.res)
  tokenizer_name = VOCAB2TOKENIZER[vocab]
  
  config = ConfigDict()

  config.input = {}
  config.input.data = dict(name='coco_captions', split='train')
  config.input.batch_size = 32
  config.input.shuffle_buffer_size = 5_000

  config.total_steps = 5_000

  config.init_shapes = [(1, arg.res, arg.res, 3), (1, seq_len)]
  config.init_types = ['float32', 'int32']

  tokenizer = lambda inkey: \
    f'tokenize(inkey="{inkey}", max_len={seq_len}, model="{tokenizer_name}", eos="sticky", pad_value=1)'
  config.input.pp = (
      f'decode|resize({arg.res})|value_range(-1, 1)'
      f'|flatten|{tokenizer("captions/text")}|keep("image", "labels")'
  )
  pp_eval = (
      f'decode|resize({arg.res})|value_range(-1,1)'
      f'|flatten|{tokenizer("captions/text")}'
      '|keep("image", "labels")'
  )
  config.pp_modules = [
      'ops_general', 'ops_image', 'ops_text', 'proj.flaxformer.bert_ops']

  config.log_training_steps = 50
  config.ckpt_steps = 1000

  # Model section
  config.model_name = 'proj.image_text.two_towers'
  config.model_load = {}
  config.model_init = f'/tmp/{ckpt}'
  config.model = ConfigDict()
  config.model.image_model = 'vit'
  config.model.text_model = 'proj.image_text.text_transformer'
  config.model.image = ConfigDict({'variant': img_variant, 'pool_type': 'map'})
  config.model.text = ConfigDict({'variant': txt_variant, 'vocab_size': vocab})
  config.model.out_dim = (None, emb_dim)
  config.model.bias_init = -10.0
  config.model.temperature_init = 10.0

  config.optax_name = 'scale_by_adam'

  config.lr = 0.001
  config.wd = 0.01
  warmup_steps = max(int(0.03 * config.total_steps), 100)
  config.schedule = [
      ('img/.*', None),  # Freezes image tower.
      ('.*', dict(decay_type='cosine', warmup_steps=warmup_steps)),
  ]

  config.grad_clip_norm = 1.0

  # Eval section (Both few-shot and zero-shot)
  eval_common = dict(
      type='proj.image_text.contrastive',
      use_global_batch=True,
      log_steps=5,
  )
  config.evals = {}
  sub = '[:5000]'
  config.evals.val = {
      **eval_common,
      'data': dict(name=config.input.data.name, split=f'val{sub}'),
      'pp_fn': pp_eval,
  }
  config.evals.coco = {
      **eval_common,
      'data': dict(name='coco_captions', split=f'val{sub}'),
      'pp_fn': (
          f'decode|resize({arg.res})|value_range(-1,1)'
          f'|flatten|{tokenizer("captions/text")}|keep("image", "labels")'),
  }
  config.evals.imagenet = {
      **eval_common,
      'data': dict(name='imagenet2012', split=f'validation{sub}'),
      'pp_fn': (
          f'decode|resize({arg.res})|value_range(-1,1)'
          '|clip_i1k_label_names'
          f'|{tokenizer("labels")}|keep("image", "labels")'),
  }

  config.evals.disclf = {}
  config.evals.disclf.pp_img = f'resize({arg.res})|value_range(-1,1)'
  config.evals.disclf.pp_txt = tokenizer('texts')
  config.evals.disclf.type = 'proj.image_text.discriminative_classifier'
  config.evals.disclf.prefix = 'z/0shot/'
  config.evals.disclf.log_steps = eval_common['log_steps']
  config.evals.retrieval_coco = common.get_coco(
      pp_img=f'resize({arg.res})|value_range(-1, 1)',
      pp_txt=tokenizer('texts'),
      log_steps=config.evals.disclf.log_steps,
  )

  config.seed = 0
  config.l = config.m = 0

  return config

However, now I'm facing the following error, which might be caused by the recent refactoring you described in the issue.

I1105 19:43:27.777555 139669404165952 contrastive.py:236] NOTE: Init evaluator: disclf…
Steps:0/5000 [0.0%]
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/code/big_vision/big_vision/trainers/proj/image_text/contrastive.py", line 514, in <module>
    app.run(main)
  File "/usr/local/lib/python3.11/dist-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.11/dist-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/code/big_vision/big_vision/trainers/proj/image_text/contrastive.py", line 427, in main
    for (name, evaluator, _, prefix) in evaluators():
                                        ^^^^^^^^^^^^
  File "/code/big_vision/big_vision/trainers/proj/image_text/contrastive.py", line 369, in evaluators
    return eval_common.from_config(
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/code/big_vision/big_vision/evaluators/common.py", line 64, in from_config
    raise RuntimeError(
RuntimeError: You are seemingly using new jit-based evaluator, but with old pmap-based train loop, see (internal link) for more details.
Exception ignored in: <function Pool.__del__ at 0x7f0750f691c0>
Traceback (most recent call last):
  File "/usr/lib/python3.11/multiprocessing/pool.py", line 271, in __del__
  File "/usr/lib/python3.11/multiprocessing/queues.py", line 371, in put
AttributeError: 'NoneType' object has no attribute 'dumps'

I'm not very experienced in Jax, so I would be very grateful if you could check whether my Docker and config are correct and guide me on how this issue could be solved.

sorry, we forgot to push some code. I will ping you in this thread once we fix the issue.

Great, thanks!

This config should work for siglip training right now: https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/image_text/siglip_lit_coco.py.

Note it is a minimal example that runs training on the coco captions dataset, do not expect great results.

Thanks a lot, @akolesnikoff ! Will have a look this week
I think for now we can close the issue, in case Ill stumble across any problems, Ill reopen it