From 514fedc4f6fc7c55ef521c7dd19d1fcb658440a3 Mon Sep 17 00:00:00 2001 From: Haris Mahmood <70361308+hmahmood24@users.noreply.github.com> Date: Wed, 6 Mar 2024 20:32:14 +0500 Subject: [PATCH] feat(backends): Added a primary implementation for flatten in tensorflow backend (#28488) --- .../backends/tensorflow/manipulation.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/ivy/functional/backends/tensorflow/manipulation.py b/ivy/functional/backends/tensorflow/manipulation.py index 8635ca5ebb173..4326b5a86caf7 100644 --- a/ivy/functional/backends/tensorflow/manipulation.py +++ b/ivy/functional/backends/tensorflow/manipulation.py @@ -69,6 +69,51 @@ def expand_dims( raise ivy.utils.exceptions.IvyIndexError(error) from error +def flatten( + x: tf.Tensor, + /, + *, + copy: Optional[bool] = None, + start_dim: Optional[int] = 0, + end_dim: Optional[int] = -1, + order: Optional[str] = "C", + out: Optional[tf.Tensor] = None, +) -> tf.Tensor: + if x.shape == (): + x = tf.reshape(x, (1, -1))[0, :] + if start_dim == end_dim: + return ivy.inplace_update(out, x) if ivy.exists(out) else x + if start_dim not in range(-x.shape.rank, x.shape.rank): + raise IndexError( + "Dimension out of range (expected to be in range of" + f" {[-x.shape.rank, x.shape.rank - 1]}, but got {start_dim}" + ) + if end_dim not in range(-x.shape.rank, x.shape.rank): + raise IndexError( + "Dimension out of range (expected to be in range of" + f" {[-x.shape.rank, x.shape.rank - 1]}, but got {end_dim}" + ) + + # If end_dim or start_dim is negative, count them from the end + if end_dim < 0: + end_dim += x.shape.rank + if start_dim < 0: + start_dim += x.shape.rank + + if start_dim == end_dim: + return x + + in_shape = tf.shape(x) + flattened_dim = tf.math.reduce_prod(in_shape[start_dim : end_dim + 1]) + out_shape = tf.concat( + [in_shape[:start_dim], [flattened_dim], in_shape[end_dim + 1 :]], axis=0 + ) + ivy.utils.assertions.check_elem_in_list(order, ["C", "F"]) + if order == "F": + return _reshape_fortran_tf(x, out_shape) + return tf.reshape(x, out_shape) + + def flip( x: Union[tf.Tensor, tf.Variable], /,