@@ -237,6 +237,11 @@ merge_gating_gmm: False
237237
238238norm_topk_prob : false # boolean to enable the top-k probability normalization. qwen3-specific normalization of router weights.
239239
240+ # how the expert axis is used to shard attention weights and activations
241+ # "fsdp" (ep acts as fsdp parallelism)
242+ # "context" (ep acts as context parallelism, training only)
243+ expert_shard_attention_option : " fsdp"
244+
240245# when moe weight matrices are sharded on both fsdp and fsdp-transpose axes, use two separate all-gather calls
241246moe_fsdp_use_two_stage_all_gather : false
242247# Shard the expert dimension of the MLP weights on the FSDP axis.
@@ -448,119 +453,92 @@ compile_xla_flags: "" # Compiler options e.g. compile_xla_flags="--xla_tpu_num_s
448453shard_mode : " auto" # can be either auto or explicit
449454custom_mesh_and_rule : " " # replace default mesh and logical rule by specifying yml name under config/mesh_and_rule/.
450455mesh_axes : ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
451- logical_axis_rules : [
452- # ==========================================
453- # Vocabulary Embedding
454- # ==========================================
455- # Vocab Activations
456+ logical_axis_rules : [
457+ ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
458+ ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
456459 ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
457460 ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
458- ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']],
459- ['activation_vocab', ['tensor', 'tensor_transpose']],
460- ['activation_vocab', 'tensor_sequence'],
461- ['activation_vocab', ['sequence', 'context']],
462- # Vocab Weights
463- ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
464- ['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
465- # ==========================================
466- # Attention
467- # ==========================================
468- # Attention Activations
469- ['activation_heads', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence', 'autoregressive']],
470- ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence']],
461+ ['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']],
462+ ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
463+ ['activation_length', ['sequence', 'context']],
464+ ['activation_length', ['context']],
471465 ['activation_attn_length', ['sequence', 'context']],
472- # ['activation_attn_length', ['context']],
466+ ['activation_attn_length', ['context']],
467+ ['activation_length_moe', ['sequence', 'context']],
468+ ['activation_length_moe', ['context']],
469+ ['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
470+ ['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']],
473471 ['activation_q_length', ['context']],
472+ ['prefill_activation_length', ['sequence', 'context']],
473+ ['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
474474 ['activation_kv_length', []],
475475 ['activation_attn_embed', ['tensor', 'tensor_transpose']],
476+ ['activation_embed', ['tensor', 'tensor_transpose']],
477+ ['activation_embed_moe', ['tensor', 'tensor_transpose']],
478+ ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
479+ ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']],
476480 ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],
481+ ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
477482 ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
478483 ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']],
479- # Attention Weights
484+ ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']],
485+ ['activation_vocab', ['tensor', 'tensor_transpose']],
486+ ['activation_vocab', 'tensor_sequence'],
487+ ['activation_vocab', ['sequence','context']],
488+ ['activation_stage', 'stage'],
489+ ['activation_exp', ['expert']],
490+ ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
491+ ['decode_length', ['sequence']],
492+ ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
493+ ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
494+ ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
495+ ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
480496 ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
481497 ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
482498 ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
483- ['qkv', []],
484- ['kv', []],
485- ['kv_head_dim', []],
499+ ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
500+ ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']],
501+ ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
502+ ['embed', ['fsdp', 'sequence', 'context', 'expert']],
503+ ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
504+ ['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
505+ ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
506+ ['embed_moe', ['fsdp', 'sequence', 'context']],
507+ ['embed_tensor_transpose', ['tensor_transpose']],
486508 ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
487509 ['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
488510 ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
489511 ['q_lora', ['fsdp', 'sequence', 'context', 'expert']],
490- ["q_lora_up_proj", []],
512+ ["q_lora_up_proj",[]],
491513 ['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
492514 ['kv_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
493515 ['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
494516 ['kv_lora', ['fsdp', 'sequence', 'context', 'expert']],
495- ["kv_lora_up_proj", []],
496- # ==========================================
497- # Mixture of Experts (MoE)
498- # ==========================================
499- # MoE Activations
500- ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
501- ['activation_length_moe', ['sequence', 'context']],
502- # ['activation_length_moe', ['context']],
503- ['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']],
504- ['activation_embed_moe', ['tensor', 'tensor_transpose']],
505- ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']],
506- ['activation_exp', ['expert']],
507- # MoE Weights
508- ['exp', 'expert'],
509- ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
510- ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
511- ['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
512- ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
513- ['embed_moe', ['fsdp', 'sequence', 'context']],
514- # ==========================================
515- # Standard MLP / Dense Layers / Model Structure
516- # ==========================================
517- # Dense Activations
518- ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
519- ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
520- ['activation_length', ['sequence', 'context']],
521- # ['activation_length', ['context']],
522- ['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
523- ['activation_embed', ['tensor', 'tensor_transpose']],
524- ['activation_stage', 'stage'],
525- # General Weights
526- ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
527- ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
528- ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context', 'expert']],
529- ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
530- ['embed', ['fsdp', 'sequence', 'context', 'expert']],
517+ ["kv_lora_up_proj",[]],
531518 ['norm', ['tensor', 'tensor_transpose']],
532519 ['layers', 'stage'],
533- ['diloco', 'diloco'],
534- ['engram_dim', ['tensor']],
535- ['dense_layers', []],
536- ['moe_layers', []],
537- ['mhc', []],
538- # ==========================================
539- # Inference(Prefill, Decode, Cache)
540- # ==========================================
541- ['prefill_activation_length', ['sequence', 'context']],
542- ['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
543- ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
544- ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
545- ['decode_length', ['sequence']],
546- ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']],
547- ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']],
548- ['paged_kv_heads', ['tensor']],
520+ ['qkv', []],
521+ ['kv', []],
522+ ['kv_head_dim', []],
549523 ['cache_batch_prefill', []],
550524 ['cache_batch', []],
551525 ['cache_heads_none', []],
526+ ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']],
527+ ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']],
552528 ['cache_kv', []],
553529 ['cache_sequence', []],
530+ ['exp', 'expert'],
531+ ['exp_with_fsdp', 'fsdp'],
532+ ['paged_kv_heads', ['tensor']],
554533 ['num_pages', []],
555534 ['tokens_per_page', []],
556535 ['paged_kv_head_dim_size', []],
557- # ==========================================
558- # Deprecated / Scheduled for Removal
559- # ==========================================
560- ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
561- ['embed_tensor_transpose', ['tensor_transpose']],
562- ['exp_with_fsdp', 'fsdp'],
563- ]
536+ ['dense_layers', []],
537+ ['moe_layers', []],
538+ ['engram_dim', ['tensor']],
539+ ['mhc', []],
540+ ['diloco', 'diloco'],
541+ ]
564542# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
565543data_sharding : [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
566544input_data_sharding_logical_axes : ['activation_embed_and_logits_batch', 'activation_norm_length']
0 commit comments