Skip to content

Commit

Permalink
fix: tf backend set_item
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam-Armstrong committed Sep 16, 2024
1 parent f85e3e5 commit d5e3461
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions ivy/functional/backends/tensorflow/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,26 +86,26 @@ def set_item(
# TODO: we should re-write this at some point so it's compatible with tf.function (don't use numpy as an intermediary)
# when doing this, be sure to check the performance of the function on large tensors, compared to this implementation

x_np = x.numpy()
val_np = val.numpy()
if tf.is_tensor(x):
x = x.numpy()
if tf.is_tensor(val):
val = val.numpy()

if isinstance(query, (tf.Tensor, tf.Variable)):
query_np = query.numpy()
query = query.numpy()
elif isinstance(query, tuple):
query_np = tuple(
query = tuple(
q.numpy() if isinstance(q, (tf.Tensor, tf.Variable)) else q
for q in query
)
else:
query_np = query

x_np[query_np] = val_np
x[query] = val

if isinstance(x, tf.Variable) and not copy:
x.assign(x_np)
x.assign(x)
return x
else:
return tf.Variable(x_np) if isinstance(x, tf.Variable) else tf.convert_to_tensor(x_np)
return tf.Variable(x) if isinstance(x, tf.Variable) else tf.convert_to_tensor(x)


def to_numpy(x: Union[tf.Tensor, tf.Variable], /, *, copy: bool = True) -> np.ndarray:
Expand Down

0 comments on commit d5e3461

Please sign in to comment.