Skip to content

Commit

Permalink
Merge branch 'master' into SPARK-8103
Browse files Browse the repository at this point in the history
  • Loading branch information
squito committed Jul 14, 2015
2 parents 906d626 + c4e98ff commit a21c8b5
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 50 deletions.
11 changes: 10 additions & 1 deletion build/mvn
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,17 @@ install_scala() {
# the environment
ZINC_PORT=${ZINC_PORT:-"3030"}

# Check for the `--force` flag dictating that `mvn` should be downloaded
# regardless of whether the system already has a `mvn` install
if [ "$1" == "--force" ]; then
FORCE_MVN=1
shift
fi

# Install Maven if necessary
MVN_BIN="$(command -v mvn)"

if [ ! "$MVN_BIN" ]; then
if [ ! "$MVN_BIN" -o -n "$FORCE_MVN" ]; then
install_mvn
fi

Expand All @@ -139,5 +146,7 @@ fi
# Set any `mvn` options if not already present
export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"}

echo "Using \`mvn\` from path: $MVN_BIN"

# Last, call the `mvn` command as usual
${MVN_BIN} "$@"
Original file line number Diff line number Diff line change
Expand Up @@ -230,24 +230,31 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
}
}

private def evalElse(input: InternalRow): Any = {
if (branchesArr.length % 2 == 0) {
null
} else {
branchesArr(branchesArr.length - 1).eval(input)
}
}

/** Written in imperative fashion for performance considerations. */
override def eval(input: InternalRow): Any = {
val evaluatedKey = key.eval(input)
val len = branchesArr.length
var i = 0
// If all branches fail and an elseVal is not provided, the whole statement
// defaults to null, according to Hive's semantics.
while (i < len - 1) {
if (threeValueEquals(evaluatedKey, branchesArr(i).eval(input))) {
return branchesArr(i + 1).eval(input)
// If key is null, we can just return the else part or null if there is no else.
// If key is not null but doesn't match any when part, we need to return
// the else part or null if there is no else, according to Hive's semantics.
if (evaluatedKey != null) {
val len = branchesArr.length
var i = 0
while (i < len - 1) {
if (evaluatedKey == branchesArr(i).eval(input)) {
return branchesArr(i + 1).eval(input)
}
i += 2
}
i += 2
}
var res: Any = null
if (i == len - 1) {
res = branchesArr(i).eval(input)
}
return res
evalElse(input)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
Expand All @@ -261,8 +268,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
s"""
if (!$got) {
${cond.code}
if (!${keyEval.isNull} && !${cond.isNull}
&& ${ctx.genEqual(key.dataType, keyEval.primitive, cond.primitive)}) {
if (!${cond.isNull} && ${ctx.genEqual(key.dataType, keyEval.primitive, cond.primitive)}) {
$got = true;
${res.code}
${ev.isNull} = ${res.isNull};
Expand Down Expand Up @@ -290,19 +296,13 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
${keyEval.code}
$cases
if (!${keyEval.isNull}) {
$cases
}
$other
"""
}

private def threeValueEquals(l: Any, r: Any) = {
if (l == null || r == null) {
false
} else {
l == r
}
}

override def toString: String = {
s"CASE $key" + branches.sliding(2, 2).map {
case Seq(cond, value) => s" WHEN $cond THEN $value"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.serde.serdeConstants

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions.{Expression, AttributeReference, BinaryComparison}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{StringType, IntegralType}

/**
Expand Down Expand Up @@ -312,37 +312,41 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] =
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq

override def getPartitionsByFilter(
hive: Hive,
table: Table,
predicates: Seq[Expression]): Seq[Partition] = {
/**
* Converts catalyst expression to the format that Hive's getPartitionsByFilter() expects, i.e.
* a string that represents partition predicates like "str_key=\"value\" and int_key=1 ...".
*
* Unsupported predicates are skipped.
*/
def convertFilters(table: Table, filters: Seq[Expression]): String = {
// hive varchar is treated as catalyst string, but hive varchar can't be pushed down.
val varcharKeys = table.getPartitionKeys
.filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME))
.map(col => col.getName).toSet

// Hive getPartitionsByFilter() takes a string that represents partition
// predicates like "str_key=\"value\" and int_key=1 ..."
val filter = predicates.flatMap { expr =>
expr match {
case op @ BinaryComparison(lhs, rhs) => {
lhs match {
case AttributeReference(_, _, _, _) => {
rhs.dataType match {
case _: IntegralType =>
Some(lhs.prettyString + op.symbol + rhs.prettyString)
case _: StringType if (!varcharKeys.contains(lhs.prettyString)) =>
Some(lhs.prettyString + op.symbol + "\"" + rhs.prettyString + "\"")
case _ => None
}
}
case _ => None
}
}
case _ => None
}
filters.collect {
case op @ BinaryComparison(a: Attribute, Literal(v, _: IntegralType)) =>
s"${a.name} ${op.symbol} $v"
case op @ BinaryComparison(Literal(v, _: IntegralType), a: Attribute) =>
s"$v ${op.symbol} ${a.name}"

case op @ BinaryComparison(a: Attribute, Literal(v, _: StringType))
if !varcharKeys.contains(a.name) =>
s"""${a.name} ${op.symbol} "$v""""
case op @ BinaryComparison(Literal(v, _: StringType), a: Attribute)
if !varcharKeys.contains(a.name) =>
s""""$v" ${op.symbol} ${a.name}"""
}.mkString(" and ")
}

override def getPartitionsByFilter(
hive: Hive,
table: Table,
predicates: Seq[Expression]): Seq[Partition] = {

// Hive getPartitionsByFilter() takes a string that represents partition
// predicates like "str_key=\"value\" and int_key=1 ..."
val filter = convertFilters(table, predicates)
val partitions =
if (filter.isEmpty) {
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.hive.client

import scala.collection.JavaConversions._

import org.apache.hadoop.hive.metastore.api.FieldSchema
import org.apache.hadoop.hive.serde.serdeConstants

import org.apache.spark.{Logging, SparkFunSuite}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._

/**
* A set of tests for the filter conversion logic used when pushing partition pruning into the
* metastore
*/
class FiltersSuite extends SparkFunSuite with Logging {
private val shim = new Shim_v0_13

private val testTable = new org.apache.hadoop.hive.ql.metadata.Table("default", "test")
private val varCharCol = new FieldSchema()
varCharCol.setName("varchar")
varCharCol.setType(serdeConstants.VARCHAR_TYPE_NAME)
testTable.setPartCols(varCharCol :: Nil)

filterTest("string filter",
(a("stringcol", StringType) > Literal("test")) :: Nil,
"stringcol > \"test\"")

filterTest("string filter backwards",
(Literal("test") > a("stringcol", StringType)) :: Nil,
"\"test\" > stringcol")

filterTest("int filter",
(a("intcol", IntegerType) === Literal(1)) :: Nil,
"intcol = 1")

filterTest("int filter backwards",
(Literal(1) === a("intcol", IntegerType)) :: Nil,
"1 = intcol")

filterTest("int and string filter",
(Literal(1) === a("intcol", IntegerType)) :: (Literal("a") === a("strcol", IntegerType)) :: Nil,
"1 = intcol and \"a\" = strcol")

filterTest("skip varchar",
(Literal("") === a("varchar", StringType)) :: Nil,
"")

private def filterTest(name: String, filters: Seq[Expression], result: String) = {
test(name){
val converted = shim.convertFilters(testTable, filters)
if (converted != result) {
fail(
s"Expected filters ${filters.mkString(",")} to convert to '$result' but got '$converted'")
}
}
}

private def a(name: String, dataType: DataType) = AttributeReference(name, dataType)()
}

0 comments on commit a21c8b5

Please sign in to comment.