|
| 1 | +#!/bin/bash |
| 2 | +# Copyright 2026 Google LLC |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +# |
| 16 | +# Agentic GSM8K GRPO launcher for Qwen3 8B using |
| 17 | +# tunix/cli/base_agentic_config.yaml plus explicit CLI overrides. |
| 18 | +# |
| 19 | +# Usage: |
| 20 | +# bash /examples/rl/grpo/gsm8k/run_qwen3_8b.sh |
| 21 | +# |
| 22 | +# Run from the tunix repo root. |
| 23 | + |
| 24 | +set -euo pipefail |
| 25 | + |
| 26 | +export SKIP_JAX_PRECOMPILE=true |
| 27 | + |
| 28 | +model_name="${model_name:-Qwen3-8B}" |
| 29 | +model_id="${model_id:-Qwen/Qwen3-8B}" |
| 30 | +tokenizer_path="${tokenizer_path:-$model_id}" |
| 31 | + |
| 32 | +batch_size="${batch_size:-8}" |
| 33 | +num_batches="${num_batches:-934}" |
| 34 | +num_train_epochs="${num_train_epochs:-1}" |
| 35 | +train_fraction="${train_fraction:-1.0}" |
| 36 | +warmup_ratio="${warmup_ratio:-0.1}" |
| 37 | + |
| 38 | +mini_batch_size="${mini_batch_size:-8}" |
| 39 | +train_micro_batch_size="${train_micro_batch_size:-1}" |
| 40 | +rollout_micro_batch_size="${rollout_micro_batch_size:-8}" |
| 41 | +compute_logps_micro_batch_size="${compute_logps_micro_batch_size:-1}" |
| 42 | + |
| 43 | +num_generations="${num_generations:-4}" |
| 44 | + |
| 45 | +train_mesh="${train_mesh:-(8,1)}" |
| 46 | +rollout_mesh="${rollout_mesh:-(1,8)}" |
| 47 | + |
| 48 | +max_steps=$(awk "BEGIN { |
| 49 | + value = $num_batches * $num_train_epochs * $train_fraction; |
| 50 | + if (value < 1) value = 1; |
| 51 | + printf \"%.0f\", value; |
| 52 | +}") |
| 53 | +warmup_steps=$(awk "BEGIN { |
| 54 | + value = $warmup_ratio * $max_steps; |
| 55 | + if (value < 1) value = 1; |
| 56 | + printf \"%.0f\", value; |
| 57 | +}") |
| 58 | +vllm_max_num_seqs=$(awk "BEGIN { |
| 59 | + value = $rollout_micro_batch_size * $num_generations; |
| 60 | + if (value < 1) value = 1; |
| 61 | + printf \"%.0f\", value; |
| 62 | +}") |
| 63 | + |
| 64 | +python -m tunix.cli.grpo_main \ |
| 65 | + tunix/cli/base_agentic_config.yaml \ |
| 66 | + \ |
| 67 | + `# -- Model ------------------------------------------------------------` \ |
| 68 | + model_config.model_name="$model_name" \ |
| 69 | + model_config.model_id="$model_id" \ |
| 70 | + model_config.model_source="huggingface" \ |
| 71 | + model_config.rng_seed=42 \ |
| 72 | + model_config.model_display=false \ |
| 73 | + model_config.remat_config=3 \ |
| 74 | + actor_model_config.mesh.shape="$train_mesh" \ |
| 75 | + actor_model_config.mesh.axis_names="('fsdp','tp')" \ |
| 76 | + reference_model_config.mesh=null \ |
| 77 | + reference_model_config.same_mesh_as="actor" \ |
| 78 | + rollout_model_config.mesh.shape="$rollout_mesh" \ |
| 79 | + rollout_model_config.mesh.axis_names="('fsdp','tp')" \ |
| 80 | + \ |
| 81 | + `# -- Data -------------------------------------------------------------` \ |
| 82 | + data_source="huggingface" \ |
| 83 | + dataset_name="openai/gsm8k:main" \ |
| 84 | + \ |
| 85 | + `# -- Training loop ----------------------------------------------------` \ |
| 86 | + training_mode="agentic_grpo" \ |
| 87 | + batch_size="$batch_size" \ |
| 88 | + num_batches="$num_batches" \ |
| 89 | + num_test_batches=100 \ |
| 90 | + num_train_epochs="$num_train_epochs" \ |
| 91 | + train_fraction="$train_fraction" \ |
| 92 | + reward_functions=["tunix/cli/reward_fn/gsm8k.py"] \ |
| 93 | + verl_compatible=false \ |
| 94 | + \ |
| 95 | + `# -- Rollout engine (vanilla | vllm | sglang_jax) ---------------------` \ |
| 96 | + rollout_engine="vllm" \ |
| 97 | + offload_to_cpu=false \ |
| 98 | + \ |
| 99 | + `# -- Rollout config ---------------------------------------------------` \ |
| 100 | + rollout_config.max_prompt_length=256 \ |
| 101 | + rollout_config.total_generation_steps=768 \ |
| 102 | + rollout_config.max_tokens_to_generate=768 \ |
| 103 | + rollout_config.temperature=0.9 \ |
| 104 | + rollout_config.top_p=1.0 \ |
| 105 | + rollout_config.top_k=50 \ |
| 106 | + rollout_config.return_logprobs=true \ |
| 107 | + \ |
| 108 | + `# -- vLLM (used when rollout_engine=vllm) -----------------------------` \ |
| 109 | + vllm_config.hbm_utilization=0.4 \ |
| 110 | + vllm_config.tpu_backend_type="jax" \ |
| 111 | + vllm_config.server_mode=true \ |
| 112 | + vllm_config.async_scheduling=true \ |
| 113 | + vllm_config.max_num_seqs="$vllm_max_num_seqs" \ |
| 114 | + vllm_config.kwargs.kv_cache_metrics=true \ |
| 115 | + vllm_config.kwargs.disable_log_stats=false \ |
| 116 | + vllm_config.kwargs.enable_prefix_caching=true \ |
| 117 | + \ |
| 118 | + `# -- Tokenizer / chat parsing ----------------------------------------` \ |
| 119 | + chat_parser_config.type="qwen" \ |
| 120 | + tokenizer_config.tokenizer_type="huggingface" \ |
| 121 | + tokenizer_config.tokenizer_path="$tokenizer_path" \ |
| 122 | + tokenizer_config.add_bos=false \ |
| 123 | + tokenizer_config.add_eos=false \ |
| 124 | + \ |
| 125 | + `# -- GRPO algorithm ---------------------------------------------------` \ |
| 126 | + agentic_grpo_config.num_generations="$num_generations" \ |
| 127 | + agentic_grpo_config.num_iterations=1 \ |
| 128 | + agentic_grpo_config.beta=0.08 \ |
| 129 | + agentic_grpo_config.epsilon=0.2 \ |
| 130 | + agentic_grpo_config.system_prompt="You are given a grade school math problem. Think step by step and respond using <reasoning>...</reasoning> followed by <answer>...</answer> with only the final numeric answer inside <answer>." \ |
| 131 | + agentic_grpo_config.max_concurrency=128 \ |
| 132 | + agentic_grpo_config.max_response_length=768 \ |
| 133 | + agentic_grpo_config.max_turns=1 \ |
| 134 | + agentic_grpo_config.context_ratio=1 \ |
| 135 | + \ |
| 136 | + `# -- Optimizer --------------------------------------------------------` \ |
| 137 | + rl_training_config.actor_optimizer_config.opt_type="adamw" \ |
| 138 | + rl_training_config.actor_optimizer_config.learning_rate=3e-6 \ |
| 139 | + rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \ |
| 140 | + rl_training_config.actor_optimizer_config.init_value=0.0 \ |
| 141 | + rl_training_config.actor_optimizer_config.peak_value=3e-6 \ |
| 142 | + rl_training_config.actor_optimizer_config.end_value=0.0 \ |
| 143 | + rl_training_config.actor_optimizer_config.warmup_ratio="$warmup_ratio" \ |
| 144 | + rl_training_config.actor_optimizer_config.warmup_steps="$warmup_steps" \ |
| 145 | + rl_training_config.actor_optimizer_config.decay_steps="$max_steps" \ |
| 146 | + rl_training_config.actor_optimizer_config.b1=0.9 \ |
| 147 | + rl_training_config.actor_optimizer_config.b2=0.99 \ |
| 148 | + rl_training_config.actor_optimizer_config.weight_decay=0.1 \ |
| 149 | + rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \ |
| 150 | + \ |
| 151 | + `# -- RL training ------------------------------------------------------` \ |
| 152 | + rl_training_config.eval_every_n_steps=10 \ |
| 153 | + rl_training_config.max_steps="$max_steps" \ |
| 154 | + rl_training_config.mini_batch_size="$mini_batch_size" \ |
| 155 | + rl_training_config.train_micro_batch_size="$train_micro_batch_size" \ |
| 156 | + rl_training_config.rollout_micro_batch_size="$rollout_micro_batch_size" \ |
| 157 | + rl_training_config.compute_logps_micro_batch_size="$compute_logps_micro_batch_size" \ |
| 158 | + rl_training_config.checkpoint_root_directory="/tmp/tunix/checkpoints/gsm8k_qwen3_8b" \ |
| 159 | + rl_training_config.checkpointing_options.save_interval_steps=250 \ |
| 160 | + rl_training_config.checkpointing_options.max_to_keep=4 \ |
| 161 | + rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/gsm8k_qwen3_8b" \ |
| 162 | + rl_training_config.metrics_logging_options.flush_every_n_steps=20 \ |
| 163 | + \ |
| 164 | + "$@" |
0 commit comments