Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 118 additions & 14 deletions sqlmesh/core/engine_adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def __init__(
self.correlation_id = correlation_id
self._schema_differ_overrides = schema_differ_overrides
self._query_execution_tracker = query_execution_tracker
self._data_object_cache: t.Dict[str, t.Optional[DataObject]] = {}

def with_settings(self, **kwargs: t.Any) -> EngineAdapter:
extra_kwargs = {
Expand Down Expand Up @@ -983,6 +984,13 @@ def _create_table(
),
track_rows_processed=track_rows_processed,
)
# Extract table name to clear cache
table_name = (
table_name_or_schema.this
if isinstance(table_name_or_schema, exp.Schema)
else table_name_or_schema
)
self._clear_data_object_cache(table_name)

def _build_create_table_exp(
self,
Expand Down Expand Up @@ -1038,7 +1046,8 @@ def create_table_like(
target_table_name: The name of the table to create. Can be fully qualified or just table name.
source_table_name: The name of the table to base the new table on.
"""
self.create_table(target_table_name, self.columns(source_table_name), exists=exists)
self._create_table_like(target_table_name, source_table_name, exists=exists, **kwargs)
self._clear_data_object_cache(target_table_name)

def clone_table(
self,
Expand Down Expand Up @@ -1074,6 +1083,7 @@ def clone_table(
**kwargs,
)
)
self._clear_data_object_cache(target_table_name)

def drop_data_object(self, data_object: DataObject, ignore_if_not_exists: bool = True) -> None:
"""Drops a data object of arbitrary type.
Expand Down Expand Up @@ -1139,6 +1149,7 @@ def _drop_object(
drop_args["cascade"] = cascade

self.execute(exp.Drop(this=exp.to_table(name), kind=kind, exists=exists, **drop_args))
self._clear_data_object_cache(name)

def get_alter_operations(
self,
Expand Down Expand Up @@ -1329,6 +1340,8 @@ def create_view(
quote_identifiers=self.QUOTE_IDENTIFIERS_IN_VIEWS,
)

self._clear_data_object_cache(view_name)

# Register table comment with commands if the engine doesn't support doing it in CREATE
if (
table_description
Expand Down Expand Up @@ -1458,8 +1471,14 @@ def columns(
}

def table_exists(self, table_name: TableName) -> bool:
table = exp.to_table(table_name)
data_object_cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name)
if data_object_cache_key in self._data_object_cache:
logger.debug("Table existence cache hit: %s", data_object_cache_key)
return self._data_object_cache[data_object_cache_key] is not None

try:
self.execute(exp.Describe(this=exp.to_table(table_name), kind="TABLE"))
self.execute(exp.Describe(this=table, kind="TABLE"))
return True
except Exception:
return False
Expand Down Expand Up @@ -2253,40 +2272,99 @@ def rename_table(
"Tried to rename table across catalogs which is not supported"
)
self._rename_table(old_table_name, new_table_name)
self._clear_data_object_cache(old_table_name)
self._clear_data_object_cache(new_table_name)

def get_data_object(self, target_name: TableName) -> t.Optional[DataObject]:
def get_data_object(
self, target_name: TableName, safe_to_cache: bool = False
) -> t.Optional[DataObject]:
target_table = exp.to_table(target_name)
existing_data_objects = self.get_data_objects(
schema_(target_table.db, target_table.catalog), {target_table.name}
schema_(target_table.db, target_table.catalog),
{target_table.name},
safe_to_cache=safe_to_cache,
)
if existing_data_objects:
return existing_data_objects[0]
return None

def get_data_objects(
self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None
self,
schema_name: SchemaName,
object_names: t.Optional[t.Set[str]] = None,
safe_to_cache: bool = False,
) -> t.List[DataObject]:
"""Lists all data objects in the target schema.

Args:
schema_name: The name of the schema to list data objects from.
object_names: If provided, only return data objects with these names.
safe_to_cache: Whether it is safe to cache the results of this call.

Returns:
A list of data objects in the target schema.
"""
if object_names is not None:
if not object_names:
return []
object_names_list = list(object_names)
batches = [
object_names_list[i : i + self.DATA_OBJECT_FILTER_BATCH_SIZE]
for i in range(0, len(object_names_list), self.DATA_OBJECT_FILTER_BATCH_SIZE)
]
return [
obj for batch in batches for obj in self._get_data_objects(schema_name, set(batch))
]
return self._get_data_objects(schema_name)

# Check cache for each object name
target_schema = to_schema(schema_name)
cached_objects = []
missing_names = set()

for name in object_names:
cache_key = _get_data_object_cache_key(
target_schema.catalog, target_schema.db, name
)
if cache_key in self._data_object_cache:
logger.debug("Data object cache hit: %s", cache_key)
data_object = self._data_object_cache[cache_key]
# If the object is none, then the table was previously looked for but not found
if data_object:
cached_objects.append(data_object)
else:
logger.debug("Data object cache miss: %s", cache_key)
missing_names.add(name)

# Fetch missing objects from database
if missing_names:
object_names_list = list(missing_names)
batches = [
object_names_list[i : i + self.DATA_OBJECT_FILTER_BATCH_SIZE]
for i in range(0, len(object_names_list), self.DATA_OBJECT_FILTER_BATCH_SIZE)
]

fetched_objects = []
fetched_object_names = set()
for batch in batches:
objects = self._get_data_objects(schema_name, set(batch))
for obj in objects:
if safe_to_cache:
cache_key = _get_data_object_cache_key(
obj.catalog, obj.schema_name, obj.name
)
self._data_object_cache[cache_key] = obj
fetched_objects.append(obj)
fetched_object_names.add(obj.name)

if safe_to_cache:
for missing_name in missing_names - fetched_object_names:
cache_key = _get_data_object_cache_key(
target_schema.catalog, target_schema.db, missing_name
)
self._data_object_cache[cache_key] = None

return cached_objects + fetched_objects

return cached_objects

fetched_objects = self._get_data_objects(schema_name)
if safe_to_cache:
for obj in fetched_objects:
cache_key = _get_data_object_cache_key(obj.catalog, obj.schema_name, obj.name)
self._data_object_cache[cache_key] = obj
return fetched_objects

def fetchone(
self,
Expand Down Expand Up @@ -2693,6 +2771,17 @@ def _to_sql(self, expression: exp.Expression, quote: bool = True, **kwargs: t.An

return expression.sql(**sql_gen_kwargs, copy=False) # type: ignore

def _clear_data_object_cache(self, table_name: t.Optional[TableName] = None) -> None:
"""Clears the cache entry for the given table name, or clears the entire cache if table_name is None."""
if table_name is None:
logger.debug("Clearing entire data object cache")
self._data_object_cache.clear()
else:
table = exp.to_table(table_name)
cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name)
logger.debug("Clearing data object cache key: %s", cache_key)
self._data_object_cache.pop(cache_key, None)

def _get_data_objects(
self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None
) -> t.List[DataObject]:
Expand Down Expand Up @@ -2878,6 +2967,15 @@ def _create_column_comments(
exc_info=True,
)

def _create_table_like(
self,
target_table_name: TableName,
source_table_name: TableName,
exists: bool,
**kwargs: t.Any,
) -> None:
self.create_table(target_table_name, self.columns(source_table_name), exists=exists)

def _rename_table(
self,
old_table_name: TableName,
Expand Down Expand Up @@ -2940,3 +3038,9 @@ def _decoded_str(value: t.Union[str, bytes]) -> str:
if isinstance(value, bytes):
return value.decode("utf-8")
return value


def _get_data_object_cache_key(catalog: t.Optional[str], schema_name: str, object_name: str) -> str:
"""Returns a cache key for a data object based on its fully qualified name."""
catalog = f"{catalog}." if catalog else ""
return f"{catalog}{schema_name}.{object_name}"
10 changes: 9 additions & 1 deletion sqlmesh/core/engine_adapter/base_postgres.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import typing as t
import logging

from sqlglot import exp

from sqlmesh.core.dialect import to_schema
from sqlmesh.core.engine_adapter import EngineAdapter
from sqlmesh.core.engine_adapter.base import EngineAdapter, _get_data_object_cache_key
from sqlmesh.core.engine_adapter.shared import (
CatalogSupport,
CommentCreationTable,
Expand All @@ -20,6 +21,9 @@
from sqlmesh.core.engine_adapter._typing import QueryOrDF


logger = logging.getLogger(__name__)


class BasePostgresEngineAdapter(EngineAdapter):
DEFAULT_BATCH_SIZE = 400
COMMENT_CREATION_TABLE = CommentCreationTable.COMMENT_COMMAND_ONLY
Expand Down Expand Up @@ -75,6 +79,10 @@ def table_exists(self, table_name: TableName) -> bool:
Reference: https://github.com/aws/amazon-redshift-python-driver/blob/master/redshift_connector/cursor.py#L528-L553
"""
table = exp.to_table(table_name)
data_object_cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name)
if data_object_cache_key in self._data_object_cache:
logger.debug("Table existence cache hit: %s", data_object_cache_key)
return self._data_object_cache[data_object_cache_key] is not None

sql = (
exp.select("1")
Expand Down
7 changes: 7 additions & 0 deletions sqlmesh/core/engine_adapter/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sqlglot.transforms import remove_precision_parameterized_types

from sqlmesh.core.dialect import to_schema
from sqlmesh.core.engine_adapter.base import _get_data_object_cache_key
from sqlmesh.core.engine_adapter.mixins import (
ClusteredByMixin,
RowDiffMixin,
Expand Down Expand Up @@ -744,6 +745,12 @@ def insert_overwrite_by_partition(
)

def table_exists(self, table_name: TableName) -> bool:
table = exp.to_table(table_name)
data_object_cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name)
if data_object_cache_key in self._data_object_cache:
logger.debug("Table existence cache hit: %s", data_object_cache_key)
return self._data_object_cache[data_object_cache_key] is not None

try:
from google.cloud.exceptions import NotFound
except ModuleNotFoundError:
Expand Down
27 changes: 15 additions & 12 deletions sqlmesh/core/engine_adapter/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def _insert_overwrite_by_condition(
target_columns_to_types = target_columns_to_types or self.columns(target_table)

temp_table = self._get_temp_table(target_table)
self._create_table_like(temp_table, target_table)
self.create_table_like(temp_table, target_table)

# REPLACE BY KEY: extract kwargs if present
dynamic_key = kwargs.get("dynamic_key")
Expand Down Expand Up @@ -456,7 +456,11 @@ def insert_overwrite_by_partition(
)

def _create_table_like(
self, target_table_name: TableName, source_table_name: TableName
self,
target_table_name: TableName,
source_table_name: TableName,
exists: bool,
**kwargs: t.Any,
) -> None:
"""Create table with identical structure as source table"""
self.execute(
Expand Down Expand Up @@ -632,16 +636,15 @@ def _drop_object(
kind: What kind of object to drop. Defaults to TABLE
**drop_args: Any extra arguments to set on the Drop expression
"""
self.execute(
exp.Drop(
this=exp.to_table(name),
kind=kind,
exists=exists,
cluster=exp.OnCluster(this=exp.to_identifier(self.cluster))
if self.engine_run_mode.is_cluster
else None,
**drop_args,
)
super()._drop_object(
name=name,
exists=exists,
kind=kind,
cascade=cascade,
cluster=exp.OnCluster(this=exp.to_identifier(self.cluster))
if self.engine_run_mode.is_cluster
else None,
**drop_args,
)

def _build_partitioned_by_exp(
Expand Down
9 changes: 9 additions & 0 deletions sqlmesh/core/engine_adapter/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import typing as t
import logging

from sqlglot import exp

Expand All @@ -13,6 +14,7 @@
InsertOverwriteStrategy,
MERGE_SOURCE_ALIAS,
MERGE_TARGET_ALIAS,
_get_data_object_cache_key,
)
from sqlmesh.core.engine_adapter.mixins import (
GetCurrentCatalogFromFunctionMixin,
Expand All @@ -36,6 +38,9 @@
from sqlmesh.core.engine_adapter._typing import DF, Query, QueryOrDF


logger = logging.getLogger(__name__)


@set_catalog()
class MSSQLEngineAdapter(
EngineAdapterWithIndexSupport,
Expand Down Expand Up @@ -144,6 +149,10 @@ def build_var_length_col(
def table_exists(self, table_name: TableName) -> bool:
"""MsSql doesn't support describe so we query information_schema."""
table = exp.to_table(table_name)
data_object_cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name)
if data_object_cache_key in self._data_object_cache:
logger.debug("Table existence cache hit: %s", data_object_cache_key)
return self._data_object_cache[data_object_cache_key] is not None

sql = (
exp.select("1")
Expand Down
4 changes: 2 additions & 2 deletions sqlmesh/core/engine_adapter/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,11 @@ def _create_column_comments(
exc_info=True,
)

def create_table_like(
def _create_table_like(
self,
target_table_name: TableName,
source_table_name: TableName,
exists: bool = True,
exists: bool,
**kwargs: t.Any,
) -> None:
self.execute(
Expand Down
Loading