Skip to content

Commit 640eecb

Browse files
committed
refactor dump sharding test for customizeing mesh and rules and more shardings
1 parent 87335ad commit 640eecb

File tree

72 files changed

+5121
-65412
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+5121
-65412
lines changed

docs/reference/core_concepts/moe_configuration.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ Dropping:
9696

9797
## 2. Sharding
9898

99+
`expert_shard_attention_option`: Determines how the "expert" axis is interpreted when sharding attention layers. Options include:
100+
101+
- `fsdp`: Treats the expert axis as a FSDP axis.
102+
- `context`: Treats the expert axis as a context parallelism axis, useful for long context.
103+
99104
`use_ring_of_experts` (experimental): This feature requires expert parallelism. If enabled, it replaces the standard two All-to-All communications with All-Gather in dispatch and Reduce-Scatter in collect. By gathering inputs across all shards, it allows for local routing and Top-K calculations, followed by result aggregation via Reduce-Scatter. This approach is particularly effective for models with a large Top-K, as it gathers activations before they are replicated k times to reduce communication.
100105

101106
`moe_fsdp_use_two_stage_all_gather`: If enabled, split the All-Gather operation for MoE weights into two separate stages when using FSDP/FSDP-transpose sharding. This is preferred when 3D All-Gather support is unavailable.

src/maxtext/common/common_types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@
6666
MODEL_MODE_PREFILL = "prefill"
6767
MODEL_MODE_TRAIN = "train"
6868

69+
# expert_shard_attention_option
70+
EP_AS_CONTEXT = "context"
71+
EP_AS_FSDP = "fsdp"
72+
6973
DECODING_ACTIVE_SEQUENCE_INDICATOR = 1
7074

7175
# A large negative mask value is used for masking to ensure that the

src/maxtext/configs/base.yml

Lines changed: 61 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,11 @@ merge_gating_gmm: False
237237

238238
norm_topk_prob: false # boolean to enable the top-k probability normalization. qwen3-specific normalization of router weights.
239239

240+
# how the expert axis is used to shard attention weights and activations
241+
# "fsdp" (ep acts as fsdp parallelism)
242+
# "context" (ep acts as context parallelism, training only)
243+
expert_shard_attention_option: "fsdp"
244+
240245
# when moe weight matrices are sharded on both fsdp and fsdp-transpose axes, use two separate all-gather calls
241246
moe_fsdp_use_two_stage_all_gather: false
242247
# Shard the expert dimension of the MLP weights on the FSDP axis.
@@ -448,119 +453,92 @@ compile_xla_flags: "" # Compiler options e.g. compile_xla_flags="--xla_tpu_num_s
448453
shard_mode: "auto" # can be either auto or explicit
449454
custom_mesh_and_rule: "" # replace default mesh and logical rule by specifying yml name under config/mesh_and_rule/.
450455
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
451-
logical_axis_rules: [
452-
# ==========================================
453-
# Vocabulary Embedding
454-
# ==========================================
455-
# Vocab Activations
456+
logical_axis_rules: [
457+
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
458+
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
456459
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
457460
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
458-
['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']],
459-
['activation_vocab', ['tensor', 'tensor_transpose']],
460-
['activation_vocab', 'tensor_sequence'],
461-
['activation_vocab', ['sequence', 'context']],
462-
# Vocab Weights
463-
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
464-
['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
465-
# ==========================================
466-
# Attention
467-
# ==========================================
468-
# Attention Activations
469-
['activation_heads', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence', 'autoregressive']],
470-
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence']],
461+
['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']],
462+
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
463+
['activation_length', ['sequence', 'context']],
464+
['activation_length', ['context']],
471465
['activation_attn_length', ['sequence', 'context']],
472-
# ['activation_attn_length', ['context']],
466+
['activation_attn_length', ['context']],
467+
['activation_length_moe', ['sequence', 'context']],
468+
['activation_length_moe', ['context']],
469+
['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
470+
['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']],
473471
['activation_q_length', ['context']],
472+
['prefill_activation_length', ['sequence', 'context']],
473+
['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
474474
['activation_kv_length', []],
475475
['activation_attn_embed', ['tensor', 'tensor_transpose']],
476+
['activation_embed', ['tensor', 'tensor_transpose']],
477+
['activation_embed_moe', ['tensor', 'tensor_transpose']],
478+
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
479+
['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']],
476480
['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],
481+
['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
477482
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
478483
['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']],
479-
# Attention Weights
484+
['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']],
485+
['activation_vocab', ['tensor', 'tensor_transpose']],
486+
['activation_vocab', 'tensor_sequence'],
487+
['activation_vocab', ['sequence','context']],
488+
['activation_stage', 'stage'],
489+
['activation_exp', ['expert']],
490+
['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
491+
['decode_length', ['sequence']],
492+
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
493+
['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
494+
['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
495+
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
480496
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
481497
['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
482498
['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
483-
['qkv', []],
484-
['kv', []],
485-
['kv_head_dim', []],
499+
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
500+
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']],
501+
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
502+
['embed', ['fsdp', 'sequence', 'context', 'expert']],
503+
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
504+
['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
505+
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
506+
['embed_moe', ['fsdp', 'sequence', 'context']],
507+
['embed_tensor_transpose', ['tensor_transpose']],
486508
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
487509
['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
488510
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
489511
['q_lora', ['fsdp', 'sequence', 'context', 'expert']],
490-
["q_lora_up_proj", []],
512+
["q_lora_up_proj",[]],
491513
['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
492514
['kv_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
493515
['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
494516
['kv_lora', ['fsdp', 'sequence', 'context', 'expert']],
495-
["kv_lora_up_proj", []],
496-
# ==========================================
497-
# Mixture of Experts (MoE)
498-
# ==========================================
499-
# MoE Activations
500-
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
501-
['activation_length_moe', ['sequence', 'context']],
502-
# ['activation_length_moe', ['context']],
503-
['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']],
504-
['activation_embed_moe', ['tensor', 'tensor_transpose']],
505-
['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']],
506-
['activation_exp', ['expert']],
507-
# MoE Weights
508-
['exp', 'expert'],
509-
['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
510-
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
511-
['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
512-
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
513-
['embed_moe', ['fsdp', 'sequence', 'context']],
514-
# ==========================================
515-
# Standard MLP / Dense Layers / Model Structure
516-
# ==========================================
517-
# Dense Activations
518-
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
519-
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
520-
['activation_length', ['sequence', 'context']],
521-
# ['activation_length', ['context']],
522-
['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
523-
['activation_embed', ['tensor', 'tensor_transpose']],
524-
['activation_stage', 'stage'],
525-
# General Weights
526-
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
527-
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
528-
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context', 'expert']],
529-
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
530-
['embed', ['fsdp', 'sequence', 'context', 'expert']],
517+
["kv_lora_up_proj",[]],
531518
['norm', ['tensor', 'tensor_transpose']],
532519
['layers', 'stage'],
533-
['diloco', 'diloco'],
534-
['engram_dim', ['tensor']],
535-
['dense_layers', []],
536-
['moe_layers', []],
537-
['mhc', []],
538-
# ==========================================
539-
# Inference(Prefill, Decode, Cache)
540-
# ==========================================
541-
['prefill_activation_length', ['sequence', 'context']],
542-
['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
543-
['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
544-
['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
545-
['decode_length', ['sequence']],
546-
['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']],
547-
['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']],
548-
['paged_kv_heads', ['tensor']],
520+
['qkv', []],
521+
['kv', []],
522+
['kv_head_dim', []],
549523
['cache_batch_prefill', []],
550524
['cache_batch', []],
551525
['cache_heads_none', []],
526+
['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']],
527+
['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']],
552528
['cache_kv', []],
553529
['cache_sequence', []],
530+
['exp', 'expert'],
531+
['exp_with_fsdp', 'fsdp'],
532+
['paged_kv_heads', ['tensor']],
554533
['num_pages', []],
555534
['tokens_per_page', []],
556535
['paged_kv_head_dim_size', []],
557-
# ==========================================
558-
# Deprecated / Scheduled for Removal
559-
# ==========================================
560-
['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
561-
['embed_tensor_transpose', ['tensor_transpose']],
562-
['exp_with_fsdp', 'fsdp'],
563-
]
536+
['dense_layers', []],
537+
['moe_layers', []],
538+
['engram_dim', ['tensor']],
539+
['mhc', []],
540+
['diloco', 'diloco'],
541+
]
564542
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
565543
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
566544
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch', 'activation_norm_length']

src/maxtext/configs/inference/inference.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ logical_axis_rules: [
2828
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
2929
['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
3030
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive','context_autoregressive']],
31-
['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
3231
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
3332
['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
3433
['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],

src/maxtext/configs/post_train/rl_mt_jt.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ logical_axis_rules: [
4242
['decode_length', []],
4343
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
4444
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive','context_autoregressive']],
45-
['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
4645
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
4746
['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
4847
['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],

src/maxtext/configs/types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,10 @@ class MoEGeneral(BaseModel):
661661
)
662662
use_random_routing: bool = Field(False, description="Whether to use random routing for debugging.")
663663
interleave_moe_layer_step: int = Field(1, description="Frequency of MoE layers, e.g., 2 means every 2nd layer is MoE.")
664+
expert_shard_attention_option: Literal["fsdp", "context"] = Field(
665+
"fsdp",
666+
description="How the expert axis is used to shard attention weights and activations.",
667+
)
664668
moe_fsdp_use_two_stage_all_gather: bool = Field(
665669
False,
666670
description="Use two separate All-Gather calls for MoE weights sharded on both FSDP and FSDP-transpose.",
@@ -2392,6 +2396,8 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
23922396
self.tensors_to_offload = [t for t in tensors if getattr(self, t) == "offload"]
23932397

23942398
cp_size = self.ici_context_parallelism * self.dcn_context_parallelism
2399+
if self.expert_shard_attention_option == "context":
2400+
cp_size *= self.ici_expert_parallelism * self.dcn_expert_parallelism
23952401
self.context_parallel_size = cp_size
23962402
if self.pipeline_parallel_layers == -1:
23972403
if self.decoder_block == DecoderBlockType.DEEPSEEK:

src/maxtext/layers/attention_mla.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
D_KV,
4949
DType,
5050
EMBED,
51+
EP_AS_CONTEXT,
5152
HEAD,
5253
Q_LORA_UP_PROJ,
5354
KV_BATCH,
@@ -900,6 +901,9 @@ def mla_get_key_value(self, low_rank_main, key_rope, model_mode):
900901
if model_mode == MODEL_MODE_PREFILL:
901902
key_logical_name = self.prefill_key_axis_names
902903
value_logical_name = self.prefill_value_axis_names
904+
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
905+
key_logical_name = self.ep_key_axis_names
906+
value_logical_name = self.ep_value_axis_names
903907
else:
904908
key_logical_name = self.key_axis_names
905909
value_logical_name = self.value_axis_names
@@ -1220,7 +1224,10 @@ def __call__(
12201224
)
12211225

12221226
out = jax.ad_checkpoint.checkpoint_name(out, "attention_out")
1223-
out = self._maybe_shard_with_logical(out, self.out_axis_names)
1227+
if model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
1228+
out = self._maybe_shard_with_logical(out, self.ep_out_axis_names)
1229+
else:
1230+
out = self._maybe_shard_with_logical(out, self.out_axis_names)
12241231

12251232
out_sharding = create_sharding(self.mesh, out_logical_name)
12261233
out = self.out_projection(out, out_sharding=out_sharding)

src/maxtext/layers/attention_op.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
DEFAULT_MASK_VALUE,
5656
DType,
5757
D_KV,
58+
EP_AS_FSDP,
5859
HEAD,
5960
KV_LENGTH,
6061
LENGTH,
@@ -1269,7 +1270,7 @@ def wrap_splash_kernel(single_head_mask):
12691270

12701271
splash_kernel = wrap_splash_kernel(single_head_mask)
12711272
segment_axis_names_splash_kernel = self._logical_to_mesh_axes((Q_LENGTH,))
1272-
elif self.config.use_jax_splash:
1273+
elif self.config.use_jax_splash and self.config.expert_shard_attention_option == EP_AS_FSDP:
12731274
if self.config.use_max_logit_estimate > 0:
12741275
sa_config = dataclasses.replace(sa_config, max_logit_const=self.config.use_max_logit_estimate)
12751276
segment_axis_names_splash_kernel = nn.logical_to_mesh_axes((Q_LENGTH,))

src/maxtext/layers/decoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -736,7 +736,7 @@ def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determi
736736
out_features_shape=cfg.vocab_size,
737737
weight_dtype=cfg.weight_dtype,
738738
dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability
739-
kernel_axes=("embed_vocab", "vocab"),
739+
kernel_axes=("embed", "vocab"),
740740
shard_mode=cfg.shard_mode,
741741
name="logits_dense",
742742
matmul_precision=self.config.matmul_precision,

0 commit comments

Comments
 (0)