Skip to content

Commit

Permalink
Disable vectorization on arg count mismatch (#3372)
Browse files Browse the repository at this point in the history
* Disable vectorization on arg count mismatch

* Add a new test case
  • Loading branch information
jmao-denver authored and devinrsmith committed Jan 27, 2023
1 parent 1644128 commit 8dd1978
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1891,16 +1891,18 @@ private void checkVectorizability(MethodCallExpr n, Expression[] expressions, Py
+ expressions[i]);
}
}
}

private void prepareVectorizationArgs(MethodCallExpr n, QueryScope queryScope, Expression[] expressions,
Class<?>[] argTypes,
PyCallableWrapper pyCallableWrapper) {
List<Class<?>> paramTypes = pyCallableWrapper.getParamTypes();
if (paramTypes.size() != expressions.length) {
// note vectorization doesn't handle Python variadic arguments
throw new RuntimeException("Python function argument count mismatch: " + n + " " + paramTypes.size()
+ " vs. " + expressions.length);
}
}

private void prepareVectorizationArgs(MethodCallExpr n, QueryScope queryScope, Expression[] expressions,
Class<?>[] argTypes,
PyCallableWrapper pyCallableWrapper) {

pyCallableWrapper.initializeChunkArguments();
for (int i = 0; i < expressions.length; i++) {
Expand All @@ -1919,6 +1921,7 @@ private void prepareVectorizationArgs(MethodCallExpr n, QueryScope queryScope, E
throw new IllegalStateException("Vectorizability check failed: " + n);
}

List<Class<?>> paramTypes = pyCallableWrapper.getParamTypes();
if (!isSafelyCoerceable(argTypes[i], paramTypes.get(i))) {
throw new RuntimeException("Python vectorized function argument type mismatch: " + n + " "
+ argTypes[i].getSimpleName() + " -> " + paramTypes.get(i).getSimpleName());
Expand Down
27 changes: 15 additions & 12 deletions py/server/tests/test_vectorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,16 @@
# Copyright (c) 2016-2022 Deephaven Data Labs and Patent Pending
#
import random
import time
import unittest
from types import SimpleNamespace
from typing import List, Any

import numpy as np

import deephaven

from deephaven import DHError, read_csv, empty_table, SortDirection, AsOfMatchRule, time_table, ugp, dtypes
from deephaven.agg import sum_, weighted_avg, avg, pct, group, count_, first, last, max_, median, min_, std, abs_sum, \
var, formula, partition
from deephaven.execution_context import make_user_exec_ctx
from deephaven import DHError, empty_table, dtypes
from deephaven import new_table
from deephaven.column import int_col
from deephaven.filters import Filter, and_
from deephaven.html import to_html
from deephaven.pandas import to_pandas
from deephaven.table import Table, dh_vectorize
from deephaven.table import dh_vectorize
from tests.testbase import BaseTestCase


Expand Down Expand Up @@ -54,7 +47,7 @@ def no_param_func():

with self.assertRaises(DHError) as cm:
t1 = t.update("X = auto_func(i)")
self.assertRegex(str(cm.exception), r".*count.*mismatch", )
self.assertRegex(str(cm.exception), r"missing 1 required positional argument", )

with self.subTest("can't cast return value"):
with self.assertRaises(DHError) as cm:
Expand Down Expand Up @@ -248,6 +241,16 @@ def pyfunc_obj() -> object:
t = empty_table(1).update("X = pyfunc_obj()")
self.assertEqual(t.columns[0].data_type, dtypes.PyObject)

def test_varargs_still_work(self):
cols = ["A", "B", "C", "D"]

def my_sum(*args):
return sum(args)

source = new_table([int_col(c, [0, 1, 2, 3, 4, 5, 6]) for c in cols])
result = source.update(f"X = my_sum({','.join(cols)})")
self.assertEqual(len(cols) + 1, len(result.columns))


if __name__ == "__main__":
unittest.main()

0 comments on commit 8dd1978

Please sign in to comment.