-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added tf.keras.backend frontends and fixed unsupported dtypes function (
#27771) Added the tf.keras.backend.rnn, tf.keras.backend.dot and tf.keras.backend.bias_add frontends Added the ivy.rnn function Fixes to the unsupported dtypes functions for inferring the unsupported dtypes of frontend code that uses other frontend code which uses ivy code A fix to ivy.Array.__repr__ to use it when trying to print arrays created with with_backend Fixed the issue with paddle.fluid being deprecated in the recent release
- Loading branch information
1 parent
b931822
commit 93f71fb
Showing
24 changed files
with
954 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from . import activations | ||
from . import backend | ||
from . import layers | ||
from . import metrics | ||
from . import regularizers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import functools | ||
import ivy | ||
import ivy.functional.frontends.tensorflow as tf_frontend | ||
from ivy.functional.frontends.tensorflow.func_wrapper import ( | ||
_ivy_array_to_tensorflow, | ||
_to_ivy_array, | ||
to_ivy_arrays_and_back, | ||
) | ||
|
||
|
||
def bias_add(x, bias, data_format=None): | ||
if data_format is None: | ||
data_format = "channels_last" | ||
bias_shape = bias.shape | ||
if len(bias_shape) == 1: | ||
if data_format == "channels_first": | ||
return tf_frontend.nn.bias_add(x, bias, data_format="NC...") | ||
return tf_frontend.nn.bias_add(x, bias, data_format="N...C") | ||
if x.ndim in (3, 4, 5): | ||
if data_format == "channels_first": | ||
bias_reshape_axis = (1, bias_shape[-1]) + bias_shape[:-1] | ||
return x + tf_frontend.reshape(bias, bias_reshape_axis) | ||
return x + tf_frontend.reshape(bias, (1,) + bias_shape) | ||
return tf_frontend.nn.bias_add(x, bias) | ||
|
||
|
||
@to_ivy_arrays_and_back | ||
def dot(x, y): | ||
return ivy.dot(x, y) | ||
|
||
|
||
@to_ivy_arrays_and_back | ||
def rnn( | ||
step_function, | ||
inputs, | ||
initial_states, | ||
go_backwards=False, | ||
mask=None, | ||
constants=None, | ||
unroll=False, | ||
input_length=None, | ||
time_major=False, | ||
zero_output_for_mask=False, | ||
return_all_outputs=True, | ||
): | ||
@functools.wraps(step_function) | ||
def _new_step_function(*args, **kwargs): | ||
frontend_args = ivy.nested_map( | ||
_ivy_array_to_tensorflow, args, include_derived=True, shallow=False | ||
) | ||
frontend_kwargs = ivy.nested_map( | ||
_ivy_array_to_tensorflow, kwargs, include_derived=True, shallow=False | ||
) | ||
ret = step_function(*frontend_args, **frontend_kwargs) | ||
return ivy.nested_map(_to_ivy_array, ret, include_derived=True) | ||
|
||
return ivy.rnn( | ||
_new_step_function, | ||
inputs, | ||
initial_states, | ||
go_backwards=go_backwards, | ||
mask=mask, | ||
constants=constants, | ||
unroll=unroll, | ||
input_length=input_length, | ||
time_major=time_major, | ||
zero_output_for_mask=zero_output_for_mask, | ||
return_all_outputs=return_all_outputs, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.