Skip to content

Commit

Permalink
[SPARK-50792][SQL] Format binary data as a binary literal in JDBC
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Format binary data as a binary literal in JDBC.

### Why are the changes needed?

The binary data is not handled to format it as binary literal in JDBC connectors
These are the steps to reproduce.
1. CREATE TABLE test_binary_literal(b BINARY);
2. INSERT INTO test_binary_literal VALUES(x'010203');
3. SELECT * FROM test_binary_literal WHERE b=x'010203';
<img width="1180" alt="image" src="https://github.com/user-attachments/assets/92800c55-5400-46b0-b3f1-d95b85d89cb5" />

### Does this PR introduce _any_ user-facing change?

'No'

### How was this patch tested?

Added new integration tests

### Was this patch authored or co-authored using generative AI tooling?

'No'

Closes #49452 from sunxiaoguang/fix_binary_literal_format_in_jdbc.

Lead-authored-by: Xiaoguang Sun <sunxiaoguang@gmail.com>
Co-authored-by: Xiaoguang Sun <sunxiaoguang@users.noreply.github.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
2 people authored and cloud-fan committed Jan 16, 2025
1 parent 9b32334 commit 813591f
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -986,4 +986,41 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
test("scan with filter push-down with date time functions") {
testDatetime(s"$catalogAndNamespace.${caseConvert("datetime")}")
}

test("SPARK-50792: Format binary data as a binary literal in JDBC.") {
val tableName = s"$catalogName.test_binary_literal"
withTable(tableName) {
// Create a table with binary column
val binary = "X'123456'"
val lessThanBinary = "X'123455'"
val greaterThanBinary = "X'123457'"

sql(s"CREATE TABLE $tableName (binary_col BINARY)")
sql(s"INSERT INTO $tableName VALUES ($binary)")

def testBinaryLiteral(operator: String, literal: String, expected: Int): Unit = {
val sql = s"SELECT * FROM $tableName WHERE binary_col $operator $literal"
val df = spark.sql(sql)
checkFilterPushed(df)
val rows = df.collect()
assert(rows.length === expected, s"Failed to run $sql")
if (expected == 1) {
assert(rows(0)(0) === Array(0x12, 0x34, 0x56).map(_.toByte))
}
}

testBinaryLiteral("=", binary, 1)
testBinaryLiteral(">=", binary, 1)
testBinaryLiteral(">=", lessThanBinary, 1)
testBinaryLiteral(">", lessThanBinary, 1)
testBinaryLiteral("<=", binary, 1)
testBinaryLiteral("<=", greaterThanBinary, 1)
testBinaryLiteral("<", greaterThanBinary, 1)
testBinaryLiteral("<>", greaterThanBinary, 1)
testBinaryLiteral("<>", lessThanBinary, 1)
testBinaryLiteral("<=>", binary, 1)
testBinaryLiteral("<=>", lessThanBinary, 0)
testBinaryLiteral("<=>", greaterThanBinary, 0)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ private case class DB2Dialect() extends JdbcDialect with SQLConfHelper with NoLe
}
}

override def compileValue(value: Any): Any = value match {
case binaryValue: Array[Byte] =>
binaryValue.map("%02X".format(_)).mkString("BLOB(X'", "", "')")
case other => super.compileValue(other)
}

override def getCatalystType(
sqlType: Int,
typeName: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ abstract class JdbcDialect extends Serializable with Logging {
case dateValue: Date => "'" + dateValue + "'"
case dateValue: LocalDate => s"'${DateFormatter().format(dateValue)}'"
case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ")
case binaryValue: Array[Byte] => binaryValue.map("%02X".format(_)).mkString("X'", "", "'")
case _ => value
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr
// scalastyle:on line.size.limit
override def compileValue(value: Any): Any = value match {
case booleanValue: Boolean => if (booleanValue) 1 else 0
case binaryValue: Array[Byte] => binaryValue.map("%02X".format(_)).mkString("0x", "", "")
case other => super.compileValue(other)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.util.control.NonFatal

import org.apache.spark.{SparkThrowable, SparkUnsupportedOperationException}
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.connector.expressions.Expression
import org.apache.spark.sql.connector.expressions.{Expression, GeneralScalarExpression, Literal}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
import org.apache.spark.sql.jdbc.OracleDialect._
Expand Down Expand Up @@ -61,6 +61,34 @@ private case class OracleDialect() extends JdbcDialect with SQLConfHelper with N
} else {
super.visitAggregateFunction(funcName, isDistinct, inputs)
}

private def compareBlob(lhs: Expression, operator: String, rhs: Expression): String = {
val l = inputToSQL(lhs)
val r = inputToSQL(rhs)
val op = if (operator == "<=>") "=" else operator
val compare = s"DBMS_LOB.COMPARE($l, $r) $op 0"
if (operator == "<=>") {
s"(($l IS NOT NULL AND $r IS NOT NULL AND $compare) OR ($l IS NULL AND $r IS NULL))"
} else {
compare
}
}

override def build(expr: Expression): String = expr match {
case e: GeneralScalarExpression =>
e.name() match {
case "=" | "<>" | "<=>" | "<" | "<=" | ">" | ">=" =>
(e.children()(0), e.children()(1)) match {
case (lhs: Literal[_], rhs: Expression) if lhs.dataType == BinaryType =>
compareBlob(lhs, e.name, rhs)
case (lhs: Expression, rhs: Literal[_]) if rhs.dataType == BinaryType =>
compareBlob(lhs, e.name, rhs)
case _ => super.build(expr)
}
case _ => super.build(expr)
}
case _ => super.build(expr)
}
}

override def compileExpression(expr: Expression): Option[String] = {
Expand Down Expand Up @@ -138,6 +166,8 @@ private case class OracleDialect() extends JdbcDialect with SQLConfHelper with N
case timestampValue: Timestamp => "{ts '" + timestampValue + "'}"
case dateValue: Date => "{d '" + dateValue + "'}"
case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ")
case binaryValue: Array[Byte] =>
binaryValue.map("%02X".format(_)).mkString("HEXTORAW('", "", "')")
case _ => value
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,12 @@ private case class PostgresDialect()
}
}

override def compileValue(value: Any): Any = value match {
case binaryValue: Array[Byte] =>
binaryValue.map("%02X".format(_)).mkString("'\\x", "", "'::bytea")
case other => super.compileValue(other)
}

override def supportsLimit: Boolean = true

override def supportsOffset: Boolean = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3097,4 +3097,24 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
assert(rows.contains(Row(null)))
assert(rows.contains(Row("a a a")))
}

test("SPARK-50792: Format binary data as a binary literal in JDBC.") {
val tableName = "h2.test.binary_literal"
withTable(tableName) {
// Create a table with binary column
val binary = "X'123456'"

sql(s"CREATE TABLE $tableName (binary_col BINARY)")
sql(s"INSERT INTO $tableName VALUES ($binary)")

val select = s"SELECT * FROM $tableName WHERE binary_col = $binary"
val df = sql(select)
val filter = df.queryExecution.optimizedPlan.collect {
case f: Filter => f
}
assert(filter.isEmpty, "Filter is not pushed")
assert(df.collect().length === 1, s"Binary literal test failed: $select")
}
}

}

0 comments on commit 813591f

Please sign in to comment.