diff --git a/.dockerignore b/.dockerignore index 0be3b56..f4932db 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,6 +1,7 @@ # Python -__pycache__/ +**/__pycache__/ *.py[cod] +*.pyo *$py.class *.so .Python diff --git a/.launch/Dockerfile_base b/.launch/Dockerfile_base index 7d13e54..fd967e6 100644 --- a/.launch/Dockerfile_base +++ b/.launch/Dockerfile_base @@ -1,52 +1,40 @@ # Build stage -ARG PYTHON_VERSION=3.12.11 -FROM python:${PYTHON_VERSION}-slim AS builder +FROM python:3.12-slim AS builder ENV PYTHONDONTWRITEBYTECODE=1 -ENV PYTHONUNBUFFERED=1 - WORKDIR /app -# Install build dependencies -RUN apt-get update && apt-get install -y \ - build-essential \ - && rm -rf /var/lib/apt/lists/* - -# Install poetry -RUN pip install --no-cache-dir poetry +# Install build tools and poetry +RUN apt-get update && apt-get install -y build-essential \ + && rm -rf /var/lib/apt/lists/* \ + && pip install --no-cache-dir poetry -# Copy dependency files +# Install dependencies COPY pyproject.toml poetry.lock ./ - -# Configure poetry and install dependencies RUN poetry config virtualenvs.create false \ && poetry install --no-root --only=main # Production stage -ARG PYTHON_VERSION=3.12 -FROM python:${PYTHON_VERSION}-slim +FROM python:3.12-slim ENV PYTHONDONTWRITEBYTECODE=1 ENV PYTHONUNBUFFERED=1 ENV PYTHONPATH=/app -# Install runtime dependencies only -RUN apt-get update && apt-get install -y \ - libpq5 \ - && rm -rf /var/lib/apt/lists/* \ - && apt-get clean +# Install runtime dependencies +RUN apt-get update && apt-get install -y libpq5 \ + && rm -rf /var/lib/apt/lists/* WORKDIR /app -# Get Python version for dynamic path -ARG PYTHON_VERSION=3.12 -ENV PYTHON_SITE_PACKAGES=/usr/local/lib/python${PYTHON_VERSION}/site-packages - -# Copy installed packages from builder stage -COPY --from=builder ${PYTHON_SITE_PACKAGES} ${PYTHON_SITE_PACKAGES} +# Copy Python packages and binaries from builder +COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages COPY --from=builder /usr/local/bin /usr/local/bin # Copy application code COPY src/ ./src/ COPY alembic.ini ./ +# Clean any cache files (safety backup) +RUN find ./src -name "*.pyc" -delete 2>/dev/null || true + diff --git a/local_prepare.sh b/local_prepare.sh index 909b946..aacc190 100644 --- a/local_prepare.sh +++ b/local_prepare.sh @@ -42,7 +42,7 @@ if [ ! "$(docker ps -aq -f name=${PROJECT_NAME_SLUG}_rabbitmq)" ]; then -p $MESSAGE_BROKER_PORT:5672 \ -e RABBITMQ_DEFAULT_USER=$MESSAGE_BROKER_USER \ -e RABBITMQ_DEFAULT_PASS=$MESSAGE_BROKER_PASSWORD \ - rabbitmq:3.11.6-management || true + rabbitmq:4.1.4-management-alpine || true fi echo " ✅ ${PROJECT_NAME_SLUG}_rabbitmq UP" diff --git a/poetry.lock b/poetry.lock index 9064d73..a36de1e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -879,7 +879,6 @@ description = "Lightweight in-process concurrent programming" optional = false python-versions = ">=3.9" groups = ["main"] -markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\"" files = [ {file = "greenlet-3.2.4-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:8c68325b0d0acf8d91dde4e6f930967dd52a5302cd4062932a6b2e7c2969f47c"}, {file = "greenlet-3.2.4-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:94385f101946790ae13da500603491f04a76b6e4c059dab271b3ce2e283b2590"}, @@ -3015,4 +3014,4 @@ propcache = ">=0.2.1" [metadata] lock-version = "2.1" python-versions = "3.12.11" -content-hash = "553fb8d81f67ea7c51950d31416b6355494cf7c79584deb13e013d41463c28da" +content-hash = "dfea5f63160f7a2d1f8a211084d156e8c02252db4a3e678e1b5389e363ddd1a9" diff --git a/pyproject.toml b/pyproject.toml index 15957bd..25499e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,8 @@ aiokafka = "^0.12.0" pytz = "^2025.2" grpcio = "^1.69.0" grpcio-tools = "^1.69.0" +greenlet = "^3.2.4" + [tool.poetry.group.dev.dependencies] diff --git a/src/app/application/container.py b/src/app/application/container.py index a0f9dab..563fbff 100644 --- a/src/app/application/container.py +++ b/src/app/application/container.py @@ -17,5 +17,11 @@ def auth_service(self) -> Type["src.app.application.services.auth_service.AuthSe return AuthService + @property + def common_service(self) -> Type["src.app.application.services.common_service.CommonApplicationService"]: + from src.app.application.services.common_service import CommonApplicationService + + return CommonApplicationService + container = ApplicationServicesContainer() diff --git a/src/app/application/services/common_service.py b/src/app/application/services/common_service.py new file mode 100644 index 0000000..4a76a16 --- /dev/null +++ b/src/app/application/services/common_service.py @@ -0,0 +1,23 @@ +from src.app.infrastructure.messaging.mq_client import mq_client +from src.app.infrastructure.repositories.container import container as repo_container +from src.app.application.common.services.base import AbstractBaseApplicationService +from loguru import logger + + +class CommonApplicationService(AbstractBaseApplicationService): + + @classmethod + async def is_healthy(cls) -> bool: + """Checks if app infrastructure is up and healthy.""" + try: + is_psql_healthy = await repo_container.common_psql_repository.is_healthy() + + is_redis_healthy = await repo_container.common_redis_repository.is_healthy() + + is_message_broker_healthy = await mq_client.is_healthy() + + except Exception as ex: + logger.error(f"Application is not healthy. Reason: {ex}") + return False + + return all([is_psql_healthy, is_redis_healthy, is_message_broker_healthy]) diff --git a/src/app/infrastructure/messaging/clients/kafka_client.py b/src/app/infrastructure/messaging/clients/kafka_client.py index 44949f3..4a543d1 100644 --- a/src/app/infrastructure/messaging/clients/kafka_client.py +++ b/src/app/infrastructure/messaging/clients/kafka_client.py @@ -23,7 +23,7 @@ async def is_healthy(self) -> bool: brokers = metadata.brokers() if callable(metadata.brokers) else metadata.brokers return len(brokers) > 0 except Exception as ex: - logger.error(f"{ex}") + logger.warning(f"{ex}") return False finally: await client.close() diff --git a/src/app/infrastructure/messaging/clients/rabbitmq_client.py b/src/app/infrastructure/messaging/clients/rabbitmq_client.py index 0cf7127..6192e12 100644 --- a/src/app/infrastructure/messaging/clients/rabbitmq_client.py +++ b/src/app/infrastructure/messaging/clients/rabbitmq_client.py @@ -33,7 +33,7 @@ async def is_healthy(self) -> bool: await connection_.close() return True except Exception as ex: - logger.error(f"{ex}") + logger.warning(f"{ex}") return False async def __get_connection(self) -> AbstractRobustConnection: diff --git a/src/app/infrastructure/repositories/base/abstract.py b/src/app/infrastructure/repositories/base/abstract.py index c0e0345..ee3e04e 100644 --- a/src/app/infrastructure/repositories/base/abstract.py +++ b/src/app/infrastructure/repositories/base/abstract.py @@ -5,6 +5,10 @@ from src.app.infrastructure.extensions.psql_ext.psql_ext import Base +class RepositoryError(Exception): + pass + + class AbstractRepository(ABC): pass diff --git a/src/app/infrastructure/repositories/base/base_psql_repository.py b/src/app/infrastructure/repositories/base/base_psql_repository.py index c61e5db..0c9f88e 100644 --- a/src/app/infrastructure/repositories/base/base_psql_repository.py +++ b/src/app/infrastructure/repositories/base/base_psql_repository.py @@ -1,133 +1,642 @@ import datetime as dt +import re from copy import deepcopy +from datetime import datetime from dataclasses import fields, make_dataclass from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, Type -from sqlalchemy import delete, exists, func, insert, inspect, select, Select, String, text, update +from sqlalchemy import ( + delete, + exists, + func, + insert, + inspect, + select, + Select, + String, + text, + update, + Column, + JSON, + DateTime, + Boolean, + Float, + Integer, +) from src.app.infrastructure.extensions.psql_ext.psql_ext import Base, get_session -from src.app.infrastructure.repositories.base.abstract import AbstractBaseRepository, OuterGenericType, BaseModel +from src.app.infrastructure.repositories.base.abstract import ( + AbstractBaseRepository, + OuterGenericType, + BaseModel, + RepositoryError, +) from src.app.infrastructure.utils.common import generate_str -class BasePSQLRepository(AbstractBaseRepository[OuterGenericType], Generic[OuterGenericType]): - MODEL: Optional[Type[Base]] = None - - __ATR_SEPARATOR: str = "__" +class PSQLLookupRegistry: LOOKUP_MAP = { - "gt": lambda stmt, key1, _, v: stmt.where(key1 > v), - "gte": lambda stmt, key1, _, v: stmt.where(key1 >= v), - "lt": lambda stmt, key1, _, v: stmt.where(key1 < v), - "lte": lambda stmt, key1, _, v: stmt.where(key1 <= v), - "e": lambda stmt, key1, _, v: stmt.where(key1 == v), - "ne": lambda stmt, key1, _, v: stmt.where(key1 != v), - "in": lambda stmt, key1, _, v: stmt.where(key1.in_(v)), # does not work with None - "not_in": lambda stmt, key1, _, v: stmt.where(key1.not_in(v)), # does not work with None - "like": lambda stmt, key1, _, v: stmt.filter(key1.cast(String).like(f"%{str(v)}%")), - "not_like_all": lambda stmt, key1, _, v: BasePSQLRepository.__not_like_all(stmt, key1, v), - "jsonb_like": lambda stmt, key1, key_2, v: BasePSQLRepository.__jsonb_like(stmt, key1, key_2, v), - "jsonb_not_like": lambda stmt, key1, key_2, v: BasePSQLRepository.__jsonb_not_like(stmt, key1, key_2, v), + "gt": lambda stmt, key1, _, v: PSQLLookupRegistry._greater_than(stmt, key1, v), + "gte": lambda stmt, key1, _, v: PSQLLookupRegistry._greater_than_equal(stmt, key1, v), + "lt": lambda stmt, key1, _, v: PSQLLookupRegistry._less_than(stmt, key1, v), + "lte": lambda stmt, key1, _, v: PSQLLookupRegistry._less_than_equal(stmt, key1, v), + "e": lambda stmt, key1, _, v: PSQLLookupRegistry._equal(stmt, key1, v), + "ne": lambda stmt, key1, _, v: PSQLLookupRegistry._not_equal(stmt, key1, v), + "in": lambda stmt, key1, _, v: PSQLLookupRegistry._in(stmt, key1, v), # does not work with None + "not_in": lambda stmt, key1, _, v: PSQLLookupRegistry._not_in(stmt, key1, v), # does not work with None + "ilike": lambda stmt, key1, _, v: PSQLLookupRegistry._ilike(stmt, key1, v), + "like": lambda stmt, key1, _, v: PSQLLookupRegistry._like(stmt, key1, v), + "not_like_all": lambda stmt, key1, _, v: PSQLLookupRegistry._not_like_all(stmt, key1, v), + "jsonb_like": lambda stmt, key1, key_2, v: PSQLLookupRegistry._jsonb_like(stmt, key1, key_2, v), + "jsonb_not_like": lambda stmt, key1, key_2, v: PSQLLookupRegistry._jsonb_not_like(stmt, key1, key_2, v), } + _JSONB_LOOKUPS = ( + "jsonb_like", + "jsonb_not_like", + ) + + @classmethod + def get_operation(cls, name: str) -> Callable: + """Get lookup operation by name""" + operation = cls.LOOKUP_MAP.get(name, None) + if not operation: + raise RepositoryError(f"Unknown lookup operation: '{name}'. Available: {list(cls.LOOKUP_MAP.keys())}") + return operation + + @classmethod + def apply_lookup( + cls, stmt: Any, column: Any, lookup: str, value: Any, jsonb_field: Optional[str] = None + ) -> Any: + """Apply lookup operation to statement""" + operation = cls.get_operation(lookup) + + if lookup in cls._JSONB_LOOKUPS: + return operation(stmt, column, jsonb_field, value) + else: + return operation(stmt, column, jsonb_field, value) + + # Core lookup operations + @staticmethod + def _equal(stmt: Any, column: Any, value: Any) -> Any: + """Equal comparison: column = value""" + return stmt.where(column == value) + + @staticmethod + def _not_equal(stmt: Any, column: Any, value: Any) -> Any: + """Not equal comparison: column != value""" + return stmt.where(column != value) + + @staticmethod + def _greater_than(stmt: Any, column: Any, value: Any) -> Any: + """Greater than comparison: column > value""" + return stmt.where(column > value) + + @staticmethod + def _greater_than_equal(stmt: Any, column: Any, value: Any) -> Any: + """Greater than or equal comparison: column >= value""" + return stmt.where(column >= value) + + @staticmethod + def _less_than(stmt: Any, column: Any, value: Any) -> Any: + """Less than comparison: column < value""" + return stmt.where(column < value) + + @staticmethod + def _less_than_equal(stmt: Any, column: Any, value: Any) -> Any: + """Less than or equal comparison: column <= value""" + return stmt.where(column <= value) + + @staticmethod + def _in(stmt: Any, column: Any, value: List[Any]) -> Any: + """IN comparison: column IN (values)""" + if not isinstance(value, (list, tuple)): + raise RepositoryError("IN lookup requires list or tuple value") + return stmt.where(column.in_(value)) + + @staticmethod + def _not_in(stmt: Any, column: Any, value: List[Any]) -> Any: + """NOT IN comparison: column NOT IN (values)""" + if not isinstance(value, (list, tuple)): + raise RepositoryError("NOT_IN lookup requires list or tuple value") + return stmt.where(column.not_in(value)) + + @staticmethod + def _like(stmt: Any, column: Any, value: Any) -> Any: + """LIKE comparison: column LIKE %value%""" + return stmt.filter(column.cast(String).like(f"%{str(value)}%")) @staticmethod - def __not_like_all(stmt: Any, k: Any, v: Any) -> Select: - for item in v: - stmt = stmt.filter(k.cast(String).like(f"%{str(item)}%")) + def _ilike(stmt: Any, column: Any, value: Any) -> Any: + """LIKE comparison: column LIKE %value%""" + return stmt.filter(column.cast(String).ilike(f"%{str(value)}%")) + + @staticmethod + def _not_like_all(stmt: Any, column: Any, value: List[str]) -> Select: + """NOT LIKE ALL: column NOT LIKE ALL values (all values must not match)""" + if not isinstance(value, (list, tuple)): + raise RepositoryError("NOT_LIKE_ALL lookup requires list or tuple value") + + for item in value: + stmt = stmt.filter(~column.cast(String).like(f"%{str(item)}%")) return stmt @staticmethod - def __jsonb_like(stmt: Any, key_1: Any, key_2: Any, v: Any) -> Select: + def _jsonb_like(stmt: Any, key_1: Any, key_2: Any, v: Any) -> Select: if not key_2: return stmt.where(key_1.cast(String).like(f"%{v}%")) else: - key_ = "jsonb_like" + generate_str(size=4) + value_param = "jsonb_like_val_" + generate_str(size=8) return stmt.where( - text(f"{key_1}->>'{key_2}' LIKE CONCAT('%', CAST(:{key_} AS TEXT), '%')").params(**{key_: str(v)}) + text(f"{key_1.name}->>:jsonb_key LIKE CONCAT('%', CAST(:{value_param} AS TEXT), '%')").params( + jsonb_key=str(key_2), **{value_param: str(v)} + ) ) @staticmethod - def __jsonb_not_like(stmt: Any, key_1: Any, key_2: Any, v: Any) -> Select: + def _jsonb_not_like(stmt: Any, key_1: Any, key_2: Any, v: Any) -> Select: if not key_2: return stmt.where(~key_1.cast(String).like(f"%{v}%")) else: - key_ = "jsonb_n_like" + generate_str(size=4) + value_param = "jsonb_not_like_val_" + generate_str(size=8) return stmt.where( - text(f"{key_1}->>'{key_2}' NOT LIKE CONCAT('%', CAST(:{key_} AS TEXT), '%')").params( - **{key_: str(v)} + text(f"{key_1.name}->>:jsonb_key NOT LIKE CONCAT('%', CAST(:{value_param} AS TEXT), '%')").params( + jsonb_key=str(key_2), **{value_param: str(v)} ) ) + +# ========================================== +# SECURITY AND VALIDATION CLASSES +# ========================================== + + +class SecurityConfig: + """Security configuration constants and patterns""" + + MAX_FILTER_COMPLEXITY = 50 + MAX_STRING_LENGTH = 5000 + MAX_LIST_LENGTH = 500 + KEY_MAX_LENGTH = 50 + DANGEROUS_STRINGS = [";", "--", "/*", "*/", "xp_", "sp_"] + ALLOWED_ORDER_PATTERN = re.compile(r"^-?[a-zA-Z_][a-zA-Z0-9_]*$") + ALLOWED_KEY_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") + + +class FilterKeyParser: + """Handles parsing of filter keys with lookup operations""" + + ATTRIBUTE_SEPARATOR = "__" + @classmethod - def _parse_filter_key(cls, key: str) -> Tuple[str, str, str]: # type: ignore - splitted: list = key.split(cls.__ATR_SEPARATOR) + def parse(cls, key: str) -> Tuple[str, str, str]: + """ + Parse filter key into components. + + Returns: + Tuple[column_name, jsonb_field, lookup_operation] + + Examples: + "name" -> ("name", "", "e") + "name__ilike" -> ("name", "", "ilike") + "meta__preferences__jsonb_like" -> ("meta", "preferences", "jsonb_like") + """ + parts = key.split(cls.ATTRIBUTE_SEPARATOR) + + if len(parts) == 1: + return parts[0], "", "e" + elif len(parts) == 2: + return parts[0], "", parts[1] + elif len(parts) == 3: + return parts[0], parts[1], parts[2] + else: + raise RepositoryError(f"Invalid filter key format: '{key}'. Too many separators.") + - if len(splitted) == 1: - key_1 = splitted[0] - key_2 = "" - return key_1, key_2, "e" - elif len(splitted) == 2: - key_1 = splitted[0] - key_2 = "" - lookup = splitted[1] - return key_1, key_2, lookup - elif len(splitted) == 3: - key_1 = splitted[0] - key_2 = splitted[1] - lookup = splitted[2] - return key_1, key_2, lookup +class SecurityValidator: + """Handles all security validation for query building""" @classmethod - def _parse_order_data(cls, order_data: Optional[Tuple[str]] = None) -> tuple: - if not order_data: - order_data = () # type: ignore - parsed_order_data = [] - - for order_item in order_data: # type: ignore - order_item_tmp = order_item - if order_item_tmp.startswith("-"): - order_item_tmp = order_item[1:] - parsed_order_data.append(getattr(cls.model(), order_item_tmp).desc()) - else: - parsed_order_data.append(getattr(cls.model(), order_item_tmp).asc()) + def validate_filter_complexity(cls, filter_data: Dict[str, Any]) -> None: + """Validate filter complexity to prevent DoS attacks""" + if len(filter_data) > SecurityConfig.MAX_FILTER_COMPLEXITY: + raise RepositoryError( + f"Filter complexity exceeds maximum allowed ({SecurityConfig.MAX_FILTER_COMPLEXITY})" + ) + + @classmethod + def validate_key_security(cls, key: str) -> None: + """Validate filter key for security""" + if len(key) > SecurityConfig.KEY_MAX_LENGTH: + raise RepositoryError(f"Key too long. Maximum {SecurityConfig.KEY_MAX_LENGTH} characters allowed.") + + if not SecurityConfig.ALLOWED_KEY_PATTERN.match(str(key)): + raise RepositoryError( + f"Invalid key format: '{key}'. Only alphanumeric characters and underscores allowed." + ) - return tuple(parsed_order_data) + # Check for potentially dangerous characters + if any(char in key for char in SecurityConfig.DANGEROUS_STRINGS): + raise RepositoryError(f"Key contains dangerous characters: '{key}'") @classmethod - def _parse_order_data_for_target(cls, target: Base, order_data: Optional[Tuple[str]] = None) -> tuple: - if not order_data: - order_data = () # type: ignore - parsed_order_data = [] - - for order_item in order_data: # type: ignore - order_item_tmp = order_item - if order_item_tmp.startswith("-"): - order_item_tmp = order_item[1:] - parsed_order_data.append(getattr(target, order_item_tmp).desc()) - else: - parsed_order_data.append(getattr(target, order_item_tmp).asc()) + def validate_value_security(cls, value: Any) -> None: + """Validate filter value for security""" + if isinstance(value, str): + if len(value) > SecurityConfig.MAX_STRING_LENGTH: + raise RepositoryError( + f"String value too long. Maximum {SecurityConfig.MAX_STRING_LENGTH} characters allowed." + ) + elif isinstance(value, (list, tuple)): + if len(value) > SecurityConfig.MAX_LIST_LENGTH: + raise RepositoryError( + f"List value too long. Maximum {SecurityConfig.MAX_LIST_LENGTH} items allowed." + ) + for item in value: + cls.validate_value_security(item) + + @classmethod + def validate_order_field(cls, order_field: str) -> None: + """Validate order field for security""" + if not SecurityConfig.ALLOWED_ORDER_PATTERN.match(order_field): + raise RepositoryError( + f"Invalid order field format: '{order_field}'. " + f"Only alphanumeric characters, underscores, and optional leading dash allowed." + ) + + if len(order_field) > SecurityConfig.KEY_MAX_LENGTH: + raise RepositoryError( + f"Order field too long. Maximum {SecurityConfig.KEY_MAX_LENGTH} characters allowed." + ) + + +class TypeValidator: + """Handles type validation for different column types""" + + TYPE_VALIDATORS = { + String: "_validate_string_type", + Integer: "_validate_integer_type", + Float: "_validate_float_type", + Boolean: "_validate_boolean_type", + DateTime: "_validate_datetime_type", + JSON: "_validate_json_type", + } + + @classmethod + def validate_value_type(cls, key: str, value: Any, column_type: Any) -> None: + """Validate a single value against column type""" + for type_class, validator_method in cls.TYPE_VALIDATORS.items(): + if isinstance(column_type, type_class): + getattr(cls, validator_method)(key, value) + return + + @classmethod + def _validate_string_type(cls, key: str, value: Any) -> None: + """Validate string type value""" + if not isinstance(value, str): + raise RepositoryError(f"Column '{key}' expects string value, got {type(value).__name__}") + + @classmethod + def _validate_integer_type(cls, key: str, value: Any) -> None: + """Validate integer type value""" + if not isinstance(value, int): + raise RepositoryError(f"Column '{key}' expects integer value, got {type(value).__name__}") + + @classmethod + def _validate_float_type(cls, key: str, value: Any) -> None: + """Validate float type value""" + if not isinstance(value, (int, float)): + raise RepositoryError(f"Column '{key}' expects numeric value, got {type(value).__name__}") + + @classmethod + def _validate_boolean_type(cls, key: str, value: Any) -> None: + """Validate boolean type value""" + if not isinstance(value, bool): + raise RepositoryError(f"Column '{key}' expects boolean value, got {type(value).__name__}") + + @classmethod + def _validate_datetime_type(cls, key: str, value: Any) -> None: + """Validate datetime type value""" + + if not isinstance(value, datetime): + raise RepositoryError(f"Column '{key}' expects datetime value, got {type(value).__name__}") + + @classmethod + def _validate_json_type(cls, key: str, value: Any) -> None: + """Validate JSON type value""" + if not isinstance(value, (dict, list, str, int, float, bool)): + raise RepositoryError(f"Column '{key}' expects JSON-compatible value, got {type(value).__name__}") + + +# ========================================== +# MAIN QUERY BUILDER CLASS +# ========================================== + + +class QueryBuilder: + """ + Main query builder class responsible for constructing SQL queries with security validations. + + This class is organized into logical sections: + - Configuration and Setup + - Column Management + - Filter Processing + - Ordering and Pagination + - Validation Orchestration + """ + + # ========================================== + # CONFIGURATION + # ========================================== + + LOOKUP_REGISTRY_CLASS = PSQLLookupRegistry + PAGINATION_KEYS = ["limit", "offset"] + _MODEL_COLUMNS_CACHE: Dict[str, Dict[str, Column]] = {} + + # ========================================== + # CORE QUERY BUILDING METHODS + # ========================================== + + @classmethod + def lookup_registry(cls) -> Type[PSQLLookupRegistry]: + """Get the lookup registry for SQL operations""" + return cls.LOOKUP_REGISTRY_CLASS + + @classmethod + def _get_model_columns(cls, model_class: Type[Base]) -> Dict[str, Column]: + """Get all columns from the model with caching""" + model_name = model_class.__name__ + + if model_name not in cls._MODEL_COLUMNS_CACHE: + inspector = inspect(model_class) + cls._MODEL_COLUMNS_CACHE[model_name] = {col.name: col for col in inspector.columns} + + return cls._MODEL_COLUMNS_CACHE[model_name] + + @classmethod + def validate_model_key(cls, key: str, model_class: Type[Base]) -> Column: + """Validate that a key exists in the model and return the column""" + column_ = cls._get_model_columns(model_class).get(key, None) + if column_ is None: + raise RepositoryError(f"Column '{key}' does not exist in model {model_class.__name__}") + return column_ + + # ========================================== + # FILTER VALIDATION METHODS + # ========================================== + + @classmethod + def validate_filter_value(cls, column: Column, key: str, value: Any, lookup: str) -> None: + """ + Comprehensive validation of filter values. + + Validates: + - None values against nullable columns + - Security constraints + - Type compatibility + - Lookup-specific requirements + """ + # Check nullable constraints + if cls._is_none_value_invalid(column, value): + raise RepositoryError(f"Column '{key}' cannot be None (not nullable)") + + if value is None: + return # None is valid for nullable columns + + # Security validation + SecurityValidator.validate_value_security(value) + + # Lookup-specific validation + if cls._is_list_based_lookup(lookup): + cls._validate_list_lookup_values(key, value, column.type) + elif cls._is_string_convertible_lookup(lookup): + cls._validate_string_convertible_lookup(key, value, lookup) + else: + # Type validation for single values + TypeValidator.validate_value_type(key, value, column.type) + + @classmethod + def _is_none_value_invalid(cls, column: Column, value: Any) -> bool: + """Check if None value is invalid for the column""" + return value is None and not column.nullable + + @classmethod + def _is_list_based_lookup(cls, lookup: str) -> bool: + """Check if lookup requires list/tuple values""" + return lookup in ("in", "not_in") - return tuple(parsed_order_data) + @classmethod + def _is_string_convertible_lookup(cls, lookup: str) -> bool: + """Check if lookup converts values to strings""" + return lookup in ("not_like_all", "like", "jsonb_like", "jsonb_not_like", "ilike") + + @classmethod + def _validate_list_lookup_values(cls, key: str, value: Any, column_type: Any) -> None: + """Validate values for list-based lookups (IN, NOT IN)""" + if not isinstance(value, (list, tuple)): + raise RepositoryError(f"List-based lookup for column '{key}' requires list/tuple value") + + # Validate each item in the list + for item in value: + if item is not None: + TypeValidator.validate_value_type(key, item, column_type) @classmethod - def _apply_where(cls, stmt: Any, filter_data: dict) -> Any: + def _validate_string_convertible_lookup(cls, key: str, value: Any, lookup: str) -> None: + """Validate values for string-convertible lookups (LIKE, ILIKE, etc.)""" + if lookup == "not_like_all" and not isinstance(value, (list, tuple)): + raise RepositoryError(f"Lookup 'not_like_all' for column '{key}' requires list/tuple value") + + # ========================================== + # MAIN QUERY PROCESSING METHODS + # ========================================== + + @classmethod + def apply_where(cls, stmt: Any, filter_data: Optional[Dict[str, Any]], model_class: Type[Base]) -> Any: + """ + Apply WHERE clauses to a SQL statement based on filter data. + + Args: + stmt: SQLAlchemy statement to modify + filter_data: Dictionary of filter conditions + model_class: SQLAlchemy model class for validation + + Returns: + Modified SQLAlchemy statement with WHERE clauses applied + """ + if not filter_data: + return stmt + + # Security validation + SecurityValidator.validate_filter_complexity(filter_data) + + # Process each filter for key, value in filter_data.items(): - key_1, key_2, lookup = cls._parse_filter_key(key) - key_1_ = getattr(cls.model(), key_1, None) - key_2_ = key_2 - if "jsonb" in lookup and key_2: - key_1_ = key_1 - key_2_ = key_2 - stmt = cls.LOOKUP_MAP[lookup](stmt, key_1_, key_2_, value) + if key in cls.PAGINATION_KEYS: + continue + + # Parse and validate the filter key + column_name, jsonb_field, lookup = FilterKeyParser.parse(key) + + # Security validation + SecurityValidator.validate_key_security(column_name) + if jsonb_field: + SecurityValidator.validate_key_security(jsonb_field) + + # Validate column exists and get column object + column = cls.validate_model_key(column_name, model_class) + + # Comprehensive value validation + cls.validate_filter_value(column, key, value, lookup) + + # Apply the lookup operation + stmt = cls.lookup_registry().apply_lookup( + stmt=stmt, column=column, lookup=lookup, value=value, jsonb_field=jsonb_field + ) + return stmt + @classmethod + def apply_ordering(cls, stmt: Any, order_data: Optional[Tuple[str, ...]], model_class: Type[Base]) -> Any: + """ + Apply ORDER BY clause to statement. + + Args: + stmt: SQLAlchemy statement to modify + order_data: Tuple of field names for ordering (prefix with "-" for DESC) + model_class: SQLAlchemy model class for validation + + Returns: + Modified SQLAlchemy statement with ORDER BY clause applied + """ + if not order_data: + return stmt + + try: + parsed_order = cls._parse_order_data(order_data, model_class) + return stmt.order_by(*parsed_order) + except Exception as e: + raise RepositoryError(f"Failed to apply ordering: {str(e)}") + + @classmethod + def apply_pagination(cls, stmt: Any, filter_data: Optional[Dict[str, Any]] = None) -> Any: + """ + Apply LIMIT and OFFSET to statement for pagination. + + Args: + stmt: SQLAlchemy statement to modify + filter_data: Dictionary containing "limit" and "offset" keys + + Returns: + Modified SQLAlchemy statement with pagination applied + """ + if not filter_data: + return stmt + + # Extract and validate pagination parameters + limit = filter_data.get("limit") + offset = filter_data.get("offset", 0) + + # Apply offset if provided + if offset: + if not isinstance(offset, int) or offset < 0: + raise RepositoryError(f"Offset must be non-negative integer, got: {offset}") + stmt = stmt.offset(offset) + + # Apply limit if provided + if limit is not None: + if not isinstance(limit, int) or limit <= 0: + raise RepositoryError(f"Limit must be positive integer, got: {limit}") + stmt = stmt.limit(limit) + + return stmt + + # ========================================== + # HELPER METHODS + # ========================================== + + @classmethod + def _parse_order_data(cls, order_data: Tuple[str, ...], model_class: Type[Base]) -> List[Any]: + """ + Parse order data into SQLAlchemy order clauses. + + Args: + order_data: Tuple of field names, optionally prefixed with "-" for descending order + model_class: SQLAlchemy model class for validation + + Returns: + List of SQLAlchemy order clauses + + Example: + ("name", "-created_at") -> [Column.asc(), Column.desc()] + """ + parsed_order = [] + + for order_item in order_data: + if not isinstance(order_item, str): + raise RepositoryError(f"Order field must be string, got: {type(order_item).__name__}") + + # Security validation + SecurityValidator.validate_order_field(order_item) + + # Parse direction and field name + if order_item.startswith("-"): + field_name = order_item[1:] + direction = "desc" + else: + field_name = order_item + direction = "asc" + + # Validate field exists in model and create order clause + try: + column = cls.validate_model_key(field_name, model_class) + parsed_order.append(getattr(column, direction)()) + except Exception as e: + raise RepositoryError(f"Invalid order field '{field_name}': {str(e)}") + + return parsed_order + + +class BasePSQLRepository(AbstractBaseRepository[OuterGenericType], Generic[OuterGenericType]): + """ + Base PostgreSQL repository with CRUD operations and bulk operations support. + + Organized into logical sections: + - Configuration and Setup + - Dataclass Helpers + - Read Operations + - Write Operations + - Bulk Operations + - Utility Methods + """ + + MODEL: Optional[Type[Base]] = None + _QUERY_BUILDER_CLASS: Type[QueryBuilder] = QueryBuilder + + # ========================================== + # CONFIGURATION AND SETUP + # ========================================== + + @classmethod + def query_builder(cls) -> Type[QueryBuilder]: + """Get the query builder class for this repository""" + if not cls._QUERY_BUILDER_CLASS: + raise AttributeError("Query builder class not configured") + return cls._QUERY_BUILDER_CLASS + @classmethod def model(cls) -> Type[BaseModel]: + """Get the SQLAlchemy model class for this repository""" if not cls.MODEL: - raise AttributeError + raise AttributeError("Model class not configured") return cls.MODEL + # ========================================== + # DATACLASS HELPERS + # ========================================== + @classmethod - def __make_out_dataclass(cls) -> Tuple[Callable, List[str]]: + def _create_dynamic_dataclass(cls) -> Tuple[Callable, List[str]]: + """Create a dynamic dataclass from the model structure""" model = cls.model() # type: ignore columns = inspect(model).c field_names = [column.name for column in columns] @@ -146,25 +655,29 @@ def __make_out_dataclass(cls) -> Tuple[Callable, List[str]]: def out_dataclass_with_columns( cls, out_dataclass: Optional[OuterGenericType] = None ) -> Tuple[Callable, List[str]]: + """Get output dataclass and column names for result conversion""" if not out_dataclass: - out_dataclass_, columns = cls.__make_out_dataclass() + out_dataclass_, columns = cls._create_dynamic_dataclass() else: out_dataclass_ = out_dataclass # type: ignore columns = [f.name for f in fields(out_dataclass_)] # type: ignore return out_dataclass_, columns + # ========================================== + # CRUD OPERATIONS + # ========================================== + @classmethod async def count(cls, filter_data: Optional[dict] = None) -> int: + """Count records matching the filter criteria""" if not filter_data: filter_data = {} - filter_data_ = deepcopy(filter_data) - filter_data_.pop("limit", "") - filter_data_.pop("offset", "") + filter_data_ = filter_data.copy() if filter_data else {} stmt: Select = select(func.count(cls.model().id)) # type: ignore - stmt = cls._apply_where(stmt, filter_data=filter_data_) + stmt = cls.query_builder().apply_where(stmt, filter_data=filter_data_, model_class=cls.model()) async with get_session(expire_on_commit=True) as session: result = await session.execute(stmt) @@ -172,13 +685,11 @@ async def count(cls, filter_data: Optional[dict] = None) -> int: @classmethod async def is_exists(cls, filter_data: dict) -> bool: - - filter_data_ = deepcopy(filter_data) - filter_data_.pop("limit", "") - filter_data_.pop("offset", "") + """Check if any records exist matching the filter criteria""" + filter_data_ = filter_data.copy() stmt = select(exists(cls.model())) - stmt = cls._apply_where(stmt, filter_data=filter_data_) + stmt = cls.query_builder().apply_where(stmt, filter_data=filter_data_, model_class=cls.model()) async with get_session() as session: result = await session.execute(stmt) @@ -189,12 +700,11 @@ async def is_exists(cls, filter_data: dict) -> bool: async def get_first( cls, filter_data: dict, out_dataclass: Optional[OuterGenericType] = None ) -> OuterGenericType | None: - filter_data_ = deepcopy(filter_data) - filter_data_.pop("limit", "") - filter_data_.pop("offset", "") + """Get the first record matching the filter criteria""" + filter_data_ = filter_data.copy() stmt: Select = select(cls.model()) - stmt = cls._apply_where(stmt, filter_data=filter_data_) + stmt = cls.query_builder().apply_where(stmt, filter_data=filter_data_, model_class=cls.model()) async with get_session(expire_on_commit=True) as session: result = await session.execute(stmt) @@ -213,17 +723,15 @@ async def get_list( order_data: Optional[Tuple[str]] = ("id",), out_dataclass: Optional[OuterGenericType] = None, ) -> List[OuterGenericType]: + """Get a list of records matching the filter criteria with pagination and ordering""" if not filter_data: filter_data = {} - limit = filter_data.pop("limit", None) - offset = filter_data.pop("offset", 0) + filter_data_ = filter_data.copy() stmt: Select = select(cls.model()) - stmt = cls._apply_where(stmt, filter_data=filter_data) - stmt = stmt.order_by(*cls._parse_order_data(order_data)) - stmt = stmt.offset(offset) - if limit is not None: - stmt = stmt.limit(limit) + stmt = cls.query_builder().apply_where(stmt, filter_data=filter_data_, model_class=cls.model()) + stmt = cls.query_builder().apply_ordering(stmt, order_data=order_data, model_class=cls.model()) + stmt = cls.query_builder().apply_pagination(stmt, filter_data=filter_data_) async with get_session(expire_on_commit=True) as session: result = await session.execute(stmt) @@ -241,6 +749,7 @@ async def get_list( async def create( cls, data: dict, is_return_require: bool = False, out_dataclass: Optional[OuterGenericType] = None ) -> OuterGenericType | None: + """Create a single record""" data_copy = data.copy() # Handle explicit ID if provided, otherwise let database auto-increment @@ -274,10 +783,78 @@ async def create( return None + @classmethod + async def update( + cls, + filter_data: dict, + data: Dict[str, Any], + is_return_require: bool = False, + out_dataclass: Optional[OuterGenericType] = None, + ) -> OuterGenericType | None: + """Update records matching the filter criteria""" + data_copy = data.copy() + + stmt = update(cls.model()) + stmt = cls.query_builder().apply_where(stmt, filter_data=filter_data, model_class=cls.model()) + + cls._set_timestamps_on_update(items=[data_copy]) + + stmt = stmt.values(**data_copy) + stmt.execution_options(synchronize_session="fetch") + + async with get_session(expire_on_commit=True) as session: + await session.execute(stmt) + await session.commit() + + if is_return_require: + return await cls.get_first(filter_data=filter_data, out_dataclass=out_dataclass) + return None + + @classmethod + async def update_or_create( + cls, + filter_data: dict, + data: Dict[str, Any], + is_return_require: bool = False, + out_dataclass: Optional[OuterGenericType] = None, + ) -> OuterGenericType | None: + """Update existing record or create new one if not found""" + is_exists = await cls.is_exists(filter_data=filter_data) + if is_exists: + data_tmp = deepcopy(data) + data_tmp.pop("id", None) + data_tmp.pop("uuid", None) + item = await cls.update( + filter_data=filter_data, + data=data_tmp, + is_return_require=is_return_require, + out_dataclass=out_dataclass, + ) + return item + else: + item = await cls.create(data=data, is_return_require=is_return_require, out_dataclass=out_dataclass) + return item + + @classmethod + async def remove( + cls, + filter_data: Dict[str, Any], + ) -> None: + """Delete records matching the filter criteria""" + if not filter_data: + filter_data = {} + stmt = delete(cls.model()) + stmt = cls.query_builder().apply_where(stmt, filter_data=filter_data, model_class=cls.model()) + + async with get_session() as session: + await session.execute(stmt) + await session.commit() + @classmethod async def create_bulk( cls, items: List[dict], is_return_require: bool = False, out_dataclass: Optional[OuterGenericType] = None ) -> List[OuterGenericType] | None: + """Create multiple records in a single operation""" if not items: return [] @@ -286,9 +863,6 @@ async def create_bulk( # Add timestamps to all items cls._set_timestamps_on_create(items=items_copy) - # Normalize data to handle mixed completeness - cls._normalize_bulk_data(items=items_copy) - async with get_session(expire_on_commit=True) as session: model_class = cls.model() # type: ignore model_table = model_class.__table__ # type: ignore @@ -314,36 +888,85 @@ async def create_bulk( return None @classmethod - async def update( - cls, - filter_data: dict, - data: Dict[str, Any], - is_return_require: bool = False, - out_dataclass: Optional[OuterGenericType] = None, - ) -> OuterGenericType | None: - data_copy = deepcopy(data) + async def update_bulk( + cls, items: List[dict], is_return_require: bool = False, out_dataclass: Optional[OuterGenericType] = None + ) -> List[OuterGenericType] | None: + """Update multiple records in optimized bulk operation - stmt = update(cls.model()) - stmt = cls._apply_where(stmt, filter_data=filter_data) + Note: Currently uses 2 queries for returning case: + - Option 1: Keep current ORM approach (cleaner, 2 queries for returning) + - Option 2: Go back to raw SQL (1 query, but more complex) + - Option 3: Hybrid approach - use ORM for non-returning, raw SQL for returning + """ + if not items: + return None - cls._set_timestamps_on_update(items=[data_copy]) + items_copy = deepcopy(items) - stmt = stmt.values(**data_copy) - stmt.execution_options(synchronize_session="fetch") + cls._set_timestamps_on_update(items=items_copy) async with get_session(expire_on_commit=True) as session: - await session.execute(stmt) - await session.commit() + if is_return_require: + return await cls._bulk_update_with_returning(session, items_copy, out_dataclass) + else: + await cls._bulk_update_without_returning(session, items_copy) + return None - if is_return_require: - return await cls.get_first(filter_data=filter_data, out_dataclass=out_dataclass) - return None + # ========================================== + # BULK OPERATION HELPERS + # ========================================== + + @classmethod + async def _bulk_update_with_returning( + cls, session: Any, items: List[dict], out_dataclass: Optional[OuterGenericType] = None + ) -> List[OuterGenericType]: + """Perform bulk update with RETURNING for result collection using ORM""" + if not items: + return [] + + model_class = cls.model() # type: ignore + + # Use SQLAlchemy's bulk_update_mappings with synchronize_session=False for performance + await session.execute(update(model_class), items, execution_options={"synchronize_session": False}) + await session.commit() + + # Get updated items by their IDs + updated_ids = [item["id"] for item in items if "id" in item] + if not updated_ids: + return [] + + # Query the updated records + stmt = select(model_class).where(model_class.id.in_(updated_ids)) + result = await session.execute(stmt) + updated_records = result.scalars().all() + + # Convert to output dataclass + out_entity_, _ = cls.out_dataclass_with_columns(out_dataclass=out_dataclass) + updated_items = [] + + for record in updated_records: + entity_data = {c.key: getattr(record, c.key) for c in inspect(record).mapper.column_attrs} + updated_items.append(out_entity_(**entity_data)) + + return updated_items + + @classmethod + async def _bulk_update_without_returning(cls, session: Any, items: List[dict]) -> None: + """Perform bulk update without RETURNING using SQLAlchemy's bulk operations""" + if not items: + return + + model_class = cls.model() # type: ignore + + # Use SQLAlchemy's built-in bulk update method + await session.execute(update(model_class), items, execution_options={"synchronize_session": False}) + await session.commit() @classmethod async def _update_single_with_returning( cls, session: Any, item_data: dict, out_entity_: Callable ) -> OuterGenericType | None: - """Update a single item and return the updated entity""" + """Update a single item and return the updated entity (legacy method)""" if "id" not in item_data: return None @@ -366,36 +989,13 @@ async def _update_single_with_returning( return out_entity_(**entity_data) return None - @classmethod - async def update_bulk( - cls, items: List[dict], is_return_require: bool = False, out_dataclass: Optional[OuterGenericType] = None - ) -> List[OuterGenericType] | None: - if not items: - return None - - items_copy = deepcopy(items) - - cls._set_timestamps_on_update(items=items_copy) - - async with get_session(expire_on_commit=True) as session: - if is_return_require: - return await cls._bulk_update_with_returning(session, items_copy, out_dataclass) - else: - await cls._bulk_update_without_returning(session, items_copy) - return None - - @classmethod - def _set_timestamps_on_update(cls, items: List[dict]) -> None: - """Set updated_at on update""" - if hasattr(cls.model(), "updated_at"): - dt_ = dt.datetime.now(dt.UTC).replace(tzinfo=None) - for item in items: - if "updated_at" not in item: - item["updated_at"] = dt_ + # ========================================== + # UTILITY METHODS + # ========================================== @classmethod def _set_timestamps_on_create(cls, items: List[dict]) -> None: - """Set created_at, updated_at on create""" + """Set created_at, updated_at timestamps on create operations""" if hasattr(cls.model(), "updated_at") or hasattr(cls.model(), "created_at"): dt_ = dt.datetime.now(dt.UTC).replace(tzinfo=None) for item in items: @@ -405,89 +1005,10 @@ def _set_timestamps_on_create(cls, items: List[dict]) -> None: item["created_at"] = dt_ @classmethod - def _normalize_bulk_data(cls, items: List[dict]) -> None: - """Normalize bulk data to handle mixed field completeness""" - if not items: - return - - # Get all unique keys from all items - all_keys: set[str] = set() - for item in items: - all_keys.update(item.keys()) - - # Get model column defaults and nullable info - model_class = cls.model() # type: ignore - model_table = model_class.__table__ # type: ignore - - # For each item, ensure it has all fields with appropriate defaults - for item in items: - for key in all_keys: - if key not in item: - # Check if column exists in model - if hasattr(model_class, key): - column = getattr(model_table.c, key, None) - if column is not None: - # Only add explicit None if column is nullable and has no default - if column.nullable and column.default is None and column.server_default is None: - item[key] = None - # Don't add anything for columns with defaults - let database handle it - - @classmethod - async def _bulk_update_with_returning( - cls, session: Any, items: List[dict], out_dataclass: Optional[OuterGenericType] = None - ) -> List[OuterGenericType]: - """Perform bulk update with RETURNING for result collection""" - updated_items = [] - out_entity_, _ = cls.out_dataclass_with_columns(out_dataclass=out_dataclass) - - for item_data in items: - updated_item = await cls._update_single_with_returning(session, item_data, out_entity_) - if updated_item: - updated_items.append(updated_item) - - await session.commit() - return updated_items - - @classmethod - async def _bulk_update_without_returning(cls, session: Any, items: List[dict]) -> None: - """Perform bulk update without RETURNING for better performance""" - await session.execute(update(cls.model()), items) - await session.commit() - - @classmethod - async def update_or_create( - cls, - filter_data: dict, - data: Dict[str, Any], - is_return_require: bool = False, - out_dataclass: Optional[OuterGenericType] = None, - ) -> OuterGenericType | None: - is_exists = await cls.is_exists(filter_data=filter_data) - if is_exists: - data_tmp = deepcopy(data) - data_tmp.pop("id", None) - data_tmp.pop("uuid", None) - item = await cls.update( - filter_data=filter_data, - data=data_tmp, - is_return_require=is_return_require, - out_dataclass=out_dataclass, - ) - return item - else: - item = await cls.create(data=data, is_return_require=is_return_require, out_dataclass=out_dataclass) - return item - - @classmethod - async def remove( - cls, - filter_data: Dict[str, Any], - ) -> None: - if not filter_data: - filter_data = {} - stmt = delete(cls.model()) - stmt = cls._apply_where(stmt, filter_data=filter_data) - - async with get_session() as session: - await session.execute(stmt) - await session.commit() + def _set_timestamps_on_update(cls, items: List[dict]) -> None: + """Set updated_at timestamp on update operations""" + if hasattr(cls.model(), "updated_at"): + dt_ = dt.datetime.now(dt.UTC).replace(tzinfo=None) + for item in items: + if "updated_at" not in item: + item["updated_at"] = dt_ diff --git a/src/app/infrastructure/repositories/base/base_redis_repository.py b/src/app/infrastructure/repositories/base/base_redis_repository.py new file mode 100644 index 0000000..3ac08e5 --- /dev/null +++ b/src/app/infrastructure/repositories/base/base_redis_repository.py @@ -0,0 +1,46 @@ +import json +from typing import Any + +import redis.asyncio as redis + +from src.app.infrastructure.extensions.redis_ext.redis_ext import redis_client +from src.app.infrastructure.repositories.base.abstract import AbstractRepository + + +class BaseRedisRepository(AbstractRepository): + client: redis.Redis = redis_client + + @classmethod + def get_client(cls) -> redis.Redis: + return cls.client + + @classmethod + async def set(cls, key: str, value: dict, expire_in_seconds: int) -> None: + client = cls.get_client() + value_ = json.dumps(value, default=str) + await client.setex(name=key, value=value_, time=expire_in_seconds) + + @classmethod + async def get(cls, key: str) -> Any: + client = cls.get_client() + value_ = await client.get(name=key) + if value_: + return json.loads(value_) + return None + + @classmethod + async def delete(cls, keys: list) -> Any: + client = cls.get_client() + for key in keys: + await client.delete(key) + return None + + @classmethod + async def exists(cls, key: str) -> bool: + client = cls.get_client() + return await client.exists(key) + + @classmethod + async def flush_db(cls) -> None: + client = cls.get_client() + await client.flushdb(asynchronous=True) diff --git a/src/app/infrastructure/repositories/common_psql_repository.py b/src/app/infrastructure/repositories/common_psql_repository.py new file mode 100644 index 0000000..27da23b --- /dev/null +++ b/src/app/infrastructure/repositories/common_psql_repository.py @@ -0,0 +1,16 @@ +from sqlalchemy import text + +from src.app.infrastructure.extensions.psql_ext.psql_ext import get_session +from src.app.infrastructure.repositories.base.abstract import AbstractRepository + + +class CommonPSQLRepository(AbstractRepository): + + @classmethod + async def is_healthy(cls) -> bool: + stmt = """SELECT 1;""" + + async with get_session() as session: + result = await session.execute(statement=text(stmt), params={}) + result = result.scalars().first() + return result == 1 diff --git a/src/app/infrastructure/repositories/common_redis_repository.py b/src/app/infrastructure/repositories/common_redis_repository.py new file mode 100644 index 0000000..512e5bb --- /dev/null +++ b/src/app/infrastructure/repositories/common_redis_repository.py @@ -0,0 +1,12 @@ +from src.app.infrastructure.extensions.redis_ext.redis_ext import redis_client +from src.app.infrastructure.repositories.base.base_redis_repository import BaseRedisRepository + + +class CommonRedisRepository(BaseRedisRepository): + client = redis_client + + @classmethod + async def is_healthy(cls) -> bool: + client = cls.get_client() + result = await client.ping() + return result diff --git a/src/app/infrastructure/repositories/container.py b/src/app/infrastructure/repositories/container.py index ce765e9..fb94d0d 100644 --- a/src/app/infrastructure/repositories/container.py +++ b/src/app/infrastructure/repositories/container.py @@ -1,10 +1,19 @@ from typing import NamedTuple, Type +from src.app.infrastructure.repositories.common_psql_repository import CommonPSQLRepository +from src.app.infrastructure.repositories.common_redis_repository import CommonRedisRepository from src.app.infrastructure.repositories.users_repository import UsersPSQLRepository class RepositoriesContainer(NamedTuple): + + common_psql_repository: Type[CommonPSQLRepository] + common_redis_repository: Type[CommonRedisRepository] users_repository: Type[UsersPSQLRepository] -container = RepositoriesContainer(users_repository=UsersPSQLRepository) +container = RepositoriesContainer( + common_psql_repository=CommonPSQLRepository, + common_redis_repository=CommonRedisRepository, + users_repository=UsersPSQLRepository, +) diff --git a/src/app/interfaces/api/v1/endpoints/debug/resources.py b/src/app/interfaces/api/v1/endpoints/debug/resources.py index b3d7be2..b981c05 100644 --- a/src/app/interfaces/api/v1/endpoints/debug/resources.py +++ b/src/app/interfaces/api/v1/endpoints/debug/resources.py @@ -1,7 +1,7 @@ from typing import Annotated, Dict from fastapi import APIRouter, Body, Request - +from src.app.application.container import container as services_container from src.app.interfaces.api.v1.endpoints.debug.schemas.req_schemas import MessageReq from src.app.config.settings import settings from src.app.infrastructure.messaging.mq_client import mq_client @@ -42,4 +42,6 @@ async def send_message( async def health_check( request: Request, ) -> Dict[str, str]: - return {"status": "ok"} + is_healthy = await services_container.common_service.is_healthy() + status = "OK" if is_healthy else "NOT OK" + return {"status": status} diff --git a/src/app/interfaces/cli/consume.py b/src/app/interfaces/cli/consume.py index 5e2f9a2..fa5f96f 100644 --- a/src/app/interfaces/cli/consume.py +++ b/src/app/interfaces/cli/consume.py @@ -45,6 +45,28 @@ async def queue_processing_aggregator(data: dict, handlers_by_event: Dict[str, D asyncio.set_event_loop(e_loop) try: + + # ================================================= + # WAIT FOR READINESS + # ================================================= + sleep_before = 120 + slept = 10 + logger.info("Waiting for readiness ..") + e_loop.run_until_complete(asyncio.sleep(slept)) + + is_healthy = e_loop.run_until_complete(mq_client.is_healthy()) + while slept < sleep_before and not is_healthy: + logger.info(f"Waiting for readiness {slept}/{sleep_before} sec..") + sleep_ = 15 + e_loop.run_until_complete(asyncio.sleep(sleep_)) + slept += sleep_ + is_healthy = e_loop.run_until_complete(mq_client.is_healthy()) + logger.info("READY.." if is_healthy else "NOT READY!") + + # ================================================= + # RUN CONSUMER + # ================================================= + handlers_by_event_ = HANDLERS_MAP aggregator_ = queue_processing_aggregator e_loop.run_until_complete( diff --git a/src/app/interfaces/grpc/pb/debug/debug_pb2.py b/src/app/interfaces/grpc/pb/debug/debug_pb2.py index 0b3b748..2c1c084 100644 --- a/src/app/interfaces/grpc/pb/debug/debug_pb2.py +++ b/src/app/interfaces/grpc/pb/debug/debug_pb2.py @@ -2,7 +2,7 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! # NO CHECKED-IN PROTOBUF GENCODE # source: debug.proto -# Protobuf Python Version: 5.29.0 +# Protobuf Python Version: 6.31.1 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool @@ -10,7 +10,7 @@ from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder -_runtime_version.ValidateProtobufRuntimeVersion(_runtime_version.Domain.PUBLIC, 5, 29, 0, "", "debug.proto") +_runtime_version.ValidateProtobufRuntimeVersion(_runtime_version.Domain.PUBLIC, 6, 31, 1, "", "debug.proto") # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -20,7 +20,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x0b\x64\x65\x62ug.proto\x12\rgrpc.pb.debug\x1a\x19google/protobuf/any.proto"\x0f\n\rSayMeqDataReq"*\n\x0bTestDataReq\x12\x0c\n\x04year\x18\x01 \x01(\t\x12\r\n\x05month\x18\x02 \x01(\t"?\n\nMessageReq\x12\r\n\x05\x65vent\x18\x01 \x01(\t\x12"\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x14.google.protobuf.Any".\n\x0bMessageResp\x12\x0e\n\x06status\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t2V\n\x0c\x44\x65\x62ugService\x12\x46\n\x0bSendMessage\x12\x19.grpc.pb.debug.MessageReq\x1a\x1a.grpc.pb.debug.MessageResp"\x00\x62\x06proto3' + b'\n\x0b\x64\x65\x62ug.proto\x12\rgrpc.pb.debug\x1a\x19google/protobuf/any.proto"\x0f\n\rSayMeqDataReq"*\n\x0bTestDataReq\x12\x0c\n\x04year\x18\x01 \x01(\t\x12\r\n\x05month\x18\x02 \x01(\t"?\n\nMessageReq\x12\r\n\x05\x65vent\x18\x01 \x01(\t\x12"\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x14.google.protobuf.Any".\n\x0bMessageResp\x12\x0e\n\x06status\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t"\x10\n\x0eHealthCheckReq"!\n\x0fHealthCheckResp\x12\x0e\n\x06status\x18\x01 \x01(\t2\xa6\x01\n\x0c\x44\x65\x62ugService\x12\x46\n\x0bSendMessage\x12\x19.grpc.pb.debug.MessageReq\x1a\x1a.grpc.pb.debug.MessageResp"\x00\x12N\n\x0bHealthCheck\x12\x1d.grpc.pb.debug.HealthCheckReq\x1a\x1e.grpc.pb.debug.HealthCheckResp"\x00\x62\x06proto3' ) _globals = globals() @@ -36,6 +36,10 @@ _globals["_MESSAGEREQ"]._serialized_end = 181 _globals["_MESSAGERESP"]._serialized_start = 183 _globals["_MESSAGERESP"]._serialized_end = 229 - _globals["_DEBUGSERVICE"]._serialized_start = 231 - _globals["_DEBUGSERVICE"]._serialized_end = 317 + _globals["_HEALTHCHECKREQ"]._serialized_start = 231 + _globals["_HEALTHCHECKREQ"]._serialized_end = 247 + _globals["_HEALTHCHECKRESP"]._serialized_start = 249 + _globals["_HEALTHCHECKRESP"]._serialized_end = 282 + _globals["_DEBUGSERVICE"]._serialized_start = 285 + _globals["_DEBUGSERVICE"]._serialized_end = 451 # @@protoc_insertion_point(module_scope) diff --git a/src/app/interfaces/grpc/pb/debug/debug_pb2_grpc.py b/src/app/interfaces/grpc/pb/debug/debug_pb2_grpc.py index 647e9b4..dee7faf 100644 --- a/src/app/interfaces/grpc/pb/debug/debug_pb2_grpc.py +++ b/src/app/interfaces/grpc/pb/debug/debug_pb2_grpc.py @@ -1,10 +1,11 @@ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc +import warnings -import src.app.interfaces.grpc.pb.debug.debug_pb2 as debug__pb2 +from src.app.interfaces.grpc.pb.debug import debug_pb2 as debug__pb2 -GRPC_GENERATED_VERSION = "1.70.0" +GRPC_GENERATED_VERSION = "1.75.0" GRPC_VERSION = grpc.__version__ _version_not_supported = False @@ -18,7 +19,7 @@ if _version_not_supported: raise RuntimeError( f"The grpc package installed is at version {GRPC_VERSION}," - + " but the generated code in debug_pb2_grpc.py depends on" + + f" but the generated code in debug_pb2_grpc.py depends on" + f" grpcio>={GRPC_GENERATED_VERSION}." + f" Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}" + f" or downgrade your generated code using grpcio-tools<={GRPC_VERSION}." @@ -40,6 +41,12 @@ def __init__(self, channel): response_deserializer=debug__pb2.MessageResp.FromString, _registered_method=True, ) + self.HealthCheck = channel.unary_unary( + "/grpc.pb.debug.DebugService/HealthCheck", + request_serializer=debug__pb2.HealthCheckReq.SerializeToString, + response_deserializer=debug__pb2.HealthCheckResp.FromString, + _registered_method=True, + ) class DebugServiceServicer(object): @@ -51,6 +58,12 @@ def SendMessage(self, request, context): context.set_details("Method not implemented!") raise NotImplementedError("Method not implemented!") + def HealthCheck(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + def add_DebugServiceServicer_to_server(servicer, server): rpc_method_handlers = { @@ -59,6 +72,11 @@ def add_DebugServiceServicer_to_server(servicer, server): request_deserializer=debug__pb2.MessageReq.FromString, response_serializer=debug__pb2.MessageResp.SerializeToString, ), + "HealthCheck": grpc.unary_unary_rpc_method_handler( + servicer.HealthCheck, + request_deserializer=debug__pb2.HealthCheckReq.FromString, + response_serializer=debug__pb2.HealthCheckResp.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler("grpc.pb.debug.DebugService", rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) @@ -98,3 +116,33 @@ def SendMessage( metadata, _registered_method=True, ) + + @staticmethod + def HealthCheck( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/grpc.pb.debug.DebugService/HealthCheck", + debug__pb2.HealthCheckReq.SerializeToString, + debug__pb2.HealthCheckResp.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True, + ) diff --git a/src/app/interfaces/grpc/protos/debug.proto b/src/app/interfaces/grpc/protos/debug.proto index 6e76996..1bcd131 100644 --- a/src/app/interfaces/grpc/protos/debug.proto +++ b/src/app/interfaces/grpc/protos/debug.proto @@ -22,6 +22,17 @@ message MessageResp { string message = 2; } + +message HealthCheckReq { +} + +message HealthCheckResp { + string status = 1; +} + + service DebugService { rpc SendMessage(MessageReq) returns (MessageResp) {} + + rpc HealthCheck(HealthCheckReq) returns (HealthCheckResp) {} } diff --git a/src/app/interfaces/grpc/services/debug_service.py b/src/app/interfaces/grpc/services/debug_service.py index faa0305..fee8e1f 100644 --- a/src/app/interfaces/grpc/services/debug_service.py +++ b/src/app/interfaces/grpc/services/debug_service.py @@ -8,6 +8,7 @@ from google.protobuf.json_format import MessageToJson from src.app.interfaces.grpc.pb.debug.debug_pb2_grpc import DebugServiceServicer +from src.app.application.container import container as services_container class DebugService(DebugServiceServicer): @@ -23,3 +24,8 @@ async def SendMessage(self, request, context) -> pb2.MessageResp: # type: ignor ) logger.debug(f"Sent message `{event}` with data {str(data)}") return pb2.MessageResp(status=True, message="OK") # type: ignore + + async def HealthCheck(self, request, context) -> pb2.HealthCheckResp: # type: ignore + is_healthy = await services_container.common_service.is_healthy() + status = "OK" if is_healthy else "NOT OK" + return pb2.HealthCheckResp(status=status) # type: ignore diff --git a/tests/application/users/services/test_common_service_is_healthy.py b/tests/application/users/services/test_common_service_is_healthy.py new file mode 100644 index 0000000..de6415b --- /dev/null +++ b/tests/application/users/services/test_common_service_is_healthy.py @@ -0,0 +1,128 @@ +from asyncio import AbstractEventLoop +from typing import Generator, Tuple, Any +from unittest.mock import patch + +import pytest + +from src.app.application.container import container as service_container + + +@pytest.fixture +def mock_health_services() -> Generator[Tuple[Any, Any, Any], None, None]: + with ( + patch( + "src.app.infrastructure.repositories.common_psql_repository.CommonPSQLRepository.is_healthy" + ) as mock_psql, + patch( + "src.app.infrastructure.repositories.common_redis_repository.CommonRedisRepository.is_healthy" + ) as mock_redis, + patch("src.app.infrastructure.messaging.mq_client.mq_client.is_healthy") as mock_mq, + ): + yield mock_psql, mock_redis, mock_mq + + +IS_HEALTHY_CASES = [ + {"psql_val": True, "redis_val": True, "mq_val": True, "expected": True}, + {"psql_val": False, "redis_val": True, "mq_val": True, "expected": False}, + {"psql_val": True, "redis_val": False, "mq_val": True, "expected": False}, + {"psql_val": True, "redis_val": True, "mq_val": False, "expected": False}, + {"psql_val": True, "redis_val": False, "mq_val": False, "expected": False}, + {"psql_val": False, "redis_val": True, "mq_val": False, "expected": False}, +] + + +@pytest.mark.parametrize("data", IS_HEALTHY_CASES, scope="function") +def test_common_service_is_healthy( + e_loop: AbstractEventLoop, mock_health_services: Tuple[Any, Any, Any], data: dict +) -> None: + mock_psql, mock_redis, mock_mq = mock_health_services + expected_val, psql_val, redis_val, mq_val = ( + data["expected"], + data["psql_val"], + data["redis_val"], + data["mq_val"], + ) + mock_psql.return_value = psql_val + mock_redis.return_value = redis_val + mock_mq.return_value = mq_val + + result = e_loop.run_until_complete(service_container.common_service.is_healthy()) + + assert result is expected_val + + +IS_HEALTHY_CASES_FAILED = [ + {"psql_val": True, "redis_val": True, "mq_val": False, "expected": False}, + {"psql_val": True, "redis_val": False, "mq_val": False, "expected": False}, + {"psql_val": False, "redis_val": True, "mq_val": False, "expected": False}, +] + + +@pytest.mark.parametrize("data", IS_HEALTHY_CASES_FAILED, scope="function") +def test_common_service_is_healthy_with_psq_exception( + e_loop: AbstractEventLoop, mock_health_services: Tuple[Any, Any, Any], data: dict +) -> None: + mock_psql, mock_redis, mock_mq = mock_health_services + expected_val, psql_val, redis_val, mq_val = ( + data["expected"], + data["psql_val"], + data["redis_val"], + data["mq_val"], + ) + mock_psql.return_value = psql_val + mock_redis.return_value = redis_val + mock_mq.return_value = mq_val + mock_psql.side_effect = Exception("Connection failed") + + result = e_loop.run_until_complete(service_container.common_service.is_healthy()) + + assert result is expected_val + + +@pytest.mark.parametrize("data", IS_HEALTHY_CASES_FAILED, scope="function") +def test_common_service_is_healthy_with_mq_exception( + e_loop: AbstractEventLoop, mock_health_services: Tuple[Any, Any, Any], data: dict +) -> None: + mock_psql, mock_redis, mock_mq = mock_health_services + expected_val, psql_val, redis_val, mq_val = ( + data["expected"], + data["psql_val"], + data["redis_val"], + data["mq_val"], + ) + mock_psql.return_value = psql_val + mock_redis.return_value = redis_val + mock_mq.return_value = mq_val + mock_mq.side_effect = Exception("Connection failed") + + result = e_loop.run_until_complete(service_container.common_service.is_healthy()) + + assert result is expected_val + + +@pytest.mark.parametrize("data", IS_HEALTHY_CASES_FAILED, scope="function") +def test_common_service_is_healthy_with_redis_exception( + e_loop: AbstractEventLoop, mock_health_services: Tuple[Any, Any, Any], data: dict +) -> None: + mock_psql, mock_redis, mock_mq = mock_health_services + expected_val, psql_val, redis_val, mq_val = ( + data["expected"], + data["psql_val"], + data["redis_val"], + data["mq_val"], + ) + mock_psql.return_value = psql_val + mock_redis.return_value = redis_val + mock_mq.return_value = mq_val + mock_redis.side_effect = Exception("Connection failed") + + result = e_loop.run_until_complete(service_container.common_service.is_healthy()) + + assert result is expected_val + + +def test_common_service_is_healthy_real_infrastructure(e_loop: AbstractEventLoop) -> None: + result = e_loop.run_until_complete(service_container.common_service.is_healthy()) + + assert isinstance(result, bool) + assert result is True diff --git a/tests/infrastructure/repositories/test_repository_general.py b/tests/infrastructure/repositories/test_repository_general.py index 472ef4e..86f78c6 100644 --- a/tests/infrastructure/repositories/test_repository_general.py +++ b/tests/infrastructure/repositories/test_repository_general.py @@ -9,17 +9,20 @@ from tests.fixtures.constants import USERS -def test_users_get_list_limit_offset_case_1(e_loop: AbstractEventLoop, users: Any) -> None: +repository = repo_container.users_repository +out_dataclass = UserTestAggregate + + +def test_get_list_limit_offset_case_1(e_loop: AbstractEventLoop, users: Any) -> None: """Test pagination with limit larger than remaining items""" - users_repository = repo_container.users_repository total_users = len(USERS) offset = total_users - 1 limit = 10 expected_count = 1 - items: List[Type[UserTestAggregate]] = e_loop.run_until_complete( - users_repository.get_list(filter_data={"limit": limit, "offset": offset}, out_dataclass=UserTestAggregate), + items: List[Type[out_dataclass]] = e_loop.run_until_complete( + repository.get_list(filter_data={"limit": limit, "offset": offset}, out_dataclass=out_dataclass), ) assert isinstance(items, list) @@ -28,21 +31,20 @@ def test_users_get_list_limit_offset_case_1(e_loop: AbstractEventLoop, users: An # Verify returned item is valid expected_ids = {user["id"] for user in USERS} for user in items: - assert isinstance(user, UserTestAggregate) + assert isinstance(user, out_dataclass) assert user.id in expected_ids -def test_users_get_list_limit_offset_case_2(e_loop: AbstractEventLoop, users: Any) -> None: +def test_get_list_limit_offset_case_2(e_loop: AbstractEventLoop, users: Any) -> None: """Test pagination with offset near end of dataset""" - users_repository = repo_container.users_repository total_users = len(USERS) offset = total_users - 2 limit = 10 expected_count = 2 - items: List[Type[UserTestAggregate]] = e_loop.run_until_complete( - users_repository.get_list(filter_data={"limit": limit, "offset": offset}, out_dataclass=UserTestAggregate) + items: List[Type[out_dataclass]] = e_loop.run_until_complete( + repository.get_list(filter_data={"limit": limit, "offset": offset}, out_dataclass=out_dataclass) ) assert isinstance(items, list) @@ -50,21 +52,20 @@ def test_users_get_list_limit_offset_case_2(e_loop: AbstractEventLoop, users: An expected_ids = {user["id"] for user in USERS} for user in items: - assert isinstance(user, UserTestAggregate) + assert isinstance(user, out_dataclass) assert user.id in expected_ids -def test_users_get_list_limit_offset_case_3(e_loop: AbstractEventLoop, users: Any) -> None: +def test_get_list_limit_offset_case_3(e_loop: AbstractEventLoop, users: Any) -> None: """Test pagination with small limit and large offset""" - users_repository = repo_container.users_repository total_users = len(USERS) offset = total_users - 2 limit = 1 expected_count = 1 - items: List[Type[UserTestAggregate]] = e_loop.run_until_complete( - users_repository.get_list(filter_data={"limit": limit, "offset": offset}, out_dataclass=UserTestAggregate) + items: List[Type[out_dataclass]] = e_loop.run_until_complete( + repository.get_list(filter_data={"limit": limit, "offset": offset}, out_dataclass=out_dataclass) ) assert isinstance(items, list) @@ -72,20 +73,19 @@ def test_users_get_list_limit_offset_case_3(e_loop: AbstractEventLoop, users: An expected_ids = {user["id"] for user in USERS} for user in items: - assert isinstance(user, UserTestAggregate) + assert isinstance(user, out_dataclass) assert user.id in expected_ids -def test_users_get_list_with_offset_only(e_loop: AbstractEventLoop, users: Any) -> None: +def test_get_list_with_offset_only(e_loop: AbstractEventLoop, users: Any) -> None: """Test pagination with offset but no limit""" - users_repository = repo_container.users_repository total_users = len(USERS) offset = 1 expected_count = total_users - offset - items: List[Type[UserTestAggregate]] = e_loop.run_until_complete( - users_repository.get_list(filter_data={"offset": offset}, out_dataclass=UserTestAggregate) + items: List[Type[out_dataclass]] = e_loop.run_until_complete( + repository.get_list(filter_data={"offset": offset}, out_dataclass=out_dataclass) ) assert isinstance(items, list) @@ -93,19 +93,18 @@ def test_users_get_list_with_offset_only(e_loop: AbstractEventLoop, users: Any) expected_ids = {user["id"] for user in USERS} for user in items: - assert isinstance(user, UserTestAggregate) + assert isinstance(user, out_dataclass) assert user.id in expected_ids -def test_users_get_list_order_by_id_asc(e_loop: AbstractEventLoop, users: Any) -> None: +def test_get_list_order_by_id_asc(e_loop: AbstractEventLoop, users: Any) -> None: """Test ordering users by ID in ascending order""" - users_repository = repo_container.users_repository users_sorted = sorted(USERS, key=lambda i: i["id"]) expected_count = len(users_sorted) - items: List[UserTestAggregate] = e_loop.run_until_complete( - users_repository.get_list(order_data=("id",), out_dataclass=UserTestAggregate) + items: List[out_dataclass] = e_loop.run_until_complete( + repository.get_list(order_data=("id",), out_dataclass=out_dataclass) ) assert isinstance(items, list) @@ -113,7 +112,7 @@ def test_users_get_list_order_by_id_asc(e_loop: AbstractEventLoop, users: Any) - # Verify ordering is correct for index, user in enumerate(items): - assert isinstance(user, UserTestAggregate) + assert isinstance(user, out_dataclass) assert user.id == users_sorted[index]["id"] # Verify items are in ascending order @@ -121,15 +120,14 @@ def test_users_get_list_order_by_id_asc(e_loop: AbstractEventLoop, users: Any) - assert user_ids == sorted(user_ids) -def test_users_get_list_order_by_id_desc(e_loop: AbstractEventLoop, users: Any) -> None: +def test_get_list_order_by_id_desc(e_loop: AbstractEventLoop, users: Any) -> None: """Test ordering users by ID in descending order""" - users_repository = repo_container.users_repository users_sorted = sorted(USERS, key=lambda i: i["id"], reverse=True) expected_count = len(users_sorted) - items: List[Type[UserTestAggregate]] = e_loop.run_until_complete( - users_repository.get_list(order_data=("-id",), out_dataclass=UserTestAggregate) + items: List[Type[out_dataclass]] = e_loop.run_until_complete( + repository.get_list(order_data=("-id",), out_dataclass=out_dataclass) ) assert isinstance(items, list) @@ -137,7 +135,7 @@ def test_users_get_list_order_by_id_desc(e_loop: AbstractEventLoop, users: Any) # Verify ordering is correct for index, user in enumerate(items): - assert isinstance(user, UserTestAggregate) + assert isinstance(user, out_dataclass) assert user.id == users_sorted[index]["id"] # Verify items are in descending order @@ -153,14 +151,14 @@ def test_users_get_list_order_by_id_desc(e_loop: AbstractEventLoop, users: Any) @pytest.mark.parametrize("data", USERS_IN_LOOKUP, scope="function") -def test_users_get_list_in_lookup(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: - users_repository = repo_container.users_repository +def test_get_list_in_lookup(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: + users_repository = repository field = data["key"] lookup = f"{field}__in" expected_values = data["value"] - items: List[Type[UserTestAggregate]] = e_loop.run_until_complete( - users_repository.get_list(filter_data={lookup: data["value"]}, out_dataclass=UserTestAggregate) + items: List[Type[out_dataclass]] = e_loop.run_until_complete( + users_repository.get_list(filter_data={lookup: data["value"]}, out_dataclass=out_dataclass) ) assert isinstance(items, list) is True @@ -178,13 +176,11 @@ def test_users_get_list_in_lookup(e_loop: AbstractEventLoop, users: Any, data: d @pytest.mark.parametrize("data", USERS_GT_LOOKUP, scope="function") -def test_users_get_list_gt_lookup(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: - users_repository = repo_container.users_repository - +def test_get_list_gt_lookup(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: field = data["key"] lookup = f"{field}__gt" - items: List[Type[UserTestAggregate]] = e_loop.run_until_complete( - users_repository.get_list(filter_data={lookup: data["value"]}, out_dataclass=UserTestAggregate) + items: List[Type[out_dataclass]] = e_loop.run_until_complete( + repository.get_list(filter_data={lookup: data["value"]}, out_dataclass=out_dataclass) ) assert isinstance(items, list) is True @@ -202,13 +198,12 @@ def test_users_get_list_gt_lookup(e_loop: AbstractEventLoop, users: Any, data: d @pytest.mark.parametrize("data", USERS_GTE_LOOKUP, scope="function") -def test_users_get_list_gte_lookup(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: - users_repository = repo_container.users_repository +def test_get_list_gte_lookup(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: field = data["key"] - lookup = f"{field}__gt" - items: List[Type[UserTestAggregate]] = e_loop.run_until_complete( - users_repository.get_list(filter_data={lookup: data["value"]}, out_dataclass=UserTestAggregate) + lookup = f"{field}__gte" + items: List[Type[out_dataclass]] = e_loop.run_until_complete( + repository.get_list(filter_data={lookup: data["value"]}, out_dataclass=out_dataclass) ) assert isinstance(items, list) is True @@ -226,13 +221,12 @@ def test_users_get_list_gte_lookup(e_loop: AbstractEventLoop, users: Any, data: @pytest.mark.parametrize("data", USERS_LT_LOOKUP, scope="function") -def test_users_get_list_lt_lookup(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: - users_repository = repo_container.users_repository +def test_get_list_lt_lookup(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: field = data["key"] lookup = f"{field}__lt" - items: List[Type[UserTestAggregate]] = e_loop.run_until_complete( - users_repository.get_list(filter_data={lookup: data["value"]}, out_dataclass=UserTestAggregate) + items: List[Type[out_dataclass]] = e_loop.run_until_complete( + repository.get_list(filter_data={lookup: data["value"]}, out_dataclass=out_dataclass) ) assert isinstance(items, list) is True @@ -250,13 +244,11 @@ def test_users_get_list_lt_lookup(e_loop: AbstractEventLoop, users: Any, data: d @pytest.mark.parametrize("data", USERS_LTE_LOOKUP, scope="function") -def test_users_get_list_lte_lookup(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: - users_repository = repo_container.users_repository - +def test_get_list_lte_lookup(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: field = data["key"] lookup = f"{field}__lte" - items: List[Type[UserTestAggregate]] = e_loop.run_until_complete( - users_repository.get_list(filter_data={lookup: data["value"]}, out_dataclass=UserTestAggregate) + items: List[Type[out_dataclass]] = e_loop.run_until_complete( + repository.get_list(filter_data={lookup: data["value"]}, out_dataclass=out_dataclass) ) assert isinstance(items, list) is True @@ -275,13 +267,12 @@ def test_users_get_list_lte_lookup(e_loop: AbstractEventLoop, users: Any, data: @pytest.mark.parametrize("data", USERS_E_LOOKUP, scope="function") -def test_users_get_list_e_lookup_case_1(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: - users_repository = repo_container.users_repository +def test_get_list_e_lookup_case_1(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: field = data["key"] lookup = f"{field}__e" - items: List[Type[UserTestAggregate]] = e_loop.run_until_complete( - users_repository.get_list(filter_data={lookup: data["value"]}, out_dataclass=UserTestAggregate) + items: List[Type[out_dataclass]] = e_loop.run_until_complete( + repository.get_list(filter_data={lookup: data["value"]}, out_dataclass=out_dataclass) ) assert isinstance(items, list) is True @@ -292,13 +283,12 @@ def test_users_get_list_e_lookup_case_1(e_loop: AbstractEventLoop, users: Any, d @pytest.mark.parametrize("data", USERS_E_LOOKUP, scope="function") -def test_users_get_list_e_lookup_case_2(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: - users_repository = repo_container.users_repository +def test_get_list_e_lookup_case_2(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: field = data["key"] lookup = field - items: List[Type[UserTestAggregate]] = e_loop.run_until_complete( - users_repository.get_list(filter_data={lookup: data["value"]}, out_dataclass=UserTestAggregate) + items: List[Type[out_dataclass]] = e_loop.run_until_complete( + repository.get_list(filter_data={lookup: data["value"]}, out_dataclass=out_dataclass) ) assert isinstance(items, list) is True @@ -317,13 +307,12 @@ def test_users_get_list_e_lookup_case_2(e_loop: AbstractEventLoop, users: Any, d @pytest.mark.parametrize("data", USERS_NE_LOOKUP, scope="function") -def test_users_get_list_ne_lookup(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: - users_repository = repo_container.users_repository +def test_get_list_ne_lookup(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: field = data["key"] lookup = f"{field}__ne" - items: List[Type[UserTestAggregate]] = e_loop.run_until_complete( - users_repository.get_list(filter_data={lookup: data["value"]}, out_dataclass=UserTestAggregate) + items: List[Type[out_dataclass]] = e_loop.run_until_complete( + repository.get_list(filter_data={lookup: data["value"]}, out_dataclass=out_dataclass) ) assert isinstance(items, list) is True @@ -342,13 +331,12 @@ def test_users_get_list_ne_lookup(e_loop: AbstractEventLoop, users: Any, data: d @pytest.mark.parametrize("data", USERS_NOT_IN_LOOKUP, scope="function") -def test_users_get_list_not_in_lookup(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: - users_repository = repo_container.users_repository +def test_get_list_not_in_lookup(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: field = data["key"] lookup = f"{field}__not_in" - items: List[Type[UserTestAggregate]] = e_loop.run_until_complete( - users_repository.get_list(filter_data={lookup: data["value"]}, out_dataclass=UserTestAggregate) + items: List[Type[out_dataclass]] = e_loop.run_until_complete( + repository.get_list(filter_data={lookup: data["value"]}, out_dataclass=out_dataclass) ) assert isinstance(items, list) is True @@ -370,24 +358,35 @@ def test_users_get_list_not_in_lookup(e_loop: AbstractEventLoop, users: Any, dat "value": "gmail", }, {"key": "birthday", "value": USERS[3]["birthday"]}, + {"key": "email", "value": "%gmail%"}, + {"key": "first_name", "value": "first%"}, + {"key": "last_name", "value": "%name%"}, ] @pytest.mark.parametrize("data", USERS_LIKE_LOOKUP, scope="function") -def test_users_get_list_like_lookup(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: - users_repository = repo_container.users_repository +def test_get_list_like_lookup(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: field = data["key"] lookup = f"{field}__like" - items: List[Type[UserTestAggregate]] = e_loop.run_until_complete( - users_repository.get_list(filter_data={lookup: data["value"]}, out_dataclass=UserTestAggregate) + items: List[Type[out_dataclass]] = e_loop.run_until_complete( + repository.get_list(filter_data={lookup: data["value"]}, out_dataclass=out_dataclass) ) assert isinstance(items, list) is True for index, user in enumerate(items): - value = getattr(user, field) - assert str(data["value"]) in str(value) + value = str(getattr(user, field)) + pattern = str(data["value"]) + if "%" in pattern: + if pattern.startswith("%") and pattern.endswith("%"): + assert pattern[1:-1] in value + elif pattern.startswith("%"): + assert value.endswith(pattern[1:]) + elif pattern.endswith("%"): + assert value.startswith(pattern[:-1]) + else: + assert pattern in value USERS_NOT_LIKE_ALL_LOOKUP = [ @@ -401,13 +400,12 @@ def test_users_get_list_like_lookup(e_loop: AbstractEventLoop, users: Any, data: @pytest.mark.parametrize("data", USERS_NOT_LIKE_ALL_LOOKUP, scope="function") -def test_users_get_list_not_like_all_lookup(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: - users_repository = repo_container.users_repository +def test_get_list_not_like_all_lookup(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: field = data["key"] lookup = f"{field}__not_like_all" - items: List[Type[UserTestAggregate]] = e_loop.run_until_complete( - users_repository.get_list(filter_data={lookup: data["value"]}, out_dataclass=UserTestAggregate) + items: List[Type[out_dataclass]] = e_loop.run_until_complete( + repository.get_list(filter_data={lookup: data["value"]}, out_dataclass=out_dataclass) ) assert isinstance(items, list) is True @@ -426,16 +424,15 @@ def test_users_get_list_not_like_all_lookup(e_loop: AbstractEventLoop, users: An @pytest.mark.parametrize("data", USERS_JSONB_LIKE_LOOKUP, scope="function") -def test_users_get_list_jsonb_like_lookup_case_1(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: - users_repository = repo_container.users_repository +def test_get_list_jsonb_like_lookup_case_1(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: field = data["key"] lookup = f"{field}__jsonb_like" for value in data["value"]: - items: List[Type[UserTestAggregate]] = e_loop.run_until_complete( - users_repository.get_list( + items: List[Type[out_dataclass]] = e_loop.run_until_complete( + repository.get_list( filter_data={lookup: value}, - out_dataclass=UserTestAggregate, + out_dataclass=out_dataclass, ) ) @@ -448,14 +445,13 @@ def test_users_get_list_jsonb_like_lookup_case_1(e_loop: AbstractEventLoop, user @pytest.mark.parametrize("data", USERS_JSONB_LIKE_LOOKUP, scope="function") -def test_users_get_list_jsonb_like_lookup_case_2(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: - users_repository = repo_container.users_repository +def test_get_list_jsonb_like_lookup_case_2(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: field = data["key"] lookup = f"meta__{field}__jsonb_like" for value in data["value"]: - items: List[Type[UserTestAggregate]] = e_loop.run_until_complete( - users_repository.get_list(filter_data={lookup: value}, out_dataclass=UserTestAggregate) + items: List[Type[out_dataclass]] = e_loop.run_until_complete( + repository.get_list(filter_data={lookup: value}, out_dataclass=out_dataclass) ) assert isinstance(items, list) is True @@ -475,14 +471,13 @@ def test_users_get_list_jsonb_like_lookup_case_2(e_loop: AbstractEventLoop, user @pytest.mark.parametrize("data", USERS_JSONB_NOT_LIKE_LOOKUP, scope="function") -def test_users_get_list_jsonb_not_like_lookup_case_1(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: - users_repository = repo_container.users_repository +def test_get_list_jsonb_not_like_lookup_case_1(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: field = data["key"] lookup = f"{field}__jsonb_not_like" for value in data["value"]: - items: List[Type[UserTestAggregate]] = e_loop.run_until_complete( - users_repository.get_list(filter_data={lookup: value}, out_dataclass=UserTestAggregate) + items: List[Type[out_dataclass]] = e_loop.run_until_complete( + repository.get_list(filter_data={lookup: value}, out_dataclass=out_dataclass) ) assert isinstance(items, list) is True @@ -495,14 +490,13 @@ def test_users_get_list_jsonb_not_like_lookup_case_1(e_loop: AbstractEventLoop, @pytest.mark.parametrize("data", USERS_JSONB_NOT_LIKE_LOOKUP, scope="function") -def test_users_get_list_jsonb_not_like_lookup_case_2(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: - users_repository = repo_container.users_repository +def test_get_list_jsonb_not_like_lookup_case_2(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: field = data["key"] lookup = f"meta__{field}__jsonb_not_like" for value in data["value"]: - items: List[Type[UserTestAggregate]] = e_loop.run_until_complete( - users_repository.get_list(filter_data={lookup: value}, out_dataclass=UserTestAggregate) + items: List[Type[out_dataclass]] = e_loop.run_until_complete( + repository.get_list(filter_data={lookup: value}, out_dataclass=out_dataclass) ) assert isinstance(items, list) is True @@ -511,3 +505,77 @@ def test_users_get_list_jsonb_not_like_lookup_case_2(e_loop: AbstractEventLoop, data = getattr(user, "meta", {}) data_value = data.get(field) assert str(value) not in str(data_value) + + +USERS_ILIKE_LOOKUP = [ + {"key": "first_name", "value": "FIRST_NAME_1"}, + {"key": "first_name", "value": "first_name_2"}, + {"key": "email", "value": "GMAIL"}, + {"key": "email", "value": "1"}, + {"key": "last_name", "value": "LAST_NAME_3"}, + {"key": "email", "value": "%gmail%"}, + {"key": "first_name", "value": "first%"}, + {"key": "last_name", "value": "%name%"}, + {"key": "email", "value": "%.com"}, +] + + +@pytest.mark.parametrize("data", USERS_ILIKE_LOOKUP, scope="function") +def test_get_list_ilike_lookup(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: + field = data["key"] + lookup = f"{field}__ilike" + items: List[Type[out_dataclass]] = e_loop.run_until_complete( + repository.get_list(filter_data={lookup: data["value"]}, out_dataclass=out_dataclass) + ) + + assert isinstance(items, list) is True + + for index, user in enumerate(items): + value = str(getattr(user, field)).lower() + pattern = str(data["value"]).lower() + if "%" in pattern: + if pattern.startswith("%") and pattern.endswith("%"): + assert pattern[1:-1] in value + elif pattern.startswith("%"): + assert value.endswith(pattern[1:]) + elif pattern.endswith("%"): + assert value.startswith(pattern[:-1]) + else: + assert pattern in value + + +USERS_EMPTY_FIELDS_LOOKUP = [ + {"key": "first_name", "value": ""}, + {"key": "email", "value": ""}, + {"key": "last_name", "value": ""}, +] + + +@pytest.mark.parametrize("data", USERS_EMPTY_FIELDS_LOOKUP, scope="function") +def test_get_list_empty_string_lookup(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: + field = data["key"] + lookup = f"{field}__e" + items: List[Type[out_dataclass]] = e_loop.run_until_complete( + repository.get_list(filter_data={lookup: data["value"]}, out_dataclass=out_dataclass) + ) + + assert isinstance(items, list) is True + assert len(items) == 0 + + +USERS_COMBINED_FILTERS_LOOKUP = [ + {"filters": {"first_name__like": "first_name", "id__gt": 1}, "expected_min_count": 1}, + {"filters": {"email__ilike": "GMAIL", "id__lte": 3}, "expected_min_count": 1}, + {"filters": {"first_name__e": USERS[0]["first_name"], "id__in": [1, 2, 3]}, "expected_min_count": 1}, +] + + +@pytest.mark.parametrize("data", USERS_COMBINED_FILTERS_LOOKUP, scope="function") +def test_get_list_combined_filters(e_loop: AbstractEventLoop, users: Any, data: dict) -> None: + + items: List[Type[out_dataclass]] = e_loop.run_until_complete( + repository.get_list(filter_data=data["filters"], out_dataclass=out_dataclass) + ) + + assert isinstance(items, list) is True + assert len(items) >= data["expected_min_count"] diff --git a/tests/infrastructure/repositories/test_users_repository.py b/tests/infrastructure/repositories/test_users_repository.py index 91e55bd..fe29453 100644 --- a/tests/infrastructure/repositories/test_users_repository.py +++ b/tests/infrastructure/repositories/test_users_repository.py @@ -6,7 +6,7 @@ import pytest - +from src.app.infrastructure.repositories.base.abstract import RepositoryError from src.app.infrastructure.utils.common import generate_str from src.app.infrastructure.repositories.container import container as repo_container from tests.domain.users.aggregates.common import UserTestAggregate @@ -886,11 +886,10 @@ def test_get_list_with_large_limit(e_loop: AbstractEventLoop, users: Any) -> Non def test_get_list_with_zero_limit(e_loop: AbstractEventLoop, users: Any) -> None: users_repository = repo_container.users_repository - items = e_loop.run_until_complete( - users_repository.get_list(filter_data={"limit": 0}, out_dataclass=UserTestAggregate) - ) - - assert len(items) == 0 + with pytest.raises(RepositoryError): + e_loop.run_until_complete( + users_repository.get_list(filter_data={"limit": 0}, out_dataclass=UserTestAggregate) + ) def test_get_list_with_large_offset(e_loop: AbstractEventLoop, users: Any) -> None: