Skip to content

feat(archon): add moe_router_dtype config for FP32 router gate GEMM#1009

Merged
garrett4wade merged 1 commit intomainfrom
rchardx/router_fp32
Mar 8, 2026
Merged

feat(archon): add moe_router_dtype config for FP32 router gate GEMM#1009
garrett4wade merged 1 commit intomainfrom
rchardx/router_fp32

Conversation

@rchardx
Copy link
Copy Markdown
Collaborator

@rchardx rchardx commented Mar 8, 2026

Description

Add configurable FP32 precision for MoE router gate GEMM to improve numerical stability with large expert counts. Uses a Megatron-Core-style custom torch.autograd.Function that computes the gate linear in FP32 while keeping gradients in the original dtype (BF16) for memory efficiency.

Also consolidates test_moe_args.py and test_router_fp32.py into a single test_moe_common.py with proper triton skip handling so MoEArgs tests are not accidentally skipped on machines without triton.

Related Issue

N/A

Type of Change

  • Bug fix (non-breaking change that fixes an issue)
  • New feature (non-breaking change that adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Documentation update
  • Code refactoring (no functional changes)
  • Performance improvement
  • Test coverage improvement

Checklist

  • I have read the Contributing Guide
  • I have run formatting tools (pre-commit or manual)
  • I have run relevant unit tests and they pass
  • I have added tests for new functionality
  • I have updated documentation if needed
  • My branch is up to date with main
  • This PR introduces breaking changes (if yes, fill out details below)
  • If this PR changes documentation, I have built and previewed it locally with jb build docs
  • No critical issues raised by AI reviewers (/gemini review)

Breaking Change Details (if applicable):

N/A

Additional Context

Key changes:

  • ArchonEngineConfig.moe_router_dtype: new field (default "fp32", accepts None/"fp32")
  • MoEArgs.router_dtype: new field threaded from engine config to router
  • RouterGatingLinearFunction: custom autograd function — FP32 forward, BF16 grads backward
  • router_gating_linear: convenience wrapper, torch.compile fullgraph compatible
  • test_moe_common.py: consolidated MoE test file with try/except triton guard instead of module-level importorskip

Files changed:

  • areal/api/cli_args.py — add moe_router_dtype config field + validation
  • areal/experimental/engine/archon_engine.py — thread config to MoEArgs
  • areal/experimental/models/archon/moe/args.py — add router_dtype field
  • areal/experimental/models/archon/moe/moe.py — pass router_dtype to router
  • areal/experimental/models/archon/moe/router.py — implement FP32 gate GEMM
  • areal/experimental/models/archon/qwen3/model/args.py — thread config
  • docs/{en,zh}/cli_reference.md — auto-generated doc update
  • tests/experimental/archon/test_moe_common.py — consolidated test suite
  • tests/experimental/archon/test_moe_args.py — deleted (merged into above)

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates a crucial feature to bolster the numerical stability of Mixture-of-Experts (MoE) models by enabling the router gate's core matrix multiplication to operate in FP32 precision. This is achieved through a custom autograd function that intelligently manages data types across forward and backward passes, optimizing both stability and memory usage. Concurrently, the PR streamlines the MoE testing infrastructure, ensuring comprehensive coverage and improved reliability for these advanced model configurations.

Highlights

  • FP32 MoE Router Gate GEMM Configuration: Introduced a new moe_router_dtype configuration option in ArchonEngineConfig to allow specifying FP32 precision for the Mixture-of-Experts (MoE) router gate's General Matrix Multiply (GEMM) operation, enhancing numerical stability.
  • Custom Mixed-Precision Autograd Function: Implemented RouterGatingLinearFunction, a custom torch.autograd.Function, which executes the forward pass of the MoE router gate linear layer in FP32 for improved stability while efficiently computing gradients in the original model's dtype (e.g., BF16) during the backward pass.
  • Consolidated and Improved MoE Testing: Refactored and merged existing MoE test files (test_moe_args.py and test_router_fp32.py) into a single, comprehensive test_moe_common.py, which now includes robust Triton skip handling and torch.compile compatibility tests for the new FP32 router logic.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • areal/api/cli_args.py
    • Added moe_router_dtype field to ArchonEngineConfig with validation for 'fp32' or None.
  • areal/experimental/engine/archon_engine.py
    • Updated _create_model_structure to map the moe_router_dtype string config to torch.dtype and pass it to model arguments.
  • areal/experimental/models/archon/moe/args.py
    • Added router_dtype field of type torch.dtype | None to MoEArgs dataclass.
  • areal/experimental/models/archon/moe/moe.py
    • Modified init to pass moe_args.router_dtype to the TokenChoiceTopKRouter constructor.
  • areal/experimental/models/archon/moe/router.py
    • Introduced RouterGatingLinearFunction for custom FP32 GEMM.
    • Added router_gating_linear wrapper.
    • Integrated router_dtype into TokenChoiceTopKRouter's initialization and forward pass.
  • areal/experimental/models/archon/qwen3/model/args.py
    • Modified from_hf_config to accept and apply the router_dtype argument to moe_args.
  • docs/en/cli_reference.md
    • Documented the new moe_router_dtype configuration option in the English CLI reference.
  • docs/zh/cli_reference.md
    • Documented the new moe_router_dtype configuration option in the Chinese CLI reference.
  • tests/experimental/archon/test_moe_args.py
    • Removed, as its test cases were migrated and consolidated.
  • tests/experimental/archon/test_moe_common.py
    • Added a new comprehensive test suite, incorporating tests for MoEArgs, ArchonEngineConfig, RouterGatingLinearFunction, TokenChoiceTopKRouter, and torch.compile compatibility, with conditional Triton skipping.
Activity
  • No specific activity was provided in the context for this pull request.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a configurable moe_router_dtype to enable FP32 precision for the MoE router gate GEMM, which is a valuable feature for improving numerical stability. The implementation correctly uses a custom torch.autograd.Function to maintain memory efficiency by keeping gradients in the original data type. The changes are well-structured, and the addition of consolidated and comprehensive tests, including for torch.compile compatibility, is excellent. I have one minor suggestion to improve a docstring for clarity. Overall, this is a solid contribution.

Comment thread areal/experimental/models/archon/moe/args.py Outdated
@rchardx rchardx force-pushed the rchardx/router_fp32 branch from a2af464 to 71924a1 Compare March 8, 2026 11:39
Add configurable FP32 precision for MoE router gate GEMM to improve
numerical stability with large expert counts, using a Megatron-Core-style
custom torch.autograd.Function.

Key changes:
- Add moe_router_dtype field to ArchonEngineConfig (default "fp32")
- Add router_dtype field to MoEArgs dataclass
- Implement RouterGatingLinearFunction with FP32 forward/backward
- Thread config from ArchonEngineConfig through to TokenChoiceTopKRouter
- None means no override (use model dtype), "fp32" runs gate GEMM in float32
- Consolidate test_moe_args.py and test_router_fp32.py into test_moe_common.py
@rchardx rchardx force-pushed the rchardx/router_fp32 branch from 71924a1 to 0f49ccb Compare March 8, 2026 11:39
@rchardx rchardx added the safe-to-test Ready to run unit-tests in a PR. label Mar 8, 2026
@rchardx rchardx temporarily deployed to AReaL-unittests March 8, 2026 11:52 — with GitHub Actions Inactive
@rchardx rchardx temporarily deployed to AReaL-unittests March 8, 2026 11:52 — with GitHub Actions Inactive
Copy link
Copy Markdown
Collaborator

@garrett4wade garrett4wade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@garrett4wade garrett4wade merged commit 4f5a294 into main Mar 8, 2026
13 checks passed
@garrett4wade garrett4wade deleted the rchardx/router_fp32 branch March 8, 2026 15:06
dingzhiqiang pushed a commit that referenced this pull request Mar 16, 2026
…1009)

Add configurable FP32 precision for MoE router gate GEMM to improve
numerical stability with large expert counts, using a Megatron-Core-style
custom torch.autograd.Function.

Key changes:
- Add moe_router_dtype field to ArchonEngineConfig (default "fp32")
- Add router_dtype field to MoEArgs dataclass
- Implement RouterGatingLinearFunction with FP32 forward/backward
- Thread config from ArchonEngineConfig through to TokenChoiceTopKRouter
- None means no override (use model dtype), "fp32" runs gate GEMM in float32
- Consolidate test_moe_args.py and test_router_fp32.py into test_moe_common.py
leandermaben pushed a commit to leandermaben/AReaL that referenced this pull request Mar 24, 2026
…nclusionAI#1009)

Add configurable FP32 precision for MoE router gate GEMM to improve
numerical stability with large expert counts, using a Megatron-Core-style
custom torch.autograd.Function.

Key changes:
- Add moe_router_dtype field to ArchonEngineConfig (default "fp32")
- Add router_dtype field to MoEArgs dataclass
- Implement RouterGatingLinearFunction with FP32 forward/backward
- Thread config from ArchonEngineConfig through to TokenChoiceTopKRouter
- None means no override (use model dtype), "fp32" runs gate GEMM in float32
- Consolidate test_moe_args.py and test_router_fp32.py into test_moe_common.py
SathyaGnanakumar pushed a commit to danielkiely/AReaL that referenced this pull request Apr 29, 2026
…nclusionAI#1009)

Add configurable FP32 precision for MoE router gate GEMM to improve
numerical stability with large expert counts, using a Megatron-Core-style
custom torch.autograd.Function.

Key changes:
- Add moe_router_dtype field to ArchonEngineConfig (default "fp32")
- Add router_dtype field to MoEArgs dataclass
- Implement RouterGatingLinearFunction with FP32 forward/backward
- Thread config from ArchonEngineConfig through to TokenChoiceTopKRouter
- None means no override (use model dtype), "fp32" runs gate GEMM in float32
- Consolidate test_moe_args.py and test_router_fp32.py into test_moe_common.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

safe-to-test Ready to run unit-tests in a PR.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants