From eb5e4796176c39d029d71ebaff228727ed8c1708 Mon Sep 17 00:00:00 2001 From: Mark Raasveldt Date: Tue, 3 Sep 2024 19:56:09 +0200 Subject: [PATCH 1/2] Fix #228: explicitly handle NULL bytes in strings by throwing a helpful error message, and add a setting pg_null_byte_replacement which can be used to replace them --- src/include/postgres_binary_writer.hpp | 29 +++++++++++- src/include/postgres_connection.hpp | 4 -- src/include/postgres_text_writer.hpp | 14 ++++++ src/include/postgres_utils.hpp | 8 ++++ src/postgres_binary_copy.cpp | 17 ++++--- src/postgres_copy_to.cpp | 29 ++++++++++-- src/postgres_extension.cpp | 14 ++++++ test/sql/storage/attach_null_byte.test | 65 ++++++++++++++++++++++++++ 8 files changed, 163 insertions(+), 17 deletions(-) create mode 100644 test/sql/storage/attach_null_byte.test diff --git a/src/include/postgres_binary_writer.hpp b/src/include/postgres_binary_writer.hpp index 5e3a1f86..4482a0d1 100644 --- a/src/include/postgres_binary_writer.hpp +++ b/src/include/postgres_binary_writer.hpp @@ -17,6 +17,9 @@ namespace duckdb { class PostgresBinaryWriter { public: + explicit PostgresBinaryWriter(PostgresCopyState &state) : state(state) { + } + template T GetInteger(T val) { if (sizeof(T) == sizeof(uint8_t)) { @@ -198,8 +201,29 @@ class PostgresBinaryWriter { } void WriteVarchar(string_t value) { - WriteRawInteger(value.GetSize()); - stream.WriteData(const_data_ptr_cast(value.GetData()), value.GetSize()); + auto str_size = value.GetSize(); + auto str_data = value.GetData(); + if (memchr(str_data, '\0', str_size) != nullptr) { + if (!state.has_null_byte_replacement) { + throw InvalidInputException("Attempting to write a VARCHAR value with a NULL-byte. Postgres does not " + "support NULL-bytes in VARCHAR values.\n* SET pg_null_byte_replacement='' " + "to remove NULL bytes or replace them with another character"); + } + // we have a NULL byte replacement - construct a new string that has all null bytes replaced and write it + // out + string new_str; + for (idx_t i = 0; i < str_size; i++) { + if (str_data[i] == '\0') { + new_str += state.null_byte_replacement; + } else { + new_str += str_data[i]; + } + } + WriteVarchar(new_str); + return; + } + WriteRawInteger(NumericCast(str_size)); + stream.WriteData(const_data_ptr_cast(str_data), str_size); } void WriteArray(Vector &col, idx_t r, const vector &dimensions, idx_t depth, uint32_t count) { @@ -405,6 +429,7 @@ class PostgresBinaryWriter { public: MemoryStream stream; + PostgresCopyState &state; }; } // namespace duckdb diff --git a/src/include/postgres_connection.hpp b/src/include/postgres_connection.hpp index 08fd19c4..d13b834c 100644 --- a/src/include/postgres_connection.hpp +++ b/src/include/postgres_connection.hpp @@ -29,10 +29,6 @@ struct OwnedPostgresConnection { PGconn *connection; }; -struct PostgresCopyState { - PostgresCopyFormat format = PostgresCopyFormat::AUTO; -}; - class PostgresConnection { public: explicit PostgresConnection(shared_ptr connection = nullptr); diff --git a/src/include/postgres_text_writer.hpp b/src/include/postgres_text_writer.hpp index 389349a8..5394f2b4 100644 --- a/src/include/postgres_text_writer.hpp +++ b/src/include/postgres_text_writer.hpp @@ -17,6 +17,9 @@ namespace duckdb { class PostgresTextWriter { public: + explicit PostgresTextWriter(PostgresCopyState &state) : state(state) { + } + void WriteNull() { stream.WriteData(const_data_ptr_cast("\b"), 1); } @@ -59,6 +62,16 @@ class PostgresTextWriter { WriteCharInternal('\\'); WriteCharInternal('"'); break; + case '\0': + if (!state.has_null_byte_replacement) { + throw InvalidInputException("Attempting to write a VARCHAR value with a NULL-byte. Postgres does not " + "support NULL-bytes in VARCHAR values.\n* SET pg_null_byte_replacement='' " + "to remove NULL bytes or replace them with another character"); + } + for (const auto replacement_chr : state.null_byte_replacement) { + WriteChar(replacement_chr); + } + break; default: WriteCharInternal(c); break; @@ -98,6 +111,7 @@ class PostgresTextWriter { public: MemoryStream stream; + PostgresCopyState &state; }; } // namespace duckdb diff --git a/src/include/postgres_utils.hpp b/src/include/postgres_utils.hpp index 89e1a726..ebeb37fa 100644 --- a/src/include/postgres_utils.hpp +++ b/src/include/postgres_utils.hpp @@ -46,6 +46,14 @@ struct PostgresType { enum class PostgresCopyFormat { AUTO = 0, BINARY = 1, TEXT = 2 }; +struct PostgresCopyState { + PostgresCopyFormat format = PostgresCopyFormat::AUTO; + bool has_null_byte_replacement = false; + string null_byte_replacement; + + void Initialize(ClientContext &context); +}; + class PostgresUtils { public: static PGconn *PGConnect(const string &dsn); diff --git a/src/postgres_binary_copy.cpp b/src/postgres_binary_copy.cpp index 4c89c3f6..9c0212f8 100644 --- a/src/postgres_binary_copy.cpp +++ b/src/postgres_binary_copy.cpp @@ -6,7 +6,6 @@ namespace duckdb { PostgresBinaryCopyFunction::PostgresBinaryCopyFunction() : CopyFunction("postgres_binary") { - copy_to_bind = PostgresBinaryWriteBind; copy_to_initialize_global = PostgresBinaryWriteInitializeGlobal; copy_to_initialize_local = PostgresBinaryWriteInitializeLocal; @@ -16,21 +15,23 @@ PostgresBinaryCopyFunction::PostgresBinaryCopyFunction() : CopyFunction("postgre } struct PostgresBinaryCopyGlobalState : public GlobalFunctionData { - unique_ptr file_writer; + explicit PostgresBinaryCopyGlobalState(ClientContext &context) { + copy_state.Initialize(context); + } void Flush(PostgresBinaryWriter &writer) { file_writer->WriteData(writer.stream.GetData(), writer.stream.GetPosition()); } void WriteHeader() { - PostgresBinaryWriter writer; + PostgresBinaryWriter writer(copy_state); writer.WriteHeader(); Flush(writer); } void WriteChunk(DataChunk &chunk) { chunk.Flatten(); - PostgresBinaryWriter writer; + PostgresBinaryWriter writer(copy_state); for (idx_t r = 0; r < chunk.size(); r++) { writer.BeginRow(chunk.ColumnCount()); for (idx_t c = 0; c < chunk.ColumnCount(); c++) { @@ -44,13 +45,17 @@ struct PostgresBinaryCopyGlobalState : public GlobalFunctionData { void Flush() { // write the footer - PostgresBinaryWriter writer; + PostgresBinaryWriter writer(copy_state); writer.WriteFooter(); Flush(writer); // flush and close the file file_writer->Flush(); file_writer.reset(); } + +public: + unique_ptr file_writer; + PostgresCopyState copy_state; }; struct PostgresBinaryWriteBindData : public TableFunctionData {}; @@ -65,7 +70,7 @@ unique_ptr PostgresBinaryCopyFunction::PostgresBinaryWriteBind(Cli unique_ptr PostgresBinaryCopyFunction::PostgresBinaryWriteInitializeGlobal(ClientContext &context, FunctionData &bind_data, const string &file_path) { - auto result = make_uniq(); + auto result = make_uniq(context); auto &fs = FileSystem::GetFileSystem(context); result->file_writer = make_uniq(fs, file_path); // write the header diff --git a/src/postgres_copy_to.cpp b/src/postgres_copy_to.cpp index d0189d5d..69b78671 100644 --- a/src/postgres_copy_to.cpp +++ b/src/postgres_copy_to.cpp @@ -5,6 +5,24 @@ namespace duckdb { +void PostgresCopyState::Initialize(ClientContext &context) { + Value replacement_value; + if (!context.TryGetCurrentSetting("pg_null_byte_replacement", replacement_value)) { + return; + } + if (replacement_value.IsNull()) { + return; + } + auto replacement_str = StringValue::Get(replacement_value); + for (const auto c : replacement_str) { + if (c == '\0') { + throw InternalException("NULL byte replacement string cannot contain NULL values"); + } + } + has_null_byte_replacement = true; + null_byte_replacement = std::move(replacement_str); +} + void PostgresConnection::BeginCopyTo(ClientContext &context, PostgresCopyState &state, PostgresCopyFormat format, const string &schema_name, const string &table_name, const vector &column_names) { @@ -24,6 +42,7 @@ void PostgresConnection::BeginCopyTo(ClientContext &context, PostgresCopyState & query += ") "; } query += "FROM STDIN (FORMAT "; + state.Initialize(context); state.format = format; switch (state.format) { case PostgresCopyFormat::BINARY: @@ -43,7 +62,7 @@ void PostgresConnection::BeginCopyTo(ClientContext &context, PostgresCopyState & } if (state.format == PostgresCopyFormat::BINARY) { // binary copy requires a header - PostgresBinaryWriter writer; + PostgresBinaryWriter writer(state); writer.WriteHeader(); CopyData(writer); } @@ -70,12 +89,12 @@ void PostgresConnection::CopyData(PostgresTextWriter &writer) { void PostgresConnection::FinishCopyTo(PostgresCopyState &state) { if (state.format == PostgresCopyFormat::BINARY) { // binary copy requires a footer - PostgresBinaryWriter writer; + PostgresBinaryWriter writer(state); writer.WriteFooter(); CopyData(writer); } else if (state.format == PostgresCopyFormat::TEXT) { // text copy requires a footer - PostgresTextWriter writer; + PostgresTextWriter writer(state); writer.WriteFooter(); CopyData(writer); } @@ -263,7 +282,7 @@ void PostgresConnection::CopyChunk(ClientContext &context, PostgresCopyState &st chunk.Flatten(); if (state.format == PostgresCopyFormat::BINARY) { - PostgresBinaryWriter writer; + PostgresBinaryWriter writer(state); for (idx_t r = 0; r < chunk.size(); r++) { writer.BeginRow(chunk.ColumnCount()); for (idx_t c = 0; c < chunk.ColumnCount(); c++) { @@ -290,7 +309,7 @@ void PostgresConnection::CopyChunk(ClientContext &context, PostgresCopyState &st } varchar_chunk.SetCardinality(chunk.size()); - PostgresTextWriter writer; + PostgresTextWriter writer(state); for (idx_t r = 0; r < chunk.size(); r++) { for (idx_t c = 0; c < chunk.ColumnCount(); c++) { if (c > 0) { diff --git a/src/postgres_extension.cpp b/src/postgres_extension.cpp index 7b417d86..90fc88b4 100644 --- a/src/postgres_extension.cpp +++ b/src/postgres_extension.cpp @@ -111,6 +111,17 @@ void SetPostgresSecretParameters(CreateSecretFunction &function) { function.named_parameters["dbname"] = LogicalType::VARCHAR; } +void SetPostgresNullByteReplacement(ClientContext &context, SetScope scope, Value ¶meter) { + if (parameter.IsNull()) { + return; + } + for (const auto c : StringValue::Get(parameter)) { + if (c == '\0') { + throw BinderException("NULL byte replacement string cannot contain NULL values"); + } + } +} + static void LoadInternal(DatabaseInstance &db) { PostgresScanFunction postgres_fun; ExtensionUtil::RegisterFunction(db, postgres_fun); @@ -165,6 +176,9 @@ static void LoadInternal(DatabaseInstance &db) { config.AddExtensionOption("pg_experimental_filter_pushdown", "Whether or not to use filter pushdown (currently experimental)", LogicalType::BOOLEAN, Value::BOOLEAN(false)); + config.AddExtensionOption("pg_null_byte_replacement", + "When writing NULL bytes to Postgres, replace them with the given character", + LogicalType::VARCHAR, Value(), SetPostgresNullByteReplacement); config.AddExtensionOption("pg_debug_show_queries", "DEBUG SETTING: print all queries sent to Postgres to stdout", LogicalType::BOOLEAN, Value::BOOLEAN(false), SetPostgresDebugQueryPrint); diff --git a/test/sql/storage/attach_null_byte.test b/test/sql/storage/attach_null_byte.test new file mode 100644 index 00000000..280fee38 --- /dev/null +++ b/test/sql/storage/attach_null_byte.test @@ -0,0 +1,65 @@ +# name: test/sql/storage/attach_null_byte.test +# description: Test inserting null byte values through ATTACH +# group: [storage] + +require postgres_scanner + +require-env POSTGRES_TEST_DATABASE_AVAILABLE + +statement ok +PRAGMA enable_verification + +statement ok +ATTACH 'dbname=postgresscanner' AS s1 (TYPE POSTGRES) + +statement ok +USE s1 + +foreach pg_binary true false + +statement ok +SET pg_use_binary_copy=${pg_binary} + +statement ok +CREATE OR REPLACE TABLE nullbyte_tbl(s VARCHAR); + +statement error +INSERT INTO nullbyte_tbl VALUES (chr(0)) +---- +Postgres does not support NULL-bytes in VARCHAR values + +statement ok +SET pg_null_byte_replacement='' + +statement ok +INSERT INTO nullbyte_tbl VALUES (chr(0)), ('FF' || chr(0) || 'FF'); + +query I +SELECT * FROM nullbyte_tbl +---- +(empty) +FFFF + +statement ok +SET pg_null_byte_replacement='NULLBYTE' + +statement ok +INSERT INTO nullbyte_tbl VALUES (chr(0)), ('FF' || chr(0) || 'FF'); + +query I +SELECT * FROM nullbyte_tbl +---- +(empty) +FFFF +NULLBYTE +FFNULLBYTEFF + +statement ok +RESET pg_null_byte_replacement + +endloop + +statement error +SET pg_null_byte_replacement=chr(0) +---- +NULL byte replacement string cannot contain NULL values From a13f20e6f2881ebe27c0eb8328a0ac9f2d191198 Mon Sep 17 00:00:00 2001 From: Mark Raasveldt Date: Tue, 3 Sep 2024 22:41:01 +0200 Subject: [PATCH 2/2] Fix for writing blobs with null bytes --- src/include/postgres_binary_writer.hpp | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/include/postgres_binary_writer.hpp b/src/include/postgres_binary_writer.hpp index 4482a0d1..eff48d1b 100644 --- a/src/include/postgres_binary_writer.hpp +++ b/src/include/postgres_binary_writer.hpp @@ -200,6 +200,13 @@ class PostgresBinaryWriter { } } + void WriteRawBlob(string_t value) { + auto str_size = value.GetSize(); + auto str_data = value.GetData(); + WriteRawInteger(NumericCast(str_size)); + stream.WriteData(const_data_ptr_cast(str_data), str_size); + } + void WriteVarchar(string_t value) { auto str_size = value.GetSize(); auto str_data = value.GetData(); @@ -219,11 +226,10 @@ class PostgresBinaryWriter { new_str += str_data[i]; } } - WriteVarchar(new_str); + WriteRawBlob(new_str); return; } - WriteRawInteger(NumericCast(str_size)); - stream.WriteData(const_data_ptr_cast(str_data), str_size); + WriteRawBlob(value); } void WriteArray(Vector &col, idx_t r, const vector &dimensions, idx_t depth, uint32_t count) { @@ -336,12 +342,16 @@ class PostgresBinaryWriter { WriteUUID(data); break; } - case LogicalTypeId::BLOB: case LogicalTypeId::VARCHAR: { auto data = FlatVector::GetData(col)[r]; WriteVarchar(data); break; } + case LogicalTypeId::BLOB: { + auto data = FlatVector::GetData(col)[r]; + WriteRawBlob(data); + break; + } case LogicalTypeId::ENUM: { idx_t pos; switch (type.InternalType()) {