Skip to content

Commit

Permalink
Remove Option from createRelation.
Browse files Browse the repository at this point in the history
  • Loading branch information
yhuai committed Jan 9, 2015
1 parent 65e9c73 commit 38f634e
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,27 @@ import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.types.StructType
import org.apache.spark.sql.sources._

private[sql] class DefaultSource extends SchemaRelationProvider {
/** Returns a new base relation with the given parameters. */
private[sql] class DefaultSource extends RelationProvider with SchemaRelationProvider {

/** Returns a new base relation with the parameters. */
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = {
val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified"))
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)

JSONRelation(fileName, samplingRatio, None)(sqlContext)
}

/** Returns a new base relation with the given schema and parameters. */
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String],
schema: Option[StructType]): BaseRelation = {
schema: StructType): BaseRelation = {
val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified"))
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)

JSONRelation(fileName, samplingRatio, schema)(sqlContext)
JSONRelation(fileName, samplingRatio, Some(schema))(sqlContext)
}
}

Expand Down
31 changes: 22 additions & 9 deletions sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,28 @@ private[sql] case class CreateTableUsing(
sys.error(s"Failed to load class for data source: $provider")
}
}
val relation = clazz.newInstance match {
case dataSource: org.apache.spark.sql.sources.RelationProvider =>
dataSource
.asInstanceOf[org.apache.spark.sql.sources.RelationProvider]
.createRelation(sqlContext, new CaseInsensitiveMap(options))
case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider =>
dataSource
.asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider]
.createRelation(sqlContext, new CaseInsensitiveMap(options), userSpecifiedSchema)

val relation = userSpecifiedSchema match {
case Some(schema: StructType) => {
clazz.newInstance match {
case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider =>
dataSource
.asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider]
.createRelation(sqlContext, new CaseInsensitiveMap(options), schema)
case _ =>
sys.error(s"${clazz.getCanonicalName} should extend SchemaRelationProvider.")
}
}
case None => {
clazz.newInstance match {
case dataSource: org.apache.spark.sql.sources.RelationProvider =>
dataSource
.asInstanceOf[org.apache.spark.sql.sources.RelationProvider]
.createRelation(sqlContext, new CaseInsensitiveMap(options))
case _ =>
sys.error(s"${clazz.getCanonicalName} should extend RelationProvider.")
}
}
}

sqlContext.baseRelationToSchemaRDD(relation).registerTempTable(tableName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ trait SchemaRelationProvider {
def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String],
schema: Option[StructType]): BaseRelation
schema: StructType): BaseRelation
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,18 @@ class AllDataTypesScanSource extends SchemaRelationProvider {
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String],
schema: Option[StructType]): BaseRelation = {
schema: StructType): BaseRelation = {
AllDataTypesScan(parameters("from").toInt, parameters("TO").toInt, schema)(sqlContext)
}
}

case class AllDataTypesScan(
from: Int,
to: Int,
userSpecifiedSchema: Option[StructType])(@transient val sqlContext: SQLContext)
userSpecifiedSchema: StructType)(@transient val sqlContext: SQLContext)
extends TableScan {

override def schema = userSpecifiedSchema.get
override def schema = userSpecifiedSchema

override def buildScan() = {
sqlContext.sparkContext.parallelize(from to to).map { i =>
Expand Down

0 comments on commit 38f634e

Please sign in to comment.