|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Qualcomm Innovation Center, Inc. |
| 3 | +# All rights reserved |
| 4 | +# |
| 5 | +# This source code is licensed under the BSD-style license found in the |
| 6 | +# LICENSE file in the root directory of this source tree. |
| 7 | + |
| 8 | +"""Validate internal imports for QNN example entrypoints. |
| 9 | +
|
| 10 | +Entrypoints are discovered dynamically from `examples/qualcomm` by looking for |
| 11 | +Python files that define a standard `if __name__ == "__main__"` block. This |
| 12 | +keeps the check focused on runnable scripts while avoiding a hardcoded list |
| 13 | +that drifts as examples are added, moved, or removed. |
| 14 | +""" |
| 15 | + |
| 16 | +import ast |
| 17 | +import importlib.util |
| 18 | +import sys |
| 19 | +from pathlib import Path |
| 20 | + |
| 21 | + |
| 22 | +EXAMPLE_MODULE_PREFIX = "executorch.examples.qualcomm." |
| 23 | + |
| 24 | + |
| 25 | +def resolve_examples_root(): |
| 26 | + for parent in Path(__file__).resolve().parents: |
| 27 | + candidate = parent / "examples" / "qualcomm" |
| 28 | + if candidate.is_dir(): |
| 29 | + return candidate |
| 30 | + return None |
| 31 | + |
| 32 | + |
| 33 | +def is_main_guard(test: ast.AST) -> bool: |
| 34 | + if not isinstance(test, ast.Compare): |
| 35 | + return False |
| 36 | + if len(test.ops) != 1 or len(test.comparators) != 1: |
| 37 | + return False |
| 38 | + if not isinstance(test.ops[0], ast.Eq): |
| 39 | + return False |
| 40 | + if not isinstance(test.left, ast.Name) or test.left.id != "__name__": |
| 41 | + return False |
| 42 | + comparator = test.comparators[0] |
| 43 | + return isinstance(comparator, ast.Constant) and comparator.value == "__main__" |
| 44 | + |
| 45 | + |
| 46 | +def is_entrypoint(tree: ast.AST) -> bool: |
| 47 | + for node in ast.walk(tree): |
| 48 | + if isinstance(node, ast.If) and is_main_guard(node.test): |
| 49 | + return True |
| 50 | + return False |
| 51 | + |
| 52 | + |
| 53 | +def discover_entrypoints(examples_root: Path) -> list[str]: |
| 54 | + entrypoints = [] |
| 55 | + for path in sorted(examples_root.rglob("*.py")): |
| 56 | + if path.name == "__init__.py": |
| 57 | + continue |
| 58 | + tree = ast.parse(path.read_text(), filename=str(path)) |
| 59 | + if is_entrypoint(tree): |
| 60 | + entrypoints.append(path.relative_to(examples_root).as_posix()) |
| 61 | + return entrypoints |
| 62 | + |
| 63 | + |
| 64 | +def module_base_path(repo_root: Path, module_name: str) -> Path: |
| 65 | + return repo_root.joinpath(*module_name.split(".")[1:]) |
| 66 | + |
| 67 | + |
| 68 | +def module_exists(repo_root: Path, module_name: str) -> bool: |
| 69 | + base_path = module_base_path(repo_root, module_name) |
| 70 | + if base_path.is_dir() or base_path.with_suffix(".py").is_file(): |
| 71 | + return True |
| 72 | + |
| 73 | + try: |
| 74 | + return importlib.util.find_spec(module_name) is not None |
| 75 | + except (AttributeError, ImportError, ValueError): |
| 76 | + return False |
| 77 | + |
| 78 | + |
| 79 | +def module_source_file(repo_root: Path, module_name: str): |
| 80 | + base_path = module_base_path(repo_root, module_name) |
| 81 | + file_path = base_path.with_suffix(".py") |
| 82 | + if file_path.is_file(): |
| 83 | + return file_path |
| 84 | + init_path = base_path / "__init__.py" |
| 85 | + if init_path.is_file(): |
| 86 | + return init_path |
| 87 | + return None |
| 88 | + |
| 89 | + |
| 90 | +def source_module_name(repo_root: Path, source_file: Path) -> str: |
| 91 | + relative_path = source_file.relative_to(repo_root) |
| 92 | + if relative_path.name == "__init__.py": |
| 93 | + relative_path = relative_path.parent |
| 94 | + else: |
| 95 | + relative_path = relative_path.with_suffix("") |
| 96 | + return "executorch." + ".".join(relative_path.parts) |
| 97 | + |
| 98 | + |
| 99 | +def target_names(node): |
| 100 | + names = set() |
| 101 | + if isinstance(node, ast.Name): |
| 102 | + names.add(node.id) |
| 103 | + elif isinstance(node, (ast.Tuple, ast.List)): |
| 104 | + for element in node.elts: |
| 105 | + names.update(target_names(element)) |
| 106 | + return names |
| 107 | + |
| 108 | + |
| 109 | +def collect_names_from_import_from( |
| 110 | + repo_root: Path, |
| 111 | + module_name: str, |
| 112 | + node: ast.ImportFrom, |
| 113 | + export_cache: dict[Path, set[str]], |
| 114 | + is_package: bool, |
| 115 | +) -> set[str]: |
| 116 | + names = set() |
| 117 | + try: |
| 118 | + imported_module = resolve_from_module(module_name, node, is_package=is_package) |
| 119 | + except ImportError: |
| 120 | + imported_module = "" |
| 121 | + |
| 122 | + for alias in node.names: |
| 123 | + if alias.name == "*": |
| 124 | + if imported_module.startswith("executorch.") and module_exists( |
| 125 | + repo_root, imported_module |
| 126 | + ): |
| 127 | + source_file = module_source_file(repo_root, imported_module) |
| 128 | + if source_file is not None: |
| 129 | + names.update(exported_names(repo_root, source_file, export_cache)) |
| 130 | + continue |
| 131 | + |
| 132 | + names.add(alias.asname or alias.name) |
| 133 | + |
| 134 | + return names |
| 135 | + |
| 136 | + |
| 137 | +def nested_statement_bodies(node: ast.stmt) -> list[list[ast.stmt]] | None: |
| 138 | + if isinstance(node, ast.If): |
| 139 | + return [node.body, node.orelse] |
| 140 | + if isinstance(node, ast.Try): |
| 141 | + bodies = [node.body, node.orelse, node.finalbody] |
| 142 | + bodies.extend(handler.body for handler in node.handlers) |
| 143 | + return bodies |
| 144 | + if isinstance(node, (ast.For, ast.AsyncFor, ast.While, ast.With, ast.AsyncWith)): |
| 145 | + return [node.body, getattr(node, "orelse", [])] |
| 146 | + if isinstance(node, ast.Match): |
| 147 | + return [case.body for case in node.cases] |
| 148 | + return None |
| 149 | + |
| 150 | + |
| 151 | +def collect_names_from_node( |
| 152 | + repo_root: Path, |
| 153 | + module_name: str, |
| 154 | + node: ast.stmt, |
| 155 | + export_cache: dict[Path, set[str]], |
| 156 | + is_package: bool, |
| 157 | +) -> set[str]: |
| 158 | + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): |
| 159 | + return {node.name} |
| 160 | + if isinstance(node, ast.Import): |
| 161 | + return {alias.asname or alias.name.split(".")[0] for alias in node.names} |
| 162 | + if isinstance(node, ast.ImportFrom): |
| 163 | + return collect_names_from_import_from( |
| 164 | + repo_root, module_name, node, export_cache, is_package |
| 165 | + ) |
| 166 | + if isinstance(node, ast.Assign): |
| 167 | + names = set() |
| 168 | + for target in node.targets: |
| 169 | + names.update(target_names(target)) |
| 170 | + return names |
| 171 | + if isinstance(node, ast.AnnAssign): |
| 172 | + return target_names(node.target) |
| 173 | + |
| 174 | + bodies = nested_statement_bodies(node) |
| 175 | + if bodies is None: |
| 176 | + return set() |
| 177 | + |
| 178 | + names = set() |
| 179 | + for body in bodies: |
| 180 | + names.update( |
| 181 | + collect_exported_names( |
| 182 | + repo_root, module_name, body, export_cache, is_package |
| 183 | + ) |
| 184 | + ) |
| 185 | + return names |
| 186 | + |
| 187 | + |
| 188 | +def collect_exported_names( |
| 189 | + repo_root: Path, |
| 190 | + module_name: str, |
| 191 | + body: list[ast.stmt], |
| 192 | + export_cache: dict[Path, set[str]], |
| 193 | + is_package: bool, |
| 194 | +) -> set[str]: |
| 195 | + names = set() |
| 196 | + for node in body: |
| 197 | + names.update( |
| 198 | + collect_names_from_node( |
| 199 | + repo_root, module_name, node, export_cache, is_package |
| 200 | + ) |
| 201 | + ) |
| 202 | + return names |
| 203 | + |
| 204 | + |
| 205 | +def exported_names( |
| 206 | + repo_root: Path, source_file: Path, export_cache: dict[Path, set[str]] |
| 207 | +) -> set[str]: |
| 208 | + cached_names = export_cache.get(source_file) |
| 209 | + if cached_names is not None: |
| 210 | + return cached_names |
| 211 | + |
| 212 | + names = set() |
| 213 | + export_cache[source_file] = names |
| 214 | + |
| 215 | + module_name = source_module_name(repo_root, source_file) |
| 216 | + tree = ast.parse(source_file.read_text(), filename=str(source_file)) |
| 217 | + names.update( |
| 218 | + collect_exported_names( |
| 219 | + repo_root, |
| 220 | + module_name, |
| 221 | + tree.body, |
| 222 | + export_cache, |
| 223 | + source_file.name == "__init__.py", |
| 224 | + ) |
| 225 | + ) |
| 226 | + return names |
| 227 | + |
| 228 | + |
| 229 | +def resolve_from_module( |
| 230 | + module_name: str, node: ast.ImportFrom, is_package: bool = False |
| 231 | +) -> str: |
| 232 | + if node.level == 0: |
| 233 | + return node.module or "" |
| 234 | + package_name = module_name if is_package else module_name.rpartition(".")[0] |
| 235 | + relative_name = "." * node.level + (node.module or "") |
| 236 | + return importlib.util.resolve_name(relative_name, package_name) |
| 237 | + |
| 238 | + |
| 239 | +def validate_import_from( |
| 240 | + repo_root: Path, |
| 241 | + module_name: str, |
| 242 | + entrypoint: str, |
| 243 | + node: ast.ImportFrom, |
| 244 | + export_cache: dict[Path, set[str]], |
| 245 | +) -> tuple[list[str], int]: |
| 246 | + failures = [] |
| 247 | + try: |
| 248 | + imported_module = resolve_from_module(module_name, node) |
| 249 | + except ImportError as error: |
| 250 | + failures.append( |
| 251 | + f"{entrypoint}:{node.lineno} relative import could not be resolved: {error}" |
| 252 | + ) |
| 253 | + return failures, 0 |
| 254 | + |
| 255 | + if not imported_module.startswith("executorch."): |
| 256 | + return failures, 0 |
| 257 | + |
| 258 | + checks = 1 |
| 259 | + if not module_exists(repo_root, imported_module): |
| 260 | + failures.append( |
| 261 | + f"{entrypoint}:{node.lineno} missing internal module `{imported_module}`" |
| 262 | + ) |
| 263 | + return failures, checks |
| 264 | + |
| 265 | + source_file = module_source_file(repo_root, imported_module) |
| 266 | + exported = ( |
| 267 | + exported_names(repo_root, source_file, export_cache) if source_file else set() |
| 268 | + ) |
| 269 | + |
| 270 | + for alias in node.names: |
| 271 | + if alias.name == "*": |
| 272 | + continue |
| 273 | + submodule_name = f"{imported_module}.{alias.name}" |
| 274 | + if module_exists(repo_root, submodule_name): |
| 275 | + checks += 1 |
| 276 | + continue |
| 277 | + if source_file is None or alias.name not in exported: |
| 278 | + failures.append( |
| 279 | + f"{entrypoint}:{node.lineno} unresolved internal import " |
| 280 | + f"`{alias.name}` from `{imported_module}`" |
| 281 | + ) |
| 282 | + checks += 1 |
| 283 | + |
| 284 | + return failures, checks |
| 285 | + |
| 286 | + |
| 287 | +def validate_entrypoint( |
| 288 | + repo_root: Path, |
| 289 | + examples_root: Path, |
| 290 | + relative_path: str, |
| 291 | + export_cache: dict[Path, set[str]], |
| 292 | +) -> tuple[list[str], int]: |
| 293 | + entrypoint_path = examples_root / relative_path |
| 294 | + if not entrypoint_path.is_file(): |
| 295 | + return [f"{relative_path}: allowlisted entrypoint not found"], 0 |
| 296 | + |
| 297 | + module_name = EXAMPLE_MODULE_PREFIX + str( |
| 298 | + Path(relative_path).with_suffix("") |
| 299 | + ).replace("/", ".") |
| 300 | + tree = ast.parse(entrypoint_path.read_text(), filename=str(entrypoint_path)) |
| 301 | + |
| 302 | + failures = [] |
| 303 | + checks = 0 |
| 304 | + for node in ast.walk(tree): |
| 305 | + if isinstance(node, ast.Import): |
| 306 | + for alias in node.names: |
| 307 | + if alias.name.startswith("executorch."): |
| 308 | + checks += 1 |
| 309 | + if not module_exists(repo_root, alias.name): |
| 310 | + failures.append( |
| 311 | + f"{relative_path}:{node.lineno} missing internal module `{alias.name}`" |
| 312 | + ) |
| 313 | + elif isinstance(node, ast.ImportFrom): |
| 314 | + import_failures, import_checks = validate_import_from( |
| 315 | + repo_root, |
| 316 | + module_name, |
| 317 | + relative_path, |
| 318 | + node, |
| 319 | + export_cache, |
| 320 | + ) |
| 321 | + failures.extend(import_failures) |
| 322 | + checks += import_checks |
| 323 | + |
| 324 | + return failures, checks |
| 325 | + |
| 326 | + |
| 327 | +def main(): |
| 328 | + if sys.version_info < (3, 10): |
| 329 | + print("Python 3.10+ is required to parse QNN example sources") |
| 330 | + sys.exit(1) |
| 331 | + |
| 332 | + examples_root = resolve_examples_root() |
| 333 | + if examples_root is None: |
| 334 | + print(f"QNN examples root not found from {Path(__file__).resolve()}") |
| 335 | + sys.exit(1) |
| 336 | + |
| 337 | + repo_root = examples_root.parent.parent |
| 338 | + entrypoints = discover_entrypoints(examples_root) |
| 339 | + if not entrypoints: |
| 340 | + print(f"No QNN example entrypoints found under {examples_root}") |
| 341 | + sys.exit(1) |
| 342 | + |
| 343 | + all_failures = [] |
| 344 | + total_checks = 0 |
| 345 | + export_cache = {} |
| 346 | + |
| 347 | + for relative_path in entrypoints: |
| 348 | + failures, checks = validate_entrypoint( |
| 349 | + repo_root, examples_root, relative_path, export_cache |
| 350 | + ) |
| 351 | + all_failures.extend(failures) |
| 352 | + total_checks += checks |
| 353 | + |
| 354 | + if total_checks == 0: |
| 355 | + print("No QNN example imports were checked") |
| 356 | + sys.exit(1) |
| 357 | + |
| 358 | + if all_failures: |
| 359 | + print( |
| 360 | + f"{len(all_failures)} unresolved internal import(s) " |
| 361 | + f"across {len(entrypoints)} QNN example entrypoint(s):" |
| 362 | + ) |
| 363 | + for failure in all_failures: |
| 364 | + print(f" FAIL: {failure}") |
| 365 | + sys.exit(1) |
| 366 | + |
| 367 | + print( |
| 368 | + f"Validated {total_checks} internal import(s) across " |
| 369 | + f"{len(entrypoints)} QNN example entrypoint(s)" |
| 370 | + ) |
| 371 | + |
| 372 | + |
| 373 | +if __name__ == "__main__": |
| 374 | + main() |
0 commit comments