Skip to content
This repository has been archived by the owner on Dec 20, 2018. It is now read-only.

Support for logical datatypes like Decimal type #121

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/main/scala/com/databricks/spark/avro/AvroRelation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@ package com.databricks.spark.avro

import java.io.FileNotFoundException
import java.util.zip.Deflater

import scala.collection.Iterator
import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer

import com.google.common.base.Objects
import org.apache.avro.SchemaBuilder
import org.apache.avro.file.{DataFileConstants, DataFileReader, FileReader}
Expand All @@ -31,12 +29,12 @@ import org.apache.avro.mapred.{AvroOutputFormat, FsInput}
import org.apache.avro.mapreduce.AvroJob
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
import org.apache.hadoop.mapreduce.Job

import org.apache.spark.Logging
import org.apache.spark.rdd.{RDD, UnionRDD}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.avro.Schema.Type

private[avro] class AvroRelation(
override val paths: Array[String],
Expand Down Expand Up @@ -130,8 +128,11 @@ private[avro] class AvroRelation(
val firstRecord = records.next()
val superSchema = firstRecord.getSchema // the schema of the actual record
// the fields that are actually required along with their converters
val avroFieldMap = superSchema.getFields.map(f => (f.name, f)).toMap

val avroFieldMap = superSchema.getFields.map{f =>
f.getJsonProps.foreach(x => f.schema().addProp(x._1, x._2))
(f.name, f)
}.toMap

new Iterator[Row] {
private[this] val baseIterator = records
private[this] var currentRecord = firstRecord
Expand Down
73 changes: 66 additions & 7 deletions src/main/scala/com/databricks/spark/avro/SchemaConverters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,23 @@ import org.apache.avro.SchemaBuilder._
import org.apache.avro.Schema.Type._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.apache.avro.Schema.Type
import java.sql.Timestamp
import java.sql.Date

/**
* This object contains method that are used to convert sparkSQL schemas to avro schemas and vice
* versa.
*/
object SchemaConverters {

val LOGICAL_TYPE = "logicalType"
val DECIMAL = "decimal"
val TIMESTAMP = "timestamp";
val DATE = "date";
val PRECISION = "precision"
val SCALE = "scale"

case class SchemaType(dataType: DataType, nullable: Boolean)

/**
Expand All @@ -42,17 +52,36 @@ object SchemaConverters {
def toSqlType(avroSchema: Schema): SchemaType = {
avroSchema.getType match {
case INT => SchemaType(IntegerType, nullable = false)
case STRING => SchemaType(StringType, nullable = false)
case STRING => {
val logicalType = avroSchema.getJsonProp(LOGICAL_TYPE)
if(logicalType != null && logicalType.asText().equalsIgnoreCase(DECIMAL)){
val precision = avroSchema.getJsonProp(PRECISION).asInt
val scale = avroSchema.getJsonProp(SCALE).asInt
SchemaType(DecimalType(precision,scale), nullable = false)
}else {
SchemaType(StringType, nullable = false)
}
}
case BOOLEAN => SchemaType(BooleanType, nullable = false)
case BYTES => SchemaType(BinaryType, nullable = false)
case DOUBLE => SchemaType(DoubleType, nullable = false)
case FLOAT => SchemaType(FloatType, nullable = false)
case LONG => SchemaType(LongType, nullable = false)
case LONG => {
val logicalType = avroSchema.getJsonProp(LOGICAL_TYPE)
if(logicalType != null && logicalType.asText().equalsIgnoreCase(TIMESTAMP)) {
SchemaType(TimestampType, nullable = false)
}else if(logicalType != null && logicalType.asText().equalsIgnoreCase(DATE)) {
SchemaType(TimestampType, nullable = false)
}else{
SchemaType(LongType, nullable = false)
}
}
case FIXED => SchemaType(BinaryType, nullable = false)
case ENUM => SchemaType(StringType, nullable = false)

case RECORD =>
val fields = avroSchema.getFields.map { f =>
f.getJsonProps.foreach(x => f.schema().addProp(x._1, x._2))
val schemaType = toSqlType(f.schema())
StructField(f.name, schemaType.dataType, schemaType.nullable)
}
Expand All @@ -76,7 +105,9 @@ object SchemaConverters {
// In case of a union with null, eliminate it and make a recursive call
val remainingUnionTypes = avroSchema.getTypes.filterNot(_.getType == NULL)
if (remainingUnionTypes.size == 1) {
toSqlType(remainingUnionTypes.get(0)).copy(nullable = true)
val remainingSchema = remainingUnionTypes.get(0)
avroSchema.getJsonProps.foreach(x => remainingSchema.addProp(x._1, x._2))
toSqlType(remainingSchema).copy(nullable = true)
} else {
toSqlType(Schema.createUnion(remainingUnionTypes)).copy(nullable = true)
}
Expand Down Expand Up @@ -125,8 +156,31 @@ object SchemaConverters {
private[avro] def createConverterToSQL(schema: Schema): Any => Any = {
schema.getType match {
// Avro strings are in Utf8, so we have to call toString on them
case STRING | ENUM => (item: Any) => if (item == null) null else item.toString
case INT | BOOLEAN | DOUBLE | FLOAT | LONG => identity
case STRING | ENUM => (item: Any) => if (item == null) {
null
}else {
val logicalType = schema.getJsonProp(LOGICAL_TYPE)
if(logicalType != null && logicalType.asText().equalsIgnoreCase(DECIMAL)){
val precision = schema.getJsonProp(PRECISION).asInt
val scale = schema.getJsonProp(SCALE).asInt
Decimal.apply(BigDecimal.apply(item.toString()), precision, scale)
}else{
item.toString
}
}
case LONG => (item: Any) => if (item == null) {
null
}else {
val logicalType = schema.getJsonProp(LOGICAL_TYPE)
if(logicalType != null && logicalType.asText().equalsIgnoreCase(TIMESTAMP)){
new Timestamp(item.asInstanceOf[Long].longValue())
}else if(logicalType != null && logicalType.asText().equalsIgnoreCase(DATE)){
new Timestamp(item.asInstanceOf[Long].longValue())
}else{
item
}
}
case INT | BOOLEAN | DOUBLE | FLOAT => identity
// Byte arrays are reused by avro, so we have to make a copy of them.
case FIXED => (item: Any) => if (item == null) {
null
Expand All @@ -142,7 +196,10 @@ object SchemaConverters {
javaBytes
}
case RECORD =>
val fieldConverters = schema.getFields.map(f => createConverterToSQL(f.schema))
val fieldConverters = schema.getFields.map{f =>
f.getJsonProps.foreach(x => f.schema().addProp(x._1, x._2))
createConverterToSQL(f.schema)
}
(item: Any) => if (item == null) {
null
} else {
Expand Down Expand Up @@ -173,7 +230,9 @@ object SchemaConverters {
if (schema.getTypes.exists(_.getType == NULL)) {
val remainingUnionTypes = schema.getTypes.filterNot(_.getType == NULL)
if (remainingUnionTypes.size == 1) {
createConverterToSQL(remainingUnionTypes.get(0))
val remainingSchema = remainingUnionTypes.get(0)
schema.getJsonProps.foreach(x => remainingSchema.addProp(x._1, x._2))
createConverterToSQL(remainingSchema)
} else {
createConverterToSQL(Schema.createUnion(remainingUnionTypes))
}
Expand Down
Binary file added src/test/resources/users.avro
Binary file not shown.
29 changes: 28 additions & 1 deletion src/test/scala/com/databricks/spark/avro/AvroSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ import org.apache.spark.SparkContext
import org.apache.spark.sql.{SQLContext, Row}
import org.apache.spark.sql.types._
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import java.sql.Date

class AvroSuite extends FunSuite with BeforeAndAfterAll {
val episodesFile = "src/test/resources/episodes.avro"
val testFile = "src/test/resources/test.avro"
val userFile = "src/test/resources/users.avro"

private var sqlContext: SQLContext = _

Expand Down Expand Up @@ -442,4 +444,29 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll {
assert(newDf.count == 8)
}
}
}

test("Logical Types") {
val df = sqlContext.read.avro(userFile)

val decimals = df.select("decimal").collect().map(x => Decimal.apply(x.getDecimal(0)))
val dec1 = Decimal.apply(BigDecimal.apply("55555.555550000"), 25, 9)
val dec2 = Decimal.apply(BigDecimal.apply("8747336654.536756000"), 25, 9)

assert(decimals.apply(0).equals(dec1))
assert(decimals.apply(1).equals(dec2))

assert(df.schema.apply("decimal").dataType == DecimalType(25,9))


val timestamps = df.select("timestamp").collect().map(x => x.getAs[Timestamp](0))
val t1 = new Timestamp(1460354720000l)
val t2 = new Timestamp(1462842320000l)

assert(timestamps.apply(0).equals(t1))
assert(timestamps.apply(1).equals(t2))

assert(df.schema.apply("timestamp").dataType == TimestampType)


}
}