microsoft/LoRA

fine tuning RoBERTa-base with LoRA (ValueError: Classification metrics can't handle a mix of binary and multilabel-indicator targets)

rozhix opened this issue · 0 comments

rozhix commented

Hello, I'm trying to fine-tune RoBERTa-base with LoRA for a multi-classification task. My compute_metrics function is as follows:

def multi_label_metrics(predictions, labels, threshold=0.5):
# first, apply sigmoid on predictions which are of shape (batch_size, num_labels)
sigmoid = torch.nn.Sigmoid()
probs = sigmoid(torch.Tensor(predictions))
# next, use threshold to turn them into integer predictions
y_pred = np.zeros(probs.shape)
y_pred[np.where(probs >= threshold)] = 1
# finally, compute metrics
y_true = labels
f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')
accuracy = accuracy_score(y_true, y_pred)
# return as dictionary
metrics = {'f1': f1_micro_average,
'roc_auc': roc_auc,
'accuracy': accuracy}
return metrics

def compute_metrics(p: EvalPrediction):
preds = p.predictions[0] if isinstance(p.predictions,
tuple) else p.predictions
result = multi_label_metrics(
predictions=preds,
labels=p.label_ids)
return result

and this is my training part:

trainer = transformers.Trainer(
model=model,
train_dataset=encoded_dataset['train'],
eval_dataset=encoded_dataset['test'],
compute_metrics=compute_metrics,
args=transformers.TrainingArguments(
evaluation_strategy = "epoch",
save_strategy = "epoch",
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=batch_size,
num_train_epochs=5,
weight_decay=0.01,
learning_rate=2e-4,
load_best_model_at_end=True,
metric_for_best_model=metric_name,
fp16=True,
logging_steps=1,
output_dir='outputs'
),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
)
model.config.use_cache = False
trainer.train()

When I run it I have this error:
ValueError: Classification metrics can't handle a mix of binary and multilabel-indicator targets

how should I change my multi_label_metrics function?