Skip to content

Commit 349d420

Browse files
committed
feat(lsp): return load time errors as diagnostics
- adds path in places to make this as possible with - refreshes well on typing
1 parent 4da307a commit 349d420

File tree

7 files changed

+157
-72
lines changed

7 files changed

+157
-72
lines changed

sqlmesh/core/loader.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def _load(path: Path) -> t.List[Model]:
349349
for row in YAML().load(file.read())
350350
]
351351
except Exception as ex:
352-
raise ConfigError(self._failed_to_load_model_error(path, ex))
352+
raise ConfigError(self._failed_to_load_model_error(path, ex), path)
353353

354354
for path in paths_to_load:
355355
self._track_file(path)
@@ -363,7 +363,8 @@ def _load(path: Path) -> t.List[Model]:
363363
raise ConfigError(
364364
self._failed_to_load_model_error(
365365
path, f"Duplicate external model name: '{model.name}'."
366-
)
366+
),
367+
path,
367368
)
368369
models[model.fqn] = model
369370

@@ -375,7 +376,8 @@ def _load(path: Path) -> t.List[Model]:
375376
raise ConfigError(
376377
self._failed_to_load_model_error(
377378
path, f"Duplicate external model name: '{model.name}'."
378-
)
379+
),
380+
path,
379381
)
380382
models.update({model.fqn: model})
381383

@@ -402,13 +404,15 @@ def _load_requirements(self) -> t.Tuple[t.Dict[str, str], t.Set[str]]:
402404
args = [k.strip() for k in line.split("==")]
403405
if len(args) != 2:
404406
raise ConfigError(
405-
f"Invalid lock file entry '{line.strip()}'. Only 'dep==ver' is supported"
407+
f"Invalid lock file entry '{line.strip()}'. Only 'dep==ver' is supported",
408+
requirements_path,
406409
)
407410
dep, ver = args
408411
other_ver = requirements.get(dep, ver)
409412
if ver != other_ver:
410413
raise ConfigError(
411-
f"Conflicting requirement {dep}: {ver} != {other_ver}. Fix your {c.REQUIREMENTS} file."
414+
f"Conflicting requirement {dep}: {ver} != {other_ver}. Fix your {c.REQUIREMENTS} file.",
415+
requirements_path,
412416
)
413417
requirements[dep] = ver
414418

@@ -619,13 +623,14 @@ def _load_sql_models(
619623
raise ConfigError(
620624
self._failed_to_load_model_error(
621625
path, f"Duplicate SQL model name: '{model.name}'."
622-
)
626+
),
627+
path,
623628
)
624629
elif model.enabled:
625630
model._path = path
626631
models[model.fqn] = model
627632
except Exception as ex:
628-
raise ConfigError(self._failed_to_load_model_error(path, ex))
633+
raise ConfigError(self._failed_to_load_model_error(path, ex), path)
629634

630635
return models
631636

@@ -678,7 +683,7 @@ def _load_python_models(
678683
if model.enabled:
679684
models[model.fqn] = model
680685
except Exception as ex:
681-
raise ConfigError(self._failed_to_load_model_error(path, ex))
686+
raise ConfigError(self._failed_to_load_model_error(path, ex), path)
682687

683688
finally:
684689
model_registry._dialect = None
@@ -782,7 +787,9 @@ def _load_metrics(self) -> UniqueKeyDict[str, MetricMeta]:
782787
metric = load_metric_ddl(expression, path=path, dialect=dialect)
783788
metrics[metric.name] = metric
784789
except SqlglotError as ex:
785-
raise ConfigError(f"Failed to parse metric definitions at '{path}': {ex}.")
790+
raise ConfigError(
791+
f"Failed to parse metric definitions at '{path}': {ex}.", path
792+
)
786793

787794
return metrics
788795

@@ -1005,7 +1012,7 @@ def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]:
10051012
package=package,
10061013
)
10071014
except Exception as e:
1008-
raise ConfigError(f"Failed to load macro file: {path}", e)
1015+
raise ConfigError(f"Failed to load macro file: {e}", path)
10091016

10101017
self._macros_max_mtime = macros_max_mtime
10111018

sqlmesh/core/model/common.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def make_python_env(
4141
macros: MacroRegistry,
4242
variables: t.Optional[t.Dict[str, t.Any]] = None,
4343
used_variables: t.Optional[t.Set[str]] = None,
44-
path: t.Optional[str | Path] = None,
44+
path: t.Optional[Path] = None,
4545
python_env: t.Optional[t.Dict[str, Executable]] = None,
4646
strict_resolution: bool = True,
4747
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
@@ -265,12 +265,14 @@ def validate_extra_and_required_fields(
265265
klass: t.Type[PydanticModel],
266266
provided_fields: t.Set[str],
267267
entity_name: str,
268+
path: t.Optional[Path] = None,
268269
) -> None:
269270
missing_required_fields = klass.missing_required_fields(provided_fields)
270271
if missing_required_fields:
271272
field_names = "'" + "', '".join(missing_required_fields) + "'"
272273
raise_config_error(
273-
f"Please add required field{'s' if len(missing_required_fields) > 1 else ''} {field_names} to the {entity_name}."
274+
f"Please add required field{'s' if len(missing_required_fields) > 1 else ''} {field_names} to the {entity_name}.",
275+
path,
274276
)
275277

276278
extra_fields = klass.extra_fields(provided_fields)
@@ -293,7 +295,8 @@ def validate_extra_and_required_fields(
293295
similar_msg = "\n\n " + "\n ".join(similar) if similar else ""
294296

295297
raise_config_error(
296-
f"Invalid field name{'s' if len(extra_fields) > 1 else ''} present in the {entity_name}: {extra_field_names}{similar_msg}"
298+
f"Invalid field name{'s' if len(extra_fields) > 1 else ''} present in the {entity_name}: {extra_field_names}{similar_msg}",
299+
path,
297300
)
298301

299302

sqlmesh/core/model/definition.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2434,7 +2434,10 @@ def _create_model(
24342434
**kwargs: t.Any,
24352435
) -> Model:
24362436
validate_extra_and_required_fields(
2437-
klass, {"name", *kwargs} - {"grain", "table_properties"}, "MODEL block"
2437+
klass,
2438+
{"name", *kwargs} - {"grain", "table_properties"},
2439+
"MODEL block",
2440+
path,
24382441
)
24392442

24402443
for prop in PROPERTIES:

sqlmesh/lsp/context.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from dataclasses import dataclass
22
from pathlib import Path
3-
import uuid
43
from sqlmesh.core.context import Context
54
import typing as t
65

@@ -35,14 +34,12 @@ class LSPContext:
3534
map: t.Dict[Path, t.Union[ModelTarget, AuditTarget]]
3635
_render_cache: t.Dict[Path, t.List[RenderModelEntry]]
3736
_lint_cache: t.Dict[Path, t.List[AnnotatedRuleViolation]]
38-
_version_id: str
3937
"""
4038
This is a version ID for the context. It is used to track changes to the context. It can be used to
4139
return a version number to the LSP client.
4240
"""
4341

4442
def __init__(self, context: Context) -> None:
45-
self._version_id = str(uuid.uuid4())
4643
self.context = context
4744
self._render_cache = {}
4845
self._lint_cache = {}
@@ -70,11 +67,6 @@ def __init__(self, context: Context) -> None:
7067
**audit_map,
7168
}
7269

73-
@property
74-
def version_id(self) -> str:
75-
"""Get the version ID for the context."""
76-
return self._version_id
77-
7870
def render_model(self, uri: URI) -> t.List[RenderModelEntry]:
7971
"""Get rendered models for a file, using cache when available.
8072

sqlmesh/lsp/errors.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from lsprotocol.types import Diagnostic, DiagnosticSeverity, Range, Position
2+
3+
from sqlmesh.lsp.uri import URI
4+
from sqlmesh.utils.errors import (
5+
ConfigError,
6+
)
7+
import typing as t
8+
9+
type ContextFailedError = str | ConfigError | Exception
10+
11+
12+
def context_error_to_diagnostic(
13+
error: t.Union[Exception, ContextFailedError],
14+
uri_filter: t.Optional[URI] = None,
15+
) -> t.Tuple[t.Optional[t.Tuple[str, Diagnostic]], ContextFailedError]:
16+
"""
17+
Convert an error to a diagnostic message.
18+
If the error is a ConfigError, it will be converted to a diagnostic message.
19+
20+
uri_filter is used to filter diagnostics by URI. If present, only diagnostics
21+
with a matching URI will be returned.
22+
"""
23+
if isinstance(error, ConfigError):
24+
return config_error_to_diagnostic(error), error
25+
return None, str(error)
26+
27+
28+
def config_error_to_diagnostic(
29+
error: ConfigError,
30+
uri_filter: t.Optional[URI] = None,
31+
) -> t.Optional[t.Tuple[str, Diagnostic]]:
32+
if error.location is None:
33+
return None
34+
uri = URI.from_path(error.location).value
35+
if uri_filter and uri != uri_filter.value:
36+
return None
37+
return uri, Diagnostic(
38+
range=Range(
39+
start=Position(
40+
line=0,
41+
character=0,
42+
),
43+
end=Position(
44+
line=0,
45+
character=0,
46+
),
47+
),
48+
message=str(error),
49+
severity=DiagnosticSeverity.Error,
50+
source="SQLMesh",
51+
)

0 commit comments

Comments
 (0)