summary batch_size parameter
Pikauba opened this issue · 0 comments
Pikauba commented
It seems like the batch size is hardcoded as 2 here -> this line should't this be x = [torch.rand(batch_size, *in_size).type(dtype).to(device=device) for in_size, dtype in zip(input_size, dtypes)]
(see the batch_size variable in torch.rand instead of the hardcoded 2) and deal with the case where the batch_size=-1 as a predefined batch of 2.