-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug] Check failed: (args.size() == initial_indices_orig.size()) is false #2276
Comments
Thanks for reporting if it is possibe to get a minimum repro that would be helpful. |
@jpf888 I met the same problem, did you solve it? |
Hi, this bugs can repro like:
Only when kernel_size is equal to 1, |
Hi @senlyu163 looks like It's a known issue when applying dlight on conv2d with a kernel size of 1. This issue arises because the reindex schedule performs simplifications on the expr. To address this, I previously created a draft PR. You can merge the relevant changes and modify the checkout this draft pr: apache/tvm#16440 The key component related to this issue is the addition of a def normalize_to_matmul(sch: tir.Schedule,
main_block: BlockRV,
layout: Optional[List[str]] = None) -> Optional[tir.Schedule]:
if layout is None:
layout = ["n", "t", "n"]
block_stmt = sch.get(main_block)
# Let layout be 'a' to auto infer the layout
index_maps = get_index_map(block_stmt, layout=layout)
if index_maps is None:
logger.debug("Cannot find the appropriate index map for tensorcore")
return None
matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
# Use `skip_simplify` to avoid the bug in the 1x1 conv
block = sch.reindex(main_block, ("read", 0), skip_simplify=True)
sch.transform_layout(block, ("write", 0), a_index_map)
block = sch.reindex(main_block, ("read", 1), skip_simplify=True)
sch.transform_layout(block, ("write", 0), b_index_map)
block = sch.reindex(main_block, ("write", 0), skip_simplify=True)
sch.transform_layout(block, ("read", 0), c_index_map)
sch.transform_block_layout(main_block, matmul_index_map)
sch.mod["main"] = sch.mod["main"].with_attr("dlight.tensorcore_prenormlized", True)
return sch |
🐛 Bug
To Reproduce
Steps to reproduce the behavior:
i think this problem is caused by the fusion of permute and conv operators after dl.gpu.Matmul(), resulting in a mismatch between buffer shape and index_map shape.
1、error log
tvm.error.InternalError: Traceback (most recent call last): 4: operator() at /workspace/tvm-unity/src/tir/schedule/schedule.cc:287 3: tvm::tir::TracedScheduleNode::TransformLayout(tvm::tir::BlockRV const&, int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool) at /workspace/tvm-unity/src/tir/schedule/traced_schedule.cc:678 2: tvm::tir::ConcreteScheduleNode::TransformLayout(tvm::tir::BlockRV const&, int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool) at /workspace/tvm-unity/src/tir/schedule/concrete_schedule.cc:993 1: tvm::tir::TransformLayout(tvm::tir::ScheduleState, tvm::tir::StmtSRef const&, int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool) at /workspace/tvm-unity/src/tir/schedule/primitive/layout_transformation.cc:1160 0: tvm::tir::LegalizeIndexMapDType(tvm::tir::IndexMap const&, tvm::runtime::Array<tvm::PrimExpr, void> const&) at /workspace/tvm-unity/src/tir/schedule/primitive/layout_transformation.cc:1106 File "/workspace/tvm-unity/src/tir/schedule/primitive/layout_transformation.cc", line 1106 InternalError: Check failed: (args.size() == initial_indices_orig.size()) is false:
2、other message
1).
T.index_map(lambda i0, i1, i2, i3, i4, i5: (T.int64(0), i1 * T.int64(64) + i2, i3)) ??? is not match?
2).
with T.block("conv2d_nchw", no_realize=True): v_nn = T.axis.spatial(T.int64(1)) v_ff = T.axis.spatial(T.int64(256)) v_yy = T.axis.spatial(T.int64(64)) v_xx = T.axis.spatial(T.int64(64)) v_rc = T.axis.reduce(T.int64(768)) v_ry = T.axis.reduce(T.int64(1)) v_rx = T.axis.reduce(T.int64(1)) pad_temp = T.Buffer((T.int64(1), T.int64(768), T.int64(64), T.int64(64)), "float16") B = T.Buffer((T.int64(256), T.int64(768), T.int64(1), T.int64(1)), "float16") T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], B[v_ff, v_rc, v_ry, v_rx]) conv2d_nchw = T.Buffer((T.int64(1), T.int64(256), T.int64(64), T.int64(64)), "float16") T.writes(conv2d_nchw[v_nn, v_ff, v_yy, v_xx]) with T.init(): conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = T.float16(0) conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw[v_nn, v_ff, v_yy, v_xx] + pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] * B[v_ff, v_rc, v_ry, v_rx]
3).
@T.prim_func(private=True) def main(permute_dims161: T.Buffer((T.int64(1), T.int64(768), T.int64(64), T.int64(64)), "float16"), vision_tower_vision_tower_high_neck_0_weight1: T.Buffer((T.int64(256), T.int64(768), T.int64(1), T.int64(1)), "float16"), compute_intermediate: T.Buffer((T.int64(1), T.int64(256), T.int64(64), T.int64(64)), "float32")): T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): pad_temp = T.alloc_buffer((T.int64(1), T.int64(768), T.int64(64), T.int64(64)), "float16") conv2d_nchw_intermediate = T.alloc_buffer((T.int64(1), T.int64(256), T.int64(64), T.int64(64)), "float16") for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(768), T.int64(64), T.int64(64)): with T.block("pad_temp"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(permute_dims161[v_i0, v_i1, v_i2, v_i3]) T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3]) pad_temp[v_i0, v_i1, v_i2, v_i3] = permute_dims161[v_i0, v_i1, v_i2, v_i3] for nn, ff, yy, xx, rc, ry, rx in T.grid(T.int64(1), T.int64(256), T.int64(64), T.int64(64), T.int64(768), T.int64(1), T.int64(1)): with T.block("conv2d_nchw"): v_nn, v_ff, v_yy, v_xx, v_rc, v_ry, v_rx = T.axis.remap("SSSSRRR", [nn, ff, yy, xx, rc, ry, rx]) T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], vision_tower_vision_tower_high_neck_0_weight1[v_ff, v_rc, v_ry, v_rx]) T.writes(conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx]) with T.init(): conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx] = T.float16(0) conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx] + pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] * vision_tower_vision_tower_high_neck_0_weight1[v_ff, v_rc, v_ry, v_rx] for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(256), T.int64(64), T.int64(64)): with T.block("compute"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(conv2d_nchw_intermediate[v_i0, v_i1, v_i2, v_i3]) T.writes(compute_intermediate[v_i0, v_i1, v_i2, v_i3]) compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", conv2d_nchw_intermediate[v_i0, v_i1, v_i2, v_i3])
T.index_map(lambda i0, i1, i2, i3, i4, i5: (T.int64(0), i1 * T.int64(64) + i2, i3))
Expected behavior
Environment
conda
, source):pip
, source):python -c "import tvm; print('\n'.join(f'{k}: {v}' for k, v in tvm.support.libinfo().items()))"
, applicable if you compile models):Additional context
The text was updated successfully, but these errors were encountered: