Dimensional Collapse Tracking
Opened this issue · 1 comments
zankerx commented
Hello, first of all, thank you for sharing this kind of model.
I have a question regarding the development of this model, particularly about the collapse of the dimension. Do you use any other indicators besides loss during training to select hyper-parameters and maximize the amount of information contained in the embedding? (Knn, eigenvalues of the embedding, or others?).
Thank you in advance for your response :)
zankerx commented
In the code, I noticed that you also have a regularization term that is not being used, with the coefficient value set to 0 by default.
'def reg_fn(z):
return sum([torch.sqrt(zi.var(dim=1) + 0.0001) for zi in z]) / len(z)
# Step 1. Forward
loss_jepa, loss_reg = 0., 0.
with torch.cuda.amp.autocast(dtype=dtype, enabled=mixed_precision):
h = forward_target(clips)
z = forward_context(clips, h)
loss_jepa = loss_fn(z, h) # jepa prediction loss
pstd_z = reg_fn(z) # predictor variance across patches
loss_reg += torch.mean(F.relu(1.-pstd_z))
loss = loss_jepa + reg_coeff * loss_reg`
Have you studied the impact of this regularization on the model ?