Skip to content

Commit

Permalink
[SPARK-40823][CONNECT] Connect Proto should carry unparsed identifiers
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Before this PR, connect proto defines the UnresolvedRelation as
```
  message NamedTable {
    repeated string parts = 1;
  }
```
which asked clients to provide multiple name parts of the relation. For example, user could offer `a.b.c.d` as the table name. However, this actually asks clients to implement `CatalystSqlParser.parseMultipartIdentifier`.  The problem to clients is they cannot access the catalyst parser thus needs to re-invent the wheel. Another problem is clients might not be able to implement the parsing correctly.

This PR proposes to change the proto to
```
  message NamedTable {
    string unparsed_identifier = 1;
}
```
which only needs clients to provide the user's table name. Server side as it can access catalyst, can parse the identifier.

This PR also changes on the Column identifier accordingly.

### Why are the changes needed?

This proposal reduced the required work on the client side.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Existing UT

Closes #38264 from amaliujia/unparsed_identifier.

Authored-by: Rui Wang <rui.wang@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
amaliujia authored and cloud-fan committed Oct 19, 2022
1 parent 646d716 commit 14d8604
Show file tree
Hide file tree
Showing 11 changed files with 91 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ message Expression {
// An unresolved attribute that is not explicitly bound to a specific column, but the column
// is resolved during analysis by name.
message UnresolvedAttribute {
repeated string parts = 1;
string unparsed_identifier = 1;
}

// An unresolved function is not explicitly bound to one explicit function, but the function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ message Read {
}

message NamedTable {
repeated string parts = 1;
string unparsed_identifier = 1;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import scala.language.implicitConversions
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.Join.JoinType
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.connect.planner.DataTypeProtoConverter

/**
Expand All @@ -33,16 +32,13 @@ package object dsl {

object expressions { // scalastyle:ignore
implicit class DslString(val s: String) {
val identifier = CatalystSqlParser.parseMultipartIdentifier(s)

def protoAttr: proto.Expression =
proto.Expression
.newBuilder()
.setUnresolvedAttribute(
proto.Expression.UnresolvedAttribute
.newBuilder()
.addAllParts(identifier.asJava)
.build())
.setUnparsedIdentifier(s))
.build()

def struct(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Sample, SubqueryAlias}
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -102,7 +103,9 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
common: Option[proto.RelationCommon]): LogicalPlan = {
val baseRelation = rel.getReadTypeCase match {
case proto.Read.ReadTypeCase.NAMED_TABLE =>
val child = UnresolvedRelation(rel.getNamedTable.getPartsList.asScala.toSeq)
val multipartIdentifier =
CatalystSqlParser.parseMultipartIdentifier(rel.getNamedTable.getUnparsedIdentifier)
val child = UnresolvedRelation(multipartIdentifier)
if (common.nonEmpty && common.get.getAlias.nonEmpty) {
SubqueryAlias(identifier = common.get.getAlias, child = child)
} else {
Expand Down Expand Up @@ -139,7 +142,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
}

private def transformUnresolvedExpression(exp: proto.Expression): UnresolvedAttribute = {
UnresolvedAttribute(exp.getUnresolvedAttribute.getPartsList.asScala.toSeq)
UnresolvedAttribute(exp.getUnresolvedAttribute.getUnparsedIdentifier)
}

private def transformExpression(exp: proto.Expression): Expression = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ trait SparkConnectPlanTest {
.setRead(
proto.Read
.newBuilder()
.setNamedTable(proto.Read.NamedTable.newBuilder().addParts("table"))
.setNamedTable(proto.Read.NamedTable.newBuilder().setUnparsedIdentifier("table"))
.build())
.build()

Expand Down Expand Up @@ -100,7 +100,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
// Invalid read without Table name.
intercept[InvalidPlanInput](transform(proto.Relation.newBuilder.setRead(read).build()))
val readWithTable = read.toBuilder
.setNamedTable(proto.Read.NamedTable.newBuilder.addParts("name").build())
.setNamedTable(proto.Read.NamedTable.newBuilder.setUnparsedIdentifier("name").build())
.build()
val res = transform(proto.Relation.newBuilder.setRead(readWithTable).build())
assert(res !== null)
Expand All @@ -110,7 +110,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
test("Simple Project") {
val readWithTable = proto.Read
.newBuilder()
.setNamedTable(proto.Read.NamedTable.newBuilder.addParts("name").build())
.setNamedTable(proto.Read.NamedTable.newBuilder.setUnparsedIdentifier("name").build())
.build()
val project =
proto.Project
Expand Down Expand Up @@ -139,10 +139,11 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
.newBuilder()
.setNulls(proto.Sort.SortNulls.SORT_NULLS_LAST)
.setDirection(proto.Sort.SortDirection.SORT_DIRECTION_DESCENDING)
.setExpression(proto.Expression.newBuilder
.setUnresolvedAttribute(
proto.Expression.UnresolvedAttribute.newBuilder.addAllParts(Seq("col").asJava).build())
.build())
.setExpression(
proto.Expression.newBuilder
.setUnresolvedAttribute(
proto.Expression.UnresolvedAttribute.newBuilder.setUnparsedIdentifier("col").build())
.build())
.build()

val res = transform(
Expand Down Expand Up @@ -192,7 +193,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
val unresolvedAttribute = proto.Expression
.newBuilder()
.setUnresolvedAttribute(
proto.Expression.UnresolvedAttribute.newBuilder().addAllParts(Seq("left").asJava).build())
proto.Expression.UnresolvedAttribute.newBuilder().setUnparsedIdentifier("left").build())
.build()

val joinCondition = proto.Expression.newBuilder.setUnresolvedFunction(
Expand Down Expand Up @@ -236,7 +237,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
val unresolvedAttribute = proto.Expression
.newBuilder()
.setUnresolvedAttribute(
proto.Expression.UnresolvedAttribute.newBuilder().addAllParts(Seq("left").asJava).build())
proto.Expression.UnresolvedAttribute.newBuilder().setUnparsedIdentifier("left").build())
.build()

val agg = proto.Aggregate.newBuilder
Expand Down
14 changes: 7 additions & 7 deletions python/pyspark/sql/connect/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#

from typing import List, cast, get_args, TYPE_CHECKING, Optional, Callable, Any
from typing import cast, get_args, TYPE_CHECKING, Optional, Callable, Any


import pyspark.sql.connect.proto as proto
Expand Down Expand Up @@ -121,20 +121,20 @@ class ColumnRef(Expression):

@classmethod
def from_qualified_name(cls, name: str) -> "ColumnRef":
return ColumnRef(*name.split("."))
return ColumnRef(name)

def __init__(self, *parts: str) -> None:
def __init__(self, name: str) -> None:
super().__init__()
self._parts: List[str] = list(parts)
self._unparsed_identifier: str = name

def name(self) -> str:
"""Returns the qualified name of the column reference."""
return ".".join(self._parts)
return self._unparsed_identifier

def to_plan(self, session: Optional["RemoteSparkSession"]) -> proto.Expression:
"""Returns the Proto representation of the expression."""
expr = proto.Expression()
expr.unresolved_attribute.parts.extend(self._parts)
expr.unresolved_attribute.unparsed_identifier = self._unparsed_identifier
return expr

def desc(self) -> "SortOrder":
Expand All @@ -144,7 +144,7 @@ def asc(self) -> "SortOrder":
return SortOrder(self, ascending=True)

def __str__(self) -> str:
return f"Column({'.'.join(self._parts)})"
return f"Column({self._unparsed_identifier})"


class SortOrder(Expression):
Expand Down
10 changes: 4 additions & 6 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ class LogicalPlan(object):
def __init__(self, child: Optional["LogicalPlan"]) -> None:
self._child = child

def unresolved_attr(self, *colNames: str) -> proto.Expression:
def unresolved_attr(self, colName: str) -> proto.Expression:
"""Creates an unresolved attribute from a column name."""
exp = proto.Expression()
exp.unresolved_attribute.parts.extend(list(colNames))
exp.unresolved_attribute.unparsed_identifier = colName
return exp

def to_attr_or_expression(
Expand Down Expand Up @@ -118,7 +118,7 @@ def __init__(self, table_name: str) -> None:

def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
plan = proto.Relation()
plan.read.named_table.parts.extend(self.table_name.split("."))
plan.read.named_table.unparsed_identifier = self.table_name
return plan

def print(self, indent: int = 0) -> str:
Expand Down Expand Up @@ -165,9 +165,7 @@ def withAlias(self, alias: str) -> LogicalPlan:
def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
assert self._child is not None
proj_exprs = [
c.to_plan(session)
if isinstance(c, Expression)
else self.unresolved_attr(*(c.split(".")))
c.to_plan(session) if isinstance(c, Expression) else self.unresolved_attr(c)
for c in self._raw_columns
]
common = proto.RelationCommon()
Expand Down
Loading

0 comments on commit 14d8604

Please sign in to comment.