michaelosthege/pytensor-federated

Implement graph rewrites to vectorize asynchronous Ops

michaelosthege opened this issue ยท 0 comments

Old design idea
## Design option: RPC-Aware Ops
This was the first idea.
+ [ ] Find `Op`s in the graph that use the `ArraysToArraysServiceClient` and can be parallelized (they must not depend on each other). This can be implemented by adding a mixin-interface by which an `isinstance(op, ArraysToArraysOp)` can be identified.
+ [ ] Write a `ParallelArraysToArraysOp` that keeps a list of streams and runs evaluations in parallel.
+ [ ] Do a subgraph replacement where the independent `ArraysToArraysOp`s nodes are substituted by a subgraph that routes the inputs to a new `ParallelArraysToArraysOp` node and redistributes the outputs.

Design option: Async Ops (preferred)

This would be RPC-unaware and more generic overall.

  • #26
  • Implement async homologues to the function-wrapping convenience-Ops: AsyncArraysToArraysOp, AsyncLogpOp, AsyncLogpGradOp.
  • Write a class ParallelAsyncOp(Op) similar to aesara.graph.basic.Composite that parallelizes the .perform_async() calls of a bunch of AsyncOps.
  • Write a graph optimization that finds AsyncOps and merges them into an ParallelAsyncOp