Skip to content

Commit

Permalink
feat(backends): Added a primary implementation for flatten in tensorf…
Browse files Browse the repository at this point in the history
…low backend (#28488)
  • Loading branch information
hmahmood24 authored Mar 6, 2024
1 parent 5791218 commit 514fedc
Showing 1 changed file with 45 additions and 0 deletions.
45 changes: 45 additions & 0 deletions ivy/functional/backends/tensorflow/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,51 @@ def expand_dims(
raise ivy.utils.exceptions.IvyIndexError(error) from error


def flatten(
x: tf.Tensor,
/,
*,
copy: Optional[bool] = None,
start_dim: Optional[int] = 0,
end_dim: Optional[int] = -1,
order: Optional[str] = "C",
out: Optional[tf.Tensor] = None,
) -> tf.Tensor:
if x.shape == ():
x = tf.reshape(x, (1, -1))[0, :]
if start_dim == end_dim:
return ivy.inplace_update(out, x) if ivy.exists(out) else x
if start_dim not in range(-x.shape.rank, x.shape.rank):
raise IndexError(
"Dimension out of range (expected to be in range of"
f" {[-x.shape.rank, x.shape.rank - 1]}, but got {start_dim}"
)
if end_dim not in range(-x.shape.rank, x.shape.rank):
raise IndexError(
"Dimension out of range (expected to be in range of"
f" {[-x.shape.rank, x.shape.rank - 1]}, but got {end_dim}"
)

# If end_dim or start_dim is negative, count them from the end
if end_dim < 0:
end_dim += x.shape.rank
if start_dim < 0:
start_dim += x.shape.rank

if start_dim == end_dim:
return x

in_shape = tf.shape(x)
flattened_dim = tf.math.reduce_prod(in_shape[start_dim : end_dim + 1])
out_shape = tf.concat(
[in_shape[:start_dim], [flattened_dim], in_shape[end_dim + 1 :]], axis=0
)
ivy.utils.assertions.check_elem_in_list(order, ["C", "F"])
if order == "F":
return _reshape_fortran_tf(x, out_shape)
return tf.reshape(x, out_shape)


def flip(
x: Union[tf.Tensor, tf.Variable],
/,
Expand Down

0 comments on commit 514fedc

Please sign in to comment.