openxla/xla

[xla:auto_sharding] Question about resharding costs of Reshape strategies

Closed this issue · 4 comments

Hi OpenXLA community, I have question about the code to generate sharding strategies and compute resharding costs for HLO reshape op.

In the code, the reshape sharding strategyoutput_sepc is generated (reshaped) from one of the operand sharding strategy src_strategy_group->strategies[sid]:

std::optional<HloSharding> output_spec =
hlo_sharding_util::ReshapeSharding(
operand->shape(), ins->shape(),
src_strategy_group->strategies[sid].output_sharding);

Then compute the communication and memory resharding cost:

std::vector<double> communication_resharding_costs =
CommunicationReshardingCostVector(
src_strategy_group, operand->shape(),
src_strategy_group->strategies[sid].output_sharding, cluster_env);
std::vector<double> memory_resharding_costs = MemoryReshardingCostVector(
src_strategy_group, operand->shape(),
src_strategy_group->strategies[sid].output_sharding, cluster_env);

It looks like the resharding costs computed are between all operand sharding strategies src_strategy_group and one operand strategy that output_spec reshaped from src_strategy_group->strategies[sid]. I guess this is why I got ZERO resharding cost on a reshape where a all_gather CC OP is actually required.

%5 = stablehlo.add %4, %cst_0 {mhlo.sharding = "{devices=[2,4,1,1]0,1,2,3,4,5,6,7}", result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f16[32,16,1000,256]{3,1,2,0}"} : tensor<32x16x1000x256xf16>
%6 = stablehlo.reshape %5 {mhlo.sharding = "{devices=[2,1,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}"} : (tensor<32x16x1000x256xf16>) -> tensor<512x1000x256xf16>

If my understanding of the Alpa algorithm is correct, the resharding cost should be computed like:

 std::vector<double> communication_resharding_costs =  CommunicationReshardingCostVector( 
     src_strategy_group, operand->shape(), 
     /* required_sharding */ output_spec, cluster_env); 
 std::vector<double> memory_resharding_costs = MemoryReshardingCostVector( 
     src_strategy_group, operand->shape(), 
     /* required_sharding */ output_spec, cluster_env); 

Hi! The memory resharding costs and the communication resharding costs above represent the costs incurred when resharding the operand of the reshape from it's chosen sharding strategy to the strategy that the reshape op requires it to be. For a given operand strategy src_strategy_group->strategies[sid], we infer the corresponding strategy for the reshape (output_spec). Thus, if output_spec is chosen as a strategy for the reshape, the reshape will require its operand to have the corresponding src_strategy_group->strategies[sid] strategy that output_spec was inferred from. Thus the required_sharding when computing these costs for the reshape operand would be src_strategy_group->strategies[sid]. Does that make sense?

Hi! The memory resharding costs and the communication resharding costs above represent the costs incurred when resharding the operand of the reshape from it's chosen sharding strategy to the strategy that the reshape op requires it to be. For a given operand strategy src_strategy_group->strategies[sid], we infer the corresponding strategy for the reshape (output_spec). Thus, if output_spec is chosen as a strategy for the reshape, the reshape will require its operand to have the corresponding src_strategy_group->strategies[sid] strategy that output_spec was inferred from. Thus the required_sharding when computing these costs for the reshape operand would be src_strategy_group->strategies[sid]. Does that make sense?

This make sense, but look at my example:

%5 = stablehlo.add %4, %cst_0 {mhlo.sharding = "{devices=[2,4,1,1]0,1,2,3,4,5,6,7}", result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f16[32,16,1000,256]{3,1,2,0}"} : tensor<32x16x1000x256xf16>
%6 = stablehlo.reshape %5 {mhlo.sharding = "{devices=[2,1,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}"} : (tensor<32x16x1000x256xf16>) -> tensor<512x1000x256xf16>

The reshape op actually introduces an all_gather along tensor_axis==1 of the operand and mesh_axis=1. If this is not accounted as resharding communication cost on operand, this should be accounted as communication cost of the reshape op, but in the code the communication cost is always set to 0:

double compute_cost = 0, communication_cost = 0;

I guess the communication cost should be:

double communication_cost = cluster_env.ReshardingCost(
          operand->shape(), src_strategy_group->strategies[sid].output_sharding,
          *output_spec);

Hi! I ran auto-sharding on the following HLO, similar to the MHLO you provide:

HloModule module

ENTRY %entry {
  %parameter1 = f16[32,16,1000,256]{3,1,2,0} parameter(0), sharding={devices=[2,4,1,1]0,1,2,3,4,5,6,7} 
  %reshape = f16[512,1000,256] reshape(parameter1), sharding={devices=[2,1,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
}

Given that the reshape op has a user annotated sharding, only the following ShardingStrategy is generated for the reshape:

Strategy S0 @ 0, {devices=[2,1,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}, compute_cost=3.2, communication_cost=0, memory_cost=1.31072e+08, communication_resharding_costs={[9.8304e+07]}, memory_resharding_costs={[9.8304e+07]}, input_shardings={[2, 1, 1, 1, 4]last_tile_dim_replicate,}

As you can see, if the reshape has a sharding of devices=[2,1,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}, it requires it's input operand to be sharded as {[2, 1, 1, 1, 4]last_tile_dim_replicate}. The cost of sharding the operand is taken into account as the communication_resharding_cost in the strategy above. Once the operand is resharded to {[2, 1, 1, 1, 4]last_tile_dim_replicate} (from {devices=[2,4,1,1]0,1,2,3,4,5,6,7}), no communication is required for the reshape. Does this make sense?

Hi! I ran auto-sharding on the following HLO, similar to the MHLO you provide:

HloModule module

ENTRY %entry {
  %parameter1 = f16[32,16,1000,256]{3,1,2,0} parameter(0), sharding={devices=[2,4,1,1]0,1,2,3,4,5,6,7} 
  %reshape = f16[512,1000,256] reshape(parameter1), sharding={devices=[2,1,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
}

Given that the reshape op has a user annotated sharding, only the following ShardingStrategy is generated for the reshape:

Strategy S0 @ 0, {devices=[2,1,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}, compute_cost=3.2, communication_cost=0, memory_cost=1.31072e+08, communication_resharding_costs={[9.8304e+07]}, memory_resharding_costs={[9.8304e+07]}, input_shardings={[2, 1, 1, 1, 4]last_tile_dim_replicate,}

As you can see, if the reshape has a sharding of devices=[2,1,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}, it requires it's input operand to be sharded as {[2, 1, 1, 1, 4]last_tile_dim_replicate}. The cost of sharding the operand is taken into account as the communication_resharding_cost in the strategy above. Once the operand is resharded to {[2, 1, 1, 1, 4]last_tile_dim_replicate} (from {devices=[2,4,1,1]0,1,2,3,4,5,6,7}), no communication is required for the reshape. Does this make sense?

I see, thanks!