jessevig/bertviz

How to fine-tune my pre-trained model on GLUE dataset (e.g. imdb dataset)

ShaobinChen-AH opened this issue · 2 comments

I finetune bert on wiki unlabeled dataset to get my own pre-trained model. Then, I want to fine-tune this pre-trained model on one of GLUE dataset such as IMDB to perform sentiment classification task.
Specifically, part of core codes are shown as:
(in run_classifier_single_layer.py)
model = BertForSequenceClassification(bert_config, len(label_list), args.layers, pooling=args.pooling_type)
if args.init_checkpoint is not None: model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))

(in finetune.sh)
CUDA_VISIBLE_DEVICES=0 python run_classifier_single_layer.py \ --task_name imdb \ --do_train \ --do_eval \ --do_lower_case \ --data_dir ./IMDB_data/ \ --vocab_file ./pre-trained model/vocab.txt \ --bert_config_file ./pre-trained model/config.json \ --init_checkpoint ./pre-trained model/pytorch_model.bin \ --max_seq_length 512 \ --train_batch_size 24 \ --learning_rate 2e-5 \ --num_train_epochs 3.0 \ --output_dir ./imdb \ --seed 42 \ --layers 11 10 \ --trunc_medium -1
I run above finetune.sh script to perform classification task, but I get the following error:
1
I also try
model_state_dict = torch.load(args.init_checkpoint) model = BertForSequenceClassification.from_pretrained("bert-base-uncased", state_dict = model_state_dict) But it does not work.
What should I do to correct this error? Could you give me some advice? Thanks a lot!

Hi @bin199 have you had success since then? I think it is recommended to use from_pretrained but just point to your local directory, as in this example:

model = BertModel.from_pretrained("./test/saved_model/")

This is from https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.from_pretrained.example

But this also may depend on how you saved the model. It assumes you saved it using saved_pretrained.

I'm going to close this one for now, but feel free to reopen if you have any further questions or updates.