From d721e536c9432225511881a4e269d7b960532ed7 Mon Sep 17 00:00:00 2001 From: Heinrich Chan Date: Fri, 15 Mar 2024 11:06:41 +0800 Subject: [PATCH] add test to make sure it is pure function --- tests/unit_tests/repositories/test_base.py | 30 ++++++++++++++++------ 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/tests/unit_tests/repositories/test_base.py b/tests/unit_tests/repositories/test_base.py index 8208bc0..60f1683 100644 --- a/tests/unit_tests/repositories/test_base.py +++ b/tests/unit_tests/repositories/test_base.py @@ -22,14 +22,6 @@ class TestBaseRepository(TestCase): def setUpClass(cls) -> None: cls.repo = CountryRepository(Session()) - def test_country(self) -> None: - val = extract_sort(Country, SortCountryEnum, "name")(select(Country)) - print("val", val) - - def test_person(self) -> None: - val = extract_sort(Person, SortPersonEnum, "id")(select(Person)) - print("person", val) - class TestCountryRepository(TestCase): def generate_from_sqlalchemy(self, sort_key_: str) -> str: @@ -42,6 +34,17 @@ def get_extract_query(self, q: QueryCountrySchema) -> Select[Tuple[Country]]: """ return extract_query(Country, ["name"], q)(select(Country)) # type: ignore + def test_country_extract_sort_is_pure_function(self) -> None: + stmt = select(Country) + # first call + stmt_ = extract_sort(Country, SortCountryEnum, "name")(stmt) + assert str(stmt_) == country_reference_sql + " ORDER BY countries.name ASC" + assert str(stmt) == str(select(Country)) + # second call + stmt_ = extract_sort(Country, SortCountryEnum, "phone")(stmt) + assert str(stmt_) == country_reference_sql + " ORDER BY countries.phone ASC" + assert str(stmt) == str(select(Country)) + def test_country_extract_sort_if_part_of_enum_asc(self) -> None: def generate_sql(sort_key_: str) -> str: return country_reference_sql + f" ORDER BY countries.{sort_key_} ASC" @@ -104,6 +107,17 @@ def get_extract_query(self, q: QueryPersonSchema) -> Select[Tuple[Person]]: """ return extract_query(Person, ["first_name", "last_name"], q)(select(Person)) # type: ignore + def test_person_is_pure_function(self) -> None: + stmt = select(Person) + # first call + stmt_ = extract_sort(Person, SortPersonEnum, "first_name")(stmt) + assert str(stmt_) == person_reference_sql + " ORDER BY persons.first_name ASC" + assert str(stmt) == str(select(Person)) + # second call + stmt_ = extract_sort(Person, SortPersonEnum, "last_name")(stmt) + assert str(stmt_) == person_reference_sql + " ORDER BY persons.last_name ASC" + assert str(stmt) == str(select(Person)) + # extract sort def test_person_extract_sort_if_part_of_enum_asc(self) -> None: def generate_sql(sort_key_: str) -> str: