Skip to content

Commit

Permalink
Support merge_asof in Spark Connect.
Browse files Browse the repository at this point in the history
  • Loading branch information
ueshin committed Sep 26, 2023
1 parent ea3104f commit ca44182
Show file tree
Hide file tree
Showing 8 changed files with 481 additions and 141 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ message Relation {
CachedLocalRelation cached_local_relation = 36;
CachedRemoteRelation cached_remote_relation = 37;
CommonInlineUserDefinedTableFunction common_inline_user_defined_table_function = 38;
AsOfJoin as_of_join = 39;

// NA functions
NAFill fill_na = 90;
Expand Down Expand Up @@ -1009,3 +1010,44 @@ message Parse {
PARSE_FORMAT_JSON = 2;
}
}

// Relation of type [[AsOfJoin]].
//
// `left` and `right` must be present.
message AsOfJoin {
// (Required) Left input relation for a Join.
Relation left = 1;

// (Required) Right input relation for a Join.
Relation right = 2;

// (Required) Field to join on in left DataFrame
Expression left_as_of = 3;

// (Required) Field to join on in right DataFrame
Expression right_as_of = 4;

// (Optional) The join condition. Could be unset when `using_columns` is utilized.
//
// This field does not co-exist with using_columns.
Expression join_expr = 5;

// Optional. using_columns provides a list of columns that should present on both sides of
// the join inputs that this Join will join on. For example A JOIN B USING col_name is
// equivalent to A JOIN B on A.col_name = B.col_name.
//
// This field does not co-exist with join_condition.
repeated string using_columns = 6;

// (Required) The join type.
string join_type = 7;

// (Optional) The asof tolerance within this range.
Expression tolerance = 8;

// (Required) Whether allow matching with the same value or not.
bool allow_exact_matches = 9;

// (Required) Whether to search for prior, subsequent, or closest matches.
string direction = 10;
}
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
case proto.Relation.RelTypeCase.OFFSET => transformOffset(rel.getOffset)
case proto.Relation.RelTypeCase.TAIL => transformTail(rel.getTail)
case proto.Relation.RelTypeCase.JOIN => transformJoinOrJoinWith(rel.getJoin)
case proto.Relation.RelTypeCase.AS_OF_JOIN => transformAsOfJoin(rel.getAsOfJoin)
case proto.Relation.RelTypeCase.DEDUPLICATE => transformDeduplicate(rel.getDeduplicate)
case proto.Relation.RelTypeCase.SET_OP => transformSetOperation(rel.getSetOp)
case proto.Relation.RelTypeCase.SORT => transformSort(rel.getSort)
Expand Down Expand Up @@ -2275,6 +2276,42 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
}
}

private def transformAsOfJoin(rel: proto.AsOfJoin): LogicalPlan = {
val left = Dataset.ofRows(session, transformRelation(rel.getLeft))
val right = Dataset.ofRows(session, transformRelation(rel.getRight))
val leftAsOf = Column(transformExpression(rel.getLeftAsOf))
val rightAsOf = Column(transformExpression(rel.getRightAsOf))
val joinType = rel.getJoinType
val tolerance = if (rel.hasTolerance) Column(transformExpression(rel.getTolerance)) else null
val allowExactMatches = rel.getAllowExactMatches
val direction = rel.getDirection

val joined = if (rel.getUsingColumnsCount > 0) {
val usingColumns = rel.getUsingColumnsList.asScala.toSeq
left.joinAsOf(
other = right,
leftAsOf = leftAsOf,
rightAsOf = rightAsOf,
usingColumns = usingColumns,
joinType = joinType,
tolerance = tolerance,
allowExactMatches = allowExactMatches,
direction = direction)
} else {
val joinExprs = if (rel.hasJoinExpr) Column(transformExpression(rel.getJoinExpr)) else null
left.joinAsOf(
other = right,
leftAsOf = leftAsOf,
rightAsOf = rightAsOf,
joinExprs = joinExprs,
joinType = joinType,
tolerance = tolerance,
allowExactMatches = allowExactMatches,
direction = direction)
}
joined.logicalPlan
}

private def transformSort(sort: proto.Sort): LogicalPlan = {
assert(sort.getOrderCount > 0, "'order' must be present and contain elements.")
logical.Sort(
Expand Down
4 changes: 1 addition & 3 deletions python/pyspark/pandas/tests/connect/test_parity_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@


class ReshapeParityTests(ReshapeTestsMixin, PandasOnSparkTestUtils, ReusedConnectTestCase):
@unittest.skip("TODO(SPARK-43662): Enable ReshapeParityTests.test_merge_asof.")
def test_merge_asof(self):
super().test_merge_asof()
pass


if __name__ == "__main__":
Expand Down
41 changes: 41 additions & 0 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,47 @@ def join(

join.__doc__ = PySparkDataFrame.join.__doc__

def _joinAsOf(
self,
other: "DataFrame",
leftAsOfColumn: Union[str, Column],
rightAsOfColumn: Union[str, Column],
on: Optional[Union[str, List[str], Column, List[Column]]] = None,
how: Optional[str] = None,
*,
tolerance: Optional[Column] = None,
allowExactMatches: bool = True,
direction: str = "backward",
) -> "DataFrame":
if self._plan is None:
raise Exception("Cannot join when self._plan is empty.")
if other._plan is None:
raise Exception("Cannot join when other._plan is empty.")

if how is None:
how = "inner"
assert isinstance(how, str), "how should be a string"

if tolerance is not None:
assert isinstance(tolerance, Column), "tolerance should be Column"

return DataFrame.withPlan(
plan.AsOfJoin(
left=self._plan,
right=other._plan,
left_as_of=leftAsOfColumn,
right_as_of=rightAsOfColumn,
on=on,
how=how,
tolerance=tolerance,
allow_exact_matches=allowExactMatches,
direction=direction,
),
session=self._session,
)

_joinAsOf.__doc__ = PySparkDataFrame._joinAsOf.__doc__

def limit(self, n: int) -> "DataFrame":
return DataFrame.withPlan(plan.Limit(child=self._plan, limit=n), session=self._session)

Expand Down
94 changes: 94 additions & 0 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,100 @@ def _repr_html_(self) -> str:
"""


class AsOfJoin(LogicalPlan):
def __init__(
self,
left: LogicalPlan,
right: LogicalPlan,
left_as_of: "ColumnOrName",
right_as_of: "ColumnOrName",
on: Optional[Union[str, List[str], Column, List[Column]]],
how: str,
tolerance: Optional[Column],
allow_exact_matches: bool,
direction: str,
) -> None:
super().__init__(left)
self.left = left
self.right = right
self.left_as_of = left_as_of
self.right_as_of = right_as_of
self.on = on
self.how = how
self.tolerance = tolerance
self.allow_exact_matches = allow_exact_matches
self.direction = direction

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.as_of_join.left.CopyFrom(self.left.plan(session))
plan.as_of_join.right.CopyFrom(self.right.plan(session))

if isinstance(self.left_as_of, Column):
plan.as_of_join.left_as_of.CopyFrom(self.left_as_of.to_plan(session))
else:
plan.as_of_join.left_as_of.CopyFrom(
ColumnReference(self.left_as_of, self.left._plan_id)
)

if isinstance(self.right_as_of, Column):
plan.as_of_join.right_as_of.CopyFrom(self.right_as_of.to_plan(session))
else:
plan.as_of_join.right_as_of.CopyFrom(
ColumnReference(self.right_as_of, self.right._plan_id)
)

if self.on is not None:
if not isinstance(self.on, list):
if isinstance(self.on, str):
plan.as_of_join.using_columns.append(self.on)
else:
plan.as_of_join.join_expr.CopyFrom(self.on.to_plan(session))
elif len(self.on) > 0:
if isinstance(self.on[0], str):
plan.as_of_join.using_columns.extend(cast(List[str], self.on))
else:
merge_column = functools.reduce(lambda c1, c2: c1 & c2, self.on)
plan.as_of_join.join_expr.CopyFrom(cast(Column, merge_column).to_plan(session))

plan.as_of_join.join_type = self.how

if self.tolerance is not None:
plan.as_of_join.tolerance.CopyFrom(self.tolerance.to_plan(session))

plan.as_of_join.allow_exact_matches = self.allow_exact_matches
plan.as_of_join.direction = self.direction

return plan

def print(self, indent: int = 0) -> str:
assert self.left is not None
assert self.right is not None

i = " " * indent
o = " " * (indent + LogicalPlan.INDENT)
n = indent + LogicalPlan.INDENT * 2
return (
f"{i}<AsOfJoin left_as_of={self.left_as_of}, right_as_of={self.right_as_of}, "
f"on={self.on} how={self.how}>\n{o}"
f"left=\n{self.left.print(n)}\n{o}right=\n{self.right.print(n)}"
)

def _repr_html_(self) -> str:
assert self.left is not None
assert self.right is not None

return f"""
<ul>
<li>
<b>AsOfJoin</b><br />
Left: {self.left._repr_html_()}
Right: {self.right._repr_html_()}
</li>
</uL>
"""


class SetOperation(LogicalPlan):
def __init__(
self,
Expand Down
278 changes: 140 additions & 138 deletions python/pyspark/sql/connect/proto/relations_pb2.py

Large diffs are not rendered by default.

Loading

0 comments on commit ca44182

Please sign in to comment.