ThilinaRajapakse/simpletransformers

Can't train T5 for classification

RanaBan opened this issue · 2 comments

model = MultiLabelClassificationModel('t5', 't5-base', ...)
gives

KeyError Traceback (most recent call last)

in <cell line: 1>()
----> 1 model = MultiLabelClassificationModel('t5', 't5-base',
2 use_cuda=True,
3 num_labels=9,
4 args={'train_batch_size':8,
5 # 'gradient_accumulation_steps':16,

/usr/local/lib/python3.9/dist-packages/simpletransformers/classification/multi_label_classification_model.py in init(self, model_type, model_name, num_labels, pos_weight, args, use_cuda, cuda_device, **kwargs)
217 self.args.fp16 = False
218
--> 219 config_class, model_class, tokenizer_class = MODEL_CLASSES[model_type]
220 if num_labels:
221 self.config = config_class.from_pretrained(

KeyError: 't5'

This is because MultiLabelClassificationModel class does not support T5 model. You can only select one of the models which it supports. If you want to use T5 for your experiment, please refer, https://simpletransformers.ai/docs/t5-minimal-start/

stale commented

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.