salesforce/CodeRL

Does `Trainer_Critic` class mimic `transformers`'s `Trainer` class?

cwarny opened this issue · 1 comments

At first glance, it seems like the Trainer_Critic class mimic the transformer library's Trainer class. I'm just curious why you felt the need to do that instead of just using the Trainer class?

@cwarny This trainer critic was modified to include the inputs of test outcome annotations and facilitate tracking of prediction accuracies during training a critic model. Specifically, we mainly modified here

if self.tuning_mode in ['critic']:
outputs = None
curr_inputs = {'input_ids': inputs['input_ids'], 'error_types': inputs['error_types'],
'labels': inputs['labels']}
error_pred_loss, error_preds = model(**curr_inputs)
error_pred_acc = (inputs['error_types'].squeeze(1) == error_preds).sum()/len(inputs['error_types'])

and here for the new tracking accuracy variable to log the training progress

tr_acc = torch.tensor(0.0).to(args.device)

The original code of T5 model and trainer do not specifically facilitate this type of classification task so we included the modified code here.