Skip to content

Commit

Permalink
[SPARK-8238][SPARK-8239][SPARK-8242][SPARK-8243][SPARK-8268][SQL]Add …
Browse files Browse the repository at this point in the history
…ascii/base64/unbase64/encode/decode functions

Add `ascii`,`base64`,`unbase64`,`encode` and `decode` expressions.

Author: Cheng Hao <hao.cheng@intel.com>

Closes apache#6843 from chenghao-intel/str_funcs2 and squashes the following commits:

78dee7d [Cheng Hao] base 64 -> base64
9d6f9f4 [Cheng Hao] remove the toString method for expressions
ed5c19c [Cheng Hao] update code as comments
96170fc [Cheng Hao] scalastyle issues
e2df768 [Cheng Hao] remove the unused import
491ce7b [Cheng Hao] add ascii/base64/unbase64/encode/decode functions
  • Loading branch information
chenghao-intel authored and rxin committed Jul 4, 2015
1 parent f32487b commit f35b0c3
Show file tree
Hide file tree
Showing 5 changed files with 308 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,16 @@ object FunctionRegistry {
expression[Sum]("sum"),

// string functions
expression[Ascii]("ascii"),
expression[Base64]("base64"),
expression[Encode]("encode"),
expression[Decode]("decode"),
expression[Lower]("lcase"),
expression[Lower]("lower"),
expression[StringLength]("length"),
expression[Substring]("substr"),
expression[Substring]("substring"),
expression[UnBase64]("unbase64"),
expression[Upper]("ucase"),
expression[UnHex]("unhex"),
expression[Upper]("upper"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,120 @@ case class StringLength(child: Expression) extends UnaryExpression with ExpectsI

override def prettyName: String = "length"
}

/**
* Returns the numeric value of the first character of str.
*/
case class Ascii(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def dataType: DataType = IntegerType
override def inputTypes: Seq[DataType] = Seq(StringType)

override def eval(input: InternalRow): Any = {
val string = child.eval(input)
if (string == null) {
null
} else {
val bytes = string.asInstanceOf[UTF8String].getBytes
if (bytes.length > 0) {
bytes(0).asInstanceOf[Int]
} else {
0
}
}
}
}

/**
* Converts the argument from binary to a base 64 string.
*/
case class Base64(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(BinaryType)

override def eval(input: InternalRow): Any = {
val bytes = child.eval(input)
if (bytes == null) {
null
} else {
UTF8String.fromBytes(
org.apache.commons.codec.binary.Base64.encodeBase64(
bytes.asInstanceOf[Array[Byte]]))
}
}
}

/**
* Converts the argument from a base 64 string to BINARY.
*/
case class UnBase64(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def dataType: DataType = BinaryType
override def inputTypes: Seq[DataType] = Seq(StringType)

override def eval(input: InternalRow): Any = {
val string = child.eval(input)
if (string == null) {
null
} else {
org.apache.commons.codec.binary.Base64.decodeBase64(string.asInstanceOf[UTF8String].toString)
}
}
}

/**
* Decodes the first argument into a String using the provided character set
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null. (As of Hive 0.12.0.).
*/
case class Decode(bin: Expression, charset: Expression) extends Expression with ExpectsInputTypes {
override def children: Seq[Expression] = bin :: charset :: Nil
override def foldable: Boolean = bin.foldable && charset.foldable
override def nullable: Boolean = bin.nullable || charset.nullable
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(BinaryType, StringType)

override def eval(input: InternalRow): Any = {
val l = bin.eval(input)
if (l == null) {
null
} else {
val r = charset.eval(input)
if (r == null) {
null
} else {
val fromCharset = r.asInstanceOf[UTF8String].toString
UTF8String.fromString(new String(l.asInstanceOf[Array[Byte]], fromCharset))
}
}
}
}

/**
* Encodes the first argument into a BINARY using the provided character set
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null. (As of Hive 0.12.0.)
*/
case class Encode(value: Expression, charset: Expression)
extends Expression with ExpectsInputTypes {
override def children: Seq[Expression] = value :: charset :: Nil
override def foldable: Boolean = value.foldable && charset.foldable
override def nullable: Boolean = value.nullable || charset.nullable
override def dataType: DataType = BinaryType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)

override def eval(input: InternalRow): Any = {
val l = value.eval(input)
if (l == null) {
null
} else {
val r = charset.eval(input)
if (r == null) {
null
} else {
val toCharset = r.asInstanceOf[UTF8String].toString
l.asInstanceOf[UTF8String].toString.getBytes(toCharset)
}
}
}
}


Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types.{IntegerType, StringType}
import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType}


class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand Down Expand Up @@ -217,11 +217,61 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("length for string") {
val regEx = 'a.string.at(0)
val a = 'a.string.at(0)
checkEvaluation(StringLength(Literal("abc")), 3, create_row("abdef"))
checkEvaluation(StringLength(regEx), 5, create_row("abdef"))
checkEvaluation(StringLength(regEx), 0, create_row(""))
checkEvaluation(StringLength(regEx), null, create_row(null))
checkEvaluation(StringLength(a), 5, create_row("abdef"))
checkEvaluation(StringLength(a), 0, create_row(""))
checkEvaluation(StringLength(a), null, create_row(null))
checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef"))
}

test("ascii for string") {
val a = 'a.string.at(0)
checkEvaluation(Ascii(Literal("efg")), 101, create_row("abdef"))
checkEvaluation(Ascii(a), 97, create_row("abdef"))
checkEvaluation(Ascii(a), 0, create_row(""))
checkEvaluation(Ascii(a), null, create_row(null))
checkEvaluation(Ascii(Literal.create(null, StringType)), null, create_row("abdef"))
}

test("base64/unbase64 for string") {
val a = 'a.string.at(0)
val b = 'b.binary.at(0)
val bytes = Array[Byte](1, 2, 3, 4)

checkEvaluation(Base64(Literal(bytes)), "AQIDBA==", create_row("abdef"))
checkEvaluation(Base64(UnBase64(Literal("AQIDBA=="))), "AQIDBA==", create_row("abdef"))
checkEvaluation(Base64(UnBase64(Literal(""))), "", create_row("abdef"))
checkEvaluation(Base64(UnBase64(Literal.create(null, StringType))), null, create_row("abdef"))
checkEvaluation(Base64(UnBase64(a)), "AQIDBA==", create_row("AQIDBA=="))

checkEvaluation(Base64(b), "AQIDBA==", create_row(bytes))
checkEvaluation(Base64(b), "", create_row(Array[Byte]()))
checkEvaluation(Base64(b), null, create_row(null))
checkEvaluation(Base64(Literal.create(null, StringType)), null, create_row("abdef"))

checkEvaluation(UnBase64(a), null, create_row(null))
checkEvaluation(UnBase64(Literal.create(null, StringType)), null, create_row("abdef"))
}

test("encode/decode for string") {
val a = 'a.string.at(0)
val b = 'b.binary.at(0)
// scalastyle:off
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
checkEvaluation(
Decode(Encode(Literal("大千世界"), Literal("UTF-16LE")), Literal("UTF-16LE")), "大千世界")
checkEvaluation(
Decode(Encode(a, Literal("utf-8")), Literal("utf-8")), "大千世界", create_row("大千世界"))
checkEvaluation(
Decode(Encode(a, Literal("utf-8")), Literal("utf-8")), "", create_row(""))
// scalastyle:on
checkEvaluation(Encode(a, Literal("utf-8")), null, create_row(null))
checkEvaluation(Encode(Literal.create(null, StringType), Literal("utf-8")), null)
checkEvaluation(Encode(a, Literal.create(null, StringType)), null, create_row(""))

checkEvaluation(Decode(b, Literal("utf-8")), null, create_row(null))
checkEvaluation(Decode(Literal.create(null, BinaryType), Literal("utf-8")), null)
checkEvaluation(Decode(b, Literal.create(null, StringType)), null, create_row(null))
}
}
93 changes: 93 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1581,18 +1581,111 @@ object functions {

/**
* Computes the length of a given string value
*
* @group string_funcs
* @since 1.5.0
*/
def strlen(e: Column): Column = StringLength(e.expr)

/**
* Computes the length of a given string column
*
* @group string_funcs
* @since 1.5.0
*/
def strlen(columnName: String): Column = strlen(Column(columnName))

/**
* Computes the numeric value of the first character of the specified string value.
*
* @group string_funcs
* @since 1.5.0
*/
def ascii(e: Column): Column = Ascii(e.expr)

/**
* Computes the numeric value of the first character of the specified string column.
*
* @group string_funcs
* @since 1.5.0
*/
def ascii(columnName: String): Column = ascii(Column(columnName))

/**
* Computes the specified value from binary to a base64 string.
*
* @group string_funcs
* @since 1.5.0
*/
def base64(e: Column): Column = Base64(e.expr)

/**
* Computes the specified column from binary to a base64 string.
*
* @group string_funcs
* @since 1.5.0
*/
def base64(columnName: String): Column = base64(Column(columnName))

/**
* Computes the specified value from a base64 string to binary.
*
* @group string_funcs
* @since 1.5.0
*/
def unbase64(e: Column): Column = UnBase64(e.expr)

/**
* Computes the specified column from a base64 string to binary.
*
* @group string_funcs
* @since 1.5.0
*/
def unbase64(columnName: String): Column = unbase64(Column(columnName))

/**
* Computes the first argument into a binary from a string using the provided character set
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null.
*
* @group string_funcs
* @since 1.5.0
*/
def encode(value: Column, charset: Column): Column = Encode(value.expr, charset.expr)

/**
* Computes the first argument into a binary from a string using the provided character set
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null.
*
* @group string_funcs
* @since 1.5.0
*/
def encode(columnName: String, charsetColumnName: String): Column =
encode(Column(columnName), Column(charsetColumnName))

/**
* Computes the first argument into a string from a binary using the provided character set
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null.
*
* @group string_funcs
* @since 1.5.0
*/
def decode(value: Column, charset: Column): Column = Decode(value.expr, charset.expr)

/**
* Computes the first argument into a string from a binary using the provided character set
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null.
*
* @group string_funcs
* @since 1.5.0
*/
def decode(columnName: String, charsetColumnName: String): Column =
decode(Column(columnName), Column(charsetColumnName))


//////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,4 +225,42 @@ class DataFrameFunctionsSuite extends QueryTest {
Row(l)
})
}

test("string ascii function") {
val df = Seq(("abc", "")).toDF("a", "b")
checkAnswer(
df.select(ascii($"a"), ascii("b")),
Row(97, 0))

checkAnswer(
df.selectExpr("ascii(a)", "ascii(b)"),
Row(97, 0))
}

test("string base64/unbase64 function") {
val bytes = Array[Byte](1, 2, 3, 4)
val df = Seq((bytes, "AQIDBA==")).toDF("a", "b")
checkAnswer(
df.select(base64("a"), base64($"a"), unbase64("b"), unbase64($"b")),
Row("AQIDBA==", "AQIDBA==", bytes, bytes))

checkAnswer(
df.selectExpr("base64(a)", "unbase64(b)"),
Row("AQIDBA==", bytes))
}

test("string encode/decode function") {
val bytes = Array[Byte](-27, -92, -89, -27, -115, -125, -28, -72, -106, -25, -107, -116)
// scalastyle:off
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
val df = Seq(("大千世界", "utf-8", bytes)).toDF("a", "b", "c")
checkAnswer(
df.select(encode($"a", $"b"), encode("a", "b"), decode($"c", $"b"), decode("c", "b")),
Row(bytes, bytes, "大千世界", "大千世界"))

checkAnswer(
df.selectExpr("encode(a, b)", "decode(c, b)"),
Row(bytes, "大千世界"))
// scalastyle:on
}
}

0 comments on commit f35b0c3

Please sign in to comment.