Jax is slower than NumPy
certik opened this issue · 1 comments
With #10, I get the following timings with NumPy on my Apple M1 Max:
$ time python gpt2.py "Alan Turing theorized that computers would one day become" -n 40
generating: 100%|███████████████████████████████| 40/40 [00:18<00:00, 2.13it/s]
the most powerful machines on the planet.
The computer is a machine that can perform complex calculations, and it can perform these calculations in a way that is very similar to the human brain.
python gpt2.py "Alan Turing theorized that computers would one day become" -n 115.74s user 1.71s system 559% cpu 20.993 total
And Jax:
$ time python gpt2.py "Alan Turing theorized that computers would one day become" -n 40
generating: 100%|███████████████████████████████| 40/40 [00:21<00:00, 1.85it/s]
the most powerful machines on the planet.
The computer is a machine that can perform complex calculations, and it can perform these calculations in a way that is very similar to the human brain.
python gpt2.py "Alan Turing theorized that computers would one day become" -n 28.86s user 1.91s system 127% cpu 24.115 total
So Jax is slower. Using htop Jax is using roughly 1.3 CPU cores, while NumPy is using almost 6 CPU cores. Is NumPy automatically parallel on macOS?
Here is my Conda environment:
$ conda env export
name: pico
channels:
- conda-forge
dependencies:
- appdirs=1.4.4=pyh9f0ad1d_0
- appnope=0.1.3=pyhd8ed1ab_0
- asttokens=2.2.1=pyhd8ed1ab_0
- backcall=0.2.0=pyh9f0ad1d_0
- backports=1.0=pyhd8ed1ab_3
- backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
- brotlipy=0.7.0=py39h02fc5c5_1005
- bzip2=1.0.8=h3422bc3_4
- c-ares=1.18.1=h3422bc3_0
- ca-certificates=2022.12.7=h4653dfc_0
- cffi=1.15.1=py39h7e6b969_3
- cryptography=39.0.1=py39he2a39a8_0
- decorator=5.1.1=pyhd8ed1ab_0
- executing=1.2.0=pyhd8ed1ab_0
- idna=3.4=pyhd8ed1ab_0
- ipython=8.10.0=pyhd1c38e8_0
- jax=0.4.3=pyhd8ed1ab_0
- jaxlib=0.4.3=cpu_py39h99d3290_1
- jedi=0.18.2=pyhd8ed1ab_0
- libabseil=20220623.0=cxx17_h28b99d4_6
- libblas=3.9.0=16_osxarm64_openblas
- libcblas=3.9.0=16_osxarm64_openblas
- libcxx=14.0.6=h2692d47_0
- libffi=3.4.2=h3422bc3_5
- libgfortran=5.0.0=11_3_0_hd922786_27
- libgfortran5=11.3.0=hdaf2cc0_27
- libgrpc=1.51.1=hb15be72_1
- liblapack=3.9.0=16_osxarm64_openblas
- libopenblas=0.3.21=openmp_hc731615_3
- libprotobuf=3.21.12=hb5ab8b9_0
- libsqlite=3.40.0=h76d750c_0
- libzlib=1.2.13=h03a7124_4
- llvm-openmp=15.0.7=h7cfbb63_0
- matplotlib-inline=0.1.6=pyhd8ed1ab_0
- ncurses=6.3=h07bb92c_1
- openssl=3.0.8=h03a7124_0
- opt_einsum=3.3.0=pyhd8ed1ab_1
- packaging=23.0=pyhd8ed1ab_0
- parso=0.8.3=pyhd8ed1ab_0
- pexpect=4.8.0=pyh1a96a4e_2
- pickleshare=0.7.5=py_1003
- pip=23.0=pyhd8ed1ab_0
- pooch=1.6.0=pyhd8ed1ab_0
- prompt-toolkit=3.0.36=pyha770c72_0
- ptyprocess=0.7.0=pyhd3deb0d_0
- pure_eval=0.2.2=pyhd8ed1ab_0
- pycparser=2.21=pyhd8ed1ab_0
- pygments=2.14.0=pyhd8ed1ab_0
- pyopenssl=23.0.0=pyhd8ed1ab_0
- pysocks=1.7.1=pyha2e5f31_6
- python=3.9.16=hea58f1e_0_cpython
- python_abi=3.9=3_cp39
- re2=2023.02.01=hb7217d7_0
- readline=8.1.2=h46ed386_0
- scipy=1.10.0=py39h18313fe_2
- setuptools=67.1.0=pyhd8ed1ab_0
- six=1.16.0=pyh6c4a22f_0
- stack_data=0.6.2=pyhd8ed1ab_0
- tk=8.6.12=he1e0b03_0
- traitlets=5.9.0=pyhd8ed1ab_0
- tzdata=2022g=h191b570_0
- urllib3=1.26.14=pyhd8ed1ab_0
- wcwidth=0.2.6=pyhd8ed1ab_0
- wheel=0.38.4=pyhd8ed1ab_0
- xz=5.2.6=h57fd34a_0
- zlib=1.2.13=h03a7124_4
- pip:
- absl-py==1.4.0
- astunparse==1.6.3
- cachetools==5.3.0
- certifi==2022.12.7
- charset-normalizer==2.0.12
- fire==0.5.0
- flatbuffers==23.1.21
- gast==0.4.0
- google-auth==2.16.0
- google-auth-oauthlib==0.4.6
- google-pasta==0.2.0
- grpcio==1.51.1
- h5py==3.8.0
- importlib-metadata==6.0.0
- keras==2.11.0
- libclang==15.0.6.1
- markdown==3.4.1
- markupsafe==2.1.2
- numpy==1.24.1
- oauthlib==3.2.2
- protobuf==3.19.6
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- regex==2017.4.5
- requests==2.27.1
- requests-oauthlib==1.3.1
- rsa==4.9
- tensorboard==2.11.2
- tensorboard-data-server==0.6.1
- tensorboard-plugin-wit==1.8.1
- tensorflow-estimator==2.11.0
- tensorflow-macos==2.11.0
- termcolor==2.2.0
- tqdm==4.64.0
- typing-extensions==4.4.0
- werkzeug==2.2.2
- wrapt==1.14.1
- zipp==3.13.0
prefix: /Users/ondrej/mambaforge/envs/pico
Curious, I would've expected jax
to be faster given that it executes asynchronously (which should effectively make this line out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)]
parallel, while numpy would execute sequentially since each call is eager and blocking).
Not sure how jax
handles multiple CPUs, I know you can manually set multiple CPUs with the environment var export XLA_FLAGS="--xla_force_host_platform_device_count=8"
, but that didn't yield a speedup for me.
Relevant link: https://jax.readthedocs.io/en/latest/faq.html#is-jax-faster-than-numpy