/MLAB-Transformers-From-Scratch

Reimplementing transformers from scratch (from Redwood Research's Machine Learning for Alignment Bootcamp).

Primary LanguagePython

MLAB Transformers From Scratch

A documented and unit-tested repo to help you learn how to build transformer neural network models from scratch.

The Transformer model architecture from Vaswani et al. 2017

Introduction

What is this?

  • It seems useful to gain a deeper understanding of transformers by building some key models "from scratch" in a library like PyTorch, especially if you want to do AI safety research on them.
  • Redwood Research runs a Machine Learning for Alignment Bootcamp (MLAB) in which one week consists of building BERT and GPT-2 from scratch, fine tuning them, and exploring some interpretability and training techniques.
  • This repository takes the code from the original MLAB repository and cleans it up to make it easier for others to do this independently. The key differences between this repo and the original repo are:
    • Removed almost all the content besides the days about building BERT and GPT-2.
    • Created a starter student file that has all the class and function stubs for the parts you'd need to build a transformer but without the implementation.
    • Migrated the original tests into a proper unittest test suite and added several more unit tests for various functionality.
    • Added docstrings to document what each part should do and give hints for the trivial-but-annoying parts.
    • Implemented a new solution file with all the new documentation to test that all the tests pass.
    • Various renaming and reorganizing of files to make the repo a bit cleaner.

Status

  • BERT: Fully operational, tested, documented, and ready for you to build.
  • GPT-2: In progress, but I'm currently not prioritizing development on this.

Getting Started

Prerequisites

Installation

  1. Create a virtual environment with venv or Anaconda and activate it if you like to work with virtual environments.
  2. Install PyTorch with the appropriate configuration for your environment.
  3. Run pip install -r requirements.txt to install the other requirements for this repository.
  4. Run pip install -e . to install the mlab_tfs package in an editable state.

Testing

This repo uses the built-in unittest framework for evaluating your code. You have a few options for running the tests

  1. python -m unittest mlab_tfs.tests.test_bert
  2. python ./mlab_tfs/tests/test_bert.py
  3. If using an IDE like Visual Studio Code with the Python extension, the unit tests should already be discovered and show up in the Testing pane.

Most of the tests are in the form

  • "Randomly initialize the student class and the reference class (from PyTorch or a solution file) with the same seed, pass the same input through it, and see if we get the same output."

but there are also tests for

  • "Are the student class' attributes of the correct types?"
  • or "Did the student cheat by calling the function from PyTorch that they're supposed to be implementing?"
  • and there's one test class at the end that consists of "Instantiate the whole student transformer model and reference transformer model, load in the real BERT/GPT-2 weights from HuggingFace Transformers, pass a sequence of input tokens through, and see if we get the same output logits."

Implementing Transformers

Now that you know how to test your code, go implement some transformers!

  • Go to the BERT folder.
  • Read the instructions in the README file there.
  • Reimplement the stubbed BERT classes and functions and pass the tests.

Note: Only the BERT folder is fully tested and documented, but you can also try writing your own gpt2_student.py and integrating it into the testing framework (please make a pull request to share this with others!).

Known issues

  • GPT-2 needs a starter file, better documentation (including a readme), and unit tests

Further Exploration

(Copied from "Week 2: Implementing transformers" of the original MLAB repo)

TODO for this repository

General

  • Delete more non-transformers or non-essential stuff
  • Redo requirements by creating a fresh venv
  • Remove old git stuff (prune?) so it's a smaller download
  • Update this main readme
  • Set up Pylint with a config
  • Remove commented requirements now that requirements are verified

BERT

  • Delete duplicated BERT solution file
  • Organize files to be much simpler (BERT and GPT2 folders)
  • Run testing code for BERT
  • Refactor testing code into unittest
  • Rename files to bert_reference, bert_student
  • Create starter file for you with empty stubs
  • Make testing code call starter code and compare to HF BERT and maybe MLAB solution
  • Update BERT readme to be more clear about what to do (e.g. no tokenizer)
    • Say it should be about 200 (or 150-300) lines of code
  • Include config or hyperparams or code to load weights
  • Change TODO into a changelist to describe differences from upstream (wrote some descriptions above)
  • Do BERT
  • Try removing init.py and other files if not used
  • Replace existing # Tensor[... comments with TensorType type hints
  • Rewrite/fix the 2 commented out tests
  • Integrate tests from the other archived files
  • Investigate mocking to check that the student didn't use methods from torch.nn instead of implementing their own
  • Add type hints to bert_student.py

GPT

  • Write GPT-2 readme
  • Clean up GPT-2 folder (might not need to do much)
  • Run testing code for GPT-2
  • Refactor testing code into unittest
  • Create starter file for you with empty stubs
  • Make testing code call starter code and compare to HF GPT-2 and maybe MLAB solution
  • Write GPT-2 readme (can say similar to the BERT folder or use similar content as that)