Skip to content

Commit 50ba0a1

Browse files
Merge pull request #3643 from AI-Hypercomputer:chengnuojin-fix-moe
PiperOrigin-RevId: 899204178
2 parents 3385aa0 + c57e4ec commit 50ba0a1

6 files changed

Lines changed: 133 additions & 57 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,7 @@ logical_axis_rules: [
488488
['activation_stage', 'stage'],
489489
['activation_exp', ['expert']],
490490
['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
491+
['decode_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
491492
['decode_length', ['sequence']],
492493
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
493494
['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],

src/maxtext/configs/inference/inference.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ logical_axis_rules: [
2424
['activation_stage', 'stage'],
2525
['activation_exp', ['expert', 'context_autoregressive']],
2626
['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']],
27+
['decode_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'context_autoregressive']],
2728
['decode_length', []],
2829
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
2930
['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],

src/maxtext/configs/inference/vllm.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ logical_axis_rules: [
5454
['activation_norm_length_moe', []],
5555
['activation_exp', ['expert', 'attn_dp_expert']],
5656
['decode_batch', ['expert', 'attn_dp_expert']],
57+
['decode_batch_moe', []],
5758
['decode_length', []],
5859
['mlp', ['model', 'attn_dp']],
5960
['mlp_moe', ['model', 'attn_dp']],

src/maxtext/configs/post_train/rl_mt_jt.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ logical_axis_rules: [
3939
['activation_stage', 'stage'],
4040
['activation_exp', ['expert', 'context_autoregressive']],
4141
['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']],
42+
['decode_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']],
4243
['decode_length', []],
4344
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
4445
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive','context_autoregressive']],

src/maxtext/layers/moe.py

Lines changed: 96 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,7 +1027,13 @@ def gmm(
10271027
output = output[: orig_inputs_shape[0]]
10281028
return output
10291029

1030-
batch_logical_axis = "activation_batch"
1030+
# The batch is sharded by expert, except during inference decoding (where batch size == 1).
1031+
# In the decoding case, the expert axis is instead replicated along the tensor's batch dimension.
1032+
is_batch_sharded_by_expert = inputs.shape[0] > 1
1033+
if is_batch_sharded_by_expert:
1034+
batch_logical_axis = "activation_batch"
1035+
else:
1036+
batch_logical_axis = "decode_batch_moe"
10311037

10321038
if self.get_tensor_transpose_parallelism_size() > 1:
10331039
input_partition_pspec = self._logical_to_mesh_axes(
@@ -1142,47 +1148,59 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
11421148
)
11431149

11441150
if num_expert_parallelism > 1:
1151+
batch_axis = self._expert_parallelism_name if is_batch_sharded_by_expert else "data"
11451152
# get group sizes for all shards
11461153
local_expert_size = self.config.num_experts // num_expert_parallelism
11471154
reshaped_group_sizes = jnp.sum(group_sizes.reshape(-1, local_expert_size), axis=1)
11481155
global_group_sizes = group_sizes
11491156

1150-
all_shards_group_sizes = jax.lax.all_gather(reshaped_group_sizes, axis_name=self._expert_parallelism_name)
1151-
input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params(
1152-
all_shards_group_sizes,
1153-
expert_shard_id,
1154-
num_expert_parallelism,
1155-
)
1156-
1157-
# TODO(ranran): For better performance, we could update output buffer to a smaller
1158-
# size to replace self.get_expert_parallelism_size() for efficiency,
1159-
# Or we could apply capacity_factor for excessive experts.
1160-
# Note: Reducing buffer increase the risk of token dropping under unbalanced distribution.
1161-
1162-
# In the worst case, all of the global input data is assigned to each expert in the current shard.
1163-
# This would result in num_expert_shards * input_size * experts_per_shard assignments. However, if
1164-
# experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs.
1165-
max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok)
1166-
buffer_size = int(num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok)
1167-
output_shape = jax.lax.empty((buffer_size, self.config.emb_dim), dtype=x.dtype)
1168-
1169-
x = jax.lax.ragged_all_to_all(
1170-
x,
1171-
output_shape,
1172-
input_offsets,
1173-
send_sizes,
1174-
output_offsets,
1175-
recv_sizes,
1176-
axis_name=self._expert_parallelism_name,
1177-
)
1178-
global_group_sizes = jax.lax.all_gather(group_sizes, axis_name=self._expert_parallelism_name)
1179-
x, local_sorted_indices, group_sizes, selected_experts = RoutedMoE.local_permute(
1180-
x,
1181-
global_group_sizes,
1182-
local_expert_size,
1183-
shard_index=expert_shard_id,
1184-
use_custom_sort_vjp=self.config.use_custom_sort_vjp,
1185-
)
1157+
if is_batch_sharded_by_expert:
1158+
all_shards_group_sizes = jax.lax.all_gather(reshaped_group_sizes, axis_name=batch_axis)
1159+
input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params(
1160+
all_shards_group_sizes,
1161+
expert_shard_id,
1162+
num_expert_parallelism,
1163+
)
1164+
1165+
# TODO(ranran): For better performance, we could update output buffer to a smaller
1166+
# size to replace self.get_expert_parallelism_size() for efficiency,
1167+
# Or we could apply capacity_factor for excessive experts.
1168+
# Note: Reducing buffer increase the risk of token dropping under unbalanced distribution.
1169+
1170+
# In the worst case, all of the global input data is assigned to each expert in the current shard.
1171+
# This would result in num_expert_shards * input_size * experts_per_shard assignments. However, if
1172+
# experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs.
1173+
max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok)
1174+
buffer_size = int(num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok)
1175+
output_shape = jax.lax.empty((buffer_size, self.config.emb_dim), dtype=x.dtype)
1176+
1177+
x = jax.lax.ragged_all_to_all(
1178+
x,
1179+
output_shape,
1180+
input_offsets,
1181+
send_sizes,
1182+
output_offsets,
1183+
recv_sizes,
1184+
axis_name=self._expert_parallelism_name,
1185+
)
1186+
global_group_sizes = jax.lax.all_gather(group_sizes, axis_name=self._expert_parallelism_name)
1187+
x, local_sorted_indices, group_sizes, selected_experts = RoutedMoE.local_permute(
1188+
x,
1189+
global_group_sizes,
1190+
local_expert_size,
1191+
shard_index=expert_shard_id,
1192+
use_custom_sort_vjp=self.config.use_custom_sort_vjp,
1193+
)
1194+
else:
1195+
x, local_sorted_indices, group_sizes, selected_experts = RoutedMoE.local_permute(
1196+
x,
1197+
global_group_sizes[None, :],
1198+
local_expert_size,
1199+
shard_index=expert_shard_id,
1200+
is_offset=True,
1201+
global_sorted_experts=selected_experts,
1202+
use_custom_sort_vjp=self.config.use_custom_sort_vjp,
1203+
)
11861204

11871205
if self.config.mlp_bias:
11881206
w0_bias, w1_bias, wo_bias = self.transform_bias(selected_experts, w0_bias, w1_bias, wo_bias)
@@ -1325,26 +1343,47 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13251343
dtype=intermediate_output.dtype,
13261344
)
13271345

1328-
# locally unpermute back to the original order
1329-
local_output = _sort_activations(
1330-
intermediate_output,
1331-
jnp.argsort(local_sorted_indices), # pylint: disable=undefined-variable
1332-
self.config.use_custom_sort_vjp,
1333-
)
1334-
input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params(
1335-
jnp.transpose(all_shards_group_sizes), # pylint: disable=undefined-variable
1336-
expert_shard_id,
1337-
num_expert_parallelism,
1338-
)
1339-
intermediate_output = jax.lax.ragged_all_to_all(
1340-
local_output,
1341-
output_shape,
1342-
input_offsets,
1343-
send_sizes,
1344-
output_offsets,
1345-
recv_sizes,
1346-
axis_name=self._expert_parallelism_name,
1347-
)
1346+
if is_batch_sharded_by_expert:
1347+
# locally unpermute back to the original order
1348+
local_output = _sort_activations(
1349+
intermediate_output,
1350+
jnp.argsort(local_sorted_indices), # pylint: disable=undefined-variable
1351+
self.config.use_custom_sort_vjp,
1352+
)
1353+
input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params(
1354+
jnp.transpose(all_shards_group_sizes), # pylint: disable=undefined-variable
1355+
expert_shard_id,
1356+
num_expert_parallelism,
1357+
)
1358+
intermediate_output = jax.lax.ragged_all_to_all(
1359+
local_output,
1360+
output_shape,
1361+
input_offsets,
1362+
send_sizes,
1363+
output_offsets,
1364+
recv_sizes,
1365+
axis_name=self._expert_parallelism_name,
1366+
)
1367+
else:
1368+
# If bach is replicated across EP shards then each shard should send
1369+
# 0..local_shard_size data to the other shards and receive the
1370+
# local_shard data from all of the other shards using
1371+
# ragged_all_to_all.
1372+
input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params(
1373+
reshaped_group_sizes, # pylint: disable=undefined-variable
1374+
expert_shard_id,
1375+
num_expert_parallelism,
1376+
is_batch_sharded=False,
1377+
)
1378+
intermediate_output = jax.lax.ragged_all_to_all(
1379+
intermediate_output,
1380+
output_shape,
1381+
input_offsets,
1382+
send_sizes,
1383+
output_offsets,
1384+
recv_sizes,
1385+
axis_name=self._expert_parallelism_name,
1386+
)
13481387

13491388
output = self.unpermute(
13501389
intermediate_output,

tests/integration/decode_tests.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,34 @@ class DecodeTests(unittest.TestCase):
9696
"prompt=I love to",
9797
"skip_jax_distributed_system=True",
9898
],
99+
"deepseek32": [ # tests decode for deepseek3.2-671b full EP
100+
None,
101+
get_test_config_path(),
102+
"base_output_directory=gs://runner-maxtext-logs",
103+
"run_name=decode",
104+
"model_name=deepseek3.2-671b",
105+
"override_model_config=True",
106+
"base_num_decoder_layers=2",
107+
"first_num_dense_layers=1",
108+
"num_experts=16",
109+
"base_mlp_dim=128",
110+
"base_emb_dim=128",
111+
"base_moe_mlp_dim=128",
112+
"tokenizer_type=huggingface",
113+
f"hf_access_token={os.environ.get('HF_TOKEN', '')}",
114+
"tokenizer_path=deepseek-ai/DeepSeek-V3.2-Exp",
115+
"scan_layers=False",
116+
"attention=dot_product",
117+
"weight_dtype=bfloat16",
118+
"per_device_batch_size=1",
119+
"max_prefill_predict_length=8",
120+
"max_target_length=16",
121+
"ici_fsdp_parallelism=1",
122+
"ici_tensor_parallelism=1",
123+
"ici_expert_parallelism=-1",
124+
"mla_naive_kvcache=false",
125+
"prompt=I love to",
126+
],
99127
}
100128
SAMPLING_STRATEGY_CONFIG = {
101129
"greedy": [
@@ -173,6 +201,11 @@ def test_decode_topk_sampling(self):
173201
expected_output = "Input `I love to` -> ` travel and I love to write"
174202
assert expected_output in captured_out
175203

204+
@pytest.mark.tpu_only
205+
@pytest.mark.scheduled_only
206+
def test_tpu_deepseek32(self):
207+
decode_main(DecodeTests.CONFIGS["deepseek32"])
208+
176209

177210
def run_decoding(config):
178211
f = io.StringIO()

0 commit comments

Comments
 (0)