Skip to content

Commit

Permalink
fix: use primary get_item implementation for 1d tensor queries with…
Browse files Browse the repository at this point in the history
… tf, and added a test for partial_mixed_handler to check for this (#28456)
  • Loading branch information
mattbarrett98 authored Mar 4, 2024
1 parent 9ecf0f2 commit 2c11f9f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
18 changes: 12 additions & 6 deletions ivy/functional/backends/tensorflow/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,21 @@ def get_item(
if ivy.is_array(query) and ivy.is_bool_dtype(query):
if not len(query.shape):
return tf.expand_dims(x, 0)
if ivy.is_array(query) and not ivy.is_bool_dtype(query):
return tf.gather(x, query)
return x.__getitem__(query)


get_item.partial_mixed_handler = lambda x, query, **kwargs: (
all(_check_query(i) for i in query)
and len({i.shape for i in query if ivy.is_array(i)}) <= 1
if isinstance(query, tuple)
else _check_query(query)
)
def _get_item_condition(x, query, **kwargs):
return (
all(_check_query(i) for i in query)
and len({i.shape for i in query if ivy.is_array(i)}) <= 1
if isinstance(query, tuple)
else _check_query(query) or ivy.is_array(query)
)


get_item.partial_mixed_handler = _get_item_condition


def to_numpy(x: Union[tf.Tensor, tf.Variable], /, *, copy: bool = True) -> np.ndarray:
Expand Down
10 changes: 10 additions & 0 deletions ivy_tests/test_ivy/test_functional/test_core/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -1865,6 +1865,16 @@ def test_supports_inplace_updates(
)


def test_tensorflow_get_item_condition():
from ivy.functional.backends.tensorflow.general import _get_item_condition

query = tf.constant([0])
assert _get_item_condition(None, query)

query = tf.constant([[0, 1], [2, 2]])
assert _get_item_condition(None, query)


# to_list
@handle_test(
fn_tree="functional.ivy.to_list",
Expand Down

0 comments on commit 2c11f9f

Please sign in to comment.