From fbaa085e3533488f696f61cc9f8b18d65c3573cc Mon Sep 17 00:00:00 2001 From: PENGUINLIONG Date: Mon, 16 Jan 2023 12:14:29 +0800 Subject: [PATCH] [aot] Fixed ti_get_last_error signature (#7165) Issue: # ### Brief Summary Initially `ti_get_last_error` returns `TI_EROR_TRUNCATED` to hint the user that the message output is not complete (because the buffer size is insufficient), but this 'non-failing' error code has been removed in later commits. So at this point this API doesn't have a reliable way to know whether the error message is complete. This is considered problematic. --- c_api/include/taichi/cpp/taichi.hpp | 31 +++++++++++++++++++++++++ c_api/include/taichi/taichi_core.h | 2 +- c_api/src/taichi_core_impl.cpp | 19 +++++++++------ c_api/taichi.json | 3 ++- c_api/tests/gtest_fixture.h | 19 +++++++-------- docs/lang/articles/c-api/taichi_core.md | 2 +- taichi/ir/ir.cpp | 2 +- 7 files changed, 57 insertions(+), 21 deletions(-) diff --git a/c_api/include/taichi/cpp/taichi.hpp b/c_api/include/taichi/cpp/taichi.hpp index d8e008e4b33af..3a216f793892b 100644 --- a/c_api/include/taichi/cpp/taichi.hpp +++ b/c_api/include/taichi/cpp/taichi.hpp @@ -51,6 +51,37 @@ inline bool is_arch_available(TiArch arch) { return false; } +struct Error { + TiError error; + std::string message; +}; + +inline Error get_last_error() { + uint64_t message_size = 0; + ti_get_last_error(&message_size, nullptr); + std::string message(message_size, '\0'); + TiError error = ti_get_last_error(&message_size, (char *)message.data()); + message.resize(message.size() - 1); + return Error{error, message}; +} +inline void check_last_error() { +#ifdef TI_WITH_EXCEPTIONS + Error error = get_last_error(); + if (error != TI_ERROR_SUCCESS) { + throw std::runtime_error(error.message); + } +#endif // TI_WITH_EXCEPTIONS +} +inline void set_last_error(TiError error) { + ti_set_last_error(error, nullptr); +} +inline void set_last_error(TiError error, const std::string &message) { + ti_set_last_error(error, message.c_str()); +} +inline void set_last_error(const Error &error) { + set_last_error(error.error, error.message); +} + // Token type for half-precision floats. struct half { uint16_t _; diff --git a/c_api/include/taichi/taichi_core.h b/c_api/include/taichi/taichi_core.h index 1c5061525eeb9..8a0723072fdba 100644 --- a/c_api/include/taichi/taichi_core.h +++ b/c_api/include/taichi/taichi_core.h @@ -835,7 +835,7 @@ TI_DLL_EXPORT void TI_API_CALL ti_get_available_archs(uint32_t *arch_count, // semantical error code. TI_DLL_EXPORT TiError TI_API_CALL ti_get_last_error( // Size of textual error message in `function.get_last_error.message` - uint64_t message_size, + uint64_t *message_size, // Text buffer for the textual error message. Ignored when `message_size` is // 0. char *message); diff --git a/c_api/src/taichi_core_impl.cpp b/c_api/src/taichi_core_impl.cpp index df46d30bf3040..ff3300d1fb0ab 100644 --- a/c_api/src/taichi_core_impl.cpp +++ b/c_api/src/taichi_core_impl.cpp @@ -190,19 +190,24 @@ void ti_get_available_archs(uint32_t *arch_count, TiArch *archs) { } } -TiError ti_get_last_error(uint64_t message_size, char *message) { +TiError ti_get_last_error(uint64_t *message_size, char *message) { TiError out = TI_ERROR_INVALID_STATE; TI_CAPI_TRY_CATCH_BEGIN(); + out = thread_error_cache.error; + // Emit message only if the output buffer is property provided. - if (message_size > 0 && message != nullptr) { - size_t n = thread_error_cache.message.size(); - if (n >= message_size) { - n = message_size - 1; // -1 for the byte of `\0`. - } + if (message_size == nullptr) { + return out; + } + size_t buffer_size = *message_size; + *message_size = thread_error_cache.message.size() + 1; + + if (buffer_size > 0 && message != nullptr) { + // -1 for the byte of `\0`. + size_t n = std::min(thread_error_cache.message.size(), buffer_size - 1); std::memcpy(message, thread_error_cache.message.data(), n); message[n] = '\0'; } - out = thread_error_cache.error; TI_CAPI_TRY_CATCH_END(); return out; } diff --git a/c_api/taichi.json b/c_api/taichi.json index 1775eb9fccc0c..cc5f2b0ca34a4 100644 --- a/c_api/taichi.json +++ b/c_api/taichi.json @@ -528,7 +528,8 @@ }, { "name": "message_size", - "type": "uint64_t" + "type": "uint64_t", + "by_mut": true }, { "name": "message", diff --git a/c_api/tests/gtest_fixture.h b/c_api/tests/gtest_fixture.h index 3b0358c3c0211..fceb533012c1c 100644 --- a/c_api/tests/gtest_fixture.h +++ b/c_api/tests/gtest_fixture.h @@ -10,23 +10,22 @@ inline bool is_error_ignorable(TiError error) { class CapiTest : public ::testing::Test { public: void ASSERT_TAICHI_SUCCESS() { - TiError actual = ti_get_last_error(0, nullptr); - EXPECT_EQ(actual, TI_ERROR_SUCCESS); + ti::Error actual = ti::get_last_error(); + EXPECT_EQ(actual.error, TI_ERROR_SUCCESS); } void EXPECT_TAICHI_ERROR(TiError expected, const std::string &match = "", bool reset_error = true) { - char err_msg[4096]{0}; - TiError err = ti_get_last_error(sizeof(err_msg), err_msg); + ti::Error err = ti::get_last_error(); - EXPECT_EQ(err, expected); + EXPECT_EQ(err.error, expected); if (!match.empty()) - EXPECT_NE(std::string(err_msg).find(match), std::string::npos); + EXPECT_NE(err.message.find(match), std::string::npos); if (reset_error) - ti_set_last_error(TI_ERROR_SUCCESS, nullptr); + ti::set_last_error(TI_ERROR_SUCCESS); } protected: @@ -34,10 +33,10 @@ class CapiTest : public ::testing::Test { } virtual void TearDown() { - auto error_code = ti_get_last_error(0, nullptr); + ti::Error err = ti::get_last_error(); - if (!is_error_ignorable(error_code)) { - EXPECT_GE(error_code, TI_ERROR_SUCCESS); + if (!is_error_ignorable(err.error)) { + EXPECT_GE(err.error, TI_ERROR_SUCCESS); } } }; diff --git a/docs/lang/articles/c-api/taichi_core.md b/docs/lang/articles/c-api/taichi_core.md index bc87c3769b62e..1eb27d8b9fb1d 100644 --- a/docs/lang/articles/c-api/taichi_core.md +++ b/docs/lang/articles/c-api/taichi_core.md @@ -974,7 +974,7 @@ An available arch has at least one device available, i.e., device index 0 is alw ```c // function.get_last_error TI_DLL_EXPORT TiError TI_API_CALL ti_get_last_error( - uint64_t message_size, + uint64_t* message_size, char* message ); ``` diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index 0d5946ced60c1..a728cdda1a0d2 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -350,7 +350,7 @@ void Block::replace_with(Stmt *old_statement, *iter = std::move(new_statements[0]); (*iter)->parent = this; } else { - statements.erase(iter); + iter = statements.erase(iter); insert_at(std::move(new_statements), iter); } }