Skip to content

Commit 60acbdf

Browse files
committed
add a loss scaler
1 parent 8f59278 commit 60acbdf

File tree

10 files changed

+27
-6
lines changed

10 files changed

+27
-6
lines changed

ajet/backbone/verl/dp_actor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,13 @@ def update_policy(self, data: DataProto):
131131

132132
calculate_entropy = self.config.calculate_entropy or (entropy_coeff != 0)
133133

134-
if self.config.use_dynamic_bsz:
134+
if self.config.override_ppo_mini_batch_num > 0:
135+
loss_scale_factor = response_mask.shape[0] / mini_batch_split_size
136+
elif self.config.use_dynamic_bsz:
135137
loss_scale_factor = response_mask.shape[0] / self.config.ppo_mini_batch_size
136138
else:
137139
loss_scale_factor = 1 / self.gradient_accumulation
140+
loss_scale_factor *= self.config.loss_extra_scale_ratio # [AJET] Extra scaling for loss if needed
138141

139142
# all return: (bsz, response_length)
140143
outputs = self._forward_micro_batch(

ajet/default_config/ajet_default.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,9 @@ ajet:
288288
kl_loss_coef: 0.002
289289
kl_loss_type: low_var_kl
290290

291+
# loss = loss * loss_extra_scale_ratio
292+
loss_extra_scale_ratio: 1.0
293+
291294
# Ulysses specific configs
292295
ulysses_sequence_parallel_size: 1
293296

ajet/default_config/verl/config_auto_convertion_verl.jsonc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"ajet.trainer_common.kl_loss_coef": "actor_rollout_ref.actor.kl_loss_coef",
1616
"ajet.trainer_common.kl_loss_type": "actor_rollout_ref.actor.kl_loss_type",
1717
"ajet.trainer_common.ulysses_sequence_parallel_size": "actor_rollout_ref.actor.ulysses_sequence_parallel_size",
18+
"ajet.trainer_common.loss_extra_scale_ratio": "actor_rollout_ref.actor.loss_extra_scale_ratio",
1819

1920
"ajet.trainer_common.save_freq": "trainer.save_freq",
2021
"ajet.trainer_common.test_freq": "trainer.test_freq",
@@ -30,6 +31,8 @@
3031
"actor_rollout_ref.ref.log_prob_max_token_len_per_gpu"
3132
],
3233

34+
"ajet.rollout.max_num_seqs": "actor_rollout_ref.rollout.max_num_seqs",
35+
"ajet.rollout.temperature": "actor_rollout_ref.rollout.temperature",
3336
"ajet.rollout.multi_turn": "actor_rollout_ref.rollout.multi_turn",
3437
"ajet.rollout.val_kwargs": "actor_rollout_ref.rollout.val_kwargs",
3538
"ajet.rollout.num_repeat": [

ajet/default_config/verl/verl_default.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ actor_rollout_ref:
7070
rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}
7171
strategy: fsdp
7272
ppo_mini_batch_size: 256
73+
loss_extra_scale_ratio: 1.0
7374
override_ppo_mini_batch_num: 1 # special in agentjet
7475
ppo_micro_batch_size: null
7576
ppo_micro_batch_size_per_gpu: null

tests/bench/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,13 @@ VERL_PYTHON="./.venv/bin/python" python -m pytest -s tests/bench/benchmark_count
2525
VERL_PYTHON="./.venv/bin/python" python -m pytest -s tests/bench/benchmark_learn2ask/execute_benchmark_learn2ask.py::TestBenchmarkLearnToAsk::test_01_begin_verl
2626
VERL_PYTHON="./.venv/bin/python" python -m pytest -s tests/bench/benchmark_frozenlake/execute_benchmark_frozenlake.py::TestBenchmarkFrozenLake::test_01_begin_verl
2727

28+
python -m ajet.launcher --conf tests/bench/benchmark_math/benchmark_math.yaml --autokill --db="UPP"
2829

2930
export APPWORLD_PATH="/dev/shm/pack_all_in_one"
3031
export APPWORLD_SCRIPT="bash EnvService/env_sandbox/appworld.sh"
3132
python -m ajet.launcher --conf tests/bench/benchmark_appworld/benchmark_appworld.yaml --with-appworld --backbone=debug --autokill
3233
python -m ajet.launcher --conf tests/bench/benchmark_appworld/benchmark_appworld_oai_sdk.yaml --with-appworld --autokill --db="EXT"
3334
```
35+
36+
37+
VERL_PYTHON="./.venv/bin/python" python -m pytest -s tests/bench/benchmark_math/execute_benchmark_math.py::TestBenchmarkMath::test_01_begin_verl

tests/bench/benchmark_appworld/benchmark_appworld.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,15 @@ ajet:
4747
max_prompt_length: 3000
4848
max_response_length: 15000
4949

50+
# trainer common configurations
5051
trainer_common:
5152
save_freq: 99999
5253
test_freq: 99999
5354
total_epochs: 99999
5455
nnodes: 1
5556
n_gpus_per_node: 8
57+
# loss = loss * loss_extra_scale_ratio
58+
loss_extra_scale_ratio: 10.0
5659

5760
execute_test: True # DO NOT EDIT, THIS IS FOR TEST ROBOT
5861
execute_testing_lambda: "tests/bench/benchmark_appworld/benchmark_appworld.py->TestProbe" #

tests/bench/benchmark_countdown/benchmark_countdown.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ ajet:
116116
kl_loss_coef: 0.002
117117
kl_loss_type: low_var_kl
118118
ulysses_sequence_parallel_size: 1
119-
119+
# loss = loss * loss_extra_scale_ratio
120+
loss_extra_scale_ratio: 10.0
120121

121122
# DO NOT EDIT, FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN.
122123
execute_test: True # FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN.

tests/bench/benchmark_frozenlake/benchmark_frozenlake.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ ajet:
6969
nnodes: 1
7070
n_gpus_per_node: 8
7171
logger: swanlab
72-
72+
# loss = loss * loss_extra_scale_ratio
73+
loss_extra_scale_ratio: 10.0
7374

7475
execute_test: True
7576
execute_testing_lambda: "tests/bench/benchmark_frozenlake/benchmark_frozenlake.py->TestProbe"

tests/bench/benchmark_learn2ask/benchmark_learn2ask.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ ajet:
4545
test_freq: 100
4646
total_epochs: 100
4747
logger: swanlab
48-
48+
# loss = loss * loss_extra_scale_ratio
49+
loss_extra_scale_ratio: 10.0
4950

5051
execute_test: True # DO NOT EDIT, THIS IS FOR TEST ROBOT
5152
execute_testing_lambda: "tests/bench/benchmark_learn2ask/benchmark_learn2ask.py->TestProbe" # DO NOT EDIT, THIS IS FOR TEST ROBOT

tests/bench/benchmark_math/benchmark_math.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ ajet:
2121
user_workflow: "tutorial.example_math_agent.math_agent->ExampleMathLearn" # ✨✨✨✨ 编写并选择Agent
2222
temperature: 1.0
2323
max_env_worker: 64
24-
max_num_seqs: 256
24+
max_num_seqs: 10
2525
num_repeat: 6
2626
agent_madness_reward: 0.0
2727
tensor_model_parallel_size: 1
@@ -49,7 +49,8 @@ ajet:
4949
logger: swanlab
5050
nnodes: 1
5151
n_gpus_per_node: 4
52-
52+
# loss = loss * loss_extra_scale_ratio
53+
loss_extra_scale_ratio: 40.0
5354

5455

5556
execute_test: True # DO NOT EDIT, THIS IS FOR TEST ROBOT

0 commit comments

Comments
 (0)