OptimalScale/LMFlow

Evaluation error on PubMedQA dataset

Closed this issue · 3 comments

I can't evaluate a model on PubMedQA dataset, I use the commond such as

./scripts/run_benchmark.sh --model_name_or_path /base_models/llama2-7b-hf --dataset_name PubMedQA

The error is "NotImplementedError: benchmarking dataset PubMedQA is not supported".

Simply adding the dataset PubMedQA to LOCAL_DATSET_GROUP_MAP in the benchmarking.py cannot solve. The new problem is

Running tokenizer on dataset:   0%|                                                                                          | 0/1000 [00:00<?, ? examples/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Running tokenizer on dataset:   0%|                                                                                          | 0/1000 [00:00<?, ? examples/s]Traceback (most recent call last):
  File "/data/wangjie/projects/LLM/LMFlow/examples/benchmarking.py", line 233, in <module>
    main()
  File "/data/wangjie/projects/LLM/LMFlow/examples/benchmarking.py", line 222, in main
    run_lmflow_local_benchmarking(dataset_name,pipeline_name,model_args,pipeline_args,model)  # Pass args TODO (@Jipeng)
  File "/data/wangjie/projects/LLM/LMFlow/examples/benchmarking.py", line 176, in run_lmflow_local_benchmarking
    result = evaluator.evaluate(model=model, dataset=dataset, metric=local_metric,verbose=True)
  File "/data/wangjie/projects/LLM/LMFlow/src/lmflow/pipeline/evaluator.py", line 149, in evaluate
    nll = self._evaluate_nll(model, dataset, verbose=verbose)
  File "/data/wangjie/projects/LLM/LMFlow/src/lmflow/pipeline/evaluator.py", line 397, in _evaluate_nll
    tokenized_dataset = model.tokenize(dataset, add_special_tokens=False)
  File "/data/wangjie/projects/LLM/LMFlow/src/lmflow/models/hf_decoder_model.py", line 588, in tokenize
    tokenized_datasets = raw_datasets.map(
  File "/data/wangjie/projects/LLM/LMFlow/src/lmflow/datasets/dataset.py", line 357, in map
    mapped_backend_dataset = self.backend_dataset.map(*args, **kwargs)
  File "/home/wangjie/anaconda3/envs/lmflow_test/lib/python3.9/site-packages/datasets/arrow_dataset.py", line 592, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
  File "/home/wangjie/anaconda3/envs/lmflow_test/lib/python3.9/site-packages/datasets/arrow_dataset.py", line 557, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
  File "/home/wangjie/anaconda3/envs/lmflow_test/lib/python3.9/site-packages/datasets/arrow_dataset.py", line 3097, in map
    for rank, done, content in Dataset._map_single(**dataset_kwargs):
  File "/home/wangjie/anaconda3/envs/lmflow_test/lib/python3.9/site-packages/datasets/arrow_dataset.py", line 3474, in _map_single
    batch = apply_function_on_filtered_inputs(
  File "/home/wangjie/anaconda3/envs/lmflow_test/lib/python3.9/site-packages/datasets/arrow_dataset.py", line 3353, in apply_function_on_filtered_inputs
    processed_inputs = function(*fn_args, *additional_args, **fn_kwargs)
  File "/data/wangjie/projects/LLM/LMFlow/src/lmflow/models/hf_decoder_model.py", line 544, in tokenize_function
    max_length = min(block_size, self.get_max_length())
TypeError: '<' not supported between instances of 'int' and 'NoneType'

It seems that something is wrong during tokenizing the test dataset.

Could you please tell me what should I do to solve the bug? Thanks very much.

Thanks for your interest in LMFlow! LMFlow benchmark hasn't supported automatic PubMedQA evaluation yet, but modifying it should be not that difficult. @2003pro I am wondering if you could take a look?

Here is the key regex for extracting the answer from responses generated from the lora-tuned model. You can check this for your evaluation script:

elif args.dataset == "pubmedqa":
        # pattern = "Output: (yes|no|maybe)"
        # sttr = re.search(pattern, temp)
        # answer = sttr.group(0)[8:] if sttr is not None else "N/A"
        answer_map = {"a":"yes","b":"no","c":"maybe","A":"yes","B":"no","C":"maybe","N/A":"N/A"}
        pattern = "(answer|Answer|ANSWER|output|Output|OUTPUT|A): \(*(A|B|C|a|b|c)"
        sttr = re.search(pattern, pred)
        if sttr is not None:
            mid_answer = sttr.group(0)
            answer = mid_answer[-1].lower()
        else:
            pattern = "\(*(A|B|C|a|b|c)\)*(\.|\s)"
            sttr = re.search(pattern, pred)
            if sttr is not None:
                if '(' in sttr.group(0):
                    answer = sttr.group(0)[1].lower()
                else:
                    answer = sttr.group(0)[0].lower()
            else:
                answer = "N/A"
        return answer_map[answer]
    elif args.dataset == "medmcqa":
        # pattern = "Output: (A|B|C|D)."
        # sttr = re.search(pattern, temp)
        # answer = sttr.group(0)[8:-1].lower() if sttr is not None else "N/A"
        pattern = "(answer|Answer|ANSWER|output|Output|OUTPUT|A): \(*(A|B|C|D|a|b|c|d)"
        sttr = re.search(pattern, pred)
        if sttr is not None:
            mid_answer = sttr.group(0)
            answer = mid_answer[-1].lower()
        else:
            pattern = "\(*(A|B|C|D|a|b|c|d)\)*(\.|\s)"
            sttr = re.search(pattern, pred)
            if sttr is not None:
                if '(' in sttr.group(0):
                    answer = sttr.group(0)[1].lower()
                else:
                    answer = sttr.group(0)[0].lower()
            else:
                answer = "N/A"
        return answer

    elif args.dataset == "usmle":
        # pattern = "Output: (A|B|C|D)."
        # sttr = re.search(pattern, temp)
        # answer = sttr.group(0)[8:-1].lower() if sttr is not None else "N/A"
        pattern = "(Answer|Output|A): \(*(A|B|C|D|a|b|c|d)"
        sttr = re.search(pattern, pred)
        if sttr is not None:
            mid_answer = sttr.group(0)
            answer = mid_answer[-1].lower()
        else:
            pattern = "\(*(A|B|C|D|a|b|c|d)\)*(\.|\s)"
            sttr = re.search(pattern, pred)
            if sttr is not None:
                if '(' in sttr.group(0):
                    answer = sttr.group(0)[1].lower()
                else:
                    answer = sttr.group(0)[0].lower()
            else:
                answer = "N/A"
        return answer

Thanks for your reply. I can use lm_eval to evaluate.