Skip to content

Commit

Permalink
Support the use of Python callable's attributes in query strings (#3509)
Browse files Browse the repository at this point in the history
* Support the use of PY callable's attr in query
  • Loading branch information
jmao-denver authored Mar 10, 2023
1 parent 3543da8 commit 36facaf
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ private Class<?>[] printArguments(Expression[] arguments, VisitArgs printer) {

printer.append('(');
for (int i = 0; i < arguments.length; i++) {
types.add(arguments[i].accept(this, printer));
types.add(arguments[i].accept(this, printer.cloneWithCastingContext(null)));

if (i != arguments.length - 1) {
printer.append(", ");
Expand Down Expand Up @@ -489,7 +489,7 @@ private Method getMethod(final Class<?> scope, final String methodName, final Cl
}
}
} else {
if (scope == org.jpy.PyObject.class) {
if (scope == org.jpy.PyObject.class || scope == PyCallableWrapper.class) {
// This is a Python method call, assume it exists and wrap in PythonScopeJpyImpl.CallableWrapper
for (Method method : PyCallableWrapper.class.getDeclaredMethods()) {
possiblyAddExecutable(acceptableMethods, method, "call", paramTypes, parameterizedTypes);
Expand Down Expand Up @@ -1656,7 +1656,7 @@ public Class<?> visit(FieldAccessExpr n, VisitArgs printer) {
try {
// For Python object, the type of the field is PyObject by default, the actual data type if
// primitive will only be known at runtime
if (scopeType == PyObject.class) {
if (scopeType == PyObject.class || scopeType == PyCallableWrapper.class) {
ret = PyObject.class;
} else {
ret = scopeType.getField(fieldName).getType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ public PyCallableWrapper(PyObject pyCallable) {
this.pyCallable = pyCallable;
}

public PyObject getAttribute(String name) {
return this.pyCallable.getAttribute(name);
}

public <T> T getAttribute(String name, Class<? extends T> valueType) {
return this.pyCallable.getAttribute(name, valueType);
}

public ArgumentsChunked buildArgumentsChunked(List<String> columnNames) {
for (ChunkArgument arg : chunkArguments) {
if (arg instanceof ColumnChunkArgument) {
Expand Down
53 changes: 52 additions & 1 deletion py/server/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from types import SimpleNamespace
from typing import List, Any

from deephaven import DHError, read_csv, empty_table, SortDirection, AsOfMatchRule, time_table, ugp
from deephaven import DHError, read_csv, empty_table, SortDirection, AsOfMatchRule, time_table, ugp, new_table, dtypes
from deephaven.agg import sum_, weighted_avg, avg, pct, group, count_, first, last, max_, median, min_, std, abs_sum, \
var, formula, partition
from deephaven.column import datetime_col
from deephaven.execution_context import make_user_exec_ctx
from deephaven.html import to_html
from deephaven.jcompat import j_hashmap
Expand Down Expand Up @@ -902,6 +903,56 @@ def make_pairs_3(tid, a, b):
self.assertEqual(x2.size, 10)
self.assertEqual(x3.size, 10)

def test_callable_attrs_in_query(self):
input_cols = [
datetime_col(name="DTCol", data=[dtypes.DateTime(1), dtypes.DateTime(10000000)]),
]
test_table = new_table(cols=input_cols)
from deephaven.time import year, TimeZone
rt = test_table.update("Year = (int)year(DTCol, TimeZone.NY)")
self.assertEqual(rt.size, test_table.size)

class Foo:
ATTR = 256

def __call__(self):
...

def do_something_instance(self, p=None):
return p if p else 1

@classmethod
def do_something_cls(cls, p=None):
return p if p else 1

@staticmethod
def do_something_static(p=None):
return p if p else 1

def do_something(p=None):
return p if p else 1

rt = empty_table(1).update("Col = Foo.ATTR")
self.assertTrue(rt.columns[0].data_type == dtypes.PyObject)

rt = empty_table(1).update("Col = (int)Foo.ATTR")
self.assertTrue(rt.columns[0].data_type == dtypes.int32)

foo = Foo()
rt = empty_table(1).update("Col = (int)foo.do_something_instance()")
self.assertTrue(rt.columns[0].data_type == dtypes.int32)

rt = empty_table(1).update("Col = (int)Foo.do_something_cls()")
self.assertTrue(rt.columns[0].data_type == dtypes.int32)

rt = empty_table(1).update("Col = (int)foo.do_something_static()")
self.assertTrue(rt.columns[0].data_type == dtypes.int32)

rt = empty_table(1).update("Col = (int)do_something((byte)Foo.ATTR)")
df = to_pandas(rt)
self.assertEqual(df.loc[0]['Col'], 1)
self.assertTrue(rt.columns[0].data_type == dtypes.int32)


if __name__ == "__main__":
unittest.main()

0 comments on commit 36facaf

Please sign in to comment.