diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 993d085c75089..31d27bb4f0571 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -430,7 +430,7 @@ class SchemaRDD( * @group schema */ private def applySchema(rdd: RDD[Row]): SchemaRDD = { - new SchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(logicalPlan.output, rdd))) + new SchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(queryExecution.analyzed.output, rdd))) } // ======================================================================= diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index 68dae58728a2a..c8ea01c4e1b6a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -33,6 +33,12 @@ class DslQuerySuite extends QueryTest { testData.collect().toSeq) } + test("repartition") { + checkAnswer( + testData.select('key).repartition(10).select('key), + testData.select('key).collect().toSeq) + } + test("agg") { checkAnswer( testData2.groupBy('a)('a, Sum('b)),