diff --git a/.gitignore b/.gitignore index ca8cff2f..2240f792 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,10 @@ config.yaml .vscode/ .DS_Store +# Documentation +docs/ +.claude/ + prd/ memory-bank/ .cursor/ diff --git a/Cargo.lock b/Cargo.lock index d1058810..7f109bb3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -204,6 +204,42 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "async-process" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc50921ec0055cdd8a16de48773bfeec5c972598674347252c0399676be7da75" +dependencies = [ + "async-channel 2.5.0", + "async-io", + "async-lock", + "async-signal", + "async-task", + "blocking", + "cfg-if", + "event-listener 5.4.1", + "futures-lite 2.6.1", + "rustix", +] + +[[package]] +name = "async-signal" +version = "0.2.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43c070bbf59cd3570b6b2dd54cd772527c7c3620fce8be898406dd3ed6adc64c" +dependencies = [ + "async-io", + "async-lock", + "atomic-waker", + "cfg-if", + "futures-core", + "futures-io", + "rustix", + "signal-hook-registry", + "slab", + "windows-sys 0.61.2", +] + [[package]] name = "async-std" version = "1.13.2" @@ -214,6 +250,7 @@ dependencies = [ "async-global-executor", "async-io", "async-lock", + "async-process", "crossbeam-utils", "futures-channel", "futures-core", @@ -2269,7 +2306,7 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" [[package]] name = "hub" -version = "0.7.5" +version = "0.7.7" dependencies = [ "anyhow", "async-stream", @@ -2299,9 +2336,11 @@ dependencies = [ "sqlx", "surf", "surf-vcr", + "temp-env", "tempfile", "testcontainers", "testcontainers-modules", + "thiserror 2.0.17", "tokio", "tower 0.5.3", "tower-http", @@ -3283,6 +3322,7 @@ version = "0.27.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "231e9d6ceef9b0b2546ddf52335785ce41252bc7474ee8ba05bfad277be13ab8" dependencies = [ + "async-std", "async-trait", "futures-channel", "futures-executor", @@ -3295,6 +3335,7 @@ dependencies = [ "thiserror 1.0.69", "tokio", "tokio-stream", + "tracing", ] [[package]] @@ -5062,6 +5103,15 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" +[[package]] +name = "temp-env" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96374855068f47402c3121c6eed88d29cb1de8f3ab27090e273e420bdabcf050" +dependencies = [ + "parking_lot", +] + [[package]] name = "tempfile" version = "3.24.0" diff --git a/Cargo.toml b/Cargo.toml index 557949b3..5d6f3662 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,6 +58,7 @@ utoipa = { version = "5.0.0", features = ["axum_extras", "chrono", "uuid", "open utoipa-swagger-ui = { version = "8", features = ["axum"] } utoipa-scalar = { version = "0.3.0", features = ["axum"] } log = "0.4" +thiserror = "2" [lib] name = "hub_lib" @@ -72,6 +73,9 @@ tempfile = "3.8" testcontainers = "0.20.0" testcontainers-modules = { version = "0.8.0", features = ["postgres"] } axum-test = "17" +temp-env = "0.3" sqlx = { version = "0.8", features = ["runtime-tokio", "postgres", "migrate"] } tokio = { version = "1.45.0", features = ["full"] } reqwest = { version = "0.12", features = ["json"] } +opentelemetry = { version = "0.27" } +opentelemetry_sdk = { version = "0.27", features = ["testing"] } diff --git a/config-example.yaml b/config-example.yaml index e8143fb9..aaca3627 100644 --- a/config-example.yaml +++ b/config-example.yaml @@ -94,6 +94,9 @@ pipelines: # Default pipeline for chat completions - name: default type: chat + guards: # Optional: reference guards by name (defined in guardrails section below) + - pii-check + - injection-check plugins: - logging: level: info # Supported levels: debug, info, warning, error @@ -125,3 +128,36 @@ pipelines: - model-router: models: # List the models you want to use for embeddings - textembedding-gecko + +guardrails: + providers: + - name: traceloop + api_base: ${TRACELOOP_BASE_URL} # or use direct URL + api_key: ${TRACELOOP_API_KEY} # or use "" + guards: + # PII Detection - Pre-call guard + - name: pii-check + provider: traceloop + evaluator_slug: pii-detector + mode: pre_call # Runs before the model call + on_failure: block # Options: block, warn + required: true # If true, request fails if guard is unavailable + + # Prompt Injection Detection - Pre-call guard + - name: injection-check + provider: traceloop + evaluator_slug: prompt-injection + params: + threshold: 0.8 + mode: pre_call + on_failure: block + required: false + + # Toxicity Detection - Post-call guard + - name: toxicity-filter + provider: traceloop + evaluator_slug: toxicity-detector + params: + threshold: 0.8 + mode: post_call # Runs after the model call + on_failure: block diff --git a/src/config/lib.rs b/src/config/lib.rs index abec66f4..eec1f495 100644 --- a/src/config/lib.rs +++ b/src/config/lib.rs @@ -1,3 +1,4 @@ +use crate::guardrails::types::GuardrailsConfig; use crate::types::{GatewayConfig, ModelConfig, Pipeline, PipelineType, PluginConfig, Provider}; use serde::Deserialize; use std::sync::OnceLock; @@ -11,6 +12,8 @@ struct YamlCompatiblePipeline { r#type: PipelineType, #[serde(with = "serde_yaml::with::singleton_map_recursive")] plugins: Vec, + #[serde(default)] + guards: Vec, #[serde(default = "default_enabled_true_lib")] #[allow(dead_code)] enabled: bool, // Keep for YAML parsing, but won't be mapped to core Pipeline @@ -31,6 +34,8 @@ struct YamlRoot { models: Vec, #[serde(default)] pipelines: Vec, + #[serde(default)] + guardrails: Option, } fn substitute_env_vars(content: &str) -> Result> { @@ -83,11 +88,12 @@ pub fn load_config(path: &str) -> Result Result<(), Vec } } - // Add more validation checks as needed: - // - Duplicate keys for providers, models, pipelines? - // - Empty names/keys? - // - Specific validation for provider params based on type (more complex, might be out of scope for basic validation) + // Check 3: If any pipeline specifies guards, guardrails section must exist + let has_pipeline_guards = config.pipelines.iter().any(|p| !p.guards.is_empty()); + if has_pipeline_guards && config.guardrails.is_none() { + errors.push( + "One or more pipelines specify guards, but the 'guardrails' section is missing." + .to_string(), + ); + } + + // Check 4: Guardrails validation + if let Some(gr_config) = &config.guardrails { + // Validate all guards in a single pass + let mut seen_guard_names = HashSet::new(); + for guard in &gr_config.guards { + // Check provider reference exists; skip api_base/api_key checks + // when the provider is missing since those would be redundant. + if !gr_config.providers.contains_key(&guard.provider) { + errors.push(format!( + "Guard '{}' references non-existent guardrail provider '{}'.", + guard.name, guard.provider + )); + } else { + // Check api_base and api_key (either directly or via provider) + let has_api_base = guard.api_base.as_ref().is_some_and(|s| !s.is_empty()) + || gr_config + .providers + .get(&guard.provider) + .is_some_and(|p| !p.api_base.is_empty()); + let has_api_key = guard.api_key.as_ref().is_some_and(|s| !s.is_empty()) + || gr_config + .providers + .get(&guard.provider) + .is_some_and(|p| !p.api_key.is_empty()); + + if !has_api_base { + errors.push(format!( + "Guard '{}' has no api_base configured (neither on the guard nor on provider '{}').", + guard.name, guard.provider + )); + } + if !has_api_key { + errors.push(format!( + "Guard '{}' has no api_key configured (neither on the guard nor on provider '{}').", + guard.name, guard.provider + )); + } + } + + // Check evaluator slug is recognized + if crate::guardrails::evaluator_types::get_evaluator(&guard.evaluator_slug).is_none() { + errors.push(format!( + "Guard '{}' has unknown evaluator_slug '{}'.", + guard.name, guard.evaluator_slug + )); + } + + // Check for duplicate guard names + if !seen_guard_names.insert(&guard.name) { + errors.push(format!("Duplicate guard name: '{}'.", guard.name)); + } + } + + // Pipeline guard references must exist in guardrails.guards + for pipeline in &config.pipelines { + for guard_name in &pipeline.guards { + if !seen_guard_names.contains(guard_name) { + errors.push(format!( + "Pipeline '{}' references non-existent guard '{}'.", + pipeline.name, guard_name + )); + } + } + } + } if errors.is_empty() { Ok(()) @@ -52,11 +122,16 @@ pub fn validate_gateway_config(config: &GatewayConfig) -> Result<(), Vec #[cfg(test)] mod tests { use super::*; // To import validate_gateway_config - use crate::types::{ModelConfig, Pipeline, PipelineType, PluginConfig, Provider, ProviderType}; // For test data + use crate::guardrails::types::{ + Guard, GuardMode, GuardrailsConfig, OnFailure, ProviderConfig as GrProviderConfig, + }; + use crate::types::{ModelConfig, Pipeline, PipelineType, PluginConfig, Provider, ProviderType}; + use std::collections::HashMap; // For test data #[test] fn test_valid_config() { let config = GatewayConfig { + guardrails: None, general: None, providers: vec![Provider { key: "p1".to_string(), @@ -76,6 +151,7 @@ mod tests { plugins: vec![PluginConfig::ModelRouter { models: vec!["m1".to_string()], }], + guards: vec![], }], }; assert!(validate_gateway_config(&config).is_ok()); @@ -84,6 +160,7 @@ mod tests { #[test] fn test_invalid_model_provider_ref() { let config = GatewayConfig { + guardrails: None, general: None, providers: vec![Provider { key: "p1".to_string(), @@ -109,6 +186,7 @@ mod tests { #[test] fn test_invalid_pipeline_model_ref() { let config = GatewayConfig { + guardrails: None, general: None, providers: vec![Provider { key: "p1".to_string(), @@ -127,7 +205,8 @@ mod tests { r#type: PipelineType::Chat, plugins: vec![PluginConfig::ModelRouter { models: vec!["m2_non_existent".to_string()], - }], // Invalid model ref + }], + guards: vec![], }], }; let result = validate_gateway_config(&config); @@ -136,4 +215,274 @@ mod tests { assert_eq!(errors.len(), 1); assert!(errors[0].contains("references non-existent model 'm2_non_existent'")); } + + #[test] + fn test_guard_references_non_existent_guardrail_provider() { + let config = GatewayConfig { + guardrails: Some(GuardrailsConfig { + providers: HashMap::from([( + "gr_p1".to_string(), + GrProviderConfig { + name: "gr_p1".to_string(), + api_base: "http://localhost".to_string(), + api_key: "key".to_string(), + }, + )]), + guards: vec![Guard { + name: "g1".to_string(), + provider: "gr_p2_non_existent".to_string(), + evaluator_slug: "pii-detector".to_string(), + params: Default::default(), + mode: GuardMode::PreCall, + on_failure: OnFailure::Block, + required: true, + api_base: None, + api_key: None, + }], + }), + general: None, + providers: vec![], + models: vec![], + pipelines: vec![], + }; + let result = validate_gateway_config(&config); + assert!(result.is_err()); + let errors = result.unwrap_err(); + assert!(errors.iter().any(|e| { + e.contains("references non-existent guardrail provider 'gr_p2_non_existent'") + })); + } + + #[test] + fn test_pipeline_references_non_existent_guard() { + let config = GatewayConfig { + guardrails: Some(GuardrailsConfig { + providers: HashMap::from([( + "gr_p1".to_string(), + GrProviderConfig { + name: "gr_p1".to_string(), + api_base: "http://localhost".to_string(), + api_key: "key".to_string(), + }, + )]), + guards: vec![Guard { + name: "g1".to_string(), + provider: "gr_p1".to_string(), + evaluator_slug: "pii-detector".to_string(), + params: Default::default(), + mode: GuardMode::PreCall, + on_failure: OnFailure::Block, + required: true, + api_base: None, + api_key: None, + }], + }), + general: None, + providers: vec![], + models: vec![], + pipelines: vec![Pipeline { + name: "pipe1".to_string(), + r#type: PipelineType::Chat, + plugins: vec![], + guards: vec!["g_non_existent".to_string()], + }], + }; + let result = validate_gateway_config(&config); + assert!(result.is_err()); + let errors = result.unwrap_err(); + assert_eq!(errors.len(), 1); + assert!(errors[0].contains("references non-existent guard 'g_non_existent'")); + } + + #[test] + fn test_duplicate_guard_names() { + let config = GatewayConfig { + guardrails: Some(GuardrailsConfig { + providers: HashMap::from([( + "gr_p1".to_string(), + GrProviderConfig { + name: "gr_p1".to_string(), + api_base: "http://localhost".to_string(), + api_key: "key".to_string(), + }, + )]), + guards: vec![ + Guard { + name: "g1".to_string(), + provider: "gr_p1".to_string(), + evaluator_slug: "pii-detector".to_string(), + params: Default::default(), + mode: GuardMode::PreCall, + on_failure: OnFailure::Block, + required: true, + api_base: None, + api_key: None, + }, + Guard { + name: "g1".to_string(), + provider: "gr_p1".to_string(), + evaluator_slug: "toxicity-detector".to_string(), + params: Default::default(), + mode: GuardMode::PostCall, + on_failure: OnFailure::Warn, + required: true, + api_base: None, + api_key: None, + }, + ], + }), + general: None, + providers: vec![], + models: vec![], + pipelines: vec![], + }; + let result = validate_gateway_config(&config); + assert!(result.is_err()); + let errors = result.unwrap_err(); + assert_eq!(errors.len(), 1); + assert!(errors[0].contains("Duplicate guard name: 'g1'")); + } + + #[test] + fn test_guard_missing_api_base_and_api_key() { + let config = GatewayConfig { + guardrails: Some(GuardrailsConfig { + providers: HashMap::from([( + "gr_p1".to_string(), + GrProviderConfig { + name: "gr_p1".to_string(), + api_base: "".to_string(), + api_key: "".to_string(), + }, + )]), + guards: vec![Guard { + name: "g1".to_string(), + provider: "gr_p1".to_string(), + evaluator_slug: "pii-detector".to_string(), + params: Default::default(), + mode: GuardMode::PreCall, + on_failure: OnFailure::Block, + required: true, + api_base: None, + api_key: None, + }], + }), + general: None, + providers: vec![], + models: vec![], + pipelines: vec![], + }; + let result = validate_gateway_config(&config); + assert!(result.is_err()); + let errors = result.unwrap_err(); + assert_eq!(errors.len(), 2); + assert!(errors[0].contains("no api_base configured")); + assert!(errors[1].contains("no api_key configured")); + } + + #[test] + fn test_guard_inherits_api_base_from_provider() { + let config = GatewayConfig { + guardrails: Some(GuardrailsConfig { + providers: HashMap::from([( + "gr_p1".to_string(), + GrProviderConfig { + name: "gr_p1".to_string(), + api_base: "http://localhost".to_string(), + api_key: "key".to_string(), + }, + )]), + guards: vec![Guard { + name: "g1".to_string(), + provider: "gr_p1".to_string(), + evaluator_slug: "pii-detector".to_string(), + params: Default::default(), + mode: GuardMode::PreCall, + on_failure: OnFailure::Block, + required: true, + api_base: None, + api_key: None, + }], + }), + general: None, + providers: vec![], + models: vec![], + pipelines: vec![], + }; + assert!(validate_gateway_config(&config).is_ok()); + } + + #[test] + fn test_guard_unknown_evaluator_slug() { + let config = GatewayConfig { + guardrails: Some(GuardrailsConfig { + providers: HashMap::from([( + "gr_p1".to_string(), + GrProviderConfig { + name: "gr_p1".to_string(), + api_base: "http://localhost".to_string(), + api_key: "key".to_string(), + }, + )]), + guards: vec![Guard { + name: "g1".to_string(), + provider: "gr_p1".to_string(), + evaluator_slug: "made-up-slug".to_string(), + params: Default::default(), + mode: GuardMode::PreCall, + on_failure: OnFailure::Block, + required: true, + api_base: None, + api_key: None, + }], + }), + general: None, + providers: vec![], + models: vec![], + pipelines: vec![], + }; + let result = validate_gateway_config(&config); + assert!(result.is_err()); + let errors = result.unwrap_err(); + assert!( + errors + .iter() + .any(|e| e.contains("unknown evaluator_slug 'made-up-slug'")) + ); + } + + #[test] + fn test_pipeline_guards_without_guardrails_section() { + // Test the case where a pipeline specifies guards but guardrails section is missing + let config = GatewayConfig { + guardrails: None, // Missing guardrails section + general: None, + providers: vec![Provider { + key: "p1".to_string(), + r#type: ProviderType::OpenAI, + api_key: "key1".to_string(), + params: Default::default(), + }], + models: vec![ModelConfig { + key: "m1".to_string(), + r#type: "gpt-4".to_string(), + provider: "p1".to_string(), + params: Default::default(), + }], + pipelines: vec![Pipeline { + name: "pipe1".to_string(), + r#type: PipelineType::Chat, + plugins: vec![PluginConfig::ModelRouter { + models: vec!["m1".to_string()], + }], + guards: vec!["g1".to_string()], // Pipeline specifies a guard + }], + }; + let result = validate_gateway_config(&config); + assert!(result.is_err()); + let errors = result.unwrap_err(); + assert_eq!(errors.len(), 1); + assert!(errors[0].contains("pipelines specify guards")); + assert!(errors[0].contains("guardrails' section is missing")); + } } diff --git a/src/guardrails/GUARDRAILS.md b/src/guardrails/GUARDRAILS.md new file mode 100644 index 00000000..f934ad61 --- /dev/null +++ b/src/guardrails/GUARDRAILS.md @@ -0,0 +1,161 @@ +# Guardrails + +Guardrails intercept LLM gateway traffic to evaluate requests and responses against configurable safety, validation, and quality checks. Each guard calls an external evaluator API and either **blocks** the request (HTTP 403) or **warns** via a response header, depending on configuration. + +> **Available in:** Traceloop Hub v1 +> **Full reference documentation:** [Guardrails Documentation](https://docs.traceloop.com/evaluators/guardrails) + +--- + +Guardrails can be implemented in **Config Mode (Hub v1)** : +Guardrails fully defined in YAML configuration, applied automatically to gateway requests + + +This document focuses on **config mode** available in Traceloop Hub v1. + +--- + +## How It Works + +``` + ┌──────────────┐ + Request ──────►│ Pre-call │──── Block (403) ──► Client + │ Guards │ + └──────┬───────┘ + │ pass + ▼ + ┌──────────────┐ + │ LLM Call │ + └──────┬───────┘ + │ + ▼ + ┌──────────────┐ + │ Post-call │──── Block (403) ──► Client + │ Guards │ + └──────┬───────┘ + │ pass + ▼ + Response (+ warning headers if any) +``` + +1. **Pre-call guards** run on the user's prompt *before* it reaches the LLM. +2. **Post-call guards** run on the LLM's response *before* it is returned to the client. +3. All guards in a phase execute **concurrently** for minimal latency. + +**Supported Routes:** +- ✅ Chat Completions (`/v1/chat/completions`) — pre-call and post-call guards +- ✅ Completions (`/v1/completions`) — pre-call and post-call guards +- ✅ Embeddings (`/v1/embeddings`) — **pre-call guards only** +- ⚠️ Streaming requests (`"stream": true`) — **pre-call guards only** (post-call guards are skipped because the response is sent as incremental chunks) + +--- + +## Configuration + +Guards are defined in the gateway YAML config under `guardrails`. Provider-level defaults for `api_base` and `api_key` can be inherited by guards or overridden per-guard. + +```yaml +guardrails: + providers: + - name: traceloop + api_base: https://api.traceloop.com + api_key: ${TRACELOOP_API_KEY} + + guards: + - name: pii-check + provider: traceloop + evaluator_slug: pii-detector + mode: pre_call # pre_call | post_call + on_failure: block # block | warn + required: false # when true, evaluator errors block the request; when false, they warn and continue (default: false) + params: # evaluator-specific parameters + probability_threshold: 0.7 +``` + +Pipelines reference guards by name: + +```yaml +pipelines: + - name: default + guards: [pii-check, injection-check] + plugins: + - model-router: + models: [gpt-4] +``` + +### Runtime Guard Addition + +Guards can be added (never removed) at request time via: +- **Header:** `X-Traceloop-Guardrails: extra-guard-1, extra-guard-2` + +These are **additive** to the pipeline-configured guards, preserving the security baseline. + +--- + +## Supported Evaluators + +| Slug | Category | Configurable Params | +|---|---|---| +| `pii-detector` | Safety | `probability_threshold` | +| `secrets-detector` | Safety | — | +| `prompt-injection` | Safety | `threshold` | +| `profanity-detector` | Safety | — | +| `sexism-detector` | Safety | `threshold` | +| `toxicity-detector` | Safety | `threshold` | +| `regex-validator` | Validation | `regex`, `should_match`, `case_sensitive`, `dot_include_nl`, `multi_line` | +| `json-validator` | Validation | `enable_schema_validation`, `schema_string` | +| `sql-validator` | Validation | — | +| `tone-detection` | Quality | — | +| `prompt-perplexity` | Quality | — | +| `uncertainty-detector` | Quality | — | + +--- + +## Failure Behavior + +| Evaluation Result | `on_failure` | `required` | Action | +|---|---|---|---| +| Pass | — | — | Continue | +| Fail | `block` | — | Return 403 | +| Fail | `warn` | — | Add warning header, continue | +| Evaluator error | — | `true` | Return 403 (fail-closed) | +| Evaluator error | — | `false` | Add warning header, continue (fail-open) | + +**Blocked response (403):** +```json +{ + "error": { + "type": "guardrail_blocked", + "guardrail": "pii-check", + "message": "Request blocked by guardrail 'pii-check'", + "evaluation_result": { ... }, + "reason": "evaluation_failed" + } +} +``` + +**Warning header:** +``` +X-Traceloop-Guardrail-Warning: guardrail_name="toxicity-filter", reason="failed" +``` + +--- + +## Observability + +Each guard evaluation emits an OpenTelemetry child span with these attributes: + +| Attribute | Description | +|---|---| +| `gen_ai.guardrail.name` | Guard name | +| `gen_ai.guardrail.status` | `PASSED`, `FAILED`, or `ERROR` | +| `gen_ai.guardrail.duration` | Evaluation time in ms | +| `gen_ai.guardrail.input` | Input text (when `trace_content_enabled`) | +| `gen_ai.guardrail.error.type` | Error category (`Unavailable`, `HttpError`, `Timeout`, `ParseError`) | +| `gen_ai.guardrail.error.message` | Error details | + +--- + +## Implementation + +See `src/guardrails/mod.rs` for module structure and key type definitions. diff --git a/src/guardrails/evaluator_types.rs b/src/guardrails/evaluator_types.rs new file mode 100644 index 00000000..be601f1a --- /dev/null +++ b/src/guardrails/evaluator_types.rs @@ -0,0 +1,199 @@ +use serde::{Deserialize, Serialize, de::DeserializeOwned}; +use serde_json::json; +use std::collections::HashMap; + +use super::types::GuardrailError; + +// --------------------------------------------------------------------------- +// Slugs +// --------------------------------------------------------------------------- + +// Safety +pub const PII_DETECTOR: &str = "pii-detector"; +pub const SECRETS_DETECTOR: &str = "secrets-detector"; +pub const PROMPT_INJECTION: &str = "prompt-injection"; +pub const PROFANITY_DETECTOR: &str = "profanity-detector"; +pub const SEXISM_DETECTOR: &str = "sexism-detector"; +pub const TOXICITY_DETECTOR: &str = "toxicity-detector"; +// Validators +pub const REGEX_VALIDATOR: &str = "regex-validator"; +pub const JSON_VALIDATOR: &str = "json-validator"; +pub const SQL_VALIDATOR: &str = "sql-validator"; +// Quality and adherence +pub const TONE_DETECTION: &str = "tone-detection"; +pub const PROMPT_PERPLEXITY: &str = "prompt-perplexity"; +pub const UNCERTAINTY_DETECTOR: &str = "uncertainty-detector"; + +// --------------------------------------------------------------------------- +// Trait +// --------------------------------------------------------------------------- + +/// Each supported evaluator implements this trait to build its typed request body. +pub trait EvaluatorRequest: Send + Sync { + fn build_body( + &self, + input: &str, + params: &HashMap, + ) -> Result; +} + +/// Look up the evaluator implementation for a given slug. +pub fn get_evaluator(slug: &str) -> Option<&'static dyn EvaluatorRequest> { + match slug { + // Safety + PII_DETECTOR => Some(&PiiDetector), + SECRETS_DETECTOR => Some(&SecretsDetector), + PROMPT_INJECTION => Some(&PromptInjection), + PROFANITY_DETECTOR => Some(&ProfanityDetector), + SEXISM_DETECTOR => Some(&SexismDetector), + TOXICITY_DETECTOR => Some(&ToxicityDetector), + // Validators + REGEX_VALIDATOR => Some(&RegexValidator), + JSON_VALIDATOR => Some(&JsonValidator), + SQL_VALIDATOR => Some(&SqlValidator), + // Quality and adherence + TONE_DETECTION => Some(&ToneDetection), + PROMPT_PERPLEXITY => Some(&PromptPerplexity), + UNCERTAINTY_DETECTOR => Some(&UncertaintyDetector), + _ => None, + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +fn text_body(input: &str) -> serde_json::Value { + json!({ "input": { "text": input } }) +} + +fn prompt_body(input: &str) -> serde_json::Value { + json!({ "input": { "prompt": input } }) +} + +/// Deserialize `params` into a typed config `C`, then attach it to the body. +/// Skips the `config` key entirely when `params` is empty. +fn attach_config( + mut body: serde_json::Value, + params: &HashMap, + slug: &str, +) -> Result { + if params.is_empty() { + return Ok(body); + } + let params_value: serde_json::Value = params.clone().into_iter().collect(); + let config: C = serde_json::from_value(params_value) + .map_err(|e| GuardrailError::ParseError(format!("{slug}: invalid config — {e}")))?; + let config_json = + serde_json::to_value(config).map_err(|e| GuardrailError::ParseError(e.to_string()))?; + if config_json.as_object().is_some_and(|m| !m.is_empty()) { + body["config"] = config_json; + } + Ok(body) +} + +macro_rules! evaluator_with_no_config { + ($name:ident, $body_fn:ident) => { + pub struct $name; + impl EvaluatorRequest for $name { + fn build_body( + &self, + input: &str, + _params: &HashMap, + ) -> Result { + Ok($body_fn(input)) + } + } + }; +} + +macro_rules! evaluator_with_config { + ($name:ident, $body_fn:ident, $config:ty, $slug:expr) => { + pub struct $name; + impl EvaluatorRequest for $name { + fn build_body( + &self, + input: &str, + params: &HashMap, + ) -> Result { + attach_config::<$config>($body_fn(input), params, $slug) + } + } + }; +} + +evaluator_with_no_config!(SecretsDetector, text_body); +evaluator_with_no_config!(ProfanityDetector, text_body); +evaluator_with_no_config!(SqlValidator, text_body); +evaluator_with_no_config!(ToneDetection, text_body); +evaluator_with_no_config!(PromptPerplexity, prompt_body); +evaluator_with_no_config!(UncertaintyDetector, prompt_body); + +// --------------------------------------------------------------------------- +// Config structs (mirroring the Go DTOs in evaluator_mbt.go) +// --------------------------------------------------------------------------- + +#[derive(Default, Deserialize, Serialize)] +pub struct PiiDetectorConfig { + #[serde(skip_serializing_if = "Option::is_none")] + pub probability_threshold: Option, +} + +#[derive(Default, Deserialize, Serialize)] +pub struct ThresholdConfig { + #[serde(skip_serializing_if = "Option::is_none")] + pub threshold: Option, +} + +#[derive(Default, Deserialize, Serialize)] +pub struct RegexValidatorConfig { + #[serde(skip_serializing_if = "Option::is_none")] + pub regex: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub should_match: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub case_sensitive: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub dot_include_nl: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub multi_line: Option, +} + +#[derive(Default, Deserialize, Serialize)] +pub struct JsonValidatorConfig { + #[serde(skip_serializing_if = "Option::is_none")] + pub enable_schema_validation: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub schema_string: Option, +} + +// --------------------------------------------------------------------------- +// Evaluators with config +// --------------------------------------------------------------------------- + +evaluator_with_config!(PiiDetector, text_body, PiiDetectorConfig, PII_DETECTOR); +evaluator_with_config!( + PromptInjection, + prompt_body, + ThresholdConfig, + PROMPT_INJECTION +); +evaluator_with_config!(SexismDetector, text_body, ThresholdConfig, SEXISM_DETECTOR); +evaluator_with_config!( + ToxicityDetector, + text_body, + ThresholdConfig, + TOXICITY_DETECTOR +); +evaluator_with_config!( + RegexValidator, + text_body, + RegexValidatorConfig, + REGEX_VALIDATOR +); +evaluator_with_config!( + JsonValidator, + text_body, + JsonValidatorConfig, + JSON_VALIDATOR +); diff --git a/src/guardrails/middleware.rs b/src/guardrails/middleware.rs new file mode 100644 index 00000000..6067cfd9 --- /dev/null +++ b/src/guardrails/middleware.rs @@ -0,0 +1,307 @@ +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use axum::body::Body; +use axum::extract::Request; +use axum::response::{IntoResponse, Response}; +use tower::{Layer, Service}; +use tracing::{debug, warn}; + +use crate::models::chat::{ChatCompletion, ChatCompletionRequest}; +use crate::models::completion::{CompletionRequest, CompletionResponse}; +use crate::models::embeddings::EmbeddingsRequest; +use crate::pipelines::otel::SharedTracer; + +use serde::de::DeserializeOwned; + +use super::parsing::PromptExtractor; +use super::runner::GuardrailsRunner; +use super::types::Guardrails; + +/// Maximum allowed body size for request/response buffering (10 MB). +/// Prevents unbounded memory allocation when guardrails need to inspect bodies. +pub const MAX_BODY_SIZE: usize = 10 * 1024 * 1024; + +/// Enum representing the endpoint type. +#[derive(Debug, Clone, Copy)] +enum EndpointType { + Chat, + Completion, + Embeddings, +} + +impl EndpointType { + /// Determine endpoint type from request path. + fn from_path(path: &str) -> Option { + match path { + p if p.ends_with("/chat/completions") => Some(Self::Chat), + p if p.ends_with("/completions") => Some(Self::Completion), + p if p.ends_with("/embeddings") => Some(Self::Embeddings), + _ => None, + } + } +} + +/// Enum representing the type of request being processed. +enum ParsedRequest { + Chat(Box), + Completion(Box), + Embeddings(Box), +} + +impl ParsedRequest { + /// Returns true if this is a streaming request. + fn is_streaming(&self) -> bool { + match self { + ParsedRequest::Chat(req) => req.stream.unwrap_or(false), + ParsedRequest::Completion(req) => req.stream.unwrap_or(false), + ParsedRequest::Embeddings(_) => false, + } + } + + /// Returns true if this request type supports post-call guards. + /// Streaming requests do not support post-call guards because the response + /// is sent as chunks and cannot be buffered for evaluation. + fn supports_post_call(&self) -> bool { + if self.is_streaming() { + return false; + } + match self { + ParsedRequest::Chat(_) | ParsedRequest::Completion(_) => true, + ParsedRequest::Embeddings(_) => false, + } + } +} + +impl PromptExtractor for ParsedRequest { + fn extract_prompt(&self) -> String { + match self { + ParsedRequest::Chat(req) => req.extract_prompt(), + ParsedRequest::Completion(req) => req.extract_prompt(), + ParsedRequest::Embeddings(req) => req.extract_prompt(), + } + } +} + +/// Helper function to handle post-call guards for supported request types. +async fn handle_post_call_guards( + parsed_request: &ParsedRequest, + resp_parts: axum::http::response::Parts, + resp_body: Body, + runner: &GuardrailsRunner<'_>, + mut warnings: Vec, +) -> Response { + let resp_bytes = match axum::body::to_bytes(resp_body, MAX_BODY_SIZE).await { + Ok(b) => b, + Err(_) => { + debug!("Guardrails middleware: failed to buffer response body, skipping post-call"); + let response = Response::from_parts(resp_parts, Body::empty()); + return GuardrailsRunner::finalize_response(response, &warnings); + } + }; + + let post_result = match parsed_request { + ParsedRequest::Chat(_) => { + if let Ok(completion) = serde_json::from_slice::(&resp_bytes) { + Some(runner.run_post_call(&completion).await) + } else { + debug!("Guardrails middleware: failed to parse chat completion response"); + None + } + } + ParsedRequest::Completion(_) => { + if let Ok(completion) = serde_json::from_slice::(&resp_bytes) { + Some(runner.run_post_call(&completion).await) + } else { + debug!("Guardrails middleware: failed to parse completion response"); + None + } + } + ParsedRequest::Embeddings(_) => None, + }; + + if let Some(result) = post_result { + match result { + Err(blocked) => return *blocked, + Ok(w) => warnings.extend(w), + } + } + + let response = Response::from_parts(resp_parts, Body::from(resp_bytes)); + GuardrailsRunner::finalize_response(response, &warnings) +} + +/// Try to deserialize bytes into a request type, logging on failure. +/// Returns None if deserialization fails, allowing the caller to pass through. +fn try_parse(bytes: &[u8], label: &str) -> Option { + match serde_json::from_slice::(bytes) { + Ok(req) => Some(req), + Err(e) => { + debug!( + "Guardrails middleware: failed to parse {} request: {}", + label, e + ); + None + } + } +} + +/// Tower layer that applies guardrail checks around a service. +/// +/// - **Pre-call guards** run before the inner service, inspecting the request body. +/// - **Post-call guards** run after the inner service, inspecting the response body. +/// - Streaming requests (`"stream": true`) run pre-call guards but skip post-call guards. +#[derive(Clone)] +pub struct GuardrailsLayer { + guardrails: Option>, +} + +impl GuardrailsLayer { + pub fn new(guardrails: Option>) -> Self { + Self { guardrails } + } +} + +impl Layer for GuardrailsLayer { + type Service = GuardrailsMiddleware; + + fn layer(&self, inner: S) -> Self::Service { + GuardrailsMiddleware { + inner, + guardrails: self.guardrails.clone(), + } + } +} + +#[derive(Clone)] +pub struct GuardrailsMiddleware { + inner: S, // pipeline router + guardrails: Option>, +} + +impl Service> for GuardrailsMiddleware +where + S: Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, + S::Error: Into> + Send, +{ + type Response = Response; + type Error = S::Error; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, request: Request) -> Self::Future { + let guardrails = self.guardrails.clone(); + let inner = self.inner.clone(); + let mut inner = std::mem::replace(&mut self.inner, inner); + + Box::pin(async move { + // No guardrails configured — pass through without buffering + let guardrails = match guardrails { + Some(gr) => gr, + None => return inner.call(request).await, + }; + + let (parts, body) = request.into_parts(); + + // Determine endpoint type from path (more efficient than parsing JSON) + let endpoint_type = match EndpointType::from_path(parts.uri.path()) { + Some(t) => t, + None => { + // Unsupported endpoint — pass through + debug!( + "Guardrails middleware: unsupported endpoint {}, passing through", + parts.uri.path() + ); + let request = Request::from_parts(parts, body); + return inner.call(request).await; + } + }; + + // Buffer request body + let bytes = match axum::body::to_bytes(body, MAX_BODY_SIZE).await { + Ok(b) => b, + Err(_) => { + warn!("Guardrails middleware: request body too large or unreadable"); + let body = serde_json::json!({ + "error": { + "message": "Request body too large or unreadable", + "type": "invalid_request_error", + } + }); + return Ok( + (axum::http::StatusCode::BAD_REQUEST, axum::Json(body)).into_response() + ); + } + }; + + // Parse request based on endpoint type + let parsed_request = match endpoint_type { + EndpointType::Chat => try_parse::(&bytes, "chat") + .map(|req| ParsedRequest::Chat(Box::new(req))), + EndpointType::Completion => try_parse::(&bytes, "completion") + .map(|req| ParsedRequest::Completion(Box::new(req))), + EndpointType::Embeddings => try_parse::(&bytes, "embeddings") + .map(|req| ParsedRequest::Embeddings(Box::new(req))), + }; + let parsed_request = match parsed_request { + Some(pr) => pr, + None => { + let request = Request::from_parts(parts, Body::from(bytes)); + return inner.call(request).await; + } + }; + + // Resolve guards from pipeline config + request headers + // Extract parent context from the tracer in request extensions + let parent_cx = parts + .extensions + .get::() + .and_then(|tracer| tracer.lock().ok().map(|t| t.parent_context())); + let runner = GuardrailsRunner::new(Some(&guardrails), &parts.headers, parent_cx); + + let runner = match runner { + Some(r) => r, + None => { + // No active guards for this request + let request = Request::from_parts(parts, Body::from(bytes)); + return inner.call(request).await; + } + }; + + // --- Pre-call guards --- + let all_warnings = match runner.run_pre_call(&parsed_request).await { + Ok(warnings) => warnings, + Err(blocked) => return Ok(*blocked), + }; + + // --- Call inner service --- + let request = Request::from_parts(parts, Body::from(bytes)); + let response = inner.call(request).await?; + + // --- Post-call guards (only for request types that produce text) --- + let (resp_parts, resp_body) = response.into_parts(); + + if parsed_request.supports_post_call() { + Ok(handle_post_call_guards( + &parsed_request, + resp_parts, + resp_body, + &runner, + all_warnings, + ) + .await) + } else { + // No post-call guards for this request type (e.g., embeddings) + // Pass through response with pre-call warnings attached + let response = Response::from_parts(resp_parts, resp_body); + Ok(GuardrailsRunner::finalize_response(response, &all_warnings)) + } + }) + } +} diff --git a/src/guardrails/mod.rs b/src/guardrails/mod.rs new file mode 100644 index 00000000..d496100b --- /dev/null +++ b/src/guardrails/mod.rs @@ -0,0 +1,8 @@ +pub mod evaluator_types; +pub mod middleware; +pub mod parsing; +pub mod providers; +pub mod runner; +pub mod setup; +pub mod span_attributes; +pub mod types; diff --git a/src/guardrails/parsing.rs b/src/guardrails/parsing.rs new file mode 100644 index 00000000..fd984248 --- /dev/null +++ b/src/guardrails/parsing.rs @@ -0,0 +1,114 @@ +use crate::models::chat::{ChatCompletion, ChatCompletionRequest}; +use crate::models::completion::{CompletionRequest, CompletionResponse}; +use crate::models::content::ChatMessageContent; +use crate::models::embeddings::EmbeddingsRequest; +use tracing::debug; + +use super::types::{EvaluatorResponse, GuardrailError}; + +/// Trait for extracting pre-call guardrail input from a request. +pub trait PromptExtractor { + fn extract_prompt(&self) -> String; +} + +/// Trait for extracting post-call guardrail input from a response. +pub trait CompletionExtractor { + fn extract_completion(&self) -> String; +} + +impl PromptExtractor for ChatCompletionRequest { + fn extract_prompt(&self) -> String { + self.messages + .iter() + .filter_map(|m| { + m.content.as_ref().map(|content| match content { + ChatMessageContent::String(s) => s.clone(), + ChatMessageContent::Array(parts) => parts + .iter() + .filter(|p| p.r#type == "text") + .map(|p| p.text.as_str()) + .collect::>() + .join(" "), + }) + }) + .collect::>() + .join("\n") + } +} + +impl CompletionExtractor for ChatCompletion { + fn extract_completion(&self) -> String { + self.choices + .first() + .and_then(|choice| choice.message.content.as_ref()) + .map(|content| match content { + ChatMessageContent::String(s) => s.clone(), + ChatMessageContent::Array(parts) => parts + .iter() + .filter(|p| p.r#type == "text") + .map(|p| p.text.as_str()) + .collect::>() + .join(" "), + }) + .unwrap_or_default() + } +} + +impl PromptExtractor for CompletionRequest { + fn extract_prompt(&self) -> String { + self.prompt.clone() + } +} + +impl CompletionExtractor for CompletionResponse { + fn extract_completion(&self) -> String { + self.choices + .first() + .map(|choice| choice.text.clone()) + .unwrap_or_default() + } +} + +impl PromptExtractor for EmbeddingsRequest { + fn extract_prompt(&self) -> String { + match &self.input { + crate::models::embeddings::EmbeddingsInput::Single(s) => s.clone(), + crate::models::embeddings::EmbeddingsInput::Multiple(v) => v.join("\n"), + crate::models::embeddings::EmbeddingsInput::SingleTokenIds(_) => { + "[token IDs]".to_string() + } + crate::models::embeddings::EmbeddingsInput::MultipleTokenIds(_) => { + "[multiple token IDs]".to_string() + } + } + } +} + +/// Parse the evaluator response body (JSON string) into an EvaluatorResponse. +pub fn parse_evaluator_response(body: &str) -> Result { + let response = serde_json::from_str::(body) + .map_err(|e| GuardrailError::ParseError(e.to_string()))?; + + // Log for debugging + debug!( + pass = response.pass, + result = %response.result, + "Parsed evaluator response" + ); + + Ok(response) +} + +/// Parse an HTTP response from the evaluator, handling non-200 status codes. +pub fn parse_evaluator_http_response( + status: u16, + body: &str, +) -> Result { + if !(200..300).contains(&status) { + return Err(GuardrailError::HttpError { + status, + body: body.to_string(), + }); + } + parse_evaluator_response(body) +} diff --git a/src/guardrails/providers/mod.rs b/src/guardrails/providers/mod.rs new file mode 100644 index 00000000..6d88b394 --- /dev/null +++ b/src/guardrails/providers/mod.rs @@ -0,0 +1 @@ +pub mod traceloop; diff --git a/src/guardrails/providers/traceloop.rs b/src/guardrails/providers/traceloop.rs new file mode 100644 index 00000000..a2c2f6cb --- /dev/null +++ b/src/guardrails/providers/traceloop.rs @@ -0,0 +1,152 @@ +use async_trait::async_trait; + +use crate::guardrails::evaluator_types::get_evaluator; +use crate::guardrails::parsing::parse_evaluator_http_response; +use crate::guardrails::types::GuardrailClient; +use crate::guardrails::types::{EvaluatorResponse, Guard, GuardrailError}; + +const DEFAULT_TRACELOOP_API: &str = "https://api.traceloop.com"; +const DEFAULT_TIMEOUT_SEC: u64 = 3; + +/// HTTP client for the Traceloop evaluator API service. +/// Calls `POST {api_base}/v2/guardrails/execute/{evaluator_slug}`. +pub struct TraceloopClient { + http_client: reqwest::Client, +} + +impl Default for TraceloopClient { + fn default() -> Self { + Self::new() + } +} + +impl TraceloopClient { + pub fn new() -> Self { + Self::with_timeout(std::time::Duration::from_secs(DEFAULT_TIMEOUT_SEC)) + } + + pub fn with_timeout(timeout: std::time::Duration) -> Self { + Self { + http_client: reqwest::Client::builder() + .timeout(timeout) + .build() + .expect("Failed to build HTTP client for Traceloop"), + } + } +} + +#[async_trait] +impl GuardrailClient for TraceloopClient { + async fn evaluate( + &self, + guard: &Guard, + input: &str, + ) -> Result { + let api_base = guard + .api_base + .as_deref() + .filter(|s| !s.is_empty()) + .unwrap_or(DEFAULT_TRACELOOP_API); + let url = format!( + "{}/v2/guardrails/execute/{}", + api_base, guard.evaluator_slug + ); + + let api_key = guard + .api_key + .as_deref() + .filter(|k| !k.is_empty()) + .ok_or_else(|| { + GuardrailError::Unavailable( + "Traceloop API key is required but not provided".to_string(), + ) + })?; + + let evaluator = get_evaluator(&guard.evaluator_slug).ok_or_else(|| { + GuardrailError::Unavailable(format!( + "Unknown evaluator slug '{}'", + guard.evaluator_slug + )) + })?; + let body = evaluator.build_body(input, &guard.params)?; + + let response = self + .http_client + .post(&url) + .header("Authorization", format!("Bearer {api_key}")) + .json(&body) + .send() + .await?; + + let status = response.status().as_u16(); + let response_body = response.text().await?; + + parse_evaluator_http_response(status, &response_body) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + use std::collections::HashMap; + + #[test] + fn test_build_body_text_slug() { + let params = HashMap::new(); + let body = get_evaluator("pii-detector") + .unwrap() + .build_body("hello world", ¶ms) + .unwrap(); + assert_eq!(body, json!({"input": {"text": "hello world"}})); + } + + #[test] + fn test_build_body_prompt_slug() { + let params = HashMap::new(); + let body = get_evaluator("prompt-injection") + .unwrap() + .build_body("hello world", ¶ms) + .unwrap(); + assert_eq!(body, json!({"input": {"prompt": "hello world"}})); + } + + #[test] + fn test_build_body_with_config() { + let mut params = HashMap::new(); + params.insert("threshold".to_string(), json!(0.8)); + let body = get_evaluator("toxicity-detector") + .unwrap() + .build_body("test", ¶ms) + .unwrap(); + assert_eq!( + body, + json!({"input": {"text": "test"}, "config": {"threshold": 0.8}}) + ); + } + + #[test] + fn test_build_body_no_config_when_params_empty() { + let params = HashMap::new(); + let body = get_evaluator("secrets-detector") + .unwrap() + .build_body("test", ¶ms) + .unwrap(); + assert!(body.get("config").is_none()); + } + + #[test] + fn test_get_evaluator_unknown_slug() { + assert!(get_evaluator("nonexistent").is_none()); + } + + #[test] + fn test_build_body_rejects_invalid_config_type() { + let mut params = HashMap::new(); + params.insert("threshold".to_string(), json!("not-a-number")); + let result = get_evaluator("toxicity-detector") + .unwrap() + .build_body("test", ¶ms); + assert!(result.is_err()); + } +} diff --git a/src/guardrails/runner.rs b/src/guardrails/runner.rs new file mode 100644 index 00000000..52268e17 --- /dev/null +++ b/src/guardrails/runner.rs @@ -0,0 +1,379 @@ +use std::collections::HashSet; + +use axum::Json; +use axum::http::HeaderMap; +use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; +use futures::future::join_all; +use opentelemetry::global::{BoxedSpan, ObjectSafeSpan}; +use opentelemetry::trace::{SpanKind, Status as OtelStatus, Tracer}; +use opentelemetry::{Context, KeyValue, global}; +use serde_json::json; +use tracing::{debug, warn}; + +use crate::config::lib::get_trace_content_enabled; + +use super::parsing::{CompletionExtractor, PromptExtractor}; +use super::setup::{parse_guardrails_header, resolve_guards_by_name, split_guards_by_mode}; +use super::span_attributes::*; +use super::types::{ + EvaluatorResponse, Guard, GuardResult, GuardWarning, GuardrailClient, GuardrailError, + Guardrails, GuardrailsOutcome, OnFailure, +}; + +fn error_type_name(err: &GuardrailError) -> &'static str { + match err { + GuardrailError::Unavailable(_) => "Unavailable", + GuardrailError::HttpError { .. } => "HttpError", + GuardrailError::Timeout(_) => "Timeout", + GuardrailError::ParseError(_) => "ParseError", + } +} + +fn record_guard_span( + span: &mut BoxedSpan, + guard: &Guard, + result: &Result, + elapsed: std::time::Duration, + input: &str, +) { + span.set_attribute(KeyValue::new(GEN_AI_GUARDRAIL_NAME, guard.name.clone())); + span.set_attribute(KeyValue::new( + GEN_AI_GUARDRAIL_DURATION, + elapsed.as_millis() as i64, + )); + + if get_trace_content_enabled() { + span.set_attribute(KeyValue::new(GEN_AI_GUARDRAIL_INPUT, input.to_string())); + } + + match result { + Ok(resp) => { + let status = if resp.pass { + GUARDRAIL_PASSED + } else { + GUARDRAIL_FAILED + }; + span.set_attribute(KeyValue::new(GEN_AI_GUARDRAIL_STATUS, status)); + } + Err(err) => { + span.set_attribute(KeyValue::new(GEN_AI_GUARDRAIL_STATUS, GUARDRAIL_ERROR)); + span.set_attribute(KeyValue::new( + GEN_AI_GUARDRAIL_ERROR_TYPE, + error_type_name(err), + )); + span.set_attribute(KeyValue::new( + GEN_AI_GUARDRAIL_ERROR_MESSAGE, + err.to_string(), + )); + span.set_status(OtelStatus::error(err.to_string())); + } + } +} + +/// Execute a set of guardrails against the given input text. +/// Guards are run concurrently. Returns a GuardrailsOutcome with results, blocked status, and warnings. +/// When `parent_cx` is provided, creates a child OTel span per guard evaluation. +pub async fn execute_guards( + guards: &[Guard], + input: &str, + client: &dyn GuardrailClient, + parent_cx: Option<&Context>, +) -> GuardrailsOutcome { + debug!(guard_count = guards.len(), "Executing guardrails"); + + let parent_cx = parent_cx.cloned(); + + let futures: Vec<_> = guards + .iter() + .map(|guard| { + let parent_cx = parent_cx.clone(); + async move { + // Create child span BEFORE evaluation so its start time is correct + let mut span = parent_cx.as_ref().map(|cx| { + let tracer = global::tracer("traceloop_hub"); + tracer + .span_builder(format!("{}.guard", guard.name)) + .with_kind(SpanKind::Internal) + .start_with_context(&tracer, cx) + }); + + let start = std::time::Instant::now(); + let result = client.evaluate(guard, input).await; + let elapsed = start.elapsed(); + + if let Some(s) = &mut span { + record_guard_span(s, guard, &result, elapsed, input); + } + + match &result { + Ok(resp) => debug!( + guard = %guard.name, + pass = resp.pass, + elapsed_ms = elapsed.as_millis(), + "Guard evaluation complete" + ), + Err(err) => warn!( + guard = %guard.name, + error = %err, + required = guard.required, + elapsed_ms = elapsed.as_millis(), + "Guard evaluation failed" + ), + } + (guard, result, span) + } + }) + .collect(); + + let results_raw = join_all(futures).await; + + let mut results = Vec::new(); + let mut blocked = false; + let mut blocking_guard = None; + let mut warnings = Vec::new(); + let mut guard_spans: Vec = Vec::new(); + + for (guard, result, span) in results_raw { + if let Some(s) = span { + guard_spans.push(s); + } + let name = guard.name.clone(); + match result { + Ok(response) => { + if response.pass { + results.push(GuardResult::Passed { name }); + } else { + match guard.on_failure { + OnFailure::Block => { + blocked = true; + if blocking_guard.is_none() { + blocking_guard = Some(name.clone()); + } + } + OnFailure::Warn => { + warnings.push(GuardWarning { + guard_name: name.clone(), + reason: "failed".to_string(), + }); + } + } + results.push(GuardResult::Failed { + name, + result: response.result, + on_failure: guard.on_failure, + }); + } + } + Err(err) => { + let is_required = guard.required; + let error_msg = err.to_string(); + if is_required { + blocked = true; + if blocking_guard.is_none() { + blocking_guard = Some(name.clone()); + } + } else { + warnings.push(GuardWarning { + guard_name: name.clone(), + reason: format!("evaluator error: {error_msg}"), + }); + } + results.push(GuardResult::Error { + name, + error: error_msg, + required: is_required, + }); + } + } + } + + if blocked { + warn!(blocking_guard = ?blocking_guard, "Request blocked by guardrail"); + } + + GuardrailsOutcome { + results, + blocked, + blocking_guard, + warnings, + } +} + +/// Result of a guard phase: Ok(warnings) on pass, Err(blocked_response) on block. +pub type GuardPhaseResult = Result, Box>; + +pub struct GuardrailsRunner<'a> { + pre_call: Vec, + post_call: Vec, + client: &'a dyn GuardrailClient, + parent_cx: Option, +} + +/// Convert a GuardrailsOutcome into a GuardPhaseResult. +/// If the outcome is blocked, produces a blocked response; otherwise, forwards warnings. +fn outcome_to_phase_result(outcome: GuardrailsOutcome) -> GuardPhaseResult { + if outcome.blocked { + Err(Box::new(blocked_response(&outcome))) + } else { + Ok(outcome.warnings) + } +} + +impl<'a> GuardrailsRunner<'a> { + /// Create a runner by resolving guards from pipeline config + request headers. + /// Returns None if no guards are active for this request. + /// When `parent_cx` is provided, guardrail evaluations are traced as child spans. + pub fn new( + guardrails: Option<&'a Guardrails>, + headers: &HeaderMap, + parent_cx: Option, + ) -> Option { + let gr = guardrails?; + let (pre_call, post_call) = resolve_request_guards(gr, headers); + if pre_call.is_empty() && post_call.is_empty() { + return None; + } + Some(Self { + pre_call, + post_call, + client: gr.client.as_ref(), + parent_cx, + }) + } + + /// Run pre-call guards, extracting input from the request only if guards exist. + pub async fn run_pre_call(&self, request: &impl PromptExtractor) -> GuardPhaseResult { + if self.pre_call.is_empty() { + return Ok(Vec::new()); + } + let input = request.extract_prompt(); + let outcome = + execute_guards(&self.pre_call, &input, self.client, self.parent_cx.as_ref()).await; + outcome_to_phase_result(outcome) + } + + /// Run post-call guards, extracting input from the response only if guards exist. + pub async fn run_post_call(&self, response: &impl CompletionExtractor) -> GuardPhaseResult { + if self.post_call.is_empty() { + return Ok(Vec::new()); + } + let completion = response.extract_completion(); + + if completion.is_empty() { + warn!("Skipping post-call guardrails: LLM response content is empty"); + return Ok(vec![GuardWarning { + guard_name: "all post_call guards".to_string(), + reason: "skipped due to empty response content".to_string(), + }]); + } + + let outcome = execute_guards( + &self.post_call, + &completion, + self.client, + self.parent_cx.as_ref(), + ) + .await; + outcome_to_phase_result(outcome) + } + + /// Attach warning headers to a response if there are any warnings. + /// Returns the response unchanged if there are no warnings. + pub fn finalize_response(response: Response, warnings: &[GuardWarning]) -> Response { + if warnings.is_empty() { + return response; + } + let header_val = warning_header_value(warnings); + let mut response = response; + match header_val.parse() { + Ok(parsed_header) => { + response + .headers_mut() + .insert("x-traceloop-guardrail-warning", parsed_header); + } + Err(e) => { + warn!( + error = %e, + header_value = %header_val, + "Failed to parse guardrail warning header, skipping header" + ); + } + } + response + } +} + +/// Build a 403 blocked response with the guard name. +pub fn blocked_response(outcome: &GuardrailsOutcome) -> Response { + let guard_name = outcome.blocking_guard.as_deref().unwrap_or("unknown"); + + // Find the blocking guard result to get details + let blocking_result = outcome.results.iter().find(|r| match r { + GuardResult::Failed { name, .. } | GuardResult::Error { name, .. } => name == guard_name, + _ => false, + }); + + let error_obj = match blocking_result { + Some(GuardResult::Failed { result, .. }) => json!({ + "type": "guardrail_blocked", + "guardrail": guard_name, + "message": format!("Request blocked by guardrail '{guard_name}'"), + "evaluation_result": result, + "reason": "evaluation_failed", + }), + Some(GuardResult::Error { error, .. }) => json!({ + "type": "guardrail_blocked", + "guardrail": guard_name, + "message": format!("Request blocked by guardrail '{guard_name}'"), + "error_details": error, + "reason": "evaluator_error", + }), + _ => json!({ + "type": "guardrail_blocked", + "guardrail": guard_name, + "message": format!("Request blocked by guardrail '{guard_name}'"), + }), + }; + + let body = json!({ "error": error_obj }); + (StatusCode::FORBIDDEN, Json(body)).into_response() +} + +pub fn warning_header_value(warnings: &[GuardWarning]) -> String { + warnings + .iter() + .map(|w| { + format!( + "guardrail_name=\"{}\", reason=\"{}\"", + w.guard_name, w.reason + ) + }) + .collect::>() + .join("; ") +} + +/// Resolve guards for this request by merging pipeline guards with header-specified guards. +fn resolve_request_guards(gr: &Guardrails, headers: &HeaderMap) -> (Vec, Vec) { + let header_guard_names = headers + .get("x-traceloop-guardrails") + .and_then(|v| v.to_str().ok()) + .map(parse_guardrails_header) + .unwrap_or_default(); + + let pipeline_names: Vec<&str> = gr.pipeline_guard_names.iter().map(|s| s.as_str()).collect(); + let header_names: Vec<&str> = header_guard_names.iter().map(|s| s.as_str()).collect(); + let resolved = resolve_guards_by_name(&gr.all_guards, &pipeline_names, &header_names); + + // Log unknown guard names from headers + if !header_guard_names.is_empty() { + let resolved_names: HashSet<&str> = resolved.iter().map(|g| g.name.as_str()).collect(); + for name in &header_guard_names { + if !resolved_names.contains(name.as_str()) { + warn!(guard_name = %name, "Unknown guard name in X-Traceloop-Guardrails header, ignoring"); + } + } + } + + split_guards_by_mode(&resolved) +} diff --git a/src/guardrails/setup.rs b/src/guardrails/setup.rs new file mode 100644 index 00000000..0d639002 --- /dev/null +++ b/src/guardrails/setup.rs @@ -0,0 +1,96 @@ +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; + +use super::types::{ + Guard, GuardMode, GuardrailClient, GuardrailResources, Guardrails, GuardrailsConfig, +}; + +/// Parse guard names from the X-Traceloop-Guardrails header value. +/// Names are comma-separated and trimmed. +pub fn parse_guardrails_header(header: &str) -> Vec { + header + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect() +} + +/// Resolve the final set of guards to execute by merging pipeline and header sources. +pub fn resolve_guards_by_name( + all_guards: &[Guard], + pipeline_names: &[&str], + header_names: &[&str], +) -> Vec { + let guard_map: HashMap<&str, &Guard> = + all_guards.iter().map(|g| (g.name.as_str(), g)).collect(); + + let mut seen = HashSet::new(); + let mut resolved = Vec::new(); + + let all_names = pipeline_names.iter().chain(header_names.iter()).copied(); + + for name in all_names { + if seen.insert(name) { + if let Some(guard) = guard_map.get(name) { + resolved.push((*guard).clone()); + } + } + } + + resolved +} + +/// Split guards into (pre_call, post_call) lists by mode. +pub fn split_guards_by_mode(guards: &[Guard]) -> (Vec, Vec) { + guards + .iter() + .cloned() + .partition(|g| g.mode == GuardMode::PreCall) +} + +/// Resolve provider defaults (api_base/api_key) for all guards in the config. +pub fn resolve_guard_defaults(config: &GuardrailsConfig) -> Vec { + let mut guards = config.guards.clone(); + for guard in &mut guards { + if guard.api_base.is_none() || guard.api_key.is_none() { + if let Some(provider) = config.providers.get(&guard.provider) { + if guard.api_base.is_none() && !provider.api_base.is_empty() { + guard.api_base = Some(provider.api_base.clone()); + } + if guard.api_key.is_none() && !provider.api_key.is_empty() { + guard.api_key = Some(provider.api_key.clone()); + } + } + } + } + guards +} + +/// Build the shared guardrail resources (resolved guards + client). +/// Returns None if the config has no guards. +/// Called once per router build; the result is shared across all pipelines. +pub fn build_guardrail_resources(config: &GuardrailsConfig) -> Option { + if config.guards.is_empty() { + return None; + } + let all_guards = Arc::new(resolve_guard_defaults(config)); + let client: Arc = + Arc::new(super::providers::traceloop::TraceloopClient::new()); + Some(GuardrailResources { + guards: all_guards, + client, + }) +} + +/// Build per-pipeline Guardrails from shared resources. +/// `shared` contains the Arc-wrapped guards and client built once by `build_guardrail_resources`. +pub fn build_pipeline_guardrails( + shared: &GuardrailResources, + pipeline_guard_names: &[String], +) -> Arc { + Arc::new(Guardrails { + all_guards: shared.guards.clone(), + pipeline_guard_names: pipeline_guard_names.to_vec(), + client: shared.client.clone(), + }) +} diff --git a/src/guardrails/span_attributes.rs b/src/guardrails/span_attributes.rs new file mode 100644 index 00000000..28739acc --- /dev/null +++ b/src/guardrails/span_attributes.rs @@ -0,0 +1,10 @@ +pub const GEN_AI_GUARDRAIL_NAME: &str = "gen_ai.guardrail.name"; +pub const GEN_AI_GUARDRAIL_STATUS: &str = "gen_ai.guardrail.status"; +pub const GEN_AI_GUARDRAIL_DURATION: &str = "gen_ai.guardrail.duration"; +pub const GEN_AI_GUARDRAIL_INPUT: &str = "gen_ai.guardrail.input"; +pub const GEN_AI_GUARDRAIL_ERROR_TYPE: &str = "gen_ai.guardrail.error.type"; +pub const GEN_AI_GUARDRAIL_ERROR_MESSAGE: &str = "gen_ai.guardrail.error.message"; + +pub const GUARDRAIL_PASSED: &str = "PASSED"; +pub const GUARDRAIL_FAILED: &str = "FAILED"; +pub const GUARDRAIL_ERROR: &str = "ERROR"; diff --git a/src/guardrails/types.rs b/src/guardrails/types.rs new file mode 100644 index 00000000..a738f92a --- /dev/null +++ b/src/guardrails/types.rs @@ -0,0 +1,213 @@ +use async_trait::async_trait; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; +use thiserror::Error; + +/// Shared guardrail resources: resolved guards + client. +/// Built once per router build and shared across all pipelines. +pub struct GuardrailResources { + pub guards: Arc>, + pub client: Arc, +} + +fn default_on_failure() -> OnFailure { + OnFailure::Warn +} + +fn default_required() -> bool { + false +} + +#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[serde(rename_all = "snake_case")] +pub enum GuardMode { + PreCall, + PostCall, +} + +#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[serde(rename_all = "snake_case")] +pub enum OnFailure { + Block, + Warn, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)] +pub struct ProviderConfig { + pub name: String, + pub api_base: String, + pub api_key: String, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct Guard { + pub name: String, + pub provider: String, + pub evaluator_slug: String, + #[serde(default)] + pub params: HashMap, + pub mode: GuardMode, + #[serde(default = "default_on_failure")] + pub on_failure: OnFailure, + #[serde(default = "default_required")] + pub required: bool, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub api_base: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub api_key: Option, +} + +impl Eq for Guard {} + +impl Hash for Guard { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.provider.hash(state); + self.evaluator_slug.hash(state); + // Hash params by sorting keys and hashing serialized values + let mut params_vec: Vec<_> = self.params.iter().collect(); + params_vec.sort_by(|a, b| a.0.cmp(b.0)); + for (k, v) in params_vec { + k.hash(state); + v.to_string().hash(state); + } + self.mode.hash(state); + self.on_failure.hash(state); + self.required.hash(state); + self.api_base.hash(state); + self.api_key.hash(state); + } +} + +#[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq)] +pub struct GuardrailsConfig { + #[serde( + default, + deserialize_with = "deserialize_providers", + serialize_with = "serialize_providers" + )] + pub providers: HashMap, + #[serde(default)] + pub guards: Vec, +} + +fn deserialize_providers<'de, D>( + deserializer: D, +) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + let list: Vec = Vec::deserialize(deserializer)?; + Ok(list.into_iter().map(|p| (p.name.clone(), p)).collect()) +} + +fn serialize_providers( + providers: &HashMap, + serializer: S, +) -> Result +where + S: Serializer, +{ + let list: Vec<&ProviderConfig> = providers.values().collect(); + list.serialize(serializer) +} + +impl Hash for GuardrailsConfig { + fn hash(&self, state: &mut H) { + let mut entries: Vec<_> = self.providers.iter().collect(); + entries.sort_by(|a, b| a.0.cmp(b.0)); + for (k, v) in entries { + k.hash(state); + v.hash(state); + } + self.guards.hash(state); + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EvaluatorResponse { + pub result: serde_json::Value, + pub pass: bool, +} + +#[derive(Debug, Clone)] +pub enum GuardResult { + Passed { + name: String, + }, + Failed { + name: String, + result: serde_json::Value, + on_failure: OnFailure, + }, + Error { + name: String, + error: String, + required: bool, + }, +} + +#[derive(Debug, Clone)] +pub struct GuardWarning { + pub guard_name: String, + pub reason: String, +} + +#[derive(Debug, Clone)] +pub struct GuardrailsOutcome { + pub results: Vec, + pub blocked: bool, + pub blocking_guard: Option, + pub warnings: Vec, +} + +#[derive(Debug, Clone, Error)] +pub enum GuardrailError { + #[error("Evaluator unavailable: {0}")] + Unavailable(String), + + #[error("HTTP error {status}: {body}")] + HttpError { status: u16, body: String }, + + #[error("Timeout: {0}")] + Timeout(String), + + #[error("Parse error: {0}")] + ParseError(String), +} + +impl From for GuardrailError { + fn from(e: reqwest::Error) -> Self { + if e.is_timeout() { + GuardrailError::Timeout(e.to_string()) + } else { + GuardrailError::Unavailable(e.to_string()) + } + } +} + +/// Trait for guardrail evaluator clients. +/// Each provider (traceloop, etc.) implements this to call its evaluator API. +#[async_trait] +pub trait GuardrailClient: Send + Sync { + async fn evaluate( + &self, + guard: &Guard, + input: &str, + ) -> Result; +} + +/// Guardrails state attached to a pipeline, containing resolved guards and client. +/// +/// `all_guards` and `client` are shared across all pipelines via `Arc` (built once). +/// `pipeline_guard_names` holds the guard names declared by this specific pipeline. +/// At request time, guards are resolved by merging pipeline guards with any +/// additional guards specified via the `X-Traceloop-Guardrails` header. +#[derive(Clone)] +pub struct Guardrails { + pub all_guards: Arc>, + pub pipeline_guard_names: Vec, + pub client: Arc, +} diff --git a/src/lib.rs b/src/lib.rs index 73b93200..1663bf71 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ pub mod ai_models; pub mod config; +pub mod guardrails; pub mod management; pub mod models; pub mod openapi; diff --git a/src/management/services/config_provider_service.rs b/src/management/services/config_provider_service.rs index 2ad80bef..94b31331 100644 --- a/src/management/services/config_provider_service.rs +++ b/src/management/services/config_provider_service.rs @@ -257,6 +257,7 @@ impl ConfigProviderService { name: dto.name, r#type: core_pipeline_type, plugins: core_plugins, + guards: vec![], }) } diff --git a/src/pipelines/mod.rs b/src/pipelines/mod.rs index deff9c45..b39f2b96 100644 --- a/src/pipelines/mod.rs +++ b/src/pipelines/mod.rs @@ -1,2 +1,3 @@ -mod otel; +pub mod otel; pub mod pipeline; +mod tracing_middleware; diff --git a/src/pipelines/otel.rs b/src/pipelines/otel.rs index d71e4411..82bd8de6 100644 --- a/src/pipelines/otel.rs +++ b/src/pipelines/otel.rs @@ -6,21 +6,25 @@ use crate::models::embeddings::{EmbeddingsInput, EmbeddingsRequest, EmbeddingsRe use crate::models::streaming::ChatCompletionChunk; use crate::models::usage::{EmbeddingUsage, Usage}; use opentelemetry::global::{BoxedSpan, ObjectSafeSpan}; -use opentelemetry::trace::{SpanKind, Status, Tracer}; -use opentelemetry::{KeyValue, global}; +use opentelemetry::trace::{SpanKind, Status, TraceContextExt, Tracer}; +use opentelemetry::{Context, KeyValue, global}; use opentelemetry_otlp::{SpanExporter, WithExportConfig, WithHttpConfig}; use opentelemetry_sdk::propagation::TraceContextPropagator; use opentelemetry_sdk::trace::TracerProvider; use opentelemetry_semantic_conventions::attribute::GEN_AI_REQUEST_MODEL; use opentelemetry_semantic_conventions::trace::*; use std::collections::HashMap; +use std::sync::{Arc, Mutex}; pub trait RecordSpan { fn record_span(&self, span: &mut BoxedSpan); } +pub type SharedTracer = Arc>; + pub struct OtelTracer { - span: BoxedSpan, + llm_span: Option, + root_span: BoxedSpan, accumulated_completion: Option, } @@ -87,21 +91,44 @@ impl OtelTracer { } } - pub fn start(operation: &str, request: &T) -> Self { + pub fn start() -> Self { let tracer = global::tracer("traceloop_hub"); - let mut span = tracer - .span_builder(format!("traceloop_hub.{operation}")) - .with_kind(SpanKind::Client) + let span = tracer + .span_builder("traceloop_hub") + .with_kind(SpanKind::Server) .start(&tracer); - request.record_span(&mut span); - Self { - span, + llm_span: None, + root_span: span, accumulated_completion: None, } } + /// Helper to extract SharedTracer from request extensions with fallback. + /// This is primarily used by handlers to get the tracer created by TracingMiddleware. + pub fn from_extensions(extensions: &axum::http::Extensions) -> SharedTracer { + extensions + .get::() + .cloned() + .unwrap_or_else(|| { + // Fallback for backwards compatibility + Arc::new(Mutex::new(OtelTracer::start())) + }) + } + + pub fn start_llm_span(&mut self, operation: &str, request: &T) { + let tracer = global::tracer("traceloop_hub"); + let parent_cx = self.parent_context(); + let mut span = tracer + .span_builder(format!("traceloop_hub.{operation}")) + .with_kind(SpanKind::Client) + .start_with_context(&tracer, &parent_cx); + + request.record_span(&mut span); + self.llm_span = Some(span); + } + pub fn log_chunk(&mut self, chunk: &ChatCompletionChunk) { if self.accumulated_completion.is_none() { self.accumulated_completion = Some(ChatCompletion { @@ -160,23 +187,52 @@ impl OtelTracer { pub fn streaming_end(&mut self) { if let Some(completion) = self.accumulated_completion.take() { - completion.record_span(&mut self.span); - self.span.set_status(Status::Ok); + if let Some(span) = &mut self.llm_span { + completion.record_span(span); + span.set_status(Status::Ok); + } } + self.root_span.set_status(Status::Ok); } pub fn log_success(&mut self, response: &T) { - response.record_span(&mut self.span); - self.span.set_status(Status::Ok); + if let Some(span) = &mut self.llm_span { + response.record_span(span); + span.set_status(Status::Ok); + } + self.root_span.set_status(Status::Ok); } pub fn log_error(&mut self, description: String) { - self.span.set_status(Status::error(description)); + if let Some(span) = &mut self.llm_span { + span.set_status(Status::error(description.clone())); + } + self.root_span.set_status(Status::error(description)); + } + + /// Returns an OTel Context carrying this tracer's root span as parent, + /// suitable for creating child spans. + pub fn parent_context(&self) -> Context { + Context::current().with_remote_span_context(self.root_span.span_context().clone()) } pub fn set_vendor(&mut self, vendor: &str) { - self.span - .set_attribute(KeyValue::new(GEN_AI_SYSTEM, vendor.to_string())); + if let Some(span) = &mut self.llm_span { + span.set_attribute(KeyValue::new(GEN_AI_SYSTEM, vendor.to_string())); + } + } +} + +fn set_optional_f64(span: &mut BoxedSpan, key: &'static str, value: Option) { + if let Some(v) = value { + span.set_attribute(KeyValue::new(key, v as f64)); + } +} + +fn content_to_string(content: &ChatMessageContent) -> String { + match content { + ChatMessageContent::String(s) => s.clone(), + ChatMessageContent::Array(parts) => serde_json::to_string(parts).unwrap_or_default(), } } @@ -185,24 +241,14 @@ impl RecordSpan for ChatCompletionRequest { span.set_attribute(KeyValue::new("llm.request.type", "chat")); span.set_attribute(KeyValue::new(GEN_AI_REQUEST_MODEL, self.model.clone())); - if let Some(freq_penalty) = self.frequency_penalty { - span.set_attribute(KeyValue::new( - GEN_AI_REQUEST_FREQUENCY_PENALTY, - freq_penalty as f64, - )); - } - if let Some(pres_penalty) = self.presence_penalty { - span.set_attribute(KeyValue::new( - GEN_AI_REQUEST_PRESENCE_PENALTY, - pres_penalty as f64, - )); - } - if let Some(top_p) = self.top_p { - span.set_attribute(KeyValue::new(GEN_AI_REQUEST_TOP_P, top_p as f64)); - } - if let Some(temp) = self.temperature { - span.set_attribute(KeyValue::new(GEN_AI_REQUEST_TEMPERATURE, temp as f64)); - } + set_optional_f64( + span, + GEN_AI_REQUEST_FREQUENCY_PENALTY, + self.frequency_penalty, + ); + set_optional_f64(span, GEN_AI_REQUEST_PRESENCE_PENALTY, self.presence_penalty); + set_optional_f64(span, GEN_AI_REQUEST_TOP_P, self.top_p); + set_optional_f64(span, GEN_AI_REQUEST_TEMPERATURE, self.temperature); if get_trace_content_enabled() { for (i, message) in self.messages.iter().enumerate() { @@ -213,12 +259,7 @@ impl RecordSpan for ChatCompletionRequest { )); span.set_attribute(KeyValue::new( format!("gen_ai.prompt.{i}.content"), - match &content { - ChatMessageContent::String(content) => content.clone(), - ChatMessageContent::Array(content) => { - serde_json::to_string(content).unwrap_or_default() - } - }, + content_to_string(content), )); } } @@ -242,12 +283,7 @@ impl RecordSpan for ChatCompletion { )); span.set_attribute(KeyValue::new( format!("gen_ai.completion.{}.content", choice.index), - match &content { - ChatMessageContent::String(content) => content.clone(), - ChatMessageContent::Array(content) => { - serde_json::to_string(content).unwrap_or_default() - } - }, + content_to_string(content), )); } span.set_attribute(KeyValue::new( @@ -265,24 +301,14 @@ impl RecordSpan for CompletionRequest { span.set_attribute(KeyValue::new(GEN_AI_REQUEST_MODEL, self.model.clone())); span.set_attribute(KeyValue::new("gen_ai.prompt", self.prompt.clone())); - if let Some(freq_penalty) = self.frequency_penalty { - span.set_attribute(KeyValue::new( - GEN_AI_REQUEST_FREQUENCY_PENALTY, - freq_penalty as f64, - )); - } - if let Some(pres_penalty) = self.presence_penalty { - span.set_attribute(KeyValue::new( - GEN_AI_REQUEST_PRESENCE_PENALTY, - pres_penalty as f64, - )); - } - if let Some(top_p) = self.top_p { - span.set_attribute(KeyValue::new(GEN_AI_REQUEST_TOP_P, top_p as f64)); - } - if let Some(temp) = self.temperature { - span.set_attribute(KeyValue::new(GEN_AI_REQUEST_TEMPERATURE, temp as f64)); - } + set_optional_f64( + span, + GEN_AI_REQUEST_FREQUENCY_PENALTY, + self.frequency_penalty, + ); + set_optional_f64(span, GEN_AI_REQUEST_PRESENCE_PENALTY, self.presence_penalty); + set_optional_f64(span, GEN_AI_REQUEST_TOP_P, self.top_p); + set_optional_f64(span, GEN_AI_REQUEST_TEMPERATURE, self.temperature); } } @@ -413,7 +439,8 @@ mod tests { // Test that set_vendor method compiles and can be called // This ensures the method signature is correct let mut tracer = OtelTracer { - span: opentelemetry::global::tracer("test").start("test"), + llm_span: Some(opentelemetry::global::tracer("test").start("test_llm")), + root_span: opentelemetry::global::tracer("test").start("test"), accumulated_completion: None, }; diff --git a/src/pipelines/pipeline.rs b/src/pipelines/pipeline.rs index e8351253..43ffd60c 100644 --- a/src/pipelines/pipeline.rs +++ b/src/pipelines/pipeline.rs @@ -1,9 +1,13 @@ use crate::config::models::PipelineType; +use crate::guardrails::middleware::GuardrailsLayer; +use crate::guardrails::setup::build_pipeline_guardrails; +use crate::guardrails::types::GuardrailResources; use crate::models::chat::ChatCompletionResponse; use crate::models::completion::CompletionRequest; use crate::models::embeddings::EmbeddingsRequest; use crate::models::streaming::ChatCompletionChunk; -use crate::pipelines::otel::OtelTracer; +use crate::pipelines::otel::{OtelTracer, SharedTracer}; +use crate::pipelines::tracing_middleware::TracingLayer; use crate::providers::provider::get_vendor_name; use crate::{ ai_models::registry::ModelRegistry, @@ -12,10 +16,10 @@ use crate::{ }; use async_stream::stream; use axum::response::sse::{Event, KeepAlive}; -use axum::response::{IntoResponse, Sse}; +use axum::response::{IntoResponse, Response, Sse}; use axum::{ Json, Router, - extract::State, + extract::{Extension, State}, http::StatusCode, routing::{get, post}, }; @@ -24,7 +28,19 @@ use futures::{Stream, StreamExt}; use reqwest_streams::error::StreamBodyError; use std::sync::Arc; -pub fn create_pipeline(pipeline: &Pipeline, model_registry: &ModelRegistry) -> Router { +/// Access the tracer behind the shared mutex, executing a closure with the locked guard. +/// Panics if the mutex is poisoned (same behavior as the previous .lock().unwrap() calls). +fn with_tracer(tracer: &SharedTracer, f: impl FnOnce(&mut OtelTracer) -> R) -> R { + f(&mut tracer.lock().unwrap()) +} + +pub fn create_pipeline( + pipeline: &Pipeline, + model_registry: &ModelRegistry, + guardrail_resources: Option<&GuardrailResources>, +) -> Router { + let guardrails = + guardrail_resources.map(|shared| build_pipeline_guardrails(shared, &pipeline.guards)); let mut router = Router::new(); let available_models: Vec = pipeline @@ -59,26 +75,31 @@ pub fn create_pipeline(pipeline: &Pipeline, model_registry: &ModelRegistry) -> R PluginConfig::ModelRouter { models } => match pipeline.r#type { PipelineType::Chat => router.route( "/chat/completions", - post(move |state, payload| chat_completions(state, payload, models)), + post(move |tracer, state, payload| { + chat_completions(tracer, state, payload, models) + }), ), PipelineType::Completion => router.route( "/completions", - post(move |state, payload| completions(state, payload, models)), + post(move |tracer, state, payload| completions(tracer, state, payload, models)), ), PipelineType::Embeddings => router.route( "/embeddings", - post(move |state, payload| embeddings(state, payload, models)), + post(move |tracer, state, payload| embeddings(tracer, state, payload, models)), ), }, _ => router, }; } - router.with_state(Arc::new(model_registry.clone())) + router + .with_state(Arc::new(model_registry.clone())) + .layer(GuardrailsLayer::new(guardrails)) + .layer(TracingLayer::new()) } fn trace_and_stream( - mut tracer: OtelTracer, + tracer: SharedTracer, stream: BoxStream<'static, Result>, ) -> impl Stream> { stream! { @@ -86,43 +107,48 @@ fn trace_and_stream( while let Some(result) = stream.next().await { yield match result { Ok(chunk) => { - tracer.log_chunk(&chunk); + with_tracer(&tracer, |t| t.log_chunk(&chunk)); Event::default().json_data(chunk) } Err(e) => { eprintln!("Error in stream: {e:?}"); - tracer.log_error(e.to_string()); + with_tracer(&tracer, |t| t.log_error(e.to_string())); Err(axum::Error::new(e)) } }; } - tracer.streaming_end(); + with_tracer(&tracer, |t| t.streaming_end()); } } pub async fn chat_completions( + Extension(tracer): Extension, State(model_registry): State>, Json(payload): Json, model_keys: Vec, -) -> Result { - let mut tracer = OtelTracer::start("chat", &payload); - +) -> Result { for model_key in model_keys { let model = model_registry.get(&model_key).unwrap(); if payload.model == model.model_type { - // Set vendor now that we know which model/provider we're using - tracer.set_vendor(&get_vendor_name(&model.provider.r#type())); + with_tracer(&tracer, |t| { + t.start_llm_span("chat", &payload); + t.set_vendor(&get_vendor_name(&model.provider.r#type())); + }); - let response = model - .chat_completions(payload.clone()) - .await - .inspect_err(|e| { + let response = match model.chat_completions(payload.clone()).await { + Ok(response) => response, + Err(e) => { eprintln!("Chat completion error for model {model_key}: {e:?}"); - })?; + with_tracer(&tracer, |t| { + t.log_error(format!("Chat completion failed: {e:?}")) + }); + return Err(e); + } + }; if let ChatCompletionResponse::NonStream(completion) = response { - tracer.log_success(&completion); + with_tracer(&tracer, |t| t.log_success(&completion)); return Ok(Json(completion).into_response()); } @@ -134,61 +160,84 @@ pub async fn chat_completions( } } - tracer.log_error("No matching model found".to_string()); + with_tracer(&tracer, |t| { + t.log_error("No matching model found".to_string()) + }); eprintln!("No matching model found for: {}", payload.model); Err(StatusCode::NOT_FOUND) } pub async fn completions( + Extension(tracer): Extension, State(model_registry): State>, Json(payload): Json, model_keys: Vec, -) -> impl IntoResponse { - let mut tracer = OtelTracer::start("completion", &payload); - +) -> Result { for model_key in model_keys { let model = model_registry.get(&model_key).unwrap(); if payload.model == model.model_type { - // Set vendor now that we know which model/provider we're using - tracer.set_vendor(&get_vendor_name(&model.provider.r#type())); + with_tracer(&tracer, |t| { + t.start_llm_span("completion", &payload); + t.set_vendor(&get_vendor_name(&model.provider.r#type())); + }); - let response = model.completions(payload.clone()).await.inspect_err(|e| { - eprintln!("Completion error for model {model_key}: {e:?}"); - })?; - tracer.log_success(&response); - return Ok(Json(response)); + let response = match model.completions(payload.clone()).await { + Ok(response) => response, + Err(e) => { + eprintln!("Completion error for model {model_key}: {e:?}"); + with_tracer(&tracer, |t| { + t.log_error(format!("Completion failed: {e:?}")) + }); + return Err(e); + } + }; + with_tracer(&tracer, |t| t.log_success(&response)); + + return Ok(Json(response).into_response()); } } - tracer.log_error("No matching model found".to_string()); + with_tracer(&tracer, |t| { + t.log_error("No matching model found".to_string()) + }); eprintln!("No matching model found for: {}", payload.model); Err(StatusCode::NOT_FOUND) } pub async fn embeddings( + Extension(tracer): Extension, State(model_registry): State>, Json(payload): Json, model_keys: Vec, ) -> impl IntoResponse { - let mut tracer = OtelTracer::start("embeddings", &payload); - for model_key in model_keys { let model = model_registry.get(&model_key).unwrap(); if payload.model == model.model_type { - // Set vendor now that we know which model/provider we're using - tracer.set_vendor(&get_vendor_name(&model.provider.r#type())); + with_tracer(&tracer, |t| { + t.start_llm_span("embeddings", &payload); + t.set_vendor(&get_vendor_name(&model.provider.r#type())); + }); - let response = model.embeddings(payload.clone()).await.inspect_err(|e| { - eprintln!("Embeddings error for model {model_key}: {e:?}"); - })?; - tracer.log_success(&response); + let response = match model.embeddings(payload.clone()).await { + Ok(response) => response, + Err(e) => { + eprintln!("Embeddings error for model {model_key}: {e:?}"); + with_tracer(&tracer, |t| { + t.log_error(format!("Embeddings failed: {e:?}")) + }); + return Err(e); + } + }; + with_tracer(&tracer, |t| t.log_success(&response)); return Ok(Json(response)); } } - tracer.log_error("No matching model found".to_string()); + with_tracer(&tracer, |t| { + t.log_error("No matching model found".to_string()) + }); eprintln!("No matching model found for: {}", payload.model); Err(StatusCode::NOT_FOUND) } @@ -293,6 +342,22 @@ mod tests { plugins: vec![PluginConfig::ModelRouter { models: model_keys.into_iter().map(|s| s.to_string()).collect(), }], + guards: vec![], + } + } + + // Helper function to create test pipeline with specific type (for otel tests) + fn create_test_pipeline_with_type( + model_keys: Vec<&str>, + pipeline_type: PipelineType, + ) -> Pipeline { + Pipeline { + name: "test".to_string(), + r#type: pipeline_type, + plugins: vec![PluginConfig::ModelRouter { + models: model_keys.into_iter().map(|s| s.to_string()).collect(), + }], + guards: vec![], } } @@ -322,7 +387,7 @@ mod tests { let model_configs = create_model_configs(vec!["test-model"]); let model_registry = ModelRegistry::new(&model_configs, provider_registry).unwrap(); let pipeline = create_test_pipeline(vec!["test-model"]); - let app = create_pipeline(&pipeline, &model_registry); + let app = create_pipeline(&pipeline, &model_registry, None); let response = get_models_response(app).await; @@ -371,7 +436,7 @@ mod tests { let model_registry = Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); let pipeline = create_test_pipeline(vec!["test-model-1", "test-model-2"]); - let app = create_pipeline(&pipeline, &model_registry); + let app = create_pipeline(&pipeline, &model_registry, None); let response = get_models_response(app).await; @@ -394,7 +459,7 @@ mod tests { let model_registry = Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); let pipeline = create_test_pipeline(vec![]); - let app = create_pipeline(&pipeline, &model_registry); + let app = create_pipeline(&pipeline, &model_registry, None); let response = get_models_response(app).await; @@ -412,7 +477,7 @@ mod tests { let model_registry = Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); let pipeline = create_test_pipeline(vec!["test-model-1", "test-model-3"]); // Only include 2 of 3 models - let app = create_pipeline(&pipeline, &model_registry); + let app = create_pipeline(&pipeline, &model_registry, None); let response = get_models_response(app).await; @@ -427,24 +492,20 @@ mod tests { assert!(!ids.contains(&"test-model-2")); } - // Test providers with different types for vendor testing - #[derive(Clone)] - struct TestProviderOpenAI; - #[derive(Clone)] - struct TestProviderAnthropic; + // Parameterized test provider for vendor testing #[derive(Clone)] - struct TestProviderAzure; + struct TestProvider(ProviderType); #[async_trait] - impl Provider for TestProviderOpenAI { + impl Provider for TestProvider { fn new(_config: &ProviderConfig) -> Self { - Self + Self(ProviderType::OpenAI) } fn key(&self) -> String { - "openai-key".to_string() + format!("{:?}-key", self.0).to_lowercase() } fn r#type(&self) -> ProviderType { - ProviderType::OpenAI + self.0.clone() } async fn chat_completions( @@ -457,7 +518,7 @@ mod tests { id: "test".to_string(), object: None, created: None, - model: "gpt-4".to_string(), + model: "test".to_string(), choices: vec![], usage: crate::models::usage::Usage::default(), system_fingerprint: None, @@ -482,126 +543,1099 @@ mod tests { } } - #[async_trait] - impl Provider for TestProviderAnthropic { - fn new(_config: &ProviderConfig) -> Self { - Self - } - fn key(&self) -> String { - "anthropic-key".to_string() + #[test] + fn test_vendor_mapping_integration() { + // Test that different provider types map to correct vendor names + // This tests the integration between provider types and vendor names + assert_eq!(get_vendor_name(&ProviderType::OpenAI), "openai"); + assert_eq!(get_vendor_name(&ProviderType::Anthropic), "Anthropic"); + assert_eq!(get_vendor_name(&ProviderType::Azure), "Azure"); + assert_eq!(get_vendor_name(&ProviderType::Bedrock), "AWS"); + assert_eq!(get_vendor_name(&ProviderType::VertexAI), "Google"); + } + + #[test] + fn test_provider_type_methods() { + // Test that our test provider returns the correct types + // and maps to the correct vendor names + let cases = [ + (ProviderType::OpenAI, "openai"), + (ProviderType::Anthropic, "Anthropic"), + (ProviderType::Azure, "Azure"), + ]; + for (provider_type, expected_vendor) in cases { + let provider = TestProvider(provider_type.clone()); + assert_eq!(provider.r#type(), provider_type); + assert_eq!(get_vendor_name(&provider.r#type()), expected_vendor); } - fn r#type(&self) -> ProviderType { - ProviderType::Anthropic + } + + // OpenTelemetry span verification tests + mod otel_span_tests { + use super::*; + use opentelemetry::trace::{SpanKind, Status as OtelStatus}; + use opentelemetry_sdk::export::trace::SpanData; + use opentelemetry_sdk::testing::trace::InMemorySpanExporter; + use opentelemetry_sdk::trace::TracerProvider; + use std::sync::{LazyLock, Mutex}; + + /// Shared OTel exporter, initialized once for all span tests + /// Tests are isolated by tracking span count before/after each request + static TEST_EXPORTER: LazyLock = LazyLock::new(|| { + let exporter = InMemorySpanExporter::default(); + let provider = TracerProvider::builder() + .with_simple_exporter(exporter.clone()) + .build(); + opentelemetry::global::set_tracer_provider(provider); + exporter + }); + + /// Mutex to serialize span tests and prevent race conditions + /// with the shared TEST_EXPORTER + static SPAN_TEST_LOCK: LazyLock> = LazyLock::new(|| Mutex::new(())); + + // Mock provider that returns realistic responses with Usage data + #[derive(Clone)] + struct MockProviderForSpanTests { + provider_type: ProviderType, } - async fn chat_completions( - &self, - _payload: crate::models::chat::ChatCompletionRequest, - _model_config: &ModelConfig, - ) -> Result { - Ok(crate::models::chat::ChatCompletionResponse::NonStream( - crate::models::chat::ChatCompletion { - id: "test".to_string(), - object: None, - created: None, - model: "claude-3".to_string(), - choices: vec![], - usage: crate::models::usage::Usage::default(), + #[async_trait] + impl Provider for MockProviderForSpanTests { + fn new(_config: &ProviderConfig) -> Self { + Self { + provider_type: ProviderType::OpenAI, + } + } + + fn key(&self) -> String { + "test-key".to_string() + } + + fn r#type(&self) -> ProviderType { + self.provider_type.clone() + } + + async fn chat_completions( + &self, + payload: crate::models::chat::ChatCompletionRequest, + _model_config: &ModelConfig, + ) -> Result { + use crate::models::chat::{ + ChatCompletion, ChatCompletionChoice, ChatCompletionResponse, + }; + use crate::models::content::{ChatCompletionMessage, ChatMessageContent}; + use crate::models::usage::Usage; + + Ok(ChatCompletionResponse::NonStream(ChatCompletion { + id: "chatcmpl-test123".to_string(), + object: Some("chat.completion".to_string()), + created: Some(1234567890), + model: payload.model.clone(), + choices: vec![ChatCompletionChoice { + index: 0, + message: ChatCompletionMessage { + role: "assistant".to_string(), + content: Some(ChatMessageContent::String("Test response".to_string())), + name: None, + tool_calls: None, + tool_call_id: None, + refusal: None, + }, + finish_reason: Some("stop".to_string()), + logprobs: None, + }], + usage: Usage { + prompt_tokens: 10, + completion_tokens: 15, + total_tokens: 25, + completion_tokens_details: None, + prompt_tokens_details: None, + }, system_fingerprint: None, - }, - )) + })) + } + + async fn completions( + &self, + payload: crate::models::completion::CompletionRequest, + _model_config: &ModelConfig, + ) -> Result { + use crate::models::completion::{CompletionChoice, CompletionResponse}; + use crate::models::usage::Usage; + + Ok(CompletionResponse { + id: "cmpl-test456".to_string(), + object: "text_completion".to_string(), + created: 1234567890, + model: payload.model.clone(), + choices: vec![CompletionChoice { + text: "Test completion".to_string(), + index: 0, + logprobs: None, + finish_reason: Some("stop".to_string()), + }], + usage: Usage { + prompt_tokens: 5, + completion_tokens: 10, + total_tokens: 15, + completion_tokens_details: None, + prompt_tokens_details: None, + }, + }) + } + + async fn embeddings( + &self, + payload: crate::models::embeddings::EmbeddingsRequest, + _model_config: &ModelConfig, + ) -> Result { + use crate::models::embeddings::{Embedding, Embeddings, EmbeddingsResponse}; + use crate::models::usage::EmbeddingUsage; + + Ok(EmbeddingsResponse { + object: "list".to_string(), + data: vec![Embeddings { + object: "embedding".to_string(), + embedding: Embedding::Float(vec![0.1, 0.2, 0.3]), + index: 0, + }], + model: payload.model.clone(), + usage: EmbeddingUsage { + prompt_tokens: Some(8), + total_tokens: Some(8), + }, + }) + } } - async fn completions( - &self, - _payload: CompletionRequest, - _model_config: &ModelConfig, - ) -> Result { - Err(StatusCode::NOT_IMPLEMENTED) + // Mock provider that returns errors + #[derive(Clone)] + struct MockProviderError { + provider_type: ProviderType, } - async fn embeddings( - &self, - _payload: EmbeddingsRequest, - _model_config: &ModelConfig, - ) -> Result { - Err(StatusCode::NOT_IMPLEMENTED) + #[async_trait] + impl Provider for MockProviderError { + fn new(_config: &ProviderConfig) -> Self { + Self { + provider_type: ProviderType::OpenAI, + } + } + + fn key(&self) -> String { + "test-key".to_string() + } + + fn r#type(&self) -> ProviderType { + self.provider_type.clone() + } + + async fn chat_completions( + &self, + _payload: crate::models::chat::ChatCompletionRequest, + _model_config: &ModelConfig, + ) -> Result { + Err(StatusCode::INTERNAL_SERVER_ERROR) + } + + async fn completions( + &self, + _payload: crate::models::completion::CompletionRequest, + _model_config: &ModelConfig, + ) -> Result { + Err(StatusCode::BAD_GATEWAY) + } + + async fn embeddings( + &self, + _payload: crate::models::embeddings::EmbeddingsRequest, + _model_config: &ModelConfig, + ) -> Result { + Err(StatusCode::SERVICE_UNAVAILABLE) + } } - } - #[async_trait] - impl Provider for TestProviderAzure { - fn new(_config: &ProviderConfig) -> Self { - Self + // Helper: Create provider registry with specific provider type + fn create_test_provider_registry_for_spans( + provider_type: ProviderType, + return_errors: bool, + ) -> Arc { + let provider_config = ProviderConfig { + key: "test-provider".to_string(), + r#type: provider_type.clone(), + api_key: String::new(), + params: HashMap::new(), + }; + + let mut registry = ProviderRegistry::new(&[provider_config.clone()]).unwrap(); + + // Replace the provider with our mock + if return_errors { + let mock = MockProviderError { provider_type }; + let providers = registry.providers_mut(); + providers.insert("test-provider".to_string(), Arc::new(mock)); + } else { + let mock = MockProviderForSpanTests { provider_type }; + let providers = registry.providers_mut(); + providers.insert("test-provider".to_string(), Arc::new(mock)); + } + + Arc::new(registry) } - fn key(&self) -> String { - "azure-key".to_string() + + // Helper: Collect spans added since before_count + // Retries for up to 100ms to handle async span finalization + fn get_spans_for_test(before_count: usize) -> Vec { + use std::time::Duration; + + // Retry for up to 100ms to wait for spans to be flushed + for _ in 0..10 { + // Get all spans + let all_spans = TEST_EXPORTER.get_finished_spans().unwrap(); + + // Skip the before_count spans and take only the next few + // We expect exactly 2 spans per test (root + LLM) + let new_spans: Vec = all_spans.into_iter().skip(before_count).collect(); + + // If we have at least 2 spans, we can proceed + if new_spans.len() >= 2 { + // If we have more than 2 spans, try to find the most recent root span + // and its immediate child + if new_spans.len() > 2 { + // Find the last root span (traceloop_hub with Server kind) + if let Some(root_idx) = new_spans.iter().rposition(|s| { + s.name == "traceloop_hub" && s.span_kind == SpanKind::Server + }) { + let root = &new_spans[root_idx]; + let root_trace_id = root.span_context.trace_id(); + + // Collect all spans with the same trace_id + return new_spans + .into_iter() + .filter(|s| s.span_context.trace_id() == root_trace_id) + .collect(); + } + } + + return new_spans; + } + + // Wait a bit before retrying + std::thread::sleep(Duration::from_millis(10)); + } + + // Last attempt - return whatever we have + let all_spans = TEST_EXPORTER.get_finished_spans().unwrap(); + all_spans.into_iter().skip(before_count).collect() } - fn r#type(&self) -> ProviderType { - ProviderType::Azure + + // Helper: Find root span (name="traceloop_hub", SpanKind::Server) + fn get_root_span(spans: &[SpanData]) -> Option<&SpanData> { + spans + .iter() + .find(|s| s.name == "traceloop_hub" && s.span_kind == SpanKind::Server) } - async fn chat_completions( - &self, - _payload: crate::models::chat::ChatCompletionRequest, - _model_config: &ModelConfig, - ) -> Result { - Ok(crate::models::chat::ChatCompletionResponse::NonStream( - crate::models::chat::ChatCompletion { - id: "test".to_string(), - object: None, - created: None, - model: "gpt-4".to_string(), - choices: vec![], - usage: crate::models::usage::Usage::default(), - system_fingerprint: None, - }, - )) + // Helper: Find LLM span by operation + fn get_llm_span<'a>(spans: &'a [SpanData], operation: &str) -> Option<&'a SpanData> { + let expected_name = format!("traceloop_hub.{}", operation); + spans + .iter() + .find(|s| s.name == expected_name && s.span_kind == SpanKind::Client) } - async fn completions( - &self, - _payload: CompletionRequest, - _model_config: &ModelConfig, - ) -> Result { - Err(StatusCode::NOT_IMPLEMENTED) + // Helper: Extract attribute value from span + fn get_span_attribute(span: &SpanData, key: &str) -> Option { + span.attributes + .iter() + .find(|kv| kv.key.to_string() == key) + .map(|kv| kv.value.to_string()) } - async fn embeddings( - &self, - _payload: EmbeddingsRequest, - _model_config: &ModelConfig, - ) -> Result { - Err(StatusCode::NOT_IMPLEMENTED) + // Helper: Check if span is child of another span + fn is_child_of(child: &SpanData, parent: &SpanData) -> bool { + child.parent_span_id == parent.span_context.span_id() + && child.span_context.trace_id() == parent.span_context.trace_id() } - } - #[test] - fn test_vendor_mapping_integration() { - // Test that different provider types map to correct vendor names - // This tests the integration between provider types and vendor names - assert_eq!(get_vendor_name(&ProviderType::OpenAI), "openai"); - assert_eq!(get_vendor_name(&ProviderType::Anthropic), "Anthropic"); - assert_eq!(get_vendor_name(&ProviderType::Azure), "Azure"); - assert_eq!(get_vendor_name(&ProviderType::Bedrock), "AWS"); - assert_eq!(get_vendor_name(&ProviderType::VertexAI), "Google"); - } + // Helper: Assert span has expected attributes + fn assert_span_attributes(span: &SpanData, expected: &[(&str, &str)]) { + for (key, expected_value) in expected { + let actual = get_span_attribute(span, key); + assert_eq!( + actual.as_deref(), + Some(*expected_value), + "Attribute {} mismatch. Expected: {}, Got: {:?}", + key, + expected_value, + actual + ); + } + } - #[test] - fn test_provider_type_methods() { - // Test that our test providers return the correct types - // This ensures the pipeline would call set_vendor with the right values - let openai_provider = TestProviderOpenAI; - let anthropic_provider = TestProviderAnthropic; - let azure_provider = TestProviderAzure; - - assert_eq!(openai_provider.r#type(), ProviderType::OpenAI); - assert_eq!(anthropic_provider.r#type(), ProviderType::Anthropic); - assert_eq!(azure_provider.r#type(), ProviderType::Azure); - - // Test that these map to the correct vendor names - assert_eq!(get_vendor_name(&openai_provider.r#type()), "openai"); - assert_eq!(get_vendor_name(&anthropic_provider.r#type()), "Anthropic"); - assert_eq!(get_vendor_name(&azure_provider.r#type()), "Azure"); + #[tokio::test] + async fn test_chat_completions_success_spans() { + // Serialize span tests to avoid race conditions + let _lock = SPAN_TEST_LOCK.lock().unwrap(); + // Initialize exporter + let _ = &*TEST_EXPORTER; + let before_count = TEST_EXPORTER.get_finished_spans().unwrap().len(); + + // Create test infrastructure + let provider_registry = + create_test_provider_registry_for_spans(ProviderType::OpenAI, false); + let model_configs = create_model_configs(vec!["test-model"]); + let model_registry = + Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); + let pipeline = create_test_pipeline(vec!["test-model"]); + + // Create router (includes TracingLayer) + let app = create_pipeline(&pipeline, &model_registry, None); + + // Prepare request + use crate::models::chat::ChatCompletionRequest; + use crate::models::content::{ChatCompletionMessage, ChatMessageContent}; + let request_body = ChatCompletionRequest { + model: "test".to_string(), + messages: vec![ChatCompletionMessage { + role: "user".to_string(), + content: Some(ChatMessageContent::String("Hello".to_string())), + name: None, + tool_calls: None, + tool_call_id: None, + refusal: None, + }], + temperature: Some(0.7), + top_p: None, + n: None, + stream: None, + stop: None, + max_tokens: None, + max_completion_tokens: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + user: None, + response_format: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + logprobs: None, + top_logprobs: None, + reasoning: None, + }; + + // Send request + let response = app + .oneshot( + Request::builder() + .uri("/chat/completions") + .method("POST") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request_body).unwrap())) + .unwrap(), + ) + .await + .unwrap(); + + // Verify response success + assert_eq!(response.status(), StatusCode::OK); + + // Collect new spans + let spans = get_spans_for_test(before_count); + assert_eq!( + spans.len(), + 2, + "Expected root + LLM span, got {}", + spans.len() + ); + + // Verify root span + let root = get_root_span(&spans).expect("Root span not found"); + assert_eq!(root.name, "traceloop_hub"); + assert_eq!(root.span_kind, SpanKind::Server); + assert_eq!(root.status, OtelStatus::Ok); + + // Verify LLM span + let llm = get_llm_span(&spans, "chat").expect("LLM span not found"); + assert_eq!(llm.name, "traceloop_hub.chat"); + assert_eq!(llm.span_kind, SpanKind::Client); + assert_eq!(llm.status, OtelStatus::Ok); + + // Verify hierarchy + assert!( + is_child_of(llm, root), + "LLM span should be child of root span" + ); + + // Verify attributes + assert_span_attributes( + llm, + &[ + ("gen_ai.system", "openai"), + ("gen_ai.request.model", "test"), + ("llm.request.type", "chat"), + ], + ); + + // Verify usage attributes exist + assert!( + get_span_attribute(llm, "gen_ai.usage.prompt_tokens").is_some(), + "prompt_tokens attribute missing" + ); + assert!( + get_span_attribute(llm, "gen_ai.usage.completion_tokens").is_some(), + "completion_tokens attribute missing" + ); + assert!( + get_span_attribute(llm, "gen_ai.usage.total_tokens").is_some(), + "total_tokens attribute missing" + ); + } + + #[tokio::test] + async fn test_completions_success_spans() { + let _lock = SPAN_TEST_LOCK.lock().unwrap(); + let _ = &*TEST_EXPORTER; + let before_count = TEST_EXPORTER.get_finished_spans().unwrap().len(); + + let provider_registry = + create_test_provider_registry_for_spans(ProviderType::Anthropic, false); + let model_configs = create_model_configs(vec!["test-model"]); + let model_registry = + Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); + let pipeline = + create_test_pipeline_with_type(vec!["test-model"], PipelineType::Completion); + let app = create_pipeline(&pipeline, &model_registry, None); + + use crate::models::completion::CompletionRequest; + let request_body = CompletionRequest { + model: "test".to_string(), + prompt: "Test prompt".to_string(), + max_tokens: Some(50), + temperature: None, + top_p: None, + n: None, + stream: None, + logprobs: None, + echo: None, + stop: None, + presence_penalty: None, + frequency_penalty: None, + best_of: None, + logit_bias: None, + user: None, + suffix: None, + }; + + let response = app + .oneshot( + Request::builder() + .uri("/completions") + .method("POST") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request_body).unwrap())) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let spans = get_spans_for_test(before_count); + assert_eq!(spans.len(), 2, "Expected root + LLM span"); + + let root = get_root_span(&spans).expect("Root span not found"); + let llm = get_llm_span(&spans, "completion").expect("LLM span not found"); + + assert_eq!(llm.name, "traceloop_hub.completion"); + assert_eq!(llm.span_kind, SpanKind::Client); + assert!(is_child_of(llm, root)); + + assert_span_attributes( + llm, + &[ + ("gen_ai.system", "Anthropic"), + ("gen_ai.request.model", "test"), + ("llm.request.type", "completion"), + ], + ); + } + + #[tokio::test] + async fn test_embeddings_success_spans() { + let _lock = SPAN_TEST_LOCK.lock().unwrap(); + let _ = &*TEST_EXPORTER; + let before_count = TEST_EXPORTER.get_finished_spans().unwrap().len(); + + let provider_registry = + create_test_provider_registry_for_spans(ProviderType::Azure, false); + let model_configs = create_model_configs(vec!["test-model"]); + let model_registry = + Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); + let pipeline = + create_test_pipeline_with_type(vec!["test-model"], PipelineType::Embeddings); + let app = create_pipeline(&pipeline, &model_registry, None); + + use crate::models::embeddings::{EmbeddingsInput, EmbeddingsRequest}; + let request_body = EmbeddingsRequest { + input: EmbeddingsInput::Single("Test text".to_string()), + model: "test".to_string(), + encoding_format: None, + user: None, + }; + + let response = app + .oneshot( + Request::builder() + .uri("/embeddings") + .method("POST") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request_body).unwrap())) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let spans = get_spans_for_test(before_count); + assert_eq!(spans.len(), 2, "Expected root + LLM span"); + + let root = get_root_span(&spans).expect("Root span not found"); + let llm = get_llm_span(&spans, "embeddings").expect("LLM span not found"); + + assert_eq!(llm.name, "traceloop_hub.embeddings"); + assert_eq!(llm.span_kind, SpanKind::Client); + assert!(is_child_of(llm, root)); + + assert_span_attributes( + llm, + &[ + ("gen_ai.system", "Azure"), + ("gen_ai.request.model", "test"), + ("llm.request.type", "embeddings"), + ], + ); + } + + #[tokio::test] + async fn test_chat_completions_error_spans() { + let _lock = SPAN_TEST_LOCK.lock().unwrap(); + let _ = &*TEST_EXPORTER; + let before_count = TEST_EXPORTER.get_finished_spans().unwrap().len(); + + let provider_registry = + create_test_provider_registry_for_spans(ProviderType::OpenAI, true); + let model_configs = create_model_configs(vec!["test-model"]); + let model_registry = + Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); + let pipeline = create_test_pipeline(vec!["test-model"]); + let app = create_pipeline(&pipeline, &model_registry, None); + + use crate::models::chat::ChatCompletionRequest; + use crate::models::content::{ChatCompletionMessage, ChatMessageContent}; + let request_body = ChatCompletionRequest { + model: "test".to_string(), + messages: vec![ChatCompletionMessage { + role: "user".to_string(), + content: Some(ChatMessageContent::String("Hello".to_string())), + name: None, + tool_calls: None, + tool_call_id: None, + refusal: None, + }], + temperature: None, + top_p: None, + n: None, + stream: None, + stop: None, + max_tokens: None, + max_completion_tokens: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + user: None, + response_format: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + logprobs: None, + top_logprobs: None, + reasoning: None, + }; + + let response = app + .oneshot( + Request::builder() + .uri("/chat/completions") + .method("POST") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request_body).unwrap())) + .unwrap(), + ) + .await + .unwrap(); + + // Verify error response + assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); + + // Collect spans + let spans = get_spans_for_test(before_count); + assert_eq!(spans.len(), 2, "Expected root + LLM span"); + + let root = get_root_span(&spans).expect("Root span not found"); + let llm = get_llm_span(&spans, "chat").expect("LLM span not found"); + + // Verify error status on both spans + assert!( + matches!(root.status, OtelStatus::Error { .. }), + "Root span should have error status" + ); + assert!( + matches!(llm.status, OtelStatus::Error { .. }), + "LLM span should have error status" + ); + + // Hierarchy should still be correct + assert!(is_child_of(llm, root)); + } + + #[tokio::test] + async fn test_completions_error_spans() { + let _lock = SPAN_TEST_LOCK.lock().unwrap(); + let _ = &*TEST_EXPORTER; + let before_count = TEST_EXPORTER.get_finished_spans().unwrap().len(); + + let provider_registry = + create_test_provider_registry_for_spans(ProviderType::OpenAI, true); + let model_configs = create_model_configs(vec!["test-model"]); + let model_registry = + Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); + let pipeline = + create_test_pipeline_with_type(vec!["test-model"], PipelineType::Completion); + let app = create_pipeline(&pipeline, &model_registry, None); + + use crate::models::completion::CompletionRequest; + let request_body = CompletionRequest { + model: "test".to_string(), + prompt: "Test".to_string(), + max_tokens: None, + temperature: None, + top_p: None, + n: None, + stream: None, + logprobs: None, + echo: None, + stop: None, + presence_penalty: None, + frequency_penalty: None, + best_of: None, + logit_bias: None, + user: None, + suffix: None, + }; + + let response = app + .oneshot( + Request::builder() + .uri("/completions") + .method("POST") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request_body).unwrap())) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::BAD_GATEWAY); + + let spans = get_spans_for_test(before_count); + let root = get_root_span(&spans).expect("Root span not found"); + let llm = get_llm_span(&spans, "completion").expect("LLM span not found"); + + assert!(matches!(root.status, OtelStatus::Error { .. })); + assert!(matches!(llm.status, OtelStatus::Error { .. })); + } + + #[tokio::test] + async fn test_embeddings_error_spans() { + let _lock = SPAN_TEST_LOCK.lock().unwrap(); + let _ = &*TEST_EXPORTER; + let before_count = TEST_EXPORTER.get_finished_spans().unwrap().len(); + + let provider_registry = + create_test_provider_registry_for_spans(ProviderType::OpenAI, true); + let model_configs = create_model_configs(vec!["test-model"]); + let model_registry = + Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); + let pipeline = + create_test_pipeline_with_type(vec!["test-model"], PipelineType::Embeddings); + let app = create_pipeline(&pipeline, &model_registry, None); + + use crate::models::embeddings::{EmbeddingsInput, EmbeddingsRequest}; + let request_body = EmbeddingsRequest { + input: EmbeddingsInput::Single("Test".to_string()), + model: "test".to_string(), + encoding_format: None, + user: None, + }; + + let response = app + .oneshot( + Request::builder() + .uri("/embeddings") + .method("POST") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request_body).unwrap())) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE); + + let spans = get_spans_for_test(before_count); + let root = get_root_span(&spans).expect("Root span not found"); + let llm = get_llm_span(&spans, "embeddings").expect("LLM span not found"); + + assert!(matches!(root.status, OtelStatus::Error { .. })); + assert!(matches!(llm.status, OtelStatus::Error { .. })); + } + + #[tokio::test] + async fn test_span_request_attributes_recorded() { + let _lock = SPAN_TEST_LOCK.lock().unwrap(); + let _ = &*TEST_EXPORTER; + let before_count = TEST_EXPORTER.get_finished_spans().unwrap().len(); + + let provider_registry = + create_test_provider_registry_for_spans(ProviderType::OpenAI, false); + let model_configs = create_model_configs(vec!["test-model"]); + let model_registry = + Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); + let pipeline = create_test_pipeline(vec!["test-model"]); + let app = create_pipeline(&pipeline, &model_registry, None); + + use crate::models::chat::ChatCompletionRequest; + use crate::models::content::{ChatCompletionMessage, ChatMessageContent}; + let request_body = ChatCompletionRequest { + model: "test".to_string(), + messages: vec![ChatCompletionMessage { + role: "user".to_string(), + content: Some(ChatMessageContent::String("Hello".to_string())), + name: None, + tool_calls: None, + tool_call_id: None, + refusal: None, + }], + temperature: Some(0.8), + top_p: Some(0.9), + frequency_penalty: Some(0.5), + presence_penalty: Some(0.3), + n: None, + stream: None, + stop: None, + max_tokens: None, + max_completion_tokens: None, + logit_bias: None, + user: None, + response_format: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + logprobs: None, + top_logprobs: None, + reasoning: None, + }; + + let _response = app + .oneshot( + Request::builder() + .uri("/chat/completions") + .method("POST") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request_body).unwrap())) + .unwrap(), + ) + .await + .unwrap(); + + let spans = get_spans_for_test(before_count); + let llm = get_llm_span(&spans, "chat").expect("LLM span not found"); + + // Verify request parameters are recorded + assert_span_attributes(llm, &[("gen_ai.request.model", "test")]); + + // Verify float attributes exist (but don't check exact values due to float precision) + assert!( + get_span_attribute(llm, "gen_ai.request.temperature").is_some(), + "Temperature attribute should exist" + ); + assert!( + get_span_attribute(llm, "gen_ai.request.top_p").is_some(), + "top_p attribute should exist" + ); + assert!( + get_span_attribute(llm, "gen_ai.request.frequency_penalty").is_some(), + "frequency_penalty attribute should exist" + ); + assert!( + get_span_attribute(llm, "gen_ai.request.presence_penalty").is_some(), + "presence_penalty attribute should exist" + ); + } + + #[tokio::test] + async fn test_span_response_attributes_recorded() { + let _lock = SPAN_TEST_LOCK.lock().unwrap(); + let _ = &*TEST_EXPORTER; + let before_count = TEST_EXPORTER.get_finished_spans().unwrap().len(); + + let provider_registry = + create_test_provider_registry_for_spans(ProviderType::OpenAI, false); + let model_configs = create_model_configs(vec!["test-model"]); + let model_registry = + Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); + let pipeline = create_test_pipeline(vec!["test-model"]); + let app = create_pipeline(&pipeline, &model_registry, None); + + use crate::models::chat::ChatCompletionRequest; + use crate::models::content::{ChatCompletionMessage, ChatMessageContent}; + let request_body = ChatCompletionRequest { + model: "test".to_string(), + messages: vec![ChatCompletionMessage { + role: "user".to_string(), + content: Some(ChatMessageContent::String("Hello".to_string())), + name: None, + tool_calls: None, + tool_call_id: None, + refusal: None, + }], + temperature: None, + top_p: None, + n: None, + stream: None, + stop: None, + max_tokens: None, + max_completion_tokens: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + user: None, + response_format: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + logprobs: None, + top_logprobs: None, + reasoning: None, + }; + + let _response = app + .oneshot( + Request::builder() + .uri("/chat/completions") + .method("POST") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request_body).unwrap())) + .unwrap(), + ) + .await + .unwrap(); + + let spans = get_spans_for_test(before_count); + let llm = get_llm_span(&spans, "chat").expect("LLM span not found"); + + // Verify response attributes + assert_span_attributes( + llm, + &[ + ("gen_ai.response.model", "test"), + ("gen_ai.response.id", "chatcmpl-test123"), + ], + ); + + // Verify usage tokens + assert_eq!( + get_span_attribute(llm, "gen_ai.usage.prompt_tokens"), + Some("10".to_string()) + ); + assert_eq!( + get_span_attribute(llm, "gen_ai.usage.completion_tokens"), + Some("15".to_string()) + ); + assert_eq!( + get_span_attribute(llm, "gen_ai.usage.total_tokens"), + Some("25".to_string()) + ); + } + + #[tokio::test] + async fn test_vendor_attribute_mapping() { + let _lock = SPAN_TEST_LOCK.lock().unwrap(); + let _ = &*TEST_EXPORTER; + + // Test each provider type maps to correct vendor name + // Note: Skip VertexAI as it requires project_id and location params + let test_cases = vec![ + (ProviderType::OpenAI, "openai"), + (ProviderType::Anthropic, "Anthropic"), + (ProviderType::Azure, "Azure"), + (ProviderType::Bedrock, "AWS"), + ]; + + for (provider_type, expected_vendor) in test_cases { + let before_count = TEST_EXPORTER.get_finished_spans().unwrap().len(); + + let provider_registry = + create_test_provider_registry_for_spans(provider_type.clone(), false); + let model_configs = create_model_configs(vec!["test-model"]); + let model_registry = + Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); + let pipeline = create_test_pipeline(vec!["test-model"]); + let app = create_pipeline(&pipeline, &model_registry, None); + + use crate::models::chat::ChatCompletionRequest; + use crate::models::content::{ChatCompletionMessage, ChatMessageContent}; + let request_body = ChatCompletionRequest { + model: "test".to_string(), + messages: vec![ChatCompletionMessage { + role: "user".to_string(), + content: Some(ChatMessageContent::String("Test".to_string())), + name: None, + tool_calls: None, + tool_call_id: None, + refusal: None, + }], + temperature: None, + top_p: None, + n: None, + stream: None, + stop: None, + max_tokens: None, + max_completion_tokens: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + user: None, + response_format: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + logprobs: None, + top_logprobs: None, + reasoning: None, + }; + + let _response = app + .oneshot( + Request::builder() + .uri("/chat/completions") + .method("POST") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request_body).unwrap())) + .unwrap(), + ) + .await + .unwrap(); + + let spans = get_spans_for_test(before_count); + let llm = get_llm_span(&spans, "chat") + .unwrap_or_else(|| panic!("LLM span not found for {:?}", provider_type)); + + assert_eq!( + get_span_attribute(llm, "gen_ai.system").as_deref(), + Some(expected_vendor), + "Vendor attribute mismatch for {:?}", + provider_type + ); + } + } + + #[tokio::test] + async fn test_span_parent_child_relationship() { + let _lock = SPAN_TEST_LOCK.lock().unwrap(); + let _ = &*TEST_EXPORTER; + let before_count = TEST_EXPORTER.get_finished_spans().unwrap().len(); + + let provider_registry = + create_test_provider_registry_for_spans(ProviderType::OpenAI, false); + let model_configs = create_model_configs(vec!["test-model"]); + let model_registry = + Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); + let pipeline = create_test_pipeline(vec!["test-model"]); + let app = create_pipeline(&pipeline, &model_registry, None); + + use crate::models::chat::ChatCompletionRequest; + use crate::models::content::{ChatCompletionMessage, ChatMessageContent}; + let request_body = ChatCompletionRequest { + model: "test".to_string(), + messages: vec![ChatCompletionMessage { + role: "user".to_string(), + content: Some(ChatMessageContent::String("Test".to_string())), + name: None, + tool_calls: None, + tool_call_id: None, + refusal: None, + }], + temperature: None, + top_p: None, + n: None, + stream: None, + stop: None, + max_tokens: None, + max_completion_tokens: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + user: None, + response_format: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + logprobs: None, + top_logprobs: None, + reasoning: None, + }; + + let _response = app + .oneshot( + Request::builder() + .uri("/chat/completions") + .method("POST") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request_body).unwrap())) + .unwrap(), + ) + .await + .unwrap(); + + let spans = get_spans_for_test(before_count); + let root = get_root_span(&spans).expect("Root span not found"); + let llm = get_llm_span(&spans, "chat").expect("LLM span not found"); + + // Verify parent-child relationship + assert_eq!( + llm.parent_span_id, + root.span_context.span_id(), + "LLM span's parent_span_id should equal root span's span_id" + ); + + // Verify trace_id consistency + assert_eq!( + llm.span_context.trace_id(), + root.span_context.trace_id(), + "Both spans should share the same trace_id" + ); + + // Verify root span has no parent (is actually root) + assert_eq!( + root.parent_span_id.to_string(), + "0000000000000000", + "Root span should not have a parent" + ); + } } } diff --git a/src/pipelines/tracing_middleware.rs b/src/pipelines/tracing_middleware.rs new file mode 100644 index 00000000..1153506e --- /dev/null +++ b/src/pipelines/tracing_middleware.rs @@ -0,0 +1,66 @@ +use axum::body::Body; +use axum::http::Request; +use axum::response::Response; +use std::sync::{Arc, Mutex}; +use tower::{Layer, Service}; + +use super::otel::OtelTracer; + +#[derive(Clone)] +pub struct TracingLayer; + +impl TracingLayer { + pub fn new() -> Self { + Self + } +} + +impl Layer for TracingLayer { + type Service = TracingMiddleware; + + fn layer(&self, inner: S) -> Self::Service { + TracingMiddleware { inner } + } +} + +#[derive(Clone)] +pub struct TracingMiddleware { + inner: S, +} + +impl Service> for TracingMiddleware +where + S: Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = std::pin::Pin< + Box> + Send>, + >; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request) -> Self::Future { + let mut inner = self.inner.clone(); + + Box::pin(async move { + // Create root traceloop_hub span and wrap in Arc> + let tracer = Arc::new(Mutex::new(OtelTracer::start())); + + // Insert into request extensions for downstream middleware and handlers + req.extensions_mut().insert(tracer.clone()); + + // Call inner service + let response = inner.call(req).await?; + + // Tracer will be finalized when the Arc is dropped + Ok(response) + }) + } +} diff --git a/src/providers/registry.rs b/src/providers/registry.rs index 4e25e91e..3a012de9 100644 --- a/src/providers/registry.rs +++ b/src/providers/registry.rs @@ -34,4 +34,9 @@ impl ProviderRegistry { pub fn get(&self, name: &str) -> Option> { self.providers.get(name).cloned() } + + #[cfg(test)] + pub fn providers_mut(&mut self) -> &mut HashMap> { + &mut self.providers + } } diff --git a/src/state.rs b/src/state.rs index bb5ab2fd..2850e6c8 100644 --- a/src/state.rs +++ b/src/state.rs @@ -161,10 +161,17 @@ impl AppState { _provider_registry: &Arc, model_registry: &Arc, ) -> axum::Router { + use crate::guardrails::setup::build_guardrail_resources; use crate::pipelines::pipeline::create_pipeline; debug!("Building router with {} pipelines", config.pipelines.len()); + // Build shared guardrail resources once for all pipelines + let guardrail_resources = config + .guardrails + .as_ref() + .and_then(build_guardrail_resources); + let (default_pipeline, other_pipelines): (Vec<_>, Vec<_>) = config .pipelines .iter() @@ -179,7 +186,11 @@ impl AppState { "Adding default pipeline '{}' to router at index 0", default_pipeline.name ); - let pipeline_router = create_pipeline(default_pipeline, model_registry); + let pipeline_router = create_pipeline( + default_pipeline, + model_registry, + guardrail_resources.as_ref(), + ); pipeline_routers.push(pipeline_router); pipeline_names.push(default_pipeline.name.clone()); } @@ -188,7 +199,8 @@ impl AppState { let name = &pipeline.name; debug!("Adding pipeline '{}' to router at index {}", name, idx + 1); - let pipeline_router = create_pipeline(pipeline, model_registry); + let pipeline_router = + create_pipeline(pipeline, model_registry, guardrail_resources.as_ref()); pipeline_routers.push(pipeline_router); pipeline_names.push(name.clone()); } @@ -314,7 +326,7 @@ impl tower::Service> for PipelineSteeringService { }; Box::pin(async move { - let router = Arc::try_unwrap(router).unwrap_or_else(|arc_router| (*arc_router).clone()); + let router: Router = (*router).clone(); match router.oneshot(request).await { Ok(response) => Ok(response), diff --git a/src/types/mod.rs b/src/types/mod.rs index 113152dc..fbbb21e8 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -149,8 +149,8 @@ pub struct Pipeline { // #[serde(with = "serde_yaml::with::singleton_map_recursive")] #[serde(default, skip_serializing_if = "Vec::is_empty")] pub plugins: Vec, - // ee_id: Option, // Removed - // enabled: bool, // Removed + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub guards: Vec, } #[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq, Hash)] @@ -169,4 +169,6 @@ pub struct GatewayConfig { pub models: Vec, #[serde(default, skip_serializing_if = "Vec::is_empty")] pub pipelines: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub guardrails: Option, } diff --git a/tests/cassettes/guardrails/json_validator_pass.json b/tests/cassettes/guardrails/json_validator_pass.json new file mode 100644 index 00000000..0a049f7c --- /dev/null +++ b/tests/cassettes/guardrails/json_validator_pass.json @@ -0,0 +1,25 @@ +{ + "evaluator_slug": "json-validator", + "input_text": "{\"name\": \"Alice\", \"age\": 30}", + "params": { + "enable_schema_validation": true, + "schema_string": "{\"name\": \"string\", \"age\": \"number\"}" + }, + "request_body": { + "input": { + "text": "{\"name\": \"Alice\", \"age\": 30}" + }, + "config": { + "enable_schema_validation": true, + "schema_string": "{\"name\": \"string\", \"age\": \"number\"}" + } + }, + "response_status": 200, + "response_body": { + "result": { + "is_valid_json": true + }, + "pass": true + }, + "expected_pass": true +} diff --git a/tests/cassettes/guardrails/pii_detector_pass.json b/tests/cassettes/guardrails/pii_detector_pass.json new file mode 100644 index 00000000..2bcb62b9 --- /dev/null +++ b/tests/cassettes/guardrails/pii_detector_pass.json @@ -0,0 +1,18 @@ +{ + "evaluator_slug": "pii-detector", + "input_text": "The weather is sunny today and I like programming in Rust.", + "params": {}, + "request_body": { + "input": { + "text": "The weather is sunny today and I like programming in Rust." + } + }, + "response_status": 200, + "response_body": { + "result": { + "has_pii": false + }, + "pass": true + }, + "expected_pass": true +} diff --git a/tests/cassettes/guardrails/profanity_detector_fail.json b/tests/cassettes/guardrails/profanity_detector_fail.json new file mode 100644 index 00000000..f2fff61e --- /dev/null +++ b/tests/cassettes/guardrails/profanity_detector_fail.json @@ -0,0 +1,18 @@ +{ + "evaluator_slug": "profanity-detector", + "input_text": "This is damn bullshit and I think it's a total crap product.", + "params": {}, + "request_body": { + "input": { + "text": "This is damn bullshit and I think it's a total crap product." + } + }, + "response_status": 200, + "response_body": { + "result": { + "is_safe": false + }, + "pass": false + }, + "expected_pass": false +} diff --git a/tests/cassettes/guardrails/prompt_injection_pass.json b/tests/cassettes/guardrails/prompt_injection_pass.json new file mode 100644 index 00000000..67b6c61e --- /dev/null +++ b/tests/cassettes/guardrails/prompt_injection_pass.json @@ -0,0 +1,23 @@ +{ + "evaluator_slug": "prompt-injection", + "input_text": "What is the capital of France?", + "params": { + "threshold": 0.8 + }, + "request_body": { + "input": { + "prompt": "What is the capital of France?" + }, + "config": { + "threshold": 0.8 + } + }, + "response_status": 200, + "response_body": { + "result": { + "has_injection": false + }, + "pass": true + }, + "expected_pass": true +} diff --git a/tests/cassettes/guardrails/prompt_perplexity_pass.json b/tests/cassettes/guardrails/prompt_perplexity_pass.json new file mode 100644 index 00000000..5db73ae6 --- /dev/null +++ b/tests/cassettes/guardrails/prompt_perplexity_pass.json @@ -0,0 +1,18 @@ +{ + "evaluator_slug": "prompt-perplexity", + "input_text": "Please explain the concept of photosynthesis in simple terms.", + "params": {}, + "request_body": { + "input": { + "prompt": "Please explain the concept of photosynthesis in simple terms." + } + }, + "response_status": 200, + "response_body": { + "result": { + "perplexity_score": 80.0 + }, + "pass": true + }, + "expected_pass": true +} diff --git a/tests/cassettes/guardrails/regex_validator_pass.json b/tests/cassettes/guardrails/regex_validator_pass.json new file mode 100644 index 00000000..a38772aa --- /dev/null +++ b/tests/cassettes/guardrails/regex_validator_pass.json @@ -0,0 +1,31 @@ +{ + "evaluator_slug": "regex-validator", + "input_text": "Order ID: ABC-12345", + "params": { + "regex": "^[A-Z]{3}-\\d{5}$", + "should_match": true, + "case_sensitive": true, + "dot_include_nl": true, + "multi_line": true + }, + "request_body": { + "input": { + "text": "Order ID: ABC-12345" + }, + "config": { + "regex": "^[A-Z]{3}-\\d{5}$", + "should_match": true, + "case_sensitive": true, + "dot_include_nl": true, + "multi_line": true + } + }, + "response_status": 200, + "response_body": { + "result": { + "is_valid_regex": true + }, + "pass": true + }, + "expected_pass": true +} diff --git a/tests/cassettes/guardrails/secrets_detector_pass.json b/tests/cassettes/guardrails/secrets_detector_pass.json new file mode 100644 index 00000000..6f5615ee --- /dev/null +++ b/tests/cassettes/guardrails/secrets_detector_pass.json @@ -0,0 +1,18 @@ +{ + "evaluator_slug": "secrets-detector", + "input_text": "Here is a simple function that adds two numbers together.", + "params": {}, + "request_body": { + "input": { + "text": "Here is a simple function that adds two numbers together." + } + }, + "response_status": 200, + "response_body": { + "result": { + "has_secret": false + }, + "pass": true + }, + "expected_pass": true +} diff --git a/tests/cassettes/guardrails/sexism_detector_fail.json b/tests/cassettes/guardrails/sexism_detector_fail.json new file mode 100644 index 00000000..9473c761 --- /dev/null +++ b/tests/cassettes/guardrails/sexism_detector_fail.json @@ -0,0 +1,23 @@ +{ + "evaluator_slug": "sexism-detector", + "input_text": "Women should not be in leadership positions because they are too emotional.", + "params": { + "threshold": 0.7 + }, + "request_body": { + "input": { + "text": "Women should not be in leadership positions because they are too emotional." + }, + "config": { + "threshold": 0.7 + } + }, + "response_status": 200, + "response_body": { + "result": { + "is_safe": false + }, + "pass": false + }, + "expected_pass": false +} diff --git a/tests/cassettes/guardrails/sql_validator_pass.json b/tests/cassettes/guardrails/sql_validator_pass.json new file mode 100644 index 00000000..fdfedfc5 --- /dev/null +++ b/tests/cassettes/guardrails/sql_validator_pass.json @@ -0,0 +1,18 @@ +{ + "evaluator_slug": "sql-validator", + "input_text": "SELECT id, name FROM users WHERE active = true ORDER BY name", + "params": {}, + "request_body": { + "input": { + "text": "SELECT id, name FROM users WHERE active = true ORDER BY name" + } + }, + "response_status": 200, + "response_body": { + "result": { + "is_valid_sql": true + }, + "pass": true + }, + "expected_pass": true +} diff --git a/tests/cassettes/guardrails/tone_detection_fail.json b/tests/cassettes/guardrails/tone_detection_fail.json new file mode 100644 index 00000000..b37d5475 --- /dev/null +++ b/tests/cassettes/guardrails/tone_detection_fail.json @@ -0,0 +1,19 @@ +{ + "evaluator_slug": "tone-detection", + "input_text": "This is ABSOLUTELY UNACCEPTABLE. I DEMAND to speak to someone competent immediately!", + "params": {}, + "request_body": { + "input": { + "text": "This is ABSOLUTELY UNACCEPTABLE. I DEMAND to speak to someone competent immediately!" + } + }, + "response_status": 200, + "response_body": { + "result": { + "tone": "", + "score": 0 + }, + "pass": false + }, + "expected_pass": false +} diff --git a/tests/cassettes/guardrails/toxicity_detector_fail.json b/tests/cassettes/guardrails/toxicity_detector_fail.json new file mode 100644 index 00000000..2b25d772 --- /dev/null +++ b/tests/cassettes/guardrails/toxicity_detector_fail.json @@ -0,0 +1,17 @@ +{ + "evaluator_slug": "toxicity-detector", + "input_text": "You are a complete idiot and everyone hates you. You should be ashamed.", + "request_body": { + "input": { + "text": "You are a complete idiot and everyone hates you. You should be ashamed." + } + }, + "response_status": 200, + "response_body": { + "result": { + "is_safe": false + }, + "pass": false + }, + "expected_pass": false +} diff --git a/tests/cassettes/guardrails/uncertainty_detector_pass.json b/tests/cassettes/guardrails/uncertainty_detector_pass.json new file mode 100644 index 00000000..8b713854 --- /dev/null +++ b/tests/cassettes/guardrails/uncertainty_detector_pass.json @@ -0,0 +1,19 @@ +{ + "evaluator_slug": "uncertainty-detector", + "input_text": "What is 2 + 2?", + "params": {}, + "request_body": { + "input": { + "prompt": "What is 2 + 2?" + } + }, + "response_status": 200, + "response_body": { + "result": { + "answer": "", + "uncertainty": 0 + }, + "pass": true + }, + "expected_pass": true +} diff --git a/tests/config_hash_tests.rs b/tests/config_hash_tests.rs index 53ff4597..d2994d2e 100644 --- a/tests/config_hash_tests.rs +++ b/tests/config_hash_tests.rs @@ -6,6 +6,7 @@ use std::collections::HashMap; fn test_identical_configs_have_same_hash() { let config1 = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test".to_string(), r#type: ProviderType::OpenAI, @@ -29,6 +30,7 @@ fn test_identical_configs_have_same_hash() { fn test_different_configs_have_different_hashes() { let config1 = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test1".to_string(), r#type: ProviderType::OpenAI, @@ -41,6 +43,7 @@ fn test_different_configs_have_different_hashes() { let config2 = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test2".to_string(), // Different key r#type: ProviderType::OpenAI, @@ -70,6 +73,7 @@ fn test_params_order_independence() { let config1 = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test".to_string(), r#type: ProviderType::OpenAI, @@ -82,6 +86,7 @@ fn test_params_order_independence() { let config2 = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test".to_string(), r#type: ProviderType::OpenAI, diff --git a/tests/guardrails/helpers.rs b/tests/guardrails/helpers.rs new file mode 100644 index 00000000..83f84109 --- /dev/null +++ b/tests/guardrails/helpers.rs @@ -0,0 +1,376 @@ +use std::collections::HashMap; +use std::convert::Infallible; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use async_trait::async_trait; +use axum::body::Body; +use axum::extract::Request; +use axum::http::StatusCode; +use axum::response::Response; +use hub_lib::guardrails::types::GuardrailClient; +use hub_lib::guardrails::types::{EvaluatorResponse, Guard, GuardMode, GuardrailError, OnFailure}; +use hub_lib::models::chat::{ChatCompletion, ChatCompletionChoice, ChatCompletionRequest}; +use hub_lib::models::completion::{CompletionChoice, CompletionRequest, CompletionResponse}; +use hub_lib::models::content::{ChatCompletionMessage, ChatMessageContent}; +use hub_lib::models::embeddings::{ + Embedding, Embeddings, EmbeddingsInput, EmbeddingsRequest, EmbeddingsResponse, +}; +use hub_lib::models::usage::{EmbeddingUsage, Usage}; +use serde::Serialize; +use serde_json::json; +use tower::Service; + +// --------------------------------------------------------------------------- +// Guard config builders +// --------------------------------------------------------------------------- + +pub struct TestGuardBuilder { + guard: Guard, +} + +impl TestGuardBuilder { + pub fn new(name: &str, mode: GuardMode) -> Self { + Self { + guard: Guard { + name: name.to_string(), + provider: "traceloop".to_string(), + evaluator_slug: "pii-detector".to_string(), + params: HashMap::new(), + mode, + on_failure: OnFailure::Block, + required: false, + api_base: Some("http://localhost:8080".to_string()), + api_key: Some("test-api-key".to_string()), + }, + } + } + + pub fn on_failure(mut self, on_failure: OnFailure) -> Self { + self.guard.on_failure = on_failure; + self + } + + pub fn required(mut self, required: bool) -> Self { + self.guard.required = required; + self + } + + pub fn api_base(mut self, api_base: &str) -> Self { + self.guard.api_base = Some(api_base.to_string()); + self + } + + pub fn evaluator_slug(mut self, slug: &str) -> Self { + self.guard.evaluator_slug = slug.to_string(); + self + } + + pub fn build(self) -> Guard { + self.guard + } +} + +// Backward-compatible helper functions +pub fn create_test_guard(name: &str, mode: GuardMode) -> Guard { + TestGuardBuilder::new(name, mode).build() +} + +pub fn create_test_guard_with_failure_action( + name: &str, + mode: GuardMode, + on_failure: OnFailure, +) -> Guard { + TestGuardBuilder::new(name, mode) + .on_failure(on_failure) + .build() +} + +pub fn create_test_guard_with_required(name: &str, mode: GuardMode, required: bool) -> Guard { + TestGuardBuilder::new(name, mode).required(required).build() +} + +pub fn create_test_guard_with_api_base(name: &str, mode: GuardMode, api_base: &str) -> Guard { + TestGuardBuilder::new(name, mode).api_base(api_base).build() +} + +// --------------------------------------------------------------------------- +// Evaluator response builders +// --------------------------------------------------------------------------- + +pub fn passing_response() -> EvaluatorResponse { + EvaluatorResponse { + result: json!({"score": 0.95, "label": "safe"}), + pass: true, + } +} + +pub fn failing_response() -> EvaluatorResponse { + EvaluatorResponse { + result: json!({"score": 0.2, "label": "unsafe", "reason": "Content violates policy"}), + pass: false, + } +} + +// --------------------------------------------------------------------------- +// Chat request/response builders +// --------------------------------------------------------------------------- + +pub fn default_message() -> ChatCompletionMessage { + ChatCompletionMessage { + role: String::new(), + content: None, + name: None, + tool_calls: None, + tool_call_id: None, + refusal: None, + } +} + +pub fn default_request() -> ChatCompletionRequest { + ChatCompletionRequest { + model: "gpt-4".to_string(), + messages: vec![], + temperature: None, + top_p: None, + n: None, + stream: None, + stop: None, + max_tokens: None, + max_completion_tokens: None, + parallel_tool_calls: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + tool_choice: None, + tools: None, + user: None, + logprobs: None, + top_logprobs: None, + response_format: None, + reasoning: None, + } +} + +pub fn create_test_chat_request(user_message: &str) -> ChatCompletionRequest { + ChatCompletionRequest { + model: "gpt-4".to_string(), + messages: vec![ChatCompletionMessage { + role: "user".to_string(), + content: Some(ChatMessageContent::String(user_message.to_string())), + name: None, + tool_calls: None, + tool_call_id: None, + refusal: None, + }], + temperature: None, + top_p: None, + n: None, + stream: None, + stop: None, + max_tokens: None, + max_completion_tokens: None, + parallel_tool_calls: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + tool_choice: None, + tools: None, + user: None, + logprobs: None, + top_logprobs: None, + response_format: None, + reasoning: None, + } +} + +pub fn create_test_chat_completion(response_text: &str) -> ChatCompletion { + ChatCompletion { + id: "chatcmpl-test".to_string(), + object: Some("chat.completion".to_string()), + created: Some(1234567890), + model: "gpt-4".to_string(), + choices: vec![ChatCompletionChoice { + index: 0, + message: ChatCompletionMessage { + role: "assistant".to_string(), + content: Some(ChatMessageContent::String(response_text.to_string())), + name: None, + tool_calls: None, + tool_call_id: None, + refusal: None, + }, + finish_reason: Some("stop".to_string()), + logprobs: None, + }], + usage: Usage::default(), + system_fingerprint: None, + } +} + +// --------------------------------------------------------------------------- +// Completion request/response builders +// --------------------------------------------------------------------------- + +pub fn create_test_completion_request(prompt: &str) -> CompletionRequest { + CompletionRequest { + model: "gpt-3.5-turbo-instruct".to_string(), + prompt: prompt.to_string(), + suffix: None, + max_tokens: Some(100), + temperature: None, + top_p: None, + n: None, + stream: None, + logprobs: None, + echo: None, + stop: None, + presence_penalty: None, + frequency_penalty: None, + best_of: None, + logit_bias: None, + user: None, + } +} + +pub fn create_test_completion_response(text: &str) -> CompletionResponse { + CompletionResponse { + id: "cmpl-test".to_string(), + object: "text_completion".to_string(), + created: 1234567890, + model: "gpt-3.5-turbo-instruct".to_string(), + choices: vec![CompletionChoice { + text: text.to_string(), + index: 0, + logprobs: None, + finish_reason: Some("stop".to_string()), + }], + usage: Usage::default(), + } +} + +// --------------------------------------------------------------------------- +// Embeddings request/response builders +// --------------------------------------------------------------------------- + +pub fn create_test_embeddings_request(text: &str) -> EmbeddingsRequest { + EmbeddingsRequest { + model: "text-embedding-ada-002".to_string(), + input: EmbeddingsInput::Single(text.to_string()), + user: None, + encoding_format: None, + } +} + +pub fn create_test_embeddings_response() -> EmbeddingsResponse { + EmbeddingsResponse { + object: "list".to_string(), + data: vec![Embeddings { + object: "embedding".to_string(), + embedding: Embedding::Float(vec![0.1, 0.2, 0.3]), + index: 0, + }], + model: "text-embedding-ada-002".to_string(), + usage: EmbeddingUsage { + prompt_tokens: Some(8), + total_tokens: Some(8), + }, + } +} + +// --------------------------------------------------------------------------- +// Streaming request builders +// --------------------------------------------------------------------------- + +pub fn create_streaming_chat_request(message: &str) -> ChatCompletionRequest { + let mut req = create_test_chat_request(message); + req.stream = Some(true); + req +} + +pub fn create_streaming_completion_request(prompt: &str) -> CompletionRequest { + let mut req = create_test_completion_request(prompt); + req.stream = Some(true); + req +} + +// --------------------------------------------------------------------------- +// Mock Service for middleware testing +// --------------------------------------------------------------------------- + +#[derive(Clone)] +pub struct MockService { + status: StatusCode, + body: Vec, +} + +impl MockService { + pub fn with_json(status: StatusCode, data: &T) -> Self { + let body = serde_json::to_vec(data).unwrap(); + Self { status, body } + } +} + +impl Service> for MockService { + type Response = Response; + type Error = Infallible; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _req: Request) -> Self::Future { + let status = self.status; + let body = self.body.clone(); + Box::pin(async move { + let response = Response::builder() + .status(status) + .header("content-type", "application/json") + .body(Body::from(body)) + .unwrap(); + Ok(response) + }) + } +} + +// --------------------------------------------------------------------------- +// Mock GuardrailClient +// --------------------------------------------------------------------------- + +pub struct MockGuardrailClient { + pub responses: HashMap>, +} + +impl MockGuardrailClient { + pub fn with_response(name: &str, resp: Result) -> Self { + let mut responses = HashMap::new(); + responses.insert(name.to_string(), resp); + Self { responses } + } + + pub fn with_responses(entries: Vec<(&str, Result)>) -> Self { + let mut responses = HashMap::new(); + for (name, resp) in entries { + responses.insert(name.to_string(), resp); + } + Self { responses } + } +} + +#[async_trait] +impl GuardrailClient for MockGuardrailClient { + async fn evaluate( + &self, + guard: &Guard, + _input: &str, + ) -> Result { + self.responses + .get(&guard.name) + .cloned() + .unwrap_or(Err(GuardrailError::Unavailable( + "no mock configured".to_string(), + ))) + } +} diff --git a/tests/guardrails/main.rs b/tests/guardrails/main.rs new file mode 100644 index 00000000..dd32663e --- /dev/null +++ b/tests/guardrails/main.rs @@ -0,0 +1,9 @@ +mod helpers; +mod test_e2e; +mod test_middleware; +mod test_parsing; +mod test_run_evaluator; +mod test_runner; +mod test_setup; +mod test_traceloop_client; +mod test_types; diff --git a/tests/guardrails/test_e2e.rs b/tests/guardrails/test_e2e.rs new file mode 100644 index 00000000..0d4c3aa3 --- /dev/null +++ b/tests/guardrails/test_e2e.rs @@ -0,0 +1,633 @@ +use hub_lib::guardrails::parsing::{CompletionExtractor, PromptExtractor}; +use hub_lib::guardrails::providers::traceloop::TraceloopClient; +use hub_lib::guardrails::runner::{ + GuardrailsRunner, blocked_response, execute_guards, warning_header_value, +}; +use hub_lib::guardrails::setup::{build_guardrail_resources, build_pipeline_guardrails}; +use hub_lib::guardrails::types::*; + +use axum::body::to_bytes; +use axum::http::HeaderMap; + +use serde_json::json; +use std::sync::Arc; +use wiremock::matchers; +use wiremock::{Mock, MockServer, ResponseTemplate}; + +use super::helpers::*; + +/// Helper: set up a wiremock evaluator that returns pass/fail +async fn setup_evaluator(pass: bool) -> MockServer { + let server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {"score": if pass { 0.95 } else { 0.1 }}, + "pass": pass + }))) + .mount(&server) + .await; + server +} + +#[tokio::test] +async fn test_e2e_pre_call_block_flow() { + // Request -> guard fail+block -> 403 + let eval = setup_evaluator(false).await; + let guard = TestGuardBuilder::new("blocker", GuardMode::PreCall) + .on_failure(OnFailure::Block) + .api_base(&eval.uri()) + .evaluator_slug("toxicity-detector") + .build(); + + let request = create_test_chat_request("Bad input"); + let input = request.extract_prompt(); + + let client = TraceloopClient::new(); + let outcome = execute_guards(&[guard], &input, &client, None).await; + + assert!(outcome.blocked); + assert_eq!(outcome.blocking_guard.as_deref(), Some("blocker")); +} + +#[tokio::test] +async fn test_e2e_pre_call_pass_flow() { + // Request -> guard pass -> LLM -> 200 + let eval = setup_evaluator(true).await; + let guard = TestGuardBuilder::new("checker", GuardMode::PreCall) + .on_failure(OnFailure::Block) + .api_base(&eval.uri()) + .evaluator_slug("toxicity-detector") + .build(); + + let request = create_test_chat_request("Safe input"); + let input = request.extract_prompt(); + + let client = TraceloopClient::new(); + let outcome = execute_guards(&[guard], &input, &client, None).await; + + assert!(!outcome.blocked); + assert!(outcome.warnings.is_empty()); + // In real flow, would proceed to LLM call +} + +#[tokio::test] +async fn test_e2e_post_call_block_flow() { + // Request -> LLM -> guard fail+block -> 403 + let eval = setup_evaluator(false).await; + let guard = TestGuardBuilder::new("pii-check", GuardMode::PostCall) + .on_failure(OnFailure::Block) + .api_base(&eval.uri()) + .evaluator_slug("pii-detector") + .build(); + + // Simulate LLM response + let completion = create_test_chat_completion("Here is the SSN: 123-45-6789"); + let response_text = completion.extract_completion(); + + let client = TraceloopClient::new(); + let outcome = execute_guards(&[guard], &response_text, &client, None).await; + + assert!(outcome.blocked); + assert_eq!(outcome.blocking_guard.as_deref(), Some("pii-check")); +} + +#[tokio::test] +async fn test_e2e_post_call_warn_flow() { + // Request -> LLM -> guard fail+warn -> 200 + header + let eval = setup_evaluator(false).await; + let guard = TestGuardBuilder::new("tone-check", GuardMode::PostCall) + .on_failure(OnFailure::Warn) + .api_base(&eval.uri()) + .evaluator_slug("tone-detection") + .build(); + + let completion = create_test_chat_completion("Mildly concerning response"); + let response_text = completion.extract_completion(); + + let client = TraceloopClient::new(); + let outcome = execute_guards(&[guard], &response_text, &client, None).await; + + assert!(!outcome.blocked); + assert_eq!(outcome.warnings.len(), 1); + assert_eq!(outcome.warnings[0].guard_name, "tone-check"); +} + +#[tokio::test] +async fn test_e2e_pre_and_post_both_pass() { + // Both stages pass -> clean 200 response + let pre_eval = setup_evaluator(true).await; + let post_eval = setup_evaluator(true).await; + + let pre_guard = TestGuardBuilder::new("pre-check", GuardMode::PreCall) + .on_failure(OnFailure::Block) + .api_base(&pre_eval.uri()) + .evaluator_slug("profanity-detector") + .build(); + let post_guard = TestGuardBuilder::new("post-check", GuardMode::PostCall) + .on_failure(OnFailure::Block) + .api_base(&post_eval.uri()) + .evaluator_slug("pii-detector") + .build(); + + let client = TraceloopClient::new(); + + // Pre-call + let request = create_test_chat_request("Hello"); + let input = request.extract_prompt(); + let pre_outcome = execute_guards(&[pre_guard], &input, &client, None).await; + assert!(!pre_outcome.blocked); + + // Post-call + let completion = create_test_chat_completion("Hi there!"); + let response_text = completion.extract_completion(); + let post_outcome = execute_guards(&[post_guard], &response_text, &client, None).await; + assert!(!post_outcome.blocked); + assert!(post_outcome.warnings.is_empty()); +} + +#[tokio::test] +async fn test_e2e_pre_blocks_post_never_runs() { + // Pre blocks -> post evaluator gets 0 requests + let pre_eval = setup_evaluator(false).await; + let post_eval = MockServer::start().await; + + // Post evaluator should receive 0 requests + Mock::given(matchers::any()) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({"result": {}, "pass": true}))) + .expect(0) + .mount(&post_eval) + .await; + + let pre_guard = TestGuardBuilder::new("blocker", GuardMode::PreCall) + .on_failure(OnFailure::Block) + .api_base(&pre_eval.uri()) + .evaluator_slug("toxicity-detector") + .build(); + let post_guard = TestGuardBuilder::new("post-check", GuardMode::PostCall) + .on_failure(OnFailure::Block) + .api_base(&post_eval.uri()) + .evaluator_slug("pii-detector") + .build(); + + let client = TraceloopClient::new(); + let request = create_test_chat_request("Bad input"); + let input = request.extract_prompt(); + + let pre_outcome = execute_guards(&[pre_guard], &input, &client, None).await; + assert!(pre_outcome.blocked); + + // Since pre blocked, post guards never run - post_eval.verify() will assert 0 calls + // (wiremock verifies expect(0) when server drops) + let _ = post_guard; // not used - that's the point +} + +#[tokio::test] +async fn test_e2e_mixed_block_and_warn() { + // Multiple guards with mixed block/warn outcomes + let eval1 = setup_evaluator(true).await; // passes + let eval2 = setup_evaluator(false).await; // fails -> warn + let eval3 = setup_evaluator(false).await; // fails -> block + + let guards = vec![ + TestGuardBuilder::new("passer", GuardMode::PreCall) + .on_failure(OnFailure::Block) + .api_base(&eval1.uri()) + .evaluator_slug("profanity-detector") + .build(), + TestGuardBuilder::new("warner", GuardMode::PreCall) + .on_failure(OnFailure::Warn) + .api_base(&eval2.uri()) + .evaluator_slug("tone-detection") + .build(), + TestGuardBuilder::new("blocker", GuardMode::PreCall) + .on_failure(OnFailure::Block) + .api_base(&eval3.uri()) + .evaluator_slug("toxicity-detector") + .build(), + ]; + + let client = TraceloopClient::new(); + let outcome = execute_guards(&guards, "test input", &client, None).await; + + assert!(outcome.blocked); + assert_eq!(outcome.blocking_guard.as_deref(), Some("blocker")); + assert!(outcome.warnings.iter().any(|w| w.guard_name == "warner")); +} + +#[tokio::test] +async fn test_e2e_streaming_post_call_buffer_pass() { + // Stream buffered, guard passes -> SSE response streamed to client + let eval = setup_evaluator(true).await; + let guard = TestGuardBuilder::new("response-check", GuardMode::PostCall) + .on_failure(OnFailure::Block) + .api_base(&eval.uri()) + .evaluator_slug("profanity-detector") + .build(); + + let accumulated = "Hello world!"; + + let client = TraceloopClient::new(); + let outcome = execute_guards(&[guard], &accumulated, &client, None).await; + + assert!(!outcome.blocked); +} + +#[tokio::test] +async fn test_e2e_streaming_post_call_buffer_block() { + // Stream buffered, guard blocks -> 403 + let eval = setup_evaluator(false).await; + let guard = TestGuardBuilder::new("pii-check", GuardMode::PostCall) + .on_failure(OnFailure::Block) + .api_base(&eval.uri()) + .evaluator_slug("pii-detector") + .build(); + + let accumulated = "Here is SSN: 123-45-6789"; + + let client = TraceloopClient::new(); + let outcome = execute_guards(&[guard], &accumulated, &client, None).await; + + assert!(outcome.blocked); +} + +#[tokio::test] +async fn test_e2e_config_from_yaml_with_env_vars() { + // Full YAML config with ${VAR} substitution in api_key + use std::io::Write; + use tempfile::NamedTempFile; + + let config_yaml = r#" +providers: + - key: openai + type: openai + api_key: "sk-test" +models: + - key: gpt-4 + type: gpt-4 + provider: openai +pipelines: + - name: default + type: chat + plugins: + - model-router: + models: + - gpt-4 +guardrails: + providers: + - name: traceloop + api_base: "https://api.traceloop.com" + api_key: "${E2E_TEST_API_KEY}" + guards: + - name: toxicity-check + provider: traceloop + evaluator_slug: toxicity-detector + mode: pre_call + on_failure: block + - name: pii-check + provider: traceloop + evaluator_slug: pii-detector + mode: post_call + on_failure: warn + api_key: "override-key" +"#; + + let mut temp_file = NamedTempFile::new().unwrap(); + temp_file.write_all(config_yaml.as_bytes()).unwrap(); + let temp_path = temp_file.path().to_str().unwrap().to_owned(); + + let gr = temp_env::with_var("E2E_TEST_API_KEY", Some("resolved-key-123"), || { + let config = hub_lib::config::load_config(&temp_path).unwrap(); + config.guardrails.unwrap() + }); + + assert_eq!(gr.providers.len(), 1); + assert_eq!(gr.providers["traceloop"].api_key, "resolved-key-123"); + + // Guards should have evaluator_slug at top level + assert_eq!(gr.guards[0].evaluator_slug, "toxicity-detector"); + assert_eq!(gr.guards[0].mode, GuardMode::PreCall); + assert!(gr.guards[0].api_base.is_none()); // inherits from provider + assert!(gr.guards[0].api_key.is_none()); // inherits from provider + + // Second guard overrides api_key + assert_eq!(gr.guards[1].api_key.as_deref(), Some("override-key")); + + // Build pipeline guardrails - should resolve provider defaults + let shared = build_guardrail_resources(&gr).unwrap(); + let guard_names: Vec = gr.guards.iter().map(|g| g.name.clone()).collect(); + let pipeline_gr = build_pipeline_guardrails(&shared, &guard_names); + assert_eq!(pipeline_gr.all_guards.len(), 2); + assert_eq!(pipeline_gr.pipeline_guard_names.len(), 2); + // Provider api_base should be resolved for guards that don't override + let pre_guard = pipeline_gr + .all_guards + .iter() + .find(|g| g.mode == GuardMode::PreCall) + .unwrap(); + let post_guard = pipeline_gr + .all_guards + .iter() + .find(|g| g.mode == GuardMode::PostCall) + .unwrap(); + assert_eq!( + pre_guard.api_base.as_deref(), + Some("https://api.traceloop.com") + ); + assert_eq!(pre_guard.api_key.as_deref(), Some("resolved-key-123")); + // Guard with override keeps its own api_key + assert_eq!(post_guard.api_key.as_deref(), Some("override-key")); +} + +#[tokio::test] +async fn test_e2e_multiple_guards_different_evaluators() { + // Different evaluator slugs -> separate mock expectations + let server = MockServer::start().await; + + Mock::given(matchers::method("POST")) + .and(matchers::path("/v2/guardrails/execute/toxicity-detector")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({"result": {}, "pass": true}))) + .expect(1) + .mount(&server) + .await; + + Mock::given(matchers::method("POST")) + .and(matchers::path("/v2/guardrails/execute/pii-detector")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({"result": {}, "pass": true}))) + .expect(1) + .mount(&server) + .await; + + let guards = vec![ + TestGuardBuilder::new("tox-guard", GuardMode::PreCall) + .on_failure(OnFailure::Block) + .api_base(&server.uri()) + .evaluator_slug("toxicity-detector") + .build(), + TestGuardBuilder::new("pii-guard", GuardMode::PreCall) + .on_failure(OnFailure::Block) + .api_base(&server.uri()) + .evaluator_slug("pii-detector") + .build(), + ]; + + let client = TraceloopClient::new(); + let outcome = execute_guards(&guards, "test input", &client, None).await; + + assert!(!outcome.blocked); + assert_eq!(outcome.results.len(), 2); + // wiremock will verify each path was called exactly once +} + +#[tokio::test] +async fn test_e2e_fail_open_evaluator_down() { + // Evaluator service down + required: false -> passthrough + let server = MockServer::start().await; + Mock::given(matchers::any()) + .respond_with(ResponseTemplate::new(500).set_body_string("Internal Server Error")) + .mount(&server) + .await; + + let mut guard = TestGuardBuilder::new("checker", GuardMode::PreCall) + .on_failure(OnFailure::Block) + .api_base(&server.uri()) + .evaluator_slug("profanity-detector") + .build(); + guard.required = false; // fail-open + + let client = TraceloopClient::new(); + let outcome = execute_guards(&[guard], "test input", &client, None).await; + + assert!(!outcome.blocked); // Fail-open: not blocked despite error +} + +#[tokio::test] +async fn test_e2e_fail_closed_evaluator_down() { + // Evaluator service down + required: true -> 403 + let server = MockServer::start().await; + Mock::given(matchers::any()) + .respond_with(ResponseTemplate::new(500).set_body_string("Internal Server Error")) + .mount(&server) + .await; + + let mut guard = TestGuardBuilder::new("checker", GuardMode::PreCall) + .on_failure(OnFailure::Block) + .api_base(&server.uri()) + .evaluator_slug("profanity-detector") + .build(); + guard.required = true; // fail-closed + + let client = TraceloopClient::new(); + let outcome = execute_guards(&[guard], "test input", &client, None).await; + + assert!(outcome.blocked); // Fail-closed: blocked due to error +} + +#[tokio::test] +async fn test_e2e_config_validation_rejects_invalid() { + // Config with missing required fields -> deserialization error + let invalid_json = json!({ + "guards": [{ + "name": "incomplete-guard" + // missing provider, evaluator_slug, mode + }] + }); + let result = serde_json::from_value::(invalid_json); + assert!(result.is_err()); +} + +#[tokio::test] +async fn test_e2e_backward_compat_no_guardrails() { + // Existing config without guardrails works unchanged + use std::io::Write; + use tempfile::NamedTempFile; + + let config_yaml = r#" +providers: + - key: openai + type: openai + api_key: "sk-test" +models: + - key: gpt-4 + type: gpt-4 + provider: openai +pipelines: + - name: default + type: chat + plugins: + - model-router: + models: + - gpt-4 +"#; + + let mut temp_file = NamedTempFile::new().unwrap(); + temp_file.write_all(config_yaml.as_bytes()).unwrap(); + + let config = hub_lib::config::load_config(temp_file.path().to_str().unwrap()).unwrap(); + assert!(config.guardrails.is_none()); + + // build_guardrail_resources with None guardrails returns None + let shared = config + .guardrails + .as_ref() + .and_then(build_guardrail_resources); + assert!(shared.is_none()); +} + +// --------------------------------------------------------------------------- +// Pipeline Integration (4 tests) +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_pre_call_guardrails_warn_and_continue() { + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(json!({"result": {"reason": "borderline"}, "pass": false})), + ) + .expect(1) + .mount(&eval_server) + .await; + + let guard = TestGuardBuilder::new("tone-check", GuardMode::PreCall) + .on_failure(OnFailure::Warn) + .api_base(&eval_server.uri()) + .evaluator_slug("tone-detection") + .build(); + + let client = TraceloopClient::new(); + let outcome = execute_guards(&[guard], "borderline input", &client, None).await; + + assert!(!outcome.blocked); + assert_eq!(outcome.warnings.len(), 1); + assert_eq!(outcome.warnings[0].guard_name, "tone-check"); +} + +#[tokio::test] +async fn test_post_call_guardrails_warn_and_add_header() { + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(json!({"result": {"reason": "mildly concerning"}, "pass": false})), + ) + .expect(1) + .mount(&eval_server) + .await; + + let guard = TestGuardBuilder::new("safety-check", GuardMode::PostCall) + .on_failure(OnFailure::Warn) + .api_base(&eval_server.uri()) + .evaluator_slug("pii-detector") + .build(); + + let client = TraceloopClient::new(); + let outcome = execute_guards(&[guard], "Some LLM response", &client, None).await; + + assert!(!outcome.blocked); + assert!(!outcome.warnings.is_empty()); + + // Verify warning header would be generated correctly + let header = warning_header_value(&outcome.warnings); + assert!(header.contains("guardrail_name=")); + assert!(header.contains("safety-check")); +} + +#[tokio::test] +async fn test_warning_header_format() { + let warnings = vec![GuardWarning { + guard_name: "my-guard".to_string(), + reason: "failed".to_string(), + }]; + let header = warning_header_value(&warnings); + assert_eq!(header, "guardrail_name=\"my-guard\", reason=\"failed\""); +} + +#[tokio::test] +async fn test_blocked_response_403_format() { + let outcome = GuardrailsOutcome { + results: vec![GuardResult::Failed { + name: "toxicity-check".to_string(), + result: json!({"reason": "toxic content"}), + on_failure: OnFailure::Block, + }], + blocked: true, + blocking_guard: Some("toxicity-check".to_string()), + warnings: vec![], + }; + let response = blocked_response(&outcome); + assert_eq!(response.status(), 403); + + let body = to_bytes(response.into_body(), 1024 * 1024).await.unwrap(); + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + assert_eq!(json["error"]["type"], "guardrail_blocked"); + assert_eq!(json["error"]["guardrail"], "toxicity-check"); + assert!( + json["error"]["message"] + .as_str() + .unwrap() + .contains("toxicity-check") + ); +} + +#[tokio::test] +async fn test_post_call_skipped_on_empty_response() { + // When the LLM returns empty content (e.g. max_tokens too low), + // post-call guards should be skipped and a warning returned. + let eval_server = MockServer::start().await; + Mock::given(matchers::any()) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({"result": {}, "pass": true}))) + .expect(0) // evaluator should never be called + .mount(&eval_server) + .await; + + let guard = TestGuardBuilder::new("toxicity-filter", GuardMode::PostCall) + .on_failure(OnFailure::Block) + .api_base(&eval_server.uri()) + .evaluator_slug("toxicity-detector") + .build(); + + let guardrails = Guardrails { + all_guards: Arc::new(vec![guard]), + pipeline_guard_names: vec!["toxicity-filter".to_string()], + client: Arc::new(TraceloopClient::new()), + }; + + let headers = HeaderMap::new(); + let runner = GuardrailsRunner::new(Some(&guardrails), &headers, None).unwrap(); + + // Completion with content: None (simulates empty LLM response) + let empty_completion = create_test_chat_completion(""); + let result = runner.run_post_call(&empty_completion).await; + + let warnings = result.expect("should not be blocked"); + assert_eq!(warnings.len(), 1); + assert!(warnings[0].reason.contains("empty response content")); + + let header = warning_header_value(&warnings); + assert!(header.contains("skipped")); + // wiremock will verify expect(0) — evaluator was never called +} + +#[tokio::test] +async fn test_evaluator_error_not_blocked_by_default() { + // Guards default to required: false (fail-open), so an evaluator HTTP 500 + // should NOT block the request. + let server = MockServer::start().await; + Mock::given(matchers::any()) + .respond_with(ResponseTemplate::new(500).set_body_string("Internal Server Error")) + .mount(&server) + .await; + + let guard = TestGuardBuilder::new("warn-guard", GuardMode::PreCall) + .on_failure(OnFailure::Warn) + .api_base(&server.uri()) + .evaluator_slug("profanity-detector") + .build(); + // guard.required is false by default + + let client = TraceloopClient::new(); + let outcome = execute_guards(&[guard], "test input", &client, None).await; + + assert!(!outcome.blocked); +} diff --git a/tests/guardrails/test_middleware.rs b/tests/guardrails/test_middleware.rs new file mode 100644 index 00000000..d0163293 --- /dev/null +++ b/tests/guardrails/test_middleware.rs @@ -0,0 +1,957 @@ +use hub_lib::guardrails::middleware::{GuardrailsLayer, MAX_BODY_SIZE}; +use hub_lib::guardrails::providers::traceloop::TraceloopClient; +use hub_lib::guardrails::types::{Guard, GuardMode, Guardrails, OnFailure}; + +use axum::body::{Body, to_bytes}; +use axum::extract::Request; +use axum::http::StatusCode; +use serde_json::json; +use std::sync::Arc; +use tower::{Layer, Service, ServiceExt}; +use wiremock::matchers; +use wiremock::{Mock, MockServer, ResponseTemplate}; + +use super::helpers::*; + +/// Helper to create a guard with a specific wiremock server +fn guard_with_server( + name: &str, + mode: GuardMode, + on_failure: OnFailure, + server_uri: &str, +) -> Guard { + Guard { + name: name.to_string(), + provider: "traceloop".to_string(), + evaluator_slug: "toxicity-detector".to_string(), + params: Default::default(), + mode, + on_failure, + required: false, + api_base: Some(server_uri.to_string()), + api_key: Some("test-key".to_string()), + } +} + +/// Helper to create a complete guardrails configuration +fn create_guardrails(guards: Vec) -> Guardrails { + let guard_names: Vec = guards.iter().map(|g| g.name.clone()).collect(); + Guardrails { + all_guards: Arc::new(guards), + pipeline_guard_names: guard_names, + client: Arc::new(TraceloopClient::new()), + } +} + +// =========================================================================== +// Category 1: Endpoint Type Detection +// =========================================================================== + +#[tokio::test] +async fn test_chat_completions_endpoint_detected() { + // Set up mock evaluator for pre-call guard + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {}, + "pass": true + }))) + .expect(1) + .mount(&eval_server) + .await; + + // Create guard + let guard = guard_with_server( + "detector", + GuardMode::PreCall, + OnFailure::Block, + &eval_server.uri(), + ); + let guardrails = create_guardrails(vec![guard]); + + // Create mock inner service + let completion = create_test_chat_completion("Response"); + let inner_service = MockService::with_json(StatusCode::OK, &completion); + + // Apply middleware + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + // Create chat request + let request = create_test_chat_request("Test input"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + // Call middleware + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); + + // Verify response is 200 OK (guard passed) + assert_eq!(response.status(), StatusCode::OK); + + // Wiremock verifies evaluator was called (expect(1)) +} + +#[tokio::test] +async fn test_completions_endpoint_detected() { + // Set up mock evaluator + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {}, + "pass": true + }))) + .expect(1) + .mount(&eval_server) + .await; + + let guard = guard_with_server( + "detector", + GuardMode::PreCall, + OnFailure::Block, + &eval_server.uri(), + ); + let guardrails = create_guardrails(vec![guard]); + + let completion = create_test_completion_response("Response text"); + let inner_service = MockService::with_json(StatusCode::OK, &completion); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + // Create completion request + let request = create_test_completion_request("Complete this"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/completions") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); +} + +#[tokio::test] +async fn test_embeddings_endpoint_detected() { + // Set up mock evaluator + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {}, + "pass": true + }))) + .expect(1) + .mount(&eval_server) + .await; + + let guard = guard_with_server( + "detector", + GuardMode::PreCall, + OnFailure::Block, + &eval_server.uri(), + ); + let guardrails = create_guardrails(vec![guard]); + + let embeddings_response = create_test_embeddings_response(); + let inner_service = MockService::with_json(StatusCode::OK, &embeddings_response); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + // Create embeddings request + let request = create_test_embeddings_request("Embed this text"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/embeddings") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); +} + +// =========================================================================== +// Category 2: Pre-Call Guard Behavior +// =========================================================================== + +#[tokio::test] +async fn test_pre_call_guard_blocks_chat() { + // Set up mock evaluator that fails + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {"reason": "toxic content"}, + "pass": false + }))) + .expect(1) + .mount(&eval_server) + .await; + + let guard = guard_with_server( + "blocker", + GuardMode::PreCall, + OnFailure::Block, + &eval_server.uri(), + ); + let guardrails = create_guardrails(vec![guard]); + + let completion = create_test_chat_completion("This won't be returned"); + let inner_service = MockService::with_json(StatusCode::OK, &completion); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + let request = create_test_chat_request("Bad input"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); + + // Verify blocked + assert_eq!(response.status(), StatusCode::FORBIDDEN); + + let body_bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap(); + assert_eq!(response_json["error"]["guardrail"], "blocker"); +} + +#[tokio::test] +async fn test_pre_call_guard_warns_chat() { + // Set up mock evaluator that fails with warn + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {"reason": "borderline"}, + "pass": false + }))) + .expect(1) + .mount(&eval_server) + .await; + + let guard = guard_with_server( + "warner", + GuardMode::PreCall, + OnFailure::Warn, + &eval_server.uri(), + ); + let guardrails = create_guardrails(vec![guard]); + + let completion = create_test_chat_completion("Response text"); + let inner_service = MockService::with_json(StatusCode::OK, &completion); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + let request = create_test_chat_request("Borderline input"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); + + // Verify passes with warning + assert_eq!(response.status(), StatusCode::OK); + assert!( + response + .headers() + .contains_key("x-traceloop-guardrail-warning") + ); + + let warning_header = response + .headers() + .get("x-traceloop-guardrail-warning") + .unwrap() + .to_str() + .unwrap(); + assert!(warning_header.contains("warner")); +} + +// =========================================================================== +// Category 3: Post-Call Guard Behavior +// =========================================================================== + +#[tokio::test] +async fn test_post_call_guard_blocks_chat() { + // Set up mock evaluator for post-call that fails + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {"reason": "unsafe output"}, + "pass": false + }))) + .expect(1) + .mount(&eval_server) + .await; + + let guard = guard_with_server( + "output-blocker", + GuardMode::PostCall, + OnFailure::Block, + &eval_server.uri(), + ); + let guardrails = create_guardrails(vec![guard]); + + let completion = create_test_chat_completion("Unsafe response"); + let inner_service = MockService::with_json(StatusCode::OK, &completion); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + let request = create_test_chat_request("Safe input"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); + + // Verify blocked by post-call guard + assert_eq!(response.status(), StatusCode::FORBIDDEN); + + let body_bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap(); + assert_eq!(response_json["error"]["guardrail"], "output-blocker"); +} + +#[tokio::test] +async fn test_post_call_guard_skipped_for_embeddings() { + // Set up mock evaluator for pre-call (should be called) + let pre_eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {}, + "pass": true + }))) + .expect(1) // Pre-call should run + .mount(&pre_eval_server) + .await; + + // Set up mock evaluator for post-call (should NOT be called) + let post_eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {}, + "pass": false // Would block if called + }))) + .expect(0) // Post-call should NOT run for embeddings + .mount(&post_eval_server) + .await; + + // Create both pre-call and post-call guards + let pre_guard = guard_with_server( + "pre-guard", + GuardMode::PreCall, + OnFailure::Block, + &pre_eval_server.uri(), + ); + let post_guard = guard_with_server( + "post-guard", + GuardMode::PostCall, + OnFailure::Block, + &post_eval_server.uri(), + ); + let guardrails = create_guardrails(vec![pre_guard, post_guard]); + + let embeddings_response = create_test_embeddings_response(); + let inner_service = MockService::with_json(StatusCode::OK, &embeddings_response); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + let request = create_test_embeddings_request("Embed this"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/embeddings") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); + + // Verify response is 200 OK (no post-call blocking) + assert_eq!(response.status(), StatusCode::OK); + + // Verify no warning header + assert!( + !response + .headers() + .contains_key("x-traceloop-guardrail-warning") + ); + + // Verify response body contains embeddings + let body_bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap(); + assert_eq!(response_json["object"], "list"); + + // Wiremock verifies: + // - pre-call evaluator called exactly once (expect(1)) + // - post-call evaluator never called (expect(0)) +} + +// =========================================================================== +// Category 4: Streaming Behavior +// =========================================================================== + +#[tokio::test] +async fn test_streaming_chat_runs_pre_call_guards() { + // Set up mock evaluator (should be called for pre-call) + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {}, + "pass": true + }))) + .expect(1) // Pre-call guard should run even for streaming + .mount(&eval_server) + .await; + + let guard = guard_with_server( + "detector", + GuardMode::PreCall, + OnFailure::Block, + &eval_server.uri(), + ); + let guardrails = create_guardrails(vec![guard]); + + let completion = create_test_chat_completion("Streamed response"); + let inner_service = MockService::with_json(StatusCode::OK, &completion); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + // Create STREAMING chat request + let request = create_streaming_chat_request("Safe input"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); + + // Verify response is 200 OK (pre-call guard passed) + assert_eq!(response.status(), StatusCode::OK); + + // Wiremock verifies evaluator was called exactly once (expect(1)) +} + +#[tokio::test] +async fn test_streaming_completion_runs_pre_call_guards() { + // Set up mock evaluator (should be called for pre-call) + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {}, + "pass": true + }))) + .expect(1) // Pre-call guard should run even for streaming + .mount(&eval_server) + .await; + + let guard = guard_with_server( + "detector", + GuardMode::PreCall, + OnFailure::Block, + &eval_server.uri(), + ); + let guardrails = create_guardrails(vec![guard]); + + let completion = create_test_completion_response("Streamed completion"); + let inner_service = MockService::with_json(StatusCode::OK, &completion); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + // Create STREAMING completion request + let request = create_streaming_completion_request("Safe input"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/completions") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); + + // Verify response is 200 OK (pre-call guard passed) + assert_eq!(response.status(), StatusCode::OK); + + // Wiremock verifies evaluator was called exactly once (expect(1)) +} + +#[tokio::test] +async fn test_streaming_chat_pre_call_blocks() { + // Set up mock evaluator that fails + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {"reason": "toxic content"}, + "pass": false + }))) + .expect(1) + .mount(&eval_server) + .await; + + let guard = guard_with_server( + "blocker", + GuardMode::PreCall, + OnFailure::Block, + &eval_server.uri(), + ); + let guardrails = create_guardrails(vec![guard]); + + let completion = create_test_chat_completion("This won't be returned"); + let inner_service = MockService::with_json(StatusCode::OK, &completion); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + // Create STREAMING chat request with bad input + let request = create_streaming_chat_request("Bad input"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); + + // Verify blocked by pre-call guard even for streaming + assert_eq!(response.status(), StatusCode::FORBIDDEN); + + let body_bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap(); + assert_eq!(response_json["error"]["guardrail"], "blocker"); +} + +#[tokio::test] +async fn test_streaming_chat_pre_call_warns() { + // Set up mock evaluator that fails with warn + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {"reason": "borderline"}, + "pass": false + }))) + .expect(1) + .mount(&eval_server) + .await; + + let guard = guard_with_server( + "warner", + GuardMode::PreCall, + OnFailure::Warn, + &eval_server.uri(), + ); + let guardrails = create_guardrails(vec![guard]); + + let completion = create_test_chat_completion("Response text"); + let inner_service = MockService::with_json(StatusCode::OK, &completion); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + let request = create_streaming_chat_request("Borderline input"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); + + // Verify passes with warning even for streaming + assert_eq!(response.status(), StatusCode::OK); + assert!( + response + .headers() + .contains_key("x-traceloop-guardrail-warning") + ); + + let warning_header = response + .headers() + .get("x-traceloop-guardrail-warning") + .unwrap() + .to_str() + .unwrap(); + assert!(warning_header.contains("warner")); +} + +#[tokio::test] +async fn test_streaming_chat_post_call_skipped() { + // Set up mock evaluator for pre-call (should pass) + let pre_eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {}, + "pass": true + }))) + .expect(1) // Pre-call should run + .mount(&pre_eval_server) + .await; + + // Set up mock evaluator for post-call (should NOT be called) + let post_eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {}, + "pass": false // Would block if called + }))) + .expect(0) // Post-call should NOT run for streaming + .mount(&post_eval_server) + .await; + + let pre_guard = guard_with_server( + "pre-guard", + GuardMode::PreCall, + OnFailure::Block, + &pre_eval_server.uri(), + ); + let post_guard = guard_with_server( + "post-guard", + GuardMode::PostCall, + OnFailure::Block, + &post_eval_server.uri(), + ); + let guardrails = create_guardrails(vec![pre_guard, post_guard]); + + let completion = create_test_chat_completion("Streamed response"); + let inner_service = MockService::with_json(StatusCode::OK, &completion); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + let request = create_streaming_chat_request("Safe input"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); + + // Verify 200 OK — post-call guard was skipped for streaming + assert_eq!(response.status(), StatusCode::OK); + + // Wiremock verifies: + // - pre-call evaluator called exactly once (expect(1)) + // - post-call evaluator never called (expect(0)) +} + +#[tokio::test] +async fn test_streaming_completion_pre_call_blocks() { + // Set up mock evaluator that fails + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {"reason": "toxic content"}, + "pass": false + }))) + .expect(1) + .mount(&eval_server) + .await; + + let guard = guard_with_server( + "blocker", + GuardMode::PreCall, + OnFailure::Block, + &eval_server.uri(), + ); + let guardrails = create_guardrails(vec![guard]); + + let completion = create_test_completion_response("This won't be returned"); + let inner_service = MockService::with_json(StatusCode::OK, &completion); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + // Create STREAMING completion request with bad input + let request = create_streaming_completion_request("Bad input"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/completions") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); + + // Verify blocked by pre-call guard even for streaming + assert_eq!(response.status(), StatusCode::FORBIDDEN); + + let body_bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap(); + assert_eq!(response_json["error"]["guardrail"], "blocker"); +} + +// =========================================================================== +// Category 5: Pass-Through Scenarios +// =========================================================================== + +#[tokio::test] +async fn test_no_guardrails_configured_passes() { + // No guardrails configured + let completion = create_test_chat_completion("Response"); + let inner_service = MockService::with_json(StatusCode::OK, &completion); + + let layer = GuardrailsLayer::new(None); // No guardrails + let mut service = layer.layer(inner_service); + + let request = create_test_chat_request("Any input"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); + + // Verify response passes through unchanged + assert_eq!(response.status(), StatusCode::OK); +} + +#[tokio::test] +async fn test_unsupported_endpoint_passes() { + let eval_server = MockServer::start().await; + Mock::given(matchers::any()) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {}, + "pass": false + }))) + .expect(0) // Should never be called + .mount(&eval_server) + .await; + + let guard = guard_with_server( + "blocker", + GuardMode::PreCall, + OnFailure::Block, + &eval_server.uri(), + ); + let guardrails = create_guardrails(vec![guard]); + + let inner_service = + MockService::with_json(StatusCode::OK, &json!({"data": [{"id": "model-1"}]})); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + // Request to unsupported endpoint + let http_request = Request::builder() + .method("GET") + .uri("/v1/models") // Unsupported endpoint + .body(Body::empty()) + .unwrap(); + + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); + + // Verify passes through + assert_eq!(response.status(), StatusCode::OK); +} + +// =========================================================================== +// Category 6: Body Size Limits +// =========================================================================== + +#[tokio::test] +async fn test_request_body_exceeding_limit_returns_400() { + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {}, + "pass": true + }))) + .expect(0) // Should never reach the evaluator + .mount(&eval_server) + .await; + + let guard = guard_with_server( + "detector", + GuardMode::PreCall, + OnFailure::Block, + &eval_server.uri(), + ); + let guardrails = create_guardrails(vec![guard]); + + let completion = create_test_chat_completion("Response"); + let inner_service = MockService::with_json(StatusCode::OK, &completion); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + // Body larger than MAX_BODY_SIZE (10 MB) + let oversized = vec![b'x'; MAX_BODY_SIZE + 1]; + let http_request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from(oversized)) + .unwrap(); + + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); +} + +#[tokio::test] +async fn test_request_body_within_limit_is_accepted() { + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {}, + "pass": true + }))) + .expect(1) + .mount(&eval_server) + .await; + + let guard = guard_with_server( + "detector", + GuardMode::PreCall, + OnFailure::Block, + &eval_server.uri(), + ); + let guardrails = create_guardrails(vec![guard]); + + let completion = create_test_chat_completion("Response"); + let inner_service = MockService::with_json(StatusCode::OK, &completion); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + let request = create_test_chat_request("Test input"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); + + // Should pass through (200 from inner service) + assert_eq!(response.status(), StatusCode::OK); +} diff --git a/tests/guardrails/test_parsing.rs b/tests/guardrails/test_parsing.rs new file mode 100644 index 00000000..5dab2db8 --- /dev/null +++ b/tests/guardrails/test_parsing.rs @@ -0,0 +1,145 @@ +use hub_lib::guardrails::parsing::{ + CompletionExtractor, PromptExtractor, parse_evaluator_http_response, parse_evaluator_response, +}; +use hub_lib::guardrails::types::GuardrailError; +use hub_lib::models::content::{ChatCompletionMessage, ChatMessageContent, ChatMessageContentPart}; + +use super::helpers::*; + +#[test] +fn test_extract_text_single_user_message() { + let request = create_test_chat_request("Hello world"); + let text = request.extract_prompt(); + assert_eq!(text, "Hello world"); +} + +#[test] +fn test_extract_text_multi_turn_conversation() { + let mut request = default_request(); + request.messages = vec![ + ChatCompletionMessage { + role: "system".to_string(), + content: Some(ChatMessageContent::String("You are helpful".to_string())), + ..default_message() + }, + ChatCompletionMessage { + role: "user".to_string(), + content: Some(ChatMessageContent::String("First question".to_string())), + ..default_message() + }, + ChatCompletionMessage { + role: "assistant".to_string(), + content: Some(ChatMessageContent::String("First answer".to_string())), + ..default_message() + }, + ChatCompletionMessage { + role: "user".to_string(), + content: Some(ChatMessageContent::String("Follow-up question".to_string())), + ..default_message() + }, + ]; + let text = request.extract_prompt(); + assert_eq!( + text, + "You are helpful\nFirst question\nFirst answer\nFollow-up question" + ); +} + +#[test] +fn test_extract_text_from_array_content_parts() { + let mut request = create_test_chat_request(""); + request.messages[0].content = Some(ChatMessageContent::Array(vec![ + ChatMessageContentPart { + r#type: "text".to_string(), + text: "Part 1".to_string(), + }, + ChatMessageContentPart { + r#type: "text".to_string(), + text: "Part 2".to_string(), + }, + ])); + let text = request.extract_prompt(); + assert_eq!(text, "Part 1 Part 2"); +} + +#[test] +fn test_extract_response_from_chat_completion() { + let completion = create_test_chat_completion("Here is my response"); + let text = completion.extract_completion(); + assert_eq!(text, "Here is my response"); +} + +#[test] +fn test_extract_handles_empty_content() { + let mut request = create_test_chat_request(""); + request.messages[0].content = None; + let text = request.extract_prompt(); + assert_eq!(text, ""); +} + +// --------------------------------------------------------------------------- +// Response Parsing (8 tests) +// --------------------------------------------------------------------------- + +#[test] +fn test_parse_successful_pass_response() { + let body = r#"{"result": {"score": 0.95, "label": "safe"}, "pass": true}"#; + let response = parse_evaluator_response(body).unwrap(); + assert!(response.pass); + assert_eq!(response.result["score"], 0.95); +} + +#[test] +fn test_parse_failed_response() { + let body = r#"{"result": {"score": 0.2, "reason": "Toxic content"}, "pass": false}"#; + let response = parse_evaluator_response(body).unwrap(); + assert!(!response.pass); + assert_eq!(response.result["reason"], "Toxic content"); +} + +#[test] +fn test_parse_with_result_details() { + let body = r#"{"result": {"score": 0.75, "label": "borderline", "categories": ["violence", "profanity"]}, "pass": true}"#; + let response = parse_evaluator_response(body).unwrap(); + assert!(response.pass); + assert_eq!(response.result["label"], "borderline"); + assert_eq!(response.result["categories"][0], "violence"); +} + +#[test] +fn test_parse_missing_pass_field() { + let body = r#"{"result": {"score": 0.5}}"#; + let result = parse_evaluator_response(body); + assert!(result.is_err()); +} + +#[test] +fn test_parse_malformed_json() { + let body = "not json {at all"; + let result = parse_evaluator_response(body); + assert!(result.is_err()); +} + +#[test] +fn test_parse_non_json_response() { + let body = "Internal Server Error"; + let result = parse_evaluator_response(body); + assert!(result.is_err()); +} + +#[test] +fn test_parse_empty_response_body() { + let body = ""; + let result = parse_evaluator_response(body); + assert!(result.is_err()); +} + +#[test] +fn test_parse_http_error_status() { + let result = parse_evaluator_http_response(500, "Internal Server Error"); + assert!(result.is_err()); + match result.unwrap_err() { + GuardrailError::HttpError { status, .. } => assert_eq!(status, 500), + other => panic!("Expected HttpError, got {other:?}"), + } +} diff --git a/tests/guardrails/test_run_evaluator.rs b/tests/guardrails/test_run_evaluator.rs new file mode 100644 index 00000000..4391e389 --- /dev/null +++ b/tests/guardrails/test_run_evaluator.rs @@ -0,0 +1,276 @@ +use hub_lib::guardrails::evaluator_types::get_evaluator; +use hub_lib::guardrails::providers::traceloop::TraceloopClient; +use hub_lib::guardrails::types::*; +use serde::Deserialize; +use serde_json::{Value, json}; +use std::collections::HashMap; +use std::fs; +use std::path::PathBuf; +use wiremock::matchers; +use wiremock::{Mock, MockServer, ResponseTemplate}; + +use super::helpers::*; + +// --------------------------------------------------------------------------- +// Infrastructure +// --------------------------------------------------------------------------- + +struct EvaluatorTestCase { + slug: &'static str, + cassette_name: &'static str, + input: &'static str, + params: HashMap, + expected_pass: bool, +} + +#[derive(Deserialize)] +struct Cassette { + response_body: Value, +} + +fn load_cassette(name: &str) -> Cassette { + let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("tests/cassettes/guardrails") + .join(format!("{}.json", name)); + let content = fs::read_to_string(&path) + .unwrap_or_else(|_| panic!("Cassette '{}' not found at {:?}.", name, path)); + serde_json::from_str(&content).expect("Failed to parse cassette JSON") +} + +/// Set up a wiremock server, execute the guard, and verify the request was correct. +async fn run_evaluator_test(tc: &EvaluatorTestCase) { + let cassette = load_cassette(tc.cassette_name); + + // Build the expected request body using the evaluator's build_body() + let evaluator = get_evaluator(tc.slug).unwrap(); + let expected_body = evaluator.build_body(tc.input, &tc.params).unwrap(); + + // Set up wiremock with strict matchers that verify the request + let server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .and(matchers::path(format!( + "/v2/guardrails/execute/{}", + tc.slug + ))) + .and(matchers::header("Authorization", "Bearer test-api-key")) + .and(matchers::header("Content-Type", "application/json")) + .and(matchers::body_json(&expected_body)) + .respond_with(ResponseTemplate::new(200).set_body_json(&cassette.response_body)) + .expect(1) + .mount(&server) + .await; + + // Create guard pointing at the mock server and execute + let mut guard = create_test_guard_with_api_base(tc.slug, GuardMode::PreCall, &server.uri()); + guard.evaluator_slug = tc.slug.to_string(); + guard.params = tc.params.clone(); + + let client = TraceloopClient::new(); + let result = client.evaluate(&guard, tc.input).await.unwrap(); + + // Verify the response was interpreted correctly + assert_eq!( + result.pass, tc.expected_pass, + "{}: expected pass={}, got pass={}", + tc.slug, tc.expected_pass, result.pass + ); + // wiremock .expect(1) verifies the request matched all matchers +} + +// --------------------------------------------------------------------------- +// 1. PII Detector (text body, optional probability_threshold config) +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_cassette_pii_detector() { + run_evaluator_test(&EvaluatorTestCase { + slug: "pii-detector", + cassette_name: "pii_detector_pass", + input: "The weather is sunny today and I like programming in Rust.", + params: HashMap::new(), + expected_pass: true, + }) + .await; +} + +// --------------------------------------------------------------------------- +// 2. Secrets Detector (text body, no config) +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_cassette_secrets_detector() { + run_evaluator_test(&EvaluatorTestCase { + slug: "secrets-detector", + cassette_name: "secrets_detector_pass", + input: "Here is a simple function that adds two numbers together.", + params: HashMap::new(), + expected_pass: true, + }) + .await; +} + +// --------------------------------------------------------------------------- +// 3. Prompt Injection (prompt body, optional threshold config) +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_cassette_prompt_injection() { + run_evaluator_test(&EvaluatorTestCase { + slug: "prompt-injection", + cassette_name: "prompt_injection_pass", + input: "What is the capital of France?", + params: HashMap::new(), + expected_pass: true, + }) + .await; +} + +// --------------------------------------------------------------------------- +// 4. Profanity Detector (text body, no config) +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_cassette_profanity_detector() { + run_evaluator_test(&EvaluatorTestCase { + slug: "profanity-detector", + cassette_name: "profanity_detector_fail", + input: "This is damn bullshit and I think it's a total crap product.", + params: HashMap::new(), + expected_pass: false, + }) + .await; +} + +// --------------------------------------------------------------------------- +// 5. Sexism Detector (text body, threshold config) +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_cassette_sexism_detector() { + run_evaluator_test(&EvaluatorTestCase { + slug: "sexism-detector", + cassette_name: "sexism_detector_fail", + input: "Women should not be in leadership positions because they are too emotional.", + params: HashMap::from([("threshold".to_string(), json!(0.5))]), + expected_pass: false, + }) + .await; +} + +// --------------------------------------------------------------------------- +// 6. Toxicity Detector (text body, threshold config) +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_cassette_toxicity_detector() { + run_evaluator_test(&EvaluatorTestCase { + slug: "toxicity-detector", + cassette_name: "toxicity_detector_fail", + input: "You are a complete idiot and everyone hates you. You should be ashamed.", + params: HashMap::from([("threshold".to_string(), json!(0.5))]), + expected_pass: false, + }) + .await; +} + +// --------------------------------------------------------------------------- +// 7. Regex Validator (text body, regex config) +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_cassette_regex_validator() { + run_evaluator_test(&EvaluatorTestCase { + slug: "regex-validator", + cassette_name: "regex_validator_pass", + input: "Order ID: ABC-12345", + params: HashMap::from([ + ("regex".to_string(), json!(r"^[A-Z]{3}-\d{5}$")), + ("should_match".to_string(), json!(true)), + ("case_sensitive".to_string(), json!(true)), + ("dot_include_nl".to_string(), json!(true)), + ("multi_line".to_string(), json!(true)), + ]), + expected_pass: true, + }) + .await; +} + +// --------------------------------------------------------------------------- +// 8. JSON Validator (text body, optional schema config) +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_cassette_json_validator() { + run_evaluator_test(&EvaluatorTestCase { + slug: "json-validator", + cassette_name: "json_validator_pass", + input: r#"{"name": "Alice", "age": 30}"#, + params: HashMap::new(), + expected_pass: true, + }) + .await; +} + +// --------------------------------------------------------------------------- +// 9. SQL Validator (text body, no config) +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_cassette_sql_validator() { + run_evaluator_test(&EvaluatorTestCase { + slug: "sql-validator", + cassette_name: "sql_validator_pass", + input: "SELECT id, name FROM users WHERE active = true ORDER BY name", + params: HashMap::new(), + expected_pass: true, + }) + .await; +} + +// --------------------------------------------------------------------------- +// 10. Tone Detection (text body, no config) +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_cassette_tone_detection() { + run_evaluator_test(&EvaluatorTestCase { + slug: "tone-detection", + cassette_name: "tone_detection_fail", + input: "This is ABSOLUTELY UNACCEPTABLE. I DEMAND to speak to someone competent immediately!", + params: HashMap::new(), + expected_pass: false, + }) + .await; +} + +// --------------------------------------------------------------------------- +// 11. Prompt Perplexity (prompt body, no config) +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_cassette_prompt_perplexity() { + run_evaluator_test(&EvaluatorTestCase { + slug: "prompt-perplexity", + cassette_name: "prompt_perplexity_pass", + input: "Please explain the concept of photosynthesis in simple terms.", + params: HashMap::new(), + expected_pass: true, + }) + .await; +} + +// --------------------------------------------------------------------------- +// 12. Uncertainty Detector (prompt body, no config) +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_cassette_uncertainty_detector() { + run_evaluator_test(&EvaluatorTestCase { + slug: "uncertainty-detector", + cassette_name: "uncertainty_detector_pass", + input: "What is 2 + 2?", + params: HashMap::new(), + expected_pass: true, + }) + .await; +} diff --git a/tests/guardrails/test_runner.rs b/tests/guardrails/test_runner.rs new file mode 100644 index 00000000..c6e91322 --- /dev/null +++ b/tests/guardrails/test_runner.rs @@ -0,0 +1,423 @@ +use hub_lib::guardrails::parsing::CompletionExtractor; +use hub_lib::guardrails::runner::*; +use hub_lib::guardrails::types::*; +use opentelemetry::Context; +use opentelemetry::trace::{Span, SpanKind, TraceContextExt, Tracer}; +use opentelemetry_sdk::export::trace::SpanData; +use opentelemetry_sdk::testing::trace::InMemorySpanExporter; +use opentelemetry_sdk::trace::TracerProvider; + +use super::helpers::*; + +#[tokio::test] +async fn test_execute_single_pre_call_guard_passes() { + let guard = create_test_guard("check", GuardMode::PreCall); + let mock_client = MockGuardrailClient::with_response("check", Ok(passing_response())); + let outcome = execute_guards(&[guard], "test input", &mock_client, None).await; + assert!(!outcome.blocked); + assert_eq!(outcome.results.len(), 1); + assert!(matches!(&outcome.results[0], GuardResult::Passed { .. })); + assert!(outcome.warnings.is_empty()); +} + +#[tokio::test] +async fn test_execute_single_pre_call_guard_fails_block() { + let guard = + create_test_guard_with_failure_action("check", GuardMode::PreCall, OnFailure::Block); + let mock_client = MockGuardrailClient::with_response("check", Ok(failing_response())); + let outcome = execute_guards(&[guard], "toxic input", &mock_client, None).await; + assert!(outcome.blocked); + assert_eq!(outcome.blocking_guard, Some("check".to_string())); +} + +#[tokio::test] +async fn test_execute_single_pre_call_guard_fails_warn() { + let guard = create_test_guard_with_failure_action("check", GuardMode::PreCall, OnFailure::Warn); + let mock_client = MockGuardrailClient::with_response("check", Ok(failing_response())); + let outcome = execute_guards(&[guard], "borderline input", &mock_client, None).await; + assert!(!outcome.blocked); + assert_eq!(outcome.warnings.len(), 1); + assert_eq!(outcome.warnings[0].guard_name, "check"); +} + +#[tokio::test] +async fn test_execute_multiple_pre_call_guards_all_pass() { + let guards = vec![ + create_test_guard("guard-1", GuardMode::PreCall), + create_test_guard("guard-2", GuardMode::PreCall), + create_test_guard("guard-3", GuardMode::PreCall), + ]; + let mock_client = MockGuardrailClient::with_responses(vec![ + ("guard-1", Ok(passing_response())), + ("guard-2", Ok(passing_response())), + ("guard-3", Ok(passing_response())), + ]); + let outcome = execute_guards(&guards, "safe input", &mock_client, None).await; + assert!(!outcome.blocked); + assert_eq!(outcome.results.len(), 3); + assert!(outcome.warnings.is_empty()); +} + +#[tokio::test] +async fn test_execute_multiple_guards_one_blocks() { + let guards = vec![ + create_test_guard("guard-1", GuardMode::PreCall), + create_test_guard_with_failure_action("guard-2", GuardMode::PreCall, OnFailure::Block), + create_test_guard("guard-3", GuardMode::PreCall), + ]; + let mock_client = MockGuardrailClient::with_responses(vec![ + ("guard-1", Ok(passing_response())), + ("guard-2", Ok(failing_response())), + ("guard-3", Ok(passing_response())), + ]); + let outcome = execute_guards(&guards, "input", &mock_client, None).await; + assert!(outcome.blocked); + assert_eq!(outcome.blocking_guard, Some("guard-2".to_string())); +} + +#[tokio::test] +async fn test_execute_multiple_guards_one_warns_continue() { + let guards = vec![ + create_test_guard("guard-1", GuardMode::PreCall), + create_test_guard_with_failure_action("guard-2", GuardMode::PreCall, OnFailure::Warn), + create_test_guard("guard-3", GuardMode::PreCall), + ]; + let mock_client = MockGuardrailClient::with_responses(vec![ + ("guard-1", Ok(passing_response())), + ("guard-2", Ok(failing_response())), + ("guard-3", Ok(passing_response())), + ]); + let outcome = execute_guards(&guards, "input", &mock_client, None).await; + assert!(!outcome.blocked); + assert_eq!(outcome.results.len(), 3); + assert_eq!(outcome.warnings.len(), 1); +} + +#[tokio::test] +async fn test_guard_evaluator_unavailable_required_false() { + let guard = create_test_guard_with_required("check", GuardMode::PreCall, false); + let mock_client = MockGuardrailClient::with_response( + "check", + Err(GuardrailError::Unavailable( + "connection refused".to_string(), + )), + ); + let outcome = execute_guards(&[guard], "input", &mock_client, None).await; + assert!(!outcome.blocked); // Fail-open + assert!(matches!( + &outcome.results[0], + GuardResult::Error { + required: false, + .. + } + )); + // Non-required guard error should produce a warning header (fail-open) + assert_eq!(outcome.warnings.len(), 1); + assert_eq!(outcome.warnings[0].guard_name, "check"); + assert!(outcome.warnings[0].reason.contains("evaluator error")); +} + +#[tokio::test] +async fn test_guard_evaluator_unavailable_required_true() { + let guard = create_test_guard_with_required("check", GuardMode::PreCall, true); + let mock_client = MockGuardrailClient::with_response( + "check", + Err(GuardrailError::Unavailable( + "connection refused".to_string(), + )), + ); + let outcome = execute_guards(&[guard], "input", &mock_client, None).await; + assert!(outcome.blocked); // Fail-closed + // Required guard error should NOT produce a warning (it blocks instead) + assert!(outcome.warnings.is_empty()); +} + +#[tokio::test] +async fn test_execute_post_call_guards_non_streaming() { + let guard = create_test_guard("response-check", GuardMode::PostCall); + let mock_client = MockGuardrailClient::with_response("response-check", Ok(passing_response())); + let completion = create_test_chat_completion("Safe response text"); + let response_text = completion.extract_completion(); + let outcome = execute_guards(&[guard], &response_text, &mock_client, None).await; + assert!(!outcome.blocked); +} + +#[tokio::test] +async fn test_execute_post_call_guards_streaming_accumulated() { + let guard = create_test_guard("response-check", GuardMode::PostCall); + let mock_client = MockGuardrailClient::with_response("response-check", Ok(passing_response())); + let accumulated_text = "Hello world from streaming!"; + let outcome = execute_guards(&[guard], accumulated_text, &mock_client, None).await; + assert!(!outcome.blocked); +} + +#[tokio::test] +async fn test_parallel_execution_of_independent_guards() { + // This test verifies guards run concurrently, not sequentially. + // We use the mock client without delay here; the implementation should + // use futures::join_all or similar for parallel execution. + let guards = vec![ + create_test_guard("guard-1", GuardMode::PreCall), + create_test_guard("guard-2", GuardMode::PreCall), + ]; + let mock_client = MockGuardrailClient::with_responses(vec![ + ("guard-1", Ok(passing_response())), + ("guard-2", Ok(passing_response())), + ]); + let start = std::time::Instant::now(); + let outcome = execute_guards(&guards, "input", &mock_client, None).await; + let _elapsed = start.elapsed(); + assert!(!outcome.blocked); + assert_eq!(outcome.results.len(), 2); +} + +#[tokio::test] +async fn test_executor_returns_correct_guardrails_outcome() { + let guards = vec![ + create_test_guard_with_failure_action("passer", GuardMode::PreCall, OnFailure::Block), + create_test_guard_with_failure_action("warner", GuardMode::PreCall, OnFailure::Warn), + create_test_guard_with_failure_action("blocker", GuardMode::PreCall, OnFailure::Block), + ]; + let mock_client = MockGuardrailClient::with_responses(vec![ + ("passer", Ok(passing_response())), + ("warner", Ok(failing_response())), + ("blocker", Ok(failing_response())), + ]); + let outcome = execute_guards(&guards, "input", &mock_client, None).await; + assert!(outcome.blocked); + assert_eq!(outcome.blocking_guard, Some("blocker".to_string())); + assert!(outcome.warnings.iter().any(|w| w.guard_name == "warner")); +} + +// --------------------------------------------------------------------------- +// Guard Span Creation +// --------------------------------------------------------------------------- + +use std::sync::LazyLock; + +/// Shared OTel exporter + provider, initialized once for all span tests. +/// Each test creates a unique parent span with a unique trace_id, then filters +/// exported spans by that trace_id — so tests are isolated despite sharing state. +static TEST_EXPORTER: LazyLock = LazyLock::new(|| { + let exporter = InMemorySpanExporter::default(); + let provider = TracerProvider::builder() + .with_simple_exporter(exporter.clone()) + .build(); + opentelemetry::global::set_tracer_provider(provider); + exporter +}); + +/// Helper: create a parent Context from the global tracer, returning +/// the Context and the parent's SpanContext for later assertions. +fn create_parent_context() -> (Context, opentelemetry::trace::SpanContext) { + let tracer = opentelemetry::global::tracer("traceloop_hub"); + let parent_span = tracer.start("traceloop_hub"); + let span_ctx = parent_span.span_context().clone(); + let cx = Context::current().with_span(parent_span); + (cx, span_ctx) +} + +/// Helper: collect guard spans from the shared exporter, filtering by trace_id. +fn get_guard_spans(trace_id: opentelemetry::trace::TraceId) -> Vec { + TEST_EXPORTER + .get_finished_spans() + .unwrap() + .into_iter() + .filter(|s| s.span_context.trace_id() == trace_id) + .filter(|s| s.name.ends_with(".guard")) + .collect() +} + +#[tokio::test] +async fn test_guard_spans_created_with_parent_context() { + let _ = &*TEST_EXPORTER; // ensure global provider is set + let (parent_cx, parent_span_ctx) = create_parent_context(); + + let guards = vec![ + create_test_guard("pii-check", GuardMode::PreCall), + create_test_guard("secrets-check", GuardMode::PostCall), + ]; + let mock_client = MockGuardrailClient::with_responses(vec![ + ("pii-check", Ok(passing_response())), + ("secrets-check", Ok(failing_response())), + ]); + + let _outcome = execute_guards(&guards, "test input", &mock_client, Some(&parent_cx)).await; + drop(parent_cx); + + let spans = get_guard_spans(parent_span_ctx.trace_id()); + assert_eq!( + spans.len(), + 2, + "Expected 2 guard spans, got {}", + spans.len() + ); + + let span_names: Vec<&str> = spans.iter().map(|s| s.name.as_ref()).collect(); + assert!(span_names.contains(&"pii-check.guard")); + assert!(span_names.contains(&"secrets-check.guard")); + + // All guard spans should be children of the parent + for span in &spans { + assert_eq!( + span.parent_span_id, + parent_span_ctx.span_id(), + "Guard span '{}' should be child of the parent span", + span.name + ); + assert_eq!(span.span_context.trace_id(), parent_span_ctx.trace_id()); + assert_eq!(span.span_kind, SpanKind::Internal); + } +} + +#[tokio::test] +async fn test_guard_span_attributes_on_pass() { + let _ = &*TEST_EXPORTER; + let (parent_cx, parent_span_ctx) = create_parent_context(); + + let guard = create_test_guard("pii-check", GuardMode::PreCall); + let mock_client = MockGuardrailClient::with_response("pii-check", Ok(passing_response())); + + let _outcome = execute_guards(&[guard], "hello world", &mock_client, Some(&parent_cx)).await; + drop(parent_cx); + + let spans = get_guard_spans(parent_span_ctx.trace_id()); + assert_eq!(spans.len(), 1); + + let span = &spans[0]; + let attrs: std::collections::HashMap = span + .attributes + .iter() + .map(|kv| (kv.key.to_string(), kv.value.to_string())) + .collect(); + + assert_eq!(attrs.get("gen_ai.guardrail.name").unwrap(), "pii-check"); + assert_eq!(attrs.get("gen_ai.guardrail.status").unwrap(), "PASSED"); + assert!(attrs.contains_key("gen_ai.guardrail.duration")); +} + +#[tokio::test] +async fn test_guard_span_attributes_on_fail() { + let _ = &*TEST_EXPORTER; + let (parent_cx, parent_span_ctx) = create_parent_context(); + + let guard = + create_test_guard_with_failure_action("toxicity", GuardMode::PreCall, OnFailure::Block); + let mock_client = MockGuardrailClient::with_response("toxicity", Ok(failing_response())); + + let _outcome = execute_guards(&[guard], "bad input", &mock_client, Some(&parent_cx)).await; + drop(parent_cx); + + let spans = get_guard_spans(parent_span_ctx.trace_id()); + assert_eq!(spans.len(), 1); + + let span = &spans[0]; + let attrs: std::collections::HashMap = span + .attributes + .iter() + .map(|kv| (kv.key.to_string(), kv.value.to_string())) + .collect(); + + assert_eq!(attrs.get("gen_ai.guardrail.name").unwrap(), "toxicity"); + assert_eq!(attrs.get("gen_ai.guardrail.status").unwrap(), "FAILED"); +} + +#[tokio::test] +async fn test_guard_span_attributes_on_error() { + let _ = &*TEST_EXPORTER; + let (parent_cx, parent_span_ctx) = create_parent_context(); + + let guard = create_test_guard_with_required("failing-guard", GuardMode::PreCall, true); + let mock_client = MockGuardrailClient::with_response( + "failing-guard", + Err(GuardrailError::Timeout("timed out".to_string())), + ); + + let _outcome = execute_guards(&[guard], "test input", &mock_client, Some(&parent_cx)).await; + drop(parent_cx); + + let spans = get_guard_spans(parent_span_ctx.trace_id()); + assert_eq!(spans.len(), 1); + + let span = &spans[0]; + let attrs: std::collections::HashMap = span + .attributes + .iter() + .map(|kv| (kv.key.to_string(), kv.value.to_string())) + .collect(); + + assert_eq!(attrs.get("gen_ai.guardrail.name").unwrap(), "failing-guard"); + assert_eq!(attrs.get("gen_ai.guardrail.status").unwrap(), "ERROR"); + assert_eq!(attrs.get("gen_ai.guardrail.error.type").unwrap(), "Timeout"); + assert!( + attrs + .get("gen_ai.guardrail.error.message") + .unwrap() + .contains("timed out") + ); +} + +#[tokio::test] +async fn test_no_guard_spans_without_parent_context() { + let _ = &*TEST_EXPORTER; + + // Create a unique trace to establish a "before" baseline + let (marker_cx, marker_span_ctx) = create_parent_context(); + drop(marker_cx); + + let guard = create_test_guard("pii-check", GuardMode::PreCall); + let mock_client = MockGuardrailClient::with_response("pii-check", Ok(passing_response())); + + // Run with None parent — no guard spans should be created + let _outcome = execute_guards(&[guard], "test input", &mock_client, None).await; + + // No guard spans should share the marker's trace_id (nothing was parented to it) + let guard_spans = get_guard_spans(marker_span_ctx.trace_id()); + assert!( + guard_spans.is_empty(), + "No guard spans should be created when parent_cx is None" + ); +} + +// --------------------------------------------------------------------------- +// Response Finalization Tests +// --------------------------------------------------------------------------- + +#[test] +fn test_finalize_response_with_valid_warnings() { + use axum::response::IntoResponse; + let response = (axum::http::StatusCode::OK, "test body").into_response(); + let warnings = vec![GuardWarning { + guard_name: "test-guard".to_string(), + reason: "failed".to_string(), + }]; + let finalized = GuardrailsRunner::finalize_response(response, &warnings); + let headers = finalized.headers(); + assert!(headers.contains_key("x-traceloop-guardrail-warning")); +} + +#[test] +fn test_finalize_response_with_invalid_header_characters() { + use axum::response::IntoResponse; + let response = (axum::http::StatusCode::OK, "test body").into_response(); + // Guard name with newline character (invalid in HTTP headers) + let warnings = vec![GuardWarning { + guard_name: "test-guard\nwith-newline".to_string(), + reason: "failed".to_string(), + }]; + // Should not panic, header should be skipped + let finalized = GuardrailsRunner::finalize_response(response, &warnings); + let headers = finalized.headers(); + // Header should not be present since it couldn't be parsed + assert!(!headers.contains_key("x-traceloop-guardrail-warning")); +} + +#[test] +fn test_finalize_response_with_no_warnings() { + use axum::response::IntoResponse; + let response = (axum::http::StatusCode::OK, "test body").into_response(); + let warnings = vec![]; + let finalized = GuardrailsRunner::finalize_response(response, &warnings); + let headers = finalized.headers(); + assert!(!headers.contains_key("x-traceloop-guardrail-warning")); +} diff --git a/tests/guardrails/test_setup.rs b/tests/guardrails/test_setup.rs new file mode 100644 index 00000000..3652b5f3 --- /dev/null +++ b/tests/guardrails/test_setup.rs @@ -0,0 +1,372 @@ +use std::collections::HashMap; + +use hub_lib::guardrails::setup::*; +use hub_lib::guardrails::types::*; + +use super::helpers::*; + +#[test] +fn test_parse_guardrails_header_single() { + let names = parse_guardrails_header("pii-check"); + assert_eq!(names, vec!["pii-check"]); +} + +#[test] +fn test_parse_guardrails_header_multiple() { + let names = parse_guardrails_header("toxicity-check,relevance-check,pii-check"); + assert_eq!( + names, + vec!["toxicity-check", "relevance-check", "pii-check"] + ); +} + +#[test] +fn test_pipeline_guardrails_always_included() { + let pipeline_guards = vec![create_test_guard("pipeline-guard", GuardMode::PreCall)]; + let resolved = resolve_guards_by_name(&pipeline_guards, &["pipeline-guard"], &[]); + assert_eq!(resolved.len(), 1); + assert_eq!(resolved[0].name, "pipeline-guard"); +} + +#[test] +fn test_header_guardrails_additive_to_pipeline() { + let all_guards = vec![ + create_test_guard("pipeline-guard", GuardMode::PreCall), + create_test_guard("header-guard", GuardMode::PreCall), + ]; + let resolved = resolve_guards_by_name(&all_guards, &["pipeline-guard"], &["header-guard"]); + assert_eq!(resolved.len(), 2); +} + +#[test] +fn test_deduplication_by_name() { + let all_guards = vec![create_test_guard("shared-guard", GuardMode::PreCall)]; + let resolved = resolve_guards_by_name( + &all_guards, + &["shared-guard"], + &["shared-guard"], // duplicate + ); + assert_eq!(resolved.len(), 1); +} + +#[test] +fn test_unknown_guard_name_in_header_ignored() { + let all_guards = vec![create_test_guard("known-guard", GuardMode::PreCall)]; + let resolved = resolve_guards_by_name( + &all_guards, + &["known-guard"], + &["nonexistent-guard"], // unknown + ); + assert_eq!(resolved.len(), 1); + assert_eq!(resolved[0].name, "known-guard"); +} + +#[test] +fn test_empty_header_pipeline_guards_only() { + let all_guards = vec![create_test_guard("pipeline-guard", GuardMode::PreCall)]; + let resolved = resolve_guards_by_name(&all_guards, &["pipeline-guard"], &[]); + assert_eq!(resolved.len(), 1); + assert_eq!(resolved[0].name, "pipeline-guard"); +} + +#[test] +fn test_cannot_remove_pipeline_guardrails_via_api() { + let all_guards = vec![ + create_test_guard("pipeline-guard", GuardMode::PreCall), + create_test_guard("extra", GuardMode::PreCall), + ]; + // Even if header/payload only mention "extra", pipeline guard still included + let resolved = resolve_guards_by_name(&all_guards, &["pipeline-guard"], &["extra"]); + assert!(resolved.iter().any(|g| g.name == "pipeline-guard")); +} + +#[test] +fn test_guards_split_into_pre_and_post_call() { + let guards = vec![ + create_test_guard("pre-1", GuardMode::PreCall), + create_test_guard("post-1", GuardMode::PostCall), + create_test_guard("pre-2", GuardMode::PreCall), + create_test_guard("post-2", GuardMode::PostCall), + ]; + let (pre_call, post_call) = split_guards_by_mode(&guards); + assert_eq!(pre_call.len(), 2); + assert_eq!(post_call.len(), 2); + assert!(pre_call.iter().all(|g| g.mode == GuardMode::PreCall)); + assert!(post_call.iter().all(|g| g.mode == GuardMode::PostCall)); +} + +#[test] +fn test_complete_resolution_merged() { + let all_guards = vec![ + create_test_guard("pipeline-pre", GuardMode::PreCall), + create_test_guard("pipeline-post", GuardMode::PostCall), + create_test_guard("header-pre", GuardMode::PreCall), + ]; + let resolved = resolve_guards_by_name( + &all_guards, + &["pipeline-pre", "pipeline-post"], + &["header-pre"], + ); + assert_eq!(resolved.len(), 3); + let (pre, post) = split_guards_by_mode(&resolved); + assert_eq!(pre.len(), 2); // pipeline-pre + header-pre + assert_eq!(post.len(), 1); // pipeline-post +} + +// --------------------------------------------------------------------------- +// Pipeline Guard Building & Provider Defaults +// --------------------------------------------------------------------------- + +fn test_guardrails_config() -> GuardrailsConfig { + GuardrailsConfig { + providers: HashMap::from([( + "traceloop".to_string(), + ProviderConfig { + name: "traceloop".to_string(), + api_base: "https://api.traceloop.com".to_string(), + api_key: "test-key".to_string(), + }, + )]), + guards: vec![ + Guard { + name: "pii-check".to_string(), + provider: "traceloop".to_string(), + evaluator_slug: "pii".to_string(), + params: Default::default(), + mode: GuardMode::PreCall, + on_failure: OnFailure::Block, + required: false, + api_base: None, + api_key: None, + }, + Guard { + name: "toxicity-filter".to_string(), + provider: "traceloop".to_string(), + evaluator_slug: "toxicity".to_string(), + params: Default::default(), + mode: GuardMode::PostCall, + on_failure: OnFailure::Warn, + required: false, + api_base: None, + api_key: None, + }, + Guard { + name: "injection-check".to_string(), + provider: "traceloop".to_string(), + evaluator_slug: "injection".to_string(), + params: Default::default(), + mode: GuardMode::PreCall, + on_failure: OnFailure::Block, + required: false, + api_base: None, + api_key: None, + }, + Guard { + name: "secrets-check".to_string(), + provider: "traceloop".to_string(), + evaluator_slug: "secrets".to_string(), + params: Default::default(), + mode: GuardMode::PostCall, + on_failure: OnFailure::Block, + required: false, + api_base: None, + api_key: None, + }, + ], + } +} + +#[test] +fn test_no_guardrails_passthrough() { + // Empty guardrails config -> build_guardrail_resources returns None + let config = GuardrailsConfig { + providers: Default::default(), + guards: vec![], + }; + let result = build_guardrail_resources(&config); + assert!(result.is_none()); + + // Config with no guards -> passthrough + let config_with_providers = GuardrailsConfig { + providers: HashMap::from([( + "traceloop".to_string(), + ProviderConfig { + name: "traceloop".to_string(), + api_base: "http://localhost".to_string(), + api_key: "key".to_string(), + }, + )]), + guards: vec![], + }; + let result = build_guardrail_resources(&config_with_providers); + assert!(result.is_none()); +} + +#[test] +fn test_build_pipeline_guardrails_with_specific_guards() { + let config = test_guardrails_config(); + let shared = build_guardrail_resources(&config).unwrap(); + let pipeline_guards = vec!["pii-check".to_string(), "toxicity-filter".to_string()]; + let gr = build_pipeline_guardrails(&shared, &pipeline_guards); + + // all_guards should contain ALL guards from config, resolved with provider defaults + assert_eq!(gr.all_guards.len(), 4); + // pipeline_guard_names should only contain the ones specified + assert_eq!( + gr.pipeline_guard_names, + vec!["pii-check", "toxicity-filter"] + ); +} + +#[test] +fn test_build_pipeline_guardrails_empty_pipeline_guards() { + let config = test_guardrails_config(); + let shared = build_guardrail_resources(&config).unwrap(); + // Pipeline with no guards specified - shared resources still exist + // (header guards can still be used at request time) + let empty: Vec = vec![]; + let gr = build_pipeline_guardrails(&shared, &empty); + + assert_eq!(gr.all_guards.len(), 4); + assert!(gr.pipeline_guard_names.is_empty()); +} + +#[test] +fn test_build_pipeline_guardrails_resolves_provider_defaults() { + let config = test_guardrails_config(); + let shared = build_guardrail_resources(&config).unwrap(); + let gr = build_pipeline_guardrails(&shared, &["pii-check".to_string()]); + + // Guards should have provider api_base/api_key resolved + for guard in gr.all_guards.iter() { + assert_eq!(guard.api_base.as_deref(), Some("https://api.traceloop.com")); + assert_eq!(guard.api_key.as_deref(), Some("test-key")); + } +} + +#[test] +fn test_resolve_guard_defaults_preserves_guard_overrides() { + let config = GuardrailsConfig { + providers: HashMap::from([( + "traceloop".to_string(), + ProviderConfig { + name: "traceloop".to_string(), + api_base: "https://default.api.com".to_string(), + api_key: "default-key".to_string(), + }, + )]), + guards: vec![Guard { + name: "custom-guard".to_string(), + provider: "traceloop".to_string(), + evaluator_slug: "custom".to_string(), + params: Default::default(), + mode: GuardMode::PreCall, + on_failure: OnFailure::Block, + required: true, + api_base: Some("https://custom.api.com".to_string()), + api_key: Some("custom-key".to_string()), + }], + }; + + let resolved = resolve_guard_defaults(&config); + assert_eq!( + resolved[0].api_base.as_deref(), + Some("https://custom.api.com") + ); + assert_eq!(resolved[0].api_key.as_deref(), Some("custom-key")); +} + +#[test] +fn test_pipeline_guards_resolved_at_request_time() { + // Simulates what happens at request time: merge pipeline + header guards + let config = test_guardrails_config(); + let all_guards = resolve_guard_defaults(&config); + + // Pipeline declares only pii-check + let pipeline_names = vec!["pii-check"]; + // Header adds injection-check + let header_names = vec!["injection-check"]; + + let resolved = resolve_guards_by_name(&all_guards, &pipeline_names, &header_names); + assert_eq!(resolved.len(), 2); + assert_eq!(resolved[0].name, "pii-check"); + assert_eq!(resolved[1].name, "injection-check"); +} + +#[test] +fn test_pipeline_guards_plus_header_guards_split_by_mode() { + let config = test_guardrails_config(); + let all_guards = resolve_guard_defaults(&config); + + // Pipeline declares pii-check (pre_call) and toxicity-filter (post_call) + let pipeline_names = vec!["pii-check", "toxicity-filter"]; + // Header adds injection-check (pre_call) and secrets-check (post_call) + let header_names = vec!["injection-check", "secrets-check"]; + + let resolved = resolve_guards_by_name(&all_guards, &pipeline_names, &header_names); + assert_eq!(resolved.len(), 4); + + let (pre_call, post_call) = split_guards_by_mode(&resolved); + assert_eq!(pre_call.len(), 2); + assert_eq!(post_call.len(), 2); + assert!(pre_call.iter().any(|g| g.name == "pii-check")); + assert!(pre_call.iter().any(|g| g.name == "injection-check")); + assert!(post_call.iter().any(|g| g.name == "toxicity-filter")); + assert!(post_call.iter().any(|g| g.name == "secrets-check")); +} + +#[test] +fn test_header_guard_not_in_config_is_ignored() { + let config = test_guardrails_config(); + let all_guards = resolve_guard_defaults(&config); + + let pipeline_names = vec!["pii-check"]; + let header_names = vec!["nonexistent-guard"]; + + let resolved = resolve_guards_by_name(&all_guards, &pipeline_names, &header_names); + // Only pii-check should be resolved; nonexistent guard is silently ignored + assert_eq!(resolved.len(), 1); + assert_eq!(resolved[0].name, "pii-check"); +} + +#[test] +fn test_duplicate_guard_in_header_and_pipeline_deduped() { + let config = test_guardrails_config(); + let all_guards = resolve_guard_defaults(&config); + + let pipeline_names = vec!["pii-check", "toxicity-filter"]; + // Header specifies same guard as pipeline + let header_names = vec!["pii-check"]; + + let resolved = resolve_guards_by_name(&all_guards, &pipeline_names, &header_names); + assert_eq!(resolved.len(), 2); // pii-check only appears once +} + +#[test] +fn test_no_pipeline_guards_header_only() { + let config = test_guardrails_config(); + let all_guards = resolve_guard_defaults(&config); + + // Pipeline has no guards + let pipeline_names: Vec<&str> = vec![]; + // Header adds guards + let header_names = vec!["injection-check", "secrets-check"]; + + let resolved = resolve_guards_by_name(&all_guards, &pipeline_names, &header_names); + assert_eq!(resolved.len(), 2); + assert_eq!(resolved[0].name, "injection-check"); + assert_eq!(resolved[1].name, "secrets-check"); +} + +#[test] +fn test_no_pipeline_guards_no_header_no_guards_executed() { + let config = test_guardrails_config(); + let all_guards = resolve_guard_defaults(&config); + + let resolved = resolve_guards_by_name(&all_guards, &[], &[]); + assert!(resolved.is_empty()); + + let (pre_call, post_call) = split_guards_by_mode(&resolved); + assert!(pre_call.is_empty()); + assert!(post_call.is_empty()); +} diff --git a/tests/guardrails/test_traceloop_client.rs b/tests/guardrails/test_traceloop_client.rs new file mode 100644 index 00000000..0b22cfa4 --- /dev/null +++ b/tests/guardrails/test_traceloop_client.rs @@ -0,0 +1,153 @@ +use hub_lib::guardrails::providers::traceloop::TraceloopClient; +use hub_lib::guardrails::types::GuardMode; +use hub_lib::guardrails::types::GuardrailClient; +use serde_json::json; +use wiremock::matchers; +use wiremock::{Mock, MockServer, ResponseTemplate}; + +use super::helpers::*; + +#[tokio::test] +async fn test_traceloop_client_constructs_correct_url() { + let mock_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .and(matchers::path("/v2/guardrails/execute/toxicity-detector")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({"result": {}, "pass": true}))) + .expect(1) + .mount(&mock_server) + .await; + + let mut guard = create_test_guard_with_api_base("test", GuardMode::PreCall, &mock_server.uri()); + guard.evaluator_slug = "toxicity-detector".to_string(); + + let client = TraceloopClient::new(); + let result = client.evaluate(&guard, "test input").await; + assert!(result.is_ok()); +} + +#[tokio::test] +async fn test_traceloop_client_sends_correct_headers() { + let mock_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .and(matchers::header("Authorization", "Bearer test-api-key")) + .and(matchers::header("Content-Type", "application/json")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({"result": {}, "pass": true}))) + .expect(1) + .mount(&mock_server) + .await; + + let guard = create_test_guard_with_api_base("test", GuardMode::PreCall, &mock_server.uri()); + let client = TraceloopClient::new(); + let result = client.evaluate(&guard, "test input").await; + assert!(result.is_ok()); +} + +#[tokio::test] +async fn test_traceloop_client_sends_correct_body() { + let mock_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .and(matchers::body_json(json!({ + "input": {"text": "test input text"}, + "config": {"threshold": 0.5} + }))) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({"result": {}, "pass": true}))) + .expect(1) + .mount(&mock_server) + .await; + + let mut guard = create_test_guard_with_api_base("test", GuardMode::PreCall, &mock_server.uri()); + guard.evaluator_slug = "toxicity-detector".to_string(); + guard.params.insert("threshold".to_string(), json!(0.5)); + + let client = TraceloopClient::new(); + let result = client.evaluate(&guard, "test input text").await; + assert!(result.is_ok()); +} + +#[tokio::test] +async fn test_traceloop_client_handles_successful_response() { + let mock_server = MockServer::start().await; + Mock::given(matchers::any()) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {"score": 0.9, "label": "safe"}, + "pass": true + }))) + .mount(&mock_server) + .await; + + let guard = create_test_guard_with_api_base("test", GuardMode::PreCall, &mock_server.uri()); + let client = TraceloopClient::new(); + let result = client.evaluate(&guard, "safe input").await.unwrap(); + assert!(result.pass); + assert_eq!(result.result["score"], 0.9); +} + +#[tokio::test] +async fn test_traceloop_client_handles_error_response() { + let mock_server = MockServer::start().await; + Mock::given(matchers::any()) + .respond_with(ResponseTemplate::new(500).set_body_string("Internal Server Error")) + .mount(&mock_server) + .await; + + let guard = create_test_guard_with_api_base("test", GuardMode::PreCall, &mock_server.uri()); + let client = TraceloopClient::new(); + let result = client.evaluate(&guard, "test").await; + assert!(result.is_err()); +} + +#[tokio::test] +async fn test_traceloop_client_handles_timeout() { + let mock_server = MockServer::start().await; + Mock::given(matchers::any()) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(json!({"result": {}, "pass": true})) + .set_delay(std::time::Duration::from_secs(30)), + ) + .mount(&mock_server) + .await; + + let guard = create_test_guard_with_api_base("test", GuardMode::PreCall, &mock_server.uri()); + let client = TraceloopClient::with_timeout(std::time::Duration::from_millis(100)); + let result = client.evaluate(&guard, "test").await; + assert!(result.is_err()); +} + +#[tokio::test] +async fn test_traceloop_client_rejects_missing_api_key() { + let mock_server = MockServer::start().await; + let mut guard = create_test_guard_with_api_base("test", GuardMode::PreCall, &mock_server.uri()); + guard.api_key = None; + + let client = TraceloopClient::new(); + let result = client.evaluate(&guard, "test input").await; + + assert!(result.is_err()); + if let Err(e) = result { + assert!( + e.to_string().contains("API key is required"), + "Expected error message about missing API key, got: {}", + e + ); + } +} + +#[tokio::test] +async fn test_traceloop_client_rejects_empty_api_key() { + let mock_server = MockServer::start().await; + let mut guard = create_test_guard_with_api_base("test", GuardMode::PreCall, &mock_server.uri()); + guard.api_key = Some("".to_string()); + + let client = TraceloopClient::new(); + let result = client.evaluate(&guard, "test input").await; + + assert!(result.is_err()); + if let Err(e) = result { + assert!( + e.to_string().contains("API key is required"), + "Expected error message about missing API key, got: {}", + e + ); + } +} diff --git a/tests/guardrails/test_types.rs b/tests/guardrails/test_types.rs new file mode 100644 index 00000000..f701358c --- /dev/null +++ b/tests/guardrails/test_types.rs @@ -0,0 +1,338 @@ +use hub_lib::guardrails::types::*; +use hub_lib::types::GatewayConfig; +use std::io::Write; +use tempfile::NamedTempFile; + +#[test] +fn test_guard_mode_deserialize_pre_call() { + let mode: GuardMode = serde_json::from_str("\"pre_call\"").unwrap(); + assert_eq!(mode, GuardMode::PreCall); +} + +#[test] +fn test_guard_mode_deserialize_post_call() { + let mode: GuardMode = serde_json::from_str("\"post_call\"").unwrap(); + assert_eq!(mode, GuardMode::PostCall); +} + +#[test] +fn test_on_failure_defaults_to_warn() { + let json = serde_json::json!({ + "name": "test-guard", + "provider": "traceloop", + "evaluator_slug": "toxicity", + "mode": "pre_call" + }); + let guard: Guard = serde_json::from_value(json).unwrap(); + assert_eq!(guard.on_failure, OnFailure::Warn); +} + +#[test] +fn test_required_defaults_to_false() { + let json = serde_json::json!({ + "name": "test-guard", + "provider": "traceloop", + "evaluator_slug": "toxicity", + "mode": "pre_call" + }); + let guard: Guard = serde_json::from_value(json).unwrap(); + assert!(!guard.required); +} + +#[test] +fn test_guard_config_full_deserialization() { + let json = serde_json::json!({ + "name": "toxicity-check", + "provider": "traceloop", + "evaluator_slug": "toxicity", + "params": { + "threshold": 0.5 + }, + "mode": "pre_call", + "on_failure": "block", + "required": false, + "api_base": "https://api.traceloop.com", + "api_key": "tl-key-123" + }); + let guard: Guard = serde_json::from_value(json).unwrap(); + assert_eq!(guard.name, "toxicity-check"); + assert_eq!(guard.provider, "traceloop"); + assert_eq!(guard.evaluator_slug, "toxicity"); + assert_eq!(guard.params.get("threshold").unwrap(), 0.5); + assert_eq!(guard.mode, GuardMode::PreCall); + assert_eq!(guard.on_failure, OnFailure::Block); + assert!(!guard.required); + assert_eq!(guard.api_base.unwrap(), "https://api.traceloop.com"); + assert_eq!(guard.api_key.unwrap(), "tl-key-123"); +} + +#[test] +fn test_guardrails_config_yaml_deserialization() { + let yaml = r#" +guards: + - name: toxicity-check + provider: traceloop + evaluator_slug: toxicity + mode: pre_call + on_failure: block + required: true + api_base: "https://api.traceloop.com" + api_key: "test-key" + - name: relevance-check + provider: traceloop + evaluator_slug: relevance + mode: post_call + on_failure: warn + api_base: "https://api.traceloop.com" + api_key: "test-key" +"#; + let config: GuardrailsConfig = serde_yaml::from_str(yaml).unwrap(); + assert_eq!(config.guards.len(), 2); + assert_eq!(config.guards[0].name, "toxicity-check"); + assert_eq!(config.guards[0].evaluator_slug, "toxicity"); + assert_eq!(config.guards[0].mode, GuardMode::PreCall); + assert_eq!(config.guards[1].name, "relevance-check"); + assert_eq!(config.guards[1].evaluator_slug, "relevance"); + assert_eq!(config.guards[1].mode, GuardMode::PostCall); + assert_eq!(config.guards[1].on_failure, OnFailure::Warn); +} + +#[test] +fn test_gateway_config_with_guardrails() { + use super::helpers::create_test_guard; + let config = GatewayConfig { + general: None, + providers: vec![], + models: vec![], + pipelines: vec![], + guardrails: Some(GuardrailsConfig { + providers: Default::default(), + guards: vec![create_test_guard("test", GuardMode::PreCall)], + }), + }; + assert!(config.guardrails.is_some()); + assert_eq!(config.guardrails.unwrap().guards.len(), 1); +} + +#[test] +fn test_gateway_config_without_guardrails_backward_compat() { + let json = serde_json::json!({ + "providers": [], + "models": [], + "pipelines": [] + }); + let config: GatewayConfig = serde_json::from_value(json).unwrap(); + assert!(config.guardrails.is_none()); +} + +#[test] +fn test_guard_config_env_var_in_api_key() { + let config_content = r#" +providers: + - key: openai + type: openai + api_key: "sk-test" +models: + - key: gpt-4 + type: gpt-4 + provider: openai +pipelines: + - name: default + type: chat + plugins: + - model-router: + models: + - gpt-4 +guardrails: + guards: + - name: toxicity-check + provider: traceloop + evaluator_slug: toxicity + mode: pre_call + api_base: "https://api.traceloop.com" + api_key: "${TEST_GUARD_API_KEY_UNIQUE}" +"#; + let mut temp_file = NamedTempFile::new().unwrap(); + temp_file.write_all(config_content.as_bytes()).unwrap(); + let temp_path = temp_file.path().to_str().unwrap().to_owned(); + temp_env::with_var("TEST_GUARD_API_KEY_UNIQUE", Some("tl-secret-key"), || { + let config = hub_lib::config::load_config(&temp_path).unwrap(); + let guards = config.guardrails.unwrap().guards; + assert_eq!(guards[0].api_key.as_deref(), Some("tl-secret-key")); + }); +} + +// --------------------------------------------------------------------------- +// Provider config tests +// --------------------------------------------------------------------------- + +#[test] +fn test_provider_config_deserialization() { + let json = serde_json::json!({ + "name": "traceloop", + "api_base": "https://api.traceloop.com", + "api_key": "tl-key-123" + }); + let provider: ProviderConfig = serde_json::from_value(json).unwrap(); + assert_eq!(provider.name, "traceloop"); + assert_eq!(provider.api_base, "https://api.traceloop.com"); + assert_eq!(provider.api_key, "tl-key-123"); +} + +#[test] +fn test_guardrails_config_with_providers_yaml() { + let yaml = r#" +providers: + - name: traceloop + api_base: "https://api.traceloop.com" + api_key: "tl-key-123" +guards: + - name: toxicity-check + provider: traceloop + evaluator_slug: toxicity + mode: pre_call + on_failure: block + - name: pii-check + provider: traceloop + evaluator_slug: pii-detection + mode: pre_call + on_failure: block + api_base: "https://custom.traceloop.com" + api_key: "custom-key" +"#; + let config: GuardrailsConfig = serde_yaml::from_str(yaml).unwrap(); + assert_eq!(config.providers.len(), 1); + assert_eq!(config.providers["traceloop"].name, "traceloop"); + assert_eq!( + config.providers["traceloop"].api_base, + "https://api.traceloop.com" + ); + assert_eq!(config.guards.len(), 2); + // First guard has no api_base/api_key (inherits from provider) + assert!(config.guards[0].api_base.is_none()); + assert!(config.guards[0].api_key.is_none()); + // Second guard overrides api_base/api_key + assert_eq!( + config.guards[1].api_base.as_deref(), + Some("https://custom.traceloop.com") + ); + assert_eq!(config.guards[1].api_key.as_deref(), Some("custom-key")); +} + +#[test] +fn test_guard_without_api_base_deserializes() { + let json = serde_json::json!({ + "name": "toxicity-check", + "provider": "traceloop", + "evaluator_slug": "toxicity", + "mode": "pre_call" + }); + let guard: Guard = serde_json::from_value(json).unwrap(); + assert!(guard.api_base.is_none()); + assert!(guard.api_key.is_none()); +} + +#[test] +fn test_guard_config_evaluator_slug_not_in_params() { + let json = serde_json::json!({ + "name": "toxicity-check", + "provider": "traceloop", + "evaluator_slug": "toxicity", + "params": {"threshold": 0.5}, + "mode": "pre_call" + }); + let guard: Guard = serde_json::from_value(json).unwrap(); + assert_eq!(guard.evaluator_slug, "toxicity"); + assert!(!guard.params.contains_key("evaluator_slug")); + assert_eq!(guard.params.get("threshold").unwrap(), 0.5); +} + +// --------------------------------------------------------------------------- +// Pipeline config YAML parsing +// --------------------------------------------------------------------------- + +#[test] +fn test_pipeline_guards_field_in_yaml_config() { + let config_yaml = r#" +providers: + - key: openai + type: openai + api_key: "sk-test" +models: + - key: gpt-4 + type: gpt-4 + provider: openai +pipelines: + - name: default + type: chat + guards: + - pii-check + - injection-check + plugins: + - model-router: + models: + - gpt-4 + - name: embeddings + type: embeddings + plugins: + - model-router: + models: + - gpt-4 +guardrails: + providers: + - name: traceloop + api_base: "https://api.traceloop.com" + api_key: "test-key" + guards: + - name: pii-check + provider: traceloop + evaluator_slug: pii + mode: pre_call + on_failure: block + - name: injection-check + provider: traceloop + evaluator_slug: injection + mode: pre_call + on_failure: block +"#; + + let mut temp_file = NamedTempFile::new().unwrap(); + temp_file.write_all(config_yaml.as_bytes()).unwrap(); + + let config = hub_lib::config::load_config(temp_file.path().to_str().unwrap()).unwrap(); + + // Default pipeline should have guards + assert_eq!( + config.pipelines[0].guards, + vec!["pii-check", "injection-check"] + ); + // Embeddings pipeline should have no guards + assert!(config.pipelines[1].guards.is_empty()); +} + +#[test] +fn test_pipeline_guards_field_absent_defaults_to_empty() { + let config_yaml = r#" +providers: + - key: openai + type: openai + api_key: "sk-test" +models: + - key: gpt-4 + type: gpt-4 + provider: openai +pipelines: + - name: default + type: chat + plugins: + - model-router: + models: + - gpt-4 +"#; + + let mut temp_file = NamedTempFile::new().unwrap(); + temp_file.write_all(config_yaml.as_bytes()).unwrap(); + + let config = hub_lib::config::load_config(temp_file.path().to_str().unwrap()).unwrap(); + assert!(config.pipelines[0].guards.is_empty()); +} diff --git a/tests/pipeline_header_routing_test.rs b/tests/pipeline_header_routing_test.rs index e015138d..2e98cd32 100644 --- a/tests/pipeline_header_routing_test.rs +++ b/tests/pipeline_header_routing_test.rs @@ -25,6 +25,7 @@ fn create_test_config_with_multiple_pipelines() -> GatewayConfig { plugins: vec![PluginConfig::ModelRouter { models: vec!["test-model".to_string()], }], + guards: vec![], }; let pipeline2 = Pipeline { @@ -33,10 +34,12 @@ fn create_test_config_with_multiple_pipelines() -> GatewayConfig { plugins: vec![PluginConfig::ModelRouter { models: vec!["test-model".to_string()], }], + guards: vec![], }; GatewayConfig { general: None, + guardrails: None, providers: vec![provider], models: vec![model], pipelines: vec![pipeline1, pipeline2], @@ -75,6 +78,7 @@ async fn test_pipeline_header_routing_configuration_updates() { plugins: vec![PluginConfig::ModelRouter { models: vec!["test-model".to_string()], }], + guards: vec![], }; updated_config.pipelines.push(pipeline3); diff --git a/tests/router_cache_tests.rs b/tests/router_cache_tests.rs index ecec5b40..4cef3445 100644 --- a/tests/router_cache_tests.rs +++ b/tests/router_cache_tests.rs @@ -9,6 +9,7 @@ async fn test_router_always_available() { // Create a basic configuration let config = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test-provider".to_string(), r#type: ProviderType::OpenAI, @@ -27,6 +28,7 @@ async fn test_router_always_available() { plugins: vec![PluginConfig::ModelRouter { models: vec!["gpt-4".to_string()], }], + guards: vec![], }], }; @@ -43,6 +45,7 @@ async fn test_configuration_change_detection() { // Create initial configuration let initial_config = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test-provider".to_string(), r#type: ProviderType::OpenAI, @@ -61,6 +64,7 @@ async fn test_configuration_change_detection() { plugins: vec![PluginConfig::ModelRouter { models: vec!["gpt-4".to_string()], }], + guards: vec![], }], }; @@ -87,6 +91,7 @@ async fn test_configuration_change_detection() { async fn test_invalid_configuration_rejected() { let initial_config = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test-provider".to_string(), r#type: ProviderType::OpenAI, @@ -102,6 +107,7 @@ async fn test_invalid_configuration_rejected() { // Create invalid configuration (model references non-existent provider) let invalid_config = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test-provider".to_string(), r#type: ProviderType::OpenAI, @@ -130,6 +136,7 @@ async fn test_invalid_configuration_rejected() { async fn test_concurrent_router_access() { let config = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test-provider".to_string(), r#type: ProviderType::OpenAI, @@ -148,6 +155,7 @@ async fn test_concurrent_router_access() { plugins: vec![PluginConfig::ModelRouter { models: vec!["gpt-4".to_string()], }], + guards: vec![], }], }; @@ -181,6 +189,7 @@ async fn test_empty_configuration_fallback() { providers: vec![], models: vec![], pipelines: vec![], + guardrails: None, }; let app_state = Arc::new(AppState::new(empty_config).expect("Failed to create app state")); @@ -195,6 +204,7 @@ async fn test_pipeline_with_failing_tracing_endpoint() { // Create configuration with a pipeline that has a failing tracing endpoint let config = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test-provider".to_string(), r#type: ProviderType::OpenAI, @@ -219,6 +229,7 @@ async fn test_pipeline_with_failing_tracing_endpoint() { models: vec!["gpt-4".to_string()], }, ], + guards: vec![], }], }; @@ -250,6 +261,7 @@ async fn test_tracing_isolation_between_pipelines() { // Create configuration with two pipelines - one with tracing, one without let config = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test-provider".to_string(), r#type: ProviderType::OpenAI, @@ -276,6 +288,7 @@ async fn test_tracing_isolation_between_pipelines() { models: vec!["gpt-4".to_string()], }, ], + guards: vec![], }, // Pipeline without tracing Pipeline { @@ -284,6 +297,7 @@ async fn test_tracing_isolation_between_pipelines() { plugins: vec![PluginConfig::ModelRouter { models: vec!["gpt-4".to_string()], }], + guards: vec![], }, ], }; diff --git a/tests/router_integration_test.rs b/tests/router_integration_test.rs index 3560d8af..5054f508 100644 --- a/tests/router_integration_test.rs +++ b/tests/router_integration_test.rs @@ -12,6 +12,7 @@ async fn test_router_integration_flow() { providers: vec![], models: vec![], pipelines: vec![], + guardrails: None, }; let app_state = Arc::new(AppState::new(empty_config).expect("Failed to create app state")); @@ -22,6 +23,7 @@ async fn test_router_integration_flow() { // Test 2: Valid configuration let valid_config = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test-provider".to_string(), r#type: ProviderType::OpenAI, @@ -40,6 +42,7 @@ async fn test_router_integration_flow() { plugins: vec![PluginConfig::ModelRouter { models: vec!["gpt-4".to_string()], }], + guards: vec![], }], }; @@ -63,6 +66,7 @@ async fn test_router_integration_flow() { // Test 6: Invalid configuration rejection let invalid_config = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test-provider".to_string(), r#type: ProviderType::OpenAI, @@ -81,6 +85,7 @@ async fn test_router_integration_flow() { plugins: vec![PluginConfig::ModelRouter { models: vec!["gpt-4".to_string()], }], + guards: vec![], }], }; @@ -95,6 +100,7 @@ async fn test_router_integration_flow() { // Test 7: Multiple pipeline configuration let multi_pipeline_config = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test-provider".to_string(), r#type: ProviderType::OpenAI, @@ -122,6 +128,7 @@ async fn test_router_integration_flow() { plugins: vec![PluginConfig::ModelRouter { models: vec!["gpt-4".to_string()], }], + guards: vec![], }, Pipeline { name: "fast".to_string(), @@ -129,6 +136,7 @@ async fn test_router_integration_flow() { plugins: vec![PluginConfig::ModelRouter { models: vec!["gpt-3.5-turbo".to_string()], }], + guards: vec![], }, ], }; @@ -158,6 +166,7 @@ async fn test_router_integration_flow() { async fn test_concurrent_configuration_updates() { let initial_config = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test-provider".to_string(), r#type: ProviderType::OpenAI, @@ -176,6 +185,7 @@ async fn test_concurrent_configuration_updates() { plugins: vec![PluginConfig::ModelRouter { models: vec!["gpt-4".to_string()], }], + guards: vec![], }], }; @@ -189,6 +199,7 @@ async fn test_concurrent_configuration_updates() { // Create a slightly different configuration for each task let config = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: format!("provider-{}", i), r#type: ProviderType::OpenAI, @@ -207,6 +218,7 @@ async fn test_concurrent_configuration_updates() { plugins: vec![PluginConfig::ModelRouter { models: vec![format!("model-{}", i)], }], + guards: vec![], }], }; diff --git a/tests/unified_openapi_test.rs b/tests/unified_openapi_test.rs index 6b298270..400b5b23 100644 --- a/tests/unified_openapi_test.rs +++ b/tests/unified_openapi_test.rs @@ -125,6 +125,7 @@ async fn test_router_creation_no_conflicts() { // Create a basic configuration for testing let config = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test-provider".to_string(), r#type: ProviderType::OpenAI, @@ -143,6 +144,7 @@ async fn test_router_creation_no_conflicts() { plugins: vec![PluginConfig::ModelRouter { models: vec!["gpt-4".to_string()], }], + guards: vec![], }], };