Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: manual dtype casting removal #22700

Merged
merged 54 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
30db6a3
feat(paddle backend)! removed manual dtype casting from functions in …
Madjid-CH Aug 26, 2023
b77ad16
feat(paddle backend)! removed manual dtype casting from functions in …
Madjid-CH Aug 27, 2023
f40c70d
feat(paddle backend)! removed manual dtype casting from functions in …
Madjid-CH Aug 27, 2023
2bf152f
feat(paddle backend)! removed manual dtype casting from functions in …
Madjid-CH Aug 27, 2023
5e19928
feat(paddle backend)! removed manual dtype casting from functions in …
Madjid-CH Aug 27, 2023
da88855
feat(paddle backend)! removed manual dtype casting from functions in …
Madjid-CH Aug 27, 2023
aae6444
feat(paddle backend)! removed manual dtype casting from functions in …
Madjid-CH Aug 27, 2023
db791d0
feat(paddle backend)! removed manual dtype casting from functions in …
Madjid-CH Aug 27, 2023
79da6bf
feat(paddle backend)! removed manual dtype casting from functions in …
Madjid-CH Aug 27, 2023
5dff2ac
feat(paddle backend)! removed manual dtype casting from functions in …
Madjid-CH Aug 27, 2023
ec7ee1b
feat(paddle backend)! removed manual dtype casting from functions in …
Madjid-CH Aug 27, 2023
116a406
feat(paddle backend)! removed manual dtype casting from functions in …
Madjid-CH Aug 27, 2023
58c2976
fix(paddle backend): reintroduced the support for complex numbers.
Madjid-CH Aug 28, 2023
049cce3
fix(paddle backend)! fixed failing pre-commit hook.
Madjid-CH Aug 28, 2023
3225b77
feat(paddle backend)! removed manual dtype casting from experimental …
Madjid-CH Aug 28, 2023
c0d7e08
merging
Madjid-CH Aug 28, 2023
992af94
feat(numpy backend)! removed manual dtype casting from the backend fu…
Madjid-CH Aug 28, 2023
849c99c
refactor(jax backend): grouped supported types in with_supported_dtyp…
Madjid-CH Aug 30, 2023
4e636c3
Merge remote-tracking branch 'upstream/main' into manual-dtype-castin…
Madjid-CH Aug 31, 2023
0622ef8
feat(jax backend)! removed manual dtype casting from the backend func…
Madjid-CH Aug 31, 2023
40ee0dc
feat(tensorflow backend)! removed manual dtype casting from the backe…
Madjid-CH Aug 31, 2023
dab7274
feat(torch backend)! removed manual dtype casting from the backend fu…
Madjid-CH Aug 31, 2023
d329020
merging
Madjid-CH Aug 31, 2023
00ddfe8
Merging 'upstream/main' into 'origin/manual-dtype-casting-removal'.
Madjid-CH Sep 5, 2023
5e22411
Merge remote-tracking branch 'upstream/main' into manual-dtype-castin…
Madjid-CH Sep 8, 2023
d66f5f2
revert(backends): undo irrelevant changes to this PR
Sep 8, 2023
9fe1105
removed uint16 from paddle backend functions and removed unrelated de…
Madjid-CH Sep 11, 2023
19c48b2
fix(tests) removed duplicate test that made the pre-commit hook fails.
Sep 11, 2023
b28f784
lint: run pre-commit hook on all files.
Sep 11, 2023
18990b6
Merge remote-tracking branch 'upstream/main' into manual-dtype-castin…
Sep 11, 2023
e008e2f
fix(paddle backend): removed uint16 from with_supported_dtype decorator.
Sep 11, 2023
e418e63
removed bfloat16 support
Madjid-CH Sep 12, 2023
2501718
refactor(backends): replaced from ivy import with_supported_dtypes wi…
Sep 12, 2023
dfefc38
fix(paddle backends):removed support for float16 in functions that do…
Sep 12, 2023
536684f
fix(backends): fixed missing argument error.
Sep 12, 2023
ce0827c
feat(paddle backend): added complex type support for logical function…
Sep 12, 2023
2cbf317
Merging.
Sep 12, 2023
6602caa
fix(paddle backends):removed support for float16 in functions that do…
Sep 12, 2023
a71d8a7
Merge remote-tracking branch 'upstream/main' into manual-dtype-castin…
Sep 13, 2023
f554061
Merge remote-tracking branch 'upstream/main' into manual-dtype-castin…
Sep 13, 2023
5d844e2
revert unreviewed changes
Sep 13, 2023
cf805ca
Merge remote-tracking branch 'upstream/main' into manual-dtype-castin…
Sep 14, 2023
ef3070f
fix(paddle backends) reintroduced complex support for paddle.exp in e…
Sep 14, 2023
d3107a5
feat(paddle backends) added complex support for divide function in el…
Sep 14, 2023
c5886c8
fix(paddle backend): removed unsupported dtypes
Sep 15, 2023
d0bb6a0
fix(paddle backend): removed unsupported dtypes from broadcast_to
Sep 15, 2023
58dbc91
made requested changes
Madjid-CH Sep 18, 2023
5b6f661
joined lines that make less than 88 characters.
Sep 18, 2023
68e2039
replaced dtypes decorator
Madjid-CH Sep 29, 2023
88ed132
🤖 Lint code
ivy-branch Sep 29, 2023
0af0309
Merge remote-tracking branch 'upstream/main' into manual-dtype-castin…
Sep 29, 2023
6e52ff4
🤖 Lint code
ivy-branch Sep 29, 2023
37fb5c3
Merge remote-tracking branch 'upstream/main' into manual-dtype-castin…
Nov 3, 2023
64aa4dc
Merge remote-tracking branch 'origin/manual-dtype-casting-removal' in…
Nov 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 32 additions & 45 deletions ivy/functional/backends/paddle/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,27 @@
# local
import ivy.functional.backends.paddle as paddle_backend
import ivy
from ivy.func_wrapper import with_unsupported_device_and_dtypes
from ivy.func_wrapper import (
with_unsupported_device_and_dtypes,
with_supported_dtypes,
with_supported_device_and_dtypes,
)
from . import backend_version


unsupported_dtypes = [
paddle.int8,
paddle.int16,
paddle.int32,
paddle.int64,
paddle.uint8,
paddle.float16,
paddle.complex64,
paddle.complex128,
paddle.bool,
]


def relu(
x: paddle.Tensor, /, *, complex_mode="jax", out: Optional[paddle.Tensor] = None
) -> paddle.Tensor:
if x.dtype in unsupported_dtypes:
if paddle.is_complex(x):
return paddle.complex(F.relu(x.real()), F.relu(x.imag()))
return F.relu(x.cast("float32")).cast(x.dtype)
@with_supported_dtypes(
{"2.5.1 and below": ("float32", "float64", "complex")},
backend_version,
)
def relu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor:
if paddle.is_complex(x):
return paddle.complex(F.relu(x.real()), F.relu(x.imag()))
return F.relu(x)


@with_unsupported_device_and_dtypes(
{"2.5.2 and below": {"cpu": ("bfloat16",)}}, backend_version
@with_supported_device_and_dtypes(
{"2.5.2 and below": {"cpu": ("float32", "float64", "complex")}},
backend_version,
)
def leaky_relu(
x: paddle.Tensor,
Expand All @@ -52,18 +44,17 @@ def leaky_relu(
complex_mode="jax",
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
if x.dtype in unsupported_dtypes:
if paddle.is_complex(x):
return paddle.complex(
F.leaky_relu(x.real(), negative_slope=alpha),
F.leaky_relu(x.imag(), negative_slope=alpha),
)
return F.leaky_relu(x.cast("float32"), negative_slope=alpha).cast(x.dtype)
if paddle.is_complex(x):
return paddle.complex(
F.leaky_relu(x.real(), negative_slope=alpha),
F.leaky_relu(x.imag(), negative_slope=alpha),
)
return F.leaky_relu(x, negative_slope=alpha)


@with_unsupported_device_and_dtypes(
{"2.5.2 and below": {"cpu": ("bfloat16",)}}, backend_version
@with_supported_device_and_dtypes(
{"2.5.2 and below": {"cpu": ("float32", "float64", "complex")}},
backend_version,
)
def gelu(
x: paddle.Tensor,
Expand All @@ -82,26 +73,23 @@ def gelu(
* x
* (1 + paddle_backend.tanh(sqrt_2_over_pi * (x + 0.044715 * x * x * x)))
)
if x.dtype in unsupported_dtypes:
return F.gelu(x.cast("float32"), approximate=approximate).cast(x.dtype)
return F.gelu(x, approximate=approximate)


@with_unsupported_device_and_dtypes(
{"2.5.2 and below": {"cpu": ("bfloat16",)}}, backend_version
@with_supported_device_and_dtypes(
{"2.5.2 and below": {"cpu": ("float32", "float64", "complex")}},
backend_version,
)
def sigmoid(
x: paddle.Tensor, /, *, complex_mode="jax", out: Optional[paddle.Tensor] = None
) -> paddle.Tensor:
if paddle.is_complex(x):
return 1.0 / (1.0 + paddle_backend.exp(-x))
if x.dtype in unsupported_dtypes:
return F.sigmoid(x.cast("float32")).cast(x.dtype)
return F.sigmoid(x)


@with_unsupported_device_and_dtypes(
{"2.5.2 and below": {"cpu": ("float16", "bfloat16")}}, backend_version
{"2.5.2 and below": {"cpu": ("bfloat16", "float16")}}, backend_version
)
def softmax(
x: paddle.Tensor,
Expand Down Expand Up @@ -183,8 +171,9 @@ def log_softmax(
return ret


@with_unsupported_device_and_dtypes(
{"2.5.2 and below": {"cpu": ("bfloat16",)}}, backend_version
@with_supported_device_and_dtypes(
{"2.5.2 and below": {"cpu": ("float32", "float64", "complex")}},
backend_version,
)
def mish(
x: paddle.Tensor,
Expand All @@ -193,10 +182,8 @@ def mish(
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
if x.dtype in unsupported_dtypes:
if paddle.is_complex(x):
return x * paddle_backend.tanh(paddle_backend.log1p(paddle_backend.exp(x)))
return F.mish(x.cast("float32")).cast(x.dtype)
if paddle.is_complex(x):
return x * paddle_backend.tanh(paddle_backend.log1p(paddle_backend.exp(x)))
return F.mish(x)


Expand Down
24 changes: 15 additions & 9 deletions ivy/functional/backends/paddle/data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import ivy.functional.backends.paddle as paddle_backend
import numpy as np
import ivy
from ivy.func_wrapper import with_unsupported_dtypes
from ivy.functional.ivy.data_type import _handle_nestable_dtype_info
from . import backend_version


ivy_dtype_dict = {
Expand Down Expand Up @@ -139,6 +141,18 @@ def broadcast_arrays(*arrays: paddle.Tensor) -> List[paddle.Tensor]:
return result


@with_unsupported_dtypes(
{
"2.5.1 and below": (
"uint8",
"int8",
"int16",
"float16",
"bfloat16",
Madjid-CH marked this conversation as resolved.
Show resolved Hide resolved
)
},
backend_version,
)
def broadcast_to(
x: paddle.Tensor,
/,
Expand All @@ -157,15 +171,7 @@ def broadcast_to(
if x.ndim > len(shape):
x = x.reshape([-1])

if x.dtype in [
Madjid-CH marked this conversation as resolved.
Show resolved Hide resolved
paddle.int8,
paddle.int16,
paddle.uint8,
paddle.float16,
paddle.bfloat16,
]:
return paddle.broadcast_to(x.cast("float32"), shape).cast(x.dtype)
elif x.dtype in [paddle.complex64, paddle.complex128]:
if x.dtype in [paddle.complex64, paddle.complex128]:
x_real = paddle.broadcast_to(x.real(), shape)
x_imag = paddle.broadcast_to(x.imag(), shape)
return paddle.complex(x_real, x_imag)
Expand Down
Loading
Loading