Custom data file for classification seems to be failing
zippeurfou opened this issue ยท 7 comments
๐ Bug
When training a classification model on custom data file, the training fails because it expect num_classes
To Reproduce
Use this collab:
https://colab.research.google.com/drive/1uamw6SNaOr_4ch24JNxAj2yfgLUKfJqO?usp=sharing
Error:
Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/lightning_transformers/cli/train.py", line 84, in hydra_entry
main(cfg)
File "/usr/local/lib/python3.7/dist-packages/lightning_transformers/cli/train.py", line 78, in main
logger=logger,
File "/usr/local/lib/python3.7/dist-packages/lightning_transformers/cli/train.py", line 61, in run
trainer.fit(model, datamodule=data_module)
File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py", line 458, in fit
self.call_setup_hook(model)
File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py", line 1066, in call_setup_hook
model.setup(stage_name)
File "/usr/local/lib/python3.7/dist-packages/lightning_transformers/core/model.py", line 88, in setup
self.configure_metrics(stage)
File "/usr/local/lib/python3.7/dist-packages/lightning_transformers/task/nlp/text_classification/model.py", line 61, in configure_metrics
self.prec = Precision(num_classes=self.num_classes)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 948, in __getattr__
type(self).__name__, name))
AttributeError: 'TextClassificationTransformer' object has no attribute 'num_classes'
Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
Expected behavior
It should start training.
Environment
check the notebook
same here~
$ python train.py task=nlp/text_classification dataset=nlp/text_classification/emotion
...
torch.nn.modules.module.ModuleAttributeError: 'TextClassificationTransformer' object has no attribute 'num_classes'
Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
Thanks guys! This makes sense, I think the right approach to fix this would be to either infer the number of classes from the data (by collecting all unique labels, which I think HF Datasets supports) or to allow the user to pass this in.
this error only for cli
when u fixed this, or how I can do this myself?
am running into the same issue, is there any workarounds?
'''
TextClassificationTransformer' object has no attribute 'num_classes'
'''
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.