Skip to content

Commit 0ba3391

Browse files
committed
use all vars when dynamic vars are detected
1 parent 349ba6d commit 0ba3391

File tree

9 files changed

+43
-172
lines changed

9 files changed

+43
-172
lines changed

sqlmesh/dbt/adapter.py

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -233,56 +233,6 @@ def _raise_parsetime_adapter_call_error(action: str) -> None:
233233
raise ParsetimeAdapterCallError(f"Can't {action} at parse time.")
234234

235235

236-
class StubParsetimeAdapter(BaseAdapter):
237-
"""Same as ParsetimeAdapter, but returns stub / empty values instead of raising an error."""
238-
239-
def get_relation(self, database: str, schema: str, identifier: str) -> t.Optional[BaseRelation]:
240-
return None
241-
242-
def load_relation(self, relation: BaseRelation) -> t.Optional[BaseRelation]:
243-
return None
244-
245-
def list_relations(self, database: t.Optional[str], schema: str) -> t.List[BaseRelation]:
246-
return []
247-
248-
def list_relations_without_caching(self, schema_relation: BaseRelation) -> t.List[BaseRelation]:
249-
return []
250-
251-
def get_columns_in_relation(self, relation: BaseRelation) -> t.List[Column]:
252-
return []
253-
254-
def get_missing_columns(
255-
self, from_relation: BaseRelation, to_relation: BaseRelation
256-
) -> t.List[Column]:
257-
return []
258-
259-
def create_schema(self, relation: BaseRelation) -> None:
260-
pass
261-
262-
def drop_schema(self, relation: BaseRelation) -> None:
263-
pass
264-
265-
def drop_relation(self, relation: BaseRelation) -> None:
266-
pass
267-
268-
def rename_relation(self, from_relation: BaseRelation, to_relation: BaseRelation) -> None:
269-
pass
270-
271-
def execute(
272-
self, sql: str, auto_begin: bool = False, fetch: bool = False
273-
) -> t.Tuple[AdapterResponse, agate.Table]:
274-
from dbt.adapters.base.impl import AdapterResponse
275-
from sqlmesh.dbt.util import empty_table
276-
277-
return AdapterResponse(""), empty_table()
278-
279-
def resolve_schema(self, relation: BaseRelation) -> t.Optional[str]:
280-
return relation.schema
281-
282-
def resolve_identifier(self, relation: BaseRelation) -> t.Optional[str]:
283-
return relation.identifier
284-
285-
286236
class RuntimeAdapter(BaseAdapter):
287237
def __init__(
288238
self,

sqlmesh/dbt/basemodel.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -293,14 +293,15 @@ def sqlmesh_model_kwargs(
293293
self,
294294
context: DbtContext,
295295
column_types_override: t.Optional[t.Dict[str, ColumnConfig]] = None,
296-
extra_dependencies: t.Optional[Dependencies] = None,
297296
) -> t.Dict[str, t.Any]:
298297
"""Get common sqlmesh model parameters"""
299298
self.check_for_circular_test_refs(context)
300299

301-
dependencies = self.dependencies
302-
if extra_dependencies:
303-
dependencies = dependencies.union(extra_dependencies)
300+
dependencies = self.dependencies.copy()
301+
if dependencies.has_dynamic_var_names:
302+
# Include ALL variables as dependencies since we couldn't determine
303+
# precisely which variables are referenced in the model
304+
dependencies.variables |= set(context.variables)
304305

305306
model_dialect = self.dialect(context)
306307
model_context = context.context_for_dependencies(
@@ -359,8 +360,3 @@ def _model_jinja_context(
359360
"config": self.config_attribute_dict,
360361
**context.jinja_globals,
361362
}
362-
363-
def _track_dependencies_on_render(self, input: str, context: DbtContext) -> Dependencies:
364-
return context.track_dependencies_on_render(
365-
input, self._model_jinja_context(context, self.dependencies), self.package_name
366-
)

sqlmesh/dbt/builtin.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from sqlmesh.core.console import get_console
1818
from sqlmesh.core.engine_adapter import EngineAdapter
1919
from sqlmesh.core.snapshot.definition import DeployabilityIndex
20-
from sqlmesh.dbt.adapter import BaseAdapter, ParsetimeAdapter, RuntimeAdapter, StubParsetimeAdapter
20+
from sqlmesh.dbt.adapter import BaseAdapter, ParsetimeAdapter, RuntimeAdapter
2121
from sqlmesh.dbt.relation import Policy
2222
from sqlmesh.dbt.target import TARGET_TYPE_TO_CONFIG_CLASS
2323
from sqlmesh.dbt.util import DBT_VERSION
@@ -384,15 +384,15 @@ def create_builtin_globals(
384384
builtin_globals["this"] = this
385385

386386
sources = jinja_globals.pop("sources", None)
387-
if sources is not None and "source" not in jinja_globals:
387+
if sources is not None:
388388
builtin_globals["source"] = generate_source(sources, api)
389389

390390
refs = jinja_globals.pop("refs", None)
391-
if refs is not None and "ref" not in jinja_globals:
391+
if refs is not None:
392392
builtin_globals["ref"] = generate_ref(refs, api)
393393

394394
variables = jinja_globals.pop("vars", None)
395-
if variables is not None and "var" not in jinja_globals:
395+
if variables is not None:
396396
builtin_globals["var"] = Var(variables)
397397

398398
deployability_index = (
@@ -415,7 +415,6 @@ def create_builtin_globals(
415415
{k: builtin_globals.get(k) for k in ("ref", "source", "config", "var")}
416416
)
417417

418-
execute = True
419418
if engine_adapter is not None:
420419
builtin_globals["flags"] = Flags(which="run")
421420
adapter: BaseAdapter = RuntimeAdapter(
@@ -436,11 +435,7 @@ def create_builtin_globals(
436435
)
437436
else:
438437
builtin_globals["flags"] = Flags(which="parse")
439-
adapter_class: t.Type[BaseAdapter] = ParsetimeAdapter
440-
if jinja_globals.get("use_stub_adapter", False):
441-
adapter_class = StubParsetimeAdapter
442-
execute = False
443-
adapter = adapter_class(
438+
adapter = ParsetimeAdapter(
444439
jinja_macros,
445440
jinja_globals={**builtin_globals, **jinja_globals},
446441
project_dialect=project_dialect,
@@ -451,7 +446,7 @@ def create_builtin_globals(
451446
builtin_globals.update(
452447
{
453448
"adapter": adapter,
454-
"execute": execute,
449+
"execute": True,
455450
"load_relation": adapter.load_relation,
456451
"store_result": sql_execution.store_result,
457452
"load_result": sql_execution.load_result,

sqlmesh/dbt/context.py

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from dbt.adapters.base import BaseRelation
88

99
from sqlmesh.core.config import Config as SQLMeshConfig
10-
from sqlmesh.dbt.builtin import _relation_info_to_relation, Var
10+
from sqlmesh.dbt.builtin import _relation_info_to_relation
1111
from sqlmesh.dbt.common import Dependencies
1212
from sqlmesh.dbt.manifest import ManifestHelper
1313
from sqlmesh.dbt.target import TargetConfig
@@ -101,8 +101,6 @@ def add_variables(self, variables: t.Dict[str, t.Any]) -> None:
101101
self._jinja_environment = None
102102

103103
def set_and_render_variables(self, variables: t.Dict[str, t.Any], package: str) -> None:
104-
self.variables = variables
105-
106104
jinja_environment = self.jinja_macros.build_environment(**self.jinja_globals)
107105

108106
def _render_var(value: t.Any) -> t.Any:
@@ -212,38 +210,6 @@ def target(self, value: TargetConfig) -> None:
212210
def render(self, source: str, **kwargs: t.Any) -> str:
213211
return self.jinja_environment.from_string(source).render(**kwargs)
214212

215-
def track_dependencies_on_render(
216-
self, input: str, jinja_context: t.Dict[str, t.Any], package_name: t.Optional[str] = None
217-
) -> Dependencies:
218-
dependencies_on_render = Dependencies()
219-
220-
class TrackingVar(Var):
221-
def __call__(
222-
self, name: str, default: t.Optional[t.Any] = None, **kwargs: t.Any
223-
) -> t.Any:
224-
dependencies_on_render.variables.add(name)
225-
return super().__call__(name, default, **kwargs)
226-
227-
def has_var(self, name: str) -> bool:
228-
dependencies_on_render.variables.add(name)
229-
return super().has_var(name)
230-
231-
if package_name:
232-
top_level_packages = [*self.jinja_macros.top_level_packages, package_name]
233-
jinja_macros = self.jinja_macros.copy(update={"top_level_packages": top_level_packages})
234-
else:
235-
jinja_macros = self.jinja_macros
236-
237-
jinja_environment = jinja_macros.build_environment(
238-
**{
239-
**jinja_context,
240-
"var": TrackingVar(self.variables),
241-
"use_stub_adapter": True,
242-
}
243-
)
244-
jinja_environment.from_string(input).render()
245-
return dependencies_on_render
246-
247213
def get_callable_macro(
248214
self, name: str, package: t.Optional[str] = None
249215
) -> t.Optional[t.Callable]:

sqlmesh/dbt/loader.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,6 @@ def _to_sqlmesh(config: BMC, context: DbtContext) -> Model:
124124
)
125125

126126
for project in self._load_projects():
127-
context = project.context.copy()
128-
129127
macros_max_mtime = self._macros_max_mtime
130128
yaml_max_mtimes = self._compute_yaml_max_mtime_per_subfolder(
131129
project.context.project_root
@@ -135,12 +133,13 @@ def _to_sqlmesh(config: BMC, context: DbtContext) -> Model:
135133
logger.debug("Converting models to sqlmesh")
136134
# Now that config is rendered, create the sqlmesh models
137135
for package in project.packages.values():
138-
context.set_and_render_variables(package.variables, package.name)
136+
package_context = project.context.copy()
137+
package_context.set_and_render_variables(package.variables, package.name)
139138
package_models: t.Dict[str, BaseModelConfig] = {**package.models, **package.seeds}
140139

141140
for model in package_models.values():
142141
sqlmesh_model = cache.get_or_load_models(
143-
model.path, loader=lambda: [_to_sqlmesh(model, context)]
142+
model.path, loader=lambda: [_to_sqlmesh(model, package_context)]
144143
)[0]
145144

146145
models[sqlmesh_model.fqn] = sqlmesh_model
@@ -155,14 +154,13 @@ def _load_audits(
155154
audits: UniqueKeyDict = UniqueKeyDict("audits")
156155

157156
for project in self._load_projects():
158-
context = project.context
159-
160157
logger.debug("Converting audits to sqlmesh")
161158
for package in project.packages.values():
162-
context.set_and_render_variables(package.variables, package.name)
159+
package_context = project.context.copy()
160+
package_context.set_and_render_variables(package.variables, package.name)
163161
for test in package.tests.values():
164162
logger.debug("Converting '%s' to sqlmesh format", test.name)
165-
audits[test.name] = test.to_sqlmesh(context)
163+
audits[test.name] = test.to_sqlmesh(package_context)
166164

167165
return audits
168166

@@ -237,9 +235,9 @@ def _load_environment_statements(self, macros: MacroRegistry) -> t.List[Environm
237235
project_names: t.Set[str] = set()
238236
dialect = self.config.dialect
239237
for project in self._load_projects():
240-
context = project.context
241238
for package_name, package in project.packages.items():
242-
context.set_and_render_variables(package.variables, package_name)
239+
package_context = project.context.copy()
240+
package_context.set_and_render_variables(package.variables, package_name)
243241
on_run_start: t.List[str] = [
244242
on_run_hook.sql
245243
for on_run_hook in sorted(package.on_run_start.values(), key=lambda h: h.index)
@@ -254,7 +252,7 @@ def _load_environment_statements(self, macros: MacroRegistry) -> t.List[Environm
254252
for hook in [*package.on_run_start.values(), *package.on_run_end.values()]:
255253
dependencies = dependencies.union(hook.dependencies)
256254

257-
statements_context = context.context_for_dependencies(dependencies)
255+
statements_context = package_context.context_for_dependencies(dependencies)
258256
jinja_registry = make_jinja_registry(
259257
statements_context.jinja_macros, package_name, set(dependencies.macros)
260258
)

sqlmesh/dbt/model.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@
2525
)
2626
from sqlmesh.core.model.kind import SCDType2ByTimeKind, OnDestructiveChange, OnAdditiveChange
2727
from sqlmesh.dbt.basemodel import BaseModelConfig, Materialization, SnapshotStrategy
28-
from sqlmesh.dbt.column import ColumnConfig
29-
from sqlmesh.dbt.common import SqlStr, extract_jinja_config, sql_str_validator, Dependencies
28+
from sqlmesh.dbt.common import SqlStr, extract_jinja_config, sql_str_validator
3029
from sqlmesh.utils.errors import ConfigError
3130
from sqlmesh.utils.pydantic import field_validator
3231

@@ -437,30 +436,6 @@ def sqlmesh_config_fields(self) -> t.Set[str]:
437436
"physical_version",
438437
}
439438

440-
def sqlmesh_model_kwargs(
441-
self,
442-
context: DbtContext,
443-
column_types_override: t.Optional[t.Dict[str, ColumnConfig]] = None,
444-
extra_dependencies: t.Optional[Dependencies] = None,
445-
) -> t.Dict[str, t.Any]:
446-
if not self.dependencies.has_dynamic_var_names:
447-
return super().sqlmesh_model_kwargs(context, column_types_override, extra_dependencies)
448-
449-
extra_dependencies = extra_dependencies or Dependencies()
450-
extra_dependencies = extra_dependencies.union(
451-
self._track_dependencies_on_render(self.sql_no_config, context)
452-
)
453-
for pre_hook in self.pre_hook:
454-
extra_dependencies = extra_dependencies.union(
455-
self._track_dependencies_on_render(pre_hook.sql, context)
456-
)
457-
for post_hook in self.post_hook:
458-
extra_dependencies = extra_dependencies.union(
459-
self._track_dependencies_on_render(post_hook.sql, context)
460-
)
461-
462-
return super().sqlmesh_model_kwargs(context, column_types_override, extra_dependencies)
463-
464439
def to_sqlmesh(
465440
self,
466441
context: DbtContext,

sqlmesh/dbt/project.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,6 @@ def load(cls, context: DbtContext, variables: t.Optional[t.Dict[str, t.Any]] = N
5555
raise ConfigError(f"Could not find {PROJECT_FILENAME} in {context.project_root}")
5656
project_yaml = load_yaml(project_file_path)
5757

58-
variable_overrides = variables
59-
variables = {**project_yaml.get("vars", {}), **(variables or {})}
60-
6158
project_name = context.render(project_yaml.get("name", ""))
6259
context.project_name = project_name
6360
if not context.project_name:
@@ -69,6 +66,7 @@ def load(cls, context: DbtContext, variables: t.Optional[t.Dict[str, t.Any]] = N
6966
profile = Profile.load(context, context.target_name)
7067
context.target = profile.target
7168

69+
variable_overrides = variables or {}
7270
context.manifest = ManifestHelper(
7371
project_file_path.parent,
7472
profile.path.parent,
@@ -101,13 +99,17 @@ def load(cls, context: DbtContext, variables: t.Optional[t.Dict[str, t.Any]] = N
10199
package = package_loader.load(path.parent)
102100
packages[package.name] = package
103101

102+
all_project_variables = {**project_yaml.get("vars", {}), **(variable_overrides or {})}
104103
for name, package in packages.items():
105-
package_vars = variables.get(name)
104+
package_vars = all_project_variables.get(name)
106105

107106
if isinstance(package_vars, dict):
108107
package.variables.update(package_vars)
109108

110-
package.variables.update(variables)
109+
if name == context.project_name:
110+
package.variables.update(all_project_variables)
111+
else:
112+
package.variables.update(variable_overrides)
111113

112114
return Project(context, profile, packages)
113115

tests/dbt/test_config.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -375,26 +375,10 @@ def test_variables(assert_exp_eq, sushi_test_project):
375375
expected_customer_variables = {
376376
"some_var": ["foo", "bar"],
377377
"some_other_var": 5,
378-
"yet_another_var": 1,
378+
"yet_another_var": 5,
379379
"customers:bla": False,
380380
"customers:customer_id": "customer_id",
381381
"start": "Jan 1 2022",
382-
"top_waiters:limit": 10,
383-
"top_waiters:revenue": "revenue",
384-
"customers:boo": ["a", "b"],
385-
"nested_vars": {
386-
"some_nested_var": 2,
387-
},
388-
"dynamic_test_var": 3,
389-
"list_var": [
390-
{"name": "item1", "value": 1},
391-
{"name": "item2", "value": 2},
392-
],
393-
"customers": {
394-
"customers:bla": False,
395-
"customers:customer_id": "customer_id",
396-
"some_var": ["foo", "bar"],
397-
},
398382
}
399383

400384
assert sushi_test_project.packages["sushi"].variables == expected_sushi_variables
@@ -407,7 +391,9 @@ def test_nested_variables(sushi_test_project):
407391
sql="SELECT {{ var('nested_vars')['some_nested_var'] }}",
408392
dependencies=Dependencies(variables=["nested_vars"]),
409393
)
410-
sqlmesh_model = model_config.to_sqlmesh(sushi_test_project.context)
394+
context = sushi_test_project.context.copy()
395+
context.set_and_render_variables(sushi_test_project.packages["sushi"].variables, "sushi")
396+
sqlmesh_model = model_config.to_sqlmesh(context)
411397
assert sqlmesh_model.jinja_macros.global_objs["vars"]["nested_vars"] == {"some_nested_var": 2}
412398

413399

0 commit comments

Comments
 (0)