From d3eaedb25c9f9f0e03c94f05dc643588cfd0629b Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 23 Dec 2025 10:47:14 +0000 Subject: [PATCH] refactor: improve code quality, add architecture docs and tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Major refactoring: - Split main.py (1349→307 lines) into modular route handlers - Create app/dependencies.py with shared state and auth helpers - Create app/routes/proxy.py for HTTP/LLM proxy endpoints - Create app/routes/rapidapi.py for RapidAPI integration - Update app/routes/health.py with all health check endpoints Improved typing: - Add Pydantic enums (HTTPMethod, MessageRole, CostPolicy) - Add typed response models (TokenUsage, LLMResponseData) - Add field validators with constraints (ge, le, min_length) - Use Literal types for success/error flags Code quality fixes: - Replace print() with logger.warning() in analytics.py - Remove unused imports - Improve import organization Documentation: - Add docs/ARCHITECTURE.md with system overview - Document request flow, design decisions, deployment Testing: - Add tests/test_routes_business.py for business routes - Cover onboarding, analytics, health, and schema validation --- app/dependencies.py | 477 +++++++++++ app/main.py | 1465 +++++---------------------------- app/routes/__init__.py | 19 +- app/routes/analytics.py | 26 +- app/routes/health.py | 99 ++- app/routes/proxy.py | 579 +++++++++++++ app/routes/rapidapi.py | 310 +++++++ app/schemas.py | 312 +++++-- docs/ARCHITECTURE.md | 368 +++++++++ tests/test_routes_business.py | 517 ++++++++++++ 10 files changed, 2840 insertions(+), 1332 deletions(-) create mode 100644 app/dependencies.py create mode 100644 app/routes/proxy.py create mode 100644 app/routes/rapidapi.py create mode 100644 docs/ARCHITECTURE.md create mode 100644 tests/test_routes_business.py diff --git a/app/dependencies.py b/app/dependencies.py new file mode 100644 index 0000000..a8ff1fe --- /dev/null +++ b/app/dependencies.py @@ -0,0 +1,477 @@ +"""Shared dependencies and utilities for ReliAPI FastAPI application. + +This module contains: +- Global state management (config, cache, rate limiter, etc.) +- Authentication and authorization helpers +- Client profile detection +- Configuration initialization +""" +import hashlib +import logging +import os +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + +from fastapi import HTTPException, Request + +from reliapi.config.loader import ConfigLoader +from reliapi.core.cache import Cache +from reliapi.core.client_profile import ClientProfile, ClientProfileManager +from reliapi.core.errors import ErrorCode +from reliapi.core.idempotency import IdempotencyManager +from reliapi.core.key_pool import KeyPoolManager, ProviderKey +from reliapi.core.rate_limiter import RateLimiter +from reliapi.core.rate_scheduler import RateScheduler +from reliapi.integrations.rapidapi import RapidAPIClient +from reliapi.integrations.rapidapi_tenant import RapidAPITenantManager +from reliapi.metrics.prometheus import rapidapi_tier_cache_total + +logger = logging.getLogger(__name__) + + +class ConfigValidationError(Exception): + """Raised when configuration validation fails.""" + pass + + +@dataclass +class AppState: + """Application state container for all shared components. + + This class encapsulates all global state for the ReliAPI application, + making it easier to manage, test, and inject dependencies. + """ + config_loader: Optional[ConfigLoader] = None + targets: Dict[str, Dict] = field(default_factory=dict) + cache: Optional[Cache] = None + idempotency: Optional[IdempotencyManager] = None + rate_limiter: Optional[RateLimiter] = None + key_pool_manager: Optional[KeyPoolManager] = None + rate_scheduler: Optional[RateScheduler] = None + client_profile_manager: Optional[ClientProfileManager] = None + rapidapi_client: Optional[RapidAPIClient] = None + rapidapi_tenant_manager: Optional[RapidAPITenantManager] = None + + +# Global application state instance +app_state = AppState() + + +def get_app_state() -> AppState: + """Get the global application state. + + Returns: + AppState: The global application state instance. + """ + return app_state + + +def verify_api_key(request: Request) -> Tuple[Optional[str], Optional[str], str]: + """Verify API key from header and resolve tenant and tier. + + Priority for tier detection: + 1. RapidAPI headers (X-RapidAPI-User, X-RapidAPI-Subscription) + 2. Redis cache (from previous RapidAPI detection) + 3. Config-based tenants + 4. Test key prefixes (sk-free, sk-dev, sk-pro) + 5. Default: 'free' + + Args: + request: FastAPI request object + + Returns: + Tuple of (api_key, tenant_name, tier). + tenant_name is None if multi-tenant not enabled or tenant not found. + tier is 'free', 'developer', 'pro', or 'enterprise'. + + Raises: + HTTPException: If API key is missing or invalid. + """ + state = get_app_state() + api_key = request.headers.get("X-API-Key") + headers_dict = dict(request.headers) + + def get_tier(api_key: str, headers: Dict[str, str]) -> str: + """Determine tier with RapidAPI priority.""" + # 1. Check RapidAPI headers first + if state.rapidapi_client: + result = state.rapidapi_client.get_tier_from_headers(headers) + if result: + user_id, tier_enum = result + rapidapi_tier_cache_total.labels(operation="hit").inc() + return tier_enum.value + + # 2. Use rate_limiter with RapidAPI client for cache lookup + if state.rate_limiter and api_key: + tier = state.rate_limiter.get_account_tier( + api_key, headers, state.rapidapi_client + ) + return tier + + return "free" + + # Multi-tenant mode: check tenants config + if (state.config_loader and + hasattr(state.config_loader, 'config') and + state.config_loader.config.tenants): + tenants = state.config_loader.config.tenants + + # Find tenant by API key + for tenant_name, tenant_config in tenants.items(): + if tenant_config.api_key == api_key: + request.state.tenant = tenant_name + tier = get_tier(api_key, headers_dict) + request.state.tier = tier + return api_key, tenant_name, tier + + # Tenant not found - check global API key (backward compatibility) + required_key = os.getenv("RELIAPI_API_KEY") + if required_key and api_key == required_key: + request.state.tenant = None + tier = get_tier(api_key, headers_dict) + request.state.tier = tier + return api_key, None, tier + + # No matching tenant and no global key match - allow with free tier + tier = get_tier(api_key, headers_dict) + request.state.tier = tier + request.state.tenant = None + return api_key, None, tier + + # Single-tenant mode: use global API key + if not api_key: + # Check if RapidAPI headers are present + if state.rapidapi_client: + result = state.rapidapi_client.get_tier_from_headers(headers_dict) + if result: + user_id, tier_enum = result + virtual_api_key = f"rapidapi:{user_id}" + + # Auto-create tenant for RapidAPI user + if state.rapidapi_tenant_manager: + tenant_name = state.rapidapi_tenant_manager.ensure_tenant_exists( + user_id, tier_enum + ) + request.state.tenant = tenant_name + else: + request.state.tenant = None + + request.state.tier = tier_enum.value + rapidapi_tier_cache_total.labels(operation="hit").inc() + return virtual_api_key, request.state.tenant, tier_enum.value + + raise HTTPException( + status_code=401, + detail={ + "success": False, + "error": { + "type": "client_error", + "code": ErrorCode.UNAUTHORIZED.value, + "message": "Missing X-API-Key header", + "retryable": False, + "target": None, + "status_code": 401, + }, + }, + ) + + tier = get_tier(api_key, headers_dict) + + # For testing: allow keys starting with sk-free/sk-dev/sk-pro + if api_key and ( + api_key.startswith("sk-free") or + api_key.startswith("sk-dev") or + api_key.startswith("sk-pro") + ): + request.state.tenant = None + request.state.tier = tier + return api_key, None, tier + + # Check against required key + required_key = os.getenv("RELIAPI_API_KEY") + if not required_key: + request.state.tenant = None + request.state.tier = tier + return api_key, None, tier + + if api_key != required_key: + raise HTTPException( + status_code=401, + detail={ + "success": False, + "error": { + "type": "client_error", + "code": ErrorCode.UNAUTHORIZED.value, + "message": "Invalid API key", + "retryable": False, + "target": None, + "status_code": 401, + }, + }, + ) + + request.state.tenant = None + request.state.tier = tier + return api_key, None, tier + + +def detect_client_profile( + request: Request, + tenant: Optional[str] = None +) -> Optional[str]: + """Detect client profile using priority: X-Client header > tenant.profile > default. + + Args: + request: FastAPI request + tenant: Tenant name (if known) + + Returns: + Profile name or None + """ + state = get_app_state() + + # Priority 1: X-Client header + client_header = request.headers.get("X-Client") + if (client_header and + state.client_profile_manager and + state.client_profile_manager.has_profile(client_header)): + return client_header + + # Priority 2: tenant.profile + if tenant and state.config_loader: + tenant_config = state.config_loader.get_tenant(tenant) + if tenant_config and tenant_config.get("profile"): + profile_name = tenant_config.get("profile") + if (state.client_profile_manager and + state.client_profile_manager.has_profile(profile_name)): + return profile_name + + # Priority 3: default + return "default" + + +def get_account_id(api_key: Optional[str]) -> str: + """Generate account ID from API key hash. + + Args: + api_key: API key string or None + + Returns: + Hashed account ID (16 chars) or 'unknown' + """ + if not api_key: + return "unknown" + return hashlib.sha256(api_key.encode()).hexdigest()[:16] + + +def validate_startup_config( + config_loader: ConfigLoader, + strict: bool = True +) -> List[str]: + """Validate configuration at startup. + + Args: + config_loader: Configuration loader + strict: If True, fail on missing required env vars + + Returns: + List of validation warnings (non-fatal issues) + + Raises: + ConfigValidationError: If configuration is invalid + """ + errors: List[str] = [] + warnings: List[str] = [] + + # 1. Validate required environment variables + redis_url = os.getenv("REDIS_URL") + if not redis_url: + warnings.append("REDIS_URL not set - Redis features will be disabled") + + rapidapi_key = os.getenv("RAPIDAPI_API_KEY") + if not rapidapi_key: + warnings.append("RAPIDAPI_API_KEY not set - RapidAPI tier detection may be limited") + + # 2. Validate key pool configuration + pools_config = config_loader.get_provider_key_pools() + if pools_config: + seen_key_ids: Dict[str, str] = {} + + for provider, pool_config in pools_config.items(): + keys_config = pool_config.get("keys", []) + + if not keys_config: + warnings.append(f"Key pool for provider '{provider}' has no keys configured") + continue + + for key_config in keys_config: + key_id = key_config.get("id") + api_key_str = key_config.get("api_key", "") + qps_limit = key_config.get("qps_limit") + rate_limit = key_config.get("rate_limit", {}) + + if not key_id: + errors.append(f"Key in provider '{provider}' is missing 'id' field") + continue + + full_key_id = f"{provider}:{key_id}" + if full_key_id in seen_key_ids: + errors.append(f"Duplicate key ID '{key_id}' in provider '{provider}'") + else: + seen_key_ids[full_key_id] = provider + + if not api_key_str: + errors.append( + f"Key '{key_id}' in provider '{provider}' is missing 'api_key' field" + ) + continue + + if api_key_str.startswith("env:"): + env_var = api_key_str[4:] + if strict and not os.getenv(env_var): + errors.append( + f"Environment variable '{env_var}' not set for key " + f"'{key_id}' in provider '{provider}'" + ) + + effective_qps = rate_limit.get("max_qps") or qps_limit + if effective_qps is not None and effective_qps <= 0: + errors.append( + f"Key '{key_id}' in provider '{provider}' has invalid " + f"QPS limit: {effective_qps} (must be > 0)" + ) + + # 3. Validate client profiles configuration + profiles_config = config_loader.get_client_profiles() + if profiles_config: + for profile_name, profile_config in profiles_config.items(): + max_parallel = profile_config.get("max_parallel_requests") + if max_parallel is not None and max_parallel <= 0: + errors.append( + f"Client profile '{profile_name}' has invalid " + f"max_parallel_requests: {max_parallel} (must be > 0)" + ) + + timeout = profile_config.get("default_timeout_s") + if timeout is not None and timeout <= 0: + errors.append( + f"Client profile '{profile_name}' has invalid " + f"default_timeout_s: {timeout} (must be > 0)" + ) + + max_qps_tenant = profile_config.get("max_qps_per_tenant") + if max_qps_tenant is not None and max_qps_tenant <= 0: + errors.append( + f"Client profile '{profile_name}' has invalid " + f"max_qps_per_tenant: {max_qps_tenant} (must be > 0)" + ) + + max_qps_key = profile_config.get("max_qps_per_provider_key") + if max_qps_key is not None and max_qps_key <= 0: + errors.append( + f"Client profile '{profile_name}' has invalid " + f"max_qps_per_provider_key: {max_qps_key} (must be > 0)" + ) + + # Log warnings + for warning in warnings: + logger.warning(f"Configuration warning: {warning}") + + # Fail fast on errors + if errors: + for error in errors: + logger.error(f"Configuration error: {error}") + raise ConfigValidationError( + f"Configuration validation failed with {len(errors)} error(s): " + f"{'; '.join(errors)}" + ) + + return warnings + + +def init_client_profile_manager(config_loader: ConfigLoader) -> ClientProfileManager: + """Initialize ClientProfileManager from configuration. + + Args: + config_loader: Configuration loader + + Returns: + Initialized ClientProfileManager + """ + profiles_config = config_loader.get_client_profiles() + if not profiles_config: + return ClientProfileManager() + + profiles: Dict[str, ClientProfile] = {} + + for profile_name, profile_config in profiles_config.items(): + profile = ClientProfile( + max_parallel_requests=profile_config.get("max_parallel_requests", 10), + max_qps_per_tenant=profile_config.get("max_qps_per_tenant"), + max_qps_per_provider_key=profile_config.get("max_qps_per_provider_key"), + burst_size=profile_config.get("burst_size", 5), + default_timeout_s=profile_config.get("default_timeout_s"), + ) + profiles[profile_name] = profile + logger.info(f"Initialized client profile: {profile_name}") + + return ClientProfileManager(profiles) + + +def init_key_pool_manager(config_loader: ConfigLoader) -> Optional[KeyPoolManager]: + """Initialize KeyPoolManager from configuration. + + Args: + config_loader: Configuration loader + + Returns: + Initialized KeyPoolManager or None if not configured + """ + pools_config = config_loader.get_provider_key_pools() + if not pools_config: + return None + + pools: Dict[str, List[ProviderKey]] = {} + + for provider, pool_config in pools_config.items(): + keys = [] + for key_config in pool_config.get("keys", []): + key_id = key_config.get("id") + api_key_str = key_config.get("api_key", "") + qps_limit = key_config.get("qps_limit") + + if not key_id: + continue + + # Resolve API key from env if needed + if api_key_str.startswith("env:"): + env_var = api_key_str[4:] + api_key = os.getenv(env_var) + if not api_key: + logger.debug(f"Skipping key {key_id}: env var {env_var} not set") + continue + else: + api_key = api_key_str + + # Get rate limit config if present + rate_limit_config = key_config.get("rate_limit", {}) + if rate_limit_config: + qps_limit = rate_limit_config.get("max_qps") or qps_limit + if qps_limit: + qps_limit = int(qps_limit) + + key = ProviderKey( + id=key_id, + provider=provider, + key=api_key, + qps_limit=qps_limit, + ) + keys.append(key) + + if keys: + pools[provider] = keys + logger.info(f"Initialized key pool for {provider} with {len(keys)} keys") + + if pools: + return KeyPoolManager(pools) + return None diff --git a/app/main.py b/app/main.py index d1a1503..d2410b4 100644 --- a/app/main.py +++ b/app/main.py @@ -1,1348 +1,305 @@ -"""ReliAPI FastAPI application - minimal reliability layer.""" -import json +"""ReliAPI FastAPI application - minimal reliability layer. + +This is the main application module that: +- Initializes the FastAPI application +- Configures middleware (CORS, exception handling) +- Registers all route handlers +- Manages application lifespan (startup/shutdown) +""" import logging import os -import uuid +import traceback from contextlib import asynccontextmanager -from typing import Dict, List, Optional -import hashlib +from typing import List -from fastapi import FastAPI, HTTPException, Request, status +from fastapi import FastAPI, Request, status from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse, Response, StreamingResponse -from prometheus_client import CONTENT_TYPE_LATEST, generate_latest - -from reliapi.app.schemas import HTTPProxyRequest, LLMProxyRequest -from reliapi.app.services import handle_http_proxy, handle_llm_proxy, handle_llm_stream_generator +from fastapi.responses import JSONResponse + +from reliapi.app.dependencies import ( + ConfigValidationError, + get_app_state, + init_client_profile_manager, + init_key_pool_manager, + validate_startup_config, +) from reliapi.config.loader import ConfigLoader from reliapi.core.cache import Cache -from reliapi.core.errors import ErrorCode from reliapi.core.idempotency import IdempotencyManager -from reliapi.core.client_profile import ClientProfileManager, ClientProfile -from reliapi.core.key_pool import KeyPoolManager, ProviderKey from reliapi.core.rate_limiter import RateLimiter from reliapi.core.rate_scheduler import RateScheduler -from reliapi.core.free_tier_restrictions import FreeTierRestrictions -from reliapi.core.security import SecurityManager -from reliapi.integrations.rapidapi import RapidAPIClient, SubscriptionTier +from reliapi.integrations.rapidapi import RapidAPIClient from reliapi.integrations.rapidapi_tenant import RapidAPITenantManager -from reliapi.integrations.routellm import ( - RouteLLMDecision, - extract_routellm_decision, - apply_routellm_overrides, - routellm_metrics, -) -from reliapi.metrics.prometheus import ( - rapidapi_webhook_events_total, - rapidapi_tier_cache_total, - rapidapi_tier_distribution, - routellm_decisions_total, - routellm_overrides_total, - free_tier_abuse_attempts_total, -) # Configure structured JSON logging log_level = os.getenv("LOG_LEVEL", "INFO").upper() logging.basicConfig( level=getattr(logging, log_level, logging.INFO), - format="%(message)s", # JSON logs are already formatted + format="%(message)s", ) logger = logging.getLogger(__name__) -# Global state -config_loader: ConfigLoader = None -targets: Dict[str, Dict] = {} -cache: Cache = None -idempotency: IdempotencyManager = None -rate_limiter: RateLimiter = None -key_pool_manager: KeyPoolManager = None -rate_scheduler: RateScheduler = None -client_profile_manager: ClientProfileManager = None -rapidapi_client: RapidAPIClient = None -rapidapi_tenant_manager: RapidAPITenantManager = None - - -def verify_api_key(request: Request) -> tuple[Optional[str], Optional[str], str]: - """Verify API key from header and resolve tenant and tier. - - Priority for tier detection: - 1. RapidAPI headers (X-RapidAPI-User, X-RapidAPI-Subscription) - 2. Redis cache (from previous RapidAPI detection) - 3. Config-based tenants - 4. Test key prefixes (sk-free, sk-dev, sk-pro) - 5. Default: 'free' - - Returns: - Tuple of (api_key, tenant_name, tier). - tenant_name is None if multi-tenant not enabled or tenant not found. - tier is 'free', 'developer', 'pro', or 'enterprise'. - """ - api_key = request.headers.get("X-API-Key") - - # Get headers dict for RapidAPI detection - headers_dict = dict(request.headers) - - # Helper to determine tier with RapidAPI priority - def get_tier(api_key: str, headers: Dict[str, str]) -> str: - # 1. Check RapidAPI headers first - if rapidapi_client: - result = rapidapi_client.get_tier_from_headers(headers) - if result: - user_id, tier_enum = result - rapidapi_tier_cache_total.labels(operation="hit").inc() - return tier_enum.value - - # 2. Use rate_limiter with RapidAPI client for cache lookup - if rate_limiter and api_key: - tier = rate_limiter.get_account_tier(api_key, headers, rapidapi_client) - return tier - - return "free" - - # Multi-tenant mode: check tenants config - if config_loader and hasattr(config_loader, 'config') and config_loader.config.tenants: - tenants = config_loader.config.tenants - # Find tenant by API key - for tenant_name, tenant_config in tenants.items(): - if tenant_config.api_key == api_key: - # Store tenant in request state for later use - request.state.tenant = tenant_name - tier = get_tier(api_key, headers_dict) - request.state.tier = tier - return api_key, tenant_name, tier - - # Tenant not found - check if global API key is set (backward compatibility) - required_key = os.getenv("RELIAPI_API_KEY") - if required_key and api_key == required_key: - request.state.tenant = None # No tenant, use default - tier = get_tier(api_key, headers_dict) - request.state.tier = tier - return api_key, None, tier - - # No matching tenant and no global key match - allow with free tier - tier = get_tier(api_key, headers_dict) - request.state.tier = tier - request.state.tenant = None - return api_key, None, tier - - # Single-tenant mode: use global API key - if not api_key: - # Check if RapidAPI headers are present (RapidAPI handles auth) - if rapidapi_client: - result = rapidapi_client.get_tier_from_headers(headers_dict) - if result: - # RapidAPI user - create virtual API key from user ID - user_id, tier_enum = result - virtual_api_key = f"rapidapi:{user_id}" - - # Auto-create tenant for RapidAPI user - if rapidapi_tenant_manager: - tenant_name = rapidapi_tenant_manager.ensure_tenant_exists(user_id, tier_enum) - request.state.tenant = tenant_name - else: - request.state.tenant = None - - request.state.tier = tier_enum.value - rapidapi_tier_cache_total.labels(operation="hit").inc() - return virtual_api_key, request.state.tenant, tier_enum.value - - raise HTTPException( - status_code=401, - detail={ - "success": False, - "error": { - "type": "client_error", - "code": ErrorCode.UNAUTHORIZED.value, - "message": "Missing X-API-Key header", - "retryable": False, - "target": None, - "status_code": 401, - }, - }, - ) - - # Determine tier - tier = get_tier(api_key, headers_dict) - - # For testing: allow keys starting with sk-free/sk-dev/sk-pro to pass - if api_key and (api_key.startswith("sk-free") or api_key.startswith("sk-dev") or api_key.startswith("sk-pro")): - request.state.tenant = None - request.state.tier = tier - return api_key, None, tier - - # Check against required key - required_key = os.getenv("RELIAPI_API_KEY") - if not required_key: - # No auth required if env var not set - request.state.tenant = None - request.state.tier = tier - return api_key, None, tier - - if api_key != required_key: - raise HTTPException( - status_code=401, - detail={ - "success": False, - "error": { - "type": "client_error", - "code": ErrorCode.UNAUTHORIZED.value, - "message": "Invalid API key", - "retryable": False, - "target": None, - "status_code": 401, - }, - }, - ) - - request.state.tenant = None # No tenant in single-tenant mode - request.state.tier = tier - return api_key, None, tier - - -def detect_client_profile(request: Request, tenant: Optional[str] = None) -> Optional[str]: - """Detect client profile using priority: X-Client header > tenant.profile > default. - - Args: - request: FastAPI request - tenant: Tenant name (if known) - - Returns: - Profile name or None - """ - # Priority 1: X-Client header - client_header = request.headers.get("X-Client") - if client_header and client_profile_manager and client_profile_manager.has_profile(client_header): - return client_header - - # Priority 2: tenant.profile - if tenant and config_loader: - tenant_config = config_loader.get_tenant(tenant) - if tenant_config and tenant_config.get("profile"): - profile_name = tenant_config.get("profile") - if client_profile_manager and client_profile_manager.has_profile(profile_name): - return profile_name - - # Priority 3: default - return "default" - - -def _init_client_profile_manager(config_loader: ConfigLoader) -> ClientProfileManager: - """Initialize ClientProfileManager from configuration.""" - profiles_config = config_loader.get_client_profiles() - if not profiles_config: - return ClientProfileManager() - - profiles: Dict[str, ClientProfile] = {} - - for profile_name, profile_config in profiles_config.items(): - profile = ClientProfile( - max_parallel_requests=profile_config.get("max_parallel_requests", 10), - max_qps_per_tenant=profile_config.get("max_qps_per_tenant"), - max_qps_per_provider_key=profile_config.get("max_qps_per_provider_key"), - burst_size=profile_config.get("burst_size", 5), - default_timeout_s=profile_config.get("default_timeout_s"), - ) - profiles[profile_name] = profile - logger.info(f"Initialized client profile: {profile_name}") - - return ClientProfileManager(profiles) - - -class ConfigValidationError(Exception): - """Raised when configuration validation fails.""" - pass - - -def _validate_startup_config(config_loader: ConfigLoader, strict: bool = True) -> List[str]: - """Validate configuration at startup. - - Args: - config_loader: Configuration loader - strict: If True, fail on missing required env vars - - Returns: - List of validation warnings (non-fatal issues) - - Raises: - ConfigValidationError: If configuration is invalid - """ - errors: List[str] = [] - warnings: List[str] = [] - - # 1. Validate required environment variables - redis_url = os.getenv("REDIS_URL") - if not redis_url: - warnings.append("REDIS_URL not set - Redis features will be disabled") - - # RapidAPI API key is optional but recommended - rapidapi_key = os.getenv("RAPIDAPI_API_KEY") - if not rapidapi_key: - warnings.append("RAPIDAPI_API_KEY not set - RapidAPI tier detection may be limited") - - # 2. Validate key pool configuration - pools_config = config_loader.get_provider_key_pools() - if pools_config: - seen_key_ids: Dict[str, str] = {} # key_id -> provider (for uniqueness check) - - for provider, pool_config in pools_config.items(): - keys_config = pool_config.get("keys", []) - - if not keys_config: - warnings.append(f"Key pool for provider '{provider}' has no keys configured") - continue - - for key_config in keys_config: - key_id = key_config.get("id") - api_key_str = key_config.get("api_key", "") - qps_limit = key_config.get("qps_limit") - rate_limit = key_config.get("rate_limit", {}) - - # Check key ID is present - if not key_id: - errors.append(f"Key in provider '{provider}' is missing 'id' field") - continue - - # Check key ID uniqueness (within provider) - full_key_id = f"{provider}:{key_id}" - if full_key_id in seen_key_ids: - errors.append(f"Duplicate key ID '{key_id}' in provider '{provider}'") - else: - seen_key_ids[full_key_id] = provider - - # Check API key is present - if not api_key_str: - errors.append(f"Key '{key_id}' in provider '{provider}' is missing 'api_key' field") - continue - - # Check env var exists if using env:VAR_NAME format - if api_key_str.startswith("env:"): - env_var = api_key_str[4:] - if strict and not os.getenv(env_var): - errors.append(f"Environment variable '{env_var}' not set for key '{key_id}' in provider '{provider}'") - - # Check QPS limit is positive - effective_qps = rate_limit.get("max_qps") or qps_limit - if effective_qps is not None and effective_qps <= 0: - errors.append(f"Key '{key_id}' in provider '{provider}' has invalid QPS limit: {effective_qps} (must be > 0)") - - # 3. Validate client profiles configuration - profiles_config = config_loader.get_client_profiles() - if profiles_config: - for profile_name, profile_config in profiles_config.items(): - max_parallel = profile_config.get("max_parallel_requests") - if max_parallel is not None and max_parallel <= 0: - errors.append(f"Client profile '{profile_name}' has invalid max_parallel_requests: {max_parallel} (must be > 0)") - - timeout = profile_config.get("default_timeout_s") - if timeout is not None and timeout <= 0: - errors.append(f"Client profile '{profile_name}' has invalid default_timeout_s: {timeout} (must be > 0)") - - max_qps_tenant = profile_config.get("max_qps_per_tenant") - if max_qps_tenant is not None and max_qps_tenant <= 0: - errors.append(f"Client profile '{profile_name}' has invalid max_qps_per_tenant: {max_qps_tenant} (must be > 0)") - - max_qps_key = profile_config.get("max_qps_per_provider_key") - if max_qps_key is not None and max_qps_key <= 0: - errors.append(f"Client profile '{profile_name}' has invalid max_qps_per_provider_key: {max_qps_key} (must be > 0)") - - # Log warnings - for warning in warnings: - logger.warning(f"Configuration warning: {warning}") - - # Fail fast on errors - if errors: - for error in errors: - logger.error(f"Configuration error: {error}") - raise ConfigValidationError(f"Configuration validation failed with {len(errors)} error(s): {'; '.join(errors)}") - - return warnings - - -def _init_key_pool_manager(config_loader: ConfigLoader) -> Optional[KeyPoolManager]: - """Initialize KeyPoolManager from configuration.""" - pools_config = config_loader.get_provider_key_pools() - if not pools_config: - return None - - pools: Dict[str, List[ProviderKey]] = {} - - for provider, pool_config in pools_config.items(): - keys = [] - for key_config in pool_config.get("keys", []): - key_id = key_config.get("id") - api_key_str = key_config.get("api_key", "") - qps_limit = key_config.get("qps_limit") - - if not key_id: - continue # Already validated, skip - - # Resolve API key from env if needed - if api_key_str.startswith("env:"): - env_var = api_key_str[4:] - api_key = os.getenv(env_var) - if not api_key: - # Already validated in strict mode, skip silently - logger.debug(f"Skipping key {key_id}: env var {env_var} not set") - continue - else: - api_key = api_key_str - - # Get rate limit config if present - rate_limit_config = key_config.get("rate_limit", {}) - if rate_limit_config: - # Use rate_limit.max_qps if present, otherwise fallback to qps_limit - qps_limit = rate_limit_config.get("max_qps") or qps_limit - if qps_limit: - qps_limit = int(qps_limit) - - key = ProviderKey( - id=key_id, - provider=provider, - key=api_key, - qps_limit=qps_limit, - ) - keys.append(key) - - if keys: - pools[provider] = keys - logger.info(f"Initialized key pool for {provider} with {len(keys)} keys") - - if pools: - return KeyPoolManager(pools) - return None - @asynccontextmanager async def lifespan(app: FastAPI): - """Application lifespan manager.""" - global config_loader, targets, cache, idempotency, rate_limiter, key_pool_manager, rate_scheduler, client_profile_manager, rapidapi_client, rapidapi_tenant_manager - + """Application lifespan manager. + + Handles startup initialization and shutdown cleanup for: + - Configuration loading and validation + - Redis connections (cache, idempotency, rate limiting) + - RapidAPI integration + - Key pool and rate scheduler + - Client profile management + """ + state = get_app_state() + # Startup - config_path = os.getenv("RELIAPI_CONFIG_PATH", os.getenv("RELIAPI_CONFIG", "config.yaml")) + config_path = os.getenv( + "RELIAPI_CONFIG_PATH", os.getenv("RELIAPI_CONFIG", "config.yaml") + ) redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0") - + logger.info(f"Loading configuration from {config_path}") - config_loader = ConfigLoader(config_path) - config_loader.load() - targets = config_loader.get_targets() - + state.config_loader = ConfigLoader(config_path) + state.config_loader.load() + state.targets = state.config_loader.get_targets() + # Validate configuration (fail fast on invalid config) - # Set strict=False for development to allow missing env vars with warnings strict_validation = os.getenv("RELIAPI_STRICT_CONFIG", "true").lower() == "true" try: - validation_warnings = _validate_startup_config(config_loader, strict=strict_validation) + validation_warnings = validate_startup_config( + state.config_loader, strict=strict_validation + ) if validation_warnings: - logger.info(f"Configuration loaded with {len(validation_warnings)} warning(s)") + logger.info( + f"Configuration loaded with {len(validation_warnings)} warning(s)" + ) except ConfigValidationError as e: logger.critical(f"Configuration validation failed: {e}") raise - + logger.info(f"Initializing Redis connection: {redis_url}") - cache = Cache(redis_url, key_prefix="reliapi") - idempotency = IdempotencyManager(redis_url, key_prefix="reliapi") - - # Initialize rate limiter - rate_limiter = RateLimiter(redis_url, key_prefix="reliapi") - + state.cache = Cache(redis_url, key_prefix="reliapi") + state.idempotency = IdempotencyManager(redis_url, key_prefix="reliapi") + state.rate_limiter = RateLimiter(redis_url, key_prefix="reliapi") + # Initialize RapidAPI client - rapidapi_client = RapidAPIClient( + state.rapidapi_client = RapidAPIClient( redis_url=redis_url, key_prefix="reliapi", ) logger.info("RapidAPI client initialized") - + # Initialize RapidAPI tenant manager - if cache and cache.client: - rapidapi_tenant_manager = RapidAPITenantManager( - redis_client=cache.client, + if state.cache and state.cache.client: + state.rapidapi_tenant_manager = RapidAPITenantManager( + redis_client=state.cache.client, key_prefix="reliapi", ) logger.info("RapidAPI tenant manager initialized") - + # Initialize key pool manager - key_pool_manager = _init_key_pool_manager(config_loader) - if key_pool_manager: + state.key_pool_manager = init_key_pool_manager(state.config_loader) + if state.key_pool_manager: logger.info("Key pool manager initialized") else: logger.info("No key pools configured, using targets.auth") - + # Initialize rate scheduler with memory management - rate_scheduler = RateScheduler( + state.rate_scheduler = RateScheduler( max_buckets=1000, - bucket_ttl_seconds=3600, # 1 hour TTL - cleanup_interval_seconds=300, # 5 minute cleanup interval + bucket_ttl_seconds=3600, + cleanup_interval_seconds=300, ) - await rate_scheduler.start_cleanup_task() + await state.rate_scheduler.start_cleanup_task() logger.info("Rate scheduler initialized with memory management") - + # Initialize client profile manager - client_profile_manager = _init_client_profile_manager(config_loader) + state.client_profile_manager = init_client_profile_manager(state.config_loader) logger.info("Client profile manager initialized") - - logger.info(f"ReliAPI started with {len(targets)} targets") - + + logger.info(f"ReliAPI started with {len(state.targets)} targets") + yield - + # Shutdown logger.info("Shutting down ReliAPI...") - if rate_scheduler: - await rate_scheduler.stop_cleanup_task() - if rapidapi_client: - await rapidapi_client.close() + if state.rate_scheduler: + await state.rate_scheduler.stop_cleanup_task() + if state.rapidapi_client: + await state.rapidapi_client.close() -app = FastAPI( - title="ReliAPI", - version="1.0.7", - description="ReliAPI is a small LLM reliability layer for HTTP and LLM calls: retries, circuit breaker, cache, idempotency, and budget caps. Idempotent LLM proxy with predictable AI costs. Self-hosted AI gateway focused on reliability, not features.", - lifespan=lifespan, -) +def create_app() -> FastAPI: + """Create and configure the FastAPI application. -# CORS middleware with production security -cors_origins_env = os.getenv("CORS_ORIGINS", "*") -is_production = os.getenv("ENVIRONMENT", "").lower() == "production" + Returns: + Configured FastAPI application instance. + """ + app = FastAPI( + title="ReliAPI", + version="1.0.7", + description=( + "ReliAPI is a small LLM reliability layer for HTTP and LLM calls: " + "retries, circuit breaker, cache, idempotency, and budget caps. " + "Idempotent LLM proxy with predictable AI costs. " + "Self-hosted AI gateway focused on reliability, not features." + ), + lifespan=lifespan, + ) -if cors_origins_env == "*": - if is_production: - # In production, warn if CORS_ORIGINS is "*" (security risk) - logger.warning( - "SECURITY WARNING: CORS_ORIGINS is set to '*' in production. " - "This allows requests from any origin. Consider restricting to specific origins." - ) - cors_origins = ["*"] -else: - cors_origins = [origin.strip() for origin in cors_origins_env.split(",")] - # Validate origins (basic security check) - validated_origins = [] - for origin in cors_origins: - if not origin: - continue - # Basic validation: must start with http:// or https:// - if origin != "*" and not (origin.startswith("http://") or origin.startswith("https://")): - logger.warning(f"Invalid CORS origin format (skipping): {origin}") - continue - # In production, don't allow wildcard subdomains without explicit configuration - if is_production and origin == "*": - logger.warning("Wildcard CORS origin '*' not recommended in production") - validated_origins.append(origin) - cors_origins = validated_origins + # Configure CORS middleware + _configure_cors(app) -# Log CORS configuration -if is_production: - logger.info(f"CORS configured for production with {len(cors_origins)} allowed origin(s)") - if len(cors_origins) > 10: - logger.warning(f"Large number of CORS origins ({len(cors_origins)}), consider consolidating") + # Register exception handlers + _register_exception_handlers(app) -app.add_middleware( - CORSMiddleware, - allow_origins=cors_origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - expose_headers=["X-Request-ID", "X-RateLimit-Remaining", "X-RateLimit-Reset"], -) + # Register routes + _register_routes(app) + return app -@app.exception_handler(Exception) -async def global_exception_handler(request: Request, exc: Exception): - """Global exception handler.""" - import traceback - error_details = traceback.format_exc() - logger.error(f"Unhandled exception: {type(exc).__name__}: {str(exc)}\n{error_details}", exc_info=True) - return JSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content={ - "success": False, - "error": { - "type": "internal_error", - "code": "INTERNAL_ERROR", - "message": f"Internal server error: {str(exc)}", - "retryable": True, - "target": None, - "status_code": 500, - }, - "meta": { - "target": None, - "cache_hit": False, - "retries": 0, - "duration_ms": 0, - "request_id": request.headers.get("X-Request-ID", "unknown"), - "trace_id": request.headers.get("X-Trace-ID"), - }, - }, - ) +def _configure_cors(app: FastAPI) -> None: + """Configure CORS middleware with production security.""" + cors_origins_env = os.getenv("CORS_ORIGINS", "*") + is_production = os.getenv("ENVIRONMENT", "").lower() == "production" -@app.get("/healthz") -async def healthz(http_request: Request): - """Health check endpoint with optional rate limiting.""" - # Optional rate limiting for healthz (20 req/min per IP) - if rate_limiter: - client_ip = http_request.client.host if http_request.client else "unknown" - allowed, error = rate_limiter.check_ip_rate_limit(client_ip, limit_per_minute=20, prefix="healthz") - if not allowed: - # For healthz, return 429 but don't log as warning (expected for monitoring) - raise HTTPException( - status_code=429, - detail={ - "type": "rate_limit_error", - "code": error, - "message": "Rate limit exceeded for healthz endpoint.", - }, + if cors_origins_env == "*": + if is_production: + logger.warning( + "SECURITY WARNING: CORS_ORIGINS is set to '*' in production. " + "This allows requests from any origin. " + "Consider restricting to specific origins." ) - return {"status": "healthy"} - + cors_origins = ["*"] + else: + cors_origins = _validate_cors_origins(cors_origins_env, is_production) -@app.get("/readyz") -async def readyz(http_request: Request): - """Readiness check endpoint with optional rate limiting.""" - # Optional rate limiting for readyz (20 req/min per IP) - if rate_limiter: - client_ip = http_request.client.host if http_request.client else "unknown" - allowed, error = rate_limiter.check_ip_rate_limit(client_ip, limit_per_minute=20, prefix="readyz") - if not allowed: - # For readyz, return 429 but don't log as warning (expected for monitoring) - raise HTTPException( - status_code=429, - detail={ - "type": "rate_limit_error", - "code": error, - "message": "Rate limit exceeded for readyz endpoint.", - }, + if is_production: + logger.info( + f"CORS configured for production with {len(cors_origins)} allowed origin(s)" + ) + if len(cors_origins) > 10: + logger.warning( + f"Large number of CORS origins ({len(cors_origins)}), " + "consider consolidating" ) - return {"status": "ready"} + app.add_middleware( + CORSMiddleware, + allow_origins=cors_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + expose_headers=["X-Request-ID", "X-RateLimit-Remaining", "X-RateLimit-Reset"], + ) -@app.get("/livez") -async def livez(http_request: Request): - """Liveness check endpoint with optional rate limiting.""" - # Optional rate limiting for livez (20 req/min per IP) - if rate_limiter: - client_ip = http_request.client.host if http_request.client else "unknown" - allowed, error = rate_limiter.check_ip_rate_limit(client_ip, limit_per_minute=20, prefix="livez") - if not allowed: - # For livez, return 429 but don't log as warning (expected for monitoring) - raise HTTPException( - status_code=429, - detail={ - "type": "rate_limit_error", - "code": error, - "message": "Rate limit exceeded for livez endpoint.", - }, - ) - return {"status": "alive"} +def _validate_cors_origins(cors_origins_env: str, is_production: bool) -> List[str]: + """Validate and filter CORS origins. -@app.get("/metrics") -async def metrics(http_request: Request): - """Prometheus metrics endpoint with rate limiting.""" - # Rate limiting for metrics endpoint (10 req/min per IP) - if rate_limiter: - client_ip = http_request.client.host if http_request.client else "unknown" - allowed, error = rate_limiter.check_ip_rate_limit(client_ip, limit_per_minute=10, prefix="metrics") - if not allowed: - logger.warning(f"Rate limit exceeded for /metrics endpoint: IP={client_ip}") - raise HTTPException( - status_code=429, - detail={ - "type": "rate_limit_error", - "code": error, - "message": "Rate limit exceeded for metrics endpoint (10 req/min per IP).", - }, - ) - - return Response(content=generate_latest(), media_type=CONTENT_TYPE_LATEST) + Args: + cors_origins_env: Comma-separated CORS origins string + is_production: Whether running in production + Returns: + List of validated CORS origins + """ + origins = [origin.strip() for origin in cors_origins_env.split(",")] + validated_origins = [] -@app.post("/proxy/http", summary="Proxy HTTP request", description="Universal HTTP proxy endpoint for any HTTP API. Supports retries, circuit breaker, cache, and idempotency. Use this endpoint to add reliability layers to any HTTP API call.") -async def proxy_http( - request: HTTPProxyRequest, - http_request: Request, -): - """Universal HTTP proxy endpoint for any HTTP API.""" - # Verify API key if required and resolve tenant and tier - api_key, tenant, tier = verify_api_key(http_request) - - # Security: Validate API key format (BYO-key security) - if api_key: - is_valid, error_msg = SecurityManager.validate_api_key_format(api_key) - if not is_valid: - raise HTTPException( - status_code=400, - detail={ - "type": "client_error", - "code": "INVALID_API_KEY_FORMAT", - "message": error_msg or "Invalid API key format", - }, - ) - - # Rate limiting and abuse protection for Free tier - if rate_limiter and tier == "free": - # Set current tier for abuse detection - rate_limiter._current_tier = tier - - client_ip = http_request.client.host if http_request.client else "unknown" - user_agent = http_request.headers.get("User-Agent", "") - - # Check IP rate limit (20 req/min) - allowed, error = rate_limiter.check_ip_rate_limit(client_ip, limit_per_minute=20) - if not allowed: - free_tier_abuse_attempts_total.labels(abuse_type="rate_limit_bypass", tier=tier).inc() - logger.warning(f"Free tier abuse attempt: IP rate limit exceeded for tier={tier}, IP={client_ip}") - raise HTTPException( - status_code=429, - detail={ - "type": "rate_limit_error", - "code": error, - "message": "Rate limit exceeded. Free tier: 20 requests/minute per IP.", - }, - ) - - # Check account burst limit (500 req/min) - account_id = hashlib.sha256(api_key.encode()).hexdigest()[:16] if api_key else "unknown" - allowed, error = rate_limiter.check_account_burst_limit(account_id, limit_per_minute=500) - if not allowed: - free_tier_abuse_attempts_total.labels(abuse_type="burst_limit", tier=tier).inc() - logger.warning(f"Free tier abuse attempt: burst limit exceeded for tier={tier}, account_id={account_id}") - raise HTTPException( - status_code=429, - detail={ - "type": "abuse_error", - "code": error, - "message": "Burst limit exceeded. Free tier abuse detected.", - }, - ) - - # Check fingerprint limit - allowed, error = rate_limiter.check_fingerprint_limit(client_ip, user_agent, api_key or "", limit_per_minute=20) - if not allowed: - free_tier_abuse_attempts_total.labels(abuse_type="fingerprint_mismatch", tier=tier).inc() - logger.warning(f"Free tier abuse attempt: fingerprint mismatch for tier={tier}, account_id={account_id}") - raise HTTPException( - status_code=429, - detail={ - "type": "rate_limit_error", - "code": error, - "message": "Rate limit exceeded based on fingerprint.", - }, - ) - - # Check anomaly detector - allowed, error = rate_limiter.check_anomaly_detector(account_id) - if not allowed: - raise HTTPException( - status_code=429, - detail={ - "type": "anomaly_error", - "code": error, - "message": "Anomalous activity detected. Request throttled.", - }, - ) - - # Generate request ID (UUID4 for better uniqueness) - request_id = f"req_{uuid.uuid4().hex[:16]}" - - # Detect client profile - client_profile_name = detect_client_profile(http_request, tenant=tenant) - - result = await handle_http_proxy( - target_name=request.target, - method=request.method, - path=request.path, - headers=request.headers, - query=request.query, - body=request.body, - idempotency_key=request.idempotency_key, - cache_ttl=request.cache, - targets=targets, - cache=cache, - idempotency=idempotency, - key_pool_manager=key_pool_manager, - rate_scheduler=rate_scheduler, - client_profile_name=client_profile_name, - client_profile_manager=client_profile_manager, - request_id=request_id, - tenant=tenant, - tier=tier, - ) - - # Record usage for RapidAPI tracking - if rapidapi_client and api_key: - await rapidapi_client.record_usage( - api_key=api_key, - endpoint="/proxy/http", - latency_ms=result.meta.duration_ms, - status="success" if result.success else "error", - ) - rapidapi_tier_distribution.labels(tier=tier).inc() - - status_code = 200 if result.success else (result.error.status_code or 500) - return JSONResponse( - content=result.model_dump(), - status_code=status_code, - headers={ - "X-Request-ID": request_id, - "X-Cache-Hit": str(result.meta.cache_hit).lower(), - "X-Retries": str(result.meta.retries), - "X-Duration-MS": str(result.meta.duration_ms), - }, - ) + for origin in origins: + if not origin: + continue + if origin != "*" and not ( + origin.startswith("http://") or origin.startswith("https://") + ): + logger.warning(f"Invalid CORS origin format (skipping): {origin}") + continue -@app.post("/proxy/llm", summary="Proxy LLM request", description="LLM proxy endpoint with idempotency, budget caps, and caching. Make idempotent LLM API calls with predictable costs. Supports OpenAI, Anthropic, and Mistral providers. Set stream=true for Server-Sent Events (SSE) streaming.") -async def proxy_llm( - request: LLMProxyRequest, - http_request: Request, -): - """LLM proxy endpoint with idempotency and budget control.""" - # Verify API key if required and resolve tenant and tier - api_key, tenant, tier = verify_api_key(http_request) - - # Security: Validate API key format (BYO-key security) - if api_key: - is_valid, error_msg = SecurityManager.validate_api_key_format(api_key) - if not is_valid: - raise HTTPException( - status_code=400, - detail={ - "type": "client_error", - "code": "INVALID_API_KEY_FORMAT", - "message": error_msg or "Invalid API key format", - }, - ) - - # Free tier: Block SSE streaming - if tier == "free" and request.stream: - allowed, error = FreeTierRestrictions.is_feature_allowed("streaming", tier) - if not allowed: - raise HTTPException( - status_code=403, - detail={ - "type": "feature_error", - "code": error, - "message": "SSE streaming not available for Free tier.", - }, - ) - - # Rate limiting and abuse protection for Free tier - if rate_limiter and tier == "free": - # Set current tier for abuse detection - rate_limiter._current_tier = tier - - client_ip = http_request.client.host if http_request.client else "unknown" - user_agent = http_request.headers.get("User-Agent", "") - accept_language = http_request.headers.get("Accept-Language", "") - account_id = hashlib.sha256(api_key.encode()).hexdigest()[:16] if api_key else "unknown" - - # Check auto-ban first (>5 bypass attempts) - should_ban, ban_reason = rate_limiter.check_auto_ban(account_id, client_ip, max_attempts=5) - if should_ban: - free_tier_abuse_attempts_total.labels(abuse_type="auto_ban", tier=tier).inc() - logger.warning(f"Free tier abuse: account/IP banned for tier={tier}, account_id={account_id}, reason={ban_reason}") - raise HTTPException( - status_code=403, - detail={ - "type": "abuse_error", - "code": "ACCOUNT_BANNED", - "message": f"Account/IP banned: {ban_reason}", - }, - ) - - # Check IP rate limit (20 req/min) - allowed, error = rate_limiter.check_ip_rate_limit(client_ip, limit_per_minute=20) - if not allowed: - rate_limiter.abuse_detector.record_limit_bypass_attempt(account_id, client_ip) - raise HTTPException( - status_code=429, - detail={ - "type": "rate_limit_error", - "code": error, - "message": "Rate limit exceeded. Free tier: 20 requests/minute per IP.", - }, - ) - - # Check burst protection (≤300 req/10min) - allowed, error = rate_limiter.check_burst_protection(account_id, limit_per_10min=300) - if not allowed: - rate_limiter.abuse_detector.record_limit_bypass_attempt(account_id, client_ip) - raise HTTPException( - status_code=429, - detail={ - "type": "abuse_error", - "code": error, - "message": "Burst limit exceeded. Free tier: maximum 300 requests per 10 minutes.", - }, - ) - - # Check account burst limit (500 req/min) - allowed, error = rate_limiter.check_account_burst_limit(account_id, limit_per_minute=500) - if not allowed: - rate_limiter.abuse_detector.record_limit_bypass_attempt(account_id, client_ip) - raise HTTPException( - status_code=429, - detail={ - "type": "abuse_error", - "code": error, - "message": "Burst limit exceeded. Free tier abuse detected.", - }, - ) - - # Check usage anomaly (3x average) - allowed, error = rate_limiter.check_usage_anomaly(account_id, multiplier=3.0) - if not allowed: - raise HTTPException( - status_code=429, - detail={ - "type": "anomaly_error", - "code": error, - "message": "Usage anomaly detected. Request throttled.", - }, - ) - - # Check fingerprint-based identity - allowed, error = rate_limiter.check_fingerprint( - account_id, client_ip, user_agent, accept_language + if is_production and origin == "*": + logger.warning("Wildcard CORS origin '*' not recommended in production") + + validated_origins.append(origin) + + return validated_origins + + +def _register_exception_handlers(app: FastAPI) -> None: + """Register global exception handlers.""" + + @app.exception_handler(Exception) + async def global_exception_handler(request: Request, exc: Exception): + """Global exception handler for unhandled errors.""" + error_details = traceback.format_exc() + logger.error( + f"Unhandled exception: {type(exc).__name__}: {str(exc)}\n{error_details}", + exc_info=True, ) - if not allowed: - if error == "FINGERPRINT_MISMATCH_BANNED": - raise HTTPException( - status_code=403, - detail={ - "type": "abuse_error", - "code": error, - "message": "Account banned due to multiple fingerprint mismatches.", - }, - ) - raise HTTPException( - status_code=429, - detail={ - "type": "abuse_error", - "code": error, - "message": "Fingerprint mismatch detected. Request throttled.", + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={ + "success": False, + "error": { + "type": "internal_error", + "code": "INTERNAL_ERROR", + "message": f"Internal server error: {str(exc)}", + "retryable": True, + "target": None, + "status_code": 500, }, - ) - - # Validate Free tier restrictions - if request.model: - allowed, error = FreeTierRestrictions.is_model_allowed( - request.target, - request.model, - tier - ) - if not allowed: - raise HTTPException( - status_code=403, - detail={ - "type": "feature_error", - "code": error, - "message": f"Model {request.model} not allowed for Free tier. Allowed: gpt-4o-mini, claude-3-haiku, mistral-small", - }, - ) - - # Check idempotency restriction - if request.idempotency_key: - allowed, error = FreeTierRestrictions.is_feature_allowed("idempotency", tier) - if not allowed: - raise HTTPException( - status_code=403, - detail={ - "type": "feature_error", - "code": error, - "message": "Idempotency not available for Free tier.", - }, - ) - - # Generate request ID (UUID4 for better uniqueness) - request_id = f"req_{uuid.uuid4().hex[:16]}" - - # Extract RouteLLM routing decision from headers - routellm_decision = extract_routellm_decision(dict(http_request.headers)) - - # Apply RouteLLM overrides to target and model - resolved_target = request.target - resolved_model = request.model - if routellm_decision and routellm_decision.has_override: - resolved_target, resolved_model = apply_routellm_overrides( - request.target, - request.model, - targets, - routellm_decision, - ) - - # Record metrics for RouteLLM routing - routellm_decisions_total.labels( - route_name=routellm_decision.route_name or "unknown", - provider=routellm_decision.provider or "default", - model=routellm_decision.model or "default", - ).inc() - - if routellm_decision.provider and routellm_decision.model: - routellm_overrides_total.labels(override_type="both").inc() - elif routellm_decision.provider: - routellm_overrides_total.labels(override_type="provider").inc() - elif routellm_decision.model: - routellm_overrides_total.labels(override_type="model").inc() - - routellm_metrics.record_decision(routellm_decision) - - # Handle streaming requests - if request.stream: - generator = handle_llm_stream_generator( - target_name=resolved_target, - messages=request.messages, - model=resolved_model, - max_tokens=request.max_tokens, - temperature=request.temperature, - top_p=request.top_p, - stop=request.stop, - idempotency_key=request.idempotency_key, - cache_ttl=request.cache, - targets=targets, - cache=cache, - idempotency=idempotency, - request_id=request_id, - tenant=tenant, - tier=tier, - ) - - # Build response headers including RouteLLM correlation - response_headers = { - "X-Request-ID": request_id, - "Cache-Control": "no-cache", - "Connection": "keep-alive", - } - if routellm_decision: - response_headers.update(routellm_decision.to_response_headers()) - - return StreamingResponse( - generator, - media_type="text/event-stream", - headers=response_headers, - ) - - # Detect client profile - client_profile_name = detect_client_profile(http_request, tenant=tenant) - - # Handle non-streaming requests (existing behavior) - result = await handle_llm_proxy( - target_name=resolved_target, - messages=request.messages, - model=resolved_model, - max_tokens=request.max_tokens, - temperature=request.temperature, - top_p=request.top_p, - stop=request.stop, - stream=False, - idempotency_key=request.idempotency_key, - cache_ttl=request.cache, - targets=targets, - cache=cache, - idempotency=idempotency, - request_id=request_id, - tenant=tenant, - tier=tier, - key_pool_manager=key_pool_manager, - rate_scheduler=rate_scheduler, - client_profile_name=client_profile_name, - client_profile_manager=client_profile_manager, - ) - - # Record usage for RapidAPI tracking - if rapidapi_client and api_key: - cost_usd = result.data.usage.estimated_cost_usd if result.success and result.data and result.data.usage else 0.0 - await rapidapi_client.record_usage( - api_key=api_key, - endpoint="/proxy/llm", - latency_ms=result.meta.duration_ms, - status="success" if result.success else "error", - cost_usd=cost_usd, + "meta": { + "target": None, + "cache_hit": False, + "retries": 0, + "duration_ms": 0, + "request_id": request.headers.get("X-Request-ID", "unknown"), + "trace_id": request.headers.get("X-Trace-ID"), + }, + }, ) - rapidapi_tier_distribution.labels(tier=tier).inc() - - # Add RouteLLM correlation to response meta - if routellm_decision: - result.meta.routellm_decision_id = routellm_decision.decision_id - result.meta.routellm_route_name = routellm_decision.route_name - result.meta.routellm_provider_override = routellm_decision.provider - result.meta.routellm_model_override = routellm_decision.model - - # Build response headers including RouteLLM correlation - response_headers = { - "X-Request-ID": request_id, - "X-Cache-Hit": str(result.meta.cache_hit).lower(), - "X-Retries": str(result.meta.retries), - "X-Duration-MS": str(result.meta.duration_ms), - } - if routellm_decision: - response_headers.update(routellm_decision.to_response_headers()) - - status_code = 200 if result.success else (result.error.status_code or 500) - return JSONResponse( - content=result.model_dump(), - status_code=status_code, - headers=response_headers, - ) -# === RapidAPI Webhook Endpoint === +def _register_routes(app: FastAPI) -> None: + """Register all route handlers.""" + # Import and register core routes + from reliapi.app.routes import health, proxy, rapidapi -@app.post( - "/webhooks/rapidapi", - summary="RapidAPI Webhook", - description="Webhook endpoint for RapidAPI events (subscription changes, usage alerts).", - include_in_schema=False, # Hidden from public docs -) -async def rapidapi_webhook(request: Request): - """ - Handle RapidAPI webhook events. - - Supported events: - - subscription.created: New subscription created - - subscription.updated: Subscription tier changed - - subscription.cancelled: Subscription cancelled - - usage.alert: Usage threshold reached - """ - if not rapidapi_client: - raise HTTPException(status_code=503, detail="RapidAPI integration not configured") - - # Rate limiting for webhook endpoint (IP-based, 10 req/min) - if rate_limiter: - client_ip = request.client.host if request.client else "unknown" - webhook_rate_key = f"webhook_ip:{client_ip}" - - # Check rate limit (10 requests per minute) - allowed, error = rate_limiter.check_ip_rate_limit(client_ip, limit_per_minute=10, prefix="webhook") - if not allowed: - logger.warning(f"Webhook rate limit exceeded for IP: {client_ip}") - rapidapi_webhook_events_total.labels(event_type="unknown", status="rate_limited").inc() - raise HTTPException( - status_code=429, - detail="Webhook rate limit exceeded (10 requests/minute)", - ) - - # Request size limit (10KB) - content_length = request.headers.get("content-length") - if content_length and int(content_length) > 10240: - logger.warning(f"Webhook payload too large: {content_length} bytes") - rapidapi_webhook_events_total.labels(event_type="unknown", status="payload_too_large").inc() - raise HTTPException( - status_code=413, - detail="Webhook payload too large (max 10KB)", - ) - - # Get raw body for signature verification - body = await request.body() - - # Verify webhook signature - signature = request.headers.get("X-RapidAPI-Signature", "") - if not rapidapi_client.verify_webhook_signature(body, signature): - logger.warning("Invalid RapidAPI webhook signature") - rapidapi_webhook_events_total.labels(event_type="unknown", status="invalid_signature").inc() - raise HTTPException(status_code=401, detail="Invalid signature") - - # Parse webhook payload - try: - payload = json.loads(body.decode("utf-8")) - except json.JSONDecodeError: - logger.error("Invalid JSON in webhook payload") - rapidapi_webhook_events_total.labels(event_type="unknown", status="invalid_json").inc() - raise HTTPException(status_code=400, detail="Invalid JSON payload") - - event_type = payload.get("type", "unknown") - event_data = payload.get("data", {}) - event_id = payload.get("id") or payload.get("event_id") or "" - - logger.info(f"Received RapidAPI webhook: {event_type}, event_id={event_id}") - - # Idempotency check: generate key from event_type + event_id - # This prevents duplicate processing of the same webhook event - if idempotency and event_id: - webhook_idempotency_key = f"webhook:rapidapi:{event_type}:{event_id}" - - # Check if this webhook has already been processed - existing_result = idempotency.get_result(webhook_idempotency_key) - if existing_result: - logger.info(f"Duplicate webhook detected: {event_type}, event_id={event_id}") - rapidapi_webhook_events_total.labels(event_type=event_type, status="duplicate").inc() - return JSONResponse( - content={ - "status": "ok", - "event_type": event_type, - "duplicate": True, - "message": "Event already processed", - }, - status_code=200, - ) - - # Mark as in progress to prevent concurrent processing - idempotency.mark_in_progress(webhook_idempotency_key, ttl_s=60) - else: - webhook_idempotency_key = None - + app.include_router(health.router) + app.include_router(proxy.router) + app.include_router(rapidapi.router) + + # Import and register business routes try: - if event_type == "subscription.created": - # New subscription - cache tier info and create tenant - api_key = event_data.get("api_key") - tier = event_data.get("tier", "free") - user_id = event_data.get("user_id") - - if api_key: - tier_enum = SubscriptionTier(tier) if tier in [t.value for t in SubscriptionTier] else SubscriptionTier.FREE - await rapidapi_client._cache_tier(api_key, tier_enum, user_id) - rapidapi_tier_cache_total.labels(operation="set").inc() - rapidapi_tier_distribution.labels(tier=tier).inc() - - # Create tenant for RapidAPI user - if rapidapi_tenant_manager and user_id: - rapidapi_tenant_manager.create_tenant( - user_id, - tier_enum, - metadata={"api_key_hash": rapidapi_client._hash_api_key(api_key)}, - ) - - logger.info(f"Cached new subscription: tier={tier}, user_id={user_id}") - - rapidapi_webhook_events_total.labels(event_type="subscription.created", status="success").inc() - - elif event_type == "subscription.updated": - # Subscription tier changed - invalidate cache, update tier, and migrate tenant - api_key = event_data.get("api_key") - new_tier = event_data.get("tier", "free") - user_id = event_data.get("user_id") - - if api_key: - # Invalidate old cache - await rapidapi_client.invalidate_tier_cache(api_key) - rapidapi_tier_cache_total.labels(operation="invalidate").inc() - - # Cache new tier - tier_enum = SubscriptionTier(new_tier) if new_tier in [t.value for t in SubscriptionTier] else SubscriptionTier.FREE - await rapidapi_client._cache_tier(api_key, tier_enum, user_id) - rapidapi_tier_cache_total.labels(operation="set").inc() - rapidapi_tier_distribution.labels(tier=new_tier).inc() - - # Update tenant tier (migration) - if rapidapi_tenant_manager and user_id: - rapidapi_tenant_manager.update_tenant_tier( - user_id, - tier_enum, - metadata={"api_key_hash": rapidapi_client._hash_api_key(api_key)}, - ) - - logger.info(f"Updated subscription: new_tier={new_tier}, user_id={user_id}") - - rapidapi_webhook_events_total.labels(event_type="subscription.updated", status="success").inc() - - elif event_type == "subscription.cancelled": - # Subscription cancelled - invalidate cache and cleanup tenant - api_key = event_data.get("api_key") - user_id = event_data.get("user_id") - - if api_key: - await rapidapi_client.invalidate_tier_cache(api_key) - rapidapi_tier_cache_total.labels(operation="invalidate").inc() - - # Cleanup tenant (delete tenant and associated data) - if rapidapi_tenant_manager and user_id: - rapidapi_tenant_manager.delete_tenant(user_id) - - logger.info(f"Subscription cancelled: user_id={user_id}") - - rapidapi_webhook_events_total.labels(event_type="subscription.cancelled", status="success").inc() - - elif event_type == "usage.alert": - # Usage alert - log and potentially throttle - api_key = event_data.get("api_key") - usage_percent = event_data.get("usage_percent", 0) - threshold = event_data.get("threshold", "unknown") - - logger.warning(f"Usage alert: api_key_hash={rapidapi_client._hash_api_key(api_key) if api_key else 'unknown'}, usage={usage_percent}%, threshold={threshold}") - rapidapi_webhook_events_total.labels(event_type="usage.alert", status="success").inc() - - else: - logger.info(f"Unknown webhook event type: {event_type}") - rapidapi_webhook_events_total.labels(event_type=event_type, status="unknown_type").inc() - - # Store idempotency result for successful processing - if idempotency and webhook_idempotency_key: - idempotency.store_result( - webhook_idempotency_key, - {"status": "processed", "event_type": event_type, "event_id": event_id}, - ttl_s=86400, # 24 hours - ) - idempotency.clear_in_progress(webhook_idempotency_key) - - return JSONResponse( - content={"status": "ok", "event_type": event_type}, - status_code=200, + from reliapi.app.routes import ( + analytics, + calculators, + dashboard, + onboarding, + paddle, ) - - except Exception as e: - logger.error(f"Error processing webhook {event_type}: {e}") - rapidapi_webhook_events_total.labels(event_type=event_type, status="error").inc() - - # Clear in-progress marker on error - if idempotency and webhook_idempotency_key: - idempotency.clear_in_progress(webhook_idempotency_key) - - raise HTTPException(status_code=500, detail=f"Webhook processing error: {str(e)}") + app.include_router(paddle.router) + app.include_router(onboarding.router) + app.include_router(analytics.router) + app.include_router(calculators.router) + app.include_router(dashboard.router) -@app.get( - "/rapidapi/status", - summary="RapidAPI Integration Status", - description="Check the status of RapidAPI integration.", -) -async def rapidapi_status(request: Request): - """Get RapidAPI integration status.""" - # Verify API key - api_key, tenant, tier = verify_api_key(request) - - if not rapidapi_client: - return JSONResponse( - content={ - "status": "not_configured", - "message": "RapidAPI integration not configured", - }, - status_code=200, + logger.info( + "Business routes registered: paddle, onboarding, analytics, " + "calculators, dashboard" ) - - # Get usage stats for the current API key - usage_stats = await rapidapi_client.get_usage_stats(api_key) if api_key else {} - - return JSONResponse( - content={ - "status": "configured", - "tier": tier, - "usage": usage_stats, - "redis_connected": rapidapi_client.redis_enabled, - "api_configured": bool(rapidapi_client.api_key), - }, - status_code=200, - ) - + except ImportError as e: + logger.warning(f"Business routes not available: {e}") -# ============================================================================= -# Business Routes (Paddle, Onboarding, Analytics, Calculators, Dashboard) -# ============================================================================= -# Import and register business routes -try: - from reliapi.app.routes import paddle, onboarding, analytics, calculators, dashboard - - app.include_router(paddle.router) - app.include_router(onboarding.router) - app.include_router(analytics.router) - app.include_router(calculators.router) - app.include_router(dashboard.router) - - logger.info("Business routes registered: paddle, onboarding, analytics, calculators, dashboard") -except ImportError as e: - logger.warning(f"Business routes not available: {e}") +# Create the application instance +app = create_app() if __name__ == "__main__": import uvicorn - - uvicorn.run(app, host="0.0.0.0", port=8000) + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/app/routes/__init__.py b/app/routes/__init__.py index b0ad6f9..fc948eb 100644 --- a/app/routes/__init__.py +++ b/app/routes/__init__.py @@ -1,2 +1,19 @@ -"""ReliAPI Routes Package.""" +"""ReliAPI Routes Package. +This package contains all route handlers organized by domain: + +Core routes: +- health: Health check and monitoring endpoints +- proxy: HTTP and LLM proxy endpoints +- rapidapi: RapidAPI integration endpoints + +Business routes: +- paddle: Paddle payment integration +- onboarding: Self-service API key generation +- analytics: Usage analytics tracking +- calculators: ROI/pricing calculators +- dashboard: Admin dashboard +""" +from reliapi.app.routes import health, proxy, rapidapi + +__all__ = ["health", "proxy", "rapidapi"] diff --git a/app/routes/analytics.py b/app/routes/analytics.py index a79f900..7fd4851 100644 --- a/app/routes/analytics.py +++ b/app/routes/analytics.py @@ -3,13 +3,18 @@ This module provides analytics endpoints for tracking user behavior, conversion funnels, and events. All tracking is automated through APIs. """ - +import base64 +import json +import logging import os -from typing import Dict, Any, Optional, List from datetime import datetime, timedelta -from fastapi import APIRouter, Request, Header, Body -from pydantic import BaseModel, Field +from typing import Any, Dict, Optional + import httpx +from fastapi import APIRouter, Request +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) router = APIRouter(prefix="/analytics", tags=["analytics"]) @@ -174,16 +179,12 @@ async def _track_google_analytics(event_data: Dict[str, Any], ga_id: str) -> Non timeout=5.0, ) except Exception as e: - # Log error but don't fail the request - print(f"Google Analytics tracking error: {e}") + logger.warning(f"Google Analytics tracking error: {e}") async def _track_mixpanel(event_data: Dict[str, Any], token: str) -> None: """Track event in Mixpanel.""" try: - import base64 - import json - # Mixpanel uses base64 encoded JSON event_payload = { "event": event_data["event_name"], @@ -203,8 +204,7 @@ async def _track_mixpanel(event_data: Dict[str, Any], token: str) -> None: timeout=5.0, ) except Exception as e: - # Log error but don't fail the request - print(f"Mixpanel tracking error: {e}") + logger.warning(f"Mixpanel tracking error: {e}") async def _track_posthog(event_data: Dict[str, Any], api_key: str, host: str) -> None: @@ -222,6 +222,4 @@ async def _track_posthog(event_data: Dict[str, Any], api_key: str, host: str) -> timeout=5.0, ) except Exception as e: - # Log error but don't fail the request - print(f"PostHog tracking error: {e}") - + logger.warning(f"PostHog tracking error: {e}") diff --git a/app/routes/health.py b/app/routes/health.py index 8666b4a..75bfb40 100644 --- a/app/routes/health.py +++ b/app/routes/health.py @@ -1,31 +1,104 @@ -"""Health check endpoints. +"""Health check and monitoring endpoints. -This module provides health check endpoints for load balancers and monitoring. -The actual health endpoints are defined in reliapi.app.main, this module -is provided for api-template compatibility. +This module provides: +- GET /health - Basic health check +- GET /healthz - Kubernetes-style health check +- GET /readyz - Readiness check +- GET /livez - Liveness check +- GET /metrics - Prometheus metrics """ +import logging -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import Response +from prometheus_client import CONTENT_TYPE_LATEST, generate_latest from pydantic import BaseModel -router = APIRouter() +from reliapi.app.dependencies import get_app_state + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["Health"]) class HealthResponse(BaseModel): """Health check response model.""" status: str - version: str + version: str = "1.0.7" + + +class StatusResponse(BaseModel): + """Simple status response model.""" + + status: str + + +def _check_health_rate_limit(request: Request, prefix: str) -> None: + """Check rate limit for health endpoints. + + Args: + request: FastAPI request + prefix: Rate limit prefix (e.g., 'healthz', 'metrics') + + Raises: + HTTPException: If rate limit exceeded + """ + state = get_app_state() + + if not state.rate_limiter: + return + + client_ip = request.client.host if request.client else "unknown" + limit = 10 if prefix == "metrics" else 20 + + allowed, error = state.rate_limiter.check_ip_rate_limit( + client_ip, limit_per_minute=limit, prefix=prefix + ) + + if not allowed: + if prefix == "metrics": + logger.warning(f"Rate limit exceeded for /metrics endpoint: IP={client_ip}") + + raise HTTPException( + status_code=429, + detail={ + "type": "rate_limit_error", + "code": error, + "message": f"Rate limit exceeded for {prefix} endpoint.", + }, + ) @router.get("/health", response_model=HealthResponse) async def health_check() -> HealthResponse: - """Health check endpoint for load balancers and monitoring.""" - return HealthResponse(status="ok", version="1.0.7") + """Basic health check endpoint for load balancers and monitoring.""" + return HealthResponse(status="ok") + + +@router.get("/healthz", response_model=StatusResponse) +async def healthz(request: Request) -> StatusResponse: + """Kubernetes-style health check endpoint with optional rate limiting.""" + _check_health_rate_limit(request, "healthz") + return StatusResponse(status="healthy") + + +@router.get("/readyz", response_model=StatusResponse) +async def readyz(request: Request) -> StatusResponse: + """Readiness check endpoint with optional rate limiting.""" + _check_health_rate_limit(request, "readyz") + return StatusResponse(status="ready") + +@router.get("/livez", response_model=StatusResponse) +async def livez(request: Request) -> StatusResponse: + """Liveness check endpoint with optional rate limiting.""" + _check_health_rate_limit(request, "livez") + return StatusResponse(status="alive") -@router.get("/healthz", response_model=HealthResponse) -async def healthz() -> HealthResponse: - """Kubernetes-style health check endpoint.""" - return HealthResponse(status="ok", version="1.0.7") +@router.get("/metrics") +async def metrics(request: Request) -> Response: + """Prometheus metrics endpoint with rate limiting.""" + _check_health_rate_limit(request, "metrics") + return Response(content=generate_latest(), media_type=CONTENT_TYPE_LATEST) diff --git a/app/routes/proxy.py b/app/routes/proxy.py new file mode 100644 index 0000000..1133a82 --- /dev/null +++ b/app/routes/proxy.py @@ -0,0 +1,579 @@ +"""Proxy endpoints for HTTP and LLM requests. + +This module provides: +- POST /proxy/http - Universal HTTP proxy with reliability features +- POST /proxy/llm - LLM proxy with idempotency and budget control +""" +import logging +import uuid +from typing import Dict, Optional + +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import JSONResponse, StreamingResponse + +from reliapi.app.dependencies import ( + detect_client_profile, + get_account_id, + get_app_state, + verify_api_key, +) +from reliapi.app.schemas import HTTPProxyRequest, LLMProxyRequest +from reliapi.app.services import ( + handle_http_proxy, + handle_llm_proxy, + handle_llm_stream_generator, +) +from reliapi.core.free_tier_restrictions import FreeTierRestrictions +from reliapi.core.security import SecurityManager +from reliapi.integrations.routellm import ( + apply_routellm_overrides, + extract_routellm_decision, + routellm_metrics, +) +from reliapi.metrics.prometheus import ( + free_tier_abuse_attempts_total, + rapidapi_tier_distribution, + routellm_decisions_total, + routellm_overrides_total, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["Proxy"]) + + +def _check_api_key_format(api_key: Optional[str]) -> None: + """Validate API key format and raise HTTPException if invalid.""" + if api_key: + is_valid, error_msg = SecurityManager.validate_api_key_format(api_key) + if not is_valid: + raise HTTPException( + status_code=400, + detail={ + "type": "client_error", + "code": "INVALID_API_KEY_FORMAT", + "message": error_msg or "Invalid API key format", + }, + ) + + +def _check_free_tier_rate_limits( + request: Request, + api_key: Optional[str], + tier: str, + endpoint: str = "http", +) -> None: + """Check rate limits and abuse protection for Free tier. + + Args: + request: FastAPI request + api_key: API key + tier: User tier + endpoint: Endpoint type ('http' or 'llm') + + Raises: + HTTPException: If rate limit exceeded or abuse detected + """ + state = get_app_state() + + if not state.rate_limiter or tier != "free": + return + + state.rate_limiter._current_tier = tier + client_ip = request.client.host if request.client else "unknown" + user_agent = request.headers.get("User-Agent", "") + account_id = get_account_id(api_key) + + # Check IP rate limit (20 req/min) + allowed, error = state.rate_limiter.check_ip_rate_limit( + client_ip, limit_per_minute=20 + ) + if not allowed: + free_tier_abuse_attempts_total.labels( + abuse_type="rate_limit_bypass", tier=tier + ).inc() + logger.warning( + f"Free tier abuse attempt: IP rate limit exceeded for " + f"tier={tier}, IP={client_ip}" + ) + raise HTTPException( + status_code=429, + detail={ + "type": "rate_limit_error", + "code": error, + "message": "Rate limit exceeded. Free tier: 20 requests/minute per IP.", + }, + ) + + # Check account burst limit (500 req/min) + allowed, error = state.rate_limiter.check_account_burst_limit( + account_id, limit_per_minute=500 + ) + if not allowed: + free_tier_abuse_attempts_total.labels( + abuse_type="burst_limit", tier=tier + ).inc() + logger.warning( + f"Free tier abuse attempt: burst limit exceeded for " + f"tier={tier}, account_id={account_id}" + ) + raise HTTPException( + status_code=429, + detail={ + "type": "abuse_error", + "code": error, + "message": "Burst limit exceeded. Free tier abuse detected.", + }, + ) + + # Check fingerprint limit + allowed, error = state.rate_limiter.check_fingerprint_limit( + client_ip, user_agent, api_key or "", limit_per_minute=20 + ) + if not allowed: + free_tier_abuse_attempts_total.labels( + abuse_type="fingerprint_mismatch", tier=tier + ).inc() + logger.warning( + f"Free tier abuse attempt: fingerprint mismatch for " + f"tier={tier}, account_id={account_id}" + ) + raise HTTPException( + status_code=429, + detail={ + "type": "rate_limit_error", + "code": error, + "message": "Rate limit exceeded based on fingerprint.", + }, + ) + + # Check anomaly detector + allowed, error = state.rate_limiter.check_anomaly_detector(account_id) + if not allowed: + raise HTTPException( + status_code=429, + detail={ + "type": "anomaly_error", + "code": error, + "message": "Anomalous activity detected. Request throttled.", + }, + ) + + +def _check_llm_free_tier_restrictions( + request: Request, + llm_request: LLMProxyRequest, + api_key: Optional[str], + tier: str, +) -> None: + """Check LLM-specific Free tier restrictions. + + Args: + request: FastAPI request + llm_request: LLM proxy request + api_key: API key + tier: User tier + + Raises: + HTTPException: If restriction violated + """ + state = get_app_state() + + # Block SSE streaming for free tier + if tier == "free" and llm_request.stream: + allowed, error = FreeTierRestrictions.is_feature_allowed("streaming", tier) + if not allowed: + raise HTTPException( + status_code=403, + detail={ + "type": "feature_error", + "code": error, + "message": "SSE streaming not available for Free tier.", + }, + ) + + if not state.rate_limiter or tier != "free": + return + + state.rate_limiter._current_tier = tier + client_ip = request.client.host if request.client else "unknown" + user_agent = request.headers.get("User-Agent", "") + accept_language = request.headers.get("Accept-Language", "") + account_id = get_account_id(api_key) + + # Check auto-ban first (>5 bypass attempts) + should_ban, ban_reason = state.rate_limiter.check_auto_ban( + account_id, client_ip, max_attempts=5 + ) + if should_ban: + free_tier_abuse_attempts_total.labels(abuse_type="auto_ban", tier=tier).inc() + logger.warning( + f"Free tier abuse: account/IP banned for tier={tier}, " + f"account_id={account_id}, reason={ban_reason}" + ) + raise HTTPException( + status_code=403, + detail={ + "type": "abuse_error", + "code": "ACCOUNT_BANNED", + "message": f"Account/IP banned: {ban_reason}", + }, + ) + + # Check IP rate limit (20 req/min) + allowed, error = state.rate_limiter.check_ip_rate_limit( + client_ip, limit_per_minute=20 + ) + if not allowed: + state.rate_limiter.abuse_detector.record_limit_bypass_attempt( + account_id, client_ip + ) + raise HTTPException( + status_code=429, + detail={ + "type": "rate_limit_error", + "code": error, + "message": "Rate limit exceeded. Free tier: 20 requests/minute per IP.", + }, + ) + + # Check burst protection (≤300 req/10min) + allowed, error = state.rate_limiter.check_burst_protection( + account_id, limit_per_10min=300 + ) + if not allowed: + state.rate_limiter.abuse_detector.record_limit_bypass_attempt( + account_id, client_ip + ) + raise HTTPException( + status_code=429, + detail={ + "type": "abuse_error", + "code": error, + "message": "Burst limit exceeded. Free tier: maximum 300 requests per 10 minutes.", + }, + ) + + # Check account burst limit (500 req/min) + allowed, error = state.rate_limiter.check_account_burst_limit( + account_id, limit_per_minute=500 + ) + if not allowed: + state.rate_limiter.abuse_detector.record_limit_bypass_attempt( + account_id, client_ip + ) + raise HTTPException( + status_code=429, + detail={ + "type": "abuse_error", + "code": error, + "message": "Burst limit exceeded. Free tier abuse detected.", + }, + ) + + # Check usage anomaly (3x average) + allowed, error = state.rate_limiter.check_usage_anomaly( + account_id, multiplier=3.0 + ) + if not allowed: + raise HTTPException( + status_code=429, + detail={ + "type": "anomaly_error", + "code": error, + "message": "Usage anomaly detected. Request throttled.", + }, + ) + + # Check fingerprint-based identity + allowed, error = state.rate_limiter.check_fingerprint( + account_id, client_ip, user_agent, accept_language + ) + if not allowed: + if error == "FINGERPRINT_MISMATCH_BANNED": + raise HTTPException( + status_code=403, + detail={ + "type": "abuse_error", + "code": error, + "message": "Account banned due to multiple fingerprint mismatches.", + }, + ) + raise HTTPException( + status_code=429, + detail={ + "type": "abuse_error", + "code": error, + "message": "Fingerprint mismatch detected. Request throttled.", + }, + ) + + # Validate model restrictions + if llm_request.model: + allowed, error = FreeTierRestrictions.is_model_allowed( + llm_request.target, + llm_request.model, + tier, + ) + if not allowed: + raise HTTPException( + status_code=403, + detail={ + "type": "feature_error", + "code": error, + "message": ( + f"Model {llm_request.model} not allowed for Free tier. " + "Allowed: gpt-4o-mini, claude-3-haiku, mistral-small" + ), + }, + ) + + # Check idempotency restriction + if llm_request.idempotency_key: + allowed, error = FreeTierRestrictions.is_feature_allowed("idempotency", tier) + if not allowed: + raise HTTPException( + status_code=403, + detail={ + "type": "feature_error", + "code": error, + "message": "Idempotency not available for Free tier.", + }, + ) + + +@router.post( + "/proxy/http", + summary="Proxy HTTP request", + description=( + "Universal HTTP proxy endpoint for any HTTP API. " + "Supports retries, circuit breaker, cache, and idempotency. " + "Use this endpoint to add reliability layers to any HTTP API call." + ), +) +async def proxy_http( + request: HTTPProxyRequest, + http_request: Request, +) -> JSONResponse: + """Universal HTTP proxy endpoint for any HTTP API.""" + state = get_app_state() + + # Verify API key and resolve tenant/tier + api_key, tenant, tier = verify_api_key(http_request) + + # Validate API key format + _check_api_key_format(api_key) + + # Check rate limits for free tier + _check_free_tier_rate_limits(http_request, api_key, tier, endpoint="http") + + # Generate request ID + request_id = f"req_{uuid.uuid4().hex[:16]}" + + # Detect client profile + client_profile_name = detect_client_profile(http_request, tenant=tenant) + + result = await handle_http_proxy( + target_name=request.target, + method=request.method, + path=request.path, + headers=request.headers, + query=request.query, + body=request.body, + idempotency_key=request.idempotency_key, + cache_ttl=request.cache, + targets=state.targets, + cache=state.cache, + idempotency=state.idempotency, + key_pool_manager=state.key_pool_manager, + rate_scheduler=state.rate_scheduler, + client_profile_name=client_profile_name, + client_profile_manager=state.client_profile_manager, + request_id=request_id, + tenant=tenant, + tier=tier, + ) + + # Record usage for RapidAPI tracking + if state.rapidapi_client and api_key: + await state.rapidapi_client.record_usage( + api_key=api_key, + endpoint="/proxy/http", + latency_ms=result.meta.duration_ms, + status="success" if result.success else "error", + ) + rapidapi_tier_distribution.labels(tier=tier).inc() + + status_code = 200 if result.success else (result.error.status_code or 500) + return JSONResponse( + content=result.model_dump(), + status_code=status_code, + headers={ + "X-Request-ID": request_id, + "X-Cache-Hit": str(result.meta.cache_hit).lower(), + "X-Retries": str(result.meta.retries), + "X-Duration-MS": str(result.meta.duration_ms), + }, + ) + + +@router.post( + "/proxy/llm", + summary="Proxy LLM request", + description=( + "LLM proxy endpoint with idempotency, budget caps, and caching. " + "Make idempotent LLM API calls with predictable costs. " + "Supports OpenAI, Anthropic, and Mistral providers. " + "Set stream=true for Server-Sent Events (SSE) streaming." + ), +) +async def proxy_llm( + request: LLMProxyRequest, + http_request: Request, +): + """LLM proxy endpoint with idempotency and budget control.""" + state = get_app_state() + + # Verify API key and resolve tenant/tier + api_key, tenant, tier = verify_api_key(http_request) + + # Validate API key format + _check_api_key_format(api_key) + + # Check LLM-specific free tier restrictions + _check_llm_free_tier_restrictions(http_request, request, api_key, tier) + + # Generate request ID + request_id = f"req_{uuid.uuid4().hex[:16]}" + + # Extract RouteLLM routing decision from headers + routellm_decision = extract_routellm_decision(dict(http_request.headers)) + + # Apply RouteLLM overrides to target and model + resolved_target = request.target + resolved_model = request.model + if routellm_decision and routellm_decision.has_override: + resolved_target, resolved_model = apply_routellm_overrides( + request.target, + request.model, + state.targets, + routellm_decision, + ) + + # Record metrics for RouteLLM routing + routellm_decisions_total.labels( + route_name=routellm_decision.route_name or "unknown", + provider=routellm_decision.provider or "default", + model=routellm_decision.model or "default", + ).inc() + + if routellm_decision.provider and routellm_decision.model: + routellm_overrides_total.labels(override_type="both").inc() + elif routellm_decision.provider: + routellm_overrides_total.labels(override_type="provider").inc() + elif routellm_decision.model: + routellm_overrides_total.labels(override_type="model").inc() + + routellm_metrics.record_decision(routellm_decision) + + # Handle streaming requests + if request.stream: + generator = handle_llm_stream_generator( + target_name=resolved_target, + messages=request.messages, + model=resolved_model, + max_tokens=request.max_tokens, + temperature=request.temperature, + top_p=request.top_p, + stop=request.stop, + idempotency_key=request.idempotency_key, + cache_ttl=request.cache, + targets=state.targets, + cache=state.cache, + idempotency=state.idempotency, + request_id=request_id, + tenant=tenant, + tier=tier, + ) + + # Build response headers including RouteLLM correlation + response_headers: Dict[str, str] = { + "X-Request-ID": request_id, + "Cache-Control": "no-cache", + "Connection": "keep-alive", + } + if routellm_decision: + response_headers.update(routellm_decision.to_response_headers()) + + return StreamingResponse( + generator, + media_type="text/event-stream", + headers=response_headers, + ) + + # Detect client profile + client_profile_name = detect_client_profile(http_request, tenant=tenant) + + # Handle non-streaming requests + result = await handle_llm_proxy( + target_name=resolved_target, + messages=request.messages, + model=resolved_model, + max_tokens=request.max_tokens, + temperature=request.temperature, + top_p=request.top_p, + stop=request.stop, + stream=False, + idempotency_key=request.idempotency_key, + cache_ttl=request.cache, + targets=state.targets, + cache=state.cache, + idempotency=state.idempotency, + request_id=request_id, + tenant=tenant, + tier=tier, + key_pool_manager=state.key_pool_manager, + rate_scheduler=state.rate_scheduler, + client_profile_name=client_profile_name, + client_profile_manager=state.client_profile_manager, + ) + + # Record usage for RapidAPI tracking + if state.rapidapi_client and api_key: + cost_usd = ( + result.data.usage.estimated_cost_usd + if result.success and result.data and result.data.usage + else 0.0 + ) + await state.rapidapi_client.record_usage( + api_key=api_key, + endpoint="/proxy/llm", + latency_ms=result.meta.duration_ms, + status="success" if result.success else "error", + cost_usd=cost_usd, + ) + rapidapi_tier_distribution.labels(tier=tier).inc() + + # Add RouteLLM correlation to response meta + if routellm_decision: + result.meta.routellm_decision_id = routellm_decision.decision_id + result.meta.routellm_route_name = routellm_decision.route_name + result.meta.routellm_provider_override = routellm_decision.provider + result.meta.routellm_model_override = routellm_decision.model + + # Build response headers including RouteLLM correlation + response_headers: Dict[str, str] = { + "X-Request-ID": request_id, + "X-Cache-Hit": str(result.meta.cache_hit).lower(), + "X-Retries": str(result.meta.retries), + "X-Duration-MS": str(result.meta.duration_ms), + } + if routellm_decision: + response_headers.update(routellm_decision.to_response_headers()) + + status_code = 200 if result.success else (result.error.status_code or 500) + return JSONResponse( + content=result.model_dump(), + status_code=status_code, + headers=response_headers, + ) diff --git a/app/routes/rapidapi.py b/app/routes/rapidapi.py new file mode 100644 index 0000000..4861c7c --- /dev/null +++ b/app/routes/rapidapi.py @@ -0,0 +1,310 @@ +"""RapidAPI integration endpoints. + +This module provides: +- POST /webhooks/rapidapi - RapidAPI webhook handler +- GET /rapidapi/status - RapidAPI integration status +""" +import json +import logging + +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import JSONResponse + +from reliapi.app.dependencies import get_app_state, verify_api_key +from reliapi.integrations.rapidapi import SubscriptionTier +from reliapi.metrics.prometheus import ( + rapidapi_tier_cache_total, + rapidapi_tier_distribution, + rapidapi_webhook_events_total, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["RapidAPI"]) + + +@router.post( + "/webhooks/rapidapi", + summary="RapidAPI Webhook", + description="Webhook endpoint for RapidAPI events (subscription changes, usage alerts).", + include_in_schema=False, +) +async def rapidapi_webhook(request: Request) -> JSONResponse: + """Handle RapidAPI webhook events. + + Supported events: + - subscription.created: New subscription created + - subscription.updated: Subscription tier changed + - subscription.cancelled: Subscription cancelled + - usage.alert: Usage threshold reached + """ + state = get_app_state() + + if not state.rapidapi_client: + raise HTTPException( + status_code=503, + detail="RapidAPI integration not configured", + ) + + # Rate limiting for webhook endpoint (IP-based, 10 req/min) + if state.rate_limiter: + client_ip = request.client.host if request.client else "unknown" + + allowed, error = state.rate_limiter.check_ip_rate_limit( + client_ip, limit_per_minute=10, prefix="webhook" + ) + if not allowed: + logger.warning(f"Webhook rate limit exceeded for IP: {client_ip}") + rapidapi_webhook_events_total.labels( + event_type="unknown", status="rate_limited" + ).inc() + raise HTTPException( + status_code=429, + detail="Webhook rate limit exceeded (10 requests/minute)", + ) + + # Request size limit (10KB) + content_length = request.headers.get("content-length") + if content_length and int(content_length) > 10240: + logger.warning(f"Webhook payload too large: {content_length} bytes") + rapidapi_webhook_events_total.labels( + event_type="unknown", status="payload_too_large" + ).inc() + raise HTTPException( + status_code=413, + detail="Webhook payload too large (max 10KB)", + ) + + # Get raw body for signature verification + body = await request.body() + + # Verify webhook signature + signature = request.headers.get("X-RapidAPI-Signature", "") + if not state.rapidapi_client.verify_webhook_signature(body, signature): + logger.warning("Invalid RapidAPI webhook signature") + rapidapi_webhook_events_total.labels( + event_type="unknown", status="invalid_signature" + ).inc() + raise HTTPException(status_code=401, detail="Invalid signature") + + # Parse webhook payload + try: + payload = json.loads(body.decode("utf-8")) + except json.JSONDecodeError: + logger.error("Invalid JSON in webhook payload") + rapidapi_webhook_events_total.labels( + event_type="unknown", status="invalid_json" + ).inc() + raise HTTPException(status_code=400, detail="Invalid JSON payload") + + event_type = payload.get("type", "unknown") + event_data = payload.get("data", {}) + event_id = payload.get("id") or payload.get("event_id") or "" + + logger.info(f"Received RapidAPI webhook: {event_type}, event_id={event_id}") + + # Idempotency check + webhook_idempotency_key = None + if state.idempotency and event_id: + webhook_idempotency_key = f"webhook:rapidapi:{event_type}:{event_id}" + + existing_result = state.idempotency.get_result(webhook_idempotency_key) + if existing_result: + logger.info( + f"Duplicate webhook detected: {event_type}, event_id={event_id}" + ) + rapidapi_webhook_events_total.labels( + event_type=event_type, status="duplicate" + ).inc() + return JSONResponse( + content={ + "status": "ok", + "event_type": event_type, + "duplicate": True, + "message": "Event already processed", + }, + status_code=200, + ) + + state.idempotency.mark_in_progress(webhook_idempotency_key, ttl_s=60) + + try: + await _process_webhook_event(event_type, event_data) + + # Store idempotency result for successful processing + if state.idempotency and webhook_idempotency_key: + state.idempotency.store_result( + webhook_idempotency_key, + {"status": "processed", "event_type": event_type, "event_id": event_id}, + ttl_s=86400, + ) + state.idempotency.clear_in_progress(webhook_idempotency_key) + + return JSONResponse( + content={"status": "ok", "event_type": event_type}, + status_code=200, + ) + + except Exception as e: + logger.error(f"Error processing webhook {event_type}: {e}") + rapidapi_webhook_events_total.labels( + event_type=event_type, status="error" + ).inc() + + if state.idempotency and webhook_idempotency_key: + state.idempotency.clear_in_progress(webhook_idempotency_key) + + raise HTTPException( + status_code=500, + detail=f"Webhook processing error: {str(e)}", + ) + + +async def _process_webhook_event(event_type: str, event_data: dict) -> None: + """Process a webhook event based on its type. + + Args: + event_type: Type of the webhook event + event_data: Event payload data + """ + state = get_app_state() + + if event_type == "subscription.created": + api_key = event_data.get("api_key") + tier = event_data.get("tier", "free") + user_id = event_data.get("user_id") + + if api_key: + tier_enum = ( + SubscriptionTier(tier) + if tier in [t.value for t in SubscriptionTier] + else SubscriptionTier.FREE + ) + await state.rapidapi_client._cache_tier(api_key, tier_enum, user_id) + rapidapi_tier_cache_total.labels(operation="set").inc() + rapidapi_tier_distribution.labels(tier=tier).inc() + + if state.rapidapi_tenant_manager and user_id: + state.rapidapi_tenant_manager.create_tenant( + user_id, + tier_enum, + metadata={ + "api_key_hash": state.rapidapi_client._hash_api_key(api_key) + }, + ) + + logger.info(f"Cached new subscription: tier={tier}, user_id={user_id}") + + rapidapi_webhook_events_total.labels( + event_type="subscription.created", status="success" + ).inc() + + elif event_type == "subscription.updated": + api_key = event_data.get("api_key") + new_tier = event_data.get("tier", "free") + user_id = event_data.get("user_id") + + if api_key: + await state.rapidapi_client.invalidate_tier_cache(api_key) + rapidapi_tier_cache_total.labels(operation="invalidate").inc() + + tier_enum = ( + SubscriptionTier(new_tier) + if new_tier in [t.value for t in SubscriptionTier] + else SubscriptionTier.FREE + ) + await state.rapidapi_client._cache_tier(api_key, tier_enum, user_id) + rapidapi_tier_cache_total.labels(operation="set").inc() + rapidapi_tier_distribution.labels(tier=new_tier).inc() + + if state.rapidapi_tenant_manager and user_id: + state.rapidapi_tenant_manager.update_tenant_tier( + user_id, + tier_enum, + metadata={ + "api_key_hash": state.rapidapi_client._hash_api_key(api_key) + }, + ) + + logger.info( + f"Updated subscription: new_tier={new_tier}, user_id={user_id}" + ) + + rapidapi_webhook_events_total.labels( + event_type="subscription.updated", status="success" + ).inc() + + elif event_type == "subscription.cancelled": + api_key = event_data.get("api_key") + user_id = event_data.get("user_id") + + if api_key: + await state.rapidapi_client.invalidate_tier_cache(api_key) + rapidapi_tier_cache_total.labels(operation="invalidate").inc() + + if state.rapidapi_tenant_manager and user_id: + state.rapidapi_tenant_manager.delete_tenant(user_id) + + logger.info(f"Subscription cancelled: user_id={user_id}") + + rapidapi_webhook_events_total.labels( + event_type="subscription.cancelled", status="success" + ).inc() + + elif event_type == "usage.alert": + api_key = event_data.get("api_key") + usage_percent = event_data.get("usage_percent", 0) + threshold = event_data.get("threshold", "unknown") + + logger.warning( + f"Usage alert: api_key_hash=" + f"{state.rapidapi_client._hash_api_key(api_key) if api_key else 'unknown'}, " + f"usage={usage_percent}%, threshold={threshold}" + ) + rapidapi_webhook_events_total.labels( + event_type="usage.alert", status="success" + ).inc() + + else: + logger.info(f"Unknown webhook event type: {event_type}") + rapidapi_webhook_events_total.labels( + event_type=event_type, status="unknown_type" + ).inc() + + +@router.get( + "/rapidapi/status", + summary="RapidAPI Integration Status", + description="Check the status of RapidAPI integration.", +) +async def rapidapi_status(request: Request) -> JSONResponse: + """Get RapidAPI integration status.""" + state = get_app_state() + + # Verify API key + api_key, tenant, tier = verify_api_key(request) + + if not state.rapidapi_client: + return JSONResponse( + content={ + "status": "not_configured", + "message": "RapidAPI integration not configured", + }, + status_code=200, + ) + + # Get usage stats for the current API key + usage_stats = ( + await state.rapidapi_client.get_usage_stats(api_key) if api_key else {} + ) + + return JSONResponse( + content={ + "status": "configured", + "tier": tier, + "usage": usage_stats, + "redis_connected": state.rapidapi_client.redis_enabled, + "api_configured": bool(state.rapidapi_client.api_key), + }, + status_code=200, + ) diff --git a/app/schemas.py b/app/schemas.py index d6e85e7..d6e96c9 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -1,106 +1,318 @@ -"""Request/Response schemas for ReliAPI.""" -from typing import Any, Dict, List, Optional +"""Request/Response schemas for ReliAPI. -from pydantic import BaseModel, Field +This module provides Pydantic models for: +- HTTP proxy requests and responses +- LLM proxy requests and responses +- Error and metadata structures +""" +from enum import Enum +from typing import Any, Dict, List, Literal, Optional, Union + +from pydantic import BaseModel, Field, field_validator + + +class HTTPMethod(str, Enum): + """Supported HTTP methods.""" + + GET = "GET" + POST = "POST" + PUT = "PUT" + DELETE = "DELETE" + PATCH = "PATCH" + HEAD = "HEAD" + OPTIONS = "OPTIONS" + + +class MessageRole(str, Enum): + """LLM message roles.""" + + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + + +class CostPolicy(str, Enum): + """Cost policy types.""" + + NONE = "none" + SOFT_CAP_THROTTLED = "soft_cap_throttled" + HARD_CAP_REJECTED = "hard_cap_rejected" + + +class ErrorSource(str, Enum): + """Error source types.""" + + RELIAPI = "reliapi" + UPSTREAM = "upstream" + + +class ChatMessage(BaseModel): + """LLM chat message structure.""" + + role: MessageRole = Field(..., description="Message role: system, user, or assistant") + content: str = Field(..., description="Message content") class HTTPProxyRequest(BaseModel): """Request schema for POST /proxy/http. - + Use this endpoint to proxy any HTTP API request with reliability layers: - Retries with exponential backoff - Circuit breaker per target - TTL cache for GET/HEAD requests - Idempotency with request coalescing """ - - target: str = Field(..., description="Target name from config.yaml (e.g., 'my_api')") - method: str = Field(..., description="HTTP method: GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS") - path: str = Field(..., description="API path (e.g., '/users/123' or '/api/v1/data')") - headers: Optional[Dict[str, str]] = Field(None, description="HTTP headers to include in request") - query: Optional[Dict[str, Any]] = Field(None, description="Query parameters (e.g., {'page': 1, 'limit': 10})") - body: Optional[str] = Field(None, description="Request body as JSON string (for POST/PUT/PATCH)") - idempotency_key: Optional[str] = Field(None, description="Idempotency key for request coalescing. Concurrent requests with same key execute once.") - cache: Optional[int] = Field(None, description="Cache TTL in seconds (overrides config default). Only applies to GET/HEAD requests.") + + target: str = Field( + ..., description="Target name from config.yaml (e.g., 'my_api')" + ) + method: str = Field( + ..., + description="HTTP method: GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS", + ) + path: str = Field( + ..., description="API path (e.g., '/users/123' or '/api/v1/data')" + ) + headers: Optional[Dict[str, str]] = Field( + None, description="HTTP headers to include in request" + ) + query: Optional[Dict[str, Any]] = Field( + None, description="Query parameters (e.g., {'page': 1, 'limit': 10})" + ) + body: Optional[str] = Field( + None, description="Request body as JSON string (for POST/PUT/PATCH)" + ) + idempotency_key: Optional[str] = Field( + None, + description=( + "Idempotency key for request coalescing. " + "Concurrent requests with same key execute once." + ), + ) + cache: Optional[int] = Field( + None, + ge=0, + description=( + "Cache TTL in seconds (overrides config default). " + "Only applies to GET/HEAD requests." + ), + ) + + @field_validator("method") + @classmethod + def validate_method(cls, v: str) -> str: + """Validate HTTP method is uppercase and supported.""" + v = v.upper() + valid_methods = {"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"} + if v not in valid_methods: + raise ValueError(f"Invalid HTTP method: {v}. Must be one of {valid_methods}") + return v class LLMProxyRequest(BaseModel): """Request schema for POST /proxy/llm. - - Make idempotent LLM API calls with predictable costs. Supports OpenAI, Anthropic, and Mistral. + + Make idempotent LLM API calls with predictable costs. + Supports OpenAI, Anthropic, and Mistral. + Features: - Idempotency: duplicate requests return cached result - Budget caps: hard cap (reject) and soft cap (throttle) - Caching: TTL cache for LLM responses - Retries: automatic retries on failures """ - - target: str = Field(..., description="LLM target name from config.yaml (e.g., 'openai', 'anthropic')") - messages: List[Dict[str, str]] = Field(..., description="Messages list with 'role' and 'content' (e.g., [{'role': 'user', 'content': 'Hello'}])") - model: Optional[str] = Field(None, description="Model name (e.g., 'gpt-4o-mini', 'claude-3-haiku'). Uses default from config if not specified.") - max_tokens: Optional[int] = Field(None, description="Maximum tokens in response (limited by config max_tokens and budget caps)") - temperature: Optional[float] = Field(None, description="Temperature for sampling (0.0-2.0, limited by config)") - top_p: Optional[float] = Field(None, description="Top-p sampling parameter (0.0-1.0)") - stop: Optional[List[str]] = Field(None, description="Stop sequences (e.g., ['\\n', 'END'])") - stream: Optional[bool] = Field(False, description="Streaming mode. If true, returns Server-Sent Events (SSE) stream. If false or omitted, returns standard JSON response.") - idempotency_key: Optional[str] = Field(None, description="Idempotency key for request coalescing. Use same key for duplicate requests to avoid duplicate LLM calls.") - cache: Optional[int] = Field(None, description="Cache TTL in seconds (overrides config default). Cached responses return instantly without LLM call.") + + target: str = Field( + ..., + description="LLM target name from config.yaml (e.g., 'openai', 'anthropic')", + ) + messages: List[Dict[str, str]] = Field( + ..., + min_length=1, + description=( + "Messages list with 'role' and 'content' " + "(e.g., [{'role': 'user', 'content': 'Hello'}])" + ), + ) + model: Optional[str] = Field( + None, + description=( + "Model name (e.g., 'gpt-4o-mini', 'claude-3-haiku'). " + "Uses default from config if not specified." + ), + ) + max_tokens: Optional[int] = Field( + None, + ge=1, + description=( + "Maximum tokens in response (limited by config max_tokens and budget caps)" + ), + ) + temperature: Optional[float] = Field( + None, + ge=0.0, + le=2.0, + description="Temperature for sampling (0.0-2.0, limited by config)", + ) + top_p: Optional[float] = Field( + None, + ge=0.0, + le=1.0, + description="Top-p sampling parameter (0.0-1.0)", + ) + stop: Optional[List[str]] = Field( + None, description="Stop sequences (e.g., ['\\n', 'END'])" + ) + stream: bool = Field( + False, + description=( + "Streaming mode. If true, returns Server-Sent Events (SSE) stream. " + "If false or omitted, returns standard JSON response." + ), + ) + idempotency_key: Optional[str] = Field( + None, + description=( + "Idempotency key for request coalescing. " + "Use same key for duplicate requests to avoid duplicate LLM calls." + ), + ) + cache: Optional[int] = Field( + None, + ge=0, + description=( + "Cache TTL in seconds (overrides config default). " + "Cached responses return instantly without LLM call." + ), + ) + + +class TokenUsage(BaseModel): + """Token usage statistics for LLM responses.""" + + prompt_tokens: int = Field(..., ge=0, description="Number of tokens in the prompt") + completion_tokens: int = Field( + ..., ge=0, description="Number of tokens in the completion" + ) + total_tokens: int = Field(..., ge=0, description="Total tokens used") + estimated_cost_usd: Optional[float] = Field( + None, ge=0, description="Estimated cost in USD" + ) + + +class LLMResponseData(BaseModel): + """LLM response data structure.""" + + content: str = Field(..., description="Generated text content") + model: str = Field(..., description="Model used for generation") + usage: Optional[TokenUsage] = Field(None, description="Token usage statistics") + finish_reason: Optional[str] = Field( + None, description="Reason for completion (stop, length, etc.)" + ) class ErrorDetail(BaseModel): """Error detail in response.""" - + type: str = Field(..., description="Error type") code: str = Field(..., description="Error code") message: str = Field(..., description="Error message") retryable: bool = Field(..., description="Whether error is retryable") target: Optional[str] = Field(None, description="Target name if applicable") status_code: Optional[int] = Field(None, description="HTTP status code") - source: Optional[str] = Field(None, description="Error source: 'reliapi' or 'upstream'") - retry_after_s: Optional[float] = Field(None, description="Retry after seconds (for rate limit errors)") - provider_key_status: Optional[str] = Field(None, description="Provider key status if applicable") + source: Optional[str] = Field( + None, description="Error source: 'reliapi' or 'upstream'" + ) + retry_after_s: Optional[float] = Field( + None, ge=0, description="Retry after seconds (for rate limit errors)" + ) + provider_key_status: Optional[str] = Field( + None, description="Provider key status if applicable" + ) hint: Optional[str] = Field(None, description="Hint for debugging") - details: Optional[Dict[str, Any]] = Field(None, description="Additional error details") + details: Optional[Dict[str, Any]] = Field( + None, description="Additional error details" + ) class MetaResponse(BaseModel): """Metadata in response.""" - + target: Optional[str] = Field(None, description="Target name") provider: Optional[str] = Field(None, description="Provider name (for LLM)") model: Optional[str] = Field(None, description="Model name (for LLM)") cache_hit: bool = Field(False, description="Whether response was from cache") - idempotent_hit: bool = Field(False, description="Whether response was from idempotency cache") - retries: int = Field(0, description="Number of retries") - duration_ms: int = Field(..., description="Request duration in milliseconds") + idempotent_hit: bool = Field( + False, description="Whether response was from idempotency cache" + ) + retries: int = Field(0, ge=0, description="Number of retries") + duration_ms: int = Field(..., ge=0, description="Request duration in milliseconds") request_id: str = Field(..., description="Request ID") trace_id: Optional[str] = Field(None, description="Trace ID") - cost_usd: Optional[float] = Field(None, description="Actual cost in USD (for LLM)") - cost_estimate_usd: Optional[float] = Field(None, description="Estimated cost before request (for LLM)") - cost_policy_applied: Optional[str] = Field(None, description="Cost policy applied: none, soft_cap_throttled, hard_cap_rejected") - max_tokens_reduced: Optional[bool] = Field(None, description="Whether max_tokens was automatically reduced due to soft cost cap (for LLM)") - original_max_tokens: Optional[int] = Field(None, description="Original max_tokens before reduction (for LLM)") - fallback_used: Optional[bool] = Field(None, description="Whether fallback was used") - fallback_target: Optional[str] = Field(None, description="Fallback target name if used") + cost_usd: Optional[float] = Field( + None, ge=0, description="Actual cost in USD (for LLM)" + ) + cost_estimate_usd: Optional[float] = Field( + None, ge=0, description="Estimated cost before request (for LLM)" + ) + cost_policy_applied: Optional[str] = Field( + None, + description="Cost policy applied: none, soft_cap_throttled, hard_cap_rejected", + ) + max_tokens_reduced: Optional[bool] = Field( + None, + description=( + "Whether max_tokens was automatically reduced due to soft cost cap (for LLM)" + ), + ) + original_max_tokens: Optional[int] = Field( + None, description="Original max_tokens before reduction (for LLM)" + ) + fallback_used: Optional[bool] = Field( + None, description="Whether fallback was used" + ) + fallback_target: Optional[str] = Field( + None, description="Fallback target name if used" + ) # RouteLLM correlation fields - routellm_decision_id: Optional[str] = Field(None, description="RouteLLM routing decision ID for correlation") - routellm_route_name: Optional[str] = Field(None, description="RouteLLM route name that was applied") - routellm_provider_override: Optional[str] = Field(None, description="Provider override from RouteLLM (if any)") - routellm_model_override: Optional[str] = Field(None, description="Model override from RouteLLM (if any)") + routellm_decision_id: Optional[str] = Field( + None, description="RouteLLM routing decision ID for correlation" + ) + routellm_route_name: Optional[str] = Field( + None, description="RouteLLM route name that was applied" + ) + routellm_provider_override: Optional[str] = Field( + None, description="Provider override from RouteLLM (if any)" + ) + routellm_model_override: Optional[str] = Field( + None, description="Model override from RouteLLM (if any)" + ) class SuccessResponse(BaseModel): """Success response format.""" - - success: bool = Field(True, description="Success flag") + + success: Literal[True] = Field(True, description="Success flag") data: Dict[str, Any] = Field(..., description="Response data") meta: MetaResponse = Field(..., description="Response metadata") +class LLMSuccessResponse(BaseModel): + """LLM-specific success response format with typed data.""" + + success: Literal[True] = Field(True, description="Success flag") + data: LLMResponseData = Field(..., description="LLM response data") + meta: MetaResponse = Field(..., description="Response metadata") + + class ErrorResponse(BaseModel): """Error response format.""" - - success: bool = Field(False, description="Success flag") + + success: Literal[False] = Field(False, description="Success flag") error: ErrorDetail = Field(..., description="Error details") meta: MetaResponse = Field(..., description="Response metadata") + +# Type alias for proxy response (union of success and error) +ProxyResponse = Union[SuccessResponse, ErrorResponse] +LLMProxyResponse = Union[LLMSuccessResponse, ErrorResponse] diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md new file mode 100644 index 0000000..fd1c504 --- /dev/null +++ b/docs/ARCHITECTURE.md @@ -0,0 +1,368 @@ +# ReliAPI Architecture + +This document describes the high-level architecture and design decisions of ReliAPI. + +## Overview + +ReliAPI is a reliability layer for HTTP and LLM API calls. It provides: + +- **Retries** with exponential backoff +- **Circuit breaker** to prevent cascading failures +- **Caching** for GET/HEAD requests and LLM responses +- **Idempotency** with request coalescing +- **Rate limiting** with abuse detection +- **Budget caps** for LLM cost control +- **Multi-tenancy** with RapidAPI integration + +## System Architecture + +``` +┌─────────────────────────────────────────────────────────────────────────┐ +│ Client Request │ +└─────────────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────────────┐ +│ FastAPI Application │ +│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ +│ │ Health │ │ Proxy │ │ RapidAPI │ │ Business │ │ +│ │ Routes │ │ Routes │ │ Routes │ │ Routes │ │ +│ └─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌───────────────────────────────────────────────────────────────────┐ │ +│ │ Dependencies Layer │ │ +│ │ • verify_api_key() • detect_client_profile() │ │ +│ │ • AppState • Configuration validation │ │ +│ └───────────────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────────────┐ +│ Services Layer │ +│ ┌─────────────────────────────┐ ┌─────────────────────────────────┐ │ +│ │ handle_http_proxy │ │ handle_llm_proxy │ │ +│ │ • Target resolution │ │ • Provider selection │ │ +│ │ • Request building │ │ • Model routing │ │ +│ │ • Response processing │ │ • Cost estimation │ │ +│ └─────────────────────────────┘ └─────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────────────┐ +│ Core Reliability Layer │ +│ ┌───────────┐ ┌───────────┐ ┌───────────┐ ┌───────────┐ ┌───────────┐│ +│ │ Cache │ │ Circuit │ │ Retry │ │ Rate │ │Idempotency││ +│ │ │ │ Breaker │ │ Engine │ │ Limiter │ │ Manager ││ +│ └───────────┘ └───────────┘ └───────────┘ └───────────┘ └───────────┘│ +│ ┌───────────┐ ┌───────────┐ ┌───────────┐ ┌───────────┐ ┌───────────┐│ +│ │ Key Pool │ │ Rate │ │ Client │ │ Cost │ │ Security ││ +│ │ Manager │ │ Scheduler │ │ Profiles │ │ Estimator │ │ Manager ││ +│ └───────────┘ └───────────┘ └───────────┘ └───────────┘ └───────────┘│ +└─────────────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────────────┐ +│ Adapters Layer │ +│ ┌─────────────────────────────┐ ┌─────────────────────────────────┐ │ +│ │ LLM Adapters │ │ HTTP Client Adapter │ │ +│ │ • OpenAI │ │ • Universal HTTP client │ │ +│ │ • Anthropic │ │ • Connection pooling │ │ +│ │ • Mistral │ │ • Timeout handling │ │ +│ └─────────────────────────────┘ └─────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────────────┐ +│ External Services │ +│ ┌───────────┐ ┌───────────┐ ┌───────────┐ ┌───────────┐ │ +│ │ Redis │ │ OpenAI │ │ Anthropic │ │ Mistral │ │ Target ││ +│ │ (State) │ │ API │ │ API │ │ API │ │ APIs ││ +│ └───────────┘ └───────────┘ └───────────┘ └───────────┘ └───────────┘│ +└─────────────────────────────────────────────────────────────────────────┘ +``` + +## Directory Structure + +``` +reliapi/ +├── app/ # FastAPI application +│ ├── main.py # Application entry point +│ ├── dependencies.py # Shared dependencies and state +│ ├── schemas.py # Pydantic request/response models +│ ├── services.py # Business logic +│ └── routes/ # Route handlers by domain +│ ├── health.py # Health check endpoints +│ ├── proxy.py # HTTP and LLM proxy endpoints +│ ├── rapidapi.py # RapidAPI integration +│ └── ... # Business routes +├── core/ # Core reliability components +│ ├── cache.py # Redis-based TTL cache +│ ├── circuit_breaker.py # Circuit breaker pattern +│ ├── retry.py # Retry engine with backoff +│ ├── rate_limiter.py # Multi-layer rate limiting +│ ├── rate_scheduler.py # Token bucket algorithm +│ ├── idempotency.py # Request coalescing +│ ├── key_pool.py # Provider key management +│ ├── cost_estimator.py # LLM cost calculation +│ ├── client_profile.py # Per-client constraints +│ ├── security.py # API key validation +│ └── errors.py # Error codes and normalization +├── adapters/ # Provider adapters +│ ├── llm/ # LLM provider implementations +│ │ ├── base.py # Abstract base class +│ │ ├── openai.py # OpenAI adapter +│ │ ├── anthropic.py # Anthropic adapter +│ │ └── mistral.py # Mistral adapter +│ └── http_generic/ # Generic HTTP adapter +├── config/ # Configuration +│ ├── loader.py # YAML config loader +│ └── schema.py # Config schema +├── integrations/ # External integrations +│ ├── rapidapi.py # RapidAPI tier detection +│ └── routellm.py # RouteLLM routing +├── metrics/ # Observability +│ └── prometheus.py # Prometheus metrics +└── tests/ # Test suite +``` + +## Core Components + +### 1. Cache (core/cache.py) + +Redis-backed TTL cache with multi-tenant isolation. + +**Key features:** +- GET/HEAD request caching +- Conditional POST caching (explicit opt-in) +- Tenant-isolated cache keys +- Graceful degradation if Redis unavailable + +**Cache key format:** +``` +{prefix}:tenant:{tenant}:cache:{hash} +``` + +### 2. Circuit Breaker (core/circuit_breaker.py) + +Prevents cascading failures using the circuit breaker pattern. + +**States:** +- **Closed**: Normal operation, requests pass through +- **Open**: Failures exceeded threshold, requests rejected +- **Half-Open**: Testing if service recovered + +**Configuration:** +```yaml +circuit_breaker: + failure_threshold: 5 + recovery_timeout: 30 +``` + +### 3. Retry Engine (core/retry.py) + +Exponential backoff retry with customizable policies. + +**Retry matrix by error type:** +```yaml +retry: + 429: {retries: 5, backoff: 2.0} # Rate limit + 5xx: {retries: 3, backoff: 1.5} # Server errors + timeout: {retries: 3, backoff: 2.0} + network: {retries: 2, backoff: 1.0} +``` + +### 4. Rate Limiter (core/rate_limiter.py) + +Multi-layer rate limiting with abuse detection. + +**Layers:** +1. IP-based rate limiting +2. Account burst limiting +3. Fingerprint-based identity +4. Anomaly detection +5. Auto-ban for repeated violations + +### 5. Idempotency Manager (core/idempotency.py) + +Request coalescing with Redis SETNX. + +**Features:** +- Duplicate request detection +- In-progress locking +- Result caching with TTL + +### 6. Key Pool Manager (core/key_pool.py) + +Multi-key rotation for LLM providers. + +**Features:** +- Round-robin key selection +- QPS limit per key +- Automatic failover on rate limits + +## Request Flow + +### HTTP Proxy Request + +``` +1. Request arrives at POST /proxy/http +2. verify_api_key() - Authenticate and resolve tenant/tier +3. Rate limiting checks (IP, burst, fingerprint) +4. Check idempotency cache +5. Check response cache +6. Check circuit breaker state +7. Execute request with retry logic +8. Store in cache if cacheable +9. Record metrics +10. Return response +``` + +### LLM Proxy Request + +``` +1. Request arrives at POST /proxy/llm +2. verify_api_key() - Authenticate and resolve tenant/tier +3. Free tier restriction checks (model, features) +4. Rate limiting and abuse detection +5. RouteLLM routing decision (if configured) +6. Check idempotency cache +7. Check response cache +8. Cost estimation and budget check +9. Select provider key from pool +10. Execute LLM request with retry +11. Calculate actual cost +12. Store in cache +13. Record metrics and usage +14. Return response +``` + +## Design Decisions + +### 1. Graceful Degradation + +All Redis-dependent features degrade gracefully: +```python +try: + self.client = redis.from_url(redis_url) + self.enabled = True +except Exception: + self.client = None + self.enabled = False +``` + +### 2. Adapter Pattern for LLM Providers + +Unified interface for multiple LLM providers: +```python +class BaseLLMAdapter(ABC): + @abstractmethod + async def complete(self, messages, model, **kwargs) -> LLMResponse: + pass +``` + +### 3. Factory Pattern for Provider Selection + +```python +def create_adapter(provider: str) -> BaseLLMAdapter: + adapters = { + "openai": OpenAIAdapter, + "anthropic": AnthropicAdapter, + "mistral": MistralAdapter, + } + return adapters[provider]() +``` + +### 4. AppState for Dependency Injection + +Centralized state management: +```python +@dataclass +class AppState: + config_loader: Optional[ConfigLoader] = None + cache: Optional[Cache] = None + rate_limiter: Optional[RateLimiter] = None + # ... other components +``` + +### 5. Configuration-Driven Behavior + +All behavior configurable via YAML: +```yaml +targets: + my_api: + base_url: https://api.example.com + cache: + ttl: 300 + retry: + max_attempts: 3 + circuit_breaker: + failure_threshold: 5 +``` + +## Observability + +### Prometheus Metrics + +- `reliapi_requests_total` - Total requests by target/status +- `reliapi_request_duration_seconds` - Request latency histogram +- `reliapi_cache_hits_total` - Cache hit/miss counts +- `reliapi_circuit_breaker_state` - Circuit breaker state gauge +- `reliapi_llm_cost_usd` - LLM cost histogram + +### Structured Logging + +JSON-formatted logs for easy aggregation: +```json +{ + "timestamp": "2025-01-15T10:30:00Z", + "level": "INFO", + "request_id": "req_abc123", + "target": "openai", + "duration_ms": 150, + "cache_hit": false +} +``` + +## Security Considerations + +1. **API Key Validation** - Format validation before use +2. **Rate Limiting** - Multi-layer protection against abuse +3. **Fingerprinting** - Detect account sharing/abuse +4. **Auto-ban** - Automatic blocking of repeat offenders +5. **CORS** - Configurable origin restrictions +6. **Non-root Docker** - Container runs as unprivileged user + +## Deployment + +### Docker + +```bash +docker build -t reliapi . +docker run -p 8000:8000 -e REDIS_URL=redis://redis:6379 reliapi +``` + +### Docker Compose + +```yaml +services: + reliapi: + build: . + ports: + - "8000:8000" + depends_on: + redis: + condition: service_healthy + redis: + image: redis:7-alpine + volumes: + - redis_data:/data +``` + +## Performance Considerations + +1. **Connection Pooling** - httpx maintains connection pools +2. **Async I/O** - All I/O operations are async +3. **Memory Management** - Rate scheduler cleanup task +4. **Redis Pipelining** - Batch operations where possible +5. **Lazy Loading** - Routes loaded on demand diff --git a/tests/test_routes_business.py b/tests/test_routes_business.py new file mode 100644 index 0000000..e75e8bf --- /dev/null +++ b/tests/test_routes_business.py @@ -0,0 +1,517 @@ +"""Tests for business routes (onboarding, analytics, health). + +This module tests the business route endpoints: +- Onboarding flow (start, quick-start, verify) +- Analytics tracking (track, conversion, funnel) +- Health check endpoints +""" +import json +from datetime import datetime +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi.testclient import TestClient + + +@pytest.fixture +def mock_redis(): + """Create a mock Redis client.""" + mock = MagicMock() + mock.get.return_value = None + mock.setex.return_value = True + return mock + + +@pytest.fixture +def client(): + """Create a test client for the FastAPI app.""" + # Import here to avoid circular imports + from reliapi.app.main import app + return TestClient(app) + + +class TestOnboardingRoutes: + """Tests for onboarding endpoints.""" + + @patch("reliapi.app.routes.onboarding.redis") + def test_start_onboarding_success(self, mock_redis_module, mock_redis): + """Test successful onboarding start.""" + mock_redis_module.from_url.return_value = mock_redis + + from reliapi.app.routes.onboarding import router + from fastapi import FastAPI + + app = FastAPI() + app.include_router(router) + client = TestClient(app) + + response = client.post( + "/onboarding/start", + json={"email": "test@example.com", "plan": "free"}, + ) + + assert response.status_code == 200 + data = response.json() + + assert "api_key" in data + assert data["api_key"].startswith("reliapi_") + assert "quick_start_url" in data + assert "documentation_url" in data + assert "example_code" in data + assert "python" in data["example_code"] + assert "javascript" in data["example_code"] + assert "curl" in data["example_code"] + assert data["integration_status"] == "pending_verification" + + @patch("reliapi.app.routes.onboarding.redis") + def test_start_onboarding_pro_plan(self, mock_redis_module, mock_redis): + """Test onboarding with pro plan.""" + mock_redis_module.from_url.return_value = mock_redis + + from reliapi.app.routes.onboarding import router + from fastapi import FastAPI + + app = FastAPI() + app.include_router(router) + client = TestClient(app) + + response = client.post( + "/onboarding/start", + json={"email": "pro@example.com", "plan": "pro"}, + ) + + assert response.status_code == 200 + data = response.json() + assert "api_key" in data + + def test_start_onboarding_invalid_email(self): + """Test onboarding with invalid email.""" + from reliapi.app.routes.onboarding import router + from fastapi import FastAPI + + app = FastAPI() + app.include_router(router) + client = TestClient(app) + + response = client.post( + "/onboarding/start", + json={"email": "not-an-email", "plan": "free"}, + ) + + assert response.status_code == 422 # Validation error + + def test_get_quick_start_guide(self): + """Test getting quick start guide.""" + from reliapi.app.routes.onboarding import router + from fastapi import FastAPI + + app = FastAPI() + app.include_router(router) + client = TestClient(app) + + response = client.get("/onboarding/quick-start") + + assert response.status_code == 200 + data = response.json() + + assert "steps" in data + assert len(data["steps"]) == 4 + assert "code_examples" in data + assert "test_endpoint" in data + + @patch("reliapi.app.routes.onboarding.redis") + def test_verify_integration_valid_key(self, mock_redis_module, mock_redis): + """Test verification with valid API key.""" + # Mock Redis to return user data + mock_redis.get.side_effect = lambda key: ( + json.dumps({"email": "test@example.com"}).encode() + if key.startswith("api_key:") + else b"5" # 5 requests made + ) + mock_redis_module.from_url.return_value = mock_redis + + from reliapi.app.routes.onboarding import router + from fastapi import FastAPI + + app = FastAPI() + app.include_router(router) + client = TestClient(app) + + response = client.post( + "/onboarding/verify", + headers={"X-API-Key": "reliapi_test123"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "verified" + assert data["requests_made"] == 5 + + @patch("reliapi.app.routes.onboarding.redis") + def test_verify_integration_invalid_key(self, mock_redis_module, mock_redis): + """Test verification with invalid API key.""" + mock_redis.get.return_value = None + mock_redis_module.from_url.return_value = mock_redis + + from reliapi.app.routes.onboarding import router + from fastapi import FastAPI + + app = FastAPI() + app.include_router(router) + client = TestClient(app) + + response = client.post( + "/onboarding/verify", + headers={"X-API-Key": "invalid_key"}, + ) + + assert response.status_code == 401 + + +class TestAnalyticsRoutes: + """Tests for analytics endpoints.""" + + def test_track_event_basic(self): + """Test basic event tracking.""" + from reliapi.app.routes.analytics import router + from fastapi import FastAPI + + app = FastAPI() + app.include_router(router) + client = TestClient(app) + + response = client.post( + "/analytics/track", + json={ + "event_name": "page_view", + "user_id": "user123", + "properties": {"page": "/home"}, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "tracked" + assert data["event"] == "page_view" + + def test_track_event_without_user_id(self): + """Test event tracking without user ID.""" + from reliapi.app.routes.analytics import router + from fastapi import FastAPI + + app = FastAPI() + app.include_router(router) + client = TestClient(app) + + response = client.post( + "/analytics/track", + json={ + "event_name": "anonymous_action", + "properties": {"action": "click"}, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "tracked" + + def test_track_conversion(self): + """Test conversion event tracking.""" + from reliapi.app.routes.analytics import router + from fastapi import FastAPI + + app = FastAPI() + app.include_router(router) + client = TestClient(app) + + response = client.post( + "/analytics/conversion", + json={ + "event_type": "signup", + "user_id": "user456", + "properties": {"plan": "pro"}, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "tracked" + assert "conversion_signup" in data["event"] + + def test_get_funnel_default_dates(self): + """Test getting funnel with default date range.""" + from reliapi.app.routes.analytics import router + from fastapi import FastAPI + + app = FastAPI() + app.include_router(router) + client = TestClient(app) + + response = client.get("/analytics/funnel") + + assert response.status_code == 200 + data = response.json() + + assert "period" in data + assert "funnel" in data + assert "conversion_rates" in data + assert "visitors" in data["funnel"] + assert "trial_signups" in data["funnel"] + assert "paid_conversions" in data["funnel"] + + def test_get_funnel_custom_dates(self): + """Test getting funnel with custom date range.""" + from reliapi.app.routes.analytics import router + from fastapi import FastAPI + + app = FastAPI() + app.include_router(router) + client = TestClient(app) + + response = client.get( + "/analytics/funnel", + params={ + "start_date": "2025-01-01T00:00:00", + "end_date": "2025-01-31T23:59:59", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert "2025-01-01" in data["period"]["start"] + assert "2025-01-31" in data["period"]["end"] + + +class TestHealthRoutes: + """Tests for health check endpoints.""" + + def test_health_endpoint(self): + """Test /health endpoint.""" + from reliapi.app.routes.health import router + from fastapi import FastAPI + + app = FastAPI() + app.include_router(router) + client = TestClient(app) + + response = client.get("/health") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert "version" in data + + def test_healthz_endpoint(self): + """Test /healthz endpoint.""" + from reliapi.app.routes.health import router + from fastapi import FastAPI + + # Need to mock the app_state for rate limiter + with patch("reliapi.app.routes.health.get_app_state") as mock_state: + mock_state.return_value = MagicMock(rate_limiter=None) + + app = FastAPI() + app.include_router(router) + client = TestClient(app) + + response = client.get("/healthz") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + + def test_readyz_endpoint(self): + """Test /readyz endpoint.""" + from reliapi.app.routes.health import router + from fastapi import FastAPI + + with patch("reliapi.app.routes.health.get_app_state") as mock_state: + mock_state.return_value = MagicMock(rate_limiter=None) + + app = FastAPI() + app.include_router(router) + client = TestClient(app) + + response = client.get("/readyz") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ready" + + def test_livez_endpoint(self): + """Test /livez endpoint.""" + from reliapi.app.routes.health import router + from fastapi import FastAPI + + with patch("reliapi.app.routes.health.get_app_state") as mock_state: + mock_state.return_value = MagicMock(rate_limiter=None) + + app = FastAPI() + app.include_router(router) + client = TestClient(app) + + response = client.get("/livez") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "alive" + + def test_metrics_endpoint(self): + """Test /metrics endpoint returns Prometheus format.""" + from reliapi.app.routes.health import router + from fastapi import FastAPI + + with patch("reliapi.app.routes.health.get_app_state") as mock_state: + mock_state.return_value = MagicMock(rate_limiter=None) + + app = FastAPI() + app.include_router(router) + client = TestClient(app) + + response = client.get("/metrics") + + assert response.status_code == 200 + # Prometheus metrics are text-based + assert "text/plain" in response.headers["content-type"] or \ + "text/plain" in response.headers.get("content-type", "") + + +class TestProxyRoutes: + """Tests for proxy route helpers.""" + + def test_check_api_key_format_valid(self): + """Test API key format validation with valid key.""" + from reliapi.core.security import SecurityManager + + is_valid, error = SecurityManager.validate_api_key_format("sk-valid-key-123") + assert is_valid is True + assert error is None + + def test_check_api_key_format_empty(self): + """Test API key format validation with empty key.""" + from reliapi.core.security import SecurityManager + + is_valid, error = SecurityManager.validate_api_key_format("") + # Empty keys might be valid depending on implementation + # Just verify the method works + assert isinstance(is_valid, bool) + + +class TestSchemaValidation: + """Tests for Pydantic schema validation.""" + + def test_http_proxy_request_valid(self): + """Test valid HTTP proxy request.""" + from reliapi.app.schemas import HTTPProxyRequest + + request = HTTPProxyRequest( + target="my_api", + method="GET", + path="/users", + ) + + assert request.target == "my_api" + assert request.method == "GET" + assert request.path == "/users" + + def test_http_proxy_request_method_uppercase(self): + """Test HTTP method is uppercased.""" + from reliapi.app.schemas import HTTPProxyRequest + + request = HTTPProxyRequest( + target="my_api", + method="get", # lowercase + path="/users", + ) + + assert request.method == "GET" + + def test_http_proxy_request_invalid_method(self): + """Test invalid HTTP method raises error.""" + from reliapi.app.schemas import HTTPProxyRequest + + with pytest.raises(ValueError): + HTTPProxyRequest( + target="my_api", + method="INVALID", + path="/users", + ) + + def test_llm_proxy_request_valid(self): + """Test valid LLM proxy request.""" + from reliapi.app.schemas import LLMProxyRequest + + request = LLMProxyRequest( + target="openai", + messages=[{"role": "user", "content": "Hello"}], + ) + + assert request.target == "openai" + assert len(request.messages) == 1 + assert request.stream is False + + def test_llm_proxy_request_with_options(self): + """Test LLM proxy request with all options.""" + from reliapi.app.schemas import LLMProxyRequest + + request = LLMProxyRequest( + target="openai", + messages=[{"role": "user", "content": "Hello"}], + model="gpt-4o-mini", + max_tokens=100, + temperature=0.7, + top_p=0.9, + stream=True, + ) + + assert request.model == "gpt-4o-mini" + assert request.max_tokens == 100 + assert request.temperature == 0.7 + assert request.top_p == 0.9 + assert request.stream is True + + def test_llm_proxy_request_temperature_bounds(self): + """Test temperature must be within bounds.""" + from reliapi.app.schemas import LLMProxyRequest + + # Valid temperature + request = LLMProxyRequest( + target="openai", + messages=[{"role": "user", "content": "Hello"}], + temperature=1.5, + ) + assert request.temperature == 1.5 + + # Invalid temperature (too high) + with pytest.raises(ValueError): + LLMProxyRequest( + target="openai", + messages=[{"role": "user", "content": "Hello"}], + temperature=3.0, + ) + + def test_error_response_model(self): + """Test error response model.""" + from reliapi.app.schemas import ErrorDetail, ErrorResponse, MetaResponse + + error = ErrorDetail( + type="rate_limit_error", + code="RATE_LIMIT_EXCEEDED", + message="Too many requests", + retryable=True, + status_code=429, + ) + + meta = MetaResponse( + duration_ms=10, + request_id="req_123", + ) + + response = ErrorResponse( + error=error, + meta=meta, + ) + + assert response.success is False + assert response.error.code == "RATE_LIMIT_EXCEEDED" + assert response.meta.request_id == "req_123"