Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions sqlmesh/core/linter/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,55 @@ def get_range_of_model_block(
return Range(
start=start_position.to_range(splitlines).start, end=end_position.to_range(splitlines).end
)


def get_range_of_a_key_in_model_block(
sql: str,
dialect: str,
key: str,
) -> t.Optional[Range]:
"""
Get the range of a specific key in the model block of an SQL file.
"""
tokens = tokenize(sql, dialect=dialect)
if tokens is None:
return None

# Find the start of the model block
start_index = next(
(
i
for i, t in enumerate(tokens)
if t.token_type is TokenType.VAR and t.text.upper() == "MODEL"
),
None,
)
end_index = next(
(i for i, t in enumerate(tokens) if t.token_type is TokenType.SEMICOLON),
None,
)
if start_index is None or end_index is None:
return None
if start_index >= end_index:
return None

tokens_of_interest = tokens[start_index + 1 : end_index]
# Find the key token
key_token = next(
(
t
for t in tokens_of_interest
if t.token_type is TokenType.VAR and t.text.upper() == key.upper()
),
None,
)
Comment on lines +195 to +202
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a heads up that this looks good as a best-effort approach, just not sure if it guarantees you'll match the exact token. Think that the token corresponding to the property's key could appear elsewhere in the MODEL block

Probably not worth trying to further improve it, but a simple heuristic if you notice any issues would be to check if the previous token is a comma, since properties are comma-separated.

if key_token is None:
return None

position = TokenPositionDetails(
line=key_token.line,
col=key_token.col,
start=key_token.start,
end=key_token.end,
)
return position.to_range(sql.splitlines())
44 changes: 43 additions & 1 deletion tests/core/linter/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from sqlmesh import Context
from sqlmesh.core.linter.helpers import read_range_from_file, get_range_of_model_block
from sqlmesh.core.linter.helpers import (
read_range_from_file,
get_range_of_model_block,
get_range_of_a_key_in_model_block,
)
from sqlmesh.core.model import SqlModel


Expand Down Expand Up @@ -34,3 +38,41 @@ def test_get_position_of_model_block():
read_range = read_range_from_file(path, range)
assert read_range.startswith("MODEL")
assert read_range.endswith(";")


def test_get_range_of_a_key_in_model_block_testing_on_sushi():
context = Context(paths=["examples/sushi"])

sql_models = [
model
for model in context.models.values()
if isinstance(model, SqlModel)
and model._path is not None
and str(model._path).endswith(".sql")
]
assert len(sql_models) > 0

for model in sql_models:
possible_keys = ["name", "tags", "description", "columns", "owner", "cron", "dialect"]

dialect = model.dialect
assert dialect is not None

path = model._path
assert path is not None

with open(path, "r", encoding="utf-8") as file:
content = file.read()

count_properties_checked = 0

for key in possible_keys:
range = get_range_of_a_key_in_model_block(content, dialect, key)

# Check that the range starts with the key and ends with ;
if range:
read_range = read_range_from_file(path, range)
assert read_range.lower() == key.lower()
count_properties_checked += 1

assert count_properties_checked > 0