Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flattening functionality to chex.Dimensions. #341

Merged
merged 1 commit into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 43 additions & 4 deletions chex/_src/dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ class Dimensions:
>>> dims.size('BT') # Same as prod(dims['BT']).
15

Similarly, you can flatten axes together by wrapping them in parentheses:

.. code::

>>> dims['(BT)N']
(15, 7)

You can set a wildcard dimension, cf. :func:`chex.assert_shape`:

.. code::
Expand Down Expand Up @@ -118,7 +125,6 @@ class Dimensions:

>>> dims['M']
(7,)

"""
# Tell static type checker not to worry about attribute errors.
_HAS_DYNAMIC_ATTRIBUTES = True
Expand All @@ -129,15 +135,48 @@ def __init__(self, **dim_sizes) -> None:

def size(self, key: str) -> int:
"""Returns the flat size of a given named shape, i.e. prod(shape)."""
if None in (shape := self[key]):
shape = self[key]
if any(size is None or size <= 0 for size in shape):
raise ValueError(
f"cannot take product of shape '{key}' = {shape}, "
'because it contains wildcard dimensions')
'because it contains non-positive sized dimensions'
)
return math.prod(shape)

def __getitem__(self, key: str) -> Shape:
self._validate_key(key)
return tuple(self._getdim(dim) for dim in key)
shape = []
open_parentheses = False
dims_to_flatten = ''
for dim in key:
# Signal to start accumulating `dims_to_flatten`.
if dim == '(':
if open_parentheses:
raise ValueError(f"nested parentheses are unsupported; got: '{key}'")
open_parentheses = True

# Signal to collect accumulated `dims_to_flatten`.
elif dim == ')':
if not open_parentheses:
raise ValueError(f"unmatched parentheses in named shape: '{key}'")
if not dims_to_flatten:
raise ValueError(f"found empty parentheses in named shape: '{key}'")
shape.append(self.size(dims_to_flatten))
# Reset.
open_parentheses = False
dims_to_flatten = ''

# Accumulate `dims_to_flatten`.
elif open_parentheses:
dims_to_flatten += dim

# The typical (non-flattening) case.
else:
shape.append(self._getdim(dim))

if open_parentheses:
raise ValueError(f"unmatched parentheses in named shape: '{key}'")
return tuple(shape)

def __setitem__(self, key: str, value: Collection[Optional[int]]) -> None:
self._validate_key(key)
Expand Down
40 changes: 37 additions & 3 deletions chex/_src/dimensions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_get_wildcard(self):
self.assertEqual(dims['x*y**'], (23, None, 29, None, None))
asserts.assert_shape(np.empty((23, 1, 29, 2, 3)), dims['x*y**'])
with self.assertRaisesRegex(KeyError, r'\_'):
dims['xy_'] # pylint: disable=pointless-statement
_ = dims['xy_']

def test_get_literals(self):
dims = dimensions.Dimensions(x=23, y=29)
Expand All @@ -89,7 +89,7 @@ def test_set_exception(self, k, v, e, m):
def test_get_exception(self, k, e, m):
dims = dimensions.Dimensions(x=23, y=29)
with self.assertRaisesRegex(e, m):
dims[k] # pylint: disable=pointless-statement
_ = dims[k]

@parameterized.named_parameters([
('scalar', '', (), 1),
Expand All @@ -102,12 +102,46 @@ def test_size_ok(self, names, shape, expected_size):
@parameterized.named_parameters([
('named', 'ab'),
('asterisk', 'a*'),
('zero', 'a0'),
('negative', 'ac'),
])
def test_size_fail_wildcard(self, names):
dims = dimensions.Dimensions(a=3, b=None)
dims = dimensions.Dimensions(a=3, b=None, c=-1)
with self.assertRaisesRegex(ValueError, r'cannot take product of shape'):
dims.size(names)

@parameterized.named_parameters([
('trivial_start', '(a)bc', (3, 5, 7)),
('trivial_mid', 'a(b)c', (3, 5, 7)),
('trivial_end', 'ab(c)', (3, 5, 7)),
('start', '(ab)cd', (15, 7, 11)),
('mid', 'a(bc)d', (3, 35, 11)),
('end', 'ab(cd)', (3, 5, 77)),
('multiple', '(ab)(cd)', (15, 77)),
('all', '(abc)', (105,)),
])
def test_flatten_ok(self, named_shape, expected_shape):
dims = dimensions.Dimensions(a=3, b=5, c=7, d=11)
self.assertEqual(dims[named_shape], expected_shape)

@parameterized.named_parameters([
('unmatched_open', '(ab', r'unmatched parentheses in named shape'),
('unmatched_closed', 'a)b', r'unmatched parentheses in named shape'),
('nested', '(a(bc))', r'nested parentheses are unsupported'),
('wildcard_named', 'a(bx)', r'cannot take product of shape'),
('wildcard_asterisk', '(a*)b', r'cannot take product of shape'),
('zero_sized_dim', '(a0)b', r'cannot take product of shape'),
('neg_sized_dim', '(ay)b', r'cannot take product of shape'),
('empty_start', '()ab', r'found empty parentheses in named shape'),
('empty_mid', 'a()b', r'found empty parentheses in named shape'),
('empty_end', 'ab()', r'found empty parentheses in named shape'),
('empty_solo', '()', r'found empty parentheses in named shape'),
])
def test_flatten_fail(self, named_shape, error_message):
dims = dimensions.Dimensions(a=3, b=5, x=None, y=-1)
with self.assertRaisesRegex(ValueError, error_message):
_ = dims[named_shape]


if __name__ == '__main__':
absltest.main()
Loading