Skip to content

Commit

Permalink
Add distinct & sorting. Improve aggregation.
Browse files Browse the repository at this point in the history
  • Loading branch information
hvanhovell committed Nov 22, 2016
1 parent edec2d8 commit b16ddb8
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2348,12 +2348,36 @@ object SortMaps extends Rule[LogicalPlan] {
cmp.withNewChildren(SortMap(left) :: right :: Nil)
case cmp @ BinaryComparison(left, right) if cmp.resolved && hasUnorderedMap(right) =>
cmp.withNewChildren(left :: SortMap(right) :: Nil)
case sort: SortOrder if sort.resolved && hasUnorderedMap(sort.child) =>
sort.copy(child = SortMap(sort.child))
} transform {
case a: Aggregate if a.resolved && a.groupingExpressions.exists(hasUnorderedMap) =>
a.transformExpressionsUp {
// Modify the top level grouping expressions
val replacements = a.groupingExpressions.collect {
case a: Attribute if hasUnorderedMap(a) =>
a -> Alias(SortMap(a), a.name)(exprId = a.exprId, qualifier = a.qualifier)
case e if hasUnorderedMap(e) =>
e -> SortMap(e)
}

// Tranform the expression tree.
a.transformExpressionsUp {
case e =>
// TODO create an expression map!
replacements
.find(_._1.semanticEquals(e))
.map(_._2)
.getOrElse(e)
}

case Distinct(child) if child.resolved && child.output.exists(hasUnorderedMap) =>
val projectList = child.output.map { a =>
if (hasUnorderedMap(a)) {
Alias(SortMap(a), a.name)(exprId = a.exprId, qualifier = a.qualifier)
case e if hasUnorderedMap(e) => SortMap(e)
} else {
a
}
}
Distinct(Project(projectList, child))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -439,12 +439,7 @@ class AnalysisErrorSuite extends AnalysisTest {
checkDataType(dataType, shouldSuccess = true)
}

val unsupportedDataTypes = Seq(
MapType(StringType, LongType),
new StructType()
.add("f1", FloatType, nullable = true)
.add("f2", MapType(StringType, LongType), nullable = true),
new UngroupableUDT())
val unsupportedDataTypes = Seq(new UngroupableUDT())
unsupportedDataTypes.foreach { dataType =>
checkDataType(dataType, shouldSuccess = false)
}
Expand Down Expand Up @@ -479,20 +474,6 @@ class AnalysisErrorSuite extends AnalysisTest {
AttributeReference("c", BinaryType)(exprId = ExprId(4)))))

assertAnalysisError(plan, "binary type expression `a` cannot be used in join conditions" :: Nil)

val plan2 =
Join(
LocalRelation(
AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)),
AttributeReference("b", IntegerType)(exprId = ExprId(1))),
LocalRelation(
AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4)),
AttributeReference("d", IntegerType)(exprId = ExprId(3))),
Cross,
Some(EqualTo(AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)),
AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4)))))

assertAnalysisError(plan2, "map type expression `a` cannot be used in join conditions" :: Nil)
}

test("PredicateSubQuery is used outside of a filter") {
Expand Down

0 comments on commit b16ddb8

Please sign in to comment.