Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: torch backend- gather fix #27757

Merged
merged 101 commits into from
Mar 23, 2024
Merged
Show file tree
Hide file tree
Changes from 93 commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
8a5020b
Draft for the ivy.unflatten | not finished
Kacper-W-Kozdon Nov 30, 2023
48df5d3
feat: Added `tridiagonal_solve` in tensorflow frontend (#23279)
AbdullahSabry Dec 1, 2023
ed14e0e
feat: added `floor_mod` in paddle frontend (#26064)
AbdullahSabry Dec 1, 2023
f0e2366
feat: add base Splitter class for decision trees in sklearn front
Ishticode Dec 1, 2023
b892dd4
🤖 Lint code
ivy-branch Dec 1, 2023
ffdf7f3
feat: add base structure for BestSplitter method which overrides init…
Ishticode Dec 1, 2023
fa6ab8e
🤖 Lint code
ivy-branch Dec 1, 2023
00a781f
feat: add DensePartitioner required for both the best splitter and ra…
Ishticode Dec 1, 2023
17325ed
feat: add custom sort function for DensePartitions's sort_samples_and…
Ishticode Dec 1, 2023
e66e1d5
test commit (#27421)
NripeshN Dec 1, 2023
cbeaa5e
feat: Added `torchTensor.frac` to torch frontend (#27417)
AbdullahSabry Dec 1, 2023
6b880f9
Add __getstate__ and __setstate__ methods to frontend torch.nn.Module
hmahmood24 Dec 2, 2023
a39a99a
fix: Fixed typos (used `codespell` pre-commit hook) (#27426)
Sai-Suraj-27 Dec 3, 2023
f56f91d
fix: Fixed passing of arguments in `get_referrers_recursive()` functi…
Sai-Suraj-27 Dec 3, 2023
392ec7c
feat: add erfc_ to torch frontend (#27291)
G544 Dec 4, 2023
7d77583
feat: add node_split_best function in sklearn frontend splitter submo…
Ishticode Dec 4, 2023
888d200
feat: define _init_split helper function used in node_split_best
Ishticode Dec 4, 2023
27a5f98
feat: define SplitRecord class used for split structure in node best …
Ishticode Dec 4, 2023
827902b
🤖 Lint code
ivy-branch Dec 4, 2023
f3986b0
🤖 Lint code
ivy-branch Dec 4, 2023
dd2a324
fix: use ivy.inf temporarily cutting out INFINITY constant to avoid u…
Ishticode Dec 4, 2023
09f96bd
refactor: remove unwanted variables in node_split_best in sklearn fro…
Ishticode Dec 4, 2023
b48a92e
fix: Fixed order of functions and added missing `staticmethod` decora…
Sai-Suraj-27 Dec 4, 2023
3543bec
fix: Fixed raising `TypeError` exception instead or `ValueError` exce…
Sai-Suraj-27 Dec 4, 2023
070bcd6
fix: Add a check in the converters to filter any NoneType model param…
hmahmood24 Dec 4, 2023
110fd68
fixing torch_gather backend
Kacper-W-Kozdon Dec 4, 2023
64ca8d7
fixing torch_gather backend
Kacper-W-Kozdon Dec 4, 2023
45c3ada
fixing torch_gather backend
Kacper-W-Kozdon Dec 4, 2023
58b4b2e
fixing torch_gather backend
Kacper-W-Kozdon Dec 4, 2023
176af0c
fixing torch_gather backend
Kacper-W-Kozdon Dec 4, 2023
368410d
fixing torch_gather backend
Kacper-W-Kozdon Dec 4, 2023
fdca35a
fixing torch_gather backend
Kacper-W-Kozdon Dec 4, 2023
25a8268
fixing torch_gather backend
Kacper-W-Kozdon Dec 4, 2023
77d2cad
fixing torch_gather backend
Kacper-W-Kozdon Dec 4, 2023
3d6e048
fixing torch_gather backend
Kacper-W-Kozdon Dec 4, 2023
b0da19d
fixing torch_gather backend
Kacper-W-Kozdon Dec 4, 2023
541d857
fixing torch_gather backend
Kacper-W-Kozdon Dec 4, 2023
7d4f58e
fixing torch_gather backend
Kacper-W-Kozdon Dec 4, 2023
4707d0f
fixing torch_gather backend
Kacper-W-Kozdon Dec 4, 2023
1f3cee4
fixing torch_gather backend
Kacper-W-Kozdon Dec 5, 2023
11d0055
fixing torch_gather backend
Kacper-W-Kozdon Dec 5, 2023
2ca112c
fixing torch_gather backend
Kacper-W-Kozdon Dec 5, 2023
9c38328
fixing torch_gather backend
Kacper-W-Kozdon Dec 5, 2023
5c5935e
fixing torch_gather backend
Kacper-W-Kozdon Dec 5, 2023
f13c16f
fixing torch_gather backend
Kacper-W-Kozdon Dec 5, 2023
7d6ad4a
fixing torch_gather backend
Kacper-W-Kozdon Dec 5, 2023
c05b919
fixing torch_gather backend
Kacper-W-Kozdon Dec 5, 2023
f155aa6
fixing torch_gather backend
Kacper-W-Kozdon Dec 5, 2023
a132f70
fixing torch_gather backend
Kacper-W-Kozdon Dec 5, 2023
8170ad6
fixing torch_gather backend
Kacper-W-Kozdon Dec 5, 2023
a6d2676
fixing torch_gather backend
Kacper-W-Kozdon Dec 5, 2023
e6be13c
fixing torch_gather backend
Kacper-W-Kozdon Dec 5, 2023
c2985f9
fixing torch_gather backend
Kacper-W-Kozdon Dec 5, 2023
ed7fd7c
fixing torch_gather backend, add reshape and expand
Kacper-W-Kozdon Dec 7, 2023
07f7e59
fixing torch_gather backend, add reshape and expand
Kacper-W-Kozdon Dec 7, 2023
9e5c796
fixing torch_gather backend, add reshape and expand
Kacper-W-Kozdon Dec 7, 2023
f1c6cbf
fixing torch_gather backend, add reshape and expand
Kacper-W-Kozdon Dec 7, 2023
4cf9645
fixing torch_gather backend, add reshape and expand
Kacper-W-Kozdon Dec 7, 2023
de6e629
fixing torch_gather backend, add reshape and expand
Kacper-W-Kozdon Dec 7, 2023
3ee2241
fixing torch_gather backend, add reshape and expand
Kacper-W-Kozdon Dec 7, 2023
8c8f83b
gather_backend fixes
Kacper-W-Kozdon Dec 11, 2023
49c1e8f
torch_gather_fix copied, not cleaned
Kacper-W-Kozdon Dec 14, 2023
3eed23b
torch_gather_fix copied, not cleaned
Kacper-W-Kozdon Dec 14, 2023
c31c343
torch_gather_fix copied, not cleaned
Kacper-W-Kozdon Dec 14, 2023
3711abe
torch_gather_fix merge
Kacper-W-Kozdon Dec 14, 2023
298e46a
torch_gather_fix fixing problems with the dimensions of params and in…
Kacper-W-Kozdon Dec 22, 2023
ce8e863
torch_gather_fix fixing problems with the dimensions of params and in…
Kacper-W-Kozdon Dec 22, 2023
93e5542
torch_gather_fix fixing problems with the dimensions of params and in…
Kacper-W-Kozdon Dec 22, 2023
e1bff94
torch_gather_fix fixing problems with the dimensions of params and in…
Kacper-W-Kozdon Dec 22, 2023
48131e3
torch_gather_fix fixing problems with the dimensions of params and in…
Kacper-W-Kozdon Dec 22, 2023
d557df5
torch_gather_fix fixing problems with the dimensions of params and in…
Kacper-W-Kozdon Dec 22, 2023
19141e2
torch_gather_fix fixing problems with the dimensions of params and in…
Kacper-W-Kozdon Dec 22, 2023
54326ed
torch_gather_fix fixing problems with the dimensions of params and in…
Kacper-W-Kozdon Dec 22, 2023
4b34ec7
torch_gather_fix fixing problems with the dimensions of params and in…
Kacper-W-Kozdon Dec 22, 2023
7c3cfc0
torch_gather_fix fixing problems with the dimensions of params and in…
Kacper-W-Kozdon Dec 22, 2023
cd07b61
torch_gather_fix fixing problems with the dimensions of params and in…
Kacper-W-Kozdon Dec 22, 2023
7ac2867
torch_gather_fix fixing problems with the dimensions of params and in…
Kacper-W-Kozdon Dec 22, 2023
cd02d6a
torch_gather_fix fixing problems with the dimensions of params and in…
Kacper-W-Kozdon Dec 22, 2023
c587e9f
torch_gather_fix fixing problems with the dimensions of params and in…
Kacper-W-Kozdon Dec 22, 2023
a92239f
remove manipulation.py from torch.gather() PR
Kacper-W-Kozdon Dec 23, 2023
09d689e
remove miscellaneous_ops.py from torch.gather() PR
Kacper-W-Kozdon Dec 23, 2023
b2772c9
remove manipulation.py from torch.gather() PR
Kacper-W-Kozdon Dec 23, 2023
2a12e0c
torch_gather_fix for pytest
Kacper-W-Kozdon Dec 26, 2023
f1a07b1
torch_gather_fix for pytest
Kacper-W-Kozdon Dec 26, 2023
da03cae
torch_gather_fix for pytest
Kacper-W-Kozdon Dec 26, 2023
cc24292
Merge branch 'unifyai:main' into torch_gather_fix
Kacper-W-Kozdon Jan 19, 2024
5c604fc
torch gather fix- removed redundant var
Kacper-W-Kozdon Jan 22, 2024
33602e7
added tolerance for test_gather
Kacper-W-Kozdon Jan 26, 2024
f1c2460
added safety factors in test_gather- needs associated PR update to a …
Kacper-W-Kozdon Feb 6, 2024
0f7047b
torch.gather modify axis
Kacper-W-Kozdon Feb 29, 2024
53c652f
torch.gather modify axis
Kacper-W-Kozdon Feb 29, 2024
6958edf
Merge branch 'unifyai:main' into torch_gather_fix
Kacper-W-Kozdon Feb 29, 2024
c5d6f6c
🤖 Lint code
ivy-branch Feb 29, 2024
f008bf5
pull
Kacper-W-Kozdon Feb 29, 2024
f05d18f
gather- remove redundant
Kacper-W-Kozdon Mar 1, 2024
edbbf85
demos revert
Kacper-W-Kozdon Mar 4, 2024
db0a24d
Merge branch 'unifyai:main' into torch_gather_fix
Kacper-W-Kozdon Mar 10, 2024
b9dd6d6
explicit bfloat16 tolerance in core test
Kacper-W-Kozdon Mar 13, 2024
7a88a1f
explicit bfloat16 tolerance in core test
Kacper-W-Kozdon Mar 13, 2024
dd6db2c
Merge branch 'unifyai:main' into torch_gather_fix
Kacper-W-Kozdon Mar 13, 2024
6f9f2c1
Merge branch 'unifyai:main' into torch_gather_fix
Kacper-W-Kozdon Mar 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 60 additions & 10 deletions ivy/functional/backends/torch/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,28 +190,78 @@ def gather(
batch_dims %= len(params.shape)
ivy.utils.assertions.check_gather_input_valid(params, indices, axis, batch_dims)
result = []

def expand_p_i(params, indices, axis=axis, batch_dims=batch_dims):
axis %= len(params.shape)
indices_ex = torch.clone(indices).detach()
Kacper-W-Kozdon marked this conversation as resolved.
Show resolved Hide resolved
abs(params.dim() - (indices_ex.dim() - batch_dims))
stack_dims1 = params.shape[:axis]
stack_dims2 = params.shape[axis + 1 :]
indices_ex = indices_ex.reshape(
(
torch.Size([1 for dim in stack_dims1])
+ torch.Size([-1])
+ torch.Size([1 for dim in stack_dims2])
)
)
indices_ex = indices_ex.expand(
(stack_dims1 + torch.Size([-1]) + stack_dims2)
).reshape((stack_dims1 + torch.Size([-1]) + stack_dims2))
return indices_ex, axis

final_shape = (
params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1 :]
)

if batch_dims == 0:
result = torch.gather(params, axis, indices, sparse_grad=False, out=out)
dim_diff = abs(params.dim() - (indices.dim() - batch_dims))
if dim_diff != 0:
indices_expanded, new_axis = expand_p_i(params, indices)

result = torch.gather(
params, new_axis, indices_expanded.long(), sparse_grad=False, out=out
).reshape(
params.shape[:axis]
+ indices.shape[batch_dims:]
+ params.shape[axis + 1 :]
)
result = result.to(dtype=params.dtype)
return result
else:
indices_expanded, new_axis = expand_p_i(params, indices)
result = torch.gather(
params, new_axis, indices_expanded.long(), sparse_grad=False, out=out
).reshape(final_shape)
result = result.to(dtype=params.dtype)

else:
indices_ex = indices
new_axis = axis
params_slices = torch.unbind(params, axis=0) if params.shape[0] > 0 else params
indices_slices = (
torch.unbind(indices_ex, axis=0) if indices.shape[0] > 0 else indices_ex
)
for b in range(batch_dims):
if b == 0:
zip_list = [(p, i) for p, i in zip(params, indices)]
zip_list = [(p, i) for p, i in zip(params_slices, indices_slices)]
else:
zip_list = [
(p, i) for z in [zip(p1, i1) for p1, i1 in zip_list] for p, i in z
]
for z in zip_list:
p, i = z
r = torch.gather(
p, (axis - batch_dims) % p.ndim, i, sparse_grad=False, out=False
)

i_ex, new_axis = expand_p_i(p, i, axis=axis - batch_dims)
r = torch.gather(p, (new_axis), i_ex.long(), sparse_grad=False, out=None)
result.append(r)
result = torch.stack(result)
result = result.reshape([*params.shape[0:batch_dims], *result.shape[1:]])
if ivy.exists(out):
return ivy.inplace_update(out, result)

result = result.reshape(
params.shape[:axis]
+ max(indices.shape[batch_dims:], torch.Size([1]))
+ params.shape[axis + 1 :]
)
result = result.to(dtype=params.dtype)
if ivy.exists(out):
return ivy.inplace_update(out, result)
return result


Expand Down
5 changes: 5 additions & 0 deletions ivy_tests/test_ivy/test_functional/test_core/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,9 @@ def test_function_unsupported_devices(func, backend_fw):
max_num_dims=5,
min_dim_size=1,
max_dim_size=10,
large_abs_safety_factor=1.5,
small_abs_safety_factor=1.5,
safety_factor_scale="log",
),
)
def test_gather(params_indices_others, test_flags, backend_fw, fn_name, on_device):
Expand All @@ -1025,6 +1028,8 @@ def test_gather(params_indices_others, test_flags, backend_fw, fn_name, on_devic
params=params,
indices=indices,
axis=axis,
atol_=1e-3,
rtol_=1e-3,
batch_dims=batch_dims,
)

Expand Down
Loading