Lightning-Universe/lightning-flash

Support generation kwargs within Seq2SeqTasks

JohnGiorgi opened this issue ยท 0 comments

๐Ÿš€ Feature

Seq2Seq tasks tasks (and tasks that inherit from it like SummarizationTask) only allow a user to specify a couple of arguments to model.generate

https://github.com/Lightning-AI/lightning-flash/blob/651e85851509fd04f723caedfef8d487d77df4e0/flash/text/seq2seq/core/model.py#L139-L144

however, the generate method from HF supports a ton of arguments and decoding strategies, specified by a generation_config. A lot of flexibility could be unlocked by allowing Seq2SeqTask to accept a generation_config.

Motivation

Seq2SeqTask appears to be the main interface to text generation within Flash. It would really open up a lot of flexibility for this class of tasks if a user could easily specify the decoding strategy.

Pitch

I think the change is quite straightforward:

  1. Update Seq2SeqTask to accept a new argument, generation_config matching the HuggingFace object
  2. Remove any arguments to Seq2SeqTask covered by this config (e.g. num_beams)
  3. Update Seq2SeqTask.forward so that it provides this config to model.generate

Alternatives

I believe something similar could be achieved by adding a new argument, generation_kwargs, which, similar to the above strategy would be provided to Seq2SeqTask and passed as **generation_kwargs to model.generate via Seq2SeqTask.forward.

Additional context

Would be happy to work on a PR if the maintainers agree!