Skip to content

Commit aa85dfc

Browse files
committed
fix dp seq balance bugs
1 parent 3a92c53 commit aa85dfc

File tree

11 files changed

+2237
-92
lines changed

11 files changed

+2237
-92
lines changed

ajet/backbone/trainer_verl.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,7 @@ def _validate(self):
810810

811811
# repeat test batch
812812
test_batch = test_batch.repeat(
813-
repeat_times=self.config.ajet.rollout.val_kwargs.num_repeat,
813+
repeat_times=self.config.ajet.trainer_common.val_pass_n,
814814
interleave=True,
815815
)
816816

@@ -858,10 +858,10 @@ def _validate(self):
858858
logger.info(f"test_gen_batch meta info: {test_gen_batch.meta_info}")
859859

860860
self.checkpoint_manager.update_weights(self.global_steps)
861-
main_val_dataset = self.get_eval_dataset()
861+
main_val_dataset = self.get_val_dataset()
862862

863863
logger.info("Starting validate rollout")
864-
context_tracker_arr, tasks, val_metrics = self.eval_dataset(
864+
context_tracker_arr, tasks, val_metrics = self._rollout_val_dataset(
865865
target_dataset=main_val_dataset,
866866
target_dataset_name="main_val_dataset",
867867
mode="validate",
@@ -920,7 +920,7 @@ def _validate(self):
920920

921921
return metric_dict
922922

923-
def eval_dataset(self, target_dataset, target_dataset_name, mode, epoch):
923+
def _rollout_val_dataset(self, target_dataset, target_dataset_name, mode, epoch):
924924
"""
925925
Evaluate a dataset by running rollouts and computing task completion metrics.
926926
@@ -1005,7 +1005,7 @@ def eval_dataset(self, target_dataset, target_dataset_name, mode, epoch):
10051005

10061006
return ctx_trackers, tasks, val_metrics
10071007

1008-
def get_eval_dataset(self):
1008+
def get_val_dataset(self):
10091009
from ajet.task_reader import RouterTaskReader
10101010

10111011
task_reader = RouterTaskReader(

ajet/backbone/verl/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
from .fsdp_workers import AjetActorRolloutRefWorker, AjetAsyncActorRolloutRefWorker
2-
from .actor_config import AjetActorConfig, AjetFSDPActorConfig
32
from .dp_actor import AjetDataParallelPPOActor
43

54
__all__ = [
65
"AjetActorRolloutRefWorker",
76
"AjetAsyncActorRolloutRefWorker",
8-
"AjetActorConfig",
9-
"AjetFSDPActorConfig",
107
"AjetDataParallelPPOActor",
118
]

ajet/backbone/verl/actor_config.py

Lines changed: 4 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,8 @@
1-
# Copyright 2025 Alibaba Ltd. and/or its affiliates
2-
#
3-
# Licensed under the Apache License, Version 2.0 (the "License");
4-
# you may not use this file except in compliance with the License.
5-
# You may obtain a copy of the License at
6-
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
9-
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
11-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
13-
# limitations under the License.
14-
15-
"""
16-
Ajet extensions for verl ActorConfig.
17-
Adds `override_ppo_mini_batch_num` field to control the number of optimizer steps per train-batch-step.
18-
"""
19-
1+
from verl.workers.config import FSDPActorConfig
202
from dataclasses import dataclass, field
21-
from typing import Optional
22-
23-
from verl.workers.config.actor import ActorConfig, FSDPActorConfig
24-
25-
26-
@dataclass
27-
class AjetActorConfig(ActorConfig):
28-
"""ActorConfig extended with ajet-specific fields.
29-
30-
Additional fields:
31-
override_ppo_mini_batch_num (Optional[int]): If > 0, overrides ppo_mini_batch_size
32-
by computing mini_batch_split_size = ceil(batch_size / override_ppo_mini_batch_num).
33-
"""
34-
35-
override_ppo_mini_batch_num: Optional[int] = None
363

374

385
@dataclass
39-
class AjetFSDPActorConfig(FSDPActorConfig):
40-
"""FSDPActorConfig extended with ajet-specific fields.
41-
42-
Additional fields:
43-
override_ppo_mini_batch_num (Optional[int]): If > 0, overrides ppo_mini_batch_size
44-
by computing mini_batch_split_size = ceil(batch_size / override_ppo_mini_batch_num).
45-
"""
46-
47-
override_ppo_mini_batch_num: Optional[int] = None
6+
class AgentJetFSDPActorConfig(FSDPActorConfig):
7+
loss_extra_scale_ratio: float = 1.0
8+
override_ppo_mini_batch_num: int = 1

ajet/backbone/verl/dp_actor.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_
3333
from verl.utils.profiler import GPUMemoryLogger
3434
from verl.utils.py_functional import append_to_dict
35-
from verl.utils.seqlen_balancing import prepare_dynamic_batch
35+
# ajet/backbone/verl/seqlen_balancing.py
36+
from ajet.backbone.verl.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch
3637
from verl.workers.actor.dp_actor import DataParallelPPOActor
3738

3839
__all__ = ["AjetDataParallelPPOActor"]
@@ -46,8 +47,94 @@ class AjetDataParallelPPOActor(DataParallelPPOActor):
4647
4748
1. Supports `override_ppo_mini_batch_num` to control the number of optimizer steps per train-batch-step.
4849
2. Adds debug print for tensor shapes during training.
50+
3. Override `prepare_dynamic_batch`
4951
"""
5052

53+
@GPUMemoryLogger(role="dp actor", logger=logger)
54+
def compute_log_prob(self, data: DataProto, calculate_entropy: bool = False) -> dict[str, torch.Tensor]:
55+
"""Compute the log probability of the responses given input_ids, attention_mask and position_ids
56+
57+
Args:
58+
data (DataProto): a DataProto containing keys
59+
60+
``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the
61+
concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``.
62+
63+
``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64.
64+
65+
``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64.
66+
67+
``responses``: tensor of shape [batch_size, response_length]. torch.int64.
68+
69+
Returns:
70+
dict[str, torch.Tensor]: a dict containing keys
71+
- ``log_probs``: tensor of shape [batch_size, response_length]. torch.float32.
72+
- ``entropys``: tensor of shape [batch_size, response_length]. torch.float32.
73+
- ``sum_pi_squared``: tensor of shape [batch_size, response_length]. torch.float32.
74+
"""
75+
calculate_sum_pi_squared = self.config.get("calculate_sum_pi_squared", False)
76+
self.actor_module.eval()
77+
78+
micro_batch_size = data.meta_info["micro_batch_size"]
79+
temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid silent error
80+
use_dynamic_bsz = data.meta_info["use_dynamic_bsz"]
81+
pad_token_id = data.meta_info.get("pad_token_id", 0)
82+
has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
83+
84+
select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
85+
non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else []
86+
if self.use_prefix_grouper:
87+
select_keys += [k for k in ["prompts", "response_mask"] if k in data.batch]
88+
if "uid" in data.non_tensor_batch:
89+
non_tensor_select_keys.append("uid")
90+
91+
data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys)
92+
93+
if use_dynamic_bsz:
94+
max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size
95+
micro_batches, batch_idx_list = prepare_dynamic_batch(data, max_token_len=max_token_len)
96+
else:
97+
micro_batches = data.split(micro_batch_size)
98+
99+
log_probs_lst = []
100+
entropy_lst = []
101+
sum_pi_squared_lst = []
102+
print(f"len(micro_batches) = {len(micro_batches)}")
103+
for micro_batch in micro_batches:
104+
micro_batch = micro_batch.to(get_device_id())
105+
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch, "pad_token_id": pad_token_id}
106+
with torch.no_grad():
107+
outputs = self._forward_micro_batch(
108+
model_inputs, temperature=temperature, calculate_entropy=calculate_entropy
109+
)
110+
log_probs_lst.append(outputs["log_probs"])
111+
if calculate_entropy:
112+
entropy_lst.append(outputs["entropys"])
113+
if calculate_sum_pi_squared:
114+
sum_pi_squared_lst.append(outputs["sum_pi_squared"])
115+
116+
log_probs = torch.concat(log_probs_lst, dim=0)
117+
if calculate_entropy:
118+
entropys = torch.concat(entropy_lst, dim=0)
119+
if calculate_sum_pi_squared:
120+
sum_pi_squared = torch.concat(sum_pi_squared_lst, dim=0)
121+
122+
if use_dynamic_bsz:
123+
log_probs = restore_dynamic_batch(log_probs, batch_idx_list)
124+
if calculate_entropy:
125+
entropys = restore_dynamic_batch(entropys, batch_idx_list)
126+
if calculate_sum_pi_squared:
127+
sum_pi_squared = restore_dynamic_batch(sum_pi_squared, batch_idx_list)
128+
129+
outputs = {"log_probs": log_probs}
130+
if calculate_entropy:
131+
outputs["entropys"] = entropys
132+
if calculate_sum_pi_squared:
133+
outputs["sum_pi_squared"] = sum_pi_squared
134+
return outputs
135+
136+
137+
51138
@GPUMemoryLogger(role="dp actor", logger=logger)
52139
def update_policy(self, data: DataProto):
53140
# make sure we are in training mode

ajet/backbone/verl/fsdp_workers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def __init__(self, config: DictConfig, role: str, **kwargs):
283283

284284
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
285285
def init_model(self):
286-
# from verl.workers.actor import DataParallelPPOActor
286+
# [AgentJet Change]: use the custom DataParallelPPOActor which supports FSDP and other features needed for ActorRolloutRefWorker
287287
from ajet.backbone.verl.dp_actor import AjetDataParallelPPOActor as DataParallelPPOActor
288288

289289
# This is used to import external_lib into the huggingface systems
@@ -347,7 +347,8 @@ def init_model(self):
347347
log_gpu_memory_usage("After offload actor optimizer during init", logger=logger)
348348

349349
if self._is_actor:
350-
actor_cfg = self.config.actor
350+
# [AgentJet Change]: use the custom DataParallelPPOActor which supports FSDP and other features needed for ActorRolloutRefWorker
351+
actor_cfg = omega_conf_to_dataclass(self.config.actor)
351352
self.actor = DataParallelPPOActor(
352353
config=actor_cfg, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer
353354
)
@@ -422,7 +423,6 @@ def init_model(self):
422423
# Free cached GPU memory so colocated vLLM processes can see it via cudaMemGetInfo
423424
aggressive_empty_cache(force_sync=True)
424425

425-
426426
# ================================= Async related workers =================================
427427
class AjetAsyncActorRolloutRefWorker(AjetActorRolloutRefWorker):
428428
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)

0 commit comments

Comments
 (0)