Skip to content

Commit

Permalink
fix: import caching (#47)
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman authored Apr 26, 2024
1 parent 0acea61 commit 8bd07e0
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 9 deletions.
8 changes: 7 additions & 1 deletion src/quaxed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

__all__ = ["__version__", "array_api"]

import sys
from typing import Any

import plum
Expand All @@ -31,4 +32,9 @@ def __getattr__(name: str) -> Any: # TODO: fuller annotation
# TODO: detect if the attribute is a function or a module.
# If it is a function, quaxify it. If it is a module, return a proxy object
# that quaxifies all of its attributes.
return quaxify(getattr(jax, name))
out = quaxify(getattr(jax, name))

# Cache the function in this module
setattr(sys.modules[__name__], name, out)

return out
15 changes: 11 additions & 4 deletions src/quaxed/operator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Quaxed :mod:`operator`."""

import operator
import sys
from collections.abc import Callable
from typing import Any

Expand All @@ -9,12 +10,18 @@
__all__ = operator.__all__


def __dir__() -> list[str]:
"""List the operators."""
return sorted(__all__)


# TODO: return type hint signature
def __getattr__(name: str) -> Callable[..., Any]:
"""Get the operator."""
return quaxify(getattr(operator, name))
# Quaxify the operator
out = quaxify(getattr(operator, name))

# Cache the function in this module
setattr(sys.modules[__name__], name, out)

def __dir__() -> list[str]:
"""List the operators."""
return sorted(__all__)
return out
12 changes: 8 additions & 4 deletions src/quaxed/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,16 @@
from quax import quaxify


def __dir__() -> list[str]:
return sorted(__all__)


# TODO: better return type annotation
def __getattr__(name: str) -> Callable[..., Any]:
# Quaxify the func
func = quaxify(getattr(jsp, name))
setattr(sys.modules[__name__], name, func)
return func

# Cache the function in this module
setattr(sys.modules[__name__], name, func)

def __dir__() -> list[str]:
return sorted(__all__)
return func

0 comments on commit 8bd07e0

Please sign in to comment.