Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] GPU scatter_add using atomic #7044

Merged
merged 5 commits into from
Dec 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1921,18 +1921,29 @@ def empty(self, inputs, input_types):
def bincount(self, inputs, input_types):
data = inputs[0]
weights = inputs[1]
input_type = _infer_type(data).checked_type.dtype
if input_type == "int64":
logging.warning(
"Casting an int64 input to int32, since we do not have int64 atomic add"
"needed for bincount yet."
)
data = _op.cast(data, "int32")
maximum = _op.max(data)
dim = maximum + _expr.const(1, dtype="int64")
dim = maximum + _expr.const(1, dtype="int32")
if weights:
weight_type = _infer_type(weights).checked_type
out_dtype = weight_type.dtype
updates = weights
else:
out_dtype = "int64"
out_dtype = "int32"
updates = _op.ones_like(data)

counts = _op.zeros(_op.reshape(dim, [1]), out_dtype)
return _op.scatter_add(counts, data, updates, axis=0)
out = _op.scatter_add(counts, data, updates, axis=0)
if input_type == "int32":
# Torch always outputs int64 results for bincount
return _op.cast(out, "int64")
return out

def scatter_add(self, inputs, input_types):
data = inputs[0]
Expand Down
80 changes: 79 additions & 1 deletion python/tvm/topi/cuda/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import tvm
from tvm import te
from ..scatter import _verify_scatter_nd_inputs
from .nms import atomic_add


def ceil_div(a, b):
Expand Down Expand Up @@ -470,6 +471,83 @@ def update_func(dst_ptr, dst_index, update):
return out


def gen_scatter_add_1d_atomic(data, indices, updates, axis, out, _):
"""Generate scatter add ir for 1d inputs, using atomic_add instruction

Parameters
----------
data : tir.Tensor
The input data to the operator.

indices : tir.Tensor
The index locations to update.

updates : tir.Tensor
The values to update.

axis : int
The axis to scatter on

out : tir.Tensor
The output tensor.

Returns
-------
ret : tir
The computational ir.
"""
assert axis == 0
n = data.shape[0]

ib = tvm.tir.ir_builder.create()

out_ptr = ib.buffer_ptr(out)
data_ptr = ib.buffer_ptr(data)

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads

with ib.new_scope():
nthread_bx = ceil_div(n, nthread_tx)
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * nthread_tx + tx
with ib.if_scope(tid < n):
out_ptr[tid] = data_ptr[tid]

indices_ptr = ib.buffer_ptr(indices)
updates_ptr = ib.buffer_ptr(updates)

ni = indices.shape[0]

atomic_add_return = ib.allocate(updates.dtype, (1,), name="atomic_add_return", scope="local")

with ib.new_scope():
nthread_bx = ceil_div(ni, nthread_tx)
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * nthread_tx + tx

with ib.if_scope(tid < ni):
index = indices_ptr[tid]
with ib.if_scope(index < 0):
atomic_add_return[0] = atomic_add(
tvm.tir.call_intrin("handle", "tir.address_of", out_ptr[index + n]),
updates_ptr[tid],
)
with ib.else_scope():
atomic_add_return[0] = atomic_add(
tvm.tir.call_intrin("handle", "tir.address_of", out_ptr[index]),
updates_ptr[tid],
)

return ib.get()


def scatter_add(data, indices, updates, axis=0):
"""Update data by adding values in updates at positions defined by indices

Expand Down Expand Up @@ -501,7 +579,7 @@ def scatter_add(data, indices, updates, axis=0):
assert 1 <= rank <= 4, "scatter_add only supports 1-4 dimensions"

ir_funcs = {
1: gen_ir_1d,
1: gen_scatter_add_1d_atomic,
2: gen_ir_2d,
3: gen_ir_3d,
4: gen_ir_4d,
Expand Down
10 changes: 5 additions & 5 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3355,12 +3355,12 @@ def test_bincount():
def test_fn(x, weights=None):
return torch.bincount(x, weights=weights)

inp = torch.randint(0, 8, (5,), dtype=torch.int64)
weights = torch.linspace(0, 1, steps=5)
inp = torch.randint(0, 100, (10000,), dtype=torch.int64)
weights = torch.linspace(0, 100, steps=10000)

verify_trace_model(test_fn, [inp], ["llvm"])
verify_trace_model(test_fn, [inp, weights], ["llvm"])
verify_trace_model(test_fn, [inp, weights.to(torch.float64)], ["llvm"])
masahi marked this conversation as resolved.
Show resolved Hide resolved
targets = ["llvm", "cuda"]
verify_trace_model(test_fn, [inp], targets)
verify_trace_model(test_fn, [inp, weights], targets)


if __name__ == "__main__":
Expand Down
4 changes: 4 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,11 +1017,15 @@ def verify_scatter_add(dshape, ishape, axis=0):
ref_res = ref_scatter_add(data_np, indices_np, updates_np, axis)
for target, ctx in tvm.testing.enabled_targets():
for kind in ["graph", "debug"]:
if target == "nvptx":
# TODO(masahi): support atomic in LLVM codegen
continue
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(data_np, indices_np, updates_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)

verify_scatter_add((10,), (10,), 0)
verify_scatter_add((1000,), (1000,), 0)
verify_scatter_add((10, 5), (10, 5), -2)
verify_scatter_add((10, 5), (10, 5), -1)
verify_scatter_add((10, 5), (3, 5), 0)
Expand Down