alexhernandezgarcia/gflownet

sample_batch args in `test_top_k` in gflownet.py

Closed this issue · 1 comments

I just cloned a fresh copy of the repo in WSL and tried a run with all default configs, and it looks like the sample_batch on line 906 in gflownet.py is getting the wrong arguments (at least, it crashes every time). Looks like the parent function test_top_k was added August 1.

call:

for b in batch_with_rest(0, self.logger.test.n_top_k, self.batch_size.forward):
    gfn_states += self.sample_batch(
        self.env, len(b), train=False, progress=progress
    )[0]

function definition:

def sample_batch(
    self,
    n_forward: int = 0,
    n_train: int = 0,
    n_replay: int = 0,
    train=True,
    progress=False,
):

You were right. This crash was due to a large domino-multi-way merge of branches into main. This PR fixes this issue.