Skip to content

Commit

Permalink
cache resolve_kernel,resolve_strides, bump to post
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Jul 29, 2022
1 parent a5fc60b commit fcfdae4
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 34 deletions.
2 changes: 1 addition & 1 deletion kernex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
"baseKernelMap", "kernelMap", "offsetKernelMap", "baseKernelScan",
"kernelScan", "offsetKernelScan")

__version__ = "0.0.4"
__version__ = "0.0.4post"
28 changes: 14 additions & 14 deletions kernex/interface/kernel_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import functools
from typing import Callable

import pytreeclass as pytc
from jax import numpy as jnp
from pytreeclass import static_field, treeclass

from kernex.interface.named_axis import named_axis_wrapper
from kernex.interface.resolve_utils import (
Expand All @@ -23,17 +23,17 @@
from kernex.src.scan import kernelScan, offsetKernelScan


@treeclass(op=False)
@pytc.treeclass(op=False)
class kernelInterface:

kernel_size: tuple[int, ...] | int = static_field()
strides: tuple[int, ...] | int = static_field(default=1)
border: tuple[int, ...] | tuple[tuple[int, int], ...] | int | str = static_field(default=0, repr=False) # fmt: skip
relative: bool = static_field(default=False)
inplace: bool = static_field(default=False)
use_offset: bool = static_field(default=False)
named_axis: dict[int, str] | None = static_field(default=None)
container: dict[Callable, slice | int] = static_field(default_factory=dict)
kernel_size: tuple[int, ...] | int = pytc.static_field()
strides: tuple[int, ...] | int = pytc.static_field(default=1)
border: tuple[int, ...] | tuple[tuple[int, int], ...] | int | str = pytc.static_field(default=0, repr=False) # fmt: skip
relative: bool = pytc.static_field(default=False)
inplace: bool = pytc.static_field(default=False)
use_offset: bool = pytc.static_field(default=False)
named_axis: dict[int, str] | None = pytc.static_field(default=None)
container: dict[Callable, slice | int] = pytc.static_field(default_factory=dict)

def __post_init__(self):
"""resolve the border values and the kernel operation"""
Expand Down Expand Up @@ -124,7 +124,7 @@ def __call__(self, *args, **kwargs):
)


@treeclass(op=False)
@pytc.treeclass(op=False)
class sscan(kernelInterface):
def __init__(
self, kernel_size=1, strides=1, offset=0, relative=False, named_axis=None
Expand All @@ -141,7 +141,7 @@ def __init__(
)


@treeclass(op=False)
@pytc.treeclass(op=False)
class smap(kernelInterface):
def __init__(
self, kernel_size=1, strides=1, offset=0, relative=False, named_axis=None
Expand All @@ -158,7 +158,7 @@ def __init__(
)


@treeclass(op=False)
@pytc.treeclass(op=False)
class kscan(kernelInterface):
def __init__(
self, kernel_size=1, strides=1, padding=0, relative=False, named_axis=None
Expand All @@ -175,7 +175,7 @@ def __init__(
)


@treeclass(op=False)
@pytc.treeclass(op=False)
class kmap(kernelInterface):
def __init__(
self, kernel_size=1, strides=1, padding=0, relative=False, named_axis=None
Expand Down
5 changes: 4 additions & 1 deletion kernex/interface/resolve_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import functools
from typing import Any, Callable

import jax
Expand Down Expand Up @@ -178,6 +179,7 @@ def normalize_slices(
return container


@functools.lru_cache(maxsize=None)
def resolve_kernel_size(arg, in_dim):

kw = "kernel_size"
Expand All @@ -201,12 +203,13 @@ def resolve_kernel_size(arg, in_dim):
return tuple(si if wi == -1 else wi for si, wi in ZIP(in_dim, arg))

elif isinstance(arg, int):
return (arg,) * len(in_dim)
return (in_dim if arg == -1 else arg) * len(in_dim)

else:
raise ValueError(f"{kw} must be instance of int or tuple. Found {type(arg)}")


@functools.lru_cache(maxsize=None)
def resolve_strides(arg, in_dim):

kw = "strides"
Expand Down
16 changes: 8 additions & 8 deletions kernex/src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,23 @@
from itertools import product
from typing import Any, Callable

import pytreeclass as pytc
from jax import numpy as jnp
from pytreeclass import static_field, treeclass
from pytreeclass.src.decorator_util import cached_property

from kernex.src.utils import ZIP, general_arange, general_product, key_search


@treeclass(op=False)
@pytc.treeclass(op=False)
class kernelOperation:
"""base class for all kernel operations"""

func_dict: dict[Callable[[Any], jnp.ndarray] : tuple[int, ...]] = static_field() # fmt: skip
shape: tuple[int, ...] = static_field()
kernel_size: tuple[int, ...] = static_field()
strides: tuple[int, ...] = static_field()
border: tuple[tuple[int, int], ...] = static_field()
relative: bool = static_field()
func_dict: dict[Callable[[Any], jnp.ndarray] : tuple[int, ...]] = pytc.static_field() # fmt: skip
shape: tuple[int, ...] = pytc.static_field()
kernel_size: tuple[int, ...] = pytc.static_field()
strides: tuple[int, ...] = pytc.static_field()
border: tuple[tuple[int, int], ...] = pytc.static_field()
relative: bool = pytc.static_field()

@cached_property
def pad_width(self):
Expand Down
8 changes: 4 additions & 4 deletions kernex/src/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@

from typing import Callable

import pytreeclass as pytc
from jax import lax
from jax import numpy as jnp
from jax import vmap
from pytreeclass import treeclass
from pytreeclass.src.decorator_util import cached_property

from kernex.src.base import kernelOperation
from kernex.src.utils import ZIP, ix_, offset_to_padding, roll_view


@treeclass(op=False)
@pytc.treeclass(op=False)
class baseKernelMap(kernelOperation):
def __post_init__(self):
self.__call__ = (
Expand Down Expand Up @@ -54,7 +54,7 @@ def __multi_call__(self, array, *args, **kwargs):
return result.reshape(*self.output_shape, *func_shape)


@treeclass(op=False)
@pytc.treeclass(op=False)
class kernelMap(baseKernelMap):
def __init__(self, func_dict, shape, kernel_size, strides, padding, relative):

Expand All @@ -64,7 +64,7 @@ def __call__(self, array, *args, **kwargs):
return self.__call__(array, *args, **kwargs)


@treeclass(op=False)
@pytc.treeclass(op=False)
class offsetKernelMap(kernelMap):
def __init__(self, func_dict, shape, kernel_size, strides, offset, relative):

Expand Down
8 changes: 4 additions & 4 deletions kernex/src/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@

from typing import Callable

import pytreeclass as pytc
from jax import lax
from jax import numpy as jnp
from pytreeclass import treeclass
from pytreeclass.src.decorator_util import cached_property

from kernex.src.base import kernelOperation
from kernex.src.utils import ZIP, ix_, offset_to_padding, roll_view


@treeclass(op=False)
@pytc.treeclass(op=False)
class baseKernelScan(kernelOperation):
def __post_init__(self):
self.__call__ = (
Expand Down Expand Up @@ -62,7 +62,7 @@ def scan_body(padded_array, view):
)


@treeclass(op=False)
@pytc.treeclass(op=False)
class kernelScan(baseKernelScan):
def __init__(self, func_dict, shape, kernel_size, strides, padding, relative):

Expand All @@ -72,7 +72,7 @@ def __call__(self, array, *args, **kwargs):
return self.__call__(array, *args, **kwargs)


@treeclass(op=False)
@pytc.treeclass(op=False)
class offsetKernelScan(kernelScan):
def __init__(self, func_dict, shape, kernel_size, strides, offset, relative):

Expand Down
3 changes: 1 addition & 2 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

jax>=0.1.55
jaxlib>=0.1.37
jax>=0.3.5
pytreeclass>=0.0.5

0 comments on commit fcfdae4

Please sign in to comment.