Issues
- 2
unset JAX_PLATFORMS finds cuda, but JAX_PLATFORMS=gpu tries to use rocm (and fails)
#25315 opened by Joshuaalbert - 0
[Pallas] Add lowering support for tl.join
#25321 opened by ggengnv - 9
- 4
- 4
Missing `jax.core.ConcreteArray` in v0.4.36: Migration Guide and API Stability Concerns
#25314 opened by daskol - 3
About FFI failure: Failed to destroy GPU Graph
#25141 opened by victoryang00 - 5
unexpected `vmap` error due to commit `c36e1f7`
#25289 opened by marcocuturi - 0
LLVM ERROR: Failed to infer result type(s). on Jax Metal 0.1.1
#25302 opened by dlwh - 2
- 0
[//tests/mosaic:gpu_test_gpu] Wrong results on H100
#25218 opened by gpupuck - 0
Regressions in `debug_nans` (and `debug_infs`)
#25299 opened by emilyfertig - 0
- 0
- 9
Value becoming a tracer error
#25186 opened by alexandermm - 2
Difference between numpy and jax.numpy in advanced indexing axes order
#25109 opened by fzimmermann89 - 8
Possible leak in random number generation
#25069 opened by RadostW - 0
Request for jax.scipy.special.airy to be implemented
#25244 opened by elliottperryman - 0
⚠️ Nightly upstream-dev CI failed ⚠️
#25225 opened by github-actions - 0
[Pallas TPU] `pl.load` does not work inside TPU kernels
#25230 opened by ayaka14732 - 0
Allow Pallas Triton to be serialized as PTX
#25196 opened by sbodenstein - 2
Unable to run FFI example with stateful function
#25185 opened by Luthaf - 1
⚠️ Nightly upstream-dev CI failed ⚠️
#25146 opened by github-actions - 4
[Pallas] Uable to convert negative values from float16/float32 to int8/int32 in pallas
#25047 opened by shangz-ai - 2
- 0
[//tests/pallas:export_back_compat_pallas_test_gpu] INTERNAL: RET_CHECK failure (external/xla/xla/service/gpu/ir_emitter_unnested.cc:1388) triton_module Failed to parse Triton module: ML�R
#25212 opened by gpupuck - 7
jax.numpy.linalg.multi_dot is O(2^N) in the number of matrices being multiplied
#25051 opened by rohan-hitchcock - 1
`AssertionError: Unexpected XLA layout override` when adding two `from_dlpack` arrays
#25066 opened by samuela - 2
Failure to build jaxlib, AMD GPU
#25204 opened by Wintoplay - 4
[Lowering] Stable IR
#25123 opened by yliu120 - 3
Allow using jax.jit as decorator with arguments
#25194 opened by cool-RR - 0
in jax.lax.switch docstring, explain that all branch output shapes must match
#25140 opened by mattjj - 2
Weird defjvp behavior when finding grad of a scalar that depends on the primal
#25101 opened by JadM133 - 2
- 0
Support serializing Mosaic GPU to PTX
#25197 opened by sbodenstein - 1
Installation fails on a cluster
#25195 opened by PhilipVinc - 1
- 2
How do you remat GSPMD inserted all-gathers?
#25010 opened by ptoulme-aws - 0
Generating random numbers in pallas on gpu
#25188 opened by lengstrom - 0
Sum along first axis is slow
#25187 opened by ricardoV94 - 1
Cannot generate from `jax.experimental.sparse.random_bcoo` for arrays larger than the int32 max
#25182 opened by lengstrom - 4
Rank-one updates to eigenvalue decompositions
#25057 opened by mishavanbeek - 2
PRNGKey error
#25076 opened by itahang - 1
`CUDA_ERROR_ILLEGAL_ADDRESS`
#25002 opened by PhilipVinc - 3
[FFI Lowering] Compiler-only Attrs
#25124 opened by yliu120 - 0
- 4
Numeric stability issue in higher-order tensor einsum
#25024 opened by dest1n1s - 1
An example for `ffi` input/output aliasing
#24986 opened by saeedmaleki - 1
Running the executable compiled directly from jax.jit is more than three times slower than jax.jit itself.
#25023 opened by caixiiaoyang - 7
Add a way to get the tree size of a jaxpr
#24995 opened by carlosgmartin - 2
libtpu.so Not Found for JAX TPU 0.4.14
#24987 opened by LeoXinhaoLee