[Bug] Incorrect categorical distribution setting
lzl65825 opened this issue · 0 comments
lzl65825 commented
In RLTF, the codes use torch.distributions.Categorical
to sample actions. For example, in line 142 of benchmark_tasks/rltf/rltf_schema_flan_t5.py:
action = torch.distributions.Categorical(torch.stack(log_prob).detach()).sample()
However, if the args of Categorical are not designated, it will use probs
instead of logits
. Thus, it should be
action = torch.distributions.Categorical(logits=torch.stack(log_prob).detach()).sample()