-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Pytrees #11
Comments
Hello, |
Hello, meanwhile, can you try this ? The key point here is to stack the arrays on some axis I also recommend using import jax.numpy as jnp
import kernex
import jax
@kernex.kmap(kernel_size=(2, 3, 3), padding=("valid","valid","valid"))
def kernel(tree):
x, y = tree
jax.debug.print("x={x} \n\n y={y}\n",x=x, y=y)
return jnp.sum(x * jnp.square(y))
data = jnp.arange(25).reshape(5, 5)
out = kernel(jnp.stack([data, data],axis=0))
# x=[[ 0 1 2]
# [ 5 6 7]
# [10 11 12]]
# y=[[ 0 1 2]
# [ 5 6 7]
# [10 11 12]]
# x=[[ 1 2 3]
# [ 6 7 8]
# [11 12 13]]
# y=[[ 1 2 3]
# [ 6 7 8]
# [11 12 13]]
# x=[[ 2 3 4]
# [ 7 8 9]
# [12 13 14]]
# y=[[ 2 3 4]
# [ 7 8 9]
# [12 13 14]]
# x=[[ 5 6 7]
# [10 11 12]
# [15 16 17]]
# y=[[ 5 6 7]
# [10 11 12]
# [15 16 17]]
# x=[[ 6 7 8]
# [11 12 13]
# [16 17 18]]
# y=[[ 6 7 8]
# [11 12 13]
# [16 17 18]]
# x=[[ 7 8 9]
# [12 13 14]
# [17 18 19]]
# y=[[ 7 8 9]
# [12 13 14]
# [17 18 19]]
# x=[[10 11 12]
# [15 16 17]
# [20 21 22]]
# y=[[10 11 12]
# [15 16 17]
# [20 21 22]]
# x=[[11 12 13]
# [16 17 18]
# [21 22 23]]
# y=[[11 12 13]
# [16 17 18]
# [21 22 23]]
# x=[[12 13 14]
# [17 18 19]
# [22 23 24]]
# y=[[12 13 14]
# [17 18 19]
# [22 23 24]] |
Thanks, that works for me! |
As a follow-up, I think it is simpler to define which What do you think? |
Thanks for the follow-up and including me in this. To clarify, do you mean not supporting trees, but instead multiple arguments? So something like @kernex.kmap(kernel_size=(3, 3), argnums=(0, 1))
def kernel(x, y):
return jnp.sum(x * jnp.square(y)) , or for non-mapped local weights @kernex.kmap(kernel_size=(3, 3), argnums=(0,))
def kernel(x, y_local):
return jnp.sum(x * jnp.square(y_local)) where TLDR: Anything is fine for me. I think supporting trees would be slightly more powerful, but any reasonable task should be translatable to multiple |
Does kernex support Pytrees? I did not find an example. It would be very useful to support moving-window filters with "global" weights or simply multiple inputs, such as a cross-channel bilateral filter in my case.
Repro:
raises
The text was updated successfully, but these errors were encountered: