yxonic/DTransformer

Is the final concatenation of the question for t+1 time step or t time step in the code?

Closed this issue · 3 comments

In the model, it is mentioned that the final fusion should be qt+1, but in the code, it seems like only qt is being concatenated and there is no indication of any shift operation.

n=1
query = q_emb[:, n - 1 :, :]
y = self.out(torch.cat([query, h], dim=-1)).squeeze(-1)

The following code also supports this observation. It is supposed to predict the label for t+1 time step (starting from the second label), but the answer for s starts from the first label.

masked_logits = logits[s >= 0]
masked_labels = s[s >= 0].float()
pred_loss = F.binary_cross_entropy_with_logits(
masked_logits, masked_labels, reduction="mean"
)
Are there any details that I may have missed? If you can provide any insights, I would greatly appreciate it.

May I ask if the above mentioned students have resolved the respective queries? Would like to discuss it with you. I was reading through the code and found the same query as you raised above, "Is the final concatenation of the question for t+1 time step or t time step in the code?". Thank you for taking time out of your busy schedule to answer my question.

Thanks for your interest in our work! To answer your question, the peek_cur argument in the DTransformerLayer class does the job for the "shifting". With the last layer set up as peek_cur=False, z[t] will only contain history strictly before time step t. The implementation is based on mask manipulation. You may refer to

query, hq, p, torch.repeat_interleave(lens, n_know), peek_cur=False
and
mask = torch.ones(seqlen, seqlen).tril(0 if peek_cur else -1)

With this trick, concatenating z[t] and q[t] in code is effectively concatenating $q_t$ and $z_{t-1}$, or equivalently, $q_{t+1}$ and $z_t$, in math.

Your response is very clear, thank you for your explanation!