Skip to content

Commit 8d9ae66

Browse files
committed
feat(service): add external model API with safety guards
Add support for routing chat completions to external OpenAI-compatible APIs (e.g. GPT-4o) through the unified gateway/router/data-proxy stack, with interaction caching, reward assignment, and trajectory export. Key changes: - External model registration flow: gateway -> router -> data proxy - Proxy forwarding for streaming and non-streaming chat completions - ExternalSessionData for interaction caching and trajectory export - Controller external_mode: skip inference server launch, validate config - Fail-fast on missing external_api_key to prevent admin key leakage - Rollback router entry when data proxy registration fails - Guard pause/continue against None inf_bridge in external mode - Only cache successful external API responses - Assert external mode requires workflow=None and group_size=1 - /v1/chat/completions and /v1/models OpenAI-compatible aliases - HITL demo and online_rollout support for --external-url flags - Comprehensive unit and integration tests (1150+ lines)
1 parent d130b99 commit 8d9ae66

16 files changed

Lines changed: 2223 additions & 164 deletions

File tree

areal/experimental/inference_service/controller/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,9 @@ class GatewayControllerConfig:
5454

5555
# -- OpenAI proxy configuration (for agent-like workflows) ---------------
5656
openai: OpenAIProxyConfig = field(default_factory=lambda: OpenAIProxyConfig())
57+
58+
# -- External model API ------------------------------------------------
59+
external_api_url: str | None = None
60+
external_api_key: str | None = None
61+
external_api_model: str | None = None
62+
external_model_name: str = "ext-model"

areal/experimental/inference_service/controller/controller.py

Lines changed: 116 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,15 @@ def __init__(
8282
config: GatewayControllerConfig,
8383
scheduler: Scheduler,
8484
) -> None:
85-
from areal.api.alloc_mode import ModelAllocation
86-
8785
self.config = config
8886
self.scheduler = scheduler
8987

90-
# Parse allocation from config.backend
91-
self.rollout_alloc = ModelAllocation.from_str(config.backend)
88+
if config.external_api_url is not None:
89+
self.rollout_alloc = None
90+
else:
91+
from areal.api.alloc_mode import ModelAllocation
92+
93+
self.rollout_alloc = ModelAllocation.from_str(config.backend)
9294

9395
# Worker management
9496
self.workers: list[Worker] = []
@@ -191,6 +193,15 @@ def initialize(
191193

192194
logger.info("GatewayInferenceController initialized (role=%s)", role)
193195

196+
if self.config.external_api_url:
197+
self._register_external_model()
198+
logger.info(
199+
"External model mode: url=%s, model=%s, name=%s",
200+
self.config.external_api_url,
201+
self.config.external_api_model,
202+
self.config.external_model_name,
203+
)
204+
194205
async def _async_initialize(
195206
self,
196207
server_args: dict[str, Any] | None,
@@ -208,6 +219,8 @@ async def _async_initialize(
208219
* **server_infos is not None** — SGLang servers already exist so
209220
we only fork data proxy on every worker; fork router + gateway
210221
on worker 0.
222+
* **external_mode** — skip inference servers entirely; data proxies
223+
start with an empty ``--backend-addr``.
211224
"""
212225
from dataclasses import asdict
213226

@@ -216,30 +229,40 @@ async def _async_initialize(
216229
from areal.api.cli_args import SchedulingSpec, SchedulingStrategy
217230
from areal.api.scheduler_api import Job
218231

219-
alloc = self.rollout_alloc
220-
dp_size = alloc.parallel.dp_size
221232
cfg = self.config
222233
admin_api_key = self.config.openai.admin_api_key
223234

224-
inf_backend = alloc.backend
235+
if self.external_mode:
236+
dp_size = 1
237+
inf_backend = None
238+
else:
239+
alloc = self.rollout_alloc
240+
dp_size = alloc.parallel.dp_size
241+
inf_backend = alloc.backend
225242

226243
# ==================================================================
227244
# Step 0: Always create dp_size RPCGuard workers
228245
# ==================================================================
229-
inf_spec = SchedulingSpec(**asdict(cfg.scheduling_spec[0]))
230-
instance_size = alloc.parallel.tp_size * alloc.parallel.pp_size
231-
if server_infos is not None:
232-
# Pre-existing inference servers — RPCGuard workers only host
233-
# CPU services (data proxy, router, gateway), no GPUs needed.
234-
inf_spec.gpu = 0
246+
if self.external_mode:
247+
inf_spec = SchedulingSpec(
248+
task_type="worker",
249+
port_count=2,
250+
gpu=0,
251+
mem=8,
252+
cmd="python -m areal.experimental.inference_service.guard",
253+
)
235254
else:
236-
inf_spec.cpu *= instance_size
237-
inf_spec.mem *= instance_size
238-
if inf_spec.gpu > 0:
239-
inf_spec.gpu = instance_size
240-
241-
# Override cmd to launch RPCGuard instead of RPC server
242-
inf_spec.cmd = "python -m areal.experimental.inference_service.guard"
255+
inf_spec = SchedulingSpec(**asdict(cfg.scheduling_spec[0]))
256+
instance_size = alloc.parallel.tp_size * alloc.parallel.pp_size
257+
if server_infos is not None:
258+
inf_spec.gpu = 0
259+
else:
260+
inf_spec.cpu *= instance_size
261+
inf_spec.mem *= instance_size
262+
if inf_spec.gpu > 0:
263+
inf_spec.gpu = instance_size
264+
# Override cmd to launch RPCGuard instead of RPC server
265+
inf_spec.cmd = "python -m areal.experimental.inference_service.guard"
243266

244267
inf_role = f"{self._worker_role}{self._INF_SUFFIX}"
245268
inf_job = Job(
@@ -256,9 +279,11 @@ async def _async_initialize(
256279
logger.info("RPCGuard workers ready: %s", [w.id for w in inf_workers])
257280

258281
# ==================================================================
259-
# Step 1: Launch inference servers (skip when pre-existing)
282+
# Step 1: Launch inference servers (skip in external mode or when pre-existing)
260283
# ==================================================================
261-
if server_infos is not None:
284+
if self.external_mode:
285+
logger.info("External mode — skipping inference server launch")
286+
elif server_infos is not None:
262287
# Pre-existing servers — just record their addresses
263288
self.server_infos = server_infos
264289
self._inf_addrs = [
@@ -327,7 +352,6 @@ def _build_launch_cmd(host: str, port: int) -> list[str]:
327352
else:
328353
raise ValueError(f"Unsupported inference backend: {inf_backend!r}")
329354

330-
# For each RPCGuard worker: alloc port, build cmd, fork server
331355
for rank, worker in enumerate(inf_workers):
332356
guard_addr = (
333357
f"http://{format_hostport(worker.ip, int(worker.worker_ports[0]))}"
@@ -447,12 +471,15 @@ def _build_launch_cmd(host: str, port: int) -> list[str]:
447471
f"http://{format_hostport(worker.ip, int(worker.worker_ports[0]))}"
448472
)
449473
# Each data proxy connects to its corresponding inference server
450-
data_proxy_cmd = data_proxy_base_cmd + [
451-
"--backend-addr",
452-
self._inf_addrs[rank],
453-
"--backend-type",
454-
inf_backend or "sglang",
455-
]
474+
if self.external_mode:
475+
data_proxy_cmd = data_proxy_base_cmd + ["--backend-addr", ""]
476+
else:
477+
data_proxy_cmd = data_proxy_base_cmd + [
478+
"--backend-addr",
479+
self._inf_addrs[rank],
480+
"--backend-type",
481+
inf_backend or "sglang",
482+
]
456483
data_proxy_host, data_proxy_port = self._fork_on_guard(
457484
guard_addr=guard_addr,
458485
role="data-proxy",
@@ -533,6 +560,40 @@ def _register_data_proxies_in_router(self) -> None:
533560
worker_id,
534561
)
535562

563+
def _register_external_model(self) -> None:
564+
import requests
565+
566+
cfg = self.config
567+
if cfg.external_api_key is None:
568+
raise ValueError(
569+
"external_api_key must be set when using external model mode. "
570+
"Without it, the internal admin API key would be leaked to the "
571+
"external provider."
572+
)
573+
resp = requests.post(
574+
f"{self._gateway_addr}/register_model",
575+
json={
576+
"name": cfg.external_model_name,
577+
"url": cfg.external_api_url,
578+
"model": cfg.external_api_model,
579+
},
580+
headers={"Authorization": f"Bearer {cfg.openai.admin_api_key}"},
581+
timeout=cfg.request_timeout,
582+
)
583+
resp.raise_for_status()
584+
logger.info(
585+
"External model registered: name=%s url=%s model=%s "
586+
"(requests will be sent to %s/chat/completions)",
587+
cfg.external_model_name,
588+
cfg.external_api_url,
589+
cfg.external_api_model,
590+
cfg.external_api_url.rstrip("/"),
591+
)
592+
593+
@property
594+
def external_mode(self) -> bool:
595+
return self.config.external_api_url is not None
596+
536597
def _start_online_callback_server(self) -> None:
537598
"""Start callback server used by the router to deliver ready trajectories."""
538599
if self._callback_server is not None:
@@ -990,11 +1051,19 @@ async def chat_completion(
9901051
if extra_body and isinstance(extra_body, dict):
9911052
body.update(extra_body)
9921053

993-
api_key = (
994-
session_api_key
995-
if session_api_key is not None
996-
else self.config.openai.admin_api_key
997-
)
1054+
if self.external_mode:
1055+
body["model"] = self.config.external_model_name
1056+
api_key = (
1057+
session_api_key
1058+
if session_api_key is not None
1059+
else self.config.external_api_key or self.config.openai.admin_api_key
1060+
)
1061+
else:
1062+
api_key = (
1063+
session_api_key
1064+
if session_api_key is not None
1065+
else self.config.openai.admin_api_key
1066+
)
9981067
url = f"{self._gateway_addr}/chat/completions"
9991068
headers = {
10001069
"Content-Type": "application/json",
@@ -1201,6 +1270,19 @@ def _resolve_workflow(
12011270
from areal.api.workflow_api import RolloutWorkflow
12021271
from areal.utils.dynamic_import import import_from_string
12031272

1273+
# External mode only supports online mode (workflow=None)
1274+
if self.external_mode and workflow is not None:
1275+
raise ValueError(
1276+
"External model mode only supports online mode (workflow=None). "
1277+
"Agent-based workflows are not supported with external models."
1278+
)
1279+
1280+
if self.external_mode and group_size > 1:
1281+
raise ValueError(
1282+
"External model mode requires group_size=1, "
1283+
f"got group_size={group_size}."
1284+
)
1285+
12041286
# (a) None → online mode: create InferenceServiceWorkflow without agent
12051287
if workflow is None:
12061288
from areal.experimental.inference_service.controller.workflow import (

areal/experimental/inference_service/controller/workflow.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ async def _export_interactions(
9898
session: aiohttp.ClientSession,
9999
session_id: str,
100100
trajectory_id: int | None = None,
101-
) -> dict[str, InteractionWithTokenLogpReward]:
101+
) -> dict[str, InteractionWithTokenLogpReward] | dict[str, Any]:
102102
url = f"{self.gateway_addr}/{_EXPORT_TRAJECTORIES_PATHNAME}"
103103
headers = {"Authorization": f"Bearer {self._admin_api_key}"}
104104
payload = {
@@ -110,13 +110,18 @@ async def _export_interactions(
110110
async with session.post(url, json=payload, headers=headers) as resp:
111111
resp.raise_for_status()
112112
data = await resp.json()
113+
114+
# External API trajectories are returned as-is without deserialization
115+
if data.get("external_api"):
116+
return data
117+
113118
return _deserialize_interactions(data["interactions"])
114119

115120
async def arun_episode(
116121
self,
117122
engine: InferenceEngine,
118123
data: dict[str, Any],
119-
) -> dict[str, InteractionWithTokenLogpReward] | None:
124+
) -> dict[str, InteractionWithTokenLogpReward] | dict[str, Any] | None:
120125
del engine
121126
http_session = await workflow_context.get_aiohttp_session()
122127
await self._grant_capacity(http_session)
@@ -190,23 +195,28 @@ async def _run_offline(
190195
async def _run_online(
191196
self,
192197
http_session: aiohttp.ClientSession,
193-
) -> dict[str, InteractionWithTokenLogpReward] | None:
198+
) -> dict[str, InteractionWithTokenLogpReward] | dict[str, Any] | None:
194199
logger.debug("Waiting for next ready online trajectory")
195200
export_request = await self.controller.wait_for_online_trajectory(
196201
timeout=self.timeout
197202
)
198203
if not export_request:
199204
return None
200205

201-
interactions = await self._export_interactions(
202-
http_session,
203-
export_request["session_id"],
204-
trajectory_id=export_request["trajectory_id"],
206+
session_id = export_request["session_id"]
207+
trajectory_id = export_request["trajectory_id"]
208+
209+
result = await self._export_interactions(
210+
http_session, session_id, trajectory_id=trajectory_id
205211
)
206-
if not interactions:
212+
213+
if isinstance(result, dict) and result.get("external_api"):
214+
return result
215+
216+
if not result:
207217
return None
208218

209-
last_id = next(reversed(interactions))
210-
last_reward = interactions[last_id].reward
219+
last_id = next(reversed(result))
220+
last_reward = result[last_id].reward
211221
stats_tracker.get(workflow_context.stat_scope()).scalar(reward=last_reward)
212-
return interactions
222+
return result

0 commit comments

Comments
 (0)