/MiniChain

Primary LanguagePythonMIT LicenseMIT

Mini-Chain

A tiny library for large language models.

[Documentation and Examples]

Write apps that can easily and efficiently call multiple language models.

# A prompt from the Jinja template below.
class MathPrompt(TemplatePrompt[str]):
    template_file = "math.pmpt.tpl"

with start_chain("math") as backend:
    # MathPrompt with OpenAI backend
    p1 = MathPrompt(backend.OpenAI())
    # A prompt that simply runs Python
    p2 = SimplePrompt(backend.Python())
    # Chain them together
    prompt = p1.chain(p2)
    # Call chain with a question.
    question ="'What is the sum of the powers of 3 (3^i) that are smaller than 100?"
    print(prompt({"question": question}))
...
Question:
A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?
Code:
2 + 2/2

Question:
{{question}}
Code:
  • Install and Execute:
> pip install git+https://github.com/srush/MiniChain/
> export OPENAI_KEY="sk-***"
> python math.py

Examples

This library allows us to implement several popular approaches in a few lines of code.

It supports the current backends.

  • OpenAI (Completions / Embeddings)
  • Hugging Face 🤗
  • Google Search
  • Python
  • Manifest-ML (AI21, Cohere, Together)
  • Bash

Why Mini-Chain?

There are several very popular libraries for prompt chaining, notably: LangChain, Promptify, and GPTIndex. These library are useful, but they are extremely large and complex. MiniChain aims to implement the core prompt chaining functionality in a tiny digestable library.

Tutorial

Mini-chain is based on Prompts.

image

You can write your own prompts by overriding the prompt and parse function on the Prompt[Input, Output] class.

class ColorPrompt(Prompt[str, bool]):
    def prompt(inp: str) -> str:
        "Encode prompting logic"
        return f"Answer 'Yes' if this is a color, {inp}. Answer:"

    def parse(out: str, inp) -> bool:
        # Encode the parsing logic
        return out == "Yes"

The LLM for the Prompt is specified by the backend. To run a prompt, we give a backend and then call it like a function. To access backends, you need to call start_chain which also manages logging.

with start_chain("color") as backend:
    prompt1 = ColorPrompt(backend.OpenAI())
    if prompt1("blue"):
        print("It's a color!")

You can write a standard Python program just by calling these prompts. Alternatively you can chain prompts together.

image

with start_chain("mychain") as backend:
    prompt0 = SimplePrompt(backend.OpenAI())
    chained_prompt = prompt0.chain(prompt1)
    if chained_prompt("..."):
        ...

Prompt SimplePrompt simply passes its input string to the language-model and returns its output string.

We also include TemplatePrompt[Output] which assumes parse uses template from the Jinja language.

class MathPrompt(TemplatePrompt[str]):
    template_file = "math.pmpt.tpl"

Logging is done automatically based on the name of your chain using the eliot logging framework. You can run the following command to get the full output of your system.

show_log("mychain.log")

Memory

MiniChain does not build in an explicit stateful memory class. We recommend implementing it as a queue.

image

Here is a class you might find useful to keep track of responses.

@dataclass
class State:
    memory: List[Tuple[str, str]]
    human_input: str = ""

    def push(self, response: str) -> "State":
        memory = self.memory if len(self.memory) < MEMORY else self.memory[1:]
        return State(memory + [(self.human_input, response)])

See the full Chat example. It keeps track of the last two responses that it has seen.

Documents and Embeddings

MiniChain is agnostic to how you manage documents and embeddings. We recommend using the Hugging Face Datasets library with built in FAISS indexing.

image

Here is the implementation.

# Load and index a dataset
olympics = datasets.load_from_disk("olympics.data")
olympics.add_faiss_index("embeddings")

class KNNPrompt(EmbeddingPrompt):
    def find(self, out, inp):
        return olympics.get_nearest_examples("embeddings", np.array(out), 3)

This creates a K-nearest neighbors (KNN) Prompt that looks up the 3 closest documents based on embeddings of the question asked. See the full Retrieval-Augemented QA example.

We recommend creating these embeddings offline using the batch map functionality of the datasets library.

def embed(x):
    emb = openai.Embedding.create(input=x["content"], engine=EMBEDDING_MODEL)
    return {"embeddings": [np.array(emb['data'][i]['embedding'])
                           for i in range(len(emb["data"]))]}
x = dataset.map(embed, batch_size=BATCH_SIZE, batched=True)
x.save_to_disk("olympics.data")

There are other ways to do this such as sqllite or Weaviate.

Advanced

Asynchronous Calls

Prompt chains make it easier to manage asynchronous execution. Prompt has a method arun which will make the language model call asynchronous. Async calls need the trio library.

import trio
async def fn1(prompt1):
        if await prompt1.arun("blue"):
        ...

trio.run(prompt1)

A convenient construct is the map function which runs a prompt on a list of inputs.

image

This code runs a summarization prompt with asynchonous calls to the API.

with start_chain("summary") as backend:
    list_prompt = SummaryPrompt(backend.OpenAI()).map()
    out = trio.run(list_prompt.arun, documents)

Parsing

Minichain lets you use whatever parser you would like. One example is parsita a cool parser combinator library. This example builds a little state machine based on the LLM response with error handling.

class SelfAsk(TemplatePrompt[IntermediateState | FinalState]):
    template_file = "selfask.pmpt.tpl"

    class Parser(TextParsers):
        follow = (lit("Follow up:") >> reg(r".*")) > IntermediateState
        finish = (lit("So the final answer is: ") >> reg(r".*")) > FinalState
        response = follow | finish

    def parse(self, response: str, inp):
        return self.Parser.response.parse(response).or_die()

[Full Examples]