Skip to content

Commit

Permalink
[Hexagon][LLVM] Enable/test tensorized Hexagon DMA on 2d transformed …
Browse files Browse the repository at this point in the history
…layout (#10905)

* [Hexagon][LLVM] Enable/test tensorized Hexagon DMA

- In the `CodeGenLLVM::CreateIntrinsic` handler for
  `builtin::address_of()`, pass N-d indices to
  `CodeGenLLVM::CreateBufferPtr`.  The base class implementation still
  asserts that there is a flat memory space, while the
  `CodeGenHexagon::CreateBufferPtr` override allows 2-d memory.

- Enable tensorization in `test_cache_read_write.py`, using
  `tir.address_of` to pass the lowered value.

Co-authored-by: Adam Straw <astraw@octoml.ai>

* [TIR] Allow buffer_bind_scope of N-d buffers

Previously, any `buffer_bind_scope` attribute that provides a view
into a non-flat buffer would result in an error.  After this commit,
`buffer_bind_scope` may be used for non-flat buffers, but use of
`arg_buffer->elem_offset` within the body of the bind statement is
still an error.

The `BufferNode::elem_offset` field represents the offset between the
pointer of the backing allocation and the first element of the buffer.
This offset is only well-defined for flat memory spaces.

* update test to tensorize cache_read `y` (works) and cache_write `z` (fails)

* add `split` to allow for tensorization of cache_write of `z`

* fix typo and cleanup comment

* add back original 1d test_cache_read_write

* update comments

* format error

Co-authored-by: Adam Straw <astraw@octoml.ai>
  • Loading branch information
Lunderberg and adstraw authored Apr 12, 2022
1 parent cd6aa7b commit 11d22bd
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 58 deletions.
16 changes: 11 additions & 5 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1006,13 +1006,19 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
} else if (op->op.same_as(builtin::address_of())) {
const BufferLoadNode* load = op->args[0].as<BufferLoadNode>();
ICHECK(op->args.size() == 1 && load);
ICHECK_EQ(load->indices.size(), 1) << "LLVM only supports flat memory allocations.";
PrimExpr index = load->indices[0];
if (const RampNode* r = index.as<RampNode>()) {
index = r->base;

Array<PrimExpr> indices = load->indices;
if (const RampNode* r = indices[indices.size() - 1].as<RampNode>()) {
indices.Set(indices.size() - 1, r->base);
}

std::vector<llvm::Value*> indices_val;
for (const auto& index : indices) {
indices_val.push_back(MakeValue(index));
}

TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(load->buffer->data), load->buffer->dtype,
{MakeValue(index)}, load->dtype);
indices_val, load->dtype);
unsigned addrspace =
llvm::dyn_cast<llvm::PointerType>(buffer_ptr.addr->getType())->getAddressSpace();
return builder_->CreatePointerCast(buffer_ptr.addr, t_char_->getPointerTo(addrspace));
Expand Down
17 changes: 14 additions & 3 deletions src/tir/ir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,6 @@ Buffer Buffer::MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const
begins = SimplifyArray(&ana, begins);
Array<PrimExpr> elem_offset = n->ElemOffset(begins);
elem_offset.MutateByApply([&](const PrimExpr& expr) { return ana.Simplify(expr); });
ICHECK_EQ(elem_offset.size(), 1) << "MakeSlice currently supports only flat 1-d memory.";

Array<PrimExpr> strides = n->strides;
if (strides.size() == 0) {
Expand All @@ -480,8 +479,20 @@ Buffer Buffer::MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const
return MakeStrideView().MakeSlice(begins, extents);
}
}
return Buffer(n->data, n->dtype, extents, strides, elem_offset[0], n->name + "_slice",
n->data_alignment, 0, n->buffer_type);
Buffer slice(n->data, n->dtype, extents, strides, elem_offset[0], n->name + "_slice",
n->data_alignment, 0, n->buffer_type);

// Buffer must be constructed with a singular element offset which means there is no
// support for n-dimensional buffers where n > 1. Insert sentinel value for
// ArgBinder::BindBuffer to state that any usage of element offset is invalid
// in this case. This allows for construction of a Buffer with multiple element offsets
// but disallows any usage of those element offsets. See PR #10816 for discussion on
// supporting multiple element offsets in TIR Buffer.
// TODO(Lunderberg): Remove if/when TIR supports multiple element offsets in TIR Buffer
if (elem_offset.size() != 1) {
slice.CopyOnWrite()->elem_offset = PrimExpr();
}
return slice;
}

PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes,
Expand Down
33 changes: 18 additions & 15 deletions src/tir/transforms/arg_binder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,22 +96,25 @@ void ArgBinder::BindBuffer(const Buffer& arg, const Buffer& value, const std::st
<< " required_alignment=" << arg->data_alignment
<< ", provided_alignment=" << value->data_alignment;
}
// bind pointer and offset.
if (is_zero(arg->elem_offset)) {
ICHECK(is_zero(value->elem_offset))
<< "Trying to bind a Buffer with offset into one without offset "
<< " required elem_offset=" << arg->elem_offset
<< ", provided elem_offset=" << value->elem_offset;
}

this->Bind(arg->data, value->data, arg_name + ".data");
if (Bind_(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset", false)) {
if (arg->offset_factor > 1) {
PrimExpr offset = value->elem_offset;
PrimExpr factor = make_const(offset.dtype(), arg->offset_factor);
PrimExpr zero = make_zero(offset.dtype());
BinderAddAssert(&analyzer_, truncmod(offset, factor) == zero, arg_name + ".elem_offset",
&asserts_);
if (value->elem_offset.defined()) {
// bind pointer and offset.
if (is_zero(arg->elem_offset)) {
ICHECK(is_zero(value->elem_offset))
<< "Trying to bind a Buffer with offset into one without offset "
<< " required elem_offset=" << arg->elem_offset
<< ", provided elem_offset=" << value->elem_offset;
}

this->Bind(arg->data, value->data, arg_name + ".data");
if (Bind_(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset", false)) {
if (arg->offset_factor > 1) {
PrimExpr offset = value->elem_offset;
PrimExpr factor = make_const(offset.dtype(), arg->offset_factor);
PrimExpr zero = make_zero(offset.dtype());
BinderAddAssert(&analyzer_, truncmod(offset, factor) == zero, arg_name + ".elem_offset",
&asserts_);
}
}
}

Expand Down
10 changes: 10 additions & 0 deletions src/tir/transforms/storage_flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,9 @@ class BufferBindUnwrapper : public StmtExprMutator {
}

PrimExpr VisitExpr_(const VarNode* op) final {
ICHECK(!illegal_vars_.count(op)) << "Variable " << op->name_hint << " is not well defined. "
<< "(e.g. use of buffer.elem_offset for a non-flat buffer)";

auto it = var_remap_.find(op);
if (it != var_remap_.end()) {
return it->second;
Expand Down Expand Up @@ -1110,6 +1113,11 @@ class BufferBindUnwrapper : public StmtExprMutator {
// transformations should have been handled in
// BufferShapeLegalize.
binder.BindBuffer(source, view, source->name, false);
if (auto* elem_offset_var = source->elem_offset.as<VarNode>()) {
if (!view->elem_offset.defined()) {
illegal_vars_.insert(elem_offset_var);
}
}

// Apply the remaps
Stmt body = op->body;
Expand Down Expand Up @@ -1162,6 +1170,8 @@ class BufferBindUnwrapper : public StmtExprMutator {
// The buffer assignment map
// Variable remap
std::unordered_map<const VarNode*, PrimExpr> var_remap_;
// Variables that may not occur within the body.
std::unordered_set<const VarNode*> illegal_vars_;
// Buffer map
std::unordered_map<const BufferNode*, BufferEntry> buf_map_;
// Set of vars that have occurred in an AllocateNode, but haven't
Expand Down
125 changes: 90 additions & 35 deletions tests/python/contrib/test_hexagon/test_cache_read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@


def intrin_mem_copy(shape, dtype, dst_scope, src_scope):
assert len(shape) == 1
src = te.placeholder(shape=shape, dtype=dtype, name="src")
dst = te.compute(shape, lambda i: src[i], name="dst")
size = shape[0] * np.dtype(dtype).itemsize
Expand All @@ -38,30 +37,72 @@ def intrin_mem_copy(shape, dtype, dst_scope, src_scope):
dtype,
scope=src_scope,
offset_factor=1,
name="mem_copy_src_buffer",
)

dst_buffer = tvm.tir.decl_buffer(
shape,
dtype,
scope=dst_scope,
offset_factor=1,
name="mem_copy_dst_buffer",
)

zero_indices = [0 for _ in shape]

def intrin_func(ins, outs):
ib = tvm.tir.ir_builder.create()

_src = ins[0]
_dst = outs[0]

dst_handle = ib.buffer_ptr(dst_buffer)
src_handle = ib.buffer_ptr(src_buffer)

ib.emit(
tvm.tir.call_intrin(
"handle", "tir.mem_copy", _dst.access_ptr("w"), _src.access_ptr("r"), size
"handle",
"tir.mem_copy",
tvm.tir.call_intrin("handle", "tir.address_of", dst_handle[zero_indices]),
tvm.tir.call_intrin("handle", "tir.address_of", src_handle[zero_indices]),
size,
)
)
return ib.get()

return te.decl_tensor_intrin(dst.op, intrin_func, binds={src: src_buffer, dst: dst_buffer})


def verify(hexagon_session, s, x, y, z, size):
print(tvm.lower(s, [x, y, z]))

target_hexagon = tvm.target.hexagon("v68", link_params=True)
func = tvm.build(
s, [x, y, z], tvm.target.Target(target_hexagon, host=target_hexagon), name="dmacpy"
)

if hexagon_session is None:
pytest.skip("Skip hardware test since ANDROID_SERIAL_NUMBER is not set.")

mod = hexagon_session.load_module(func)
xt = tvm.nd.array(
np.random.randint(low=-128, high=127, size=size, dtype=x.dtype),
device=hexagon_session.device,
)
yt = tvm.nd.array(
np.random.randint(low=-128, high=127, size=size, dtype=y.dtype),
device=hexagon_session.device,
)
zt = tvm.nd.array(
np.random.randint(low=-128, high=127, size=size, dtype=z.dtype),
device=hexagon_session.device,
)
mod["dmacpy"](xt, yt, zt)

ref = xt.numpy() + yt.numpy()
np.testing.assert_equal(zt.numpy(), ref)


@requires_hexagon_toolchain
def test_cache_read_write(hexagon_session):
size = 128
Expand All @@ -75,52 +116,66 @@ def test_cache_read_write(hexagon_session):
z = te.compute(outer_shape, lambda i: x[i] + y[i], name="z")
s = te.create_schedule(z.op)

x_global = s.cache_read(x, "global.vtcm", [z])
y_global = s.cache_read(y, "global.vtcm", [z])
z_global = s.cache_write(z, "global.vtcm")
x_vtcm = s.cache_read(x, "global.vtcm", [z])
y_vtcm = s.cache_read(y, "global.vtcm", [z])
z_vtcm = s.cache_write(z, "global.vtcm")

zouter, zinner = s[z_global].split(z_global.op.axis[0], factor=factor)
zouter, zinner = s[z_vtcm].split(z_vtcm.op.axis[0], factor=factor)

s[x_global].compute_at(s[z_global], zouter)
s[y_global].compute_at(s[z_global], zouter)
s[x_vtcm].compute_at(s[z_vtcm], zouter)
s[y_vtcm].compute_at(s[z_vtcm], zouter)

mem_copy_read = intrin_mem_copy(inner_shape, dtype, "global.vtcm", "global")

(cache_read_x,) = s[x_global].op.axis
s[x_global].tensorize(cache_read_x, mem_copy_read)
(cache_read_x,) = s[x_vtcm].op.axis
s[x_vtcm].tensorize(cache_read_x, mem_copy_read)

(cache_read_y,) = s[y_global].op.axis
s[y_global].tensorize(cache_read_y, mem_copy_read)
(cache_read_y,) = s[y_vtcm].op.axis
s[y_vtcm].tensorize(cache_read_y, mem_copy_read)

mem_copy_write = intrin_mem_copy(outer_shape, dtype, "global", "global.vtcm")

(cache_write_z,) = s[z].op.axis
s[z].tensorize(cache_write_z, mem_copy_write)

print(tvm.lower(s, [x, y, z]))
verify(hexagon_session, s, x, y, z, size)

target_hexagon = tvm.target.hexagon("v68", link_params=True)
func = tvm.build(
s, [x, y, z], tvm.target.Target(target_hexagon, host=target_hexagon), name="dmacpy"
)

if hexagon_session is None:
pytest.skip("Skip hardware test since ANDROID_SERIAL_NUMBER is not set.")
def layout_transform_2d(n):
return [n // 16, te.AXIS_SEPARATOR, n % 16]

mod = hexagon_session.load_module(func)
xt = tvm.nd.array(
np.random.randint(low=-128, high=127, size=size, dtype=x.dtype),
device=hexagon_session.device,
)
yt = tvm.nd.array(
np.random.randint(low=-128, high=127, size=size, dtype=y.dtype),
device=hexagon_session.device,
)
zt = tvm.nd.array(
np.random.randint(low=-128, high=127, size=size, dtype=z.dtype),
device=hexagon_session.device,
)
mod["dmacpy"](xt, yt, zt)

ref = xt.numpy() + yt.numpy()
np.testing.assert_equal(zt.numpy(), ref)
@requires_hexagon_toolchain
def test_cache_read_write_2d(hexagon_session):
size = 128
outer_shape = (size,)
factor = 16
inner_shape = (factor,)
dtype = "int8"

x = te.placeholder(shape=outer_shape, dtype=dtype, name="x")
y = te.placeholder(shape=outer_shape, dtype=dtype, name="y")
z = te.compute(outer_shape, lambda i: x[i] + y[i], name="z")
s = te.create_schedule(z.op)

x_vtcm = s.cache_read(x, "global.vtcm", [z])
y_vtcm = s.cache_read(y, "global.vtcm", [z])
z_vtcm = s.cache_write(z, "global.vtcm")

layout_x_vtcm = s[x_vtcm].transform_layout(layout_transform_2d)
layout_y_vtcm = s[y_vtcm].transform_layout(layout_transform_2d)
layout_z_vtcm = s[z_vtcm].transform_layout(layout_transform_2d)

mem_copy_read = intrin_mem_copy(inner_shape, dtype, "global.vtcm", "global")
s[x_vtcm].tensorize(layout_x_vtcm[1], mem_copy_read)
s[y_vtcm].tensorize(layout_y_vtcm[1], mem_copy_read)

# The loop schedule over `z` is not modified when calling `transform_layout`
# on `z_vtcm` above therefore we must call `split` to modify the loop schedule
# over `z` to match the layout of `z_vtcm` such that we can accurately write
# `z_vtcm` back to `z` using memory copy intrinsic
zouter, zinner = s[z].split(z.op.axis[0], factor=factor)
mem_copy_write = intrin_mem_copy(inner_shape, dtype, "global", "global.vtcm")
s[z].tensorize(zinner, mem_copy_write)

verify(hexagon_session, s, x, y, z, size)

0 comments on commit 11d22bd

Please sign in to comment.