Skip to content

Commit

Permalink
address comment
Browse files Browse the repository at this point in the history
  • Loading branch information
AngersZhuuuu committed Jul 22, 2020
1 parent 7916d72 commit a769aa7
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,27 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
selectClause.hints.asScala.foldRight(withWindow)(withHints)
}

// Decode and input/output format.
type Format = (Seq[(String, String)], Option[String], Seq[(String, String)], Option[String])

protected def getRowFormatDelimited(ctx: RowFormatDelimitedContext): Format = {
// TODO we should use the visitRowFormatDelimited function here. However HiveScriptIOSchema
// expects a seq of pairs in which the old parsers' token names are used as keys.
// Transforming the result of visitRowFormatDelimited would be quite a bit messier than
// retrieving the key value pairs ourselves.
def entry(key: String, value: Token): Seq[(String, String)] = {
Option(value).map(t => key -> t.getText).toSeq
}

val entries = entry("TOK_TABLEROWFORMATFIELD", ctx.fieldsTerminatedBy) ++
entry("TOK_TABLEROWFORMATCOLLITEMS", ctx.collectionItemsTerminatedBy) ++
entry("TOK_TABLEROWFORMATMAPKEYS", ctx.keysTerminatedBy) ++
entry("TOK_TABLEROWFORMATLINES", ctx.linesSeparatedBy) ++
entry("TOK_TABLEROWFORMATNULL", ctx.nullDefinedAs)

(entries, None, Seq.empty, None)
}

/**
* Create a [[ScriptInputOutputSchema]].
*/
Expand All @@ -754,26 +775,10 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
outRowFormat: RowFormatContext,
recordReader: Token,
schemaLess: Boolean): ScriptInputOutputSchema = {
// Decode and input/output format.
type Format = (Seq[(String, String)], Option[String], Seq[(String, String)], Option[String])

def format(fmt: RowFormatContext): Format = fmt match {
case c: RowFormatDelimitedContext =>
// TODO we should use the visitRowFormatDelimited function here. However HiveScriptIOSchema
// expects a seq of pairs in which the old parsers' token names are used as keys.
// Transforming the result of visitRowFormatDelimited would be quite a bit messier than
// retrieving the key value pairs ourselves.
def entry(key: String, value: Token): Seq[(String, String)] = {
Option(value).map(t => key -> t.getText).toSeq
}

val entries = entry("TOK_TABLEROWFORMATFIELD", c.fieldsTerminatedBy) ++
entry("TOK_TABLEROWFORMATCOLLITEMS", c.collectionItemsTerminatedBy) ++
entry("TOK_TABLEROWFORMATMAPKEYS", c.keysTerminatedBy) ++
entry("TOK_TABLEROWFORMATLINES", c.linesSeparatedBy) ++
entry("TOK_TABLEROWFORMATNULL", c.nullDefinedAs)

(entries, None, Seq.empty, None)
getRowFormatDelimited(c)

case c: RowFormatSerdeContext =>
throw new ParseException("TRANSFORM with serde is only supported in hive mode", ctx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.{IntegerType, LongType, StringType}

/**
* Parser test cases for rules defined in [[CatalystSqlParser]] / [[AstBuilder]].
Expand Down Expand Up @@ -1031,4 +1031,96 @@ class PlanParserSuite extends AnalysisTest {
assertEqual("select a, b from db.c;;;", table("db", "c").select('a, 'b))
assertEqual("select a, b from db.c; ;; ;", table("db", "c").select('a, 'b))
}

test("SPARK-32106: TRANSFORM without serde") {
// verify schema less
assertEqual(
"""
|SELECT TRANSFORM(a, b, c)
|USING 'cat'
|FROM testData
""".stripMargin,
ScriptTransformation(
Seq('a, 'b, 'c),
"cat",
Seq(AttributeReference("key", StringType)(),
AttributeReference("value", StringType)()),
UnresolvedRelation(TableIdentifier("testData")),
ScriptInputOutputSchema(List.empty, List.empty, None, None,
List.empty, List.empty, None, None, true))
)

// verify without output schema
assertEqual(
"""
|SELECT TRANSFORM(a, b, c)
|USING 'cat' AS (a, b, c)
|FROM testData
""".stripMargin,
ScriptTransformation(
Seq('a, 'b, 'c),
"cat",
Seq(AttributeReference("a", StringType)(),
AttributeReference("b", StringType)(),
AttributeReference("c", StringType)()),
UnresolvedRelation(TableIdentifier("testData")),
ScriptInputOutputSchema(List.empty, List.empty, None, None,
List.empty, List.empty, None, None, false)))

// verify with output schema
assertEqual(
"""
|SELECT TRANSFORM(a, b, c)
|USING 'cat' AS (a int, b string, c long)
|FROM testData
""".stripMargin,
ScriptTransformation(
Seq('a, 'b, 'c),
"cat",
Seq(AttributeReference("a", IntegerType)(),
AttributeReference("b", StringType)(),
AttributeReference("c", LongType)()),
UnresolvedRelation(TableIdentifier("testData")),
ScriptInputOutputSchema(List.empty, List.empty, None, None,
List.empty, List.empty, None, None, false)))

// verify with ROW FORMAT DELIMETED
assertEqual(
"""
|SELECT TRANSFORM(a, b, c)
|ROW FORMAT DELIMITED
|FIELDS TERMINATED BY '\t'
|COLLECTION ITEMS TERMINATED BY '\u0002'
|MAP KEYS TERMINATED BY '\u0003'
|LINES TERMINATED BY '\n'
|NULL DEFINED AS 'null'
|USING 'cat' AS (a, b, c)
|ROW FORMAT DELIMITED
|FIELDS TERMINATED BY '\t'
|COLLECTION ITEMS TERMINATED BY '\u0004'
|MAP KEYS TERMINATED BY '\u0005'
|LINES TERMINATED BY '\n'
|NULL DEFINED AS 'NULL'
|FROM testData
""".stripMargin,
ScriptTransformation(
Seq('a, 'b, 'c),
"cat",
Seq(AttributeReference("a", StringType)(),
AttributeReference("b", StringType)(),
AttributeReference("c", StringType)()),
UnresolvedRelation(TableIdentifier("testData")),
ScriptInputOutputSchema(
Seq(("TOK_TABLEROWFORMATFIELD", "'\\t'"),
("TOK_TABLEROWFORMATCOLLITEMS", "'\u0002'"),
("TOK_TABLEROWFORMATMAPKEYS", "'\u0003'"),
("TOK_TABLEROWFORMATLINES", "'\\n'"),
("TOK_TABLEROWFORMATNULL", "'null'")),
Seq(("TOK_TABLEROWFORMATFIELD", "'\\t'"),
("TOK_TABLEROWFORMATCOLLITEMS", "'\u0004'"),
("TOK_TABLEROWFORMATMAPKEYS", "'\u0005'"),
("TOK_TABLEROWFORMATLINES", "'\\n'"),
("TOK_TABLEROWFORMATNULL", "'NULL'")), None, None,
List.empty, List.empty, None, None, false)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
prevLine: String =>
new GenericInternalRow(
prevLine.split(outputRowFormat)
.zip(fieldWriters)
.zip(outputFieldWriters)
.map { case (data, writer) => writer(data) })
} else {
// In schema less mode, hive default serde will choose first two output column as output
Expand Down Expand Up @@ -182,7 +182,7 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
}
}

private lazy val fieldWriters: Seq[String => Any] = output.map { attr =>
private lazy val outputFieldWriters: Seq[String => Any] = output.map { attr =>
val converter = CatalystTypeConverters.createToCatalystConverter(attr.dataType)
attr.dataType match {
case StringType => wrapperConvertException(data => data, converter)
Expand Down Expand Up @@ -218,10 +218,9 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
converter)
case udt: UserDefinedType[_] =>
wrapperConvertException(data => udt.deserialize(data), converter)
case ArrayType(_, _) | MapType(_, _, _) | StructType(_) =>
throw new SparkException("TRANSFORM without serde don't support" +
" ArrayType/MapType/StructType as output data type")
case _ => wrapperConvertException(data => data, converter)
case dt =>
throw new SparkException("TRANSFORM without serde does not support " +
s"${dt.getClass.getSimpleName} as output data type")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -689,30 +689,12 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
recordReader,
schemaLess)
} else {

// Decode and input/output format.
type Format = (Seq[(String, String)], Option[String], Seq[(String, String)], Option[String])

def format(
fmt: RowFormatContext,
configKey: String,
defaultConfigValue: String): Format = fmt match {
case c: RowFormatDelimitedContext =>
// TODO we should use visitRowFormatDelimited function here. However HiveScriptIOSchema
// expects a seq of pairs in which the old parsers' token names are used as keys.
// Transforming the result of visitRowFormatDelimited would be quite a bit messier than
// retrieving the key value pairs ourselves.
def entry(key: String, value: Token): Seq[(String, String)] = {
Option(value).map(t => key -> t.getText).toSeq
}

val entries = entry("TOK_TABLEROWFORMATFIELD", c.fieldsTerminatedBy) ++
entry("TOK_TABLEROWFORMATCOLLITEMS", c.collectionItemsTerminatedBy) ++
entry("TOK_TABLEROWFORMATMAPKEYS", c.keysTerminatedBy) ++
entry("TOK_TABLEROWFORMATLINES", c.linesSeparatedBy) ++
entry("TOK_TABLEROWFORMATNULL", c.nullDefinedAs)

(entries, None, Seq.empty, None)
getRowFormatDelimited(c)

case c: RowFormatSerdeContext =>
// Use a serde format.
Expand Down
2 changes: 1 addition & 1 deletion sql/core/src/test/resources/sql-tests/inputs/transform.sql
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k, l FROM (
FROM t
) tmp;

-- handle schema less
-- SPARK-32388 handle schema less
SELECT TRANSFORM(a)
USING 'cat'
FROM t;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,6 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU
}
}


test("SPARK-32106: TRANSFORM should respect DATETIME_JAVA8API_ENABLED (no serde)") {
assume(TestUtils.testCommandAvailable("python"))
Array(false, true).foreach { java8AapiEnable =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with
}
}

test("TRANSFORM don't support ArrayType/MapType/StructType as output data type (no serde)") {
test("TRANSFORM doesn't support ArrayType/MapType/StructType as output data type (no serde)") {
assume(TestUtils.testCommandAvailable("/bin/bash"))
// check for ArrayType
val e1 = intercept[SparkException] {
Expand All @@ -73,8 +73,8 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with
|FROM VALUES (array(1, 1), map('1', 1), struct(1, 'a')) t(a, b, c)
""".stripMargin).collect()
}.getMessage
assert(e1.contains("TRANSFORM without serde don't support" +
" ArrayType/MapType/StructType as output data type"))
assert(e1.contains("TRANSFORM without serde does not support" +
" ArrayType as output data type"))

// check for MapType
val e2 = intercept[SparkException] {
Expand All @@ -85,8 +85,8 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with
|FROM VALUES (array(1, 1), map('1', 1), struct(1, 'a')) t(a, b, c)
""".stripMargin).collect()
}.getMessage
assert(e2.contains("TRANSFORM without serde don't support" +
" ArrayType/MapType/StructType as output data type"))
assert(e2.contains("TRANSFORM without serde does not support" +
" MapType as output data type"))

// check for StructType
val e3 = intercept[SparkException] {
Expand All @@ -97,7 +97,7 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with
|FROM VALUES (array(1, 1), map('1', 1), struct(1, 'a')) t(a, b, c)
""".stripMargin).collect()
}.getMessage
assert(e3.contains("TRANSFORM without serde don't support" +
" ArrayType/MapType/StructType as output data type"))
assert(e3.contains("TRANSFORM without serde does not support" +
" StructType as output data type"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
)
}

private val serdeIOSchema: ScriptTransformationIOSchema = {
private val hiveIOSchema: ScriptTransformationIOSchema = {
defaultIOSchema.copy(
inputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName),
outputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName)
Expand All @@ -71,7 +71,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
script = "cat",
output = Seq(AttributeReference("a", StringType)()),
child = child,
ioschema = serdeIOSchema
ioschema = hiveIOSchema
),
rowsDf.collect())
assert(uncaughtExceptionHandler.exception.isEmpty)
Expand All @@ -89,7 +89,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
script = "cat",
output = Seq(AttributeReference("a", StringType)()),
child = ExceptionInjectingOperator(child),
ioschema = serdeIOSchema
ioschema = hiveIOSchema
),
rowsDf.collect())
}
Expand All @@ -110,7 +110,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
script = "some_non_existent_command",
output = Seq(AttributeReference("a", StringType)()),
child = rowsDf.queryExecution.sparkPlan,
ioschema = serdeIOSchema)
ioschema = hiveIOSchema)
SparkPlanTest.executePlan(plan, hiveContext)
}
assert(e.getMessage.contains("Subprocess exited with status"))
Expand All @@ -131,7 +131,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
script = "cat",
output = Seq(AttributeReference("name", StringType)()),
child = child,
ioschema = serdeIOSchema
ioschema = hiveIOSchema
),
rowsDf.select("name").collect())
assert(uncaughtExceptionHandler.exception.isEmpty)
Expand All @@ -148,7 +148,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
script = "some_non_existent_command",
output = Seq(AttributeReference("a", StringType)()),
child = rowsDf.queryExecution.sparkPlan,
ioschema = serdeIOSchema)
ioschema = hiveIOSchema)
SparkPlanTest.executePlan(plan, hiveContext)
}
assert(e.getMessage.contains("Subprocess exited with status"))
Expand Down Expand Up @@ -212,7 +212,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
StructField("col1", IntegerType, false),
StructField("col2", StringType, true))))()),
child = child,
ioschema = serdeIOSchema
ioschema = hiveIOSchema
),
df.select('c, 'd, 'e).collect())
}
Expand Down Expand Up @@ -256,7 +256,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
AttributeReference("a", IntegerType)(),
AttributeReference("b", CalendarIntervalType)()),
child = df.queryExecution.sparkPlan,
ioschema = serdeIOSchema)
ioschema = hiveIOSchema)
SparkPlanTest.executePlan(plan, hiveContext)
}
assert(e1.getMessage.contains("scala.MatchError: CalendarIntervalType"))
Expand All @@ -269,7 +269,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
AttributeReference("a", IntegerType)(),
AttributeReference("c", new TestUDT.MyDenseVectorUDT)()),
child = df.queryExecution.sparkPlan,
ioschema = serdeIOSchema)
ioschema = hiveIOSchema)
SparkPlanTest.executePlan(plan, hiveContext)
}
assert(e2.getMessage.contains(
Expand Down

0 comments on commit a769aa7

Please sign in to comment.