facebookresearch/jepa

Dimensional Collapse Tracking

Opened this issue · 1 comments

Hello, first of all, thank you for sharing this kind of model.

I have a question regarding the development of this model, particularly about the collapse of the dimension. Do you use any other indicators besides loss during training to select hyper-parameters and maximize the amount of information contained in the embedding? (Knn, eigenvalues of the embedding, or others?).

Thank you in advance for your response :)

In the code, I noticed that you also have a regularization term that is not being used, with the coefficient value set to 0 by default.

            'def reg_fn(z):
                return sum([torch.sqrt(zi.var(dim=1) + 0.0001) for zi in z]) / len(z)
            # Step 1. Forward
            loss_jepa, loss_reg = 0., 0.
            with torch.cuda.amp.autocast(dtype=dtype, enabled=mixed_precision):
                h = forward_target(clips)
                z = forward_context(clips, h)
                loss_jepa = loss_fn(z, h)  # jepa prediction loss
                pstd_z = reg_fn(z)  # predictor variance across patches
                loss_reg += torch.mean(F.relu(1.-pstd_z))
            loss = loss_jepa + reg_coeff * loss_reg`

Have you studied the impact of this regularization on the model ?