From 58ba2fc0a970e9728152107efe50128aae408241 Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Sun, 21 Sep 2025 23:02:58 +0000 Subject: [PATCH 1/2] Feat(sqlmesh_dbt): Select based on dbt name, not sqlmesh name --- sqlmesh/core/context.py | 5 + sqlmesh/core/selector.py | 65 ++++++++++- sqlmesh_dbt/operations.py | 4 +- tests/dbt/cli/test_list.py | 8 +- tests/dbt/cli/test_operations.py | 10 +- tests/dbt/cli/test_run.py | 2 +- tests/dbt/cli/test_selectors.py | 180 +++++++++++++++++++++++++++++++ tests/dbt/conftest.py | 2 +- 8 files changed, 259 insertions(+), 17 deletions(-) diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 437fbd6edd..f82e2aa6a2 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -348,6 +348,8 @@ class GenericContext(BaseContext, t.Generic[C]): load: Whether or not to automatically load all models and macros (default True). console: The rich instance used for printing out CLI command results. users: A list of users to make known to SQLMesh. + dbt_mode: A flag to indicate we are running in 'dbt mode' which means that things like + model selections should use the dbt names and not the native SQLMesh names """ CONFIG_TYPE: t.Type[C] @@ -368,6 +370,7 @@ def __init__( load: bool = True, users: t.Optional[t.List[User]] = None, config_loader_kwargs: t.Optional[t.Dict[str, t.Any]] = None, + dbt_mode: bool = False, ): self.configs = ( config @@ -390,6 +393,7 @@ def __init__( self._engine_adapter: t.Optional[EngineAdapter] = None self._linters: t.Dict[str, Linter] = {} self._loaded: bool = False + self._dbt_mode = dbt_mode self.path, self.config = t.cast(t.Tuple[Path, C], next(iter(self.configs.items()))) @@ -2901,6 +2905,7 @@ def _new_selector( default_catalog=self.default_catalog, dialect=self.default_dialect, cache_dir=self.cache_dir, + dbt_mode=self._dbt_mode, ) def _register_notification_targets(self) -> None: diff --git a/sqlmesh/core/selector.py b/sqlmesh/core/selector.py index c44065bdc0..34a69046c3 100644 --- a/sqlmesh/core/selector.py +++ b/sqlmesh/core/selector.py @@ -3,6 +3,7 @@ import fnmatch import typing as t from pathlib import Path +from itertools import zip_longest from sqlglot import exp from sqlglot.errors import ParseError @@ -36,6 +37,7 @@ def __init__( default_catalog: t.Optional[str] = None, dialect: t.Optional[str] = None, cache_dir: t.Optional[Path] = None, + dbt_mode: bool = False, ): self._state_reader = state_reader self._models = models @@ -44,6 +46,7 @@ def __init__( self._default_catalog = default_catalog self._dialect = dialect self._git_client = GitClient(context_path) + self._dbt_mode = dbt_mode if dag is None: self._dag: DAG[str] = DAG() @@ -167,13 +170,13 @@ def get_model(fqn: str) -> t.Optional[Model]: def expand_model_selections( self, model_selections: t.Iterable[str], models: t.Optional[t.Dict[str, Model]] = None ) -> t.Set[str]: - """Expands a set of model selections into a set of model names. + """Expands a set of model selections into a set of model fqns that can be looked up in the Context. Args: model_selections: A set of model selections. Returns: - A set of model names. + A set of model fqns. """ node = parse(" | ".join(f"({s})" for s in model_selections)) @@ -194,10 +197,9 @@ def evaluate(node: exp.Expression) -> t.Set[str]: return { fqn for fqn, model in all_models.items() - if fnmatch.fnmatchcase(model.name, node.this) + if fnmatch.fnmatchcase(self._model_name(model), node.this) } - fqn = normalize_model_name(pattern, self._default_catalog, self._dialect) - return {fqn} if fqn in all_models else set() + return self._pattern_to_model_fqns(pattern, all_models) if isinstance(node, exp.And): return evaluate(node.left) & evaluate(node.right) if isinstance(node, exp.Or): @@ -241,6 +243,59 @@ def evaluate(node: exp.Expression) -> t.Set[str]: return evaluate(node) + def _model_fqn(self, model: Model) -> str: + if self._dbt_mode: + dbt_fqn = model.dbt_fqn + if dbt_fqn is None: + raise SQLMeshError("Expecting dbt node information to be populated; it wasnt") + return dbt_fqn + return model.fqn + + def _model_name(self, model: Model) -> str: + if self._dbt_mode: + # dbt always matches on the fqn, not the name + return self._model_fqn(model) + return model.name + + def _pattern_to_model_fqns(self, pattern: str, all_models: t.Dict[str, Model]) -> t.Set[str]: + # note: all_models should be keyed by sqlmesh fqn, not dbt fqn + if not self._dbt_mode: + fqn = normalize_model_name(pattern, self._default_catalog, self._dialect) + return {fqn} if fqn in all_models else set() + + # a pattern like "staging.customers" should match a model called "jaffle_shop.staging.customers" + # but not a model called "jaffle_shop.customers.staging" + # also a pattern like "aging" should not match "staging" so we need to consider components; not substrings + pattern_components = pattern.split(".") + first_pattern_component = pattern_components[0] + matches = set() + for fqn, model in all_models.items(): + if not model.dbt_fqn: + continue + + dbt_fqn_components = model.dbt_fqn.split(".") + try: + starting_idx = dbt_fqn_components.index(first_pattern_component) + except ValueError: + continue + for pattern_component, fqn_component in zip_longest( + pattern_components, dbt_fqn_components[starting_idx:] + ): + if pattern_component and not fqn_component: + # the pattern still goes but we have run out of fqn components to match; no match + break + if fqn_component and not pattern_component: + # all elements of the pattern have matched elements of the fqn; match + matches.add(fqn) + break + if pattern_component != fqn_component: + # the pattern explicitly doesnt match a component; no match + break + else: + # called if no explicit break, indicating all components of the pattern matched all components of the fqn + matches.add(fqn) + return matches + class SelectorDialect(Dialect): IDENTIFIERS_CAN_START_WITH_DIGIT = True diff --git a/sqlmesh_dbt/operations.py b/sqlmesh_dbt/operations.py index e15a2cb93e..c79ab7e1bf 100644 --- a/sqlmesh_dbt/operations.py +++ b/sqlmesh_dbt/operations.py @@ -185,7 +185,7 @@ def _plan_builder_options( options.update( dict( # Add every selected model as a restatement to force them to get repopulated from scratch - restate_models=list(self.context.models) + restate_models=[m.dbt_fqn for m in self.context.models.values() if m.dbt_fqn] if not select_models else select_models, # by default in SQLMesh, restatements only operate on what has been committed to state. @@ -250,6 +250,8 @@ def create( paths=[project_dir], config_loader_kwargs=dict(profile=profile, target=target, variables=vars), load=True, + # dbt mode enables selectors to use dbt model fqn's rather than SQLMesh model names + dbt_mode=True, ) dbt_loader = sqlmesh_context._loaders[0] diff --git a/tests/dbt/cli/test_list.py b/tests/dbt/cli/test_list.py index 4d294decc1..712d80b2fe 100644 --- a/tests/dbt/cli/test_list.py +++ b/tests/dbt/cli/test_list.py @@ -19,7 +19,7 @@ def test_list(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]): def test_list_select(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]): - result = invoke_cli(["list", "--select", "main.raw_customers+"]) + result = invoke_cli(["list", "--select", "raw_customers+"]) assert result.exit_code == 0 assert not result.exception @@ -34,7 +34,7 @@ def test_list_select(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Resul def test_list_select_exclude(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]): # single exclude - result = invoke_cli(["list", "--select", "main.raw_customers+", "--exclude", "main.orders"]) + result = invoke_cli(["list", "--select", "raw_customers+", "--exclude", "orders"]) assert result.exit_code == 0 assert not result.exception @@ -49,8 +49,8 @@ def test_list_select_exclude(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[.. # multiple exclude for args in ( - ["--select", "main.stg_orders+", "--exclude", "main.customers", "--exclude", "main.orders"], - ["--select", "main.stg_orders+", "--exclude", "main.customers main.orders"], + ["--select", "stg_orders+", "--exclude", "customers", "--exclude", "orders"], + ["--select", "stg_orders+", "--exclude", "customers orders"], ): result = invoke_cli(["list", *args]) assert result.exit_code == 0 diff --git a/tests/dbt/cli/test_operations.py b/tests/dbt/cli/test_operations.py index 769887efe4..b23c87882a 100644 --- a/tests/dbt/cli/test_operations.py +++ b/tests/dbt/cli/test_operations.py @@ -138,7 +138,7 @@ def test_run_option_mapping(jaffle_shop_duckdb: Path): assert plan.selected_models_to_backfill is None assert {s.name for s in plan.snapshots} == {k for k in operations.context.snapshots} - plan = operations.run(select=["main.stg_orders+"]) + plan = operations.run(select=["stg_orders+"]) assert plan.environment.name == "prod" assert console.no_prompts is True assert console.no_diff is True @@ -155,7 +155,7 @@ def test_run_option_mapping(jaffle_shop_duckdb: Path): plan.selected_models_to_backfill | {standalone_audit_name} ) - plan = operations.run(select=["main.stg_orders+"], exclude=["main.customers"]) + plan = operations.run(select=["stg_orders+"], exclude=["customers"]) assert plan.environment.name == "prod" assert console.no_prompts is True assert console.no_diff is True @@ -171,7 +171,7 @@ def test_run_option_mapping(jaffle_shop_duckdb: Path): plan.selected_models_to_backfill | {standalone_audit_name} ) - plan = operations.run(exclude=["main.customers"]) + plan = operations.run(exclude=["customers"]) assert plan.environment.name == "prod" assert console.no_prompts is True assert console.no_diff is True @@ -238,7 +238,7 @@ def test_run_option_mapping_dev(jaffle_shop_duckdb: Path): assert plan.skip_backfill is True assert plan.selected_models_to_backfill == {'"jaffle_shop"."main"."new_model"'} - plan = operations.run(environment="dev", select=["main.stg_orders+"]) + plan = operations.run(environment="dev", select=["stg_orders+"]) assert plan.environment.name == "dev" assert console.no_prompts is True assert console.no_diff is True @@ -325,7 +325,7 @@ def test_run_option_full_refresh_with_selector(jaffle_shop_duckdb: Path): console = PlanCapturingConsole() operations.context.console = console - plan = operations.run(select=["main.stg_customers"], full_refresh=True) + plan = operations.run(select=["stg_customers"], full_refresh=True) assert len(plan.restatements) == 1 assert list(plan.restatements)[0].name == '"jaffle_shop"."main"."stg_customers"' diff --git a/tests/dbt/cli/test_run.py b/tests/dbt/cli/test_run.py index 788a7b04a8..7aeb8dd4d7 100644 --- a/tests/dbt/cli/test_run.py +++ b/tests/dbt/cli/test_run.py @@ -27,7 +27,7 @@ def test_run_with_selectors(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[... assert result.exit_code == 0 assert "main.orders" in result.output - result = invoke_cli(["run", "--select", "main.raw_customers+", "--exclude", "main.orders"]) + result = invoke_cli(["run", "--select", "raw_customers+", "--exclude", "orders"]) assert result.exit_code == 0 assert not result.exception diff --git a/tests/dbt/cli/test_selectors.py b/tests/dbt/cli/test_selectors.py index 6041a50d0a..319d4372a9 100644 --- a/tests/dbt/cli/test_selectors.py +++ b/tests/dbt/cli/test_selectors.py @@ -1,6 +1,8 @@ import typing as t import pytest from sqlmesh_dbt import selectors +from sqlmesh.core.context import Context +from pathlib import Path @pytest.mark.parametrize( @@ -77,3 +79,181 @@ def test_split_unions_and_intersections( expression: str, expected: t.Tuple[t.List[str], t.List[str]] ): assert selectors._split_unions_and_intersections(expression) == expected + + +@pytest.mark.parametrize( + "dbt_select,expected", + [ + (["aging"], set()), + ( + ["staging"], + { + '"jaffle_shop"."main"."stg_customers"', + '"jaffle_shop"."main"."stg_orders"', + '"jaffle_shop"."main"."stg_payments"', + }, + ), + (["staging.stg_customers"], {'"jaffle_shop"."main"."stg_customers"'}), + (["stg_customers.staging"], set()), + ( + ["+customers"], + { + '"jaffle_shop"."main"."customers"', + '"jaffle_shop"."main"."stg_customers"', + '"jaffle_shop"."main"."stg_orders"', + '"jaffle_shop"."main"."stg_payments"', + '"jaffle_shop"."main"."raw_customers"', + '"jaffle_shop"."main"."raw_orders"', + '"jaffle_shop"."main"."raw_payments"', + }, + ), + (["customers+"], {'"jaffle_shop"."main"."customers"'}), + ( + ["customers+", "stg_orders"], + {'"jaffle_shop"."main"."customers"', '"jaffle_shop"."main"."stg_orders"'}, + ), + (["tag:agg"], {'"jaffle_shop"."main"."agg_orders"'}), + ( + ["staging.stg_customers", "tag:agg"], + { + '"jaffle_shop"."main"."stg_customers"', + '"jaffle_shop"."main"."agg_orders"', + }, + ), + ( + ["+tag:agg"], + { + '"jaffle_shop"."main"."agg_orders"', + '"jaffle_shop"."main"."orders"', + '"jaffle_shop"."main"."stg_orders"', + '"jaffle_shop"."main"."stg_payments"', + '"jaffle_shop"."main"."raw_orders"', + '"jaffle_shop"."main"."raw_payments"', + }, + ), + ( + ["tag:agg+"], + { + '"jaffle_shop"."main"."agg_orders"', + }, + ), + ], +) +def test_select_by_dbt_names( + jaffle_shop_duckdb: Path, + jaffle_shop_duckdb_context: Context, + dbt_select: t.List[str], + expected: t.Set[str], +): + (jaffle_shop_duckdb / "models" / "agg_orders.sql").write_text(""" + {{ config(tags=["agg"]) }} + select order_date, count(*) as num_orders from {{ ref('orders') }} + """) + + ctx = jaffle_shop_duckdb_context + ctx.load() + assert '"jaffle_shop"."main"."agg_orders"' in ctx.models + + selector = ctx._new_selector() + assert selector._dbt_mode + + sqlmesh_selector = selectors.to_sqlmesh(dbt_select=dbt_select, dbt_exclude=[]) + assert sqlmesh_selector + + assert selector.expand_model_selections([sqlmesh_selector]) == expected + + +@pytest.mark.parametrize( + "dbt_exclude,expected", + [ + (["jaffle_shop"], set()), + ( + ["staging"], + { + '"jaffle_shop"."main"."agg_orders"', + '"jaffle_shop"."main"."customers"', + '"jaffle_shop"."main"."orders"', + '"jaffle_shop"."main"."raw_customers"', + '"jaffle_shop"."main"."raw_orders"', + '"jaffle_shop"."main"."raw_payments"', + }, + ), + (["+customers"], {'"jaffle_shop"."main"."orders"', '"jaffle_shop"."main"."agg_orders"'}), + ( + ["+tag:agg"], + { + '"jaffle_shop"."main"."customers"', + '"jaffle_shop"."main"."stg_customers"', + '"jaffle_shop"."main"."raw_customers"', + }, + ), + ], +) +def test_exclude_by_dbt_names( + jaffle_shop_duckdb: Path, + jaffle_shop_duckdb_context: Context, + dbt_exclude: t.List[str], + expected: t.Set[str], +): + (jaffle_shop_duckdb / "models" / "agg_orders.sql").write_text(""" + {{ config(tags=["agg"]) }} + select order_date, count(*) as num_orders from {{ ref('orders') }} + """) + + ctx = jaffle_shop_duckdb_context + ctx.load() + assert '"jaffle_shop"."main"."agg_orders"' in ctx.models + + selector = ctx._new_selector() + assert selector._dbt_mode + + sqlmesh_selector = selectors.to_sqlmesh(dbt_select=[], dbt_exclude=dbt_exclude) + assert sqlmesh_selector + + assert selector.expand_model_selections([sqlmesh_selector]) == expected + + +@pytest.mark.parametrize( + "dbt_select,dbt_exclude,expected", + [ + (["jaffle_shop"], ["jaffle_shop"], set()), + ( + ["staging"], + ["stg_customers"], + { + '"jaffle_shop"."main"."stg_orders"', + '"jaffle_shop"."main"."stg_payments"', + }, + ), + ( + ["staging.stg_customers", "tag:agg"], + ["tag:agg"], + { + '"jaffle_shop"."main"."stg_customers"', + }, + ), + ], +) +def test_selection_and_exclusion_by_dbt_names( + jaffle_shop_duckdb: Path, + jaffle_shop_duckdb_context: Context, + dbt_select: t.List[str], + dbt_exclude: t.List[str], + expected: t.Set[str], +): + (jaffle_shop_duckdb / "models" / "agg_orders.sql").write_text(""" + {{ config(tags=["agg"]) }} + select order_date, count(*) as num_orders from {{ ref('orders') }} + """) + + ctx = jaffle_shop_duckdb_context + ctx.load() + assert '"jaffle_shop"."main"."agg_orders"' in ctx.models + + selector = ctx._new_selector() + assert selector._dbt_mode + + sqlmesh_selector = selectors.to_sqlmesh(dbt_select=dbt_select, dbt_exclude=dbt_exclude) + assert sqlmesh_selector + + assert selector.expand_model_selections([sqlmesh_selector]) == expected diff --git a/tests/dbt/conftest.py b/tests/dbt/conftest.py index 56d77e7496..b2ea36ec86 100644 --- a/tests/dbt/conftest.py +++ b/tests/dbt/conftest.py @@ -99,7 +99,7 @@ def jaffle_shop_duckdb(copy_to_temp_path: t.Callable[..., t.List[Path]]) -> t.It @pytest.fixture def jaffle_shop_duckdb_context(jaffle_shop_duckdb: Path) -> Context: init_project_if_required(jaffle_shop_duckdb) - return Context(paths=[jaffle_shop_duckdb]) + return Context(paths=[jaffle_shop_duckdb], dbt_mode=True) @pytest.fixture() From 1e6cf0758e06dc237d12d523724bce892261a1f2 Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Tue, 23 Sep 2025 22:09:51 +0000 Subject: [PATCH 2/2] PR feedback --- sqlmesh/core/context.py | 11 +++------ sqlmesh/core/selector.py | 44 ++++++++++++++++++++------------- sqlmesh_dbt/operations.py | 5 ++-- tests/core/test_selector.py | 20 +++++++-------- tests/dbt/cli/test_selectors.py | 18 +++++++++++--- tests/dbt/conftest.py | 3 ++- 6 files changed, 61 insertions(+), 40 deletions(-) diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index f82e2aa6a2..e3feb1e14b 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -93,7 +93,7 @@ from sqlmesh.core.reference import ReferenceGraph from sqlmesh.core.scheduler import Scheduler, CompletionStatus from sqlmesh.core.schema_loader import create_external_models_file -from sqlmesh.core.selector import Selector +from sqlmesh.core.selector import Selector, NativeSelector from sqlmesh.core.snapshot import ( DeployabilityIndex, Snapshot, @@ -348,8 +348,6 @@ class GenericContext(BaseContext, t.Generic[C]): load: Whether or not to automatically load all models and macros (default True). console: The rich instance used for printing out CLI command results. users: A list of users to make known to SQLMesh. - dbt_mode: A flag to indicate we are running in 'dbt mode' which means that things like - model selections should use the dbt names and not the native SQLMesh names """ CONFIG_TYPE: t.Type[C] @@ -370,7 +368,7 @@ def __init__( load: bool = True, users: t.Optional[t.List[User]] = None, config_loader_kwargs: t.Optional[t.Dict[str, t.Any]] = None, - dbt_mode: bool = False, + selector: t.Optional[t.Type[Selector]] = None, ): self.configs = ( config @@ -393,7 +391,7 @@ def __init__( self._engine_adapter: t.Optional[EngineAdapter] = None self._linters: t.Dict[str, Linter] = {} self._loaded: bool = False - self._dbt_mode = dbt_mode + self._selector_cls = selector or NativeSelector self.path, self.config = t.cast(t.Tuple[Path, C], next(iter(self.configs.items()))) @@ -2897,7 +2895,7 @@ def _new_state_sync(self) -> StateSync: def _new_selector( self, models: t.Optional[UniqueKeyDict[str, Model]] = None, dag: t.Optional[DAG[str]] = None ) -> Selector: - return Selector( + return self._selector_cls( self.state_reader, models=models or self._models, context_path=self.path, @@ -2905,7 +2903,6 @@ def _new_selector( default_catalog=self.default_catalog, dialect=self.default_dialect, cache_dir=self.cache_dir, - dbt_mode=self._dbt_mode, ) def _register_notification_targets(self) -> None: diff --git a/sqlmesh/core/selector.py b/sqlmesh/core/selector.py index 34a69046c3..1484d06cee 100644 --- a/sqlmesh/core/selector.py +++ b/sqlmesh/core/selector.py @@ -4,6 +4,7 @@ import typing as t from pathlib import Path from itertools import zip_longest +import abc from sqlglot import exp from sqlglot.errors import ParseError @@ -27,7 +28,7 @@ from sqlmesh.core.state_sync import StateReader -class Selector: +class Selector(abc.ABC): def __init__( self, state_reader: StateReader, @@ -37,7 +38,6 @@ def __init__( default_catalog: t.Optional[str] = None, dialect: t.Optional[str] = None, cache_dir: t.Optional[Path] = None, - dbt_mode: bool = False, ): self._state_reader = state_reader self._models = models @@ -46,7 +46,6 @@ def __init__( self._default_catalog = default_catalog self._dialect = dialect self._git_client = GitClient(context_path) - self._dbt_mode = dbt_mode if dag is None: self._dag: DAG[str] = DAG() @@ -243,26 +242,37 @@ def evaluate(node: exp.Expression) -> t.Set[str]: return evaluate(node) - def _model_fqn(self, model: Model) -> str: - if self._dbt_mode: - dbt_fqn = model.dbt_fqn - if dbt_fqn is None: - raise SQLMeshError("Expecting dbt node information to be populated; it wasnt") - return dbt_fqn - return model.fqn + @abc.abstractmethod + def _model_name(self, model: Model) -> str: + """Given a model, return the name that a selector pattern contining wildcards should be fnmatch'd on""" + pass + + @abc.abstractmethod + def _pattern_to_model_fqns(self, pattern: str, all_models: t.Dict[str, Model]) -> t.Set[str]: + """Given a pattern, return the keys of the matching models from :all_models""" + pass + + +class NativeSelector(Selector): + """Implementation of selectors that matches objects based on SQLMesh native names""" def _model_name(self, model: Model) -> str: - if self._dbt_mode: - # dbt always matches on the fqn, not the name - return self._model_fqn(model) return model.name def _pattern_to_model_fqns(self, pattern: str, all_models: t.Dict[str, Model]) -> t.Set[str]: - # note: all_models should be keyed by sqlmesh fqn, not dbt fqn - if not self._dbt_mode: - fqn = normalize_model_name(pattern, self._default_catalog, self._dialect) - return {fqn} if fqn in all_models else set() + fqn = normalize_model_name(pattern, self._default_catalog, self._dialect) + return {fqn} if fqn in all_models else set() + +class DbtSelector(Selector): + """Implementation of selectors that matches objects based on the DBT names instead of the SQLMesh native names""" + + def _model_name(self, model: Model) -> str: + if dbt_fqn := model.dbt_fqn: + return dbt_fqn + raise SQLMeshError("dbt node information must be populated to use dbt selectors") + + def _pattern_to_model_fqns(self, pattern: str, all_models: t.Dict[str, Model]) -> t.Set[str]: # a pattern like "staging.customers" should match a model called "jaffle_shop.staging.customers" # but not a model called "jaffle_shop.customers.staging" # also a pattern like "aging" should not match "staging" so we need to consider components; not substrings diff --git a/sqlmesh_dbt/operations.py b/sqlmesh_dbt/operations.py index c79ab7e1bf..a157705ffd 100644 --- a/sqlmesh_dbt/operations.py +++ b/sqlmesh_dbt/operations.py @@ -231,6 +231,7 @@ def create( from sqlmesh.core.console import set_console from sqlmesh_dbt.console import DbtCliConsole from sqlmesh.utils.errors import SQLMeshError + from sqlmesh.core.selector import DbtSelector # clear any existing handlers set up by click/rich as defaults so that once SQLMesh logging config is applied, # we dont get duplicate messages logged from things like console.log_warning() @@ -250,8 +251,8 @@ def create( paths=[project_dir], config_loader_kwargs=dict(profile=profile, target=target, variables=vars), load=True, - # dbt mode enables selectors to use dbt model fqn's rather than SQLMesh model names - dbt_mode=True, + # DbtSelector selects based on dbt model fqn's rather than SQLMesh model names + selector=DbtSelector, ) dbt_loader = sqlmesh_context._loaders[0] diff --git a/tests/core/test_selector.py b/tests/core/test_selector.py index 80b9ef691e..46d666db64 100644 --- a/tests/core/test_selector.py +++ b/tests/core/test_selector.py @@ -12,7 +12,7 @@ from sqlmesh.core.environment import Environment from sqlmesh.core.model import Model, SqlModel from sqlmesh.core.model.common import ParsableSql -from sqlmesh.core.selector import Selector +from sqlmesh.core.selector import NativeSelector from sqlmesh.core.snapshot import SnapshotChangeCategory from sqlmesh.utils import UniqueKeyDict from sqlmesh.utils.date import now_timestamp @@ -88,7 +88,7 @@ def test_select_models(mocker: MockerFixture, make_snapshot, default_catalog: t. local_models[modified_model_v2.fqn] = modified_model_v2.copy( update={"mapping_schema": added_model_schema} ) - selector = Selector(state_reader_mock, local_models, default_catalog=default_catalog) + selector = NativeSelector(state_reader_mock, local_models, default_catalog=default_catalog) _assert_models_equal( selector.select_models(["db.added_model"], env_name), @@ -243,7 +243,7 @@ def test_select_models_expired_environment(mocker: MockerFixture, make_snapshot) local_models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") local_models[modified_model_v2.fqn] = modified_model_v2 - selector = Selector(state_reader_mock, local_models) + selector = NativeSelector(state_reader_mock, local_models) _assert_models_equal( selector.select_models(["*.modified_model"], env_name, fallback_env_name="prod"), @@ -305,7 +305,7 @@ def test_select_change_schema(mocker: MockerFixture, make_snapshot): local_child = child.copy(update={"mapping_schema": {'"db"': {'"parent"': {"b": "INT"}}}}) local_models[local_child.fqn] = local_child - selector = Selector(state_reader_mock, local_models) + selector = NativeSelector(state_reader_mock, local_models) selected = selector.select_models(["db.parent"], env_name) assert selected[local_child.fqn].render_query() != child.render_query() @@ -339,7 +339,7 @@ def test_select_models_missing_env(mocker: MockerFixture, make_snapshot): local_models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") local_models[model.fqn] = model - selector = Selector(state_reader_mock, local_models) + selector = NativeSelector(state_reader_mock, local_models) assert selector.select_models([model.name], "missing_env").keys() == {model.fqn} assert not selector.select_models(["missing"], "missing_env") @@ -563,7 +563,7 @@ def test_expand_model_selections( ) models[model.fqn] = model - selector = Selector(mocker.Mock(), models) + selector = NativeSelector(mocker.Mock(), models) assert selector.expand_model_selections(selections) == output @@ -576,7 +576,7 @@ def test_model_selection_normalized(mocker: MockerFixture, make_snapshot): dialect="bigquery", ) models[model.fqn] = model - selector = Selector(mocker.Mock(), models, dialect="bigquery") + selector = NativeSelector(mocker.Mock(), models, dialect="bigquery") assert selector.expand_model_selections(["db.test_Model"]) == {'"db"."test_Model"'} @@ -624,7 +624,7 @@ def test_expand_git_selection( git_client_mock.list_uncommitted_changed_files.return_value = [] git_client_mock.list_committed_changed_files.return_value = [model_a._path, model_c._path] - selector = Selector(mocker.Mock(), models) + selector = NativeSelector(mocker.Mock(), models) selector._git_client = git_client_mock assert selector.expand_model_selections(expressions) == expected_fqns @@ -658,7 +658,7 @@ def test_select_models_with_external_parent(mocker: MockerFixture): local_models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") local_models[added_model.fqn] = added_model - selector = Selector(state_reader_mock, local_models, default_catalog=default_catalog) + selector = NativeSelector(state_reader_mock, local_models, default_catalog=default_catalog) expanded_selections = selector.expand_model_selections(["+*added_model*"]) assert expanded_selections == {added_model.fqn} @@ -699,7 +699,7 @@ def test_select_models_local_tags_take_precedence_over_remote( local_models[local_existing.fqn] = local_existing local_models[local_new.fqn] = local_new - selector = Selector(state_reader_mock, local_models) + selector = NativeSelector(state_reader_mock, local_models) selected = selector.select_models(["tag:a"], env_name) diff --git a/tests/dbt/cli/test_selectors.py b/tests/dbt/cli/test_selectors.py index 319d4372a9..99907bda84 100644 --- a/tests/dbt/cli/test_selectors.py +++ b/tests/dbt/cli/test_selectors.py @@ -1,6 +1,7 @@ import typing as t import pytest from sqlmesh_dbt import selectors +from sqlmesh.core.selector import DbtSelector from sqlmesh.core.context import Context from pathlib import Path @@ -112,6 +113,7 @@ def test_split_unions_and_intersections( ["customers+", "stg_orders"], {'"jaffle_shop"."main"."customers"', '"jaffle_shop"."main"."stg_orders"'}, ), + (["*.staging.stg_c*"], {'"jaffle_shop"."main"."stg_customers"'}), (["tag:agg"], {'"jaffle_shop"."main"."agg_orders"'}), ( ["staging.stg_customers", "tag:agg"], @@ -137,6 +139,16 @@ def test_split_unions_and_intersections( '"jaffle_shop"."main"."agg_orders"', }, ), + ( + ["tag:b*"], + set(), + ), + ( + ["tag:a*"], + { + '"jaffle_shop"."main"."agg_orders"', + }, + ), ], ) def test_select_by_dbt_names( @@ -155,7 +167,7 @@ def test_select_by_dbt_names( assert '"jaffle_shop"."main"."agg_orders"' in ctx.models selector = ctx._new_selector() - assert selector._dbt_mode + assert isinstance(selector, DbtSelector) sqlmesh_selector = selectors.to_sqlmesh(dbt_select=dbt_select, dbt_exclude=[]) assert sqlmesh_selector @@ -205,7 +217,7 @@ def test_exclude_by_dbt_names( assert '"jaffle_shop"."main"."agg_orders"' in ctx.models selector = ctx._new_selector() - assert selector._dbt_mode + assert isinstance(selector, DbtSelector) sqlmesh_selector = selectors.to_sqlmesh(dbt_select=[], dbt_exclude=dbt_exclude) assert sqlmesh_selector @@ -251,7 +263,7 @@ def test_selection_and_exclusion_by_dbt_names( assert '"jaffle_shop"."main"."agg_orders"' in ctx.models selector = ctx._new_selector() - assert selector._dbt_mode + assert isinstance(selector, DbtSelector) sqlmesh_selector = selectors.to_sqlmesh(dbt_select=dbt_select, dbt_exclude=dbt_exclude) assert sqlmesh_selector diff --git a/tests/dbt/conftest.py b/tests/dbt/conftest.py index b2ea36ec86..846dfc6aa9 100644 --- a/tests/dbt/conftest.py +++ b/tests/dbt/conftest.py @@ -7,6 +7,7 @@ import pytest from sqlmesh.core.context import Context +from sqlmesh.core.selector import DbtSelector from sqlmesh.dbt.context import DbtContext from sqlmesh.dbt.project import Project from sqlmesh.dbt.target import PostgresConfig @@ -99,7 +100,7 @@ def jaffle_shop_duckdb(copy_to_temp_path: t.Callable[..., t.List[Path]]) -> t.It @pytest.fixture def jaffle_shop_duckdb_context(jaffle_shop_duckdb: Path) -> Context: init_project_if_required(jaffle_shop_duckdb) - return Context(paths=[jaffle_shop_duckdb], dbt_mode=True) + return Context(paths=[jaffle_shop_duckdb], selector=DbtSelector) @pytest.fixture()