From 39bc512c8b0b7ab10fb4e615972e51648030e614 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 27 Nov 2016 17:51:43 -0800 Subject: [PATCH 01/22] Change version to 3.1.1-SNAPSHOT --- version.sbt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.sbt b/version.sbt index 42cf5c4d..5a1bd5aa 100644 --- a/version.sbt +++ b/version.sbt @@ -1 +1 @@ -version in ThisBuild := "3.1.0" \ No newline at end of file +version in ThisBuild := "3.1.1-SNAPSHOT" From 62c3c53e07188598bcb0dddc5d9b5ef37fcd6a51 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 27 Nov 2016 17:57:16 -0800 Subject: [PATCH 02/22] Remove Spark 1.x documentation from README; add links to older READMEs To avoid confusion, we should remove the documentation / linking instructions for the 2.x line of releases since the current README describes features which don't apply there. Instead, we should link to the older docs. Author: Josh Rosen Closes #199 from JoshRosen/readme-fixes. (cherry picked from commit b01a034b0e5d785609275aa97d9c5e3719613194) Signed-off-by: Josh Rosen --- README.md | 34 +++++++--------------------------- 1 file changed, 7 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index 68b7d336..f002f56a 100644 --- a/README.md +++ b/README.md @@ -7,16 +7,17 @@ A library for reading and writing Avro data from [Spark SQL](http://spark.apache ## Requirements -This documentation is for Spark 1.4+ and 2.0. +This documentation is for version 3.1.0 of this library, which supports Spark 2.0+. For +documentation on earlier versions of this library, see the links below. This library has different versions for Spark 1.2, 1.3, 1.4+, and 2.0: | Spark Version | Compatible version of Avro Data Source for Spark | | ------------- | ------------------------------------------------ | | `1.2` | `0.2.0` | -| `1.3` | `1.0.0` | -| `1.4+` | `2.0.1` | -| `2.0` | `3.1.0` | +| `1.3` | [`1.0.0`](https://github.com/databricks/spark-avro/tree/v1.0.0) | +| `1.4+` | [`2.0.1`](https://github.com/databricks/spark-avro/tree/v2.0.1) | +| `2.0` | `3.1.0` (this version) | ## Linking @@ -24,33 +25,13 @@ This library is cross-published for Scala 2.11, so 2.11 users should replace 2.1 You can link against this library in your program at the following coordinates: -### For Spark 1.4+ - -Using SBT: - -``` -libraryDependencies += "com.databricks" %% "spark-avro" % "2.0.1" -``` - -Using Maven: - -```xml - - com.databricks - spark-avro_2.10 - 2.0.1 - -``` - -### For Spark 2.0 - -Using SBT: +**Using SBT:** ``` libraryDependencies += "com.databricks" %% "spark-avro" % "3.1.0" ``` -Using Maven: +**Using Maven:** ```xml @@ -66,7 +47,6 @@ This library can also be added to Spark jobs launched through `spark-shell` or ` For example, to include it when starting the spark shell: ``` -$ bin/spark-shell --packages com.databricks:spark-avro_2.10:2.0.1 $ bin/spark-shell --packages com.databricks:spark-avro_2.11:3.1.0 ``` From 29ef5f60f8ed206c4918ccef3e14541e78e71da6 Mon Sep 17 00:00:00 2001 From: Ian Hummel Date: Thu, 26 Jan 2017 15:50:21 -0500 Subject: [PATCH 03/22] WIP - allow creation of Dataset from RDD[SpecificRecord] There is still an issue with ExternalMapToCatalyst to be resolved --- .../databricks/spark/avro/AvroEncoder.scala | 625 ++++++++++++ .../spark/avro/SchemaConverters.scala | 66 +- .../com/databricks/spark/avro/ByteArray.java | 142 +++ .../databricks/spark/avro/DoubleArray.java | 142 +++ .../com/databricks/spark/avro/Feature.java | 196 ++++ .../databricks/spark/avro/SimpleEnums.java | 13 + .../databricks/spark/avro/SimpleFixed.java | 23 + .../databricks/spark/avro/SimpleRecord.java | 195 ++++ .../databricks/spark/avro/StringArray.java | 142 +++ .../com/databricks/spark/avro/TestRecord.java | 893 ++++++++++++++++++ src/test/resources/specific.avsc | 40 + .../com/databricks/spark/avro/AvroSuite.scala | 241 ++++- 12 files changed, 2687 insertions(+), 31 deletions(-) create mode 100644 src/main/scala/com/databricks/spark/avro/AvroEncoder.scala create mode 100644 src/test/java/com/databricks/spark/avro/ByteArray.java create mode 100644 src/test/java/com/databricks/spark/avro/DoubleArray.java create mode 100644 src/test/java/com/databricks/spark/avro/Feature.java create mode 100644 src/test/java/com/databricks/spark/avro/SimpleEnums.java create mode 100644 src/test/java/com/databricks/spark/avro/SimpleFixed.java create mode 100644 src/test/java/com/databricks/spark/avro/SimpleRecord.java create mode 100644 src/test/java/com/databricks/spark/avro/StringArray.java create mode 100644 src/test/java/com/databricks/spark/avro/TestRecord.java create mode 100644 src/test/resources/specific.avsc diff --git a/src/main/scala/com/databricks/spark/avro/AvroEncoder.scala b/src/main/scala/com/databricks/spark/avro/AvroEncoder.scala new file mode 100644 index 00000000..a488ecdc --- /dev/null +++ b/src/main/scala/com/databricks/spark/avro/AvroEncoder.scala @@ -0,0 +1,625 @@ +/* + * Copyright 2014 Databricks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.spark.avro + +/** + * A Spark-SQL Encoder for Avro objects + */ + +import java.io._ +import java.util.{Map => JMap} +import org.apache.avro.Schema.Parser +import org.apache.hadoop.conf.Configuration +import org.apache.spark.util.Utils + +import scala.collection.JavaConverters._ +import com.databricks.spark.avro.SchemaConverters.{IncompatibleSchemaException, SchemaType, resolveUnionType, toSqlType} +import org.apache.avro.Schema +import org.apache.avro.Schema.Type._ +import org.apache.avro.generic.{GenericData, IndexedRecord} +import org.apache.avro.reflect.ReflectData +import org.apache.avro.specific.SpecificRecord +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.objects._ +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +import scala.collection.JavaConversions._ +import scala.reflect.ClassTag + +/** + * A Spark-SQL Encoder for Avro objects + */ +object AvroEncoder { + /** + * Provides an Encoder for Avro objects of the given class + * + * @param avroClass the class of the Avro object for which to generate the Encoder + * @tparam T the type of the Avro class, must implement SpecificRecord + * @return an Encoder for the given Avro class + */ + def of[T <: SpecificRecord](avroClass: Class[T]): Encoder[T] = { + AvroExpressionEncoder.of(avroClass) + } + + /** + * Provides an Encoder for Avro objects implementing the given schema + * + * @param avroSchema the Schema of the Avro object for which to generate the Encoder + * @tparam T the type of the Avro class that implements the Schema, must implement IndexedRecord + * @return an Encoder for the given Avro Schema + */ + def of[T <: IndexedRecord](avroSchema: Schema): Encoder[T] = { + AvroExpressionEncoder.of(avroSchema) + } +} + +object ObjectType { + def apply(cls: Class[_]): DataType = { + val ot = Class.forName("org.apache.spark.sql.types.ObjectType") + val meth = ot.getDeclaredConstructor(classOf[Class[_]]) + meth.setAccessible(true) + meth.newInstance(cls).asInstanceOf[DataType] + } + + def _isInstanceOf(obj: AnyRef): Boolean = { + val ot = Class.forName("org.apache.spark.sql.types.ObjectType") + ot.isInstance(obj) + } +} + +class SerializableSchema(@transient var value: Schema) extends Externalizable { + def this() = this(null) + override def readExternal(in: ObjectInput): Unit = { + value = new Parser().parse(in.readObject().asInstanceOf[String]) + } + override def writeExternal(out: ObjectOutput): Unit = out.writeObject(value.toString) + def resolveUnion(datum: Any): Int = GenericData.get.resolveUnion(value, datum) +} + +object AvroExpressionEncoder { + def of[T <: SpecificRecord](avroClass: Class[T]): ExpressionEncoder[T] = { + val schema = avroClass.getMethod("getClassSchema").invoke(null).asInstanceOf[Schema] + assert(toSqlType(schema).dataType.isInstanceOf[StructType]) + + val serializer = AvroTypeInference.serializerFor(avroClass, schema) + val deserializer = AvroTypeInference.deserializerFor(schema) + + new ExpressionEncoder[T]( + toSqlType(schema).dataType.asInstanceOf[StructType], + flat = false, + serializer.flatten, + deserializer = deserializer, + ClassTag[T](avroClass)) + } + + def of[T <: IndexedRecord](schema: Schema): ExpressionEncoder[T] = { + assert(toSqlType(schema).dataType.isInstanceOf[StructType]) + + val avroClass = Option(ReflectData.get.getClass(schema)) + .map(_.asSubclass(classOf[SpecificRecord])) + .getOrElse(classOf[GenericData.Record]) + val serializer = AvroTypeInference.serializerFor(avroClass, schema) + val deserializer = AvroTypeInference.deserializerFor(schema) + + new ExpressionEncoder[T]( + toSqlType(schema).dataType.asInstanceOf[StructType], + flat = false, + serializer.flatten, + deserializer, + ClassTag[T](avroClass)) + } +} + +/** + * Utilities for providing Avro object serializers and deserializers + */ +private object AvroTypeInference { + /** + * Translates an Avro Schema type to a proper SQL DataType. The Java Objects that back data in + * generated Generic and Specific records sometimes do not align with those suggested by Avro + * ReflectData, so we infer the proper SQL DataType to serialize and deserialize based on + * nullability and the wrapping Schema type. + */ + private def inferExternalType(avroSchema: Schema): DataType = { + toSqlType(avroSchema) match { + // the non-nullable primitive types + case SchemaType(BooleanType, false) => BooleanType + case SchemaType(IntegerType, false) => IntegerType + case SchemaType(LongType, false) => + if (avroSchema.getType == UNION) { + ObjectType(classOf[java.lang.Number]) + } else { + LongType + } + case SchemaType(FloatType, false) => FloatType + case SchemaType(DoubleType, false) => + if (avroSchema.getType == UNION) { + ObjectType(classOf[java.lang.Number]) + } else { + DoubleType + } + // the nullable primitive types + case SchemaType(BooleanType, true) => ObjectType(classOf[java.lang.Boolean]) + case SchemaType(IntegerType, true) => ObjectType(classOf[java.lang.Integer]) + case SchemaType(LongType, true) => ObjectType(classOf[java.lang.Long]) + case SchemaType(FloatType, true) => ObjectType(classOf[java.lang.Float]) + case SchemaType(DoubleType, true) => ObjectType(classOf[java.lang.Double]) + // the binary types + case SchemaType(BinaryType, _) => + if (avroSchema.getType == FIXED) { + Option(ReflectData.get.getClass(avroSchema)) + .map(ObjectType(_)) + .getOrElse(ObjectType(classOf[GenericData.Fixed])) + } else { + ObjectType(classOf[java.nio.ByteBuffer]) + } + // the referenced types + case SchemaType(ArrayType(_, _), _) => + ObjectType(classOf[java.util.List[Object]]) + case SchemaType(StringType, _) => + avroSchema.getType match { + case ENUM => + Option(ReflectData.get.getClass(avroSchema)) + .map(ObjectType(_)) + .getOrElse(ObjectType(classOf[GenericData.EnumSymbol])) + case _ => + ObjectType(classOf[CharSequence]) + } + case SchemaType(StructType(_), _) => + Option(ReflectData.get.getClass(avroSchema)) + .map(ObjectType(_)) + .getOrElse(ObjectType(classOf[GenericData.Record])) + case SchemaType(MapType(_, _, _), _) => + ObjectType(classOf[java.util.Map[Object, Object]]) + } + } + + /** + * Returns an expression that can be used to deserialize an InternalRow to an Avro object of + * type `T` that implements IndexedRecord and is compatible with the given Schema + */ + def deserializerFor[T <: IndexedRecord] (avroSchema: Schema): Expression = { + deserializerFor(avroSchema, None) + } + + private def deserializerFor(avroSchema: Schema, path: Option[Expression]): Expression = { + def addToPath(part: String): Expression = path + .map(p => UnresolvedExtractValue(p, Literal(part))) + .getOrElse(UnresolvedAttribute(part)) + + def getPath: Expression = path.getOrElse( + GetColumnByOrdinal(0, inferExternalType(avroSchema))) + + avroSchema.getType match { + case BOOLEAN => + NewInstance( + classOf[java.lang.Boolean], + getPath :: Nil, + ObjectType(classOf[java.lang.Boolean])) + case INT => + NewInstance( + classOf[java.lang.Integer], + getPath :: Nil, + ObjectType(classOf[java.lang.Integer])) + case LONG => + NewInstance( + classOf[java.lang.Long], + getPath :: Nil, + ObjectType(classOf[java.lang.Long])) + case FLOAT => + NewInstance( + classOf[java.lang.Float], + getPath :: Nil, + ObjectType(classOf[java.lang.Float])) + case DOUBLE => + NewInstance( + classOf[java.lang.Double], + getPath :: Nil, + ObjectType(classOf[java.lang.Double])) + + case BYTES => + StaticInvoke( + classOf[java.nio.ByteBuffer], + ObjectType(classOf[java.nio.ByteBuffer]), + "wrap", + getPath :: Nil) + case FIXED => + val fixedClass = Option(ReflectData.get.getClass(avroSchema)) + .getOrElse(classOf[GenericData.Fixed]) + if (fixedClass == classOf[GenericData.Fixed]) { + NewInstance( + fixedClass, + Literal.fromObject(avroSchema, ObjectType(classOf[Schema])) :: + getPath :: + Nil, + ObjectType(fixedClass)) + } else { + NewInstance( + fixedClass, + getPath :: Nil, + ObjectType(fixedClass)) + } + + case STRING => + Invoke(getPath, "toString", ObjectType(classOf[String])) + + case ENUM => + val enumClass = Option(ReflectData.get.getClass(avroSchema)) + .getOrElse(classOf[GenericData.EnumSymbol]) + if (enumClass == classOf[GenericData.EnumSymbol]) { + NewInstance( + enumClass, + Literal.fromObject(avroSchema, ObjectType(classOf[Schema])) :: + Invoke(getPath, "toString", ObjectType(classOf[String])) :: + Nil, + ObjectType(enumClass)) + } else { + StaticInvoke( + enumClass, + ObjectType(enumClass), + "valueOf", + Invoke(getPath, "toString", ObjectType(classOf[String])) :: Nil) + } + + case ARRAY => + val elementSchema = avroSchema.getElementType + val elementType = toSqlType(elementSchema).dataType + val array = Invoke( + MapObjects(element => + deserializerFor(elementSchema, Some(element)), + getPath, + elementType), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke( + classOf[java.util.Arrays], + ObjectType(classOf[java.util.List[Object]]), + "asList", + array :: Nil) + + case MAP => + val valueSchema = avroSchema.getValueType + val valueType = inferExternalType(valueSchema) match { + case t if t == ObjectType(classOf[java.lang.CharSequence]) => + StringType + case other => other + } + + val keyData = Invoke( + MapObjects( + p => deserializerFor(Schema.create(STRING), Some(p)), + Invoke(getPath, "keyArray", ArrayType(StringType)), + StringType), + "array", + ObjectType(classOf[Array[Any]])) + val valueData = Invoke( + MapObjects( + p => deserializerFor(valueSchema, Some(p)), + Invoke(getPath, "valueArray", ArrayType(valueType)), + valueType), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke( + ArrayBasedMapData.getClass, + ObjectType(classOf[JMap[_, _]]), + "toJavaMap", + keyData :: valueData :: Nil) + + case UNION => + val (resolvedSchema, _) = resolveUnionType(avroSchema) + if (resolvedSchema.getType == RECORD && + avroSchema.getTypes.asScala.filterNot(_.getType == NULL).length > 1) { + // A Union resolved to a record that originally had more than 1 type when filtered + // of its nulls must be complex + val bottom = Literal.create(null, ObjectType(classOf[Object])).asInstanceOf[Expression] + + resolvedSchema.getFields.foldLeft(bottom) { (tree: Expression, field: Schema.Field) => + val fieldValue = ObjectCast( + deserializerFor(field.schema, Some(addToPath(field.name))), + ObjectType(classOf[Object])) + + If(IsNull(fieldValue), tree, fieldValue) + } + } else { + deserializerFor(resolvedSchema, path) + } + + case RECORD => + val args = avroSchema.getFields.map { field => + val position = Literal(field.pos) + val argument = deserializerFor(field.schema, Some(addToPath(field.name))) + (position, argument) + }.toList + + val recordClass = Option(ReflectData.get.getClass(avroSchema)) + .getOrElse(classOf[GenericData.Record]) + val newInstance = if (recordClass == classOf[GenericData.Record]) { + NewInstance( + recordClass, + Literal.fromObject(avroSchema, ObjectType(classOf[Schema])) :: Nil, + ObjectType(recordClass)) + } else { + NewInstance( + recordClass, + Nil, + ObjectType(recordClass)) + } + + val result = InitializeAvroObject(newInstance, args) + + if (path.nonEmpty) { + If(IsNull(getPath), + Literal.create(null, ObjectType(recordClass)), + result) + } else { + result + } + + case NULL => + /* + * Encountering NULL at this level implies it was the type of a Field, which should never + * be the case + */ + throw new IncompatibleSchemaException("Null type should only be used in Union types") + } + } + + /** + * Returns an expression that can be used to serialize an Avro object with a class of type `T` + * that is compatible with the given Schema to an InternalRow + */ + def serializerFor[T <: IndexedRecord](avroClass: Class[T], avroSchema: Schema): + CreateNamedStruct = { + val inputObject = BoundReference(0, ObjectType(avroClass), nullable = true) + serializerFor(inputObject, avroSchema, topLevel = true).asInstanceOf[CreateNamedStruct] + } + + def serializerFor( + inputObject: Expression, + avroSchema: Schema, + topLevel: Boolean = false): Expression = { + + def toCatalystArray(inputObject: Expression, schema: Schema): Expression = { + val elementType = inferExternalType(schema) + + if (ObjectType._isInstanceOf(elementType)) { + MapObjects(element => + serializerFor(element, schema), + Invoke( + inputObject, + "toArray", + ObjectType(classOf[Array[Object]])), + elementType) + } else { + NewInstance( + classOf[GenericArrayData], + inputObject :: Nil, + dataType = ArrayType(elementType, containsNull = false)) + } + } + + def toCatalystMap(inputObject: Expression, schema: Schema): Expression = { + val valueSchema = schema.getValueType + val valueType = inferExternalType(valueSchema) + +// ExternalMapToCatalyst( +// inputObject, +// ObjectType(classOf[org.apache.avro.util.Utf8]), +// serializerFor(_, Schema.create(STRING)), +// valueType, +// serializerFor(_, valueSchema)) + ??? + } + + if (!ObjectType._isInstanceOf(inputObject.dataType)) { + inputObject + } else { + avroSchema.getType match { + case BOOLEAN => + Invoke(inputObject, "booleanValue", BooleanType) + case INT => + Invoke(inputObject, "intValue", IntegerType) + case LONG => + Invoke(inputObject, "longValue", LongType) + case FLOAT => + Invoke(inputObject, "floatValue", FloatType) + case DOUBLE => + Invoke(inputObject, "doubleValue", DoubleType) + + case BYTES => + Invoke(inputObject, "array", BinaryType) + case FIXED => + Invoke(inputObject, "bytes", BinaryType) + + case STRING => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + Invoke(inputObject, "toString", ObjectType(classOf[java.lang.String])) :: Nil) + + case ENUM => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + Invoke(inputObject, "toString", ObjectType(classOf[java.lang.String])) :: Nil) + + case ARRAY => + val elementSchema = avroSchema.getElementType + toCatalystArray(inputObject, elementSchema) + + case MAP => + toCatalystMap(inputObject, avroSchema) + + case UNION => + val unionWithoutNulls = Schema.createUnion( + avroSchema.getTypes.asScala.filterNot(_.getType == NULL)) + val (resolvedSchema, nullable) = resolveUnionType(avroSchema) + if (resolvedSchema.getType == RECORD && unionWithoutNulls.getTypes.length > 1) { + // A Union resolved to a record that originally had more than 1 type when filtered + // of its nulls must be complex + val complexStruct = CreateNamedStruct( + resolvedSchema.getFields.zipWithIndex.flatMap { case (field, index) => + val unionIndex = Invoke( + Literal.fromObject( + new SerializableSchema(unionWithoutNulls), + ObjectType(classOf[SerializableSchema])), + "resolveUnion", + IntegerType, + inputObject :: Nil) + + val fieldValue = If(EqualTo(Literal(index), unionIndex), +// val fieldValue = If(EqualTo(Literal(index), Literal.fromObject(1, IntegerType)), + serializerFor( + ObjectCast( + inputObject, + inferExternalType(field.schema())), + field.schema), + Literal.create(null, toSqlType(field.schema()).dataType)) + + Literal(field.name) :: serializerFor(fieldValue, field.schema) :: Nil}) + + complexStruct + + } else { + if (nullable) { + serializerFor(inputObject, resolvedSchema) + } else { + serializerFor( + AssertNotNull(inputObject, Seq(avroSchema.getTypes.toString)), + resolvedSchema) + } + } + + case RECORD => + val createStruct = CreateNamedStruct( + avroSchema.getFields.flatMap { field => + val fieldValue = Invoke( + inputObject, + "get", + inferExternalType(field.schema), + Literal(field.pos) :: Nil) + Literal(field.name) :: serializerFor(fieldValue, field.schema) :: Nil}) + if (topLevel) { + createStruct + } else { + If(IsNull(inputObject), + Literal.create(null, createStruct.dataType), + createStruct) + } + + case NULL => + /* + * Encountering NULL at this level implies it was the type of a Field, which should never + * be the case + */ + throw new IncompatibleSchemaException("Null type should only be used in Union types") + } + } + } + + /** + * Initializes an Avro Record instance (that implements the IndexedRecord interface) by calling + * the `put` method on a the Record instance with the provided position and value arguments + * + * @param objectInstance an expression that will evaluate to the Record instance + * @param args a sequence of expression pairs that will respectively evaluate to the index of + * the record in which to insert, and the argument value to insert + */ + private case class InitializeAvroObject( + objectInstance: Expression, + args: List[(Expression, Expression)]) extends Expression with NonSQLExpression { + + override def nullable: Boolean = objectInstance.nullable + override def children: Seq[Expression] = objectInstance +: args.map { case (_, v) => v } + override def dataType: DataType = objectInstance.dataType + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val instanceGen = objectInstance.genCode(ctx) + + val avroInstance = ctx.freshName("avroObject") + val avroInstanceJavaType = ctx.javaType(objectInstance.dataType) + ctx.addMutableState(avroInstanceJavaType, avroInstance, "") + + val initialize = args.map { + case (posExpr, argExpr) => + val posGen = posExpr.genCode(ctx) + val argGen = argExpr.genCode(ctx) + s""" + ${posGen.code} + ${argGen.code} + $avroInstance.put(${posGen.value}, ${argGen.value}); + """ + } + + val initExpressions = ctx.splitExpressions(ctx.INPUT_ROW, initialize) + val code = + s""" + ${instanceGen.code} + $avroInstance = ${instanceGen.value}; + if (!${instanceGen.isNull}) { + $initExpressions + } + """ + ev.copy(code = code, isNull = instanceGen.isNull, value = instanceGen.value) + } + } + + /** + * Casts an expression to another object. + * + * @param value The value to cast + * @param resultType The type the value should be cast to. + */ + private case class ObjectCast( + value : Expression, + resultType: DataType) extends Expression with NonSQLExpression { + + override def nullable: Boolean = value.nullable + override def dataType: DataType = resultType + override def children: Seq[Expression] = value :: Nil + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + + val javaType = ctx.javaType(resultType) + val obj = value.genCode(ctx) + + val code = s""" + ${obj.code} + final $javaType ${ev.value} = ($javaType) ${obj.value}; + """ + + ev.copy(code = code, isNull = obj.isNull) + } + } +} diff --git a/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala b/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala index aa634d4c..1b8bc450 100644 --- a/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala +++ b/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala @@ -18,13 +18,11 @@ package com.databricks.spark.avro import java.nio.ByteBuffer import scala.collection.JavaConverters._ - import org.apache.avro.generic.GenericData.Fixed import org.apache.avro.generic.{GenericData, GenericRecord} import org.apache.avro.{Schema, SchemaBuilder} import org.apache.avro.SchemaBuilder._ import org.apache.avro.Schema.Type._ - import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.types._ @@ -74,38 +72,50 @@ object SchemaConverters { nullable = false) case UNION => - if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) { - // In case of a union with null, eliminate it and make a recursive call - val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL) - if (remainingUnionTypes.size == 1) { - toSqlType(remainingUnionTypes.head).copy(nullable = true) - } else { - toSqlType(Schema.createUnion(remainingUnionTypes.asJava)).copy(nullable = true) - } - } else avroSchema.getTypes.asScala.map(_.getType) match { - case Seq(t1) => - toSqlType(avroSchema.getTypes.get(0)) - case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) => - SchemaType(LongType, nullable = false) - case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) => - SchemaType(DoubleType, nullable = false) - case _ => - // Convert complex unions to struct types where field names are member0, member1, etc. - // This is consistent with the behavior when converting between Avro and Parquet. - val fields = avroSchema.getTypes.asScala.zipWithIndex.map { - case (s, i) => - val schemaType = toSqlType(s) - // All fields are nullable because only one of them is set at a time - StructField(s"member$i", schemaType.dataType, nullable = true) - } - - SchemaType(StructType(fields), nullable = false) + resolveUnionType(avroSchema) match { + case (schema, nullable) => toSqlType(schema).copy(nullable = nullable) } case other => throw new IncompatibleSchemaException(s"Unsupported type $other") } } + /** + * Resolves an avro UNION type to an SQL-compatible avro type. Converts complex unions to records + * if necessary. + */ + def resolveUnionType(avroSchema: Schema, nullable: Boolean = false): (Schema, Boolean) = { + if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) { + // In case of a union with null, eliminate it, and make a recursive call + val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL) + if (remainingUnionTypes.size == 1) { + (remainingUnionTypes.head, true) + } else { + resolveUnionType(Schema.createUnion(remainingUnionTypes.asJava), nullable = true) + } + } else avroSchema.getTypes.asScala.map(_.getType) match { + case Seq(t1) => + (avroSchema.getTypes.get(0), true) + case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) => + (Schema.create(LONG), false) + case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) => + (Schema.create(DOUBLE), false) + case _ => + // Convert complex unions to records where field names are member0, member1, etc. + // This is consistent with the behavior when converting between Avro and Parquet. + val record = SchemaBuilder.record(avroSchema.getName).fields() + avroSchema.getTypes.asScala.zipWithIndex.foreach { + case (s, i) => + // All fields are nullable because only one of them is set at a time + record.name(s"member$i").`type`(SchemaBuilder.unionOf() + .`type`(Schema.create(NULL)).and + .`type`(s).endUnion()) + .withDefault(null) + } + (record.endRecord(), false) + } + } + /** * This function converts sparkSQL StructType into avro schema. This method uses two other * converter methods in order to do the conversion. diff --git a/src/test/java/com/databricks/spark/avro/ByteArray.java b/src/test/java/com/databricks/spark/avro/ByteArray.java new file mode 100644 index 00000000..d29b38d7 --- /dev/null +++ b/src/test/java/com/databricks/spark/avro/ByteArray.java @@ -0,0 +1,142 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package com.databricks.spark.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class ByteArray extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"ByteArray\",\"namespace\":\"com.databricks.spark.avro\",\"fields\":[{\"name\":\"value\",\"type\":{\"type\":\"array\",\"items\":\"bytes\"}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.util.List value; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use {@link \#newBuilder()}. + */ + public ByteArray() {} + + /** + * All-args constructor. + */ + public ByteArray(java.util.List value) { + this.value = value; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return value; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: value = (java.util.List)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'value' field. + */ + public java.util.List getValue() { + return value; + } + + /** + * Sets the value of the 'value' field. + * @param value the value to set. + */ + public void setValue(java.util.List value) { + this.value = value; + } + + /** Creates a new ByteArray RecordBuilder */ + public static ByteArray.Builder newBuilder() { + return new ByteArray.Builder(); + } + + /** Creates a new ByteArray RecordBuilder by copying an existing Builder */ + public static ByteArray.Builder newBuilder(ByteArray.Builder other) { + return new ByteArray.Builder(other); + } + + /** Creates a new ByteArray RecordBuilder by copying an existing ByteArray instance */ + public static ByteArray.Builder newBuilder(ByteArray other) { + return new ByteArray.Builder(other); + } + + /** + * RecordBuilder for ByteArray instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.util.List value; + + /** Creates a new Builder */ + private Builder() { + super(ByteArray.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(ByteArray.Builder other) { + super(other); + if (isValidValue(fields()[0], other.value)) { + this.value = data().deepCopy(fields()[0].schema(), other.value); + fieldSetFlags()[0] = true; + } + } + + /** Creates a Builder by copying an existing ByteArray instance */ + private Builder(ByteArray other) { + super(ByteArray.SCHEMA$); + if (isValidValue(fields()[0], other.value)) { + this.value = data().deepCopy(fields()[0].schema(), other.value); + fieldSetFlags()[0] = true; + } + } + + /** Gets the value of the 'value' field */ + public java.util.List getValue() { + return value; + } + + /** Sets the value of the 'value' field */ + public ByteArray.Builder setValue(java.util.List value) { + validate(fields()[0], value); + this.value = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'value' field has been set */ + public boolean hasValue() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'value' field */ + public ByteArray.Builder clearValue() { + value = null; + fieldSetFlags()[0] = false; + return this; + } + + @Override + public ByteArray build() { + try { + ByteArray record = new ByteArray(); + record.value = fieldSetFlags()[0] ? this.value : (java.util.List) defaultValue(fields()[0]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/src/test/java/com/databricks/spark/avro/DoubleArray.java b/src/test/java/com/databricks/spark/avro/DoubleArray.java new file mode 100644 index 00000000..470fdc70 --- /dev/null +++ b/src/test/java/com/databricks/spark/avro/DoubleArray.java @@ -0,0 +1,142 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package com.databricks.spark.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class DoubleArray extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"DoubleArray\",\"namespace\":\"com.databricks.spark.avro\",\"fields\":[{\"name\":\"value\",\"type\":{\"type\":\"array\",\"items\":\"double\"}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.util.List value; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use {@link \#newBuilder()}. + */ + public DoubleArray() {} + + /** + * All-args constructor. + */ + public DoubleArray(java.util.List value) { + this.value = value; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return value; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: value = (java.util.List)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'value' field. + */ + public java.util.List getValue() { + return value; + } + + /** + * Sets the value of the 'value' field. + * @param value the value to set. + */ + public void setValue(java.util.List value) { + this.value = value; + } + + /** Creates a new DoubleArray RecordBuilder */ + public static DoubleArray.Builder newBuilder() { + return new DoubleArray.Builder(); + } + + /** Creates a new DoubleArray RecordBuilder by copying an existing Builder */ + public static DoubleArray.Builder newBuilder(DoubleArray.Builder other) { + return new DoubleArray.Builder(other); + } + + /** Creates a new DoubleArray RecordBuilder by copying an existing DoubleArray instance */ + public static DoubleArray.Builder newBuilder(DoubleArray other) { + return new DoubleArray.Builder(other); + } + + /** + * RecordBuilder for DoubleArray instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.util.List value; + + /** Creates a new Builder */ + private Builder() { + super(DoubleArray.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(DoubleArray.Builder other) { + super(other); + if (isValidValue(fields()[0], other.value)) { + this.value = data().deepCopy(fields()[0].schema(), other.value); + fieldSetFlags()[0] = true; + } + } + + /** Creates a Builder by copying an existing DoubleArray instance */ + private Builder(DoubleArray other) { + super(DoubleArray.SCHEMA$); + if (isValidValue(fields()[0], other.value)) { + this.value = data().deepCopy(fields()[0].schema(), other.value); + fieldSetFlags()[0] = true; + } + } + + /** Gets the value of the 'value' field */ + public java.util.List getValue() { + return value; + } + + /** Sets the value of the 'value' field */ + public DoubleArray.Builder setValue(java.util.List value) { + validate(fields()[0], value); + this.value = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'value' field has been set */ + public boolean hasValue() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'value' field */ + public DoubleArray.Builder clearValue() { + value = null; + fieldSetFlags()[0] = false; + return this; + } + + @Override + public DoubleArray build() { + try { + DoubleArray record = new DoubleArray(); + record.value = fieldSetFlags()[0] ? this.value : (java.util.List) defaultValue(fields()[0]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/src/test/java/com/databricks/spark/avro/Feature.java b/src/test/java/com/databricks/spark/avro/Feature.java new file mode 100644 index 00000000..c421a183 --- /dev/null +++ b/src/test/java/com/databricks/spark/avro/Feature.java @@ -0,0 +1,196 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package com.databricks.spark.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class Feature extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"Feature\",\"namespace\":\"com.databricks.spark.avro\",\"fields\":[{\"name\":\"key\",\"type\":\"string\"},{\"name\":\"value\",\"type\":[{\"type\":\"record\",\"name\":\"DoubleArray\",\"fields\":[{\"name\":\"value\",\"type\":{\"type\":\"array\",\"items\":\"double\"}}]},{\"type\":\"record\",\"name\":\"StringArray\",\"fields\":[{\"name\":\"value\",\"type\":{\"type\":\"array\",\"items\":\"string\"}}]},{\"type\":\"record\",\"name\":\"ByteArray\",\"fields\":[{\"name\":\"value\",\"type\":{\"type\":\"array\",\"items\":\"bytes\"}}]}]}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.lang.CharSequence key; + @Deprecated public java.lang.Object value; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use {@link \#newBuilder()}. + */ + public Feature() {} + + /** + * All-args constructor. + */ + public Feature(java.lang.CharSequence key, java.lang.Object value) { + this.key = key; + this.value = value; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return key; + case 1: return value; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: key = (java.lang.CharSequence)value$; break; + case 1: value = (java.lang.Object)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'key' field. + */ + public java.lang.CharSequence getKey() { + return key; + } + + /** + * Sets the value of the 'key' field. + * @param value the value to set. + */ + public void setKey(java.lang.CharSequence value) { + this.key = value; + } + + /** + * Gets the value of the 'value' field. + */ + public java.lang.Object getValue() { + return value; + } + + /** + * Sets the value of the 'value' field. + * @param value the value to set. + */ + public void setValue(java.lang.Object value) { + this.value = value; + } + + /** Creates a new Feature RecordBuilder */ + public static Feature.Builder newBuilder() { + return new Feature.Builder(); + } + + /** Creates a new Feature RecordBuilder by copying an existing Builder */ + public static Feature.Builder newBuilder(Feature.Builder other) { + return new Feature.Builder(other); + } + + /** Creates a new Feature RecordBuilder by copying an existing Feature instance */ + public static Feature.Builder newBuilder(Feature other) { + return new Feature.Builder(other); + } + + /** + * RecordBuilder for Feature instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.lang.CharSequence key; + private java.lang.Object value; + + /** Creates a new Builder */ + private Builder() { + super(Feature.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(Feature.Builder other) { + super(other); + if (isValidValue(fields()[0], other.key)) { + this.key = data().deepCopy(fields()[0].schema(), other.key); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.value)) { + this.value = data().deepCopy(fields()[1].schema(), other.value); + fieldSetFlags()[1] = true; + } + } + + /** Creates a Builder by copying an existing Feature instance */ + private Builder(Feature other) { + super(Feature.SCHEMA$); + if (isValidValue(fields()[0], other.key)) { + this.key = data().deepCopy(fields()[0].schema(), other.key); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.value)) { + this.value = data().deepCopy(fields()[1].schema(), other.value); + fieldSetFlags()[1] = true; + } + } + + /** Gets the value of the 'key' field */ + public java.lang.CharSequence getKey() { + return key; + } + + /** Sets the value of the 'key' field */ + public Feature.Builder setKey(java.lang.CharSequence value) { + validate(fields()[0], value); + this.key = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'key' field has been set */ + public boolean hasKey() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'key' field */ + public Feature.Builder clearKey() { + key = null; + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'value' field */ + public java.lang.Object getValue() { + return value; + } + + /** Sets the value of the 'value' field */ + public Feature.Builder setValue(java.lang.Object value) { + validate(fields()[1], value); + this.value = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'value' field has been set */ + public boolean hasValue() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'value' field */ + public Feature.Builder clearValue() { + value = null; + fieldSetFlags()[1] = false; + return this; + } + + @Override + public Feature build() { + try { + Feature record = new Feature(); + record.key = fieldSetFlags()[0] ? this.key : (java.lang.CharSequence) defaultValue(fields()[0]); + record.value = fieldSetFlags()[1] ? this.value : (java.lang.Object) defaultValue(fields()[1]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/src/test/java/com/databricks/spark/avro/SimpleEnums.java b/src/test/java/com/databricks/spark/avro/SimpleEnums.java new file mode 100644 index 00000000..5989c620 --- /dev/null +++ b/src/test/java/com/databricks/spark/avro/SimpleEnums.java @@ -0,0 +1,13 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package com.databricks.spark.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public enum SimpleEnums { + SPADES, HEARTS, DIAMONDS, CLUBS ; + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"enum\",\"name\":\"SimpleEnums\",\"namespace\":\"com.databricks.spark.avro\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } +} diff --git a/src/test/java/com/databricks/spark/avro/SimpleFixed.java b/src/test/java/com/databricks/spark/avro/SimpleFixed.java new file mode 100644 index 00000000..8318b65a --- /dev/null +++ b/src/test/java/com/databricks/spark/avro/SimpleFixed.java @@ -0,0 +1,23 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package com.databricks.spark.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.FixedSize(16) +@org.apache.avro.specific.AvroGenerated +public class SimpleFixed extends org.apache.avro.specific.SpecificFixed { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"fixed\",\"name\":\"SimpleFixed\",\"namespace\":\"com.databricks.spark.avro\",\"size\":16}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + + /** Creates a new SimpleFixed */ + public SimpleFixed() { + super(); + } + + /** Creates a new SimpleFixed with the given bytes */ + public SimpleFixed(byte[] bytes) { + super(bytes); + } +} diff --git a/src/test/java/com/databricks/spark/avro/SimpleRecord.java b/src/test/java/com/databricks/spark/avro/SimpleRecord.java new file mode 100644 index 00000000..a36161ed --- /dev/null +++ b/src/test/java/com/databricks/spark/avro/SimpleRecord.java @@ -0,0 +1,195 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package com.databricks.spark.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class SimpleRecord extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"SimpleRecord\",\"namespace\":\"com.databricks.spark.avro\",\"fields\":[{\"name\":\"nested1\",\"type\":\"int\",\"default\":0},{\"name\":\"nested2\",\"type\":\"string\",\"default\":\"string\"}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public int nested1; + @Deprecated public java.lang.CharSequence nested2; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public SimpleRecord() {} + + /** + * All-args constructor. + */ + public SimpleRecord(java.lang.Integer nested1, java.lang.CharSequence nested2) { + this.nested1 = nested1; + this.nested2 = nested2; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return nested1; + case 1: return nested2; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: nested1 = (java.lang.Integer)value$; break; + case 1: nested2 = (java.lang.CharSequence)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'nested1' field. + */ + public java.lang.Integer getNested1() { + return nested1; + } + + /** + * Sets the value of the 'nested1' field. + * @param value the value to set. + */ + public void setNested1(java.lang.Integer value) { + this.nested1 = value; + } + + /** + * Gets the value of the 'nested2' field. + */ + public java.lang.CharSequence getNested2() { + return nested2; + } + + /** + * Sets the value of the 'nested2' field. + * @param value the value to set. + */ + public void setNested2(java.lang.CharSequence value) { + this.nested2 = value; + } + + /** Creates a new SimpleRecord RecordBuilder */ + public static com.databricks.spark.avro.SimpleRecord.Builder newBuilder() { + return new com.databricks.spark.avro.SimpleRecord.Builder(); + } + + /** Creates a new SimpleRecord RecordBuilder by copying an existing Builder */ + public static com.databricks.spark.avro.SimpleRecord.Builder newBuilder(com.databricks.spark.avro.SimpleRecord.Builder other) { + return new com.databricks.spark.avro.SimpleRecord.Builder(other); + } + + /** Creates a new SimpleRecord RecordBuilder by copying an existing SimpleRecord instance */ + public static com.databricks.spark.avro.SimpleRecord.Builder newBuilder(com.databricks.spark.avro.SimpleRecord other) { + return new com.databricks.spark.avro.SimpleRecord.Builder(other); + } + + /** + * RecordBuilder for SimpleRecord instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private int nested1; + private java.lang.CharSequence nested2; + + /** Creates a new Builder */ + private Builder() { + super(com.databricks.spark.avro.SimpleRecord.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(com.databricks.spark.avro.SimpleRecord.Builder other) { + super(other); + if (isValidValue(fields()[0], other.nested1)) { + this.nested1 = data().deepCopy(fields()[0].schema(), other.nested1); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.nested2)) { + this.nested2 = data().deepCopy(fields()[1].schema(), other.nested2); + fieldSetFlags()[1] = true; + } + } + + /** Creates a Builder by copying an existing SimpleRecord instance */ + private Builder(com.databricks.spark.avro.SimpleRecord other) { + super(com.databricks.spark.avro.SimpleRecord.SCHEMA$); + if (isValidValue(fields()[0], other.nested1)) { + this.nested1 = data().deepCopy(fields()[0].schema(), other.nested1); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.nested2)) { + this.nested2 = data().deepCopy(fields()[1].schema(), other.nested2); + fieldSetFlags()[1] = true; + } + } + + /** Gets the value of the 'nested1' field */ + public java.lang.Integer getNested1() { + return nested1; + } + + /** Sets the value of the 'nested1' field */ + public com.databricks.spark.avro.SimpleRecord.Builder setNested1(int value) { + validate(fields()[0], value); + this.nested1 = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'nested1' field has been set */ + public boolean hasNested1() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'nested1' field */ + public com.databricks.spark.avro.SimpleRecord.Builder clearNested1() { + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'nested2' field */ + public java.lang.CharSequence getNested2() { + return nested2; + } + + /** Sets the value of the 'nested2' field */ + public com.databricks.spark.avro.SimpleRecord.Builder setNested2(java.lang.CharSequence value) { + validate(fields()[1], value); + this.nested2 = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'nested2' field has been set */ + public boolean hasNested2() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'nested2' field */ + public com.databricks.spark.avro.SimpleRecord.Builder clearNested2() { + nested2 = null; + fieldSetFlags()[1] = false; + return this; + } + + @Override + public SimpleRecord build() { + try { + SimpleRecord record = new SimpleRecord(); + record.nested1 = fieldSetFlags()[0] ? this.nested1 : (java.lang.Integer) defaultValue(fields()[0]); + record.nested2 = fieldSetFlags()[1] ? this.nested2 : (java.lang.CharSequence) defaultValue(fields()[1]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/src/test/java/com/databricks/spark/avro/StringArray.java b/src/test/java/com/databricks/spark/avro/StringArray.java new file mode 100644 index 00000000..ce980d12 --- /dev/null +++ b/src/test/java/com/databricks/spark/avro/StringArray.java @@ -0,0 +1,142 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package com.databricks.spark.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class StringArray extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"StringArray\",\"namespace\":\"com.databricks.spark.avro\",\"fields\":[{\"name\":\"value\",\"type\":{\"type\":\"array\",\"items\":\"string\"}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.util.List value; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use {@link \#newBuilder()}. + */ + public StringArray() {} + + /** + * All-args constructor. + */ + public StringArray(java.util.List value) { + this.value = value; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return value; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: value = (java.util.List)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'value' field. + */ + public java.util.List getValue() { + return value; + } + + /** + * Sets the value of the 'value' field. + * @param value the value to set. + */ + public void setValue(java.util.List value) { + this.value = value; + } + + /** Creates a new StringArray RecordBuilder */ + public static StringArray.Builder newBuilder() { + return new StringArray.Builder(); + } + + /** Creates a new StringArray RecordBuilder by copying an existing Builder */ + public static StringArray.Builder newBuilder(StringArray.Builder other) { + return new StringArray.Builder(other); + } + + /** Creates a new StringArray RecordBuilder by copying an existing StringArray instance */ + public static StringArray.Builder newBuilder(StringArray other) { + return new StringArray.Builder(other); + } + + /** + * RecordBuilder for StringArray instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.util.List value; + + /** Creates a new Builder */ + private Builder() { + super(StringArray.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(StringArray.Builder other) { + super(other); + if (isValidValue(fields()[0], other.value)) { + this.value = data().deepCopy(fields()[0].schema(), other.value); + fieldSetFlags()[0] = true; + } + } + + /** Creates a Builder by copying an existing StringArray instance */ + private Builder(StringArray other) { + super(StringArray.SCHEMA$); + if (isValidValue(fields()[0], other.value)) { + this.value = data().deepCopy(fields()[0].schema(), other.value); + fieldSetFlags()[0] = true; + } + } + + /** Gets the value of the 'value' field */ + public java.util.List getValue() { + return value; + } + + /** Sets the value of the 'value' field */ + public StringArray.Builder setValue(java.util.List value) { + validate(fields()[0], value); + this.value = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'value' field has been set */ + public boolean hasValue() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'value' field */ + public StringArray.Builder clearValue() { + value = null; + fieldSetFlags()[0] = false; + return this; + } + + @Override + public StringArray build() { + try { + StringArray record = new StringArray(); + record.value = fieldSetFlags()[0] ? this.value : (java.util.List) defaultValue(fields()[0]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/src/test/java/com/databricks/spark/avro/TestRecord.java b/src/test/java/com/databricks/spark/avro/TestRecord.java new file mode 100644 index 00000000..dd323bb7 --- /dev/null +++ b/src/test/java/com/databricks/spark/avro/TestRecord.java @@ -0,0 +1,893 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package com.databricks.spark.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class TestRecord extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"TestRecord\",\"namespace\":\"com.databricks.spark.avro\",\"fields\":[{\"name\":\"boolean\",\"type\":\"boolean\",\"default\":true},{\"name\":\"int\",\"type\":\"int\",\"default\":0},{\"name\":\"long\",\"type\":\"long\",\"default\":0},{\"name\":\"float\",\"type\":\"float\",\"default\":0.0},{\"name\":\"double\",\"type\":\"double\",\"default\":0.0},{\"name\":\"string\",\"type\":\"string\",\"default\":\"value\"},{\"name\":\"bytes\",\"type\":\"bytes\",\"default\":\"ΓΏ\"},{\"name\":\"nested\",\"type\":{\"type\":\"record\",\"name\":\"SimpleRecord\",\"fields\":[{\"name\":\"nested1\",\"type\":\"int\",\"default\":0},{\"name\":\"nested2\",\"type\":\"string\",\"default\":\"string\"}]},\"default\":{\"nested1\":0,\"nested2\":\"string\"}},{\"name\":\"enum\",\"type\":{\"type\":\"enum\",\"name\":\"SimpleEnums\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]},\"default\":\"SPADES\"},{\"name\":\"fixed\",\"type\":{\"type\":\"fixed\",\"name\":\"SimpleFixed\",\"size\":16},\"default\":\"string_length_16\"},{\"name\":\"intArray\",\"type\":{\"type\":\"array\",\"items\":\"int\"},\"default\":[1,2,3]},{\"name\":\"stringArray\",\"type\":{\"type\":\"array\",\"items\":\"string\"},\"default\":[\"a\",\"b\",\"c\"]},{\"name\":\"recordArray\",\"type\":{\"type\":\"array\",\"items\":\"SimpleRecord\"},\"default\":[{\"nested1\":0,\"nested2\":\"value\"},{\"nested1\":0,\"nested2\":\"value\"}]},{\"name\":\"enumArray\",\"type\":{\"type\":\"array\",\"items\":\"SimpleEnums\"},\"default\":[\"SPADES\",\"HEARTS\",\"SPADES\"]},{\"name\":\"fixedArray\",\"type\":{\"type\":\"array\",\"items\":\"SimpleFixed\"},\"default\":[\"foo\",\"bar\",\"baz\"]}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public boolean boolean$; + @Deprecated public int int$; + @Deprecated public long long$; + @Deprecated public float float$; + @Deprecated public double double$; + @Deprecated public java.lang.CharSequence string; + @Deprecated public java.nio.ByteBuffer bytes; + @Deprecated public com.databricks.spark.avro.SimpleRecord nested; + @Deprecated public com.databricks.spark.avro.SimpleEnums enum$; + @Deprecated public com.databricks.spark.avro.SimpleFixed fixed; + @Deprecated public java.util.List intArray; + @Deprecated public java.util.List stringArray; + @Deprecated public java.util.List recordArray; + @Deprecated public java.util.List enumArray; + @Deprecated public java.util.List fixedArray; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public TestRecord() {} + + /** + * All-args constructor. + */ + public TestRecord(java.lang.Boolean boolean$, java.lang.Integer int$, java.lang.Long long$, java.lang.Float float$, java.lang.Double double$, java.lang.CharSequence string, java.nio.ByteBuffer bytes, com.databricks.spark.avro.SimpleRecord nested, com.databricks.spark.avro.SimpleEnums enum$, com.databricks.spark.avro.SimpleFixed fixed, java.util.List intArray, java.util.List stringArray, java.util.List recordArray, java.util.List enumArray, java.util.List fixedArray) { + this.boolean$ = boolean$; + this.int$ = int$; + this.long$ = long$; + this.float$ = float$; + this.double$ = double$; + this.string = string; + this.bytes = bytes; + this.nested = nested; + this.enum$ = enum$; + this.fixed = fixed; + this.intArray = intArray; + this.stringArray = stringArray; + this.recordArray = recordArray; + this.enumArray = enumArray; + this.fixedArray = fixedArray; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return boolean$; + case 1: return int$; + case 2: return long$; + case 3: return float$; + case 4: return double$; + case 5: return string; + case 6: return bytes; + case 7: return nested; + case 8: return enum$; + case 9: return fixed; + case 10: return intArray; + case 11: return stringArray; + case 12: return recordArray; + case 13: return enumArray; + case 14: return fixedArray; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: boolean$ = (java.lang.Boolean)value$; break; + case 1: int$ = (java.lang.Integer)value$; break; + case 2: long$ = (java.lang.Long)value$; break; + case 3: float$ = (java.lang.Float)value$; break; + case 4: double$ = (java.lang.Double)value$; break; + case 5: string = (java.lang.CharSequence)value$; break; + case 6: bytes = (java.nio.ByteBuffer)value$; break; + case 7: nested = (com.databricks.spark.avro.SimpleRecord)value$; break; + case 8: enum$ = (com.databricks.spark.avro.SimpleEnums)value$; break; + case 9: fixed = (com.databricks.spark.avro.SimpleFixed)value$; break; + case 10: intArray = (java.util.List)value$; break; + case 11: stringArray = (java.util.List)value$; break; + case 12: recordArray = (java.util.List)value$; break; + case 13: enumArray = (java.util.List)value$; break; + case 14: fixedArray = (java.util.List)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'boolean$' field. + */ + public java.lang.Boolean getBoolean$() { + return boolean$; + } + + /** + * Sets the value of the 'boolean$' field. + * @param value the value to set. + */ + public void setBoolean$(java.lang.Boolean value) { + this.boolean$ = value; + } + + /** + * Gets the value of the 'int$' field. + */ + public java.lang.Integer getInt$() { + return int$; + } + + /** + * Sets the value of the 'int$' field. + * @param value the value to set. + */ + public void setInt$(java.lang.Integer value) { + this.int$ = value; + } + + /** + * Gets the value of the 'long$' field. + */ + public java.lang.Long getLong$() { + return long$; + } + + /** + * Sets the value of the 'long$' field. + * @param value the value to set. + */ + public void setLong$(java.lang.Long value) { + this.long$ = value; + } + + /** + * Gets the value of the 'float$' field. + */ + public java.lang.Float getFloat$() { + return float$; + } + + /** + * Sets the value of the 'float$' field. + * @param value the value to set. + */ + public void setFloat$(java.lang.Float value) { + this.float$ = value; + } + + /** + * Gets the value of the 'double$' field. + */ + public java.lang.Double getDouble$() { + return double$; + } + + /** + * Sets the value of the 'double$' field. + * @param value the value to set. + */ + public void setDouble$(java.lang.Double value) { + this.double$ = value; + } + + /** + * Gets the value of the 'string' field. + */ + public java.lang.CharSequence getString() { + return string; + } + + /** + * Sets the value of the 'string' field. + * @param value the value to set. + */ + public void setString(java.lang.CharSequence value) { + this.string = value; + } + + /** + * Gets the value of the 'bytes' field. + */ + public java.nio.ByteBuffer getBytes() { + return bytes; + } + + /** + * Sets the value of the 'bytes' field. + * @param value the value to set. + */ + public void setBytes(java.nio.ByteBuffer value) { + this.bytes = value; + } + + /** + * Gets the value of the 'nested' field. + */ + public com.databricks.spark.avro.SimpleRecord getNested() { + return nested; + } + + /** + * Sets the value of the 'nested' field. + * @param value the value to set. + */ + public void setNested(com.databricks.spark.avro.SimpleRecord value) { + this.nested = value; + } + + /** + * Gets the value of the 'enum$' field. + */ + public com.databricks.spark.avro.SimpleEnums getEnum$() { + return enum$; + } + + /** + * Sets the value of the 'enum$' field. + * @param value the value to set. + */ + public void setEnum$(com.databricks.spark.avro.SimpleEnums value) { + this.enum$ = value; + } + + /** + * Gets the value of the 'fixed' field. + */ + public com.databricks.spark.avro.SimpleFixed getFixed() { + return fixed; + } + + /** + * Sets the value of the 'fixed' field. + * @param value the value to set. + */ + public void setFixed(com.databricks.spark.avro.SimpleFixed value) { + this.fixed = value; + } + + /** + * Gets the value of the 'intArray' field. + */ + public java.util.List getIntArray() { + return intArray; + } + + /** + * Sets the value of the 'intArray' field. + * @param value the value to set. + */ + public void setIntArray(java.util.List value) { + this.intArray = value; + } + + /** + * Gets the value of the 'stringArray' field. + */ + public java.util.List getStringArray() { + return stringArray; + } + + /** + * Sets the value of the 'stringArray' field. + * @param value the value to set. + */ + public void setStringArray(java.util.List value) { + this.stringArray = value; + } + + /** + * Gets the value of the 'recordArray' field. + */ + public java.util.List getRecordArray() { + return recordArray; + } + + /** + * Sets the value of the 'recordArray' field. + * @param value the value to set. + */ + public void setRecordArray(java.util.List value) { + this.recordArray = value; + } + + /** + * Gets the value of the 'enumArray' field. + */ + public java.util.List getEnumArray() { + return enumArray; + } + + /** + * Sets the value of the 'enumArray' field. + * @param value the value to set. + */ + public void setEnumArray(java.util.List value) { + this.enumArray = value; + } + + /** + * Gets the value of the 'fixedArray' field. + */ + public java.util.List getFixedArray() { + return fixedArray; + } + + /** + * Sets the value of the 'fixedArray' field. + * @param value the value to set. + */ + public void setFixedArray(java.util.List value) { + this.fixedArray = value; + } + + /** Creates a new TestRecord RecordBuilder */ + public static com.databricks.spark.avro.TestRecord.Builder newBuilder() { + return new com.databricks.spark.avro.TestRecord.Builder(); + } + + /** Creates a new TestRecord RecordBuilder by copying an existing Builder */ + public static com.databricks.spark.avro.TestRecord.Builder newBuilder(com.databricks.spark.avro.TestRecord.Builder other) { + return new com.databricks.spark.avro.TestRecord.Builder(other); + } + + /** Creates a new TestRecord RecordBuilder by copying an existing TestRecord instance */ + public static com.databricks.spark.avro.TestRecord.Builder newBuilder(com.databricks.spark.avro.TestRecord other) { + return new com.databricks.spark.avro.TestRecord.Builder(other); + } + + /** + * RecordBuilder for TestRecord instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private boolean boolean$; + private int int$; + private long long$; + private float float$; + private double double$; + private java.lang.CharSequence string; + private java.nio.ByteBuffer bytes; + private com.databricks.spark.avro.SimpleRecord nested; + private com.databricks.spark.avro.SimpleEnums enum$; + private com.databricks.spark.avro.SimpleFixed fixed; + private java.util.List intArray; + private java.util.List stringArray; + private java.util.List recordArray; + private java.util.List enumArray; + private java.util.List fixedArray; + + /** Creates a new Builder */ + private Builder() { + super(com.databricks.spark.avro.TestRecord.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(com.databricks.spark.avro.TestRecord.Builder other) { + super(other); + if (isValidValue(fields()[0], other.boolean$)) { + this.boolean$ = data().deepCopy(fields()[0].schema(), other.boolean$); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.int$)) { + this.int$ = data().deepCopy(fields()[1].schema(), other.int$); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.long$)) { + this.long$ = data().deepCopy(fields()[2].schema(), other.long$); + fieldSetFlags()[2] = true; + } + if (isValidValue(fields()[3], other.float$)) { + this.float$ = data().deepCopy(fields()[3].schema(), other.float$); + fieldSetFlags()[3] = true; + } + if (isValidValue(fields()[4], other.double$)) { + this.double$ = data().deepCopy(fields()[4].schema(), other.double$); + fieldSetFlags()[4] = true; + } + if (isValidValue(fields()[5], other.string)) { + this.string = data().deepCopy(fields()[5].schema(), other.string); + fieldSetFlags()[5] = true; + } + if (isValidValue(fields()[6], other.bytes)) { + this.bytes = data().deepCopy(fields()[6].schema(), other.bytes); + fieldSetFlags()[6] = true; + } + if (isValidValue(fields()[7], other.nested)) { + this.nested = data().deepCopy(fields()[7].schema(), other.nested); + fieldSetFlags()[7] = true; + } + if (isValidValue(fields()[8], other.enum$)) { + this.enum$ = data().deepCopy(fields()[8].schema(), other.enum$); + fieldSetFlags()[8] = true; + } + if (isValidValue(fields()[9], other.fixed)) { + this.fixed = data().deepCopy(fields()[9].schema(), other.fixed); + fieldSetFlags()[9] = true; + } + if (isValidValue(fields()[10], other.intArray)) { + this.intArray = data().deepCopy(fields()[10].schema(), other.intArray); + fieldSetFlags()[10] = true; + } + if (isValidValue(fields()[11], other.stringArray)) { + this.stringArray = data().deepCopy(fields()[11].schema(), other.stringArray); + fieldSetFlags()[11] = true; + } + if (isValidValue(fields()[12], other.recordArray)) { + this.recordArray = data().deepCopy(fields()[12].schema(), other.recordArray); + fieldSetFlags()[12] = true; + } + if (isValidValue(fields()[13], other.enumArray)) { + this.enumArray = data().deepCopy(fields()[13].schema(), other.enumArray); + fieldSetFlags()[13] = true; + } + if (isValidValue(fields()[14], other.fixedArray)) { + this.fixedArray = data().deepCopy(fields()[14].schema(), other.fixedArray); + fieldSetFlags()[14] = true; + } + } + + /** Creates a Builder by copying an existing TestRecord instance */ + private Builder(com.databricks.spark.avro.TestRecord other) { + super(com.databricks.spark.avro.TestRecord.SCHEMA$); + if (isValidValue(fields()[0], other.boolean$)) { + this.boolean$ = data().deepCopy(fields()[0].schema(), other.boolean$); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.int$)) { + this.int$ = data().deepCopy(fields()[1].schema(), other.int$); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.long$)) { + this.long$ = data().deepCopy(fields()[2].schema(), other.long$); + fieldSetFlags()[2] = true; + } + if (isValidValue(fields()[3], other.float$)) { + this.float$ = data().deepCopy(fields()[3].schema(), other.float$); + fieldSetFlags()[3] = true; + } + if (isValidValue(fields()[4], other.double$)) { + this.double$ = data().deepCopy(fields()[4].schema(), other.double$); + fieldSetFlags()[4] = true; + } + if (isValidValue(fields()[5], other.string)) { + this.string = data().deepCopy(fields()[5].schema(), other.string); + fieldSetFlags()[5] = true; + } + if (isValidValue(fields()[6], other.bytes)) { + this.bytes = data().deepCopy(fields()[6].schema(), other.bytes); + fieldSetFlags()[6] = true; + } + if (isValidValue(fields()[7], other.nested)) { + this.nested = data().deepCopy(fields()[7].schema(), other.nested); + fieldSetFlags()[7] = true; + } + if (isValidValue(fields()[8], other.enum$)) { + this.enum$ = data().deepCopy(fields()[8].schema(), other.enum$); + fieldSetFlags()[8] = true; + } + if (isValidValue(fields()[9], other.fixed)) { + this.fixed = data().deepCopy(fields()[9].schema(), other.fixed); + fieldSetFlags()[9] = true; + } + if (isValidValue(fields()[10], other.intArray)) { + this.intArray = data().deepCopy(fields()[10].schema(), other.intArray); + fieldSetFlags()[10] = true; + } + if (isValidValue(fields()[11], other.stringArray)) { + this.stringArray = data().deepCopy(fields()[11].schema(), other.stringArray); + fieldSetFlags()[11] = true; + } + if (isValidValue(fields()[12], other.recordArray)) { + this.recordArray = data().deepCopy(fields()[12].schema(), other.recordArray); + fieldSetFlags()[12] = true; + } + if (isValidValue(fields()[13], other.enumArray)) { + this.enumArray = data().deepCopy(fields()[13].schema(), other.enumArray); + fieldSetFlags()[13] = true; + } + if (isValidValue(fields()[14], other.fixedArray)) { + this.fixedArray = data().deepCopy(fields()[14].schema(), other.fixedArray); + fieldSetFlags()[14] = true; + } + } + + /** Gets the value of the 'boolean$' field */ + public java.lang.Boolean getBoolean$() { + return boolean$; + } + + /** Sets the value of the 'boolean$' field */ + public com.databricks.spark.avro.TestRecord.Builder setBoolean$(boolean value) { + validate(fields()[0], value); + this.boolean$ = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'boolean$' field has been set */ + public boolean hasBoolean$() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'boolean$' field */ + public com.databricks.spark.avro.TestRecord.Builder clearBoolean$() { + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'int$' field */ + public java.lang.Integer getInt$() { + return int$; + } + + /** Sets the value of the 'int$' field */ + public com.databricks.spark.avro.TestRecord.Builder setInt$(int value) { + validate(fields()[1], value); + this.int$ = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'int$' field has been set */ + public boolean hasInt$() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'int$' field */ + public com.databricks.spark.avro.TestRecord.Builder clearInt$() { + fieldSetFlags()[1] = false; + return this; + } + + /** Gets the value of the 'long$' field */ + public java.lang.Long getLong$() { + return long$; + } + + /** Sets the value of the 'long$' field */ + public com.databricks.spark.avro.TestRecord.Builder setLong$(long value) { + validate(fields()[2], value); + this.long$ = value; + fieldSetFlags()[2] = true; + return this; + } + + /** Checks whether the 'long$' field has been set */ + public boolean hasLong$() { + return fieldSetFlags()[2]; + } + + /** Clears the value of the 'long$' field */ + public com.databricks.spark.avro.TestRecord.Builder clearLong$() { + fieldSetFlags()[2] = false; + return this; + } + + /** Gets the value of the 'float$' field */ + public java.lang.Float getFloat$() { + return float$; + } + + /** Sets the value of the 'float$' field */ + public com.databricks.spark.avro.TestRecord.Builder setFloat$(float value) { + validate(fields()[3], value); + this.float$ = value; + fieldSetFlags()[3] = true; + return this; + } + + /** Checks whether the 'float$' field has been set */ + public boolean hasFloat$() { + return fieldSetFlags()[3]; + } + + /** Clears the value of the 'float$' field */ + public com.databricks.spark.avro.TestRecord.Builder clearFloat$() { + fieldSetFlags()[3] = false; + return this; + } + + /** Gets the value of the 'double$' field */ + public java.lang.Double getDouble$() { + return double$; + } + + /** Sets the value of the 'double$' field */ + public com.databricks.spark.avro.TestRecord.Builder setDouble$(double value) { + validate(fields()[4], value); + this.double$ = value; + fieldSetFlags()[4] = true; + return this; + } + + /** Checks whether the 'double$' field has been set */ + public boolean hasDouble$() { + return fieldSetFlags()[4]; + } + + /** Clears the value of the 'double$' field */ + public com.databricks.spark.avro.TestRecord.Builder clearDouble$() { + fieldSetFlags()[4] = false; + return this; + } + + /** Gets the value of the 'string' field */ + public java.lang.CharSequence getString() { + return string; + } + + /** Sets the value of the 'string' field */ + public com.databricks.spark.avro.TestRecord.Builder setString(java.lang.CharSequence value) { + validate(fields()[5], value); + this.string = value; + fieldSetFlags()[5] = true; + return this; + } + + /** Checks whether the 'string' field has been set */ + public boolean hasString() { + return fieldSetFlags()[5]; + } + + /** Clears the value of the 'string' field */ + public com.databricks.spark.avro.TestRecord.Builder clearString() { + string = null; + fieldSetFlags()[5] = false; + return this; + } + + /** Gets the value of the 'bytes' field */ + public java.nio.ByteBuffer getBytes() { + return bytes; + } + + /** Sets the value of the 'bytes' field */ + public com.databricks.spark.avro.TestRecord.Builder setBytes(java.nio.ByteBuffer value) { + validate(fields()[6], value); + this.bytes = value; + fieldSetFlags()[6] = true; + return this; + } + + /** Checks whether the 'bytes' field has been set */ + public boolean hasBytes() { + return fieldSetFlags()[6]; + } + + /** Clears the value of the 'bytes' field */ + public com.databricks.spark.avro.TestRecord.Builder clearBytes() { + bytes = null; + fieldSetFlags()[6] = false; + return this; + } + + /** Gets the value of the 'nested' field */ + public com.databricks.spark.avro.SimpleRecord getNested() { + return nested; + } + + /** Sets the value of the 'nested' field */ + public com.databricks.spark.avro.TestRecord.Builder setNested(com.databricks.spark.avro.SimpleRecord value) { + validate(fields()[7], value); + this.nested = value; + fieldSetFlags()[7] = true; + return this; + } + + /** Checks whether the 'nested' field has been set */ + public boolean hasNested() { + return fieldSetFlags()[7]; + } + + /** Clears the value of the 'nested' field */ + public com.databricks.spark.avro.TestRecord.Builder clearNested() { + nested = null; + fieldSetFlags()[7] = false; + return this; + } + + /** Gets the value of the 'enum$' field */ + public com.databricks.spark.avro.SimpleEnums getEnum$() { + return enum$; + } + + /** Sets the value of the 'enum$' field */ + public com.databricks.spark.avro.TestRecord.Builder setEnum$(com.databricks.spark.avro.SimpleEnums value) { + validate(fields()[8], value); + this.enum$ = value; + fieldSetFlags()[8] = true; + return this; + } + + /** Checks whether the 'enum$' field has been set */ + public boolean hasEnum$() { + return fieldSetFlags()[8]; + } + + /** Clears the value of the 'enum$' field */ + public com.databricks.spark.avro.TestRecord.Builder clearEnum$() { + enum$ = null; + fieldSetFlags()[8] = false; + return this; + } + + /** Gets the value of the 'fixed' field */ + public com.databricks.spark.avro.SimpleFixed getFixed() { + return fixed; + } + + /** Sets the value of the 'fixed' field */ + public com.databricks.spark.avro.TestRecord.Builder setFixed(com.databricks.spark.avro.SimpleFixed value) { + validate(fields()[9], value); + this.fixed = value; + fieldSetFlags()[9] = true; + return this; + } + + /** Checks whether the 'fixed' field has been set */ + public boolean hasFixed() { + return fieldSetFlags()[9]; + } + + /** Clears the value of the 'fixed' field */ + public com.databricks.spark.avro.TestRecord.Builder clearFixed() { + fixed = null; + fieldSetFlags()[9] = false; + return this; + } + + /** Gets the value of the 'intArray' field */ + public java.util.List getIntArray() { + return intArray; + } + + /** Sets the value of the 'intArray' field */ + public com.databricks.spark.avro.TestRecord.Builder setIntArray(java.util.List value) { + validate(fields()[10], value); + this.intArray = value; + fieldSetFlags()[10] = true; + return this; + } + + /** Checks whether the 'intArray' field has been set */ + public boolean hasIntArray() { + return fieldSetFlags()[10]; + } + + /** Clears the value of the 'intArray' field */ + public com.databricks.spark.avro.TestRecord.Builder clearIntArray() { + intArray = null; + fieldSetFlags()[10] = false; + return this; + } + + /** Gets the value of the 'stringArray' field */ + public java.util.List getStringArray() { + return stringArray; + } + + /** Sets the value of the 'stringArray' field */ + public com.databricks.spark.avro.TestRecord.Builder setStringArray(java.util.List value) { + validate(fields()[11], value); + this.stringArray = value; + fieldSetFlags()[11] = true; + return this; + } + + /** Checks whether the 'stringArray' field has been set */ + public boolean hasStringArray() { + return fieldSetFlags()[11]; + } + + /** Clears the value of the 'stringArray' field */ + public com.databricks.spark.avro.TestRecord.Builder clearStringArray() { + stringArray = null; + fieldSetFlags()[11] = false; + return this; + } + + /** Gets the value of the 'recordArray' field */ + public java.util.List getRecordArray() { + return recordArray; + } + + /** Sets the value of the 'recordArray' field */ + public com.databricks.spark.avro.TestRecord.Builder setRecordArray(java.util.List value) { + validate(fields()[12], value); + this.recordArray = value; + fieldSetFlags()[12] = true; + return this; + } + + /** Checks whether the 'recordArray' field has been set */ + public boolean hasRecordArray() { + return fieldSetFlags()[12]; + } + + /** Clears the value of the 'recordArray' field */ + public com.databricks.spark.avro.TestRecord.Builder clearRecordArray() { + recordArray = null; + fieldSetFlags()[12] = false; + return this; + } + + /** Gets the value of the 'enumArray' field */ + public java.util.List getEnumArray() { + return enumArray; + } + + /** Sets the value of the 'enumArray' field */ + public com.databricks.spark.avro.TestRecord.Builder setEnumArray(java.util.List value) { + validate(fields()[13], value); + this.enumArray = value; + fieldSetFlags()[13] = true; + return this; + } + + /** Checks whether the 'enumArray' field has been set */ + public boolean hasEnumArray() { + return fieldSetFlags()[13]; + } + + /** Clears the value of the 'enumArray' field */ + public com.databricks.spark.avro.TestRecord.Builder clearEnumArray() { + enumArray = null; + fieldSetFlags()[13] = false; + return this; + } + + /** Gets the value of the 'fixedArray' field */ + public java.util.List getFixedArray() { + return fixedArray; + } + + /** Sets the value of the 'fixedArray' field */ + public com.databricks.spark.avro.TestRecord.Builder setFixedArray(java.util.List value) { + validate(fields()[14], value); + this.fixedArray = value; + fieldSetFlags()[14] = true; + return this; + } + + /** Checks whether the 'fixedArray' field has been set */ + public boolean hasFixedArray() { + return fieldSetFlags()[14]; + } + + /** Clears the value of the 'fixedArray' field */ + public com.databricks.spark.avro.TestRecord.Builder clearFixedArray() { + fixedArray = null; + fieldSetFlags()[14] = false; + return this; + } + + @Override + public TestRecord build() { + try { + TestRecord record = new TestRecord(); + record.boolean$ = fieldSetFlags()[0] ? this.boolean$ : (java.lang.Boolean) defaultValue(fields()[0]); + record.int$ = fieldSetFlags()[1] ? this.int$ : (java.lang.Integer) defaultValue(fields()[1]); + record.long$ = fieldSetFlags()[2] ? this.long$ : (java.lang.Long) defaultValue(fields()[2]); + record.float$ = fieldSetFlags()[3] ? this.float$ : (java.lang.Float) defaultValue(fields()[3]); + record.double$ = fieldSetFlags()[4] ? this.double$ : (java.lang.Double) defaultValue(fields()[4]); + record.string = fieldSetFlags()[5] ? this.string : (java.lang.CharSequence) defaultValue(fields()[5]); + record.bytes = fieldSetFlags()[6] ? this.bytes : (java.nio.ByteBuffer) defaultValue(fields()[6]); + record.nested = fieldSetFlags()[7] ? this.nested : (com.databricks.spark.avro.SimpleRecord) defaultValue(fields()[7]); + record.enum$ = fieldSetFlags()[8] ? this.enum$ : (com.databricks.spark.avro.SimpleEnums) defaultValue(fields()[8]); + record.fixed = fieldSetFlags()[9] ? this.fixed : (com.databricks.spark.avro.SimpleFixed) defaultValue(fields()[9]); + record.intArray = fieldSetFlags()[10] ? this.intArray : (java.util.List) defaultValue(fields()[10]); + record.stringArray = fieldSetFlags()[11] ? this.stringArray : (java.util.List) defaultValue(fields()[11]); + record.recordArray = fieldSetFlags()[12] ? this.recordArray : (java.util.List) defaultValue(fields()[12]); + record.enumArray = fieldSetFlags()[13] ? this.enumArray : (java.util.List) defaultValue(fields()[13]); + record.fixedArray = fieldSetFlags()[14] ? this.fixedArray : (java.util.List) defaultValue(fields()[14]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/src/test/resources/specific.avsc b/src/test/resources/specific.avsc new file mode 100644 index 00000000..dbbc1da6 --- /dev/null +++ b/src/test/resources/specific.avsc @@ -0,0 +1,40 @@ +{ + "namespace": "com.databricks.spark.avro", + "type": "record", + "name": "TestRecord", + "fields": [ + {"name": "boolean", "type": "boolean", "default": true}, + {"name": "int", "type": "int", "default": 0}, + {"name": "long", "type": "long", "default": 0}, + {"name": "float", "type": "float", "default": 0.0}, + {"name": "double", "type": "double", "default": 0.0}, + {"name": "string", "type": "string", "default": "value"}, + {"name": "bytes", "type": "bytes", "default": "\u00ff"}, + {"name": "nested", "type": { + "type": "record", "name": "SimpleRecord", "fields": [ + {"name": "nested1", "type": "int", "default": 0}, + {"name": "nested2", "type": "string", "default": "string"}]}, + "default": {"nested1": 0, "nested2": "string"}}, + {"name": "enum", "type": { + "name": "SimpleEnums", "type": "enum", "symbols": ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"]}, + "default": "SPADES"}, + {"name": "fixed", "type": { + "name": "SimpleFixed", "type": "fixed", "size": 16}, + "default": "string_length_16"}, + {"name": "intArray", + "type": {"type": "array", "items": "int"}, + "default": [1, 2, 3]}, + {"name": "stringArray", + "type": {"type": "array", "items": "string"}, + "default": ["a", "b", "c"]}, + {"name": "recordArray", + "type": {"type": "array", "items": "com.databricks.spark.avro.SimpleRecord"}, + "default": [{"nested1": 0, "nested2": "value"}, {"nested1": 0, "nested2": "value"}]}, + {"name": "enumArray", + "type": {"type": "array", "items": "com.databricks.spark.avro.SimpleEnums"}, + "default": ["SPADES", "HEARTS", "SPADES"]}, + {"name": "fixedArray", + "type": {"type": "array", "items": "com.databricks.spark.avro.SimpleFixed"}, + "default": ["foo", "bar", "baz"]} + ] +} \ No newline at end of file diff --git a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala index 1b5d07aa..5779b711 100644 --- a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala +++ b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala @@ -22,18 +22,22 @@ import java.nio.file.Files import java.sql.Timestamp import java.util.UUID -import scala.collection.JavaConversions._ +import org.apache.spark.SparkConf +import scala.collection.JavaConversions._ import com.databricks.spark.avro.SchemaConverters.IncompatibleSchemaException import org.apache.avro.Schema import org.apache.avro.Schema.{Field, Type} +import org.apache.avro.SchemaBuilder import org.apache.avro.file.DataFileWriter import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} -import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord} +import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord, GenericRecordBuilder} import org.apache.commons.io.FileUtils - +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.scalatest.{BeforeAndAfterAll, FunSuite} class AvroSuite extends FunSuite with BeforeAndAfterAll { @@ -44,10 +48,16 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll { override protected def beforeAll(): Unit = { super.beforeAll() + + val sc = new SparkConf() + sc.registerAvroSchemas(Feature.getClassSchema) + spark = SparkSession.builder() .master("local[2]") .appName("AvroSuite") .config("spark.sql.files.maxPartitionBytes", 1024) + .config(sc) + .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") .getOrCreate() } @@ -674,4 +684,229 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll { assert(input.rdd.partitions.size > 2) } } + + test("generic record converts to row and back") { + val nested = + SchemaBuilder.record("simple_record").fields() + .name("nested1").`type`("int").withDefault(0) + .name("nested2").`type`("string").withDefault("string").endRecord() + + val schema = SchemaBuilder.record("record").fields() + .name("boolean").`type`("boolean").withDefault(false) + .name("int").`type`("int").withDefault(0) + .name("long").`type`("long").withDefault(0L) + .name("float").`type`("float").withDefault(0.0F) + .name("double").`type`("double").withDefault(0.0) + .name("string").`type`("string").withDefault("string") + .name("bytes").`type`("bytes").withDefault(java.nio.ByteBuffer.wrap("bytes".getBytes)) + .name("nested").`type`(nested).withDefault(new GenericRecordBuilder(nested).build) + .name("enum").`type`( + SchemaBuilder.enumeration("simple_enums") + .symbols("SPADES", "HEARTS", "CLUBS", "DIAMONDS")) + .withDefault("SPADES") + .name("int_array").`type`( + SchemaBuilder.array().items().`type`("int")) + .withDefault(java.util.Arrays.asList(1, 2, 3)) + .name("string_array").`type`( + SchemaBuilder.array().items().`type`("string")) + .withDefault(java.util.Arrays.asList("a", "b", "c")) + .name("record_array").`type`( + SchemaBuilder.array.items.`type`(nested)) + .withDefault(java.util.Arrays.asList( + new GenericRecordBuilder(nested).build, + new GenericRecordBuilder(nested).build)) + .name("enum_array").`type`( + SchemaBuilder.array.items.`type`( + SchemaBuilder.enumeration("simple_enums") + .symbols("SPADES", "HEARTS", "CLUBS", "DIAMONDS"))) + .withDefault(java.util.Arrays.asList("SPADES", "HEARTS", "SPADES")) + .name("fixed_array").`type`( + SchemaBuilder.array.items().`type`( + SchemaBuilder.fixed("simple_fixed").size(3))) + .withDefault(java.util.Arrays.asList("foo", "bar", "baz")) + .name("fixed").`type`(SchemaBuilder.fixed("simple_fixed").size(16)) + .withDefault("string_length_16") + .endRecord() + + val encoder = AvroEncoder.of[GenericData.Record](schema) + val expressionEncoder = encoder.asInstanceOf[ExpressionEncoder[GenericData.Record]] + val record = new GenericRecordBuilder(schema).build + val row = expressionEncoder.toRow(record) + val recordFromRow = expressionEncoder.resolveAndBind().fromRow(row) + + assert(record == recordFromRow) + } + + test("specific record converts to row and back") { + val schemaPath = "src/test/resources/specific.avsc" + val schema = new Schema.Parser().parse(new File(schemaPath)) + val record = TestRecord.newBuilder().build() + + val classEncoder = AvroEncoder.of[TestRecord](classOf[TestRecord]) + val classExpressionEncoder = classEncoder.asInstanceOf[ExpressionEncoder[TestRecord]] + val classRow = classExpressionEncoder.toRow(record) + val classRecordFromRow = classExpressionEncoder.resolveAndBind().fromRow(classRow) + + assert(record == classRecordFromRow) + + val schemaEncoder = AvroEncoder.of[TestRecord](schema) + val schemaExpressionEncoder = schemaEncoder.asInstanceOf[ExpressionEncoder[TestRecord]] + val schemaRow = schemaExpressionEncoder.toRow(record) + val schemaRecordFromRow = schemaExpressionEncoder.resolveAndBind().fromRow(schemaRow) + + assert(record == schemaRecordFromRow) + } + + test("encoder resolves union types to rows") { + val schema = SchemaBuilder.record("record").fields() + .name("int_null_union").`type`( + SchemaBuilder.unionOf.`type`("null").and.`type`("int").endUnion) + .withDefault(null) + .name("string_null_union").`type`( + SchemaBuilder.unionOf.`type`("null").and.`type`("string").endUnion) + .withDefault(null) + .name("int_long_union").`type`( + SchemaBuilder.unionOf.`type`("int").and.`type`("long").endUnion) + .withDefault(0) + .name("float_double_union").`type`( + SchemaBuilder.unionOf.`type`("float").and.`type`("double").endUnion) + .withDefault(0.0) + .endRecord + + val encoder = AvroEncoder.of[GenericData.Record](schema) + val expressionEncoder = encoder.asInstanceOf[ExpressionEncoder[GenericData.Record]] + val record = new GenericRecordBuilder(schema).build + val row = expressionEncoder.toRow(record) + val recordFromRow = expressionEncoder.resolveAndBind().fromRow(row) + + assert(record.get(0) == recordFromRow.get(0)) + assert(record.get(1) == recordFromRow.get(1)) + assert(record.get(2) == recordFromRow.get(2)) + assert(record.get(3) == recordFromRow.get(3)) + + record.put(0, 0) + record.put(1, "value") + + val updatedRow = expressionEncoder.toRow(record) + val updatedRecordFromRow = expressionEncoder.resolveAndBind().fromRow(updatedRow) + + assert(record.get(0) == updatedRecordFromRow.get(0)) + assert(record.get(1) == updatedRecordFromRow.get(1)) + } + + test("encoder resolves map types to rows") { + val intMap = new java.util.HashMap[java.lang.String, java.lang.Integer] + intMap.put("foo", 1) + intMap.put("bar", 2) + intMap.put("baz", 3) + + val stringMap = new java.util.HashMap[java.lang.String, java.lang.String] + stringMap.put("foo", "a") + stringMap.put("bar", "b") + stringMap.put("baz", "c") + + val schema = SchemaBuilder.record("record").fields() + .name("int_map").`type`( + SchemaBuilder.map.values.`type`("int")).withDefault(intMap) + .name("string_map").`type`( + SchemaBuilder.map.values.`type`("string")).withDefault(stringMap) + .endRecord() + + val encoder = AvroEncoder.of[GenericData.Record](schema) + val expressionEncoder = encoder.asInstanceOf[ExpressionEncoder[GenericData.Record]] + val record = new GenericRecordBuilder(schema).build + val row = expressionEncoder.toRow(record) + val recordFromRow = expressionEncoder.resolveAndBind().fromRow(row) + + val rowIntMap = recordFromRow.get(0) + assert(intMap == rowIntMap) + + val rowStringMap = recordFromRow.get(1) + assert(stringMap == rowStringMap) + } + + test("encoder resolves complex unions to rows") { + val nested = + SchemaBuilder.record("simple_record").fields() + .name("nested1").`type`("int").withDefault(0) + .name("nested2").`type`("string").withDefault("foo").endRecord() + val schema = SchemaBuilder.record("record").fields() + .name("int_float_string_record").`type`( + SchemaBuilder.unionOf() + .`type`("null").and() + .`type`("int").and() + .`type`("float").and() + .`type`("string").and() + .`type`(nested).endUnion() + ).withDefault(null).endRecord() + + val encoder = AvroEncoder.of[GenericData.Record](schema) + val expressionEncoder = encoder.asInstanceOf[ExpressionEncoder[GenericData.Record]] + val record = new GenericRecordBuilder(schema).build + var row = expressionEncoder.toRow(record) + var recordFromRow = expressionEncoder.resolveAndBind().fromRow(row) + + assert(row.getStruct(0, 4).get(0, IntegerType) == null) + assert(row.getStruct(0, 4).get(1, FloatType) == null) + assert(row.getStruct(0, 4).get(2, StringType) == null) + assert(row.getStruct(0, 4).getStruct(3, 2) == null) + assert(record == recordFromRow) + + record.put(0, 1) + row = expressionEncoder.toRow(record) + recordFromRow = expressionEncoder.resolveAndBind().fromRow(row) + + assert(row.getStruct(0, 4).get(1, FloatType) == null) + assert(row.getStruct(0, 4).get(2, StringType) == null) + assert(row.getStruct(0, 4).getStruct(3, 2) == null) + assert(record == recordFromRow) + + record.put(0, 1F) + row = expressionEncoder.toRow(record) + recordFromRow = expressionEncoder.resolveAndBind().fromRow(row) + + assert(row.getStruct(0, 4).get(0, IntegerType) == null) + assert(row.getStruct(0, 4).get(2, StringType) == null) + assert(row.getStruct(0, 4).getStruct(3, 2) == null) + assert(record == recordFromRow) + + record.put(0, "bar") + row = expressionEncoder.toRow(record) + recordFromRow = expressionEncoder.resolveAndBind().fromRow(row) + + assert(row.getStruct(0, 4).get(0, IntegerType) == null) + assert(row.getStruct(0, 4).get(1, FloatType) == null) + assert(row.getStruct(0, 4).getStruct(3, 2) == null) + assert(record == recordFromRow) + + record.put(0, new GenericRecordBuilder(nested).build()) + row = expressionEncoder.toRow(record) + recordFromRow = expressionEncoder.resolveAndBind().fromRow(row) + + assert(row.getStruct(0, 4).get(0, IntegerType) == null) + assert(row.getStruct(0, 4).get(1, FloatType) == null) + assert(row.getStruct(0, 4).get(2, StringType) == null) + assert(record == recordFromRow) + } + + test("create Dataset from SpecificRecords with unions") { + val sparkSession = spark + import sparkSession.implicits._ + + implicit val enc = AvroEncoder.of(classOf[Feature]) + + val rdd = sparkSession.sparkContext + .parallelize(Seq(1)).mapPartitions { iter => + iter.map { _ => + val t = StringArray.newBuilder().setValue(List("the title")).build() + val b = StringArray.newBuilder().setValue(List("BODY TEXT")).build() + val ls = StringArray.newBuilder().setValue(List("foo", "bar", "baz")).build() + + Feature.newBuilder().setKey("FOOBAR").setValue(ls).build() + } + } + + val ds = rdd.toDS() + assert(ds.count() == 1) + } } From fd777c679f69a82616d774c1781a02e5dddd5fd1 Mon Sep 17 00:00:00 2001 From: Ian Hummel Date: Thu, 2 Feb 2017 06:02:13 -0800 Subject: [PATCH 04/22] Fix imports so that we can redeclare LambdaVariable, make unit tests pass --- .../databricks/spark/avro/AvroEncoder.scala | 192 ++++++++++++++++-- .../com/databricks/spark/avro/AvroSuite.scala | 14 +- 2 files changed, 182 insertions(+), 24 deletions(-) diff --git a/src/main/scala/com/databricks/spark/avro/AvroEncoder.scala b/src/main/scala/com/databricks/spark/avro/AvroEncoder.scala index a488ecdc..eff31330 100644 --- a/src/main/scala/com/databricks/spark/avro/AvroEncoder.scala +++ b/src/main/scala/com/databricks/spark/avro/AvroEncoder.scala @@ -22,13 +22,10 @@ package com.databricks.spark.avro import java.io._ import java.util.{Map => JMap} -import org.apache.avro.Schema.Parser -import org.apache.hadoop.conf.Configuration -import org.apache.spark.util.Utils -import scala.collection.JavaConverters._ import com.databricks.spark.avro.SchemaConverters.{IncompatibleSchemaException, SchemaType, resolveUnionType, toSqlType} import org.apache.avro.Schema +import org.apache.avro.Schema.Parser import org.apache.avro.Schema.Type._ import org.apache.avro.generic.{GenericData, IndexedRecord} import org.apache.avro.reflect.ReflectData @@ -39,12 +36,13 @@ import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAtt import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.expressions.objects._ +import org.apache.spark.sql.catalyst.expressions.objects.{LambdaVariable => _, _} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag /** @@ -74,20 +72,180 @@ object AvroEncoder { } } -object ObjectType { +private[avro] object ObjectType { + val ot = Class.forName("org.apache.spark.sql.types.ObjectType") + val meth = ot.getDeclaredConstructor(classOf[Class[_]]) + meth.setAccessible(true) + + val cls = ot.getMethod("cls") + cls.setAccessible(true) + def apply(cls: Class[_]): DataType = { - val ot = Class.forName("org.apache.spark.sql.types.ObjectType") - val meth = ot.getDeclaredConstructor(classOf[Class[_]]) - meth.setAccessible(true) meth.newInstance(cls).asInstanceOf[DataType] } def _isInstanceOf(obj: AnyRef): Boolean = { - val ot = Class.forName("org.apache.spark.sql.types.ObjectType") ot.isInstance(obj) } + + def unapply(arg: DataType): Option[Class[_]] = { + arg match { + case arg if ot.isInstance(arg) => { + Some(cls.invoke(arg).asInstanceOf[Class[_]]) + } + case _ => None + } + } } +case class LambdaVariable( + value: String, + isNull: String, + dataType: DataType, + nullable: Boolean = true) extends LeafExpression + with Unevaluable with NonSQLExpression { + + override def genCode(ctx: CodegenContext): ExprCode = { + ExprCode(code = "", value = value, isNull = if (nullable) isNull else "false") + } +} + +object ExternalMapToCatalyst { + private val curId = new java.util.concurrent.atomic.AtomicInteger() + + def apply( + inputMap: Expression, + keyType: DataType, + keyConverter: Expression => Expression, + valueType: DataType, + valueConverter: Expression => Expression, + valueNullable: Boolean): ExternalMapToCatalyst = { + val id = curId.getAndIncrement() + val keyName = "ExternalMapToCatalyst_key" + id + val valueName = "ExternalMapToCatalyst_value" + id + val valueIsNull = "ExternalMapToCatalyst_value_isNull" + id + + ExternalMapToCatalyst( + keyName, + keyType, + keyConverter(LambdaVariable(keyName, "false", keyType, false)), + valueName, + valueIsNull, + valueType, + valueConverter(LambdaVariable(valueName, valueIsNull, valueType, valueNullable)), + inputMap + ) + } +} + +case class ExternalMapToCatalyst private( + key: String, + keyType: DataType, + keyConverter: Expression, + value: String, + valueIsNull: String, + valueType: DataType, + valueConverter: Expression, + child: Expression) + extends UnaryExpression with NonSQLExpression { + + override def foldable: Boolean = false + + override def dataType: MapType = MapType( + keyConverter.dataType, valueConverter.dataType, valueContainsNull = valueConverter.nullable) + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val inputMap = child.genCode(ctx) + val genKeyConverter = keyConverter.genCode(ctx) + val genValueConverter = valueConverter.genCode(ctx) + val length = ctx.freshName("length") + val index = ctx.freshName("index") + val convertedKeys = ctx.freshName("convertedKeys") + val convertedValues = ctx.freshName("convertedValues") + val entry = ctx.freshName("entry") + val entries = ctx.freshName("entries") + + val (defineEntries, defineKeyValue) = child.dataType match { + case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => + val javaIteratorCls = classOf[java.util.Iterator[_]].getName + val javaMapEntryCls = classOf[java.util.Map.Entry[_, _]].getName + + val defineEntries = + s"final $javaIteratorCls $entries = ${inputMap.value}.entrySet().iterator();" + + val defineKeyValue = + s""" + final $javaMapEntryCls $entry = ($javaMapEntryCls) $entries.next(); + ${ctx.javaType(keyType)} $key = (${ctx.boxedType(keyType)}) $entry.getKey(); + ${ctx.javaType(valueType)} $value = (${ctx.boxedType(valueType)}) $entry.getValue(); + """ + + defineEntries -> defineKeyValue + + case ObjectType(cls) if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) => + val scalaIteratorCls = classOf[Iterator[_]].getName + val scalaMapEntryCls = classOf[Tuple2[_, _]].getName + + val defineEntries = s"final $scalaIteratorCls $entries = ${inputMap.value}.iterator();" + + val defineKeyValue = + s""" + final $scalaMapEntryCls $entry = ($scalaMapEntryCls) $entries.next(); + ${ctx.javaType(keyType)} $key = (${ctx.boxedType(keyType)}) $entry._1(); + ${ctx.javaType(valueType)} $value = (${ctx.boxedType(valueType)}) $entry._2(); + """ + + defineEntries -> defineKeyValue + } + + val valueNullCheck = if (ctx.isPrimitiveType(valueType)) { + s"boolean $valueIsNull = false;" + } else { + s"boolean $valueIsNull = $value == null;" + } + + val arrayCls = classOf[GenericArrayData].getName + val mapCls = classOf[ArrayBasedMapData].getName + val convertedKeyType = ctx.boxedType(keyConverter.dataType) + val convertedValueType = ctx.boxedType(valueConverter.dataType) + val code = + s""" + ${inputMap.code} + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${inputMap.isNull}) { + final int $length = ${inputMap.value}.size(); + final Object[] $convertedKeys = new Object[$length]; + final Object[] $convertedValues = new Object[$length]; + int $index = 0; + $defineEntries + while($entries.hasNext()) { + $defineKeyValue + $valueNullCheck + ${genKeyConverter.code} + if (${genKeyConverter.isNull}) { + throw new RuntimeException("Cannot use null as map key!"); + } else { + $convertedKeys[$index] = ($convertedKeyType) ${genKeyConverter.value}; + } + ${genValueConverter.code} + if (${genValueConverter.isNull}) { + $convertedValues[$index] = null; + } else { + $convertedValues[$index] = ($convertedValueType) ${genValueConverter.value}; + } + $index++; + } + ${ev.value} = new $mapCls(new $arrayCls($convertedKeys), new $arrayCls($convertedValues)); + } + """ + ev.copy(code = code, isNull = inputMap.isNull) + } +} + + class SerializableSchema(@transient var value: Schema) extends Externalizable { def this() = this(null) override def readExternal(in: ObjectInput): Unit = { @@ -425,13 +583,13 @@ private object AvroTypeInference { val valueSchema = schema.getValueType val valueType = inferExternalType(valueSchema) -// ExternalMapToCatalyst( -// inputObject, -// ObjectType(classOf[org.apache.avro.util.Utf8]), -// serializerFor(_, Schema.create(STRING)), -// valueType, -// serializerFor(_, valueSchema)) - ??? + ExternalMapToCatalyst( + inputObject, + ObjectType(classOf[org.apache.avro.util.Utf8]), + serializerFor(_, Schema.create(STRING)), + valueType, + serializerFor(_, valueSchema), + true) } if (!ObjectType._isInstanceOf(inputObject.dataType)) { diff --git a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala index 5779b711..64f2b098 100644 --- a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala +++ b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala @@ -27,11 +27,13 @@ import org.apache.spark.SparkConf import scala.collection.JavaConversions._ import com.databricks.spark.avro.SchemaConverters.IncompatibleSchemaException import org.apache.avro.Schema -import org.apache.avro.Schema.{Field, Type} +import org.apache.avro.Schema.{Field, Parser, Type} import org.apache.avro.SchemaBuilder import org.apache.avro.file.DataFileWriter import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord, GenericRecordBuilder} +import org.apache.avro.io.{DecoderFactory, JsonDecoder} +import org.apache.avro.specific.{SpecificData, SpecificDatumReader} import org.apache.commons.io.FileUtils import org.apache.spark.api.java.JavaRDD import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -897,14 +899,12 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll { val rdd = sparkSession.sparkContext .parallelize(Seq(1)).mapPartitions { iter => - iter.map { _ => - val t = StringArray.newBuilder().setValue(List("the title")).build() - val b = StringArray.newBuilder().setValue(List("BODY TEXT")).build() - val ls = StringArray.newBuilder().setValue(List("foo", "bar", "baz")).build() + iter.map { _ => + val ls = StringArray.newBuilder().setValue(List("foo", "bar", "baz")).build() - Feature.newBuilder().setKey("FOOBAR").setValue(ls).build() + Feature.newBuilder().setKey("FOOBAR").setValue(ls).build() + } } - } val ds = rdd.toDS() assert(ds.count() == 1) From e4c7a4285ae7d1d5257f5502ad3d116524a3cb3d Mon Sep 17 00:00:00 2001 From: Ian Hummel Date: Fri, 3 Feb 2017 10:30:15 -0800 Subject: [PATCH 05/22] Update SBT for Spark-version-dependent code --- build.sbt | 11 ++++- .../Spark20AvroOutputWriterFactory.scala} | 3 +- .../avro/Spark21AvroOutputWriterFactory.scala | 41 +++++++++++++++++++ .../spark/avro/AvroOutputWriter.scala | 12 ++++-- .../databricks/spark/avro/DefaultSource.scala | 24 +++++++---- 5 files changed, 76 insertions(+), 15 deletions(-) rename src/main/{scala/com/databricks/spark/avro/AvroOutputWriterFactory.scala => scala-spark20/com/databricks/spark/avro/Spark20AvroOutputWriterFactory.scala} (95%) create mode 100644 src/main/scala-spark21/com/databricks/spark/avro/Spark21AvroOutputWriterFactory.scala diff --git a/build.sbt b/build.sbt index 4d3bc51c..9c47a8a9 100644 --- a/build.sbt +++ b/build.sbt @@ -8,7 +8,7 @@ crossScalaVersions := Seq("2.10.5", "2.11.7") spName := "databricks/spark-avro" -sparkVersion := "2.0.0" +sparkVersion := "2.1.0" val testSparkVersion = settingKey[String]("The version of Spark to test against.") @@ -34,6 +34,13 @@ spIgnoreProvided := true sparkComponents := Seq("sql") +unmanagedSourceDirectories in Compile += { + sparkVersion.value match { + case v if v.startsWith("2.0.") => baseDirectory.value / "src" / "main" / "scala-spark20" + case v => baseDirectory.value / "src" / "main" / "scala-spark21" + } +} + libraryDependencies ++= Seq( "org.slf4j" % "slf4j-api" % "1.7.5", "org.apache.avro" % "avro" % "1.7.6" exclude("org.mortbay.jetty", "servlet-api"), @@ -104,7 +111,7 @@ pomExtra := bintrayReleaseOnPublish in ThisBuild := false -import ReleaseTransformations._ +import sbtrelease.ReleasePlugin.autoImport.ReleaseTransformations._ // Add publishing to spark packages as another step. releaseProcess := Seq[ReleaseStep]( diff --git a/src/main/scala/com/databricks/spark/avro/AvroOutputWriterFactory.scala b/src/main/scala-spark20/com/databricks/spark/avro/Spark20AvroOutputWriterFactory.scala similarity index 95% rename from src/main/scala/com/databricks/spark/avro/AvroOutputWriterFactory.scala rename to src/main/scala-spark20/com/databricks/spark/avro/Spark20AvroOutputWriterFactory.scala index 339eb147..9c93ddda 100644 --- a/src/main/scala/com/databricks/spark/avro/AvroOutputWriterFactory.scala +++ b/src/main/scala-spark20/com/databricks/spark/avro/Spark20AvroOutputWriterFactory.scala @@ -17,11 +17,10 @@ package com.databricks.spark.avro import org.apache.hadoop.mapreduce.TaskAttemptContext - import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory} import org.apache.spark.sql.types.StructType -private[avro] class AvroOutputWriterFactory( +private[avro] class Spark20AvroOutputWriterFactory( schema: StructType, recordName: String, recordNamespace: String) extends OutputWriterFactory { diff --git a/src/main/scala-spark21/com/databricks/spark/avro/Spark21AvroOutputWriterFactory.scala b/src/main/scala-spark21/com/databricks/spark/avro/Spark21AvroOutputWriterFactory.scala new file mode 100644 index 00000000..7e7756ca --- /dev/null +++ b/src/main/scala-spark21/com/databricks/spark/avro/Spark21AvroOutputWriterFactory.scala @@ -0,0 +1,41 @@ +/* + * Copyright 2014 Databricks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.spark.avro + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory} +import org.apache.spark.sql.types.StructType + +private[avro] class Spark21AvroOutputWriterFactory( + schema: StructType, + recordName: String, + recordNamespace: String) extends OutputWriterFactory { + + def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new AvroOutputWriter(path, context, schema, recordName, recordNamespace) { + override def doGetDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + new Path(path) + } + } + } + + override def getFileExtension(context: TaskAttemptContext): String = ".avro" +} diff --git a/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala b/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala index 24d1b5a3..9ab63af0 100644 --- a/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala +++ b/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala @@ -46,6 +46,13 @@ private[avro] class AvroOutputWriter( private lazy val converter = createConverterToAvro(schema, recordName, recordNamespace) + protected def doGetDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID") + val taskAttemptId: TaskAttemptID = context.getTaskAttemptID + val split = taskAttemptId.getTaskID.getId + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + } + /** * Overrides the couple of methods responsible for generating the output streams / files so * that the data can be correctly partitioned @@ -54,10 +61,7 @@ private[avro] class AvroOutputWriter( new AvroKeyOutputFormat[GenericRecord]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID") - val taskAttemptId: TaskAttemptID = context.getTaskAttemptID - val split = taskAttemptId.getTaskID.getId - new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + doGetDefaultWorkFile(context, extension) } @throws(classOf[IOException]) diff --git a/src/main/scala/com/databricks/spark/avro/DefaultSource.scala b/src/main/scala/com/databricks/spark/avro/DefaultSource.scala index bfbadd7c..6d90e392 100644 --- a/src/main/scala/com/databricks/spark/avro/DefaultSource.scala +++ b/src/main/scala/com/databricks/spark/avro/DefaultSource.scala @@ -20,21 +20,17 @@ import java.io._ import java.net.URI import java.util.zip.Deflater -import scala.util.control.NonFatal - import com.databricks.spark.avro.DefaultSource.{AvroSchema, IgnoreFilesWithoutExtensionProperty, SerializableConfiguration} -import com.esotericsoftware.kryo.{Kryo, KryoSerializable} import com.esotericsoftware.kryo.io.{Input, Output} -import org.apache.avro.{Schema, SchemaBuilder} +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} import org.apache.avro.file.{DataFileConstants, DataFileReader} import org.apache.avro.generic.{GenericDatumReader, GenericRecord} import org.apache.avro.mapred.{AvroOutputFormat, FsInput} import org.apache.avro.mapreduce.AvroJob +import org.apache.avro.{Schema, SchemaBuilder} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.Job -import org.slf4j.LoggerFactory - import org.apache.spark.TaskContext import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow @@ -43,6 +39,9 @@ import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile} import org.apache.spark.sql.sources.{DataSourceRegister, Filter} import org.apache.spark.sql.types.StructType +import org.slf4j.LoggerFactory + +import scala.util.control.NonFatal private[avro] class DefaultSource extends FileFormat with DataSourceRegister { private val log = LoggerFactory.getLogger(getClass) @@ -142,7 +141,18 @@ private[avro] class DefaultSource extends FileFormat with DataSourceRegister { log.error(s"unsupported compression codec $unknown") } - new AvroOutputWriterFactory(dataSchema, recordName, recordNamespace) + val clz = spark.version match { + case v if v.startsWith("2.0.") => { + Class.forName("com.databricks.spark.avro.Spark20AvroOutputWriterFactory") + } + case v => { + Class.forName("com.databricks.spark.avro.Spark21AvroOutputWriterFactory") + } + } + + val m = clz.getDeclaredConstructor(classOf[StructType], classOf[String], classOf[String]) + m.setAccessible(true) + m.newInstance(dataSchema, recordName, recordNamespace).asInstanceOf[OutputWriterFactory] } override def buildReader( From f661f6112f50fd1952b3059c041daa685358a188 Mon Sep 17 00:00:00 2001 From: Ian Hummel Date: Fri, 3 Feb 2017 10:33:04 -0800 Subject: [PATCH 06/22] Keep spark version at 2.0.0 --- build.sbt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build.sbt b/build.sbt index 9c47a8a9..9e98610c 100644 --- a/build.sbt +++ b/build.sbt @@ -8,7 +8,7 @@ crossScalaVersions := Seq("2.10.5", "2.11.7") spName := "databricks/spark-avro" -sparkVersion := "2.1.0" +sparkVersion := "2.0.0" val testSparkVersion = settingKey[String]("The version of Spark to test against.") From cdba6536f62d6caef9e0b86879f7f895071b608a Mon Sep 17 00:00:00 2001 From: Ian Hummel Date: Mon, 6 Feb 2017 15:19:45 -0500 Subject: [PATCH 07/22] Ensures all unit tests pass on Windows --- src/test/scala/com/databricks/spark/avro/AvroSuite.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala index 64f2b098..097a3fca 100644 --- a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala +++ b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala @@ -35,6 +35,8 @@ import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord, import org.apache.avro.io.{DecoderFactory, JsonDecoder} import org.apache.avro.specific.{SpecificData, SpecificDatumReader} import org.apache.commons.io.FileUtils +import org.apache.hadoop.fs +import org.apache.hadoop.fs.Path import org.apache.spark.api.java.JavaRDD import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.{AnalysisException, Row, SparkSession} @@ -572,9 +574,10 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll { test("SQL test insert overwrite") { TestUtils.withTempDir { tempDir => - val tempEmptyDir = s"$tempDir/sqlOverwrite" + val tempEmptyDir = new Path(s"$tempDir/sqlOverwrite") // Create a temp directory for table that will be overwritten - new File(tempEmptyDir).mkdirs() + val local = fs.FileSystem.getLocal(spark.sparkContext.hadoopConfiguration) + local.mkdirs(tempEmptyDir) spark.sql( s""" |CREATE TEMPORARY TABLE episodes From f9ef6363c4540e31b92f218afee419aaded02bc9 Mon Sep 17 00:00:00 2001 From: Ian Hummel Date: Tue, 7 Feb 2017 15:27:38 -0500 Subject: [PATCH 08/22] Unit tests pass with different versions, but build.sbt is messy and it will fail under pacakaging --- .travis.yml | 20 +++++++ build.sbt | 58 ++++++++++++++++--- .../avro/Spark20AvroOutputWriterFactory.scala | 21 ++++++- .../avro/Spark21AvroOutputWriterFactory.scala | 31 ++++++---- .../spark/avro/AvroOutputWriter.scala | 12 +--- .../databricks/spark/avro/SimpleFixed.java | 21 ++++++- 6 files changed, 131 insertions(+), 32 deletions(-) rename {src/main/scala-spark20 => spark-2.0.x/src/main/scala}/com/databricks/spark/avro/Spark20AvroOutputWriterFactory.scala (53%) rename {src/main/scala-spark21 => spark-2.1.x/src/main/scala}/com/databricks/spark/avro/Spark21AvroOutputWriterFactory.scala (55%) diff --git a/.travis.yml b/.travis.yml index a6d4f63c..a6a952a4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -15,6 +15,10 @@ matrix: - jdk: openjdk7 scala: 2.11.7 env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="2.0.0" TEST_AVRO_VERSION="1.7.6" TEST_AVRO_MAPRED_VERSION="1.7.7" + # Spark 2.0.0, Scala 2.11, and Avro 1.8.x + - jdk: openjdk7 + scala: 2.11.7 + env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="2.0.0" TEST_AVRO_VERSION="1.8.0" TEST_AVRO_MAPRED_VERSION="1.8.0" # Spark 2.0.0, Scala 2.10, and Avro 1.7.x - jdk: openjdk7 scala: 2.10.4 @@ -23,6 +27,22 @@ matrix: - jdk: openjdk7 scala: 2.10.4 env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="2.0.0" TEST_AVRO_VERSION="1.8.0" TEST_AVRO_MAPRED_VERSION="1.8.0" + # Spark 2.1.0, Scala 2.11, and Avro 1.7.x + - jdk: openjdk7 + scala: 2.11.7 + env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="2.1.0" TEST_AVRO_VERSION="1.7.6" TEST_AVRO_MAPRED_VERSION="1.7.7" + # Spark 2.1.0, Scala 2.11, and Avro 1.8.x + - jdk: openjdk7 + scala: 2.11.7 + env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="2.1.0" TEST_AVRO_VERSION="1.8.0" TEST_AVRO_MAPRED_VERSION="1.8.0" + # Spark 2.1.0, Scala 2.10, and Avro 1.7.x + - jdk: openjdk7 + scala: 2.10.4 + env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="2.1.0" TEST_AVRO_VERSION="1.7.6" TEST_AVRO_MAPRED_VERSION="1.7.7" + # Spark 2.1.0, Scala 2.10, and Avro 1.8.x + - jdk: openjdk7 + scala: 2.10.4 + env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="2.1.0" TEST_AVRO_VERSION="1.8.0" TEST_AVRO_MAPRED_VERSION="1.8.0" script: - ./dev/run-tests-travis.sh after_success: diff --git a/build.sbt b/build.sbt index 9e98610c..b069d83a 100644 --- a/build.sbt +++ b/build.sbt @@ -1,11 +1,16 @@ -name := "spark-avro" -organization := "com.databricks" +lazy val commonSettings = Seq( + organization := "com.databricks", + scalaVersion := "2.11.7", + crossScalaVersions := Seq("2.10.5", "2.11.7") +) +organization := "com.databricks" scalaVersion := "2.11.7" - crossScalaVersions := Seq("2.10.5", "2.11.7") +name := "spark-avro" + spName := "databricks/spark-avro" sparkVersion := "2.0.0" @@ -34,12 +39,12 @@ spIgnoreProvided := true sparkComponents := Seq("sql") -unmanagedSourceDirectories in Compile += { - sparkVersion.value match { - case v if v.startsWith("2.0.") => baseDirectory.value / "src" / "main" / "scala-spark20" - case v => baseDirectory.value / "src" / "main" / "scala-spark21" - } -} +//unmanagedSourceDirectories in Compile += { +// sparkVersion.value match { +// case v if v.startsWith("2.0.") => baseDirectory.value / "src" / "main" / "scala-spark20" +// case v => baseDirectory.value / "src" / "main" / "scala-spark21" +// } +//} libraryDependencies ++= Seq( "org.slf4j" % "slf4j-api" % "1.7.5", @@ -111,6 +116,7 @@ pomExtra := bintrayReleaseOnPublish in ThisBuild := false +import sbt.Keys.crossScalaVersions import sbtrelease.ReleasePlugin.autoImport.ReleaseTransformations._ // Add publishing to spark packages as another step. @@ -127,3 +133,37 @@ releaseProcess := Seq[ReleaseStep]( pushChanges, releaseStepTask(spPublish) ) + +lazy val spark21xProj = project.in(file("spark-2.1.x")).settings( +// organization := "com.databricks", + scalaVersion := "2.11.7", + crossScalaVersions := Seq("2.10.5", "2.11.7"), + libraryDependencies += "org.apache.spark" %% "spark-sql" % "2.1.0" % "provided" +) +lazy val spark20xProj = project.in(file("spark-2.0.x")).settings( +// organization := "com.databricks", + scalaVersion := "2.11.7", + crossScalaVersions := Seq("2.10.5", "2.11.7"), + libraryDependencies += "org.apache.spark" %% "spark-sql" % "2.0.0" % "provided" +) + +aggregateProjects(spark20xProj, spark21xProj) + +dependsOn(spark21xProj) +dependsOn(spark20xProj) + +//projectDependencies := { +// Seq( +// (projectID in spark20xProj).value.exclude("org.apache.spark", "spark-sql"), +// (projectID in spark21xProj).value.exclude("org.apache.spark", "spark-sql") +// ) +//} +// +// +// +//mappings in (Compile, packageBin) ++= { +// (dependencyClasspath in Runtime).value.foreach { i => +// println(s"%%%%%%%%% ${i}") +// } +// Seq() +//} \ No newline at end of file diff --git a/src/main/scala-spark20/com/databricks/spark/avro/Spark20AvroOutputWriterFactory.scala b/spark-2.0.x/src/main/scala/com/databricks/spark/avro/Spark20AvroOutputWriterFactory.scala similarity index 53% rename from src/main/scala-spark20/com/databricks/spark/avro/Spark20AvroOutputWriterFactory.scala rename to spark-2.0.x/src/main/scala/com/databricks/spark/avro/Spark20AvroOutputWriterFactory.scala index 9c93ddda..72855b43 100644 --- a/src/main/scala-spark20/com/databricks/spark/avro/Spark20AvroOutputWriterFactory.scala +++ b/spark-2.0.x/src/main/scala/com/databricks/spark/avro/Spark20AvroOutputWriterFactory.scala @@ -16,7 +16,8 @@ package com.databricks.spark.avro -import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.{TaskAttemptContext, TaskAttemptID} import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory} import org.apache.spark.sql.types.StructType @@ -25,11 +26,27 @@ private[avro] class Spark20AvroOutputWriterFactory( recordName: String, recordNamespace: String) extends OutputWriterFactory { + def doGetDefaultWorkFile(path: String, context: TaskAttemptContext, extension: String): Path = { + val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID") + val taskAttemptId: TaskAttemptID = context.getTaskAttemptID + val split = taskAttemptId.getTaskID.getId + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + } + def newInstance( path: String, bucketId: Option[Int], dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new AvroOutputWriter(path, context, schema, recordName, recordNamespace) + + val ot = Class.forName("com.databricks.spark.avro.AvroOutputWriter") + val meth = ot.getDeclaredConstructor( + classOf[String], classOf[TaskAttemptContext], classOf[StructType], + classOf[String], classOf[String], + classOf[Function3[String, TaskAttemptContext, String, Path]] + ) + meth.setAccessible(true) + meth.newInstance(path, context, schema, recordName, recordNamespace, doGetDefaultWorkFile _) + .asInstanceOf[OutputWriter] } } diff --git a/src/main/scala-spark21/com/databricks/spark/avro/Spark21AvroOutputWriterFactory.scala b/spark-2.1.x/src/main/scala/com/databricks/spark/avro/Spark21AvroOutputWriterFactory.scala similarity index 55% rename from src/main/scala-spark21/com/databricks/spark/avro/Spark21AvroOutputWriterFactory.scala rename to spark-2.1.x/src/main/scala/com/databricks/spark/avro/Spark21AvroOutputWriterFactory.scala index 7e7756ca..6bfb5eb0 100644 --- a/src/main/scala-spark21/com/databricks/spark/avro/Spark21AvroOutputWriterFactory.scala +++ b/spark-2.1.x/src/main/scala/com/databricks/spark/avro/Spark21AvroOutputWriterFactory.scala @@ -22,19 +22,28 @@ import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFac import org.apache.spark.sql.types.StructType private[avro] class Spark21AvroOutputWriterFactory( - schema: StructType, - recordName: String, - recordNamespace: String) extends OutputWriterFactory { + schema: StructType, + recordName: String, + recordNamespace: String) extends OutputWriterFactory { + + def doGetDefaultWorkFile(path: String, context: TaskAttemptContext, extension: String): Path = { + new Path(path) + } def newInstance( - path: String, - dataSchema: StructType, - context: TaskAttemptContext): OutputWriter = { - new AvroOutputWriter(path, context, schema, recordName, recordNamespace) { - override def doGetDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - new Path(path) - } - } + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + + val ot = Class.forName("com.databricks.spark.avro.AvroOutputWriter") + val meth = ot.getDeclaredConstructor( + classOf[String], classOf[TaskAttemptContext], classOf[StructType], + classOf[String], classOf[String], + classOf[Function3[String, TaskAttemptContext, String, Path]] + ) + meth.setAccessible(true) + meth.newInstance(path, context, schema, recordName, recordNamespace, doGetDefaultWorkFile _) + .asInstanceOf[OutputWriter] } override def getFileExtension(context: TaskAttemptContext): String = ".avro" diff --git a/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala b/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala index 9ab63af0..2de1bdb9 100644 --- a/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala +++ b/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala @@ -42,17 +42,11 @@ private[avro] class AvroOutputWriter( context: TaskAttemptContext, schema: StructType, recordName: String, - recordNamespace: String) extends OutputWriter { + recordNamespace: String, + workPathFunc: (String, TaskAttemptContext, String) => Path) extends OutputWriter { private lazy val converter = createConverterToAvro(schema, recordName, recordNamespace) - protected def doGetDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID") - val taskAttemptId: TaskAttemptID = context.getTaskAttemptID - val split = taskAttemptId.getTaskID.getId - new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") - } - /** * Overrides the couple of methods responsible for generating the output streams / files so * that the data can be correctly partitioned @@ -61,7 +55,7 @@ private[avro] class AvroOutputWriter( new AvroKeyOutputFormat[GenericRecord]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - doGetDefaultWorkFile(context, extension) + workPathFunc(path, context, extension) } @throws(classOf[IOException]) diff --git a/src/test/java/com/databricks/spark/avro/SimpleFixed.java b/src/test/java/com/databricks/spark/avro/SimpleFixed.java index 8318b65a..184b51f5 100644 --- a/src/test/java/com/databricks/spark/avro/SimpleFixed.java +++ b/src/test/java/com/databricks/spark/avro/SimpleFixed.java @@ -3,7 +3,14 @@ * * DO NOT EDIT DIRECTLY */ -package com.databricks.spark.avro; +package com.databricks.spark.avro; + +import org.apache.avro.Schema; + +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; + @SuppressWarnings("all") @org.apache.avro.specific.FixedSize(16) @org.apache.avro.specific.AvroGenerated @@ -20,4 +27,16 @@ public SimpleFixed() { public SimpleFixed(byte[] bytes) { super(bytes); } + + public void writeExternal(ObjectOutput out) throws IOException { + // + } + + public void readExternal(ObjectInput in) throws IOException { + // + } + + public Schema getSchema() { + return getClassSchema(); + } } From 5c3c11601451d1fc99409d7aded74fa728cb88dd Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 7 Feb 2017 15:41:16 -0800 Subject: [PATCH 09/22] Add support for Spark 2.1.x (while retaining Spark 2.0.x support) This patch builds on #206 in order to restore support for Spark 2.0.x. This means that a single binary artifact can be used with both Spark 2.0.x and 2.1.x, simplifying the builds of downstream projects which are compatible with both Spark versions. Author: Josh Rosen Closes #212 from JoshRosen/add-spark-2.1. --- .travis.yml | 14 ++++++++++++++ build.sbt | 2 +- .../databricks/spark/avro/AvroOutputWriter.scala | 10 +++++++++- .../spark/avro/AvroOutputWriterFactory.scala | 12 ++++++++++-- 4 files changed, 34 insertions(+), 4 deletions(-) diff --git a/.travis.yml b/.travis.yml index 42a12162..1df09667 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,6 +11,20 @@ before_cache: - find $HOME/.sbt -name "*.lock" -delete matrix: include: + # ---- Spark 2.0.x ---------------------------------------------------------------------------- + # Spark 2.0.0, Scala 2.11, and Avro 1.7.x + - jdk: openjdk7 + scala: 2.11.7 + env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="2.0.0" TEST_AVRO_VERSION="1.7.6" TEST_AVRO_MAPRED_VERSION="1.7.7" + # Spark 2.0.0, Scala 2.10, and Avro 1.7.x + - jdk: openjdk7 + scala: 2.10.4 + env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="2.0.0" TEST_AVRO_VERSION="1.7.6" TEST_AVRO_MAPRED_VERSION="1.7.7" + # Spark 2.0.0, Scala 2.10, and Avro 1.8.x + - jdk: openjdk7 + scala: 2.10.4 + env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="2.0.0" TEST_AVRO_VERSION="1.8.0" TEST_AVRO_MAPRED_VERSION="1.8.0" + # ---- Spark 2.1.x ---------------------------------------------------------------------------- # Spark 2.1.0, Scala 2.11, and Avro 1.7.x - jdk: openjdk7 scala: 2.11.8 diff --git a/build.sbt b/build.sbt index d8d933f3..1c9b212c 100644 --- a/build.sbt +++ b/build.sbt @@ -8,7 +8,7 @@ crossScalaVersions := Seq("2.10.6", "2.11.8") spName := "databricks/spark-avro" -sparkVersion := "2.1.0" +sparkVersion := "2.0.0" val testSparkVersion = settingKey[String]("The version of Spark to test against.") diff --git a/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala b/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala index bc71564e..cf515206 100644 --- a/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala +++ b/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala @@ -32,6 +32,7 @@ import org.apache.avro.mapreduce.AvroKeyOutputFormat import org.apache.hadoop.io.NullWritable import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext, TaskAttemptID} +import org.apache.spark.SPARK_VERSION import org.apache.spark.sql.Row import org.apache.spark.sql.execution.datasources.OutputWriter import org.apache.spark.sql.types._ @@ -54,7 +55,14 @@ private[avro] class AvroOutputWriter( new AvroKeyOutputFormat[GenericRecord]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - new Path(path) + if (SPARK_VERSION.startsWith("2.0")) { + val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID") + val taskAttemptId: TaskAttemptID = context.getTaskAttemptID + val split = taskAttemptId.getTaskID.getId + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + } else { + new Path(path) + } } @throws(classOf[IOException]) diff --git a/src/main/scala/com/databricks/spark/avro/AvroOutputWriterFactory.scala b/src/main/scala/com/databricks/spark/avro/AvroOutputWriterFactory.scala index 3f3cbf07..84fce45d 100644 --- a/src/main/scala/com/databricks/spark/avro/AvroOutputWriterFactory.scala +++ b/src/main/scala/com/databricks/spark/avro/AvroOutputWriterFactory.scala @@ -26,11 +26,19 @@ private[avro] class AvroOutputWriterFactory( recordName: String, recordNamespace: String) extends OutputWriterFactory { - override def getFileExtension(context: TaskAttemptContext): String = { + def getFileExtension(context: TaskAttemptContext): String = { ".avro" } - override def newInstance( + def newInstance( + path: String, + bucketId: Option[Int], + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + newInstance(path, dataSchema, context) + } + + def newInstance( path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { From 762f7a1720010cc38416c222f618a78375deb2a7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 7 Feb 2017 15:42:39 -0800 Subject: [PATCH 10/22] Update README in preparation for 3.2.0 release Author: Josh Rosen Closes #213 from JoshRosen/prepare-for-3.2-release. --- README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index f002f56a..0eb977a5 100644 --- a/README.md +++ b/README.md @@ -10,14 +10,14 @@ A library for reading and writing Avro data from [Spark SQL](http://spark.apache This documentation is for version 3.1.0 of this library, which supports Spark 2.0+. For documentation on earlier versions of this library, see the links below. -This library has different versions for Spark 1.2, 1.3, 1.4+, and 2.0: +This library has different versions for Spark 1.2, 1.3, 1.4 through 1.6, and 2.0+: | Spark Version | Compatible version of Avro Data Source for Spark | | ------------- | ------------------------------------------------ | | `1.2` | `0.2.0` | | `1.3` | [`1.0.0`](https://github.com/databricks/spark-avro/tree/v1.0.0) | -| `1.4+` | [`2.0.1`](https://github.com/databricks/spark-avro/tree/v2.0.1) | -| `2.0` | `3.1.0` (this version) | +| `1.4`-`1.6` | [`2.0.1`](https://github.com/databricks/spark-avro/tree/v2.0.1) | +| `2.0+` | `3.2.0` (this version) | ## Linking @@ -28,7 +28,7 @@ You can link against this library in your program at the following coordinates: **Using SBT:** ``` -libraryDependencies += "com.databricks" %% "spark-avro" % "3.1.0" +libraryDependencies += "com.databricks" %% "spark-avro" % "3.2.0" ``` **Using Maven:** @@ -37,7 +37,7 @@ libraryDependencies += "com.databricks" %% "spark-avro" % "3.1.0" com.databricks spark-avro_2.10 - 3.1.0 + 3.2.0 ``` @@ -47,7 +47,7 @@ This library can also be added to Spark jobs launched through `spark-shell` or ` For example, to include it when starting the spark shell: ``` -$ bin/spark-shell --packages com.databricks:spark-avro_2.11:3.1.0 +$ bin/spark-shell --packages com.databricks:spark-avro_2.11:3.2.0 ``` Unlike using `--jars`, using `--packages` ensures that this library and its dependencies will be added to the classpath. The `--packages` argument can also be used with `bin/spark-submit`. From e28e45609b0178634a91ad8807f5f4be93590fef Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 8 Feb 2017 10:20:59 -0800 Subject: [PATCH 11/22] Setting version to 3.2.0 --- version.sbt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.sbt b/version.sbt index 7370c6b4..737e59e9 100644 --- a/version.sbt +++ b/version.sbt @@ -1 +1 @@ -version in ThisBuild := "3.2.0-SNAPSHOT" \ No newline at end of file +version in ThisBuild := "3.2.0" \ No newline at end of file From 1bfe421cd110d721e1688f1e091225487e917524 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 8 Feb 2017 10:23:23 -0800 Subject: [PATCH 12/22] Setting version to 3.2.1-SNAPSHOT --- version.sbt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.sbt b/version.sbt index 737e59e9..0750fecd 100644 --- a/version.sbt +++ b/version.sbt @@ -1 +1 @@ -version in ThisBuild := "3.2.0" \ No newline at end of file +version in ThisBuild := "3.2.1-SNAPSHOT" \ No newline at end of file From 51eb883769cb5d3e006da19d8ffe19d41ddd9895 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 8 Feb 2017 11:12:52 -0800 Subject: [PATCH 13/22] 3.1.0 -> 3.2.0 in README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 0eb977a5..ca6f168b 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ A library for reading and writing Avro data from [Spark SQL](http://spark.apache ## Requirements -This documentation is for version 3.1.0 of this library, which supports Spark 2.0+. For +This documentation is for version 3.2.0 of this library, which supports Spark 2.0+. For documentation on earlier versions of this library, see the links below. This library has different versions for Spark 1.2, 1.3, 1.4 through 1.6, and 2.0+: From 69a057044364cf27bce084a3004389816bb78dc9 Mon Sep 17 00:00:00 2001 From: Ian Hummel Date: Thu, 9 Feb 2017 10:45:01 -0500 Subject: [PATCH 14/22] Build is now working and including the classfiles in the final artifact --- build.sbt | 53 ++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 17 deletions(-) diff --git a/build.sbt b/build.sbt index b069d83a..9b92b654 100644 --- a/build.sbt +++ b/build.sbt @@ -139,31 +139,50 @@ lazy val spark21xProj = project.in(file("spark-2.1.x")).settings( scalaVersion := "2.11.7", crossScalaVersions := Seq("2.10.5", "2.11.7"), libraryDependencies += "org.apache.spark" %% "spark-sql" % "2.1.0" % "provided" -) +).disablePlugins(SparkPackagePlugin) + + lazy val spark20xProj = project.in(file("spark-2.0.x")).settings( // organization := "com.databricks", scalaVersion := "2.11.7", crossScalaVersions := Seq("2.10.5", "2.11.7"), libraryDependencies += "org.apache.spark" %% "spark-sql" % "2.0.0" % "provided" -) +).disablePlugins(SparkPackagePlugin) aggregateProjects(spark20xProj, spark21xProj) dependsOn(spark21xProj) dependsOn(spark20xProj) -//projectDependencies := { -// Seq( -// (projectID in spark20xProj).value.exclude("org.apache.spark", "spark-sql"), -// (projectID in spark21xProj).value.exclude("org.apache.spark", "spark-sql") -// ) -//} -// -// -// -//mappings in (Compile, packageBin) ++= { -// (dependencyClasspath in Runtime).value.foreach { i => -// println(s"%%%%%%%%% ${i}") -// } -// Seq() -//} \ No newline at end of file +projectDependencies := { + Seq( + (projectID in spark20xProj).value.excludeAll(ExclusionRule(organization = "*")), + (projectID in spark21xProj).value.excludeAll(ExclusionRule(organization = "*")) + ) +} + + +def createMappingForPackage(base: File): Seq[(File, String)] = { + import Path._ + + (base ** (-DirectoryFilter)).get.map { f => + f -> IO.relativize(base, f) + }.collect { + case (f, Some(p)) => (f, p) + } +} + +mappings in (Compile, packageBin) ++= { + import Path._ + + val base = (dependencyClasspath in Runtime).value.collect { + case i if i.get(moduleID.key).exists(_ == (projectID in spark20xProj).value) => i.data + case i if i.get(moduleID.key).exists(_ == (projectID in spark21xProj).value) => i.data + } + + val m = base.flatMap(createMappingForPackage) + + println("****** ret ********") + println(m) + m +} \ No newline at end of file From ce479068e0e0f09f89062c350c9e8c5345aabcf0 Mon Sep 17 00:00:00 2001 From: Ian Hummel Date: Thu, 9 Feb 2017 19:00:48 -0500 Subject: [PATCH 15/22] build.sbt is looking better! --- build.sbt | 58 +++++++++++-------------------------------------------- 1 file changed, 11 insertions(+), 47 deletions(-) diff --git a/build.sbt b/build.sbt index 9b92b654..5cbf33d8 100644 --- a/build.sbt +++ b/build.sbt @@ -5,9 +5,7 @@ lazy val commonSettings = Seq( crossScalaVersions := Seq("2.10.5", "2.11.7") ) -organization := "com.databricks" -scalaVersion := "2.11.7" -crossScalaVersions := Seq("2.10.5", "2.11.7") +commonSettings name := "spark-avro" @@ -39,13 +37,6 @@ spIgnoreProvided := true sparkComponents := Seq("sql") -//unmanagedSourceDirectories in Compile += { -// sparkVersion.value match { -// case v if v.startsWith("2.0.") => baseDirectory.value / "src" / "main" / "scala-spark20" -// case v => baseDirectory.value / "src" / "main" / "scala-spark21" -// } -//} - libraryDependencies ++= Seq( "org.slf4j" % "slf4j-api" % "1.7.5", "org.apache.avro" % "avro" % "1.7.6" exclude("org.mortbay.jetty", "servlet-api"), @@ -116,7 +107,6 @@ pomExtra := bintrayReleaseOnPublish in ThisBuild := false -import sbt.Keys.crossScalaVersions import sbtrelease.ReleasePlugin.autoImport.ReleaseTransformations._ // Add publishing to spark packages as another step. @@ -134,55 +124,29 @@ releaseProcess := Seq[ReleaseStep]( releaseStepTask(spPublish) ) + lazy val spark21xProj = project.in(file("spark-2.1.x")).settings( -// organization := "com.databricks", - scalaVersion := "2.11.7", - crossScalaVersions := Seq("2.10.5", "2.11.7"), + commonSettings, libraryDependencies += "org.apache.spark" %% "spark-sql" % "2.1.0" % "provided" ).disablePlugins(SparkPackagePlugin) lazy val spark20xProj = project.in(file("spark-2.0.x")).settings( -// organization := "com.databricks", - scalaVersion := "2.11.7", - crossScalaVersions := Seq("2.10.5", "2.11.7"), + commonSettings, libraryDependencies += "org.apache.spark" %% "spark-sql" % "2.0.0" % "provided" ).disablePlugins(SparkPackagePlugin) -aggregateProjects(spark20xProj, spark21xProj) - -dependsOn(spark21xProj) -dependsOn(spark20xProj) - -projectDependencies := { - Seq( - (projectID in spark20xProj).value.excludeAll(ExclusionRule(organization = "*")), - (projectID in spark21xProj).value.excludeAll(ExclusionRule(organization = "*")) - ) -} - - -def createMappingForPackage(base: File): Seq[(File, String)] = { - import Path._ - - (base ** (-DirectoryFilter)).get.map { f => - f -> IO.relativize(base, f) - }.collect { - case (f, Some(p)) => (f, p) - } -} - mappings in (Compile, packageBin) ++= { import Path._ - val base = (dependencyClasspath in Runtime).value.collect { - case i if i.get(moduleID.key).exists(_ == (projectID in spark20xProj).value) => i.data - case i if i.get(moduleID.key).exists(_ == (projectID in spark21xProj).value) => i.data + def createMappingForPackage(base: File): Seq[(File, String)] = { + (base ** (-DirectoryFilter)).get.flatMap { f => + IO.relativize(base, f).map(f -> _) + } } - val m = base.flatMap(createMappingForPackage) + val compatClasses = (exportedProducts in (spark20xProj, Runtime)).value ++ + (exportedProducts in (spark21xProj, Runtime)).value - println("****** ret ********") - println(m) - m + compatClasses.flatMap { x => createMappingForPackage(x.data) } } \ No newline at end of file From 0934452d13ef54fb3d6551b5ff3e926b9f54a28f Mon Sep 17 00:00:00 2001 From: Ian Hummel Date: Fri, 10 Feb 2017 17:10:56 -0500 Subject: [PATCH 16/22] Much tighter way to include the spark version-specific modules into the final jar --- build.sbt | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/build.sbt b/build.sbt index 5cbf33d8..974460b6 100644 --- a/build.sbt +++ b/build.sbt @@ -136,17 +136,13 @@ lazy val spark20xProj = project.in(file("spark-2.0.x")).settings( libraryDependencies += "org.apache.spark" %% "spark-sql" % "2.0.0" % "provided" ).disablePlugins(SparkPackagePlugin) -mappings in (Compile, packageBin) ++= { - import Path._ - def createMappingForPackage(base: File): Seq[(File, String)] = { - (base ** (-DirectoryFilter)).get.flatMap { f => - IO.relativize(base, f).map(f -> _) - } - } - - val compatClasses = (exportedProducts in (spark20xProj, Runtime)).value ++ +unmanagedClasspath in Test ++= { + (exportedProducts in (spark20xProj, Runtime)).value ++ (exportedProducts in (spark21xProj, Runtime)).value +} - compatClasses.flatMap { x => createMappingForPackage(x.data) } -} \ No newline at end of file +products in (Compile, packageBin) ++= Seq( + (classDirectory in (spark20xProj, Compile)).value, + (classDirectory in (spark21xProj, Compile)).value +) \ No newline at end of file From 73fc4d81e3e79a0f527fe1cb093753ddafb22f58 Mon Sep 17 00:00:00 2001 From: Ian Hummel Date: Fri, 10 Feb 2017 18:18:58 -0500 Subject: [PATCH 17/22] Update create table to avoid deprecration warnings --- src/test/scala/com/databricks/spark/avro/AvroSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala index 097a3fca..cdd1b739 100644 --- a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala +++ b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala @@ -422,7 +422,7 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll { test("sql test") { spark.sql( s""" - |CREATE TEMPORARY TABLE avroTable + |CREATE TEMPORARY VIEW avroTable |USING com.databricks.spark.avro |OPTIONS (path "$episodesFile") """.stripMargin.replaceAll("\n", " ")) @@ -580,13 +580,13 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll { local.mkdirs(tempEmptyDir) spark.sql( s""" - |CREATE TEMPORARY TABLE episodes + |CREATE TEMPORARY VIEW episodes |USING com.databricks.spark.avro |OPTIONS (path "$episodesFile") """.stripMargin.replaceAll("\n", " ")) spark.sql( s""" - |CREATE TEMPORARY TABLE episodesEmpty + |CREATE TEMPORARY VIEW episodesEmpty |(name string, air_date string, doctor int) |USING com.databricks.spark.avro |OPTIONS (path "$tempEmptyDir") From 0e99d723a25ef0a844bec5931195bcf0e8a4590a Mon Sep 17 00:00:00 2001 From: Ian Hummel Date: Wed, 15 Feb 2017 17:44:30 -0500 Subject: [PATCH 18/22] Make unit test for GenericData.Record pass --- .../databricks/spark/avro/AvroEncoder.scala | 14 ++++++--- .../com/databricks/spark/avro/AvroSuite.scala | 31 ++++++++++++++++++- 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/src/main/scala/com/databricks/spark/avro/AvroEncoder.scala b/src/main/scala/com/databricks/spark/avro/AvroEncoder.scala index eff31330..fcfbc671 100644 --- a/src/main/scala/com/databricks/spark/avro/AvroEncoder.scala +++ b/src/main/scala/com/databricks/spark/avro/AvroEncoder.scala @@ -98,7 +98,7 @@ private[avro] object ObjectType { } } -case class LambdaVariable( +private[avro] case class LambdaVariable( value: String, isNull: String, dataType: DataType, @@ -110,7 +110,7 @@ case class LambdaVariable( } } -object ExternalMapToCatalyst { +private[avro] object ExternalMapToCatalyst { private val curId = new java.util.concurrent.atomic.AtomicInteger() def apply( @@ -138,7 +138,7 @@ object ExternalMapToCatalyst { } } -case class ExternalMapToCatalyst private( +private[avro] case class ExternalMapToCatalyst private( key: String, keyType: DataType, keyConverter: Expression, @@ -517,7 +517,13 @@ private object AvroTypeInference { val newInstance = if (recordClass == classOf[GenericData.Record]) { NewInstance( recordClass, - Literal.fromObject(avroSchema, ObjectType(classOf[Schema])) :: Nil, + Invoke( + Literal.fromObject( + new SerializableSchema(avroSchema), + ObjectType(classOf[SerializableSchema])), + "value", + ObjectType(classOf[Schema]), + Nil) :: Nil, ObjectType(recordClass)) } else { NewInstance( diff --git a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala index cdd1b739..fdc86d31 100644 --- a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala +++ b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala @@ -38,6 +38,7 @@ import org.apache.commons.io.FileUtils import org.apache.hadoop.fs import org.apache.hadoop.fs.Path import org.apache.spark.api.java.JavaRDD +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.types._ @@ -91,7 +92,7 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll { test("request no fields") { val df = spark.read.avro(episodesFile) - df.registerTempTable("avro_table") + df.createOrReplaceTempView("avro_table") assert(spark.sql("select count(*) from avro_table").collect().head === Row(8)) } @@ -912,4 +913,32 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll { val ds = rdd.toDS() assert(ds.count() == 1) } + + test("create Dataset from GenericRecord") { + val sparkSession = spark + import sparkSession.implicits._ + + val schema: Schema = + SchemaBuilder + .record("GenericRecordTest") + .namespace("com.databricks.spark.avro") + .fields() + .requiredString("field1") + .endRecord() + + implicit val enc = AvroEncoder.of[GenericData.Record](schema) + + val genericRecords = (1 to 10) map { i => + new GenericRecordBuilder(schema) + .set("field1", "field-" + i) + .build() + } + + val rdd: RDD[GenericData.Record] = sparkSession.sparkContext + .parallelize(genericRecords) + + val ds = rdd.toDS() + + assert(ds.count() == genericRecords.size) + } } From 3a1fd20c8d662f5b6f1a17574fe8d99d2a2e3161 Mon Sep 17 00:00:00 2001 From: Ian Hummel Date: Wed, 15 Feb 2017 17:54:43 -0500 Subject: [PATCH 19/22] Make tests pass with ENUM and FIXED --- .../com/databricks/spark/avro/AvroEncoder.scala | 16 ++++++++++++++-- .../com/databricks/spark/avro/AvroSuite.scala | 2 ++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/main/scala/com/databricks/spark/avro/AvroEncoder.scala b/src/main/scala/com/databricks/spark/avro/AvroEncoder.scala index fcfbc671..41acfa2d 100644 --- a/src/main/scala/com/databricks/spark/avro/AvroEncoder.scala +++ b/src/main/scala/com/databricks/spark/avro/AvroEncoder.scala @@ -408,7 +408,13 @@ private object AvroTypeInference { if (fixedClass == classOf[GenericData.Fixed]) { NewInstance( fixedClass, - Literal.fromObject(avroSchema, ObjectType(classOf[Schema])) :: + Invoke( + Literal.fromObject( + new SerializableSchema(avroSchema), + ObjectType(classOf[SerializableSchema])), + "value", + ObjectType(classOf[Schema]), + Nil) :: getPath :: Nil, ObjectType(fixedClass)) @@ -428,7 +434,13 @@ private object AvroTypeInference { if (enumClass == classOf[GenericData.EnumSymbol]) { NewInstance( enumClass, - Literal.fromObject(avroSchema, ObjectType(classOf[Schema])) :: + Invoke( + Literal.fromObject( + new SerializableSchema(avroSchema), + ObjectType(classOf[SerializableSchema])), + "value", + ObjectType(classOf[Schema]), + Nil) :: Invoke(getPath, "toString", ObjectType(classOf[String])) :: Nil, ObjectType(enumClass)) diff --git a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala index fdc86d31..229ef46d 100644 --- a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala +++ b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala @@ -924,6 +924,8 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll { .namespace("com.databricks.spark.avro") .fields() .requiredString("field1") + .name("enumVal").`type`().enumeration("letters").symbols("a", "b", "c").enumDefault("a") + .name("fixedVal").`type`().fixed("MD5").size(16).fixedDefault(ByteBuffer.allocate(16)) .endRecord() implicit val enc = AvroEncoder.of[GenericData.Record](schema) From d5c0329eed42a63df1d6a3a3290be10bfc63f4f4 Mon Sep 17 00:00:00 2001 From: Nihed MBAREK Date: Thu, 16 Feb 2017 02:47:30 -0800 Subject: [PATCH 20/22] add support for DateType Hi, based on this issue https://github.com/databricks/spark-avro/issues/67 I create this pull request Author: Nihed MBAREK Author: vlyubin Author: nihed Closes #124 from nihed/master. (cherry picked from commit c19f01af60458bf6ffd24f2505581d924fffffd2) Signed-off-by: vlyubin --- .../spark/avro/AvroOutputWriter.scala | 3 ++ .../spark/avro/SchemaConverters.scala | 2 ++ .../com/databricks/spark/avro/AvroSuite.scala | 33 +++++++++++++++---- .../spark/avro/AvroWriteBenchmark.scala | 7 ++-- 4 files changed, 35 insertions(+), 10 deletions(-) diff --git a/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala b/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala index cf515206..c746b50c 100644 --- a/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala +++ b/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala @@ -19,6 +19,7 @@ package com.databricks.spark.avro import java.io.{IOException, OutputStream} import java.nio.ByteBuffer import java.sql.Timestamp +import java.sql.Date import java.util.HashMap import org.apache.hadoop.fs.Path @@ -98,6 +99,8 @@ private[avro] class AvroOutputWriter( case _: DecimalType => (item: Any) => if (item == null) null else item.toString case TimestampType => (item: Any) => if (item == null) null else item.asInstanceOf[Timestamp].getTime + case DateType => (item: Any) => + if (item == null) null else item.asInstanceOf[Date].getTime case ArrayType(elementType, _) => val elementConverter = createConverterToAvro(elementType, structName, recordNamespace) (item: Any) => { diff --git a/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala b/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala index aa634d4c..7f8e20f4 100644 --- a/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala +++ b/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala @@ -328,6 +328,7 @@ object SchemaConverters { case BinaryType => schemaBuilder.bytesType() case BooleanType => schemaBuilder.booleanType() case TimestampType => schemaBuilder.longType() + case DateType => schemaBuilder.longType() case ArrayType(elementType, _) => val builder = getSchemaBuilder(dataType.asInstanceOf[ArrayType].containsNull) @@ -371,6 +372,7 @@ object SchemaConverters { case BinaryType => newFieldBuilder.bytesType() case BooleanType => newFieldBuilder.booleanType() case TimestampType => newFieldBuilder.longType() + case DateType => newFieldBuilder.longType() case ArrayType(elementType, _) => val builder = getSchemaBuilder(dataType.asInstanceOf[ArrayType].containsNull) diff --git a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala index 1b5d07aa..4843ad46 100644 --- a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala +++ b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala @@ -19,22 +19,21 @@ package com.databricks.spark.avro import java.io._ import java.nio.ByteBuffer import java.nio.file.Files -import java.sql.Timestamp -import java.util.UUID +import java.sql.{Date, Timestamp} +import java.util.{TimeZone, UUID} import scala.collection.JavaConversions._ - -import com.databricks.spark.avro.SchemaConverters.IncompatibleSchemaException import org.apache.avro.Schema import org.apache.avro.Schema.{Field, Type} import org.apache.avro.file.DataFileWriter -import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord} +import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} import org.apache.commons.io.FileUtils - -import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.SparkContext +import org.apache.spark.sql._ import org.apache.spark.sql.types._ import org.scalatest.{BeforeAndAfterAll, FunSuite} +import com.databricks.spark.avro.SchemaConverters.IncompatibleSchemaException class AvroSuite extends FunSuite with BeforeAndAfterAll { val episodesFile = "src/test/resources/episodes.avro" @@ -297,6 +296,26 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll { } } + test("Date field type") { + TestUtils.withTempDir { dir => + val schema = StructType(Seq( + StructField("float", FloatType, true), + StructField("date", DateType, true) + )) + TimeZone.setDefault(TimeZone.getTimeZone("UTC")) + val rdd = spark.sparkContext.parallelize(Seq( + Row(1f, null), + Row(2f, new Date(1451948400000L)), + Row(3f, new Date(1460066400500L)) + )) + val df = spark.createDataFrame(rdd, schema) + df.write.avro(dir.toString) + assert(spark.read.avro(dir.toString).count == rdd.count) + assert(spark.read.avro(dir.toString).select("date").collect().map(_(0)).toSet == + Array(null, 1451865600000L, 1459987200000L).toSet) + } + } + test("Array data types") { TestUtils.withTempDir { dir => val testSchema = StructType(Seq( diff --git a/src/test/scala/com/databricks/spark/avro/AvroWriteBenchmark.scala b/src/test/scala/com/databricks/spark/avro/AvroWriteBenchmark.scala index b36438c1..2ccc456b 100644 --- a/src/test/scala/com/databricks/spark/avro/AvroWriteBenchmark.scala +++ b/src/test/scala/com/databricks/spark/avro/AvroWriteBenchmark.scala @@ -16,6 +16,7 @@ package com.databricks.spark.avro +import java.sql.Date import java.util.concurrent.TimeUnit import scala.collection.JavaConversions._ @@ -23,8 +24,7 @@ import scala.util.Random import com.google.common.io.Files import org.apache.commons.io.FileUtils - -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql._ import org.apache.spark.sql.types._ /** @@ -40,6 +40,7 @@ object AvroWriteBenchmark { val testSchema = StructType(Seq( StructField("StringField", StringType, false), StructField("IntField", IntegerType, true), + StructField("dateField", DateType, true), StructField("DoubleField", DoubleType, false), StructField("DecimalField", DecimalType(10, 10), true), StructField("ArrayField", ArrayType(BooleanType), false), @@ -48,7 +49,7 @@ object AvroWriteBenchmark { private def generateRandomRow(): Row = { val rand = new Random() - Row(rand.nextString(defaultSize), rand.nextInt(), rand.nextDouble(), rand.nextDouble(), + Row(rand.nextString(defaultSize), rand.nextInt(), new Date(rand.nextLong()) ,rand.nextDouble(), rand.nextDouble(), TestUtils.generateRandomArray(rand, defaultSize).toSeq, TestUtils.generateRandomMap(rand, defaultSize).toMap, Row(rand.nextInt())) } From 909edcd89ff16698f0035d7aa2f90308f5a3557f Mon Sep 17 00:00:00 2001 From: Ian Hummel Date: Tue, 18 Apr 2017 15:25:36 -0400 Subject: [PATCH 21/22] Revert "Merge master" This reverts commit 5de35746d30afa72dd00fe8635a63ff480ffc4da, reversing changes made to 3a1fd20c8d662f5b6f1a17574fe8d99d2a2e3161. --- .../spark/avro/AvroOutputWriter.scala | 3 -- .../spark/avro/SchemaConverters.scala | 2 - .../com/databricks/spark/avro/AvroSuite.scala | 37 ++++++------------- .../spark/avro/AvroWriteBenchmark.scala | 7 ++-- version.sbt | 2 +- 5 files changed, 15 insertions(+), 36 deletions(-) diff --git a/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala b/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala index c675ebe3..03c97792 100644 --- a/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala +++ b/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala @@ -19,7 +19,6 @@ package com.databricks.spark.avro import java.io.{IOException, OutputStream} import java.nio.ByteBuffer import java.sql.Timestamp -import java.sql.Date import java.util.HashMap import org.apache.avro.generic.GenericData.Record @@ -91,8 +90,6 @@ private[avro] class AvroOutputWriter( case _: DecimalType => (item: Any) => if (item == null) null else item.toString case TimestampType => (item: Any) => if (item == null) null else item.asInstanceOf[Timestamp].getTime - case DateType => (item: Any) => - if (item == null) null else item.asInstanceOf[Date].getTime case ArrayType(elementType, _) => val elementConverter = createConverterToAvro(elementType, structName, recordNamespace) (item: Any) => { diff --git a/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala b/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala index 033fc525..1b8bc450 100644 --- a/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala +++ b/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala @@ -338,7 +338,6 @@ object SchemaConverters { case BinaryType => schemaBuilder.bytesType() case BooleanType => schemaBuilder.booleanType() case TimestampType => schemaBuilder.longType() - case DateType => schemaBuilder.longType() case ArrayType(elementType, _) => val builder = getSchemaBuilder(dataType.asInstanceOf[ArrayType].containsNull) @@ -382,7 +381,6 @@ object SchemaConverters { case BinaryType => newFieldBuilder.bytesType() case BooleanType => newFieldBuilder.booleanType() case TimestampType => newFieldBuilder.longType() - case DateType => newFieldBuilder.longType() case ArrayType(elementType, _) => val builder = getSchemaBuilder(dataType.asInstanceOf[ArrayType].containsNull) diff --git a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala index 24089b57..229ef46d 100644 --- a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala +++ b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala @@ -19,25 +19,30 @@ package com.databricks.spark.avro import java.io._ import java.nio.ByteBuffer import java.nio.file.Files -import java.sql.{Date, Timestamp} -import java.util.{TimeZone, UUID} +import java.sql.Timestamp +import java.util.UUID -import org.apache.avro.Schema.{Field, Type} -import org.apache.avro.{Schema, SchemaBuilder} -import org.apache.avro.file.DataFileWriter import org.apache.spark.SparkConf import scala.collection.JavaConversions._ import com.databricks.spark.avro.SchemaConverters.IncompatibleSchemaException +import org.apache.avro.Schema +import org.apache.avro.Schema.{Field, Parser, Type} +import org.apache.avro.SchemaBuilder +import org.apache.avro.file.DataFileWriter import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord, GenericRecordBuilder} +import org.apache.avro.io.{DecoderFactory, JsonDecoder} +import org.apache.avro.specific.{SpecificData, SpecificDatumReader} import org.apache.commons.io.FileUtils import org.apache.hadoop.fs import org.apache.hadoop.fs.Path +import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.scalatest.{BeforeAndAfterAll, FunSuite} class AvroSuite extends FunSuite with BeforeAndAfterAll { @@ -307,26 +312,6 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll { } } - test("Date field type") { - TestUtils.withTempDir { dir => - val schema = StructType(Seq( - StructField("float", FloatType, true), - StructField("date", DateType, true) - )) - TimeZone.setDefault(TimeZone.getTimeZone("UTC")) - val rdd = spark.sparkContext.parallelize(Seq( - Row(1f, null), - Row(2f, new Date(1451948400000L)), - Row(3f, new Date(1460066400500L)) - )) - val df = spark.createDataFrame(rdd, schema) - df.write.avro(dir.toString) - assert(spark.read.avro(dir.toString).count == rdd.count) - assert(spark.read.avro(dir.toString).select("date").collect().map(_(0)).toSet == - Array(null, 1451865600000L, 1459987200000L).toSet) - } - } - test("Array data types") { TestUtils.withTempDir { dir => val testSchema = StructType(Seq( diff --git a/src/test/scala/com/databricks/spark/avro/AvroWriteBenchmark.scala b/src/test/scala/com/databricks/spark/avro/AvroWriteBenchmark.scala index 2ccc456b..b36438c1 100644 --- a/src/test/scala/com/databricks/spark/avro/AvroWriteBenchmark.scala +++ b/src/test/scala/com/databricks/spark/avro/AvroWriteBenchmark.scala @@ -16,7 +16,6 @@ package com.databricks.spark.avro -import java.sql.Date import java.util.concurrent.TimeUnit import scala.collection.JavaConversions._ @@ -24,7 +23,8 @@ import scala.util.Random import com.google.common.io.Files import org.apache.commons.io.FileUtils -import org.apache.spark.sql._ + +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.types._ /** @@ -40,7 +40,6 @@ object AvroWriteBenchmark { val testSchema = StructType(Seq( StructField("StringField", StringType, false), StructField("IntField", IntegerType, true), - StructField("dateField", DateType, true), StructField("DoubleField", DoubleType, false), StructField("DecimalField", DecimalType(10, 10), true), StructField("ArrayField", ArrayType(BooleanType), false), @@ -49,7 +48,7 @@ object AvroWriteBenchmark { private def generateRandomRow(): Row = { val rand = new Random() - Row(rand.nextString(defaultSize), rand.nextInt(), new Date(rand.nextLong()) ,rand.nextDouble(), rand.nextDouble(), + Row(rand.nextString(defaultSize), rand.nextInt(), rand.nextDouble(), rand.nextDouble(), TestUtils.generateRandomArray(rand, defaultSize).toSeq, TestUtils.generateRandomMap(rand, defaultSize).toMap, Row(rand.nextInt())) } diff --git a/version.sbt b/version.sbt index 0f7fe009..bcde3efb 100644 --- a/version.sbt +++ b/version.sbt @@ -1 +1 @@ -version in ThisBuild := "4.0.0-SNAPSHOT" +version in ThisBuild := "3.2.1-SNAPSHOT" From 7869c115d2dc4dd1216c53e13a6a38412e071a0f Mon Sep 17 00:00:00 2001 From: Ian Hummel Date: Mon, 30 Oct 2017 15:54:57 -0400 Subject: [PATCH 22/22] Flakey tests? --- src/test/scala/com/databricks/spark/avro/AvroSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala index b72f9f78..80ff4726 100644 --- a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala +++ b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala @@ -706,7 +706,6 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll { } } -<<<<<<< HEAD test("generic record converts to row and back") { val nested = SchemaBuilder.record("simple_record").fields()