Skip to content

Commit 07556c9

Browse files
authored
perf: replace all_inputs_casadi with all_inputs_stacked (#5413)
* perf: replace all_inputs_casadi with all_inputs_stacked * Update solution.py * add test * Update CHANGELOG.md * add cache tests and remove cached property
1 parent e8975a9 commit 07556c9

6 files changed

Lines changed: 73 additions & 20 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## Features
44

5+
- Improved the performance of processed variables by replacing `casadi.vertcat` input stacking with numpy vectors. ([#5413](https://github.com/pybamm-team/PyBaMM/pull/5413))
56
- Allow out of bounds initial state of charge to enable initialising a simulation at a voltage outside the voltage limits. ([#5386](https://github.com/pybamm-team/PyBaMM/pull/5386))
67
- Added `cache_esoh` option to `Simulation` that caches the electrode SOH computation across repeated `solve` calls, avoiding redundant recalculation when eSOH-relevant parameters have not changed. The cached eSOH solver/simulation object is also reused on cache misses to skip expensive model rebuilding. ([#5408](https://github.com/pybamm-team/PyBaMM/pull/5408))
78
- Eliminated the mass matrix inverse and temporary dense matrix objects when building the simulation. ([#5391](https://github.com/pybamm-team/PyBaMM/pull/5391))

src/pybamm/solvers/idaklu_solver.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,6 +1090,7 @@ def reduce_solution(
10901090
)
10911091

10921092
# Propagate metadata from the original solution
1093+
new_sol._all_inputs_stacked = solution.all_inputs_stacked
10931094
new_sol._all_inputs_casadi = solution.all_inputs_casadi
10941095
new_sol.closest_event_idx = solution.closest_event_idx
10951096

src/pybamm/solvers/processed_variable.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(
5252
self.all_ys = solution.all_ys
5353
self.all_yps = solution.all_yps
5454
self.all_inputs = solution.all_inputs
55-
self.all_inputs_casadi = solution.all_inputs_casadi
55+
self.all_inputs_stacked = solution.all_inputs_stacked
5656

5757
self.mesh = base_variables[0].mesh
5858
self.domain = base_variables[0].domain
@@ -127,7 +127,7 @@ def _setup_inputs(self, t, full_range):
127127
ts = self.all_ts
128128
ys = self.all_ys
129129
yps = self.all_yps
130-
inputs = self.all_inputs_casadi
130+
inputs = self.all_inputs_stacked
131131

132132
# Remove all empty ts
133133
idxs = np.where([ti.size > 0 for ti in ts])[0]
@@ -143,7 +143,7 @@ def _setup_inputs(self, t, full_range):
143143
ys = [ys[idx] for idx in idxs]
144144
if self.hermite_interpolation:
145145
yps = [yps[idx] for idx in idxs]
146-
inputs = [self.all_inputs_casadi[idx] for idx in idxs]
146+
inputs = [inputs[idx] for idx in idxs]
147147

148148
is_f_contiguous = _is_f_contiguous(ys)
149149

@@ -422,7 +422,7 @@ def initialise_sensitivity_explicit_forward(self):
422422
for ts, ys, inputs_stacked, inputs, base_variable, dy_dp in zip(
423423
self.all_ts,
424424
self.all_ys,
425-
self.all_inputs_casadi,
425+
self.all_inputs_stacked,
426426
self.all_inputs,
427427
self.base_variables,
428428
self.all_solution_sensitivities["all"],
@@ -508,19 +508,21 @@ def _stub_solution(self):
508508
"""
509509

510510
class StubSolution:
511-
def __init__(self, ts, ys, inputs, inputs_casadi, sensitivities, t_pts):
511+
def __init__(
512+
self, ts, ys, inputs, inputs_stacked, sensitivities, t_pts
513+
):
512514
self.all_ts = ts
513515
self.all_ys = ys
514516
self.all_inputs = inputs
515-
self.all_inputs_casadi = inputs_casadi
517+
self.all_inputs_stacked = inputs_stacked
516518
self.sensitivities = sensitivities
517519
self.t = t_pts
518520

519521
return StubSolution(
520522
self.all_ts,
521523
self.all_ys,
522524
self.all_inputs,
523-
self.all_inputs_casadi,
525+
self.all_inputs_stacked,
524526
self.sensitivities,
525527
self.t_pts,
526528
)
@@ -568,7 +570,7 @@ def _observe_postfix(self, entries, t):
568570
if self.time_integral is None:
569571
return entries
570572
return self.time_integral.postfix(
571-
entries, self.t_pts, self.all_inputs_casadi[0]
573+
entries, self.t_pts, self.all_inputs_stacked[0]
572574
)
573575

574576
def _interp_setup(self, entries, t):

src/pybamm/solvers/processed_variable_computed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __init__(
6363
self.all_ts = solution.all_ts
6464
self.all_ys = solution.all_ys
6565
self.all_inputs = solution.all_inputs
66-
self.all_inputs_casadi = solution.all_inputs_casadi
66+
self.all_inputs_stacked = solution.all_inputs_stacked
6767

6868
self.mesh = base_variables[0].mesh
6969
self.domain = base_variables[0].domain

src/pybamm/solvers/solution.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ def __init__(
155155
self.solve_time = None
156156
self.integration_time = None
157157

158+
self._all_inputs_stacked = None
159+
self._all_inputs_casadi = None
160+
158161
# initialize empty variable cache and data
159162
self._variables = {}
160163
self._data = pybamm.FuzzyDict()
@@ -337,9 +340,21 @@ def all_models(self):
337340
"""Model(s) used for solution"""
338341
return self._all_models
339342

340-
@cached_property
341-
def all_inputs_casadi(self):
342-
return [casadi.vertcat(*inp.values()) for inp in self.all_inputs]
343+
@property
344+
def all_inputs_stacked(self) -> list[np.ndarray]:
345+
if self._all_inputs_stacked is None:
346+
self._all_inputs_stacked = [
347+
np.asarray(list(inp.values())).reshape(-1) for inp in self.all_inputs
348+
]
349+
return self._all_inputs_stacked
350+
351+
@property
352+
def all_inputs_casadi(self) -> list[casadi.DM]:
353+
if self._all_inputs_casadi is None:
354+
self._all_inputs_casadi = [
355+
casadi.vertcat(inp) for inp in self.all_inputs_stacked
356+
]
357+
return self._all_inputs_casadi
343358

344359
@property
345360
def all_yps(self) -> list[np.ndarray | casadi.DM | casadi.MX] | None:
@@ -421,6 +436,7 @@ def first_state(self):
421436
all_sensitivities=sensitivities,
422437
all_yps=all_yps,
423438
)
439+
new_sol._all_inputs_stacked = self.all_inputs_stacked[:1]
424440
new_sol._all_inputs_casadi = self.all_inputs_casadi[:1]
425441
new_sol._sub_solutions = self.sub_solutions[:1]
426442

@@ -464,6 +480,7 @@ def last_state(self):
464480
all_sensitivities=sensitivities,
465481
all_yps=all_yps,
466482
)
483+
new_sol._all_inputs_stacked = self.all_inputs_stacked[-1:]
467484
new_sol._all_inputs_casadi = self.all_inputs_casadi[-1:]
468485
new_sol._sub_solutions = self.sub_solutions[-1:]
469486
new_sol.solve_time = 0
@@ -621,20 +638,23 @@ def _convert_to_casadi(self, var_pybamm, inputs, ys_shape):
621638
def process_casadi_var(self, var_pybamm, inputs, ys_shape):
622639
t_MX = casadi.MX.sym("t")
623640
y_MX = casadi.MX.sym("y", ys_shape[0])
624-
inputs_MX_dict = {
625-
key: casadi.MX.sym("input", value.shape[0]) for key, value in inputs.items()
626-
}
627-
inputs_MX = casadi.vertcat(*[p for p in inputs_MX_dict.values()])
641+
total_input_size = sum(v.size for v in inputs.values())
642+
inputs_MX = casadi.MX.sym("p", total_input_size)
643+
inputs_MX_dict = {}
644+
offset = 0
645+
for key, value in inputs.items():
646+
n = value.size
647+
inputs_MX_dict[key] = inputs_MX[offset : offset + n]
648+
offset += n
628649
var_sym = var_pybamm.to_casadi(t_MX, y_MX, inputs=inputs_MX_dict)
629650

630651
opts = {
631652
"cse": True,
632653
"inputs_check": False,
633-
"is_diff_in": [False, False, False],
634-
"is_diff_out": [False],
654+
"is_diff_in": [False, True, False],
655+
"is_diff_out": [True],
635656
"regularity_check": False,
636657
"error_on_fail": False,
637-
"enable_jacobian": False,
638658
}
639659

640660
# Casadi has a bug where it does not correctly handle arrays with
@@ -940,6 +960,7 @@ def __add__(self, other):
940960
)
941961

942962
new_sol.closest_event_idx = other.closest_event_idx
963+
new_sol._all_inputs_stacked = self.all_inputs_stacked + other.all_inputs_stacked
943964
new_sol._all_inputs_casadi = self.all_inputs_casadi + other.all_inputs_casadi
944965

945966
# Add timers (if available)
@@ -978,6 +999,7 @@ def copy(self):
978999
all_t_evals=self._all_t_evals,
9791000
variables_returned=self.variables_returned,
9801001
)
1002+
new_sol._all_inputs_stacked = self.all_inputs_stacked
9811003
new_sol._all_inputs_casadi = self.all_inputs_casadi
9821004
new_sol._sub_solutions = self.sub_solutions
9831005
new_sol.closest_event_idx = self.closest_event_idx
@@ -1105,6 +1127,7 @@ def make_cycle_solution(
11051127
if sum_sols.variables_returned:
11061128
cycle_solution._variables = sum_sols._variables
11071129

1130+
cycle_solution._all_inputs_stacked = sum_sols.all_inputs_stacked
11081131
cycle_solution._all_inputs_casadi = sum_sols.all_inputs_casadi
11091132
cycle_solution._sub_solutions = sum_sols.sub_solutions
11101133

tests/unit/test_solvers/test_solution.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,10 @@ def test_add_solutions(self):
280280
sol_sum.y, np.concatenate([y1, y2[:, 1:]], axis=1)
281281
)
282282
np.testing.assert_array_equal(sol_sum.all_inputs, [{"a": 1}, {"a": 2}])
283+
assert sol_sum.all_inputs_stacked[0] is sol1.all_inputs_stacked[0]
284+
assert sol_sum.all_inputs_stacked[1] is sol2.all_inputs_stacked[0]
285+
assert sol_sum.all_inputs_casadi[0] is sol1.all_inputs_casadi[0]
286+
assert sol_sum.all_inputs_casadi[1] is sol2.all_inputs_casadi[0]
283287

284288
# Test sub-solutions
285289
assert len(sol_sum.sub_solutions) == 2
@@ -402,7 +406,8 @@ def test_copy(self):
402406
for ys_copy, ys1 in zip(sol_copy.all_ys, sol1.all_ys, strict=False):
403407
np.testing.assert_array_equal(ys_copy, ys1)
404408
assert sol_copy.all_inputs == sol1.all_inputs
405-
assert sol_copy.all_inputs_casadi == sol1.all_inputs_casadi
409+
assert sol_copy.all_inputs_stacked is sol1.all_inputs_stacked
410+
assert sol_copy.all_inputs_casadi is sol1.all_inputs_casadi
406411
assert sol_copy.set_up_time == sol1.set_up_time
407412
assert sol_copy.solve_time == sol1.solve_time
408413
assert sol_copy.integration_time == sol1.integration_time
@@ -432,6 +437,26 @@ def test_copy_with_computed_variables(self):
432437
)
433438
assert sol2.variables_returned is True
434439

440+
def test_all_inputs(self):
441+
t = [np.linspace(0, 1, 10), np.linspace(1, 2, 10)]
442+
t[1][0] = np.nextafter(t[1][0], np.inf)
443+
y = [np.tile(t[0], (5, 1)), np.tile(t[1], (5, 1))]
444+
inputs = [{"a": 1.0, "b": 2.0, "c": 3.0}, {"a": 4.0, "b": 5.0, "c": 6.0}]
445+
sol = pybamm.Solution(t, y, pybamm.BaseModel(), inputs)
446+
447+
stacked = sol.all_inputs_stacked
448+
assert len(stacked) == 2
449+
for s, inp in zip(stacked, inputs, strict=True):
450+
assert isinstance(s, np.ndarray)
451+
# check that it's a vector
452+
assert s.shape == (len(inp),)
453+
np.testing.assert_array_equal(s, np.array(list(inp.values())))
454+
455+
casadi_inputs = sol.all_inputs_casadi
456+
assert len(casadi_inputs) == 2
457+
for c, s in zip(casadi_inputs, stacked, strict=True):
458+
np.testing.assert_array_equal(np.array(c).flatten(), s)
459+
435460
def test_last_state(self):
436461
# Set up first solution
437462
t1 = [np.linspace(0, 1), np.linspace(1, 2, 5)]
@@ -447,6 +472,7 @@ def test_last_state(self):
447472
assert sol_last_state.all_ts[0] == 2
448473
np.testing.assert_array_equal(sol_last_state.all_ys[0], 2)
449474
assert sol_last_state.all_inputs == sol1.all_inputs[-1:]
475+
assert sol_last_state.all_inputs_stacked == sol1.all_inputs_stacked[-1:]
450476
assert sol_last_state.all_inputs_casadi == sol1.all_inputs_casadi[-1:]
451477
assert sol_last_state.all_models == sol1.all_models[-1:]
452478
assert sol_last_state.set_up_time == 0

0 commit comments

Comments
 (0)