[xla: GPU] Multiple ppermutes are using same CUDA stream
DwaraknathT opened this issue · 8 comments
Hey all, I am working on a custom collective matmul implementation in jax to overlap an all gather with a matmul. I noticed that my bidirectional send/recv calls are running on the same stream, and therefore, end up being executed serially (please see attached profile image). Is it possible to manually assign different ppermute calls to different cuda streams?
Thank you!
I'd guess it's a bug against JAX actually to provide an API for such control? Or both?
I'd guess it's a bug against JAX actually to provide an API for such control? Or both?
My apologies, I don't quite follow, are you saying jax shouldn't provide an API for such cases? How do you think I can do what I'm trying to achieve? My ppermute delay is basically giving me no advantage of compute overlap with communication.
If you'd like JAX to provide such an API, then this should be a feature request on JAX?
If you'd like JAX to provide such an API, then this should be a feature request on JAX?
Ah, I don't necessarily want this as a jax API. XLA should already know that the forward and backward ppermutes can be done in parallel on different streams no?
My question more broadly was, how can I reduce the time taken by the serial ppermutes.
A couple of notes: you can get a multistreamed collective matmul via the SPMD partioner if you set:
--xla_gpu_multi_streamed_einsum_window, –-xla_gpu_threshold_for_windowed_einsum_mib=<small_value>
For the lower-level version, there are two asks:
- Most manual: JAX and XLA should offer an API for explicit stream assignment, so you can explicitly place these ops on different streams and we will respect that.
- More automatic: XLA should use multiple streams for collectives if it deems it profitable.
I suspect that starting with "most manual" is the right thing.
@hawkinsp Makes sense to me to go with the manual approach first ! For someone relatively new to scaling, could you sketch out roughly how those could propagate through the SPMD partitioner? Those should go together with shardings?
I guess sharding could just propagate the stream attributes, which we already have in XLA. I guess that would need to be added to JAX and we'd have to make sure the annotation doesnt get lost in XLA.
For the more automatic approach, yes XLA should really do that. @golechwierowicz
We currently only expose one gpu stream to run the scheduler for running asynchronous collective-permute, so, having the two pairs of send/recv running sequentially is by the current design.
NVIDIA has implement xla/gpu collective matmul, see the discussion thread.
Here are some notes about the flags to enable the optimization.
Is this something you want to try with and provide feedback for, instead of implementing your own custom collective matmul?
--xla_gpu_threshold_for_windowed_einsum_mib= : this controls for which sizes of gemm, CM will be enabled
--xla_gpu_multi_streamed_windowed_einsum: this controls whether we want to unroll the CM loop
some other features that would improve perf:
PGLE will definitely help
--xla_gpu_use_memcpy_local_p2p: this will use cudamem copy p2p instead of calling nccl p2p for intra-node communications