From 5147e7180f72696515fc0068ca8e41462b4546d5 Mon Sep 17 00:00:00 2001 From: Ved Patwardhan <54766411+vedpatwardhan@users.noreply.github.com> Date: Tue, 27 Feb 2024 06:51:57 +0000 Subject: [PATCH 1/2] fix: dealing with boolean arrays with tuple query in ivy.get_item --- ivy/functional/ivy/general.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ivy/functional/ivy/general.py b/ivy/functional/ivy/general.py index dfc16539047ef..8ce71bba60b1e 100644 --- a/ivy/functional/ivy/general.py +++ b/ivy/functional/ivy/general.py @@ -2937,6 +2937,10 @@ def _parse_query(query, x_shape, scatter=False): array_queries = ivy.broadcast_arrays( *[v for i, v in enumerate(query) if i in array_inds] ) + array_queries = [ + ivy.nonzero(q, as_tuple=False) if ivy.is_bool_dtype(q) else q + for q in array_queries + ] array_queries = [ ( ivy.where(arr < 0, arr + x_shape[i], arr).astype(ivy.int64) From 28b8382068bf142757948df6277d8af7b5a21446 Mon Sep 17 00:00:00 2001 From: Ved Patwardhan <54766411+vedpatwardhan@users.noreply.github.com> Date: Tue, 27 Feb 2024 08:40:11 +0000 Subject: [PATCH 2/2] made some fixes and added a test --- ivy/functional/ivy/general.py | 2 +- .../test_ivy/test_functional/test_core/test_general.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/ivy/functional/ivy/general.py b/ivy/functional/ivy/general.py index 8ce71bba60b1e..cb70031569070 100644 --- a/ivy/functional/ivy/general.py +++ b/ivy/functional/ivy/general.py @@ -2938,7 +2938,7 @@ def _parse_query(query, x_shape, scatter=False): *[v for i, v in enumerate(query) if i in array_inds] ) array_queries = [ - ivy.nonzero(q, as_tuple=False) if ivy.is_bool_dtype(q) else q + ivy.nonzero(q, as_tuple=False)[0] if ivy.is_bool_dtype(q) else q for q in array_queries ] array_queries = [ diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_general.py b/ivy_tests/test_ivy/test_functional/test_core/test_general.py index ebeecffba73dc..0a75937e8a1ca 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_general.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_general.py @@ -1074,6 +1074,15 @@ def test_get_all_arrays_in_memory(): container_flags=st.just([False]), test_with_copy=st.just(True), ) +@handle_example( + test_example=True, + test_flags={ + "num_positional_args": 2, + }, + dtypes_x_query=(["float32", "bool"], np.ones((1, 3, 3)), (np.array([True]), 2, 2)), + copy=None, + fn_name="get_item", +) def test_get_item( dtypes_x_query, copy,