diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 5d47e2718..72712f54e 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -13,6 +13,7 @@ add_library( postgres_execute.cpp postgres_extension.cpp postgres_filter_pushdown.cpp + postgres_parameters.cpp postgres_query.cpp postgres_scanner.cpp postgres_storage.cpp diff --git a/src/include/postgres_connection.hpp b/src/include/postgres_connection.hpp index 69787217c..4f8f5d8fd 100644 --- a/src/include/postgres_connection.hpp +++ b/src/include/postgres_connection.hpp @@ -9,6 +9,7 @@ #pragma once #include "postgres_utils.hpp" +#include "postgres_parameters.hpp" #include "postgres_result.hpp" #include "duckdb/common/shared_ptr.hpp" @@ -45,9 +46,13 @@ class PostgresConnection { public: static PostgresConnection Open(const string &dsn, const string &attach_path); - void Execute(optional_ptr context, const string &query); - unique_ptr TryQuery(optional_ptr context, const string &query, optional_ptr error_message = nullptr); - unique_ptr Query(optional_ptr context, const string &query); + void Execute(optional_ptr context, const string &query, + const PostgresParameters ¶ms = PostgresParameters()); + unique_ptr TryQuery(optional_ptr context, const string &query, + optional_ptr error_message = nullptr, + const PostgresParameters ¶ms = PostgresParameters()); + unique_ptr Query(optional_ptr context, const string &query, + const PostgresParameters ¶ms = PostgresParameters()); //! Submits a set of queries to be executed in the connection. vector> ExecuteQueries(ClientContext &context, const string &queries); @@ -87,7 +92,8 @@ class PostgresConnection { static bool DebugPrintQueries(); private: - PGresult *PQExecute(optional_ptr context, const string &query); + PGresult *PQExecute(optional_ptr context, const string &query, + const PostgresParameters ¶ms = PostgresParameters()); shared_ptr connection; string dsn; diff --git a/src/include/postgres_logging.hpp b/src/include/postgres_logging.hpp index 5c038efca..1fc9650b9 100644 --- a/src/include/postgres_logging.hpp +++ b/src/include/postgres_logging.hpp @@ -17,4 +17,4 @@ class PostgresQueryLogType : public LogType { static LogicalType GetLogType(); }; -} // namespace +} // namespace duckdb diff --git a/src/include/postgres_parameters.hpp b/src/include/postgres_parameters.hpp new file mode 100644 index 000000000..62c948ac9 --- /dev/null +++ b/src/include/postgres_parameters.hpp @@ -0,0 +1,56 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// postgres_parameters.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb.hpp" +#include +#include "postgres_version.hpp" + +namespace duckdb { + +class PostgresParameters { + vector types; + vector values; + vector> copied_values; + vector value_ptrs; + vector lengths; + vector formats; + +public: + PostgresParameters() { + } + + PostgresParameters(vector types_p, vector values_p); + + bool Empty() const { + return types.empty(); + } + + int Count() const { + return static_cast(types.size()); + } + + const Oid *Types() const { + return types.data(); + } + + const char *const *Values() const { + return value_ptrs.data(); + } + + const int *Lengths() const { + return lengths.data(); + } + + const int *Formats() const { + return formats.data(); + } +}; + +} // namespace duckdb diff --git a/src/include/postgres_scanner.hpp b/src/include/postgres_scanner.hpp index 0f9cda70f..3e0162cdc 100644 --- a/src/include/postgres_scanner.hpp +++ b/src/include/postgres_scanner.hpp @@ -11,6 +11,7 @@ #include "duckdb.hpp" #include "postgres_utils.hpp" #include "postgres_connection.hpp" +#include "postgres_parameters.hpp" #include "storage/postgres_connection_pool.hpp" namespace duckdb { @@ -29,6 +30,7 @@ struct PostgresBindData : public FunctionData { string schema_name; string table_name; string sql; + PostgresParameters params; string limit; idx_t pages_approx = 0; diff --git a/src/include/storage/postgres_catalog.hpp b/src/include/storage/postgres_catalog.hpp index 957be9bd5..9530835b7 100644 --- a/src/include/storage/postgres_catalog.hpp +++ b/src/include/storage/postgres_catalog.hpp @@ -21,7 +21,8 @@ class PostgresSchemaEntry; class PostgresCatalog : public Catalog { public: explicit PostgresCatalog(AttachedDatabase &db_p, string connection_string, string attach_path, - AccessMode access_mode, string schema_to_load, PostgresIsolationLevel isolation_level, ClientContext &context); + AccessMode access_mode, string schema_to_load, PostgresIsolationLevel isolation_level, + ClientContext &context); ~PostgresCatalog(); string connection_string; diff --git a/src/include/storage/postgres_catalog_set.hpp b/src/include/storage/postgres_catalog_set.hpp index 0c24441a4..d454c591d 100644 --- a/src/include/storage/postgres_catalog_set.hpp +++ b/src/include/storage/postgres_catalog_set.hpp @@ -25,7 +25,8 @@ class PostgresCatalogSet { optional_ptr GetEntry(ClientContext &context, PostgresTransaction &transaction, const string &name); void DropEntry(PostgresTransaction &transaction, DropInfo &info); - void Scan(ClientContext& context, PostgresTransaction &transaction, const std::function &callback); + void Scan(ClientContext &context, PostgresTransaction &transaction, + const std::function &callback); virtual optional_ptr CreateEntry(PostgresTransaction &transaction, shared_ptr entry); void ClearEntries(); virtual bool SupportReload() const { diff --git a/src/include/storage/postgres_delete.hpp b/src/include/storage/postgres_delete.hpp index 4e878de01..5b074ff40 100644 --- a/src/include/storage/postgres_delete.hpp +++ b/src/include/storage/postgres_delete.hpp @@ -22,7 +22,8 @@ class PostgresDelete : public PhysicalOperator { public: // Source interface - SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/include/storage/postgres_index.hpp b/src/include/storage/postgres_index.hpp index 76593a023..556924f20 100644 --- a/src/include/storage/postgres_index.hpp +++ b/src/include/storage/postgres_index.hpp @@ -24,7 +24,8 @@ class PostgresCreateIndex : public PhysicalOperator { public: // Source interface - SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/include/storage/postgres_insert.hpp b/src/include/storage/postgres_insert.hpp index 4bd0e2991..3d20e97f5 100644 --- a/src/include/storage/postgres_insert.hpp +++ b/src/include/storage/postgres_insert.hpp @@ -35,7 +35,8 @@ class PostgresInsert : public PhysicalOperator { public: // Source interface - SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/include/storage/postgres_table_set.hpp b/src/include/storage/postgres_table_set.hpp index a5993ba35..fcdff5428 100644 --- a/src/include/storage/postgres_table_set.hpp +++ b/src/include/storage/postgres_table_set.hpp @@ -26,8 +26,8 @@ class PostgresTableSet : public PostgresInSchemaSet { static unique_ptr GetTableInfo(PostgresTransaction &transaction, PostgresSchemaEntry &schema, const string &table_name); - static unique_ptr GetTableInfo(ClientContext &context, PostgresConnection &connection, const string &schema_name, - const string &table_name); + static unique_ptr GetTableInfo(ClientContext &context, PostgresConnection &connection, + const string &schema_name, const string &table_name); optional_ptr ReloadEntry(PostgresTransaction &transaction, const string &table_name) override; void AlterTable(ClientContext &context, PostgresTransaction &transaction, AlterTableInfo &info); diff --git a/src/include/storage/postgres_update.hpp b/src/include/storage/postgres_update.hpp index dd25e7594..1d567b978 100644 --- a/src/include/storage/postgres_update.hpp +++ b/src/include/storage/postgres_update.hpp @@ -29,7 +29,8 @@ class PostgresUpdate : public PhysicalOperator { public: // Source interface - SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/postgres_connection.cpp b/src/postgres_connection.cpp index 7c890ef54..9b11f370c 100644 --- a/src/postgres_connection.cpp +++ b/src/postgres_connection.cpp @@ -61,26 +61,42 @@ static bool ResultHasError(PGresult *result) { } } -PGresult *PostgresConnection::PQExecute(optional_ptr context, const string &query) { +PGresult *PostgresConnection::PQExecute(optional_ptr context, const string &query, + const PostgresParameters ¶ms) { if (PostgresConnection::DebugPrintQueries()) { Printer::Print(query + "\n"); } - int64_t start_time = std::chrono::time_point_cast(std::chrono::steady_clock::now()) - .time_since_epoch() - .count(); - auto res = PQexec(GetConn(), query.c_str()); - int64_t end_time = std::chrono::time_point_cast(std::chrono::steady_clock::now()) - .time_since_epoch() - .count(); + int64_t start_time = std::chrono::time_point_cast(std::chrono::steady_clock::now()) + .time_since_epoch() + .count(); + + PGconn *conn = GetConn(); + PGresult *res = nullptr; + + if (params.Empty()) { + res = PQexec(GetConn(), query.c_str()); + } else { + // Unlike PQexec, PQexecParams allows at most one SQL command in the given string. + int format = 0; // text format + res = PQexecParams(conn, query.c_str(), params.Count(), params.Types(), params.Values(), params.Lengths(), + params.Formats(), format); + } + + int64_t end_time = std::chrono::time_point_cast(std::chrono::steady_clock::now()) + .time_since_epoch() + .count(); if (context) { DUCKDB_LOG(*context, PostgresQueryLogType, query, end_time - start_time); } return res; } -unique_ptr PostgresConnection::TryQuery(optional_ptr context, const string &query, optional_ptr error_message) { +unique_ptr PostgresConnection::TryQuery(optional_ptr context, const string &query, + optional_ptr error_message, + const PostgresParameters ¶ms) { lock_guard guard(connection->connection_lock); - auto result = PQExecute(context, query.c_str()); + auto result = PQExecute(context, query.c_str(), params); + if (ResultHasError(result)) { if (error_message) { *error_message = StringUtil::Format("Failed to execute query \"" + query + @@ -92,26 +108,28 @@ unique_ptr PostgresConnection::TryQuery(optional_ptr(result); } -unique_ptr PostgresConnection::Query(optional_ptr context, const string &query) { +unique_ptr PostgresConnection::Query(optional_ptr context, const string &query, + const PostgresParameters ¶ms) { string error_msg; - auto result = TryQuery(context, query, &error_msg); + auto result = TryQuery(context, query, &error_msg, params); if (!result) { throw std::runtime_error(error_msg); } return result; } -void PostgresConnection::Execute(optional_ptr context, const string &query) { - Query(context, query); +void PostgresConnection::Execute(optional_ptr context, const string &query, + const PostgresParameters ¶ms) { + Query(context, query, params); } vector> PostgresConnection::ExecuteQueries(ClientContext &context, const string &queries) { if (PostgresConnection::DebugPrintQueries()) { Printer::Print(queries + "\n"); } - int64_t start_time = std::chrono::time_point_cast(std::chrono::steady_clock::now()) - .time_since_epoch() - .count(); + int64_t start_time = std::chrono::time_point_cast(std::chrono::steady_clock::now()) + .time_since_epoch() + .count(); auto res = PQsendQuery(GetConn(), queries.c_str()); if (res == 0) { throw std::runtime_error("Failed to execute query \"" + queries + "\": " + string(PQerrorMessage(GetConn()))); @@ -132,9 +150,9 @@ vector> PostgresConnection::ExecuteQueries(ClientCont } results.push_back(std::move(result)); } - int64_t end_time = std::chrono::time_point_cast(std::chrono::steady_clock::now()) - .time_since_epoch() - .count(); + int64_t end_time = std::chrono::time_point_cast(std::chrono::steady_clock::now()) + .time_since_epoch() + .count(); DUCKDB_LOG(context, PostgresQueryLogType, queries, end_time - start_time); return results; } diff --git a/src/postgres_extension.cpp b/src/postgres_extension.cpp index dce65b121..c7c67c462 100644 --- a/src/postgres_extension.cpp +++ b/src/postgres_extension.cpp @@ -198,8 +198,8 @@ static void LoadInternal(ExtensionLoader &loader) { connection->registered_state->Insert("postgres_extension", make_shared_ptr()); } -auto &instance = loader.GetDatabaseInstance(); -auto &log_manager = instance.GetLogManager(); + auto &instance = loader.GetDatabaseInstance(); + auto &log_manager = instance.GetLogManager(); log_manager.RegisterLogType(make_uniq()); } diff --git a/src/postgres_logging.cpp b/src/postgres_logging.cpp index 55ab8ab66..2701382e2 100644 --- a/src/postgres_logging.cpp +++ b/src/postgres_logging.cpp @@ -16,20 +16,20 @@ constexpr LogLevel PostgresQueryLogType::LEVEL; // PostgresQueryLogType //===--------------------------------------------------------------------===// string PostgresQueryLogType::ConstructLogMessage(const string &str, int64_t duration) { - child_list_t child_list = { - {"query", str}, - {"duration_ms", duration}, - }; + child_list_t child_list = { + {"query", str}, + {"duration_ms", duration}, + }; - return Value::STRUCT(std::move(child_list)).ToString(); + return Value::STRUCT(std::move(child_list)).ToString(); } LogicalType PostgresQueryLogType::GetLogType() { - child_list_t child_list = { - {"query", LogicalType::VARCHAR}, - {"duration_ms", LogicalType::BIGINT}, - }; - return LogicalType::STRUCT(child_list); + child_list_t child_list = { + {"query", LogicalType::VARCHAR}, + {"duration_ms", LogicalType::BIGINT}, + }; + return LogicalType::STRUCT(child_list); } -} // namespace +} // namespace duckdb diff --git a/src/postgres_parameters.cpp b/src/postgres_parameters.cpp new file mode 100644 index 000000000..4fb57b88b --- /dev/null +++ b/src/postgres_parameters.cpp @@ -0,0 +1,125 @@ +#include "postgres_parameters.hpp" + +#include "duckdb.hpp" + +#include + +#include "postgres_conversion.hpp" + +namespace duckdb { + +static const int FORMAT_TEXT = 0; +static const int FORMAT_BINARY = 1; + +struct Param { + const char *ptr = nullptr; + int length = 0; + int format = FORMAT_TEXT; + + Param() { + } + + Param(const char *ptr_in, int length_in, int format_in) : ptr(ptr_in), length(length_in), format(format_in) { + } +}; + +static Param CreateVarcharParam(Value &value) { + const string &str = StringValue::Get(value); + return Param(str.c_str(), static_cast(str.length()), FORMAT_TEXT); +} + +template +static Param CreateIntParam(INT_TYPE num, vector ©_holder) { + copy_holder.resize(sizeof(INT_TYPE)); + memcpy(copy_holder.data(), &num, sizeof(INT_TYPE)); + return Param(copy_holder.data(), sizeof(INT_TYPE), FORMAT_BINARY); +} + +static uint32_t FloatHtonl(float num) { + std::array arr; + memcpy(arr.data(), &num, sizeof(float)); + uint32_t int_num = *reinterpret_cast(arr.data()); + return htonl(int_num); +} + +static uint64_t DoubleHtonll(double num) { + std::array arr; + memcpy(arr.data(), &num, sizeof(double)); + uint64_t int_num = *reinterpret_cast(arr.data()); + return htonll(int_num); +} + +static Param CreateParam(Value &value, vector ©_holder) { + if (value.IsNull()) { + return Param(nullptr, 0, FORMAT_BINARY); + } + + switch (value.type().id()) { + case LogicalTypeId::VARCHAR: + return CreateVarcharParam(value); + case LogicalTypeId::TINYINT: { + uint16_t num = static_cast(TinyIntValue::Get(value)); + return CreateIntParam(htons(num), copy_holder); + } + case LogicalTypeId::UTINYINT: { + uint16_t num = static_cast(UTinyIntValue::Get(value)); + return CreateIntParam(htons(num), copy_holder); + } + case LogicalTypeId::SMALLINT: { + uint16_t num = static_cast(SmallIntValue::Get(value)); + return CreateIntParam(htons(num), copy_holder); + } + case LogicalTypeId::USMALLINT: { + uint16_t num = static_cast(USmallIntValue::Get(value)); + return CreateIntParam(htonl(num), copy_holder); + } + case LogicalTypeId::INTEGER: { + uint32_t num = static_cast(IntegerValue::Get(value)); + return CreateIntParam(htonl(num), copy_holder); + } + case LogicalTypeId::UINTEGER: { + uint64_t num = static_cast(UIntegerValue::Get(value)); + return CreateIntParam(htonll(num), copy_holder); + } + case LogicalTypeId::BIGINT: { + uint64_t num = static_cast(BigIntValue::Get(value)); + return CreateIntParam(htonll(num), copy_holder); + } + case LogicalTypeId::FLOAT: { + float num = FloatValue::Get(value); + return CreateIntParam(FloatHtonl(num), copy_holder); + } + case LogicalTypeId::DOUBLE: { + double num = DoubleValue::Get(value); + return CreateIntParam(DoubleHtonll(num), copy_holder); + } + default: + throw BinderException("Unsupported parameter type: %s", value.type().ToString().c_str()); + } +} + +PostgresParameters::PostgresParameters(vector types_p, vector values_p) + : types(std::move(types_p)), values(std::move(values_p)) { + idx_t count = types.size(); + if (values.size() != count) { + throw BinderException("Parameters count mismatch, types count: %zu, values count: %zu", count, values.size()); + } + + copied_values.resize(count); + value_ptrs.resize(count); + lengths.resize(count); + formats.resize(count); + + for (idx_t i = 0; i < types.size(); i++) { + Value &val = values[i]; + vector ©_holder = copied_values[i]; + + Param param = CreateParam(val, copy_holder); + + value_ptrs[i] = param.ptr; + lengths[i] = param.length; + formats[i] = param.format; + } +} + +} // namespace duckdb diff --git a/src/postgres_query.cpp b/src/postgres_query.cpp index 700db2fa7..f999617dd 100644 --- a/src/postgres_query.cpp +++ b/src/postgres_query.cpp @@ -1,6 +1,7 @@ #include "duckdb.hpp" #include "duckdb/parser/parsed_data/create_table_function_info.hpp" +#include "postgres_parameters.hpp" #include "postgres_scanner.hpp" #include "duckdb/main/database_manager.hpp" #include "duckdb/main/attached_database.hpp" @@ -44,7 +45,19 @@ static unique_ptr PGQueryBind(ClientContext &context, TableFunctio use_transaction = BooleanValue::Get(kv.second); } } - result->use_transaction = use_transaction; + + vector param_values; + auto params_it = input.named_parameters.find("params"); + if (params_it != input.named_parameters.end()) { + Value &struct_val = params_it->second; + if (struct_val.IsNull()) { + throw BinderException("Parameters to postgres_query cannot be NULL"); + } + if (struct_val.type().id() != LogicalTypeId::STRUCT) { + throw BinderException("Query parameters must be specified in a STRUCT"); + } + param_values = StructValue::GetChildren(struct_val); + } auto &con = use_transaction ? transaction.GetConnection() : transaction.GetConnectionWithoutTransaction(); @@ -65,7 +78,7 @@ static unique_ptr PGQueryBind(ClientContext &context, TableFunctio auto extended_err = describe_prepared ? PQresultErrorMessage(describe_prepared) : PQerrorMessage(conn); throw BinderException("Failed to describe prepared statement: %s", extended_err); } - auto nfields = PQnfields(describe_prepared); + int nfields = PQnfields(describe_prepared); if (nfields <= 0) { throw BinderException("No fields returned by query \"%s\" - the query must be a SELECT statement that returns " "at least one column", @@ -82,6 +95,16 @@ static unique_ptr PGQueryBind(ClientContext &context, TableFunctio return_types.emplace_back(converted_type); names.emplace_back(PQfname(describe_prepared, c)); } + int nparams = PQnparams(describe_prepared); + if (nparams != param_values.size()) { + throw BinderException("Incorrect number of parameters specified, expected: %d, actual: %zu, query: \"%s\"", + nparams, param_values.size(), sql); + } + vector param_types; + for (idx_t p = 0; p < nparams; p++) { + Oid ptype = PQparamtype(describe_prepared, p); + param_types.emplace_back(ptype); + } // set up the bind data result->SetCatalog(pg_catalog); @@ -91,12 +114,15 @@ static unique_ptr PGQueryBind(ClientContext &context, TableFunctio result->read_only = false; result->SetTablePages(0); result->sql = std::move(sql); + result->params = PostgresParameters(std::move(param_types), std::move(param_values)); + result->use_transaction = use_transaction; return std::move(result); } PostgresQueryFunction::PostgresQueryFunction() : TableFunction("postgres_query", {LogicalType::VARCHAR, LogicalType::VARCHAR}, nullptr, PGQueryBind) { named_parameters["use_transaction"] = LogicalType::BOOLEAN; + named_parameters["params"] = LogicalType::ANY; PostgresScanFunction scan_function; init_global = scan_function.init_global; init_local = scan_function.init_local; diff --git a/src/postgres_scanner.cpp b/src/postgres_scanner.cpp index 829f4552d..eb4360a39 100644 --- a/src/postgres_scanner.cpp +++ b/src/postgres_scanner.cpp @@ -89,8 +89,8 @@ static void PostgresGetSnapshot(ClientContext &context, PostgresVersion version, return; } - result = - con.TryQuery(context, "SELECT pg_is_in_recovery(), pg_export_snapshot(), (select count(*) from pg_stat_wal_receiver)"); + result = con.TryQuery( + context, "SELECT pg_is_in_recovery(), pg_export_snapshot(), (select count(*) from pg_stat_wal_receiver)"); if (result) { auto in_recovery = result->GetBool(0, 0) || result->GetInt64(0, 2) > 0; gstate.snapshot = ""; diff --git a/src/postgres_text_reader.cpp b/src/postgres_text_reader.cpp index 94db174ea..e345d2b8c 100644 --- a/src/postgres_text_reader.cpp +++ b/src/postgres_text_reader.cpp @@ -14,7 +14,7 @@ PostgresTextReader::~PostgresTextReader() { } void PostgresTextReader::BeginCopy(ClientContext &context, const string &sql) { - result = con.Query(context, sql); + result = con.Query(context, sql, bind_data.params); row_offset = 0; } diff --git a/src/storage/postgres_catalog.cpp b/src/storage/postgres_catalog.cpp index 888463b27..84f637ddf 100644 --- a/src/storage/postgres_catalog.cpp +++ b/src/storage/postgres_catalog.cpp @@ -11,7 +11,8 @@ namespace duckdb { PostgresCatalog::PostgresCatalog(AttachedDatabase &db_p, string connection_string_p, string attach_path_p, - AccessMode access_mode, string schema_to_load, PostgresIsolationLevel isolation_level, ClientContext &context) + AccessMode access_mode, string schema_to_load, PostgresIsolationLevel isolation_level, + ClientContext &context) : Catalog(db_p), connection_string(std::move(connection_string_p)), attach_path(std::move(attach_path_p)), access_mode(access_mode), isolation_level(isolation_level), schemas(*this, schema_to_load), connection_pool(*this), default_schema(schema_to_load) { @@ -138,7 +139,8 @@ void PostgresCatalog::DropSchema(ClientContext &context, DropInfo &info) { void PostgresCatalog::ScanSchemas(ClientContext &context, std::function callback) { auto &postgres_transaction = PostgresTransaction::Get(context, *this); - schemas.Scan(context, postgres_transaction, [&](CatalogEntry &schema) { callback(schema.Cast()); }); + schemas.Scan(context, postgres_transaction, + [&](CatalogEntry &schema) { callback(schema.Cast()); }); } optional_ptr PostgresCatalog::LookupSchema(CatalogTransaction transaction, diff --git a/src/storage/postgres_catalog_set.cpp b/src/storage/postgres_catalog_set.cpp index f0218f093..a46aa241e 100644 --- a/src/storage/postgres_catalog_set.cpp +++ b/src/storage/postgres_catalog_set.cpp @@ -8,7 +8,8 @@ namespace duckdb { PostgresCatalogSet::PostgresCatalogSet(Catalog &catalog, bool is_loaded_p) : catalog(catalog), is_loaded(is_loaded_p) { } -optional_ptr PostgresCatalogSet::GetEntry(ClientContext &context, PostgresTransaction &transaction, const string &name) { +optional_ptr PostgresCatalogSet::GetEntry(ClientContext &context, PostgresTransaction &transaction, + const string &name) { TryLoadEntries(context, transaction); { lock_guard l(entry_lock); @@ -78,7 +79,8 @@ void PostgresCatalogSet::DropEntry(PostgresTransaction &transaction, DropInfo &i entries.erase(info.name); } -void PostgresCatalogSet::Scan(ClientContext &context, PostgresTransaction &transaction, const std::function &callback) { +void PostgresCatalogSet::Scan(ClientContext &context, PostgresTransaction &transaction, + const std::function &callback) { TryLoadEntries(context, transaction); lock_guard l(entry_lock); for (auto &entry : entries) { diff --git a/src/storage/postgres_delete.cpp b/src/storage/postgres_delete.cpp index b50fbc171..d2e6a24b5 100644 --- a/src/storage/postgres_delete.cpp +++ b/src/storage/postgres_delete.cpp @@ -97,7 +97,7 @@ SinkFinalizeType PostgresDelete::Finalize(Pipeline &pipeline, Event &event, Clie // GetData //===--------------------------------------------------------------------===// SourceResultType PostgresDelete::GetDataInternal(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { + OperatorSourceInput &input) const { auto &insert_gstate = sink_state->Cast(); chunk.SetCardinality(1); chunk.SetValue(0, 0, Value::BIGINT(insert_gstate.delete_count)); diff --git a/src/storage/postgres_index.cpp b/src/storage/postgres_index.cpp index eae13de59..cf26c37bd 100644 --- a/src/storage/postgres_index.cpp +++ b/src/storage/postgres_index.cpp @@ -21,7 +21,7 @@ PostgresCreateIndex::PostgresCreateIndex(PhysicalPlan &physical_plan, unique_ptr // Source //===--------------------------------------------------------------------===// SourceResultType PostgresCreateIndex::GetDataInternal(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { + OperatorSourceInput &input) const { auto &catalog = table.catalog; auto &schema = table.schema; auto transaction = catalog.GetCatalogTransaction(context.client); diff --git a/src/storage/postgres_insert.cpp b/src/storage/postgres_insert.cpp index 385351df7..6b129061d 100644 --- a/src/storage/postgres_insert.cpp +++ b/src/storage/postgres_insert.cpp @@ -149,7 +149,7 @@ SinkFinalizeType PostgresInsert::Finalize(Pipeline &pipeline, Event &event, Clie // GetData //===--------------------------------------------------------------------===// SourceResultType PostgresInsert::GetDataInternal(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { + OperatorSourceInput &input) const { auto &insert_gstate = sink_state->Cast(); chunk.SetCardinality(1); chunk.SetValue(0, 0, Value::BIGINT(insert_gstate.insert_count)); diff --git a/src/storage/postgres_schema_entry.cpp b/src/storage/postgres_schema_entry.cpp index 5c1b9def9..d4b16cd52 100644 --- a/src/storage/postgres_schema_entry.cpp +++ b/src/storage/postgres_schema_entry.cpp @@ -194,7 +194,8 @@ optional_ptr PostgresSchemaEntry::LookupEntry(CatalogTransaction t return nullptr; } auto &postgres_transaction = GetPostgresTransaction(transaction); - return GetCatalogSet(catalog_type).GetEntry(transaction.GetContext(), postgres_transaction, lookup_info.GetEntryName()); + return GetCatalogSet(catalog_type) + .GetEntry(transaction.GetContext(), postgres_transaction, lookup_info.GetEntryName()); } PostgresCatalogSet &PostgresSchemaEntry::GetCatalogSet(CatalogType type) { diff --git a/src/storage/postgres_table_set.cpp b/src/storage/postgres_table_set.cpp index 624e137f1..ae53dc61b 100644 --- a/src/storage/postgres_table_set.cpp +++ b/src/storage/postgres_table_set.cpp @@ -173,8 +173,8 @@ unique_ptr PostgresTableSet::GetTableInfo(PostgresTransaction return table_info; } -unique_ptr PostgresTableSet::GetTableInfo(ClientContext &context, PostgresConnection &connection, const string &schema_name, - const string &table_name) { +unique_ptr PostgresTableSet::GetTableInfo(ClientContext &context, PostgresConnection &connection, + const string &schema_name, const string &table_name) { auto query = PostgresTableSet::GetInitializeQuery(schema_name, table_name); auto result = connection.Query(context, query); auto rows = result->Count(); @@ -340,7 +340,8 @@ string PostgresTableSet::GetAlterTableColumnName(const string &name, optional_pt return table.postgres_names[column_index.index]; } -string PostgresTableSet::GetAlterTablePrefix(ClientContext &context, PostgresTransaction &transaction, const string &name) { +string PostgresTableSet::GetAlterTablePrefix(ClientContext &context, PostgresTransaction &transaction, + const string &name) { auto entry = GetEntry(context, transaction, name); return GetAlterTablePrefix(name, entry); } @@ -352,7 +353,7 @@ void PostgresTableSet::AlterTable(ClientContext &context, PostgresTransaction &t transaction.Query(sql); } -void PostgresTableSet::AlterTable(ClientContext &context,PostgresTransaction &transaction, RenameColumnInfo &info) { +void PostgresTableSet::AlterTable(ClientContext &context, PostgresTransaction &transaction, RenameColumnInfo &info) { auto entry = GetEntry(context, transaction, info.name); string sql = GetAlterTablePrefix(info.name, entry); sql += " RENAME COLUMN "; diff --git a/src/storage/postgres_transaction.cpp b/src/storage/postgres_transaction.cpp index 11cc301dc..dfc7cd44c 100644 --- a/src/storage/postgres_transaction.cpp +++ b/src/storage/postgres_transaction.cpp @@ -107,7 +107,7 @@ unique_ptr PostgresTransaction::QueryWithoutTransaction(const st return con.Query(GetContext(), query); } -vector> PostgresTransaction::ExecuteQueries(ClientContext& context, const string &queries) { +vector> PostgresTransaction::ExecuteQueries(ClientContext &context, const string &queries) { auto &con = GetConnectionRaw(); if (transaction_state == PostgresTransactionState::TRANSACTION_NOT_YET_STARTED) { transaction_state = PostgresTransactionState::TRANSACTION_STARTED; diff --git a/src/storage/postgres_update.cpp b/src/storage/postgres_update.cpp index af167ef24..42dc5bb71 100644 --- a/src/storage/postgres_update.cpp +++ b/src/storage/postgres_update.cpp @@ -178,7 +178,7 @@ SinkFinalizeType PostgresUpdate::Finalize(Pipeline &pipeline, Event &event, Clie // GetData //===--------------------------------------------------------------------===// SourceResultType PostgresUpdate::GetDataInternal(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { + OperatorSourceInput &input) const { auto &insert_gstate = sink_state->Cast(); chunk.SetCardinality(1); chunk.SetValue(0, 0, Value::BIGINT(insert_gstate.update_count)); diff --git a/test/sql/scanner/postgres_query_params.test b/test/sql/scanner/postgres_query_params.test new file mode 100644 index 000000000..e83d041e9 --- /dev/null +++ b/test/sql/scanner/postgres_query_params.test @@ -0,0 +1,110 @@ +# name: test/sql/scanner/postgres_query_params.test +# description: Test running postgres_query with parameters +# group: [scanner] + +require postgres_scanner + +require-env POSTGRES_TEST_DATABASE_AVAILABLE + +# COPY does not support parameters: ERROR: there is no parameter $1 +statement ok +SET pg_use_text_protocol=true + +statement ok +ATTACH 'dbname=postgresscanner' AS s1 (TYPE POSTGRES) + +query II +FROM postgres_query('s1', 'SELECT $1::INTEGER, $2::TEXT', params=row(42, 'foo')) +---- +42 foo + +statement ok +PREPARE p1 AS SELECT * FROM postgres_query('s1', 'SELECT $1::INTEGER, $2::TEXT', params=row(?::INTEGER, ?::VARCHAR)) + +query II +EXECUTE p1(42, 'foo') +---- +42 foo + +query II +EXECUTE p1(43, 'bar') +---- +43 bar + +statement ok +DEALLOCATE p1 + +query I +FROM postgres_query('s1', 'SELECT $1::TEXT', params=row(NULL)) +---- +NULL + +query I +FROM postgres_query('s1', 'SELECT $1::TEXT', params=row('foo')) +---- +foo + +query I +FROM postgres_query('s1', 'SELECT $1::BOOLEAN', params=row('t')) +---- +TRUE + +query I +FROM postgres_query('s1', 'SELECT $1::BOOLEAN', params=row('f')) +---- +FALSE + +query I +FROM postgres_query('s1', 'SELECT $1::SMALLINT', params=row(-127::TINYINT)) +---- +-127 + +query I +FROM postgres_query('s1', 'SELECT $1::SMALLINT', params=row(255::UTINYINT)) +---- +255 + +query I +FROM postgres_query('s1', 'SELECT $1::SMALLINT', params=row(-32767::SMALLINT)) +---- +-32767 + +statement error +FROM postgres_query('s1', 'SELECT $1::SMALLINT', params=row(32767::USMALLINT)) +---- +incorrect binary data format in bind parameter 1 + +query I +FROM postgres_query('s1', 'SELECT $1::INTEGER', params=row(65535::USMALLINT)) +---- +65535 + +query I +FROM postgres_query('s1', 'SELECT $1::INTEGER', params=row(-2147483647::INTEGER)) +---- +-2147483647 + +statement error +FROM postgres_query('s1', 'SELECT $1::INTEGER', params=row(2147483647::UINTEGER)) +---- +incorrect binary data format in bind parameter 1 + +query I +FROM postgres_query('s1', 'SELECT $1::BIGINT', params=row(2147483647::UINTEGER)) +---- +2147483647 + +query I +FROM postgres_query('s1', 'SELECT $1::BIGINT', params=row(-9223372036854775807::BIGINT)) +---- +-9223372036854775807 + +query I +FROM postgres_query('s1', 'SELECT $1::REAL', params=row(42.123::FLOAT)) +---- +42.123 + +query I +FROM postgres_query('s1', 'SELECT $1::DOUBLE PRECISION', params=row(42.123::DOUBLE)) +---- +42.123