diff --git a/Cargo.lock b/Cargo.lock index 6b90359..e169110 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3929,6 +3929,7 @@ dependencies = [ "fs2", "futures-core", "keyring", + "parking_lot", "reqwest 0.12.28", "sa-domain", "serde", diff --git a/apps/dashboard/src/api/client.ts b/apps/dashboard/src/api/client.ts index eda5d6f..7e36014 100644 --- a/apps/dashboard/src/api/client.ts +++ b/apps/dashboard/src/api/client.ts @@ -636,6 +636,7 @@ export type Schedule = { max_concurrency: number; timeout_ms?: number; digest_mode: DigestMode; + routing_profile?: string; fetch_config: FetchConfig; source_states: Record; max_catchup_runs: number; @@ -671,6 +672,7 @@ export type CreateScheduleRequest = { max_concurrency?: number; timeout_ms?: number; digest_mode?: DigestMode; + routing_profile?: string; fetch_config?: Partial; max_catchup_runs?: number; }; @@ -688,6 +690,7 @@ export type UpdateScheduleRequest = { max_concurrency?: number; timeout_ms?: number | null; digest_mode?: DigestMode; + routing_profile?: string; fetch_config?: Partial; max_catchup_runs?: number; }; @@ -756,6 +759,38 @@ export type QuotaListResponse = { quotas: QuotaStatus[]; }; +// ── Router types ──────────────────────────────────────────────────── + +export type RouterStatus = { + enabled: boolean; + default_profile: string; + classifier: { + provider: string; + model: string; + connected: boolean; + avg_latency_ms?: number; + }; + tiers: Record; + thresholds: Record; +}; + +export type ClassifyResult = { + tier: string; + scores: Record; + resolved_model: string; + latency_ms: number; +}; + +export type RouterDecision = { + timestamp: string; + prompt_snippet: string; + profile: string; + tier: string; + model: string; + latency_ms: number; + bypassed: boolean; +}; + // ── API functions ────────────────────────────────────────────────── export const api = { @@ -871,4 +906,11 @@ export const api = { // Provider listing providers: () => get<{ providers: string[]; count: number }>("/v1/models"), roles: () => get<{ roles: Record }>("/v1/models/roles"), + + // Router + routerStatus: () => get("/v1/router/status"), + classifyPrompt: (prompt: string) => + post("/v1/router/classify", { prompt }), + routerDecisions: (limit = 100) => + get<{ decisions: RouterDecision[]; count: number }>(`/v1/router/decisions?limit=${limit}`), }; diff --git a/apps/dashboard/src/pages/ScheduleDetail.vue b/apps/dashboard/src/pages/ScheduleDetail.vue index 0c20dac..2209515 100644 --- a/apps/dashboard/src/pages/ScheduleDetail.vue +++ b/apps/dashboard/src/pages/ScheduleDetail.vue @@ -33,6 +33,7 @@ const editMissedPolicy = ref("run_once"); const editDigestMode = ref("full"); const editMaxConcurrency = ref(1); const editMaxCatchupRuns = ref(5); +const editRoutingProfile = ref(""); const editTimeoutMs = ref(null); const editSubmitting = ref(false); const editError = ref(""); @@ -49,6 +50,7 @@ function startEdit() { editDigestMode.value = s.digest_mode; editMaxConcurrency.value = s.max_concurrency; editMaxCatchupRuns.value = s.max_catchup_runs; + editRoutingProfile.value = s.routing_profile ?? ""; editTimeoutMs.value = s.timeout_ms ?? null; editError.value = ""; editing.value = true; @@ -83,6 +85,7 @@ async function submitEdit() { digest_mode: editDigestMode.value, max_concurrency: editMaxConcurrency.value, max_catchup_runs: editMaxCatchupRuns.value, + routing_profile: editRoutingProfile.value === "" ? null : editRoutingProfile.value, timeout_ms: editTimeoutMs.value, }; @@ -368,6 +371,17 @@ function goToRun(runId?: string) { +
+ + +
diff --git a/apps/dashboard/src/pages/Schedules.vue b/apps/dashboard/src/pages/Schedules.vue index 32f2ff5..d534b27 100644 --- a/apps/dashboard/src/pages/Schedules.vue +++ b/apps/dashboard/src/pages/Schedules.vue @@ -34,6 +34,7 @@ const formMissedPolicy = ref<"skip" | "run_once" | "catch_up">("run_once"); const formDigestMode = ref<"full" | "changes_only">("full"); const formMaxConcurrency = ref(1); const formMaxCatchupRuns = ref(5); +const formRoutingProfile = ref(""); const formSubmitting = ref(false); const formError = ref(""); @@ -90,6 +91,7 @@ function openForm() { formDigestMode.value = "full"; formMaxConcurrency.value = 1; formMaxCatchupRuns.value = 5; + formRoutingProfile.value = ""; formError.value = ""; } @@ -121,6 +123,7 @@ async function submitForm() { digest_mode: formDigestMode.value, max_concurrency: formMaxConcurrency.value, max_catchup_runs: formMaxCatchupRuns.value, + routing_profile: formRoutingProfile.value === "" ? undefined : formRoutingProfile.value, }; formSubmitting.value = true; @@ -299,6 +302,17 @@ function goToSchedule(id: string) {
+
+ + +
diff --git a/apps/dashboard/src/pages/Settings.vue b/apps/dashboard/src/pages/Settings.vue index a0086bc..a66a32f 100644 --- a/apps/dashboard/src/pages/Settings.vue +++ b/apps/dashboard/src/pages/Settings.vue @@ -2,6 +2,7 @@ import { ref, computed, onMounted } from "vue"; import { api, ApiError, setApiToken, getApiToken } from "@/api/client"; import type { SystemInfo, ReadinessResponse } from "@/api/client"; +import type { RouterStatus, RouterDecision } from "@/api/client"; import Card from "@/components/Card.vue"; import LoadingPanel from "@/components/LoadingPanel.vue"; import ConfigEditor from "@/components/ConfigEditor.vue"; @@ -20,6 +21,27 @@ const tokenSaved = ref(false); // Restart const restarting = ref(false); +// Router +const routerStatus = ref(null); +const routerDecisions = ref([]); +const routerLoading = ref(false); +const routerError = ref(""); +const decisionsExpanded = ref(false); + +async function loadRouter() { + routerLoading.value = true; + routerError.value = ""; + try { + routerStatus.value = await api.routerStatus(); + const res = await api.routerDecisions(20); + routerDecisions.value = res.decisions; + } catch (e: unknown) { + routerError.value = e instanceof ApiError ? e.friendly : String(e); + } finally { + routerLoading.value = false; + } +} + const generatedToml = computed(() => { if (!sysInfo.value || !readiness.value) return ""; return configToToml(sysInfo.value, readiness.value); @@ -64,7 +86,10 @@ async function load() { } } -onMounted(load); +onMounted(() => { + load(); + loadRouter(); +}); + + + + + + + + +

{{ routerError }}

+
@@ -297,4 +379,58 @@ button.secondary:hover { color: var(--text); border-color: var(--text-dim); } button.secondary.danger { border-color: var(--red); color: var(--red); } button.secondary.danger:hover { background: var(--red); color: #fff; } button.secondary.danger:disabled { opacity: 0.5; cursor: not-allowed; } + +/* Router card */ +.profile-badge { + background: var(--accent); + color: #fff; + padding: 0.15rem 0.5rem; + border-radius: 3px; + font-size: 0.75rem; + font-weight: 600; + text-transform: uppercase; +} +.tier-row { + display: flex; + align-items: center; + gap: 0.8rem; + padding: 0.2rem 0; + font-size: 0.82rem; +} +.tier-label { + min-width: 5rem; + color: var(--text-dim); + font-weight: 500; + text-transform: capitalize; +} +.clickable { cursor: pointer; user-select: none; } +.decisions-log { + max-height: 300px; + overflow-y: auto; + font-size: 0.78rem; +} +.decision-row { + display: flex; + align-items: center; + gap: 0.6rem; + padding: 0.2rem 0; + border-bottom: 1px solid var(--border); +} +.decision-snippet { + flex: 1; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} +.tier-badge { + padding: 0.1rem 0.4rem; + border-radius: 3px; + font-size: 0.7rem; + font-weight: 600; + text-transform: uppercase; +} +.tier-simple { background: var(--green); color: #000; } +.tier-complex { background: var(--accent); color: #fff; } +.tier-reasoning { background: var(--red); color: #fff; } +.tier-free { background: var(--text-dim); color: #fff; } diff --git a/crates/domain/src/config/llm.rs b/crates/domain/src/config/llm.rs index 1868315..893f4a6 100644 --- a/crates/domain/src/config/llm.rs +++ b/crates/domain/src/config/llm.rs @@ -39,6 +39,9 @@ pub struct LlmConfig { /// Per-model pricing for cost estimation (key = model name, e.g. "gpt-4o"). #[serde(default)] pub pricing: HashMap, + /// Smart router configuration (optional). + #[serde(default)] + pub router: Option, } impl Default for LlmConfig { @@ -52,6 +55,7 @@ impl Default for LlmConfig { roles: HashMap::new(), providers: Vec::new(), pricing: HashMap::new(), + router: None, } } } @@ -193,6 +197,104 @@ fn d_2() -> u32 { 2 } +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Smart router types +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +/// Routing profile determines how the smart router selects a model tier. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum RoutingProfile { + #[default] + Auto, + Eco, + Premium, + Free, + Reasoning, +} + +/// Model tier for router classification. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ModelTier { + Simple, + Complex, + Reasoning, + Free, +} + +/// Smart router configuration (optional section under [llm]). +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct RouterConfig { + #[serde(default)] + pub enabled: bool, + #[serde(default)] + pub default_profile: RoutingProfile, + #[serde(default)] + pub classifier: ClassifierConfig, + #[serde(default)] + pub tiers: TierConfig, + #[serde(default)] + pub thresholds: RouterThresholds, +} + +/// Embedding classifier configuration. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClassifierConfig { + pub provider: String, + pub model: String, + pub endpoint: String, + pub cache_ttl_secs: u64, +} + +impl Default for ClassifierConfig { + fn default() -> Self { + Self { + provider: "ollama".into(), + model: "nomic-embed-text".into(), + endpoint: "http://localhost:11434".into(), + cache_ttl_secs: 300, + } + } +} + +/// Per-tier ordered list of provider/model strings. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct TierConfig { + #[serde(default)] + pub simple: Vec, + #[serde(default)] + pub complex: Vec, + #[serde(default)] + pub reasoning: Vec, + #[serde(default)] + pub free: Vec, +} + +/// Cosine similarity thresholds for the classifier. +/// +/// Each score is compared independently against the embedding centroid +/// for that tier. A prompt is assigned to the highest-scoring tier +/// that exceeds its threshold. Values are not required to be ordered. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RouterThresholds { + pub simple_min_score: f64, + pub complex_min_score: f64, + pub reasoning_min_score: f64, + pub escalate_token_threshold: usize, +} + +impl Default for RouterThresholds { + fn default() -> Self { + Self { + simple_min_score: 0.6, + complex_min_score: 0.5, + reasoning_min_score: 0.55, + escalate_token_threshold: 8000, + } + } +} + // ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ // Tests // ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ @@ -258,4 +360,66 @@ mod tests { assert!((gpt4o.input_per_1m - 2.50).abs() < 1e-10); assert!((gpt4o.output_per_1m - 10.00).abs() < 1e-10); } + + #[test] + fn router_config_deserializes() { + let json = r#"{ + "router": { + "enabled": true, + "default_profile": "auto", + "classifier": { + "provider": "ollama", + "model": "nomic-embed-text", + "endpoint": "http://localhost:11434", + "cache_ttl_secs": 300 + }, + "tiers": { + "simple": ["deepseek/deepseek-chat"], + "complex": ["anthropic/claude-sonnet-4-20250514"], + "reasoning": ["anthropic/claude-opus-4-6"], + "free": ["venice/venice-uncensored"] + }, + "thresholds": { + "simple_min_score": 0.6, + "complex_min_score": 0.5, + "reasoning_min_score": 0.55, + "escalate_token_threshold": 8000 + } + } + }"#; + let config: LlmConfig = serde_json::from_str(json).unwrap(); + let router = config.router.unwrap(); + assert!(router.enabled); + assert_eq!(router.default_profile, RoutingProfile::Auto); + assert_eq!(router.classifier.model, "nomic-embed-text"); + assert_eq!(router.tiers.simple.len(), 1); + assert!((router.thresholds.simple_min_score - 0.6).abs() < 1e-10); + } + + #[test] + fn router_config_defaults_when_absent() { + let json = r#"{}"#; + let config: LlmConfig = serde_json::from_str(json).unwrap(); + assert!(config.router.is_none()); + } + + #[test] + fn routing_profile_serde_roundtrip() { + for profile in &["auto", "eco", "premium", "free", "reasoning"] { + let json = format!("\"{}\"", profile); + let parsed: RoutingProfile = serde_json::from_str(&json).unwrap(); + let back = serde_json::to_string(&parsed).unwrap(); + assert_eq!(back, json); + } + } + + #[test] + fn model_tier_serde_roundtrip() { + for tier in &["simple", "complex", "reasoning", "free"] { + let json = format!("\"{}\"", tier); + let parsed: ModelTier = serde_json::from_str(&json).unwrap(); + let back = serde_json::to_string(&parsed).unwrap(); + assert_eq!(back, json); + } + } } diff --git a/crates/gateway/src/api/chat.rs b/crates/gateway/src/api/chat.rs index ae46edc..daf1ffd 100644 --- a/crates/gateway/src/api/chat.rs +++ b/crates/gateway/src/api/chat.rs @@ -85,6 +85,7 @@ pub async fn chat( model: body.model, response_format: body.response_format, agent: None, + routing_profile: None, }; let (_run_id, mut rx) = run_turn(state.clone(), input); @@ -209,6 +210,7 @@ pub async fn chat_stream( model: body.model, response_format: body.response_format, agent: None, + routing_profile: None, }; let (_run_id, rx) = run_turn(state.clone(), input); diff --git a/crates/gateway/src/api/inbound.rs b/crates/gateway/src/api/inbound.rs index 4eb0731..97e5a2e 100644 --- a/crates/gateway/src/api/inbound.rs +++ b/crates/gateway/src/api/inbound.rs @@ -460,6 +460,7 @@ pub async fn inbound( model: body.model, response_format: None, agent: None, + routing_profile: None, }; let (_run_id, mut rx) = run_turn(state.clone(), input); diff --git a/crates/gateway/src/api/mod.rs b/crates/gateway/src/api/mod.rs index 9011286..6faae46 100644 --- a/crates/gateway/src/api/mod.rs +++ b/crates/gateway/src/api/mod.rs @@ -13,6 +13,7 @@ pub mod nodes; pub mod openai_compat; pub mod providers; pub mod quota; +pub mod router; pub mod runs; pub mod schedules; pub mod sessions; @@ -109,6 +110,11 @@ pub fn router(state: AppState) -> Router { .route("/v1/tasks/:id/events", get(tasks::task_events_sse)) // Quotas (per-agent daily usage limits) .route("/v1/quotas", get(quota::get_quotas)) + // Smart router + .route("/v1/router/status", get(router::status)) + .route("/v1/router/config", put(router::update_config)) + .route("/v1/router/classify", post(router::classify)) + .route("/v1/router/decisions", get(router::decisions)) // Runs (execution tracking) .route("/v1/runs", get(runs::list_runs)) .route("/v1/runs/:id", get(runs::get_run)) diff --git a/crates/gateway/src/api/openai_compat.rs b/crates/gateway/src/api/openai_compat.rs index 4b87e91..93ea7be 100644 --- a/crates/gateway/src/api/openai_compat.rs +++ b/crates/gateway/src/api/openai_compat.rs @@ -170,6 +170,7 @@ async fn chat_completions_blocking( model: Some(body.model), response_format: body.response_format, agent: None, + routing_profile: None, }; let (_run_id, mut rx) = run_turn(state, input); @@ -285,6 +286,7 @@ async fn chat_completions_stream(state: AppState, body: OpenAIChatRequest) -> im model: Some(body.model), response_format: body.response_format, agent: None, + routing_profile: None, }; let (_run_id, rx) = run_turn(state, input); diff --git a/crates/gateway/src/api/router.rs b/crates/gateway/src/api/router.rs new file mode 100644 index 0000000..86bbbe0 --- /dev/null +++ b/crates/gateway/src/api/router.rs @@ -0,0 +1,229 @@ +//! Smart router API endpoints. +//! +//! - `GET /v1/router/status` — classifier health, active profile, tier config +//! - `PUT /v1/router/config` — update profile, tiers (stub — not yet implemented) +//! - `POST /v1/router/classify` — test: send a prompt, get back tier + scores + model +//! - `GET /v1/router/decisions` — last N routing decisions + +use axum::extract::{Json, Query, State}; +use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +use crate::state::AppState; + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Response / request types +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +#[derive(Serialize)] +struct RouterStatusResponse { + enabled: bool, + default_profile: String, + classifier: ClassifierStatus, + tiers: HashMap>, + thresholds: HashMap, +} + +#[derive(Serialize)] +struct ClassifierStatus { + provider: String, + model: String, + connected: bool, +} + +#[derive(Deserialize)] +pub struct ClassifyRequest { + prompt: String, +} + +#[derive(Serialize)] +struct ClassifyResponse { + tier: String, + scores: HashMap, + resolved_model: String, + latency_ms: u64, +} + +#[derive(Deserialize)] +pub struct DecisionsQuery { + #[serde(default = "default_limit")] + limit: usize, +} + +fn default_limit() -> usize { + 100 +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Helper +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +/// Build a standardized JSON error response: `{ "error": "" }`. +fn api_error(status: StatusCode, message: impl Into) -> Response { + (status, Json(serde_json::json!({ "error": message.into() }))).into_response() +} + +/// Serialize a serde-serializable value to its lowercase JSON string +/// representation (e.g. `RoutingProfile::Auto` -> `"auto"`). +fn ser_lowercase(value: &T) -> String { + serde_json::to_value(value) + .ok() + .and_then(|v| v.as_str().map(String::from)) + .unwrap_or_default() +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// GET /v1/router/status +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +pub async fn status(State(state): State) -> impl IntoResponse { + match &state.smart_router { + Some(router) => { + let classifier_status = match &router.classifier { + Some(c) => ClassifierStatus { + provider: c.config().provider.clone(), + model: c.config().model.clone(), + connected: true, + }, + None => ClassifierStatus { + provider: String::new(), + model: String::new(), + connected: false, + }, + }; + + let mut tiers = HashMap::new(); + tiers.insert("simple".to_string(), router.tiers.simple.clone()); + tiers.insert("complex".to_string(), router.tiers.complex.clone()); + tiers.insert("reasoning".to_string(), router.tiers.reasoning.clone()); + tiers.insert("free".to_string(), router.tiers.free.clone()); + + let thresholds = if let Some(ref rc) = state.config.llm.router { + let mut t = HashMap::new(); + t.insert( + "simple_min_score".to_string(), + serde_json::json!(rc.thresholds.simple_min_score), + ); + t.insert( + "complex_min_score".to_string(), + serde_json::json!(rc.thresholds.complex_min_score), + ); + t.insert( + "reasoning_min_score".to_string(), + serde_json::json!(rc.thresholds.reasoning_min_score), + ); + t.insert( + "escalate_token_threshold".to_string(), + serde_json::json!(rc.thresholds.escalate_token_threshold), + ); + t + } else { + HashMap::new() + }; + + let resp = RouterStatusResponse { + enabled: true, + default_profile: ser_lowercase(&router.default_profile), + classifier: classifier_status, + tiers, + thresholds, + }; + Json(serde_json::json!(resp)).into_response() + } + None => Json(serde_json::json!({ + "enabled": false, + "default_profile": "auto", + "classifier": { "provider": "", "model": "", "connected": false }, + "tiers": {}, + "thresholds": {} + })) + .into_response(), + } +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// PUT /v1/router/config (stub) +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +/// Stub — runtime config update requires rebuilding the router. +/// Returns 501 Not Implemented until hot-reload support is added. +pub async fn update_config(State(_state): State) -> Response { + api_error( + StatusCode::NOT_IMPLEMENTED, + "runtime router config update is not yet supported", + ) +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// POST /v1/router/classify +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +pub async fn classify( + State(state): State, + Json(req): Json, +) -> Response { + let router = match &state.smart_router { + Some(r) => r, + None => return api_error(StatusCode::SERVICE_UNAVAILABLE, "smart router not enabled"), + }; + + let classifier = match &router.classifier { + Some(c) => c, + None => { + return api_error( + StatusCode::SERVICE_UNAVAILABLE, + "classifier not initialized", + ) + } + }; + + match classifier.classify(&req.prompt).await { + Ok(result) => { + let resolved = sa_providers::smart_router::resolve_model_for_request( + None, + router.default_profile, + Some(result.tier), + &router.tiers, + ); + + let scores: HashMap = result + .scores + .iter() + .map(|(k, v)| (ser_lowercase(k), *v)) + .collect(); + + let resp = ClassifyResponse { + tier: ser_lowercase(&result.tier), + scores, + resolved_model: resolved.model, + latency_ms: result.latency_ms, + }; + Json(serde_json::json!(resp)).into_response() + } + Err(e) => api_error( + StatusCode::INTERNAL_SERVER_ERROR, + format!("classification failed: {e}"), + ), + } +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// GET /v1/router/decisions +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +pub async fn decisions( + State(state): State, + Query(query): Query, +) -> impl IntoResponse { + let items = match &state.smart_router { + Some(router) => router.decisions.recent(query.limit), + None => Vec::new(), + }; + + Json(serde_json::json!({ + "decisions": items, + "count": items.len(), + })) +} diff --git a/crates/gateway/src/api/schedules.rs b/crates/gateway/src/api/schedules.rs index 9b761c7..c120bce 100644 --- a/crates/gateway/src/api/schedules.rs +++ b/crates/gateway/src/api/schedules.rs @@ -87,6 +87,8 @@ pub struct CreateScheduleRequest { pub max_catchup_runs: usize, #[serde(default)] pub webhook_secret: Option, + #[serde(default)] + pub routing_profile: Option, } fn default_max_catchup_runs() -> usize { @@ -141,6 +143,13 @@ pub async fn create_schedule( } } + // Validate routing_profile if set + if let Some(ref rp) = req.routing_profile { + if !matches!(rp.as_str(), "auto" | "eco" | "premium" | "free" | "reasoning") { + return api_error(StatusCode::BAD_REQUEST, format!("invalid routing_profile: '{rp}'")); + } + } + let now = chrono::Utc::now(); let schedule = crate::runtime::schedules::Schedule { id: uuid::Uuid::new_v4(), @@ -164,6 +173,7 @@ pub async fn create_schedule( fetch_config: req.fetch_config, max_catchup_runs: req.max_catchup_runs, webhook_secret: req.webhook_secret, + routing_profile: req.routing_profile, source_states: std::collections::HashMap::new(), last_error: None, last_error_at: None, @@ -203,6 +213,7 @@ pub struct UpdateScheduleRequest { pub fetch_config: Option, pub max_catchup_runs: Option, pub webhook_secret: Option>, + pub routing_profile: Option>, } pub async fn update_schedule( @@ -251,6 +262,13 @@ pub async fn update_schedule( } } + // Validate routing_profile if provided + if let Some(Some(ref rp)) = req.routing_profile { + if !matches!(rp.as_str(), "auto" | "eco" | "premium" | "free" | "reasoning") { + return api_error(StatusCode::BAD_REQUEST, format!("invalid routing_profile: '{rp}'")); + } + } + match state .schedule_store .update(&id, |s| { @@ -299,6 +317,9 @@ pub async fn update_schedule( if let Some(ws) = req.webhook_secret { s.webhook_secret = ws; } + if let Some(rp) = req.routing_profile { + s.routing_profile = rp; + } }) .await { diff --git a/crates/gateway/src/api/tasks.rs b/crates/gateway/src/api/tasks.rs index 25aa7be..69a70d0 100644 --- a/crates/gateway/src/api/tasks.rs +++ b/crates/gateway/src/api/tasks.rs @@ -112,6 +112,7 @@ pub async fn create_task( model: body.model, response_format: None, agent: None, + routing_profile: None, }; // Enqueue the task for execution. diff --git a/crates/gateway/src/bootstrap.rs b/crates/gateway/src/bootstrap.rs index 7e510ce..6b88315 100644 --- a/crates/gateway/src/bootstrap.rs +++ b/crates/gateway/src/bootstrap.rs @@ -289,6 +289,47 @@ pub async fn build_app_state( ); } + // ── Smart router ────────────────────────────────────────────────── + let smart_router = if let Some(ref router_cfg) = config.llm.router { + if router_cfg.enabled { + let classifier = match sa_providers::classifier::EmbeddingClassifier::initialize( + router_cfg.classifier.clone(), + router_cfg.thresholds.clone(), + ) + .await + { + Ok(c) => { + tracing::info!( + provider = %router_cfg.classifier.provider, + model = %router_cfg.classifier.model, + "smart router classifier initialized" + ); + Some(c) + } + Err(e) => { + tracing::warn!( + error = %e, + "smart router classifier failed to initialize, \ + routing will use fixed profiles only" + ); + None + } + }; + + Some(Arc::new(crate::state::SmartRouterState { + classifier, + tiers: router_cfg.tiers.clone(), + default_profile: router_cfg.default_profile, + decisions: sa_providers::decisions::DecisionLog::new(100), + })) + } else { + tracing::debug!("smart router configured but disabled"); + None + } + } else { + None + }; + // ── App state (without agents — needed for AgentManager init) ─── let mut state = AppState { config: config.clone(), @@ -297,6 +338,7 @@ pub async fn build_app_state( workspace, bootstrap, llm, + smart_router, sessions, identity, lifecycle, diff --git a/crates/gateway/src/cli/chat.rs b/crates/gateway/src/cli/chat.rs index f13d6aa..14b6c3d 100644 --- a/crates/gateway/src/cli/chat.rs +++ b/crates/gateway/src/cli/chat.rs @@ -205,6 +205,7 @@ async fn send_message( model: model.clone(), response_format: None, agent: None, + routing_profile: None, }; let (_run_id, mut rx) = run_turn(state.clone(), input); diff --git a/crates/gateway/src/cli/run.rs b/crates/gateway/src/cli/run.rs index a2514bd..64570e8 100644 --- a/crates/gateway/src/cli/run.rs +++ b/crates/gateway/src/cli/run.rs @@ -43,6 +43,7 @@ pub async fn run( model, response_format: None, agent: None, + routing_profile: None, }; // 4. Run the turn and obtain the event receiver. diff --git a/crates/gateway/src/import/openclaw/mod.rs b/crates/gateway/src/import/openclaw/mod.rs index 9cc8c37..79a2081 100644 --- a/crates/gateway/src/import/openclaw/mod.rs +++ b/crates/gateway/src/import/openclaw/mod.rs @@ -395,6 +395,7 @@ pub async fn import_schedules( last_error_at: None, consecutive_failures: 0, cooldown_until: None, + routing_profile: None, webhook_secret: None, total_input_tokens: 0, total_output_tokens: 0, diff --git a/crates/gateway/src/runtime/agent.rs b/crates/gateway/src/runtime/agent.rs index 5df0b8d..bd94778 100644 --- a/crates/gateway/src/runtime/agent.rs +++ b/crates/gateway/src/runtime/agent.rs @@ -298,6 +298,7 @@ pub async fn run_agent( model, response_format: None, agent: Some(ctx), + routing_profile: None, }; let (_run_id, mut rx) = run_turn((*state).clone(), input); diff --git a/crates/gateway/src/runtime/digest.rs b/crates/gateway/src/runtime/digest.rs index d68ca46..db403e4 100644 --- a/crates/gateway/src/runtime/digest.rs +++ b/crates/gateway/src/runtime/digest.rs @@ -405,6 +405,7 @@ mod tests { last_error_at: None, consecutive_failures: 0, cooldown_until: None, + routing_profile: None, webhook_secret: None, total_input_tokens: 0, total_output_tokens: 0, diff --git a/crates/gateway/src/runtime/mod.rs b/crates/gateway/src/runtime/mod.rs index 508d16e..a753786 100644 --- a/crates/gateway/src/runtime/mod.rs +++ b/crates/gateway/src/runtime/mod.rs @@ -78,40 +78,74 @@ pub(super) fn fire_auto_capture(state: &AppState, input: &turn::TurnInput, final /// Provider resolution order: /// 1. Explicit model override (from API request / agent.run) -/// 2. Agent-level model mapping (per sub-agent config) -/// 3. Global role defaults (planner/executor/summarizer) -/// 4. Any available provider +/// 2. Smart router (when enabled and no explicit override) +/// 3. Agent-level model mapping (per sub-agent config) +/// 4. Global role defaults (planner/executor/summarizer) +/// 5. Any available provider +/// +/// Returns the provider and an optional model name (when the router +/// selects a specific model within the provider). pub(super) fn resolve_provider( state: &AppState, model_override: Option<&str>, agent_ctx: Option<&agent::AgentContext>, -) -> Result, Box> { + routing_profile: Option, +) -> Result<(Arc, Option), Box> { // 1. Explicit override. if let Some(spec) = model_override { let provider_id = spec.split('/').next().unwrap_or(spec); if let Some(p) = state.llm.get(provider_id) { - return Ok(p); + let model_name = spec.split_once('/').map(|(_, m)| m.to_string()); + return Ok((p, model_name)); + } + } + + // 2. Smart router (when enabled and no explicit override). + if let Some(router) = &state.smart_router { + let profile = routing_profile.unwrap_or(router.default_profile); + // For non-Auto profiles, resolve tier directly (no classifier needed). + let tier = sa_providers::smart_router::profile_to_tier(profile); + if let Some(tier) = tier { + if let Some(model_spec) = sa_providers::smart_router::resolve_tier_model(tier, &router.tiers) { + let provider_id = model_spec.split('/').next().unwrap_or(model_spec); + if let Some(p) = state.llm.get(provider_id) { + let model_name = model_spec.split_once('/').map(|(_, m)| m.to_string()); + // Record the routing decision for observability. + router.decisions.record(sa_providers::decisions::Decision { + timestamp: chrono::Utc::now(), + prompt_snippet: String::new(), // populated by caller + profile, + tier, + model: model_spec.to_string(), + latency_ms: 0, + bypassed: false, + }); + return Ok((p, model_name)); + } + } } + // Auto profile without classifier falls through to role-based routing. } - // 2. Agent-level model mapping. + // 3. Agent-level model mapping. if let Some(ctx) = agent_ctx { if let Some(spec) = ctx.models.get("executor") { let provider_id = spec.split('/').next().unwrap_or(spec); if let Some(p) = state.llm.get(provider_id) { - return Ok(p); + let model_name = spec.split_once('/').map(|(_, m)| m.to_string()); + return Ok((p, model_name)); } } } - // 3. Global role defaults. + // 4. Global role defaults. if let Some(p) = state.llm.for_role("executor") { - return Ok(p); + return Ok((p, None)); } - // 4. Any available provider. + // 5. Any available provider. if let Some((_, p)) = state.llm.iter().next() { - return Ok(p.clone()); + return Ok((p.clone(), None)); } Err("no_provider_configured: no LLM providers available. \ diff --git a/crates/gateway/src/runtime/schedule_runner.rs b/crates/gateway/src/runtime/schedule_runner.rs index b490833..e2b8388 100644 --- a/crates/gateway/src/runtime/schedule_runner.rs +++ b/crates/gateway/src/runtime/schedule_runner.rs @@ -257,6 +257,18 @@ pub async fn spawn_scheduled_run( Utc::now().format("%Y%m%d%H%M%S") ); + let routing_profile = schedule.routing_profile.as_deref() + .and_then(|s| { + match s { + "auto" => Some(sa_domain::config::RoutingProfile::Auto), + "eco" => Some(sa_domain::config::RoutingProfile::Eco), + "premium" => Some(sa_domain::config::RoutingProfile::Premium), + "free" => Some(sa_domain::config::RoutingProfile::Free), + "reasoning" => Some(sa_domain::config::RoutingProfile::Reasoning), + _ => None, + } + }); + let input = crate::runtime::TurnInput { session_key, session_id, @@ -264,6 +276,7 @@ pub async fn spawn_scheduled_run( model: None, response_format: None, agent: None, + routing_profile, }; let (run_id, mut rx) = crate::runtime::run_turn(state.clone(), input); diff --git a/crates/gateway/src/runtime/schedules/model.rs b/crates/gateway/src/runtime/schedules/model.rs index 81ffccb..0fb369f 100644 --- a/crates/gateway/src/runtime/schedules/model.rs +++ b/crates/gateway/src/runtime/schedules/model.rs @@ -168,6 +168,12 @@ pub struct Schedule { #[serde(default)] pub cooldown_until: Option>, + // ── LLM routing ──────────────────────────────────────────────────── + /// Routing profile override for this schedule (e.g. "auto", "eco", "premium"). + /// None = use default profile from router config. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub routing_profile: Option, + // ── Webhook trigger ─────────────────────────────────────────────── /// HMAC-SHA256 secret for webhook trigger authentication. /// When set, `POST /v1/schedules/:id/trigger` additionally verifies @@ -287,6 +293,7 @@ mod tests { last_error_at: None, consecutive_failures, cooldown_until: None, + routing_profile: None, webhook_secret: None, total_input_tokens: 0, total_output_tokens: 0, diff --git a/crates/gateway/src/runtime/turn.rs b/crates/gateway/src/runtime/turn.rs index 73dcca6..1cd4e42 100644 --- a/crates/gateway/src/runtime/turn.rs +++ b/crates/gateway/src/runtime/turn.rs @@ -41,6 +41,8 @@ pub(super) struct TurnContext { provider: Arc, messages: Vec, tool_defs: Arc>, + /// Model name selected by the smart router (if any). + router_model: Option, } // ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ @@ -116,6 +118,8 @@ pub struct TurnInput { pub response_format: Option, /// When running as a sub-agent, carries agent-scoped overrides. pub agent: Option, + /// Routing profile override. None = use default. + pub routing_profile: Option, } // ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ @@ -399,6 +403,7 @@ async fn run_turn_inner( provider, mut messages, tool_defs, + router_model, } = ctx; // ── Phase 2: Tool loop ─────────────────────────────────────────────── @@ -461,6 +466,17 @@ async fn run_turn_inner( ); // Call LLM (streaming). + // Determine which model name to send on the request: + // - Explicit model override (provider/model) takes priority. + // - Router-selected model is used when no explicit override is present. + let effective_model = if let Some(ref m) = input.model { + // Extract the model name from "provider/model" format. + m.split_once('/').map(|(_, model_name)| model_name.to_string()) + .or_else(|| Some(m.clone())) + } else { + router_model.clone() + }; + let req = sa_providers::ChatRequest { messages: messages.clone(), tools: (*tool_defs).clone(), @@ -470,7 +486,7 @@ async fn run_turn_inner( .response_format .clone() .unwrap_or_default(), - model: input.model.clone(), + model: effective_model, }; let llm_call_span = tracing::info_span!( @@ -830,8 +846,8 @@ async fn prepare_turn_context( state: &AppState, input: &TurnInput, ) -> Result> { - // 1. Resolve the LLM provider (agent models -> global roles -> any). - let provider = resolve_provider(state, input.model.as_deref(), input.agent.as_ref())?; + // 1. Resolve the LLM provider (explicit -> router -> agent models -> global roles -> any). + let (provider, resolved_model) = resolve_provider(state, input.model.as_deref(), input.agent.as_ref(), input.routing_profile)?; // 2. Build system context (agent-scoped workspace/skills if present). let system_prompt = build_system_context(state, input.agent.as_ref()).await; @@ -927,5 +943,6 @@ async fn prepare_turn_context( provider, messages, tool_defs, + router_model: resolved_model, }) } diff --git a/crates/gateway/src/state.rs b/crates/gateway/src/state.rs index 78cf057..ec5fe64 100644 --- a/crates/gateway/src/state.rs +++ b/crates/gateway/src/state.rs @@ -4,8 +4,10 @@ use std::sync::Arc; use std::time::Instant; use parking_lot::RwLock; -use sa_domain::config::Config; +use sa_domain::config::{Config, RoutingProfile, TierConfig}; use sa_memory::provider::SerialMemoryProvider; +use sa_providers::classifier::EmbeddingClassifier; +use sa_providers::decisions::DecisionLog; use sa_providers::registry::ProviderRegistry; use sa_sessions::{IdentityResolver, LifecycleManager, SessionStore, TranscriptWriter}; use sa_skills::registry::SkillsRegistry; @@ -43,6 +45,14 @@ pub struct CachedToolDefs { pub policy_key: String, } +/// Smart router state (None when [llm.router] is not configured or disabled). +pub struct SmartRouterState { + pub classifier: Option, + pub tiers: TierConfig, + pub default_profile: RoutingProfile, + pub decisions: DecisionLog, +} + /// Shared application state passed to all API handlers. /// /// Fields are grouped by concern: @@ -58,6 +68,8 @@ pub struct AppState { pub config: Arc, pub memory: Arc, pub llm: Arc, + /// Smart LLM router (None when [llm.router] is absent or disabled). + pub smart_router: Option>, // ── Session management ──────────────────────────────────────────── pub sessions: Arc, diff --git a/crates/providers/Cargo.toml b/crates/providers/Cargo.toml index 869d9d7..fa67afd 100644 --- a/crates/providers/Cargo.toml +++ b/crates/providers/Cargo.toml @@ -20,3 +20,4 @@ chrono = { workspace = true } dirs = "5" fs2 = "0.4" tempfile = { workspace = true } +parking_lot = { workspace = true } diff --git a/crates/providers/src/classifier.rs b/crates/providers/src/classifier.rs new file mode 100644 index 0000000..a35055d --- /dev/null +++ b/crates/providers/src/classifier.rs @@ -0,0 +1,792 @@ +//! Embedding-based prompt classifier for smart model routing. +//! +//! Uses cosine similarity between prompt embeddings and pre-computed tier +//! centroids to classify incoming prompts as Simple, Complex, or Reasoning. +//! Embeddings are fetched from an Ollama-compatible endpoint and cached +//! in-memory with TTL-based eviction. + +use parking_lot::RwLock; +use sa_domain::config::{ClassifierConfig, ModelTier, RouterThresholds}; +use sa_domain::error::{Error, Result}; +use std::collections::HashMap; +use std::time::{Duration, Instant}; + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Constants +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +/// Maximum number of cached embeddings before eviction runs. +const CACHE_MAX_ENTRIES: usize = 10_000; + +/// Timeout for individual embedding requests. +const EMBEDDING_TIMEOUT: Duration = Duration::from_millis(500); + +/// Timeout for batch initialization (fetching all reference embeddings). +const BATCH_TIMEOUT: Duration = Duration::from_secs(30); + +/// Approximate chars-per-token multiplier for agentic detection. +const CHARS_PER_TOKEN: usize = 4; + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Reference prompts +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +/// Reference prompts used to build tier centroids at startup. +/// +/// Each tier gets a set of representative prompts whose embeddings are +/// averaged to form that tier's centroid vector. +pub fn default_reference_prompts() -> HashMap> { + let mut prompts = HashMap::new(); + + prompts.insert( + ModelTier::Simple, + vec![ + "What is the capital of France?", + "Convert 5 miles to kilometers", + "What time is it in Tokyo?", + "Define the word 'ephemeral'", + "How many cups in a gallon?", + "What year was the Eiffel Tower built?", + "Translate 'hello' to Spanish", + "What is 15% of 200?", + ], + ); + + prompts.insert( + ModelTier::Complex, + vec![ + "Write a Python script that scrapes a website and stores the data in a SQLite database with proper error handling", + "Explain the differences between microservices and monolithic architectures, including trade-offs for a startup", + "Design a REST API for a multi-tenant SaaS application with rate limiting and authentication", + "Refactor this legacy codebase to use dependency injection and add comprehensive test coverage", + "Create a data pipeline that ingests CSV files, validates schemas, transforms data, and loads into PostgreSQL", + "Build a React component library with TypeScript, Storybook documentation, and unit tests", + "Implement a caching strategy for a high-traffic e-commerce API with cache invalidation", + "Debug this distributed system issue where messages are being processed out of order", + ], + ); + + prompts.insert( + ModelTier::Reasoning, + vec![ + "Prove that the square root of 2 is irrational using proof by contradiction", + "Analyze the computational complexity of this recursive algorithm and suggest optimizations with formal proofs", + "Design a consensus protocol for a Byzantine fault-tolerant distributed system and prove its safety properties", + "Evaluate the philosophical implications of artificial general intelligence on human autonomy and free will", + "Derive the optimal strategy for this game theory problem using backward induction and Nash equilibrium", + "Compare and critically evaluate three competing theories of consciousness with respect to the hard problem", + "Analyze the economic second-order effects of universal basic income on labor markets and innovation", + "Formally verify the correctness of this concurrent data structure using temporal logic", + ], + ); + + prompts +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Vector math +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +/// Cosine similarity between two vectors. +/// +/// Returns a value in `[-1.0, 1.0]`. Returns `0.0` if either vector has +/// zero magnitude (avoiding division by zero). +pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 { + if a.len() != b.len() { + tracing::warn!( + len_a = a.len(), + len_b = b.len(), + "cosine_similarity: mismatched vector lengths, returning 0.0" + ); + return 0.0; + } + + let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(); + let mag_a: f32 = a.iter().map(|x| x * x).sum::().sqrt(); + let mag_b: f32 = b.iter().map(|x| x * x).sum::().sqrt(); + + if mag_a == 0.0 || mag_b == 0.0 { + return 0.0; + } + + dot / (mag_a * mag_b) +} + +/// Compute the centroid (element-wise average) of a set of vectors. +/// +/// Returns an empty vector if the input is empty. +pub fn compute_centroid(vectors: &[Vec]) -> Vec { + if vectors.is_empty() { + return Vec::new(); + } + + let dim = vectors[0].len(); + let count = vectors.len() as f32; + + let mut centroid = vec![0.0f32; dim]; + for v in vectors { + for (acc, val) in centroid.iter_mut().zip(v.iter()) { + *acc += val; + } + } + for val in &mut centroid { + *val /= count; + } + + centroid +} + +/// Build centroids from pre-computed reference embeddings. +/// +/// Each tier maps to a list of embedding vectors; this function computes +/// the centroid of each tier's vectors. +pub fn build_centroids( + embeddings: &HashMap>>, +) -> HashMap> { + embeddings + .iter() + .map(|(tier, vecs)| (*tier, compute_centroid(vecs))) + .collect() +} + +/// Classify a prompt embedding against tier centroids. +/// +/// Returns the best-matching tier and a map of all tier scores. +/// If centroids are empty, defaults to `ModelTier::Complex`. +pub fn classify_against_centroids( + embedding: &[f32], + centroids: &HashMap>, +) -> (ModelTier, HashMap) { + let mut scores = HashMap::new(); + let mut best_tier = ModelTier::Complex; + let mut best_score = f32::NEG_INFINITY; + + for (tier, centroid) in centroids { + let score = cosine_similarity(embedding, centroid); + scores.insert(*tier, score); + if score > best_score { + best_score = score; + best_tier = *tier; + } + } + + // When centroids are empty or all scores are tied / ambiguous, default to Complex. + if centroids.is_empty() { + return (ModelTier::Complex, scores); + } + + (best_tier, scores) +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Cache entry +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +/// A cached embedding vector with expiration time. +struct CachedEmbedding { + embedding: Vec, + expires_at: Instant, +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Classifier result +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +/// Result of classifying a prompt. +#[derive(Debug, Clone)] +pub struct ClassifyResult { + /// The selected model tier. + pub tier: ModelTier, + /// Cosine similarity scores for each tier. + pub scores: HashMap, + /// Classification latency in milliseconds. + pub latency_ms: u64, +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Embedding classifier +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +/// Embedding-based prompt classifier. +/// +/// Maintains pre-computed centroids for each model tier and classifies +/// incoming prompts by comparing their embeddings against those centroids. +pub struct EmbeddingClassifier { + config: ClassifierConfig, + thresholds: RouterThresholds, + centroids: HashMap>, + http: reqwest::Client, + cache: RwLock>, +} + +impl EmbeddingClassifier { + /// Create a classifier with pre-computed centroids (useful for testing + /// or when centroids are loaded from a snapshot). + pub fn with_centroids( + config: ClassifierConfig, + thresholds: RouterThresholds, + centroids: HashMap>, + ) -> Self { + Self { + config, + thresholds, + centroids, + http: reqwest::Client::new(), + cache: RwLock::new(HashMap::new()), + } + } + + /// Initialize the classifier by fetching embeddings for all reference + /// prompts and building centroids. + /// + /// This makes HTTP calls to the configured embedding endpoint. + pub async fn initialize( + config: ClassifierConfig, + thresholds: RouterThresholds, + ) -> Result { + let http = reqwest::Client::builder() + .timeout(BATCH_TIMEOUT) + .build() + .map_err(|e| Error::Http(format!("failed to build HTTP client: {e}")))?; + + let reference_prompts = default_reference_prompts(); + let mut tier_embeddings: HashMap>> = HashMap::new(); + + for (tier, prompts) in &reference_prompts { + let texts: Vec<&str> = prompts.iter().copied().collect(); + let embeddings = Self::fetch_embeddings_batch(&http, &config, &texts).await?; + tier_embeddings.insert(*tier, embeddings); + } + + let centroids = build_centroids(&tier_embeddings); + + tracing::info!( + tiers = centroids.len(), + "embedding classifier initialized with centroids" + ); + + Ok(Self { + config, + thresholds, + centroids, + http, + cache: RwLock::new(HashMap::new()), + }) + } + + /// Classify a prompt into a model tier. + /// + /// 1. Checks the embedding cache. + /// 2. Fetches embedding from the provider if not cached. + /// 3. Compares against centroids. + /// 4. Applies threshold rules and agentic escalation. + pub async fn classify(&self, prompt: &str) -> Result { + let start = Instant::now(); + + // Check cache first. + let cache_key = hash_prompt(prompt); + if let Some(cached) = self.get_cached(cache_key) { + let (tier, scores) = classify_against_centroids(&cached, &self.centroids); + let final_tier = self.apply_thresholds(tier, &scores, prompt); + return Ok(ClassifyResult { + tier: final_tier, + scores, + latency_ms: start.elapsed().as_millis() as u64, + }); + } + + // Fetch embedding from the provider. + let embedding = Self::fetch_embedding(&self.http, &self.config, prompt).await?; + + // Cache the result. + self.put_cached(cache_key, &embedding); + + // Classify. + let (tier, scores) = classify_against_centroids(&embedding, &self.centroids); + let final_tier = self.apply_thresholds(tier, &scores, prompt); + + Ok(ClassifyResult { + tier: final_tier, + scores, + latency_ms: start.elapsed().as_millis() as u64, + }) + } + + /// Apply threshold rules to potentially escalate or de-escalate the tier. + /// + /// Rules: + /// - If classified as Simple but score < simple_min_score, escalate to Complex. + /// - If classified as Reasoning but score < reasoning_min_score, fall back to Complex. + /// - If prompt is long (agentic), escalate Simple to Complex. + fn apply_thresholds( + &self, + tier: ModelTier, + scores: &HashMap, + prompt: &str, + ) -> ModelTier { + // Agentic detection: long prompts escalate Simple -> Complex. + let char_threshold = self.thresholds.escalate_token_threshold * CHARS_PER_TOKEN; + let after_length = if tier == ModelTier::Simple && prompt.len() > char_threshold { + tracing::debug!( + prompt_len = prompt.len(), + threshold = char_threshold, + "escalating Simple -> Complex due to prompt length" + ); + ModelTier::Complex + } else { + tier + }; + + // Threshold checks. + match after_length { + ModelTier::Simple => { + let score = scores + .get(&ModelTier::Simple) + .copied() + .unwrap_or(0.0) as f64; + if score < self.thresholds.simple_min_score { + tracing::debug!( + score, + min = self.thresholds.simple_min_score, + "escalating Simple -> Complex due to low score" + ); + ModelTier::Complex + } else { + ModelTier::Simple + } + } + ModelTier::Reasoning => { + let score = scores + .get(&ModelTier::Reasoning) + .copied() + .unwrap_or(0.0) as f64; + if score < self.thresholds.reasoning_min_score { + tracing::debug!( + score, + min = self.thresholds.reasoning_min_score, + "de-escalating Reasoning -> Complex due to low score" + ); + ModelTier::Complex + } else { + ModelTier::Reasoning + } + } + other => other, + } + } + + /// Fetch a single embedding vector from the Ollama-compatible endpoint. + async fn fetch_embedding( + http: &reqwest::Client, + config: &ClassifierConfig, + text: &str, + ) -> Result> { + let url = format!("{}/api/embeddings", config.endpoint.trim_end_matches('/')); + + let body = serde_json::json!({ + "model": config.model, + "prompt": text, + }); + + let resp = http + .post(&url) + .timeout(EMBEDDING_TIMEOUT) + .json(&body) + .send() + .await + .map_err(|e| Error::Http(format!("embedding request failed: {e}")))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body_text = resp.text().await.unwrap_or_default(); + return Err(Error::Provider { + provider: config.provider.clone(), + message: format!("embedding HTTP {status}: {body_text}"), + }); + } + + let json: serde_json::Value = resp + .json() + .await + .map_err(|e| Error::Http(format!("failed to parse embedding response: {e}")))?; + + let embedding = json + .get("embedding") + .and_then(|v| v.as_array()) + .ok_or_else(|| { + Error::Provider { + provider: config.provider.clone(), + message: "response missing 'embedding' array".into(), + } + })? + .iter() + .map(|v| v.as_f64().unwrap_or(0.0) as f32) + .collect(); + + Ok(embedding) + } + + /// Fetch embeddings for multiple texts sequentially. + /// + /// Uses the batch timeout for the overall operation. Individual requests + /// use the standard embedding timeout. + async fn fetch_embeddings_batch( + http: &reqwest::Client, + config: &ClassifierConfig, + texts: &[&str], + ) -> Result>> { + let mut results = Vec::with_capacity(texts.len()); + for text in texts { + let embedding = Self::fetch_embedding(http, config, text).await?; + results.push(embedding); + } + Ok(results) + } + + /// Check whether the embedding endpoint is reachable. + pub async fn health_check(&self) -> bool { + let test_result = + Self::fetch_embedding(&self.http, &self.config, "health check").await; + test_result.is_ok() + } + + /// Get a reference to the classifier config. + pub fn config(&self) -> &ClassifierConfig { + &self.config + } + + /// Get a reference to the centroids. + pub fn centroids(&self) -> &HashMap> { + &self.centroids + } + + // ── Cache helpers ────────────────────────────────────────────── + + /// Look up a cached embedding by prompt hash. Returns `None` if absent or expired. + fn get_cached(&self, key: u64) -> Option> { + let cache = self.cache.read(); + cache.get(&key).and_then(|entry| { + if Instant::now() < entry.expires_at { + Some(entry.embedding.clone()) + } else { + None + } + }) + } + + /// Store an embedding in the cache. Evicts expired entries if over capacity. + fn put_cached(&self, key: u64, embedding: &[f32]) { + let ttl = Duration::from_secs(self.config.cache_ttl_secs); + let entry = CachedEmbedding { + embedding: embedding.to_vec(), + expires_at: Instant::now() + ttl, + }; + + let mut cache = self.cache.write(); + + // Evict expired entries when over capacity. + if cache.len() >= CACHE_MAX_ENTRIES { + let now = Instant::now(); + cache.retain(|_, v| v.expires_at > now); + } + + cache.insert(key, entry); + } +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Helpers +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +/// Hash a prompt string to a u64 for cache lookup. +fn hash_prompt(prompt: &str) -> u64 { + use std::hash::{Hash, Hasher}; + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + prompt.hash(&mut hasher); + hasher.finish() +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Tests +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn cosine_similarity_identical_vectors() { + let a = vec![1.0, 2.0, 3.0]; + let b = vec![1.0, 2.0, 3.0]; + let sim = cosine_similarity(&a, &b); + assert!( + (sim - 1.0).abs() < 1e-6, + "identical vectors should have similarity ~1.0, got {sim}" + ); + } + + #[test] + fn cosine_similarity_orthogonal_vectors() { + let a = vec![1.0, 0.0, 0.0]; + let b = vec![0.0, 1.0, 0.0]; + let sim = cosine_similarity(&a, &b); + assert!( + sim.abs() < 1e-6, + "orthogonal vectors should have similarity ~0.0, got {sim}" + ); + } + + #[test] + fn cosine_similarity_opposite_vectors() { + let a = vec![1.0, 2.0, 3.0]; + let b = vec![-1.0, -2.0, -3.0]; + let sim = cosine_similarity(&a, &b); + assert!( + (sim - (-1.0)).abs() < 1e-6, + "opposite vectors should have similarity ~-1.0, got {sim}" + ); + } + + #[test] + fn cosine_similarity_zero_vector_returns_zero() { + let a = vec![0.0, 0.0, 0.0]; + let b = vec![1.0, 2.0, 3.0]; + let sim = cosine_similarity(&a, &b); + assert!( + sim.abs() < 1e-6, + "zero vector should yield similarity 0.0, got {sim}" + ); + } + + #[test] + fn compute_centroid_single_vector() { + let vectors = vec![vec![1.0, 2.0, 3.0]]; + let centroid = compute_centroid(&vectors); + assert_eq!(centroid, vec![1.0, 2.0, 3.0]); + } + + #[test] + fn compute_centroid_average() { + let vectors = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0], vec![0.0, 0.0, 1.0]]; + let centroid = compute_centroid(&vectors); + let expected = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]; + for (a, b) in centroid.iter().zip(expected.iter()) { + assert!( + (a - b).abs() < 1e-6, + "centroid mismatch: got {a}, expected {b}" + ); + } + } + + #[test] + fn compute_centroid_empty_returns_empty() { + let vectors: Vec> = vec![]; + let centroid = compute_centroid(&vectors); + assert!(centroid.is_empty()); + } + + #[test] + fn classify_with_centroids_picks_nearest() { + // Build centroids that are clearly separated in 3D space. + let mut centroids = HashMap::new(); + centroids.insert(ModelTier::Simple, vec![1.0, 0.0, 0.0]); + centroids.insert(ModelTier::Complex, vec![0.0, 1.0, 0.0]); + centroids.insert(ModelTier::Reasoning, vec![0.0, 0.0, 1.0]); + + // A vector close to the Simple centroid. + let embedding = vec![0.9, 0.1, 0.0]; + let (tier, scores) = classify_against_centroids(&embedding, ¢roids); + + assert_eq!(tier, ModelTier::Simple); + assert!(scores[&ModelTier::Simple] > scores[&ModelTier::Complex]); + assert!(scores[&ModelTier::Simple] > scores[&ModelTier::Reasoning]); + } + + #[test] + fn classify_ambiguous_defaults_to_complex() { + // Empty centroids should default to Complex. + let centroids: HashMap> = HashMap::new(); + let embedding = vec![1.0, 2.0, 3.0]; + let (tier, _scores) = classify_against_centroids(&embedding, ¢roids); + assert_eq!(tier, ModelTier::Complex); + } + + #[test] + fn build_centroids_from_embeddings() { + let mut embeddings = HashMap::new(); + embeddings.insert( + ModelTier::Simple, + vec![vec![1.0, 0.0], vec![0.8, 0.2]], + ); + embeddings.insert( + ModelTier::Complex, + vec![vec![0.0, 1.0], vec![0.2, 0.8]], + ); + + let centroids = build_centroids(&embeddings); + + assert_eq!(centroids.len(), 2); + + let simple = ¢roids[&ModelTier::Simple]; + assert!((simple[0] - 0.9).abs() < 1e-6); + assert!((simple[1] - 0.1).abs() < 1e-6); + + let complex = ¢roids[&ModelTier::Complex]; + assert!((complex[0] - 0.1).abs() < 1e-6); + assert!((complex[1] - 0.9).abs() < 1e-6); + } + + #[test] + fn default_reference_prompts_has_all_tiers() { + let prompts = default_reference_prompts(); + assert!(prompts.contains_key(&ModelTier::Simple)); + assert!(prompts.contains_key(&ModelTier::Complex)); + assert!(prompts.contains_key(&ModelTier::Reasoning)); + // Each tier should have multiple reference prompts. + for (_tier, texts) in &prompts { + assert!(texts.len() >= 3, "each tier should have at least 3 reference prompts"); + } + } + + #[test] + fn apply_thresholds_escalates_low_simple_score() { + let config = ClassifierConfig::default(); + let thresholds = RouterThresholds { + simple_min_score: 0.6, + complex_min_score: 0.5, + reasoning_min_score: 0.55, + escalate_token_threshold: 8000, + }; + + let classifier = EmbeddingClassifier::with_centroids( + config, + thresholds, + HashMap::new(), + ); + + // Simple score below threshold -> should escalate to Complex. + let mut scores = HashMap::new(); + scores.insert(ModelTier::Simple, 0.4_f32); // below 0.6 + scores.insert(ModelTier::Complex, 0.3_f32); + scores.insert(ModelTier::Reasoning, 0.2_f32); + + let result = classifier.apply_thresholds(ModelTier::Simple, &scores, "short prompt"); + assert_eq!(result, ModelTier::Complex); + } + + #[test] + fn apply_thresholds_deescalates_low_reasoning_score() { + let config = ClassifierConfig::default(); + let thresholds = RouterThresholds { + simple_min_score: 0.6, + complex_min_score: 0.5, + reasoning_min_score: 0.55, + escalate_token_threshold: 8000, + }; + + let classifier = EmbeddingClassifier::with_centroids( + config, + thresholds, + HashMap::new(), + ); + + // Reasoning score below threshold -> should fall back to Complex. + let mut scores = HashMap::new(); + scores.insert(ModelTier::Simple, 0.2_f32); + scores.insert(ModelTier::Complex, 0.3_f32); + scores.insert(ModelTier::Reasoning, 0.4_f32); // below 0.55 + + let result = classifier.apply_thresholds(ModelTier::Reasoning, &scores, "short prompt"); + assert_eq!(result, ModelTier::Complex); + } + + #[test] + fn apply_thresholds_escalates_long_prompt() { + let config = ClassifierConfig::default(); + let thresholds = RouterThresholds { + simple_min_score: 0.3, // low threshold so score check won't trigger + complex_min_score: 0.5, + reasoning_min_score: 0.55, + escalate_token_threshold: 100, // 100 tokens * 4 chars = 400 chars + }; + + let classifier = EmbeddingClassifier::with_centroids( + config, + thresholds, + HashMap::new(), + ); + + let mut scores = HashMap::new(); + scores.insert(ModelTier::Simple, 0.9_f32); + scores.insert(ModelTier::Complex, 0.3_f32); + + // A prompt longer than 400 chars should escalate Simple -> Complex. + let long_prompt = "a".repeat(500); + let result = classifier.apply_thresholds(ModelTier::Simple, &scores, &long_prompt); + assert_eq!(result, ModelTier::Complex); + } + + #[test] + fn apply_thresholds_keeps_good_simple_score() { + let config = ClassifierConfig::default(); + let thresholds = RouterThresholds::default(); + + let classifier = EmbeddingClassifier::with_centroids( + config, + thresholds, + HashMap::new(), + ); + + let mut scores = HashMap::new(); + scores.insert(ModelTier::Simple, 0.8_f32); // above 0.6 threshold + scores.insert(ModelTier::Complex, 0.3_f32); + scores.insert(ModelTier::Reasoning, 0.2_f32); + + let result = classifier.apply_thresholds(ModelTier::Simple, &scores, "short"); + assert_eq!(result, ModelTier::Simple); + } + + #[test] + fn cache_stores_and_retrieves() { + let config = ClassifierConfig { + cache_ttl_secs: 300, + ..ClassifierConfig::default() + }; + let classifier = EmbeddingClassifier::with_centroids( + config, + RouterThresholds::default(), + HashMap::new(), + ); + + let key = hash_prompt("test prompt"); + let embedding = vec![1.0, 2.0, 3.0]; + + classifier.put_cached(key, &embedding); + let result = classifier.get_cached(key); + + assert!(result.is_some()); + assert_eq!(result.unwrap(), embedding); + } + + #[test] + fn cache_returns_none_for_missing() { + let classifier = EmbeddingClassifier::with_centroids( + ClassifierConfig::default(), + RouterThresholds::default(), + HashMap::new(), + ); + + let result = classifier.get_cached(999); + assert!(result.is_none()); + } + + #[test] + fn hash_prompt_deterministic() { + let h1 = hash_prompt("hello world"); + let h2 = hash_prompt("hello world"); + let h3 = hash_prompt("different prompt"); + + assert_eq!(h1, h2, "same input should produce same hash"); + assert_ne!(h1, h3, "different input should produce different hash"); + } +} diff --git a/crates/providers/src/decisions.rs b/crates/providers/src/decisions.rs new file mode 100644 index 0000000..5587d95 --- /dev/null +++ b/crates/providers/src/decisions.rs @@ -0,0 +1,111 @@ +use chrono::{DateTime, Utc}; +use parking_lot::Mutex; +use sa_domain::config::{ModelTier, RoutingProfile}; +use serde::Serialize; +use std::collections::VecDeque; + +/// A single routing decision record. +#[derive(Debug, Clone, Serialize)] +pub struct Decision { + pub timestamp: DateTime, + pub prompt_snippet: String, + pub profile: RoutingProfile, + pub tier: ModelTier, + pub model: String, + pub latency_ms: u64, + pub bypassed: bool, +} + +/// Thread-safe ring buffer of recent routing decisions. +/// +/// Uses `parking_lot::Mutex` for low-overhead synchronisation. +/// The buffer evicts the oldest entry when it reaches capacity, +/// keeping only the most recent decisions for observability. +pub struct DecisionLog { + inner: Mutex>, + capacity: usize, +} + +impl DecisionLog { + /// Create a new decision log with the given maximum capacity. + pub fn new(capacity: usize) -> Self { + Self { + inner: Mutex::new(VecDeque::with_capacity(capacity)), + capacity, + } + } + + /// Record a new decision. If the buffer is at capacity the oldest + /// entry is evicted first. + pub fn record(&self, decision: Decision) { + let mut buf = self.inner.lock(); + if buf.len() >= self.capacity { + buf.pop_front(); + } + buf.push_back(decision); + } + + /// Return the `limit` most recent decisions, newest first. + /// + /// If fewer than `limit` decisions exist, all are returned. + pub fn recent(&self, limit: usize) -> Vec { + let buf = self.inner.lock(); + buf.iter().rev().take(limit).cloned().collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Helper — build a `Decision` with a distinguishing index baked + /// into `prompt_snippet` so assertions can identify ordering. + fn make_decision(index: u64) -> Decision { + Decision { + timestamp: Utc::now(), + prompt_snippet: format!("prompt-{index}"), + profile: RoutingProfile::Auto, + tier: ModelTier::Simple, + model: "test-model".into(), + latency_ms: index, + bypassed: false, + } + } + + #[test] + fn ring_buffer_stores_up_to_capacity() { + let log = DecisionLog::new(3); + for i in 0..5 { + log.record(make_decision(i)); + } + + let recent = log.recent(10); + assert_eq!(recent.len(), 3, "should keep at most 3 entries"); + + // Newest first: 4, 3, 2 + assert_eq!(recent[0].latency_ms, 4); + assert_eq!(recent[1].latency_ms, 3); + assert_eq!(recent[2].latency_ms, 2); + } + + #[test] + fn ring_buffer_recent_respects_limit() { + let log = DecisionLog::new(100); + for i in 0..50 { + log.record(make_decision(i)); + } + + let recent = log.recent(5); + assert_eq!(recent.len(), 5); + // Newest first: 49, 48, 47, 46, 45 + assert_eq!(recent[0].latency_ms, 49); + assert_eq!(recent[4].latency_ms, 45); + } + + #[test] + fn ring_buffer_empty() { + let log = DecisionLog::new(10); + let recent = log.recent(5); + assert!(recent.is_empty()); + } +} diff --git a/crates/providers/src/lib.rs b/crates/providers/src/lib.rs index 11461cb..a6c52f8 100644 --- a/crates/providers/src/lib.rs +++ b/crates/providers/src/lib.rs @@ -1,11 +1,14 @@ pub mod anthropic; pub mod auth; pub mod bedrock; +pub mod classifier; +pub mod decisions; pub mod google; pub mod oauth; pub mod openai_compat; pub mod registry; pub mod router; +pub mod smart_router; pub mod traits; pub(crate) mod sse; pub(crate) mod util; diff --git a/crates/providers/src/smart_router.rs b/crates/providers/src/smart_router.rs new file mode 100644 index 0000000..04d1379 --- /dev/null +++ b/crates/providers/src/smart_router.rs @@ -0,0 +1,233 @@ +//! Smart router resolution logic. +//! +//! Pure, synchronous functions that resolve a model string from routing +//! profiles, classified tiers, and tier configuration. No HTTP, no async +//! — just deterministic decision logic. + +use sa_domain::config::{ModelTier, RoutingProfile, TierConfig}; + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Types +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +/// The result of a routing decision. +#[derive(Debug, Clone)] +pub struct RoutingDecision { + pub model: String, + pub tier: ModelTier, + pub profile: RoutingProfile, + pub bypassed: bool, +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Public API +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +/// Map a fixed profile to its corresponding tier. +/// Returns `None` for `Auto` (requires classification). +pub fn profile_to_tier(profile: RoutingProfile) -> Option { + match profile { + RoutingProfile::Auto => None, + RoutingProfile::Eco => Some(ModelTier::Simple), + RoutingProfile::Premium => Some(ModelTier::Complex), + RoutingProfile::Free => Some(ModelTier::Free), + RoutingProfile::Reasoning => Some(ModelTier::Reasoning), + } +} + +/// Get the first available model from a tier. +pub fn resolve_tier_model<'a>(tier: ModelTier, tiers: &'a TierConfig) -> Option<&'a str> { + let models = match tier { + ModelTier::Simple => &tiers.simple, + ModelTier::Complex => &tiers.complex, + ModelTier::Reasoning => &tiers.reasoning, + ModelTier::Free => &tiers.free, + }; + models.first().map(|s| s.as_str()) +} + +/// Core resolution: explicit model > profile tier > classified tier > fallback. +/// +/// Resolution order: +/// 1. If `explicit_model` is `Some`, bypass the router entirely. +/// 2. If the profile maps to a fixed tier, use that tier. +/// 3. If the profile is `Auto`, use the `classified_tier`. +/// 4. If no model is found in the chosen tier, walk the fallback chain. +pub fn resolve_model_for_request( + explicit_model: Option<&str>, + profile: RoutingProfile, + classified_tier: Option, + tiers: &TierConfig, +) -> RoutingDecision { + // 1. Explicit model bypass. + if let Some(model) = explicit_model { + return RoutingDecision { + model: model.to_string(), + tier: ModelTier::Complex, // sensible default for explicit + profile, + bypassed: true, + }; + } + + // 2. Determine the target tier from profile or classification. + let target_tier = profile_to_tier(profile) + .or(classified_tier) + .unwrap_or(ModelTier::Complex); // fallback default + + // 3. Try the target tier first, then walk fallbacks. + if let Some(model) = resolve_tier_model(target_tier, tiers) { + return RoutingDecision { + model: model.to_string(), + tier: target_tier, + profile, + bypassed: false, + }; + } + + for fallback_tier in fallback_tiers(target_tier) { + if let Some(model) = resolve_tier_model(fallback_tier, tiers) { + return RoutingDecision { + model: model.to_string(), + tier: fallback_tier, + profile, + bypassed: false, + }; + } + } + + // 4. Absolute last resort — nothing configured anywhere. + RoutingDecision { + model: String::new(), + tier: target_tier, + profile, + bypassed: false, + } +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Internal helpers +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +/// Tier fallback order when the target tier has no models configured. +fn fallback_tiers(starting: ModelTier) -> Vec { + match starting { + ModelTier::Simple => vec![ModelTier::Complex, ModelTier::Reasoning], + ModelTier::Complex => vec![ModelTier::Reasoning, ModelTier::Simple], + ModelTier::Reasoning => vec![ModelTier::Complex, ModelTier::Simple], + ModelTier::Free => vec![ModelTier::Simple, ModelTier::Complex, ModelTier::Reasoning], + } +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Tests +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +#[cfg(test)] +mod tests { + use super::*; + + fn test_tiers() -> TierConfig { + TierConfig { + simple: vec!["deepseek/deepseek-chat".into()], + complex: vec!["anthropic/claude-sonnet-4-20250514".into()], + reasoning: vec!["anthropic/claude-opus-4-6".into()], + free: vec!["venice/venice-uncensored".into()], + } + } + + // ── resolve_tier_model ──────────────────────────────────────── + + #[test] + fn resolve_tier_model_picks_first_in_list() { + let tiers = TierConfig { + simple: vec!["model-a".into(), "model-b".into()], + ..Default::default() + }; + assert_eq!(resolve_tier_model(ModelTier::Simple, &tiers), Some("model-a")); + } + + #[test] + fn resolve_tier_model_empty_tier_returns_none() { + let tiers = TierConfig::default(); + assert_eq!(resolve_tier_model(ModelTier::Simple, &tiers), None); + } + + // ── profile_to_tier ─────────────────────────────────────────── + + #[test] + fn profile_to_tier_eco_is_simple() { + assert_eq!(profile_to_tier(RoutingProfile::Eco), Some(ModelTier::Simple)); + } + + #[test] + fn profile_to_tier_premium_is_complex() { + assert_eq!(profile_to_tier(RoutingProfile::Premium), Some(ModelTier::Complex)); + } + + #[test] + fn profile_to_tier_auto_is_none() { + assert_eq!(profile_to_tier(RoutingProfile::Auto), None); + } + + // ── resolve_model_for_request ───────────────────────────────── + + #[test] + fn resolve_with_explicit_model_bypasses_router() { + let tiers = test_tiers(); + let decision = resolve_model_for_request( + Some("custom/my-model"), + RoutingProfile::Auto, + None, + &tiers, + ); + assert_eq!(decision.model, "custom/my-model"); + assert!(decision.bypassed); + } + + #[test] + fn resolve_with_eco_profile_uses_simple_tier() { + let tiers = test_tiers(); + let decision = resolve_model_for_request( + None, + RoutingProfile::Eco, + None, + &tiers, + ); + assert_eq!(decision.model, "deepseek/deepseek-chat"); + assert_eq!(decision.tier, ModelTier::Simple); + assert!(!decision.bypassed); + } + + #[test] + fn resolve_with_auto_profile_uses_classified_tier() { + let tiers = test_tiers(); + let decision = resolve_model_for_request( + None, + RoutingProfile::Auto, + Some(ModelTier::Reasoning), + &tiers, + ); + assert_eq!(decision.model, "anthropic/claude-opus-4-6"); + assert_eq!(decision.tier, ModelTier::Reasoning); + assert!(!decision.bypassed); + } + + #[test] + fn resolve_falls_back_across_tiers() { + // Simple tier is empty, should fall back to Complex. + let tiers = TierConfig { + simple: vec![], + complex: vec!["fallback-model".into()], + ..Default::default() + }; + let decision = resolve_model_for_request( + None, + RoutingProfile::Eco, // maps to Simple + None, + &tiers, + ); + assert_eq!(decision.model, "fallback-model"); + assert_eq!(decision.tier, ModelTier::Complex); + assert!(!decision.bypassed); + } +} diff --git a/crates/providers/tests/router_integration.rs b/crates/providers/tests/router_integration.rs new file mode 100644 index 0000000..c5e41d9 --- /dev/null +++ b/crates/providers/tests/router_integration.rs @@ -0,0 +1,252 @@ +//! Integration tests for the smart router — full round-trip without Ollama. +//! +//! These tests validate the complete routing flow across multiple modules +//! (smart_router + decisions) without requiring any external services. +//! All tests are pure and deterministic. + +use chrono::Utc; +use sa_domain::config::{ModelTier, RoutingProfile, TierConfig}; +use sa_providers::decisions::{Decision, DecisionLog}; +use sa_providers::smart_router::resolve_model_for_request; +use std::time::Instant; + +fn test_tiers() -> TierConfig { + TierConfig { + simple: vec!["deepseek/deepseek-chat".into()], + complex: vec!["anthropic/claude-sonnet-4-20250514".into()], + reasoning: vec!["anthropic/claude-opus-4-6".into()], + free: vec!["venice/venice-uncensored".into()], + } +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Profile-to-model resolution +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +#[test] +fn eco_profile_resolves_simple_tier_and_logs_decision() { + let tiers = test_tiers(); + let decisions = DecisionLog::new(10); + + let start = Instant::now(); + let decision = resolve_model_for_request(None, RoutingProfile::Eco, None, &tiers); + let latency_ms = start.elapsed().as_millis() as u64; + + assert_eq!(decision.model, "deepseek/deepseek-chat"); + assert_eq!(decision.tier, ModelTier::Simple); + assert_eq!(decision.profile, RoutingProfile::Eco); + assert!(!decision.bypassed); + + // Log the decision to the ring buffer and verify round-trip. + decisions.record(Decision { + timestamp: Utc::now(), + prompt_snippet: "test prompt".into(), + profile: decision.profile, + tier: decision.tier, + model: decision.model.clone(), + latency_ms, + bypassed: decision.bypassed, + }); + + let recent = decisions.recent(10); + assert_eq!(recent.len(), 1); + assert_eq!(recent[0].model, "deepseek/deepseek-chat"); + assert_eq!(recent[0].tier, ModelTier::Simple); +} + +#[test] +fn premium_profile_resolves_complex_tier() { + let tiers = test_tiers(); + let decision = resolve_model_for_request(None, RoutingProfile::Premium, None, &tiers); + + assert_eq!(decision.model, "anthropic/claude-sonnet-4-20250514"); + assert_eq!(decision.tier, ModelTier::Complex); + assert!(!decision.bypassed); +} + +#[test] +fn reasoning_profile_resolves_reasoning_tier() { + let tiers = test_tiers(); + let decision = resolve_model_for_request(None, RoutingProfile::Reasoning, None, &tiers); + + assert_eq!(decision.model, "anthropic/claude-opus-4-6"); + assert_eq!(decision.tier, ModelTier::Reasoning); + assert!(!decision.bypassed); +} + +#[test] +fn free_profile_resolves_free_tier() { + let tiers = test_tiers(); + let decision = resolve_model_for_request(None, RoutingProfile::Free, None, &tiers); + + assert_eq!(decision.model, "venice/venice-uncensored"); + assert_eq!(decision.tier, ModelTier::Free); + assert!(!decision.bypassed); +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Auto profile with classification +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +#[test] +fn auto_profile_with_classified_tier_uses_classification() { + let tiers = test_tiers(); + let decision = resolve_model_for_request( + None, + RoutingProfile::Auto, + Some(ModelTier::Reasoning), + &tiers, + ); + + assert_eq!(decision.model, "anthropic/claude-opus-4-6"); + assert_eq!(decision.tier, ModelTier::Reasoning); + assert!(!decision.bypassed); +} + +#[test] +fn auto_profile_without_classification_falls_back_to_complex() { + let tiers = test_tiers(); + let decision = resolve_model_for_request(None, RoutingProfile::Auto, None, &tiers); + + assert_eq!(decision.model, "anthropic/claude-sonnet-4-20250514"); + assert_eq!(decision.tier, ModelTier::Complex); + assert!(!decision.bypassed); +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Explicit model bypass +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +#[test] +fn explicit_model_bypasses_router() { + let tiers = test_tiers(); + let decision = resolve_model_for_request( + Some("custom/my-fine-tune"), + RoutingProfile::Eco, + Some(ModelTier::Simple), + &tiers, + ); + + assert_eq!(decision.model, "custom/my-fine-tune"); + assert!(decision.bypassed); +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Fallback behaviour +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +#[test] +fn fallback_when_target_tier_empty() { + let tiers = TierConfig { + simple: vec![], + complex: vec!["fallback-model".into()], + reasoning: vec![], + free: vec![], + }; + + let decision = resolve_model_for_request( + None, + RoutingProfile::Eco, // maps to Simple, which is empty + None, + &tiers, + ); + + assert_eq!(decision.model, "fallback-model"); + assert_eq!(decision.tier, ModelTier::Complex); // fell back to Complex + assert!(!decision.bypassed); +} + +#[test] +fn fallback_walks_full_chain_when_multiple_tiers_empty() { + let tiers = TierConfig { + simple: vec![], + complex: vec![], + reasoning: vec!["last-resort".into()], + free: vec![], + }; + + let decision = resolve_model_for_request( + None, + RoutingProfile::Eco, // Simple -> fallback: Complex -> Reasoning + None, + &tiers, + ); + + assert_eq!(decision.model, "last-resort"); + assert_eq!(decision.tier, ModelTier::Reasoning); + assert!(!decision.bypassed); +} + +#[test] +fn all_tiers_empty_returns_empty_model() { + let tiers = TierConfig::default(); // all vecs empty + + let decision = resolve_model_for_request(None, RoutingProfile::Eco, None, &tiers); + + assert!(decision.model.is_empty()); + assert!(!decision.bypassed); +} + +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +// Decision log round-trip +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +#[test] +fn decision_log_round_trip_multiple_decisions() { + let tiers = test_tiers(); + let decisions = DecisionLog::new(100); + + // Simulate multiple routing decisions. + for profile in &[ + RoutingProfile::Eco, + RoutingProfile::Premium, + RoutingProfile::Reasoning, + ] { + let decision = resolve_model_for_request(None, *profile, None, &tiers); + decisions.record(Decision { + timestamp: Utc::now(), + prompt_snippet: format!("test {:?}", profile), + profile: decision.profile, + tier: decision.tier, + model: decision.model, + latency_ms: 0, + bypassed: decision.bypassed, + }); + } + + let recent = decisions.recent(10); + assert_eq!(recent.len(), 3); + // Newest first: Reasoning, Premium, Eco + assert_eq!(recent[0].tier, ModelTier::Reasoning); + assert_eq!(recent[1].tier, ModelTier::Complex); + assert_eq!(recent[2].tier, ModelTier::Simple); +} + +#[test] +fn decision_log_capacity_evicts_oldest() { + let tiers = test_tiers(); + let decisions = DecisionLog::new(2); // capacity of 2 + + for profile in &[ + RoutingProfile::Eco, + RoutingProfile::Premium, + RoutingProfile::Reasoning, + ] { + let decision = resolve_model_for_request(None, *profile, None, &tiers); + decisions.record(Decision { + timestamp: Utc::now(), + prompt_snippet: format!("test {:?}", profile), + profile: decision.profile, + tier: decision.tier, + model: decision.model, + latency_ms: 0, + bypassed: decision.bypassed, + }); + } + + let recent = decisions.recent(10); + assert_eq!(recent.len(), 2); + // Only the last two remain: Reasoning (newest) and Premium + assert_eq!(recent[0].tier, ModelTier::Reasoning); + assert_eq!(recent[1].tier, ModelTier::Complex); +}