Skip to content

Commit d8aa539

Browse files
committed
refactor: refactor dialect parsing code to be reused
1 parent 9eba5c1 commit d8aa539

File tree

1 file changed

+76
-35
lines changed

1 file changed

+76
-35
lines changed

sqlmesh/core/linter/helpers.py

Lines changed: 76 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@ def get_range_of_model_block(
156156

157157
splitlines = sql.splitlines()
158158
return Range(
159-
start=start_position.to_range(splitlines).start, end=end_position.to_range(splitlines).end
159+
start=start_position.to_range(splitlines).start,
160+
end=end_position.to_range(splitlines).end,
160161
)
161162

162163

@@ -169,44 +170,84 @@ def get_range_of_a_key_in_model_block(
169170
Get the range of a specific key in the model block of an SQL file.
170171
"""
171172
tokens = tokenize(sql, dialect=dialect)
172-
if tokens is None:
173+
if not tokens:
173174
return None
174175

175-
# Find the start of the model block
176-
start_index = next(
177-
(
176+
# 1) Find the MODEL token
177+
try:
178+
model_idx = next(
178179
i
179-
for i, t in enumerate(tokens)
180-
if t.token_type is TokenType.VAR and t.text.upper() == "MODEL"
181-
),
182-
None,
183-
)
184-
end_index = next(
185-
(i for i, t in enumerate(tokens) if t.token_type is TokenType.SEMICOLON),
186-
None,
187-
)
188-
if start_index is None or end_index is None:
189-
return None
190-
if start_index >= end_index:
180+
for i, tok in enumerate(tokens)
181+
if tok.token_type is TokenType.VAR and tok.text.upper() == "MODEL"
182+
)
183+
except StopIteration:
191184
return None
192185

193-
tokens_of_interest = tokens[start_index + 1 : end_index]
194-
# Find the key token
195-
key_token = next(
196-
(
197-
t
198-
for t in tokens_of_interest
199-
if t.token_type is TokenType.VAR and t.text.upper() == key.upper()
200-
),
201-
None,
202-
)
203-
if key_token is None:
186+
# 2) Find the opening parenthesis for the MODEL properties list
187+
try:
188+
lparen_idx = next(
189+
i
190+
for i in range(model_idx + 1, len(tokens))
191+
if tokens[i].token_type is TokenType.L_PAREN
192+
)
193+
except StopIteration:
204194
return None
205195

206-
position = TokenPositionDetails(
207-
line=key_token.line,
208-
col=key_token.col,
209-
start=key_token.start,
210-
end=key_token.end,
211-
)
212-
return position.to_range(sql.splitlines())
196+
# 3) Find the matching closing parenthesis for that list by tracking depth
197+
depth = 0
198+
rparen_idx: t.Optional[int] = None
199+
for i in range(lparen_idx, len(tokens)):
200+
tt = tokens[i].token_type
201+
if tt is TokenType.L_PAREN:
202+
depth += 1
203+
elif tt is TokenType.R_PAREN:
204+
depth -= 1
205+
if depth == 0:
206+
rparen_idx = i
207+
break
208+
209+
if rparen_idx is None:
210+
# Fallback: stop at the first semicolon after MODEL
211+
try:
212+
rparen_idx = next(
213+
i
214+
for i in range(lparen_idx + 1, len(tokens))
215+
if tokens[i].token_type is TokenType.SEMICOLON
216+
)
217+
except StopIteration:
218+
return None
219+
220+
# 4) Scan within the MODEL property list for the key at top-level (depth == 1)
221+
# Initialize depth to 1 since we're inside the first parentheses
222+
depth = 1
223+
for i in range(lparen_idx + 1, rparen_idx):
224+
tok = tokens[i]
225+
tt = tok.token_type
226+
227+
if tt is TokenType.L_PAREN:
228+
depth += 1
229+
continue
230+
if tt is TokenType.R_PAREN:
231+
depth -= 1
232+
# If we somehow exit before rparen_idx, stop early
233+
if depth <= 0:
234+
break
235+
continue
236+
237+
if depth == 1 and tt is TokenType.VAR and tok.text.upper() == key.upper():
238+
# Validate key position: it should immediately follow '(' or ',' at top level
239+
prev_idx = i - 1
240+
# Skip over non-significant tokens we don't want to gate on (e.g., comments)
241+
while prev_idx >= 0 and tokens[prev_idx].token_type in (TokenType.COMMENT,):
242+
prev_idx -= 1
243+
prev_tt = tokens[prev_idx].token_type if prev_idx >= 0 else None
244+
if prev_tt in (TokenType.L_PAREN, TokenType.COMMA):
245+
position = TokenPositionDetails(
246+
line=tok.line,
247+
col=tok.col,
248+
start=tok.start,
249+
end=tok.end,
250+
)
251+
return position.to_range(sql.splitlines())
252+
253+
return None

0 commit comments

Comments
 (0)