Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions etc/config.properties
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,8 @@ protocol.spooling.enabled=true
protocol.spooling.shared-secret-key=jxTKysfCBuMZtFqUf8UJDQ1w9ez8rynEJsJqgJf66u0=
protocol.spooling.retrieval-mode=coordinator_proxy

# Enable dynamic catalog management
catalog.management=dynamic

# Disable http request log
http-server.log.enabled=false
4 changes: 1 addition & 3 deletions tests/development_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,6 @@ def start_development_server(port=None, trino_version=TRINO_VERSION):

root = Path(__file__).parent.parent

trino = trino \
.with_volume_mapping(str(root / "etc/catalog"), "/etc/trino/catalog")

# Enable spooling config
if supports_spooling_protocol:
trino \
Expand All @@ -89,6 +86,7 @@ def start_development_server(port=None, trino_version=TRINO_VERSION):
.with_volume_mapping(str(root / "etc/config.properties"), "/etc/trino/config.properties")
else:
trino \
.with_volume_mapping(str(root / "etc/catalog"), "/etc/trino/catalog") \
.with_volume_mapping(str(root / "etc/jvm-pre-466.config"), "/etc/trino/jvm.config") \
.with_volume_mapping(str(root / "etc/config-pre-466.properties"), "/etc/trino/config.properties")

Expand Down
97 changes: 97 additions & 0 deletions tests/integration/test_sqlalchemy_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sqlalchemy.sql import or_
from sqlalchemy.types import ARRAY

import trino.dbapi
from tests.integration.conftest import trino_version
from tests.unit.conftest import sqlalchemy_version
from trino.sqlalchemy.datatype import JSON
Expand Down Expand Up @@ -757,3 +758,99 @@ def _num_queries_containing_string(connection, query_string):
result = connection.execute(statement)
rows = result.fetchall()
return len(list(filter(lambda rec: query_string in rec[0], rows)))


@pytest.mark.skipif(trino_version() == 351, reason="Dynamic catalogs not supported")
def test_get_indexes_returns_empty_for_iceberg_table(run_trino):
host, port = run_trino
catalog_name = "test_iceberg"
schema_name = "test_schema"
table_name = "partitioned"

conn = trino.dbapi.connect(host=host, port=port, user="test")
try:
cur = conn.cursor()
cur.execute(
f"CREATE CATALOG {catalog_name} USING iceberg "
f"WITH (\"iceberg.catalog.type\" = 'TESTING_FILE_METASTORE', "
f"\"hive.metastore.catalog.dir\" = 'file:///tmp/iceberg-test', "
f"\"fs.native-local.enabled\" = 'true')"
)
cur.fetchall()
cur.execute(f"CREATE SCHEMA {catalog_name}.{schema_name}")
cur.fetchall()
cur.execute(
f"CREATE TABLE {catalog_name}.{schema_name}.{table_name} "
f"(id INTEGER, year INTEGER) "
f"WITH (partitioning = ARRAY['year'])"
)
cur.fetchall()
cur.execute(
f"INSERT INTO {catalog_name}.{schema_name}.{table_name} VALUES (1, 2023)"
)
cur.fetchall()

engine = sqla.create_engine(
f"trino://test@{host}:{port}/{catalog_name}",
connect_args={"source": "test", "max_attempts": 1},
)
indexes = sqla.inspect(engine).get_indexes(table_name, schema=schema_name)
assert indexes == []
finally:
cur = conn.cursor()
cur.execute(f"DROP TABLE IF EXISTS {catalog_name}.{schema_name}.{table_name}")
cur.fetchall()
cur.execute(f"DROP SCHEMA IF EXISTS {catalog_name}.{schema_name}")
cur.fetchall()
cur.execute(f"DROP CATALOG IF EXISTS {catalog_name}")
cur.fetchall()
conn.close()


@pytest.mark.skipif(trino_version() == 351, reason="Dynamic catalogs not supported")
def test_get_indexes_returns_partitions_for_hive_table(run_trino):
host, port = run_trino
catalog_name = "test_hive"
schema_name = "test_schema"
table_name = "partitioned"

conn = trino.dbapi.connect(host=host, port=port, user="test")
try:
cur = conn.cursor()
cur.execute(
f"CREATE CATALOG {catalog_name} USING hive "
f"WITH (\"hive.metastore\" = 'file', "
f"\"hive.metastore.catalog.dir\" = 'file:///tmp/hive-test', "
f"\"fs.native-local.enabled\" = 'true')"
)
cur.fetchall()
cur.execute(f"CREATE SCHEMA {catalog_name}.{schema_name}")
cur.fetchall()
cur.execute(
f"CREATE TABLE {catalog_name}.{schema_name}.{table_name} "
f"(id INTEGER, name VARCHAR, region VARCHAR) "
f"WITH (partitioned_by = ARRAY['name', 'region'])"
)
cur.fetchall()
cur.execute(
f"INSERT INTO {catalog_name}.{schema_name}.{table_name} VALUES (1, 'alice', 'us-east')"
)
cur.fetchall()

engine = sqla.create_engine(
f"trino://test@{host}:{port}/{catalog_name}",
connect_args={"source": "test", "max_attempts": 1},
)
indexes = sqla.inspect(engine).get_indexes(table_name, schema=schema_name)
assert len(indexes) == 1
assert indexes[0]["name"] == "partition"
assert indexes[0]["column_names"] == ["name", "region"]
finally:
cur = conn.cursor()
cur.execute(f"DROP TABLE IF EXISTS {catalog_name}.{schema_name}.{table_name}")
cur.fetchall()
cur.execute(f"DROP SCHEMA IF EXISTS {catalog_name}.{schema_name}")
cur.fetchall()
cur.execute(f"DROP CATALOG IF EXISTS {catalog_name}")
cur.fetchall()
conn.close()
15 changes: 13 additions & 2 deletions trino/sqlalchemy/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def _get_partitions(
connection: Connection,
table_name: str,
schema: str = None
) -> List[Dict[str, List[Any]]]:
) -> Optional[List[str]]:
schema = schema or self._get_default_schema_name(connection)
query = dedent(
f"""
Expand All @@ -223,6 +223,17 @@ def _get_partitions(
).strip()
res = connection.execute(sql.text(query))
partition_names = [desc[0] for desc in res.cursor.description]
data_types = [desc[1] for desc in res.cursor.description]
# Compare the column names and types to the shape of an Iceberg $partitions table
if (partition_names == ['partition', 'record_count', 'file_count', 'total_size', 'data']
and data_types[0].startswith('row(')
and data_types[1] == 'bigint'
and data_types[2] == 'bigint'
and data_types[3] == 'bigint'
and data_types[4].startswith('row(')):
# This is an Iceberg $partitions table - these match the partition metadata columns
return None
# This is a Hive table - these are the partition names
return partition_names

def get_pk_constraint(self, connection: Connection, table_name: str, schema: str = None, **kw) -> Dict[str, Any]:
Expand Down Expand Up @@ -322,7 +333,7 @@ def get_indexes(self, connection: Connection, table_name: str, schema: str = Non
try:
partitioned_columns = self._get_partitions(connection, f"{table_name}", schema)
except Exception as e:
# e.g. it's not a Hive table or an unpartitioned Hive table
# e.g. it's an unpartitioned Hive table
logger.debug("Couldn't fetch partition columns. schema: %s, table: %s, error: %s", schema, table_name, e)
if not partitioned_columns:
return []
Expand Down