flax nn.tabulate Incorrectly Reports FLOPs and VJP FLOPs
Surya-77 opened this issue · 2 comments
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
Ubuntu 22.04.4 LTS x86_64
- Flax, jax, jaxlib versions:
Name: flax
Version: 0.8.4
---
Name: jax
Version: 0.4.30
---
Name: jaxlib
Version: 0.4.30
- Python version:
Python 3.12.4
- GPU/TPU model and memory:
NVIDIA GeForce GTX 3080 Ti
- CUDA version:
12.2
Problem you have encountered:
When running a script to tabulate the model summary including FLOPs and VJP FLOPs using Flax's nn.tabulate
function, the output incorrectly shows both FLOPs and VJP FLOPs as 0. This is unexpected as the model does perform computations that should result in a non-zero FLOPs count, and especially the VJP FLOPs should be a non-zero integer value given the model's structure and operations.
What you expected to happen:
The expected output should correctly calculate and display the FLOPs and VJP FLOPs for each layer in the model.
Logs, error messages, etc:
import flax.linen as nn
import jax
import jax.numpy as jnp
class Foo(nn.Module):
@nn.compact
def __call__(self, x):
h = nn.Dense(4)(x)
return nn.Dense(2)(h)
x = jnp.ones((16, 9))
tabulate_fn = nn.tabulate(
Foo(), jax.random.PRNGKey(0), compute_flops=True, compute_vjp_flops=True)
print(tabulate_fn(x))
Foo Summary
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃ path ┃ module ┃ inputs ┃ outputs ┃ flops ┃ vjp_flops ┃ params ┃
┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
│ │ Foo │ float32[16,9] │ float32[16,2] │ 0 │ 0 │ │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼──────────────────────┤
│ Dense_0 │ Dense │ float32[16,9] │ float32[16,4] │ 0 │ 0 │ bias: float32[4] │
│ │ │ │ │ │ │ kernel: float32[9,4] │
│ │ │ │ │ │ │ │
│ │ │ │ │ │ │ 40 (160 B) │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼──────────────────────┤
│ Dense_1 │ Dense │ float32[16,4] │ float32[16,2] │ 0 │ 0 │ bias: float32[2] │
│ │ │ │ │ │ │ kernel: float32[4,2] │
│ │ │ │ │ │ │ │
│ │ │ │ │ │ │ 10 (40 B) │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼──────────────────────┤
│ │ │ │ │ │ Total │ 50 (200 B) │
└─────────┴────────┴───────────────┴───────────────┴───────┴───────────┴──────────────────────┘
Hmm unfortunately I cannot repro this (Flax 0.8.5). My printout yields this:
This can be reproed by opening any empty colab.
Foo Summary
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃ path ┃ module ┃ inputs ┃ outputs ┃ flops ┃ vjp_flops ┃ params ┃
┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
│ │ Foo │ float32[16,9] │ float32[16,2] │ 1504 │ 4460 │ │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼──────────────────────┤
│ Dense_0 │ Dense │ float32[16,9] │ float32[16,4] │ 1216 │ 3620 │ bias: float32[4] │
│ │ │ │ │ │ │ kernel: float32[9,4] │
│ │ │ │ │ │ │ │
│ │ │ │ │ │ │ 40 (160 B) │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼──────────────────────┤
│ Dense_1 │ Dense │ float32[16,4] │ float32[16,2] │ 288 │ 840 │ bias: float32[2] │
│ │ │ │ │ │ │ kernel: float32[4,2] │
│ │ │ │ │ │ │ │
│ │ │ │ │ │ │ 10 (40 B) │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼──────────────────────┤
│ │ │ │ │ │ Total │ 50 (200 B) │
└─────────┴────────┴───────────────┴───────────────┴───────┴───────────┴──────────────────────┘
The code does work on the pinned package configurations on Colab and Kaggle, but fails to run when installed with the same package versions on a local machine. The provided data is based on a new install of flax, jax and jaxlib cuda on a mamba environment using pip. (Though that shouldn't affect it).
For reference, the Colab and Kaggle runtime use system level CuDA packages while the pip installed versions come with their own CuDA wheels.
Here's the minimal dependency list anyways.
Package Version
------------------------ ---------
absl-py 2.1.0
asttokens 2.4.1
chex 0.1.86
decorator 5.1.1
etils 1.7.0
exceptiongroup 1.2.0
executing 2.0.1
flax 0.8.4
fsspec 2024.6.0
importlib_resources 6.4.0
ipython 8.25.0
jax 0.4.26
jax-cuda12-pjrt 0.4.26
jax-cuda12-plugin 0.4.26
jaxlib 0.4.26
jedi 0.19.1
markdown-it-py 3.0.0
matplotlib-inline 0.1.7
mdurl 0.1.2
ml-dtypes 0.4.0
msgpack 1.0.8
nest-asyncio 1.6.0
numpy 2.0.0
nvidia-cublas-cu12 12.5.2.13
nvidia-cuda-cupti-cu12 12.5.39
nvidia-cuda-nvcc-cu12 12.5.40
nvidia-cuda-nvrtc-cu12 12.5.40
nvidia-cuda-runtime-cu12 12.5.39
nvidia-cudnn-cu12 8.9.7.29
nvidia-cufft-cu12 11.2.3.18
nvidia-cusolver-cu12 11.6.2.40
nvidia-cusparse-cu12 12.4.1.24
nvidia-nccl-cu12 2.22.3
nvidia-nvjitlink-cu12 12.5.40
opt-einsum 3.3.0
optax 0.2.2
orbax-checkpoint 0.5.20
parso 0.8.4
pexpect 4.9.0
pickleshare 0.7.5
pip 24.0
prompt_toolkit 3.0.47
protobuf 5.27.2
ptyprocess 0.7.0
pure-eval 0.2.2
Pygments 2.18.0
PyYAML 6.0.1
rich 13.7.1
scipy 1.14.0
setuptools 70.1.1
six 1.16.0
stack-data 0.6.2
tensorstore 0.1.63
toolz 0.12.1
traitlets 5.14.3
typing_extensions 4.12.2
wcwidth 0.2.13
wheel 0.43.0
zipp 3.19.2