How to check `is_last_batch` in torchtnt==0.1.0?
yiminglin-ai opened this issue ยท 2 comments
๐ Describe the bug
In the train_step
method of a Callback class:
def train_step(self, state: State, data: TrainBatch) -> None:
global_step = state.train_state.progress.num_steps_completed
is_last_batch = state.train_state.is_last_batch # error
Error:
AttributeError:'PhaseState' object has no attribute 'is_last_batch'
train_state._step_output = train_unit.train_step(state, step_input)
Versions
Hi @daniellepintz
Why is is_last_batch
removed in #367 ?
This wont pass the test defined in
tnt/tests/framework/test_auto_unit.py
Line 901 in 9b3b7b1
What is the correct way to check
is_last_batch
? Thank you in advance!
The env:
torchtnt==0.1.0
torcheval==0.0.6
torchsnapshot==0.1.0
CC @ananthsub
Hi @yiminglin-ai, could you describe your use case in more detail?
For context, that field was added to support gradient accumulation in the AutoUnit
extension. However, subsequent PRs made it such that we could deduce this information entirely within the AutoUnit
. Accordingly, to keep the state as minimal as possible, we removed the is_last_batch
attribute from there.
Knowing more information on how you'd like to use this data will help us out!