Skip to content

Commit

Permalink
more udts...
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbradley committed Nov 2, 2014
1 parent 982c035 commit 53de70f
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,19 @@ object ScalaReflection {
}

/** Converts Catalyst types used internally in rows to standard Scala types */
def convertToScala(a: Any): Any = a match {
def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match {
// TODO: USE DATATYPE
// TODO: What about Option and Product?
case s: Seq[_] => s.map(convertToScala)
case m: Map[_, _] => m.map { case (k, v) => convertToScala(k) -> convertToScala(v) }
case d: Decimal => d.toBigDecimal
case (udt: Any, udtType: UserDefinedType[_]) => udtType.serialize(udt)
case other => other
}

def convertRowToScala(r: Row): Row = new GenericRow(r.toArray.map(convertToScala))
def convertRowToScala(r: Row, schema: StructType): Row = {
new GenericRow(r.toArray.map(convertToScala(_, schema)))
}

/** Returns a Sequence of attributes for the given case class type. */
def attributesFor[T: TypeTag](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class SchemaRDD(
// =========================================================================================

override def compute(split: Partition, context: TaskContext): Iterator[Row] =
firstParent[Row].compute(split, context).map(ScalaReflection.convertRowToScala)
firstParent[Row].compute(split, context).map(ScalaReflection.convertRowToScala(_, this.schema))

override def getPartitions: Array[Partition] = firstParent[Row].partitions

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/**
* Runs this query returning the result as an array.
*/
def executeCollect(): Array[Row] = execute().map(ScalaReflection.convertRowToScala).collect()
def executeCollect(): Array[Row] = {
execute().map(ScalaReflection.convertRowToScala(_, schema)).collect()
}

protected def newProjection(
expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ case class Limit(limit: Int, child: SparkPlan)
partsScanned += numPartsToTry
}

buf.toArray.map(ScalaReflection.convertRowToScala)
buf.toArray.map(ScalaReflection.convertRowToScala(_, this.schema))
}

override def execute() = {
Expand Down Expand Up @@ -180,7 +180,7 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)

// TODO: Is this copying for no reason?
override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ord)
.map(ScalaReflection.convertRowToScala)
.map(ScalaReflection.convertRowToScala(_, this.schema))

// TODO: Terminal split should be implemented differently from non-terminal split.
// TODO: Pick num splits based on |limit|.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
protected[sql] abstract class QueryExecution extends super.QueryExecution {

override lazy val toRdd: RDD[Row] = {
//val dataType = StructType.fromAttributes(logical.output)
executedPlan.execute().map(ScalaReflection.convertRowToScala(_))
val schema = StructType.fromAttributes(logical.output)
executedPlan.execute().map(ScalaReflection.convertRowToScala(_, schema))
}

protected val primitiveTypes =
Expand Down

0 comments on commit 53de70f

Please sign in to comment.