pytorch/tnt

Increment epoch should not reset num_steps_completed_in_epoch

Opened this issue · 0 comments

train_unit.train_progress.increment_epoch()
train_unit.on_train_epoch_end(state)
callback_handler.on_train_epoch_end(state, train_unit)

def increment_epoch(self) -> None:
"""Increment the epochs completed and resets the steps completed within the epoch."""
self._num_epochs_completed += 1
self._num_steps_completed_in_epoch = 0

If num_steps_completed_in_epoch is reset in increment_epoch, then the callback on_train_epoch_end cannot have access to the number of steps completed in the epoch.