diff --git a/sqlmesh/core/linter/helpers.py b/sqlmesh/core/linter/helpers.py index 3f6e96765f..3c79f83a43 100644 --- a/sqlmesh/core/linter/helpers.py +++ b/sqlmesh/core/linter/helpers.py @@ -2,7 +2,7 @@ from sqlmesh.core.linter.rule import Range, Position from sqlmesh.utils.pydantic import PydanticModel -from sqlglot import tokenize, TokenType +from sqlglot import tokenize, TokenType, Token import typing as t @@ -122,25 +122,65 @@ def read_range_from_file(file: Path, text_range: Range) -> str: return read_range_from_string("".join(lines), text_range) +def get_start_and_end_of_model_block( + tokens: t.List[Token], +) -> t.Optional[t.Tuple[int, int]]: + """ + Returns the start and end tokens of the MODEL block in an SQL file. + The MODEL block is defined as the first occurrence of the keyword "MODEL" followed by + an opening parenthesis and a closing parenthesis that matches the opening one. + """ + # 1) Find the MODEL token + try: + model_idx = next( + i + for i, tok in enumerate(tokens) + if tok.token_type is TokenType.VAR and tok.text.upper() == "MODEL" + ) + except StopIteration: + return None + + # 2) Find the opening parenthesis for the MODEL properties list + try: + lparen_idx = next( + i + for i in range(model_idx + 1, len(tokens)) + if tokens[i].token_type is TokenType.L_PAREN + ) + except StopIteration: + return None + + # 3) Find the matching closing parenthesis by looking for the first semicolon after + # the opening parenthesis and assuming the MODEL block ends there. + try: + closing_semicolon = next( + i + for i in range(lparen_idx + 1, len(tokens)) + if tokens[i].token_type is TokenType.SEMICOLON + ) + # If we find a semicolon, we can assume the MODEL block ends there + rparen_idx = closing_semicolon - 1 + if tokens[rparen_idx].token_type is TokenType.R_PAREN: + return (lparen_idx, rparen_idx) + return None + except StopIteration: + return None + + def get_range_of_model_block( sql: str, dialect: str, ) -> t.Optional[Range]: """ - Get the range of the model block in an SQL file. + Get the range of the model block in an SQL file, """ tokens = tokenize(sql, dialect=dialect) - - # Find start of the model block - start = next( - (t for t in tokens if t.token_type is TokenType.VAR and t.text.upper() == "MODEL"), - None, - ) - end = next((t for t in tokens if t.token_type is TokenType.SEMICOLON), None) - - if start is None or end is None: + block = get_start_and_end_of_model_block(tokens) + if not block: return None - + (start_idx, end_idx) = block + start = tokens[start_idx - 1] + end = tokens[end_idx + 1] start_position = TokenPositionDetails( line=start.line, col=start.col, @@ -153,10 +193,10 @@ def get_range_of_model_block( start=end.start, end=end.end, ) - splitlines = sql.splitlines() return Range( - start=start_position.to_range(splitlines).start, end=end_position.to_range(splitlines).end + start=start_position.to_range(splitlines).start, + end=end_position.to_range(splitlines).end, ) @@ -164,49 +204,112 @@ def get_range_of_a_key_in_model_block( sql: str, dialect: str, key: str, -) -> t.Optional[Range]: +) -> t.Optional[t.Tuple[Range, Range]]: """ - Get the range of a specific key in the model block of an SQL file. + Get the ranges of a specific key and its value in the MODEL block of an SQL file. + + Returns a tuple of (key_range, value_range) if found, otherwise None. """ tokens = tokenize(sql, dialect=dialect) - if tokens is None: + if not tokens: return None - # Find the start of the model block - start_index = next( - ( - i - for i, t in enumerate(tokens) - if t.token_type is TokenType.VAR and t.text.upper() == "MODEL" - ), - None, - ) - end_index = next( - (i for i, t in enumerate(tokens) if t.token_type is TokenType.SEMICOLON), - None, - ) - if start_index is None or end_index is None: - return None - if start_index >= end_index: + block = get_start_and_end_of_model_block(tokens) + if not block: return None + (lparen_idx, rparen_idx) = block + + # 4) Scan within the MODEL property list for the key at top-level (depth == 1) + # Initialize depth to 1 since we're inside the first parentheses + depth = 1 + for i in range(lparen_idx + 1, rparen_idx): + tok = tokens[i] + tt = tok.token_type + + if tt is TokenType.L_PAREN: + depth += 1 + continue + if tt is TokenType.R_PAREN: + depth -= 1 + # If we somehow exit before rparen_idx, stop early + if depth <= 0: + break + continue + + if depth == 1 and tt is TokenType.VAR and tok.text.upper() == key.upper(): + # Validate key position: it should immediately follow '(' or ',' at top level + prev_idx = i - 1 + prev_tt = tokens[prev_idx].token_type if prev_idx >= 0 else None + if prev_tt not in (TokenType.L_PAREN, TokenType.COMMA): + continue + + # Key range + lines = sql.splitlines() + key_start = TokenPositionDetails( + line=tok.line, col=tok.col, start=tok.start, end=tok.end + ) + key_range = key_start.to_range(lines) + + value_start_idx = i + 1 + if value_start_idx >= rparen_idx: + return None + + # Walk to the end of the value expression: until top-level comma or closing paren + # Track internal nesting for (), [], {} + nested = 0 + j = value_start_idx + value_end_idx = value_start_idx + + def is_open(t: TokenType) -> bool: + return t in (TokenType.L_PAREN, TokenType.L_BRACE, TokenType.L_BRACKET) + + def is_close(t: TokenType) -> bool: + return t in (TokenType.R_PAREN, TokenType.R_BRACE, TokenType.R_BRACKET) + + while j < rparen_idx: + ttype = tokens[j].token_type + if is_open(ttype): + nested += 1 + elif is_close(ttype): + nested -= 1 + + # End of value: at top-level (nested == 0) encountering a comma or the end paren + if nested == 0 and ( + ttype is TokenType.COMMA or (ttype is TokenType.R_PAREN and depth == 1) + ): + # For comma, don't include it in the value range + # For closing paren, include it only if it's part of the value structure + if ttype is TokenType.COMMA: + # Don't include the comma in the value range + break + else: + # Include the closing parenthesis in the value range + value_end_idx = j + break + + value_end_idx = j + j += 1 + + value_start_tok = tokens[value_start_idx] + value_end_tok = tokens[value_end_idx] + + value_start_pos = TokenPositionDetails( + line=value_start_tok.line, + col=value_start_tok.col, + start=value_start_tok.start, + end=value_start_tok.end, + ) + value_end_pos = TokenPositionDetails( + line=value_end_tok.line, + col=value_end_tok.col, + start=value_end_tok.start, + end=value_end_tok.end, + ) + value_range = Range( + start=value_start_pos.to_range(lines).start, + end=value_end_pos.to_range(lines).end, + ) - tokens_of_interest = tokens[start_index + 1 : end_index] - # Find the key token - key_token = next( - ( - t - for t in tokens_of_interest - if t.token_type is TokenType.VAR and t.text.upper() == key.upper() - ), - None, - ) - if key_token is None: - return None + return (key_range, value_range) - position = TokenPositionDetails( - line=key_token.line, - col=key_token.col, - start=key_token.start, - end=key_token.end, - ) - return position.to_range(sql.splitlines()) + return None diff --git a/tests/core/linter/test_helpers.py b/tests/core/linter/test_helpers.py index f3ae193bb0..c3ba46f304 100644 --- a/tests/core/linter/test_helpers.py +++ b/tests/core/linter/test_helpers.py @@ -52,8 +52,17 @@ def test_get_range_of_a_key_in_model_block_testing_on_sushi(): ] assert len(sql_models) > 0 + # Test that the function works for all keys in the model block for model in sql_models: - possible_keys = ["name", "tags", "description", "columns", "owner", "cron", "dialect"] + possible_keys = [ + "name", + "tags", + "description", + "column_descriptions", + "owner", + "cron", + "dialect", + ] dialect = model.dialect assert dialect is not None @@ -67,12 +76,55 @@ def test_get_range_of_a_key_in_model_block_testing_on_sushi(): count_properties_checked = 0 for key in possible_keys: - range = get_range_of_a_key_in_model_block(content, dialect, key) - - # Check that the range starts with the key and ends with ; - if range: - read_range = read_range_from_file(path, range) - assert read_range.lower() == key.lower() + ranges = get_range_of_a_key_in_model_block(content, dialect, key) + + if ranges: + key_range, value_range = ranges + read_key = read_range_from_file(path, key_range) + assert read_key.lower() == key.lower() + # Value range should be non-empty + read_value = read_range_from_file(path, value_range) + assert len(read_value) > 0 count_properties_checked += 1 assert count_properties_checked > 0 + + # Test that the function works for different kind of value blocks + tests = [ + ("sushi.customers", "name", "sushi.customers"), + ( + "sushi.customers", + "tags", + "(pii, fact)", + ), + ("sushi.customers", "description", "'Sushi customer data'"), + ( + "sushi.customers", + "column_descriptions", + "( customer_id = 'customer_id uniquely identifies customers' )", + ), + ("sushi.customers", "owner", "jen"), + ("sushi.customers", "cron", "'@daily'"), + ] + for model_name, key, value in tests: + model = context.get_model(model_name) + assert model is not None + + dialect = model.dialect + assert dialect is not None + + path = model._path + assert path is not None + + with open(path, "r", encoding="utf-8") as file: + content = file.read() + + ranges = get_range_of_a_key_in_model_block(content, dialect, key) + assert ranges is not None, f"Could not find key '{key}' in model '{model_name}'" + + key_range, value_range = ranges + read_key = read_range_from_file(path, key_range) + assert read_key.lower() == key.lower() + + read_value = read_range_from_file(path, value_range) + assert read_value == value