Skip to content

Commit

Permalink
feat: More Numpy (#40)
Browse files Browse the repository at this point in the history
* feat: much of the numpy namespace

Signed-off-by: nstarman <nstarman@users.noreply.github.com>

* fix: tests

Signed-off-by: nstarman <nstarman@users.noreply.github.com>

* fix: test

Signed-off-by: nstarman <nstarman@users.noreply.github.com>

---------

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman authored Apr 25, 2024
1 parent 134795f commit e379d65
Show file tree
Hide file tree
Showing 29 changed files with 2,945 additions and 280 deletions.
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ default_stages: [pre-commit, pre-push]

repos:
- repo: https://github.com/commitizen-tools/commitizen
rev: v3.21.3
rev: v3.24.0
hooks:
- id: commitizen
- id: commitizen-branch
Expand All @@ -18,12 +18,12 @@ repos:
- id: check-useless-excludes

- repo: https://github.com/scientific-python/cookie
rev: 2024.03.10
rev: 2024.04.23
hooks:
- id: sp-repo-review

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: "v4.5.0"
rev: "v4.6.0"
hooks:
- id: check-added-large-files
- id: check-case-conflict
Expand All @@ -47,14 +47,14 @@ repos:
- id: rst-inline-touching-normal

- repo: https://github.com/python-jsonschema/check-jsonschema
rev: 0.28.1
rev: 0.28.2
hooks:
- id: check-dependabot
- id: check-github-workflows
- id: check-readthedocs

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.3.5"
rev: "v0.4.1"
hooks:
# Run the linter
- id: ruff
Expand All @@ -78,7 +78,7 @@ repos:
args: [--prose-wrap=always]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.9.0"
rev: "v1.10.0"
hooks:
- id: mypy
files: src
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ ignore = [
"D103", # Missing docstring in public function # TODO
"D203", # one-blank-line-before-class
"D213", # Multi-line docstring summary should start at the second line
"ISC001", # handled by formatter
"ERA001", # Found commented-out code
"F811", # Redefinition of unused variable <- plum
"FIX002", # Line contains TODO, consider resolving the issue
Expand Down
File renamed without changes.
10 changes: 10 additions & 0 deletions src/quaxed/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import TypeVar

import quax

T = TypeVar("T")


def quaxify(func: T, *, filter_spec: bool | tuple[bool, ...] = True) -> T:
"""Quaxify, but makes mypy happy."""
return quax.quaxify(func, filter_spec=filter_spec)
3 changes: 3 additions & 0 deletions src/quaxed/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
_constants,
_creation_functions,
_data_type_functions,
_dispatch,
_elementwise_functions,
_indexing_functions,
_linear_algebra_functions,
Expand All @@ -28,6 +29,7 @@
from ._constants import *
from ._creation_functions import *
from ._data_type_functions import *
from ._dispatch import *
from ._elementwise_functions import *
from ._indexing_functions import *
from ._linear_algebra_functions import *
Expand All @@ -51,3 +53,4 @@
__all__ += _sorting_functions.__all__
__all__ += _statistical_functions.__all__
__all__ += _utility_functions.__all__
__all__ += _dispatch.__all__
5 changes: 3 additions & 2 deletions src/quaxed/array_api/_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
from jaxtyping import ArrayLike
from quax import Value

from quaxed._types import DType
from quaxed._utils import quaxify

from ._dispatch import dispatcher
from ._types import DType
from ._utils import quaxify

T = TypeVar("T")

Expand Down
4 changes: 2 additions & 2 deletions src/quaxed/array_api/_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from jax.experimental.array_api._data_type_functions import FInfo, IInfo
from jaxtyping import ArrayLike

from ._types import DType
from ._utils import quaxify
from quaxed._types import DType
from quaxed._utils import quaxify


@quaxify
Expand Down
4 changes: 2 additions & 2 deletions src/quaxed/array_api/_dispatch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Dispatching."""
""":mod:`jax.experimental.array_api` Dispatching."""

__all__: list[str] = []
__all__ = ["dispatcher"]

import plum

Expand Down
2 changes: 1 addition & 1 deletion src/quaxed/array_api/_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from jaxtyping import ArrayLike
from quax import Value

from ._utils import quaxify
from quaxed._utils import quaxify


@quaxify
Expand Down
2 changes: 1 addition & 1 deletion src/quaxed/array_api/_indexing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jaxtyping import ArrayLike
from quax import Value

from ._utils import quaxify
from quaxed._utils import quaxify


@quaxify
Expand Down
2 changes: 1 addition & 1 deletion src/quaxed/array_api/_linear_algebra_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jaxtyping import ArrayLike
from quax import Value

from ._utils import quaxify
from quaxed._utils import quaxify


@quaxify
Expand Down
2 changes: 1 addition & 1 deletion src/quaxed/array_api/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from jaxtyping import ArrayLike
from quax import Value

from ._utils import quaxify
from quaxed._utils import quaxify


@quaxify
Expand Down
2 changes: 1 addition & 1 deletion src/quaxed/array_api/_searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from jaxtyping import ArrayLike
from quax import Value

from ._utils import quaxify
from quaxed._utils import quaxify


@quaxify
Expand Down
2 changes: 1 addition & 1 deletion src/quaxed/array_api/_set_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from jaxtyping import ArrayLike
from quax import Value

from ._utils import quaxify
from quaxed._utils import quaxify


@quaxify
Expand Down
2 changes: 1 addition & 1 deletion src/quaxed/array_api/_sorting_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jaxtyping import ArrayLike
from quax import Value

from ._utils import quaxify
from quaxed._utils import quaxify


@quaxify
Expand Down
4 changes: 2 additions & 2 deletions src/quaxed/array_api/_statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from jaxtyping import ArrayLike
from quax import Value

from ._types import DType
from ._utils import quaxify
from quaxed._types import DType
from quaxed._utils import quaxify


@quaxify
Expand Down
2 changes: 1 addition & 1 deletion src/quaxed/array_api/_utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from jaxtyping import ArrayLike
from quax import Value

from ._utils import quaxify
from quaxed._utils import quaxify


@quaxify
Expand Down
10 changes: 0 additions & 10 deletions src/quaxed/array_api/_utils.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/quaxed/array_api/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from jaxtyping import ArrayLike
from quax import Value

from ._utils import quaxify
from quaxed._utils import quaxify


@quaxify
Expand Down
4 changes: 2 additions & 2 deletions src/quaxed/array_api/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
from jaxtyping import ArrayLike
from quax import Value

from ._types import DType
from ._utils import quaxify
from quaxed._types import DType
from quaxed._utils import quaxify


@quaxify
Expand Down
12 changes: 9 additions & 3 deletions src/quaxed/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,15 @@

from jaxtyping import install_import_hook

with install_import_hook("quaxed", None):
from . import _core
from ._core import *
with install_import_hook("quaxed.numpy", None):
from . import _core, _creation_functions, _dispatch, _higher_order
from ._core import * # TODO: make this lazy
from ._creation_functions import *
from ._dispatch import *
from ._higher_order import *

__all__: list[str] = []
__all__ += _core.__all__
__all__ += _higher_order.__all__
__all__ += _creation_functions.__all__
__all__ += _dispatch.__all__
Loading

0 comments on commit e379d65

Please sign in to comment.