From ea82df0251995638df00d58983a06c5aed250fe1 Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Tue, 30 Sep 2025 12:40:24 -0700 Subject: [PATCH 1/3] Chore: Cache results of get_data_objects --- sqlmesh/core/engine_adapter/base.py | 83 ++++++- sqlmesh/core/snapshot/evaluator.py | 67 ++++- tests/core/engine_adapter/test_athena.py | 1 + tests/core/engine_adapter/test_base.py | 261 ++++++++++++++++++++ tests/core/engine_adapter/test_snowflake.py | 8 +- tests/core/test_snapshot_evaluator.py | 2 + 6 files changed, 405 insertions(+), 17 deletions(-) diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index 68c6404081..d9f417695f 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -161,6 +161,7 @@ def __init__( self.correlation_id = correlation_id self._schema_differ_overrides = schema_differ_overrides self._query_execution_tracker = query_execution_tracker + self._data_object_cache: t.Dict[str, t.Optional[DataObject]] = {} def with_settings(self, **kwargs: t.Any) -> EngineAdapter: extra_kwargs = { @@ -983,6 +984,13 @@ def _create_table( ), track_rows_processed=track_rows_processed, ) + # Extract table name to clear cache + table_name = ( + table_name_or_schema.this + if isinstance(table_name_or_schema, exp.Schema) + else table_name_or_schema + ) + self._clear_data_object_cache(table_name) def _build_create_table_exp( self, @@ -1074,6 +1082,7 @@ def clone_table( **kwargs, ) ) + self._clear_data_object_cache(target_table_name) def drop_data_object(self, data_object: DataObject, ignore_if_not_exists: bool = True) -> None: """Drops a data object of arbitrary type. @@ -1139,6 +1148,7 @@ def _drop_object( drop_args["cascade"] = cascade self.execute(exp.Drop(this=exp.to_table(name), kind=kind, exists=exists, **drop_args)) + self._clear_data_object_cache(name) def get_alter_operations( self, @@ -1329,6 +1339,8 @@ def create_view( quote_identifiers=self.QUOTE_IDENTIFIERS_IN_VIEWS, ) + self._clear_data_object_cache(view_name) + # Register table comment with commands if the engine doesn't support doing it in CREATE if ( table_description @@ -2278,14 +2290,52 @@ def get_data_objects( if object_names is not None: if not object_names: return [] - object_names_list = list(object_names) - batches = [ - object_names_list[i : i + self.DATA_OBJECT_FILTER_BATCH_SIZE] - for i in range(0, len(object_names_list), self.DATA_OBJECT_FILTER_BATCH_SIZE) - ] - return [ - obj for batch in batches for obj in self._get_data_objects(schema_name, set(batch)) - ] + + # Check cache for each object name + target_schema = to_schema(schema_name) + cached_objects = [] + missing_names = set() + + for name in object_names: + cache_key = _get_data_object_cache_key( + target_schema.catalog, target_schema.db, name + ) + if cache_key in self._data_object_cache: + data_object = self._data_object_cache[cache_key] + # If the object is none, then the table was previously looked for but not found + if data_object: + cached_objects.append(data_object) + else: + missing_names.add(name) + + # Fetch missing objects from database + if missing_names: + object_names_list = list(missing_names) + batches = [ + object_names_list[i : i + self.DATA_OBJECT_FILTER_BATCH_SIZE] + for i in range(0, len(object_names_list), self.DATA_OBJECT_FILTER_BATCH_SIZE) + ] + fetched_objects = [ + obj + for batch in batches + for obj in self._get_data_objects(schema_name, set(batch)) + ] + + # Cache the fetched objects + for obj in fetched_objects: + cache_key = _get_data_object_cache_key(obj.catalog, obj.schema_name, obj.name) + self._data_object_cache[cache_key] = obj + + fetched_object_names = {obj.name for obj in fetched_objects} + for missing_name in missing_names - fetched_object_names: + cache_key = _get_data_object_cache_key( + target_schema.catalog, target_schema.db, missing_name + ) + self._data_object_cache[cache_key] = None + + return cached_objects + fetched_objects + + return cached_objects return self._get_data_objects(schema_name) def fetchone( @@ -2693,6 +2743,15 @@ def _to_sql(self, expression: exp.Expression, quote: bool = True, **kwargs: t.An return expression.sql(**sql_gen_kwargs, copy=False) # type: ignore + def _clear_data_object_cache(self, table_name: t.Optional[TableName] = None) -> None: + """Clears the cache entry for the given table name, or clears the entire cache if table_name is None.""" + if table_name is None: + self._data_object_cache.clear() + else: + table = exp.to_table(table_name) + cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name) + self._data_object_cache.pop(cache_key, None) + def _get_data_objects( self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None ) -> t.List[DataObject]: @@ -2940,3 +2999,11 @@ def _decoded_str(value: t.Union[str, bytes]) -> str: if isinstance(value, bytes): return value.decode("utf-8") return value + + +def _get_data_object_cache_key(catalog: t.Optional[str], schema_name: str, object_name: str) -> str: + """Returns a cache key for a data object based on its fully qualified name.""" + catalog_part = catalog.lower() if catalog else "" + schema_part = schema_name.lower() + object_part = object_name.lower() + return f"{catalog_part}.{schema_part}.{object_part}" diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 4ac87199c6..9419815939 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -307,6 +307,9 @@ def promote( ] self._create_schemas(gateway_table_pairs=gateway_table_pairs) + # Fetch the view data objects for the promoted snapshots to get them cached + self._get_virtual_data_objects(target_snapshots, environment_naming_info) + deployability_index = deployability_index or DeployabilityIndex.all_deployable() with self.concurrent_context(): concurrent_apply_to_snapshots( @@ -425,7 +428,9 @@ def get_snapshots_to_create( target_snapshots: Target snapshots. deployability_index: Determines snapshots that are deployable / representative in the context of this creation. """ - existing_data_objects = self._get_data_objects(target_snapshots, deployability_index) + existing_data_objects = self._get_physical_data_objects( + target_snapshots, deployability_index + ) snapshots_to_create = [] for snapshot in target_snapshots: if not snapshot.is_model or snapshot.is_symbolic: @@ -482,7 +487,7 @@ def migrate( deployability_index: Determines snapshots that are deployable in the context of this evaluation. """ deployability_index = deployability_index or DeployabilityIndex.all_deployable() - target_data_objects = self._get_data_objects(target_snapshots, deployability_index) + target_data_objects = self._get_physical_data_objects(target_snapshots, deployability_index) if not target_data_objects: return @@ -1472,7 +1477,7 @@ def _can_clone(self, snapshot: Snapshot, deployability_index: DeployabilityIndex and adapter.table_exists(snapshot.table_name()) ) - def _get_data_objects( + def _get_physical_data_objects( self, target_snapshots: t.Iterable[Snapshot], deployability_index: DeployabilityIndex, @@ -1488,6 +1493,59 @@ def _get_data_objects( A dictionary of snapshot IDs to existing data objects of their physical tables. If the data object for a snapshot is not found, it will not be included in the dictionary. """ + return self._get_data_objects( + target_snapshots, + lambda s: exp.to_table( + s.table_name(deployability_index.is_deployable(s)), dialect=s.model.dialect + ), + ) + + def _get_virtual_data_objects( + self, + target_snapshots: t.Iterable[Snapshot], + environment_naming_info: EnvironmentNamingInfo, + ) -> t.Dict[SnapshotId, DataObject]: + """Returns a dictionary of snapshot IDs to existing data objects of their virtual views. + + Args: + target_snapshots: Target snapshots. + environment_naming_info: The environment naming info of the target virtual environment. + + Returns: + A dictionary of snapshot IDs to existing data objects of their virtual views. If the data object + for a snapshot is not found, it will not be included in the dictionary. + """ + + def _get_view_name(s: Snapshot) -> exp.Table: + adapter = ( + self.get_adapter(s.model_gateway) + if environment_naming_info.gateway_managed + else self.adapter + ) + return exp.to_table( + s.qualified_view_name.for_environment( + environment_naming_info, dialect=adapter.dialect + ), + dialect=adapter.dialect, + ) + + return self._get_data_objects(target_snapshots, _get_view_name) + + def _get_data_objects( + self, + target_snapshots: t.Iterable[Snapshot], + table_name_callable: t.Callable[[Snapshot], exp.Table], + ) -> t.Dict[SnapshotId, DataObject]: + """Returns a dictionary of snapshot IDs to existing data objects. + + Args: + target_snapshots: Target snapshots. + table_name_callable: A function that takes a snapshot and returns the table to look for. + + Returns: + A dictionary of snapshot IDs to existing data objects. If the data object for a snapshot is not found, + it will not be included in the dictionary. + """ tables_by_gateway_and_schema: t.Dict[t.Union[str, None], t.Dict[exp.Table, set[str]]] = ( defaultdict(lambda: defaultdict(set)) ) @@ -1495,8 +1553,7 @@ def _get_data_objects( for snapshot in target_snapshots: if not snapshot.is_model or snapshot.is_symbolic: continue - is_deployable = deployability_index.is_deployable(snapshot) - table = exp.to_table(snapshot.table_name(is_deployable), dialect=snapshot.model.dialect) + table = table_name_callable(snapshot) table_schema = d.schema_(table.db, catalog=table.catalog) tables_by_gateway_and_schema[snapshot.model_gateway][table_schema].add(table.name) snapshots_by_table_name[table.name] = snapshot diff --git a/tests/core/engine_adapter/test_athena.py b/tests/core/engine_adapter/test_athena.py index 4fe57baf34..66e84ae025 100644 --- a/tests/core/engine_adapter/test_athena.py +++ b/tests/core/engine_adapter/test_athena.py @@ -312,6 +312,7 @@ def test_replace_query(adapter: AthenaEngineAdapter, mocker: MockerFixture): ) mocker.patch.object(adapter, "_get_data_objects", return_value=[]) adapter.cursor.execute.reset_mock() + adapter._clear_data_object_cache() adapter.s3_warehouse_location = "s3://foo" adapter.replace_query( diff --git a/tests/core/engine_adapter/test_base.py b/tests/core/engine_adapter/test_base.py index 140fac43eb..5a4bc459ef 100644 --- a/tests/core/engine_adapter/test_base.py +++ b/tests/core/engine_adapter/test_base.py @@ -3695,3 +3695,264 @@ def test_casted_columns( assert [ x.sql() for x in EngineAdapter._casted_columns(columns_to_types, source_columns) ] == expected + + +def test_data_object_cache_get_data_objects( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(EngineAdapter, patch_get_data_objects=False) + + table1 = DataObject(catalog=None, schema="test_schema", name="table1", type="table") + table2 = DataObject(catalog=None, schema="test_schema", name="table2", type="table") + + mock_get_data_objects = mocker.patch.object( + adapter, "_get_data_objects", return_value=[table1, table2] + ) + + result1 = adapter.get_data_objects("test_schema", {"table1", "table2"}) + assert len(result1) == 2 + assert mock_get_data_objects.call_count == 1 + + result2 = adapter.get_data_objects("test_schema", {"table1", "table2"}) + assert len(result2) == 2 + assert mock_get_data_objects.call_count == 1 # Should not increase + + result3 = adapter.get_data_objects("test_schema", {"table1"}) + assert len(result3) == 1 + assert result3[0].name == "table1" + assert mock_get_data_objects.call_count == 1 # Should not increase + + +def test_data_object_cache_get_data_object( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(EngineAdapter, patch_get_data_objects=False) + + table = DataObject(catalog=None, schema="test_schema", name="test_table", type="table") + + mock_get_data_objects = mocker.patch.object(adapter, "_get_data_objects", return_value=[table]) + + result1 = adapter.get_data_object("test_schema.test_table") + assert result1 is not None + assert result1.name == "test_table" + assert mock_get_data_objects.call_count == 1 + + result2 = adapter.get_data_object("test_schema.test_table") + assert result2 is not None + assert result2.name == "test_table" + assert mock_get_data_objects.call_count == 1 # Should not increase + + +def test_data_object_cache_cleared_on_drop_table( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(EngineAdapter, patch_get_data_objects=False) + + table = DataObject(catalog=None, schema="test_schema", name="test_table", type="table") + + mock_get_data_objects = mocker.patch.object(adapter, "_get_data_objects", return_value=[table]) + + adapter.get_data_object("test_schema.test_table") + assert mock_get_data_objects.call_count == 1 + + adapter.drop_table("test_schema.test_table") + + mock_get_data_objects.return_value = [] + result = adapter.get_data_object("test_schema.test_table") + assert result is None + assert mock_get_data_objects.call_count == 2 + + +def test_data_object_cache_cleared_on_drop_view( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(EngineAdapter, patch_get_data_objects=False) + + view = DataObject(catalog=None, schema="test_schema", name="test_view", type="view") + + mock_get_data_objects = mocker.patch.object(adapter, "_get_data_objects", return_value=[view]) + + adapter.get_data_object("test_schema.test_view") + assert mock_get_data_objects.call_count == 1 + + adapter.drop_view("test_schema.test_view") + + mock_get_data_objects.return_value = [] + result = adapter.get_data_object("test_schema.test_view") + assert result is None + assert mock_get_data_objects.call_count == 2 + + +def test_data_object_cache_cleared_on_drop_data_object( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(EngineAdapter, patch_get_data_objects=False) + + table = DataObject(catalog=None, schema="test_schema", name="test_table", type="table") + + mock_get_data_objects = mocker.patch.object(adapter, "_get_data_objects", return_value=[table]) + + adapter.get_data_object("test_schema.test_table") + assert mock_get_data_objects.call_count == 1 + + adapter.drop_data_object(table) + + mock_get_data_objects.return_value = [] + result = adapter.get_data_object("test_schema.test_table") + assert result is None + assert mock_get_data_objects.call_count == 2 + + +def test_data_object_cache_cleared_on_create_table( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + from sqlglot import exp + + adapter = make_mocked_engine_adapter(EngineAdapter, patch_get_data_objects=False) + + # Initially cache that table doesn't exist + mock_get_data_objects = mocker.patch.object(adapter, "_get_data_objects", return_value=[]) + result = adapter.get_data_object("test_schema.test_table") + assert result is None + assert mock_get_data_objects.call_count == 1 + + # Create the table + table = DataObject(catalog=None, schema="test_schema", name="test_table", type="table") + mock_get_data_objects.return_value = [table] + adapter.create_table( + "test_schema.test_table", + {"col1": exp.DataType.build("INT")}, + ) + + # Cache should be cleared, so next get_data_object should call _get_data_objects again + result = adapter.get_data_object("test_schema.test_table") + assert result is not None + assert mock_get_data_objects.call_count == 2 + + +def test_data_object_cache_cleared_on_create_view( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + from sqlglot import parse_one + + adapter = make_mocked_engine_adapter(EngineAdapter, patch_get_data_objects=False) + + # Initially cache that view doesn't exist + mock_get_data_objects = mocker.patch.object(adapter, "_get_data_objects", return_value=[]) + result = adapter.get_data_object("test_schema.test_view") + assert result is None + assert mock_get_data_objects.call_count == 1 + + # Create the view + view = DataObject(catalog=None, schema="test_schema", name="test_view", type="view") + mock_get_data_objects.return_value = [view] + adapter.create_view( + "test_schema.test_view", + parse_one("SELECT 1 AS col1"), + ) + + # Cache should be cleared, so next get_data_object should call _get_data_objects again + result = adapter.get_data_object("test_schema.test_view") + assert result is not None + assert mock_get_data_objects.call_count == 2 + + +def test_data_object_cache_cleared_on_clone_table( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + from sqlmesh.core.engine_adapter.snowflake import SnowflakeEngineAdapter + + adapter = make_mocked_engine_adapter( + SnowflakeEngineAdapter, patch_get_data_objects=False, default_catalog="test_catalog" + ) + + # Initially cache that target table doesn't exist + mock_get_data_objects = mocker.patch.object(adapter, "_get_data_objects", return_value=[]) + result = adapter.get_data_object("test_schema.test_target") + assert result is None + assert mock_get_data_objects.call_count == 1 + + # Clone the table + target_table = DataObject( + catalog="test_catalog", schema="test_schema", name="test_target", type="table" + ) + mock_get_data_objects.return_value = [target_table] + adapter.clone_table("test_schema.test_target", "test_schema.test_source") + + # Cache should be cleared, so next get_data_object should call _get_data_objects again + result = adapter.get_data_object("test_schema.test_target") + assert result is not None + assert mock_get_data_objects.call_count == 2 + + +def test_data_object_cache_with_catalog( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + from sqlmesh.core.engine_adapter.snowflake import SnowflakeEngineAdapter + + adapter = make_mocked_engine_adapter( + SnowflakeEngineAdapter, patch_get_data_objects=False, default_catalog="test_catalog" + ) + + table = DataObject( + catalog="test_catalog", schema="test_schema", name="test_table", type="table" + ) + + mock_get_data_objects = mocker.patch.object(adapter, "_get_data_objects", return_value=[table]) + + result1 = adapter.get_data_object("test_catalog.test_schema.test_table") + assert result1 is not None + assert result1.catalog == "test_catalog" + assert mock_get_data_objects.call_count == 1 + + result2 = adapter.get_data_object("test_catalog.test_schema.test_table") + assert result2 is not None + assert result2.catalog == "test_catalog" + assert mock_get_data_objects.call_count == 1 # Should not increase + + +def test_data_object_cache_partial_cache_hit( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(EngineAdapter, patch_get_data_objects=False) + + table1 = DataObject(catalog=None, schema="test_schema", name="table1", type="table") + table2 = DataObject(catalog=None, schema="test_schema", name="table2", type="table") + table3 = DataObject(catalog=None, schema="test_schema", name="table3", type="table") + + mock_get_data_objects = mocker.patch.object( + adapter, "_get_data_objects", return_value=[table1, table2] + ) + + adapter.get_data_objects("test_schema", {"table1", "table2"}) + assert mock_get_data_objects.call_count == 1 + + mock_get_data_objects.return_value = [table3] + result = adapter.get_data_objects("test_schema", {"table1", "table3"}) + + assert len(result) == 2 + assert {obj.name for obj in result} == {"table1", "table3"} + assert mock_get_data_objects.call_count == 2 # Called again for table3 + + +def test_data_object_cache_get_data_objects_missing_objects( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(EngineAdapter, patch_get_data_objects=False) + + table1 = DataObject(catalog=None, schema="test_schema", name="table1", type="table") + table2 = DataObject(catalog=None, schema="test_schema", name="table2", type="table") + + mock_get_data_objects = mocker.patch.object(adapter, "_get_data_objects", return_value=[]) + + result1 = adapter.get_data_objects("test_schema", {"table1", "table2"}) + assert not result1 + assert mock_get_data_objects.call_count == 1 + + result2 = adapter.get_data_objects("test_schema", {"table1", "table2"}) + assert not result2 + assert mock_get_data_objects.call_count == 1 # Should not increase + + result3 = adapter.get_data_objects("test_schema", {"table1"}) + assert not result3 + assert mock_get_data_objects.call_count == 1 # Should not increase diff --git a/tests/core/engine_adapter/test_snowflake.py b/tests/core/engine_adapter/test_snowflake.py index 62c4a4f3eb..ce4d3a886c 100644 --- a/tests/core/engine_adapter/test_snowflake.py +++ b/tests/core/engine_adapter/test_snowflake.py @@ -358,12 +358,12 @@ def test_create_managed_table(make_mocked_engine_adapter: t.Callable, mocker: Mo def test_drop_managed_table(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) - adapter.drop_managed_table(table_name=exp.parse_identifier("foo"), exists=False) - adapter.drop_managed_table(table_name=exp.parse_identifier("foo"), exists=True) + adapter.drop_managed_table(table_name="foo.bar", exists=False) + adapter.drop_managed_table(table_name="foo.bar", exists=True) assert to_sql_calls(adapter) == [ - 'DROP DYNAMIC TABLE "foo"', - 'DROP DYNAMIC TABLE IF EXISTS "foo"', + 'DROP DYNAMIC TABLE "foo"."bar"', + 'DROP DYNAMIC TABLE IF EXISTS "foo"."bar"', ] diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index 2df91afb10..d2f28a08a7 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -1344,6 +1344,7 @@ def test_promote_deployable(mocker: MockerFixture, make_snapshot): ) adapter_mock.create_table.assert_not_called() + adapter_mock.get_data_objects.return_value = [] evaluator.promote([snapshot], EnvironmentNamingInfo(name="test_env")) adapter_mock.create_schema.assert_called_once_with(to_schema("test_schema__test_env")) @@ -4188,6 +4189,7 @@ def test_multiple_engine_promotion(mocker: MockerFixture, adapter_mock, make_sna connection_mock.cursor.return_value = cursor_mock adapter = EngineAdapter(lambda: connection_mock, "") adapter.with_settings = lambda **kwargs: adapter # type: ignore + adapter._get_data_objects = lambda *args, **kwargs: [] # type: ignore engine_adapters = {"default": adapter_mock, "secondary": adapter} def columns(table_name): From c7f5771cf8b9e2323bc4e67ebfd848cd26674864 Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Tue, 30 Sep 2025 17:35:16 -0700 Subject: [PATCH 2/3] address comments --- sqlmesh/core/engine_adapter/base.py | 26 ++++++++++++++------ sqlmesh/core/engine_adapter/base_postgres.py | 10 +++++++- sqlmesh/core/engine_adapter/bigquery.py | 7 ++++++ sqlmesh/core/engine_adapter/mssql.py | 9 +++++++ sqlmesh/core/engine_adapter/postgres.py | 2 +- tests/core/engine_adapter/test_base.py | 21 ++++++++++++++++ tests/dbt/test_transformation.py | 6 ++--- 7 files changed, 69 insertions(+), 12 deletions(-) diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index d9f417695f..9a30bbb080 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -1470,8 +1470,14 @@ def columns( } def table_exists(self, table_name: TableName) -> bool: + table = exp.to_table(table_name) + data_object_cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name) + if data_object_cache_key in self._data_object_cache: + logger.debug("Table existence cache hit: %s", data_object_cache_key) + return self._data_object_cache[data_object_cache_key] is not None + try: - self.execute(exp.Describe(this=exp.to_table(table_name), kind="TABLE")) + self.execute(exp.Describe(this=table, kind="TABLE")) return True except Exception: return False @@ -2301,11 +2307,13 @@ def get_data_objects( target_schema.catalog, target_schema.db, name ) if cache_key in self._data_object_cache: + logger.debug("Data object cache hit: %s", cache_key) data_object = self._data_object_cache[cache_key] # If the object is none, then the table was previously looked for but not found if data_object: cached_objects.append(data_object) else: + logger.debug("Data object cache miss: %s", cache_key) missing_names.add(name) # Fetch missing objects from database @@ -2321,7 +2329,6 @@ def get_data_objects( for obj in self._get_data_objects(schema_name, set(batch)) ] - # Cache the fetched objects for obj in fetched_objects: cache_key = _get_data_object_cache_key(obj.catalog, obj.schema_name, obj.name) self._data_object_cache[cache_key] = obj @@ -2336,7 +2343,12 @@ def get_data_objects( return cached_objects + fetched_objects return cached_objects - return self._get_data_objects(schema_name) + + fetched_objects = self._get_data_objects(schema_name) + for obj in fetched_objects: + cache_key = _get_data_object_cache_key(obj.catalog, obj.schema_name, obj.name) + self._data_object_cache[cache_key] = obj + return fetched_objects def fetchone( self, @@ -2746,10 +2758,12 @@ def _to_sql(self, expression: exp.Expression, quote: bool = True, **kwargs: t.An def _clear_data_object_cache(self, table_name: t.Optional[TableName] = None) -> None: """Clears the cache entry for the given table name, or clears the entire cache if table_name is None.""" if table_name is None: + logger.debug("Clearing entire data object cache") self._data_object_cache.clear() else: table = exp.to_table(table_name) cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name) + logger.debug("Clearing data object cache key: %s", cache_key) self._data_object_cache.pop(cache_key, None) def _get_data_objects( @@ -3003,7 +3017,5 @@ def _decoded_str(value: t.Union[str, bytes]) -> str: def _get_data_object_cache_key(catalog: t.Optional[str], schema_name: str, object_name: str) -> str: """Returns a cache key for a data object based on its fully qualified name.""" - catalog_part = catalog.lower() if catalog else "" - schema_part = schema_name.lower() - object_part = object_name.lower() - return f"{catalog_part}.{schema_part}.{object_part}" + catalog = catalog or "" + return f"{catalog}.{schema_name}.{object_name}" diff --git a/sqlmesh/core/engine_adapter/base_postgres.py b/sqlmesh/core/engine_adapter/base_postgres.py index c6ba7d6d62..3de975d6a5 100644 --- a/sqlmesh/core/engine_adapter/base_postgres.py +++ b/sqlmesh/core/engine_adapter/base_postgres.py @@ -1,11 +1,12 @@ from __future__ import annotations import typing as t +import logging from sqlglot import exp from sqlmesh.core.dialect import to_schema -from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.engine_adapter.base import EngineAdapter, _get_data_object_cache_key from sqlmesh.core.engine_adapter.shared import ( CatalogSupport, CommentCreationTable, @@ -20,6 +21,9 @@ from sqlmesh.core.engine_adapter._typing import QueryOrDF +logger = logging.getLogger(__name__) + + class BasePostgresEngineAdapter(EngineAdapter): DEFAULT_BATCH_SIZE = 400 COMMENT_CREATION_TABLE = CommentCreationTable.COMMENT_COMMAND_ONLY @@ -75,6 +79,10 @@ def table_exists(self, table_name: TableName) -> bool: Reference: https://github.com/aws/amazon-redshift-python-driver/blob/master/redshift_connector/cursor.py#L528-L553 """ table = exp.to_table(table_name) + data_object_cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name) + if data_object_cache_key in self._data_object_cache: + logger.debug("Table existence cache hit: %s", data_object_cache_key) + return self._data_object_cache[data_object_cache_key] is not None sql = ( exp.select("1") diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index 26abad9ebc..09fd7537ef 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -8,6 +8,7 @@ from sqlglot.transforms import remove_precision_parameterized_types from sqlmesh.core.dialect import to_schema +from sqlmesh.core.engine_adapter.base import _get_data_object_cache_key from sqlmesh.core.engine_adapter.mixins import ( ClusteredByMixin, RowDiffMixin, @@ -744,6 +745,12 @@ def insert_overwrite_by_partition( ) def table_exists(self, table_name: TableName) -> bool: + table = exp.to_table(table_name) + data_object_cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name) + if data_object_cache_key in self._data_object_cache: + logger.debug("Table existence cache hit: %s", data_object_cache_key) + return self._data_object_cache[data_object_cache_key] is not None + try: from google.cloud.exceptions import NotFound except ModuleNotFoundError: diff --git a/sqlmesh/core/engine_adapter/mssql.py b/sqlmesh/core/engine_adapter/mssql.py index fd0bf1011b..05c3753f14 100644 --- a/sqlmesh/core/engine_adapter/mssql.py +++ b/sqlmesh/core/engine_adapter/mssql.py @@ -3,6 +3,7 @@ from __future__ import annotations import typing as t +import logging from sqlglot import exp @@ -13,6 +14,7 @@ InsertOverwriteStrategy, MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS, + _get_data_object_cache_key, ) from sqlmesh.core.engine_adapter.mixins import ( GetCurrentCatalogFromFunctionMixin, @@ -36,6 +38,9 @@ from sqlmesh.core.engine_adapter._typing import DF, Query, QueryOrDF +logger = logging.getLogger(__name__) + + @set_catalog() class MSSQLEngineAdapter( EngineAdapterWithIndexSupport, @@ -144,6 +149,10 @@ def build_var_length_col( def table_exists(self, table_name: TableName) -> bool: """MsSql doesn't support describe so we query information_schema.""" table = exp.to_table(table_name) + data_object_cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name) + if data_object_cache_key in self._data_object_cache: + logger.debug("Table existence cache hit: %s", data_object_cache_key) + return self._data_object_cache[data_object_cache_key] is not None sql = ( exp.select("1") diff --git a/sqlmesh/core/engine_adapter/postgres.py b/sqlmesh/core/engine_adapter/postgres.py index e9c212bd5f..c67c7cc4e3 100644 --- a/sqlmesh/core/engine_adapter/postgres.py +++ b/sqlmesh/core/engine_adapter/postgres.py @@ -34,7 +34,7 @@ class PostgresEngineAdapter( HAS_VIEW_BINDING = True CURRENT_CATALOG_EXPRESSION = exp.column("current_catalog") SUPPORTS_REPLACE_TABLE = False - MAX_IDENTIFIER_LENGTH = 63 + MAX_IDENTIFIER_LENGTH: t.Optional[int] = 63 SUPPORTS_QUERY_EXECUTION_TRACKING = True SCHEMA_DIFFER_KWARGS = { "parameterized_type_defaults": { diff --git a/tests/core/engine_adapter/test_base.py b/tests/core/engine_adapter/test_base.py index 5a4bc459ef..08207ac726 100644 --- a/tests/core/engine_adapter/test_base.py +++ b/tests/core/engine_adapter/test_base.py @@ -3723,6 +3723,27 @@ def test_data_object_cache_get_data_objects( assert mock_get_data_objects.call_count == 1 # Should not increase +def test_data_object_cache_get_data_objects_no_object_names( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(EngineAdapter, patch_get_data_objects=False) + + table1 = DataObject(catalog=None, schema="test_schema", name="table1", type="table") + table2 = DataObject(catalog=None, schema="test_schema", name="table2", type="table") + + mock_get_data_objects = mocker.patch.object( + adapter, "_get_data_objects", return_value=[table1, table2] + ) + + result1 = adapter.get_data_objects("test_schema") + assert len(result1) == 2 + assert mock_get_data_objects.call_count == 1 + + result2 = adapter.get_data_objects("test_schema", {"table1", "table2"}) + assert len(result2) == 2 + assert mock_get_data_objects.call_count == 1 # Should not increase + + def test_data_object_cache_get_data_object( make_mocked_engine_adapter: t.Callable, mocker: MockerFixture ): diff --git a/tests/dbt/test_transformation.py b/tests/dbt/test_transformation.py index 9a9ce8f906..a33e3ed843 100644 --- a/tests/dbt/test_transformation.py +++ b/tests/dbt/test_transformation.py @@ -122,10 +122,10 @@ def test_dbt_custom_materialization(): selected_model = list(plan.selected_models)[0] assert selected_model == "model.sushi.custom_incremental_model" - qoery = "SELECT * FROM sushi.custom_incremental_model ORDER BY created_at" + query = "SELECT * FROM sushi.custom_incremental_model ORDER BY created_at" hook_table = "SELECT * FROM hook_table ORDER BY id" sushi_context.apply(plan) - result = sushi_context.engine_adapter.fetchdf(qoery) + result = sushi_context.engine_adapter.fetchdf(query) assert len(result) == 1 assert {"created_at", "id"}.issubset(result.columns) @@ -140,7 +140,7 @@ def test_dbt_custom_materialization(): tomorrow = datetime.now() + timedelta(days=1) sushi_context.run(select_models=["sushi.custom_incremental_model"], execution_time=tomorrow) - result_after_run = sushi_context.engine_adapter.fetchdf(qoery) + result_after_run = sushi_context.engine_adapter.fetchdf(query) assert {"created_at", "id"}.issubset(result_after_run.columns) # this should have added new unique values for the new row From 032cd9789c4c3771dfc8cef4c7a3568a602e85a2 Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Thu, 2 Oct 2025 10:30:02 -0700 Subject: [PATCH 3/3] only cache data objects mantained by sqlmesh --- sqlmesh/core/engine_adapter/base.py | 73 ++++++++---- sqlmesh/core/engine_adapter/clickhouse.py | 27 +++-- sqlmesh/core/engine_adapter/mysql.py | 4 +- sqlmesh/core/engine_adapter/postgres.py | 4 +- sqlmesh/core/engine_adapter/spark.py | 6 +- sqlmesh/core/snapshot/evaluator.py | 4 +- tests/core/engine_adapter/test_base.py | 138 ++++++++++++++++++---- tests/core/test_snapshot_evaluator.py | 6 + 8 files changed, 194 insertions(+), 68 deletions(-) diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index 9a30bbb080..d9cc4f44a2 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -1046,7 +1046,8 @@ def create_table_like( target_table_name: The name of the table to create. Can be fully qualified or just table name. source_table_name: The name of the table to base the new table on. """ - self.create_table(target_table_name, self.columns(source_table_name), exists=exists) + self._create_table_like(target_table_name, source_table_name, exists=exists, **kwargs) + self._clear_data_object_cache(target_table_name) def clone_table( self, @@ -2271,24 +2272,34 @@ def rename_table( "Tried to rename table across catalogs which is not supported" ) self._rename_table(old_table_name, new_table_name) + self._clear_data_object_cache(old_table_name) + self._clear_data_object_cache(new_table_name) - def get_data_object(self, target_name: TableName) -> t.Optional[DataObject]: + def get_data_object( + self, target_name: TableName, safe_to_cache: bool = False + ) -> t.Optional[DataObject]: target_table = exp.to_table(target_name) existing_data_objects = self.get_data_objects( - schema_(target_table.db, target_table.catalog), {target_table.name} + schema_(target_table.db, target_table.catalog), + {target_table.name}, + safe_to_cache=safe_to_cache, ) if existing_data_objects: return existing_data_objects[0] return None def get_data_objects( - self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None + self, + schema_name: SchemaName, + object_names: t.Optional[t.Set[str]] = None, + safe_to_cache: bool = False, ) -> t.List[DataObject]: """Lists all data objects in the target schema. Args: schema_name: The name of the schema to list data objects from. object_names: If provided, only return data objects with these names. + safe_to_cache: Whether it is safe to cache the results of this call. Returns: A list of data objects in the target schema. @@ -2323,31 +2334,36 @@ def get_data_objects( object_names_list[i : i + self.DATA_OBJECT_FILTER_BATCH_SIZE] for i in range(0, len(object_names_list), self.DATA_OBJECT_FILTER_BATCH_SIZE) ] - fetched_objects = [ - obj - for batch in batches - for obj in self._get_data_objects(schema_name, set(batch)) - ] - - for obj in fetched_objects: - cache_key = _get_data_object_cache_key(obj.catalog, obj.schema_name, obj.name) - self._data_object_cache[cache_key] = obj - fetched_object_names = {obj.name for obj in fetched_objects} - for missing_name in missing_names - fetched_object_names: - cache_key = _get_data_object_cache_key( - target_schema.catalog, target_schema.db, missing_name - ) - self._data_object_cache[cache_key] = None + fetched_objects = [] + fetched_object_names = set() + for batch in batches: + objects = self._get_data_objects(schema_name, set(batch)) + for obj in objects: + if safe_to_cache: + cache_key = _get_data_object_cache_key( + obj.catalog, obj.schema_name, obj.name + ) + self._data_object_cache[cache_key] = obj + fetched_objects.append(obj) + fetched_object_names.add(obj.name) + + if safe_to_cache: + for missing_name in missing_names - fetched_object_names: + cache_key = _get_data_object_cache_key( + target_schema.catalog, target_schema.db, missing_name + ) + self._data_object_cache[cache_key] = None return cached_objects + fetched_objects return cached_objects fetched_objects = self._get_data_objects(schema_name) - for obj in fetched_objects: - cache_key = _get_data_object_cache_key(obj.catalog, obj.schema_name, obj.name) - self._data_object_cache[cache_key] = obj + if safe_to_cache: + for obj in fetched_objects: + cache_key = _get_data_object_cache_key(obj.catalog, obj.schema_name, obj.name) + self._data_object_cache[cache_key] = obj return fetched_objects def fetchone( @@ -2951,6 +2967,15 @@ def _create_column_comments( exc_info=True, ) + def _create_table_like( + self, + target_table_name: TableName, + source_table_name: TableName, + exists: bool, + **kwargs: t.Any, + ) -> None: + self.create_table(target_table_name, self.columns(source_table_name), exists=exists) + def _rename_table( self, old_table_name: TableName, @@ -3017,5 +3042,5 @@ def _decoded_str(value: t.Union[str, bytes]) -> str: def _get_data_object_cache_key(catalog: t.Optional[str], schema_name: str, object_name: str) -> str: """Returns a cache key for a data object based on its fully qualified name.""" - catalog = catalog or "" - return f"{catalog}.{schema_name}.{object_name}" + catalog = f"{catalog}." if catalog else "" + return f"{catalog}{schema_name}.{object_name}" diff --git a/sqlmesh/core/engine_adapter/clickhouse.py b/sqlmesh/core/engine_adapter/clickhouse.py index 84d6ad311e..45c22a6e55 100644 --- a/sqlmesh/core/engine_adapter/clickhouse.py +++ b/sqlmesh/core/engine_adapter/clickhouse.py @@ -224,7 +224,7 @@ def _insert_overwrite_by_condition( target_columns_to_types = target_columns_to_types or self.columns(target_table) temp_table = self._get_temp_table(target_table) - self._create_table_like(temp_table, target_table) + self.create_table_like(temp_table, target_table) # REPLACE BY KEY: extract kwargs if present dynamic_key = kwargs.get("dynamic_key") @@ -456,7 +456,11 @@ def insert_overwrite_by_partition( ) def _create_table_like( - self, target_table_name: TableName, source_table_name: TableName + self, + target_table_name: TableName, + source_table_name: TableName, + exists: bool, + **kwargs: t.Any, ) -> None: """Create table with identical structure as source table""" self.execute( @@ -632,16 +636,15 @@ def _drop_object( kind: What kind of object to drop. Defaults to TABLE **drop_args: Any extra arguments to set on the Drop expression """ - self.execute( - exp.Drop( - this=exp.to_table(name), - kind=kind, - exists=exists, - cluster=exp.OnCluster(this=exp.to_identifier(self.cluster)) - if self.engine_run_mode.is_cluster - else None, - **drop_args, - ) + super()._drop_object( + name=name, + exists=exists, + kind=kind, + cascade=cascade, + cluster=exp.OnCluster(this=exp.to_identifier(self.cluster)) + if self.engine_run_mode.is_cluster + else None, + **drop_args, ) def _build_partitioned_by_exp( diff --git a/sqlmesh/core/engine_adapter/mysql.py b/sqlmesh/core/engine_adapter/mysql.py index 26cc7c0197..31773d6c63 100644 --- a/sqlmesh/core/engine_adapter/mysql.py +++ b/sqlmesh/core/engine_adapter/mysql.py @@ -164,11 +164,11 @@ def _create_column_comments( exc_info=True, ) - def create_table_like( + def _create_table_like( self, target_table_name: TableName, source_table_name: TableName, - exists: bool = True, + exists: bool, **kwargs: t.Any, ) -> None: self.execute( diff --git a/sqlmesh/core/engine_adapter/postgres.py b/sqlmesh/core/engine_adapter/postgres.py index c67c7cc4e3..79431ee360 100644 --- a/sqlmesh/core/engine_adapter/postgres.py +++ b/sqlmesh/core/engine_adapter/postgres.py @@ -79,11 +79,11 @@ def _fetch_native_df( self._connection_pool.commit() return df - def create_table_like( + def _create_table_like( self, target_table_name: TableName, source_table_name: TableName, - exists: bool = True, + exists: bool, **kwargs: t.Any, ) -> None: self.execute( diff --git a/sqlmesh/core/engine_adapter/spark.py b/sqlmesh/core/engine_adapter/spark.py index 18ba6ea106..b2d6a9cbb5 100644 --- a/sqlmesh/core/engine_adapter/spark.py +++ b/sqlmesh/core/engine_adapter/spark.py @@ -402,14 +402,16 @@ def get_current_database(self) -> str: return self.spark.catalog.currentDatabase() return self.fetchone(exp.select(exp.func("current_database")))[0] # type: ignore - def get_data_object(self, target_name: TableName) -> t.Optional[DataObject]: + def get_data_object( + self, target_name: TableName, safe_to_cache: bool = False + ) -> t.Optional[DataObject]: target_table = exp.to_table(target_name) if isinstance(target_table.this, exp.Dot) and target_table.this.expression.name.startswith( f"{self.BRANCH_PREFIX}{self.WAP_PREFIX}" ): # Exclude the branch name target_table.set("this", target_table.this.this) - return super().get_data_object(target_table) + return super().get_data_object(target_table, safe_to_cache=safe_to_cache) def create_state_table( self, diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 9419815939..1483bdeece 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -1564,7 +1564,9 @@ def _get_data_objects_in_schema( gateway: t.Optional[str] = None, ) -> t.List[DataObject]: logger.info("Listing data objects in schema %s", schema.sql()) - return self.get_adapter(gateway).get_data_objects(schema, object_names) + return self.get_adapter(gateway).get_data_objects( + schema, object_names, safe_to_cache=True + ) with self.concurrent_context(): existing_objects: t.List[DataObject] = [] diff --git a/tests/core/engine_adapter/test_base.py b/tests/core/engine_adapter/test_base.py index 08207ac726..ba775c0779 100644 --- a/tests/core/engine_adapter/test_base.py +++ b/tests/core/engine_adapter/test_base.py @@ -3709,11 +3709,11 @@ def test_data_object_cache_get_data_objects( adapter, "_get_data_objects", return_value=[table1, table2] ) - result1 = adapter.get_data_objects("test_schema", {"table1", "table2"}) + result1 = adapter.get_data_objects("test_schema", {"table1", "table2"}, safe_to_cache=True) assert len(result1) == 2 assert mock_get_data_objects.call_count == 1 - result2 = adapter.get_data_objects("test_schema", {"table1", "table2"}) + result2 = adapter.get_data_objects("test_schema", {"table1", "table2"}, safe_to_cache=True) assert len(result2) == 2 assert mock_get_data_objects.call_count == 1 # Should not increase @@ -3723,6 +3723,35 @@ def test_data_object_cache_get_data_objects( assert mock_get_data_objects.call_count == 1 # Should not increase +def test_data_object_cache_get_data_objects_bypasses_cache( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(EngineAdapter, patch_get_data_objects=False) + + table1 = DataObject(catalog=None, schema="test_schema", name="table1", type="table") + table2 = DataObject(catalog=None, schema="test_schema", name="table2", type="table") + + mock_get_data_objects = mocker.patch.object( + adapter, "_get_data_objects", return_value=[table1, table2] + ) + + assert adapter.get_data_objects("test_schema") + assert adapter.get_data_objects("test_schema", {"table1", "table2"}) + assert adapter.get_data_objects("test_schema", {"table1", "table2"}) + assert adapter.get_data_objects("test_schema", {"table1"}) + assert adapter.get_data_object("test_schema.table1") is not None + + mock_get_data_objects.return_value = [] + assert not adapter.get_data_objects("test_schema") + assert not adapter.get_data_objects("test_schema", {"missing"}) + assert not adapter.get_data_objects("test_schema", {"missing"}) + assert adapter.get_data_object("test_schema.missing") is None + + # None of the calls should've been cached + assert mock_get_data_objects.call_count == 9 + assert not adapter._data_object_cache + + def test_data_object_cache_get_data_objects_no_object_names( make_mocked_engine_adapter: t.Callable, mocker: MockerFixture ): @@ -3735,11 +3764,11 @@ def test_data_object_cache_get_data_objects_no_object_names( adapter, "_get_data_objects", return_value=[table1, table2] ) - result1 = adapter.get_data_objects("test_schema") + result1 = adapter.get_data_objects("test_schema", safe_to_cache=True) assert len(result1) == 2 assert mock_get_data_objects.call_count == 1 - result2 = adapter.get_data_objects("test_schema", {"table1", "table2"}) + result2 = adapter.get_data_objects("test_schema", {"table1", "table2"}, safe_to_cache=True) assert len(result2) == 2 assert mock_get_data_objects.call_count == 1 # Should not increase @@ -3753,12 +3782,12 @@ def test_data_object_cache_get_data_object( mock_get_data_objects = mocker.patch.object(adapter, "_get_data_objects", return_value=[table]) - result1 = adapter.get_data_object("test_schema.test_table") + result1 = adapter.get_data_object("test_schema.test_table", safe_to_cache=True) assert result1 is not None assert result1.name == "test_table" assert mock_get_data_objects.call_count == 1 - result2 = adapter.get_data_object("test_schema.test_table") + result2 = adapter.get_data_object("test_schema.test_table", safe_to_cache=True) assert result2 is not None assert result2.name == "test_table" assert mock_get_data_objects.call_count == 1 # Should not increase @@ -3773,13 +3802,13 @@ def test_data_object_cache_cleared_on_drop_table( mock_get_data_objects = mocker.patch.object(adapter, "_get_data_objects", return_value=[table]) - adapter.get_data_object("test_schema.test_table") + adapter.get_data_object("test_schema.test_table", safe_to_cache=True) assert mock_get_data_objects.call_count == 1 adapter.drop_table("test_schema.test_table") mock_get_data_objects.return_value = [] - result = adapter.get_data_object("test_schema.test_table") + result = adapter.get_data_object("test_schema.test_table", safe_to_cache=True) assert result is None assert mock_get_data_objects.call_count == 2 @@ -3793,13 +3822,13 @@ def test_data_object_cache_cleared_on_drop_view( mock_get_data_objects = mocker.patch.object(adapter, "_get_data_objects", return_value=[view]) - adapter.get_data_object("test_schema.test_view") + adapter.get_data_object("test_schema.test_view", safe_to_cache=True) assert mock_get_data_objects.call_count == 1 adapter.drop_view("test_schema.test_view") mock_get_data_objects.return_value = [] - result = adapter.get_data_object("test_schema.test_view") + result = adapter.get_data_object("test_schema.test_view", safe_to_cache=True) assert result is None assert mock_get_data_objects.call_count == 2 @@ -3813,13 +3842,13 @@ def test_data_object_cache_cleared_on_drop_data_object( mock_get_data_objects = mocker.patch.object(adapter, "_get_data_objects", return_value=[table]) - adapter.get_data_object("test_schema.test_table") + adapter.get_data_object("test_schema.test_table", safe_to_cache=True) assert mock_get_data_objects.call_count == 1 adapter.drop_data_object(table) mock_get_data_objects.return_value = [] - result = adapter.get_data_object("test_schema.test_table") + result = adapter.get_data_object("test_schema.test_table", safe_to_cache=True) assert result is None assert mock_get_data_objects.call_count == 2 @@ -3833,7 +3862,7 @@ def test_data_object_cache_cleared_on_create_table( # Initially cache that table doesn't exist mock_get_data_objects = mocker.patch.object(adapter, "_get_data_objects", return_value=[]) - result = adapter.get_data_object("test_schema.test_table") + result = adapter.get_data_object("test_schema.test_table", safe_to_cache=True) assert result is None assert mock_get_data_objects.call_count == 1 @@ -3846,7 +3875,7 @@ def test_data_object_cache_cleared_on_create_table( ) # Cache should be cleared, so next get_data_object should call _get_data_objects again - result = adapter.get_data_object("test_schema.test_table") + result = adapter.get_data_object("test_schema.test_table", safe_to_cache=True) assert result is not None assert mock_get_data_objects.call_count == 2 @@ -3860,7 +3889,7 @@ def test_data_object_cache_cleared_on_create_view( # Initially cache that view doesn't exist mock_get_data_objects = mocker.patch.object(adapter, "_get_data_objects", return_value=[]) - result = adapter.get_data_object("test_schema.test_view") + result = adapter.get_data_object("test_schema.test_view", safe_to_cache=True) assert result is None assert mock_get_data_objects.call_count == 1 @@ -3873,7 +3902,7 @@ def test_data_object_cache_cleared_on_create_view( ) # Cache should be cleared, so next get_data_object should call _get_data_objects again - result = adapter.get_data_object("test_schema.test_view") + result = adapter.get_data_object("test_schema.test_view", safe_to_cache=True) assert result is not None assert mock_get_data_objects.call_count == 2 @@ -3889,7 +3918,7 @@ def test_data_object_cache_cleared_on_clone_table( # Initially cache that target table doesn't exist mock_get_data_objects = mocker.patch.object(adapter, "_get_data_objects", return_value=[]) - result = adapter.get_data_object("test_schema.test_target") + result = adapter.get_data_object("test_schema.test_target", safe_to_cache=True) assert result is None assert mock_get_data_objects.call_count == 1 @@ -3901,7 +3930,7 @@ def test_data_object_cache_cleared_on_clone_table( adapter.clone_table("test_schema.test_target", "test_schema.test_source") # Cache should be cleared, so next get_data_object should call _get_data_objects again - result = adapter.get_data_object("test_schema.test_target") + result = adapter.get_data_object("test_schema.test_target", safe_to_cache=True) assert result is not None assert mock_get_data_objects.call_count == 2 @@ -3921,12 +3950,12 @@ def test_data_object_cache_with_catalog( mock_get_data_objects = mocker.patch.object(adapter, "_get_data_objects", return_value=[table]) - result1 = adapter.get_data_object("test_catalog.test_schema.test_table") + result1 = adapter.get_data_object("test_catalog.test_schema.test_table", safe_to_cache=True) assert result1 is not None assert result1.catalog == "test_catalog" assert mock_get_data_objects.call_count == 1 - result2 = adapter.get_data_object("test_catalog.test_schema.test_table") + result2 = adapter.get_data_object("test_catalog.test_schema.test_table", safe_to_cache=True) assert result2 is not None assert result2.catalog == "test_catalog" assert mock_get_data_objects.call_count == 1 # Should not increase @@ -3945,11 +3974,11 @@ def test_data_object_cache_partial_cache_hit( adapter, "_get_data_objects", return_value=[table1, table2] ) - adapter.get_data_objects("test_schema", {"table1", "table2"}) + adapter.get_data_objects("test_schema", {"table1", "table2"}, safe_to_cache=True) assert mock_get_data_objects.call_count == 1 mock_get_data_objects.return_value = [table3] - result = adapter.get_data_objects("test_schema", {"table1", "table3"}) + result = adapter.get_data_objects("test_schema", {"table1", "table3"}, safe_to_cache=True) assert len(result) == 2 assert {obj.name for obj in result} == {"table1", "table3"} @@ -3966,14 +3995,73 @@ def test_data_object_cache_get_data_objects_missing_objects( mock_get_data_objects = mocker.patch.object(adapter, "_get_data_objects", return_value=[]) - result1 = adapter.get_data_objects("test_schema", {"table1", "table2"}) + result1 = adapter.get_data_objects("test_schema", {"table1", "table2"}, safe_to_cache=True) assert not result1 assert mock_get_data_objects.call_count == 1 - result2 = adapter.get_data_objects("test_schema", {"table1", "table2"}) + result2 = adapter.get_data_objects("test_schema", {"table1", "table2"}, safe_to_cache=True) assert not result2 assert mock_get_data_objects.call_count == 1 # Should not increase - result3 = adapter.get_data_objects("test_schema", {"table1"}) + result3 = adapter.get_data_objects("test_schema", {"table1"}, safe_to_cache=True) assert not result3 assert mock_get_data_objects.call_count == 1 # Should not increase + + +def test_data_object_cache_cleared_on_rename_table( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(EngineAdapter, patch_get_data_objects=False) + + old_table = DataObject(catalog=None, schema="test_schema", name="old_table", type="table") + mock_get_data_objects = mocker.patch.object( + adapter, "_get_data_objects", return_value=[old_table] + ) + + result = adapter.get_data_object("test_schema.old_table", safe_to_cache=True) + assert result is not None + assert result.name == "old_table" + assert mock_get_data_objects.call_count == 1 + + new_table = DataObject(catalog=None, schema="test_schema", name="new_table", type="table") + mock_get_data_objects.return_value = [new_table] + adapter.rename_table("test_schema.old_table", "test_schema.new_table") + + mock_get_data_objects.return_value = [] + result = adapter.get_data_object("test_schema.old_table", safe_to_cache=True) + assert result is None + assert mock_get_data_objects.call_count == 2 + + mock_get_data_objects.return_value = [new_table] + result = adapter.get_data_object("test_schema.new_table", safe_to_cache=True) + assert result is not None + assert result.name == "new_table" + assert mock_get_data_objects.call_count == 3 + + +def test_data_object_cache_cleared_on_create_table_like( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + from sqlglot import exp + + adapter = make_mocked_engine_adapter(EngineAdapter, patch_get_data_objects=False) + + columns_to_types = { + "col1": exp.DataType.build("INT"), + "col2": exp.DataType.build("TEXT"), + } + mocker.patch.object(adapter, "columns", return_value=columns_to_types) + + mock_get_data_objects = mocker.patch.object(adapter, "_get_data_objects", return_value=[]) + result = adapter.get_data_object("test_schema.target_table", safe_to_cache=True) + assert result is None + assert mock_get_data_objects.call_count == 1 + + target_table = DataObject(catalog=None, schema="test_schema", name="target_table", type="table") + mock_get_data_objects.return_value = [target_table] + adapter.create_table_like("test_schema.target_table", "test_schema.source_table") + + result = adapter.get_data_object("test_schema.target_table", safe_to_cache=True) + assert result is not None + assert result.name == "target_table" + assert mock_get_data_objects.call_count == 2 diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index d2f28a08a7..19685e81c3 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -888,6 +888,7 @@ def test_create_prod_table_exists(mocker: MockerFixture, adapter_mock, make_snap { f"test_schema__test_model__{snapshot.version}", }, + safe_to_cache=True, ) @@ -974,6 +975,7 @@ def test_create_only_dev_table_exists(mocker: MockerFixture, adapter_mock, make_ { f"test_schema__test_model__{snapshot.version}__dev", }, + safe_to_cache=True, ) @@ -1023,6 +1025,7 @@ def test_create_new_forward_only_model(mocker: MockerFixture, adapter_mock, make { f"test_schema__test_model__{snapshot.dev_version}__dev", }, + safe_to_cache=True, ) @@ -1113,6 +1116,7 @@ def test_create_tables_exist( adapter_mock.get_data_objects.assert_called_once_with( schema_("sqlmesh__db"), {table_name}, + safe_to_cache=True, ) adapter_mock.create_schema.assert_not_called() adapter_mock.create_table.assert_not_called() @@ -1150,6 +1154,7 @@ def test_create_prod_table_exists_forward_only(mocker: MockerFixture, adapter_mo { f"test_schema__test_model__{snapshot.version}", }, + safe_to_cache=True, ) adapter_mock.create_table.assert_not_called() @@ -1341,6 +1346,7 @@ def test_promote_deployable(mocker: MockerFixture, make_snapshot): { f"test_schema__test_model__{snapshot.version}", }, + safe_to_cache=True, ) adapter_mock.create_table.assert_not_called()