diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 822e8e468377..c586e81f1b9c 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -492,6 +492,34 @@ Var EnvThread(String thread_tag) { } void BufferStore(Buffer buffer, PrimExpr value, Array indices) { + runtime::DataType buffer_dtype = buffer->dtype; + int index_lanes = indices.size() ? indices.back().dtype().lanes() : 1; + runtime::DataType lhs_dtype = buffer_dtype.with_lanes(buffer_dtype.lanes() * index_lanes); + runtime::DataType rhs_dtype = value->dtype; + if (lhs_dtype != rhs_dtype) { + if (lhs_dtype.lanes() != rhs_dtype.lanes()) { + LOG(FATAL) << "TypeError: Incompatible types in BufferStore" + << ": LHS is `" << lhs_dtype << "`, RHS is `" << rhs_dtype + << "`, indexing lanes: " << index_lanes; + } + if (lhs_dtype.code() != rhs_dtype.code()) { + if ( + // Case 1. lhs is handle, and rhs needs to be casted to handle. + (lhs_dtype.code() == runtime::DataType::kHandle) || + // Case 2. rhs is handle, and it needs to be casted to non-handle. + (rhs_dtype.code() == runtime::DataType::kHandle) || + // Case 3. rhs is float or bfloat, and casting to non-float can lose precision. + ((lhs_dtype.code() == runtime::DataType::kInt || + lhs_dtype.code() == runtime::DataType::kUInt) && + (rhs_dtype.code() == runtime::DataType::kFloat || + rhs_dtype.code() == runtime::DataType::kBFloat))) { + LOG(WARNING) << "Casting in BufferStore may lose precision" + << ": LHS is `" << lhs_dtype << "`, RHS is `" << rhs_dtype + << "`, indexing lanes: " << index_lanes; + } + } + value = tvm::cast(lhs_dtype, value); + } AddToParent(tvm::tir::BufferStore(buffer, value, indices)); } diff --git a/src/script/ir_builder/tir/utils.h b/src/script/ir_builder/tir/utils.h index 733c975fad7e..485757063867 100644 --- a/src/script/ir_builder/tir/utils.h +++ b/src/script/ir_builder/tir/utils.h @@ -21,6 +21,7 @@ #include #include +#include #include namespace tvm { diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 355a3b16b855..1652786cb510 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -507,6 +507,12 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, << "Cannot store value with " << value.dtype().lanes() << ", expected value with " << index_lanes * buffer_lanes << " (" << index_lanes << " index lanes * " << buffer_lanes << " buffer element lanes)"; + if (buffer->dtype.with_lanes(buffer_lanes * index_lanes) != value.dtype()) { + LOG(FATAL) << "TypeError: dtype mismatch on BufferStore: " // + << "buffer's dtype is `" << buffer->dtype // + << "`, the lanes of indexing are: `" << index_lanes // + << "`, but RHS's dtype is `" << value.dtype() << "`"; + } ObjectPtr node = make_object(); node->buffer = std::move(buffer); diff --git a/tests/python/unittest/test_tir_constructor.py b/tests/python/unittest/test_tir_constructor.py index 6f591efc2d2d..2df644d7e198 100644 --- a/tests/python/unittest/test_tir_constructor.py +++ b/tests/python/unittest/test_tir_constructor.py @@ -16,7 +16,6 @@ # under the License. import pytest - import tvm from tvm import te @@ -153,7 +152,7 @@ def test_stmt_constructor(): buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("uint1"))) buffer = tvm.tir.decl_buffer([16], "uint1", data=buffer_var) - x = tvm.tir.BufferStore(buffer, 1, [10]) + x = tvm.tir.BufferStore(buffer, tvm.tir.IntImm("bool", 1), [10]) assert isinstance(x, tvm.tir.BufferStore) assert x.buffer == buffer assert x.buffer.data == buffer_var