Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Parameterize __getitem__ and __setitem__ in stdlib types #2384

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion stdlib/src/builtin/bool.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@ trait Boolable:
@value
@register_passable("trivial")
struct Bool(
Stringable, CollectionElement, Boolable, EqualityComparable, Intable
Stringable,
CollectionElement,
Boolable,
EqualityComparable,
Intable,
Indexer,
):
"""The primitive Bool scalar value used in Mojo."""

Expand Down Expand Up @@ -324,6 +329,15 @@ struct Bool(
"""
return __mlir_op.`index.casts`[_type = __mlir_type.index](self.value)

@always_inline("nodebug")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a test for this (we test these implementations directly using explicit calls to __index__ instead of through the free functions).

fn __index__(self) -> Int:
bgreni marked this conversation as resolved.
Show resolved Hide resolved
"""Convert this Bool to an integer for indexing purposes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
"""Convert this Bool to an integer for indexing purposes
"""Convert this Bool to an integer for indexing purposes.


Returns:
Bool as Int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please provide a more details description here, e.g. "1 if True and 0 otherwise."

"""
return self.__int__()


# ===----------------------------------------------------------------------=== #
# bool
Expand Down
18 changes: 12 additions & 6 deletions stdlib/src/builtin/builtin_list.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -146,16 +146,19 @@ struct VariadicList[type: AnyRegType](Sized):
return __mlir_op.`pop.variadic.size`(self.value)

@always_inline
fn __getitem__(self, index: Int) -> type:
fn __getitem__[indexer: Indexer](self, idx: indexer) -> type:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: here and elsewhere: we try to make type parameters PascalCase, could you please change these?

Suggested change
fn __getitem__[indexer: Indexer](self, idx: indexer) -> type:
fn __getitem__[IndexerType: Indexer](self, idx: IndexerType) -> type:

If you want, I'm also okay with calling these just T for brevity.

"""Gets a single element on the variadic list.

Parameters:
indexer: The type of the indexing value.

Args:
index: The index of the element to access on the list.
idx: The index of the element to access on the list.

Returns:
The element on the list corresponding to the given index.
"""
return __mlir_op.`pop.variadic.get`(self.value, index.value)
return __mlir_op.`pop.variadic.get`(self.value, index(idx).value)

@always_inline
fn __iter__(self) -> Self.IterType:
Expand Down Expand Up @@ -358,18 +361,21 @@ struct VariadicListMem[
# TODO: Fix for loops + _VariadicListIter to support a __nextref__ protocol
# allowing us to get rid of this and make foreach iteration clean.
@always_inline
fn __getitem__(self, index: Int) -> Self.reference_type:
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Self.reference_type:
"""Gets a single element on the variadic list.

Parameters:
indexer: The type of the indexing value.

Args:
index: The index of the element to access on the list.
idx: The index of the element to access on the list.

Returns:
A low-level pointer to the element on the list corresponding to the
given index.
"""
return Self.reference_type(
__mlir_op.`pop.variadic.get`(self.value, index.value)
__mlir_op.`pop.variadic.get`(self.value, index(idx).value)
)

@always_inline
Expand Down
7 changes: 5 additions & 2 deletions stdlib/src/builtin/builtin_slice.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,19 @@ struct Slice(Sized, Stringable, EqualityComparable):
return len(range(self.start, self.end, self.step))

@always_inline
fn __getitem__(self, idx: Int) -> Int:
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int:
"""Get the slice index.

Parameters:
indexer: The type of the indexing value.

Args:
idx: The index.

Returns:
The slice index.
"""
return self.start + idx * self.step
return self.start + index(idx) * self.step

@always_inline("nodebug")
fn _has_end(self) -> Bool:
Expand Down
1 change: 1 addition & 0 deletions stdlib/src/builtin/int.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ struct Int(
Roundable,
Stringable,
Truncable,
Indexer,
):
"""This type represents an integer value."""

Expand Down
1 change: 1 addition & 0 deletions stdlib/src/builtin/int_literal.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ struct IntLiteral(
Roundable,
Stringable,
Truncable,
Indexer,
):
"""This type represents a static integer literal value with
infinite precision. They can't be materialized at runtime and
Expand Down
12 changes: 6 additions & 6 deletions stdlib/src/builtin/range.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ struct _ZeroStartingRange(Sized, ReversibleRange):
return self.curr

@always_inline("nodebug")
fn __getitem__(self, idx: Int) -> Int:
return idx
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int:
return index(idx)

@always_inline("nodebug")
fn __reversed__(self) -> _StridedRangeIterator:
Expand Down Expand Up @@ -113,8 +113,8 @@ struct _SequentialRange(Sized, ReversibleRange):
return self.end - self.start if self.start < self.end else 0

@always_inline("nodebug")
fn __getitem__(self, idx: Int) -> Int:
return self.start + idx
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int:
return self.start + index(idx)

@always_inline("nodebug")
fn __reversed__(self) -> _StridedRangeIterator:
Expand Down Expand Up @@ -185,8 +185,8 @@ struct _StridedRange(Sized, ReversibleRange):
return _div_ceil_positive(abs(self.start - self.end), abs(self.step))

@always_inline("nodebug")
fn __getitem__(self, idx: Int) -> Int:
return self.start + idx * self.step
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int:
return self.start + index(idx) * self.step

@always_inline("nodebug")
fn __reversed__(self) -> _StridedRangeIterator:
Expand Down
46 changes: 39 additions & 7 deletions stdlib/src/builtin/simd.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
Sized,
Stringable,
Truncable,
Indexer,
):
"""Represents a small vector that is backed by a hardware vector element.

Expand Down Expand Up @@ -483,6 +484,22 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
rebind[Scalar[type]](self).value
)

@always_inline("nodebug")
fn __index__(self) -> Int:
"""Returns the value as an int if it is an integral value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit here and below: missing periods

Suggested change
"""Returns the value as an int if it is an integral value
"""Returns the value as an int if it is an integral value.


Contraints:
Must be an integral value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also needs to be a scalar, could you please add that here?


Returns:
The value as an integer
"""
constrained[
type.is_integral() or type.is_bool(),
"expected integral or bool type",
]()
return self.__int__()

@always_inline
fn __str__(self) -> String:
"""Get the SIMD as a string.
Expand Down Expand Up @@ -1731,9 +1748,12 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
# ===-------------------------------------------------------------------===#

@always_inline("nodebug")
fn __getitem__(self, idx: Int) -> Scalar[type]:
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Scalar[type]:
"""Gets an element from the vector.

Parameters:
indexer: The type of the indexing value.

Args:
idx: The element index.

Expand All @@ -1742,32 +1762,44 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
"""
return __mlir_op.`pop.simd.extractelement`[
_type = __mlir_type[`!pop.scalar<`, type.value, `>`]
](self.value, idx.value)
](self.value, index(idx).value)

@always_inline("nodebug")
fn __setitem__(inout self, idx: Int, val: Scalar[type]):
fn __setitem__[
indexer: Indexer
](inout self, idx: indexer, val: Scalar[type]):
"""Sets an element in the vector.

Parameters:
indexer: The type of the indexing value.

Args:
idx: The index to set.
val: The value to set.
"""
self.value = __mlir_op.`pop.simd.insertelement`(
self.value, val.value, idx.value
self.value, val.value, index(idx).value
)

@always_inline("nodebug")
fn __setitem__(
inout self, idx: Int, val: __mlir_type[`!pop.scalar<`, type.value, `>`]
fn __setitem__[
indexer: Indexer
](
inout self,
idx: indexer,
val: __mlir_type[`!pop.scalar<`, type.value, `>`],
):
"""Sets an element in the vector.

Parameters:
indexer: The type of the indexing value.

Args:
idx: The index to set.
val: The value to set.
"""
self.value = __mlir_op.`pop.simd.insertelement`(
self.value, val, idx.value
self.value, val, index(idx).value
)

fn __hash__(self) -> Int:
Expand Down
14 changes: 9 additions & 5 deletions stdlib/src/builtin/string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -798,21 +798,25 @@ struct String(
"""
return len(self) > 0

fn __getitem__(self, idx: Int) -> String:
fn __getitem__[indexer: Indexer](self, idx: indexer) -> String:
"""Gets the character at the specified position.

Parameters:
indexer: The type of the indexing value.

Args:
idx: The index value.

Returns:
A new string containing the character at the specified position.
"""
if idx < 0:
return self.__getitem__(len(self) + idx)
var index_val = index(idx)
if index_val < 0:
return self.__getitem__(len(self) + index_val)

debug_assert(0 <= idx < len(self), "index must be in range")
debug_assert(0 <= index_val < len(self), "index must be in range")
var buf = Self._buffer_type(capacity=1)
buf.append(self._buffer[idx])
buf.append(self._buffer[index_val])
buf.append(0)
return String(buf^)

Expand Down
34 changes: 34 additions & 0 deletions stdlib/src/builtin/value.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,37 @@ trait BoolableKeyElement(Boolable, KeyElement):
"""

pass


trait Indexer:
bgreni marked this conversation as resolved.
Show resolved Hide resolved
"""This trait denotes a type that can be used to index a container that
handles integral index values.

This solves the issue of being able to index data structures such as `List` with the various
integral types without being too broad and allowing types that should not be used such as float point
values.
Comment on lines +212 to +214
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style nit: please wrap lines to 80 char

"""

fn __index__(self) -> Int:
"""Return the index value

Returns:
The index value of the object
"""
Comment on lines +218 to +222
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JoeLoser Do we have styleguide recommendations on documentating method declarations in traits? And do they even generate docs correctly?

...


@always_inline("nodebug")
fn index[indexer: Indexer](idx: indexer) -> Int:
"""Returns the value of `__index__` for the given value.

Parameters:
indexer: The type of the given value.

Args:
idx: The value.

Returns:
An int respresenting the index value.
"""
return idx.__index__()
11 changes: 7 additions & 4 deletions stdlib/src/collections/inline_list.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -71,19 +71,22 @@ struct InlineList[ElementType: CollectionElement, capacity: Int = 16](Sized):

@always_inline
fn __refitem__[
IntableType: Intable,
](self: Reference[Self, _, _], index: IntableType) -> Reference[
indexer: Indexer
](self: Reference[Self, _, _], idx: indexer) -> Reference[
Self.ElementType, self.is_mutable, self.lifetime
]:
"""Get a `Reference` to the element at the given index.

Parameters:
indexer: The inferred type of the indexer.

Args:
index: The index of the item.
idx: The index of the item.

Returns:
A reference to the item at the given index.
"""
var i = int(index)
var i = index(idx)
debug_assert(
-self[]._size <= i < self[]._size, "Index must be within bounds."
)
Expand Down
Loading
Loading