Issues
- 2
How do you remat GSPMD inserted all-gathers?
#25010 opened by ptoulme-aws - 1
`CUDA_ERROR_ILLEGAL_ADDRESS`
#25002 opened by PhilipVinc - 4
Numeric stability issue in higher-order tensor einsum
#25024 opened by dest1n1s - 1
An example for `ffi` input/output aliasing
#24986 opened by saeedmaleki - 1
Running the executable compiled directly from jax.jit is more than three times slower than jax.jit itself.
#25023 opened by caixiiaoyang - 8
time.sleep affects the execution time of JAX
#24941 opened by horse6 - 7
Add a way to get the tree size of a jaxpr
#24995 opened by carlosgmartin - 2
libtpu.so Not Found for JAX TPU 0.4.14
#24987 opened by LeoXinhaoLee - 5
`debug_nans` error always says the de-optimized function did not produce NaNs
#24955 opened by emilyfertig - 6
AOT compilation and serialization
#24982 opened by jkup14 - 2
FloatingPointError in jax.scipy.stats
#24939 opened by IlayMenahem - 2
Slow import time: can it be reduced?
#24967 opened by jeertmans - 4
[GPU] FlashAttention performance lags behind PyTorch
#24934 opened by neel04 - 1
⚠️ Nightly upstream-dev CI failed ⚠️
#24875 opened by github-actions - 3
Bug in the latest libtpu nightly release
#24829 opened by knyazer - 2
Small Einsum is hanging
#24929 opened by ryan112358 - 1
Rework comparisons in the docs of JAX vs. NumPy PRNG
#24927 opened by emilyfertig - 0
Allow einsum to support naive contraction strategy
#24915 opened by ryan112358 - 5
array.at.set is incredibly slow for complex128 dtype
#24872 opened by chrisrothUT - 0
Results do not match the reference. This is likely a bug/unexpected loss of precision
#24909 opened by yanboyang97 - 0
CUDA12 plugin segfaults when older version of JAX is installed
#24901 opened by dime10 - 2
- 3
pure_callback using a threadpool to accelerate vmap
#24756 opened by Joshuaalbert - 4
Missing annotations
#24888 opened by PerilousApricot - 1
CUDA_ERROR_SYSTEM_NOT_READY: system not yet initialized
#24866 opened by carlosgmartin - 5
pjit handling of static_argnames is broken when using "dynamic" python strings.
#24857 opened by etarassov - 0
Unsupported type in metal PJRT plugin with rng_bit_generator
#24867 opened by dmarro89 - 3
cond error by using Tracer as a condition
#24858 opened by 2300504237 - 1
⚠️ Nightly upstream-dev CI failed ⚠️
#24824 opened by github-actions - 0
AOT errors when setting compile options
#24869 opened by man2machine - 2
Error with `at[index].set(value)` when using `jit` and `shard_map` in multi-host setting.
#24768 opened by lollcat - 2
support q offset w.r.t k/v in flash_attention function
#24830 opened by yamingx - 5
Wrong determinant results for large batch
#24843 opened by ChenAo-Phys - 3
- 3
InconclusiveDimensionOperation: Symbolic dimension comparison 'b' < '2147483647' is inconclusive.
#24730 opened by njzjz - 1
7x7 `nnx.Conv` using `float32` parameter dtype overflows(?) to `nan` when sharded
#24848 opened by joshhansen - 1
potential jax 0.4.35 release issue?
#24826 opened by haohuanw - 3
cudaErrorSymbolNotFound : named symbol not found
#24749 opened by amanjitsk - 0
`shard_map` doesn't work with `jnp.insert`
#24762 opened by mrlazy1708 - 0
bincount rejects bool
#24813 opened by carlosgmartin - 0
Deepcopy of pjit functions failing (or not supported)
#24838 opened by jlperla - 2
`scipy.sparse.csgraph.connected_components` implementation
#24737 opened by vboussange - 2
TPU with sharding: grpc initialization failure
#24821 opened by knyazer - 1
Sparse reshape throws error when `n_dense>0` and some target dimension has size 1
#24795 opened by cherrywoods - 1
Division by self not always "1.0" on JAX GPU, but consistently gives "1.0" on JAX CPU.
#24807 opened by mattlevine22 - 1
pre-commit run --all failed for ruff and mypy
#24799 opened by apivovarov - 2
Flaky test tests/fft_test.py::FftTest::testFftfreq5
#24798 opened by apivovarov - 4
TracerBoolConversionError
#24740 opened by hangita101 - 0
Facing Scaling issue on cpu (arm and x86).
#24753 opened by choudhary-devang - 2
median FloatingPointError: invalid value (nan) encountered in jit(convert_element_type)
#24732 opened by zhangylch