diff --git a/sqlmesh/utils/__init__.py b/sqlmesh/utils/__init__.py index c220de4847..5b1b077216 100644 --- a/sqlmesh/utils/__init__.py +++ b/sqlmesh/utils/__init__.py @@ -21,6 +21,7 @@ from functools import lru_cache, reduce, wraps from pathlib import Path +import unicodedata from sqlglot import exp from sqlglot.dialects.dialect import Dialects @@ -291,8 +292,14 @@ def sqlglot_dialects() -> str: NON_ALNUM = re.compile(r"[^a-zA-Z0-9_]") +NON_ALUM_INCLUDE_UNICODE = re.compile(r"\W", flags=re.UNICODE) -def sanitize_name(name: str) -> str: + +def sanitize_name(name: str, *, include_unicode: bool = False) -> str: + if include_unicode: + s = unicodedata.normalize("NFC", name) + s = NON_ALUM_INCLUDE_UNICODE.sub("_", s) + return s return NON_ALNUM.sub("_", name) diff --git a/sqlmesh/utils/cache.py b/sqlmesh/utils/cache.py index 002248f511..4b557e43b6 100644 --- a/sqlmesh/utils/cache.py +++ b/sqlmesh/utils/cache.py @@ -133,7 +133,7 @@ def clear(self) -> None: def _cache_entry_path(self, name: str, entry_id: str = "") -> Path: entry_file_name = "__".join(p for p in (self._cache_version, name, entry_id) if p) - full_path = self._path / sanitize_name(entry_file_name) + full_path = self._path / sanitize_name(entry_file_name, include_unicode=True) if IS_WINDOWS: # handle paths longer than 260 chars full_path = fix_windows_path(full_path) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index e69de29bb2..744ad37757 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -0,0 +1,23 @@ +import pytest + +from sqlmesh.utils import sanitize_name + + +@pytest.mark.parametrize( + "raw,exclude_unicode,include_unicode", + [ + ("simple", "simple", "simple"), + ("snake_case", "snake_case", "snake_case"), + ("客户数据", "____", "客户数据"), + ("客户-数据 v2", "______v2", "客户_数据_v2"), + ("中文,逗号", "_____", "中文_逗号"), + ("a/b", "a_b", "a_b"), + ("spaces\tand\nnewlines", "spaces_and_newlines", "spaces_and_newlines"), + ("data📦2025", "data_2025", "data_2025"), + ("MiXeD123_名字", "MiXeD123___", "MiXeD123_名字"), + ("", "", ""), + ], +) +def test_sanitize_name_no_(raw, exclude_unicode, include_unicode): + assert sanitize_name(raw) == exclude_unicode + assert sanitize_name(raw, include_unicode=True) == include_unicode diff --git a/tests/utils/test_cache.py b/tests/utils/test_cache.py index cd1fdb0115..0b6d335446 100644 --- a/tests/utils/test_cache.py +++ b/tests/utils/test_cache.py @@ -39,6 +39,7 @@ def test_file_cache(tmp_path: Path, mocker: MockerFixture): loader.assert_called_once() assert "___test_model_" in cache._cache_entry_path('"test_model"').name + assert "客户数据" in cache._cache_entry_path("客户数据").name def test_optimized_query_cache(tmp_path: Path, mocker: MockerFixture):