Skip to content

Commit

Permalink
remove dispatch
Browse files Browse the repository at this point in the history
bump
  • Loading branch information
ASEM000 committed Sep 24, 2022
1 parent 2af15e8 commit f5dd7f2
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 82 deletions.
2 changes: 1 addition & 1 deletion kernex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
"offsetKernelScan",
)

__version__ = "0.1.1"
__version__ = "0.1.2"
142 changes: 62 additions & 80 deletions kernex/interface/resolve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,73 +5,64 @@

import jax
import jax.numpy as jnp
from pytreeclass._src.dispatch import dispatch

from kernex._src.utils import ZIP

# ---------------------- resolve_arguments ------------------------ #


@dispatch(argnum=0)
def _resolve_padding_argument(
input_argument: tuple[int | tuple[int, int] | str, ...] | int | str,
kernel_size: tuple[int, ...],
):
"""Helper function to generate padding"""
...

if isinstance(input_argument, tuple):
same = lambda wi: ((wi - 1) // 2, wi // 2)

@_resolve_padding_argument.register(tuple)
def _(input_argument, kernel_size):
same = lambda wi: ((wi - 1) // 2, wi // 2)

assert len(input_argument) == len(kernel_size), (
"kernel_size dimension != padding dimension.",
f"Found length(kernel_size)={len(kernel_size)} length(padding)={len(input_argument)}",
)
assert len(input_argument) == len(kernel_size), (
"kernel_size dimension != padding dimension.",
f"Found length(kernel_size)={len(kernel_size)} length(padding)={len(input_argument)}",
)

padding = [[]] * len(kernel_size)
padding = [[]] * len(kernel_size)

for i, item in enumerate(input_argument):
if isinstance(item, int):
padding[i] = (item, item)
for i, item in enumerate(input_argument):
if isinstance(item, int):
padding[i] = (item, item)

elif isinstance(item, tuple):
padding[i] = item
elif isinstance(item, tuple):
padding[i] = item

elif isinstance(item, str):
if item in ["same", "SAME"]:
padding[i] = same(kernel_size[i])
elif isinstance(item, str):
if item in ["same", "SAME"]:
padding[i] = same(kernel_size[i])

elif item in ["valid", "VALID"]:
padding[i] = (0, 0)
elif item in ["valid", "VALID"]:
padding[i] = (0, 0)

else:
raise ValueError(
f'string argument must be in ["same","SAME","VALID","valid"].Found {item}'
)
return tuple(padding)
else:
raise ValueError(
f'string argument must be in ["same","SAME","VALID","valid"].Found {item}'
)
return tuple(padding)

elif isinstance(input_argument, int):
return ((input_argument, input_argument),) * len(kernel_size)

@_resolve_padding_argument.register(str)
def _(input_argument, kernel_size):
same = lambda wi: ((wi - 1) // 2, wi // 2)
elif isinstance(input_argument, str):
same = lambda wi: ((wi - 1) // 2, wi // 2)

if input_argument in ["same", "SAME", "Same"]:
return tuple(same(wi) for wi in kernel_size)
if input_argument.lower() == "same":
return tuple(same(wi) for wi in kernel_size)

elif input_argument in ["valid", "VALID", "Valid"]:
return ((0, 0),) * len(kernel_size)
elif input_argument.lower() == "valid":
return ((0, 0),) * len(kernel_size)

else:
raise ValueError(
f'string argument must be in ["same","SAME","VALID","valid"].Found {input_argument}'
)


@_resolve_padding_argument.register(int)
def _(input_argument, kernel_size):
return ((input_argument, input_argument),) * len(kernel_size)
else:
raise ValueError(
f'string argument must be in ["same","SAME","VALID","valid"].Found {input_argument}'
)


def _resolve_dict_argument(
Expand Down Expand Up @@ -100,68 +91,59 @@ def _resolve_dict_argument(


def _resolve_offset_argument(input_argument, kernel_size):
@dispatch(argnum=0)
def __resolve_offset_argument(input_argument, kernel_size):
raise NotImplementedError(
"input_argument type={} is not implemented".format(type(input_argument))
)

@__resolve_offset_argument.register(int)
def _(input_argument, kernel_size):
return [(input_argument, input_argument)] * len(kernel_size)

@__resolve_offset_argument.register(list)
@__resolve_offset_argument.register(tuple)
def _(input_argument, kernel_size):
if isinstance(input_argument, (tuple, list)):
offset = [[]] * len(kernel_size)

for i, item in enumerate(input_argument):
offset[i] = (item, item) if isinstance(item, int) else item

return offset

return __resolve_offset_argument(input_argument, kernel_size)
elif isinstance(input_argument, int):
return [(input_argument, input_argument)] * len(kernel_size)

else:
raise NotImplementedError(
"input_argument type={} is not implemented".format(type(input_argument))
)


def _resolve_index(index, shape):
"""Resolve index to a tuple of int"""

@dispatch(argnum=0)
def __resolve_index_step(index, shape):
raise NotImplementedError(f"index type={type(index)} is not implemented")
def _resolve_single_index(index, shape):
if isinstance(index, int):
index += shape if index < 0 else 0
return index

@__resolve_index_step.register(int)
def _(index, shape):
index += shape if index < 0 else 0
return index
elif isinstance(index, slice):
start, end, step = index.start, index.stop, index.step

@__resolve_index_step.register(slice)
def _(index, shape):
start, end, step = index.start, index.stop, index.step
start = start or 0
start += shape if start < 0 else 0

start = start or 0
start += shape if start < 0 else 0
end = end or shape
end += shape if end < 0 else 0

end = end or shape
end += shape if end < 0 else 0
step = step or 1

step = step or 1
return (start, end, step)

return (start, end, step)
elif isinstance(index, (list, tuple)):
assert all(
isinstance(i, int) for i in jax.tree_util.tree_leaves(index)
), "All items in tuple must be int"
return index

@__resolve_index_step.register(list)
@__resolve_index_step.register(tuple)
def _(index, shape):
assert all(
isinstance(i, int) for i in jax.tree_util.tree_leaves(index)
), "All items in tuple must be int"
return index
else:
raise NotImplementedError(f"index type={type(index)} is not implemented")

index = [index] if not isinstance(index, tuple) else index
resolved_index = [[]] * len(index)

for i, (item, in_dim) in enumerate(zip(index, shape)):
resolved_index[i] = __resolve_index_step(item, in_dim)
resolved_index[i] = _resolve_single_index(item, in_dim)

return resolved_index

Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@

jax>=0.3.5
pytreeclass>=0.1.7
pytreeclass>=0.1.9

0 comments on commit f5dd7f2

Please sign in to comment.