From a7b2218d367656ab07b95858546c3ac73f7c2495 Mon Sep 17 00:00:00 2001 From: Sam Armstrong <88863522+Sam-Armstrong@users.noreply.github.com> Date: Fri, 8 Mar 2024 02:09:14 +0000 Subject: [PATCH] fix: tf backend range supporting float args with int dtype (#28507) --- ivy/functional/backends/tensorflow/creation.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/ivy/functional/backends/tensorflow/creation.py b/ivy/functional/backends/tensorflow/creation.py index 2b7f602d8a2dd..c6f61ba5fde3c 100644 --- a/ivy/functional/backends/tensorflow/creation.py +++ b/ivy/functional/backends/tensorflow/creation.py @@ -52,6 +52,15 @@ def arange( stop = float(start) else: stop = start + + # convert builtin types to tf scalars, as is expected by tf.range + if isinstance(start, (float, int)): + start = tf.constant(start) + if isinstance(stop, (float, int)): + stop = tf.constant(stop) + if isinstance(step, (float, int)): + step = tf.constant(step) + if dtype is None: if isinstance(start, int) and isinstance(stop, int) and isinstance(step, int): return tf.cast(tf.range(start, stop, delta=step, dtype=tf.int64), tf.int32)