Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/client_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def example_prepared_statement():
# Execute with different parameters
for user_id in [1, 2, 3]:
print(f"Fetching user {user_id}...")
reader = stmt.query(parameters={"id": user_id})
reader = stmt.query(parameters=[user_id])
for batch in reader:
print(batch.data.to_pandas())

Expand Down
86 changes: 75 additions & 11 deletions src/altertable_flightsql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
This module provides a high-level Python client for Altertable.
"""

from collections.abc import Mapping
from typing import Any, Optional
from collections.abc import Mapping, Sequence
from typing import Any, Optional, Union

import pyarrow as pa
import pyarrow.flight as flight
Expand Down Expand Up @@ -271,7 +271,15 @@ def prepare(
# Parse result
result = sql_pb2.ActionCreatePreparedStatementResult()
_unpack_command(results[0].body.to_pybytes(), result)
return PreparedStatement(self._client, result.prepared_statement_handle)

# Extract parameter schema if available
parameter_schema = None
if result.parameter_schema:
parameter_schema = pa.ipc.read_schema(pa.py_buffer(result.parameter_schema))

return PreparedStatement(
self._client, result.prepared_statement_handle, parameter_schema=parameter_schema
)

def get_catalogs(self) -> flight.FlightStreamReader:
"""
Expand Down Expand Up @@ -445,42 +453,65 @@ class PreparedStatement:
Prepared statements can be executed multiple times with different parameters.
"""

def __init__(self, client: flight.FlightClient, handle: bytes):
def __init__(
self,
client: flight.FlightClient,
handle: bytes,
parameter_schema: Optional[pa.Schema] = None,
):
"""
Initialize a prepared statement.

Args:
client: FlightClient instance.
handle: Prepared statement handle from server.
parameter_schema: Optional parameter schema for the prepared statement.
"""
self._client = client
self._handle = handle
self._parameter_schema = parameter_schema

def query(
self,
*,
parameters: Optional[Mapping[str, Any]] = None,
parameters: Optional[
Union[pa.Table, pa.RecordBatch, Mapping[str, Any], Sequence[Any]]
] = None,
) -> flight.FlightStreamReader:
"""
Execute the prepared statement query.

Args:
parameters: Optional RecordBatch containing parameter values.
parameters: Optional parameters for the query. Can be:
- pyarrow.Table: A table of parameter values
- pyarrow.RecordBatch: A batch of parameter values
- Mapping[str, Any]: A dictionary mapping parameter names to values
- Sequence[Any]: A list of positional parameter values

Returns:
FlightStreamReader with query results.

Example:
>>> # Using a dictionary
>>> stmt.query(parameters={"id": 42, "name": "Alice"})

>>> # Using a list
>>> stmt.query(parameters=[42, "Alice"])

>>> # Using a RecordBatch
>>> batch = pa.record_batch({"id": [42], "name": ["Alice"]})
>>> stmt.query(parameters=batch)
"""
cmd = sql_pb2.CommandPreparedStatementQuery()
cmd.prepared_statement_handle = self._handle

descriptor = flight.FlightDescriptor.for_command(_pack_command(cmd))
info = self._client.get_flight_info(descriptor)

# If parameters are provided, send them via DoPut
if parameters:
record_batch = pa.record_batch({key: [value] for (key, value) in parameters.items()})
writer, _ = self._client.do_put(descriptor, record_batch.schema)
writer.write_batch(record_batch)
if parameters is not None:
as_pyarrow = self._get_parameter_as_pyarrow(parameters)
writer, _ = self._client.do_put(descriptor, as_pyarrow.schema)
writer.write(as_pyarrow)
writer.close()

return self._client.do_get(info.endpoints[0].ticket)
Expand All @@ -501,3 +532,36 @@ def __enter__(self) -> "PreparedStatement":
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Context manager exit."""
self.close()

def _get_parameter_as_pyarrow(
self, parameters: Union[pa.Table, pa.RecordBatch, Mapping[str, Any], Sequence[Any]]
) -> Union[pa.Table, pa.RecordBatch]:
if isinstance(parameters, pa.Table):
return parameters
elif isinstance(parameters, pa.RecordBatch):
return parameters
elif isinstance(parameters, Mapping):
return pa.record_batch({key: [value] for (key, value) in parameters.items()})
elif isinstance(parameters, Sequence):
if self._parameter_schema is None:
raise ValueError(
"Cannot use positional parameters without parameter schema. "
"Use a dictionary (Mapping[str, Any]) instead."
)

# Create record batch with positional parameters
if len(parameters) != len(self._parameter_schema):
raise ValueError(
f"Expected {len(self._parameter_schema)} parameters, "
f"but got {len(parameters)}"
)
param_dict = {
field.name: [value] for field, value in zip(self._parameter_schema, parameters)
}

return pa.record_batch(param_dict)
else:
raise TypeError(
f"Unsupported parameter type: {type(parameters)}. "
"Expected Table, RecordBatch, Mapping, or Sequence."
)
42 changes: 38 additions & 4 deletions tests/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Tests basic query execution, updates, and prepared statements.
"""

import pyarrow as pa
import pytest

from altertable_flightsql import Client
Expand Down Expand Up @@ -56,14 +57,47 @@ def test_empty_result_set(self, altertable_client: Client, test_table: TableInfo
class TestPreparedStatements:
"""Test prepared statement functionality."""

def test_prepare_and_execute(self, altertable_client: Client, test_table: TableInfo):
"""Test creating and executing a prepared statement."""
def test_prepare_with_dict_parameters(self, altertable_client: Client, test_table: TableInfo):
"""Test prepared statement with dict parameters."""
# Prepare statement
with altertable_client.prepare(
f"SELECT * FROM {test_table.full_name} WHERE id = $id"
f"SELECT * FROM {test_table.full_name} WHERE id = $id AND value >= $min_value"
) as stmt:
# Execute prepared statement
reader = stmt.query(parameters={"id": 1})
reader = stmt.query(parameters={"id": 1, "min_value": 100})
table = reader.read_all()
assert table.num_rows > 0

def test_prepare_with_list_parameters(self, altertable_client: Client, test_table: TableInfo):
"""Test prepared statement with list parameters."""
with altertable_client.prepare(
f"SELECT * FROM {test_table.full_name} WHERE id = ? AND value >= ?"
) as stmt:
reader = stmt.query(parameters=[1, 100])
table = reader.read_all()
assert table.num_rows >= 0

def test_prepare_with_record_batch_parameters(
self, altertable_client: Client, test_table: TableInfo
):
"""Test prepared statement with RecordBatch parameters."""
with altertable_client.prepare(
f"SELECT * FROM {test_table.full_name} WHERE id = $id AND value >= $min_value"
) as stmt:
# Create a RecordBatch with parameters
batch = pa.record_batch({"id": [1], "min_value": [100]})
reader = stmt.query(parameters=batch)
table = reader.read_all()
assert table.num_rows > 0

def test_prepare_with_table_parameters(self, altertable_client: Client, test_table: TableInfo):
"""Test prepared statement with Table parameters."""
with altertable_client.prepare(
f"SELECT * FROM {test_table.full_name} WHERE id = $id AND value >= $min_value"
) as stmt:
# Create a Table with parameters
param_table = pa.table({"id": [1], "min_value": [100]})
reader = stmt.query(parameters=param_table)
table = reader.read_all()
assert table.num_rows > 0

Expand Down