Skip to content

Commit

Permalink
added support for an extra case of advanced indexing to ivy get_item/…
Browse files Browse the repository at this point in the history
…set_item; this solved some rare assertion error
  • Loading branch information
AnnaTz committed Jul 13, 2023
1 parent b80f5b6 commit a922b72
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions ivy/functional/ivy/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -2859,8 +2859,26 @@ def set_item(
}


def _int_list_or_array(var):
# check if var is a list/tuple of integers or a 1-d array of integers
return \
(isinstance(var, (list, tuple)) and all(isinstance(i, int) for i in var)) or \
(ivy.is_array(var) and ivy.is_int_dtype(var) and len(var.shape) == 1)


def _parse_query(query, x_shape):
query = (query,) if not isinstance(query, (tuple, list)) else query
if isinstance(query, tuple) and all(_int_list_or_array(q) for q in query):
query = list(query) if isinstance(query, tuple) else query
for i, idx in enumerate(query):
if ivy.is_array(idx):
query[i] = ivy.where(idx < 0, idx + x_shape[i], idx)
else:
query[i] = [ii + x_shape[i] if ii < 0 else ii for ii in idx]
query = ivy.array(query)
query = query.T if len(query.shape) > 1 else query
target_shape = [query.shape[0], *x_shape[query.shape[1]:]]
return query, target_shape
query = (query,) if not isinstance(query, tuple) else query
query = (
_parse_ellipsis(query, len(x_shape))
if any(q is Ellipsis for q in query)
Expand Down Expand Up @@ -2889,15 +2907,14 @@ def _parse_query(query, x_shape):
elif isinstance(idx, int):
query[i] = ivy.array(idx + s if idx < 0 else idx)
elif isinstance(idx, (tuple, list)):
# ToDo: add handling for case of nested tuple/lists
query[i] = ivy.array([ii + s if ii < 0 else ii for ii in idx])
elif ivy.is_array(idx):
query[i] = ivy.where(idx < 0, idx + s, idx)
else:
raise ivy.exceptions.IvyException("unsupported query format")
query[i] = ivy.astype(query[i], ivy.int64)
target_shape = [s for q in query for s in q.shape]
if all(s == 1 for s in target_shape):
target_shape = list(max(query, key=lambda x: x.ndim).shape)
target_shape += list(x_shape[len(query) :])
query = [q.reshape((-1,)) if len(q.shape) > 1 else q for q in query]
grid = ivy.meshgrid(*query, indexing="ij")
Expand Down

0 comments on commit a922b72

Please sign in to comment.