Skip to content

Commit

Permalink
to_protobuf support
Browse files Browse the repository at this point in the history
  • Loading branch information
itholic committed Aug 27, 2024
1 parent bff0d81 commit 4914670
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@ object InternalFunctionRegistration {
Some(descriptor).asInstanceOf[Option[Array[Byte]]],
options.asInstanceOf[Map[String, String]])
}

registerFunction("to_protobuf") {
case Seq(input, StringLiteral(messageName), descriptor, options) =>
CatalystDataToProtobuf(
input,
messageName,
Some(descriptor).asInstanceOf[Option[Array[Byte]]],
options.asInstanceOf[Map[String, String]])
}
}

class InternalFunctionRegistration extends SparkSessionExtensionsProvider {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.Column
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.internal.ExpressionUtils.{column, expression}
import org.apache.spark.sql.protobuf.utils.ProtobufUtils

// scalastyle:off: object.name
Expand Down Expand Up @@ -207,7 +206,12 @@ object functions {
@Experimental
def to_protobuf(data: Column, messageName: String, binaryFileDescriptorSet: Array[Byte])
: Column = {
CatalystDataToProtobuf(data, messageName, Some(binaryFileDescriptorSet))
Column.internalFn(
"to_protobuf",
data,
lit(messageName),
lit(binaryFileDescriptorSet)
)
}
/**
* Converts a column into binary of protobuf format. The Protobuf definition is provided
Expand All @@ -229,7 +233,7 @@ object functions {
descFilePath: String,
options: java.util.Map[String, String]): Column = {
val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath)
CatalystDataToProtobuf(data, messageName, Some(fileContent), options.asScala.toMap)
to_protobuf(data, messageName, fileContent, options)
}

/**
Expand All @@ -253,7 +257,13 @@ object functions {
binaryFileDescriptorSet: Array[Byte],
options: java.util.Map[String, String]
): Column = {
CatalystDataToProtobuf(data, messageName, Some(binaryFileDescriptorSet), options.asScala.toMap)
Column.internalFnWithOptions(
"to_protobuf",
options.asScala.iterator,
data,
lit(messageName),
lit(binaryFileDescriptorSet)
)
}

/**
Expand All @@ -273,7 +283,11 @@ object functions {
*/
@Experimental
def to_protobuf(data: Column, messageClassName: String): Column = {
CatalystDataToProtobuf(data, messageClassName)
Column.internalFn(
"to_protobuf",
data,
lit(messageClassName)
)
}

/**
Expand All @@ -295,6 +309,6 @@ object functions {
@Experimental
def to_protobuf(data: Column, messageClassName: String, options: java.util.Map[String, String])
: Column = {
CatalystDataToProtobuf(data, messageClassName, None, options.asScala.toMap)
to_protobuf(data, messageClassName, "", options)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ case class FromProtobuf(
"representing the Protobuf message name"))
}
val descFilePathCheck = descFilePath.dataType match {
case _: StringType | BinaryType if descFilePath.foldable => None
case _: StringType | BinaryType | NullType if descFilePath.foldable => None
case _ =>
Some(TypeCheckResult.TypeCheckFailure(
"The third argument of the FROM_PROTOBUF SQL function must be a constant string " +
Expand Down Expand Up @@ -209,6 +209,15 @@ case class ToProtobuf(
messageName: Expression,
descFilePath: Expression,
options: Expression) extends QuaternaryExpression with RuntimeReplaceable {

def this(data: Expression, messageName: Expression, descFilePath: Expression) = {
this(data, messageName, descFilePath, Literal(null))
}

def this(data: Expression, messageName: Expression) = {
this(data, messageName, Literal(null), Literal(null))
}

override def first: Expression = data
override def second: Expression = messageName
override def third: Expression = descFilePath
Expand All @@ -231,7 +240,7 @@ case class ToProtobuf(
"representing the Protobuf message name"))
}
val descFilePathCheck = descFilePath.dataType match {
case _: StringType | BinaryType if descFilePath.foldable => None
case _: StringType | BinaryType | NullType if descFilePath.foldable => None
case _ =>
Some(TypeCheckResult.TypeCheckFailure(
"The third argument of the TO_PROTOBUF SQL function must be a constant string " +
Expand Down

0 comments on commit 4914670

Please sign in to comment.