Issues
- 7
Pickling a trained model (NNX)
#4247 opened by eugene - 3
ToLinen is not hashable (Linen modules are)
#4156 opened by PhilipVinc - 0
How to use max_pool with xla_gpu_deterministic_ops?
#4233 opened by blackblitz - 1
Add NNX support for legacy `jax.random.PRNGKey()`
#4231 opened by cisprague - 3
- 7
How to freeze parameters with `nnx` and `optax`?
#4167 opened by maxencefaldor - 0
nn.scan messes with jax's sharding
#4217 opened by sokol11 - 2
nnx.Swish, jax.swish,... change the input shape
#4214 opened by leson207 - 3
`test_vmap_and_cond_passthrough[_error]` tests fail
#4211 opened by GaetanLepage - 0
Depthwise convolution in `flax.nnx.Conv` is significantly slower than PyTorch and TensorFlow
#4207 opened by YushaArif99 - 1
`nnx.vmap` use the same random key `rngs` inside nnx.Module across vectorization.
#4195 opened by maxencefaldor - 17
- 0
`TODO: add notes link` in flax_basics guide
#4193 opened by garymm - 0
flax counterpart for `torch.nn.Conv1d`
#4188 opened by Liyang90 - 0
MultiHeadAttention documentation missing descriptions of `None` values for optional arguments
#4182 opened by carlosgmartin - 0
Contributing to Flax Community Tutorials
#4181 opened by yilunyu - 0
CHANGELOG has not been updated since 0.8.2
#4179 opened by enolan - 0
[FLIP] JAX-style NNX Transforms
#4107 opened by cgarciae - 1
bias and kernel params are put on different gpu devices
#4116 opened by YunxiTang - 2
Flax 0.9.0 broke nnx rng splitting
#4153 opened by kasper0406 - 0
- 2
- 2
Experimental-pytree flag causes crash
#4142 opened by NeilGirdhar - 1
Support for optax lbfgs and related optimizers with NNX
#4144 opened by jlperla - 1
- 0
- 1
Add IndRNN
#4133 opened by carlosgmartin - 1
SPMD for initializing model using nnx.jit
#4129 opened by mmorinag127 - 2
Truncated Normal initializer doesn't match PyTorch
#4091 opened by DBraun - 3
- 2
Clarification for LSTMCell Documentation
#4124 opened by corentinlger - 0
MNIST tutorial broken for Colab TPU
#4122 opened by rcrowe-google - 2
- 1
Flax.linen.conv unexpected behavior.
#4113 opened by NITHISHM2410 - 2
Edit parameters of flax module in another module
#4109 opened by minhkhoi1026 - 0
`DynamicScale` behaves unexpected when computing per-sample gradients with `vmap`.
#4114 opened by hlzl - 0
- 1
lowering / cost analysis of @nnx.jit functions
#4094 opened by cgarciae - 2
NNXWrapper
#4088 opened by PhilipVinc - 1
Dropout seems not compatible with jax.jit
#4085 opened by richardmkit - 1
Large Difference in Loss between JAX and FLAX Two-Layer Linear Autoencoder
#4087 opened by yCobanoglu - 2
GroupNorm missing from NNX normalization layer
#4086 opened by treigerm - 2
lstm error
#4032 opened by layssi - 3
- 1
Is there anyway to analyze activations in flax?
#4058 opened by mohamad-amin - 1
Opaque XLA crash when initializing model
#4054 opened by fernandopalafox - 3
typo in nnx_basics.md
#4046 opened by notnot - 1
Feature request: Mixture of Experts example
#4034 opened by SamKG - 2
Suboptimal default initialization of q/k/v projections in `nn.MultiHeadDotProductAttention`
#4027 opened by MasterSkepticista - 2
flax nn.tabulate Incorrectly Reports FLOPs and VJP FLOPs
#4023 opened by Surya-77