Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(bigquery): add __eq__ method for class PartitionRange and RangePartitioning #162

Merged
merged 4 commits into from
Jul 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions google/cloud/bigquery/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1891,10 +1891,20 @@ def interval(self, value):
def _key(self):
return tuple(sorted(self._properties.items()))

def __eq__(self, other):
if not isinstance(other, PartitionRange):
return NotImplemented
return self._key() == other._key()

def __ne__(self, other):
return not self == other

def __repr__(self):
key_vals = ["{}={}".format(key, val) for key, val in self._key()]
return "PartitionRange({})".format(", ".join(key_vals))

__hash__ = None


class RangePartitioning(object):
"""Range-based partitioning configuration for a table.
Expand Down Expand Up @@ -1961,10 +1971,20 @@ def field(self, value):
def _key(self):
return (("field", self.field), ("range_", self.range_))

def __eq__(self, other):
if not isinstance(other, RangePartitioning):
return NotImplemented
return self._key() == other._key()

def __ne__(self, other):
return not self == other

def __repr__(self):
key_vals = ["{}={}".format(key, repr(val)) for key, val in self._key()]
return "RangePartitioning({})".format(", ".join(key_vals))

__hash__ = None


class TimePartitioningType(object):
"""Specifies the type of time partitioning to perform."""
Expand Down
82 changes: 82 additions & 0 deletions tests/unit/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3525,6 +3525,37 @@ def test_constructor_w_resource(self):
assert object_under_test.end == 1234567890
assert object_under_test.interval == 1000000

def test___eq___start_mismatch(self):
object_under_test = self._make_one(start=1, end=10, interval=2)
other = self._make_one(start=2, end=10, interval=2)
self.assertNotEqual(object_under_test, other)

def test___eq___end__mismatch(self):
object_under_test = self._make_one(start=1, end=10, interval=2)
other = self._make_one(start=1, end=11, interval=2)
self.assertNotEqual(object_under_test, other)

def test___eq___interval__mismatch(self):
object_under_test = self._make_one(start=1, end=10, interval=2)
other = self._make_one(start=1, end=11, interval=3)
self.assertNotEqual(object_under_test, other)

def test___eq___hit(self):
object_under_test = self._make_one(start=1, end=10, interval=2)
other = self._make_one(start=1, end=10, interval=2)
self.assertEqual(object_under_test, other)

def test__eq___type_mismatch(self):
object_under_test = self._make_one(start=1, end=10, interval=2)
self.assertNotEqual(object_under_test, object())
self.assertEqual(object_under_test, mock.ANY)

def test_unhashable_object(self):
object_under_test1 = self._make_one(start=1, end=10, interval=2)

with six.assertRaisesRegex(self, TypeError, r".*unhashable type.*"):
hash(object_under_test1)

def test_repr(self):
object_under_test = self._make_one(start=1, end=10, interval=2)
assert repr(object_under_test) == "PartitionRange(end=10, interval=2, start=1)"
Expand Down Expand Up @@ -3574,6 +3605,57 @@ def test_range_w_wrong_type(self):
with pytest.raises(ValueError, match="PartitionRange"):
object_under_test.range_ = object()

def test___eq___field_mismatch(self):
from google.cloud.bigquery.table import PartitionRange

object_under_test = self._make_one(
range_=PartitionRange(start=1, end=10, interval=2), field="integer_col"
)
other = self._make_one(
range_=PartitionRange(start=1, end=10, interval=2), field="float_col"
)
self.assertNotEqual(object_under_test, other)

def test___eq___range__mismatch(self):
from google.cloud.bigquery.table import PartitionRange

object_under_test = self._make_one(
range_=PartitionRange(start=1, end=10, interval=2), field="integer_col"
)
other = self._make_one(
range_=PartitionRange(start=2, end=20, interval=2), field="float_col"
)
self.assertNotEqual(object_under_test, other)

def test___eq___hit(self):
from google.cloud.bigquery.table import PartitionRange

object_under_test = self._make_one(
range_=PartitionRange(start=1, end=10, interval=2), field="integer_col"
)
other = self._make_one(
range_=PartitionRange(start=1, end=10, interval=2), field="integer_col"
)
self.assertEqual(object_under_test, other)

def test__eq___type_mismatch(self):
from google.cloud.bigquery.table import PartitionRange

object_under_test = self._make_one(
range_=PartitionRange(start=1, end=10, interval=2), field="integer_col"
)
self.assertNotEqual(object_under_test, object())
self.assertEqual(object_under_test, mock.ANY)

def test_unhashable_object(self):
from google.cloud.bigquery.table import PartitionRange

object_under_test1 = self._make_one(
range_=PartitionRange(start=1, end=10, interval=2), field="integer_col"
)
with six.assertRaisesRegex(self, TypeError, r".*unhashable type.*"):
hash(object_under_test1)

def test_repr(self):
from google.cloud.bigquery.table import PartitionRange

Expand Down