From 57a91139618378144641ca60ef7fd872a4749bac Mon Sep 17 00:00:00 2001 From: Kaushal Phulgirkar Date: Tue, 23 Apr 2024 20:42:54 +0530 Subject: [PATCH] Added fn to normalize index calculation Signed-off-by: Kaushal Phulgirkar --- stdlib/src/collections/list.mojo | 26 +++++++++++--------------- stdlib/test/collections/test_list.mojo | 2 +- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/stdlib/src/collections/list.mojo b/stdlib/src/collections/list.mojo index 38b9b62b78..742c007bb3 100644 --- a/stdlib/src/collections/list.mojo +++ b/stdlib/src/collections/list.mojo @@ -220,9 +220,7 @@ struct List[T: CollectionElement](CollectionElement, Sized, Boolable): """ debug_assert(i <= self.size, "insert index out of range") - var normalized_idx = i - if i < 0: - normalized_idx = max(0, len(self) + i) + var normalized_idx = self._normalize_index(i) var earlier_idx = len(self) var later_idx = len(self) - 1 @@ -292,9 +290,7 @@ struct List[T: CollectionElement](CollectionElement, Sized, Boolable): """ debug_assert(-len(self) <= i < len(self), "pop index out of range") - var normalized_idx = i - if i < 0: - normalized_idx += len(self) + var normalized_idx = self._normalize_index(i) var ret_val = move_from_pointee(self.data + normalized_idx) for j in range(normalized_idx + 1, self.size): @@ -490,9 +486,7 @@ struct List[T: CollectionElement](CollectionElement, Sized, Boolable): """ debug_assert(-self.size <= i < self.size, "index must be within bounds") - var normalized_idx = i - if i < 0: - normalized_idx += len(self) + var normalized_idx = self._normalize_index(i) destroy_pointee(self.data + normalized_idx) initialize_pointee_move(self.data + normalized_idx, value^) @@ -554,9 +548,7 @@ struct List[T: CollectionElement](CollectionElement, Sized, Boolable): """ debug_assert(-self.size <= i < self.size, "index must be within bounds") - var normalized_idx = i - if i < 0: - normalized_idx += len(self) + var normalized_idx = self._normalize_index(i) return (self.data + normalized_idx)[] @@ -572,9 +564,7 @@ struct List[T: CollectionElement](CollectionElement, Sized, Boolable): Returns: An immutable reference to the element at the given index. """ - var normalized_idx = i - if i < 0: - normalized_idx += self[].size + var normalized_idx = Reference(self)[]._normalize_index(i) return (self[].data + normalized_idx)[] @@ -671,3 +661,9 @@ struct List[T: CollectionElement](CollectionElement, Sized, Boolable): if elem[] == value: count += 1 return count + + @always_inline + fn _normalize_index(self, index: Int) -> Int: + if index < 0: + return _max(0, len(self) + index) + return index diff --git a/stdlib/test/collections/test_list.mojo b/stdlib/test/collections/test_list.mojo index 4f9f9ac0cb..1f5c4b77f8 100644 --- a/stdlib/test/collections/test_list.mojo +++ b/stdlib/test/collections/test_list.mojo @@ -549,7 +549,7 @@ def test_2d_dynamic_list(): def test_list_explicit_copy(): var list = List[CopyCounter]() - list.append(CopyCounter()^) + list.append(CopyCounter()) var list_copy = List(list) assert_equal(0, list.__get_ref(0)[].copy_count) assert_equal(1, list_copy.__get_ref(0)[].copy_count)