Ber666/ToolkenGPT

implement chatgpt+react and vicuna+react

siyuyuan opened this issue · 4 comments

Hi, I tried to implement chatgpt+react and vicuna+react using the methods in sections A.2.3 and A.2.4 of your paper, but I am confused about how exactly you do this. Your paper mentions:

'ReAct incorporates special syntax to call operators, e.g., "... The cost is 50*3.2=(50,3.2)=160". Once the syntax is detected during inference, the tool would be called to calculate the result.'

However, in the paper ReAct: Synergizing Reasoning and Acting in Language Models, the authors said they do react by action+thought as follows:

Question What is the elevation range for the area that the eastern sector of the Colorado orogeny extends into?
Thought 1 I need to search Colorado orogeny, find the area that the eastern sector of the Colorado orogeny extends into, then find the elevation range of the area.
Action 1 Search[Colorado orogeny]
Observation 1 The Colorado orogeny was an episode of mountain building (an orogeny) in Colorado and surrounding areas.

So, how is this implemented using chatgpt and vicuna if i used your prompt as the method you mentioned in your paper? Can you share your llama 30B+react code? I think the funcqa dataset is interesting, and I hope to achieve a fairer comparison.

Ber666 commented

We didn't follow the exact format of ReAct. Instead, it's like an in-context learning version of Toolformer, in terms of the formatting. However, the idea is the same, which is combining Chain-of-thoughts and Action using in-context learning. We chose to format this way to align with the format of ToolkenGPT.

I am not sure whether you can directly apply our prompt to a chat LLM (Vicuna or chatgpt), but it should work with minor modification (maybe some additional instruction).

@Leolty could you share our code for ReAct?

Leolty commented

Hi @siyuyuan,

Apologies for the delayed response; I've been preoccupied with other things and unfortunately forgot and overlooked this. Sorry for any inconvenience caused!

Below is the React inference code that you can integrate. You can try to place this function within the inference_modes.py file for use (similar with other funcs):

def react_inference(templates, case_idx, question, funcmodel, temperature, top_p, max_gen_len):
    funcmodel.inference_mode = "react"
    # get func list
    func_map = list(funcmodel.func_dict.keys())

    cur_generation = ""
    try:
        results = []
        func_calls = []
        while True:
            prompt = templates["func"].replace("[QUESTION]", question) + cur_generation
            results = funcmodel.generate([prompt], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, stop_token=[13, 29897, 3892])

            cur_generation = results[0].replace(templates["func"].replace("[QUESTION]", question), "")
            
            endflag = True

            if cur_generation.endswith(")") or cur_generation.endswith(")="):
                endflag = False

                # use pattern to extract args
                pattern = r"\<(.*?)\>\((.*?)\)\="

                args = ""

                matches = re.findall(pattern, cur_generation)

                if len(matches) == 0:
                    raise Exception("invalid args")
                else:
                    op, args = matches[-1]

                op = "<" + op.strip() + ">"
                args = "(" + args.strip() + ")"
                
                if op not in func_map:
                    raise Exception(f"invalid func -- {op}")

                if args == "":
                    raise Exception("invalid args")
                
                
                args = args.replace("=", "").replace(">", "").replace("((", "(").replace("))", ")")


                # remove , in the args
                if ", " in args:
                    args = args.replace(", ", ";").replace(",", "").replace(";", ", ")

                args = args.replace(" ", "")

                if "(" not in args or ")" not in args:
                    raise Exception("invalid args")

                # handle % and / in args
                if '%' in args or '/' in args:
                    temp = args.split("(")[1].split(")")[0].split(",")

                    for arg_i, arg in enumerate(temp):
                        # if have percentage, convert to decimal
                        if "%" in arg:
                            arg = arg.replace("%", "").strip()
                            arg = str(float(arg) / 100)
                        # if have fraction, convert to decimal
                        if "/" in arg:
                            numerator, denominator = [a.strip() for a in arg.split("/")]
                            arg = str(float(numerator) / float(denominator))
                        
                        temp[arg_i] = arg
                    
                    args = f"({', '.join(temp)})"
                
                try:
                    res = eval(f"_{op[1:-1]}_{args}")
                    func_calls.append(f"{op}{args} = {res}")
                    cur_generation = cur_generation + str(res)
                    # only generate the next token
                    # disable all the numbers
                    prompt = templates["func"].replace("[QUESTION]", question) + cur_generation
                    results = funcmodel.generate([prompt], max_gen_len=1, temperature=temperature, top_p=top_p, stop_token=[13],
                                                    disable_token = [29900, 29896, 29906, 29941, 29946, 29945, 29953, 29955, 29947, 29929]) # disable all the numbers
               
                    cur_generation = results[0].replace(templates["func"].replace("[QUESTION]", question), "")
                except Exception as e:
                    raise Exception(f"error -- {e}")

            if endflag:
                break

        log = {
            "case_idx": case_idx,
            "question": question,
            "func_calls": func_calls,
            "generation": cur_generation,
            "status": "success"
        }

    except Exception as e:
        # if local_rank == 0:
        log = {
            "case_idx": case_idx,
            "question": question,
            "generation": cur_generation,
            "status": str(e)
        }
    return log 
    ```

Hi @siyuyuan,

Apologies for the delayed response; I've been preoccupied with other things and unfortunately forgot and overlooked this. Sorry for any inconvenience caused!

Below is the React inference code that you can integrate. You can try to place this function within the inference_modes.py file for use (similar with other funcs):

def react_inference(templates, case_idx, question, funcmodel, temperature, top_p, max_gen_len):
    funcmodel.inference_mode = "react"
    # get func list
    func_map = list(funcmodel.func_dict.keys())

    cur_generation = ""
    try:
        results = []
        func_calls = []
        while True:
            prompt = templates["func"].replace("[QUESTION]", question) + cur_generation
            results = funcmodel.generate([prompt], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, stop_token=[13, 29897, 3892])

            cur_generation = results[0].replace(templates["func"].replace("[QUESTION]", question), "")
            
            endflag = True

            if cur_generation.endswith(")") or cur_generation.endswith(")="):
                endflag = False

                # use pattern to extract args
                pattern = r"\<(.*?)\>\((.*?)\)\="

                args = ""

                matches = re.findall(pattern, cur_generation)

                if len(matches) == 0:
                    raise Exception("invalid args")
                else:
                    op, args = matches[-1]

                op = "<" + op.strip() + ">"
                args = "(" + args.strip() + ")"
                
                if op not in func_map:
                    raise Exception(f"invalid func -- {op}")

                if args == "":
                    raise Exception("invalid args")
                
                
                args = args.replace("=", "").replace(">", "").replace("((", "(").replace("))", ")")


                # remove , in the args
                if ", " in args:
                    args = args.replace(", ", ";").replace(",", "").replace(";", ", ")

                args = args.replace(" ", "")

                if "(" not in args or ")" not in args:
                    raise Exception("invalid args")

                # handle % and / in args
                if '%' in args or '/' in args:
                    temp = args.split("(")[1].split(")")[0].split(",")

                    for arg_i, arg in enumerate(temp):
                        # if have percentage, convert to decimal
                        if "%" in arg:
                            arg = arg.replace("%", "").strip()
                            arg = str(float(arg) / 100)
                        # if have fraction, convert to decimal
                        if "/" in arg:
                            numerator, denominator = [a.strip() for a in arg.split("/")]
                            arg = str(float(numerator) / float(denominator))
                        
                        temp[arg_i] = arg
                    
                    args = f"({', '.join(temp)})"
                
                try:
                    res = eval(f"_{op[1:-1]}_{args}")
                    func_calls.append(f"{op}{args} = {res}")
                    cur_generation = cur_generation + str(res)
                    # only generate the next token
                    # disable all the numbers
                    prompt = templates["func"].replace("[QUESTION]", question) + cur_generation
                    results = funcmodel.generate([prompt], max_gen_len=1, temperature=temperature, top_p=top_p, stop_token=[13],
                                                    disable_token = [29900, 29896, 29906, 29941, 29946, 29945, 29953, 29955, 29947, 29929]) # disable all the numbers
               
                    cur_generation = results[0].replace(templates["func"].replace("[QUESTION]", question), "")
                except Exception as e:
                    raise Exception(f"error -- {e}")

            if endflag:
                break

        log = {
            "case_idx": case_idx,
            "question": question,
            "func_calls": func_calls,
            "generation": cur_generation,
            "status": "success"
        }

    except Exception as e:
        # if local_rank == 0:
        log = {
            "case_idx": case_idx,
            "question": question,
            "generation": cur_generation,
            "status": str(e)
        }
    return log 
    ```

Thank you so much! Merry Christmas in advance:)

Leolty commented

Thank you, @siyuyuan!! Merry Christmas to you too! 🎄