diff --git a/sqlmesh/core/linter/helpers.py b/sqlmesh/core/linter/helpers.py index e62545bc02..3f6e96765f 100644 --- a/sqlmesh/core/linter/helpers.py +++ b/sqlmesh/core/linter/helpers.py @@ -158,3 +158,55 @@ def get_range_of_model_block( return Range( start=start_position.to_range(splitlines).start, end=end_position.to_range(splitlines).end ) + + +def get_range_of_a_key_in_model_block( + sql: str, + dialect: str, + key: str, +) -> t.Optional[Range]: + """ + Get the range of a specific key in the model block of an SQL file. + """ + tokens = tokenize(sql, dialect=dialect) + if tokens is None: + return None + + # Find the start of the model block + start_index = next( + ( + i + for i, t in enumerate(tokens) + if t.token_type is TokenType.VAR and t.text.upper() == "MODEL" + ), + None, + ) + end_index = next( + (i for i, t in enumerate(tokens) if t.token_type is TokenType.SEMICOLON), + None, + ) + if start_index is None or end_index is None: + return None + if start_index >= end_index: + return None + + tokens_of_interest = tokens[start_index + 1 : end_index] + # Find the key token + key_token = next( + ( + t + for t in tokens_of_interest + if t.token_type is TokenType.VAR and t.text.upper() == key.upper() + ), + None, + ) + if key_token is None: + return None + + position = TokenPositionDetails( + line=key_token.line, + col=key_token.col, + start=key_token.start, + end=key_token.end, + ) + return position.to_range(sql.splitlines()) diff --git a/tests/core/linter/test_helpers.py b/tests/core/linter/test_helpers.py index be6ebf2c27..f3ae193bb0 100644 --- a/tests/core/linter/test_helpers.py +++ b/tests/core/linter/test_helpers.py @@ -1,5 +1,9 @@ from sqlmesh import Context -from sqlmesh.core.linter.helpers import read_range_from_file, get_range_of_model_block +from sqlmesh.core.linter.helpers import ( + read_range_from_file, + get_range_of_model_block, + get_range_of_a_key_in_model_block, +) from sqlmesh.core.model import SqlModel @@ -34,3 +38,41 @@ def test_get_position_of_model_block(): read_range = read_range_from_file(path, range) assert read_range.startswith("MODEL") assert read_range.endswith(";") + + +def test_get_range_of_a_key_in_model_block_testing_on_sushi(): + context = Context(paths=["examples/sushi"]) + + sql_models = [ + model + for model in context.models.values() + if isinstance(model, SqlModel) + and model._path is not None + and str(model._path).endswith(".sql") + ] + assert len(sql_models) > 0 + + for model in sql_models: + possible_keys = ["name", "tags", "description", "columns", "owner", "cron", "dialect"] + + dialect = model.dialect + assert dialect is not None + + path = model._path + assert path is not None + + with open(path, "r", encoding="utf-8") as file: + content = file.read() + + count_properties_checked = 0 + + for key in possible_keys: + range = get_range_of_a_key_in_model_block(content, dialect, key) + + # Check that the range starts with the key and ends with ; + if range: + read_range = read_range_from_file(path, range) + assert read_range.lower() == key.lower() + count_properties_checked += 1 + + assert count_properties_checked > 0