jaymody/picoGPT

Jax is slower than NumPy

certik opened this issue · 1 comments

With #10, I get the following timings with NumPy on my Apple M1 Max:

$ time python gpt2.py "Alan Turing theorized that computers would one day become" -n 40
generating: 100%|███████████████████████████████| 40/40 [00:18<00:00,  2.13it/s]
 the most powerful machines on the planet.

The computer is a machine that can perform complex calculations, and it can perform these calculations in a way that is very similar to the human brain.

python gpt2.py "Alan Turing theorized that computers would one day become" -n  115.74s user 1.71s system 559% cpu 20.993 total

And Jax:

$ time python gpt2.py "Alan Turing theorized that computers would one day become" -n 40
generating: 100%|███████████████████████████████| 40/40 [00:21<00:00,  1.85it/s]
 the most powerful machines on the planet.

The computer is a machine that can perform complex calculations, and it can perform these calculations in a way that is very similar to the human brain.

python gpt2.py "Alan Turing theorized that computers would one day become" -n  28.86s user 1.91s system 127% cpu 24.115 total

So Jax is slower. Using htop Jax is using roughly 1.3 CPU cores, while NumPy is using almost 6 CPU cores. Is NumPy automatically parallel on macOS?

Here is my Conda environment:

$ conda env export
name: pico
channels:
  - conda-forge
dependencies:
  - appdirs=1.4.4=pyh9f0ad1d_0
  - appnope=0.1.3=pyhd8ed1ab_0
  - asttokens=2.2.1=pyhd8ed1ab_0
  - backcall=0.2.0=pyh9f0ad1d_0
  - backports=1.0=pyhd8ed1ab_3
  - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
  - brotlipy=0.7.0=py39h02fc5c5_1005
  - bzip2=1.0.8=h3422bc3_4
  - c-ares=1.18.1=h3422bc3_0
  - ca-certificates=2022.12.7=h4653dfc_0
  - cffi=1.15.1=py39h7e6b969_3
  - cryptography=39.0.1=py39he2a39a8_0
  - decorator=5.1.1=pyhd8ed1ab_0
  - executing=1.2.0=pyhd8ed1ab_0
  - idna=3.4=pyhd8ed1ab_0
  - ipython=8.10.0=pyhd1c38e8_0
  - jax=0.4.3=pyhd8ed1ab_0
  - jaxlib=0.4.3=cpu_py39h99d3290_1
  - jedi=0.18.2=pyhd8ed1ab_0
  - libabseil=20220623.0=cxx17_h28b99d4_6
  - libblas=3.9.0=16_osxarm64_openblas
  - libcblas=3.9.0=16_osxarm64_openblas
  - libcxx=14.0.6=h2692d47_0
  - libffi=3.4.2=h3422bc3_5
  - libgfortran=5.0.0=11_3_0_hd922786_27
  - libgfortran5=11.3.0=hdaf2cc0_27
  - libgrpc=1.51.1=hb15be72_1
  - liblapack=3.9.0=16_osxarm64_openblas
  - libopenblas=0.3.21=openmp_hc731615_3
  - libprotobuf=3.21.12=hb5ab8b9_0
  - libsqlite=3.40.0=h76d750c_0
  - libzlib=1.2.13=h03a7124_4
  - llvm-openmp=15.0.7=h7cfbb63_0
  - matplotlib-inline=0.1.6=pyhd8ed1ab_0
  - ncurses=6.3=h07bb92c_1
  - openssl=3.0.8=h03a7124_0
  - opt_einsum=3.3.0=pyhd8ed1ab_1
  - packaging=23.0=pyhd8ed1ab_0
  - parso=0.8.3=pyhd8ed1ab_0
  - pexpect=4.8.0=pyh1a96a4e_2
  - pickleshare=0.7.5=py_1003
  - pip=23.0=pyhd8ed1ab_0
  - pooch=1.6.0=pyhd8ed1ab_0
  - prompt-toolkit=3.0.36=pyha770c72_0
  - ptyprocess=0.7.0=pyhd3deb0d_0
  - pure_eval=0.2.2=pyhd8ed1ab_0
  - pycparser=2.21=pyhd8ed1ab_0
  - pygments=2.14.0=pyhd8ed1ab_0
  - pyopenssl=23.0.0=pyhd8ed1ab_0
  - pysocks=1.7.1=pyha2e5f31_6
  - python=3.9.16=hea58f1e_0_cpython
  - python_abi=3.9=3_cp39
  - re2=2023.02.01=hb7217d7_0
  - readline=8.1.2=h46ed386_0
  - scipy=1.10.0=py39h18313fe_2
  - setuptools=67.1.0=pyhd8ed1ab_0
  - six=1.16.0=pyh6c4a22f_0
  - stack_data=0.6.2=pyhd8ed1ab_0
  - tk=8.6.12=he1e0b03_0
  - traitlets=5.9.0=pyhd8ed1ab_0
  - tzdata=2022g=h191b570_0
  - urllib3=1.26.14=pyhd8ed1ab_0
  - wcwidth=0.2.6=pyhd8ed1ab_0
  - wheel=0.38.4=pyhd8ed1ab_0
  - xz=5.2.6=h57fd34a_0
  - zlib=1.2.13=h03a7124_4
  - pip:
    - absl-py==1.4.0
    - astunparse==1.6.3
    - cachetools==5.3.0
    - certifi==2022.12.7
    - charset-normalizer==2.0.12
    - fire==0.5.0
    - flatbuffers==23.1.21
    - gast==0.4.0
    - google-auth==2.16.0
    - google-auth-oauthlib==0.4.6
    - google-pasta==0.2.0
    - grpcio==1.51.1
    - h5py==3.8.0
    - importlib-metadata==6.0.0
    - keras==2.11.0
    - libclang==15.0.6.1
    - markdown==3.4.1
    - markupsafe==2.1.2
    - numpy==1.24.1
    - oauthlib==3.2.2
    - protobuf==3.19.6
    - pyasn1==0.4.8
    - pyasn1-modules==0.2.8
    - regex==2017.4.5
    - requests==2.27.1
    - requests-oauthlib==1.3.1
    - rsa==4.9
    - tensorboard==2.11.2
    - tensorboard-data-server==0.6.1
    - tensorboard-plugin-wit==1.8.1
    - tensorflow-estimator==2.11.0
    - tensorflow-macos==2.11.0
    - termcolor==2.2.0
    - tqdm==4.64.0
    - typing-extensions==4.4.0
    - werkzeug==2.2.2
    - wrapt==1.14.1
    - zipp==3.13.0
prefix: /Users/ondrej/mambaforge/envs/pico

Curious, I would've expected jax to be faster given that it executes asynchronously (which should effectively make this line out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)] parallel, while numpy would execute sequentially since each call is eager and blocking).

Not sure how jax handles multiple CPUs, I know you can manually set multiple CPUs with the environment var export XLA_FLAGS="--xla_force_host_platform_device_count=8", but that didn't yield a speedup for me.

Relevant link: https://jax.readthedocs.io/en/latest/faq.html#is-jax-faster-than-numpy