22
33from sqlmesh .core .linter .rule import Range , Position
44from sqlmesh .utils .pydantic import PydanticModel
5- from sqlglot import tokenize , TokenType
5+ from sqlglot import tokenize , TokenType , Token
66import typing as t
77
88
@@ -122,57 +122,14 @@ def read_range_from_file(file: Path, text_range: Range) -> str:
122122 return read_range_from_string ("" .join (lines ), text_range )
123123
124124
125- def get_range_of_model_block (
126- sql : str ,
127- dialect : str ,
128- ) -> t .Optional [Range ]:
129- """
130- Get the range of the model block in an SQL file.
131- """
132- tokens = tokenize (sql , dialect = dialect )
133-
134- # Find start of the model block
135- start = next (
136- (t for t in tokens if t .token_type is TokenType .VAR and t .text .upper () == "MODEL" ),
137- None ,
138- )
139- end = next ((t for t in tokens if t .token_type is TokenType .SEMICOLON ), None )
140-
141- if start is None or end is None :
142- return None
143-
144- start_position = TokenPositionDetails (
145- line = start .line ,
146- col = start .col ,
147- start = start .start ,
148- end = start .end ,
149- )
150- end_position = TokenPositionDetails (
151- line = end .line ,
152- col = end .col ,
153- start = end .start ,
154- end = end .end ,
155- )
156-
157- splitlines = sql .splitlines ()
158- return Range (
159- start = start_position .to_range (splitlines ).start ,
160- end = end_position .to_range (splitlines ).end ,
161- )
162-
163-
164- def get_range_of_a_key_in_model_block (
165- sql : str ,
166- dialect : str ,
167- key : str ,
168- ) -> t .Optional [Range ]:
125+ def get_start_and_end_of_model_block (
126+ tokens : t .List [Token ],
127+ ) -> t .Optional [t .Tuple [int , int ]]:
169128 """
170- Get the range of a specific key in the model block of an SQL file.
129+ Returns the start and end tokens of the MODEL block in an SQL file.
130+ The MODEL block is defined as the first occurrence of the keyword "MODEL" followed by
131+ an opening parenthesis and a closing parenthesis that matches the opening one.
171132 """
172- tokens = tokenize (sql , dialect = dialect )
173- if not tokens :
174- return None
175-
176133 # 1) Find the MODEL token
177134 try :
178135 model_idx = next (
@@ -216,6 +173,65 @@ def get_range_of_a_key_in_model_block(
216173 )
217174 except StopIteration :
218175 return None
176+ return (
177+ lparen_idx ,
178+ rparen_idx ,
179+ )
180+
181+
182+ def get_range_of_model_block (
183+ sql : str ,
184+ dialect : str ,
185+ ) -> t .Optional [Range ]:
186+ """
187+ Get the range of the model block in an SQL file,
188+ """
189+ tokens = tokenize (sql , dialect = dialect )
190+
191+ block = get_start_and_end_of_model_block (tokens )
192+ if not block :
193+ return None
194+
195+ (start_idx , end_idx ) = block
196+ start = tokens [start_idx - 1 ]
197+ end = tokens [end_idx + 1 ]
198+ start_position = TokenPositionDetails (
199+ line = start .line ,
200+ col = start .col ,
201+ start = start .start ,
202+ end = start .end ,
203+ )
204+ end_position = TokenPositionDetails (
205+ line = end .line ,
206+ col = end .col ,
207+ start = end .start ,
208+ end = end .end ,
209+ )
210+ splitlines = sql .splitlines ()
211+ return Range (
212+ start = start_position .to_range (splitlines ).start ,
213+ end = end_position .to_range (splitlines ).end ,
214+ )
215+
216+
217+ def get_range_of_a_key_in_model_block (
218+ sql : str ,
219+ dialect : str ,
220+ key : str ,
221+ ) -> t .Optional [t .Tuple [Range , Range ]]:
222+ """
223+ Get the ranges of a specific key and its value in the MODEL block of an SQL file.
224+
225+ Returns a tuple of (key_range, value_range) if found, otherwise None.
226+ """
227+ tokens = tokenize (sql , dialect = dialect )
228+ if not tokens :
229+ return None
230+
231+ block = get_start_and_end_of_model_block (tokens )
232+ if not block :
233+ return None
234+ (lparen_idx , rparen_idx ) = block
219235
220236 # 4) Scan within the MODEL property list for the key at top-level (depth == 1)
221237 # Initialize depth to 1 since we're inside the first parentheses
@@ -237,17 +253,78 @@ def get_range_of_a_key_in_model_block(
237253 if depth == 1 and tt is TokenType .VAR and tok .text .upper () == key .upper ():
238254 # Validate key position: it should immediately follow '(' or ',' at top level
239255 prev_idx = i - 1
240- # Skip over non-significant tokens we don't want to gate on (e.g., comments)
241- while prev_idx >= 0 and tokens [prev_idx ].token_type in (TokenType .COMMENT ,):
242- prev_idx -= 1
243256 prev_tt = tokens [prev_idx ].token_type if prev_idx >= 0 else None
244- if prev_tt in (TokenType .L_PAREN , TokenType .COMMA ):
245- position = TokenPositionDetails (
246- line = tok .line ,
247- col = tok .col ,
248- start = tok .start ,
249- end = tok .end ,
250- )
251- return position .to_range (sql .splitlines ())
257+ if prev_tt not in (TokenType .L_PAREN , TokenType .COMMA ):
258+ continue
259+
260+ # Key range
261+ lines = sql .splitlines ()
262+ key_start = TokenPositionDetails (
263+ line = tok .line , col = tok .col , start = tok .start , end = tok .end
264+ )
265+ key_range = key_start .to_range (lines )
266+
267+ # Find value start: the next non-comment token after the key
268+ value_start_idx = i + 1
269+ if value_start_idx >= rparen_idx :
270+ return None
271+
272+ # Walk to the end of the value expression: until top-level comma or closing paren
273+ # Track internal nesting for (), [], {}
274+ nested = 0
275+ j = value_start_idx
276+ value_end_idx = value_start_idx
277+
278+ def is_open (t : TokenType ) -> bool :
279+ return t in (TokenType .L_PAREN , TokenType .L_BRACE , TokenType .L_BRACKET )
280+
281+ def is_close (t : TokenType ) -> bool :
282+ return t in (TokenType .R_PAREN , TokenType .R_BRACE , TokenType .R_BRACKET )
283+
284+ while j < rparen_idx :
285+ ttype = tokens [j ].token_type
286+ if is_open (ttype ):
287+ nested += 1
288+ elif is_close (ttype ):
289+ nested -= 1
290+
291+ # End of value: at top-level (nested == 0) encountering a comma or the end paren
292+ if nested == 0 and (
293+ ttype is TokenType .COMMA or (ttype is TokenType .R_PAREN and depth == 1 )
294+ ):
295+ # For comma, don't include it in the value range
296+ # For closing paren, include it only if it's part of the value structure
297+ if ttype is TokenType .COMMA :
298+ # Don't include the comma in the value range
299+ break
300+ else :
301+ # Include the closing parenthesis in the value range
302+ value_end_idx = j
303+ break
304+
305+ value_end_idx = j
306+ j += 1
307+
308+ value_start_tok = tokens [value_start_idx ]
309+ value_end_tok = tokens [value_end_idx ]
310+
311+ value_start_pos = TokenPositionDetails (
312+ line = value_start_tok .line ,
313+ col = value_start_tok .col ,
314+ start = value_start_tok .start ,
315+ end = value_start_tok .end ,
316+ )
317+ value_end_pos = TokenPositionDetails (
318+ line = value_end_tok .line ,
319+ col = value_end_tok .col ,
320+ start = value_end_tok .start ,
321+ end = value_end_tok .end ,
322+ )
323+ value_range = Range (
324+ start = value_start_pos .to_range (lines ).start ,
325+ end = value_end_pos .to_range (lines ).end ,
326+ )
327+
328+ return (key_range , value_range )
252329
253330 return None
0 commit comments