lucidrains/PaLM-rlhf-pytorch

Why the value calculate in generate and learn use different mask?

Nightbringers opened this issue · 1 comments

I'm very confused about the value calculate, why use different mask? In generate method, the mask include prompt. But when training in learn method, the mask did not include prompt.
this is in learn method:
action_masks = ~prompt_masks & masks
action_logits, values = self.actor_critic(
sequences,
mask = action_masks
)
and in generate method:
mask = None
if exists(eos_token):
mask = ((sequence == eos_token).cumsum(dim = -1) == 0)
mask = F.pad(mask, (1, -1), value = True) # include eos token
action_logits, value = self.forward(
sequence,
mask = mask,
return_values = return_values
)

@Nightbringers yes you are correct! thank you for catching this! a0b9774