diff --git a/docs/integrations/engines/trino.md b/docs/integrations/engines/trino.md index c590ee32ba..ec1139e20d 100644 --- a/docs/integrations/engines/trino.md +++ b/docs/integrations/engines/trino.md @@ -81,19 +81,21 @@ hive.metastore.glue.default-warehouse-dir=s3://my-bucket/ ### Connection options -| Option | Description | Type | Required | -|----------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------:|:--------:| -| `type` | Engine type name - must be `trino` | string | Y | -| `user` | The username (of the account) to log in to your cluster. When connecting to Starburst Galaxy clusters, you must include the role of the user as a suffix to the username. | string | Y | -| `host` | The hostname of your cluster. Don't include the `http://` or `https://` prefix. | string | Y | -| `catalog` | The name of a catalog in your cluster. | string | Y | -| `http_scheme` | The HTTP scheme to use when connecting to your cluster. By default, it's `https` and can only be `http` for no-auth or basic auth. | string | N | -| `port` | The port to connect to your cluster. By default, it's `443` for `https` scheme and `80` for `http` | int | N | -| `roles` | Mapping of catalog name to a role | dict | N | -| `http_headers` | Additional HTTP headers to send with each request. | dict | N | -| `session_properties` | Trino session properties. Run `SHOW SESSION` to see all options. | dict | N | -| `retries` | Number of retries to attempt when a request fails. Default: `3` | int | N | -| `timezone` | Timezone to use for the connection. Default: client-side local timezone | string | N | +| Option | Description | Type | Required | +|---------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------:|:--------:| +| `type` | Engine type name - must be `trino` | string | Y | +| `user` | The username (of the account) to log in to your cluster. When connecting to Starburst Galaxy clusters, you must include the role of the user as a suffix to the username. | string | Y | +| `host` | The hostname of your cluster. Don't include the `http://` or `https://` prefix. | string | Y | +| `catalog` | The name of a catalog in your cluster. | string | Y | +| `http_scheme` | The HTTP scheme to use when connecting to your cluster. By default, it's `https` and can only be `http` for no-auth or basic auth. | string | N | +| `port` | The port to connect to your cluster. By default, it's `443` for `https` scheme and `80` for `http` | int | N | +| `roles` | Mapping of catalog name to a role | dict | N | +| `http_headers` | Additional HTTP headers to send with each request. | dict | N | +| `session_properties` | Trino session properties. Run `SHOW SESSION` to see all options. | dict | N | +| `retries` | Number of retries to attempt when a request fails. Default: `3` | int | N | +| `timezone` | Timezone to use for the connection. Default: client-side local timezone | string | N | +| `schema_location_mapping` | A mapping of regex patterns to S3 locations to use for the `LOCATION` property when creating schemas. See [Table and Schema locations](#table-and-schema-locations) for more details. | dict | N | +| `catalog_type_overrides` | A mapping of catalog names to their connector type. This is used to enable/disable connector specific behavior. See [Catalog Type Overrides](#catalog-type-overrides) for more details. | dict | N | ## Table and Schema locations @@ -204,6 +206,25 @@ SELECT ... This will cause SQLMesh to set the specified `LOCATION` when issuing a `CREATE TABLE` statement. +## Catalog Type Overrides + +SQLMesh attempts to determine the connector type of a catalog by querying the `system.metadata.catalogs` table and checking the `connector_name` column. +It checks if the connector name is `hive` for Hive connector behavior or contains `iceberg` or `delta_lake` for Iceberg or Delta Lake connector behavior respectively. +However, the connector name may not always be a reliable way to determine the connector type, for example when using a custom connector or a fork of an existing connector. +To handle such cases, you can use the `catalog_type_overrides` connection property to explicitly specify the connector type for specific catalogs. +For example, to specify that the `datalake` catalog is using the Iceberg connector and the `analytics` catalog is using the Hive connector, you can configure the connection as follows: + +```yaml title="config.yaml" +gateways: + trino: + connection: + type: trino + ... + catalog_type_overrides: + datalake: iceberg + analytics: hive +``` + ## Authentication === "No Auth" diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py index 1678f5d147..553ffd58a5 100644 --- a/sqlmesh/core/config/connection.py +++ b/sqlmesh/core/config/connection.py @@ -101,6 +101,7 @@ class ConnectionConfig(abc.ABC, BaseConfig): pre_ping: bool pretty_sql: bool = False schema_differ_overrides: t.Optional[t.Dict[str, t.Any]] = None + catalog_type_overrides: t.Optional[t.Dict[str, str]] = None # Whether to share a single connection across threads or create a new connection per thread. shared_connection: t.ClassVar[bool] = False @@ -176,6 +177,7 @@ def create_engine_adapter( pretty_sql=self.pretty_sql, shared_connection=self.shared_connection, schema_differ_overrides=self.schema_differ_overrides, + catalog_type_overrides=self.catalog_type_overrides, **self._extra_engine_config, ) diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index c48ce2154d..81d0769cdd 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -223,6 +223,10 @@ def schema_differ(self) -> SchemaDiffer: } ) + @property + def _catalog_type_overrides(self) -> t.Dict[str, str]: + return self._extra_config.get("catalog_type_overrides") or {} + @classmethod def _casted_columns( cls, @@ -430,7 +434,11 @@ def get_catalog_type(self, catalog: t.Optional[str]) -> str: raise UnsupportedCatalogOperationError( f"{self.dialect} does not support catalogs and a catalog was provided: {catalog}" ) - return self.DEFAULT_CATALOG_TYPE + return ( + self._catalog_type_overrides.get(catalog, self.DEFAULT_CATALOG_TYPE) + if catalog + else self.DEFAULT_CATALOG_TYPE + ) def get_catalog_type_from_table(self, table: TableName) -> str: """Get the catalog type from a table name if it has a catalog specified, otherwise return the current catalog type""" diff --git a/sqlmesh/core/engine_adapter/trino.py b/sqlmesh/core/engine_adapter/trino.py index 0e6853dd4a..21846b8693 100644 --- a/sqlmesh/core/engine_adapter/trino.py +++ b/sqlmesh/core/engine_adapter/trino.py @@ -71,7 +71,7 @@ class TrinoEngineAdapter( MAX_TIMESTAMP_PRECISION = 3 @property - def schema_location_mapping(self) -> t.Optional[dict[re.Pattern, str]]: + def schema_location_mapping(self) -> t.Optional[t.Dict[re.Pattern, str]]: return self._extra_config.get("schema_location_mapping") @property @@ -86,6 +86,8 @@ def set_current_catalog(self, catalog: str) -> None: def get_catalog_type(self, catalog: t.Optional[str]) -> str: row: t.Tuple = tuple() if catalog: + if catalog_type_override := self._catalog_type_overrides.get(catalog): + return catalog_type_override row = ( self.fetchone( f"select connector_name from system.metadata.catalogs where catalog_name='{catalog}'" diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index e460387bbc..ba987d76da 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -957,6 +957,7 @@ def test_dlt_filesystem_pipeline(tmp_path): " # pre_ping: False\n" " # pretty_sql: False\n" " # schema_differ_overrides: \n" + " # catalog_type_overrides: \n" " # aws_access_key_id: \n" " # aws_secret_access_key: \n" " # role_arn: \n" @@ -1960,11 +1961,11 @@ def test_init_dbt_template(runner: CliRunner, tmp_path: Path): @time_machine.travel(FREEZE_TIME) def test_init_project_engine_configs(tmp_path): engine_type_to_config = { - "redshift": "# concurrent_tasks: 4\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # user: \n # password: \n # database: \n # host: \n # port: \n # source_address: \n # unix_sock: \n # ssl: \n # sslmode: \n # timeout: \n # tcp_keepalive: \n # application_name: \n # preferred_role: \n # principal_arn: \n # credentials_provider: \n # region: \n # cluster_identifier: \n # iam: \n # is_serverless: \n # serverless_acct_id: \n # serverless_work_group: \n # enable_merge: ", - "bigquery": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # method: oauth\n # project: \n # execution_project: \n # quota_project: \n # location: \n # keyfile: \n # keyfile_json: \n # token: \n # refresh_token: \n # client_id: \n # client_secret: \n # token_uri: \n # scopes: \n # impersonated_service_account: \n # job_creation_timeout_seconds: \n # job_execution_timeout_seconds: \n # job_retries: 1\n # job_retry_deadline_seconds: \n # priority: \n # maximum_bytes_billed: ", - "snowflake": "account: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # user: \n # password: \n # warehouse: \n # database: \n # role: \n # authenticator: \n # token: \n # host: \n # port: \n # application: Tobiko_SQLMesh\n # private_key: \n # private_key_path: \n # private_key_passphrase: \n # session_parameters: ", - "databricks": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # server_hostname: \n # http_path: \n # access_token: \n # auth_type: \n # oauth_client_id: \n # oauth_client_secret: \n # catalog: \n # http_headers: \n # session_configuration: \n # databricks_connect_server_hostname: \n # databricks_connect_access_token: \n # databricks_connect_cluster_id: \n # databricks_connect_use_serverless: False\n # force_databricks_connect: False\n # disable_databricks_connect: False\n # disable_spark_session: False", - "postgres": "host: \n user: \n password: \n port: \n database: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: True\n # pretty_sql: False\n # schema_differ_overrides: \n # keepalives_idle: \n # connect_timeout: 10\n # role: \n # sslmode: \n # application_name: ", + "redshift": "# concurrent_tasks: 4\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # user: \n # password: \n # database: \n # host: \n # port: \n # source_address: \n # unix_sock: \n # ssl: \n # sslmode: \n # timeout: \n # tcp_keepalive: \n # application_name: \n # preferred_role: \n # principal_arn: \n # credentials_provider: \n # region: \n # cluster_identifier: \n # iam: \n # is_serverless: \n # serverless_acct_id: \n # serverless_work_group: \n # enable_merge: ", + "bigquery": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # method: oauth\n # project: \n # execution_project: \n # quota_project: \n # location: \n # keyfile: \n # keyfile_json: \n # token: \n # refresh_token: \n # client_id: \n # client_secret: \n # token_uri: \n # scopes: \n # impersonated_service_account: \n # job_creation_timeout_seconds: \n # job_execution_timeout_seconds: \n # job_retries: 1\n # job_retry_deadline_seconds: \n # priority: \n # maximum_bytes_billed: ", + "snowflake": "account: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # user: \n # password: \n # warehouse: \n # database: \n # role: \n # authenticator: \n # token: \n # host: \n # port: \n # application: Tobiko_SQLMesh\n # private_key: \n # private_key_path: \n # private_key_passphrase: \n # session_parameters: ", + "databricks": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # server_hostname: \n # http_path: \n # access_token: \n # auth_type: \n # oauth_client_id: \n # oauth_client_secret: \n # catalog: \n # http_headers: \n # session_configuration: \n # databricks_connect_server_hostname: \n # databricks_connect_access_token: \n # databricks_connect_cluster_id: \n # databricks_connect_use_serverless: False\n # force_databricks_connect: False\n # disable_databricks_connect: False\n # disable_spark_session: False", + "postgres": "host: \n user: \n password: \n port: \n database: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: True\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # keepalives_idle: \n # connect_timeout: 10\n # role: \n # sslmode: \n # application_name: ", } for engine_type, expected_config in engine_type_to_config.items(): diff --git a/tests/core/engine_adapter/test_trino.py b/tests/core/engine_adapter/test_trino.py index 07c4657eb3..526cb05b04 100644 --- a/tests/core/engine_adapter/test_trino.py +++ b/tests/core/engine_adapter/test_trino.py @@ -11,6 +11,7 @@ from sqlmesh.core.model import load_sql_based_model from sqlmesh.core.model.definition import SqlModel from sqlmesh.core.dialect import schema_ +from sqlmesh.utils.date import to_ds from sqlmesh.utils.errors import SQLMeshError from tests.core.engine_adapter import to_sql_calls @@ -683,3 +684,74 @@ def test_replace_table_catalog_support( sql_calls[0] == f'CREATE TABLE IF NOT EXISTS "{catalog_name}"."schema"."test_table" AS SELECT 1 AS "col"' ) + + +@pytest.mark.parametrize( + "catalog_type_overrides", [{}, {"my_catalog": "hive"}, {"other_catalog": "iceberg"}] +) +def test_insert_overwrite_time_partition_hive( + make_mocked_engine_adapter: t.Callable, catalog_type_overrides: t.Dict[str, str] +): + config = TrinoConnectionConfig( + user="user", + host="host", + catalog="catalog", + catalog_type_overrides=catalog_type_overrides, + ) + adapter: TrinoEngineAdapter = make_mocked_engine_adapter( + TrinoEngineAdapter, catalog_type_overrides=config.catalog_type_overrides + ) + adapter.fetchone = MagicMock(return_value=None) # type: ignore + + adapter.insert_overwrite_by_time_partition( + table_name=".".join(["my_catalog", "schema", "test_table"]), + query_or_df=parse_one("SELECT a, b FROM tbl"), + start="2022-01-01", + end="2022-01-02", + time_column="b", + time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), + target_columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")}, + ) + + assert to_sql_calls(adapter) == [ + "SET SESSION my_catalog.insert_existing_partitions_behavior='OVERWRITE'", + 'INSERT INTO "my_catalog"."schema"."test_table" ("a", "b") SELECT "a", "b" FROM (SELECT "a", "b" FROM "tbl") AS "_subquery" WHERE "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\'', + "SET SESSION my_catalog.insert_existing_partitions_behavior='APPEND'", + ] + + +@pytest.mark.parametrize( + "catalog_type_overrides", + [ + {"my_catalog": "iceberg"}, + {"my_catalog": "unknown"}, + ], +) +def test_insert_overwrite_time_partition_iceberg( + make_mocked_engine_adapter: t.Callable, catalog_type_overrides: t.Dict[str, str] +): + config = TrinoConnectionConfig( + user="user", + host="host", + catalog="catalog", + catalog_type_overrides=catalog_type_overrides, + ) + adapter: TrinoEngineAdapter = make_mocked_engine_adapter( + TrinoEngineAdapter, catalog_type_overrides=config.catalog_type_overrides + ) + adapter.fetchone = MagicMock(return_value=None) # type: ignore + + adapter.insert_overwrite_by_time_partition( + table_name=".".join(["my_catalog", "schema", "test_table"]), + query_or_df=parse_one("SELECT a, b FROM tbl"), + start="2022-01-01", + end="2022-01-02", + time_column="b", + time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), + target_columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")}, + ) + + assert to_sql_calls(adapter) == [ + 'DELETE FROM "my_catalog"."schema"."test_table" WHERE "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\'', + 'INSERT INTO "my_catalog"."schema"."test_table" ("a", "b") SELECT "a", "b" FROM (SELECT "a", "b" FROM "tbl") AS "_subquery" WHERE "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\'', + ] diff --git a/tests/core/test_connection_config.py b/tests/core/test_connection_config.py index 907d1b70cc..4e71e18148 100644 --- a/tests/core/test_connection_config.py +++ b/tests/core/test_connection_config.py @@ -425,6 +425,25 @@ def test_trino_schema_location_mapping(make_config): assert all((isinstance(v, str) for v in config.schema_location_mapping.values())) +def test_trino_catalog_type_override(make_config): + required_kwargs = dict( + type="trino", + user="user", + host="host", + catalog="catalog", + ) + + config = make_config( + **required_kwargs, + catalog_type_overrides={"my_catalog": "iceberg"}, + ) + + assert config.catalog_type_overrides is not None + assert len(config.catalog_type_overrides) == 1 + + assert config.catalog_type_overrides == {"my_catalog": "iceberg"} + + def test_duckdb(make_config): config = make_config( type="duckdb",