|
1 | 1 | import typing as t |
| 2 | +import threading |
| 3 | +import queue |
2 | 4 | import pytest |
3 | 5 | from pytest import FixtureRequest |
4 | 6 | from sqlmesh.core.engine_adapter import FabricEngineAdapter |
| 7 | +from sqlmesh.utils.connection_pool import ThreadLocalConnectionPool |
5 | 8 | from tests.core.engine_adapter.integration import TestContext |
| 9 | +from concurrent.futures import ThreadPoolExecutor |
6 | 10 |
|
7 | 11 | from tests.core.engine_adapter.integration import ( |
8 | 12 | TestContext, |
@@ -39,3 +43,75 @@ def test_create_drop_catalog(ctx: TestContext, engine_adapter: FabricEngineAdapt |
39 | 43 | finally: |
40 | 44 | # if doesnt exist, should be no-op, not error |
41 | 45 | ctx.drop_catalog(catalog_name) |
| 46 | + |
| 47 | + |
| 48 | +def test_drop_catalog_clears_threadlocals_that_reference_it( |
| 49 | + ctx: TestContext, engine_adapter: FabricEngineAdapter |
| 50 | +): |
| 51 | + catalog_name = ctx.add_test_suffix("test_drop_catalog") |
| 52 | + default_catalog = engine_adapter.get_current_catalog() |
| 53 | + |
| 54 | + assert isinstance(engine_adapter._connection_pool, ThreadLocalConnectionPool) |
| 55 | + |
| 56 | + # sets the connection attribute for this thread |
| 57 | + engine_adapter.create_catalog(catalog_name) |
| 58 | + assert engine_adapter._target_catalog is None |
| 59 | + engine_adapter.set_current_catalog(catalog_name) |
| 60 | + assert engine_adapter.get_current_catalog() == catalog_name |
| 61 | + assert engine_adapter._target_catalog == catalog_name |
| 62 | + |
| 63 | + lock = threading.RLock() |
| 64 | + |
| 65 | + def _set_and_return_catalog_in_another_thread( |
| 66 | + q: queue.Queue, engine_adapter: FabricEngineAdapter |
| 67 | + ) -> t.Optional[str]: |
| 68 | + q.put("thread_started") |
| 69 | + |
| 70 | + assert engine_adapter.get_current_catalog() == default_catalog |
| 71 | + assert engine_adapter._target_catalog is None |
| 72 | + |
| 73 | + engine_adapter.set_current_catalog(catalog_name) |
| 74 | + assert engine_adapter.get_current_catalog() == catalog_name |
| 75 | + assert engine_adapter._target_catalog == catalog_name |
| 76 | + |
| 77 | + q.put("catalog_set_in_thread") |
| 78 | + |
| 79 | + # block this thread while we drop the catalog in the main test thread |
| 80 | + lock.acquire() |
| 81 | + |
| 82 | + # the current catalog should have been cleared from the threadlocal connection attributes |
| 83 | + # when this catalog was dropped by the outer thread, causing it to fall back to the default catalog |
| 84 | + try: |
| 85 | + assert engine_adapter._target_catalog is None |
| 86 | + return engine_adapter.get_current_catalog() |
| 87 | + finally: |
| 88 | + lock.release() |
| 89 | + |
| 90 | + q: queue.Queue = queue.Queue() |
| 91 | + |
| 92 | + with ThreadPoolExecutor() as executor: |
| 93 | + lock.acquire() # we have the lock, thread will be blocked until we release it |
| 94 | + |
| 95 | + future = executor.submit(_set_and_return_catalog_in_another_thread, q, engine_adapter) |
| 96 | + |
| 97 | + assert q.get() == "thread_started" |
| 98 | + assert not future.done() |
| 99 | + |
| 100 | + try: |
| 101 | + assert q.get(timeout=20) == "catalog_set_in_thread" |
| 102 | + except: |
| 103 | + if exec := future.exception(): |
| 104 | + raise exec |
| 105 | + raise |
| 106 | + |
| 107 | + ctx.drop_catalog(catalog_name) |
| 108 | + assert not future.done() |
| 109 | + |
| 110 | + lock.release() # yield the lock to the thread |
| 111 | + |
| 112 | + # block until thread complete |
| 113 | + result = future.result() |
| 114 | + |
| 115 | + # both threads should be automatically using the default catalog now |
| 116 | + assert result == default_catalog |
| 117 | + assert engine_adapter.get_current_catalog() == default_catalog |
0 commit comments