Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mode="reflect" in padding is incorrect #14

Open
clemisch opened this issue Jul 11, 2023 · 0 comments
Open

mode="reflect" in padding is incorrect #14

clemisch opened this issue Jul 11, 2023 · 0 comments
Labels
bug Something isn't working

Comments

@clemisch
Copy link

I think mode="reflect" for padding_kwargs is incorrect:

import jax.numpy as jnp
import kernex

@kernex.kmap(
    kernel_size=(3,),
    padding=("same"),
    relative=False,
    padding_kwargs=dict(mode="reflect"),
)
def f(x):
    return x

x = jnp.array([1, 2, 3, 4, 5])
y = f(x)
z = jnp.pad(x, 1, mode="reflect")

print("x: ", x)
print("y: ", y)
print("z: ", z)

gives

x:  [1 2 3 4 5]
y:  [[3 1 2]     # <-- the `3` is incorrect, should be `2`
     [1 2 3]
     [2 3 4]
     [3 4 5]
     [4 5 4]]
z:  [2 1 2 3 4 5 4]   # <-- here, the first element is `2`

The Kernex output reflects incorrectly: the first element is 3 instead of 2.

@ASEM000 ASEM000 added the bug Something isn't working label Jul 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants