Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add additional dependencies for mypy in pre-commit #1292

Merged
merged 4 commits into from
Apr 25, 2024

Conversation

mcwitt
Copy link
Collaborator

@mcwitt mcwitt commented Apr 24, 2024

Currently mypy, when run via pre-commit, treats imports from jax and other libraries as having type Any, resulting in it failing to catch some type errors. This is due to the combination of 1) pre-commit running checks in isolated environments without dependencies installed by default, and 2) the mypy pre-commit hook automatically passing --ignore-missing-imports in the mypy invocation. See this issue comment for more details.

This PR:

  1. Adds the remaining dependencies that export types to additional_dependencies for mypy in the pre-commit configuration
  2. Fixes any newly-exposed type errors

Notes on typing fixes:

  • There are a few cases where functions using jax were annotated as returning float, but actually return a scalar Array object. It's unfortunate that there doesn't seem to be a way to indicate that the return value is a scalar.
  • I tried to follow the approach outlined in https://jax.readthedocs.io/en/latest/jax.typing.html#jax-typing-best-practices to consistently annotate functions using jax (but only for cases where mypy checks were failing, for now). TL;DR inputs should typically be ArrayLike and outputs Array, and jnp.asarray used to convert inputs where necessary.
  • In several places we set Array = Any and link to this (now closed) JAX issue: One array type to rule them all! jax-ml/jax#943. Now that jax.typing has been introduced we could probably make all of these cases more precise, but to stay focused this PR only changes instances where mypy checks were failing.

@mcwitt mcwitt force-pushed the add-mypy-additional-deps branch from 155b164 to f61cda3 Compare April 24, 2024 23:29
@mcwitt mcwitt force-pushed the add-mypy-additional-deps branch from f61cda3 to 0036365 Compare April 24, 2024 23:50
@mcwitt mcwitt marked this pull request as ready for review April 24, 2024 23:57
@mcwitt mcwitt requested review from badisa and maxentile April 24, 2024 23:57
@@ -129,7 +128,7 @@ def nonbonded_block_unsummed(
lj = lennard_jones(dij, sig_ij, eps_ij)

nrgs = jnp.where(dij < cutoff, es + lj, 0)
return nrgs
return cast(Array, nrgs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curiosity: What type is this casting from?

Copy link
Collaborator Author

@mcwitt mcwitt Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type of jnp.where is Array | tuple[Array, ...], I think because when where is called with only one argument the behavior is different. In this case, we know it will return an Array. (Note that cast is a no-op at runtime; it serves as an assertion to the type checker that a value has a certain type.)

Copy link
Collaborator

@maxentile maxentile left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good fixes!

(I agree with the notes in the PR description -- commented inline on two examples, but I don't think either are currently actionable)

Comment on lines -6 to +13
def fixed_to_float(v: int | jnp.uint64) -> jnp.float64:
def fixed_to_float(v: ArrayLike) -> Array:
"""Meant to imitate the logic of timemachine/cpp/src/fixed_point.hpp::FIXED_TO_FLOAT"""
return jnp.float64(jnp.int64(jnp.uint64(v))) / custom_ops.FIXED_EXPONENT


def float_to_fixed(v: jnp.float32 | float) -> jnp.uint64:
def float_to_fixed(v: ArrayLike) -> Array:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two seem more accurate but less precise

Comment on lines -477 to +479
def coulomb_prefactor_on_atom(x_i, x_others, q_others, box=None, beta=2.0, cutoff=jnp.inf) -> float:
def coulomb_prefactor_on_atom(x_i, x_others, q_others, box=None, beta=2.0, cutoff=jnp.inf) -> Array:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed with the PR comment -- may be desirable later to restore an annotation that functions like this return a scalar / are non-broadcasting (i.e. they need to be transformed by vmap or similar). But the docstring + context seems sufficient.

@mcwitt mcwitt enabled auto-merge (squash) April 25, 2024 16:15
@mcwitt mcwitt merged commit 70706e3 into master Apr 25, 2024
1 check passed
@mcwitt mcwitt deleted the add-mypy-additional-deps branch April 25, 2024 17:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants