Skip to content

Commit

Permalink
[TORCH][TOPI] Support mean reduction for scatter_reduce (#14110)
Browse files Browse the repository at this point in the history
* support mean reduction, clean comments, extend tests

* fix pylint

---------

Co-authored-by: Valery Chernov <valery.chernov@deelvin.com>
  • Loading branch information
vvchernov and vvchernov authored Feb 24, 2023
1 parent d5806ec commit 9b6df18
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 14 deletions.
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ struct ScatterElementsAttrs : public tvm::AttrsNode<ScatterElementsAttrs> {
TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to select values.");
TVM_ATTR_FIELD(reduction).set_default("update").describe(
"Reduction mode of the scatter elements, "
"either \"update\", \"add\", \"mul\", \"min\" or \"max\".");
"either \"update\", \"add\", \"mul\", \"mean\", \"min\" or \"max\".");
}
};

Expand Down
3 changes: 0 additions & 3 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2766,9 +2766,6 @@ def scatter_reduce(self, inputs, input_types):
reduce = "min"
elif reduce == "amax":
reduce = "max"
else: # reduce == "mean"
# TODO(vvchernov): support mean reduction
raise NotImplementedError("Mean reduction has not been supported yet!")

return _op.scatter_elements(data, index, src, axis=dim, reduction=reduce)

Expand Down
5 changes: 3 additions & 2 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,10 +397,11 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"):
The axis to scatter elements on. It is zero by default.
reduction : string, optional
The reduction mode for scatter. Choise is from ["update", "add", "mul", "min", max"]
The reduction mode for scatter. Choise is from ["update", "add", "mul", "mean", "min", max"]
If update, the update values will replace the input data
If add, the update values will be added to the input data
If mul, the update values will be multiply to the input data
If mul, the input data will be multiplied on the update values
If mean, the input data will be mean between the update values and the input data
If min, there is choice of minimal between the update values and the input data
If max, there is choice of maximal between the update values and the input data
It is "update" by default
Expand Down
10 changes: 8 additions & 2 deletions python/tvm/topi/cuda/scatter_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,11 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"):
The axis to scatter on. It is zero by default.
reduction : optional, string
The update mode for the algorithm, either "update", "add", "mul", "min" or "max"
The update mode for the algorithm, either "update", "add", "mul", "mean", "min" or "max"
If update, the update values will replace the input data
If add, the update values will be added to the input data
If mul, the update values will be multiply to the input data
If mul, the input data will be multiplied on the update values
If mean, the input data will be mean between the update values and the input data
If min, there is choice of minimal between the update values and the input data
If max, there is choice of maximal between the update values and the input data
It is "update" by default
Expand All @@ -258,6 +259,9 @@ def add_func(dst_ptr, dst_index, update):
def mul_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] *= update

def mean_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] = (dst_ptr[dst_index] + update) / 2

def min_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] = tir.min(dst_ptr[dst_index], update)

Expand All @@ -271,6 +275,8 @@ def max_func(dst_ptr, dst_index, update):
reduce_func = add_func
elif reduction == "mul":
reduce_func = mul_func
elif reduction == "mean":
reduce_func = mean_func
elif reduction == "min":
reduce_func = min_func
elif reduction == "max":
Expand Down
10 changes: 8 additions & 2 deletions python/tvm/topi/scatter_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"):
The update mode for the algorithm, either "update", "add", "mul", "min" or "max"
If update, the update values will replace the input data
If add, the update values will be added to the input data
If mul, the update values will be multiply to the input data
If mul, the input data will be multiplied on the update values
If mean, the input data will be mean between the update values and the input data
If min, there is choice of minimal between the update values and the input data
If max, there is choice of maximal between the update values and the input data
It is "update" by default
Expand Down Expand Up @@ -133,6 +134,9 @@ def add_func(dst_ptr, dst_index, update):
def mul_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] *= update

def mean_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] = (dst_ptr[dst_index] + update) / 2

def min_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] = tir.min(dst_ptr[dst_index], update)

Expand All @@ -146,13 +150,15 @@ def max_func(dst_ptr, dst_index, update):
reduce_func = add_func
elif reduction == "mul":
reduce_func = mul_func
elif reduction == "mean":
reduce_func = mean_func
elif reduction == "min":
reduce_func = min_func
elif reduction == "max":
reduce_func = max_func
else:
raise NotImplementedError(
"scatter_elements reduction not in [update, add, mul, min, max]:", reduction
"scatter_elements reduction not in [update, add, mul, mean, min, max]:", reduction
)

out_buf = tir.decl_buffer(data.shape, data.dtype, "out_buf")
Expand Down
6 changes: 2 additions & 4 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4252,16 +4252,14 @@ def test_fn_scatter_reduce(dim, reduce):
in_src = torch.rand(2, 5) - 1

targets = ["llvm", "cuda"]
# TODO(vvchernov): support test of mean reduction and include_self=False
for reduce in ["sum", "prod", "amin", "amax"]:
for reduce in ["sum", "prod", "amin", "amax", "mean"]:
verify_trace_model(test_fn_scatter_reduce(0, reduce), [in_data, in_index, in_src], targets)

in_data = torch.rand(2, 4) - 1
in_index = torch.tensor([[2], [3]])
in_src = torch.rand(2, 1) - 1

# TODO(vvchernov): support test of mean reduction and include_self=False
for reduce in ["sum", "prod", "amin", "amax"]:
for reduce in ["sum", "prod", "amin", "amax", "mean"]:
verify_trace_model(test_fn_scatter_reduce(1, reduce), [in_data, in_index, in_src], targets)


Expand Down

0 comments on commit 9b6df18

Please sign in to comment.