Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 34 additions & 13 deletions docs/integrations/engines/trino.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,21 @@ hive.metastore.glue.default-warehouse-dir=s3://my-bucket/

### Connection options

| Option | Description | Type | Required |
|----------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------:|:--------:|
| `type` | Engine type name - must be `trino` | string | Y |
| `user` | The username (of the account) to log in to your cluster. When connecting to Starburst Galaxy clusters, you must include the role of the user as a suffix to the username. | string | Y |
| `host` | The hostname of your cluster. Don't include the `http://` or `https://` prefix. | string | Y |
| `catalog` | The name of a catalog in your cluster. | string | Y |
| `http_scheme` | The HTTP scheme to use when connecting to your cluster. By default, it's `https` and can only be `http` for no-auth or basic auth. | string | N |
| `port` | The port to connect to your cluster. By default, it's `443` for `https` scheme and `80` for `http` | int | N |
| `roles` | Mapping of catalog name to a role | dict | N |
| `http_headers` | Additional HTTP headers to send with each request. | dict | N |
| `session_properties` | Trino session properties. Run `SHOW SESSION` to see all options. | dict | N |
| `retries` | Number of retries to attempt when a request fails. Default: `3` | int | N |
| `timezone` | Timezone to use for the connection. Default: client-side local timezone | string | N |
| Option | Description | Type | Required |
|---------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------:|:--------:|
| `type` | Engine type name - must be `trino` | string | Y |
| `user` | The username (of the account) to log in to your cluster. When connecting to Starburst Galaxy clusters, you must include the role of the user as a suffix to the username. | string | Y |
| `host` | The hostname of your cluster. Don't include the `http://` or `https://` prefix. | string | Y |
| `catalog` | The name of a catalog in your cluster. | string | Y |
| `http_scheme` | The HTTP scheme to use when connecting to your cluster. By default, it's `https` and can only be `http` for no-auth or basic auth. | string | N |
| `port` | The port to connect to your cluster. By default, it's `443` for `https` scheme and `80` for `http` | int | N |
| `roles` | Mapping of catalog name to a role | dict | N |
| `http_headers` | Additional HTTP headers to send with each request. | dict | N |
| `session_properties` | Trino session properties. Run `SHOW SESSION` to see all options. | dict | N |
| `retries` | Number of retries to attempt when a request fails. Default: `3` | int | N |
| `timezone` | Timezone to use for the connection. Default: client-side local timezone | string | N |
| `schema_location_mapping` | A mapping of regex patterns to S3 locations to use for the `LOCATION` property when creating schemas. See [Table and Schema locations](#table-and-schema-locations) for more details. | dict | N |
| `catalog_type_overrides` | A mapping of catalog names to their connector type. This is used to enable/disable connector specific behavior. See [Catalog Type Overrides](#catalog-type-overrides) for more details. | dict | N |

## Table and Schema locations

Expand Down Expand Up @@ -204,6 +206,25 @@ SELECT ...

This will cause SQLMesh to set the specified `LOCATION` when issuing a `CREATE TABLE` statement.

## Catalog Type Overrides

SQLMesh attempts to determine the connector type of a catalog by querying the `system.metadata.catalogs` table and checking the `connector_name` column.
It checks if the connector name is `hive` for Hive connector behavior or contains `iceberg` or `delta_lake` for Iceberg or Delta Lake connector behavior respectively.
However, the connector name may not always be a reliable way to determine the connector type, for example when using a custom connector or a fork of an existing connector.
To handle such cases, you can use the `catalog_type_overrides` connection property to explicitly specify the connector type for specific catalogs.
For example, to specify that the `datalake` catalog is using the Iceberg connector and the `analytics` catalog is using the Hive connector, you can configure the connection as follows:

```yaml title="config.yaml"
gateways:
trino:
connection:
type: trino
...
catalog_type_overrides:
datalake: iceberg
analytics: hive
```

## Authentication

=== "No Auth"
Expand Down
2 changes: 2 additions & 0 deletions sqlmesh/core/config/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class ConnectionConfig(abc.ABC, BaseConfig):
pre_ping: bool
pretty_sql: bool = False
schema_differ_overrides: t.Optional[t.Dict[str, t.Any]] = None
catalog_type_overrides: t.Optional[t.Dict[str, str]] = None

# Whether to share a single connection across threads or create a new connection per thread.
shared_connection: t.ClassVar[bool] = False
Expand Down Expand Up @@ -176,6 +177,7 @@ def create_engine_adapter(
pretty_sql=self.pretty_sql,
shared_connection=self.shared_connection,
schema_differ_overrides=self.schema_differ_overrides,
catalog_type_overrides=self.catalog_type_overrides,
**self._extra_engine_config,
)

Expand Down
10 changes: 9 additions & 1 deletion sqlmesh/core/engine_adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ def schema_differ(self) -> SchemaDiffer:
}
)

@property
def _catalog_type_overrides(self) -> t.Dict[str, str]:
return self._extra_config.get("catalog_type_overrides") or {}

@classmethod
def _casted_columns(
cls,
Expand Down Expand Up @@ -430,7 +434,11 @@ def get_catalog_type(self, catalog: t.Optional[str]) -> str:
raise UnsupportedCatalogOperationError(
f"{self.dialect} does not support catalogs and a catalog was provided: {catalog}"
)
return self.DEFAULT_CATALOG_TYPE
return (
self._catalog_type_overrides.get(catalog, self.DEFAULT_CATALOG_TYPE)
if catalog
else self.DEFAULT_CATALOG_TYPE
)

def get_catalog_type_from_table(self, table: TableName) -> str:
"""Get the catalog type from a table name if it has a catalog specified, otherwise return the current catalog type"""
Expand Down
4 changes: 3 additions & 1 deletion sqlmesh/core/engine_adapter/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class TrinoEngineAdapter(
MAX_TIMESTAMP_PRECISION = 3

@property
def schema_location_mapping(self) -> t.Optional[dict[re.Pattern, str]]:
def schema_location_mapping(self) -> t.Optional[t.Dict[re.Pattern, str]]:
return self._extra_config.get("schema_location_mapping")

@property
Expand All @@ -86,6 +86,8 @@ def set_current_catalog(self, catalog: str) -> None:
def get_catalog_type(self, catalog: t.Optional[str]) -> str:
row: t.Tuple = tuple()
if catalog:
if catalog_type_override := self._catalog_type_overrides.get(catalog):
return catalog_type_override
row = (
self.fetchone(
f"select connector_name from system.metadata.catalogs where catalog_name='{catalog}'"
Expand Down
11 changes: 6 additions & 5 deletions tests/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,6 +957,7 @@ def test_dlt_filesystem_pipeline(tmp_path):
" # pre_ping: False\n"
" # pretty_sql: False\n"
" # schema_differ_overrides: \n"
" # catalog_type_overrides: \n"
" # aws_access_key_id: \n"
" # aws_secret_access_key: \n"
" # role_arn: \n"
Expand Down Expand Up @@ -1960,11 +1961,11 @@ def test_init_dbt_template(runner: CliRunner, tmp_path: Path):
@time_machine.travel(FREEZE_TIME)
def test_init_project_engine_configs(tmp_path):
engine_type_to_config = {
"redshift": "# concurrent_tasks: 4\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # user: \n # password: \n # database: \n # host: \n # port: \n # source_address: \n # unix_sock: \n # ssl: \n # sslmode: \n # timeout: \n # tcp_keepalive: \n # application_name: \n # preferred_role: \n # principal_arn: \n # credentials_provider: \n # region: \n # cluster_identifier: \n # iam: \n # is_serverless: \n # serverless_acct_id: \n # serverless_work_group: \n # enable_merge: ",
"bigquery": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # method: oauth\n # project: \n # execution_project: \n # quota_project: \n # location: \n # keyfile: \n # keyfile_json: \n # token: \n # refresh_token: \n # client_id: \n # client_secret: \n # token_uri: \n # scopes: \n # impersonated_service_account: \n # job_creation_timeout_seconds: \n # job_execution_timeout_seconds: \n # job_retries: 1\n # job_retry_deadline_seconds: \n # priority: \n # maximum_bytes_billed: ",
"snowflake": "account: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # user: \n # password: \n # warehouse: \n # database: \n # role: \n # authenticator: \n # token: \n # host: \n # port: \n # application: Tobiko_SQLMesh\n # private_key: \n # private_key_path: \n # private_key_passphrase: \n # session_parameters: ",
"databricks": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # server_hostname: \n # http_path: \n # access_token: \n # auth_type: \n # oauth_client_id: \n # oauth_client_secret: \n # catalog: \n # http_headers: \n # session_configuration: \n # databricks_connect_server_hostname: \n # databricks_connect_access_token: \n # databricks_connect_cluster_id: \n # databricks_connect_use_serverless: False\n # force_databricks_connect: False\n # disable_databricks_connect: False\n # disable_spark_session: False",
"postgres": "host: \n user: \n password: \n port: \n database: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: True\n # pretty_sql: False\n # schema_differ_overrides: \n # keepalives_idle: \n # connect_timeout: 10\n # role: \n # sslmode: \n # application_name: ",
"redshift": "# concurrent_tasks: 4\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # user: \n # password: \n # database: \n # host: \n # port: \n # source_address: \n # unix_sock: \n # ssl: \n # sslmode: \n # timeout: \n # tcp_keepalive: \n # application_name: \n # preferred_role: \n # principal_arn: \n # credentials_provider: \n # region: \n # cluster_identifier: \n # iam: \n # is_serverless: \n # serverless_acct_id: \n # serverless_work_group: \n # enable_merge: ",
"bigquery": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # method: oauth\n # project: \n # execution_project: \n # quota_project: \n # location: \n # keyfile: \n # keyfile_json: \n # token: \n # refresh_token: \n # client_id: \n # client_secret: \n # token_uri: \n # scopes: \n # impersonated_service_account: \n # job_creation_timeout_seconds: \n # job_execution_timeout_seconds: \n # job_retries: 1\n # job_retry_deadline_seconds: \n # priority: \n # maximum_bytes_billed: ",
"snowflake": "account: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # user: \n # password: \n # warehouse: \n # database: \n # role: \n # authenticator: \n # token: \n # host: \n # port: \n # application: Tobiko_SQLMesh\n # private_key: \n # private_key_path: \n # private_key_passphrase: \n # session_parameters: ",
"databricks": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # server_hostname: \n # http_path: \n # access_token: \n # auth_type: \n # oauth_client_id: \n # oauth_client_secret: \n # catalog: \n # http_headers: \n # session_configuration: \n # databricks_connect_server_hostname: \n # databricks_connect_access_token: \n # databricks_connect_cluster_id: \n # databricks_connect_use_serverless: False\n # force_databricks_connect: False\n # disable_databricks_connect: False\n # disable_spark_session: False",
"postgres": "host: \n user: \n password: \n port: \n database: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: True\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # keepalives_idle: \n # connect_timeout: 10\n # role: \n # sslmode: \n # application_name: ",
}

for engine_type, expected_config in engine_type_to_config.items():
Expand Down
72 changes: 72 additions & 0 deletions tests/core/engine_adapter/test_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sqlmesh.core.model import load_sql_based_model
from sqlmesh.core.model.definition import SqlModel
from sqlmesh.core.dialect import schema_
from sqlmesh.utils.date import to_ds
from sqlmesh.utils.errors import SQLMeshError
from tests.core.engine_adapter import to_sql_calls

Expand Down Expand Up @@ -683,3 +684,74 @@ def test_replace_table_catalog_support(
sql_calls[0]
== f'CREATE TABLE IF NOT EXISTS "{catalog_name}"."schema"."test_table" AS SELECT 1 AS "col"'
)


@pytest.mark.parametrize(
"catalog_type_overrides", [{}, {"my_catalog": "hive"}, {"other_catalog": "iceberg"}]
)
def test_insert_overwrite_time_partition_hive(
make_mocked_engine_adapter: t.Callable, catalog_type_overrides: t.Dict[str, str]
):
config = TrinoConnectionConfig(
user="user",
host="host",
catalog="catalog",
catalog_type_overrides=catalog_type_overrides,
)
adapter: TrinoEngineAdapter = make_mocked_engine_adapter(
TrinoEngineAdapter, catalog_type_overrides=config.catalog_type_overrides
)
adapter.fetchone = MagicMock(return_value=None) # type: ignore

adapter.insert_overwrite_by_time_partition(
table_name=".".join(["my_catalog", "schema", "test_table"]),
query_or_df=parse_one("SELECT a, b FROM tbl"),
start="2022-01-01",
end="2022-01-02",
time_column="b",
time_formatter=lambda x, _: exp.Literal.string(to_ds(x)),
target_columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")},
)

assert to_sql_calls(adapter) == [
"SET SESSION my_catalog.insert_existing_partitions_behavior='OVERWRITE'",
'INSERT INTO "my_catalog"."schema"."test_table" ("a", "b") SELECT "a", "b" FROM (SELECT "a", "b" FROM "tbl") AS "_subquery" WHERE "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\'',
"SET SESSION my_catalog.insert_existing_partitions_behavior='APPEND'",
]


@pytest.mark.parametrize(
"catalog_type_overrides",
[
{"my_catalog": "iceberg"},
{"my_catalog": "unknown"},
],
)
def test_insert_overwrite_time_partition_iceberg(
make_mocked_engine_adapter: t.Callable, catalog_type_overrides: t.Dict[str, str]
):
config = TrinoConnectionConfig(
user="user",
host="host",
catalog="catalog",
catalog_type_overrides=catalog_type_overrides,
)
adapter: TrinoEngineAdapter = make_mocked_engine_adapter(
TrinoEngineAdapter, catalog_type_overrides=config.catalog_type_overrides
)
adapter.fetchone = MagicMock(return_value=None) # type: ignore

adapter.insert_overwrite_by_time_partition(
table_name=".".join(["my_catalog", "schema", "test_table"]),
query_or_df=parse_one("SELECT a, b FROM tbl"),
start="2022-01-01",
end="2022-01-02",
time_column="b",
time_formatter=lambda x, _: exp.Literal.string(to_ds(x)),
target_columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")},
)

assert to_sql_calls(adapter) == [
'DELETE FROM "my_catalog"."schema"."test_table" WHERE "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\'',
'INSERT INTO "my_catalog"."schema"."test_table" ("a", "b") SELECT "a", "b" FROM (SELECT "a", "b" FROM "tbl") AS "_subquery" WHERE "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\'',
]
Loading