OpenPipe/ART

Suggestion: Potential major performance improvement for `gather_trajectory_groups` with `after_each`

Opened this issue · 0 comments

The current implementation of gather_trajectory_groups() is as follows:

async def gather_trajectory_groups(
    groups: Iterable[Awaitable[TrajectoryGroup]],
    after_each: Callable[
        [TrajectoryGroup], Awaitable[TrajectoryGroup | None | list[TrajectoryGroup]]
    ]
    | None = None,
) -> list[TrajectoryGroup]:
    
    ...

    # FIRST AWAIT:
    # First await all trajectory groups to finish, only then proceed

    with set_gather_context(context):
        future = asyncio.gather(*[wrap_group_awaitable(g) for g in groups])
        total = sum(getattr(g, "_num_trajectories", 1) for g in groups)
        context.pbar = tqdm.tqdm(desc=pbar_desc, total=total)
        result_groups = await future

    ...

    # SECOND AWAIT:
    # Only after *ALL* trajectory groups have been constructed, call the `after_each` callback for all of then

    # If an after_each callback was provided, await it and collect its return values.
    if after_each is not None:
        ae_processed_groups = await asyncio.gather(
            *(after_each(g) for g in processed_groups)
        )
    
    ...

    return processed_groups

Notice that in the current implementation, there are two separate asyncion.gather() calls, and the second gather which deals with the after_each callback will not fire up until the first gather is finished. This may cause severe performance bottlenecks, since if we use after_each=art.rewards.ruler_score_group, we wont be able to even begin scoring any group until all groups are ready.

See following image for demostration. Green and Red are the GPUs on which a local RULER model is deployed, Orange and Blue are GPUs on which rollouts are performed. Notice that there is NO OVERLAP - RULER GPUs are inactive throughout the entire rollout periods, and vice-versa.

Image

I propose to modify gather_trajectory_groups to have the following structure (or the spirit of it, im not 100% sure whats going on with the pbars and with the gather_context):

async def gather_trajectory_groups(
    groups: Iterable[Awaitable[TrajectoryGroup]],
    after_each: Callable[
        [TrajectoryGroup], Awaitable[TrajectoryGroup | None | list[TrajectoryGroup]]
    ]
    | None = None,
) -> list[TrajectoryGroup]:
    
    ...

   # A single awaitable function which constructs the trajectory group and invokes the callback if needed
   async def forward_group(g: TrajectoryGroup) -> TrajectoryGroup | None:
       g = await wrap_group_awaitable(g)
       if g and after_each:
           g = await after_each(g)
       return g

    # Simultaneously await group construction and `after_each` callbacks

    with set_gather_context(context):
        future = asyncio.gather(*[forward_group(g) for g in groups])
        total = sum(getattr(g, "_num_trajectories", 1) for g in groups)
        context.pbar = tqdm.tqdm(desc=pbar_desc, total=total)
        result_groups = await future

    ...


    return processed_groups