OSU-NLP-Group/TravelPlanner

hard code in eval.py

Closed this issue · 2 comments

hi folks,
kindly ask why there are some hard codes in line: 180 ~ 194 here ?
https://github.com/OSU-NLP-Group/TravelPlanner/blob/main/evaluation/eval.py#L180

is there any specific consideration ?

can I change it to this below, on my local, which also applies for train set_type ?

    query_data_list_  = load_dataset('osunlp/TravelPlanner',set_type)[set_type]
    query_data_list = [x for x in query_data_list_][:len(tested_plans)]
    
    result['Delivery Rate'] = delivery_cnt / query_data_list.num_rows
    result['Commonsense Constraint Micro Pass Rate'] = constraint_dis_record['commonsense']['pass'] / constraint_dis_record['commonsense']['total']
    result['Commonsense Constraint Macro Pass Rate'] = final_commonsense_cnt / query_data_list.num_rows
    result['Hard Constraint Micro Pass Rate'] = constraint_dis_record['hard']['pass'] / constraint_dis_record['hard']['total']
    result['Hard Constraint Macro Pass Rate'] = final_hardConstraint_cnt / query_data_list.num_rows
    result['Final Pass Rate'] = final_all_cnt / query_data_list.num_rows

thanks a lot.

Hi,

If I remember correctly, they were introduced to address inaccuracies caused by failed plan deliveries during evaluation.

Apologies for the oversight on training set evaluation support; we're working to add this feature ASAP. In the meantime, you can use these temporary hard-coded values for training set: 45, 360, 45, 105, 45, 45.

@hsaest

thanks a lot.
added the block on my side:

    elif set_type == 'train':
        result['Delivery Rate'] = delivery_cnt / 45
        result['Commonsense Constraint Micro Pass Rate'] = constraint_dis_record['commonsense']['pass'] / 360
        result['Commonsense Constraint Macro Pass Rate'] = final_commonsense_cnt / 45
        result['Hard Constraint Micro Pass Rate'] = constraint_dis_record['hard']['pass'] / 105
        result['Hard Constraint Macro Pass Rate'] = final_hardConstraint_cnt / 45
        result['Final Pass Rate'] = final_all_cnt / 45