Support for `pytorch_lightning.Trainer.predict`
RR-28023 opened this issue ยท 1 comments
๐ Feature
Support for using Pytorch's Lightning pytorch_lightning.Trainer.predict
method with ligthing-transformers models and datamodules.
Motivation
Trainer.predict
is a very convenient method to run inference on a large dataset, since it leverages all the device management and parallelization functionalities of Trainer. However, I recently trained a lighting-transformers TextClassificationTransformer
and, despite being able to train it using Trainer.fit
, I can't run inference using `Trainer.predict.
Pitch
Only a couple of small changes are needed (mainly define the predict_dataloader
in TransformerDataModule
and then define the predict_step
on each child of TaskTransformer
). Can submit a PR showing what would need to change (since I did it anyway for my project).
Alternatives
Currently, inference can be run using a transformers.pipeline
, like shown in predict.py. However that does not allow to leverage the Trainer functionalities, and it does not go along with the main purpose of this project which is to enable the use of Trainer
in combination with the transformers.
Note that using Trainer.test
would not work because it does not return predictions.
Thanks!
Thanks @RR-28023 we can get this in for sure! We're slowly working towards removing the rest of the hydra code (and then rely on pure lightning, we'll provide examples to show how to do this).
We'll be making a release before this change though, so will make sure this is in before then.