diff --git a/sqlmesh/core/engine_adapter/fabric.py b/sqlmesh/core/engine_adapter/fabric.py index 585622e866..a528be3cb4 100644 --- a/sqlmesh/core/engine_adapter/fabric.py +++ b/sqlmesh/core/engine_adapter/fabric.py @@ -121,10 +121,19 @@ def _create_catalog(self, catalog_name: exp.Identifier) -> None: def _drop_catalog(self, catalog_name: exp.Identifier) -> None: """Drop a catalog (warehouse) in Microsoft Fabric via REST API.""" warehouse_name = catalog_name.sql(dialect=self.dialect, identify=False) + current_catalog = self.get_current_catalog() logger.info(f"Deleting Fabric warehouse: {warehouse_name}") self.api_client.delete_warehouse(warehouse_name) + if warehouse_name == current_catalog: + # Somewhere around 2025-09-08, Fabric started validating the "Database=" connection argument and throwing 'Authentication failed' if the database doesnt exist + # In addition, set_current_catalog() is implemented using a threadlocal variable "target_catalog" + # So, when we drop a warehouse, and there are still threads with "target_catalog" set to reference it, any operations on those threads + # that use an either use an existing connection pointing to this warehouse or trigger a new connection + # will fail with an 'Authentication Failed' error unless we close all connections here, which also clears all the threadlocal data + self.close() + def set_current_catalog(self, catalog_name: str) -> None: """ Set the current catalog for Microsoft Fabric connections. diff --git a/sqlmesh/utils/connection_pool.py b/sqlmesh/utils/connection_pool.py index a4f9486184..9a70db6885 100644 --- a/sqlmesh/utils/connection_pool.py +++ b/sqlmesh/utils/connection_pool.py @@ -227,7 +227,8 @@ def close_all(self, exclude_calling_thread: bool = False) -> None: self._thread_connections.pop(thread_id) self._thread_cursors.pop(thread_id, None) self._discard_transaction(thread_id) - self._thread_attributes.pop(thread_id, None) + + self._thread_attributes.clear() class ThreadLocalSharedConnectionPool(_ThreadLocalBase): diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py index fcbc711f49..42ff8b881f 100644 --- a/tests/core/engine_adapter/integration/test_integration.py +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -3743,6 +3743,14 @@ def _set_config(gateway: str, config: Config) -> None: assert not md.tables assert not md.managed_tables + if ctx.dialect == "fabric": + # TestContext is using a different EngineAdapter instance / connection pool instance to the SQLMesh context + # When the SQLMesh context drops :snapshot_schema using its EngineAdapter, connections in TestContext are unaware + # and still have their threadlocal "target_catalog" attribute pointing to a catalog that no longer exists + # Trying to establish a connection to a nonexistant catalog produces an error, so we close all connections here + # to clear the threadlocal attributes + ctx.engine_adapter.close() + md = ctx.get_metadata_results(snapshot_schema) assert not md.views assert not md.managed_tables diff --git a/tests/core/engine_adapter/integration/test_integration_fabric.py b/tests/core/engine_adapter/integration/test_integration_fabric.py index a272005bdc..41f399b3b8 100644 --- a/tests/core/engine_adapter/integration/test_integration_fabric.py +++ b/tests/core/engine_adapter/integration/test_integration_fabric.py @@ -1,8 +1,12 @@ import typing as t +import threading +import queue import pytest from pytest import FixtureRequest from sqlmesh.core.engine_adapter import FabricEngineAdapter +from sqlmesh.utils.connection_pool import ThreadLocalConnectionPool from tests.core.engine_adapter.integration import TestContext +from concurrent.futures import ThreadPoolExecutor from tests.core.engine_adapter.integration import ( TestContext, @@ -39,3 +43,75 @@ def test_create_drop_catalog(ctx: TestContext, engine_adapter: FabricEngineAdapt finally: # if doesnt exist, should be no-op, not error ctx.drop_catalog(catalog_name) + + +def test_drop_catalog_clears_threadlocals_that_reference_it( + ctx: TestContext, engine_adapter: FabricEngineAdapter +): + catalog_name = ctx.add_test_suffix("test_drop_catalog") + default_catalog = engine_adapter.get_current_catalog() + + assert isinstance(engine_adapter._connection_pool, ThreadLocalConnectionPool) + + # sets the connection attribute for this thread + engine_adapter.create_catalog(catalog_name) + assert engine_adapter._target_catalog is None + engine_adapter.set_current_catalog(catalog_name) + assert engine_adapter.get_current_catalog() == catalog_name + assert engine_adapter._target_catalog == catalog_name + + lock = threading.RLock() + + def _set_and_return_catalog_in_another_thread( + q: queue.Queue, engine_adapter: FabricEngineAdapter + ) -> t.Optional[str]: + q.put("thread_started") + + assert engine_adapter.get_current_catalog() == default_catalog + assert engine_adapter._target_catalog is None + + engine_adapter.set_current_catalog(catalog_name) + assert engine_adapter.get_current_catalog() == catalog_name + assert engine_adapter._target_catalog == catalog_name + + q.put("catalog_set_in_thread") + + # block this thread while we drop the catalog in the main test thread + lock.acquire() + + # the current catalog should have been cleared from the threadlocal connection attributes + # when this catalog was dropped by the outer thread, causing it to fall back to the default catalog + try: + assert engine_adapter._target_catalog is None + return engine_adapter.get_current_catalog() + finally: + lock.release() + + q: queue.Queue = queue.Queue() + + with ThreadPoolExecutor() as executor: + lock.acquire() # we have the lock, thread will be blocked until we release it + + future = executor.submit(_set_and_return_catalog_in_another_thread, q, engine_adapter) + + assert q.get() == "thread_started" + assert not future.done() + + try: + assert q.get(timeout=20) == "catalog_set_in_thread" + except: + if exec := future.exception(): + raise exec + raise + + ctx.drop_catalog(catalog_name) + assert not future.done() + + lock.release() # yield the lock to the thread + + # block until thread complete + result = future.result() + + # both threads should be automatically using the default catalog now + assert result == default_catalog + assert engine_adapter.get_current_catalog() == default_catalog diff --git a/tests/utils/test_connection_pool.py b/tests/utils/test_connection_pool.py index 96c2f69012..c5926a3824 100644 --- a/tests/utils/test_connection_pool.py +++ b/tests/utils/test_connection_pool.py @@ -210,6 +210,29 @@ def thread(): assert cursor_mock_thread_two.begin.call_count == 1 +def test_thread_local_connection_pool_attributes(mocker: MockerFixture): + pool = ThreadLocalConnectionPool(connection_factory=lambda: mocker.Mock()) + + pool.set_attribute("foo", "bar") + current_threadid = get_ident() + + def _in_thread(pool: ThreadLocalConnectionPool): + assert get_ident() != current_threadid + pool.set_attribute("foo", "baz") + + with ThreadPoolExecutor() as executor: + future = executor.submit(_in_thread, pool) + assert not future.exception() + + assert pool.get_all_attributes("foo") == ["bar", "baz"] + assert pool.get_attribute("foo") == "bar" + + pool.close_all() + + assert pool.get_all_attributes("foo") == [] + assert pool.get_attribute("foo") is None + + def test_thread_local_shared_connection_pool(mocker: MockerFixture): cursor_mock_thread_one = mocker.Mock() cursor_mock_thread_two = mocker.Mock()