Skip to content

Commit ca03fcf

Browse files
authored
feat: add snowflake grant support (#5433)
1 parent 2f3b72a commit ca03fcf

File tree

8 files changed

+564
-171
lines changed

8 files changed

+564
-171
lines changed

sqlmesh/core/engine_adapter/_typing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@
3131

3232
QueryOrDF = t.Union[Query, DF]
3333
GrantsConfig = t.Dict[str, t.List[str]]
34+
DCL = t.TypeVar("DCL", exp.Grant, exp.Revoke)

sqlmesh/core/engine_adapter/postgres.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@
1818

1919
if t.TYPE_CHECKING:
2020
from sqlmesh.core._typing import TableName
21-
from sqlmesh.core.engine_adapter._typing import DF, GrantsConfig, QueryOrDF
22-
23-
DCL = t.TypeVar("DCL", exp.Grant, exp.Revoke)
21+
from sqlmesh.core.engine_adapter._typing import DCL, DF, GrantsConfig, QueryOrDF
2422

2523
logger = logging.getLogger(__name__)
2624

@@ -38,7 +36,7 @@ class PostgresEngineAdapter(
3836
HAS_VIEW_BINDING = True
3937
CURRENT_CATALOG_EXPRESSION = exp.column("current_catalog")
4038
SUPPORTS_REPLACE_TABLE = False
41-
MAX_IDENTIFIER_LENGTH = 63
39+
MAX_IDENTIFIER_LENGTH: t.Optional[int] = 63
4240
SUPPORTS_QUERY_EXECUTION_TRACKING = True
4341
SCHEMA_DIFFER_KWARGS = {
4442
"parameterized_type_defaults": {

sqlmesh/core/engine_adapter/risingwave.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class RisingwaveEngineAdapter(PostgresEngineAdapter):
3232
SUPPORTS_MATERIALIZED_VIEWS = True
3333
SUPPORTS_TRANSACTIONS = False
3434
MAX_IDENTIFIER_LENGTH = None
35+
SUPPORTS_GRANTS = False
3536

3637
def columns(
3738
self, table_name: TableName, include_pseudo_columns: bool = False

sqlmesh/core/engine_adapter/snowflake.py

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,14 @@
3434
import pandas as pd
3535

3636
from sqlmesh.core._typing import SchemaName, SessionProperties, TableName
37-
from sqlmesh.core.engine_adapter._typing import DF, Query, QueryOrDF, SnowparkSession
37+
from sqlmesh.core.engine_adapter._typing import (
38+
DCL,
39+
DF,
40+
GrantsConfig,
41+
Query,
42+
QueryOrDF,
43+
SnowparkSession,
44+
)
3845
from sqlmesh.core.node import IntervalUnit
3946

4047

@@ -73,6 +80,7 @@ class SnowflakeEngineAdapter(GetCurrentCatalogFromFunctionMixin, ClusteredByMixi
7380
MANAGED_TABLE_KIND = "DYNAMIC TABLE"
7481
SNOWPARK = "snowpark"
7582
SUPPORTS_QUERY_EXECUTION_TRACKING = True
83+
SUPPORTS_GRANTS = True
7684

7785
@contextlib.contextmanager
7886
def session(self, properties: SessionProperties) -> t.Iterator[None]:
@@ -127,6 +135,118 @@ def snowpark(self) -> t.Optional[SnowparkSession]:
127135
def catalog_support(self) -> CatalogSupport:
128136
return CatalogSupport.FULL_SUPPORT
129137

138+
@staticmethod
139+
def _grant_object_kind(table_type: DataObjectType) -> str:
140+
if table_type == DataObjectType.VIEW:
141+
return "VIEW"
142+
if table_type == DataObjectType.MATERIALIZED_VIEW:
143+
return "MATERIALIZED VIEW"
144+
if table_type == DataObjectType.MANAGED_TABLE:
145+
return "DYNAMIC TABLE"
146+
return "TABLE"
147+
148+
def _get_current_schema(self) -> str:
149+
"""Returns the current default schema for the connection."""
150+
result = self.fetchone("SELECT CURRENT_SCHEMA()")
151+
if not result or not result[0]:
152+
raise SQLMeshError("Unable to determine current schema")
153+
return str(result[0])
154+
155+
def _dcl_grants_config_expr(
156+
self,
157+
dcl_cmd: t.Type[DCL],
158+
table: exp.Table,
159+
grant_config: GrantsConfig,
160+
table_type: DataObjectType = DataObjectType.TABLE,
161+
) -> t.List[exp.Expression]:
162+
expressions: t.List[exp.Expression] = []
163+
if not grant_config:
164+
return expressions
165+
166+
object_kind = self._grant_object_kind(table_type)
167+
for privilege, principals in grant_config.items():
168+
for principal in principals:
169+
args: t.Dict[str, t.Any] = {
170+
"privileges": [exp.GrantPrivilege(this=exp.Var(this=privilege))],
171+
"securable": table.copy(),
172+
"principals": [principal],
173+
}
174+
175+
if object_kind:
176+
args["kind"] = exp.Var(this=object_kind)
177+
178+
expressions.append(dcl_cmd(**args)) # type: ignore[arg-type]
179+
180+
return expressions
181+
182+
def _apply_grants_config_expr(
183+
self,
184+
table: exp.Table,
185+
grant_config: GrantsConfig,
186+
table_type: DataObjectType = DataObjectType.TABLE,
187+
) -> t.List[exp.Expression]:
188+
return self._dcl_grants_config_expr(exp.Grant, table, grant_config, table_type)
189+
190+
def _revoke_grants_config_expr(
191+
self,
192+
table: exp.Table,
193+
grant_config: GrantsConfig,
194+
table_type: DataObjectType = DataObjectType.TABLE,
195+
) -> t.List[exp.Expression]:
196+
return self._dcl_grants_config_expr(exp.Revoke, table, grant_config, table_type)
197+
198+
def _get_current_grants_config(self, table: exp.Table) -> GrantsConfig:
199+
schema_identifier = table.args.get("db") or normalize_identifiers(
200+
exp.to_identifier(self._get_current_schema(), quoted=True), dialect=self.dialect
201+
)
202+
catalog_identifier = table.args.get("catalog")
203+
if not catalog_identifier:
204+
current_catalog = self.get_current_catalog()
205+
if not current_catalog:
206+
raise SQLMeshError("Unable to determine current catalog for fetching grants")
207+
catalog_identifier = normalize_identifiers(
208+
exp.to_identifier(current_catalog, quoted=True), dialect=self.dialect
209+
)
210+
catalog_identifier.set("quoted", True)
211+
table_identifier = table.args.get("this")
212+
213+
grant_expr = (
214+
exp.select("privilege_type", "grantee")
215+
.from_(
216+
exp.table_(
217+
"TABLE_PRIVILEGES",
218+
db="INFORMATION_SCHEMA",
219+
catalog=catalog_identifier,
220+
)
221+
)
222+
.where(
223+
exp.and_(
224+
exp.column("table_schema").eq(exp.Literal.string(schema_identifier.this)),
225+
exp.column("table_name").eq(exp.Literal.string(table_identifier.this)), # type: ignore
226+
exp.column("grantor").eq(exp.func("CURRENT_ROLE")),
227+
exp.column("grantee").neq(exp.func("CURRENT_ROLE")),
228+
)
229+
)
230+
)
231+
232+
results = self.fetchall(grant_expr)
233+
234+
grants_dict: GrantsConfig = {}
235+
for privilege_raw, grantee_raw in results:
236+
if privilege_raw is None or grantee_raw is None:
237+
continue
238+
239+
privilege = str(privilege_raw)
240+
grantee = str(grantee_raw)
241+
if not privilege or not grantee:
242+
continue
243+
244+
grantees = grants_dict.setdefault(privilege, [])
245+
if grantee not in grantees:
246+
grantees.append(grantee)
247+
248+
return grants_dict
249+
130250
def _create_catalog(self, catalog_name: exp.Identifier) -> None:
131251
props = exp.Properties(
132252
expressions=[exp.SchemaCommentProperty(this=exp.Literal.string(c.SQLMESH_MANAGED))]

tests/core/engine_adapter/integration/__init__.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
import sys
66
import typing as t
77
import time
8+
from contextlib import contextmanager
89

910
import pandas as pd # noqa: TID253
1011
import pytest
1112
from sqlglot import exp, parse_one
13+
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
1214

1315
from sqlmesh import Config, Context, EngineAdapter
1416
from sqlmesh.core.config import load_config_from_paths
@@ -744,6 +746,55 @@ def upsert_sql_model(self, model_definition: str) -> t.Tuple[Context, SqlModel]:
744746
self._context.upsert_model(model)
745747
return self._context, model
746748

749+
def _get_create_user_or_role(self, username: str, password: t.Optional[str] = None) -> str:
750+
password = password or random_id()
751+
if self.dialect == "postgres":
752+
return f"CREATE USER \"{username}\" WITH PASSWORD '{password}'"
753+
if self.dialect == "snowflake":
754+
return f"CREATE ROLE {username}"
755+
raise ValueError(f"User creation not supported for dialect: {self.dialect}")
756+
757+
def _create_user_or_role(self, username: str, password: t.Optional[str] = None) -> None:
758+
create_user_sql = self._get_create_user_or_role(username, password)
759+
self.engine_adapter.execute(create_user_sql)
760+
761+
@contextmanager
762+
def create_users_or_roles(self, *role_names: str) -> t.Iterator[t.Dict[str, str]]:
763+
created_users = []
764+
roles = {}
765+
766+
try:
767+
for role_name in role_names:
768+
user_name = normalize_identifiers(
769+
self.add_test_suffix(f"test_{role_name}"), dialect=self.dialect
770+
).sql(dialect=self.dialect)
771+
password = random_id()
772+
self._create_user_or_role(user_name, password)
773+
created_users.append(user_name)
774+
roles[role_name] = user_name
775+
776+
yield roles
777+
778+
finally:
779+
for user_name in created_users:
780+
self._cleanup_user_or_role(user_name)
781+
782+
def _cleanup_user_or_role(self, user_name: str) -> None:
783+
"""Helper function to clean up a PostgreSQL user and all their dependencies."""
784+
try:
785+
if self.dialect == "postgres":
786+
self.engine_adapter.execute(f"""
787+
SELECT pg_terminate_backend(pid)
788+
FROM pg_stat_activity
789+
WHERE usename = '{user_name}' AND pid <> pg_backend_pid()
790+
""")
791+
self.engine_adapter.execute(f'DROP OWNED BY "{user_name}"')
792+
self.engine_adapter.execute(f'DROP USER IF EXISTS "{user_name}"')
793+
elif self.dialect == "snowflake":
794+
self.engine_adapter.execute(f"DROP ROLE IF EXISTS {user_name}")
795+
except Exception:
796+
pass
797+
747798

748799
def wait_until(fn: t.Callable[..., bool], attempts=3, wait=5) -> None:
749800
current_attempt = 0

0 commit comments

Comments
 (0)