diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index cf93a481c226..625488430bf8 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -1053,11 +1053,17 @@ Map AsIntSet(const Map& var_dom) { /*! \brief Helper function to convert IterSumExpr to the actual touched range. */ static Optional EvalIterSum(const IterSumExpr& iter_min, const PrimExpr& extent, Analyzer* analyzer) { + if (analyzer->CanProve(extent == 0)) { + return IntSet::Nothing(); + } if (iter_min->args.empty()) { return IntSet::FromMinExtent(iter_min->base, extent); } ICHECK_EQ(iter_min->args.size(), 1) << "The `EvalIterSum` expects fused iter sum expr"; const IterSplitExpr& split = iter_min->args[0]; + if (analyzer->CanProve(split->extent == 0)) { + return IntSet::Nothing(); + } if (!analyzer->CanProve(extent >= split->scale)) { return NullOpt; } diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index 33f859828927..facee629afcf 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -99,6 +99,9 @@ bool ProducerCoversConsumer(const Array& buffer_shape, if (produced_region[i].IsNothing()) { return false; } + if (consumed_region[i].IsNothing()) { + continue; + } arith::IntSet produced = arith::IntSet::Interval(analyzer->canonical_simplify(produced_region[i].min()), analyzer->canonical_simplify(produced_region[i].max())); diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py index cd91a44b6518..4484b6ab39ba 100644 --- a/tests/python/unittest/test_tir_schedule_analysis.py +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -417,5 +417,33 @@ def two_elementwise(a: T.handle, c: T.handle) -> None: assert is_output_block(sch, block_rv) +def test_empty_grid(): + @T.prim_func + def foo(out: T.Buffer((T.int64(1), T.int64(8), T.int64(8)), "int32")): + act = T.alloc_buffer((1, 8, 8), "int32") + for z2, y2, x2 in T.grid(1, 8, 8): + with T.block("b0"): + az, ay, ax = T.axis.remap("SSS", [z2, y2, x2]) + T.writes(act[az, ay, ax]) + act[az, ay, az] = T.int32(0) + # Empty grid: + for z1, y1, x1 in T.grid(0, 8, 8): + with T.block("b1"): + az, ay, ax = T.axis.remap("SSS", [z1, y1, x1]) + T.reads(act[az + 1, ay, ax]) + T.writes(out[az, ay, ax]) + out[az, ay, ax] = act[az + 1, ay, ax] + # The block below is not needed to show the bug, but the 'out' + # buffer would be undefined without it. + for z2, y2, x2 in T.grid(1, 8, 8): + with T.block("b2"): + az, ay, ax = T.axis.remap("SSS", [z2, y2, x2]) + T.writes(out[az, ay, ax]) + out[az, ay, az] = T.int32(0) + + # This caused a crash before. + sch = tvm.tir.Schedule(foo) + + if __name__ == "__main__": tvm.testing.main()