Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,19 +1502,19 @@ def test_info_uri(trino_connection):
def test_client_tags_single_tag(run_trino):
client_tags = ["foo"]
query_client_tags = retrieve_client_tags_from_query(run_trino, client_tags)
assert query_client_tags == client_tags
assert set(query_client_tags) == set(client_tags)


def test_client_tags_multiple_tags(run_trino):
client_tags = ["foo", "bar"]
query_client_tags = retrieve_client_tags_from_query(run_trino, client_tags)
assert query_client_tags == client_tags
assert set(query_client_tags) == set(client_tags)


def test_client_tags_special_characters(run_trino):
client_tags = ["foo %20", "bar=test"]
query_client_tags = retrieve_client_tags_from_query(run_trino, client_tags)
assert query_client_tags == client_tags
assert set(query_client_tags) == set(client_tags)


def retrieve_client_tags_from_query(run_trino, client_tags):
Expand Down
31 changes: 31 additions & 0 deletions tests/unit/sqlalchemy/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,37 @@ def setup_method(self):
source="trino-sqlalchemy",
),
),
(
make_url(trino_url(
user="user",
host="localhost",
client_info="my-app",
trace_token="trace-123",
sql_path="catalog.schema",
resource_estimates={"CPU_TIME": "10s", "PEAK_MEMORY": "1GB"},
language="en-US",
)),
'trino://user@localhost:8080/'
'?client_info=my-app'
'&language=en-US'
'&resource_estimates=%7B%22CPU_TIME%22%3A+%2210s%22%2C+%22PEAK_MEMORY%22%3A+%221GB%22%7D'
'&source=trino-sqlalchemy'
'&sql_path=catalog.schema'
'&trace_token=trace-123',
list(),
dict(
host="localhost",
port=8080,
catalog="system",
user="user",
source="trino-sqlalchemy",
client_info="my-app",
trace_token="trace-123",
sql_path="catalog.schema",
resource_estimates={"CPU_TIME": "10s", "PEAK_MEMORY": "1GB"},
language="en-US",
),
),
(
make_url(trino_url(
user="user",
Expand Down
201 changes: 200 additions & 1 deletion tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def test_request_client_tags_headers(mock_get_and_post):
)

def assert_headers(headers):
assert headers[constants.HEADER_CLIENT_TAGS] == "tag1,tag2"
assert set(headers[constants.HEADER_CLIENT_TAGS].split(",")) == {"tag1", "tag2"}

req.post("URL")
_, post_kwargs = post.call_args
Expand Down Expand Up @@ -284,6 +284,205 @@ def assert_headers(headers):
assert_headers(get_kwargs["headers"])


def test_request_client_info_headers(mock_get_and_post):
get, post = mock_get_and_post

req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(
user="test_user",
client_info="test-client-info",
),
)

req.post("URL")
_, post_kwargs = post.call_args
assert post_kwargs["headers"][constants.HEADER_CLIENT_INFO] == "test-client-info"

req.get("URL")
_, get_kwargs = get.call_args
assert get_kwargs["headers"][constants.HEADER_CLIENT_INFO] == "test-client-info"


def test_request_client_info_headers_absent(mock_get_and_post):
get, post = mock_get_and_post

req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(user="test_user"),
)

req.post("URL")
_, post_kwargs = post.call_args
assert constants.HEADER_CLIENT_INFO not in post_kwargs["headers"]

req.get("URL")
_, get_kwargs = get.call_args
assert constants.HEADER_CLIENT_INFO not in get_kwargs["headers"]


def test_request_trace_token_headers(mock_get_and_post):
get, post = mock_get_and_post

req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(
user="test_user",
trace_token="test-trace-token",
),
)

req.post("URL")
_, post_kwargs = post.call_args
assert post_kwargs["headers"][constants.HEADER_TRACE_TOKEN] == "test-trace-token"

req.get("URL")
_, get_kwargs = get.call_args
assert get_kwargs["headers"][constants.HEADER_TRACE_TOKEN] == "test-trace-token"


def test_request_trace_token_headers_absent(mock_get_and_post):
get, post = mock_get_and_post

req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(user="test_user"),
)

req.post("URL")
_, post_kwargs = post.call_args
assert constants.HEADER_TRACE_TOKEN not in post_kwargs["headers"]

req.get("URL")
_, get_kwargs = get.call_args
assert constants.HEADER_TRACE_TOKEN not in get_kwargs["headers"]


def test_request_sql_path_headers(mock_get_and_post):
get, post = mock_get_and_post

req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(
user="test_user",
sql_path="catalog.schema",
),
)

req.post("URL")
_, post_kwargs = post.call_args
assert post_kwargs["headers"][constants.HEADER_PATH] == "catalog.schema"

req.get("URL")
_, get_kwargs = get.call_args
assert get_kwargs["headers"][constants.HEADER_PATH] == "catalog.schema"


def test_request_sql_path_headers_absent(mock_get_and_post):
get, post = mock_get_and_post

req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(user="test_user"),
)

req.post("URL")
_, post_kwargs = post.call_args
assert constants.HEADER_PATH not in post_kwargs["headers"]

req.get("URL")
_, get_kwargs = get.call_args
assert constants.HEADER_PATH not in get_kwargs["headers"]


def test_request_resource_estimates_headers(mock_get_and_post):
get, post = mock_get_and_post

req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(
user="test_user",
resource_estimates={"CPU_TIME": "10s", "PEAK_MEMORY": "1GB"},
),
)

req.post("URL")
_, post_kwargs = post.call_args
header = post_kwargs["headers"][constants.HEADER_RESOURCE_ESTIMATE]
assert "CPU_TIME=10s" in header
assert "PEAK_MEMORY=1GB" in header

req.get("URL")
_, get_kwargs = get.call_args
header = get_kwargs["headers"][constants.HEADER_RESOURCE_ESTIMATE]
assert "CPU_TIME=10s" in header
assert "PEAK_MEMORY=1GB" in header


def test_request_resource_estimates_headers_absent(mock_get_and_post):
get, post = mock_get_and_post

req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(user="test_user"),
)

req.post("URL")
_, post_kwargs = post.call_args
assert constants.HEADER_RESOURCE_ESTIMATE not in post_kwargs["headers"]

req.get("URL")
_, get_kwargs = get.call_args
assert constants.HEADER_RESOURCE_ESTIMATE not in get_kwargs["headers"]


def test_request_language_headers(mock_get_and_post):
get, post = mock_get_and_post

req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(
user="test_user",
language="en-US",
),
)

req.post("URL")
_, post_kwargs = post.call_args
assert post_kwargs["headers"][constants.HEADER_LANGUAGE] == "en-US"

req.get("URL")
_, get_kwargs = get.call_args
assert get_kwargs["headers"][constants.HEADER_LANGUAGE] == "en-US"


def test_request_language_headers_absent(mock_get_and_post):
get, post = mock_get_and_post

req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(user="test_user"),
)

req.post("URL")
_, post_kwargs = post.call_args
assert constants.HEADER_LANGUAGE not in post_kwargs["headers"]

req.get("URL")
_, get_kwargs = get.call_args
assert constants.HEADER_LANGUAGE not in get_kwargs["headers"]


def test_enabling_https_automatically_when_using_port_443(mock_get_and_post):
_, post = mock_get_and_post

Expand Down
27 changes: 26 additions & 1 deletion tests/unit/test_client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,32 @@ def test_client_session_extra_credential() -> None:

def test_client_session_extra_client_tags() -> None:
session = ClientSession(user="user")
assert session.client_tags == []
assert session.client_tags == set()


def test_client_session_client_info() -> None:
session = ClientSession(user="user")
assert session.client_info is None


def test_client_session_trace_token() -> None:
session = ClientSession(user="user")
assert session.trace_token is None


def test_client_session_sql_path() -> None:
session = ClientSession(user="user")
assert session.sql_path is None


def test_client_session_resource_estimates() -> None:
session = ClientSession(user="user")
assert session.resource_estimates == {}


def test_client_session_language() -> None:
session = ClientSession(user="user")
assert session.language is None


@pytest.mark.parametrize(
Expand Down
46 changes: 46 additions & 0 deletions tests/unit/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,52 @@ def test_tags_are_set_when_specified(mock_client):
assert passed_client_tags["client_tags"] == client_tags


@patch("trino.dbapi.trino.client")
def test_client_info_is_set_when_specified(mock_client):
with connect("sample_trino_cluster:443", client_info="test-info") as conn:
conn.cursor().execute("SOME FAKE QUERY")

_, passed_kwargs = mock_client.ClientSession.call_args
assert passed_kwargs["client_info"] == "test-info"


@patch("trino.dbapi.trino.client")
def test_trace_token_is_set_when_specified(mock_client):
with connect("sample_trino_cluster:443", trace_token="test-token") as conn:
conn.cursor().execute("SOME FAKE QUERY")

_, passed_kwargs = mock_client.ClientSession.call_args
assert passed_kwargs["trace_token"] == "test-token"


@patch("trino.dbapi.trino.client")
def test_sql_path_is_set_when_specified(mock_client):
with connect("sample_trino_cluster:443", sql_path="catalog.schema") as conn:
conn.cursor().execute("SOME FAKE QUERY")

_, passed_kwargs = mock_client.ClientSession.call_args
assert passed_kwargs["sql_path"] == "catalog.schema"


@patch("trino.dbapi.trino.client")
def test_resource_estimates_is_set_when_specified(mock_client):
resource_estimates = {"CPU_TIME": "10s", "PEAK_MEMORY": "1GB"}
with connect("sample_trino_cluster:443", resource_estimates=resource_estimates) as conn:
conn.cursor().execute("SOME FAKE QUERY")

_, passed_kwargs = mock_client.ClientSession.call_args
assert passed_kwargs["resource_estimates"] == resource_estimates


@patch("trino.dbapi.trino.client")
def test_language_is_set_when_specified(mock_client):
with connect("sample_trino_cluster:443", language="en-US") as conn:
conn.cursor().execute("SOME FAKE QUERY")

_, passed_kwargs = mock_client.ClientSession.call_args
assert passed_kwargs["language"] == "en-US"


@patch("trino.dbapi.trino.client")
def test_role_is_set_when_specified(mock_client):
roles = {"system": "finance"}
Expand Down
Loading
Loading