Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #228: explicitly handle NULL bytes in strings by throwing a helpful error message, and add a setting pg_null_byte_replacement which can be used to replace them #255

Merged
merged 2 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions src/include/postgres_binary_writer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ namespace duckdb {

class PostgresBinaryWriter {
public:
explicit PostgresBinaryWriter(PostgresCopyState &state) : state(state) {
}

template <class T>
T GetInteger(T val) {
if (sizeof(T) == sizeof(uint8_t)) {
Expand Down Expand Up @@ -197,9 +200,36 @@ class PostgresBinaryWriter {
}
}

void WriteRawBlob(string_t value) {
auto str_size = value.GetSize();
auto str_data = value.GetData();
WriteRawInteger<int32_t>(NumericCast<int32_t>(str_size));
stream.WriteData(const_data_ptr_cast(str_data), str_size);
}

void WriteVarchar(string_t value) {
WriteRawInteger<int32_t>(value.GetSize());
stream.WriteData(const_data_ptr_cast(value.GetData()), value.GetSize());
auto str_size = value.GetSize();
auto str_data = value.GetData();
if (memchr(str_data, '\0', str_size) != nullptr) {
if (!state.has_null_byte_replacement) {
throw InvalidInputException("Attempting to write a VARCHAR value with a NULL-byte. Postgres does not "
"support NULL-bytes in VARCHAR values.\n* SET pg_null_byte_replacement='' "
"to remove NULL bytes or replace them with another character");
}
// we have a NULL byte replacement - construct a new string that has all null bytes replaced and write it
// out
string new_str;
for (idx_t i = 0; i < str_size; i++) {
if (str_data[i] == '\0') {
new_str += state.null_byte_replacement;
} else {
new_str += str_data[i];
}
}
WriteRawBlob(new_str);
return;
}
WriteRawBlob(value);
}

void WriteArray(Vector &col, idx_t r, const vector<uint32_t> &dimensions, idx_t depth, uint32_t count) {
Expand Down Expand Up @@ -312,12 +342,16 @@ class PostgresBinaryWriter {
WriteUUID(data);
break;
}
case LogicalTypeId::BLOB:
case LogicalTypeId::VARCHAR: {
auto data = FlatVector::GetData<string_t>(col)[r];
WriteVarchar(data);
break;
}
case LogicalTypeId::BLOB: {
auto data = FlatVector::GetData<string_t>(col)[r];
WriteRawBlob(data);
break;
}
case LogicalTypeId::ENUM: {
idx_t pos;
switch (type.InternalType()) {
Expand Down Expand Up @@ -405,6 +439,7 @@ class PostgresBinaryWriter {

public:
MemoryStream stream;
PostgresCopyState &state;
};

} // namespace duckdb
4 changes: 0 additions & 4 deletions src/include/postgres_connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ struct OwnedPostgresConnection {
PGconn *connection;
};

struct PostgresCopyState {
PostgresCopyFormat format = PostgresCopyFormat::AUTO;
};

class PostgresConnection {
public:
explicit PostgresConnection(shared_ptr<OwnedPostgresConnection> connection = nullptr);
Expand Down
14 changes: 14 additions & 0 deletions src/include/postgres_text_writer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ namespace duckdb {

class PostgresTextWriter {
public:
explicit PostgresTextWriter(PostgresCopyState &state) : state(state) {
}

void WriteNull() {
stream.WriteData(const_data_ptr_cast("\b"), 1);
}
Expand Down Expand Up @@ -59,6 +62,16 @@ class PostgresTextWriter {
WriteCharInternal('\\');
WriteCharInternal('"');
break;
case '\0':
if (!state.has_null_byte_replacement) {
throw InvalidInputException("Attempting to write a VARCHAR value with a NULL-byte. Postgres does not "
"support NULL-bytes in VARCHAR values.\n* SET pg_null_byte_replacement='' "
"to remove NULL bytes or replace them with another character");
}
for (const auto replacement_chr : state.null_byte_replacement) {
WriteChar(replacement_chr);
}
break;
default:
WriteCharInternal(c);
break;
Expand Down Expand Up @@ -98,6 +111,7 @@ class PostgresTextWriter {

public:
MemoryStream stream;
PostgresCopyState &state;
};

} // namespace duckdb
8 changes: 8 additions & 0 deletions src/include/postgres_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ struct PostgresType {

enum class PostgresCopyFormat { AUTO = 0, BINARY = 1, TEXT = 2 };

struct PostgresCopyState {
PostgresCopyFormat format = PostgresCopyFormat::AUTO;
bool has_null_byte_replacement = false;
string null_byte_replacement;

void Initialize(ClientContext &context);
};

class PostgresUtils {
public:
static PGconn *PGConnect(const string &dsn);
Expand Down
17 changes: 11 additions & 6 deletions src/postgres_binary_copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
namespace duckdb {

PostgresBinaryCopyFunction::PostgresBinaryCopyFunction() : CopyFunction("postgres_binary") {

copy_to_bind = PostgresBinaryWriteBind;
copy_to_initialize_global = PostgresBinaryWriteInitializeGlobal;
copy_to_initialize_local = PostgresBinaryWriteInitializeLocal;
Expand All @@ -16,21 +15,23 @@ PostgresBinaryCopyFunction::PostgresBinaryCopyFunction() : CopyFunction("postgre
}

struct PostgresBinaryCopyGlobalState : public GlobalFunctionData {
unique_ptr<BufferedFileWriter> file_writer;
explicit PostgresBinaryCopyGlobalState(ClientContext &context) {
copy_state.Initialize(context);
}

void Flush(PostgresBinaryWriter &writer) {
file_writer->WriteData(writer.stream.GetData(), writer.stream.GetPosition());
}

void WriteHeader() {
PostgresBinaryWriter writer;
PostgresBinaryWriter writer(copy_state);
writer.WriteHeader();
Flush(writer);
}

void WriteChunk(DataChunk &chunk) {
chunk.Flatten();
PostgresBinaryWriter writer;
PostgresBinaryWriter writer(copy_state);
for (idx_t r = 0; r < chunk.size(); r++) {
writer.BeginRow(chunk.ColumnCount());
for (idx_t c = 0; c < chunk.ColumnCount(); c++) {
Expand All @@ -44,13 +45,17 @@ struct PostgresBinaryCopyGlobalState : public GlobalFunctionData {

void Flush() {
// write the footer
PostgresBinaryWriter writer;
PostgresBinaryWriter writer(copy_state);
writer.WriteFooter();
Flush(writer);
// flush and close the file
file_writer->Flush();
file_writer.reset();
}

public:
unique_ptr<BufferedFileWriter> file_writer;
PostgresCopyState copy_state;
};

struct PostgresBinaryWriteBindData : public TableFunctionData {};
Expand All @@ -65,7 +70,7 @@ unique_ptr<FunctionData> PostgresBinaryCopyFunction::PostgresBinaryWriteBind(Cli
unique_ptr<GlobalFunctionData>
PostgresBinaryCopyFunction::PostgresBinaryWriteInitializeGlobal(ClientContext &context, FunctionData &bind_data,
const string &file_path) {
auto result = make_uniq<PostgresBinaryCopyGlobalState>();
auto result = make_uniq<PostgresBinaryCopyGlobalState>(context);
auto &fs = FileSystem::GetFileSystem(context);
result->file_writer = make_uniq<BufferedFileWriter>(fs, file_path);
// write the header
Expand Down
29 changes: 24 additions & 5 deletions src/postgres_copy_to.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,24 @@

namespace duckdb {

void PostgresCopyState::Initialize(ClientContext &context) {
Value replacement_value;
if (!context.TryGetCurrentSetting("pg_null_byte_replacement", replacement_value)) {
return;
}
if (replacement_value.IsNull()) {
return;
}
auto replacement_str = StringValue::Get(replacement_value);
for (const auto c : replacement_str) {
if (c == '\0') {
throw InternalException("NULL byte replacement string cannot contain NULL values");
}
}
has_null_byte_replacement = true;
null_byte_replacement = std::move(replacement_str);
}

void PostgresConnection::BeginCopyTo(ClientContext &context, PostgresCopyState &state, PostgresCopyFormat format,
const string &schema_name, const string &table_name,
const vector<string> &column_names) {
Expand All @@ -24,6 +42,7 @@ void PostgresConnection::BeginCopyTo(ClientContext &context, PostgresCopyState &
query += ") ";
}
query += "FROM STDIN (FORMAT ";
state.Initialize(context);
state.format = format;
switch (state.format) {
case PostgresCopyFormat::BINARY:
Expand All @@ -43,7 +62,7 @@ void PostgresConnection::BeginCopyTo(ClientContext &context, PostgresCopyState &
}
if (state.format == PostgresCopyFormat::BINARY) {
// binary copy requires a header
PostgresBinaryWriter writer;
PostgresBinaryWriter writer(state);
writer.WriteHeader();
CopyData(writer);
}
Expand All @@ -70,12 +89,12 @@ void PostgresConnection::CopyData(PostgresTextWriter &writer) {
void PostgresConnection::FinishCopyTo(PostgresCopyState &state) {
if (state.format == PostgresCopyFormat::BINARY) {
// binary copy requires a footer
PostgresBinaryWriter writer;
PostgresBinaryWriter writer(state);
writer.WriteFooter();
CopyData(writer);
} else if (state.format == PostgresCopyFormat::TEXT) {
// text copy requires a footer
PostgresTextWriter writer;
PostgresTextWriter writer(state);
writer.WriteFooter();
CopyData(writer);
}
Expand Down Expand Up @@ -263,7 +282,7 @@ void PostgresConnection::CopyChunk(ClientContext &context, PostgresCopyState &st
chunk.Flatten();

if (state.format == PostgresCopyFormat::BINARY) {
PostgresBinaryWriter writer;
PostgresBinaryWriter writer(state);
for (idx_t r = 0; r < chunk.size(); r++) {
writer.BeginRow(chunk.ColumnCount());
for (idx_t c = 0; c < chunk.ColumnCount(); c++) {
Expand All @@ -290,7 +309,7 @@ void PostgresConnection::CopyChunk(ClientContext &context, PostgresCopyState &st
}
varchar_chunk.SetCardinality(chunk.size());

PostgresTextWriter writer;
PostgresTextWriter writer(state);
for (idx_t r = 0; r < chunk.size(); r++) {
for (idx_t c = 0; c < chunk.ColumnCount(); c++) {
if (c > 0) {
Expand Down
14 changes: 14 additions & 0 deletions src/postgres_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,17 @@ void SetPostgresSecretParameters(CreateSecretFunction &function) {
function.named_parameters["dbname"] = LogicalType::VARCHAR;
}

void SetPostgresNullByteReplacement(ClientContext &context, SetScope scope, Value &parameter) {
if (parameter.IsNull()) {
return;
}
for (const auto c : StringValue::Get(parameter)) {
if (c == '\0') {
throw BinderException("NULL byte replacement string cannot contain NULL values");
}
}
}

static void LoadInternal(DatabaseInstance &db) {
PostgresScanFunction postgres_fun;
ExtensionUtil::RegisterFunction(db, postgres_fun);
Expand Down Expand Up @@ -165,6 +176,9 @@ static void LoadInternal(DatabaseInstance &db) {
config.AddExtensionOption("pg_experimental_filter_pushdown",
"Whether or not to use filter pushdown (currently experimental)", LogicalType::BOOLEAN,
Value::BOOLEAN(false));
config.AddExtensionOption("pg_null_byte_replacement",
"When writing NULL bytes to Postgres, replace them with the given character",
LogicalType::VARCHAR, Value(), SetPostgresNullByteReplacement);
config.AddExtensionOption("pg_debug_show_queries", "DEBUG SETTING: print all queries sent to Postgres to stdout",
LogicalType::BOOLEAN, Value::BOOLEAN(false), SetPostgresDebugQueryPrint);

Expand Down
65 changes: 65 additions & 0 deletions test/sql/storage/attach_null_byte.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# name: test/sql/storage/attach_null_byte.test
# description: Test inserting null byte values through ATTACH
# group: [storage]

require postgres_scanner

require-env POSTGRES_TEST_DATABASE_AVAILABLE

statement ok
PRAGMA enable_verification

statement ok
ATTACH 'dbname=postgresscanner' AS s1 (TYPE POSTGRES)

statement ok
USE s1

foreach pg_binary true false

statement ok
SET pg_use_binary_copy=${pg_binary}

statement ok
CREATE OR REPLACE TABLE nullbyte_tbl(s VARCHAR);

statement error
INSERT INTO nullbyte_tbl VALUES (chr(0))
----
Postgres does not support NULL-bytes in VARCHAR values

statement ok
SET pg_null_byte_replacement=''

statement ok
INSERT INTO nullbyte_tbl VALUES (chr(0)), ('FF' || chr(0) || 'FF');

query I
SELECT * FROM nullbyte_tbl
----
(empty)
FFFF

statement ok
SET pg_null_byte_replacement='NULLBYTE'

statement ok
INSERT INTO nullbyte_tbl VALUES (chr(0)), ('FF' || chr(0) || 'FF');

query I
SELECT * FROM nullbyte_tbl
----
(empty)
FFFF
NULLBYTE
FFNULLBYTEFF

statement ok
RESET pg_null_byte_replacement

endloop

statement error
SET pg_null_byte_replacement=chr(0)
----
NULL byte replacement string cannot contain NULL values
Loading