diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/TestXmlData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/TestXmlData.scala index abcf8c7cdd726..704a02482ada8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/TestXmlData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/TestXmlData.scala @@ -68,4 +68,314 @@ private[xml] trait TestXmlData { f(dir) fs.setVerifyChecksum(true) } + + def primitiveFieldValueTypeConflict: Seq[String] = + """ + | 11 + | + | 1.1 + | true + | 13.1 + | str1 + | + |""".stripMargin :: + """ + | + | + | 21474836470.9 + | + | 12 + | + | true + |""".stripMargin :: + """ + | + | 21474836470 + | 92233720368547758070 + | 100 + | false + | str1 + | false + |""".stripMargin :: + """ + | + | 21474836570 + | 1.1 + | 21474836470 + | + | 92233720368547758070 + | + |""".stripMargin :: Nil + + def xmlNullStruct: Seq[String] = + """ + | + | 27.31.100.29 + | + | 1.abc.com + | UTF-8 + | + |""".stripMargin :: + """ + | + | 27.31.100.29 + | + |""".stripMargin :: + """ + | + | 27.31.100.29 + | + |""".stripMargin :: + """ + | + | 27.31.100.29 + | + |""".stripMargin :: Nil + + def complexFieldValueTypeConflict: Seq[String] = + """ + 11 + 1 + 2 + 3 + + + + """ :: + """ + + false + + + + + + """ :: + """ + + str + 4 + 5 + 6 + 7 + 8 + 9 + + + + """ :: + """ + + str1 + str2 + 33 + 7 + + true + + + str + + """ :: Nil + + def arrayElementTypeConflict: Seq[String] = + """ + | + | + | 1 + | 1.1 + | true + | + | + | + | + | + | + | + | + | + | + | + | 2 + | 3 + | 4 + | + | + | + | + | str + | + | + | + | + | 214748364700 + | + | + | 1 + | + | + |""".stripMargin :: + """ + | + | + | str + | + | + | 1 + | + | + |""".stripMargin :: + """ + | + | 1 + | 2 + | 3 + | + |""".stripMargin :: Nil + + def missingFields: Seq[String] = + """ + true + """ :: + """ + 21474836470 + """ :: + """ + 3344 + """ :: + """ + true + """ :: + """ + str + """ :: Nil + + // XML doesn't support array of arrays + // It only supports array of structs + def complexFieldAndType1: Seq[String] = + """ + | + | + | true + | 92233720368547758070 + | + | + | 4 + | 5 + | 6 + | str1 + | str2 + | + | str1 + | str2 + | 1 + | 2147483647 + | -2147483648 + | 21474836470 + | 9223372036854775807 + | -9223372036854775808 + | 922337203685477580700 + | -922337203685477580800 + | 1.2 + | 1.7976931348623157 + | 4.9E-324 + | 2.2250738585072014E-308 + | true + | false + | true + | + | + | + | true + | str1 + | + | + | false + | + | + | + | + | + | 123 + | + | + | str1str2 + | + | + | 123 + | + | + | 1.12.13.1 + | + | + | + |""".stripMargin :: Nil + + def complexFieldAndType2: Seq[String] = + """ + | + | + | + | 5 + | + | + | + | + | 67 + | + | + | 8 + | + | + | + | + | + | str1 + | + | + | + | + | + | + | + | str3 + | str33 + | + | + | str4 + | str11 + | + | + | + | + | + | + | + | 2 + | 3 + | + | + | + | + | + | + |""".stripMargin :: Nil + + def emptyRecords: Seq[String] = + """ + + """ :: + """ + + + + """ :: + """ + + + + + + + """ :: Nil } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlInferSchemaSuite.scala new file mode 100644 index 0000000000000..697bd3d8b824f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlInferSchemaSuite.scala @@ -0,0 +1,296 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.xml + +import org.apache.spark.sql.{DataFrame, Encoders, QueryTest, Row} +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{ + ArrayType, + BooleanType, + DecimalType, + DoubleType, + LongType, + StringType, + StructField, + StructType +} + +class XmlInferSchemaSuite extends QueryTest with SharedSparkSession with TestXmlData { + + val baseOptions = Map("rowTag" -> "ROW") + + def readData(xmlString: Seq[String], options: Map[String, String] = Map.empty): DataFrame = { + val dataset = spark.createDataset(spark.sparkContext.parallelize(xmlString))(Encoders.STRING) + spark.read.options(baseOptions ++ options).xml(dataset) + } + + // TODO: add tests for type widening + test("Type conflict in primitive field values") { + val xmlDF = readData(primitiveFieldValueTypeConflict, Map("nullValue" -> "")) + val expectedSchema = StructType( + StructField("num_bool", StringType, true) :: + StructField("num_num_1", LongType, true) :: + StructField("num_num_2", DoubleType, true) :: + StructField("num_num_3", DoubleType, true) :: + StructField("num_str", StringType, true) :: + StructField("str_bool", StringType, true) :: Nil + ) + val expectedAns = Row("true", 11L, null, 1.1, "13.1", "str1") :: + Row("12", null, 21474836470.9, null, null, "true") :: + Row("false", 21474836470L, 92233720368547758070d, 100, "str1", "false") :: + Row(null, 21474836570L, 1.1, 21474836470L, "92233720368547758070", null) :: Nil + assert(expectedSchema == xmlDF.schema) + checkAnswer(xmlDF, expectedAns) + } + + test("Type conflict in complex field values") { + val xmlDF = readData( + complexFieldValueTypeConflict, + Map("nullValue" -> "", "ignoreSurroundingSpaces" -> "true") + ) + // XML will merge an array and a singleton into an array + val expectedSchema = StructType( + StructField("array", ArrayType(LongType, true), true) :: + StructField("num_struct", StringType, true) :: + StructField("str_array", ArrayType(StringType), true) :: + StructField("struct", StructType(StructField("field", StringType, true) :: Nil), true) :: + StructField("struct_array", ArrayType(StringType), true) :: Nil + ) + + assert(expectedSchema === xmlDF.schema) + checkAnswer( + xmlDF, + Row(Seq(null), "11", Seq("1", "2", "3"), Row(null), Seq(null)) :: + Row(Seq(null), """false""", Seq(null), Row(null), Seq(null)) :: + Row(Seq(4, 5, 6), null, Seq("str"), Row(null), Seq("7", "8", "9")) :: + Row(Seq(7), null, Seq("str1", "str2", "33"), Row("str"), Seq("""true""")) :: + Nil + ) + } + + test("Type conflict in array elements") { + val xmlDF = + readData( + arrayElementTypeConflict, + Map("ignoreSurroundingSpaces" -> "true", "nullValue" -> "")) + + val expectedSchema = StructType( + StructField( + "array1", + ArrayType(StructType(StructField("element", ArrayType(StringType)) :: Nil), true), + true + ) :: + StructField( + "array2", + ArrayType(StructType(StructField("field", LongType, true) :: Nil), true), + true + ) :: + StructField("array3", ArrayType(StringType, true), true) :: Nil + ) + + assert(xmlDF.schema === expectedSchema) + checkAnswer( + xmlDF, + Row( + Seq( + Row(List("1", "1.1", "true", null, "", "")), + Row( + List( + """ + | 2 + | 3 + | 4 + | """.stripMargin, + """ + | str + | """.stripMargin + ) + ) + ), + Seq(Row(214748364700L), Row(1)), + null + ) :: + Row(null, null, Seq("""str""", """1""")) :: + Row(null, null, Seq("1", "2", "3")) :: Nil + ) + } + + test("Handling missing fields") { + val xmlDF = readData(missingFields) + + val expectedSchema = StructType( + StructField("a", BooleanType, true) :: + StructField("b", LongType, true) :: + StructField("c", ArrayType(LongType, true), true) :: + StructField("d", StructType(StructField("field", BooleanType, true) :: Nil), true) :: + StructField("e", StringType, true) :: Nil + ) + + assert(expectedSchema === xmlDF.schema) + + } + + test("Complex field and type inferring") { + val xmlDF = readData(complexFieldAndType1, Map("prefersDecimal" -> "true")) + val expectedSchema = StructType( + StructField( + "arrayOfArray1", + ArrayType(StructType(StructField("item", ArrayType(StringType, true)) :: Nil)), + true + ) :: + StructField( + "arrayOfArray2", + ArrayType(StructType(StructField("item", ArrayType(DecimalType(21, 1), true)) :: Nil), true) + ) :: + StructField("arrayOfBigInteger", ArrayType(DecimalType(21, 0), true), true) :: + StructField("arrayOfBoolean", ArrayType(BooleanType, true), true) :: + StructField("arrayOfDouble", ArrayType(DoubleType, true), true) :: + StructField("arrayOfInteger", ArrayType(LongType, true), true) :: + StructField("arrayOfLong", ArrayType(DecimalType(20, 0), true), true) :: + StructField("arrayOfNull", ArrayType(StringType, true), true) :: + StructField("arrayOfString", ArrayType(StringType, true), true) :: + StructField( + "arrayOfStruct", + ArrayType( + StructType( + StructField("field1", BooleanType, true) :: + StructField("field2", StringType, true) :: + StructField("field3", StringType, true) :: Nil + ), + true + ), + true + ) :: + StructField( + "struct", + StructType( + StructField("field1", BooleanType, true) :: + StructField("field2", DecimalType(20, 0), true) :: Nil + ), + true + ) :: + StructField( + "structWithArrayFields", + StructType( + StructField("field1", ArrayType(LongType, true), true) :: + StructField("field2", ArrayType(StringType, true), true) :: Nil + ), + true + ) :: Nil + ) + assert(expectedSchema === xmlDF.schema) + } + + test("complex arrays") { + val xmlDF = readData(complexFieldAndType2) + val expectedSchemaArrayOfArray1 = new StructType().add( + "arrayOfArray1", + ArrayType( + new StructType() + .add("array", ArrayType(new StructType().add("item", ArrayType(LongType)))) + ) + ) + assert(xmlDF.select("arrayOfArray1").schema === expectedSchemaArrayOfArray1) + checkAnswer( + xmlDF.select("arrayOfArray1"), + Row( + Seq( + Row(Seq(Row(Seq(5)))), + Row(Seq(Row(Seq(6, 7)), Row(Seq(8)))) + ) + ) :: Nil + ) + val expectedSchemaArrayOfArray2 = new StructType().add( + "arrayOfArray2", + ArrayType( + new StructType() + .add( + "array", + ArrayType( + new StructType().add( + "item", + ArrayType( + new StructType() + .add("inner1", StringType) + .add("inner2", ArrayType(StringType)) + .add("inner3", ArrayType(new StructType().add("inner4", ArrayType(LongType)))) + ) + ) + ) + ) + ) + ) + assert(xmlDF.select("arrayOfArray2").schema === expectedSchemaArrayOfArray2) + checkAnswer( + xmlDF.select("arrayOfArray2"), + Row( + Seq( + Row(Seq(Row(Seq(Row("str1", null, null))))), + Row( + Seq( + Row(null), + Row(Seq(Row(null, Seq("str3", "str33"), null), Row("str11", Seq("str4"), null))) + ) + ), + Row(Seq(Row(Seq(Row(null, null, Seq(Row(Seq(2, 3)), Row(null))))))) + ) + ) :: Nil + ) + } + + test("Complex field and type inferring with null in sampling") { + val xmlDF = readData(xmlNullStruct) + val expectedSchema = StructType( + StructField( + "headers", + StructType( + StructField("Charset", StringType, true) :: + StructField("Host", StringType, true) :: Nil + ), + true + ) :: + StructField("ip", StringType, true) :: + StructField("nullstr", StringType, true) :: Nil + ) + + assert(expectedSchema === xmlDF.schema) + checkAnswer( + xmlDF.select("nullStr", "headers.Host"), + Seq(Row("", "1.abc.com"), Row("", null), Row("", null), Row("", null)) + ) + } + + test("empty records") { + val emptyDF = readData(emptyRecords) + val expectedSchema = new StructType() + .add( + "a", + new StructType() + .add( + "struct", + StructType(StructField("b", StructType(StructField("c", StringType) :: Nil)) :: Nil))) + .add( + "b", + new StructType() + .add( + "item", + ArrayType( + new StructType().add("c", StructType(StructField("struct", StringType) :: Nil))))) + assert(emptyDF.schema === expectedSchema) + } + +}