/LLM-finetune-demos

Supervised finetuning (sft) of instruction-following LLMs.

Primary LanguageJupyter NotebookApache License 2.0Apache-2.0

Supervised finetuning of instruction-following LLMs

License Python 3.9+ Code style: black

This repo contains lightweight demos for supervised finetuning (SFT) of large language models, like MosaicML's MPT-7B. In particular, we focus on short-form instruction following.

Instruction tuning background

In recent years, there has been a growing interest in building models that can follow natural language instructions to perform a wide range of tasks. These models, known as "instruction-tuned" language models, have demonstrated the ability to generalize to new tasks.

The below was captured from the "State of GPTs" talk by Andrej Karpathy. The key points illustrated for SFT:

  • Collect small but high-quality datasets in the form of "prompt" and "ideal responses".
  • Do language modeling on this data, nothing changes algorithmically from pretraining.
  • After training we get an SFT model which can be deployed as assistants (and it works to some extent).

training_pipeline

For more background, see any number of excellent papers on the subject, including Self-Instruct (2023), Orca (2023), and InstructGPT (2022).

Code assets

  • See the ./sft folder for finetuning scripts and postprocessing notebooks.
  • See the ./runs folder for the raw results from each sft experiment.
  • See the ./inf_tests folder for runtime tests on different models.

NB The scripts here perform full-parameter sft (updates all weights in the network). However, there are other options, such as parameter-efficient finetuning (see HuggingFace's peft).

Base models and datasets

In this repo, we finetune the following base models. Each is open-source and licensed for commercial use.

We test the following instruction datasets. Each is open-source and licensed for commercial use.

Results

1. mpt-7b-instruct-orca

This short-form instruction following model was built by finetuning MPT-7B on the first 100k rows of the ehartford/dolphin dataset (an open-source implementation of Microsoft's Orca). It was trained on a single H100 (80 GB PCIe) for about 12 hours using the Lambda Labs Platform.

For more details, see the model card.

loss_curves

Plot derived from the tfevents log at runs/jul_5_23_3_15_00_sft-instruction-mpt-7b-orca (link).

Example prompts and responses

Example 1:

User:

You are a helpful assistant. Write me a numbered list of things to do in New York City.\n

mpt-7b-instruct-orca:

  1. Visit the Statue of Liberty: The Statue of Liberty is a symbol of freedom and democracy, and it is a must-see attraction in New York City.
  2. Take a walk in Central Park: Central Park is a massive green space in the heart of Manhattan, and it is a great place to relax, enjoy nature, and people-watch.
  3. Visit the Metropolitan Museum of Art: The Met is one of the world's largest and most renowned art museums, housing an extensive collection of art from around the world.
  4. Take a stroll through Times Square: Times Square is the heart of the city, and it is a must-see for its iconic billboards, bright lights, and bustling energy.
  5. Visit the Empire State Building: The Empire State Building is a famous skyscraper in Midtown Manhattan, offering panoramic views of the city from its observation deck.
  6. Explore the Museum of Modern Art (MoMA): MoMA is a renowned art museum in Midtown Manhattan, showcasing modern and contemporary art from around the world.
  7. Visit the 9/11 Memorial and Museum: The 9/11 Memorial and Museum is a somber but important attraction in Lower Manhattan, commemorating the victims of the September 11th attacks.
  8. Take a ferry ride to the Statue of Liberty: A ferry ride to the Statue of Liberty is a great way to see the city from the water

Example 2:

User:

You are a helpful assistant. Daniel is in need of a haircut. His barber works Mondays, Wednesdays, and Fridays. So, Daniel went in for a haircut on Sunday. Does this make logical sense? Let's work this out in a step by step fashion to make sure we have the right answer.\n

mpt-7b-instruct-orca:

Step 1: Identify the days of the week when the barber works.

The barber works on Mondays, Wednesdays, and Fridays.

Step 2: Determine the day of the week on which Daniel went for a haircut.

Daniel went for a haircut on Sunday.

Step 3: Compare the day of the week on which Daniel went for a haircut to the days when the barber works.

Sunday is not one of the days when the barber works.

Step 4: Determine if the statement makes logical sense.

The statement does not make logical sense because Daniel went for a haircut on a day when the barber does not work.

Answer: No, the statement does not make logical sense. Daniel should have gone for a haircut on a day when the barber works, such as Monday, Wednesday, or Friday.


Example 3:

User:

You are a helpful assistant. Write a short email inviting my friends to a dinner party on Friday. Respond succinctly.\n

mpt-7b-instruct-orca:

Subject: Invitation to Dinner Party on Friday

Dear Friend,

I hope this email finds you well. I am writing to invite you to a dinner party that I am hosting on Friday. The party will be a casual gathering, and I would love to have you join us.

The dinner will be held at my home, and the time is from 7:00 pm to 10:00 pm. The address is 123 Main Street, San Francisco, CA 94122.

The dinner will be a potluck, so please bring a dish to share. I will provide the main course, so you can focus on bringing a side dish, dessert, or appetizer.

Please RSVP by replying to this email or by calling me at 555-555-5555. I look forward to seeing you on Friday!

Yours,

Your Friendly Assistant


Runtime tests

runtime per 50 tokens (sec) GPU type attn implementation torch dtype VRAM consumption (GB)
0.61 1x H100 (80 GB PCIe) triton bfloat16 12
0.67 1x H100 (80 GB PCIe) torch bfloat16 12
1.17 1x A100 (40 GB SXM) triton bfloat16 13
1.36 1x A100 (40 GB SXM) torch bfloat16 13
4.84 1x Tesla T4 (15 GB) torch bfloat16 13

The runtime statistics above (leftmost column) were generated with following code for each test, as per the corresponding notebook.

prompt = "You are a helpful assistant. Write me a long list of things to do in San Francisco:\n"

runtimes = []
for i in tqdm.tqdm(range(100)):
    start = time.time()
    response = mpt_generate(
        model,
        tokenizer,
        prompt,
        max_new_tokens=50,
        temperature=0.92,
    )
    end = time.time()
    runtimes.append(end - start)
    assert len(tokenizer.encode(response)) == 50

avg_runtime = torch.mean(torch.tensor(runtimes)).item()
print(f"Runtime avg in seconds: {avg_runtime}")  # time in seconds