Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add size based assertions #265

Merged
merged 2 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions chex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from chex._src.asserts import assert_devices_available
from chex._src.asserts import assert_equal
from chex._src.asserts import assert_equal_rank
from chex._src.asserts import assert_equal_size
from chex._src.asserts import assert_equal_shape
from chex._src.asserts import assert_equal_shape_prefix
from chex._src.asserts import assert_equal_shape_suffix
Expand All @@ -39,6 +40,7 @@
from chex._src.asserts import assert_scalar_negative
from chex._src.asserts import assert_scalar_non_negative
from chex._src.asserts import assert_scalar_positive
from chex._src.asserts import assert_size
from chex._src.asserts import assert_shape
from chex._src.asserts import assert_tpu_available
from chex._src.asserts import assert_tree_all_close # Deprecated
Expand All @@ -58,6 +60,7 @@
from chex._src.asserts import assert_trees_all_equal
from chex._src.asserts import assert_trees_all_equal_comparator
from chex._src.asserts import assert_trees_all_equal_dtypes
from chex._src.asserts import assert_trees_all_equal_sizes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing this one in __all__.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good spot, this is sorted in the latest commit.

from chex._src.asserts import assert_trees_all_equal_shapes
from chex._src.asserts import assert_trees_all_equal_shapes_and_dtypes
from chex._src.asserts import assert_trees_all_equal_structs
Expand Down Expand Up @@ -124,6 +127,7 @@
"assert_devices_available",
"assert_equal",
"assert_equal_rank",
"assert_equal_size",
"assert_equal_shape",
"assert_equal_shape_prefix",
"assert_equal_shape_suffix",
Expand All @@ -140,6 +144,7 @@
"assert_scalar_negative",
"assert_scalar_non_negative",
"assert_scalar_positive",
"assert_size",
"assert_shape",
"assert_tpu_available",
"assert_tree_all_close", # Deprecated
Expand All @@ -159,6 +164,7 @@
"assert_trees_all_equal",
"assert_trees_all_equal_comparator",
"assert_trees_all_equal_dtypes",
"assert_trees_all_equal_sizes",
"assert_trees_all_equal_shapes",
"assert_trees_all_equal_shapes_and_dtypes",
"assert_trees_all_equal_structs",
Expand Down
101 changes: 101 additions & 0 deletions chex/_src/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,87 @@ def assert_scalar_negative(x: Scalar) -> None:
raise AssertionError(f"The argument must be negative, was {x}.")


@_static_assertion
def assert_equal_size(inputs: Sequence[Array]) -> None:
"""Checks that all arrays have the same size.

Args:
inputs: A collection of arrays.

Raises:
AssertionError: If the size of all arrays do not match.
"""
_ai.assert_collection_of_arrays(inputs)
size = inputs[0].size
expected_sizes = [size] * len(inputs)
sizes = [x.size for x in inputs]
if sizes != expected_sizes:
raise AssertionError(f"Arrays have different sizes: {sizes}")


@_static_assertion
def assert_size(
inputs: Union[Scalar, Union[Array, Sequence[Array]]],
expected_sizes: Union[_ai.TShapeMatcher,
Sequence[_ai.TShapeMatcher]]) -> None:
"""Checks that the size of all inputs matches specified ``expected_sizes``.

Valid usages include:

.. code-block:: python

assert_size(x, 1) # x is scalar (size 1)
assert_size([x, y], (2, {1, 3})) # x has size 2, y has size 1 or 3
assert_size([x, y], (2, ...)) # x has size 2, y has any size
assert_size([x, y], 1) # x and y are scalar (size 1)
assert_size((x, y), (5, 2)) # x has size 5, y has size 2

Args:
inputs: An array or a sequence of arrays.
expected_sizes: A sqeuence of expected sizes associated with each input,
where the expected size is a sequence of integer and `None` dimensions;
if all inputs have same size, a single size may be passed as
``expected_sizes``.

Raises:
AssertionError: If the lengths of ``inputs`` and ``expected_sizes`` do not
match; if ``expected_sizes`` has wrong type; if size of ``input`` does
not match ``expected_sizes``.
"""
# Ensure inputs and expected sizes are sequences.
if not isinstance(inputs, collections.abc.Sequence):
inputs = [inputs]

if isinstance(expected_sizes, int):
expected_sizes = [expected_sizes] * len(inputs)

if not isinstance(expected_sizes, (list, tuple)):
raise AssertionError(
"Error in size compatibility check: expected sizes should be an int, "
f"list, or tuple of ints, got {expected_sizes}.")

if len(inputs) != len(expected_sizes):
raise AssertionError(
"Length of `inputs` and `expected_sizes` must match: "
f"{len(inputs)} is not equal to {len(expected_sizes)}.")

errors = []
for idx, (x, expected) in enumerate(zip(inputs, expected_sizes)):
size = getattr(x, "size", 1) # scalars have size 1 by definition.
# Allow any size for the ellipsis case and allow handling of integer
# expected sizes or collection of acceptable expected sizes.
int_condition = expected in {Ellipsis, None} or size == expected
set_condition = (isinstance(expected, collections.abc.Collection) and
size in expected)
if not (int_condition or set_condition):
errors.append((idx, size, expected))

if errors:
msg = "; ".join(
f"input {e[0]} has size {e[1]} but expected {e[2]}" for e in errors)
raise AssertionError(f"Error in size compatibility check: {msg}.")


@_static_assertion
def assert_equal_shape(
inputs: Sequence[Array],
Expand Down Expand Up @@ -1405,6 +1486,26 @@ def err_msg_fn(arr_1, arr_2):
cmp_fn, err_msg_fn, *trees, ignore_nones=ignore_nones)


@_static_assertion
def assert_trees_all_equal_sizes(*trees: ArrayTree,
ignore_nones: bool = False) -> None:
"""Checks that trees have the same structure and leaves' sizes.

Args:
*trees: A sequence of (at least 2) trees with array leaves.
ignore_nones: Whether to ignore `None` in the trees.

Raises:
AssertionError: If trees' structures or leaves' sizes are different;
if the trees contain `None` (with ``ignore_nones=False``).
"""
cmp_fn = lambda arr_1, arr_2: arr_1.size == arr_2.size
err_msg_fn = lambda arr_1, arr_2: f"sizes: {arr_1.size} != {arr_2.size}"
assert_trees_all_equal_comparator(
cmp_fn, err_msg_fn, *trees, ignore_nones=ignore_nones
)


@_static_assertion
def assert_trees_all_equal_shapes(*trees: ArrayTree,
ignore_nones: bool = False) -> None:
Expand Down
98 changes: 98 additions & 0 deletions chex/_src/asserts_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,80 @@ def test_scalar_in_excluded(self):
asserts.assert_scalar_in(1, 0, 1, included=False)


class EqualSizeAssertTest(parameterized.TestCase):
@parameterized.named_parameters(
('scalar_vector_matrix', [1, 2, [3], [[4, 5]]]),
('vector_matrix', [[1], [2], [[3, 5]]]),
('matrix', [[[1, 2]], [[3], [4], [5]]]),
)
def test_equal_size_should_fail(self, arrays):
arrays = as_arrays(arrays)
with self.assertRaisesRegex(AssertionError,
_get_err_regex('Arrays have different sizes')):
asserts.assert_equal_size(arrays)

@parameterized.named_parameters(
('scalar_vector_matrix', [1, 2, [3], [[4]]]),
('vector_matrix', [[1], [2], [[3]]]),
('matrix', [[[1, 2]], [[3], [4]]]),
)
def test_equal_size_should_pass(self, arrays):
arrays = as_arrays(arrays)
asserts.assert_equal_size(arrays)


class SizeAssertTest(parameterized.TestCase):

@parameterized.named_parameters(
('wrong_size', [1, 2], 2),
('some_wrong_size', [[1, 2], [3, 4]], (2, 3)),
('wrong_common_shape', [[1, 2], [3, 4, 3]], 3),
('wrong_common_shape_2', [[1, 2, 3], [1, 2]], 2),
('some_wrong_size_set', [[1, 2], [3, 4]], (2, {3, 4})),
)
def test_size_should_fail(self, arrays, sizes):
arrays = as_arrays(arrays)
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('input .+ has size .+ but expected .+')):
asserts.assert_size(arrays, sizes)

@parameterized.named_parameters(
('too_many_sizes', [[1]], (1, 1)),
('not_enough_sizes', [[1, 2], [3, 4], [5, 6]], (2, 2)),
)
def test_size_should_fail_wrong_length(self, arrays, sizes):
arrays = as_arrays(arrays)
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('Length of `inputs` and `expected_sizes` must match')):
asserts.assert_size(arrays, sizes)

@parameterized.named_parameters(
('scalars', [1, 2], 1),
('vectors', [[1, 2], [3, 4, 5]], [2, 3]),
('matrices', [[[1, 2], [3, 4]]], 4),
('common_size_set', [[[1, 2], [3, 4]], [[1], [3]]], (4, {1, 2})),
)
def test_size_should_pass(self, arrays, sizes):
arrays = as_arrays(arrays)
asserts.assert_size(arrays, sizes)

def test_pytypes_pass(self):
arrays = as_arrays([[[1, 2], [3, 4]], [[1], [3]]])
asserts.assert_size(arrays, (4, None))
asserts.assert_size(arrays, (4, {1, 2}))
asserts.assert_size(arrays, (4, ...))

@parameterized.named_parameters(
('single_ellipsis', [[1,2,3,4], [1,2]],(..., 2) ),
('multiple_ellipsis', [[1,2,3], [1,2,3]], (..., ...)),
)
def test_ellipsis_should_pass(self, arrays, expected_size):
arrays = as_arrays(arrays)
asserts.assert_size(arrays, expected_size)


class EqualShapeAssertTest(parameterized.TestCase):

@parameterized.named_parameters(
Expand Down Expand Up @@ -1299,6 +1373,30 @@ def test_assert_trees_all_close_fails_values_differ(self):
AssertionError, _get_err_regex('Values not approximately equal')):
asserts.assert_trees_all_close(tree1, tree2, rtol=0.01)

def test_assert_trees_all_equal_sizes(self):
get_val = lambda s1, s2: jnp.zeros([s1, s2])
tree1 = dict(a1=get_val(3,1), d=dict(a2=get_val(4,1), a3=get_val(5,3)))
tree2 = dict(a1=get_val(3,1), d=dict(a2=get_val(4,1), a3=get_val(5,3)))
tree3 = dict(a1=get_val(3,1), d=dict(a2=get_val(4,2), a3=get_val(5,3)))

self._assert_tree_structs_validation(asserts.assert_trees_all_equal_sizes)
asserts.assert_trees_all_equal_sizes(tree1, tree1)
asserts.assert_trees_all_equal_sizes(tree2, tree1)

with self.assertRaisesRegex(
AssertionError,
_get_err_regex(
r'Trees 0 and 1 differ in leaves \'d/a2\': sizes: 4 != 8'
)):
asserts.assert_trees_all_equal_sizes(tree1, tree3)

with self.assertRaisesRegex(
AssertionError,
_get_err_regex(
r'Trees 0 and 3 differ in leaves \'d/a2\': sizes: 4 != 8'
)):
asserts.assert_trees_all_equal_sizes(tree1, tree2, tree2, tree3, tree1)

def test_assert_trees_all_equal_shapes(self):
get_val = lambda s: jnp.zeros([s])
tree1 = dict(a1=get_val(3), d=dict(a2=get_val(4), a3=get_val(5)))
Expand Down
6 changes: 6 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Assertions
assert_devices_available
assert_equal
assert_equal_rank
assert_equal_size
assert_equal_shape
assert_equal_shape_prefix
assert_equal_shape_suffix
Expand All @@ -30,6 +31,7 @@ Assertions
assert_scalar_negative
assert_scalar_non_negative
assert_scalar_positive
assert_size
assert_shape
assert_tpu_available
assert_tree_all_finite
Expand All @@ -45,6 +47,7 @@ Assertions
assert_trees_all_equal
assert_trees_all_equal_comparator
assert_trees_all_equal_dtypes
assert_trees_all_equal_sizes
assert_trees_all_equal_shapes
assert_trees_all_equal_shapes_and_dtypes
assert_trees_all_equal_structs
Expand Down Expand Up @@ -94,6 +97,7 @@ Tree Assertions
.. autofunction:: assert_trees_all_equal
.. autofunction:: assert_trees_all_equal_comparator
.. autofunction:: assert_trees_all_equal_dtypes
.. autofunction:: assert_trees_all_equal_sizes
.. autofunction:: assert_trees_all_equal_shapes
.. autofunction:: assert_trees_all_equal_shapes_and_dtypes
.. autofunction:: assert_trees_all_equal_structs
Expand All @@ -110,6 +114,7 @@ Generic Assertions
.. autofunction:: assert_axis_dimension_lteq
.. autofunction:: assert_equal
.. autofunction:: assert_equal_rank
.. autofunction:: assert_equal_size
.. autofunction:: assert_equal_shape
.. autofunction:: assert_equal_shape_prefix
.. autofunction:: assert_equal_shape_suffix
Expand All @@ -124,6 +129,7 @@ Generic Assertions
.. autofunction:: assert_scalar_negative
.. autofunction:: assert_scalar_non_negative
.. autofunction:: assert_scalar_positive
.. autofunction:: assert_size
.. autofunction:: assert_shape
.. autofunction:: assert_type

Expand Down