/bragghls

PyTorch model to RTL flow for low latency inference

Primary LanguageSystemVerilogMIT LicenseMIT

BraggHLS

This a framework for lowering PyTorch models to RTL using high-level synthesis (HLS) techniques. Crucially, we do not use any existing HLS tools (such as Xilinx's Vitis). In addition, there are no fixed architecture designs (with the exception of FloPoCo's floating point IPs) - no systolic arrays, no matrix multipliers - i.e., models are actually compiled. The particular, driving, use case is low-latency Bragg peak detection for high-energy diffraction microscopy (HEDM).

The "flow" is

TL;DR

Turn this

BraggNN(
  (cnn_layers_1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1))
  (nlb): NLB(
    (theta_layer): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))
    (phi_layer): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))
    (g_layer): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))
    (out_cnn): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
    (soft): Softmax(
      (exp): Exp()
    )
  )
  (cnn_layers_2): Sequential(
    (0): ReLU()
    (1): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1))
    (2): ReLU()
    (3): Conv2d(8, 2, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
  )
  (dense_layers): Sequential(
    (0): Linear(in_features=50, out_features=16, bias=True)
    (1): ReLU()
    (2): Linear(in_features=16, out_features=8, bias=True)
    (3): ReLU()
    (4): Linear(in_features=8, out_features=4, bias=True)
    (5): ReLU()
    (6): Linear(in_features=4, out_features=2, bias=True)
    (7): ReLU()
  )
)

into this

1200 intervals at ~100 MHz on Xilinx Alveo U280 with pipeline depth 2 (and so the throughput is actually ~4.7μs/sample)

(Red represents FMUL logic, green represents FADD logic)

Repo structure

This project has a lot of moving parts; the directory structure tells the tale:

  • bragghls/ - the core python library
    • compiler/compiler.py - python script the drives the entire flow
    • flopoco/ - functionality related to converting between FloPoCo's nonstandard floating point representation and IEEE754 (for purposes of RTL generation and simulation)
    • ip_cores/ - FloPoCo cores for 4,4 and 5,5 floating point addition and multiplication along with testbench generation
    • ir/ - functionality related to parsing, transforming, and interpreting MLIR representations of PyTorch models.
    • rtl/ - functionality related to emitting RTL (SystemVerilog)
    • testbench/ - testbench runners via cocotb and iverilog
  • bragghls_translate/ - MLIR parser/emitter translation library for translating MLIR to python
  • examples/ - obviously...
  • extensions/flopoco_convert_ext/ - pybind-ed extension for converting between IEEE754 and FloPoCo's floating point representation
  • scripts/ - helper scripts for things like generating new FloPoCo IPs and building the entire project
  • tests/ - obviously...

Current status

Build and Test Build Docker

Docker

The easiest way to get started is to build a docker container using the Dockerfile:

docker build . -t bragghls && docker run -it bragghls

or just pull the latest image from dockerhub:

docker pull mlevental/bragghls:latest && docker run -it mlevental/bragghls

Building

The build steps are many and tortuous.

Requirements

  1. A compiler (GCC or Clang)
  2. Python (>= 3.10) (recommended to use conda)
  3. GNU MP Bignum Library
  4. GNU Multiple Precision Floating-Point Reliable Library
  5. Multiple Precision Floating-point Interval library
  6. Icarus Verilog (sudo apt-get install iverilog)
  7. Patience

On Linux you can bash -c "$(wget -O - https://apt.llvm.org/llvm.sh)" for Clang and

sudo apt-get install libgmp3-dev libmpfr-dev libmpfi-dev iverilog

On Mac you can

brew install llvm gmp mpfr mpfi icarus-verilog

Build steps

  1. First make sure you have all the submodules checked out by running scripts/clone_externals.sh; This will take a while due to our dependency on LLVM.
  2. pip install -r requirements.txt to get PyTorch and TorchMLIR; if this fails due to unsupported platform for TorchMLIR it's probably because wheels are broken (cf. releases). Consult scripts/build.sh for how to build TorchMLIR from source, or feel free to reach out to me.
  3. pip install . -vvvv. This will:
    1. Build all of LLVM
    2. Build CIRCT against LLVM
    3. Build bragghls_translate and flopoco_converter
    4. Put all of the things in the correct places (circt-opt and bragghls_translate in venv/bin)

If you get some mysterious errors like

The C++ compiler "/usr/bin/c++" is not able to compile a simple test program

make sure you have either GCC or Clang in your path but also set the C_COMPILER, CXX_COMPILER environment variables like this:

C_COMPILER=clang CXX_COMPILER=clang++ pip install . -vvvv

Running

There is a small config file that sets necessary parameters (cf. bragghls_config.ini). The path to this config file needs to be set as an environment variable before invoking any of the bragghls scripts, as such:

BRAGGHLS_CONFIG_FP=$(pwd)/bragghls_config.ini bragghls_compiler 

Note the $(pwd) which is necessary if you're running the testbenches.

Assuming everything built successfully and you have all of the correct paths and environment variables, run any of the scripts in examples to generate MLIR IR. Then the main compiler driver can be run with the following arguments

usage: BraggHLS compiler driver [-h] [-t] [-r] [-s] [-v] [-b] [-n N_TEST_VECTORS] [--threshold THRESHOLD] fp

positional arguments:
  fp                    Filepath of top-level MLIR file

options:
  -h, --help            show this help message and exit
  -t, --translate       Translate MLIR to python
  -r, --rewrite         Transform/rewrite python
  -s, --schedule        Schedule the model using CIRCT
  -v, --verilog         Emit verilog
  -b, --testbench       Run autogenerated testbench
  -n N_TEST_VECTORS, --n_test_vectors N_TEST_VECTORS
                        Number of test vectors for testbench
  --threshold THRESHOLD
                        Test for average number of testbench failures instead of absolute

For example,

BRAGGHLS_CONFIG_FP=$(pwd)/bragghls_config.ini python $(pwd)/examples/simple_nns.py linear --size 11

which runs examples/simple_nns.py and produces an artifacts folder at examples/linear_bragghls_artifacts which will contains a linear.mlir file that looks like

module attributes {torch.debug_module_name = "Linear"} {
  memref.global "private" constant @__constant_8x8xf32 : memref<8x8xf32> = dense<[...]>
  memref.global "private" constant @__constant_8xf32 : memref<8xf32> = dense<[...]>
  func.func @forward(%arg0: memref<1x8xf32>) -> memref<f32> {
    
    ...
    
    scf.for %arg1 = %c0 to %c1 step %c1 {
      scf.for %arg2 = %c0 to %c8 step %c1 {
        %7 = memref.load %4[%arg1, %arg2] : memref<1x8xf32>
        %8 = memref.load %6[] : memref<f32>
        %9 = arith.addf %7, %8 : f32
        memref.store %9, %6[] : memref<f32>
      }
    }
    return %6 : memref<f32>
  }
}

Then running

BRAGGHLS_CONFIG_FP=$(pwd)/bragghls_config.ini bragghls_compiler $(pwd)/examples/linear_bragghls_artifacts/linear.mlir -t -r -s -v -b

will generate linear.sv and run the automatically generated (no artifact) testbench, and produce the following output:

INFO: Running command: iverilog "-o "examples/linear_bragghls_artifacts/sim.vvp "-D "COCOTB_SIM=1 "-g2012 ...
     0.00ns INFO     Running on Icarus Verilog version 11.0 (stable)
     0.00ns INFO     Running tests with cocotb v1.6.2 from ...
     0.00ns INFO     Seeding Python random module with 1659448436
     0.00ns WARNING  Pytest not found, assertion rewriting will not occur
     0.00ns INFO     Found test tb_runner.test_tb
     0.00ns INFO     running test_tb (1/1)
     
outputs {'_6': [<IEEE -4.6549486522000025> <FPNumber -4.50e0:01110010010>]}
passed 43
outputs {'_6': [<IEEE -1.2715176573999998> <FPNumber -1.31e0:01101110101>]}
passed 87
outputs {'_6': [<IEEE -7.192521898300005> <FPNumber -6.75e0:01110011011>]}
passed 131
outputs {'_6': [<IEEE -0.42565990870000003> <FPNumber -5.00e-1:01101100000>]}
passed 175

...

passed 703
outputs {'_6': [<IEEE 5.495344332200002> <FPNumber 5.00e0:01010010100>]}
passed 747
outputs {'_6': [<IEEE 4.6494865835> <FPNumber 5.25e0:01010010101>]}
passed 791
outputs {'_6': [<IEEE -2.963233154800001> <FPNumber -3.12e0:01110001001>]}
passed 835
outputs {'_6': [<IEEE 3.8036288347999996> <FPNumber 4.00e0:01010010000>]}
passed 879


  1761.00ns INFO     test_tb passed
  1761.00ns INFO     **************************************************************************************
                     ** TEST                          STATUS  SIM TIME (ns)  REAL TIME (s)  RATIO (ns/s) **
                     **************************************************************************************
                     ** tb_runner.test_tb              PASS        1761.00           1.08       1636.30  **
                     **************************************************************************************
                     ** TESTS=1 PASS=1 FAIL=0 SKIP=0               1761.00           1.12       1571.26  **
                     **************************************************************************************

If you have a waveform viewer (gtkwave or Scansion), then you can examine the generated linear.vcd:

alt text