Skip to content

Commit

Permalink
fix: Add output to Comet operators equal and hashCode (#902)
Browse files Browse the repository at this point in the history
* fix: Add output to Comet operators equal and hashCode

* Update
  • Loading branch information
viirya authored Sep 2, 2024
1 parent 40b27cb commit 033fe6f
Show file tree
Hide file tree
Showing 18 changed files with 690 additions and 573 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ case class CometWindowExec(
override def equals(obj: Any): Boolean = {
obj match {
case other: CometWindowExec =>
this.output == other.output &&
this.windowExpression == other.windowExpression && this.child == other.child &&
this.partitionSpec == other.partitionSpec && this.orderSpec == other.orderSpec &&
this.serializedPlanOpt == other.serializedPlanOpt
Expand All @@ -74,5 +75,5 @@ case class CometWindowExec(
}

override def hashCode(): Int =
Objects.hashCode(windowExpression, partitionSpec, orderSpec, child)
Objects.hashCode(output, windowExpression, partitionSpec, orderSpec, child)
}
36 changes: 24 additions & 12 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ case class CometProjectExec(
override def equals(obj: Any): Boolean = {
obj match {
case other: CometProjectExec =>
this.output == other.output &&
this.projectList == other.projectList &&
this.child == other.child &&
this.serializedPlanOpt == other.serializedPlanOpt
Expand All @@ -436,7 +437,7 @@ case class CometProjectExec(
}
}

override def hashCode(): Int = Objects.hashCode(projectList, child)
override def hashCode(): Int = Objects.hashCode(output, projectList, child)

override protected def outputExpressions: Seq[NamedExpression] = projectList
}
Expand All @@ -462,14 +463,15 @@ case class CometFilterExec(
override def equals(obj: Any): Boolean = {
obj match {
case other: CometFilterExec =>
this.output == other.output &&
this.condition == other.condition && this.child == other.child &&
this.serializedPlanOpt == other.serializedPlanOpt
case _ =>
false
}
}

override def hashCode(): Int = Objects.hashCode(condition, child)
override def hashCode(): Int = Objects.hashCode(output, condition, child)

override def verboseStringWithOperatorId(): String = {
s"""
Expand Down Expand Up @@ -501,14 +503,15 @@ case class CometSortExec(
override def equals(obj: Any): Boolean = {
obj match {
case other: CometSortExec =>
this.output == other.output &&
this.sortOrder == other.sortOrder && this.child == other.child &&
this.serializedPlanOpt == other.serializedPlanOpt
case _ =>
false
}
}

override def hashCode(): Int = Objects.hashCode(sortOrder, child)
override def hashCode(): Int = Objects.hashCode(output, sortOrder, child)

override lazy val metrics: Map[String, SQLMetric] =
CometMetricNode.baselineMetrics(sparkContext) ++
Expand Down Expand Up @@ -539,14 +542,15 @@ case class CometLocalLimitExec(
override def equals(obj: Any): Boolean = {
obj match {
case other: CometLocalLimitExec =>
this.output == other.output &&
this.limit == other.limit && this.child == other.child &&
this.serializedPlanOpt == other.serializedPlanOpt
case _ =>
false
}
}

override def hashCode(): Int = Objects.hashCode(limit: java.lang.Integer, child)
override def hashCode(): Int = Objects.hashCode(output, limit: java.lang.Integer, child)
}

case class CometGlobalLimitExec(
Expand All @@ -569,14 +573,15 @@ case class CometGlobalLimitExec(
override def equals(obj: Any): Boolean = {
obj match {
case other: CometGlobalLimitExec =>
this.output == other.output &&
this.limit == other.limit && this.child == other.child &&
this.serializedPlanOpt == other.serializedPlanOpt
case _ =>
false
}
}

override def hashCode(): Int = Objects.hashCode(limit: java.lang.Integer, child)
override def hashCode(): Int = Objects.hashCode(output, limit: java.lang.Integer, child)
}

case class CometExpandExec(
Expand All @@ -599,14 +604,15 @@ case class CometExpandExec(
override def equals(obj: Any): Boolean = {
obj match {
case other: CometExpandExec =>
this.output == other.output &&
this.projections == other.projections && this.child == other.child &&
this.serializedPlanOpt == other.serializedPlanOpt
case _ =>
false
}
}

override def hashCode(): Int = Objects.hashCode(projections, child)
override def hashCode(): Int = Objects.hashCode(output, projections, child)

// TODO: support native Expand metrics
override lazy val metrics: Map[String, SQLMetric] = Map.empty
Expand Down Expand Up @@ -638,12 +644,14 @@ case class CometUnionExec(

override def equals(obj: Any): Boolean = {
obj match {
case other: CometUnionExec => this.children == other.children
case other: CometUnionExec =>
this.output == other.output &&
this.children == other.children
case _ => false
}
}

override def hashCode(): Int = Objects.hashCode(children)
override def hashCode(): Int = Objects.hashCode(output, children)
}

case class CometHashAggregateExec(
Expand Down Expand Up @@ -677,6 +685,7 @@ case class CometHashAggregateExec(
override def equals(obj: Any): Boolean = {
obj match {
case other: CometHashAggregateExec =>
this.output == other.output &&
this.groupingExpressions == other.groupingExpressions &&
this.aggregateExpressions == other.aggregateExpressions &&
this.input == other.input &&
Expand All @@ -689,7 +698,7 @@ case class CometHashAggregateExec(
}

override def hashCode(): Int =
Objects.hashCode(groupingExpressions, aggregateExpressions, input, mode, child)
Objects.hashCode(output, groupingExpressions, aggregateExpressions, input, mode, child)

override protected def outputExpressions: Seq[NamedExpression] = resultExpressions
}
Expand Down Expand Up @@ -729,6 +738,7 @@ case class CometHashJoinExec(
override def equals(obj: Any): Boolean = {
obj match {
case other: CometHashJoinExec =>
this.output == other.output &&
this.leftKeys == other.leftKeys &&
this.rightKeys == other.rightKeys &&
this.condition == other.condition &&
Expand All @@ -742,7 +752,7 @@ case class CometHashJoinExec(
}

override def hashCode(): Int =
Objects.hashCode(leftKeys, rightKeys, condition, buildSide, left, right)
Objects.hashCode(output, leftKeys, rightKeys, condition, buildSide, left, right)

override lazy val metrics: Map[String, SQLMetric] =
CometMetricNode.hashJoinMetrics(sparkContext)
Expand Down Expand Up @@ -865,6 +875,7 @@ case class CometBroadcastHashJoinExec(
override def equals(obj: Any): Boolean = {
obj match {
case other: CometBroadcastHashJoinExec =>
this.output == other.output &&
this.leftKeys == other.leftKeys &&
this.rightKeys == other.rightKeys &&
this.condition == other.condition &&
Expand All @@ -878,7 +889,7 @@ case class CometBroadcastHashJoinExec(
}

override def hashCode(): Int =
Objects.hashCode(leftKeys, rightKeys, condition, buildSide, left, right)
Objects.hashCode(output, leftKeys, rightKeys, condition, buildSide, left, right)

override lazy val metrics: Map[String, SQLMetric] =
CometMetricNode.hashJoinMetrics(sparkContext)
Expand Down Expand Up @@ -918,6 +929,7 @@ case class CometSortMergeJoinExec(
override def equals(obj: Any): Boolean = {
obj match {
case other: CometSortMergeJoinExec =>
this.output == other.output &&
this.leftKeys == other.leftKeys &&
this.rightKeys == other.rightKeys &&
this.condition == other.condition &&
Expand All @@ -930,7 +942,7 @@ case class CometSortMergeJoinExec(
}

override def hashCode(): Int =
Objects.hashCode(leftKeys, rightKeys, condition, left, right)
Objects.hashCode(output, leftKeys, rightKeys, condition, left, right)

override lazy val metrics: Map[String, SQLMetric] =
CometMetricNode.sortMergeJoinMetrics(sparkContext)
Expand Down
Loading

0 comments on commit 033fe6f

Please sign in to comment.