sthalles/SimCLR-tensorflow

Clarification regarding the shapes expected by the losses

Closed this issue · 0 comments

Hi.

Thanks for your article on SimCLR and also for coming up with a TensorFlow version of the code. Really appreciate your efforts here.

I am trying to come up with a minimal implementation of SimCLR with the CIAFR10 dataset and I am using a number of the utilities you have already put together. To that end, I wanted to have some clarification regarding the expected shapes in the loss functions even though you have stated them.

My training loop is identical to yours -

def train_step(xis, xjs, model, optimizer, criterion, temperature):
    with tf.GradientTape() as tape:
        zis = model(xis)
        zjs = model(xjs)

        # normalize projection feature vectors
        zis = tf.math.l2_normalize(zis, axis=1)
        zjs = tf.math.l2_normalize(zjs, axis=1)

        # tf.summary.histogram('zis', zis, step=optimizer.iterations)
        # tf.summary.histogram('zjs', zjs, step=optimizer.iterations)

        l_pos = sim_func_dim1(zis, zjs)
        l_pos = tf.reshape(l_pos, (config['batch_size'], 1))
        l_pos /= temperature
        assert l_pos.shape == (config['batch_size'], 1), "l_pos shape not valid" + str(l_pos.shape)  # [N,1]

        negatives = tf.concat([zjs, zis], axis=0)

        loss = 0

        for positives in [zis, zjs]:
            l_neg = sim_func_dim2(positives, negatives)

            labels = tf.zeros(128, dtype=tf.int32)

            l_neg = tf.boolean_mask(l_neg, negative_mask)
            l_neg = tf.reshape(l_neg, (config['batch_size'], -1))
            l_neg /= config['temperature']

            assert l_neg.shape == (
                config['batch_size'], 2 * (config['batch_size'] - 1)), "Shape of negatives not expected." + str(
                l_neg.shape)
            logits = tf.concat([l_pos, l_neg], axis=1)  # [N,K+1]
            loss += criterion(y_pred=logits, y_true=labels)

        loss = loss / (2 * 128)

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    return loss

Now, the shapes of zis and zjs are - (128 , 8, 8, 128) and when it reaches to _dot_simililarity_dim1 it throws -

InvalidArgumentError: In[0] mismatch In[1] shape: 128 vs. 8: [128,1,8,8,128] [128,8,1,8,128] 0 0 [Op:BatchMatMulV2]

I was wondering if you could shed some light on this. This is a Colab Gist in case you would want to reproduce this issue.