diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 69423a2e1..519221509 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -32,11 +32,13 @@ Iterator, OrderedDict, Sequence, + SupportsInt, Union, ) import ml_dtypes import numpy as np +from typing_extensions import TypeIs import onnxscript from onnxscript.ir import ( @@ -859,12 +861,37 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({self._value})" +def _is_int_compatible(value: object) -> TypeIs[SupportsInt]: + """Return True if the value is int compatible.""" + if isinstance(value, int): + return True + if hasattr(value, "__int__"): + # For performance reasons, we do not use isinstance(value, SupportsInt) + return True + return False + + +def _maybe_convert_to_symbolic_dim( + dim: int | SupportsInt | SymbolicDim | str | None, +) -> SymbolicDim | int: + """Convert the value to a SymbolicDim if it is not an int.""" + if dim is None or isinstance(dim, str): + return SymbolicDim(dim) + if _is_int_compatible(dim): + return int(dim) + if isinstance(dim, SymbolicDim): + return dim + raise TypeError( + f"Expected int, str, None or SymbolicDim, but value {dim!r} has type '{type(dim)}'" + ) + + class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable): __slots__ = ("_dims", "_frozen") def __init__( self, - dims: Iterable[int | SymbolicDim | str | None], + dims: Iterable[int | SupportsInt | SymbolicDim | str | None], /, denotations: Iterable[str | None] | None = None, frozen: bool = False, @@ -885,8 +912,7 @@ def __init__( is useful when the shape is initialized by a Tensor. """ self._dims: list[int | SymbolicDim] = [ - SymbolicDim(dim) if not isinstance(dim, (int, SymbolicDim)) else dim - for dim in dims + _maybe_convert_to_symbolic_dim(dim) for dim in dims ] self._denotations: list[str | None] = ( list(denotations) if denotations is not None else [None] * len(self._dims) @@ -946,12 +972,8 @@ def __setitem__(self, index: int, value: int | SymbolicDim | str | None) -> None """ if self._frozen: raise TypeError("The shape is frozen and cannot be modified.") - if isinstance(value, str) or value is None: - value = SymbolicDim(value) - if not isinstance(value, (int, SymbolicDim)): - raise TypeError(f"Expected int, str, None or SymbolicDim, got '{type(value)}'") - self._dims[index] = value + self._dims[index] = _maybe_convert_to_symbolic_dim(value) def get_denotation(self, index: int) -> str | None: """Return the denotation of the dimension at the index. @@ -986,7 +1008,7 @@ def __str__(self) -> str: def __eq__(self, other: object) -> bool: """Return True if the shapes are equal. - Two shapes are eqaul if all their dimensions are equal. + Two shapes are equal if all their dimensions are equal. """ if isinstance(other, Shape): return self._dims == other._dims @@ -997,6 +1019,33 @@ def __eq__(self, other: object) -> bool: def __ne__(self, other: object) -> bool: return not self.__eq__(other) + @typing.overload + def is_static(self, dim: int) -> bool: # noqa: D418 + """Return True if the dimension is static.""" + + @typing.overload + def is_static(self) -> bool: # noqa: D418 + """Return True if all dimensions are static.""" + + def is_static(self, dim=None) -> bool: + """Return True if the dimension is static. If dim is None, return True if all dimensions are static.""" + if dim is None: + return all(isinstance(dim, int) for dim in self._dims) + return isinstance(self[dim], int) + + @typing.overload + def is_dynamic(self, dim: int) -> bool: # noqa: D418 + """Return True if the dimension is dynamic.""" + + @typing.overload + def is_dynamic(self) -> bool: # noqa: D418 + """Return True if any dimension is dynamic.""" + + def is_dynamic(self, dim=None) -> bool: + if dim is None: + return not self.is_static() + return not self.is_static(dim) + def _quoted(string: str) -> str: """Return a quoted string. diff --git a/onnxscript/ir/_core_test.py b/onnxscript/ir/_core_test.py index 498a8a3ce..8662a8c01 100644 --- a/onnxscript/ir/_core_test.py +++ b/onnxscript/ir/_core_test.py @@ -520,6 +520,30 @@ def test_int_dimensions_are_python_ints(self): shape = _core.Shape([42]) self.assertIsInstance(shape[0], int) + def test_str_dimensions_are_symbolic_dims(self): + shape = _core.Shape(["any string"]) + self.assertIsInstance(shape[0], _core.SymbolicDim) + + def test_none_dimensions_are_symbolic_dims(self): + shape = _core.Shape([None]) + self.assertIsInstance(shape[0], _core.SymbolicDim) + + def test_init_raises_when_dims_is_not_a_list(self): + with self.assertRaises(TypeError): + _core.Shape(42) + + def test_init_converts_np_shape_to_tuple(self): + dims = np.array([42, 42]) + shape = _core.Shape(dims) + self.assertEqual(shape.dims, tuple(dims)) + + def test_init_converts_np_int_to_python_int(self): + dims = [np.int32(42)] + shape = _core.Shape(dims) + self.assertIsInstance(shape[0], int) + self.assertNotIsInstance(shape[0], np.int32) + self.assertIsInstance(shape.dims[0], int) + @parameterized.parameterized.expand( [ ("empty", (), ()), @@ -623,6 +647,10 @@ def test_setitem(self, _: str, value): else: self.assertEqual(dim, value) + def test_len(self): + shape = _core.Shape([42, "any string"]) + self.assertEqual(len(shape), 2) + def test_get_denotation(self): shape = _core.Shape([42], denotations=("DATA_CHANNEL",)) self.assertEqual(shape.get_denotation(0), "DATA_CHANNEL") @@ -637,6 +665,56 @@ def test_set_denotation_is_still_possible_when_shape_is_frozen(self): shape.set_denotation(0, "UPDATED") self.assertEqual(shape.get_denotation(0), "UPDATED") + def test_is_static(self): + dim_from_numpy = np.array([42]).shape[0] + np_int = np.int32(42) + shape = _core.Shape([42, "any string", dim_from_numpy, np_int]) + self.assertTrue(shape.is_static(0)) + self.assertFalse(shape.is_static(1)) + self.assertTrue(shape.is_static(2)) + self.assertTrue(shape.is_static(3)) + self.assertFalse(shape.is_static()) + + def test_is_static_raises_when_index_out_of_range(self): + shape = _core.Shape([42]) + with self.assertRaises(IndexError): + shape.is_static(1) + + def test_is_static_on_whole_shape(self): + shape = _core.Shape([42, "any string"]) + self.assertFalse(shape.is_static()) + shape = _core.Shape([42, 42]) + self.assertTrue(shape.is_static()) + + def test_is_static_on_empty_shape(self): + shape = _core.Shape(()) + self.assertTrue(shape.is_static()) + + def test_is_dynamic(self): + dim_from_numpy = np.array([42]).shape[0] + np_int = np.int32(42) + shape = _core.Shape([42, "any string", dim_from_numpy, np_int]) + self.assertFalse(shape.is_dynamic(0)) + self.assertTrue(shape.is_dynamic(1)) + self.assertFalse(shape.is_dynamic(2)) + self.assertFalse(shape.is_dynamic(3)) + self.assertTrue(shape.is_dynamic()) + + def test_is_dynamic_raises_when_index_out_of_range(self): + shape = _core.Shape([42]) + with self.assertRaises(IndexError): + shape.is_dynamic(1) + + def test_is_dynamic_on_whole_shape(self): + shape = _core.Shape([42, "any string"]) + self.assertTrue(shape.is_dynamic()) + shape = _core.Shape([42, 42]) + self.assertFalse(shape.is_dynamic()) + + def test_is_dynamic_on_empty_shape(self): + shape = _core.Shape(()) + self.assertFalse(shape.is_dynamic()) + class ValueTest(unittest.TestCase): def test_initialize(self):