Skip to content

Commit 8f904b3

Browse files
WeiHaochengclaude
andcommitted
feat: add scaffolding rollout workflow
Key design: #818 Co-Authored-By: narutolhy Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent f3d7e50 commit 8f904b3

31 files changed

Lines changed: 8285 additions & 20 deletions

areal/experimental/openai/cache.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,22 @@ def __init__(self, *args, **kwargs):
1717
self._total_reward = 0.0
1818
self._lock = threading.Lock()
1919

20+
def __deepcopy__(self, memo):
21+
"""Allow deep-copy of the empty cache.
22+
23+
``threading.Lock`` cannot be deep-copied. Controllers that hold
24+
an ``InteractionCache`` (e.g. ``ChatTracer``) are cloned via
25+
``Controller.clone()`` (``copy.deepcopy``). The cache must be
26+
empty at clone time; a non-empty cache indicates a bug in the
27+
caller.
28+
"""
29+
assert len(self) == 0, (
30+
f"InteractionCache must be empty when deep-copied, but has {len(self)} items"
31+
)
32+
new = InteractionCache()
33+
memo[id(self)] = new
34+
return new
35+
2036
@property
2137
def last_interaction_id(self) -> str:
2238
return next(reversed(self))

areal/reward/__init__.py

Lines changed: 60 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from math_verify.metric import math_metric
4-
from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig
3+
import concurrent.futures
4+
5+
from math_verify.grader import verify as math_verify_verify
6+
from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig, parse
57

68
from areal.utils import logging
79

@@ -29,38 +31,76 @@ def get_custom_reward_fn(path: str, **kwargs):
2931
class MathVerifyWorker:
3032
"""Thin wrapper over math_verify with configurable extraction/precision.
3133
34+
Uses ``parse()`` + ``verify()`` directly instead of ``math_metric()``
35+
so that signal-based timeouts can be disabled (``parsing_timeout=None``,
36+
``timeout_seconds=None``). This avoids ``signal.alarm()`` which only
37+
works in the main thread. A thread-safe timeout is enforced via
38+
``concurrent.futures`` instead.
39+
3240
Args:
3341
try_extract_without_anchor: When False, only answers with explicit anchors
3442
(e.g., "answer = 1", "final answer = 1") are matched. When True,
3543
any numeric string in the text may be extracted.
3644
precision: Number of significant digits that must match.
45+
timeout: Thread-safe timeout in seconds for the entire verify call
46+
(parsing + comparison). ``None`` disables the timeout.
3747
3848
Notes:
3949
Tune these knobs based on dataset format and model output style.
4050
"""
4151

42-
def __init__(self, try_extract_without_anchor=True, precision: int = 6):
43-
self.verify_func = math_metric(
44-
gold_extraction_target=(
45-
ExprExtractionConfig(
46-
try_extract_without_anchor=try_extract_without_anchor
47-
),
48-
LatexExtractionConfig(),
49-
),
50-
pred_extraction_target=(
51-
ExprExtractionConfig(
52-
try_extract_without_anchor=try_extract_without_anchor
53-
),
54-
LatexExtractionConfig(),
55-
),
56-
precision=precision,
52+
def __init__(
53+
self,
54+
try_extract_without_anchor=True,
55+
precision: int = 6,
56+
timeout: float | None = 5.0,
57+
):
58+
self.gold_extraction_target = (
59+
ExprExtractionConfig(try_extract_without_anchor=try_extract_without_anchor),
60+
LatexExtractionConfig(),
61+
)
62+
self.pred_extraction_target = (
63+
ExprExtractionConfig(try_extract_without_anchor=try_extract_without_anchor),
64+
LatexExtractionConfig(),
65+
)
66+
self.precision = precision
67+
self.timeout = timeout
68+
69+
def _verify_impl(self, response: str, ground_truth: str) -> float:
70+
"""Core verification logic without timeout wrapper."""
71+
gold_parsed = parse(
72+
ground_truth,
73+
extraction_config=self.gold_extraction_target,
74+
parsing_timeout=None,
75+
)
76+
pred_parsed = parse(
77+
response,
78+
extraction_config=self.pred_extraction_target,
79+
parsing_timeout=None,
5780
)
81+
if not gold_parsed or not pred_parsed:
82+
return 0.0
83+
result = math_verify_verify(
84+
gold_parsed,
85+
pred_parsed,
86+
float_rounding=self.precision,
87+
timeout_seconds=None,
88+
)
89+
return 1.0 if result else 0.0
5890

5991
def verify(self, response: str, ground_truth: str) -> float:
60-
# ground_truth_parsable = "\\boxed{" + ground_truth + "}"
6192
try:
62-
ret_score, _ = self.verify_func([ground_truth], [response])
63-
return float(ret_score)
93+
if self.timeout is None:
94+
return self._verify_impl(response, ground_truth)
95+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
96+
future = executor.submit(self._verify_impl, response, ground_truth)
97+
return future.result(timeout=self.timeout)
98+
except concurrent.futures.TimeoutError:
99+
logger.warning(
100+
f"Timeout ({self.timeout}s) in MathVerifyWorker.verify for "
101+
f"response={response!r} and ground_truth={ground_truth!r}",
102+
)
103+
return 0.0
64104
except Exception:
65105
logger.warning(
66106
f"Exception in MathVerifyWorker.verify for response={response} and ground_truth={ground_truth}",

examples/scaffolding/README.md

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
# Scaffolding Framework Examples for AReaL
2+
3+
This directory contains examples demonstrating how to use the Scaffolding framework with
4+
AReaL for reinforcement learning training.
5+
6+
## Overview
7+
8+
The scaffolding framework provides a modular and extensible way to compose
9+
various methods with RL training. It decouples the inference logic
10+
(Controllers) from the execution backend (Workers), enabling flexible composition of
11+
different methods. With Scaffolding, we can flexibly compose various rollout, reward, and trajectory tracing methods.
12+
13+
### Key Components
14+
15+
1. **Controller**: Defines the inference-time compute logic (e.g., generation, reward
16+
computation)
17+
1. **Worker**: Handles the actual execution of tasks (e.g., TRT-LLM, OpenAI API)
18+
1. **ScaffoldingLlm**: Orchestrates controllers and workers together
19+
1. **ScaffoldingWorkflow**: Wraps ScaffoldingLlm as a RolloutWorkflow for AReaL training
20+
21+
### AReaL-Specific Components
22+
23+
The following components are implemented in `examples/scaffolding/`:
24+
25+
- **`CreateWorkerFromEngine`**: Creates a scaffolding Worker from AReaL's
26+
InferenceEngine (e.g., RemoteSGLangEngine). The returned Worker is similar to
27+
scaffolding's `OpenaiWorker` but integrated with AReaL's engine.
28+
29+
- **`RLVRRewardController`**: A Controller that computes rewards for generated samples
30+
using verifiable reward functions (e.g., math answer verification).
31+
32+
- **`PipelineTrajectoryMaker`**: A Controller that composes generation and reward
33+
controllers into a pipeline that produces training trajectories.
34+
35+
- **`ScaffoldingWorkflow`**: A `RolloutWorkflow` implementation that wraps
36+
ScaffoldingLlm for integration with AReaL's training pipeline.
37+
38+
## RLVR Example with GSM8K
39+
40+
### Quick Start
41+
42+
```bash
43+
python examples/scaffolding/gsm8k_rlvr_scaffolding.py \
44+
--config examples/scaffolding/gsm8k_rlvr_scaffolding.yaml
45+
```
46+
47+
### Architecture
48+
49+
The scaffolding workflow follows this pattern from the RFC:
50+
51+
```python
52+
# Step 1: Create Worker from the SGLang engine
53+
rollout_worker = CreateWorkerFromEngine(engine)
54+
55+
# Step 2: Create controllers
56+
rollout_controller = NativeGenerationController()
57+
reward_controller = RLVRRewardController(gsm8k_reward_fn)
58+
59+
# Step 3: Create trajectory maker (composes the controllers)
60+
trajectory_maker = PipelineTrajectoryMaker(rollout_controller, reward_controller)
61+
62+
# Step 4: Create ScaffoldingLlm (orchestrates controllers with workers)
63+
scaffolding_llm = ScaffoldingLlm(
64+
trajectory_maker,
65+
{NativeGenerationController.WorkerTag.GENERATION: rollout_worker},
66+
)
67+
68+
# Step 5: Create ScaffoldingWorkflow (wraps as RolloutWorkflow)
69+
scaffolding_workflow = ScaffoldingWorkflow(scaffolding_llm)
70+
```
71+
72+
### Data Flow Diagram
73+
74+
```
75+
┌─────────────────────────────────────────────────┐
76+
│ ScaffoldingWorkflow │
77+
│ │
78+
│ ┌───────────────────────────────────────────┐ │
79+
│ │ ScaffoldingLlm │ │
80+
│ │ │ │
81+
│ │ ┌─────────────────────────────────────┐ │ │
82+
│ │ │ PipelineTrajectoryMaker │ │ │
83+
│ │ │ │ │ │
84+
│ │ │ ┌───────────────────────────────┐ │ │ │
85+
Data ─────────────────────────┼──┼──┼──► NativeGenerationController │ │ │ │
86+
│ │ │ │ (from scaffolding.core) │ │ │ │
87+
│ │ │ └───────────────┬───────────────┘ │ │ │
88+
│ │ │ │ │ │ │
89+
│ │ │ ▼ │ │ │
90+
│ │ │ ┌───────────────────────────────┐ │ │ │
91+
│ │ │ │ RLVRRewardController │ │ │ │
92+
│ │ │ │ (from areal.experimental) │ │ │ │
93+
│ │ │ └───────────────┬───────────────┘ │ │ │
94+
│ │ │ │ │ │ │
95+
│ │ └──────────────────┼──────────────────┘ │ │
96+
│ │ │ │ │
97+
│ └─────────────────────┼─────────────────────┘ │
98+
│ │ │
99+
└────────────────────────┼────────────────────────┘
100+
101+
▼ Trajectories
102+
┌─────────────────────────────┐
103+
│ PPOTrainer │
104+
│ (GRPO/PPO Training) │
105+
└─────────────────────────────┘
106+
107+
via CreateWorkerFromEngine │
108+
109+
┌─────────────────────────────────────────┐
110+
│ RemoteSGLangEngine │
111+
│ (AReaL Inference Backend) │
112+
└─────────────────────────────────────────┘
113+
```
114+
115+
### How It Works
116+
117+
1. **Engine Initialization**: `RemoteSGLangEngine` is initialized with the rollout
118+
configuration and connected to the model server.
119+
120+
1. **Worker Creation**: `CreateWorkerFromEngine(engine)` wraps the engine into a
121+
scaffolding-compatible Worker. This allows scaffolding controllers to use AReaL's
122+
inference backends.
123+
124+
1. **Controller Pipeline**:
125+
126+
- `NativeGenerationController()`: Handles text generation by yielding
127+
`GenerationTask` objects to the Worker.
128+
- `RLVRRewardController(reward_fn)`: Computes rewards for generated samples using the
129+
provided reward function.
130+
- `PipelineTrajectoryMaker(gen_ctrl, reward_ctrl)`: Composes these controllers into a
131+
pipeline that produces training trajectories.
132+
133+
1. **ScaffoldingLlm**: Orchestrates the trajectory maker with the worker, handling the
134+
async execution of tasks.
135+
136+
1. **ScaffoldingWorkflow**: Wraps the ScaffoldingLlm as a `RolloutWorkflow` that can be
137+
used directly with AReaL's `PPOTrainer`.
138+
139+
1. **Training**: The trainer calls the workflow to generate trajectories, which are then
140+
used for GRPO/PPO training.
141+
142+
### Configuration
143+
144+
See `gsm8k_rlvr_scaffolding.yaml` for the full configuration. Key options:
145+
146+
```yaml
147+
# Model configuration
148+
pretrain_path: Qwen/Qwen2.5-3B-Instruct
149+
tokenizer_path: Qwen/Qwen2.5-3B-Instruct
150+
151+
# Generation hyperparameters
152+
gconfig:
153+
max_new_tokens: 1024
154+
temperature: 1.0
155+
top_p: 1.0
156+
n_samples: 8
157+
158+
# Inference engine configuration
159+
engine:
160+
type: sglang
161+
tp: 1
162+
max_model_len: 4096
163+
```
164+
165+
## Extending the Framework
166+
167+
### Custom Reward Controllers
168+
169+
You can create custom reward controllers by subclassing the base Controller:
170+
171+
```python
172+
from examples.scaffolding._compat import Controller
173+
174+
class CustomRewardController(Controller):
175+
def __init__(self, reward_fn):
176+
super().__init__()
177+
self.reward_fn = reward_fn
178+
179+
def process(self, tasks, **kwargs):
180+
# Compute rewards for completed generation tasks
181+
for task in tasks:
182+
reward = self.reward_fn(
183+
prompt=task.input_str,
184+
completion=task.output_str,
185+
**kwargs
186+
)
187+
task.customized_result_fields["reward"] = reward
188+
yield tasks
189+
```
190+
191+
### Custom Trajectory Makers
192+
193+
For different RL algorithms, you may need different trajectory formats:
194+
195+
```python
196+
from examples.scaffolding._compat import Controller
197+
import torch
198+
199+
class CustomTrajectoryMaker(Controller):
200+
def __init__(self, generation_controller, reward_controller):
201+
super().__init__()
202+
self.generation_controller = generation_controller
203+
self.reward_controller = reward_controller
204+
205+
def process(self, tasks, **kwargs):
206+
# Run generation
207+
yield from self.generation_controller.process(tasks, **kwargs)
208+
209+
# Run reward computation
210+
yield from self.reward_controller.process(tasks, **kwargs)
211+
212+
# Build trajectories
213+
trajectories = []
214+
for task in tasks:
215+
trajectory = {
216+
"input_ids": torch.tensor(task.output_tokens),
217+
"rewards": torch.tensor(task.customized_result_fields["reward"]),
218+
}
219+
trajectories.append(trajectory)
220+
yield trajectories
221+
```
222+
223+
## References
224+
225+
- [TensorRT-LLM Scaffolding README](https://github.com/NVIDIA/TensorRT-LLM/tree/main/tensorrt_llm/scaffolding)
226+
- [AReaL Workflow Documentation](../../docs/customization/workflow.md)
227+
- [RFC: Scaffolding Integration](https://github.com/inclusionAI/AReaL/issues/818)

0 commit comments

Comments
 (0)