diff --git a/sqlmesh/core/selector.py b/sqlmesh/core/selector.py index 1484d06cee..3865327acd 100644 --- a/sqlmesh/core/selector.py +++ b/sqlmesh/core/selector.py @@ -16,6 +16,7 @@ from sqlmesh.core.dialect import normalize_model_name from sqlmesh.core.environment import Environment from sqlmesh.core.model import update_model_schemas +from sqlmesh.core.audit import StandaloneAudit from sqlmesh.utils import UniqueKeyDict from sqlmesh.utils.dag import DAG from sqlmesh.utils.git import GitClient @@ -25,6 +26,7 @@ if t.TYPE_CHECKING: from typing_extensions import Literal as Lit # noqa from sqlmesh.core.model import Model + from sqlmesh.core.node import Node from sqlmesh.core.state_sync import StateReader @@ -167,7 +169,7 @@ def get_model(fqn: str) -> t.Optional[Model]: return models def expand_model_selections( - self, model_selections: t.Iterable[str], models: t.Optional[t.Dict[str, Model]] = None + self, model_selections: t.Iterable[str], models: t.Optional[t.Dict[str, Node]] = None ) -> t.Set[str]: """Expands a set of model selections into a set of model fqns that can be looked up in the Context. @@ -180,7 +182,7 @@ def expand_model_selections( node = parse(" | ".join(f"({s})" for s in model_selections)) - all_models = models or self._models + all_models: t.Dict[str, Node] = models or dict(self._models) models_by_tags: t.Dict[str, t.Set[str]] = {} for fqn, model in all_models.items(): @@ -226,6 +228,13 @@ def evaluate(node: exp.Expression) -> t.Set[str]: if fnmatch.fnmatchcase(tag, pattern) } return models_by_tags.get(pattern, set()) + if isinstance(node, ResourceType): + resource_type = node.name.lower() + return { + fqn + for fqn, model in all_models.items() + if self._matches_resource_type(resource_type, model) + } if isinstance(node, Direction): selected = set() @@ -243,36 +252,49 @@ def evaluate(node: exp.Expression) -> t.Set[str]: return evaluate(node) @abc.abstractmethod - def _model_name(self, model: Model) -> str: + def _model_name(self, model: Node) -> str: """Given a model, return the name that a selector pattern contining wildcards should be fnmatch'd on""" pass @abc.abstractmethod - def _pattern_to_model_fqns(self, pattern: str, all_models: t.Dict[str, Model]) -> t.Set[str]: + def _pattern_to_model_fqns(self, pattern: str, all_models: t.Dict[str, Node]) -> t.Set[str]: """Given a pattern, return the keys of the matching models from :all_models""" pass + @abc.abstractmethod + def _matches_resource_type(self, resource_type: str, model: Node) -> bool: + """Indicate whether or not the supplied model matches the supplied resource type""" + pass + class NativeSelector(Selector): """Implementation of selectors that matches objects based on SQLMesh native names""" - def _model_name(self, model: Model) -> str: + def _model_name(self, model: Node) -> str: return model.name - def _pattern_to_model_fqns(self, pattern: str, all_models: t.Dict[str, Model]) -> t.Set[str]: + def _pattern_to_model_fqns(self, pattern: str, all_models: t.Dict[str, Node]) -> t.Set[str]: fqn = normalize_model_name(pattern, self._default_catalog, self._dialect) return {fqn} if fqn in all_models else set() + def _matches_resource_type(self, resource_type: str, model: Node) -> bool: + if resource_type == "model": + return model.is_model + if resource_type == "audit": + return isinstance(model, StandaloneAudit) + + raise SQLMeshError(f"Unsupported resource type: {resource_type}") + class DbtSelector(Selector): """Implementation of selectors that matches objects based on the DBT names instead of the SQLMesh native names""" - def _model_name(self, model: Model) -> str: + def _model_name(self, model: Node) -> str: if dbt_fqn := model.dbt_fqn: return dbt_fqn raise SQLMeshError("dbt node information must be populated to use dbt selectors") - def _pattern_to_model_fqns(self, pattern: str, all_models: t.Dict[str, Model]) -> t.Set[str]: + def _pattern_to_model_fqns(self, pattern: str, all_models: t.Dict[str, Node]) -> t.Set[str]: # a pattern like "staging.customers" should match a model called "jaffle_shop.staging.customers" # but not a model called "jaffle_shop.customers.staging" # also a pattern like "aging" should not match "staging" so we need to consider components; not substrings @@ -306,6 +328,40 @@ def _pattern_to_model_fqns(self, pattern: str, all_models: t.Dict[str, Model]) - matches.add(fqn) return matches + def _matches_resource_type(self, resource_type: str, model: Node) -> bool: + """ + ref: https://docs.getdbt.com/reference/node-selection/methods#resource_type + + # supported by SQLMesh + "model" + "seed" + "source" # external model + "test" # standalone audit + + # not supported by SQLMesh yet, commented out to throw an error if someone tries to use them + "analysis" + "exposure" + "metric" + "saved_query" + "semantic_model" + "snapshot" + "unit_test" + """ + if resource_type not in ("model", "seed", "source", "test"): + raise SQLMeshError(f"Unsupported resource type: {resource_type}") + + if isinstance(model, StandaloneAudit): + return resource_type == "test" + + if resource_type == "model": + return model.is_model and not model.kind.is_external and not model.kind.is_seed + if resource_type == "source": + return model.kind.is_external + if resource_type == "seed": + return model.kind.is_seed + + return False + class SelectorDialect(Dialect): IDENTIFIERS_CAN_START_WITH_DIGIT = True @@ -336,6 +392,10 @@ class Tag(exp.Expression): pass +class ResourceType(exp.Expression): + pass + + class Direction(exp.Expression): pass @@ -388,7 +448,8 @@ def _parse_var() -> exp.Expression: upstream = _match(TokenType.PLUS) downstream = None tag = _parse_kind("tag") - git = False if tag else _parse_kind("git") + resource_type = False if tag else _parse_kind("resource_type") + git = False if resource_type else _parse_kind("git") lstar = "*" if _match(TokenType.STAR) else "" directions = {} @@ -414,6 +475,8 @@ def _parse_var() -> exp.Expression: if tag: this = Tag(this=this) + if resource_type: + this = ResourceType(this=this) if git: this = Git(this=this) if directions: diff --git a/sqlmesh_dbt/cli.py b/sqlmesh_dbt/cli.py index fa75d303a1..83230de3fd 100644 --- a/sqlmesh_dbt/cli.py +++ b/sqlmesh_dbt/cli.py @@ -33,15 +33,39 @@ def _cleanup() -> None: select_option = click.option( "-s", - "-m", "--select", + multiple=True, + help="Specify the nodes to include.", +) +model_option = click.option( + "-m", "--models", "--model", multiple=True, - help="Specify the nodes to include.", + help="Specify the model nodes to include; other nodes are excluded.", ) exclude_option = click.option("--exclude", multiple=True, help="Specify the nodes to exclude.") +# TODO: expand this out into --resource-type/--resource-types and --exclude-resource-type/--exclude-resource-types +resource_types = [ + "metric", + "semantic_model", + "saved_query", + "source", + "analysis", + "model", + "test", + "unit_test", + "exposure", + "snapshot", + "seed", + "default", + "all", +] +resource_type_option = click.option( + "--resource-type", type=click.Choice(resource_types, case_sensitive=False) +) + @click.group(cls=ErrorHandlingGroup, invoke_without_command=True) @click.option("--profile", help="Which existing profile to load. Overrides output.profile") @@ -86,7 +110,9 @@ def dbt( @dbt.command() @select_option +@model_option @exclude_option +@resource_type_option @click.option( "-f", "--full-refresh", @@ -116,7 +142,9 @@ def run( @dbt.command(name="list") @select_option +@model_option @exclude_option +@resource_type_option @vars_option @click.pass_context def list_(ctx: click.Context, vars: t.Optional[t.Dict[str, t.Any]], **kwargs: t.Any) -> None: diff --git a/sqlmesh_dbt/operations.py b/sqlmesh_dbt/operations.py index a157705ffd..6e8b452b28 100644 --- a/sqlmesh_dbt/operations.py +++ b/sqlmesh_dbt/operations.py @@ -26,12 +26,16 @@ def list_( self, select: t.Optional[t.List[str]] = None, exclude: t.Optional[t.List[str]] = None, + models: t.Optional[t.List[str]] = None, + resource_type: t.Optional[str] = None, ) -> None: # dbt list prints: # - models # - "data tests" (audits) for those models # it also applies selectors which is useful for testing selectors - selected_models = list(self._selected_models(select, exclude).values()) + selected_models = list( + self._selected_models(select, exclude, models, resource_type).values() + ) self.console.list_models( selected_models, {k: v.node for k, v in self.context.snapshots.items()} ) @@ -41,13 +45,19 @@ def run( environment: t.Optional[str] = None, select: t.Optional[t.List[str]] = None, exclude: t.Optional[t.List[str]] = None, + models: t.Optional[t.List[str]] = None, + resource_type: t.Optional[str] = None, full_refresh: bool = False, empty: bool = False, ) -> Plan: + consolidated_select, consolidated_exclude = selectors.consolidate( + select or [], exclude or [], models or [], resource_type + ) + plan_builder = self._plan_builder( environment=environment, - select=select, - exclude=exclude, + select=consolidated_select, + exclude=consolidated_exclude, full_refresh=full_refresh, empty=empty, ) @@ -86,9 +96,15 @@ def _plan_builder( ) def _selected_models( - self, select: t.Optional[t.List[str]] = None, exclude: t.Optional[t.List[str]] = None + self, + select: t.Optional[t.List[str]] = None, + exclude: t.Optional[t.List[str]] = None, + models: t.Optional[t.List[str]] = None, + resource_type: t.Optional[str] = None, ) -> t.Dict[str, Model]: - if sqlmesh_selector := selectors.to_sqlmesh(select or [], exclude or []): + if sqlmesh_selector := selectors.to_sqlmesh( + *selectors.consolidate(select or [], exclude or [], models or [], resource_type) + ): if self.debug: self.console.print(f"dbt --select: {select}") self.console.print(f"dbt --exclude: {exclude}") diff --git a/sqlmesh_dbt/selectors.py b/sqlmesh_dbt/selectors.py index 120d5dcb36..5821586ad3 100644 --- a/sqlmesh_dbt/selectors.py +++ b/sqlmesh_dbt/selectors.py @@ -4,7 +4,45 @@ logger = logging.getLogger(__name__) -def to_sqlmesh(dbt_select: t.Collection[str], dbt_exclude: t.Collection[str]) -> t.Optional[str]: +def consolidate( + select: t.List[str], + exclude: t.List[str], + models: t.List[str], + resource_type: t.Optional[str], +) -> t.Tuple[t.List[str], t.List[str]]: + """ + Given a bunch of dbt CLI arguments that may or may not be defined: + --select, --exclude, --models, --resource-type + + Combine them into a single set of --select/--exclude node selectors, throwing an error if mutually exclusive combinations are provided + Note that the returned value is still in dbt format, pass it to to_sqlmesh() to create a selector for the sqlmesh selector engine + """ + if models and select: + raise ValueError('"models" and "select" are mutually exclusive arguments') + + if models and resource_type: + raise ValueError('"models" and "resource_type" are mutually exclusive arguments') + + if models: + # --models implies resource_type:model + resource_type = "model" + + if resource_type: + resource_type_selector = f"resource_type:{resource_type}" + all_selectors = [*select, *models] + select = ( + [ + f"resource_type:{resource_type},{original_selector}" + for original_selector in all_selectors + ] + if all_selectors + else [resource_type_selector] + ) + + return select, exclude + + +def to_sqlmesh(dbt_select: t.List[str], dbt_exclude: t.List[str]) -> t.Optional[str]: """ Given selectors defined in the format of the dbt cli --select and --exclude arguments, convert them into a selector expression that the SQLMesh selector engine can understand. diff --git a/tests/core/test_selector_dbt.py b/tests/core/test_selector_dbt.py new file mode 100644 index 0000000000..112c5740ac --- /dev/null +++ b/tests/core/test_selector_dbt.py @@ -0,0 +1,63 @@ +import typing as t +import pytest +from pytest_mock import MockerFixture +from sqlglot import exp +from sqlmesh.core.model.kind import SeedKind, ExternalKind, FullKind +from sqlmesh.core.model.seed import Seed +from sqlmesh.core.model.definition import SqlModel, SeedModel, ExternalModel +from sqlmesh.core.audit.definition import StandaloneAudit +from sqlmesh.core.snapshot.definition import Node +from sqlmesh.core.selector import DbtSelector +from sqlmesh.core.selector import parse, ResourceType +from sqlmesh.utils.errors import SQLMeshError +import sqlmesh.core.dialect as d +from sqlmesh.utils import UniqueKeyDict + + +def test_parse_resource_type(): + assert parse("resource_type:foo") == ResourceType(this=exp.Var(this="foo")) + + +@pytest.mark.parametrize( + "resource_type,expected", + [ + ("model", {'"test"."normal_model"'}), + ("seed", {'"test"."seed_model"'}), + ("test", {'"test"."standalone_audit"'}), + ("source", {'"external"."model"'}), + ], +) +def test_expand_model_selections_resource_type( + mocker: MockerFixture, resource_type: str, expected: t.Set[str] +): + models: t.Dict[str, Node] = { + '"test"."normal_model"': SqlModel( + name="test.normal_model", + kind=FullKind(), + query=d.parse_one("SELECT 'normal_model' AS what"), + ), + '"test"."seed_model"': SeedModel( + name="test.seed_model", kind=SeedKind(path="/tmp/foo"), seed=Seed(content="id,name") + ), + '"test"."standalone_audit"': StandaloneAudit( + name="test.standalone_audit", query=d.parse_one("SELECT 'standalone_audit' AS what") + ), + '"external"."model"': ExternalModel(name="external.model", kind=ExternalKind()), + } + + selector = DbtSelector(state_reader=mocker.Mock(), models=UniqueKeyDict("models")) + + assert selector.expand_model_selections([f"resource_type:{resource_type}"], models) == expected + + +def test_unsupported_resource_type(mocker: MockerFixture): + selector = DbtSelector(state_reader=mocker.Mock(), models=UniqueKeyDict("models")) + + models: t.Dict[str, Node] = { + '"test"."normal_model"': SqlModel( + name="test.normal_model", query=d.parse_one("SELECT 'normal_model' AS what") + ), + } + + with pytest.raises(SQLMeshError, match="Unsupported"): + selector.expand_model_selections(["resource_type:analysis"], models) diff --git a/tests/core/test_selector.py b/tests/core/test_selector_native.py similarity index 100% rename from tests/core/test_selector.py rename to tests/core/test_selector_native.py diff --git a/tests/dbt/cli/test_list.py b/tests/dbt/cli/test_list.py index 712d80b2fe..3e6a55125c 100644 --- a/tests/dbt/cli/test_list.py +++ b/tests/dbt/cli/test_list.py @@ -79,3 +79,26 @@ def test_list_with_vars(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Re │ └── depends_on: jaffle_shop.customers""" in result.output ) + + +def test_list_models_mutually_exclusive( + jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result] +): + result = invoke_cli(["list", "--select", "foo", "--models", "bar"]) + assert result.exit_code != 0 + assert '"models" and "select" are mutually exclusive arguments' in result.output + + result = invoke_cli(["list", "--resource-type", "test", "--models", "bar"]) + assert result.exit_code != 0 + assert '"models" and "resource_type" are mutually exclusive arguments' in result.output + + +def test_list_models(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]): + result = invoke_cli(["list", "--models", "jaffle_shop"]) + assert result.exit_code == 0 + assert not result.exception + + assert "─ jaffle_shop.customers" in result.output + assert ( + "─ jaffle_shop.raw_customers" not in result.output + ) # should be excluded because dbt --models excludes seeds diff --git a/tests/dbt/cli/test_selectors.py b/tests/dbt/cli/test_selectors.py index 99907bda84..3d50fe6ed2 100644 --- a/tests/dbt/cli/test_selectors.py +++ b/tests/dbt/cli/test_selectors.py @@ -269,3 +269,61 @@ def test_selection_and_exclusion_by_dbt_names( assert sqlmesh_selector assert selector.expand_model_selections([sqlmesh_selector]) == expected + + +@pytest.mark.parametrize( + "input_args,expected", + [ + ( + dict(select=["jaffle_shop"], models=["jaffle_shop"]), + '"models" and "select" are mutually exclusive', + ), + ( + dict(models=["jaffle_shop"], resource_type="test"), + '"models" and "resource_type" are mutually exclusive', + ), + ( + dict(select=["jaffle_shop"], resource_type="test"), + (["resource_type:test,jaffle_shop"], []), + ), + (dict(resource_type="model"), (["resource_type:model"], [])), + (dict(models=["stg_customers"]), (["resource_type:model,stg_customers"], [])), + ( + dict(models=["stg_customers"], exclude=["orders"]), + (["resource_type:model,stg_customers"], ["orders"]), + ), + ], +) +def test_consolidate(input_args: t.Dict[str, t.Any], expected: t.Union[t.Tuple[str, str], str]): + all_input_args: t.Dict[str, t.Any] = dict(select=[], exclude=[], models=[], resource_type=None) + + all_input_args.update(input_args) + + def _do_assert(): + assert selectors.consolidate(**all_input_args) == expected + + if isinstance(expected, str): + with pytest.raises(ValueError, match=expected): + _do_assert() + else: + _do_assert() + + +def test_models_by_dbt_names(jaffle_shop_duckdb_context: Context): + ctx = jaffle_shop_duckdb_context + + selector = ctx._new_selector() + assert isinstance(selector, DbtSelector) + + selector_expr = selectors.to_sqlmesh( + *selectors.consolidate(select=[], exclude=[], models=["jaffle_shop"], resource_type=None) + ) + assert selector_expr + + assert selector.expand_model_selections([selector_expr]) == { + '"jaffle_shop"."main"."customers"', + '"jaffle_shop"."main"."orders"', + '"jaffle_shop"."main"."stg_customers"', + '"jaffle_shop"."main"."stg_orders"', + '"jaffle_shop"."main"."stg_payments"', + }