Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sklearn decisiontree #23630

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
124 commits
Select commit Hold shift + click to select a range
c5c6906
added sklearn classes
umairjavaid Aug 28, 2023
ab4442f
added frontend support for DecisionTrees scikit-learn
umairjavaid Aug 29, 2023
5b8cc5a
added sklearn decision tree extracted code
umairjavaid Sep 3, 2023
a8a4ef7
removed the files that are not being worked on rn
umairjavaid Sep 3, 2023
bd339b3
added apply dense
umairjavaid Sep 3, 2023
1d70bf0
fixed translation
umairjavaid Sep 3, 2023
1f3d95c
fixed tree functions
umairjavaid Sep 4, 2023
30625dd
Merge branch 'main' into sklearn-decisiontree
HaiderSultanArc Sep 4, 2023
b428c52
Update _classes.py
HaiderSultanArc Sep 4, 2023
5f27d6f
implemented compute partial in tree module
umairjavaid Sep 6, 2023
2601ef5
added splitter
umairjavaid Sep 7, 2023
cd87ce9
added splitter
umairjavaid Sep 8, 2023
b854c1b
Merge branch 'sklearn-decisiontree' of https://github.com/umairjavaid…
umairjavaid Sep 8, 2023
032a748
Update: Decision Tree
HaiderSultanArc Sep 9, 2023
4bc26b7
Merge branch 'unifyai:main' into sklearn-decisiontree
HaiderSultanArc Sep 9, 2023
ddfd42a
Frontend Paddle: pixel_unshuffle (#23289)
aibenStunner Sep 9, 2023
c98b5a3
Implemented precision_score function in _classification.py and Added …
muzakkirhussain011 Sep 9, 2023
a6307a8
Feature/add frontend function pytorch eigh (#22637)
AliTarekk Sep 9, 2023
70ec79a
🤖 Lint code
ivy-branch Sep 10, 2023
33d0042
fixed numpy.prod failing tests (#21858)
ShreyanshBardia Sep 10, 2023
09163ec
Add patchelf for binaries handling.
vaatsalya123 Sep 10, 2023
c6e6167
Update demos 🤖
ivy-branch Sep 10, 2023
c010d17
changed key and value to pos argument in multiheadattention (#23375)
Killua7362 Sep 10, 2023
0d17da0
move from keyword arg to pos arg multihead (#23376)
Killua7362 Sep 10, 2023
637bbd9
Update requirements.txt
vaatsalya123 Sep 10, 2023
7695f21
Ivy Functional API column_stack (#23110)
ReneFabricius Sep 11, 2023
d667dc9
fixing maxwell() for jax frontend (#23358)
dash96 Sep 11, 2023
b0d6f60
For double sided maxwell (#21264)
stalemate1 Sep 11, 2023
af64f7d
🤖 Lint code
ivy-branch Sep 11, 2023
e37911b
Reformatted array_equal (#23058)
JomBarce Sep 11, 2023
f5531f4
docs: ivy lint (#22991)
NripeshN Sep 11, 2023
c6dc597
fix(docs): renamed ivy.get_backend to ivy.current_backend in backend …
lucasalavapena Sep 11, 2023
1723568
refactor: Reformatted and Refactored few files to make the code bette…
Sai-Suraj-27 Sep 11, 2023
70d8127
added paddle.nanmean function frontend and test for it (#23278)
sabre-code Sep 11, 2023
6b545c2
Resolve all the issues with the synchronize_db.py script and add supp…
RashulChutani Sep 11, 2023
c05be23
Reformat synchronize_db.py script and update the MongoDB user account…
RashulChutani Sep 11, 2023
dfa2526
feat: add complex type decorator and arguments for logsigmoid (#23273)
mosesdaudu001 Sep 11, 2023
389434f
fixed wrong returned shape of paddle.mean function (#23389)
Mghrabi Sep 11, 2023
e227737
refactor: Simplified conditional logic using `De Morgan's law` at few…
Sai-Suraj-27 Sep 11, 2023
f713068
fix(ivy): Extends ivy.pow to work for all input and exponent cases
AnnaTz Sep 11, 2023
9c2dc94
feat(Paddle-frontend): added multilabel_soft_margin_loss to the paddl…
Indraneel99 Sep 11, 2023
b6cdb1f
feat(Numpy-frontend): Added eigvals function to numpy frontend (#22920)
Javeria-Siddique Sep 11, 2023
20440c5
fix(ivy): Tests and fixes ivy.pow for int/float exponents
AnnaTz Sep 11, 2023
1faa24c
rfftfreq (#23306)
Kiprop2020 Sep 11, 2023
b1a4da3
updated scipy.linalg.svdvals test (#22716)
Sep 11, 2023
b26bdac
lcm_ (#22396)
NiteshK84 Sep 11, 2023
f3ef825
feat(frontend): Added lerp method to the Paddle frontend (#22571)
he11owthere Sep 11, 2023
832e3ee
fix
HaiderSultanArc Sep 11, 2023
510ce00
chore: remove test for not fullly implemented max_unpool1d
Daniel4078 Sep 12, 2023
c541422
chore: remove frontend for not fullly implemented max_unpool1d
Daniel4078 Sep 12, 2023
7bccf71
fix(testing): update CLI flags to be dynamically retrieved (#22788)
CatB1t Sep 12, 2023
0f4a62e
fix(testing): update compile flag to be False by default
CatB1t Sep 12, 2023
6fad2fa
feat(github): Add Sherry as a testing CODEOWNER
CatB1t Sep 12, 2023
61ddd4c
Added kl_div loss to ivy experimental api (#23054)
vismaysur Sep 12, 2023
57a3a50
🤖 Lint code
ivy-branch Sep 12, 2023
f296f4b
Jax Frontend: from_dlpack (#23445)
aibenStunner Sep 12, 2023
f201cff
refactor: handle_exceptions decorator (#23383)
Madjid-CH Sep 12, 2023
06bbb9e
refactor: Manual lint fixes
KareemMAX Sep 12, 2023
48deaeb
feat: make mish activation function support complex dtype (#23136)
mohame54 Sep 12, 2023
369a21c
feat: made log_softmax function support complex dtype (#23412)
mohame54 Sep 12, 2023
0fc82a4
lint: Updated black formatter version in `.pre-commit-config.yaml` fi…
Sai-Suraj-27 Sep 12, 2023
b8227ec
🤖 Lint code
ivy-branch Sep 12, 2023
8516d3f
lint: ignore `E704`
KareemMAX Sep 12, 2023
24e561e
lint: ignore paddle frontend lint errors
KareemMAX Sep 12, 2023
d7dba52
fix(paddle_frontend): reinstate `repeat_interleave` at `manipulation…
akshatvishu Sep 12, 2023
cbd1c96
revert multiheadattention (#23460)
Killua7362 Sep 12, 2023
d20d4ee
fix: missing kwargs in `MultiHeadAttention` class forward method
MahmoudAshraf97 Sep 12, 2023
4fe1301
fix(ivy): Fixes ivy.pow for the complex raised to inf case, and for t…
AnnaTz Sep 12, 2023
5c57b6e
Update demos 🤖
ivy-branch Sep 12, 2023
9edecd9
fix(namespaces): Correct namespace references in multiple functions a…
akshatvishu Sep 12, 2023
6e48cd9
control_flow_ops refactor (#23465)
YushaArif99 Sep 12, 2023
964cf14
chore: Add @KareemMAX as demos codeowner
KareemMAX Sep 12, 2023
3048f60
Sub ivy (#23192)
khethan123 Sep 12, 2023
a71c7ce
manual lint
NripeshN Sep 12, 2023
4a67bc0
removed unnecessary restriction on dtype and fill_value in ivy.full (…
ShreyanshBardia Sep 12, 2023
8f01791
fix(testing): Fixes bug in test_function when used on functions that …
AnnaTz Sep 12, 2023
73bc791
Topk method added to paddle frontend (#22934)
Aaryan562 Sep 12, 2023
3097fac
Added jax.numpy.linalg.lstsq to JAX frontend (#22870)
Supremolink81 Sep 12, 2023
a8b7251
Update meta.py (#22592)
MahadShahid8 Sep 12, 2023
0a696f2
Implement_stft_functional_api (#22581)
Dharshannan Sep 12, 2023
56d2b96
Update demos 🤖
ivy-branch Sep 12, 2023
acc8335
refactor: improve function import exception handling in _import_fn (#…
akshatvishu Sep 12, 2023
cd27e20
manual linting
NripeshN Sep 12, 2023
6d603cd
manual lint for comparison_ops.py
NripeshN Sep 12, 2023
52bfe1c
Stateful changes 2 (#23313)
RickSanchezStoic Sep 13, 2023
4f21c7d
Update demos 🤖
ivy-branch Sep 13, 2023
e119a41
fix(torch-frontend): Fixes torch.pow value and dtype errors
AnnaTz Sep 13, 2023
7944a66
fix(ivy): remove exception traceback object printing in functional API
CatB1t Sep 13, 2023
63ce8e5
fix: max_pool2d of jax backend to return array of same dtype as input
Daniel4078 Sep 13, 2023
8ca3cc4
🤖 Lint code
ivy-branch Sep 13, 2023
476d591
Update demos 🤖
ivy-branch Sep 13, 2023
3d0c21e
ci: add semantic PR action and update welcome message (#23382)
a0m0rajab Sep 13, 2023
b54f968
Update demos 🤖
ivy-branch Sep 13, 2023
c964019
Update the XLA compiler engine.
vaatsalya123 Sep 13, 2023
7fe2d53
feat: implement complex dtypes for sigmoid (#23436)
mosesdaudu001 Sep 13, 2023
19574b7
manual pre-commit all files
NripeshN Sep 13, 2023
80989d7
Implement a new workflow to build and push the multiversion docker fi…
RashulChutani Sep 13, 2023
0ab5367
Incorporate Multi version Testing into the CI and Migrate completely …
RashulChutani Sep 13, 2023
5ca141a
Remove redundant functions, run_multiversion_tests and remove_fron_db…
RashulChutani Sep 13, 2023
c202a2a
Pass backends as arguments to the multiversion_framework_directory sc…
RashulChutani Sep 13, 2023
1e3533c
Resolve pytest not found issue and update backend_version while runni…
RashulChutani Sep 13, 2023
1b00f37
Add /bin/bash -c to docker run command to correct multiversion testin…
RashulChutani Sep 13, 2023
81cc410
Update python version in DockerfileMultiversion to 3.10 [skip ci]
RashulChutani Sep 13, 2023
dfcbb15
Reformat requirement_mappings_multiversion.json and add tensorflow-pr…
RashulChutani Sep 13, 2023
33f79de
Merge branch 'unifyai:main' into sklearn-decisiontree
HaiderSultanArc Sep 13, 2023
ba03b93
Merge branch 'sklearn-decisiontree-test' into sklearn-decisiontree
HaiderSultanArc Sep 13, 2023
0813e39
Merge remote-tracking branch 'upstream/sklearn-decisiontree-test' int…
HaiderSultanArc Sep 13, 2023
18cc469
feat(frontend): Added kaiser_bessel_derived_window to tensorflow fron…
waqaarahmed Sep 14, 2023
4153e78
Add subtract_ to paddle frontend
Guilhermeslucas Sep 14, 2023
19540d4
🤖 Lint code
ivy-branch Sep 14, 2023
a7ed845
feat(dx): enforce pre-commit & tests through pr template (#23590)
lucasalavapena Sep 14, 2023
0ff39ab
Test frontend methods refactor (#23380)
mosesdaudu001 Sep 14, 2023
5b48641
ci: show welcome message always (#23592)
a0m0rajab Sep 14, 2023
4469bfe
fix(testing): remove duplicate code in `test_function`. (#23591)
CatB1t Sep 14, 2023
8997af3
feat: add support for complex dtype to relu6 (#23599)
mosesdaudu001 Sep 14, 2023
db29dca
Docs(Error Handling): Fixed some small grammatical mistakes.
zaeemansari70 Sep 14, 2023
f8dee53
manual lint
NripeshN Sep 14, 2023
a9c3583
pixel_unshuffle function added for Paddle Frontend (#21754)
Sameerk22 Sep 14, 2023
0eb0cfb
fix: more efficient implementation of layernorm (#23620)
tomatillos Sep 14, 2023
09473c3
add max_pool3d to pytorch frontend (#22038)
progs2002 Sep 14, 2023
a240a4b
Added arctan2 math function for numpy (#23315)
Abhayooo7 Sep 14, 2023
9caa8c7
fix: activate welcome message (#23629)
a0m0rajab Sep 14, 2023
12165d7
Tree and Classes
HaiderSultanArc Sep 15, 2023
118f442
Merge branch 'unifyai:main' into sklearn-decisiontree
HaiderSultanArc Sep 15, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .github/pull_request_template.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,21 @@ Close #

- [ ] Did you add a function?
- [ ] Did you add the tests?
- [ ] Did you run your tests and are your tests passing?
- [ ] Did pre-commit not fail on any check?
- [ ] Did you follow the steps we provided?

<!--
Please mark your PR as a draft if you realise after the fact that your tests are not passing or
that your pre-commit check has some failures.

Here are some relevant resources regarding tests and pre-commit:

https://unify.ai/docs/ivy/overview/deep_dive/ivy_tests.html
https://unify.ai/docs/ivy/overview/deep_dive/formatting.html#pre-commit

-->

### Socials:

<!--
Expand Down
Original file line number Diff line number Diff line change
@@ -1,62 +1,55 @@
name: Check Semantic and welcome new contributors

on:
pull_request_target:
types:
- opened
- edited
- synchronize
- reopened
workflow_call:

permissions:
pull-requests: write

jobs:
semantics:
name: Semantics
runs-on: ubuntu-latest
timeout-minutes: 10
steps:
- uses: amannn/action-semantic-pull-request@v3.4.0
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

pr-compliance-checks:
name: PR Compliance Checks
runs-on: ubuntu-latest
timeout-minutes: 10
steps:
- uses: mtfoley/pr-compliance-action@v0.5.0
with:
body-auto-close: false
protected-branch-auto-close: false
body-comment: >
## Issue Reference

In order to be considered for merging, the pull request description must refer to a
specific issue number. This is described in our
[contributing guide](https://unify.ai/docs/ivy/overview/contributing/the_basics.html#todo-list-issues) and our PR template.

This check is looking for a phrase similar to: "Fixes #XYZ" or "Resolves #XYZ" where XYZ is the issue
number that this PR is meant to address.

welcome:
name: Welcome
runs-on: ubuntu-latest
timeout-minutes: 10
needs: semantics
if: github.event.action == 'opened'
steps:
- uses: actions/first-interaction@v1
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
pr-message: |-
Congrats on making your first Pull Request and thanks for supporting Ivy! 🎉
Joing the conversation in our [Discord](https://discord.com/invite/sXyFF8tDtm)

Here are some notes to understand our tests:
- We have merged all the tests in one file called \`display_test_results\` job. 👀 It contains the following two sections:
- **Combined Test Results:** This shows the results of all the ivy tests that ran on the PR. ✔️
- **New Failures Introduced:** This lists the tests that are passing on main, but fail on the PR Fork.
Please try to make sure that there are no such tests. 💪
name: Check Semantic and welcome new contributors

on:
pull_request_target:
types:
- opened
- edited
- synchronize
- reopened
workflow_call:

permissions:
pull-requests: write

jobs:
pr-compliance-checks:
name: PR Compliance Checks
runs-on: ubuntu-latest
timeout-minutes: 10
steps:
- uses: mtfoley/pr-compliance-action@v0.5.0
with:
body-auto-close: false
protected-branch-auto-close: false
body-comment: >
## Issue Reference

In order to be considered for merging, the pull request description must refer to a
specific issue number. This is described in our
[contributing guide](https://unify.ai/docs/ivy/overview/contributing/the_basics.html#todo-list-issues) and our PR template.

This check is looking for a phrase similar to: "Fixes #XYZ" or "Resolves #XYZ" where XYZ is the issue
number that this PR is meant to address.

welcome:
name: Welcome
runs-on: ubuntu-latest
timeout-minutes: 10
if: github.event.action == 'opened'
steps:
- uses: actions/first-interaction@v1
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
pr-message: |-
Congrats on making your first Pull Request and thanks for supporting Ivy! 🎉
Join the conversation in our [Discord](https://discord.com/invite/sXyFF8tDtm)

Here are some notes to understand our tests:
- We have merged all the tests in one file called \`display_test_results\` job. 👀 It contains the following two sections:
- **Combined Test Results:** This shows the results of all the ivy tests that ran on the PR. ✔️
- **New Failures Introduced:** This lists the tests that fails on this PR.

Please make sure they are passing. 💪

Keep in mind that we will assign an engineer for this task and they will look at it based on the workload that they have, so be patient.
4 changes: 2 additions & 2 deletions docs/overview/contributing/error_handling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Error Handling

This section, "Error Handling" aims to assist you in navigating through some common errors you might encounter while working with the Ivy's Functional API. We'll go through some common errors which you might encounter while working as a contributor or a developer.

#. This is the case where we pass in a dtype to `torch` which is not actually supported by the torch's native framework itself. The function which was
#. This is the case where we pass in a dtype to `torch` which is not actually supported by the torch's native framework itself.

.. code-block:: python

Expand Down Expand Up @@ -64,7 +64,7 @@ This section, "Error Handling" aims to assist you in navigating through some com
E
E You can reproduce this example by temporarily adding @reproduce_failure('6.82.4', b'AXicY2BAABYQwQgiAABDAAY=') as a decorator on your test case

#. This is a similar assertion as stated in point 2 but with torch and ground-truth tensorflow not matching but the matrices are quite different so there should be an issue in the backends rather than a numerical instability here:
#. This is a similar assertion as stated in point 2 but with torch and ground-truth tensorflow not matching but the matrices are quite different so there should be an issue in the backends rather than a numerical instability here.

.. code-block:: python

Expand Down
13 changes: 11 additions & 2 deletions ivy/data_classes/array/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,23 @@ def prelu(
"""
return ivy.prelu(self._data, slope, out=out)

def relu6(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array:
def relu6(
self,
/,
*,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
Apply the rectified linear unit 6 function element-wise.

Parameters
----------
self
input array
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
out
optional output array, for writing the result to.
It must have a shape that the inputs broadcast to.
Expand All @@ -156,7 +165,7 @@ def relu6(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array:
>>> print(y)
ivy.array([0., 0., 1., 2., 3., 4., 5., 6., 6.])
"""
return ivy.relu6(self._data, out=out)
return ivy.relu6(self._data, complex_mode=complex_mode, out=out)

def logsigmoid(
self: ivy.Array,
Expand Down
10 changes: 10 additions & 0 deletions ivy/data_classes/container/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def static_relu6(
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
Expand All @@ -351,6 +352,9 @@ def static_relu6(
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is ``False``.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
Expand Down Expand Up @@ -379,6 +383,7 @@ def static_relu6(
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
complex_mode=complex_mode,
out=out,
)

Expand All @@ -390,6 +395,7 @@ def relu6(
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
Expand All @@ -412,6 +418,9 @@ def relu6(
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is ``False``.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
Expand Down Expand Up @@ -439,6 +448,7 @@ def relu6(
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
complex_mode=complex_mode,
out=out,
)

Expand Down
4 changes: 3 additions & 1 deletion ivy/functional/backends/jax/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def logit(
return jnp.log(x / (1 - x))


def relu6(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
def relu6(
x: JaxArray, /, *, complex_mode="jax", out: Optional[JaxArray] = None
) -> JaxArray:
relu6_func = jax.nn.relu6

# sets gradient at 0 and 6 to 0 instead of 0.5
Expand Down
4 changes: 3 additions & 1 deletion ivy/functional/backends/numpy/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def thresholded_relu(


@_scalar_output_to_0d_array
def relu6(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray:
def relu6(
x: np.ndarray, /, *, complex_mode="jax", out: Optional[np.ndarray] = None
) -> np.ndarray:
return np.minimum(np.maximum(x, 0, dtype=x.dtype), 6, out=out, dtype=x.dtype)


Expand Down
7 changes: 6 additions & 1 deletion ivy/functional/backends/paddle/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@ def thresholded_relu(
)


def relu6(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor:
@with_unsupported_device_and_dtypes(
{"2.5.1 and below": {"cpu": ("bfloat16",)}}, backend_version
)
def relu6(
x: paddle.Tensor, /, *, complex_mode="jax", out: Optional[paddle.Tensor] = None
) -> paddle.Tensor:
if x.dtype in [paddle.float32, paddle.float64]:
return F.relu6(x)
if paddle.is_complex(x):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def thresholded_relu(
return tf.cast(tf.where(x > threshold, x, 0), x.dtype)


@with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version)
def relu6(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor:
def relu6(x: Tensor, /, *, complex_mode="jax", out: Optional[Tensor] = None) -> Tensor:
return tf.nn.relu6(x)


Expand Down
4 changes: 3 additions & 1 deletion ivy/functional/backends/torch/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def thresholded_relu(


@with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version)
def relu6(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor:
def relu6(
x: torch.Tensor, /, *, complex_mode="jax", out: Optional[torch.Tensor] = None
) -> torch.Tensor:
return torch.nn.functional.relu6(x)


Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/frontends/jax/nn/non_linear_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def relu(x):

@to_ivy_arrays_and_back
def relu6(x):
res = ivy.relu6(x)
res = ivy.relu6(x, complex_mode="jax")
return _type_conversion_64(res)


Expand Down
2 changes: 2 additions & 0 deletions ivy/functional/frontends/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ def promote_types_of_numpy_inputs(
_sin,
_tan,
_degrees,
_arctan2,
)

from ivy.functional.frontends.numpy.mathematical_functions.handling_complex_numbers import ( # noqa
Expand Down Expand Up @@ -672,6 +673,7 @@ def promote_types_of_numpy_inputs(
arccos = ufunc("_arccos")
arcsin = ufunc("_arcsin")
arctan = ufunc("_arctan")
arctan2 = ufunc("_arctan2")
cos = ufunc("_cos")
deg2rad = ufunc("_deg2rad")
rad2deg = ufunc("_rad2deg")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,32 @@ def _arctan(
return ret


# arctan2


@handle_numpy_out
@handle_numpy_dtype
@to_ivy_arrays_and_back
@handle_numpy_casting
@from_zero_dim_arrays_to_scalar
def _arctan2(
x1,
x2,
/,
out=None,
*,
where=True,
casting="same_kind",
order="K",
dtype=None,
subok=True,
):
ret = ivy.atan2(x1, x2, out=out)
if ivy.is_array(where):
ret = ivy.where(where, ret, ivy.default(out, ivy.zeros_like(ret)), out=out)
return ret


@handle_numpy_out
@handle_numpy_dtype
@to_ivy_arrays_and_back
Expand Down
Loading
Loading