Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add padding _kwargs #13

Merged
merged 1 commit into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion kernex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@
"offset_kernel_scan",
)

__version__ = "0.2.0"
__version__ = "0.2.1"
10 changes: 8 additions & 2 deletions kernex/_src/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,13 @@ def kernel_map(
relative: bool = False,
map_kind: MapKind = "vmap",
map_kwargs: dict[str, Any] | None = None,
padding_kwargs: dict[str, Any] | None = None,
) -> Callable:

map_kwargs = map_kwargs or {}
padding_kwargs = padding_kwargs or {}
padding_kwargs.pop("pad_width", None) # handled by border

map_tranform = transform_func_map[map_kind]
pad_width = _calculate_pad_width(border)
args = (shape, kernel_size, strides, border)
Expand All @@ -82,7 +86,7 @@ def kernel_map(
slices = tuple(func_map.values())

def single_call_wrapper(array: jax.Array, *a, **k):
padded_array = jnp.pad(array, pad_width)
padded_array = jnp.pad(array, pad_width=pad_width, **padding_kwargs)

# convert the function to a callable that takes a view and an array
# and returns the result of the function applied to the view
Expand All @@ -98,7 +102,7 @@ def map_func(view):
return result.reshape(*output_shape, *result.shape[1:])

def multi_call_wrapper(array: jax.Array, *a, **k):
padded_array = jnp.pad(array, pad_width)
padded_array = jnp.pad(array, pad_width=pad_width, **padding_kwargs)
# convert the functions to a callable that takes a view and an array
# and returns the result of the function applied to the view
# the result is a 1D array of the same length as the number of views
Expand Down Expand Up @@ -133,6 +137,7 @@ def offset_kernel_map(
relative: bool = False,
map_kind: MapKind = "vmap",
map_kwargs: dict[str, Any] = None,
offset_kwargs: dict[str, Any] = None,
):

func = kernel_map(
Expand All @@ -144,6 +149,7 @@ def offset_kernel_map(
relative=relative,
map_kind=map_kind,
map_kwargs=map_kwargs,
padding_kwargs=offset_kwargs,
)
set_indices = _get_set_indices(shape, strides, offset)

Expand Down
9 changes: 7 additions & 2 deletions kernex/_src/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,12 @@ def kernel_scan(
relative: bool = False,
scan_kind: ScanKind = "scan", # dummy to make signature consistent with kernel_map
scan_kwargs: dict[str, Any] | None = None,
padding_kwargs: dict[str, Any] | None = None,
):

scan_kwargs = scan_kwargs or {}
padding_kwargs = padding_kwargs or {}
padding_kwargs.pop("pad_width", None)
scan_transform = transform_func_map[scan_kind]
pad_width = _calculate_pad_width(border)
args = (shape, kernel_size, strides, border)
Expand All @@ -82,7 +85,7 @@ def kernel_scan(
slices = tuple(func_map.values())

def single_call_wrapper(array: jax.Array, *a, **k):
padded_array = jnp.pad(array, pad_width)
padded_array = jnp.pad(array, pad_width=pad_width, **padding_kwargs)
func0 = next(iter(func_map))
reduced_func = _transform_scan_func(func0, kernel_size, relative)(*a, **k)

Expand All @@ -95,7 +98,7 @@ def scan_body(padded_array: jax.Array, view: jax.Array):
return result.reshape(output_shape)

def multi_call_wrapper(array: jax.Array, *a, **k):
padded_array = jnp.pad(array, pad_width)
padded_array = jnp.pad(array, pad_width=pad_width, **padding_kwargs)

reduced_funcs = tuple(
_transform_scan_func(func, kernel_size, relative)(*a, **k)
Expand Down Expand Up @@ -124,6 +127,7 @@ def offset_kernel_scan(
relative: bool = False,
scan_kind: ScanKind = "scan",
scan_kwargs: dict[str, Any] | None = None,
offset_kwargs: dict[str, Any] | None = None,
):

func = kernel_scan(
Expand All @@ -135,6 +139,7 @@ def offset_kernel_scan(
relative=relative,
scan_kind=scan_kind,
scan_kwargs=scan_kwargs,
padding_kwargs=offset_kwargs,
)
set_indices = _get_set_indices(shape, strides, offset)

Expand Down
24 changes: 20 additions & 4 deletions kernex/interface/kernel_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
)

BorderType = Union[
int,
Tuple[int, ...],
Tuple[Tuple[int, int], ...],
Literal["valid", "same", "SAME", "VALID"],
int, # single int to pad all axes before and after the array
Tuple[int, ...], # tuple of ints to pad before and after each axis
Tuple[Tuple[int, int], ...], # tuple of tuples to pad before and after each axis
Literal["valid", "same", "SAME", "VALID"], # string to use a predefined padding
]

StridesType = Union[Tuple[int, ...], int]
Expand All @@ -56,6 +56,7 @@ def __init__(
container: dict[Callable, slice | int] | None = None,
transform_kind: MapKind | ScanKind | None = None,
transform_kwargs: dict[str, Any] | None = None,
border_kwargs: dict[str, Any] | None = None,
):
self.kernel_size = kernel_size
self.strides = strides
Expand All @@ -71,6 +72,7 @@ def __init__(
)
self.transform_kind = transform_kind
self.transform_kwargs = transform_kwargs
self.border_kwargs = border_kwargs

def __repr__(self) -> str:
return (
Expand Down Expand Up @@ -124,6 +126,7 @@ def _wrap_mesh(self, array: jax.Array, *a, **k):
self.relative,
self.transform_kind,
self.transform_kwargs,
self.border_kwargs,
)(array, *a, **k)

def _wrap_decorator(self, func):
Expand Down Expand Up @@ -154,6 +157,7 @@ def call(array, *args, **kwargs):
self.relative,
self.transform_kind,
self.transform_kwargs,
self.border_kwargs,
)(array, *args, **kwargs)

return call
Expand Down Expand Up @@ -252,6 +256,7 @@ def __init__(
named_axis=named_axis,
transform_kind=scan_kind,
transform_kwargs=scan_kwargs,
border_kwargs=None,
)


Expand Down Expand Up @@ -337,6 +342,7 @@ def __init__(
named_axis=named_axis,
transform_kind=map_kind,
transform_kwargs=map_kwargs,
border_kwargs=None,
)


Expand All @@ -350,6 +356,7 @@ def __init__(
named_axis: dict[int, str] = None,
scan_kind: ScanKind = "scan",
scan_kwargs: dict[str, Any] | None = None,
padding_kwargs: dict[str, Any] | None = None,
):
"""Apply a function to a sliding window of the input array sequentially.

Expand All @@ -368,6 +375,9 @@ def __init__(
scan_kwargs: optional kwargs to be passed to the scan function.
for example, `scan_kwargs={'reverse': True}` will reverse the
application of the function.
padding_kwargs: optional kwargs to be passed to the padding function.
for example, `padding_kwargs=dict(constant_values=10)` will pad
the input array with 10 for same padding.

Returns:
A function that takes an array as input and returns the result of
Expand Down Expand Up @@ -400,6 +410,7 @@ def __init__(
named_axis=named_axis,
transform_kind=scan_kind,
transform_kwargs=scan_kwargs,
border_kwargs=padding_kwargs,
)


Expand All @@ -413,6 +424,7 @@ def __init__(
named_axis: dict[int, str] = None,
map_kind: MapKind = "vmap",
map_kwargs: dict = None,
padding_kwargs: dict = None,
):
"""Apply a function to a sliding window of the input array in parallel.

Expand All @@ -432,6 +444,9 @@ def __init__(
map_kwargs: optional kwargs to be passed to the map function.
for example, `map_kwargs={'axis_name': 'i'}` will apply the
function along the axis named `i` for `pmap`.
padding_kwargs: optional kwargs to be passed to the padding function.
for example, `padding_kwargs=dict(constant_values=10)` will pad
the input array with 10 for same padding.

Returns:
A function that takes an array as input and applies the kernel
Expand Down Expand Up @@ -464,4 +479,5 @@ def __init__(
named_axis=named_axis,
transform_kind=map_kind,
transform_kwargs=map_kwargs,
border_kwargs=padding_kwargs,
)
18 changes: 18 additions & 0 deletions tests/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,21 @@ def kex_conv2d(x, w):
pred_grad = jax.grad(lambda w: jnp.sum(kex_conv2d(x, w)))(w)

np.testing.assert_allclose(true_grad[0], pred_grad, atol=1e-3)


def test_padding_kwargs():
@kex.kmap(
kernel_size=(3,),
padding=("same"),
relative=False,
padding_kwargs=dict(constant_values=10),
)
def f(x):
return x

x = jnp.array([1, 2, 3, 4, 5])

np.testing.assert_allclose(
f(x),
np.array([[10, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, 5], [4, 5, 10]]),
)