Skip to content

Commit

Permalink
Update core.py
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby authored Nov 19, 2024
1 parent e8a87c8 commit f36169e
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8013,11 +8013,20 @@ def aten_sub_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
return aten_sub(self, other, alpha=alpha)


@torch_op(("aten::sum", "aten::sum.dim_IntList"), trace_only=True)
@torch_op("aten::sum", trace_only=True)
def aten_sumt(self: TReal, dtype: int = -1) -> TReal:
"""sum(Tensor self, *, ScalarType? dtype=None) -> Tensor"""
result = op.ReduceSum(self, keepdims=keepdim)
if dtype != -1 and dtype is not None:
result = op.Cast(result, to=dtype)
return result


@torch_op("aten::sum.dim_IntList", trace_only=True)
def aten_sum_dim_IntList(
self: TReal, dim: Optional[INT64] = None, keepdim: bool = False, dtype: int = -1
) -> TReal:
"""sum(Tensor self, SymInt dim, bool keepdim, *, ScalarType? dtype=None) -> Tensor"""
"""sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"""

self_is_scalar = IsScalar(self)
if self_is_scalar:
Expand All @@ -8031,7 +8040,7 @@ def aten_sum_dim_IntList(
if self_is_scalar:
result = op.Squeeze(result)

if dtype != -1:
if dtype != -1 and dtype is not None:
result = op.Cast(result, to=dtype)

return result
Expand Down

0 comments on commit f36169e

Please sign in to comment.