diff --git a/.circleci/continue_config.yml b/.circleci/continue_config.yml index c549c0ae78..c4b7bcbd53 100644 --- a/.circleci/continue_config.yml +++ b/.circleci/continue_config.yml @@ -148,7 +148,7 @@ jobs: command: ./.circleci/test_migration.sh sushi "--gateway duckdb_persistent" - run: name: Run the migration test - sushi_dbt - command: ./.circleci/test_migration.sh sushi_dbt "--config migration_test_config" + command: ./.circleci/test_migration.sh sushi_dbt "--config migration_test_config" ui_style: docker: diff --git a/.gitignore b/.gitignore index 72b41b5ce1..16593984dd 100644 --- a/.gitignore +++ b/.gitignore @@ -138,6 +138,12 @@ dmypy.json *~ *# +# Vim +*.swp +*.swo +.null-ls* + + *.duckdb *.duckdb.wal @@ -158,3 +164,4 @@ spark-warehouse/ # claude .claude/ + diff --git a/sqlmesh/core/_typing.py b/sqlmesh/core/_typing.py index e495df169e..8e28312c1a 100644 --- a/sqlmesh/core/_typing.py +++ b/sqlmesh/core/_typing.py @@ -11,6 +11,7 @@ SessionProperties = t.Dict[str, t.Union[exp.Expression, str, int, float, bool]] CustomMaterializationProperties = t.Dict[str, t.Union[exp.Expression, str, int, float, bool]] + if sys.version_info >= (3, 11): from typing import Self as Self else: diff --git a/sqlmesh/core/engine_adapter/_typing.py b/sqlmesh/core/engine_adapter/_typing.py index 98821bb2d4..77bcf2c015 100644 --- a/sqlmesh/core/engine_adapter/_typing.py +++ b/sqlmesh/core/engine_adapter/_typing.py @@ -30,3 +30,5 @@ ] QueryOrDF = t.Union[Query, DF] + GrantsConfig = t.Dict[str, t.List[str]] + DCL = t.TypeVar("DCL", exp.Grant, exp.Revoke) diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index d9cc4f44a2..ebbf136cd1 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -63,6 +63,7 @@ from sqlmesh.core.engine_adapter._typing import ( DF, BigframeSession, + GrantsConfig, PySparkDataFrame, PySparkSession, Query, @@ -114,6 +115,7 @@ class EngineAdapter: SUPPORTS_TUPLE_IN = True HAS_VIEW_BINDING = False SUPPORTS_REPLACE_TABLE = True + SUPPORTS_GRANTS = False DEFAULT_CATALOG_TYPE = DIALECT QUOTE_IDENTIFIERS_IN_VIEWS = True MAX_IDENTIFIER_LENGTH: t.Optional[int] = None @@ -2478,6 +2480,33 @@ def wap_publish(self, table_name: TableName, wap_id: str) -> None: """ raise NotImplementedError(f"Engine does not support WAP: {type(self)}") + def sync_grants_config( + self, + table: exp.Table, + grants_config: GrantsConfig, + table_type: DataObjectType = DataObjectType.TABLE, + ) -> None: + """Applies the grants_config to a table authoritatively. + It first compares the specified grants against the current grants, and then + applies the diffs to the table by revoking and granting privileges as needed. + + Args: + table: The table/view to apply grants to. + grants_config: Dictionary mapping privileges to lists of grantees. + table_type: The type of database object (TABLE, VIEW, MATERIALIZED_VIEW). + """ + if not self.SUPPORTS_GRANTS: + raise NotImplementedError(f"Engine does not support grants: {type(self)}") + + current_grants = self._get_current_grants_config(table) + new_grants, revoked_grants = self._diff_grants_configs(grants_config, current_grants) + revoke_exprs = self._revoke_grants_config_expr(table, revoked_grants, table_type) + grant_exprs = self._apply_grants_config_expr(table, new_grants, table_type) + dcl_exprs = revoke_exprs + grant_exprs + + if dcl_exprs: + self.execute(dcl_exprs) + @contextlib.contextmanager def transaction( self, @@ -3029,6 +3058,124 @@ def _check_identifier_length(self, expression: exp.Expression) -> None: def get_table_last_modified_ts(self, table_names: t.List[TableName]) -> t.List[int]: raise NotImplementedError() + @classmethod + def _diff_grants_configs( + cls, new_config: GrantsConfig, old_config: GrantsConfig + ) -> t.Tuple[GrantsConfig, GrantsConfig]: + """Compute additions and removals between two grants configurations. + + This method compares new (desired) and old (current) GrantsConfigs case-insensitively + for both privilege keys and grantees, while preserving original casing + in the output GrantsConfigs. + + Args: + new_config: Desired grants configuration (specified by the user). + old_config: Current grants configuration (returned by the database). + + Returns: + A tuple of (additions, removals) GrantsConfig where: + - additions contains privileges/grantees present in new_config but not in old_config + - additions uses keys and grantee strings from new_config (user-specified casing) + - removals contains privileges/grantees present in old_config but not in new_config + - removals uses keys and grantee strings from old_config (database-returned casing) + + Notes: + - Comparison is case-insensitive using casefold(); original casing is preserved in results. + - Overlapping grantees (case-insensitive) are excluded from the results. + """ + + def _diffs(config1: GrantsConfig, config2: GrantsConfig) -> GrantsConfig: + diffs: GrantsConfig = {} + cf_config2 = {k.casefold(): {g.casefold() for g in v} for k, v in config2.items()} + for key, grantees in config1.items(): + cf_key = key.casefold() + + # Missing key (add all grantees) + if cf_key not in cf_config2: + diffs[key] = grantees.copy() + continue + + # Include only grantees not in config2 + cf_grantees2 = cf_config2[cf_key] + diff_grantees = [] + for grantee in grantees: + if grantee.casefold() not in cf_grantees2: + diff_grantees.append(grantee) + if diff_grantees: + diffs[key] = diff_grantees + return diffs + + return _diffs(new_config, old_config), _diffs(old_config, new_config) + + def _get_current_grants_config(self, table: exp.Table) -> GrantsConfig: + """Returns current grants for a table as a dictionary. + + This method queries the database and returns the current grants/permissions + for the given table, parsed into a dictionary format. The it handles + case-insensitive comparison between these current grants and the desired + grants from model configuration. + + Args: + table: The table/view to query grants for. + + Returns: + Dictionary mapping permissions to lists of grantees. Permission names + should be returned as the database provides them (typically uppercase + for standard SQL permissions, but engine-specific roles may vary). + + Raises: + NotImplementedError: If the engine does not support grants. + """ + if not self.SUPPORTS_GRANTS: + raise NotImplementedError(f"Engine does not support grants: {type(self)}") + raise NotImplementedError("Subclass must implement get_current_grants") + + def _apply_grants_config_expr( + self, + table: exp.Table, + grants_config: GrantsConfig, + table_type: DataObjectType = DataObjectType.TABLE, + ) -> t.List[exp.Expression]: + """Returns SQLGlot Grant expressions to apply grants to a table. + + Args: + table: The table/view to grant permissions on. + grants_config: Dictionary mapping permissions to lists of grantees. + table_type: The type of database object (TABLE, VIEW, MATERIALIZED_VIEW). + + Returns: + List of SQLGlot expressions for grant operations. + + Raises: + NotImplementedError: If the engine does not support grants. + """ + if not self.SUPPORTS_GRANTS: + raise NotImplementedError(f"Engine does not support grants: {type(self)}") + raise NotImplementedError("Subclass must implement _apply_grants_config_expr") + + def _revoke_grants_config_expr( + self, + table: exp.Table, + grants_config: GrantsConfig, + table_type: DataObjectType = DataObjectType.TABLE, + ) -> t.List[exp.Expression]: + """Returns SQLGlot expressions to revoke grants from a table. + + Args: + table: The table/view to revoke permissions from. + grants_config: Dictionary mapping permissions to lists of grantees. + table_type: The type of database object (TABLE, VIEW, MATERIALIZED_VIEW). + + Returns: + List of SQLGlot expressions for revoke operations. + + Raises: + NotImplementedError: If the engine does not support grants. + """ + if not self.SUPPORTS_GRANTS: + raise NotImplementedError(f"Engine does not support grants: {type(self)}") + raise NotImplementedError("Subclass must implement _revoke_grants_config_expr") + class EngineAdapterWithIndexSupport(EngineAdapter): SUPPORTS_INDEXES = True diff --git a/sqlmesh/core/engine_adapter/base_postgres.py b/sqlmesh/core/engine_adapter/base_postgres.py index 3de975d6a5..11f56da133 100644 --- a/sqlmesh/core/engine_adapter/base_postgres.py +++ b/sqlmesh/core/engine_adapter/base_postgres.py @@ -62,6 +62,7 @@ def columns( raise SQLMeshError( f"Could not get columns for table '{table.sql(dialect=self.dialect)}'. Table not found." ) + return { column_name: exp.DataType.build(data_type, dialect=self.dialect, udt=True) for column_name, data_type in resp @@ -196,3 +197,10 @@ def _get_data_objects( ) for row in df.itertuples() ] + + def _get_current_schema(self) -> str: + """Returns the current default schema for the connection.""" + result = self.fetchone(exp.select(exp.func("current_schema"))) + if result and result[0]: + return result[0] + return "public" diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index 09fd7537ef..59a56b6ace 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -11,6 +11,7 @@ from sqlmesh.core.engine_adapter.base import _get_data_object_cache_key from sqlmesh.core.engine_adapter.mixins import ( ClusteredByMixin, + GrantsFromInfoSchemaMixin, RowDiffMixin, TableAlterClusterByOperation, ) @@ -40,7 +41,7 @@ from google.cloud.bigquery.table import Table as BigQueryTable from sqlmesh.core._typing import SchemaName, SessionProperties, TableName - from sqlmesh.core.engine_adapter._typing import BigframeSession, DF, Query + from sqlmesh.core.engine_adapter._typing import BigframeSession, DCL, DF, GrantsConfig, Query from sqlmesh.core.engine_adapter.base import QueryOrDF @@ -55,7 +56,7 @@ @set_catalog() -class BigQueryEngineAdapter(ClusteredByMixin, RowDiffMixin): +class BigQueryEngineAdapter(ClusteredByMixin, RowDiffMixin, GrantsFromInfoSchemaMixin): """ BigQuery Engine Adapter using the `google-cloud-bigquery` library's DB API. """ @@ -65,6 +66,11 @@ class BigQueryEngineAdapter(ClusteredByMixin, RowDiffMixin): SUPPORTS_TRANSACTIONS = False SUPPORTS_MATERIALIZED_VIEWS = True SUPPORTS_CLONING = True + SUPPORTS_GRANTS = True + CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expression = exp.func("session_user") + SUPPORTS_MULTIPLE_GRANT_PRINCIPALS = True + USE_CATALOG_IN_GRANTS = True + GRANT_INFORMATION_SCHEMA_TABLE_NAME = "OBJECT_PRIVILEGES" MAX_TABLE_COMMENT_LENGTH = 1024 MAX_COLUMN_COMMENT_LENGTH = 1024 SUPPORTS_QUERY_EXECUTION_TRACKING = True @@ -1326,6 +1332,108 @@ def _session_id(self) -> t.Any: def _session_id(self, value: t.Any) -> None: self._connection_pool.set_attribute("session_id", value) + def _get_current_schema(self) -> str: + raise NotImplementedError("BigQuery does not support current schema") + + def _get_bq_dataset_location(self, project: str, dataset: str) -> str: + return self._db_call(self.client.get_dataset, dataset_ref=f"{project}.{dataset}").location + + def _get_grant_expression(self, table: exp.Table) -> exp.Expression: + if not table.db: + raise ValueError( + f"Table {table.sql(dialect=self.dialect)} does not have a schema (dataset)" + ) + project = table.catalog or self.get_current_catalog() + if not project: + raise ValueError( + f"Table {table.sql(dialect=self.dialect)} does not have a catalog (project)" + ) + + dataset = table.db + table_name = table.name + location = self._get_bq_dataset_location(project, dataset) + + # https://cloud.google.com/bigquery/docs/information-schema-object-privileges + # OBJECT_PRIVILEGES is a project-level INFORMATION_SCHEMA view with regional qualifier + object_privileges_table = exp.to_table( + f"`{project}`.`region-{location}`.INFORMATION_SCHEMA.{self.GRANT_INFORMATION_SCHEMA_TABLE_NAME}", + dialect=self.dialect, + ) + return ( + exp.select("privilege_type", "grantee") + .from_(object_privileges_table) + .where( + exp.and_( + exp.column("object_schema").eq(exp.Literal.string(dataset)), + exp.column("object_name").eq(exp.Literal.string(table_name)), + # Filter out current_user + # BigQuery grantees format: "user:email" or "group:name" + exp.func("split", exp.column("grantee"), exp.Literal.string(":"))[ + exp.func("OFFSET", exp.Literal.number("1")) + ].neq(self.CURRENT_USER_OR_ROLE_EXPRESSION), + ) + ) + ) + + @staticmethod + def _grant_object_kind(table_type: DataObjectType) -> str: + if table_type == DataObjectType.VIEW: + return "VIEW" + if table_type == DataObjectType.MATERIALIZED_VIEW: + # We actually need to use "MATERIALIZED VIEW" here even though it's not listed + # as a supported resource_type in the BigQuery DCL doc: + # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-control-language + return "MATERIALIZED VIEW" + return "TABLE" + + def _dcl_grants_config_expr( + self, + dcl_cmd: t.Type[DCL], + table: exp.Table, + grants_config: GrantsConfig, + table_type: DataObjectType = DataObjectType.TABLE, + ) -> t.List[exp.Expression]: + expressions: t.List[exp.Expression] = [] + if not grants_config: + return expressions + + # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-control-language + + def normalize_principal(p: str) -> str: + if ":" not in p: + raise ValueError(f"Principal '{p}' missing a prefix label") + + # allUsers and allAuthenticatedUsers special groups that are cas-sensitive and must start with "specialGroup:" + if p.endswith("allUsers") or p.endswith("allAuthenticatedUsers"): + if not p.startswith("specialGroup:"): + raise ValueError( + f"Special group principal '{p}' must start with 'specialGroup:' prefix label" + ) + return p + + label, principal = p.split(":", 1) + # always lowercase principals + return f"{label}:{principal.lower()}" + + object_kind = self._grant_object_kind(table_type) + for privilege, principals in grants_config.items(): + if not principals: + continue + + noramlized_principals = [exp.Literal.string(normalize_principal(p)) for p in principals] + args: t.Dict[str, t.Any] = { + "privileges": [exp.GrantPrivilege(this=exp.to_identifier(privilege, quoted=True))], + "securable": table.copy(), + "principals": noramlized_principals, + } + + if object_kind: + args["kind"] = exp.Var(this=object_kind) + + expressions.append(dcl_cmd(**args)) # type: ignore[arg-type] + + return expressions + class _ErrorCounter: """ diff --git a/sqlmesh/core/engine_adapter/databricks.py b/sqlmesh/core/engine_adapter/databricks.py index 173e1b08af..7521124684 100644 --- a/sqlmesh/core/engine_adapter/databricks.py +++ b/sqlmesh/core/engine_adapter/databricks.py @@ -5,7 +5,9 @@ from functools import partial from sqlglot import exp + from sqlmesh.core.dialect import to_schema +from sqlmesh.core.engine_adapter.mixins import GrantsFromInfoSchemaMixin from sqlmesh.core.engine_adapter.shared import ( CatalogSupport, DataObject, @@ -28,12 +30,14 @@ logger = logging.getLogger(__name__) -class DatabricksEngineAdapter(SparkEngineAdapter): +class DatabricksEngineAdapter(SparkEngineAdapter, GrantsFromInfoSchemaMixin): DIALECT = "databricks" INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.REPLACE_WHERE SUPPORTS_CLONING = True SUPPORTS_MATERIALIZED_VIEWS = True SUPPORTS_MATERIALIZED_VIEW_SCHEMA = True + SUPPORTS_GRANTS = True + USE_CATALOG_IN_GRANTS = True # Spark has this set to false for compatibility when mixing with Trino but that isn't a concern with Databricks QUOTE_IDENTIFIERS_IN_VIEWS = True SCHEMA_DIFFER_KWARGS = { @@ -151,6 +155,28 @@ def spark(self) -> PySparkSession: def catalog_support(self) -> CatalogSupport: return CatalogSupport.FULL_SUPPORT + @staticmethod + def _grant_object_kind(table_type: DataObjectType) -> str: + if table_type == DataObjectType.VIEW: + return "VIEW" + if table_type == DataObjectType.MATERIALIZED_VIEW: + return "MATERIALIZED VIEW" + return "TABLE" + + def _get_grant_expression(self, table: exp.Table) -> exp.Expression: + # We only care about explicitly granted privileges and not inherited ones + # if this is removed you would see grants inherited from the catalog get returned + expression = super()._get_grant_expression(table) + expression.args["where"].set( + "this", + exp.and_( + expression.args["where"].this, + exp.column("inherited_from").eq(exp.Literal.string("NONE")), + wrap=False, + ), + ) + return expression + def _begin_session(self, properties: SessionProperties) -> t.Any: """Begin a new session.""" # Align the different possible connectors to a single catalog diff --git a/sqlmesh/core/engine_adapter/mixins.py b/sqlmesh/core/engine_adapter/mixins.py index 1d66da0607..c8ef32b9da 100644 --- a/sqlmesh/core/engine_adapter/mixins.py +++ b/sqlmesh/core/engine_adapter/mixins.py @@ -7,8 +7,10 @@ from sqlglot import exp, parse_one from sqlglot.helper import seq_get +from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlmesh.core.engine_adapter.base import EngineAdapter +from sqlmesh.core.engine_adapter.shared import DataObjectType from sqlmesh.core.node import IntervalUnit from sqlmesh.core.dialect import schema_ from sqlmesh.core.schema_diff import TableAlterOperation @@ -16,7 +18,12 @@ if t.TYPE_CHECKING: from sqlmesh.core._typing import TableName - from sqlmesh.core.engine_adapter._typing import DF + from sqlmesh.core.engine_adapter._typing import ( + DCL, + DF, + GrantsConfig, + QueryOrDF, + ) from sqlmesh.core.engine_adapter.base import QueryOrDF logger = logging.getLogger(__name__) @@ -548,3 +555,137 @@ def _normalize_decimal_value(self, expr: exp.Expression, precision: int) -> exp. def _normalize_boolean_value(self, expr: exp.Expression) -> exp.Expression: return exp.cast(expr, "INT") + + +class GrantsFromInfoSchemaMixin(EngineAdapter): + CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expression = exp.func("current_user") + SUPPORTS_MULTIPLE_GRANT_PRINCIPALS = False + USE_CATALOG_IN_GRANTS = False + GRANT_INFORMATION_SCHEMA_TABLE_NAME = "table_privileges" + + @staticmethod + @abc.abstractmethod + def _grant_object_kind(table_type: DataObjectType) -> t.Optional[str]: + pass + + @abc.abstractmethod + def _get_current_schema(self) -> str: + pass + + def _dcl_grants_config_expr( + self, + dcl_cmd: t.Type[DCL], + table: exp.Table, + grants_config: GrantsConfig, + table_type: DataObjectType = DataObjectType.TABLE, + ) -> t.List[exp.Expression]: + expressions: t.List[exp.Expression] = [] + if not grants_config: + return expressions + + object_kind = self._grant_object_kind(table_type) + for privilege, principals in grants_config.items(): + args: t.Dict[str, t.Any] = { + "privileges": [exp.GrantPrivilege(this=exp.Var(this=privilege))], + "securable": table.copy(), + } + if object_kind: + args["kind"] = exp.Var(this=object_kind) + if self.SUPPORTS_MULTIPLE_GRANT_PRINCIPALS: + args["principals"] = [ + normalize_identifiers( + parse_one(principal, into=exp.GrantPrincipal, dialect=self.dialect), + dialect=self.dialect, + ) + for principal in principals + ] + expressions.append(dcl_cmd(**args)) # type: ignore[arg-type] + else: + for principal in principals: + args["principals"] = [ + normalize_identifiers( + parse_one(principal, into=exp.GrantPrincipal, dialect=self.dialect), + dialect=self.dialect, + ) + ] + expressions.append(dcl_cmd(**args)) # type: ignore[arg-type] + + return expressions + + def _apply_grants_config_expr( + self, + table: exp.Table, + grants_config: GrantsConfig, + table_type: DataObjectType = DataObjectType.TABLE, + ) -> t.List[exp.Expression]: + return self._dcl_grants_config_expr(exp.Grant, table, grants_config, table_type) + + def _revoke_grants_config_expr( + self, + table: exp.Table, + grants_config: GrantsConfig, + table_type: DataObjectType = DataObjectType.TABLE, + ) -> t.List[exp.Expression]: + return self._dcl_grants_config_expr(exp.Revoke, table, grants_config, table_type) + + def _get_grant_expression(self, table: exp.Table) -> exp.Expression: + schema_identifier = table.args.get("db") or normalize_identifiers( + exp.to_identifier(self._get_current_schema(), quoted=True), dialect=self.dialect + ) + schema_name = schema_identifier.this + table_name = table.args.get("this").this # type: ignore + + grant_conditions = [ + exp.column("table_schema").eq(exp.Literal.string(schema_name)), + exp.column("table_name").eq(exp.Literal.string(table_name)), + exp.column("grantor").eq(self.CURRENT_USER_OR_ROLE_EXPRESSION), + exp.column("grantee").neq(self.CURRENT_USER_OR_ROLE_EXPRESSION), + ] + + info_schema_table = normalize_identifiers( + exp.table_(self.GRANT_INFORMATION_SCHEMA_TABLE_NAME, db="information_schema"), + dialect=self.dialect, + ) + if self.USE_CATALOG_IN_GRANTS: + catalog_identifier = table.args.get("catalog") + if not catalog_identifier: + catalog_name = self.get_current_catalog() + if not catalog_name: + raise SQLMeshError( + "Current catalog could not be determined for fetching grants. This is unexpected." + ) + catalog_identifier = normalize_identifiers( + exp.to_identifier(catalog_name, quoted=True), dialect=self.dialect + ) + catalog_name = catalog_identifier.this + info_schema_table.set("catalog", catalog_identifier.copy()) + grant_conditions.insert( + 0, exp.column("table_catalog").eq(exp.Literal.string(catalog_name)) + ) + + return ( + exp.select("privilege_type", "grantee") + .from_(info_schema_table) + .where(exp.and_(*grant_conditions)) + ) + + def _get_current_grants_config(self, table: exp.Table) -> GrantsConfig: + grant_expr = self._get_grant_expression(table) + + results = self.fetchall(grant_expr) + + grants_dict: GrantsConfig = {} + for privilege_raw, grantee_raw in results: + if privilege_raw is None or grantee_raw is None: + continue + + privilege = str(privilege_raw) + grantee = str(grantee_raw) + if not privilege or not grantee: + continue + + grantees = grants_dict.setdefault(privilege, []) + if grantee not in grantees: + grantees.append(grantee) + + return grants_dict diff --git a/sqlmesh/core/engine_adapter/postgres.py b/sqlmesh/core/engine_adapter/postgres.py index 79431ee360..3dd108cf91 100644 --- a/sqlmesh/core/engine_adapter/postgres.py +++ b/sqlmesh/core/engine_adapter/postgres.py @@ -12,6 +12,7 @@ PandasNativeFetchDFSupportMixin, RowDiffMixin, logical_merge, + GrantsFromInfoSchemaMixin, ) from sqlmesh.core.engine_adapter.shared import set_catalog @@ -28,14 +29,19 @@ class PostgresEngineAdapter( PandasNativeFetchDFSupportMixin, GetCurrentCatalogFromFunctionMixin, RowDiffMixin, + GrantsFromInfoSchemaMixin, ): DIALECT = "postgres" + SUPPORTS_GRANTS = True SUPPORTS_INDEXES = True HAS_VIEW_BINDING = True CURRENT_CATALOG_EXPRESSION = exp.column("current_catalog") SUPPORTS_REPLACE_TABLE = False MAX_IDENTIFIER_LENGTH: t.Optional[int] = 63 SUPPORTS_QUERY_EXECUTION_TRACKING = True + GRANT_INFORMATION_SCHEMA_TABLE_NAME = "role_table_grants" + CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expression = exp.column("current_role") + SUPPORTS_MULTIPLE_GRANT_PRINCIPALS = True SCHEMA_DIFFER_KWARGS = { "parameterized_type_defaults": { # DECIMAL without precision is "up to 131072 digits before the decimal point; up to 16383 digits after the decimal point" diff --git a/sqlmesh/core/engine_adapter/redshift.py b/sqlmesh/core/engine_adapter/redshift.py index 7979268473..03dc89053e 100644 --- a/sqlmesh/core/engine_adapter/redshift.py +++ b/sqlmesh/core/engine_adapter/redshift.py @@ -14,6 +14,7 @@ VarcharSizeWorkaroundMixin, RowDiffMixin, logical_merge, + GrantsFromInfoSchemaMixin, ) from sqlmesh.core.engine_adapter.shared import ( CommentCreationView, @@ -40,12 +41,15 @@ class RedshiftEngineAdapter( NonTransactionalTruncateMixin, VarcharSizeWorkaroundMixin, RowDiffMixin, + GrantsFromInfoSchemaMixin, ): DIALECT = "redshift" CURRENT_CATALOG_EXPRESSION = exp.func("current_database") # Redshift doesn't support comments for VIEWs WITH NO SCHEMA BINDING (which we always use) COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED SUPPORTS_REPLACE_TABLE = False + SUPPORTS_GRANTS = True + SUPPORTS_MULTIPLE_GRANT_PRINCIPALS = True SCHEMA_DIFFER_KWARGS = { "parameterized_type_defaults": { diff --git a/sqlmesh/core/engine_adapter/risingwave.py b/sqlmesh/core/engine_adapter/risingwave.py index fdcee90f0f..61b44f5bbb 100644 --- a/sqlmesh/core/engine_adapter/risingwave.py +++ b/sqlmesh/core/engine_adapter/risingwave.py @@ -32,6 +32,7 @@ class RisingwaveEngineAdapter(PostgresEngineAdapter): SUPPORTS_MATERIALIZED_VIEWS = True SUPPORTS_TRANSACTIONS = False MAX_IDENTIFIER_LENGTH = None + SUPPORTS_GRANTS = False def columns( self, table_name: TableName, include_pseudo_columns: bool = False diff --git a/sqlmesh/core/engine_adapter/snowflake.py b/sqlmesh/core/engine_adapter/snowflake.py index 1554589779..a8eabe070d 100644 --- a/sqlmesh/core/engine_adapter/snowflake.py +++ b/sqlmesh/core/engine_adapter/snowflake.py @@ -15,6 +15,7 @@ GetCurrentCatalogFromFunctionMixin, ClusteredByMixin, RowDiffMixin, + GrantsFromInfoSchemaMixin, ) from sqlmesh.core.engine_adapter.shared import ( CatalogSupport, @@ -34,7 +35,12 @@ import pandas as pd from sqlmesh.core._typing import SchemaName, SessionProperties, TableName - from sqlmesh.core.engine_adapter._typing import DF, Query, QueryOrDF, SnowparkSession + from sqlmesh.core.engine_adapter._typing import ( + DF, + Query, + QueryOrDF, + SnowparkSession, + ) from sqlmesh.core.node import IntervalUnit @@ -46,7 +52,9 @@ "drop_catalog": CatalogSupport.REQUIRES_SET_CATALOG, # needs a catalog to issue a query to information_schema.databases even though the result is global } ) -class SnowflakeEngineAdapter(GetCurrentCatalogFromFunctionMixin, ClusteredByMixin, RowDiffMixin): +class SnowflakeEngineAdapter( + GetCurrentCatalogFromFunctionMixin, ClusteredByMixin, RowDiffMixin, GrantsFromInfoSchemaMixin +): DIALECT = "snowflake" SUPPORTS_MATERIALIZED_VIEWS = True SUPPORTS_MATERIALIZED_VIEW_SCHEMA = True @@ -74,6 +82,9 @@ class SnowflakeEngineAdapter(GetCurrentCatalogFromFunctionMixin, ClusteredByMixi MANAGED_TABLE_KIND = "DYNAMIC TABLE" SNOWPARK = "snowpark" SUPPORTS_QUERY_EXECUTION_TRACKING = True + SUPPORTS_GRANTS = True + CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expression = exp.func("CURRENT_ROLE") + USE_CATALOG_IN_GRANTS = True @contextlib.contextmanager def session(self, properties: SessionProperties) -> t.Iterator[None]: @@ -128,6 +139,23 @@ def snowpark(self) -> t.Optional[SnowparkSession]: def catalog_support(self) -> CatalogSupport: return CatalogSupport.FULL_SUPPORT + @staticmethod + def _grant_object_kind(table_type: DataObjectType) -> str: + if table_type == DataObjectType.VIEW: + return "VIEW" + if table_type == DataObjectType.MATERIALIZED_VIEW: + return "MATERIALIZED VIEW" + if table_type == DataObjectType.MANAGED_TABLE: + return "DYNAMIC TABLE" + return "TABLE" + + def _get_current_schema(self) -> str: + """Returns the current default schema for the connection.""" + result = self.fetchone("SELECT CURRENT_SCHEMA()") + if not result or not result[0]: + raise SQLMeshError("Unable to determine current schema") + return str(result[0]) + def _create_catalog(self, catalog_name: exp.Identifier) -> None: props = exp.Properties( expressions=[exp.SchemaCommentProperty(this=exp.Literal.string(c.SQLMESH_MANAGED))] @@ -533,13 +561,32 @@ def _get_data_objects( for row in df.rename(columns={col: col.lower() for col in df.columns}).itertuples() ] + def _get_grant_expression(self, table: exp.Table) -> exp.Expression: + # Upon execute the catalog in table expressions are properly normalized to handle the case where a user provides + # the default catalog in their connection config. This doesn't though update catalogs in strings like when querying + # the information schema. So we need to manually replace those here. + expression = super()._get_grant_expression(table) + for col_exp in expression.find_all(exp.Column): + if col_exp.this.name == "table_catalog": + and_exp = col_exp.parent + assert and_exp is not None, "Expected column expression to have a parent" + assert and_exp.expression, "Expected AND expression to have an expression" + normalized_catalog = self._normalize_catalog( + exp.table_("placeholder", db="placeholder", catalog=and_exp.expression.this) + ) + and_exp.set( + "expression", + exp.Literal.string(normalized_catalog.args["catalog"].alias_or_name), + ) + return expression + def set_current_catalog(self, catalog: str) -> None: self.execute(exp.Use(this=exp.to_identifier(catalog))) def set_current_schema(self, schema: str) -> None: self.execute(exp.Use(kind="SCHEMA", this=to_schema(schema))) - def _to_sql(self, expression: exp.Expression, quote: bool = True, **kwargs: t.Any) -> str: + def _normalize_catalog(self, expression: exp.Expression) -> exp.Expression: # note: important to use self._default_catalog instead of the self.default_catalog property # otherwise we get RecursionError: maximum recursion depth exceeded # because it calls get_current_catalog(), which executes a query, which needs the default catalog, which calls get_current_catalog()... etc @@ -572,8 +619,12 @@ def catalog_rewriter(node: exp.Expression) -> exp.Expression: # Snowflake connection config. This is because the catalog present on the model gets normalized and quoted to match # the source dialect, which isnt always compatible with Snowflake expression = expression.transform(catalog_rewriter) + return expression - return super()._to_sql(expression=expression, quote=quote, **kwargs) + def _to_sql(self, expression: exp.Expression, quote: bool = True, **kwargs: t.Any) -> str: + return super()._to_sql( + expression=self._normalize_catalog(expression), quote=quote, **kwargs + ) def _create_column_comments( self, diff --git a/sqlmesh/core/engine_adapter/spark.py b/sqlmesh/core/engine_adapter/spark.py index b2d6a9cbb5..5216b0a329 100644 --- a/sqlmesh/core/engine_adapter/spark.py +++ b/sqlmesh/core/engine_adapter/spark.py @@ -397,7 +397,7 @@ def get_current_catalog(self) -> t.Optional[str]: def set_current_catalog(self, catalog_name: str) -> None: self.connection.set_current_catalog(catalog_name) - def get_current_database(self) -> str: + def _get_current_schema(self) -> str: if self._use_spark_session: return self.spark.catalog.currentDatabase() return self.fetchone(exp.select(exp.func("current_database")))[0] # type: ignore @@ -539,7 +539,7 @@ def _ensure_fqn(self, table_name: TableName) -> exp.Table: if not table.catalog: table.set("catalog", self.get_current_catalog()) if not table.db: - table.set("db", self.get_current_database()) + table.set("db", self._get_current_schema()) return table def _build_create_comment_column_exp( diff --git a/sqlmesh/core/model/common.py b/sqlmesh/core/model/common.py index 0a55f80cee..d2b9a11c08 100644 --- a/sqlmesh/core/model/common.py +++ b/sqlmesh/core/model/common.py @@ -641,6 +641,7 @@ def parse_strings_with_macro_refs(value: t.Any, dialect: DialectType) -> t.Any: "physical_properties_", "virtual_properties_", "materialization_properties_", + "grants_", mode="before", check_fields=False, )(parse_properties) diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index 974901cb55..f81dae004b 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -67,6 +67,7 @@ from sqlmesh.core.context import ExecutionContext from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.engine_adapter._typing import QueryOrDF + from sqlmesh.core.engine_adapter.shared import DataObjectType from sqlmesh.core.linter.rule import Rule from sqlmesh.core.snapshot import DeployabilityIndex, Node, Snapshot from sqlmesh.utils.jinja import MacroReference @@ -1186,6 +1187,8 @@ def metadata_hash(self) -> str: gen(self.session_properties_) if self.session_properties_ else None, *[gen(g) for g in self.grains], *self._audit_metadata_hash_values(), + json.dumps(self.grants, sort_keys=True) if self.grants else None, + self.grants_target_layer, ] for key, value in (self.virtual_properties or {}).items(): @@ -1210,6 +1213,24 @@ def is_model(self) -> bool: """Return True if this is a model node""" return True + @property + def grants_table_type(self) -> DataObjectType: + """Get the table type for grants application (TABLE, VIEW, MATERIALIZED_VIEW). + + Returns: + The DataObjectType that should be used when applying grants to this model. + """ + from sqlmesh.core.engine_adapter.shared import DataObjectType + + if self.kind.is_view: + if hasattr(self.kind, "materialized") and getattr(self.kind, "materialized", False): + return DataObjectType.MATERIALIZED_VIEW + return DataObjectType.VIEW + if self.kind.is_managed: + return DataObjectType.MANAGED_TABLE + # All other materialized models are tables + return DataObjectType.TABLE + @property def _additional_metadata(self) -> t.List[str]: additional_metadata = [] @@ -1823,6 +1844,12 @@ def _data_hash_values_no_sql(self) -> t.List[str]: for column_name, column_hash in self.column_hashes.items(): data.append(column_name) data.append(column_hash) + + # Include grants in data hash for seed models to force recreation on grant changes + # since seed models don't support migration + data.append(json.dumps(self.grants, sort_keys=True) if self.grants else "") + data.append(self.grants_target_layer) + return data @@ -3023,6 +3050,8 @@ def render_expression( "optimize_query": str, "virtual_environment_mode": lambda value: exp.Literal.string(value.value), "dbt_node_info_": lambda value: value.to_expression(), + "grants_": lambda value: value, + "grants_target_layer": lambda value: exp.Literal.string(value.value), } diff --git a/sqlmesh/core/model/kind.py b/sqlmesh/core/model/kind.py index 7b8e88ac17..cc4c6f0826 100644 --- a/sqlmesh/core/model/kind.py +++ b/sqlmesh/core/model/kind.py @@ -154,6 +154,11 @@ def full_history_restatement_only(self) -> bool: def supports_python_models(self) -> bool: return True + @property + def supports_grants(self) -> bool: + """Whether this model kind supports grants configuration.""" + return self.is_materialized or self.is_view + class ModelKindName(str, ModelKindMixin, Enum): """The kind of model, determining how this data is computed and stored in the warehouse.""" diff --git a/sqlmesh/core/model/meta.py b/sqlmesh/core/model/meta.py index 9208fbdbb5..c48b7d1524 100644 --- a/sqlmesh/core/model/meta.py +++ b/sqlmesh/core/model/meta.py @@ -1,6 +1,7 @@ from __future__ import annotations import typing as t +from enum import Enum from functools import cached_property from typing_extensions import Self @@ -13,6 +14,7 @@ from sqlmesh.core.config.common import VirtualEnvironmentMode from sqlmesh.core.config.linter import LinterConfig from sqlmesh.core.dialect import normalize_model_name +from sqlmesh.utils import classproperty from sqlmesh.core.model.common import ( bool_validator, default_catalog_validator, @@ -46,10 +48,41 @@ if t.TYPE_CHECKING: from sqlmesh.core._typing import CustomMaterializationProperties, SessionProperties + from sqlmesh.core.engine_adapter._typing import GrantsConfig FunctionCall = t.Tuple[str, t.Dict[str, exp.Expression]] +class GrantsTargetLayer(str, Enum): + """Target layer(s) where grants should be applied.""" + + ALL = "all" + PHYSICAL = "physical" + VIRTUAL = "virtual" + + @classproperty + def default(cls) -> "GrantsTargetLayer": + return GrantsTargetLayer.VIRTUAL + + @property + def is_all(self) -> bool: + return self == GrantsTargetLayer.ALL + + @property + def is_physical(self) -> bool: + return self == GrantsTargetLayer.PHYSICAL + + @property + def is_virtual(self) -> bool: + return self == GrantsTargetLayer.VIRTUAL + + def __str__(self) -> str: + return self.name + + def __repr__(self) -> str: + return str(self) + + class ModelMeta(_Node): """Metadata for models which can be defined in SQL.""" @@ -85,6 +118,8 @@ class ModelMeta(_Node): ) formatting: t.Optional[bool] = Field(default=None, exclude=True) virtual_environment_mode: VirtualEnvironmentMode = VirtualEnvironmentMode.default + grants_: t.Optional[exp.Tuple] = Field(default=None, alias="grants") + grants_target_layer: GrantsTargetLayer = GrantsTargetLayer.default _bool_validator = bool_validator _model_kind_validator = model_kind_validator @@ -287,6 +322,14 @@ def _refs_validator(cls, vs: t.Any, info: ValidationInfo) -> t.List[exp.Expressi def ignored_rules_validator(cls, vs: t.Any) -> t.Any: return LinterConfig._validate_rules(vs) + @field_validator("grants_target_layer", mode="before") + def _grants_target_layer_validator(cls, v: t.Any) -> t.Any: + if isinstance(v, exp.Identifier): + return v.this + if isinstance(v, exp.Literal) and v.is_string: + return v.this + return v + @field_validator("session_properties_", mode="before") def session_properties_validator(cls, v: t.Any, info: ValidationInfo) -> t.Any: # use the generic properties validator to parse the session properties @@ -394,6 +437,10 @@ def _root_validator(self) -> Self: f"Model {self.name} has `storage_format` set to a table format '{storage_format}' which is deprecated. Please use the `table_format` property instead." ) + # Validate grants configuration for model kind support + if self.grants is not None and not kind.supports_grants: + raise ValueError(f"grants cannot be set for {kind.name} models") + return self @property @@ -465,6 +512,30 @@ def custom_materialization_properties(self) -> CustomMaterializationProperties: return self.kind.materialization_properties return {} + @cached_property + def grants(self) -> t.Optional[GrantsConfig]: + """A dictionary of grants mapping permission names to lists of grantees.""" + + if self.grants_ is None: + return None + + if not self.grants_.expressions: + return {} + + grants_dict = {} + for eq_expr in self.grants_.expressions: + try: + permission_name = self._validate_config_expression(eq_expr.left) + grantee_list = self._validate_nested_config_values(eq_expr.expression) + grants_dict[permission_name] = grantee_list + except ConfigError as e: + permission_name = ( + eq_expr.left.name if hasattr(eq_expr.left, "name") else str(eq_expr.left) + ) + raise ConfigError(f"Invalid grants configuration for '{permission_name}': {e}") + + return grants_dict if grants_dict else None + @property def all_references(self) -> t.List[Reference]: """All references including grains.""" @@ -529,3 +600,33 @@ def on_additive_change(self) -> OnAdditiveChange: @property def ignored_rules(self) -> t.Set[str]: return self.ignored_rules_ or set() + + def _validate_config_expression(self, expr: exp.Expression) -> str: + if isinstance(expr, (d.MacroFunc, d.MacroVar)): + raise ConfigError(f"Unresolved macro: {expr.sql(dialect=self.dialect)}") + + if isinstance(expr, exp.Null): + raise ConfigError("NULL value") + + if isinstance(expr, exp.Literal): + return str(expr.this).strip() + if isinstance(expr, (exp.Column, exp.Identifier)): + return expr.name + return expr.sql(dialect=self.dialect).strip() + + def _validate_nested_config_values(self, value_expr: exp.Expression) -> t.List[str]: + result = [] + + def flatten_expr(expr: exp.Expression) -> None: + if isinstance(expr, exp.Array): + for elem in expr.expressions: + flatten_expr(elem) + elif isinstance(expr, (exp.Tuple, exp.Paren)): + expressions = [expr.unnest()] if isinstance(expr, exp.Paren) else expr.expressions + for elem in expressions: + flatten_expr(elem) + else: + result.append(self._validate_config_expression(expr)) + + flatten_expr(value_expr) + return result diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 1483bdeece..2676709d85 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -39,6 +39,7 @@ from sqlmesh.core.audit import Audit, StandaloneAudit from sqlmesh.core.dialect import schema_ from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy, DataObjectType, DataObject +from sqlmesh.core.model.meta import GrantsTargetLayer from sqlmesh.core.macros import RuntimeStage from sqlmesh.core.model import ( AuditResult, @@ -932,6 +933,7 @@ def _render_and_insert_snapshot( model = snapshot.model adapter = self.get_adapter(model.gateway) evaluation_strategy = _evaluation_strategy(snapshot, adapter) + is_snapshot_deployable = deployability_index.is_deployable(snapshot) queries_or_dfs = self._render_snapshot_for_evaluation( snapshot, @@ -955,6 +957,7 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None: execution_time=execution_time, physical_properties=rendered_physical_properties, render_kwargs=create_render_kwargs, + is_snapshot_deployable=is_snapshot_deployable, ) else: logger.info( @@ -977,6 +980,7 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None: execution_time=execution_time, physical_properties=rendered_physical_properties, render_kwargs=create_render_kwargs, + is_snapshot_deployable=is_snapshot_deployable, ) # DataFrames, unlike SQL expressions, can provide partial results by yielding dataframes. As a result, @@ -1066,6 +1070,7 @@ def _clone_snapshot_in_dev( allow_additive_snapshots=allow_additive_snapshots, run_pre_post_statements=run_pre_post_statements, ) + except Exception: adapter.drop_table(target_table_name) raise @@ -1166,6 +1171,7 @@ def _migrate_target_table( rendered_physical_properties=rendered_physical_properties, dry_run=False, run_pre_post_statements=run_pre_post_statements, + skip_grants=True, # skip grants for tmp table ) try: evaluation_strategy = _evaluation_strategy(snapshot, adapter) @@ -1183,6 +1189,7 @@ def _migrate_target_table( allow_additive_snapshots=allow_additive_snapshots, ignore_destructive=snapshot.model.on_destructive_change.is_ignore, ignore_additive=snapshot.model.on_additive_change.is_ignore, + deployability_index=deployability_index, ) finally: if snapshot.is_materialized: @@ -1232,6 +1239,7 @@ def _promote_snapshot( model=snapshot.model, environment=environment_naming_info.name, snapshots=snapshots, + snapshot=snapshot, **render_kwargs, ) @@ -1431,6 +1439,7 @@ def _execute_create( rendered_physical_properties: t.Dict[str, exp.Expression], dry_run: bool, run_pre_post_statements: bool = True, + skip_grants: bool = False, ) -> None: adapter = self.get_adapter(snapshot.model.gateway) evaluation_strategy = _evaluation_strategy(snapshot, adapter) @@ -1451,11 +1460,14 @@ def _execute_create( table_name=table_name, model=snapshot.model, is_table_deployable=is_table_deployable, + skip_grants=skip_grants, render_kwargs=create_render_kwargs, is_snapshot_deployable=is_snapshot_deployable, is_snapshot_representative=is_snapshot_representative, dry_run=dry_run, physical_properties=rendered_physical_properties, + snapshot=snapshot, + deployability_index=deployability_index, ) if run_pre_post_statements: evaluation_strategy.run_post_statements( @@ -1469,7 +1481,7 @@ def _can_clone(self, snapshot: Snapshot, deployability_index: DeployabilityIndex and snapshot.is_materialized and bool(snapshot.previous_versions) and adapter.SUPPORTS_CLONING - # managed models cannot have their schema mutated because theyre based on queries, so clone + alter wont work + # managed models cannot have their schema mutated because they're based on queries, so clone + alter won't work and not snapshot.is_managed and not snapshot.is_dbt_custom and not deployability_index.is_deployable(snapshot) @@ -1690,6 +1702,7 @@ def create( model: Model, is_table_deployable: bool, render_kwargs: t.Dict[str, t.Any], + skip_grants: bool, **kwargs: t.Any, ) -> None: """Creates the target table or view. @@ -1780,6 +1793,66 @@ def run_post_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None: render_kwargs: Additional key-value arguments to pass when rendering the statements. """ + def _apply_grants( + self, + model: Model, + table_name: str, + target_layer: GrantsTargetLayer, + is_snapshot_deployable: bool = False, + ) -> None: + """Apply grants for a model if grants are configured. + + This method provides consistent grants application across all evaluation strategies. + It ensures that whenever a physical database object (table, view, materialized view) + is created or modified, the appropriate grants are applied. + + Args: + model: The SQLMesh model containing grants configuration + table_name: The target table/view name to apply grants to + target_layer: The grants application layer (physical or virtual) + is_snapshot_deployable: Whether the snapshot is deployable (targeting production) + """ + grants_config = model.grants + if grants_config is None: + return + + if not self.adapter.SUPPORTS_GRANTS: + logger.warning( + f"Engine {self.adapter.__class__.__name__} does not support grants. " + f"Skipping grants application for model {model.name}" + ) + return + + model_grants_target_layer = model.grants_target_layer + deployable_vde_dev_only = ( + is_snapshot_deployable and model.virtual_environment_mode.is_dev_only + ) + + # table_type is always a VIEW in the virtual layer unless model is deployable and VDE is dev_only + # in which case we fall back to the model's model_grants_table_type + if target_layer == GrantsTargetLayer.VIRTUAL and not deployable_vde_dev_only: + model_grants_table_type = DataObjectType.VIEW + else: + model_grants_table_type = model.grants_table_type + + if ( + model_grants_target_layer.is_all + or model_grants_target_layer == target_layer + # Always apply grants in production when VDE is dev_only regardless of target_layer + # since only physical tables are created in production + or deployable_vde_dev_only + ): + logger.info(f"Applying grants for model {model.name} to table {table_name}") + self.adapter.sync_grants_config( + exp.to_table(table_name, dialect=self.adapter.dialect), + grants_config, + model_grants_table_type, + ) + else: + logger.debug( + f"Skipping grants application for model {model.name} in {target_layer} layer" + ) + class SymbolicStrategy(EvaluationStrategy): def insert( @@ -1809,6 +1882,7 @@ def create( model: Model, is_table_deployable: bool, render_kwargs: t.Dict[str, t.Any], + skip_grants: bool, **kwargs: t.Any, ) -> None: pass @@ -1890,6 +1964,17 @@ def promote( view_properties=model.render_virtual_properties(**render_kwargs), ) + snapshot = kwargs.get("snapshot") + deployability_index = kwargs.get("deployability_index") + is_snapshot_deployable = ( + deployability_index.is_deployable(snapshot) + if snapshot and deployability_index + else False + ) + + # Apply grants to the virtual layer (view) after promotion + self._apply_grants(model, view_name, GrantsTargetLayer.VIRTUAL, is_snapshot_deployable) + def demote(self, view_name: str, **kwargs: t.Any) -> None: logger.info("Dropping view '%s'", view_name) self.adapter.drop_view(view_name, cascade=False) @@ -1908,6 +1993,7 @@ def create( model: Model, is_table_deployable: bool, render_kwargs: t.Dict[str, t.Any], + skip_grants: bool, **kwargs: t.Any, ) -> None: ctas_query = model.ctas_query(**render_kwargs) @@ -1952,6 +2038,13 @@ def create( column_descriptions=model.column_descriptions if is_table_deployable else None, ) + # Apply grants after table creation (unless explicitly skipped by caller) + if not skip_grants: + is_snapshot_deployable = kwargs.get("is_snapshot_deployable", False) + self._apply_grants( + model, table_name, GrantsTargetLayer.PHYSICAL, is_snapshot_deployable + ) + def migrate( self, target_table_name: str, @@ -1977,6 +2070,15 @@ def migrate( ) self.adapter.alter_table(alter_operations) + # Apply grants after schema migration + deployability_index = kwargs.get("deployability_index") + is_snapshot_deployable = ( + deployability_index.is_deployable(snapshot) if deployability_index else False + ) + self._apply_grants( + snapshot.model, target_table_name, GrantsTargetLayer.PHYSICAL, is_snapshot_deployable + ) + def delete(self, name: str, **kwargs: t.Any) -> None: _check_table_db_is_physical_schema(name, kwargs["physical_schema"]) self.adapter.drop_table(name, cascade=kwargs.pop("cascade", False)) @@ -1988,6 +2090,7 @@ def _replace_query_for_model( name: str, query_or_df: QueryOrDF, render_kwargs: t.Dict[str, t.Any], + skip_grants: bool = False, **kwargs: t.Any, ) -> None: """Replaces the table for the given model. @@ -2024,6 +2127,11 @@ def _replace_query_for_model( source_columns=source_columns, ) + # Apply grants after table replacement (unless explicitly skipped by caller) + if not skip_grants: + is_snapshot_deployable = kwargs.get("is_snapshot_deployable", False) + self._apply_grants(model, name, GrantsTargetLayer.PHYSICAL, is_snapshot_deployable) + def _get_target_and_source_columns( self, model: Model, @@ -2271,6 +2379,7 @@ def create( model: Model, is_table_deployable: bool, render_kwargs: t.Dict[str, t.Any], + skip_grants: bool, **kwargs: t.Any, ) -> None: model = t.cast(SeedModel, model) @@ -2284,16 +2393,37 @@ def create( ) return - super().create(table_name, model, is_table_deployable, render_kwargs, **kwargs) + super().create( + table_name, + model, + is_table_deployable, + render_kwargs, + skip_grants=True, # Skip grants; they're applied after data insertion + **kwargs, + ) # For seeds we insert data at the time of table creation. try: for index, df in enumerate(model.render_seed()): if index == 0: - self._replace_query_for_model(model, table_name, df, render_kwargs, **kwargs) + self._replace_query_for_model( + model, + table_name, + df, + render_kwargs, + skip_grants=True, # Skip grants; they're applied after data insertion + **kwargs, + ) else: self.adapter.insert_append( table_name, df, target_columns_to_types=model.columns_to_types ) + + if not skip_grants: + # Apply grants after seed table creation and data insertion + is_snapshot_deployable = kwargs.get("is_snapshot_deployable", False) + self._apply_grants( + model, table_name, GrantsTargetLayer.PHYSICAL, is_snapshot_deployable + ) except Exception: self.adapter.drop_table(table_name) raise @@ -2341,6 +2471,7 @@ def create( model: Model, is_table_deployable: bool, render_kwargs: t.Dict[str, t.Any], + skip_grants: bool, **kwargs: t.Any, ) -> None: assert isinstance(model.kind, (SCDType2ByTimeKind, SCDType2ByColumnKind)) @@ -2370,9 +2501,17 @@ def create( model, is_table_deployable, render_kwargs, + skip_grants, **kwargs, ) + if not skip_grants: + # Apply grants after SCD Type 2 table creation + is_snapshot_deployable = kwargs.get("is_snapshot_deployable", False) + self._apply_grants( + model, table_name, GrantsTargetLayer.PHYSICAL, is_snapshot_deployable + ) + def insert( self, table_name: str, @@ -2440,6 +2579,10 @@ def insert( f"Unexpected SCD Type 2 kind: {model.kind}. This is not expected and please report this as a bug." ) + # Apply grants after SCD Type 2 table recreation + is_snapshot_deployable = kwargs.get("is_snapshot_deployable", False) + self._apply_grants(model, table_name, GrantsTargetLayer.PHYSICAL, is_snapshot_deployable) + def append( self, table_name: str, @@ -2496,6 +2639,10 @@ def insert( column_descriptions=model.column_descriptions, ) + # Apply grants after view creation / replacement + is_snapshot_deployable = kwargs.get("is_snapshot_deployable", False) + self._apply_grants(model, table_name, GrantsTargetLayer.PHYSICAL, is_snapshot_deployable) + def append( self, table_name: str, @@ -2512,12 +2659,21 @@ def create( model: Model, is_table_deployable: bool, render_kwargs: t.Dict[str, t.Any], + skip_grants: bool, **kwargs: t.Any, ) -> None: + is_snapshot_deployable = kwargs.get("is_snapshot_deployable", False) + if self.adapter.table_exists(table_name): # Make sure we don't recreate the view to prevent deletion of downstream views in engines with no late # binding support (because of DROP CASCADE). logger.info("View '%s' already exists", table_name) + + if not skip_grants: + # Always apply grants when present, even if view exists, to handle grants updates + self._apply_grants( + model, table_name, GrantsTargetLayer.PHYSICAL, is_snapshot_deployable + ) return logger.info("Creating view '%s'", table_name) @@ -2541,6 +2697,12 @@ def create( column_descriptions=model.column_descriptions if is_table_deployable else None, ) + if not skip_grants: + # Apply grants after view creation + self._apply_grants( + model, table_name, GrantsTargetLayer.PHYSICAL, is_snapshot_deployable + ) + def migrate( self, target_table_name: str, @@ -2567,6 +2729,15 @@ def migrate( column_descriptions=model.column_descriptions, ) + # Apply grants after view migration + deployability_index = kwargs.get("deployability_index") + is_snapshot_deployable = ( + deployability_index.is_deployable(snapshot) if deployability_index else False + ) + self._apply_grants( + snapshot.model, target_table_name, GrantsTargetLayer.PHYSICAL, is_snapshot_deployable + ) + def delete(self, name: str, **kwargs: t.Any) -> None: cascade = kwargs.pop("cascade", False) try: @@ -2723,6 +2894,7 @@ def create( model: Model, is_table_deployable: bool, render_kwargs: t.Dict[str, t.Any], + skip_grants: bool, **kwargs: t.Any, ) -> None: original_query = model.render_query_or_raise(**render_kwargs) @@ -2852,6 +3024,7 @@ def create( model: Model, is_table_deployable: bool, render_kwargs: t.Dict[str, t.Any], + skip_grants: bool, **kwargs: t.Any, ) -> None: is_snapshot_deployable: bool = kwargs["is_snapshot_deployable"] @@ -2870,6 +3043,13 @@ def create( column_descriptions=model.column_descriptions, table_format=model.table_format, ) + + # Apply grants after managed table creation + if not skip_grants: + self._apply_grants( + model, table_name, GrantsTargetLayer.PHYSICAL, is_snapshot_deployable + ) + elif not is_table_deployable: # Only create the dev preview table as a normal table. # For the main table, if the snapshot is cant be deployed to prod (eg upstream is forward-only) do nothing. @@ -2880,6 +3060,7 @@ def create( model=model, is_table_deployable=is_table_deployable, render_kwargs=render_kwargs, + skip_grants=skip_grants, **kwargs, ) @@ -2895,7 +3076,6 @@ def insert( deployability_index: DeployabilityIndex = kwargs["deployability_index"] snapshot: Snapshot = kwargs["snapshot"] is_snapshot_deployable = deployability_index.is_deployable(snapshot) - if is_first_insert and is_snapshot_deployable and not self.adapter.table_exists(table_name): self.adapter.create_managed_table( table_name=table_name, @@ -2908,6 +3088,9 @@ def insert( column_descriptions=model.column_descriptions, table_format=model.table_format, ) + self._apply_grants( + model, table_name, GrantsTargetLayer.PHYSICAL, is_snapshot_deployable + ) elif not is_snapshot_deployable: # Snapshot isnt deployable; update the preview table instead # If the snapshot was deployable, then data would have already been loaded in create() because a managed table would have been created @@ -2956,6 +3139,15 @@ def migrate( f"The schema of the managed model '{target_table_name}' cannot be updated in a forward-only fashion." ) + # Apply grants after verifying no schema changes + deployability_index = kwargs.get("deployability_index") + is_snapshot_deployable = ( + deployability_index.is_deployable(snapshot) if deployability_index else False + ) + self._apply_grants( + snapshot.model, target_table_name, GrantsTargetLayer.PHYSICAL, is_snapshot_deployable + ) + def delete(self, name: str, **kwargs: t.Any) -> None: # a dev preview table is created as a normal table, so it needs to be dropped as a normal table _check_table_db_is_physical_schema(name, kwargs["physical_schema"]) diff --git a/sqlmesh/dbt/basemodel.py b/sqlmesh/dbt/basemodel.py index 0b75955129..3e325f13e6 100644 --- a/sqlmesh/dbt/basemodel.py +++ b/sqlmesh/dbt/basemodel.py @@ -165,7 +165,11 @@ def _validate_hooks(cls, v: t.Union[str, t.List[t.Union[SqlStr, str]]]) -> t.Lis @field_validator("grants", mode="before") @classmethod - def _validate_grants(cls, v: t.Dict[str, str]) -> t.Dict[str, t.List[str]]: + def _validate_grants( + cls, v: t.Optional[t.Dict[str, str]] + ) -> t.Optional[t.Dict[str, t.List[str]]]: + if v is None: + return None return {key: ensure_list(value) for key, value in v.items()} _FIELD_UPDATE_STRATEGY: t.ClassVar[t.Dict[str, UpdateStrategy]] = { diff --git a/sqlmesh/dbt/model.py b/sqlmesh/dbt/model.py index f47283d06e..f21eefe95d 100644 --- a/sqlmesh/dbt/model.py +++ b/sqlmesh/dbt/model.py @@ -679,6 +679,12 @@ def to_sqlmesh( if physical_properties: model_kwargs["physical_properties"] = physical_properties + kind = self.model_kind(context) + + # A falsy grants config (None or {}) is considered as unmanaged per dbt semantics + if self.grants and kind.supports_grants: + model_kwargs["grants"] = self.grants + allow_partials = model_kwargs.pop("allow_partials", None) if allow_partials is None: # Set allow_partials to True for dbt models to preserve the original semantics. diff --git a/sqlmesh/migrations/v0100_add_grants_and_grants_target_layer.py b/sqlmesh/migrations/v0100_add_grants_and_grants_target_layer.py new file mode 100644 index 0000000000..fa23935da0 --- /dev/null +++ b/sqlmesh/migrations/v0100_add_grants_and_grants_target_layer.py @@ -0,0 +1,9 @@ +"""Add grants and grants_target_layer to incremental model metadata hash.""" + + +def migrate_schemas(state_sync, **kwargs): # type: ignore + pass + + +def migrate_rows(state_sync, **kwargs): # type: ignore + pass diff --git a/tests/core/engine_adapter/integration/__init__.py b/tests/core/engine_adapter/integration/__init__.py index c5377e309a..49624154e4 100644 --- a/tests/core/engine_adapter/integration/__init__.py +++ b/tests/core/engine_adapter/integration/__init__.py @@ -5,10 +5,12 @@ import sys import typing as t import time +from contextlib import contextmanager import pandas as pd # noqa: TID253 import pytest from sqlglot import exp, parse_one +from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlmesh import Config, Context, EngineAdapter from sqlmesh.core.config import load_config_from_paths @@ -744,6 +746,106 @@ def upsert_sql_model(self, model_definition: str) -> t.Tuple[Context, SqlModel]: self._context.upsert_model(model) return self._context, model + def _get_create_user_or_role( + self, username: str, password: t.Optional[str] = None + ) -> t.Tuple[str, t.Optional[str]]: + password = password or random_id() + if self.dialect in ["postgres", "redshift"]: + return username, f"CREATE USER \"{username}\" WITH PASSWORD '{password}'" + if self.dialect == "snowflake": + return username, f"CREATE ROLE {username}" + if self.dialect == "databricks": + # Creating an account-level group in Databricks requires making REST API calls so we are going to + # use a pre-created group instead. We assume the suffix on the name is the unique id + return "_".join(username.split("_")[:-1]), None + if self.dialect == "bigquery": + # BigQuery uses IAM service accounts that need to be pre-created + # Pre-created GCP service accounts: + # - sqlmesh-test-admin@{project-id}.iam.gserviceaccount.com + # - sqlmesh-test-analyst@{project-id}.iam.gserviceaccount.com + # - sqlmesh-test-etl-user@{project-id}.iam.gserviceaccount.com + # - sqlmesh-test-reader@{project-id}.iam.gserviceaccount.com + # - sqlmesh-test-user@{project-id}.iam.gserviceaccount.com + # - sqlmesh-test-writer@{project-id}.iam.gserviceaccount.com + role_name = ( + username.replace(f"_{self.test_id}", "").replace("test_", "").replace("_", "-") + ) + project_id = self.engine_adapter.get_current_catalog() + service_account = f"sqlmesh-test-{role_name}@{project_id}.iam.gserviceaccount.com" + return f"serviceAccount:{service_account}", None + raise ValueError(f"User creation not supported for dialect: {self.dialect}") + + def _create_user_or_role(self, username: str, password: t.Optional[str] = None) -> str: + username, create_user_sql = self._get_create_user_or_role(username, password) + if create_user_sql: + self.engine_adapter.execute(create_user_sql) + return username + + @contextmanager + def create_users_or_roles(self, *role_names: str) -> t.Iterator[t.Dict[str, str]]: + created_users = [] + roles = {} + + try: + for role_name in role_names: + user_name = normalize_identifiers( + self.add_test_suffix(f"test_{role_name}"), dialect=self.dialect + ).sql(dialect=self.dialect) + password = random_id() + if self.dialect == "redshift": + password += ( + "A" # redshift requires passwords to have at least one uppercase letter + ) + user_name = self._create_user_or_role(user_name, password) + created_users.append(user_name) + roles[role_name] = user_name + + yield roles + + finally: + for user_name in created_users: + self._cleanup_user_or_role(user_name) + + def get_select_privilege(self) -> str: + if self.dialect == "bigquery": + return "roles/bigquery.dataViewer" + return "SELECT" + + def get_insert_privilege(self) -> str: + if self.dialect == "databricks": + # This would really be "MODIFY" but for the purposes of having this be unique from UPDATE + # we return "MANAGE" instead + return "MANAGE" + if self.dialect == "bigquery": + return "roles/bigquery.dataEditor" + return "INSERT" + + def get_update_privilege(self) -> str: + if self.dialect == "databricks": + return "MODIFY" + if self.dialect == "bigquery": + return "roles/bigquery.dataOwner" + return "UPDATE" + + def _cleanup_user_or_role(self, user_name: str) -> None: + """Helper function to clean up a user/role and all their dependencies.""" + try: + if self.dialect in ["postgres", "redshift"]: + self.engine_adapter.execute(f""" + SELECT pg_terminate_backend(pid) + FROM pg_stat_activity + WHERE usename = '{user_name}' AND pid <> pg_backend_pid() + """) + self.engine_adapter.execute(f'DROP OWNED BY "{user_name}"') + self.engine_adapter.execute(f'DROP USER IF EXISTS "{user_name}"') + elif self.dialect == "snowflake": + self.engine_adapter.execute(f"DROP ROLE IF EXISTS {user_name}") + elif self.dialect in ["databricks", "bigquery"]: + # For Databricks and BigQuery, we use pre-created accounts that should not be deleted + pass + except Exception: + pass + def wait_until(fn: t.Callable[..., bool], attempts=3, wait=5) -> None: current_attempt = 0 diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py index 995875c778..5e976f8dd5 100644 --- a/tests/core/engine_adapter/integration/test_integration.py +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -4027,3 +4027,209 @@ def test_unicode_characters(ctx: TestContext, tmp_path: Path): table_results = ctx.get_metadata_results(schema) assert len(table_results.tables) == 1 assert table_results.tables[0].lower().startswith(schema_name.lower() + "________") + + +def test_sync_grants_config(ctx: TestContext) -> None: + if not ctx.engine_adapter.SUPPORTS_GRANTS: + pytest.skip( + f"Skipping Test since engine adapter {ctx.engine_adapter.dialect} doesn't support grants" + ) + + table = ctx.table("sync_grants_integration") + select_privilege = ctx.get_select_privilege() + insert_privilege = ctx.get_insert_privilege() + update_privilege = ctx.get_update_privilege() + with ctx.create_users_or_roles("reader", "writer", "admin") as roles: + ctx.engine_adapter.create_table(table, {"id": exp.DataType.build("INT")}) + + initial_grants = { + select_privilege: [roles["reader"]], + insert_privilege: [roles["writer"]], + } + ctx.engine_adapter.sync_grants_config(table, initial_grants) + + current_grants = ctx.engine_adapter._get_current_grants_config(table) + assert set(current_grants.get(select_privilege, [])) == {roles["reader"]} + assert set(current_grants.get(insert_privilege, [])) == {roles["writer"]} + + target_grants = { + select_privilege: [roles["writer"], roles["admin"]], + update_privilege: [roles["admin"]], + } + ctx.engine_adapter.sync_grants_config(table, target_grants) + + synced_grants = ctx.engine_adapter._get_current_grants_config(table) + assert set(synced_grants.get(select_privilege, [])) == { + roles["writer"], + roles["admin"], + } + assert set(synced_grants.get(update_privilege, [])) == {roles["admin"]} + assert synced_grants.get(insert_privilege, []) == [] + + +def test_grants_sync_empty_config(ctx: TestContext): + if not ctx.engine_adapter.SUPPORTS_GRANTS: + pytest.skip( + f"Skipping Test since engine adapter {ctx.engine_adapter.dialect} doesn't support grants" + ) + + table = ctx.table("grants_empty_test") + select_privilege = ctx.get_select_privilege() + insert_privilege = ctx.get_insert_privilege() + with ctx.create_users_or_roles("user") as roles: + ctx.engine_adapter.create_table(table, {"id": exp.DataType.build("INT")}) + + initial_grants = { + select_privilege: [roles["user"]], + insert_privilege: [roles["user"]], + } + ctx.engine_adapter.sync_grants_config(table, initial_grants) + + initial_current_grants = ctx.engine_adapter._get_current_grants_config(table) + assert roles["user"] in initial_current_grants.get(select_privilege, []) + assert roles["user"] in initial_current_grants.get(insert_privilege, []) + + ctx.engine_adapter.sync_grants_config(table, {}) + + final_grants = ctx.engine_adapter._get_current_grants_config(table) + assert final_grants == {} + + +def test_grants_case_insensitive_grantees(ctx: TestContext): + if not ctx.engine_adapter.SUPPORTS_GRANTS: + pytest.skip( + f"Skipping Test since engine adapter {ctx.engine_adapter.dialect} doesn't support grants" + ) + + with ctx.create_users_or_roles("reader", "writer") as roles: + table = ctx.table("grants_quoted_test") + ctx.engine_adapter.create_table(table, {"id": exp.DataType.build("INT")}) + + reader = roles["reader"] + writer = roles["writer"] + select_privilege = ctx.get_select_privilege() + + if ctx.dialect == "bigquery": + # BigQuery labels are case sensitive, e.g. serviceAccount + lablel, grantee = writer.split(":", 1) + upper_case_writer = f"{lablel}:{grantee.upper()}" + else: + upper_case_writer = writer.upper() + + grants_config = {select_privilege: [reader, upper_case_writer]} + ctx.engine_adapter.sync_grants_config(table, grants_config) + + # Grantees are still in lowercase + current_grants = ctx.engine_adapter._get_current_grants_config(table) + assert reader in current_grants.get(select_privilege, []) + assert writer in current_grants.get(select_privilege, []) + + # Revoke writer + grants_config = {select_privilege: [reader.upper()]} + ctx.engine_adapter.sync_grants_config(table, grants_config) + + current_grants = ctx.engine_adapter._get_current_grants_config(table) + assert reader in current_grants.get(select_privilege, []) + assert writer not in current_grants.get(select_privilege, []) + + +def test_grants_plan(ctx: TestContext, tmp_path: Path): + if not ctx.engine_adapter.SUPPORTS_GRANTS: + pytest.skip( + f"Skipping Test since engine adapter {ctx.engine_adapter.dialect} doesn't support grants" + ) + + table = ctx.table("grant_model").sql(dialect="duckdb") + select_privilege = ctx.get_select_privilege() + insert_privilege = ctx.get_insert_privilege() + with ctx.create_users_or_roles("analyst", "etl_user") as roles: + (tmp_path / "models").mkdir(exist_ok=True) + + model_def = f""" + MODEL ( + name {table}, + kind FULL, + grants ( + '{select_privilege}' = ['{roles["analyst"]}'] + ), + grants_target_layer 'all' + ); + SELECT 1 as id, CURRENT_DATE as created_date + """ + + (tmp_path / "models" / "grant_model.sql").write_text(model_def) + + context = ctx.create_context(path=tmp_path) + plan_result = context.plan(auto_apply=True, no_prompts=True) + + assert len(plan_result.new_snapshots) == 1 + snapshot = plan_result.new_snapshots[0] + + # Physical layer w/ grants + table_name = snapshot.table_name() + view_name = snapshot.qualified_view_name.for_environment( + plan_result.environment_naming_info, dialect=ctx.dialect + ) + current_grants = ctx.engine_adapter._get_current_grants_config( + exp.to_table(table_name, dialect=ctx.dialect) + ) + assert current_grants == {select_privilege: [roles["analyst"]]} + + # Virtual layer (view) w/ grants + virtual_grants = ctx.engine_adapter._get_current_grants_config( + exp.to_table(view_name, dialect=ctx.dialect) + ) + assert virtual_grants == {select_privilege: [roles["analyst"]]} + + # Update model with query change and new grants + updated_model = load_sql_based_model( + d.parse( + f""" + MODEL ( + name {table}, + kind FULL, + grants ( + '{select_privilege}' = ['{roles["analyst"]}', '{roles["etl_user"]}'], + '{insert_privilege}' = ['{roles["etl_user"]}'] + ), + grants_target_layer 'all' + ); + SELECT 1 as id, CURRENT_DATE as created_date, 'v2' as version + """, + default_dialect=context.default_dialect, + ), + dialect=context.default_dialect, + ) + context.upsert_model(updated_model) + + plan = context.plan(auto_apply=True, no_prompts=True) + plan_result = PlanResults.create(plan, ctx, ctx.add_test_suffix(TEST_SCHEMA)) + assert len(plan_result.plan.directly_modified) == 1 + + new_snapshot = plan_result.snapshot_for(updated_model) + assert new_snapshot is not None + + new_table_name = new_snapshot.table_name() + final_grants = ctx.engine_adapter._get_current_grants_config( + exp.to_table(new_table_name, dialect=ctx.dialect) + ) + expected_final_grants = { + select_privilege: [roles["analyst"], roles["etl_user"]], + insert_privilege: [roles["etl_user"]], + } + assert set(final_grants.get(select_privilege, [])) == set( + expected_final_grants[select_privilege] + ) + assert final_grants.get(insert_privilege, []) == expected_final_grants[insert_privilege] + + # Virtual layer should also have the updated grants + updated_virtual_grants = ctx.engine_adapter._get_current_grants_config( + exp.to_table(view_name, dialect=ctx.dialect) + ) + assert set(updated_virtual_grants.get(select_privilege, [])) == set( + expected_final_grants[select_privilege] + ) + assert ( + updated_virtual_grants.get(insert_privilege, []) + == expected_final_grants[insert_privilege] + ) diff --git a/tests/core/engine_adapter/integration/test_integration_postgres.py b/tests/core/engine_adapter/integration/test_integration_postgres.py index 26b8cbda42..f236fdebce 100644 --- a/tests/core/engine_adapter/integration/test_integration_postgres.py +++ b/tests/core/engine_adapter/integration/test_integration_postgres.py @@ -1,9 +1,11 @@ import typing as t +from contextlib import contextmanager import pytest from pytest import FixtureRequest from pathlib import Path from sqlmesh.core.engine_adapter import PostgresEngineAdapter from sqlmesh.core.config import Config, DuckDBConnectionConfig +from sqlmesh.core.config.common import VirtualEnvironmentMode from tests.core.engine_adapter.integration import TestContext import time_machine from datetime import timedelta @@ -12,6 +14,7 @@ from sqlmesh.core.context import Context from sqlmesh.core.state_sync import CachingStateSync, EngineAdapterStateSync from sqlmesh.core.snapshot.definition import SnapshotId +from sqlmesh.utils import random_id from tests.core.engine_adapter.integration import ( TestContext, @@ -22,6 +25,87 @@ ) +def _cleanup_user(engine_adapter: PostgresEngineAdapter, user_name: str) -> None: + """Helper function to clean up a PostgreSQL user and all their dependencies.""" + try: + engine_adapter.execute(f""" + SELECT pg_terminate_backend(pid) + FROM pg_stat_activity + WHERE usename = '{user_name}' AND pid <> pg_backend_pid() + """) + engine_adapter.execute(f'DROP OWNED BY "{user_name}"') + engine_adapter.execute(f'DROP USER IF EXISTS "{user_name}"') + except Exception: + pass + + +@contextmanager +def create_users( + engine_adapter: PostgresEngineAdapter, *role_names: str +) -> t.Iterator[t.Dict[str, t.Dict[str, str]]]: + """Create a set of Postgres users and yield their credentials.""" + created_users = [] + roles = {} + + try: + for role_name in role_names: + user_name = f"test_{role_name}" + _cleanup_user(engine_adapter, user_name) + + for role_name in role_names: + user_name = f"test_{role_name}" + password = random_id() + engine_adapter.execute(f"CREATE USER \"{user_name}\" WITH PASSWORD '{password}'") + engine_adapter.execute(f'GRANT USAGE ON SCHEMA public TO "{user_name}"') + created_users.append(user_name) + roles[role_name] = {"username": user_name, "password": password} + + yield roles + + finally: + for user_name in created_users: + _cleanup_user(engine_adapter, user_name) + + +def create_engine_adapter_for_role( + role_credentials: t.Dict[str, str], ctx: TestContext, config: Config +) -> PostgresEngineAdapter: + """Create a PostgreSQL adapter for a specific role to test authentication and permissions.""" + from sqlmesh.core.config import PostgresConnectionConfig + + gateway = ctx.gateway + assert gateway in config.gateways + connection_config = config.gateways[gateway].connection + assert isinstance(connection_config, PostgresConnectionConfig) + + role_connection_config = PostgresConnectionConfig( + host=connection_config.host, + port=connection_config.port, + database=connection_config.database, + user=role_credentials["username"], + password=role_credentials["password"], + keepalives_idle=connection_config.keepalives_idle, + connect_timeout=connection_config.connect_timeout, + role=connection_config.role, + sslmode=connection_config.sslmode, + application_name=connection_config.application_name, + ) + + return t.cast(PostgresEngineAdapter, role_connection_config.create_engine_adapter()) + + +@contextmanager +def engine_adapter_for_role( + role_credentials: t.Dict[str, str], ctx: TestContext, config: Config +) -> t.Iterator[PostgresEngineAdapter]: + """Context manager that yields a PostgresEngineAdapter and ensures it is closed.""" + adapter = create_engine_adapter_for_role(role_credentials, ctx, config) + try: + yield adapter + finally: + adapter.close() + + @pytest.fixture(params=list(generate_pytest_params(ENGINES_BY_NAME["postgres"]))) def ctx( request: FixtureRequest, @@ -286,3 +370,857 @@ def _mutate_config(gateway: str, config: Config): assert after_objects.views == [ exp.to_table(model_b_prod_snapshot.table_name()).text("this") ] + + +# Grants Integration Tests + + +def test_grants_plan_target_layer_physical_only( + engine_adapter: PostgresEngineAdapter, ctx: TestContext, tmp_path: Path +): + with create_users(engine_adapter, "reader") as roles: + (tmp_path / "models").mkdir(exist_ok=True) + + model_def = """ + MODEL ( + name test_schema.physical_grants_model, + kind FULL, + grants ( + 'select' = ['test_reader'] + ), + grants_target_layer 'physical' + ); + SELECT 1 as id, 'physical_only' as layer + """ + + (tmp_path / "models" / "physical_grants_model.sql").write_text(model_def) + + context = ctx.create_context(path=tmp_path) + plan_result = context.plan(auto_apply=True, no_prompts=True) + + assert len(plan_result.new_snapshots) == 1 + snapshot = plan_result.new_snapshots[0] + physical_table_name = snapshot.table_name() + + physical_grants = engine_adapter._get_current_grants_config( + exp.to_table(physical_table_name, dialect=engine_adapter.dialect) + ) + assert physical_grants == {"SELECT": [roles["reader"]["username"]]} + + # Virtual layer should have no grants + virtual_view_name = f"test_schema.physical_grants_model" + virtual_grants = engine_adapter._get_current_grants_config( + exp.to_table(virtual_view_name, dialect=engine_adapter.dialect) + ) + assert virtual_grants == {} + + +def test_grants_plan_target_layer_virtual_only( + engine_adapter: PostgresEngineAdapter, ctx: TestContext, tmp_path: Path +): + with create_users(engine_adapter, "viewer") as roles: + (tmp_path / "models").mkdir(exist_ok=True) + + model_def = """ + MODEL ( + name test_schema.virtual_grants_model, + kind FULL, + grants ( + 'select' = ['test_viewer'] + ), + grants_target_layer 'virtual' + ); + SELECT 1 as id, 'virtual_only' as layer + """ + + (tmp_path / "models" / "virtual_grants_model.sql").write_text(model_def) + + context = ctx.create_context(path=tmp_path) + plan_result = context.plan(auto_apply=True, no_prompts=True) + + assert len(plan_result.new_snapshots) == 1 + snapshot = plan_result.new_snapshots[0] + physical_table_name = snapshot.table_name() + + physical_grants = engine_adapter._get_current_grants_config( + exp.to_table(physical_table_name, dialect=engine_adapter.dialect) + ) + # Physical table should have no grants + assert physical_grants == {} + + virtual_view_name = f"test_schema.virtual_grants_model" + virtual_grants = engine_adapter._get_current_grants_config( + exp.to_table(virtual_view_name, dialect=engine_adapter.dialect) + ) + assert virtual_grants == {"SELECT": [roles["viewer"]["username"]]} + + +def test_grants_plan_full_refresh_model_via_replace( + engine_adapter: PostgresEngineAdapter, ctx: TestContext, tmp_path: Path +): + with create_users(engine_adapter, "reader") as roles: + (tmp_path / "models").mkdir(exist_ok=True) + (tmp_path / "models" / "full_refresh_model.sql").write_text( + f""" + MODEL ( + name test_schema.full_refresh_model, + kind FULL, + grants ( + 'SELECT' = ['{roles["reader"]["username"]}'] + ), + grants_target_layer 'all' + ); + SELECT 1 as id, 'test_data' as status + """ + ) + + context = ctx.create_context(path=tmp_path) + + plan_result = context.plan( + "dev", # this triggers _replace_query_for_model for FULL models + auto_apply=True, + no_prompts=True, + ) + + assert len(plan_result.new_snapshots) == 1 + snapshot = plan_result.new_snapshots[0] + table_name = snapshot.table_name() + + # Physical table + grants = engine_adapter._get_current_grants_config( + exp.to_table(table_name, dialect=engine_adapter.dialect) + ) + assert grants == {"SELECT": [roles["reader"]["username"]]} + + # Virtual view + dev_view_name = "test_schema__dev.full_refresh_model" + dev_grants = engine_adapter._get_current_grants_config( + exp.to_table(dev_view_name, dialect=engine_adapter.dialect) + ) + assert dev_grants == {"SELECT": [roles["reader"]["username"]]} + + +def test_grants_plan_incremental_model( + engine_adapter: PostgresEngineAdapter, ctx: TestContext, tmp_path: Path +): + with create_users(engine_adapter, "reader", "writer") as roles: + (tmp_path / "models").mkdir(exist_ok=True) + + model_name = "incr_model" + model_definition = f""" + MODEL ( + name test_schema.{model_name}, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ts + ), + grants ( + 'SELECT' = ['{roles["reader"]["username"]}'], + 'INSERT' = ['{roles["writer"]["username"]}'] + ), + grants_target_layer 'all' + ); + SELECT 1 as id, @start_ds::timestamp as ts, 'data' as value + """ + + (tmp_path / "models" / f"{model_name}.sql").write_text(model_definition) + + context = ctx.create_context(path=tmp_path) + + plan_result = context.plan( + "dev", start="2020-01-01", end="2020-01-01", auto_apply=True, no_prompts=True + ) + assert len(plan_result.new_snapshots) == 1 + + snapshot = plan_result.new_snapshots[0] + table_name = snapshot.table_name() + + physical_grants = engine_adapter._get_current_grants_config( + exp.to_table(table_name, dialect=engine_adapter.dialect) + ) + assert physical_grants.get("SELECT", []) == [roles["reader"]["username"]] + assert physical_grants.get("INSERT", []) == [roles["writer"]["username"]] + + view_name = f"test_schema__dev.{model_name}" + view_grants = engine_adapter._get_current_grants_config( + exp.to_table(view_name, dialect=engine_adapter.dialect) + ) + assert view_grants.get("SELECT", []) == [roles["reader"]["username"]] + assert view_grants.get("INSERT", []) == [roles["writer"]["username"]] + + +def test_grants_plan_clone_environment( + engine_adapter: PostgresEngineAdapter, ctx: TestContext, tmp_path: Path +): + with create_users(engine_adapter, "reader") as roles: + (tmp_path / "models").mkdir(exist_ok=True) + (tmp_path / "models" / "clone_model.sql").write_text( + f""" + MODEL ( + name test_schema.clone_model, + kind FULL, + grants ( + 'SELECT' = ['{roles["reader"]["username"]}'] + ), + grants_target_layer 'all' + ); + + SELECT 1 as id, 'data' as value + """ + ) + + context = ctx.create_context(path=tmp_path) + prod_plan_result = context.plan("prod", auto_apply=True, no_prompts=True) + + assert len(prod_plan_result.new_snapshots) == 1 + prod_snapshot = prod_plan_result.new_snapshots[0] + prod_table_name = prod_snapshot.table_name() + + # Prod physical table grants + prod_grants = engine_adapter._get_current_grants_config( + exp.to_table(prod_table_name, dialect=engine_adapter.dialect) + ) + assert prod_grants == {"SELECT": [roles["reader"]["username"]]} + + # Prod virtual view grants + prod_view_name = f"test_schema.clone_model" + prod_view_grants = engine_adapter._get_current_grants_config( + exp.to_table(prod_view_name, dialect=engine_adapter.dialect) + ) + assert prod_view_grants == {"SELECT": [roles["reader"]["username"]]} + + # Create dev environment (cloned from prod) + context.plan("dev", auto_apply=True, no_prompts=True, include_unmodified=True) + + # Physical table grants should remain unchanged + prod_grants_after_clone = engine_adapter._get_current_grants_config( + exp.to_table(prod_table_name, dialect=engine_adapter.dialect) + ) + assert prod_grants_after_clone == prod_grants + + # Dev virtual view should have the same grants as prod + dev_view_name = f"test_schema__dev.clone_model" + dev_view_grants = engine_adapter._get_current_grants_config( + exp.to_table(dev_view_name, dialect=engine_adapter.dialect) + ) + assert dev_view_grants == prod_grants + + +@pytest.mark.parametrize( + "model_name,kind_config,query,extra_config,needs_seed", + [ + ( + "grants_full", + "FULL", + "SELECT 1 as id, 'unchanged_query' as data", + "", + False, + ), + ( + "grants_view", + "VIEW", + "SELECT 1 as id, 'unchanged_query' as data", + "", + False, + ), + ( + "grants_incr_time", + "INCREMENTAL_BY_TIME_RANGE (time_column event_date)", + "SELECT '2025-09-01'::date as event_date, 1 as id, 'unchanged_query' as data", + "start '2025-09-01',", + False, + ), + ( + "grants_seed", + "SEED (path '../seeds/grants_seed.csv')", + "", + "", + True, + ), + ], +) +def test_grants_metadata_only_changes( + engine_adapter: PostgresEngineAdapter, + ctx: TestContext, + tmp_path: Path, + model_name: str, + kind_config: str, + query: str, + extra_config: str, + needs_seed: bool, +): + with create_users(engine_adapter, "reader", "writer", "admin") as roles: + (tmp_path / "models").mkdir(exist_ok=True) + + if needs_seed: + (tmp_path / "seeds").mkdir(exist_ok=True) + csv_content = "id,data\\n1,unchanged_query" + (tmp_path / "seeds" / f"{model_name}.csv").write_text(csv_content) + + initial_model_def = f""" + MODEL ( + name test_schema.{model_name}, + kind {kind_config}, + {extra_config} + grants ( + 'select' = ['{roles["reader"]["username"]}'] + ), + grants_target_layer 'all' + ); + {query} + """ + (tmp_path / "models" / f"{model_name}.sql").write_text(initial_model_def) + + context = ctx.create_context(path=tmp_path) + initial_plan_result = context.plan(auto_apply=True, no_prompts=True) + + assert len(initial_plan_result.new_snapshots) == 1 + initial_snapshot = initial_plan_result.new_snapshots[0] + + physical_table_name = initial_snapshot.table_name() + virtual_view_name = f"test_schema.{model_name}" + + initial_physical_grants = engine_adapter._get_current_grants_config( + exp.to_table(physical_table_name, dialect=engine_adapter.dialect) + ) + assert initial_physical_grants == {"SELECT": [roles["reader"]["username"]]} + + initial_virtual_grants = engine_adapter._get_current_grants_config( + exp.to_table(virtual_view_name, dialect=engine_adapter.dialect) + ) + assert initial_virtual_grants == {"SELECT": [roles["reader"]["username"]]} + + # Metadata-only change: update grants only using upsert_model + existing_model = context.get_model(f"test_schema.{model_name}") + context.upsert_model( + existing_model, + grants={ + "select": [roles["writer"]["username"], roles["admin"]["username"]], + "insert": [roles["admin"]["username"]], + }, + ) + second_plan_result = context.plan(auto_apply=True, no_prompts=True) + + expected_grants = { + "SELECT": [roles["writer"]["username"], roles["admin"]["username"]], + "INSERT": [roles["admin"]["username"]], + } + + # For seed models, grant changes rebuild the entire table, so it will create a new physical table + if model_name == "grants_seed" and second_plan_result.new_snapshots: + updated_snapshot = second_plan_result.new_snapshots[0] + physical_table_name = updated_snapshot.table_name() + + updated_physical_grants = engine_adapter._get_current_grants_config( + exp.to_table(physical_table_name, dialect=engine_adapter.dialect) + ) + assert set(updated_physical_grants.get("SELECT", [])) == set(expected_grants["SELECT"]) + assert updated_physical_grants.get("INSERT", []) == expected_grants["INSERT"] + + updated_virtual_grants = engine_adapter._get_current_grants_config( + exp.to_table(virtual_view_name, dialect=engine_adapter.dialect) + ) + assert set(updated_virtual_grants.get("SELECT", [])) == set(expected_grants["SELECT"]) + assert updated_virtual_grants.get("INSERT", []) == expected_grants["INSERT"] + + +def _vde_dev_only_config(gateway: str, config: Config) -> None: + config.virtual_environment_mode = VirtualEnvironmentMode.DEV_ONLY + + +@pytest.mark.parametrize( + "grants_target_layer,model_kind", + [ + ("virtual", "FULL"), + ("physical", "FULL"), + ("all", "FULL"), + ("virtual", "VIEW"), + ("physical", "VIEW"), + ], +) +def test_grants_target_layer_with_vde_dev_only( + engine_adapter: PostgresEngineAdapter, + ctx: TestContext, + tmp_path: Path, + grants_target_layer: str, + model_kind: str, +): + with create_users(engine_adapter, "reader", "writer") as roles: + (tmp_path / "models").mkdir(exist_ok=True) + + if model_kind == "VIEW": + grants_config = ( + f"'SELECT' = ['{roles['reader']['username']}', '{roles['writer']['username']}']" + ) + else: + grants_config = f""" + 'SELECT' = ['{roles["reader"]["username"]}', '{roles["writer"]["username"]}'], + 'INSERT' = ['{roles["writer"]["username"]}'] + """.strip() + + model_def = f""" + MODEL ( + name test_schema.vde_model_{grants_target_layer}_{model_kind.lower()}, + kind {model_kind}, + grants ( + {grants_config} + ), + grants_target_layer '{grants_target_layer}' + ); + SELECT 1 as id, '{grants_target_layer}_{model_kind}' as test_type + """ + ( + tmp_path / "models" / f"vde_model_{grants_target_layer}_{model_kind.lower()}.sql" + ).write_text(model_def) + + context = ctx.create_context(path=tmp_path, config_mutator=_vde_dev_only_config) + context.plan("prod", auto_apply=True, no_prompts=True) + + table_name = f"test_schema.vde_model_{grants_target_layer}_{model_kind.lower()}" + + # In VDE dev_only mode, VIEWs are created as actual views + assert context.engine_adapter.table_exists(table_name) + + grants = engine_adapter._get_current_grants_config( + exp.to_table(table_name, dialect=engine_adapter.dialect) + ) + assert roles["reader"]["username"] in grants.get("SELECT", []) + assert roles["writer"]["username"] in grants.get("SELECT", []) + + if model_kind != "VIEW": + assert roles["writer"]["username"] in grants.get("INSERT", []) + + +def test_grants_incremental_model_with_vde_dev_only( + engine_adapter: PostgresEngineAdapter, ctx: TestContext, tmp_path: Path +): + with create_users(engine_adapter, "etl", "analyst") as roles: + (tmp_path / "models").mkdir(exist_ok=True) + + model_def = f""" + MODEL ( + name test_schema.vde_incremental_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column event_date + ), + grants ( + 'SELECT' = ['{roles["analyst"]["username"]}'], + 'INSERT' = ['{roles["etl"]["username"]}'] + ), + grants_target_layer 'virtual' + ); + SELECT + 1 as id, + @start_date::date as event_date, + 'event' as event_type + """ + (tmp_path / "models" / "vde_incremental_model.sql").write_text(model_def) + + context = ctx.create_context(path=tmp_path, config_mutator=_vde_dev_only_config) + context.plan("prod", auto_apply=True, no_prompts=True) + + prod_table = "test_schema.vde_incremental_model" + prod_grants = engine_adapter._get_current_grants_config( + exp.to_table(prod_table, dialect=engine_adapter.dialect) + ) + assert roles["analyst"]["username"] in prod_grants.get("SELECT", []) + assert roles["etl"]["username"] in prod_grants.get("INSERT", []) + + +@pytest.mark.parametrize( + "change_type,initial_query,updated_query,expect_schema_change", + [ + # Metadata-only change (grants only) + ( + "metadata_only", + "SELECT 1 as id, 'same' as status", + "SELECT 1 as id, 'same' as status", + False, + ), + # Breaking change only + ( + "breaking_only", + "SELECT 1 as id, 'initial' as status, 100 as amount", + "SELECT 1 as id, 'updated' as status", # Removed column + True, + ), + # Both metadata and breaking changes + ( + "metadata_and_breaking", + "SELECT 1 as id, 'initial' as status, 100 as amount", + "SELECT 2 as id, 'changed' as new_status", # Different schema + True, + ), + ], +) +def test_grants_changes_with_vde_dev_only( + engine_adapter: PostgresEngineAdapter, + ctx: TestContext, + tmp_path: Path, + change_type: str, + initial_query: str, + updated_query: str, + expect_schema_change: bool, +): + with create_users(engine_adapter, "user1", "user2", "user3") as roles: + (tmp_path / "models").mkdir(exist_ok=True) + model_path = tmp_path / "models" / f"vde_changes_{change_type}.sql" + + initial_model = f""" + MODEL ( + name test_schema.vde_changes_{change_type}, + kind FULL, + grants ( + 'SELECT' = ['{roles["user1"]["username"]}'] + ), + grants_target_layer 'virtual' + ); + {initial_query} + """ + model_path.write_text(initial_model) + + context = ctx.create_context(path=tmp_path, config_mutator=_vde_dev_only_config) + context.plan("prod", auto_apply=True, no_prompts=True) + + table_name = f"test_schema.vde_changes_{change_type}" + initial_grants = engine_adapter._get_current_grants_config( + exp.to_table(table_name, dialect=engine_adapter.dialect) + ) + assert roles["user1"]["username"] in initial_grants.get("SELECT", []) + assert roles["user2"]["username"] not in initial_grants.get("SELECT", []) + + # Update model with new grants and potentially new query + updated_model = f""" + MODEL ( + name test_schema.vde_changes_{change_type}, + kind FULL, + grants ( + 'SELECT' = ['{roles["user1"]["username"]}', '{roles["user2"]["username"]}', '{roles["user3"]["username"]}'], + 'INSERT' = ['{roles["user3"]["username"]}'] + ), + grants_target_layer 'virtual' + ); + {updated_query} + """ + model_path.write_text(updated_model) + + # Get initial table columns + initial_columns = set( + col[0] + for col in engine_adapter.fetchall( + f"SELECT column_name FROM information_schema.columns WHERE table_schema = 'test_schema' AND table_name = 'vde_changes_{change_type}'" + ) + ) + + context.load() + plan = context.plan("prod", auto_apply=True, no_prompts=True) + + assert len(plan.new_snapshots) == 1 + + current_columns = set( + col[0] + for col in engine_adapter.fetchall( + f"SELECT column_name FROM information_schema.columns WHERE table_schema = 'test_schema' AND table_name = 'vde_changes_{change_type}'" + ) + ) + + if expect_schema_change: + assert current_columns != initial_columns + else: + # For metadata-only changes, schema should be the same + assert current_columns == initial_columns + + # Grants should be updated in all cases + updated_grants = engine_adapter._get_current_grants_config( + exp.to_table(table_name, dialect=engine_adapter.dialect) + ) + assert roles["user1"]["username"] in updated_grants.get("SELECT", []) + assert roles["user2"]["username"] in updated_grants.get("SELECT", []) + assert roles["user3"]["username"] in updated_grants.get("SELECT", []) + assert roles["user3"]["username"] in updated_grants.get("INSERT", []) + + +@pytest.mark.parametrize( + "grants_target_layer,environment", + [ + ("virtual", "prod"), + ("virtual", "dev"), + ("physical", "prod"), + ("physical", "staging"), + ("all", "prod"), + ("all", "preview"), + ], +) +def test_grants_target_layer_plan_env_with_vde_dev_only( + engine_adapter: PostgresEngineAdapter, + ctx: TestContext, + tmp_path: Path, + grants_target_layer: str, + environment: str, +): + with create_users(engine_adapter, "grantee") as roles: + (tmp_path / "models").mkdir(exist_ok=True) + + model_def = f""" + MODEL ( + name test_schema.vde_layer_model, + kind FULL, + grants ( + 'SELECT' = ['{roles["grantee"]["username"]}'] + ), + grants_target_layer '{grants_target_layer}' + ); + SELECT 1 as id, '{environment}' as env, '{grants_target_layer}' as layer + """ + (tmp_path / "models" / "vde_layer_model.sql").write_text(model_def) + + context = ctx.create_context(path=tmp_path, config_mutator=_vde_dev_only_config) + + if environment == "prod": + context.plan("prod", auto_apply=True, no_prompts=True) + table_name = "test_schema.vde_layer_model" + grants = engine_adapter._get_current_grants_config( + exp.to_table(table_name, dialect=engine_adapter.dialect) + ) + assert roles["grantee"]["username"] in grants.get("SELECT", []) + else: + context.plan(environment, auto_apply=True, no_prompts=True, include_unmodified=True) + virtual_view = f"test_schema__{environment}.vde_layer_model" + assert context.engine_adapter.table_exists(virtual_view) + virtual_grants = engine_adapter._get_current_grants_config( + exp.to_table(virtual_view, dialect=engine_adapter.dialect) + ) + + data_objects = engine_adapter.get_data_objects("sqlmesh__test_schema") + physical_tables = [ + obj + for obj in data_objects + if "vde_layer_model" in obj.name + and obj.name.endswith("__dev") # Always __dev suffix in VDE dev_only + and "TABLE" in str(obj.type).upper() + ] + + if grants_target_layer == "virtual": + # Virtual layer should have grants, physical should not + assert roles["grantee"]["username"] in virtual_grants.get("SELECT", []) + + assert len(physical_tables) > 0 + for physical_table in physical_tables: + physical_table_name = f"sqlmesh__test_schema.{physical_table.name}" + physical_grants = engine_adapter._get_current_grants_config( + exp.to_table(physical_table_name, dialect=engine_adapter.dialect) + ) + assert roles["grantee"]["username"] not in physical_grants.get("SELECT", []) + + elif grants_target_layer == "physical": + # Virtual layer should not have grants, physical should + assert roles["grantee"]["username"] not in virtual_grants.get("SELECT", []) + + assert len(physical_tables) > 0 + for physical_table in physical_tables: + physical_table_name = f"sqlmesh__test_schema.{physical_table.name}" + physical_grants = engine_adapter._get_current_grants_config( + exp.to_table(physical_table_name, dialect=engine_adapter.dialect) + ) + assert roles["grantee"]["username"] in physical_grants.get("SELECT", []) + + else: # grants_target_layer == "all" + # Both layers should have grants + assert roles["grantee"]["username"] in virtual_grants.get("SELECT", []) + assert len(physical_tables) > 0 + for physical_table in physical_tables: + physical_table_name = f"sqlmesh__test_schema.{physical_table.name}" + physical_grants = engine_adapter._get_current_grants_config( + exp.to_table(physical_table_name, dialect=engine_adapter.dialect) + ) + assert roles["grantee"]["username"] in physical_grants.get("SELECT", []) + + +@pytest.mark.parametrize( + "model_kind", + [ + "SCD_TYPE_2", + "SCD_TYPE_2_BY_TIME", + ], +) +def test_grants_plan_scd_type_2_models( + engine_adapter: PostgresEngineAdapter, + ctx: TestContext, + tmp_path: Path, + model_kind: str, +): + with create_users(engine_adapter, "reader", "writer", "analyst") as roles: + (tmp_path / "models").mkdir(exist_ok=True) + model_name = "scd_model" + + kind_config = f"{model_kind} (unique_key [id])" + model_definition = f""" + MODEL ( + name test_schema.{model_name}, + kind {kind_config}, + grants ( + 'SELECT' = ['{roles["reader"]["username"]}'], + 'INSERT' = ['{roles["writer"]["username"]}'] + ), + grants_target_layer 'all' + ); + SELECT 1 as id, 'initial_data' as name, CURRENT_TIMESTAMP as updated_at + """ + (tmp_path / "models" / f"{model_name}.sql").write_text(model_definition) + + context = ctx.create_context(path=tmp_path) + plan_result = context.plan( + "dev", start="2023-01-01", end="2023-01-01", auto_apply=True, no_prompts=True + ) + assert len(plan_result.new_snapshots) == 1 + + current_snapshot = plan_result.new_snapshots[0] + fingerprint_version = current_snapshot.fingerprint.to_version() + physical_table_name = ( + f"sqlmesh__test_schema.test_schema__{model_name}__{fingerprint_version}__dev" + ) + physical_grants = engine_adapter._get_current_grants_config( + exp.to_table(physical_table_name, dialect=engine_adapter.dialect) + ) + assert physical_grants.get("SELECT", []) == [roles["reader"]["username"]] + assert physical_grants.get("INSERT", []) == [roles["writer"]["username"]] + + view_name = f"test_schema__dev.{model_name}" + view_grants = engine_adapter._get_current_grants_config( + exp.to_table(view_name, dialect=engine_adapter.dialect) + ) + assert view_grants.get("SELECT", []) == [roles["reader"]["username"]] + assert view_grants.get("INSERT", []) == [roles["writer"]["username"]] + + # Data change + updated_model_definition = f""" + MODEL ( + name test_schema.{model_name}, + kind {kind_config}, + grants ( + 'SELECT' = ['{roles["reader"]["username"]}'], + 'INSERT' = ['{roles["writer"]["username"]}'] + ), + grants_target_layer 'all' + ); + SELECT 1 as id, 'updated_data' as name, CURRENT_TIMESTAMP as updated_at + """ + (tmp_path / "models" / f"{model_name}.sql").write_text(updated_model_definition) + + context.load() + context.plan("dev", start="2023-01-02", end="2023-01-02", auto_apply=True, no_prompts=True) + + snapshot = context.get_snapshot(f"test_schema.{model_name}") + assert snapshot + fingerprint = snapshot.fingerprint.to_version() + table_name = f"sqlmesh__test_schema.test_schema__{model_name}__{fingerprint}__dev" + data_change_grants = engine_adapter._get_current_grants_config( + exp.to_table(table_name, dialect=engine_adapter.dialect) + ) + assert data_change_grants.get("SELECT", []) == [roles["reader"]["username"]] + assert data_change_grants.get("INSERT", []) == [roles["writer"]["username"]] + + # Data + grants changes + grant_change_model_definition = f""" + MODEL ( + name test_schema.{model_name}, + kind {kind_config}, + grants ( + 'SELECT' = ['{roles["reader"]["username"]}', '{roles["analyst"]["username"]}'], + 'INSERT' = ['{roles["writer"]["username"]}'], + 'UPDATE' = ['{roles["analyst"]["username"]}'] + ), + grants_target_layer 'all' + ); + SELECT 1 as id, 'grant_changed_data' as name, CURRENT_TIMESTAMP as updated_at + """ + (tmp_path / "models" / f"{model_name}.sql").write_text(grant_change_model_definition) + + context.load() + context.plan("dev", start="2023-01-03", end="2023-01-03", auto_apply=True, no_prompts=True) + + snapshot = context.get_snapshot(f"test_schema.{model_name}") + assert snapshot + fingerprint = snapshot.fingerprint.to_version() + table_name = f"sqlmesh__test_schema.test_schema__{model_name}__{fingerprint}__dev" + final_grants = engine_adapter._get_current_grants_config( + exp.to_table(table_name, dialect=engine_adapter.dialect) + ) + expected_select_users = {roles["reader"]["username"], roles["analyst"]["username"]} + assert set(final_grants.get("SELECT", [])) == expected_select_users + assert final_grants.get("INSERT", []) == [roles["writer"]["username"]] + assert final_grants.get("UPDATE", []) == [roles["analyst"]["username"]] + + final_view_grants = engine_adapter._get_current_grants_config( + exp.to_table(view_name, dialect=engine_adapter.dialect) + ) + assert set(final_view_grants.get("SELECT", [])) == expected_select_users + assert final_view_grants.get("INSERT", []) == [roles["writer"]["username"]] + assert final_view_grants.get("UPDATE", []) == [roles["analyst"]["username"]] + + +@pytest.mark.parametrize( + "model_kind", + [ + "SCD_TYPE_2", + "SCD_TYPE_2_BY_TIME", + ], +) +def test_grants_plan_scd_type_2_with_vde_dev_only( + engine_adapter: PostgresEngineAdapter, + ctx: TestContext, + tmp_path: Path, + model_kind: str, +): + with create_users(engine_adapter, "etl_user", "analyst") as roles: + (tmp_path / "models").mkdir(exist_ok=True) + model_name = "vde_scd_model" + + model_def = f""" + MODEL ( + name test_schema.{model_name}, + kind {model_kind} (unique_key [customer_id]), + grants ( + 'SELECT' = ['{roles["analyst"]["username"]}'], + 'INSERT' = ['{roles["etl_user"]["username"]}'] + ), + grants_target_layer 'all' + ); + SELECT + 1 as customer_id, + 'active' as status, + CURRENT_TIMESTAMP as updated_at + """ + (tmp_path / "models" / f"{model_name}.sql").write_text(model_def) + + context = ctx.create_context(path=tmp_path, config_mutator=_vde_dev_only_config) + + # Prod + context.plan("prod", auto_apply=True, no_prompts=True) + prod_table = f"test_schema.{model_name}" + prod_grants = engine_adapter._get_current_grants_config( + exp.to_table(prod_table, dialect=engine_adapter.dialect) + ) + assert roles["analyst"]["username"] in prod_grants.get("SELECT", []) + assert roles["etl_user"]["username"] in prod_grants.get("INSERT", []) + + # Dev + context.plan("dev", auto_apply=True, no_prompts=True, include_unmodified=True) + dev_view = f"test_schema__dev.{model_name}" + dev_grants = engine_adapter._get_current_grants_config( + exp.to_table(dev_view, dialect=engine_adapter.dialect) + ) + assert roles["analyst"]["username"] in dev_grants.get("SELECT", []) + assert roles["etl_user"]["username"] in dev_grants.get("INSERT", []) + + snapshot = context.get_snapshot(f"test_schema.{model_name}") + assert snapshot + fingerprint_version = snapshot.fingerprint.to_version() + dev_physical_table_name = ( + f"sqlmesh__test_schema.test_schema__{model_name}__{fingerprint_version}__dev" + ) + + dev_physical_grants = engine_adapter._get_current_grants_config( + exp.to_table(dev_physical_table_name, dialect=engine_adapter.dialect) + ) + assert roles["analyst"]["username"] in dev_physical_grants.get("SELECT", []) + assert roles["etl_user"]["username"] in dev_physical_grants.get("INSERT", []) diff --git a/tests/core/engine_adapter/test_base.py b/tests/core/engine_adapter/test_base.py index ba775c0779..2b9bcc665f 100644 --- a/tests/core/engine_adapter/test_base.py +++ b/tests/core/engine_adapter/test_base.py @@ -4065,3 +4065,108 @@ def test_data_object_cache_cleared_on_create_table_like( assert result is not None assert result.name == "target_table" assert mock_get_data_objects.call_count == 2 + + +def test_diff_grants_configs(): + new = {"SELECT": ["u1", "u2"], "INSERT": ["u1"]} + old = {"SELECT": ["u1", "u3"], "update": ["u1"]} + + additions, removals = EngineAdapter._diff_grants_configs(new, old) + + assert additions.get("SELECT") and set(additions["SELECT"]) == {"u2"} + assert removals.get("SELECT") and set(removals["SELECT"]) == {"u3"} + + assert additions.get("INSERT") and set(additions["INSERT"]) == {"u1"} + assert removals.get("update") and set(removals["update"]) == {"u1"} + + for perm, grantees in additions.items(): + assert set(grantees).isdisjoint(set(old.get(perm, []))) + for perm, grantees in removals.items(): + assert set(grantees).isdisjoint(set(new.get(perm, []))) + + +def test_diff_grants_configs_empty_new(): + new = {} + old = {"SELECT": ["u1", "u2"], "INSERT": ["u3"]} + + additions, removals = EngineAdapter._diff_grants_configs(new, old) + + assert additions == {} + assert removals == old + + +def test_diff_grants_configs_empty_old(): + new = {"SELECT": ["u1", "u2"], "INSERT": ["u3"]} + old = {} + + additions, removals = EngineAdapter._diff_grants_configs(new, old) + + assert additions == new + assert removals == {} + + +def test_diff_grants_configs_identical(): + grants = {"SELECT": ["u1", "u2"], "INSERT": ["u3"]} + + additions, removals = EngineAdapter._diff_grants_configs(grants, grants) + + assert additions == {} + assert removals == {} + + +def test_diff_grants_configs_none_configs(): + grants = {"SELECT": ["u1"]} + + additions, removals = EngineAdapter._diff_grants_configs(grants, {}) + assert additions == grants + assert removals == {} + + additions, removals = EngineAdapter._diff_grants_configs({}, grants) + assert additions == {} + assert removals == grants + + additions, removals = EngineAdapter._diff_grants_configs({}, {}) + assert additions == {} + assert removals == {} + + +def test_diff_grants_configs_duplicate_grantees(): + new = {"SELECT": ["u1", "u2", "u1"]} + old = {"SELECT": ["u2", "u3", "u2"]} + + additions, removals = EngineAdapter._diff_grants_configs(new, old) + + assert additions["SELECT"] == ["u1", "u1"] + assert removals["SELECT"] == ["u3"] + + +def test_diff_grants_configs_case_sensitive(): + new = {"select": ["u1"], "SELECT": ["u2"]} + old = {"Select": ["u3"]} + + additions, removals = EngineAdapter._diff_grants_configs(new, old) + + assert set(additions.keys()) == {"select", "SELECT"} + assert set(removals.keys()) == {"Select"} + assert additions["select"] == ["u1"] + assert additions["SELECT"] == ["u2"] + assert removals["Select"] == ["u3"] + + +def test_sync_grants_config_unsupported_engine(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.SUPPORTS_GRANTS = False + + relation = exp.to_table("test_table") + grants_config = {"SELECT": ["user1"]} + + with pytest.raises(NotImplementedError, match="Engine does not support grants"): + adapter.sync_grants_config(relation, grants_config) + + +def test_get_current_grants_config_not_implemented(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(EngineAdapter) + relation = exp.to_table("test_table") + + with pytest.raises(NotImplementedError): + adapter._get_current_grants_config(relation) diff --git a/tests/core/engine_adapter/test_base_postgres.py b/tests/core/engine_adapter/test_base_postgres.py index df280a9059..f286c47c56 100644 --- a/tests/core/engine_adapter/test_base_postgres.py +++ b/tests/core/engine_adapter/test_base_postgres.py @@ -3,6 +3,7 @@ from unittest.mock import call import pytest +from pytest_mock.plugin import MockerFixture from sqlglot import exp, parse_one from sqlmesh.core.engine_adapter.base_postgres import BasePostgresEngineAdapter @@ -75,3 +76,26 @@ def test_drop_view(make_mocked_engine_adapter: t.Callable): call('DROP VIEW IF EXISTS "db"."view"'), ] ) + + +def test_get_current_schema(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(BasePostgresEngineAdapter) + + fetchone_mock = mocker.patch.object(adapter, "fetchone", return_value=("test_schema",)) + result = adapter._get_current_schema() + + assert result == "test_schema" + fetchone_mock.assert_called_once() + executed_query = fetchone_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="postgres") + assert executed_sql == "SELECT CURRENT_SCHEMA" + + fetchone_mock.reset_mock() + fetchone_mock.return_value = None + result = adapter._get_current_schema() + assert result == "public" + + fetchone_mock.reset_mock() + fetchone_mock.return_value = (None,) # search_path = '' or 'nonexistent_schema' + result = adapter._get_current_schema() + assert result == "public" diff --git a/tests/core/engine_adapter/test_bigquery.py b/tests/core/engine_adapter/test_bigquery.py index f195bbaa2a..047613e47a 100644 --- a/tests/core/engine_adapter/test_bigquery.py +++ b/tests/core/engine_adapter/test_bigquery.py @@ -13,6 +13,7 @@ import sqlmesh.core.dialect as d from sqlmesh.core.engine_adapter import BigQueryEngineAdapter from sqlmesh.core.engine_adapter.bigquery import select_partitions_expr +from sqlmesh.core.engine_adapter.shared import DataObjectType from sqlmesh.core.node import IntervalUnit from sqlmesh.utils import AttributeDict from sqlmesh.utils.errors import SQLMeshError @@ -588,13 +589,14 @@ def _to_sql_calls(execute_mock: t.Any, identify: bool = True) -> t.List[str]: execute_mock = execute_mock.execute output = [] for call in execute_mock.call_args_list: - value = call[0][0] - sql = ( - value.sql(dialect="bigquery", identify=identify) - if isinstance(value, exp.Expression) - else str(value) - ) - output.append(sql) + values = ensure_list(call[0][0]) + for value in values: + sql = ( + value.sql(dialect="bigquery", identify=identify) + if isinstance(value, exp.Expression) + else str(value) + ) + output.append(sql) return output @@ -1213,3 +1215,168 @@ def test_scd_type_2_by_partitioning(adapter: BigQueryEngineAdapter): # Both calls should contain the partition logic (the scd logic is already covered by other tests) assert "PARTITION BY TIMESTAMP_TRUNC(`valid_from`, DAY)" in calls[0] assert "PARTITION BY TIMESTAMP_TRUNC(`valid_from`, DAY)" in calls[1] + + +def test_sync_grants_config(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(BigQueryEngineAdapter) + relation = exp.to_table("project.dataset.test_table", dialect="bigquery") + new_grants_config = { + "roles/bigquery.dataViewer": ["user:analyst@example.com", "group:data-team@example.com"], + "roles/bigquery.dataEditor": ["user:admin@example.com"], + } + current_grants = [ + ("roles/bigquery.dataViewer", "user:old_analyst@example.com"), + ("roles/bigquery.admin", "user:old_admin@example.com"), + ] + + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + execute_mock = mocker.patch.object(adapter, "execute") + mocker.patch.object(adapter, "get_current_catalog", return_value="project") + mocker.patch.object(adapter.client, "location", "us-central1") + + mock_dataset = mocker.Mock() + mock_dataset.location = "us-central1" + mocker.patch.object(adapter, "_db_call", return_value=mock_dataset) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="bigquery") + expected_sql = ( + "SELECT privilege_type, grantee FROM `project`.`region-us-central1`.`INFORMATION_SCHEMA.OBJECT_PRIVILEGES` AS OBJECT_PRIVILEGES " + "WHERE object_schema = 'dataset' AND object_name = 'test_table' AND SPLIT(grantee, ':')[OFFSET(1)] <> session_user()" + ) + assert executed_sql == expected_sql + + sql_calls = _to_sql_calls(execute_mock) + + assert len(sql_calls) == 4 + assert ( + "REVOKE `roles/bigquery.dataViewer` ON TABLE `project`.`dataset`.`test_table` FROM 'user:old_analyst@example.com'" + in sql_calls + ) + assert ( + "REVOKE `roles/bigquery.admin` ON TABLE `project`.`dataset`.`test_table` FROM 'user:old_admin@example.com'" + in sql_calls + ) + assert ( + "GRANT `roles/bigquery.dataViewer` ON TABLE `project`.`dataset`.`test_table` TO 'user:analyst@example.com', 'group:data-team@example.com'" + in sql_calls + ) + assert ( + "GRANT `roles/bigquery.dataEditor` ON TABLE `project`.`dataset`.`test_table` TO 'user:admin@example.com'" + in sql_calls + ) + + +def test_sync_grants_config_with_overlaps( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(BigQueryEngineAdapter) + relation = exp.to_table("project.dataset.test_table", dialect="bigquery") + new_grants_config = { + "roles/bigquery.dataViewer": [ + "user:analyst1@example.com", + "user:analyst2@example.com", + "user:analyst3@example.com", + ], + "roles/bigquery.dataEditor": ["user:analyst2@example.com", "user:editor@example.com"], + } + current_grants = [ + ("roles/bigquery.dataViewer", "user:analyst1@example.com"), # Keep + ("roles/bigquery.dataViewer", "user:old_analyst@example.com"), # Remove + ("roles/bigquery.dataEditor", "user:analyst2@example.com"), # Keep + ("roles/bigquery.admin", "user:admin@example.com"), # Remove + ] + + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + execute_mock = mocker.patch.object(adapter, "execute") + mocker.patch.object(adapter, "get_current_catalog", return_value="project") + mocker.patch.object(adapter.client, "location", "us-central1") + + mock_dataset = mocker.Mock() + mock_dataset.location = "us-central1" + mocker.patch.object(adapter, "_db_call", return_value=mock_dataset) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="bigquery") + expected_sql = ( + "SELECT privilege_type, grantee FROM `project`.`region-us-central1`.`INFORMATION_SCHEMA.OBJECT_PRIVILEGES` AS OBJECT_PRIVILEGES " + "WHERE object_schema = 'dataset' AND object_name = 'test_table' AND SPLIT(grantee, ':')[OFFSET(1)] <> session_user()" + ) + assert executed_sql == expected_sql + + sql_calls = _to_sql_calls(execute_mock) + + assert len(sql_calls) == 4 + assert ( + "REVOKE `roles/bigquery.dataViewer` ON TABLE `project`.`dataset`.`test_table` FROM 'user:old_analyst@example.com'" + in sql_calls + ) + assert ( + "REVOKE `roles/bigquery.admin` ON TABLE `project`.`dataset`.`test_table` FROM 'user:admin@example.com'" + in sql_calls + ) + assert ( + "GRANT `roles/bigquery.dataViewer` ON TABLE `project`.`dataset`.`test_table` TO 'user:analyst2@example.com', 'user:analyst3@example.com'" + in sql_calls + ) + assert ( + "GRANT `roles/bigquery.dataEditor` ON TABLE `project`.`dataset`.`test_table` TO 'user:editor@example.com'" + in sql_calls + ) + + +@pytest.mark.parametrize( + "table_type, expected_keyword", + [ + (DataObjectType.TABLE, "TABLE"), + (DataObjectType.VIEW, "VIEW"), + (DataObjectType.MATERIALIZED_VIEW, "MATERIALIZED VIEW"), + ], +) +def test_sync_grants_config_object_kind( + make_mocked_engine_adapter: t.Callable, + mocker: MockerFixture, + table_type: DataObjectType, + expected_keyword: str, +) -> None: + adapter = make_mocked_engine_adapter(BigQueryEngineAdapter) + relation = exp.to_table("project.dataset.test_object", dialect="bigquery") + + mocker.patch.object(adapter, "fetchall", return_value=[]) + execute_mock = mocker.patch.object(adapter, "execute") + mocker.patch.object(adapter, "get_current_catalog", return_value="project") + mocker.patch.object(adapter.client, "location", "us-central1") + + mock_dataset = mocker.Mock() + mock_dataset.location = "us-central1" + mocker.patch.object(adapter, "_db_call", return_value=mock_dataset) + + adapter.sync_grants_config( + relation, {"roles/bigquery.dataViewer": ["user:test@example.com"]}, table_type + ) + + executed_exprs = execute_mock.call_args[0][0] + sql_calls = [expr.sql(dialect="bigquery") for expr in executed_exprs] + assert sql_calls == [ + f"GRANT `roles/bigquery.dataViewer` ON {expected_keyword} project.dataset.test_object TO 'user:test@example.com'" + ] + + +def test_sync_grants_config_no_schema( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(BigQueryEngineAdapter) + relation = exp.to_table("test_table", dialect="bigquery") + new_grants_config = { + "roles/bigquery.dataViewer": ["user:analyst@example.com"], + "roles/bigquery.dataEditor": ["user:editor@example.com"], + } + + with pytest.raises(ValueError, match="Table test_table does not have a schema \\(dataset\\)"): + adapter.sync_grants_config(relation, new_grants_config) diff --git a/tests/core/engine_adapter/test_databricks.py b/tests/core/engine_adapter/test_databricks.py index 27988fed39..e4512f11c9 100644 --- a/tests/core/engine_adapter/test_databricks.py +++ b/tests/core/engine_adapter/test_databricks.py @@ -128,17 +128,194 @@ def test_get_current_catalog(mocker: MockFixture, make_mocked_engine_adapter: t. assert to_sql_calls(adapter) == ["SELECT CURRENT_CATALOG()"] -def test_get_current_database(mocker: MockFixture, make_mocked_engine_adapter: t.Callable): +def test_get_current_schema(mocker: MockFixture, make_mocked_engine_adapter: t.Callable): mocker.patch( "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" ) adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") adapter.cursor.fetchone.return_value = ("test_database",) - assert adapter.get_current_database() == "test_database" + assert adapter._get_current_schema() == "test_database" assert to_sql_calls(adapter) == ["SELECT CURRENT_DATABASE()"] +def test_sync_grants_config(make_mocked_engine_adapter: t.Callable, mocker: MockFixture): + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="main") + relation = exp.to_table("main.test_schema.test_table", dialect="databricks") + new_grants_config = { + "SELECT": ["group1", "group2"], + "MODIFY": ["writers"], + } + + current_grants = [ + ("SELECT", "legacy"), + ("REFRESH", "stale"), + ] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="databricks") + expected_sql = ( + "SELECT privilege_type, grantee FROM main.information_schema.table_privileges " + "WHERE table_catalog = 'main' AND table_schema = 'test_schema' AND table_name = 'test_table' " + "AND grantor = CURRENT_USER() AND grantee <> CURRENT_USER() AND inherited_from = 'NONE'" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 5 + + assert "GRANT SELECT ON TABLE `main`.`test_schema`.`test_table` TO `group1`" in sql_calls + assert "GRANT SELECT ON TABLE `main`.`test_schema`.`test_table` TO `group2`" in sql_calls + assert "GRANT MODIFY ON TABLE `main`.`test_schema`.`test_table` TO `writers`" in sql_calls + assert "REVOKE SELECT ON TABLE `main`.`test_schema`.`test_table` FROM `legacy`" in sql_calls + assert "REVOKE REFRESH ON TABLE `main`.`test_schema`.`test_table` FROM `stale`" in sql_calls + + +def test_sync_grants_config_with_overlaps( + make_mocked_engine_adapter: t.Callable, mocker: MockFixture +): + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="main") + relation = exp.to_table("main.test_schema.test_table", dialect="databricks") + new_grants_config = { + "SELECT": ["shared", "new_role"], + "MODIFY": ["shared", "writer"], + } + + current_grants = [ + ("SELECT", "shared"), + ("SELECT", "legacy"), + ("MODIFY", "shared"), + ] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="databricks") + expected_sql = ( + "SELECT privilege_type, grantee FROM main.information_schema.table_privileges " + "WHERE table_catalog = 'main' AND table_schema = 'test_schema' AND table_name = 'test_table' " + "AND grantor = CURRENT_USER() AND grantee <> CURRENT_USER() AND inherited_from = 'NONE'" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 3 + + assert "GRANT SELECT ON TABLE `main`.`test_schema`.`test_table` TO `new_role`" in sql_calls + assert "GRANT MODIFY ON TABLE `main`.`test_schema`.`test_table` TO `writer`" in sql_calls + assert "REVOKE SELECT ON TABLE `main`.`test_schema`.`test_table` FROM `legacy`" in sql_calls + + +@pytest.mark.parametrize( + "table_type, expected_keyword", + [ + (DataObjectType.TABLE, "TABLE"), + (DataObjectType.VIEW, "VIEW"), + (DataObjectType.MATERIALIZED_VIEW, "MATERIALIZED VIEW"), + (DataObjectType.MANAGED_TABLE, "TABLE"), + ], +) +def test_sync_grants_config_object_kind( + make_mocked_engine_adapter: t.Callable, + mocker: MockFixture, + table_type: DataObjectType, + expected_keyword: str, +) -> None: + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="main") + relation = exp.to_table("main.test_schema.test_object", dialect="databricks") + + mocker.patch.object(adapter, "fetchall", return_value=[]) + + adapter.sync_grants_config(relation, {"SELECT": ["test"]}, table_type) + + sql_calls = to_sql_calls(adapter) + assert sql_calls == [ + f"GRANT SELECT ON {expected_keyword} `main`.`test_schema`.`test_object` TO `test`" + ] + + +def test_sync_grants_config_quotes(make_mocked_engine_adapter: t.Callable, mocker: MockFixture): + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="`test_db`") + relation = exp.to_table("`test_db`.`test_schema`.`test_table`", dialect="databricks") + new_grants_config = { + "SELECT": ["group1", "group2"], + "MODIFY": ["writers"], + } + + current_grants = [ + ("SELECT", "legacy"), + ("REFRESH", "stale"), + ] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="databricks") + expected_sql = ( + "SELECT privilege_type, grantee FROM `test_db`.information_schema.table_privileges " + "WHERE table_catalog = 'test_db' AND table_schema = 'test_schema' AND table_name = 'test_table' " + "AND grantor = CURRENT_USER() AND grantee <> CURRENT_USER() AND inherited_from = 'NONE'" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 5 + + assert "GRANT SELECT ON TABLE `test_db`.`test_schema`.`test_table` TO `group1`" in sql_calls + assert "GRANT SELECT ON TABLE `test_db`.`test_schema`.`test_table` TO `group2`" in sql_calls + assert "GRANT MODIFY ON TABLE `test_db`.`test_schema`.`test_table` TO `writers`" in sql_calls + assert "REVOKE SELECT ON TABLE `test_db`.`test_schema`.`test_table` FROM `legacy`" in sql_calls + assert "REVOKE REFRESH ON TABLE `test_db`.`test_schema`.`test_table` FROM `stale`" in sql_calls + + +def test_sync_grants_config_no_catalog_or_schema( + make_mocked_engine_adapter: t.Callable, mocker: MockFixture +): + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="main_catalog") + relation = exp.to_table("test_table", dialect="databricks") + new_grants_config = { + "SELECT": ["group1", "group2"], + "MODIFY": ["writers"], + } + + current_grants = [ + ("SELECT", "legacy"), + ("REFRESH", "stale"), + ] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + mocker.patch.object(adapter, "_get_current_schema", return_value="schema") + mocker.patch.object(adapter, "get_current_catalog", return_value="main_catalog") + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="databricks") + expected_sql = ( + "SELECT privilege_type, grantee FROM `main_catalog`.information_schema.table_privileges " + "WHERE table_catalog = 'main_catalog' AND table_schema = 'schema' AND table_name = 'test_table' " + "AND grantor = CURRENT_USER() AND grantee <> CURRENT_USER() AND inherited_from = 'NONE'" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 5 + + assert "GRANT SELECT ON TABLE `test_table` TO `group1`" in sql_calls + assert "GRANT SELECT ON TABLE `test_table` TO `group2`" in sql_calls + assert "GRANT MODIFY ON TABLE `test_table` TO `writers`" in sql_calls + assert "REVOKE SELECT ON TABLE `test_table` FROM `legacy`" in sql_calls + assert "REVOKE REFRESH ON TABLE `test_table` FROM `stale`" in sql_calls + + def test_insert_overwrite_by_partition_query( make_mocked_engine_adapter: t.Callable, mocker: MockFixture, make_temp_table_name: t.Callable ): diff --git a/tests/core/engine_adapter/test_postgres.py b/tests/core/engine_adapter/test_postgres.py index 6134126a41..ebcdd03f55 100644 --- a/tests/core/engine_adapter/test_postgres.py +++ b/tests/core/engine_adapter/test_postgres.py @@ -177,3 +177,108 @@ def test_server_version(make_mocked_engine_adapter: t.Callable, mocker: MockerFi del adapter.server_version fetchone_mock.return_value = ("15.13 (Debian 15.13-1.pgdg120+1)",) assert adapter.server_version == (15, 13) + + +def test_sync_grants_config(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(PostgresEngineAdapter) + relation = exp.to_table("test_schema.test_table", dialect="postgres") + new_grants_config = {"SELECT": ["user1", "user2"], "INSERT": ["user3"]} + + current_grants = [("SELECT", "old_user"), ("UPDATE", "admin_user")] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="postgres") + + assert executed_sql == ( + "SELECT privilege_type, grantee FROM information_schema.role_table_grants " + "WHERE table_schema = 'test_schema' AND table_name = 'test_table' " + "AND grantor = current_role AND grantee <> current_role" + ) + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 4 + + assert 'GRANT SELECT ON "test_schema"."test_table" TO "user1", "user2"' in sql_calls + assert 'GRANT INSERT ON "test_schema"."test_table" TO "user3"' in sql_calls + assert 'REVOKE SELECT ON "test_schema"."test_table" FROM "old_user"' in sql_calls + assert 'REVOKE UPDATE ON "test_schema"."test_table" FROM "admin_user"' in sql_calls + + +def test_sync_grants_config_with_overlaps( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(PostgresEngineAdapter) + relation = exp.to_table("test_schema.test_table", dialect="postgres") + new_grants_config = {"SELECT": ["user1", "user2", "user3"], "INSERT": ["user2", "user4"]} + + current_grants = [ + ("SELECT", "user1"), + ("SELECT", "user5"), + ("INSERT", "user2"), + ("UPDATE", "user3"), + ] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="postgres") + + assert executed_sql == ( + "SELECT privilege_type, grantee FROM information_schema.role_table_grants " + "WHERE table_schema = 'test_schema' AND table_name = 'test_table' " + "AND grantor = current_role AND grantee <> current_role" + ) + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 4 + + assert 'GRANT SELECT ON "test_schema"."test_table" TO "user2", "user3"' in sql_calls + assert 'GRANT INSERT ON "test_schema"."test_table" TO "user4"' in sql_calls + assert 'REVOKE SELECT ON "test_schema"."test_table" FROM "user5"' in sql_calls + assert 'REVOKE UPDATE ON "test_schema"."test_table" FROM "user3"' in sql_calls + + +def test_diff_grants_configs(make_mocked_engine_adapter: t.Callable): + new_grants = {"select": ["USER1", "USER2"], "insert": ["user3"]} + old_grants = {"SELECT": ["user1", "user4"], "UPDATE": ["user5"]} + + adapter = make_mocked_engine_adapter(PostgresEngineAdapter) + additions, removals = adapter._diff_grants_configs(new_grants, old_grants) + + assert additions["select"] == ["USER2"] + assert additions["insert"] == ["user3"] + + assert removals["SELECT"] == ["user4"] + assert removals["UPDATE"] == ["user5"] + + +def test_sync_grants_config_with_default_schema( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(PostgresEngineAdapter) + relation = exp.to_table("test_table", dialect="postgres") # No schema + new_grants_config = {"SELECT": ["user1"], "INSERT": ["user2"]} + + currrent_grants = [("UPDATE", "old_user")] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=currrent_grants) + get_schema_mock = mocker.patch.object(adapter, "_get_current_schema", return_value="public") + + adapter.sync_grants_config(relation, new_grants_config) + + get_schema_mock.assert_called_once() + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="postgres") + + assert executed_sql == ( + "SELECT privilege_type, grantee FROM information_schema.role_table_grants " + "WHERE table_schema = 'public' AND table_name = 'test_table' " + "AND grantor = current_role AND grantee <> current_role" + ) diff --git a/tests/core/engine_adapter/test_redshift.py b/tests/core/engine_adapter/test_redshift.py index c5e3dfff17..5438943556 100644 --- a/tests/core/engine_adapter/test_redshift.py +++ b/tests/core/engine_adapter/test_redshift.py @@ -9,7 +9,7 @@ from sqlglot import parse_one from sqlmesh.core.engine_adapter import RedshiftEngineAdapter -from sqlmesh.core.engine_adapter.shared import DataObject +from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType from sqlmesh.utils.errors import SQLMeshError from tests.core.engine_adapter import to_sql_calls @@ -83,6 +83,154 @@ def test_varchar_size_workaround(make_mocked_engine_adapter: t.Callable, mocker: ] +def test_sync_grants_config(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(RedshiftEngineAdapter) + relation = exp.to_table("test_schema.test_table", dialect="redshift") + new_grants_config = {"SELECT": ["user1", "user2"], "INSERT": ["user3"]} + + current_grants = [("SELECT", "old_user"), ("UPDATE", "legacy_user")] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="redshift") + expected_sql = ( + "SELECT privilege_type, grantee FROM information_schema.table_privileges " + "WHERE table_schema = 'test_schema' AND table_name = 'test_table' " + "AND grantor = CURRENT_USER AND grantee <> CURRENT_USER" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 4 + assert 'REVOKE SELECT ON "test_schema"."test_table" FROM "old_user"' in sql_calls + assert 'REVOKE UPDATE ON "test_schema"."test_table" FROM "legacy_user"' in sql_calls + assert 'GRANT SELECT ON "test_schema"."test_table" TO "user1", "user2"' in sql_calls + assert 'GRANT INSERT ON "test_schema"."test_table" TO "user3"' in sql_calls + + +def test_sync_grants_config_with_overlaps( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(RedshiftEngineAdapter) + relation = exp.to_table("test_schema.test_table", dialect="redshift") + new_grants_config = { + "SELECT": ["user_shared", "user_new"], + "INSERT": ["user_shared", "user_writer"], + } + + current_grants = [ + ("SELECT", "user_shared"), + ("SELECT", "user_legacy"), + ("INSERT", "user_shared"), + ] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="redshift") + expected_sql = ( + "SELECT privilege_type, grantee FROM information_schema.table_privileges " + "WHERE table_schema = 'test_schema' AND table_name = 'test_table' " + "AND grantor = CURRENT_USER AND grantee <> CURRENT_USER" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 3 + assert 'REVOKE SELECT ON "test_schema"."test_table" FROM "user_legacy"' in sql_calls + assert 'GRANT SELECT ON "test_schema"."test_table" TO "user_new"' in sql_calls + assert 'GRANT INSERT ON "test_schema"."test_table" TO "user_writer"' in sql_calls + + +@pytest.mark.parametrize( + "table_type", + [ + (DataObjectType.TABLE), + (DataObjectType.VIEW), + (DataObjectType.MATERIALIZED_VIEW), + ], +) +def test_sync_grants_config_object_kind( + make_mocked_engine_adapter: t.Callable, + mocker: MockerFixture, + table_type: DataObjectType, +) -> None: + adapter = make_mocked_engine_adapter(RedshiftEngineAdapter) + relation = exp.to_table("test_schema.test_object", dialect="redshift") + + mocker.patch.object(adapter, "fetchall", return_value=[]) + + adapter.sync_grants_config(relation, {"SELECT": ["user_test"]}, table_type) + + sql_calls = to_sql_calls(adapter) + # we don't need to explicitly specify object_type for tables and views + assert sql_calls == [f'GRANT SELECT ON "test_schema"."test_object" TO "user_test"'] + + +def test_sync_grants_config_quotes(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(RedshiftEngineAdapter) + relation = exp.to_table('"TestSchema"."TestTable"', dialect="redshift") + new_grants_config = {"SELECT": ["user1", "user2"], "INSERT": ["user3"]} + + current_grants = [("SELECT", "user_old"), ("UPDATE", "user_legacy")] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="redshift") + expected_sql = ( + "SELECT privilege_type, grantee FROM information_schema.table_privileges " + "WHERE table_schema = 'TestSchema' AND table_name = 'TestTable' " + "AND grantor = CURRENT_USER AND grantee <> CURRENT_USER" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 4 + assert 'REVOKE SELECT ON "TestSchema"."TestTable" FROM "user_old"' in sql_calls + assert 'REVOKE UPDATE ON "TestSchema"."TestTable" FROM "user_legacy"' in sql_calls + assert 'GRANT SELECT ON "TestSchema"."TestTable" TO "user1", "user2"' in sql_calls + assert 'GRANT INSERT ON "TestSchema"."TestTable" TO "user3"' in sql_calls + + +def test_sync_grants_config_no_schema( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(RedshiftEngineAdapter) + relation = exp.to_table("test_table", dialect="redshift") + new_grants_config = {"SELECT": ["user1"], "INSERT": ["user2"]} + + current_grants = [("UPDATE", "user_old")] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + get_schema_mock = mocker.patch.object(adapter, "_get_current_schema", return_value="public") + + adapter.sync_grants_config(relation, new_grants_config) + + get_schema_mock.assert_called_once() + + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="redshift") + expected_sql = ( + "SELECT privilege_type, grantee FROM information_schema.table_privileges " + "WHERE table_schema = 'public' AND table_name = 'test_table' " + "AND grantor = CURRENT_USER AND grantee <> CURRENT_USER" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 3 + assert 'REVOKE UPDATE ON "test_table" FROM "user_old"' in sql_calls + assert 'GRANT SELECT ON "test_table" TO "user1"' in sql_calls + assert 'GRANT INSERT ON "test_table" TO "user2"' in sql_calls + + def test_create_table_from_query_exists_no_if_not_exists( adapter: t.Callable, mocker: MockerFixture ): diff --git a/tests/core/engine_adapter/test_snowflake.py b/tests/core/engine_adapter/test_snowflake.py index ce4d3a886c..60f6d38e5f 100644 --- a/tests/core/engine_adapter/test_snowflake.py +++ b/tests/core/engine_adapter/test_snowflake.py @@ -4,6 +4,7 @@ import pytest from pytest_mock.plugin import MockerFixture from sqlglot import exp, parse_one +from sqlglot.optimizer.normalize_identifiers import normalize_identifiers import sqlmesh.core.dialect as d from sqlmesh.core.dialect import normalize_model_name @@ -245,6 +246,204 @@ def test_multiple_column_comments(make_mocked_engine_adapter: t.Callable, mocker ] +def test_sync_grants_config(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) + relation = normalize_identifiers( + exp.to_table("test_db.test_schema.test_table", dialect="snowflake"), dialect="snowflake" + ) + new_grants_config = {"SELECT": ["ROLE role1", "ROLE role2"], "INSERT": ["ROLE role3"]} + + current_grants = [ + ("SELECT", "ROLE old_role"), + ("UPDATE", "ROLE legacy_role"), + ] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="snowflake") + expected_sql = ( + "SELECT privilege_type, grantee FROM TEST_DB.INFORMATION_SCHEMA.TABLE_PRIVILEGES " + "WHERE table_catalog = 'TEST_DB' AND table_schema = 'TEST_SCHEMA' AND table_name = 'TEST_TABLE' " + "AND grantor = CURRENT_ROLE() AND grantee <> CURRENT_ROLE()" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 5 + + assert 'GRANT SELECT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" TO ROLE "ROLE1"' in sql_calls + assert 'GRANT SELECT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" TO ROLE "ROLE2"' in sql_calls + assert 'GRANT INSERT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" TO ROLE "ROLE3"' in sql_calls + assert ( + 'REVOKE SELECT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" FROM ROLE "OLD_ROLE"' + in sql_calls + ) + assert ( + 'REVOKE UPDATE ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" FROM ROLE "LEGACY_ROLE"' + in sql_calls + ) + + +def test_sync_grants_config_with_overlaps( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) + relation = normalize_identifiers( + exp.to_table("test_db.test_schema.test_table", dialect="snowflake"), dialect="snowflake" + ) + new_grants_config = { + "SELECT": ["ROLE shared", "ROLE new_role"], + "INSERT": ["ROLE shared", "ROLE writer"], + } + + current_grants = [ + ("SELECT", "ROLE shared"), + ("SELECT", "ROLE legacy"), + ("INSERT", "ROLE shared"), + ] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="snowflake") + expected_sql = ( + """SELECT privilege_type, grantee FROM TEST_DB.INFORMATION_SCHEMA.TABLE_PRIVILEGES """ + "WHERE table_catalog = 'TEST_DB' AND table_schema = 'TEST_SCHEMA' AND table_name = 'TEST_TABLE' " + "AND grantor = CURRENT_ROLE() AND grantee <> CURRENT_ROLE()" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 3 + + assert ( + 'GRANT SELECT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" TO ROLE "NEW_ROLE"' in sql_calls + ) + assert ( + 'GRANT INSERT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" TO ROLE "WRITER"' in sql_calls + ) + assert ( + 'REVOKE SELECT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" FROM ROLE "LEGACY"' + in sql_calls + ) + + +@pytest.mark.parametrize( + "table_type, expected_keyword", + [ + (DataObjectType.TABLE, "TABLE"), + (DataObjectType.VIEW, "VIEW"), + (DataObjectType.MATERIALIZED_VIEW, "MATERIALIZED VIEW"), + (DataObjectType.MANAGED_TABLE, "DYNAMIC TABLE"), + ], +) +def test_sync_grants_config_object_kind( + make_mocked_engine_adapter: t.Callable, + mocker: MockerFixture, + table_type: DataObjectType, + expected_keyword: str, +) -> None: + adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) + relation = normalize_identifiers( + exp.to_table("test_db.test_schema.test_object", dialect="snowflake"), dialect="snowflake" + ) + + mocker.patch.object(adapter, "fetchall", return_value=[]) + + adapter.sync_grants_config(relation, {"SELECT": ["ROLE test"]}, table_type) + + sql_calls = to_sql_calls(adapter) + assert sql_calls == [ + f'GRANT SELECT ON {expected_keyword} "TEST_DB"."TEST_SCHEMA"."TEST_OBJECT" TO ROLE "TEST"' + ] + + +def test_sync_grants_config_quotes(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) + relation = normalize_identifiers( + exp.to_table('"test_db"."test_schema"."test_table"', dialect="snowflake"), + dialect="snowflake", + ) + new_grants_config = {"SELECT": ["ROLE role1", "ROLE role2"], "INSERT": ["ROLE role3"]} + + current_grants = [ + ("SELECT", "ROLE old_role"), + ("UPDATE", "ROLE legacy_role"), + ] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="snowflake") + expected_sql = ( + """SELECT privilege_type, grantee FROM "test_db".INFORMATION_SCHEMA.TABLE_PRIVILEGES """ + "WHERE table_catalog = 'test_db' AND table_schema = 'test_schema' AND table_name = 'test_table' " + "AND grantor = CURRENT_ROLE() AND grantee <> CURRENT_ROLE()" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 5 + + assert 'GRANT SELECT ON TABLE "test_db"."test_schema"."test_table" TO ROLE "ROLE1"' in sql_calls + assert 'GRANT SELECT ON TABLE "test_db"."test_schema"."test_table" TO ROLE "ROLE2"' in sql_calls + assert 'GRANT INSERT ON TABLE "test_db"."test_schema"."test_table" TO ROLE "ROLE3"' in sql_calls + assert ( + 'REVOKE SELECT ON TABLE "test_db"."test_schema"."test_table" FROM ROLE "OLD_ROLE"' + in sql_calls + ) + assert ( + 'REVOKE UPDATE ON TABLE "test_db"."test_schema"."test_table" FROM ROLE "LEGACY_ROLE"' + in sql_calls + ) + + +def test_sync_grants_config_no_catalog_or_schema( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) + relation = normalize_identifiers( + exp.to_table('"TesT_Table"', dialect="snowflake"), dialect="snowflake" + ) + new_grants_config = {"SELECT": ["ROLE role1", "ROLE role2"], "INSERT": ["ROLE role3"]} + + current_grants = [ + ("SELECT", "ROLE old_role"), + ("UPDATE", "ROLE legacy_role"), + ] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + mocker.patch.object(adapter, "get_current_catalog", return_value="caTalog") + mocker.patch.object(adapter, "_get_current_schema", return_value="sChema") + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="snowflake") + expected_sql = ( + """SELECT privilege_type, grantee FROM "caTalog".INFORMATION_SCHEMA.TABLE_PRIVILEGES """ + "WHERE table_catalog = 'caTalog' AND table_schema = 'sChema' AND table_name = 'TesT_Table' " + "AND grantor = CURRENT_ROLE() AND grantee <> CURRENT_ROLE()" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 5 + + assert 'GRANT SELECT ON TABLE "TesT_Table" TO ROLE "ROLE1"' in sql_calls + assert 'GRANT SELECT ON TABLE "TesT_Table" TO ROLE "ROLE2"' in sql_calls + assert 'GRANT INSERT ON TABLE "TesT_Table" TO ROLE "ROLE3"' in sql_calls + assert 'REVOKE SELECT ON TABLE "TesT_Table" FROM ROLE "OLD_ROLE"' in sql_calls + assert 'REVOKE UPDATE ON TABLE "TesT_Table" FROM ROLE "LEGACY_ROLE"' in sql_calls + + def test_df_to_source_queries_use_schema( make_mocked_engine_adapter: t.Callable, mocker: MockerFixture ): diff --git a/tests/core/engine_adapter/test_spark.py b/tests/core/engine_adapter/test_spark.py index bc4e352bd7..d7c3127f05 100644 --- a/tests/core/engine_adapter/test_spark.py +++ b/tests/core/engine_adapter/test_spark.py @@ -224,7 +224,7 @@ def test_replace_query_self_ref_not_exists( lambda self: "spark_catalog", ) mocker.patch( - "sqlmesh.core.engine_adapter.spark.SparkEngineAdapter.get_current_database", + "sqlmesh.core.engine_adapter.spark.SparkEngineAdapter._get_current_schema", side_effect=lambda: "default", ) @@ -283,7 +283,7 @@ def test_replace_query_self_ref_exists( return_value="spark_catalog", ) mocker.patch( - "sqlmesh.core.engine_adapter.spark.SparkEngineAdapter.get_current_database", + "sqlmesh.core.engine_adapter.spark.SparkEngineAdapter._get_current_schema", return_value="default", ) diff --git a/tests/core/test_context.py b/tests/core/test_context.py index b7ce64eb4c..6270cec56a 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -3050,9 +3050,10 @@ def test_uppercase_gateway_external_models(tmp_path): # Check that the column types are properly loaded (not UNKNOWN) external_model = gateway_specific_models[0] column_types = {name: str(dtype) for name, dtype in external_model.columns_to_types.items()} - assert column_types == {"id": "INT", "name": "TEXT"}, ( - f"External model column types should not be UNKNOWN, got: {column_types}" - ) + assert column_types == { + "id": "INT", + "name": "TEXT", + }, f"External model column types should not be UNKNOWN, got: {column_types}" # Test that when using a different case for the gateway parameter, we get the same results context_mixed_case = Context( @@ -3177,3 +3178,55 @@ def test_lint_model_projections(tmp_path: Path): with pytest.raises(LinterError, match=config_err): prod_plan = context.plan(no_prompts=True, auto_apply=True) + + +def test_grants_through_plan_apply(sushi_context, mocker): + from sqlmesh.core.engine_adapter.duckdb import DuckDBEngineAdapter + from sqlmesh.core.model.meta import GrantsTargetLayer + + model = sushi_context.get_model("sushi.waiter_revenue_by_day") + + mocker.patch.object(DuckDBEngineAdapter, "SUPPORTS_GRANTS", True) + sync_grants_mock = mocker.patch.object(DuckDBEngineAdapter, "sync_grants_config") + + model_with_grants = model.copy( + update={ + "grants": {"select": ["analyst", "reporter"]}, + "grants_target_layer": GrantsTargetLayer.ALL, + } + ) + sushi_context.upsert_model(model_with_grants) + + sushi_context.plan("dev", no_prompts=True, auto_apply=True) + + # When planning for dev env w/ metadata only changes, + # only virtual layer is updated, so no physical grants are applied + assert sync_grants_mock.call_count == 1 + assert all( + call[0][1] == {"select": ["analyst", "reporter"]} + for call in sync_grants_mock.call_args_list + ) + + sync_grants_mock.reset_mock() + + new_grants = ({"select": ["analyst", "reporter", "manager"], "insert": ["etl_user"]},) + model_updated = model_with_grants.copy( + update={ + "query": parse_one(model.query.sql() + " LIMIT 1000"), + "grants": new_grants, + # force model update, hence new physical table creation + "stamp": "update model and grants", + } + ) + sushi_context.upsert_model(model_updated) + sushi_context.plan("dev", no_prompts=True, auto_apply=True) + + # Applies grants 2 times: 1 x physical, 1 x virtual + assert sync_grants_mock.call_count == 2 + assert all(call[0][1] == new_grants for call in sync_grants_mock.call_args_list) + + sync_grants_mock.reset_mock() + + # plan for prod + sushi_context.plan(no_prompts=True, auto_apply=True) + assert sync_grants_mock.call_count == 2 diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 726ac52b66..f1a9eeb0b9 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -1,6 +1,7 @@ # ruff: noqa: F811 import json import typing as t +import re from datetime import date, datetime from pathlib import Path from unittest.mock import patch, PropertyMock @@ -14,7 +15,7 @@ from sqlglot.schema import MappingSchema from sqlmesh.cli.project_init import init_example_project, ProjectTemplate from sqlmesh.core.environment import EnvironmentNamingInfo -from sqlmesh.core.model.kind import TimeColumn, ModelKindName +from sqlmesh.core.model.kind import TimeColumn, ModelKindName, SeedKind from sqlmesh import CustomMaterialization, CustomKind from pydantic import model_validator, ValidationError @@ -36,6 +37,7 @@ from sqlmesh.core.dialect import parse from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS from sqlmesh.core.engine_adapter.duckdb import DuckDBEngineAdapter +from sqlmesh.core.engine_adapter.shared import DataObjectType from sqlmesh.core.macros import MacroEvaluator, macro from sqlmesh.core.model import ( CustomKind, @@ -51,6 +53,8 @@ TimeColumn, ExternalKind, ViewKind, + EmbeddedKind, + SCDType2ByTimeKind, create_external_model, create_seed_model, create_sql_model, @@ -59,7 +63,7 @@ model, ) from sqlmesh.core.model.common import parse_expression -from sqlmesh.core.model.kind import ModelKindName, _model_kind_validator +from sqlmesh.core.model.kind import _ModelKind, ModelKindName, _model_kind_validator from sqlmesh.core.model.seed import CsvSettings from sqlmesh.core.node import IntervalUnit, _Node, DbtNodeInfo from sqlmesh.core.signal import signal @@ -1922,7 +1926,8 @@ def test_render_definition_with_defaults(): kind VIEW ( materialized FALSE ), - virtual_environment_mode 'full' + virtual_environment_mode 'full', + grants_target_layer 'virtual' ); {query} @@ -1935,6 +1940,90 @@ def test_render_definition_with_defaults(): ) == d.format_model_expressions(expected_expressions) +def test_render_definition_with_grants(): + from sqlmesh.core.model.meta import GrantsTargetLayer + + expressions = d.parse( + """ + MODEL ( + name test.grants_model, + kind FULL, + grants ( + 'select' = ['user1', 'user2'], + 'insert' = ['admin'], + 'roles/bigquery.dataViewer' = ['user:data_eng@mycompany.com'] + ), + grants_target_layer all, + ); + SELECT 1 as id + """ + ) + model = load_sql_based_model(expressions) + assert model.grants_target_layer == GrantsTargetLayer.ALL + assert model.grants == { + "select": ["user1", "user2"], + "insert": ["admin"], + "roles/bigquery.dataViewer": ["user:data_eng@mycompany.com"], + } + + rendered = model.render_definition(include_defaults=True) + rendered_text = d.format_model_expressions(rendered) + assert "grants_target_layer 'all'" in rendered_text + assert re.search( + r"grants\s*\(" + r"\s*'select'\s*=\s*ARRAY\('user1',\s*'user2'\)," + r"\s*'insert'\s*=\s*ARRAY\('admin'\)," + r"\s*'roles/bigquery.dataViewer'\s*=\s*ARRAY\('user:data_eng@mycompany.com'\)" + r"\s*\)", + rendered_text, + ) + + model_with_grants = create_sql_model( + name="test_grants_programmatic", + query=d.parse_one("SELECT 1 as id"), + grants={"select": ["user1", "user2"], "insert": ["admin"]}, + grants_target_layer=GrantsTargetLayer.ALL, + ) + assert model_with_grants.grants == {"select": ["user1", "user2"], "insert": ["admin"]} + assert model_with_grants.grants_target_layer == GrantsTargetLayer.ALL + rendered_text = d.format_model_expressions( + model_with_grants.render_definition(include_defaults=True) + ) + assert "grants_target_layer 'all'" in rendered_text + assert re.search( + r"grants\s*\(" + r"\s*'select'\s*=\s*ARRAY\('user1',\s*'user2'\)," + r"\s*'insert'\s*=\s*ARRAY\('admin'\)" + r"\s*\)", + rendered_text, + ) + + virtual_expressions = d.parse( + """ + MODEL ( + name test.virtual_grants_model, + kind FULL, + grants_target_layer virtual + ); + SELECT 1 as id + """ + ) + virtual_model = load_sql_based_model(virtual_expressions) + assert virtual_model.grants_target_layer == GrantsTargetLayer.VIRTUAL + + default_expressions = d.parse( + """ + MODEL ( + name test.default_grants_model, + kind FULL + ); + SELECT 1 as id + """ + ) + default_model = load_sql_based_model(default_expressions) + assert default_model.grants_target_layer == GrantsTargetLayer.VIRTUAL # default value + + def test_render_definition_partitioned_by(): # no parenthesis in definition, no parenthesis when rendered model = load_sql_based_model( @@ -11717,3 +11806,254 @@ def my_macro(evaluator): model = context.get_model("test_model", raise_if_missing=True) assert model.render_query_or_raise().sql() == 'SELECT 3 AS "c"' + + +def test_grants(): + expressions = d.parse(""" + MODEL ( + name test.table, + kind FULL, + grants ( + 'select' = ['user1', 123, admin_role, 'user2'], + 'insert' = 'admin', + 'roles/bigquery.dataViewer' = ["group:data_eng@company.com", 'user:someone@company.com'], + 'update' = 'admin' + ) + ); + SELECT 1 as id + """) + model = load_sql_based_model(expressions) + assert model.grants == { + "select": ["user1", "123", "admin_role", "user2"], + "insert": ["admin"], + "roles/bigquery.dataViewer": ["group:data_eng@company.com", "user:someone@company.com"], + "update": ["admin"], + } + + model = create_sql_model( + "db.table", + parse_one("SELECT 1 AS id"), + kind="FULL", + grants={ + "select": ["user1", "user2"], + "insert": ["admin"], + "roles/bigquery.dataViewer": "user:data_eng@company.com", + }, + ) + assert model.grants == { + "select": ["user1", "user2"], + "insert": ["admin"], + "roles/bigquery.dataViewer": ["user:data_eng@company.com"], + } + + +@pytest.mark.parametrize( + "kind", + [ + "FULL", + "VIEW", + SeedKind(path="test.csv"), + IncrementalByTimeRangeKind(time_column="ds"), + IncrementalByUniqueKeyKind(unique_key="id"), + ], +) +def test_grants_valid_model_kinds(kind: t.Union[str, _ModelKind]): + model = create_sql_model( + "db.table", + parse_one("SELECT 1 AS id"), + kind=kind, + grants={"select": ["user1", "user2"], "insert": ["admin_user"]}, + ) + assert model.grants == {"select": ["user1", "user2"], "insert": ["admin_user"]} + + +@pytest.mark.parametrize( + "kind", + [ + "EXTERNAL", + "EMBEDDED", + ], +) +def test_grants_invalid_model_kind_errors(kind: str): + with pytest.raises(ValidationError, match=rf".*grants cannot be set for {kind}.*"): + create_sql_model( + "db.table", + parse_one("SELECT 1 AS id"), + kind=kind, + grants={"select": ["user1"], "insert": ["admin_user"]}, + ) + + +def test_model_kind_supports_grants(): + assert FullKind().supports_grants is True + assert ViewKind().supports_grants is True + assert IncrementalByTimeRangeKind(time_column="ds").supports_grants is True + assert IncrementalByUniqueKeyKind(unique_key=["id"]).supports_grants is True + assert SCDType2ByTimeKind(unique_key=["id"]).supports_grants is True + + assert EmbeddedKind().supports_grants is False + assert ExternalKind().supports_grants is False + + +def test_grants_validation_no_grants(): + model = create_sql_model("db.table", parse_one("SELECT 1 AS id"), kind="FULL") + assert model.grants is None + + +def test_grants_validation_empty_grantees(): + model = create_sql_model( + "db.table", parse_one("SELECT 1 AS id"), kind="FULL", grants={"select": []} + ) + assert model.grants == {"select": []} + + +def test_grants_single_value_conversions(): + expressions = d.parse(f""" + MODEL ( + name test.nested_arrays, + kind FULL, + grants ( + 'select' = "user1", update = user2 + ) + ); + SELECT 1 as id + """) + model = load_sql_based_model(expressions) + assert model.grants == {"select": ["user1"], "update": ["user2"]} + + model = create_sql_model( + "db.table", + parse_one("SELECT 1 AS id"), + kind="FULL", + grants={"select": "user1", "insert": 123}, + ) + assert model.grants == {"select": ["user1"], "insert": ["123"]} + + +@pytest.mark.parametrize( + "grantees", + [ + "('user1', ('user2', 'user3'), 'user4')", + "('user1', ['user2', 'user3'], user4)", + "['user1', ['user2', user3], 'user4']", + "[user1, ('user2', \"user3\"), 'user4']", + ], +) +def test_grants_array_flattening(grantees: str): + expressions = d.parse(f""" + MODEL ( + name test.nested_arrays, + kind FULL, + grants ( + 'select' = {grantees} + ) + ); + SELECT 1 as id + """) + model = load_sql_based_model(expressions) + assert model.grants == {"select": ["user1", "user2", "user3", "user4"]} + + +def test_grants_macro_var_resolved(): + expressions = d.parse(""" + MODEL ( + name test.macro_grants, + kind FULL, + grants ( + 'select' = @VAR('readers'), + 'insert' = @VAR('writers') + ) + ); + SELECT 1 as id + """) + model = load_sql_based_model( + expressions, variables={"readers": ["user1", "user2"], "writers": "admin"} + ) + assert model.grants == { + "select": ["user1", "user2"], + "insert": ["admin"], + } + + +def test_grants_macro_var_in_array_flattening(): + expressions = d.parse(""" + MODEL ( + name test.macro_in_array, + kind FULL, + grants ( + 'select' = ['user1', @VAR('admins'), 'user3'] + ) + ); + SELECT 1 as id + """) + + model = load_sql_based_model(expressions, variables={"admins": ["admin1", "admin2"]}) + assert model.grants == {"select": ["user1", "admin1", "admin2", "user3"]} + + model2 = load_sql_based_model(expressions, variables={"admins": "super_admin"}) + assert model2.grants == {"select": ["user1", "super_admin", "user3"]} + + +def test_grants_dynamic_permission_names(): + expressions = d.parse(""" + MODEL ( + name test.dynamic_keys, + kind FULL, + grants ( + @VAR('read_perm') = ['user1', 'user2'], + @VAR('write_perm') = ['admin'] + ) + ); + SELECT 1 as id + """) + model = load_sql_based_model( + expressions, variables={"read_perm": "select", "write_perm": "insert"} + ) + assert model.grants == {"select": ["user1", "user2"], "insert": ["admin"]} + + +def test_grants_unresolved_macro_errors(): + expressions1 = d.parse(""" + MODEL (name test.bad1, kind FULL, grants ('select' = @VAR('undefined'))); + SELECT 1 as id + """) + with pytest.raises(ConfigError, match=r"Invalid grants configuration for 'select': NULL value"): + load_sql_based_model(expressions1) + + expressions2 = d.parse(""" + MODEL (name test.bad2, kind FULL, grants (@VAR('undefined') = ['user'])); + SELECT 1 as id + """) + with pytest.raises(ConfigError, match=r"Invalid grants configuration.*NULL value"): + load_sql_based_model(expressions2) + + expressions3 = d.parse(""" + MODEL (name test.bad3, kind FULL, grants ('select' = ['user', @VAR('undefined')])); + SELECT 1 as id + """) + with pytest.raises(ConfigError, match=r"Invalid grants configuration for 'select': NULL value"): + load_sql_based_model(expressions3) + + +def test_grants_empty_values(): + model1 = create_sql_model( + "db.table", parse_one("SELECT 1 AS id"), kind="FULL", grants={"select": []} + ) + assert model1.grants == {"select": []} + + model2 = create_sql_model("db.table", parse_one("SELECT 1 AS id"), kind="FULL") + assert model2.grants is None + + +@pytest.mark.parametrize( + "kind, expected", + [ + ("VIEW", DataObjectType.VIEW), + ("FULL", DataObjectType.TABLE), + ("MANAGED", DataObjectType.MANAGED_TABLE), + (ViewKind(materialized=True), DataObjectType.MATERIALIZED_VIEW), + ], +) +def test_grants_table_type(kind: t.Union[str, _ModelKind], expected: DataObjectType): + model = create_sql_model("test_table", parse_one("SELECT 1 as id"), kind=kind) + assert model.grants_table_type == expected diff --git a/tests/core/test_snapshot.py b/tests/core/test_snapshot.py index c769991b86..1acc6cc265 100644 --- a/tests/core/test_snapshot.py +++ b/tests/core/test_snapshot.py @@ -168,6 +168,7 @@ def test_json(snapshot: Snapshot): "enabled": True, "extract_dependencies_from_query": True, "virtual_environment_mode": "full", + "grants_target_layer": "virtual", }, "name": '"name"', "parents": [{"name": '"parent"."tbl"', "identifier": snapshot.parents[0].identifier}], @@ -181,6 +182,36 @@ def test_json(snapshot: Snapshot): } +def test_json_with_grants(make_snapshot: t.Callable): + from sqlmesh.core.model.meta import GrantsTargetLayer + + model = SqlModel( + name="name", + kind=dict(time_column="ds", batch_size=30, name=ModelKindName.INCREMENTAL_BY_TIME_RANGE), + owner="owner", + dialect="spark", + cron="1 0 * * *", + start="2020-01-01", + query=parse_one("SELECT @EACH([1, 2], x -> x), ds FROM parent.tbl"), + grants={"SELECT": ["role1", "role2"], "INSERT": ["role3"]}, + grants_target_layer=GrantsTargetLayer.VIRTUAL, + ) + snapshot = make_snapshot(model) + + json_str = snapshot.json() + json_data = json.loads(json_str) + assert ( + json_data["node"]["grants"] + == "('SELECT' = ARRAY('role1', 'role2'), 'INSERT' = ARRAY('role3'))" + ) + assert json_data["node"]["grants_target_layer"] == "virtual" + + reparsed_snapshot = Snapshot.model_validate_json(json_str) + assert isinstance(reparsed_snapshot.node, SqlModel) + assert reparsed_snapshot.node.grants == {"SELECT": ["role1", "role2"], "INSERT": ["role3"]} + assert reparsed_snapshot.node.grants_target_layer == GrantsTargetLayer.VIRTUAL + + def test_json_custom_materialization(make_snapshot: t.Callable): model = SqlModel( name="name", @@ -954,7 +985,7 @@ def test_fingerprint(model: Model, parent_model: Model): original_fingerprint = SnapshotFingerprint( data_hash="2406542604", - metadata_hash="3341445192", + metadata_hash="1056339358", ) assert fingerprint == original_fingerprint @@ -1014,8 +1045,8 @@ def test_fingerprint_seed_model(): ) expected_fingerprint = SnapshotFingerprint( - data_hash="1586624913", - metadata_hash="2315134974", + data_hash="2112858704", + metadata_hash="2674364560", ) model = load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql")) @@ -1054,7 +1085,7 @@ def test_fingerprint_jinja_macros(model: Model): ) original_fingerprint = SnapshotFingerprint( data_hash="93332825", - metadata_hash="3341445192", + metadata_hash="1056339358", ) fingerprint = fingerprint_from_node(model, nodes={}) @@ -1131,6 +1162,40 @@ def test_fingerprint_virtual_properties(model: Model, parent_model: Model): assert updated_fingerprint.data_hash == fingerprint.data_hash +def test_fingerprint_grants(model: Model, parent_model: Model): + from sqlmesh.core.model.meta import GrantsTargetLayer + + original_model = deepcopy(model) + fingerprint = fingerprint_from_node(model, nodes={}) + + updated_model = SqlModel( + **original_model.dict(), + grants={"SELECT": ["role1", "role2"]}, + ) + updated_fingerprint = fingerprint_from_node(updated_model, nodes={}) + + assert updated_fingerprint != fingerprint + assert updated_fingerprint.metadata_hash != fingerprint.metadata_hash + assert updated_fingerprint.data_hash == fingerprint.data_hash + + different_grants_model = SqlModel( + **original_model.dict(), + grants={"SELECT": ["role3"], "INSERT": ["role4"]}, + ) + different_grants_fingerprint = fingerprint_from_node(different_grants_model, nodes={}) + + assert different_grants_fingerprint.metadata_hash != updated_fingerprint.metadata_hash + assert different_grants_fingerprint.metadata_hash != fingerprint.metadata_hash + + target_layer_model = SqlModel( + **{**original_model.dict(), "grants_target_layer": GrantsTargetLayer.PHYSICAL}, + grants={"SELECT": ["role1", "role2"]}, + ) + target_layer_fingerprint = fingerprint_from_node(target_layer_model, nodes={}) + + assert target_layer_fingerprint.metadata_hash != updated_fingerprint.metadata_hash + + def test_tableinfo_equality(): snapshot_a = SnapshotTableInfo( name="test_schema.a", diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index 19685e81c3..68061544a8 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -41,8 +41,10 @@ load_sql_based_model, ExternalModel, model, + create_sql_model, ) from sqlmesh.core.model.kind import OnDestructiveChange, ExternalKind, OnAdditiveChange +from sqlmesh.core.model.meta import GrantsTargetLayer from sqlmesh.core.node import IntervalUnit from sqlmesh.core.snapshot import ( DeployabilityIndex, @@ -55,7 +57,19 @@ SnapshotTableCleanupTask, ) from sqlmesh.core.snapshot.definition import to_view_mapping -from sqlmesh.core.snapshot.evaluator import CustomMaterialization, SnapshotCreationFailedError +from sqlmesh.core.snapshot.evaluator import ( + CustomMaterialization, + EngineManagedStrategy, + FullRefreshStrategy, + IncrementalByPartitionStrategy, + IncrementalByTimeRangeStrategy, + IncrementalByUniqueKeyStrategy, + IncrementalUnmanagedStrategy, + MaterializableStrategy, + SCDType2Strategy, + SnapshotCreationFailedError, + ViewStrategy, +) from sqlmesh.utils.concurrency import NodeExecutionFailedError from sqlmesh.utils.date import to_timestamp from sqlmesh.utils.errors import ( @@ -908,7 +922,7 @@ def test_pre_hook_forward_only_clone( time_column ds ) ); - + {pre_statement}; SELECT a::int, ds::string FROM tbl; @@ -4858,3 +4872,524 @@ def mutate_view_properties(*args, **kwargs): # Both calls should have view_properties with security invoker assert props == ["'SECURITY INVOKER'", "'SECURITY INVOKER'"] + + +def _create_grants_test_model( + grants=None, kind="FULL", grants_target_layer=None, virtual_environment_mode=None +): + if kind == "SEED": + from sqlmesh.core.model.definition import create_seed_model + from sqlmesh.core.model.kind import SeedKind + import tempfile + import os + + # Create a temporary CSV file for the test + temp_csv = tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) + temp_csv.write("id,name\n1,test\n2,test2\n") + temp_csv.flush() + temp_csv.close() + + seed_kind_config = {"name": "SEED", "path": temp_csv.name} + seed_kind = SeedKind(**seed_kind_config) + + kwargs = {} + if grants is not None: + kwargs["grants"] = grants + if grants_target_layer is not None: + kwargs["grants_target_layer"] = grants_target_layer + + model = create_seed_model("test_model", seed_kind, **kwargs) + + # Clean up the temporary file + os.unlink(temp_csv.name) + + return model + + # Handle regular SQL models + kwargs = { + "kind": kind, + } + if grants is not None: + kwargs["grants"] = grants + if grants_target_layer is not None: + kwargs["grants_target_layer"] = grants_target_layer + if virtual_environment_mode is not None: + kwargs["virtual_environment_mode"] = virtual_environment_mode + + # Add column annotations for non-SEED models to ensure table creation + if kind != "SEED": + kwargs["columns"] = { + "id": "INT", + "ds": "DATE", + "updated_at": "TIMESTAMP", + } + + # Add required fields for specific model kinds + if kind == "INCREMENTAL_BY_TIME_RANGE": + kwargs["kind"] = {"name": "INCREMENTAL_BY_TIME_RANGE", "time_column": "ds"} + elif kind == "INCREMENTAL_BY_PARTITION": + kwargs["kind"] = {"name": "INCREMENTAL_BY_PARTITION"} + kwargs["partitioned_by"] = ["ds"] # This goes on the model, not the kind + elif kind == "INCREMENTAL_BY_UNIQUE_KEY": + kwargs["kind"] = {"name": "INCREMENTAL_BY_UNIQUE_KEY", "unique_key": ["id"]} + elif kind == "INCREMENTAL_UNMANAGED": + kwargs["kind"] = {"name": "INCREMENTAL_UNMANAGED"} + elif kind == "SCD_TYPE_2": + kwargs["kind"] = { + "name": "SCD_TYPE_2", + "unique_key": ["id"], + "updated_at_name": "updated_at", + } + + return create_sql_model( + "test_model", + parse_one("SELECT 1 as id, CURRENT_DATE as ds, CURRENT_TIMESTAMP as updated_at"), + **kwargs, + ) + + +@pytest.mark.parametrize( + "target_layer,apply_layer,expected_call_count", + [ + (GrantsTargetLayer.ALL, GrantsTargetLayer.PHYSICAL, 1), + (GrantsTargetLayer.ALL, GrantsTargetLayer.VIRTUAL, 1), + (GrantsTargetLayer.PHYSICAL, GrantsTargetLayer.PHYSICAL, 1), + (GrantsTargetLayer.PHYSICAL, GrantsTargetLayer.VIRTUAL, 0), + (GrantsTargetLayer.VIRTUAL, GrantsTargetLayer.PHYSICAL, 0), + (GrantsTargetLayer.VIRTUAL, GrantsTargetLayer.VIRTUAL, 1), + ], +) +def test_apply_grants_target_layer( + target_layer: GrantsTargetLayer, + apply_layer: GrantsTargetLayer, + expected_call_count: int, + adapter_mock: Mock, + mocker: MockerFixture, +): + adapter_mock.SUPPORTS_GRANTS = True + sync_grants_mock = mocker.patch.object(adapter_mock, "sync_grants_config") + strategy = ViewStrategy(adapter_mock) + + model = _create_grants_test_model( + grants={"select": ["user1"]}, grants_target_layer=target_layer + ) + + strategy._apply_grants(model, "test_table", apply_layer) + + if expected_call_count > 0: + assert sync_grants_mock.call_count == expected_call_count + else: + sync_grants_mock.assert_not_called() + + +@pytest.mark.parametrize( + "model_kind_name", + [ + "FULL", + "INCREMENTAL_BY_TIME_RANGE", + "SEED", + "MANAGED", + "SCD_TYPE_2", + "VIEW", + ], +) +def test_grants_create_model_kind( + model_kind_name: str, + adapter_mock: Mock, + mocker: MockerFixture, + make_snapshot: t.Callable[..., Snapshot], +): + adapter_mock.SUPPORTS_GRANTS = True + sync_grants_mock = mocker.patch.object(adapter_mock, "sync_grants_config") + + grants = {"select": ["user1"]} + model = _create_grants_test_model( + grants=grants, kind=model_kind_name, grants_target_layer=GrantsTargetLayer.ALL + ) + snapshot = make_snapshot(model) + + evaluator = SnapshotEvaluator(adapter_mock) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + evaluator.create([snapshot], {}) + + sync_grants_mock.assert_called_once() + assert sync_grants_mock.call_args[0][1] == grants + + +@pytest.mark.parametrize( + "target_layer", + [ + GrantsTargetLayer.PHYSICAL, + GrantsTargetLayer.VIRTUAL, + GrantsTargetLayer.ALL, + ], +) +def test_grants_target_layer( + target_layer: GrantsTargetLayer, + adapter_mock: Mock, + mocker: MockerFixture, + make_snapshot: t.Callable[..., Snapshot], +): + adapter_mock.SUPPORTS_GRANTS = True + sync_grants_mock = mocker.patch.object(adapter_mock, "sync_grants_config") + evaluator = SnapshotEvaluator(adapter_mock) + + grants = {"select": ["user1"]} + model = create_sql_model( + "test_schema.test_model", + parse_one("SELECT 1 as id"), + kind="FULL", + grants=grants, + grants_target_layer=target_layer, + ) + + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator.create([snapshot], {}) + if target_layer == GrantsTargetLayer.VIRTUAL: + assert sync_grants_mock.call_count == 0 + else: + assert sync_grants_mock.call_count == 1 + assert sync_grants_mock.call_args[0][1] == grants + sync_grants_mock.reset_mock() + evaluator.promote([snapshot], EnvironmentNamingInfo(name="prod")) + if target_layer == GrantsTargetLayer.VIRTUAL: + assert sync_grants_mock.call_count == 1 + elif target_layer == GrantsTargetLayer.PHYSICAL: + # Physical layer: no grants applied during promotion (already applied during create) + assert sync_grants_mock.call_count == 0 + else: # target_layer == GrantsTargetLayer.ALL + # All layers: only virtual grants applied during promotion (physical already done in create) + assert sync_grants_mock.call_count == 1 + + +def test_grants_update( + adapter_mock: Mock, mocker: MockerFixture, make_snapshot: t.Callable[..., Snapshot] +): + adapter_mock.SUPPORTS_GRANTS = True + sync_grants_mock = mocker.patch.object(adapter_mock, "sync_grants_config") + + evaluator = SnapshotEvaluator(adapter_mock) + + model = create_sql_model( + "test_schema.test_model", + parse_one("SELECT 1 as id"), + kind="FULL", + grants={"select": ["user1"]}, + grants_target_layer=GrantsTargetLayer.ALL, + ) + + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + evaluator.create([snapshot], {}) + + sync_grants_mock.assert_called_once() + assert sync_grants_mock.call_args[0][1] == {"select": ["user1"]} + + # Update model query AND change grants + updated_model_dict = model.dict() + updated_model_dict["query"] = parse_one("SELECT 1 as id, 2 as value") + updated_model_dict["grants"] = {"select": ["user2", "user3"], "insert": ["admin"]} + updated_model = SqlModel.parse_obj(updated_model_dict) + + new_snapshot = make_snapshot(updated_model) + new_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + sync_grants_mock.reset_mock() + evaluator.create([new_snapshot], {}) + + sync_grants_mock.assert_called_once() + assert sync_grants_mock.call_args[0][1] == {"select": ["user2", "user3"], "insert": ["admin"]} + + # Update model query AND remove grants + updated_model_dict = model.dict() + updated_model_dict["query"] = parse_one("SELECT 1 as id, 'updated' as status") + updated_model_dict["grants"] = {} + updated_model = SqlModel.parse_obj(updated_model_dict) + + new_snapshot = make_snapshot(updated_model) + new_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + sync_grants_mock.reset_mock() + evaluator.create([new_snapshot], {}) + + sync_grants_mock.assert_called_once() + assert sync_grants_mock.call_args[0][1] == {} + + +def test_grants_create_and_evaluate( + adapter_mock: Mock, mocker: MockerFixture, make_snapshot: t.Callable[..., Snapshot] +): + adapter_mock.SUPPORTS_GRANTS = True + sync_grants_mock = mocker.patch.object(adapter_mock, "sync_grants_config") + + evaluator = SnapshotEvaluator(adapter_mock) + + model = load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name test_schema.test_model, + kind INCREMENTAL_BY_TIME_RANGE (time_column ds), + grants ( + 'select' = ['reader1', 'reader2'], + 'insert' = ['writer'] + ), + grants_target_layer 'all' + ); + SELECT ds::DATE, value::INT FROM source WHERE ds BETWEEN @start_ds AND @end_ds; + """ + ) + ) + + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator.create([snapshot], {}) + sync_grants_mock.assert_called_once() + assert sync_grants_mock.call_args[0][1] == { + "select": ["reader1", "reader2"], + "insert": ["writer"], + } + + sync_grants_mock.reset_mock() + evaluator.evaluate( + snapshot, start="2020-01-01", end="2020-01-02", execution_time="2020-01-02", snapshots={} + ) + # Evaluate should not reapply grants + sync_grants_mock.assert_not_called() + + +@pytest.mark.parametrize( + "strategy_class", + [ + EngineManagedStrategy, + FullRefreshStrategy, + IncrementalByTimeRangeStrategy, + IncrementalByPartitionStrategy, + IncrementalUnmanagedStrategy, + IncrementalByUniqueKeyStrategy, + SCDType2Strategy, + # SeedStrategy excluded because seeds do not support migrations + ], +) +def test_grants_materializable_strategy_migrate( + strategy_class: t.Type[MaterializableStrategy], + adapter_mock: Mock, + mocker: MockerFixture, + make_snapshot: t.Callable[..., Snapshot], +): + adapter_mock.SUPPORTS_GRANTS = True + adapter_mock.get_alter_operations.return_value = [] + sync_grants_mock = mocker.patch.object(adapter_mock, "sync_grants_config") + strategy = strategy_class(adapter_mock) + grants = {"select": ["user1"]} + model = _create_grants_test_model(grants=grants, grants_target_layer=GrantsTargetLayer.ALL) + snapshot = make_snapshot(model) + + strategy.migrate( + "target_table", + "source_table", + snapshot, + ignore_destructive=False, + ignore_additive=False, + allow_destructive_snapshots=set(), + allow_additive_snapshots=set(), + ) + + sync_grants_mock.assert_called_once() + assert sync_grants_mock.call_args[0][1] == grants + + +def test_grants_clone_snapshot_in_dev( + adapter_mock: Mock, mocker: MockerFixture, make_snapshot: t.Callable[..., Snapshot] +): + adapter_mock.SUPPORTS_CLONING = True + sync_grants_mock = mocker.patch.object(adapter_mock, "sync_grants_config") + + evaluator = SnapshotEvaluator(adapter_mock) + grants = {"select": ["user1", "user2"]} + model = _create_grants_test_model(grants=grants, grants_target_layer=GrantsTargetLayer.ALL) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator._clone_snapshot_in_dev( + snapshot, {}, DeployabilityIndex.all_deployable(), {}, {}, set(), set() + ) + + sync_grants_mock.assert_called_once() + assert ( + sync_grants_mock.call_args[0][0].sql() + == f"sqlmesh__default.test_model__{snapshot.version}__dev" + ) + assert sync_grants_mock.call_args[0][1] == grants + + +@pytest.mark.parametrize( + "model_kind_name", + [ + "INCREMENTAL_BY_TIME_RANGE", + "SEED", + ], +) +def test_grants_evaluator_insert_without_replace_query_for_model( + model_kind_name: str, + adapter_mock: Mock, + mocker: MockerFixture, + make_snapshot: t.Callable[..., Snapshot], +): + adapter_mock.SUPPORTS_GRANTS = True + adapter_mock.table_exists.return_value = False # Table doesn't exist + sync_grants_mock = mocker.patch.object(adapter_mock, "sync_grants_config") + + evaluator = SnapshotEvaluator(adapter_mock) + + grants = {"select": ["reader1", "reader2"]} + model = _create_grants_test_model( + grants=grants, kind=model_kind_name, grants_target_layer=GrantsTargetLayer.ALL + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator.evaluate( + snapshot, + start="2023-01-01", + end="2023-01-01", + execution_time="2023-01-01", + snapshots={}, + ) + + # Grants are applied during the table creation phase, not during insert + sync_grants_mock.assert_called_once() + assert sync_grants_mock.call_args[0][1] == grants + + sync_grants_mock.reset_mock() + adapter_mock.table_exists.return_value = True + snapshot.add_interval("2023-01-01", "2023-01-01") + evaluator.evaluate( + snapshot, + start="2023-01-02", # Different date from existing interval + end="2023-01-02", + execution_time="2023-01-02", + snapshots={}, + ) + + # Should not apply grants since it's not the first insert + sync_grants_mock.assert_not_called() + + +@pytest.mark.parametrize( + "model_kind_name", + [ + "INCREMENTAL_BY_PARTITION", + "INCREMENTAL_BY_UNIQUE_KEY", + "INCREMENTAL_UNMANAGED", + "FULL", + "SCD_TYPE_2", + ], +) +def test_grants_evaluator_insert_with_replace_query_for_model( + model_kind_name: str, + adapter_mock: Mock, + mocker: MockerFixture, + make_snapshot: t.Callable[..., Snapshot], +): + adapter_mock.SUPPORTS_GRANTS = True + sync_grants_mock = mocker.patch.object(adapter_mock, "sync_grants_config") + adapter_mock.table_exists.return_value = False # Table doesn't exist + adapter_mock.columns.return_value = { + "id": exp.DataType.build("int"), + "ds": exp.DataType.build("date"), + } + + evaluator = SnapshotEvaluator(adapter_mock) + + grants = {"select": ["user1"]} + model = _create_grants_test_model( + grants=grants, kind=model_kind_name, grants_target_layer=GrantsTargetLayer.ALL + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + # Now evaluate the snapshot (this should apply grants during first insert) + evaluator.evaluate( + snapshot, + start="2023-01-01", + end="2023-01-01", + execution_time="2023-01-01", + snapshots={}, + ) + + # Should be called twice more during evaluate: once creating table, + # once during first insert with _replace_query_for_model() + assert sync_grants_mock.call_count == 2 + assert sync_grants_mock.call_args[0][1] == grants + + sync_grants_mock.reset_mock() + adapter_mock.table_exists.return_value = True + snapshot.add_interval("2023-01-01", "2023-01-01") + evaluator.evaluate( + snapshot, + start="2023-01-02", # Different date from existing interval + end="2023-01-02", + execution_time="2023-01-02", + snapshots={}, + ) + + if model_kind_name in ("FULL", "SCD_TYPE_2"): + # Full refresh and SCD_TYPE_2 always recreate the table, so grants are always applied + sync_grants_mock.assert_called_once() + assert sync_grants_mock.call_args[0][1] == grants + else: + # Should not apply grants since it's not the first insert + sync_grants_mock.assert_not_called() + + +@pytest.mark.parametrize( + "model_grants_target_layer", + [ + GrantsTargetLayer.ALL, + GrantsTargetLayer.VIRTUAL, + GrantsTargetLayer.PHYSICAL, + ], +) +def test_grants_in_production_with_dev_only_vde( + adapter_mock: Mock, + mocker: MockerFixture, + make_snapshot: t.Callable[..., Snapshot], + model_grants_target_layer: GrantsTargetLayer, +): + adapter_mock.SUPPORTS_GRANTS = True + sync_grants_mock = mocker.patch.object(adapter_mock, "sync_grants_config") + + from sqlmesh.core.model.meta import VirtualEnvironmentMode, GrantsTargetLayer + from sqlmesh.core.snapshot.definition import DeployabilityIndex + + model_virtual_grants = _create_grants_test_model( + grants={"select": ["user1"], "insert": ["role1"]}, + grants_target_layer=model_grants_target_layer, + virtual_environment_mode=VirtualEnvironmentMode.DEV_ONLY, + ) + + snapshot = make_snapshot(model_virtual_grants) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + evaluator = SnapshotEvaluator(adapter_mock) + # create will apply grants to physical layer tables + deployability_index = DeployabilityIndex.all_deployable() + evaluator.create([snapshot], {}, deployability_index=deployability_index) + + sync_grants_mock.assert_called_once() + assert sync_grants_mock.call_args[0][1] == {"select": ["user1"], "insert": ["role1"]} + + # Non-deployable (dev) env + sync_grants_mock.reset_mock() + deployability_index = DeployabilityIndex.none_deployable() + evaluator.create([snapshot], {}, deployability_index=deployability_index) + if model_grants_target_layer == GrantsTargetLayer.VIRTUAL: + sync_grants_mock.assert_not_called() + else: + # Should still apply grants to physical table when target layer is ALL or PHYSICAL + sync_grants_mock.assert_called_once() + assert sync_grants_mock.call_args[0][1] == {"select": ["user1"], "insert": ["role1"]} diff --git a/tests/dbt/test_model.py b/tests/dbt/test_model.py index e29c6768bf..eb16a4b4b1 100644 --- a/tests/dbt/test_model.py +++ b/tests/dbt/test_model.py @@ -9,10 +9,12 @@ from sqlmesh.core.model import TimeColumn, IncrementalByTimeRangeKind from sqlmesh.core.model.kind import OnDestructiveChange, OnAdditiveChange from sqlmesh.core.state_sync.db.snapshot import _snapshot_to_json +from sqlmesh.core.config.common import VirtualEnvironmentMode +from sqlmesh.core.model.meta import GrantsTargetLayer from sqlmesh.dbt.common import Dependencies from sqlmesh.dbt.context import DbtContext from sqlmesh.dbt.model import ModelConfig -from sqlmesh.dbt.target import PostgresConfig +from sqlmesh.dbt.target import BigQueryConfig, DuckDbConfig, PostgresConfig from sqlmesh.dbt.test import TestConfig from sqlmesh.utils.yaml import YAML from sqlmesh.utils.date import to_ds @@ -853,3 +855,176 @@ def test_load_custom_materialisations(sushi_test_dbt_context: Context) -> None: context.load() assert context.get_model("sushi.custom_incremental_model") assert context.get_model("sushi.custom_incremental_with_filter") + + +def test_model_grants_to_sqlmesh_grants_config() -> None: + grants_config = { + "select": ["user1", "user2"], + "insert": ["admin_user"], + "update": ["power_user"], + } + model_config = ModelConfig( + name="test_model", + sql="SELECT 1 as id", + grants=grants_config, + path=Path("test_model.sql"), + ) + + context = DbtContext() + context.project_name = "test_project" + context.target = DuckDbConfig(name="target", schema="test_schema") + + sqlmesh_model = model_config.to_sqlmesh( + context, virtual_environment_mode=VirtualEnvironmentMode.FULL + ) + + model_grants = sqlmesh_model.grants + assert model_grants == grants_config + + assert sqlmesh_model.grants_target_layer == GrantsTargetLayer.default + + +def test_model_grants_empty_permissions() -> None: + model_config = ModelConfig( + name="test_model_empty", + sql="SELECT 1 as id", + grants={"select": [], "insert": ["admin_user"]}, + path=Path("test_model_empty.sql"), + ) + + context = DbtContext() + context.project_name = "test_project" + context.target = DuckDbConfig(name="target", schema="test_schema") + + sqlmesh_model = model_config.to_sqlmesh( + context, virtual_environment_mode=VirtualEnvironmentMode.FULL + ) + + model_grants = sqlmesh_model.grants + expected_grants = {"select": [], "insert": ["admin_user"]} + assert model_grants == expected_grants + + +def test_model_no_grants() -> None: + model_config = ModelConfig( + name="test_model_no_grants", + sql="SELECT 1 as id", + path=Path("test_model_no_grants.sql"), + ) + + context = DbtContext() + context.project_name = "test_project" + context.target = DuckDbConfig(name="target", schema="test_schema") + + sqlmesh_model = model_config.to_sqlmesh( + context, virtual_environment_mode=VirtualEnvironmentMode.FULL + ) + + grants_config = sqlmesh_model.grants + assert grants_config is None + + +def test_model_empty_grants() -> None: + model_config = ModelConfig( + name="test_model_empty_grants", + sql="SELECT 1 as id", + grants={}, + path=Path("test_model_empty_grants.sql"), + ) + + context = DbtContext() + context.project_name = "test_project" + context.target = DuckDbConfig(name="target", schema="test_schema") + + sqlmesh_model = model_config.to_sqlmesh( + context, virtual_environment_mode=VirtualEnvironmentMode.FULL + ) + + grants_config = sqlmesh_model.grants + assert grants_config is None + + +def test_model_grants_valid_special_characters() -> None: + valid_grantees = [ + "user@domain.com", + "service-account@project.iam.gserviceaccount.com", + "group:analysts", + '"quoted user"', + "`backtick user`", + "user_with_underscores", + "user.with.dots", + ] + + model_config = ModelConfig( + name="test_model_special_chars", + sql="SELECT 1 as id", + grants={"select": valid_grantees}, + path=Path("test_model.sql"), + ) + + context = DbtContext() + context.project_name = "test_project" + context.target = DuckDbConfig(name="target", schema="test_schema") + + sqlmesh_model = model_config.to_sqlmesh( + context, virtual_environment_mode=VirtualEnvironmentMode.FULL + ) + + grants_config = sqlmesh_model.grants + assert grants_config is not None + assert "select" in grants_config + assert grants_config["select"] == valid_grantees + + +def test_model_grants_engine_specific_bigquery() -> None: + model_config = ModelConfig( + name="test_model_bigquery", + sql="SELECT 1 as id", + grants={ + "bigquery.dataviewer": ["user@domain.com"], + "select": ["analyst@company.com"], + }, + path=Path("test_model.sql"), + ) + + context = DbtContext() + context.project_name = "test_project" + context.target = BigQueryConfig( + name="bigquery_target", + project="test-project", + dataset="test_dataset", + location="US", + database="test-project", + schema="test_dataset", + ) + + sqlmesh_model = model_config.to_sqlmesh( + context, virtual_environment_mode=VirtualEnvironmentMode.FULL + ) + + grants_config = sqlmesh_model.grants + assert grants_config is not None + assert grants_config["bigquery.dataviewer"] == ["user@domain.com"] + assert grants_config["select"] == ["analyst@company.com"] + + +def test_ephemeral_model_ignores_grants() -> None: + """Test that ephemeral models ignore grants configuration.""" + model_config = ModelConfig( + name="ephemeral_model", + sql="SELECT 1 as id", + materialized="ephemeral", + grants={"select": ["reporter", "analyst"]}, + path=Path("ephemeral_model.sql"), + ) + + context = DbtContext() + context.project_name = "test_project" + context.target = DuckDbConfig(name="target", schema="test_schema") + + sqlmesh_model = model_config.to_sqlmesh( + context, virtual_environment_mode=VirtualEnvironmentMode.FULL + ) + + assert sqlmesh_model.kind.is_embedded + assert sqlmesh_model.grants is None # grants config is skipped for ephemeral / embedded models