From f1ab6db24a6e6954757a6714eede51a8ac22c353 Mon Sep 17 00:00:00 2001 From: sblanchard Date: Sat, 21 Feb 2026 13:52:03 +0100 Subject: [PATCH 01/15] feat(domain): add RouterConfig, RoutingProfile, ModelTier types Add smart router configuration types: RouterConfig (with Default), RoutingProfile enum (Auto/Eco/Premium/Free/Reasoning), ModelTier enum (Simple/Complex/Reasoning/Free), ClassifierConfig, TierConfig, and RouterThresholds. Add optional router field to LlmConfig. --- crates/domain/src/config/llm.rs | 164 ++++++++++++++++++++++++++++++++ 1 file changed, 164 insertions(+) diff --git a/crates/domain/src/config/llm.rs b/crates/domain/src/config/llm.rs index 1868315..893f4a6 100644 --- a/crates/domain/src/config/llm.rs +++ b/crates/domain/src/config/llm.rs @@ -39,6 +39,9 @@ pub struct LlmConfig { /// Per-model pricing for cost estimation (key = model name, e.g. "gpt-4o"). #[serde(default)] pub pricing: HashMap, + /// Smart router configuration (optional). + #[serde(default)] + pub router: Option, } impl Default for LlmConfig { @@ -52,6 +55,7 @@ impl Default for LlmConfig { roles: HashMap::new(), providers: Vec::new(), pricing: HashMap::new(), + router: None, } } } @@ -193,6 +197,104 @@ fn d_2() -> u32 { 2 } +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Smart router types +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +/// Routing profile determines how the smart router selects a model tier. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum RoutingProfile { + #[default] + Auto, + Eco, + Premium, + Free, + Reasoning, +} + +/// Model tier for router classification. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ModelTier { + Simple, + Complex, + Reasoning, + Free, +} + +/// Smart router configuration (optional section under [llm]). +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct RouterConfig { + #[serde(default)] + pub enabled: bool, + #[serde(default)] + pub default_profile: RoutingProfile, + #[serde(default)] + pub classifier: ClassifierConfig, + #[serde(default)] + pub tiers: TierConfig, + #[serde(default)] + pub thresholds: RouterThresholds, +} + +/// Embedding classifier configuration. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClassifierConfig { + pub provider: String, + pub model: String, + pub endpoint: String, + pub cache_ttl_secs: u64, +} + +impl Default for ClassifierConfig { + fn default() -> Self { + Self { + provider: "ollama".into(), + model: "nomic-embed-text".into(), + endpoint: "http://localhost:11434".into(), + cache_ttl_secs: 300, + } + } +} + +/// Per-tier ordered list of provider/model strings. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct TierConfig { + #[serde(default)] + pub simple: Vec, + #[serde(default)] + pub complex: Vec, + #[serde(default)] + pub reasoning: Vec, + #[serde(default)] + pub free: Vec, +} + +/// Cosine similarity thresholds for the classifier. +/// +/// Each score is compared independently against the embedding centroid +/// for that tier. A prompt is assigned to the highest-scoring tier +/// that exceeds its threshold. Values are not required to be ordered. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RouterThresholds { + pub simple_min_score: f64, + pub complex_min_score: f64, + pub reasoning_min_score: f64, + pub escalate_token_threshold: usize, +} + +impl Default for RouterThresholds { + fn default() -> Self { + Self { + simple_min_score: 0.6, + complex_min_score: 0.5, + reasoning_min_score: 0.55, + escalate_token_threshold: 8000, + } + } +} + // ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ // Tests // ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ @@ -258,4 +360,66 @@ mod tests { assert!((gpt4o.input_per_1m - 2.50).abs() < 1e-10); assert!((gpt4o.output_per_1m - 10.00).abs() < 1e-10); } + + #[test] + fn router_config_deserializes() { + let json = r#"{ + "router": { + "enabled": true, + "default_profile": "auto", + "classifier": { + "provider": "ollama", + "model": "nomic-embed-text", + "endpoint": "http://localhost:11434", + "cache_ttl_secs": 300 + }, + "tiers": { + "simple": ["deepseek/deepseek-chat"], + "complex": ["anthropic/claude-sonnet-4-20250514"], + "reasoning": ["anthropic/claude-opus-4-6"], + "free": ["venice/venice-uncensored"] + }, + "thresholds": { + "simple_min_score": 0.6, + "complex_min_score": 0.5, + "reasoning_min_score": 0.55, + "escalate_token_threshold": 8000 + } + } + }"#; + let config: LlmConfig = serde_json::from_str(json).unwrap(); + let router = config.router.unwrap(); + assert!(router.enabled); + assert_eq!(router.default_profile, RoutingProfile::Auto); + assert_eq!(router.classifier.model, "nomic-embed-text"); + assert_eq!(router.tiers.simple.len(), 1); + assert!((router.thresholds.simple_min_score - 0.6).abs() < 1e-10); + } + + #[test] + fn router_config_defaults_when_absent() { + let json = r#"{}"#; + let config: LlmConfig = serde_json::from_str(json).unwrap(); + assert!(config.router.is_none()); + } + + #[test] + fn routing_profile_serde_roundtrip() { + for profile in &["auto", "eco", "premium", "free", "reasoning"] { + let json = format!("\"{}\"", profile); + let parsed: RoutingProfile = serde_json::from_str(&json).unwrap(); + let back = serde_json::to_string(&parsed).unwrap(); + assert_eq!(back, json); + } + } + + #[test] + fn model_tier_serde_roundtrip() { + for tier in &["simple", "complex", "reasoning", "free"] { + let json = format!("\"{}\"", tier); + let parsed: ModelTier = serde_json::from_str(&json).unwrap(); + let back = serde_json::to_string(&parsed).unwrap(); + assert_eq!(back, json); + } + } } From d2926c789d73ba1d21507705e1fe1d8d531853a5 Mon Sep 17 00:00:00 2001 From: sblanchard Date: Sat, 21 Feb 2026 13:58:24 +0100 Subject: [PATCH 02/15] feat(providers): add embedding-based prompt classifier Introduces EmbeddingClassifier that uses cosine similarity between prompt embeddings and pre-computed tier centroids (Simple, Complex, Reasoning) to route prompts to the appropriate model tier. Includes an in-memory LRU cache with TTL eviction, threshold-based escalation rules, and agentic prompt detection via length heuristics. --- Cargo.lock | 1 + crates/providers/Cargo.toml | 1 + crates/providers/src/classifier.rs | 792 +++++++++++++++++++++++++++++ crates/providers/src/lib.rs | 1 + 4 files changed, 795 insertions(+) create mode 100644 crates/providers/src/classifier.rs diff --git a/Cargo.lock b/Cargo.lock index 6b90359..e169110 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3929,6 +3929,7 @@ dependencies = [ "fs2", "futures-core", "keyring", + "parking_lot", "reqwest 0.12.28", "sa-domain", "serde", diff --git a/crates/providers/Cargo.toml b/crates/providers/Cargo.toml index 869d9d7..fa67afd 100644 --- a/crates/providers/Cargo.toml +++ b/crates/providers/Cargo.toml @@ -20,3 +20,4 @@ chrono = { workspace = true } dirs = "5" fs2 = "0.4" tempfile = { workspace = true } +parking_lot = { workspace = true } diff --git a/crates/providers/src/classifier.rs b/crates/providers/src/classifier.rs new file mode 100644 index 0000000..a35055d --- /dev/null +++ b/crates/providers/src/classifier.rs @@ -0,0 +1,792 @@ +//! Embedding-based prompt classifier for smart model routing. +//! +//! Uses cosine similarity between prompt embeddings and pre-computed tier +//! centroids to classify incoming prompts as Simple, Complex, or Reasoning. +//! Embeddings are fetched from an Ollama-compatible endpoint and cached +//! in-memory with TTL-based eviction. + +use parking_lot::RwLock; +use sa_domain::config::{ClassifierConfig, ModelTier, RouterThresholds}; +use sa_domain::error::{Error, Result}; +use std::collections::HashMap; +use std::time::{Duration, Instant}; + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Constants +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +/// Maximum number of cached embeddings before eviction runs. +const CACHE_MAX_ENTRIES: usize = 10_000; + +/// Timeout for individual embedding requests. +const EMBEDDING_TIMEOUT: Duration = Duration::from_millis(500); + +/// Timeout for batch initialization (fetching all reference embeddings). +const BATCH_TIMEOUT: Duration = Duration::from_secs(30); + +/// Approximate chars-per-token multiplier for agentic detection. +const CHARS_PER_TOKEN: usize = 4; + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Reference prompts +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +/// Reference prompts used to build tier centroids at startup. +/// +/// Each tier gets a set of representative prompts whose embeddings are +/// averaged to form that tier's centroid vector. +pub fn default_reference_prompts() -> HashMap> { + let mut prompts = HashMap::new(); + + prompts.insert( + ModelTier::Simple, + vec![ + "What is the capital of France?", + "Convert 5 miles to kilometers", + "What time is it in Tokyo?", + "Define the word 'ephemeral'", + "How many cups in a gallon?", + "What year was the Eiffel Tower built?", + "Translate 'hello' to Spanish", + "What is 15% of 200?", + ], + ); + + prompts.insert( + ModelTier::Complex, + vec![ + "Write a Python script that scrapes a website and stores the data in a SQLite database with proper error handling", + "Explain the differences between microservices and monolithic architectures, including trade-offs for a startup", + "Design a REST API for a multi-tenant SaaS application with rate limiting and authentication", + "Refactor this legacy codebase to use dependency injection and add comprehensive test coverage", + "Create a data pipeline that ingests CSV files, validates schemas, transforms data, and loads into PostgreSQL", + "Build a React component library with TypeScript, Storybook documentation, and unit tests", + "Implement a caching strategy for a high-traffic e-commerce API with cache invalidation", + "Debug this distributed system issue where messages are being processed out of order", + ], + ); + + prompts.insert( + ModelTier::Reasoning, + vec![ + "Prove that the square root of 2 is irrational using proof by contradiction", + "Analyze the computational complexity of this recursive algorithm and suggest optimizations with formal proofs", + "Design a consensus protocol for a Byzantine fault-tolerant distributed system and prove its safety properties", + "Evaluate the philosophical implications of artificial general intelligence on human autonomy and free will", + "Derive the optimal strategy for this game theory problem using backward induction and Nash equilibrium", + "Compare and critically evaluate three competing theories of consciousness with respect to the hard problem", + "Analyze the economic second-order effects of universal basic income on labor markets and innovation", + "Formally verify the correctness of this concurrent data structure using temporal logic", + ], + ); + + prompts +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Vector math +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +/// Cosine similarity between two vectors. +/// +/// Returns a value in `[-1.0, 1.0]`. Returns `0.0` if either vector has +/// zero magnitude (avoiding division by zero). +pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 { + if a.len() != b.len() { + tracing::warn!( + len_a = a.len(), + len_b = b.len(), + "cosine_similarity: mismatched vector lengths, returning 0.0" + ); + return 0.0; + } + + let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(); + let mag_a: f32 = a.iter().map(|x| x * x).sum::().sqrt(); + let mag_b: f32 = b.iter().map(|x| x * x).sum::().sqrt(); + + if mag_a == 0.0 || mag_b == 0.0 { + return 0.0; + } + + dot / (mag_a * mag_b) +} + +/// Compute the centroid (element-wise average) of a set of vectors. +/// +/// Returns an empty vector if the input is empty. +pub fn compute_centroid(vectors: &[Vec]) -> Vec { + if vectors.is_empty() { + return Vec::new(); + } + + let dim = vectors[0].len(); + let count = vectors.len() as f32; + + let mut centroid = vec![0.0f32; dim]; + for v in vectors { + for (acc, val) in centroid.iter_mut().zip(v.iter()) { + *acc += val; + } + } + for val in &mut centroid { + *val /= count; + } + + centroid +} + +/// Build centroids from pre-computed reference embeddings. +/// +/// Each tier maps to a list of embedding vectors; this function computes +/// the centroid of each tier's vectors. +pub fn build_centroids( + embeddings: &HashMap>>, +) -> HashMap> { + embeddings + .iter() + .map(|(tier, vecs)| (*tier, compute_centroid(vecs))) + .collect() +} + +/// Classify a prompt embedding against tier centroids. +/// +/// Returns the best-matching tier and a map of all tier scores. +/// If centroids are empty, defaults to `ModelTier::Complex`. +pub fn classify_against_centroids( + embedding: &[f32], + centroids: &HashMap>, +) -> (ModelTier, HashMap) { + let mut scores = HashMap::new(); + let mut best_tier = ModelTier::Complex; + let mut best_score = f32::NEG_INFINITY; + + for (tier, centroid) in centroids { + let score = cosine_similarity(embedding, centroid); + scores.insert(*tier, score); + if score > best_score { + best_score = score; + best_tier = *tier; + } + } + + // When centroids are empty or all scores are tied / ambiguous, default to Complex. + if centroids.is_empty() { + return (ModelTier::Complex, scores); + } + + (best_tier, scores) +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Cache entry +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +/// A cached embedding vector with expiration time. +struct CachedEmbedding { + embedding: Vec, + expires_at: Instant, +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Classifier result +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +/// Result of classifying a prompt. +#[derive(Debug, Clone)] +pub struct ClassifyResult { + /// The selected model tier. + pub tier: ModelTier, + /// Cosine similarity scores for each tier. + pub scores: HashMap, + /// Classification latency in milliseconds. + pub latency_ms: u64, +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Embedding classifier +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +/// Embedding-based prompt classifier. +/// +/// Maintains pre-computed centroids for each model tier and classifies +/// incoming prompts by comparing their embeddings against those centroids. +pub struct EmbeddingClassifier { + config: ClassifierConfig, + thresholds: RouterThresholds, + centroids: HashMap>, + http: reqwest::Client, + cache: RwLock>, +} + +impl EmbeddingClassifier { + /// Create a classifier with pre-computed centroids (useful for testing + /// or when centroids are loaded from a snapshot). + pub fn with_centroids( + config: ClassifierConfig, + thresholds: RouterThresholds, + centroids: HashMap>, + ) -> Self { + Self { + config, + thresholds, + centroids, + http: reqwest::Client::new(), + cache: RwLock::new(HashMap::new()), + } + } + + /// Initialize the classifier by fetching embeddings for all reference + /// prompts and building centroids. + /// + /// This makes HTTP calls to the configured embedding endpoint. + pub async fn initialize( + config: ClassifierConfig, + thresholds: RouterThresholds, + ) -> Result { + let http = reqwest::Client::builder() + .timeout(BATCH_TIMEOUT) + .build() + .map_err(|e| Error::Http(format!("failed to build HTTP client: {e}")))?; + + let reference_prompts = default_reference_prompts(); + let mut tier_embeddings: HashMap>> = HashMap::new(); + + for (tier, prompts) in &reference_prompts { + let texts: Vec<&str> = prompts.iter().copied().collect(); + let embeddings = Self::fetch_embeddings_batch(&http, &config, &texts).await?; + tier_embeddings.insert(*tier, embeddings); + } + + let centroids = build_centroids(&tier_embeddings); + + tracing::info!( + tiers = centroids.len(), + "embedding classifier initialized with centroids" + ); + + Ok(Self { + config, + thresholds, + centroids, + http, + cache: RwLock::new(HashMap::new()), + }) + } + + /// Classify a prompt into a model tier. + /// + /// 1. Checks the embedding cache. + /// 2. Fetches embedding from the provider if not cached. + /// 3. Compares against centroids. + /// 4. Applies threshold rules and agentic escalation. + pub async fn classify(&self, prompt: &str) -> Result { + let start = Instant::now(); + + // Check cache first. + let cache_key = hash_prompt(prompt); + if let Some(cached) = self.get_cached(cache_key) { + let (tier, scores) = classify_against_centroids(&cached, &self.centroids); + let final_tier = self.apply_thresholds(tier, &scores, prompt); + return Ok(ClassifyResult { + tier: final_tier, + scores, + latency_ms: start.elapsed().as_millis() as u64, + }); + } + + // Fetch embedding from the provider. + let embedding = Self::fetch_embedding(&self.http, &self.config, prompt).await?; + + // Cache the result. + self.put_cached(cache_key, &embedding); + + // Classify. + let (tier, scores) = classify_against_centroids(&embedding, &self.centroids); + let final_tier = self.apply_thresholds(tier, &scores, prompt); + + Ok(ClassifyResult { + tier: final_tier, + scores, + latency_ms: start.elapsed().as_millis() as u64, + }) + } + + /// Apply threshold rules to potentially escalate or de-escalate the tier. + /// + /// Rules: + /// - If classified as Simple but score < simple_min_score, escalate to Complex. + /// - If classified as Reasoning but score < reasoning_min_score, fall back to Complex. + /// - If prompt is long (agentic), escalate Simple to Complex. + fn apply_thresholds( + &self, + tier: ModelTier, + scores: &HashMap, + prompt: &str, + ) -> ModelTier { + // Agentic detection: long prompts escalate Simple -> Complex. + let char_threshold = self.thresholds.escalate_token_threshold * CHARS_PER_TOKEN; + let after_length = if tier == ModelTier::Simple && prompt.len() > char_threshold { + tracing::debug!( + prompt_len = prompt.len(), + threshold = char_threshold, + "escalating Simple -> Complex due to prompt length" + ); + ModelTier::Complex + } else { + tier + }; + + // Threshold checks. + match after_length { + ModelTier::Simple => { + let score = scores + .get(&ModelTier::Simple) + .copied() + .unwrap_or(0.0) as f64; + if score < self.thresholds.simple_min_score { + tracing::debug!( + score, + min = self.thresholds.simple_min_score, + "escalating Simple -> Complex due to low score" + ); + ModelTier::Complex + } else { + ModelTier::Simple + } + } + ModelTier::Reasoning => { + let score = scores + .get(&ModelTier::Reasoning) + .copied() + .unwrap_or(0.0) as f64; + if score < self.thresholds.reasoning_min_score { + tracing::debug!( + score, + min = self.thresholds.reasoning_min_score, + "de-escalating Reasoning -> Complex due to low score" + ); + ModelTier::Complex + } else { + ModelTier::Reasoning + } + } + other => other, + } + } + + /// Fetch a single embedding vector from the Ollama-compatible endpoint. + async fn fetch_embedding( + http: &reqwest::Client, + config: &ClassifierConfig, + text: &str, + ) -> Result> { + let url = format!("{}/api/embeddings", config.endpoint.trim_end_matches('/')); + + let body = serde_json::json!({ + "model": config.model, + "prompt": text, + }); + + let resp = http + .post(&url) + .timeout(EMBEDDING_TIMEOUT) + .json(&body) + .send() + .await + .map_err(|e| Error::Http(format!("embedding request failed: {e}")))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body_text = resp.text().await.unwrap_or_default(); + return Err(Error::Provider { + provider: config.provider.clone(), + message: format!("embedding HTTP {status}: {body_text}"), + }); + } + + let json: serde_json::Value = resp + .json() + .await + .map_err(|e| Error::Http(format!("failed to parse embedding response: {e}")))?; + + let embedding = json + .get("embedding") + .and_then(|v| v.as_array()) + .ok_or_else(|| { + Error::Provider { + provider: config.provider.clone(), + message: "response missing 'embedding' array".into(), + } + })? + .iter() + .map(|v| v.as_f64().unwrap_or(0.0) as f32) + .collect(); + + Ok(embedding) + } + + /// Fetch embeddings for multiple texts sequentially. + /// + /// Uses the batch timeout for the overall operation. Individual requests + /// use the standard embedding timeout. + async fn fetch_embeddings_batch( + http: &reqwest::Client, + config: &ClassifierConfig, + texts: &[&str], + ) -> Result>> { + let mut results = Vec::with_capacity(texts.len()); + for text in texts { + let embedding = Self::fetch_embedding(http, config, text).await?; + results.push(embedding); + } + Ok(results) + } + + /// Check whether the embedding endpoint is reachable. + pub async fn health_check(&self) -> bool { + let test_result = + Self::fetch_embedding(&self.http, &self.config, "health check").await; + test_result.is_ok() + } + + /// Get a reference to the classifier config. + pub fn config(&self) -> &ClassifierConfig { + &self.config + } + + /// Get a reference to the centroids. + pub fn centroids(&self) -> &HashMap> { + &self.centroids + } + + // ── Cache helpers ────────────────────────────────────────────── + + /// Look up a cached embedding by prompt hash. Returns `None` if absent or expired. + fn get_cached(&self, key: u64) -> Option> { + let cache = self.cache.read(); + cache.get(&key).and_then(|entry| { + if Instant::now() < entry.expires_at { + Some(entry.embedding.clone()) + } else { + None + } + }) + } + + /// Store an embedding in the cache. Evicts expired entries if over capacity. + fn put_cached(&self, key: u64, embedding: &[f32]) { + let ttl = Duration::from_secs(self.config.cache_ttl_secs); + let entry = CachedEmbedding { + embedding: embedding.to_vec(), + expires_at: Instant::now() + ttl, + }; + + let mut cache = self.cache.write(); + + // Evict expired entries when over capacity. + if cache.len() >= CACHE_MAX_ENTRIES { + let now = Instant::now(); + cache.retain(|_, v| v.expires_at > now); + } + + cache.insert(key, entry); + } +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Helpers +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +/// Hash a prompt string to a u64 for cache lookup. +fn hash_prompt(prompt: &str) -> u64 { + use std::hash::{Hash, Hasher}; + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + prompt.hash(&mut hasher); + hasher.finish() +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Tests +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn cosine_similarity_identical_vectors() { + let a = vec![1.0, 2.0, 3.0]; + let b = vec![1.0, 2.0, 3.0]; + let sim = cosine_similarity(&a, &b); + assert!( + (sim - 1.0).abs() < 1e-6, + "identical vectors should have similarity ~1.0, got {sim}" + ); + } + + #[test] + fn cosine_similarity_orthogonal_vectors() { + let a = vec![1.0, 0.0, 0.0]; + let b = vec![0.0, 1.0, 0.0]; + let sim = cosine_similarity(&a, &b); + assert!( + sim.abs() < 1e-6, + "orthogonal vectors should have similarity ~0.0, got {sim}" + ); + } + + #[test] + fn cosine_similarity_opposite_vectors() { + let a = vec![1.0, 2.0, 3.0]; + let b = vec![-1.0, -2.0, -3.0]; + let sim = cosine_similarity(&a, &b); + assert!( + (sim - (-1.0)).abs() < 1e-6, + "opposite vectors should have similarity ~-1.0, got {sim}" + ); + } + + #[test] + fn cosine_similarity_zero_vector_returns_zero() { + let a = vec![0.0, 0.0, 0.0]; + let b = vec![1.0, 2.0, 3.0]; + let sim = cosine_similarity(&a, &b); + assert!( + sim.abs() < 1e-6, + "zero vector should yield similarity 0.0, got {sim}" + ); + } + + #[test] + fn compute_centroid_single_vector() { + let vectors = vec![vec![1.0, 2.0, 3.0]]; + let centroid = compute_centroid(&vectors); + assert_eq!(centroid, vec![1.0, 2.0, 3.0]); + } + + #[test] + fn compute_centroid_average() { + let vectors = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0], vec![0.0, 0.0, 1.0]]; + let centroid = compute_centroid(&vectors); + let expected = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]; + for (a, b) in centroid.iter().zip(expected.iter()) { + assert!( + (a - b).abs() < 1e-6, + "centroid mismatch: got {a}, expected {b}" + ); + } + } + + #[test] + fn compute_centroid_empty_returns_empty() { + let vectors: Vec> = vec![]; + let centroid = compute_centroid(&vectors); + assert!(centroid.is_empty()); + } + + #[test] + fn classify_with_centroids_picks_nearest() { + // Build centroids that are clearly separated in 3D space. + let mut centroids = HashMap::new(); + centroids.insert(ModelTier::Simple, vec![1.0, 0.0, 0.0]); + centroids.insert(ModelTier::Complex, vec![0.0, 1.0, 0.0]); + centroids.insert(ModelTier::Reasoning, vec![0.0, 0.0, 1.0]); + + // A vector close to the Simple centroid. + let embedding = vec![0.9, 0.1, 0.0]; + let (tier, scores) = classify_against_centroids(&embedding, ¢roids); + + assert_eq!(tier, ModelTier::Simple); + assert!(scores[&ModelTier::Simple] > scores[&ModelTier::Complex]); + assert!(scores[&ModelTier::Simple] > scores[&ModelTier::Reasoning]); + } + + #[test] + fn classify_ambiguous_defaults_to_complex() { + // Empty centroids should default to Complex. + let centroids: HashMap> = HashMap::new(); + let embedding = vec![1.0, 2.0, 3.0]; + let (tier, _scores) = classify_against_centroids(&embedding, ¢roids); + assert_eq!(tier, ModelTier::Complex); + } + + #[test] + fn build_centroids_from_embeddings() { + let mut embeddings = HashMap::new(); + embeddings.insert( + ModelTier::Simple, + vec![vec![1.0, 0.0], vec![0.8, 0.2]], + ); + embeddings.insert( + ModelTier::Complex, + vec![vec![0.0, 1.0], vec![0.2, 0.8]], + ); + + let centroids = build_centroids(&embeddings); + + assert_eq!(centroids.len(), 2); + + let simple = ¢roids[&ModelTier::Simple]; + assert!((simple[0] - 0.9).abs() < 1e-6); + assert!((simple[1] - 0.1).abs() < 1e-6); + + let complex = ¢roids[&ModelTier::Complex]; + assert!((complex[0] - 0.1).abs() < 1e-6); + assert!((complex[1] - 0.9).abs() < 1e-6); + } + + #[test] + fn default_reference_prompts_has_all_tiers() { + let prompts = default_reference_prompts(); + assert!(prompts.contains_key(&ModelTier::Simple)); + assert!(prompts.contains_key(&ModelTier::Complex)); + assert!(prompts.contains_key(&ModelTier::Reasoning)); + // Each tier should have multiple reference prompts. + for (_tier, texts) in &prompts { + assert!(texts.len() >= 3, "each tier should have at least 3 reference prompts"); + } + } + + #[test] + fn apply_thresholds_escalates_low_simple_score() { + let config = ClassifierConfig::default(); + let thresholds = RouterThresholds { + simple_min_score: 0.6, + complex_min_score: 0.5, + reasoning_min_score: 0.55, + escalate_token_threshold: 8000, + }; + + let classifier = EmbeddingClassifier::with_centroids( + config, + thresholds, + HashMap::new(), + ); + + // Simple score below threshold -> should escalate to Complex. + let mut scores = HashMap::new(); + scores.insert(ModelTier::Simple, 0.4_f32); // below 0.6 + scores.insert(ModelTier::Complex, 0.3_f32); + scores.insert(ModelTier::Reasoning, 0.2_f32); + + let result = classifier.apply_thresholds(ModelTier::Simple, &scores, "short prompt"); + assert_eq!(result, ModelTier::Complex); + } + + #[test] + fn apply_thresholds_deescalates_low_reasoning_score() { + let config = ClassifierConfig::default(); + let thresholds = RouterThresholds { + simple_min_score: 0.6, + complex_min_score: 0.5, + reasoning_min_score: 0.55, + escalate_token_threshold: 8000, + }; + + let classifier = EmbeddingClassifier::with_centroids( + config, + thresholds, + HashMap::new(), + ); + + // Reasoning score below threshold -> should fall back to Complex. + let mut scores = HashMap::new(); + scores.insert(ModelTier::Simple, 0.2_f32); + scores.insert(ModelTier::Complex, 0.3_f32); + scores.insert(ModelTier::Reasoning, 0.4_f32); // below 0.55 + + let result = classifier.apply_thresholds(ModelTier::Reasoning, &scores, "short prompt"); + assert_eq!(result, ModelTier::Complex); + } + + #[test] + fn apply_thresholds_escalates_long_prompt() { + let config = ClassifierConfig::default(); + let thresholds = RouterThresholds { + simple_min_score: 0.3, // low threshold so score check won't trigger + complex_min_score: 0.5, + reasoning_min_score: 0.55, + escalate_token_threshold: 100, // 100 tokens * 4 chars = 400 chars + }; + + let classifier = EmbeddingClassifier::with_centroids( + config, + thresholds, + HashMap::new(), + ); + + let mut scores = HashMap::new(); + scores.insert(ModelTier::Simple, 0.9_f32); + scores.insert(ModelTier::Complex, 0.3_f32); + + // A prompt longer than 400 chars should escalate Simple -> Complex. + let long_prompt = "a".repeat(500); + let result = classifier.apply_thresholds(ModelTier::Simple, &scores, &long_prompt); + assert_eq!(result, ModelTier::Complex); + } + + #[test] + fn apply_thresholds_keeps_good_simple_score() { + let config = ClassifierConfig::default(); + let thresholds = RouterThresholds::default(); + + let classifier = EmbeddingClassifier::with_centroids( + config, + thresholds, + HashMap::new(), + ); + + let mut scores = HashMap::new(); + scores.insert(ModelTier::Simple, 0.8_f32); // above 0.6 threshold + scores.insert(ModelTier::Complex, 0.3_f32); + scores.insert(ModelTier::Reasoning, 0.2_f32); + + let result = classifier.apply_thresholds(ModelTier::Simple, &scores, "short"); + assert_eq!(result, ModelTier::Simple); + } + + #[test] + fn cache_stores_and_retrieves() { + let config = ClassifierConfig { + cache_ttl_secs: 300, + ..ClassifierConfig::default() + }; + let classifier = EmbeddingClassifier::with_centroids( + config, + RouterThresholds::default(), + HashMap::new(), + ); + + let key = hash_prompt("test prompt"); + let embedding = vec![1.0, 2.0, 3.0]; + + classifier.put_cached(key, &embedding); + let result = classifier.get_cached(key); + + assert!(result.is_some()); + assert_eq!(result.unwrap(), embedding); + } + + #[test] + fn cache_returns_none_for_missing() { + let classifier = EmbeddingClassifier::with_centroids( + ClassifierConfig::default(), + RouterThresholds::default(), + HashMap::new(), + ); + + let result = classifier.get_cached(999); + assert!(result.is_none()); + } + + #[test] + fn hash_prompt_deterministic() { + let h1 = hash_prompt("hello world"); + let h2 = hash_prompt("hello world"); + let h3 = hash_prompt("different prompt"); + + assert_eq!(h1, h2, "same input should produce same hash"); + assert_ne!(h1, h3, "different input should produce different hash"); + } +} diff --git a/crates/providers/src/lib.rs b/crates/providers/src/lib.rs index 11461cb..369b092 100644 --- a/crates/providers/src/lib.rs +++ b/crates/providers/src/lib.rs @@ -1,6 +1,7 @@ pub mod anthropic; pub mod auth; pub mod bedrock; +pub mod classifier; pub mod google; pub mod oauth; pub mod openai_compat; From 8973b57b48af3450164c9ea255d6846b26798cdc Mon Sep 17 00:00:00 2001 From: sblanchard Date: Sat, 21 Feb 2026 14:02:36 +0100 Subject: [PATCH 03/15] feat(providers): add smart router resolution logic Pure synchronous routing functions that resolve model strings from routing profiles, classified tiers, and tier configuration. Supports explicit model bypass, profile-to-tier mapping, classifier-driven Auto mode, and cross-tier fallback chains. --- crates/providers/src/lib.rs | 1 + crates/providers/src/smart_router.rs | 233 +++++++++++++++++++++++++++ 2 files changed, 234 insertions(+) create mode 100644 crates/providers/src/smart_router.rs diff --git a/crates/providers/src/lib.rs b/crates/providers/src/lib.rs index 369b092..926cc4b 100644 --- a/crates/providers/src/lib.rs +++ b/crates/providers/src/lib.rs @@ -7,6 +7,7 @@ pub mod oauth; pub mod openai_compat; pub mod registry; pub mod router; +pub mod smart_router; pub mod traits; pub(crate) mod sse; pub(crate) mod util; diff --git a/crates/providers/src/smart_router.rs b/crates/providers/src/smart_router.rs new file mode 100644 index 0000000..04d1379 --- /dev/null +++ b/crates/providers/src/smart_router.rs @@ -0,0 +1,233 @@ +//! Smart router resolution logic. +//! +//! Pure, synchronous functions that resolve a model string from routing +//! profiles, classified tiers, and tier configuration. No HTTP, no async +//! — just deterministic decision logic. + +use sa_domain::config::{ModelTier, RoutingProfile, TierConfig}; + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Types +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +/// The result of a routing decision. +#[derive(Debug, Clone)] +pub struct RoutingDecision { + pub model: String, + pub tier: ModelTier, + pub profile: RoutingProfile, + pub bypassed: bool, +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Public API +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +/// Map a fixed profile to its corresponding tier. +/// Returns `None` for `Auto` (requires classification). +pub fn profile_to_tier(profile: RoutingProfile) -> Option { + match profile { + RoutingProfile::Auto => None, + RoutingProfile::Eco => Some(ModelTier::Simple), + RoutingProfile::Premium => Some(ModelTier::Complex), + RoutingProfile::Free => Some(ModelTier::Free), + RoutingProfile::Reasoning => Some(ModelTier::Reasoning), + } +} + +/// Get the first available model from a tier. +pub fn resolve_tier_model<'a>(tier: ModelTier, tiers: &'a TierConfig) -> Option<&'a str> { + let models = match tier { + ModelTier::Simple => &tiers.simple, + ModelTier::Complex => &tiers.complex, + ModelTier::Reasoning => &tiers.reasoning, + ModelTier::Free => &tiers.free, + }; + models.first().map(|s| s.as_str()) +} + +/// Core resolution: explicit model > profile tier > classified tier > fallback. +/// +/// Resolution order: +/// 1. If `explicit_model` is `Some`, bypass the router entirely. +/// 2. If the profile maps to a fixed tier, use that tier. +/// 3. If the profile is `Auto`, use the `classified_tier`. +/// 4. If no model is found in the chosen tier, walk the fallback chain. +pub fn resolve_model_for_request( + explicit_model: Option<&str>, + profile: RoutingProfile, + classified_tier: Option, + tiers: &TierConfig, +) -> RoutingDecision { + // 1. Explicit model bypass. + if let Some(model) = explicit_model { + return RoutingDecision { + model: model.to_string(), + tier: ModelTier::Complex, // sensible default for explicit + profile, + bypassed: true, + }; + } + + // 2. Determine the target tier from profile or classification. + let target_tier = profile_to_tier(profile) + .or(classified_tier) + .unwrap_or(ModelTier::Complex); // fallback default + + // 3. Try the target tier first, then walk fallbacks. + if let Some(model) = resolve_tier_model(target_tier, tiers) { + return RoutingDecision { + model: model.to_string(), + tier: target_tier, + profile, + bypassed: false, + }; + } + + for fallback_tier in fallback_tiers(target_tier) { + if let Some(model) = resolve_tier_model(fallback_tier, tiers) { + return RoutingDecision { + model: model.to_string(), + tier: fallback_tier, + profile, + bypassed: false, + }; + } + } + + // 4. Absolute last resort — nothing configured anywhere. + RoutingDecision { + model: String::new(), + tier: target_tier, + profile, + bypassed: false, + } +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Internal helpers +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +/// Tier fallback order when the target tier has no models configured. +fn fallback_tiers(starting: ModelTier) -> Vec { + match starting { + ModelTier::Simple => vec![ModelTier::Complex, ModelTier::Reasoning], + ModelTier::Complex => vec![ModelTier::Reasoning, ModelTier::Simple], + ModelTier::Reasoning => vec![ModelTier::Complex, ModelTier::Simple], + ModelTier::Free => vec![ModelTier::Simple, ModelTier::Complex, ModelTier::Reasoning], + } +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Tests +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +#[cfg(test)] +mod tests { + use super::*; + + fn test_tiers() -> TierConfig { + TierConfig { + simple: vec!["deepseek/deepseek-chat".into()], + complex: vec!["anthropic/claude-sonnet-4-20250514".into()], + reasoning: vec!["anthropic/claude-opus-4-6".into()], + free: vec!["venice/venice-uncensored".into()], + } + } + + // ── resolve_tier_model ──────────────────────────────────────── + + #[test] + fn resolve_tier_model_picks_first_in_list() { + let tiers = TierConfig { + simple: vec!["model-a".into(), "model-b".into()], + ..Default::default() + }; + assert_eq!(resolve_tier_model(ModelTier::Simple, &tiers), Some("model-a")); + } + + #[test] + fn resolve_tier_model_empty_tier_returns_none() { + let tiers = TierConfig::default(); + assert_eq!(resolve_tier_model(ModelTier::Simple, &tiers), None); + } + + // ── profile_to_tier ─────────────────────────────────────────── + + #[test] + fn profile_to_tier_eco_is_simple() { + assert_eq!(profile_to_tier(RoutingProfile::Eco), Some(ModelTier::Simple)); + } + + #[test] + fn profile_to_tier_premium_is_complex() { + assert_eq!(profile_to_tier(RoutingProfile::Premium), Some(ModelTier::Complex)); + } + + #[test] + fn profile_to_tier_auto_is_none() { + assert_eq!(profile_to_tier(RoutingProfile::Auto), None); + } + + // ── resolve_model_for_request ───────────────────────────────── + + #[test] + fn resolve_with_explicit_model_bypasses_router() { + let tiers = test_tiers(); + let decision = resolve_model_for_request( + Some("custom/my-model"), + RoutingProfile::Auto, + None, + &tiers, + ); + assert_eq!(decision.model, "custom/my-model"); + assert!(decision.bypassed); + } + + #[test] + fn resolve_with_eco_profile_uses_simple_tier() { + let tiers = test_tiers(); + let decision = resolve_model_for_request( + None, + RoutingProfile::Eco, + None, + &tiers, + ); + assert_eq!(decision.model, "deepseek/deepseek-chat"); + assert_eq!(decision.tier, ModelTier::Simple); + assert!(!decision.bypassed); + } + + #[test] + fn resolve_with_auto_profile_uses_classified_tier() { + let tiers = test_tiers(); + let decision = resolve_model_for_request( + None, + RoutingProfile::Auto, + Some(ModelTier::Reasoning), + &tiers, + ); + assert_eq!(decision.model, "anthropic/claude-opus-4-6"); + assert_eq!(decision.tier, ModelTier::Reasoning); + assert!(!decision.bypassed); + } + + #[test] + fn resolve_falls_back_across_tiers() { + // Simple tier is empty, should fall back to Complex. + let tiers = TierConfig { + simple: vec![], + complex: vec!["fallback-model".into()], + ..Default::default() + }; + let decision = resolve_model_for_request( + None, + RoutingProfile::Eco, // maps to Simple + None, + &tiers, + ); + assert_eq!(decision.model, "fallback-model"); + assert_eq!(decision.tier, ModelTier::Complex); + assert!(!decision.bypassed); + } +} From b2595d39474cdec43559a06cdcd8f582d8bb812f Mon Sep 17 00:00:00 2001 From: sblanchard Date: Sat, 21 Feb 2026 14:04:00 +0100 Subject: [PATCH 04/15] feat(providers): add routing decisions ring buffer Thread-safe DecisionLog backed by parking_lot::Mutex and VecDeque that evicts the oldest entry at capacity. Provides `record()` and `recent()` for observability of smart-router routing choices. --- crates/providers/src/decisions.rs | 111 ++++++++++++++++++++++++++++++ crates/providers/src/lib.rs | 1 + 2 files changed, 112 insertions(+) create mode 100644 crates/providers/src/decisions.rs diff --git a/crates/providers/src/decisions.rs b/crates/providers/src/decisions.rs new file mode 100644 index 0000000..5587d95 --- /dev/null +++ b/crates/providers/src/decisions.rs @@ -0,0 +1,111 @@ +use chrono::{DateTime, Utc}; +use parking_lot::Mutex; +use sa_domain::config::{ModelTier, RoutingProfile}; +use serde::Serialize; +use std::collections::VecDeque; + +/// A single routing decision record. +#[derive(Debug, Clone, Serialize)] +pub struct Decision { + pub timestamp: DateTime, + pub prompt_snippet: String, + pub profile: RoutingProfile, + pub tier: ModelTier, + pub model: String, + pub latency_ms: u64, + pub bypassed: bool, +} + +/// Thread-safe ring buffer of recent routing decisions. +/// +/// Uses `parking_lot::Mutex` for low-overhead synchronisation. +/// The buffer evicts the oldest entry when it reaches capacity, +/// keeping only the most recent decisions for observability. +pub struct DecisionLog { + inner: Mutex>, + capacity: usize, +} + +impl DecisionLog { + /// Create a new decision log with the given maximum capacity. + pub fn new(capacity: usize) -> Self { + Self { + inner: Mutex::new(VecDeque::with_capacity(capacity)), + capacity, + } + } + + /// Record a new decision. If the buffer is at capacity the oldest + /// entry is evicted first. + pub fn record(&self, decision: Decision) { + let mut buf = self.inner.lock(); + if buf.len() >= self.capacity { + buf.pop_front(); + } + buf.push_back(decision); + } + + /// Return the `limit` most recent decisions, newest first. + /// + /// If fewer than `limit` decisions exist, all are returned. + pub fn recent(&self, limit: usize) -> Vec { + let buf = self.inner.lock(); + buf.iter().rev().take(limit).cloned().collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Helper — build a `Decision` with a distinguishing index baked + /// into `prompt_snippet` so assertions can identify ordering. + fn make_decision(index: u64) -> Decision { + Decision { + timestamp: Utc::now(), + prompt_snippet: format!("prompt-{index}"), + profile: RoutingProfile::Auto, + tier: ModelTier::Simple, + model: "test-model".into(), + latency_ms: index, + bypassed: false, + } + } + + #[test] + fn ring_buffer_stores_up_to_capacity() { + let log = DecisionLog::new(3); + for i in 0..5 { + log.record(make_decision(i)); + } + + let recent = log.recent(10); + assert_eq!(recent.len(), 3, "should keep at most 3 entries"); + + // Newest first: 4, 3, 2 + assert_eq!(recent[0].latency_ms, 4); + assert_eq!(recent[1].latency_ms, 3); + assert_eq!(recent[2].latency_ms, 2); + } + + #[test] + fn ring_buffer_recent_respects_limit() { + let log = DecisionLog::new(100); + for i in 0..50 { + log.record(make_decision(i)); + } + + let recent = log.recent(5); + assert_eq!(recent.len(), 5); + // Newest first: 49, 48, 47, 46, 45 + assert_eq!(recent[0].latency_ms, 49); + assert_eq!(recent[4].latency_ms, 45); + } + + #[test] + fn ring_buffer_empty() { + let log = DecisionLog::new(10); + let recent = log.recent(5); + assert!(recent.is_empty()); + } +} diff --git a/crates/providers/src/lib.rs b/crates/providers/src/lib.rs index 926cc4b..a6c52f8 100644 --- a/crates/providers/src/lib.rs +++ b/crates/providers/src/lib.rs @@ -2,6 +2,7 @@ pub mod anthropic; pub mod auth; pub mod bedrock; pub mod classifier; +pub mod decisions; pub mod google; pub mod oauth; pub mod openai_compat; From 1159b6c462a10e0b830d2e48e45f9f71c23a75dd Mon Sep 17 00:00:00 2001 From: sblanchard Date: Sat, 21 Feb 2026 14:08:20 +0100 Subject: [PATCH 05/15] feat(gateway): wire smart router into resolve_provider() Add SmartRouterState to AppState and update resolve_provider() to check the smart router when no explicit model override is provided. The return type now includes an optional model name so the router can specify which model within a provider to use. The resolution order becomes: explicit override -> smart router -> agent models -> role defaults -> any provider. --- crates/gateway/src/bootstrap.rs | 1 + crates/gateway/src/runtime/mod.rs | 46 +++++++++++++++++++++++------- crates/gateway/src/runtime/turn.rs | 21 ++++++++++++-- crates/gateway/src/state.rs | 14 ++++++++- 4 files changed, 67 insertions(+), 15 deletions(-) diff --git a/crates/gateway/src/bootstrap.rs b/crates/gateway/src/bootstrap.rs index 7e510ce..131f9d4 100644 --- a/crates/gateway/src/bootstrap.rs +++ b/crates/gateway/src/bootstrap.rs @@ -297,6 +297,7 @@ pub async fn build_app_state( workspace, bootstrap, llm, + smart_router: None, sessions, identity, lifecycle, diff --git a/crates/gateway/src/runtime/mod.rs b/crates/gateway/src/runtime/mod.rs index 508d16e..3f64827 100644 --- a/crates/gateway/src/runtime/mod.rs +++ b/crates/gateway/src/runtime/mod.rs @@ -78,40 +78,64 @@ pub(super) fn fire_auto_capture(state: &AppState, input: &turn::TurnInput, final /// Provider resolution order: /// 1. Explicit model override (from API request / agent.run) -/// 2. Agent-level model mapping (per sub-agent config) -/// 3. Global role defaults (planner/executor/summarizer) -/// 4. Any available provider +/// 2. Smart router (when enabled and no explicit override) +/// 3. Agent-level model mapping (per sub-agent config) +/// 4. Global role defaults (planner/executor/summarizer) +/// 5. Any available provider +/// +/// Returns the provider and an optional model name (when the router +/// selects a specific model within the provider). pub(super) fn resolve_provider( state: &AppState, model_override: Option<&str>, agent_ctx: Option<&agent::AgentContext>, -) -> Result, Box> { + routing_profile: Option, +) -> Result<(Arc, Option), Box> { // 1. Explicit override. if let Some(spec) = model_override { let provider_id = spec.split('/').next().unwrap_or(spec); if let Some(p) = state.llm.get(provider_id) { - return Ok(p); + let model_name = spec.split_once('/').map(|(_, m)| m.to_string()); + return Ok((p, model_name)); + } + } + + // 2. Smart router (when enabled and no explicit override). + if let Some(router) = &state.smart_router { + let profile = routing_profile.unwrap_or(router.default_profile); + // For non-Auto profiles, resolve tier directly (no classifier needed). + let tier = sa_providers::smart_router::profile_to_tier(profile); + if let Some(tier) = tier { + if let Some(model_spec) = sa_providers::smart_router::resolve_tier_model(tier, &router.tiers) { + let provider_id = model_spec.split('/').next().unwrap_or(model_spec); + if let Some(p) = state.llm.get(provider_id) { + let model_name = model_spec.split_once('/').map(|(_, m)| m.to_string()); + return Ok((p, model_name)); + } + } } + // Auto profile without classifier falls through to role-based routing. } - // 2. Agent-level model mapping. + // 3. Agent-level model mapping. if let Some(ctx) = agent_ctx { if let Some(spec) = ctx.models.get("executor") { let provider_id = spec.split('/').next().unwrap_or(spec); if let Some(p) = state.llm.get(provider_id) { - return Ok(p); + let model_name = spec.split_once('/').map(|(_, m)| m.to_string()); + return Ok((p, model_name)); } } } - // 3. Global role defaults. + // 4. Global role defaults. if let Some(p) = state.llm.for_role("executor") { - return Ok(p); + return Ok((p, None)); } - // 4. Any available provider. + // 5. Any available provider. if let Some((_, p)) = state.llm.iter().next() { - return Ok(p.clone()); + return Ok((p.clone(), None)); } Err("no_provider_configured: no LLM providers available. \ diff --git a/crates/gateway/src/runtime/turn.rs b/crates/gateway/src/runtime/turn.rs index 73dcca6..f1e0e01 100644 --- a/crates/gateway/src/runtime/turn.rs +++ b/crates/gateway/src/runtime/turn.rs @@ -41,6 +41,8 @@ pub(super) struct TurnContext { provider: Arc, messages: Vec, tool_defs: Arc>, + /// Model name selected by the smart router (if any). + router_model: Option, } // ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ @@ -399,6 +401,7 @@ async fn run_turn_inner( provider, mut messages, tool_defs, + router_model, } = ctx; // ── Phase 2: Tool loop ─────────────────────────────────────────────── @@ -461,6 +464,17 @@ async fn run_turn_inner( ); // Call LLM (streaming). + // Determine which model name to send on the request: + // - Explicit model override (provider/model) takes priority. + // - Router-selected model is used when no explicit override is present. + let effective_model = if let Some(ref m) = input.model { + // Extract the model name from "provider/model" format. + m.split_once('/').map(|(_, model_name)| model_name.to_string()) + .or_else(|| Some(m.clone())) + } else { + router_model.clone() + }; + let req = sa_providers::ChatRequest { messages: messages.clone(), tools: (*tool_defs).clone(), @@ -470,7 +484,7 @@ async fn run_turn_inner( .response_format .clone() .unwrap_or_default(), - model: input.model.clone(), + model: effective_model, }; let llm_call_span = tracing::info_span!( @@ -830,8 +844,8 @@ async fn prepare_turn_context( state: &AppState, input: &TurnInput, ) -> Result> { - // 1. Resolve the LLM provider (agent models -> global roles -> any). - let provider = resolve_provider(state, input.model.as_deref(), input.agent.as_ref())?; + // 1. Resolve the LLM provider (explicit -> router -> agent models -> global roles -> any). + let (provider, resolved_model) = resolve_provider(state, input.model.as_deref(), input.agent.as_ref(), None)?; // 2. Build system context (agent-scoped workspace/skills if present). let system_prompt = build_system_context(state, input.agent.as_ref()).await; @@ -927,5 +941,6 @@ async fn prepare_turn_context( provider, messages, tool_defs, + router_model: resolved_model, }) } diff --git a/crates/gateway/src/state.rs b/crates/gateway/src/state.rs index 78cf057..ec5fe64 100644 --- a/crates/gateway/src/state.rs +++ b/crates/gateway/src/state.rs @@ -4,8 +4,10 @@ use std::sync::Arc; use std::time::Instant; use parking_lot::RwLock; -use sa_domain::config::Config; +use sa_domain::config::{Config, RoutingProfile, TierConfig}; use sa_memory::provider::SerialMemoryProvider; +use sa_providers::classifier::EmbeddingClassifier; +use sa_providers::decisions::DecisionLog; use sa_providers::registry::ProviderRegistry; use sa_sessions::{IdentityResolver, LifecycleManager, SessionStore, TranscriptWriter}; use sa_skills::registry::SkillsRegistry; @@ -43,6 +45,14 @@ pub struct CachedToolDefs { pub policy_key: String, } +/// Smart router state (None when [llm.router] is not configured or disabled). +pub struct SmartRouterState { + pub classifier: Option, + pub tiers: TierConfig, + pub default_profile: RoutingProfile, + pub decisions: DecisionLog, +} + /// Shared application state passed to all API handlers. /// /// Fields are grouped by concern: @@ -58,6 +68,8 @@ pub struct AppState { pub config: Arc, pub memory: Arc, pub llm: Arc, + /// Smart LLM router (None when [llm.router] is absent or disabled). + pub smart_router: Option>, // ── Session management ──────────────────────────────────────────── pub sessions: Arc, From fdacd8146d4f8c581e200a24dddb435d19f0af65 Mon Sep 17 00:00:00 2001 From: sblanchard Date: Sat, 21 Feb 2026 14:11:28 +0100 Subject: [PATCH 06/15] feat(gateway): add router API endpoints (status, classify, decisions) Co-Authored-By: Claude Opus 4.6 --- crates/gateway/src/api/mod.rs | 6 + crates/gateway/src/api/router.rs | 229 +++++++++++++++++++++++++++++++ 2 files changed, 235 insertions(+) create mode 100644 crates/gateway/src/api/router.rs diff --git a/crates/gateway/src/api/mod.rs b/crates/gateway/src/api/mod.rs index 9011286..6faae46 100644 --- a/crates/gateway/src/api/mod.rs +++ b/crates/gateway/src/api/mod.rs @@ -13,6 +13,7 @@ pub mod nodes; pub mod openai_compat; pub mod providers; pub mod quota; +pub mod router; pub mod runs; pub mod schedules; pub mod sessions; @@ -109,6 +110,11 @@ pub fn router(state: AppState) -> Router { .route("/v1/tasks/:id/events", get(tasks::task_events_sse)) // Quotas (per-agent daily usage limits) .route("/v1/quotas", get(quota::get_quotas)) + // Smart router + .route("/v1/router/status", get(router::status)) + .route("/v1/router/config", put(router::update_config)) + .route("/v1/router/classify", post(router::classify)) + .route("/v1/router/decisions", get(router::decisions)) // Runs (execution tracking) .route("/v1/runs", get(runs::list_runs)) .route("/v1/runs/:id", get(runs::get_run)) diff --git a/crates/gateway/src/api/router.rs b/crates/gateway/src/api/router.rs new file mode 100644 index 0000000..86bbbe0 --- /dev/null +++ b/crates/gateway/src/api/router.rs @@ -0,0 +1,229 @@ +//! Smart router API endpoints. +//! +//! - `GET /v1/router/status` — classifier health, active profile, tier config +//! - `PUT /v1/router/config` — update profile, tiers (stub — not yet implemented) +//! - `POST /v1/router/classify` — test: send a prompt, get back tier + scores + model +//! - `GET /v1/router/decisions` — last N routing decisions + +use axum::extract::{Json, Query, State}; +use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +use crate::state::AppState; + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Response / request types +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +#[derive(Serialize)] +struct RouterStatusResponse { + enabled: bool, + default_profile: String, + classifier: ClassifierStatus, + tiers: HashMap>, + thresholds: HashMap, +} + +#[derive(Serialize)] +struct ClassifierStatus { + provider: String, + model: String, + connected: bool, +} + +#[derive(Deserialize)] +pub struct ClassifyRequest { + prompt: String, +} + +#[derive(Serialize)] +struct ClassifyResponse { + tier: String, + scores: HashMap, + resolved_model: String, + latency_ms: u64, +} + +#[derive(Deserialize)] +pub struct DecisionsQuery { + #[serde(default = "default_limit")] + limit: usize, +} + +fn default_limit() -> usize { + 100 +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Helper +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +/// Build a standardized JSON error response: `{ "error": "" }`. +fn api_error(status: StatusCode, message: impl Into) -> Response { + (status, Json(serde_json::json!({ "error": message.into() }))).into_response() +} + +/// Serialize a serde-serializable value to its lowercase JSON string +/// representation (e.g. `RoutingProfile::Auto` -> `"auto"`). +fn ser_lowercase(value: &T) -> String { + serde_json::to_value(value) + .ok() + .and_then(|v| v.as_str().map(String::from)) + .unwrap_or_default() +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// GET /v1/router/status +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +pub async fn status(State(state): State) -> impl IntoResponse { + match &state.smart_router { + Some(router) => { + let classifier_status = match &router.classifier { + Some(c) => ClassifierStatus { + provider: c.config().provider.clone(), + model: c.config().model.clone(), + connected: true, + }, + None => ClassifierStatus { + provider: String::new(), + model: String::new(), + connected: false, + }, + }; + + let mut tiers = HashMap::new(); + tiers.insert("simple".to_string(), router.tiers.simple.clone()); + tiers.insert("complex".to_string(), router.tiers.complex.clone()); + tiers.insert("reasoning".to_string(), router.tiers.reasoning.clone()); + tiers.insert("free".to_string(), router.tiers.free.clone()); + + let thresholds = if let Some(ref rc) = state.config.llm.router { + let mut t = HashMap::new(); + t.insert( + "simple_min_score".to_string(), + serde_json::json!(rc.thresholds.simple_min_score), + ); + t.insert( + "complex_min_score".to_string(), + serde_json::json!(rc.thresholds.complex_min_score), + ); + t.insert( + "reasoning_min_score".to_string(), + serde_json::json!(rc.thresholds.reasoning_min_score), + ); + t.insert( + "escalate_token_threshold".to_string(), + serde_json::json!(rc.thresholds.escalate_token_threshold), + ); + t + } else { + HashMap::new() + }; + + let resp = RouterStatusResponse { + enabled: true, + default_profile: ser_lowercase(&router.default_profile), + classifier: classifier_status, + tiers, + thresholds, + }; + Json(serde_json::json!(resp)).into_response() + } + None => Json(serde_json::json!({ + "enabled": false, + "default_profile": "auto", + "classifier": { "provider": "", "model": "", "connected": false }, + "tiers": {}, + "thresholds": {} + })) + .into_response(), + } +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// PUT /v1/router/config (stub) +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +/// Stub — runtime config update requires rebuilding the router. +/// Returns 501 Not Implemented until hot-reload support is added. +pub async fn update_config(State(_state): State) -> Response { + api_error( + StatusCode::NOT_IMPLEMENTED, + "runtime router config update is not yet supported", + ) +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// POST /v1/router/classify +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +pub async fn classify( + State(state): State, + Json(req): Json, +) -> Response { + let router = match &state.smart_router { + Some(r) => r, + None => return api_error(StatusCode::SERVICE_UNAVAILABLE, "smart router not enabled"), + }; + + let classifier = match &router.classifier { + Some(c) => c, + None => { + return api_error( + StatusCode::SERVICE_UNAVAILABLE, + "classifier not initialized", + ) + } + }; + + match classifier.classify(&req.prompt).await { + Ok(result) => { + let resolved = sa_providers::smart_router::resolve_model_for_request( + None, + router.default_profile, + Some(result.tier), + &router.tiers, + ); + + let scores: HashMap = result + .scores + .iter() + .map(|(k, v)| (ser_lowercase(k), *v)) + .collect(); + + let resp = ClassifyResponse { + tier: ser_lowercase(&result.tier), + scores, + resolved_model: resolved.model, + latency_ms: result.latency_ms, + }; + Json(serde_json::json!(resp)).into_response() + } + Err(e) => api_error( + StatusCode::INTERNAL_SERVER_ERROR, + format!("classification failed: {e}"), + ), + } +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// GET /v1/router/decisions +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +pub async fn decisions( + State(state): State, + Query(query): Query, +) -> impl IntoResponse { + let items = match &state.smart_router { + Some(router) => router.decisions.recent(query.limit), + None => Vec::new(), + }; + + Json(serde_json::json!({ + "decisions": items, + "count": items.len(), + })) +} From a12c0646f2d7e2c6a1140e7df9f7c06ea050d518 Mon Sep 17 00:00:00 2001 From: sblanchard Date: Sat, 21 Feb 2026 14:14:42 +0100 Subject: [PATCH 07/15] feat(gateway): add routing_profile field to schedules Wire routing_profile through the Schedule model, create/update API, and schedule runner so scheduled runs can use a specific router profile instead of the global default. Co-Authored-By: Claude Opus 4.6 --- crates/gateway/src/api/chat.rs | 2 ++ crates/gateway/src/api/inbound.rs | 1 + crates/gateway/src/api/openai_compat.rs | 2 ++ crates/gateway/src/api/schedules.rs | 7 +++++++ crates/gateway/src/api/tasks.rs | 1 + crates/gateway/src/cli/chat.rs | 1 + crates/gateway/src/cli/run.rs | 1 + crates/gateway/src/import/openclaw/mod.rs | 1 + crates/gateway/src/runtime/agent.rs | 1 + crates/gateway/src/runtime/digest.rs | 1 + crates/gateway/src/runtime/schedule_runner.rs | 13 +++++++++++++ crates/gateway/src/runtime/schedules/model.rs | 7 +++++++ crates/gateway/src/runtime/turn.rs | 4 +++- 13 files changed, 41 insertions(+), 1 deletion(-) diff --git a/crates/gateway/src/api/chat.rs b/crates/gateway/src/api/chat.rs index ae46edc..daf1ffd 100644 --- a/crates/gateway/src/api/chat.rs +++ b/crates/gateway/src/api/chat.rs @@ -85,6 +85,7 @@ pub async fn chat( model: body.model, response_format: body.response_format, agent: None, + routing_profile: None, }; let (_run_id, mut rx) = run_turn(state.clone(), input); @@ -209,6 +210,7 @@ pub async fn chat_stream( model: body.model, response_format: body.response_format, agent: None, + routing_profile: None, }; let (_run_id, rx) = run_turn(state.clone(), input); diff --git a/crates/gateway/src/api/inbound.rs b/crates/gateway/src/api/inbound.rs index 4eb0731..97e5a2e 100644 --- a/crates/gateway/src/api/inbound.rs +++ b/crates/gateway/src/api/inbound.rs @@ -460,6 +460,7 @@ pub async fn inbound( model: body.model, response_format: None, agent: None, + routing_profile: None, }; let (_run_id, mut rx) = run_turn(state.clone(), input); diff --git a/crates/gateway/src/api/openai_compat.rs b/crates/gateway/src/api/openai_compat.rs index 4b87e91..93ea7be 100644 --- a/crates/gateway/src/api/openai_compat.rs +++ b/crates/gateway/src/api/openai_compat.rs @@ -170,6 +170,7 @@ async fn chat_completions_blocking( model: Some(body.model), response_format: body.response_format, agent: None, + routing_profile: None, }; let (_run_id, mut rx) = run_turn(state, input); @@ -285,6 +286,7 @@ async fn chat_completions_stream(state: AppState, body: OpenAIChatRequest) -> im model: Some(body.model), response_format: body.response_format, agent: None, + routing_profile: None, }; let (_run_id, rx) = run_turn(state, input); diff --git a/crates/gateway/src/api/schedules.rs b/crates/gateway/src/api/schedules.rs index 9b761c7..4f744d2 100644 --- a/crates/gateway/src/api/schedules.rs +++ b/crates/gateway/src/api/schedules.rs @@ -87,6 +87,8 @@ pub struct CreateScheduleRequest { pub max_catchup_runs: usize, #[serde(default)] pub webhook_secret: Option, + #[serde(default)] + pub routing_profile: Option, } fn default_max_catchup_runs() -> usize { @@ -164,6 +166,7 @@ pub async fn create_schedule( fetch_config: req.fetch_config, max_catchup_runs: req.max_catchup_runs, webhook_secret: req.webhook_secret, + routing_profile: req.routing_profile, source_states: std::collections::HashMap::new(), last_error: None, last_error_at: None, @@ -203,6 +206,7 @@ pub struct UpdateScheduleRequest { pub fetch_config: Option, pub max_catchup_runs: Option, pub webhook_secret: Option>, + pub routing_profile: Option>, } pub async fn update_schedule( @@ -299,6 +303,9 @@ pub async fn update_schedule( if let Some(ws) = req.webhook_secret { s.webhook_secret = ws; } + if let Some(rp) = req.routing_profile { + s.routing_profile = rp; + } }) .await { diff --git a/crates/gateway/src/api/tasks.rs b/crates/gateway/src/api/tasks.rs index 25aa7be..69a70d0 100644 --- a/crates/gateway/src/api/tasks.rs +++ b/crates/gateway/src/api/tasks.rs @@ -112,6 +112,7 @@ pub async fn create_task( model: body.model, response_format: None, agent: None, + routing_profile: None, }; // Enqueue the task for execution. diff --git a/crates/gateway/src/cli/chat.rs b/crates/gateway/src/cli/chat.rs index f13d6aa..14b6c3d 100644 --- a/crates/gateway/src/cli/chat.rs +++ b/crates/gateway/src/cli/chat.rs @@ -205,6 +205,7 @@ async fn send_message( model: model.clone(), response_format: None, agent: None, + routing_profile: None, }; let (_run_id, mut rx) = run_turn(state.clone(), input); diff --git a/crates/gateway/src/cli/run.rs b/crates/gateway/src/cli/run.rs index a2514bd..64570e8 100644 --- a/crates/gateway/src/cli/run.rs +++ b/crates/gateway/src/cli/run.rs @@ -43,6 +43,7 @@ pub async fn run( model, response_format: None, agent: None, + routing_profile: None, }; // 4. Run the turn and obtain the event receiver. diff --git a/crates/gateway/src/import/openclaw/mod.rs b/crates/gateway/src/import/openclaw/mod.rs index 9cc8c37..79a2081 100644 --- a/crates/gateway/src/import/openclaw/mod.rs +++ b/crates/gateway/src/import/openclaw/mod.rs @@ -395,6 +395,7 @@ pub async fn import_schedules( last_error_at: None, consecutive_failures: 0, cooldown_until: None, + routing_profile: None, webhook_secret: None, total_input_tokens: 0, total_output_tokens: 0, diff --git a/crates/gateway/src/runtime/agent.rs b/crates/gateway/src/runtime/agent.rs index 5df0b8d..bd94778 100644 --- a/crates/gateway/src/runtime/agent.rs +++ b/crates/gateway/src/runtime/agent.rs @@ -298,6 +298,7 @@ pub async fn run_agent( model, response_format: None, agent: Some(ctx), + routing_profile: None, }; let (_run_id, mut rx) = run_turn((*state).clone(), input); diff --git a/crates/gateway/src/runtime/digest.rs b/crates/gateway/src/runtime/digest.rs index d68ca46..db403e4 100644 --- a/crates/gateway/src/runtime/digest.rs +++ b/crates/gateway/src/runtime/digest.rs @@ -405,6 +405,7 @@ mod tests { last_error_at: None, consecutive_failures: 0, cooldown_until: None, + routing_profile: None, webhook_secret: None, total_input_tokens: 0, total_output_tokens: 0, diff --git a/crates/gateway/src/runtime/schedule_runner.rs b/crates/gateway/src/runtime/schedule_runner.rs index b490833..e2b8388 100644 --- a/crates/gateway/src/runtime/schedule_runner.rs +++ b/crates/gateway/src/runtime/schedule_runner.rs @@ -257,6 +257,18 @@ pub async fn spawn_scheduled_run( Utc::now().format("%Y%m%d%H%M%S") ); + let routing_profile = schedule.routing_profile.as_deref() + .and_then(|s| { + match s { + "auto" => Some(sa_domain::config::RoutingProfile::Auto), + "eco" => Some(sa_domain::config::RoutingProfile::Eco), + "premium" => Some(sa_domain::config::RoutingProfile::Premium), + "free" => Some(sa_domain::config::RoutingProfile::Free), + "reasoning" => Some(sa_domain::config::RoutingProfile::Reasoning), + _ => None, + } + }); + let input = crate::runtime::TurnInput { session_key, session_id, @@ -264,6 +276,7 @@ pub async fn spawn_scheduled_run( model: None, response_format: None, agent: None, + routing_profile, }; let (run_id, mut rx) = crate::runtime::run_turn(state.clone(), input); diff --git a/crates/gateway/src/runtime/schedules/model.rs b/crates/gateway/src/runtime/schedules/model.rs index 81ffccb..0fb369f 100644 --- a/crates/gateway/src/runtime/schedules/model.rs +++ b/crates/gateway/src/runtime/schedules/model.rs @@ -168,6 +168,12 @@ pub struct Schedule { #[serde(default)] pub cooldown_until: Option>, + // ── LLM routing ──────────────────────────────────────────────────── + /// Routing profile override for this schedule (e.g. "auto", "eco", "premium"). + /// None = use default profile from router config. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub routing_profile: Option, + // ── Webhook trigger ─────────────────────────────────────────────── /// HMAC-SHA256 secret for webhook trigger authentication. /// When set, `POST /v1/schedules/:id/trigger` additionally verifies @@ -287,6 +293,7 @@ mod tests { last_error_at: None, consecutive_failures, cooldown_until: None, + routing_profile: None, webhook_secret: None, total_input_tokens: 0, total_output_tokens: 0, diff --git a/crates/gateway/src/runtime/turn.rs b/crates/gateway/src/runtime/turn.rs index f1e0e01..1cd4e42 100644 --- a/crates/gateway/src/runtime/turn.rs +++ b/crates/gateway/src/runtime/turn.rs @@ -118,6 +118,8 @@ pub struct TurnInput { pub response_format: Option, /// When running as a sub-agent, carries agent-scoped overrides. pub agent: Option, + /// Routing profile override. None = use default. + pub routing_profile: Option, } // ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ @@ -845,7 +847,7 @@ async fn prepare_turn_context( input: &TurnInput, ) -> Result> { // 1. Resolve the LLM provider (explicit -> router -> agent models -> global roles -> any). - let (provider, resolved_model) = resolve_provider(state, input.model.as_deref(), input.agent.as_ref(), None)?; + let (provider, resolved_model) = resolve_provider(state, input.model.as_deref(), input.agent.as_ref(), input.routing_profile)?; // 2. Build system context (agent-scoped workspace/skills if present). let system_prompt = build_system_context(state, input.agent.as_ref()).await; From 37d126ba190d2c4c813063feca53bdc3f4ad0916 Mon Sep 17 00:00:00 2001 From: sblanchard Date: Sat, 21 Feb 2026 14:16:34 +0100 Subject: [PATCH 08/15] feat(gateway): initialize smart router at startup from config --- crates/gateway/src/bootstrap.rs | 43 ++++++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/crates/gateway/src/bootstrap.rs b/crates/gateway/src/bootstrap.rs index 131f9d4..6b88315 100644 --- a/crates/gateway/src/bootstrap.rs +++ b/crates/gateway/src/bootstrap.rs @@ -289,6 +289,47 @@ pub async fn build_app_state( ); } + // ── Smart router ────────────────────────────────────────────────── + let smart_router = if let Some(ref router_cfg) = config.llm.router { + if router_cfg.enabled { + let classifier = match sa_providers::classifier::EmbeddingClassifier::initialize( + router_cfg.classifier.clone(), + router_cfg.thresholds.clone(), + ) + .await + { + Ok(c) => { + tracing::info!( + provider = %router_cfg.classifier.provider, + model = %router_cfg.classifier.model, + "smart router classifier initialized" + ); + Some(c) + } + Err(e) => { + tracing::warn!( + error = %e, + "smart router classifier failed to initialize, \ + routing will use fixed profiles only" + ); + None + } + }; + + Some(Arc::new(crate::state::SmartRouterState { + classifier, + tiers: router_cfg.tiers.clone(), + default_profile: router_cfg.default_profile, + decisions: sa_providers::decisions::DecisionLog::new(100), + })) + } else { + tracing::debug!("smart router configured but disabled"); + None + } + } else { + None + }; + // ── App state (without agents — needed for AgentManager init) ─── let mut state = AppState { config: config.clone(), @@ -297,7 +338,7 @@ pub async fn build_app_state( workspace, bootstrap, llm, - smart_router: None, + smart_router, sessions, identity, lifecycle, From b18a15ac1024569946754ea4ea3f8666e49d9e27 Mon Sep 17 00:00:00 2001 From: sblanchard Date: Sat, 21 Feb 2026 14:20:25 +0100 Subject: [PATCH 09/15] feat(dashboard): add router API client methods and types --- apps/dashboard/src/api/client.ts | 39 ++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/apps/dashboard/src/api/client.ts b/apps/dashboard/src/api/client.ts index eda5d6f..03204ce 100644 --- a/apps/dashboard/src/api/client.ts +++ b/apps/dashboard/src/api/client.ts @@ -756,6 +756,38 @@ export type QuotaListResponse = { quotas: QuotaStatus[]; }; +// ── Router types ──────────────────────────────────────────────────── + +export type RouterStatus = { + enabled: boolean; + default_profile: string; + classifier: { + provider: string; + model: string; + connected: boolean; + avg_latency_ms: number; + }; + tiers: Record; + thresholds: Record; +}; + +export type ClassifyResult = { + tier: string; + scores: Record; + resolved_model: string; + latency_ms: number; +}; + +export type RouterDecision = { + timestamp: string; + prompt_snippet: string; + profile: string; + tier: string; + model: string; + latency_ms: number; + bypassed: boolean; +}; + // ── API functions ────────────────────────────────────────────────── export const api = { @@ -871,4 +903,11 @@ export const api = { // Provider listing providers: () => get<{ providers: string[]; count: number }>("/v1/models"), roles: () => get<{ roles: Record }>("/v1/models/roles"), + + // Router + routerStatus: () => get("/v1/router/status"), + classifyPrompt: (prompt: string) => + post("/v1/router/classify", { prompt }), + routerDecisions: (limit = 100) => + get(`/v1/router/decisions?limit=${limit}`), }; From 11fa461841818d4aeb2152f546f8fae5367a3e45 Mon Sep 17 00:00:00 2001 From: sblanchard Date: Sat, 21 Feb 2026 14:22:21 +0100 Subject: [PATCH 10/15] fix(dashboard): make avg_latency_ms optional, fix routerDecisions response shape --- apps/dashboard/src/api/client.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/dashboard/src/api/client.ts b/apps/dashboard/src/api/client.ts index 03204ce..3aa6578 100644 --- a/apps/dashboard/src/api/client.ts +++ b/apps/dashboard/src/api/client.ts @@ -765,7 +765,7 @@ export type RouterStatus = { provider: string; model: string; connected: boolean; - avg_latency_ms: number; + avg_latency_ms?: number; }; tiers: Record; thresholds: Record; @@ -909,5 +909,5 @@ export const api = { classifyPrompt: (prompt: string) => post("/v1/router/classify", { prompt }), routerDecisions: (limit = 100) => - get(`/v1/router/decisions?limit=${limit}`), + get<{ decisions: RouterDecision[]; count: number }>(`/v1/router/decisions?limit=${limit}`), }; From f9560a735d6a5f81a3c53b5fe450588ff9405dee Mon Sep 17 00:00:00 2001 From: sblanchard Date: Sat, 21 Feb 2026 14:23:56 +0100 Subject: [PATCH 11/15] feat(dashboard): add LLM Router card to Settings page --- apps/dashboard/src/pages/Settings.vue | 133 +++++++++++++++++++++++++- 1 file changed, 132 insertions(+), 1 deletion(-) diff --git a/apps/dashboard/src/pages/Settings.vue b/apps/dashboard/src/pages/Settings.vue index a0086bc..4d89071 100644 --- a/apps/dashboard/src/pages/Settings.vue +++ b/apps/dashboard/src/pages/Settings.vue @@ -2,6 +2,7 @@ import { ref, computed, onMounted } from "vue"; import { api, ApiError, setApiToken, getApiToken } from "@/api/client"; import type { SystemInfo, ReadinessResponse } from "@/api/client"; +import type { RouterStatus, RouterDecision } from "@/api/client"; import Card from "@/components/Card.vue"; import LoadingPanel from "@/components/LoadingPanel.vue"; import ConfigEditor from "@/components/ConfigEditor.vue"; @@ -20,6 +21,27 @@ const tokenSaved = ref(false); // Restart const restarting = ref(false); +// Router +const routerStatus = ref(null); +const routerDecisions = ref([]); +const routerLoading = ref(false); +const routerError = ref(""); +const decisionsExpanded = ref(false); + +async function loadRouter() { + routerLoading.value = true; + routerError.value = ""; + try { + routerStatus.value = await api.routerStatus(); + const res = await api.routerDecisions(20); + routerDecisions.value = res.decisions; + } catch (e: unknown) { + routerError.value = e instanceof ApiError ? e.friendly : String(e); + } finally { + routerLoading.value = false; + } +} + const generatedToml = computed(() => { if (!sysInfo.value || !readiness.value) return ""; return configToToml(sysInfo.value, readiness.value); @@ -64,7 +86,10 @@ async function load() { } } -onMounted(load); +onMounted(() => { + load(); + loadRouter(); +}); + + + +
+ + {{ routerStatus.enabled ? "Enabled" : "Disabled" }} + + {{ routerStatus.default_profile }} +
+ + +
Classifier
+
+
Provider {{ routerStatus.classifier.provider }}
+
Model {{ routerStatus.classifier.model }}
+
Status + + {{ routerStatus.classifier.connected ? "Connected" : "Disconnected" }} + +
+
+ Avg Latency + {{ routerStatus.classifier.avg_latency_ms }}ms +
+
+ + +
Tier Assignments
+
+ {{ tier }} + {{ models.join(", ") || "\u2014" }} +
+ + +
+ Recent Decisions {{ decisionsExpanded ? "\u25BE" : "\u25B8" }} +
+
+
+ {{ new Date(d.timestamp).toLocaleTimeString() }} + {{ d.tier }} + {{ d.model }} + {{ d.latency_ms }}ms + {{ d.prompt_snippet }} +
+
+
+ No routing decisions recorded yet. +
+ +

{{ routerError }}

+
@@ -297,4 +374,58 @@ button.secondary:hover { color: var(--text); border-color: var(--text-dim); } button.secondary.danger { border-color: var(--red); color: var(--red); } button.secondary.danger:hover { background: var(--red); color: #fff; } button.secondary.danger:disabled { opacity: 0.5; cursor: not-allowed; } + +/* Router card */ +.profile-badge { + background: var(--accent); + color: #fff; + padding: 0.15rem 0.5rem; + border-radius: 3px; + font-size: 0.75rem; + font-weight: 600; + text-transform: uppercase; +} +.tier-row { + display: flex; + align-items: center; + gap: 0.8rem; + padding: 0.2rem 0; + font-size: 0.82rem; +} +.tier-label { + min-width: 5rem; + color: var(--text-dim); + font-weight: 500; + text-transform: capitalize; +} +.clickable { cursor: pointer; user-select: none; } +.decisions-log { + max-height: 300px; + overflow-y: auto; + font-size: 0.78rem; +} +.decision-row { + display: flex; + align-items: center; + gap: 0.6rem; + padding: 0.2rem 0; + border-bottom: 1px solid var(--border); +} +.decision-snippet { + flex: 1; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} +.tier-badge { + padding: 0.1rem 0.4rem; + border-radius: 3px; + font-size: 0.7rem; + font-weight: 600; + text-transform: uppercase; +} +.tier-simple { background: var(--green); color: #000; } +.tier-complex { background: var(--accent); color: #fff; } +.tier-reasoning { background: var(--red); color: #fff; } +.tier-free { background: var(--text-dim); color: #fff; } From 54e6b5e7caecf203e8f2a5548ca8dfa1ab059526 Mon Sep 17 00:00:00 2001 From: sblanchard Date: Sat, 21 Feb 2026 14:26:06 +0100 Subject: [PATCH 12/15] feat(dashboard): add routing profile dropdown to schedule forms Add routing_profile field to Schedule, CreateScheduleRequest, and UpdateScheduleRequest TypeScript types. Wire a "Routing Profile" dropdown into both the create form (Schedules.vue) and edit form (ScheduleDetail.vue) with options: Default (inherit), Auto, Eco, Premium, Free, and Reasoning. --- apps/dashboard/src/api/client.ts | 3 +++ apps/dashboard/src/pages/ScheduleDetail.vue | 14 ++++++++++++++ apps/dashboard/src/pages/Schedules.vue | 14 ++++++++++++++ 3 files changed, 31 insertions(+) diff --git a/apps/dashboard/src/api/client.ts b/apps/dashboard/src/api/client.ts index 3aa6578..7e36014 100644 --- a/apps/dashboard/src/api/client.ts +++ b/apps/dashboard/src/api/client.ts @@ -636,6 +636,7 @@ export type Schedule = { max_concurrency: number; timeout_ms?: number; digest_mode: DigestMode; + routing_profile?: string; fetch_config: FetchConfig; source_states: Record; max_catchup_runs: number; @@ -671,6 +672,7 @@ export type CreateScheduleRequest = { max_concurrency?: number; timeout_ms?: number; digest_mode?: DigestMode; + routing_profile?: string; fetch_config?: Partial; max_catchup_runs?: number; }; @@ -688,6 +690,7 @@ export type UpdateScheduleRequest = { max_concurrency?: number; timeout_ms?: number | null; digest_mode?: DigestMode; + routing_profile?: string; fetch_config?: Partial; max_catchup_runs?: number; }; diff --git a/apps/dashboard/src/pages/ScheduleDetail.vue b/apps/dashboard/src/pages/ScheduleDetail.vue index 0c20dac..8dfe938 100644 --- a/apps/dashboard/src/pages/ScheduleDetail.vue +++ b/apps/dashboard/src/pages/ScheduleDetail.vue @@ -33,6 +33,7 @@ const editMissedPolicy = ref("run_once"); const editDigestMode = ref("full"); const editMaxConcurrency = ref(1); const editMaxCatchupRuns = ref(5); +const editRoutingProfile = ref(""); const editTimeoutMs = ref(null); const editSubmitting = ref(false); const editError = ref(""); @@ -49,6 +50,7 @@ function startEdit() { editDigestMode.value = s.digest_mode; editMaxConcurrency.value = s.max_concurrency; editMaxCatchupRuns.value = s.max_catchup_runs; + editRoutingProfile.value = s.routing_profile ?? ""; editTimeoutMs.value = s.timeout_ms ?? null; editError.value = ""; editing.value = true; @@ -83,6 +85,7 @@ async function submitEdit() { digest_mode: editDigestMode.value, max_concurrency: editMaxConcurrency.value, max_catchup_runs: editMaxCatchupRuns.value, + routing_profile: editRoutingProfile.value || undefined, timeout_ms: editTimeoutMs.value, }; @@ -368,6 +371,17 @@ function goToRun(runId?: string) { +
+ + +
diff --git a/apps/dashboard/src/pages/Schedules.vue b/apps/dashboard/src/pages/Schedules.vue index 32f2ff5..48b8734 100644 --- a/apps/dashboard/src/pages/Schedules.vue +++ b/apps/dashboard/src/pages/Schedules.vue @@ -34,6 +34,7 @@ const formMissedPolicy = ref<"skip" | "run_once" | "catch_up">("run_once"); const formDigestMode = ref<"full" | "changes_only">("full"); const formMaxConcurrency = ref(1); const formMaxCatchupRuns = ref(5); +const formRoutingProfile = ref(""); const formSubmitting = ref(false); const formError = ref(""); @@ -90,6 +91,7 @@ function openForm() { formDigestMode.value = "full"; formMaxConcurrency.value = 1; formMaxCatchupRuns.value = 5; + formRoutingProfile.value = ""; formError.value = ""; } @@ -121,6 +123,7 @@ async function submitForm() { digest_mode: formDigestMode.value, max_concurrency: formMaxConcurrency.value, max_catchup_runs: formMaxCatchupRuns.value, + routing_profile: formRoutingProfile.value || undefined, }; formSubmitting.value = true; @@ -299,6 +302,17 @@ function goToSchedule(id: string) {
+
+ + +
From 369ccc7e6cc4ff777f1bd63b7a24e7de1f14a5f2 Mon Sep 17 00:00:00 2001 From: sblanchard Date: Sat, 21 Feb 2026 14:27:01 +0100 Subject: [PATCH 13/15] fix(dashboard): show router card error state and loading indicator --- apps/dashboard/src/pages/Settings.vue | 87 ++++++++++++++------------- 1 file changed, 46 insertions(+), 41 deletions(-) diff --git a/apps/dashboard/src/pages/Settings.vue b/apps/dashboard/src/pages/Settings.vue index 4d89071..a66a32f 100644 --- a/apps/dashboard/src/pages/Settings.vue +++ b/apps/dashboard/src/pages/Settings.vue @@ -202,53 +202,58 @@ onMounted(() => { - -
- - {{ routerStatus.enabled ? "Enabled" : "Disabled" }} - - {{ routerStatus.default_profile }} -
- - -
Classifier
-
-
Provider {{ routerStatus.classifier.provider }}
-
Model {{ routerStatus.classifier.model }}
-
Status - - {{ routerStatus.classifier.connected ? "Connected" : "Disconnected" }} + + + + +

{{ routerError }}

From 4c2c730609406177c533b4b9ee56a7652e52c316 Mon Sep 17 00:00:00 2001 From: sblanchard Date: Sat, 21 Feb 2026 14:29:34 +0100 Subject: [PATCH 14/15] test(providers): add smart router integration tests Full round-trip validation of the routing pipeline without requiring Ollama or any external services. Covers all five routing profiles, explicit model bypass, tier fallback chains, and decision log recording with capacity eviction. --- crates/providers/tests/router_integration.rs | 252 +++++++++++++++++++ 1 file changed, 252 insertions(+) create mode 100644 crates/providers/tests/router_integration.rs diff --git a/crates/providers/tests/router_integration.rs b/crates/providers/tests/router_integration.rs new file mode 100644 index 0000000..c5e41d9 --- /dev/null +++ b/crates/providers/tests/router_integration.rs @@ -0,0 +1,252 @@ +//! Integration tests for the smart router — full round-trip without Ollama. +//! +//! These tests validate the complete routing flow across multiple modules +//! (smart_router + decisions) without requiring any external services. +//! All tests are pure and deterministic. + +use chrono::Utc; +use sa_domain::config::{ModelTier, RoutingProfile, TierConfig}; +use sa_providers::decisions::{Decision, DecisionLog}; +use sa_providers::smart_router::resolve_model_for_request; +use std::time::Instant; + +fn test_tiers() -> TierConfig { + TierConfig { + simple: vec!["deepseek/deepseek-chat".into()], + complex: vec!["anthropic/claude-sonnet-4-20250514".into()], + reasoning: vec!["anthropic/claude-opus-4-6".into()], + free: vec!["venice/venice-uncensored".into()], + } +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Profile-to-model resolution +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +#[test] +fn eco_profile_resolves_simple_tier_and_logs_decision() { + let tiers = test_tiers(); + let decisions = DecisionLog::new(10); + + let start = Instant::now(); + let decision = resolve_model_for_request(None, RoutingProfile::Eco, None, &tiers); + let latency_ms = start.elapsed().as_millis() as u64; + + assert_eq!(decision.model, "deepseek/deepseek-chat"); + assert_eq!(decision.tier, ModelTier::Simple); + assert_eq!(decision.profile, RoutingProfile::Eco); + assert!(!decision.bypassed); + + // Log the decision to the ring buffer and verify round-trip. + decisions.record(Decision { + timestamp: Utc::now(), + prompt_snippet: "test prompt".into(), + profile: decision.profile, + tier: decision.tier, + model: decision.model.clone(), + latency_ms, + bypassed: decision.bypassed, + }); + + let recent = decisions.recent(10); + assert_eq!(recent.len(), 1); + assert_eq!(recent[0].model, "deepseek/deepseek-chat"); + assert_eq!(recent[0].tier, ModelTier::Simple); +} + +#[test] +fn premium_profile_resolves_complex_tier() { + let tiers = test_tiers(); + let decision = resolve_model_for_request(None, RoutingProfile::Premium, None, &tiers); + + assert_eq!(decision.model, "anthropic/claude-sonnet-4-20250514"); + assert_eq!(decision.tier, ModelTier::Complex); + assert!(!decision.bypassed); +} + +#[test] +fn reasoning_profile_resolves_reasoning_tier() { + let tiers = test_tiers(); + let decision = resolve_model_for_request(None, RoutingProfile::Reasoning, None, &tiers); + + assert_eq!(decision.model, "anthropic/claude-opus-4-6"); + assert_eq!(decision.tier, ModelTier::Reasoning); + assert!(!decision.bypassed); +} + +#[test] +fn free_profile_resolves_free_tier() { + let tiers = test_tiers(); + let decision = resolve_model_for_request(None, RoutingProfile::Free, None, &tiers); + + assert_eq!(decision.model, "venice/venice-uncensored"); + assert_eq!(decision.tier, ModelTier::Free); + assert!(!decision.bypassed); +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Auto profile with classification +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +#[test] +fn auto_profile_with_classified_tier_uses_classification() { + let tiers = test_tiers(); + let decision = resolve_model_for_request( + None, + RoutingProfile::Auto, + Some(ModelTier::Reasoning), + &tiers, + ); + + assert_eq!(decision.model, "anthropic/claude-opus-4-6"); + assert_eq!(decision.tier, ModelTier::Reasoning); + assert!(!decision.bypassed); +} + +#[test] +fn auto_profile_without_classification_falls_back_to_complex() { + let tiers = test_tiers(); + let decision = resolve_model_for_request(None, RoutingProfile::Auto, None, &tiers); + + assert_eq!(decision.model, "anthropic/claude-sonnet-4-20250514"); + assert_eq!(decision.tier, ModelTier::Complex); + assert!(!decision.bypassed); +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Explicit model bypass +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +#[test] +fn explicit_model_bypasses_router() { + let tiers = test_tiers(); + let decision = resolve_model_for_request( + Some("custom/my-fine-tune"), + RoutingProfile::Eco, + Some(ModelTier::Simple), + &tiers, + ); + + assert_eq!(decision.model, "custom/my-fine-tune"); + assert!(decision.bypassed); +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Fallback behaviour +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +#[test] +fn fallback_when_target_tier_empty() { + let tiers = TierConfig { + simple: vec![], + complex: vec!["fallback-model".into()], + reasoning: vec![], + free: vec![], + }; + + let decision = resolve_model_for_request( + None, + RoutingProfile::Eco, // maps to Simple, which is empty + None, + &tiers, + ); + + assert_eq!(decision.model, "fallback-model"); + assert_eq!(decision.tier, ModelTier::Complex); // fell back to Complex + assert!(!decision.bypassed); +} + +#[test] +fn fallback_walks_full_chain_when_multiple_tiers_empty() { + let tiers = TierConfig { + simple: vec![], + complex: vec![], + reasoning: vec!["last-resort".into()], + free: vec![], + }; + + let decision = resolve_model_for_request( + None, + RoutingProfile::Eco, // Simple -> fallback: Complex -> Reasoning + None, + &tiers, + ); + + assert_eq!(decision.model, "last-resort"); + assert_eq!(decision.tier, ModelTier::Reasoning); + assert!(!decision.bypassed); +} + +#[test] +fn all_tiers_empty_returns_empty_model() { + let tiers = TierConfig::default(); // all vecs empty + + let decision = resolve_model_for_request(None, RoutingProfile::Eco, None, &tiers); + + assert!(decision.model.is_empty()); + assert!(!decision.bypassed); +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Decision log round-trip +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +#[test] +fn decision_log_round_trip_multiple_decisions() { + let tiers = test_tiers(); + let decisions = DecisionLog::new(100); + + // Simulate multiple routing decisions. + for profile in &[ + RoutingProfile::Eco, + RoutingProfile::Premium, + RoutingProfile::Reasoning, + ] { + let decision = resolve_model_for_request(None, *profile, None, &tiers); + decisions.record(Decision { + timestamp: Utc::now(), + prompt_snippet: format!("test {:?}", profile), + profile: decision.profile, + tier: decision.tier, + model: decision.model, + latency_ms: 0, + bypassed: decision.bypassed, + }); + } + + let recent = decisions.recent(10); + assert_eq!(recent.len(), 3); + // Newest first: Reasoning, Premium, Eco + assert_eq!(recent[0].tier, ModelTier::Reasoning); + assert_eq!(recent[1].tier, ModelTier::Complex); + assert_eq!(recent[2].tier, ModelTier::Simple); +} + +#[test] +fn decision_log_capacity_evicts_oldest() { + let tiers = test_tiers(); + let decisions = DecisionLog::new(2); // capacity of 2 + + for profile in &[ + RoutingProfile::Eco, + RoutingProfile::Premium, + RoutingProfile::Reasoning, + ] { + let decision = resolve_model_for_request(None, *profile, None, &tiers); + decisions.record(Decision { + timestamp: Utc::now(), + prompt_snippet: format!("test {:?}", profile), + profile: decision.profile, + tier: decision.tier, + model: decision.model, + latency_ms: 0, + bypassed: decision.bypassed, + }); + } + + let recent = decisions.recent(10); + assert_eq!(recent.len(), 2); + // Only the last two remain: Reasoning (newest) and Premium + assert_eq!(recent[0].tier, ModelTier::Reasoning); + assert_eq!(recent[1].tier, ModelTier::Complex); +} From 80d574db9609f4f0037c64853a2673e5f2217619 Mon Sep 17 00:00:00 2001 From: sblanchard Date: Sat, 21 Feb 2026 14:35:47 +0100 Subject: [PATCH 15/15] fix: record routing decisions, validate routing_profile, fix clearing profile --- apps/dashboard/src/pages/ScheduleDetail.vue | 2 +- apps/dashboard/src/pages/Schedules.vue | 2 +- crates/gateway/src/api/schedules.rs | 14 ++++++++++++++ crates/gateway/src/runtime/mod.rs | 10 ++++++++++ 4 files changed, 26 insertions(+), 2 deletions(-) diff --git a/apps/dashboard/src/pages/ScheduleDetail.vue b/apps/dashboard/src/pages/ScheduleDetail.vue index 8dfe938..2209515 100644 --- a/apps/dashboard/src/pages/ScheduleDetail.vue +++ b/apps/dashboard/src/pages/ScheduleDetail.vue @@ -85,7 +85,7 @@ async function submitEdit() { digest_mode: editDigestMode.value, max_concurrency: editMaxConcurrency.value, max_catchup_runs: editMaxCatchupRuns.value, - routing_profile: editRoutingProfile.value || undefined, + routing_profile: editRoutingProfile.value === "" ? null : editRoutingProfile.value, timeout_ms: editTimeoutMs.value, }; diff --git a/apps/dashboard/src/pages/Schedules.vue b/apps/dashboard/src/pages/Schedules.vue index 48b8734..d534b27 100644 --- a/apps/dashboard/src/pages/Schedules.vue +++ b/apps/dashboard/src/pages/Schedules.vue @@ -123,7 +123,7 @@ async function submitForm() { digest_mode: formDigestMode.value, max_concurrency: formMaxConcurrency.value, max_catchup_runs: formMaxCatchupRuns.value, - routing_profile: formRoutingProfile.value || undefined, + routing_profile: formRoutingProfile.value === "" ? undefined : formRoutingProfile.value, }; formSubmitting.value = true; diff --git a/crates/gateway/src/api/schedules.rs b/crates/gateway/src/api/schedules.rs index 4f744d2..c120bce 100644 --- a/crates/gateway/src/api/schedules.rs +++ b/crates/gateway/src/api/schedules.rs @@ -143,6 +143,13 @@ pub async fn create_schedule( } } + // Validate routing_profile if set + if let Some(ref rp) = req.routing_profile { + if !matches!(rp.as_str(), "auto" | "eco" | "premium" | "free" | "reasoning") { + return api_error(StatusCode::BAD_REQUEST, format!("invalid routing_profile: '{rp}'")); + } + } + let now = chrono::Utc::now(); let schedule = crate::runtime::schedules::Schedule { id: uuid::Uuid::new_v4(), @@ -255,6 +262,13 @@ pub async fn update_schedule( } } + // Validate routing_profile if provided + if let Some(Some(ref rp)) = req.routing_profile { + if !matches!(rp.as_str(), "auto" | "eco" | "premium" | "free" | "reasoning") { + return api_error(StatusCode::BAD_REQUEST, format!("invalid routing_profile: '{rp}'")); + } + } + match state .schedule_store .update(&id, |s| { diff --git a/crates/gateway/src/runtime/mod.rs b/crates/gateway/src/runtime/mod.rs index 3f64827..a753786 100644 --- a/crates/gateway/src/runtime/mod.rs +++ b/crates/gateway/src/runtime/mod.rs @@ -110,6 +110,16 @@ pub(super) fn resolve_provider( let provider_id = model_spec.split('/').next().unwrap_or(model_spec); if let Some(p) = state.llm.get(provider_id) { let model_name = model_spec.split_once('/').map(|(_, m)| m.to_string()); + // Record the routing decision for observability. + router.decisions.record(sa_providers::decisions::Decision { + timestamp: chrono::Utc::now(), + prompt_snippet: String::new(), // populated by caller + profile, + tier, + model: model_spec.to_string(), + latency_ms: 0, + bypassed: false, + }); return Ok((p, model_name)); } }