diff --git a/docs/changelog.md b/docs/changelog.md index fc0d8700c6..869193bb40 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -241,6 +241,18 @@ what we publish. use `str(my_list)` yet. ([PR #2673](https://github.com/modularml/mojo/pull/2673) by [@gabrieldemarmiesse](https://github.com/gabrieldemarmiesse)) +- Added the `Indexer` trait to denote types that implement the `__index__()` method + which allows all integral types to be accepted in `__getitem__` and `__setitem__` + implementations. For example: + + ```mojo + struct MyList: + var data: List[Int] + + fn __getitem__[T: Indexer](self, idx: T) -> T: + return self.data[index(idx)] + ``` + ### 🦋 Changed - The `let` keyword has been completely removed from the language. We previously diff --git a/stdlib/src/builtin/bool.mojo b/stdlib/src/builtin/bool.mojo index 1068d7fd4b..bcc45bb128 100644 --- a/stdlib/src/builtin/bool.mojo +++ b/stdlib/src/builtin/bool.mojo @@ -59,7 +59,12 @@ trait Boolable: @value @register_passable("trivial") struct Bool( - Stringable, CollectionElement, Boolable, EqualityComparable, Intable + Stringable, + CollectionElement, + Boolable, + EqualityComparable, + Intable, + Indexer, ): """The primitive Bool scalar value used in Mojo.""" @@ -151,6 +156,15 @@ struct Bool( """ return __mlir_op.`pop.select`[_type=Int](self.value, Int(1), Int(0)) + @always_inline("nodebug") + fn __index__(self) -> Int: + """Convert this Bool to an integer for indexing purposes. + + Returns: + 1 if the Bool is True, 0 otherwise. + """ + return self.__int__() + @always_inline("nodebug") fn __eq__(self, rhs: Bool) -> Bool: """Compare this Bool to RHS. diff --git a/stdlib/src/builtin/int.mojo b/stdlib/src/builtin/int.mojo index af64739040..266bc8f5ba 100644 --- a/stdlib/src/builtin/int.mojo +++ b/stdlib/src/builtin/int.mojo @@ -27,6 +27,50 @@ from utils._visualizers import lldb_formatter_wrapping_type from utils._format import Formattable, Formatter from utils.inlined_string import _ArrayMem +# ===----------------------------------------------------------------------=== # +# Indexer +# ===----------------------------------------------------------------------=== # + + +trait Indexer: + """This trait denotes a type that can be used to index a container that + handles integral index values. + + This solves the issue of being able to index data structures such as `List` with the various + integral types without being too broad and allowing types that should not be used such as float point + values. + """ + + fn __index__(self) -> Int: + """Return the index value. + + Returns: + The index value of the object. + """ + ... + + +# ===----------------------------------------------------------------------=== # +# index +# ===----------------------------------------------------------------------=== # + + +@always_inline("nodebug") +fn index[indexer: Indexer](idx: indexer) -> Int: + """Returns the value of `__index__` for the given value. + + Parameters: + indexer: The type of the given value. + + Args: + idx: The value. + + Returns: + An int respresenting the index value. + """ + return idx.__index__() + + # ===----------------------------------------------------------------------=== # # Intable # ===----------------------------------------------------------------------=== # @@ -205,6 +249,7 @@ struct Int( Roundable, Stringable, Truncable, + Indexer, ): """This type represents an integer value.""" diff --git a/stdlib/src/builtin/int_literal.mojo b/stdlib/src/builtin/int_literal.mojo index e4a13b06e8..d8691e0446 100644 --- a/stdlib/src/builtin/int_literal.mojo +++ b/stdlib/src/builtin/int_literal.mojo @@ -29,6 +29,7 @@ struct IntLiteral( Roundable, Stringable, Truncable, + Indexer, ): """This type represents a static integer literal value with infinite precision. They can't be materialized at runtime and diff --git a/stdlib/src/builtin/simd.mojo b/stdlib/src/builtin/simd.mojo index 03e27e002f..9d2f1efb42 100644 --- a/stdlib/src/builtin/simd.mojo +++ b/stdlib/src/builtin/simd.mojo @@ -139,6 +139,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()]( Sized, Stringable, Truncable, + Indexer, ): """Represents a small vector that is backed by a hardware vector element. @@ -179,6 +180,22 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()]( _simd_construction_checks[type, size]() self = _unchecked_zero[type, size]() + @always_inline("nodebug") + fn __index__(self) -> Int: + """Returns the value as an int if it is an integral value. + + Constraints: + Must be a scalar integral value. + + Returns: + The value as an integer. + """ + constrained[ + type.is_integral() or type.is_bool(), + "expected integral or bool type", + ]() + return self.__int__() + @always_inline("nodebug") fn __init__(inout self, value: SIMD[DType.float64, 1]): """Initializes the SIMD vector with a float. diff --git a/stdlib/test/builtin/test_bool.mojo b/stdlib/test/builtin/test_bool.mojo index ab13481124..17a4347afc 100644 --- a/stdlib/test/builtin/test_bool.mojo +++ b/stdlib/test/builtin/test_bool.mojo @@ -100,6 +100,11 @@ def test_neg(): assert_equal(0, -False) +def test_indexer(): + assert_equal(1, Bool.__index__(True)) + assert_equal(0, Bool.__index__(False)) + + def main(): test_bool_cast_to_int() test_bool_none() @@ -107,3 +112,4 @@ def main(): test_bool_to_string() test_bitwise() test_neg() + test_indexer() diff --git a/stdlib/test/builtin/test_int.mojo b/stdlib/test/builtin/test_int.mojo index 7fbe49c04d..07e9ac5c25 100644 --- a/stdlib/test/builtin/test_int.mojo +++ b/stdlib/test/builtin/test_int.mojo @@ -143,6 +143,11 @@ def test_int_representation(): assert_equal(repr(Int(-100)), "-100") +def test_indexer(): + assert_equal(5, Int(5).__index__()) + assert_equal(987, Int(987).__index__()) + + def main(): test_constructors() test_properties() @@ -160,3 +165,4 @@ def main(): test_abs() test_string_conversion() test_int_representation() + test_indexer() diff --git a/stdlib/test/builtin/test_int_literal.mojo b/stdlib/test/builtin/test_int_literal.mojo index 697034cf47..19a459eb11 100644 --- a/stdlib/test/builtin/test_int_literal.mojo +++ b/stdlib/test/builtin/test_int_literal.mojo @@ -78,6 +78,11 @@ def test_abs(): assert_equal(abs(0), 0) +def test_indexer(): + assert_equal(1, IntLiteral.__index__(1)) + assert_equal(88, IntLiteral.__index__(88)) + + def main(): test_int() test_ceil() @@ -88,3 +93,4 @@ def main(): test_mod() test_bit_width() test_abs() + test_indexer() diff --git a/stdlib/test/builtin/test_simd.mojo b/stdlib/test/builtin/test_simd.mojo index 8b1d18ffe0..b009793a07 100644 --- a/stdlib/test/builtin/test_simd.mojo +++ b/stdlib/test/builtin/test_simd.mojo @@ -974,6 +974,13 @@ def test_min_max_clamp(): assert_equal(i.clamp(-7, 4), I(-7, -5, 4, 4)) +def test_indexer(): + assert_equal(5, Int8(5).__index__()) + assert_equal(56, UInt32(56).__index__()) + assert_equal(1, Scalar[DType.bool](True).__index__()) + assert_equal(0, Scalar[DType.bool](False).__index__()) + + def main(): test_cast() test_simd_variadic() @@ -1008,3 +1015,4 @@ def main(): test_mul_with_overflow() test_abs() test_min_max_clamp() + test_indexer()