diff --git a/docs/changelog.md b/docs/changelog.md index ed4b6d8330..7265674c07 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -22,6 +22,20 @@ what we publish. [PR #2609](https://github.com/modularml/mojo/pull/2609) by [@mzaks](https://github.com/mzaks) +- Mojo functions can return an auto-dereferenced refeference to storage with a + new `ref` keyword in the result type specifier. For example: + + ```mojo + struct Pair: + var first: Int + var second: Int + fn get_first_ref(inout self) -> ref[__lifetime_of(self)] Int: + return self.first + fn show_mutation(): + var somePair = ... + get_first_ref(somePair) = 1 + ``` + - Mojo has introduced `@parameter for`, a new feature for compile-time programming. `@parameter for` defines a for loop where the sequence and the induction values in the sequence must be parameter values. For example: diff --git a/docs/manual/types.ipynb b/docs/manual/types.ipynb index 6ceffe596b..1a454a11b4 100644 --- a/docs/manual/types.ipynb +++ b/docs/manual/types.ipynb @@ -530,7 +530,7 @@ "Many types have a boolean representation. Any type that implements the \n", "[`Boolable`](/mojo/stdlib/builtin/bool/Boolable) trait has a boolean \n", "representation. As a general principle, collections evaluate as True if they \n", - "contain any elements, False if the are empty; strings evaluate as True if they\n", + "contain any elements, False if they are empty; strings evaluate as True if they\n", "have a non-zero length." ] }, diff --git a/stdlib/COMPATIBLE_COMPILER_VERSION b/stdlib/COMPATIBLE_COMPILER_VERSION index 3fa5fb09ab..353e233e87 100644 --- a/stdlib/COMPATIBLE_COMPILER_VERSION +++ b/stdlib/COMPATIBLE_COMPILER_VERSION @@ -1 +1 @@ -2024.5.2514 +2024.5.2605 diff --git a/stdlib/src/builtin/builtin_list.mojo b/stdlib/src/builtin/builtin_list.mojo index d25572f336..0135d962f8 100644 --- a/stdlib/src/builtin/builtin_list.mojo +++ b/stdlib/src/builtin/builtin_list.mojo @@ -194,7 +194,8 @@ struct _VariadicListMemIter[ self.index += 1 # TODO: Need to make this return a dereferenced reference, not a # reference that must be deref'd by the user. - return self.src[].__getitem__(self.index - 1) + # NOTE: Using UnsafePointer here to get lifetimes to match. + return UnsafePointer.address_of(self.src[][self.index - 1])[] fn __len__(self) -> Int: return len(self.src[]) - self.index @@ -355,29 +356,10 @@ struct VariadicListMem[ """ return __mlir_op.`pop.variadic.size`(self.value) - # TODO: Fix for loops + _VariadicListIter to support a __nextref__ protocol - # allowing us to get rid of this and make foreach iteration clean. @always_inline - fn __getitem__(self, idx: Int) -> Self.reference_type: - """Gets a single element on the variadic list. - - Args: - idx: The index of the element to access on the list. - - Returns: - A low-level pointer to the element on the list corresponding to the - given index. - """ - return Self.reference_type( - __mlir_op.`pop.variadic.get`(self.value, idx.value) - ) - - @always_inline - fn __refitem__( + fn __getitem__( self, idx: Int - ) -> Reference[ - element_type, - Bool {value: elt_is_mutable}, + ) -> ref [ _lit_lifetime_union[ Bool {value: elt_is_mutable}, lifetime, @@ -387,8 +369,8 @@ struct VariadicListMem[ _lit_mut_cast[ False, __lifetime_of(self), Bool {value: elt_is_mutable} ].result, - ].result, - ]: + ].result + ] element_type: """Gets a single element on the variadic list. Args: @@ -398,7 +380,7 @@ struct VariadicListMem[ A low-level pointer to the element on the list corresponding to the given index. """ - return __mlir_op.`pop.variadic.get`(self.value, idx.value) + return Reference(__mlir_op.`pop.variadic.get`(self.value, idx.value))[] fn __iter__( self, @@ -590,13 +572,9 @@ struct VariadicPack[ return Self.__len__() @always_inline - fn __refitem__[ + fn __getitem__[ index: Int - ](self) -> Reference[ - element_types[index.value], - Bool {value: Self.elt_is_mutable}, - Self.lifetime, - ]: + ](self) -> ref [Self.lifetime] element_types[index.value]: """Return a reference to an element of the pack. Parameters: @@ -618,7 +596,7 @@ struct VariadicPack[ Bool {value: Self.elt_is_mutable}, Self.lifetime, ] - return rebind[result_ref._mlir_type](ref_elt) + return Reference(rebind[result_ref._mlir_type](ref_elt))[] @always_inline fn each[func: fn[T: element_trait] (T) capturing -> None](self): diff --git a/stdlib/src/builtin/string.mojo b/stdlib/src/builtin/string.mojo index 1614f4e17e..77a0003724 100644 --- a/stdlib/src/builtin/string.mojo +++ b/stdlib/src/builtin/string.mojo @@ -236,7 +236,7 @@ fn _atol(str_ref: StringRef, base: Int = 10) raises -> Int: alias ord_0 = ord("0") # FIXME: - # Change this to `alias` after fixing support for __refitem__ of alias. + # Change this to `alias` after fixing support for __getitem__ of alias. var ord_letter_min = (ord("a"), ord("A")) alias ord_underscore = ord("_") diff --git a/stdlib/src/builtin/tuple.mojo b/stdlib/src/builtin/tuple.mojo index 45c150a092..885220c040 100644 --- a/stdlib/src/builtin/tuple.mojo +++ b/stdlib/src/builtin/tuple.mojo @@ -150,11 +150,19 @@ struct Tuple[*element_types: Movable](Sized, Movable): return Self.__len__() @always_inline("nodebug") - fn __refitem__[ + fn __getitem__[ idx: Int - ](self: Reference[Self, _, _]) -> Reference[ - element_types[idx.value], self.is_mutable, self.lifetime + ](self: Reference[Self, _, _]) -> ref [self.lifetime] element_types[ + idx.value ]: + """Get a reference to an element in the tuple. + + Parameters: + idx: The element to return. + + Returns: + A referece to the specified element. + """ # Return a reference to an element at the specified index, propagating # mutability of self. var storage_kgen_ptr = UnsafePointer.address_of(self[].storage).address @@ -220,16 +228,8 @@ struct Tuple[*element_types: Movable](Sized, Movable): @parameter if _type_is_eq[T, element_types[i]](): - var tmp_ref = self.__refitem__[i]() - var tmp = rebind[ - Reference[ - T, - tmp_ref.is_mutable, - tmp_ref.lifetime, - tmp_ref.address_space, - ] - ](tmp_ref) - if tmp[].__eq__(value): + var elt_ptr = UnsafePointer.address_of(self[i]).bitcast[T]() + if elt_ptr[].__eq__(value): return True return False diff --git a/stdlib/src/collections/dict.mojo b/stdlib/src/collections/dict.mojo index 9d348ecc04..e3977e1371 100644 --- a/stdlib/src/collections/dict.mojo +++ b/stdlib/src/collections/dict.mojo @@ -546,7 +546,7 @@ struct Dict[K: KeyElement, V: CollectionElement]( """ return self._find_ref(key)[] - # TODO(MSTDL-452): rename to __refitem__ + # TODO(MSTDL-452): rename to __getitem__ returning a reference fn __get_ref( self: Reference[Self, _, _], key: K ) raises -> Reference[V, self.is_mutable, self.lifetime]: diff --git a/stdlib/src/collections/inline_list.mojo b/stdlib/src/collections/inline_list.mojo index e53b9f0f06..b1ccfcc860 100644 --- a/stdlib/src/collections/inline_list.mojo +++ b/stdlib/src/collections/inline_list.mojo @@ -58,10 +58,10 @@ struct _InlineListIter[ @parameter if forward: self.index += 1 - return self.src[].__refitem__(self.index - 1) + return self.src[][self.index - 1] else: self.index -= 1 - return self.src[].__refitem__(self.index) + return self.src[][self.index] fn __len__(self) -> Int: @parameter @@ -128,9 +128,9 @@ struct InlineList[ElementType: CollectionElement, capacity: Int = 16](Sized): self._size += 1 @always_inline - fn __refitem__( + fn __getitem__( self: Reference[Self, _, _], owned idx: Int - ) -> Reference[Self.ElementType, self.is_mutable, self.lifetime]: + ) -> ref [self.lifetime] Self.ElementType: """Get a `Reference` to the element at the given index. Args: diff --git a/stdlib/src/collections/list.mojo b/stdlib/src/collections/list.mojo index 5d6c502f00..79d60fce10 100644 --- a/stdlib/src/collections/list.mojo +++ b/stdlib/src/collections/list.mojo @@ -910,7 +910,7 @@ struct List[T: CollectionElement, small_buffer_size: Int = 0]( return (self.data + normalized_idx)[] - # TODO(30737): Replace __getitem__ with this as __refitem__, but lots of places use it + # TODO(30737): Replace __getitem__ with this, but lots of places use it fn __get_ref( self: Reference[Self, _, _], i: Int ) -> Reference[T, self.is_mutable, self.lifetime]: diff --git a/stdlib/src/memory/arc.mojo b/stdlib/src/memory/arc.mojo index 836bd4a0a6..ffa3354a71 100644 --- a/stdlib/src/memory/arc.mojo +++ b/stdlib/src/memory/arc.mojo @@ -101,9 +101,7 @@ struct Arc[T: Movable](CollectionElement): # FIXME: This isn't right - the element should be mutable regardless # of whether the 'self' type is mutable. - fn __refitem__( - self: Reference[Self, _, _] - ) -> Reference[T, self.is_mutable, self.lifetime]: + fn __getitem__(self: Reference[Self, _, _]) -> ref [self.lifetime] T: """Returns a Reference to the managed value. Returns: @@ -117,4 +115,4 @@ struct Arc[T: Movable](CollectionElement): Returns: The UnsafePointer to the underlying memory. """ - return UnsafePointer.address_of(self._inner[].payload) + return UnsafePointer.address_of(self._inner[].payload)[] diff --git a/stdlib/src/utils/span.mojo b/stdlib/src/utils/span.mojo index d1ab7f5a61..b3c2c3c0c9 100644 --- a/stdlib/src/utils/span.mojo +++ b/stdlib/src/utils/span.mojo @@ -169,10 +169,10 @@ struct Span[ value: The value to set at the given index. """ # note that self._refitem__ is already bounds checking - var ref = Reference[T, __mlir_attr.`1: i1`, __lifetime_of(self)]( + var r = Reference[T, __mlir_attr.`1: i1`, __lifetime_of(self)]( UnsafePointer(self._refitem__(idx))[] ) - ref[] = value + r[] = value @always_inline fn __getitem__(self, slice: Slice) -> Self: diff --git a/stdlib/src/utils/static_tuple.mojo b/stdlib/src/utils/static_tuple.mojo index 408c3c2438..9ab41f97eb 100644 --- a/stdlib/src/utils/static_tuple.mojo +++ b/stdlib/src/utils/static_tuple.mojo @@ -331,9 +331,9 @@ struct InlineArray[ElementType: CollectionElement, size: Int](Sized): @parameter for i in range(size): - var ref = self._get_reference_unsafe(i) + var eltref = self._get_reference_unsafe(i) initialize_pointee_move( - UnsafePointer[Self.ElementType](ref), elems[i] + UnsafePointer[Self.ElementType](eltref), elems[i] ) # ===------------------------------------------------------------------===# @@ -341,11 +341,11 @@ struct InlineArray[ElementType: CollectionElement, size: Int](Sized): # ===------------------------------------------------------------------===# @always_inline("nodebug") - fn __refitem__[ + fn __getitem__[ IntableType: Intable, - ](self: Reference[Self, _, _], index: IntableType) -> Reference[ - Self.ElementType, self.is_mutable, self.lifetime - ]: + ](self: Reference[Self, _, _], index: IntableType) -> ref [ + self.lifetime + ] Self.ElementType: """Get a `Reference` to the element at the given index. Parameters: @@ -362,15 +362,13 @@ struct InlineArray[ElementType: CollectionElement, size: Int](Sized): if normalized_idx < 0: normalized_idx += size - return self[]._get_reference_unsafe(normalized_idx) + return self[]._get_reference_unsafe(normalized_idx)[] @always_inline("nodebug") - fn __refitem__[ + fn __getitem__[ IntableType: Intable, index: IntableType, - ](self: Reference[Self, _, _]) -> Reference[ - Self.ElementType, self.is_mutable, self.lifetime - ]: + ](self: Reference[Self, _, _]) -> ref [self.lifetime] Self.ElementType: """Get a `Reference` to the element at the given index. Parameters: @@ -389,7 +387,7 @@ struct InlineArray[ElementType: CollectionElement, size: Int](Sized): if i < 0: normalized_idx += size - return self[]._get_reference_unsafe(normalized_idx) + return self[]._get_reference_unsafe(normalized_idx)[] # ===------------------------------------------------------------------=== # # Trait implementations @@ -414,7 +412,7 @@ struct InlineArray[ElementType: CollectionElement, size: Int](Sized): ) -> Reference[Self.ElementType, self.is_mutable, self.lifetime]: """Get a reference to an element of self without checking index bounds. - Users should opt for `__refitem__` instead of this method. + Users should opt for `__getitem__` instead of this method. """ var ptr = __mlir_op.`pop.array.gep`( UnsafePointer.address_of(self[]._array).address, diff --git a/stdlib/src/utils/variant.mojo b/stdlib/src/utils/variant.mojo index 3adcf36ed7..b1dcc7cd57 100644 --- a/stdlib/src/utils/variant.mojo +++ b/stdlib/src/utils/variant.mojo @@ -210,11 +210,9 @@ struct Variant[*Ts: CollectionElement](CollectionElement): # Operator dunders # ===-------------------------------------------------------------------===# - fn __refitem__[ + fn __getitem__[ T: CollectionElement - ](self: Reference[Self, _, _]) -> Reference[ - T, self.is_mutable, self.lifetime - ]: + ](self: Reference[Self, _, _]) -> ref [self.lifetime] T: """Get the value out of the variant as a type-checked type. This explicitly check that your value is of that type! @@ -233,7 +231,7 @@ struct Variant[*Ts: CollectionElement](CollectionElement): if not self[].isa[T](): abort("get: wrong variant type") - return self[].unsafe_get[T]() + return self[].unsafe_get[T]()[] # ===-------------------------------------------------------------------===# # Methods