Skip to content

Commit

Permalink
Addressed PR comments and added a test.
Browse files Browse the repository at this point in the history
  • Loading branch information
lbooker42 committed Aug 31, 2023
1 parent fe34379 commit 6caa4ce
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
14 changes: 9 additions & 5 deletions py/server/deephaven/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3596,14 +3596,16 @@ def __init__(self, table: Table, on: Union[str, Sequence[str]], joins: Union[str


class MultiJoinTable(JObjectWrapper):
"""A MultiJoinTable represents the result of a multi-table natural join. """
"""A MultiJoinTable is an object that contains the result of a multi-table natural join. To retrieve the underlying
result Table, use the table() method. """
j_object_type = _JMultiJoinTable

@property
def j_object(self) -> jpy.JType:
return self.j_multijointable

def table(self) -> Table:
"""Returns the Table containing the multi-table natural join output. """
return Table(j_table=self.j_multijointable.table())

def __init__(self, input: Union[Table, Sequence[Table], MultiJoinInput, Sequence[MultiJoinInput]],
Expand All @@ -3614,8 +3616,8 @@ def __init__(self, input: Union[Table, Sequence[Table], MultiJoinInput, Sequence
input (Union[Table, Sequence[Table], MultiJoinInput, Sequence[MultiJoinInput]]): the input objects
specifying the tables and columns to include in the join.
on (Union[str, Sequence[str]], optional): the column(s) to match, can be a common name or an equality
expression that matches every input table, i.e. "col_a = col_b" to rename output column names. When
using MultiJoinInput objects, this parameter is ignored.
expression that matches every input table, i.e. "col_a = col_b" to rename output column names. Note:
When MultiJoinInput objects are supplied, this parameter must be omitted.
Raises:
DHError
Expand All @@ -3627,6 +3629,8 @@ def __init__(self, input: Union[Table, Sequence[Table], MultiJoinInput, Sequence
j_tables = to_sequence(input)
self.j_multijointable = _JMultiJoinFactory.of(on, *j_tables)
elif isinstance(input, MultiJoinInput) or (isinstance(input, Sequence) and all(isinstance(ji, MultiJoinInput) for ji in input)):
if on is not None:
raise DHError(message="on parameter is not permitted when MultiJoinInput objects are provided.")
wrapped_input = to_sequence(input, wrapped=True)
tables = [ji.table for ji in wrapped_input]
with auto_locking_ctx(*tables):
Expand All @@ -3650,8 +3654,8 @@ def multi_join(input: Union[Table, Sequence[Table], MultiJoinInput, Sequence[Mul
input (Union[Table, Sequence[Table], MultiJoinInput, Sequence[MultiJoinInput]]): the input objects specifying the
tables and columns to include in the join.
on (Union[str, Sequence[str]], optional): the column(s) to match, can be a common name or an equality expression
that matches every input table, i.e. "col_a = col_b" to rename output column names. When using MultiJoinInput
objects, this parameter is ignored.
that matches every input table, i.e. "col_a = col_b" to rename output column names. Note: When
MultiJoinInput objects are supplied, this parameter must be omitted.
Returns:
MultiJoinTable: the result of the multi-table natural join operation. To access the underlying Table, use the
Expand Down
11 changes: 10 additions & 1 deletion py/server/tests/test_multijoin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import unittest

from deephaven import read_csv, time_table, update_graph
from deephaven import read_csv, time_table, update_graph, DHError
from deephaven.table import MultiJoinInput, MultiJoinTable, multi_join
from tests.testbase import BaseTestCase
from deephaven.execution_context import get_exec_ctx
Expand Down Expand Up @@ -115,6 +115,15 @@ def test_ticking(self):
with update_graph.exclusive_lock(self.test_update_graph):
self.assertEqual(mj_table.table().size, self.ticking_tableA.size)

def test_errors(self):
# Assert the exception is raised when providing MultiJoinInput and the on parameter is not None (omitted).
mj_input = [
MultiJoinInput(table=self.ticking_tableA, on=["key1=a","key2=b"], joins=["c1","e1"]),
MultiJoinInput(table=self.ticking_tableB, on=["key1=a","key2=b"], joins=["d2"])
]
with self.assertRaises(DHError):
mj_table = multi_join(mj_input, on=["key1=a","key2=b"])


if __name__ == '__main__':
unittest.main()

0 comments on commit 6caa4ce

Please sign in to comment.