Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Add output to Comet operators equal and hashCode #902

Merged
merged 2 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ case class CometWindowExec(
override def equals(obj: Any): Boolean = {
obj match {
case other: CometWindowExec =>
this.output == other.output &&
this.windowExpression == other.windowExpression && this.child == other.child &&
this.partitionSpec == other.partitionSpec && this.orderSpec == other.orderSpec &&
this.serializedPlanOpt == other.serializedPlanOpt
Expand All @@ -74,5 +75,5 @@ case class CometWindowExec(
}

override def hashCode(): Int =
Objects.hashCode(windowExpression, partitionSpec, orderSpec, child)
Objects.hashCode(output, windowExpression, partitionSpec, orderSpec, child)
}
36 changes: 24 additions & 12 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ case class CometProjectExec(
override def equals(obj: Any): Boolean = {
obj match {
case other: CometProjectExec =>
this.output == other.output &&
this.projectList == other.projectList &&
this.child == other.child &&
this.serializedPlanOpt == other.serializedPlanOpt
Expand All @@ -436,7 +437,7 @@ case class CometProjectExec(
}
}

override def hashCode(): Int = Objects.hashCode(projectList, child)
override def hashCode(): Int = Objects.hashCode(output, projectList, child)

override protected def outputExpressions: Seq[NamedExpression] = projectList
}
Expand All @@ -462,14 +463,15 @@ case class CometFilterExec(
override def equals(obj: Any): Boolean = {
obj match {
case other: CometFilterExec =>
this.output == other.output &&
this.condition == other.condition && this.child == other.child &&
this.serializedPlanOpt == other.serializedPlanOpt
case _ =>
false
}
}

override def hashCode(): Int = Objects.hashCode(condition, child)
override def hashCode(): Int = Objects.hashCode(output, condition, child)

override def verboseStringWithOperatorId(): String = {
s"""
Expand Down Expand Up @@ -501,14 +503,15 @@ case class CometSortExec(
override def equals(obj: Any): Boolean = {
obj match {
case other: CometSortExec =>
this.output == other.output &&
this.sortOrder == other.sortOrder && this.child == other.child &&
this.serializedPlanOpt == other.serializedPlanOpt
case _ =>
false
}
}

override def hashCode(): Int = Objects.hashCode(sortOrder, child)
override def hashCode(): Int = Objects.hashCode(output, sortOrder, child)

override lazy val metrics: Map[String, SQLMetric] =
CometMetricNode.baselineMetrics(sparkContext) ++
Expand Down Expand Up @@ -539,14 +542,15 @@ case class CometLocalLimitExec(
override def equals(obj: Any): Boolean = {
obj match {
case other: CometLocalLimitExec =>
this.output == other.output &&
this.limit == other.limit && this.child == other.child &&
this.serializedPlanOpt == other.serializedPlanOpt
case _ =>
false
}
}

override def hashCode(): Int = Objects.hashCode(limit: java.lang.Integer, child)
override def hashCode(): Int = Objects.hashCode(output, limit: java.lang.Integer, child)
}

case class CometGlobalLimitExec(
Expand All @@ -569,14 +573,15 @@ case class CometGlobalLimitExec(
override def equals(obj: Any): Boolean = {
obj match {
case other: CometGlobalLimitExec =>
this.output == other.output &&
this.limit == other.limit && this.child == other.child &&
this.serializedPlanOpt == other.serializedPlanOpt
case _ =>
false
}
}

override def hashCode(): Int = Objects.hashCode(limit: java.lang.Integer, child)
override def hashCode(): Int = Objects.hashCode(output, limit: java.lang.Integer, child)
}

case class CometExpandExec(
Expand All @@ -599,14 +604,15 @@ case class CometExpandExec(
override def equals(obj: Any): Boolean = {
obj match {
case other: CometExpandExec =>
this.output == other.output &&
this.projections == other.projections && this.child == other.child &&
this.serializedPlanOpt == other.serializedPlanOpt
case _ =>
false
}
}

override def hashCode(): Int = Objects.hashCode(projections, child)
override def hashCode(): Int = Objects.hashCode(output, projections, child)

// TODO: support native Expand metrics
override lazy val metrics: Map[String, SQLMetric] = Map.empty
Expand Down Expand Up @@ -638,12 +644,14 @@ case class CometUnionExec(

override def equals(obj: Any): Boolean = {
obj match {
case other: CometUnionExec => this.children == other.children
case other: CometUnionExec =>
this.output == other.output &&
this.children == other.children
case _ => false
}
}

override def hashCode(): Int = Objects.hashCode(children)
override def hashCode(): Int = Objects.hashCode(output, children)
}

case class CometHashAggregateExec(
Expand Down Expand Up @@ -677,6 +685,7 @@ case class CometHashAggregateExec(
override def equals(obj: Any): Boolean = {
obj match {
case other: CometHashAggregateExec =>
this.output == other.output &&
this.groupingExpressions == other.groupingExpressions &&
this.aggregateExpressions == other.aggregateExpressions &&
this.input == other.input &&
Expand All @@ -689,7 +698,7 @@ case class CometHashAggregateExec(
}

override def hashCode(): Int =
Objects.hashCode(groupingExpressions, aggregateExpressions, input, mode, child)
Objects.hashCode(output, groupingExpressions, aggregateExpressions, input, mode, child)

override protected def outputExpressions: Seq[NamedExpression] = resultExpressions
}
Expand Down Expand Up @@ -729,6 +738,7 @@ case class CometHashJoinExec(
override def equals(obj: Any): Boolean = {
obj match {
case other: CometHashJoinExec =>
this.output == other.output &&
this.leftKeys == other.leftKeys &&
this.rightKeys == other.rightKeys &&
this.condition == other.condition &&
Expand All @@ -742,7 +752,7 @@ case class CometHashJoinExec(
}

override def hashCode(): Int =
Objects.hashCode(leftKeys, rightKeys, condition, buildSide, left, right)
Objects.hashCode(output, leftKeys, rightKeys, condition, buildSide, left, right)

override lazy val metrics: Map[String, SQLMetric] =
CometMetricNode.hashJoinMetrics(sparkContext)
Expand Down Expand Up @@ -865,6 +875,7 @@ case class CometBroadcastHashJoinExec(
override def equals(obj: Any): Boolean = {
obj match {
case other: CometBroadcastHashJoinExec =>
this.output == other.output &&
this.leftKeys == other.leftKeys &&
this.rightKeys == other.rightKeys &&
this.condition == other.condition &&
Expand All @@ -878,7 +889,7 @@ case class CometBroadcastHashJoinExec(
}

override def hashCode(): Int =
Objects.hashCode(leftKeys, rightKeys, condition, buildSide, left, right)
Objects.hashCode(output, leftKeys, rightKeys, condition, buildSide, left, right)

override lazy val metrics: Map[String, SQLMetric] =
CometMetricNode.hashJoinMetrics(sparkContext)
Expand Down Expand Up @@ -918,6 +929,7 @@ case class CometSortMergeJoinExec(
override def equals(obj: Any): Boolean = {
obj match {
case other: CometSortMergeJoinExec =>
this.output == other.output &&
this.leftKeys == other.leftKeys &&
this.rightKeys == other.rightKeys &&
this.condition == other.condition &&
Expand All @@ -930,7 +942,7 @@ case class CometSortMergeJoinExec(
}

override def hashCode(): Int =
Objects.hashCode(leftKeys, rightKeys, condition, left, right)
Objects.hashCode(output, leftKeys, rightKeys, condition, left, right)

override lazy val metrics: Map[String, SQLMetric] =
CometMetricNode.sortMergeJoinMetrics(sparkContext)
Expand Down
Loading
Loading