From 30fd8fb4d929775a1ac7daf173b336cd0365cfac Mon Sep 17 00:00:00 2001 From: Ben King <9087625+benfdking@users.noreply.github.com> Date: Mon, 11 Aug 2025 12:00:38 +0200 Subject: [PATCH 1/4] refactor: refactor dialect parsing code to be reused --- sqlmesh/core/linter/helpers.py | 111 ++++++++++++++++++++++----------- 1 file changed, 76 insertions(+), 35 deletions(-) diff --git a/sqlmesh/core/linter/helpers.py b/sqlmesh/core/linter/helpers.py index 3f6e96765f..075b25b6d4 100644 --- a/sqlmesh/core/linter/helpers.py +++ b/sqlmesh/core/linter/helpers.py @@ -156,7 +156,8 @@ def get_range_of_model_block( splitlines = sql.splitlines() return Range( - start=start_position.to_range(splitlines).start, end=end_position.to_range(splitlines).end + start=start_position.to_range(splitlines).start, + end=end_position.to_range(splitlines).end, ) @@ -169,44 +170,84 @@ def get_range_of_a_key_in_model_block( Get the range of a specific key in the model block of an SQL file. """ tokens = tokenize(sql, dialect=dialect) - if tokens is None: + if not tokens: return None - # Find the start of the model block - start_index = next( - ( + # 1) Find the MODEL token + try: + model_idx = next( i - for i, t in enumerate(tokens) - if t.token_type is TokenType.VAR and t.text.upper() == "MODEL" - ), - None, - ) - end_index = next( - (i for i, t in enumerate(tokens) if t.token_type is TokenType.SEMICOLON), - None, - ) - if start_index is None or end_index is None: - return None - if start_index >= end_index: + for i, tok in enumerate(tokens) + if tok.token_type is TokenType.VAR and tok.text.upper() == "MODEL" + ) + except StopIteration: return None - tokens_of_interest = tokens[start_index + 1 : end_index] - # Find the key token - key_token = next( - ( - t - for t in tokens_of_interest - if t.token_type is TokenType.VAR and t.text.upper() == key.upper() - ), - None, - ) - if key_token is None: + # 2) Find the opening parenthesis for the MODEL properties list + try: + lparen_idx = next( + i + for i in range(model_idx + 1, len(tokens)) + if tokens[i].token_type is TokenType.L_PAREN + ) + except StopIteration: return None - position = TokenPositionDetails( - line=key_token.line, - col=key_token.col, - start=key_token.start, - end=key_token.end, - ) - return position.to_range(sql.splitlines()) + # 3) Find the matching closing parenthesis for that list by tracking depth + depth = 0 + rparen_idx: t.Optional[int] = None + for i in range(lparen_idx, len(tokens)): + tt = tokens[i].token_type + if tt is TokenType.L_PAREN: + depth += 1 + elif tt is TokenType.R_PAREN: + depth -= 1 + if depth == 0: + rparen_idx = i + break + + if rparen_idx is None: + # Fallback: stop at the first semicolon after MODEL + try: + rparen_idx = next( + i + for i in range(lparen_idx + 1, len(tokens)) + if tokens[i].token_type is TokenType.SEMICOLON + ) + except StopIteration: + return None + + # 4) Scan within the MODEL property list for the key at top-level (depth == 1) + # Initialize depth to 1 since we're inside the first parentheses + depth = 1 + for i in range(lparen_idx + 1, rparen_idx): + tok = tokens[i] + tt = tok.token_type + + if tt is TokenType.L_PAREN: + depth += 1 + continue + if tt is TokenType.R_PAREN: + depth -= 1 + # If we somehow exit before rparen_idx, stop early + if depth <= 0: + break + continue + + if depth == 1 and tt is TokenType.VAR and tok.text.upper() == key.upper(): + # Validate key position: it should immediately follow '(' or ',' at top level + prev_idx = i - 1 + # Skip over non-significant tokens we don't want to gate on (e.g., comments) + while prev_idx >= 0 and tokens[prev_idx].token_type in (TokenType.COMMENT,): + prev_idx -= 1 + prev_tt = tokens[prev_idx].token_type if prev_idx >= 0 else None + if prev_tt in (TokenType.L_PAREN, TokenType.COMMA): + position = TokenPositionDetails( + line=tok.line, + col=tok.col, + start=tok.start, + end=tok.end, + ) + return position.to_range(sql.splitlines()) + + return None From ea963255abf1770208f2579cb75907b656514036 Mon Sep 17 00:00:00 2001 From: Ben King <9087625+benfdking@users.noreply.github.com> Date: Mon, 11 Aug 2025 12:32:04 +0200 Subject: [PATCH 2/4] feat: add the ability to return range for key and value --- sqlmesh/core/linter/helpers.py | 199 +++++++++++++++++++++--------- tests/core/linter/test_helpers.py | 66 ++++++++-- 2 files changed, 197 insertions(+), 68 deletions(-) diff --git a/sqlmesh/core/linter/helpers.py b/sqlmesh/core/linter/helpers.py index 075b25b6d4..78e1a514c5 100644 --- a/sqlmesh/core/linter/helpers.py +++ b/sqlmesh/core/linter/helpers.py @@ -2,7 +2,7 @@ from sqlmesh.core.linter.rule import Range, Position from sqlmesh.utils.pydantic import PydanticModel -from sqlglot import tokenize, TokenType +from sqlglot import tokenize, TokenType, Token import typing as t @@ -122,57 +122,14 @@ def read_range_from_file(file: Path, text_range: Range) -> str: return read_range_from_string("".join(lines), text_range) -def get_range_of_model_block( - sql: str, - dialect: str, -) -> t.Optional[Range]: - """ - Get the range of the model block in an SQL file. - """ - tokens = tokenize(sql, dialect=dialect) - - # Find start of the model block - start = next( - (t for t in tokens if t.token_type is TokenType.VAR and t.text.upper() == "MODEL"), - None, - ) - end = next((t for t in tokens if t.token_type is TokenType.SEMICOLON), None) - - if start is None or end is None: - return None - - start_position = TokenPositionDetails( - line=start.line, - col=start.col, - start=start.start, - end=start.end, - ) - end_position = TokenPositionDetails( - line=end.line, - col=end.col, - start=end.start, - end=end.end, - ) - - splitlines = sql.splitlines() - return Range( - start=start_position.to_range(splitlines).start, - end=end_position.to_range(splitlines).end, - ) - - -def get_range_of_a_key_in_model_block( - sql: str, - dialect: str, - key: str, -) -> t.Optional[Range]: +def get_start_and_end_of_model_block( + tokens: t.List[Token], +) -> t.Optional[t.Tuple[int, int]]: """ - Get the range of a specific key in the model block of an SQL file. + Returns the start and end tokens of the MODEL block in an SQL file. + The MODEL block is defined as the first occurrence of the keyword "MODEL" followed by + an opening parenthesis and a closing parenthesis that matches the opening one. """ - tokens = tokenize(sql, dialect=dialect) - if not tokens: - return None - # 1) Find the MODEL token try: model_idx = next( @@ -216,6 +173,65 @@ def get_range_of_a_key_in_model_block( ) except StopIteration: return None + return ( + lparen_idx, + rparen_idx, + ) + + +def get_range_of_model_block( + sql: str, + dialect: str, +) -> t.Optional[Range]: + """ + Get the range of the model block in an SQL file, + """ + tokens = tokenize(sql, dialect=dialect) + + block = get_start_and_end_of_model_block(tokens) + if not block: + return None + + (start_idx, end_idx) = block + start = tokens[start_idx - 1] + end = tokens[end_idx + 1] + start_position = TokenPositionDetails( + line=start.line, + col=start.col, + start=start.start, + end=start.end, + ) + end_position = TokenPositionDetails( + line=end.line, + col=end.col, + start=end.start, + end=end.end, + ) + splitlines = sql.splitlines() + return Range( + start=start_position.to_range(splitlines).start, + end=end_position.to_range(splitlines).end, + ) + + +def get_range_of_a_key_in_model_block( + sql: str, + dialect: str, + key: str, +) -> t.Optional[t.Tuple[Range, Range]]: + """ + Get the ranges of a specific key and its value in the MODEL block of an SQL file. + + Returns a tuple of (key_range, value_range) if found, otherwise None. + """ + tokens = tokenize(sql, dialect=dialect) + if not tokens: + return None + + block = get_start_and_end_of_model_block(tokens) + if not block: + return None + (lparen_idx, rparen_idx) = block # 4) Scan within the MODEL property list for the key at top-level (depth == 1) # Initialize depth to 1 since we're inside the first parentheses @@ -237,17 +253,78 @@ def get_range_of_a_key_in_model_block( if depth == 1 and tt is TokenType.VAR and tok.text.upper() == key.upper(): # Validate key position: it should immediately follow '(' or ',' at top level prev_idx = i - 1 - # Skip over non-significant tokens we don't want to gate on (e.g., comments) - while prev_idx >= 0 and tokens[prev_idx].token_type in (TokenType.COMMENT,): - prev_idx -= 1 prev_tt = tokens[prev_idx].token_type if prev_idx >= 0 else None - if prev_tt in (TokenType.L_PAREN, TokenType.COMMA): - position = TokenPositionDetails( - line=tok.line, - col=tok.col, - start=tok.start, - end=tok.end, - ) - return position.to_range(sql.splitlines()) + if prev_tt not in (TokenType.L_PAREN, TokenType.COMMA): + continue + + # Key range + lines = sql.splitlines() + key_start = TokenPositionDetails( + line=tok.line, col=tok.col, start=tok.start, end=tok.end + ) + key_range = key_start.to_range(lines) + + # Find value start: the next non-comment token after the key + value_start_idx = i + 1 + if value_start_idx >= rparen_idx: + return None + + # Walk to the end of the value expression: until top-level comma or closing paren + # Track internal nesting for (), [], {} + nested = 0 + j = value_start_idx + value_end_idx = value_start_idx + + def is_open(t: TokenType) -> bool: + return t in (TokenType.L_PAREN, TokenType.L_BRACE, TokenType.L_BRACKET) + + def is_close(t: TokenType) -> bool: + return t in (TokenType.R_PAREN, TokenType.R_BRACE, TokenType.R_BRACKET) + + while j < rparen_idx: + ttype = tokens[j].token_type + if is_open(ttype): + nested += 1 + elif is_close(ttype): + nested -= 1 + + # End of value: at top-level (nested == 0) encountering a comma or the end paren + if nested == 0 and ( + ttype is TokenType.COMMA or (ttype is TokenType.R_PAREN and depth == 1) + ): + # For comma, don't include it in the value range + # For closing paren, include it only if it's part of the value structure + if ttype is TokenType.COMMA: + # Don't include the comma in the value range + break + else: + # Include the closing parenthesis in the value range + value_end_idx = j + break + + value_end_idx = j + j += 1 + + value_start_tok = tokens[value_start_idx] + value_end_tok = tokens[value_end_idx] + + value_start_pos = TokenPositionDetails( + line=value_start_tok.line, + col=value_start_tok.col, + start=value_start_tok.start, + end=value_start_tok.end, + ) + value_end_pos = TokenPositionDetails( + line=value_end_tok.line, + col=value_end_tok.col, + start=value_end_tok.start, + end=value_end_tok.end, + ) + value_range = Range( + start=value_start_pos.to_range(lines).start, + end=value_end_pos.to_range(lines).end, + ) + + return (key_range, value_range) return None diff --git a/tests/core/linter/test_helpers.py b/tests/core/linter/test_helpers.py index f3ae193bb0..c3ba46f304 100644 --- a/tests/core/linter/test_helpers.py +++ b/tests/core/linter/test_helpers.py @@ -52,8 +52,17 @@ def test_get_range_of_a_key_in_model_block_testing_on_sushi(): ] assert len(sql_models) > 0 + # Test that the function works for all keys in the model block for model in sql_models: - possible_keys = ["name", "tags", "description", "columns", "owner", "cron", "dialect"] + possible_keys = [ + "name", + "tags", + "description", + "column_descriptions", + "owner", + "cron", + "dialect", + ] dialect = model.dialect assert dialect is not None @@ -67,12 +76,55 @@ def test_get_range_of_a_key_in_model_block_testing_on_sushi(): count_properties_checked = 0 for key in possible_keys: - range = get_range_of_a_key_in_model_block(content, dialect, key) - - # Check that the range starts with the key and ends with ; - if range: - read_range = read_range_from_file(path, range) - assert read_range.lower() == key.lower() + ranges = get_range_of_a_key_in_model_block(content, dialect, key) + + if ranges: + key_range, value_range = ranges + read_key = read_range_from_file(path, key_range) + assert read_key.lower() == key.lower() + # Value range should be non-empty + read_value = read_range_from_file(path, value_range) + assert len(read_value) > 0 count_properties_checked += 1 assert count_properties_checked > 0 + + # Test that the function works for different kind of value blocks + tests = [ + ("sushi.customers", "name", "sushi.customers"), + ( + "sushi.customers", + "tags", + "(pii, fact)", + ), + ("sushi.customers", "description", "'Sushi customer data'"), + ( + "sushi.customers", + "column_descriptions", + "( customer_id = 'customer_id uniquely identifies customers' )", + ), + ("sushi.customers", "owner", "jen"), + ("sushi.customers", "cron", "'@daily'"), + ] + for model_name, key, value in tests: + model = context.get_model(model_name) + assert model is not None + + dialect = model.dialect + assert dialect is not None + + path = model._path + assert path is not None + + with open(path, "r", encoding="utf-8") as file: + content = file.read() + + ranges = get_range_of_a_key_in_model_block(content, dialect, key) + assert ranges is not None, f"Could not find key '{key}' in model '{model_name}'" + + key_range, value_range = ranges + read_key = read_range_from_file(path, key_range) + assert read_key.lower() == key.lower() + + read_value = read_range_from_file(path, value_range) + assert read_value == value From cc4a7fc67d50e4201162d3dd7940a6db4e6d619a Mon Sep 17 00:00:00 2001 From: Ben King <9087625+benfdking@users.noreply.github.com> Date: Thu, 14 Aug 2025 12:20:50 +0200 Subject: [PATCH 3/4] simplify after comments --- sqlmesh/core/linter/helpers.py | 44 ++++++++++++---------------------- 1 file changed, 15 insertions(+), 29 deletions(-) diff --git a/sqlmesh/core/linter/helpers.py b/sqlmesh/core/linter/helpers.py index 78e1a514c5..475efb9b6a 100644 --- a/sqlmesh/core/linter/helpers.py +++ b/sqlmesh/core/linter/helpers.py @@ -150,33 +150,21 @@ def get_start_and_end_of_model_block( except StopIteration: return None - # 3) Find the matching closing parenthesis for that list by tracking depth - depth = 0 - rparen_idx: t.Optional[int] = None - for i in range(lparen_idx, len(tokens)): - tt = tokens[i].token_type - if tt is TokenType.L_PAREN: - depth += 1 - elif tt is TokenType.R_PAREN: - depth -= 1 - if depth == 0: - rparen_idx = i - break - - if rparen_idx is None: - # Fallback: stop at the first semicolon after MODEL - try: - rparen_idx = next( - i - for i in range(lparen_idx + 1, len(tokens)) - if tokens[i].token_type is TokenType.SEMICOLON - ) - except StopIteration: - return None - return ( - lparen_idx, - rparen_idx, - ) + # 3) Find the matching closing parenthesis by looking for the first semicolon after + # the opening parenthesis and assuming the MODEL block ends there. + try: + closing_semicolon = next( + i + for i in range(lparen_idx + 1, len(tokens)) + if tokens[i].token_type is TokenType.SEMICOLON + ) + # If we find a semicolon, we can assume the MODEL block ends there + rparen_idx = closing_semicolon - 1 + if tokens[rparen_idx].token_type is TokenType.R_PAREN: + return (lparen_idx, rparen_idx) + return None + except StopIteration: + return None def get_range_of_model_block( @@ -187,11 +175,9 @@ def get_range_of_model_block( Get the range of the model block in an SQL file, """ tokens = tokenize(sql, dialect=dialect) - block = get_start_and_end_of_model_block(tokens) if not block: return None - (start_idx, end_idx) = block start = tokens[start_idx - 1] end = tokens[end_idx + 1] From dee3e24ad8ad7a667f514ba93225d178e18caaa1 Mon Sep 17 00:00:00 2001 From: Ben <9087625+benfdking@users.noreply.github.com> Date: Fri, 15 Aug 2025 11:19:48 +0200 Subject: [PATCH 4/4] Update sqlmesh/core/linter/helpers.py Co-authored-by: Jo <46752250+georgesittas@users.noreply.github.com> --- sqlmesh/core/linter/helpers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sqlmesh/core/linter/helpers.py b/sqlmesh/core/linter/helpers.py index 475efb9b6a..3c79f83a43 100644 --- a/sqlmesh/core/linter/helpers.py +++ b/sqlmesh/core/linter/helpers.py @@ -250,7 +250,6 @@ def get_range_of_a_key_in_model_block( ) key_range = key_start.to_range(lines) - # Find value start: the next non-comment token after the key value_start_idx = i + 1 if value_start_idx >= rparen_idx: return None