Clarification regarding the shapes expected by the losses
Closed this issue · 0 comments
sayakpaul commented
Hi.
Thanks for your article on SimCLR and also for coming up with a TensorFlow version of the code. Really appreciate your efforts here.
I am trying to come up with a minimal implementation of SimCLR with the CIAFR10 dataset and I am using a number of the utilities you have already put together. To that end, I wanted to have some clarification regarding the expected shapes in the loss functions even though you have stated them.
My training loop is identical to yours -
def train_step(xis, xjs, model, optimizer, criterion, temperature):
with tf.GradientTape() as tape:
zis = model(xis)
zjs = model(xjs)
# normalize projection feature vectors
zis = tf.math.l2_normalize(zis, axis=1)
zjs = tf.math.l2_normalize(zjs, axis=1)
# tf.summary.histogram('zis', zis, step=optimizer.iterations)
# tf.summary.histogram('zjs', zjs, step=optimizer.iterations)
l_pos = sim_func_dim1(zis, zjs)
l_pos = tf.reshape(l_pos, (config['batch_size'], 1))
l_pos /= temperature
assert l_pos.shape == (config['batch_size'], 1), "l_pos shape not valid" + str(l_pos.shape) # [N,1]
negatives = tf.concat([zjs, zis], axis=0)
loss = 0
for positives in [zis, zjs]:
l_neg = sim_func_dim2(positives, negatives)
labels = tf.zeros(128, dtype=tf.int32)
l_neg = tf.boolean_mask(l_neg, negative_mask)
l_neg = tf.reshape(l_neg, (config['batch_size'], -1))
l_neg /= config['temperature']
assert l_neg.shape == (
config['batch_size'], 2 * (config['batch_size'] - 1)), "Shape of negatives not expected." + str(
l_neg.shape)
logits = tf.concat([l_pos, l_neg], axis=1) # [N,K+1]
loss += criterion(y_pred=logits, y_true=labels)
loss = loss / (2 * 128)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
return loss
Now, the shapes of zis
and zjs
are - (128 , 8, 8, 128)
and when it reaches to _dot_simililarity_dim1
it throws -
InvalidArgumentError: In[0] mismatch In[1] shape: 128 vs. 8: [128,1,8,8,128] [128,8,1,8,128] 0 0 [Op:BatchMatMulV2]
I was wondering if you could shed some light on this. This is a Colab Gist in case you would want to reproduce this issue.