diff --git a/sqlmesh/core/loader.py b/sqlmesh/core/loader.py index 90894fd23d..2b40be0230 100644 --- a/sqlmesh/core/loader.py +++ b/sqlmesh/core/loader.py @@ -349,7 +349,7 @@ def _load(path: Path) -> t.List[Model]: for row in YAML().load(file.read()) ] except Exception as ex: - raise ConfigError(self._failed_to_load_model_error(path, ex)) + raise ConfigError(self._failed_to_load_model_error(path, ex), path) for path in paths_to_load: self._track_file(path) @@ -363,7 +363,8 @@ def _load(path: Path) -> t.List[Model]: raise ConfigError( self._failed_to_load_model_error( path, f"Duplicate external model name: '{model.name}'." - ) + ), + path, ) models[model.fqn] = model @@ -375,7 +376,8 @@ def _load(path: Path) -> t.List[Model]: raise ConfigError( self._failed_to_load_model_error( path, f"Duplicate external model name: '{model.name}'." - ) + ), + path, ) models.update({model.fqn: model}) @@ -402,13 +404,15 @@ def _load_requirements(self) -> t.Tuple[t.Dict[str, str], t.Set[str]]: args = [k.strip() for k in line.split("==")] if len(args) != 2: raise ConfigError( - f"Invalid lock file entry '{line.strip()}'. Only 'dep==ver' is supported" + f"Invalid lock file entry '{line.strip()}'. Only 'dep==ver' is supported", + requirements_path, ) dep, ver = args other_ver = requirements.get(dep, ver) if ver != other_ver: raise ConfigError( - f"Conflicting requirement {dep}: {ver} != {other_ver}. Fix your {c.REQUIREMENTS} file." + f"Conflicting requirement {dep}: {ver} != {other_ver}. Fix your {c.REQUIREMENTS} file.", + requirements_path, ) requirements[dep] = ver @@ -619,13 +623,14 @@ def _load_sql_models( raise ConfigError( self._failed_to_load_model_error( path, f"Duplicate SQL model name: '{model.name}'." - ) + ), + path, ) elif model.enabled: model._path = path models[model.fqn] = model except Exception as ex: - raise ConfigError(self._failed_to_load_model_error(path, ex)) + raise ConfigError(self._failed_to_load_model_error(path, ex), path) return models @@ -678,7 +683,7 @@ def _load_python_models( if model.enabled: models[model.fqn] = model except Exception as ex: - raise ConfigError(self._failed_to_load_model_error(path, ex)) + raise ConfigError(self._failed_to_load_model_error(path, ex), path) finally: model_registry._dialect = None @@ -782,7 +787,9 @@ def _load_metrics(self) -> UniqueKeyDict[str, MetricMeta]: metric = load_metric_ddl(expression, path=path, dialect=dialect) metrics[metric.name] = metric except SqlglotError as ex: - raise ConfigError(f"Failed to parse metric definitions at '{path}': {ex}.") + raise ConfigError( + f"Failed to parse metric definitions at '{path}': {ex}.", path + ) return metrics @@ -1005,7 +1012,7 @@ def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]: package=package, ) except Exception as e: - raise ConfigError(f"Failed to load macro file: {path}", e) + raise ConfigError(f"Failed to load macro file: {e}", path) self._macros_max_mtime = macros_max_mtime diff --git a/sqlmesh/core/model/common.py b/sqlmesh/core/model/common.py index c843213f2a..f35a08a28b 100644 --- a/sqlmesh/core/model/common.py +++ b/sqlmesh/core/model/common.py @@ -41,7 +41,7 @@ def make_python_env( macros: MacroRegistry, variables: t.Optional[t.Dict[str, t.Any]] = None, used_variables: t.Optional[t.Set[str]] = None, - path: t.Optional[str | Path] = None, + path: t.Optional[Path] = None, python_env: t.Optional[t.Dict[str, Executable]] = None, strict_resolution: bool = True, blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, @@ -265,12 +265,14 @@ def validate_extra_and_required_fields( klass: t.Type[PydanticModel], provided_fields: t.Set[str], entity_name: str, + path: t.Optional[Path] = None, ) -> None: missing_required_fields = klass.missing_required_fields(provided_fields) if missing_required_fields: field_names = "'" + "', '".join(missing_required_fields) + "'" raise_config_error( - f"Please add required field{'s' if len(missing_required_fields) > 1 else ''} {field_names} to the {entity_name}." + f"Please add required field{'s' if len(missing_required_fields) > 1 else ''} {field_names} to the {entity_name}.", + path, ) extra_fields = klass.extra_fields(provided_fields) @@ -293,7 +295,8 @@ def validate_extra_and_required_fields( similar_msg = "\n\n " + "\n ".join(similar) if similar else "" raise_config_error( - f"Invalid field name{'s' if len(extra_fields) > 1 else ''} present in the {entity_name}: {extra_field_names}{similar_msg}" + f"Invalid field name{'s' if len(extra_fields) > 1 else ''} present in the {entity_name}: {extra_field_names}{similar_msg}", + path, ) diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index 910e6eccc5..f6c83c85f7 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -2434,7 +2434,10 @@ def _create_model( **kwargs: t.Any, ) -> Model: validate_extra_and_required_fields( - klass, {"name", *kwargs} - {"grain", "table_properties"}, "MODEL block" + klass, + {"name", *kwargs} - {"grain", "table_properties"}, + "MODEL block", + path, ) for prop in PROPERTIES: diff --git a/sqlmesh/lsp/context.py b/sqlmesh/lsp/context.py index 0d7ba16c10..30adfce5a2 100644 --- a/sqlmesh/lsp/context.py +++ b/sqlmesh/lsp/context.py @@ -1,6 +1,5 @@ from dataclasses import dataclass from pathlib import Path -import uuid from sqlmesh.core.context import Context import typing as t @@ -35,14 +34,8 @@ class LSPContext: map: t.Dict[Path, t.Union[ModelTarget, AuditTarget]] _render_cache: t.Dict[Path, t.List[RenderModelEntry]] _lint_cache: t.Dict[Path, t.List[AnnotatedRuleViolation]] - _version_id: str - """ - This is a version ID for the context. It is used to track changes to the context. It can be used to - return a version number to the LSP client. - """ def __init__(self, context: Context) -> None: - self._version_id = str(uuid.uuid4()) self.context = context self._render_cache = {} self._lint_cache = {} @@ -70,11 +63,6 @@ def __init__(self, context: Context) -> None: **audit_map, } - @property - def version_id(self) -> str: - """Get the version ID for the context.""" - return self._version_id - def render_model(self, uri: URI) -> t.List[RenderModelEntry]: """Get rendered models for a file, using cache when available. diff --git a/sqlmesh/lsp/errors.py b/sqlmesh/lsp/errors.py new file mode 100644 index 0000000000..a9e778a555 --- /dev/null +++ b/sqlmesh/lsp/errors.py @@ -0,0 +1,51 @@ +from lsprotocol.types import Diagnostic, DiagnosticSeverity, Range, Position + +from sqlmesh.lsp.uri import URI +from sqlmesh.utils.errors import ( + ConfigError, +) +import typing as t + +ContextFailedError = t.Union[str, ConfigError, Exception] + + +def context_error_to_diagnostic( + error: t.Union[Exception, ContextFailedError], + uri_filter: t.Optional[URI] = None, +) -> t.Tuple[t.Optional[t.Tuple[str, Diagnostic]], ContextFailedError]: + """ + Convert an error to a diagnostic message. + If the error is a ConfigError, it will be converted to a diagnostic message. + + uri_filter is used to filter diagnostics by URI. If present, only diagnostics + with a matching URI will be returned. + """ + if isinstance(error, ConfigError): + return config_error_to_diagnostic(error), error + return None, str(error) + + +def config_error_to_diagnostic( + error: ConfigError, + uri_filter: t.Optional[URI] = None, +) -> t.Optional[t.Tuple[str, Diagnostic]]: + if error.location is None: + return None + uri = URI.from_path(error.location).value + if uri_filter and uri != uri_filter.value: + return None + return uri, Diagnostic( + range=Range( + start=Position( + line=0, + character=0, + ), + end=Position( + line=0, + character=0, + ), + ), + message=str(error), + severity=DiagnosticSeverity.Error, + source="SQLMesh", + ) diff --git a/sqlmesh/lsp/main.py b/sqlmesh/lsp/main.py index 0082c4a911..771ebd19c1 100755 --- a/sqlmesh/lsp/main.py +++ b/sqlmesh/lsp/main.py @@ -6,6 +6,7 @@ import typing as t from pathlib import Path import urllib.parse +import uuid from lsprotocol import types from lsprotocol.types import ( @@ -46,6 +47,7 @@ FormatProjectResponse, CustomMethod, ) +from sqlmesh.lsp.errors import ContextFailedError, context_error_to_diagnostic from sqlmesh.lsp.hints import get_hints from sqlmesh.lsp.reference import ( LSPCteReference, @@ -56,17 +58,18 @@ ) from sqlmesh.lsp.rename import prepare_rename, rename_symbol, get_document_highlights from sqlmesh.lsp.uri import URI +from sqlmesh.utils.errors import ConfigError from web.server.api.endpoints.lineage import column_lineage, model_lineage from web.server.api.endpoints.models import get_models from typing import Union -from dataclasses import dataclass +from dataclasses import dataclass, field @dataclass class NoContext: """State when no context has been attempted to load.""" - pass + version_id: str = field(default_factory=lambda: str(uuid.uuid4())) @dataclass @@ -74,14 +77,16 @@ class ContextLoaded: """State when context has been successfully loaded.""" lsp_context: LSPContext + version_id: str = field(default_factory=lambda: str(uuid.uuid4())) @dataclass class ContextFailed: """State when context failed to load with an error message.""" - error_message: str + error: ContextFailedError context: t.Optional[Context] = None + version_id: str = field(default_factory=lambda: str(uuid.uuid4())) ContextState = Union[NoContext, ContextLoaded, ContextFailed] @@ -110,7 +115,7 @@ def __init__( self._supported_custom_methods: t.Dict[ str, t.Callable[ - # mypy unable to recognise the base class + # mypy unable to recognize the base class [LanguageServer, t.Any], t.Any, ], @@ -223,9 +228,8 @@ def _reload_context_and_publish_diagnostics( ) -> None: """Helper method to reload context and publish diagnostics.""" if isinstance(self.context_state, NoContext): - return - - if isinstance(self.context_state, ContextFailed): + pass + elif isinstance(self.context_state, ContextFailed): if self.context_state.context: try: self.context_state.context.load() @@ -235,14 +239,17 @@ def _reload_context_and_publish_diagnostics( ) except Exception as e: ls.log_trace(f"Error loading context: {e}") - if not isinstance(self.context_state, ContextFailed): - raise Exception("Context state should be failed") - self.context_state = ContextFailed( - error_message=str(e), context=self.context_state.context + context = ( + self.context_state.context + if hasattr(self.context_state, "context") + else None ) - return + self.context_state = ContextFailed(error=e, context=context) else: - # If there's no context, try to create one from scratch + # If there's no context, reset to NoContext and try to create one from scratch + ls.log_trace("No partial context available, attempting fresh creation") + self.context_state = NoContext() + self.has_raised_loading_error = False # Reset error flag to show new errors try: self._ensure_context_for_document(uri) # If successful, context_state will be ContextLoaded @@ -253,43 +260,42 @@ def _reload_context_and_publish_diagnostics( ) except Exception as e: ls.log_trace(f"Still cannot load context: {e}") - return - - # Reload the context if it was successfully loaded - try: - context = self.context_state.lsp_context.context - context.load() - # Create new LSPContext which will have fresh, empty caches - self.context_state = ContextLoaded(lsp_context=LSPContext(context)) - except Exception as e: - ls.log_trace(f"Error loading context: {e}") - self.context_state = ContextFailed( - error_message=str(e), context=self.context_state.lsp_context.context - ) - return + # The error will be stored in context_state by _ensure_context_for_document + else: + # Reload the context if it was successfully loaded + try: + context = self.context_state.lsp_context.context + context.load() + # Create new LSPContext which will have fresh, empty caches + self.context_state = ContextLoaded(lsp_context=LSPContext(context)) + except Exception as e: + ls.log_trace(f"Error loading context: {e}") + self.context_state = ContextFailed( + error=e, context=self.context_state.lsp_context.context + ) # Send a workspace diagnostic refresh request to the client. This is used to notify the client that the diagnostics have changed. ls.lsp.send_request( types.WORKSPACE_DIAGNOSTIC_REFRESH, WorkspaceDiagnosticRefreshRequest( - id=self.context_state.lsp_context.version_id, + id=self.context_state.version_id, ), ) - ls.lsp.send_request( types.WORKSPACE_INLAY_HINT_REFRESH, WorkspaceInlayHintRefreshRequest( - id=self.context_state.lsp_context.version_id, + id=self.context_state.version_id, ), ) # Only publish diagnostics if client doesn't support pull diagnostics if not self.client_supports_pull_diagnostics: - diagnostics = self.context_state.lsp_context.lint_model(uri) - ls.publish_diagnostics( - document_uri, - LSPContext.diagnostics_to_lsp_diagnostics(diagnostics), - ) + if hasattr(self.context_state, "lsp_context"): + diagnostics = self.context_state.lsp_context.lint_model(uri) + ls.publish_diagnostics( + document_uri, + LSPContext.diagnostics_to_lsp_diagnostics(diagnostics), + ) def _register_features(self) -> None: """Register LSP features on the internal LanguageServer instance.""" @@ -650,9 +656,21 @@ def workspace_diagnostic( return types.WorkspaceDiagnosticReport(items=items) except Exception as e: - ls.log_trace( - f"Error getting workspace diagnostics: {e}", - ) + ls.log_trace(f"Error getting workspace diagnostics: {e}") + error_diagnostic, error = context_error_to_diagnostic(e) + if error_diagnostic: + uri_value, unpacked_diagnostic = error_diagnostic + return types.WorkspaceDiagnosticReport( + items=[ + types.WorkspaceFullDocumentDiagnosticReport( + kind=types.DocumentDiagnosticReportKind.Full, + result_id=self.context_state.version_id, # No versioning, always fresh + uri=uri_value, + items=[unpacked_diagnostic], + ) + ] + ) + return types.WorkspaceDiagnosticReport(items=[]) @self.server.feature(types.TEXT_DOCUMENT_CODE_ACTION) @@ -759,17 +777,23 @@ def _get_diagnostics_for_uri(self, uri: URI) -> t.Tuple[t.List[types.Diagnostic] context = self._context_get_or_load(uri) diagnostics = context.lint_model(uri) return LSPContext.diagnostics_to_lsp_diagnostics(diagnostics), 0 - except Exception: + except ConfigError as config_error: + diagnostic, error = context_error_to_diagnostic(config_error, uri_filter=uri) + if diagnostic: + return [diagnostic[1]], 0 return [], 0 def _context_get_or_load(self, document_uri: t.Optional[URI] = None) -> LSPContext: - if isinstance(self.context_state, ContextFailed): - raise RuntimeError(self.context_state.error_message) - if isinstance(self.context_state, NoContext): + state = self.context_state + if isinstance(state, ContextFailed): + if isinstance(state.error, str): + raise Exception(state.error) + raise state.error + if isinstance(state, NoContext): self._ensure_context_for_document(document_uri) - if not isinstance(self.context_state, ContextLoaded): - raise RuntimeError("Context is not loaded") - return self.context_state.lsp_context + if isinstance(state, ContextLoaded): + return state.lsp_context + raise RuntimeError("Context failed to load") def _ensure_context_for_document( self, @@ -866,7 +890,7 @@ def _create_lsp_context(self, paths: t.List[Path]) -> t.Optional[LSPContext]: context = self.context_state.lsp_context.context elif isinstance(self.context_state, ContextFailed) and self.context_state.context: context = self.context_state.context - self.context_state = ContextFailed(error_message=str(e), context=context) + self.context_state = ContextFailed(error=e, context=context) return None @staticmethod diff --git a/sqlmesh/utils/errors.py b/sqlmesh/utils/errors.py index a2156d0438..9974fdce0a 100644 --- a/sqlmesh/utils/errors.py +++ b/sqlmesh/utils/errors.py @@ -24,7 +24,12 @@ class SQLMeshError(Exception): class ConfigError(SQLMeshError): - pass + location: t.Optional[Path] = None + + def __init__(self, message: str | Exception, location: t.Optional[Path] = None) -> None: + super().__init__(message) + if location: + self.location = Path(location) if isinstance(location, str) else location class MissingDependencyError(SQLMeshError): @@ -188,12 +193,12 @@ class SignalEvalError(SQLMeshError): def raise_config_error( msg: str, - location: t.Optional[str | Path] = None, + location: t.Optional[Path] = None, error_type: t.Type[ConfigError] = ConfigError, ) -> None: if location: - raise error_type(f"{msg} at '{location}'") - raise error_type(msg) + raise error_type(f"{msg} at '{location}'", location) + raise error_type(msg, location=location) def raise_for_status(response: Response) -> None: