diff --git a/src/apprentice/apprentice_class.py b/src/apprentice/apprentice_class.py index a0b8f05..05a30c5 100644 --- a/src/apprentice/apprentice_class.py +++ b/src/apprentice/apprentice_class.py @@ -433,6 +433,36 @@ async def close(self): if self._running: await self.__aexit__(None, None, None) + async def feedback(self, request_id: str, feedback_type: str, **kwargs) -> None: + """Record feedback for a previous request. No-op if feedback collector not configured.""" + collector = getattr(self, '_feedback_collector', None) + if collector is None: + return + from apprentice.feedback_collector import FeedbackEntry, FeedbackType + entry = FeedbackEntry( + request_id=request_id, + task_name=kwargs.get('task_name', ''), + feedback_type=FeedbackType(feedback_type), + score=kwargs.get('score', 0.0), + edited_output=kwargs.get('edited_output'), + reason=kwargs.get('reason'), + ) + collector.record_feedback(entry) + + async def observe(self, event) -> None: + """Record an observation event. No-op if observer not configured.""" + observer = getattr(self, '_observer', None) + if observer is None: + return + observer.observe(event) + + async def record_action(self, event_id: str, action: dict) -> None: + """Record an actual action for a previously observed event.""" + observer = getattr(self, '_observer', None) + if observer is None: + return + observer.record_action(event_id, action) + async def run(self, task_name: str, input_data: Dict[str, Any]) -> TaskResponse: """ Primary public method. Executes a task: routes to appropriate model backend, diff --git a/src/apprentice/config_loader.py b/src/apprentice/config_loader.py index ac530a8..82df7c4 100644 --- a/src/apprentice/config_loader.py +++ b/src/apprentice/config_loader.py @@ -332,6 +332,39 @@ class TrainingDataStoreConfig(BaseModel): max_examples_per_task: int = Field(default=50000, ge=100, le=10000000) +class PluginEntryConfig(BaseModel): + """Configuration for a single plugin entry.""" + model_config = ConfigDict(frozen=True, strict=True, extra="forbid") + + factory: str = Field(min_length=1) + + +class MiddlewareEntryConfig(BaseModel): + """Configuration for a single middleware entry in the pipeline.""" + model_config = ConfigDict(frozen=True, strict=True, extra="allow") + + name: str = Field(min_length=1) + config: Mapping[str, Any] = Field(default_factory=dict) + + +class FeedbackConfig(BaseModel): + """Configuration for the feedback collector.""" + model_config = ConfigDict(frozen=True, strict=True, extra="forbid") + + enabled: bool = False + storage_dir: str = ".apprentice/feedback/" + + +class ObserverConfig(BaseModel): + """Configuration for the observer.""" + model_config = ConfigDict(frozen=True, strict=True, extra="forbid") + + enabled: bool = False + context_window_size: int = Field(default=50, ge=1, le=1000) + shadow_recommendation_rate: float = Field(default=0.1, ge=0.0, le=1.0) + min_context_before_recommending: int = Field(default=10, ge=1) + + class ApprenticeConfig(BaseModel): """Root configuration model. Frozen and immutable after construction.""" model_config = ConfigDict(frozen=True, strict=True, extra="forbid") @@ -344,6 +377,13 @@ class ApprenticeConfig(BaseModel): audit: AuditConfig training_data: TrainingDataStoreConfig + # New extensibility fields (all optional, backward compatible) + mode: str = Field(default="distillation", pattern=r"^(distillation|copilot|observer)$") + plugins: Optional[Mapping[str, Mapping[str, PluginEntryConfig]]] = None + middleware: Optional[List[MiddlewareEntryConfig]] = None + feedback: Optional[FeedbackConfig] = None + observer: Optional[ObserverConfig] = None + @model_validator(mode="after") def validate_cross_field_constraints(self) -> "ApprenticeConfig": """Performs all cross-field validations.""" diff --git a/src/apprentice/factory.py b/src/apprentice/factory.py index bfc7bd2..f8cd519 100644 --- a/src/apprentice/factory.py +++ b/src/apprentice/factory.py @@ -278,6 +278,7 @@ async def build_from_config(config_path: str) -> Any: TaskConfig as ACTaskConfig, ConfidenceThresholds as ACThresholds, ) + from apprentice.plugin_registry import PluginRegistrySet from apprentice.audit_log import AuditConfig, JsonLinesAuditLogger from apprentice.budget_manager import BudgetConfig, BudgetManager, PeriodLimit, PeriodType from apprentice.confidence_engine import ConfidenceEngine, ConfidenceEngineConfig @@ -306,6 +307,14 @@ async def build_from_config(config_path: str) -> Any: # 1. Load validated config cfg = config_loader.load_config(Path(config_path)) + # 1.5. Construct Plugin Registry + plugin_registry_set = PluginRegistrySet.with_defaults() + if cfg.plugins: + plugin_registry_set.register_from_config( + {domain: {name: {"factory": entry.factory} for name, entry in plugins.items()} + for domain, plugins in cfg.plugins.items()} + ) + # 2. Create directories base_dir = Path(".apprentice") base_dir.mkdir(parents=True, exist_ok=True) @@ -551,5 +560,41 @@ async def build_from_config(config_path: str) -> Any: apprentice._ft_version_store = ft_version_store apprentice._model_validator = model_validator apprentice._real_config = cfg + apprentice._plugin_registry_set = plugin_registry_set + + # Construct middleware pipeline if configured + if cfg.middleware: + from apprentice.middleware import MiddlewarePipeline + middleware_registry = plugin_registry_set.get_registry("middleware") + apprentice._middleware_pipeline = MiddlewarePipeline.from_config( + cfg.middleware, middleware_registry, + ) + else: + apprentice._middleware_pipeline = None + + # Construct feedback collector if configured + if cfg.feedback and cfg.feedback.enabled: + from apprentice.feedback_collector import FeedbackCollector + apprentice._feedback_collector = FeedbackCollector( + storage_dir=cfg.feedback.storage_dir, + enabled=True, + ) + else: + apprentice._feedback_collector = None + + # Construct observer if configured + if cfg.observer and cfg.observer.enabled: + from apprentice.observer import Observer, ObserverConfig as ObsCfg + apprentice._observer = Observer(ObsCfg( + enabled=True, + context_window_size=cfg.observer.context_window_size, + shadow_recommendation_rate=cfg.observer.shadow_recommendation_rate, + min_context_before_recommending=cfg.observer.min_context_before_recommending, + )) + else: + apprentice._observer = None + + # Store mode + apprentice._mode = cfg.mode return apprentice diff --git a/src/apprentice/feedback_collector.py b/src/apprentice/feedback_collector.py new file mode 100644 index 0000000..d855add --- /dev/null +++ b/src/apprentice/feedback_collector.py @@ -0,0 +1,253 @@ +""" +Feedback Collector (feedback_collector) v1 + +Human-in-the-loop and AI-scored feedback collection for the Apprentice system. +Records accept/reject/edit/ignore/ai_score feedback per task as append-only +JSON-lines files, computes summaries and confidence adjustments from feedback history. +""" + +import json +import logging +import uuid +from datetime import datetime, timezone +from enum import Enum +from pathlib import Path +from typing import Optional + +from pydantic import BaseModel, ConfigDict, Field + +_PACT_KEY = "PACT:feedback_collector" +logger = logging.getLogger(__name__) + + +def _log(level: str, msg: str, **kwargs) -> None: + """Log with PACT key embedded for production traceability.""" + getattr(logger, level)(f"[{_PACT_KEY}] {msg}", **kwargs) + + +# ============================================================================ +# ENUMS +# ============================================================================ + +class FeedbackType(str, Enum): + """StrEnum representing the kind of feedback provided.""" + accept = "accept" + reject = "reject" + edit = "edit" + ignore = "ignore" + ai_score = "ai_score" + + +# ============================================================================ +# DATA MODELS +# ============================================================================ + +class FeedbackEntry(BaseModel): + """Frozen Pydantic v2 model representing a single feedback record.""" + model_config = ConfigDict(frozen=True) + + feedback_id: str = Field(default_factory=lambda: str(uuid.uuid4())) + request_id: str + task_name: str + feedback_type: FeedbackType + score: float = Field(default=0.0, ge=0.0, le=1.0) + edited_output: Optional[dict] = None + reason: Optional[str] = None + timestamp: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + + +class FeedbackSummary(BaseModel): + """Frozen Pydantic v2 model summarizing all feedback for a single task.""" + model_config = ConfigDict(frozen=True) + + task_name: str + accept_count: int = 0 + reject_count: int = 0 + edit_count: int = 0 + ignore_count: int = 0 + ai_score_count: int = 0 + total_count: int = 0 + acceptance_rate: float = 0.0 + average_ai_score: float = 0.0 + + +class FeedbackConfig(BaseModel): + """Frozen Pydantic v2 configuration model for the FeedbackCollector.""" + model_config = ConfigDict(frozen=True) + + enabled: bool = False + storage_dir: str = ".apprentice/feedback/" + + +# ============================================================================ +# FEEDBACK COLLECTOR +# ============================================================================ + +class FeedbackCollector: + """ + Collects, persists, and summarizes human and AI feedback per task. + + Feedback is stored as append-only JSON-lines files, one file per task, + inside a configurable storage directory. When disabled, all write + operations are no-ops and read operations return empty/default values. + """ + + def __init__( + self, + storage_dir: str = ".apprentice/feedback/", + enabled: bool = False, + ) -> None: + self._storage_dir = Path(storage_dir) + self._enabled = enabled + _log("info", f"FeedbackCollector initialized (enabled={enabled}, dir={storage_dir})") + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _task_file(self, task_name: str) -> Path: + """Return the JSON-lines file path for a given task.""" + return self._storage_dir / f"{task_name}.jsonl" + + def _ensure_dir(self) -> None: + """Create the storage directory tree if it does not already exist.""" + self._storage_dir.mkdir(parents=True, exist_ok=True) + + def _read_entries(self, task_name: str) -> list[FeedbackEntry]: + """Read all FeedbackEntry records from a task's JSON-lines file.""" + path = self._task_file(task_name) + if not path.exists(): + return [] + + entries: list[FeedbackEntry] = [] + with open(path, "r", encoding="utf-8") as fh: + for line in fh: + line = line.strip() + if not line: + continue + try: + data = json.loads(line) + entries.append(FeedbackEntry(**data)) + except (json.JSONDecodeError, Exception) as exc: + _log("warning", f"Skipping malformed line in {path}: {exc}") + return entries + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def record_feedback(self, entry: FeedbackEntry) -> None: + """ + Persist a single FeedbackEntry by appending it to the task's + JSON-lines file. No-op when the collector is disabled. + """ + if not self._enabled: + return + + self._ensure_dir() + path = self._task_file(entry.task_name) + with open(path, "a", encoding="utf-8") as fh: + fh.write(entry.model_dump_json() + "\n") + + _log("info", f"Recorded {entry.feedback_type.value} feedback for task '{entry.task_name}'") + + def get_feedback_summary(self, task_name: str) -> FeedbackSummary: + """ + Read all feedback for *task_name* and compute an aggregate summary. + Returns an empty summary when disabled or when no feedback exists. + """ + if not self._enabled: + return FeedbackSummary(task_name=task_name) + + entries = self._read_entries(task_name) + if not entries: + return FeedbackSummary(task_name=task_name) + + accept_count = sum(1 for e in entries if e.feedback_type == FeedbackType.accept) + reject_count = sum(1 for e in entries if e.feedback_type == FeedbackType.reject) + edit_count = sum(1 for e in entries if e.feedback_type == FeedbackType.edit) + ignore_count = sum(1 for e in entries if e.feedback_type == FeedbackType.ignore) + ai_score_count = sum(1 for e in entries if e.feedback_type == FeedbackType.ai_score) + total_count = len(entries) + + # Acceptance rate: accepts / (accepts + rejects), 0.0 if denominator is zero + accept_reject_total = accept_count + reject_count + acceptance_rate = (accept_count / accept_reject_total) if accept_reject_total > 0 else 0.0 + + # Average AI score: mean of score field across ai_score entries + ai_entries = [e for e in entries if e.feedback_type == FeedbackType.ai_score] + average_ai_score = ( + sum(e.score for e in ai_entries) / len(ai_entries) + if ai_entries + else 0.0 + ) + + return FeedbackSummary( + task_name=task_name, + accept_count=accept_count, + reject_count=reject_count, + edit_count=edit_count, + ignore_count=ignore_count, + ai_score_count=ai_score_count, + total_count=total_count, + acceptance_rate=acceptance_rate, + average_ai_score=average_ai_score, + ) + + def compute_feedback_adjustment(self, task_name: str) -> float: + """ + Derive a confidence adjustment in the range [-0.1, +0.1] from the + accept/reject ratio for *task_name*. + + * Pure accepts -> +0.1 + * Pure rejects -> -0.1 + * No data or disabled -> 0.0 + + The value scales linearly between -0.1 and +0.1 based on + (accepts - rejects) / (accepts + rejects). + """ + if not self._enabled: + return 0.0 + + summary = self.get_feedback_summary(task_name) + accept_reject_total = summary.accept_count + summary.reject_count + if accept_reject_total == 0: + return 0.0 + + # Ratio in [-1.0, +1.0] + ratio = (summary.accept_count - summary.reject_count) / accept_reject_total + # Scale to [-0.1, +0.1] + adjustment = ratio * 0.1 + # Clamp for safety + return max(-0.1, min(0.1, adjustment)) + + def list_tasks(self) -> list[str]: + """ + Return the names of all tasks that have at least one feedback entry + on disk. Returns an empty list when disabled or when the storage + directory does not exist. + """ + if not self._enabled: + return [] + + if not self._storage_dir.exists(): + return [] + + tasks: list[str] = [] + for path in sorted(self._storage_dir.iterdir()): + if path.suffix == ".jsonl" and path.stat().st_size > 0: + tasks.append(path.stem) + return tasks + + +# ============================================================================ +# EXPORTS +# ============================================================================ + +__all__ = [ + "FeedbackType", + "FeedbackEntry", + "FeedbackSummary", + "FeedbackConfig", + "FeedbackCollector", +] diff --git a/src/apprentice/middleware.py b/src/apprentice/middleware.py new file mode 100644 index 0000000..9c87cbd --- /dev/null +++ b/src/apprentice/middleware.py @@ -0,0 +1,229 @@ +""" +Middleware Pipeline — pre/post processing hooks for the Apprentice request lifecycle. + +Provides a Protocol-based middleware abstraction with an onion-model pipeline: + - pre_process runs in registration order (first-registered runs first) + - post_process runs in reverse order (first-registered runs last — onion/LIFO) + +Error in any single middleware is logged and skipped; it never prevents +other middleware from executing. + +All models use frozen=True. State flows between pre and post phases via +the middleware_state dict on MiddlewareContext and MiddlewareResponse. +New instances are created (via model_copy) at each pipeline step. +""" + +import logging +from typing import Any, Protocol, runtime_checkable + +from pydantic import BaseModel, ConfigDict, Field + + +logger = logging.getLogger(__name__) + + +# =========================================================================== +# Data Models +# =========================================================================== + + +class MiddlewareContext(BaseModel): + """Frozen Pydantic model carrying request data through the middleware pipeline. + + Attributes: + request_id: Unique identifier for this request. + task_name: Name of the task being processed. + input_data: The raw input payload. + metadata: Arbitrary metadata associated with the request. + middleware_state: Opaque dict for passing data from pre_process to post_process. + Each middleware may add keys; accumulated across the pipeline via model_copy. + """ + model_config = ConfigDict(frozen=True) + + request_id: str + task_name: str + input_data: dict + metadata: dict = Field(default_factory=dict) + middleware_state: dict = Field(default_factory=dict) + + +class MiddlewareResponse(BaseModel): + """Frozen Pydantic model carrying response data through the middleware pipeline. + + Attributes: + output_data: The output payload. + metadata: Arbitrary metadata associated with the response. + middleware_state: Opaque dict for passing data through the post-processing phase. + """ + model_config = ConfigDict(frozen=True) + + output_data: dict + metadata: dict = Field(default_factory=dict) + middleware_state: dict = Field(default_factory=dict) + + +# =========================================================================== +# Errors +# =========================================================================== + + +class MiddlewareError(Exception): + """Base error for middleware failures. + + Attributes: + middleware_name: Name/class of the middleware that failed. + phase: Which phase failed — ``"pre_process"`` or ``"post_process"``. + reason: Human-readable description of the failure. + """ + + def __init__(self, middleware_name: str, phase: str, reason: str) -> None: + self.middleware_name = middleware_name + self.phase = phase + self.reason = reason + super().__init__( + f"MiddlewareError in '{middleware_name}' during {phase}: {reason}" + ) + + +# =========================================================================== +# Protocol +# =========================================================================== + + +@runtime_checkable +class Middleware(Protocol): + """Runtime-checkable protocol that all middleware implementations must satisfy.""" + + def pre_process(self, context: MiddlewareContext) -> MiddlewareContext: + """Transform the context before the core task executes. + + Must return a (potentially new) MiddlewareContext instance. + """ + ... + + def post_process( + self, context: MiddlewareContext, response: MiddlewareResponse + ) -> MiddlewareResponse: + """Transform the response after the core task executes. + + Must return a (potentially new) MiddlewareResponse instance. + """ + ... + + +# =========================================================================== +# Pipeline +# =========================================================================== + + +class MiddlewarePipeline: + """Ordered pipeline of Middleware instances. + + - ``execute_pre`` runs each middleware's ``pre_process`` in registration order. + - ``execute_post`` runs each middleware's ``post_process`` in **reverse** order + (onion/LIFO model). + - An error in any single middleware is logged and skipped — it does **not** + prevent the remaining middleware from running. + - An empty pipeline is a safe no-op passthrough. + """ + + def __init__(self, middlewares: list[Middleware] | None = None) -> None: + self._middlewares: list[Middleware] = list(middlewares) if middlewares else [] + + @property + def middlewares(self) -> list[Middleware]: + """Return a shallow copy of the registered middleware list.""" + return list(self._middlewares) + + # ---- Pre-processing (forward order) ------------------------------------ + + def execute_pre(self, context: MiddlewareContext) -> MiddlewareContext: + """Run ``pre_process`` on each middleware in registration order. + + Middleware state is accumulated: the ``middleware_state`` dict from each + step is merged into the next context via ``model_copy``. + + If a middleware raises, the error is logged and that middleware is + skipped; processing continues with the remaining middleware. + """ + current = context + for mw in self._middlewares: + mw_name = type(mw).__name__ + try: + result = mw.pre_process(current) + # Merge middleware_state from result into the running state + merged_state = {**current.middleware_state, **result.middleware_state} + current = result.model_copy(update={"middleware_state": merged_state}) + except Exception as exc: + logger.error( + "Middleware '%s' failed during pre_process: %s", + mw_name, + exc, + ) + return current + + # ---- Post-processing (reverse order) ----------------------------------- + + def execute_post( + self, context: MiddlewareContext, response: MiddlewareResponse + ) -> MiddlewareResponse: + """Run ``post_process`` on each middleware in **reverse** registration order. + + If a middleware raises, the error is logged and that middleware is + skipped; processing continues with the remaining middleware. + """ + current = response + for mw in reversed(self._middlewares): + mw_name = type(mw).__name__ + try: + result = mw.post_process(context, current) + # Merge middleware_state from result into the running state + merged_state = {**current.middleware_state, **result.middleware_state} + current = result.model_copy(update={"middleware_state": merged_state}) + except Exception as exc: + logger.error( + "Middleware '%s' failed during post_process: %s", + mw_name, + exc, + ) + return current + + # ---- Factory ----------------------------------------------------------- + + @classmethod + def from_config( + cls, + config_list: list[dict], + registry: Any, + ) -> "MiddlewarePipeline": + """Build a pipeline from a YAML-style config list using a registry. + + Expected config format:: + + [ + {"name": "pii_tokenizer", "config": {"key": "value"}}, + {"name": "rate_limiter", "config": {}}, + ] + + The *registry* must support ``registry.create(name, **config)`` which + returns a ``Middleware``-compatible instance. + + If ``config_list`` is empty or ``None``, returns an empty (no-op) pipeline. + """ + if not config_list: + return cls([]) + + middlewares: list[Middleware] = [] + for entry in config_list: + name = entry.get("name", "") + config = entry.get("config", {}) + try: + mw = registry.create(name, **config) + middlewares.append(mw) + except Exception as exc: + logger.error( + "Failed to create middleware '%s' from config: %s", + name, + exc, + ) + return cls(middlewares) diff --git a/src/apprentice/observer.py b/src/apprentice/observer.py new file mode 100644 index 0000000..24ce197 --- /dev/null +++ b/src/apprentice/observer.py @@ -0,0 +1,244 @@ +""" +Observer Component (observer) v1 + +Observes user and agent actions per task, maintains a rolling context window, +and probabilistically generates shadow recommendations for later comparison. + +Shadow recommendations are placeholder-only in this phase; actual model +integration is deferred. +""" + +import random +import uuid +from collections import deque +from datetime import datetime, timezone +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, ConfigDict, Field + + +# =========================================================================== +# Enums +# =========================================================================== + + +class EventType(str, Enum): + """Type of event observed by the Observer.""" + user_action = "user_action" + agent_action = "agent_action" + system_event = "system_event" + + +# =========================================================================== +# Data Models +# =========================================================================== + + +class ObservationEvent(BaseModel): + """Frozen Pydantic model representing a single observed event.""" + model_config = ConfigDict(frozen=True) + + event_id: str = Field(default_factory=lambda: str(uuid.uuid4())) + task_name: str + event_type: EventType + action_data: dict = Field(default_factory=dict) + context: dict = Field(default_factory=dict) + timestamp: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + + +class ShadowRecommendation(BaseModel): + """Frozen Pydantic model representing a shadow recommendation paired with an observed event.""" + model_config = ConfigDict(frozen=True) + + recommendation_id: str = Field(default_factory=lambda: str(uuid.uuid4())) + event_id: str + recommended_action: dict = Field(default_factory=dict) + actual_action: Optional[dict] = None + match_score: Optional[float] = None # 0.0-1.0 when actual_action filled + model_used: str = "" + + +class ObserverConfig(BaseModel): + """Frozen Pydantic configuration model for the Observer.""" + model_config = ConfigDict(frozen=True) + + enabled: bool = False + context_window_size: int = Field(default=50, ge=1, le=1000) + shadow_recommendation_rate: float = Field(default=0.1, ge=0.0, le=1.0) + min_context_before_recommending: int = Field(default=10, ge=1) + + +# =========================================================================== +# Match Score Computation +# =========================================================================== + + +def _compute_match_score(recommended: dict, actual: dict) -> float: + """Compute simple dict key overlap ratio between recommended_action and actual_action. + + Returns 0.0 if both dicts are empty (no keys to compare). + Otherwise returns |intersection of keys| / |union of keys|. + """ + rec_keys = set(recommended.keys()) + act_keys = set(actual.keys()) + union = rec_keys | act_keys + if not union: + return 0.0 + intersection = rec_keys & act_keys + return len(intersection) / len(union) + + +# =========================================================================== +# Observer +# =========================================================================== + + +class Observer: + """Observes task events, maintains rolling context windows, and generates + probabilistic shadow recommendations for offline comparison. + + When ``config.enabled`` is False, ``observe`` and ``record_action`` are + no-ops (observe does nothing; record_action returns None). + """ + + def __init__(self, config: ObserverConfig) -> None: + self._config = config + # task_name -> deque of ObservationEvent (bounded by context_window_size) + self._context: dict[str, deque[ObservationEvent]] = {} + # event_id -> ShadowRecommendation + self._recommendations: dict[str, ShadowRecommendation] = {} + # task_name -> list of event_ids that have recommendations (preserves order) + self._task_recommendation_ids: dict[str, list[str]] = {} + # track total events observed per task (for min_context_before_recommending) + self._event_counts: dict[str, int] = {} + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def observe(self, event: ObservationEvent) -> None: + """Add an event to the rolling context window for its task. + + If the context window exceeds ``config.context_window_size``, the + oldest event is dropped. When the observer is disabled this is a + no-op. + """ + if not self._config.enabled: + return + + task = event.task_name + + # Ensure deque exists for this task + if task not in self._context: + self._context[task] = deque(maxlen=self._config.context_window_size) + + self._context[task].append(event) + + # Track total events observed (not just window contents) + self._event_counts[task] = self._event_counts.get(task, 0) + 1 + + # Possibly generate a shadow recommendation + self._maybe_generate_recommendation(event) + + def record_action(self, event_id: str, action: dict) -> Optional[ShadowRecommendation]: + """Record the actual action taken for a previously observed event. + + If a shadow recommendation was generated for this event, computes the + match score and returns the updated ``ShadowRecommendation``. + Returns ``None`` if the observer is disabled or no recommendation + exists for the given event_id. + """ + if not self._config.enabled: + return None + + if event_id not in self._recommendations: + return None + + old_rec = self._recommendations[event_id] + score = _compute_match_score(old_rec.recommended_action, action) + + updated = old_rec.model_copy( + update={ + "actual_action": action, + "match_score": score, + } + ) + self._recommendations[event_id] = updated + return updated + + def get_context(self, task_name: str) -> list[ObservationEvent]: + """Return the current context window for *task_name*. + + Returns an empty list if no events have been observed for the task. + """ + if task_name not in self._context: + return [] + return list(self._context[task_name]) + + def get_recommendations(self, task_name: str) -> list[ShadowRecommendation]: + """Return all shadow recommendations for *task_name*. + + Returns an empty list if no recommendations exist for the task. + """ + if task_name not in self._task_recommendation_ids: + return [] + return [ + self._recommendations[eid] + for eid in self._task_recommendation_ids[task_name] + if eid in self._recommendations + ] + + def clear(self, task_name: str) -> None: + """Clear context window and recommendations for *task_name*.""" + self._context.pop(task_name, None) + removed_ids = self._task_recommendation_ids.pop(task_name, []) + for eid in removed_ids: + self._recommendations.pop(eid, None) + self._event_counts.pop(task_name, None) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _maybe_generate_recommendation(self, event: ObservationEvent) -> None: + """Probabilistically generate a shadow recommendation for *event*. + + A recommendation is only considered when the total number of events + observed for the task meets ``min_context_before_recommending``. + The recommendation is then generated with probability + ``shadow_recommendation_rate``. + """ + task = event.task_name + total_events = self._event_counts.get(task, 0) + + if total_events < self._config.min_context_before_recommending: + return + + if random.random() >= self._config.shadow_recommendation_rate: + return + + # Placeholder model call — returns empty recommended_action + rec = ShadowRecommendation( + event_id=event.event_id, + recommended_action={}, + model_used="placeholder", + ) + + self._recommendations[event.event_id] = rec + if task not in self._task_recommendation_ids: + self._task_recommendation_ids[task] = [] + self._task_recommendation_ids[task].append(event.event_id) + + +# =========================================================================== +# Exports +# =========================================================================== + +__all__ = [ + "EventType", + "ObservationEvent", + "ShadowRecommendation", + "ObserverConfig", + "Observer", +] diff --git a/src/apprentice/pii_tokenizer.py b/src/apprentice/pii_tokenizer.py new file mode 100644 index 0000000..a4586ed --- /dev/null +++ b/src/apprentice/pii_tokenizer.py @@ -0,0 +1,219 @@ +""" +PII Tokenizer Middleware — scans and replaces PII with opaque tokens. + +Implements the Middleware protocol from apprentice.middleware. + +Critical invariants: + - Token registry is NEVER written to disk. + - Training data store NEVER receives un-tokenized PII. + - Model NEVER sees real PII values. + - Tokens are opaque to the model. +""" + +import hashlib +import re +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + +from apprentice.middleware import MiddlewareContext, MiddlewareResponse + + +# =========================================================================== +# Built-in PII patterns +# =========================================================================== + +_BUILTIN_PATTERNS: dict[str, str] = { + "email": r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", + "phone": r"\b\d{3}[-.]?\d{3}[-.]?\d{4}\b", + "ssn": r"\b\d{3}-\d{2}-\d{4}\b", + "credit_card": r"\b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b", +} + + +# =========================================================================== +# Configuration +# =========================================================================== + + +class PIITokenizerConfig(BaseModel): + """Frozen configuration for the PII tokenizer middleware. + + Attributes: + enabled_entity_types: Which built-in entity types to scan for. + custom_patterns: Additional name-to-regex mappings beyond the built-ins. + token_format: Python format string for replacement tokens. + Must contain ``{type}`` and ``{hash}`` placeholders. + """ + model_config = ConfigDict(frozen=True) + + enabled_entity_types: list[str] = Field( + default_factory=lambda: ["email", "phone", "ssn", "credit_card"] + ) + custom_patterns: dict[str, str] = Field(default_factory=dict) + token_format: str = "" + + +# =========================================================================== +# Token Registry (in-memory only — NEVER persisted) +# =========================================================================== + + +class TokenRegistry: + """In-memory bidirectional mapping between PII values and opaque tokens. + + - Deterministic: the same (value, entity_type) pair always produces the + same token within the lifetime of this registry instance. + - The registry is NEVER serialized or written to disk. + """ + + def __init__(self, token_format: str = "") -> None: + self._token_format = token_format + # Forward: token -> original value + self._token_to_value: dict[str, str] = {} + # Reverse: original value -> token (for determinism) + self._value_to_token: dict[str, str] = {} + + def tokenize(self, value: str, entity_type: str) -> str: + """Replace *value* with an opaque token. + + If the same *value* was already tokenized, returns the same token + (deterministic within one registry instance). + """ + if value in self._value_to_token: + return self._value_to_token[value] + + short_hash = hashlib.sha256(value.encode("utf-8")).hexdigest()[:8] + token = self._token_format.format(type=entity_type, hash=short_hash) + + self._token_to_value[token] = value + self._value_to_token[value] = token + return token + + def detokenize(self, token: str) -> str: + """Restore the original value from *token*. + + Returns the token unchanged if it is not recognized. + """ + return self._token_to_value.get(token, token) + + def clear(self) -> None: + """Wipe all mappings.""" + self._token_to_value.clear() + self._value_to_token.clear() + + +# =========================================================================== +# Recursive scan / replace helpers +# =========================================================================== + + +def _scan_and_replace( + data: Any, + compiled_patterns: list[tuple[str, re.Pattern]], + registry: TokenRegistry, +) -> Any: + """Recursively walk *data* and replace PII matches with tokens.""" + if isinstance(data, str): + result = data + for entity_type, pattern in compiled_patterns: + result = pattern.sub( + lambda m, et=entity_type: registry.tokenize(m.group(), et), + result, + ) + return result + if isinstance(data, dict): + return {k: _scan_and_replace(v, compiled_patterns, registry) for k, v in data.items()} + if isinstance(data, list): + return [_scan_and_replace(item, compiled_patterns, registry) for item in data] + return data + + +def _restore_tokens(data: Any, registry: TokenRegistry) -> Any: + """Recursively walk *data* and restore original PII values from tokens.""" + if isinstance(data, str): + result = data + # Replace every token found in the string + for token, original in registry._token_to_value.items(): + result = result.replace(token, original) + return result + if isinstance(data, dict): + return {k: _restore_tokens(v, registry) for k, v in data.items()} + if isinstance(data, list): + return [_restore_tokens(item, registry) for item in data] + return data + + +# =========================================================================== +# PIITokenizer — Middleware implementation +# =========================================================================== + + +class PIITokenizer: + """Middleware that tokenizes PII in ``input_data`` before the model sees it, + and restores original values in ``output_data`` after the model responds. + + Implements the ``Middleware`` protocol from ``apprentice.middleware``. + """ + + def __init__(self, config: PIITokenizerConfig | None = None) -> None: + self._config = config or PIITokenizerConfig() + self._compiled_patterns = self._build_patterns() + + def _build_patterns(self) -> list[tuple[str, re.Pattern]]: + """Compile the enabled built-in patterns plus any custom patterns.""" + patterns: list[tuple[str, re.Pattern]] = [] + + # SSN pattern must be checked before phone to avoid partial matches + # when both are enabled. Order: ssn first, then everything else. + ordered_types: list[str] = [] + if "ssn" in self._config.enabled_entity_types: + ordered_types.append("ssn") + for et in self._config.enabled_entity_types: + if et != "ssn" and et in _BUILTIN_PATTERNS: + ordered_types.append(et) + + for entity_type in ordered_types: + raw = _BUILTIN_PATTERNS[entity_type] + patterns.append((entity_type, re.compile(raw))) + + # Custom patterns (appended after built-ins) + for name, raw in self._config.custom_patterns.items(): + patterns.append((name, re.compile(raw))) + + return patterns + + # ---- Middleware Protocol ------------------------------------------------ + + def pre_process(self, context: MiddlewareContext) -> MiddlewareContext: + """Scan ``input_data`` for PII and replace with opaque tokens. + + The ``TokenRegistry`` is stored in ``middleware_state["pii_token_registry"]`` + so that ``post_process`` can restore original values. + """ + registry = TokenRegistry(self._config.token_format) + tokenized_input = _scan_and_replace( + context.input_data, self._compiled_patterns, registry + ) + + new_state = {**context.middleware_state, "pii_token_registry": registry} + return context.model_copy( + update={"input_data": tokenized_input, "middleware_state": new_state} + ) + + def post_process( + self, context: MiddlewareContext, response: MiddlewareResponse + ) -> MiddlewareResponse: + """Restore original PII values in ``output_data`` from tokens. + + Reads the ``TokenRegistry`` from ``context.middleware_state["pii_token_registry"]``. + If the registry is missing, returns the response unchanged. + """ + registry: TokenRegistry | None = context.middleware_state.get( + "pii_token_registry" + ) + if registry is None: + return response + + restored_output = _restore_tokens(response.output_data, registry) + return response.model_copy(update={"output_data": restored_output}) diff --git a/src/apprentice/plugin_registry.py b/src/apprentice/plugin_registry.py new file mode 100644 index 0000000..4064610 --- /dev/null +++ b/src/apprentice/plugin_registry.py @@ -0,0 +1,234 @@ +""" +Plugin Registry — instance-based, domain-scoped plugin registration system. + +Replaces hardcoded enums with string-keyed factory registries. +Thread-safe, O(1) lookup, no global state. +""" + +import importlib +import re +import threading +from types import SimpleNamespace +from typing import Any, Callable + + +# ============================================================================ +# Errors +# ============================================================================ + + +class PluginError(Exception): + """Base error for all plugin registry errors.""" + + +class DuplicatePluginError(PluginError): + """Raised when registering a name that already exists.""" + + def __init__(self, name: str, domain: str): + self.name = name + self.domain = domain + super().__init__(f"Plugin '{name}' already registered in domain '{domain}'") + + +class UnknownPluginError(PluginError): + """Raised when looking up a name that doesn't exist.""" + + def __init__(self, name: str, domain: str, available: list[str]): + self.name = name + self.domain = domain + self.available = available + super().__init__( + f"Unknown plugin '{name}' in domain '{domain}'. " + f"Available: {available}" + ) + + +class PluginConfigError(PluginError): + """Raised when config-based registration fails.""" + + def __init__(self, domain: str, name: str, reason: str): + self.domain = domain + self.name = name + self.reason = reason + super().__init__( + f"Failed to register plugin '{name}' in domain '{domain}': {reason}" + ) + + +# ============================================================================ +# Name validation +# ============================================================================ + +_PLUGIN_NAME_RE = re.compile(r"^[a-z][a-z0-9_]*$") + + +def _validate_plugin_name(name: str) -> None: + """Validate that a plugin name matches the required pattern.""" + if not name: + raise ValueError("Plugin name must be non-empty") + if not _PLUGIN_NAME_RE.match(name): + raise ValueError( + f"Plugin name '{name}' must match ^[a-z][a-z0-9_]*$ " + f"(lowercase letters, digits, underscores, starting with a letter)" + ) + + +# ============================================================================ +# PluginRegistry +# ============================================================================ + + +class PluginRegistry: + """Instance-based, domain-scoped plugin registry. + + Maps string names to factory callables. O(1) lookup. Thread-safe. + """ + + def __init__(self, domain: str): + if not domain or not domain.strip(): + raise ValueError("Domain must be a non-empty string") + self._domain = domain + self._factories: dict[str, Callable[..., Any]] = {} + self._lock = threading.Lock() + + @property + def domain(self) -> str: + return self._domain + + def register(self, name: str, factory: Callable[..., Any]) -> None: + """Register a factory callable under the given name.""" + _validate_plugin_name(name) + if not callable(factory): + raise TypeError(f"Factory must be callable, got {type(factory).__name__}") + with self._lock: + if name in self._factories: + raise DuplicatePluginError(name, self._domain) + self._factories[name] = factory + + def create(self, name: str, **config: Any) -> Any: + """Look up factory by name, call it with **config, return the result.""" + with self._lock: + factory = self._factories.get(name) + if factory is None: + raise UnknownPluginError(name, self._domain, self.list_plugins()) + return factory(**config) + + def validate_name(self, name: str) -> bool: + """Return True if name is registered, False otherwise.""" + with self._lock: + return name in self._factories + + def list_plugins(self) -> list[str]: + """Return sorted list of registered plugin names.""" + with self._lock: + return sorted(self._factories.keys()) + + def __contains__(self, name: str) -> bool: + return self.validate_name(name) + + def __len__(self) -> int: + with self._lock: + return len(self._factories) + + +# ============================================================================ +# PluginRegistrySet +# ============================================================================ + + +def _make_default_factory(name: str) -> Callable[..., Any]: + """Create a placeholder factory that returns a SimpleNamespace with config.""" + def factory(**config: Any) -> SimpleNamespace: + return SimpleNamespace(plugin_name=name, **config) + return factory + + +class PluginRegistrySet: + """Bundles multiple PluginRegistry instances, one per domain.""" + + def __init__(self): + self._registries: dict[str, PluginRegistry] = {} + self._lock = threading.Lock() + + def get_registry(self, domain: str) -> PluginRegistry: + """Return the registry for the given domain, creating if needed.""" + with self._lock: + if domain not in self._registries: + self._registries[domain] = PluginRegistry(domain) + return self._registries[domain] + + def register_domain(self, domain: str) -> PluginRegistry: + """Explicitly create and return a new registry for the domain.""" + with self._lock: + if domain in self._registries: + raise DuplicatePluginError(domain, "domains") + registry = PluginRegistry(domain) + self._registries[domain] = registry + return registry + + @classmethod + def with_defaults(cls) -> "PluginRegistrySet": + """Create a new PluginRegistrySet pre-populated with built-in registries.""" + registry_set = cls() + + # Evaluators + evaluators = registry_set.get_registry("evaluators") + for name in ("exact_match", "structured_match", "semantic_similarity", "custom"): + evaluators.register(name, _make_default_factory(name)) + + # Fine-tune backends + ft_backends = registry_set.get_registry("fine_tune_backends") + for name in ("unsloth", "openai", "huggingface"): + ft_backends.register(name, _make_default_factory(name)) + + # Providers + providers = registry_set.get_registry("providers") + for name in ("anthropic", "openai", "google"): + providers.register(name, _make_default_factory(name)) + + # Local backends + local_backends = registry_set.get_registry("local_backends") + for name in ("ollama", "vllm", "llamacpp"): + local_backends.register(name, _make_default_factory(name)) + + # Decay functions + decay_functions = registry_set.get_registry("decay_functions") + for name in ("exponential", "linear", "step"): + decay_functions.register(name, _make_default_factory(name)) + + # Middleware (empty) + registry_set.get_registry("middleware") + + return registry_set + + def register_from_config(self, config_dict: dict) -> None: + """Register plugins from a config dict. + + Format: {"domain": {"name": {"factory": "dotted.path.ClassName"}, ...}, ...} + """ + for domain, plugins in config_dict.items(): + if not isinstance(plugins, dict): + raise PluginConfigError(domain, "", f"Expected dict, got {type(plugins).__name__}") + registry = self.get_registry(domain) + for name, plugin_config in plugins.items(): + if not isinstance(plugin_config, dict): + raise PluginConfigError(domain, name, f"Expected dict, got {type(plugin_config).__name__}") + factory_path = plugin_config.get("factory") + if not factory_path: + raise PluginConfigError(domain, name, "Missing 'factory' key") + if not isinstance(factory_path, str): + raise PluginConfigError(domain, name, f"'factory' must be a string, got {type(factory_path).__name__}") + try: + module_path, attr_name = factory_path.rsplit(".", 1) + module = importlib.import_module(module_path) + factory = getattr(module, attr_name) + except (ValueError, ImportError, AttributeError) as e: + raise PluginConfigError(domain, name, f"Failed to import '{factory_path}': {e}") + if not callable(factory): + raise PluginConfigError(domain, name, f"'{factory_path}' is not callable") + registry.register(name, factory) + + def list_domains(self) -> list[str]: + """Return sorted list of domain names.""" + with self._lock: + return sorted(self._registries.keys()) diff --git a/src/apprentice/root/composition.py b/src/apprentice/root/composition.py index f4b6d43..d83b516 100644 --- a/src/apprentice/root/composition.py +++ b/src/apprentice/root/composition.py @@ -25,6 +25,7 @@ class InitializationPhase(str, Enum): CONFIG_PARSE = "CONFIG_PARSE" + PLUGIN_REGISTRY_INIT = "PLUGIN_REGISTRY_INIT" REGISTRY_BUILD = "REGISTRY_BUILD" BUDGET_MANAGER_INIT = "BUDGET_MANAGER_INIT" CONFIDENCE_ENGINE_INIT = "CONFIDENCE_ENGINE_INIT" @@ -47,6 +48,7 @@ class ComponentStatus(str, Enum): class ComponentId(str, Enum): CONFIG_AND_REGISTRY = "CONFIG_AND_REGISTRY" + PLUGIN_REGISTRY = "PLUGIN_REGISTRY" BUDGET_MANAGER = "BUDGET_MANAGER" CONFIDENCE_ENGINE = "CONFIDENCE_ENGINE" EXTERNAL_INTERFACES = "EXTERNAL_INTERFACES" @@ -203,6 +205,7 @@ def __init__( _COMPONENT_IDS = [ "config_and_registry", + "plugin_registry", "budget_manager", "confidence_engine", "external_interfaces", @@ -215,6 +218,7 @@ def __init__( _COMPONENT_ID_TO_ENUM = { "config_and_registry": ComponentId.CONFIG_AND_REGISTRY, + "plugin_registry": ComponentId.PLUGIN_REGISTRY, "budget_manager": ComponentId.BUDGET_MANAGER, "confidence_engine": ComponentId.CONFIDENCE_ENGINE, "external_interfaces": ComponentId.EXTERNAL_INTERFACES, @@ -227,6 +231,7 @@ def __init__( _COMPONENT_PHASE = { "config_and_registry": InitializationPhase.CONFIG_PARSE, + "plugin_registry": InitializationPhase.PLUGIN_REGISTRY_INIT, "budget_manager": InitializationPhase.BUDGET_MANAGER_INIT, "confidence_engine": InitializationPhase.CONFIDENCE_ENGINE_INIT, "external_interfaces": InitializationPhase.EXTERNAL_INTERFACES_INIT, @@ -430,6 +435,7 @@ async def initialize_composition_root( # DAG order init phases _init_phases = [ ("config_and_registry", InitializationPhase.CONFIG_PARSE), + ("plugin_registry", InitializationPhase.PLUGIN_REGISTRY_INIT), ("budget_manager", InitializationPhase.BUDGET_MANAGER_INIT), ("confidence_engine", InitializationPhase.CONFIDENCE_ENGINE_INIT), ("external_interfaces", InitializationPhase.EXTERNAL_INTERFACES_INIT), @@ -493,6 +499,10 @@ def _construct_component(cid: str, raw_config: dict, initialized: dict, env: dic cr.task_registry.task_names = lambda: [t.get("name", "") for t in config.tasks] if isinstance(config.tasks, list) else [] return cr + elif cid == "plugin_registry": + from apprentice.plugin_registry import PluginRegistrySet + return PluginRegistrySet.with_defaults() + elif cid == "budget_manager": return BudgetManager() diff --git a/tests/test_feedback_collector.py b/tests/test_feedback_collector.py new file mode 100644 index 0000000..7424f78 --- /dev/null +++ b/tests/test_feedback_collector.py @@ -0,0 +1,555 @@ +""" +Contract tests for FeedbackCollector. + +Tests verify FeedbackEntry/FeedbackSummary/FeedbackConfig model construction, +FeedbackCollector disabled mode, record/retrieve/summarize workflows, +acceptance rate computation, confidence adjustment, task listing, +and on-disk persistence via JSON-lines. +""" + +import json +import uuid + +import pytest + +from src.feedback_collector import ( + FeedbackType, + FeedbackEntry, + FeedbackSummary, + FeedbackConfig, + FeedbackCollector, +) + + +# =================================================================== +# 1. FEEDBACKENTRY MODEL TESTS +# =================================================================== + +class TestFeedbackEntry: + """Test FeedbackEntry construction, defaults, and immutability.""" + + def test_creation_with_defaults(self): + """FeedbackEntry populates feedback_id, score, and timestamp automatically.""" + entry = FeedbackEntry( + request_id="req-001", + task_name="summarize", + feedback_type=FeedbackType.accept, + ) + + assert entry.request_id == "req-001" + assert entry.task_name == "summarize" + assert entry.feedback_type == FeedbackType.accept + assert entry.score == 0.0 + assert entry.edited_output is None + assert entry.reason is None + # feedback_id should be a valid UUID string + uuid.UUID(entry.feedback_id) + # timestamp should be a non-empty ISO-8601 string + assert len(entry.timestamp) > 0 + + def test_creation_with_all_fields(self): + """FeedbackEntry accepts all explicit fields.""" + entry = FeedbackEntry( + feedback_id="custom-id-123", + request_id="req-002", + task_name="classify", + feedback_type=FeedbackType.edit, + score=0.75, + edited_output={"label": "positive"}, + reason="Corrected sentiment label", + timestamp="2026-01-15T12:00:00+00:00", + ) + + assert entry.feedback_id == "custom-id-123" + assert entry.request_id == "req-002" + assert entry.task_name == "classify" + assert entry.feedback_type == FeedbackType.edit + assert entry.score == 0.75 + assert entry.edited_output == {"label": "positive"} + assert entry.reason == "Corrected sentiment label" + assert entry.timestamp == "2026-01-15T12:00:00+00:00" + + def test_frozen_immutability(self): + """FeedbackEntry is frozen and cannot be mutated after creation.""" + entry = FeedbackEntry( + request_id="req-003", + task_name="extract", + feedback_type=FeedbackType.reject, + ) + with pytest.raises((AttributeError, TypeError, Exception)): + entry.score = 0.5 + + def test_score_lower_bound(self): + """Score below 0.0 is rejected by the validator.""" + with pytest.raises(Exception): + FeedbackEntry( + request_id="req-004", + task_name="task", + feedback_type=FeedbackType.ai_score, + score=-0.1, + ) + + def test_score_upper_bound(self): + """Score above 1.0 is rejected by the validator.""" + with pytest.raises(Exception): + FeedbackEntry( + request_id="req-005", + task_name="task", + feedback_type=FeedbackType.ai_score, + score=1.1, + ) + + +# =================================================================== +# 2. FEEDBACKSUMMARY MODEL TESTS +# =================================================================== + +class TestFeedbackSummary: + """Test FeedbackSummary construction and defaults.""" + + def test_creation_defaults(self): + """FeedbackSummary defaults all counters and rates to zero.""" + summary = FeedbackSummary(task_name="summarize") + + assert summary.task_name == "summarize" + assert summary.accept_count == 0 + assert summary.reject_count == 0 + assert summary.edit_count == 0 + assert summary.ignore_count == 0 + assert summary.ai_score_count == 0 + assert summary.total_count == 0 + assert summary.acceptance_rate == 0.0 + assert summary.average_ai_score == 0.0 + + def test_frozen_immutability(self): + """FeedbackSummary is frozen and cannot be mutated.""" + summary = FeedbackSummary(task_name="classify") + with pytest.raises((AttributeError, TypeError, Exception)): + summary.accept_count = 99 + + +# =================================================================== +# 3. FEEDBACKCONFIG MODEL TESTS +# =================================================================== + +class TestFeedbackConfig: + """Test FeedbackConfig construction and defaults.""" + + def test_defaults(self): + """FeedbackConfig defaults to disabled with standard storage dir.""" + cfg = FeedbackConfig() + assert cfg.enabled is False + assert cfg.storage_dir == ".apprentice/feedback/" + + def test_custom_values(self): + """FeedbackConfig accepts custom values.""" + cfg = FeedbackConfig(enabled=True, storage_dir="/tmp/my_feedback/") + assert cfg.enabled is True + assert cfg.storage_dir == "/tmp/my_feedback/" + + def test_frozen_immutability(self): + """FeedbackConfig is frozen and cannot be mutated.""" + cfg = FeedbackConfig() + with pytest.raises((AttributeError, TypeError, Exception)): + cfg.enabled = True + + +# =================================================================== +# 4. FEEDBACKCOLLECTOR — DISABLED MODE +# =================================================================== + +class TestFeedbackCollectorDisabled: + """When disabled, the collector is a safe no-op.""" + + def test_record_feedback_noop(self, tmp_path): + """record_feedback does nothing when disabled.""" + collector = FeedbackCollector(storage_dir=str(tmp_path / "fb"), enabled=False) + entry = FeedbackEntry( + request_id="req-100", + task_name="summarize", + feedback_type=FeedbackType.accept, + ) + collector.record_feedback(entry) + # Directory should not be created + assert not (tmp_path / "fb").exists() + + def test_get_feedback_summary_empty(self, tmp_path): + """get_feedback_summary returns empty summary when disabled.""" + collector = FeedbackCollector(storage_dir=str(tmp_path / "fb"), enabled=False) + summary = collector.get_feedback_summary("summarize") + + assert summary.task_name == "summarize" + assert summary.total_count == 0 + assert summary.acceptance_rate == 0.0 + + def test_compute_feedback_adjustment_zero(self, tmp_path): + """compute_feedback_adjustment returns 0.0 when disabled.""" + collector = FeedbackCollector(storage_dir=str(tmp_path / "fb"), enabled=False) + assert collector.compute_feedback_adjustment("summarize") == 0.0 + + def test_list_tasks_empty(self, tmp_path): + """list_tasks returns empty list when disabled.""" + collector = FeedbackCollector(storage_dir=str(tmp_path / "fb"), enabled=False) + assert collector.list_tasks() == [] + + +# =================================================================== +# 5. FEEDBACKCOLLECTOR — RECORD AND RETRIEVE +# =================================================================== + +class TestFeedbackCollectorRecordRetrieve: + """Test record_feedback -> get_feedback_summary round-trip.""" + + def test_record_and_retrieve_summary(self, tmp_path): + """Recording a single accept entry yields correct summary.""" + collector = FeedbackCollector(storage_dir=str(tmp_path / "fb"), enabled=True) + entry = FeedbackEntry( + request_id="req-200", + task_name="summarize", + feedback_type=FeedbackType.accept, + ) + collector.record_feedback(entry) + + summary = collector.get_feedback_summary("summarize") + assert summary.task_name == "summarize" + assert summary.accept_count == 1 + assert summary.total_count == 1 + assert summary.acceptance_rate == 1.0 + + def test_multiple_feedback_types(self, tmp_path): + """Summary correctly tallies mixed feedback types.""" + collector = FeedbackCollector(storage_dir=str(tmp_path / "fb"), enabled=True) + + types_to_record = [ + FeedbackType.accept, + FeedbackType.accept, + FeedbackType.reject, + FeedbackType.edit, + FeedbackType.ignore, + FeedbackType.ai_score, + ] + for i, ft in enumerate(types_to_record): + entry = FeedbackEntry( + request_id=f"req-{300 + i}", + task_name="classify", + feedback_type=ft, + score=0.8 if ft == FeedbackType.ai_score else 0.0, + ) + collector.record_feedback(entry) + + summary = collector.get_feedback_summary("classify") + assert summary.accept_count == 2 + assert summary.reject_count == 1 + assert summary.edit_count == 1 + assert summary.ignore_count == 1 + assert summary.ai_score_count == 1 + assert summary.total_count == 6 + + def test_acceptance_rate_computation(self, tmp_path): + """acceptance_rate = accepts / (accepts + rejects).""" + collector = FeedbackCollector(storage_dir=str(tmp_path / "fb"), enabled=True) + + # 3 accepts, 1 reject -> rate = 3/4 = 0.75 + for ft in [FeedbackType.accept, FeedbackType.accept, FeedbackType.accept, FeedbackType.reject]: + collector.record_feedback(FeedbackEntry( + request_id=str(uuid.uuid4()), + task_name="rate_task", + feedback_type=ft, + )) + + summary = collector.get_feedback_summary("rate_task") + assert summary.acceptance_rate == pytest.approx(0.75) + + def test_acceptance_rate_no_rejects(self, tmp_path): + """acceptance_rate is 1.0 when there are only accepts.""" + collector = FeedbackCollector(storage_dir=str(tmp_path / "fb"), enabled=True) + + for _ in range(5): + collector.record_feedback(FeedbackEntry( + request_id=str(uuid.uuid4()), + task_name="perfect_task", + feedback_type=FeedbackType.accept, + )) + + summary = collector.get_feedback_summary("perfect_task") + assert summary.acceptance_rate == pytest.approx(1.0) + + def test_acceptance_rate_no_accepts(self, tmp_path): + """acceptance_rate is 0.0 when there are only rejects.""" + collector = FeedbackCollector(storage_dir=str(tmp_path / "fb"), enabled=True) + + for _ in range(3): + collector.record_feedback(FeedbackEntry( + request_id=str(uuid.uuid4()), + task_name="bad_task", + feedback_type=FeedbackType.reject, + )) + + summary = collector.get_feedback_summary("bad_task") + assert summary.acceptance_rate == pytest.approx(0.0) + + def test_average_ai_score(self, tmp_path): + """average_ai_score is the mean of scores across ai_score entries.""" + collector = FeedbackCollector(storage_dir=str(tmp_path / "fb"), enabled=True) + + scores = [0.6, 0.8, 1.0] + for s in scores: + collector.record_feedback(FeedbackEntry( + request_id=str(uuid.uuid4()), + task_name="scored_task", + feedback_type=FeedbackType.ai_score, + score=s, + )) + + summary = collector.get_feedback_summary("scored_task") + assert summary.average_ai_score == pytest.approx(0.8) + + def test_summary_for_unknown_task(self, tmp_path): + """get_feedback_summary for a task with no data returns empty summary.""" + collector = FeedbackCollector(storage_dir=str(tmp_path / "fb"), enabled=True) + summary = collector.get_feedback_summary("nonexistent") + assert summary.total_count == 0 + assert summary.acceptance_rate == 0.0 + + +# =================================================================== +# 6. FEEDBACKCOLLECTOR — CONFIDENCE ADJUSTMENT +# =================================================================== + +class TestFeedbackCollectorAdjustment: + """Test compute_feedback_adjustment.""" + + def test_positive_adjustment(self, tmp_path): + """Mostly accepts yields a positive adjustment.""" + collector = FeedbackCollector(storage_dir=str(tmp_path / "fb"), enabled=True) + + # 9 accepts, 1 reject -> ratio = 8/10 = 0.8 -> adjustment = 0.08 + for _ in range(9): + collector.record_feedback(FeedbackEntry( + request_id=str(uuid.uuid4()), + task_name="good_task", + feedback_type=FeedbackType.accept, + )) + collector.record_feedback(FeedbackEntry( + request_id=str(uuid.uuid4()), + task_name="good_task", + feedback_type=FeedbackType.reject, + )) + + adj = collector.compute_feedback_adjustment("good_task") + assert adj > 0.0 + assert adj <= 0.1 + + def test_negative_adjustment(self, tmp_path): + """Mostly rejects yields a negative adjustment.""" + collector = FeedbackCollector(storage_dir=str(tmp_path / "fb"), enabled=True) + + # 1 accept, 9 rejects -> ratio = -8/10 = -0.8 -> adjustment = -0.08 + collector.record_feedback(FeedbackEntry( + request_id=str(uuid.uuid4()), + task_name="bad_task", + feedback_type=FeedbackType.accept, + )) + for _ in range(9): + collector.record_feedback(FeedbackEntry( + request_id=str(uuid.uuid4()), + task_name="bad_task", + feedback_type=FeedbackType.reject, + )) + + adj = collector.compute_feedback_adjustment("bad_task") + assert adj < 0.0 + assert adj >= -0.1 + + def test_all_accepts_max_positive(self, tmp_path): + """All accepts produces +0.1.""" + collector = FeedbackCollector(storage_dir=str(tmp_path / "fb"), enabled=True) + + for _ in range(5): + collector.record_feedback(FeedbackEntry( + request_id=str(uuid.uuid4()), + task_name="perfect", + feedback_type=FeedbackType.accept, + )) + + adj = collector.compute_feedback_adjustment("perfect") + assert adj == pytest.approx(0.1) + + def test_all_rejects_max_negative(self, tmp_path): + """All rejects produces -0.1.""" + collector = FeedbackCollector(storage_dir=str(tmp_path / "fb"), enabled=True) + + for _ in range(5): + collector.record_feedback(FeedbackEntry( + request_id=str(uuid.uuid4()), + task_name="terrible", + feedback_type=FeedbackType.reject, + )) + + adj = collector.compute_feedback_adjustment("terrible") + assert adj == pytest.approx(-0.1) + + def test_no_feedback_returns_zero(self, tmp_path): + """No feedback for a task yields 0.0 adjustment.""" + collector = FeedbackCollector(storage_dir=str(tmp_path / "fb"), enabled=True) + adj = collector.compute_feedback_adjustment("empty_task") + assert adj == 0.0 + + def test_only_non_accept_reject_returns_zero(self, tmp_path): + """Only edit/ignore/ai_score entries (no accept or reject) yields 0.0.""" + collector = FeedbackCollector(storage_dir=str(tmp_path / "fb"), enabled=True) + + for ft in [FeedbackType.edit, FeedbackType.ignore, FeedbackType.ai_score]: + collector.record_feedback(FeedbackEntry( + request_id=str(uuid.uuid4()), + task_name="neutral", + feedback_type=ft, + score=0.5 if ft == FeedbackType.ai_score else 0.0, + )) + + adj = collector.compute_feedback_adjustment("neutral") + assert adj == 0.0 + + def test_adjustment_bounded(self, tmp_path): + """Adjustment is always in [-0.1, +0.1].""" + collector = FeedbackCollector(storage_dir=str(tmp_path / "fb"), enabled=True) + + # Record a mix + for _ in range(50): + collector.record_feedback(FeedbackEntry( + request_id=str(uuid.uuid4()), + task_name="bounded", + feedback_type=FeedbackType.accept, + )) + for _ in range(50): + collector.record_feedback(FeedbackEntry( + request_id=str(uuid.uuid4()), + task_name="bounded", + feedback_type=FeedbackType.reject, + )) + + adj = collector.compute_feedback_adjustment("bounded") + assert -0.1 <= adj <= 0.1 + + +# =================================================================== +# 7. FEEDBACKCOLLECTOR — LIST TASKS +# =================================================================== + +class TestFeedbackCollectorListTasks: + """Test list_tasks.""" + + def test_list_tasks_multiple(self, tmp_path): + """list_tasks returns all tasks with recorded feedback.""" + collector = FeedbackCollector(storage_dir=str(tmp_path / "fb"), enabled=True) + + for task in ["alpha", "beta", "gamma"]: + collector.record_feedback(FeedbackEntry( + request_id=str(uuid.uuid4()), + task_name=task, + feedback_type=FeedbackType.accept, + )) + + tasks = collector.list_tasks() + assert set(tasks) == {"alpha", "beta", "gamma"} + + def test_list_tasks_empty_when_no_data(self, tmp_path): + """list_tasks returns empty list when no feedback has been recorded.""" + collector = FeedbackCollector(storage_dir=str(tmp_path / "fb"), enabled=True) + assert collector.list_tasks() == [] + + def test_list_tasks_no_directory(self, tmp_path): + """list_tasks returns empty when storage dir does not exist.""" + collector = FeedbackCollector( + storage_dir=str(tmp_path / "nonexistent_fb"), + enabled=True, + ) + assert collector.list_tasks() == [] + + +# =================================================================== +# 8. FEEDBACKCOLLECTOR — PERSISTENCE TO DISK +# =================================================================== + +class TestFeedbackCollectorPersistence: + """Test that data is persisted as JSON-lines on disk and survives re-instantiation.""" + + def test_file_created_on_record(self, tmp_path): + """Recording feedback creates a .jsonl file for the task.""" + storage = str(tmp_path / "fb") + collector = FeedbackCollector(storage_dir=storage, enabled=True) + + collector.record_feedback(FeedbackEntry( + request_id="req-persist-1", + task_name="persist_task", + feedback_type=FeedbackType.accept, + )) + + jsonl_path = tmp_path / "fb" / "persist_task.jsonl" + assert jsonl_path.exists() + + def test_jsonl_format(self, tmp_path): + """Each line in the file is valid JSON representing a FeedbackEntry.""" + storage = str(tmp_path / "fb") + collector = FeedbackCollector(storage_dir=storage, enabled=True) + + for i in range(3): + collector.record_feedback(FeedbackEntry( + request_id=f"req-fmt-{i}", + task_name="format_task", + feedback_type=FeedbackType.accept, + )) + + jsonl_path = tmp_path / "fb" / "format_task.jsonl" + lines = jsonl_path.read_text().strip().split("\n") + assert len(lines) == 3 + + for line in lines: + data = json.loads(line) + assert "request_id" in data + assert "feedback_type" in data + + def test_survives_reinstantiation(self, tmp_path): + """Data written by one collector instance is readable by a new one.""" + storage = str(tmp_path / "fb") + + collector1 = FeedbackCollector(storage_dir=storage, enabled=True) + for i in range(5): + collector1.record_feedback(FeedbackEntry( + request_id=f"req-surv-{i}", + task_name="survive_task", + feedback_type=FeedbackType.accept if i < 3 else FeedbackType.reject, + )) + + # New instance, same directory + collector2 = FeedbackCollector(storage_dir=storage, enabled=True) + summary = collector2.get_feedback_summary("survive_task") + + assert summary.total_count == 5 + assert summary.accept_count == 3 + assert summary.reject_count == 2 + assert summary.acceptance_rate == pytest.approx(0.6) + + def test_append_only_semantics(self, tmp_path): + """Subsequent record_feedback calls append, never overwrite.""" + storage = str(tmp_path / "fb") + collector = FeedbackCollector(storage_dir=storage, enabled=True) + + collector.record_feedback(FeedbackEntry( + request_id="req-app-1", + task_name="append_task", + feedback_type=FeedbackType.accept, + )) + + # Re-instantiate and append more + collector2 = FeedbackCollector(storage_dir=storage, enabled=True) + collector2.record_feedback(FeedbackEntry( + request_id="req-app-2", + task_name="append_task", + feedback_type=FeedbackType.reject, + )) + + summary = collector2.get_feedback_summary("append_task") + assert summary.total_count == 2 + assert summary.accept_count == 1 + assert summary.reject_count == 1 diff --git a/tests/test_middleware.py b/tests/test_middleware.py new file mode 100644 index 0000000..d85c3dd --- /dev/null +++ b/tests/test_middleware.py @@ -0,0 +1,620 @@ +""" +Tests for the Middleware Pipeline component. + +Covers: MiddlewareContext, MiddlewareResponse, Middleware protocol, +MiddlewarePipeline (pre/post ordering, error resilience, state accumulation), +MiddlewareError, and from_config factory. +""" + +import logging +import pytest +from unittest.mock import MagicMock + +from apprentice.middleware import ( + Middleware, + MiddlewareContext, + MiddlewareError, + MiddlewarePipeline, + MiddlewareResponse, +) + + +# =========================================================================== +# Fixtures +# =========================================================================== + + +@pytest.fixture +def base_context(): + """A minimal valid MiddlewareContext.""" + return MiddlewareContext( + request_id="req-001", + task_name="summarize", + input_data={"text": "hello world"}, + ) + + +@pytest.fixture +def base_response(): + """A minimal valid MiddlewareResponse.""" + return MiddlewareResponse( + output_data={"summary": "hello"}, + ) + + +class PassthroughMiddleware: + """Middleware that passes data through unchanged.""" + + def pre_process(self, context: MiddlewareContext) -> MiddlewareContext: + return context + + def post_process( + self, context: MiddlewareContext, response: MiddlewareResponse + ) -> MiddlewareResponse: + return response + + +class AnnotatingMiddleware: + """Middleware that injects a marker into middleware_state so we can + verify ordering and accumulation.""" + + def __init__(self, tag: str) -> None: + self.tag = tag + + def pre_process(self, context: MiddlewareContext) -> MiddlewareContext: + pre_order = context.middleware_state.get("pre_order", []) + new_state = { + **context.middleware_state, + "pre_order": pre_order + [self.tag], + f"pre_{self.tag}": True, + } + return context.model_copy(update={"middleware_state": new_state}) + + def post_process( + self, context: MiddlewareContext, response: MiddlewareResponse + ) -> MiddlewareResponse: + post_order = response.middleware_state.get("post_order", []) + new_state = { + **response.middleware_state, + "post_order": post_order + [self.tag], + f"post_{self.tag}": True, + } + return response.model_copy(update={"middleware_state": new_state}) + + +class FailingPreMiddleware: + """Middleware whose pre_process always raises.""" + + def pre_process(self, context: MiddlewareContext) -> MiddlewareContext: + raise RuntimeError("pre_process boom") + + def post_process( + self, context: MiddlewareContext, response: MiddlewareResponse + ) -> MiddlewareResponse: + return response + + +class FailingPostMiddleware: + """Middleware whose post_process always raises.""" + + def pre_process(self, context: MiddlewareContext) -> MiddlewareContext: + return context + + def post_process( + self, context: MiddlewareContext, response: MiddlewareResponse + ) -> MiddlewareResponse: + raise RuntimeError("post_process boom") + + +class TransformingMiddleware: + """Middleware that modifies input_data (pre) and output_data (post).""" + + def __init__(self, key: str, value: str) -> None: + self.key = key + self.value = value + + def pre_process(self, context: MiddlewareContext) -> MiddlewareContext: + new_input = {**context.input_data, self.key: self.value} + return context.model_copy(update={"input_data": new_input}) + + def post_process( + self, context: MiddlewareContext, response: MiddlewareResponse + ) -> MiddlewareResponse: + new_output = {**response.output_data, self.key: self.value} + return response.model_copy(update={"output_data": new_output}) + + +# =========================================================================== +# MiddlewareContext Tests +# =========================================================================== + + +class TestMiddlewareContext: + + def test_creation_with_required_fields(self): + ctx = MiddlewareContext( + request_id="r1", + task_name="classify", + input_data={"text": "hi"}, + ) + assert ctx.request_id == "r1" + assert ctx.task_name == "classify" + assert ctx.input_data == {"text": "hi"} + assert ctx.metadata == {} + assert ctx.middleware_state == {} + + def test_creation_with_all_fields(self): + ctx = MiddlewareContext( + request_id="r2", + task_name="translate", + input_data={"text": "bonjour"}, + metadata={"lang": "fr"}, + middleware_state={"token_count": 5}, + ) + assert ctx.metadata == {"lang": "fr"} + assert ctx.middleware_state == {"token_count": 5} + + def test_frozen_immutability(self, base_context): + with pytest.raises((AttributeError, TypeError, Exception)): + base_context.request_id = "changed" + + def test_frozen_immutability_input_data(self, base_context): + with pytest.raises((AttributeError, TypeError, Exception)): + base_context.input_data = {"new": "data"} + + def test_model_copy_produces_new_instance(self, base_context): + updated = base_context.model_copy( + update={"middleware_state": {"key": "val"}} + ) + assert updated is not base_context + assert updated.middleware_state == {"key": "val"} + assert base_context.middleware_state == {} + + +# =========================================================================== +# MiddlewareResponse Tests +# =========================================================================== + + +class TestMiddlewareResponse: + + def test_creation_with_required_fields(self): + resp = MiddlewareResponse(output_data={"result": 42}) + assert resp.output_data == {"result": 42} + assert resp.metadata == {} + assert resp.middleware_state == {} + + def test_creation_with_all_fields(self): + resp = MiddlewareResponse( + output_data={"result": 42}, + metadata={"model": "gpt-4"}, + middleware_state={"tokens_used": 100}, + ) + assert resp.metadata == {"model": "gpt-4"} + assert resp.middleware_state == {"tokens_used": 100} + + def test_frozen_immutability(self, base_response): + with pytest.raises((AttributeError, TypeError, Exception)): + base_response.output_data = {"hacked": True} + + def test_model_copy_produces_new_instance(self, base_response): + updated = base_response.model_copy( + update={"middleware_state": {"flag": True}} + ) + assert updated is not base_response + assert updated.middleware_state == {"flag": True} + assert base_response.middleware_state == {} + + +# =========================================================================== +# MiddlewareError Tests +# =========================================================================== + + +class TestMiddlewareError: + + def test_attributes(self): + err = MiddlewareError("PiiTokenizer", "pre_process", "connection timeout") + assert err.middleware_name == "PiiTokenizer" + assert err.phase == "pre_process" + assert err.reason == "connection timeout" + + def test_inherits_from_exception(self): + err = MiddlewareError("X", "post_process", "boom") + assert isinstance(err, Exception) + + def test_string_representation(self): + err = MiddlewareError("MyMW", "pre_process", "bad input") + assert "MyMW" in str(err) + assert "pre_process" in str(err) + assert "bad input" in str(err) + + +# =========================================================================== +# Middleware Protocol Tests +# =========================================================================== + + +class TestMiddlewareProtocol: + + def test_passthrough_satisfies_protocol(self): + mw = PassthroughMiddleware() + assert isinstance(mw, Middleware) + + def test_annotating_satisfies_protocol(self): + mw = AnnotatingMiddleware("test") + assert isinstance(mw, Middleware) + + def test_object_without_methods_fails_protocol(self): + assert not isinstance("not a middleware", Middleware) + assert not isinstance(42, Middleware) + assert not isinstance({}, Middleware) + + def test_partial_implementation_fails_protocol(self): + class OnlyPre: + def pre_process(self, context): + return context + + # Missing post_process, should not satisfy the protocol + assert not isinstance(OnlyPre(), Middleware) + + +# =========================================================================== +# MiddlewarePipeline — Empty Pipeline (no-op) +# =========================================================================== + + +class TestEmptyPipeline: + + def test_execute_pre_passthrough(self, base_context): + pipeline = MiddlewarePipeline([]) + result = pipeline.execute_pre(base_context) + assert result == base_context + + def test_execute_post_passthrough(self, base_context, base_response): + pipeline = MiddlewarePipeline([]) + result = pipeline.execute_post(base_context, base_response) + assert result == base_response + + def test_none_middlewares_is_empty(self, base_context): + pipeline = MiddlewarePipeline(None) + result = pipeline.execute_pre(base_context) + assert result == base_context + + def test_default_constructor_is_empty(self, base_context): + pipeline = MiddlewarePipeline() + result = pipeline.execute_pre(base_context) + assert result == base_context + + +# =========================================================================== +# MiddlewarePipeline — Single Middleware +# =========================================================================== + + +class TestSingleMiddleware: + + def test_pre_process_runs(self, base_context): + pipeline = MiddlewarePipeline([AnnotatingMiddleware("alpha")]) + result = pipeline.execute_pre(base_context) + assert result.middleware_state.get("pre_alpha") is True + + def test_post_process_runs(self, base_context, base_response): + pipeline = MiddlewarePipeline([AnnotatingMiddleware("alpha")]) + result = pipeline.execute_post(base_context, base_response) + assert result.middleware_state.get("post_alpha") is True + + def test_passthrough_does_not_modify(self, base_context, base_response): + pipeline = MiddlewarePipeline([PassthroughMiddleware()]) + pre_result = pipeline.execute_pre(base_context) + assert pre_result.input_data == base_context.input_data + + post_result = pipeline.execute_post(base_context, base_response) + assert post_result.output_data == base_response.output_data + + +# =========================================================================== +# MiddlewarePipeline — Multiple Middleware (ordering) +# =========================================================================== + + +class TestMultipleMiddlewareOrdering: + + def test_pre_process_runs_in_forward_order(self, base_context): + pipeline = MiddlewarePipeline([ + AnnotatingMiddleware("first"), + AnnotatingMiddleware("second"), + AnnotatingMiddleware("third"), + ]) + result = pipeline.execute_pre(base_context) + assert result.middleware_state["pre_order"] == ["first", "second", "third"] + + def test_post_process_runs_in_reverse_order(self, base_context, base_response): + pipeline = MiddlewarePipeline([ + AnnotatingMiddleware("first"), + AnnotatingMiddleware("second"), + AnnotatingMiddleware("third"), + ]) + result = pipeline.execute_post(base_context, base_response) + # Reverse order: third -> second -> first + assert result.middleware_state["post_order"] == ["third", "second", "first"] + + def test_all_middleware_state_accumulated_pre(self, base_context): + pipeline = MiddlewarePipeline([ + AnnotatingMiddleware("a"), + AnnotatingMiddleware("b"), + ]) + result = pipeline.execute_pre(base_context) + assert result.middleware_state.get("pre_a") is True + assert result.middleware_state.get("pre_b") is True + + def test_all_middleware_state_accumulated_post(self, base_context, base_response): + pipeline = MiddlewarePipeline([ + AnnotatingMiddleware("a"), + AnnotatingMiddleware("b"), + ]) + result = pipeline.execute_post(base_context, base_response) + assert result.middleware_state.get("post_a") is True + assert result.middleware_state.get("post_b") is True + + def test_transforming_middleware_composes(self, base_context, base_response): + pipeline = MiddlewarePipeline([ + TransformingMiddleware("injected_a", "val_a"), + TransformingMiddleware("injected_b", "val_b"), + ]) + pre_result = pipeline.execute_pre(base_context) + assert pre_result.input_data["injected_a"] == "val_a" + assert pre_result.input_data["injected_b"] == "val_b" + + post_result = pipeline.execute_post(base_context, base_response) + assert post_result.output_data["injected_a"] == "val_a" + assert post_result.output_data["injected_b"] == "val_b" + + +# =========================================================================== +# MiddlewarePipeline — Error Resilience +# =========================================================================== + + +class TestErrorResilience: + + def test_pre_process_error_skips_and_continues(self, base_context, caplog): + pipeline = MiddlewarePipeline([ + AnnotatingMiddleware("before"), + FailingPreMiddleware(), + AnnotatingMiddleware("after"), + ]) + with caplog.at_level(logging.ERROR): + result = pipeline.execute_pre(base_context) + + # "before" and "after" should have run; FailingPreMiddleware skipped + assert result.middleware_state.get("pre_before") is True + assert result.middleware_state.get("pre_after") is True + assert "FailingPreMiddleware" in caplog.text + assert "pre_process" in caplog.text + + def test_post_process_error_skips_and_continues( + self, base_context, base_response, caplog + ): + pipeline = MiddlewarePipeline([ + AnnotatingMiddleware("outer"), + FailingPostMiddleware(), + AnnotatingMiddleware("inner"), + ]) + with caplog.at_level(logging.ERROR): + result = pipeline.execute_post(base_context, base_response) + + # Reverse order: inner -> FailingPost (skipped) -> outer + assert result.middleware_state.get("post_outer") is True + assert result.middleware_state.get("post_inner") is True + assert "FailingPostMiddleware" in caplog.text + assert "post_process" in caplog.text + + def test_all_middleware_fail_returns_original_context(self, base_context, caplog): + pipeline = MiddlewarePipeline([ + FailingPreMiddleware(), + FailingPreMiddleware(), + ]) + with caplog.at_level(logging.ERROR): + result = pipeline.execute_pre(base_context) + + # Should get back essentially the original context + assert result.request_id == base_context.request_id + assert result.input_data == base_context.input_data + + def test_all_middleware_fail_returns_original_response( + self, base_context, base_response, caplog + ): + pipeline = MiddlewarePipeline([ + FailingPostMiddleware(), + FailingPostMiddleware(), + ]) + with caplog.at_level(logging.ERROR): + result = pipeline.execute_post(base_context, base_response) + + assert result.output_data == base_response.output_data + + +# =========================================================================== +# MiddlewarePipeline — Middleware State Accumulation +# =========================================================================== + + +class TestMiddlewareStateAccumulation: + + def test_state_accumulates_through_pre_pipeline(self, base_context): + pipeline = MiddlewarePipeline([ + AnnotatingMiddleware("step1"), + AnnotatingMiddleware("step2"), + AnnotatingMiddleware("step3"), + ]) + result = pipeline.execute_pre(base_context) + + # All steps should have left their mark + assert result.middleware_state["pre_step1"] is True + assert result.middleware_state["pre_step2"] is True + assert result.middleware_state["pre_step3"] is True + assert result.middleware_state["pre_order"] == ["step1", "step2", "step3"] + + def test_state_accumulates_through_post_pipeline( + self, base_context, base_response + ): + pipeline = MiddlewarePipeline([ + AnnotatingMiddleware("step1"), + AnnotatingMiddleware("step2"), + ]) + result = pipeline.execute_post(base_context, base_response) + + assert result.middleware_state["post_step1"] is True + assert result.middleware_state["post_step2"] is True + # Reverse order + assert result.middleware_state["post_order"] == ["step2", "step1"] + + def test_initial_state_preserved_through_pipeline(self): + ctx = MiddlewareContext( + request_id="r1", + task_name="task", + input_data={}, + middleware_state={"existing_key": "existing_value"}, + ) + pipeline = MiddlewarePipeline([AnnotatingMiddleware("mw")]) + result = pipeline.execute_pre(ctx) + + assert result.middleware_state["existing_key"] == "existing_value" + assert result.middleware_state["pre_mw"] is True + + def test_original_context_not_mutated(self, base_context): + pipeline = MiddlewarePipeline([AnnotatingMiddleware("mutator")]) + _ = pipeline.execute_pre(base_context) + + # The original context should be untouched (frozen model) + assert base_context.middleware_state == {} + + def test_original_response_not_mutated(self, base_context, base_response): + pipeline = MiddlewarePipeline([AnnotatingMiddleware("mutator")]) + _ = pipeline.execute_post(base_context, base_response) + + # The original response should be untouched (frozen model) + assert base_response.middleware_state == {} + + +# =========================================================================== +# MiddlewarePipeline — from_config +# =========================================================================== + + +class TestFromConfig: + + def test_empty_config_returns_empty_pipeline(self): + registry = MagicMock() + pipeline = MiddlewarePipeline.from_config([], registry) + assert len(pipeline.middlewares) == 0 + + def test_none_config_returns_empty_pipeline(self): + registry = MagicMock() + pipeline = MiddlewarePipeline.from_config(None, registry) + assert len(pipeline.middlewares) == 0 + + def test_single_entry_config(self): + mw = PassthroughMiddleware() + registry = MagicMock() + registry.create.return_value = mw + + config = [{"name": "passthrough", "config": {}}] + pipeline = MiddlewarePipeline.from_config(config, registry) + + registry.create.assert_called_once_with("passthrough") + assert len(pipeline.middlewares) == 1 + + def test_multiple_entries_config(self): + mw_a = AnnotatingMiddleware("a") + mw_b = AnnotatingMiddleware("b") + registry = MagicMock() + registry.create.side_effect = [mw_a, mw_b] + + config = [ + {"name": "annotator_a", "config": {"tag": "a"}}, + {"name": "annotator_b", "config": {"tag": "b"}}, + ] + pipeline = MiddlewarePipeline.from_config(config, registry) + + assert len(pipeline.middlewares) == 2 + assert registry.create.call_count == 2 + + def test_config_with_kwargs_passed_to_registry(self): + registry = MagicMock() + registry.create.return_value = PassthroughMiddleware() + + config = [{"name": "pii_tokenizer", "config": {"pattern": r"\d{3}-\d{4}", "replace": "***"}}] + MiddlewarePipeline.from_config(config, registry) + + registry.create.assert_called_once_with( + "pii_tokenizer", pattern=r"\d{3}-\d{4}", replace="***" + ) + + def test_config_missing_config_key_uses_empty_dict(self): + registry = MagicMock() + registry.create.return_value = PassthroughMiddleware() + + config = [{"name": "simple"}] + pipeline = MiddlewarePipeline.from_config(config, registry) + + registry.create.assert_called_once_with("simple") + assert len(pipeline.middlewares) == 1 + + def test_registry_error_skips_middleware(self, caplog): + registry = MagicMock() + registry.create.side_effect = [ + AnnotatingMiddleware("ok"), + ValueError("unknown plugin"), + AnnotatingMiddleware("also_ok"), + ] + + config = [ + {"name": "good1", "config": {}}, + {"name": "bad", "config": {}}, + {"name": "good2", "config": {}}, + ] + with caplog.at_level(logging.ERROR): + pipeline = MiddlewarePipeline.from_config(config, registry) + + # The bad one should be skipped + assert len(pipeline.middlewares) == 2 + assert "bad" in caplog.text + + def test_from_config_then_execute(self, base_context): + """End-to-end: build from config, then execute pre-processing.""" + mw = AnnotatingMiddleware("configured") + registry = MagicMock() + registry.create.return_value = mw + + config = [{"name": "configured_mw", "config": {"tag": "configured"}}] + pipeline = MiddlewarePipeline.from_config(config, registry) + + result = pipeline.execute_pre(base_context) + assert result.middleware_state.get("pre_configured") is True + + +# =========================================================================== +# MiddlewarePipeline — middlewares property +# =========================================================================== + + +class TestMiddlewaresProperty: + + def test_middlewares_returns_copy(self): + mw = PassthroughMiddleware() + pipeline = MiddlewarePipeline([mw]) + mws = pipeline.middlewares + mws.clear() # mutate the returned list + assert len(pipeline.middlewares) == 1 # internal list is unaffected + + def test_middlewares_preserves_order(self): + mw_a = AnnotatingMiddleware("a") + mw_b = AnnotatingMiddleware("b") + mw_c = AnnotatingMiddleware("c") + pipeline = MiddlewarePipeline([mw_a, mw_b, mw_c]) + mws = pipeline.middlewares + assert mws[0] is mw_a + assert mws[1] is mw_b + assert mws[2] is mw_c diff --git a/tests/test_observer.py b/tests/test_observer.py new file mode 100644 index 0000000..b575482 --- /dev/null +++ b/tests/test_observer.py @@ -0,0 +1,508 @@ +""" +Contract tests for the Observer component. + +Sections: +1. Fixtures and helpers +2. ObservationEvent model tests +3. ShadowRecommendation model tests +4. ObserverConfig model tests +5. EventType enum tests +6. Observer unit tests (disabled mode, observe, context window, record_action, + get_context, get_recommendations, clear, probabilistic generation) +""" + +import uuid +from unittest.mock import patch + +import pytest + +from src.observer import ( + EventType, + ObservationEvent, + ShadowRecommendation, + ObserverConfig, + Observer, + _compute_match_score, +) + + +# ═══════════════════════════════════════════════════════════════════════════ +# FIXTURES & HELPERS +# ═══════════════════════════════════════════════════════════════════════════ + + +def make_event( + task_name: str = "test-task", + event_type: EventType = EventType.user_action, + action_data: dict | None = None, + context: dict | None = None, + event_id: str | None = None, +) -> ObservationEvent: + kwargs: dict = { + "task_name": task_name, + "event_type": event_type, + } + if action_data is not None: + kwargs["action_data"] = action_data + if context is not None: + kwargs["context"] = context + if event_id is not None: + kwargs["event_id"] = event_id + return ObservationEvent(**kwargs) + + +def enabled_config(**overrides) -> ObserverConfig: + defaults = { + "enabled": True, + "context_window_size": 50, + "shadow_recommendation_rate": 0.1, + "min_context_before_recommending": 10, + } + defaults.update(overrides) + return ObserverConfig(**defaults) + + +def disabled_config() -> ObserverConfig: + return ObserverConfig(enabled=False) + + +@pytest.fixture +def observer_disabled(): + return Observer(disabled_config()) + + +@pytest.fixture +def observer_enabled(): + return Observer(enabled_config()) + + +@pytest.fixture +def observer_always_recommend(): + """Observer that always generates recommendations after 1 event.""" + return Observer( + enabled_config( + shadow_recommendation_rate=1.0, + min_context_before_recommending=1, + ) + ) + + +@pytest.fixture +def observer_never_recommend(): + """Observer that never generates recommendations.""" + return Observer( + enabled_config( + shadow_recommendation_rate=0.0, + min_context_before_recommending=1, + ) + ) + + +# ═══════════════════════════════════════════════════════════════════════════ +# EVENT TYPE ENUM +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestEventType: + def test_values(self): + assert EventType.user_action == "user_action" + assert EventType.agent_action == "agent_action" + assert EventType.system_event == "system_event" + + def test_member_count(self): + assert len(EventType) == 3 + + def test_is_str_enum(self): + assert isinstance(EventType.user_action, str) + + +# ═══════════════════════════════════════════════════════════════════════════ +# OBSERVATION EVENT MODEL +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestObservationEvent: + def test_creation_with_defaults(self): + event = ObservationEvent(task_name="my-task", event_type=EventType.user_action) + assert event.task_name == "my-task" + assert event.event_type == EventType.user_action + assert event.action_data == {} + assert event.context == {} + # event_id should be a valid UUID + uuid.UUID(event.event_id) + # timestamp should be a non-empty ISO string + assert len(event.timestamp) > 0 + + def test_creation_with_custom_values(self): + event = ObservationEvent( + event_id="custom-id", + task_name="task-2", + event_type=EventType.agent_action, + action_data={"key": "value"}, + context={"session": "abc"}, + timestamp="2025-01-01T00:00:00+00:00", + ) + assert event.event_id == "custom-id" + assert event.action_data == {"key": "value"} + assert event.context == {"session": "abc"} + assert event.timestamp == "2025-01-01T00:00:00+00:00" + + def test_frozen(self): + event = make_event() + with pytest.raises(Exception): + event.task_name = "changed" + + +# ═══════════════════════════════════════════════════════════════════════════ +# SHADOW RECOMMENDATION MODEL +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestShadowRecommendation: + def test_creation_defaults(self): + rec = ShadowRecommendation(event_id="ev-1") + uuid.UUID(rec.recommendation_id) + assert rec.event_id == "ev-1" + assert rec.recommended_action == {} + assert rec.actual_action is None + assert rec.match_score is None + assert rec.model_used == "" + + def test_creation_full(self): + rec = ShadowRecommendation( + recommendation_id="rec-1", + event_id="ev-1", + recommended_action={"a": 1}, + actual_action={"a": 2}, + match_score=0.75, + model_used="gpt-4", + ) + assert rec.recommendation_id == "rec-1" + assert rec.match_score == 0.75 + assert rec.model_used == "gpt-4" + + def test_frozen(self): + rec = ShadowRecommendation(event_id="ev-1") + with pytest.raises(Exception): + rec.event_id = "changed" + + +# ═══════════════════════════════════════════════════════════════════════════ +# OBSERVER CONFIG MODEL +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestObserverConfig: + def test_defaults(self): + cfg = ObserverConfig() + assert cfg.enabled is False + assert cfg.context_window_size == 50 + assert cfg.shadow_recommendation_rate == 0.1 + assert cfg.min_context_before_recommending == 10 + + def test_custom(self): + cfg = ObserverConfig( + enabled=True, + context_window_size=200, + shadow_recommendation_rate=0.5, + min_context_before_recommending=5, + ) + assert cfg.enabled is True + assert cfg.context_window_size == 200 + assert cfg.shadow_recommendation_rate == 0.5 + assert cfg.min_context_before_recommending == 5 + + def test_frozen(self): + cfg = ObserverConfig() + with pytest.raises(Exception): + cfg.enabled = True + + def test_context_window_size_bounds(self): + with pytest.raises(Exception): + ObserverConfig(context_window_size=0) + with pytest.raises(Exception): + ObserverConfig(context_window_size=1001) + # Edge values should succeed + ObserverConfig(context_window_size=1) + ObserverConfig(context_window_size=1000) + + def test_shadow_recommendation_rate_bounds(self): + with pytest.raises(Exception): + ObserverConfig(shadow_recommendation_rate=-0.01) + with pytest.raises(Exception): + ObserverConfig(shadow_recommendation_rate=1.01) + ObserverConfig(shadow_recommendation_rate=0.0) + ObserverConfig(shadow_recommendation_rate=1.0) + + def test_min_context_before_recommending_bounds(self): + with pytest.raises(Exception): + ObserverConfig(min_context_before_recommending=0) + ObserverConfig(min_context_before_recommending=1) + + +# ═══════════════════════════════════════════════════════════════════════════ +# MATCH SCORE HELPER +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestComputeMatchScore: + def test_both_empty(self): + assert _compute_match_score({}, {}) == 0.0 + + def test_identical_keys(self): + assert _compute_match_score({"a": 1, "b": 2}, {"a": 3, "b": 4}) == 1.0 + + def test_no_overlap(self): + assert _compute_match_score({"a": 1}, {"b": 2}) == 0.0 + + def test_partial_overlap(self): + score = _compute_match_score({"a": 1, "b": 2}, {"b": 3, "c": 4}) + # intersection = {b}, union = {a, b, c} => 1/3 + assert abs(score - 1.0 / 3.0) < 1e-9 + + def test_one_empty(self): + assert _compute_match_score({"a": 1}, {}) == 0.0 + assert _compute_match_score({}, {"a": 1}) == 0.0 + + +# ═══════════════════════════════════════════════════════════════════════════ +# OBSERVER — DISABLED MODE +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestObserverDisabled: + def test_observe_is_noop(self, observer_disabled): + event = make_event() + observer_disabled.observe(event) + assert observer_disabled.get_context("test-task") == [] + + def test_record_action_returns_none(self, observer_disabled): + result = observer_disabled.record_action("any-id", {"action": "do"}) + assert result is None + + +# ═══════════════════════════════════════════════════════════════════════════ +# OBSERVER — OBSERVE & CONTEXT WINDOW +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestObserverObserve: + def test_observe_adds_events(self, observer_enabled): + e1 = make_event(task_name="task-a") + e2 = make_event(task_name="task-a") + observer_enabled.observe(e1) + observer_enabled.observe(e2) + ctx = observer_enabled.get_context("task-a") + assert len(ctx) == 2 + assert ctx[0].event_id == e1.event_id + assert ctx[1].event_id == e2.event_id + + def test_context_window_drops_oldest(self): + cfg = enabled_config(context_window_size=3) + obs = Observer(cfg) + events = [make_event(task_name="t") for _ in range(5)] + for e in events: + obs.observe(e) + ctx = obs.get_context("t") + assert len(ctx) == 3 + # Should contain the 3 most recent events + assert ctx[0].event_id == events[2].event_id + assert ctx[1].event_id == events[3].event_id + assert ctx[2].event_id == events[4].event_id + + def test_separate_task_contexts(self, observer_enabled): + e1 = make_event(task_name="alpha") + e2 = make_event(task_name="beta") + observer_enabled.observe(e1) + observer_enabled.observe(e2) + assert len(observer_enabled.get_context("alpha")) == 1 + assert len(observer_enabled.get_context("beta")) == 1 + + +# ═══════════════════════════════════════════════════════════════════════════ +# OBSERVER — GET CONTEXT & GET RECOMMENDATIONS (NON-EXISTING TASK) +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestObserverGetters: + def test_get_context_nonexistent_task(self, observer_enabled): + assert observer_enabled.get_context("no-such-task") == [] + + def test_get_recommendations_nonexistent_task(self, observer_enabled): + assert observer_enabled.get_recommendations("no-such-task") == [] + + +# ═══════════════════════════════════════════════════════════════════════════ +# OBSERVER — RECORD ACTION +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestObserverRecordAction: + def test_record_action_no_recommendation(self, observer_enabled): + """record_action returns None when no recommendation exists for the event.""" + event = make_event() + observer_enabled.observe(event) + result = observer_enabled.record_action(event.event_id, {"do": "something"}) + # With default config (rate=0.1, min_context=10), after 1 event no rec is generated + assert result is None + + def test_record_action_with_recommendation(self, observer_always_recommend): + """record_action returns updated ShadowRecommendation with match score.""" + obs = observer_always_recommend + event = make_event(task_name="t1") + obs.observe(event) + + # Verify recommendation was generated + recs = obs.get_recommendations("t1") + assert len(recs) == 1 + assert recs[0].actual_action is None + + # Record actual action + actual = {"key_a": 1, "key_b": 2} + result = obs.record_action(event.event_id, actual) + assert result is not None + assert result.actual_action == actual + assert result.match_score is not None + assert 0.0 <= result.match_score <= 1.0 + # The recommended_action is {} (placeholder), actual has keys -> score = 0.0 + assert result.match_score == 0.0 + + def test_record_action_match_score_with_overlapping_keys(self, observer_always_recommend): + """Verify match_score computation when recommended and actual share keys.""" + obs = observer_always_recommend + event = make_event(task_name="t1") + obs.observe(event) + + # Manually replace the recommendation with one that has keys + rec = obs._recommendations[event.event_id] + patched_rec = rec.model_copy(update={"recommended_action": {"a": 1, "b": 2}}) + obs._recommendations[event.event_id] = patched_rec + + actual = {"b": 99, "c": 3} + result = obs.record_action(event.event_id, actual) + # intersection={b}, union={a,b,c} => 1/3 + assert result is not None + assert abs(result.match_score - 1.0 / 3.0) < 1e-9 + + +# ═══════════════════════════════════════════════════════════════════════════ +# OBSERVER — CLEAR +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestObserverClear: + def test_clear_removes_context_and_recommendations(self, observer_always_recommend): + obs = observer_always_recommend + for _ in range(5): + obs.observe(make_event(task_name="clearing")) + assert len(obs.get_context("clearing")) == 5 + assert len(obs.get_recommendations("clearing")) > 0 + + obs.clear("clearing") + assert obs.get_context("clearing") == [] + assert obs.get_recommendations("clearing") == [] + + def test_clear_nonexistent_task_is_safe(self, observer_enabled): + observer_enabled.clear("nonexistent") # should not raise + + +# ═══════════════════════════════════════════════════════════════════════════ +# OBSERVER — MIN CONTEXT BEFORE RECOMMENDING +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestObserverMinContext: + def test_no_recommendations_before_min_context(self): + cfg = enabled_config( + shadow_recommendation_rate=1.0, + min_context_before_recommending=5, + ) + obs = Observer(cfg) + for i in range(4): + obs.observe(make_event(task_name="t")) + # 4 events < min 5 => no recommendations + assert obs.get_recommendations("t") == [] + + def test_recommendations_after_min_context(self): + cfg = enabled_config( + shadow_recommendation_rate=1.0, + min_context_before_recommending=5, + ) + obs = Observer(cfg) + for i in range(5): + obs.observe(make_event(task_name="t")) + # 5th event meets threshold, rate=1.0 => recommendation generated + recs = obs.get_recommendations("t") + assert len(recs) == 1 + + +# ═══════════════════════════════════════════════════════════════════════════ +# OBSERVER — SHADOW RECOMMENDATION RATE +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestObserverRecommendationRate: + def test_rate_1_always_generates(self, observer_always_recommend): + obs = observer_always_recommend + events = [make_event(task_name="t") for _ in range(20)] + for e in events: + obs.observe(e) + recs = obs.get_recommendations("t") + # All 20 events should produce recommendations (rate=1.0, min_context=1) + assert len(recs) == 20 + + def test_rate_0_never_generates(self, observer_never_recommend): + obs = observer_never_recommend + for _ in range(50): + obs.observe(make_event(task_name="t")) + recs = obs.get_recommendations("t") + assert len(recs) == 0 + + def test_probabilistic_rate(self): + """With rate=0.5, roughly half of eligible events should get recommendations.""" + cfg = enabled_config( + shadow_recommendation_rate=0.5, + min_context_before_recommending=1, + ) + obs = Observer(cfg) + n = 1000 + for _ in range(n): + obs.observe(make_event(task_name="t")) + recs = obs.get_recommendations("t") + # Allow wide margin for randomness but not 0 or n + assert 100 < len(recs) < 900 + + +# ═══════════════════════════════════════════════════════════════════════════ +# OBSERVER — RECOMMENDATIONS SURVIVE CONTEXT EVICTION +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestObserverEvictionBehavior: + def test_recommendations_persist_after_window_eviction(self): + """Recommendations remain accessible even after their events are evicted from context.""" + cfg = enabled_config( + context_window_size=3, + shadow_recommendation_rate=1.0, + min_context_before_recommending=1, + ) + obs = Observer(cfg) + early_events = [make_event(task_name="t") for _ in range(3)] + for e in early_events: + obs.observe(e) + + # Context has 3 events, 3 recommendations + assert len(obs.get_context("t")) == 3 + assert len(obs.get_recommendations("t")) == 3 + + # Add 3 more events to evict the earlier ones from context + for _ in range(3): + obs.observe(make_event(task_name="t")) + + assert len(obs.get_context("t")) == 3 + # All 6 recommendations should still be accessible + assert len(obs.get_recommendations("t")) == 6 + + # Can still record_action for evicted event + result = obs.record_action(early_events[0].event_id, {"x": 1}) + assert result is not None diff --git a/tests/test_pii_tokenizer.py b/tests/test_pii_tokenizer.py new file mode 100644 index 0000000..7fdcdc4 --- /dev/null +++ b/tests/test_pii_tokenizer.py @@ -0,0 +1,406 @@ +"""Tests for the PII tokenizer middleware.""" + +import pytest + +from apprentice.middleware import MiddlewareContext, MiddlewareResponse +from apprentice.pii_tokenizer import ( + PIITokenizer, + PIITokenizerConfig, + TokenRegistry, +) + + +# ============================================================================ +# PIITokenizerConfig +# ============================================================================ + + +class TestPIITokenizerConfig: + def test_defaults(self): + config = PIITokenizerConfig() + assert config.enabled_entity_types == ["email", "phone", "ssn", "credit_card"] + assert config.custom_patterns == {} + assert config.token_format == "" + + def test_custom_config(self): + config = PIITokenizerConfig( + enabled_entity_types=["email"], + custom_patterns={"ip_address": r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b"}, + token_format="[REDACTED:{type}:{hash}]", + ) + assert config.enabled_entity_types == ["email"] + assert "ip_address" in config.custom_patterns + assert config.token_format == "[REDACTED:{type}:{hash}]" + + def test_frozen(self): + config = PIITokenizerConfig() + with pytest.raises(Exception): + config.enabled_entity_types = ["email"] + + +# ============================================================================ +# TokenRegistry +# ============================================================================ + + +class TestTokenRegistry: + def test_tokenize_detokenize_roundtrip(self): + registry = TokenRegistry() + token = registry.tokenize("test@example.com", "email") + assert token != "test@example.com" + assert registry.detokenize(token) == "test@example.com" + + def test_deterministic_same_input(self): + registry = TokenRegistry() + token1 = registry.tokenize("test@example.com", "email") + token2 = registry.tokenize("test@example.com", "email") + assert token1 == token2 + + def test_different_values_different_tokens(self): + registry = TokenRegistry() + token1 = registry.tokenize("alice@example.com", "email") + token2 = registry.tokenize("bob@example.com", "email") + assert token1 != token2 + + def test_token_format(self): + registry = TokenRegistry(token_format="") + token = registry.tokenize("test@example.com", "email") + assert token.startswith("") + + def test_custom_token_format(self): + registry = TokenRegistry(token_format="[REDACTED:{type}:{hash}]") + token = registry.tokenize("555-12-3456", "ssn") + assert token.startswith("[REDACTED:ssn:") + assert token.endswith("]") + + def test_clear(self): + registry = TokenRegistry() + token = registry.tokenize("test@example.com", "email") + registry.clear() + # After clear, detokenize should return the token itself (unrecognized) + assert registry.detokenize(token) == token + + def test_detokenize_unknown_returns_unchanged(self): + registry = TokenRegistry() + assert registry.detokenize("not-a-token") == "not-a-token" + + +# ============================================================================ +# PIITokenizer — pre_process +# ============================================================================ + + +class TestPIITokenizerPreProcess: + def _make_context(self, input_data: dict) -> MiddlewareContext: + return MiddlewareContext( + request_id="req-001", + task_name="test_task", + input_data=input_data, + ) + + def test_pre_process_email(self): + tokenizer = PIITokenizer() + ctx = self._make_context({"text": "Contact alice@example.com for info."}) + result = tokenizer.pre_process(ctx) + assert "alice@example.com" not in result.input_data["text"] + assert "