|
7 | 7 |
|
8 | 8 | from sqlglot import exp, parse_one |
9 | 9 | from sqlglot.helper import seq_get |
| 10 | +from sqlglot.optimizer.normalize_identifiers import normalize_identifiers |
10 | 11 |
|
11 | 12 | from sqlmesh.core.engine_adapter.base import EngineAdapter |
| 13 | +from sqlmesh.core.engine_adapter.shared import DataObjectType |
12 | 14 | from sqlmesh.core.node import IntervalUnit |
13 | 15 | from sqlmesh.core.dialect import schema_ |
14 | 16 | from sqlmesh.core.schema_diff import TableAlterOperation |
15 | 17 | from sqlmesh.utils.errors import SQLMeshError |
16 | 18 |
|
17 | 19 | if t.TYPE_CHECKING: |
18 | 20 | from sqlmesh.core._typing import TableName |
19 | | - from sqlmesh.core.engine_adapter._typing import DF |
| 21 | + from sqlmesh.core.engine_adapter._typing import ( |
| 22 | + DCL, |
| 23 | + DF, |
| 24 | + GrantsConfig, |
| 25 | + QueryOrDF, |
| 26 | + ) |
20 | 27 | from sqlmesh.core.engine_adapter.base import QueryOrDF |
21 | 28 |
|
22 | 29 | logger = logging.getLogger(__name__) |
@@ -548,3 +555,137 @@ def _normalize_decimal_value(self, expr: exp.Expression, precision: int) -> exp. |
548 | 555 |
|
549 | 556 | def _normalize_boolean_value(self, expr: exp.Expression) -> exp.Expression: |
550 | 557 | return exp.cast(expr, "INT") |
| 558 | + |
| 559 | + |
| 560 | +class GrantsFromInfoSchemaMixin(EngineAdapter): |
| 561 | + CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expression = exp.func("current_user") |
| 562 | + SUPPORTS_MULTIPLE_GRANT_PRINCIPALS = False |
| 563 | + USE_CATALOG_IN_GRANTS = False |
| 564 | + GRANT_INFORMATION_SCHEMA_TABLE_NAME = "table_privileges" |
| 565 | + |
| 566 | + @staticmethod |
| 567 | + @abc.abstractmethod |
| 568 | + def _grant_object_kind(table_type: DataObjectType) -> t.Optional[str]: |
| 569 | + pass |
| 570 | + |
| 571 | + @abc.abstractmethod |
| 572 | + def _get_current_schema(self) -> str: |
| 573 | + pass |
| 574 | + |
| 575 | + def _dcl_grants_config_expr( |
| 576 | + self, |
| 577 | + dcl_cmd: t.Type[DCL], |
| 578 | + table: exp.Table, |
| 579 | + grant_config: GrantsConfig, |
| 580 | + table_type: DataObjectType = DataObjectType.TABLE, |
| 581 | + ) -> t.List[exp.Expression]: |
| 582 | + expressions: t.List[exp.Expression] = [] |
| 583 | + if not grant_config: |
| 584 | + return expressions |
| 585 | + |
| 586 | + object_kind = self._grant_object_kind(table_type) |
| 587 | + for privilege, principals in grant_config.items(): |
| 588 | + args: t.Dict[str, t.Any] = { |
| 589 | + "privileges": [exp.GrantPrivilege(this=exp.Var(this=privilege))], |
| 590 | + "securable": table.copy(), |
| 591 | + } |
| 592 | + if object_kind: |
| 593 | + args["kind"] = exp.Var(this=object_kind) |
| 594 | + if self.SUPPORTS_MULTIPLE_GRANT_PRINCIPALS: |
| 595 | + args["principals"] = [ |
| 596 | + normalize_identifiers( |
| 597 | + parse_one(principal, into=exp.GrantPrincipal, dialect=self.dialect), |
| 598 | + dialect=self.dialect, |
| 599 | + ) |
| 600 | + for principal in principals |
| 601 | + ] |
| 602 | + expressions.append(dcl_cmd(**args)) # type: ignore[arg-type] |
| 603 | + else: |
| 604 | + for principal in principals: |
| 605 | + args["principals"] = [ |
| 606 | + normalize_identifiers( |
| 607 | + parse_one(principal, into=exp.GrantPrincipal, dialect=self.dialect), |
| 608 | + dialect=self.dialect, |
| 609 | + ) |
| 610 | + ] |
| 611 | + expressions.append(dcl_cmd(**args)) # type: ignore[arg-type] |
| 612 | + |
| 613 | + return expressions |
| 614 | + |
| 615 | + def _apply_grants_config_expr( |
| 616 | + self, |
| 617 | + table: exp.Table, |
| 618 | + grant_config: GrantsConfig, |
| 619 | + table_type: DataObjectType = DataObjectType.TABLE, |
| 620 | + ) -> t.List[exp.Expression]: |
| 621 | + return self._dcl_grants_config_expr(exp.Grant, table, grant_config, table_type) |
| 622 | + |
| 623 | + def _revoke_grants_config_expr( |
| 624 | + self, |
| 625 | + table: exp.Table, |
| 626 | + grant_config: GrantsConfig, |
| 627 | + table_type: DataObjectType = DataObjectType.TABLE, |
| 628 | + ) -> t.List[exp.Expression]: |
| 629 | + return self._dcl_grants_config_expr(exp.Revoke, table, grant_config, table_type) |
| 630 | + |
| 631 | + def _get_grant_expression(self, table: exp.Table) -> exp.Expression: |
| 632 | + schema_identifier = table.args.get("db") or normalize_identifiers( |
| 633 | + exp.to_identifier(self._get_current_schema(), quoted=True), dialect=self.dialect |
| 634 | + ) |
| 635 | + schema_name = schema_identifier.this |
| 636 | + table_name = table.args.get("this").this # type: ignore |
| 637 | + |
| 638 | + grant_conditions = [ |
| 639 | + exp.column("table_schema").eq(exp.Literal.string(schema_name)), |
| 640 | + exp.column("table_name").eq(exp.Literal.string(table_name)), |
| 641 | + exp.column("grantor").eq(self.CURRENT_USER_OR_ROLE_EXPRESSION), |
| 642 | + exp.column("grantee").neq(self.CURRENT_USER_OR_ROLE_EXPRESSION), |
| 643 | + ] |
| 644 | + |
| 645 | + info_schema_table = normalize_identifiers( |
| 646 | + exp.table_(self.GRANT_INFORMATION_SCHEMA_TABLE_NAME, db="information_schema"), |
| 647 | + dialect=self.dialect, |
| 648 | + ) |
| 649 | + if self.USE_CATALOG_IN_GRANTS: |
| 650 | + catalog_identifier = table.args.get("catalog") |
| 651 | + if not catalog_identifier: |
| 652 | + catalog_name = self.get_current_catalog() |
| 653 | + if not catalog_name: |
| 654 | + raise SQLMeshError( |
| 655 | + "Current catalog could not be determined for fetching grants. This is unexpected." |
| 656 | + ) |
| 657 | + catalog_identifier = normalize_identifiers( |
| 658 | + exp.to_identifier(catalog_name, quoted=True), dialect=self.dialect |
| 659 | + ) |
| 660 | + catalog_name = catalog_identifier.this |
| 661 | + info_schema_table.set("catalog", catalog_identifier.copy()) |
| 662 | + grant_conditions.insert( |
| 663 | + 0, exp.column("table_catalog").eq(exp.Literal.string(catalog_name)) |
| 664 | + ) |
| 665 | + |
| 666 | + return ( |
| 667 | + exp.select("privilege_type", "grantee") |
| 668 | + .from_(info_schema_table) |
| 669 | + .where(exp.and_(*grant_conditions)) |
| 670 | + ) |
| 671 | + |
| 672 | + def _get_current_grants_config(self, table: exp.Table) -> GrantsConfig: |
| 673 | + grant_expr = self._get_grant_expression(table) |
| 674 | + |
| 675 | + results = self.fetchall(grant_expr) |
| 676 | + |
| 677 | + grants_dict: GrantsConfig = {} |
| 678 | + for privilege_raw, grantee_raw in results: |
| 679 | + if privilege_raw is None or grantee_raw is None: |
| 680 | + continue |
| 681 | + |
| 682 | + privilege = str(privilege_raw) |
| 683 | + grantee = str(grantee_raw) |
| 684 | + if not privilege or not grantee: |
| 685 | + continue |
| 686 | + |
| 687 | + grantees = grants_dict.setdefault(privilege, []) |
| 688 | + if grantee not in grantees: |
| 689 | + grantees.append(grantee) |
| 690 | + |
| 691 | + return grants_dict |
0 commit comments