diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 7fc7bf6b15dfa..28d59645c0d5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -19,10 +19,9 @@ package org.apache.spark.sql.sources import scala.language.implicitConversions import scala.util.parsing.combinator.syntactical.StandardTokenParsers -import scala.util.parsing.combinator.{RegexParsers, PackratParsers} +import scala.util.parsing.combinator.PackratParsers import org.apache.spark.Logging -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.execution.RunnableCommand @@ -44,18 +43,43 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi } } + def parseType(input: String): DataType = { + phrase(dataType)(new lexical.Scanner(input)) match { + case Success(r, x) => r + case x => + sys.error(s"Unsupported dataType: $x") + } + } + protected case class Keyword(str: String) protected implicit def asParser(k: Keyword): Parser[String] = lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) protected val CREATE = Keyword("CREATE") - protected val DECIMAL = Keyword("DECIMAL") protected val TEMPORARY = Keyword("TEMPORARY") protected val TABLE = Keyword("TABLE") protected val USING = Keyword("USING") protected val OPTIONS = Keyword("OPTIONS") + // Data types. + protected val STRING = Keyword("STRING") + protected val FLOAT = Keyword("FLOAT") + protected val INT = Keyword("INT") + protected val TINYINT = Keyword("TINYINT") + protected val SMALLINT = Keyword("SMALLINT") + protected val DOUBLE = Keyword("DOUBLE") + protected val BIGINT = Keyword("BIGINT") + protected val BINARY = Keyword("BINARY") + protected val BOOLEAN = Keyword("BOOLEAN") + protected val DECIMAL = Keyword("DECIMAL") + protected val DATE = Keyword("DATE") + protected val TIMESTAMP = Keyword("TIMESTAMP") + protected val VARCHAR = Keyword("VARCHAR") + protected val ARRAY = Keyword("ARRAY") + protected val MAP = Keyword("MAP") + protected val STRUCT = Keyword("STRUCT") + // Use reflection to find the reserved words defined in this class. protected val reservedWords = this.getClass @@ -77,20 +101,15 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` */ protected lazy val createTable: Parser[LogicalPlan] = - ( CREATE ~ TEMPORARY ~ TABLE ~> ident ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ { - case tableName ~ provider ~ opts => - CreateTableUsing(tableName, Seq.empty, provider, opts) - } - | + ( CREATE ~ TEMPORARY ~ TABLE ~> ident - ~ tableCols ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ { - case tableName ~ tableColumns ~ provider ~ opts => - CreateTableUsing(tableName, tableColumns, provider, opts) + ~ (tableCols).? ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ { + case tableName ~ columns ~ provider ~ opts => + val tblColumns = if(columns.isEmpty) Seq.empty else columns.get + CreateTableUsing(tableName, tblColumns, provider, opts) } ) - protected lazy val metastoreTypes = new MetastoreTypes - protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" protected lazy val options: Parser[Map[String, String]] = @@ -101,96 +120,62 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi protected lazy val pair: Parser[(String, String)] = ident ~ stringLit ^^ { case k ~ v => (k,v) } protected lazy val column: Parser[StructField] = - ( ident ~ ident ^^ { case name ~ typ => - StructField(name, metastoreTypes.toDataType(typ)) + ident ~ dataType ^^ { case columnName ~ typ => + StructField(cleanIdentifier(columnName), typ) } - | - ident ~ (DECIMAL ~ "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ { - case name ~ precision ~ scale => - StructField(name, DecimalType(precision.toInt, scale.toInt)) - } - ) -} -/** - * :: DeveloperApi :: - * Provides a parser for data types. - */ -@DeveloperApi -private[sql] class MetastoreTypes extends RegexParsers { protected lazy val primitiveType: Parser[DataType] = - "string" ^^^ StringType | - "float" ^^^ FloatType | - "int" ^^^ IntegerType | - "tinyint" ^^^ ByteType | - "smallint" ^^^ ShortType | - "double" ^^^ DoubleType | - "bigint" ^^^ LongType | - "binary" ^^^ BinaryType | - "boolean" ^^^ BooleanType | - fixedDecimalType | // decimal with precision/scale - "decimal" ^^^ DecimalType.Unlimited | // decimal with no precision/scale - "date" ^^^ DateType | - "timestamp" ^^^ TimestampType | - "varchar\\((\\d+)\\)".r ^^^ StringType + STRING ^^^ StringType | + BINARY ^^^ BinaryType | + BOOLEAN ^^^ BooleanType | + TINYINT ^^^ ByteType | + SMALLINT ^^^ ShortType | + INT ^^^ IntegerType | + BIGINT ^^^ LongType | + FLOAT ^^^ FloatType | + DOUBLE ^^^ DoubleType | + fixedDecimalType | // decimal with precision/scale + DECIMAL ^^^ DecimalType.Unlimited | // decimal with no precision/scale + DATE ^^^ DateType | + TIMESTAMP ^^^ TimestampType | + VARCHAR ~ "(" ~ numericLit ~ ")" ^^^ StringType protected lazy val fixedDecimalType: Parser[DataType] = - ("decimal" ~> "(" ~> "\\d+".r) ~ ("," ~> "\\d+".r <~ ")") ^^ { - case precision ~ scale => - DecimalType(precision.toInt, scale.toInt) + (DECIMAL ~ "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ { + case precision ~ scale => DecimalType(precision.toInt, scale.toInt) } protected lazy val arrayType: Parser[DataType] = - "array" ~> "<" ~> dataType <~ ">" ^^ { + ARRAY ~> "<" ~> dataType <~ ">" ^^ { case tpe => ArrayType(tpe) } protected lazy val mapType: Parser[DataType] = - "map" ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ { + MAP ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ { case t1 ~ _ ~ t2 => MapType(t1, t2) } protected lazy val structField: Parser[StructField] = - "[a-zA-Z0-9_]*".r ~ ":" ~ dataType ^^ { - case name ~ _ ~ tpe => StructField(name, tpe, nullable = true) + ident ~ ":" ~ dataType ^^ { + case fieldName ~ _ ~ tpe => StructField(cleanIdentifier(fieldName), tpe, nullable = true) } protected lazy val structType: Parser[DataType] = - "struct" ~> "<" ~> repsep(structField,",") <~ ">" ^^ { + STRUCT ~> "<" ~> repsep(structField, ",") <~ ">" ^^ { case fields => new StructType(fields) } private[sql] lazy val dataType: Parser[DataType] = arrayType | - mapType | - structType | - primitiveType - - def toDataType(metastoreType: String): DataType = parseAll(dataType, metastoreType) match { - case Success(result, _) => result - case failure: NoSuccess => sys.error(s"Unsupported dataType: $metastoreType") - } - - def toMetastoreType(dt: DataType): String = dt match { - case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>" - case StructType(fields) => - s"struct<${fields.map(f => s"${f.name}:${toMetastoreType(f.dataType)}").mkString(",")}>" - case MapType(keyType, valueType, _) => - s"map<${toMetastoreType(keyType)},${toMetastoreType(valueType)}>" - case StringType => "string" - case FloatType => "float" - case IntegerType => "int" - case ByteType => "tinyint" - case ShortType => "smallint" - case DoubleType => "double" - case LongType => "bigint" - case BinaryType => "binary" - case BooleanType => "boolean" - case DateType => "date" - case d: DecimalType => "decimal" - case TimestampType => "timestamp" - case NullType => "void" - case udt: UserDefinedType[_] => toMetastoreType(udt.sqlType) + mapType | + structType | + primitiveType + + protected val escapedIdentifier = "`([^`]+)`".r + /** Strips backticks from ident if present */ + protected def cleanIdentifier(ident: String): String = ident match { + case escapedIdentifier(i) => i + case plainIdent => plainIdent } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/NewTableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/NewTableScanSuite.scala index c8095b336f8e8..b860ca302cee1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/NewTableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/NewTableScanSuite.scala @@ -93,9 +93,9 @@ class NewTableScanSuite extends DataSourceTest { before { sql( """ - |CREATE TEMPORARY TABLE oneToTen(stringField string, intField int, longField bigint, - |floatField float, doubleField double, shortField smallint, byteField tinyint, - |booleanField boolean, decimalField decimal(10,2), dateField date, timestampField timestamp) + |CREATE TEMPORARY TABLE oneToTen(stringField stRIng, intField iNt, longField Bigint, + |floatField flOat, doubleField doubLE, shortField smaLlint, byteField tinyint, + |booleanField boolean, decimalField decimal(10,2), dateField dAte, timestampField tiMestamp) |USING org.apache.spark.sql.sources.AllDataTypesScanSource |OPTIONS ( | From '1', diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index accdaf591b5ea..74584c72cebbf 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -38,7 +38,6 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.sources.MetastoreTypes import org.apache.spark.util.Utils /* Implicit conversions */ @@ -438,7 +437,7 @@ private[hive] case class MetastoreRelation implicit class SchemaAttribute(f: FieldSchema) { def toAttribute = AttributeReference( f.getName, - HiveMetastoreTypes.toDataType(f.getType), + sqlContext.ddlParser.parseType(f.getType), // Since data can be dumped in randomly with no validation, everything is nullable. nullable = true )(qualifiers = Seq(alias.getOrElse(tableName))) @@ -459,9 +458,8 @@ private[hive] case class MetastoreRelation val columnOrdinals = AttributeMap(attributes.zipWithIndex) } - -object HiveMetastoreTypes extends MetastoreTypes { - override def toMetastoreType(dt: DataType): String = dt match { +object HiveMetastoreTypes { + def toMetastoreType(dt: DataType): String = dt match { case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>" case StructType(fields) => s"struct<${fields.map(f => s"${f.name}:${toMetastoreType(f.dataType)}").mkString(",")}>" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 86535f8dd4f58..041a36f1295ef 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.types.StructType +import org.apache.spark.sql.sources.DDLParser import org.apache.spark.sql.test.ExamplePointUDT class HiveMetastoreCatalogSuite extends FunSuite { @@ -27,7 +28,9 @@ class HiveMetastoreCatalogSuite extends FunSuite { test("struct field should accept underscore in sub-column name") { val metastr = "struct" - val datatype = HiveMetastoreTypes.toDataType(metastr) + val ddlParser = new DDLParser + + val datatype = ddlParser.parseType(metastr) assert(datatype.isInstanceOf[StructType]) }