Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Pytrees #11

Open
clemisch opened this issue Jun 19, 2023 · 5 comments
Open

Support Pytrees #11

clemisch opened this issue Jun 19, 2023 · 5 comments
Assignees
Labels
enhancement New feature or request

Comments

@clemisch
Copy link

clemisch commented Jun 19, 2023

Does kernex support Pytrees? I did not find an example. It would be very useful to support moving-window filters with "global" weights or simply multiple inputs, such as a cross-channel bilateral filter in my case.

Repro:

import jax.numpy as jnp
import kernex

@kernex.kmap(kernel_size=(3, 3))
def kernel(tree):
    x, y = tree
    return jnp.sum(x * jnp.square(y))

data = jnp.arange(20 * 30).reshape((20, 30))
out = kernel((data, data))

raises

Traceback (most recent call last):
  File "/home/clemisch/kernex_tree.py", line 52, in <module>
    out = kernel((data, data))
          ^^^^^^^^^^^^^^^^^^^^
  File "/home/clemisch/venvs/11/lib64/python3.11/site-packages/kernex/interface/kernel_interface.py", line 131, in call
    self.shape = array.shape
                 ^^^^^^^^^^^
AttributeError: 'tuple' object has no attribute 'shape'
@ASEM000
Copy link
Owner

ASEM000 commented Jun 19, 2023

Hello,
Thanks for your question.
This is a reasonable request; I will try to look into it when I have time.

@ASEM000 ASEM000 added the enhancement New feature or request label Jun 20, 2023
@ASEM000 ASEM000 self-assigned this Jun 20, 2023
@ASEM000
Copy link
Owner

ASEM000 commented Jun 21, 2023

Hello, meanwhile, can you try this ?

The key point here is to stack the arrays on some axis i and make the kernel size for that axis i equal to the same size as the axis size with valid padding for that axis. In this example, i is the first axis.

I also recommend using jax.debug.print to ensure the array views are what you are looking for.

import jax.numpy as jnp
import kernex
import jax

@kernex.kmap(kernel_size=(2, 3, 3), padding=("valid","valid","valid"))
def kernel(tree):
    x, y = tree
    jax.debug.print("x={x} \n\n y={y}\n",x=x, y=y)
    return jnp.sum(x * jnp.square(y))

data = jnp.arange(25).reshape(5, 5)
out = kernel(jnp.stack([data, data],axis=0))

# x=[[ 0  1  2]
#  [ 5  6  7]
#  [10 11 12]] 

#  y=[[ 0  1  2]
#  [ 5  6  7]
#  [10 11 12]]

# x=[[ 1  2  3]
#  [ 6  7  8]
#  [11 12 13]] 

#  y=[[ 1  2  3]
#  [ 6  7  8]
#  [11 12 13]]

# x=[[ 2  3  4]
#  [ 7  8  9]
#  [12 13 14]] 

#  y=[[ 2  3  4]
#  [ 7  8  9]
#  [12 13 14]]

# x=[[ 5  6  7]
#  [10 11 12]
#  [15 16 17]] 

#  y=[[ 5  6  7]
#  [10 11 12]
#  [15 16 17]]

# x=[[ 6  7  8]
#  [11 12 13]
#  [16 17 18]] 

#  y=[[ 6  7  8]
#  [11 12 13]
#  [16 17 18]]

# x=[[ 7  8  9]
#  [12 13 14]
#  [17 18 19]] 

#  y=[[ 7  8  9]
#  [12 13 14]
#  [17 18 19]]

# x=[[10 11 12]
#  [15 16 17]
#  [20 21 22]] 

#  y=[[10 11 12]
#  [15 16 17]
#  [20 21 22]]

# x=[[11 12 13]
#  [16 17 18]
#  [21 22 23]] 

#  y=[[11 12 13]
#  [16 17 18]
#  [21 22 23]]

# x=[[12 13 14]
#  [17 18 19]
#  [22 23 24]] 

#  y=[[12 13 14]
#  [17 18 19]
#  [22 23 24]]

@clemisch
Copy link
Author

Thanks, that works for me!

@ASEM000
Copy link
Owner

ASEM000 commented Aug 29, 2023

As a follow-up, I think it is simpler to define which argnums to generate kernel. For the previous example maybe the API would be something like this kmap(.., argnums=(0,1))(lambda x,y: ... )

What do you think?

@clemisch
Copy link
Author

Thanks for the follow-up and including me in this.

To clarify, do you mean not supporting trees, but instead multiple arguments? So something like

@kernex.kmap(kernel_size=(3, 3), argnums=(0, 1))
def kernel(x, y):
    return jnp.sum(x * jnp.square(y))

, or for non-mapped local weights

@kernex.kmap(kernel_size=(3, 3), argnums=(0,))
def kernel(x, y_local):
    return jnp.sum(x * jnp.square(y_local))

where y_local would not be mapped over y but a constant (3,3) array.

TLDR: Anything is fine for me. I think supporting trees would be slightly more powerful, but any reasonable task should be translatable to multiple args instead of a tree.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants