Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
104 commits
Select commit Hold shift + click to select a range
4f7686d
feat: speculative decode init
Apr 13, 2026
71cecba
feat: add config
Apr 13, 2026
ac145b8
feat: fix log
Apr 13, 2026
c82b174
feat: fix
Apr 13, 2026
3282366
feat: fix
Apr 13, 2026
fffe03e
feat: fix
Apr 13, 2026
76999a3
feat: fix
Apr 13, 2026
4682168
feat: fix
Apr 13, 2026
8a9e8d1
feat: fix code
Apr 14, 2026
e955708
feat: add log
Apr 14, 2026
985cb26
fet: fix
Apr 14, 2026
6670f3d
feat: fix
Apr 14, 2026
6fdcbc7
feat: fix code
Apr 14, 2026
261c9c7
feat: fix mtp
Apr 14, 2026
bd36e6a
feat: remove
Apr 14, 2026
4924ee4
feat: change config
Apr 14, 2026
9f7796b
feat: improve mtp_loss_scaling_factor
Apr 14, 2026
e0471c3
feat: fix config
Apr 14, 2026
0bc4616
feat: fix
Apr 14, 2026
031fa89
feat: fix local for test
Apr 14, 2026
b36ca69
feat: bug fix
Apr 14, 2026
a07dc22
feat: fix
Apr 14, 2026
3a56483
feat: add base config
Apr 14, 2026
05bdf24
feat: remove mtp keep
Apr 14, 2026
41ac1ab
feat: fix mtp loss
Apr 15, 2026
dd4eb5e
feat: fix ckpt
Apr 15, 2026
7c27544
feat: fix config oom
Apr 15, 2026
f719dad
feat: fix no mtp
Apr 15, 2026
1bcabea
feat: fix config
Apr 15, 2026
107b8f3
feat: fix
Apr 15, 2026
7cb5e49
feat: fix
Apr 15, 2026
7cbfe81
feat: add qwen
Apr 16, 2026
0a32066
feat: remove enable_draft_weights_cpu_backup
Apr 16, 2026
93baaf8
feat: fix mtp loss
Apr 16, 2026
cdeebcb
feat: fix OOM
Apr 16, 2026
f79b5c5
feat: revert
Apr 16, 2026
f553572
feat: add mem log
Apr 16, 2026
53a3b2d
feat: rm log
Apr 16, 2026
b4bbb1f
feat: sample log
Apr 16, 2026
df7d918
feat: fix
Apr 16, 2026
f50e604
feat: fix mtp gradient
Apr 17, 2026
b40a55b
feat: fix again
Apr 17, 2026
ca39b2e
feat: fix
Apr 17, 2026
1889b1a
feat(engine): add mtp weight update
Apr 18, 2026
bbc9deb
feat(mtp): fix mtp weight update
Apr 18, 2026
ee24e8d
fix(controller): fix callback
Apr 19, 2026
e031206
fix(controller): skip _NO_PROXY
Apr 19, 2026
aaa3aa5
fix(controller): fix update
Apr 19, 2026
4d04c35
feat(controller): add log
Apr 19, 2026
a1c3e82
fix(engine): cuda ipc sync
Apr 20, 2026
b23abd1
fix(megatron): add log
Apr 20, 2026
e7c3f7b
fix(engine): improve serialize
Apr 20, 2026
dd3eeea
fix(engine): skip NCCL broadcast
Apr 20, 2026
1e6a453
fix(engine): improve
Apr 20, 2026
e7a6b38
fix(engine): fix nccl block
Apr 20, 2026
7028020
refactor(rollout_controller): add log metric
Apr 20, 2026
cfd9115
fix(engine): fix CUDA stream
Apr 20, 2026
d373f03
feat(megatron): add log
Apr 20, 2026
8c070b7
fix(rollout_controller): add
Apr 20, 2026
a4b48c7
feat(megatron): fix
Apr 20, 2026
c802bee
refactor(megatron_engine): improve
Apr 20, 2026
3af6904
fix(mcore): deal eh_proj.weight
Apr 22, 2026
4b2e96a
fix(megatron_engine): remove code
Apr 22, 2026
57061fe
fix(megatron_engine): grad
Apr 22, 2026
fa88152
feat(megatron_engine): add mtp log
Apr 23, 2026
02dc326
fix: use _logger
Apr 23, 2026
f8c2dab
fix(engine): fix mtp gradient
Apr 24, 2026
7e4118a
feat(mtp): add mtp lr
Apr 24, 2026
2276771
fix(engine): add mtp clip
Apr 24, 2026
b4f5543
refactor(megatron_engine): mv
Apr 24, 2026
00e4497
feat(megatron_engine): ad
Apr 24, 2026
3117ccf
fix: h20 config
Apr 27, 2026
6dba807
perf: fix config
Apr 27, 2026
a9161e7
feat: add log
Apr 27, 2026
5feca78
fix(scheduler): worker check、
Apr 27, 2026
a5177cc
fix(infra): fix net
Apr 27, 2026
4e06ba9
fix(net): add callback(need rethink)
Apr 27, 2026
63497c7
fix(engine): double scale
Apr 27, 2026
5b84634
feat(actor): fix mtp_lr_scale
Apr 27, 2026
056724c
fix(engine): fix mtp gradient numbatch
Apr 28, 2026
8b90666
fix(engine): lr
Apr 28, 2026
be8c1b0
feat(engine): megatron log
Apr 28, 2026
b55f862
feat(engine): audit log
Apr 28, 2026
23ebf0e
fix(megatron_engine): mimo weight update
Apr 28, 2026
b0d9363
fix: scale up mtp_lr_scale
Apr 29, 2026
9c29945
fix(megatron_engine): add log
Apr 29, 2026
410fb90
feat(megatron_engine): fp32 weight update
Apr 29, 2026
8c3f60d
feat(engine): add full stage log
May 2, 2026
56e5e08
fix(engine): fix mtp
May 2, 2026
918fb3f
fix(megatron_engine): mtp nccl error
May 2, 2026
8913215
fix(engine): fix again
May 2, 2026
ba01036
fix(megatron_engine): fix
May 2, 2026
4933461
feat(engine): improve
May 3, 2026
4b9a8e7
feat(megatron_engine): verify fp32 weight
May 3, 2026
3ea975f
feat(engine): fix again
May 3, 2026
d76a4ff
feat(infra): read sglang weight for verify
May 3, 2026
866e6a1
feat(engine): fix
May 3, 2026
dc6571b
feat(megatron_engine): verify log
May 3, 2026
dcb2e44
fix(engine): mtp issue
May 3, 2026
3ee9a06
feat(controller): fix
May 3, 2026
59b319c
feat(controller): fix1
May 3, 2026
832025c
fix(engine): fix
May 3, 2026
2dff0d9
fix(engine): v36
May 4, 2026
17cf72e
fix(engine): fix
May 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,29 @@ class MegatronEngineConfig:
},
)

# MTP (Multi-Token Prediction) Configuration
mtp_num_layers: int = field(
default=0,
metadata={
"help": "Number of MTP (Multi-Token Prediction) layers for speculative decoding training. "
"0 means MTP is disabled."
},
)
mtp_loss_scaling_factor: float = field(
default=0.1,
metadata={
"help": "Scaling factor for MTP auxiliary loss. Controls the weight of MTP loss "
"relative to the main RL loss."
},
)
mtp_detach_heads: bool = field(
default=True,
metadata={
"help": "Whether to detach hidden states before passing to MTP heads in MegatronEngine. "
"When True, MTP loss gradients only update MTP parameters."
},
)


class SchedulingStrategyType(str, Enum):
separation = "separation"
Expand Down Expand Up @@ -1316,6 +1339,37 @@ class PPOActorConfig(TrainEngineConfig):
metadata={"help": "Maximum number of new tokens to generate"},
)

# MTP (Multi-Token Prediction) Online Training
enable_mtp_training: bool = field(
default=False,
metadata={
"help": "Enable MTP (Multi-Token Prediction) online training during RL. "
"When enabled, MTP layers are trained alongside the main policy model "
"to keep the draft model aligned with the evolving policy."
},
)
mtp_num_layers: int = field(
default=1,
metadata={
"help": "Number of MTP layers to train. Must match the model's MTP architecture."
},
)
mtp_loss_scaling_factor: float = field(
default=0.1,
metadata={
"help": "Scaling factor for MTP auxiliary loss relative to the main RL loss."
},
)
mtp_detach_heads: bool = field(
default=True,
metadata={
"help": "Whether to detach hidden states before passing to MTP heads. "
"When True (recommended for RL), MTP loss gradients only update MTP parameters, "
"preventing the MTP auxiliary loss from corrupting the main policy gradients. "
"When False, MTP loss gradients also flow back to the main model."
},
)

def should_compute_prox_logp(self) -> bool:
"""Determine if forward pass is needed for proximal log-probabilities.

Expand Down Expand Up @@ -1373,6 +1427,19 @@ def __post_init__(self):
"Please set `actor.use_decoupled_loss=false` in your configuration."
)

# Validate MTP configuration
if self.enable_mtp_training:
if self.mtp_num_layers <= 0:
raise ValueError(
f"mtp_num_layers must be > 0 when enable_mtp_training is True, "
f"got {self.mtp_num_layers}."
)
if not (0 < self.mtp_loss_scaling_factor <= 1.0):
raise ValueError(
f"mtp_loss_scaling_factor must be in (0, 1.0], "
f"got {self.mtp_loss_scaling_factor}."
)

super().__post_init__()


Expand Down Expand Up @@ -1579,6 +1646,44 @@ class SGLangConfig:
# Internal field, not exposed to users.
enable_return_routed_experts: bool = False

# Speculative Decoding Configuration
speculative_algorithm: str | None = field(
default=None,
metadata={
"help": "Speculative decoding algorithm. Options: 'EAGLE', 'EAGLE3'. None disables speculative decoding."
},
)
speculative_draft_model_path: str | None = field(
default=None,
metadata={"help": "Path to the draft model for speculative decoding."},
)
speculative_num_steps: int = field(
default=3,
metadata={"help": "Number of speculative decoding draft steps."},
)
speculative_eagle_topk: int = field(
default=1,
metadata={"help": "Top-k value for EAGLE draft token selection."},
)
speculative_num_draft_tokens: int = field(
default=4,
metadata={"help": "Number of draft tokens per speculative step."},
)
speculative_attention_mode: str | None = field(
default=None,
metadata={
"help": "Attention mode for speculative decoding. E.g., 'full', 'sparse'."
},
)
enable_multi_layer_eagle: bool = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with other configuration fields in this dataclass, enable_multi_layer_eagle should be defined using field(). This also provides an opportunity to add a help string in the metadata for better documentation and discoverability through CLI help messages.

    enable_multi_layer_eagle: bool = field(
        default=False,
        metadata={"help": "Enable multi-layer EAGLE for speculative decoding."},
    )

enable_draft_weights_cpu_backup: bool | None = field(
default=None,
metadata={
"help": "Keep draft model weights on CPU as backup during GPU offload cycles. "
"Essential for colocated training+inference mode to prevent draft weight loss."
},
)

# Use staticmethod to make OmegaConf happy.
@staticmethod
def build_cmd(
Expand Down Expand Up @@ -1630,6 +1735,13 @@ def build_args(
)
args.pop("enable_multithread_load", None)

# enable_draft_weights_cpu_backup: pass to SGLang ServerArgs constructor if set.
# Essential for colocated training+inference mode to prevent draft weight loss
# during GPU offload cycles. If None, let SGLang use its default.
draft_cpu_backup = args.pop("enable_draft_weights_cpu_backup", None)
if draft_cpu_backup is not None:
args["enable_draft_weights_cpu_backup"] = draft_cpu_backup

args = dict(
# Model and tokenizer
tokenizer_path=sglang_config.model_path,
Expand Down
6 changes: 6 additions & 0 deletions areal/api/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ class ModelResponse:
# MoE routing (only populated when return_routed_experts=True)
routed_experts: np.ndarray | None = None

# Speculative decoding statistics
spec_accept_token_num: int = 0
spec_draft_token_num: int = 0

@property
def input_len(self) -> int:
return len(self.input_tokens)
Expand Down Expand Up @@ -283,6 +287,8 @@ class HttpGenerationResult:
output_logprobs: list[float]
stop_reason: str
routed_experts: np.ndarray | None = None
spec_accept_token_num: int | None = None
spec_draft_token_num: int | None = None


@dataclass
Expand Down
Loading
Loading