Skip to content

Commit

Permalink
Fixed a bug with symbol conversion.
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed Jan 27, 2015
1 parent 2ca74db commit 1e5e454
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 17 deletions.
5 changes: 5 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Literal.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ object Literal {
* data type is not supported by SparkSQL.
*/
protected[sql] def anyToLiteral(literal: Any): Column = {
// If the literal is a symbol, convert it into a Column.
if (literal.isInstanceOf[Symbol]) {
return dsl.symbolToColumn(literal.asInstanceOf[Symbol])
}

val literalExpr = literal match {
case v: Int => LiteralExpr(v, IntegerType)
case v: Long => LiteralExpr(v, LongType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,19 +118,19 @@ class DslQuerySuite extends QueryTest {

checkAnswer(
arrayData.orderBy('data.getItem(0).asc),
arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq)
arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq)

checkAnswer(
arrayData.orderBy('data.getItem(0).desc),
arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq)
arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq)

checkAnswer(
arrayData.orderBy('data.getItem(1).asc),
arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq)
arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq)

checkAnswer(
arrayData.orderBy('data.getItem(1).desc),
arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq)
arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq)
}

test("limit") {
Expand Down
20 changes: 10 additions & 10 deletions sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ case class TestData(key: Int, value: String)

object TestData {
val testData = TestSQLContext.sparkContext.parallelize(
(1 to 100).map(i => TestData(i, i.toString))).toSchemaRDD
(1 to 100).map(i => TestData(i, i.toString))).toDF
testData.registerTempTable("testData")

val negativeData = TestSQLContext.sparkContext.parallelize(
(1 to 100).map(i => TestData(-i, (-i).toString))).toSchemaRDD
(1 to 100).map(i => TestData(-i, (-i).toString))).toDF
negativeData.registerTempTable("negativeData")

case class LargeAndSmallInts(a: Int, b: Int)
Expand All @@ -45,7 +45,7 @@ object TestData {
LargeAndSmallInts(2147483645, 1) ::
LargeAndSmallInts(2, 2) ::
LargeAndSmallInts(2147483646, 1) ::
LargeAndSmallInts(3, 2) :: Nil).toSchemaRDD
LargeAndSmallInts(3, 2) :: Nil).toDF
largeAndSmallInts.registerTempTable("largeAndSmallInts")

case class TestData2(a: Int, b: Int)
Expand All @@ -56,7 +56,7 @@ object TestData {
TestData2(2, 1) ::
TestData2(2, 2) ::
TestData2(3, 1) ::
TestData2(3, 2) :: Nil, 2).toSchemaRDD
TestData2(3, 2) :: Nil, 2).toDF
testData2.registerTempTable("testData2")

case class DecimalData(a: BigDecimal, b: BigDecimal)
Expand All @@ -68,7 +68,7 @@ object TestData {
DecimalData(2, 1) ::
DecimalData(2, 2) ::
DecimalData(3, 1) ::
DecimalData(3, 2) :: Nil).toSchemaRDD
DecimalData(3, 2) :: Nil).toDF
decimalData.registerTempTable("decimalData")

case class BinaryData(a: Array[Byte], b: Int)
Expand All @@ -78,14 +78,14 @@ object TestData {
BinaryData("22".getBytes(), 5) ::
BinaryData("122".getBytes(), 3) ::
BinaryData("121".getBytes(), 2) ::
BinaryData("123".getBytes(), 4) :: Nil).toSchemaRDD
BinaryData("123".getBytes(), 4) :: Nil).toDF
binaryData.registerTempTable("binaryData")

case class TestData3(a: Int, b: Option[Int])
val testData3 =
TestSQLContext.sparkContext.parallelize(
TestData3(1, None) ::
TestData3(2, Some(2)) :: Nil).toSchemaRDD
TestData3(2, Some(2)) :: Nil).toDF
testData3.registerTempTable("testData3")

val emptyTableData = logical.LocalRelation($"a".int, $"b".int)
Expand All @@ -98,7 +98,7 @@ object TestData {
UpperCaseData(3, "C") ::
UpperCaseData(4, "D") ::
UpperCaseData(5, "E") ::
UpperCaseData(6, "F") :: Nil).toSchemaRDD
UpperCaseData(6, "F") :: Nil).toDF
upperCaseData.registerTempTable("upperCaseData")

case class LowerCaseData(n: Int, l: String)
Expand All @@ -107,7 +107,7 @@ object TestData {
LowerCaseData(1, "a") ::
LowerCaseData(2, "b") ::
LowerCaseData(3, "c") ::
LowerCaseData(4, "d") :: Nil).toSchemaRDD
LowerCaseData(4, "d") :: Nil).toDF
lowerCaseData.registerTempTable("lowerCaseData")

case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]])
Expand Down Expand Up @@ -201,6 +201,6 @@ object TestData {
TestSQLContext.sparkContext.parallelize(
ComplexData(Map(1 -> "1"), TestData(1, "1"), Seq(1), true)
:: ComplexData(Map(2 -> "2"), TestData(2, "2"), Seq(2), false)
:: Nil).toSchemaRDD
:: Nil).toDF
complexData.registerTempTable("complexData")
}
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,7 @@ class JsonSuite extends QueryTest {

val schemaRDD1 = applySchema(rowRDD1, schema1)
schemaRDD1.registerTempTable("applySchema1")
val schemaRDD2 = schemaRDD1.toSchemaRDD
val schemaRDD2 = schemaRDD1.toDF
val result = schemaRDD2.toJSON.collect()
assert(result(0) == "{\"f1\":1,\"f2\":\"A1\",\"f3\":true,\"f4\":[\"1\",\" A1\",\" true\",\" null\"]}")
assert(result(3) == "{\"f1\":4,\"f2\":\"D4\",\"f3\":true,\"f4\":[\"4\",\" D4\",\" true\",\" 2147483644\"],\"f5\":2147483644}")
Expand All @@ -842,7 +842,7 @@ class JsonSuite extends QueryTest {

val schemaRDD3 = applySchema(rowRDD2, schema2)
schemaRDD3.registerTempTable("applySchema2")
val schemaRDD4 = schemaRDD3.toSchemaRDD
val schemaRDD4 = schemaRDD3.toDF
val result2 = schemaRDD4.toJSON.collect()

assert(result2(1) == "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class InsertIntoHiveTableSuite extends QueryTest {
// Make sure the table has been updated.
checkAnswer(
sql("SELECT * FROM createAndInsertTest"),
testData.toSchemaRDD.collect().toSeq ++ testData.toSchemaRDD.collect().toSeq
testData.toDF.collect().toSeq ++ testData.toDF.collect().toSeq
)

// Now overwrite.
Expand Down

0 comments on commit 1e5e454

Please sign in to comment.