|
| 1 | +"""Lightweight dependency injection primitives — no pydantic import.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +from typing import TYPE_CHECKING, Annotated, Any, get_args, get_origin, get_type_hints |
| 6 | + |
| 7 | +if TYPE_CHECKING: |
| 8 | + from collections.abc import Callable |
| 9 | + |
| 10 | + from aws_lambda_powertools.event_handler.openapi.params import Dependant |
| 11 | + from aws_lambda_powertools.event_handler.request import Request |
| 12 | + |
| 13 | + |
| 14 | +class DependencyResolutionError(Exception): |
| 15 | + """Raised when a dependency cannot be resolved.""" |
| 16 | + |
| 17 | + |
| 18 | +class Depends: |
| 19 | + """ |
| 20 | + Declares a dependency for a route handler parameter. |
| 21 | +
|
| 22 | + Dependencies are resolved automatically before the handler is called. The return value |
| 23 | + of the dependency callable is injected as the parameter value. |
| 24 | +
|
| 25 | + Parameters |
| 26 | + ---------- |
| 27 | + dependency: Callable[..., Any] |
| 28 | + A callable whose return value will be injected into the handler parameter. |
| 29 | + The callable can itself declare ``Depends()`` parameters to form a dependency tree. |
| 30 | + use_cache: bool |
| 31 | + If ``True`` (default), the dependency result is cached per invocation so that |
| 32 | + the same dependency used multiple times is only called once. |
| 33 | +
|
| 34 | + Examples |
| 35 | + -------- |
| 36 | +
|
| 37 | + ```python |
| 38 | + from typing import Annotated |
| 39 | +
|
| 40 | + from aws_lambda_powertools.event_handler import APIGatewayHttpResolver, Depends |
| 41 | +
|
| 42 | + app = APIGatewayHttpResolver() |
| 43 | +
|
| 44 | + def get_tenant() -> str: |
| 45 | + return "default-tenant" |
| 46 | +
|
| 47 | + @app.get("/orders") |
| 48 | + def list_orders(tenant_id: Annotated[str, Depends(get_tenant)]): |
| 49 | + return {"tenant": tenant_id} |
| 50 | + ``` |
| 51 | + """ |
| 52 | + |
| 53 | + def __init__(self, dependency: Callable[..., Any], *, use_cache: bool = True) -> None: |
| 54 | + if not callable(dependency): |
| 55 | + raise DependencyResolutionError( |
| 56 | + f"Depends() requires a callable, got {type(dependency).__name__}: {dependency!r}", |
| 57 | + ) |
| 58 | + self.dependency = dependency |
| 59 | + self.use_cache = use_cache |
| 60 | + |
| 61 | + |
| 62 | +class _DependencyNode: |
| 63 | + """Lightweight node in a dependency tree — used by ``build_dependency_tree``.""" |
| 64 | + |
| 65 | + def __init__(self, *, param_name: str, depends: Depends, sub_tree: DependencyTree) -> None: |
| 66 | + self.param_name = param_name |
| 67 | + self.depends = depends |
| 68 | + self.dependant = sub_tree |
| 69 | + |
| 70 | + |
| 71 | +class DependencyTree: |
| 72 | + """Lightweight dependency tree — no pydantic required. |
| 73 | +
|
| 74 | + This mirrors the shape that ``solve_dependencies`` expects (a ``.dependencies`` |
| 75 | + attribute containing nodes with ``.param_name``, ``.depends``, and ``.dependant``), |
| 76 | + but can be built without importing pydantic. |
| 77 | + """ |
| 78 | + |
| 79 | + def __init__(self, *, dependencies: list[_DependencyNode] | None = None) -> None: |
| 80 | + self.dependencies: list[_DependencyNode] = dependencies or [] |
| 81 | + |
| 82 | + |
| 83 | +class DependencyParam: |
| 84 | + """Holds a dependency's parameter name and its resolved Dependant sub-tree (OpenAPI path).""" |
| 85 | + |
| 86 | + def __init__(self, *, param_name: str, depends: Depends, dependant: Dependant) -> None: |
| 87 | + self.param_name = param_name |
| 88 | + self.depends = depends |
| 89 | + self.dependant = dependant |
| 90 | + |
| 91 | + |
| 92 | +def _get_depends_from_annotation(annotation: Any) -> Depends | None: |
| 93 | + """Extract a Depends instance from an Annotated[Type, Depends(...)] annotation.""" |
| 94 | + if get_origin(annotation) is Annotated: |
| 95 | + for arg in get_args(annotation)[1:]: |
| 96 | + if isinstance(arg, Depends): |
| 97 | + return arg |
| 98 | + return None |
| 99 | + |
| 100 | + |
| 101 | +def _has_depends(func: Callable[..., Any]) -> bool: |
| 102 | + """Check if a callable has any Depends() parameters, without importing pydantic.""" |
| 103 | + try: |
| 104 | + hints = get_type_hints(func, include_extras=True) |
| 105 | + except Exception: |
| 106 | + return False |
| 107 | + |
| 108 | + for annotation in hints.values(): |
| 109 | + if _get_depends_from_annotation(annotation) is not None: |
| 110 | + return True |
| 111 | + return False |
| 112 | + |
| 113 | + |
| 114 | +def build_dependency_tree(func: Callable[..., Any]) -> DependencyTree: |
| 115 | + """Build a lightweight dependency tree from a callable's signature. |
| 116 | +
|
| 117 | + This inspects the function parameters for ``Annotated[Type, Depends(...)]`` |
| 118 | + annotations and recursively builds the tree — all without importing pydantic. |
| 119 | + """ |
| 120 | + try: |
| 121 | + hints = get_type_hints(func, include_extras=True) |
| 122 | + except Exception: |
| 123 | + return DependencyTree() |
| 124 | + |
| 125 | + dependencies: list[_DependencyNode] = [] |
| 126 | + |
| 127 | + for param_name, annotation in hints.items(): |
| 128 | + if param_name == "return": |
| 129 | + continue |
| 130 | + |
| 131 | + depends_instance = _get_depends_from_annotation(annotation) |
| 132 | + if depends_instance is not None: |
| 133 | + sub_tree = build_dependency_tree(depends_instance.dependency) |
| 134 | + dependencies.append( |
| 135 | + _DependencyNode( |
| 136 | + param_name=param_name, |
| 137 | + depends=depends_instance, |
| 138 | + sub_tree=sub_tree, |
| 139 | + ), |
| 140 | + ) |
| 141 | + |
| 142 | + return DependencyTree(dependencies=dependencies) |
| 143 | + |
| 144 | + |
| 145 | +def solve_dependencies( |
| 146 | + *, |
| 147 | + dependant: Dependant | DependencyTree, |
| 148 | + request: Request | None = None, |
| 149 | + dependency_overrides: dict[Callable[..., Any], Callable[..., Any]] | None = None, |
| 150 | + dependency_cache: dict[Callable[..., Any], Any] | None = None, |
| 151 | +) -> dict[str, Any]: |
| 152 | + """ |
| 153 | + Recursively resolve all ``Depends()`` parameters for a given dependant. |
| 154 | +
|
| 155 | + Parameters |
| 156 | + ---------- |
| 157 | + dependant: Dependant |
| 158 | + The dependant model containing dependency declarations |
| 159 | + request: Request, optional |
| 160 | + The current request object, injected into dependencies that declare a Request parameter |
| 161 | + dependency_overrides: dict, optional |
| 162 | + Mapping of original dependency callable to override callable (for testing) |
| 163 | + dependency_cache: dict, optional |
| 164 | + Per-invocation cache of resolved dependency values |
| 165 | +
|
| 166 | + Returns |
| 167 | + ------- |
| 168 | + dict[str, Any] |
| 169 | + Mapping of parameter name to resolved dependency value |
| 170 | + """ |
| 171 | + from aws_lambda_powertools.event_handler.request import Request as RequestClass |
| 172 | + |
| 173 | + if dependency_cache is None: |
| 174 | + dependency_cache = {} |
| 175 | + |
| 176 | + values: dict[str, Any] = {} |
| 177 | + |
| 178 | + for dep in dependant.dependencies: |
| 179 | + use_fn = dep.depends.dependency |
| 180 | + |
| 181 | + # Apply overrides (for testing) |
| 182 | + if dependency_overrides and use_fn in dependency_overrides: |
| 183 | + use_fn = dependency_overrides[use_fn] |
| 184 | + |
| 185 | + # Check cache |
| 186 | + if dep.depends.use_cache and use_fn in dependency_cache: |
| 187 | + values[dep.param_name] = dependency_cache[use_fn] |
| 188 | + continue |
| 189 | + |
| 190 | + # Recursively resolve sub-dependencies |
| 191 | + sub_values = solve_dependencies( |
| 192 | + dependant=dep.dependant, |
| 193 | + request=request, |
| 194 | + dependency_overrides=dependency_overrides, |
| 195 | + dependency_cache=dependency_cache, |
| 196 | + ) |
| 197 | + |
| 198 | + # Inject Request if the dependency declares it |
| 199 | + if request is not None: |
| 200 | + try: |
| 201 | + hints = get_type_hints(use_fn) |
| 202 | + except Exception: # pragma: no cover - defensive for broken annotations |
| 203 | + hints = {} |
| 204 | + for param_name, annotation in hints.items(): |
| 205 | + if annotation is RequestClass: |
| 206 | + sub_values[param_name] = request |
| 207 | + |
| 208 | + try: |
| 209 | + solved = use_fn(**sub_values) |
| 210 | + except Exception as exc: |
| 211 | + dep_name = getattr(use_fn, "__name__", repr(use_fn)) |
| 212 | + raise DependencyResolutionError( |
| 213 | + f"Failed to resolve dependency '{dep_name}' for parameter '{dep.param_name}': {exc}", |
| 214 | + ) from exc |
| 215 | + |
| 216 | + # Cache result |
| 217 | + if dep.depends.use_cache: |
| 218 | + dependency_cache[use_fn] = solved |
| 219 | + |
| 220 | + values[dep.param_name] = solved |
| 221 | + |
| 222 | + return values |
0 commit comments