Skip to content

Commit 9728f66

Browse files
committed
feat: add the ability to return range for key and value
1 parent d8aa539 commit 9728f66

File tree

2 files changed

+174
-20
lines changed

2 files changed

+174
-20
lines changed

sqlmesh/core/linter/helpers.py

Lines changed: 115 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ def to_range(self, read_file: t.Optional[t.List[str]]) -> Range:
5353
)
5454

5555
if read_file is None:
56-
raise ValueError("read_file must be provided when start and end positions differ.")
56+
raise ValueError(
57+
"read_file must be provided when start and end positions differ."
58+
)
5759

5860
# Convert from 1-indexed to 0-indexed for line only
5961
end_line_0 = self.line - 1
@@ -133,7 +135,11 @@ def get_range_of_model_block(
133135

134136
# Find start of the model block
135137
start = next(
136-
(t for t in tokens if t.token_type is TokenType.VAR and t.text.upper() == "MODEL"),
138+
(
139+
t
140+
for t in tokens
141+
if t.token_type is TokenType.VAR and t.text.upper() == "MODEL"
142+
),
137143
None,
138144
)
139145
end = next((t for t in tokens if t.token_type is TokenType.SEMICOLON), None)
@@ -165,9 +171,11 @@ def get_range_of_a_key_in_model_block(
165171
sql: str,
166172
dialect: str,
167173
key: str,
168-
) -> t.Optional[Range]:
174+
) -> t.Optional[t.Tuple[Range, Range]]:
169175
"""
170-
Get the range of a specific key in the model block of an SQL file.
176+
Get the ranges of a specific key and its value in the MODEL block of an SQL file.
177+
178+
Returns a tuple of (key_range, value_range) if found, otherwise None.
171179
"""
172180
tokens = tokenize(sql, dialect=dialect)
173181
if not tokens:
@@ -237,17 +245,111 @@ def get_range_of_a_key_in_model_block(
237245
if depth == 1 and tt is TokenType.VAR and tok.text.upper() == key.upper():
238246
# Validate key position: it should immediately follow '(' or ',' at top level
239247
prev_idx = i - 1
240-
# Skip over non-significant tokens we don't want to gate on (e.g., comments)
248+
# Skip comments
241249
while prev_idx >= 0 and tokens[prev_idx].token_type in (TokenType.COMMENT,):
242250
prev_idx -= 1
243251
prev_tt = tokens[prev_idx].token_type if prev_idx >= 0 else None
244-
if prev_tt in (TokenType.L_PAREN, TokenType.COMMA):
245-
position = TokenPositionDetails(
246-
line=tok.line,
247-
col=tok.col,
248-
start=tok.start,
249-
end=tok.end,
250-
)
251-
return position.to_range(sql.splitlines())
252+
if prev_tt not in (TokenType.L_PAREN, TokenType.COMMA):
253+
continue
254+
255+
# Key range
256+
lines = sql.splitlines()
257+
key_start = TokenPositionDetails(
258+
line=tok.line, col=tok.col, start=tok.start, end=tok.end
259+
)
260+
key_range = key_start.to_range(lines)
261+
262+
# Find value start: the next non-comment token after the key
263+
value_start_idx = i + 1
264+
while value_start_idx < rparen_idx and tokens[
265+
value_start_idx
266+
].token_type in (TokenType.COMMENT,):
267+
value_start_idx += 1
268+
if value_start_idx >= rparen_idx:
269+
return None
270+
271+
# Walk to the end of the value expression: until top-level comma or closing paren
272+
# Track internal nesting for (), [], {}
273+
nested = 0
274+
j = value_start_idx
275+
value_end_idx = value_start_idx
276+
277+
def is_open(t: TokenType) -> bool:
278+
return t in (TokenType.L_PAREN, TokenType.L_BRACE, TokenType.L_BRACKET)
279+
280+
def is_close(t: TokenType) -> bool:
281+
return t in (TokenType.R_PAREN, TokenType.R_BRACE, TokenType.R_BRACKET)
282+
283+
while j < rparen_idx:
284+
ttype = tokens[j].token_type
285+
if ttype is TokenType.COMMENT:
286+
j += 1
287+
continue
288+
if is_open(ttype):
289+
nested += 1
290+
elif is_close(ttype):
291+
nested -= 1
292+
293+
# End of value: at top-level (nested == 0) encountering a comma or the end paren
294+
if nested == 0 and (
295+
ttype is TokenType.COMMA
296+
or (ttype is TokenType.R_PAREN and depth == 1)
297+
):
298+
# For comma, don't include it in the value range
299+
# For closing paren, include it only if it's part of the value structure
300+
if ttype is TokenType.COMMA:
301+
# Don't include the comma in the value range
302+
break
303+
else:
304+
# Include the closing parenthesis in the value range
305+
value_end_idx = j
306+
break
307+
308+
value_end_idx = j
309+
j += 1
310+
311+
# Special case: if the value ends with a closing parenthesis that's part of the value
312+
# (not the MODEL block's closing parenthesis), we need to include it
313+
if value_end_idx < rparen_idx - 1:
314+
next_token = tokens[value_end_idx + 1]
315+
if next_token.token_type is TokenType.COMMA:
316+
# Value ends before the comma, which is correct
317+
pass
318+
elif next_token.token_type is TokenType.R_PAREN and depth == 1:
319+
# This is the MODEL block's closing parenthesis, don't include it
320+
pass
321+
else:
322+
# Check if we should extend the range to include more tokens
323+
# This handles cases like incomplete parsing
324+
pass
325+
326+
# Trim trailing comments from value end
327+
while (
328+
value_end_idx > value_start_idx
329+
and tokens[value_end_idx].token_type is TokenType.COMMENT
330+
):
331+
value_end_idx -= 1
332+
333+
value_start_tok = tokens[value_start_idx]
334+
value_end_tok = tokens[value_end_idx]
335+
336+
value_start_pos = TokenPositionDetails(
337+
line=value_start_tok.line,
338+
col=value_start_tok.col,
339+
start=value_start_tok.start,
340+
end=value_start_tok.end,
341+
)
342+
value_end_pos = TokenPositionDetails(
343+
line=value_end_tok.line,
344+
col=value_end_tok.col,
345+
start=value_end_tok.start,
346+
end=value_end_tok.end,
347+
)
348+
value_range = Range(
349+
start=value_start_pos.to_range(lines).start,
350+
end=value_end_pos.to_range(lines).end,
351+
)
352+
353+
return (key_range, value_range)
252354

253355
return None

tests/core/linter/test_helpers.py

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,17 @@ def test_get_range_of_a_key_in_model_block_testing_on_sushi():
5252
]
5353
assert len(sql_models) > 0
5454

55+
# Test that the function works for all keys in the model block
5556
for model in sql_models:
56-
possible_keys = ["name", "tags", "description", "columns", "owner", "cron", "dialect"]
57+
possible_keys = [
58+
"name",
59+
"tags",
60+
"description",
61+
"column_descriptions",
62+
"owner",
63+
"cron",
64+
"dialect",
65+
]
5766

5867
dialect = model.dialect
5968
assert dialect is not None
@@ -67,12 +76,55 @@ def test_get_range_of_a_key_in_model_block_testing_on_sushi():
6776
count_properties_checked = 0
6877

6978
for key in possible_keys:
70-
range = get_range_of_a_key_in_model_block(content, dialect, key)
71-
72-
# Check that the range starts with the key and ends with ;
73-
if range:
74-
read_range = read_range_from_file(path, range)
75-
assert read_range.lower() == key.lower()
79+
ranges = get_range_of_a_key_in_model_block(content, dialect, key)
80+
81+
if ranges:
82+
key_range, value_range = ranges
83+
read_key = read_range_from_file(path, key_range)
84+
assert read_key.lower() == key.lower()
85+
# Value range should be non-empty
86+
read_value = read_range_from_file(path, value_range)
87+
assert len(read_value) > 0
7688
count_properties_checked += 1
7789

7890
assert count_properties_checked > 0
91+
92+
# Test that the function works for different kind of value blocks
93+
tests = [
94+
("sushi.customers", "name", "sushi.customers"),
95+
(
96+
"sushi.customers",
97+
"tags",
98+
"(pii, fact)",
99+
),
100+
("sushi.customers", "description", "'Sushi customer data'"),
101+
(
102+
"sushi.customers",
103+
"column_descriptions",
104+
"( customer_id = 'customer_id uniquely identifies customers' )",
105+
),
106+
("sushi.customers", "owner", "jen"),
107+
("sushi.customers", "cron", "'@daily'"),
108+
]
109+
for model_name, key, value in tests:
110+
model = context.get_model(model_name)
111+
assert model is not None
112+
113+
dialect = model.dialect
114+
assert dialect is not None
115+
116+
path = model._path
117+
assert path is not None
118+
119+
with open(path, "r", encoding="utf-8") as file:
120+
content = file.read()
121+
122+
ranges = get_range_of_a_key_in_model_block(content, dialect, key)
123+
assert ranges is not None, f"Could not find key '{key}' in model '{model_name}'"
124+
125+
key_range, value_range = ranges
126+
read_key = read_range_from_file(path, key_range)
127+
assert read_key.lower() == key.lower()
128+
129+
read_value = read_range_from_file(path, value_range)
130+
assert read_value == value

0 commit comments

Comments
 (0)