google-research/big_vision

Add chunked implementation of siglip sigmoid loss

Opened this issue · 0 comments

Thank you for the great work on siglip paper. In particular, the gains in metrics at small batch sizes are impressive.
Sigmoid Loss for Language Image Pre-Training

I tried using this pytorch based distributed chunked implementation in open clip repo using torch.distributed.P2POp
https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/loss.py#L307

I tried taking a ViT B vision encoder + XLM Roberta text encoder and train it using both CLIP softmax and SigLip sigmoid loss on an in house dataset of 10M image-text pairs at an effective batch size of 9k (with V100 GPUs) and observed that CLIP softmax still performs better than siglip sigmoid loss on nDCG metric.

I was wondering if there is any error in above implementation using p2pop. I also tried using an all_gather to get negative text_features from other gpus but still the behavior seems to be the same

class SigLipLossAllGather(nn.Module):
    """ Sigmoid Loss for Language Image Pre-Training (SigLIP) - https://arxiv.org/abs/2303.15343
    """
    def __init__(
            self,
            logit_scale=np.log(10),
            logit_bias=-10
    ):
        super().__init__()
        self.logit_scale = logit_scale
        self.logit_bias = logit_bias

        self.labels = {}

    def get_ground_truth(self, device, dtype, num_logits, negative_only=False) -> torch.Tensor:
        labels = -torch.ones((num_logits, num_logits), device=device, dtype=dtype)
        if not negative_only:
            labels = 2 * torch.eye(num_logits, device=device, dtype=dtype) + labels
        return labels

    def get_logits(self, image_features, text_features, logit_scale, logit_bias=None):
        logits = logit_scale * image_features @ text_features.T
        if logit_bias is not None:
            logits += logit_bias
        return logits

    def _loss(self, image_features, text_features, logit_scale, logit_bias=None, negative_only=False):
        logits = self.get_logits(image_features, text_features, logit_scale, logit_bias)
        labels = self.get_ground_truth(
            image_features.device,
            image_features.dtype,
            image_features.shape[0],
            negative_only=negative_only,
        )
        loss = -F.logsigmoid(labels * logits).sum() / image_features.shape[0]
        return loss

    def forward(self, image_features, text_features):
        
        loss = self._loss(image_features, text_features, self.logit_scale, self.logit_bias)

        if global_manager.world_size > 1:

            # Gather text features from all ranks
            text_features_dict = {'text_features': text_features, 'global_rank': global_manager.rank}
            all_text_features_dict = all_gather_objects(text_features_dict)

            # Compute loss against negative text features from all other ranks
            for i in range(len(all_text_features_dict)):
                neigh_rank = all_text_features_dict[i]['global_rank']
                neigh_text_features = all_text_features_dict[i]['text_features']
                if neigh_rank != global_manager.rank:
                    loss += self._loss(image_features, neigh_text_features, self.logit_scale, self.logit_bias, negative_only=True)
                    
        return {"loss": loss}

The implementation in this repo seems to be the non-chunked version

def loss_fn(params):
zimg, ztxt, extras = model.apply(
{"params": params}, images, labels,
train=True, rngs={"dropout": rng_model})
logits = jnp.dot(zimg, ztxt.T)
logits = logits * extras["t"] + extras["b"]
eye = jnp.eye(zimg.shape[0])
# Standard sigmoid computes everything twice, once assuming positive
# labels and once assuming negative ones. But here we know exactly where
# to find positives (on "me" diagonal) and negatives (everywhere else),
# so compute each one's loss only once:
m1_diag1 = -jnp.ones_like(logits) + 2 * eye
loglik = jax.nn.log_sigmoid(m1_diag1 * logits)
# Normalize by npos per column, but that's one, so just sum.
nll = -jnp.sum(loglik, axis=-1)
# NOTE: same as concat'ing me/ot along axis -1 above.
l = jnp.mean(nll)
return l

I am wondering if you can add a chunked implementation of the siglip sigmoig loss.