diff --git a/sqlmesh/core/audit/definition.py b/sqlmesh/core/audit/definition.py index c136a00cc0..210ae9da1b 100644 --- a/sqlmesh/core/audit/definition.py +++ b/sqlmesh/core/audit/definition.py @@ -438,13 +438,15 @@ def load_audit( extra_kwargs: t.Dict[str, t.Any] = {} if is_standalone: - jinja_macro_refrences, used_variables = extract_macro_references_and_variables( + jinja_macro_refrences, referenced_variables = extract_macro_references_and_variables( *(gen(s) for s in statements), gen(query), ) jinja_macros = (jinja_macros or JinjaMacroRegistry()).trim(jinja_macro_refrences) for jinja_macro in jinja_macros.root_macros.values(): - used_variables.update(extract_macro_references_and_variables(jinja_macro.definition)[1]) + referenced_variables.update( + extract_macro_references_and_variables(jinja_macro.definition)[1] + ) extra_kwargs["jinja_macros"] = jinja_macros extra_kwargs["python_env"] = make_python_env( @@ -453,7 +455,7 @@ def load_audit( module_path, macros or macro.get_registry(), variables=variables, - used_variables=used_variables, + referenced_variables=referenced_variables, ) extra_kwargs["default_catalog"] = default_catalog if project is not None: diff --git a/sqlmesh/core/constants.py b/sqlmesh/core/constants.py index 27d6cf0d7f..2df7697b9d 100644 --- a/sqlmesh/core/constants.py +++ b/sqlmesh/core/constants.py @@ -80,7 +80,9 @@ DEFAULT_SCHEMA = "default" SQLMESH_VARS = "__sqlmesh__vars__" +SQLMESH_VARS_METADATA = "__sqlmesh__vars__metadata__" SQLMESH_BLUEPRINT_VARS = "__sqlmesh__blueprint__vars__" +SQLMESH_BLUEPRINT_VARS_METADATA = "__sqlmesh__blueprint__vars__metadata__" VAR = "var" BLUEPRINT_VAR = "blueprint_var" diff --git a/sqlmesh/core/macros.py b/sqlmesh/core/macros.py index af891a5460..42a4a8b8dc 100644 --- a/sqlmesh/core/macros.py +++ b/sqlmesh/core/macros.py @@ -210,7 +210,12 @@ def __init__( self.macros[normalize_macro_name(k)] = self.env[k] elif v.is_value: value = self.env[k] - if k in (c.SQLMESH_VARS, c.SQLMESH_BLUEPRINT_VARS): + if k in ( + c.SQLMESH_VARS, + c.SQLMESH_VARS_METADATA, + c.SQLMESH_BLUEPRINT_VARS, + c.SQLMESH_BLUEPRINT_VARS_METADATA, + ): value = { var_name: ( self.parse_one(var_value.sql) @@ -557,17 +562,25 @@ def views(self) -> t.List[str]: def var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: """Returns the value of the specified variable, or the default value if it doesn't exist.""" - return (self.locals.get(c.SQLMESH_VARS) or {}).get(var_name.lower(), default) + return { + **(self.locals.get(c.SQLMESH_VARS) or {}), + **(self.locals.get(c.SQLMESH_VARS_METADATA) or {}), + }.get(var_name.lower(), default) def blueprint_var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: """Returns the value of the specified blueprint variable, or the default value if it doesn't exist.""" - return (self.locals.get(c.SQLMESH_BLUEPRINT_VARS) or {}).get(var_name.lower(), default) + return { + **(self.locals.get(c.SQLMESH_BLUEPRINT_VARS) or {}), + **(self.locals.get(c.SQLMESH_BLUEPRINT_VARS_METADATA) or {}), + }.get(var_name.lower(), default) @property def variables(self) -> t.Dict[str, t.Any]: return { **self.locals.get(c.SQLMESH_VARS, {}), + **self.locals.get(c.SQLMESH_VARS_METADATA, {}), **self.locals.get(c.SQLMESH_BLUEPRINT_VARS, {}), + **self.locals.get(c.SQLMESH_BLUEPRINT_VARS_METADATA, {}), } def _coerce(self, expr: exp.Expression, typ: t.Any, strict: bool = False) -> t.Any: diff --git a/sqlmesh/core/model/common.py b/sqlmesh/core/model/common.py index 11ddc8234b..9a68ec18c0 100644 --- a/sqlmesh/core/model/common.py +++ b/sqlmesh/core/model/common.py @@ -28,7 +28,7 @@ from sqlmesh.utils import registry_decorator from sqlmesh.utils.jinja import MacroReference - MacroCallable = registry_decorator + MacroCallable = t.Union[Executable, registry_decorator] def make_python_env( @@ -40,7 +40,7 @@ def make_python_env( module_path: Path, macros: MacroRegistry, variables: t.Optional[t.Dict[str, t.Any]] = None, - used_variables: t.Optional[t.Set[str]] = None, + referenced_variables: t.Optional[t.Set[str]] = None, path: t.Optional[Path] = None, python_env: t.Optional[t.Dict[str, Executable]] = None, strict_resolution: bool = True, @@ -48,20 +48,64 @@ def make_python_env( dialect: DialectType = None, ) -> t.Dict[str, Executable]: python_env = {} if python_env is None else python_env - variables = variables or {} env: t.Dict[str, t.Tuple[t.Any, t.Optional[bool]]] = {} - used_macros: t.Dict[ - str, - t.Tuple[t.Union[Executable | MacroCallable], t.Optional[bool]], - ] = {} - used_variables = (used_variables or set()).copy() + + variables = variables or {} + blueprint_variables = blueprint_variables or {} + + used_macros: t.Dict[str, t.Tuple[MacroCallable, bool]] = {} + + # var -> True: var is metadata-only + # var -> False: var is not metadata-only + # var -> None: cannot determine whether var is metadata-only yet, need to walk macros first + used_variables: t.Dict[str, t.Optional[bool]] = dict.fromkeys( + referenced_variables or set(), False + ) + + # id(expr) -> true: expr appears under the AST of a metadata-only macro function + # id(expr) -> false: expr appears under the AST of a macro function whose metadata status we don't yet know + expr_under_metadata_macro_func: t.Dict[int, bool] = {} + + # For @m1(@m2(@x), @y), we'd get x -> m1 and y -> m1 + outermost_macro_func_ancestor_by_var: t.Dict[str, str] = {} + visited_macro_funcs: t.Set[int] = set() + + def _is_metadata_var( + name: str, expression: exp.Expression, appears_in_metadata_expression: bool + ) -> t.Optional[bool]: + is_metadata_so_far = used_variables.get(name, True) + if is_metadata_so_far is False: + # We've concluded this variable is definitely not metadata-only + return False + + appears_under_metadata_macro_func = expr_under_metadata_macro_func.get(id(expression)) + if is_metadata_so_far and ( + appears_in_metadata_expression or appears_under_metadata_macro_func + ): + # The variable appears in a metadata expression, e.g., audits (...), + # or in the AST of metadata-only macro call, e.g., @FOO(@x) + return True + + # The variable appears in the AST of a macro call, but we don't know if it's metadata-only + if appears_under_metadata_macro_func is False: + return None + + # The variable appears elsewhere, e.g., in the model's query: SELECT @x + return False + + def _is_metadata_macro(name: str, appears_in_metadata_expression: bool) -> bool: + if name in used_macros: + is_metadata_so_far = used_macros[name][1] + return is_metadata_so_far and appears_in_metadata_expression + + return appears_in_metadata_expression expressions = ensure_list(expressions) for expression_metadata in expressions: if isinstance(expression_metadata, tuple): expression, is_metadata = expression_metadata else: - expression, is_metadata = expression_metadata, None + expression, is_metadata = expression_metadata, False if isinstance(expression, d.Jinja): continue @@ -72,31 +116,51 @@ def make_python_env( if name not in macros: continue - # If this macro has been seen before as a non-metadata macro, prioritize that - used_macros[name] = ( - macros[name], - used_macros.get(name, (None, is_metadata))[1] and is_metadata, - ) - if name == c.VAR: + used_macros[name] = (macros[name], _is_metadata_macro(name, is_metadata)) + + if name in (c.VAR, c.BLUEPRINT_VAR): args = macro_func_or_var.this.expressions if len(args) < 1: - raise_config_error("Macro VAR requires at least one argument", path) + raise_config_error( + f"Macro {name.upper()} requires at least one argument", path + ) + if not args[0].is_string: raise_config_error( f"The variable name must be a string literal, '{args[0].sql()}' was given instead", path, ) - used_variables.add(args[0].this.lower()) + + var_name = args[0].this.lower() + used_variables[var_name] = _is_metadata_var( + var_name, macro_func_or_var, is_metadata + ) + elif id(macro_func_or_var) not in visited_macro_funcs: + # We only care about the top-level macro function calls to determine the metadata + # status of the variables referenced in their ASTs. For example, in @m1(@m2(@x)), + # if m1 is metadata-only but m2 is not, we can still determine that @x only affects + # the metadata hash, since m2's result feeds into a metadata-only macro function. + # + # Generally, if the top-level call is known to be metadata-only or appear in a + # metadata expression, then we can avoid traversing nested macro function calls. + + var_refs, _expr_under_metadata_macro_func, _visited_macro_funcs = ( + _extract_macro_func_variable_references(macro_func_or_var, is_metadata) + ) + expr_under_metadata_macro_func.update(_expr_under_metadata_macro_func) + visited_macro_funcs.update(_visited_macro_funcs) + outermost_macro_func_ancestor_by_var |= {var_ref: name for var_ref in var_refs} elif macro_func_or_var.__class__ is d.MacroVar: - name = macro_func_or_var.name.lower() - if name in macros: - # If this macro has been seen before as a non-metadata macro, prioritize that - used_macros[name] = ( - macros[name], - used_macros.get(name, (None, is_metadata))[1] and is_metadata, + var_name = macro_func_or_var.name.lower() + if var_name in macros: + used_macros[var_name] = ( + macros[var_name], + _is_metadata_macro(var_name, is_metadata), + ) + elif var_name in variables or var_name in blueprint_variables: + used_variables[var_name] = _is_metadata_var( + var_name, macro_func_or_var, is_metadata ) - elif name in variables: - used_variables.add(name) elif ( isinstance(macro_func_or_var, (exp.Identifier, d.MacroStrReplace, d.MacroSQL)) ) and "@" in macro_func_or_var.name: @@ -104,12 +168,14 @@ def make_python_env( macro_func_or_var.name ): var_name = braced_identifier or identifier - if var_name in variables: - used_variables.add(var_name) + if var_name in variables or var_name in blueprint_variables: + used_variables[var_name] = _is_metadata_var( + var_name, macro_func_or_var, is_metadata + ) for macro_ref in jinja_macro_references or set(): if macro_ref.package is None and macro_ref.name in macros: - used_macros[macro_ref.name] = (macros[macro_ref.name], None) + used_macros[macro_ref.name] = (macros[macro_ref.name], False) for name, (used_macro, is_metadata) in used_macros.items(): if isinstance(used_macro, Executable): @@ -131,16 +197,49 @@ def make_python_env( blueprint_variables=blueprint_variables, dialect=dialect, strict_resolution=strict_resolution, + outermost_macro_func_ancestor_by_var=outermost_macro_func_ancestor_by_var, ) +def _extract_macro_func_variable_references( + macro_func: exp.Expression, + is_metadata: bool, +) -> t.Tuple[t.Set[str], t.Dict[int, bool], t.Set[int]]: + var_references = set() + visited_macro_funcs = set() + expr_under_metadata_macro_func = {} + + for n in macro_func.walk(): + if type(n) is d.MacroFunc: + visited_macro_funcs.add(id(n)) + + this = n.this + args = this.expressions + + if this.name.lower() in (c.VAR, c.BLUEPRINT_VAR) and args and args[0].is_string: + var_references.add(args[0].this.lower()) + expr_under_metadata_macro_func[id(n)] = is_metadata + elif isinstance(n, d.MacroVar): + var_references.add(n.name.lower()) + expr_under_metadata_macro_func[id(n)] = is_metadata + elif isinstance(n, (exp.Identifier, d.MacroStrReplace, d.MacroSQL)) and "@" in n.name: + var_references.update( + (braced_identifier or identifier).lower() + for _, identifier, braced_identifier, _ in MacroStrTemplate.pattern.findall(n.name) + ) + expr_under_metadata_macro_func[id(n)] = is_metadata + + return (var_references, expr_under_metadata_macro_func, visited_macro_funcs) + + def _add_variables_to_python_env( python_env: t.Dict[str, Executable], - used_variables: t.Optional[t.Set[str]], + used_variables: t.Dict[str, t.Optional[bool]], variables: t.Optional[t.Dict[str, t.Any]], strict_resolution: bool = True, blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, dialect: DialectType = None, + outermost_macro_func_ancestor_by_var: t.Optional[t.Dict[str, str]] = None, ) -> t.Dict[str, Executable]: _, python_used_variables = parse_dependencies( python_env, @@ -149,20 +248,67 @@ def _add_variables_to_python_env( variables=variables, blueprint_variables=blueprint_variables, ) - used_variables = (used_variables or set()) | python_used_variables + for var_name, is_metadata in python_used_variables.items(): + used_variables[var_name] = is_metadata and used_variables.get(var_name, True) + + # Variables are treated as metadata-only when all of their references either: + # - appear in metadata-only expressions, such as `audits (...)`, virtual statements, etc + # - appear in the ASTs or definitions of metadata-only macros + # + # See also: https://github.com/TobikoData/sqlmesh/pull/4936#issuecomment-3136339936, + # specifically the "Terminology" and "Observations" section. + metadata_used_variables = { + var_name for var_name, is_metadata in used_variables.items() if is_metadata + } + for used_var, outermost_macro_func in (outermost_macro_func_ancestor_by_var or {}).items(): + used_var_is_metadata = used_variables.get(used_var) + if used_var_is_metadata is False: + continue + + # At this point we can decide whether a variable reference in a macro call's AST is + # metadata-only, because we've annotated the corresponding macro call in the python env. + if outermost_macro_func in python_env and python_env[outermost_macro_func].is_metadata: + metadata_used_variables.add(used_var) + + non_metadata_used_variables = set(used_variables) - metadata_used_variables + + if overlapping_variables := (non_metadata_used_variables & metadata_used_variables): + raise ConfigError( + f"Variables {', '.join(overlapping_variables)} are both metadata and non-metadata, " + "which is unexpected. Please file an issue at https://github.com/TobikoData/sqlmesh/issues/new." + ) + + metadata_variables = { + k: v for k, v in (variables or {}).items() if k in metadata_used_variables + } + variables = {k: v for k, v in (variables or {}).items() if k in non_metadata_used_variables} - variables = {k: v for k, v in (variables or {}).items() if k in used_variables} if variables: python_env[c.SQLMESH_VARS] = Executable.value(variables, sort_root_dict=True) + if metadata_variables: + python_env[c.SQLMESH_VARS_METADATA] = Executable.value( + metadata_variables, sort_root_dict=True, is_metadata=True + ) if blueprint_variables: + metadata_blueprint_variables = { + k: SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expression) else v + for k, v in blueprint_variables.items() + if k in metadata_used_variables + } blueprint_variables = { k.lower(): SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expression) else v for k, v in blueprint_variables.items() + if k in non_metadata_used_variables } - python_env[c.SQLMESH_BLUEPRINT_VARS] = Executable.value( - blueprint_variables, sort_root_dict=True - ) + if blueprint_variables: + python_env[c.SQLMESH_BLUEPRINT_VARS] = Executable.value( + blueprint_variables, sort_root_dict=True + ) + if metadata_blueprint_variables: + python_env[c.SQLMESH_BLUEPRINT_VARS_METADATA] = Executable.value( + metadata_blueprint_variables, sort_root_dict=True, is_metadata=True + ) return python_env @@ -173,7 +319,7 @@ def parse_dependencies( strict_resolution: bool = True, variables: t.Optional[t.Dict[str, t.Any]] = None, blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, -) -> t.Tuple[t.Set[str], t.Set[str]]: +) -> t.Tuple[t.Set[str], t.Dict[str, bool]]: """ Parses the source of a model function and finds upstream table dependencies and referenced variables based on calls to context / evaluator. @@ -187,7 +333,8 @@ def parse_dependencies( blueprint_variables: The blueprint variables available to the python environment. Returns: - A tuple containing the set of upstream table dependencies and the set of referenced variables. + A tuple containing the set of upstream table dependencies and a mapping of + the referenced variables associated with their metadata status. """ class VariableResolutionContext: @@ -205,12 +352,16 @@ def blueprint_var(var_name: str, default: t.Optional[t.Any] = None) -> t.Optiona local_env = dict.fromkeys(("context", "evaluator"), VariableResolutionContext) depends_on = set() - used_variables = set() + used_variables: t.Dict[str, bool] = {} for executable in python_env.values(): if not executable.is_definition: continue + + is_metadata = executable.is_metadata for node in ast.walk(ast.parse(executable.payload)): + next_variables = set() + if isinstance(node, ast.Call): func = node.func if not isinstance(func, ast.Attribute) or not isinstance(func.value, ast.Name): @@ -241,8 +392,11 @@ def get_first_arg(keyword_arg_name: str) -> t.Any: if func.value.id == "context" and func.attr in ("table", "resolve_table"): depends_on.add(get_first_arg("model_name")) - elif func.value.id in ("context", "evaluator") and func.attr == c.VAR: - used_variables.add(get_first_arg("var_name").lower()) + elif func.value.id in ("context", "evaluator") and func.attr in ( + c.VAR, + c.BLUEPRINT_VAR, + ): + next_variables.add(get_first_arg("var_name").lower()) elif ( isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name) @@ -250,9 +404,9 @@ def get_first_arg(keyword_arg_name: str) -> t.Any: and node.attr == c.GATEWAY ): # Check whether the gateway attribute is referenced. - used_variables.add(c.GATEWAY) + next_variables.add(c.GATEWAY) elif isinstance(node, ast.FunctionDef) and node.name == entrypoint: - used_variables.update( + next_variables.update( [ arg.arg for arg in [*node.args.args, *node.args.kwonlyargs] @@ -260,6 +414,9 @@ def get_first_arg(keyword_arg_name: str) -> t.Any: ] ) + for var_name in next_variables: + used_variables[var_name] = used_variables.get(var_name, True) and bool(is_metadata) + return depends_on, used_variables diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index 559d67e960..f2cfeac163 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -1782,12 +1782,17 @@ def render( start, end = make_inclusive(start or c.EPOCH, end or c.EPOCH, self.dialect) execution_time = to_datetime(execution_time or c.EPOCH) - variables = env.get(c.SQLMESH_VARS, {}) - variables.update(kwargs.pop("variables", {})) - + variables = { + **env.get(c.SQLMESH_VARS, {}), + **env.get(c.SQLMESH_VARS_METADATA, {}), + **kwargs.pop("variables", {}), + } blueprint_variables = { k: d.parse_one(v.sql, dialect=self.dialect) if isinstance(v, SqlValue) else v - for k, v in env.get(c.SQLMESH_BLUEPRINT_VARS, {}).items() + for k, v in { + **env.get(c.SQLMESH_BLUEPRINT_VARS, {}), + **env.get(c.SQLMESH_BLUEPRINT_VARS_METADATA, {}), + }.items() } try: kwargs = { @@ -1909,11 +1914,11 @@ def _extract_blueprint_variables(blueprint: t.Any, path: Path) -> t.Dict[str, t. return {} if isinstance(blueprint, (exp.Paren, exp.PropertyEQ)): blueprint = blueprint.unnest() - return {blueprint.left.name: blueprint.right} + return {blueprint.left.name.lower(): blueprint.right} if isinstance(blueprint, (exp.Tuple, exp.Array)): - return {e.left.name: e.right for e in blueprint.expressions} + return {e.left.name.lower(): e.right for e in blueprint.expressions} if isinstance(blueprint, dict): - return blueprint + return {k.lower(): v for k, v in blueprint.items()} raise_config_error( f"Expected a key-value mapping for the blueprint value, got '{blueprint}' instead", @@ -2509,7 +2514,7 @@ def _create_model( if isinstance(getattr(kwargs.get("kind"), "merge_filter", None), exp.Expression): statements.append(kwargs["kind"].merge_filter) - jinja_macro_references, used_variables = extract_macro_references_and_variables( + jinja_macro_references, referenced_variables = extract_macro_references_and_variables( *(gen(e if isinstance(e, exp.Expression) else e[0]) for e in statements) ) @@ -2532,11 +2537,13 @@ def _create_model( _extract_migrated_dbt_variable_references(jinja_macros, variables) ) - used_variables.update(nested_macro_used_variables) + referenced_variables.update(nested_macro_used_variables) variables.update(flattened_package_variables) else: for jinja_macro in jinja_macros.root_macros.values(): - used_variables.update(extract_macro_references_and_variables(jinja_macro.definition)[1]) + referenced_variables.update( + extract_macro_references_and_variables(jinja_macro.definition)[1] + ) # Merge model-specific audits with default audits if default_audits := defaults.pop("audits", None): @@ -2598,7 +2605,7 @@ def _create_model( module_path, macros or macro.get_registry(), variables=variables, - used_variables=used_variables, + referenced_variables=referenced_variables, path=path, python_env=python_env, strict_resolution=depends_on is None, diff --git a/sqlmesh/migrations/v0088_warn_about_variable_python_env_diffs.py b/sqlmesh/migrations/v0088_warn_about_variable_python_env_diffs.py new file mode 100644 index 0000000000..eb33a8041f --- /dev/null +++ b/sqlmesh/migrations/v0088_warn_about_variable_python_env_diffs.py @@ -0,0 +1,74 @@ +""" +This script's goal is to warn users about two situations that could lead to a diff: + +- They have blueprint models and some of their variables may be trimmed from `python_env` +- Variables are used in metadata-only contexts, e.g., within metadata-only macros + +Context: + +We used to store *all* blueprint variables in `python_env`, even though some of them were +redundant. For example, if a blueprint variable is only used in the model's `name` property, +then it is rendered once, at load time, and after that point it's not needed elsewhere. + +This behavior is now different: we only store the blueprint variables that are required to render +expressions at runtime, such as model query or runtime-rendered properties, like `merge_filter`. + +Additionally, variables were previously treated as non-metadata, regardless of how they were used. +This behavior changed as well: SQLMesh now analyzes variable references and tracks the data flow, +in order to detect whether changing them will result in a metadata diff for a given model. + +Some examples where variables can be treated as metadata-only `python_env` executables are: + +- A variable is referenced in metadata-only macros +- A variable is referenced in metadata-only expressions, such as virtual update statements +- A variable is passed as argument to metadata-only macros +""" + +import json + +from sqlglot import exp + +from sqlmesh.core.console import get_console + +SQLMESH_VARS = "__sqlmesh__vars__" +SQLMESH_BLUEPRINT_VARS = "__sqlmesh__blueprint__vars__" +METADATA_HASH_EXPRESSIONS = {"on_virtual_update", "audits", "signals", "audit_definitions"} + + +def migrate(state_sync, **kwargs): # type: ignore + engine_adapter = state_sync.engine_adapter + schema = state_sync.schema + snapshots_table = "_snapshots" + if schema: + snapshots_table = f"{schema}.{snapshots_table}" + + warning = ( + "SQLMesh detected that it may not be able to fully migrate the state database. This should not impact " + "the migration process, but may result in unexpected changes being reported by the next `sqlmesh plan` " + "command. Please run `sqlmesh diff prod` after the migration has completed, before making any new " + "changes. If any unexpected changes are reported, consider running a forward-only plan to apply these " + "changes and avoid unnecessary backfills: sqlmesh plan prod --forward-only. " + "See https://sqlmesh.readthedocs.io/en/stable/concepts/plans/#forward-only-plans for more details.\n" + ) + + for (snapshot,) in engine_adapter.fetchall( + exp.select("snapshot").from_(snapshots_table), quote_identifiers=True + ): + parsed_snapshot = json.loads(snapshot) + node = parsed_snapshot["node"] + + # Standalone audits don't have a data hash, so they're unaffected + if node.get("source_type") == "audit": + continue + + python_env = node.get("python_env") or {} + + if SQLMESH_BLUEPRINT_VARS in python_env or ( + SQLMESH_VARS in python_env + and ( + any(v.get("is_metadata") for v in python_env.values()) + or any(node.get(k) for k in METADATA_HASH_EXPRESSIONS) + ) + ): + get_console().log_warning(warning) + return diff --git a/sqlmesh/utils/jinja.py b/sqlmesh/utils/jinja.py index d1c0ef0361..6720c24581 100644 --- a/sqlmesh/utils/jinja.py +++ b/sqlmesh/utils/jinja.py @@ -229,7 +229,7 @@ def extract_macro_references_and_variables( ) for call_name, node in extract_call_names(jinja_str): - if call_name[0] == c.VAR: + if call_name[0] in (c.VAR, c.BLUEPRINT_VAR): assert isinstance(node, nodes.Call) args = [jinja_call_arg_name(arg) for arg in node.args] if args and args[0]: diff --git a/sqlmesh/utils/metaprogramming.py b/sqlmesh/utils/metaprogramming.py index 9330532442..858e8a50da 100644 --- a/sqlmesh/utils/metaprogramming.py +++ b/sqlmesh/utils/metaprogramming.py @@ -283,7 +283,7 @@ def build_env( env: t.Dict[str, t.Tuple[t.Any, t.Optional[bool]]], name: str, path: Path, - is_metadata_obj: t.Optional[bool] = None, + is_metadata_obj: bool = False, ) -> None: """Fills in env dictionary with all globals needed to execute the object. @@ -299,7 +299,7 @@ def build_env( # We don't rely on `env` to keep track of visited objects, because it's populated in post-order visited: t.Set[str] = set() - def walk(obj: t.Any, name: str, is_metadata: t.Optional[bool] = None) -> None: + def walk(obj: t.Any, name: str, is_metadata: bool = False) -> None: obj_module = inspect.getmodule(obj) if obj_module and obj_module.__name__ == "builtins": return @@ -320,7 +320,7 @@ def walk(obj: t.Any, name: str, is_metadata: t.Optional[bool] = None) -> None: # The existing object in the env is "metadata only" but we're walking it again as a # non-"metadata only" dependency, so we update this flag to ensure all transitive # dependencies are also not marked as "metadata only" - is_metadata = None + is_metadata = False if hasattr(obj, c.SQLMESH_MACRO): # We only need to add the undecorated code of @macro() functions in env, which @@ -380,7 +380,7 @@ def walk(obj: t.Any, name: str, is_metadata: t.Optional[bool] = None) -> None: ) # The "metadata only" annotation of the object is transitive - walk(obj, name, is_metadata_obj or getattr(obj, c.SQLMESH_METADATA, None)) + walk(obj, name, is_metadata_obj or getattr(obj, c.SQLMESH_METADATA, False)) @dataclass @@ -432,7 +432,11 @@ def value( cls, v: t.Any, is_metadata: t.Optional[bool] = None, sort_root_dict: bool = False ) -> Executable: payload = _dict_sort(v) if sort_root_dict else repr(v) - return Executable(payload=payload, kind=ExecutableKind.VALUE, is_metadata=is_metadata) + return Executable( + payload=payload, + kind=ExecutableKind.VALUE, + is_metadata=is_metadata or None, + ) def serialize_env(env: t.Dict[str, t.Any], path: Path) -> t.Dict[str, Executable]: @@ -447,6 +451,9 @@ def serialize_env(env: t.Dict[str, t.Any], path: Path) -> t.Dict[str, Executable serialized = {} for k, (v, is_metadata) in env.items(): + # We don't store `False` for `is_metadata` to reduce the pydantic model's payload size + is_metadata = is_metadata or None + if isinstance(v, LITERALS) or v is None: serialized[k] = Executable.value(v, is_metadata=is_metadata) elif inspect.ismodule(v): diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 0be1702fa1..f8070a98a4 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -9383,13 +9383,15 @@ def entrypoint(evaluator): assert "blueprints" not in model.all_fields() python_env = model.python_env - serialized_blueprint = ( - SqlValue(sql=blueprint_value) if model_name == "test_model_sql" else blueprint_value - ) + assert python_env.get(c.SQLMESH_VARS) == Executable.value({"x": gateway_no}) - assert python_env.get(c.SQLMESH_BLUEPRINT_VARS) == Executable.value( - {"blueprint": serialized_blueprint} - ) + + if model_name == "test_model_sql": + assert c.SQLMESH_BLUEPRINT_VARS not in python_env + else: + assert python_env.get(c.SQLMESH_BLUEPRINT_VARS) == Executable.value( + {"blueprint": blueprint_value} + ) assert context.fetchdf(f"from {model.fqn}").to_dict() == {"x": {0: gateway_no}} @@ -10053,6 +10055,185 @@ def metadata_macro(evaluator): assert new_snapshot.change_category == SnapshotChangeCategory.METADATA +def test_vars_are_taken_into_account_when_propagating_metadata_status(tmp_path: Path) -> None: + init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) + + test_model = tmp_path / "models/test_model.sql" + test_model.parent.mkdir(parents=True, exist_ok=True) + test_model.write_text( + "MODEL (name test_model, kind FULL, blueprints ((v4 := 4, v5 := 5)));" + "@m1_metadata_references_v1();" # metadata macro, references v1 internally => v1 metadata + "@m2_metadata_does_not_reference_var(@v2, @v3);" # metadata macro => v2 metadata, v3 metadata + "@m3_non_metadata_references_v4(@v3);" # non-metadata macro, references v4 => v3, v4 are not metadata + "SELECT 1 AS c;" + "@m2_metadata_does_not_reference_var(@v6);" # metadata macro => v6 is metadata + "@m4_non_metadata_references_v6();" # non-metadata macro, references v6 => v6 is not metadata + "ON_VIRTUAL_UPDATE_BEGIN;" + "@m3_non_metadata_references_v4(@v5);" # non-metadata macro, metadata expression => v5 metadata + "ON_VIRTUAL_UPDATE_END;" + ) + + macro_code = """ +from sqlmesh import macro + +@macro(metadata_only=True) +def m1_metadata_references_v1(evaluator): + evaluator.var("v1") + return None + +@macro(metadata_only=True) +def m2_metadata_does_not_reference_var(evaluator, *args): + return None + +@macro() +def m3_non_metadata_references_v4(evaluator, *args): + evaluator.var("v4") + return None + +@macro() +def m4_non_metadata_references_v6(evaluator): + evaluator.var("v6") + return None""" + + test_macros = tmp_path / "macros/test_macros.py" + test_macros.parent.mkdir(parents=True, exist_ok=True) + test_macros.write_text(macro_code) + + ctx = Context( + config=Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + variables={"v1": 1, "v2": 2, "v3": 3, "v6": 6}, + ), + paths=tmp_path, + ) + model = ctx.get_model("test_model") + + python_env = model.python_env + + assert len(python_env) == 8 + assert "m1_metadata_references_v1" in python_env + assert "m2_metadata_does_not_reference_var" in python_env + assert "m3_non_metadata_references_v4" in python_env + assert "m4_non_metadata_references_v6" in python_env + + variables = python_env.get(c.SQLMESH_VARS) + metadata_variables = python_env.get(c.SQLMESH_VARS_METADATA) + + assert variables == Executable.value({"v3": 3, "v6": 6}) + assert metadata_variables == Executable.value({"v1": 1, "v2": 2}, is_metadata=True) + + blueprint_variables = python_env.get(c.SQLMESH_BLUEPRINT_VARS) + blueprint_metadata_variables = python_env.get(c.SQLMESH_BLUEPRINT_VARS_METADATA) + + assert blueprint_variables == Executable.value({"v4": SqlValue(sql="4")}) + assert blueprint_metadata_variables == Executable.value( + {"v5": SqlValue(sql="5")}, is_metadata=True + ) + + macro_evaluator = MacroEvaluator(python_env=python_env) + + assert macro_evaluator.locals == { + "runtime_stage": "loading", + "default_catalog": None, + c.SQLMESH_VARS: {"v3": 3, "v6": 6}, + c.SQLMESH_VARS_METADATA: {"v1": 1, "v2": 2}, + c.SQLMESH_BLUEPRINT_VARS: {"v4": exp.Literal.number("4")}, + c.SQLMESH_BLUEPRINT_VARS_METADATA: {"v5": exp.Literal.number("5")}, + } + assert macro_evaluator.var("v1") == 1 + assert macro_evaluator.var("v2") == 2 + assert macro_evaluator.var("v3") == 3 + assert macro_evaluator.var("v6") == 6 + assert macro_evaluator.blueprint_var("v4") == exp.Literal.number("4") + assert macro_evaluator.blueprint_var("v5") == exp.Literal.number("5") + + query_with_vars = macro_evaluator.transform( + parse_one("SELECT " + ", ".join(f"@v{var}, @VAR('v{var}')" for var in [1, 2, 3, 6])) + ) + assert t.cast(exp.Expression, query_with_vars).sql() == "SELECT 1, 1, 2, 2, 3, 3, 6, 6" + + query_with_blueprint_vars = macro_evaluator.transform( + parse_one("SELECT " + ", ".join(f"@v{var}, @BLUEPRINT_VAR('v{var}')" for var in [4, 5])) + ) + assert t.cast(exp.Expression, query_with_blueprint_vars).sql() == "SELECT 4, 4, 5, 5" + + +def test_variable_mentioned_in_both_metadata_and_non_metadata_macro(tmp_path: Path) -> None: + init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) + + test_model = tmp_path / "models/test_model.sql" + test_model.parent.mkdir(parents=True, exist_ok=True) + test_model.write_text( + "MODEL (name test_model, kind FULL); @m1_references_v_metadata(); SELECT @m2_references_v_non_metadata() AS c;" + ) + + macro_code = """ +from sqlmesh import macro + +@macro(metadata_only=True) +def m1_references_v_metadata(evaluator): + evaluator.var("v") + return None + +@macro() +def m2_references_v_non_metadata(evaluator): + evaluator.var("v") + return None""" + + test_macros = tmp_path / "macros/test_macros.py" + test_macros.parent.mkdir(parents=True, exist_ok=True) + test_macros.write_text(macro_code) + + ctx = Context( + config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"), variables={"v": 1}), + paths=tmp_path, + ) + model = ctx.get_model("test_model") + + python_env = model.python_env + + assert len(python_env) == 3 + assert set(python_env) > {"m1_references_v_metadata", "m2_references_v_non_metadata"} + assert python_env.get(c.SQLMESH_VARS) == Executable.value({"v": 1}) + + +def test_only_top_level_macro_func_impacts_var_descendant_metadata_status(tmp_path: Path) -> None: + init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) + + test_model = tmp_path / "models/test_model.sql" + test_model.parent.mkdir(parents=True, exist_ok=True) + test_model.write_text( + "MODEL (name test_model, kind FULL); @m1_metadata(@m2_non_metadata(@v)); SELECT 1 AS c;" + ) + + macro_code = """ +from sqlmesh import macro + +@macro(metadata_only=True) +def m1_metadata(evaluator, *args): + return None + +@macro() +def m2_non_metadata(evaluator, *args): + return None""" + + test_macros = tmp_path / "macros/test_macros.py" + test_macros.parent.mkdir(parents=True, exist_ok=True) + test_macros.write_text(macro_code) + + ctx = Context( + config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"), variables={"v": 1}), + paths=tmp_path, + ) + model = ctx.get_model("test_model") + + python_env = model.python_env + + assert len(python_env) == 3 + assert set(python_env) > {"m1_metadata", "m2_non_metadata"} + assert python_env.get(c.SQLMESH_VARS_METADATA) == Executable.value({"v": 1}, is_metadata=True) + + def test_non_metadata_object_takes_precedence_over_metadata_only_object(tmp_path: Path) -> None: init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) @@ -10958,7 +11139,7 @@ def entrypoint( assert customer1_model.enabled assert "blueprints" not in customer1_model.all_fields() assert customer1_model.python_env.get(c.SQLMESH_BLUEPRINT_VARS) == Executable.value( - {"customer": "customer1", "field_a": "x", "field_b": "y", "min": 5} + {"customer": "customer1", "field_a": "x", "field_b": "y"} ) # Test second blueprint @@ -10966,7 +11147,7 @@ def entrypoint( assert customer2_model is not None assert customer2_model.cron == "*/10 * * * *" assert customer2_model.python_env.get(c.SQLMESH_BLUEPRINT_VARS) == Executable.value( - {"customer": "customer2", "field_a": "z", "field_b": "w", "min": 10} + {"customer": "customer2", "field_a": "z", "field_b": "w"} ) # Test that the models can be planned and applied @@ -11158,3 +11339,22 @@ def test_each_macro_with_paren_expression_arg(assert_exp_eq): 'value' AS "property1" """, ) + + +@pytest.mark.parametrize( + "macro_func, variables", + [ + ("@M(@v1)", {"v1"}), + ("@M(@{v1})", {"v1"}), + ("@M(@SQL('@v1'))", {"v1"}), + ("@M(@'@{v1}_foo')", {"v1"}), + ("@M1(@VAR('v1'))", {"v1"}), + ("@M1(@v1, @M2(@v2), @BLUEPRINT_VAR('v3'))", {"v1", "v2", "v3"}), + ("@M1(@BLUEPRINT_VAR(@VAR('v1')))", {"v1"}), + ], +) +def test_extract_macro_func_variable_references(macro_func: str, variables: t.Set[str]) -> None: + from sqlmesh.core.model.common import _extract_macro_func_variable_references + + macro_func_ast = parse_one(macro_func) + assert _extract_macro_func_variable_references(macro_func_ast, True)[0] == variables diff --git a/tests/utils/test_metaprogramming.py b/tests/utils/test_metaprogramming.py index 8519e1eb04..19413f68ef 100644 --- a/tests/utils/test_metaprogramming.py +++ b/tests/utils/test_metaprogramming.py @@ -406,7 +406,7 @@ def function_with_custom_decorator(): "SQLGLOT_META": Executable.value("sqlglot.meta"), } - assert all(is_metadata is None for (_, is_metadata) in env.values()) + assert all(not is_metadata for (_, is_metadata) in env.values()) assert serialized_env == expected_env # Annotate the entrypoint as "metadata only" to show how it propagates