Skip to content

Commit da3dad9

Browse files
committed
feat: add the ability to return range for key and value
1 parent 4106dcd commit da3dad9

File tree

2 files changed

+197
-68
lines changed

2 files changed

+197
-68
lines changed

sqlmesh/core/linter/helpers.py

Lines changed: 138 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from sqlmesh.core.linter.rule import Range, Position
44
from sqlmesh.utils.pydantic import PydanticModel
5-
from sqlglot import tokenize, TokenType
5+
from sqlglot import tokenize, TokenType, Token
66
import typing as t
77

88

@@ -122,57 +122,14 @@ def read_range_from_file(file: Path, text_range: Range) -> str:
122122
return read_range_from_string("".join(lines), text_range)
123123

124124

125-
def get_range_of_model_block(
126-
sql: str,
127-
dialect: str,
128-
) -> t.Optional[Range]:
129-
"""
130-
Get the range of the model block in an SQL file.
131-
"""
132-
tokens = tokenize(sql, dialect=dialect)
133-
134-
# Find start of the model block
135-
start = next(
136-
(t for t in tokens if t.token_type is TokenType.VAR and t.text.upper() == "MODEL"),
137-
None,
138-
)
139-
end = next((t for t in tokens if t.token_type is TokenType.SEMICOLON), None)
140-
141-
if start is None or end is None:
142-
return None
143-
144-
start_position = TokenPositionDetails(
145-
line=start.line,
146-
col=start.col,
147-
start=start.start,
148-
end=start.end,
149-
)
150-
end_position = TokenPositionDetails(
151-
line=end.line,
152-
col=end.col,
153-
start=end.start,
154-
end=end.end,
155-
)
156-
157-
splitlines = sql.splitlines()
158-
return Range(
159-
start=start_position.to_range(splitlines).start,
160-
end=end_position.to_range(splitlines).end,
161-
)
162-
163-
164-
def get_range_of_a_key_in_model_block(
165-
sql: str,
166-
dialect: str,
167-
key: str,
168-
) -> t.Optional[Range]:
125+
def get_start_and_end_of_model_block(
126+
tokens: t.List[Token],
127+
) -> t.Optional[t.Tuple[int, int]]:
169128
"""
170-
Get the range of a specific key in the model block of an SQL file.
129+
Returns the start and end tokens of the MODEL block in an SQL file.
130+
The MODEL block is defined as the first occurrence of the keyword "MODEL" followed by
131+
an opening parenthesis and a closing parenthesis that matches the opening one.
171132
"""
172-
tokens = tokenize(sql, dialect=dialect)
173-
if not tokens:
174-
return None
175-
176133
# 1) Find the MODEL token
177134
try:
178135
model_idx = next(
@@ -216,6 +173,65 @@ def get_range_of_a_key_in_model_block(
216173
)
217174
except StopIteration:
218175
return None
176+
return (
177+
lparen_idx,
178+
rparen_idx,
179+
)
180+
181+
182+
def get_range_of_model_block(
183+
sql: str,
184+
dialect: str,
185+
) -> t.Optional[Range]:
186+
"""
187+
Get the range of the model block in an SQL file,
188+
"""
189+
tokens = tokenize(sql, dialect=dialect)
190+
191+
block = get_start_and_end_of_model_block(tokens)
192+
if not block:
193+
return None
194+
195+
(start_idx, end_idx) = block
196+
start = tokens[start_idx - 1]
197+
end = tokens[end_idx + 1]
198+
start_position = TokenPositionDetails(
199+
line=start.line,
200+
col=start.col,
201+
start=start.start,
202+
end=start.end,
203+
)
204+
end_position = TokenPositionDetails(
205+
line=end.line,
206+
col=end.col,
207+
start=end.start,
208+
end=end.end,
209+
)
210+
splitlines = sql.splitlines()
211+
return Range(
212+
start=start_position.to_range(splitlines).start,
213+
end=end_position.to_range(splitlines).end,
214+
)
215+
216+
217+
def get_range_of_a_key_in_model_block(
218+
sql: str,
219+
dialect: str,
220+
key: str,
221+
) -> t.Optional[t.Tuple[Range, Range]]:
222+
"""
223+
Get the ranges of a specific key and its value in the MODEL block of an SQL file.
224+
225+
Returns a tuple of (key_range, value_range) if found, otherwise None.
226+
"""
227+
tokens = tokenize(sql, dialect=dialect)
228+
if not tokens:
229+
return None
230+
231+
block = get_start_and_end_of_model_block(tokens)
232+
if not block:
233+
return None
234+
(lparen_idx, rparen_idx) = block
219235

220236
# 4) Scan within the MODEL property list for the key at top-level (depth == 1)
221237
# Initialize depth to 1 since we're inside the first parentheses
@@ -237,17 +253,78 @@ def get_range_of_a_key_in_model_block(
237253
if depth == 1 and tt is TokenType.VAR and tok.text.upper() == key.upper():
238254
# Validate key position: it should immediately follow '(' or ',' at top level
239255
prev_idx = i - 1
240-
# Skip over non-significant tokens we don't want to gate on (e.g., comments)
241-
while prev_idx >= 0 and tokens[prev_idx].token_type in (TokenType.COMMENT,):
242-
prev_idx -= 1
243256
prev_tt = tokens[prev_idx].token_type if prev_idx >= 0 else None
244-
if prev_tt in (TokenType.L_PAREN, TokenType.COMMA):
245-
position = TokenPositionDetails(
246-
line=tok.line,
247-
col=tok.col,
248-
start=tok.start,
249-
end=tok.end,
250-
)
251-
return position.to_range(sql.splitlines())
257+
if prev_tt not in (TokenType.L_PAREN, TokenType.COMMA):
258+
continue
259+
260+
# Key range
261+
lines = sql.splitlines()
262+
key_start = TokenPositionDetails(
263+
line=tok.line, col=tok.col, start=tok.start, end=tok.end
264+
)
265+
key_range = key_start.to_range(lines)
266+
267+
# Find value start: the next non-comment token after the key
268+
value_start_idx = i + 1
269+
if value_start_idx >= rparen_idx:
270+
return None
271+
272+
# Walk to the end of the value expression: until top-level comma or closing paren
273+
# Track internal nesting for (), [], {}
274+
nested = 0
275+
j = value_start_idx
276+
value_end_idx = value_start_idx
277+
278+
def is_open(t: TokenType) -> bool:
279+
return t in (TokenType.L_PAREN, TokenType.L_BRACE, TokenType.L_BRACKET)
280+
281+
def is_close(t: TokenType) -> bool:
282+
return t in (TokenType.R_PAREN, TokenType.R_BRACE, TokenType.R_BRACKET)
283+
284+
while j < rparen_idx:
285+
ttype = tokens[j].token_type
286+
if is_open(ttype):
287+
nested += 1
288+
elif is_close(ttype):
289+
nested -= 1
290+
291+
# End of value: at top-level (nested == 0) encountering a comma or the end paren
292+
if nested == 0 and (
293+
ttype is TokenType.COMMA or (ttype is TokenType.R_PAREN and depth == 1)
294+
):
295+
# For comma, don't include it in the value range
296+
# For closing paren, include it only if it's part of the value structure
297+
if ttype is TokenType.COMMA:
298+
# Don't include the comma in the value range
299+
break
300+
else:
301+
# Include the closing parenthesis in the value range
302+
value_end_idx = j
303+
break
304+
305+
value_end_idx = j
306+
j += 1
307+
308+
value_start_tok = tokens[value_start_idx]
309+
value_end_tok = tokens[value_end_idx]
310+
311+
value_start_pos = TokenPositionDetails(
312+
line=value_start_tok.line,
313+
col=value_start_tok.col,
314+
start=value_start_tok.start,
315+
end=value_start_tok.end,
316+
)
317+
value_end_pos = TokenPositionDetails(
318+
line=value_end_tok.line,
319+
col=value_end_tok.col,
320+
start=value_end_tok.start,
321+
end=value_end_tok.end,
322+
)
323+
value_range = Range(
324+
start=value_start_pos.to_range(lines).start,
325+
end=value_end_pos.to_range(lines).end,
326+
)
327+
328+
return (key_range, value_range)
252329

253330
return None

tests/core/linter/test_helpers.py

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,17 @@ def test_get_range_of_a_key_in_model_block_testing_on_sushi():
5252
]
5353
assert len(sql_models) > 0
5454

55+
# Test that the function works for all keys in the model block
5556
for model in sql_models:
56-
possible_keys = ["name", "tags", "description", "columns", "owner", "cron", "dialect"]
57+
possible_keys = [
58+
"name",
59+
"tags",
60+
"description",
61+
"column_descriptions",
62+
"owner",
63+
"cron",
64+
"dialect",
65+
]
5766

5867
dialect = model.dialect
5968
assert dialect is not None
@@ -67,12 +76,55 @@ def test_get_range_of_a_key_in_model_block_testing_on_sushi():
6776
count_properties_checked = 0
6877

6978
for key in possible_keys:
70-
range = get_range_of_a_key_in_model_block(content, dialect, key)
71-
72-
# Check that the range starts with the key and ends with ;
73-
if range:
74-
read_range = read_range_from_file(path, range)
75-
assert read_range.lower() == key.lower()
79+
ranges = get_range_of_a_key_in_model_block(content, dialect, key)
80+
81+
if ranges:
82+
key_range, value_range = ranges
83+
read_key = read_range_from_file(path, key_range)
84+
assert read_key.lower() == key.lower()
85+
# Value range should be non-empty
86+
read_value = read_range_from_file(path, value_range)
87+
assert len(read_value) > 0
7688
count_properties_checked += 1
7789

7890
assert count_properties_checked > 0
91+
92+
# Test that the function works for different kind of value blocks
93+
tests = [
94+
("sushi.customers", "name", "sushi.customers"),
95+
(
96+
"sushi.customers",
97+
"tags",
98+
"(pii, fact)",
99+
),
100+
("sushi.customers", "description", "'Sushi customer data'"),
101+
(
102+
"sushi.customers",
103+
"column_descriptions",
104+
"( customer_id = 'customer_id uniquely identifies customers' )",
105+
),
106+
("sushi.customers", "owner", "jen"),
107+
("sushi.customers", "cron", "'@daily'"),
108+
]
109+
for model_name, key, value in tests:
110+
model = context.get_model(model_name)
111+
assert model is not None
112+
113+
dialect = model.dialect
114+
assert dialect is not None
115+
116+
path = model._path
117+
assert path is not None
118+
119+
with open(path, "r", encoding="utf-8") as file:
120+
content = file.read()
121+
122+
ranges = get_range_of_a_key_in_model_block(content, dialect, key)
123+
assert ranges is not None, f"Could not find key '{key}' in model '{model_name}'"
124+
125+
key_range, value_range = ranges
126+
read_key = read_range_from_file(path, key_range)
127+
assert read_key.lower() == key.lower()
128+
129+
read_value = read_range_from_file(path, value_range)
130+
assert read_value == value

0 commit comments

Comments
 (0)