Make npt.NDArray type hints more specific with dtype#4901
Make npt.NDArray type hints more specific with dtype#4901kratman merged 14 commits intopybamm-team:developfrom
Conversation
Saransh-cpp
left a comment
There was a problem hiding this comment.
Thanks, @vidipsingh! Could you attach the output of mypy run in the PR description?
@Saransh-cpp, I have attached the output of |
@vidipsingh, it would be nice if you could:
Thank you! |
Thanks for the feedback, @agriyakhetarpal! I’ll replace the image with the Just to clarify, are you referring to the "Before" vs. "After" comparison of the |
The |
|
I’ve added the Please let me know if any changes are needed! |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## develop #4901 +/- ##
========================================
Coverage 98.57% 98.57%
========================================
Files 304 304
Lines 23645 23656 +11
========================================
+ Hits 23309 23320 +11
Misses 336 336 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Saransh-cpp
left a comment
There was a problem hiding this comment.
Thanks, @vidipsingh! See my comments below. Could you please also comment why you are using Any in all the places where you are using it? Thank you!
| method: Literal["discrete", "continuous"] | ||
| initial_condition: npt.NDArray | ||
| discrete_times: Optional[npt.NDArray] | ||
| initial_condition: npt.NDArray[np.float64] |
There was a problem hiding this comment.
| initial_condition: npt.NDArray[np.float64] | |
| initial_condition: float | npt.NDArray[np.float64] |
| def jax_value( | ||
| self, | ||
| t: npt.NDArray = None, | ||
| t: npt.NDArray[np.float64] = None, |
There was a problem hiding this comment.
| t: npt.NDArray[np.float64] = None, | |
| t: npt.NDArray[np.float64] | None = None, |
| def jax_grad( | ||
| self, | ||
| t: npt.NDArray = None, | ||
| t: npt.NDArray[np.float64] = None, |
There was a problem hiding this comment.
| t: npt.NDArray[np.float64] = None, | |
| t: npt.NDArray[np.float64] | None = None, |
| def _jax_solve( | ||
| self, | ||
| t: Union[float, npt.NDArray], | ||
| t: Union[float, npt.NDArray[np.float64]], |
There was a problem hiding this comment.
| t: Union[float, npt.NDArray[np.float64]], | |
| t: float | npt.NDArray[np.float64], |
| def _jax_jvp_impl( | ||
| self, | ||
| *args: Union[npt.NDArray], | ||
| *args: Union[npt.NDArray[np.float64]], |
There was a problem hiding this comment.
| *args: Union[npt.NDArray[np.float64]], | |
| *args: npt.NDArray[np.float64], |
| self, | ||
| y_bar: npt.NDArray, | ||
| y_bar: npt.NDArray[np.float64], | ||
| invar: Union[str, int], # index or name of input variable |
There was a problem hiding this comment.
| invar: Union[str, int], # index or name of input variable | |
| invar: str | int, # index or name of input variable |
| initial_condition: npt.NDArray | ||
| discrete_times: Optional[npt.NDArray] | ||
| initial_condition: npt.NDArray[np.float64] | ||
| discrete_times: Optional[npt.NDArray[np.float64]] |
There was a problem hiding this comment.
Will -
npt.NDArray[np.float64] | Nonework?
There was a problem hiding this comment.
Will -
npt.NDArray[np.float64] | Nonework?
I think npt.NDArray[np.float64] | None will work, let me look into it.
There was a problem hiding this comment.
Will -
npt.NDArray[np.float64] | Nonework?
npt.NDArray[np.float64] | None will not work for initial_condition because it excludes float (e.g., 0.0 or scalar from evaluate()).
It will work for discrete_times since it already matches npt.NDArray[np.float64] | None.
There was a problem hiding this comment.
How about npt.NDArray[np.float64] | None | float then?
There was a problem hiding this comment.
How about
npt.NDArray[np.float64] | None | floatthen?
Yes, npt.NDArray[np.float64] | None | float will work for initial_condition since it covers scalar float (e.g., 0.0) and arrays.
Thank you for the feedback! I used I'll also work on the other suggested changes. Appreciate it! |
|
@vidipsingh let me know when this is ready for a review again! |
@Saransh-cpp I think it is ready for review now! |
Saransh-cpp
left a comment
There was a problem hiding this comment.
Thanks, @vidipsingh! Maybe we should create a new issue that aims to narrow down the Any dtype.
|
The CI should work once you add |
Sure, I will create a new issue for it then. |
Co-authored-by: Saransh Chopra <saransh0701@gmail.com>
Co-authored-by: Saransh Chopra <saransh0701@gmail.com>
Hi @Saransh-cpp, I have added |
Co-authored-by: Saransh Chopra <saransh0701@gmail.com>
|
Hi @Saransh-cpp, Just pinging you here to check if any changes are required for this PR. |
Saransh-cpp
left a comment
There was a problem hiding this comment.
Hi @vidipsingh, sorry for taking too long. The tests are failing with the error -
ImportError while loading conftest '/Users/runner/work/PyBaMM/PyBaMM/conftest.py'.
conftest.py:3: in <module>
import pybamm
src/pybamm/__init__.py:176: in <module>
from .solvers.idaklu_jax import IDAKLUJax
src/pybamm/solvers/idaklu_jax.py:25: in <module>
class IDAKLUJax:
src/pybamm/solvers/idaklu_jax.py:261: in IDAKLUJax
t: npt.NDArray[np.float64] | None = None,
E TypeError: unsupported operand type(s) for |: 'types.GenericAlias' and 'NoneType'
nox > Command python -m pytest -m unit failed with exit code 4which looks like a missing import. Could you please fix this? Thanks!
Saransh-cpp
left a comment
There was a problem hiding this comment.
Looks good now, thanks, @vidipsingh!
Great! Is there anything else to be done, or is this PR good to be merged? |
|
Thanks! It is good to be merged, but @Saransh-cpp and I moved ourselves out of the pybamm-team/maintainers team and into another one, and the permissions for both aren't in sync yet. @kratman should be able to merge this (and can review if needed, as he previously reviewed the PR). |
Got it, thanks for the heads-up! Let's wait for @kratman to review and merge. |
Description
This PR refines
npt.NDArraytype hints in PyBaMM by adding explicitdtype(e.g.,np.float64for time/state arrays,Anyfor variable cases).Fixes: #4900
Type of change
Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #)
Important checks:
Please confirm the following before marking the PR as ready for review:
nox -s pre-commitnox -s testsnox -s doctestsmypy Output (Before):
mypy Output (After):