diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 4f75cf380cc6..d2c52fbc262a 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1921,18 +1921,29 @@ def empty(self, inputs, input_types): def bincount(self, inputs, input_types): data = inputs[0] weights = inputs[1] + input_type = _infer_type(data).checked_type.dtype + if input_type == "int64": + logging.warning( + "Casting an int64 input to int32, since we do not have int64 atomic add" + "needed for bincount yet." + ) + data = _op.cast(data, "int32") maximum = _op.max(data) - dim = maximum + _expr.const(1, dtype="int64") + dim = maximum + _expr.const(1, dtype="int32") if weights: weight_type = _infer_type(weights).checked_type out_dtype = weight_type.dtype updates = weights else: - out_dtype = "int64" + out_dtype = "int32" updates = _op.ones_like(data) counts = _op.zeros(_op.reshape(dim, [1]), out_dtype) - return _op.scatter_add(counts, data, updates, axis=0) + out = _op.scatter_add(counts, data, updates, axis=0) + if input_type == "int32": + # Torch always outputs int64 results for bincount + return _op.cast(out, "int64") + return out def scatter_add(self, inputs, input_types): data = inputs[0] diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index 5e03fafcfb58..89c5cd23111b 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -19,6 +19,7 @@ import tvm from tvm import te from ..scatter import _verify_scatter_nd_inputs +from .nms import atomic_add def ceil_div(a, b): @@ -470,6 +471,83 @@ def update_func(dst_ptr, dst_index, update): return out +def gen_scatter_add_1d_atomic(data, indices, updates, axis, out, _): + """Generate scatter add ir for 1d inputs, using atomic_add instruction + + Parameters + ---------- + data : tir.Tensor + The input data to the operator. + + indices : tir.Tensor + The index locations to update. + + updates : tir.Tensor + The values to update. + + axis : int + The axis to scatter on + + out : tir.Tensor + The output tensor. + + Returns + ------- + ret : tir + The computational ir. + """ + assert axis == 0 + n = data.shape[0] + + ib = tvm.tir.ir_builder.create() + + out_ptr = ib.buffer_ptr(out) + data_ptr = ib.buffer_ptr(data) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + + with ib.new_scope(): + nthread_bx = ceil_div(n, nthread_tx) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * nthread_tx + tx + with ib.if_scope(tid < n): + out_ptr[tid] = data_ptr[tid] + + indices_ptr = ib.buffer_ptr(indices) + updates_ptr = ib.buffer_ptr(updates) + + ni = indices.shape[0] + + atomic_add_return = ib.allocate(updates.dtype, (1,), name="atomic_add_return", scope="local") + + with ib.new_scope(): + nthread_bx = ceil_div(ni, nthread_tx) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * nthread_tx + tx + + with ib.if_scope(tid < ni): + index = indices_ptr[tid] + with ib.if_scope(index < 0): + atomic_add_return[0] = atomic_add( + tvm.tir.call_intrin("handle", "tir.address_of", out_ptr[index + n]), + updates_ptr[tid], + ) + with ib.else_scope(): + atomic_add_return[0] = atomic_add( + tvm.tir.call_intrin("handle", "tir.address_of", out_ptr[index]), + updates_ptr[tid], + ) + + return ib.get() + + def scatter_add(data, indices, updates, axis=0): """Update data by adding values in updates at positions defined by indices @@ -501,7 +579,7 @@ def scatter_add(data, indices, updates, axis=0): assert 1 <= rank <= 4, "scatter_add only supports 1-4 dimensions" ir_funcs = { - 1: gen_ir_1d, + 1: gen_scatter_add_1d_atomic, 2: gen_ir_2d, 3: gen_ir_3d, 4: gen_ir_4d, diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 6250dfff811a..2dda675c74f5 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3355,12 +3355,12 @@ def test_bincount(): def test_fn(x, weights=None): return torch.bincount(x, weights=weights) - inp = torch.randint(0, 8, (5,), dtype=torch.int64) - weights = torch.linspace(0, 1, steps=5) + inp = torch.randint(0, 100, (10000,), dtype=torch.int64) + weights = torch.linspace(0, 100, steps=10000) - verify_trace_model(test_fn, [inp], ["llvm"]) - verify_trace_model(test_fn, [inp, weights], ["llvm"]) - verify_trace_model(test_fn, [inp, weights.to(torch.float64)], ["llvm"]) + targets = ["llvm", "cuda"] + verify_trace_model(test_fn, [inp], targets) + verify_trace_model(test_fn, [inp, weights], targets) if __name__ == "__main__": diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 82d056381666..fc1929e9dc18 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1017,11 +1017,15 @@ def verify_scatter_add(dshape, ishape, axis=0): ref_res = ref_scatter_add(data_np, indices_np, updates_np, axis) for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: + if target == "nvptx": + # TODO(masahi): support atomic in LLVM codegen + continue intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(data_np, indices_np, updates_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) verify_scatter_add((10,), (10,), 0) + verify_scatter_add((1000,), (1000,), 0) verify_scatter_add((10, 5), (10, 5), -2) verify_scatter_add((10, 5), (10, 5), -1) verify_scatter_add((10, 5), (3, 5), 0)