/gpt2sharp

GPT2# is a zero dependency, sub 1 000 loc implementation of GPT2 inference, batteries included

Primary LanguageC#MIT LicenseMIT

GPT2#

logoD
GPT2# is a sub-1,000-lines, zero dependencies GPT2 model inference implementation written in C# 11, capable of loading the original weights released by OpenAI.

⚙️ Build

To get started with GPT2#, make sure dotnet --version returns 7.0+. If not, grab your copy of the SDK here.
Then, run the following commands:

  • git clone https://github.com/lofcz/gpt2sharp & cd gpt2sharp
  • dotnet run --project SharpGpt

🔮 About GPT2#

GPT2# is a minimal model inference implementation intended for ML hackers. Have you ever wondered how LLMs work behind the scenes?
This repository can help you understand the process. The code is written in a clean, easily readable manner. The entire forward pass consists of just about 100 lines of code.

✖️ The Math Behind GPT2#

GPT2# has just a handful of functions defined, all of which can be found in MathExt.cs. For more information on specific functions, you can follow the links attached.
These functions were defined by hand without using any third-party libraries on purpose, as understanding them is an important step to understanding the inference process.

📸 Inference example

gpt2
Text in cyan is inferred

ℹ️ Limitations

  • Currently, only greedy decoding is implemented (no temp/top_k). To implement these, you can replace the implementation of the GetBest function in Gpt.cs. For guidance, consult the section "Sampling" here.
  • The math used in GPT2# is not accelerated by using a GPU/TPU or SIMD vectorization. While a Parallel.For() is used in matrix multiplication, the math is written in a simple, easy-to-understand way rather than being focused on performance. Replacing the naive matrix multiplication with the Strassen algorithm can provide a significant speedup.
  • The used model is pretty small and dated, don't expect miracles. The first few tokens inferred are generally ok, but due to the lack of repetition penalty and greedy sampling the model stucks in a token loop pretty fast.

Literature & Related projects