Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

index_add_ #26761

Merged
merged 37 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
1ba0eab
index_add_ issue #26393
imsoumya18 Oct 8, 2023
83c86a9
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Oct 9, 2023
d2681b0
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Oct 10, 2023
b26d2ec
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Oct 10, 2023
6338e5d
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Oct 12, 2023
69b11b8
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Oct 13, 2023
c947eb5
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Oct 14, 2023
90c77b3
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Oct 15, 2023
8a60203
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Oct 16, 2023
8550b08
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Oct 18, 2023
4535d73
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Oct 19, 2023
b9265b5
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Oct 22, 2023
43e83cb
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Oct 25, 2023
64e4bc1
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Nov 8, 2023
0e45dc3
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Nov 9, 2023
3fbebaf
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Nov 10, 2023
0f862b6
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Nov 19, 2023
ca84dbe
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Nov 21, 2023
dd1832d
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Dec 1, 2023
7d2be3c
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Dec 4, 2023
44ce753
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Dec 5, 2023
464ff40
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Dec 7, 2023
37fbc54
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Dec 7, 2023
22d1409
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Dec 9, 2023
597036f
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Dec 11, 2023
10ecbb1
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Jan 2, 2024
52e7d6b
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Jan 9, 2024
511b8c8
update
imsoumya18 Jan 9, 2024
448b916
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Jan 9, 2024
bafba4f
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Jan 9, 2024
5e76663
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Jan 11, 2024
699cee6
update
imsoumya18 Jan 11, 2024
c3a7095
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Jan 11, 2024
b8f6f7b
update
imsoumya18 Jan 11, 2024
98854df
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Jan 13, 2024
1911ddd
all tests passed
imsoumya18 Jan 13, 2024
e640ad2
Merge branch 'unifyai:main' into index_add_#26393
imsoumya18 Jan 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions ivy/functional/frontends/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,39 @@
)
from ivy.func_wrapper import with_unsupported_dtypes


@with_supported_dtypes(
{"2.5.1 and below": ("bool", "int32", "int64", "float16", "float32", "float64")},
"paddle",
)
@to_ivy_arrays_and_back
def index_add_(x, index, axis, value, *, name=None):
x = ivy.swapaxes(x, axis, 0)
value = ivy.swapaxes(value, axis, 0)
_to_adds = []
index = sorted(zip(ivy.to_list(index), range(len(index))), key=(lambda i: i[0]))
while index:
_curr_idx = index[0][0]
while len(_to_adds) < _curr_idx:
_to_adds.append(ivy.zeros_like(value[0]))
_to_add_cum = ivy.get_item(value, index[0][1])
while (len(index)) > 1 and (index[0][0] == index[1][0]):
_to_add_cum = _to_add_cum + ivy.get_item(value, index.pop(1)[1])
index.pop(0)
_to_adds.append(_to_add_cum)
while len(_to_adds) < x.shape[0]:
_to_adds.append(ivy.zeros_like(value[0]))
_to_adds = ivy.stack(_to_adds)
if len(x.shape) < 2:
# Added this line due to the paddle backend treating scalars as 1-d arrays
_to_adds = ivy.flatten(_to_adds)

ret = ivy.add(x, _to_adds)
ret = ivy.swapaxes(ret, axis, 0)
x = ret
return x


# NOTE:
# Only inplace functions are to be added in this file.
# Please add non-inplace counterparts to `/frontends/paddle/manipulation.py`.
Expand All @@ -17,6 +50,5 @@
)
@to_ivy_arrays_and_back
def reshape_(x, shape):
ret = ivy.reshape(x, shape)
ivy.inplace_update(x, ret)
ivy.reshape(x, shape)
return x
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,63 @@
# --------------- #


@st.composite
def _arrays_dim_idx_n_dtypes(draw):
num_dims = draw(st.shared(helpers.ints(min_value=1, max_value=4), key="num_dims"))
num_arrays = 2
common_shape = draw(
helpers.lists(
x=helpers.ints(min_value=2, max_value=3),
min_size=num_dims - 1,
max_size=num_dims - 1,
)
)
_dim = draw(helpers.ints(min_value=0, max_value=num_dims - 1))
unique_dims = draw(
helpers.lists(
x=helpers.ints(min_value=2, max_value=3),
min_size=num_arrays,
max_size=num_arrays,
)
)

min_dim = min(unique_dims)
max_dim = max(unique_dims)
_idx = draw(
helpers.array_values(
shape=min_dim,
dtype="int64",
min_value=0,
max_value=max_dim,
exclude_min=False,
)
)

xs = []
# available_input_types = draw(helpers.get_dtypes("integer"))
# available_input_types = ["int32", "int64", "float16", "float32", "float64"]
available_input_types = ["int32", "int64"]
input_dtypes = draw(
helpers.array_dtypes(
available_dtypes=available_input_types,
num_arrays=num_arrays,
shared_dtype=True,
)
)
for ud, dt in zip(unique_dims, input_dtypes):
x = draw(
helpers.array_values(
shape=common_shape[:_dim] + [ud] + common_shape[_dim:],
dtype=dt,
large_abs_safety_factor=2.5,
small_abs_safety_factor=2.5,
safety_factor_scale="log",
)
)
xs.append(x)
return xs, input_dtypes, _dim, _idx


@st.composite
def dtypes_x_reshape_(draw):
shape = draw(helpers.get_shape(min_num_dims=1))
Expand All @@ -25,6 +82,42 @@ def dtypes_x_reshape_(draw):
return dtypes, x, shape


# --- Main --- #
# ------------ #


@handle_frontend_test(
fn_tree="paddle.tensor.manipulation.index_add_",
xs_dtypes_dim_idx=_arrays_dim_idx_n_dtypes(),
)
def test_paddle_index_add_(
*,
xs_dtypes_dim_idx,
on_device,
fn_tree,
frontend,
test_flags,
backend_fw,
):
xs, input_dtypes, axis, indices = xs_dtypes_dim_idx
if xs[0].shape[axis] < xs[1].shape[axis]:
source, input = xs
else:
input, source = xs
helpers.test_frontend_function(
input_dtypes=input_dtypes,
backend_to_test=backend_fw,
test_flags=test_flags,
fn_tree=fn_tree,
frontend=frontend,
on_device=on_device,
x=input,
index=indices,
axis=axis,
value=source,
)


# reshape_
@handle_frontend_test(
fn_tree="paddle.tensor.manipulation.reshape_",
Expand Down
Loading