Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI][OP] cuda for argwhere #6868

Merged
merged 13 commits into from
Dec 4, 2020
2 changes: 1 addition & 1 deletion 3rdparty/vta-hw
16 changes: 1 addition & 15 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,7 @@ def compute_strided_set(attrs, inputs, output_type):
_reg.register_pattern("auto_scheduler_layout_transform", OpPattern.INJECTIVE)

# argwhere
@_reg.register_compute("argwhere")
def compute_argwhere(attrs, inputs, output_type):
"""Compute definition of argwhere"""
output_shape = []
for s in output_type.shape:
if hasattr(s, "value"):
output_shape.append(s)
else:
# see Any, replace it with a var
output_shape.append(te.var("any_dim", "int32"))
new_output_type = tvm.relay.ty.TensorType(output_shape, "int32")
return [topi.argwhere(new_output_type, inputs[0])]


_reg.register_schedule("argwhere", strategy.schedule_argwhere)
_reg.register_strategy("argwhere", strategy.argwhere_strategy)

# scatter
@_reg.register_compute("scatter")
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,3 +921,15 @@ def correlation_strategy_cuda(attrs, inputs, out_type, target):
name="correlation.cuda",
)
return strategy


@argwhere_strategy.register(["cuda", "gpu"])
def argwhere_strategy_cuda(attrs, inputs, out_type, target):
"""argwhere cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_argwhere(topi.cuda.argwhere),
wrap_topi_schedule(topi.cuda.schedule_argwhere),
name="argwhere.cuda",
)
return strategy
39 changes: 30 additions & 9 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging

import re
from tvm import topi, _ffi
from tvm import topi, _ffi, te, ir
from tvm.topi.utils import get_const_int, get_const_float, get_const_tuple, get_float_tuple
from tvm.target import generic_func, override_native_generic_func
from .. import op as _op
Expand Down Expand Up @@ -1034,14 +1034,6 @@ def proposal_strategy(attrs, inputs, out_type, target):
return strategy


# argwhere
@generic_func
def schedule_argwhere(attrs, outs, target):
"""schedule argwhere"""
with target:
return topi.generic.schedule_argwhere(outs)


# scatter
@override_native_generic_func("scatter_strategy")
def scatter_strategy(attrs, outs, out_type, target):
Expand Down Expand Up @@ -1223,3 +1215,32 @@ def correlation_strategy(attrs, inputs, out_type, target):
name="correlation.generic",
)
return strategy


# argwhere
def wrap_compute_argwhere(topi_compute):
"""wrap argwhere topi compute"""

def _compute_argwhere(attrs, inputs, out_type):
output_shape = []
for s in out_type.shape:
if hasattr(s, "value"):
output_shape.append(s)
else:
output_shape.append(te.var("any_dim", "int32"))
new_output_type = ir.TensorType(output_shape, "int32")
return [topi_compute(new_output_type, inputs[0])]

return _compute_argwhere


@override_native_generic_func("argwhere_strategy")
def argwhere_strategy(attrs, inputs, out_type, target):
"""argwhere generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_argwhere(topi.argwhere),
wrap_topi_schedule(topi.generic.schedule_argwhere),
name="argwhere.generic",
)
return strategy
2 changes: 2 additions & 0 deletions python/tvm/topi/argwhere.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
"""Argwhere operator"""
import tvm
from tvm.te import hybrid


Expand Down Expand Up @@ -169,6 +170,7 @@ def hybrid_argwhere_5d(output_shape, condition):
return a


@tvm.target.generic_func
def argwhere(output_shape, condition):
"""Find the indices of elements of a tensor that are non-zero.

Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,4 @@
from .conv2d_hwnc_tensorcore import *
from .correlation import *
from .sparse import *
from .argwhere import *
Loading