feat(archon): add moe_router_dtype config for FP32 router gate GEMM#1009
feat(archon): add moe_router_dtype config for FP32 router gate GEMM#1009garrett4wade merged 1 commit intomainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates a crucial feature to bolster the numerical stability of Mixture-of-Experts (MoE) models by enabling the router gate's core matrix multiplication to operate in FP32 precision. This is achieved through a custom autograd function that intelligently manages data types across forward and backward passes, optimizing both stability and memory usage. Concurrently, the PR streamlines the MoE testing infrastructure, ensuring comprehensive coverage and improved reliability for these advanced model configurations. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a configurable moe_router_dtype to enable FP32 precision for the MoE router gate GEMM, which is a valuable feature for improving numerical stability. The implementation correctly uses a custom torch.autograd.Function to maintain memory efficiency by keeping gradients in the original data type. The changes are well-structured, and the addition of consolidated and comprehensive tests, including for torch.compile compatibility, is excellent. I have one minor suggestion to improve a docstring for clarity. Overall, this is a solid contribution.
a2af464 to
71924a1
Compare
Add configurable FP32 precision for MoE router gate GEMM to improve numerical stability with large expert counts, using a Megatron-Core-style custom torch.autograd.Function. Key changes: - Add moe_router_dtype field to ArchonEngineConfig (default "fp32") - Add router_dtype field to MoEArgs dataclass - Implement RouterGatingLinearFunction with FP32 forward/backward - Thread config from ArchonEngineConfig through to TokenChoiceTopKRouter - None means no override (use model dtype), "fp32" runs gate GEMM in float32 - Consolidate test_moe_args.py and test_router_fp32.py into test_moe_common.py
71924a1 to
0f49ccb
Compare
…1009) Add configurable FP32 precision for MoE router gate GEMM to improve numerical stability with large expert counts, using a Megatron-Core-style custom torch.autograd.Function. Key changes: - Add moe_router_dtype field to ArchonEngineConfig (default "fp32") - Add router_dtype field to MoEArgs dataclass - Implement RouterGatingLinearFunction with FP32 forward/backward - Thread config from ArchonEngineConfig through to TokenChoiceTopKRouter - None means no override (use model dtype), "fp32" runs gate GEMM in float32 - Consolidate test_moe_args.py and test_router_fp32.py into test_moe_common.py
…nclusionAI#1009) Add configurable FP32 precision for MoE router gate GEMM to improve numerical stability with large expert counts, using a Megatron-Core-style custom torch.autograd.Function. Key changes: - Add moe_router_dtype field to ArchonEngineConfig (default "fp32") - Add router_dtype field to MoEArgs dataclass - Implement RouterGatingLinearFunction with FP32 forward/backward - Thread config from ArchonEngineConfig through to TokenChoiceTopKRouter - None means no override (use model dtype), "fp32" runs gate GEMM in float32 - Consolidate test_moe_args.py and test_router_fp32.py into test_moe_common.py
…nclusionAI#1009) Add configurable FP32 precision for MoE router gate GEMM to improve numerical stability with large expert counts, using a Megatron-Core-style custom torch.autograd.Function. Key changes: - Add moe_router_dtype field to ArchonEngineConfig (default "fp32") - Add router_dtype field to MoEArgs dataclass - Implement RouterGatingLinearFunction with FP32 forward/backward - Thread config from ArchonEngineConfig through to TokenChoiceTopKRouter - None means no override (use model dtype), "fp32" runs gate GEMM in float32 - Consolidate test_moe_args.py and test_router_fp32.py into test_moe_common.py
Description
Add configurable FP32 precision for MoE router gate GEMM to improve numerical stability with large expert counts. Uses a Megatron-Core-style custom
torch.autograd.Functionthat computes the gate linear in FP32 while keeping gradients in the original dtype (BF16) for memory efficiency.Also consolidates
test_moe_args.pyandtest_router_fp32.pyinto a singletest_moe_common.pywith proper triton skip handling so MoEArgs tests are not accidentally skipped on machines without triton.Related Issue
N/A
Type of Change
Checklist
jb build docs/gemini review)Breaking Change Details (if applicable):
N/A
Additional Context
Key changes:
ArchonEngineConfig.moe_router_dtype: new field (default"fp32", acceptsNone/"fp32")MoEArgs.router_dtype: new field threaded from engine config to routerRouterGatingLinearFunction: custom autograd function — FP32 forward, BF16 grads backwardrouter_gating_linear: convenience wrapper,torch.compilefullgraph compatibletest_moe_common.py: consolidated MoE test file withtry/excepttriton guard instead of module-levelimportorskipFiles changed:
areal/api/cli_args.py— addmoe_router_dtypeconfig field + validationareal/experimental/engine/archon_engine.py— thread config to MoEArgsareal/experimental/models/archon/moe/args.py— addrouter_dtypefieldareal/experimental/models/archon/moe/moe.py— passrouter_dtypeto routerareal/experimental/models/archon/moe/router.py— implement FP32 gate GEMMareal/experimental/models/archon/qwen3/model/args.py— thread configdocs/{en,zh}/cli_reference.md— auto-generated doc updatetests/experimental/archon/test_moe_common.py— consolidated test suitetests/experimental/archon/test_moe_args.py— deleted (merged into above)