In the docs in Logic.train_step, batch_idx is assumed to be the number of iterations executed, but in reality, it is the number of iterations in epoch.
|
def train_step( |
|
self, |
|
models: Mapping[str, torch.nn.Module], |
|
optimizers: Mapping[str, torch.optim.Optimizer], |
|
batch_idx: int, |
|
batch: Any, |
|
) -> Any: |
|
"""A method invokes the model forward and backward passes. |
|
|
|
Optimizing is left to `train_step_optimizers` since maybe the user |
|
would like to aggregate the gradients of several iterations. |
|
|
|
Args: |
|
models (dict of torch.nn.Module): |
|
The models. |
|
optimizers (dict of torch.optim.Optimizer): |
|
The optimizers. |
|
batch_idx (int): |
|
Number of training steps already finished. |
|
batch (torch.Tensor, list of torch.Tensor, dict of torch.Tensor): |
|
Input tensors feeded to the model of the current step. |
|
""" |
|
self.handler.train_step( |
|
self, |
|
idx, |
|
x, |
|
complete_fn=self._complete_step, |
|
) |
|
for idx in range(train_len): |