From e8d8a13aa82c0ec929a103dcc81f250ab0dda02b Mon Sep 17 00:00:00 2001 From: Rafael Telles Date: Fri, 26 Nov 2021 13:59:04 -0300 Subject: [PATCH] [C++] Address ratification comments (round 4 - part 3) (#214) * Make other methods from SQLite server example to return arrow::Result instead of Status * Fix bug for null values in numeric columns on SQLite server example * Add comment regarding to performance on sqlite_statement_batch_reader * Separate GetSqlInfoResultMap from sqlite_server.h * Remove unused parameter on DoPutCommandStatementUpdate --- cpp/src/arrow/flight/sql/CMakeLists.txt | 2 + .../arrow/flight/sql/example/sqlite_server.cc | 33 +-- .../arrow/flight/sql/example/sqlite_server.h | 198 +-------------- .../flight/sql/example/sqlite_sql_info.cc | 225 ++++++++++++++++++ .../flight/sql/example/sqlite_sql_info.h | 34 +++ .../flight/sql/example/sqlite_statement.cc | 31 +-- .../flight/sql/example/sqlite_statement.h | 20 +- .../example/sqlite_statement_batch_reader.cc | 58 +++-- cpp/src/arrow/flight/sql/server.cc | 5 +- cpp/src/arrow/flight/sql/server.h | 4 +- cpp/src/arrow/flight/sql/server_test.cc | 33 +-- 11 files changed, 349 insertions(+), 294 deletions(-) create mode 100644 cpp/src/arrow/flight/sql/example/sqlite_sql_info.cc create mode 100644 cpp/src/arrow/flight/sql/example/sqlite_sql_info.h diff --git a/cpp/src/arrow/flight/sql/CMakeLists.txt b/cpp/src/arrow/flight/sql/CMakeLists.txt index edff4d8595423..0c80d19332999 100644 --- a/cpp/src/arrow/flight/sql/CMakeLists.txt +++ b/cpp/src/arrow/flight/sql/CMakeLists.txt @@ -70,6 +70,7 @@ add_arrow_test(flight_sql_test SOURCES client_test.cc server_test.cc + example/sqlite_sql_info.cc STATIC_LINK_LIBS ${ARROW_FLIGHT_SQL_TEST_LINK_LIBS} LABELS @@ -81,6 +82,7 @@ if(ARROW_BUILD_TESTS OR ARROW_BUILD_EXAMPLES) add_executable(flight_sql_test_server test_server_cli.cc + example/sqlite_sql_info.cc example/sqlite_statement.cc example/sqlite_statement_batch_reader.cc example/sqlite_server.cc diff --git a/cpp/src/arrow/flight/sql/example/sqlite_server.cc b/cpp/src/arrow/flight/sql/example/sqlite_server.cc index cac80163d7132..de2bd5b7e02d9 100644 --- a/cpp/src/arrow/flight/sql/example/sqlite_server.cc +++ b/cpp/src/arrow/flight/sql/example/sqlite_server.cc @@ -24,6 +24,7 @@ #include #include "arrow/api.h" +#include "arrow/flight/sql/example/sqlite_sql_info.h" #include "arrow/flight/sql/example/sqlite_statement.h" #include "arrow/flight/sql/example/sqlite_statement_batch_reader.h" #include "arrow/flight/sql/example/sqlite_tables_schema_batch_reader.h" @@ -204,6 +205,7 @@ arrow::Result> SQLiteFlightSqlServer::Cre INSERT INTO intTable (keyName, value, foreignId) VALUES ('one', 1, 1); INSERT INTO intTable (keyName, value, foreignId) VALUES ('zero', 0, 1); INSERT INTO intTable (keyName, value, foreignId) VALUES ('negative one', -1, 1); + INSERT INTO intTable (keyName, value, foreignId) VALUES (NULL, NULL, NULL); )")); return result; @@ -251,14 +253,11 @@ arrow::Result> SQLiteFlightSqlServer::GetFlightInfoS const FlightDescriptor& descriptor) { const std::string& query = command.query; - std::shared_ptr statement; - ARROW_ASSIGN_OR_RAISE(statement, SqliteStatement::Create(db_, query)); + ARROW_ASSIGN_OR_RAISE(auto statement, SqliteStatement::Create(db_, query)); - std::shared_ptr schema; - ARROW_RETURN_NOT_OK(statement->GetSchema(&schema)); + ARROW_ASSIGN_OR_RAISE(auto schema, statement->GetSchema()); - std::string ticket_string; - ARROW_ASSIGN_OR_RAISE(ticket_string, CreateStatementQueryTicket(query)); + ARROW_ASSIGN_OR_RAISE(auto ticket_string, CreateStatementQueryTicket(query)); std::vector endpoints{FlightEndpoint{{ticket_string}, {}}}; ARROW_ASSIGN_OR_RAISE(auto result, FlightInfo::Make(*schema, descriptor, endpoints, -1, -1)) @@ -363,17 +362,12 @@ arrow::Result> SQLiteFlightSqlServer::DoGetTab } arrow::Result SQLiteFlightSqlServer::DoPutCommandStatementUpdate( - const ServerCallContext& context, const StatementUpdate& command, - std::unique_ptr& reader) { + const ServerCallContext& context, const StatementUpdate& command) { const std::string& sql = command.query; - std::shared_ptr statement; - ARROW_ASSIGN_OR_RAISE(statement, SqliteStatement::Create(db_, sql)); - - int64_t record_count; - ARROW_RETURN_NOT_OK(statement->ExecuteUpdate(&record_count)); + ARROW_ASSIGN_OR_RAISE(auto statement, SqliteStatement::Create(db_, sql)); - return record_count; + return statement->ExecuteUpdate(); } arrow::Result @@ -385,8 +379,7 @@ SQLiteFlightSqlServer::CreatePreparedStatement( boost::uuids::uuid uuid = uuid_generator_(); prepared_statements_[uuid] = statement; - std::shared_ptr dataset_schema; - ARROW_RETURN_NOT_OK(statement->GetSchema(&dataset_schema)); + ARROW_ASSIGN_OR_RAISE(auto dataset_schema, statement->GetSchema()); sqlite3_stmt* stmt = statement->GetSqlite3Stmt(); const int parameter_count = sqlite3_bind_parameter_count(stmt); @@ -449,8 +442,7 @@ SQLiteFlightSqlServer::GetFlightInfoPreparedStatement( std::shared_ptr statement = search->second; - std::shared_ptr schema; - ARROW_RETURN_NOT_OK(statement->GetSchema(&schema)); + ARROW_ASSIGN_OR_RAISE(auto schema, statement->GetSchema()); return GetFlightInfoForCommand(descriptor, schema); } @@ -497,10 +489,7 @@ arrow::Result SQLiteFlightSqlServer::DoPutPreparedStatementUpdate( sqlite3_stmt* stmt = statement->GetSqlite3Stmt(); ARROW_RETURN_NOT_OK(SetParametersOnSQLiteStatement(stmt, reader)); - int64_t record_count; - ARROW_RETURN_NOT_OK(statement->ExecuteUpdate(&record_count)); - - return record_count; + return statement->ExecuteUpdate(); } arrow::Result> SQLiteFlightSqlServer::GetFlightInfoTableTypes( diff --git a/cpp/src/arrow/flight/sql/example/sqlite_server.h b/cpp/src/arrow/flight/sql/example/sqlite_server.h index 58f88d9d5d318..f505c291e394c 100644 --- a/cpp/src/arrow/flight/sql/example/sqlite_server.h +++ b/cpp/src/arrow/flight/sql/example/sqlite_server.h @@ -26,7 +26,6 @@ #include #include "arrow/api.h" -#include "arrow/flight/sql/FlightSql.pb.h" #include "arrow/flight/sql/example/sqlite_statement.h" #include "arrow/flight/sql/example/sqlite_statement_batch_reader.h" #include "arrow/flight/sql/server.h" @@ -36,200 +35,6 @@ namespace flight { namespace sql { namespace example { -namespace flight_sql_pb = arrow::flight::protocol::sql; - -/// \brief Gets the mapping from SQL info ids to SqlInfoResult instances. -/// \return the cache. -inline SqlInfoResultMap GetSqlInfoResultMap() { - return { - {flight_sql_pb::SqlInfo::FLIGHT_SQL_SERVER_NAME, - SqlInfoResult(std::string("db_name"))}, - {flight_sql_pb::SqlInfo::FLIGHT_SQL_SERVER_VERSION, - SqlInfoResult(std::string("sqlite 3"))}, - {flight_sql_pb::SqlInfo::FLIGHT_SQL_SERVER_ARROW_VERSION, - SqlInfoResult(std::string("7.0.0-SNAPSHOT" /* Only an example */))}, - {flight_sql_pb::SqlInfo::FLIGHT_SQL_SERVER_READ_ONLY, SqlInfoResult(false)}, - {flight_sql_pb::SqlInfo::SQL_DDL_CATALOG, - SqlInfoResult(false /* SQLite 3 does not support catalogs */)}, - {flight_sql_pb::SqlInfo::SQL_DDL_SCHEMA, - SqlInfoResult(false /* SQLite 3 does not support schemas */)}, - {flight_sql_pb::SqlInfo::SQL_DDL_TABLE, SqlInfoResult(true)}, - {flight_sql_pb::SqlInfo::SQL_IDENTIFIER_CASE, - SqlInfoResult(int64_t(flight_sql_pb::SqlSupportedCaseSensitivity:: - SQL_CASE_SENSITIVITY_CASE_INSENSITIVE))}, - {flight_sql_pb::SqlInfo::SQL_IDENTIFIER_QUOTE_CHAR, - SqlInfoResult(std::string("\""))}, - {flight_sql_pb::SqlInfo::SQL_QUOTED_IDENTIFIER_CASE, - SqlInfoResult(int64_t(flight_sql_pb::SqlSupportedCaseSensitivity:: - SQL_CASE_SENSITIVITY_CASE_INSENSITIVE))}, - {flight_sql_pb::SqlInfo::SQL_ALL_TABLES_ARE_SELECTABLE, SqlInfoResult(true)}, - {flight_sql_pb::SqlInfo::SQL_NULL_ORDERING, - SqlInfoResult(int64_t(flight_sql_pb::SqlNullOrdering::SQL_NULLS_SORTED_AT_START))}, - {flight_sql_pb::SqlInfo::SQL_KEYWORDS, - SqlInfoResult(std::vector({"ABORT", - "ACTION", - "ADD", - "AFTER", - "ALL", - "ALTER", - "ALWAYS", - "ANALYZE", - "AND", - "AS", - "ASC", - "ATTACH", - "AUTOINCREMENT", - "BEFORE", - "BEGIN", - "BETWEEN", - "BY", - "CASCADE", - "CASE", - "CAST", - "CHECK", - "COLLATE", - "COLUMN", - "COMMIT", - "CONFLICT", - "CONSTRAINT", - "CREATE", - "CROSS", - "CURRENT", - "CURRENT_DATE", - "CURRENT_TIME", - "CURRENT_TIMESTAMP", - "DATABASE", - "DEFAULT", - "DEFERRABLE", - "DEFERRED", - "DELETE", - "DESC", - "DETACH", - "DISTINCT", - "DO", - "DROP", - "EACH", - "ELSE", - "END", - "ESCAPE", - "EXCEPT", - "EXCLUDE", - "EXCLUSIVE", - "EXISTS", - "EXPLAIN", - "FAIL", - "FILTER", - "FIRST", - "FOLLOWING", - "FOR", - "FOREIGN", - "FROM", - "FULL", - "GENERATED", - "GLOB", - "GROUP", - "GROUPS", - "HAVING", - "IF", - "IGNORE", - "IMMEDIATE", - "IN", - "INDEX", - "INDEXED", - "INITIALLY", - "INNER", - "INSERT", - "INSTEAD", - "INTERSECT", - "INTO", - "IS", - "ISNULL", - "JOIN", - "KEY", - "LAST", - "LEFT", - "LIKE", - "LIMIT", - "MATCH", - "MATERIALIZED", - "NATURAL", - "NO", - "NOT", - "NOTHING", - "NOTNULL", - "NULL", - "NULLS", - "OF", - "OFFSET", - "ON", - "OR", - "ORDER", - "OTHERS", - "OUTER", - "OVER", - "PARTITION", - "PLAN", - "PRAGMA", - "PRECEDING", - "PRIMARY", - "QUERY", - "RAISE", - "RANGE", - "RECURSIVE", - "REFERENCES", - "REGEXP", - "REINDEX", - "RELEASE", - "RENAME", - "REPLACE", - "RESTRICT", - "RETURNING", - "RIGHT", - "ROLLBACK", - "ROW", - "ROWS", - "SAVEPOINT", - "SELECT", - "SET", - "TABLE", - "TEMP", - "TEMPORARY", - "THEN", - "TIES", - "TO", - "TRANSACTION", - "TRIGGER", - "UNBOUNDED", - "UNION", - "UNIQUE", - "UPDATE", - "USING", - "VACUUM", - "VALUES", - "VIEW", - "VIRTUAL", - "WHEN", - "WHERE", - "WINDOW", - "WITH", - "WITHOUT"}))}, - {flight_sql_pb::SqlInfo::SQL_NUMERIC_FUNCTIONS, - SqlInfoResult(std::vector( - {"acos", "acosh", "asin", "asinh", "atan", "atan2", "atanh", "ceil", - "ceiling", "cos", "cosh", "degrees", "exp", "floor", "ln", "log", - "log", "log10", "log2", "mod", "pi", "pow", "power", "radians", - "sin", "sinh", "sqrt", "tan", "tanh", "trunc"}))}, - {flight_sql_pb::SqlInfo::SQL_STRING_FUNCTIONS, - SqlInfoResult( - std::vector({"SUBSTR", "TRIM", "LTRIM", "RTRIM", "LENGTH", - "REPLACE", "UPPER", "LOWER", "INSTR"}))}, - {flight_sql_pb::SqlInfo::SQL_SUPPORTS_CONVERT, - SqlInfoResult(std::unordered_map>( - {{flight_sql_pb::SqlSupportsConvert::SQL_CONVERT_BIGINT, - std::vector( - {flight_sql_pb::SqlSupportsConvert::SQL_CONVERT_INTEGER})}}))}}; -} - /// \brief Convert a column type to a ArrowType. /// \param sqlite_type the sqlite type. /// \return The equivalent ArrowType. @@ -274,8 +79,7 @@ class SQLiteFlightSqlServer : public FlightSqlServerBase { arrow::Result> DoGetSchemas( const ServerCallContext& context, const GetSchemas& command) override; arrow::Result DoPutCommandStatementUpdate( - const ServerCallContext& context, const StatementUpdate& update, - std::unique_ptr& reader) override; + const ServerCallContext& context, const StatementUpdate& update) override; arrow::Result CreatePreparedStatement( const ServerCallContext& context, const ActionCreatePreparedStatementRequest& request) override; diff --git a/cpp/src/arrow/flight/sql/example/sqlite_sql_info.cc b/cpp/src/arrow/flight/sql/example/sqlite_sql_info.cc new file mode 100644 index 0000000000000..92007fa9eb377 --- /dev/null +++ b/cpp/src/arrow/flight/sql/example/sqlite_sql_info.cc @@ -0,0 +1,225 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/sql/example/sqlite_sql_info.h" + +#include "arrow/flight/sql/FlightSql.pb.h" +#include "arrow/flight/sql/sql_info_types.h" + +namespace arrow { +namespace flight { +namespace sql { +namespace example { + +namespace flight_sql_pb = arrow::flight::protocol::sql; + +/// \brief Gets the mapping from SQL info ids to SqlInfoResult instances. +/// \return the cache. +SqlInfoResultMap GetSqlInfoResultMap() { + return { + {flight_sql_pb::SqlInfo::FLIGHT_SQL_SERVER_NAME, + SqlInfoResult(std::string("db_name"))}, + {flight_sql_pb::SqlInfo::FLIGHT_SQL_SERVER_VERSION, + SqlInfoResult(std::string("sqlite 3"))}, + {flight_sql_pb::SqlInfo::FLIGHT_SQL_SERVER_ARROW_VERSION, + SqlInfoResult(std::string("7.0.0-SNAPSHOT" /* Only an example */))}, + {flight_sql_pb::SqlInfo::FLIGHT_SQL_SERVER_READ_ONLY, SqlInfoResult(false)}, + {flight_sql_pb::SqlInfo::SQL_DDL_CATALOG, + SqlInfoResult(false /* SQLite 3 does not support catalogs */)}, + {flight_sql_pb::SqlInfo::SQL_DDL_SCHEMA, + SqlInfoResult(false /* SQLite 3 does not support schemas */)}, + {flight_sql_pb::SqlInfo::SQL_DDL_TABLE, SqlInfoResult(true)}, + {flight_sql_pb::SqlInfo::SQL_IDENTIFIER_CASE, + SqlInfoResult(int64_t(flight_sql_pb::SqlSupportedCaseSensitivity:: + SQL_CASE_SENSITIVITY_CASE_INSENSITIVE))}, + {flight_sql_pb::SqlInfo::SQL_IDENTIFIER_QUOTE_CHAR, + SqlInfoResult(std::string("\""))}, + {flight_sql_pb::SqlInfo::SQL_QUOTED_IDENTIFIER_CASE, + SqlInfoResult(int64_t(flight_sql_pb::SqlSupportedCaseSensitivity:: + SQL_CASE_SENSITIVITY_CASE_INSENSITIVE))}, + {flight_sql_pb::SqlInfo::SQL_ALL_TABLES_ARE_SELECTABLE, SqlInfoResult(true)}, + {flight_sql_pb::SqlInfo::SQL_NULL_ORDERING, + SqlInfoResult(int64_t(flight_sql_pb::SqlNullOrdering::SQL_NULLS_SORTED_AT_START))}, + {flight_sql_pb::SqlInfo::SQL_KEYWORDS, + SqlInfoResult(std::vector({"ABORT", + "ACTION", + "ADD", + "AFTER", + "ALL", + "ALTER", + "ALWAYS", + "ANALYZE", + "AND", + "AS", + "ASC", + "ATTACH", + "AUTOINCREMENT", + "BEFORE", + "BEGIN", + "BETWEEN", + "BY", + "CASCADE", + "CASE", + "CAST", + "CHECK", + "COLLATE", + "COLUMN", + "COMMIT", + "CONFLICT", + "CONSTRAINT", + "CREATE", + "CROSS", + "CURRENT", + "CURRENT_DATE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "DATABASE", + "DEFAULT", + "DEFERRABLE", + "DEFERRED", + "DELETE", + "DESC", + "DETACH", + "DISTINCT", + "DO", + "DROP", + "EACH", + "ELSE", + "END", + "ESCAPE", + "EXCEPT", + "EXCLUDE", + "EXCLUSIVE", + "EXISTS", + "EXPLAIN", + "FAIL", + "FILTER", + "FIRST", + "FOLLOWING", + "FOR", + "FOREIGN", + "FROM", + "FULL", + "GENERATED", + "GLOB", + "GROUP", + "GROUPS", + "HAVING", + "IF", + "IGNORE", + "IMMEDIATE", + "IN", + "INDEX", + "INDEXED", + "INITIALLY", + "INNER", + "INSERT", + "INSTEAD", + "INTERSECT", + "INTO", + "IS", + "ISNULL", + "JOIN", + "KEY", + "LAST", + "LEFT", + "LIKE", + "LIMIT", + "MATCH", + "MATERIALIZED", + "NATURAL", + "NO", + "NOT", + "NOTHING", + "NOTNULL", + "NULL", + "NULLS", + "OF", + "OFFSET", + "ON", + "OR", + "ORDER", + "OTHERS", + "OUTER", + "OVER", + "PARTITION", + "PLAN", + "PRAGMA", + "PRECEDING", + "PRIMARY", + "QUERY", + "RAISE", + "RANGE", + "RECURSIVE", + "REFERENCES", + "REGEXP", + "REINDEX", + "RELEASE", + "RENAME", + "REPLACE", + "RESTRICT", + "RETURNING", + "RIGHT", + "ROLLBACK", + "ROW", + "ROWS", + "SAVEPOINT", + "SELECT", + "SET", + "TABLE", + "TEMP", + "TEMPORARY", + "THEN", + "TIES", + "TO", + "TRANSACTION", + "TRIGGER", + "UNBOUNDED", + "UNION", + "UNIQUE", + "UPDATE", + "USING", + "VACUUM", + "VALUES", + "VIEW", + "VIRTUAL", + "WHEN", + "WHERE", + "WINDOW", + "WITH", + "WITHOUT"}))}, + {flight_sql_pb::SqlInfo::SQL_NUMERIC_FUNCTIONS, + SqlInfoResult(std::vector( + {"acos", "acosh", "asin", "asinh", "atan", "atan2", "atanh", "ceil", + "ceiling", "cos", "cosh", "degrees", "exp", "floor", "ln", "log", + "log", "log10", "log2", "mod", "pi", "pow", "power", "radians", + "sin", "sinh", "sqrt", "tan", "tanh", "trunc"}))}, + {flight_sql_pb::SqlInfo::SQL_STRING_FUNCTIONS, + SqlInfoResult( + std::vector({"SUBSTR", "TRIM", "LTRIM", "RTRIM", "LENGTH", + "REPLACE", "UPPER", "LOWER", "INSTR"}))}, + {flight_sql_pb::SqlInfo::SQL_SUPPORTS_CONVERT, + SqlInfoResult(std::unordered_map>( + {{flight_sql_pb::SqlSupportsConvert::SQL_CONVERT_BIGINT, + std::vector( + {flight_sql_pb::SqlSupportsConvert::SQL_CONVERT_INTEGER})}}))}}; +} + +} // namespace example +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/example/sqlite_sql_info.h b/cpp/src/arrow/flight/sql/example/sqlite_sql_info.h new file mode 100644 index 0000000000000..f6057540774fb --- /dev/null +++ b/cpp/src/arrow/flight/sql/example/sqlite_sql_info.h @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/flight/sql/sql_info_types.h" + +namespace arrow { +namespace flight { +namespace sql { +namespace example { + +/// \brief Gets the mapping from SQL info ids to SqlInfoResult instances. +/// \return the cache. +SqlInfoResultMap GetSqlInfoResultMap(); + +} // namespace example +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/example/sqlite_statement.cc b/cpp/src/arrow/flight/sql/example/sqlite_statement.cc index 3d7c5efeb23eb..018f8de37dbdd 100644 --- a/cpp/src/arrow/flight/sql/example/sqlite_statement.cc +++ b/cpp/src/arrow/flight/sql/example/sqlite_statement.cc @@ -66,7 +66,7 @@ arrow::Result> SqliteStatement::Create( return result; } -Status SqliteStatement::GetSchema(std::shared_ptr* schema) const { +arrow::Result> SqliteStatement::GetSchema() const { std::vector> fields; int column_count = sqlite3_column_count(stmt_); for (int i = 0; i < column_count; i++) { @@ -99,41 +99,36 @@ Status SqliteStatement::GetSchema(std::shared_ptr* schema) const { fields.push_back(arrow::field(column_name, data_type)); } - *schema = arrow::schema(fields); - return Status::OK(); + return arrow::schema(fields); } SqliteStatement::~SqliteStatement() { sqlite3_finalize(stmt_); } -Status SqliteStatement::Step(int* rc) { - *rc = sqlite3_step(stmt_); - if (*rc == SQLITE_ERROR) { +arrow::Result SqliteStatement::Step() { + int rc = sqlite3_step(stmt_); + if (rc == SQLITE_ERROR) { return Status::ExecutionError("A SQLite runtime error has occurred: ", sqlite3_errmsg(db_)); } - return Status::OK(); + return rc; } -Status SqliteStatement::Reset(int* rc) { - *rc = sqlite3_reset(stmt_); - if (*rc == SQLITE_ERROR) { +arrow::Result SqliteStatement::Reset() { + int rc = sqlite3_reset(stmt_); + if (rc == SQLITE_ERROR) { return Status::ExecutionError("A SQLite runtime error has occurred: ", sqlite3_errmsg(db_)); } - return Status::OK(); + return rc; } sqlite3_stmt* SqliteStatement::GetSqlite3Stmt() const { return stmt_; } -Status SqliteStatement::ExecuteUpdate(int64_t* result) { - int rc; - ARROW_RETURN_NOT_OK(Step(&rc)); - - *result = sqlite3_changes(db_); - - return Status::OK(); +arrow::Result SqliteStatement::ExecuteUpdate() { + ARROW_RETURN_NOT_OK(Step()); + return sqlite3_changes(db_); } } // namespace example diff --git a/cpp/src/arrow/flight/sql/example/sqlite_statement.h b/cpp/src/arrow/flight/sql/example/sqlite_statement.h index 0d32a2f282ddf..a3f086abc4703 100644 --- a/cpp/src/arrow/flight/sql/example/sqlite_statement.h +++ b/cpp/src/arrow/flight/sql/example/sqlite_statement.h @@ -41,28 +41,24 @@ class SqliteStatement { ~SqliteStatement(); /// \brief Creates an Arrow Schema based on the results of this statement. - /// \param[out] schema The resulting Schema. - /// \return Status. - Status GetSchema(std::shared_ptr* schema) const; + /// \return The resulting Schema. + arrow::Result> GetSchema() const; /// \brief Steps on underlying sqlite3_stmt. - /// \param[out] rc The resulting return code from SQLite. - /// \return Status. - Status Step(int* rc); + /// \return The resulting return code from SQLite. + arrow::Result Step(); /// \brief Reset the state of the sqlite3_stmt. - /// \param[out] rc The resulting return code from SQLite. - /// \return Status. - Status Reset(int* rc); + /// \return The resulting return code from SQLite. + arrow::Result Reset(); /// \brief Returns the underlying sqlite3_stmt. /// \return A sqlite statement. sqlite3_stmt* GetSqlite3Stmt() const; /// \brief Executes an UPDATE, INSERT or DELETE statement. - /// \param[out] result The number of rows changed by execution. - /// \return Status. - Status ExecuteUpdate(int64_t* result); + /// \return The number of rows changed by execution. + arrow::Result ExecuteUpdate(); private: sqlite3* db_; diff --git a/cpp/src/arrow/flight/sql/example/sqlite_statement_batch_reader.cc b/cpp/src/arrow/flight/sql/example/sqlite_statement_batch_reader.cc index c6ac34bfddba0..a5824ae255f8d 100644 --- a/cpp/src/arrow/flight/sql/example/sqlite_statement_batch_reader.cc +++ b/cpp/src/arrow/flight/sql/example/sqlite_statement_batch_reader.cc @@ -50,29 +50,40 @@ break; \ } -#define INT_BUILDER_CASE(TYPE_CLASS, STMT, COLUMN) \ - case TYPE_CLASS##Type::type_id: { \ - sqlite3_int64 value = sqlite3_column_int64(STMT, COLUMN); \ - ARROW_RETURN_NOT_OK( \ - (reinterpret_cast(builder)).Append(value)); \ - break; \ +#define INT_BUILDER_CASE(TYPE_CLASS, STMT, COLUMN) \ + case TYPE_CLASS##Type::type_id: { \ + if (sqlite3_column_type(stmt_, i) == SQLITE_NULL) { \ + ARROW_RETURN_NOT_OK( \ + (reinterpret_cast(builder)).AppendNull()); \ + break; \ + } \ + sqlite3_int64 value = sqlite3_column_int64(STMT, COLUMN); \ + ARROW_RETURN_NOT_OK( \ + (reinterpret_cast(builder)).Append(value)); \ + break; \ } -#define FLOAT_BUILDER_CASE(TYPE_CLASS, STMT, COLUMN) \ - case TYPE_CLASS##Type::type_id: { \ - double value = sqlite3_column_double(STMT, COLUMN); \ - ARROW_RETURN_NOT_OK( \ - (reinterpret_cast(builder)).Append(value)); \ - break; \ +#define FLOAT_BUILDER_CASE(TYPE_CLASS, STMT, COLUMN) \ + case TYPE_CLASS##Type::type_id: { \ + if (sqlite3_column_type(stmt_, i) == SQLITE_NULL) { \ + ARROW_RETURN_NOT_OK( \ + (reinterpret_cast(builder)).AppendNull()); \ + break; \ + } \ + double value = sqlite3_column_double(STMT, COLUMN); \ + ARROW_RETURN_NOT_OK( \ + (reinterpret_cast(builder)).Append(value)); \ + break; \ } -#define MAX_BATCH_SIZE 1024 - namespace arrow { namespace flight { namespace sql { namespace example { +// Batch size for SQLite statement results +static constexpr int kMaxBatchSize = 1024; + std::shared_ptr SqliteStatementBatchReader::schema() const { return schema_; } SqliteStatementBatchReader::SqliteStatementBatchReader( @@ -84,11 +95,9 @@ SqliteStatementBatchReader::SqliteStatementBatchReader( Result> SqliteStatementBatchReader::Create( const std::shared_ptr& statement_) { - int rc; - ARROW_RETURN_NOT_OK(statement_->Step(&rc)); + ARROW_RETURN_NOT_OK(statement_->Step()); - std::shared_ptr schema; - ARROW_RETURN_NOT_OK(statement_->GetSchema(&schema)); + ARROW_ASSIGN_OR_RAISE(auto schema, statement_->GetSchema()); std::shared_ptr result( new SqliteStatementBatchReader(statement_, schema)); @@ -119,19 +128,22 @@ Status SqliteStatementBatchReader::ReadNext(std::shared_ptr* out) { } if (!already_executed_) { - ARROW_RETURN_NOT_OK(statement_->Reset(&rc_)); - ARROW_RETURN_NOT_OK(statement_->Step(&rc_)); + ARROW_ASSIGN_OR_RAISE(rc_, statement_->Reset()); + ARROW_ASSIGN_OR_RAISE(rc_, statement_->Step()); already_executed_ = true; } - int rows = 0; - while (rows < MAX_BATCH_SIZE && rc_ == SQLITE_ROW) { + int64_t rows = 0; + while (rows < kMaxBatchSize && rc_ == SQLITE_ROW) { rows++; for (int i = 0; i < num_fields; i++) { const std::shared_ptr& field = schema_->field(i); const std::shared_ptr& field_type = field->type(); ArrayBuilder& builder = *builders[i]; + // NOTE: This is not the optimal way of building Arrow vectors. + // That would be to presize the builders to avoiding several resizing operations + // when appending values and also to build one vector at a time. switch (field_type->id()) { INT_BUILDER_CASE(Int64, stmt_, i) INT_BUILDER_CASE(UInt64, stmt_, i) @@ -154,7 +166,7 @@ Status SqliteStatementBatchReader::ReadNext(std::shared_ptr* out) { } } - ARROW_RETURN_NOT_OK(statement_->Step(&rc_)); + ARROW_ASSIGN_OR_RAISE(rc_, statement_->Step()); } if (rows > 0) { diff --git a/cpp/src/arrow/flight/sql/server.cc b/cpp/src/arrow/flight/sql/server.cc index 219ffd8e01b70..6a34bdeb96aef 100644 --- a/cpp/src/arrow/flight/sql/server.cc +++ b/cpp/src/arrow/flight/sql/server.cc @@ -402,7 +402,7 @@ Status FlightSqlServerBase::DoPut(const ServerCallContext& context, ARROW_ASSIGN_OR_RAISE(StatementUpdate internal_command, ParseCommandStatementUpdate(any)); ARROW_ASSIGN_OR_RAISE(auto record_count, - DoPutCommandStatementUpdate(context, internal_command, reader)) + DoPutCommandStatementUpdate(context, internal_command)) pb::sql::DoPutUpdateResult result; result.set_record_count(record_count); @@ -687,8 +687,7 @@ arrow::Result FlightSqlServerBase::DoPutPreparedStatementUpdate( } arrow::Result FlightSqlServerBase::DoPutCommandStatementUpdate( - const ServerCallContext& context, const StatementUpdate& command, - std::unique_ptr& reader) { + const ServerCallContext& context, const StatementUpdate& command) { return Status::NotImplemented("DoPutCommandStatementUpdate not implemented"); } diff --git a/cpp/src/arrow/flight/sql/server.h b/cpp/src/arrow/flight/sql/server.h index 537cee7085b24..f63606e557a1c 100644 --- a/cpp/src/arrow/flight/sql/server.h +++ b/cpp/src/arrow/flight/sql/server.h @@ -352,11 +352,9 @@ class ARROW_EXPORT FlightSqlServerBase : public FlightServerBase { /// \brief Execute an update SQL statement. /// \param[in] context The call context. /// \param[in] command The StatementUpdate object containing the SQL statement. - /// \param[in] reader a sequence of uploaded record batches. /// \return The changed record count. virtual arrow::Result DoPutCommandStatementUpdate( - const ServerCallContext& context, const StatementUpdate& command, - std::unique_ptr& reader); + const ServerCallContext& context, const StatementUpdate& command); /// \brief Create a prepared statement from given SQL statement. /// \param[in] context The call context. diff --git a/cpp/src/arrow/flight/sql/server_test.cc b/cpp/src/arrow/flight/sql/server_test.cc index 5cd23444fae21..30466eb32f139 100644 --- a/cpp/src/arrow/flight/sql/server_test.cc +++ b/cpp/src/arrow/flight/sql/server_test.cc @@ -23,6 +23,7 @@ #include "arrow/flight/api.h" #include "arrow/flight/sql/api.h" #include "arrow/flight/sql/example/sqlite_server.h" +#include "arrow/flight/sql/example/sqlite_sql_info.h" #include "arrow/flight/test_util.h" #include "arrow/flight/types.h" #include "arrow/testing/gtest_util.h" @@ -30,8 +31,6 @@ using ::testing::_; using ::testing::Ref; -namespace pb = arrow::flight::protocol; - using arrow::internal::checked_cast; namespace arrow { @@ -173,10 +172,11 @@ TEST(TestFlightSqlServer, TestCommandStatementQuery) { arrow::schema({arrow::field("id", int64()), arrow::field("keyName", utf8()), arrow::field("value", int64()), arrow::field("foreignId", int64())}); - const auto id_array = ArrayFromJSON(int64(), R"([1, 2, 3])"); - const auto keyname_array = ArrayFromJSON(utf8(), R"(["one", "zero", "negative one"])"); - const auto value_array = ArrayFromJSON(int64(), R"([1, 0, -1])"); - const auto foreignId_array = ArrayFromJSON(int64(), R"([1, 1, 1])"); + const auto id_array = ArrayFromJSON(int64(), R"([1, 2, 3, 4])"); + const auto keyname_array = + ArrayFromJSON(utf8(), R"(["one", "zero", "negative one", null])"); + const auto value_array = ArrayFromJSON(int64(), R"([1, 0, -1, null])"); + const auto foreignId_array = ArrayFromJSON(int64(), R"([1, 1, 1, null])"); const std::shared_ptr& expected_table = Table::Make( expected_schema, {id_array, keyname_array, value_array, foreignId_array}); @@ -428,10 +428,11 @@ TEST(TestFlightSqlServer, TestCommandPreparedStatementQuery) { arrow::schema({arrow::field("id", int64()), arrow::field("keyName", utf8()), arrow::field("value", int64()), arrow::field("foreignId", int64())}); - const auto id_array = ArrayFromJSON(int64(), R"([1, 2, 3])"); - const auto keyname_array = ArrayFromJSON(utf8(), R"(["one", "zero", "negative one"])"); - const auto value_array = ArrayFromJSON(int64(), R"([1, 0, -1])"); - const auto foreignId_array = ArrayFromJSON(int64(), R"([1, 1, 1])"); + const auto id_array = ArrayFromJSON(int64(), R"([1, 2, 3, 4])"); + const auto keyname_array = + ArrayFromJSON(utf8(), R"(["one", "zero", "negative one", null])"); + const auto value_array = ArrayFromJSON(int64(), R"([1, 0, -1, null])"); + const auto foreignId_array = ArrayFromJSON(int64(), R"([1, 1, 1, null])"); const std::shared_ptr
& expected_table = Table::Make( expected_schema, {id_array, keyname_array, value_array, foreignId_array}); @@ -538,16 +539,16 @@ TEST(TestFlightSqlServer, TestCommandPreparedStatementUpdateWithParameterBinding ASSERT_OK(prepared_statement->SetParameters(record_batch)); - ASSERT_OK_AND_EQ(3, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); + ASSERT_OK_AND_EQ(4, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); ASSERT_OK_AND_EQ(1, prepared_statement->ExecuteUpdate()); - ASSERT_OK_AND_EQ(4, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); + ASSERT_OK_AND_EQ(5, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); ASSERT_OK_AND_EQ(1, sql_client->ExecuteUpdate( {}, "DELETE FROM intTable WHERE keyName = 'new_value'")); - ASSERT_OK_AND_EQ(3, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); + ASSERT_OK_AND_EQ(4, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); } TEST(TestFlightSqlServer, TestCommandPreparedStatementUpdate) { @@ -556,16 +557,16 @@ TEST(TestFlightSqlServer, TestCommandPreparedStatementUpdate) { sql_client->Prepare( {}, "INSERT INTO INTTABLE (keyName, value) VALUES ('new_value', 999)")); - ASSERT_OK_AND_EQ(3, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); + ASSERT_OK_AND_EQ(4, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); ASSERT_OK_AND_EQ(1, prepared_statement->ExecuteUpdate()); - ASSERT_OK_AND_EQ(4, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); + ASSERT_OK_AND_EQ(5, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); ASSERT_OK_AND_EQ(1, sql_client->ExecuteUpdate( {}, "DELETE FROM intTable WHERE keyName = 'new_value'")); - ASSERT_OK_AND_EQ(3, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); + ASSERT_OK_AND_EQ(4, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); } TEST(TestFlightSqlServer, TestCommandGetPrimaryKeys) {