Question about the max-pooling operation.
rongaoli opened this issue · 2 comments
Hi, congratulations on completing this great work!
I have some questions about the Critic model while reading your code:
In your paper, you mentioned that "The contextual hidden states of the program tokens (h1, . . . , hT ) obtained from the critic model decoder are max-pooled along the sequence length dimension." However, in your code, you did not perform max-pooling along the sequence length dimension on h(contextual hidden states, last dim size is: config.d_model), but on error_states(last dim size is: 4) by taking the maximum value along the first dimension.
Location:
the problem code in CodeRL/transformers/src/transformers/models/t5/modeling_t5.py
class: T5ForConditionalGeneration
method: forward
self.error_head = nn.Sequential(
nn.Linear(config.d_model, 128),
nn.ReLU(),
nn.Linear(128, 4)
)
if error_types is not None:
error_states = self.error_head(sequence_output)
error_logits, _ = error_states.max(1)
error_pred_loss_fct = CrossEntropyLoss()
error_pred_loss = error_pred_loss_fct(error_logits.view(-1, error_logits.size(-1)), error_types.view(-1))
_, error_preds = torch.max(error_logits, dim=-1)
if return_error_hidden_states:
return error_pred_loss, error_preds, error_states
return error_pred_loss, error_preds
Thanks for raising this concern.
In my experience, I didn't notice large performance gaps when applying maxpooling on decoder hidden states or on the hidden states after linear transformation.
Helps a lot,Thank you for your reply :)