Skip to content

Commit

Permalink
Added tf.keras.backend frontends and fixed unsupported dtypes function (
Browse files Browse the repository at this point in the history
#27771)

Added the tf.keras.backend.rnn, tf.keras.backend.dot and tf.keras.backend.bias_add frontends
Added the ivy.rnn function
Fixes to the unsupported dtypes functions for inferring the unsupported dtypes of frontend code that uses other frontend code which uses ivy code
A fix to ivy.Array.__repr__ to use it when trying to print arrays created with with_backend
Fixed the issue with paddle.fluid being deprecated in the recent release
  • Loading branch information
vedpatwardhan authored Dec 27, 2023
1 parent b931822 commit 93f71fb
Show file tree
Hide file tree
Showing 24 changed files with 954 additions and 36 deletions.
8 changes: 7 additions & 1 deletion ivy/data_classes/array/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,13 @@ def __repr__(self):
self._post_repr = ")"
sig_fig = ivy.array_significant_figures
dec_vals = ivy.array_decimal_values
backend = ivy.with_backend(self.backend)
if self.backend == "" or ivy.is_local():
# If the array was constructed using implicit backend
backend = ivy.current_backend()
else:
# Requirerd in the case that backend is different
# from the currently set backend
backend = ivy.with_backend(self.backend)
arr_np = backend.to_numpy(self._data)
rep = (
np.array(ivy.vec_sig_fig(arr_np, sig_fig))
Expand Down
75 changes: 75 additions & 0 deletions ivy/data_classes/container/experimental/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2826,3 +2826,78 @@ def max_unpool1d(
padding=padding,
data_format=data_format,
)

@staticmethod
def static_rnn(
step_function: Callable,
inputs: ivy.Array,
initial_states: List[ivy.Array],
/,
*,
go_backwards: bool = False,
mask: Optional[ivy.Array] = None,
constants: Optional[ivy.Array] = None,
unroll: bool = False,
input_length: Optional[int] = None,
time_major: bool = False,
zero_output_for_mask: bool = False,
return_all_outputs: bool = True,
) -> ivy.Container:
"""ivy.Container static method variant of ivy.rnn.
Parameters
----------
step_function
RNN step function.
inputs
Array of temporal data of shape (samples, time, ...).
initial_states
Array with shape (samples, state_size).
go_backwards
If True, do the iteration over the time dimension in reverse order and
return the reversed sequence.
mask
Binary array with shape (samples, time, 1), with a zero for every element
that is masked.
constants
List of constant values passed at each step.
unroll
Whether to use a pythonic while loop or ivy.while_loop
input_length
An integer or 1-D array, depending on whether the time dimension is
fixed-length. In case of variable length input, it is used for masking in
case there is no mask specified.
time_major
If True, the inputs and outputs will be in shape (timesteps, batch, ...)
whereas in the False case, it will be (batch, timesteps, ...).
zero_output_for_mask
If True, the otput for masked timestep will be zeros, whereas in the False
case, output from previous timestep is returned
return_all_outputs
If True, return the recurrent outputs for all timesteps in the sequence. If
False, only return the output for the last timestep.
Returns
-------
ret
A tuple of
- the latest output of the rnn of shape (samples, ...)
- the output of the rnn of shape (samples, time, ...) if
return_all_outputs=True else (samples, 1, ...)
- list of tensors, latest states returned by the step funciton, of shape
(samples, ...)
"""
return ContainerBase.cont_multi_map_in_function(
"rnn",
step_function,
inputs,
initial_states,
go_backwards=go_backwards,
mask=mask,
constants=constants,
unroll=unroll,
input_length=input_length,
time_major=time_major,
zero_output_for_mask=zero_output_for_mask,
return_all_outputs=return_all_outputs,
)
3 changes: 2 additions & 1 deletion ivy/functional/backends/jax/control_flow_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def body_fn_wrapper(loop_vars):
def test_fn_wrapper(loop_vars):
return test_fn(*loop_vars)

vars = list(vars.values())
if isinstance(vars, dict):
vars = list(vars.values())
with jax.disable_jit():
final_loop_vars = jax.lax.while_loop(test_fn_wrapper, body_fn_wrapper, vars)
return final_loop_vars
4 changes: 3 additions & 1 deletion ivy/functional/backends/numpy/control_flow_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ def cond(*_):


def while_loop(test_fn, body_fn, vars):
result = list(vars.values())
result = vars
if isinstance(vars, dict):
result = list(vars.values())
while test_fn(*result):
result = body_fn(*result)
if not isinstance(result, tuple):
Expand Down
4 changes: 4 additions & 0 deletions ivy/functional/backends/paddle/control_flow_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ def if_else(cond, body_fn, orelse_fn, vars):

def while_loop(test_fn, body_fn, vars):
result = vars
if isinstance(vars, dict):
result = list(vars.values())
while test_fn(*result):
result = body_fn(*result)
if not isinstance(result, tuple):
result = (result,)
return result
2 changes: 1 addition & 1 deletion ivy/functional/backends/paddle/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _elementwise_helper(x1, x2):


@with_unsupported_dtypes(
{"2.5.1 and below": ("int8", "uint8", "float16", "bool", "bfloat16")},
{"2.5.2 and below": ("int8", "uint8", "float16", "bool", "bfloat16")},
backend_version,
)
def add(
Expand Down
18 changes: 17 additions & 1 deletion ivy/functional/backends/paddle/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def flip(


@with_unsupported_dtypes(
{"2.5.2 and below": ("int16", "int8", "uint8", "bfloat16")}, backend_version
{"2.5.2 and below": ("uint8", "int8", "int16", "bfloat16")}, backend_version
)
def permute_dims(
x: paddle.Tensor,
Expand Down Expand Up @@ -417,6 +417,22 @@ def zero_pad(
return paddle_backend.constant_pad(x, pad_width=pad_width, value=0)


@with_supported_dtypes(
{
"2.5.2 and below": (
"bool",
"int32",
"int64",
"float16",
"bfloat16",
"float32",
"float64",
"complex64",
"complex128",
)
},
backend_version,
)
def swapaxes(
x: paddle.Tensor,
axis0: int,
Expand Down
10 changes: 5 additions & 5 deletions ivy/functional/backends/paddle/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

# local
import ivy
from paddle.fluid.libpaddle import Place
from paddle.device import core
from ivy.functional.ivy.random import (
_check_bounds_and_get_shape,
_randint_check_dtype_and_bound,
Expand All @@ -35,7 +35,7 @@ def random_uniform(
high: Union[float, paddle.Tensor] = 1.0,
shape: Optional[Union[paddle.Tensor, ivy.NativeShape, Sequence[int]]] = None,
dtype: paddle.dtype,
device: Place = None,
device: core.Place = None,
seed=None,
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
Expand Down Expand Up @@ -66,7 +66,7 @@ def random_normal(
shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None,
dtype: paddle.dtype,
seed: Optional[int] = None,
device: Place = None,
device: core.Place = None,
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
_check_valid_scale(std)
Expand Down Expand Up @@ -95,7 +95,7 @@ def multinomial(
batch_size: int = 1,
probs: Optional[paddle.Tensor] = None,
replace: bool = True,
device: Place = None,
device: core.Place = None,
seed: Optional[int] = None,
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
Expand All @@ -118,7 +118,7 @@ def randint(
/,
*,
shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None,
device: Place = None,
device: core.Place = None,
dtype: Optional[Union[paddle.dtype, ivy.Dtype]] = None,
seed: Optional[int] = None,
out: Optional[paddle.Tensor] = None,
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/tensorflow/control_flow_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_fn_wrapper(*loop_vars):

if not vars:
vars = (0,)
else:
elif isinstance(vars, dict):
vars = list(vars.values())
return tf.while_loop(test_fn_wrapper, body_fn_wrapper, loop_vars=vars)

Expand Down
33 changes: 33 additions & 0 deletions ivy/functional/backends/tensorflow/experimental/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

# local
from ivy.func_wrapper import (
inputs_to_ivy_arrays,
output_to_native_arrays,
with_unsupported_dtypes,
with_supported_dtypes,
with_supported_device_and_dtypes,
Expand Down Expand Up @@ -1675,3 +1677,34 @@ def sliding_window(
return tf.image.extract_patches(
images=input, sizes=kernel_size, strides=stride, rates=dilation, padding=padding
)


def rnn(
step_function,
inputs,
initial_states,
/,
*,
go_backwards: bool = False,
mask: Optional[Union[tf.Tensor, tf.Variable]] = None,
constants: Optional[Union[tf.Tensor, tf.Variable]] = None,
unroll: bool = False,
input_length: Optional[int] = None,
time_major: bool = False,
zero_output_for_mask: bool = False,
return_all_outputs: bool = True,
):
step_function = inputs_to_ivy_arrays(output_to_native_arrays(step_function))
return tf.keras.backend.rnn(
step_function,
inputs,
initial_states,
go_backwards=go_backwards,
mask=mask,
constants=constants,
unroll=unroll,
input_length=input_length,
time_major=time_major,
zero_output_for_mask=zero_output_for_mask,
return_all_outputs=return_all_outputs,
)
4 changes: 3 additions & 1 deletion ivy/functional/backends/torch/control_flow_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ def cond(*_):


def while_loop(test_fn, body_fn, vars):
result = list(vars.values())
result = vars
if isinstance(vars, dict):
result = list(vars.values())
while test_fn(*result):
result = body_fn(*result)
if not isinstance(result, tuple):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def lu_factor(
raise IvyNotImplementedException()


@with_unsupported_dtypes({"2.1.1 and below": ("float16",)}, backend_version)
def dot(
a: torch.Tensor,
b: torch.Tensor,
Expand Down
1 change: 1 addition & 0 deletions ivy/functional/frontends/tensorflow/keras/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from . import activations
from . import backend
from . import layers
from . import metrics
from . import regularizers
69 changes: 69 additions & 0 deletions ivy/functional/frontends/tensorflow/keras/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import functools
import ivy
import ivy.functional.frontends.tensorflow as tf_frontend
from ivy.functional.frontends.tensorflow.func_wrapper import (
_ivy_array_to_tensorflow,
_to_ivy_array,
to_ivy_arrays_and_back,
)


def bias_add(x, bias, data_format=None):
if data_format is None:
data_format = "channels_last"
bias_shape = bias.shape
if len(bias_shape) == 1:
if data_format == "channels_first":
return tf_frontend.nn.bias_add(x, bias, data_format="NC...")
return tf_frontend.nn.bias_add(x, bias, data_format="N...C")
if x.ndim in (3, 4, 5):
if data_format == "channels_first":
bias_reshape_axis = (1, bias_shape[-1]) + bias_shape[:-1]
return x + tf_frontend.reshape(bias, bias_reshape_axis)
return x + tf_frontend.reshape(bias, (1,) + bias_shape)
return tf_frontend.nn.bias_add(x, bias)


@to_ivy_arrays_and_back
def dot(x, y):
return ivy.dot(x, y)


@to_ivy_arrays_and_back
def rnn(
step_function,
inputs,
initial_states,
go_backwards=False,
mask=None,
constants=None,
unroll=False,
input_length=None,
time_major=False,
zero_output_for_mask=False,
return_all_outputs=True,
):
@functools.wraps(step_function)
def _new_step_function(*args, **kwargs):
frontend_args = ivy.nested_map(
_ivy_array_to_tensorflow, args, include_derived=True, shallow=False
)
frontend_kwargs = ivy.nested_map(
_ivy_array_to_tensorflow, kwargs, include_derived=True, shallow=False
)
ret = step_function(*frontend_args, **frontend_kwargs)
return ivy.nested_map(_to_ivy_array, ret, include_derived=True)

return ivy.rnn(
_new_step_function,
inputs,
initial_states,
go_backwards=go_backwards,
mask=mask,
constants=constants,
unroll=unroll,
input_length=input_length,
time_major=time_major,
zero_output_for_mask=zero_output_for_mask,
return_all_outputs=return_all_outputs,
)
4 changes: 2 additions & 2 deletions ivy/functional/frontends/tensorflow/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ def bias_add(value, bias, data_format=None, name=None):
if data_format is None:
data_format = "N...C"

chanel_index = data_format.find("C")
if chanel_index != 1:
channel_index = data_format.find("C")
if channel_index != 1:
return ivy.add(value, bias)
else:
value = ivy.swapaxes(value, 1, -1)
Expand Down
8 changes: 8 additions & 0 deletions ivy/functional/ivy/control_flow_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
from ivy.utils.backend import current_backend
from ivy.func_wrapper import (
handle_array_like_without_promotion,
handle_backend_invalid,
handle_device,
outputs_to_ivy_arrays,
to_native_arrays_and_back,
)
from ivy.utils.exceptions import handle_exceptions


def if_else(
Expand Down Expand Up @@ -61,6 +65,10 @@ def _if_else(cond, body_fn, orelse_fn, **vars):
return _if_else(cond, body_fn, orelse_fn, **vars)


@handle_exceptions
@handle_backend_invalid
@outputs_to_ivy_arrays
@handle_device
def while_loop(
test_fn: Callable,
body_fn: Callable,
Expand Down
Loading

0 comments on commit 93f71fb

Please sign in to comment.