From 93ef1d7972ff01749c361a72b20e13b6ac0d689e Mon Sep 17 00:00:00 2001 From: Retro <44505837+dankmolot@users.noreply.github.com> Date: Sat, 23 Nov 2024 21:16:18 +0200 Subject: [PATCH] Update --- .github/workflows/build.yml | 6 +- CMakePresets.json | 20 ++- source/async_postgres.hpp | 25 ++- source/connection.cpp | 72 +++++++- source/main.cpp | 329 +++++++++++++----------------------- source/misc.cpp | 201 ++++++++++++++++++++++ source/notifications.cpp | 26 +-- source/query.cpp | 26 +++ 8 files changed, 467 insertions(+), 238 deletions(-) create mode 100644 source/misc.cpp diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 3d176a1..3911d87 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -34,13 +34,13 @@ jobs: CMAKE_ARCH_FLAG: --preset=x86-windows - os: ubuntu-20.04 arch: x64 - CMAKE_ARCH_FLAG: --preset=unix + CMAKE_ARCH_FLAG: --preset=x64-linux - os: ubuntu-20.04 arch: x86 - CMAKE_ARCH_FLAG: --preset=unix -DCMAKE_C_FLAGS="-m32" -DCMAKE_CXX_FLAGS="-m32" -DVCPKG_TARGET_TRIPLET=x86-linux + CMAKE_ARCH_FLAG: --preset=x86-linux - os: macos-13 arch: x64 - CMAKE_ARCH_FLAG: --preset=unix + CMAKE_ARCH_FLAG: --preset=x64-macos runs-on: ${{ matrix.os }} diff --git a/CMakePresets.json b/CMakePresets.json index e9b1187..3f5afbe 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -29,9 +29,25 @@ } }, { - "name": "unix", + "name": "x64-linux", "inherits": "vcpkg", - "generator": "Ninja" + "generator": "Ninja", + "architecture": "x64" + }, + { + "name": "x86-linux", + "inherits": "vcpkg", + "generator": "Ninja", + "architecture": "x86" + }, + { + "name": "x64-macos", + "inherits": "vcpkg", + "generator": "Ninja", + "architecture": "x64", + "cacheVariables": { + "VCPKG_TARGET_TRIPLET": "x64-osx" + } } ] } diff --git a/source/async_postgres.hpp b/source/async_postgres.hpp index 6c7aef2..86a36d5 100644 --- a/source/async_postgres.hpp +++ b/source/async_postgres.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -8,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -49,20 +51,34 @@ namespace async_postgres { std::string command; }; + struct ParameterizedCommand { + std::string command; + std::vector values; + }; + struct Query { - std::variant command; + std::variant command; GLua::AutoReference callback; bool sent = false; bool flushed = false; }; + struct ResetEvent { + std::vector callbacks; + PostgresPollingStatusType status = PGRES_POLLING_WRITING; + }; + using PGconnPtr = std::unique_ptr; struct Connection { PGconnPtr conn; GLua::AutoReference lua_table; std::queue queries; + std::optional reset_event; + bool receive_notifications = + false; // enabled if on_notify lua field is set + Connection(GLua::ILuaInterface* lua, PGconnPtr&& conn); ~Connection(); }; @@ -74,6 +90,10 @@ namespace async_postgres { GLua::AutoReference&& callback); void process_pending_connections(GLua::ILuaInterface* lua); + void reset(GLua::ILuaInterface* lua, Connection* state, + GLua::AutoReference&& callback); + void process_reset(GLua::ILuaInterface* lua, Connection* state); + // notifications.cpp void process_notifications(GLua::ILuaInterface* lua, Connection* state); @@ -83,6 +103,9 @@ namespace async_postgres { // result.cpp void create_result_table(GLua::ILuaInterface* lua, PGresult* result); + // misc.cpp + void register_misc_connection_functions(GLua::ILuaInterface* lua); + // util.cpp std::string_view get_string(GLua::ILuaInterface* lua, int index = -1); void pcall(GLua::ILuaInterface* lua, int nargs, int nresults); diff --git a/source/connection.cpp b/source/connection.cpp index 4c90bf6..657e5dd 100644 --- a/source/connection.cpp +++ b/source/connection.cpp @@ -4,6 +4,21 @@ using namespace async_postgres; std::vector async_postgres::connections = {}; +Connection::Connection(GLua::ILuaInterface* lua, PGconnPtr&& conn) + : conn(std::move(conn)) { + lua->CreateTable(); + this->lua_table = GLua::AutoReference(lua); + + // add connection to global list + connections.push_back(this); +} + +Connection::~Connection() { + // remove connection from global list + // so event loop doesn't try to process it + connections.erase(std::find(connections.begin(), connections.end(), this)); +} + struct ConnectionEvent { PGconnPtr conn; GLua::AutoReference callback; @@ -47,22 +62,14 @@ void async_postgres::connect(GLua::ILuaInterface* lua, std::string_view url, inline bool poll_pending_connection(GLua::ILuaInterface* lua, ConnectionEvent& event) { if (!socket_is_ready(event.conn.get(), event.status)) { - lua->Msg("socket is not ready (%s)\n", - event.status == PGRES_POLLING_READING ? "reading" : "writing"); return false; } // TODO: handle reset event.status = PQconnectPoll(event.conn.get()); - lua->Msg("status: %d (%d)\n", event.status, PQstatus(event.conn.get())); if (event.status == PGRES_POLLING_OK) { - auto state = new Connection{std::move(event.conn)}; - - lua->CreateTable(); - state->lua_table = GLua::AutoReference(lua); - - connections.push_back(state); + auto state = new Connection(lua, std::move(event.conn)); event.callback.Push(); lua->PushBool(true); @@ -96,3 +103,50 @@ void async_postgres::process_pending_connections(GLua::ILuaInterface* lua) { } } } + +void async_postgres::reset(GLua::ILuaInterface* lua, Connection* state, + GLua::AutoReference&& callback) { + if (!state->reset_event) { + if (PQresetStart(state->conn.get()) == 0) { + throw std::runtime_error(PQerrorMessage(state->conn.get())); + } + + state->reset_event = ResetEvent(); + } + + if (callback) { + state->reset_event->callbacks.push_back(std::move(callback)); + } +} + +void async_postgres::process_reset(GLua::ILuaInterface* lua, + Connection* state) { + if (!state->reset_event) { + return; + } + + auto& event = state->reset_event.value(); + if (!socket_is_ready(state->conn.get(), state->reset_event->status)) { + return; + } + + event.status = PQresetPoll(state->conn.get()); + if (event.status == PGRES_POLLING_OK) { + for (auto& callback : event.callbacks) { + callback.Push(); + lua->PushBool(true); + pcall(lua, 1, 0); + } + + state->reset_event.reset(); + } else if (event.status == PGRES_POLLING_FAILED) { + for (auto& callback : event.callbacks) { + callback.Push(); + lua->PushBool(false); + lua->PushString(PQerrorMessage(state->conn.get())); + pcall(lua, 2, 0); + } + + state->reset_event.reset(); + } +} diff --git a/source/main.cpp b/source/main.cpp index 556655d..893022e 100644 --- a/source/main.cpp +++ b/source/main.cpp @@ -1,5 +1,3 @@ -#include - #include "async_postgres.hpp" int async_postgres::connection_meta = 0; @@ -8,264 +6,171 @@ int async_postgres::connection_meta = 0; lua->GetUserType( \ 1, async_postgres::connection_meta) -lua_protected_fn(gc_connection) { - delete lua_connection_state(); - return 0; -} +namespace async_postgres::lua { + lua_protected_fn(__gc) { + delete lua_connection_state(); + return 0; + } -async_postgres::Connection::~Connection() { - // remove connection from global list - // so event loop doesn't try to process it - async_postgres::connections.erase( - std::find(async_postgres::connections.begin(), - async_postgres::connections.end(), this)); -} + lua_protected_fn(__index) { + auto state = lua_connection_state(); + + state->lua_table.Push(); + lua->Push(2); + lua->GetTable(-2); + if (!lua->IsType(-1, GLua::Type::Nil)) { + return 1; + } + + // is it alright if I don't pop previous stack values? -lua_protected_fn(index_connection) { - auto state = lua_connection_state(); + lua->PushMetaTable(async_postgres::connection_meta); + lua->Push(2); + lua->GetTable(-2); - state->lua_table.Push(); - lua->Push(2); - lua->GetTable(-2); - if (!lua->IsType(-1, GLua::Type::Nil)) { return 1; } - // is it alright if I don't pop previous stack values? + lua_protected_fn(__newindex) { + auto state = lua_connection_state(); - lua->PushMetaTable(async_postgres::connection_meta); - lua->Push(2); - lua->GetTable(-2); + state->lua_table.Push(); + lua->Push(2); + lua->Push(3); + lua->SetTable(-3); - return 1; -} + auto key = get_string(lua, 2); + if (key == "on_notify") { + state->receive_notifications = !lua->IsType(3, GLua::Type::Nil); + } + + return 1; + } -lua_protected_fn(newindex_connection) { - auto state = lua_connection_state(); + lua_protected_fn(loop) { + async_postgres::process_pending_connections(lua); - state->lua_table.Push(); - lua->Push(2); - lua->Push(3); - lua->SetTable(-3); + for (auto* state : async_postgres::connections) { + if (!state->conn) { + lua->Msg("[async_postgres] connection is null for %p\n", state); + continue; + } - return 1; -} + async_postgres::process_reset(lua, state); + async_postgres::process_notifications(lua, state); + async_postgres::process_queries(lua, state); + } -// inline void process_connections(GLua::ILuaInterface* lua) { -// for (auto it = connections.begin(); it != connections.end();) { -// auto& event = *it; -// auto status = PQconnectPoll(event.conn); -// if (status == PGRES_POLLING_OK || status == PGRES_POLLING_FAILED) { -// lua->GetField(GLua::INDEX_GLOBAL, "ErrorNoHaltWithStack"); -// lua->ReferencePush(event.callback); -// if (status == PGRES_POLLING_OK) { -// lua->PushBool(true); -// lua->PushUserType(event.conn, connection_meta); -// lua->PushMetaTable(connection_meta); -// lua->SetMetaTable(-2); - -// queries[event.conn] = {}; -// } else { -// lua->PushBool(false); -// lua->PushString(PQerrorMessage(event.conn)); -// } - -// if (lua->PCall(2, 0, -4) != 0) { -// lua->Pop(); -// } -// lua->Pop(); -// lua->ReferenceFree(event.callback); - -// it = connections.erase(it); -// } else { -// ++it; -// } -// } -// } - -// inline void query_failed(GLua::ILuaInterface* lua, PGconn* conn, -// QueryEvent& event) { -// lua->GetField(GLua::INDEX_GLOBAL, "ErrorNoHaltWithStack"); -// lua->ReferencePush(event.callback); -// lua->PushBool(false); -// lua->PushString(PQerrorMessage(conn)); -// if (lua->PCall(2, 0, -4) != 0) { -// lua->Pop(); -// } -// lua->Pop(); -// lua->ReferenceFree(event.callback); -// } - -// inline void parse_query_result(GLua::ILuaInterface* lua, -// const PGresult* result) { -// lua->CreateTable(); -// for (int i = 0; i < PQntuples(result); i++) { -// lua->PushNumber(i); -// lua->CreateTable(); -// for (int j = 0; j < PQnfields(result); j++) { -// lua->PushString(PQfname(result, j)); -// lua->PushString(PQgetvalue(result, i, j)); -// lua->SetTable(-3); -// } -// lua->SetTable(-3); -// } -// } - -// inline void process_query_result(GLua::ILuaInterface* lua, QueryEvent& event, -// const PGresult* result) { -// auto status = PQresultStatus(result); - -// lua->GetField(GLua::INDEX_GLOBAL, "ErrorNoHaltWithStack"); -// lua->ReferencePush(event.callback); - -// if (status == PGRES_BAD_RESPONSE || status == PGRES_NONFATAL_ERROR || -// status == PGRES_FATAL_ERROR) { -// lua->PushBool(false); -// lua->PushString(PQresultErrorMessage(result)); -// } else { -// lua->PushBool(true); -// parse_query_result(lua, result); -// } - -// if (lua->PCall(2, 0, -4) != 0) { -// lua->Pop(); -// } -// lua->Pop(); -// } - -// inline void process_queries(GLua::ILuaInterface* lua) { -// for (auto& [conn, queue] : queries) { -// while (!queue.empty()) { -// auto& event = queue.front(); -// if (!event.sent) { -// if (PQsendQuery(conn, event.query.c_str()) == 0) { -// query_failed(lua, conn, event); -// queue.pop(); -// continue; -// } -// event.sent = true; -// } - -// if (!event.flushed) { -// auto status = PQflush(conn); -// if (status == -1) { -// query_failed(lua, conn, event); -// queue.pop(); -// continue; -// } -// event.flushed = status == 0; -// } - -// PQconsumeInput(conn); - -// if (PQisBusy(conn) == 1) { -// break; -// } - -// if (auto result = PQgetResult(conn)) { -// process_query_result(lua, event, result); -// PQclear(result); -// } else { -// lua->ReferenceFree(event.callback); -// queue.pop(); -// } -// } -// } -// } - -lua_protected_fn(loop) { - async_postgres::process_pending_connections(lua); - - for (auto state : async_postgres::connections) { - async_postgres::process_notifications(lua, state); - async_postgres::process_queries(lua, state); + return 0; } - return 0; -} + lua_protected_fn(connect) { + lua->CheckType(1, GLua::Type::String); + lua->CheckType(2, GLua::Type::Function); -lua_protected_fn(connect) { - lua->CheckType(1, GLua::Type::String); - lua->CheckType(2, GLua::Type::Function); + auto url = lua->GetString(1); + GLua::AutoReference callback(lua, 2); - auto url = lua->GetString(1); - GLua::AutoReference callback(lua, 2); + async_postgres::connect(lua, url, std::move(callback)); - async_postgres::connect(lua, url, std::move(callback)); + return 0; + } - return 0; -} + lua_protected_fn(query) { + lua->CheckType(1, async_postgres::connection_meta); + lua->CheckType(2, GLua::Type::String); + lua->CheckType(3, GLua::Type::Function); -lua_protected_fn(query) { - lua->CheckType(1, async_postgres::connection_meta); - lua->CheckType(2, GLua::Type::String); - lua->CheckType(3, GLua::Type::Function); + auto state = lua_connection_state(); - auto state = lua_connection_state(); + async_postgres::SimpleCommand command = {lua->GetString(2)}; + async_postgres::Query query = {std::move(command)}; + if (!lua->IsType(3, GLua::Type::Nil)) { + query.callback = GLua::AutoReference(lua, 3); + } - async_postgres::SimpleCommand command = {lua->GetString(2)}; - async_postgres::Query query = {std::move(command)}; - if (!lua->IsType(3, GLua::Type::Nil)) { - query.callback = GLua::AutoReference(lua, 3); + state->queries.push(std::move(query)); + return 0; } - state->queries.push(std::move(query)); - return 0; -} + lua_protected_fn(queryParams) { + lua->CheckType(1, async_postgres::connection_meta); + lua->CheckType(2, GLua::Type::String); + lua->CheckType(3, GLua::Type::Table); + lua->CheckType(4, GLua::Type::Function); -// int query(lua_State* L) { -// auto lua = reinterpret_cast(L->luabase); -// lua->SetState(L); + auto state = lua_connection_state(); -// lua->CheckType(1, connection_meta); -// lua->CheckType(2, GLua::Type::String); -// lua->CheckType(3, GLua::Type::Function); + async_postgres::ParameterizedCommand command = {lua->GetString(2)}; -// auto conn = lua->GetUserType(1, connection_meta); -// auto query = lua->GetString(2); + lua->Push(3); + for (int i = 1;; i++) { + lua->PushNumber(i); + lua->GetTable(-2); + if (lua->IsType(-1, GLua::Type::Nil)) { + lua->Pop(2); + break; + } -// lua->Push(3); -// int callback = lua->ReferenceCreate(); + auto str = get_string(lua, -1); + command.values.push_back({str.data(), str.size()}); + lua->Pop(1); + } -// auto& queue = queries[conn]; -// queue.push({conn, callback, query}); + async_postgres::Query query = {std::move(command)}; + query.callback = GLua::AutoReference(lua, 4); + state->queries.push(std::move(query)); -// return 0; -// } + return 0; + } -inline void register_connection_mt(GLua::ILuaInterface* lua) { - async_postgres::connection_meta = lua->CreateMetaTable("PGconn"); + lua_protected_fn(reset) { + lua->CheckType(1, async_postgres::connection_meta); + + GLua::AutoReference callback; + if (!lua->IsType(2, GLua::Type::Nil)) { + callback = GLua::AutoReference(lua, 2); + } + + auto state = lua_connection_state(); + async_postgres::reset(lua, state, std::move(callback)); - lua->PushCFunction(index_connection); - lua->SetField(-2, "__index"); + return 0; + } +} // namespace async_postgres::lua + +#define register_lua_fn(name) \ + lua->PushCFunction(async_postgres::lua::name); \ + lua->SetField(-2, #name) - lua->PushCFunction(newindex_connection); - lua->SetField(-2, "__newindex"); +void register_connection_mt(GLua::ILuaInterface* lua) { + async_postgres::connection_meta = lua->CreateMetaTable("PGconn"); - lua->PushCFunction(gc_connection); - lua->SetField(-2, "__gc"); + register_lua_fn(__index); + register_lua_fn(__newindex); + register_lua_fn(__gc); + register_lua_fn(query); + register_lua_fn(reset); - lua->PushCFunction(query); - lua->SetField(-2, "query"); + async_postgres::register_misc_connection_functions(lua); lua->Pop(); } -inline void make_global_table(GLua::ILuaInterface* lua) { +void make_global_table(GLua::ILuaInterface* lua) { lua->CreateTable(); - lua->PushCFunction(connect); - lua->SetField(-2, "connect"); + register_lua_fn(connect); lua->SetField(GLua::INDEX_GLOBAL, "async_postgres"); } -inline void register_loop_hook(GLua::ILuaInterface* lua) { +void register_loop_hook(GLua::ILuaInterface* lua) { lua->GetField(GLua::INDEX_GLOBAL, "hook"); lua->GetField(-1, "Add"); lua->PushString("Think"); lua->PushString("async_postgres_loop"); - lua->PushCFunction(loop); + lua->PushCFunction(async_postgres::lua::loop); lua->Call(3, 0); lua->Pop(); } diff --git a/source/misc.cpp b/source/misc.cpp new file mode 100644 index 0000000..085f530 --- /dev/null +++ b/source/misc.cpp @@ -0,0 +1,201 @@ +#include "async_postgres.hpp" + +using namespace async_postgres; + +#define lua_connection_state() \ + lua->GetUserType( \ + 1, async_postgres::connection_meta) + +#define lua_connection() lua_connection_state()->conn.get() + +#define lua_string_getter(name, getter) \ + lua_protected_fn(name) { \ + lua->CheckType(1, connection_meta); \ + lua->PushString(getter(lua_connection())); \ + return 1; \ + } + +#define lua_number_getter(name, getter) \ + lua_protected_fn(name) { \ + lua->CheckType(1, connection_meta); \ + lua->PushNumber(getter(lua_connection())); \ + return 1; \ + } + +#define lua_bool_getter(name, getter) \ + lua_protected_fn(name) { \ + lua->CheckType(1, connection_meta); \ + lua->PushBool(getter(lua_connection())); \ + return 1; \ + } + +namespace async_postgres::lua { + // 34.2. Connection Status Functions + lua_string_getter(db, PQdb); + lua_string_getter(user, PQuser); + lua_string_getter(host, PQhost); + lua_string_getter(hostaddr, PQhostaddr); + lua_string_getter(port, PQport); + lua_number_getter(status, PQstatus); + lua_number_getter(transactionStatus, PQtransactionStatus); + + lua_protected_fn(parameterStatus) { + lua->CheckType(1, connection_meta); + lua->CheckType(2, GLua::Type::String); + + lua->PushString(PQparameterStatus(lua_connection(), lua->GetString(2))); + return 1; + } + + lua_number_getter(protocolVersion, PQprotocolVersion); + lua_number_getter(serverVersion, PQserverVersion); + lua_string_getter(errorMessage, PQerrorMessage); + lua_number_getter(backendPID, PQbackendPID); + lua_bool_getter(sslInUse, PQsslInUse); + + lua_protected_fn(sslAttribute) { + lua->CheckType(1, connection_meta); + lua->CheckType(2, GLua::Type::String); + + lua->PushString(PQsslAttribute(lua_connection(), lua->GetString(2))); + return 1; + } + + lua_protected_fn(clientEncoding) { + lua->CheckType(1, connection_meta); + lua->PushString( + pg_encoding_to_char(PQclientEncoding(lua_connection()))); + return 1; + } + + lua_protected_fn(setClientEncoding) { + lua->CheckType(1, connection_meta); + lua->CheckType(2, GLua::Type::String); + + lua->PushBool( + PQsetClientEncoding(lua_connection(), lua->GetString(2)) == 0); + return 1; + } + + lua_protected_fn(encryptPassword) { + lua->CheckType(1, connection_meta); + lua->CheckType(2, GLua::Type::String); + lua->CheckType(3, GLua::Type::String); + lua->CheckType(4, GLua::Type::String); + + auto result = + PQencryptPasswordConn(lua_connection(), lua->GetString(2), + lua->GetString(3), lua->GetString(4)); + + lua->PushString(result); + PQfreemem(result); + return 1; + } + + lua_protected_fn(escape) { + lua->CheckType(1, connection_meta); + lua->CheckType(2, GLua::Type::String); + + auto str = get_string(lua, 2); + auto escaped = + PQescapeLiteral(lua_connection(), str.data(), str.size()); + + if (!escaped) { + throw std::runtime_error(PQerrorMessage(lua_connection())); + } + + lua->PushString(escaped); + PQfreemem(escaped); + return 1; + } + + lua_protected_fn(escapeIdentifier) { + lua->CheckType(1, connection_meta); + lua->CheckType(2, GLua::Type::String); + + auto str = get_string(lua, 2); + auto escaped = + PQescapeIdentifier(lua_connection(), str.data(), str.size()); + + if (!escaped) { + throw std::runtime_error(PQerrorMessage(lua_connection())); + } + + lua->PushString(escaped); + PQfreemem(escaped); + return 1; + } + + lua_protected_fn(escapeBytea) { + lua->CheckType(1, connection_meta); + lua->CheckType(2, GLua::Type::String); + + unsigned int strLen = 0; + size_t outLen = 0; + const unsigned char* str = + reinterpret_cast(lua->GetString(2, &strLen)); + char* escaped = + reinterpret_cast(PQescapeBytea(str, strLen, &outLen)); + + if (!escaped) { + throw std::runtime_error(PQerrorMessage(lua_connection())); + } + + lua->PushString(escaped, outLen - 1); + PQfreemem(escaped); + return 1; + } + + lua_protected_fn(unescapeBytea) { + lua->CheckType(1, connection_meta); + lua->CheckType(2, GLua::Type::String); + + unsigned int strLen = 0; + size_t outLen = 0; + const unsigned char* str = + reinterpret_cast(lua->GetString(2, &strLen)); + char* unescaped = + reinterpret_cast(PQunescapeBytea(str, &outLen)); + + if (!unescaped) { + throw std::runtime_error(PQerrorMessage(lua_connection())); + } + + lua->PushString(unescaped, outLen); + PQfreemem(unescaped); + return 1; + } + + // 34.3. Asynchronous Command Processing + lua_bool_getter(isBusy, PQisBusy); +} // namespace async_postgres::lua + +#define register_lua_fn(name) \ + lua->PushCFunction(async_postgres::lua::name); \ + lua->SetField(-2, #name) + +void async_postgres::register_misc_connection_functions( + GLua::ILuaInterface* lua) { + register_lua_fn(db); + register_lua_fn(user); + register_lua_fn(host); + register_lua_fn(hostaddr); + register_lua_fn(port); + register_lua_fn(status); + register_lua_fn(transactionStatus); + register_lua_fn(parameterStatus); + register_lua_fn(protocolVersion); + register_lua_fn(serverVersion); + register_lua_fn(errorMessage); + register_lua_fn(backendPID); + register_lua_fn(sslInUse); + register_lua_fn(sslAttribute); + register_lua_fn(clientEncoding); + register_lua_fn(setClientEncoding); + register_lua_fn(encryptPassword); + register_lua_fn(escape); + register_lua_fn(escapeIdentifier); + register_lua_fn(escapeBytea); + register_lua_fn(unescapeBytea); + register_lua_fn(isBusy); +} diff --git a/source/notifications.cpp b/source/notifications.cpp index 2cc6a0c..3b89f6e 100644 --- a/source/notifications.cpp +++ b/source/notifications.cpp @@ -16,29 +16,33 @@ bool push_on_notify(GLua::ILuaInterface* lua, Connection* state) { void async_postgres::process_notifications(GLua::ILuaInterface* lua, Connection* state) { - if (!push_on_notify(lua, state)) { + if (state->reset_event) { + // don't process notifications while reconnecting return; } - if (state->queries.empty() && PQconsumeInput(state->conn.get()) == 0) { + if (!state->receive_notifications) { + return; + } + + if (state->queries.empty() && + check_socket_status(state->conn.get()).read_ready && + PQconsumeInput(state->conn.get()) == 0) { // we consumed input // but there was some error - // TODO: update connection state return; } PGnotify* notify; while ((notify = PQnotifies(state->conn.get()))) { - lua->Push(-1); // copy on_notify + if (push_on_notify(lua, state)) { + lua->PushString(notify->relname); // arg 1 channel name + lua->PushString(notify->extra); // arg 2 payload + lua->PushNumber(notify->be_pid); // arg 3 backend pid - lua->PushString(notify->relname); // arg 1 channel name - lua->PushString(notify->extra); // arg 2 payload - lua->PushNumber(notify->be_pid); // arg 3 backend pid - - pcall(lua, 3, 0); + pcall(lua, 3, 0); + } PQfreemem(notify); } - - lua->Pop(); // pop on_notify } diff --git a/source/query.cpp b/source/query.cpp index 6339004..44664e5 100644 --- a/source/query.cpp +++ b/source/query.cpp @@ -7,6 +7,19 @@ using namespace async_postgres; inline bool send_query(PGconn* conn, Query& query) { if (const auto* command = std::get_if(&query.command)) { return PQsendQuery(conn, command->command.c_str()) == 1; + } else if (const auto* command = + std::get_if(&query.command)) { + size_t nParams = command->values.size(); + std::vector paramValues(nParams); + for (size_t i = 0; i < nParams; i++) { + paramValues[i] = command->values[i].c_str(); + } + + bool success = + PQsendQueryParams(conn, command->command.c_str(), nParams, nullptr, + paramValues.data(), nullptr, nullptr, 0) == 1; + + return success; } return false; } @@ -47,6 +60,8 @@ bool bad_result(PGresult* result) { status == PGRES_FATAL_ERROR; } +static bool reseted = false; + void async_postgres::process_queries(GLua::ILuaInterface* lua, Connection* state) { if (state->queries.empty()) { @@ -54,6 +69,11 @@ void async_postgres::process_queries(GLua::ILuaInterface* lua, return; } + if (state->reset_event) { + // don't process queries while reconnecting + return; + } + auto& query = state->queries.front(); if (!query.sent) { if (!send_query(state->conn.get(), query)) { @@ -66,6 +86,12 @@ void async_postgres::process_queries(GLua::ILuaInterface* lua, query.flushed = PQflush(state->conn.get()) == 0; } + if (query.flushed && !reseted) { + reset(lua, state, {}); + reseted = true; + return; + } + if (!poll_query(state->conn.get(), query)) { query_failed(lua, state->conn.get(), query); state->queries.pop();