diff --git a/src/postgres_extension.cpp b/src/postgres_extension.cpp index 34d46d04..a0deed2c 100644 --- a/src/postgres_extension.cpp +++ b/src/postgres_extension.cpp @@ -72,6 +72,44 @@ static void SetPostgresDebugQueryPrint(ClientContext &context, SetScope scope, V PostgresConnection::DebugSetPrintQueries(BooleanValue::Get(parameter)); } +unique_ptr CreatePostgresSecretFunction(ClientContext &context, CreateSecretInput &input) { + // apply any overridden settings + vector prefix_paths; + auto result = make_uniq(prefix_paths, "postgres", "config", input.name); + for (const auto &named_param : input.options) { + auto lower_name = StringUtil::Lower(named_param.first); + + if (lower_name == "host") { + result->secret_map["host"] = named_param.second.ToString(); + } else if (lower_name == "user") { + result->secret_map["user"] = named_param.second.ToString(); + } else if (lower_name == "database") { + result->secret_map["dbname"] = named_param.second.ToString(); + } else if (lower_name == "dbname") { + result->secret_map["dbname"] = named_param.second.ToString(); + } else if (lower_name == "password") { + result->secret_map["password"] = named_param.second.ToString(); + } else if (lower_name == "port") { + result->secret_map["port"] = named_param.second.ToString(); + } else { + throw InternalException("Unknown named parameter passed to CreatePostgresSecretFunction: " + lower_name); + } + } + + //! Set redact keys + result->redact_keys = {"password"}; + return std::move(result); +} + +void SetPostgresSecretParameters(CreateSecretFunction &function) { + function.named_parameters["host"] = LogicalType::VARCHAR; + function.named_parameters["port"] = LogicalType::VARCHAR; + function.named_parameters["password"] = LogicalType::VARCHAR; + function.named_parameters["user"] = LogicalType::VARCHAR; + function.named_parameters["database"] = LogicalType::VARCHAR; // alias for dbname + function.named_parameters["dbname"] = LogicalType::VARCHAR; +} + static void LoadInternal(DatabaseInstance &db) { PostgresScanFunction postgres_fun; ExtensionUtil::RegisterFunction(db, postgres_fun); @@ -94,6 +132,18 @@ static void LoadInternal(DatabaseInstance &db) { PostgresBinaryCopyFunction binary_copy; ExtensionUtil::RegisterFunction(db, binary_copy); + // Register the new type + SecretType secret_type; + secret_type.name = "postgres"; + secret_type.deserializer = KeyValueSecret::Deserialize; + secret_type.default_provider = "config"; + + ExtensionUtil::RegisterSecretType(db, secret_type); + + CreateSecretFunction postgres_secret_function = {"postgres", "config", CreatePostgresSecretFunction}; + SetPostgresSecretParameters(postgres_secret_function); + ExtensionUtil::RegisterFunction(db, postgres_secret_function); + auto &config = DBConfig::GetConfig(db); config.storage_extensions["postgres_scanner"] = make_uniq(); diff --git a/src/postgres_storage.cpp b/src/postgres_storage.cpp index 8837840a..30af82b3 100644 --- a/src/postgres_storage.cpp +++ b/src/postgres_storage.cpp @@ -4,13 +4,96 @@ #include "storage/postgres_catalog.hpp" #include "duckdb/parser/parsed_data/attach_info.hpp" #include "storage/postgres_transaction_manager.hpp" +#include "duckdb/main/secret/secret_manager.hpp" namespace duckdb { +string EscapeConnectionString(const string &input) { + string result = "'"; + for (auto c : input) { + if (c == '\\') { + result += "\\\\"; + } else if (c == '\'') { + result += "\\'"; + } else { + result += c; + } + } + result += "'"; + return result; +} + +string AddConnectionOption(const KeyValueSecret &kv_secret, const string &name) { + Value input_val = kv_secret.TryGetValue(name); + if (input_val.IsNull()) { + // not provided + return string(); + } + string result; + result += name; + result += "="; + result += EscapeConnectionString(input_val.ToString()); + result += " "; + return result; +} + +unique_ptr GetSecret(ClientContext &context, const string &secret_name) { + auto &secret_manager = SecretManager::Get(context); + auto transaction = CatalogTransaction::GetSystemCatalogTransaction(context); + // FIXME: this should be adjusted once the `GetSecretByName` API supports this use case + auto secret_entry = secret_manager.GetSecretByName(transaction, secret_name, "memory"); + if (secret_entry) { + return secret_entry; + } + secret_entry = secret_manager.GetSecretByName(transaction, secret_name, "local_file"); + if (secret_entry) { + return secret_entry; + } + return nullptr; +} + static unique_ptr PostgresAttach(StorageExtensionInfo *storage_info, ClientContext &context, AttachedDatabase &db, const string &name, AttachInfo &info, AccessMode access_mode) { - return make_uniq(db, info.path, access_mode); + string connection_string = info.path; + + string secret_name; + for (auto &entry : info.options) { + auto lower_name = StringUtil::Lower(entry.first); + if (lower_name == "type" || lower_name == "read_only") { + // already handled + } else if (lower_name == "secret") { + secret_name = entry.second.ToString(); + } else { + throw BinderException("Unrecognized option for Postgres attach: %s", entry.first); + } + } + + // if no secret is specified we default to the unnamed postgres secret, if it exists + bool explicit_secret = !secret_name.empty(); + if (!explicit_secret) { + // look up settings from the default unnamed postgres secret if none is provided + secret_name = "__default_postgres"; + } + + auto secret_entry = GetSecret(context, secret_name); + if (secret_entry) { + // secret found - read data + const auto &kv_secret = dynamic_cast(*secret_entry->secret); + string new_connection_info; + + new_connection_info += AddConnectionOption(kv_secret, "user"); + new_connection_info += AddConnectionOption(kv_secret, "password"); + new_connection_info += AddConnectionOption(kv_secret, "host"); + new_connection_info += AddConnectionOption(kv_secret, "port"); + new_connection_info += AddConnectionOption(kv_secret, "dbname"); + + connection_string = new_connection_info + connection_string; + } else if (explicit_secret) { + // secret not found and one was explicitly provided - throw an error + throw BinderException("Secret with name \"%s\" not found", secret_name); + } + return make_uniq(db, connection_string, access_mode); } static unique_ptr PostgresCreateTransactionManager(StorageExtensionInfo *storage_info, diff --git a/test/sql/storage/attach_secret.test b/test/sql/storage/attach_secret.test new file mode 100644 index 00000000..e8a270d2 --- /dev/null +++ b/test/sql/storage/attach_secret.test @@ -0,0 +1,71 @@ +# name: test/sql/storage/attach_secret.test +# description: Test attaching using a secret +# group: [storage] + +require postgres_scanner + +require-env POSTGRES_TEST_DATABASE_AVAILABLE + +statement ok +PRAGMA enable_verification + +# attach using default secret +statement ok +CREATE SECRET ( + TYPE POSTGRES, + HOST '127.0.0.1', + DATABASE unknown_db, + PASSWORD '' +); + +statement error +ATTACH '' AS secret_attach (TYPE POSTGRES) +---- +unknown_db + +# attach using an explicit secret +statement ok +CREATE SECRET postgres_db ( + TYPE POSTGRES, + HOST '127.0.0.1', + DATABASE postgresscanner, + PASSWORD '' +); + +statement ok +ATTACH '' AS secret_attach (TYPE POSTGRES, SECRET postgres_db) + +statement ok +DETACH secret_attach + +statement ok +CREATE OR REPLACE SECRET postgres_db ( + TYPE POSTGRES, + HOST '127.0.0.1', + DBNAME unknown_database, + PASSWORD '' +); + +# non-existent database +statement error +ATTACH '' AS secret_attach (TYPE POSTGRES, SECRET postgres_db) +---- +unknown_database + +# we can override options in the attach string +statement ok +ATTACH 'dbname=postgresscanner' AS secret_attach (TYPE POSTGRES, SECRET postgres_db) + +statement error +CREATE SECRET new_secret ( + TYPE POSTGRES, + UNKNOWN_OPTION xx +); +---- +unknown_option + +# unknown secret +statement error +ATTACH '' AS secret_attach (TYPE POSTGRES, SECRET unknown_secret) +---- +unknown_secret