Skip to content

Commit ab8cfe4

Browse files
committed
refactor(infra): move scheduler and rpc modules under areal/infra
Consolidate infrastructure-related modules under areal/infra/ for better code organization: - Move areal/scheduler/ to areal/infra/scheduler/ - Move areal/scheduler/rpc/ to areal/infra/rpc/ - Update all imports across the codebase - Update example configs to use new import paths - Add new areal/infra/rpc/__init__.py for cleaner exports Key changes: - Scheduler classes (LocalScheduler, RayScheduler, SlurmScheduler) now at areal.infra.scheduler - RPC utilities (RTensor, serialization) now at areal.infra.rpc - All existing functionality preserved with updated import paths
1 parent b06d242 commit ab8cfe4

81 files changed

Lines changed: 232 additions & 200 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.claude/agents/launcher-scheduler-expert.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ allocation correctness.
2222

2323
Use this agent **when requested** when:
2424

25-
- **Code modifications**: User edits files in `areal/launcher/` or `areal/scheduler/`
25+
- **Code modifications**: User edits files in `areal/launcher/`, `areal/infra/rpc`, or
26+
`areal/infra/scheduler/`
2627
- **Configuration changes**: User modifies `ClusterSpecConfig`, `SchedulerConfig`, or
2728
related dataclasses
2829
- **Deployment issues**: User encounters job launch failures, port conflicts, GPU
@@ -99,8 +100,8 @@ Critical utilities in `areal/utils/launcher.py`:
99100
hard-coded GPU indices or direct `torch.cuda` calls
100101
- Use `areal.utils.name_resolve` for multi-node service discovery -> not direct
101102
IP/hostname assumptions
102-
- Raise specific exceptions from `areal.scheduler.exceptions` -> not generic exception
103-
types
103+
- Raise specific exceptions from `areal.infra.scheduler.exceptions` -> not generic
104+
exception types
104105
- Use `areal.utils.proc.kill_process_tree()` for process termination -> not leaving
105106
zombie processes
106107
- Propagate all `BASE_ENVIRONS` variables and thread control variables -> not missing
@@ -139,9 +140,9 @@ Critical utilities in `areal/utils/launcher.py`:
139140
| `areal/launcher/ray.py` | Ray cluster deployment | Ray actor management, placement group allocation |
140141
| `areal/launcher/sglang_server.py` | SGLang inference server deployment | SGLang server process management, cache isolation |
141142
| `areal/launcher/vllm_server.py` | vLLM inference server deployment | vLLM server process management, cache isolation |
142-
| `areal/scheduler/local.py` | Local worker scheduling | GPU round-robin, port allocation, health monitoring |
143-
| `areal/scheduler/slurm.py` | Slurm-integrated scheduling | Job array coordination, resource reservation |
144-
| `areal/scheduler/ray.py` | Ray cluster scheduling | Ray placement groups, actor-based worker management |
143+
| `areal/infra/scheduler/local.py` | Local worker scheduling | GPU round-robin, port allocation, health monitoring |
144+
| `areal/infra/scheduler/slurm.py` | Slurm-integrated scheduling | Job array coordination, resource reservation |
145+
| `areal/infra/scheduler/ray.py` | Ray cluster scheduling | Ray placement groups, actor-based worker management |
145146
| `areal/utils/launcher.py` | Shared utilities | Environment variable management, configuration validation |
146147

147148
______________________________________________________________________

.claude/data/pr-review-change-types.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ ______________________________________________________________________
4444
| **CHECKPOINT_RECOVERY** | `areal/utils/saver.py`, `areal/utils/recover.py`, `areal/utils/fsdp/checkpoint.py` | `state_dict`, `load_state_dict`, `checkpoint` |
4545
| **REWARD** | `areal/reward/` | `reward_fn`, `AsyncRewardWrapper`, `MathVerifyWorker` |
4646
| **DATASET** | `areal/dataset/` | `get_*_dataset`, `DataLoader`, `IterableDataset` |
47-
| **LAUNCHER_SCHEDULER** | `areal/launcher/`, `areal/scheduler/` | `LaunchConfig`, `Scheduler`, `RayLauncher`, `SlurmLauncher` |
47+
| **LAUNCHER_SCHEDULER** | `areal/launcher/`, `areal/infra/scheduler/`, `areal/infra/rpc` | `LaunchConfig`, `Scheduler`, `RayLauncher`, `SlurmLauncher` |
4848
| **ATTENTION** | `attention.py`, `varlen_attention.py` | `flash_attn`, `sdpa`, `varlen`, `causal_mask` |
4949

5050
## LOW Level (Use Haiku)

.claude/data/pr-review-templates.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ Checklist:
349349
### Launcher and Scheduler Configuration \[Sonnet\]
350350

351351
```
352-
Applicable: areal/launcher/, areal/scheduler/ directories
352+
Applicable: areal/launcher/, areal/infra/scheduler/, areal/infra/rpc directories
353353
Checklist:
354354
- Resource config reasonableness (GPU count, memory)
355355
- Process group config matches parallel strategy

AGENTS.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@ When unsure, leave a `TODO(agent)` comment and note the constraint in your respo
3737
shared utilities:
3838
- `areal/api/` - Contracts for workflows, engines, schedulers, IO structs, and
3939
CLI/config dataclasses.
40-
- `areal/infra/` - Single controller implementation, async orchestration primitives,
41-
and hardware/platform abstractions for CPU/GPU/NPU runtimes.
40+
- `areal/infra/` - Core infrastructure including single controller implementation,
41+
placement and allocation policies, async orchestration primitives, and
42+
hardware/platform abstractions for CPU/GPU/NPU runtimes.
4243
- `areal/dataset/` - Stateful dataset loaders (GSM8K, Geometry3K, CLEVR, HH-RLHF,
4344
TORL, etc.) and utilities that feed rollout jobs safely.
4445
- `areal/engine/` - Training backends (FSDP2, Megatron, PPO, SFT, reward modeling) and
@@ -51,8 +52,6 @@ When unsure, leave a `TODO(agent)` comment and note the constraint in your respo
5152
wrappers, custom heads).
5253
- `areal/reward/` - Built-in reward functions (GSM8K, Geometry3K, CLEVR, etc.), math
5354
parsers, and helpers; wrap slow logic with `AsyncRewardWrapper`.
54-
- `areal/scheduler/` - Placement and allocation policies aligned with launcher
55-
configs.
5655
- `areal/tests/` - Focused unit/integration suites (many require GPUs or mocked
5756
distributed backends).
5857
- `areal/tools/` - Developer utilities and maintenance scripts tied to the core

areal/__init__.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,22 @@
33
from .version import __version__ # noqa
44

55
from .infra import (
6-
TrainController,
76
RolloutController,
8-
WorkflowExecutor,
97
StalenessManager,
10-
workflow_context,
8+
TrainController,
9+
WorkflowExecutor,
1110
current_platform,
11+
workflow_context,
1212
)
1313
from .trainer import PPOTrainer, SFTTrainer
1414

1515
__all__ = [
16-
"TrainController",
16+
"PPOTrainer",
1717
"RolloutController",
18-
"WorkflowExecutor",
18+
"SFTTrainer",
1919
"StalenessManager",
20-
"workflow_context",
20+
"TrainController",
21+
"WorkflowExecutor",
2122
"current_platform",
22-
"PPOTrainer",
23-
"SFTTrainer",
23+
"workflow_context",
2424
]

areal/api/cli_args.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -882,7 +882,7 @@ class TrainEngineConfig:
882882
# Scheduling
883883
scheduling_spec: tuple[SchedulingSpec, ...] = field(
884884
default_factory=lambda: (
885-
SchedulingSpec(cmd="python -m areal.scheduler.rpc.rpc_server"),
885+
SchedulingSpec(cmd="python -m areal.infra.rpc.rpc_server"),
886886
),
887887
metadata={
888888
"help": "Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: "
@@ -1512,7 +1512,7 @@ class InferenceEngineConfig:
15121512
)
15131513
scheduling_spec: tuple[SchedulingSpec, ...] = field(
15141514
default_factory=lambda: (
1515-
SchedulingSpec(cmd="python -m areal.scheduler.rpc.rpc_server"),
1515+
SchedulingSpec(cmd="python -m areal.infra.rpc.rpc_server"),
15161516
),
15171517
metadata={
15181518
"help": "inference engine schedule specs. Can accept 1 or 2 SchedulingSpec: "

areal/experimental/openai/proxy/proxy_rollout_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
from areal.api.cli_args import NameResolveConfig, OpenAIProxyConfig
2929
from areal.experimental.openai.client import ArealOpenAI
30-
from areal.scheduler.rpc.serialization import deserialize_value, serialize_value
30+
from areal.infra.rpc.serialization import deserialize_value, serialize_value
3131
from areal.utils import name_resolve, names, seeding
3232
from areal.utils.dynamic_import import import_from_string
3333
from areal.utils.hf_utils import load_hf_tokenizer

areal/infra/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
RemoteInfBackendProtocol,
88
RemoteInfEngine,
99
)
10+
from .scheduler import LocalScheduler, RayScheduler, SlurmScheduler
1011
from .staleness_manager import StalenessManager
1112
from .workflow_executor import (
1213
WorkflowExecutor,
@@ -25,4 +26,7 @@
2526
"Platform",
2627
"current_platform",
2728
"is_npu_available",
29+
"LocalScheduler",
30+
"RayScheduler",
31+
"SlurmScheduler",
2832
]

areal/infra/controller/rollout_callback.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import requests
66

77
from areal.api.io_struct import ParamSpec, WeightUpdateMeta
8-
from areal.scheduler.rpc.serialization import serialize_value
8+
from areal.infra.rpc.serialization import serialize_value
99
from areal.utils import logging
1010
from areal.utils.concurrent import get_executor
1111

areal/infra/controller/rollout_controller.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
)
3030
from areal.api.scheduler_api import Job, Scheduler, Worker
3131
from areal.api.workflow_api import RolloutWorkflow, WorkflowLike
32-
from areal.scheduler.rpc.serialization import deserialize_value
32+
from areal.infra.rpc.serialization import deserialize_value
3333
from areal.utils import logging, perf_tracer
3434
from areal.utils.concurrent import run_async_task
3535
from areal.utils.data import concat_padded_tensors, cycle_dataloader

0 commit comments

Comments
 (0)