Skip to content

Commit

Permalink
[Arith][TIR] Recognize empty extents (apache#15129)
Browse files Browse the repository at this point in the history
Generate empty interval sets when empty extents are encountered. Handle
empty regions when constructing ScheduleState.
  • Loading branch information
Krzysztof Parzyszek authored Jun 21, 2023
1 parent 54b9741 commit b37ad17
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/arith/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1053,11 +1053,17 @@ Map<Var, arith::IntSet> AsIntSet(const Map<Var, Range>& var_dom) {
/*! \brief Helper function to convert IterSumExpr to the actual touched range. */
static Optional<IntSet> EvalIterSum(const IterSumExpr& iter_min, const PrimExpr& extent,
Analyzer* analyzer) {
if (analyzer->CanProve(extent == 0)) {
return IntSet::Nothing();
}
if (iter_min->args.empty()) {
return IntSet::FromMinExtent(iter_min->base, extent);
}
ICHECK_EQ(iter_min->args.size(), 1) << "The `EvalIterSum` expects fused iter sum expr";
const IterSplitExpr& split = iter_min->args[0];
if (analyzer->CanProve(split->extent == 0)) {
return IntSet::Nothing();
}
if (!analyzer->CanProve(extent >= split->scale)) {
return NullOpt;
}
Expand Down
3 changes: 3 additions & 0 deletions src/tir/schedule/state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ bool ProducerCoversConsumer(const Array<PrimExpr>& buffer_shape,
if (produced_region[i].IsNothing()) {
return false;
}
if (consumed_region[i].IsNothing()) {
continue;
}
arith::IntSet produced =
arith::IntSet::Interval(analyzer->canonical_simplify(produced_region[i].min()),
analyzer->canonical_simplify(produced_region[i].max()));
Expand Down
28 changes: 28 additions & 0 deletions tests/python/unittest/test_tir_schedule_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,5 +417,33 @@ def two_elementwise(a: T.handle, c: T.handle) -> None:
assert is_output_block(sch, block_rv)


def test_empty_grid():
@T.prim_func
def foo(out: T.Buffer((T.int64(1), T.int64(8), T.int64(8)), "int32")):
act = T.alloc_buffer((1, 8, 8), "int32")
for z2, y2, x2 in T.grid(1, 8, 8):
with T.block("b0"):
az, ay, ax = T.axis.remap("SSS", [z2, y2, x2])
T.writes(act[az, ay, ax])
act[az, ay, az] = T.int32(0)
# Empty grid:
for z1, y1, x1 in T.grid(0, 8, 8):
with T.block("b1"):
az, ay, ax = T.axis.remap("SSS", [z1, y1, x1])
T.reads(act[az + 1, ay, ax])
T.writes(out[az, ay, ax])
out[az, ay, ax] = act[az + 1, ay, ax]
# The block below is not needed to show the bug, but the 'out'
# buffer would be undefined without it.
for z2, y2, x2 in T.grid(1, 8, 8):
with T.block("b2"):
az, ay, ax = T.axis.remap("SSS", [z2, y2, x2])
T.writes(out[az, ay, ax])
out[az, ay, az] = T.int32(0)

# This caused a crash before.
sch = tvm.tir.Schedule(foo)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit b37ad17

Please sign in to comment.