Skip to content

Commit c5dc816

Browse files
Update src/pybamm/solvers/idaklu_jax.py
Co-authored-by: Saransh Chopra <saransh0701@gmail.com>
1 parent b3e0768 commit c5dc816

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

src/pybamm/solvers/idaklu_jax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,8 @@ def f_isolated(*args, **kwargs):
260260
def jax_value(
261261
self,
262262
t: npt.NDArray[np.float64] | None = None,
263-
inputs: Union[dict, None] = None,
264-
output_variables: Union[list[str], None] = None,
263+
inputs: dict | None = None,
264+
output_variables: list[str] | None = None,
265265
):
266266
"""Helper function to compute the gradient of a jaxified expression
267267

0 commit comments

Comments
 (0)