Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 155 additions & 52 deletions sqlmesh/core/linter/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from sqlmesh.core.linter.rule import Range, Position
from sqlmesh.utils.pydantic import PydanticModel
from sqlglot import tokenize, TokenType
from sqlglot import tokenize, TokenType, Token
import typing as t


Expand Down Expand Up @@ -122,25 +122,65 @@ def read_range_from_file(file: Path, text_range: Range) -> str:
return read_range_from_string("".join(lines), text_range)


def get_start_and_end_of_model_block(
tokens: t.List[Token],
) -> t.Optional[t.Tuple[int, int]]:
"""
Returns the start and end tokens of the MODEL block in an SQL file.
The MODEL block is defined as the first occurrence of the keyword "MODEL" followed by
an opening parenthesis and a closing parenthesis that matches the opening one.
"""
# 1) Find the MODEL token
try:
model_idx = next(
i
for i, tok in enumerate(tokens)
if tok.token_type is TokenType.VAR and tok.text.upper() == "MODEL"
)
except StopIteration:
return None

# 2) Find the opening parenthesis for the MODEL properties list
try:
lparen_idx = next(
i
for i in range(model_idx + 1, len(tokens))
if tokens[i].token_type is TokenType.L_PAREN
)
except StopIteration:
return None

# 3) Find the matching closing parenthesis by looking for the first semicolon after
# the opening parenthesis and assuming the MODEL block ends there.
try:
closing_semicolon = next(
i
for i in range(lparen_idx + 1, len(tokens))
if tokens[i].token_type is TokenType.SEMICOLON
)
# If we find a semicolon, we can assume the MODEL block ends there
rparen_idx = closing_semicolon - 1
if tokens[rparen_idx].token_type is TokenType.R_PAREN:
return (lparen_idx, rparen_idx)
return None
except StopIteration:
return None


def get_range_of_model_block(
sql: str,
dialect: str,
) -> t.Optional[Range]:
"""
Get the range of the model block in an SQL file.
Get the range of the model block in an SQL file,
"""
tokens = tokenize(sql, dialect=dialect)

# Find start of the model block
start = next(
(t for t in tokens if t.token_type is TokenType.VAR and t.text.upper() == "MODEL"),
None,
)
end = next((t for t in tokens if t.token_type is TokenType.SEMICOLON), None)

if start is None or end is None:
block = get_start_and_end_of_model_block(tokens)
if not block:
return None

(start_idx, end_idx) = block
start = tokens[start_idx - 1]
end = tokens[end_idx + 1]
start_position = TokenPositionDetails(
line=start.line,
col=start.col,
Expand All @@ -153,60 +193,123 @@ def get_range_of_model_block(
start=end.start,
end=end.end,
)

splitlines = sql.splitlines()
return Range(
start=start_position.to_range(splitlines).start, end=end_position.to_range(splitlines).end
start=start_position.to_range(splitlines).start,
end=end_position.to_range(splitlines).end,
)


def get_range_of_a_key_in_model_block(
sql: str,
dialect: str,
key: str,
) -> t.Optional[Range]:
) -> t.Optional[t.Tuple[Range, Range]]:
"""
Get the range of a specific key in the model block of an SQL file.
Get the ranges of a specific key and its value in the MODEL block of an SQL file.

Returns a tuple of (key_range, value_range) if found, otherwise None.
"""
tokens = tokenize(sql, dialect=dialect)
if tokens is None:
if not tokens:
return None

# Find the start of the model block
start_index = next(
(
i
for i, t in enumerate(tokens)
if t.token_type is TokenType.VAR and t.text.upper() == "MODEL"
),
None,
)
end_index = next(
(i for i, t in enumerate(tokens) if t.token_type is TokenType.SEMICOLON),
None,
)
if start_index is None or end_index is None:
return None
if start_index >= end_index:
block = get_start_and_end_of_model_block(tokens)
if not block:
return None
(lparen_idx, rparen_idx) = block

# 4) Scan within the MODEL property list for the key at top-level (depth == 1)
# Initialize depth to 1 since we're inside the first parentheses
depth = 1
for i in range(lparen_idx + 1, rparen_idx):
tok = tokens[i]
tt = tok.token_type

if tt is TokenType.L_PAREN:
depth += 1
continue
if tt is TokenType.R_PAREN:
depth -= 1
# If we somehow exit before rparen_idx, stop early
if depth <= 0:
break
continue

if depth == 1 and tt is TokenType.VAR and tok.text.upper() == key.upper():
# Validate key position: it should immediately follow '(' or ',' at top level
prev_idx = i - 1
prev_tt = tokens[prev_idx].token_type if prev_idx >= 0 else None
if prev_tt not in (TokenType.L_PAREN, TokenType.COMMA):
continue

# Key range
lines = sql.splitlines()
key_start = TokenPositionDetails(
line=tok.line, col=tok.col, start=tok.start, end=tok.end
)
key_range = key_start.to_range(lines)

value_start_idx = i + 1
if value_start_idx >= rparen_idx:
return None

# Walk to the end of the value expression: until top-level comma or closing paren
# Track internal nesting for (), [], {}
nested = 0
j = value_start_idx
value_end_idx = value_start_idx

def is_open(t: TokenType) -> bool:
return t in (TokenType.L_PAREN, TokenType.L_BRACE, TokenType.L_BRACKET)

def is_close(t: TokenType) -> bool:
return t in (TokenType.R_PAREN, TokenType.R_BRACE, TokenType.R_BRACKET)

while j < rparen_idx:
ttype = tokens[j].token_type
if is_open(ttype):
nested += 1
elif is_close(ttype):
nested -= 1

# End of value: at top-level (nested == 0) encountering a comma or the end paren
if nested == 0 and (
ttype is TokenType.COMMA or (ttype is TokenType.R_PAREN and depth == 1)
):
# For comma, don't include it in the value range
# For closing paren, include it only if it's part of the value structure
if ttype is TokenType.COMMA:
# Don't include the comma in the value range
break
else:
# Include the closing parenthesis in the value range
value_end_idx = j
break

value_end_idx = j
j += 1

value_start_tok = tokens[value_start_idx]
value_end_tok = tokens[value_end_idx]

value_start_pos = TokenPositionDetails(
line=value_start_tok.line,
col=value_start_tok.col,
start=value_start_tok.start,
end=value_start_tok.end,
)
value_end_pos = TokenPositionDetails(
line=value_end_tok.line,
col=value_end_tok.col,
start=value_end_tok.start,
end=value_end_tok.end,
)
value_range = Range(
start=value_start_pos.to_range(lines).start,
end=value_end_pos.to_range(lines).end,
)

tokens_of_interest = tokens[start_index + 1 : end_index]
# Find the key token
key_token = next(
(
t
for t in tokens_of_interest
if t.token_type is TokenType.VAR and t.text.upper() == key.upper()
),
None,
)
if key_token is None:
return None
return (key_range, value_range)

position = TokenPositionDetails(
line=key_token.line,
col=key_token.col,
start=key_token.start,
end=key_token.end,
)
return position.to_range(sql.splitlines())
return None
66 changes: 59 additions & 7 deletions tests/core/linter/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,17 @@ def test_get_range_of_a_key_in_model_block_testing_on_sushi():
]
assert len(sql_models) > 0

# Test that the function works for all keys in the model block
for model in sql_models:
possible_keys = ["name", "tags", "description", "columns", "owner", "cron", "dialect"]
possible_keys = [
"name",
"tags",
"description",
"column_descriptions",
"owner",
"cron",
"dialect",
]

dialect = model.dialect
assert dialect is not None
Expand All @@ -67,12 +76,55 @@ def test_get_range_of_a_key_in_model_block_testing_on_sushi():
count_properties_checked = 0

for key in possible_keys:
range = get_range_of_a_key_in_model_block(content, dialect, key)

# Check that the range starts with the key and ends with ;
if range:
read_range = read_range_from_file(path, range)
assert read_range.lower() == key.lower()
ranges = get_range_of_a_key_in_model_block(content, dialect, key)

if ranges:
key_range, value_range = ranges
read_key = read_range_from_file(path, key_range)
assert read_key.lower() == key.lower()
# Value range should be non-empty
read_value = read_range_from_file(path, value_range)
assert len(read_value) > 0
count_properties_checked += 1

assert count_properties_checked > 0

# Test that the function works for different kind of value blocks
tests = [
("sushi.customers", "name", "sushi.customers"),
(
"sushi.customers",
"tags",
"(pii, fact)",
),
("sushi.customers", "description", "'Sushi customer data'"),
(
"sushi.customers",
"column_descriptions",
"( customer_id = 'customer_id uniquely identifies customers' )",
),
("sushi.customers", "owner", "jen"),
("sushi.customers", "cron", "'@daily'"),
]
for model_name, key, value in tests:
model = context.get_model(model_name)
assert model is not None

dialect = model.dialect
assert dialect is not None

path = model._path
assert path is not None

with open(path, "r", encoding="utf-8") as file:
content = file.read()

ranges = get_range_of_a_key_in_model_block(content, dialect, key)
assert ranges is not None, f"Could not find key '{key}' in model '{model_name}'"

key_range, value_range = ranges
read_key = read_range_from_file(path, key_range)
assert read_key.lower() == key.lower()

read_value = read_range_from_file(path, value_range)
assert read_value == value
Loading