diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..90df308 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,77 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + lint: + name: Format & Lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + cache: pip + + - name: Install linters + run: pip install black isort flake8 + + - name: black + run: black --check app/ tests/ + + - name: isort + run: isort --check-only app/ tests/ + + - name: flake8 + run: flake8 app/ tests/ --max-line-length=88 --extend-ignore=E203,W503 + + typecheck: + name: Type Check + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + cache: pip + + - name: Install package with dev extras + run: pip install -e ".[dev]" + + - name: mypy + run: mypy app/ --ignore-missing-imports + + test: + name: Tests + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + cache: pip + + - name: Install package with dev extras + run: pip install -e ".[dev]" + + - name: pytest + run: pytest tests/ -v --tb=short + + secret-scan: + name: Secret Scan + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - uses: gitleaks/gitleaks-action@v2 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..40464cf --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,30 @@ +name: Publish to PyPI + +on: + push: + tags: + - "v*" + +jobs: + publish: + name: Build & Publish + runs-on: ubuntu-latest + environment: pypi + permissions: + id-token: write # OIDC trusted publishing + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install build tools + run: pip install build + + - name: Build distribution + run: python -m build + + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/app/budget.py b/app/budget.py index 4e24ed6..9537146 100644 --- a/app/budget.py +++ b/app/budget.py @@ -30,15 +30,32 @@ def estimate_llm_cost(model: str, input_tokens: int = 0, output_tokens: int = 0) return input_cost + output_cost +def _estimate_tokens(text: str) -> int: + """Estimate token count without an external tokeniser. + + Uses two complementary heuristics and takes the higher value to avoid + under-counting (safer for budget enforcement): + + 1. Character-based — 1 token ≈ 4 chars (works well for dense prose) + 2. Word-based — 1 word ≈ 1.3 tokens (handles short words & punctuation + that the character rule under-counts) + + The result is clamped to a minimum of 1 so zero-length strings don't + produce a zero cost estimate silently. + """ + char_estimate = len(text) / 4 + word_estimate = len(text.split()) * 1.3 + return max(1, int(max(char_estimate, word_estimate))) + + def estimate_request_cost(raw_text: str, model: str = "gpt-4") -> float: - """Estimate cost for a single request based on text length""" - # Rough estimation: 1 token ≈ 4 characters for English text - estimated_tokens = len(raw_text) // 4 - - # Estimate 50% input, 50% output for typical requests + """Estimate cost for a single request based on text length.""" + estimated_tokens = _estimate_tokens(raw_text) + + # Assume 50 % input / 50 % output split for a typical chat request input_tokens = int(estimated_tokens * 0.5) output_tokens = int(estimated_tokens * 0.5) - + return estimate_llm_cost(model, input_tokens, output_tokens) def get_purchase_amount(tool_config: Dict[str, Any]) -> Optional[float]: @@ -56,7 +73,16 @@ def get_purchase_amount(tool_config: Dict[str, Any]) -> Optional[float]: return None def get_user_budget(user_id: str, db: Session) -> Budget: - """Get or create budget for user""" + """Get or create budget for user. + + .. deprecated:: + Budget state is now owned by the Console (governsai-console). + Use ``check_budget_with_context`` with a ``budget_context`` payload + sourced from the Console's ``/api/v1/budget/context`` endpoint instead. + This function operates against Precheck's local Budget table which may + disagree with Console; it exists only for backwards-compatibility with + standalone deployments and will be removed in a future release. + """ budget = db.query(Budget).filter(Budget.user_id == user_id).first() if not budget: @@ -145,12 +171,19 @@ def check_budget_with_context( return budget_status, budget_info def check_budget( - user_id: str, - estimated_llm_cost: float, + user_id: str, + estimated_llm_cost: float, estimated_purchase: Optional[float] = None, db: Optional[Session] = None ) -> Tuple[BudgetStatus, BudgetInfo]: - """Check if request is within budget limits""" + """Check if request is within budget limits (local-DB path). + + .. deprecated:: + Prefer ``check_budget_with_context`` which uses budget state supplied + by the Console. This function reads from Precheck's local Budget table + and can produce results that disagree with Console when both services + are deployed together. It will be removed in a future release. + """ if db is None: db = next(get_db()) diff --git a/app/rate_limit.py b/app/rate_limit.py index d4af4a2..252580e 100644 --- a/app/rate_limit.py +++ b/app/rate_limit.py @@ -1,17 +1,31 @@ -import redis import logging import time -from typing import Optional +import threading +from collections import deque +from typing import Deque, Dict, Optional from .settings import settings logger = logging.getLogger(__name__) +try: + import redis +except Exception: # pragma: no cover - exercised in environments without redis package + redis = None + + class RateLimiter: - """Redis-based token bucket rate limiter""" - + """Redis-first sliding-window rate limiter with in-memory fallback.""" + def __init__(self, redis_url: Optional[str] = None): self.redis_client = None - if redis_url: + self._local_lock = threading.Lock() + self._local_windows: Dict[str, Deque[float]] = {} + self._local_last_seen: Dict[str, float] = {} + self._local_idle_ttl = 3600.0 + self._cleanup_interval = 60.0 + self._last_cleanup = 0.0 + + if redis_url and redis is not None: try: self.redis_client = redis.from_url(redis_url) # Test connection @@ -19,51 +33,85 @@ def __init__(self, redis_url: Optional[str] = None): except Exception as e: logger.warning("Failed to connect to Redis: %s", type(e).__name__) self.redis_client = None + elif redis_url and redis is None: + logger.warning("redis package not installed; using in-memory rate limiter") def is_allowed(self, key: str, limit: int, window: int) -> bool: """ - Check if request is allowed using sliding window counter - + Check if request is allowed using a sliding window counter. + Args: key: Unique identifier for the rate limit (e.g., user_id) limit: Maximum number of requests allowed window: Time window in seconds - + Returns: True if request is allowed, False otherwise """ - if not self.redis_client: - # No Redis available, allow all requests - return True - - try: - current_time = int(time.time()) - window_start = current_time - window - - # Use Redis pipeline for atomic operations - pipe = self.redis_client.pipeline() - - # Remove old entries - pipe.zremrangebyscore(key, 0, window_start) - - # Count current requests - pipe.zcard(key) - - # Add current request - pipe.zadd(key, {str(current_time): current_time}) - - # Set expiration - pipe.expire(key, window) - - results = pipe.execute() - current_count = results[1] - - return current_count < limit - - except Exception as e: - logger.warning("Rate limiting error: %s", type(e).__name__) - # Fail open - allow request if Redis is down + if limit <= 0 or window <= 0: + return False + + if self.redis_client: + try: + return self._is_allowed_redis(key=key, limit=limit, window=window) + except Exception as e: + logger.warning( + "Redis rate limiter unavailable; falling back to in-memory limiter: %s", + type(e).__name__, + ) + + return self._is_allowed_local(key=key, limit=limit, window=window) + + def _is_allowed_redis(self, key: str, limit: int, window: int) -> bool: + current_time = time.time() + window_start = current_time - window + member = f"{current_time}:{time.time_ns()}" + + # Use Redis pipeline for atomic operations. + pipe = self.redis_client.pipeline() + pipe.zremrangebyscore(key, 0, window_start) + pipe.zcard(key) + pipe.zadd(key, {member: current_time}) + pipe.expire(key, max(1, int(window))) + + results = pipe.execute() + current_count = int(results[1]) + return current_count < limit + + def _is_allowed_local(self, key: str, limit: int, window: int) -> bool: + current_time = time.time() + window_start = current_time - window + + with self._local_lock: + self._cleanup_local_state(current_time) + events = self._local_windows.setdefault(key, deque()) + + while events and events[0] <= window_start: + events.popleft() + + self._local_last_seen[key] = current_time + + if len(events) >= limit: + return False + + events.append(current_time) return True + def _cleanup_local_state(self, current_time: float) -> None: + if current_time - self._last_cleanup < self._cleanup_interval: + return + + expired_keys = [ + key + for key, last_seen in self._local_last_seen.items() + if current_time - last_seen > self._local_idle_ttl + ] + for expired_key in expired_keys: + self._local_last_seen.pop(expired_key, None) + self._local_windows.pop(expired_key, None) + + self._last_cleanup = current_time + + # Global rate limiter instance rate_limiter = RateLimiter(settings.redis_url) diff --git a/tests/test_rate_limit.py b/tests/test_rate_limit.py new file mode 100644 index 0000000..1196982 --- /dev/null +++ b/tests/test_rate_limit.py @@ -0,0 +1,52 @@ +from app.rate_limit import RateLimiter + + +class FailingPipeline: + def zremrangebyscore(self, *_args, **_kwargs): + return self + + def zcard(self, *_args, **_kwargs): + return self + + def zadd(self, *_args, **_kwargs): + return self + + def expire(self, *_args, **_kwargs): + return self + + def execute(self): + raise RuntimeError("redis unavailable") + + +class FailingRedis: + def pipeline(self): + return FailingPipeline() + + +def test_in_memory_fallback_enforces_limit_without_redis(): + limiter = RateLimiter(redis_url=None) + + assert limiter.is_allowed("user-a", limit=2, window=60) is True + assert limiter.is_allowed("user-a", limit=2, window=60) is True + assert limiter.is_allowed("user-a", limit=2, window=60) is False + + +def test_in_memory_fallback_enforces_limit_when_redis_errors(): + limiter = RateLimiter(redis_url=None) + limiter.redis_client = FailingRedis() + + assert limiter.is_allowed("user-b", limit=1, window=60) is True + assert limiter.is_allowed("user-b", limit=1, window=60) is False + + +def test_in_memory_fallback_resets_after_window(monkeypatch): + limiter = RateLimiter(redis_url=None) + now = [1000.0] + + monkeypatch.setattr("app.rate_limit.time.time", lambda: now[0]) + + assert limiter.is_allowed("user-c", limit=1, window=10) is True + assert limiter.is_allowed("user-c", limit=1, window=10) is False + + now[0] = 1011.0 + assert limiter.is_allowed("user-c", limit=1, window=10) is True