diff --git a/environment.yml b/environment.yml index 98be2b7856..d9e078aa7a 100644 --- a/environment.yml +++ b/environment.yml @@ -39,12 +39,7 @@ dependencies: - pydot - ipython # code style - - black - - isort - # For linting - - flake8 - - pep8 - - pyflakes + - ruff # developer tools - pre-commit - packaging diff --git a/pytensor/__init__.py b/pytensor/__init__.py index 8e01416b93..c3ab34c9c3 100644 --- a/pytensor/__init__.py +++ b/pytensor/__init__.py @@ -70,12 +70,9 @@ def disable_log_handler(logger=pytensor_logger, handler=logging_default_handler) # very rarely. __api_version__ = 1 -# isort: off from pytensor.graph.basic import Variable from pytensor.graph.replace import clone_replace, graph_replace -# isort: on - def as_symbolic(x: Any, name: Optional[str] = None, **kwargs) -> Variable: """Convert `x` into an equivalent PyTensor `Variable`. @@ -115,7 +112,6 @@ def _as_symbolic(x, **kwargs) -> Variable: return as_tensor_variable(x, **kwargs) -# isort: off from pytensor import scalar, tensor from pytensor.compile import ( In, @@ -134,8 +130,6 @@ def _as_symbolic(x, **kwargs) -> Variable: from pytensor.printing import pp, pprint from pytensor.updates import OrderedUpdates -# isort: on - def get_underlying_scalar_constant(v): """Return the constant scalar (i.e. 0-D) value underlying variable `v`. @@ -157,15 +151,12 @@ def get_underlying_scalar_constant(v): return tensor.get_underlying_scalar_constant_value(v) -# isort: off -import pytensor.tensor.random.var import pytensor.sparse +import pytensor.tensor.random.var from pytensor.scan import checkpoints from pytensor.scan.basic import scan from pytensor.scan.views import foldl, foldr, map, reduce -# isort: on - # Some config variables are registered by submodules. Only after all those # imports were executed, we can warn about remaining flags provided by the user diff --git a/pytensor/graph/__init__.py b/pytensor/graph/__init__.py index 189dfed237..1e2f982421 100644 --- a/pytensor/graph/__init__.py +++ b/pytensor/graph/__init__.py @@ -1,20 +1,17 @@ """Graph objects and manipulation functions.""" -# isort: off from pytensor.graph.basic import ( Apply, - Variable, Constant, - graph_inputs, - clone, + Variable, ancestors, + clone, + graph_inputs, ) -from pytensor.graph.replace import clone_replace, graph_replace, vectorize_graph -from pytensor.graph.op import Op -from pytensor.graph.type import Type from pytensor.graph.fg import FunctionGraph -from pytensor.graph.rewriting.basic import node_rewriter, graph_rewriter -from pytensor.graph.rewriting.utils import rewrite_graph +from pytensor.graph.op import Op +from pytensor.graph.replace import clone_replace, graph_replace, vectorize_graph +from pytensor.graph.rewriting.basic import graph_rewriter, node_rewriter from pytensor.graph.rewriting.db import RewriteDatabaseQuery - -# isort: on +from pytensor.graph.rewriting.utils import rewrite_graph +from pytensor.graph.type import Type diff --git a/pytensor/link/jax/dispatch/__init__.py b/pytensor/link/jax/dispatch/__init__.py index 0a12442a97..bada854acf 100644 --- a/pytensor/link/jax/dispatch/__init__.py +++ b/pytensor/link/jax/dispatch/__init__.py @@ -1,18 +1,15 @@ -# isort: off -from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify +import pytensor.link.jax.dispatch.blockwise +import pytensor.link.jax.dispatch.elemwise +import pytensor.link.jax.dispatch.extra_ops +import pytensor.link.jax.dispatch.nlinalg +import pytensor.link.jax.dispatch.random # Load dispatch specializations import pytensor.link.jax.dispatch.scalar -import pytensor.link.jax.dispatch.tensor_basic -import pytensor.link.jax.dispatch.subtensor +import pytensor.link.jax.dispatch.scan import pytensor.link.jax.dispatch.shape -import pytensor.link.jax.dispatch.extra_ops -import pytensor.link.jax.dispatch.nlinalg import pytensor.link.jax.dispatch.slinalg -import pytensor.link.jax.dispatch.random -import pytensor.link.jax.dispatch.elemwise -import pytensor.link.jax.dispatch.scan import pytensor.link.jax.dispatch.sparse -import pytensor.link.jax.dispatch.blockwise - -# isort: on +import pytensor.link.jax.dispatch.subtensor +import pytensor.link.jax.dispatch.tensor_basic +from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify diff --git a/pytensor/link/numba/dispatch/__init__.py b/pytensor/link/numba/dispatch/__init__.py index 9810e14178..dacdae2184 100644 --- a/pytensor/link/numba/dispatch/__init__.py +++ b/pytensor/link/numba/dispatch/__init__.py @@ -1,15 +1,12 @@ -# isort: off -from pytensor.link.numba.dispatch.basic import numba_funcify, numba_typify - -# Load dispatch specializations -import pytensor.link.numba.dispatch.scalar -import pytensor.link.numba.dispatch.tensor_basic +import pytensor.link.numba.dispatch.elemwise import pytensor.link.numba.dispatch.extra_ops import pytensor.link.numba.dispatch.nlinalg import pytensor.link.numba.dispatch.random -import pytensor.link.numba.dispatch.elemwise + +# Load dispatch specializations +import pytensor.link.numba.dispatch.scalar import pytensor.link.numba.dispatch.scan -import pytensor.link.numba.dispatch.sparse import pytensor.link.numba.dispatch.slinalg - -# isort: on +import pytensor.link.numba.dispatch.sparse +import pytensor.link.numba.dispatch.tensor_basic +from pytensor.link.numba.dispatch.basic import numba_funcify, numba_typify diff --git a/pytensor/tensor/__init__.py b/pytensor/tensor/__init__.py index 3112892697..488ae6b5b0 100644 --- a/pytensor/tensor/__init__.py +++ b/pytensor/tensor/__init__.py @@ -113,7 +113,6 @@ def _get_vector_length_Constant(op: Union[Op, Variable], var: Constant) -> int: import pytensor.tensor.rewriting -# isort: off from pytensor.tensor import linalg # noqa from pytensor.tensor import special @@ -121,7 +120,6 @@ def _get_vector_length_Constant(op: Union[Op, Variable], var: Constant) -> int: from pytensor.tensor import nlinalg # noqa from pytensor.tensor import slinalg # noqa -# isort: on from pytensor.tensor.basic import * # noqa from pytensor.tensor.blas import batched_dot, batched_tensordot # noqa from pytensor.tensor.extra_ops import * diff --git a/pytensor/tensor/random/rewriting/__init__.py b/pytensor/tensor/random/rewriting/__init__.py index 2c32c16b33..71a6d6703d 100644 --- a/pytensor/tensor/random/rewriting/__init__.py +++ b/pytensor/tensor/random/rewriting/__init__.py @@ -1,10 +1,4 @@ # TODO: This is for backward-compatibility; remove when reasonable. -from pytensor.tensor.random.rewriting.basic import * - - -# isort: off - # Register JAX specializations import pytensor.tensor.random.rewriting.jax - -# isort: on +from pytensor.tensor.random.rewriting.basic import *