pfnet/pytorch-pfn-extras

batch_idx rules for train_step

Opened this issue · 0 comments

In the docs in Logic.train_step, batch_idx is assumed to be the number of iterations executed, but in reality, it is the number of iterations in epoch.

def train_step(
self,
models: Mapping[str, torch.nn.Module],
optimizers: Mapping[str, torch.optim.Optimizer],
batch_idx: int,
batch: Any,
) -> Any:
"""A method invokes the model forward and backward passes.
Optimizing is left to `train_step_optimizers` since maybe the user
would like to aggregate the gradients of several iterations.
Args:
models (dict of torch.nn.Module):
The models.
optimizers (dict of torch.optim.Optimizer):
The optimizers.
batch_idx (int):
Number of training steps already finished.
batch (torch.Tensor, list of torch.Tensor, dict of torch.Tensor):
Input tensors feeded to the model of the current step.
"""

self.handler.train_step(
self,
idx,
x,
complete_fn=self._complete_step,
)

for idx in range(train_len):