diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala index d9b636e3745aa..ab627d73bee28 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala @@ -168,7 +168,12 @@ object functions { data: Column, messageClassName: String, options: java.util.Map[String, String]): Column = { - from_protobuf(data, messageClassName, "", options) + Column.internalFnWithOptions( + "to_protobuf", + options.asScala.iterator, + data, + lit(messageClassName) + ) } /** @@ -309,6 +314,11 @@ object functions { @Experimental def to_protobuf(data: Column, messageClassName: String, options: java.util.Map[String, String]) : Column = { - to_protobuf(data, messageClassName, "", options) + Column.internalFnWithOptions( + "to_protobuf", + options.asScala.iterator, + data, + lit(messageClassName) + ) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala index 5c6bd3ff64b95..046bab6c3aaf7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala @@ -88,8 +88,19 @@ case class FromProtobuf( descFilePath: Expression, options: Expression) extends QuaternaryExpression with RuntimeReplaceable { - def this(data: Expression, messageName: Expression, descFilePath: Expression) = { - this(data, messageName, descFilePath, Literal(null)) + def this(data: Expression, messageName: Expression, descFilePathOrOptions: Expression) = { + this( + data, + messageName, + descFilePathOrOptions.dataType match { + case _: StringType | BinaryType => descFilePathOrOptions + case _ => Literal(null) + }, + descFilePathOrOptions.dataType match { + case _: MapType => descFilePathOrOptions + case _ => Literal(null) + } + ) } def this(data: Expression, messageName: Expression) = { @@ -210,8 +221,19 @@ case class ToProtobuf( descFilePath: Expression, options: Expression) extends QuaternaryExpression with RuntimeReplaceable { - def this(data: Expression, messageName: Expression, descFilePath: Expression) = { - this(data, messageName, descFilePath, Literal(null)) + def this(data: Expression, messageName: Expression, descFilePathOrOptions: Expression) = { + this( + data, + messageName, + descFilePathOrOptions.dataType match { + case _: StringType | BinaryType => descFilePathOrOptions + case _ => Literal(null) + }, + descFilePathOrOptions.dataType match { + case _: MapType => descFilePathOrOptions + case _ => Literal(null) + } + ) } def this(data: Expression, messageName: Expression) = { diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_descFilePath.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_descFilePath.explain index 6eb4805b4fcc4..7e100d577ef74 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_descFilePath.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_descFilePath.explain @@ -1,2 +1,2 @@ -Project [from_protobuf(bytes#0, StorageLevel, Some([B)) AS from_protobuf(bytes)#0] +Project [from_protobuf(bytes#0, StorageLevel, Some([B)) AS fromprotobuf(bytes, StorageLevel, X'0AFC010A0C636F6D6D6F6E2E70726F746F120D737061726B2E636F6E6E65637422B0010A0C53746F726167654C6576656C12190A087573655F6469736B18012001280852077573654469736B121D0A0A7573655F6D656D6F727918022001280852097573654D656D6F727912200A0C7573655F6F66665F68656170180320012808520A7573654F66664865617012220A0C646573657269616C697A6564180420012808520C646573657269616C697A656412200A0B7265706C69636174696F6E180520012805520B7265706C69636174696F6E42220A1E6F72672E6170616368652E737061726B2E636F6E6E6563742E70726F746F5001620670726F746F33', NULL)#0] +- LocalRelation , [id#0L, bytes#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_descFilePath_options.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_descFilePath_options.explain index c4a47b1aef07b..08644da4595e6 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_descFilePath_options.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_descFilePath_options.explain @@ -1,2 +1,2 @@ -Project [from_protobuf(bytes#0, StorageLevel, Some([B), (recursive.fields.max.depth,2)) AS from_protobuf(bytes)#0] +Project [from_protobuf(bytes#0, StorageLevel, Some([B), (recursive.fields.max.depth,2)) AS fromprotobuf(bytes, StorageLevel, X'0AFC010A0C636F6D6D6F6E2E70726F746F120D737061726B2E636F6E6E65637422B0010A0C53746F726167654C6576656C12190A087573655F6469736B18012001280852077573654469736B121D0A0A7573655F6D656D6F727918022001280852097573654D656D6F727912200A0C7573655F6F66665F68656170180320012808520A7573654F66664865617012220A0C646573657269616C697A6564180420012808520C646573657269616C697A656412200A0B7265706C69636174696F6E180520012805520B7265706C69636174696F6E42220A1E6F72672E6170616368652E737061726B2E636F6E6E6563742E70726F746F5001620670726F746F33', map(recursive.fields.max.depth, 2))#0] +- LocalRelation , [id#0L, bytes#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName.explain index e7f70fa2c1a9e..6e928917d0f95 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName.explain @@ -1,2 +1,2 @@ -Project [to_protobuf(bytes#0, org.apache.spark.connect.proto.StorageLevel, None) AS to_protobuf(bytes)#0] +Project [to_protobuf(bytes#0, org.apache.spark.connect.proto.StorageLevel, None) AS toprotobuf(bytes, org.apache.spark.connect.proto.StorageLevel, NULL, NULL)#0] +- LocalRelation , [id#0L, bytes#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_descFilePath.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_descFilePath.explain index 7c688cc446947..c54e8da223074 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_descFilePath.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_descFilePath.explain @@ -1,2 +1,2 @@ -Project [to_protobuf(bytes#0, StorageLevel, Some([B)) AS to_protobuf(bytes)#0] +Project [to_protobuf(bytes#0, StorageLevel, Some([B)) AS toprotobuf(bytes, StorageLevel, X'0AFC010A0C636F6D6D6F6E2E70726F746F120D737061726B2E636F6E6E65637422B0010A0C53746F726167654C6576656C12190A087573655F6469736B18012001280852077573654469736B121D0A0A7573655F6D656D6F727918022001280852097573654D656D6F727912200A0C7573655F6F66665F68656170180320012808520A7573654F66664865617012220A0C646573657269616C697A6564180420012808520C646573657269616C697A656412200A0B7265706C69636174696F6E180520012805520B7265706C69636174696F6E42220A1E6F72672E6170616368652E737061726B2E636F6E6E6563742E70726F746F5001620670726F746F33', NULL)#0] +- LocalRelation , [id#0L, bytes#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_descFilePath_options.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_descFilePath_options.explain index 9f05bb03c9c6d..301562203955f 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_descFilePath_options.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_descFilePath_options.explain @@ -1,2 +1,2 @@ -Project [to_protobuf(bytes#0, StorageLevel, Some([B), (recursive.fields.max.depth,2)) AS to_protobuf(bytes)#0] +Project [to_protobuf(bytes#0, StorageLevel, Some([B), (recursive.fields.max.depth,2)) AS toprotobuf(bytes, StorageLevel, X'0AFC010A0C636F6D6D6F6E2E70726F746F120D737061726B2E636F6E6E65637422B0010A0C53746F726167654C6576656C12190A087573655F6469736B18012001280852077573654469736B121D0A0A7573655F6D656D6F727918022001280852097573654D656D6F727912200A0C7573655F6F66665F68656170180320012808520A7573654F66664865617012220A0C646573657269616C697A6564180420012808520C646573657269616C697A656412200A0B7265706C69636174696F6E180520012805520B7265706C69636174696F6E42220A1E6F72672E6170616368652E737061726B2E636F6E6E6563742E70726F746F5001620670726F746F33', map(recursive.fields.max.depth, 2))#0] +- LocalRelation , [id#0L, bytes#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_options.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_options.explain index a5d8851a7d1f3..27c35ed7d1da0 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_options.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_options.explain @@ -1,2 +1,2 @@ -Project [to_protobuf(bytes#0, org.apache.spark.connect.proto.StorageLevel, None, (recursive.fields.max.depth,2)) AS to_protobuf(bytes)#0] +Project [to_protobuf(bytes#0, org.apache.spark.connect.proto.StorageLevel, None, (recursive.fields.max.depth,2)) AS toprotobuf(bytes, org.apache.spark.connect.proto.StorageLevel, NULL, map(recursive.fields.max.depth, 2))#0] +- LocalRelation , [id#0L, bytes#0]