Suggestion: Potential major performance improvement for `gather_trajectory_groups` with `after_each`
Opened this issue · 0 comments
The current implementation of gather_trajectory_groups() is as follows:
async def gather_trajectory_groups(
groups: Iterable[Awaitable[TrajectoryGroup]],
after_each: Callable[
[TrajectoryGroup], Awaitable[TrajectoryGroup | None | list[TrajectoryGroup]]
]
| None = None,
) -> list[TrajectoryGroup]:
...
# FIRST AWAIT:
# First await all trajectory groups to finish, only then proceed
with set_gather_context(context):
future = asyncio.gather(*[wrap_group_awaitable(g) for g in groups])
total = sum(getattr(g, "_num_trajectories", 1) for g in groups)
context.pbar = tqdm.tqdm(desc=pbar_desc, total=total)
result_groups = await future
...
# SECOND AWAIT:
# Only after *ALL* trajectory groups have been constructed, call the `after_each` callback for all of then
# If an after_each callback was provided, await it and collect its return values.
if after_each is not None:
ae_processed_groups = await asyncio.gather(
*(after_each(g) for g in processed_groups)
)
...
return processed_groupsNotice that in the current implementation, there are two separate asyncion.gather() calls, and the second gather which deals with the after_each callback will not fire up until the first gather is finished. This may cause severe performance bottlenecks, since if we use after_each=art.rewards.ruler_score_group, we wont be able to even begin scoring any group until all groups are ready.
See following image for demostration. Green and Red are the GPUs on which a local RULER model is deployed, Orange and Blue are GPUs on which rollouts are performed. Notice that there is NO OVERLAP - RULER GPUs are inactive throughout the entire rollout periods, and vice-versa.
I propose to modify gather_trajectory_groups to have the following structure (or the spirit of it, im not 100% sure whats going on with the pbars and with the gather_context):
async def gather_trajectory_groups(
groups: Iterable[Awaitable[TrajectoryGroup]],
after_each: Callable[
[TrajectoryGroup], Awaitable[TrajectoryGroup | None | list[TrajectoryGroup]]
]
| None = None,
) -> list[TrajectoryGroup]:
...
# A single awaitable function which constructs the trajectory group and invokes the callback if needed
async def forward_group(g: TrajectoryGroup) -> TrajectoryGroup | None:
g = await wrap_group_awaitable(g)
if g and after_each:
g = await after_each(g)
return g
# Simultaneously await group construction and `after_each` callbacks
with set_gather_context(context):
future = asyncio.gather(*[forward_group(g) for g in groups])
total = sum(getattr(g, "_num_trajectories", 1) for g in groups)
context.pbar = tqdm.tqdm(desc=pbar_desc, total=total)
result_groups = await future
...
return processed_groups