From fcfdae49cba0886693edd314922f3379caadab3a Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Fri, 29 Jul 2022 19:23:01 +0900 Subject: [PATCH] cache resolve_kernel,resolve_strides, bump to post --- kernex/__init__.py | 2 +- kernex/interface/kernel_interface.py | 28 ++++++++++++++-------------- kernex/interface/resolve_utils.py | 5 ++++- kernex/src/base.py | 16 ++++++++-------- kernex/src/map.py | 8 ++++---- kernex/src/scan.py | 8 ++++---- requirements/requirements.txt | 3 +-- 7 files changed, 36 insertions(+), 34 deletions(-) diff --git a/kernex/__init__.py b/kernex/__init__.py index 7c4f6d3..54e2d43 100644 --- a/kernex/__init__.py +++ b/kernex/__init__.py @@ -8,4 +8,4 @@ "baseKernelMap", "kernelMap", "offsetKernelMap", "baseKernelScan", "kernelScan", "offsetKernelScan") -__version__ = "0.0.4" +__version__ = "0.0.4post" diff --git a/kernex/interface/kernel_interface.py b/kernex/interface/kernel_interface.py index 15f94e8..ec9329a 100644 --- a/kernex/interface/kernel_interface.py +++ b/kernex/interface/kernel_interface.py @@ -8,8 +8,8 @@ import functools from typing import Callable +import pytreeclass as pytc from jax import numpy as jnp -from pytreeclass import static_field, treeclass from kernex.interface.named_axis import named_axis_wrapper from kernex.interface.resolve_utils import ( @@ -23,17 +23,17 @@ from kernex.src.scan import kernelScan, offsetKernelScan -@treeclass(op=False) +@pytc.treeclass(op=False) class kernelInterface: - kernel_size: tuple[int, ...] | int = static_field() - strides: tuple[int, ...] | int = static_field(default=1) - border: tuple[int, ...] | tuple[tuple[int, int], ...] | int | str = static_field(default=0, repr=False) # fmt: skip - relative: bool = static_field(default=False) - inplace: bool = static_field(default=False) - use_offset: bool = static_field(default=False) - named_axis: dict[int, str] | None = static_field(default=None) - container: dict[Callable, slice | int] = static_field(default_factory=dict) + kernel_size: tuple[int, ...] | int = pytc.static_field() + strides: tuple[int, ...] | int = pytc.static_field(default=1) + border: tuple[int, ...] | tuple[tuple[int, int], ...] | int | str = pytc.static_field(default=0, repr=False) # fmt: skip + relative: bool = pytc.static_field(default=False) + inplace: bool = pytc.static_field(default=False) + use_offset: bool = pytc.static_field(default=False) + named_axis: dict[int, str] | None = pytc.static_field(default=None) + container: dict[Callable, slice | int] = pytc.static_field(default_factory=dict) def __post_init__(self): """resolve the border values and the kernel operation""" @@ -124,7 +124,7 @@ def __call__(self, *args, **kwargs): ) -@treeclass(op=False) +@pytc.treeclass(op=False) class sscan(kernelInterface): def __init__( self, kernel_size=1, strides=1, offset=0, relative=False, named_axis=None @@ -141,7 +141,7 @@ def __init__( ) -@treeclass(op=False) +@pytc.treeclass(op=False) class smap(kernelInterface): def __init__( self, kernel_size=1, strides=1, offset=0, relative=False, named_axis=None @@ -158,7 +158,7 @@ def __init__( ) -@treeclass(op=False) +@pytc.treeclass(op=False) class kscan(kernelInterface): def __init__( self, kernel_size=1, strides=1, padding=0, relative=False, named_axis=None @@ -175,7 +175,7 @@ def __init__( ) -@treeclass(op=False) +@pytc.treeclass(op=False) class kmap(kernelInterface): def __init__( self, kernel_size=1, strides=1, padding=0, relative=False, named_axis=None diff --git a/kernex/interface/resolve_utils.py b/kernex/interface/resolve_utils.py index 1a5b71c..9237ce2 100644 --- a/kernex/interface/resolve_utils.py +++ b/kernex/interface/resolve_utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import functools from typing import Any, Callable import jax @@ -178,6 +179,7 @@ def normalize_slices( return container +@functools.lru_cache(maxsize=None) def resolve_kernel_size(arg, in_dim): kw = "kernel_size" @@ -201,12 +203,13 @@ def resolve_kernel_size(arg, in_dim): return tuple(si if wi == -1 else wi for si, wi in ZIP(in_dim, arg)) elif isinstance(arg, int): - return (arg,) * len(in_dim) + return (in_dim if arg == -1 else arg) * len(in_dim) else: raise ValueError(f"{kw} must be instance of int or tuple. Found {type(arg)}") +@functools.lru_cache(maxsize=None) def resolve_strides(arg, in_dim): kw = "strides" diff --git a/kernex/src/base.py b/kernex/src/base.py index 79cebf2..6f6755c 100644 --- a/kernex/src/base.py +++ b/kernex/src/base.py @@ -3,23 +3,23 @@ from itertools import product from typing import Any, Callable +import pytreeclass as pytc from jax import numpy as jnp -from pytreeclass import static_field, treeclass from pytreeclass.src.decorator_util import cached_property from kernex.src.utils import ZIP, general_arange, general_product, key_search -@treeclass(op=False) +@pytc.treeclass(op=False) class kernelOperation: """base class for all kernel operations""" - func_dict: dict[Callable[[Any], jnp.ndarray] : tuple[int, ...]] = static_field() # fmt: skip - shape: tuple[int, ...] = static_field() - kernel_size: tuple[int, ...] = static_field() - strides: tuple[int, ...] = static_field() - border: tuple[tuple[int, int], ...] = static_field() - relative: bool = static_field() + func_dict: dict[Callable[[Any], jnp.ndarray] : tuple[int, ...]] = pytc.static_field() # fmt: skip + shape: tuple[int, ...] = pytc.static_field() + kernel_size: tuple[int, ...] = pytc.static_field() + strides: tuple[int, ...] = pytc.static_field() + border: tuple[tuple[int, int], ...] = pytc.static_field() + relative: bool = pytc.static_field() @cached_property def pad_width(self): diff --git a/kernex/src/map.py b/kernex/src/map.py index f5f0aca..c82f125 100644 --- a/kernex/src/map.py +++ b/kernex/src/map.py @@ -2,17 +2,17 @@ from typing import Callable +import pytreeclass as pytc from jax import lax from jax import numpy as jnp from jax import vmap -from pytreeclass import treeclass from pytreeclass.src.decorator_util import cached_property from kernex.src.base import kernelOperation from kernex.src.utils import ZIP, ix_, offset_to_padding, roll_view -@treeclass(op=False) +@pytc.treeclass(op=False) class baseKernelMap(kernelOperation): def __post_init__(self): self.__call__ = ( @@ -54,7 +54,7 @@ def __multi_call__(self, array, *args, **kwargs): return result.reshape(*self.output_shape, *func_shape) -@treeclass(op=False) +@pytc.treeclass(op=False) class kernelMap(baseKernelMap): def __init__(self, func_dict, shape, kernel_size, strides, padding, relative): @@ -64,7 +64,7 @@ def __call__(self, array, *args, **kwargs): return self.__call__(array, *args, **kwargs) -@treeclass(op=False) +@pytc.treeclass(op=False) class offsetKernelMap(kernelMap): def __init__(self, func_dict, shape, kernel_size, strides, offset, relative): diff --git a/kernex/src/scan.py b/kernex/src/scan.py index cb0560e..bbcc073 100644 --- a/kernex/src/scan.py +++ b/kernex/src/scan.py @@ -2,16 +2,16 @@ from typing import Callable +import pytreeclass as pytc from jax import lax from jax import numpy as jnp -from pytreeclass import treeclass from pytreeclass.src.decorator_util import cached_property from kernex.src.base import kernelOperation from kernex.src.utils import ZIP, ix_, offset_to_padding, roll_view -@treeclass(op=False) +@pytc.treeclass(op=False) class baseKernelScan(kernelOperation): def __post_init__(self): self.__call__ = ( @@ -62,7 +62,7 @@ def scan_body(padded_array, view): ) -@treeclass(op=False) +@pytc.treeclass(op=False) class kernelScan(baseKernelScan): def __init__(self, func_dict, shape, kernel_size, strides, padding, relative): @@ -72,7 +72,7 @@ def __call__(self, array, *args, **kwargs): return self.__call__(array, *args, **kwargs) -@treeclass(op=False) +@pytc.treeclass(op=False) class offsetKernelScan(kernelScan): def __init__(self, func_dict, shape, kernel_size, strides, offset, relative): diff --git a/requirements/requirements.txt b/requirements/requirements.txt index e954642..5e65be6 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,4 +1,3 @@ -jax>=0.1.55 -jaxlib>=0.1.37 +jax>=0.3.5 pytreeclass>=0.0.5