tung-nd/TNP-pytorch

About the attention mask for TNP-A

Closed this issue · 5 comments

Thanks for your awesome work and codes!

I notice that the constructed attention mask allows each fake target point attends to preceding real target points in

mask[num_all:, num_ctx:num_all].triu_(diagonal=0) # each fake target point attends to preceeding real target points

I am confused about if it is correct to use real target points' info (label) when testing the model on the meta-valid/test set. Did I miss something?

Thanks a lot for your reply

Hi, this corresponds to the autoregressive model, where we factorize the joint distribution of p(y_1:N | x_1:N, C) as p(y_1:N | x_1:N, C) = p(y_1 | x_1, C)p(y_2 | x_2, x_1, y_1, C) ... p(y_N | x_N, x_1:N-1, y_1:N-1, C) --> we condition on both the context and the preceding target points when making predictions for a target point

The practice of conditioning on the ground-truth information of previous tokens is often referred to as teacher-forcing in autoregressive models. Note that this is only applicable for evaluating the likelihood of a target sequence. During generation, we have to generate the tokens one by one as we don't know the entire sequence beforehand

Hi, this corresponds to the autoregressive model, where we factorize the joint distribution of p(y_1:N | x_1:N, C) as p(y_1:N | x_1:N, C) = p(y_1 | x_1, C)p(y_2 | x_2, x_1, y_1, C) ... p(y_N | x_N, x_1:N-1, y_1:N-1, C) --> we condition on both the context and the preceding target points when making predictions for a target point

Thank you very much for your reply. Conditioning on the real target point is reasonable for autoregressive models during training, but I am a little confused about whether it should still condition on the real target point for the test. When testing the model, should we condition on the preceding predicted target point or condition on the preceding real target point?

during testing, if you want to evaluate the likelihood of a sequence, you can condition on the preceding real target points, because you know the entire sequence beforehand, and you just want to see if this sequence is likely under the model

if you want to evaluate the accuracy, however, it becomes a prediction task, which means you only know the x's and not the y's, the model must condition of the preceding predicted target points, as the real points are unknown

I get it. Thank you very much.