fix(tree-attn): Handle dummy trie in Megatron and optimize attention mask memory#911
fix(tree-attn): Handle dummy trie in Megatron and optimize attention mask memory#911
Conversation
Summary of ChangesHello @nuzant, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical bug in the MegatronEngine related to handling empty trie nodes during distributed training, ensuring stability and correct gradient propagation. Concurrently, it introduces a significant performance improvement by optimizing the memory footprint of attention mask generation for tree-based attention models, making the system more efficient for processing long sequences. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces two valuable changes. First, it correctly handles dummy tries in the Megatron engine to prevent errors during data parallel synchronization, which is a solid bug fix. Second, it significantly optimizes the attention mask building process by using a blockwise approach, reducing memory usage from O(N²) to O(B²). This is a great performance improvement for long sequences. The implementation of the blockwise attention mask is well-structured. I have a couple of minor suggestions regarding the explicit use of del for tensor deallocation, which could be simplified for better code clarity.
…mask memory - Add early return for empty trie in MegatronEngine._compute_logprobs_and_loss to prevent zero-numel errors in stats_tracker.denominator when using tree training with data parallel synchronization - Replace O(N²) tril_indices cache with blockwise processing in _build_attention_mask to reduce peak memory usage for long sequences - Add _apply_causal_mask_blockwise helper for memory-efficient causal mask construction using 2048-token blocks Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
7df9bf0 to
73ba5f3
Compare
…mask memory (inclusionAI#911) - Add early return for empty trie in MegatronEngine._compute_logprobs_and_loss to prevent zero-numel errors in stats_tracker.denominator when using tree training with data parallel synchronization - Replace O(N²) tril_indices cache with blockwise processing in _build_attention_mask to reduce peak memory usage for long sequences - Add _apply_causal_mask_blockwise helper for memory-efficient causal mask construction using 2048-token blocks Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Summary
stats_tracker.denominatorwhen using tree training with MegatronEngine and data parallel synchronizationRelated Issue
Fixes crash when dummy trie batches (created for DP synchronization) reach
grpo_loss_fnvia MegatronEngine, causing empty tensors to be passed tostats_tracker.denominator.Type of Change
Changes
1. MegatronEngine dummy trie handling (
areal/engine/megatron_engine.py)Added early return check in
_compute_logprobs_and_lossfor empty tries:This mirrors the existing handling in FSDPEngine and prevents empty tensors from reaching
grpo_loss_fn.2. Blockwise attention mask optimization (
areal/models/tree_attn/tree.py)Replaced the
tril_cacheapproach with blockwise processing:tril_indices(seq_len, seq_len)which uses O(N²) memoryThe lower triangular matrix is divided into:
Checklist
jb build docs/gemini review)Additional Context
See
.legacy/debug-megatron-zero-numel.mdfor detailed analysis of the root cause.🤖 Generated with Claude Code