-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: gabrieldemarmiesse <gabrieldemarmiesse@gmail.com>
- Loading branch information
1 parent
510ab7c
commit 5d859b1
Showing
3 changed files
with
155 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# ===----------------------------------------------------------------------=== # | ||
# Copyright (c) 2024, Modular Inc. All rights reserved. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions: | ||
# https://llvm.org/LICENSE.txt | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ===----------------------------------------------------------------------=== # | ||
"""The utilities provided in this module help normalize the access | ||
to data elements in arrays.""" | ||
|
||
|
||
fn get_out_of_bounds_error_message[ | ||
container_name: String | ||
](i: Int, container_length: Int) -> String: | ||
if container_length == 0: | ||
return ( | ||
"The " | ||
+ container_name | ||
+ " has a length of 0. " | ||
+ "Thus it's not possible to access its values with an index " | ||
+ "but the index value " | ||
+ str(i) | ||
+ " was used. " | ||
+ "Aborting now to avoid an out-of-bounds access." | ||
) | ||
else: | ||
return ( | ||
"The " | ||
+ container_name | ||
+ " has a length of " | ||
+ str(container_length) | ||
+ ". " | ||
+ "Thus the index provided should be between " | ||
+ str(-container_length) | ||
+ " (inclusive) and " | ||
+ str(container_length) | ||
+ " (exclusive) but the index value " | ||
+ str(i) | ||
+ " was used. " | ||
+ "Aborting now to avoid an out-of-bounds access." | ||
) | ||
|
||
|
||
@always_inline | ||
fn normalize_index[ | ||
inferred IndexType: Indexer, | ||
inferred ContainerType: Sized, | ||
container_name: StringLiteral, | ||
](index_value: IndexType, container: ContainerType) -> Int: | ||
"""Normalize the given index value to a valid index value for the given container length. | ||
If the provided value is negative, the `index + container_length` is returned. | ||
Parameters: | ||
IndexType: The type of the index value. Must have an `__index__` method. | ||
ContainerType: The type of the container. Must have a `__len__` method. | ||
container_name: The name of the container. Used for the error message. | ||
Args: | ||
index_value: The index value to normalize. | ||
container: The container to normalize the index for. | ||
Returns: | ||
The normalized index value. | ||
""" | ||
var index_as_int = index(index_value) | ||
var container_length = len(container) | ||
|
||
if not (-container_length <= index_as_int < container_length): | ||
# TODO: Get the container_name from the ContainerType when the compiler allows it. | ||
abort( | ||
get_out_of_bounds_error_message[container_name]( | ||
index_as_int, container_length | ||
) | ||
) | ||
if index_as_int < 0: | ||
index_as_int += container_length | ||
return index_as_int |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# ===----------------------------------------------------------------------=== # | ||
# Copyright (c) 2024, Modular Inc. All rights reserved. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions: | ||
# https://llvm.org/LICENSE.txt | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ===----------------------------------------------------------------------=== # | ||
# RUN: %mojo %s | ||
|
||
from collections._index_normalization import ( | ||
get_out_of_bounds_error_message, | ||
normalize_index, | ||
) | ||
from testing import assert_equal | ||
|
||
|
||
def test_out_of_bounds_message(): | ||
assert_equal( | ||
get_out_of_bounds_error_message[container_name="List"](5, 2), | ||
( | ||
"The List has a length of 2. Thus the index provided should be" | ||
" between -2 (inclusive) and 2 (exclusive) but the index value 5" | ||
" was used. Aborting now to avoid an out-of-bounds access." | ||
), | ||
) | ||
|
||
assert_equal( | ||
get_out_of_bounds_error_message[container_name="List"](0, 0), | ||
( | ||
"The List has a length of 0. Thus it's not possible to access its" | ||
" values with an index but the index value 0 was used. Aborting now" | ||
" to avoid an out-of-bounds access." | ||
), | ||
) | ||
assert_equal( | ||
get_out_of_bounds_error_message[container_name="InlineArray"](8, 0), | ||
( | ||
"The InlineArray has a length of 0. Thus it's not possible to" | ||
" access its values with an index but the index value 8 was used." | ||
" Aborting now to avoid an out-of-bounds access." | ||
), | ||
) | ||
|
||
|
||
def test_normalize_index(): | ||
container = List[Int](1, 1, 1, 1) | ||
assert_equal(normalize_index[container_name=""](-4, container), 0) | ||
assert_equal(normalize_index[container_name=""](-3, container), 1) | ||
assert_equal(normalize_index[container_name=""](-2, container), 2) | ||
assert_equal(normalize_index[container_name=""](-1, container), 3) | ||
assert_equal(normalize_index[container_name=""](0, container), 0) | ||
assert_equal(normalize_index[container_name=""](1, container), 1) | ||
assert_equal(normalize_index[container_name=""](2, container), 2) | ||
assert_equal(normalize_index[container_name=""](3, container), 3) | ||
|
||
|
||
def main(): | ||
test_out_of_bounds_message() | ||
test_normalize_index() |