@@ -1027,7 +1027,13 @@ def gmm(
10271027 output = output [: orig_inputs_shape [0 ]]
10281028 return output
10291029
1030- batch_logical_axis = "activation_batch"
1030+ # The batch is sharded by expert, except during inference decoding (where batch size == 1).
1031+ # In the decoding case, the expert axis is instead replicated along the tensor's batch dimension.
1032+ is_batch_sharded_by_expert = inputs .shape [0 ] > 1
1033+ if is_batch_sharded_by_expert :
1034+ batch_logical_axis = "activation_batch"
1035+ else :
1036+ batch_logical_axis = "decode_batch_moe"
10311037
10321038 if self .get_tensor_transpose_parallelism_size () > 1 :
10331039 input_partition_pspec = self ._logical_to_mesh_axes (
@@ -1142,47 +1148,59 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
11421148 )
11431149
11441150 if num_expert_parallelism > 1 :
1151+ batch_axis = self ._expert_parallelism_name if is_batch_sharded_by_expert else "data"
11451152 # get group sizes for all shards
11461153 local_expert_size = self .config .num_experts // num_expert_parallelism
11471154 reshaped_group_sizes = jnp .sum (group_sizes .reshape (- 1 , local_expert_size ), axis = 1 )
11481155 global_group_sizes = group_sizes
11491156
1150- all_shards_group_sizes = jax .lax .all_gather (reshaped_group_sizes , axis_name = self ._expert_parallelism_name )
1151- input_offsets , send_sizes , output_offsets , recv_sizes = RoutedMoE .get_all_to_all_params (
1152- all_shards_group_sizes ,
1153- expert_shard_id ,
1154- num_expert_parallelism ,
1155- )
1156-
1157- # TODO(ranran): For better performance, we could update output buffer to a smaller
1158- # size to replace self.get_expert_parallelism_size() for efficiency,
1159- # Or we could apply capacity_factor for excessive experts.
1160- # Note: Reducing buffer increase the risk of token dropping under unbalanced distribution.
1161-
1162- # In the worst case, all of the global input data is assigned to each expert in the current shard.
1163- # This would result in num_expert_shards * input_size * experts_per_shard assignments. However, if
1164- # experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs.
1165- max_local_experts_per_tok = min (local_expert_size , self .config .num_experts_per_tok )
1166- buffer_size = int (num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok )
1167- output_shape = jax .lax .empty ((buffer_size , self .config .emb_dim ), dtype = x .dtype )
1168-
1169- x = jax .lax .ragged_all_to_all (
1170- x ,
1171- output_shape ,
1172- input_offsets ,
1173- send_sizes ,
1174- output_offsets ,
1175- recv_sizes ,
1176- axis_name = self ._expert_parallelism_name ,
1177- )
1178- global_group_sizes = jax .lax .all_gather (group_sizes , axis_name = self ._expert_parallelism_name )
1179- x , local_sorted_indices , group_sizes , selected_experts = RoutedMoE .local_permute (
1180- x ,
1181- global_group_sizes ,
1182- local_expert_size ,
1183- shard_index = expert_shard_id ,
1184- use_custom_sort_vjp = self .config .use_custom_sort_vjp ,
1185- )
1157+ if is_batch_sharded_by_expert :
1158+ all_shards_group_sizes = jax .lax .all_gather (reshaped_group_sizes , axis_name = batch_axis )
1159+ input_offsets , send_sizes , output_offsets , recv_sizes = RoutedMoE .get_all_to_all_params (
1160+ all_shards_group_sizes ,
1161+ expert_shard_id ,
1162+ num_expert_parallelism ,
1163+ )
1164+
1165+ # TODO(ranran): For better performance, we could update output buffer to a smaller
1166+ # size to replace self.get_expert_parallelism_size() for efficiency,
1167+ # Or we could apply capacity_factor for excessive experts.
1168+ # Note: Reducing buffer increase the risk of token dropping under unbalanced distribution.
1169+
1170+ # In the worst case, all of the global input data is assigned to each expert in the current shard.
1171+ # This would result in num_expert_shards * input_size * experts_per_shard assignments. However, if
1172+ # experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs.
1173+ max_local_experts_per_tok = min (local_expert_size , self .config .num_experts_per_tok )
1174+ buffer_size = int (num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok )
1175+ output_shape = jax .lax .empty ((buffer_size , self .config .emb_dim ), dtype = x .dtype )
1176+
1177+ x = jax .lax .ragged_all_to_all (
1178+ x ,
1179+ output_shape ,
1180+ input_offsets ,
1181+ send_sizes ,
1182+ output_offsets ,
1183+ recv_sizes ,
1184+ axis_name = self ._expert_parallelism_name ,
1185+ )
1186+ global_group_sizes = jax .lax .all_gather (group_sizes , axis_name = self ._expert_parallelism_name )
1187+ x , local_sorted_indices , group_sizes , selected_experts = RoutedMoE .local_permute (
1188+ x ,
1189+ global_group_sizes ,
1190+ local_expert_size ,
1191+ shard_index = expert_shard_id ,
1192+ use_custom_sort_vjp = self .config .use_custom_sort_vjp ,
1193+ )
1194+ else :
1195+ x , local_sorted_indices , group_sizes , selected_experts = RoutedMoE .local_permute (
1196+ x ,
1197+ global_group_sizes [None , :],
1198+ local_expert_size ,
1199+ shard_index = expert_shard_id ,
1200+ is_offset = True ,
1201+ global_sorted_experts = selected_experts ,
1202+ use_custom_sort_vjp = self .config .use_custom_sort_vjp ,
1203+ )
11861204
11871205 if self .config .mlp_bias :
11881206 w0_bias , w1_bias , wo_bias = self .transform_bias (selected_experts , w0_bias , w1_bias , wo_bias )
@@ -1325,26 +1343,47 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13251343 dtype = intermediate_output .dtype ,
13261344 )
13271345
1328- # locally unpermute back to the original order
1329- local_output = _sort_activations (
1330- intermediate_output ,
1331- jnp .argsort (local_sorted_indices ), # pylint: disable=undefined-variable
1332- self .config .use_custom_sort_vjp ,
1333- )
1334- input_offsets , send_sizes , output_offsets , recv_sizes = RoutedMoE .get_all_to_all_params (
1335- jnp .transpose (all_shards_group_sizes ), # pylint: disable=undefined-variable
1336- expert_shard_id ,
1337- num_expert_parallelism ,
1338- )
1339- intermediate_output = jax .lax .ragged_all_to_all (
1340- local_output ,
1341- output_shape ,
1342- input_offsets ,
1343- send_sizes ,
1344- output_offsets ,
1345- recv_sizes ,
1346- axis_name = self ._expert_parallelism_name ,
1347- )
1346+ if is_batch_sharded_by_expert :
1347+ # locally unpermute back to the original order
1348+ local_output = _sort_activations (
1349+ intermediate_output ,
1350+ jnp .argsort (local_sorted_indices ), # pylint: disable=undefined-variable
1351+ self .config .use_custom_sort_vjp ,
1352+ )
1353+ input_offsets , send_sizes , output_offsets , recv_sizes = RoutedMoE .get_all_to_all_params (
1354+ jnp .transpose (all_shards_group_sizes ), # pylint: disable=undefined-variable
1355+ expert_shard_id ,
1356+ num_expert_parallelism ,
1357+ )
1358+ intermediate_output = jax .lax .ragged_all_to_all (
1359+ local_output ,
1360+ output_shape ,
1361+ input_offsets ,
1362+ send_sizes ,
1363+ output_offsets ,
1364+ recv_sizes ,
1365+ axis_name = self ._expert_parallelism_name ,
1366+ )
1367+ else :
1368+ # If bach is replicated across EP shards then each shard should send
1369+ # 0..local_shard_size data to the other shards and receive the
1370+ # local_shard data from all of the other shards using
1371+ # ragged_all_to_all.
1372+ input_offsets , send_sizes , output_offsets , recv_sizes = RoutedMoE .get_all_to_all_params (
1373+ reshaped_group_sizes , # pylint: disable=undefined-variable
1374+ expert_shard_id ,
1375+ num_expert_parallelism ,
1376+ is_batch_sharded = False ,
1377+ )
1378+ intermediate_output = jax .lax .ragged_all_to_all (
1379+ intermediate_output ,
1380+ output_shape ,
1381+ input_offsets ,
1382+ send_sizes ,
1383+ output_offsets ,
1384+ recv_sizes ,
1385+ axis_name = self ._expert_parallelism_name ,
1386+ )
13481387
13491388 output = self .unpermute (
13501389 intermediate_output ,
0 commit comments