Skip to content

Commit

Permalink
implementation for bitwise and,or, not and xor on Column with tests a…
Browse files Browse the repository at this point in the history
…nd docs
  • Loading branch information
Shiti committed May 7, 2015
1 parent 2d6612c commit 71a9913
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 2 deletions.
5 changes: 5 additions & 0 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,6 +1277,11 @@ def __init__(self, jc):
__contains__ = _bin_op("contains")
__getitem__ = _bin_op("getItem")

# bitwise operators
bitwiseOR = _bin_op("bitwiseOR")
bitwiseAND = _bin_op("bitwiseAND")
bitwiseXOR = _bin_op("bitwiseXOR")

def getItem(self, key):
"""An expression that gets an item at position `ordinal` out of a list,
or gets an item by key out of a dict.
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def _(col1, col2):
'toRadians': 'Converts an angle measured in degrees to an approximately equivalent angle ' +
'measured in radians.',

'bitwiseNOT': 'Computes bitwise not.',

'max': 'Aggregate function: returns the maximum value of the expression in a group.',
'min': 'Aggregate function: returns the minimum value of the expression in a group.',
'first': 'Aggregate function: returns the first value in a group.',
Expand Down
13 changes: 13 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,19 @@ def test_fillna(self):
self.assertEqual(row.age, None)
self.assertEqual(row.height, None)

def test_bitwise_operations(self):
from pyspark.sql import functions
row = Row(a=170, b=75)
df = self.sqlCtx.createDataFrame([row])
result = df.select(df.a.bitwiseAND(df.b)).collect()[0].asDict()
self.assertEqual(170 & 75, result['(a & b)'])
result = df.select(df.a.bitwiseOR(df.b)).collect()[0].asDict()
self.assertEqual(170 | 75, result['(a | b)'])
result = df.select(df.a.bitwiseXOR(df.b)).collect()[0].asDict()
self.assertEqual(170 ^ 75, result['(a ^ b)'])
result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict()
self.assertEqual(~75, result['~b'])


class HiveContextSQLTests(ReusedPySparkTestCase):

Expand Down
31 changes: 31 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,37 @@ class Column(protected[sql] val expr: Expression) extends Logging {
println(expr.prettyString)
}
}

/**
* Compute bitwise OR of this expression with another expression.
* {{{
* df.select($"colA".bitwiseOR($"colB"))
* }}}
*
* @group expr_ops
*/
def bitwiseOR(other: Any): Column = BitwiseOr(expr, lit(other).expr)

/**
* Compute bitwise AND of this expression with another expression.
* {{{
* df.select($"colA".bitwiseAND($"colB"))
* }}}
*
* @group expr_ops
*/
def bitwiseAND(other: Any): Column = BitwiseAnd(expr, lit(other).expr)

/**
* Compute bitwise XOR of this expression with another expression.
* {{{
* df.select($"colA".bitwiseXOR($"colB"))
* }}}
*
* @group expr_ops
*/
def bitwiseXOR(other: Any): Column = BitwiseXor(expr, lit(other).expr)

}


Expand Down
8 changes: 8 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,14 @@ object functions {
*/
def upper(e: Column): Column = Upper(e.expr)


/**
* Computes bitwise NOT.
*
* @group normal_funcs
*/
def bitwiseNOT(e: Column): Column = BitwiseNot(e.expr)

//////////////////////////////////////////////////////////////////////////////////////////////
// Math Functions
//////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ import org.apache.spark.sql.types._
class ColumnExpressionSuite extends QueryTest {
import org.apache.spark.sql.TestData._

// TODO: Add test cases for bitwise operations.

test("collect on column produced by a binary operator") {
val df = Seq((1, 2, 3)).toDF("a", "b", "c")
checkAnswer(df.select(df("a") + df("b")), Seq(Row(3)))
Expand Down Expand Up @@ -385,4 +383,35 @@ class ColumnExpressionSuite extends QueryTest {
assert(row.getDouble(1) >= -4.0)
}
}

test("bitwiseAND") {
checkAnswer(
testData2.select($"a".bitwiseAND(75)),
testData2.collect().toSeq.map(r => Row(r.getInt(0) & 75)))

checkAnswer(
testData2.select($"a".bitwiseAND($"b").bitwiseAND(22)),
testData2.collect().toSeq.map(r => Row(r.getInt(0) & r.getInt(1) & 22)))
}

test("bitwiseOR") {
checkAnswer(
testData2.select($"a".bitwiseOR(170)),
testData2.collect().toSeq.map(r => Row(r.getInt(0) | 170)))

checkAnswer(
testData2.select($"a".bitwiseOR($"b").bitwiseOR(42)),
testData2.collect().toSeq.map(r => Row(r.getInt(0) | r.getInt(1) | 42)))
}

test("bitwiseXOR") {
checkAnswer(
testData2.select($"a".bitwiseXOR(112)),
testData2.collect().toSeq.map(r => Row(r.getInt(0) ^ 112)))

checkAnswer(
testData2.select($"a".bitwiseXOR($"b").bitwiseXOR(39)),
testData2.collect().toSeq.map(r => Row(r.getInt(0) ^ r.getInt(1) ^ 39)))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql

import org.apache.spark.sql.TestData._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -81,4 +82,10 @@ class DataFrameFunctionsSuite extends QueryTest {
struct(col("a") * 2)
}
}

test("bitwiseNOT") {
checkAnswer(
testData2.select(bitwiseNOT($"a")),
testData2.collect().toSeq.map(r => Row(~r.getInt(0))))
}
}

0 comments on commit 71a9913

Please sign in to comment.