Skip to content

Commit

Permalink
[torchlib] Simplify aten_sum_dim_IntList
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby authored Nov 19, 2024
1 parent 35b20fe commit e8a87c8
Showing 1 changed file with 6 additions and 32 deletions.
38 changes: 6 additions & 32 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8019,47 +8019,21 @@ def aten_sum_dim_IntList(
) -> TReal:
"""sum(Tensor self, SymInt dim, bool keepdim, *, ScalarType? dtype=None) -> Tensor"""

# NOTE: trace_only because both if branches need to be the same type, but we have
# a cast in the if branch.

# TODO: Combine the overloads when OptionalHasElement() works
if dim is None:
result = _aten_sum_dim_none(self, keepdim=keepdim)
else:
result = _aten_sum_dim_onnx(self, dim, keepdim=keepdim)

if dtype != -1:
result = op.Cast(result, to=dtype)

return result


@torch_op("aten::sum", private=True, traceable=True)
def _aten_sum_dim_onnx(self: TReal, dim: INT64, keepdim: bool = False) -> TReal:
self_is_scalar = IsScalar(self)
if self_is_scalar:
self = op.Reshape(self, op.Constant(value_ints=[-1]))

if IsScalar(dim):
if dim is None:
result = op.ReduceSum(self, keepdims=keepdim)
else:
dim = op.Reshape(dim, op.Constant(value_ints=[-1]))
dim = op.Cast(dim, to=INT64.dtype)
result = op.ReduceSum(self, dim, keepdims=keepdim)

result = op.ReduceSum(self, dim, keepdims=keepdim)
if self_is_scalar:
result = op.Squeeze(result)
return result


@torch_op("aten::sum", private=True)
def _aten_sum_dim_none(self: TReal, keepdim: bool = False) -> TReal:
self_is_scalar = IsScalar(self)
if self_is_scalar:
self = op.Reshape(self, op.Constant(value_ints=[-1]))

result = op.ReduceSum(self, keepdims=keepdim)
if dtype != -1:
result = op.Cast(result, to=dtype)

if self_is_scalar:
result = op.Squeeze(result)
return result


Expand Down

0 comments on commit e8a87c8

Please sign in to comment.