diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 71034c2c43c77..2cf241de61f7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -118,7 +118,19 @@ trait ScalaReflection { case t if t <:< typeOf[Product] => val formalTypeArgs = t.typeSymbol.asClass.typeParams val TypeRef(_, _, actualTypeArgs) = t - val params = t.member(nme.CONSTRUCTOR).asMethod.paramss + val constructorSymbol = t.member(nme.CONSTRUCTOR) + val params = if (constructorSymbol.isMethod) { + constructorSymbol.asMethod.paramss + } else { + // Find the primary constructor, and use its parameter ordering. + val primaryConstructorSymbol: Option[Symbol] = constructorSymbol.asTerm.alternatives.find( + s => s.isMethod && s.asMethod.isPrimaryConstructor) + if (primaryConstructorSymbol.isEmpty) { + sys.error("Internal SQL error: Product object did not have a primary constructor.") + } else { + primaryConstructorSymbol.get.asMethod.paramss + } + } Schema(StructType( params.head.map { p => val Schema(dataType, nullable) = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index ddc3d44869c98..7be24bea7d5a6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -68,6 +68,10 @@ case class ComplexData( case class GenericData[A]( genericField: A) +case class MultipleConstructorsData(a: Int, b: String, c: Double) { + def this(b: String, a: Int) = this(a, b, c = 1.0) +} + class ScalaReflectionSuite extends FunSuite { import ScalaReflection._ @@ -253,4 +257,14 @@ class ScalaReflectionSuite extends FunSuite { Row(1, 1, 1, 1, 1, 1, true)) assert(convertToCatalyst(data, dataType) === convertedData) } + + test("infer schema from case class with multiple constructors") { + val dataType = schemaFor[MultipleConstructorsData].dataType + dataType match { + case s: StructType => + // Schema should have order: a: Int, b: String, c: Double + assert(s.fieldNames === Seq("a", "b", "c")) + assert(s.fields.map(_.dataType) === Seq(IntegerType, StringType, DoubleType)) + } + } }