Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Add Indexer trait #2685

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,18 @@ what we publish.
use `str(my_list)` yet.
([PR #2673](https://github.com/modularml/mojo/pull/2673) by [@gabrieldemarmiesse](https://github.com/gabrieldemarmiesse))

- Added the `Indexer` trait to denote types that implement the `__index__()` method
which allows all integral types to be accepted in `__getitem__` and `__setitem__`
implementations. For example:

```mojo
struct MyList:
var data: List[Int]

fn __getitem__[T: Indexer](self, idx: T) -> T:
return self.data[index(idx)]
```

### 🦋 Changed

- The `let` keyword has been completely removed from the language. We previously
Expand Down
16 changes: 15 additions & 1 deletion stdlib/src/builtin/bool.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@ trait Boolable:
@value
@register_passable("trivial")
struct Bool(
Stringable, CollectionElement, Boolable, EqualityComparable, Intable
Stringable,
CollectionElement,
Boolable,
EqualityComparable,
Intable,
Indexer,
):
"""The primitive Bool scalar value used in Mojo."""

Expand Down Expand Up @@ -151,6 +156,15 @@ struct Bool(
"""
return __mlir_op.`pop.select`[_type=Int](self.value, Int(1), Int(0))

@always_inline("nodebug")
fn __index__(self) -> Int:
"""Convert this Bool to an integer for indexing purposes.

Returns:
1 if the Bool is True, 0 otherwise.
"""
return self.__int__()

@always_inline("nodebug")
fn __eq__(self, rhs: Bool) -> Bool:
"""Compare this Bool to RHS.
Expand Down
45 changes: 45 additions & 0 deletions stdlib/src/builtin/int.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,50 @@ from utils._visualizers import lldb_formatter_wrapping_type
from utils._format import Formattable, Formatter
from utils.inlined_string import _ArrayMem

# ===----------------------------------------------------------------------=== #
# Indexer
# ===----------------------------------------------------------------------=== #


trait Indexer:
"""This trait denotes a type that can be used to index a container that
handles integral index values.

This solves the issue of being able to index data structures such as `List` with the various
integral types without being too broad and allowing types that should not be used such as float point
values.
"""

fn __index__(self) -> Int:
"""Return the index value.

Returns:
The index value of the object.
"""
...


# ===----------------------------------------------------------------------=== #
# index
# ===----------------------------------------------------------------------=== #


@always_inline("nodebug")
fn index[indexer: Indexer](idx: indexer) -> Int:
"""Returns the value of `__index__` for the given value.

Parameters:
indexer: The type of the given value.

Args:
idx: The value.

Returns:
An int respresenting the index value.
"""
return idx.__index__()


# ===----------------------------------------------------------------------=== #
# Intable
# ===----------------------------------------------------------------------=== #
Expand Down Expand Up @@ -205,6 +249,7 @@ struct Int(
Roundable,
Stringable,
Truncable,
Indexer,
):
"""This type represents an integer value."""

Expand Down
1 change: 1 addition & 0 deletions stdlib/src/builtin/int_literal.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ struct IntLiteral(
Roundable,
Stringable,
Truncable,
Indexer,
):
"""This type represents a static integer literal value with
infinite precision. They can't be materialized at runtime and
Expand Down
17 changes: 17 additions & 0 deletions stdlib/src/builtin/simd.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
Sized,
Stringable,
Truncable,
Indexer,
):
"""Represents a small vector that is backed by a hardware vector element.

Expand Down Expand Up @@ -179,6 +180,22 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
_simd_construction_checks[type, size]()
self = _unchecked_zero[type, size]()

@always_inline("nodebug")
fn __index__(self) -> Int:
"""Returns the value as an int if it is an integral value.

Constraints:
Must be a scalar integral value.

Returns:
The value as an integer.
"""
constrained[
type.is_integral() or type.is_bool(),
"expected integral or bool type",
]()
return self.__int__()

@always_inline("nodebug")
fn __init__(inout self, value: SIMD[DType.float64, 1]):
"""Initializes the SIMD vector with a float.
Expand Down
6 changes: 6 additions & 0 deletions stdlib/test/builtin/test_bool.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,16 @@ def test_neg():
assert_equal(0, -False)


def test_indexer():
assert_equal(1, Bool.__index__(True))
assert_equal(0, Bool.__index__(False))


def main():
test_bool_cast_to_int()
test_bool_none()
bgreni marked this conversation as resolved.
Show resolved Hide resolved
test_convert_from_boolable()
test_bool_to_string()
test_bitwise()
test_neg()
test_indexer()
6 changes: 6 additions & 0 deletions stdlib/test/builtin/test_int.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ def test_int_representation():
assert_equal(repr(Int(-100)), "-100")


def test_indexer():
assert_equal(5, Int(5).__index__())
assert_equal(987, Int(987).__index__())


def main():
test_constructors()
test_properties()
Expand All @@ -160,3 +165,4 @@ def main():
test_abs()
test_string_conversion()
test_int_representation()
test_indexer()
6 changes: 6 additions & 0 deletions stdlib/test/builtin/test_int_literal.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ def test_abs():
assert_equal(abs(0), 0)


def test_indexer():
assert_equal(1, IntLiteral.__index__(1))
assert_equal(88, IntLiteral.__index__(88))


def main():
test_int()
test_ceil()
Expand All @@ -88,3 +93,4 @@ def main():
test_mod()
test_bit_width()
test_abs()
test_indexer()
8 changes: 8 additions & 0 deletions stdlib/test/builtin/test_simd.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,13 @@ def test_min_max_clamp():
assert_equal(i.clamp(-7, 4), I(-7, -5, 4, 4))


def test_indexer():
assert_equal(5, Int8(5).__index__())
assert_equal(56, UInt32(56).__index__())
assert_equal(1, Scalar[DType.bool](True).__index__())
assert_equal(0, Scalar[DType.bool](False).__index__())


def main():
test_cast()
test_simd_variadic()
Expand Down Expand Up @@ -1008,3 +1015,4 @@ def main():
test_mul_with_overflow()
test_abs()
test_min_max_clamp()
test_indexer()
Loading