Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement array-valued signatures #56

Open
adeak opened this issue Apr 16, 2021 · 12 comments · May be fixed by #58
Open

Implement array-valued signatures #56

adeak opened this issue Apr 16, 2021 · 12 comments · May be fixed by #58

Comments

@adeak
Copy link

adeak commented Apr 16, 2021

As of #54 the simplest scalar calls to jitted special functions should work.

However there's no support yet for array-valued inputs:

import numpy as np
from numba import njit
from scipy import special

x = np.linspace(-10, 10, 1000)

@njit
def jitted_j0(x):
    res = special.j0(x[0])  # works after PR #54
    # res = special.j0(x)  # breaks
    return res

print(jitted_j0(x))

This is not obviously a shortcoming, since looping in jitted functions should be alright. So this is just a mild suggestion to consider adding support for array-valued signatures. (This should probably be preceded with some benchmarks to see whether this would help anything performance-wise.)

@brandonwillard
Copy link
Contributor

This is not obviously a shortcoming, since looping in jitted functions should be alright.

It's definitely a shortcoming, because the corresponding scipy.special functions that are being overloaded are ufuncs and do not have this limitation.

I would say that it doesn't render the library useless, though.

Anyway, I took a quick shot at using numba.vectorize on the functions produced by choose_kernel, but numba.extending.overload does not like the returned type of a numba.vectorize-wrapped function. Is that expected?

@esc
Copy link
Member

esc commented Apr 16, 2021

Anyway, I took a quick shot at using numba.vectorize on the functions produced by choose_kernel, but numba.extending.overload does not like the returned type of a numba.vectorize-wrapped function. Is that expected?

do you have an example, perchance?

@brandonwillard
Copy link
Contributor

brandonwillard commented Apr 16, 2021

modified   numba_scipy/special/overloads.py
@@ -10,7 +10,12 @@ def choose_kernel(name, all_signatures):
         for signature in all_signatures:
             if args == signature:
                 f = signatures.name_and_types_to_pointer[(name, *signature)]
-                return lambda *args: f(*args)
+
+                @numba.vectorize
+                def _f(*args):
+                    return f(*args)
+
+                return _f
 
     return choice_function

results in the following error:

E   numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
E   No implementation of function Function(<ufunc 'agm'>) found for signature:
E    
E    >>> agm(float64, float64)
E    
E   There are 2 candidate implementations:
E     - Of which 2 did not match due to:
E     Overload in function 'choose_kernel.<locals>.choice_function': File: ../code/python/numba-scipy/numba_scipy/special/overloads.py: Line 9.
E       With argument(s): '(float64, float64)':
E      Rejected as the implementation raised a specific error:
E        AssertionError: Implementator function returned by `@overload` has an unexpected type.  Got <numba._DUFunc '_f'>
E     raised from ~/envs/numba-scipy-env/lib/python3.7/site-packages/numba/core/typing/templates.py:742
E   
E   During: resolving callee type: Function(<ufunc 'agm'>)
E   During: typing of call at ~/code/python/numba-scipy/numba_scipy/tests/test_special.py (76)
E   
E   
E   File "numba_scipy/tests/test_special.py", line 76:
E       def numba_func(*args):
E           return scipy_func(*args)
E           ^

Is numba.extending.overload attempting to numba.jit the function returned by choose_kernel? The error looks similar to the one produced when numba.njit-ing a function wrapped with numba.vectorize.

@brandonwillard
Copy link
Contributor

The varargs could also be a problem.

@brandonwillard
Copy link
Contributor

brandonwillard commented Apr 17, 2021

I have a hack to get this working in my vectorized-overloads branch. It creates a fixed-arguments function on the fly to get past some apparent varargs issues with numba.vectorize.

If anyone knows how to get past this varargs issue without creating functions in this fashion—or any other fundamentally AST-based approach—please tell me, it would really help with the work we're doing in Aesara, as well.

@stuartarchibald
Copy link
Contributor

There's no public extension API in Numba for declaring this in a simple manner, this sort of thing could be a work around.

from numba import njit, vectorize, types
from numba.extending import overload
import numpy as np
from numba import njit
from scipy import special

x = np.linspace(-10, 10, 1000)

# this is just a dummy scalar function cf. those in numba-scipy's wrapper for
# scipy.special.*, now #54 is in the standard overload for scalar j0 should
# just work.
@njit
def pretend_j0_from_cython(x):
    return x + 12.34

@vectorize
def vectorize_j0(x):
    return pretend_j0_from_cython(x)

# This gets the vectorization mechanics but will end up "hiding" the NumPy ufunc
@overload(special.j0)
def ol_beta(x):
    if isinstance(x, (types.Array, types.Number)):
        def impl(x):
            return vectorize_j0(x)
        return impl

@njit
def jitted_j0(x):
    res1 = special.j0(x[0])
    res2 = special.j0(x)
    return res1, res2

print(jitted_j0(x))

@brandonwillard
Copy link
Contributor

The issue I ran into above is the signature for the @vectorized function: varargs wouldn't work, so I had to construct the function via compile/AST.

@stuartarchibald
Copy link
Contributor

The issue I ran into above is the signature for the @vectorized function: varargs wouldn't work, so I had to construct the function via compile/AST.

Ah, I see, I misinterpreted this as not being able to register an overload with vectorize, and whilst that's a problem, I can see why *args failing is also a problem if you want to do that automatic generation!

Opened numba/numba#6954 to track.

@brandonwillard
Copy link
Contributor

Opened numba/numba#6954 to track.

Thanks for that; it's a problem that shows up in at least a couple places where we're trying to use Numba as a backend (e.g. here).

@brandonwillard brandonwillard linked a pull request Apr 20, 2021 that will close this issue
@PabloRdrRbl
Copy link

Hello, I have been able of using the workaround by @stuartarchibald . Is there any plan add this so there is no need to write the vectorized version of every function?

@esc
Copy link
Member

esc commented Jun 17, 2021

@PabloRdrRbl I think a PR has already been opened: #58

@PabloRdrRbl
Copy link

PabloRdrRbl commented Mar 10, 2022

There's no public extension API in Numba for declaring this in a simple manner, this sort of thing could be a work around.

from numba import njit, vectorize, types
from numba.extending import overload
import numpy as np
from numba import njit
from scipy import special

x = np.linspace(-10, 10, 1000)

# this is just a dummy scalar function cf. those in numba-scipy's wrapper for
# scipy.special.*, now #54 is in the standard overload for scalar j0 should
# just work.
@njit
def pretend_j0_from_cython(x):
    return x + 12.34

@vectorize
def vectorize_j0(x):
    return pretend_j0_from_cython(x)

# This gets the vectorization mechanics but will end up "hiding" the NumPy ufunc
@overload(special.j0)
def ol_beta(x):
    if isinstance(x, (types.Array, types.Number)):
        def impl(x):
            return vectorize_j0(x)
        return impl

@njit
def jitted_j0(x):
    res1 = special.j0(x[0])
    res2 = special.j0(x)
    return res1, res2

print(jitted_j0(x))

Is it possible to extend it to a function like jv, which takes two arguments?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants