Skip to content

Commit

Permalink
Add Dimensions.size to get flat size from named shape.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 615765097
  • Loading branch information
KristianHolsheimer authored and ChexDev committed Mar 15, 2024
1 parent a564566 commit 70f25b6
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
16 changes: 16 additions & 0 deletions chex/_src/dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================
"""Utilities to hold expected dimension sizes."""

import math
import re
from typing import Any, Collection, Dict, Optional, Sized, Tuple

Expand Down Expand Up @@ -60,6 +61,13 @@ class Dimensions:
>>> dims
Dimensions(B=3, N=7, T=5, X=2, Y=4)
You can access the flat size of a shape as
.. code::
>>> dims.size('BT') # Same as prod(dims['BT']).
15
You can set a wildcard dimension, cf. :func:`chex.assert_shape`:
.. code::
Expand Down Expand Up @@ -119,6 +127,14 @@ def __init__(self, **dim_sizes) -> None:
for dim, size in dim_sizes.items():
self._setdim(dim, size)

def size(self, key: str) -> int:
"""Returns the flat size of a given named shape, i.e. prod(shape)."""
if None in (shape := self[key]):
raise ValueError(
f"cannot take product of shape '{key}' = {shape}, "
'because it contains wildcard dimensions')
return math.prod(shape)

def __getitem__(self, key: str) -> Shape:
self._validate_key(key)
return tuple(self._getdim(dim) for dim in key)
Expand Down
17 changes: 17 additions & 0 deletions chex/_src/dimensions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,23 @@ def test_get_exception(self, k, e, m):
with self.assertRaisesRegex(e, m):
dims[k] # pylint: disable=pointless-statement

@parameterized.named_parameters([
('scalar', '', (), 1),
('nonscalar', 'ab', (3, 5), 15),
])
def test_size_ok(self, names, shape, expected_size):
dims = dimensions.Dimensions(**dict(zip(names, shape)))
self.assertEqual(dims.size(names), expected_size)

@parameterized.named_parameters([
('named', 'ab'),
('asterisk', 'a*'),
])
def test_size_fail_wildcard(self, names):
dims = dimensions.Dimensions(a=3, b=None)
with self.assertRaisesRegex(ValueError, r'cannot take product of shape'):
dims.size(names)


if __name__ == '__main__':
absltest.main()

0 comments on commit 70f25b6

Please sign in to comment.