blt2114/ProtDiff_SMCDiff

What means "Kt_1 = self._diffuser.K ** t"?

Opened this issue · 1 comments

In torch_train_diffusion.py, Line 279

Kt_1 = self._diffuser.K ** t

I think there is no attribute 'K' on Diffuser class.
Is that bb_mask_2d or something like that?
What attribute K means?

Hi Buddha7771,

Thank you for your question. When we trained the model, self._diffuser.K was a batch of NxN identity matrices which did not impact the loss. And so I have just now removed it from the training code.

It was there because we had explored other losses that, instead of simply using squared L2 error, used a squared K_t-quadratic norm loss, where K_t was a time-step-dependent positive semidefinite matrix.

As this was too unstable to reliably improve training, did not pursue this direction for long. But since squared L2 error was initially implemented as a special case we forgot to strike all references to the implementation of this idea.