XJDR-ALT has some some really cool work on entropy based sampling, but I couldn't get it running on mac silicon with Jax or Torch MPS compatibility. So I'm trying to fork it to work with MLX.
Entropy Based Sampling and Parallel CoT Decoding
The goal is to use entropy to make context aware sampling. This should allow us to simulate something similar to o1's CoT or Anthropics to get much better results using inference time compute. This project is a research project and a work in process. Its comprised of an inference stack, the sampler, and a UI (future). Please reach out to me on X if you have any question or concerns @_xjdr (original idea and implementation), @samefarrar (MLX implementation).
Generally, when LLMs pick tokens to output they do so with a set of fixed parameters. You might vary the temperature, the top_k, add logit or repetition penalties but these are fixed for that generation. This means that for every token in response to a question, the way the model samples from the logits is the same.
This doesn't necessarily make sense - some tokens are very straightforward, whereas some tokens might benefit from different sampling to scale inference time compute. As a concrete example, when you ask a model to compare 9.9 or 9.11, the token "." is very "certain". Everywhere in the response to the question, " 9" will likely be followed by ".". Here, scaling inference time compute is wasted because the most likely token is definitely the right one. This is a perfect example of a token where argmax of the logits makes sense.
However, there are tokens that are less "clear", and we think that we can detect this through statistics of the distribution of the logits and the attention scores. For example:
We can see in this example that "compare" is acting as a kind of "uncertainty sink", it is a token that is sampled where the logits varentropy is quite high. In order to scale inference time compute, in the above quadrants, this would be a token well suited to branching. So for now we sample that at that token with a high temperature to try to prevent the model from answering quickly, wrongly and confidently, instead to mimic chain of thought thinking to make it more likely to come to the right answer.
Current supported models: llama3.1+
- Clean up UI (make it look nicer)
- Introduce frog branch sampling parameters
- Allow comparison of metrics from multiple timesteps - we see that attention entropy gradually increases as the model comes to a "decision phrase" e.g. "9.9 is ".
install bun if you want to use the local server
uv sync
download weights (Instruct), you need to have set up your huggingface cli for this!
uv run mlx_download_weights.py
uv run mlx_main.py
--prompts
: Use predefined prompts frommlx_entropix.prompts
--prompt_csv
: Use prompts fromdata/prompts.csv
--input TEXT
: Provide a custom input prompt--normal
: Use default MLX Llama for generation
cd ui
bun run dev
This will call uv run mlx_server.py
in the background, as well as the web server.
--normal
: Use normal model for generation (as opposed to the entropix model)
-
Model Loading:
- Loads either a standard language model or an Entropix model based on the specified options.
- Uses the Llama-3.2-1B-Instruct model by default.
-
Text Generation:
- Generates text using either the mlx_lm
generate_mlx_lm
function or the Entropixgenerate
function. - Supports a maximum token limit of 4096.
- Generates text using either the mlx_lm
-
Command line or Server
- Use the model with the command line or the server.
-
Use predefined prompts:
uv run mlx_main.py --prompts
-
Use a custom input:
uv run mlx_main.py --input "What is the capital of France?"
-
Use normal sampling instead of Entropix:
uv run mlx_main.py --normal --input "Explain quantum computing"
- Ensure all required dependencies are installed and the model weights are downloaded before running the script.
- The Entropix model is used by default unless the
--normal
flag is specified.