From 30140d27d4186e1e515ec88830d421d968585354 Mon Sep 17 00:00:00 2001 From: jianfengmao Date: Tue, 11 Jun 2024 09:14:46 -0600 Subject: [PATCH 1/2] Fix an issue with boxed type arg and np scaar wrn --- .../engine/util/PyCallableWrapperJpyImpl.java | 18 ++++++++++++--- py/server/deephaven/_udf.py | 8 +++++++ py/server/tests/test_numba_guvectorize.py | 14 +++++++++++ .../tests/test_numba_vectorized_column.py | 23 +++++++++++++++++-- py/server/tests/test_udf_scalar_args.py | 13 +++++++++++ 5 files changed, 71 insertions(+), 5 deletions(-) diff --git a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java index 7cb26d39a2c..fe8dc367439 100644 --- a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java +++ b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java @@ -6,6 +6,7 @@ import io.deephaven.engine.table.impl.select.python.ArgumentsChunked; import io.deephaven.internal.log.LoggerFactory; import io.deephaven.io.logger.Logger; +import io.deephaven.util.type.TypeUtils; import org.jetbrains.annotations.NotNull; import org.jpy.PyModule; import org.jpy.PyObject; @@ -68,6 +69,13 @@ private static class UnsupportedPythonTypeHint { for (Map.Entry> classClassEntry : numpyType2JavaArrayClass.entrySet()) { javaClass2NumpyType.put(classClassEntry.getValue(), classClassEntry.getKey()); } + javaClass2NumpyType.put(Byte.class, 'b'); + javaClass2NumpyType.put(Short.class, 'h'); + javaClass2NumpyType.put(Character.class, 'H'); + javaClass2NumpyType.put(Integer.class, 'i'); + javaClass2NumpyType.put(Long.class, 'l'); + javaClass2NumpyType.put(Float.class, 'f'); + javaClass2NumpyType.put(Double.class, 'd'); } /** @@ -235,15 +243,19 @@ public void parseSignature() { } - private boolean hasSafelyCastable(Set> types, @NotNull Class type) { + private boolean hasSafelyCastable(Set> types, @NotNull Class argType) { for (Class t : types) { if (t == null) { continue; } - if (t.isAssignableFrom(type)) { + if (t.isAssignableFrom(argType)) { return true; } - if (t.isPrimitive() && type.isPrimitive() && isLosslessWideningPrimitiveConversion(type, t)) { + if (t.isPrimitive() && argType.isPrimitive() && isLosslessWideningPrimitiveConversion(argType, t)) { + return true; + } + if (t.isPrimitive() && TypeUtils.isBoxedType(argType) + && isLosslessWideningPrimitiveConversion(TypeUtils.getUnboxedType(argType), t)) { return true; } } diff --git a/py/server/deephaven/_udf.py b/py/server/deephaven/_udf.py index 42d44163063..db715dc4736 100644 --- a/py/server/deephaven/_udf.py +++ b/py/server/deephaven/_udf.py @@ -200,6 +200,11 @@ def encoded(self) -> str: def prepare_auto_arg_conv(self, encoded_arg_types: str) -> bool: """ Determine whether the auto argument conversion should be used and set the converter functions for the parameters.""" + # numba and numpy ufuncs don't need auto argument conversion as they handle the conversion themselves and the + # arg types are already verified by the query engine + if isinstance(self.fn, (numba.np.ufunc.gufunc.GUFunc, numba.np.ufunc.dufunc.DUFunc)) or isinstance(self.fn, numpy.ufunc): + return False + if not self.params or not encoded_arg_types: return False @@ -211,6 +216,9 @@ def prepare_auto_arg_conv(self, encoded_arg_types: str) -> bool: for arg_type_str, param in zip(arg_type_strs, self.params): param.setup_arg_converter(arg_type_str) + if isinstance(self.fn, (numba.np.ufunc.gufunc.GUFunc, numba.np.ufunc.dufunc.DUFunc)) or isinstance(self.fn, numpy.ufunc): + warnings.filterwarnings("default", category=UserWarning) + if all([param.arg_converter is None for param in self.params]): arg_conv_needed = False diff --git a/py/server/tests/test_numba_guvectorize.py b/py/server/tests/test_numba_guvectorize.py index adbb7119b4f..e63a155dc30 100644 --- a/py/server/tests/test_numba_guvectorize.py +++ b/py/server/tests/test_numba_guvectorize.py @@ -3,6 +3,7 @@ # import unittest +import warnings import numpy as np from numba import guvectorize, int64, int32 @@ -108,6 +109,19 @@ def g(x, res): t = empty_table(10).update(["X=i%3", "Y=(double)ii"]).group_by("X").update("Z=g(Y)") self.assertIn("g: Expected argument (1)", str(cm.exception)) + def test_boxed_type_arg(self): + @guvectorize([(int64[:], int64, int64[:])], "(m),()->(m)", nopython=True) + def g(x, y, res): + for i in range(len(x)): + res[i] = x[i] + y + + # make sure we don't get a warning about numpy scalar used in annotation + warnings.filterwarnings("error", category=UserWarning) + lv = 2 + t = empty_table(10).update(["X=ii%3", "Y=ii"]).group_by("X").update("Z=g(Y, lv)") + self.assertEqual(t.columns[2].data_type, dtypes.long_array) + warnings.filterwarnings("default", category=UserWarning) + if __name__ == '__main__': unittest.main() diff --git a/py/server/tests/test_numba_vectorized_column.py b/py/server/tests/test_numba_vectorized_column.py index eb5d2d7cfc0..0343f4d56dc 100644 --- a/py/server/tests/test_numba_vectorized_column.py +++ b/py/server/tests/test_numba_vectorized_column.py @@ -3,19 +3,22 @@ # import unittest +import warnings from numba import vectorize, int64 - +import math from deephaven import empty_table, DHError from deephaven.html import to_html + from tests.testbase import BaseTestCase +rate_risk_free = 0.05 + @vectorize([int64(int64, int64)]) def vectorized_func(x, y): return x % 3 + y - class NumbaVectorizedColumnTestCase(BaseTestCase): def test_part_of_expr(self): @@ -32,6 +35,22 @@ def test_column(self): html_output = to_html(t) self.assertIn("9", html_output) + def test_boxed_type_arg(self): + @vectorize(['float64(float64)']) + def norm_cdf(x): + """ Cumulative distribution function for the standard normal distribution """ + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + # make sure we don't get a warning about numpy scalar used in annotation + warnings.filterwarnings("error", category=UserWarning) + with self.subTest("Boxed type 2 primitives"): + dv = 0.05 + t = empty_table(10).update("X = norm_cdf(dv)") + + with self.subTest("Boxed type 2 primitives - 2"): + dv = 0.05 + t = empty_table(10).update(["Y = dv*1.0", "X = norm_cdf(Y)"]) + warnings.filterwarnings("default", category=UserWarning) if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/py/server/tests/test_udf_scalar_args.py b/py/server/tests/test_udf_scalar_args.py index f5fe4ac759e..1eebdd914a2 100644 --- a/py/server/tests/test_udf_scalar_args.py +++ b/py/server/tests/test_udf_scalar_args.py @@ -2,6 +2,7 @@ # Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending # import typing +import warnings from datetime import datetime from typing import Optional, Union, Any, Sequence import unittest @@ -601,5 +602,17 @@ def misuse_c(c: int) -> int: with self.assertRaises(DHError) as cm: t.update("V = misuse_c(C)") + def test_boxed_type_arg(self): + def f(p1: float, p2: np.float64) -> bool: + return p1 == 0.05 + + dv = 0.05 + with warnings.catch_warnings(record=True) as w: + t = empty_table(10).update("X = f(dv, dv)") + self.assertEqual(w[-1].category, UserWarning) + self.assertRegex(str(w[-1].message), "numpy scalar type.*is used") + self.assertEqual(10, t.to_string().count("true")) + + if __name__ == "__main__": unittest.main() From 4886ecdd468c6146c262a1b3f40e9b62c7e1a5c1 Mon Sep 17 00:00:00 2001 From: jianfengmao Date: Tue, 11 Jun 2024 09:53:18 -0600 Subject: [PATCH 2/2] Some code cleanup/fix --- py/server/deephaven/_udf.py | 3 --- py/server/tests/test_numba_vectorized_column.py | 2 -- 2 files changed, 5 deletions(-) diff --git a/py/server/deephaven/_udf.py b/py/server/deephaven/_udf.py index db715dc4736..dabe1cb2c0b 100644 --- a/py/server/deephaven/_udf.py +++ b/py/server/deephaven/_udf.py @@ -216,9 +216,6 @@ def prepare_auto_arg_conv(self, encoded_arg_types: str) -> bool: for arg_type_str, param in zip(arg_type_strs, self.params): param.setup_arg_converter(arg_type_str) - if isinstance(self.fn, (numba.np.ufunc.gufunc.GUFunc, numba.np.ufunc.dufunc.DUFunc)) or isinstance(self.fn, numpy.ufunc): - warnings.filterwarnings("default", category=UserWarning) - if all([param.arg_converter is None for param in self.params]): arg_conv_needed = False diff --git a/py/server/tests/test_numba_vectorized_column.py b/py/server/tests/test_numba_vectorized_column.py index 0343f4d56dc..9c7f95a8baf 100644 --- a/py/server/tests/test_numba_vectorized_column.py +++ b/py/server/tests/test_numba_vectorized_column.py @@ -12,8 +12,6 @@ from tests.testbase import BaseTestCase -rate_risk_free = 0.05 - @vectorize([int64(int64, int64)]) def vectorized_func(x, y):