Skip to content

Commit

Permalink
adding test case
Browse files Browse the repository at this point in the history
  • Loading branch information
scwf committed Dec 30, 2014
1 parent 7787ec7 commit 9bf12f8
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 17 deletions.
15 changes: 0 additions & 15 deletions sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,6 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi
protected implicit def asParser(k: Keyword): Parser[String] =
lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _)

// data types
protected val STRING = Keyword("STRING")
protected val DOUBLE = Keyword("DOUBLE")
protected val BOOLEAN = Keyword("BOOLEAN")
protected val FLOAT = Keyword("FLOAT")
protected val INT = Keyword("INT")
protected val TINYINT = Keyword("TINYINT")
protected val SMALLINT = Keyword("SMALLINT")
protected val BIGINT = Keyword("BIGINT")
protected val BINARY = Keyword("BINARY")
protected val DECIMAL = Keyword("DECIMAL")
protected val DATE = Keyword("DATE")
protected val TIMESTAMP = Keyword("TIMESTAMP")
protected val VARCHAR = Keyword("VARCHAR")

protected val CREATE = Keyword("CREATE")
protected val TEMPORARY = Keyword("TEMPORARY")
protected val TABLE = Keyword("TABLE")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.sources

import org.apache.spark.annotation.{Experimental, DeveloperApi}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{SQLConf, Row, SQLContext, StructType}
import org.apache.spark.sql.{Row, SQLContext, StructType}
import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class QueryTest extends PlanTest {
""".stripMargin)
}

if (prepareAnswer(convertedAnswer) != prepareAnswer(sparkAnswer)) {
if (prepareAnswer(convertedAnswer) != prepareAnswer(sparkAnswer)) { // issues here, sparkAnswer may be GenericRow[]
fail(s"""
|Results do not match for query:
|${rdd.logicalPlan}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.sources

import org.apache.spark.sql._
import java.sql.{Timestamp, Date}
import org.apache.spark.sql.execution.RDDConversions

case class PrimaryData(
stringField: String,
intField: Int,
longField: Long,
floatField: Float,
doubleField: Double,
shortField: Short,
byteField: Byte,
booleanField: Boolean,
decimalField: BigDecimal,
date: Date,
timestampField: Timestamp)

class AllDataTypesScanSource extends SchemaRelationProvider {
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String],
schema: Option[StructType] = None): BaseRelation = {
AllDataTypesScan(parameters("from").toInt, parameters("TO").toInt, schema)(sqlContext)
}
}

case class AllDataTypesScan(
from: Int,
to: Int,
userSpecifiedSchema: Option[StructType])(@transient val sqlContext: SQLContext)
extends TableScan {

override def schema = userSpecifiedSchema.get

override def buildScan() = {
val rdd = sqlContext.sparkContext.parallelize(from to to).map { i =>
PrimaryData(
i.toString,
i,
i.toLong,
i.toFloat,
i.toDouble,
i.toShort,
i.toByte,
true,
BigDecimal(i),
new Date(12345),
new Timestamp(12345))
}

RDDConversions.productToRowRdd(rdd, schema)
}

}

class NewTableScanSuite extends DataSourceTest {
import caseInsensisitiveContext._

var records = (1 to 10).map { i =>
Row(
i.toString,
i,
i.toLong,
i.toFloat,
i.toDouble,
i.toShort,
i.toByte,
true,
BigDecimal(i),
new Date(12345),
new Timestamp(12345))
}.toSeq

before {
sql(
"""
|CREATE TEMPORARY TABLE oneToTen(stringField string, intField int, longField bigint,
|floatField float, doubleField double, shortField smallint, byteField tinyint,
|booleanField boolean, decimalField decimal, dateField date, timestampField timestamp)
|USING org.apache.spark.sql.sources.AllDataTypesScanSource
|OPTIONS (
| From '1',
| To '10'
|)
""".stripMargin)
}

sqlTest(
"SELECT * FROM oneToTen",
records)

sqlTest(
"SELECT stringField FROM oneToTen",
(1 to 10).map(i =>Row(i.toString)).toSeq)

sqlTest(
"SELECT intField FROM oneToTen WHERE intField < 5",
(1 to 4).map(Row(_)).toSeq)

sqlTest(
"SELECT longField * 2 FROM oneToTen",
(1 to 10).map(i => Row(i * 2.toLong)).toSeq)

sqlTest(
"""SELECT a.floatField, b.floatField FROM oneToTen a JOIN oneToTen b
|ON a.floatField = b.floatField + 1""".stripMargin,
(2 to 10).map(i => Row(i.toFloat, i - 1.toFloat)).toSeq)

sqlTest(
"SELECT distinct(a.dateField) FROM oneToTen a",
Some(new Date(12345)).map(Row(_)).toSeq)

sqlTest(
"SELECT distinct(a.timestampField) FROM oneToTen a",
Some(new Timestamp(12345)).map(Row(_)).toSeq)

}

0 comments on commit 9bf12f8

Please sign in to comment.