Skip to content

Commit

Permalink
[External] [stdlib] Apply Indexer trait to stdlib containers (#40205)
Browse files Browse the repository at this point in the history
[External] [stdlib] Apply Indexer trait to stdlib containers

The second half of #2384

I have intentionally omitted `static_tuple.mojo` to avoid conflicting
with #2677

Co-authored-by: bgreni <42788181+bgreni@users.noreply.github.com>
Closes #2722
MODULAR_ORIG_COMMIT_REV_ID: 365298938b36c3f4f9ad0067dfcb3856de369875
  • Loading branch information
bgreni authored and modularbot committed May 21, 2024
1 parent 24fd95f commit 9760431
Show file tree
Hide file tree
Showing 26 changed files with 266 additions and 86 deletions.
8 changes: 4 additions & 4 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -306,10 +306,8 @@ what we publish.
- Added the `Indexer` trait to denote types that implement the `__index__()`
method which allow these types to be accepted in common `__getitem__` and
`__setitem__` implementations, as well as allow a new builtin `index` function
to be called on them.
([PR #2685](https://github.com/modularml/mojo/pull/2685) by
[@bgreni](https://github.com/bgreni))
For example:
to be called on them. Most stdlib containers are now able to be indexed by
any type that implements `Indexer`. For example:

```mojo
@value
Expand All @@ -329,6 +327,8 @@ what we publish.
print(MyList()[AlwaysZero()]) # prints `1`
```

([PR #2685](https://github.com/modularml/mojo/pull/2685) by [@bgreni](https://github.com/bgreni))

- `StringRef` now implements `strip()` which can be used to remove leading and
trailing whitespaces. ([PR #2683](https://github.com/modularml/mojo/pull/2683)
by [@fknfilewalker](https://github.com/fknfilewalker))
Expand Down
33 changes: 22 additions & 11 deletions stdlib/src/builtin/builtin_list.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -146,16 +146,19 @@ struct VariadicList[type: AnyRegType](Sized):
return __mlir_op.`pop.variadic.size`(self.value)

@always_inline
fn __getitem__(self, index: Int) -> type:
fn __getitem__[IndexerType: Indexer](self, idx: IndexerType) -> type:
"""Gets a single element on the variadic list.
Parameters:
IndexerType: The type of the indexer.
Args:
index: The index of the element to access on the list.
idx: The index of the element to access on the list.
Returns:
The element on the list corresponding to the given index.
"""
return __mlir_op.`pop.variadic.get`(self.value, index.value)
return __mlir_op.`pop.variadic.get`(self.value, index(idx).value)

@always_inline
fn __iter__(self) -> Self.IterType:
Expand Down Expand Up @@ -358,24 +361,29 @@ struct VariadicListMem[
# TODO: Fix for loops + _VariadicListIter to support a __nextref__ protocol
# allowing us to get rid of this and make foreach iteration clean.
@always_inline
fn __getitem__(self, index: Int) -> Self.reference_type:
fn __getitem__[
IndexerType: Indexer
](self, idx: IndexerType) -> Self.reference_type:
"""Gets a single element on the variadic list.
Parameters:
IndexerType: The type of the indexer.
Args:
index: The index of the element to access on the list.
idx: The index of the element to access on the list.
Returns:
A low-level pointer to the element on the list corresponding to the
given index.
"""
return Self.reference_type(
__mlir_op.`pop.variadic.get`(self.value, index.value)
__mlir_op.`pop.variadic.get`(self.value, index(idx).value)
)

@always_inline
fn __refitem__(
self, index: Int
) -> Reference[
fn __refitem__[
IndexerType: Indexer
](self, idx: IndexerType) -> Reference[
element_type,
Bool {value: elt_is_mutable},
_lit_lifetime_union[
Expand All @@ -391,14 +399,17 @@ struct VariadicListMem[
]:
"""Gets a single element on the variadic list.
Parameters:
IndexerType: The type of the indexer.
Args:
index: The index of the element to access on the list.
idx: The index of the element to access on the list.
Returns:
A low-level pointer to the element on the list corresponding to the
given index.
"""
return __mlir_op.`pop.variadic.get`(self.value, index.value)
return __mlir_op.`pop.variadic.get`(self.value, index(idx).value)

fn __iter__(
self,
Expand Down
7 changes: 5 additions & 2 deletions stdlib/src/builtin/builtin_slice.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,19 @@ struct Slice(Sized, Stringable, EqualityComparable):
return len(range(self.start, self.end, self.step))

@always_inline
fn __getitem__(self, idx: Int) -> Int:
fn __getitem__[IndexerType: Indexer](self, idx: IndexerType) -> Int:
"""Get the slice index.
Parameters:
IndexerType: The type of the indexer.
Args:
idx: The index.
Returns:
The slice index.
"""
return self.start + idx * self.step
return self.start + index(idx) * self.step

@always_inline("nodebug")
fn _has_end(self) -> Bool:
Expand Down
12 changes: 6 additions & 6 deletions stdlib/src/builtin/range.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ struct _ZeroStartingRange(Sized, ReversibleRange, _IntIterable):
return self.curr

@always_inline("nodebug")
fn __getitem__(self, idx: Int) -> Int:
return idx
fn __getitem__[IndexerType: Indexer](self, idx: IndexerType) -> Int:
return index(idx)

@always_inline("nodebug")
fn __reversed__(self) -> _StridedRange:
Expand Down Expand Up @@ -116,8 +116,8 @@ struct _SequentialRange(Sized, ReversibleRange, _IntIterable):
return self.end - self.start if self.start < self.end else 0

@always_inline("nodebug")
fn __getitem__(self, idx: Int) -> Int:
return self.start + idx
fn __getitem__[IndexerType: Indexer](self, idx: IndexerType) -> Int:
return self.start + index(idx)

@always_inline("nodebug")
fn __reversed__(self) -> _StridedRange:
Expand Down Expand Up @@ -184,8 +184,8 @@ struct _StridedRange(Sized, ReversibleRange, _StridedIterable):
return _div_ceil_positive(abs(self.start - self.end), abs(self.step))

@always_inline("nodebug")
fn __getitem__(self, idx: Int) -> Int:
return self.start + idx * self.step
fn __getitem__[IndexerType: Indexer](self, idx: IndexerType) -> Int:
return self.start + index(idx) * self.step

@always_inline("nodebug")
fn __reversed__(self) -> _StridedRange:
Expand Down
31 changes: 24 additions & 7 deletions stdlib/src/builtin/simd.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -1824,9 +1824,14 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
# ===-------------------------------------------------------------------===#

@always_inline("nodebug")
fn __getitem__(self, idx: Int) -> Scalar[type]:
fn __getitem__[
IndexerType: Indexer
](self, idx: IndexerType) -> Scalar[type]:
"""Gets an element from the vector.
Parameters:
IndexerType: The type of the indexer.
Args:
idx: The element index.
Expand All @@ -1835,32 +1840,44 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
"""
return __mlir_op.`pop.simd.extractelement`[
_type = __mlir_type[`!pop.scalar<`, type.value, `>`]
](self.value, idx.value)
](self.value, index(idx).value)

@always_inline("nodebug")
fn __setitem__(inout self, idx: Int, val: Scalar[type]):
fn __setitem__[
IndexerType: Indexer
](inout self, idx: IndexerType, val: Scalar[type]):
"""Sets an element in the vector.
Parameters:
IndexerType: The type of the indexer.
Args:
idx: The index to set.
val: The value to set.
"""
self.value = __mlir_op.`pop.simd.insertelement`(
self.value, val.value, idx.value
self.value, val.value, index(idx).value
)

@always_inline("nodebug")
fn __setitem__(
inout self, idx: Int, val: __mlir_type[`!pop.scalar<`, type.value, `>`]
fn __setitem__[
IndexerType: Indexer
](
inout self,
idx: IndexerType,
val: __mlir_type[`!pop.scalar<`, type.value, `>`],
):
"""Sets an element in the vector.
Parameters:
IndexerType: The type of the indexer.
Args:
idx: The index to set.
val: The value to set.
"""
self.value = __mlir_op.`pop.simd.insertelement`(
self.value, val, idx.value
self.value, val, index(idx).value
)

fn __hash__(self) -> Int:
Expand Down
8 changes: 6 additions & 2 deletions stdlib/src/builtin/string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -877,15 +877,19 @@ struct String(
"""
return len(self) > 0

fn __getitem__(self, idx: Int) -> String:
fn __getitem__[IndexerType: Indexer](self, i: IndexerType) -> String:
"""Gets the character at the specified position.
Parameters:
IndexerType: The type of the indexer.
Args:
idx: The index value.
i: The index value.
Returns:
A new string containing the character at the specified position.
"""
var idx = index(i)
if idx < 0:
return self.__getitem__(len(self) + idx)

Expand Down
11 changes: 7 additions & 4 deletions stdlib/src/collections/inline_list.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -129,19 +129,22 @@ struct InlineList[ElementType: CollectionElement, capacity: Int = 16](Sized):

@always_inline
fn __refitem__[
IntableType: Intable,
](self: Reference[Self, _, _], index: IntableType) -> Reference[
IndexerType: Indexer,
](self: Reference[Self, _, _], idx: IndexerType) -> Reference[
Self.ElementType, self.is_mutable, self.lifetime
]:
"""Get a `Reference` to the element at the given index.
Parameters:
IndexerType: The type of the indexer.
Args:
index: The index of the item.
idx: The index of the item.
Returns:
A reference to the item at the given index.
"""
var i = int(index)
var i = index(idx)
debug_assert(
-self[]._size <= i < self[]._size, "Index must be within bounds."
)
Expand Down
31 changes: 22 additions & 9 deletions stdlib/src/collections/list.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -553,17 +553,25 @@ struct List[T: CollectionElement](CollectionElement, Sized, Boolable):
self.capacity = 0
return ptr

fn __setitem__(inout self, i: Int, owned value: T):
fn __setitem__[
IndexerType: Indexer
](inout self, i: IndexerType, owned value: T):
"""Sets a list element at the given index.
Parameters:
IndexerType: The type of the indexer.
Args:
i: The index of the element.
value: The value to assign.
"""
debug_assert(-self.size <= i < self.size, "index must be within bounds")
var normalized_idx = index(i)
debug_assert(
-self.size <= normalized_idx < self.size,
"index must be within bounds",
)

var normalized_idx = i
if i < 0:
if normalized_idx < 0:
normalized_idx += len(self)

destroy_pointee(self.data + normalized_idx)
Expand Down Expand Up @@ -613,21 +621,26 @@ struct List[T: CollectionElement](CollectionElement, Sized, Boolable):
return res^

@always_inline
fn __getitem__(self, i: Int) -> T:
fn __getitem__[IndexerType: Indexer](self, i: IndexerType) -> T:
"""Gets a copy of the list element at the given index.
FIXME(lifetimes): This should return a reference, not a copy!
Parameters:
IndexerType: The type of the indexer.
Args:
i: The index of the element.
Returns:
A copy of the element at the given index.
"""
debug_assert(-self.size <= i < self.size, "index must be within bounds")

var normalized_idx = i
if i < 0:
var normalized_idx = index(i)
debug_assert(
-self.size <= normalized_idx < self.size,
"index must be within bounds",
)
if normalized_idx < 0:
normalized_idx += len(self)

return (self.data + normalized_idx)[]
Expand Down
26 changes: 17 additions & 9 deletions stdlib/src/collections/vector.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -181,21 +181,25 @@ struct InlinedFixedVector[
return self.current_size

@always_inline
fn __getitem__(self, i: Int) -> type:
fn __getitem__[IndexerType: Indexer](self, i: IndexerType) -> type:
"""Gets a vector element at the given index.
Parameters:
IndexerType: The type of the indexer.
Args:
i: The index of the element.
Returns:
The element at the given index.
"""
var normalized_idx = index(i)
debug_assert(
-self.current_size <= i < self.current_size,
-self.current_size <= normalized_idx < self.current_size,
"index must be within bounds",
)
var normalized_idx = i
if i < 0:

if normalized_idx < 0:
normalized_idx += len(self)

if normalized_idx < Self.static_size:
Expand All @@ -204,20 +208,24 @@ struct InlinedFixedVector[
return self.dynamic_data[normalized_idx - Self.static_size]

@always_inline
fn __setitem__(inout self, i: Int, value: type):
fn __setitem__[
IndexerType: Indexer
](inout self, i: IndexerType, value: type):
"""Sets a vector element at the given index.
Parameters:
IndexerType: The type of the indexer.
Args:
i: The index of the element.
value: The value to assign.
"""
var normalized_idx = index(i)
debug_assert(
-self.current_size <= i < self.current_size,
-self.current_size <= normalized_idx < self.current_size,
"index must be within bounds",
)

var normalized_idx = i
if i < 0:
if normalized_idx < 0:
normalized_idx += len(self)

if normalized_idx < Self.static_size:
Expand Down
Loading

0 comments on commit 9760431

Please sign in to comment.