Skip to content

Commit

Permalink
Merge pull request #114 from oyamad/overload
Browse files Browse the repository at this point in the history
Apply `@njit` to `interp` and `mlinterp`
  • Loading branch information
albop authored Mar 21, 2024
2 parents 705cbce + 7003f0f commit 2d20da5
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 9 deletions.
26 changes: 17 additions & 9 deletions interpolation/multilinear/mlinterp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@
# logic of multilinear interpolation


def mlinterp(grid, c, u):
def _mlinterp(grid, c, u):
pass


@overload(mlinterp)
@overload(_mlinterp)
def ol_mlinterp(grid, c, u):
if isinstance(u, UniTuple):

def mlininterp(grid: Tuple, c: Array, u: Tuple) -> float:
def mlininterp(grid, c, u):
# get indices and barycentric coordinates
tmp = fmap(get_index, grid, u)
indices, barycenters = funzip(tmp)
Expand All @@ -59,7 +59,7 @@ def mlininterp(grid: Tuple, c: Array, u: Tuple) -> float:

elif isinstance(u, Array) and u.ndim == 2:

def mlininterp(grid: Tuple, c: Array, u: Array) -> float:
def mlininterp(grid, c, u):
N = u.shape[0]
res = np.zeros(N)
for n in range(N):
Expand All @@ -76,6 +76,11 @@ def mlininterp(grid: Tuple, c: Array, u: Array) -> float:
return mlininterp


@njit
def mlinterp(grid, c, u):
return _mlinterp(grid, c, u)


### The rest of this file constrcts function `interp`

from collections import namedtuple
Expand Down Expand Up @@ -217,15 +222,13 @@ def {funname}(*args):
return source


def interp(*args):
def _interp(*args):
pass


@overload(interp)
@overload(_interp)
def ol_interp(*args):
aa = args[0].types

it = detect_types(aa)
it = detect_types(args)
if it.d == 1 and it.eval == "point":
it = itt(it.d, it.values, "cartesian")
source = make_mlinterp(it, "__mlinterp")
Expand All @@ -235,3 +238,8 @@ def ol_interp(*args):
code = compile(tree, "<string>", "exec")
eval(code, globals())
return __mlinterp


@njit
def interp(*args):
return _interp(*args)
5 changes: 5 additions & 0 deletions interpolation/multilinear/tests/test_multilinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,10 @@ def test_mlinterp():
pp = np.random.random((2000, 2))

res0 = mlinterp((x1, x2), y, pp)
assert res0 is not None

res0 = mlinterp((x1, x2), y, (0.1, 0.2))
assert res0 is not None


def test_multilinear():
Expand All @@ -125,6 +128,8 @@ def test_multilinear():
tt = [typeof(e) for e in t]
rr = interp(*t)

assert rr is not None

try:
print(f"{tt}: {rr.shape}")
except:
Expand Down

0 comments on commit 2d20da5

Please sign in to comment.