[xla:auto_sharding] Question about resharding costs of Reshape strategies
Closed this issue · 4 comments
Hi OpenXLA community, I have question about the code to generate sharding strategies and compute resharding costs for HLO reshape
op.
In the code, the reshape
sharding strategyoutput_sepc
is generated (reshaped) from one of the operand sharding strategy src_strategy_group->strategies[sid]
:
xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc
Lines 1898 to 1901 in 421f4c4
Then compute the communication and memory resharding cost:
xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc
Lines 1917 to 1923 in 421f4c4
It looks like the resharding costs computed are between all operand sharding strategies src_strategy_group
and one operand strategy that output_spec
reshaped from src_strategy_group->strategies[sid]
. I guess this is why I got ZERO resharding cost on a reshape
where a all_gather
CC OP is actually required.
%5 = stablehlo.add %4, %cst_0 {mhlo.sharding = "{devices=[2,4,1,1]0,1,2,3,4,5,6,7}", result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f16[32,16,1000,256]{3,1,2,0}"} : tensor<32x16x1000x256xf16>
%6 = stablehlo.reshape %5 {mhlo.sharding = "{devices=[2,1,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}"} : (tensor<32x16x1000x256xf16>) -> tensor<512x1000x256xf16>
If my understanding of the Alpa algorithm is correct, the resharding cost should be computed like:
std::vector<double> communication_resharding_costs = CommunicationReshardingCostVector(
src_strategy_group, operand->shape(),
/* required_sharding */ output_spec, cluster_env);
std::vector<double> memory_resharding_costs = MemoryReshardingCostVector(
src_strategy_group, operand->shape(),
/* required_sharding */ output_spec, cluster_env);
Hi! The memory resharding costs and the communication resharding costs above represent the costs incurred when resharding the operand of the reshape from it's chosen sharding strategy to the strategy that the reshape op requires it to be. For a given operand strategy src_strategy_group->strategies[sid]
, we infer the corresponding strategy for the reshape (output_spec
). Thus, if output_spec
is chosen as a strategy for the reshape, the reshape will require its operand to have the corresponding src_strategy_group->strategies[sid]
strategy that output_spec
was inferred from. Thus the required_sharding
when computing these costs for the reshape operand would be src_strategy_group->strategies[sid]
. Does that make sense?
Hi! The memory resharding costs and the communication resharding costs above represent the costs incurred when resharding the operand of the reshape from it's chosen sharding strategy to the strategy that the reshape op requires it to be. For a given operand strategy
src_strategy_group->strategies[sid]
, we infer the corresponding strategy for the reshape (output_spec
). Thus, ifoutput_spec
is chosen as a strategy for the reshape, the reshape will require its operand to have the correspondingsrc_strategy_group->strategies[sid]
strategy thatoutput_spec
was inferred from. Thus therequired_sharding
when computing these costs for the reshape operand would besrc_strategy_group->strategies[sid]
. Does that make sense?
This make sense, but look at my example:
%5 = stablehlo.add %4, %cst_0 {mhlo.sharding = "{devices=[2,4,1,1]0,1,2,3,4,5,6,7}", result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f16[32,16,1000,256]{3,1,2,0}"} : tensor<32x16x1000x256xf16>
%6 = stablehlo.reshape %5 {mhlo.sharding = "{devices=[2,1,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}"} : (tensor<32x16x1000x256xf16>) -> tensor<512x1000x256xf16>
The reshape
op actually introduces an all_gather
along tensor_axis==1
of the operand and mesh_axis=1
. If this is not accounted as resharding communication cost
on operand, this should be accounted as communication cost
of the reshape
op, but in the code the communication cost
is always set to 0
:
I guess the communication cost
should be:
double communication_cost = cluster_env.ReshardingCost(
operand->shape(), src_strategy_group->strategies[sid].output_sharding,
*output_spec);
Hi! I ran auto-sharding on the following HLO, similar to the MHLO you provide:
HloModule module
ENTRY %entry {
%parameter1 = f16[32,16,1000,256]{3,1,2,0} parameter(0), sharding={devices=[2,4,1,1]0,1,2,3,4,5,6,7}
%reshape = f16[512,1000,256] reshape(parameter1), sharding={devices=[2,1,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
}
Given that the reshape op has a user annotated sharding, only the following ShardingStrategy is generated for the reshape:
Strategy S0 @ 0, {devices=[2,1,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}, compute_cost=3.2, communication_cost=0, memory_cost=1.31072e+08, communication_resharding_costs={[9.8304e+07]}, memory_resharding_costs={[9.8304e+07]}, input_shardings={[2, 1, 1, 1, 4]last_tile_dim_replicate,}
As you can see, if the reshape has a sharding of devices=[2,1,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
, it requires it's input operand to be sharded as {[2, 1, 1, 1, 4]last_tile_dim_replicate}
. The cost of sharding the operand is taken into account as the communication_resharding_cost in the strategy above. Once the operand is resharded to {[2, 1, 1, 1, 4]last_tile_dim_replicate}
(from {devices=[2,4,1,1]0,1,2,3,4,5,6,7}
), no communication is required for the reshape. Does this make sense?
Hi! I ran auto-sharding on the following HLO, similar to the MHLO you provide:
HloModule module ENTRY %entry { %parameter1 = f16[32,16,1000,256]{3,1,2,0} parameter(0), sharding={devices=[2,4,1,1]0,1,2,3,4,5,6,7} %reshape = f16[512,1000,256] reshape(parameter1), sharding={devices=[2,1,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate} }
Given that the reshape op has a user annotated sharding, only the following ShardingStrategy is generated for the reshape:
Strategy S0 @ 0, {devices=[2,1,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}, compute_cost=3.2, communication_cost=0, memory_cost=1.31072e+08, communication_resharding_costs={[9.8304e+07]}, memory_resharding_costs={[9.8304e+07]}, input_shardings={[2, 1, 1, 1, 4]last_tile_dim_replicate,}
As you can see, if the reshape has a sharding of
devices=[2,1,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
, it requires it's input operand to be sharded as{[2, 1, 1, 1, 4]last_tile_dim_replicate}
. The cost of sharding the operand is taken into account as the communication_resharding_cost in the strategy above. Once the operand is resharded to{[2, 1, 1, 1, 4]last_tile_dim_replicate}
(from{devices=[2,4,1,1]0,1,2,3,4,5,6,7}
), no communication is required for the reshape. Does this make sense?
I see, thanks!