From 49d4a983d78d00ead4de8d707e5a9d3a85eebbeb Mon Sep 17 00:00:00 2001 From: Ashhar Hasan Date: Tue, 17 Mar 2026 08:55:37 +0530 Subject: [PATCH 1/3] Add support for more protocol headers to connection options Add new connection options that exist in the Trino CLI but were missing from the Python client: client_info, trace_token, sql_path and resource_estimates. --- tests/unit/sqlalchemy/test_dialect.py | 28 +++++ tests/unit/test_client.py | 160 ++++++++++++++++++++++++++ tests/unit/test_client_session.py | 20 ++++ tests/unit/test_dbapi.py | 37 ++++++ trino/client.py | 39 +++++++ trino/constants.py | 3 + trino/dbapi.py | 12 ++ trino/sqlalchemy/dialect.py | 12 ++ trino/sqlalchemy/util.py | 16 +++ 9 files changed, 327 insertions(+) diff --git a/tests/unit/sqlalchemy/test_dialect.py b/tests/unit/sqlalchemy/test_dialect.py index d247a536..c7b03c80 100644 --- a/tests/unit/sqlalchemy/test_dialect.py +++ b/tests/unit/sqlalchemy/test_dialect.py @@ -197,6 +197,34 @@ def setup_method(self): source="trino-sqlalchemy", ), ), + ( + make_url(trino_url( + user="user", + host="localhost", + client_info="my-app", + trace_token="trace-123", + sql_path="catalog.schema", + resource_estimates={"CPU_TIME": "10s", "PEAK_MEMORY": "1GB"}, + )), + 'trino://user@localhost:8080/' + '?client_info=my-app' + '&resource_estimates=%7B%22CPU_TIME%22%3A+%2210s%22%2C+%22PEAK_MEMORY%22%3A+%221GB%22%7D' + '&source=trino-sqlalchemy' + '&sql_path=catalog.schema' + '&trace_token=trace-123', + list(), + dict( + host="localhost", + port=8080, + catalog="system", + user="user", + source="trino-sqlalchemy", + client_info="my-app", + trace_token="trace-123", + sql_path="catalog.schema", + resource_estimates={"CPU_TIME": "10s", "PEAK_MEMORY": "1GB"}, + ), + ), ( make_url(trino_url( user="user", diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index f011a54d..61ba6111 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -284,6 +284,166 @@ def assert_headers(headers): assert_headers(get_kwargs["headers"]) +def test_request_client_info_headers(mock_get_and_post): + get, post = mock_get_and_post + + req = TrinoRequest( + host="coordinator", + port=8080, + client_session=ClientSession( + user="test_user", + client_info="test-client-info", + ), + ) + + req.post("URL") + _, post_kwargs = post.call_args + assert post_kwargs["headers"][constants.HEADER_CLIENT_INFO] == "test-client-info" + + req.get("URL") + _, get_kwargs = get.call_args + assert get_kwargs["headers"][constants.HEADER_CLIENT_INFO] == "test-client-info" + + +def test_request_client_info_headers_absent(mock_get_and_post): + get, post = mock_get_and_post + + req = TrinoRequest( + host="coordinator", + port=8080, + client_session=ClientSession(user="test_user"), + ) + + req.post("URL") + _, post_kwargs = post.call_args + assert constants.HEADER_CLIENT_INFO not in post_kwargs["headers"] + + req.get("URL") + _, get_kwargs = get.call_args + assert constants.HEADER_CLIENT_INFO not in get_kwargs["headers"] + + +def test_request_trace_token_headers(mock_get_and_post): + get, post = mock_get_and_post + + req = TrinoRequest( + host="coordinator", + port=8080, + client_session=ClientSession( + user="test_user", + trace_token="test-trace-token", + ), + ) + + req.post("URL") + _, post_kwargs = post.call_args + assert post_kwargs["headers"][constants.HEADER_TRACE_TOKEN] == "test-trace-token" + + req.get("URL") + _, get_kwargs = get.call_args + assert get_kwargs["headers"][constants.HEADER_TRACE_TOKEN] == "test-trace-token" + + +def test_request_trace_token_headers_absent(mock_get_and_post): + get, post = mock_get_and_post + + req = TrinoRequest( + host="coordinator", + port=8080, + client_session=ClientSession(user="test_user"), + ) + + req.post("URL") + _, post_kwargs = post.call_args + assert constants.HEADER_TRACE_TOKEN not in post_kwargs["headers"] + + req.get("URL") + _, get_kwargs = get.call_args + assert constants.HEADER_TRACE_TOKEN not in get_kwargs["headers"] + + +def test_request_sql_path_headers(mock_get_and_post): + get, post = mock_get_and_post + + req = TrinoRequest( + host="coordinator", + port=8080, + client_session=ClientSession( + user="test_user", + sql_path="catalog.schema", + ), + ) + + req.post("URL") + _, post_kwargs = post.call_args + assert post_kwargs["headers"][constants.HEADER_PATH] == "catalog.schema" + + req.get("URL") + _, get_kwargs = get.call_args + assert get_kwargs["headers"][constants.HEADER_PATH] == "catalog.schema" + + +def test_request_sql_path_headers_absent(mock_get_and_post): + get, post = mock_get_and_post + + req = TrinoRequest( + host="coordinator", + port=8080, + client_session=ClientSession(user="test_user"), + ) + + req.post("URL") + _, post_kwargs = post.call_args + assert constants.HEADER_PATH not in post_kwargs["headers"] + + req.get("URL") + _, get_kwargs = get.call_args + assert constants.HEADER_PATH not in get_kwargs["headers"] + + +def test_request_resource_estimates_headers(mock_get_and_post): + get, post = mock_get_and_post + + req = TrinoRequest( + host="coordinator", + port=8080, + client_session=ClientSession( + user="test_user", + resource_estimates={"CPU_TIME": "10s", "PEAK_MEMORY": "1GB"}, + ), + ) + + req.post("URL") + _, post_kwargs = post.call_args + header = post_kwargs["headers"][constants.HEADER_RESOURCE_ESTIMATE] + assert "CPU_TIME=10s" in header + assert "PEAK_MEMORY=1GB" in header + + req.get("URL") + _, get_kwargs = get.call_args + header = get_kwargs["headers"][constants.HEADER_RESOURCE_ESTIMATE] + assert "CPU_TIME=10s" in header + assert "PEAK_MEMORY=1GB" in header + + +def test_request_resource_estimates_headers_absent(mock_get_and_post): + get, post = mock_get_and_post + + req = TrinoRequest( + host="coordinator", + port=8080, + client_session=ClientSession(user="test_user"), + ) + + req.post("URL") + _, post_kwargs = post.call_args + assert constants.HEADER_RESOURCE_ESTIMATE not in post_kwargs["headers"] + + req.get("URL") + _, get_kwargs = get.call_args + assert constants.HEADER_RESOURCE_ESTIMATE not in get_kwargs["headers"] + + def test_enabling_https_automatically_when_using_port_443(mock_get_and_post): _, post = mock_get_and_post diff --git a/tests/unit/test_client_session.py b/tests/unit/test_client_session.py index ce9a05b0..755104d2 100644 --- a/tests/unit/test_client_session.py +++ b/tests/unit/test_client_session.py @@ -72,6 +72,26 @@ def test_client_session_extra_client_tags() -> None: assert session.client_tags == [] +def test_client_session_client_info() -> None: + session = ClientSession(user="user") + assert session.client_info is None + + +def test_client_session_trace_token() -> None: + session = ClientSession(user="user") + assert session.trace_token is None + + +def test_client_session_sql_path() -> None: + session = ClientSession(user="user") + assert session.sql_path is None + + +def test_client_session_resource_estimates() -> None: + session = ClientSession(user="user") + assert session.resource_estimates == {} + + @pytest.mark.parametrize( argnames=["argument", "result"], argvalues=[ diff --git a/tests/unit/test_dbapi.py b/tests/unit/test_dbapi.py index e3821bba..457ec789 100644 --- a/tests/unit/test_dbapi.py +++ b/tests/unit/test_dbapi.py @@ -264,6 +264,43 @@ def test_tags_are_set_when_specified(mock_client): assert passed_client_tags["client_tags"] == client_tags +@patch("trino.dbapi.trino.client") +def test_client_info_is_set_when_specified(mock_client): + with connect("sample_trino_cluster:443", client_info="test-info") as conn: + conn.cursor().execute("SOME FAKE QUERY") + + _, passed_kwargs = mock_client.ClientSession.call_args + assert passed_kwargs["client_info"] == "test-info" + + +@patch("trino.dbapi.trino.client") +def test_trace_token_is_set_when_specified(mock_client): + with connect("sample_trino_cluster:443", trace_token="test-token") as conn: + conn.cursor().execute("SOME FAKE QUERY") + + _, passed_kwargs = mock_client.ClientSession.call_args + assert passed_kwargs["trace_token"] == "test-token" + + +@patch("trino.dbapi.trino.client") +def test_sql_path_is_set_when_specified(mock_client): + with connect("sample_trino_cluster:443", sql_path="catalog.schema") as conn: + conn.cursor().execute("SOME FAKE QUERY") + + _, passed_kwargs = mock_client.ClientSession.call_args + assert passed_kwargs["sql_path"] == "catalog.schema" + + +@patch("trino.dbapi.trino.client") +def test_resource_estimates_is_set_when_specified(mock_client): + resource_estimates = {"CPU_TIME": "10s", "PEAK_MEMORY": "1GB"} + with connect("sample_trino_cluster:443", resource_estimates=resource_estimates) as conn: + conn.cursor().execute("SOME FAKE QUERY") + + _, passed_kwargs = mock_client.ClientSession.call_args + assert passed_kwargs["resource_estimates"] == resource_estimates + + @patch("trino.dbapi.trino.client") def test_role_is_set_when_specified(mock_client): roles = {"system": "finance"} diff --git a/trino/client.py b/trino/client.py index b5cc62ba..a6bcada6 100644 --- a/trino/client.py +++ b/trino/client.py @@ -190,6 +190,10 @@ def __init__( transaction_id: Optional[str] = None, extra_credential: Optional[List[Tuple[str, str]]] = None, client_tags: Optional[List[str]] = None, + client_info: Optional[str] = None, + trace_token: Optional[str] = None, + sql_path: Optional[str] = None, + resource_estimates: Optional[Dict[str, str]] = None, roles: Optional[Union[Dict[str, str], str]] = None, timezone: Optional[str] = None, encoding: Optional[Union[str, List[str]]] = None, @@ -207,6 +211,10 @@ def __init__( self._transaction_id = transaction_id self._extra_credential = extra_credential self._client_tags = client_tags.copy() if client_tags is not None else list() + self._client_info = client_info + self._trace_token = trace_token + self._sql_path = sql_path + self._resource_estimates = resource_estimates.copy() if resource_estimates is not None else {} self._roles = self._format_roles(roles) if roles is not None else {} if timezone: # Check timezone validity ZoneInfo(timezone) @@ -286,6 +294,22 @@ def extra_credential(self) -> Optional[List[Tuple[str, str]]]: def client_tags(self) -> List[str]: return self._client_tags + @property + def client_info(self) -> Optional[str]: + return self._client_info + + @property + def trace_token(self) -> Optional[str]: + return self._trace_token + + @property + def sql_path(self) -> Optional[str]: + return self._sql_path + + @property + def resource_estimates(self) -> Dict[str, str]: + return self._resource_estimates + @property def roles(self) -> Dict[str, str]: with self._object_lock: @@ -572,6 +596,21 @@ def http_headers(self) -> CaseInsensitiveDict[str]: if self._client_session.client_tags is not None and len(self._client_session.client_tags) > 0: headers[constants.HEADER_CLIENT_TAGS] = ",".join(self._client_session.client_tags) + if self._client_session.client_info is not None: + headers[constants.HEADER_CLIENT_INFO] = self._client_session.client_info + + if self._client_session.trace_token is not None: + headers[constants.HEADER_TRACE_TOKEN] = self._client_session.trace_token + + if self._client_session.sql_path is not None: + headers[constants.HEADER_PATH] = self._client_session.sql_path + + if self._client_session.resource_estimates is not None and len(self._client_session.resource_estimates) > 0: + headers[constants.HEADER_RESOURCE_ESTIMATE] = ",".join( + "{}={}".format(name, urllib.parse.quote_plus(str(value))) + for name, value in self._client_session.resource_estimates.items() + ) + headers[constants.HEADER_SESSION] = ",".join( # ``name`` must not contain ``=`` "{}={}".format(name, urllib.parse.quote(str(value))) diff --git a/trino/constants.py b/trino/constants.py index b136aaaf..9a2f0382 100644 --- a/trino/constants.py +++ b/trino/constants.py @@ -38,6 +38,9 @@ HEADER_CLIENT_TAGS = "X-Trino-Client-Tags" HEADER_EXTRA_CREDENTIAL = "X-Trino-Extra-Credential" HEADER_TIMEZONE = "X-Trino-Time-Zone" +HEADER_TRACE_TOKEN = "X-Trino-Trace-Token" +HEADER_RESOURCE_ESTIMATE = "X-Trino-Resource-Estimate" +HEADER_PATH = "X-Trino-Path" HEADER_ENCODING = "X-Trino-Query-Data-Encoding" HEADER_SESSION = "X-Trino-Session" diff --git a/trino/dbapi.py b/trino/dbapi.py index 42eeb547..1d5096a1 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -159,6 +159,10 @@ def __init__( verify=True, http_session=None, client_tags=None, + client_info=None, + trace_token=None, + sql_path=None, + resource_estimates=None, legacy_primitive_types=False, legacy_prepared_statements=None, roles=None, @@ -191,6 +195,10 @@ def __init__( transaction_id=NO_TRANSACTION, extra_credential=extra_credential, client_tags=client_tags, + client_info=client_info, + trace_token=trace_token, + sql_path=sql_path, + resource_estimates=resource_estimates, roles=roles, timezone=timezone, encoding=encoding, @@ -229,6 +237,10 @@ def __init__( self.max_attempts = max_attempts self.request_timeout = request_timeout self.client_tags = client_tags + self.client_info = client_info + self.trace_token = trace_token + self.sql_path = sql_path + self.resource_estimates = resource_estimates self._isolation_level = isolation_level self._request = None diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py index ad05aeec..020d100d 100644 --- a/trino/sqlalchemy/dialect.py +++ b/trino/sqlalchemy/dialect.py @@ -163,6 +163,18 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any if "client_tags" in url.query: kwargs["client_tags"] = json.loads(unquote_plus(url.query["client_tags"])) + if "client_info" in url.query: + kwargs["client_info"] = unquote_plus(url.query["client_info"]) + + if "trace_token" in url.query: + kwargs["trace_token"] = unquote_plus(url.query["trace_token"]) + + if "sql_path" in url.query: + kwargs["sql_path"] = unquote_plus(url.query["sql_path"]) + + if "resource_estimates" in url.query: + kwargs["resource_estimates"] = json.loads(unquote_plus(url.query["resource_estimates"])) + if "legacy_primitive_types" in url.query: kwargs["legacy_primitive_types"] = json.loads(unquote_plus(url.query["legacy_primitive_types"])) diff --git a/trino/sqlalchemy/util.py b/trino/sqlalchemy/util.py index 58af9206..c797ec7b 100644 --- a/trino/sqlalchemy/util.py +++ b/trino/sqlalchemy/util.py @@ -26,6 +26,10 @@ def _url( http_headers: Dict[str, Union[str, int]] = None, extra_credential: Optional[List[Tuple[str, str]]] = None, client_tags: Optional[List[str]] = None, + client_info: Optional[str] = None, + trace_token: Optional[str] = None, + sql_path: Optional[str] = None, + resource_estimates: Optional[Dict[str, str]] = None, legacy_primitive_types: Optional[bool] = None, legacy_prepared_statements: Optional[bool] = None, access_token: Optional[str] = None, @@ -87,6 +91,18 @@ def _url( if client_tags is not None: trino_url += f"&client_tags={quote_plus(json.dumps(client_tags))}" + if client_info is not None: + trino_url += f"&client_info={quote_plus(client_info)}" + + if trace_token is not None: + trino_url += f"&trace_token={quote_plus(trace_token)}" + + if sql_path is not None: + trino_url += f"&sql_path={quote_plus(sql_path)}" + + if resource_estimates is not None: + trino_url += f"&resource_estimates={quote_plus(json.dumps(resource_estimates))}" + if legacy_primitive_types is not None: trino_url += f"&legacy_primitive_types={json.dumps(legacy_primitive_types)}" From 3decbb52b0a7e790b4459ab7b5023a4f440d56f9 Mon Sep 17 00:00:00 2001 From: Ashhar Hasan Date: Tue, 17 Mar 2026 08:56:49 +0530 Subject: [PATCH 2/3] Change client_tags from List to Set to match trino-cli --- tests/integration/test_dbapi_integration.py | 6 +++--- tests/unit/test_client.py | 2 +- tests/unit/test_client_session.py | 2 +- trino/client.py | 8 +++++--- trino/sqlalchemy/util.py | 2 +- 5 files changed, 11 insertions(+), 9 deletions(-) diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index 71b6c663..6c4f3edf 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -1502,19 +1502,19 @@ def test_info_uri(trino_connection): def test_client_tags_single_tag(run_trino): client_tags = ["foo"] query_client_tags = retrieve_client_tags_from_query(run_trino, client_tags) - assert query_client_tags == client_tags + assert set(query_client_tags) == set(client_tags) def test_client_tags_multiple_tags(run_trino): client_tags = ["foo", "bar"] query_client_tags = retrieve_client_tags_from_query(run_trino, client_tags) - assert query_client_tags == client_tags + assert set(query_client_tags) == set(client_tags) def test_client_tags_special_characters(run_trino): client_tags = ["foo %20", "bar=test"] query_client_tags = retrieve_client_tags_from_query(run_trino, client_tags) - assert query_client_tags == client_tags + assert set(query_client_tags) == set(client_tags) def retrieve_client_tags_from_query(run_trino, client_tags): diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 61ba6111..a031d2b3 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -250,7 +250,7 @@ def test_request_client_tags_headers(mock_get_and_post): ) def assert_headers(headers): - assert headers[constants.HEADER_CLIENT_TAGS] == "tag1,tag2" + assert set(headers[constants.HEADER_CLIENT_TAGS].split(",")) == {"tag1", "tag2"} req.post("URL") _, post_kwargs = post.call_args diff --git a/tests/unit/test_client_session.py b/tests/unit/test_client_session.py index 755104d2..b63b2993 100644 --- a/tests/unit/test_client_session.py +++ b/tests/unit/test_client_session.py @@ -69,7 +69,7 @@ def test_client_session_extra_credential() -> None: def test_client_session_extra_client_tags() -> None: session = ClientSession(user="user") - assert session.client_tags == [] + assert session.client_tags == set() def test_client_session_client_info() -> None: diff --git a/trino/client.py b/trino/client.py index a6bcada6..9f05c732 100644 --- a/trino/client.py +++ b/trino/client.py @@ -47,6 +47,7 @@ import warnings from abc import abstractmethod from collections.abc import Iterator +from collections.abc import Sequence from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from datetime import datetime @@ -59,6 +60,7 @@ from typing import List from typing import Literal from typing import Optional +from typing import Set from typing import Tuple from typing import TypedDict from typing import Union @@ -189,7 +191,7 @@ def __init__( headers: Optional[Dict[str, str]] = None, transaction_id: Optional[str] = None, extra_credential: Optional[List[Tuple[str, str]]] = None, - client_tags: Optional[List[str]] = None, + client_tags: Optional[Sequence[str]] = None, client_info: Optional[str] = None, trace_token: Optional[str] = None, sql_path: Optional[str] = None, @@ -210,7 +212,7 @@ def __init__( self._headers = headers.copy() if headers is not None else {} self._transaction_id = transaction_id self._extra_credential = extra_credential - self._client_tags = client_tags.copy() if client_tags is not None else list() + self._client_tags = set(client_tags) if client_tags is not None else set() self._client_info = client_info self._trace_token = trace_token self._sql_path = sql_path @@ -291,7 +293,7 @@ def extra_credential(self) -> Optional[List[Tuple[str, str]]]: return self._extra_credential @property - def client_tags(self) -> List[str]: + def client_tags(self) -> Set[str]: return self._client_tags @property diff --git a/trino/sqlalchemy/util.py b/trino/sqlalchemy/util.py index c797ec7b..0d7c2b3c 100644 --- a/trino/sqlalchemy/util.py +++ b/trino/sqlalchemy/util.py @@ -89,7 +89,7 @@ def _url( trino_url += f"&extra_credential={quote_plus(json.dumps(extra_credential))}" if client_tags is not None: - trino_url += f"&client_tags={quote_plus(json.dumps(client_tags))}" + trino_url += f"&client_tags={quote_plus(json.dumps(sorted(client_tags)))}" if client_info is not None: trino_url += f"&client_info={quote_plus(client_info)}" From f010d499c93dc1e5ad9c157a0ee871071c23dec5 Mon Sep 17 00:00:00 2001 From: Ashhar Hasan Date: Tue, 17 Mar 2026 08:57:22 +0530 Subject: [PATCH 3/3] Change default request_timeout from 30s to 120s to match trino-cli --- trino/constants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trino/constants.py b/trino/constants.py index 9a2f0382..409d91ee 100644 --- a/trino/constants.py +++ b/trino/constants.py @@ -19,7 +19,7 @@ DEFAULT_SCHEMA: Optional[str] = None DEFAULT_AUTH: Optional[Any] = None DEFAULT_MAX_ATTEMPTS = 3 -DEFAULT_REQUEST_TIMEOUT: float = 30.0 +DEFAULT_REQUEST_TIMEOUT: float = 120.0 MAX_NT_PASSWORD_SIZE: int = 1280 HTTP = "http"