Skip to content

Commit

Permalink
fix: Variable frontend methods setting ivy_array to frontend tensor (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam-Armstrong authored Feb 26, 2024
1 parent e22cd90 commit c7a867d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,15 @@


class ResourceVariable(tf_frontend.Variable):
pass
def __repr__(self):
return (
repr(self._ivy_array).replace(
"ivy.array",
"ivy.functional.frontends.tensorflow.python.ops.resource_variable_ops.ResourceVariable",
)[:-1]
+ ", shape="
+ str(self._ivy_array.shape)
+ ", dtype="
+ str(self._ivy_array.dtype)
+ ")"
)
7 changes: 5 additions & 2 deletions ivy/functional/frontends/tensorflow/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,22 +58,25 @@ def assign(self, value, use_locking=None, name=None, read_value=True):
as_array=False,
)
self._ivy_array = value._ivy_array
return self

def assign_add(self, delta, use_locking=None, name=None, read_value=True):
ivy.utils.assertions.check_equal(
delta.ivy_array.shape if hasattr(delta, "ivy_array") else ivy.shape(delta),
self.shape,
as_array=False,
)
self._ivy_array = tf_frontend.math.add(self._ivy_array, delta._ivy_array)
self._ivy_array = ivy.add(self._ivy_array, delta._ivy_array)
return self

def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
ivy.utils.assertions.check_equal(
delta.ivy_array.shape if hasattr(delta, "ivy_array") else ivy.shape(delta),
self.shape,
as_array=False,
)
self._ivy_array = tf_frontend.math.subtract(self._ivy_array, delta._ivy_array)
self._ivy_array = ivy.subtract(self._ivy_array, delta._ivy_array)
return self

def batch_scatter_update(
self, sparse_delta, use_locking=None, name=None, read_value=True
Expand Down

0 comments on commit c7a867d

Please sign in to comment.