Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam-Armstrong committed Jul 11, 2024
1 parent 4c6699f commit 632e9c9
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 39 deletions.
7 changes: 5 additions & 2 deletions ivy/functional/frontends/tensorflow/general_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,9 +405,12 @@ def scan(
return ivy.associative_scan(elems, fn, reverse=reverse)


@with_supported_dtypes(
{"2.17.0 and below": ("int32", "int64")}, "tensorflow"
)
@to_ivy_arrays_and_back
def scatter_nd(indices, updates, shape=None, reduction="sum", out=None, name=None):
return ivy.scatter_nd(indices, updates, shape=shape, reduction=reduction, out=out)
def scatter_nd(indices, updates, shape, name=None):
return ivy.astype(ivy.scatter_nd(indices, updates, shape=shape), updates.dtype)


@to_ivy_arrays_and_back
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# global
from hypothesis import strategies as st, assume
from hypothesis.extra.numpy import arrays
import numpy as np
from tensorflow import errors as tf_errors

Expand Down Expand Up @@ -371,8 +372,8 @@ def _strided_slice_helper(draw):
def _values_and_ndindices(
draw,
*,
array_dtypes,
indices_dtypes=helpers.get_dtypes("integer"),
indices_dtypes=helpers.get_dtypes("valid"),
array_dtypes=helpers.get_dtypes("numeric"),
allow_inf=False,
x_min_value=None,
x_max_value=None,
Expand All @@ -381,9 +382,9 @@ def _values_and_ndindices(
min_dim_size=1,
max_dim_size=10,
):
# Generate the dtype and values for x
x_dtype, x, x_shape = draw(
helpers.dtype_and_values(
available_dtypes=array_dtypes,
allow_inf=allow_inf,
ret_shape=True,
min_value=x_min_value,
Expand All @@ -394,54 +395,67 @@ def _values_and_ndindices(
max_dim_size=max_dim_size,
)
)
x_dtype = x_dtype[0] if isinstance(x_dtype, (list)) else x_dtype
x = x[0] if isinstance(x, (list)) else x
# indices_dims defines how far into the array to index.
x_dtype = x_dtype[0] if isinstance(x_dtype, list) else x_dtype
x = x[0] if isinstance(x, list) else x

# Determine the number of index dimensions
indices_dims = draw(
helpers.ints(
min_value=1,
max_value=len(x_shape) - 1,
max_value=len(x_shape),
)
)

dtype_str = draw(st.sampled_from(indices_dtypes))
if dtype_str == "int16":
dtype = np.int16
elif dtype_str == "int32":
dtype = np.int32
else:
dtype = np.int64

# Generate the shape of the output tensor
output_shape = draw(
arrays(
dtype=dtype,
shape=(indices_dims,),
elements=st.integers(min_value=1, max_value=max_dim_size)
)
)

# num_ndindices defines the number of elements to generate.
num_ndindices = draw(
# Ensure output_shape is at least as large as x_shape up to indices_dims
for i in range(indices_dims):
output_shape[i] = max(output_shape[i], x_shape[i])

# Generate the number of indices
num_indices = draw(
helpers.ints(
min_value=1,
max_value=x_shape[indices_dims],
max_value=10
)
)

# updates_dims defines how far into the array to index.
# Generate the indices
indices = []
for _ in range(num_indices):
index = [draw(st.integers(min_value=0, max_value=output_shape[j] - 1)) for j in range(indices_dims)]
indices.append(index)
indices = np.array(indices, dtype=dtype)

# Generate the dtype and values for updates
updates_shape = list(indices.shape[:-1]) + list(output_shape[indices.shape[-1]:])
updates_dtype, updates = draw(
helpers.dtype_and_values(
available_dtypes=array_dtypes,
allow_inf=allow_inf,
shape=x_shape[indices_dims:],
num_arrays=num_ndindices,
shape=updates_shape,
shared_dtype=True,
)
)
updates_dtype = (
updates_dtype[0] if isinstance(updates_dtype, list) else updates_dtype
)
updates_dtype = updates_dtype[0] if isinstance(updates_dtype, list) else updates_dtype
updates = updates[0] if isinstance(updates, list) else updates

indices = []
indices_dtype = draw(st.sampled_from(indices_dtypes))
for _ in range(num_ndindices):
nd_index = []
for j in range(indices_dims):
axis_index = draw(
helpers.ints(
min_value=0,
max_value=max(0, x_shape[j] - 1),
)
)
nd_index.append(axis_index)
indices.append(nd_index)
indices = np.array(indices)
return [x_dtype, indices_dtype, updates_dtype], x, indices, updates
return [x_dtype, indices.dtype, updates_dtype], x, indices, updates, output_shape


@st.composite
Expand Down Expand Up @@ -1780,29 +1794,26 @@ def _test_fn(a, x):
@handle_frontend_test(
fn_tree="tensorflow.scatter_nd",
x=_values_and_ndindices(
array_dtypes=helpers.get_dtypes("numeric"),
indices_dtypes=["int32", "int64"],
array_dtypes=helpers.get_dtypes("numeric"),
x_min_value=0,
x_max_value=0,
min_num_dims=2,
allow_inf=False,
),
reduction=st.sampled_from(["sum", "min", "max", "replace"]),
)
def test_tensorflow_scatter_nd(
*,
x,
reduction,
test_flags,
backend_fw,
fn_tree,
frontend,
on_device,
):
(val_dtype, ind_dtype, update_dtype), vals, ind, updates = x
shape = vals.shape
(_, ind_dtype, update_dtype), _, ind, updates, shape = x
helpers.test_frontend_function(
input_dtypes=[ind_dtype, update_dtype],
input_dtypes=[update_dtype, ind_dtype],
frontend=frontend,
test_flags=test_flags,
on_device=on_device,
Expand Down

0 comments on commit 632e9c9

Please sign in to comment.