Skip to content

Commit 23a55a2

Browse files
committed
Fix the broken cli data loading for gsm8k; Refactor the standard and agentic grpo to a single function, deduped.
1 parent 0fac961 commit 23a55a2

5 files changed

Lines changed: 247 additions & 54 deletions

File tree

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
#!/bin/bash
2+
# Copyright 2026 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
# Agentic GSM8K GRPO launcher for Qwen3 8B using
17+
# tunix/cli/base_agentic_config.yaml plus explicit CLI overrides.
18+
#
19+
# Usage:
20+
# bash /examples/rl/grpo/gsm8k/run_qwen3_8b.sh
21+
#
22+
# Run from the tunix repo root.
23+
24+
set -euo pipefail
25+
26+
export SKIP_JAX_PRECOMPILE=true
27+
28+
model_name="${model_name:-Qwen3-8B}"
29+
model_id="${model_id:-Qwen/Qwen3-8B}"
30+
tokenizer_path="${tokenizer_path:-$model_id}"
31+
32+
batch_size="${batch_size:-8}"
33+
num_batches="${num_batches:-934}"
34+
num_train_epochs="${num_train_epochs:-1}"
35+
train_fraction="${train_fraction:-1.0}"
36+
warmup_ratio="${warmup_ratio:-0.1}"
37+
38+
mini_batch_size="${mini_batch_size:-8}"
39+
train_micro_batch_size="${train_micro_batch_size:-1}"
40+
rollout_micro_batch_size="${rollout_micro_batch_size:-8}"
41+
compute_logps_micro_batch_size="${compute_logps_micro_batch_size:-1}"
42+
43+
num_generations="${num_generations:-4}"
44+
45+
train_mesh="${train_mesh:-(8,1)}"
46+
rollout_mesh="${rollout_mesh:-(1,8)}"
47+
48+
max_steps=$(awk "BEGIN {
49+
value = $num_batches * $num_train_epochs * $train_fraction;
50+
if (value < 1) value = 1;
51+
printf \"%.0f\", value;
52+
}")
53+
warmup_steps=$(awk "BEGIN {
54+
value = $warmup_ratio * $max_steps;
55+
if (value < 1) value = 1;
56+
printf \"%.0f\", value;
57+
}")
58+
vllm_max_num_seqs=$(awk "BEGIN {
59+
value = $rollout_micro_batch_size * $num_generations;
60+
if (value < 1) value = 1;
61+
printf \"%.0f\", value;
62+
}")
63+
64+
python -m tunix.cli.grpo_main \
65+
tunix/cli/base_agentic_config.yaml \
66+
\
67+
`# -- Model ------------------------------------------------------------` \
68+
model_config.model_name="$model_name" \
69+
model_config.model_id="$model_id" \
70+
model_config.model_source="huggingface" \
71+
model_config.rng_seed=42 \
72+
model_config.model_display=false \
73+
model_config.remat_config=3 \
74+
actor_model_config.mesh.shape="$train_mesh" \
75+
actor_model_config.mesh.axis_names="('fsdp','tp')" \
76+
reference_model_config.mesh=null \
77+
reference_model_config.same_mesh_as="actor" \
78+
rollout_model_config.mesh.shape="$rollout_mesh" \
79+
rollout_model_config.mesh.axis_names="('fsdp','tp')" \
80+
\
81+
`# -- Data -------------------------------------------------------------` \
82+
data_source="huggingface" \
83+
dataset_name="openai/gsm8k:main" \
84+
\
85+
`# -- Training loop ----------------------------------------------------` \
86+
training_mode="agentic_grpo" \
87+
batch_size="$batch_size" \
88+
num_batches="$num_batches" \
89+
num_test_batches=100 \
90+
num_train_epochs="$num_train_epochs" \
91+
train_fraction="$train_fraction" \
92+
reward_functions=["tunix/cli/reward_fn/gsm8k.py"] \
93+
verl_compatible=false \
94+
\
95+
`# -- Rollout engine (vanilla | vllm | sglang_jax) ---------------------` \
96+
rollout_engine="vllm" \
97+
offload_to_cpu=false \
98+
\
99+
`# -- Rollout config ---------------------------------------------------` \
100+
rollout_config.max_prompt_length=256 \
101+
rollout_config.total_generation_steps=768 \
102+
rollout_config.max_tokens_to_generate=768 \
103+
rollout_config.temperature=0.9 \
104+
rollout_config.top_p=1.0 \
105+
rollout_config.top_k=50 \
106+
rollout_config.return_logprobs=true \
107+
\
108+
`# -- vLLM (used when rollout_engine=vllm) -----------------------------` \
109+
vllm_config.hbm_utilization=0.4 \
110+
vllm_config.tpu_backend_type="jax" \
111+
vllm_config.server_mode=true \
112+
vllm_config.async_scheduling=true \
113+
vllm_config.max_num_seqs="$vllm_max_num_seqs" \
114+
vllm_config.kwargs.kv_cache_metrics=true \
115+
vllm_config.kwargs.disable_log_stats=false \
116+
vllm_config.kwargs.enable_prefix_caching=true \
117+
\
118+
`# -- Tokenizer / chat parsing ----------------------------------------` \
119+
chat_parser_config.type="qwen" \
120+
tokenizer_config.tokenizer_type="huggingface" \
121+
tokenizer_config.tokenizer_path="$tokenizer_path" \
122+
tokenizer_config.add_bos=false \
123+
tokenizer_config.add_eos=false \
124+
\
125+
`# -- GRPO algorithm ---------------------------------------------------` \
126+
agentic_grpo_config.num_generations="$num_generations" \
127+
agentic_grpo_config.num_iterations=1 \
128+
agentic_grpo_config.beta=0.08 \
129+
agentic_grpo_config.epsilon=0.2 \
130+
agentic_grpo_config.system_prompt="You are given a grade school math problem. Think step by step and respond using <reasoning>...</reasoning> followed by <answer>...</answer> with only the final numeric answer inside <answer>." \
131+
agentic_grpo_config.max_concurrency=128 \
132+
agentic_grpo_config.max_response_length=768 \
133+
agentic_grpo_config.max_turns=1 \
134+
agentic_grpo_config.context_ratio=1 \
135+
\
136+
`# -- Optimizer --------------------------------------------------------` \
137+
rl_training_config.actor_optimizer_config.opt_type="adamw" \
138+
rl_training_config.actor_optimizer_config.learning_rate=3e-6 \
139+
rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \
140+
rl_training_config.actor_optimizer_config.init_value=0.0 \
141+
rl_training_config.actor_optimizer_config.peak_value=3e-6 \
142+
rl_training_config.actor_optimizer_config.end_value=0.0 \
143+
rl_training_config.actor_optimizer_config.warmup_ratio="$warmup_ratio" \
144+
rl_training_config.actor_optimizer_config.warmup_steps="$warmup_steps" \
145+
rl_training_config.actor_optimizer_config.decay_steps="$max_steps" \
146+
rl_training_config.actor_optimizer_config.b1=0.9 \
147+
rl_training_config.actor_optimizer_config.b2=0.99 \
148+
rl_training_config.actor_optimizer_config.weight_decay=0.1 \
149+
rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \
150+
\
151+
`# -- RL training ------------------------------------------------------` \
152+
rl_training_config.eval_every_n_steps=10 \
153+
rl_training_config.max_steps="$max_steps" \
154+
rl_training_config.mini_batch_size="$mini_batch_size" \
155+
rl_training_config.train_micro_batch_size="$train_micro_batch_size" \
156+
rl_training_config.rollout_micro_batch_size="$rollout_micro_batch_size" \
157+
rl_training_config.compute_logps_micro_batch_size="$compute_logps_micro_batch_size" \
158+
rl_training_config.checkpoint_root_directory="/tmp/tunix/checkpoints/gsm8k_qwen3_8b" \
159+
rl_training_config.checkpointing_options.save_interval_steps=250 \
160+
rl_training_config.checkpointing_options.max_to_keep=4 \
161+
rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/gsm8k_qwen3_8b" \
162+
rl_training_config.metrics_logging_options.flush_every_n_steps=20 \
163+
\
164+
"$@"

tests/cli/grpo_main_test.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -357,10 +357,9 @@ def test_standard_grpo_dispatches_to_standard(self):
357357
"""
358358
pipeline = _make_pipeline(extra)
359359
self.assertEqual(pipeline.config.get("training_mode", "grpo"), "grpo")
360-
# _run_standard_grpo should be called; we verify no AttributeError on dispatch
361-
with mock.patch.object(pipeline, "_run_standard_grpo") as mock_std:
360+
with mock.patch.object(pipeline, "_run") as mock_run:
362361
pipeline.run_grpo_trainer()
363-
mock_std.assert_called_once()
362+
mock_run.assert_called_once_with(mode="grpo")
364363

365364
def test_agentic_grpo_dispatches_to_agentic(self):
366365
extra = """
@@ -398,9 +397,9 @@ def test_agentic_grpo_dispatches_to_agentic(self):
398397
"""
399398
pipeline = _make_pipeline(extra)
400399
self.assertEqual(pipeline.config["training_mode"], "agentic_grpo")
401-
with mock.patch.object(pipeline, "_run_agentic_grpo") as mock_ag:
400+
with mock.patch.object(pipeline, "_run") as mock_run:
402401
pipeline.run_grpo_trainer()
403-
mock_ag.assert_called_once()
402+
mock_run.assert_called_once_with(mode="agentic_grpo")
404403

405404
def test_unknown_mode_raises(self):
406405
# Build pipeline with standard config then manually set bad mode
@@ -418,8 +417,35 @@ def test_unknown_mode_raises(self):
418417
"""
419418
pipeline = _make_pipeline(extra)
420419
pipeline.config["training_mode"] = "bad_mode"
421-
with self.assertRaisesRegex(ValueError, "Unknown training_mode"):
422-
pipeline.run_grpo_trainer()
420+
raw_dataset = mock.Mock()
421+
raw_dataset.__len__ = mock.Mock(return_value=1)
422+
with mock.patch.object(pipeline, "_setup_kubernetes"):
423+
with mock.patch.object(pipeline, "_get_tokenizer", return_value=mock.sentinel.tokenizer):
424+
with mock.patch.object(
425+
pipeline,
426+
"_create_chat_parser",
427+
return_value=mock.sentinel.chat_parser,
428+
):
429+
with mock.patch.object(
430+
pipeline,
431+
"_load_raw_dataset",
432+
return_value=(raw_dataset, None),
433+
):
434+
with mock.patch.object(pipeline, "compute_params"):
435+
with mock.patch.object(
436+
grpo_main.data_lib,
437+
"post_init_dataset",
438+
return_value=(mock.sentinel.dataset, None),
439+
):
440+
with mock.patch.object(
441+
pipeline,
442+
"create_rl_cluster",
443+
return_value=mock.sentinel.rl_cluster,
444+
):
445+
with self.assertRaisesRegex(
446+
ValueError, "Unsupported training_mode 'bad_mode'"
447+
):
448+
pipeline.run_grpo_trainer()
423449

424450

425451
# ---------------------------------------------------------------------------

tests/examples/data/math_dataset_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,14 @@ def test_parse_huggingface_dataset_name_supports_gsm8k_alias(self):
148148
self.assertEqual(dataset_name, "openai/gsm8k")
149149
self.assertEqual(config_name, "default")
150150

151+
def test_parse_huggingface_dataset_name_supports_explicit_config(self):
152+
dataset_name, config_name = math_dataset._parse_huggingface_dataset_name(
153+
"openai/gsm8k:main"
154+
)
155+
156+
self.assertEqual(dataset_name, "openai/gsm8k")
157+
self.assertEqual(config_name, "main")
158+
151159
def test_create_dataset_uses_huggingface_loader(self):
152160
raw_dataset = _BaseDataset([
153161
{"question": "Q3", "answer": "#### 42"},

tunix/cli/grpo_main.py

Lines changed: 39 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,8 @@
5151
from tunix.perf import metrics as perf_metrics
5252
from tunix.perf.experimental import export as perf_export_v2
5353
from tunix.rl import rl_cluster as rl_cluster_lib
54-
from tunix.rl.grpo import grpo_learner
5554
from tunix.rl.rollout import base_rollout
5655

57-
GrpoConfig = grpo_learner.GrpoConfig
5856

5957
_PATHWAYS_BNS = flags.DEFINE_string(
6058
"pathways_bns", None, "BNS address of the Pathways server."
@@ -573,16 +571,22 @@ def compute_params(self, dataset):
573571
# Standard GRPO training
574572
# ------------------------------------------------------------------
575573

576-
def _run_standard_grpo(self):
577-
"""Execute standard (non-agentic) GRPO training."""
578-
tokenizer = model_lib.create_tokenizer(
574+
def _get_tokenizer(self):
575+
return model_lib.create_tokenizer(
579576
self.config["tokenizer_config"],
580577
self.config["tokenizer_config"]["tokenizer_path"],
581578
)
582579

580+
def _get_data_module(self,):
581+
if self.data_module is None:
582+
self.data_module = importlib.import_module(self.config["data_module"])
583+
return self.data_module
584+
585+
def _get_dataset(self, tokenizer):
583586
if self.config.get("data_module", None):
587+
data_module = self.config.get("data_module", None)
584588
dataset = data_lib.get_dataset_from_module(
585-
self.config["data_module"],
589+
data_module,
586590
tokenizer,
587591
)
588592
elif self.config["data_source"] == "local":
@@ -608,23 +612,7 @@ def _run_standard_grpo(self):
608612
else:
609613
raise ValueError(f"Unsupported data_source {self.config['data_source']}")
610614

611-
self.compute_params(dataset)
612-
dataset, _ = data_lib.post_init_dataset(
613-
dataset,
614-
tokenizer,
615-
batch_size=self.config["batch_size"],
616-
num_batches=self.config.get("num_batches", None),
617-
max_prompt_length=self.config["rollout_config"].get(
618-
"max_prompt_length", None
619-
),
620-
)
621-
rl_cluster = self.create_rl_cluster(tokenizer)
622-
grpo_trainer = grpo_learner.GrpoLearner(
623-
rl_cluster=rl_cluster,
624-
reward_fns=self.obtain_reward_fn(),
625-
algo_config=GrpoConfig(**self.config["grpo_config"]),
626-
)
627-
grpo_trainer.train(dataset)
615+
return dataset
628616

629617
# ------------------------------------------------------------------
630618
# Agentic GRPO helpers
@@ -671,16 +659,17 @@ def _load_class_from_path(self, dotted_path: str) -> type:
671659
module_path, class_name = dotted_path.rsplit(".", 1)
672660
return getattr(importlib.import_module(module_path), class_name)
673661

674-
def _load_raw_dataset(self):
662+
def _load_raw_dataset(self, tokenizer):
675663
"""Load a raw grain.MapDataset from data_module.
676664
677665
The module must expose ``create_dataset(**data_config) -> grain.MapDataset``
678666
and optionally a ``batch_fn`` used as ``custom_batch_fn``.
679667
"""
680-
module = importlib.import_module(self.config["data_module"])
681-
data_config = dict(self.config.get("data_config", {}))
682-
dataset = module.create_dataset(**data_config)
683-
batch_fn = getattr(module, "batch_fn", None)
668+
dataset = self._get_dataset(tokenizer)
669+
data_module = (
670+
self._get_data_module() if self.config.get("data_module", None) else None
671+
)
672+
batch_fn = getattr(data_module, "batch_fn", None) if data_module else None
684673
return dataset, batch_fn
685674

686675
def _setup_kubernetes(self) -> None:
@@ -707,19 +696,15 @@ def _setup_kubernetes(self) -> None:
707696
# Agentic GRPO training
708697
# ------------------------------------------------------------------
709698

710-
def _run_agentic_grpo(self):
699+
def _run(self, mode: str = "grpo"):
711700
"""Execute agentic GRPO training (DeepScaleR, DeepSWE, etc.)."""
712-
from tunix.rl.agentic.agentic_grpo_learner import GRPOLearner # pylint: disable=g-import-not-at-top
713-
714701
self._setup_kubernetes()
715702

716-
tokenizer = model_lib.create_tokenizer(
717-
self.config["tokenizer_config"],
718-
self.config["tokenizer_config"]["tokenizer_path"],
719-
)
703+
tokenizer = self._get_tokenizer()
704+
720705
chat_parser = self._create_chat_parser(tokenizer)
721706

722-
raw_dataset, custom_batch_fn = self._load_raw_dataset()
707+
raw_dataset, custom_batch_fn = self._load_raw_dataset(tokenizer)
723708
self.compute_params(raw_dataset)
724709

725710
dataset, _ = data_lib.post_init_dataset(
@@ -737,6 +722,23 @@ def _run_agentic_grpo(self):
737722
)
738723

739724
rl_cluster = self.create_rl_cluster(tokenizer)
725+
726+
if mode == "grpo":
727+
from tunix.rl.grpo import grpo_learner
728+
729+
grpo_trainer = grpo_learner.GrpoLearner(
730+
rl_cluster=rl_cluster,
731+
reward_fns=self.obtain_reward_fn(),
732+
algo_config=grpo_learner.GrpoConfig(**self.config["grpo_config"]),
733+
)
734+
grpo_trainer.train(dataset)
735+
return
736+
737+
# agentic GRPO
738+
if mode != "agentic_grpo":
739+
raise ValueError(f"Unsupported training_mode {mode!r}")
740+
741+
from tunix.rl.agentic.agentic_grpo_learner import GRPOLearner # pylint: disable=g-import-not-at-top
740742
algo_config = self._create_agentic_grpo_config()
741743

742744
reward_fns = (
@@ -774,14 +776,7 @@ def _run_agentic_grpo(self):
774776
def run_grpo_trainer(self):
775777
"""Dispatch to standard or agentic GRPO based on training_mode."""
776778
mode = self.config.get("training_mode", "grpo")
777-
if mode == "agentic_grpo":
778-
self._run_agentic_grpo()
779-
elif mode == "grpo":
780-
self._run_standard_grpo()
781-
else:
782-
raise ValueError(
783-
f"Unknown training_mode: {mode!r}. Expected 'grpo' or 'agentic_grpo'."
784-
)
779+
self._run(mode=mode)
785780

786781

787782
def _setup_jax_pathways(pathways_bns: str):

0 commit comments

Comments
 (0)