diff --git a/python/tvm/topi/cuda/reduction.py b/python/tvm/topi/cuda/reduction.py index ee868ac9b639..ceab71640533 100644 --- a/python/tvm/topi/cuda/reduction.py +++ b/python/tvm/topi/cuda/reduction.py @@ -130,12 +130,14 @@ def traverse_after_reduce(operator): for tensor in operator.input_tensors: traverse_after_reduce(tensor.op) elif operator.tag == "comm_reduce": - _schedule_reduce(operator, sch, is_idx_reduce=False) + if operator not in scheduled_ops: + _schedule_reduce(operator, sch, is_idx_reduce=False) for tensor in operator.input_tensors: if tensor.op not in scheduled_ops: traverse_before_reduce(tensor.op) elif operator.tag == "comm_reduce_idx": - _schedule_reduce(operator, sch, is_idx_reduce=True) + if operator not in scheduled_ops: + _schedule_reduce(operator, sch, is_idx_reduce=True) input_tensors = operator.input_tensors[0].op.input_tensors for tensor in input_tensors: if tensor.op not in scheduled_ops: @@ -147,5 +149,6 @@ def traverse_after_reduce(operator): scheduled_ops.append(operator) - traverse_after_reduce(outs[0].op) + for out in outs: + traverse_after_reduce(out.op) return sch diff --git a/tests/python/topi/python/test_topi_reduce.py b/tests/python/topi/python/test_topi_reduce.py index daf380d9f4e6..9ddcb0d3884b 100644 --- a/tests/python/topi/python/test_topi_reduce.py +++ b/tests/python/topi/python/test_topi_reduce.py @@ -152,5 +152,31 @@ def test_reduce_map(): ) +@tvm.testing.uses_gpu +def test_complex_reduce(): + in_shape = (2, 3) + dtype = "float32" + axis = 0 + keepdims = False + A = te.placeholder(shape=in_shape, name="A", dtype=dtype) + B = topi.sum(A, axis=axis, keepdims=keepdims) + C = topi.add(B, B) + D = topi.multiply(B, B) + E = topi.add(C, D) + for device, ctx in tvm.testing.enabled_targets(): + print("Running on target: %s" % device) + with tvm.target.Target(device): + s = tvm.topi.testing.get_reduce_schedule(device)(E) + foo = tvm.build(s, [A, E], device, name="sum") + in_npy = np.random.uniform(-1, 1, size=in_shape).astype(dtype) + sum_npy = in_npy.sum(axis=axis, keepdims=keepdims) + out_npy = sum_npy * 2 + sum_npy * sum_npy + data_tvm = tvm.nd.array(in_npy, ctx=ctx) + out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx, dtype=dtype) + foo(data_tvm, out_tvm) + tvm.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1e-3, 1e-3) + + if __name__ == "__main__": test_reduce_map() + test_complex_reduce()