Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,53 @@ D SELECT * FROM postgres_db.uuids;

For more information on how to use the connector, refer to the [Postgres documentation on the website](https://duckdb.org/docs/extensions/postgres).

### AWS RDS IAM Authentication

The extension supports AWS RDS IAM-based authentication, which allows you to connect to RDS PostgreSQL instances using IAM database authentication instead of static passwords. This feature automatically generates temporary authentication tokens using the AWS CLI.

#### Requirements

- AWS CLI installed and configured
- RDS instance with IAM database authentication enabled
- IAM user/role with `rds-db:connect` permission for the RDS instance
- AWS credentials configured (via `AWS_PROFILE`, `AWS_ACCESS_KEY_ID`/`AWS_SECRET_ACCESS_KEY`, or IAM role)

#### Usage

To use RDS IAM authentication, create a Postgres secret with the `USE_RDS_IAM_AUTH` parameter set to `true`:

```sql
CREATE SECRET rds_secret (
TYPE POSTGRES,
HOST 'my-db-instance.xxxxxx.us-west-2.rds.amazonaws.com',
PORT '5432',
USER 'my_iam_user',
DATABASE 'mydb',
USE_RDS_IAM_AUTH true,
AWS_REGION 'us-west-2' -- Optional: uses AWS CLI default if not specified
);

ATTACH '' AS rds_db (TYPE POSTGRES, SECRET rds_secret);
```

#### Secret Parameters for RDS IAM Authentication

| Parameter | Type | Required | Description |
|-----------|------|----------|-------------|
| `USE_RDS_IAM_AUTH` | BOOLEAN | Yes | Enable RDS IAM authentication |
| `HOST` | VARCHAR | Yes | RDS instance hostname |
| `PORT` | VARCHAR | Yes | RDS instance port (typically 5432) |
| `USER` | VARCHAR | Yes | IAM database username |
| `DATABASE` or `DBNAME` | VARCHAR | No | Database name |
| `AWS_REGION` | VARCHAR | No | AWS region (optional, uses AWS CLI default if not specified) |

#### Important Notes

- **Token Expiration**: RDS auth tokens expire after 15 minutes. The extension automatically generates fresh tokens when creating new connections, so long-running queries will continue to work.
- **AWS CLI**: The extension uses the `aws rds generate-db-auth-token` command. Make sure the AWS CLI is installed and configured with appropriate credentials.
- **Environment Variables**: The AWS CLI command inherits environment variables from the parent process, so `AWS_PROFILE`, `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_SESSION_TOKEN`, and `AWS_REGION` will be available to the AWS CLI.


## Building & Loading the Extension

The DuckDB submodule must be initialized prior to building.
Expand Down
6 changes: 5 additions & 1 deletion src/include/storage/postgres_catalog.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ class PostgresSchemaEntry;
class PostgresCatalog : public Catalog {
public:
explicit PostgresCatalog(AttachedDatabase &db_p, string connection_string, string attach_path,
AccessMode access_mode, string schema_to_load, PostgresIsolationLevel isolation_level);
AccessMode access_mode, string schema_to_load, PostgresIsolationLevel isolation_level,
string secret_name = string());
~PostgresCatalog();

string connection_string;
string attach_path;
string secret_name;
AccessMode access_mode;
PostgresIsolationLevel isolation_level;

Expand All @@ -40,6 +42,8 @@ class PostgresCatalog : public Catalog {

static string GetConnectionString(ClientContext &context, const string &attach_path, string secret_name);

string GetFreshConnectionString(ClientContext &context);

optional_ptr<CatalogEntry> CreateSchema(CatalogTransaction transaction, CreateSchemaInfo &info) override;

void ScanSchemas(ClientContext &context, std::function<void(SchemaCatalogEntry &)> callback) override;
Expand Down
8 changes: 4 additions & 4 deletions src/include/storage/postgres_connection_pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ class PostgresConnectionPool {
PostgresConnectionPool(PostgresCatalog &postgres_catalog, idx_t maximum_connections = DEFAULT_MAX_CONNECTIONS);

public:
bool TryGetConnection(PostgresPoolConnection &connection);
PostgresPoolConnection GetConnection();
bool TryGetConnection(PostgresPoolConnection &connection, optional_ptr<ClientContext> context = nullptr);
PostgresPoolConnection GetConnection(optional_ptr<ClientContext> context = nullptr);
//! Always returns a connection - even if the connection slots are exhausted
PostgresPoolConnection ForceGetConnection();
PostgresPoolConnection ForceGetConnection(optional_ptr<ClientContext> context = nullptr);
void ReturnConnection(PostgresConnection connection);
void SetMaximumConnections(idx_t new_max);

Expand All @@ -61,7 +61,7 @@ class PostgresConnectionPool {
vector<PostgresConnection> connection_cache;

private:
PostgresPoolConnection GetConnectionInternal();
PostgresPoolConnection GetConnectionInternal(optional_ptr<ClientContext> context = nullptr);
};

} // namespace duckdb
6 changes: 6 additions & 0 deletions src/postgres_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ unique_ptr<BaseSecret> CreatePostgresSecretFunction(ClientContext &context, Crea
result->secret_map["port"] = named_param.second.ToString();
} else if (lower_name == "passfile") {
result->secret_map["passfile"] = named_param.second.ToString();
} else if (lower_name == "use_rds_iam_auth") {
result->secret_map["use_rds_iam_auth"] = named_param.second.ToString();
} else if (lower_name == "aws_region") {
result->secret_map["aws_region"] = named_param.second.ToString();
} else {
throw InternalException("Unknown named parameter passed to CreatePostgresSecretFunction: " + lower_name);
}
Expand All @@ -112,6 +116,8 @@ void SetPostgresSecretParameters(CreateSecretFunction &function) {
function.named_parameters["database"] = LogicalType::VARCHAR; // alias for dbname
function.named_parameters["dbname"] = LogicalType::VARCHAR;
function.named_parameters["passfile"] = LogicalType::VARCHAR;
function.named_parameters["use_rds_iam_auth"] = LogicalType::BOOLEAN;
function.named_parameters["aws_region"] = LogicalType::VARCHAR;
}

void SetPostgresNullByteReplacement(ClientContext &context, SetScope scope, Value &parameter) {
Expand Down
4 changes: 2 additions & 2 deletions src/postgres_scanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ bool PostgresGlobalState::TryOpenNewConnection(ClientContext &context, PostgresL
} else {
// we cannot use the main thread but we haven't initiated ANY scan yet
// we HAVE to open a new connection
lstate.pool_connection = pg_catalog->GetConnectionPool().ForceGetConnection();
lstate.pool_connection = pg_catalog->GetConnectionPool().ForceGetConnection(&context);
lstate.connection = PostgresConnection(lstate.pool_connection.GetConnection().GetConnection());
}
used_main_thread = true;
Expand All @@ -400,7 +400,7 @@ bool PostgresGlobalState::TryOpenNewConnection(ClientContext &context, PostgresL
}

if (pg_catalog) {
if (!pg_catalog->GetConnectionPool().TryGetConnection(lstate.pool_connection)) {
if (!pg_catalog->GetConnectionPool().TryGetConnection(lstate.pool_connection, &context)) {
return false;
}
lstate.connection = PostgresConnection(lstate.pool_connection.GetConnection().GetConnection());
Expand Down
3 changes: 2 additions & 1 deletion src/postgres_storage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ static unique_ptr<Catalog> PostgresAttach(optional_ptr<StorageExtensionInfo> sto
}
auto connection_string = PostgresCatalog::GetConnectionString(context, attach_path, secret_name);
return make_uniq<PostgresCatalog>(db, std::move(connection_string), std::move(attach_path),
attach_options.access_mode, std::move(schema_to_load), isolation_level);
attach_options.access_mode, std::move(schema_to_load), isolation_level,
std::move(secret_name));
}

static unique_ptr<TransactionManager> PostgresCreateTransactionManager(optional_ptr<StorageExtensionInfo> storage_info,
Expand Down
119 changes: 115 additions & 4 deletions src/storage/postgres_catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
#include "duckdb/parser/parsed_data/create_schema_info.hpp"
#include "duckdb/main/attached_database.hpp"
#include "duckdb/main/secret/secret_manager.hpp"
#include "duckdb/common/printer.hpp"
#include <cstdio>

namespace duckdb {

PostgresCatalog::PostgresCatalog(AttachedDatabase &db_p, string connection_string_p, string attach_path_p,
AccessMode access_mode, string schema_to_load, PostgresIsolationLevel isolation_level)
AccessMode access_mode, string schema_to_load, PostgresIsolationLevel isolation_level,
string secret_name_p)
: Catalog(db_p), connection_string(std::move(connection_string_p)), attach_path(std::move(attach_path_p)),
access_mode(access_mode), isolation_level(isolation_level), schemas(*this, schema_to_load),
connection_pool(*this), default_schema(schema_to_load) {
secret_name(std::move(secret_name_p)), access_mode(access_mode), isolation_level(isolation_level),
schemas(*this, schema_to_load), connection_pool(*this), default_schema(schema_to_load) {
if (default_schema.empty()) {
default_schema = "public";
}
Expand Down Expand Up @@ -72,6 +75,71 @@ unique_ptr<SecretEntry> GetSecret(ClientContext &context, const string &secret_n
return nullptr;
}

string GenerateRdsAuthToken(const string &hostname, const string &port, const string &username,
const string &aws_region) {

auto escape_shell_arg = [](const string &arg) -> string {
string escaped = "'";
for (char c : arg) {
if (c == '\'') {
escaped += "'\\''";
} else {
escaped += c;
}
}
escaped += "'";
return escaped;
};

string command = "aws rds generate-db-auth-token --hostname " + escape_shell_arg(hostname) +
" --port " + escape_shell_arg(port) + " --username " + escape_shell_arg(username);

if (!aws_region.empty()) {
command += " --region " + escape_shell_arg(aws_region);
}

command += " 2>&1";

FILE *pipe = popen(command.c_str(), "r");
if (!pipe) {
throw IOException("Failed to execute AWS CLI command to generate RDS auth token. "
"Make sure AWS CLI is installed and configured.");
}

string token;
char buffer[128];
while (fgets(buffer, sizeof(buffer), pipe) != nullptr) {
token += buffer;
}

int status = pclose(pipe);

if (!token.empty() && token.back() == '\n') {
token.pop_back();
}
if (status != 0) {

throw IOException("Failed to generate RDS auth token: %s. "
"Make sure AWS CLI is installed, configured, and you have the necessary permissions.",
token.empty() ? "Unknown error" : token.c_str());
}

if (!token.empty() && token.back() == '\n') {
token.pop_back();
}


if (PostgresConnection::DebugPrintQueries()) {
string debug_msg = StringUtil::Format(
"[RDS IAM Auth] Generated auth token for hostname=%s, port=%s, username=%s, region=%s\n, token=%s",
hostname.c_str(), port.c_str(), username.c_str(),
aws_region.empty() ? "(default)" : aws_region.c_str(), token.c_str());
Printer::Print(debug_msg);
}

return token;
}

string PostgresCatalog::GetConnectionString(ClientContext &context, const string &attach_path, string secret_name) {
// if no secret is specified we default to the unnamed postgres secret, if it exists
string connection_string = attach_path;
Expand All @@ -87,8 +155,47 @@ string PostgresCatalog::GetConnectionString(ClientContext &context, const string
const auto &kv_secret = dynamic_cast<const KeyValueSecret &>(*secret_entry->secret);
string new_connection_info;

Value use_rds_iam_auth_val = kv_secret.TryGetValue("use_rds_iam_auth");
bool use_rds_iam_auth = false;
if (!use_rds_iam_auth_val.IsNull()) {
use_rds_iam_auth = BooleanValue::Get(use_rds_iam_auth_val);
}

new_connection_info += AddConnectionOption(kv_secret, "user");
new_connection_info += AddConnectionOption(kv_secret, "password");

if (use_rds_iam_auth) {
Value host_val = kv_secret.TryGetValue("host");
Value port_val = kv_secret.TryGetValue("port");
Value user_val = kv_secret.TryGetValue("user");
Value aws_region_val = kv_secret.TryGetValue("aws_region");

if (host_val.IsNull() || port_val.IsNull() || user_val.IsNull()) {
throw BinderException(
"RDS IAM authentication requires 'host', 'port', and 'user' to be set in the secret");
}

string hostname = host_val.ToString();
string port = port_val.ToString();
string username = user_val.ToString();
string aws_region;


if (!aws_region_val.IsNull()) {
aws_region = aws_region_val.ToString();
}

try {
string rds_token = GenerateRdsAuthToken(hostname, port, username, aws_region);
new_connection_info += "password=";
new_connection_info += EscapeConnectionString(rds_token);
new_connection_info += " ";
} catch (const std::exception &e) {
throw BinderException("Failed to generate RDS auth token: %s", e.what());
}
} else {
new_connection_info += AddConnectionOption(kv_secret, "password");
}

new_connection_info += AddConnectionOption(kv_secret, "host");
new_connection_info += AddConnectionOption(kv_secret, "port");
new_connection_info += AddConnectionOption(kv_secret, "dbname");
Expand All @@ -102,6 +209,10 @@ string PostgresCatalog::GetConnectionString(ClientContext &context, const string
return connection_string;
}

string PostgresCatalog::GetFreshConnectionString(ClientContext &context) {
return GetConnectionString(context, attach_path, secret_name);
}

PostgresCatalog::~PostgresCatalog() = default;

void PostgresCatalog::Initialize(bool load_builtin) {
Expand Down
21 changes: 13 additions & 8 deletions src/storage/postgres_connection_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ PostgresConnectionPool::PostgresConnectionPool(PostgresCatalog &postgres_catalog
: postgres_catalog(postgres_catalog), active_connections(0), maximum_connections(maximum_connections_p) {
}

PostgresPoolConnection PostgresConnectionPool::GetConnectionInternal() {
PostgresPoolConnection PostgresConnectionPool::GetConnectionInternal(optional_ptr<ClientContext> context) {
active_connections++;
// check if we have any cached connections left
if (!connection_cache.empty()) {
Expand All @@ -54,21 +54,26 @@ PostgresPoolConnection PostgresConnectionPool::GetConnectionInternal() {
}

// no cached connections left but there is space to open a new one - open it
// If we have a context, generate a fresh connection string (with new RDS token if needed)
string connection_string_to_use = postgres_catalog.connection_string;
if (context) {
connection_string_to_use = postgres_catalog.GetFreshConnectionString(*context);
}
return PostgresPoolConnection(
this, PostgresConnection::Open(postgres_catalog.connection_string, postgres_catalog.attach_path));
this, PostgresConnection::Open(connection_string_to_use, postgres_catalog.attach_path));
}

PostgresPoolConnection PostgresConnectionPool::ForceGetConnection() {
PostgresPoolConnection PostgresConnectionPool::ForceGetConnection(optional_ptr<ClientContext> context) {
lock_guard<mutex> l(connection_lock);
return GetConnectionInternal();
return GetConnectionInternal(context);
}

bool PostgresConnectionPool::TryGetConnection(PostgresPoolConnection &connection) {
bool PostgresConnectionPool::TryGetConnection(PostgresPoolConnection &connection, optional_ptr<ClientContext> context) {
lock_guard<mutex> l(connection_lock);
if (active_connections >= maximum_connections) {
return false;
}
connection = GetConnectionInternal();
connection = GetConnectionInternal(context);
return true;
}

Expand All @@ -79,9 +84,9 @@ void PostgresConnectionPool::PostgresSetConnectionCache(ClientContext &context,
pg_use_connection_cache = BooleanValue::Get(parameter);
}

PostgresPoolConnection PostgresConnectionPool::GetConnection() {
PostgresPoolConnection PostgresConnectionPool::GetConnection(optional_ptr<ClientContext> context) {
PostgresPoolConnection result;
if (!TryGetConnection(result)) {
if (!TryGetConnection(result, context)) {
throw IOException(
"Failed to get connection from PostgresConnectionPool - maximum connection count exceeded (%llu/%llu max)",
active_connections, maximum_connections);
Expand Down
2 changes: 1 addition & 1 deletion src/storage/postgres_transaction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ PostgresTransaction::PostgresTransaction(PostgresCatalog &postgres_catalog, Tran
ClientContext &context)
: Transaction(manager, context), access_mode(postgres_catalog.access_mode),
isolation_level(postgres_catalog.isolation_level) {
connection = postgres_catalog.GetConnectionPool().GetConnection();
connection = postgres_catalog.GetConnectionPool().GetConnection(&context);
}

PostgresTransaction::~PostgresTransaction() = default;
Expand Down