Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions src/include/storage/postgres_connection_pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,22 @@
#include "duckdb/common/mutex.hpp"
#include "duckdb/common/optional_ptr.hpp"
#include "postgres_connection.hpp"
#include <chrono>
#include <thread>
#include <condition_variable>
#include <atomic>

namespace duckdb {
class PostgresCatalog;
class PostgresConnectionPool;

class PostgresPoolConnection {
public:
using time_point_t = std::chrono::steady_clock::time_point;

PostgresPoolConnection();
PostgresPoolConnection(optional_ptr<PostgresConnectionPool> pool, PostgresConnection connection);
PostgresPoolConnection(optional_ptr<PostgresConnectionPool> pool, PostgresConnection connection,
time_point_t created_at);
~PostgresPoolConnection();
// disable copy constructors
PostgresPoolConnection(const PostgresPoolConnection &other) = delete;
Expand All @@ -35,32 +42,57 @@ class PostgresPoolConnection {
private:
optional_ptr<PostgresConnectionPool> pool;
PostgresConnection connection;
time_point_t created_at;
};

class PostgresConnectionPool {
public:
using steady_clock = std::chrono::steady_clock;
using steady_time_point = steady_clock::time_point;

static constexpr const idx_t DEFAULT_MAX_CONNECTIONS = 64;

PostgresConnectionPool(PostgresCatalog &postgres_catalog, idx_t maximum_connections = DEFAULT_MAX_CONNECTIONS);
~PostgresConnectionPool();

public:
bool TryGetConnection(PostgresPoolConnection &connection);
PostgresPoolConnection GetConnection();
//! Always returns a connection - even if the connection slots are exhausted
PostgresPoolConnection ForceGetConnection();
void ReturnConnection(PostgresConnection connection);
void ReturnConnection(PostgresConnection connection, steady_time_point created_at);
void SetMaximumConnections(idx_t new_max);
void SetMaxLifetime(idx_t seconds);
void SetIdleTimeout(idx_t seconds);

static void PostgresSetConnectionCache(ClientContext &context, SetScope scope, Value &parameter);

private:
struct CachedConnection {
PostgresConnection connection;
steady_time_point created_at;
steady_time_point returned_at;
};

PostgresCatalog &postgres_catalog;
mutex connection_lock;
idx_t active_connections;
idx_t maximum_connections;
vector<PostgresConnection> connection_cache;
vector<CachedConnection> connection_cache;

idx_t max_lifetime_seconds = 0;
idx_t idle_timeout_seconds = 0;

std::thread reaper_thread;
std::condition_variable reaper_cv;
std::atomic<bool> shutdown {false};

bool IsExpired(const CachedConnection &entry, steady_time_point now) const;
void ReaperLoop();
void StartReaperIfNeeded(unique_lock<mutex> &lock);
void StopReaper(unique_lock<mutex> &lock);
void UpdateTimeoutSetting(idx_t &field, idx_t seconds);

private:
PostgresPoolConnection GetConnectionInternal(unique_lock<mutex> &lock);
};

Expand Down
39 changes: 39 additions & 0 deletions src/postgres_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,39 @@ static void SetPostgresConnectionLimit(ClientContext &context, SetScope scope, V
config.SetOption("pg_connection_limit", parameter);
}

using PoolTimeoutSetter = void (PostgresConnectionPool::*)(idx_t);

static void SetPostgresPoolTimeout(ClientContext &context, SetScope scope, Value &parameter,
const char *option_name, PoolTimeoutSetter setter) {
if (parameter.IsNull()) {
throw BinderException("Cannot be set to NULL");
}
if (scope == SetScope::LOCAL) {
throw InvalidInputException("%s can only be set globally", option_name);
}
auto databases = DatabaseManager::Get(context).GetDatabases(context);
for (auto &db_ref : databases) {
auto &db = *db_ref;
auto &catalog = db.GetCatalog();
if (catalog.GetCatalogType() != "postgres") {
continue;
}
(catalog.Cast<PostgresCatalog>().GetConnectionPool().*setter)(UBigIntValue::Get(parameter));
}
auto &config = DBConfig::GetConfig(context);
config.SetOption(option_name, parameter);
}

static void SetPostgresMaxLifetime(ClientContext &context, SetScope scope, Value &parameter) {
SetPostgresPoolTimeout(context, scope, parameter, "pg_connection_max_lifetime",
&PostgresConnectionPool::SetMaxLifetime);
}

static void SetPostgresIdleTimeout(ClientContext &context, SetScope scope, Value &parameter) {
SetPostgresPoolTimeout(context, scope, parameter, "pg_connection_idle_timeout",
&PostgresConnectionPool::SetIdleTimeout);
}

static void SetPostgresDebugQueryPrint(ClientContext &context, SetScope scope, Value &parameter) {
PostgresConnection::DebugSetPrintQueries(BooleanValue::Get(parameter));
}
Expand Down Expand Up @@ -172,6 +205,12 @@ static void LoadInternal(ExtensionLoader &loader) {
config.AddExtensionOption("pg_connection_limit", "The maximum amount of concurrent Postgres connections",
LogicalType::UBIGINT, Value::UBIGINT(PostgresConnectionPool::DEFAULT_MAX_CONNECTIONS),
SetPostgresConnectionLimit);
config.AddExtensionOption("pg_connection_max_lifetime",
"Maximum lifetime of a pooled connection in seconds (0 = disabled)",
LogicalType::UBIGINT, Value::UBIGINT(0), SetPostgresMaxLifetime);
config.AddExtensionOption("pg_connection_idle_timeout",
"Maximum idle time of a pooled connection in seconds before it is closed (0 = disabled)",
LogicalType::UBIGINT, Value::UBIGINT(0), SetPostgresIdleTimeout);
config.AddExtensionOption(
"pg_array_as_varchar", "Read Postgres arrays as varchar - enables reading mixed dimensional arrays",
LogicalType::BOOLEAN, Value::BOOLEAN(false), PostgresClearCacheFunction::ClearCacheOnSetting);
Expand Down
8 changes: 8 additions & 0 deletions src/storage/postgres_catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ PostgresCatalog::PostgresCatalog(AttachedDatabase &db_p, string connection_strin
if (db_instance.TryGetCurrentSetting("pg_connection_limit", connection_limit)) {
connection_pool.SetMaximumConnections(UBigIntValue::Get(connection_limit));
}
Value max_lifetime;
if (db_instance.TryGetCurrentSetting("pg_connection_max_lifetime", max_lifetime)) {
connection_pool.SetMaxLifetime(UBigIntValue::Get(max_lifetime));
}
Value idle_timeout;
if (db_instance.TryGetCurrentSetting("pg_connection_idle_timeout", idle_timeout)) {
connection_pool.SetIdleTimeout(UBigIntValue::Get(idle_timeout));
}

auto connection = connection_pool.GetConnection();
this->version = connection.GetConnection().GetPostgresVersion(context);
Expand Down
147 changes: 134 additions & 13 deletions src/storage/postgres_connection_pool.cpp
Original file line number Diff line number Diff line change
@@ -1,31 +1,34 @@
#include "storage/postgres_connection_pool.hpp"
#include "storage/postgres_catalog.hpp"
#include <algorithm>

namespace duckdb {
static bool pg_use_connection_cache = true;

PostgresPoolConnection::PostgresPoolConnection() : pool(nullptr) {
PostgresPoolConnection::PostgresPoolConnection() : pool(nullptr), created_at() {
}

PostgresPoolConnection::PostgresPoolConnection(optional_ptr<PostgresConnectionPool> pool,
PostgresConnection connection_p)
: pool(pool), connection(std::move(connection_p)) {
PostgresConnection connection_p, time_point_t created_at_p)
: pool(pool), connection(std::move(connection_p)), created_at(created_at_p) {
}

PostgresPoolConnection::~PostgresPoolConnection() {
if (pool) {
pool->ReturnConnection(std::move(connection));
pool->ReturnConnection(std::move(connection), created_at);
}
}

PostgresPoolConnection::PostgresPoolConnection(PostgresPoolConnection &&other) noexcept {
std::swap(pool, other.pool);
std::swap(connection, other.connection);
std::swap(created_at, other.created_at);
}

PostgresPoolConnection &PostgresPoolConnection::operator=(PostgresPoolConnection &&other) noexcept {
std::swap(pool, other.pool);
std::swap(connection, other.connection);
std::swap(created_at, other.created_at);
return *this;
}

Expand All @@ -44,18 +47,102 @@ PostgresConnectionPool::PostgresConnectionPool(PostgresCatalog &postgres_catalog
: postgres_catalog(postgres_catalog), active_connections(0), maximum_connections(maximum_connections_p) {
}

PostgresConnectionPool::~PostgresConnectionPool() {
unique_lock<mutex> l(connection_lock);
StopReaper(l);
}

bool PostgresConnectionPool::IsExpired(const CachedConnection &entry, steady_time_point now) const {
if (max_lifetime_seconds > 0) {
auto age = std::chrono::duration_cast<std::chrono::seconds>(now - entry.created_at).count();
if (static_cast<idx_t>(age) >= max_lifetime_seconds) {
return true;
}
}
if (idle_timeout_seconds > 0) {
auto idle = std::chrono::duration_cast<std::chrono::seconds>(now - entry.returned_at).count();
if (static_cast<idx_t>(idle) >= idle_timeout_seconds) {
return true;
}
}
return false;
}

void PostgresConnectionPool::ReaperLoop() {
unique_lock<mutex> l(connection_lock);
while (!shutdown.load()) {
idx_t sleep_seconds = 30;
if (max_lifetime_seconds > 0 && idle_timeout_seconds > 0) {
sleep_seconds = std::min(max_lifetime_seconds, idle_timeout_seconds);
} else if (max_lifetime_seconds > 0) {
sleep_seconds = max_lifetime_seconds;
} else if (idle_timeout_seconds > 0) {
sleep_seconds = idle_timeout_seconds;
}
sleep_seconds = std::max<idx_t>(1, sleep_seconds / 2);
sleep_seconds = std::min<idx_t>(60, sleep_seconds);

reaper_cv.wait_for(l, std::chrono::seconds(sleep_seconds), [this]() { return shutdown.load(); });

if (shutdown.load()) {
break;
}

auto now = steady_clock::now();
auto it = std::partition(connection_cache.begin(), connection_cache.end(),
[this, now](const CachedConnection &e) { return !IsExpired(e, now); });
vector<CachedConnection> expired(std::make_move_iterator(it), std::make_move_iterator(connection_cache.end()));
connection_cache.erase(it, connection_cache.end());
// release lock while destroying expired connections (PQfinish may block)
l.unlock();
expired.clear();
l.lock();
}
}

void PostgresConnectionPool::StartReaperIfNeeded(unique_lock<mutex> &lock) {
if (max_lifetime_seconds == 0 && idle_timeout_seconds == 0) {
return;
}
if (reaper_thread.joinable()) {
return;
}
shutdown.store(false);
reaper_thread = std::thread(&PostgresConnectionPool::ReaperLoop, this);
}

void PostgresConnectionPool::StopReaper(unique_lock<mutex> &lock) {
if (!reaper_thread.joinable()) {
return;
}
shutdown.store(true);
reaper_cv.notify_all();
lock.unlock();
reaper_thread.join();
lock.lock();
}

PostgresPoolConnection PostgresConnectionPool::GetConnectionInternal(unique_lock<mutex> &lock) {
active_connections++;
// check if we have any cached connections left
if (!connection_cache.empty()) {
auto connection = PostgresPoolConnection(this, std::move(connection_cache.back()));
auto now = steady_clock::now();
while (!connection_cache.empty()) {
auto cached = std::move(connection_cache.back());
connection_cache.pop_back();
return connection;
if (IsExpired(cached, now)) {
continue;
}
return PostgresPoolConnection(this, std::move(cached.connection), cached.created_at);
}
// no cached connections left but there is space to open a new one - open it after releasing the cache lock
lock.unlock();
return PostgresPoolConnection(
this, PostgresConnection::Open(postgres_catalog.connection_string, postgres_catalog.attach_path));
auto created = steady_clock::now();
try {
return PostgresPoolConnection(
this, PostgresConnection::Open(postgres_catalog.connection_string, postgres_catalog.attach_path), created);
} catch (...) {
lock.lock();
active_connections--;
throw;
}
}

PostgresPoolConnection PostgresConnectionPool::ForceGetConnection() {
Expand Down Expand Up @@ -89,7 +176,7 @@ PostgresPoolConnection PostgresConnectionPool::GetConnection() {
return result;
}

void PostgresConnectionPool::ReturnConnection(PostgresConnection connection) {
void PostgresConnectionPool::ReturnConnection(PostgresConnection connection, steady_time_point created_at) {
unique_lock<mutex> l(connection_lock);
if (active_connections <= 0) {
throw InternalException("PostgresConnectionPool::ReturnConnection called but active_connections is 0");
Expand All @@ -99,6 +186,16 @@ void PostgresConnectionPool::ReturnConnection(PostgresConnection connection) {
active_connections--;
return;
}

// check if the connection has exceeded its max lifetime before doing anything else
if (max_lifetime_seconds > 0) {
auto age = std::chrono::duration_cast<std::chrono::seconds>(steady_clock::now() - created_at).count();
if (static_cast<idx_t>(age) >= max_lifetime_seconds) {
active_connections--;
return;
}
}

// we want to cache the connection
// check if the underlying connection is still usable
// avoid holding the lock while doing this
Expand All @@ -116,6 +213,7 @@ void PostgresConnectionPool::ReturnConnection(PostgresConnection connection) {
if (!connection_is_bad && PQtransactionStatus(pg_con) != PQTRANS_IDLE) {
connection_is_bad = true;
}

// lock and return the connection
l.lock();
active_connections--;
Expand All @@ -128,7 +226,11 @@ void PostgresConnectionPool::ReturnConnection(PostgresConnection connection) {
// immediately
return;
}
connection_cache.push_back(std::move(connection));
CachedConnection cached;
cached.connection = std::move(connection);
cached.created_at = created_at;
cached.returned_at = steady_clock::now();
connection_cache.push_back(std::move(cached));
}

void PostgresConnectionPool::SetMaximumConnections(idx_t new_max) {
Expand All @@ -146,4 +248,23 @@ void PostgresConnectionPool::SetMaximumConnections(idx_t new_max) {
maximum_connections = new_max;
}

void PostgresConnectionPool::UpdateTimeoutSetting(idx_t &field, idx_t seconds) {
unique_lock<mutex> l(connection_lock);
field = seconds;
if (max_lifetime_seconds == 0 && idle_timeout_seconds == 0) {
StopReaper(l);
} else {
StartReaperIfNeeded(l);
reaper_cv.notify_all();
}
}

void PostgresConnectionPool::SetMaxLifetime(idx_t seconds) {
UpdateTimeoutSetting(max_lifetime_seconds, seconds);
}

void PostgresConnectionPool::SetIdleTimeout(idx_t seconds) {
UpdateTimeoutSetting(idle_timeout_seconds, seconds);
}

} // namespace duckdb
Loading
Loading