GanjinZero/math401-llm

评测代码

cobraheleah opened this issue · 1 comments

请问评测代码在哪里

没有清理出完整的代码

def cal_accu(pred, gtrue):

    accu_count = 0
    idx = 0
    for p, gt in zip(pred, gtrue):
        if p:
            if abs(float(p) - float(gt)) < 1e-3:
                accu_count += 1
            else:
                pass
                # print(idx, p, gt)
        else:
            pass
            # print(idx, p, gt)
        idx += 1
    
    return accu_count / len(pred)

def average_error(pred, gtrue):

    errors = 0
    for p, gt in zip(pred, gtrue):
        if p:
            # print(float(p), float(gt))
            errors += min(10, abs(float(p) - float(gt)) / max(abs(float(gt)), 1))
        else:
            errors += 10

    return errors / len(pred)

def non_number_rate(pred, gtrue):

    count = 0
    for p, gt in zip(pred, gtrue):
        if not p:
            count += 1
    return count / len(pred)