Skip to content

Commit

Permalink
update to support the new immutable pytreeclass
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Aug 23, 2022
1 parent dbcb683 commit 36bdd5e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 18 deletions.
2 changes: 1 addition & 1 deletion kernex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
"offsetKernelScan",
)

__version__ = "0.0.5"
__version__ = "0.0.6"
40 changes: 23 additions & 17 deletions kernex/interface/kernel_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,22 +53,25 @@ def __setitem__(self, index, func):
self.container[func] = [*self.container.get(func, []), index]

def _wrap_mesh(self, array, *args, **kwargs):
object.__setattr__(self, "shape", array.shape) # fmt: skip
object.__setattr__(self, "kernel_size", _resolve_kernel_size(self.kernel_size, self.shape)) # fmt: skip
object.__setattr__(self, "strides", _resolve_strides(self.strides, self.shape)) # fmt: skip
object.__setattr__(self, "container", _normalize_slices(self.container, self.shape)) # fmt: skip
object.__setattr__(self, "resolved_container", {})

self.shape = array.shape
self.kernel_size = _resolve_kernel_size(self.kernel_size, self.shape)
self.strides = _resolve_strides(self.strides, self.shape)
self.container = _normalize_slices(self.container, self.shape)
self.resolved_container = {}
resolved_container = {}

for (func, index) in self.container.items():

if func is not None and self.named_axis is not None:
self.resolved_container[
resolved_container[
named_axis_wrapper(self.kernel_size, self.named_axis)(func)
] = index

else:
self.resolved_container[func] = index
resolved_container[func] = index

object.__setattr__(self, "resolved_container", resolved_container)

kernel_op = (
(offsetKernelScan if self.inplace else offsetKernelMap)
Expand All @@ -87,16 +90,19 @@ def _wrap_mesh(self, array, *args, **kwargs):

def _wrap_decorator(self, func):
def call(array, *args, **kwargs):

self.shape = array.shape
self.kernel_size = _resolve_kernel_size(self.kernel_size, self.shape)
self.strides = _resolve_strides(self.strides, self.shape)

self.resolved_container = {
named_axis_wrapper(self.kernel_size, self.named_axis)(func)
if self.named_axis is not None
else func: ()
}
object.__setattr__(self, "shape", array.shape) # fmt: skip
object.__setattr__(self, "kernel_size", _resolve_kernel_size(self.kernel_size, self.shape)) # fmt: skip
object.__setattr__(self, "strides", _resolve_strides(self.strides, self.shape)) # fmt: skip

object.__setattr__(
self,
"resolved_container",
{
named_axis_wrapper(self.kernel_size, self.named_axis)(func)
if self.named_axis is not None
else func: ()
},
)

kernel_op = (
(offsetKernelScan if self.inplace else offsetKernelMap)
Expand Down

0 comments on commit 36bdd5e

Please sign in to comment.