diff --git a/sqlmesh/core/linter/helpers.py b/sqlmesh/core/linter/helpers.py index 707b0a9159..e62545bc02 100644 --- a/sqlmesh/core/linter/helpers.py +++ b/sqlmesh/core/linter/helpers.py @@ -84,19 +84,8 @@ def to_range(self, read_file: t.Optional[t.List[str]]) -> Range: ) -def read_range_from_file(file: Path, text_range: Range) -> str: - """ - Read the file and return the content within the specified range. - - Args: - file: Path to the file to read - text_range: The range of text to extract - - Returns: - The content within the specified range - """ - with file.open("r", encoding="utf-8") as f: - lines = f.readlines() +def read_range_from_string(content: str, text_range: Range) -> str: + lines = content.splitlines(keepends=False) # Ensure the range is within bounds start_line = max(0, text_range.start.line) @@ -116,6 +105,23 @@ def read_range_from_file(file: Path, text_range: Range) -> str: return "".join(result) +def read_range_from_file(file: Path, text_range: Range) -> str: + """ + Read the file and return the content within the specified range. + + Args: + file: Path to the file to read + text_range: The range of text to extract + + Returns: + The content within the specified range + """ + with file.open("r", encoding="utf-8") as f: + lines = f.readlines() + + return read_range_from_string("".join(lines), text_range) + + def get_range_of_model_block( sql: str, dialect: str, diff --git a/sqlmesh/core/linter/rules/builtin.py b/sqlmesh/core/linter/rules/builtin.py index f16bb5d111..c1a5f9b877 100644 --- a/sqlmesh/core/linter/rules/builtin.py +++ b/sqlmesh/core/linter/rules/builtin.py @@ -7,10 +7,16 @@ from sqlglot.expressions import Star from sqlglot.helper import subclasses -from sqlmesh.core.linter.helpers import TokenPositionDetails, get_range_of_model_block +from sqlmesh.core.dialect import normalize_model_name +from sqlmesh.core.linter.helpers import ( + TokenPositionDetails, + get_range_of_model_block, + read_range_from_string, +) from sqlmesh.core.linter.rule import Rule, RuleViolation, Range, Fix, TextEdit from sqlmesh.core.linter.definition import RuleSet from sqlmesh.core.model import Model, SqlModel, ExternalModel +from sqlmesh.utils.lineage import extract_references_from_query, ExternalModelReference class NoSelectStar(Rule): @@ -113,7 +119,9 @@ def check_model(self, model: Model) -> t.Optional[RuleViolation]: class NoMissingExternalModels(Rule): """All external models must be registered in the external_models.yaml file""" - def check_model(self, model: Model) -> t.Optional[RuleViolation]: + def check_model( + self, model: Model + ) -> t.Optional[t.Union[RuleViolation, t.List[RuleViolation]]]: # Ignore external models themselves, because either they are registered, # and if they are not, they will be caught as referenced in another model. if isinstance(model, ExternalModel): @@ -129,10 +137,74 @@ def check_model(self, model: Model) -> t.Optional[RuleViolation]: if not not_registered_external_models: return None + # If the model is anything other than a sql model that and has a path + # that ends with .sql, we cannot extract the references from the query. + path = model._path + if not isinstance(model, SqlModel) or not path or not str(path).endswith(".sql"): + return self._standard_error_message( + model_name=model.fqn, + external_models=not_registered_external_models, + ) + + with open(path, "r", encoding="utf-8") as file: + read_file = file.read() + split_read_file = read_file.splitlines() + + # If there are any unregistered external models, return a violation find + # the ranges for them. + references = extract_references_from_query( + query=model.query, + context=self.context, + document_path=path, + read_file=split_read_file, + depends_on=model.depends_on, + dialect=model.dialect, + ) + external_references = { + normalize_model_name( + table=read_range_from_string(read_file, ref.range), + default_catalog=model.default_catalog, + dialect=model.dialect, + ): ref + for ref in references + if isinstance(ref, ExternalModelReference) and ref.path is None + } + + # Ensure that depends_on and external references match. + if not_registered_external_models != set(external_references.keys()): + return self._standard_error_message( + model_name=model.fqn, + external_models=not_registered_external_models, + ) + + # Return a violation for each unregistered external model with its range. + violations = [] + for ref_name, ref in external_references.items(): + if ref_name in not_registered_external_models: + violations.append( + RuleViolation( + rule=self, + violation_msg=f"Model '{model.fqn}' depends on unregistered external model '{ref_name}'. " + "Please register it in the external models file. This can be done by running 'sqlmesh create_external_models'.", + violation_range=ref.range, + ) + ) + + if len(violations) < len(not_registered_external_models): + return self._standard_error_message( + model_name=model.fqn, + external_models=not_registered_external_models, + ) + + return violations + + def _standard_error_message( + self, model_name: str, external_models: t.Set[str] + ) -> RuleViolation: return RuleViolation( rule=self, - violation_msg=f"Model '{model.name}' depends on unregistered external models: " - f"{', '.join(m for m in not_registered_external_models)}. " + violation_msg=f"Model '{model_name}' depends on unregistered external models: " + f"{', '.join(m for m in external_models)}. " "Please register them in the external models file. This can be done by running 'sqlmesh create_external_models'.", ) diff --git a/sqlmesh/lsp/completions.py b/sqlmesh/lsp/completions.py index 0026260481..93162b15a4 100644 --- a/sqlmesh/lsp/completions.py +++ b/sqlmesh/lsp/completions.py @@ -8,8 +8,8 @@ from sqlmesh import macro import typing as t from sqlmesh.lsp.context import AuditTarget, LSPContext, ModelTarget -from sqlmesh.lsp.description import generate_markdown_description from sqlmesh.lsp.uri import URI +from sqlmesh.utils.lineage import generate_markdown_description def get_sql_completions( diff --git a/sqlmesh/lsp/description.py b/sqlmesh/lsp/description.py deleted file mode 100644 index 768197742f..0000000000 --- a/sqlmesh/lsp/description.py +++ /dev/null @@ -1,29 +0,0 @@ -from sqlmesh.core.model.definition import ( - ExternalModel, - PythonModel, - SeedModel, - SqlModel, -) -import typing as t - - -def generate_markdown_description( - model: t.Union[SqlModel, ExternalModel, PythonModel, SeedModel], -) -> t.Optional[str]: - description = model.description - columns = model.columns_to_types - column_descriptions = model.column_descriptions - - if columns is None: - return description or None - - columns_table = "\n".join( - [ - f"| {column} | {column_type} | {column_descriptions.get(column, '')} |" - for column, column_type in columns.items() - ] - ) - - table_header = "| Column | Type | Description |\n|--------|------|-------------|\n" - columns_text = table_header + columns_table - return f"{description}\n\n{columns_text}" if description else columns_text diff --git a/sqlmesh/lsp/main.py b/sqlmesh/lsp/main.py index 416b092122..d53822fcac 100755 --- a/sqlmesh/lsp/main.py +++ b/sqlmesh/lsp/main.py @@ -23,6 +23,10 @@ ApiResponseGetLineage, ApiResponseGetModels, ) + +# Define the command constant +EXTERNAL_MODEL_UPDATE_COLUMNS = "sqlmesh.external_model_update_columns" + from sqlmesh.lsp.completions import get_sql_completions from sqlmesh.lsp.context import ( LSPContext, @@ -60,15 +64,15 @@ from sqlmesh.lsp.helpers import to_lsp_range, to_sqlmesh_position from sqlmesh.lsp.hints import get_hints from sqlmesh.lsp.reference import ( - LSPCteReference, - LSPModelReference, - LSPExternalModelReference, + CTEReference, + ModelReference, get_references, get_all_references, ) from sqlmesh.lsp.rename import prepare_rename, rename_symbol, get_document_highlights from sqlmesh.lsp.uri import URI from sqlmesh.utils.errors import ConfigError +from sqlmesh.utils.lineage import ExternalModelReference from web.server.api.endpoints.lineage import column_lineage, model_lineage from web.server.api.endpoints.models import get_models from typing import Union @@ -479,7 +483,7 @@ def hover(ls: LanguageServer, params: types.HoverParams) -> t.Optional[types.Hov if not references: return None reference = references[0] - if isinstance(reference, LSPCteReference) or not reference.markdown_description: + if isinstance(reference, CTEReference) or not reference.markdown_description: return None return types.Hover( contents=types.MarkupContent( @@ -525,7 +529,7 @@ def goto_definition( location_links = [] for reference in references: # Use target_range if available (CTEs, Macros, and external models in YAML) - if isinstance(reference, LSPModelReference): + if isinstance(reference, ModelReference): # Regular SQL models - default to start of file target_range = types.Range( start=types.Position(line=0, character=0), @@ -535,7 +539,7 @@ def goto_definition( start=types.Position(line=0, character=0), end=types.Position(line=0, character=0), ) - elif isinstance(reference, LSPExternalModelReference): + elif isinstance(reference, ExternalModelReference): # External models may have target_range set for YAML files target_range = types.Range( start=types.Position(line=0, character=0), diff --git a/sqlmesh/lsp/reference.py b/sqlmesh/lsp/reference.py index 6aee3e10da..80d401f79c 100644 --- a/sqlmesh/lsp/reference.py +++ b/sqlmesh/lsp/reference.py @@ -1,74 +1,27 @@ import typing as t from pathlib import Path -from pydantic import Field from sqlmesh.core.audit import StandaloneAudit -from sqlmesh.core.dialect import normalize_model_name from sqlmesh.core.linter.helpers import ( TokenPositionDetails, ) from sqlmesh.core.linter.rule import Range, Position -from sqlmesh.core.model.definition import SqlModel, ExternalModel +from sqlmesh.core.model.definition import SqlModel from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget from sqlglot import exp -from sqlmesh.lsp.description import generate_markdown_description -from sqlglot.optimizer.scope import build_scope from sqlmesh.lsp.uri import URI -from sqlmesh.utils.pydantic import PydanticModel -from sqlglot.optimizer.normalize_identifiers import normalize_identifiers +from sqlmesh.utils.lineage import ( + MacroReference, + CTEReference, + Reference, + ModelReference, + extract_references_from_query, +) import ast from sqlmesh.core.model import Model from sqlmesh import macro import inspect -from ruamel.yaml import YAML - - -class LSPModelReference(PydanticModel): - """A LSP reference to a model, excluding external models.""" - - type: t.Literal["model"] = "model" - path: Path - range: Range - markdown_description: t.Optional[str] = None - - -class LSPExternalModelReference(PydanticModel): - """A LSP reference to an external model.""" - - type: t.Literal["external_model"] = "external_model" - range: Range - target_range: t.Optional[Range] = None - path: t.Optional[Path] = None - """The path of the external model, typically a YAML file, it is optional because - external models can be unregistered and so the path is not available.""" - - markdown_description: t.Optional[str] = None - - -class LSPCteReference(PydanticModel): - """A LSP reference to a CTE.""" - - type: t.Literal["cte"] = "cte" - path: Path - range: Range - target_range: Range - - -class LSPMacroReference(PydanticModel): - """A LSP reference to a macro.""" - - type: t.Literal["macro"] = "macro" - path: Path - range: Range - target_range: Range - markdown_description: t.Optional[str] = None - - -Reference = t.Annotated[ - t.Union[LSPModelReference, LSPCteReference, LSPMacroReference, LSPExternalModelReference], - Field(discriminator="type"), -] def by_position(position: Position) -> t.Callable[[Reference], bool]: @@ -158,7 +111,6 @@ def get_model_definitions_for_a_path( audit = lint_context.context.standalone_audits.get(file_info.name) if audit is None: return [] - query = audit.query dialect = audit.dialect depends_on = audit.depends_on @@ -169,177 +121,17 @@ def get_model_definitions_for_a_path( if file_path is None: return [] - # Find all possible references - references: t.List[Reference] = [] - with open(file_path, "r", encoding="utf-8") as file: read_file = file.readlines() - # Build a scope tree to properly handle nested CTEs - try: - query = normalize_identifiers(query.copy(), dialect=dialect) - root_scope = build_scope(query) - except Exception: - root_scope = None - - if root_scope: - # Traverse all scopes to find CTE definitions and table references - for scope in root_scope.traverse(): - for table in scope.tables: - table_name = table.name - - # Check if this table reference is a CTE in the current scope - if cte_scope := scope.cte_sources.get(table_name): - cte = cte_scope.expression.parent - alias = cte.args["alias"] - if isinstance(alias, exp.TableAlias): - identifier = alias.this - if isinstance(identifier, exp.Identifier): - target_range_sqlmesh = TokenPositionDetails.from_meta( - identifier.meta - ).to_range(read_file) - table_range_sqlmesh = TokenPositionDetails.from_meta( - table.this.meta - ).to_range(read_file) - - references.append( - LSPCteReference( - path=document_uri.to_path(), # Same file - range=table_range_sqlmesh, - target_range=target_range_sqlmesh, - ) - ) - - column_references = _process_column_references( - scope=scope, - reference_name=table.name, - read_file=read_file, - referenced_model_path=document_uri.to_path(), - description="", - reference_type="cte", - cte_target_range=target_range_sqlmesh, - ) - references.extend(column_references) - continue - - # For non-CTE tables, process these as before (external model references) - # Normalize the table reference - unaliased = table.copy() - if unaliased.args.get("alias") is not None: - unaliased.set("alias", None) - reference_name = unaliased.sql(dialect=dialect) - try: - normalized_reference_name = normalize_model_name( - reference_name, - default_catalog=lint_context.context.default_catalog, - dialect=dialect, - ) - if normalized_reference_name not in depends_on: - continue - except Exception: - # Skip references that cannot be normalized - continue - - # Get the referenced model uri - referenced_model = lint_context.context.get_model( - model_or_snapshot=normalized_reference_name, raise_if_missing=False - ) - if referenced_model is None: - table_meta = TokenPositionDetails.from_meta(table.this.meta) - table_range_sqlmesh = table_meta.to_range(read_file) - start_pos_sqlmesh = table_range_sqlmesh.start - end_pos_sqlmesh = table_range_sqlmesh.end - references.append( - LSPExternalModelReference( - range=Range( - start=start_pos_sqlmesh, - end=end_pos_sqlmesh, - ), - markdown_description="Unregistered external model", - ) - ) - continue - referenced_model_path = referenced_model._path - if referenced_model_path is None: - continue - # Check whether the path exists - if not referenced_model_path.is_file(): - continue - - # Extract metadata for positioning - table_meta = TokenPositionDetails.from_meta(table.this.meta) - table_range_sqlmesh = table_meta.to_range(read_file) - start_pos_sqlmesh = table_range_sqlmesh.start - end_pos_sqlmesh = table_range_sqlmesh.end - - # If there's a catalog or database qualifier, adjust the start position - catalog_or_db = table.args.get("catalog") or table.args.get("db") - if catalog_or_db is not None: - catalog_or_db_meta = TokenPositionDetails.from_meta(catalog_or_db.meta) - catalog_or_db_range_sqlmesh = catalog_or_db_meta.to_range(read_file) - start_pos_sqlmesh = catalog_or_db_range_sqlmesh.start - - description = generate_markdown_description(referenced_model) - - # For external models in YAML files, find the specific model block - if isinstance(referenced_model, ExternalModel): - yaml_target_range: t.Optional[Range] = None - if ( - referenced_model_path.suffix in (".yaml", ".yml") - and referenced_model_path.is_file() - ): - yaml_target_range = _get_yaml_model_range( - referenced_model_path, referenced_model.name - ) - references.append( - LSPExternalModelReference( - path=referenced_model_path, - range=Range( - start=start_pos_sqlmesh, - end=end_pos_sqlmesh, - ), - markdown_description=description, - target_range=yaml_target_range, - ) - ) - - column_references = _process_column_references( - scope=scope, - reference_name=normalized_reference_name, - read_file=read_file, - referenced_model_path=referenced_model_path, - description=description, - yaml_target_range=yaml_target_range, - reference_type="external_model", - default_catalog=lint_context.context.default_catalog, - dialect=dialect, - ) - references.extend(column_references) - else: - references.append( - LSPModelReference( - path=referenced_model_path, - range=Range( - start=start_pos_sqlmesh, - end=end_pos_sqlmesh, - ), - markdown_description=description, - ) - ) - - column_references = _process_column_references( - scope=scope, - reference_name=normalized_reference_name, - read_file=read_file, - referenced_model_path=referenced_model_path, - description=description, - reference_type="model", - default_catalog=lint_context.context.default_catalog, - dialect=dialect, - ) - references.extend(column_references) - - return references + return extract_references_from_query( + query=query, + context=lint_context.context, + document_path=document_uri.to_path(), + read_file=read_file, + depends_on=depends_on, + dialect=dialect, + ) def get_macro_definitions_for_a_path( @@ -476,7 +268,7 @@ def get_macro_reference( # Create a reference to the macro definition - return LSPMacroReference( + return MacroReference( path=path, range=macro_range, target_range=Range( @@ -509,7 +301,7 @@ def get_built_in_macro_reference(macro_name: str, macro_range: Range) -> t.Optio # Calculate the end line number by counting the number of source lines end_line_number = line_number + len(source_lines) - 1 - return LSPMacroReference( + return MacroReference( path=Path(filename), range=macro_range, target_range=Range( @@ -522,7 +314,7 @@ def get_built_in_macro_reference(macro_name: str, macro_range: Range) -> t.Optio def get_model_find_all_references( lint_context: LSPContext, document_uri: URI, position: Position -) -> t.List[LSPModelReference]: +) -> t.List[ModelReference]: """ Get all references to a model across the entire project. @@ -540,7 +332,7 @@ def get_model_find_all_references( # Find the model reference at the cursor position model_at_position = next( filter( - lambda ref: isinstance(ref, LSPModelReference) + lambda ref: isinstance(ref, ModelReference) and _position_within_range(position, ref.range), get_model_definitions_for_a_path(lint_context, document_uri), ), @@ -550,13 +342,13 @@ def get_model_find_all_references( if not model_at_position: return [] - assert isinstance(model_at_position, LSPModelReference) # for mypy + assert isinstance(model_at_position, ModelReference) # for mypy target_model_path = model_at_position.path # Start with the model definition - all_references: t.List[LSPModelReference] = [ - LSPModelReference( + all_references: t.List[ModelReference] = [ + ModelReference( path=model_at_position.path, range=Range( start=Position(line=0, character=0), @@ -568,15 +360,15 @@ def get_model_find_all_references( # Then add references from the current file current_file_refs = filter( - lambda ref: isinstance(ref, LSPModelReference) and ref.path == target_model_path, + lambda ref: isinstance(ref, ModelReference) and ref.path == target_model_path, get_model_definitions_for_a_path(lint_context, document_uri), ) for ref in current_file_refs: - assert isinstance(ref, LSPModelReference) # for mypy + assert isinstance(ref, ModelReference) # for mypy all_references.append( - LSPModelReference( + ModelReference( path=document_uri.to_path(), range=ref.range, markdown_description=ref.markdown_description, @@ -593,15 +385,15 @@ def get_model_find_all_references( # Get model references that point to the target model matching_refs = filter( - lambda ref: isinstance(ref, LSPModelReference) and ref.path == target_model_path, + lambda ref: isinstance(ref, ModelReference) and ref.path == target_model_path, get_model_definitions_for_a_path(lint_context, file_uri), ) for ref in matching_refs: - assert isinstance(ref, LSPModelReference) # for mypy + assert isinstance(ref, ModelReference) # for mypy all_references.append( - LSPModelReference( + ModelReference( path=path, range=ref.range, markdown_description=ref.markdown_description, @@ -613,7 +405,7 @@ def get_model_find_all_references( def get_cte_references( lint_context: LSPContext, document_uri: URI, position: Position -) -> t.List[LSPCteReference]: +) -> t.List[CTEReference]: """ Get all references to a CTE at a specific position in a document. @@ -629,10 +421,10 @@ def get_cte_references( """ # Filter to get the CTE references - cte_references: t.List[LSPCteReference] = [ + cte_references: t.List[CTEReference] = [ ref for ref in get_model_definitions_for_a_path(lint_context, document_uri) - if isinstance(ref, LSPCteReference) + if isinstance(ref, CTEReference) ] if not cte_references: @@ -654,7 +446,7 @@ def get_cte_references( # Add the CTE definition matching_references = [ - LSPCteReference( + CTEReference( path=document_uri.to_path(), range=target_cte_definition_range, target_range=target_cte_definition_range, @@ -665,7 +457,7 @@ def get_cte_references( for ref in cte_references: if ref.target_range == target_cte_definition_range: matching_references.append( - LSPCteReference( + CTEReference( path=document_uri.to_path(), range=ref.range, target_range=ref.target_range, @@ -677,7 +469,7 @@ def get_cte_references( def get_macro_find_all_references( lsp_context: LSPContext, document_uri: URI, position: Position -) -> t.List[LSPMacroReference]: +) -> t.List[MacroReference]: """ Get all references to a macro at a specific position in a document. @@ -694,7 +486,7 @@ def get_macro_find_all_references( # Find the macro reference at the cursor position macro_at_position = next( filter( - lambda ref: isinstance(ref, LSPMacroReference) + lambda ref: isinstance(ref, MacroReference) and _position_within_range(position, ref.range), get_macro_definitions_for_a_path(lsp_context, document_uri), ), @@ -704,14 +496,14 @@ def get_macro_find_all_references( if not macro_at_position: return [] - assert isinstance(macro_at_position, LSPMacroReference) # for mypy + assert isinstance(macro_at_position, MacroReference) # for mypy target_macro_path = macro_at_position.path target_macro_target_range = macro_at_position.target_range # Start with the macro definition - all_references: t.List[LSPMacroReference] = [ - LSPMacroReference( + all_references: t.List[MacroReference] = [ + MacroReference( path=target_macro_path, range=target_macro_target_range, target_range=target_macro_target_range, @@ -725,16 +517,16 @@ def get_macro_find_all_references( # Get macro references that point to the same macro definition matching_refs = filter( - lambda ref: isinstance(ref, LSPMacroReference) + lambda ref: isinstance(ref, MacroReference) and ref.path == target_macro_path and ref.target_range == target_macro_target_range, get_macro_definitions_for_a_path(lsp_context, file_uri), ) for ref in matching_refs: - assert isinstance(ref, LSPMacroReference) # for mypy + assert isinstance(ref, MacroReference) # for mypy all_references.append( - LSPMacroReference( + MacroReference( path=path, range=ref.range, target_range=ref.target_range, @@ -786,129 +578,3 @@ def _position_within_range(position: Position, range: Range) -> bool: range.end.line > position.line or (range.end.line == position.line and range.end.character >= position.character) ) - - -def _get_column_table_range(column: exp.Column, read_file: t.List[str]) -> Range: - """ - Get the range for a column's table reference, handling both simple and qualified table names. - - Args: - column: The column expression - read_file: The file content as list of lines - - Returns: - The Range covering the table reference in the column - """ - - table_parts = column.parts[:-1] - - start_range = TokenPositionDetails.from_meta(table_parts[0].meta).to_range(read_file) - end_range = TokenPositionDetails.from_meta(table_parts[-1].meta).to_range(read_file) - - return Range( - start=start_range.start, - end=end_range.end, - ) - - -def _process_column_references( - scope: t.Any, - reference_name: str, - read_file: t.List[str], - referenced_model_path: Path, - description: t.Optional[str] = None, - yaml_target_range: t.Optional[Range] = None, - reference_type: t.Literal["model", "external_model", "cte"] = "model", - default_catalog: t.Optional[str] = None, - dialect: t.Optional[str] = None, - cte_target_range: t.Optional[Range] = None, -) -> t.List[Reference]: - """ - Process column references for a given table and create appropriate reference objects. - - Args: - scope: The SQL scope to search for columns - reference_name: The full reference name (may include database/catalog) - read_file: The file content as list of lines - referenced_model_path: Path of the referenced model - description: Markdown description for the reference - yaml_target_range: Target range for external models (YAML files) - reference_type: Type of reference - "model", "external_model", or "cte" - default_catalog: Default catalog for normalization - dialect: SQL dialect for normalization - cte_target_range: Target range for CTE references - - Returns: - List of table references for column usages - """ - - references: t.List[Reference] = [] - for column in scope.find_all(exp.Column): - if column.table: - if reference_type == "cte": - if column.table == reference_name: - table_range = _get_column_table_range(column, read_file) - references.append( - LSPCteReference( - path=referenced_model_path, - range=table_range, - target_range=cte_target_range, - ) - ) - else: - table_parts = [part.sql(dialect) for part in column.parts[:-1]] - table_ref = ".".join(table_parts) - normalized_reference_name = normalize_model_name( - table_ref, - default_catalog=default_catalog, - dialect=dialect, - ) - if normalized_reference_name == reference_name: - table_range = _get_column_table_range(column, read_file) - if reference_type == "external_model": - references.append( - LSPExternalModelReference( - path=referenced_model_path, - range=table_range, - markdown_description=description, - target_range=yaml_target_range, - ) - ) - else: - references.append( - LSPModelReference( - path=referenced_model_path, - range=table_range, - markdown_description=description, - ) - ) - - return references - - -def _get_yaml_model_range(path: Path, model_name: str) -> t.Optional[Range]: - """ - Find the range of a specific model block in a YAML file. - - Args: - yaml_path: Path to the YAML file - model_name: Name of the model to find - - Returns: - The Range of the model block in the YAML file, or None if not found - """ - yaml = YAML() - with path.open("r", encoding="utf-8") as f: - data = yaml.load(f) - - if not isinstance(data, list): - return None - - for item in data: - if isinstance(item, dict) and item.get("name") == model_name: - # Get size of block by taking the earliest line/col in the items block and the last line/col of the block - position_data = item.lc.data["name"] # type: ignore - start = Position(line=position_data[2], character=position_data[3]) - end = Position(line=position_data[2], character=position_data[3] + len(item["name"])) - return Range(start=start, end=end) - return None diff --git a/sqlmesh/lsp/rename.py b/sqlmesh/lsp/rename.py index 31f7eb3200..5675c4efca 100644 --- a/sqlmesh/lsp/rename.py +++ b/sqlmesh/lsp/rename.py @@ -13,7 +13,7 @@ from sqlmesh.lsp.reference import ( _position_within_range, get_cte_references, - LSPCteReference, + CTEReference, ) from sqlmesh.lsp.uri import URI @@ -82,7 +82,7 @@ def rename_symbol( return None -def _rename_cte(cte_references: t.List[LSPCteReference], new_name: str) -> WorkspaceEdit: +def _rename_cte(cte_references: t.List[CTEReference], new_name: str) -> WorkspaceEdit: """ Create a WorkspaceEdit for renaming a CTE. diff --git a/sqlmesh/utils/lineage.py b/sqlmesh/utils/lineage.py new file mode 100644 index 0000000000..8fcb92f56b --- /dev/null +++ b/sqlmesh/utils/lineage.py @@ -0,0 +1,404 @@ +import typing as t +from pathlib import Path + +from pydantic import Field + +from sqlmesh.core.dialect import normalize_model_name +from sqlmesh.core.linter.helpers import ( + TokenPositionDetails, +) +from sqlmesh.core.linter.rule import Range, Position +from sqlmesh.core.model.definition import SqlModel, ExternalModel, PythonModel, SeedModel +from sqlglot import exp +from sqlglot.optimizer.scope import build_scope + +from sqlglot.optimizer.normalize_identifiers import normalize_identifiers +from ruamel.yaml import YAML + +from sqlmesh.utils.pydantic import PydanticModel + +if t.TYPE_CHECKING: + from sqlmesh.core.context import Context + from sqlmesh.core.context import GenericContext + + +class ModelReference(PydanticModel): + """A reference to a model, excluding external models.""" + + type: t.Literal["model"] = "model" + path: Path + range: Range + markdown_description: t.Optional[str] = None + + +class ExternalModelReference(PydanticModel): + """A reference to an external model.""" + + type: t.Literal["external_model"] = "external_model" + range: Range + target_range: t.Optional[Range] = None + path: t.Optional[Path] = None + """The path of the external model, typically a YAML file, it is optional because + external models can be unregistered and so the path is not available.""" + + markdown_description: t.Optional[str] = None + + +class CTEReference(PydanticModel): + """A reference to a CTE.""" + + type: t.Literal["cte"] = "cte" + path: Path + range: Range + target_range: Range + + +class MacroReference(PydanticModel): + """A reference to a macro.""" + + type: t.Literal["macro"] = "macro" + path: Path + range: Range + target_range: Range + markdown_description: t.Optional[str] = None + + +Reference = t.Annotated[ + t.Union[ModelReference, CTEReference, MacroReference, ExternalModelReference], + Field(discriminator="type"), +] + + +def extract_references_from_query( + query: exp.Expression, + context: t.Union["Context", "GenericContext[t.Any]"], + document_path: Path, + read_file: t.List[str], + depends_on: t.Set[str], + dialect: t.Optional[str] = None, +) -> t.List[Reference]: + # Build a scope tree to properly handle nested CTEs + try: + query = normalize_identifiers(query.copy(), dialect=dialect) + root_scope = build_scope(query) + except Exception: + root_scope = None + + references: t.List[Reference] = [] + if not root_scope: + return references + + # Traverse all scopes to find CTE definitions and table references + for scope in root_scope.traverse(): + for table in scope.tables: + table_name = table.name + + # Check if this table reference is a CTE in the current scope + if cte_scope := scope.cte_sources.get(table_name): + cte = cte_scope.expression.parent + alias = cte.args["alias"] + if isinstance(alias, exp.TableAlias): + identifier = alias.this + if isinstance(identifier, exp.Identifier): + target_range_sqlmesh = TokenPositionDetails.from_meta( + identifier.meta + ).to_range(read_file) + table_range_sqlmesh = TokenPositionDetails.from_meta( + table.this.meta + ).to_range(read_file) + + references.append( + CTEReference( + path=document_path, # Same file + range=table_range_sqlmesh, + target_range=target_range_sqlmesh, + ) + ) + + column_references = _process_column_references( + scope=scope, + reference_name=table.name, + read_file=read_file, + referenced_model_path=document_path, + description="", + reference_type="cte", + cte_target_range=target_range_sqlmesh, + ) + references.extend(column_references) + continue + + # For non-CTE tables, process these as before (external model references) + # Normalize the table reference + unaliased = table.copy() + if unaliased.args.get("alias") is not None: + unaliased.set("alias", None) + reference_name = unaliased.sql(dialect=dialect) + try: + normalized_reference_name = normalize_model_name( + reference_name, + default_catalog=context.default_catalog, + dialect=dialect, + ) + if normalized_reference_name not in depends_on: + continue + except Exception: + # Skip references that cannot be normalized + continue + + # Get the referenced model uri + referenced_model = context.get_model( + model_or_snapshot=normalized_reference_name, raise_if_missing=False + ) + if referenced_model is None: + # Extract metadata for positioning + table_meta = TokenPositionDetails.from_meta(table.this.meta) + table_range_sqlmesh = table_meta.to_range(read_file) + start_pos_sqlmesh = table_range_sqlmesh.start + end_pos_sqlmesh = table_range_sqlmesh.end + + # If there's a catalog or database qualifier, adjust the start position + catalog_or_db = table.args.get("catalog") or table.args.get("db") + if catalog_or_db is not None: + catalog_or_db_meta = TokenPositionDetails.from_meta(catalog_or_db.meta) + catalog_or_db_range_sqlmesh = catalog_or_db_meta.to_range(read_file) + start_pos_sqlmesh = catalog_or_db_range_sqlmesh.start + + references.append( + ExternalModelReference( + range=Range( + start=start_pos_sqlmesh, + end=end_pos_sqlmesh, + ), + markdown_description="Unregistered external model", + ) + ) + continue + referenced_model_path = referenced_model._path + if referenced_model_path is None: + continue + # Check whether the path exists + if not referenced_model_path.is_file(): + continue + + # Extract metadata for positioning + table_meta = TokenPositionDetails.from_meta(table.this.meta) + table_range_sqlmesh = table_meta.to_range(read_file) + start_pos_sqlmesh = table_range_sqlmesh.start + end_pos_sqlmesh = table_range_sqlmesh.end + + # If there's a catalog or database qualifier, adjust the start position + catalog_or_db = table.args.get("catalog") or table.args.get("db") + if catalog_or_db is not None: + catalog_or_db_meta = TokenPositionDetails.from_meta(catalog_or_db.meta) + catalog_or_db_range_sqlmesh = catalog_or_db_meta.to_range(read_file) + start_pos_sqlmesh = catalog_or_db_range_sqlmesh.start + + description = generate_markdown_description(referenced_model) + + # For external models in YAML files, find the specific model block + if isinstance(referenced_model, ExternalModel): + yaml_target_range: t.Optional[Range] = None + if ( + referenced_model_path.suffix in (".yaml", ".yml") + and referenced_model_path.is_file() + ): + yaml_target_range = _get_yaml_model_range( + referenced_model_path, referenced_model.name + ) + references.append( + ExternalModelReference( + path=referenced_model_path, + range=Range( + start=start_pos_sqlmesh, + end=end_pos_sqlmesh, + ), + markdown_description=description, + target_range=yaml_target_range, + ) + ) + + column_references = _process_column_references( + scope=scope, + reference_name=normalized_reference_name, + read_file=read_file, + referenced_model_path=referenced_model_path, + description=description, + yaml_target_range=yaml_target_range, + reference_type="external_model", + default_catalog=context.default_catalog, + dialect=dialect, + ) + references.extend(column_references) + else: + references.append( + ModelReference( + path=referenced_model_path, + range=Range( + start=start_pos_sqlmesh, + end=end_pos_sqlmesh, + ), + markdown_description=description, + ) + ) + + column_references = _process_column_references( + scope=scope, + reference_name=normalized_reference_name, + read_file=read_file, + referenced_model_path=referenced_model_path, + description=description, + reference_type="model", + default_catalog=context.default_catalog, + dialect=dialect, + ) + references.extend(column_references) + + return references + + +def generate_markdown_description( + model: t.Union[SqlModel, ExternalModel, PythonModel, SeedModel], +) -> t.Optional[str]: + description = model.description + columns = model.columns_to_types + column_descriptions = model.column_descriptions + + if columns is None: + return description or None + + columns_table = "\n".join( + [ + f"| {column} | {column_type} | {column_descriptions.get(column, '')} |" + for column, column_type in columns.items() + ] + ) + + table_header = "| Column | Type | Description |\n|--------|------|-------------|\n" + columns_text = table_header + columns_table + return f"{description}\n\n{columns_text}" if description else columns_text + + +def _process_column_references( + scope: t.Any, + reference_name: str, + read_file: t.List[str], + referenced_model_path: Path, + description: t.Optional[str] = None, + yaml_target_range: t.Optional[Range] = None, + reference_type: t.Literal["model", "external_model", "cte"] = "model", + default_catalog: t.Optional[str] = None, + dialect: t.Optional[str] = None, + cte_target_range: t.Optional[Range] = None, +) -> t.List[Reference]: + """ + Process column references for a given table and create appropriate reference objects. + + Args: + scope: The SQL scope to search for columns + reference_name: The full reference name (may include database/catalog) + read_file: The file content as list of lines + referenced_model_path: Path of the referenced model + description: Markdown description for the reference + yaml_target_range: Target range for external models (YAML files) + reference_type: Type of reference - "model", "external_model", or "cte" + default_catalog: Default catalog for normalization + dialect: SQL dialect for normalization + cte_target_range: Target range for CTE references + + Returns: + List of table references for column usages + """ + + references: t.List[Reference] = [] + for column in scope.find_all(exp.Column): + if column.table: + if reference_type == "cte": + if column.table == reference_name: + table_range = _get_column_table_range(column, read_file) + references.append( + CTEReference( + path=referenced_model_path, + range=table_range, + target_range=cte_target_range, + ) + ) + else: + table_parts = [part.sql(dialect) for part in column.parts[:-1]] + table_ref = ".".join(table_parts) + normalized_reference_name = normalize_model_name( + table_ref, + default_catalog=default_catalog, + dialect=dialect, + ) + if normalized_reference_name == reference_name: + table_range = _get_column_table_range(column, read_file) + if reference_type == "external_model": + references.append( + ExternalModelReference( + path=referenced_model_path, + range=table_range, + markdown_description=description, + target_range=yaml_target_range, + ) + ) + else: + references.append( + ModelReference( + path=referenced_model_path, + range=table_range, + markdown_description=description, + ) + ) + + return references + + +def _get_column_table_range(column: exp.Column, read_file: t.List[str]) -> Range: + """ + Get the range for a column's table reference, handling both simple and qualified table names. + + Args: + column: The column expression + read_file: The file content as list of lines + + Returns: + The Range covering the table reference in the column + """ + + table_parts = column.parts[:-1] + + start_range = TokenPositionDetails.from_meta(table_parts[0].meta).to_range(read_file) + end_range = TokenPositionDetails.from_meta(table_parts[-1].meta).to_range(read_file) + + return Range( + start=start_range.start, + end=end_range.end, + ) + + +def _get_yaml_model_range(path: Path, model_name: str) -> t.Optional[Range]: + """ + Find the range of a specific model block in a YAML file. + + Args: + yaml_path: Path to the YAML file + model_name: Name of the model to find + + Returns: + The Range of the model block in the YAML file, or None if not found + """ + yaml = YAML() + with path.open("r", encoding="utf-8") as f: + data = yaml.load(f) + + if not isinstance(data, list): + return None + + for item in data: + if isinstance(item, dict) and item.get("name") == model_name: + # Get size of block by taking the earliest line/col in the items block and the last line/col of the block + position_data = item.lc.data["name"] # type: ignore + start = Position(line=position_data[2], character=position_data[3]) + end = Position(line=position_data[2], character=position_data[3] + len(item["name"])) + return Range(start=start, end=end) + return None diff --git a/tests/core/linter/test_builtin.py b/tests/core/linter/test_builtin.py index 208b591a2d..b9cf759946 100644 --- a/tests/core/linter/test_builtin.py +++ b/tests/core/linter/test_builtin.py @@ -44,7 +44,8 @@ def test_no_missing_external_models(tmp_path, copy_to_temp_path) -> None: # Lint the models lints = context.lint_models(raise_on_error=False) assert len(lints) == 1 + assert lints[0].violation_range is not None assert ( - "Model 'sushi.customers' depends on unregistered external models: " - in lints[0].violation_msg + lints[0].violation_msg + == """Model '"memory"."sushi"."customers"' depends on unregistered external model '"memory"."raw"."demographics"'. Please register it in the external models file. This can be done by running 'sqlmesh create_external_models'.""" ) diff --git a/tests/lsp/test_reference.py b/tests/lsp/test_reference.py index 3d1e19f3cc..6aae4b869e 100644 --- a/tests/lsp/test_reference.py +++ b/tests/lsp/test_reference.py @@ -1,7 +1,7 @@ from sqlmesh.core.context import Context from sqlmesh.core.linter.rule import Position from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget -from sqlmesh.lsp.reference import LSPModelReference, get_model_definitions_for_a_path, by_position +from sqlmesh.lsp.reference import ModelReference, get_model_definitions_for_a_path, by_position from sqlmesh.lsp.uri import URI @@ -54,7 +54,7 @@ def test_reference_with_alias() -> None: for ref in get_model_definitions_for_a_path( lsp_context, URI.from_path(waiter_revenue_by_day_path) ) - if isinstance(ref, LSPModelReference) + if isinstance(ref, ModelReference) ] assert len(references) == 3 diff --git a/tests/lsp/test_reference_cte.py b/tests/lsp/test_reference_cte.py index c6c56fd8a5..9bc74bc990 100644 --- a/tests/lsp/test_reference_cte.py +++ b/tests/lsp/test_reference_cte.py @@ -1,7 +1,7 @@ import re from sqlmesh.core.context import Context from sqlmesh.lsp.context import LSPContext, ModelTarget -from sqlmesh.lsp.reference import LSPCteReference, get_references +from sqlmesh.lsp.reference import CTEReference, get_references from sqlmesh.lsp.uri import URI from lsprotocol.types import Range, Position import typing as t @@ -28,7 +28,7 @@ def test_cte_parsing(): references = get_references(lsp_context, URI.from_path(sushi_customers_path), position) assert len(references) == 1 assert references[0].path == sushi_customers_path - assert isinstance(references[0], LSPCteReference) + assert isinstance(references[0], CTEReference) assert ( references[0].range.start.line == ranges[1].start.line ) # The reference location (where we clicked) @@ -43,7 +43,7 @@ def test_cte_parsing(): references = get_references(lsp_context, URI.from_path(sushi_customers_path), position) assert len(references) == 1 assert references[0].path == sushi_customers_path - assert isinstance(references[0], LSPCteReference) + assert isinstance(references[0], CTEReference) assert ( references[0].range.start.line == ranges[1].start.line ) # The reference location (where we clicked) diff --git a/tests/lsp/test_reference_external_model.py b/tests/lsp/test_reference_external_model.py index 36c64fe277..25de22f10f 100644 --- a/tests/lsp/test_reference_external_model.py +++ b/tests/lsp/test_reference_external_model.py @@ -1,3 +1,4 @@ +import os from pathlib import Path from sqlmesh import Config @@ -5,9 +6,11 @@ from sqlmesh.core.linter.helpers import read_range_from_file from sqlmesh.core.linter.rule import Position from sqlmesh.lsp.context import LSPContext, ModelTarget -from sqlmesh.lsp.reference import get_references, LSPExternalModelReference +from sqlmesh.lsp.reference import get_references from sqlmesh.lsp.uri import URI +from sqlmesh.utils.lineage import ExternalModelReference from tests.utils.test_filesystem import create_temp_file +import typing as t def test_reference() -> None: @@ -27,7 +30,7 @@ def test_reference() -> None: assert len(references) == 1 reference = references[0] - assert isinstance(reference, LSPExternalModelReference) + assert isinstance(reference, ExternalModelReference) path = reference.path assert path is not None assert str(path).endswith("external_models.yaml") @@ -55,8 +58,65 @@ def test_unregistered_external_model(tmp_path: Path): assert len(references) == 1 reference = references[0] - assert isinstance(reference, LSPExternalModelReference) + assert isinstance(reference, ExternalModelReference) assert reference.path is None assert reference.target_range is None assert reference.markdown_description == "Unregistered external model" assert read_range_from_file(model_path, reference.range) == "external_model" + + +def test_unregistered_external_model_with_schema( + copy_to_temp_path: t.Callable[[str], list[Path]], +) -> None: + """ + Tests that the linter correctly identifies unregistered external model dependencies. + + This test removes the `external_models.yaml` file from the sushi example project, + enables the linter, and verifies that the linter raises a violation for a model + that depends on unregistered external models. + """ + sushi_paths = copy_to_temp_path("examples/sushi") + sushi_path = sushi_paths[0] + + # Remove the external_models.yaml file + os.remove(sushi_path / "external_models.yaml") + + # Override the config.py to turn on lint + with open(sushi_path / "config.py", "r") as f: + read_file = f.read() + + before = """ linter=LinterConfig( + enabled=False, + rules=[ + "ambiguousorinvalidcolumn", + "invalidselectstarexpansion", + "noselectstar", + "nomissingaudits", + "nomissingowner", + "nomissingexternalmodels", + ], + ),""" + after = """linter=LinterConfig(enabled=True, rules=["nomissingexternalmodels"]),""" + read_file = read_file.replace(before, after) + assert after in read_file + with open(sushi_path / "config.py", "w") as f: + f.writelines(read_file) + + # Load the context with the temporary sushi path + context = Context(paths=[sushi_path]) + + model = context.get_model("sushi.customers") + if model is None: + raise AssertionError("Model 'sushi.customers' not found in context") + + lsp_context = LSPContext(context) + path = model._path + assert path is not None + uri = URI.from_path(path) + references = get_references(lsp_context, uri, Position(line=42, character=20)) + + assert len(references) == 1 + reference = references[0] + assert isinstance(reference, ExternalModelReference) + assert reference.path is None + assert read_range_from_file(path, reference.range) == "raw.demographics" diff --git a/tests/lsp/test_reference_macro.py b/tests/lsp/test_reference_macro.py index e287212dd2..3ee7c48b3b 100644 --- a/tests/lsp/test_reference_macro.py +++ b/tests/lsp/test_reference_macro.py @@ -1,6 +1,6 @@ from sqlmesh.core.context import Context from sqlmesh.lsp.context import LSPContext, ModelTarget -from sqlmesh.lsp.reference import LSPMacroReference, get_macro_definitions_for_a_path +from sqlmesh.lsp.reference import MacroReference, get_macro_definitions_for_a_path from sqlmesh.lsp.uri import URI @@ -24,6 +24,6 @@ def test_macro_references() -> None: # Check that all references point to the utils.py file for ref in macro_references: - assert isinstance(ref, LSPMacroReference) + assert isinstance(ref, MacroReference) assert URI.from_path(ref.path).value.endswith("sushi/macros/utils.py") assert ref.target_range is not None diff --git a/tests/lsp/test_reference_macro_multi.py b/tests/lsp/test_reference_macro_multi.py index 8226085a1d..3902c0b275 100644 --- a/tests/lsp/test_reference_macro_multi.py +++ b/tests/lsp/test_reference_macro_multi.py @@ -1,6 +1,6 @@ from sqlmesh.core.context import Context from sqlmesh.lsp.context import LSPContext, ModelTarget -from sqlmesh.lsp.reference import LSPMacroReference, get_macro_definitions_for_a_path +from sqlmesh.lsp.reference import MacroReference, get_macro_definitions_for_a_path from sqlmesh.lsp.uri import URI @@ -19,6 +19,6 @@ def test_macro_references_multirepo() -> None: assert len(macro_references) == 2 for ref in macro_references: - assert isinstance(ref, LSPMacroReference) + assert isinstance(ref, MacroReference) assert str(URI.from_path(ref.path).value).endswith("multi/repo_2/macros/__init__.py") assert ref.target_range is not None diff --git a/tests/lsp/test_description.py b/tests/utils/test_lineage_description.py similarity index 93% rename from tests/lsp/test_description.py rename to tests/utils/test_lineage_description.py index 054d55fecc..e7053e3bcc 100644 --- a/tests/lsp/test_description.py +++ b/tests/utils/test_lineage_description.py @@ -1,5 +1,5 @@ from sqlmesh.core.context import Context -from sqlmesh.lsp.description import generate_markdown_description +from sqlmesh.utils.lineage import generate_markdown_description def test_model_description() -> None: