diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 24f370543def4..cee804f5cc1f7 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1277,6 +1277,11 @@ def __init__(self, jc): __contains__ = _bin_op("contains") __getitem__ = _bin_op("getItem") + # bitwise operators + bitwiseOR = _bin_op("bitwiseOR") + bitwiseAND = _bin_op("bitwiseAND") + bitwiseXOR = _bin_op("bitwiseXOR") + def getItem(self, key): """An expression that gets an item at position `ordinal` out of a list, or gets an item by key out of a dict. diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 692af868dd534..274c410a1ee9c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -104,6 +104,8 @@ def _(col1, col2): 'toRadians': 'Converts an angle measured in degrees to an approximately equivalent angle ' + 'measured in radians.', + 'bitwiseNOT': 'Computes bitwise not.', + 'max': 'Aggregate function: returns the maximum value of the expression in a group.', 'min': 'Aggregate function: returns the minimum value of the expression in a group.', 'first': 'Aggregate function: returns the first value in a group.', diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b232f3a965526..45dfedce22add 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -645,6 +645,19 @@ def test_fillna(self): self.assertEqual(row.age, None) self.assertEqual(row.height, None) + def test_bitwise_operations(self): + from pyspark.sql import functions + row = Row(a=170, b=75) + df = self.sqlCtx.createDataFrame([row]) + result = df.select(df.a.bitwiseAND(df.b)).collect()[0].asDict() + self.assertEqual(170 & 75, result['(a & b)']) + result = df.select(df.a.bitwiseOR(df.b)).collect()[0].asDict() + self.assertEqual(170 | 75, result['(a | b)']) + result = df.select(df.a.bitwiseXOR(df.b)).collect()[0].asDict() + self.assertEqual(170 ^ 75, result['(a ^ b)']) + result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict() + self.assertEqual(~75, result['~b']) + class HiveContextSQLTests(ReusedPySparkTestCase): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 8eb632d3d600b..8bbe11b412214 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -698,6 +698,37 @@ class Column(protected[sql] val expr: Expression) extends Logging { println(expr.prettyString) } } + + /** + * Compute bitwise OR of this expression with another expression. + * {{{ + * df.select($"colA".bitwiseOR($"colB")) + * }}} + * + * @group expr_ops + */ + def bitwiseOR(other: Any): Column = BitwiseOr(expr, lit(other).expr) + + /** + * Compute bitwise AND of this expression with another expression. + * {{{ + * df.select($"colA".bitwiseAND($"colB")) + * }}} + * + * @group expr_ops + */ + def bitwiseAND(other: Any): Column = BitwiseAnd(expr, lit(other).expr) + + /** + * Compute bitwise XOR of this expression with another expression. + * {{{ + * df.select($"colA".bitwiseXOR($"colB")) + * }}} + * + * @group expr_ops + */ + def bitwiseXOR(other: Any): Column = BitwiseXor(expr, lit(other).expr) + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 830b5017717b5..1728b0b8c910e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -438,6 +438,14 @@ object functions { */ def upper(e: Column): Column = Upper(e.expr) + + /** + * Computes bitwise NOT. + * + * @group normal_funcs + */ + def bitwiseNOT(e: Column): Column = BitwiseNot(e.expr) + ////////////////////////////////////////////////////////////////////////////////////////////// // Math Functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 3c1ad656fc855..d96186c268720 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.types._ class ColumnExpressionSuite extends QueryTest { import org.apache.spark.sql.TestData._ - // TODO: Add test cases for bitwise operations. - test("collect on column produced by a binary operator") { val df = Seq((1, 2, 3)).toDF("a", "b", "c") checkAnswer(df.select(df("a") + df("b")), Seq(Row(3))) @@ -385,4 +383,35 @@ class ColumnExpressionSuite extends QueryTest { assert(row.getDouble(1) >= -4.0) } } + + test("bitwiseAND") { + checkAnswer( + testData2.select($"a".bitwiseAND(75)), + testData2.collect().toSeq.map(r => Row(r.getInt(0) & 75))) + + checkAnswer( + testData2.select($"a".bitwiseAND($"b").bitwiseAND(22)), + testData2.collect().toSeq.map(r => Row(r.getInt(0) & r.getInt(1) & 22))) + } + + test("bitwiseOR") { + checkAnswer( + testData2.select($"a".bitwiseOR(170)), + testData2.collect().toSeq.map(r => Row(r.getInt(0) | 170))) + + checkAnswer( + testData2.select($"a".bitwiseOR($"b").bitwiseOR(42)), + testData2.collect().toSeq.map(r => Row(r.getInt(0) | r.getInt(1) | 42))) + } + + test("bitwiseXOR") { + checkAnswer( + testData2.select($"a".bitwiseXOR(112)), + testData2.collect().toSeq.map(r => Row(r.getInt(0) ^ 112))) + + checkAnswer( + testData2.select($"a".bitwiseXOR($"b").bitwiseXOR(39)), + testData2.collect().toSeq.map(r => Row(r.getInt(0) ^ r.getInt(1) ^ 39))) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index ca03713ef4658..b1e0faa310b68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ @@ -81,4 +82,10 @@ class DataFrameFunctionsSuite extends QueryTest { struct(col("a") * 2) } } + + test("bitwiseNOT") { + checkAnswer( + testData2.select(bitwiseNOT($"a")), + testData2.collect().toSeq.map(r => Row(~r.getInt(0)))) + } }