Skip to content

Commit

Permalink
[PRNG] Add check to PRNG to make sure that unsigned integer arithmeti…
Browse files Browse the repository at this point in the history
…c is wrapping (apache#7287)

* [PRNG] Add check to PRNG to make sure that unsigned integer arithmetic is wrapping

* Add threefry_test_wrapping: a manual test for wrapping unsigned arithmetic.

* fix test to actually run on the target

* formatting

* lint
  • Loading branch information
tkonolige authored and Lokiiiiii committed Mar 1, 2021
1 parent b52e70d commit da7e72a
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 2 deletions.
62 changes: 60 additions & 2 deletions python/tvm/topi/random/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""Pseudorandom number kernels."""
import tvm
import tvm.topi
import numpy as np
from ... import tir
from ...tir import ir_builder

Expand Down Expand Up @@ -135,7 +136,7 @@ def _threefry(
assert key_buf.dtype == counter_buf.dtype, "threefry key and counter must be the same dtype"

def mix(a, b, rotation):
x = a + b # TODO should be wrapping
x = a + b # wrapping
y = x ^ ((b << rotation) | (b >> (iwidth - rotation)))
return [x, y]

Expand Down Expand Up @@ -167,7 +168,7 @@ def key_schedule(s, i):
with irb.for_range(0, out_shape, name="l") as l: # pylint: disable=invalid-name
for i in range(nrounds // 4):
for j in range(nwords):
out_buf[out_offset + l * nwords + j] += key_schedule(i, j) # TODO wrapping
out_buf[out_offset + l * nwords + j] += key_schedule(i, j) # wrapping
for k in range(4):
for j in range(nwords // 2):
(
Expand Down Expand Up @@ -201,6 +202,13 @@ def threefry_generate(gen, out_shape):
then a new generator is created by applying Threefry to the current key, path, and counter.
This new generator will have a reset counter.
Warning
-------
Threeyfry requires that unsigned integer arithmetic wraps on overflow. Currently TVM has no
guarantee of this, so threefry contains an internal assert to check wrapping behavior. This
assert may or may not run depending on your platform, so it is recommended you run
:py:func:`threefry_test_wrapping` to verify wrapping behavior.
Parameters
----------
gen : Tensor[10, uint64]
Expand Down Expand Up @@ -234,6 +242,18 @@ def gen_ir(gen_ptr, out_gen_ptr, out_array_ptr):
out_gen = irb.buffer_ptr(out_gen_ptr)
out_array = irb.buffer_ptr(out_array_ptr)

# Check that unsigned arithmetic wraps, as it is required to implement threefry correctly.
irb.emit(
tvm.tir.AssertStmt(
tvm.tir.const(0xFFFFFFFFFFFFFFFF, "uint64") + tvm.tir.const(1, "uint64")
== tvm.tir.const(0, "uint64"),
tvm.tir.StringImm(
"Unsigned integer arithmetic is not wrapping, but threefry requires wrapping."
),
tvm.tir.Evaluate(0),
)
)

# Create a temporary array to hold the generator state we will use to create the random
# numbers. We cannot use gen because we may need to update the key + path if there is not
# enough room in the counter.
Expand Down Expand Up @@ -408,3 +428,41 @@ def gen_ir(gen_ptr, out_left_ptr, out_right_ptr):
name="threefry_split",
tag="threefry_split",
)


def threefry_test_wrapping(target, ctx):
"""Test that unsigned arithmetic wraps on overflow.
Parameters
----------
target : tvm.target.Target
Target to run against
ctx : tvm.runtime.TVMContext
Context to run the test on
Returns
-------
is_wrapping : bool
Whether or not unsigned integer arithmetic is wrapping for this target, context pair. True
indicates that threefry will work on this platform.
"""
if isinstance(target, str):
target = tvm.target.Target(target)

def gen_ir(out_ptr):
irb = ir_builder.create()
out = irb.buffer_ptr(out_ptr)
if "gpu" in target.keys:
thread_x = tvm.te.thread_axis("threadIdx.x")
irb.scope_attr(thread_x, "thread_extent", 1)
out[0] = tvm.tir.const(0xFFFFFFFFFFFFFFFF, "uint64") + tvm.tir.const(1, "uint64")
return irb.get()

out = tvm.tir.decl_buffer((1,), dtype="uint64")
f = tvm.te.extern(
[out.shape], [], lambda ins, outs: gen_ir(outs[0]), dtype="uint64", out_buffers=[out]
)
s = tvm.te.create_schedule([f.op])
out_ary = tvm.nd.array(np.ones((1,), "uint64"), ctx)
tvm.build(s, [f], target=target)(out_ary)
return out_ary.asnumpy()[0] == 0
8 changes: 8 additions & 0 deletions tests/python/topi/python/test_topi_prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ def test_threefry_generate(target, ctx):
).any(), "Overflowing counter with no space left in path should change state"


@tvm.testing.parametrize_targets
def test_threefry_wrapping(target, ctx):
assert tvm.topi.random.threefry_test_wrapping(
target, ctx
), f"{target} does not suppport wrapping unsigned integer arithmetic"


if __name__ == "__main__":
test_threefry_split(tvm.target.Target("llvm"), tvm.context("cpu"))
test_threefry_generate(tvm.target.Target("llvm"), tvm.context("cpu"))
test_threefry_wrapping(tvm.target.Target("llvm"), tvm.context("cpu"))

0 comments on commit da7e72a

Please sign in to comment.