Skip to content

Commit 4cb9997

Browse files
feat(event_handler): add Dependency injection with Depends() (#8128)
* feat: add Dependency injection feature * Merging from develop * feat: add Dependency injection feature * feat: add Dependency injection feature * feat: add Dependency injection feature * feat: add Dependency injection feature
1 parent f1d07ab commit 4cb9997

File tree

12 files changed

+1077
-1
lines changed

12 files changed

+1077
-1
lines changed

aws_lambda_powertools/event_handler/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
BedrockAgentFunctionResolver,
1717
BedrockFunctionResponse,
1818
)
19+
from aws_lambda_powertools.event_handler.depends import DependencyResolutionError, Depends
1920
from aws_lambda_powertools.event_handler.events_appsync.appsync_events import AppSyncEventsResolver
2021
from aws_lambda_powertools.event_handler.http_resolver import HttpResolverLocal
2122
from aws_lambda_powertools.event_handler.lambda_function_url import (
@@ -36,6 +37,8 @@
3637
"BedrockResponse",
3738
"BedrockFunctionResponse",
3839
"CORSConfig",
40+
"Depends",
41+
"DependencyResolutionError",
3942
"HttpResolverLocal",
4043
"LambdaFunctionUrlResolver",
4144
"Request",

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,9 @@ def __init__(
472472

473473
self.custom_response_validation_http_code = custom_response_validation_http_code
474474

475+
# Cache whether this route's handler declares Depends() parameters
476+
self._has_dependencies: bool | None = None
477+
475478
# Caches the name of any Request-typed parameter in the handler.
476479
# Avoids re-scanning the signature on every invocation.
477480
self.request_param_name: str | None = None
@@ -613,6 +616,15 @@ def dependant(self) -> Dependant:
613616

614617
return self._dependant
615618

619+
@property
620+
def has_dependencies(self) -> bool:
621+
"""Check if handler declares Depends() parameters without triggering full dependant computation."""
622+
if self._has_dependencies is None:
623+
from aws_lambda_powertools.event_handler.depends import _has_depends
624+
625+
self._has_dependencies = _has_depends(self.func)
626+
return self._has_dependencies
627+
616628
@property
617629
def body_field(self) -> ModelField | None:
618630
if self._body_field is None:
@@ -1428,6 +1440,17 @@ def _registered_api_adapter(
14281440
if route.request_param_name:
14291441
route_args = {**route_args, route.request_param_name: app.request}
14301442

1443+
# Resolve Depends() parameters
1444+
if route.has_dependencies:
1445+
from aws_lambda_powertools.event_handler.depends import build_dependency_tree, solve_dependencies
1446+
1447+
dep_values = solve_dependencies(
1448+
dependant=build_dependency_tree(route.func),
1449+
request=app.request,
1450+
dependency_overrides=app.dependency_overrides or None,
1451+
)
1452+
route_args.update(dep_values)
1453+
14311454
return app._to_response(next_middleware(**route_args))
14321455

14331456

@@ -1497,6 +1520,7 @@ def __init__(
14971520
function to deserialize `str`, `bytes`, `bytearray` containing a JSON document to a Python `dict`,
14981521
by default json.loads when integrating with EventSource data class
14991522
"""
1523+
self.dependency_overrides: dict[Callable, Callable] = {}
15001524
self._proxy_type = proxy_type or self._proxy_event_type
15011525
self._dynamic_routes: list[Route] = []
15021526
self._static_routes: list[Route] = []
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
"""Lightweight dependency injection primitives — no pydantic import."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING, Annotated, Any, get_args, get_origin, get_type_hints
6+
7+
if TYPE_CHECKING:
8+
from collections.abc import Callable
9+
10+
from aws_lambda_powertools.event_handler.openapi.params import Dependant
11+
from aws_lambda_powertools.event_handler.request import Request
12+
13+
14+
class DependencyResolutionError(Exception):
15+
"""Raised when a dependency cannot be resolved."""
16+
17+
18+
class Depends:
19+
"""
20+
Declares a dependency for a route handler parameter.
21+
22+
Dependencies are resolved automatically before the handler is called. The return value
23+
of the dependency callable is injected as the parameter value.
24+
25+
Parameters
26+
----------
27+
dependency: Callable[..., Any]
28+
A callable whose return value will be injected into the handler parameter.
29+
The callable can itself declare ``Depends()`` parameters to form a dependency tree.
30+
use_cache: bool
31+
If ``True`` (default), the dependency result is cached per invocation so that
32+
the same dependency used multiple times is only called once.
33+
34+
Examples
35+
--------
36+
37+
```python
38+
from typing import Annotated
39+
40+
from aws_lambda_powertools.event_handler import APIGatewayHttpResolver, Depends
41+
42+
app = APIGatewayHttpResolver()
43+
44+
def get_tenant() -> str:
45+
return "default-tenant"
46+
47+
@app.get("/orders")
48+
def list_orders(tenant_id: Annotated[str, Depends(get_tenant)]):
49+
return {"tenant": tenant_id}
50+
```
51+
"""
52+
53+
def __init__(self, dependency: Callable[..., Any], *, use_cache: bool = True) -> None:
54+
if not callable(dependency):
55+
raise DependencyResolutionError(
56+
f"Depends() requires a callable, got {type(dependency).__name__}: {dependency!r}",
57+
)
58+
self.dependency = dependency
59+
self.use_cache = use_cache
60+
61+
62+
class _DependencyNode:
63+
"""Lightweight node in a dependency tree — used by ``build_dependency_tree``."""
64+
65+
def __init__(self, *, param_name: str, depends: Depends, sub_tree: DependencyTree) -> None:
66+
self.param_name = param_name
67+
self.depends = depends
68+
self.dependant = sub_tree
69+
70+
71+
class DependencyTree:
72+
"""Lightweight dependency tree — no pydantic required.
73+
74+
This mirrors the shape that ``solve_dependencies`` expects (a ``.dependencies``
75+
attribute containing nodes with ``.param_name``, ``.depends``, and ``.dependant``),
76+
but can be built without importing pydantic.
77+
"""
78+
79+
def __init__(self, *, dependencies: list[_DependencyNode] | None = None) -> None:
80+
self.dependencies: list[_DependencyNode] = dependencies or []
81+
82+
83+
class DependencyParam:
84+
"""Holds a dependency's parameter name and its resolved Dependant sub-tree (OpenAPI path)."""
85+
86+
def __init__(self, *, param_name: str, depends: Depends, dependant: Dependant) -> None:
87+
self.param_name = param_name
88+
self.depends = depends
89+
self.dependant = dependant
90+
91+
92+
def _get_depends_from_annotation(annotation: Any) -> Depends | None:
93+
"""Extract a Depends instance from an Annotated[Type, Depends(...)] annotation."""
94+
if get_origin(annotation) is Annotated:
95+
for arg in get_args(annotation)[1:]:
96+
if isinstance(arg, Depends):
97+
return arg
98+
return None
99+
100+
101+
def _has_depends(func: Callable[..., Any]) -> bool:
102+
"""Check if a callable has any Depends() parameters, without importing pydantic."""
103+
try:
104+
hints = get_type_hints(func, include_extras=True)
105+
except Exception:
106+
return False
107+
108+
for annotation in hints.values():
109+
if _get_depends_from_annotation(annotation) is not None:
110+
return True
111+
return False
112+
113+
114+
def build_dependency_tree(func: Callable[..., Any]) -> DependencyTree:
115+
"""Build a lightweight dependency tree from a callable's signature.
116+
117+
This inspects the function parameters for ``Annotated[Type, Depends(...)]``
118+
annotations and recursively builds the tree — all without importing pydantic.
119+
"""
120+
try:
121+
hints = get_type_hints(func, include_extras=True)
122+
except Exception:
123+
return DependencyTree()
124+
125+
dependencies: list[_DependencyNode] = []
126+
127+
for param_name, annotation in hints.items():
128+
if param_name == "return":
129+
continue
130+
131+
depends_instance = _get_depends_from_annotation(annotation)
132+
if depends_instance is not None:
133+
sub_tree = build_dependency_tree(depends_instance.dependency)
134+
dependencies.append(
135+
_DependencyNode(
136+
param_name=param_name,
137+
depends=depends_instance,
138+
sub_tree=sub_tree,
139+
),
140+
)
141+
142+
return DependencyTree(dependencies=dependencies)
143+
144+
145+
def solve_dependencies(
146+
*,
147+
dependant: Dependant | DependencyTree,
148+
request: Request | None = None,
149+
dependency_overrides: dict[Callable[..., Any], Callable[..., Any]] | None = None,
150+
dependency_cache: dict[Callable[..., Any], Any] | None = None,
151+
) -> dict[str, Any]:
152+
"""
153+
Recursively resolve all ``Depends()`` parameters for a given dependant.
154+
155+
Parameters
156+
----------
157+
dependant: Dependant
158+
The dependant model containing dependency declarations
159+
request: Request, optional
160+
The current request object, injected into dependencies that declare a Request parameter
161+
dependency_overrides: dict, optional
162+
Mapping of original dependency callable to override callable (for testing)
163+
dependency_cache: dict, optional
164+
Per-invocation cache of resolved dependency values
165+
166+
Returns
167+
-------
168+
dict[str, Any]
169+
Mapping of parameter name to resolved dependency value
170+
"""
171+
from aws_lambda_powertools.event_handler.request import Request as RequestClass
172+
173+
if dependency_cache is None:
174+
dependency_cache = {}
175+
176+
values: dict[str, Any] = {}
177+
178+
for dep in dependant.dependencies:
179+
use_fn = dep.depends.dependency
180+
181+
# Apply overrides (for testing)
182+
if dependency_overrides and use_fn in dependency_overrides:
183+
use_fn = dependency_overrides[use_fn]
184+
185+
# Check cache
186+
if dep.depends.use_cache and use_fn in dependency_cache:
187+
values[dep.param_name] = dependency_cache[use_fn]
188+
continue
189+
190+
# Recursively resolve sub-dependencies
191+
sub_values = solve_dependencies(
192+
dependant=dep.dependant,
193+
request=request,
194+
dependency_overrides=dependency_overrides,
195+
dependency_cache=dependency_cache,
196+
)
197+
198+
# Inject Request if the dependency declares it
199+
if request is not None:
200+
try:
201+
hints = get_type_hints(use_fn)
202+
except Exception: # pragma: no cover - defensive for broken annotations
203+
hints = {}
204+
for param_name, annotation in hints.items():
205+
if annotation is RequestClass:
206+
sub_values[param_name] = request
207+
208+
try:
209+
solved = use_fn(**sub_values)
210+
except Exception as exc:
211+
dep_name = getattr(use_fn, "__name__", repr(use_fn))
212+
raise DependencyResolutionError(
213+
f"Failed to resolve dependency '{dep_name}' for parameter '{dep.param_name}': {exc}",
214+
) from exc
215+
216+
# Cache result
217+
if dep.depends.use_cache:
218+
dependency_cache[use_fn] = solved
219+
220+
values[dep.param_name] = solved
221+
222+
return values

aws_lambda_powertools/event_handler/openapi/dependant.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import re
55
from typing import TYPE_CHECKING, Any, ForwardRef, cast
66

7+
from aws_lambda_powertools.event_handler.depends import DependencyParam, _get_depends_from_annotation
78
from aws_lambda_powertools.event_handler.openapi.compat import (
89
ModelField,
910
create_body_model,
@@ -193,6 +194,22 @@ def get_dependant(
193194
if param.annotation is Request:
194195
continue
195196

197+
# Depends() parameters (via Annotated[Type, Depends(fn)]) are resolved at call time.
198+
depends_instance = _get_depends_from_annotation(param.annotation)
199+
if depends_instance is not None:
200+
sub_dependant = get_dependant(
201+
path=path,
202+
call=depends_instance.dependency,
203+
)
204+
dependant.dependencies.append(
205+
DependencyParam(
206+
param_name=param_name,
207+
depends=depends_instance,
208+
dependant=sub_dependant,
209+
),
210+
)
211+
continue
212+
196213
# If the parameter is a path parameter, we need to set the in_ field to "path".
197214
is_path_param = param_name in path_param_names
198215

aws_lambda_powertools/event_handler/openapi/params.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
if TYPE_CHECKING:
2424
from collections.abc import Callable
2525

26+
from aws_lambda_powertools.event_handler.depends import DependencyParam
2627
from aws_lambda_powertools.event_handler.openapi.models import Example
2728
from aws_lambda_powertools.event_handler.openapi.types import CacheKey
2829

@@ -64,6 +65,7 @@ def __init__(
6465
http_connection_param_name: str | None = None,
6566
response_param_name: str | None = None,
6667
background_tasks_param_name: str | None = None,
68+
dependencies: list[DependencyParam] | None = None,
6769
path: str | None = None,
6870
) -> None:
6971
self.path_params = path_params or []
@@ -78,6 +80,7 @@ def __init__(
7880
self.http_connection_param_name = http_connection_param_name
7981
self.response_param_name = response_param_name
8082
self.background_tasks_param_name = background_tasks_param_name
83+
self.dependencies = dependencies or []
8184
self.name = name
8285
self.call = call
8386
# Store the path to be able to re-generate a dependable from it in overrides
@@ -816,7 +819,7 @@ def get_flat_dependant(
816819
visited = []
817820
visited.append(dependant.cache_key)
818821

819-
return Dependant(
822+
flat = Dependant(
820823
path_params=dependant.path_params.copy(),
821824
query_params=dependant.query_params.copy(),
822825
header_params=dependant.header_params.copy(),
@@ -825,6 +828,18 @@ def get_flat_dependant(
825828
path=dependant.path,
826829
)
827830

831+
# Flatten sub-dependencies that declare HTTP params (query, header, etc.)
832+
for dep in dependant.dependencies:
833+
if dep.dependant.cache_key not in visited:
834+
sub_flat = get_flat_dependant(dep.dependant, visited=visited)
835+
flat.path_params.extend(sub_flat.path_params)
836+
flat.query_params.extend(sub_flat.query_params)
837+
flat.header_params.extend(sub_flat.header_params)
838+
flat.cookie_params.extend(sub_flat.cookie_params)
839+
flat.body_params.extend(sub_flat.body_params)
840+
841+
return flat
842+
828843

829844
def analyze_param(
830845
*,

0 commit comments

Comments
 (0)