diff --git a/.circleci/continue_config.yml b/.circleci/continue_config.yml index b93caf482e..afaf0e080b 100644 --- a/.circleci/continue_config.yml +++ b/.circleci/continue_config.yml @@ -303,6 +303,7 @@ workflows: - bigquery - clickhouse-cloud - athena + - fabric filters: branches: only: diff --git a/.circleci/install-prerequisites.sh b/.circleci/install-prerequisites.sh index 1eebd92c71..acd25ae02c 100755 --- a/.circleci/install-prerequisites.sh +++ b/.circleci/install-prerequisites.sh @@ -12,7 +12,7 @@ fi ENGINE="$1" -COMMON_DEPENDENCIES="libpq-dev netcat-traditional" +COMMON_DEPENDENCIES="libpq-dev netcat-traditional unixodbc-dev" ENGINE_DEPENDENCIES="" if [ "$ENGINE" == "spark" ]; then diff --git a/Makefile b/Makefile index 0a89bba437..e643ae7ad2 100644 --- a/Makefile +++ b/Makefile @@ -173,6 +173,9 @@ clickhouse-cloud-test: guard-CLICKHOUSE_CLOUD_HOST guard-CLICKHOUSE_CLOUD_USERNA athena-test: guard-AWS_ACCESS_KEY_ID guard-AWS_SECRET_ACCESS_KEY guard-ATHENA_S3_WAREHOUSE_LOCATION engine-athena-install pytest -n auto -m "athena" --retries 3 --junitxml=test-results/junit-athena.xml +fabric-test: guard-FABRIC_HOST guard-FABRIC_CLIENT_ID guard-FABRIC_CLIENT_SECRET guard-FABRIC_DATABASE engine-fabric-install + pytest -n auto -m "fabric" --retries 3 --junitxml=test-results/junit-fabric.xml + vscode_settings: mkdir -p .vscode cp -r ./tooling/vscode/*.json .vscode/ diff --git a/docs/guides/configuration.md b/docs/guides/configuration.md index 6e14d1f605..005e78541b 100644 --- a/docs/guides/configuration.md +++ b/docs/guides/configuration.md @@ -869,6 +869,7 @@ These pages describe the connection configuration options for each execution eng * [BigQuery](../integrations/engines/bigquery.md) * [Databricks](../integrations/engines/databricks.md) * [DuckDB](../integrations/engines/duckdb.md) +* [Fabric](../integrations/engines/fabric.md) * [MotherDuck](../integrations/engines/motherduck.md) * [MySQL](../integrations/engines/mysql.md) * [MSSQL](../integrations/engines/mssql.md) diff --git a/docs/integrations/engines/fabric.md b/docs/integrations/engines/fabric.md new file mode 100644 index 0000000000..eb00b5ac1d --- /dev/null +++ b/docs/integrations/engines/fabric.md @@ -0,0 +1,34 @@ +# Fabric + +## Local/Built-in Scheduler +**Engine Adapter Type**: `fabric` + +NOTE: Fabric Warehouse is not recommended to be used for the SQLMesh [state connection](../../reference/configuration.md#connections). + +### Installation +#### Microsoft Entra ID / Azure Active Directory Authentication: +``` +pip install "sqlmesh[fabric]" +``` + +### Connection options + +| Option | Description | Type | Required | +| ----------------- | ------------------------------------------------------------ | :----------: | :------: | +| `type` | Engine type name - must be `fabric` | string | Y | +| `host` | The hostname of the Fabric Warehouse server | string | Y | +| `user` | The client id to use for authentication with the Fabric Warehouse server | string | N | +| `password` | The client secret to use for authentication with the Fabric Warehouse server | string | N | +| `port` | The port number of the Fabric Warehouse server | int | N | +| `database` | The target database | string | N | +| `charset` | The character set used for the connection | string | N | +| `timeout` | The query timeout in seconds. Default: no timeout | int | N | +| `login_timeout` | The timeout for connection and login in seconds. Default: 60 | int | N | +| `appname` | The application name to use for the connection | string | N | +| `conn_properties` | The list of connection properties | list[string] | N | +| `autocommit` | Is autocommit mode enabled. Default: false | bool | N | +| `driver` | The driver to use for the connection. Default: pyodbc | string | N | +| `driver_name` | The driver name to use for the connection. E.g., *ODBC Driver 18 for SQL Server* | string | N | +| `tenant_id` | The Azure / Entra tenant UUID | string | Y | +| `workspace_id` | The Fabric workspace UUID. The preferred way to retrieve it is by running `notebookutils.runtime.context.get("currentWorkspaceId")` in a python notebook. | string | Y | +| `odbc_properties` | The dict of ODBC connection properties. E.g., authentication: ActiveDirectoryServicePrincipal. See more [here](https://learn.microsoft.com/en-us/sql/connect/odbc/dsn-connection-string-attribute?view=sql-server-ver16). | dict | N | diff --git a/docs/integrations/overview.md b/docs/integrations/overview.md index 5e850afbf6..94b9289d21 100644 --- a/docs/integrations/overview.md +++ b/docs/integrations/overview.md @@ -17,6 +17,7 @@ SQLMesh supports the following execution engines for running SQLMesh projects (e * [ClickHouse](./engines/clickhouse.md) (clickhouse) * [Databricks](./engines/databricks.md) (databricks) * [DuckDB](./engines/duckdb.md) (duckdb) +* [Fabric](./engines/fabric.md) (fabric) * [MotherDuck](./engines/motherduck.md) (motherduck) * [MSSQL](./engines/mssql.md) (mssql) * [MySQL](./engines/mysql.md) (mysql) diff --git a/mkdocs.yml b/mkdocs.yml index 34156b1b66..47ddca54e9 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -83,6 +83,7 @@ nav: - integrations/engines/clickhouse.md - integrations/engines/databricks.md - integrations/engines/duckdb.md + - integrations/engines/fabric.md - integrations/engines/motherduck.md - integrations/engines/mssql.md - integrations/engines/mysql.md diff --git a/pyproject.toml b/pyproject.toml index 91a671b2d4..887f65c8ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,6 +103,7 @@ dev = [ dbt = ["dbt-core<2"] dlt = ["dlt"] duckdb = [] +fabric = ["pyodbc>=5.0.0"] gcppostgres = ["cloud-sql-python-connector[pg8000]>=1.8.0"] github = ["PyGithub>=2.6.0"] llm = ["langchain", "openai"] @@ -252,6 +253,7 @@ markers = [ "clickhouse_cloud: test for Clickhouse (cloud mode)", "databricks: test for Databricks", "duckdb: test for DuckDB", + "fabric: test for Fabric", "motherduck: test for MotherDuck", "mssql: test for MSSQL", "mysql: test for MySQL", diff --git a/sqlmesh/core/config/__init__.py b/sqlmesh/core/config/__init__.py index d8c7607d51..0dc99c0fd1 100644 --- a/sqlmesh/core/config/__init__.py +++ b/sqlmesh/core/config/__init__.py @@ -13,6 +13,7 @@ ConnectionConfig as ConnectionConfig, DatabricksConnectionConfig as DatabricksConnectionConfig, DuckDBConnectionConfig as DuckDBConnectionConfig, + FabricConnectionConfig as FabricConnectionConfig, GCPPostgresConnectionConfig as GCPPostgresConnectionConfig, MotherDuckConnectionConfig as MotherDuckConnectionConfig, MSSQLConnectionConfig as MSSQLConnectionConfig, diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py index 415e916365..e72374a877 100644 --- a/sqlmesh/core/config/connection.py +++ b/sqlmesh/core/config/connection.py @@ -43,7 +43,13 @@ logger = logging.getLogger(__name__) -RECOMMENDED_STATE_SYNC_ENGINES = {"postgres", "gcp_postgres", "mysql", "mssql", "azuresql"} +RECOMMENDED_STATE_SYNC_ENGINES = { + "postgres", + "gcp_postgres", + "mysql", + "mssql", + "azuresql", +} FORBIDDEN_STATE_SYNC_ENGINES = { # Do not support row-level operations "spark", @@ -1690,6 +1696,40 @@ def _extra_engine_config(self) -> t.Dict[str, t.Any]: return {"catalog_support": CatalogSupport.SINGLE_CATALOG_ONLY} +class FabricConnectionConfig(MSSQLConnectionConfig): + """ + Fabric Connection Configuration. + Inherits most settings from MSSQLConnectionConfig and sets the type to 'fabric'. + It is recommended to use the 'pyodbc' driver for Fabric. + """ + + type_: t.Literal["fabric"] = Field(alias="type", default="fabric") # type: ignore + DIALECT: t.ClassVar[t.Literal["fabric"]] = "fabric" # type: ignore + DISPLAY_NAME: t.ClassVar[t.Literal["Fabric"]] = "Fabric" # type: ignore + DISPLAY_ORDER: t.ClassVar[t.Literal[17]] = 17 # type: ignore + driver: t.Literal["pyodbc"] = "pyodbc" + workspace_id: str + tenant_id: str + autocommit: t.Optional[bool] = True + + @property + def _engine_adapter(self) -> t.Type[EngineAdapter]: + from sqlmesh.core.engine_adapter.fabric import FabricAdapter + + return FabricAdapter + + @property + def _extra_engine_config(self) -> t.Dict[str, t.Any]: + return { + "database": self.database, + "catalog_support": CatalogSupport.FULL_SUPPORT, + "workspace_id": self.workspace_id, + "tenant_id": self.tenant_id, + "user": self.user, + "password": self.password, + } + + class SparkConnectionConfig(ConnectionConfig): """ Vanilla Spark Connection Configuration. Use `DatabricksConnectionConfig` for Databricks. diff --git a/sqlmesh/core/engine_adapter/__init__.py b/sqlmesh/core/engine_adapter/__init__.py index 19332dc005..337de39905 100644 --- a/sqlmesh/core/engine_adapter/__init__.py +++ b/sqlmesh/core/engine_adapter/__init__.py @@ -19,6 +19,7 @@ from sqlmesh.core.engine_adapter.trino import TrinoEngineAdapter from sqlmesh.core.engine_adapter.athena import AthenaEngineAdapter from sqlmesh.core.engine_adapter.risingwave import RisingwaveEngineAdapter +from sqlmesh.core.engine_adapter.fabric import FabricAdapter DIALECT_TO_ENGINE_ADAPTER = { "hive": SparkEngineAdapter, @@ -35,6 +36,7 @@ "trino": TrinoEngineAdapter, "athena": AthenaEngineAdapter, "risingwave": RisingwaveEngineAdapter, + "fabric": FabricAdapter, } DIALECT_ALIASES = { diff --git a/sqlmesh/core/engine_adapter/fabric.py b/sqlmesh/core/engine_adapter/fabric.py new file mode 100644 index 0000000000..26e93f55ed --- /dev/null +++ b/sqlmesh/core/engine_adapter/fabric.py @@ -0,0 +1,673 @@ +from __future__ import annotations + +import typing as t +import logging +from sqlglot import exp +from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_result +from sqlmesh.core.engine_adapter.mssql import MSSQLEngineAdapter +from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy, SourceQuery +from sqlmesh.core.engine_adapter.base import EngineAdapter +from sqlmesh.utils import optional_import +from sqlmesh.utils.errors import SQLMeshError + +if t.TYPE_CHECKING: + from sqlmesh.core._typing import TableName, SchemaName + + +from sqlmesh.core.engine_adapter.mixins import LogicalMergeMixin + +logger = logging.getLogger(__name__) +requests = optional_import("requests") + + +class FabricAdapter(LogicalMergeMixin, MSSQLEngineAdapter): + """ + Adapter for Microsoft Fabric. + """ + + DIALECT = "fabric" + SUPPORTS_INDEXES = False + SUPPORTS_TRANSACTIONS = False + SUPPORTS_CREATE_DROP_CATALOG = True + INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.DELETE_INSERT + + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: + super().__init__(*args, **kwargs) + # Store the desired catalog for dynamic switching + self._target_catalog: t.Optional[str] = None + # Store the original connection factory for wrapping + self._original_connection_factory = self._connection_pool._connection_factory # type: ignore + # Replace the connection factory with our custom one + self._connection_pool._connection_factory = self._create_fabric_connection # type: ignore + + def _create_fabric_connection(self) -> t.Any: + """Custom connection factory that uses the target catalog if set.""" + # If we have a target catalog, we need to modify the connection parameters + if self._target_catalog: + # The original factory was created with partial(), so we need to extract and modify the kwargs + if hasattr(self._original_connection_factory, "keywords"): + # It's a partial function, get the original keywords + original_kwargs = self._original_connection_factory.keywords.copy() + original_kwargs["database"] = self._target_catalog + # Call the underlying function with modified kwargs + return self._original_connection_factory.func(**original_kwargs) + + # Use the original factory if no target catalog is set + return self._original_connection_factory() + + def _insert_overwrite_by_condition( + self, + table_name: TableName, + source_queries: t.List[SourceQuery], + columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + where: t.Optional[exp.Condition] = None, + insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None, + **kwargs: t.Any, + ) -> None: + """ + Implements the insert overwrite strategy for Fabric using DELETE and INSERT. + + This method is overridden to avoid the MERGE statement from the parent + MSSQLEngineAdapter, which is not fully supported in Fabric. + """ + return EngineAdapter._insert_overwrite_by_condition( + self, + table_name=table_name, + source_queries=source_queries, + columns_to_types=columns_to_types, + where=where, + insert_overwrite_strategy_override=InsertOverwriteStrategy.DELETE_INSERT, + **kwargs, + ) + + def _get_access_token(self) -> str: + """Get access token using Service Principal authentication.""" + tenant_id = self._extra_config.get("tenant_id") + client_id = self._extra_config.get("user") + client_secret = self._extra_config.get("password") + + if not all([tenant_id, client_id, client_secret]): + raise SQLMeshError( + "Service Principal authentication requires tenant_id, client_id, and client_secret " + "in the Fabric connection configuration" + ) + + if not requests: + raise SQLMeshError("requests library is required for Fabric authentication") + + # Use Azure AD OAuth2 token endpoint + token_url = f"https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/token" + + data = { + "grant_type": "client_credentials", + "client_id": client_id, + "client_secret": client_secret, + "scope": "https://api.fabric.microsoft.com/.default", + } + + try: + response = requests.post(token_url, data=data) + response.raise_for_status() + token_data = response.json() + return token_data["access_token"] + except requests.exceptions.RequestException as e: + raise SQLMeshError(f"Failed to authenticate with Azure AD: {e}") + except KeyError: + raise SQLMeshError("Invalid response from Azure AD token endpoint") + + def _get_fabric_auth_headers(self) -> t.Dict[str, str]: + """Get authentication headers for Fabric REST API calls.""" + access_token = self._get_access_token() + return {"Authorization": f"Bearer {access_token}", "Content-Type": "application/json"} + + def _make_fabric_api_request( + self, method: str, endpoint: str, data: t.Optional[t.Dict[str, t.Any]] = None + ) -> t.Dict[str, t.Any]: + """Make a request to the Fabric REST API.""" + if not requests: + raise SQLMeshError("requests library is required for Fabric catalog operations") + + workspace_id = self._extra_config.get("workspace_id") + if not workspace_id: + raise SQLMeshError( + "workspace_id parameter is required in connection config for Fabric catalog operations" + ) + + base_url = "https://api.fabric.microsoft.com/v1" + url = f"{base_url}/workspaces/{workspace_id}/{endpoint}" + + headers = self._get_fabric_auth_headers() + + try: + if method.upper() == "GET": + response = requests.get(url, headers=headers) + elif method.upper() == "POST": + response = requests.post(url, headers=headers, json=data) + elif method.upper() == "DELETE": + response = requests.delete(url, headers=headers) + else: + raise SQLMeshError(f"Unsupported HTTP method: {method}") + + response.raise_for_status() + + if response.status_code == 204: # No content + return {} + + return response.json() if response.content else {} + + except requests.exceptions.HTTPError as e: + error_details = "" + try: + if response.content: + error_response = response.json() + error_details = error_response.get("error", {}).get( + "message", str(error_response) + ) + except (ValueError, AttributeError): + error_details = response.text if hasattr(response, "text") else str(e) + + raise SQLMeshError(f"Fabric API HTTP error ({response.status_code}): {error_details}") + except requests.exceptions.RequestException as e: + raise SQLMeshError(f"Fabric API request failed: {e}") + + def _make_fabric_api_request_with_location( + self, method: str, endpoint: str, data: t.Optional[t.Dict[str, t.Any]] = None + ) -> t.Dict[str, t.Any]: + """Make a request to the Fabric REST API and return response with status code and location.""" + if not requests: + raise SQLMeshError("requests library is required for Fabric catalog operations") + + workspace_id = self._extra_config.get("workspace_id") + if not workspace_id: + raise SQLMeshError( + "workspace_id parameter is required in connection config for Fabric catalog operations" + ) + + base_url = "https://api.fabric.microsoft.com/v1" + url = f"{base_url}/workspaces/{workspace_id}/{endpoint}" + headers = self._get_fabric_auth_headers() + + try: + if method.upper() == "POST": + response = requests.post(url, headers=headers, json=data) + else: + raise SQLMeshError(f"Unsupported HTTP method for location tracking: {method}") + + # Check for errors first + response.raise_for_status() + + result = {"status_code": response.status_code} + + # Extract location header for polling + if "location" in response.headers: + result["location"] = response.headers["location"] + + # Include response body if present + if response.content: + json_data = response.json() + if json_data: + result.update(json_data) + + return result + + except requests.exceptions.HTTPError as e: + error_details = "" + try: + if response.content: + error_response = response.json() + error_details = error_response.get("error", {}).get( + "message", str(error_response) + ) + except (ValueError, AttributeError): + error_details = response.text if hasattr(response, "text") else str(e) + + raise SQLMeshError(f"Fabric API HTTP error ({response.status_code}): {error_details}") + except requests.exceptions.RequestException as e: + raise SQLMeshError(f"Fabric API request failed: {e}") + + @retry( + wait=wait_exponential(multiplier=1, min=1, max=30), + stop=stop_after_attempt(60), + retry=retry_if_result(lambda result: result not in ["Succeeded", "Failed"]), + ) + def _check_operation_status(self, location_url: str, operation_name: str) -> str: + """Check the operation status and return the status string.""" + if not requests: + raise SQLMeshError("requests library is required for Fabric catalog operations") + + headers = self._get_fabric_auth_headers() + + try: + response = requests.get(location_url, headers=headers) + response.raise_for_status() + + result = response.json() + status = result.get("status", "Unknown") + + logger.info(f"Operation {operation_name} status: {status}") + + if status == "Failed": + error_msg = result.get("error", {}).get("message", "Unknown error") + raise SQLMeshError(f"Operation {operation_name} failed: {error_msg}") + elif status in ["InProgress", "Running"]: + logger.info(f"Operation {operation_name} still in progress...") + elif status not in ["Succeeded"]: + logger.warning(f"Unknown status '{status}' for operation {operation_name}") + + return status + + except requests.exceptions.RequestException as e: + logger.warning(f"Failed to poll status: {e}") + raise SQLMeshError(f"Failed to poll operation status: {e}") + + def _poll_operation_status(self, location_url: str, operation_name: str) -> None: + """Poll the operation status until completion.""" + try: + final_status = self._check_operation_status(location_url, operation_name) + if final_status != "Succeeded": + raise SQLMeshError( + f"Operation {operation_name} completed with status: {final_status}" + ) + except Exception as e: + if "retry" in str(e).lower(): + raise SQLMeshError(f"Operation {operation_name} did not complete within timeout") + raise + + def _create_catalog(self, catalog_name: exp.Identifier) -> None: + """Create a catalog (warehouse) in Microsoft Fabric via REST API.""" + warehouse_name = catalog_name.sql(dialect=self.dialect, identify=False) + logger.info(f"Creating Fabric warehouse: {warehouse_name}") + + request_data = { + "displayName": warehouse_name, + "description": f"Warehouse created by SQLMesh: {warehouse_name}", + } + + response = self._make_fabric_api_request_with_location("POST", "warehouses", request_data) + + # Handle direct success (201) or async creation (202) + if response.get("status_code") == 201: + logger.info(f"Successfully created Fabric warehouse: {warehouse_name}") + return + + if response.get("status_code") == 202 and response.get("location"): + logger.info(f"Warehouse creation initiated for: {warehouse_name}") + self._poll_operation_status(response["location"], warehouse_name) + logger.info(f"Successfully created Fabric warehouse: {warehouse_name}") + else: + raise SQLMeshError(f"Unexpected response from warehouse creation: {response}") + + def _drop_catalog(self, catalog_name: exp.Identifier) -> None: + """Drop a catalog (warehouse) in Microsoft Fabric via REST API.""" + warehouse_name = catalog_name.sql(dialect=self.dialect, identify=False) + + logger.info(f"Deleting Fabric warehouse: {warehouse_name}") + + try: + # Get the warehouse ID by listing warehouses + warehouses = self._make_fabric_api_request("GET", "warehouses") + warehouse_id = None + + for warehouse in warehouses.get("value", []): + if warehouse.get("displayName") == warehouse_name: + warehouse_id = warehouse.get("id") + break + + if not warehouse_id: + logger.info(f"Fabric warehouse does not exist: {warehouse_name}") + return + + # Delete the warehouse by ID + self._make_fabric_api_request("DELETE", f"warehouses/{warehouse_id}") + logger.info(f"Successfully deleted Fabric warehouse: {warehouse_name}") + + except SQLMeshError as e: + error_msg = str(e).lower() + if "not found" in error_msg or "does not exist" in error_msg: + logger.info(f"Fabric warehouse does not exist: {warehouse_name}") + return + logger.error(f"Failed to delete Fabric warehouse {warehouse_name}: {e}") + raise + + def set_current_catalog(self, catalog_name: str) -> None: + """ + Set the current catalog for Microsoft Fabric connections. + + Override to handle Fabric's stateless session limitation where USE statements + don't persist across queries. Instead, we close existing connections and + recreate them with the new catalog in the connection configuration. + + Args: + catalog_name: The name of the catalog (warehouse) to switch to + + Note: + Fabric doesn't support catalog switching via USE statements because each + statement runs as an independent session. This method works around this + limitation by updating the connection pool with new catalog configuration. + + See: + https://learn.microsoft.com/en-us/fabric/data-warehouse/sql-query-editor#limitations + """ + current_catalog = self.get_current_catalog() + + # If already using the requested catalog, do nothing + if current_catalog and current_catalog == catalog_name: + logger.debug(f"Already using catalog '{catalog_name}', no action needed") + return + + logger.info(f"Switching from catalog '{current_catalog}' to '{catalog_name}'") + + # Set the target catalog for our custom connection factory + self._target_catalog = catalog_name + + # Close all existing connections since Fabric requires reconnection for catalog changes + self.close() + + # Verify the catalog switch worked by getting a new connection + try: + actual_catalog = self.get_current_catalog() + if actual_catalog and actual_catalog == catalog_name: + logger.debug(f"Successfully switched to catalog '{catalog_name}'") + else: + logger.warning( + f"Catalog switch may have failed. Expected '{catalog_name}', got '{actual_catalog}'" + ) + except Exception as e: + logger.debug(f"Could not verify catalog switch: {e}") + + logger.debug(f"Updated target catalog to '{catalog_name}' and closed connections") + + def drop_schema( + self, + schema_name: SchemaName, + ignore_if_not_exists: bool = True, + cascade: bool = False, + **drop_args: t.Any, + ) -> None: + """ + Override drop_schema to handle catalog-qualified schema names. + Fabric doesn't support 'DROP SCHEMA [catalog].[schema]' syntax. + """ + logger.debug(f"drop_schema called with: {schema_name} (type: {type(schema_name)})") + + # Handle Table objects created by schema_() function + if isinstance(schema_name, exp.Table) and not schema_name.name: + # This is a schema Table object - check for catalog qualification + if schema_name.catalog: + # Catalog-qualified schema: catalog.schema + catalog_name = schema_name.catalog + schema_only = schema_name.db + logger.debug( + f"Detected catalog-qualified schema: catalog='{catalog_name}', schema='{schema_only}'" + ) + # Switch to the catalog first + self.set_current_catalog(catalog_name) + # Use just the schema name + super().drop_schema(schema_only, ignore_if_not_exists, cascade, **drop_args) + else: + # Schema only, no catalog + schema_only = schema_name.db + logger.debug(f"Detected schema-only: schema='{schema_only}'") + super().drop_schema(schema_only, ignore_if_not_exists, cascade, **drop_args) + else: + # Handle string or table name inputs by parsing as table + table = exp.to_table(schema_name) + + if table.catalog: + # 3-part name detected (catalog.db.table) - this shouldn't happen for schema operations + raise SQLMeshError( + f"Invalid schema name format: {schema_name}. Expected 'schema' or 'catalog.schema', got 3-part name" + ) + elif table.db: + # Catalog-qualified schema: catalog.schema + catalog_name = table.db + schema_only = table.name + logger.debug( + f"Detected catalog.schema format: catalog='{catalog_name}', schema='{schema_only}'" + ) + # Switch to the catalog first + self.set_current_catalog(catalog_name) + # Use just the schema name + super().drop_schema(schema_only, ignore_if_not_exists, cascade, **drop_args) + else: + # No catalog qualification, use as-is + logger.debug(f"No catalog detected, using original: {schema_name}") + super().drop_schema(schema_name, ignore_if_not_exists, cascade, **drop_args) + + def create_schema( + self, + schema_name: SchemaName, + ignore_if_exists: bool = True, + **kwargs: t.Any, + ) -> None: + """ + Override create_schema to handle catalog-qualified schema names. + Fabric doesn't support 'CREATE SCHEMA [catalog].[schema]' syntax. + """ + logger.debug(f"create_schema called with: {schema_name} (type: {type(schema_name)})") + + # Handle Table objects created by schema_() function + if isinstance(schema_name, exp.Table) and not schema_name.name: + # This is a schema Table object - check for catalog qualification + if schema_name.catalog: + # Catalog-qualified schema: catalog.schema + catalog_name = schema_name.catalog + schema_only = schema_name.db + logger.debug( + f"Detected catalog-qualified schema: catalog='{catalog_name}', schema='{schema_only}'" + ) + # Switch to the catalog first + self.set_current_catalog(catalog_name) + # Use just the schema name + super().create_schema(schema_only, ignore_if_exists, **kwargs) + else: + # Schema only, no catalog + schema_only = schema_name.db + logger.debug(f"Detected schema-only: schema='{schema_only}'") + super().create_schema(schema_only, ignore_if_exists, **kwargs) + else: + # Handle string or table name inputs by parsing as table + table = exp.to_table(schema_name) + + if table.catalog: + # 3-part name detected (catalog.db.table) - this shouldn't happen for schema operations + raise SQLMeshError( + f"Invalid schema name format: {schema_name}. Expected 'schema' or 'catalog.schema', got 3-part name" + ) + elif table.db: + # Catalog-qualified schema: catalog.schema + catalog_name = table.db + schema_only = table.name + logger.debug( + f"Detected catalog.schema format: catalog='{catalog_name}', schema='{schema_only}'" + ) + # Switch to the catalog first + self.set_current_catalog(catalog_name) + # Use just the schema name + super().create_schema(schema_only, ignore_if_exists, **kwargs) + else: + # No catalog qualification, use as-is + logger.debug(f"No catalog detected, using original: {schema_name}") + super().create_schema(schema_name, ignore_if_exists, **kwargs) + + def _ensure_schema_exists(self, table_name: TableName) -> None: + """ + Ensure that the schema for a table exists before creating the table. + This is necessary for Fabric because schemas must exist before tables can be created in them. + """ + table = exp.to_table(table_name) + if table.db: + schema_name = table.db + catalog_name = table.catalog + + # Build the full schema name + full_schema_name = f"{catalog_name}.{schema_name}" if catalog_name else schema_name + + logger.debug(f"Ensuring schema exists: {full_schema_name}") + + try: + # Create the schema if it doesn't exist + self.create_schema(full_schema_name, ignore_if_exists=True) + except Exception as e: + logger.debug(f"Error creating schema {full_schema_name}: {e}") + # Continue anyway - the schema might already exist or we might not have permissions + + def _create_table( + self, + table_name_or_schema: t.Union[exp.Schema, TableName], + expression: t.Optional[exp.Expression], + exists: bool = True, + replace: bool = False, + columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + table_description: t.Optional[str] = None, + column_descriptions: t.Optional[t.Dict[str, str]] = None, + table_kind: t.Optional[str] = None, + **kwargs: t.Any, + ) -> None: + """ + Override _create_table to ensure schema exists before creating tables. + """ + # Extract table name for schema creation + if isinstance(table_name_or_schema, exp.Schema): + table_name = table_name_or_schema.this + else: + table_name = table_name_or_schema + + # Ensure the schema exists before creating the table + self._ensure_schema_exists(table_name) + + # Call the parent implementation + super()._create_table( + table_name_or_schema=table_name_or_schema, + expression=expression, + exists=exists, + replace=replace, + columns_to_types=columns_to_types, + table_description=table_description, + column_descriptions=column_descriptions, + table_kind=table_kind, + **kwargs, + ) + + def create_table( + self, + table_name: TableName, + columns_to_types: t.Dict[str, exp.DataType], + primary_key: t.Optional[t.Tuple[str, ...]] = None, + exists: bool = True, + table_description: t.Optional[str] = None, + column_descriptions: t.Optional[t.Dict[str, str]] = None, + **kwargs: t.Any, + ) -> None: + """ + Override create_table to ensure schema exists before creating tables. + """ + # Ensure the schema exists before creating the table + self._ensure_schema_exists(table_name) + + # Call the parent implementation + super().create_table( + table_name=table_name, + columns_to_types=columns_to_types, + primary_key=primary_key, + exists=exists, + table_description=table_description, + column_descriptions=column_descriptions, + **kwargs, + ) + + def ctas( + self, + table_name: TableName, + query_or_df: t.Any, + columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + exists: bool = True, + table_description: t.Optional[str] = None, + column_descriptions: t.Optional[t.Dict[str, str]] = None, + **kwargs: t.Any, + ) -> None: + """ + Override ctas to ensure schema exists before creating tables. + """ + # Ensure the schema exists before creating the table + self._ensure_schema_exists(table_name) + + # Call the parent implementation + super().ctas( + table_name=table_name, + query_or_df=query_or_df, + columns_to_types=columns_to_types, + exists=exists, + table_description=table_description, + column_descriptions=column_descriptions, + **kwargs, + ) + + def create_view( + self, + view_name: SchemaName, + query_or_df: t.Any, + columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + replace: bool = True, + materialized: bool = False, + materialized_properties: t.Optional[t.Dict[str, t.Any]] = None, + table_description: t.Optional[str] = None, + column_descriptions: t.Optional[t.Dict[str, str]] = None, + view_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + **create_kwargs: t.Any, + ) -> None: + """ + Override create_view to handle catalog-qualified view names and ensure schema exists. + Fabric doesn't support 'CREATE VIEW [catalog].[schema].[view]' syntax. + """ + logger.debug(f"create_view called with: {view_name} (type: {type(view_name)})") + + # Parse view_name into an exp.Table to properly handle both string and Table cases + table = exp.to_table(view_name) + + # Ensure schema exists for the view + self._ensure_schema_exists(table) + + if table.catalog: + # 3-part name: catalog.schema.view + catalog_name = table.catalog + schema_name = table.db or "" + view_only = table.name + + logger.debug( + f"Detected catalog.schema.view format: catalog='{catalog_name}', schema='{schema_name}', view='{view_only}'" + ) + + # Switch to the catalog first + self.set_current_catalog(catalog_name) + + # Create new Table expression without catalog + unqualified_view = exp.Table(this=view_only, db=schema_name) + + super().create_view( + unqualified_view, + query_or_df, + columns_to_types, + replace, + materialized, + materialized_properties, + table_description, + column_descriptions, + view_properties, + **create_kwargs, + ) + else: + # No catalog qualification, use as-is + logger.debug(f"No catalog detected, using original: {view_name}") + super().create_view( + view_name, + query_or_df, + columns_to_types, + replace, + materialized, + materialized_properties, + table_description, + column_descriptions, + view_properties, + **create_kwargs, + ) diff --git a/tests/core/engine_adapter/integration/__init__.py b/tests/core/engine_adapter/integration/__init__.py index 7e35b832be..63c4ca465f 100644 --- a/tests/core/engine_adapter/integration/__init__.py +++ b/tests/core/engine_adapter/integration/__init__.py @@ -82,6 +82,7 @@ def pytest_marks(self) -> t.List[MarkDecorator]: IntegrationTestEngine("bigquery", native_dataframe_type="bigframe", cloud=True), IntegrationTestEngine("databricks", native_dataframe_type="pyspark", cloud=True), IntegrationTestEngine("snowflake", native_dataframe_type="snowpark", cloud=True), + IntegrationTestEngine("fabric", cloud=True), ] ENGINES_BY_NAME = {e.engine: e for e in ENGINES} @@ -679,6 +680,9 @@ def create_catalog(self, catalog_name: str): except Exception: pass self.engine_adapter.cursor.connection.autocommit(False) + elif self.dialect == "fabric": + # Use the engine adapter's built-in catalog creation functionality + self.engine_adapter.create_catalog(catalog_name) elif self.dialect == "snowflake": self.engine_adapter.execute(f'CREATE DATABASE IF NOT EXISTS "{catalog_name}"') elif self.dialect == "duckdb": @@ -695,6 +699,9 @@ def drop_catalog(self, catalog_name: str): return # bigquery cannot create/drop catalogs if self.dialect == "databricks": self.engine_adapter.execute(f"DROP CATALOG IF EXISTS {catalog_name} CASCADE") + elif self.dialect == "fabric": + # Use the engine adapter's built-in catalog dropping functionality + self.engine_adapter.drop_catalog(catalog_name) else: self.engine_adapter.execute(f'DROP DATABASE IF EXISTS "{catalog_name}"') diff --git a/tests/core/engine_adapter/integration/config.yaml b/tests/core/engine_adapter/integration/config.yaml index d18ea5366f..6733077ff0 100644 --- a/tests/core/engine_adapter/integration/config.yaml +++ b/tests/core/engine_adapter/integration/config.yaml @@ -186,5 +186,20 @@ gateways: state_connection: type: duckdb + inttest_fabric: + connection: + type: fabric + driver: pyodbc + host: {{ env_var("FABRIC_HOST") }} + user: {{ env_var("FABRIC_CLIENT_ID") }} + password: {{ env_var("FABRIC_CLIENT_SECRET") }} + database: {{ env_var("FABRIC_DATABASE") }} + tenant_id: {{ env_var("FABRIC_TENANT_ID") }} + workspace_id: {{ env_var("FABRIC_WORKSPACE_ID") }} + odbc_properties: + Authentication: ActiveDirectoryServicePrincipal + state_connection: + type: duckdb + model_defaults: dialect: duckdb diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py index cb09d20537..a4827579d8 100644 --- a/tests/core/engine_adapter/integration/test_integration.py +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -1756,6 +1756,7 @@ def test_dialects(ctx: TestContext): { "default": pd.Timestamp("2020-01-01 00:00:00+00:00"), "clickhouse": pd.Timestamp("2020-01-01 00:00:00"), + "fabric": pd.Timestamp("2020-01-01 00:00:00"), "mysql": pd.Timestamp("2020-01-01 00:00:00"), "spark": pd.Timestamp("2020-01-01 00:00:00"), "databricks": pd.Timestamp("2020-01-01 00:00:00"), @@ -2157,14 +2158,12 @@ def test_value_normalization( input_data: t.Tuple[t.Any, ...], expected_results: t.Tuple[str, ...], ) -> None: - if ( - ctx.dialect == "trino" - and ctx.engine_adapter.current_catalog_type == "hive" - and column_type == exp.DataType.Type.TIMESTAMPTZ - ): - pytest.skip( - "Trino on Hive doesnt support creating tables with TIMESTAMP WITH TIME ZONE fields" - ) + # Skip TIMESTAMPTZ tests for engines that don't support it + if column_type == exp.DataType.Type.TIMESTAMPTZ: + if ctx.dialect == "trino" and ctx.engine_adapter.current_catalog_type == "hive": + pytest.skip("Trino on Hive doesn't support TIMESTAMP WITH TIME ZONE fields") + if ctx.dialect == "fabric": + pytest.skip("Fabric doesn't support TIMESTAMP WITH TIME ZONE fields") if not isinstance(ctx.engine_adapter, RowDiffMixin): pytest.skip( @@ -2254,7 +2253,10 @@ def test_table_diff_grain_check_single_key(ctx: TestContext): src_table = ctx.table("source") target_table = ctx.table("target") - columns_to_types = {"key1": exp.DataType.build("int"), "value": exp.DataType.build("varchar")} + columns_to_types = { + "key1": exp.DataType.build("int"), + "value": exp.DataType.build("varchar"), + } ctx.engine_adapter.create_table(src_table, columns_to_types) ctx.engine_adapter.create_table(target_table, columns_to_types) diff --git a/tests/core/engine_adapter/test_fabric.py b/tests/core/engine_adapter/test_fabric.py new file mode 100644 index 0000000000..709df816d2 --- /dev/null +++ b/tests/core/engine_adapter/test_fabric.py @@ -0,0 +1,83 @@ +# type: ignore + +import typing as t + +import pytest +from sqlglot import exp, parse_one + +from sqlmesh.core.engine_adapter import FabricAdapter +from tests.core.engine_adapter import to_sql_calls + +pytestmark = [pytest.mark.engine, pytest.mark.fabric] + + +@pytest.fixture +def adapter(make_mocked_engine_adapter: t.Callable) -> FabricAdapter: + return make_mocked_engine_adapter(FabricAdapter) + + +def test_columns(adapter: FabricAdapter): + adapter.cursor.fetchall.return_value = [ + ("decimal_ps", "decimal", None, 5, 4), + ("decimal", "decimal", None, 18, 0), + ("float", "float", None, 53, None), + ("char_n", "char", 10, None, None), + ("varchar_n", "varchar", 10, None, None), + ("nvarchar_max", "nvarchar", -1, None, None), + ] + + assert adapter.columns("db.table") == { + "decimal_ps": exp.DataType.build("decimal(5, 4)", dialect=adapter.dialect), + "decimal": exp.DataType.build("decimal(18, 0)", dialect=adapter.dialect), + "float": exp.DataType.build("float(53)", dialect=adapter.dialect), + "char_n": exp.DataType.build("char(10)", dialect=adapter.dialect), + "varchar_n": exp.DataType.build("varchar(10)", dialect=adapter.dialect), + "nvarchar_max": exp.DataType.build("nvarchar(max)", dialect=adapter.dialect), + } + + # Verify that the adapter queries the uppercase INFORMATION_SCHEMA + adapter.cursor.execute.assert_called_once_with( + """SELECT [COLUMN_NAME], [DATA_TYPE], [CHARACTER_MAXIMUM_LENGTH], [NUMERIC_PRECISION], [NUMERIC_SCALE] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_NAME] = 'table' AND [TABLE_SCHEMA] = 'db';""" + ) + + +def test_table_exists(adapter: FabricAdapter): + adapter.cursor.fetchone.return_value = (1,) + assert adapter.table_exists("db.table") + # Verify that the adapter queries the uppercase INFORMATION_SCHEMA + adapter.cursor.execute.assert_called_once_with( + """SELECT 1 FROM [INFORMATION_SCHEMA].[TABLES] WHERE [TABLE_NAME] = 'table' AND [TABLE_SCHEMA] = 'db';""" + ) + + adapter.cursor.fetchone.return_value = None + assert not adapter.table_exists("db.table") + + +def test_insert_overwrite_by_time_partition(adapter: FabricAdapter): + adapter.insert_overwrite_by_time_partition( + "test_table", + parse_one("SELECT a, b FROM tbl"), + start="2022-01-01", + end="2022-01-02", + time_column="b", + time_formatter=lambda x, _: exp.Literal.string(x.strftime("%Y-%m-%d")), + columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")}, + ) + + # Fabric adapter should use DELETE/INSERT strategy, not MERGE. + assert to_sql_calls(adapter) == [ + """DELETE FROM [test_table] WHERE [b] BETWEEN '2022-01-01' AND '2022-01-02';""", + """INSERT INTO [test_table] ([a], [b]) SELECT [a], [b] FROM (SELECT [a] AS [a], [b] AS [b] FROM [tbl]) AS [_subquery] WHERE [b] BETWEEN '2022-01-01' AND '2022-01-02';""", + ] + + +def test_replace_query(adapter: FabricAdapter): + adapter.cursor.fetchone.return_value = (1,) + adapter.replace_query("test_table", parse_one("SELECT a FROM tbl"), {"a": "int"}) + + # This behavior is inherited from MSSQLEngineAdapter and should be TRUNCATE + INSERT + assert to_sql_calls(adapter) == [ + """SELECT 1 FROM [INFORMATION_SCHEMA].[TABLES] WHERE [TABLE_NAME] = 'test_table';""", + "TRUNCATE TABLE [test_table];", + "INSERT INTO [test_table] ([a]) SELECT [a] FROM [tbl];", + ] diff --git a/tests/core/test_connection_config.py b/tests/core/test_connection_config.py index 7fe2487891..522c85c434 100644 --- a/tests/core/test_connection_config.py +++ b/tests/core/test_connection_config.py @@ -12,6 +12,7 @@ ConnectionConfig, DatabricksConnectionConfig, DuckDBAttachOptions, + FabricConnectionConfig, DuckDBConnectionConfig, GCPPostgresConnectionConfig, MotherDuckConnectionConfig, @@ -1687,3 +1688,95 @@ def mock_add_output_converter(sql_type, converter_func): expected_dt = datetime(2023, 1, 1, 12, 0, 0, 0, timezone(timedelta(hours=-8, minutes=0))) assert result == expected_dt assert result.tzinfo == timezone(timedelta(hours=-8)) + + +def test_fabric_connection_config_defaults(make_config): + """Test Fabric connection config defaults to pyodbc and autocommit=True.""" + config = make_config( + type="fabric", + host="localhost", + workspace_id="test-workspace-id", + tenant_id="test-tenant-id", + check_import=False, + ) + assert isinstance(config, FabricConnectionConfig) + assert config.driver == "pyodbc" + assert config.autocommit is True + + # Ensure it creates the FabricAdapter + from sqlmesh.core.engine_adapter.fabric import FabricAdapter + + assert isinstance(config.create_engine_adapter(), FabricAdapter) + + +def test_fabric_connection_config_parameter_validation(make_config): + """Test Fabric connection config parameter validation.""" + # Test that FabricConnectionConfig correctly handles pyodbc-specific parameters. + config = make_config( + type="fabric", + host="localhost", + driver_name="ODBC Driver 18 for SQL Server", + trust_server_certificate=True, + encrypt=False, + odbc_properties={"Authentication": "ActiveDirectoryServicePrincipal"}, + workspace_id="test-workspace-id", + tenant_id="test-tenant-id", + check_import=False, + ) + assert isinstance(config, FabricConnectionConfig) + assert config.driver == "pyodbc" # Driver is fixed to pyodbc + assert config.driver_name == "ODBC Driver 18 for SQL Server" + assert config.trust_server_certificate is True + assert config.encrypt is False + assert config.odbc_properties == {"Authentication": "ActiveDirectoryServicePrincipal"} + + # Test that specifying a different driver for Fabric raises an error + with pytest.raises(ConfigError, match=r"Input should be 'pyodbc'"): + make_config(type="fabric", host="localhost", driver="pymssql", check_import=False) + + +def test_fabric_pyodbc_connection_string_generation(): + """Test that the Fabric pyodbc connection gets invoked with the correct ODBC connection string.""" + with patch("pyodbc.connect") as mock_pyodbc_connect: + # Create a Fabric config + config = FabricConnectionConfig( + host="testserver.datawarehouse.fabric.microsoft.com", + port=1433, + database="testdb", + user="testuser", + password="testpass", + driver_name="ODBC Driver 18 for SQL Server", + trust_server_certificate=True, + encrypt=True, + login_timeout=30, + workspace_id="test-workspace-id", + tenant_id="test-tenant-id", + check_import=False, + ) + + # Get the connection factory with kwargs and call it + factory_with_kwargs = config._connection_factory_with_kwargs + connection = factory_with_kwargs() + + # Verify pyodbc.connect was called with the correct connection string + mock_pyodbc_connect.assert_called_once() + call_args = mock_pyodbc_connect.call_args + + # Check the connection string (first argument) + conn_str = call_args[0][0] + expected_parts = [ + "DRIVER={ODBC Driver 18 for SQL Server}", + "SERVER=testserver.datawarehouse.fabric.microsoft.com,1433", + "DATABASE=testdb", + "Encrypt=YES", + "TrustServerCertificate=YES", + "Connection Timeout=30", + "UID=testuser", + "PWD=testpass", + ] + + for part in expected_parts: + assert part in conn_str + + # Check autocommit parameter, should default to True for Fabric + assert call_args[1]["autocommit"] is True