Skip to content

Commit

Permalink
fix rebase bug
Browse files Browse the repository at this point in the history
  • Loading branch information
alter-xp committed Jan 14, 2021
1 parent 45c270a commit 76681a3
Showing 1 changed file with 115 additions and 0 deletions.
115 changes: 115 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,3 +1361,118 @@ def threefry_split_strategy(attrs, inputs, out_type, target):
name="threefry_split.generic",
)
return strategy


# segment_max
def wrap_compute_segment_max(topi_compute):
"""wrap segment_max topi compute"""

def _compute_segment_max(attrs, inputs, out_type):
num_segments = attrs.num_segments
return [topi_compute(inputs[0], inputs[1], num_segments, "max")]

return _compute_segment_max


@override_native_generic_func("segment_max_strategy")
def segment_max_strategy(attrs, inputs, out_type, target):
"""segment_max generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_segment_max(topi.segment_op),
wrap_topi_schedule(topi.generic.schedule_segment_op),
name="segment_max.generic",
)
return strategy


# segment_min
def wrap_compute_segment_min(topi_compute):
"""wrap segment_min topi compute"""

def _compute_segment_min(attrs, inputs, out_type):
num_segments = attrs.num_segments
return [topi_compute(inputs[0], inputs[1], num_segments, "min")]

return _compute_segment_min


@override_native_generic_func("segment_min_strategy")
def segment_min_strategy(attrs, inputs, out_type, target):
"""segment_min generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_segment_min(topi.segment_op),
wrap_topi_schedule(topi.generic.schedule_segment_op),
name="segment_min.generic",
)
return strategy


# segment_mean
def wrap_compute_segment_mean(topi_compute):
"""wrap segment_mean topi compute"""

def _compute_segment_mean(attrs, inputs, out_type):
num_segments = attrs.num_segments
return [topi_compute(inputs[0], inputs[1], num_segments, "mean")]

return _compute_segment_mean


@override_native_generic_func("segment_mean_strategy")
def segment_mean_strategy(attrs, inputs, out_type, target):
"""segment_mean generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_segment_mean(topi.segment_op),
wrap_topi_schedule(topi.generic.schedule_segment_op),
name="segment_mean.generic",
)
return strategy


# segment_sum
def wrap_compute_segment_sum(topi_compute):
"""wrap segment_sum topi compute"""

def _compute_segment_sum(attrs, inputs, out_type):
num_segments = attrs.num_segments
return [topi_compute(inputs[0], inputs[1], num_segments, "sum")]

return _compute_segment_sum


@override_native_generic_func("segment_sum_strategy")
def segment_sum_strategy(attrs, inputs, out_type, target):
"""segment_sum generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_segment_sum(topi.segment_op),
wrap_topi_schedule(topi.generic.schedule_segment_op),
name="segment_sum.generic",
)
return strategy


# segment_prod
def wrap_compute_segment_prod(topi_compute):
"""wrap segment_prod topi compute"""

def _compute_segment_prod(attrs, inputs, out_type):
num_segments = attrs.num_segments
return [topi_compute(inputs[0], inputs[1], num_segments, "prod")]

return _compute_segment_prod


@override_native_generic_func("segment_prod_strategy")
def segment_prod_strategy(attrs, inputs, out_type, target):
"""segment_prod generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_segment_prod(topi.segment_op),
wrap_topi_schedule(topi.generic.schedule_segment_op),
name="segment_prod.generic",
)
return strategy

0 comments on commit 76681a3

Please sign in to comment.