Skip to content

Commit

Permalink
fix conflicting II
Browse files Browse the repository at this point in the history
  • Loading branch information
alter-xp committed May 24, 2021
2 parents ebf80cb + 76681a3 commit 2d8a118
Show file tree
Hide file tree
Showing 11 changed files with 771 additions and 1 deletion.
9 changes: 9 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,15 @@ struct UniqueAttrs : public tvm::AttrsNode<UniqueAttrs> {
.set_default(false);
}
}; // struct UniqueAttrs
/*! \brief Attributes used in segment_max, segment_min,
segment_mean, segment_sum, segment_prod operator */
struct SegmentAttrs : public tvm::AttrsNode<SegmentAttrs> {
int num_segments;

TVM_DECLARE_ATTRS(SegmentAttrs, "relay.attrs.SegmentAttrs") {
TVM_ATTR_FIELD(num_segments).set_default(0).describe("The maximum of segment_ids.");
}
}; // struct SegmentAttrs

} // namespace relay
} // namespace tvm
Expand Down
43 changes: 43 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,22 @@ def _impl(inputs, attr, params, mod):
return _impl


def _unsorted_segment(name):
def _impl(inputs, attr, params, mod):
# op description: https://www.tensorflow.org/api_docs/python/tf/math/unsorted_segment_max
try:
num_segments = _infer_value(inputs[2], params).asnumpy().tolist()
except Exception:
raise tvm.error.OpAttributeInvalid("Can't find num_segments.")
return AttrCvt(
op_name="segment_" + name,
ignores=["Tdim", "Tidx", "Tindices", "Tnumsegments"],
extras={"num_segments": num_segments},
)([inputs[0], inputs[1]], attr)

return _impl


def _crop_and_resize():
def _impl(inputs, attr, params, mod):
# input image is a 4-D tensor of shape [batch, image_height, image_width, depth]
Expand Down Expand Up @@ -2617,6 +2633,23 @@ def _impl(inputs, attr, params, mod):
return _impl


def _segment(opname):
def _impl(inputs, attr, params, mod):
# op description: https://www.tensorflow.org/api_docs/python/tf/math/segment_max
try:
segment_ids = _infer_value(inputs[1], params)
except Exception:
raise tvm.error.OpAttributeInvalid("Can't get value of segment_ids.")

num_out = segment_ids.asnumpy().max() + 1
out = AttrCvt(op_name=opname, ignores=["T", "Tindices"], extras={"num_segments": num_out})(
inputs, attr
)
return out

return _impl


def _size():
def _impl(inputs, attr, params, mod):
new_attr = attr
Expand Down Expand Up @@ -2864,6 +2897,11 @@ def _impl(inputs, attr, params, mod):
"SelectV2": _where(),
"Selu": _selu(),
"Shape": _shape(),
"SegmentMax": _segment("segment_max"),
"SegmentMean": _segment("segment_mean"),
"SegmentMin": _segment("segment_min"),
"SegmentProd": _segment("segment_prod"),
"SegmentSum": _segment("segment_sum"),
"Sigmoid": AttrCvt("sigmoid"),
"Sign": AttrCvt("sign"),
"Sin": AttrCvt("sin"),
Expand Down Expand Up @@ -2915,6 +2953,11 @@ def _impl(inputs, attr, params, mod):
"UniqueWithCounts": _unique(True),
"Unpack": _unpack(),
"UnravelIndex": _unravel_index(),
"UnsortedSegmentMax": _unsorted_segment("max"),
"UnsortedSegmentMin": _unsorted_segment("min"),
"UnsortedSegmentMean": _unsorted_segment("mean"),
"UnsortedSegmentProd": _unsorted_segment("prod"),
"UnsortedSegmentSum": _unsorted_segment("sum"),
"Where": _where(),
"ZerosLike": AttrCvt("zeros_like"),
}
Expand Down
23 changes: 22 additions & 1 deletion python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from tvm import topi
from tvm.runtime import convert

from .op import register_compute, register_shape_func
from . import strategy
from .op import register_compute, register_shape_func, register_strategy
from .op import register_broadcast_schedule, register_injective_schedule
from .op import register_pattern, OpPattern

Expand Down Expand Up @@ -283,3 +284,23 @@ def elemwise_shape_func(attrs, inputs, _):
register_shape_func("sigmoid", False, elemwise_shape_func)
register_shape_func("tanh", False, elemwise_shape_func)
register_shape_func("logical_not", False, elemwise_shape_func)

# segment_max
register_strategy("segment_max", strategy.segment_max_strategy)
register_pattern("segment_max", OpPattern.OPAQUE)

# segment_min
register_strategy("segment_min", strategy.segment_min_strategy)
register_pattern("segment_min", OpPattern.OPAQUE)

# segment_mean
register_strategy("segment_mean", strategy.segment_mean_strategy)
register_pattern("segment_mean", OpPattern.OPAQUE)

# segment_sum
register_strategy("segment_sum", strategy.segment_sum_strategy)
register_pattern("segment_sum", OpPattern.OPAQUE)

# segment_prod
register_strategy("segment_prod", strategy.segment_prod_strategy)
register_pattern("segment_prod", OpPattern.OPAQUE)
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,11 @@ class ProposalAttrs(Attrs):
"""Attributes used in proposal operators"""


@tvm._ffi.register_object("relay.attrs.SegmentAttrs")
class SegmentAttrs(Attrs):
"""Attributes used in segment operators"""


@tvm._ffi.register_object("relay.attrs.MaxPool2DAttrs")
class MaxPool2DAttrs(Attrs):
"""Attributes used in max_pool2d operators"""
Expand Down
115 changes: 115 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,6 +1540,29 @@ def uniform_strategy(attrs, inputs, out_type, target):
return strategy


# segment_max
def wrap_compute_segment_max(topi_compute):
"""wrap segment_max topi compute"""

def _compute_segment_max(attrs, inputs, out_type):
num_segments = attrs.num_segments
return [topi_compute(inputs[0], inputs[1], num_segments, "max")]

return _compute_segment_max


@override_native_generic_func("segment_max_strategy")
def segment_max_strategy(attrs, inputs, out_type, target):
"""segment_max generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_segment_max(topi.segment_op),
wrap_topi_schedule(topi.generic.schedule_segment_op),
name="segment_max.generic",
)
return strategy


def wrap_compute_scanop(topi_compute):
"""Wrap scanop style topi compute"""

Expand All @@ -1561,6 +1584,29 @@ def cumsum_strategy(attrs, inputs, out_type, target):
return strategy


# segment_min
def wrap_compute_segment_min(topi_compute):
"""wrap segment_min topi compute"""

def _compute_segment_min(attrs, inputs, out_type):
num_segments = attrs.num_segments
return [topi_compute(inputs[0], inputs[1], num_segments, "min")]

return _compute_segment_min


@override_native_generic_func("segment_min_strategy")
def segment_min_strategy(attrs, inputs, out_type, target):
"""segment_min generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_segment_min(topi.segment_op),
wrap_topi_schedule(topi.generic.schedule_segment_op),
name="segment_min.generic",
)
return strategy


@override_native_generic_func("cumprod_strategy")
def cumprod_strategy(attrs, inputs, out_type, target):
"""cumprod generic strategy"""
Expand All @@ -1573,6 +1619,29 @@ def cumprod_strategy(attrs, inputs, out_type, target):
return strategy


# segment_mean
def wrap_compute_segment_mean(topi_compute):
"""wrap segment_mean topi compute"""

def _compute_segment_mean(attrs, inputs, out_type):
num_segments = attrs.num_segments
return [topi_compute(inputs[0], inputs[1], num_segments, "mean")]

return _compute_segment_mean


@override_native_generic_func("segment_mean_strategy")
def segment_mean_strategy(attrs, inputs, out_type, target):
"""segment_mean generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_segment_mean(topi.segment_op),
wrap_topi_schedule(topi.generic.schedule_segment_op),
name="segment_mean.generic",
)
return strategy


def wrap_compute_unique(topi_compute):
"""Wrap unique topi compute"""

Expand All @@ -1594,8 +1663,54 @@ def unique_strategy(attrs, inputs, out_type, target):
return strategy


# segment_sum
def wrap_compute_segment_sum(topi_compute):
"""wrap segment_sum topi compute"""

def _compute_segment_sum(attrs, inputs, out_type):
num_segments = attrs.num_segments
return [topi_compute(inputs[0], inputs[1], num_segments, "sum")]

return _compute_segment_sum


@override_native_generic_func("segment_sum_strategy")
def segment_sum_strategy(attrs, inputs, out_type, target):
"""segment_sum generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_segment_sum(topi.segment_op),
wrap_topi_schedule(topi.generic.schedule_segment_op),
name="segment_sum.generic",
)
return strategy


@generic_func
def schedule_transpose(attrs, outs, target):
"""schedule transpose"""
with target:
return schedule_injective(attrs, outs, target)


# segment_prod
def wrap_compute_segment_prod(topi_compute):
"""wrap segment_prod topi compute"""

def _compute_segment_prod(attrs, inputs, out_type):
num_segments = attrs.num_segments
return [topi_compute(inputs[0], inputs[1], num_segments, "prod")]

return _compute_segment_prod


@override_native_generic_func("segment_prod_strategy")
def segment_prod_strategy(attrs, inputs, out_type, target):
"""segment_prod generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_segment_prod(topi.segment_op),
wrap_topi_schedule(topi.generic.schedule_segment_op),
name="segment_prod.generic",
)
return strategy
110 changes: 110 additions & 0 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1285,3 +1285,113 @@ def isinf(data):
The computed result.
"""
return _make.isinf(data)


def segment_max(data, segment_ids, num_segments):
"""Computes the maximum along segments of a tensor.
Parameters
----------
data : relay.Expr
The input data
segment_ids : relay.Expr
The segments data
num_segments : int
The maximum of segment_ids.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.segment_max(data, segment_ids, num_segments)


def segment_min(data, segment_ids, num_segments):
"""Computes the minimum along segments of a tensor.
Parameters
----------
data : relay.Expr
The input data
segment_ids : relay.Expr
The segments data
num_segments : int
The maximum of segment_ids.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.segment_min(data, segment_ids, num_segments)


def segment_mean(data, segment_ids, num_segments):
"""Computes the mean along segments of a tensor.
Parameters
----------
data : relay.Expr
The input data
segment_ids : relay.Expr
The segments data
num_segments : int
The maximum of segment_ids.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.segment_mean(data, segment_ids, num_segments)


def segment_sum(data, segment_ids, num_segments):
"""Computes the sum along segments of a tensor.
Parameters
----------
data : relay.Expr
The input data
segment_ids : relay.Expr
The segments data
num_segments : int
The maximum of segment_ids.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.segment_sum(data, segment_ids, num_segments)


def segment_prod(data, segment_ids, num_segments):
"""Computes the prod along segments of a tensor.
Parameters
----------
data : relay.Expr
The input data
segment_ids : relay.Expr
The segments data
num_segments : int
The maximum of segment_ids.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.segment_prod(data, segment_ids, num_segments)
Loading

0 comments on commit 2d8a118

Please sign in to comment.