Skip to content

Commit

Permalink
Make sql always use spark sql parser, users of hive context can now…
Browse files Browse the repository at this point in the history
… use hql or hiveql to run queries using HiveQL instead.
  • Loading branch information
marmbrus committed Apr 4, 2014
1 parent d94826b commit fbe4a54
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,20 @@ object HiveFromSpark {
val hiveContext = new LocalHiveContext(sc)
import hiveContext._

sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
sql("LOAD DATA LOCAL INPATH 'src/main/resources/kv1.txt' INTO TABLE src")
hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
hql("LOAD DATA LOCAL INPATH 'src/main/resources/kv1.txt' INTO TABLE src")

// Queries are expressed in HiveQL
println("Result of 'SELECT *': ")
sql("SELECT * FROM src").collect.foreach(println)
hql("SELECT * FROM src").collect.foreach(println)

// Aggregation queries are also supported.
val count = sql("SELECT COUNT(*) FROM src").collect().head.getInt(0)
val count = hql("SELECT COUNT(*) FROM src").collect().head.getInt(0)
println(s"COUNT(*): $count")

// The results of SQL queries are themselves RDDs and support all normal RDD functions. The
// items in the RDD are of type Row, which allows you to access each column by ordinal.
val rddFromSql = sql("SELECT key, value FROM src WHERE key < 10 ORDER BY key")
val rddFromSql = hql("SELECT key, value FROM src WHERE key < 10 ORDER BY key")

println("Result of RDD.map:")
val rddAsStrings = rddFromSql.map {
Expand All @@ -59,6 +59,6 @@ object HiveFromSpark {

// Queries can then join RDD data with data stored in Hive.
println("Result of SELECT *:")
sql("SELECT * FROM records r JOIN src s ON r.key = s.key").collect().foreach(println)
hql("SELECT * FROM records r JOIN src s ON r.key = s.key").collect().foreach(println)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,13 @@ class LocalHiveContext(sc: SparkContext) extends HiveContext(sc) {
class HiveContext(sc: SparkContext) extends SQLContext(sc) {
self =>

override def parseSql(sql: String): LogicalPlan = HiveQl.parseSql(sql)
override def executePlan(plan: LogicalPlan): this.QueryExecution =
override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution =
new this.QueryExecution { val logical = plan }

/**
* Executes a query expressed in HiveQL using Spark, returning the result as a SchemaRDD.
*/
def hql(hqlQuery: String): SchemaRDD = {
def hiveql(hqlQuery: String): SchemaRDD = {
val result = new SchemaRDD(this, HiveQl.parseSql(hqlQuery))
// We force query optimization to happen right away instead of letting it happen lazily like
// when using the query DSL. This is so DDL commands behave as expected. This is only
Expand All @@ -83,6 +82,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
result
}

/** An alias for `hiveql`. */
def hql(hqlQuery: String): SchemaRDD = hiveql(hqlQuery)

// Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur.
@transient
protected val outputBuffer = new java.io.OutputStream {
Expand Down Expand Up @@ -120,7 +122,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {

/* A catalyst metadata catalog that points to the Hive Metastore. */
@transient
override lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog {
override protected[sql] lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog {
override def lookupRelation(
databaseName: Option[String],
tableName: String,
Expand All @@ -132,7 +134,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {

/* An analyzer that uses the Hive metastore. */
@transient
override lazy val analyzer = new Analyzer(catalog, HiveFunctionRegistry, caseSensitive = false)
override protected[sql] lazy val analyzer =
new Analyzer(catalog, HiveFunctionRegistry, caseSensitive = false)

/**
* Runs the specified SQL query using Hive.
Expand Down Expand Up @@ -214,14 +217,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
}

@transient
override val planner = hivePlanner
override protected[sql] val planner = hivePlanner

@transient
protected lazy val emptyResult =
sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1)

/** Extends QueryExecution with hive specific features. */
abstract class QueryExecution extends super.QueryExecution {
protected[sql] abstract class QueryExecution extends super.QueryExecution {
// TODO: Create mixin for the analyzer instead of overriding things here.
override lazy val optimizedPlan =
optimizer(catalog.PreInsertionCasts(catalog.CreateTables(analyzed)))
Expand Down
12 changes: 6 additions & 6 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) {

val describedTable = "DESCRIBE (\\w+)".r

class SqlQueryExecution(sql: String) extends this.QueryExecution {
lazy val logical = HiveQl.parseSql(sql)
def hiveExec() = runSqlHive(sql)
override def toString = sql + "\n" + super.toString
protected[hive] class HiveQLQueryExecution(hql: String) extends this.QueryExecution {
lazy val logical = HiveQl.parseSql(hql)
def hiveExec() = runSqlHive(hql)
override def toString = hql + "\n" + super.toString
}

/**
Expand All @@ -140,8 +140,8 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) {

case class TestTable(name: String, commands: (()=>Unit)*)

implicit class SqlCmd(sql: String) {
def cmd = () => new SqlQueryExecution(sql).stringResult(): Unit
protected[hive] implicit class SqlCmd(sql: String) {
def cmd = () => new HiveQLQueryExecution(sql).stringResult(): Unit
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ abstract class HiveComparisonTest
}

protected def prepareAnswer(
hiveQuery: TestHive.type#SqlQueryExecution,
hiveQuery: TestHive.type#HiveQLQueryExecution,
answer: Seq[String]): Seq[String] = {
val orderedAnswer = hiveQuery.logical match {
// Clean out non-deterministic time schema info.
Expand Down Expand Up @@ -227,7 +227,7 @@ abstract class HiveComparisonTest

try {
// MINOR HACK: You must run a query before calling reset the first time.
TestHive.sql("SHOW TABLES")
TestHive.hql("SHOW TABLES")
if (reset) { TestHive.reset() }

val hiveCacheFiles = queryList.zipWithIndex.map {
Expand Down Expand Up @@ -256,7 +256,7 @@ abstract class HiveComparisonTest
hiveCachedResults
} else {

val hiveQueries = queryList.map(new TestHive.SqlQueryExecution(_))
val hiveQueries = queryList.map(new TestHive.HiveQLQueryExecution(_))
// Make sure we can at least parse everything before attempting hive execution.
hiveQueries.foreach(_.logical)
val computedResults = (queryList.zipWithIndex, hiveQueries, hiveCacheFiles).zipped.map {
Expand Down Expand Up @@ -302,7 +302,7 @@ abstract class HiveComparisonTest

// Run w/ catalyst
val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) =>
val query = new TestHive.SqlQueryExecution(queryString)
val query = new TestHive.HiveQLQueryExecution(queryString)
try { (query, prepareAnswer(query, query.stringResult())) } catch {
case e: Exception =>
val errorMessage =
Expand Down Expand Up @@ -359,7 +359,7 @@ abstract class HiveComparisonTest
// When we encounter an error we check to see if the environment is still okay by running a simple query.
// If this fails then we halt testing since something must have gone seriously wrong.
try {
new TestHive.SqlQueryExecution("SELECT key FROM src").stringResult()
new TestHive.HiveQLQueryExecution("SELECT key FROM src").stringResult()
TestHive.runSqlHive("SELECT key FROM src")
} catch {
case e: Exception =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ import org.apache.spark.sql.hive.TestHive._
* A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution.
*/
class HiveQuerySuite extends HiveComparisonTest {

test("Query expressed in SQL") {
assert(sql("SELECT 1").collect() === Array(Seq(1)))
}

test("Query expressed in HiveQL") {
hql("FROM src SELECT key").collect()
hiveql("FROM src SELECT key").collect()
}

createQueryTest("Simple Average",
"SELECT AVG(key) FROM src")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class HiveResolutionSuite extends HiveComparisonTest {
TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2)) :: Nil)
.registerAsTable("caseSensitivityTest")

sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest")
hql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest")
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ class PruningSuite extends HiveComparisonTest {
expectedScannedColumns: Seq[String],
expectedPartValues: Seq[Seq[String]]) = {
test(s"$testCaseName - pruning test") {
val plan = new TestHive.SqlQueryExecution(sql).executedPlan
val plan = new TestHive.HiveQLQueryExecution(sql).executedPlan
val actualOutputColumns = plan.output.map(_.name)
val (actualScannedColumns, actualPartValues) = plan.collect {
case p @ HiveTableScan(columns, relation, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,34 +57,34 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft
}

test("SELECT on Parquet table") {
val rdd = sql("SELECT * FROM testsource").collect()
val rdd = hql("SELECT * FROM testsource").collect()
assert(rdd != null)
assert(rdd.forall(_.size == 6))
}

test("Simple column projection + filter on Parquet table") {
val rdd = sql("SELECT myboolean, mylong FROM testsource WHERE myboolean=true").collect()
val rdd = hql("SELECT myboolean, mylong FROM testsource WHERE myboolean=true").collect()
assert(rdd.size === 5, "Filter returned incorrect number of rows")
assert(rdd.forall(_.getBoolean(0)), "Filter returned incorrect Boolean field value")
}

test("Converting Hive to Parquet Table via saveAsParquetFile") {
sql("SELECT * FROM src").saveAsParquetFile(dirname.getAbsolutePath)
hql("SELECT * FROM src").saveAsParquetFile(dirname.getAbsolutePath)
parquetFile(dirname.getAbsolutePath).registerAsTable("ptable")
val rddOne = sql("SELECT * FROM src").collect().sortBy(_.getInt(0))
val rddTwo = sql("SELECT * from ptable").collect().sortBy(_.getInt(0))
val rddOne = hql("SELECT * FROM src").collect().sortBy(_.getInt(0))
val rddTwo = hql("SELECT * from ptable").collect().sortBy(_.getInt(0))
compareRDDs(rddOne, rddTwo, "src (Hive)", Seq("key:Int", "value:String"))
}

test("INSERT OVERWRITE TABLE Parquet table") {
sql("SELECT * FROM testsource").saveAsParquetFile(dirname.getAbsolutePath)
hql("SELECT * FROM testsource").saveAsParquetFile(dirname.getAbsolutePath)
parquetFile(dirname.getAbsolutePath).registerAsTable("ptable")
// let's do three overwrites for good measure
sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect()
sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect()
sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect()
val rddCopy = sql("SELECT * FROM ptable").collect()
val rddOrig = sql("SELECT * FROM testsource").collect()
hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect()
hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect()
hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect()
val rddCopy = hql("SELECT * FROM ptable").collect()
val rddOrig = hql("SELECT * FROM testsource").collect()
assert(rddCopy.size === rddOrig.size, "INSERT OVERWRITE changed size of table??")
compareRDDs(rddOrig, rddCopy, "testsource", ParquetTestData.testSchemaFieldNames)
}
Expand All @@ -93,13 +93,13 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft
createParquetFile(dirname.getAbsolutePath, ("key", IntegerType), ("value", StringType))
.registerAsTable("tmp")
val rddCopy =
sql("INSERT INTO TABLE tmp SELECT * FROM src")
hql("INSERT INTO TABLE tmp SELECT * FROM src")
.collect()
.sortBy[Int](_.apply(0) match {
case x: Int => x
case _ => 0
})
val rddOrig = sql("SELECT * FROM src")
val rddOrig = hql("SELECT * FROM src")
.collect()
.sortBy(_.getInt(0))
compareRDDs(rddOrig, rddCopy, "src (Hive)", Seq("key:Int", "value:String"))
Expand All @@ -108,22 +108,22 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft
test("Appending to Parquet table") {
createParquetFile(dirname.getAbsolutePath, ("key", IntegerType), ("value", StringType))
.registerAsTable("tmpnew")
sql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect()
sql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect()
sql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect()
val rddCopies = sql("SELECT * FROM tmpnew").collect()
val rddOrig = sql("SELECT * FROM src").collect()
hql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect()
hql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect()
hql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect()
val rddCopies = hql("SELECT * FROM tmpnew").collect()
val rddOrig = hql("SELECT * FROM src").collect()
assert(rddCopies.size === 3 * rddOrig.size, "number of copied rows via INSERT INTO did not match correct number")
}

test("Appending to and then overwriting Parquet table") {
createParquetFile(dirname.getAbsolutePath, ("key", IntegerType), ("value", StringType))
.registerAsTable("tmp")
sql("INSERT INTO TABLE tmp SELECT * FROM src").collect()
sql("INSERT INTO TABLE tmp SELECT * FROM src").collect()
sql("INSERT OVERWRITE TABLE tmp SELECT * FROM src").collect()
val rddCopies = sql("SELECT * FROM tmp").collect()
val rddOrig = sql("SELECT * FROM src").collect()
hql("INSERT INTO TABLE tmp SELECT * FROM src").collect()
hql("INSERT INTO TABLE tmp SELECT * FROM src").collect()
hql("INSERT OVERWRITE TABLE tmp SELECT * FROM src").collect()
val rddCopies = hql("SELECT * FROM tmp").collect()
val rddOrig = hql("SELECT * FROM src").collect()
assert(rddCopies.size === rddOrig.size, "INSERT OVERWRITE did not actually overwrite")
}

Expand Down

0 comments on commit fbe4a54

Please sign in to comment.