diff --git a/sqlmesh/core/linter/definition.py b/sqlmesh/core/linter/definition.py index 9cfa4076cd..c7cee6aaa9 100644 --- a/sqlmesh/core/linter/definition.py +++ b/sqlmesh/core/linter/definition.py @@ -108,9 +108,10 @@ def check_model(self, model: Model, context: GenericContext) -> t.List[RuleViola for rule in self._underlying.values(): violation = rule(context).check_model(model) - + if isinstance(violation, RuleViolation): + violation = [violation] if violation: - violations.append(violation) + violations.extend(violation) return violations diff --git a/sqlmesh/core/linter/rule.py b/sqlmesh/core/linter/rule.py index da33df2124..6e63dd2ee6 100644 --- a/sqlmesh/core/linter/rule.py +++ b/sqlmesh/core/linter/rule.py @@ -70,7 +70,9 @@ def __init__(self, context: GenericContext): self.context = context @abc.abstractmethod - def check_model(self, model: Model) -> t.Optional[RuleViolation]: + def check_model( + self, model: Model + ) -> t.Optional[t.Union[RuleViolation, t.List[RuleViolation]]]: """The evaluation function that'll check for a violation of this rule.""" @property diff --git a/sqlmesh/core/linter/rules/builtin.py b/sqlmesh/core/linter/rules/builtin.py index f16bb5d111..ada3334982 100644 --- a/sqlmesh/core/linter/rules/builtin.py +++ b/sqlmesh/core/linter/rules/builtin.py @@ -11,6 +11,8 @@ from sqlmesh.core.linter.rule import Rule, RuleViolation, Range, Fix, TextEdit from sqlmesh.core.linter.definition import RuleSet from sqlmesh.core.model import Model, SqlModel, ExternalModel +from sqlmesh.core.model import Model, SqlModel, ExternalModel +from sqlmesh.core.linter.rules.helpers.lineage import find_external_model_ranges class NoSelectStar(Rule): @@ -110,18 +112,22 @@ def check_model(self, model: Model) -> t.Optional[RuleViolation]: return self.violation() -class NoMissingExternalModels(Rule): +class NoUnregisteredExternalModels(Rule): """All external models must be registered in the external_models.yaml file""" - def check_model(self, model: Model) -> t.Optional[RuleViolation]: - # Ignore external models themselves, because either they are registered, - # and if they are not, they will be caught as referenced in another model. + def check_model( + self, model: Model + ) -> t.Optional[t.Union[RuleViolation, t.List[RuleViolation]]]: + depends_on = model.depends_on + + # Ignore external models themselves, because either they are registered + # if they are not, they will be caught as referenced in another model. if isinstance(model, ExternalModel): return None - # Handle other models that may refer to the external models. + # Handle other models that are referring to them not_registered_external_models: t.Set[str] = set() - for depends_on_model in model.depends_on: + for depends_on_model in depends_on: existing_model = self.context.get_model(depends_on_model) if existing_model is None: not_registered_external_models.add(depends_on_model) @@ -129,12 +135,59 @@ def check_model(self, model: Model) -> t.Optional[RuleViolation]: if not not_registered_external_models: return None + path = model._path + # For SQL models, try to do better than just raise it + if isinstance(model, SqlModel) and path is not None and str(path).endswith(".sql"): + external_model_ranges = self.find_external_model_ranges( + not_registered_external_models, model + ) + if external_model_ranges is None: + return RuleViolation( + rule=self, + violation_msg=f"Model '{model.fqn}' depends on unregistered external models: " + f"{', '.join(m for m in not_registered_external_models)}. " + "Please register them in the external_models.yaml file.", + ) + + outs: t.List[RuleViolation] = [] + for external_model in not_registered_external_models: + external_model_range = external_model_ranges.get(external_model) + if external_model_range: + outs.extend( + RuleViolation( + rule=self, + violation_msg=f"Model '{model.fqn}' depends on unregistered external model: " + f"{external_model}. Please register it in the external_models.yaml file.", + violation_range=target, + ) + for target in external_model_range + ) + else: + outs.append( + RuleViolation( + rule=self, + violation_msg=f"Model '{model.fqn}' depends on unregistered external model: " + f"{external_model}. Please register it in the external_models.yaml file.", + ) + ) + + return outs + return RuleViolation( rule=self, violation_msg=f"Model '{model.name}' depends on unregistered external models: " f"{', '.join(m for m in not_registered_external_models)}. " - "Please register them in the external models file. This can be done by running 'sqlmesh create_external_models'.", + "Please register them in the external_models.yaml file.", ) + def find_external_model_ranges( + self, external_models_not_registered: t.Set[str], model: SqlModel + ) -> t.Optional[t.Dict[str, t.List[Range]]]: + """Returns a map of external model names to their ranges found in the query. + + It returns a dictionary of fqn to a list of ranges where the external model + """ + return find_external_model_ranges(self.context, external_models_not_registered, model) + BUILTIN_RULES = RuleSet(subclasses(__name__, Rule, (Rule,))) diff --git a/sqlmesh/core/linter/rules/helpers/__init__.py b/sqlmesh/core/linter/rules/helpers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sqlmesh/core/linter/rules/helpers/lineage.py b/sqlmesh/core/linter/rules/helpers/lineage.py new file mode 100644 index 0000000000..e6a6d68036 --- /dev/null +++ b/sqlmesh/core/linter/rules/helpers/lineage.py @@ -0,0 +1,415 @@ +import typing as t +from pathlib import Path + +from pydantic import Field +from sqlglot.optimizer import build_scope +from sqlglot.optimizer.normalize_identifiers import normalize_identifiers + +from sqlmesh.core.dialect import normalize_model_name +from sqlmesh.core.linter.helpers import TokenPositionDetails +from sqlmesh.core.linter.rule import Range, Position +from sqlmesh.core.linter.rules.helpers.yaml import _get_yaml_model_range +from sqlmesh.core.model import SqlModel, ExternalModel +from sqlglot import exp + +from sqlmesh.utils.pydantic import PydanticModel + +if t.TYPE_CHECKING: + from sqlmesh.core.context import GenericContext, Context + + +class LSPModelReference(PydanticModel): + """A LSP reference to a model, excluding external models.""" + + type: t.Literal["model"] = "model" + path: Path + range: Range + markdown_description: t.Optional[str] = None + + +class LSPExternalModelReference(PydanticModel): + """A LSP reference to an external model.""" + + type: t.Literal["external_model"] = "external_model" + + model_name: str + path: Path + range: Range + markdown_description: t.Optional[str] = None + target_range: t.Optional[Range] = None + + +class LSPCteReference(PydanticModel): + """A LSP reference to a CTE.""" + + type: t.Literal["cte"] = "cte" + path: Path + range: Range + target_range: Range + + +class LSPMacroReference(PydanticModel): + """A LSP reference to a macro.""" + + type: t.Literal["macro"] = "macro" + path: Path + range: Range + target_range: Range + markdown_description: t.Optional[str] = None + + +Reference = t.Annotated[ + t.Union[LSPModelReference, LSPCteReference, LSPMacroReference, LSPExternalModelReference], + Field(discriminator="type"), +] + + +def find_external_model_ranges( + context: "GenericContext", + external_models_not_registered: t.Set[str], + model: SqlModel, +) -> t.Optional[t.Dict[str, t.List[Range]]]: + """Returns a map of external model names to their ranges found in the query. + + It returns a dictionary of fqn to a list of ranges where the external model + """ + path = model._path + if path is None or not str(path).endswith(".sql"): + return None + + depends_on = model.depends_on + query = model.query + with open(path, "r", encoding="utf-8") as file: + content = file.readlines() + + references = extract_references_from_query( + context=context, # type: ignore + read_file=content, + query=query, + dialect=model.dialect, + depends_on=depends_on, + document_path=path, + ) + external_model_references = [ + reference for reference in references if isinstance(reference, LSPExternalModelReference) + ] + if not external_model_references: + return None + + external_model_references_filtered = [ + reference + for reference in external_model_references + if reference.model_name in external_models_not_registered + ] + if not external_model_references_filtered: + return None + external_model_ranges: t.Dict[str, t.List[Range]] = {} + for reference in external_model_references_filtered: + if reference.model_name not in external_model_ranges: + external_model_ranges[reference.model_name] = [] + # Convert LSP Range to linter Range + lsp_range = reference.range + linter_range = Range( + start=Position(line=lsp_range.start.line, character=lsp_range.start.character), + end=Position(line=lsp_range.end.line, character=lsp_range.end.character), + ) + external_model_ranges[reference.model_name].append(linter_range) + return external_model_ranges + + +def extract_references_from_query( + query: exp.Expression, + context: "Context", + document_path: Path, + read_file: t.List[str], + depends_on: t.Set[str], + dialect: t.Optional[str] = None, +) -> t.List[Reference]: + """ + Extract references from a SQL query, including CTEs and external + models. + """ + references: t.List[Reference] = [] + # Build a scope tree to properly handle nested CTEs + try: + query = normalize_identifiers(query.copy(), dialect=dialect) + root_scope = build_scope(query) + except Exception: + root_scope = None + + if root_scope: + # Traverse all scopes to find CTE definitions and table references + for scope in root_scope.traverse(): + for table in scope.tables: + table_name = table.name + + # Check if this table reference is a CTE in the current scope + if cte_scope := scope.cte_sources.get(table_name): + cte = cte_scope.expression.parent + alias = cte.args["alias"] + if isinstance(alias, exp.TableAlias): + identifier = alias.this + if isinstance(identifier, exp.Identifier): + target_range = TokenPositionDetails.from_meta(identifier.meta).to_range( + read_file + ) + table_range = TokenPositionDetails.from_meta(table.this.meta).to_range( + read_file + ) + + references.append( + LSPCteReference( + path=document_path, # Same file + range=table_range, + target_range=target_range, + ) + ) + + column_references = _process_column_references( + scope=scope, + reference_name=table.name, + read_file=read_file, + referenced_model_path=document_path, + description="", + reference_type="cte", + cte_target_range=target_range, + ) + references.extend(column_references) + continue + + # For non-CTE tables, process as before (external model references) + # Normalize the table reference + unaliased = table.copy() + if unaliased.args.get("alias") is not None: + unaliased.set("alias", None) + reference_name = unaliased.sql(dialect=dialect) + try: + normalized_reference_name = normalize_model_name( + reference_name, + default_catalog=context.default_catalog, + dialect=dialect, + ) + if normalized_reference_name not in depends_on: + continue + except Exception: + # Skip references that cannot be normalized + continue + + # Get the referenced model uri + referenced_model = context.get_model( + model_or_snapshot=normalized_reference_name, raise_if_missing=False + ) + if referenced_model is None: + continue + referenced_model_path = referenced_model._path + if referenced_model_path is None: + continue + # Check whether the path exists + if not referenced_model_path.is_file(): + continue + + # Extract metadata for positioning + table_meta = TokenPositionDetails.from_meta(table.this.meta) + table_range_sqlmesh = table_meta.to_range(read_file) + start_pos_sqlmesh = table_range_sqlmesh.start + end_pos_sqlmesh = table_range_sqlmesh.end + + # If there's a catalog or database qualifier, adjust the start position + catalog_or_db = table.args.get("catalog") or table.args.get("db") + if catalog_or_db is not None: + catalog_or_db_meta = TokenPositionDetails.from_meta(catalog_or_db.meta) + catalog_or_db_range_sqlmesh = catalog_or_db_meta.to_range(read_file) + start_pos_sqlmesh = catalog_or_db_range_sqlmesh.start + + description = generate_markdown_description(referenced_model) + + # For external models in YAML files, find the specific model block + if isinstance(referenced_model, ExternalModel): + yaml_target_range: t.Optional[Range] = None + if ( + referenced_model_path.suffix in (".yaml", ".yml") + and referenced_model_path.is_file() + ): + yaml_target_range = _get_yaml_model_range( + referenced_model_path, referenced_model.name + ) + references.append( + LSPExternalModelReference( + model_name=normalized_reference_name, + path=referenced_model_path, + range=Range( + start=start_pos_sqlmesh, + end=end_pos_sqlmesh, + ), + markdown_description=description, + target_range=yaml_target_range, + ) + ) + + column_references = _process_column_references( + scope=scope, + reference_name=normalized_reference_name, + read_file=read_file, + referenced_model_path=referenced_model_path, + description=description, + yaml_target_range=yaml_target_range, + reference_type="external_model", + default_catalog=context.default_catalog, + dialect=dialect, + ) + references.extend(column_references) + else: + references.append( + LSPModelReference( + path=referenced_model_path, + range=Range( + start=start_pos_sqlmesh, + end=end_pos_sqlmesh, + ), + markdown_description=description, + ) + ) + + column_references = _process_column_references( + scope=scope, + reference_name=normalized_reference_name, + read_file=read_file, + referenced_model_path=referenced_model_path, + description=description, + reference_type="model", + default_catalog=context.default_catalog, + dialect=dialect, + ) + references.extend(column_references) + + return references + + +def _process_column_references( + scope: t.Any, + reference_name: str, + read_file: t.List[str], + referenced_model_path: Path, + description: t.Optional[str] = None, + yaml_target_range: t.Optional[Range] = None, + reference_type: t.Literal["model", "external_model", "cte"] = "model", + default_catalog: t.Optional[str] = None, + dialect: t.Optional[str] = None, + cte_target_range: t.Optional[Range] = None, +) -> t.List[Reference]: + """ + Process column references for a given table and create appropriate reference objects. + + Args: + scope: The SQL scope to search for columns + reference_name: The full reference name (may include database/catalog) + read_file: The file content as list of lines + referenced_model_uri: URI of the referenced model + description: Markdown description for the reference + yaml_target_range: Target range for external models (YAML files) + reference_type: Type of reference - "model", "external_model", or "cte" + default_catalog: Default catalog for normalization + dialect: SQL dialect for normalization + cte_target_range: Target range for CTE references + + Returns: + List of table references for column usages + """ + + references: t.List[Reference] = [] + for column in scope.find_all(exp.Column): + if column.table: + if reference_type == "cte": + if column.table == reference_name: + table_range = _get_column_table_range(column, read_file) + references.append( + LSPCteReference( + path=referenced_model_path, + range=table_range, + target_range=cte_target_range, + ) + ) + else: + table_parts = [part.sql(dialect) for part in column.parts[:-1]] + table_ref = ".".join(table_parts) + normalized_reference_name = normalize_model_name( + table_ref, + default_catalog=default_catalog, + dialect=dialect, + ) + if normalized_reference_name == reference_name: + table_range = _get_column_table_range(column, read_file) + if reference_type == "external_model": + references.append( + LSPExternalModelReference( + model_name=reference_name, + path=referenced_model_path, + range=table_range, + markdown_description=description, + target_range=yaml_target_range, + ) + ) + else: + references.append( + LSPModelReference( + path=referenced_model_path, + range=table_range, + markdown_description=description, + ) + ) + + return references + + +def _get_column_table_range(column: exp.Column, read_file: t.List[str]) -> Range: + """ + Get the range for a column's table reference, handling both simple and qualified table names. + + Args: + column: The column expression + read_file: The file content as list of lines + + Returns: + The Range covering the table reference in the column + """ + + table_parts = column.parts[:-1] + + start_range = TokenPositionDetails.from_meta(table_parts[0].meta).to_range(read_file) + end_range = TokenPositionDetails.from_meta(table_parts[-1].meta).to_range(read_file) + + return Range( + start=start_range.start, + end=end_range.end, + ) + + +from sqlmesh.core.model.definition import ( + ExternalModel, + PythonModel, + SeedModel, + SqlModel, +) +import typing as t + + +def generate_markdown_description( + model: t.Union[SqlModel, ExternalModel, PythonModel, SeedModel], +) -> t.Optional[str]: + description = model.description + columns = model.columns_to_types + column_descriptions = model.column_descriptions + + if columns is None: + return description or None + + columns_table = "\n".join( + [ + f"| {column} | {column_type} | {column_descriptions.get(column, '')} |" + for column, column_type in columns.items() + ] + ) + + table_header = "| Column | Type | Description |\n|--------|------|-------------|\n" + columns_text = table_header + columns_table + return f"{description}\n\n{columns_text}" if description else columns_text diff --git a/sqlmesh/core/linter/rules/helpers/yaml.py b/sqlmesh/core/linter/rules/helpers/yaml.py new file mode 100644 index 0000000000..3def077374 --- /dev/null +++ b/sqlmesh/core/linter/rules/helpers/yaml.py @@ -0,0 +1,34 @@ +from pathlib import Path + +from ruamel.yaml import YAML + +from sqlmesh.core.linter.rule import Range, Position +import typing as t + + +def _get_yaml_model_range(path: Path, model_name: str) -> t.Optional[Range]: + """ + Find the range of a specific model block in a YAML file. + + Args: + yaml_path: Path to the YAML file + model_name: Name of the model to find + + Returns: + The Range of the model block in the YAML file, or None if not found + """ + yaml = YAML() + with path.open("r", encoding="utf-8") as f: + data = yaml.load(f) + + if not isinstance(data, list): + return None + + for item in data: + if isinstance(item, dict) and item.get("name") == model_name: + # Get size of block by taking the earliest line/col in the items block and the last line/col of the block + position_data = item.lc.data["name"] # type: ignore + start = Position(line=position_data[2], character=position_data[3]) + end = Position(line=position_data[2], character=position_data[3] + len(item["name"])) + return Range(start=start, end=end) + return None diff --git a/sqlmesh/lsp/completions.py b/sqlmesh/lsp/completions.py index 0026260481..7c4dc8aa46 100644 --- a/sqlmesh/lsp/completions.py +++ b/sqlmesh/lsp/completions.py @@ -1,5 +1,7 @@ from functools import lru_cache from sqlglot import Dialect, Tokenizer + +from sqlmesh.core.linter.rules.helpers.lineage import generate_markdown_description from sqlmesh.lsp.custom import ( AllModelsResponse, MacroCompletion, @@ -8,7 +10,6 @@ from sqlmesh import macro import typing as t from sqlmesh.lsp.context import AuditTarget, LSPContext, ModelTarget -from sqlmesh.lsp.description import generate_markdown_description from sqlmesh.lsp.uri import URI diff --git a/sqlmesh/lsp/description.py b/sqlmesh/lsp/description.py deleted file mode 100644 index 768197742f..0000000000 --- a/sqlmesh/lsp/description.py +++ /dev/null @@ -1,29 +0,0 @@ -from sqlmesh.core.model.definition import ( - ExternalModel, - PythonModel, - SeedModel, - SqlModel, -) -import typing as t - - -def generate_markdown_description( - model: t.Union[SqlModel, ExternalModel, PythonModel, SeedModel], -) -> t.Optional[str]: - description = model.description - columns = model.columns_to_types - column_descriptions = model.column_descriptions - - if columns is None: - return description or None - - columns_table = "\n".join( - [ - f"| {column} | {column_type} | {column_descriptions.get(column, '')} |" - for column, column_type in columns.items() - ] - ) - - table_header = "| Column | Type | Description |\n|--------|------|-------------|\n" - columns_text = table_header + columns_table - return f"{description}\n\n{columns_text}" if description else columns_text diff --git a/sqlmesh/lsp/main.py b/sqlmesh/lsp/main.py index c257ccfaa1..407939ded6 100755 --- a/sqlmesh/lsp/main.py +++ b/sqlmesh/lsp/main.py @@ -16,6 +16,7 @@ from pygls.server import LanguageServer from sqlmesh._version import __version__ from sqlmesh.core.context import Context +from sqlmesh.core.linter.rules.helpers.lineage import LSPExternalModelReference from sqlmesh.lsp.api import ( API_FEATURE, ApiRequest, @@ -48,11 +49,11 @@ CustomMethod, ) from sqlmesh.lsp.errors import ContextFailedError, context_error_to_diagnostic +from sqlmesh.lsp.helpers import to_lsp_range from sqlmesh.lsp.hints import get_hints from sqlmesh.lsp.reference import ( LSPCteReference, LSPModelReference, - LSPExternalModelReference, get_references, get_all_references, ) @@ -432,7 +433,7 @@ def hover(ls: LanguageServer, params: types.HoverParams) -> t.Optional[types.Hov kind=types.MarkupKind.Markdown, value=reference.markdown_description, ), - range=reference.range, + range=to_lsp_range(reference.range), ) except Exception as e: @@ -492,19 +493,19 @@ def goto_definition( end=types.Position(line=0, character=0), ) if reference.target_range is not None: - target_range = reference.target_range - target_selection_range = reference.target_range + target_range = to_lsp_range(reference.target_range) + target_selection_range = to_lsp_range(reference.target_range) else: # CTEs and Macros always have target_range - target_range = reference.target_range - target_selection_range = reference.target_range + target_range = to_lsp_range(reference.target_range) + target_selection_range = to_lsp_range(reference.target_range) location_links.append( types.LocationLink( - target_uri=reference.uri, + target_uri=URI.from_path(reference.path).value, target_selection_range=target_selection_range, target_range=target_range, - origin_selection_range=reference.range, + origin_selection_range=to_lsp_range(reference.range), ) ) return location_links @@ -524,7 +525,10 @@ def find_references( all_references = get_all_references(context, uri, params.position) # Convert references to Location objects - locations = [types.Location(uri=ref.uri, range=ref.range) for ref in all_references] + locations = [ + types.Location(uri=URI.from_path(ref.path).value, range=to_lsp_range(ref.range)) + for ref in all_references + ] return locations if locations else None except Exception as e: diff --git a/sqlmesh/lsp/reference.py b/sqlmesh/lsp/reference.py index 9f1215d9ca..8dfd9115d7 100644 --- a/sqlmesh/lsp/reference.py +++ b/sqlmesh/lsp/reference.py @@ -1,72 +1,28 @@ from lsprotocol.types import Range, Position import typing as t from pathlib import Path -from pydantic import Field from sqlmesh.core.audit import StandaloneAudit -from sqlmesh.core.dialect import normalize_model_name from sqlmesh.core.linter.helpers import ( TokenPositionDetails, ) -from sqlmesh.core.model.definition import SqlModel, ExternalModel +from sqlmesh.core.linter.rules.helpers.lineage import ( + extract_references_from_query, + LSPMacroReference, + LSPCteReference, + LSPModelReference, + Reference, +) +from sqlmesh.core.model.definition import SqlModel from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget from sqlglot import exp -from sqlmesh.lsp.description import generate_markdown_description -from sqlglot.optimizer.scope import build_scope -from sqlmesh.lsp.helpers import to_lsp_range, to_lsp_position +from sqlmesh.lsp.helpers import to_lsp_range from sqlmesh.lsp.uri import URI -from sqlmesh.utils.pydantic import PydanticModel -from sqlglot.optimizer.normalize_identifiers import normalize_identifiers import ast from sqlmesh.core.model import Model from sqlmesh import macro import inspect -from ruamel.yaml import YAML - - -class LSPModelReference(PydanticModel): - """A LSP reference to a model, excluding external models.""" - - type: t.Literal["model"] = "model" - uri: str - range: Range - markdown_description: t.Optional[str] = None - - -class LSPExternalModelReference(PydanticModel): - """A LSP reference to an external model.""" - - type: t.Literal["external_model"] = "external_model" - uri: str - range: Range - markdown_description: t.Optional[str] = None - target_range: t.Optional[Range] = None - - -class LSPCteReference(PydanticModel): - """A LSP reference to a CTE.""" - - type: t.Literal["cte"] = "cte" - uri: str - range: Range - target_range: Range - - -class LSPMacroReference(PydanticModel): - """A LSP reference to a macro.""" - - type: t.Literal["macro"] = "macro" - uri: str - range: Range - target_range: Range - markdown_description: t.Optional[str] = None - - -Reference = t.Annotated[ - t.Union[LSPModelReference, LSPCteReference, LSPMacroReference, LSPExternalModelReference], - Field(discriminator="type"), -] def by_position(position: Position) -> t.Callable[[Reference], bool]: @@ -81,7 +37,7 @@ def by_position(position: Position) -> t.Callable[[Reference], bool]: """ def contains_position(r: Reference) -> bool: - return _position_within_range(position, r.range) + return _position_within_range(position, to_lsp_range(r.range)) return contains_position @@ -167,169 +123,17 @@ def get_model_definitions_for_a_path( if file_path is None: return [] - # Find all possible references - references: t.List[Reference] = [] - with open(file_path, "r", encoding="utf-8") as file: read_file = file.readlines() - # Build a scope tree to properly handle nested CTEs - try: - query = normalize_identifiers(query.copy(), dialect=dialect) - root_scope = build_scope(query) - except Exception: - root_scope = None - - if root_scope: - # Traverse all scopes to find CTE definitions and table references - for scope in root_scope.traverse(): - for table in scope.tables: - table_name = table.name - - # Check if this table reference is a CTE in the current scope - if cte_scope := scope.cte_sources.get(table_name): - cte = cte_scope.expression.parent - alias = cte.args["alias"] - if isinstance(alias, exp.TableAlias): - identifier = alias.this - if isinstance(identifier, exp.Identifier): - target_range_sqlmesh = TokenPositionDetails.from_meta( - identifier.meta - ).to_range(read_file) - table_range_sqlmesh = TokenPositionDetails.from_meta( - table.this.meta - ).to_range(read_file) - - # Convert SQLMesh Range to LSP Range - target_range = to_lsp_range(target_range_sqlmesh) - table_range = to_lsp_range(table_range_sqlmesh) - - references.append( - LSPCteReference( - uri=document_uri.value, # Same file - range=table_range, - target_range=target_range, - ) - ) - - column_references = _process_column_references( - scope=scope, - reference_name=table.name, - read_file=read_file, - referenced_model_uri=document_uri, - description="", - reference_type="cte", - cte_target_range=target_range, - ) - references.extend(column_references) - continue - - # For non-CTE tables, process as before (external model references) - # Normalize the table reference - unaliased = table.copy() - if unaliased.args.get("alias") is not None: - unaliased.set("alias", None) - reference_name = unaliased.sql(dialect=dialect) - try: - normalized_reference_name = normalize_model_name( - reference_name, - default_catalog=lint_context.context.default_catalog, - dialect=dialect, - ) - if normalized_reference_name not in depends_on: - continue - except Exception: - # Skip references that cannot be normalized - continue - - # Get the referenced model uri - referenced_model = lint_context.context.get_model( - model_or_snapshot=normalized_reference_name, raise_if_missing=False - ) - if referenced_model is None: - continue - referenced_model_path = referenced_model._path - if referenced_model_path is None: - continue - # Check whether the path exists - if not referenced_model_path.is_file(): - continue - referenced_model_uri = URI.from_path(referenced_model_path) - - # Extract metadata for positioning - table_meta = TokenPositionDetails.from_meta(table.this.meta) - table_range_sqlmesh = table_meta.to_range(read_file) - start_pos_sqlmesh = table_range_sqlmesh.start - end_pos_sqlmesh = table_range_sqlmesh.end - - # If there's a catalog or database qualifier, adjust the start position - catalog_or_db = table.args.get("catalog") or table.args.get("db") - if catalog_or_db is not None: - catalog_or_db_meta = TokenPositionDetails.from_meta(catalog_or_db.meta) - catalog_or_db_range_sqlmesh = catalog_or_db_meta.to_range(read_file) - start_pos_sqlmesh = catalog_or_db_range_sqlmesh.start - - description = generate_markdown_description(referenced_model) - - # For external models in YAML files, find the specific model block - if isinstance(referenced_model, ExternalModel): - yaml_target_range: t.Optional[Range] = None - if ( - referenced_model_path.suffix in (".yaml", ".yml") - and referenced_model_path.is_file() - ): - yaml_target_range = _get_yaml_model_range( - referenced_model_path, referenced_model.name - ) - references.append( - LSPExternalModelReference( - uri=referenced_model_uri.value, - range=Range( - start=to_lsp_position(start_pos_sqlmesh), - end=to_lsp_position(end_pos_sqlmesh), - ), - markdown_description=description, - target_range=yaml_target_range, - ) - ) - - column_references = _process_column_references( - scope=scope, - reference_name=normalized_reference_name, - read_file=read_file, - referenced_model_uri=referenced_model_uri, - description=description, - yaml_target_range=yaml_target_range, - reference_type="external_model", - default_catalog=lint_context.context.default_catalog, - dialect=dialect, - ) - references.extend(column_references) - else: - references.append( - LSPModelReference( - uri=referenced_model_uri.value, - range=Range( - start=to_lsp_position(start_pos_sqlmesh), - end=to_lsp_position(end_pos_sqlmesh), - ), - markdown_description=description, - ) - ) - - column_references = _process_column_references( - scope=scope, - reference_name=normalized_reference_name, - read_file=read_file, - referenced_model_uri=referenced_model_uri, - description=description, - reference_type="model", - default_catalog=lint_context.context.default_catalog, - dialect=dialect, - ) - references.extend(column_references) - - return references + return extract_references_from_query( + query=query, + context=lint_context.context, + document_path=document_uri.to_path(), + read_file=read_file, + depends_on=depends_on, + dialect=dialect, + ) def get_macro_definitions_for_a_path( @@ -464,11 +268,8 @@ def get_macro_reference( if start_line is None or end_line is None or get_length_of_end_line is None: return None - # Create a reference to the macro definition - macro_uri = URI.from_path(path) - return LSPMacroReference( - uri=macro_uri.value, + path=path, range=to_lsp_range(macro_range), target_range=Range( start=Position(line=start_line - 1, character=0), @@ -501,7 +302,7 @@ def get_built_in_macro_reference(macro_name: str, macro_range: Range) -> t.Optio end_line_number = line_number + len(source_lines) - 1 return LSPMacroReference( - uri=URI.from_path(Path(filename)).value, + path=Path(filename), range=macro_range, target_range=Range( start=Position(line=line_number - 1, character=0), @@ -532,7 +333,7 @@ def get_model_find_all_references( model_at_position = next( filter( lambda ref: isinstance(ref, LSPModelReference) - and _position_within_range(position, ref.range), + and _position_within_range(position, to_lsp_range(ref.range)), get_model_definitions_for_a_path(lint_context, document_uri), ), None, @@ -543,12 +344,10 @@ def get_model_find_all_references( assert isinstance(model_at_position, LSPModelReference) # for mypy - target_model_uri = model_at_position.uri - # Start with the model definition all_references: t.List[LSPModelReference] = [ LSPModelReference( - uri=model_at_position.uri, + path=model_at_position.path, range=Range( start=Position(line=0, character=0), end=Position(line=0, character=0), @@ -559,7 +358,7 @@ def get_model_find_all_references( # Then add references from the current file current_file_refs = filter( - lambda ref: isinstance(ref, LSPModelReference) and ref.uri == target_model_uri, + lambda ref: isinstance(ref, LSPModelReference) and ref.path == path, get_model_definitions_for_a_path(lint_context, document_uri), ) @@ -568,7 +367,7 @@ def get_model_find_all_references( all_references.append( LSPModelReference( - uri=document_uri.value, + path=document_uri.to_path(), range=ref.range, markdown_description=ref.markdown_description, ) @@ -584,7 +383,7 @@ def get_model_find_all_references( # Get model references that point to the target model matching_refs = filter( - lambda ref: isinstance(ref, LSPModelReference) and ref.uri == target_model_uri, + lambda ref: isinstance(ref, LSPModelReference) and ref.path == path, get_model_definitions_for_a_path(lint_context, file_uri), ) @@ -593,7 +392,7 @@ def get_model_find_all_references( all_references.append( LSPModelReference( - uri=file_uri.value, + path=file_uri.to_path(), range=ref.range, markdown_description=ref.markdown_description, ) @@ -632,11 +431,11 @@ def get_cte_references( target_cte_definition_range = None for ref in cte_references: # Check if cursor is on a CTE usage - if _position_within_range(position, ref.range): + if _position_within_range(position, to_lsp_range(ref.range)): target_cte_definition_range = ref.target_range break # Check if cursor is on the CTE definition - elif _position_within_range(position, ref.target_range): + elif _position_within_range(position, to_lsp_range(ref.target_range)): target_cte_definition_range = ref.target_range break @@ -646,7 +445,7 @@ def get_cte_references( # Add the CTE definition matching_references = [ LSPCteReference( - uri=document_uri.value, + path=document_uri.to_path(), range=target_cte_definition_range, target_range=target_cte_definition_range, ) @@ -657,7 +456,7 @@ def get_cte_references( if ref.target_range == target_cte_definition_range: matching_references.append( LSPCteReference( - uri=document_uri.value, + path=document_uri.to_path(), range=ref.range, target_range=ref.target_range, ) @@ -686,7 +485,7 @@ def get_macro_find_all_references( macro_at_position = next( filter( lambda ref: isinstance(ref, LSPMacroReference) - and _position_within_range(position, ref.range), + and _position_within_range(position, to_lsp_range(ref.range)), get_macro_definitions_for_a_path(lsp_context, document_uri), ), None, @@ -697,13 +496,13 @@ def get_macro_find_all_references( assert isinstance(macro_at_position, LSPMacroReference) # for mypy - target_macro_uri = macro_at_position.uri + target_macro_path = macro_at_position.path target_macro_target_range = macro_at_position.target_range # Start with the macro definition all_references: t.List[LSPMacroReference] = [ LSPMacroReference( - uri=target_macro_uri, + path=macro_at_position.path, range=target_macro_target_range, target_range=target_macro_target_range, markdown_description=None, @@ -717,7 +516,7 @@ def get_macro_find_all_references( # Get macro references that point to the same macro definition matching_refs = filter( lambda ref: isinstance(ref, LSPMacroReference) - and ref.uri == target_macro_uri + and ref.path == target_macro_path and ref.target_range == target_macro_target_range, get_macro_definitions_for_a_path(lsp_context, file_uri), ) @@ -726,7 +525,7 @@ def get_macro_find_all_references( assert isinstance(ref, LSPMacroReference) # for mypy all_references.append( LSPMacroReference( - uri=file_uri.value, + path=file_uri.to_path(), range=ref.range, target_range=ref.target_range, markdown_description=ref.markdown_description, @@ -777,129 +576,3 @@ def _position_within_range(position: Position, range: Range) -> bool: range.end.line > position.line or (range.end.line == position.line and range.end.character >= position.character) ) - - -def _get_column_table_range(column: exp.Column, read_file: t.List[str]) -> Range: - """ - Get the range for a column's table reference, handling both simple and qualified table names. - - Args: - column: The column expression - read_file: The file content as list of lines - - Returns: - The Range covering the table reference in the column - """ - - table_parts = column.parts[:-1] - - start_range = TokenPositionDetails.from_meta(table_parts[0].meta).to_range(read_file) - end_range = TokenPositionDetails.from_meta(table_parts[-1].meta).to_range(read_file) - - return Range( - start=to_lsp_position(start_range.start), - end=to_lsp_position(end_range.end), - ) - - -def _process_column_references( - scope: t.Any, - reference_name: str, - read_file: t.List[str], - referenced_model_uri: URI, - description: t.Optional[str] = None, - yaml_target_range: t.Optional[Range] = None, - reference_type: t.Literal["model", "external_model", "cte"] = "model", - default_catalog: t.Optional[str] = None, - dialect: t.Optional[str] = None, - cte_target_range: t.Optional[Range] = None, -) -> t.List[Reference]: - """ - Process column references for a given table and create appropriate reference objects. - - Args: - scope: The SQL scope to search for columns - reference_name: The full reference name (may include database/catalog) - read_file: The file content as list of lines - referenced_model_uri: URI of the referenced model - description: Markdown description for the reference - yaml_target_range: Target range for external models (YAML files) - reference_type: Type of reference - "model", "external_model", or "cte" - default_catalog: Default catalog for normalization - dialect: SQL dialect for normalization - cte_target_range: Target range for CTE references - - Returns: - List of table references for column usages - """ - - references: t.List[Reference] = [] - for column in scope.find_all(exp.Column): - if column.table: - if reference_type == "cte": - if column.table == reference_name: - table_range = _get_column_table_range(column, read_file) - references.append( - LSPCteReference( - uri=referenced_model_uri.value, - range=table_range, - target_range=cte_target_range, - ) - ) - else: - table_parts = [part.sql(dialect) for part in column.parts[:-1]] - table_ref = ".".join(table_parts) - normalized_reference_name = normalize_model_name( - table_ref, - default_catalog=default_catalog, - dialect=dialect, - ) - if normalized_reference_name == reference_name: - table_range = _get_column_table_range(column, read_file) - if reference_type == "external_model": - references.append( - LSPExternalModelReference( - uri=referenced_model_uri.value, - range=table_range, - markdown_description=description, - target_range=yaml_target_range, - ) - ) - else: - references.append( - LSPModelReference( - uri=referenced_model_uri.value, - range=table_range, - markdown_description=description, - ) - ) - - return references - - -def _get_yaml_model_range(path: Path, model_name: str) -> t.Optional[Range]: - """ - Find the range of a specific model block in a YAML file. - - Args: - yaml_path: Path to the YAML file - model_name: Name of the model to find - - Returns: - The Range of the model block in the YAML file, or None if not found - """ - yaml = YAML() - with path.open("r", encoding="utf-8") as f: - data = yaml.load(f) - - if not isinstance(data, list): - return None - - for item in data: - if isinstance(item, dict) and item.get("name") == model_name: - # Get size of block by taking the earliest line/col in the items block and the last line/col of the block - position_data = item.lc.data["name"] # type: ignore - start = Position(line=position_data[2], character=position_data[3]) - end = Position(line=position_data[2], character=position_data[3] + len(item["name"])) - return Range(start=start, end=end) - return None diff --git a/sqlmesh/lsp/rename.py b/sqlmesh/lsp/rename.py index 0dbe2594ea..bd05c584ba 100644 --- a/sqlmesh/lsp/rename.py +++ b/sqlmesh/lsp/rename.py @@ -9,6 +9,7 @@ ) from sqlmesh.lsp.context import LSPContext +from sqlmesh.lsp.helpers import to_lsp_range from sqlmesh.lsp.reference import ( _position_within_range, get_cte_references, @@ -38,15 +39,17 @@ def prepare_rename( target_range = None for ref in cte_references: # Check if cursor is on a CTE usage - if _position_within_range(position, ref.range): + if _position_within_range(position, to_lsp_range(ref.range)): target_range = ref.target_range break # Check if cursor is on the CTE definition - elif _position_within_range(position, ref.target_range): + elif _position_within_range(position, to_lsp_range(ref.target_range)): target_range = ref.target_range break if target_range: - return PrepareRenameResult_Type1(range=target_range, placeholder="cte_name") + return PrepareRenameResult_Type1( + range=to_lsp_range(target_range), placeholder="cte_name" + ) # For now, only CTEs are supported return None @@ -90,12 +93,13 @@ def _rename_cte(cte_references: t.List[LSPCteReference], new_name: str) -> Works changes: t.Dict[str, t.List[TextEdit]] = {} for ref in cte_references: - uri = ref.uri + uri = URI.from_path(ref.path).value + if uri not in changes: changes[uri] = [] # Create a text edit for this reference - text_edit = TextEdit(range=ref.range, new_text=new_name) + text_edit = TextEdit(range=to_lsp_range(ref.range), new_text=new_name) changes[uri].append(text_edit) return WorkspaceEdit(changes=changes) @@ -130,7 +134,7 @@ def get_document_highlights( else DocumentHighlightKind.Read ) - highlights.append(DocumentHighlight(range=ref.range, kind=kind)) + highlights.append(DocumentHighlight(range=to_lsp_range(ref.range), kind=kind)) return highlights # For now, only CTEs are supported diff --git a/tests/core/linter/test_builtin.py b/tests/core/linter/test_builtin.py index 208b591a2d..86e17dde4d 100644 --- a/tests/core/linter/test_builtin.py +++ b/tests/core/linter/test_builtin.py @@ -48,3 +48,5 @@ def test_no_missing_external_models(tmp_path, copy_to_temp_path) -> None: "Model 'sushi.customers' depends on unregistered external models: " in lints[0].violation_msg ) + + assert lints[0].violation_range is not None diff --git a/tests/lsp/test_description.py b/tests/lsp/test_description.py index 054d55fecc..7a372586b3 100644 --- a/tests/lsp/test_description.py +++ b/tests/lsp/test_description.py @@ -1,5 +1,5 @@ from sqlmesh.core.context import Context -from sqlmesh.lsp.description import generate_markdown_description +from sqlmesh.core.linter.rules.helpers.lineage import generate_markdown_description def test_model_description() -> None: diff --git a/tests/lsp/test_reference.py b/tests/lsp/test_reference.py index f39bddc059..83995d1818 100644 --- a/tests/lsp/test_reference.py +++ b/tests/lsp/test_reference.py @@ -59,7 +59,7 @@ def test_reference_with_alias() -> None: with open(waiter_revenue_by_day_path, "r") as file: read_file = file.readlines() - assert references[0].uri.endswith("orders.py") + assert str(references[0].path).endswith("orders.py") assert get_string_from_range(read_file, references[0].range) == "sushi.orders" assert ( references[0].markdown_description diff --git a/tests/lsp/test_reference_external_model.py b/tests/lsp/test_reference_external_model.py index ebf6420934..95e91cc720 100644 --- a/tests/lsp/test_reference_external_model.py +++ b/tests/lsp/test_reference_external_model.py @@ -3,7 +3,7 @@ from sqlmesh.core.linter.helpers import read_range_from_file from sqlmesh.lsp.context import LSPContext, ModelTarget from sqlmesh.lsp.helpers import to_sqlmesh_range -from sqlmesh.lsp.reference import get_references, LSPExternalModelReference +from sqlmesh.lsp.reference import get_references from sqlmesh.lsp.uri import URI