Skip to content

Commit

Permalink
comments and minor edits
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Sep 11, 2022
1 parent 21cde92 commit 8b12e2a
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 54 deletions.
2 changes: 1 addition & 1 deletion kernex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
"offsetKernelScan",
)

__version__ = "0.0.8"
__version__ = "0.1.0"
16 changes: 15 additions & 1 deletion kernex/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def pad_width(self):
Returns:
padding value passed to `pad_width` in `jnp.pad`
"""
# this function is cached because it is called multiple times
# and it is expensive to calculate
# if the border is negative, the padding is 0
# if the border is positive, the padding is the border value
return tuple([0, max(0, pi[0]) + max(0, pi[1])] for pi in self.border)

@cached_property
Expand All @@ -45,6 +49,10 @@ def output_shape(self) -> tuple[int, ...]:
Returns:
tuple[int, ...]: resulting shape of the kernel operation
"""
# this function is cached because it is called multiple times
# and it is expensive to calculate
# the output shape is the shape of the array after the kernel operation
# is applied to the input array
return tuple(
(xi + (li + ri) - ki) // si + 1
for xi, ki, si, (li, ri) in ZIP(
Expand All @@ -55,13 +63,16 @@ def output_shape(self) -> tuple[int, ...]:
@cached_property
def views(self) -> tuple[jnp.ndarray, ...]:
"""Generate absolute sampling matrix"""
# this function is cached because it is called multiple times
# and it is expensive to calculate
# the view is the indices of the array that is used to calculate
# the output value
dim_range = tuple(
general_arange(di, ki, si, x0, xf)
for (di, ki, si, (x0, xf)) in zip(
self.shape, self.kernel_size, self.strides, self.border
)
)

matrix = general_product(*dim_range)
return tuple(map(lambda xi, wi: xi.reshape(-1, wi), matrix, self.kernel_size))

Expand All @@ -86,6 +97,8 @@ def funcs(self) -> tuple[Callable[[Any], jnp.ndarray]]:

@property
def slices(self):
# this function returns a tuple of slices
# the slices are used to slice the array
return tuple(self.func_index_map.values())

def index_from_view(self, view: tuple[jnp.ndarray, ...]) -> tuple[int, ...]:
Expand All @@ -97,6 +110,7 @@ def index_from_view(self, view: tuple[jnp.ndarray, ...]) -> tuple[int, ...]:
Returns:
tuple[int, ...]: index as a tuple of int for each dimension
"""
# this function returns a tuple of int
return tuple(
view[i][wi // 2] if wi % 2 == 1 else view[i][(wi - 1) // 2]
for i, wi in enumerate(self.kernel_size)
Expand Down
45 changes: 39 additions & 6 deletions kernex/_src/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,55 @@ class baseKernelMap(kernelOperation):
def __post_init__(self):

self.__call__ = (
self.__single_call__ if len(self.funcs) == 1 else self.__multi_call__
# if there is only one function, use the single call method
# this is faster than the multi call method
# this is because the multi call method uses lax.switch
self.__single_call__
if len(self.funcs) == 1
else self.__multi_call__
)

def reduce_map_func(self, func, *args, **kwargs) -> Callable:
if self.relative:
# if the function is relative, the function is applied to the view
return lambda view, array: func(
roll_view(array[ix_(*view)]), *args, **kwargs
)

else:
return lambda view, array: func(array[ix_(*view)], *args, **kwargs)

def __single_call__(self, array, *args, **kwargs):

def __single_call__(self, array: jnp.ndarray, *args, **kwargs):
padded_array = jnp.pad(array, self.pad_width)

# convert the function to a callable that takes a view and an array
# and returns the result of the function applied to the view
reduced_func = self.reduce_map_func(self.funcs[0], *args, **kwargs)

# apply the function to each view using vmap
# the result is a 1D array of the same length as the number of views
result = vmap(lambda view: reduced_func(view, padded_array))(self.views)
func_shape = result.shape[1:]
return result.reshape(*self.output_shape, *func_shape)

# reshape the result to the output shape
# for example if the input shape is (3, 3) and the kernel shape is (2, 2)
# and the stride is 1 , and the padding is 0, the output shape is (2, 2)
return result.reshape(*self.output_shape, *result.shape[1:])

def __multi_call__(self, array, *args, **kwargs):

padded_array = jnp.pad(array, self.pad_width)

# convert the functions to a callable that takes a view and an array
# and returns the result of the function applied to the view
# the result is a 1D array of the same length as the number of views
reduced_funcs = tuple(
self.reduce_map_func(func, *args, **kwargs) for func in self.funcs[::-1]
)

# apply the functions to each view using vmap
# the result is a 1D array of the same length as the number of views
# here, lax.switch is used to apply the functions in order
# the first function is applied to the first view, the second function
# is applied to the second view, and so on
result = vmap(
lambda view: lax.switch(
self.func_index_from_view(view), reduced_funcs, view, padded_array
Expand All @@ -56,6 +77,8 @@ def __multi_call__(self, array, *args, **kwargs):

@pytc.treeclass
class kernelMap(baseKernelMap):
"""A class for applying a function to a kernel map of an array"""

def __init__(self, func_dict, shape, kernel_size, strides, padding, relative):
super().__init__(func_dict, shape, kernel_size, strides, padding, relative)

Expand All @@ -65,7 +88,11 @@ def __call__(self, array, *args, **kwargs):

@pytc.treeclass
class offsetKernelMap(kernelMap):
"""A class for applying a function to a kernel map of an array"""

def __init__(self, func_dict, shape, kernel_size, strides, offset, relative):
# the offset is converted to padding and the padding is used to pad the array
# the padding is then used to calculate the views

self.offset = offset

Expand All @@ -80,6 +107,9 @@ def __init__(self, func_dict, shape, kernel_size, strides, offset, relative):

@cached_property
def set_indices(self):
# the indices of the array that are set by the kernel operation
# this is used to set the values of the array after the kernel operation
# is applied
return tuple(
jnp.arange(x0, di - xf, si)
for di, ki, si, (x0, xf) in ZIP(
Expand All @@ -88,6 +118,9 @@ def set_indices(self):
)

def __call__(self, array, *args, **kwargs):
# apply the kernel operation
# the result is a 1D array of the same length as the number of views
# the result is reshaped to the output shape
result = self.__call__(array, *args, **kwargs)
assert (
result.shape <= array.shape
Expand Down
6 changes: 5 additions & 1 deletion kernex/_src/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,17 @@
@pytc.treeclass
class baseKernelScan(kernelOperation):
def __post_init__(self):
# if there is only one function, use the single call method
# this is faster than the multi call method
# this is because the multi call method uses lax.switch
self.__call__ = (
self.__single_call__ if len(self.funcs) == 1 else self.__multi_call__
)

def reduce_scan_func(self, func, *args, **kwargs) -> Callable:
if self.relative:
# if the function is relative, the function is applied to the view
# the result is a 1D array of the same length as the number of views
return lambda view, array: array.at[self.index_from_view(view)].set(
func(roll_view(array[ix_(*view)]), *args, **kwargs)
)
Expand Down Expand Up @@ -64,7 +69,6 @@ def scan_body(padded_array, view):
@pytc.treeclass
class kernelScan(baseKernelScan):
def __init__(self, func_dict, shape, kernel_size, strides, padding, relative):

super().__init__(func_dict, shape, kernel_size, strides, padding, relative)

def __call__(self, array, *args, **kwargs):
Expand Down
8 changes: 6 additions & 2 deletions kernex/_src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@


class cached_property:
"""this function is a decorator that caches the result of the function"""

def __init__(self, func):
self.name = func.__name__
self.func = func
Expand All @@ -35,7 +37,8 @@ def ZIP(*args):

def _offset_to_padding(input_argument, kernel_size):
"""convert offset argument to negative border values"""

# for example for a kernel_size = (3,3) and offset = (1,1)
# the padding will be (-1,-1) for each dimension
padding = [[]] * len(kernel_size)

# offset = 1 ==> padding= 0 for kernel_size =3
Expand Down Expand Up @@ -70,6 +73,7 @@ def roll_view(array: jnp.ndarray) -> jnp.ndarray:
[ 3 4 5 1 2]
[ 8 9 10 6 7]]
"""
# this function is used to roll the view along all axes
shape = jnp.array(array.shape)
axes = tuple(range(len(shape))) # list all axes
shift = tuple(
Expand Down Expand Up @@ -114,6 +118,7 @@ def general_arange(di: int, ki: int, si: int, x0: int, xf: int) -> jnp.ndarray:
[1 2 3]
[2 3 4]]
"""
# this function is used to calculate the windows indices for a given dimension
start, end = -x0 + ((ki - 1) // 2), di + xf - (ki // 2)
size = end - start
lhs = jax.lax.broadcasted_iota(dtype=jnp.int32, shape=(size, ki), dimension=0) + (start) # fmt: skip
Expand Down Expand Up @@ -170,7 +175,6 @@ def _index_from_view(
Returns:
tuple[int, ...]: index as a tuple of int for each dimension
"""

return tuple(
view[i][wi // 2] if wi % 2 == 1 else view[i][(wi - 1) // 2]
for i, wi in enumerate(kernel_size)
Expand Down
5 changes: 5 additions & 0 deletions kernex/interface/named_axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
class sortedDict(dict):
"""a class that sort a key before setting or getting an item"""

# this dict is used to store the kernel values
# the key is a tuple of the axis names
# the value is the kernel values
# for example if the kernel is 3x3 and the axis names are ['x', 'y']
# the key will be ('x', 'y') and the value will be the kernel values
def __getitem__(self, key: tuple[str, ...]):
key = (key,) if isinstance(key, str) else tuple(sorted(key))
return super().__getitem__(key)
Expand Down
84 changes: 41 additions & 43 deletions kernex/interface/resolve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,66 +99,64 @@ def _resolve_dict_argument(
return tuple(temp)


@dispatch(argnum=0)
def _resolve_offset_argument(input_argument, kernel_size):
raise NotImplementedError(
"input_argument type={} is not implemented".format(type(input_argument))
)


@_resolve_offset_argument.register(int)
def _(input_argument, kernel_size):
return [(input_argument, input_argument)] * len(kernel_size)


@_resolve_offset_argument.register(list)
@_resolve_offset_argument.register(tuple)
def _(input_argument, kernel_size):
offset = [[]] * len(kernel_size)
@dispatch(argnum=0)
def __resolve_offset_argument(input_argument, kernel_size):
raise NotImplementedError(
"input_argument type={} is not implemented".format(type(input_argument))
)

for i, item in enumerate(input_argument):
offset[i] = (item, item) if isinstance(item, int) else item
@__resolve_offset_argument.register(int)
def _(input_argument, kernel_size):
return [(input_argument, input_argument)] * len(kernel_size)

return offset
@__resolve_offset_argument.register(list)
@__resolve_offset_argument.register(tuple)
def _(input_argument, kernel_size):
offset = [[]] * len(kernel_size)

for i, item in enumerate(input_argument):
offset[i] = (item, item) if isinstance(item, int) else item

@dispatch(argnum=0)
def __resolve_index_step(index, shape):
raise NotImplementedError(f"index type={type(index)} is not implemented")
return offset

return __resolve_offset_argument(input_argument, kernel_size)

@__resolve_index_step.register(int)
def _(index, shape):
index += shape if index < 0 else 0
return index

def _resolve_index(index, shape):
"""Resolve index to a tuple of int"""

@__resolve_index_step.register(slice)
def _(index, shape):
start, end, step = index.start, index.stop, index.step
@dispatch(argnum=0)
def __resolve_index_step(index, shape):
raise NotImplementedError(f"index type={type(index)} is not implemented")

start = start or 0
start += shape if start < 0 else 0
@__resolve_index_step.register(int)
def _(index, shape):
index += shape if index < 0 else 0
return index

end = end or shape
end += shape if end < 0 else 0
@__resolve_index_step.register(slice)
def _(index, shape):
start, end, step = index.start, index.stop, index.step

step = step or 1
start = start or 0
start += shape if start < 0 else 0

return (start, end, step)
end = end or shape
end += shape if end < 0 else 0

step = step or 1

@__resolve_index_step.register(list)
@__resolve_index_step.register(tuple)
def _(index, shape):
assert all(
isinstance(i, int) for i in jax.tree_util.tree_leaves(index)
), "All items in tuple must be int"
return index
return (start, end, step)

@__resolve_index_step.register(list)
@__resolve_index_step.register(tuple)
def _(index, shape):
assert all(
isinstance(i, int) for i in jax.tree_util.tree_leaves(index)
), "All items in tuple must be int"
return index

def _resolve_index(index, shape):
"""Resolve index to a tuple of int"""
index = [index] if not isinstance(index, tuple) else index
resolved_index = [[]] * len(index)

Expand Down

0 comments on commit 8b12e2a

Please sign in to comment.