Skip to content

Commit 54ed6fc

Browse files
Add import smoke tests to PR CI (#19091)
Summary: Add a PR QNN import job that validates backend module imports and statically checks internal imports for runnable Qualcomm example entrypoints. Also fix the stale `ExecutorchBackendConfig` import in the QAIHub stable diffusion example so the new check passes. Differential Revision: D102218906
1 parent 7b5dcc1 commit 54ed6fc

5 files changed

Lines changed: 501 additions & 3 deletions

File tree

Lines changed: 374 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,374 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Qualcomm Innovation Center, Inc.
3+
# All rights reserved
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
"""Validate internal imports for QNN example entrypoints.
9+
10+
Entrypoints are discovered dynamically from `examples/qualcomm` by looking for
11+
Python files that define a standard `if __name__ == "__main__"` block. This
12+
keeps the check focused on runnable scripts while avoiding a hardcoded list
13+
that drifts as examples are added, moved, or removed.
14+
"""
15+
16+
import ast
17+
import importlib.util
18+
import sys
19+
from pathlib import Path
20+
21+
22+
EXAMPLE_MODULE_PREFIX = "executorch.examples.qualcomm."
23+
24+
25+
def resolve_examples_root():
26+
for parent in Path(__file__).resolve().parents:
27+
candidate = parent / "examples" / "qualcomm"
28+
if candidate.is_dir():
29+
return candidate
30+
return None
31+
32+
33+
def is_main_guard(test: ast.AST) -> bool:
34+
if not isinstance(test, ast.Compare):
35+
return False
36+
if len(test.ops) != 1 or len(test.comparators) != 1:
37+
return False
38+
if not isinstance(test.ops[0], ast.Eq):
39+
return False
40+
if not isinstance(test.left, ast.Name) or test.left.id != "__name__":
41+
return False
42+
comparator = test.comparators[0]
43+
return isinstance(comparator, ast.Constant) and comparator.value == "__main__"
44+
45+
46+
def is_entrypoint(tree: ast.AST) -> bool:
47+
for node in ast.walk(tree):
48+
if isinstance(node, ast.If) and is_main_guard(node.test):
49+
return True
50+
return False
51+
52+
53+
def discover_entrypoints(examples_root: Path) -> list[str]:
54+
entrypoints = []
55+
for path in sorted(examples_root.rglob("*.py")):
56+
if path.name == "__init__.py":
57+
continue
58+
tree = ast.parse(path.read_text(), filename=str(path))
59+
if is_entrypoint(tree):
60+
entrypoints.append(path.relative_to(examples_root).as_posix())
61+
return entrypoints
62+
63+
64+
def module_base_path(repo_root: Path, module_name: str) -> Path:
65+
return repo_root.joinpath(*module_name.split(".")[1:])
66+
67+
68+
def module_exists(repo_root: Path, module_name: str) -> bool:
69+
base_path = module_base_path(repo_root, module_name)
70+
if base_path.is_dir() or base_path.with_suffix(".py").is_file():
71+
return True
72+
73+
try:
74+
return importlib.util.find_spec(module_name) is not None
75+
except (AttributeError, ImportError, ValueError):
76+
return False
77+
78+
79+
def module_source_file(repo_root: Path, module_name: str):
80+
base_path = module_base_path(repo_root, module_name)
81+
file_path = base_path.with_suffix(".py")
82+
if file_path.is_file():
83+
return file_path
84+
init_path = base_path / "__init__.py"
85+
if init_path.is_file():
86+
return init_path
87+
return None
88+
89+
90+
def source_module_name(repo_root: Path, source_file: Path) -> str:
91+
relative_path = source_file.relative_to(repo_root)
92+
if relative_path.name == "__init__.py":
93+
relative_path = relative_path.parent
94+
else:
95+
relative_path = relative_path.with_suffix("")
96+
return "executorch." + ".".join(relative_path.parts)
97+
98+
99+
def target_names(node):
100+
names = set()
101+
if isinstance(node, ast.Name):
102+
names.add(node.id)
103+
elif isinstance(node, (ast.Tuple, ast.List)):
104+
for element in node.elts:
105+
names.update(target_names(element))
106+
return names
107+
108+
109+
def collect_names_from_import_from(
110+
repo_root: Path,
111+
module_name: str,
112+
node: ast.ImportFrom,
113+
export_cache: dict[Path, set[str]],
114+
is_package: bool,
115+
) -> set[str]:
116+
names = set()
117+
try:
118+
imported_module = resolve_from_module(module_name, node, is_package=is_package)
119+
except ImportError:
120+
imported_module = ""
121+
122+
for alias in node.names:
123+
if alias.name == "*":
124+
if imported_module.startswith("executorch.") and module_exists(
125+
repo_root, imported_module
126+
):
127+
source_file = module_source_file(repo_root, imported_module)
128+
if source_file is not None:
129+
names.update(exported_names(repo_root, source_file, export_cache))
130+
continue
131+
132+
names.add(alias.asname or alias.name)
133+
134+
return names
135+
136+
137+
def nested_statement_bodies(node: ast.stmt) -> list[list[ast.stmt]] | None:
138+
if isinstance(node, ast.If):
139+
return [node.body, node.orelse]
140+
if isinstance(node, ast.Try):
141+
bodies = [node.body, node.orelse, node.finalbody]
142+
bodies.extend(handler.body for handler in node.handlers)
143+
return bodies
144+
if isinstance(node, (ast.For, ast.AsyncFor, ast.While, ast.With, ast.AsyncWith)):
145+
return [node.body, getattr(node, "orelse", [])]
146+
if isinstance(node, ast.Match):
147+
return [case.body for case in node.cases]
148+
return None
149+
150+
151+
def collect_names_from_node(
152+
repo_root: Path,
153+
module_name: str,
154+
node: ast.stmt,
155+
export_cache: dict[Path, set[str]],
156+
is_package: bool,
157+
) -> set[str]:
158+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
159+
return {node.name}
160+
if isinstance(node, ast.Import):
161+
return {alias.asname or alias.name.split(".")[0] for alias in node.names}
162+
if isinstance(node, ast.ImportFrom):
163+
return collect_names_from_import_from(
164+
repo_root, module_name, node, export_cache, is_package
165+
)
166+
if isinstance(node, ast.Assign):
167+
names = set()
168+
for target in node.targets:
169+
names.update(target_names(target))
170+
return names
171+
if isinstance(node, ast.AnnAssign):
172+
return target_names(node.target)
173+
174+
bodies = nested_statement_bodies(node)
175+
if bodies is None:
176+
return set()
177+
178+
names = set()
179+
for body in bodies:
180+
names.update(
181+
collect_exported_names(
182+
repo_root, module_name, body, export_cache, is_package
183+
)
184+
)
185+
return names
186+
187+
188+
def collect_exported_names(
189+
repo_root: Path,
190+
module_name: str,
191+
body: list[ast.stmt],
192+
export_cache: dict[Path, set[str]],
193+
is_package: bool,
194+
) -> set[str]:
195+
names = set()
196+
for node in body:
197+
names.update(
198+
collect_names_from_node(
199+
repo_root, module_name, node, export_cache, is_package
200+
)
201+
)
202+
return names
203+
204+
205+
def exported_names(
206+
repo_root: Path, source_file: Path, export_cache: dict[Path, set[str]]
207+
) -> set[str]:
208+
cached_names = export_cache.get(source_file)
209+
if cached_names is not None:
210+
return cached_names
211+
212+
names = set()
213+
export_cache[source_file] = names
214+
215+
module_name = source_module_name(repo_root, source_file)
216+
tree = ast.parse(source_file.read_text(), filename=str(source_file))
217+
names.update(
218+
collect_exported_names(
219+
repo_root,
220+
module_name,
221+
tree.body,
222+
export_cache,
223+
source_file.name == "__init__.py",
224+
)
225+
)
226+
return names
227+
228+
229+
def resolve_from_module(
230+
module_name: str, node: ast.ImportFrom, is_package: bool = False
231+
) -> str:
232+
if node.level == 0:
233+
return node.module or ""
234+
package_name = module_name if is_package else module_name.rpartition(".")[0]
235+
relative_name = "." * node.level + (node.module or "")
236+
return importlib.util.resolve_name(relative_name, package_name)
237+
238+
239+
def validate_import_from(
240+
repo_root: Path,
241+
module_name: str,
242+
entrypoint: str,
243+
node: ast.ImportFrom,
244+
export_cache: dict[Path, set[str]],
245+
) -> tuple[list[str], int]:
246+
failures = []
247+
try:
248+
imported_module = resolve_from_module(module_name, node)
249+
except ImportError as error:
250+
failures.append(
251+
f"{entrypoint}:{node.lineno} relative import could not be resolved: {error}"
252+
)
253+
return failures, 0
254+
255+
if not imported_module.startswith("executorch."):
256+
return failures, 0
257+
258+
checks = 1
259+
if not module_exists(repo_root, imported_module):
260+
failures.append(
261+
f"{entrypoint}:{node.lineno} missing internal module `{imported_module}`"
262+
)
263+
return failures, checks
264+
265+
source_file = module_source_file(repo_root, imported_module)
266+
exported = (
267+
exported_names(repo_root, source_file, export_cache) if source_file else set()
268+
)
269+
270+
for alias in node.names:
271+
if alias.name == "*":
272+
continue
273+
submodule_name = f"{imported_module}.{alias.name}"
274+
if module_exists(repo_root, submodule_name):
275+
checks += 1
276+
continue
277+
if source_file is None or alias.name not in exported:
278+
failures.append(
279+
f"{entrypoint}:{node.lineno} unresolved internal import "
280+
f"`{alias.name}` from `{imported_module}`"
281+
)
282+
checks += 1
283+
284+
return failures, checks
285+
286+
287+
def validate_entrypoint(
288+
repo_root: Path,
289+
examples_root: Path,
290+
relative_path: str,
291+
export_cache: dict[Path, set[str]],
292+
) -> tuple[list[str], int]:
293+
entrypoint_path = examples_root / relative_path
294+
if not entrypoint_path.is_file():
295+
return [f"{relative_path}: allowlisted entrypoint not found"], 0
296+
297+
module_name = EXAMPLE_MODULE_PREFIX + str(
298+
Path(relative_path).with_suffix("")
299+
).replace("/", ".")
300+
tree = ast.parse(entrypoint_path.read_text(), filename=str(entrypoint_path))
301+
302+
failures = []
303+
checks = 0
304+
for node in ast.walk(tree):
305+
if isinstance(node, ast.Import):
306+
for alias in node.names:
307+
if alias.name.startswith("executorch."):
308+
checks += 1
309+
if not module_exists(repo_root, alias.name):
310+
failures.append(
311+
f"{relative_path}:{node.lineno} missing internal module `{alias.name}`"
312+
)
313+
elif isinstance(node, ast.ImportFrom):
314+
import_failures, import_checks = validate_import_from(
315+
repo_root,
316+
module_name,
317+
relative_path,
318+
node,
319+
export_cache,
320+
)
321+
failures.extend(import_failures)
322+
checks += import_checks
323+
324+
return failures, checks
325+
326+
327+
def main():
328+
if sys.version_info < (3, 10):
329+
print("Python 3.10+ is required to parse QNN example sources")
330+
sys.exit(1)
331+
332+
examples_root = resolve_examples_root()
333+
if examples_root is None:
334+
print(f"QNN examples root not found from {Path(__file__).resolve()}")
335+
sys.exit(1)
336+
337+
repo_root = examples_root.parent.parent
338+
entrypoints = discover_entrypoints(examples_root)
339+
if not entrypoints:
340+
print(f"No QNN example entrypoints found under {examples_root}")
341+
sys.exit(1)
342+
343+
all_failures = []
344+
total_checks = 0
345+
export_cache = {}
346+
347+
for relative_path in entrypoints:
348+
failures, checks = validate_entrypoint(
349+
repo_root, examples_root, relative_path, export_cache
350+
)
351+
all_failures.extend(failures)
352+
total_checks += checks
353+
354+
if total_checks == 0:
355+
print("No QNN example imports were checked")
356+
sys.exit(1)
357+
358+
if all_failures:
359+
print(
360+
f"{len(all_failures)} unresolved internal import(s) "
361+
f"across {len(entrypoints)} QNN example entrypoint(s):"
362+
)
363+
for failure in all_failures:
364+
print(f" FAIL: {failure}")
365+
sys.exit(1)
366+
367+
print(
368+
f"Validated {total_checks} internal import(s) across "
369+
f"{len(entrypoints)} QNN example entrypoint(s)"
370+
)
371+
372+
373+
if __name__ == "__main__":
374+
main()

0 commit comments

Comments
 (0)