Skip to content

Commit e263427

Browse files
authored
Add ahead-of-time casadi compilation (#5448)
* Improve codegen for x-averaging, RegPower and Arcsinh2 * Add AOT casadi compilation backend for IDAKLUSolver * Add asv benchmark for vm vs aot solve and observe * Update CHANGELOG.md * Update DFN-with-particle-size-distributions.ipynb * refactor: compilation: "aot" | "vm" -> compile: True/False * generalize averaging over constants * Update CHANGELOG.md * use `sha1(usedforsecurity=False)` * set allowed subprocess compilers * subprocess comment * cov * fix regpower delta=0 fast path * compile output variables/output sens
1 parent 9209226 commit e263427

23 files changed

Lines changed: 1874 additions & 433 deletions

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
## Features
44

5+
- Improved the codegen performance of averaging over spatial domains. ([#5448](https://github.com/pybamm-team/PyBaMM/pull/5448))
6+
- Added ahead-of-time CasADi compilation to `IDAKLUSolver` via the `"compile": True` option. ([#5448](https://github.com/pybamm-team/PyBaMM/pull/5448))
57
- Adds EIS support via `EISSimulation` class with a restructure to the `Solution` class to support `EISSolution`. Includes examples and citation, with a ~5x solve improvement over pybamm-eis. ([#5433](https://github.com/pybamm-team/PyBaMM/pull/5433))
68
- Refactors `Simulation` class into an inherited class from a `BaseSimulation`. `BaseSimulation` is used for all non-experiment based simulations, with the `Simulation` class adding experiment support. Includes a small code cleanup. ([#5430](https://github.com/pybamm-team/PyBaMM/pull/5430))
79
- Improved pchip interpolation performance. ([#5436](https://github.com/pybamm-team/PyBaMM/pull/5436))

benchmarks/time_solve_models.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,37 @@ def time_solve_model(self, _solve_first, _parameters, _solver_class):
142142
self.solver.solve(self.model, t_eval=self.t_eval, t_interp=self.t_interp)
143143

144144

145+
class TimeRepeatedSolveAndObserveVoltage:
146+
param_names = ["model", "compile"]
147+
params = (
148+
[pybamm.lithium_ion.SPM, pybamm.lithium_ion.SPMe, pybamm.lithium_ion.DFN],
149+
[False, True],
150+
)
151+
sim: pybamm.Simulation
152+
sol: pybamm.Solution
153+
t_eval: list[float]
154+
t_interp: npt.NDArray[np.float64]
155+
156+
def setup(self, model_class, compile):
157+
set_random_seed()
158+
self.sim = pybamm.Simulation(
159+
model_class(),
160+
solver=pybamm.IDAKLUSolver(options={"compile": compile}),
161+
)
162+
self.t_eval = [0.0, 3600.0]
163+
self.t_interp = np.linspace(self.t_eval[0], self.t_eval[-1], 10000)
164+
# Warm the casadi/AOT caches and the voltage observer before timing.
165+
self.sol = self.sim.solve(self.t_eval, t_interp=self.t_interp)
166+
_ = self.sol["Voltage [V]"].data
167+
168+
def time_repeated_solve(self, _model_class, _compile):
169+
self.sim.solve(self.t_eval, t_interp=self.t_interp)
170+
171+
def time_voltage_observe(self, _model_class, _compile):
172+
self.sol._variables.clear()
173+
_ = self.sol["Voltage [V]"].data
174+
175+
145176
class TimeSolveDFN:
146177
param_names = ["solve first", "parameter", "solver_class"]
147178
params = (
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
AOT compilation
2+
===============
3+
4+
.. automodule:: pybamm.codegen.compilation
5+
:members: aot_compile
6+
7+
Environment variables
8+
---------------------
9+
10+
``PYBAMM_CASADI_AOT_CACHE``
11+
On-disk cache directory. Defaults to ``$TMPDIR/pybamm_casadi_aot``.
12+
13+
``PYBAMM_CASADI_AOT_KEEP_C``
14+
If set, retain the generated ``.c`` source next to each compiled library.
15+
Useful for debugging codegen output.

docs/source/api/codegen/index.rst

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
Code generation
2+
===============
3+
4+
PyBaMM evaluates CasADi expressions through one of two backends:
5+
6+
* ``compile=False`` (default) -- CasADi's in-process virtual machine.
7+
* ``compile=True`` -- :func:`pybamm.codegen.compilation.aot_compile` emits C
8+
source for a CasADi ``Function``, compiles it to a shared library and
9+
returns a ``casadi.external`` wrapper. Results are cached in-process and on
10+
disk, keyed by a hash of the serialised function.
11+
12+
The backend is selected via the ``compile`` option on a solver, e.g.
13+
``pybamm.IDAKLUSolver(options={"compile": True})``. The setting is forwarded
14+
to :class:`pybamm.Solution` so post-solve variable observation uses the same
15+
backend as the integration.
16+
17+
.. toctree::
18+
19+
compilation

docs/source/api/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ For a high-level introduction to PyBaMM, see the :ref:`user guide <user_guide>`
2323
meshes/index
2424
spatial_methods/index
2525
solvers/index
26+
codegen/index
2627
experiment/index
2728
simulation
2829
plotting/index

docs/source/examples/notebooks/models/DFN-with-particle-size-distributions.ipynb

Lines changed: 26 additions & 34 deletions
Large diffs are not rendered by default.

src/pybamm/codegen/__init__.py

Whitespace-only changes.

src/pybamm/codegen/compilation.py

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
import hashlib
2+
import os
3+
import re
4+
import subprocess # nosec B404 - compiler validated against _ALLOWED_COMPILERS
5+
import sys
6+
import tempfile
7+
import time
8+
import uuid
9+
10+
import casadi
11+
12+
from pybamm import logger
13+
14+
# Cache of bundle-hash -> list of ``casadi.external`` wrappers, one per
15+
# non-External input to that bundle. A single-Function call is a bundle of
16+
# size one; no separate code path.
17+
_CACHE: dict[str, list[casadi.Function]] = {}
18+
19+
# Only remove build artifacts older than this to avoid racing with another
20+
# process's in-flight compile.
21+
_STALE_TMP_AGE_S = 3600
22+
23+
# Per-attempt temp filenames have the form ``<stem>.<pid>.<32-hex-uuid>.c``
24+
# or ``...<ext>.tmp``.
25+
_PER_ATTEMPT_TOKEN = re.compile(r"\.\d+\.[0-9a-f]{32}(?:\.|$)")
26+
27+
_TMP_FILE_PREFIX = "pybamm_"
28+
29+
# ``int NAME(const casadi_real** arg, ...);`` at the top level of the
30+
# generated C marks an External sub-Function named ``NAME``. Decls for names
31+
# defined in the same TU are fine; decls for anything else mean the caller
32+
# wrapped an inner Function as an External before feeding it to a composite.
33+
_EXTERN_DECL = re.compile(
34+
r"^\s*int\s+([A-Za-z_][A-Za-z0-9_]*)\s*\(\s*const\s+casadi_real\s*\*\*",
35+
re.MULTILINE,
36+
)
37+
38+
_swept_dirs: set[str] = set()
39+
40+
_ALLOWED_COMPILERS = frozenset({"gcc", "clang", "cc", "g++", "clang++"})
41+
42+
43+
def _default_cache_dir() -> str:
44+
d = os.environ.get("PYBAMM_CASADI_AOT_CACHE")
45+
if d:
46+
os.makedirs(d, exist_ok=True)
47+
return d
48+
d = os.path.join(tempfile.gettempdir(), f"{_TMP_FILE_PREFIX}casadi_aot")
49+
os.makedirs(d, exist_ok=True)
50+
return d
51+
52+
53+
def _shared_ext() -> str:
54+
if sys.platform == "darwin":
55+
return ".dylib"
56+
if sys.platform == "win32":
57+
return ".dll"
58+
return ".so"
59+
60+
61+
def aot_compile(fn_or_fns, **kwargs):
62+
"""Ahead-of-time compile one or more casadi ``Function`` objects to a
63+
single shared library and return ``casadi.external`` wrappers.
64+
65+
Accepts either a single ``casadi.Function`` (returns a Function) or a
66+
list/tuple of Functions (returns a list, one per input, in order). In
67+
either case everything is lowered in one ``CodeGenerator`` / ``gcc``
68+
invocation -- a single fn is a bundle of size one.
69+
70+
Intended for the *outermost* Functions a solver hands off (e.g.
71+
``rhs_algebraic``, ``jac_times_cjmass``, ``rootfn``, output-variable
72+
evaluators). Intermediate Functions should stay as MX/SX so
73+
``casadi.CodeGenerator`` can inline them into one translation unit.
74+
Wrapping inner Functions as Externals forces cross-dylib dispatch and
75+
produces unresolvable ``extern`` declarations.
76+
77+
Results are cached in-process (by a hash of the serialised forms) and on
78+
disk under ``$PYBAMM_CASADI_AOT_CACHE`` (default
79+
``$TMPDIR/pybamm_casadi_aot``). Inputs already of class ``External`` are
80+
returned unchanged. On any failure, the original Function(s) are returned
81+
and a warning is logged.
82+
83+
Parameters
84+
----------
85+
fn_or_fns : casadi.Function or list of casadi.Function
86+
**kwargs
87+
``cache_dir``, ``compiler`` and ``flags`` overrides.
88+
"""
89+
is_single = isinstance(fn_or_fns, casadi.Function)
90+
fns = [fn_or_fns] if is_single else list(fn_or_fns)
91+
try:
92+
out = _aot_compile(fns, **kwargs)
93+
except Exception as e:
94+
names = ", ".join(fn.name() for fn in fns)
95+
logger.warning(f"Failed to compile [{names}] with error: {e}")
96+
out = list(fns)
97+
return out[0] if is_single else out
98+
99+
100+
def _aot_compile(
101+
fns: list[casadi.Function],
102+
*,
103+
cache_dir: str | None = None,
104+
compiler: str | None = None,
105+
flags: tuple[str, ...] | None = None,
106+
) -> list[casadi.Function]:
107+
# Pass-through Externals; compile the rest together in one TU.
108+
result: list[casadi.Function] = list(fns)
109+
indices_to_compile = [
110+
i for i, fn in enumerate(fns) if fn.class_name() != "External"
111+
]
112+
if not indices_to_compile:
113+
return result
114+
115+
# Cache key: ordered hash of each fn's name + serialized form.
116+
hasher = hashlib.sha1(usedforsecurity=False)
117+
for idx in indices_to_compile:
118+
fn = fns[idx]
119+
hasher.update(fn.name().encode())
120+
hasher.update(b"\0")
121+
hasher.update(fn.serialize().encode())
122+
hasher.update(b"\0")
123+
key = hasher.hexdigest()[:16]
124+
125+
cached = _CACHE.get(key)
126+
if cached is not None:
127+
for idx, ext_fn in zip(indices_to_compile, cached, strict=True):
128+
result[idx] = ext_fn
129+
return result
130+
131+
if compiler is None:
132+
compiler = "gcc"
133+
if os.path.basename(compiler) not in _ALLOWED_COMPILERS:
134+
raise ValueError(
135+
f"Compiler '{compiler}' not in allowed list: {sorted(_ALLOWED_COMPILERS)}"
136+
)
137+
if flags is None:
138+
flags = ("-O3", "-march=native", "-fPIC")
139+
140+
cdir = cache_dir or _default_cache_dir()
141+
_maybe_sweep_stale(cdir)
142+
143+
# Single-fn bundles get named after the fn for readability; multi-fn
144+
# bundles are hash-only since the member list isn't knowable from the
145+
# filename anyway.
146+
fns_to_compile = [fns[idx] for idx in indices_to_compile]
147+
label = fns_to_compile[0].name() if len(fns_to_compile) == 1 else "bundle"
148+
stem = f"{_TMP_FILE_PREFIX}{label}_{key}"
149+
ext = _shared_ext()
150+
sofile = os.path.join(cdir, stem + ext)
151+
152+
if not os.path.exists(sofile):
153+
gen = casadi.CodeGenerator(stem, {"with_header": False})
154+
for fn in fns_to_compile:
155+
gen.add(fn)
156+
c_source = gen.dump()
157+
158+
bundled = {fn.name() for fn in fns_to_compile}
159+
externs = set(_EXTERN_DECL.findall(c_source)) - bundled
160+
if externs:
161+
raise RuntimeError(
162+
f"References to External sub-Function(s) {sorted(externs)} "
163+
"cannot be linked. aot_compile should only be called on "
164+
"top-level Functions; keep intermediate Functions as MX/SX."
165+
)
166+
167+
# Per-attempt temp paths so concurrent compiles of the same bundle
168+
# can't clobber each other, and so an interrupted build can be
169+
# detected and cleaned up later.
170+
suffix = f".{os.getpid()}.{uuid.uuid4().hex}"
171+
tmp_cfile = os.path.join(cdir, stem + suffix + ".c")
172+
tmp_sofile = os.path.join(cdir, stem + suffix + ext + ".tmp")
173+
try:
174+
with open(tmp_cfile, "w") as f:
175+
f.write(c_source)
176+
subprocess.run( # nosec B603 B607 - compiler validated against allowlist
177+
[compiler, *flags, "-shared", tmp_cfile, "-o", tmp_sofile],
178+
check=True,
179+
)
180+
os.replace(tmp_sofile, sofile)
181+
if os.environ.get("PYBAMM_CASADI_AOT_KEEP_C"):
182+
os.replace(tmp_cfile, os.path.join(cdir, stem + ".c"))
183+
finally:
184+
for p in (tmp_cfile, tmp_sofile):
185+
try:
186+
os.remove(p)
187+
except OSError:
188+
pass
189+
190+
ext_fns: list[casadi.Function] = []
191+
for idx, fn in zip(indices_to_compile, fns_to_compile, strict=True):
192+
ext_fn = casadi.external(fn.name(), sofile)
193+
result[idx] = ext_fn
194+
ext_fns.append(ext_fn)
195+
_CACHE[key] = ext_fns
196+
return result
197+
198+
199+
def _maybe_sweep_stale(cdir: str) -> None:
200+
# Remove leaked per-attempt artifacts and orphan .c files once per
201+
# process. Only touches files matching our naming, and only if older
202+
# than ``_STALE_TMP_AGE_S``.
203+
if cdir in _swept_dirs:
204+
return
205+
_swept_dirs.add(cdir)
206+
207+
try:
208+
entries = os.listdir(cdir)
209+
except OSError:
210+
return
211+
212+
cutoff = time.time() - _STALE_TMP_AGE_S
213+
ext = _shared_ext()
214+
have_so = {n for n in entries if n.endswith(ext) and n.startswith(_TMP_FILE_PREFIX)}
215+
216+
for name in entries:
217+
if not name.startswith(_TMP_FILE_PREFIX):
218+
continue
219+
path = os.path.join(cdir, name)
220+
try:
221+
if os.path.getmtime(path) > cutoff:
222+
continue
223+
224+
is_per_attempt = bool(_PER_ATTEMPT_TOKEN.search(name))
225+
if is_per_attempt and (name.endswith(".tmp") or name.endswith(".c")):
226+
os.remove(path)
227+
continue
228+
229+
if name.endswith(".c") and not is_per_attempt:
230+
stem = name[: -len(".c")]
231+
if (stem + ext) not in have_so:
232+
os.remove(path)
233+
except OSError:
234+
pass

0 commit comments

Comments
 (0)