Skip to content

fix(tree-attn): Handle dummy trie in Megatron and optimize attention mask memory#911

Merged
rchardx merged 1 commit intomainfrom
mzy/fix-tau2-megatron
Feb 9, 2026
Merged

fix(tree-attn): Handle dummy trie in Megatron and optimize attention mask memory#911
rchardx merged 1 commit intomainfrom
mzy/fix-tau2-megatron

Conversation

@nuzant
Copy link
Copy Markdown
Collaborator

@nuzant nuzant commented Feb 7, 2026

Summary

  • Fix zero-numel error in stats_tracker.denominator when using tree training with MegatronEngine and data parallel synchronization
  • Optimize attention mask building memory usage from O(N²) to O(B²) per block for long sequences
  • Fix some markdown format issue

Related Issue

Fixes crash when dummy trie batches (created for DP synchronization) reach grpo_loss_fn via MegatronEngine, causing empty tensors to be passed to stats_tracker.denominator.

Type of Change

  • Bug fix (non-breaking change that fixes an issue)
  • New feature (non-breaking change that adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Documentation update
  • Code refactoring (no functional changes)
  • Performance improvement
  • Test coverage improvement

Changes

1. MegatronEngine dummy trie handling (areal/engine/megatron_engine.py)

Added early return check in _compute_logprobs_and_loss for empty tries:

trie_node = inputs.get("trie_node")
if trie_node is None or not trie_node.all_sequence_ids:
    return output.sum() * 0.0

This mirrors the existing handling in FSDPEngine and prevents empty tensors from reaching grpo_loss_fn.

2. Blockwise attention mask optimization (areal/models/tree_attn/tree.py)

Replaced the tril_cache approach with blockwise processing:

  • Before: Created tril_indices(seq_len, seq_len) which uses O(N²) memory
  • After: Process in 2048-token blocks, each using O(B²) memory (~32MB per block)

The lower triangular matrix is divided into:

  • Diagonal blocks: lower triangular within each block
  • Off-diagonal blocks: fully dense blocks below the diagonal

Checklist

  • I have read the Contributing Guide
  • I have run formatting tools (pre-commit or manual)
  • I have run relevant unit tests and they pass
  • I have added tests for new functionality
  • I have updated documentation if needed
  • My branch is up to date with main
  • This PR introduces breaking changes (if yes, fill out details below)
  • If this PR changes documentation, I have built and previewed it locally with jb build docs
  • No critical issues raised by AI reviewers (/gemini review)

Additional Context

See .legacy/debug-megatron-zero-numel.md for detailed analysis of the root cause.

🤖 Generated with Claude Code

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @nuzant, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical bug in the MegatronEngine related to handling empty trie nodes during distributed training, ensuring stability and correct gradient propagation. Concurrently, it introduces a significant performance improvement by optimizing the memory footprint of attention mask generation for tree-based attention models, making the system more efficient for processing long sequences.

Highlights

  • MegatronEngine Dummy Trie Handling: Implemented an early return in _compute_logprobs_and_loss to gracefully handle empty or dummy trie nodes, preventing zero-numel errors during data parallel synchronization by returning a zero loss that maintains gradient connection.
  • Attention Mask Memory Optimization: Refactored the attention mask building process to use a blockwise approach, reducing memory consumption from O(N²) to O(B²) for long sequences by avoiding the creation of large tril_indices tensors and processing the mask in smaller, manageable blocks.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • areal/engine/megatron_engine.py
    • Added a check for trie_node being None or having no sequences at the beginning of _compute_logprobs_and_loss.
    • If a dummy trie is detected, the function now returns a zero loss to maintain gradient connectivity for distributed synchronization.
    • Updated references to inputs["trie_node"] to use the newly introduced trie_node local variable.
  • areal/models/tree_attn/tree.py
    • Introduced _ATTN_MASK_BLOCK_SIZE constant (2048) to define the block size for memory-efficient attention mask building.
    • Refactored _build_attention_mask to remove the tril_cache and delegate the causal mask application to a new helper function.
    • Added _apply_causal_mask_blockwise function which constructs the causal attention mask in a blockwise manner, processing diagonal and off-diagonal blocks separately to optimize memory usage.
  • examples/vlm_npu/README.md
    • Applied minor formatting adjustments to improve readability, including adding blank lines around code blocks and adjusting bullet point indentation.
Activity
  • No specific activity (comments, reviews, or progress updates) has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces two valuable changes. First, it correctly handles dummy tries in the Megatron engine to prevent errors during data parallel synchronization, which is a solid bug fix. Second, it significantly optimizes the attention mask building process by using a blockwise approach, reducing memory usage from O(N²) to O(B²). This is a great performance improvement for long sequences. The implementation of the blockwise attention mask is well-structured. I have a couple of minor suggestions regarding the explicit use of del for tensor deallocation, which could be simplified for better code clarity.

Comment thread areal/models/tree_attn/tree.py
Comment thread areal/models/tree_attn/tree.py
@nuzant nuzant added the safe-to-test Ready to run unit-tests in a PR. label Feb 7, 2026
…mask memory

- Add early return for empty trie in MegatronEngine._compute_logprobs_and_loss
  to prevent zero-numel errors in stats_tracker.denominator when using tree
  training with data parallel synchronization
- Replace O(N²) tril_indices cache with blockwise processing in
  _build_attention_mask to reduce peak memory usage for long sequences
- Add _apply_causal_mask_blockwise helper for memory-efficient causal mask
  construction using 2048-token blocks

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@nuzant nuzant force-pushed the mzy/fix-tau2-megatron branch from 7df9bf0 to 73ba5f3 Compare February 9, 2026 12:29
@nuzant nuzant added safe-to-test Ready to run unit-tests in a PR. and removed safe-to-test Ready to run unit-tests in a PR. labels Feb 9, 2026
@nuzant nuzant temporarily deployed to AReaL-unittests February 9, 2026 12:55 — with GitHub Actions Inactive
Copy link
Copy Markdown
Collaborator

@rchardx rchardx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@rchardx rchardx merged commit 985ecb2 into main Feb 9, 2026
10 of 11 checks passed
@rchardx rchardx deleted the mzy/fix-tau2-megatron branch February 9, 2026 14:13
leandermaben pushed a commit to leandermaben/AReaL that referenced this pull request Mar 24, 2026
…mask memory (inclusionAI#911)

- Add early return for empty trie in MegatronEngine._compute_logprobs_and_loss
  to prevent zero-numel errors in stats_tracker.denominator when using tree
  training with data parallel synchronization
- Replace O(N²) tril_indices cache with blockwise processing in
  _build_attention_mask to reduce peak memory usage for long sequences
- Add _apply_causal_mask_blockwise helper for memory-efficient causal mask
  construction using 2048-token blocks

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants