diff --git a/kernex/__init__.py b/kernex/__init__.py index f97f23b..cdab4cf 100644 --- a/kernex/__init__.py +++ b/kernex/__init__.py @@ -17,4 +17,4 @@ "offsetKernelScan", ) -__version__ = "0.1.1" +__version__ = "0.1.2" diff --git a/kernex/interface/resolve_utils.py b/kernex/interface/resolve_utils.py index 34300f1..7b05070 100644 --- a/kernex/interface/resolve_utils.py +++ b/kernex/interface/resolve_utils.py @@ -5,73 +5,64 @@ import jax import jax.numpy as jnp -from pytreeclass._src.dispatch import dispatch from kernex._src.utils import ZIP # ---------------------- resolve_arguments ------------------------ # -@dispatch(argnum=0) def _resolve_padding_argument( input_argument: tuple[int | tuple[int, int] | str, ...] | int | str, kernel_size: tuple[int, ...], ): """Helper function to generate padding""" - ... + if isinstance(input_argument, tuple): + same = lambda wi: ((wi - 1) // 2, wi // 2) -@_resolve_padding_argument.register(tuple) -def _(input_argument, kernel_size): - same = lambda wi: ((wi - 1) // 2, wi // 2) - - assert len(input_argument) == len(kernel_size), ( - "kernel_size dimension != padding dimension.", - f"Found length(kernel_size)={len(kernel_size)} length(padding)={len(input_argument)}", - ) + assert len(input_argument) == len(kernel_size), ( + "kernel_size dimension != padding dimension.", + f"Found length(kernel_size)={len(kernel_size)} length(padding)={len(input_argument)}", + ) - padding = [[]] * len(kernel_size) + padding = [[]] * len(kernel_size) - for i, item in enumerate(input_argument): - if isinstance(item, int): - padding[i] = (item, item) + for i, item in enumerate(input_argument): + if isinstance(item, int): + padding[i] = (item, item) - elif isinstance(item, tuple): - padding[i] = item + elif isinstance(item, tuple): + padding[i] = item - elif isinstance(item, str): - if item in ["same", "SAME"]: - padding[i] = same(kernel_size[i]) + elif isinstance(item, str): + if item in ["same", "SAME"]: + padding[i] = same(kernel_size[i]) - elif item in ["valid", "VALID"]: - padding[i] = (0, 0) + elif item in ["valid", "VALID"]: + padding[i] = (0, 0) - else: - raise ValueError( - f'string argument must be in ["same","SAME","VALID","valid"].Found {item}' - ) - return tuple(padding) + else: + raise ValueError( + f'string argument must be in ["same","SAME","VALID","valid"].Found {item}' + ) + return tuple(padding) + elif isinstance(input_argument, int): + return ((input_argument, input_argument),) * len(kernel_size) -@_resolve_padding_argument.register(str) -def _(input_argument, kernel_size): - same = lambda wi: ((wi - 1) // 2, wi // 2) + elif isinstance(input_argument, str): + same = lambda wi: ((wi - 1) // 2, wi // 2) - if input_argument in ["same", "SAME", "Same"]: - return tuple(same(wi) for wi in kernel_size) + if input_argument.lower() == "same": + return tuple(same(wi) for wi in kernel_size) - elif input_argument in ["valid", "VALID", "Valid"]: - return ((0, 0),) * len(kernel_size) + elif input_argument.lower() == "valid": + return ((0, 0),) * len(kernel_size) - else: - raise ValueError( - f'string argument must be in ["same","SAME","VALID","valid"].Found {input_argument}' - ) - - -@_resolve_padding_argument.register(int) -def _(input_argument, kernel_size): - return ((input_argument, input_argument),) * len(kernel_size) + else: + raise ValueError( + f'string argument must be in ["same","SAME","VALID","valid"].Found {input_argument}' + ) def _resolve_dict_argument( @@ -100,19 +91,8 @@ def _resolve_dict_argument( def _resolve_offset_argument(input_argument, kernel_size): - @dispatch(argnum=0) - def __resolve_offset_argument(input_argument, kernel_size): - raise NotImplementedError( - "input_argument type={} is not implemented".format(type(input_argument)) - ) - - @__resolve_offset_argument.register(int) - def _(input_argument, kernel_size): - return [(input_argument, input_argument)] * len(kernel_size) - @__resolve_offset_argument.register(list) - @__resolve_offset_argument.register(tuple) - def _(input_argument, kernel_size): + if isinstance(input_argument, (tuple, list)): offset = [[]] * len(kernel_size) for i, item in enumerate(input_argument): @@ -120,48 +100,50 @@ def _(input_argument, kernel_size): return offset - return __resolve_offset_argument(input_argument, kernel_size) + elif isinstance(input_argument, int): + return [(input_argument, input_argument)] * len(kernel_size) + + else: + raise NotImplementedError( + "input_argument type={} is not implemented".format(type(input_argument)) + ) def _resolve_index(index, shape): """Resolve index to a tuple of int""" - @dispatch(argnum=0) - def __resolve_index_step(index, shape): - raise NotImplementedError(f"index type={type(index)} is not implemented") + def _resolve_single_index(index, shape): + if isinstance(index, int): + index += shape if index < 0 else 0 + return index - @__resolve_index_step.register(int) - def _(index, shape): - index += shape if index < 0 else 0 - return index + elif isinstance(index, slice): + start, end, step = index.start, index.stop, index.step - @__resolve_index_step.register(slice) - def _(index, shape): - start, end, step = index.start, index.stop, index.step + start = start or 0 + start += shape if start < 0 else 0 - start = start or 0 - start += shape if start < 0 else 0 + end = end or shape + end += shape if end < 0 else 0 - end = end or shape - end += shape if end < 0 else 0 + step = step or 1 - step = step or 1 + return (start, end, step) - return (start, end, step) + elif isinstance(index, (list, tuple)): + assert all( + isinstance(i, int) for i in jax.tree_util.tree_leaves(index) + ), "All items in tuple must be int" + return index - @__resolve_index_step.register(list) - @__resolve_index_step.register(tuple) - def _(index, shape): - assert all( - isinstance(i, int) for i in jax.tree_util.tree_leaves(index) - ), "All items in tuple must be int" - return index + else: + raise NotImplementedError(f"index type={type(index)} is not implemented") index = [index] if not isinstance(index, tuple) else index resolved_index = [[]] * len(index) for i, (item, in_dim) in enumerate(zip(index, shape)): - resolved_index[i] = __resolve_index_step(item, in_dim) + resolved_index[i] = _resolve_single_index(item, in_dim) return resolved_index diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 286d80f..f41e7d9 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,3 +1,3 @@ jax>=0.3.5 -pytreeclass>=0.1.7 +pytreeclass>=0.1.9