openxla/xla

XLA CollectiveOpsUtil does not include kReduceScatter

Closed this issue · 2 comments

When debugging the XLA while loop unroller channel id creation, we found that HloOpcode::kReduceScatter was not included as a collective to generate new channel id for.

While loop unroller location:

if (IsCollectiveWithChannelId(body_inst)) {

Collective ops util location:

bool IsCollectiveWithChannelId(const HloInstruction* instruction) {

We have added kReduceScatter locally and it fixed our channel id issues.

Questions:

  1. Why is kReduceScatter not listed in the list of collective operations to replace channel id with?
  2. Can kReduceScatter be added to the list if this is a bug?

Can you please check whether f584acb fixes this issue?

I think it is fixed, please let us know in case this is still a problem.