Lightning-Universe/lightning-transformers

Custom data file for classification seems to be failing

zippeurfou opened this issue ยท 7 comments

๐Ÿ› Bug

When training a classification model on custom data file, the training fails because it expect num_classes

To Reproduce

Use this collab:
https://colab.research.google.com/drive/1uamw6SNaOr_4ch24JNxAj2yfgLUKfJqO?usp=sharing
Error:

Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/lightning_transformers/cli/train.py", line 84, in hydra_entry
    main(cfg)
  File "/usr/local/lib/python3.7/dist-packages/lightning_transformers/cli/train.py", line 78, in main
    logger=logger,
  File "/usr/local/lib/python3.7/dist-packages/lightning_transformers/cli/train.py", line 61, in run
    trainer.fit(model, datamodule=data_module)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py", line 458, in fit
    self.call_setup_hook(model)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py", line 1066, in call_setup_hook
    model.setup(stage_name)
  File "/usr/local/lib/python3.7/dist-packages/lightning_transformers/core/model.py", line 88, in setup
    self.configure_metrics(stage)
  File "/usr/local/lib/python3.7/dist-packages/lightning_transformers/task/nlp/text_classification/model.py", line 61, in configure_metrics
    self.prec = Precision(num_classes=self.num_classes)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 948, in __getattr__
    type(self).__name__, name))
AttributeError: 'TextClassificationTransformer' object has no attribute 'num_classes'

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

Expected behavior

It should start training.

Environment

check the notebook

same here~

$ python train.py task=nlp/text_classification dataset=nlp/text_classification/emotion
...
torch.nn.modules.module.ModuleAttributeError: 'TextClassificationTransformer' object has no attribute 'num_classes'

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

Thanks guys! This makes sense, I think the right approach to fix this would be to either infer the number of classes from the data (by collecting all unique labels, which I think HF Datasets supports) or to allow the user to pass this in.

this error only for cli
when u fixed this, or how I can do this myself?

but, when run predict accept next error
image

am running into the same issue, is there any workarounds?
'''
TextClassificationTransformer' object has no attribute 'num_classes'
'''

stale commented

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

see #216 and #215