-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add additional dependencies for mypy in pre-commit #1292
Conversation
155b164
to
f61cda3
Compare
f61cda3
to
0036365
Compare
@@ -129,7 +128,7 @@ def nonbonded_block_unsummed( | |||
lj = lennard_jones(dij, sig_ij, eps_ij) | |||
|
|||
nrgs = jnp.where(dij < cutoff, es + lj, 0) | |||
return nrgs | |||
return cast(Array, nrgs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curiosity: What type is this casting from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return type of jnp.where
is Array | tuple[Array, ...]
, I think because when where
is called with only one argument the behavior is different. In this case, we know it will return an Array
. (Note that cast
is a no-op at runtime; it serves as an assertion to the type checker that a value has a certain type.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good fixes!
(I agree with the notes in the PR description -- commented inline on two examples, but I don't think either are currently actionable)
def fixed_to_float(v: int | jnp.uint64) -> jnp.float64: | ||
def fixed_to_float(v: ArrayLike) -> Array: | ||
"""Meant to imitate the logic of timemachine/cpp/src/fixed_point.hpp::FIXED_TO_FLOAT""" | ||
return jnp.float64(jnp.int64(jnp.uint64(v))) / custom_ops.FIXED_EXPONENT | ||
|
||
|
||
def float_to_fixed(v: jnp.float32 | float) -> jnp.uint64: | ||
def float_to_fixed(v: ArrayLike) -> Array: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These two seem more accurate but less precise
def coulomb_prefactor_on_atom(x_i, x_others, q_others, box=None, beta=2.0, cutoff=jnp.inf) -> float: | ||
def coulomb_prefactor_on_atom(x_i, x_others, q_others, box=None, beta=2.0, cutoff=jnp.inf) -> Array: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed with the PR comment -- may be desirable later to restore an annotation that functions like this return a scalar / are non-broadcasting (i.e. they need to be transformed by vmap or similar). But the docstring + context seems sufficient.
Currently mypy, when run via pre-commit, treats imports from jax and other libraries as having type
Any
, resulting in it failing to catch some type errors. This is due to the combination of 1) pre-commit running checks in isolated environments without dependencies installed by default, and 2) the mypy pre-commit hook automatically passing--ignore-missing-imports
in the mypy invocation. See this issue comment for more details.This PR:
additional_dependencies
for mypy in the pre-commit configurationNotes on typing fixes:
float
, but actually return a scalarArray
object. It's unfortunate that there doesn't seem to be a way to indicate that the return value is a scalar.ArrayLike
and outputsArray
, andjnp.asarray
used to convert inputs where necessary.Array = Any
and link to this (now closed) JAX issue: One array type to rule them all! jax-ml/jax#943. Now thatjax.typing
has been introduced we could probably make all of these cases more precise, but to stay focused this PR only changes instances where mypy checks were failing.