Skip to content

Commit

Permalink
Fix code generation. Fix joins.
Browse files Browse the repository at this point in the history
  • Loading branch information
hvanhovell committed Nov 24, 2016
1 parent e4cc4b0 commit 7d847b8
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2336,14 +2336,14 @@ object ResolveCreateNamedStruct extends Rule[LogicalPlan] {
*/
object SortMaps extends Rule[LogicalPlan] {
private def containsUnorderedMap(e: Expression): Boolean =
MapType.containsUnorderedMap(e.dataType)
e.resolved && MapType.containsUnorderedMap(e.dataType)

override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions {
case cmp @ BinaryComparison(left, right) if cmp.resolved && containsUnorderedMap(left) =>
case cmp @ BinaryComparison(left, right) if containsUnorderedMap(left) =>
cmp.withNewChildren(OrderMaps(left) :: right :: Nil)
case cmp @ BinaryComparison(left, right) if cmp.resolved && containsUnorderedMap(right) =>
case cmp @ BinaryComparison(left, right) if containsUnorderedMap(right) =>
cmp.withNewChildren(left :: OrderMaps(right) :: Nil)
case sort: SortOrder if sort.resolved && containsUnorderedMap(sort.child) =>
case sort: SortOrder if containsUnorderedMap(sort.child) =>
sort.copy(child = OrderMaps(sort.child))
} transform {
case a: Aggregate if a.resolved && a.groupingExpressions.exists(containsUnorderedMap) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ class CodegenContext {
case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)"
case array: ArrayType => genComp(array, c1, c2) + " == 0"
case struct: StructType => genComp(struct, c1, c2) + " == 0"
case map: MapType if map.ordered => genComp(map, c1, c2) + " == 0"
case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2)
case _ =>
throw new IllegalArgumentException(
Expand Down Expand Up @@ -567,9 +568,9 @@ class CodegenContext {
ArrayData bValues = b.valueArray();
int minLength = (lengthA > lengthB) ? lengthB : lengthA;
for (int i = 0; i < minLength; i++) {
${javaType(keyType)} keyA = ${getValue("aKeys", valueType, "i")};
${javaType(keyType)} keyB = ${getValue("bKeys", valueType, "i")};
int comp = ${genComp(valueType, "keyA", "keyB")};
${javaType(keyType)} keyA = ${getValue("aKeys", keyType, "i")};
${javaType(keyType)} keyB = ${getValue("bKeys", keyType, "i")};
int comp = ${genComp(keyType, "keyA", "keyB")};
if (comp != 0) {
return comp;
}
Expand All @@ -584,19 +585,13 @@ class CodegenContext {
} else {
${javaType(valueType)} valueA = ${getValue("aValues", valueType, "i")};
${javaType(valueType)} valueB = ${getValue("bValues", valueType, "i")};
int comp = ${genComp(valueType, "valueA", "valueB")};
comp = ${genComp(valueType, "valueA", "valueB")};
if (comp != 0) {
return comp;
}
}
}

if (lengthA < lengthB) {
return -1;
} else if (lengthA > lengthB) {
return 1;
}
return 0;
return lengthA - lengthB;
}
"""
addNewFunction(compareFunc, funcCode)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,8 @@ case class EqualTo(left: Expression, right: Expression)
case TypeCheckResult.TypeCheckSuccess =>
// Maps are only allowed when they are ordered.
if (MapType.containsUnorderedMap(left.dataType)) {
TypeCheckResult.TypeCheckFailure("Cannot use unordered map type in EqualTo, but " +
s"the actual input type is ${left.dataType.catalogString}.")
TypeCheckResult.TypeCheckFailure(
s"Cannot use unordered map type in EqualTo: ${left.dataType.catalogString}.")
} else {
TypeCheckResult.TypeCheckSuccess
}
Expand Down Expand Up @@ -452,8 +452,8 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
EqualNullSafe
// Maps are only allowed when they are ordered.
if (MapType.containsUnorderedMap(left.dataType)) {
TypeCheckResult.TypeCheckFailure("Cannot use unordered map type in EqualNullSafe, but " +
s"the actual input type is ${left.dataType.catalogString}.")
TypeCheckResult.TypeCheckFailure(
s"Cannot use unordered map type in EqualNullSafe: ${left.dataType.catalogString}.")
} else {
TypeCheckResult.TypeCheckSuccess
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ class AnalysisErrorSuite extends AnalysisTest {

errorTest(
"sorting by unsupported column types",
mapRelation.orderBy('map.asc),
"sort" :: "type" :: "map<int,int>" :: Nil)
intervalRelation.orderBy('interval.asc),
"sort" :: "type" :: "calendarinterval" :: Nil)

errorTest(
"sorting by attributes are not from grouping expressions",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
val e = intercept[AnalysisException] {
assertSuccess(expr)
}
assert(e.getMessage.contains(
s"cannot resolve '${expr.sql}' due to data type mismatch:"))
assert(e.getMessage.contains("cannot resolve "))
assert(e.getMessage.contains("due to data type mismatch:"))
assert(e.getMessage.contains(errorMessage))
}

Expand All @@ -51,8 +51,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
}

def assertErrorForDifferingTypes(expr: Expression): Unit = {
assertError(expr,
s"differing types in '${expr.sql}'")
assertError(expr, "differing types in")
}

test("check types for unary arithmetic") {
Expand Down Expand Up @@ -99,6 +98,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertSuccess(LessThanOrEqual('intField, 'stringField))
assertSuccess(GreaterThan('intField, 'stringField))
assertSuccess(GreaterThanOrEqual('intField, 'stringField))
assertSuccess(EqualTo('mapField, 'mapField))
assertSuccess(EqualNullSafe('mapField, 'mapField))

// We will transform EqualTo with numeric and boolean types to CaseKeyWhen
assertSuccess(EqualTo('intField, 'booleanField))
Expand All @@ -111,8 +112,6 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField))
assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField))

assertError(EqualTo('mapField, 'mapField), "Cannot use map type in EqualTo")
assertError(EqualNullSafe('mapField, 'mapField), "Cannot use map type in EqualNullSafe")
assertError(LessThan('mapField, 'mapField),
s"requires ${TypeCollection.Ordered.simpleString} type")
assertError(LessThanOrEqual('mapField, 'mapField),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,7 @@ object TestRelations {

val mapRelation = LocalRelation(
AttributeReference("map", MapType(IntegerType, IntegerType))())

val intervalRelation = LocalRelation(
AttributeReference("interval", CalendarIntervalType)())
}

0 comments on commit 7d847b8

Please sign in to comment.