Skip to content

Commit

Permalink
[SC-5704][REDSHIFT] Refactor and improve old Filter Pushdown code path
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Main changes:

- Move FilterPushdown.scala under the pushdown package and make it reuse some of the helper functions there (e.g. wrap, block)
- Add support for more expressions: StartsWith, EndsWith, Contains, AND, OR, NOT, IN
- Add parenthesis around all basic predicates and reapprove affected tests.

## How was this patch tested?

Ran all unit tests and `RedshiftReadIntegrationSuite.scala`

Author: Adrian Ionescu <adrian@databricks.com>

Closes apache#227 from adrian-ionescu/redshift-basic-pushdown.
  • Loading branch information
adrian-ionescu authored and rxin committed Feb 15, 2017
1 parent 118b4dd commit b7e161d
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 116 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -143,19 +143,19 @@ class RedshiftReadSuite extends IntegrationSuiteBase {
"""
|select testbyte, testbool
|from test_table
|where testbool = true
| and teststring = "Unicode's樂趣"
| and testdouble = 1234152.12312498
| and testfloat = 1.0
| and testint = 42
|where (testbool = true)
| and (teststring = "Unicode's樂趣")
| and (testdouble = 1234152.12312498)
| and (testfloat = 1.0)
| and (testint = 42)
""".stripMargin),
Seq(Row(1, true)))
// scalastyle:on
}

test("RedshiftRelation implements Spark 1.6+'s unhandledFilters API") {
assume(org.apache.spark.SPARK_VERSION.take(3) >= "1.6")
val df = sqlContext.sql("select testbool from test_table where testbool = true")
val df = sqlContext.sql("select testbool from test_table where (testbool = true)")
val physicalPlan = df.queryExecution.sparkPlan
physicalPlan.collectFirst { case f: execution.FilterExec => f }.foreach { filter =>
fail(s"Filter should have been eliminated:\n${df.queryExecution}")
Expand Down Expand Up @@ -244,68 +244,68 @@ class RedshiftReadSuite extends IntegrationSuiteBase {

test("properly escape literals in filter pushdown (SC-5504)") {
checkAnswer(
sqlContext.sql("select count(1) from test_table where testint = 4141214"),
sqlContext.sql("select count(1) from test_table where (testint = 4141214)"),
Seq(Row(1))
)
checkAnswer(
sqlContext.sql("select count(1) from test_table where testint = 7"),
sqlContext.sql("select count(1) from test_table where (testint = 7)"),
Seq(Row(0))
)
checkAnswer(
sqlContext.sql("select testint from test_table where testint = 42"),
sqlContext.sql("select testint from test_table where (testint = 42)"),
Seq(Row(42), Row(42))
)

checkAnswer(
sqlContext.sql("select count(1) from test_table where teststring = 'asdf'"),
sqlContext.sql("select count(1) from test_table where (teststring = 'asdf')"),
Seq(Row(1))
)
checkAnswer(
sqlContext.sql("select count(1) from test_table where teststring = 'alamakota'"),
sqlContext.sql("select count(1) from test_table where (teststring = 'alamakota')"),
Seq(Row(0))
)
checkAnswer(
sqlContext.sql("select teststring from test_table where teststring = 'asdf'"),
sqlContext.sql("select teststring from test_table where (teststring = 'asdf')"),
Seq(Row("asdf"))
)

checkAnswer(
sqlContext.sql("select count(1) from test_table where teststring = 'a\\'b'"),
sqlContext.sql("select count(1) from test_table where (teststring = 'a\\'b')"),
Seq(Row(0))
)
checkAnswer(
sqlContext.sql("select teststring from test_table where teststring = 'a\\'b'"),
sqlContext.sql("select teststring from test_table where (teststring = 'a\\'b')"),
Seq()
)

// scalastyle:off
checkAnswer(
sqlContext.sql("select count(1) from test_table where teststring = 'Unicode\\'s樂趣'"),
sqlContext.sql("select count(1) from test_table where (teststring = 'Unicode\\'s樂趣')"),
Seq(Row(1))
)
checkAnswer(
sqlContext.sql("select teststring from test_table where teststring = \"Unicode's樂趣\""),
sqlContext.sql("select teststring from test_table where (teststring = \"Unicode's樂趣\")"),
Seq(Row("Unicode's樂趣"))
)
// scalastyle:on

checkAnswer(
sqlContext.sql("select count(1) from test_table where teststring = 'a\\\\b'"),
sqlContext.sql("select count(1) from test_table where (teststring = 'a\\\\b')"),
Seq(Row(0))
)
checkAnswer(
sqlContext.sql("select teststring from test_table where teststring = 'a\\\\b'"),
sqlContext.sql("select teststring from test_table where (teststring = 'a\\\\b')"),
Seq()
)

checkAnswer(
sqlContext.sql(
"select count(1) from test_table where teststring = 'Ba\\\\ckslash\\\\'"),
"select count(1) from test_table where (teststring = 'Ba\\\\ckslash\\\\')"),
Seq(Row(1))
)
checkAnswer(
sqlContext.sql(
"select teststring from test_table where teststring = \"Ba\\\\ckslash\\\\\""),
"select teststring from test_table where (teststring = \"Ba\\\\ckslash\\\\\")"),
Seq(Row("Ba\\ckslash\\"))
)
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import com.amazonaws.auth.AWSCredentialsProvider
import com.amazonaws.services.s3.AmazonS3Client
import com.databricks.spark.redshift.Parameters.MergedParameters
import com.databricks.spark.redshift.Utils.escapeJdbcString
import com.databricks.spark.redshift.pushdown.FilterPushdown
import org.json4s.{DefaultFormats, JValue, StreamInput}
import org.json4s.JsonAST.JValue
import org.json4s.jackson.JsonMethods
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Copyright (C) 2016 Databricks, Inc.
*
* Portions of this software incorporate or are derived from software contained within Apache Spark,
* and this modified software differs from the Apache Spark software provided under the Apache
* License, Version 2.0, a copy of which you may obtain at
* http://www.apache.org/licenses/LICENSE-2.0
*/

package com.databricks.spark.redshift.pushdown

import java.sql.{Date, Timestamp}

import com.databricks.spark.redshift.Utils

import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._

/**
* Helper methods for pushing filters into Redshift queries.
*/
private[redshift] object FilterPushdown {
/**
* Build a SQL WHERE clause for the given filters. If a filter cannot be pushed down then no
* condition will be added to the WHERE clause. If none of the filters can be pushed down then
* an empty string will be returned.
*
* @param schema the schema of the table being queried
* @param filters an array of filters, the conjunction of which is the filter condition for the
* scan.
*/
def buildWhereClause(schema: StructType, filters: Seq[Filter]): String = {
val filterExpressions = filters.flatMap(f => buildFilterExpression(schema, f)).mkString(" AND ")
if (filterExpressions.isEmpty) "" else "WHERE " + filterExpressions
}

/**
* Attempt to convert the given filter into a SQL expression. Returns None if the expression
* could not be converted.
*/
def buildFilterExpression(schema: StructType, filter: Filter): Option[String] = {

// Builds an escaped value, based on the expected datatype
def buildValueWithType(dataType: DataType, value: Any): String = {
dataType match {
case StringType => s"'${Utils.escapeRedshiftStringLiteral(value.toString)}'"
case DateType => s"'${value.asInstanceOf[Date]}'"
case TimestampType => s"'${value.asInstanceOf[Timestamp]}'"
case _ => value.toString
}
}

// Builds an escaped value, based on the value itself
def buildValue(value: Any): String = {
value match {
case _: String => s"'${Utils.escapeRedshiftStringLiteral(value.toString)}'"
case _: Date => s"'${value.asInstanceOf[Date]}'"
case _: Timestamp => s"'${value.asInstanceOf[Timestamp]}'"
case _ => value.toString
}
}

// Builds a simple comparison string
def buildComparison(attr: String, value: Any, comparisonOp: String): Option[String] = {
for {
dataType <- getTypeForAttribute(schema, attr)
sqlEscapedValue = buildValueWithType(dataType, value)
} yield {
s"""${wrap(attr)} $comparisonOp $sqlEscapedValue"""
}
}

// Builds a string out of a binary logical operation
def buildBooleanLogicExpr(left: Filter, right: Filter, logicalOp: String) : Option[String] = {
for {
leftStr <- buildFilterExpression(schema, left)
rightStr <- buildFilterExpression(schema, right)
} yield {
s"""$leftStr $logicalOp $rightStr"""
}
}

val predicateOption = filter match {
case EqualTo(attr, value) =>
buildComparison(attr, value, "=")
case LessThan(attr, value) =>
buildComparison(attr, value, "<")
case GreaterThan(attr, value) =>
buildComparison(attr, value, ">")
case LessThanOrEqual(attr, value) =>
buildComparison(attr, value, "<=")
case GreaterThanOrEqual(attr, value) =>
buildComparison(attr, value, ">=")
case In(attr, values: Array[Any]) =>
val dataType = getTypeForAttribute(schema, attr).get
val valueStrings = values.map(v => buildValueWithType(dataType, v)).mkString(", ")
Some(s"""${wrap(attr)} IN ${block(valueStrings)}""")
case IsNull(attr) =>
Some(s"""${wrap(attr)} IS NULL""")
case IsNotNull(attr) =>
Some(s"""${wrap(attr)} IS NOT NULL""")
case And(left, right) =>
buildBooleanLogicExpr(left, right, "AND")
case Or(left, right) =>
buildBooleanLogicExpr(left, right, "OR")
case Not(child) =>
buildFilterExpression(schema, child).map(s => s"""NOT $s""")
case StringStartsWith(attr, value) =>
Some(s"""${wrap(attr)} LIKE ${buildValue(value + "%")}""")
case StringEndsWith(attr, value) =>
Some(s"""${wrap(attr)} LIKE ${buildValue("%" + value)}""")
case StringContains(attr, value) =>
Some(s"""${wrap(attr)} LIKE ${buildValue("%" + value + "%")}""")
case _ => None
}

// Let's be safe and wrap every individual expression in parentheses in order to avoid having
// to reason about operator precedence rules in Redshift, which are briefly documented here:
// http://docs.aws.amazon.com/redshift/latest/dg/r_logical_condition.html
// Note that there's no mention of operators such as LIKE, IN, IS NULL, etc.
predicateOption.map(block)
}

/**
* Use the given schema to look up the attribute's data type. Returns None if the attribute could
* not be resolved.
*/
private def getTypeForAttribute(schema: StructType, attribute: String): Option[DataType] = {
if (schema.fieldNames.contains(attribute)) {
Some(schema(attribute).dataType)
} else {
None
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,12 @@ class RedshiftSourceSuite
val expectedQuery = (
"UNLOAD \\('SELECT \"testbyte\", \"testbool\" " +
"FROM \"PUBLIC\".\"test_table\" " +
"WHERE \"testbool\" = true " +
"AND \"teststring\" = \\\\'Unicode\\\\'\\\\'s樂趣\\\\' " +
"AND \"testdouble\" > 1000.0 " +
"AND \"testdouble\" < 1.7976931348623157E308 " +
"AND \"testfloat\" >= 1.0 " +
"AND \"testint\" <= 43'\\) " +
"WHERE \\(\"testbool\" = true\\) " +
"AND \\(\"teststring\" = \\\\'Unicode\\\\'\\\\'s樂趣\\\\'\\) " +
"AND \\(\"testdouble\" > 1000.0\\) " +
"AND \\(\"testdouble\" < 1.7976931348623157E308\\) " +
"AND \\(\"testfloat\" >= 1.0\\) " +
"AND \\(\"testint\" <= 43\\)'\\) " +
"TO '.*' " +
"WITH CREDENTIALS 'aws_access_key_id=test1;aws_secret_access_key=test2' " +
"ESCAPE").r
Expand Down
Loading

0 comments on commit b7e161d

Please sign in to comment.