3232from verl .utils .fsdp_utils import FSDPModule , fsdp2_clip_grad_norm_
3333from verl .utils .profiler import GPUMemoryLogger
3434from verl .utils .py_functional import append_to_dict
35- from verl .utils .seqlen_balancing import prepare_dynamic_batch
35+ # ajet/backbone/verl/seqlen_balancing.py
36+ from ajet .backbone .verl .seqlen_balancing import prepare_dynamic_batch , restore_dynamic_batch
3637from verl .workers .actor .dp_actor import DataParallelPPOActor
3738
3839__all__ = ["AjetDataParallelPPOActor" ]
@@ -46,8 +47,94 @@ class AjetDataParallelPPOActor(DataParallelPPOActor):
4647
4748 1. Supports `override_ppo_mini_batch_num` to control the number of optimizer steps per train-batch-step.
4849 2. Adds debug print for tensor shapes during training.
50+ 3. Override `prepare_dynamic_batch`
4951 """
5052
53+ @GPUMemoryLogger (role = "dp actor" , logger = logger )
54+ def compute_log_prob (self , data : DataProto , calculate_entropy : bool = False ) -> dict [str , torch .Tensor ]:
55+ """Compute the log probability of the responses given input_ids, attention_mask and position_ids
56+
57+ Args:
58+ data (DataProto): a DataProto containing keys
59+
60+ ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the
61+ concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``.
62+
63+ ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64.
64+
65+ ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64.
66+
67+ ``responses``: tensor of shape [batch_size, response_length]. torch.int64.
68+
69+ Returns:
70+ dict[str, torch.Tensor]: a dict containing keys
71+ - ``log_probs``: tensor of shape [batch_size, response_length]. torch.float32.
72+ - ``entropys``: tensor of shape [batch_size, response_length]. torch.float32.
73+ - ``sum_pi_squared``: tensor of shape [batch_size, response_length]. torch.float32.
74+ """
75+ calculate_sum_pi_squared = self .config .get ("calculate_sum_pi_squared" , False )
76+ self .actor_module .eval ()
77+
78+ micro_batch_size = data .meta_info ["micro_batch_size" ]
79+ temperature = data .meta_info ["temperature" ] # temperature must be in the data.meta_info to avoid silent error
80+ use_dynamic_bsz = data .meta_info ["use_dynamic_bsz" ]
81+ pad_token_id = data .meta_info .get ("pad_token_id" , 0 )
82+ has_multi_modal_inputs = "multi_modal_inputs" in data .non_tensor_batch .keys ()
83+
84+ select_keys = ["responses" , "input_ids" , "attention_mask" , "position_ids" ]
85+ non_tensor_select_keys = ["multi_modal_inputs" ] if has_multi_modal_inputs else []
86+ if self .use_prefix_grouper :
87+ select_keys += [k for k in ["prompts" , "response_mask" ] if k in data .batch ]
88+ if "uid" in data .non_tensor_batch :
89+ non_tensor_select_keys .append ("uid" )
90+
91+ data = data .select (batch_keys = select_keys , non_tensor_batch_keys = non_tensor_select_keys )
92+
93+ if use_dynamic_bsz :
94+ max_token_len = data .meta_info ["max_token_len" ] * self .ulysses_sequence_parallel_size
95+ micro_batches , batch_idx_list = prepare_dynamic_batch (data , max_token_len = max_token_len )
96+ else :
97+ micro_batches = data .split (micro_batch_size )
98+
99+ log_probs_lst = []
100+ entropy_lst = []
101+ sum_pi_squared_lst = []
102+ print (f"len(micro_batches) = { len (micro_batches )} " )
103+ for micro_batch in micro_batches :
104+ micro_batch = micro_batch .to (get_device_id ())
105+ model_inputs = {** micro_batch .batch , ** micro_batch .non_tensor_batch , "pad_token_id" : pad_token_id }
106+ with torch .no_grad ():
107+ outputs = self ._forward_micro_batch (
108+ model_inputs , temperature = temperature , calculate_entropy = calculate_entropy
109+ )
110+ log_probs_lst .append (outputs ["log_probs" ])
111+ if calculate_entropy :
112+ entropy_lst .append (outputs ["entropys" ])
113+ if calculate_sum_pi_squared :
114+ sum_pi_squared_lst .append (outputs ["sum_pi_squared" ])
115+
116+ log_probs = torch .concat (log_probs_lst , dim = 0 )
117+ if calculate_entropy :
118+ entropys = torch .concat (entropy_lst , dim = 0 )
119+ if calculate_sum_pi_squared :
120+ sum_pi_squared = torch .concat (sum_pi_squared_lst , dim = 0 )
121+
122+ if use_dynamic_bsz :
123+ log_probs = restore_dynamic_batch (log_probs , batch_idx_list )
124+ if calculate_entropy :
125+ entropys = restore_dynamic_batch (entropys , batch_idx_list )
126+ if calculate_sum_pi_squared :
127+ sum_pi_squared = restore_dynamic_batch (sum_pi_squared , batch_idx_list )
128+
129+ outputs = {"log_probs" : log_probs }
130+ if calculate_entropy :
131+ outputs ["entropys" ] = entropys
132+ if calculate_sum_pi_squared :
133+ outputs ["sum_pi_squared" ] = sum_pi_squared
134+ return outputs
135+
136+
137+
51138 @GPUMemoryLogger (role = "dp actor" , logger = logger )
52139 def update_policy (self , data : DataProto ):
53140 # make sure we are in training mode
0 commit comments