From 64e2b0294f8e64342c3912aedced911f0e943ff5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9o=20Ercolanelli?= Date: Tue, 25 Nov 2025 15:01:44 +0100 Subject: [PATCH 1/2] More permissive statement parameters --- examples/client_usage.py | 2 +- src/altertable_flightsql/client.py | 85 ++++++++++++++++++++++++++---- tests/test_queries.py | 42 +++++++++++++-- 3 files changed, 113 insertions(+), 16 deletions(-) diff --git a/examples/client_usage.py b/examples/client_usage.py index 2e1079c..38b82a4 100644 --- a/examples/client_usage.py +++ b/examples/client_usage.py @@ -98,7 +98,7 @@ def example_prepared_statement(): # Execute with different parameters for user_id in [1, 2, 3]: print(f"Fetching user {user_id}...") - reader = stmt.query(parameters={"id": user_id}) + reader = stmt.query(parameters=[user_id]) for batch in reader: print(batch.data.to_pandas()) diff --git a/src/altertable_flightsql/client.py b/src/altertable_flightsql/client.py index 3190e0a..ba5ea46 100644 --- a/src/altertable_flightsql/client.py +++ b/src/altertable_flightsql/client.py @@ -4,8 +4,8 @@ This module provides a high-level Python client for Altertable. """ -from collections.abc import Mapping -from typing import Any, Optional +from collections.abc import Mapping, Sequence +from typing import Any, Optional, Union import pyarrow as pa import pyarrow.flight as flight @@ -271,7 +271,17 @@ def prepare( # Parse result result = sql_pb2.ActionCreatePreparedStatementResult() _unpack_command(results[0].body.to_pybytes(), result) - return PreparedStatement(self._client, result.prepared_statement_handle) + + # Extract parameter schema if available + parameter_schema = None + if result.parameter_schema: + parameter_schema = pa.ipc.read_schema(pa.py_buffer(result.parameter_schema)) + + return PreparedStatement( + self._client, + result.prepared_statement_handle, + parameter_schema=parameter_schema + ) def get_catalogs(self) -> flight.FlightStreamReader: """ @@ -445,30 +455,52 @@ class PreparedStatement: Prepared statements can be executed multiple times with different parameters. """ - def __init__(self, client: flight.FlightClient, handle: bytes): + def __init__( + self, + client: flight.FlightClient, + handle: bytes, + parameter_schema: Optional[pa.Schema] = None + ): """ Initialize a prepared statement. Args: client: FlightClient instance. handle: Prepared statement handle from server. + parameter_schema: Optional parameter schema for the prepared statement. """ self._client = client self._handle = handle + self._parameter_schema = parameter_schema def query( self, *, - parameters: Optional[Mapping[str, Any]] = None, + parameters: Optional[Union[pa.Table, pa.RecordBatch, Mapping[str, Any], Sequence[Any]]] = None, ) -> flight.FlightStreamReader: """ Execute the prepared statement query. Args: - parameters: Optional RecordBatch containing parameter values. + parameters: Optional parameters for the query. Can be: + - pyarrow.Table: A table of parameter values + - pyarrow.RecordBatch: A batch of parameter values + - Mapping[str, Any]: A dictionary mapping parameter names to values + - Sequence[Any]: A list of positional parameter values Returns: FlightStreamReader with query results. + + Example: + >>> # Using a dictionary + >>> stmt.query(parameters={"id": 42, "name": "Alice"}) + + >>> # Using a list + >>> stmt.query(parameters=[42, "Alice"]) + + >>> # Using a RecordBatch + >>> batch = pa.record_batch({"id": [42], "name": ["Alice"]}) + >>> stmt.query(parameters=batch) """ cmd = sql_pb2.CommandPreparedStatementQuery() cmd.prepared_statement_handle = self._handle @@ -476,11 +508,10 @@ def query( descriptor = flight.FlightDescriptor.for_command(_pack_command(cmd)) info = self._client.get_flight_info(descriptor) - # If parameters are provided, send them via DoPut - if parameters: - record_batch = pa.record_batch({key: [value] for (key, value) in parameters.items()}) - writer, _ = self._client.do_put(descriptor, record_batch.schema) - writer.write_batch(record_batch) + if parameters is not None: + as_pyarrow = self._get_parameter_as_pyarrow(parameters) + writer, _ = self._client.do_put(descriptor, as_pyarrow.schema) + writer.write(as_pyarrow) writer.close() return self._client.do_get(info.endpoints[0].ticket) @@ -501,3 +532,35 @@ def __enter__(self) -> "PreparedStatement": def __exit__(self, exc_type, exc_val, exc_tb) -> None: """Context manager exit.""" self.close() + + def _get_parameter_as_pyarrow(self, parameters: Union[pa.Table, pa.RecordBatch, Mapping[str, Any], Sequence[Any]]) -> Union[pa.Table, pa.RecordBatch]: + if isinstance(parameters, pa.Table): + return parameters + elif isinstance(parameters, pa.RecordBatch): + return parameters + elif isinstance(parameters, Mapping): + return pa.record_batch({key: [value] for (key, value) in parameters.items()}) + elif isinstance(parameters, Sequence): + if self._parameter_schema is None: + raise ValueError( + "Cannot use positional parameters without parameter schema. " + "Use a dictionary (Mapping[str, Any]) instead." + ) + + # Create record batch with positional parameters + if len(parameters) != len(self._parameter_schema): + raise ValueError( + f"Expected {len(self._parameter_schema)} parameters, " + f"but got {len(parameters)}" + ) + param_dict = { + field.name: [value] + for field, value in zip(self._parameter_schema, parameters) + } + + return pa.record_batch(param_dict) + else: + raise TypeError( + f"Unsupported parameter type: {type(parameters)}. " + "Expected Table, RecordBatch, Mapping, or Sequence." + ) \ No newline at end of file diff --git a/tests/test_queries.py b/tests/test_queries.py index a8c0aee..0ed98ef 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -4,6 +4,7 @@ Tests basic query execution, updates, and prepared statements. """ +import pyarrow as pa import pytest from altertable_flightsql import Client @@ -56,14 +57,47 @@ def test_empty_result_set(self, altertable_client: Client, test_table: TableInfo class TestPreparedStatements: """Test prepared statement functionality.""" - def test_prepare_and_execute(self, altertable_client: Client, test_table: TableInfo): - """Test creating and executing a prepared statement.""" + def test_prepare_with_dict_parameters(self, altertable_client: Client, test_table: TableInfo): + """Test prepared statement with dict parameters.""" # Prepare statement with altertable_client.prepare( - f"SELECT * FROM {test_table.full_name} WHERE id = $id" + f"SELECT * FROM {test_table.full_name} WHERE id = $id AND value >= $min_value" ) as stmt: # Execute prepared statement - reader = stmt.query(parameters={"id": 1}) + reader = stmt.query(parameters={"id": 1, "min_value": 100}) + table = reader.read_all() + assert table.num_rows > 0 + + def test_prepare_with_list_parameters(self, altertable_client: Client, test_table: TableInfo): + """Test prepared statement with list parameters.""" + with altertable_client.prepare( + f"SELECT * FROM {test_table.full_name} WHERE id = ? AND value >= ?" + ) as stmt: + reader = stmt.query(parameters=[1, 100]) + table = reader.read_all() + assert table.num_rows >= 0 + + def test_prepare_with_record_batch_parameters( + self, altertable_client: Client, test_table: TableInfo + ): + """Test prepared statement with RecordBatch parameters.""" + with altertable_client.prepare( + f"SELECT * FROM {test_table.full_name} WHERE id = $id AND value >= $min_value" + ) as stmt: + # Create a RecordBatch with parameters + batch = pa.record_batch({"id": [1], "min_value": [100]}) + reader = stmt.query(parameters=batch) + table = reader.read_all() + assert table.num_rows > 0 + + def test_prepare_with_table_parameters(self, altertable_client: Client, test_table: TableInfo): + """Test prepared statement with Table parameters.""" + with altertable_client.prepare( + f"SELECT * FROM {test_table.full_name} WHERE id = $id AND value >= $min_value" + ) as stmt: + # Create a Table with parameters + param_table = pa.table({"id": [1], "min_value": [100]}) + reader = stmt.query(parameters=param_table) table = reader.read_all() assert table.num_rows > 0 From 0153af425a1f1a8a542488bfb9e8bcf00ea7e3b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9o=20Ercolanelli?= Date: Tue, 25 Nov 2025 15:04:25 +0100 Subject: [PATCH 2/2] fix lint --- src/altertable_flightsql/client.py | 35 +++++++++++++++--------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/src/altertable_flightsql/client.py b/src/altertable_flightsql/client.py index ba5ea46..6d58e32 100644 --- a/src/altertable_flightsql/client.py +++ b/src/altertable_flightsql/client.py @@ -271,16 +271,14 @@ def prepare( # Parse result result = sql_pb2.ActionCreatePreparedStatementResult() _unpack_command(results[0].body.to_pybytes(), result) - + # Extract parameter schema if available parameter_schema = None if result.parameter_schema: parameter_schema = pa.ipc.read_schema(pa.py_buffer(result.parameter_schema)) - + return PreparedStatement( - self._client, - result.prepared_statement_handle, - parameter_schema=parameter_schema + self._client, result.prepared_statement_handle, parameter_schema=parameter_schema ) def get_catalogs(self) -> flight.FlightStreamReader: @@ -456,10 +454,10 @@ class PreparedStatement: """ def __init__( - self, - client: flight.FlightClient, + self, + client: flight.FlightClient, handle: bytes, - parameter_schema: Optional[pa.Schema] = None + parameter_schema: Optional[pa.Schema] = None, ): """ Initialize a prepared statement. @@ -476,7 +474,9 @@ def __init__( def query( self, *, - parameters: Optional[Union[pa.Table, pa.RecordBatch, Mapping[str, Any], Sequence[Any]]] = None, + parameters: Optional[ + Union[pa.Table, pa.RecordBatch, Mapping[str, Any], Sequence[Any]] + ] = None, ) -> flight.FlightStreamReader: """ Execute the prepared statement query. @@ -490,14 +490,14 @@ def query( Returns: FlightStreamReader with query results. - + Example: >>> # Using a dictionary >>> stmt.query(parameters={"id": 42, "name": "Alice"}) - + >>> # Using a list >>> stmt.query(parameters=[42, "Alice"]) - + >>> # Using a RecordBatch >>> batch = pa.record_batch({"id": [42], "name": ["Alice"]}) >>> stmt.query(parameters=batch) @@ -533,7 +533,9 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None: """Context manager exit.""" self.close() - def _get_parameter_as_pyarrow(self, parameters: Union[pa.Table, pa.RecordBatch, Mapping[str, Any], Sequence[Any]]) -> Union[pa.Table, pa.RecordBatch]: + def _get_parameter_as_pyarrow( + self, parameters: Union[pa.Table, pa.RecordBatch, Mapping[str, Any], Sequence[Any]] + ) -> Union[pa.Table, pa.RecordBatch]: if isinstance(parameters, pa.Table): return parameters elif isinstance(parameters, pa.RecordBatch): @@ -546,7 +548,7 @@ def _get_parameter_as_pyarrow(self, parameters: Union[pa.Table, pa.RecordBatch, "Cannot use positional parameters without parameter schema. " "Use a dictionary (Mapping[str, Any]) instead." ) - + # Create record batch with positional parameters if len(parameters) != len(self._parameter_schema): raise ValueError( @@ -554,8 +556,7 @@ def _get_parameter_as_pyarrow(self, parameters: Union[pa.Table, pa.RecordBatch, f"but got {len(parameters)}" ) param_dict = { - field.name: [value] - for field, value in zip(self._parameter_schema, parameters) + field.name: [value] for field, value in zip(self._parameter_schema, parameters) } return pa.record_batch(param_dict) @@ -563,4 +564,4 @@ def _get_parameter_as_pyarrow(self, parameters: Union[pa.Table, pa.RecordBatch, raise TypeError( f"Unsupported parameter type: {type(parameters)}. " "Expected Table, RecordBatch, Mapping, or Sequence." - ) \ No newline at end of file + )