From c4962ae1fba39f9da17ad2bf8d013a5bd705cc55 Mon Sep 17 00:00:00 2001 From: bgreni Date: Mon, 22 Apr 2024 18:49:45 -0600 Subject: [PATCH] Use __index__ for __getitem__ and __setitem__ When indexing stdlib containers we should accept a generic type that calls on the __index__ method to allow types other than Int to be used but doesn't allow Intable types that should not be used for such purposes (such as Float) Signed-off-by: Brian Grenier --- stdlib/src/builtin/bool.mojo | 16 ++++++++- stdlib/src/builtin/builtin_list.mojo | 10 +++--- stdlib/src/builtin/builtin_slice.mojo | 4 +-- stdlib/src/builtin/int.mojo | 2 +- stdlib/src/builtin/int_literal.mojo | 2 +- stdlib/src/builtin/range.mojo | 12 +++---- stdlib/src/builtin/simd.mojo | 43 +++++++++++++++++------- stdlib/src/builtin/value.mojo | 10 ++++++ stdlib/src/collections/vector.mojo | 16 ++++----- stdlib/src/memory/unsafe.mojo | 10 +++--- stdlib/src/python/object.mojo | 8 ++++- stdlib/src/utils/index.mojo | 8 ++--- stdlib/src/utils/static_tuple.mojo | 14 ++++---- stdlib/src/utils/stringref.mojo | 4 +-- stdlib/test/collections/test_vector.mojo | 5 +++ 15 files changed, 109 insertions(+), 55 deletions(-) diff --git a/stdlib/src/builtin/bool.mojo b/stdlib/src/builtin/bool.mojo index fb2eaef300..eca9585a9d 100644 --- a/stdlib/src/builtin/bool.mojo +++ b/stdlib/src/builtin/bool.mojo @@ -47,7 +47,12 @@ trait Boolable: @value @register_passable("trivial") struct Bool( - Stringable, CollectionElement, Boolable, EqualityComparable, Intable + Stringable, + CollectionElement, + Boolable, + EqualityComparable, + Intable, + Indexer, ): """The primitive Bool scalar value used in Mojo.""" @@ -261,6 +266,15 @@ struct Bool( ) ) + @always_inline("nodebug") + fn __index__(self) -> Int: + """Convert this Bool to an integer for indexing purposes + + Returns: + Bool as Int + """ + return self.__int__() + @always_inline fn bool(value: None) -> Bool: diff --git a/stdlib/src/builtin/builtin_list.mojo b/stdlib/src/builtin/builtin_list.mojo index 16ef431d3b..43847c764e 100644 --- a/stdlib/src/builtin/builtin_list.mojo +++ b/stdlib/src/builtin/builtin_list.mojo @@ -138,7 +138,7 @@ struct VariadicList[type: AnyRegType](Sized): return __mlir_op.`pop.variadic.size`(self.value) @always_inline - fn __getitem__(self, index: Int) -> type: + fn __getitem__[indexer: Indexer](self, index: indexer) -> type: """Gets a single element on the variadic list. Args: @@ -147,7 +147,7 @@ struct VariadicList[type: AnyRegType](Sized): Returns: The element on the list corresponding to the given index. """ - return __mlir_op.`pop.variadic.get`(self.value, index.value) + return __mlir_op.`pop.variadic.get`(self.value, index.__index__().value) @always_inline fn __iter__(self) -> Self.IterType: @@ -348,7 +348,9 @@ struct VariadicListMem[ # TODO: Fix for loops + _VariadicListIter to support a __nextref__ protocol # allowing us to get rid of this and make foreach iteration clean. @always_inline - fn __getitem__(self, index: Int) -> Self.reference_type: + fn __getitem__[ + indexer: Indexer + ](self, index: indexer) -> Self.reference_type: """Gets a single element on the variadic list. Args: @@ -359,7 +361,7 @@ struct VariadicListMem[ given index. """ return Self.reference_type( - __mlir_op.`pop.variadic.get`(self.value, index.value) + __mlir_op.`pop.variadic.get`(self.value, index.__index__().value) ) @always_inline diff --git a/stdlib/src/builtin/builtin_slice.mojo b/stdlib/src/builtin/builtin_slice.mojo index b6da005a8f..ae00b55589 100644 --- a/stdlib/src/builtin/builtin_slice.mojo +++ b/stdlib/src/builtin/builtin_slice.mojo @@ -155,7 +155,7 @@ struct Slice(Sized, Stringable, EqualityComparable): return len(range(self.start, self.end, self.step)) @always_inline - fn __getitem__(self, idx: Int) -> Int: + fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int: """Get the slice index. Args: @@ -164,7 +164,7 @@ struct Slice(Sized, Stringable, EqualityComparable): Returns: The slice index. """ - return self.start + idx * self.step + return self.start + idx.__index__() * self.step @always_inline("nodebug") fn _has_end(self) -> Bool: diff --git a/stdlib/src/builtin/int.mojo b/stdlib/src/builtin/int.mojo index b33aaaa681..6e93a7bd3b 100644 --- a/stdlib/src/builtin/int.mojo +++ b/stdlib/src/builtin/int.mojo @@ -172,7 +172,7 @@ fn int[T: IntableRaising](value: T) raises -> Int: @lldb_formatter_wrapping_type @value @register_passable("trivial") -struct Int(Intable, Stringable, KeyElement, Boolable, Formattable): +struct Int(Intable, Stringable, KeyElement, Boolable, Formattable, Indexer): """This type represents an integer value.""" var value: __mlir_type.index diff --git a/stdlib/src/builtin/int_literal.mojo b/stdlib/src/builtin/int_literal.mojo index c9e21c0974..3241b5f5bd 100644 --- a/stdlib/src/builtin/int_literal.mojo +++ b/stdlib/src/builtin/int_literal.mojo @@ -16,7 +16,7 @@ @value @nonmaterializable(Int) @register_passable("trivial") -struct IntLiteral(Intable, Stringable, Boolable, EqualityComparable): +struct IntLiteral(Intable, Stringable, Boolable, EqualityComparable, Indexer): """This type represents a static integer literal value with infinite precision. They can't be materialized at runtime and must be lowered to other integer types (like Int), but allow for diff --git a/stdlib/src/builtin/range.mojo b/stdlib/src/builtin/range.mojo index 2d2e9dc032..87f3468271 100644 --- a/stdlib/src/builtin/range.mojo +++ b/stdlib/src/builtin/range.mojo @@ -92,8 +92,8 @@ struct _ZeroStartingRange(Sized, ReversibleRange): return self.curr @always_inline("nodebug") - fn __getitem__(self, idx: Int) -> Int: - return idx + fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int: + return idx.__index__() @always_inline("nodebug") fn __reversed__(self) -> _StridedRangeIterator: @@ -123,8 +123,8 @@ struct _SequentialRange(Sized, ReversibleRange): return self.end - self.start if self.start < self.end else 0 @always_inline("nodebug") - fn __getitem__(self, idx: Int) -> Int: - return self.start + idx + fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int: + return self.start + idx.__index__() @always_inline("nodebug") fn __reversed__(self) -> _StridedRangeIterator: @@ -195,8 +195,8 @@ struct _StridedRange(Sized, ReversibleRange): return _div_ceil_positive(_abs(self.start - self.end), _abs(self.step)) @always_inline("nodebug") - fn __getitem__(self, idx: Int) -> Int: - return self.start + idx * self.step + fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int: + return self.start + idx.__index__() * self.step @always_inline("nodebug") fn __reversed__(self) -> _StridedRangeIterator: diff --git a/stdlib/src/builtin/simd.mojo b/stdlib/src/builtin/simd.mojo index 8e8e6d62f3..9fa18c6e74 100644 --- a/stdlib/src/builtin/simd.mojo +++ b/stdlib/src/builtin/simd.mojo @@ -113,12 +113,7 @@ fn _unchecked_zero[type: DType, size: Int]() -> SIMD[type, size]: @lldb_formatter_wrapping_type @register_passable("trivial") struct SIMD[type: DType, size: Int = simdwidthof[type]()]( - Sized, - Intable, - CollectionElement, - Stringable, - Hashable, - Boolable, + Sized, Intable, CollectionElement, Stringable, Hashable, Boolable, Indexer ): """Represents a small vector that is backed by a hardware vector element. @@ -513,6 +508,22 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()]( rebind[Scalar[type]](self).value ) + @always_inline("nodebug") + fn __index__(self) -> Int: + """Returns the value as an int if it is an integral value + + Contraints: + Must be an integral value + + Returns: + The value as an integer + """ + constrained[ + type.is_integral() or type.is_bool(), + "expected integral or bool type", + ]() + return self.__int__() + @always_inline fn __str__(self) -> String: """Get the SIMD as a string. @@ -1518,7 +1529,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()]( # ===-------------------------------------------------------------------===# @always_inline("nodebug") - fn __getitem__(self, idx: Int) -> Scalar[type]: + fn __getitem__[indexer: Indexer](self, idx: indexer) -> Scalar[type]: """Gets an element from the vector. Args: @@ -1529,10 +1540,12 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()]( """ return __mlir_op.`pop.simd.extractelement`[ _type = __mlir_type[`!pop.scalar<`, type.value, `>`] - ](self.value, idx.value) + ](self.value, idx.__index__().value) @always_inline("nodebug") - fn __setitem__(inout self, idx: Int, val: Scalar[type]): + fn __setitem__[ + indexer: Indexer + ](inout self, idx: indexer, val: Scalar[type]): """Sets an element in the vector. Args: @@ -1540,12 +1553,16 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()]( val: The value to set. """ self.value = __mlir_op.`pop.simd.insertelement`( - self.value, val.value, idx.value + self.value, val.value, idx.__index__().value ) @always_inline("nodebug") - fn __setitem__( - inout self, idx: Int, val: __mlir_type[`!pop.scalar<`, type.value, `>`] + fn __setitem__[ + indexer: Indexer + ]( + inout self, + idx: indexer, + val: __mlir_type[`!pop.scalar<`, type.value, `>`], ): """Sets an element in the vector. @@ -1554,7 +1571,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()]( val: The value to set. """ self.value = __mlir_op.`pop.simd.insertelement`( - self.value, val, idx.value + self.value, val, idx.__index__().value ) fn __hash__(self) -> Int: diff --git a/stdlib/src/builtin/value.mojo b/stdlib/src/builtin/value.mojo index a1e7067733..b87028aeda 100644 --- a/stdlib/src/builtin/value.mojo +++ b/stdlib/src/builtin/value.mojo @@ -121,3 +121,13 @@ trait StringableCollectionElement(CollectionElement, Stringable): """ pass + + +trait Indexer: + fn __index__(self) -> Int: + """Return the index value + + Returns: + The index value of the object + """ + ... diff --git a/stdlib/src/collections/vector.mojo b/stdlib/src/collections/vector.mojo index 51a9ff9632..9b22f283f8 100644 --- a/stdlib/src/collections/vector.mojo +++ b/stdlib/src/collections/vector.mojo @@ -182,7 +182,7 @@ struct InlinedFixedVector[ return self.current_size @always_inline - fn __getitem__(self, i: Int) -> type: + fn __getitem__[indexer: Indexer](self, i: indexer) -> type: """Gets a vector element at the given index. Args: @@ -191,12 +191,12 @@ struct InlinedFixedVector[ Returns: The element at the given index. """ + var normalized_idx = i.__index__() debug_assert( - -self.current_size <= i < self.current_size, + -self.current_size <= normalized_idx < self.current_size, "index must be within bounds", ) - var normalized_idx = i - if i < 0: + if normalized_idx < 0: normalized_idx += len(self) if normalized_idx < Self.static_size: @@ -205,20 +205,20 @@ struct InlinedFixedVector[ return self.dynamic_data[normalized_idx - Self.static_size] @always_inline - fn __setitem__(inout self, i: Int, value: type): + fn __setitem__[indexer: Indexer](inout self, i: indexer, value: type): """Sets a vector element at the given index. Args: i: The index of the element. value: The value to assign. """ + var normalized_idx = i.__index__() debug_assert( - -self.current_size <= i < self.current_size, + -self.current_size <= normalized_idx < self.current_size, "index must be within bounds", ) - var normalized_idx = i - if i < 0: + if normalized_idx < 0: normalized_idx += len(self) if normalized_idx < Self.static_size: diff --git a/stdlib/src/memory/unsafe.mojo b/stdlib/src/memory/unsafe.mojo index 3788ba39bc..af5801213c 100644 --- a/stdlib/src/memory/unsafe.mojo +++ b/stdlib/src/memory/unsafe.mojo @@ -289,11 +289,11 @@ struct LegacyPointer[ ) @always_inline("nodebug") - fn __refitem__[T: Intable](self, offset: T) -> Self._mlir_ref_type: + fn __refitem__[T: Indexer](self, offset: T) -> Self._mlir_ref_type: """Enable subscript syntax `ref[idx]` to access the element. Parameters: - T: The Intable type of the offset. + T: The Indexer type of the offset. Args: offset: The offset to load from. @@ -301,7 +301,7 @@ struct LegacyPointer[ Returns: The MLIR reference for the Mojo compiler to use. """ - return (self + offset).__refitem__() + return (self + offset.__index__()).__refitem__() # ===------------------------------------------------------------------=== # # Load/Store @@ -714,7 +714,7 @@ struct DTypePointer[ return arg.get_legacy_pointer() @always_inline("nodebug") - fn __getitem__[T: Intable](self, offset: T) -> Scalar[type]: + fn __getitem__[T: Indexer](self, offset: T) -> Scalar[type]: """Loads a single element (SIMD of size 1) from the pointer at the specified index. @@ -727,7 +727,7 @@ struct DTypePointer[ Returns: The loaded value. """ - return self.load(offset) + return self.load(offset.__index__()) @always_inline("nodebug") fn __setitem__[T: Intable](self, offset: T, val: Scalar[type]): diff --git a/stdlib/src/python/object.mojo b/stdlib/src/python/object.mojo index 44bd922bca..151984b992 100644 --- a/stdlib/src/python/object.mojo +++ b/stdlib/src/python/object.mojo @@ -101,7 +101,13 @@ struct _PyIter(Sized): @register_passable struct PythonObject( - Intable, Stringable, SizedRaising, Boolable, CollectionElement, KeyElement + Intable, + Stringable, + SizedRaising, + Boolable, + CollectionElement, + KeyElement, + Indexer, ): """A Python object.""" diff --git a/stdlib/src/utils/index.mojo b/stdlib/src/utils/index.mojo index 9d26582984..846f4e7fa3 100644 --- a/stdlib/src/utils/index.mojo +++ b/stdlib/src/utils/index.mojo @@ -335,11 +335,11 @@ struct StaticIntTuple[size: Int](Sized, Stringable, EqualityComparable): return size @always_inline("nodebug") - fn __getitem__[intable: Intable](self, index: intable) -> Int: + fn __getitem__[indexer: Indexer](self, index: indexer) -> Int: """Gets an element from the tuple by index. Parameters: - intable: The intable type. + indexer: The index type. Args: index: The element index. @@ -362,11 +362,11 @@ struct StaticIntTuple[size: Int](Sized, Stringable, EqualityComparable): self.data.__setitem__[index](val) @always_inline("nodebug") - fn __setitem__[intable: Intable](inout self, index: intable, val: Int): + fn __setitem__[indexer: Indexer](inout self, index: indexer, val: Int): """Sets an element in the tuple at the given index. Parameters: - intable: The intable type. + indexer: The index type. Args: index: The element index. diff --git a/stdlib/src/utils/static_tuple.mojo b/stdlib/src/utils/static_tuple.mojo index 6b2e0d2332..cb1e8e596b 100644 --- a/stdlib/src/utils/static_tuple.mojo +++ b/stdlib/src/utils/static_tuple.mojo @@ -197,11 +197,11 @@ struct StaticTuple[element_type: AnyRegType, size: Int](Sized): self = tmp @always_inline("nodebug") - fn __getitem__[intable: Intable](self, index: intable) -> Self.element_type: + fn __getitem__[indexer: Indexer](self, index: indexer) -> Self.element_type: """Returns the value of the tuple at the given dynamic index. Parameters: - intable: The intable type. + indexer: The index type. Args: index: The index into the tuple. @@ -209,7 +209,7 @@ struct StaticTuple[element_type: AnyRegType, size: Int](Sized): Returns: The value at the specified position. """ - var offset = int(index) + var offset = index.__index__() debug_assert(offset < size, "index must be within bounds") # Copy the array so we can get its address, because we can't take the # address of 'self' in a non-mutating method. @@ -221,18 +221,18 @@ struct StaticTuple[element_type: AnyRegType, size: Int](Sized): @always_inline("nodebug") fn __setitem__[ - intable: Intable - ](inout self, index: intable, val: Self.element_type): + indexer: Indexer + ](inout self, index: indexer, val: Self.element_type): """Stores a single value into the tuple at the specified dynamic index. Parameters: - intable: The intable type. + indexer: The intable type. Args: index: The index into the tuple. val: The value to store. """ - var offset = int(index) + var offset = index.__index__() debug_assert(offset < size, "index must be within bounds") var tmp = self var ptr = __mlir_op.`pop.array.gep`( diff --git a/stdlib/src/utils/stringref.mojo b/stdlib/src/utils/stringref.mojo index 86cb86f944..95e0616094 100644 --- a/stdlib/src/utils/stringref.mojo +++ b/stdlib/src/utils/stringref.mojo @@ -195,7 +195,7 @@ struct StringRef( return not (self == rhs) @always_inline("nodebug") - fn __getitem__(self, idx: Int) -> StringRef: + fn __getitem__[indexer: Indexer](self, idx: indexer) -> StringRef: """Get the string value at the specified position. Args: @@ -204,7 +204,7 @@ struct StringRef( Returns: The character at the specified position. """ - return StringRef {data: self.data + idx, length: 1} + return StringRef {data: self.data + idx.__index__(), length: 1} fn __hash__(self) -> Int: """Hash the underlying buffer using builtin hash. diff --git a/stdlib/test/collections/test_vector.mojo b/stdlib/test/collections/test_vector.mojo index 9750141bd8..7537e06e93 100644 --- a/stdlib/test/collections/test_vector.mojo +++ b/stdlib/test/collections/test_vector.mojo @@ -103,6 +103,11 @@ def test_inlined_fixed_vector_with_default(): vector[5] = -2 assert_equal(-2, vector[5]) + # check we can index with non Int or IntLiteral + assert_equal(1, vector[Int16(1)]) + assert_equal(1, vector[True]) + assert_equal(1, vector[Scalar[DType.bool](True)]) + vector.clear() assert_equal(0, len(vector))