inngest/inngest-py

step.parallel is actually sequential?

PrathamSoni opened this issue · 2 comments

async def parallel(
        self,
        callables: tuple[typing.Callable[[], typing.Awaitable[types.T]], ...],
    ) -> tuple[types.T, ...]:
        """
        Run multiple steps in parallel.

        Args:
        ----
            callables: An arbitrary number of step callbacks to run. These are
                callables that contain the step (e.g. `lambda: step.run("my_step", my_step_fn)`.
        """

        self._inside_parallel = True

        outputs = tuple[types.T]()
        responses: list[execution.StepResponse] = []
        for cb in callables:
            try:
                output = await cb()
                outputs = (*outputs, output)
            except base.ResponseInterrupt as interrupt:
                responses = [*responses, *interrupt.responses]
            except base.SkipInterrupt:
                pass

        if len(responses) > 0:
            raise base.ResponseInterrupt(responses)

        self._inside_parallel = False
        return outputs

The inner calls against the callables happen in sequence instead of wrapping an asyncio.gather/as_completed with error pass through.

It's actually in parallel even though this code reads sequentially! That code will:

  1. Iterate over the parallel steps without executing their callbacks
  2. Respond to Inngest with the "plan" of parallel steps (i.e. "Please run these steps in parallel")
  3. Inngest sends a separate request for each parallel step, executing step.run callbacks in each request

So if you're using a threaded framework like Flask, each parallel step runs in a separate thread. Give a try with the following function:

@client.create_function(
    fn_id="my-fn",
    trigger=inngest.TriggerEvent(event="my-event"),
)
def fn(
    ctx: inngest.Context,
    step: inngest.StepSync,
) -> None:
    def _step_1() -> None:
        print("1")

    def _step_2() -> None:
        print("2")

    step.parallel(
        (
            lambda: step.run("1", _step_1),
            lambda: step.run("2", _step_2),
        )
    )

~50% of the time you'll see 2 logged before 1

Ahh i see thank you