From 813591faddb6a7cd2e50ac13cdfa98e6d0ce6ffa Mon Sep 17 00:00:00 2001 From: Xiaoguang Sun Date: Thu, 16 Jan 2025 15:54:58 +0800 Subject: [PATCH] [SPARK-50792][SQL] Format binary data as a binary literal in JDBC ### What changes were proposed in this pull request? Format binary data as a binary literal in JDBC. ### Why are the changes needed? The binary data is not handled to format it as binary literal in JDBC connectors These are the steps to reproduce. 1. CREATE TABLE test_binary_literal(b BINARY); 2. INSERT INTO test_binary_literal VALUES(x'010203'); 3. SELECT * FROM test_binary_literal WHERE b=x'010203'; image ### Does this PR introduce _any_ user-facing change? 'No' ### How was this patch tested? Added new integration tests ### Was this patch authored or co-authored using generative AI tooling? 'No' Closes #49452 from sunxiaoguang/fix_binary_literal_format_in_jdbc. Lead-authored-by: Xiaoguang Sun Co-authored-by: Xiaoguang Sun Signed-off-by: Wenchen Fan --- .../apache/spark/sql/jdbc/v2/V2JDBCTest.scala | 37 +++++++++++++++++++ .../apache/spark/sql/jdbc/DB2Dialect.scala | 6 +++ .../apache/spark/sql/jdbc/JdbcDialects.scala | 1 + .../spark/sql/jdbc/MsSqlServerDialect.scala | 1 + .../apache/spark/sql/jdbc/OracleDialect.scala | 32 +++++++++++++++- .../spark/sql/jdbc/PostgresDialect.scala | 6 +++ .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 20 ++++++++++ 7 files changed, 102 insertions(+), 1 deletion(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala index 54635f69f8b65..f97b6a6eb183d 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala @@ -986,4 +986,41 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu test("scan with filter push-down with date time functions") { testDatetime(s"$catalogAndNamespace.${caseConvert("datetime")}") } + + test("SPARK-50792: Format binary data as a binary literal in JDBC.") { + val tableName = s"$catalogName.test_binary_literal" + withTable(tableName) { + // Create a table with binary column + val binary = "X'123456'" + val lessThanBinary = "X'123455'" + val greaterThanBinary = "X'123457'" + + sql(s"CREATE TABLE $tableName (binary_col BINARY)") + sql(s"INSERT INTO $tableName VALUES ($binary)") + + def testBinaryLiteral(operator: String, literal: String, expected: Int): Unit = { + val sql = s"SELECT * FROM $tableName WHERE binary_col $operator $literal" + val df = spark.sql(sql) + checkFilterPushed(df) + val rows = df.collect() + assert(rows.length === expected, s"Failed to run $sql") + if (expected == 1) { + assert(rows(0)(0) === Array(0x12, 0x34, 0x56).map(_.toByte)) + } + } + + testBinaryLiteral("=", binary, 1) + testBinaryLiteral(">=", binary, 1) + testBinaryLiteral(">=", lessThanBinary, 1) + testBinaryLiteral(">", lessThanBinary, 1) + testBinaryLiteral("<=", binary, 1) + testBinaryLiteral("<=", greaterThanBinary, 1) + testBinaryLiteral("<", greaterThanBinary, 1) + testBinaryLiteral("<>", greaterThanBinary, 1) + testBinaryLiteral("<>", lessThanBinary, 1) + testBinaryLiteral("<=>", binary, 1) + testBinaryLiteral("<=>", lessThanBinary, 0) + testBinaryLiteral("<=>", greaterThanBinary, 0) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala index 2f54f1f62fde1..f33e64c859fb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -82,6 +82,12 @@ private case class DB2Dialect() extends JdbcDialect with SQLConfHelper with NoLe } } + override def compileValue(value: Any): Any = value match { + case binaryValue: Array[Byte] => + binaryValue.map("%02X".format(_)).mkString("BLOB(X'", "", "')") + case other => super.compileValue(other) + } + override def getCatalystType( sqlType: Int, typeName: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 81ad1a6d38bbf..694e60102852b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -374,6 +374,7 @@ abstract class JdbcDialect extends Serializable with Logging { case dateValue: Date => "'" + dateValue + "'" case dateValue: LocalDate => s"'${DateFormatter().format(dateValue)}'" case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ") + case binaryValue: Array[Byte] => binaryValue.map("%02X".format(_)).mkString("X'", "", "'") case _ => value } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 7d339a90db8c8..e2a5671a7c28a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -45,6 +45,7 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr // scalastyle:on line.size.limit override def compileValue(value: Any): Any = value match { case booleanValue: Boolean => if (booleanValue) 1 else 0 + case binaryValue: Array[Byte] => binaryValue.map("%02X".format(_)).mkString("0x", "", "") case other => super.compileValue(other) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index a73a34c646356..adb0da1a21264 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -24,7 +24,7 @@ import scala.util.control.NonFatal import org.apache.spark.{SparkThrowable, SparkUnsupportedOperationException} import org.apache.spark.sql.catalyst.SQLConfHelper -import org.apache.spark.sql.connector.expressions.Expression +import org.apache.spark.sql.connector.expressions.{Expression, GeneralScalarExpression, Literal} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.jdbc.OracleDialect._ @@ -61,6 +61,34 @@ private case class OracleDialect() extends JdbcDialect with SQLConfHelper with N } else { super.visitAggregateFunction(funcName, isDistinct, inputs) } + + private def compareBlob(lhs: Expression, operator: String, rhs: Expression): String = { + val l = inputToSQL(lhs) + val r = inputToSQL(rhs) + val op = if (operator == "<=>") "=" else operator + val compare = s"DBMS_LOB.COMPARE($l, $r) $op 0" + if (operator == "<=>") { + s"(($l IS NOT NULL AND $r IS NOT NULL AND $compare) OR ($l IS NULL AND $r IS NULL))" + } else { + compare + } + } + + override def build(expr: Expression): String = expr match { + case e: GeneralScalarExpression => + e.name() match { + case "=" | "<>" | "<=>" | "<" | "<=" | ">" | ">=" => + (e.children()(0), e.children()(1)) match { + case (lhs: Literal[_], rhs: Expression) if lhs.dataType == BinaryType => + compareBlob(lhs, e.name, rhs) + case (lhs: Expression, rhs: Literal[_]) if rhs.dataType == BinaryType => + compareBlob(lhs, e.name, rhs) + case _ => super.build(expr) + } + case _ => super.build(expr) + } + case _ => super.build(expr) + } } override def compileExpression(expr: Expression): Option[String] = { @@ -138,6 +166,8 @@ private case class OracleDialect() extends JdbcDialect with SQLConfHelper with N case timestampValue: Timestamp => "{ts '" + timestampValue + "'}" case dateValue: Date => "{d '" + dateValue + "'}" case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ") + case binaryValue: Array[Byte] => + binaryValue.map("%02X".format(_)).mkString("HEXTORAW('", "", "')") case _ => value } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index c1b79f8017419..0d2e0164079b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -327,6 +327,12 @@ private case class PostgresDialect() } } + override def compileValue(value: Any): Any = value match { + case binaryValue: Array[Byte] => + binaryValue.map("%02X".format(_)).mkString("'\\x", "", "'::bytea") + case other => super.compileValue(other) + } + override def supportsLimit: Boolean = true override def supportsOffset: Boolean = true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 8d3379805e013..541b2975da1e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -3097,4 +3097,24 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel assert(rows.contains(Row(null))) assert(rows.contains(Row("a a a"))) } + + test("SPARK-50792: Format binary data as a binary literal in JDBC.") { + val tableName = "h2.test.binary_literal" + withTable(tableName) { + // Create a table with binary column + val binary = "X'123456'" + + sql(s"CREATE TABLE $tableName (binary_col BINARY)") + sql(s"INSERT INTO $tableName VALUES ($binary)") + + val select = s"SELECT * FROM $tableName WHERE binary_col = $binary" + val df = sql(select) + val filter = df.queryExecution.optimizedPlan.collect { + case f: Filter => f + } + assert(filter.isEmpty, "Filter is not pushed") + assert(df.collect().length === 1, s"Binary literal test failed: $select") + } + } + }