Skip to content

Commit

Permalink
fix: update tf backend matmul to be compatible with tf.function
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam-Armstrong committed Sep 1, 2024
1 parent c51862a commit 8a460e1
Showing 1 changed file with 7 additions and 20 deletions.
27 changes: 7 additions & 20 deletions ivy/functional/backends/tensorflow/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,37 +233,24 @@ def matmul(
x1 = tf.cast(x1, dtype_from)
x2 = tf.cast(x2, dtype_from)

if (
x1.shape == ()
or x2.shape == ()
or (len(x1.shape) == len(x2.shape) == 1 and x1.shape != x2.shape)
or (len(x1.shape) == len(x2.shape) == 1 and x1.shape != x2.shape)
or (len(x1.shape) == 1 and len(x2.shape) >= 2 and x1.shape[0] != x2.shape[-2])
or (len(x2.shape) == 1 and len(x1.shape) >= 2 and x2.shape[0] != x1.shape[-1])
or (len(x1.shape) >= 2 and len(x2.shape) >= 2 and x1.shape[-1] != x2.shape[-2])
):
raise ivy.utils.exceptions.IvyException("Error,shapes not compatible")

x1_padded = False
x1_padded_2 = False
x2_padded = False

if len(x1.shape) == len(x2.shape) == 1:
if x1.shape == 0:
if tf.reduce_all([tf.equal(tf.rank(x1), tf.rank(x2)), tf.equal(tf.rank(x1), 1)]):
if tf.equal(tf.size(x1), 0):
ret = tf.constant(0)
else:
ret = tf.reduce_sum(tf.math.multiply(x1, x2))
ret = tf.cast(ret, dtype=dtype_from) # return ret

ret = tf.cast(ret, dtype=dtype_from)
else:
if len(x1.shape) == 1:
if len(x2.shape) == 2:
if tf.equal(tf.rank(x1), 1):
if tf.equal(tf.rank(x2), 2):
x1_padded_2 = True
elif len(x2.shape) > 2:
elif tf.greater(tf.rank(x2), 2):
x1_padded = True
x1 = tf.expand_dims(x1, axis=0)

elif len(x2.shape) == 1 and len(x1.shape) >= 2:
elif tf.reduce_all([tf.equal(tf.rank(x2), 1), tf.greater_equal(tf.rank(x1), 2)]):
x2 = tf.expand_dims(x2, axis=1)
x2_padded = True
ret = tf.matmul(x1, x2)
Expand Down

0 comments on commit 8a460e1

Please sign in to comment.