Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Arith] Update BufferDomainTouched to support vector access. #11722

Merged
merged 2 commits into from
Jun 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions src/arith/domain_touched.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,14 @@ class BufferTouchedDomain final : public StmtExprVisitor {
}

Region FindUnion(const Buffer& buffer, bool consider_loads, bool consider_stores) {
Region ret;
auto kv = buffer_access_map_.find(buffer.get());
CHECK(kv != buffer_access_map_.end())
<< "The requested buffer is not contained in the provided stmt body.";
if (kv == buffer_access_map_.end()) {
LOG(WARNING) << "[arith::BufferDomainTouched] "
<< "The requested buffer is not contained in the provided stmt body: " << buffer;
return ret;
}

Region ret;
Range none;
BufferTouches bounds;
if (consider_loads && consider_stores) {
Expand Down Expand Up @@ -131,13 +134,16 @@ class BufferTouchedDomain final : public StmtExprVisitor {
}

private:
template <typename ArrayType>
void Touch(BufferTouches* bounds, const ArrayType& args) const {
void Touch(BufferTouches* bounds, const Array<PrimExpr>& args) const {
if (args.size() > bounds->size()) {
bounds->resize(args.size());
}
for (size_t i = 0; i < args.size(); ++i) {
(*bounds)[i].emplace_back(EvalSet(args[i], dom_map_));
if (args[i].as<RampNode>()) {
(*bounds)[i].emplace_back(IntSet::Vector(args[i]));
} else {
(*bounds)[i].emplace_back(EvalSet(args[i], dom_map_));
}
}
}

Expand Down
63 changes: 38 additions & 25 deletions tests/python/unittest/test_arith_domain_touched.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,36 @@
# under the License.
import tvm
from tvm import te
from tvm.script import tir as T


@T.prim_func
def scalar_func(a: T.handle, b: T.handle):
m = T.var("int32")
n = T.int32(100)
A = T.match_buffer(a, (n, m), name="A")
B = T.match_buffer(b, (n, m), name="B")

for i, j in T.grid(n, m):
A[i, j] = B[i - 1, j + 1] + A[i - 1, j - 1]


@T.prim_func
def vector_func(a: T.handle, b: T.handle):
n = T.var("int32")
m = T.int32(128)
A = T.match_buffer(a, (n, m), name="A")
B = T.match_buffer(b, (n, m), name="B")

for i in T.serial(n):
for j in T.vectorized(m):
A[i, j] = A[i, j] + B[i, j]


def test_domain_touched():
i = te.var("i")
j = te.var("j")
n = tvm.runtime.convert(100)
m = te.var("m")

a = tvm.tir.decl_buffer((n, m), name="a")
b = tvm.tir.decl_buffer((n, m), name="b")

ir = tvm.tir.For(
i,
0,
n,
tvm.tir.ForKind.SERIAL,
tvm.tir.For(
j,
0,
m,
tvm.tir.ForKind.SERIAL,
tvm.tir.BufferStore(
a,
tvm.tir.BufferLoad(b, [i - 1, j + 1]) + tvm.tir.BufferLoad(a, [i - 1, j - 1]),
[i, j],
),
),
)
func = scalar_func
a, b = [func.buffer_map[var] for var in func.params]
ir = func.body

a_domain_r = tvm.arith._ffi_api.DomainTouched(ir, a, True, False)

Expand Down Expand Up @@ -78,5 +80,16 @@ def test_domain_touched():
assert len(b_domain_w) == 0


def test_domain_touched_vector():
func = tvm.lower(vector_func)["main"]
a, b = [func.buffer_map[var] for var in func.params]

assert tvm.arith._ffi_api.DomainTouched(func.body, a, True, False)[0].extent.value == 128
assert tvm.arith._ffi_api.DomainTouched(func.body, a, True, False)[0].extent.value == 128
assert tvm.arith._ffi_api.DomainTouched(func.body, a, True, True)[0].extent.value == 128
assert tvm.arith._ffi_api.DomainTouched(func.body, b, True, False)[0].extent.value == 128
assert tvm.arith._ffi_api.DomainTouched(func.body, b, True, False)[0].extent.value == 128


if __name__ == "__main__":
test_domain_touched()