pytorch/tnt

How to check `is_last_batch` in torchtnt==0.1.0?

yiminglin-ai opened this issue ยท 2 comments

๐Ÿ› Describe the bug

In the train_step method of a Callback class:

def train_step(self, state: State, data: TrainBatch) -> None:
  global_step = state.train_state.progress.num_steps_completed
  is_last_batch = state.train_state.is_last_batch # error

Error:

AttributeError:'PhaseState' object has no attribute 'is_last_batch'
train_state._step_output = train_unit.train_step(state, step_input)

Versions

Hi @daniellepintz
Why is is_last_batch removed in #367 ?
This wont pass the test defined in

def test_is_last_batch(self) -> None:

What is the correct way to check is_last_batch? Thank you in advance!

The env:

torchtnt==0.1.0
torcheval==0.0.6
torchsnapshot==0.1.0

Tasks

No tasks being tracked yet.

Hi @yiminglin-ai, could you describe your use case in more detail?

For context, that field was added to support gradient accumulation in the AutoUnit extension. However, subsequent PRs made it such that we could deduce this information entirely within the AutoUnit. Accordingly, to keep the state as minimal as possible, we removed the is_last_batch attribute from there.

Knowing more information on how you'd like to use this data will help us out!