diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 1faef5340f20..1dfa09ffcce9 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -137,17 +137,16 @@ class AOTOnDemandAllocator : public ExprVisitor { void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "if is not supported."; } - void VisitExpr_(const LetNode* op) final { LOG(FATAL) << "if is not supported."; } + void VisitExpr_(const LetNode* op) final { LOG(FATAL) << "let is not supported."; } private: void AssignReturnSid(Expr e) { if (storage_device_map_.find(e) != storage_device_map_.end()) { auto buffers = storage_device_map_[e]; - std::vector return_ids; + return_ids_.clear(); for (auto buffer : buffers) { - return_ids.push_back(buffer.sid); + return_ids_.push_back(buffer.sid); } - return_ids_ = return_ids; } } /*! @@ -163,7 +162,7 @@ class AOTOnDemandAllocator : public ExprVisitor { * \param prototype The prototype token. * \return The required memory size. */ - size_t GetMemorySize(const TensorTypeNode* ttype) { + size_t GetMemorySizeBytes(const TensorTypeNode* ttype) { ICHECK(ttype != nullptr); size_t size = 1; for (IndexExpr dim : ttype->shape) { @@ -200,8 +199,8 @@ class AOTOnDemandAllocator : public ExprVisitor { const auto* ttype = t.as(); ICHECK(ttype); StorageInfo buffer; - buffer.sid = sid_++; - buffer.size_bytes = GetMemorySize(ttype); + buffer.sid = next_available_sid_++; + buffer.size_bytes = GetMemorySizeBytes(ttype); buffer.dev_type = device_type; buffers.push_back(buffer); } @@ -209,8 +208,8 @@ class AOTOnDemandAllocator : public ExprVisitor { const auto* ttype = op->checked_type().as(); ICHECK(ttype); StorageInfo buffer; - buffer.sid = sid_++; - buffer.size_bytes = GetMemorySize(ttype); + buffer.sid = next_available_sid_++; + buffer.size_bytes = GetMemorySizeBytes(ttype); buffer.dev_type = device_type; buffers.push_back(buffer); } @@ -221,7 +220,7 @@ class AOTOnDemandAllocator : public ExprVisitor { /*! \brief mapping of expression -> device type*/ Map node_device_map_; /*! \brief current id of the temporary allocated*/ - int sid_{0}; + int next_available_sid_{0}; /*! \brief the set of intermediate tensors that are return variables */ std::vector return_ids_; }; diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 66218b3555d3..cd91a4b53317 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -142,7 +142,8 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { // Recall that the arguments of a tvm_call_cpacked are passed as // TVMValues. But a TVMValue is only a container, that points to // a real buffer previously allocated. We need to signal that those - // buffers need to be live at the same time (i.e., cannot be overridden) + // buffers need to be live at the same time (i.e., cannot be overwritten during the function + // call) Array args = op->args; for (auto arg : args) { const VarNode* var = arg.as(); @@ -234,7 +235,12 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { bool in_thread_env_{false}; // The scope stack. std::vector scope_; - // This is a map to connect TVMValues to real allocations + // This is a map to connect TVMValues to real allocations. When we pass parameters + // to a tvm_call_cpacked, the data needs to be wrapped in a TVMValue. The wrapping + // happens through the tvm_struct_set built-in. This map is mapping the variable + // representing the TVMValue to the variable representing the real buffer. The live + // analysis needs to happen on the latter and not on the TVMValue which only acts as + // a container. std::unordered_map> value_to_alloc_; }; diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index ee0af7fce66c..7a692341814a 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -387,6 +387,9 @@ def test_byoc_utvm(use_calculated_workspaces, target_options): def test_quant_mobilenet_tfl(): + """Since in AOT we pass directly the output buffer from the user, in quantized networks sharing the output buffers is not possible. + This is because the output data type is int8 and the intermediate buffer are int32 or int16. We use mobilenet quantized to stress this + situation and verify that the output buffer sharing is disabled in AOT.""" pytest.importorskip("tflite") import tvm.relay.testing.tf as tf_testing @@ -410,6 +413,8 @@ def test_quant_mobilenet_tfl(): @pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) def test_transpose(target_options): + """Test that non-inpleaceable operations (e.g., transpose) do not happen in-place.""" + dtype = "float32" x = relay.var("x", shape=(10, 5), dtype=dtype) y = relay.var("y", shape=(10, 5), dtype=dtype)