Skip to content

Commit

Permalink
remove isort mentions
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Jan 11, 2024
1 parent 7280917 commit 7f55843
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 58 deletions.
7 changes: 1 addition & 6 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,7 @@ dependencies:
- pydot
- ipython
# code style
- black
- isort
# For linting
- flake8
- pep8
- pyflakes
- ruff
# developer tools
- pre-commit
- packaging
Expand Down
11 changes: 1 addition & 10 deletions pytensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,9 @@ def disable_log_handler(logger=pytensor_logger, handler=logging_default_handler)
# very rarely.
__api_version__ = 1

# isort: off
from pytensor.graph.basic import Variable
from pytensor.graph.replace import clone_replace, graph_replace

# isort: on


def as_symbolic(x: Any, name: Optional[str] = None, **kwargs) -> Variable:
"""Convert `x` into an equivalent PyTensor `Variable`.
Expand Down Expand Up @@ -115,7 +112,6 @@ def _as_symbolic(x, **kwargs) -> Variable:
return as_tensor_variable(x, **kwargs)


# isort: off
from pytensor import scalar, tensor
from pytensor.compile import (
In,
Expand All @@ -134,8 +130,6 @@ def _as_symbolic(x, **kwargs) -> Variable:
from pytensor.printing import pp, pprint
from pytensor.updates import OrderedUpdates

# isort: on


def get_underlying_scalar_constant(v):
"""Return the constant scalar (i.e. 0-D) value underlying variable `v`.
Expand All @@ -157,15 +151,12 @@ def get_underlying_scalar_constant(v):
return tensor.get_underlying_scalar_constant_value(v)


# isort: off
import pytensor.tensor.random.var
import pytensor.sparse
import pytensor.tensor.random.var
from pytensor.scan import checkpoints
from pytensor.scan.basic import scan
from pytensor.scan.views import foldl, foldr, map, reduce

# isort: on


# Some config variables are registered by submodules. Only after all those
# imports were executed, we can warn about remaining flags provided by the user
Expand Down
19 changes: 8 additions & 11 deletions pytensor/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
"""Graph objects and manipulation functions."""

# isort: off
from pytensor.graph.basic import (
Apply,
Variable,
Constant,
graph_inputs,
clone,
Variable,
ancestors,
clone,
graph_inputs,
)
from pytensor.graph.replace import clone_replace, graph_replace, vectorize_graph
from pytensor.graph.op import Op
from pytensor.graph.type import Type
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter, graph_rewriter
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.graph.op import Op
from pytensor.graph.replace import clone_replace, graph_replace, vectorize_graph
from pytensor.graph.rewriting.basic import graph_rewriter, node_rewriter
from pytensor.graph.rewriting.db import RewriteDatabaseQuery

# isort: on
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.graph.type import Type
21 changes: 9 additions & 12 deletions pytensor/link/jax/dispatch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
# isort: off
from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify
import pytensor.link.jax.dispatch.blockwise
import pytensor.link.jax.dispatch.elemwise
import pytensor.link.jax.dispatch.extra_ops
import pytensor.link.jax.dispatch.nlinalg
import pytensor.link.jax.dispatch.random

# Load dispatch specializations
import pytensor.link.jax.dispatch.scalar
import pytensor.link.jax.dispatch.tensor_basic
import pytensor.link.jax.dispatch.subtensor
import pytensor.link.jax.dispatch.scan
import pytensor.link.jax.dispatch.shape
import pytensor.link.jax.dispatch.extra_ops
import pytensor.link.jax.dispatch.nlinalg
import pytensor.link.jax.dispatch.slinalg
import pytensor.link.jax.dispatch.random
import pytensor.link.jax.dispatch.elemwise
import pytensor.link.jax.dispatch.scan
import pytensor.link.jax.dispatch.sparse
import pytensor.link.jax.dispatch.blockwise

# isort: on
import pytensor.link.jax.dispatch.subtensor
import pytensor.link.jax.dispatch.tensor_basic
from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify
17 changes: 7 additions & 10 deletions pytensor/link/numba/dispatch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
# isort: off
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_typify

# Load dispatch specializations
import pytensor.link.numba.dispatch.scalar
import pytensor.link.numba.dispatch.tensor_basic
import pytensor.link.numba.dispatch.elemwise
import pytensor.link.numba.dispatch.extra_ops
import pytensor.link.numba.dispatch.nlinalg
import pytensor.link.numba.dispatch.random
import pytensor.link.numba.dispatch.elemwise

# Load dispatch specializations
import pytensor.link.numba.dispatch.scalar
import pytensor.link.numba.dispatch.scan
import pytensor.link.numba.dispatch.sparse
import pytensor.link.numba.dispatch.slinalg

# isort: on
import pytensor.link.numba.dispatch.sparse
import pytensor.link.numba.dispatch.tensor_basic
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_typify
2 changes: 0 additions & 2 deletions pytensor/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,13 @@ def _get_vector_length_Constant(op: Union[Op, Variable], var: Constant) -> int:
import pytensor.tensor.rewriting


# isort: off
from pytensor.tensor import linalg # noqa
from pytensor.tensor import special

# For backward compatibility
from pytensor.tensor import nlinalg # noqa
from pytensor.tensor import slinalg # noqa

# isort: on
from pytensor.tensor.basic import * # noqa
from pytensor.tensor.blas import batched_dot, batched_tensordot # noqa
from pytensor.tensor.extra_ops import *
Expand Down
8 changes: 1 addition & 7 deletions pytensor/tensor/random/rewriting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
# TODO: This is for backward-compatibility; remove when reasonable.
from pytensor.tensor.random.rewriting.basic import *


# isort: off

# Register JAX specializations
import pytensor.tensor.random.rewriting.jax

# isort: on
from pytensor.tensor.random.rewriting.basic import *

0 comments on commit 7f55843

Please sign in to comment.