Skip to content

Commit

Permalink
fix: better handling for different operand dtypes in tf einsum
Browse files Browse the repository at this point in the history
  • Loading branch information
MahmoudAshraf97 committed Oct 28, 2023
1 parent 99b0dfa commit d919d79
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions ivy/functional/backends/tensorflow/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

# local
import ivy
from ivy.functional.ivy.statistical import _get_promoted_type_of_operands
from ivy.func_wrapper import with_unsupported_dtypes
from . import backend_version
from ivy.utils.einsum_parser import legalise_einsum_expr
Expand Down Expand Up @@ -217,6 +216,15 @@ def einsum(
*operands: Union[tf.Tensor, tf.Variable],
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
dtype = _get_promoted_type_of_operands(operands)
equation = legalise_einsum_expr(*[equation, *operands])
return tf.cast(tf.einsum(equation, *operands), dtype)
dtype_list = set(map(lambda x: x.dtype, operands))
dtype = dtype_list.pop()
if len(dtype_list) > 0:
for d in dtype_list:
dtype = ivy.promote_types(dtype, d)
dtype = ivy.as_native_dtype(dtype)
operands = list(
map(lambda x: tf.cast(x, dtype) if x.dtype != dtype else x, operands)
)

return tf.einsum(equation, *operands)

0 comments on commit d919d79

Please sign in to comment.