Skip to content

Commit

Permalink
Fix #228: explicitly handle NULL bytes in strings by throwing a helpf…
Browse files Browse the repository at this point in the history
…ul error message, and add a setting pg_null_byte_replacement which can be used to replace them
  • Loading branch information
Mytherin committed Sep 3, 2024
1 parent b6dda6c commit eb5e479
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 17 deletions.
29 changes: 27 additions & 2 deletions src/include/postgres_binary_writer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ namespace duckdb {

class PostgresBinaryWriter {
public:
explicit PostgresBinaryWriter(PostgresCopyState &state) : state(state) {
}

template <class T>
T GetInteger(T val) {
if (sizeof(T) == sizeof(uint8_t)) {
Expand Down Expand Up @@ -198,8 +201,29 @@ class PostgresBinaryWriter {
}

void WriteVarchar(string_t value) {
WriteRawInteger<int32_t>(value.GetSize());
stream.WriteData(const_data_ptr_cast(value.GetData()), value.GetSize());
auto str_size = value.GetSize();
auto str_data = value.GetData();
if (memchr(str_data, '\0', str_size) != nullptr) {
if (!state.has_null_byte_replacement) {
throw InvalidInputException("Attempting to write a VARCHAR value with a NULL-byte. Postgres does not "
"support NULL-bytes in VARCHAR values.\n* SET pg_null_byte_replacement='' "
"to remove NULL bytes or replace them with another character");
}
// we have a NULL byte replacement - construct a new string that has all null bytes replaced and write it
// out
string new_str;
for (idx_t i = 0; i < str_size; i++) {
if (str_data[i] == '\0') {
new_str += state.null_byte_replacement;
} else {
new_str += str_data[i];
}
}
WriteVarchar(new_str);
return;
}
WriteRawInteger<int32_t>(NumericCast<int32_t>(str_size));
stream.WriteData(const_data_ptr_cast(str_data), str_size);
}

void WriteArray(Vector &col, idx_t r, const vector<uint32_t> &dimensions, idx_t depth, uint32_t count) {
Expand Down Expand Up @@ -405,6 +429,7 @@ class PostgresBinaryWriter {

public:
MemoryStream stream;
PostgresCopyState &state;
};

} // namespace duckdb
4 changes: 0 additions & 4 deletions src/include/postgres_connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ struct OwnedPostgresConnection {
PGconn *connection;
};

struct PostgresCopyState {
PostgresCopyFormat format = PostgresCopyFormat::AUTO;
};

class PostgresConnection {
public:
explicit PostgresConnection(shared_ptr<OwnedPostgresConnection> connection = nullptr);
Expand Down
14 changes: 14 additions & 0 deletions src/include/postgres_text_writer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ namespace duckdb {

class PostgresTextWriter {
public:
explicit PostgresTextWriter(PostgresCopyState &state) : state(state) {
}

void WriteNull() {
stream.WriteData(const_data_ptr_cast("\b"), 1);
}
Expand Down Expand Up @@ -59,6 +62,16 @@ class PostgresTextWriter {
WriteCharInternal('\\');
WriteCharInternal('"');
break;
case '\0':
if (!state.has_null_byte_replacement) {
throw InvalidInputException("Attempting to write a VARCHAR value with a NULL-byte. Postgres does not "
"support NULL-bytes in VARCHAR values.\n* SET pg_null_byte_replacement='' "
"to remove NULL bytes or replace them with another character");
}
for (const auto replacement_chr : state.null_byte_replacement) {
WriteChar(replacement_chr);
}
break;
default:
WriteCharInternal(c);
break;
Expand Down Expand Up @@ -98,6 +111,7 @@ class PostgresTextWriter {

public:
MemoryStream stream;
PostgresCopyState &state;
};

} // namespace duckdb
8 changes: 8 additions & 0 deletions src/include/postgres_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ struct PostgresType {

enum class PostgresCopyFormat { AUTO = 0, BINARY = 1, TEXT = 2 };

struct PostgresCopyState {
PostgresCopyFormat format = PostgresCopyFormat::AUTO;
bool has_null_byte_replacement = false;
string null_byte_replacement;

void Initialize(ClientContext &context);
};

class PostgresUtils {
public:
static PGconn *PGConnect(const string &dsn);
Expand Down
17 changes: 11 additions & 6 deletions src/postgres_binary_copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
namespace duckdb {

PostgresBinaryCopyFunction::PostgresBinaryCopyFunction() : CopyFunction("postgres_binary") {

copy_to_bind = PostgresBinaryWriteBind;
copy_to_initialize_global = PostgresBinaryWriteInitializeGlobal;
copy_to_initialize_local = PostgresBinaryWriteInitializeLocal;
Expand All @@ -16,21 +15,23 @@ PostgresBinaryCopyFunction::PostgresBinaryCopyFunction() : CopyFunction("postgre
}

struct PostgresBinaryCopyGlobalState : public GlobalFunctionData {
unique_ptr<BufferedFileWriter> file_writer;
explicit PostgresBinaryCopyGlobalState(ClientContext &context) {
copy_state.Initialize(context);
}

void Flush(PostgresBinaryWriter &writer) {
file_writer->WriteData(writer.stream.GetData(), writer.stream.GetPosition());
}

void WriteHeader() {
PostgresBinaryWriter writer;
PostgresBinaryWriter writer(copy_state);
writer.WriteHeader();
Flush(writer);
}

void WriteChunk(DataChunk &chunk) {
chunk.Flatten();
PostgresBinaryWriter writer;
PostgresBinaryWriter writer(copy_state);
for (idx_t r = 0; r < chunk.size(); r++) {
writer.BeginRow(chunk.ColumnCount());
for (idx_t c = 0; c < chunk.ColumnCount(); c++) {
Expand All @@ -44,13 +45,17 @@ struct PostgresBinaryCopyGlobalState : public GlobalFunctionData {

void Flush() {
// write the footer
PostgresBinaryWriter writer;
PostgresBinaryWriter writer(copy_state);
writer.WriteFooter();
Flush(writer);
// flush and close the file
file_writer->Flush();
file_writer.reset();
}

public:
unique_ptr<BufferedFileWriter> file_writer;
PostgresCopyState copy_state;
};

struct PostgresBinaryWriteBindData : public TableFunctionData {};
Expand All @@ -65,7 +70,7 @@ unique_ptr<FunctionData> PostgresBinaryCopyFunction::PostgresBinaryWriteBind(Cli
unique_ptr<GlobalFunctionData>
PostgresBinaryCopyFunction::PostgresBinaryWriteInitializeGlobal(ClientContext &context, FunctionData &bind_data,
const string &file_path) {
auto result = make_uniq<PostgresBinaryCopyGlobalState>();
auto result = make_uniq<PostgresBinaryCopyGlobalState>(context);
auto &fs = FileSystem::GetFileSystem(context);
result->file_writer = make_uniq<BufferedFileWriter>(fs, file_path);
// write the header
Expand Down
29 changes: 24 additions & 5 deletions src/postgres_copy_to.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,24 @@

namespace duckdb {

void PostgresCopyState::Initialize(ClientContext &context) {
Value replacement_value;
if (!context.TryGetCurrentSetting("pg_null_byte_replacement", replacement_value)) {
return;
}
if (replacement_value.IsNull()) {
return;
}
auto replacement_str = StringValue::Get(replacement_value);
for (const auto c : replacement_str) {
if (c == '\0') {
throw InternalException("NULL byte replacement string cannot contain NULL values");
}
}
has_null_byte_replacement = true;
null_byte_replacement = std::move(replacement_str);
}

void PostgresConnection::BeginCopyTo(ClientContext &context, PostgresCopyState &state, PostgresCopyFormat format,
const string &schema_name, const string &table_name,
const vector<string> &column_names) {
Expand All @@ -24,6 +42,7 @@ void PostgresConnection::BeginCopyTo(ClientContext &context, PostgresCopyState &
query += ") ";
}
query += "FROM STDIN (FORMAT ";
state.Initialize(context);
state.format = format;
switch (state.format) {
case PostgresCopyFormat::BINARY:
Expand All @@ -43,7 +62,7 @@ void PostgresConnection::BeginCopyTo(ClientContext &context, PostgresCopyState &
}
if (state.format == PostgresCopyFormat::BINARY) {
// binary copy requires a header
PostgresBinaryWriter writer;
PostgresBinaryWriter writer(state);
writer.WriteHeader();
CopyData(writer);
}
Expand All @@ -70,12 +89,12 @@ void PostgresConnection::CopyData(PostgresTextWriter &writer) {
void PostgresConnection::FinishCopyTo(PostgresCopyState &state) {
if (state.format == PostgresCopyFormat::BINARY) {
// binary copy requires a footer
PostgresBinaryWriter writer;
PostgresBinaryWriter writer(state);
writer.WriteFooter();
CopyData(writer);
} else if (state.format == PostgresCopyFormat::TEXT) {
// text copy requires a footer
PostgresTextWriter writer;
PostgresTextWriter writer(state);
writer.WriteFooter();
CopyData(writer);
}
Expand Down Expand Up @@ -263,7 +282,7 @@ void PostgresConnection::CopyChunk(ClientContext &context, PostgresCopyState &st
chunk.Flatten();

if (state.format == PostgresCopyFormat::BINARY) {
PostgresBinaryWriter writer;
PostgresBinaryWriter writer(state);
for (idx_t r = 0; r < chunk.size(); r++) {
writer.BeginRow(chunk.ColumnCount());
for (idx_t c = 0; c < chunk.ColumnCount(); c++) {
Expand All @@ -290,7 +309,7 @@ void PostgresConnection::CopyChunk(ClientContext &context, PostgresCopyState &st
}
varchar_chunk.SetCardinality(chunk.size());

PostgresTextWriter writer;
PostgresTextWriter writer(state);
for (idx_t r = 0; r < chunk.size(); r++) {
for (idx_t c = 0; c < chunk.ColumnCount(); c++) {
if (c > 0) {
Expand Down
14 changes: 14 additions & 0 deletions src/postgres_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,17 @@ void SetPostgresSecretParameters(CreateSecretFunction &function) {
function.named_parameters["dbname"] = LogicalType::VARCHAR;
}

void SetPostgresNullByteReplacement(ClientContext &context, SetScope scope, Value &parameter) {
if (parameter.IsNull()) {
return;
}
for (const auto c : StringValue::Get(parameter)) {
if (c == '\0') {
throw BinderException("NULL byte replacement string cannot contain NULL values");
}
}
}

static void LoadInternal(DatabaseInstance &db) {
PostgresScanFunction postgres_fun;
ExtensionUtil::RegisterFunction(db, postgres_fun);
Expand Down Expand Up @@ -165,6 +176,9 @@ static void LoadInternal(DatabaseInstance &db) {
config.AddExtensionOption("pg_experimental_filter_pushdown",
"Whether or not to use filter pushdown (currently experimental)", LogicalType::BOOLEAN,
Value::BOOLEAN(false));
config.AddExtensionOption("pg_null_byte_replacement",
"When writing NULL bytes to Postgres, replace them with the given character",
LogicalType::VARCHAR, Value(), SetPostgresNullByteReplacement);
config.AddExtensionOption("pg_debug_show_queries", "DEBUG SETTING: print all queries sent to Postgres to stdout",
LogicalType::BOOLEAN, Value::BOOLEAN(false), SetPostgresDebugQueryPrint);

Expand Down
65 changes: 65 additions & 0 deletions test/sql/storage/attach_null_byte.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# name: test/sql/storage/attach_null_byte.test
# description: Test inserting null byte values through ATTACH
# group: [storage]

require postgres_scanner

require-env POSTGRES_TEST_DATABASE_AVAILABLE

statement ok
PRAGMA enable_verification

statement ok
ATTACH 'dbname=postgresscanner' AS s1 (TYPE POSTGRES)

statement ok
USE s1

foreach pg_binary true false

statement ok
SET pg_use_binary_copy=${pg_binary}

statement ok
CREATE OR REPLACE TABLE nullbyte_tbl(s VARCHAR);

statement error
INSERT INTO nullbyte_tbl VALUES (chr(0))
----
Postgres does not support NULL-bytes in VARCHAR values

statement ok
SET pg_null_byte_replacement=''

statement ok
INSERT INTO nullbyte_tbl VALUES (chr(0)), ('FF' || chr(0) || 'FF');

query I
SELECT * FROM nullbyte_tbl
----
(empty)
FFFF

statement ok
SET pg_null_byte_replacement='NULLBYTE'

statement ok
INSERT INTO nullbyte_tbl VALUES (chr(0)), ('FF' || chr(0) || 'FF');

query I
SELECT * FROM nullbyte_tbl
----
(empty)
FFFF
NULLBYTE
FFNULLBYTEFF

statement ok
RESET pg_null_byte_replacement

endloop

statement error
SET pg_null_byte_replacement=chr(0)
----
NULL byte replacement string cannot contain NULL values

0 comments on commit eb5e479

Please sign in to comment.