Skip to content

Commit 4ff2e8e

Browse files
fix: release image router models after compression
1 parent fcef84f commit 4ff2e8e

12 files changed

Lines changed: 324 additions & 75 deletions

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3636
backup is missing, strips only the Headroom-managed block and leaves
3737
surrounding user content intact). Safe no-op when run without a prior
3838
wrap. Reported by @raenaryl in Discord.
39+
- **Image compressors now release shared router models after use and proxy shutdown**
40+
the proxy/image compression path no longer keeps global `technique-router`
41+
and `SigLIP` model instances pinned in memory after one-off image
42+
optimization work. The `get_compressor()` helper now returns a fresh,
43+
caller-owned compressor instead of a process-lifetime singleton.
3944
- **`headroom learn` no longer clobbers prior recommendations on re-run**
4045
the marker block in `CLAUDE.md` / `MEMORY.md` is now merged with the
4146
prior block instead of wholesale-replaced. Sections re-surfaced by the

headroom/image/compressor.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,14 @@ def _get_router(self) -> TrainedRouter:
102102
)
103103
return self._router
104104

105+
def close(self, unload_models: bool = True) -> None:
106+
"""Release any router-held model state."""
107+
if self._router is not None:
108+
# Only loaded routers hold heavyweight image models; plain has_images()
109+
# checks remain cheap and have nothing to release.
110+
self._router.release_models(unload_registry=unload_models)
111+
self._router = None
112+
105113
def has_images(self, messages: list[dict[str, Any]]) -> bool:
106114
"""Check if messages contain images."""
107115
for message in messages:
@@ -563,16 +571,13 @@ def compress(
563571
return compressed_messages
564572

565573

566-
# Singleton for convenience
567-
_default_compressor: ImageCompressor | None = None
568-
569-
570574
def get_compressor() -> ImageCompressor:
571-
"""Get the default ImageCompressor instance."""
572-
global _default_compressor
573-
if _default_compressor is None:
574-
_default_compressor = ImageCompressor()
575-
return _default_compressor
575+
"""Create an ImageCompressor instance.
576+
577+
Kept for backwards-compatible imports; callers that use it directly own
578+
closing the returned compressor.
579+
"""
580+
return ImageCompressor()
576581

577582

578583
def compress_images(
@@ -588,4 +593,8 @@ def compress_images(
588593
Returns:
589594
Messages with compressed images
590595
"""
591-
return get_compressor().compress(messages, provider)
596+
compressor = ImageCompressor()
597+
try:
598+
return compressor.compress(messages, provider)
599+
finally:
600+
compressor.close()

headroom/image/trained_router.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from __future__ import annotations
1212

13+
import gc
1314
import io
1415
from dataclasses import dataclass
1516
from enum import Enum
@@ -151,6 +152,8 @@ def __init__(
151152
self._siglip_model: Any = None
152153
self._siglip_processor: Any = None
153154
self._text_embeddings: Any = None
155+
self._classifier_key: str | None = None
156+
self._siglip_key: str | None = None
154157

155158
def is_available(self) -> bool:
156159
"""Check if required models can be loaded."""
@@ -186,6 +189,7 @@ def _load_models(self) -> None:
186189
model_path=model_id,
187190
device=self.device,
188191
)
192+
self._classifier_key = f"technique_router:{model_id}"
189193

190194
if self.use_siglip and self._siglip_model is None:
191195
# Use centralized registry for shared model instances
@@ -195,10 +199,34 @@ def _load_models(self) -> None:
195199
model_name=self.siglip_model,
196200
device=self.device,
197201
)
202+
self._siglip_key = f"siglip:{self.siglip_model}"
198203

199204
# Pre-compute text embeddings for image analysis
200205
self._compute_text_embeddings()
201206

207+
def release_models(self, unload_registry: bool = True) -> None:
208+
"""Release router-held model references and optional shared cache entries."""
209+
classifier_key = self._classifier_key
210+
siglip_key = self._siglip_key
211+
212+
self._text_embeddings = None
213+
self._siglip_processor = None
214+
self._siglip_model = None
215+
self._tokenizer = None
216+
self._classifier = None
217+
self._classifier_key = None
218+
self._siglip_key = None
219+
220+
if unload_registry:
221+
from headroom.models.ml_models import MLModelRegistry
222+
223+
keys = [key for key in (classifier_key, siglip_key) if key]
224+
MLModelRegistry.unload_many(keys)
225+
else:
226+
gc.collect()
227+
228+
close = release_models
229+
202230
def _compute_text_embeddings(self) -> None:
203231
"""Pre-compute SigLIP text embeddings for image analysis."""
204232
assert self._siglip_processor is not None

headroom/models/ml_models.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626

2727
from __future__ import annotations
2828

29+
import contextlib
30+
import gc
2931
import logging
3032
from threading import RLock
3133
from typing import TYPE_CHECKING, Any
@@ -76,6 +78,63 @@ def reset(cls) -> None:
7678
cls._instance._models.clear()
7779
cls._instance = None
7880

81+
@classmethod
82+
def _release_runtime_memory(cls) -> None:
83+
"""Best-effort cleanup after unloading heavyweight models."""
84+
gc.collect()
85+
try:
86+
import torch
87+
except ImportError:
88+
return
89+
90+
with contextlib.suppress(Exception):
91+
if torch.cuda.is_available():
92+
torch.cuda.empty_cache()
93+
94+
mps = getattr(torch, "mps", None)
95+
if mps is not None and hasattr(mps, "empty_cache"):
96+
mps.empty_cache()
97+
98+
@classmethod
99+
def unload(cls, key: str) -> bool:
100+
"""Unload one cached model entry."""
101+
return bool(cls.unload_many([key]))
102+
103+
@classmethod
104+
def unload_many(cls, keys: list[str]) -> list[str]:
105+
"""Unload several cached model entries with one runtime cleanup pass."""
106+
instance = cls.get()
107+
removed_keys: list[str] = []
108+
109+
with instance._model_lock:
110+
for key in keys:
111+
if key not in instance._models:
112+
continue
113+
value = instance._models.pop(key)
114+
del value
115+
removed_keys.append(key)
116+
117+
if removed_keys:
118+
cls._release_runtime_memory()
119+
return removed_keys
120+
121+
@classmethod
122+
def unload_prefix(cls, prefix: str) -> list[str]:
123+
"""Unload every cached model entry matching a prefix."""
124+
instance = cls.get()
125+
removed_keys: list[str] = []
126+
127+
with instance._model_lock:
128+
for key in list(instance._models):
129+
if key.startswith(prefix):
130+
value = instance._models.pop(key)
131+
del value
132+
removed_keys.append(key)
133+
134+
if removed_keys:
135+
cls._release_runtime_memory()
136+
return removed_keys
137+
79138
# =========================================================================
80139
# Sentence Transformers
81140
# =========================================================================

headroom/proxy/handlers/anthropic.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -709,16 +709,21 @@ async def _finalize_pre_upstream() -> None:
709709
and not _bypass
710710
and not is_cache_mode(self.config.mode)
711711
):
712-
compressor = _get_image_compressor()
713-
if compressor and compressor.has_images(messages):
714-
messages = compressor.compress(messages, provider="anthropic")
715-
if compressor.last_result:
716-
logger.info(
717-
f"Image compression: {compressor.last_result.technique.value} "
718-
f"({compressor.last_result.savings_percent:.0f}% saved, "
719-
f"{compressor.last_result.original_tokens} -> "
720-
f"{compressor.last_result.compressed_tokens} tokens)"
721-
)
712+
compressor = None
713+
try:
714+
compressor = _get_image_compressor()
715+
if compressor and compressor.has_images(messages):
716+
messages = compressor.compress(messages, provider="anthropic")
717+
if compressor.last_result:
718+
logger.info(
719+
f"Image compression: {compressor.last_result.technique.value} "
720+
f"({compressor.last_result.savings_percent:.0f}% saved, "
721+
f"{compressor.last_result.original_tokens} -> "
722+
f"{compressor.last_result.compressed_tokens} tokens)"
723+
)
724+
finally:
725+
if compressor and hasattr(compressor, "close"):
726+
compressor.close()
722727

723728
_compression_failed = False
724729
original_messages = messages # Preserve for 400-retry fallback

headroom/proxy/handlers/openai.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -239,16 +239,21 @@ async def handle_openai_chat(
239239
if self.config.image_optimize and messages and not _bypass:
240240
from headroom.proxy.helpers import _get_image_compressor
241241

242-
compressor = _get_image_compressor()
243-
if compressor and compressor.has_images(messages):
244-
messages = compressor.compress(messages, provider="openai")
245-
if compressor.last_result:
246-
logger.info(
247-
f"[{request_id}] Image: {compressor.last_result.technique.value} "
248-
f"({compressor.last_result.savings_percent:.0f}% saved, "
249-
f"{compressor.last_result.original_tokens} → "
250-
f"{compressor.last_result.compressed_tokens} tokens)"
251-
)
242+
compressor = None
243+
try:
244+
compressor = _get_image_compressor()
245+
if compressor and compressor.has_images(messages):
246+
messages = compressor.compress(messages, provider="openai")
247+
if compressor.last_result:
248+
logger.info(
249+
f"[{request_id}] Image: {compressor.last_result.technique.value} "
250+
f"({compressor.last_result.savings_percent:.0f}% saved, "
251+
f"{compressor.last_result.original_tokens} → "
252+
f"{compressor.last_result.compressed_tokens} tokens)"
253+
)
254+
finally:
255+
if compressor and hasattr(compressor, "close"):
256+
compressor.close()
252257

253258
headers = dict(request.headers.items())
254259
headers.pop("host", None)

headroom/proxy/helpers.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,23 +59,31 @@ def jitter_delay_ms(base_ms: int, max_ms: int, attempt: int) -> float:
5959
return capped * (0.5 + random.random())
6060

6161

62-
# Image compression (lazy-loaded to avoid heavy dependencies at startup)
63-
_image_compressor = None
62+
# Image compression availability (do not retain a global compressor instance)
63+
_image_compressor_available: bool | None = None
6464

6565

6666
def _get_image_compressor():
67-
"""Lazy load image compressor to avoid startup overhead."""
68-
global _image_compressor
69-
if _image_compressor is None:
70-
try:
71-
from headroom.image import ImageCompressor
67+
"""Create a short-lived image compressor on demand."""
68+
global _image_compressor_available
69+
if _image_compressor_available is False:
70+
return None
7271

73-
_image_compressor = ImageCompressor()
72+
try:
73+
from headroom.image import ImageCompressor
74+
75+
# Callers own closing the compressor; this helper only memoizes whether
76+
# the optional image stack is importable.
77+
compressor = ImageCompressor()
78+
if _image_compressor_available is None:
7479
logger.info("Image compression enabled (model: chopratejas/technique-router)")
75-
except ImportError as e:
80+
_image_compressor_available = True
81+
return compressor
82+
except ImportError as e:
83+
if _image_compressor_available is not False:
7684
logger.warning(f"Image compression not available: {e}")
77-
_image_compressor = False # Mark as unavailable
78-
return _image_compressor if _image_compressor else None
85+
_image_compressor_available = False
86+
return None
7987

8088

8189
# Always-on file logging to the workspace logs directory for `headroom perf` analysis.

headroom/proxy/server.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import argparse
2727
import asyncio
28+
import contextlib
2829
import json
2930
import logging
3031
import os
@@ -844,6 +845,15 @@ async def shutdown(self):
844845
if self.memory_handler and hasattr(self.memory_handler, "close"):
845846
await self.memory_handler.close()
846847

848+
with contextlib.suppress(Exception):
849+
from headroom.models.ml_models import MLModelRegistry
850+
851+
released_models: list[str] = []
852+
released_models.extend(MLModelRegistry.unload_prefix("technique_router:"))
853+
released_models.extend(MLModelRegistry.unload_prefix("siglip:"))
854+
if released_models:
855+
logger.info("Released image optimizer models: %s", ", ".join(released_models))
856+
847857
# Stop all quota trackers via the registry
848858
await get_quota_registry().stop_all()
849859

0 commit comments

Comments
 (0)