This library is used to train and evaluate Sparse Autoencoders (SAEs). It handles the following training types:
- e2e (end-to-end): Loss function includes sparsity and final model kl_divergence.
- e2e + downstream reconstruction: Loss function includes sparsity, final model kl_divergence, and MSE at downstream layers.
- local (i.e. vanilla SAEs): Loss function includes sparsity and MSE at the SAE layer
- Any combination of the above.
See our paper which argues for training SAEs e2e rather than locally. All SAEs presented in the paper can be found at https://wandb.ai/sparsify/gpt2 and can be loaded using this library.
pip install e2e_sae
Train SAEs on any TransformerLens model
If you would like to track your run with Weights and Biases, place your api key and entity name in
a new file called .env
. An example is provided in .env.example.
Create a config file (see gpt2 configs here for examples). Then run
python e2e_sae/scripts/train_tlens_saes/run_train_tlens_saes.py <path_to_config>
If using a Colab notebook, see this example.
Sample wandb sweep configs are provided in e2e_sae/scripts/train_tlens_saes/.
The library also contains scripts for training mlps and SAEs on mlps, as well as training custom transformerlens models and SAEs on these models (see here).
You can load any pre-trained SAE (and accompanying TransformerLens model) trained using this library from Weights and Biases or locally by running
from e2e_sae import SAETransformer
model = SAETransformer.from_wandb("<entity/project/run_id>")
# or, if stored locally
model = SAETransformer.from_local_path("/path/to/checkpoint/dir")
All runs in our paper can be loaded this way (e.g.sparsify/gpt2/tvj2owza).
This will instantiate a SAETransformer
class, which contains a TransformerLens model with SAEs
attached. To do a forward pass without SAEs, use the forward_raw
method, to do a forward pass with
SAEs, use the forward
method (or simply call the SAETansformer instance).
The dictionary elements of an SAE can be accessed via SAE.dict_elements
. This is will normalize
the decoder elements to have norm 1.
To reproduce all of the analysis in our
paper use the
scripts in e2e_sae/scripts/analysis/
.
Developer dependencies are installed with make install-dev
, which will also install pre-commit
hooks.
Suggested extensions and settings for VSCode are provided in .vscode/
. To use the suggested
settings, copy .vscode/settings-example.json
to .vscode/settings.json
.
There are various make
commands that may be helpful
make check # Run pre-commit checks on all files (i.e. pyright, ruff linter, and ruff formatter)
make type # Run pyright on all files
make format # Run ruff linter and formatter on all files
make test # Run tests that aren't marked `slow`
make test-all # Run all tests
This library is maintained by Dan Braun.
Join the Open Source Mechanistic Interpretability Slack to chat about this library and other projects in the space!