Skip to content

Commit 8847325

Browse files
committed
feat: helper to find key in model block
- to be used by linter in the future
1 parent 751c38d commit 8847325

File tree

2 files changed

+95
-1
lines changed

2 files changed

+95
-1
lines changed

sqlmesh/core/linter/helpers.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,55 @@ def get_range_of_model_block(
158158
return Range(
159159
start=start_position.to_range(splitlines).start, end=end_position.to_range(splitlines).end
160160
)
161+
162+
163+
def get_range_of_a_key_in_model_block(
164+
sql: str,
165+
dialect: str,
166+
key: str,
167+
) -> t.Optional[Range]:
168+
"""
169+
Get the range of a specific key in the model block of an SQL file.
170+
"""
171+
tokens = tokenize(sql, dialect=dialect)
172+
if tokens is None:
173+
return None
174+
175+
# Find the start of the model block
176+
start_index = next(
177+
(
178+
i
179+
for i, t in enumerate(tokens)
180+
if t.token_type is TokenType.VAR and t.text.upper() == "MODEL"
181+
),
182+
None,
183+
)
184+
end_index = next(
185+
(i for i, t in enumerate(tokens) if t.token_type is TokenType.SEMICOLON),
186+
None,
187+
)
188+
if start_index is None or end_index is None:
189+
return None
190+
if start_index >= end_index:
191+
return None
192+
193+
tokens_of_interest = tokens[start_index + 1 : end_index]
194+
# Find the key token
195+
key_token = next(
196+
(
197+
t
198+
for t in tokens_of_interest
199+
if t.token_type is TokenType.VAR and t.text.upper() == key.upper()
200+
),
201+
None,
202+
)
203+
if key_token is None:
204+
return None
205+
206+
position = TokenPositionDetails(
207+
line=key_token.line,
208+
col=key_token.col,
209+
start=key_token.start,
210+
end=key_token.end,
211+
)
212+
return position.to_range(sql.splitlines())

tests/core/linter/test_helpers.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from sqlmesh import Context
2-
from sqlmesh.core.linter.helpers import read_range_from_file, get_range_of_model_block
2+
from sqlmesh.core.linter.helpers import (
3+
read_range_from_file,
4+
get_range_of_model_block,
5+
get_range_of_a_key_in_model_block,
6+
)
37
from sqlmesh.core.model import SqlModel
48

59

@@ -34,3 +38,41 @@ def test_get_position_of_model_block():
3438
read_range = read_range_from_file(path, range)
3539
assert read_range.startswith("MODEL")
3640
assert read_range.endswith(";")
41+
42+
43+
def test_get_range_of_a_key_in_model_block_testing_on_sushi():
44+
context = Context(paths=["examples/sushi"])
45+
46+
sql_models = [
47+
model
48+
for model in context.models.values()
49+
if isinstance(model, SqlModel)
50+
and model._path is not None
51+
and str(model._path).endswith(".sql")
52+
]
53+
assert len(sql_models) > 0
54+
55+
for model in sql_models:
56+
possible_keys = ["name", "tags", "description", "columns", "owner", "cron", "dialect"]
57+
58+
dialect = model.dialect
59+
assert dialect is not None
60+
61+
path = model._path
62+
assert path is not None
63+
64+
with open(path, "r", encoding="utf-8") as file:
65+
content = file.read()
66+
67+
count_properties_checked = 0
68+
69+
for key in possible_keys:
70+
range = get_range_of_a_key_in_model_block(content, dialect, key)
71+
72+
# Check that the range starts with the key and ends with ;
73+
if range:
74+
read_range = read_range_from_file(path, range)
75+
assert read_range.lower() == key.lower()
76+
count_properties_checked += 1
77+
78+
assert count_properties_checked > 0

0 commit comments

Comments
 (0)