Skip to content

Commit

Permalink
[aot] Fixed ti_get_last_error signature (taichi-dev#7165)
Browse files Browse the repository at this point in the history
Issue: #

### Brief Summary

Initially `ti_get_last_error` returns `TI_EROR_TRUNCATED` to hint the
user that the message output is not complete (because the buffer size is
insufficient), but this 'non-failing' error code has been removed in
later commits. So at this point this API doesn't have a reliable way to
know whether the error message is complete. This is considered
problematic.
  • Loading branch information
PENGUINLIONG authored and quadpixels committed May 13, 2023
1 parent cbd1f3e commit 3aae825
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 21 deletions.
31 changes: 31 additions & 0 deletions c_api/include/taichi/cpp/taichi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,37 @@ inline bool is_arch_available(TiArch arch) {
return false;
}

struct Error {
TiError error;
std::string message;
};

inline Error get_last_error() {
uint64_t message_size = 0;
ti_get_last_error(&message_size, nullptr);
std::string message(message_size, '\0');
TiError error = ti_get_last_error(&message_size, (char *)message.data());
message.resize(message.size() - 1);
return Error{error, message};
}
inline void check_last_error() {
#ifdef TI_WITH_EXCEPTIONS
Error error = get_last_error();
if (error != TI_ERROR_SUCCESS) {
throw std::runtime_error(error.message);
}
#endif // TI_WITH_EXCEPTIONS
}
inline void set_last_error(TiError error) {
ti_set_last_error(error, nullptr);
}
inline void set_last_error(TiError error, const std::string &message) {
ti_set_last_error(error, message.c_str());
}
inline void set_last_error(const Error &error) {
set_last_error(error.error, error.message);
}

// Token type for half-precision floats.
struct half {
uint16_t _;
Expand Down
2 changes: 1 addition & 1 deletion c_api/include/taichi/taichi_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ TI_DLL_EXPORT void TI_API_CALL ti_get_available_archs(uint32_t *arch_count,
// semantical error code.
TI_DLL_EXPORT TiError TI_API_CALL ti_get_last_error(
// Size of textual error message in `function.get_last_error.message`
uint64_t message_size,
uint64_t *message_size,
// Text buffer for the textual error message. Ignored when `message_size` is
// 0.
char *message);
Expand Down
19 changes: 12 additions & 7 deletions c_api/src/taichi_core_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,19 +190,24 @@ void ti_get_available_archs(uint32_t *arch_count, TiArch *archs) {
}
}

TiError ti_get_last_error(uint64_t message_size, char *message) {
TiError ti_get_last_error(uint64_t *message_size, char *message) {
TiError out = TI_ERROR_INVALID_STATE;
TI_CAPI_TRY_CATCH_BEGIN();
out = thread_error_cache.error;

// Emit message only if the output buffer is property provided.
if (message_size > 0 && message != nullptr) {
size_t n = thread_error_cache.message.size();
if (n >= message_size) {
n = message_size - 1; // -1 for the byte of `\0`.
}
if (message_size == nullptr) {
return out;
}
size_t buffer_size = *message_size;
*message_size = thread_error_cache.message.size() + 1;

if (buffer_size > 0 && message != nullptr) {
// -1 for the byte of `\0`.
size_t n = std::min(thread_error_cache.message.size(), buffer_size - 1);
std::memcpy(message, thread_error_cache.message.data(), n);
message[n] = '\0';
}
out = thread_error_cache.error;
TI_CAPI_TRY_CATCH_END();
return out;
}
Expand Down
3 changes: 2 additions & 1 deletion c_api/taichi.json
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,8 @@
},
{
"name": "message_size",
"type": "uint64_t"
"type": "uint64_t",
"by_mut": true
},
{
"name": "message",
Expand Down
19 changes: 9 additions & 10 deletions c_api/tests/gtest_fixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,33 @@ inline bool is_error_ignorable(TiError error) {
class CapiTest : public ::testing::Test {
public:
void ASSERT_TAICHI_SUCCESS() {
TiError actual = ti_get_last_error(0, nullptr);
EXPECT_EQ(actual, TI_ERROR_SUCCESS);
ti::Error actual = ti::get_last_error();
EXPECT_EQ(actual.error, TI_ERROR_SUCCESS);
}

void EXPECT_TAICHI_ERROR(TiError expected,
const std::string &match = "",
bool reset_error = true) {
char err_msg[4096]{0};
TiError err = ti_get_last_error(sizeof(err_msg), err_msg);
ti::Error err = ti::get_last_error();

EXPECT_EQ(err, expected);
EXPECT_EQ(err.error, expected);

if (!match.empty())
EXPECT_NE(std::string(err_msg).find(match), std::string::npos);
EXPECT_NE(err.message.find(match), std::string::npos);

if (reset_error)
ti_set_last_error(TI_ERROR_SUCCESS, nullptr);
ti::set_last_error(TI_ERROR_SUCCESS);
}

protected:
virtual void SetUp() {
}

virtual void TearDown() {
auto error_code = ti_get_last_error(0, nullptr);
ti::Error err = ti::get_last_error();

if (!is_error_ignorable(error_code)) {
EXPECT_GE(error_code, TI_ERROR_SUCCESS);
if (!is_error_ignorable(err.error)) {
EXPECT_GE(err.error, TI_ERROR_SUCCESS);
}
}
};
2 changes: 1 addition & 1 deletion docs/lang/articles/c-api/taichi_core.md
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,7 @@ An available arch has at least one device available, i.e., device index 0 is alw
```c
// function.get_last_error
TI_DLL_EXPORT TiError TI_API_CALL ti_get_last_error(
uint64_t message_size,
uint64_t* message_size,
char* message
);
```
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ void Block::replace_with(Stmt *old_statement,
*iter = std::move(new_statements[0]);
(*iter)->parent = this;
} else {
statements.erase(iter);
iter = statements.erase(iter);
insert_at(std::move(new_statements), iter);
}
}
Expand Down

0 comments on commit 3aae825

Please sign in to comment.