Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix an issue with boxed type arg matching and disable np scalar warnings for numba #5599

Merged
merged 2 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import io.deephaven.engine.table.impl.select.python.ArgumentsChunked;
import io.deephaven.internal.log.LoggerFactory;
import io.deephaven.io.logger.Logger;
import io.deephaven.util.type.TypeUtils;
import org.jetbrains.annotations.NotNull;
import org.jpy.PyModule;
import org.jpy.PyObject;
Expand Down Expand Up @@ -68,6 +69,13 @@ private static class UnsupportedPythonTypeHint {
for (Map.Entry<Character, Class<?>> classClassEntry : numpyType2JavaArrayClass.entrySet()) {
javaClass2NumpyType.put(classClassEntry.getValue(), classClassEntry.getKey());
}
javaClass2NumpyType.put(Byte.class, 'b');
javaClass2NumpyType.put(Short.class, 'h');
javaClass2NumpyType.put(Character.class, 'H');
javaClass2NumpyType.put(Integer.class, 'i');
javaClass2NumpyType.put(Long.class, 'l');
javaClass2NumpyType.put(Float.class, 'f');
javaClass2NumpyType.put(Double.class, 'd');
}

/**
Expand Down Expand Up @@ -235,15 +243,19 @@ public void parseSignature() {

}

private boolean hasSafelyCastable(Set<Class<?>> types, @NotNull Class<?> type) {
private boolean hasSafelyCastable(Set<Class<?>> types, @NotNull Class<?> argType) {
for (Class<?> t : types) {
if (t == null) {
continue;
}
if (t.isAssignableFrom(type)) {
if (t.isAssignableFrom(argType)) {
return true;
}
if (t.isPrimitive() && type.isPrimitive() && isLosslessWideningPrimitiveConversion(type, t)) {
if (t.isPrimitive() && argType.isPrimitive() && isLosslessWideningPrimitiveConversion(argType, t)) {
return true;
}
if (t.isPrimitive() && TypeUtils.isBoxedType(argType)
&& isLosslessWideningPrimitiveConversion(TypeUtils.getUnboxedType(argType), t)) {
return true;
}
}
Expand Down
5 changes: 5 additions & 0 deletions py/server/deephaven/_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@ def encoded(self) -> str:
def prepare_auto_arg_conv(self, encoded_arg_types: str) -> bool:
""" Determine whether the auto argument conversion should be used and set the converter functions for the
parameters."""
# numba and numpy ufuncs don't need auto argument conversion as they handle the conversion themselves and the
# arg types are already verified by the query engine
if isinstance(self.fn, (numba.np.ufunc.gufunc.GUFunc, numba.np.ufunc.dufunc.DUFunc)) or isinstance(self.fn, numpy.ufunc):
return False

if not self.params or not encoded_arg_types:
return False

Expand Down
14 changes: 14 additions & 0 deletions py/server/tests/test_numba_guvectorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#

import unittest
import warnings

import numpy as np
from numba import guvectorize, int64, int32
Expand Down Expand Up @@ -108,6 +109,19 @@ def g(x, res):
t = empty_table(10).update(["X=i%3", "Y=(double)ii"]).group_by("X").update("Z=g(Y)")
self.assertIn("g: Expected argument (1)", str(cm.exception))

def test_boxed_type_arg(self):
@guvectorize([(int64[:], int64, int64[:])], "(m),()->(m)", nopython=True)
def g(x, y, res):
for i in range(len(x)):
res[i] = x[i] + y

# make sure we don't get a warning about numpy scalar used in annotation
warnings.filterwarnings("error", category=UserWarning)
jmao-denver marked this conversation as resolved.
Show resolved Hide resolved
lv = 2
t = empty_table(10).update(["X=ii%3", "Y=ii"]).group_by("X").update("Z=g(Y, lv)")
self.assertEqual(t.columns[2].data_type, dtypes.long_array)
warnings.filterwarnings("default", category=UserWarning)


if __name__ == '__main__':
unittest.main()
21 changes: 19 additions & 2 deletions py/server/tests/test_numba_vectorized_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,20 @@
#

import unittest
import warnings

from numba import vectorize, int64

import math
from deephaven import empty_table, DHError
from deephaven.html import to_html

from tests.testbase import BaseTestCase


@vectorize([int64(int64, int64)])
def vectorized_func(x, y):
return x % 3 + y


class NumbaVectorizedColumnTestCase(BaseTestCase):

def test_part_of_expr(self):
Expand All @@ -32,6 +33,22 @@ def test_column(self):
html_output = to_html(t)
self.assertIn("<td>9</td>", html_output)

def test_boxed_type_arg(self):
@vectorize(['float64(float64)'])
def norm_cdf(x):
""" Cumulative distribution function for the standard normal distribution """
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0

# make sure we don't get a warning about numpy scalar used in annotation
warnings.filterwarnings("error", category=UserWarning)
jmao-denver marked this conversation as resolved.
Show resolved Hide resolved
with self.subTest("Boxed type 2 primitives"):
dv = 0.05
t = empty_table(10).update("X = norm_cdf(dv)")

with self.subTest("Boxed type 2 primitives - 2"):
dv = 0.05
t = empty_table(10).update(["Y = dv*1.0", "X = norm_cdf(Y)"])
warnings.filterwarnings("default", category=UserWarning)

if __name__ == "__main__":
unittest.main(verbosity=2)
13 changes: 13 additions & 0 deletions py/server/tests/test_udf_scalar_args.py
jmao-denver marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending
#
import typing
import warnings
from datetime import datetime
from typing import Optional, Union, Any, Sequence
import unittest
Expand Down Expand Up @@ -601,5 +602,17 @@ def misuse_c(c: int) -> int:
with self.assertRaises(DHError) as cm:
t.update("V = misuse_c(C)")

def test_boxed_type_arg(self):
def f(p1: float, p2: np.float64) -> bool:
return p1 == 0.05

dv = 0.05
with warnings.catch_warnings(record=True) as w:
jmao-denver marked this conversation as resolved.
Show resolved Hide resolved
t = empty_table(10).update("X = f(dv, dv)")
self.assertEqual(w[-1].category, UserWarning)
self.assertRegex(str(w[-1].message), "numpy scalar type.*is used")
self.assertEqual(10, t.to_string().count("true"))


if __name__ == "__main__":
unittest.main()
Loading