agiresearch/OpenAGI

[Bug] Incorrect categorical distribution setting

lzl65825 opened this issue · 0 comments

In RLTF, the codes use torch.distributions.Categorical to sample actions. For example, in line 142 of benchmark_tasks/rltf/rltf_schema_flan_t5.py:

action = torch.distributions.Categorical(torch.stack(log_prob).detach()).sample()

However, if the args of Categorical are not designated, it will use probs instead of logits. Thus, it should be

action = torch.distributions.Categorical(logits=torch.stack(log_prob).detach()).sample()