From 18ae3a77c395c624adadc7d337aff4627aa40e94 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Tue, 10 Feb 2026 10:22:57 +0200 Subject: [PATCH 01/59] tests added --- .gitignore | 4 + Cargo.lock | 2 +- src/config/lib.rs | 4 + src/config/validation.rs | 3 + src/guardrails/api_control.rs | 28 +++ src/guardrails/executor.rs | 12 ++ src/guardrails/input_extractor.rs | 13 ++ src/guardrails/mod.rs | 7 + src/guardrails/providers/mod.rs | 21 ++ src/guardrails/providers/traceloop.rs | 36 ++++ src/guardrails/response_parser.rs | 14 ++ src/guardrails/stream_buffer.rs | 7 + src/guardrails/types.rs | 140 ++++++++++++ src/lib.rs | 1 + src/types/mod.rs | 2 + tests/config_hash_tests.rs | 5 + tests/guardrails/helpers.rs | 230 ++++++++++++++++++++ tests/guardrails/main.rs | 10 + tests/guardrails/test_api_control.rs | 174 +++++++++++++++ tests/guardrails/test_e2e.rs | 100 +++++++++ tests/guardrails/test_executor.rs | 183 ++++++++++++++++ tests/guardrails/test_input_extractor.rs | 76 +++++++ tests/guardrails/test_pipeline.rs | 52 +++++ tests/guardrails/test_response_parser.rs | 69 ++++++ tests/guardrails/test_stream_buffer.rs | 19 ++ tests/guardrails/test_traceloop_client.rs | 125 +++++++++++ tests/guardrails/test_types.rs | 252 ++++++++++++++++++++++ tests/pipeline_header_routing_test.rs | 1 + tests/router_cache_tests.rs | 8 + tests/router_integration_test.rs | 6 + tests/unified_openapi_test.rs | 1 + 31 files changed, 1604 insertions(+), 1 deletion(-) create mode 100644 src/guardrails/api_control.rs create mode 100644 src/guardrails/executor.rs create mode 100644 src/guardrails/input_extractor.rs create mode 100644 src/guardrails/mod.rs create mode 100644 src/guardrails/providers/mod.rs create mode 100644 src/guardrails/providers/traceloop.rs create mode 100644 src/guardrails/response_parser.rs create mode 100644 src/guardrails/stream_buffer.rs create mode 100644 src/guardrails/types.rs create mode 100644 tests/guardrails/helpers.rs create mode 100644 tests/guardrails/main.rs create mode 100644 tests/guardrails/test_api_control.rs create mode 100644 tests/guardrails/test_e2e.rs create mode 100644 tests/guardrails/test_executor.rs create mode 100644 tests/guardrails/test_input_extractor.rs create mode 100644 tests/guardrails/test_pipeline.rs create mode 100644 tests/guardrails/test_response_parser.rs create mode 100644 tests/guardrails/test_stream_buffer.rs create mode 100644 tests/guardrails/test_traceloop_client.rs create mode 100644 tests/guardrails/test_types.rs diff --git a/.gitignore b/.gitignore index ca8cff2f..2240f792 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,10 @@ config.yaml .vscode/ .DS_Store +# Documentation +docs/ +.claude/ + prd/ memory-bank/ .cursor/ diff --git a/Cargo.lock b/Cargo.lock index d1058810..f2f8d530 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2269,7 +2269,7 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" [[package]] name = "hub" -version = "0.7.5" +version = "0.7.6" dependencies = [ "anyhow", "async-stream", diff --git a/src/config/lib.rs b/src/config/lib.rs index abec66f4..e6be59fd 100644 --- a/src/config/lib.rs +++ b/src/config/lib.rs @@ -1,3 +1,4 @@ +use crate::guardrails::types::GuardrailsConfig; use crate::types::{GatewayConfig, ModelConfig, Pipeline, PipelineType, PluginConfig, Provider}; use serde::Deserialize; use std::sync::OnceLock; @@ -31,6 +32,8 @@ struct YamlRoot { models: Vec, #[serde(default)] pipelines: Vec, + #[serde(default)] + guardrails: Option, } fn substitute_env_vars(content: &str) -> Result> { @@ -88,6 +91,7 @@ pub fn load_config(path: &str) -> Result Vec { + todo!("Implement header parsing") +} + +/// Parse guard names from the request payload's `guardrails` field. +pub fn parse_guardrails_from_payload(_payload: &serde_json::Value) -> Vec { + todo!("Implement payload guardrails parsing") +} + +/// Resolve the final set of guards to execute by merging pipeline, header, and payload sources. +/// Guards are additive and deduplicated by name. +pub fn resolve_guards_by_name<'a>( + _all_guards: &'a [GuardConfig], + _pipeline_names: &[&str], + _header_names: &[&str], + _payload_names: &[&str], +) -> Vec { + todo!("Implement additive guard resolution") +} + +/// Split guards into (pre_call, post_call) lists by mode. +pub fn split_guards_by_mode(_guards: &[GuardConfig]) -> (Vec, Vec) { + todo!("Implement guard splitting by mode") +} diff --git a/src/guardrails/executor.rs b/src/guardrails/executor.rs new file mode 100644 index 00000000..0b4c04f1 --- /dev/null +++ b/src/guardrails/executor.rs @@ -0,0 +1,12 @@ +use super::providers::GuardrailClient; +use super::types::{GuardConfig, GuardrailsOutcome}; + +/// Execute a set of guardrails against the given input text. +/// Returns a GuardrailsOutcome with results, blocked status, and warnings. +pub async fn execute_guards( + _guards: &[GuardConfig], + _input: &str, + _client: &dyn GuardrailClient, +) -> GuardrailsOutcome { + todo!("Implement guard execution orchestration") +} diff --git a/src/guardrails/input_extractor.rs b/src/guardrails/input_extractor.rs new file mode 100644 index 00000000..c6271c47 --- /dev/null +++ b/src/guardrails/input_extractor.rs @@ -0,0 +1,13 @@ +use crate::models::chat::{ChatCompletion, ChatCompletionRequest}; + +/// Extract text from the request for pre_call guardrails. +/// Returns the content of the last user message. +pub fn extract_pre_call_input(_request: &ChatCompletionRequest) -> String { + todo!("Implement pre_call input extraction") +} + +/// Extract text from a non-streaming ChatCompletion for post_call guardrails. +/// Returns the content of the first assistant choice. +pub fn extract_post_call_input_from_completion(_completion: &ChatCompletion) -> String { + todo!("Implement post_call input extraction from completion") +} diff --git a/src/guardrails/mod.rs b/src/guardrails/mod.rs new file mode 100644 index 00000000..81ae9edf --- /dev/null +++ b/src/guardrails/mod.rs @@ -0,0 +1,7 @@ +pub mod api_control; +pub mod executor; +pub mod input_extractor; +pub mod providers; +pub mod response_parser; +pub mod stream_buffer; +pub mod types; diff --git a/src/guardrails/providers/mod.rs b/src/guardrails/providers/mod.rs new file mode 100644 index 00000000..faa46196 --- /dev/null +++ b/src/guardrails/providers/mod.rs @@ -0,0 +1,21 @@ +pub mod traceloop; + +use async_trait::async_trait; + +use super::types::{EvaluatorResponse, GuardConfig, GuardrailError}; + +/// Trait for guardrail evaluator clients. +/// Each provider (traceloop, etc.) implements this to call its evaluator API. +#[async_trait] +pub trait GuardrailClient: Send + Sync { + async fn evaluate( + &self, + guard: &GuardConfig, + input: &str, + ) -> Result; +} + +/// Create a guardrail client based on the guard's provider type. +pub fn create_guardrail_client(_guard: &GuardConfig) -> Option> { + todo!("Implement client factory based on guard.provider") +} diff --git a/src/guardrails/providers/traceloop.rs b/src/guardrails/providers/traceloop.rs new file mode 100644 index 00000000..30b43800 --- /dev/null +++ b/src/guardrails/providers/traceloop.rs @@ -0,0 +1,36 @@ +use async_trait::async_trait; +use std::time::Duration; + +use super::GuardrailClient; +use crate::guardrails::types::{EvaluatorResponse, GuardConfig, GuardrailError}; + +/// HTTP client for the Traceloop evaluator API service. +/// Calls `POST {api_base}/v2/guardrails/{evaluator_slug}`. +pub struct TraceloopClient { + http_client: reqwest::Client, +} + +impl TraceloopClient { + pub fn new() -> Self { + Self { + http_client: reqwest::Client::new(), + } + } + + pub fn with_timeout(timeout: Duration) -> Self { + Self { + http_client: reqwest::Client::builder().timeout(timeout).build().unwrap(), + } + } +} + +#[async_trait] +impl GuardrailClient for TraceloopClient { + async fn evaluate( + &self, + _guard: &GuardConfig, + _input: &str, + ) -> Result { + todo!("Implement Traceloop evaluator API call") + } +} diff --git a/src/guardrails/response_parser.rs b/src/guardrails/response_parser.rs new file mode 100644 index 00000000..eb39cb67 --- /dev/null +++ b/src/guardrails/response_parser.rs @@ -0,0 +1,14 @@ +use super::types::{EvaluatorResponse, GuardrailError}; + +/// Parse the evaluator response body (JSON string) into an EvaluatorResponse. +pub fn parse_evaluator_response(_body: &str) -> Result { + todo!("Implement evaluator response parsing") +} + +/// Parse an HTTP response from the evaluator, handling non-200 status codes. +pub fn parse_evaluator_http_response( + _status: u16, + _body: &str, +) -> Result { + todo!("Implement HTTP response parsing") +} diff --git a/src/guardrails/stream_buffer.rs b/src/guardrails/stream_buffer.rs new file mode 100644 index 00000000..526af7f5 --- /dev/null +++ b/src/guardrails/stream_buffer.rs @@ -0,0 +1,7 @@ +use crate::models::streaming::ChatCompletionChunk; + +/// Extract and concatenate text from accumulated streaming chunks. +/// Joins the delta content from all chunks into a single string. +pub fn extract_text_from_chunks(_chunks: &[ChatCompletionChunk]) -> String { + todo!("Implement text extraction from streaming chunks") +} diff --git a/src/guardrails/types.rs b/src/guardrails/types.rs new file mode 100644 index 00000000..560ae633 --- /dev/null +++ b/src/guardrails/types.rs @@ -0,0 +1,140 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; + +fn default_on_failure() -> OnFailure { + OnFailure::Warn +} + +fn default_required() -> bool { + true +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)] +#[serde(rename_all = "snake_case")] +pub enum GuardMode { + PreCall, + PostCall, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)] +#[serde(rename_all = "snake_case")] +pub enum OnFailure { + Block, + Warn, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)] +pub struct ProviderConfig { + pub name: String, + pub api_base: String, + pub api_key: String, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct GuardConfig { + pub name: String, + pub provider: String, + pub evaluator_slug: String, + #[serde(default)] + pub params: HashMap, + pub mode: GuardMode, + #[serde(default = "default_on_failure")] + pub on_failure: OnFailure, + #[serde(default = "default_required")] + pub required: bool, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub api_base: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub api_key: Option, +} + +impl Hash for GuardConfig { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.provider.hash(state); + self.evaluator_slug.hash(state); + // Hash params by sorting keys and hashing serialized values + let mut params_vec: Vec<_> = self.params.iter().collect(); + params_vec.sort_by_key(|(k, _)| (*k).clone()); + for (k, v) in params_vec { + k.hash(state); + v.to_string().hash(state); + } + self.mode.hash(state); + self.on_failure.hash(state); + self.required.hash(state); + self.api_base.hash(state); + self.api_key.hash(state); + } +} + +#[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq)] +pub struct GuardrailsConfig { + #[serde(default)] + pub providers: Vec, + #[serde(default)] + pub guards: Vec, +} + +impl Hash for GuardrailsConfig { + fn hash(&self, state: &mut H) { + self.providers.hash(state); + self.guards.hash(state); + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EvaluatorResponse { + pub result: serde_json::Value, + pub pass: bool, +} + +#[derive(Debug, Clone)] +pub enum GuardResult { + Passed { + name: String, + result: serde_json::Value, + }, + Failed { + name: String, + result: serde_json::Value, + on_failure: OnFailure, + }, + Error { + name: String, + error: String, + required: bool, + }, +} + +#[derive(Debug, Clone)] +pub struct GuardrailsOutcome { + pub results: Vec, + pub blocked: bool, + pub blocking_guard: Option, + pub warnings: Vec, +} + +#[derive(Debug, Clone)] +pub enum GuardrailError { + Unavailable(String), + HttpError { status: u16, body: String }, + Timeout(String), + ParseError(String), +} + +impl std::fmt::Display for GuardrailError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + GuardrailError::Unavailable(msg) => write!(f, "Evaluator unavailable: {msg}"), + GuardrailError::HttpError { status, body } => { + write!(f, "HTTP error {status}: {body}") + } + GuardrailError::Timeout(msg) => write!(f, "Timeout: {msg}"), + GuardrailError::ParseError(msg) => write!(f, "Parse error: {msg}"), + } + } +} + +impl std::error::Error for GuardrailError {} diff --git a/src/lib.rs b/src/lib.rs index 73b93200..1663bf71 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ pub mod ai_models; pub mod config; +pub mod guardrails; pub mod management; pub mod models; pub mod openapi; diff --git a/src/types/mod.rs b/src/types/mod.rs index 113152dc..b2e3d63c 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -169,4 +169,6 @@ pub struct GatewayConfig { pub models: Vec, #[serde(default, skip_serializing_if = "Vec::is_empty")] pub pipelines: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub guardrails: Option, } diff --git a/tests/config_hash_tests.rs b/tests/config_hash_tests.rs index 53ff4597..d2994d2e 100644 --- a/tests/config_hash_tests.rs +++ b/tests/config_hash_tests.rs @@ -6,6 +6,7 @@ use std::collections::HashMap; fn test_identical_configs_have_same_hash() { let config1 = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test".to_string(), r#type: ProviderType::OpenAI, @@ -29,6 +30,7 @@ fn test_identical_configs_have_same_hash() { fn test_different_configs_have_different_hashes() { let config1 = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test1".to_string(), r#type: ProviderType::OpenAI, @@ -41,6 +43,7 @@ fn test_different_configs_have_different_hashes() { let config2 = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test2".to_string(), // Different key r#type: ProviderType::OpenAI, @@ -70,6 +73,7 @@ fn test_params_order_independence() { let config1 = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test".to_string(), r#type: ProviderType::OpenAI, @@ -82,6 +86,7 @@ fn test_params_order_independence() { let config2 = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test".to_string(), r#type: ProviderType::OpenAI, diff --git a/tests/guardrails/helpers.rs b/tests/guardrails/helpers.rs new file mode 100644 index 00000000..2e989dd8 --- /dev/null +++ b/tests/guardrails/helpers.rs @@ -0,0 +1,230 @@ +use std::collections::HashMap; + +use async_trait::async_trait; +use hub_lib::guardrails::providers::GuardrailClient; +use hub_lib::guardrails::types::{ + EvaluatorResponse, GuardConfig, GuardMode, GuardrailError, OnFailure, +}; +use hub_lib::models::chat::{ChatCompletion, ChatCompletionChoice, ChatCompletionRequest}; +use hub_lib::models::content::{ChatCompletionMessage, ChatMessageContent}; +use hub_lib::models::streaming::{ChatCompletionChunk, Choice, ChoiceDelta}; +use hub_lib::models::usage::Usage; +use serde_json::json; + +// --------------------------------------------------------------------------- +// Guard config builders +// --------------------------------------------------------------------------- + +pub fn create_test_guard(name: &str, mode: GuardMode) -> GuardConfig { + GuardConfig { + name: name.to_string(), + provider: "traceloop".to_string(), + evaluator_slug: "test-evaluator".to_string(), + params: HashMap::new(), + mode, + on_failure: OnFailure::Block, + required: true, + api_base: Some("http://localhost:8080".to_string()), + api_key: Some("test-api-key".to_string()), + } +} + +pub fn create_test_guard_with_failure_action( + name: &str, + mode: GuardMode, + on_failure: OnFailure, +) -> GuardConfig { + let mut guard = create_test_guard(name, mode); + guard.on_failure = on_failure; + guard +} + +pub fn create_test_guard_with_required(name: &str, mode: GuardMode, required: bool) -> GuardConfig { + let mut guard = create_test_guard(name, mode); + guard.required = required; + guard +} + +#[allow(dead_code)] +pub fn create_test_guard_with_api_base(name: &str, mode: GuardMode, api_base: &str) -> GuardConfig { + let mut guard = create_test_guard(name, mode); + guard.api_base = Some(api_base.to_string()); + guard +} + +// --------------------------------------------------------------------------- +// Evaluator response builders +// --------------------------------------------------------------------------- + +pub fn passing_response() -> EvaluatorResponse { + EvaluatorResponse { + result: json!({"score": 0.95, "label": "safe"}), + pass: true, + } +} + +pub fn failing_response() -> EvaluatorResponse { + EvaluatorResponse { + result: json!({"score": 0.2, "label": "unsafe", "reason": "Content violates policy"}), + pass: false, + } +} + +// --------------------------------------------------------------------------- +// Chat request/response builders +// --------------------------------------------------------------------------- + +pub fn default_message() -> ChatCompletionMessage { + ChatCompletionMessage { + role: String::new(), + content: None, + name: None, + tool_calls: None, + tool_call_id: None, + refusal: None, + } +} + +#[allow(dead_code)] +pub fn default_request() -> ChatCompletionRequest { + ChatCompletionRequest { + model: "gpt-4".to_string(), + messages: vec![], + temperature: None, + top_p: None, + n: None, + stream: None, + stop: None, + max_tokens: None, + max_completion_tokens: None, + parallel_tool_calls: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + tool_choice: None, + tools: None, + user: None, + logprobs: None, + top_logprobs: None, + response_format: None, + reasoning: None, + } +} + +pub fn create_test_chat_request(user_message: &str) -> ChatCompletionRequest { + ChatCompletionRequest { + model: "gpt-4".to_string(), + messages: vec![ChatCompletionMessage { + role: "user".to_string(), + content: Some(ChatMessageContent::String(user_message.to_string())), + name: None, + tool_calls: None, + tool_call_id: None, + refusal: None, + }], + temperature: None, + top_p: None, + n: None, + stream: None, + stop: None, + max_tokens: None, + max_completion_tokens: None, + parallel_tool_calls: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + tool_choice: None, + tools: None, + user: None, + logprobs: None, + top_logprobs: None, + response_format: None, + reasoning: None, + } +} + +pub fn create_test_chat_completion(response_text: &str) -> ChatCompletion { + ChatCompletion { + id: "chatcmpl-test".to_string(), + object: Some("chat.completion".to_string()), + created: Some(1234567890), + model: "gpt-4".to_string(), + choices: vec![ChatCompletionChoice { + index: 0, + message: ChatCompletionMessage { + role: "assistant".to_string(), + content: Some(ChatMessageContent::String(response_text.to_string())), + name: None, + tool_calls: None, + tool_call_id: None, + refusal: None, + }, + finish_reason: Some("stop".to_string()), + logprobs: None, + }], + usage: Usage::default(), + system_fingerprint: None, + } +} + +pub fn create_test_chunk(content: &str) -> ChatCompletionChunk { + ChatCompletionChunk { + id: "chunk-1".to_string(), + choices: vec![Choice { + delta: ChoiceDelta { + content: Some(content.to_string()), + role: None, + tool_calls: None, + reasoning: None, + }, + finish_reason: None, + index: 0, + logprobs: None, + }], + created: 1234567890, + model: "gpt-4".to_string(), + service_tier: None, + system_fingerprint: None, + usage: None, + } +} + +// --------------------------------------------------------------------------- +// Mock GuardrailClient +// --------------------------------------------------------------------------- + +pub struct MockGuardrailClient { + pub responses: HashMap>, +} + +impl MockGuardrailClient { + pub fn with_response(name: &str, resp: Result) -> Self { + let mut responses = HashMap::new(); + responses.insert(name.to_string(), resp); + Self { responses } + } + + pub fn with_responses(entries: Vec<(&str, Result)>) -> Self { + let mut responses = HashMap::new(); + for (name, resp) in entries { + responses.insert(name.to_string(), resp); + } + Self { responses } + } +} + +#[async_trait] +impl GuardrailClient for MockGuardrailClient { + async fn evaluate( + &self, + guard: &GuardConfig, + _input: &str, + ) -> Result { + self.responses + .get(&guard.name) + .cloned() + .unwrap_or(Err(GuardrailError::Unavailable( + "no mock configured".to_string(), + ))) + } +} diff --git a/tests/guardrails/main.rs b/tests/guardrails/main.rs new file mode 100644 index 00000000..57873d30 --- /dev/null +++ b/tests/guardrails/main.rs @@ -0,0 +1,10 @@ +mod helpers; +mod test_api_control; +mod test_e2e; +mod test_executor; +mod test_input_extractor; +mod test_pipeline; +mod test_response_parser; +mod test_stream_buffer; +mod test_traceloop_client; +mod test_types; diff --git a/tests/guardrails/test_api_control.rs b/tests/guardrails/test_api_control.rs new file mode 100644 index 00000000..6f16fa01 --- /dev/null +++ b/tests/guardrails/test_api_control.rs @@ -0,0 +1,174 @@ +use hub_lib::guardrails::api_control::*; +use hub_lib::guardrails::types::GuardMode; +use serde_json::json; + +use super::helpers::*; + +// --------------------------------------------------------------------------- +// Phase 7: API Control (15 tests) +// --------------------------------------------------------------------------- + +#[test] +fn test_parse_guardrails_header_single() { + let names = parse_guardrails_header("pii-check"); + assert_eq!(names, vec!["pii-check"]); +} + +#[test] +fn test_parse_guardrails_header_multiple() { + let names = parse_guardrails_header("toxicity-check,relevance-check,pii-check"); + assert_eq!( + names, + vec!["toxicity-check", "relevance-check", "pii-check"] + ); +} + +#[test] +fn test_parse_guardrails_from_payload() { + let payload = json!({"guardrails": ["toxicity-check", "pii-check"]}); + let names = parse_guardrails_from_payload(&payload); + assert_eq!(names, vec!["toxicity-check", "pii-check"]); +} + +#[test] +fn test_pipeline_guardrails_always_included() { + let pipeline_guards = vec![create_test_guard("pipeline-guard", GuardMode::PreCall)]; + let resolved = resolve_guards_by_name(&pipeline_guards, &["pipeline-guard"], &[], &[]); + assert_eq!(resolved.len(), 1); + assert_eq!(resolved[0].name, "pipeline-guard"); +} + +#[test] +fn test_header_guardrails_additive_to_pipeline() { + let all_guards = vec![ + create_test_guard("pipeline-guard", GuardMode::PreCall), + create_test_guard("header-guard", GuardMode::PreCall), + ]; + let resolved = resolve_guards_by_name(&all_guards, &["pipeline-guard"], &["header-guard"], &[]); + assert_eq!(resolved.len(), 2); +} + +#[test] +fn test_payload_guardrails_additive_to_pipeline() { + let all_guards = vec![ + create_test_guard("pipeline-guard", GuardMode::PreCall), + create_test_guard("payload-guard", GuardMode::PreCall), + ]; + let resolved = + resolve_guards_by_name(&all_guards, &["pipeline-guard"], &[], &["payload-guard"]); + assert_eq!(resolved.len(), 2); +} + +#[test] +fn test_header_and_payload_both_additive() { + let all_guards = vec![ + create_test_guard("pipeline-guard", GuardMode::PreCall), + create_test_guard("header-guard", GuardMode::PreCall), + create_test_guard("payload-guard", GuardMode::PreCall), + ]; + let resolved = resolve_guards_by_name( + &all_guards, + &["pipeline-guard"], + &["header-guard"], + &["payload-guard"], + ); + assert_eq!(resolved.len(), 3); +} + +#[test] +fn test_deduplication_by_name() { + let all_guards = vec![create_test_guard("shared-guard", GuardMode::PreCall)]; + let resolved = resolve_guards_by_name( + &all_guards, + &["shared-guard"], + &["shared-guard"], // duplicate + &["shared-guard"], // duplicate + ); + assert_eq!(resolved.len(), 1); +} + +#[test] +fn test_unknown_guard_name_in_header_ignored() { + let all_guards = vec![create_test_guard("known-guard", GuardMode::PreCall)]; + let resolved = resolve_guards_by_name( + &all_guards, + &["known-guard"], + &["nonexistent-guard"], // unknown + &[], + ); + assert_eq!(resolved.len(), 1); + assert_eq!(resolved[0].name, "known-guard"); +} + +#[test] +fn test_unknown_guard_name_in_payload_ignored() { + let all_guards = vec![create_test_guard("known-guard", GuardMode::PreCall)]; + let resolved = resolve_guards_by_name( + &all_guards, + &["known-guard"], + &[], + &["nonexistent-guard"], // unknown + ); + assert_eq!(resolved.len(), 1); +} + +#[test] +fn test_empty_header_pipeline_guards_only() { + let all_guards = vec![create_test_guard("pipeline-guard", GuardMode::PreCall)]; + let resolved = resolve_guards_by_name(&all_guards, &["pipeline-guard"], &[], &[]); + assert_eq!(resolved.len(), 1); + assert_eq!(resolved[0].name, "pipeline-guard"); +} + +#[test] +fn test_empty_payload_pipeline_guards_only() { + let all_guards = vec![create_test_guard("pipeline-guard", GuardMode::PreCall)]; + let resolved = resolve_guards_by_name(&all_guards, &["pipeline-guard"], &[], &[]); + assert_eq!(resolved.len(), 1); +} + +#[test] +fn test_cannot_remove_pipeline_guardrails_via_api() { + let all_guards = vec![ + create_test_guard("pipeline-guard", GuardMode::PreCall), + create_test_guard("extra", GuardMode::PreCall), + ]; + // Even if header/payload only mention "extra", pipeline guard still included + let resolved = resolve_guards_by_name(&all_guards, &["pipeline-guard"], &["extra"], &[]); + assert!(resolved.iter().any(|g| g.name == "pipeline-guard")); +} + +#[test] +fn test_guards_split_into_pre_and_post_call() { + let guards = vec![ + create_test_guard("pre-1", GuardMode::PreCall), + create_test_guard("post-1", GuardMode::PostCall), + create_test_guard("pre-2", GuardMode::PreCall), + create_test_guard("post-2", GuardMode::PostCall), + ]; + let (pre_call, post_call) = split_guards_by_mode(&guards); + assert_eq!(pre_call.len(), 2); + assert_eq!(post_call.len(), 2); + assert!(pre_call.iter().all(|g| g.mode == GuardMode::PreCall)); + assert!(post_call.iter().all(|g| g.mode == GuardMode::PostCall)); +} + +#[test] +fn test_complete_resolution_merged() { + let all_guards = vec![ + create_test_guard("pipeline-pre", GuardMode::PreCall), + create_test_guard("pipeline-post", GuardMode::PostCall), + create_test_guard("header-pre", GuardMode::PreCall), + create_test_guard("payload-post", GuardMode::PostCall), + ]; + let resolved = resolve_guards_by_name( + &all_guards, + &["pipeline-pre", "pipeline-post"], + &["header-pre"], + &["payload-post"], + ); + assert_eq!(resolved.len(), 4); + let (pre, post) = split_guards_by_mode(&resolved); + assert_eq!(pre.len(), 2); // pipeline-pre + header-pre + assert_eq!(post.len(), 2); // pipeline-post + payload-post +} diff --git a/tests/guardrails/test_e2e.rs b/tests/guardrails/test_e2e.rs new file mode 100644 index 00000000..50403910 --- /dev/null +++ b/tests/guardrails/test_e2e.rs @@ -0,0 +1,100 @@ +// --------------------------------------------------------------------------- +// Phase 8: End-to-End Integration (15 tests) +// +// Full request flow tests using wiremock for both evaluator and LLM services. +// These validate the complete lifecycle from request to response. +// --------------------------------------------------------------------------- + +// TODO: These tests require full pipeline integration. +// They will be fully implemented when all prior phases are complete. +// For now, we define the test signatures to drive the implementation. + +#[tokio::test] +async fn test_e2e_pre_call_block_flow() { + // Request -> guard fail+block -> 403 + todo!("Implement E2E test: pre_call block flow") +} + +#[tokio::test] +async fn test_e2e_pre_call_pass_flow() { + // Request -> guard pass -> LLM -> 200 + todo!("Implement E2E test: pre_call pass flow") +} + +#[tokio::test] +async fn test_e2e_post_call_block_flow() { + // Request -> LLM -> guard fail+block -> 403 + todo!("Implement E2E test: post_call block flow") +} + +#[tokio::test] +async fn test_e2e_post_call_warn_flow() { + // Request -> LLM -> guard fail+warn -> 200 + header + todo!("Implement E2E test: post_call warn flow") +} + +#[tokio::test] +async fn test_e2e_pre_and_post_both_pass() { + // Both stages pass -> clean 200 response + todo!("Implement E2E test: pre and post both pass") +} + +#[tokio::test] +async fn test_e2e_pre_blocks_post_never_runs() { + // Pre blocks -> post evaluator gets 0 requests + todo!("Implement E2E test: pre blocks, post never runs") +} + +#[tokio::test] +async fn test_e2e_mixed_block_and_warn() { + // Multiple guards with mixed block/warn outcomes + todo!("Implement E2E test: mixed block and warn") +} + +#[tokio::test] +async fn test_e2e_streaming_post_call_buffer_pass() { + // Stream buffered, guard passes -> SSE response streamed to client + todo!("Implement E2E test: streaming post_call buffer pass") +} + +#[tokio::test] +async fn test_e2e_streaming_post_call_buffer_block() { + // Stream buffered, guard blocks -> 403 + todo!("Implement E2E test: streaming post_call buffer block") +} + +#[tokio::test] +async fn test_e2e_config_from_yaml_with_env_vars() { + // Full YAML config with ${VAR} substitution in api_key + todo!("Implement E2E test: config from YAML with env vars") +} + +#[tokio::test] +async fn test_e2e_multiple_guards_different_evaluators() { + // Different evaluator slugs -> separate mock expectations + todo!("Implement E2E test: multiple guards different evaluators") +} + +#[tokio::test] +async fn test_e2e_fail_open_evaluator_down() { + // Evaluator service down + required: false -> passthrough + todo!("Implement E2E test: fail open evaluator down") +} + +#[tokio::test] +async fn test_e2e_fail_closed_evaluator_down() { + // Evaluator service down + required: true -> 403 + todo!("Implement E2E test: fail closed evaluator down") +} + +#[tokio::test] +async fn test_e2e_config_validation_rejects_invalid() { + // Config with missing required fields -> startup validation error + todo!("Implement E2E test: config validation rejects invalid") +} + +#[tokio::test] +async fn test_e2e_backward_compat_no_guardrails() { + // Existing config without guardrails works unchanged + todo!("Implement E2E test: backward compat no guardrails") +} diff --git a/tests/guardrails/test_executor.rs b/tests/guardrails/test_executor.rs new file mode 100644 index 00000000..0e136c7d --- /dev/null +++ b/tests/guardrails/test_executor.rs @@ -0,0 +1,183 @@ +use hub_lib::guardrails::executor::*; +use hub_lib::guardrails::input_extractor::*; +use hub_lib::guardrails::types::*; + +use super::helpers::*; + +// --------------------------------------------------------------------------- +// Phase 5: Executor (12 tests) +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_execute_single_pre_call_guard_passes() { + let guard = create_test_guard("check", GuardMode::PreCall); + let mock_client = MockGuardrailClient::with_response("check", Ok(passing_response())); + let outcome = execute_guards(&[guard], "test input", &mock_client).await; + assert!(!outcome.blocked); + assert_eq!(outcome.results.len(), 1); + assert!(matches!(&outcome.results[0], GuardResult::Passed { .. })); + assert!(outcome.warnings.is_empty()); +} + +#[tokio::test] +async fn test_execute_single_pre_call_guard_fails_block() { + let guard = + create_test_guard_with_failure_action("check", GuardMode::PreCall, OnFailure::Block); + let mock_client = MockGuardrailClient::with_response("check", Ok(failing_response())); + let outcome = execute_guards(&[guard], "toxic input", &mock_client).await; + assert!(outcome.blocked); + assert_eq!(outcome.blocking_guard, Some("check".to_string())); +} + +#[tokio::test] +async fn test_execute_single_pre_call_guard_fails_warn() { + let guard = create_test_guard_with_failure_action("check", GuardMode::PreCall, OnFailure::Warn); + let mock_client = MockGuardrailClient::with_response("check", Ok(failing_response())); + let outcome = execute_guards(&[guard], "borderline input", &mock_client).await; + assert!(!outcome.blocked); + assert_eq!(outcome.warnings.len(), 1); + assert!(outcome.warnings[0].contains("check")); +} + +#[tokio::test] +async fn test_execute_multiple_pre_call_guards_all_pass() { + let guards = vec![ + create_test_guard("guard-1", GuardMode::PreCall), + create_test_guard("guard-2", GuardMode::PreCall), + create_test_guard("guard-3", GuardMode::PreCall), + ]; + let mock_client = MockGuardrailClient::with_responses(vec![ + ("guard-1", Ok(passing_response())), + ("guard-2", Ok(passing_response())), + ("guard-3", Ok(passing_response())), + ]); + let outcome = execute_guards(&guards, "safe input", &mock_client).await; + assert!(!outcome.blocked); + assert_eq!(outcome.results.len(), 3); + assert!(outcome.warnings.is_empty()); +} + +#[tokio::test] +async fn test_execute_multiple_guards_one_blocks() { + let guards = vec![ + create_test_guard("guard-1", GuardMode::PreCall), + create_test_guard_with_failure_action("guard-2", GuardMode::PreCall, OnFailure::Block), + create_test_guard("guard-3", GuardMode::PreCall), + ]; + let mock_client = MockGuardrailClient::with_responses(vec![ + ("guard-1", Ok(passing_response())), + ("guard-2", Ok(failing_response())), + ("guard-3", Ok(passing_response())), + ]); + let outcome = execute_guards(&guards, "input", &mock_client).await; + assert!(outcome.blocked); + assert_eq!(outcome.blocking_guard, Some("guard-2".to_string())); +} + +#[tokio::test] +async fn test_execute_multiple_guards_one_warns_continue() { + let guards = vec![ + create_test_guard("guard-1", GuardMode::PreCall), + create_test_guard_with_failure_action("guard-2", GuardMode::PreCall, OnFailure::Warn), + create_test_guard("guard-3", GuardMode::PreCall), + ]; + let mock_client = MockGuardrailClient::with_responses(vec![ + ("guard-1", Ok(passing_response())), + ("guard-2", Ok(failing_response())), + ("guard-3", Ok(passing_response())), + ]); + let outcome = execute_guards(&guards, "input", &mock_client).await; + assert!(!outcome.blocked); + assert_eq!(outcome.results.len(), 3); + assert_eq!(outcome.warnings.len(), 1); +} + +#[tokio::test] +async fn test_guard_evaluator_unavailable_required_false() { + let guard = create_test_guard_with_required("check", GuardMode::PreCall, false); + let mock_client = MockGuardrailClient::with_response( + "check", + Err(GuardrailError::Unavailable( + "connection refused".to_string(), + )), + ); + let outcome = execute_guards(&[guard], "input", &mock_client).await; + assert!(!outcome.blocked); // Fail-open + assert!(matches!( + &outcome.results[0], + GuardResult::Error { + required: false, + .. + } + )); +} + +#[tokio::test] +async fn test_guard_evaluator_unavailable_required_true() { + let guard = create_test_guard_with_required("check", GuardMode::PreCall, true); + let mock_client = MockGuardrailClient::with_response( + "check", + Err(GuardrailError::Unavailable( + "connection refused".to_string(), + )), + ); + let outcome = execute_guards(&[guard], "input", &mock_client).await; + assert!(outcome.blocked); // Fail-closed +} + +#[tokio::test] +async fn test_execute_post_call_guards_non_streaming() { + let guard = create_test_guard("response-check", GuardMode::PostCall); + let mock_client = MockGuardrailClient::with_response("response-check", Ok(passing_response())); + let completion = create_test_chat_completion("Safe response text"); + let response_text = extract_post_call_input_from_completion(&completion); + let outcome = execute_guards(&[guard], &response_text, &mock_client).await; + assert!(!outcome.blocked); +} + +#[tokio::test] +async fn test_execute_post_call_guards_streaming_accumulated() { + let guard = create_test_guard("response-check", GuardMode::PostCall); + let mock_client = MockGuardrailClient::with_response("response-check", Ok(passing_response())); + let accumulated_text = "Hello world from streaming!"; + let outcome = execute_guards(&[guard], accumulated_text, &mock_client).await; + assert!(!outcome.blocked); +} + +#[tokio::test] +async fn test_parallel_execution_of_independent_guards() { + // This test verifies guards run concurrently, not sequentially. + // We use the mock client without delay here; the implementation should + // use futures::join_all or similar for parallel execution. + let guards = vec![ + create_test_guard("guard-1", GuardMode::PreCall), + create_test_guard("guard-2", GuardMode::PreCall), + ]; + let mock_client = MockGuardrailClient::with_responses(vec![ + ("guard-1", Ok(passing_response())), + ("guard-2", Ok(passing_response())), + ]); + let start = std::time::Instant::now(); + let outcome = execute_guards(&guards, "input", &mock_client).await; + let _elapsed = start.elapsed(); + assert!(!outcome.blocked); + assert_eq!(outcome.results.len(), 2); +} + +#[tokio::test] +async fn test_executor_returns_correct_guardrails_outcome() { + let guards = vec![ + create_test_guard_with_failure_action("passer", GuardMode::PreCall, OnFailure::Block), + create_test_guard_with_failure_action("warner", GuardMode::PreCall, OnFailure::Warn), + create_test_guard_with_failure_action("blocker", GuardMode::PreCall, OnFailure::Block), + ]; + let mock_client = MockGuardrailClient::with_responses(vec![ + ("passer", Ok(passing_response())), + ("warner", Ok(failing_response())), + ("blocker", Ok(failing_response())), + ]); + let outcome = execute_guards(&guards, "input", &mock_client).await; + assert!(outcome.blocked); + assert_eq!(outcome.blocking_guard, Some("blocker".to_string())); + assert!(outcome.warnings.iter().any(|w| w.contains("warner"))); +} diff --git a/tests/guardrails/test_input_extractor.rs b/tests/guardrails/test_input_extractor.rs new file mode 100644 index 00000000..732b6d40 --- /dev/null +++ b/tests/guardrails/test_input_extractor.rs @@ -0,0 +1,76 @@ +use hub_lib::guardrails::input_extractor::*; +use hub_lib::models::content::{ChatCompletionMessage, ChatMessageContent, ChatMessageContentPart}; + +use super::helpers::*; + +// --------------------------------------------------------------------------- +// Phase 2: Input Extractor (5 tests) +// --------------------------------------------------------------------------- + +#[test] +fn test_extract_text_single_user_message() { + let request = create_test_chat_request("Hello world"); + let text = extract_pre_call_input(&request); + assert_eq!(text, "Hello world"); +} + +#[test] +fn test_extract_text_multi_turn_conversation() { + let mut request = default_request(); + request.messages = vec![ + ChatCompletionMessage { + role: "system".to_string(), + content: Some(ChatMessageContent::String("You are helpful".to_string())), + ..default_message() + }, + ChatCompletionMessage { + role: "user".to_string(), + content: Some(ChatMessageContent::String("First question".to_string())), + ..default_message() + }, + ChatCompletionMessage { + role: "assistant".to_string(), + content: Some(ChatMessageContent::String("First answer".to_string())), + ..default_message() + }, + ChatCompletionMessage { + role: "user".to_string(), + content: Some(ChatMessageContent::String("Follow-up question".to_string())), + ..default_message() + }, + ]; + let text = extract_pre_call_input(&request); + assert_eq!(text, "Follow-up question"); +} + +#[test] +fn test_extract_text_from_array_content_parts() { + let mut request = create_test_chat_request(""); + request.messages[0].content = Some(ChatMessageContent::Array(vec![ + ChatMessageContentPart { + r#type: "text".to_string(), + text: "Part 1".to_string(), + }, + ChatMessageContentPart { + r#type: "text".to_string(), + text: "Part 2".to_string(), + }, + ])); + let text = extract_pre_call_input(&request); + assert_eq!(text, "Part 1 Part 2"); +} + +#[test] +fn test_extract_response_from_chat_completion() { + let completion = create_test_chat_completion("Here is my response"); + let text = extract_post_call_input_from_completion(&completion); + assert_eq!(text, "Here is my response"); +} + +#[test] +fn test_extract_handles_empty_content() { + let mut request = create_test_chat_request(""); + request.messages[0].content = None; + let text = extract_pre_call_input(&request); + assert_eq!(text, ""); +} diff --git a/tests/guardrails/test_pipeline.rs b/tests/guardrails/test_pipeline.rs new file mode 100644 index 00000000..6243e6fd --- /dev/null +++ b/tests/guardrails/test_pipeline.rs @@ -0,0 +1,52 @@ +// --------------------------------------------------------------------------- +// Phase 6: Pipeline Integration (7 tests) +// +// These tests verify that guardrails are properly wired into the pipeline +// request handling flow. They use wiremock for both the evaluator and LLM. +// --------------------------------------------------------------------------- + +// TODO: These tests require pipeline integration implementation. +// They will be fully implemented when the pipeline hooks are added. +// For now, we define the test signatures to drive the implementation. + +#[tokio::test] +async fn test_pre_call_guardrails_block_before_llm() { + // Guard blocks the request -> 403 response, LLM receives 0 requests + todo!("Implement pipeline integration test: pre_call block") +} + +#[tokio::test] +async fn test_pre_call_guardrails_warn_and_continue() { + // Guard warns but request proceeds to LLM -> 200 + warning header + todo!("Implement pipeline integration test: pre_call warn") +} + +#[tokio::test] +async fn test_post_call_guardrails_block_response() { + // LLM responds, guard blocks output -> 403 + todo!("Implement pipeline integration test: post_call block") +} + +#[tokio::test] +async fn test_post_call_guardrails_warn_and_add_header() { + // LLM responds, guard warns -> 200 + X-Traceloop-Guardrail-Warning header + todo!("Implement pipeline integration test: post_call warn") +} + +#[tokio::test] +async fn test_warning_header_format() { + // Warning header format: guardrail_name="...", reason="..." + todo!("Implement pipeline integration test: warning header format") +} + +#[tokio::test] +async fn test_blocked_response_403_format() { + // Blocked response body: {"error": {"type": "guardrail_blocked", ...}} + todo!("Implement pipeline integration test: 403 response format") +} + +#[tokio::test] +async fn test_no_guardrails_passthrough() { + // No guardrails configured -> normal passthrough behavior + todo!("Implement pipeline integration test: passthrough") +} diff --git a/tests/guardrails/test_response_parser.rs b/tests/guardrails/test_response_parser.rs new file mode 100644 index 00000000..a9f2dc85 --- /dev/null +++ b/tests/guardrails/test_response_parser.rs @@ -0,0 +1,69 @@ +use hub_lib::guardrails::response_parser::*; +use hub_lib::guardrails::types::GuardrailError; + +// --------------------------------------------------------------------------- +// Phase 3: Response Parser (8 tests) +// --------------------------------------------------------------------------- + +#[test] +fn test_parse_successful_pass_response() { + let body = r#"{"result": {"score": 0.95, "label": "safe"}, "pass": true}"#; + let response = parse_evaluator_response(body).unwrap(); + assert!(response.pass); + assert_eq!(response.result["score"], 0.95); +} + +#[test] +fn test_parse_failed_response() { + let body = r#"{"result": {"score": 0.2, "reason": "Toxic content"}, "pass": false}"#; + let response = parse_evaluator_response(body).unwrap(); + assert!(!response.pass); + assert_eq!(response.result["reason"], "Toxic content"); +} + +#[test] +fn test_parse_with_result_details() { + let body = r#"{"result": {"score": 0.75, "label": "borderline", "categories": ["violence", "profanity"]}, "pass": true}"#; + let response = parse_evaluator_response(body).unwrap(); + assert!(response.pass); + assert_eq!(response.result["label"], "borderline"); + assert_eq!(response.result["categories"][0], "violence"); +} + +#[test] +fn test_parse_missing_pass_field() { + let body = r#"{"result": {"score": 0.5}}"#; + let result = parse_evaluator_response(body); + assert!(result.is_err()); +} + +#[test] +fn test_parse_malformed_json() { + let body = "not json {at all"; + let result = parse_evaluator_response(body); + assert!(result.is_err()); +} + +#[test] +fn test_parse_non_json_response() { + let body = "Internal Server Error"; + let result = parse_evaluator_response(body); + assert!(result.is_err()); +} + +#[test] +fn test_parse_empty_response_body() { + let body = ""; + let result = parse_evaluator_response(body); + assert!(result.is_err()); +} + +#[test] +fn test_parse_http_error_status() { + let result = parse_evaluator_http_response(500, "Internal Server Error"); + assert!(result.is_err()); + match result.unwrap_err() { + GuardrailError::HttpError { status, .. } => assert_eq!(status, 500), + other => panic!("Expected HttpError, got {other:?}"), + } +} diff --git a/tests/guardrails/test_stream_buffer.rs b/tests/guardrails/test_stream_buffer.rs new file mode 100644 index 00000000..8fe1b2f5 --- /dev/null +++ b/tests/guardrails/test_stream_buffer.rs @@ -0,0 +1,19 @@ +use hub_lib::guardrails::stream_buffer::*; + +use super::helpers::*; + +// --------------------------------------------------------------------------- +// Phase 2: Stream Buffer (1 test) +// --------------------------------------------------------------------------- + +#[test] +fn test_extract_from_accumulated_stream_chunks() { + let chunks = vec![ + create_test_chunk("Hello"), + create_test_chunk(" "), + create_test_chunk("world"), + create_test_chunk("!"), + ]; + let text = extract_text_from_chunks(&chunks); + assert_eq!(text, "Hello world!"); +} diff --git a/tests/guardrails/test_traceloop_client.rs b/tests/guardrails/test_traceloop_client.rs new file mode 100644 index 00000000..a92c337e --- /dev/null +++ b/tests/guardrails/test_traceloop_client.rs @@ -0,0 +1,125 @@ +use hub_lib::guardrails::providers::traceloop::TraceloopClient; +use hub_lib::guardrails::providers::{GuardrailClient, create_guardrail_client}; +use hub_lib::guardrails::types::GuardMode; +use serde_json::json; +use wiremock::matchers; +use wiremock::{Mock, MockServer, ResponseTemplate}; + +use super::helpers::*; + +// --------------------------------------------------------------------------- +// Phase 4: Provider Client System (7 tests) +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_traceloop_client_constructs_correct_url() { + let mock_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .and(matchers::path("/v2/guardrails/toxicity")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({"result": {}, "pass": true}))) + .expect(1) + .mount(&mock_server) + .await; + + let mut guard = create_test_guard_with_api_base("test", GuardMode::PreCall, &mock_server.uri()); + guard.evaluator_slug = "toxicity".to_string(); + + let client = TraceloopClient::new(); + let result = client.evaluate(&guard, "test input").await; + assert!(result.is_ok()); +} + +#[tokio::test] +async fn test_traceloop_client_sends_correct_headers() { + let mock_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .and(matchers::header("Authorization", "Bearer test-api-key")) + .and(matchers::header("Content-Type", "application/json")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({"result": {}, "pass": true}))) + .expect(1) + .mount(&mock_server) + .await; + + let guard = create_test_guard_with_api_base("test", GuardMode::PreCall, &mock_server.uri()); + let client = TraceloopClient::new(); + let result = client.evaluate(&guard, "test input").await; + assert!(result.is_ok()); +} + +#[tokio::test] +async fn test_traceloop_client_sends_correct_body() { + let mock_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .and(matchers::body_json(json!({ + "inputs": ["test input text"], + "config": {"threshold": 0.5} + }))) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({"result": {}, "pass": true}))) + .expect(1) + .mount(&mock_server) + .await; + + let mut guard = create_test_guard_with_api_base("test", GuardMode::PreCall, &mock_server.uri()); + guard.params.insert("threshold".to_string(), json!(0.5)); + + let client = TraceloopClient::new(); + let result = client.evaluate(&guard, "test input text").await; + assert!(result.is_ok()); +} + +#[tokio::test] +async fn test_traceloop_client_handles_successful_response() { + let mock_server = MockServer::start().await; + Mock::given(matchers::any()) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {"score": 0.9, "label": "safe"}, + "pass": true + }))) + .mount(&mock_server) + .await; + + let guard = create_test_guard_with_api_base("test", GuardMode::PreCall, &mock_server.uri()); + let client = TraceloopClient::new(); + let result = client.evaluate(&guard, "safe input").await.unwrap(); + assert!(result.pass); + assert_eq!(result.result["score"], 0.9); +} + +#[tokio::test] +async fn test_traceloop_client_handles_error_response() { + let mock_server = MockServer::start().await; + Mock::given(matchers::any()) + .respond_with(ResponseTemplate::new(500).set_body_string("Internal Server Error")) + .mount(&mock_server) + .await; + + let guard = create_test_guard_with_api_base("test", GuardMode::PreCall, &mock_server.uri()); + let client = TraceloopClient::new(); + let result = client.evaluate(&guard, "test").await; + assert!(result.is_err()); +} + +#[tokio::test] +async fn test_traceloop_client_handles_timeout() { + let mock_server = MockServer::start().await; + Mock::given(matchers::any()) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(json!({"result": {}, "pass": true})) + .set_delay(std::time::Duration::from_secs(30)), + ) + .mount(&mock_server) + .await; + + let guard = create_test_guard_with_api_base("test", GuardMode::PreCall, &mock_server.uri()); + let client = TraceloopClient::with_timeout(std::time::Duration::from_millis(100)); + let result = client.evaluate(&guard, "test").await; + assert!(result.is_err()); +} + +#[test] +fn test_client_creation_from_guard_config() { + let guard = create_test_guard("test", GuardMode::PreCall); + let client = create_guardrail_client(&guard); + assert!(client.is_some()); +} diff --git a/tests/guardrails/test_types.rs b/tests/guardrails/test_types.rs new file mode 100644 index 00000000..8f2383e8 --- /dev/null +++ b/tests/guardrails/test_types.rs @@ -0,0 +1,252 @@ +use hub_lib::guardrails::types::*; +use hub_lib::types::GatewayConfig; +use std::io::Write; +use tempfile::NamedTempFile; + +// --------------------------------------------------------------------------- +// Phase 1: Core Types & Configuration (9 tests + 4 provider tests) +// --------------------------------------------------------------------------- + +#[test] +fn test_guard_mode_deserialize_pre_call() { + let mode: GuardMode = serde_json::from_str("\"pre_call\"").unwrap(); + assert_eq!(mode, GuardMode::PreCall); +} + +#[test] +fn test_guard_mode_deserialize_post_call() { + let mode: GuardMode = serde_json::from_str("\"post_call\"").unwrap(); + assert_eq!(mode, GuardMode::PostCall); +} + +#[test] +fn test_on_failure_defaults_to_warn() { + let json = serde_json::json!({ + "name": "test-guard", + "provider": "traceloop", + "evaluator_slug": "toxicity", + "mode": "pre_call" + }); + let guard: GuardConfig = serde_json::from_value(json).unwrap(); + assert_eq!(guard.on_failure, OnFailure::Warn); +} + +#[test] +fn test_required_defaults_to_true() { + let json = serde_json::json!({ + "name": "test-guard", + "provider": "traceloop", + "evaluator_slug": "toxicity", + "mode": "pre_call" + }); + let guard: GuardConfig = serde_json::from_value(json).unwrap(); + assert!(guard.required); +} + +#[test] +fn test_guard_config_full_deserialization() { + let json = serde_json::json!({ + "name": "toxicity-check", + "provider": "traceloop", + "evaluator_slug": "toxicity", + "params": { + "threshold": 0.5 + }, + "mode": "pre_call", + "on_failure": "block", + "required": false, + "api_base": "https://api.traceloop.com", + "api_key": "tl-key-123" + }); + let guard: GuardConfig = serde_json::from_value(json).unwrap(); + assert_eq!(guard.name, "toxicity-check"); + assert_eq!(guard.provider, "traceloop"); + assert_eq!(guard.evaluator_slug, "toxicity"); + assert_eq!(guard.params.get("threshold").unwrap(), 0.5); + assert_eq!(guard.mode, GuardMode::PreCall); + assert_eq!(guard.on_failure, OnFailure::Block); + assert!(!guard.required); + assert_eq!(guard.api_base.unwrap(), "https://api.traceloop.com"); + assert_eq!(guard.api_key.unwrap(), "tl-key-123"); +} + +#[test] +fn test_guardrails_config_yaml_deserialization() { + let yaml = r#" +guards: + - name: toxicity-check + provider: traceloop + evaluator_slug: toxicity + mode: pre_call + on_failure: block + required: true + api_base: "https://api.traceloop.com" + api_key: "test-key" + - name: relevance-check + provider: traceloop + evaluator_slug: relevance + mode: post_call + on_failure: warn + api_base: "https://api.traceloop.com" + api_key: "test-key" +"#; + let config: GuardrailsConfig = serde_yaml::from_str(yaml).unwrap(); + assert_eq!(config.guards.len(), 2); + assert_eq!(config.guards[0].name, "toxicity-check"); + assert_eq!(config.guards[0].evaluator_slug, "toxicity"); + assert_eq!(config.guards[0].mode, GuardMode::PreCall); + assert_eq!(config.guards[1].name, "relevance-check"); + assert_eq!(config.guards[1].evaluator_slug, "relevance"); + assert_eq!(config.guards[1].mode, GuardMode::PostCall); + assert_eq!(config.guards[1].on_failure, OnFailure::Warn); +} + +#[test] +fn test_gateway_config_with_guardrails() { + use super::helpers::create_test_guard; + let config = GatewayConfig { + general: None, + providers: vec![], + models: vec![], + pipelines: vec![], + guardrails: Some(GuardrailsConfig { + providers: vec![], + guards: vec![create_test_guard("test", GuardMode::PreCall)], + }), + }; + assert!(config.guardrails.is_some()); + assert_eq!(config.guardrails.unwrap().guards.len(), 1); +} + +#[test] +fn test_gateway_config_without_guardrails_backward_compat() { + let json = serde_json::json!({ + "providers": [], + "models": [], + "pipelines": [] + }); + let config: GatewayConfig = serde_json::from_value(json).unwrap(); + assert!(config.guardrails.is_none()); +} + +#[test] +fn test_guard_config_env_var_in_api_key() { + unsafe { + std::env::set_var("TEST_GUARD_API_KEY_UNIQUE", "tl-secret-key"); + } + let config_content = r#" +providers: + - key: openai + type: openai + api_key: "sk-test" +models: + - key: gpt-4 + type: gpt-4 + provider: openai +pipelines: + - name: default + type: chat + plugins: + - model-router: + models: + - gpt-4 +guardrails: + guards: + - name: toxicity-check + provider: traceloop + evaluator_slug: toxicity + mode: pre_call + api_base: "https://api.traceloop.com" + api_key: "${TEST_GUARD_API_KEY_UNIQUE}" +"#; + let mut temp_file = NamedTempFile::new().unwrap(); + temp_file.write_all(config_content.as_bytes()).unwrap(); + let config = hub_lib::config::load_config(temp_file.path().to_str().unwrap()).unwrap(); + let guards = config.guardrails.unwrap().guards; + assert_eq!(guards[0].api_key.as_deref(), Some("tl-secret-key")); + unsafe { + std::env::remove_var("TEST_GUARD_API_KEY_UNIQUE"); + } +} + +// --------------------------------------------------------------------------- +// Provider config tests +// --------------------------------------------------------------------------- + +#[test] +fn test_provider_config_deserialization() { + let json = serde_json::json!({ + "name": "traceloop", + "api_base": "https://api.traceloop.com", + "api_key": "tl-key-123" + }); + let provider: ProviderConfig = serde_json::from_value(json).unwrap(); + assert_eq!(provider.name, "traceloop"); + assert_eq!(provider.api_base, "https://api.traceloop.com"); + assert_eq!(provider.api_key, "tl-key-123"); +} + +#[test] +fn test_guardrails_config_with_providers_yaml() { + let yaml = r#" +providers: + - name: traceloop + api_base: "https://api.traceloop.com" + api_key: "tl-key-123" +guards: + - name: toxicity-check + provider: traceloop + evaluator_slug: toxicity + mode: pre_call + on_failure: block + - name: pii-check + provider: traceloop + evaluator_slug: pii-detection + mode: pre_call + on_failure: block + api_base: "https://custom.traceloop.com" + api_key: "custom-key" +"#; + let config: GuardrailsConfig = serde_yaml::from_str(yaml).unwrap(); + assert_eq!(config.providers.len(), 1); + assert_eq!(config.providers[0].name, "traceloop"); + assert_eq!(config.providers[0].api_base, "https://api.traceloop.com"); + assert_eq!(config.guards.len(), 2); + // First guard has no api_base/api_key (inherits from provider) + assert!(config.guards[0].api_base.is_none()); + assert!(config.guards[0].api_key.is_none()); + // Second guard overrides api_base/api_key + assert_eq!( + config.guards[1].api_base.as_deref(), + Some("https://custom.traceloop.com") + ); + assert_eq!(config.guards[1].api_key.as_deref(), Some("custom-key")); +} + +#[test] +fn test_guard_without_api_base_deserializes() { + let json = serde_json::json!({ + "name": "toxicity-check", + "provider": "traceloop", + "evaluator_slug": "toxicity", + "mode": "pre_call" + }); + let guard: GuardConfig = serde_json::from_value(json).unwrap(); + assert!(guard.api_base.is_none()); + assert!(guard.api_key.is_none()); +} + +#[test] +fn test_guard_config_evaluator_slug_not_in_params() { + let json = serde_json::json!({ + "name": "toxicity-check", + "provider": "traceloop", + "evaluator_slug": "toxicity", + "params": {"threshold": 0.5}, + "mode": "pre_call" + }); + let guard: GuardConfig = serde_json::from_value(json).unwrap(); + assert_eq!(guard.evaluator_slug, "toxicity"); + assert!(!guard.params.contains_key("evaluator_slug")); + assert_eq!(guard.params.get("threshold").unwrap(), 0.5); +} diff --git a/tests/pipeline_header_routing_test.rs b/tests/pipeline_header_routing_test.rs index e015138d..02840f59 100644 --- a/tests/pipeline_header_routing_test.rs +++ b/tests/pipeline_header_routing_test.rs @@ -37,6 +37,7 @@ fn create_test_config_with_multiple_pipelines() -> GatewayConfig { GatewayConfig { general: None, + guardrails: None, providers: vec![provider], models: vec![model], pipelines: vec![pipeline1, pipeline2], diff --git a/tests/router_cache_tests.rs b/tests/router_cache_tests.rs index ecec5b40..bd84450e 100644 --- a/tests/router_cache_tests.rs +++ b/tests/router_cache_tests.rs @@ -9,6 +9,7 @@ async fn test_router_always_available() { // Create a basic configuration let config = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test-provider".to_string(), r#type: ProviderType::OpenAI, @@ -43,6 +44,7 @@ async fn test_configuration_change_detection() { // Create initial configuration let initial_config = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test-provider".to_string(), r#type: ProviderType::OpenAI, @@ -87,6 +89,7 @@ async fn test_configuration_change_detection() { async fn test_invalid_configuration_rejected() { let initial_config = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test-provider".to_string(), r#type: ProviderType::OpenAI, @@ -102,6 +105,7 @@ async fn test_invalid_configuration_rejected() { // Create invalid configuration (model references non-existent provider) let invalid_config = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test-provider".to_string(), r#type: ProviderType::OpenAI, @@ -130,6 +134,7 @@ async fn test_invalid_configuration_rejected() { async fn test_concurrent_router_access() { let config = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test-provider".to_string(), r#type: ProviderType::OpenAI, @@ -181,6 +186,7 @@ async fn test_empty_configuration_fallback() { providers: vec![], models: vec![], pipelines: vec![], + guardrails: None, }; let app_state = Arc::new(AppState::new(empty_config).expect("Failed to create app state")); @@ -195,6 +201,7 @@ async fn test_pipeline_with_failing_tracing_endpoint() { // Create configuration with a pipeline that has a failing tracing endpoint let config = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test-provider".to_string(), r#type: ProviderType::OpenAI, @@ -250,6 +257,7 @@ async fn test_tracing_isolation_between_pipelines() { // Create configuration with two pipelines - one with tracing, one without let config = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test-provider".to_string(), r#type: ProviderType::OpenAI, diff --git a/tests/router_integration_test.rs b/tests/router_integration_test.rs index 3560d8af..dd3060e5 100644 --- a/tests/router_integration_test.rs +++ b/tests/router_integration_test.rs @@ -12,6 +12,7 @@ async fn test_router_integration_flow() { providers: vec![], models: vec![], pipelines: vec![], + guardrails: None, }; let app_state = Arc::new(AppState::new(empty_config).expect("Failed to create app state")); @@ -22,6 +23,7 @@ async fn test_router_integration_flow() { // Test 2: Valid configuration let valid_config = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test-provider".to_string(), r#type: ProviderType::OpenAI, @@ -63,6 +65,7 @@ async fn test_router_integration_flow() { // Test 6: Invalid configuration rejection let invalid_config = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test-provider".to_string(), r#type: ProviderType::OpenAI, @@ -95,6 +98,7 @@ async fn test_router_integration_flow() { // Test 7: Multiple pipeline configuration let multi_pipeline_config = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test-provider".to_string(), r#type: ProviderType::OpenAI, @@ -158,6 +162,7 @@ async fn test_router_integration_flow() { async fn test_concurrent_configuration_updates() { let initial_config = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test-provider".to_string(), r#type: ProviderType::OpenAI, @@ -189,6 +194,7 @@ async fn test_concurrent_configuration_updates() { // Create a slightly different configuration for each task let config = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: format!("provider-{}", i), r#type: ProviderType::OpenAI, diff --git a/tests/unified_openapi_test.rs b/tests/unified_openapi_test.rs index 6b298270..9bbbd8cd 100644 --- a/tests/unified_openapi_test.rs +++ b/tests/unified_openapi_test.rs @@ -125,6 +125,7 @@ async fn test_router_creation_no_conflicts() { // Create a basic configuration for testing let config = GatewayConfig { general: None, + guardrails: None, providers: vec![Provider { key: "test-provider".to_string(), r#type: ProviderType::OpenAI, From 3cf4f780aea9ee9b1ae4b58f84cec099a0b52874 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Tue, 10 Feb 2026 11:00:08 +0200 Subject: [PATCH 02/59] add implementation --- src/guardrails/api_control.rs | 69 +++- src/guardrails/executor.rs | 78 ++++- src/guardrails/input_extractor.rs | 37 +- src/guardrails/providers/mod.rs | 8 +- src/guardrails/providers/traceloop.rs | 47 ++- src/guardrails/response_parser.rs | 17 +- src/guardrails/stream_buffer.rs | 8 +- src/pipelines/pipeline.rs | 165 ++++++++- src/state.rs | 6 +- tests/guardrails/test_e2e.rs | 482 ++++++++++++++++++++++++-- tests/guardrails/test_pipeline.rs | 197 ++++++++++- 11 files changed, 1028 insertions(+), 86 deletions(-) diff --git a/src/guardrails/api_control.rs b/src/guardrails/api_control.rs index dac3f693..b0e1dcad 100644 --- a/src/guardrails/api_control.rs +++ b/src/guardrails/api_control.rs @@ -1,28 +1,73 @@ +use std::collections::HashSet; + use super::types::{GuardConfig, GuardMode}; /// Parse guard names from the X-Traceloop-Guardrails header value. /// Names are comma-separated and trimmed. -pub fn parse_guardrails_header(_header: &str) -> Vec { - todo!("Implement header parsing") +pub fn parse_guardrails_header(header: &str) -> Vec { + header + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect() } /// Parse guard names from the request payload's `guardrails` field. -pub fn parse_guardrails_from_payload(_payload: &serde_json::Value) -> Vec { - todo!("Implement payload guardrails parsing") +pub fn parse_guardrails_from_payload(payload: &serde_json::Value) -> Vec { + payload + .get("guardrails") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect() + }) + .unwrap_or_default() } /// Resolve the final set of guards to execute by merging pipeline, header, and payload sources. /// Guards are additive and deduplicated by name. -pub fn resolve_guards_by_name<'a>( - _all_guards: &'a [GuardConfig], - _pipeline_names: &[&str], - _header_names: &[&str], - _payload_names: &[&str], +pub fn resolve_guards_by_name( + all_guards: &[GuardConfig], + pipeline_names: &[&str], + header_names: &[&str], + payload_names: &[&str], ) -> Vec { - todo!("Implement additive guard resolution") + let mut seen = HashSet::new(); + let mut resolved = Vec::new(); + + // Collect all requested names, pipeline first, then header, then payload + let all_names: Vec<&str> = pipeline_names + .iter() + .chain(header_names.iter()) + .chain(payload_names.iter()) + .copied() + .collect(); + + for name in all_names { + if seen.contains(name) { + continue; + } + if let Some(guard) = all_guards.iter().find(|g| g.name == name) { + seen.insert(name); + resolved.push(guard.clone()); + } + } + + resolved } /// Split guards into (pre_call, post_call) lists by mode. -pub fn split_guards_by_mode(_guards: &[GuardConfig]) -> (Vec, Vec) { - todo!("Implement guard splitting by mode") +pub fn split_guards_by_mode(guards: &[GuardConfig]) -> (Vec, Vec) { + let pre_call: Vec = guards + .iter() + .filter(|g| g.mode == GuardMode::PreCall) + .cloned() + .collect(); + let post_call: Vec = guards + .iter() + .filter(|g| g.mode == GuardMode::PostCall) + .cloned() + .collect(); + (pre_call, post_call) } diff --git a/src/guardrails/executor.rs b/src/guardrails/executor.rs index 0b4c04f1..22855a2f 100644 --- a/src/guardrails/executor.rs +++ b/src/guardrails/executor.rs @@ -1,12 +1,78 @@ +use futures::future::join_all; + use super::providers::GuardrailClient; -use super::types::{GuardConfig, GuardrailsOutcome}; +use super::types::{GuardConfig, GuardResult, GuardrailsOutcome, OnFailure}; /// Execute a set of guardrails against the given input text. -/// Returns a GuardrailsOutcome with results, blocked status, and warnings. +/// Guards are run concurrently. Returns a GuardrailsOutcome with results, blocked status, and warnings. pub async fn execute_guards( - _guards: &[GuardConfig], - _input: &str, - _client: &dyn GuardrailClient, + guards: &[GuardConfig], + input: &str, + client: &dyn GuardrailClient, ) -> GuardrailsOutcome { - todo!("Implement guard execution orchestration") + let futures: Vec<_> = guards + .iter() + .map(|guard| async move { + let result = client.evaluate(guard, input).await; + (guard, result) + }) + .collect(); + + let results_raw = join_all(futures).await; + + let mut results = Vec::new(); + let mut blocked = false; + let mut blocking_guard = None; + let mut warnings = Vec::new(); + + for (guard, result) in results_raw { + match result { + Ok(response) => { + if response.pass { + results.push(GuardResult::Passed { + name: guard.name.clone(), + result: response.result, + }); + } else { + results.push(GuardResult::Failed { + name: guard.name.clone(), + result: response.result, + on_failure: guard.on_failure.clone(), + }); + match guard.on_failure { + OnFailure::Block => { + blocked = true; + if blocking_guard.is_none() { + blocking_guard = Some(guard.name.clone()); + } + } + OnFailure::Warn => { + warnings.push(format!("Guard '{}' failed with warning", guard.name)); + } + } + } + } + Err(err) => { + let is_required = guard.required; + results.push(GuardResult::Error { + name: guard.name.clone(), + error: err.to_string(), + required: is_required, + }); + if is_required { + blocked = true; + if blocking_guard.is_none() { + blocking_guard = Some(guard.name.clone()); + } + } + } + } + } + + GuardrailsOutcome { + results, + blocked, + blocking_guard, + warnings, + } } diff --git a/src/guardrails/input_extractor.rs b/src/guardrails/input_extractor.rs index c6271c47..b36b8954 100644 --- a/src/guardrails/input_extractor.rs +++ b/src/guardrails/input_extractor.rs @@ -1,13 +1,42 @@ use crate::models::chat::{ChatCompletion, ChatCompletionRequest}; +use crate::models::content::ChatMessageContent; /// Extract text from the request for pre_call guardrails. /// Returns the content of the last user message. -pub fn extract_pre_call_input(_request: &ChatCompletionRequest) -> String { - todo!("Implement pre_call input extraction") +pub fn extract_pre_call_input(request: &ChatCompletionRequest) -> String { + request + .messages + .iter() + .rev() + .find(|m| m.role == "user") + .and_then(|m| m.content.as_ref()) + .map(|content| match content { + ChatMessageContent::String(s) => s.clone(), + ChatMessageContent::Array(parts) => parts + .iter() + .filter(|p| p.r#type == "text") + .map(|p| p.text.as_str()) + .collect::>() + .join(" "), + }) + .unwrap_or_default() } /// Extract text from a non-streaming ChatCompletion for post_call guardrails. /// Returns the content of the first assistant choice. -pub fn extract_post_call_input_from_completion(_completion: &ChatCompletion) -> String { - todo!("Implement post_call input extraction from completion") +pub fn extract_post_call_input_from_completion(completion: &ChatCompletion) -> String { + completion + .choices + .first() + .and_then(|choice| choice.message.content.as_ref()) + .map(|content| match content { + ChatMessageContent::String(s) => s.clone(), + ChatMessageContent::Array(parts) => parts + .iter() + .filter(|p| p.r#type == "text") + .map(|p| p.text.as_str()) + .collect::>() + .join(" "), + }) + .unwrap_or_default() } diff --git a/src/guardrails/providers/mod.rs b/src/guardrails/providers/mod.rs index faa46196..99d9b122 100644 --- a/src/guardrails/providers/mod.rs +++ b/src/guardrails/providers/mod.rs @@ -2,6 +2,7 @@ pub mod traceloop; use async_trait::async_trait; +use self::traceloop::TraceloopClient; use super::types::{EvaluatorResponse, GuardConfig, GuardrailError}; /// Trait for guardrail evaluator clients. @@ -16,6 +17,9 @@ pub trait GuardrailClient: Send + Sync { } /// Create a guardrail client based on the guard's provider type. -pub fn create_guardrail_client(_guard: &GuardConfig) -> Option> { - todo!("Implement client factory based on guard.provider") +pub fn create_guardrail_client(guard: &GuardConfig) -> Option> { + match guard.provider.as_str() { + "traceloop" => Some(Box::new(TraceloopClient::new())), + _ => None, + } } diff --git a/src/guardrails/providers/traceloop.rs b/src/guardrails/providers/traceloop.rs index 30b43800..34ff6faf 100644 --- a/src/guardrails/providers/traceloop.rs +++ b/src/guardrails/providers/traceloop.rs @@ -1,7 +1,9 @@ use async_trait::async_trait; +use serde_json::json; use std::time::Duration; use super::GuardrailClient; +use crate::guardrails::response_parser::parse_evaluator_http_response; use crate::guardrails::types::{EvaluatorResponse, GuardConfig, GuardrailError}; /// HTTP client for the Traceloop evaluator API service. @@ -28,9 +30,48 @@ impl TraceloopClient { impl GuardrailClient for TraceloopClient { async fn evaluate( &self, - _guard: &GuardConfig, - _input: &str, + guard: &GuardConfig, + input: &str, ) -> Result { - todo!("Implement Traceloop evaluator API call") + let api_base = guard.api_base.as_deref().unwrap_or("http://localhost:8080"); + let url = format!( + "{}/v2/guardrails/{}", + api_base.trim_end_matches('/'), + guard.evaluator_slug + ); + + let api_key = guard.api_key.as_deref().unwrap_or(""); + + // Build config from params (excluding evaluator_slug which is top-level) + let config: serde_json::Value = guard.params.clone().into_iter().collect(); + + let body = json!({ + "inputs": [input], + "config": config, + }); + + let response = self + .http_client + .post(&url) + .header("Authorization", format!("Bearer {api_key}")) + .header("Content-Type", "application/json") + .json(&body) + .send() + .await + .map_err(|e| { + if e.is_timeout() { + GuardrailError::Timeout(e.to_string()) + } else { + GuardrailError::Unavailable(e.to_string()) + } + })?; + + let status = response.status().as_u16(); + let response_body = response + .text() + .await + .map_err(|e| GuardrailError::Unavailable(e.to_string()))?; + + parse_evaluator_http_response(status, &response_body) } } diff --git a/src/guardrails/response_parser.rs b/src/guardrails/response_parser.rs index eb39cb67..40286c7e 100644 --- a/src/guardrails/response_parser.rs +++ b/src/guardrails/response_parser.rs @@ -1,14 +1,21 @@ use super::types::{EvaluatorResponse, GuardrailError}; /// Parse the evaluator response body (JSON string) into an EvaluatorResponse. -pub fn parse_evaluator_response(_body: &str) -> Result { - todo!("Implement evaluator response parsing") +pub fn parse_evaluator_response(body: &str) -> Result { + serde_json::from_str::(body) + .map_err(|e| GuardrailError::ParseError(e.to_string())) } /// Parse an HTTP response from the evaluator, handling non-200 status codes. pub fn parse_evaluator_http_response( - _status: u16, - _body: &str, + status: u16, + body: &str, ) -> Result { - todo!("Implement HTTP response parsing") + if !(200..300).contains(&status) { + return Err(GuardrailError::HttpError { + status, + body: body.to_string(), + }); + } + parse_evaluator_response(body) } diff --git a/src/guardrails/stream_buffer.rs b/src/guardrails/stream_buffer.rs index 526af7f5..dfe964f2 100644 --- a/src/guardrails/stream_buffer.rs +++ b/src/guardrails/stream_buffer.rs @@ -2,6 +2,10 @@ use crate::models::streaming::ChatCompletionChunk; /// Extract and concatenate text from accumulated streaming chunks. /// Joins the delta content from all chunks into a single string. -pub fn extract_text_from_chunks(_chunks: &[ChatCompletionChunk]) -> String { - todo!("Implement text extraction from streaming chunks") +pub fn extract_text_from_chunks(chunks: &[ChatCompletionChunk]) -> String { + chunks + .iter() + .flat_map(|chunk| &chunk.choices) + .filter_map(|choice| choice.delta.content.as_deref()) + .collect() } diff --git a/src/pipelines/pipeline.rs b/src/pipelines/pipeline.rs index e8351253..f4b0b985 100644 --- a/src/pipelines/pipeline.rs +++ b/src/pipelines/pipeline.rs @@ -1,4 +1,11 @@ use crate::config::models::PipelineType; +use crate::guardrails::api_control::split_guards_by_mode; +use crate::guardrails::executor::execute_guards; +use crate::guardrails::input_extractor::{ + extract_post_call_input_from_completion, extract_pre_call_input, +}; +use crate::guardrails::providers::GuardrailClient; +use crate::guardrails::types::{GuardConfig, GuardrailsConfig, GuardrailsOutcome}; use crate::models::chat::ChatCompletionResponse; use crate::models::completion::CompletionRequest; use crate::models::embeddings::EmbeddingsRequest; @@ -12,7 +19,7 @@ use crate::{ }; use async_stream::stream; use axum::response::sse::{Event, KeepAlive}; -use axum::response::{IntoResponse, Sse}; +use axum::response::{IntoResponse, Response, Sse}; use axum::{ Json, Router, extract::State, @@ -22,9 +29,82 @@ use axum::{ use futures::stream::BoxStream; use futures::{Stream, StreamExt}; use reqwest_streams::error::StreamBodyError; +use serde_json::json; use std::sync::Arc; -pub fn create_pipeline(pipeline: &Pipeline, model_registry: &ModelRegistry) -> Router { +/// Guardrails state attached to a pipeline, containing resolved guards and client. +#[derive(Clone)] +pub struct PipelineGuardrails { + pub pre_call: Vec, + pub post_call: Vec, + pub client: Arc, +} + +/// Build a PipelineGuardrails from config, resolving provider defaults for api_base/api_key. +pub fn build_pipeline_guardrails(config: &GuardrailsConfig) -> Option> { + if config.guards.is_empty() { + return None; + } + + let mut guards = config.guards.clone(); + for guard in &mut guards { + if guard.api_base.is_none() || guard.api_key.is_none() { + if let Some(provider) = config.providers.iter().find(|p| p.name == guard.provider) { + if guard.api_base.is_none() { + guard.api_base = Some(provider.api_base.clone()); + } + if guard.api_key.is_none() { + guard.api_key = Some(provider.api_key.clone()); + } + } + } + } + + let (pre_call, post_call) = split_guards_by_mode(&guards); + let client: Arc = + Arc::new(crate::guardrails::providers::traceloop::TraceloopClient::new()); + + Some(Arc::new(PipelineGuardrails { + pre_call, + post_call, + client, + })) +} + +pub fn blocked_response(outcome: &GuardrailsOutcome) -> Response { + let guard_name = outcome.blocking_guard.as_deref().unwrap_or("unknown"); + let body = json!({ + "error": { + "type": "guardrail_blocked", + "guardrail": guard_name, + "message": format!("Request blocked by guardrail '{guard_name}'"), + } + }); + (StatusCode::FORBIDDEN, Json(body)).into_response() +} + +pub fn warning_header_value(outcome: &GuardrailsOutcome) -> String { + outcome + .warnings + .iter() + .map(|w| { + // Extract guard name from the warning string "Guard 'name' failed with warning" + let name = w + .strip_prefix("Guard '") + .and_then(|s| s.strip_suffix("' failed with warning")) + .unwrap_or("unknown"); + format!("guardrail_name=\"{name}\", reason=\"failed\"") + }) + .collect::>() + .join("; ") +} + +pub fn create_pipeline( + pipeline: &Pipeline, + model_registry: &ModelRegistry, + guardrails_config: Option<&GuardrailsConfig>, +) -> Router { + let guardrails = guardrails_config.and_then(build_pipeline_guardrails); let mut router = Router::new(); let available_models: Vec = pipeline @@ -57,10 +137,13 @@ pub fn create_pipeline(pipeline: &Pipeline, model_registry: &ModelRegistry) -> R router } PluginConfig::ModelRouter { models } => match pipeline.r#type { - PipelineType::Chat => router.route( - "/chat/completions", - post(move |state, payload| chat_completions(state, payload, models)), - ), + PipelineType::Chat => { + let gr = guardrails.clone(); + router.route( + "/chat/completions", + post(move |state, payload| chat_completions(state, payload, models, gr)), + ) + } PipelineType::Completion => router.route( "/completions", post(move |state, payload| completions(state, payload, models)), @@ -104,9 +187,23 @@ pub async fn chat_completions( State(model_registry): State>, Json(payload): Json, model_keys: Vec, -) -> Result { + guardrails: Option>, +) -> Result { let mut tracer = OtelTracer::start("chat", &payload); + // Pre-call guardrails + let mut pre_warnings = Vec::new(); + if let Some(ref gr) = guardrails { + if !gr.pre_call.is_empty() { + let input = extract_pre_call_input(&payload); + let outcome = execute_guards(&gr.pre_call, &input, gr.client.as_ref()).await; + if outcome.blocked { + return Ok(blocked_response(&outcome)); + } + pre_warnings = outcome.warnings; + } + } + for model_key in model_keys { let model = model_registry.get(&model_key).unwrap(); @@ -123,6 +220,52 @@ pub async fn chat_completions( if let ChatCompletionResponse::NonStream(completion) = response { tracer.log_success(&completion); + + // Post-call guardrails (non-streaming) + if let Some(ref gr) = guardrails { + if !gr.post_call.is_empty() { + let response_text = extract_post_call_input_from_completion(&completion); + let outcome = + execute_guards(&gr.post_call, &response_text, gr.client.as_ref()).await; + if outcome.blocked { + return Ok(blocked_response(&outcome)); + } + if !outcome.warnings.is_empty() || !pre_warnings.is_empty() { + let mut all_warnings = pre_warnings; + all_warnings.extend(outcome.warnings); + let combined = GuardrailsOutcome { + results: vec![], + blocked: false, + blocking_guard: None, + warnings: all_warnings, + }; + let header_val = warning_header_value(&combined); + let mut response = Json(completion).into_response(); + response.headers_mut().insert( + "X-Traceloop-Guardrail-Warning", + header_val.parse().unwrap(), + ); + return Ok(response); + } + } + } + + // Add pre-call warning headers if any + if !pre_warnings.is_empty() { + let combined = GuardrailsOutcome { + results: vec![], + blocked: false, + blocking_guard: None, + warnings: pre_warnings, + }; + let header_val = warning_header_value(&combined); + let mut response = Json(completion).into_response(); + response + .headers_mut() + .insert("X-Traceloop-Guardrail-Warning", header_val.parse().unwrap()); + return Ok(response); + } + return Ok(Json(completion).into_response()); } @@ -322,7 +465,7 @@ mod tests { let model_configs = create_model_configs(vec!["test-model"]); let model_registry = ModelRegistry::new(&model_configs, provider_registry).unwrap(); let pipeline = create_test_pipeline(vec!["test-model"]); - let app = create_pipeline(&pipeline, &model_registry); + let app = create_pipeline(&pipeline, &model_registry, None); let response = get_models_response(app).await; @@ -371,7 +514,7 @@ mod tests { let model_registry = Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); let pipeline = create_test_pipeline(vec!["test-model-1", "test-model-2"]); - let app = create_pipeline(&pipeline, &model_registry); + let app = create_pipeline(&pipeline, &model_registry, None); let response = get_models_response(app).await; @@ -394,7 +537,7 @@ mod tests { let model_registry = Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); let pipeline = create_test_pipeline(vec![]); - let app = create_pipeline(&pipeline, &model_registry); + let app = create_pipeline(&pipeline, &model_registry, None); let response = get_models_response(app).await; @@ -412,7 +555,7 @@ mod tests { let model_registry = Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); let pipeline = create_test_pipeline(vec!["test-model-1", "test-model-3"]); // Only include 2 of 3 models - let app = create_pipeline(&pipeline, &model_registry); + let app = create_pipeline(&pipeline, &model_registry, None); let response = get_models_response(app).await; diff --git a/src/state.rs b/src/state.rs index bb5ab2fd..8c124d3b 100644 --- a/src/state.rs +++ b/src/state.rs @@ -179,7 +179,8 @@ impl AppState { "Adding default pipeline '{}' to router at index 0", default_pipeline.name ); - let pipeline_router = create_pipeline(default_pipeline, model_registry); + let pipeline_router = + create_pipeline(default_pipeline, model_registry, config.guardrails.as_ref()); pipeline_routers.push(pipeline_router); pipeline_names.push(default_pipeline.name.clone()); } @@ -188,7 +189,8 @@ impl AppState { let name = &pipeline.name; debug!("Adding pipeline '{}' to router at index {}", name, idx + 1); - let pipeline_router = create_pipeline(pipeline, model_registry); + let pipeline_router = + create_pipeline(pipeline, model_registry, config.guardrails.as_ref()); pipeline_routers.push(pipeline_router); pipeline_names.push(name.clone()); } diff --git a/tests/guardrails/test_e2e.rs b/tests/guardrails/test_e2e.rs index 50403910..e6398fde 100644 --- a/tests/guardrails/test_e2e.rs +++ b/tests/guardrails/test_e2e.rs @@ -1,100 +1,542 @@ +use hub_lib::guardrails::executor::execute_guards; +use hub_lib::guardrails::input_extractor::{ + extract_post_call_input_from_completion, extract_pre_call_input, +}; +use hub_lib::guardrails::providers::traceloop::TraceloopClient; +use hub_lib::guardrails::stream_buffer::extract_text_from_chunks; +use hub_lib::guardrails::types::*; +use hub_lib::pipelines::pipeline::build_pipeline_guardrails; + +use serde_json::json; +use wiremock::matchers; +use wiremock::{Mock, MockServer, ResponseTemplate}; + +use super::helpers::*; + // --------------------------------------------------------------------------- // Phase 8: End-to-End Integration (15 tests) // -// Full request flow tests using wiremock for both evaluator and LLM services. +// Full request flow tests using wiremock for evaluator services. // These validate the complete lifecycle from request to response. // --------------------------------------------------------------------------- -// TODO: These tests require full pipeline integration. -// They will be fully implemented when all prior phases are complete. -// For now, we define the test signatures to drive the implementation. +/// Helper: set up a wiremock evaluator that returns pass/fail +async fn setup_evaluator(pass: bool) -> MockServer { + let server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {"score": if pass { 0.95 } else { 0.1 }}, + "pass": pass + }))) + .mount(&server) + .await; + server +} + +fn guard_with_server( + name: &str, + mode: GuardMode, + on_failure: OnFailure, + server_uri: &str, + slug: &str, +) -> GuardConfig { + GuardConfig { + name: name.to_string(), + provider: "traceloop".to_string(), + evaluator_slug: slug.to_string(), + params: Default::default(), + mode, + on_failure, + required: true, + api_base: Some(server_uri.to_string()), + api_key: Some("test-key".to_string()), + } +} #[tokio::test] async fn test_e2e_pre_call_block_flow() { // Request -> guard fail+block -> 403 - todo!("Implement E2E test: pre_call block flow") + let eval = setup_evaluator(false).await; + let guard = guard_with_server( + "blocker", + GuardMode::PreCall, + OnFailure::Block, + &eval.uri(), + "toxicity", + ); + + let request = create_test_chat_request("Bad input"); + let input = extract_pre_call_input(&request); + + let client = TraceloopClient::new(); + let outcome = execute_guards(&[guard], &input, &client).await; + + assert!(outcome.blocked); + assert_eq!(outcome.blocking_guard.as_deref(), Some("blocker")); } #[tokio::test] async fn test_e2e_pre_call_pass_flow() { // Request -> guard pass -> LLM -> 200 - todo!("Implement E2E test: pre_call pass flow") + let eval = setup_evaluator(true).await; + let guard = guard_with_server( + "checker", + GuardMode::PreCall, + OnFailure::Block, + &eval.uri(), + "toxicity", + ); + + let request = create_test_chat_request("Safe input"); + let input = extract_pre_call_input(&request); + + let client = TraceloopClient::new(); + let outcome = execute_guards(&[guard], &input, &client).await; + + assert!(!outcome.blocked); + assert!(outcome.warnings.is_empty()); + // In real flow, would proceed to LLM call } #[tokio::test] async fn test_e2e_post_call_block_flow() { // Request -> LLM -> guard fail+block -> 403 - todo!("Implement E2E test: post_call block flow") + let eval = setup_evaluator(false).await; + let guard = guard_with_server( + "pii-check", + GuardMode::PostCall, + OnFailure::Block, + &eval.uri(), + "pii", + ); + + // Simulate LLM response + let completion = create_test_chat_completion("Here is the SSN: 123-45-6789"); + let response_text = extract_post_call_input_from_completion(&completion); + + let client = TraceloopClient::new(); + let outcome = execute_guards(&[guard], &response_text, &client).await; + + assert!(outcome.blocked); + assert_eq!(outcome.blocking_guard.as_deref(), Some("pii-check")); } #[tokio::test] async fn test_e2e_post_call_warn_flow() { // Request -> LLM -> guard fail+warn -> 200 + header - todo!("Implement E2E test: post_call warn flow") + let eval = setup_evaluator(false).await; + let guard = guard_with_server( + "tone-check", + GuardMode::PostCall, + OnFailure::Warn, + &eval.uri(), + "tone", + ); + + let completion = create_test_chat_completion("Mildly concerning response"); + let response_text = extract_post_call_input_from_completion(&completion); + + let client = TraceloopClient::new(); + let outcome = execute_guards(&[guard], &response_text, &client).await; + + assert!(!outcome.blocked); + assert_eq!(outcome.warnings.len(), 1); + assert!(outcome.warnings[0].contains("tone-check")); } #[tokio::test] async fn test_e2e_pre_and_post_both_pass() { // Both stages pass -> clean 200 response - todo!("Implement E2E test: pre and post both pass") + let pre_eval = setup_evaluator(true).await; + let post_eval = setup_evaluator(true).await; + + let pre_guard = guard_with_server( + "pre-check", + GuardMode::PreCall, + OnFailure::Block, + &pre_eval.uri(), + "safety", + ); + let post_guard = guard_with_server( + "post-check", + GuardMode::PostCall, + OnFailure::Block, + &post_eval.uri(), + "pii", + ); + + let client = TraceloopClient::new(); + + // Pre-call + let request = create_test_chat_request("Hello"); + let input = extract_pre_call_input(&request); + let pre_outcome = execute_guards(&[pre_guard], &input, &client).await; + assert!(!pre_outcome.blocked); + + // Post-call + let completion = create_test_chat_completion("Hi there!"); + let response_text = extract_post_call_input_from_completion(&completion); + let post_outcome = execute_guards(&[post_guard], &response_text, &client).await; + assert!(!post_outcome.blocked); + assert!(post_outcome.warnings.is_empty()); } #[tokio::test] async fn test_e2e_pre_blocks_post_never_runs() { // Pre blocks -> post evaluator gets 0 requests - todo!("Implement E2E test: pre blocks, post never runs") + let pre_eval = setup_evaluator(false).await; + let post_eval = MockServer::start().await; + + // Post evaluator should receive 0 requests + Mock::given(matchers::any()) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({"result": {}, "pass": true}))) + .expect(0) + .mount(&post_eval) + .await; + + let pre_guard = guard_with_server( + "blocker", + GuardMode::PreCall, + OnFailure::Block, + &pre_eval.uri(), + "toxicity", + ); + let post_guard = guard_with_server( + "post-check", + GuardMode::PostCall, + OnFailure::Block, + &post_eval.uri(), + "pii", + ); + + let client = TraceloopClient::new(); + let request = create_test_chat_request("Bad input"); + let input = extract_pre_call_input(&request); + + let pre_outcome = execute_guards(&[pre_guard], &input, &client).await; + assert!(pre_outcome.blocked); + + // Since pre blocked, post guards never run - post_eval.verify() will assert 0 calls + // (wiremock verifies expect(0) when server drops) + let _ = post_guard; // not used - that's the point } #[tokio::test] async fn test_e2e_mixed_block_and_warn() { // Multiple guards with mixed block/warn outcomes - todo!("Implement E2E test: mixed block and warn") + let eval1 = setup_evaluator(true).await; // passes + let eval2 = setup_evaluator(false).await; // fails -> warn + let eval3 = setup_evaluator(false).await; // fails -> block + + let guards = vec![ + guard_with_server( + "passer", + GuardMode::PreCall, + OnFailure::Block, + &eval1.uri(), + "safety", + ), + guard_with_server( + "warner", + GuardMode::PreCall, + OnFailure::Warn, + &eval2.uri(), + "tone", + ), + guard_with_server( + "blocker", + GuardMode::PreCall, + OnFailure::Block, + &eval3.uri(), + "toxicity", + ), + ]; + + let client = TraceloopClient::new(); + let outcome = execute_guards(&guards, "test input", &client).await; + + assert!(outcome.blocked); + assert_eq!(outcome.blocking_guard.as_deref(), Some("blocker")); + assert!(outcome.warnings.iter().any(|w| w.contains("warner"))); } #[tokio::test] async fn test_e2e_streaming_post_call_buffer_pass() { // Stream buffered, guard passes -> SSE response streamed to client - todo!("Implement E2E test: streaming post_call buffer pass") + let eval = setup_evaluator(true).await; + let guard = guard_with_server( + "response-check", + GuardMode::PostCall, + OnFailure::Block, + &eval.uri(), + "safety", + ); + + // Simulate accumulated streaming chunks + let chunks = vec![ + create_test_chunk("Hello"), + create_test_chunk(" "), + create_test_chunk("world!"), + ]; + let accumulated = extract_text_from_chunks(&chunks); + assert_eq!(accumulated, "Hello world!"); + + let client = TraceloopClient::new(); + let outcome = execute_guards(&[guard], &accumulated, &client).await; + + assert!(!outcome.blocked); } #[tokio::test] async fn test_e2e_streaming_post_call_buffer_block() { // Stream buffered, guard blocks -> 403 - todo!("Implement E2E test: streaming post_call buffer block") + let eval = setup_evaluator(false).await; + let guard = guard_with_server( + "pii-check", + GuardMode::PostCall, + OnFailure::Block, + &eval.uri(), + "pii", + ); + + let chunks = vec![ + create_test_chunk("Here is "), + create_test_chunk("SSN: 123-45-6789"), + ]; + let accumulated = extract_text_from_chunks(&chunks); + + let client = TraceloopClient::new(); + let outcome = execute_guards(&[guard], &accumulated, &client).await; + + assert!(outcome.blocked); } #[tokio::test] async fn test_e2e_config_from_yaml_with_env_vars() { // Full YAML config with ${VAR} substitution in api_key - todo!("Implement E2E test: config from YAML with env vars") + use std::io::Write; + use tempfile::NamedTempFile; + + unsafe { + std::env::set_var("E2E_TEST_API_KEY", "resolved-key-123"); + } + + let config_yaml = r#" +providers: + - key: openai + type: openai + api_key: "sk-test" +models: + - key: gpt-4 + type: gpt-4 + provider: openai +pipelines: + - name: default + type: chat + plugins: + - model-router: + models: + - gpt-4 +guardrails: + providers: + - name: traceloop + api_base: "https://api.traceloop.com" + api_key: "${E2E_TEST_API_KEY}" + guards: + - name: toxicity-check + provider: traceloop + evaluator_slug: toxicity + mode: pre_call + on_failure: block + - name: pii-check + provider: traceloop + evaluator_slug: pii + mode: post_call + on_failure: warn + api_key: "override-key" +"#; + + let mut temp_file = NamedTempFile::new().unwrap(); + temp_file.write_all(config_yaml.as_bytes()).unwrap(); + + let config = hub_lib::config::load_config(temp_file.path().to_str().unwrap()).unwrap(); + let gr = config.guardrails.unwrap(); + + assert_eq!(gr.providers.len(), 1); + assert_eq!(gr.providers[0].api_key, "resolved-key-123"); + + // Guards should have evaluator_slug at top level + assert_eq!(gr.guards[0].evaluator_slug, "toxicity"); + assert_eq!(gr.guards[0].mode, GuardMode::PreCall); + assert!(gr.guards[0].api_base.is_none()); // inherits from provider + assert!(gr.guards[0].api_key.is_none()); // inherits from provider + + // Second guard overrides api_key + assert_eq!(gr.guards[1].api_key.as_deref(), Some("override-key")); + + // Build pipeline guardrails - should resolve provider defaults + let pipeline_gr = build_pipeline_guardrails(&gr).unwrap(); + assert_eq!(pipeline_gr.pre_call.len(), 1); + assert_eq!(pipeline_gr.post_call.len(), 1); + // Provider api_base should be resolved for guards that don't override + assert_eq!( + pipeline_gr.pre_call[0].api_base.as_deref(), + Some("https://api.traceloop.com") + ); + assert_eq!( + pipeline_gr.pre_call[0].api_key.as_deref(), + Some("resolved-key-123") + ); + // Guard with override keeps its own api_key + assert_eq!( + pipeline_gr.post_call[0].api_key.as_deref(), + Some("override-key") + ); + + unsafe { + std::env::remove_var("E2E_TEST_API_KEY"); + } } #[tokio::test] async fn test_e2e_multiple_guards_different_evaluators() { // Different evaluator slugs -> separate mock expectations - todo!("Implement E2E test: multiple guards different evaluators") + let server = MockServer::start().await; + + Mock::given(matchers::method("POST")) + .and(matchers::path("/v2/guardrails/toxicity")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({"result": {}, "pass": true}))) + .expect(1) + .mount(&server) + .await; + + Mock::given(matchers::method("POST")) + .and(matchers::path("/v2/guardrails/pii")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({"result": {}, "pass": true}))) + .expect(1) + .mount(&server) + .await; + + let guards = vec![ + guard_with_server( + "tox-guard", + GuardMode::PreCall, + OnFailure::Block, + &server.uri(), + "toxicity", + ), + guard_with_server( + "pii-guard", + GuardMode::PreCall, + OnFailure::Block, + &server.uri(), + "pii", + ), + ]; + + let client = TraceloopClient::new(); + let outcome = execute_guards(&guards, "test input", &client).await; + + assert!(!outcome.blocked); + assert_eq!(outcome.results.len(), 2); + // wiremock will verify each path was called exactly once } #[tokio::test] async fn test_e2e_fail_open_evaluator_down() { // Evaluator service down + required: false -> passthrough - todo!("Implement E2E test: fail open evaluator down") + let server = MockServer::start().await; + Mock::given(matchers::any()) + .respond_with(ResponseTemplate::new(500).set_body_string("Internal Server Error")) + .mount(&server) + .await; + + let mut guard = guard_with_server( + "checker", + GuardMode::PreCall, + OnFailure::Block, + &server.uri(), + "safety", + ); + guard.required = false; // fail-open + + let client = TraceloopClient::new(); + let outcome = execute_guards(&[guard], "test input", &client).await; + + assert!(!outcome.blocked); // Fail-open: not blocked despite error } #[tokio::test] async fn test_e2e_fail_closed_evaluator_down() { // Evaluator service down + required: true -> 403 - todo!("Implement E2E test: fail closed evaluator down") + let server = MockServer::start().await; + Mock::given(matchers::any()) + .respond_with(ResponseTemplate::new(500).set_body_string("Internal Server Error")) + .mount(&server) + .await; + + let mut guard = guard_with_server( + "checker", + GuardMode::PreCall, + OnFailure::Block, + &server.uri(), + "safety", + ); + guard.required = true; // fail-closed + + let client = TraceloopClient::new(); + let outcome = execute_guards(&[guard], "test input", &client).await; + + assert!(outcome.blocked); // Fail-closed: blocked due to error } #[tokio::test] async fn test_e2e_config_validation_rejects_invalid() { - // Config with missing required fields -> startup validation error - todo!("Implement E2E test: config validation rejects invalid") + // Config with missing required fields -> deserialization error + let invalid_json = json!({ + "guards": [{ + "name": "incomplete-guard" + // missing provider, evaluator_slug, mode + }] + }); + let result = serde_json::from_value::(invalid_json); + assert!(result.is_err()); } #[tokio::test] async fn test_e2e_backward_compat_no_guardrails() { // Existing config without guardrails works unchanged - todo!("Implement E2E test: backward compat no guardrails") + use std::io::Write; + use tempfile::NamedTempFile; + + let config_yaml = r#" +providers: + - key: openai + type: openai + api_key: "sk-test" +models: + - key: gpt-4 + type: gpt-4 + provider: openai +pipelines: + - name: default + type: chat + plugins: + - model-router: + models: + - gpt-4 +"#; + + let mut temp_file = NamedTempFile::new().unwrap(); + temp_file.write_all(config_yaml.as_bytes()).unwrap(); + + let config = hub_lib::config::load_config(temp_file.path().to_str().unwrap()).unwrap(); + assert!(config.guardrails.is_none()); + + // build_pipeline_guardrails with None returns None + let gr = config + .guardrails + .as_ref() + .and_then(build_pipeline_guardrails); + assert!(gr.is_none()); } diff --git a/tests/guardrails/test_pipeline.rs b/tests/guardrails/test_pipeline.rs index 6243e6fd..7c99d404 100644 --- a/tests/guardrails/test_pipeline.rs +++ b/tests/guardrails/test_pipeline.rs @@ -1,52 +1,211 @@ +use hub_lib::guardrails::executor::execute_guards; +use hub_lib::guardrails::providers::traceloop::TraceloopClient; +use hub_lib::guardrails::types::*; +use hub_lib::pipelines::pipeline::{ + blocked_response, build_pipeline_guardrails, warning_header_value, +}; + +use axum::body::to_bytes; +use axum::response::IntoResponse; +use serde_json::json; +use wiremock::matchers; +use wiremock::{Mock, MockServer, ResponseTemplate}; + +use super::helpers::*; + // --------------------------------------------------------------------------- // Phase 6: Pipeline Integration (7 tests) // // These tests verify that guardrails are properly wired into the pipeline -// request handling flow. They use wiremock for both the evaluator and LLM. +// request handling flow. They use wiremock for the evaluator service. // --------------------------------------------------------------------------- -// TODO: These tests require pipeline integration implementation. -// They will be fully implemented when the pipeline hooks are added. -// For now, we define the test signatures to drive the implementation. - #[tokio::test] async fn test_pre_call_guardrails_block_before_llm() { - // Guard blocks the request -> 403 response, LLM receives 0 requests - todo!("Implement pipeline integration test: pre_call block") + // Set up evaluator mock that rejects the input + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(json!({"result": {"reason": "toxic"}, "pass": false})), + ) + .expect(1) + .mount(&eval_server) + .await; + + let guard = GuardConfig { + name: "toxicity-check".to_string(), + provider: "traceloop".to_string(), + evaluator_slug: "toxicity".to_string(), + params: Default::default(), + mode: GuardMode::PreCall, + on_failure: OnFailure::Block, + required: true, + api_base: Some(eval_server.uri()), + api_key: Some("test-key".to_string()), + }; + + let client = TraceloopClient::new(); + let outcome = execute_guards(&[guard], "toxic input", &client).await; + + assert!(outcome.blocked); + assert_eq!(outcome.blocking_guard.as_deref(), Some("toxicity-check")); } #[tokio::test] async fn test_pre_call_guardrails_warn_and_continue() { - // Guard warns but request proceeds to LLM -> 200 + warning header - todo!("Implement pipeline integration test: pre_call warn") + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(json!({"result": {"reason": "borderline"}, "pass": false})), + ) + .expect(1) + .mount(&eval_server) + .await; + + let guard = GuardConfig { + name: "tone-check".to_string(), + provider: "traceloop".to_string(), + evaluator_slug: "tone".to_string(), + params: Default::default(), + mode: GuardMode::PreCall, + on_failure: OnFailure::Warn, + required: true, + api_base: Some(eval_server.uri()), + api_key: Some("test-key".to_string()), + }; + + let client = TraceloopClient::new(); + let outcome = execute_guards(&[guard], "borderline input", &client).await; + + assert!(!outcome.blocked); + assert_eq!(outcome.warnings.len(), 1); + assert!(outcome.warnings[0].contains("tone-check")); } #[tokio::test] async fn test_post_call_guardrails_block_response() { - // LLM responds, guard blocks output -> 403 - todo!("Implement pipeline integration test: post_call block") + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(json!({"result": {"reason": "pii detected"}, "pass": false})), + ) + .expect(1) + .mount(&eval_server) + .await; + + let guard = GuardConfig { + name: "pii-check".to_string(), + provider: "traceloop".to_string(), + evaluator_slug: "pii".to_string(), + params: Default::default(), + mode: GuardMode::PostCall, + on_failure: OnFailure::Block, + required: true, + api_base: Some(eval_server.uri()), + api_key: Some("test-key".to_string()), + }; + + let client = TraceloopClient::new(); + // Simulate post-call: evaluate the LLM response text + let outcome = execute_guards(&[guard], "Here is John's SSN: 123-45-6789", &client).await; + + assert!(outcome.blocked); + assert_eq!(outcome.blocking_guard.as_deref(), Some("pii-check")); } #[tokio::test] async fn test_post_call_guardrails_warn_and_add_header() { - // LLM responds, guard warns -> 200 + X-Traceloop-Guardrail-Warning header - todo!("Implement pipeline integration test: post_call warn") + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(json!({"result": {"reason": "mildly concerning"}, "pass": false})), + ) + .expect(1) + .mount(&eval_server) + .await; + + let guard = GuardConfig { + name: "safety-check".to_string(), + provider: "traceloop".to_string(), + evaluator_slug: "safety".to_string(), + params: Default::default(), + mode: GuardMode::PostCall, + on_failure: OnFailure::Warn, + required: true, + api_base: Some(eval_server.uri()), + api_key: Some("test-key".to_string()), + }; + + let client = TraceloopClient::new(); + let outcome = execute_guards(&[guard], "Some LLM response", &client).await; + + assert!(!outcome.blocked); + assert!(!outcome.warnings.is_empty()); + + // Verify warning header would be generated correctly + let header = warning_header_value(&outcome); + assert!(header.contains("guardrail_name=")); + assert!(header.contains("safety-check")); } #[tokio::test] async fn test_warning_header_format() { - // Warning header format: guardrail_name="...", reason="..." - todo!("Implement pipeline integration test: warning header format") + let outcome = GuardrailsOutcome { + results: vec![], + blocked: false, + blocking_guard: None, + warnings: vec!["Guard 'my-guard' failed with warning".to_string()], + }; + let header = warning_header_value(&outcome); + assert_eq!(header, "guardrail_name=\"my-guard\", reason=\"failed\""); } #[tokio::test] async fn test_blocked_response_403_format() { - // Blocked response body: {"error": {"type": "guardrail_blocked", ...}} - todo!("Implement pipeline integration test: 403 response format") + let outcome = GuardrailsOutcome { + results: vec![], + blocked: true, + blocking_guard: Some("toxicity-check".to_string()), + warnings: vec![], + }; + let response = blocked_response(&outcome); + assert_eq!(response.status(), 403); + + let body = to_bytes(response.into_body(), 1024 * 1024).await.unwrap(); + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + assert_eq!(json["error"]["type"], "guardrail_blocked"); + assert_eq!(json["error"]["guardrail"], "toxicity-check"); + assert!( + json["error"]["message"] + .as_str() + .unwrap() + .contains("toxicity-check") + ); } #[tokio::test] async fn test_no_guardrails_passthrough() { - // No guardrails configured -> normal passthrough behavior - todo!("Implement pipeline integration test: passthrough") + // Empty guardrails config -> build_pipeline_guardrails returns None + let config = GuardrailsConfig { + providers: vec![], + guards: vec![], + }; + let result = build_pipeline_guardrails(&config); + assert!(result.is_none()); + + // Config with no guards -> passthrough + let config_with_providers = GuardrailsConfig { + providers: vec![ProviderConfig { + name: "traceloop".to_string(), + api_base: "http://localhost".to_string(), + api_key: "key".to_string(), + }], + guards: vec![], + }; + let result = build_pipeline_guardrails(&config_with_providers); + assert!(result.is_none()); } From 9c69876b9a6b301032c3d7195e1e02a1cd8c18bd Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Wed, 11 Feb 2026 10:18:59 +0200 Subject: [PATCH 03/59] model location --- src/guardrails/api_control.rs | 12 ++-- src/guardrails/executor.rs | 4 +- src/guardrails/input_extractor.rs | 17 +++++ src/guardrails/providers/mod.rs | 6 +- src/guardrails/providers/traceloop.rs | 4 +- src/guardrails/types.rs | 17 ++++- src/pipelines/pipeline.rs | 93 +++++++++++++++++++++------ tests/guardrails/helpers.rs | 14 ++-- tests/guardrails/test_e2e.rs | 4 +- tests/guardrails/test_pipeline.rs | 8 +-- tests/guardrails/test_types.rs | 10 +-- 11 files changed, 136 insertions(+), 53 deletions(-) diff --git a/src/guardrails/api_control.rs b/src/guardrails/api_control.rs index b0e1dcad..0a1c4357 100644 --- a/src/guardrails/api_control.rs +++ b/src/guardrails/api_control.rs @@ -1,6 +1,6 @@ use std::collections::HashSet; -use super::types::{GuardConfig, GuardMode}; +use super::types::{Guard, GuardMode}; /// Parse guard names from the X-Traceloop-Guardrails header value. /// Names are comma-separated and trimmed. @@ -28,11 +28,11 @@ pub fn parse_guardrails_from_payload(payload: &serde_json::Value) -> Vec /// Resolve the final set of guards to execute by merging pipeline, header, and payload sources. /// Guards are additive and deduplicated by name. pub fn resolve_guards_by_name( - all_guards: &[GuardConfig], + all_guards: &[Guard], pipeline_names: &[&str], header_names: &[&str], payload_names: &[&str], -) -> Vec { +) -> Vec { let mut seen = HashSet::new(); let mut resolved = Vec::new(); @@ -58,13 +58,13 @@ pub fn resolve_guards_by_name( } /// Split guards into (pre_call, post_call) lists by mode. -pub fn split_guards_by_mode(guards: &[GuardConfig]) -> (Vec, Vec) { - let pre_call: Vec = guards +pub fn split_guards_by_mode(guards: &[Guard]) -> (Vec, Vec) { + let pre_call: Vec = guards .iter() .filter(|g| g.mode == GuardMode::PreCall) .cloned() .collect(); - let post_call: Vec = guards + let post_call: Vec = guards .iter() .filter(|g| g.mode == GuardMode::PostCall) .cloned() diff --git a/src/guardrails/executor.rs b/src/guardrails/executor.rs index 22855a2f..a5eb783f 100644 --- a/src/guardrails/executor.rs +++ b/src/guardrails/executor.rs @@ -1,12 +1,12 @@ use futures::future::join_all; use super::providers::GuardrailClient; -use super::types::{GuardConfig, GuardResult, GuardrailsOutcome, OnFailure}; +use super::types::{Guard, GuardResult, GuardrailsOutcome, OnFailure}; /// Execute a set of guardrails against the given input text. /// Guards are run concurrently. Returns a GuardrailsOutcome with results, blocked status, and warnings. pub async fn execute_guards( - guards: &[GuardConfig], + guards: &[Guard], input: &str, client: &dyn GuardrailClient, ) -> GuardrailsOutcome { diff --git a/src/guardrails/input_extractor.rs b/src/guardrails/input_extractor.rs index b36b8954..33cc2d6d 100644 --- a/src/guardrails/input_extractor.rs +++ b/src/guardrails/input_extractor.rs @@ -1,4 +1,5 @@ use crate::models::chat::{ChatCompletion, ChatCompletionRequest}; +use crate::models::completion::{CompletionRequest, CompletionResponse}; use crate::models::content::ChatMessageContent; /// Extract text from the request for pre_call guardrails. @@ -22,6 +23,22 @@ pub fn extract_pre_call_input(request: &ChatCompletionRequest) -> String { .unwrap_or_default() } +/// Extract text from a CompletionRequest for pre_call guardrails. +/// Returns the prompt string. +pub fn extract_pre_call_input_from_completion_request(request: &CompletionRequest) -> String { + request.prompt.clone() +} + +/// Extract text from a CompletionResponse for post_call guardrails. +/// Returns the text of the first choice. +pub fn extract_post_call_input_from_completion_response(response: &CompletionResponse) -> String { + response + .choices + .first() + .map(|choice| choice.text.clone()) + .unwrap_or_default() +} + /// Extract text from a non-streaming ChatCompletion for post_call guardrails. /// Returns the content of the first assistant choice. pub fn extract_post_call_input_from_completion(completion: &ChatCompletion) -> String { diff --git a/src/guardrails/providers/mod.rs b/src/guardrails/providers/mod.rs index 99d9b122..3f26a85d 100644 --- a/src/guardrails/providers/mod.rs +++ b/src/guardrails/providers/mod.rs @@ -3,7 +3,7 @@ pub mod traceloop; use async_trait::async_trait; use self::traceloop::TraceloopClient; -use super::types::{EvaluatorResponse, GuardConfig, GuardrailError}; +use super::types::{EvaluatorResponse, Guard, GuardrailError}; /// Trait for guardrail evaluator clients. /// Each provider (traceloop, etc.) implements this to call its evaluator API. @@ -11,13 +11,13 @@ use super::types::{EvaluatorResponse, GuardConfig, GuardrailError}; pub trait GuardrailClient: Send + Sync { async fn evaluate( &self, - guard: &GuardConfig, + guard: &Guard, input: &str, ) -> Result; } /// Create a guardrail client based on the guard's provider type. -pub fn create_guardrail_client(guard: &GuardConfig) -> Option> { +pub fn create_guardrail_client(guard: &Guard) -> Option> { match guard.provider.as_str() { "traceloop" => Some(Box::new(TraceloopClient::new())), _ => None, diff --git a/src/guardrails/providers/traceloop.rs b/src/guardrails/providers/traceloop.rs index 34ff6faf..827d8e5d 100644 --- a/src/guardrails/providers/traceloop.rs +++ b/src/guardrails/providers/traceloop.rs @@ -4,7 +4,7 @@ use std::time::Duration; use super::GuardrailClient; use crate::guardrails::response_parser::parse_evaluator_http_response; -use crate::guardrails::types::{EvaluatorResponse, GuardConfig, GuardrailError}; +use crate::guardrails::types::{EvaluatorResponse, Guard, GuardrailError}; /// HTTP client for the Traceloop evaluator API service. /// Calls `POST {api_base}/v2/guardrails/{evaluator_slug}`. @@ -30,7 +30,7 @@ impl TraceloopClient { impl GuardrailClient for TraceloopClient { async fn evaluate( &self, - guard: &GuardConfig, + guard: &Guard, input: &str, ) -> Result { let api_base = guard.api_base.as_deref().unwrap_or("http://localhost:8080"); diff --git a/src/guardrails/types.rs b/src/guardrails/types.rs index 560ae633..e8bd31e6 100644 --- a/src/guardrails/types.rs +++ b/src/guardrails/types.rs @@ -1,6 +1,9 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use super::providers::GuardrailClient; fn default_on_failure() -> OnFailure { OnFailure::Warn @@ -32,7 +35,7 @@ pub struct ProviderConfig { } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct GuardConfig { +pub struct Guard { pub name: String, pub provider: String, pub evaluator_slug: String, @@ -49,7 +52,7 @@ pub struct GuardConfig { pub api_key: Option, } -impl Hash for GuardConfig { +impl Hash for Guard { fn hash(&self, state: &mut H) { self.name.hash(state); self.provider.hash(state); @@ -74,7 +77,7 @@ pub struct GuardrailsConfig { #[serde(default)] pub providers: Vec, #[serde(default)] - pub guards: Vec, + pub guards: Vec, } impl Hash for GuardrailsConfig { @@ -138,3 +141,11 @@ impl std::fmt::Display for GuardrailError { } impl std::error::Error for GuardrailError {} + +/// Guardrails state attached to a pipeline, containing resolved guards and client. +#[derive(Clone)] +pub struct Guardrails { + pub pre_call: Vec, + pub post_call: Vec, + pub client: Arc, +} diff --git a/src/pipelines/pipeline.rs b/src/pipelines/pipeline.rs index f4b0b985..8921f562 100644 --- a/src/pipelines/pipeline.rs +++ b/src/pipelines/pipeline.rs @@ -2,10 +2,11 @@ use crate::config::models::PipelineType; use crate::guardrails::api_control::split_guards_by_mode; use crate::guardrails::executor::execute_guards; use crate::guardrails::input_extractor::{ - extract_post_call_input_from_completion, extract_pre_call_input, + extract_post_call_input_from_completion, extract_post_call_input_from_completion_response, + extract_pre_call_input, extract_pre_call_input_from_completion_request, }; use crate::guardrails::providers::GuardrailClient; -use crate::guardrails::types::{GuardConfig, GuardrailsConfig, GuardrailsOutcome}; +use crate::guardrails::types::{GuardrailsConfig, GuardrailsOutcome, Guardrails}; use crate::models::chat::ChatCompletionResponse; use crate::models::completion::CompletionRequest; use crate::models::embeddings::EmbeddingsRequest; @@ -32,16 +33,8 @@ use reqwest_streams::error::StreamBodyError; use serde_json::json; use std::sync::Arc; -/// Guardrails state attached to a pipeline, containing resolved guards and client. -#[derive(Clone)] -pub struct PipelineGuardrails { - pub pre_call: Vec, - pub post_call: Vec, - pub client: Arc, -} - /// Build a PipelineGuardrails from config, resolving provider defaults for api_base/api_key. -pub fn build_pipeline_guardrails(config: &GuardrailsConfig) -> Option> { +pub fn build_pipeline_guardrails(config: &GuardrailsConfig) -> Option> { if config.guards.is_empty() { return None; } @@ -64,7 +57,7 @@ pub fn build_pipeline_guardrails(config: &GuardrailsConfig) -> Option = Arc::new(crate::guardrails::providers::traceloop::TraceloopClient::new()); - Some(Arc::new(PipelineGuardrails { + Some(Arc::new(Guardrails { pre_call, post_call, client, @@ -144,10 +137,13 @@ pub fn create_pipeline( post(move |state, payload| chat_completions(state, payload, models, gr)), ) } - PipelineType::Completion => router.route( - "/completions", - post(move |state, payload| completions(state, payload, models)), - ), + PipelineType::Completion => { + let gr = guardrails.clone(); + router.route( + "/completions", + post(move |state, payload| completions(state, payload, models, gr)), + ) + } PipelineType::Embeddings => router.route( "/embeddings", post(move |state, payload| embeddings(state, payload, models)), @@ -187,7 +183,7 @@ pub async fn chat_completions( State(model_registry): State>, Json(payload): Json, model_keys: Vec, - guardrails: Option>, + guardrails: Option>, ) -> Result { let mut tracer = OtelTracer::start("chat", &payload); @@ -286,9 +282,23 @@ pub async fn completions( State(model_registry): State>, Json(payload): Json, model_keys: Vec, -) -> impl IntoResponse { + guardrails: Option>, +) -> Result { let mut tracer = OtelTracer::start("completion", &payload); + // Pre-call guardrails + let mut pre_warnings = Vec::new(); + if let Some(ref gr) = guardrails { + if !gr.pre_call.is_empty() { + let input = extract_pre_call_input_from_completion_request(&payload); + let outcome = execute_guards(&gr.pre_call, &input, gr.client.as_ref()).await; + if outcome.blocked { + return Ok(blocked_response(&outcome)); + } + pre_warnings = outcome.warnings; + } + } + for model_key in model_keys { let model = model_registry.get(&model_key).unwrap(); @@ -300,7 +310,52 @@ pub async fn completions( eprintln!("Completion error for model {model_key}: {e:?}"); })?; tracer.log_success(&response); - return Ok(Json(response)); + + // Post-call guardrails + if let Some(ref gr) = guardrails { + if !gr.post_call.is_empty() { + let response_text = extract_post_call_input_from_completion_response(&response); + let outcome = + execute_guards(&gr.post_call, &response_text, gr.client.as_ref()).await; + if outcome.blocked { + return Ok(blocked_response(&outcome)); + } + if !outcome.warnings.is_empty() || !pre_warnings.is_empty() { + let mut all_warnings = pre_warnings; + all_warnings.extend(outcome.warnings); + let combined = GuardrailsOutcome { + results: vec![], + blocked: false, + blocking_guard: None, + warnings: all_warnings, + }; + let header_val = warning_header_value(&combined); + let mut resp = Json(response).into_response(); + resp.headers_mut().insert( + "X-Traceloop-Guardrail-Warning", + header_val.parse().unwrap(), + ); + return Ok(resp); + } + } + } + + // Add pre-call warning headers if any + if !pre_warnings.is_empty() { + let combined = GuardrailsOutcome { + results: vec![], + blocked: false, + blocking_guard: None, + warnings: pre_warnings, + }; + let header_val = warning_header_value(&combined); + let mut resp = Json(response).into_response(); + resp.headers_mut() + .insert("X-Traceloop-Guardrail-Warning", header_val.parse().unwrap()); + return Ok(resp); + } + + return Ok(Json(response).into_response()); } } diff --git a/tests/guardrails/helpers.rs b/tests/guardrails/helpers.rs index 2e989dd8..2443a86e 100644 --- a/tests/guardrails/helpers.rs +++ b/tests/guardrails/helpers.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use async_trait::async_trait; use hub_lib::guardrails::providers::GuardrailClient; use hub_lib::guardrails::types::{ - EvaluatorResponse, GuardConfig, GuardMode, GuardrailError, OnFailure, + EvaluatorResponse, Guard, GuardMode, GuardrailError, OnFailure, }; use hub_lib::models::chat::{ChatCompletion, ChatCompletionChoice, ChatCompletionRequest}; use hub_lib::models::content::{ChatCompletionMessage, ChatMessageContent}; @@ -15,8 +15,8 @@ use serde_json::json; // Guard config builders // --------------------------------------------------------------------------- -pub fn create_test_guard(name: &str, mode: GuardMode) -> GuardConfig { - GuardConfig { +pub fn create_test_guard(name: &str, mode: GuardMode) -> Guard { + Guard { name: name.to_string(), provider: "traceloop".to_string(), evaluator_slug: "test-evaluator".to_string(), @@ -33,20 +33,20 @@ pub fn create_test_guard_with_failure_action( name: &str, mode: GuardMode, on_failure: OnFailure, -) -> GuardConfig { +) -> Guard { let mut guard = create_test_guard(name, mode); guard.on_failure = on_failure; guard } -pub fn create_test_guard_with_required(name: &str, mode: GuardMode, required: bool) -> GuardConfig { +pub fn create_test_guard_with_required(name: &str, mode: GuardMode, required: bool) -> Guard { let mut guard = create_test_guard(name, mode); guard.required = required; guard } #[allow(dead_code)] -pub fn create_test_guard_with_api_base(name: &str, mode: GuardMode, api_base: &str) -> GuardConfig { +pub fn create_test_guard_with_api_base(name: &str, mode: GuardMode, api_base: &str) -> Guard { let mut guard = create_test_guard(name, mode); guard.api_base = Some(api_base.to_string()); guard @@ -217,7 +217,7 @@ impl MockGuardrailClient { impl GuardrailClient for MockGuardrailClient { async fn evaluate( &self, - guard: &GuardConfig, + guard: &Guard, _input: &str, ) -> Result { self.responses diff --git a/tests/guardrails/test_e2e.rs b/tests/guardrails/test_e2e.rs index e6398fde..628f4dc1 100644 --- a/tests/guardrails/test_e2e.rs +++ b/tests/guardrails/test_e2e.rs @@ -39,8 +39,8 @@ fn guard_with_server( on_failure: OnFailure, server_uri: &str, slug: &str, -) -> GuardConfig { - GuardConfig { +) -> Guard { + Guard { name: name.to_string(), provider: "traceloop".to_string(), evaluator_slug: slug.to_string(), diff --git a/tests/guardrails/test_pipeline.rs b/tests/guardrails/test_pipeline.rs index 7c99d404..64152934 100644 --- a/tests/guardrails/test_pipeline.rs +++ b/tests/guardrails/test_pipeline.rs @@ -33,7 +33,7 @@ async fn test_pre_call_guardrails_block_before_llm() { .mount(&eval_server) .await; - let guard = GuardConfig { + let guard = Guard { name: "toxicity-check".to_string(), provider: "traceloop".to_string(), evaluator_slug: "toxicity".to_string(), @@ -64,7 +64,7 @@ async fn test_pre_call_guardrails_warn_and_continue() { .mount(&eval_server) .await; - let guard = GuardConfig { + let guard = Guard { name: "tone-check".to_string(), provider: "traceloop".to_string(), evaluator_slug: "tone".to_string(), @@ -96,7 +96,7 @@ async fn test_post_call_guardrails_block_response() { .mount(&eval_server) .await; - let guard = GuardConfig { + let guard = Guard { name: "pii-check".to_string(), provider: "traceloop".to_string(), evaluator_slug: "pii".to_string(), @@ -128,7 +128,7 @@ async fn test_post_call_guardrails_warn_and_add_header() { .mount(&eval_server) .await; - let guard = GuardConfig { + let guard = Guard { name: "safety-check".to_string(), provider: "traceloop".to_string(), evaluator_slug: "safety".to_string(), diff --git a/tests/guardrails/test_types.rs b/tests/guardrails/test_types.rs index 8f2383e8..8430ae19 100644 --- a/tests/guardrails/test_types.rs +++ b/tests/guardrails/test_types.rs @@ -27,7 +27,7 @@ fn test_on_failure_defaults_to_warn() { "evaluator_slug": "toxicity", "mode": "pre_call" }); - let guard: GuardConfig = serde_json::from_value(json).unwrap(); + let guard: Guard = serde_json::from_value(json).unwrap(); assert_eq!(guard.on_failure, OnFailure::Warn); } @@ -39,7 +39,7 @@ fn test_required_defaults_to_true() { "evaluator_slug": "toxicity", "mode": "pre_call" }); - let guard: GuardConfig = serde_json::from_value(json).unwrap(); + let guard: Guard = serde_json::from_value(json).unwrap(); assert!(guard.required); } @@ -58,7 +58,7 @@ fn test_guard_config_full_deserialization() { "api_base": "https://api.traceloop.com", "api_key": "tl-key-123" }); - let guard: GuardConfig = serde_json::from_value(json).unwrap(); + let guard: Guard = serde_json::from_value(json).unwrap(); assert_eq!(guard.name, "toxicity-check"); assert_eq!(guard.provider, "traceloop"); assert_eq!(guard.evaluator_slug, "toxicity"); @@ -231,7 +231,7 @@ fn test_guard_without_api_base_deserializes() { "evaluator_slug": "toxicity", "mode": "pre_call" }); - let guard: GuardConfig = serde_json::from_value(json).unwrap(); + let guard: Guard = serde_json::from_value(json).unwrap(); assert!(guard.api_base.is_none()); assert!(guard.api_key.is_none()); } @@ -245,7 +245,7 @@ fn test_guard_config_evaluator_slug_not_in_params() { "params": {"threshold": 0.5}, "mode": "pre_call" }); - let guard: GuardConfig = serde_json::from_value(json).unwrap(); + let guard: Guard = serde_json::from_value(json).unwrap(); assert_eq!(guard.evaluator_slug, "toxicity"); assert!(!guard.params.contains_key("evaluator_slug")); assert_eq!(guard.params.get("threshold").unwrap(), 0.5); From a386ad42743ea76c70a452aa26b56c1143b4f72a Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Wed, 11 Feb 2026 11:30:37 +0200 Subject: [PATCH 04/59] common guardrail --- src/config/lib.rs | 4 +- src/config/validation.rs | 4 +- src/guardrails/types.rs | 9 +- .../services/config_provider_service.rs | 1 + src/pipelines/pipeline.rs | 217 +++++++----- src/state.rs | 12 +- src/types/mod.rs | 8 +- tests/guardrails/test_e2e.rs | 26 +- tests/guardrails/test_pipeline.rs | 324 +++++++++++++++++- tests/pipeline_header_routing_test.rs | 3 + tests/router_cache_tests.rs | 6 + tests/router_integration_test.rs | 6 + tests/unified_openapi_test.rs | 1 + 13 files changed, 510 insertions(+), 111 deletions(-) diff --git a/src/config/lib.rs b/src/config/lib.rs index e6be59fd..eec1f495 100644 --- a/src/config/lib.rs +++ b/src/config/lib.rs @@ -12,6 +12,8 @@ struct YamlCompatiblePipeline { r#type: PipelineType, #[serde(with = "serde_yaml::with::singleton_map_recursive")] plugins: Vec, + #[serde(default)] + guards: Vec, #[serde(default = "default_enabled_true_lib")] #[allow(dead_code)] enabled: bool, // Keep for YAML parsing, but won't be mapped to core Pipeline @@ -86,7 +88,7 @@ pub fn load_config(path: &str) -> Result, - pub post_call: Vec, + pub all_guards: Arc>, + pub pipeline_guard_names: Vec, pub client: Arc, } diff --git a/src/management/services/config_provider_service.rs b/src/management/services/config_provider_service.rs index 2ad80bef..94b31331 100644 --- a/src/management/services/config_provider_service.rs +++ b/src/management/services/config_provider_service.rs @@ -257,6 +257,7 @@ impl ConfigProviderService { name: dto.name, r#type: core_pipeline_type, plugins: core_plugins, + guards: vec![], }) } diff --git a/src/pipelines/pipeline.rs b/src/pipelines/pipeline.rs index 8921f562..f9d30858 100644 --- a/src/pipelines/pipeline.rs +++ b/src/pipelines/pipeline.rs @@ -1,12 +1,12 @@ use crate::config::models::PipelineType; -use crate::guardrails::api_control::split_guards_by_mode; +use crate::guardrails::api_control::{parse_guardrails_header, resolve_guards_by_name, split_guards_by_mode}; use crate::guardrails::executor::execute_guards; use crate::guardrails::input_extractor::{ extract_post_call_input_from_completion, extract_post_call_input_from_completion_response, extract_pre_call_input, extract_pre_call_input_from_completion_request, }; use crate::guardrails::providers::GuardrailClient; -use crate::guardrails::types::{GuardrailsConfig, GuardrailsOutcome, Guardrails}; +use crate::guardrails::types::{Guard, GuardrailsConfig, GuardrailsOutcome, Guardrails}; use crate::models::chat::ChatCompletionResponse; use crate::models::completion::CompletionRequest; use crate::models::embeddings::EmbeddingsRequest; @@ -19,6 +19,7 @@ use crate::{ models::chat::ChatCompletionRequest, }; use async_stream::stream; +use axum::http::HeaderMap; use axum::response::sse::{Event, KeepAlive}; use axum::response::{IntoResponse, Response, Sse}; use axum::{ @@ -33,12 +34,8 @@ use reqwest_streams::error::StreamBodyError; use serde_json::json; use std::sync::Arc; -/// Build a PipelineGuardrails from config, resolving provider defaults for api_base/api_key. -pub fn build_pipeline_guardrails(config: &GuardrailsConfig) -> Option> { - if config.guards.is_empty() { - return None; - } - +/// Resolve provider defaults (api_base/api_key) for all guards in the config. +pub fn resolve_guard_defaults(config: &GuardrailsConfig) -> Vec { let mut guards = config.guards.clone(); for guard in &mut guards { if guard.api_base.is_none() || guard.api_key.is_none() { @@ -52,16 +49,35 @@ pub fn build_pipeline_guardrails(config: &GuardrailsConfig) -> Option Option<(Arc>, Arc)> { + if config.guards.is_empty() { + return None; + } + let all_guards = Arc::new(resolve_guard_defaults(config)); let client: Arc = Arc::new(crate::guardrails::providers::traceloop::TraceloopClient::new()); + Some((all_guards, client)) +} - Some(Arc::new(Guardrails { - pre_call, - post_call, - client, - })) +/// Build per-pipeline Guardrails from shared resources. +/// `shared` contains the Arc-wrapped guards and client built once by `build_guardrail_resources`. +pub fn build_pipeline_guardrails( + shared: &(Arc>, Arc), + pipeline_guard_names: &[String], +) -> Arc { + Arc::new(Guardrails { + all_guards: shared.0.clone(), + pipeline_guard_names: pipeline_guard_names.to_vec(), + client: shared.1.clone(), + }) } pub fn blocked_response(outcome: &GuardrailsOutcome) -> Response { @@ -95,9 +111,10 @@ pub fn warning_header_value(outcome: &GuardrailsOutcome) -> String { pub fn create_pipeline( pipeline: &Pipeline, model_registry: &ModelRegistry, - guardrails_config: Option<&GuardrailsConfig>, + guardrail_resources: Option<&(Arc>, Arc)>, ) -> Router { - let guardrails = guardrails_config.and_then(build_pipeline_guardrails); + let guardrails: Option> = guardrail_resources + .map(|shared| build_pipeline_guardrails(shared, &pipeline.guards)); let mut router = Router::new(); let available_models: Vec = pipeline @@ -134,14 +151,14 @@ pub fn create_pipeline( let gr = guardrails.clone(); router.route( "/chat/completions", - post(move |state, payload| chat_completions(state, payload, models, gr)), + post(move |state, headers, payload| chat_completions(state, headers, payload, models, gr)), ) } PipelineType::Completion => { let gr = guardrails.clone(); router.route( "/completions", - post(move |state, payload| completions(state, payload, models, gr)), + post(move |state, headers, payload| completions(state, headers, payload, models, gr)), ) } PipelineType::Embeddings => router.route( @@ -179,25 +196,48 @@ fn trace_and_stream( } } +/// Resolve guards for this request by merging pipeline guards with header-specified guards. +fn resolve_request_guards( + gr: &Guardrails, + headers: &HeaderMap, +) -> (Vec, Vec) { + let header_guard_names = headers + .get("x-traceloop-guardrails") + .and_then(|v| v.to_str().ok()) + .map(parse_guardrails_header) + .unwrap_or_default(); + + let pipeline_names: Vec<&str> = gr.pipeline_guard_names.iter().map(|s| s.as_str()).collect(); + let header_names: Vec<&str> = header_guard_names.iter().map(|s| s.as_str()).collect(); + let resolved = resolve_guards_by_name(&gr.all_guards, &pipeline_names, &header_names, &[]); + split_guards_by_mode(&resolved) +} + pub async fn chat_completions( State(model_registry): State>, + headers: HeaderMap, Json(payload): Json, model_keys: Vec, guardrails: Option>, ) -> Result { let mut tracer = OtelTracer::start("chat", &payload); + // Resolve guards for this request (pipeline + header) + let (pre_call, post_call) = match guardrails.as_ref() { + Some(gr) => resolve_request_guards(gr, &headers), + None => (vec![], vec![]), + }; + // Pre-call guardrails let mut pre_warnings = Vec::new(); - if let Some(ref gr) = guardrails { - if !gr.pre_call.is_empty() { - let input = extract_pre_call_input(&payload); - let outcome = execute_guards(&gr.pre_call, &input, gr.client.as_ref()).await; - if outcome.blocked { - return Ok(blocked_response(&outcome)); - } - pre_warnings = outcome.warnings; - } + if !pre_call.is_empty() { + let gr = guardrails.as_ref().unwrap(); + let input = extract_pre_call_input(&payload); + let outcome = execute_guards(&pre_call, &input, gr.client.as_ref()).await; + if outcome.blocked { + return Ok(blocked_response(&outcome)); + } + pre_warnings = outcome.warnings; } for model_key in model_keys { @@ -218,31 +258,30 @@ pub async fn chat_completions( tracer.log_success(&completion); // Post-call guardrails (non-streaming) - if let Some(ref gr) = guardrails { - if !gr.post_call.is_empty() { - let response_text = extract_post_call_input_from_completion(&completion); - let outcome = - execute_guards(&gr.post_call, &response_text, gr.client.as_ref()).await; - if outcome.blocked { - return Ok(blocked_response(&outcome)); - } - if !outcome.warnings.is_empty() || !pre_warnings.is_empty() { - let mut all_warnings = pre_warnings; - all_warnings.extend(outcome.warnings); - let combined = GuardrailsOutcome { - results: vec![], - blocked: false, - blocking_guard: None, - warnings: all_warnings, - }; - let header_val = warning_header_value(&combined); - let mut response = Json(completion).into_response(); - response.headers_mut().insert( - "X-Traceloop-Guardrail-Warning", - header_val.parse().unwrap(), - ); - return Ok(response); - } + if !post_call.is_empty() { + let gr = guardrails.as_ref().unwrap(); + let response_text = extract_post_call_input_from_completion(&completion); + let outcome = + execute_guards(&post_call, &response_text, gr.client.as_ref()).await; + if outcome.blocked { + return Ok(blocked_response(&outcome)); + } + if !outcome.warnings.is_empty() || !pre_warnings.is_empty() { + let mut all_warnings = pre_warnings; + all_warnings.extend(outcome.warnings); + let combined = GuardrailsOutcome { + results: vec![], + blocked: false, + blocking_guard: None, + warnings: all_warnings, + }; + let header_val = warning_header_value(&combined); + let mut response = Json(completion).into_response(); + response.headers_mut().insert( + "X-Traceloop-Guardrail-Warning", + header_val.parse().unwrap(), + ); + return Ok(response); } } @@ -280,23 +319,29 @@ pub async fn chat_completions( pub async fn completions( State(model_registry): State>, + headers: HeaderMap, Json(payload): Json, model_keys: Vec, guardrails: Option>, ) -> Result { let mut tracer = OtelTracer::start("completion", &payload); + // Resolve guards for this request (pipeline + header) + let (pre_call, post_call) = match guardrails.as_ref() { + Some(gr) => resolve_request_guards(gr, &headers), + None => (vec![], vec![]), + }; + // Pre-call guardrails let mut pre_warnings = Vec::new(); - if let Some(ref gr) = guardrails { - if !gr.pre_call.is_empty() { - let input = extract_pre_call_input_from_completion_request(&payload); - let outcome = execute_guards(&gr.pre_call, &input, gr.client.as_ref()).await; - if outcome.blocked { - return Ok(blocked_response(&outcome)); - } - pre_warnings = outcome.warnings; - } + if !pre_call.is_empty() { + let gr = guardrails.as_ref().unwrap(); + let input = extract_pre_call_input_from_completion_request(&payload); + let outcome = execute_guards(&pre_call, &input, gr.client.as_ref()).await; + if outcome.blocked { + return Ok(blocked_response(&outcome)); + } + pre_warnings = outcome.warnings; } for model_key in model_keys { @@ -312,31 +357,30 @@ pub async fn completions( tracer.log_success(&response); // Post-call guardrails - if let Some(ref gr) = guardrails { - if !gr.post_call.is_empty() { - let response_text = extract_post_call_input_from_completion_response(&response); - let outcome = - execute_guards(&gr.post_call, &response_text, gr.client.as_ref()).await; - if outcome.blocked { - return Ok(blocked_response(&outcome)); - } - if !outcome.warnings.is_empty() || !pre_warnings.is_empty() { - let mut all_warnings = pre_warnings; - all_warnings.extend(outcome.warnings); - let combined = GuardrailsOutcome { - results: vec![], - blocked: false, - blocking_guard: None, - warnings: all_warnings, - }; - let header_val = warning_header_value(&combined); - let mut resp = Json(response).into_response(); - resp.headers_mut().insert( - "X-Traceloop-Guardrail-Warning", - header_val.parse().unwrap(), - ); - return Ok(resp); - } + if !post_call.is_empty() { + let gr = guardrails.as_ref().unwrap(); + let response_text = extract_post_call_input_from_completion_response(&response); + let outcome = + execute_guards(&post_call, &response_text, gr.client.as_ref()).await; + if outcome.blocked { + return Ok(blocked_response(&outcome)); + } + if !outcome.warnings.is_empty() || !pre_warnings.is_empty() { + let mut all_warnings = pre_warnings; + all_warnings.extend(outcome.warnings); + let combined = GuardrailsOutcome { + results: vec![], + blocked: false, + blocking_guard: None, + warnings: all_warnings, + }; + let header_val = warning_header_value(&combined); + let mut resp = Json(response).into_response(); + resp.headers_mut().insert( + "X-Traceloop-Guardrail-Warning", + header_val.parse().unwrap(), + ); + return Ok(resp); } } @@ -491,6 +535,7 @@ mod tests { plugins: vec![PluginConfig::ModelRouter { models: model_keys.into_iter().map(|s| s.to_string()).collect(), }], + guards: vec![], } } diff --git a/src/state.rs b/src/state.rs index 8c124d3b..63717080 100644 --- a/src/state.rs +++ b/src/state.rs @@ -161,10 +161,16 @@ impl AppState { _provider_registry: &Arc, model_registry: &Arc, ) -> axum::Router { - use crate::pipelines::pipeline::create_pipeline; + use crate::pipelines::pipeline::{build_guardrail_resources, create_pipeline}; debug!("Building router with {} pipelines", config.pipelines.len()); + // Build shared guardrail resources once for all pipelines + let guardrail_resources = config + .guardrails + .as_ref() + .and_then(build_guardrail_resources); + let (default_pipeline, other_pipelines): (Vec<_>, Vec<_>) = config .pipelines .iter() @@ -180,7 +186,7 @@ impl AppState { default_pipeline.name ); let pipeline_router = - create_pipeline(default_pipeline, model_registry, config.guardrails.as_ref()); + create_pipeline(default_pipeline, model_registry, guardrail_resources.as_ref()); pipeline_routers.push(pipeline_router); pipeline_names.push(default_pipeline.name.clone()); } @@ -190,7 +196,7 @@ impl AppState { debug!("Adding pipeline '{}' to router at index {}", name, idx + 1); let pipeline_router = - create_pipeline(pipeline, model_registry, config.guardrails.as_ref()); + create_pipeline(pipeline, model_registry, guardrail_resources.as_ref()); pipeline_routers.push(pipeline_router); pipeline_names.push(name.clone()); } diff --git a/src/types/mod.rs b/src/types/mod.rs index b2e3d63c..f4b1467f 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -149,8 +149,12 @@ pub struct Pipeline { // #[serde(with = "serde_yaml::with::singleton_map_recursive")] #[serde(default, skip_serializing_if = "Vec::is_empty")] pub plugins: Vec, - // ee_id: Option, // Removed - // enabled: bool, // Removed + + /// Guard names associated with this pipeline. Guards listed here + /// are always executed for every request to this pipeline. Additional + /// guards can be added per-request via the `X-Traceloop-Guardrails` header. + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub guards: Vec, } #[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq, Hash)] diff --git a/tests/guardrails/test_e2e.rs b/tests/guardrails/test_e2e.rs index 628f4dc1..e71911fa 100644 --- a/tests/guardrails/test_e2e.rs +++ b/tests/guardrails/test_e2e.rs @@ -5,7 +5,7 @@ use hub_lib::guardrails::input_extractor::{ use hub_lib::guardrails::providers::traceloop::TraceloopClient; use hub_lib::guardrails::stream_buffer::extract_text_from_chunks; use hub_lib::guardrails::types::*; -use hub_lib::pipelines::pipeline::build_pipeline_guardrails; +use hub_lib::pipelines::pipeline::{build_guardrail_resources, build_pipeline_guardrails}; use serde_json::json; use wiremock::matchers; @@ -375,21 +375,25 @@ guardrails: assert_eq!(gr.guards[1].api_key.as_deref(), Some("override-key")); // Build pipeline guardrails - should resolve provider defaults - let pipeline_gr = build_pipeline_guardrails(&gr).unwrap(); - assert_eq!(pipeline_gr.pre_call.len(), 1); - assert_eq!(pipeline_gr.post_call.len(), 1); + let shared = build_guardrail_resources(&gr).unwrap(); + let guard_names: Vec = gr.guards.iter().map(|g| g.name.clone()).collect(); + let pipeline_gr = build_pipeline_guardrails(&shared, &guard_names); + assert_eq!(pipeline_gr.all_guards.len(), 2); + assert_eq!(pipeline_gr.pipeline_guard_names.len(), 2); // Provider api_base should be resolved for guards that don't override + let pre_guard = pipeline_gr.all_guards.iter().find(|g| g.mode == GuardMode::PreCall).unwrap(); + let post_guard = pipeline_gr.all_guards.iter().find(|g| g.mode == GuardMode::PostCall).unwrap(); assert_eq!( - pipeline_gr.pre_call[0].api_base.as_deref(), + pre_guard.api_base.as_deref(), Some("https://api.traceloop.com") ); assert_eq!( - pipeline_gr.pre_call[0].api_key.as_deref(), + pre_guard.api_key.as_deref(), Some("resolved-key-123") ); // Guard with override keeps its own api_key assert_eq!( - pipeline_gr.post_call[0].api_key.as_deref(), + post_guard.api_key.as_deref(), Some("override-key") ); @@ -533,10 +537,10 @@ pipelines: let config = hub_lib::config::load_config(temp_file.path().to_str().unwrap()).unwrap(); assert!(config.guardrails.is_none()); - // build_pipeline_guardrails with None returns None - let gr = config + // build_guardrail_resources with None guardrails returns None + let shared = config .guardrails .as_ref() - .and_then(build_pipeline_guardrails); - assert!(gr.is_none()); + .and_then(build_guardrail_resources); + assert!(shared.is_none()); } diff --git a/tests/guardrails/test_pipeline.rs b/tests/guardrails/test_pipeline.rs index 64152934..4ce2f1ab 100644 --- a/tests/guardrails/test_pipeline.rs +++ b/tests/guardrails/test_pipeline.rs @@ -1,12 +1,13 @@ +use hub_lib::guardrails::api_control::{resolve_guards_by_name, split_guards_by_mode}; use hub_lib::guardrails::executor::execute_guards; use hub_lib::guardrails::providers::traceloop::TraceloopClient; use hub_lib::guardrails::types::*; use hub_lib::pipelines::pipeline::{ - blocked_response, build_pipeline_guardrails, warning_header_value, + blocked_response, build_guardrail_resources, build_pipeline_guardrails, resolve_guard_defaults, + warning_header_value, }; use axum::body::to_bytes; -use axum::response::IntoResponse; use serde_json::json; use wiremock::matchers; use wiremock::{Mock, MockServer, ResponseTemplate}; @@ -189,12 +190,12 @@ async fn test_blocked_response_403_format() { #[tokio::test] async fn test_no_guardrails_passthrough() { - // Empty guardrails config -> build_pipeline_guardrails returns None + // Empty guardrails config -> build_guardrail_resources returns None let config = GuardrailsConfig { providers: vec![], guards: vec![], }; - let result = build_pipeline_guardrails(&config); + let result = build_guardrail_resources(&config); assert!(result.is_none()); // Config with no guards -> passthrough @@ -206,6 +207,319 @@ async fn test_no_guardrails_passthrough() { }], guards: vec![], }; - let result = build_pipeline_guardrails(&config_with_providers); + let result = build_guardrail_resources(&config_with_providers); assert!(result.is_none()); } + +// --------------------------------------------------------------------------- +// Pipeline-specific guard association tests +// --------------------------------------------------------------------------- + +fn test_guardrails_config() -> GuardrailsConfig { + GuardrailsConfig { + providers: vec![ProviderConfig { + name: "traceloop".to_string(), + api_base: "https://api.traceloop.com".to_string(), + api_key: "test-key".to_string(), + }], + guards: vec![ + Guard { + name: "pii-check".to_string(), + provider: "traceloop".to_string(), + evaluator_slug: "pii".to_string(), + params: Default::default(), + mode: GuardMode::PreCall, + on_failure: OnFailure::Block, + required: true, + api_base: None, + api_key: None, + }, + Guard { + name: "toxicity-filter".to_string(), + provider: "traceloop".to_string(), + evaluator_slug: "toxicity".to_string(), + params: Default::default(), + mode: GuardMode::PostCall, + on_failure: OnFailure::Warn, + required: true, + api_base: None, + api_key: None, + }, + Guard { + name: "injection-check".to_string(), + provider: "traceloop".to_string(), + evaluator_slug: "injection".to_string(), + params: Default::default(), + mode: GuardMode::PreCall, + on_failure: OnFailure::Block, + required: true, + api_base: None, + api_key: None, + }, + Guard { + name: "secrets-check".to_string(), + provider: "traceloop".to_string(), + evaluator_slug: "secrets".to_string(), + params: Default::default(), + mode: GuardMode::PostCall, + on_failure: OnFailure::Block, + required: true, + api_base: None, + api_key: None, + }, + ], + } +} + +#[test] +fn test_build_pipeline_guardrails_with_specific_guards() { + let config = test_guardrails_config(); + let shared = build_guardrail_resources(&config).unwrap(); + let pipeline_guards = vec!["pii-check".to_string(), "toxicity-filter".to_string()]; + let gr = build_pipeline_guardrails(&shared, &pipeline_guards); + + // all_guards should contain ALL guards from config, resolved with provider defaults + assert_eq!(gr.all_guards.len(), 4); + // pipeline_guard_names should only contain the ones specified + assert_eq!(gr.pipeline_guard_names, vec!["pii-check", "toxicity-filter"]); +} + +#[test] +fn test_build_pipeline_guardrails_empty_pipeline_guards() { + let config = test_guardrails_config(); + let shared = build_guardrail_resources(&config).unwrap(); + // Pipeline with no guards specified - shared resources still exist + // (header guards can still be used at request time) + let gr = build_pipeline_guardrails(&shared, &[]); + + assert_eq!(gr.all_guards.len(), 4); + assert!(gr.pipeline_guard_names.is_empty()); +} + +#[test] +fn test_build_pipeline_guardrails_resolves_provider_defaults() { + let config = test_guardrails_config(); + let shared = build_guardrail_resources(&config).unwrap(); + let gr = build_pipeline_guardrails(&shared, &["pii-check".to_string()]); + + // Guards should have provider api_base/api_key resolved + for guard in gr.all_guards.iter() { + assert_eq!(guard.api_base.as_deref(), Some("https://api.traceloop.com")); + assert_eq!(guard.api_key.as_deref(), Some("test-key")); + } +} + +#[test] +fn test_resolve_guard_defaults_preserves_guard_overrides() { + let config = GuardrailsConfig { + providers: vec![ProviderConfig { + name: "traceloop".to_string(), + api_base: "https://default.api.com".to_string(), + api_key: "default-key".to_string(), + }], + guards: vec![Guard { + name: "custom-guard".to_string(), + provider: "traceloop".to_string(), + evaluator_slug: "custom".to_string(), + params: Default::default(), + mode: GuardMode::PreCall, + on_failure: OnFailure::Block, + required: true, + api_base: Some("https://custom.api.com".to_string()), + api_key: Some("custom-key".to_string()), + }], + }; + + let resolved = resolve_guard_defaults(&config); + assert_eq!(resolved[0].api_base.as_deref(), Some("https://custom.api.com")); + assert_eq!(resolved[0].api_key.as_deref(), Some("custom-key")); +} + +#[test] +fn test_pipeline_guards_resolved_at_request_time() { + // Simulates what happens at request time: merge pipeline + header guards + let config = test_guardrails_config(); + let all_guards = resolve_guard_defaults(&config); + + // Pipeline declares only pii-check + let pipeline_names = vec!["pii-check"]; + // Header adds injection-check + let header_names = vec!["injection-check"]; + + let resolved = resolve_guards_by_name( + &all_guards, + &pipeline_names, + &header_names, + &[], + ); + assert_eq!(resolved.len(), 2); + assert_eq!(resolved[0].name, "pii-check"); + assert_eq!(resolved[1].name, "injection-check"); +} + +#[test] +fn test_pipeline_guards_plus_header_guards_split_by_mode() { + let config = test_guardrails_config(); + let all_guards = resolve_guard_defaults(&config); + + // Pipeline declares pii-check (pre_call) and toxicity-filter (post_call) + let pipeline_names = vec!["pii-check", "toxicity-filter"]; + // Header adds injection-check (pre_call) and secrets-check (post_call) + let header_names = vec!["injection-check", "secrets-check"]; + + let resolved = resolve_guards_by_name(&all_guards, &pipeline_names, &header_names, &[]); + assert_eq!(resolved.len(), 4); + + let (pre_call, post_call) = split_guards_by_mode(&resolved); + assert_eq!(pre_call.len(), 2); + assert_eq!(post_call.len(), 2); + assert!(pre_call.iter().any(|g| g.name == "pii-check")); + assert!(pre_call.iter().any(|g| g.name == "injection-check")); + assert!(post_call.iter().any(|g| g.name == "toxicity-filter")); + assert!(post_call.iter().any(|g| g.name == "secrets-check")); +} + +#[test] +fn test_header_guard_not_in_config_is_ignored() { + let config = test_guardrails_config(); + let all_guards = resolve_guard_defaults(&config); + + let pipeline_names = vec!["pii-check"]; + let header_names = vec!["nonexistent-guard"]; + + let resolved = resolve_guards_by_name(&all_guards, &pipeline_names, &header_names, &[]); + // Only pii-check should be resolved; nonexistent guard is silently ignored + assert_eq!(resolved.len(), 1); + assert_eq!(resolved[0].name, "pii-check"); +} + +#[test] +fn test_duplicate_guard_in_header_and_pipeline_deduped() { + let config = test_guardrails_config(); + let all_guards = resolve_guard_defaults(&config); + + let pipeline_names = vec!["pii-check", "toxicity-filter"]; + // Header specifies same guard as pipeline + let header_names = vec!["pii-check"]; + + let resolved = resolve_guards_by_name(&all_guards, &pipeline_names, &header_names, &[]); + assert_eq!(resolved.len(), 2); // pii-check only appears once +} + +#[test] +fn test_no_pipeline_guards_header_only() { + let config = test_guardrails_config(); + let all_guards = resolve_guard_defaults(&config); + + // Pipeline has no guards + let pipeline_names: Vec<&str> = vec![]; + // Header adds guards + let header_names = vec!["injection-check", "secrets-check"]; + + let resolved = resolve_guards_by_name(&all_guards, &pipeline_names, &header_names, &[]); + assert_eq!(resolved.len(), 2); + assert_eq!(resolved[0].name, "injection-check"); + assert_eq!(resolved[1].name, "secrets-check"); +} + +#[test] +fn test_no_pipeline_guards_no_header_no_guards_executed() { + let config = test_guardrails_config(); + let all_guards = resolve_guard_defaults(&config); + + let resolved = resolve_guards_by_name(&all_guards, &[], &[], &[]); + assert!(resolved.is_empty()); + + let (pre_call, post_call) = split_guards_by_mode(&resolved); + assert!(pre_call.is_empty()); + assert!(post_call.is_empty()); +} + +#[test] +fn test_pipeline_guards_field_in_yaml_config() { + use std::io::Write; + use tempfile::NamedTempFile; + + let config_yaml = r#" +providers: + - key: openai + type: openai + api_key: "sk-test" +models: + - key: gpt-4 + type: gpt-4 + provider: openai +pipelines: + - name: default + type: chat + guards: + - pii-check + - injection-check + plugins: + - model-router: + models: + - gpt-4 + - name: embeddings + type: embeddings + plugins: + - model-router: + models: + - gpt-4 +guardrails: + providers: + - name: traceloop + api_base: "https://api.traceloop.com" + api_key: "test-key" + guards: + - name: pii-check + provider: traceloop + evaluator_slug: pii + mode: pre_call + on_failure: block + - name: injection-check + provider: traceloop + evaluator_slug: injection + mode: pre_call + on_failure: block +"#; + + let mut temp_file = NamedTempFile::new().unwrap(); + temp_file.write_all(config_yaml.as_bytes()).unwrap(); + + let config = hub_lib::config::load_config(temp_file.path().to_str().unwrap()).unwrap(); + + // Default pipeline should have guards + assert_eq!(config.pipelines[0].guards, vec!["pii-check", "injection-check"]); + // Embeddings pipeline should have no guards + assert!(config.pipelines[1].guards.is_empty()); +} + +#[test] +fn test_pipeline_guards_field_absent_defaults_to_empty() { + use std::io::Write; + use tempfile::NamedTempFile; + + let config_yaml = r#" +providers: + - key: openai + type: openai + api_key: "sk-test" +models: + - key: gpt-4 + type: gpt-4 + provider: openai +pipelines: + - name: default + type: chat + plugins: + - model-router: + models: + - gpt-4 +"#; + + let mut temp_file = NamedTempFile::new().unwrap(); + temp_file.write_all(config_yaml.as_bytes()).unwrap(); + + let config = hub_lib::config::load_config(temp_file.path().to_str().unwrap()).unwrap(); + assert!(config.pipelines[0].guards.is_empty()); +} diff --git a/tests/pipeline_header_routing_test.rs b/tests/pipeline_header_routing_test.rs index 02840f59..2e98cd32 100644 --- a/tests/pipeline_header_routing_test.rs +++ b/tests/pipeline_header_routing_test.rs @@ -25,6 +25,7 @@ fn create_test_config_with_multiple_pipelines() -> GatewayConfig { plugins: vec![PluginConfig::ModelRouter { models: vec!["test-model".to_string()], }], + guards: vec![], }; let pipeline2 = Pipeline { @@ -33,6 +34,7 @@ fn create_test_config_with_multiple_pipelines() -> GatewayConfig { plugins: vec![PluginConfig::ModelRouter { models: vec!["test-model".to_string()], }], + guards: vec![], }; GatewayConfig { @@ -76,6 +78,7 @@ async fn test_pipeline_header_routing_configuration_updates() { plugins: vec![PluginConfig::ModelRouter { models: vec!["test-model".to_string()], }], + guards: vec![], }; updated_config.pipelines.push(pipeline3); diff --git a/tests/router_cache_tests.rs b/tests/router_cache_tests.rs index bd84450e..4cef3445 100644 --- a/tests/router_cache_tests.rs +++ b/tests/router_cache_tests.rs @@ -28,6 +28,7 @@ async fn test_router_always_available() { plugins: vec![PluginConfig::ModelRouter { models: vec!["gpt-4".to_string()], }], + guards: vec![], }], }; @@ -63,6 +64,7 @@ async fn test_configuration_change_detection() { plugins: vec![PluginConfig::ModelRouter { models: vec!["gpt-4".to_string()], }], + guards: vec![], }], }; @@ -153,6 +155,7 @@ async fn test_concurrent_router_access() { plugins: vec![PluginConfig::ModelRouter { models: vec!["gpt-4".to_string()], }], + guards: vec![], }], }; @@ -226,6 +229,7 @@ async fn test_pipeline_with_failing_tracing_endpoint() { models: vec!["gpt-4".to_string()], }, ], + guards: vec![], }], }; @@ -284,6 +288,7 @@ async fn test_tracing_isolation_between_pipelines() { models: vec!["gpt-4".to_string()], }, ], + guards: vec![], }, // Pipeline without tracing Pipeline { @@ -292,6 +297,7 @@ async fn test_tracing_isolation_between_pipelines() { plugins: vec![PluginConfig::ModelRouter { models: vec!["gpt-4".to_string()], }], + guards: vec![], }, ], }; diff --git a/tests/router_integration_test.rs b/tests/router_integration_test.rs index dd3060e5..5054f508 100644 --- a/tests/router_integration_test.rs +++ b/tests/router_integration_test.rs @@ -42,6 +42,7 @@ async fn test_router_integration_flow() { plugins: vec![PluginConfig::ModelRouter { models: vec!["gpt-4".to_string()], }], + guards: vec![], }], }; @@ -84,6 +85,7 @@ async fn test_router_integration_flow() { plugins: vec![PluginConfig::ModelRouter { models: vec!["gpt-4".to_string()], }], + guards: vec![], }], }; @@ -126,6 +128,7 @@ async fn test_router_integration_flow() { plugins: vec![PluginConfig::ModelRouter { models: vec!["gpt-4".to_string()], }], + guards: vec![], }, Pipeline { name: "fast".to_string(), @@ -133,6 +136,7 @@ async fn test_router_integration_flow() { plugins: vec![PluginConfig::ModelRouter { models: vec!["gpt-3.5-turbo".to_string()], }], + guards: vec![], }, ], }; @@ -181,6 +185,7 @@ async fn test_concurrent_configuration_updates() { plugins: vec![PluginConfig::ModelRouter { models: vec!["gpt-4".to_string()], }], + guards: vec![], }], }; @@ -213,6 +218,7 @@ async fn test_concurrent_configuration_updates() { plugins: vec![PluginConfig::ModelRouter { models: vec![format!("model-{}", i)], }], + guards: vec![], }], }; diff --git a/tests/unified_openapi_test.rs b/tests/unified_openapi_test.rs index 9bbbd8cd..400b5b23 100644 --- a/tests/unified_openapi_test.rs +++ b/tests/unified_openapi_test.rs @@ -144,6 +144,7 @@ async fn test_router_creation_no_conflicts() { plugins: vec![PluginConfig::ModelRouter { models: vec!["gpt-4".to_string()], }], + guards: vec![], }], }; From d0bf203d490fe9cb022f4cd7fe0f0b6ad9285681 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Wed, 11 Feb 2026 12:15:46 +0200 Subject: [PATCH 05/59] redesign --- Cargo.lock | 1 + Cargo.toml | 1 + src/config/validation.rs | 39 +++- src/guardrails/api_control.rs | 23 +- src/guardrails/builder.rs | 49 +++++ src/guardrails/executor.rs | 31 ++- src/guardrails/guardrails_orchestrator.rs | 152 +++++++++++++ src/guardrails/mod.rs | 2 + src/guardrails/providers/traceloop.rs | 23 +- src/guardrails/types.rs | 41 ++-- src/pipelines/pipeline.rs | 247 ++++------------------ src/state.rs | 3 +- tests/guardrails/test_e2e.rs | 4 +- tests/guardrails/test_executor.rs | 4 +- tests/guardrails/test_pipeline.rs | 25 +-- 15 files changed, 376 insertions(+), 269 deletions(-) create mode 100644 src/guardrails/builder.rs create mode 100644 src/guardrails/guardrails_orchestrator.rs diff --git a/Cargo.lock b/Cargo.lock index f2f8d530..7349a059 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2302,6 +2302,7 @@ dependencies = [ "tempfile", "testcontainers", "testcontainers-modules", + "thiserror 2.0.17", "tokio", "tower 0.5.3", "tower-http", diff --git a/Cargo.toml b/Cargo.toml index 4c2a8838..32a7cbc5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,6 +58,7 @@ utoipa = { version = "5.0.0", features = ["axum_extras", "chrono", "uuid", "open utoipa-swagger-ui = { version = "8", features = ["axum"] } utoipa-scalar = { version = "0.3.0", features = ["axum"] } log = "0.4" +thiserror = "2" [lib] name = "hub_lib" diff --git a/src/config/validation.rs b/src/config/validation.rs index 4c17eb4e..86771cf4 100644 --- a/src/config/validation.rs +++ b/src/config/validation.rs @@ -37,10 +37,41 @@ pub fn validate_gateway_config(config: &GatewayConfig) -> Result<(), Vec } } - // Add more validation checks as needed: - // - Duplicate keys for providers, models, pipelines? - // - Empty names/keys? - // - Specific validation for provider params based on type (more complex, might be out of scope for basic validation) + // Check 3: Guardrails validation + if let Some(gr_config) = &config.guardrails { + // Guard provider references must exist in guardrails.providers + let gr_provider_names: HashSet<&String> = + gr_config.providers.iter().map(|p| &p.name).collect(); + for guard in &gr_config.guards { + if !gr_provider_names.contains(&guard.provider) { + errors.push(format!( + "Guard '{}' references non-existent guardrail provider '{}'.", + guard.name, guard.provider + )); + } + } + + // Pipeline guard references must exist in guardrails.guards + let guard_names: HashSet<&String> = gr_config.guards.iter().map(|g| &g.name).collect(); + for pipeline in &config.pipelines { + for guard_name in &pipeline.guards { + if !guard_names.contains(guard_name) { + errors.push(format!( + "Pipeline '{}' references non-existent guard '{}'.", + pipeline.name, guard_name + )); + } + } + } + + // Guard names must be unique + let mut seen_guard_names = HashSet::new(); + for guard in &gr_config.guards { + if !seen_guard_names.insert(&guard.name) { + errors.push(format!("Duplicate guard name: '{}'.", guard.name)); + } + } + } if errors.is_empty() { Ok(()) diff --git a/src/guardrails/api_control.rs b/src/guardrails/api_control.rs index 0a1c4357..15491f20 100644 --- a/src/guardrails/api_control.rs +++ b/src/guardrails/api_control.rs @@ -1,4 +1,4 @@ -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use super::types::{Guard, GuardMode}; @@ -26,31 +26,30 @@ pub fn parse_guardrails_from_payload(payload: &serde_json::Value) -> Vec } /// Resolve the final set of guards to execute by merging pipeline, header, and payload sources. -/// Guards are additive and deduplicated by name. +/// Guards are additive and deduplicated by name. Uses HashMap for O(1) guard lookup. pub fn resolve_guards_by_name( all_guards: &[Guard], pipeline_names: &[&str], header_names: &[&str], payload_names: &[&str], ) -> Vec { + let guard_map: HashMap<&str, &Guard> = + all_guards.iter().map(|g| (g.name.as_str(), g)).collect(); + let mut seen = HashSet::new(); let mut resolved = Vec::new(); - // Collect all requested names, pipeline first, then header, then payload - let all_names: Vec<&str> = pipeline_names + let all_names = pipeline_names .iter() .chain(header_names.iter()) .chain(payload_names.iter()) - .copied() - .collect(); + .copied(); for name in all_names { - if seen.contains(name) { - continue; - } - if let Some(guard) = all_guards.iter().find(|g| g.name == name) { - seen.insert(name); - resolved.push(guard.clone()); + if seen.insert(name) { + if let Some(guard) = guard_map.get(name) { + resolved.push((*guard).clone()); + } } } diff --git a/src/guardrails/builder.rs b/src/guardrails/builder.rs new file mode 100644 index 00000000..5b710402 --- /dev/null +++ b/src/guardrails/builder.rs @@ -0,0 +1,49 @@ +use std::sync::Arc; + +use super::types::{Guard, GuardrailResources, GuardrailsConfig, Guardrails}; + +/// Resolve provider defaults (api_base/api_key) for all guards in the config. +pub fn resolve_guard_defaults(config: &GuardrailsConfig) -> Vec { + let mut guards = config.guards.clone(); + for guard in &mut guards { + if guard.api_base.is_none() || guard.api_key.is_none() { + if let Some(provider) = config.providers.iter().find(|p| p.name == guard.provider) { + if guard.api_base.is_none() { + guard.api_base = Some(provider.api_base.clone()); + } + if guard.api_key.is_none() { + guard.api_key = Some(provider.api_key.clone()); + } + } + } + } + guards +} + +/// Build the shared guardrail resources (resolved guards + client). +/// Returns None if the config has no guards. +/// Called once per router build; the result is shared across all pipelines. +pub fn build_guardrail_resources( + config: &GuardrailsConfig, +) -> Option { + if config.guards.is_empty() { + return None; + } + let all_guards = Arc::new(resolve_guard_defaults(config)); + let client: Arc = + Arc::new(super::providers::traceloop::TraceloopClient::new()); + Some((all_guards, client)) +} + +/// Build per-pipeline Guardrails from shared resources. +/// `shared` contains the Arc-wrapped guards and client built once by `build_guardrail_resources`. +pub fn build_pipeline_guardrails( + shared: &GuardrailResources, + pipeline_guard_names: &[String], +) -> Arc { + Arc::new(Guardrails { + all_guards: shared.0.clone(), + pipeline_guard_names: pipeline_guard_names.to_vec(), + client: shared.1.clone(), + }) +} diff --git a/src/guardrails/executor.rs b/src/guardrails/executor.rs index a5eb783f..ec38b6c9 100644 --- a/src/guardrails/executor.rs +++ b/src/guardrails/executor.rs @@ -1,7 +1,8 @@ use futures::future::join_all; +use tracing::{debug, warn}; use super::providers::GuardrailClient; -use super::types::{Guard, GuardResult, GuardrailsOutcome, OnFailure}; +use super::types::{Guard, GuardResult, GuardWarning, GuardrailsOutcome, OnFailure}; /// Execute a set of guardrails against the given input text. /// Guards are run concurrently. Returns a GuardrailsOutcome with results, blocked status, and warnings. @@ -10,10 +11,29 @@ pub async fn execute_guards( input: &str, client: &dyn GuardrailClient, ) -> GuardrailsOutcome { + debug!(guard_count = guards.len(), "Executing guardrails"); + let futures: Vec<_> = guards .iter() .map(|guard| async move { + let start = std::time::Instant::now(); let result = client.evaluate(guard, input).await; + let elapsed = start.elapsed(); + match &result { + Ok(resp) => debug!( + guard = %guard.name, + pass = resp.pass, + elapsed_ms = elapsed.as_millis(), + "Guard evaluation complete" + ), + Err(err) => warn!( + guard = %guard.name, + error = %err, + required = guard.required, + elapsed_ms = elapsed.as_millis(), + "Guard evaluation failed" + ), + } (guard, result) }) .collect(); @@ -47,7 +67,10 @@ pub async fn execute_guards( } } OnFailure::Warn => { - warnings.push(format!("Guard '{}' failed with warning", guard.name)); + warnings.push(GuardWarning { + guard_name: guard.name.clone(), + reason: "failed".to_string(), + }); } } } @@ -69,6 +92,10 @@ pub async fn execute_guards( } } + if blocked { + warn!(blocking_guard = ?blocking_guard, "Request blocked by guardrail"); + } + GuardrailsOutcome { results, blocked, diff --git a/src/guardrails/guardrails_orchestrator.rs b/src/guardrails/guardrails_orchestrator.rs new file mode 100644 index 00000000..200f5b31 --- /dev/null +++ b/src/guardrails/guardrails_orchestrator.rs @@ -0,0 +1,152 @@ +use std::collections::HashSet; + +use axum::http::HeaderMap; +use axum::response::{IntoResponse, Response}; +use axum::http::StatusCode; +use axum::Json; +use serde_json::json; +use tracing::warn; + +use super::api_control::{parse_guardrails_header, resolve_guards_by_name, split_guards_by_mode}; +use super::executor::execute_guards; +use super::providers::GuardrailClient; +use super::types::{Guard, GuardWarning, Guardrails}; + +/// Result of running pre-call or post-call guards. +pub struct GuardPhaseResult { + pub blocked_response: Option, + pub warnings: Vec, +} + +/// Orchestrates guardrail execution across pre-call and post-call phases. +/// Shared between chat_completions and completions handlers. +pub struct GuardrailOrchestrator<'a> { + pre_call: Vec, + post_call: Vec, + client: &'a dyn GuardrailClient, +} + +impl<'a> GuardrailOrchestrator<'a> { + /// Create an orchestrator by resolving guards from pipeline config + request headers. + /// Returns None if no guards are active for this request. + pub fn new(guardrails: Option<&'a Guardrails>, headers: &HeaderMap) -> Option { + let gr = guardrails?; + let (pre_call, post_call) = resolve_request_guards(gr, headers); + if pre_call.is_empty() && post_call.is_empty() { + return None; + } + Some(Self { + pre_call, + post_call, + client: gr.client.as_ref(), + }) + } + + /// Run pre-call guards against the extracted input text. + pub async fn run_pre_call(&self, input: &str) -> GuardPhaseResult { + if self.pre_call.is_empty() { + return GuardPhaseResult { + blocked_response: None, + warnings: Vec::new(), + }; + } + let outcome = execute_guards(&self.pre_call, input, self.client).await; + if outcome.blocked { + return GuardPhaseResult { + blocked_response: Some(blocked_response(&outcome.blocking_guard)), + warnings: Vec::new(), + }; + } + GuardPhaseResult { + blocked_response: None, + warnings: outcome.warnings, + } + } + + /// Run post-call guards against the LLM response text. + pub async fn run_post_call(&self, response_text: &str) -> GuardPhaseResult { + if self.post_call.is_empty() { + return GuardPhaseResult { + blocked_response: None, + warnings: Vec::new(), + }; + } + let outcome = execute_guards(&self.post_call, response_text, self.client).await; + if outcome.blocked { + return GuardPhaseResult { + blocked_response: Some(blocked_response(&outcome.blocking_guard)), + warnings: Vec::new(), + }; + } + GuardPhaseResult { + blocked_response: None, + warnings: outcome.warnings, + } + } + + /// Returns true if post-call guards are configured for this request. + pub fn has_post_call_guards(&self) -> bool { + !self.post_call.is_empty() + } + + /// Attach warning headers to a response if there are any warnings. + /// Returns the response unchanged if there are no warnings. + pub fn finalize_response(response: Response, warnings: &[GuardWarning]) -> Response { + if warnings.is_empty() { + return response; + } + let header_val = warning_header_value(warnings); + let mut response = response; + response.headers_mut().insert( + "X-Traceloop-Guardrail-Warning", + header_val.parse().unwrap(), + ); + response + } +} + +/// Build a 403 blocked response with the guard name. +pub fn blocked_response(blocking_guard: &Option) -> Response { + let guard_name = blocking_guard.as_deref().unwrap_or("unknown"); + let body = json!({ + "error": { + "type": "guardrail_blocked", + "guardrail": guard_name, + "message": format!("Request blocked by guardrail '{guard_name}'"), + } + }); + (StatusCode::FORBIDDEN, Json(body)).into_response() +} + +pub fn warning_header_value(warnings: &[GuardWarning]) -> String { + warnings + .iter() + .map(|w| format!("guardrail_name=\"{}\", reason=\"{}\"", w.guard_name, w.reason)) + .collect::>() + .join("; ") +} + +/// Resolve guards for this request by merging pipeline guards with header-specified guards. +fn resolve_request_guards(gr: &Guardrails, headers: &HeaderMap) -> (Vec, Vec) { + let header_guard_names = headers + .get("x-traceloop-guardrails") + .and_then(|v| v.to_str().ok()) + .map(parse_guardrails_header) + .unwrap_or_default(); + + let pipeline_names: Vec<&str> = gr.pipeline_guard_names.iter().map(|s| s.as_str()).collect(); + let header_names: Vec<&str> = header_guard_names.iter().map(|s| s.as_str()).collect(); + let resolved = resolve_guards_by_name(&gr.all_guards, &pipeline_names, &header_names, &[]); + + // Log unknown guard names from headers + if !header_guard_names.is_empty() { + let resolved_names: HashSet<&str> = resolved.iter().map(|g| g.name.as_str()).collect(); + for name in &header_guard_names { + if !resolved_names.contains(name.as_str()) { + warn!(guard_name = %name, "Unknown guard name in X-Traceloop-Guardrails header, ignoring"); + } + } + } + + split_guards_by_mode(&resolved) +} diff --git a/src/guardrails/mod.rs b/src/guardrails/mod.rs index 81ae9edf..165e70a4 100644 --- a/src/guardrails/mod.rs +++ b/src/guardrails/mod.rs @@ -1,6 +1,8 @@ pub mod api_control; +pub mod builder; pub mod executor; pub mod input_extractor; +pub mod guardrails_orchestrator; pub mod providers; pub mod response_parser; pub mod stream_buffer; diff --git a/src/guardrails/providers/traceloop.rs b/src/guardrails/providers/traceloop.rs index 827d8e5d..74e7226f 100644 --- a/src/guardrails/providers/traceloop.rs +++ b/src/guardrails/providers/traceloop.rs @@ -1,6 +1,7 @@ use async_trait::async_trait; use serde_json::json; use std::time::Duration; +use tracing::debug; use super::GuardrailClient; use crate::guardrails::response_parser::parse_evaluator_http_response; @@ -12,6 +13,12 @@ pub struct TraceloopClient { http_client: reqwest::Client, } +impl Default for TraceloopClient { + fn default() -> Self { + Self::new() + } +} + impl TraceloopClient { pub fn new() -> Self { Self { @@ -50,6 +57,8 @@ impl GuardrailClient for TraceloopClient { "config": config, }); + debug!(guard = %guard.name, slug = %guard.evaluator_slug, %url, "Calling evaluator API"); + let response = self .http_client .post(&url) @@ -57,20 +66,10 @@ impl GuardrailClient for TraceloopClient { .header("Content-Type", "application/json") .json(&body) .send() - .await - .map_err(|e| { - if e.is_timeout() { - GuardrailError::Timeout(e.to_string()) - } else { - GuardrailError::Unavailable(e.to_string()) - } - })?; + .await?; let status = response.status().as_u16(); - let response_body = response - .text() - .await - .map_err(|e| GuardrailError::Unavailable(e.to_string()))?; + let response_body = response.text().await?; parse_evaluator_http_response(status, &response_body) } diff --git a/src/guardrails/types.rs b/src/guardrails/types.rs index 81adcd09..695f05d1 100644 --- a/src/guardrails/types.rs +++ b/src/guardrails/types.rs @@ -2,9 +2,14 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::hash::{Hash, Hasher}; use std::sync::Arc; +use thiserror::Error; use super::providers::GuardrailClient; +/// Shared guardrail resources: resolved guards + client. +/// Built once per router build and shared across all pipelines. +pub type GuardrailResources = (Arc>, Arc); + fn default_on_failure() -> OnFailure { OnFailure::Warn } @@ -52,6 +57,8 @@ pub struct Guard { pub api_key: Option, } +impl Eq for Guard {} + impl Hash for Guard { fn hash(&self, state: &mut H) { self.name.hash(state); @@ -111,37 +118,45 @@ pub enum GuardResult { }, } +#[derive(Debug, Clone)] +pub struct GuardWarning { + pub guard_name: String, + pub reason: String, +} + #[derive(Debug, Clone)] pub struct GuardrailsOutcome { pub results: Vec, pub blocked: bool, pub blocking_guard: Option, - pub warnings: Vec, + pub warnings: Vec, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Error)] pub enum GuardrailError { + #[error("Evaluator unavailable: {0}")] Unavailable(String), + + #[error("HTTP error {status}: {body}")] HttpError { status: u16, body: String }, + + #[error("Timeout: {0}")] Timeout(String), + + #[error("Parse error: {0}")] ParseError(String), } -impl std::fmt::Display for GuardrailError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - GuardrailError::Unavailable(msg) => write!(f, "Evaluator unavailable: {msg}"), - GuardrailError::HttpError { status, body } => { - write!(f, "HTTP error {status}: {body}") - } - GuardrailError::Timeout(msg) => write!(f, "Timeout: {msg}"), - GuardrailError::ParseError(msg) => write!(f, "Parse error: {msg}"), +impl From for GuardrailError { + fn from(e: reqwest::Error) -> Self { + if e.is_timeout() { + GuardrailError::Timeout(e.to_string()) + } else { + GuardrailError::Unavailable(e.to_string()) } } } -impl std::error::Error for GuardrailError {} - /// Guardrails state attached to a pipeline, containing resolved guards and client. /// /// `all_guards` and `client` are shared across all pipelines via `Arc` (built once). diff --git a/src/pipelines/pipeline.rs b/src/pipelines/pipeline.rs index f9d30858..52f54dd3 100644 --- a/src/pipelines/pipeline.rs +++ b/src/pipelines/pipeline.rs @@ -1,12 +1,10 @@ use crate::config::models::PipelineType; -use crate::guardrails::api_control::{parse_guardrails_header, resolve_guards_by_name, split_guards_by_mode}; -use crate::guardrails::executor::execute_guards; use crate::guardrails::input_extractor::{ extract_post_call_input_from_completion, extract_post_call_input_from_completion_response, extract_pre_call_input, extract_pre_call_input_from_completion_request, }; -use crate::guardrails::providers::GuardrailClient; -use crate::guardrails::types::{Guard, GuardrailsConfig, GuardrailsOutcome, Guardrails}; +use crate::guardrails::guardrails_orchestrator::GuardrailOrchestrator; +use crate::guardrails::types::{GuardrailResources, Guardrails}; use crate::models::chat::ChatCompletionResponse; use crate::models::completion::CompletionRequest; use crate::models::embeddings::EmbeddingsRequest; @@ -31,87 +29,18 @@ use axum::{ use futures::stream::BoxStream; use futures::{Stream, StreamExt}; use reqwest_streams::error::StreamBodyError; -use serde_json::json; use std::sync::Arc; -/// Resolve provider defaults (api_base/api_key) for all guards in the config. -pub fn resolve_guard_defaults(config: &GuardrailsConfig) -> Vec { - let mut guards = config.guards.clone(); - for guard in &mut guards { - if guard.api_base.is_none() || guard.api_key.is_none() { - if let Some(provider) = config.providers.iter().find(|p| p.name == guard.provider) { - if guard.api_base.is_none() { - guard.api_base = Some(provider.api_base.clone()); - } - if guard.api_key.is_none() { - guard.api_key = Some(provider.api_key.clone()); - } - } - } - } - guards -} - -/// Build the shared guardrail resources (resolved guards + client). -/// Returns None if the config has no guards. -/// Called once per router build; the result is shared across all pipelines. -pub fn build_guardrail_resources( - config: &GuardrailsConfig, -) -> Option<(Arc>, Arc)> { - if config.guards.is_empty() { - return None; - } - let all_guards = Arc::new(resolve_guard_defaults(config)); - let client: Arc = - Arc::new(crate::guardrails::providers::traceloop::TraceloopClient::new()); - Some((all_guards, client)) -} - -/// Build per-pipeline Guardrails from shared resources. -/// `shared` contains the Arc-wrapped guards and client built once by `build_guardrail_resources`. -pub fn build_pipeline_guardrails( - shared: &(Arc>, Arc), - pipeline_guard_names: &[String], -) -> Arc { - Arc::new(Guardrails { - all_guards: shared.0.clone(), - pipeline_guard_names: pipeline_guard_names.to_vec(), - client: shared.1.clone(), - }) -} - -pub fn blocked_response(outcome: &GuardrailsOutcome) -> Response { - let guard_name = outcome.blocking_guard.as_deref().unwrap_or("unknown"); - let body = json!({ - "error": { - "type": "guardrail_blocked", - "guardrail": guard_name, - "message": format!("Request blocked by guardrail '{guard_name}'"), - } - }); - (StatusCode::FORBIDDEN, Json(body)).into_response() -} - -pub fn warning_header_value(outcome: &GuardrailsOutcome) -> String { - outcome - .warnings - .iter() - .map(|w| { - // Extract guard name from the warning string "Guard 'name' failed with warning" - let name = w - .strip_prefix("Guard '") - .and_then(|s| s.strip_suffix("' failed with warning")) - .unwrap_or("unknown"); - format!("guardrail_name=\"{name}\", reason=\"failed\"") - }) - .collect::>() - .join("; ") -} +// Re-export builder and orchestrator functions for backward compatibility with tests +pub use crate::guardrails::builder::{ + build_guardrail_resources, build_pipeline_guardrails, resolve_guard_defaults, +}; +pub use crate::guardrails::guardrails_orchestrator::{blocked_response, warning_header_value}; pub fn create_pipeline( pipeline: &Pipeline, model_registry: &ModelRegistry, - guardrail_resources: Option<&(Arc>, Arc)>, + guardrail_resources: Option<&GuardrailResources>, ) -> Router { let guardrails: Option> = guardrail_resources .map(|shared| build_pipeline_guardrails(shared, &pipeline.guards)); @@ -196,23 +125,6 @@ fn trace_and_stream( } } -/// Resolve guards for this request by merging pipeline guards with header-specified guards. -fn resolve_request_guards( - gr: &Guardrails, - headers: &HeaderMap, -) -> (Vec, Vec) { - let header_guard_names = headers - .get("x-traceloop-guardrails") - .and_then(|v| v.to_str().ok()) - .map(parse_guardrails_header) - .unwrap_or_default(); - - let pipeline_names: Vec<&str> = gr.pipeline_guard_names.iter().map(|s| s.as_str()).collect(); - let header_names: Vec<&str> = header_guard_names.iter().map(|s| s.as_str()).collect(); - let resolved = resolve_guards_by_name(&gr.all_guards, &pipeline_names, &header_names, &[]); - split_guards_by_mode(&resolved) -} - pub async fn chat_completions( State(model_registry): State>, headers: HeaderMap, @@ -221,30 +133,22 @@ pub async fn chat_completions( guardrails: Option>, ) -> Result { let mut tracer = OtelTracer::start("chat", &payload); - - // Resolve guards for this request (pipeline + header) - let (pre_call, post_call) = match guardrails.as_ref() { - Some(gr) => resolve_request_guards(gr, &headers), - None => (vec![], vec![]), - }; + let orchestrator = GuardrailOrchestrator::new(guardrails.as_deref(), &headers); // Pre-call guardrails - let mut pre_warnings = Vec::new(); - if !pre_call.is_empty() { - let gr = guardrails.as_ref().unwrap(); - let input = extract_pre_call_input(&payload); - let outcome = execute_guards(&pre_call, &input, gr.client.as_ref()).await; - if outcome.blocked { - return Ok(blocked_response(&outcome)); + let mut all_warnings = Vec::new(); + if let Some(orch) = &orchestrator { + let pre = orch.run_pre_call(&extract_pre_call_input(&payload)).await; + if let Some(resp) = pre.blocked_response { + return Ok(resp); } - pre_warnings = outcome.warnings; + all_warnings = pre.warnings; } for model_key in model_keys { let model = model_registry.get(&model_key).unwrap(); if payload.model == model.model_type { - // Set vendor now that we know which model/provider we're using tracer.set_vendor(&get_vendor_name(&model.provider.r#type())); let response = model @@ -258,50 +162,19 @@ pub async fn chat_completions( tracer.log_success(&completion); // Post-call guardrails (non-streaming) - if !post_call.is_empty() { - let gr = guardrails.as_ref().unwrap(); + if let Some(orch) = &orchestrator { let response_text = extract_post_call_input_from_completion(&completion); - let outcome = - execute_guards(&post_call, &response_text, gr.client.as_ref()).await; - if outcome.blocked { - return Ok(blocked_response(&outcome)); - } - if !outcome.warnings.is_empty() || !pre_warnings.is_empty() { - let mut all_warnings = pre_warnings; - all_warnings.extend(outcome.warnings); - let combined = GuardrailsOutcome { - results: vec![], - blocked: false, - blocking_guard: None, - warnings: all_warnings, - }; - let header_val = warning_header_value(&combined); - let mut response = Json(completion).into_response(); - response.headers_mut().insert( - "X-Traceloop-Guardrail-Warning", - header_val.parse().unwrap(), - ); - return Ok(response); + let post = orch.run_post_call(&response_text).await; + if let Some(resp) = post.blocked_response { + return Ok(resp); } + all_warnings.extend(post.warnings); } - // Add pre-call warning headers if any - if !pre_warnings.is_empty() { - let combined = GuardrailsOutcome { - results: vec![], - blocked: false, - blocking_guard: None, - warnings: pre_warnings, - }; - let header_val = warning_header_value(&combined); - let mut response = Json(completion).into_response(); - response - .headers_mut() - .insert("X-Traceloop-Guardrail-Warning", header_val.parse().unwrap()); - return Ok(response); - } - - return Ok(Json(completion).into_response()); + return Ok(GuardrailOrchestrator::finalize_response( + Json(completion).into_response(), + &all_warnings, + )); } if let ChatCompletionResponse::Stream(stream) = response { @@ -325,30 +198,24 @@ pub async fn completions( guardrails: Option>, ) -> Result { let mut tracer = OtelTracer::start("completion", &payload); - - // Resolve guards for this request (pipeline + header) - let (pre_call, post_call) = match guardrails.as_ref() { - Some(gr) => resolve_request_guards(gr, &headers), - None => (vec![], vec![]), - }; + let orchestrator = GuardrailOrchestrator::new(guardrails.as_deref(), &headers); // Pre-call guardrails - let mut pre_warnings = Vec::new(); - if !pre_call.is_empty() { - let gr = guardrails.as_ref().unwrap(); - let input = extract_pre_call_input_from_completion_request(&payload); - let outcome = execute_guards(&pre_call, &input, gr.client.as_ref()).await; - if outcome.blocked { - return Ok(blocked_response(&outcome)); - } - pre_warnings = outcome.warnings; + let mut all_warnings = Vec::new(); + if let Some(orch) = &orchestrator { + let pre = orch + .run_pre_call(&extract_pre_call_input_from_completion_request(&payload)) + .await; + if let Some(resp) = pre.blocked_response { + return Ok(resp); + } + all_warnings = pre.warnings; } for model_key in model_keys { let model = model_registry.get(&model_key).unwrap(); if payload.model == model.model_type { - // Set vendor now that we know which model/provider we're using tracer.set_vendor(&get_vendor_name(&model.provider.r#type())); let response = model.completions(payload.clone()).await.inspect_err(|e| { @@ -357,49 +224,19 @@ pub async fn completions( tracer.log_success(&response); // Post-call guardrails - if !post_call.is_empty() { - let gr = guardrails.as_ref().unwrap(); + if let Some(orch) = &orchestrator { let response_text = extract_post_call_input_from_completion_response(&response); - let outcome = - execute_guards(&post_call, &response_text, gr.client.as_ref()).await; - if outcome.blocked { - return Ok(blocked_response(&outcome)); - } - if !outcome.warnings.is_empty() || !pre_warnings.is_empty() { - let mut all_warnings = pre_warnings; - all_warnings.extend(outcome.warnings); - let combined = GuardrailsOutcome { - results: vec![], - blocked: false, - blocking_guard: None, - warnings: all_warnings, - }; - let header_val = warning_header_value(&combined); - let mut resp = Json(response).into_response(); - resp.headers_mut().insert( - "X-Traceloop-Guardrail-Warning", - header_val.parse().unwrap(), - ); + let post = orch.run_post_call(&response_text).await; + if let Some(resp) = post.blocked_response { return Ok(resp); } + all_warnings.extend(post.warnings); } - // Add pre-call warning headers if any - if !pre_warnings.is_empty() { - let combined = GuardrailsOutcome { - results: vec![], - blocked: false, - blocking_guard: None, - warnings: pre_warnings, - }; - let header_val = warning_header_value(&combined); - let mut resp = Json(response).into_response(); - resp.headers_mut() - .insert("X-Traceloop-Guardrail-Warning", header_val.parse().unwrap()); - return Ok(resp); - } - - return Ok(Json(response).into_response()); + return Ok(GuardrailOrchestrator::finalize_response( + Json(response).into_response(), + &all_warnings, + )); } } diff --git a/src/state.rs b/src/state.rs index 63717080..f6647d45 100644 --- a/src/state.rs +++ b/src/state.rs @@ -161,7 +161,8 @@ impl AppState { _provider_registry: &Arc, model_registry: &Arc, ) -> axum::Router { - use crate::pipelines::pipeline::{build_guardrail_resources, create_pipeline}; + use crate::guardrails::builder::build_guardrail_resources; + use crate::pipelines::pipeline::create_pipeline; debug!("Building router with {} pipelines", config.pipelines.len()); diff --git a/tests/guardrails/test_e2e.rs b/tests/guardrails/test_e2e.rs index e71911fa..94a2ec5d 100644 --- a/tests/guardrails/test_e2e.rs +++ b/tests/guardrails/test_e2e.rs @@ -141,7 +141,7 @@ async fn test_e2e_post_call_warn_flow() { assert!(!outcome.blocked); assert_eq!(outcome.warnings.len(), 1); - assert!(outcome.warnings[0].contains("tone-check")); + assert_eq!(outcome.warnings[0].guard_name, "tone-check"); } #[tokio::test] @@ -257,7 +257,7 @@ async fn test_e2e_mixed_block_and_warn() { assert!(outcome.blocked); assert_eq!(outcome.blocking_guard.as_deref(), Some("blocker")); - assert!(outcome.warnings.iter().any(|w| w.contains("warner"))); + assert!(outcome.warnings.iter().any(|w| w.guard_name == "warner")); } #[tokio::test] diff --git a/tests/guardrails/test_executor.rs b/tests/guardrails/test_executor.rs index 0e136c7d..a4b13371 100644 --- a/tests/guardrails/test_executor.rs +++ b/tests/guardrails/test_executor.rs @@ -36,7 +36,7 @@ async fn test_execute_single_pre_call_guard_fails_warn() { let outcome = execute_guards(&[guard], "borderline input", &mock_client).await; assert!(!outcome.blocked); assert_eq!(outcome.warnings.len(), 1); - assert!(outcome.warnings[0].contains("check")); + assert_eq!(outcome.warnings[0].guard_name, "check"); } #[tokio::test] @@ -179,5 +179,5 @@ async fn test_executor_returns_correct_guardrails_outcome() { let outcome = execute_guards(&guards, "input", &mock_client).await; assert!(outcome.blocked); assert_eq!(outcome.blocking_guard, Some("blocker".to_string())); - assert!(outcome.warnings.iter().any(|w| w.contains("warner"))); + assert!(outcome.warnings.iter().any(|w| w.guard_name == "warner")); } diff --git a/tests/guardrails/test_pipeline.rs b/tests/guardrails/test_pipeline.rs index 4ce2f1ab..e442364f 100644 --- a/tests/guardrails/test_pipeline.rs +++ b/tests/guardrails/test_pipeline.rs @@ -82,7 +82,7 @@ async fn test_pre_call_guardrails_warn_and_continue() { assert!(!outcome.blocked); assert_eq!(outcome.warnings.len(), 1); - assert!(outcome.warnings[0].contains("tone-check")); + assert_eq!(outcome.warnings[0].guard_name, "tone-check"); } #[tokio::test] @@ -148,32 +148,25 @@ async fn test_post_call_guardrails_warn_and_add_header() { assert!(!outcome.warnings.is_empty()); // Verify warning header would be generated correctly - let header = warning_header_value(&outcome); + let header = warning_header_value(&outcome.warnings); assert!(header.contains("guardrail_name=")); assert!(header.contains("safety-check")); } #[tokio::test] async fn test_warning_header_format() { - let outcome = GuardrailsOutcome { - results: vec![], - blocked: false, - blocking_guard: None, - warnings: vec!["Guard 'my-guard' failed with warning".to_string()], - }; - let header = warning_header_value(&outcome); + let warnings = vec![GuardWarning { + guard_name: "my-guard".to_string(), + reason: "failed".to_string(), + }]; + let header = warning_header_value(&warnings); assert_eq!(header, "guardrail_name=\"my-guard\", reason=\"failed\""); } #[tokio::test] async fn test_blocked_response_403_format() { - let outcome = GuardrailsOutcome { - results: vec![], - blocked: true, - blocking_guard: Some("toxicity-check".to_string()), - warnings: vec![], - }; - let response = blocked_response(&outcome); + let blocking_guard = Some("toxicity-check".to_string()); + let response = blocked_response(&blocking_guard); assert_eq!(response.status(), 403); let body = to_bytes(response.into_body(), 1024 * 1024).await.unwrap(); From 8643fab08987156f6a1c5785b12d61491947cab1 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Wed, 11 Feb 2026 12:26:37 +0200 Subject: [PATCH 06/59] added validation test --- src/config/validation.rs | 118 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) diff --git a/src/config/validation.rs b/src/config/validation.rs index 86771cf4..b31b7d20 100644 --- a/src/config/validation.rs +++ b/src/config/validation.rs @@ -83,6 +83,7 @@ pub fn validate_gateway_config(config: &GatewayConfig) -> Result<(), Vec #[cfg(test)] mod tests { use super::*; // To import validate_gateway_config + use crate::guardrails::types::{Guard, GuardMode, GuardrailsConfig, OnFailure, ProviderConfig as GrProviderConfig}; use crate::types::{ModelConfig, Pipeline, PipelineType, PluginConfig, Provider, ProviderType}; // For test data #[test] @@ -172,4 +173,121 @@ mod tests { assert_eq!(errors.len(), 1); assert!(errors[0].contains("references non-existent model 'm2_non_existent'")); } + + #[test] + fn test_guard_references_non_existent_guardrail_provider() { + let config = GatewayConfig { + guardrails: Some(GuardrailsConfig { + providers: vec![GrProviderConfig { + name: "gr_p1".to_string(), + api_base: "http://localhost".to_string(), + api_key: "key".to_string(), + }], + guards: vec![Guard { + name: "g1".to_string(), + provider: "gr_p2_non_existent".to_string(), + evaluator_slug: "slug".to_string(), + params: Default::default(), + mode: GuardMode::PreCall, + on_failure: OnFailure::Block, + required: true, + api_base: None, + api_key: None, + }], + }), + general: None, + providers: vec![], + models: vec![], + pipelines: vec![], + }; + let result = validate_gateway_config(&config); + assert!(result.is_err()); + let errors = result.unwrap_err(); + assert_eq!(errors.len(), 1); + assert!(errors[0].contains("references non-existent guardrail provider 'gr_p2_non_existent'")); + } + + #[test] + fn test_pipeline_references_non_existent_guard() { + let config = GatewayConfig { + guardrails: Some(GuardrailsConfig { + providers: vec![GrProviderConfig { + name: "gr_p1".to_string(), + api_base: "http://localhost".to_string(), + api_key: "key".to_string(), + }], + guards: vec![Guard { + name: "g1".to_string(), + provider: "gr_p1".to_string(), + evaluator_slug: "slug".to_string(), + params: Default::default(), + mode: GuardMode::PreCall, + on_failure: OnFailure::Block, + required: true, + api_base: None, + api_key: None, + }], + }), + general: None, + providers: vec![], + models: vec![], + pipelines: vec![Pipeline { + name: "pipe1".to_string(), + r#type: PipelineType::Chat, + plugins: vec![], + guards: vec!["g_non_existent".to_string()], + }], + }; + let result = validate_gateway_config(&config); + assert!(result.is_err()); + let errors = result.unwrap_err(); + assert_eq!(errors.len(), 1); + assert!(errors[0].contains("references non-existent guard 'g_non_existent'")); + } + + #[test] + fn test_duplicate_guard_names() { + let config = GatewayConfig { + guardrails: Some(GuardrailsConfig { + providers: vec![GrProviderConfig { + name: "gr_p1".to_string(), + api_base: "http://localhost".to_string(), + api_key: "key".to_string(), + }], + guards: vec![ + Guard { + name: "g1".to_string(), + provider: "gr_p1".to_string(), + evaluator_slug: "slug".to_string(), + params: Default::default(), + mode: GuardMode::PreCall, + on_failure: OnFailure::Block, + required: true, + api_base: None, + api_key: None, + }, + Guard { + name: "g1".to_string(), + provider: "gr_p1".to_string(), + evaluator_slug: "slug2".to_string(), + params: Default::default(), + mode: GuardMode::PostCall, + on_failure: OnFailure::Warn, + required: true, + api_base: None, + api_key: None, + }, + ], + }), + general: None, + providers: vec![], + models: vec![], + pipelines: vec![], + }; + let result = validate_gateway_config(&config); + assert!(result.is_err()); + let errors = result.unwrap_err(); + assert_eq!(errors.len(), 1); + assert!(errors[0].contains("Duplicate guard name: 'g1'")); + } } From 3eed5486bb4e6c2738af9cf1d4f8e2cc3aa74ee4 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Wed, 11 Feb 2026 14:17:10 +0200 Subject: [PATCH 07/59] provider hash --- src/config/validation.rs | 17 +++++++-------- src/guardrails/builder.rs | 6 +++--- src/guardrails/types.rs | 36 +++++++++++++++++++++++++++---- src/pipelines/pipeline.rs | 3 +-- tests/guardrails/test_e2e.rs | 2 +- tests/guardrails/test_pipeline.rs | 15 +++++++------ tests/guardrails/test_types.rs | 6 +++--- 7 files changed, 56 insertions(+), 29 deletions(-) diff --git a/src/config/validation.rs b/src/config/validation.rs index b31b7d20..0d673e3d 100644 --- a/src/config/validation.rs +++ b/src/config/validation.rs @@ -40,10 +40,8 @@ pub fn validate_gateway_config(config: &GatewayConfig) -> Result<(), Vec // Check 3: Guardrails validation if let Some(gr_config) = &config.guardrails { // Guard provider references must exist in guardrails.providers - let gr_provider_names: HashSet<&String> = - gr_config.providers.iter().map(|p| &p.name).collect(); for guard in &gr_config.guards { - if !gr_provider_names.contains(&guard.provider) { + if !gr_config.providers.contains_key(&guard.provider) { errors.push(format!( "Guard '{}' references non-existent guardrail provider '{}'.", guard.name, guard.provider @@ -83,6 +81,7 @@ pub fn validate_gateway_config(config: &GatewayConfig) -> Result<(), Vec #[cfg(test)] mod tests { use super::*; // To import validate_gateway_config + use std::collections::HashMap; use crate::guardrails::types::{Guard, GuardMode, GuardrailsConfig, OnFailure, ProviderConfig as GrProviderConfig}; use crate::types::{ModelConfig, Pipeline, PipelineType, PluginConfig, Provider, ProviderType}; // For test data @@ -178,11 +177,11 @@ mod tests { fn test_guard_references_non_existent_guardrail_provider() { let config = GatewayConfig { guardrails: Some(GuardrailsConfig { - providers: vec![GrProviderConfig { + providers: HashMap::from([("gr_p1".to_string(), GrProviderConfig { name: "gr_p1".to_string(), api_base: "http://localhost".to_string(), api_key: "key".to_string(), - }], + })]), guards: vec![Guard { name: "g1".to_string(), provider: "gr_p2_non_existent".to_string(), @@ -211,11 +210,11 @@ mod tests { fn test_pipeline_references_non_existent_guard() { let config = GatewayConfig { guardrails: Some(GuardrailsConfig { - providers: vec![GrProviderConfig { + providers: HashMap::from([("gr_p1".to_string(), GrProviderConfig { name: "gr_p1".to_string(), api_base: "http://localhost".to_string(), api_key: "key".to_string(), - }], + })]), guards: vec![Guard { name: "g1".to_string(), provider: "gr_p1".to_string(), @@ -249,11 +248,11 @@ mod tests { fn test_duplicate_guard_names() { let config = GatewayConfig { guardrails: Some(GuardrailsConfig { - providers: vec![GrProviderConfig { + providers: HashMap::from([("gr_p1".to_string(), GrProviderConfig { name: "gr_p1".to_string(), api_base: "http://localhost".to_string(), api_key: "key".to_string(), - }], + })]), guards: vec![ Guard { name: "g1".to_string(), diff --git a/src/guardrails/builder.rs b/src/guardrails/builder.rs index 5b710402..c57550cb 100644 --- a/src/guardrails/builder.rs +++ b/src/guardrails/builder.rs @@ -1,13 +1,13 @@ use std::sync::Arc; -use super::types::{Guard, GuardrailResources, GuardrailsConfig, Guardrails}; +use super::types::{GuardrailResources, GuardrailsConfig, Guardrails}; /// Resolve provider defaults (api_base/api_key) for all guards in the config. -pub fn resolve_guard_defaults(config: &GuardrailsConfig) -> Vec { +pub fn resolve_guard_defaults(config: &GuardrailsConfig) -> Vec { let mut guards = config.guards.clone(); for guard in &mut guards { if guard.api_base.is_none() || guard.api_key.is_none() { - if let Some(provider) = config.providers.iter().find(|p| p.name == guard.provider) { + if let Some(provider) = config.providers.get(&guard.provider) { if guard.api_base.is_none() { guard.api_base = Some(provider.api_base.clone()); } diff --git a/src/guardrails/types.rs b/src/guardrails/types.rs index 695f05d1..efd7cd72 100644 --- a/src/guardrails/types.rs +++ b/src/guardrails/types.rs @@ -1,4 +1,4 @@ -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use std::collections::HashMap; use std::hash::{Hash, Hasher}; use std::sync::Arc; @@ -81,15 +81,43 @@ impl Hash for Guard { #[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq)] pub struct GuardrailsConfig { - #[serde(default)] - pub providers: Vec, + #[serde( + default, + deserialize_with = "deserialize_providers", + serialize_with = "serialize_providers" + )] + pub providers: HashMap, #[serde(default)] pub guards: Vec, } +fn deserialize_providers<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + let list: Vec = Vec::deserialize(deserializer)?; + Ok(list.into_iter().map(|p| (p.name.clone(), p)).collect()) +} + +fn serialize_providers( + providers: &HashMap, + serializer: S, +) -> Result +where + S: Serializer, +{ + let list: Vec<&ProviderConfig> = providers.values().collect(); + list.serialize(serializer) +} + impl Hash for GuardrailsConfig { fn hash(&self, state: &mut H) { - self.providers.hash(state); + let mut entries: Vec<_> = self.providers.iter().collect(); + entries.sort_by_key(|(k, _)| (*k).clone()); + for (k, v) in entries { + k.hash(state); + v.hash(state); + } self.guards.hash(state); } } diff --git a/src/pipelines/pipeline.rs b/src/pipelines/pipeline.rs index 52f54dd3..d703cbc8 100644 --- a/src/pipelines/pipeline.rs +++ b/src/pipelines/pipeline.rs @@ -69,6 +69,7 @@ pub fn create_pipeline( ); for plugin in pipeline.plugins.clone() { + let gr = guardrails.clone(); router = match plugin { PluginConfig::Tracing { endpoint, api_key } => { tracing::info!("Initializing OtelTracer for pipeline {}", pipeline.name); @@ -77,14 +78,12 @@ pub fn create_pipeline( } PluginConfig::ModelRouter { models } => match pipeline.r#type { PipelineType::Chat => { - let gr = guardrails.clone(); router.route( "/chat/completions", post(move |state, headers, payload| chat_completions(state, headers, payload, models, gr)), ) } PipelineType::Completion => { - let gr = guardrails.clone(); router.route( "/completions", post(move |state, headers, payload| completions(state, headers, payload, models, gr)), diff --git a/tests/guardrails/test_e2e.rs b/tests/guardrails/test_e2e.rs index 94a2ec5d..c9e6877c 100644 --- a/tests/guardrails/test_e2e.rs +++ b/tests/guardrails/test_e2e.rs @@ -363,7 +363,7 @@ guardrails: let gr = config.guardrails.unwrap(); assert_eq!(gr.providers.len(), 1); - assert_eq!(gr.providers[0].api_key, "resolved-key-123"); + assert_eq!(gr.providers["traceloop"].api_key, "resolved-key-123"); // Guards should have evaluator_slug at top level assert_eq!(gr.guards[0].evaluator_slug, "toxicity"); diff --git a/tests/guardrails/test_pipeline.rs b/tests/guardrails/test_pipeline.rs index e442364f..c500358d 100644 --- a/tests/guardrails/test_pipeline.rs +++ b/tests/guardrails/test_pipeline.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use hub_lib::guardrails::api_control::{resolve_guards_by_name, split_guards_by_mode}; use hub_lib::guardrails::executor::execute_guards; use hub_lib::guardrails::providers::traceloop::TraceloopClient; @@ -185,7 +186,7 @@ async fn test_blocked_response_403_format() { async fn test_no_guardrails_passthrough() { // Empty guardrails config -> build_guardrail_resources returns None let config = GuardrailsConfig { - providers: vec![], + providers: Default::default(), guards: vec![], }; let result = build_guardrail_resources(&config); @@ -193,11 +194,11 @@ async fn test_no_guardrails_passthrough() { // Config with no guards -> passthrough let config_with_providers = GuardrailsConfig { - providers: vec![ProviderConfig { + providers: HashMap::from([("traceloop".to_string(), ProviderConfig { name: "traceloop".to_string(), api_base: "http://localhost".to_string(), api_key: "key".to_string(), - }], + })]), guards: vec![], }; let result = build_guardrail_resources(&config_with_providers); @@ -210,11 +211,11 @@ async fn test_no_guardrails_passthrough() { fn test_guardrails_config() -> GuardrailsConfig { GuardrailsConfig { - providers: vec![ProviderConfig { + providers: HashMap::from([("traceloop".to_string(), ProviderConfig { name: "traceloop".to_string(), api_base: "https://api.traceloop.com".to_string(), api_key: "test-key".to_string(), - }], + })]), guards: vec![ Guard { name: "pii-check".to_string(), @@ -305,11 +306,11 @@ fn test_build_pipeline_guardrails_resolves_provider_defaults() { #[test] fn test_resolve_guard_defaults_preserves_guard_overrides() { let config = GuardrailsConfig { - providers: vec![ProviderConfig { + providers: HashMap::from([("traceloop".to_string(), ProviderConfig { name: "traceloop".to_string(), api_base: "https://default.api.com".to_string(), api_key: "default-key".to_string(), - }], + })]), guards: vec![Guard { name: "custom-guard".to_string(), provider: "traceloop".to_string(), diff --git a/tests/guardrails/test_types.rs b/tests/guardrails/test_types.rs index 8430ae19..8dcfa847 100644 --- a/tests/guardrails/test_types.rs +++ b/tests/guardrails/test_types.rs @@ -110,7 +110,7 @@ fn test_gateway_config_with_guardrails() { models: vec![], pipelines: vec![], guardrails: Some(GuardrailsConfig { - providers: vec![], + providers: Default::default(), guards: vec![create_test_guard("test", GuardMode::PreCall)], }), }; @@ -209,8 +209,8 @@ guards: "#; let config: GuardrailsConfig = serde_yaml::from_str(yaml).unwrap(); assert_eq!(config.providers.len(), 1); - assert_eq!(config.providers[0].name, "traceloop"); - assert_eq!(config.providers[0].api_base, "https://api.traceloop.com"); + assert_eq!(config.providers["traceloop"].name, "traceloop"); + assert_eq!(config.providers["traceloop"].api_base, "https://api.traceloop.com"); assert_eq!(config.guards.len(), 2); // First guard has no api_base/api_key (inherits from provider) assert!(config.guards[0].api_base.is_none()); From 6cf07de9b450c29e375fa3a661a82f06b644e710 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Wed, 11 Feb 2026 15:19:56 +0200 Subject: [PATCH 08/59] input extractor --- src/config/validation.rs | 2 +- ...s_orchestrator.rs => guardrails_runner.rs} | 16 +++--- src/guardrails/input_extractor.rs | 50 ++++++++++--------- src/guardrails/mod.rs | 2 +- src/pipelines/pipeline.rs | 17 +++---- tests/guardrails/test_e2e.rs | 10 ++-- tests/guardrails/test_input_extractor.rs | 8 +-- 7 files changed, 55 insertions(+), 50 deletions(-) rename src/guardrails/{guardrails_orchestrator.rs => guardrails_runner.rs} (89%) diff --git a/src/config/validation.rs b/src/config/validation.rs index 0d673e3d..c3e72367 100644 --- a/src/config/validation.rs +++ b/src/config/validation.rs @@ -38,7 +38,7 @@ pub fn validate_gateway_config(config: &GatewayConfig) -> Result<(), Vec } // Check 3: Guardrails validation - if let Some(gr_config) = &config.guardrails { + if let Some(gr_config&config.guardrails { // Guard provider references must exist in guardrails.providers for guard in &gr_config.guards { if !gr_config.providers.contains_key(&guard.provider) { diff --git a/src/guardrails/guardrails_orchestrator.rs b/src/guardrails/guardrails_runner.rs similarity index 89% rename from src/guardrails/guardrails_orchestrator.rs rename to src/guardrails/guardrails_runner.rs index 200f5b31..8307f7af 100644 --- a/src/guardrails/guardrails_orchestrator.rs +++ b/src/guardrails/guardrails_runner.rs @@ -9,6 +9,7 @@ use tracing::warn; use super::api_control::{parse_guardrails_header, resolve_guards_by_name, split_guards_by_mode}; use super::executor::execute_guards; +use super::input_extractor::PreCallInput; use super::providers::GuardrailClient; use super::types::{Guard, GuardWarning, Guardrails}; @@ -18,16 +19,16 @@ pub struct GuardPhaseResult { pub warnings: Vec, } -/// Orchestrates guardrail execution across pre-call and post-call phases. +/// Runs guardrails across pre-call and post-call phases. /// Shared between chat_completions and completions handlers. -pub struct GuardrailOrchestrator<'a> { +pub struct GuardrailsRunner<'a> { pre_call: Vec, post_call: Vec, client: &'a dyn GuardrailClient, } -impl<'a> GuardrailOrchestrator<'a> { - /// Create an orchestrator by resolving guards from pipeline config + request headers. +impl<'a> GuardrailsRunner<'a> { + /// Create a runner by resolving guards from pipeline config + request headers. /// Returns None if no guards are active for this request. pub fn new(guardrails: Option<&'a Guardrails>, headers: &HeaderMap) -> Option { let gr = guardrails?; @@ -42,15 +43,16 @@ impl<'a> GuardrailOrchestrator<'a> { }) } - /// Run pre-call guards against the extracted input text. - pub async fn run_pre_call(&self, input: &str) -> GuardPhaseResult { + /// Run pre-call guards, extracting input from the request only if guards exist. + pub async fn run_pre_call(&self, request: &impl PreCallInput) -> GuardPhaseResult { if self.pre_call.is_empty() { return GuardPhaseResult { blocked_response: None, warnings: Vec::new(), }; } - let outcome = execute_guards(&self.pre_call, input, self.client).await; + let input = request.extract_pre_call_input(); + let outcome = execute_guards(&self.pre_call, &input, self.client).await; if outcome.blocked { return GuardPhaseResult { blocked_response: Some(blocked_response(&outcome.blocking_guard)), diff --git a/src/guardrails/input_extractor.rs b/src/guardrails/input_extractor.rs index 33cc2d6d..5d0ff017 100644 --- a/src/guardrails/input_extractor.rs +++ b/src/guardrails/input_extractor.rs @@ -2,31 +2,35 @@ use crate::models::chat::{ChatCompletion, ChatCompletionRequest}; use crate::models::completion::{CompletionRequest, CompletionResponse}; use crate::models::content::ChatMessageContent; -/// Extract text from the request for pre_call guardrails. -/// Returns the content of the last user message. -pub fn extract_pre_call_input(request: &ChatCompletionRequest) -> String { - request - .messages - .iter() - .rev() - .find(|m| m.role == "user") - .and_then(|m| m.content.as_ref()) - .map(|content| match content { - ChatMessageContent::String(s) => s.clone(), - ChatMessageContent::Array(parts) => parts - .iter() - .filter(|p| p.r#type == "text") - .map(|p| p.text.as_str()) - .collect::>() - .join(" "), - }) - .unwrap_or_default() +/// Trait for extracting pre-call guardrail input from a request. +pub trait PreCallInput { + fn extract_pre_call_input(&self) -> String; +} + +impl PreCallInput for ChatCompletionRequest { + fn extract_pre_call_input(&self) -> String { + self.messages + .iter() + .filter_map(|m| { + m.content.as_ref().map(|content| match content { + ChatMessageContent::String(s) => s.clone(), + ChatMessageContent::Array(parts) => parts + .iter() + .filter(|p| p.r#type == "text") + .map(|p| p.text.as_str()) + .collect::>() + .join(" "), + }) + }) + .collect::>() + .join("\n") + } } -/// Extract text from a CompletionRequest for pre_call guardrails. -/// Returns the prompt string. -pub fn extract_pre_call_input_from_completion_request(request: &CompletionRequest) -> String { - request.prompt.clone() +impl PreCallInput for CompletionRequest { + fn extract_pre_call_input(&self) -> String { + self.prompt.clone() + } } /// Extract text from a CompletionResponse for post_call guardrails. diff --git a/src/guardrails/mod.rs b/src/guardrails/mod.rs index 165e70a4..b280c502 100644 --- a/src/guardrails/mod.rs +++ b/src/guardrails/mod.rs @@ -2,7 +2,7 @@ pub mod api_control; pub mod builder; pub mod executor; pub mod input_extractor; -pub mod guardrails_orchestrator; +pub mod guardrails_runner; pub mod providers; pub mod response_parser; pub mod stream_buffer; diff --git a/src/pipelines/pipeline.rs b/src/pipelines/pipeline.rs index d703cbc8..c8427a83 100644 --- a/src/pipelines/pipeline.rs +++ b/src/pipelines/pipeline.rs @@ -1,9 +1,8 @@ use crate::config::models::PipelineType; use crate::guardrails::input_extractor::{ extract_post_call_input_from_completion, extract_post_call_input_from_completion_response, - extract_pre_call_input, extract_pre_call_input_from_completion_request, }; -use crate::guardrails::guardrails_orchestrator::GuardrailOrchestrator; +use crate::guardrails::guardrails_runner::GuardrailsRunner; use crate::guardrails::types::{GuardrailResources, Guardrails}; use crate::models::chat::ChatCompletionResponse; use crate::models::completion::CompletionRequest; @@ -35,7 +34,7 @@ use std::sync::Arc; pub use crate::guardrails::builder::{ build_guardrail_resources, build_pipeline_guardrails, resolve_guard_defaults, }; -pub use crate::guardrails::guardrails_orchestrator::{blocked_response, warning_header_value}; +pub use crate::guardrails::guardrails_runner::{blocked_response, warning_header_value}; pub fn create_pipeline( pipeline: &Pipeline, @@ -132,12 +131,12 @@ pub async fn chat_completions( guardrails: Option>, ) -> Result { let mut tracer = OtelTracer::start("chat", &payload); - let orchestrator = GuardrailOrchestrator::new(guardrails.as_deref(), &headers); + let orchestrator = GuardrailsRunner::new(guardrails.as_deref(), &headers); // Pre-call guardrails let mut all_warnings = Vec::new(); if let Some(orch) = &orchestrator { - let pre = orch.run_pre_call(&extract_pre_call_input(&payload)).await; + let pre = orch.run_pre_call(&payload).await; if let Some(resp) = pre.blocked_response { return Ok(resp); } @@ -170,7 +169,7 @@ pub async fn chat_completions( all_warnings.extend(post.warnings); } - return Ok(GuardrailOrchestrator::finalize_response( + return Ok(GuardrailsRunner::finalize_response( Json(completion).into_response(), &all_warnings, )); @@ -197,13 +196,13 @@ pub async fn completions( guardrails: Option>, ) -> Result { let mut tracer = OtelTracer::start("completion", &payload); - let orchestrator = GuardrailOrchestrator::new(guardrails.as_deref(), &headers); + let orchestrator = GuardrailsRunner::new(guardrails.as_deref(), &headers); // Pre-call guardrails let mut all_warnings = Vec::new(); if let Some(orch) = &orchestrator { let pre = orch - .run_pre_call(&extract_pre_call_input_from_completion_request(&payload)) + .run_pre_call(&payload) .await; if let Some(resp) = pre.blocked_response { return Ok(resp); @@ -232,7 +231,7 @@ pub async fn completions( all_warnings.extend(post.warnings); } - return Ok(GuardrailOrchestrator::finalize_response( + return Ok(GuardrailsRunner::finalize_response( Json(response).into_response(), &all_warnings, )); diff --git a/tests/guardrails/test_e2e.rs b/tests/guardrails/test_e2e.rs index c9e6877c..c74fe7c9 100644 --- a/tests/guardrails/test_e2e.rs +++ b/tests/guardrails/test_e2e.rs @@ -1,6 +1,6 @@ use hub_lib::guardrails::executor::execute_guards; use hub_lib::guardrails::input_extractor::{ - extract_post_call_input_from_completion, extract_pre_call_input, + extract_post_call_input_from_completion, extract_prompt, }; use hub_lib::guardrails::providers::traceloop::TraceloopClient; use hub_lib::guardrails::stream_buffer::extract_text_from_chunks; @@ -66,7 +66,7 @@ async fn test_e2e_pre_call_block_flow() { ); let request = create_test_chat_request("Bad input"); - let input = extract_pre_call_input(&request); + let input = extract_prompt(&request); let client = TraceloopClient::new(); let outcome = execute_guards(&[guard], &input, &client).await; @@ -88,7 +88,7 @@ async fn test_e2e_pre_call_pass_flow() { ); let request = create_test_chat_request("Safe input"); - let input = extract_pre_call_input(&request); + let input = extract_prompt(&request); let client = TraceloopClient::new(); let outcome = execute_guards(&[guard], &input, &client).await; @@ -169,7 +169,7 @@ async fn test_e2e_pre_and_post_both_pass() { // Pre-call let request = create_test_chat_request("Hello"); - let input = extract_pre_call_input(&request); + let input = extract_prompt(&request); let pre_outcome = execute_guards(&[pre_guard], &input, &client).await; assert!(!pre_outcome.blocked); @@ -211,7 +211,7 @@ async fn test_e2e_pre_blocks_post_never_runs() { let client = TraceloopClient::new(); let request = create_test_chat_request("Bad input"); - let input = extract_pre_call_input(&request); + let input = extract_prompt(&request); let pre_outcome = execute_guards(&[pre_guard], &input, &client).await; assert!(pre_outcome.blocked); diff --git a/tests/guardrails/test_input_extractor.rs b/tests/guardrails/test_input_extractor.rs index 732b6d40..d9dd0232 100644 --- a/tests/guardrails/test_input_extractor.rs +++ b/tests/guardrails/test_input_extractor.rs @@ -10,7 +10,7 @@ use super::helpers::*; #[test] fn test_extract_text_single_user_message() { let request = create_test_chat_request("Hello world"); - let text = extract_pre_call_input(&request); + let text = extract_prompt(&request); assert_eq!(text, "Hello world"); } @@ -39,7 +39,7 @@ fn test_extract_text_multi_turn_conversation() { ..default_message() }, ]; - let text = extract_pre_call_input(&request); + let text = extract_prompt(&request); assert_eq!(text, "Follow-up question"); } @@ -56,7 +56,7 @@ fn test_extract_text_from_array_content_parts() { text: "Part 2".to_string(), }, ])); - let text = extract_pre_call_input(&request); + let text = extract_prompt(&request); assert_eq!(text, "Part 1 Part 2"); } @@ -71,6 +71,6 @@ fn test_extract_response_from_chat_completion() { fn test_extract_handles_empty_content() { let mut request = create_test_chat_request(""); request.messages[0].content = None; - let text = extract_pre_call_input(&request); + let text = extract_prompt(&request); assert_eq!(text, ""); } From f08c7d50f163195166e1bc2d1e4e98da4a21c661 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Thu, 12 Feb 2026 09:27:39 +0200 Subject: [PATCH 09/59] add check --- src/config/validation.rs | 95 +++++++++++++++++++++++++-- src/guardrails/builder.rs | 4 +- src/guardrails/providers/traceloop.rs | 17 ++--- 3 files changed, 100 insertions(+), 16 deletions(-) diff --git a/src/config/validation.rs b/src/config/validation.rs index c3e72367..3e238921 100644 --- a/src/config/validation.rs +++ b/src/config/validation.rs @@ -38,8 +38,8 @@ pub fn validate_gateway_config(config: &GatewayConfig) -> Result<(), Vec } // Check 3: Guardrails validation - if let Some(gr_config&config.guardrails { - // Guard provider references must exist in guardrails.providers + if let Some(gr_config) = &config.guardrails { + // Guard provider references must exist in guardrails.providers for guard in &gr_config.guards { if !gr_config.providers.contains_key(&guard.provider) { errors.push(format!( @@ -62,6 +62,29 @@ pub fn validate_gateway_config(config: &GatewayConfig) -> Result<(), Vec } } + // Guards must have api_base and api_key (either directly or via provider) + for guard in &gr_config.guards { + let has_api_base = guard.api_base.as_ref().is_some_and(|s| !s.is_empty()) + || gr_config.providers.get(&guard.provider) + .is_some_and(|p| !p.api_base.is_empty()); + let has_api_key = guard.api_key.as_ref().is_some_and(|s| !s.is_empty()) + || gr_config.providers.get(&guard.provider) + .is_some_and(|p| !p.api_key.is_empty()); + + if !has_api_base { + errors.push(format!( + "Guard '{}' has no api_base configured (neither on the guard nor on provider '{}').", + guard.name, guard.provider + )); + } + if !has_api_key { + errors.push(format!( + "Guard '{}' has no api_key configured (neither on the guard nor on provider '{}').", + guard.name, guard.provider + )); + } + } + // Guard names must be unique let mut seen_guard_names = HashSet::new(); for guard in &gr_config.guards { @@ -202,8 +225,9 @@ mod tests { let result = validate_gateway_config(&config); assert!(result.is_err()); let errors = result.unwrap_err(); - assert_eq!(errors.len(), 1); - assert!(errors[0].contains("references non-existent guardrail provider 'gr_p2_non_existent'")); + assert!(errors.iter().any(|e| e.contains("references non-existent guardrail provider 'gr_p2_non_existent'"))); + assert!(errors.iter().any(|e| e.contains("no api_base configured"))); + assert!(errors.iter().any(|e| e.contains("no api_key configured"))); } #[test] @@ -289,4 +313,67 @@ mod tests { assert_eq!(errors.len(), 1); assert!(errors[0].contains("Duplicate guard name: 'g1'")); } + + #[test] + fn test_guard_missing_api_base_and_api_key() { + let config = GatewayConfig { + guardrails: Some(GuardrailsConfig { + providers: HashMap::from([("gr_p1".to_string(), GrProviderConfig { + name: "gr_p1".to_string(), + api_base: "".to_string(), + api_key: "".to_string(), + })]), + guards: vec![Guard { + name: "g1".to_string(), + provider: "gr_p1".to_string(), + evaluator_slug: "slug".to_string(), + params: Default::default(), + mode: GuardMode::PreCall, + on_failure: OnFailure::Block, + required: true, + api_base: None, + api_key: None, + }], + }), + general: None, + providers: vec![], + models: vec![], + pipelines: vec![], + }; + let result = validate_gateway_config(&config); + assert!(result.is_err()); + let errors = result.unwrap_err(); + assert_eq!(errors.len(), 2); + assert!(errors[0].contains("no api_base configured")); + assert!(errors[1].contains("no api_key configured")); + } + + #[test] + fn test_guard_inherits_api_base_from_provider() { + let config = GatewayConfig { + guardrails: Some(GuardrailsConfig { + providers: HashMap::from([("gr_p1".to_string(), GrProviderConfig { + name: "gr_p1".to_string(), + api_base: "http://localhost".to_string(), + api_key: "key".to_string(), + })]), + guards: vec![Guard { + name: "g1".to_string(), + provider: "gr_p1".to_string(), + evaluator_slug: "slug".to_string(), + params: Default::default(), + mode: GuardMode::PreCall, + on_failure: OnFailure::Block, + required: true, + api_base: None, + api_key: None, + }], + }), + general: None, + providers: vec![], + models: vec![], + pipelines: vec![], + }; + assert!(validate_gateway_config(&config).is_ok()); + } } diff --git a/src/guardrails/builder.rs b/src/guardrails/builder.rs index c57550cb..4ceeff8b 100644 --- a/src/guardrails/builder.rs +++ b/src/guardrails/builder.rs @@ -8,10 +8,10 @@ pub fn resolve_guard_defaults(config: &GuardrailsConfig) -> Vec Self { - Self { - http_client: reqwest::Client::builder().timeout(timeout).build().unwrap(), - } - } } #[async_trait] @@ -40,10 +35,12 @@ impl GuardrailClient for TraceloopClient { guard: &Guard, input: &str, ) -> Result { - let api_base = guard.api_base.as_deref().unwrap_or("http://localhost:8080"); + let api_base = guard.api_base.as_deref() + .filter(|s| !s.is_empty()) + .unwrap_or(DEFAULT_TRACELOOP_API); let url = format!( - "{}/v2/guardrails/{}", - api_base.trim_end_matches('/'), + "{}/v2/guardrails/execute/{}", + api_base, guard.evaluator_slug ); From d4139b058e86e7a5265d90df580046bb7c6b6cb3 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Thu, 12 Feb 2026 11:40:18 +0200 Subject: [PATCH 10/59] eval structs --- src/config/validation.rs | 54 +++++- src/guardrails/evaluator_types.rs | 223 ++++++++++++++++++++++ src/guardrails/mod.rs | 1 + src/guardrails/providers/traceloop.rs | 74 ++++++- src/guardrails/types.rs | 4 +- tests/guardrails/helpers.rs | 2 +- tests/guardrails/test_e2e.rs | 44 ++--- tests/guardrails/test_traceloop_client.rs | 6 +- 8 files changed, 367 insertions(+), 41 deletions(-) create mode 100644 src/guardrails/evaluator_types.rs diff --git a/src/config/validation.rs b/src/config/validation.rs index 3e238921..7576d5b7 100644 --- a/src/config/validation.rs +++ b/src/config/validation.rs @@ -85,6 +85,16 @@ pub fn validate_gateway_config(config: &GatewayConfig) -> Result<(), Vec } } + // Evaluator slugs must be recognised + for guard in &gr_config.guards { + if crate::guardrails::evaluator_types::get_evaluator(&guard.evaluator_slug).is_none() { + errors.push(format!( + "Guard '{}' has unknown evaluator_slug '{}'.", + guard.name, guard.evaluator_slug + )); + } + } + // Guard names must be unique let mut seen_guard_names = HashSet::new(); for guard in &gr_config.guards { @@ -208,7 +218,7 @@ mod tests { guards: vec![Guard { name: "g1".to_string(), provider: "gr_p2_non_existent".to_string(), - evaluator_slug: "slug".to_string(), + evaluator_slug: "pii-detector".to_string(), params: Default::default(), mode: GuardMode::PreCall, on_failure: OnFailure::Block, @@ -242,7 +252,7 @@ mod tests { guards: vec![Guard { name: "g1".to_string(), provider: "gr_p1".to_string(), - evaluator_slug: "slug".to_string(), + evaluator_slug: "pii-detector".to_string(), params: Default::default(), mode: GuardMode::PreCall, on_failure: OnFailure::Block, @@ -281,7 +291,7 @@ mod tests { Guard { name: "g1".to_string(), provider: "gr_p1".to_string(), - evaluator_slug: "slug".to_string(), + evaluator_slug: "pii-detector".to_string(), params: Default::default(), mode: GuardMode::PreCall, on_failure: OnFailure::Block, @@ -292,7 +302,7 @@ mod tests { Guard { name: "g1".to_string(), provider: "gr_p1".to_string(), - evaluator_slug: "slug2".to_string(), + evaluator_slug: "toxicity-detector".to_string(), params: Default::default(), mode: GuardMode::PostCall, on_failure: OnFailure::Warn, @@ -326,7 +336,7 @@ mod tests { guards: vec![Guard { name: "g1".to_string(), provider: "gr_p1".to_string(), - evaluator_slug: "slug".to_string(), + evaluator_slug: "pii-detector".to_string(), params: Default::default(), mode: GuardMode::PreCall, on_failure: OnFailure::Block, @@ -360,7 +370,7 @@ mod tests { guards: vec![Guard { name: "g1".to_string(), provider: "gr_p1".to_string(), - evaluator_slug: "slug".to_string(), + evaluator_slug: "pii-detector".to_string(), params: Default::default(), mode: GuardMode::PreCall, on_failure: OnFailure::Block, @@ -376,4 +386,36 @@ mod tests { }; assert!(validate_gateway_config(&config).is_ok()); } + + #[test] + fn test_guard_unknown_evaluator_slug() { + let config = GatewayConfig { + guardrails: Some(GuardrailsConfig { + providers: HashMap::from([("gr_p1".to_string(), GrProviderConfig { + name: "gr_p1".to_string(), + api_base: "http://localhost".to_string(), + api_key: "key".to_string(), + })]), + guards: vec![Guard { + name: "g1".to_string(), + provider: "gr_p1".to_string(), + evaluator_slug: "made-up-slug".to_string(), + params: Default::default(), + mode: GuardMode::PreCall, + on_failure: OnFailure::Block, + required: true, + api_base: None, + api_key: None, + }], + }), + general: None, + providers: vec![], + models: vec![], + pipelines: vec![], + }; + let result = validate_gateway_config(&config); + assert!(result.is_err()); + let errors = result.unwrap_err(); + assert!(errors.iter().any(|e| e.contains("unknown evaluator_slug 'made-up-slug'"))); + } } diff --git a/src/guardrails/evaluator_types.rs b/src/guardrails/evaluator_types.rs new file mode 100644 index 00000000..6bef0ffd --- /dev/null +++ b/src/guardrails/evaluator_types.rs @@ -0,0 +1,223 @@ +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use serde_json::json; +use std::collections::HashMap; + +use super::types::GuardrailError; + +// --------------------------------------------------------------------------- +// Slugs +// --------------------------------------------------------------------------- + +// Safety +pub const PII_DETECTOR: &str = "pii-detector"; +pub const SECRETS_DETECTOR: &str = "secrets-detector"; +pub const PROMPT_INJECTION: &str = "prompt-injection"; +pub const PROFANITY_DETECTOR: &str = "profanity-detector"; +pub const SEXISM_DETECTOR: &str = "sexism-detector"; +pub const TOXICITY_DETECTOR: &str = "toxicity-detector"; +// Validators +pub const REGEX_VALIDATOR: &str = "regex-validator"; +pub const JSON_VALIDATOR: &str = "json-validator"; +pub const SQL_VALIDATOR: &str = "sql-validator"; +// Quality and adherence +pub const TONE_DETECTION: &str = "tone-detection"; +pub const PROMPT_PERPLEXITY: &str = "prompt-perplexity"; +pub const UNCERTAINTY_DETECTOR: &str = "uncertainty-detector"; + +// --------------------------------------------------------------------------- +// Trait +// --------------------------------------------------------------------------- + +/// Each supported evaluator implements this trait to build its typed request body. +pub trait EvaluatorRequest: Send + Sync { + fn build_body( + &self, + input: &str, + params: &HashMap, + ) -> Result; +} + +/// Look up the evaluator implementation for a given slug. +pub fn get_evaluator(slug: &str) -> Option<&'static dyn EvaluatorRequest> { + match slug { + // Safety + PII_DETECTOR => Some(&PiiDetector), + SECRETS_DETECTOR => Some(&SecretsDetector), + PROMPT_INJECTION => Some(&PromptInjection), + PROFANITY_DETECTOR => Some(&ProfanityDetector), + SEXISM_DETECTOR => Some(&SexismDetector), + TOXICITY_DETECTOR => Some(&ToxicityDetector), + // Validators + REGEX_VALIDATOR => Some(&RegexValidator), + JSON_VALIDATOR => Some(&JsonValidator), + SQL_VALIDATOR => Some(&SqlValidator), + // Quality and adherence + TONE_DETECTION => Some(&ToneDetection), + PROMPT_PERPLEXITY => Some(&PromptPerplexity), + UNCERTAINTY_DETECTOR => Some(&UncertaintyDetector), + _ => None, + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +fn text_body(input: &str) -> serde_json::Value { + json!({ "input": { "text": input } }) +} + +fn prompt_body(input: &str) -> serde_json::Value { + json!({ "input": { "prompt": input } }) +} + +/// Deserialize `params` into a typed config `C`, then attach it to the body. +/// Skips the `config` key entirely when `params` is empty. +fn attach_config( + mut body: serde_json::Value, + params: &HashMap, + slug: &str, +) -> Result { + if params.is_empty() { + return Ok(body); + } + let params_value: serde_json::Value = params.clone().into_iter().collect(); + let config: C = serde_json::from_value(params_value) + .map_err(|e| GuardrailError::ParseError(format!("{slug}: invalid config — {e}")))?; + let config_json = serde_json::to_value(config) + .map_err(|e| GuardrailError::ParseError(e.to_string()))?; + if config_json.as_object().is_some_and(|m| !m.is_empty()) { + body["config"] = config_json; + } + Ok(body) +} + +macro_rules! evaluator_with_no_config { + ($name:ident, $body_fn:ident) => { + pub struct $name; + impl EvaluatorRequest for $name { + fn build_body( + &self, + input: &str, + _params: &HashMap, + ) -> Result { + Ok($body_fn(input)) + } + } + }; +} + +evaluator_with_no_config!(SecretsDetector, text_body); +evaluator_with_no_config!(ProfanityDetector, text_body); +evaluator_with_no_config!(SqlValidator, text_body); +evaluator_with_no_config!(ToneDetection, text_body); +evaluator_with_no_config!(PromptPerplexity, prompt_body); +evaluator_with_no_config!(UncertaintyDetector, prompt_body); + +// --------------------------------------------------------------------------- +// Config structs (mirroring the Go DTOs in evaluator_mbt.go) +// --------------------------------------------------------------------------- + +#[derive(Default, Deserialize, Serialize)] +pub struct PiiDetectorConfig { + #[serde(skip_serializing_if = "Option::is_none")] + pub probability_threshold: Option, +} + +#[derive(Default, Deserialize, Serialize)] +pub struct ThresholdConfig { + #[serde(skip_serializing_if = "Option::is_none")] + pub threshold: Option, +} + +#[derive(Default, Deserialize, Serialize)] +pub struct RegexValidatorConfig { + #[serde(skip_serializing_if = "Option::is_none")] + pub regex: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub should_match: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub case_sensitive: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub dot_include_nl: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub multi_line: Option, +} + +#[derive(Default, Deserialize, Serialize)] +pub struct JsonValidatorConfig { + #[serde(skip_serializing_if = "Option::is_none")] + pub enable_schema_validation: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub schema_string: Option, +} + +// --------------------------------------------------------------------------- +// Evaluators with config +// --------------------------------------------------------------------------- + +pub struct PiiDetector; +impl EvaluatorRequest for PiiDetector { + fn build_body( + &self, + input: &str, + params: &HashMap, + ) -> Result { + attach_config::(text_body(input), params, PII_DETECTOR) + } +} + +pub struct PromptInjection; +impl EvaluatorRequest for PromptInjection { + fn build_body( + &self, + input: &str, + params: &HashMap, + ) -> Result { + attach_config::(prompt_body(input), params, PROMPT_INJECTION) + } +} + +pub struct SexismDetector; +impl EvaluatorRequest for SexismDetector { + fn build_body( + &self, + input: &str, + params: &HashMap, + ) -> Result { + attach_config::(text_body(input), params, SEXISM_DETECTOR) + } +} + +pub struct ToxicityDetector; +impl EvaluatorRequest for ToxicityDetector { + fn build_body( + &self, + input: &str, + params: &HashMap, + ) -> Result { + attach_config::(text_body(input), params, TOXICITY_DETECTOR) + } +} + +pub struct RegexValidator; +impl EvaluatorRequest for RegexValidator { + fn build_body( + &self, + input: &str, + params: &HashMap, + ) -> Result { + attach_config::(text_body(input), params, REGEX_VALIDATOR) + } +} + +pub struct JsonValidator; +impl EvaluatorRequest for JsonValidator { + fn build_body( + &self, + input: &str, + params: &HashMap, + ) -> Result { + attach_config::(text_body(input), params, JSON_VALIDATOR) + } +} diff --git a/src/guardrails/mod.rs b/src/guardrails/mod.rs index b280c502..fdc23eaa 100644 --- a/src/guardrails/mod.rs +++ b/src/guardrails/mod.rs @@ -1,5 +1,6 @@ pub mod api_control; pub mod builder; +pub mod evaluator_types; pub mod executor; pub mod input_extractor; pub mod guardrails_runner; diff --git a/src/guardrails/providers/traceloop.rs b/src/guardrails/providers/traceloop.rs index d1879591..6b0c9158 100644 --- a/src/guardrails/providers/traceloop.rs +++ b/src/guardrails/providers/traceloop.rs @@ -1,8 +1,9 @@ use async_trait::async_trait; -use serde_json::json; +use std::collections::HashMap; use tracing::debug; use super::GuardrailClient; +use crate::guardrails::evaluator_types::get_evaluator; use crate::guardrails::response_parser::parse_evaluator_http_response; use crate::guardrails::types::{EvaluatorResponse, Guard, GuardrailError}; @@ -26,6 +27,15 @@ impl TraceloopClient { http_client: reqwest::Client::new(), } } + + pub fn with_timeout(timeout: std::time::Duration) -> Self { + Self { + http_client: reqwest::Client::builder() + .timeout(timeout) + .build() + .unwrap_or_default(), + } + } } #[async_trait] @@ -46,13 +56,10 @@ impl GuardrailClient for TraceloopClient { let api_key = guard.api_key.as_deref().unwrap_or(""); - // Build config from params (excluding evaluator_slug which is top-level) - let config: serde_json::Value = guard.params.clone().into_iter().collect(); - - let body = json!({ - "inputs": [input], - "config": config, - }); + let evaluator = get_evaluator(&guard.evaluator_slug).ok_or_else(|| { + GuardrailError::Unavailable(format!("Unknown evaluator slug '{}'", guard.evaluator_slug)) + })?; + let body = evaluator.build_body(input, &guard.params)?; debug!(guard = %guard.name, slug = %guard.evaluator_slug, %url, "Calling evaluator API"); @@ -71,3 +78,54 @@ impl GuardrailClient for TraceloopClient { parse_evaluator_http_response(status, &response_body) } } + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_build_body_text_slug() { + let params = HashMap::new(); + let body = get_evaluator("pii-detector").unwrap().build_body("hello world", ¶ms).unwrap(); + assert_eq!(body, json!({"input": {"text": "hello world"}})); + } + + #[test] + fn test_build_body_prompt_slug() { + let params = HashMap::new(); + let body = get_evaluator("prompt-injection").unwrap().build_body("hello world", ¶ms).unwrap(); + assert_eq!(body, json!({"input": {"prompt": "hello world"}})); + } + + #[test] + fn test_build_body_with_config() { + let mut params = HashMap::new(); + params.insert("threshold".to_string(), json!(0.8)); + let body = get_evaluator("toxicity-detector").unwrap().build_body("test", ¶ms).unwrap(); + assert_eq!( + body, + json!({"input": {"text": "test"}, "config": {"threshold": 0.8}}) + ); + } + + #[test] + fn test_build_body_no_config_when_params_empty() { + let params = HashMap::new(); + let body = get_evaluator("secrets-detector").unwrap().build_body("test", ¶ms).unwrap(); + assert!(body.get("config").is_none()); + } + + #[test] + fn test_get_evaluator_unknown_slug() { + assert!(get_evaluator("nonexistent").is_none()); + } + + #[test] + fn test_build_body_rejects_invalid_config_type() { + let mut params = HashMap::new(); + params.insert("threshold".to_string(), json!("not-a-number")); + let result = get_evaluator("toxicity-detector").unwrap().build_body("test", ¶ms); + assert!(result.is_err()); + } +} diff --git a/src/guardrails/types.rs b/src/guardrails/types.rs index efd7cd72..4f637838 100644 --- a/src/guardrails/types.rs +++ b/src/guardrails/types.rs @@ -91,7 +91,9 @@ pub struct GuardrailsConfig { pub guards: Vec, } -fn deserialize_providers<'de, D>(deserializer: D) -> Result, D::Error> +fn deserialize_providers<'de, D>( + deserializer: D, +) -> Result, D::Error> where D: Deserializer<'de>, { diff --git a/tests/guardrails/helpers.rs b/tests/guardrails/helpers.rs index 2443a86e..0c8412d4 100644 --- a/tests/guardrails/helpers.rs +++ b/tests/guardrails/helpers.rs @@ -19,7 +19,7 @@ pub fn create_test_guard(name: &str, mode: GuardMode) -> Guard { Guard { name: name.to_string(), provider: "traceloop".to_string(), - evaluator_slug: "test-evaluator".to_string(), + evaluator_slug: "pii-detector".to_string(), params: HashMap::new(), mode, on_failure: OnFailure::Block, diff --git a/tests/guardrails/test_e2e.rs b/tests/guardrails/test_e2e.rs index c74fe7c9..f0bc95b8 100644 --- a/tests/guardrails/test_e2e.rs +++ b/tests/guardrails/test_e2e.rs @@ -62,7 +62,7 @@ async fn test_e2e_pre_call_block_flow() { GuardMode::PreCall, OnFailure::Block, &eval.uri(), - "toxicity", + "toxicity-detector", ); let request = create_test_chat_request("Bad input"); @@ -84,7 +84,7 @@ async fn test_e2e_pre_call_pass_flow() { GuardMode::PreCall, OnFailure::Block, &eval.uri(), - "toxicity", + "toxicity-detector", ); let request = create_test_chat_request("Safe input"); @@ -107,7 +107,7 @@ async fn test_e2e_post_call_block_flow() { GuardMode::PostCall, OnFailure::Block, &eval.uri(), - "pii", + "pii-detector", ); // Simulate LLM response @@ -130,7 +130,7 @@ async fn test_e2e_post_call_warn_flow() { GuardMode::PostCall, OnFailure::Warn, &eval.uri(), - "tone", + "tone-detection", ); let completion = create_test_chat_completion("Mildly concerning response"); @@ -155,14 +155,14 @@ async fn test_e2e_pre_and_post_both_pass() { GuardMode::PreCall, OnFailure::Block, &pre_eval.uri(), - "safety", + "profanity-detector", ); let post_guard = guard_with_server( "post-check", GuardMode::PostCall, OnFailure::Block, &post_eval.uri(), - "pii", + "pii-detector", ); let client = TraceloopClient::new(); @@ -199,14 +199,14 @@ async fn test_e2e_pre_blocks_post_never_runs() { GuardMode::PreCall, OnFailure::Block, &pre_eval.uri(), - "toxicity", + "toxicity-detector", ); let post_guard = guard_with_server( "post-check", GuardMode::PostCall, OnFailure::Block, &post_eval.uri(), - "pii", + "pii-detector", ); let client = TraceloopClient::new(); @@ -234,21 +234,21 @@ async fn test_e2e_mixed_block_and_warn() { GuardMode::PreCall, OnFailure::Block, &eval1.uri(), - "safety", + "profanity-detector", ), guard_with_server( "warner", GuardMode::PreCall, OnFailure::Warn, &eval2.uri(), - "tone", + "tone-detection", ), guard_with_server( "blocker", GuardMode::PreCall, OnFailure::Block, &eval3.uri(), - "toxicity", + "toxicity-detector", ), ]; @@ -269,7 +269,7 @@ async fn test_e2e_streaming_post_call_buffer_pass() { GuardMode::PostCall, OnFailure::Block, &eval.uri(), - "safety", + "profanity-detector", ); // Simulate accumulated streaming chunks @@ -296,7 +296,7 @@ async fn test_e2e_streaming_post_call_buffer_block() { GuardMode::PostCall, OnFailure::Block, &eval.uri(), - "pii", + "pii-detector", ); let chunks = vec![ @@ -345,12 +345,12 @@ guardrails: guards: - name: toxicity-check provider: traceloop - evaluator_slug: toxicity + evaluator_slug: toxicity-detector mode: pre_call on_failure: block - name: pii-check provider: traceloop - evaluator_slug: pii + evaluator_slug: pii-detector mode: post_call on_failure: warn api_key: "override-key" @@ -366,7 +366,7 @@ guardrails: assert_eq!(gr.providers["traceloop"].api_key, "resolved-key-123"); // Guards should have evaluator_slug at top level - assert_eq!(gr.guards[0].evaluator_slug, "toxicity"); + assert_eq!(gr.guards[0].evaluator_slug, "toxicity-detector"); assert_eq!(gr.guards[0].mode, GuardMode::PreCall); assert!(gr.guards[0].api_base.is_none()); // inherits from provider assert!(gr.guards[0].api_key.is_none()); // inherits from provider @@ -408,14 +408,14 @@ async fn test_e2e_multiple_guards_different_evaluators() { let server = MockServer::start().await; Mock::given(matchers::method("POST")) - .and(matchers::path("/v2/guardrails/toxicity")) + .and(matchers::path("/v2/guardrails/execute/toxicity-detector")) .respond_with(ResponseTemplate::new(200).set_body_json(json!({"result": {}, "pass": true}))) .expect(1) .mount(&server) .await; Mock::given(matchers::method("POST")) - .and(matchers::path("/v2/guardrails/pii")) + .and(matchers::path("/v2/guardrails/execute/pii-detector")) .respond_with(ResponseTemplate::new(200).set_body_json(json!({"result": {}, "pass": true}))) .expect(1) .mount(&server) @@ -427,14 +427,14 @@ async fn test_e2e_multiple_guards_different_evaluators() { GuardMode::PreCall, OnFailure::Block, &server.uri(), - "toxicity", + "toxicity-detector", ), guard_with_server( "pii-guard", GuardMode::PreCall, OnFailure::Block, &server.uri(), - "pii", + "pii-detector", ), ]; @@ -460,7 +460,7 @@ async fn test_e2e_fail_open_evaluator_down() { GuardMode::PreCall, OnFailure::Block, &server.uri(), - "safety", + "profanity-detector", ); guard.required = false; // fail-open @@ -484,7 +484,7 @@ async fn test_e2e_fail_closed_evaluator_down() { GuardMode::PreCall, OnFailure::Block, &server.uri(), - "safety", + "profanity-detector", ); guard.required = true; // fail-closed diff --git a/tests/guardrails/test_traceloop_client.rs b/tests/guardrails/test_traceloop_client.rs index a92c337e..afe1977e 100644 --- a/tests/guardrails/test_traceloop_client.rs +++ b/tests/guardrails/test_traceloop_client.rs @@ -15,14 +15,14 @@ use super::helpers::*; async fn test_traceloop_client_constructs_correct_url() { let mock_server = MockServer::start().await; Mock::given(matchers::method("POST")) - .and(matchers::path("/v2/guardrails/toxicity")) + .and(matchers::path("/v2/guardrails/execute/toxicity-detector")) .respond_with(ResponseTemplate::new(200).set_body_json(json!({"result": {}, "pass": true}))) .expect(1) .mount(&mock_server) .await; let mut guard = create_test_guard_with_api_base("test", GuardMode::PreCall, &mock_server.uri()); - guard.evaluator_slug = "toxicity".to_string(); + guard.evaluator_slug = "toxicity-detector".to_string(); let client = TraceloopClient::new(); let result = client.evaluate(&guard, "test input").await; @@ -51,7 +51,7 @@ async fn test_traceloop_client_sends_correct_body() { let mock_server = MockServer::start().await; Mock::given(matchers::method("POST")) .and(matchers::body_json(json!({ - "inputs": ["test input text"], + "input": {"text": "test input text"}, "config": {"threshold": 0.5} }))) .respond_with(ResponseTemplate::new(200).set_body_json(json!({"result": {}, "pass": true}))) From 77c8c97e910e2d226050fe067049205751a0b996 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Thu, 12 Feb 2026 13:34:17 +0200 Subject: [PATCH 11/59] change client pos --- src/guardrails/builder.rs | 2 +- src/guardrails/executor.rs | 3 +-- src/guardrails/guardrails_runner.rs | 3 +-- src/guardrails/providers/mod.rs | 15 +-------------- src/guardrails/providers/traceloop.rs | 6 ++++-- src/guardrails/types.rs | 14 ++++++++++++-- 6 files changed, 20 insertions(+), 23 deletions(-) diff --git a/src/guardrails/builder.rs b/src/guardrails/builder.rs index 4ceeff8b..37892163 100644 --- a/src/guardrails/builder.rs +++ b/src/guardrails/builder.rs @@ -30,7 +30,7 @@ pub fn build_guardrail_resources( return None; } let all_guards = Arc::new(resolve_guard_defaults(config)); - let client: Arc = + let client: Arc = Arc::new(super::providers::traceloop::TraceloopClient::new()); Some((all_guards, client)) } diff --git a/src/guardrails/executor.rs b/src/guardrails/executor.rs index ec38b6c9..dd7150cc 100644 --- a/src/guardrails/executor.rs +++ b/src/guardrails/executor.rs @@ -1,8 +1,7 @@ use futures::future::join_all; use tracing::{debug, warn}; -use super::providers::GuardrailClient; -use super::types::{Guard, GuardResult, GuardWarning, GuardrailsOutcome, OnFailure}; +use super::types::{Guard, GuardResult, GuardWarning, GuardrailClient, GuardrailsOutcome, OnFailure}; /// Execute a set of guardrails against the given input text. /// Guards are run concurrently. Returns a GuardrailsOutcome with results, blocked status, and warnings. diff --git a/src/guardrails/guardrails_runner.rs b/src/guardrails/guardrails_runner.rs index 8307f7af..30f3f59b 100644 --- a/src/guardrails/guardrails_runner.rs +++ b/src/guardrails/guardrails_runner.rs @@ -10,8 +10,7 @@ use tracing::warn; use super::api_control::{parse_guardrails_header, resolve_guards_by_name, split_guards_by_mode}; use super::executor::execute_guards; use super::input_extractor::PreCallInput; -use super::providers::GuardrailClient; -use super::types::{Guard, GuardWarning, Guardrails}; +use super::types::{Guard, GuardWarning, GuardrailClient, Guardrails}; /// Result of running pre-call or post-call guards. pub struct GuardPhaseResult { diff --git a/src/guardrails/providers/mod.rs b/src/guardrails/providers/mod.rs index 3f26a85d..9e45663e 100644 --- a/src/guardrails/providers/mod.rs +++ b/src/guardrails/providers/mod.rs @@ -1,20 +1,7 @@ pub mod traceloop; -use async_trait::async_trait; - use self::traceloop::TraceloopClient; -use super::types::{EvaluatorResponse, Guard, GuardrailError}; - -/// Trait for guardrail evaluator clients. -/// Each provider (traceloop, etc.) implements this to call its evaluator API. -#[async_trait] -pub trait GuardrailClient: Send + Sync { - async fn evaluate( - &self, - guard: &Guard, - input: &str, - ) -> Result; -} +use super::types::{Guard, GuardrailClient}; /// Create a guardrail client based on the guard's provider type. pub fn create_guardrail_client(guard: &Guard) -> Option> { diff --git a/src/guardrails/providers/traceloop.rs b/src/guardrails/providers/traceloop.rs index 6b0c9158..bb103845 100644 --- a/src/guardrails/providers/traceloop.rs +++ b/src/guardrails/providers/traceloop.rs @@ -61,10 +61,10 @@ impl GuardrailClient for TraceloopClient { })?; let body = evaluator.build_body(input, &guard.params)?; - debug!(guard = %guard.name, slug = %guard.evaluator_slug, %url, "Calling evaluator API"); + debug!(guard = %guard.name, slug = %guard.evaluator_slug, %url, %body, "Calling evaluator API"); let response = self - .http_client + .http_client .post(&url) .header("Authorization", format!("Bearer {api_key}")) .header("Content-Type", "application/json") @@ -75,6 +75,8 @@ impl GuardrailClient for TraceloopClient { let status = response.status().as_u16(); let response_body = response.text().await?; + debug!(guard = %guard.name, %status, %response_body, "Evaluator API response"); + parse_evaluator_http_response(status, &response_body) } } diff --git a/src/guardrails/types.rs b/src/guardrails/types.rs index 4f637838..cef97360 100644 --- a/src/guardrails/types.rs +++ b/src/guardrails/types.rs @@ -1,11 +1,10 @@ +use async_trait::async_trait; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use std::collections::HashMap; use std::hash::{Hash, Hasher}; use std::sync::Arc; use thiserror::Error; -use super::providers::GuardrailClient; - /// Shared guardrail resources: resolved guards + client. /// Built once per router build and shared across all pipelines. pub type GuardrailResources = (Arc>, Arc); @@ -187,6 +186,17 @@ impl From for GuardrailError { } } +/// Trait for guardrail evaluator clients. +/// Each provider (traceloop, etc.) implements this to call its evaluator API. +#[async_trait] +pub trait GuardrailClient: Send + Sync { + async fn evaluate( + &self, + guard: &Guard, + input: &str, + ) -> Result; +} + /// Guardrails state attached to a pipeline, containing resolved guards and client. /// /// `all_guards` and `client` are shared across all pipelines via `Arc` (built once). From 582decf03f99ee2138383e7b1cf854276f5549f1 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Thu, 12 Feb 2026 13:50:13 +0200 Subject: [PATCH 12/59] guard response --- src/guardrails/executor.rs | 1 - src/guardrails/response_parser.rs | 14 ++++++++++++-- src/guardrails/types.rs | 1 - 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/guardrails/executor.rs b/src/guardrails/executor.rs index dd7150cc..9da6a1aa 100644 --- a/src/guardrails/executor.rs +++ b/src/guardrails/executor.rs @@ -50,7 +50,6 @@ pub async fn execute_guards( if response.pass { results.push(GuardResult::Passed { name: guard.name.clone(), - result: response.result, }); } else { results.push(GuardResult::Failed { diff --git a/src/guardrails/response_parser.rs b/src/guardrails/response_parser.rs index 40286c7e..05f5ae27 100644 --- a/src/guardrails/response_parser.rs +++ b/src/guardrails/response_parser.rs @@ -1,9 +1,19 @@ use super::types::{EvaluatorResponse, GuardrailError}; +use tracing::debug; /// Parse the evaluator response body (JSON string) into an EvaluatorResponse. pub fn parse_evaluator_response(body: &str) -> Result { - serde_json::from_str::(body) - .map_err(|e| GuardrailError::ParseError(e.to_string())) + let response = serde_json::from_str::(body) + .map_err(|e| GuardrailError::ParseError(e.to_string()))?; + + // Log for debugging + debug!( + pass = response.pass, + result = %response.result, + "Parsed evaluator response" + ); + + Ok(response) } /// Parse an HTTP response from the evaluator, handling non-200 status codes. diff --git a/src/guardrails/types.rs b/src/guardrails/types.rs index cef97360..8ba6b88a 100644 --- a/src/guardrails/types.rs +++ b/src/guardrails/types.rs @@ -133,7 +133,6 @@ pub struct EvaluatorResponse { pub enum GuardResult { Passed { name: String, - result: serde_json::Value, }, Failed { name: String, From f82250f310d5219e0e6dc9ae2ad07e6c79a99885 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Thu, 12 Feb 2026 15:09:10 +0200 Subject: [PATCH 13/59] added error --- src/guardrails/guardrails_runner.rs | 68 ++++++++++++++++++--------- src/guardrails/input_extractor.rs | 65 +++++++++++++------------ src/guardrails/providers/traceloop.rs | 4 +- src/pipelines/pipeline.rs | 9 +--- 4 files changed, 85 insertions(+), 61 deletions(-) diff --git a/src/guardrails/guardrails_runner.rs b/src/guardrails/guardrails_runner.rs index 30f3f59b..b333aa53 100644 --- a/src/guardrails/guardrails_runner.rs +++ b/src/guardrails/guardrails_runner.rs @@ -9,7 +9,7 @@ use tracing::warn; use super::api_control::{parse_guardrails_header, resolve_guards_by_name, split_guards_by_mode}; use super::executor::execute_guards; -use super::input_extractor::PreCallInput; +use super::input_extractor::{PromptExtractor, CompletionExtractor}; use super::types::{Guard, GuardWarning, GuardrailClient, Guardrails}; /// Result of running pre-call or post-call guards. @@ -43,18 +43,18 @@ impl<'a> GuardrailsRunner<'a> { } /// Run pre-call guards, extracting input from the request only if guards exist. - pub async fn run_pre_call(&self, request: &impl PreCallInput) -> GuardPhaseResult { + pub async fn run_pre_call(&self, request: &impl PromptExtractor) -> GuardPhaseResult { if self.pre_call.is_empty() { return GuardPhaseResult { blocked_response: None, warnings: Vec::new(), }; } - let input = request.extract_pre_call_input(); + let input = request.extract_pompt(); let outcome = execute_guards(&self.pre_call, &input, self.client).await; if outcome.blocked { return GuardPhaseResult { - blocked_response: Some(blocked_response(&outcome.blocking_guard)), + blocked_response: Some(blocked_response(&outcome)), warnings: Vec::new(), }; } @@ -64,18 +64,19 @@ impl<'a> GuardrailsRunner<'a> { } } - /// Run post-call guards against the LLM response text. - pub async fn run_post_call(&self, response_text: &str) -> GuardPhaseResult { + /// Run post-call guards, extracting input from the response only if guards exist. + pub async fn run_post_call(&self, response: &impl CompletionExtractor) -> GuardPhaseResult { if self.post_call.is_empty() { return GuardPhaseResult { blocked_response: None, warnings: Vec::new(), }; } - let outcome = execute_guards(&self.post_call, response_text, self.client).await; + let input = response.extract_completion(); + let outcome = execute_guards(&self.post_call, &input, self.client).await; if outcome.blocked { return GuardPhaseResult { - blocked_response: Some(blocked_response(&outcome.blocking_guard)), + blocked_response: Some(blocked_response(&outcome)), warnings: Vec::new(), }; } @@ -85,11 +86,6 @@ impl<'a> GuardrailsRunner<'a> { } } - /// Returns true if post-call guards are configured for this request. - pub fn has_post_call_guards(&self) -> bool { - !self.post_call.is_empty() - } - /// Attach warning headers to a response if there are any warnings. /// Returns the response unchanged if there are no warnings. pub fn finalize_response(response: Response, warnings: &[GuardWarning]) -> Response { @@ -107,15 +103,45 @@ impl<'a> GuardrailsRunner<'a> { } /// Build a 403 blocked response with the guard name. -pub fn blocked_response(blocking_guard: &Option) -> Response { - let guard_name = blocking_guard.as_deref().unwrap_or("unknown"); - let body = json!({ - "error": { - "type": "guardrail_blocked", - "guardrail": guard_name, - "message": format!("Request blocked by guardrail '{guard_name}'"), - } +pub fn blocked_response(outcome: &super::types::GuardrailsOutcome) -> Response { + use super::types::GuardResult; + + let guard_name = outcome.blocking_guard.as_deref().unwrap_or("unknown"); + + // Find the blocking guard result to get details + let details = outcome.results.iter() + .find(|r| match r { + GuardResult::Failed { name, .. } => name == guard_name, + GuardResult::Error { name, .. } => name == guard_name, + _ => false, + }) + .and_then(|r| match r { + GuardResult::Failed { result, .. } => Some(json!({ + "evaluation_result": result, + "reason": "evaluation_failed" + })), + GuardResult::Error { error, .. } => Some(json!({ + "error_details": error, + "reason": "evaluator_error" + })), + _ => None, + }); + + let mut error_obj = json!({ + "type": "guardrail_blocked", + "guardrail": guard_name, + "message": format!("Request blocked by guardrail '{guard_name}'"), }); + + if let Some(details) = details { + if let Some(obj) = error_obj.as_object_mut() { + if let Some(details_obj) = details.as_object() { + obj.extend(details_obj.clone()); + } + } + } + + let body = json!({ "error": error_obj }); (StatusCode::FORBIDDEN, Json(body)).into_response() } diff --git a/src/guardrails/input_extractor.rs b/src/guardrails/input_extractor.rs index 5d0ff017..7f1023eb 100644 --- a/src/guardrails/input_extractor.rs +++ b/src/guardrails/input_extractor.rs @@ -3,12 +3,17 @@ use crate::models::completion::{CompletionRequest, CompletionResponse}; use crate::models::content::ChatMessageContent; /// Trait for extracting pre-call guardrail input from a request. -pub trait PreCallInput { - fn extract_pre_call_input(&self) -> String; +pub trait PromptExtractor { + fn extract_pompt(&self) -> String; } -impl PreCallInput for ChatCompletionRequest { - fn extract_pre_call_input(&self) -> String { +/// Trait for extracting post-call guardrail input from a response. +pub trait CompletionExtractor { + fn extract_completion(&self) -> String; +} + +impl PromptExtractor for ChatCompletionRequest { + fn extract_pompt(&self) -> String { self.messages .iter() .filter_map(|m| { @@ -27,37 +32,35 @@ impl PreCallInput for ChatCompletionRequest { } } -impl PreCallInput for CompletionRequest { - fn extract_pre_call_input(&self) -> String { +impl PromptExtractor for CompletionRequest { + fn extract_pompt(&self) -> String { self.prompt.clone() } } -/// Extract text from a CompletionResponse for post_call guardrails. -/// Returns the text of the first choice. -pub fn extract_post_call_input_from_completion_response(response: &CompletionResponse) -> String { - response - .choices - .first() - .map(|choice| choice.text.clone()) - .unwrap_or_default() +impl CompletionExtractor for ChatCompletion { + fn extract_completion(&self) -> String { + self.choices + .first() + .and_then(|choice| choice.message.content.as_ref()) + .map(|content| match content { + ChatMessageContent::String(s) => s.clone(), + ChatMessageContent::Array(parts) => parts + .iter() + .filter(|p| p.r#type == "text") + .map(|p| p.text.as_str()) + .collect::>() + .join(" "), + }) + .unwrap_or_default() + } } -/// Extract text from a non-streaming ChatCompletion for post_call guardrails. -/// Returns the content of the first assistant choice. -pub fn extract_post_call_input_from_completion(completion: &ChatCompletion) -> String { - completion - .choices - .first() - .and_then(|choice| choice.message.content.as_ref()) - .map(|content| match content { - ChatMessageContent::String(s) => s.clone(), - ChatMessageContent::Array(parts) => parts - .iter() - .filter(|p| p.r#type == "text") - .map(|p| p.text.as_str()) - .collect::>() - .join(" "), - }) - .unwrap_or_default() +impl CompletionExtractor for CompletionResponse { + fn extract_completion(&self) -> String { + self.choices + .first() + .map(|choice| choice.text.clone()) + .unwrap_or_default() + } } diff --git a/src/guardrails/providers/traceloop.rs b/src/guardrails/providers/traceloop.rs index bb103845..7ca3e460 100644 --- a/src/guardrails/providers/traceloop.rs +++ b/src/guardrails/providers/traceloop.rs @@ -61,7 +61,7 @@ impl GuardrailClient for TraceloopClient { })?; let body = evaluator.build_body(input, &guard.params)?; - debug!(guard = %guard.name, slug = %guard.evaluator_slug, %url, %body, "Calling evaluator API"); + debug!(guard = %guard.name, slug = %guard.evaluator_slug, %url, %body, "NOMI - Calling evaluator API"); let response = self .http_client @@ -75,7 +75,7 @@ impl GuardrailClient for TraceloopClient { let status = response.status().as_u16(); let response_body = response.text().await?; - debug!(guard = %guard.name, %status, %response_body, "Evaluator API response"); + debug!(guard = %guard.name, %status, %response_body, "RON - Evaluator API response"); parse_evaluator_http_response(status, &response_body) } diff --git a/src/pipelines/pipeline.rs b/src/pipelines/pipeline.rs index c8427a83..afa2816f 100644 --- a/src/pipelines/pipeline.rs +++ b/src/pipelines/pipeline.rs @@ -1,7 +1,4 @@ use crate::config::models::PipelineType; -use crate::guardrails::input_extractor::{ - extract_post_call_input_from_completion, extract_post_call_input_from_completion_response, -}; use crate::guardrails::guardrails_runner::GuardrailsRunner; use crate::guardrails::types::{GuardrailResources, Guardrails}; use crate::models::chat::ChatCompletionResponse; @@ -161,8 +158,7 @@ pub async fn chat_completions( // Post-call guardrails (non-streaming) if let Some(orch) = &orchestrator { - let response_text = extract_post_call_input_from_completion(&completion); - let post = orch.run_post_call(&response_text).await; + let post = orch.run_post_call(&completion).await; if let Some(resp) = post.blocked_response { return Ok(resp); } @@ -223,8 +219,7 @@ pub async fn completions( // Post-call guardrails if let Some(orch) = &orchestrator { - let response_text = extract_post_call_input_from_completion_response(&response); - let post = orch.run_post_call(&response_text).await; + let post = orch.run_post_call(&response).await; if let Some(resp) = post.blocked_response { return Ok(resp); } From 89b5e479418ef99fccdd13e91fa89d4bdf35e87d Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Sun, 15 Feb 2026 09:43:44 +0200 Subject: [PATCH 14/59] no guards in completion --- src/guardrails/input_extractor.rs | 15 --------------- src/pipelines/pipeline.rs | 31 ++----------------------------- 2 files changed, 2 insertions(+), 44 deletions(-) diff --git a/src/guardrails/input_extractor.rs b/src/guardrails/input_extractor.rs index 7f1023eb..ec0a7fc4 100644 --- a/src/guardrails/input_extractor.rs +++ b/src/guardrails/input_extractor.rs @@ -1,5 +1,4 @@ use crate::models::chat::{ChatCompletion, ChatCompletionRequest}; -use crate::models::completion::{CompletionRequest, CompletionResponse}; use crate::models::content::ChatMessageContent; /// Trait for extracting pre-call guardrail input from a request. @@ -32,12 +31,6 @@ impl PromptExtractor for ChatCompletionRequest { } } -impl PromptExtractor for CompletionRequest { - fn extract_pompt(&self) -> String { - self.prompt.clone() - } -} - impl CompletionExtractor for ChatCompletion { fn extract_completion(&self) -> String { self.choices @@ -56,11 +49,3 @@ impl CompletionExtractor for ChatCompletion { } } -impl CompletionExtractor for CompletionResponse { - fn extract_completion(&self) -> String { - self.choices - .first() - .map(|choice| choice.text.clone()) - .unwrap_or_default() - } -} diff --git a/src/pipelines/pipeline.rs b/src/pipelines/pipeline.rs index afa2816f..fa70dda5 100644 --- a/src/pipelines/pipeline.rs +++ b/src/pipelines/pipeline.rs @@ -82,7 +82,7 @@ pub fn create_pipeline( PipelineType::Completion => { router.route( "/completions", - post(move |state, headers, payload| completions(state, headers, payload, models, gr)), + post(move |state, payload| completions(state, payload, models)), ) } PipelineType::Embeddings => router.route( @@ -186,25 +186,10 @@ pub async fn chat_completions( pub async fn completions( State(model_registry): State>, - headers: HeaderMap, Json(payload): Json, model_keys: Vec, - guardrails: Option>, ) -> Result { let mut tracer = OtelTracer::start("completion", &payload); - let orchestrator = GuardrailsRunner::new(guardrails.as_deref(), &headers); - - // Pre-call guardrails - let mut all_warnings = Vec::new(); - if let Some(orch) = &orchestrator { - let pre = orch - .run_pre_call(&payload) - .await; - if let Some(resp) = pre.blocked_response { - return Ok(resp); - } - all_warnings = pre.warnings; - } for model_key in model_keys { let model = model_registry.get(&model_key).unwrap(); @@ -217,19 +202,7 @@ pub async fn completions( })?; tracer.log_success(&response); - // Post-call guardrails - if let Some(orch) = &orchestrator { - let post = orch.run_post_call(&response).await; - if let Some(resp) = post.blocked_response { - return Ok(resp); - } - all_warnings.extend(post.warnings); - } - - return Ok(GuardrailsRunner::finalize_response( - Json(response).into_response(), - &all_warnings, - )); + return Ok(Json(response).into_response()); } } From 36b1720d552125051c6b845e7c0a3caae8a67bbc Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Sun, 15 Feb 2026 10:03:18 +0200 Subject: [PATCH 15/59] cleanup --- src/guardrails/api_control.rs | 18 +------ src/guardrails/guardrails_runner.rs | 2 +- src/guardrails/mod.rs | 1 - src/guardrails/providers/mod.rs | 4 +- src/guardrails/stream_buffer.rs | 11 ---- tests/guardrails/helpers.rs | 23 --------- tests/guardrails/main.rs | 1 - tests/guardrails/test_api_control.rs | 70 +++----------------------- tests/guardrails/test_e2e.rs | 16 +----- tests/guardrails/test_pipeline.rs | 13 +++-- tests/guardrails/test_stream_buffer.rs | 19 ------- 11 files changed, 19 insertions(+), 159 deletions(-) delete mode 100644 src/guardrails/stream_buffer.rs delete mode 100644 tests/guardrails/test_stream_buffer.rs diff --git a/src/guardrails/api_control.rs b/src/guardrails/api_control.rs index 15491f20..d908992d 100644 --- a/src/guardrails/api_control.rs +++ b/src/guardrails/api_control.rs @@ -12,26 +12,11 @@ pub fn parse_guardrails_header(header: &str) -> Vec { .collect() } -/// Parse guard names from the request payload's `guardrails` field. -pub fn parse_guardrails_from_payload(payload: &serde_json::Value) -> Vec { - payload - .get("guardrails") - .and_then(|v| v.as_array()) - .map(|arr| { - arr.iter() - .filter_map(|v| v.as_str().map(String::from)) - .collect() - }) - .unwrap_or_default() -} - -/// Resolve the final set of guards to execute by merging pipeline, header, and payload sources. -/// Guards are additive and deduplicated by name. Uses HashMap for O(1) guard lookup. +/// Resolve the final set of guards to execute by merging pipeline and header sources. pub fn resolve_guards_by_name( all_guards: &[Guard], pipeline_names: &[&str], header_names: &[&str], - payload_names: &[&str], ) -> Vec { let guard_map: HashMap<&str, &Guard> = all_guards.iter().map(|g| (g.name.as_str(), g)).collect(); @@ -42,7 +27,6 @@ pub fn resolve_guards_by_name( let all_names = pipeline_names .iter() .chain(header_names.iter()) - .chain(payload_names.iter()) .copied(); for name in all_names { diff --git a/src/guardrails/guardrails_runner.rs b/src/guardrails/guardrails_runner.rs index b333aa53..8bf7839e 100644 --- a/src/guardrails/guardrails_runner.rs +++ b/src/guardrails/guardrails_runner.rs @@ -163,7 +163,7 @@ fn resolve_request_guards(gr: &Guardrails, headers: &HeaderMap) -> (Vec, let pipeline_names: Vec<&str> = gr.pipeline_guard_names.iter().map(|s| s.as_str()).collect(); let header_names: Vec<&str> = header_guard_names.iter().map(|s| s.as_str()).collect(); - let resolved = resolve_guards_by_name(&gr.all_guards, &pipeline_names, &header_names, &[]); + let resolved = resolve_guards_by_name(&gr.all_guards, &pipeline_names, &header_names); // Log unknown guard names from headers if !header_guard_names.is_empty() { diff --git a/src/guardrails/mod.rs b/src/guardrails/mod.rs index fdc23eaa..959a8419 100644 --- a/src/guardrails/mod.rs +++ b/src/guardrails/mod.rs @@ -6,5 +6,4 @@ pub mod input_extractor; pub mod guardrails_runner; pub mod providers; pub mod response_parser; -pub mod stream_buffer; pub mod types; diff --git a/src/guardrails/providers/mod.rs b/src/guardrails/providers/mod.rs index 9e45663e..600e5726 100644 --- a/src/guardrails/providers/mod.rs +++ b/src/guardrails/providers/mod.rs @@ -3,10 +3,12 @@ pub mod traceloop; use self::traceloop::TraceloopClient; use super::types::{Guard, GuardrailClient}; +pub const TRACELOOP_PROVIDER: &str = "traceloop"; + /// Create a guardrail client based on the guard's provider type. pub fn create_guardrail_client(guard: &Guard) -> Option> { match guard.provider.as_str() { - "traceloop" => Some(Box::new(TraceloopClient::new())), + TRACELOOP_PROVIDER => Some(Box::new(TraceloopClient::new())), _ => None, } } diff --git a/src/guardrails/stream_buffer.rs b/src/guardrails/stream_buffer.rs deleted file mode 100644 index dfe964f2..00000000 --- a/src/guardrails/stream_buffer.rs +++ /dev/null @@ -1,11 +0,0 @@ -use crate::models::streaming::ChatCompletionChunk; - -/// Extract and concatenate text from accumulated streaming chunks. -/// Joins the delta content from all chunks into a single string. -pub fn extract_text_from_chunks(chunks: &[ChatCompletionChunk]) -> String { - chunks - .iter() - .flat_map(|chunk| &chunk.choices) - .filter_map(|choice| choice.delta.content.as_deref()) - .collect() -} diff --git a/tests/guardrails/helpers.rs b/tests/guardrails/helpers.rs index 0c8412d4..e3f76538 100644 --- a/tests/guardrails/helpers.rs +++ b/tests/guardrails/helpers.rs @@ -7,7 +7,6 @@ use hub_lib::guardrails::types::{ }; use hub_lib::models::chat::{ChatCompletion, ChatCompletionChoice, ChatCompletionRequest}; use hub_lib::models::content::{ChatCompletionMessage, ChatMessageContent}; -use hub_lib::models::streaming::{ChatCompletionChunk, Choice, ChoiceDelta}; use hub_lib::models::usage::Usage; use serde_json::json; @@ -167,28 +166,6 @@ pub fn create_test_chat_completion(response_text: &str) -> ChatCompletion { } } -pub fn create_test_chunk(content: &str) -> ChatCompletionChunk { - ChatCompletionChunk { - id: "chunk-1".to_string(), - choices: vec![Choice { - delta: ChoiceDelta { - content: Some(content.to_string()), - role: None, - tool_calls: None, - reasoning: None, - }, - finish_reason: None, - index: 0, - logprobs: None, - }], - created: 1234567890, - model: "gpt-4".to_string(), - service_tier: None, - system_fingerprint: None, - usage: None, - } -} - // --------------------------------------------------------------------------- // Mock GuardrailClient // --------------------------------------------------------------------------- diff --git a/tests/guardrails/main.rs b/tests/guardrails/main.rs index 57873d30..dedbde9e 100644 --- a/tests/guardrails/main.rs +++ b/tests/guardrails/main.rs @@ -5,6 +5,5 @@ mod test_executor; mod test_input_extractor; mod test_pipeline; mod test_response_parser; -mod test_stream_buffer; mod test_traceloop_client; mod test_types; diff --git a/tests/guardrails/test_api_control.rs b/tests/guardrails/test_api_control.rs index 6f16fa01..62c3ab4a 100644 --- a/tests/guardrails/test_api_control.rs +++ b/tests/guardrails/test_api_control.rs @@ -1,6 +1,5 @@ use hub_lib::guardrails::api_control::*; use hub_lib::guardrails::types::GuardMode; -use serde_json::json; use super::helpers::*; @@ -23,17 +22,10 @@ fn test_parse_guardrails_header_multiple() { ); } -#[test] -fn test_parse_guardrails_from_payload() { - let payload = json!({"guardrails": ["toxicity-check", "pii-check"]}); - let names = parse_guardrails_from_payload(&payload); - assert_eq!(names, vec!["toxicity-check", "pii-check"]); -} - #[test] fn test_pipeline_guardrails_always_included() { let pipeline_guards = vec![create_test_guard("pipeline-guard", GuardMode::PreCall)]; - let resolved = resolve_guards_by_name(&pipeline_guards, &["pipeline-guard"], &[], &[]); + let resolved = resolve_guards_by_name(&pipeline_guards, &["pipeline-guard"], &[]); assert_eq!(resolved.len(), 1); assert_eq!(resolved[0].name, "pipeline-guard"); } @@ -44,37 +36,10 @@ fn test_header_guardrails_additive_to_pipeline() { create_test_guard("pipeline-guard", GuardMode::PreCall), create_test_guard("header-guard", GuardMode::PreCall), ]; - let resolved = resolve_guards_by_name(&all_guards, &["pipeline-guard"], &["header-guard"], &[]); - assert_eq!(resolved.len(), 2); -} - -#[test] -fn test_payload_guardrails_additive_to_pipeline() { - let all_guards = vec![ - create_test_guard("pipeline-guard", GuardMode::PreCall), - create_test_guard("payload-guard", GuardMode::PreCall), - ]; - let resolved = - resolve_guards_by_name(&all_guards, &["pipeline-guard"], &[], &["payload-guard"]); + let resolved = resolve_guards_by_name(&all_guards, &["pipeline-guard"], &["header-guard"]); assert_eq!(resolved.len(), 2); } -#[test] -fn test_header_and_payload_both_additive() { - let all_guards = vec![ - create_test_guard("pipeline-guard", GuardMode::PreCall), - create_test_guard("header-guard", GuardMode::PreCall), - create_test_guard("payload-guard", GuardMode::PreCall), - ]; - let resolved = resolve_guards_by_name( - &all_guards, - &["pipeline-guard"], - &["header-guard"], - &["payload-guard"], - ); - assert_eq!(resolved.len(), 3); -} - #[test] fn test_deduplication_by_name() { let all_guards = vec![create_test_guard("shared-guard", GuardMode::PreCall)]; @@ -82,7 +47,6 @@ fn test_deduplication_by_name() { &all_guards, &["shared-guard"], &["shared-guard"], // duplicate - &["shared-guard"], // duplicate ); assert_eq!(resolved.len(), 1); } @@ -94,39 +58,19 @@ fn test_unknown_guard_name_in_header_ignored() { &all_guards, &["known-guard"], &["nonexistent-guard"], // unknown - &[], ); assert_eq!(resolved.len(), 1); assert_eq!(resolved[0].name, "known-guard"); } -#[test] -fn test_unknown_guard_name_in_payload_ignored() { - let all_guards = vec![create_test_guard("known-guard", GuardMode::PreCall)]; - let resolved = resolve_guards_by_name( - &all_guards, - &["known-guard"], - &[], - &["nonexistent-guard"], // unknown - ); - assert_eq!(resolved.len(), 1); -} - #[test] fn test_empty_header_pipeline_guards_only() { let all_guards = vec![create_test_guard("pipeline-guard", GuardMode::PreCall)]; - let resolved = resolve_guards_by_name(&all_guards, &["pipeline-guard"], &[], &[]); + let resolved = resolve_guards_by_name(&all_guards, &["pipeline-guard"], &[]); assert_eq!(resolved.len(), 1); assert_eq!(resolved[0].name, "pipeline-guard"); } -#[test] -fn test_empty_payload_pipeline_guards_only() { - let all_guards = vec![create_test_guard("pipeline-guard", GuardMode::PreCall)]; - let resolved = resolve_guards_by_name(&all_guards, &["pipeline-guard"], &[], &[]); - assert_eq!(resolved.len(), 1); -} - #[test] fn test_cannot_remove_pipeline_guardrails_via_api() { let all_guards = vec![ @@ -134,7 +78,7 @@ fn test_cannot_remove_pipeline_guardrails_via_api() { create_test_guard("extra", GuardMode::PreCall), ]; // Even if header/payload only mention "extra", pipeline guard still included - let resolved = resolve_guards_by_name(&all_guards, &["pipeline-guard"], &["extra"], &[]); + let resolved = resolve_guards_by_name(&all_guards, &["pipeline-guard"], &["extra"]); assert!(resolved.iter().any(|g| g.name == "pipeline-guard")); } @@ -159,16 +103,14 @@ fn test_complete_resolution_merged() { create_test_guard("pipeline-pre", GuardMode::PreCall), create_test_guard("pipeline-post", GuardMode::PostCall), create_test_guard("header-pre", GuardMode::PreCall), - create_test_guard("payload-post", GuardMode::PostCall), ]; let resolved = resolve_guards_by_name( &all_guards, &["pipeline-pre", "pipeline-post"], &["header-pre"], - &["payload-post"], ); - assert_eq!(resolved.len(), 4); + assert_eq!(resolved.len(), 3); let (pre, post) = split_guards_by_mode(&resolved); assert_eq!(pre.len(), 2); // pipeline-pre + header-pre - assert_eq!(post.len(), 2); // pipeline-post + payload-post + assert_eq!(post.len(), 1); // pipeline-post } diff --git a/tests/guardrails/test_e2e.rs b/tests/guardrails/test_e2e.rs index f0bc95b8..056b286a 100644 --- a/tests/guardrails/test_e2e.rs +++ b/tests/guardrails/test_e2e.rs @@ -3,7 +3,6 @@ use hub_lib::guardrails::input_extractor::{ extract_post_call_input_from_completion, extract_prompt, }; use hub_lib::guardrails::providers::traceloop::TraceloopClient; -use hub_lib::guardrails::stream_buffer::extract_text_from_chunks; use hub_lib::guardrails::types::*; use hub_lib::pipelines::pipeline::{build_guardrail_resources, build_pipeline_guardrails}; @@ -272,14 +271,7 @@ async fn test_e2e_streaming_post_call_buffer_pass() { "profanity-detector", ); - // Simulate accumulated streaming chunks - let chunks = vec![ - create_test_chunk("Hello"), - create_test_chunk(" "), - create_test_chunk("world!"), - ]; - let accumulated = extract_text_from_chunks(&chunks); - assert_eq!(accumulated, "Hello world!"); + let accumulated = "Hello world!"; let client = TraceloopClient::new(); let outcome = execute_guards(&[guard], &accumulated, &client).await; @@ -299,11 +291,7 @@ async fn test_e2e_streaming_post_call_buffer_block() { "pii-detector", ); - let chunks = vec![ - create_test_chunk("Here is "), - create_test_chunk("SSN: 123-45-6789"), - ]; - let accumulated = extract_text_from_chunks(&chunks); + let accumulated = "Here is SSN: 123-45-6789"; let client = TraceloopClient::new(); let outcome = execute_guards(&[guard], &accumulated, &client).await; diff --git a/tests/guardrails/test_pipeline.rs b/tests/guardrails/test_pipeline.rs index c500358d..e38ba6f9 100644 --- a/tests/guardrails/test_pipeline.rs +++ b/tests/guardrails/test_pipeline.rs @@ -284,7 +284,7 @@ fn test_build_pipeline_guardrails_empty_pipeline_guards() { let shared = build_guardrail_resources(&config).unwrap(); // Pipeline with no guards specified - shared resources still exist // (header guards can still be used at request time) - let gr = build_pipeline_guardrails(&shared, &[]); + let gr = build_pipeline_guardrails(&shared); assert_eq!(gr.all_guards.len(), 4); assert!(gr.pipeline_guard_names.is_empty()); @@ -344,7 +344,6 @@ fn test_pipeline_guards_resolved_at_request_time() { &all_guards, &pipeline_names, &header_names, - &[], ); assert_eq!(resolved.len(), 2); assert_eq!(resolved[0].name, "pii-check"); @@ -361,7 +360,7 @@ fn test_pipeline_guards_plus_header_guards_split_by_mode() { // Header adds injection-check (pre_call) and secrets-check (post_call) let header_names = vec!["injection-check", "secrets-check"]; - let resolved = resolve_guards_by_name(&all_guards, &pipeline_names, &header_names, &[]); + let resolved = resolve_guards_by_name(&all_guards, &pipeline_names, &header_names); assert_eq!(resolved.len(), 4); let (pre_call, post_call) = split_guards_by_mode(&resolved); @@ -381,7 +380,7 @@ fn test_header_guard_not_in_config_is_ignored() { let pipeline_names = vec!["pii-check"]; let header_names = vec!["nonexistent-guard"]; - let resolved = resolve_guards_by_name(&all_guards, &pipeline_names, &header_names, &[]); + let resolved = resolve_guards_by_name(&all_guards, &pipeline_names, &header_names); // Only pii-check should be resolved; nonexistent guard is silently ignored assert_eq!(resolved.len(), 1); assert_eq!(resolved[0].name, "pii-check"); @@ -396,7 +395,7 @@ fn test_duplicate_guard_in_header_and_pipeline_deduped() { // Header specifies same guard as pipeline let header_names = vec!["pii-check"]; - let resolved = resolve_guards_by_name(&all_guards, &pipeline_names, &header_names, &[]); + let resolved = resolve_guards_by_name(&all_guards, &pipeline_names, &header_names); assert_eq!(resolved.len(), 2); // pii-check only appears once } @@ -410,7 +409,7 @@ fn test_no_pipeline_guards_header_only() { // Header adds guards let header_names = vec!["injection-check", "secrets-check"]; - let resolved = resolve_guards_by_name(&all_guards, &pipeline_names, &header_names, &[]); + let resolved = resolve_guards_by_name(&all_guards, &pipeline_names, &header_names); assert_eq!(resolved.len(), 2); assert_eq!(resolved[0].name, "injection-check"); assert_eq!(resolved[1].name, "secrets-check"); @@ -421,7 +420,7 @@ fn test_no_pipeline_guards_no_header_no_guards_executed() { let config = test_guardrails_config(); let all_guards = resolve_guard_defaults(&config); - let resolved = resolve_guards_by_name(&all_guards, &[], &[], &[]); + let resolved = resolve_guards_by_name(&all_guards, &[], &[]); assert!(resolved.is_empty()); let (pre_call, post_call) = split_guards_by_mode(&resolved); diff --git a/tests/guardrails/test_stream_buffer.rs b/tests/guardrails/test_stream_buffer.rs deleted file mode 100644 index 8fe1b2f5..00000000 --- a/tests/guardrails/test_stream_buffer.rs +++ /dev/null @@ -1,19 +0,0 @@ -use hub_lib::guardrails::stream_buffer::*; - -use super::helpers::*; - -// --------------------------------------------------------------------------- -// Phase 2: Stream Buffer (1 test) -// --------------------------------------------------------------------------- - -#[test] -fn test_extract_from_accumulated_stream_chunks() { - let chunks = vec![ - create_test_chunk("Hello"), - create_test_chunk(" "), - create_test_chunk("world"), - create_test_chunk("!"), - ]; - let text = extract_text_from_chunks(&chunks); - assert_eq!(text, "Hello world!"); -} From accfd5f2497322a97b3ae862e88612f0cfed1ecb Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Sun, 15 Feb 2026 10:14:55 +0200 Subject: [PATCH 16/59] merge files --- src/guardrails/executor.rs | 103 ---------------- src/guardrails/mod.rs | 6 +- .../{input_extractor.rs => parsing.rs} | 31 +++++ src/guardrails/providers/traceloop.rs | 2 +- src/guardrails/response_parser.rs | 31 ----- .../{guardrails_runner.rs => runner.rs} | 111 ++++++++++++++++-- src/pipelines/pipeline.rs | 4 +- tests/guardrails/test_e2e.rs | 4 +- tests/guardrails/test_executor.rs | 4 +- tests/guardrails/test_input_extractor.rs | 2 +- tests/guardrails/test_pipeline.rs | 2 +- tests/guardrails/test_response_parser.rs | 2 +- 12 files changed, 147 insertions(+), 155 deletions(-) delete mode 100644 src/guardrails/executor.rs rename src/guardrails/{input_extractor.rs => parsing.rs} (64%) delete mode 100644 src/guardrails/response_parser.rs rename src/guardrails/{guardrails_runner.rs => runner.rs} (62%) diff --git a/src/guardrails/executor.rs b/src/guardrails/executor.rs deleted file mode 100644 index 9da6a1aa..00000000 --- a/src/guardrails/executor.rs +++ /dev/null @@ -1,103 +0,0 @@ -use futures::future::join_all; -use tracing::{debug, warn}; - -use super::types::{Guard, GuardResult, GuardWarning, GuardrailClient, GuardrailsOutcome, OnFailure}; - -/// Execute a set of guardrails against the given input text. -/// Guards are run concurrently. Returns a GuardrailsOutcome with results, blocked status, and warnings. -pub async fn execute_guards( - guards: &[Guard], - input: &str, - client: &dyn GuardrailClient, -) -> GuardrailsOutcome { - debug!(guard_count = guards.len(), "Executing guardrails"); - - let futures: Vec<_> = guards - .iter() - .map(|guard| async move { - let start = std::time::Instant::now(); - let result = client.evaluate(guard, input).await; - let elapsed = start.elapsed(); - match &result { - Ok(resp) => debug!( - guard = %guard.name, - pass = resp.pass, - elapsed_ms = elapsed.as_millis(), - "Guard evaluation complete" - ), - Err(err) => warn!( - guard = %guard.name, - error = %err, - required = guard.required, - elapsed_ms = elapsed.as_millis(), - "Guard evaluation failed" - ), - } - (guard, result) - }) - .collect(); - - let results_raw = join_all(futures).await; - - let mut results = Vec::new(); - let mut blocked = false; - let mut blocking_guard = None; - let mut warnings = Vec::new(); - - for (guard, result) in results_raw { - match result { - Ok(response) => { - if response.pass { - results.push(GuardResult::Passed { - name: guard.name.clone(), - }); - } else { - results.push(GuardResult::Failed { - name: guard.name.clone(), - result: response.result, - on_failure: guard.on_failure.clone(), - }); - match guard.on_failure { - OnFailure::Block => { - blocked = true; - if blocking_guard.is_none() { - blocking_guard = Some(guard.name.clone()); - } - } - OnFailure::Warn => { - warnings.push(GuardWarning { - guard_name: guard.name.clone(), - reason: "failed".to_string(), - }); - } - } - } - } - Err(err) => { - let is_required = guard.required; - results.push(GuardResult::Error { - name: guard.name.clone(), - error: err.to_string(), - required: is_required, - }); - if is_required { - blocked = true; - if blocking_guard.is_none() { - blocking_guard = Some(guard.name.clone()); - } - } - } - } - } - - if blocked { - warn!(blocking_guard = ?blocking_guard, "Request blocked by guardrail"); - } - - GuardrailsOutcome { - results, - blocked, - blocking_guard, - warnings, - } -} diff --git a/src/guardrails/mod.rs b/src/guardrails/mod.rs index 959a8419..74a27473 100644 --- a/src/guardrails/mod.rs +++ b/src/guardrails/mod.rs @@ -1,9 +1,7 @@ pub mod api_control; pub mod builder; pub mod evaluator_types; -pub mod executor; -pub mod input_extractor; -pub mod guardrails_runner; +pub mod runner; +pub mod parsing; pub mod providers; -pub mod response_parser; pub mod types; diff --git a/src/guardrails/input_extractor.rs b/src/guardrails/parsing.rs similarity index 64% rename from src/guardrails/input_extractor.rs rename to src/guardrails/parsing.rs index ec0a7fc4..eef38b8a 100644 --- a/src/guardrails/input_extractor.rs +++ b/src/guardrails/parsing.rs @@ -1,5 +1,8 @@ use crate::models::chat::{ChatCompletion, ChatCompletionRequest}; use crate::models::content::ChatMessageContent; +use tracing::debug; + +use super::types::{EvaluatorResponse, GuardrailError}; /// Trait for extracting pre-call guardrail input from a request. pub trait PromptExtractor { @@ -49,3 +52,31 @@ impl CompletionExtractor for ChatCompletion { } } +/// Parse the evaluator response body (JSON string) into an EvaluatorResponse. +pub fn parse_evaluator_response(body: &str) -> Result { + let response = serde_json::from_str::(body) + .map_err(|e| GuardrailError::ParseError(e.to_string()))?; + + // Log for debugging + debug!( + pass = response.pass, + result = %response.result, + "Parsed evaluator response" + ); + + Ok(response) +} + +/// Parse an HTTP response from the evaluator, handling non-200 status codes. +pub fn parse_evaluator_http_response( + status: u16, + body: &str, +) -> Result { + if !(200..300).contains(&status) { + return Err(GuardrailError::HttpError { + status, + body: body.to_string(), + }); + } + parse_evaluator_response(body) +} diff --git a/src/guardrails/providers/traceloop.rs b/src/guardrails/providers/traceloop.rs index 7ca3e460..3795f3e1 100644 --- a/src/guardrails/providers/traceloop.rs +++ b/src/guardrails/providers/traceloop.rs @@ -4,7 +4,7 @@ use tracing::debug; use super::GuardrailClient; use crate::guardrails::evaluator_types::get_evaluator; -use crate::guardrails::response_parser::parse_evaluator_http_response; +use crate::guardrails::parsing::parse_evaluator_http_response; use crate::guardrails::types::{EvaluatorResponse, Guard, GuardrailError}; diff --git a/src/guardrails/response_parser.rs b/src/guardrails/response_parser.rs deleted file mode 100644 index 05f5ae27..00000000 --- a/src/guardrails/response_parser.rs +++ /dev/null @@ -1,31 +0,0 @@ -use super::types::{EvaluatorResponse, GuardrailError}; -use tracing::debug; - -/// Parse the evaluator response body (JSON string) into an EvaluatorResponse. -pub fn parse_evaluator_response(body: &str) -> Result { - let response = serde_json::from_str::(body) - .map_err(|e| GuardrailError::ParseError(e.to_string()))?; - - // Log for debugging - debug!( - pass = response.pass, - result = %response.result, - "Parsed evaluator response" - ); - - Ok(response) -} - -/// Parse an HTTP response from the evaluator, handling non-200 status codes. -pub fn parse_evaluator_http_response( - status: u16, - body: &str, -) -> Result { - if !(200..300).contains(&status) { - return Err(GuardrailError::HttpError { - status, - body: body.to_string(), - }); - } - parse_evaluator_response(body) -} diff --git a/src/guardrails/guardrails_runner.rs b/src/guardrails/runner.rs similarity index 62% rename from src/guardrails/guardrails_runner.rs rename to src/guardrails/runner.rs index 8bf7839e..bee222b5 100644 --- a/src/guardrails/guardrails_runner.rs +++ b/src/guardrails/runner.rs @@ -4,13 +4,112 @@ use axum::http::HeaderMap; use axum::response::{IntoResponse, Response}; use axum::http::StatusCode; use axum::Json; +use futures::future::join_all; use serde_json::json; -use tracing::warn; +use tracing::{debug, warn}; use super::api_control::{parse_guardrails_header, resolve_guards_by_name, split_guards_by_mode}; -use super::executor::execute_guards; -use super::input_extractor::{PromptExtractor, CompletionExtractor}; -use super::types::{Guard, GuardWarning, GuardrailClient, Guardrails}; +use super::parsing::{PromptExtractor, CompletionExtractor}; +use super::types::{Guard, GuardResult, GuardWarning, GuardrailClient, GuardrailsOutcome, Guardrails, OnFailure}; + +/// Execute a set of guardrails against the given input text. +/// Guards are run concurrently. Returns a GuardrailsOutcome with results, blocked status, and warnings. +pub async fn execute_guards( + guards: &[Guard], + input: &str, + client: &dyn GuardrailClient, +) -> GuardrailsOutcome { + debug!(guard_count = guards.len(), "Executing guardrails"); + + let futures: Vec<_> = guards + .iter() + .map(|guard| async move { + let start = std::time::Instant::now(); + let result = client.evaluate(guard, input).await; + let elapsed = start.elapsed(); + match &result { + Ok(resp) => debug!( + guard = %guard.name, + pass = resp.pass, + elapsed_ms = elapsed.as_millis(), + "Guard evaluation complete" + ), + Err(err) => warn!( + guard = %guard.name, + error = %err, + required = guard.required, + elapsed_ms = elapsed.as_millis(), + "Guard evaluation failed" + ), + } + (guard, result) + }) + .collect(); + + let results_raw = join_all(futures).await; + + let mut results = Vec::new(); + let mut blocked = false; + let mut blocking_guard = None; + let mut warnings = Vec::new(); + + for (guard, result) in results_raw { + match result { + Ok(response) => { + if response.pass { + results.push(GuardResult::Passed { + name: guard.name.clone(), + }); + } else { + results.push(GuardResult::Failed { + name: guard.name.clone(), + result: response.result, + on_failure: guard.on_failure.clone(), + }); + match guard.on_failure { + OnFailure::Block => { + blocked = true; + if blocking_guard.is_none() { + blocking_guard = Some(guard.name.clone()); + } + } + OnFailure::Warn => { + warnings.push(GuardWarning { + guard_name: guard.name.clone(), + reason: "failed".to_string(), + }); + } + } + } + } + Err(err) => { + let is_required = guard.required; + results.push(GuardResult::Error { + name: guard.name.clone(), + error: err.to_string(), + required: is_required, + }); + if is_required { + blocked = true; + if blocking_guard.is_none() { + blocking_guard = Some(guard.name.clone()); + } + } + } + } + } + + if blocked { + warn!(blocking_guard = ?blocking_guard, "Request blocked by guardrail"); + } + + GuardrailsOutcome { + results, + blocked, + blocking_guard, + warnings, + } +} /// Result of running pre-call or post-call guards. pub struct GuardPhaseResult { @@ -103,9 +202,7 @@ impl<'a> GuardrailsRunner<'a> { } /// Build a 403 blocked response with the guard name. -pub fn blocked_response(outcome: &super::types::GuardrailsOutcome) -> Response { - use super::types::GuardResult; - +pub fn blocked_response(outcome: &GuardrailsOutcome) -> Response { let guard_name = outcome.blocking_guard.as_deref().unwrap_or("unknown"); // Find the blocking guard result to get details diff --git a/src/pipelines/pipeline.rs b/src/pipelines/pipeline.rs index fa70dda5..adf269de 100644 --- a/src/pipelines/pipeline.rs +++ b/src/pipelines/pipeline.rs @@ -1,5 +1,5 @@ use crate::config::models::PipelineType; -use crate::guardrails::guardrails_runner::GuardrailsRunner; +use crate::guardrails::runner::GuardrailsRunner; use crate::guardrails::types::{GuardrailResources, Guardrails}; use crate::models::chat::ChatCompletionResponse; use crate::models::completion::CompletionRequest; @@ -31,7 +31,7 @@ use std::sync::Arc; pub use crate::guardrails::builder::{ build_guardrail_resources, build_pipeline_guardrails, resolve_guard_defaults, }; -pub use crate::guardrails::guardrails_runner::{blocked_response, warning_header_value}; +pub use crate::guardrails::runner::{blocked_response, warning_header_value}; pub fn create_pipeline( pipeline: &Pipeline, diff --git a/tests/guardrails/test_e2e.rs b/tests/guardrails/test_e2e.rs index 056b286a..e010c037 100644 --- a/tests/guardrails/test_e2e.rs +++ b/tests/guardrails/test_e2e.rs @@ -1,5 +1,5 @@ -use hub_lib::guardrails::executor::execute_guards; -use hub_lib::guardrails::input_extractor::{ +use hub_lib::guardrails::runner::execute_guards; +use hub_lib::guardrails::parsing::{ extract_post_call_input_from_completion, extract_prompt, }; use hub_lib::guardrails::providers::traceloop::TraceloopClient; diff --git a/tests/guardrails/test_executor.rs b/tests/guardrails/test_executor.rs index a4b13371..0ae292f0 100644 --- a/tests/guardrails/test_executor.rs +++ b/tests/guardrails/test_executor.rs @@ -1,5 +1,5 @@ -use hub_lib::guardrails::executor::*; -use hub_lib::guardrails::input_extractor::*; +use hub_lib::guardrails::runner::*; +use hub_lib::guardrails::parsing::*; use hub_lib::guardrails::types::*; use super::helpers::*; diff --git a/tests/guardrails/test_input_extractor.rs b/tests/guardrails/test_input_extractor.rs index d9dd0232..cc88d62d 100644 --- a/tests/guardrails/test_input_extractor.rs +++ b/tests/guardrails/test_input_extractor.rs @@ -1,4 +1,4 @@ -use hub_lib::guardrails::input_extractor::*; +use hub_lib::guardrails::parsing::*; use hub_lib::models::content::{ChatCompletionMessage, ChatMessageContent, ChatMessageContentPart}; use super::helpers::*; diff --git a/tests/guardrails/test_pipeline.rs b/tests/guardrails/test_pipeline.rs index e38ba6f9..78e2b694 100644 --- a/tests/guardrails/test_pipeline.rs +++ b/tests/guardrails/test_pipeline.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; use hub_lib::guardrails::api_control::{resolve_guards_by_name, split_guards_by_mode}; -use hub_lib::guardrails::executor::execute_guards; +use hub_lib::guardrails::runner::execute_guards; use hub_lib::guardrails::providers::traceloop::TraceloopClient; use hub_lib::guardrails::types::*; use hub_lib::pipelines::pipeline::{ diff --git a/tests/guardrails/test_response_parser.rs b/tests/guardrails/test_response_parser.rs index a9f2dc85..7ed9813a 100644 --- a/tests/guardrails/test_response_parser.rs +++ b/tests/guardrails/test_response_parser.rs @@ -1,4 +1,4 @@ -use hub_lib::guardrails::response_parser::*; +use hub_lib::guardrails::parsing::*; use hub_lib::guardrails::types::GuardrailError; // --------------------------------------------------------------------------- From 5c78a5645203bc61d50204b48ac0ff22b3853064 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Sun, 15 Feb 2026 10:18:47 +0200 Subject: [PATCH 17/59] setup --- src/guardrails/api_control.rs | 56 ----------------------- src/guardrails/mod.rs | 5 +-- src/guardrails/runner.rs | 2 +- src/guardrails/{builder.rs => setup.rs} | 60 +++++++++++++++++++++++-- src/pipelines/pipeline.rs | 2 +- src/state.rs | 2 +- tests/guardrails/test_api_control.rs | 2 +- tests/guardrails/test_pipeline.rs | 2 +- 8 files changed, 64 insertions(+), 67 deletions(-) delete mode 100644 src/guardrails/api_control.rs rename src/guardrails/{builder.rs => setup.rs} (50%) diff --git a/src/guardrails/api_control.rs b/src/guardrails/api_control.rs deleted file mode 100644 index d908992d..00000000 --- a/src/guardrails/api_control.rs +++ /dev/null @@ -1,56 +0,0 @@ -use std::collections::{HashMap, HashSet}; - -use super::types::{Guard, GuardMode}; - -/// Parse guard names from the X-Traceloop-Guardrails header value. -/// Names are comma-separated and trimmed. -pub fn parse_guardrails_header(header: &str) -> Vec { - header - .split(',') - .map(|s| s.trim().to_string()) - .filter(|s| !s.is_empty()) - .collect() -} - -/// Resolve the final set of guards to execute by merging pipeline and header sources. -pub fn resolve_guards_by_name( - all_guards: &[Guard], - pipeline_names: &[&str], - header_names: &[&str], -) -> Vec { - let guard_map: HashMap<&str, &Guard> = - all_guards.iter().map(|g| (g.name.as_str(), g)).collect(); - - let mut seen = HashSet::new(); - let mut resolved = Vec::new(); - - let all_names = pipeline_names - .iter() - .chain(header_names.iter()) - .copied(); - - for name in all_names { - if seen.insert(name) { - if let Some(guard) = guard_map.get(name) { - resolved.push((*guard).clone()); - } - } - } - - resolved -} - -/// Split guards into (pre_call, post_call) lists by mode. -pub fn split_guards_by_mode(guards: &[Guard]) -> (Vec, Vec) { - let pre_call: Vec = guards - .iter() - .filter(|g| g.mode == GuardMode::PreCall) - .cloned() - .collect(); - let post_call: Vec = guards - .iter() - .filter(|g| g.mode == GuardMode::PostCall) - .cloned() - .collect(); - (pre_call, post_call) -} diff --git a/src/guardrails/mod.rs b/src/guardrails/mod.rs index 74a27473..140ab4c4 100644 --- a/src/guardrails/mod.rs +++ b/src/guardrails/mod.rs @@ -1,7 +1,6 @@ -pub mod api_control; -pub mod builder; pub mod evaluator_types; -pub mod runner; pub mod parsing; pub mod providers; +pub mod runner; +pub mod setup; pub mod types; diff --git a/src/guardrails/runner.rs b/src/guardrails/runner.rs index bee222b5..f5139b19 100644 --- a/src/guardrails/runner.rs +++ b/src/guardrails/runner.rs @@ -8,7 +8,7 @@ use futures::future::join_all; use serde_json::json; use tracing::{debug, warn}; -use super::api_control::{parse_guardrails_header, resolve_guards_by_name, split_guards_by_mode}; +use super::setup::{parse_guardrails_header, resolve_guards_by_name, split_guards_by_mode}; use super::parsing::{PromptExtractor, CompletionExtractor}; use super::types::{Guard, GuardResult, GuardWarning, GuardrailClient, GuardrailsOutcome, Guardrails, OnFailure}; diff --git a/src/guardrails/builder.rs b/src/guardrails/setup.rs similarity index 50% rename from src/guardrails/builder.rs rename to src/guardrails/setup.rs index 37892163..f343b0c4 100644 --- a/src/guardrails/builder.rs +++ b/src/guardrails/setup.rs @@ -1,9 +1,63 @@ +use std::collections::{HashMap, HashSet}; use std::sync::Arc; -use super::types::{GuardrailResources, GuardrailsConfig, Guardrails}; +use super::types::{Guard, GuardMode, GuardrailClient, GuardrailResources, GuardrailsConfig, Guardrails}; + +/// Parse guard names from the X-Traceloop-Guardrails header value. +/// Names are comma-separated and trimmed. +pub fn parse_guardrails_header(header: &str) -> Vec { + header + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect() +} + +/// Resolve the final set of guards to execute by merging pipeline and header sources. +pub fn resolve_guards_by_name( + all_guards: &[Guard], + pipeline_names: &[&str], + header_names: &[&str], +) -> Vec { + let guard_map: HashMap<&str, &Guard> = + all_guards.iter().map(|g| (g.name.as_str(), g)).collect(); + + let mut seen = HashSet::new(); + let mut resolved = Vec::new(); + + let all_names = pipeline_names + .iter() + .chain(header_names.iter()) + .copied(); + + for name in all_names { + if seen.insert(name) { + if let Some(guard) = guard_map.get(name) { + resolved.push((*guard).clone()); + } + } + } + + resolved +} + +/// Split guards into (pre_call, post_call) lists by mode. +pub fn split_guards_by_mode(guards: &[Guard]) -> (Vec, Vec) { + let pre_call: Vec = guards + .iter() + .filter(|g| g.mode == GuardMode::PreCall) + .cloned() + .collect(); + let post_call: Vec = guards + .iter() + .filter(|g| g.mode == GuardMode::PostCall) + .cloned() + .collect(); + (pre_call, post_call) +} /// Resolve provider defaults (api_base/api_key) for all guards in the config. -pub fn resolve_guard_defaults(config: &GuardrailsConfig) -> Vec { +pub fn resolve_guard_defaults(config: &GuardrailsConfig) -> Vec { let mut guards = config.guards.clone(); for guard in &mut guards { if guard.api_base.is_none() || guard.api_key.is_none() { @@ -30,7 +84,7 @@ pub fn build_guardrail_resources( return None; } let all_guards = Arc::new(resolve_guard_defaults(config)); - let client: Arc = + let client: Arc = Arc::new(super::providers::traceloop::TraceloopClient::new()); Some((all_guards, client)) } diff --git a/src/pipelines/pipeline.rs b/src/pipelines/pipeline.rs index adf269de..da4f75d1 100644 --- a/src/pipelines/pipeline.rs +++ b/src/pipelines/pipeline.rs @@ -28,7 +28,7 @@ use reqwest_streams::error::StreamBodyError; use std::sync::Arc; // Re-export builder and orchestrator functions for backward compatibility with tests -pub use crate::guardrails::builder::{ +pub use crate::guardrails::setup::{ build_guardrail_resources, build_pipeline_guardrails, resolve_guard_defaults, }; pub use crate::guardrails::runner::{blocked_response, warning_header_value}; diff --git a/src/state.rs b/src/state.rs index f6647d45..d5dcbf4f 100644 --- a/src/state.rs +++ b/src/state.rs @@ -161,7 +161,7 @@ impl AppState { _provider_registry: &Arc, model_registry: &Arc, ) -> axum::Router { - use crate::guardrails::builder::build_guardrail_resources; + use crate::guardrails::setup::build_guardrail_resources; use crate::pipelines::pipeline::create_pipeline; debug!("Building router with {} pipelines", config.pipelines.len()); diff --git a/tests/guardrails/test_api_control.rs b/tests/guardrails/test_api_control.rs index 62c3ab4a..e44059d9 100644 --- a/tests/guardrails/test_api_control.rs +++ b/tests/guardrails/test_api_control.rs @@ -1,4 +1,4 @@ -use hub_lib::guardrails::api_control::*; +use hub_lib::guardrails::setup::*; use hub_lib::guardrails::types::GuardMode; use super::helpers::*; diff --git a/tests/guardrails/test_pipeline.rs b/tests/guardrails/test_pipeline.rs index 78e2b694..fec01a98 100644 --- a/tests/guardrails/test_pipeline.rs +++ b/tests/guardrails/test_pipeline.rs @@ -1,5 +1,5 @@ use std::collections::HashMap; -use hub_lib::guardrails::api_control::{resolve_guards_by_name, split_guards_by_mode}; +use hub_lib::guardrails::setup::{resolve_guards_by_name, split_guards_by_mode}; use hub_lib::guardrails::runner::execute_guards; use hub_lib::guardrails::providers::traceloop::TraceloopClient; use hub_lib::guardrails::types::*; From c95024cac35cebcae602270e98c1e519bc310c37 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Sun, 15 Feb 2026 10:33:35 +0200 Subject: [PATCH 18/59] comments --- src/guardrails/runner.rs | 3 --- src/types/mod.rs | 4 ---- 2 files changed, 7 deletions(-) diff --git a/src/guardrails/runner.rs b/src/guardrails/runner.rs index f5139b19..a8d4d557 100644 --- a/src/guardrails/runner.rs +++ b/src/guardrails/runner.rs @@ -111,14 +111,11 @@ pub async fn execute_guards( } } -/// Result of running pre-call or post-call guards. pub struct GuardPhaseResult { pub blocked_response: Option, pub warnings: Vec, } -/// Runs guardrails across pre-call and post-call phases. -/// Shared between chat_completions and completions handlers. pub struct GuardrailsRunner<'a> { pre_call: Vec, post_call: Vec, diff --git a/src/types/mod.rs b/src/types/mod.rs index f4b1467f..fbbb21e8 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -149,10 +149,6 @@ pub struct Pipeline { // #[serde(with = "serde_yaml::with::singleton_map_recursive")] #[serde(default, skip_serializing_if = "Vec::is_empty")] pub plugins: Vec, - - /// Guard names associated with this pipeline. Guards listed here - /// are always executed for every request to this pipeline. Additional - /// guards can be added per-request via the `X-Traceloop-Guardrails` header. #[serde(default, skip_serializing_if = "Vec::is_empty")] pub guards: Vec, } From c12b0e7144b660243c5eac71a04b9a33f4baafbd Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Sun, 15 Feb 2026 10:48:04 +0200 Subject: [PATCH 19/59] test rename --- tests/guardrails/helpers.rs | 2 +- tests/guardrails/main.rs | 8 +- tests/guardrails/test_api_control.rs | 116 ------ tests/guardrails/test_e2e.rs | 125 +++++- tests/guardrails/test_input_extractor.rs | 76 ---- tests/guardrails/test_parsing.rs | 147 +++++++ tests/guardrails/test_response_parser.rs | 69 --- .../{test_executor.rs => test_runner.rs} | 6 +- .../{test_pipeline.rs => test_setup.rs} | 393 ++++++------------ tests/guardrails/test_traceloop_client.rs | 3 +- tests/guardrails/test_types.rs | 87 ++++ 11 files changed, 476 insertions(+), 556 deletions(-) delete mode 100644 tests/guardrails/test_api_control.rs delete mode 100644 tests/guardrails/test_input_extractor.rs create mode 100644 tests/guardrails/test_parsing.rs delete mode 100644 tests/guardrails/test_response_parser.rs rename tests/guardrails/{test_executor.rs => test_runner.rs} (98%) rename tests/guardrails/{test_pipeline.rs => test_setup.rs} (53%) diff --git a/tests/guardrails/helpers.rs b/tests/guardrails/helpers.rs index e3f76538..1d69039b 100644 --- a/tests/guardrails/helpers.rs +++ b/tests/guardrails/helpers.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use async_trait::async_trait; -use hub_lib::guardrails::providers::GuardrailClient; +use hub_lib::guardrails::types::GuardrailClient; use hub_lib::guardrails::types::{ EvaluatorResponse, Guard, GuardMode, GuardrailError, OnFailure, }; diff --git a/tests/guardrails/main.rs b/tests/guardrails/main.rs index dedbde9e..1e18de92 100644 --- a/tests/guardrails/main.rs +++ b/tests/guardrails/main.rs @@ -1,9 +1,7 @@ mod helpers; -mod test_api_control; mod test_e2e; -mod test_executor; -mod test_input_extractor; -mod test_pipeline; -mod test_response_parser; +mod test_parsing; +mod test_runner; +mod test_setup; mod test_traceloop_client; mod test_types; diff --git a/tests/guardrails/test_api_control.rs b/tests/guardrails/test_api_control.rs deleted file mode 100644 index e44059d9..00000000 --- a/tests/guardrails/test_api_control.rs +++ /dev/null @@ -1,116 +0,0 @@ -use hub_lib::guardrails::setup::*; -use hub_lib::guardrails::types::GuardMode; - -use super::helpers::*; - -// --------------------------------------------------------------------------- -// Phase 7: API Control (15 tests) -// --------------------------------------------------------------------------- - -#[test] -fn test_parse_guardrails_header_single() { - let names = parse_guardrails_header("pii-check"); - assert_eq!(names, vec!["pii-check"]); -} - -#[test] -fn test_parse_guardrails_header_multiple() { - let names = parse_guardrails_header("toxicity-check,relevance-check,pii-check"); - assert_eq!( - names, - vec!["toxicity-check", "relevance-check", "pii-check"] - ); -} - -#[test] -fn test_pipeline_guardrails_always_included() { - let pipeline_guards = vec![create_test_guard("pipeline-guard", GuardMode::PreCall)]; - let resolved = resolve_guards_by_name(&pipeline_guards, &["pipeline-guard"], &[]); - assert_eq!(resolved.len(), 1); - assert_eq!(resolved[0].name, "pipeline-guard"); -} - -#[test] -fn test_header_guardrails_additive_to_pipeline() { - let all_guards = vec![ - create_test_guard("pipeline-guard", GuardMode::PreCall), - create_test_guard("header-guard", GuardMode::PreCall), - ]; - let resolved = resolve_guards_by_name(&all_guards, &["pipeline-guard"], &["header-guard"]); - assert_eq!(resolved.len(), 2); -} - -#[test] -fn test_deduplication_by_name() { - let all_guards = vec![create_test_guard("shared-guard", GuardMode::PreCall)]; - let resolved = resolve_guards_by_name( - &all_guards, - &["shared-guard"], - &["shared-guard"], // duplicate - ); - assert_eq!(resolved.len(), 1); -} - -#[test] -fn test_unknown_guard_name_in_header_ignored() { - let all_guards = vec![create_test_guard("known-guard", GuardMode::PreCall)]; - let resolved = resolve_guards_by_name( - &all_guards, - &["known-guard"], - &["nonexistent-guard"], // unknown - ); - assert_eq!(resolved.len(), 1); - assert_eq!(resolved[0].name, "known-guard"); -} - -#[test] -fn test_empty_header_pipeline_guards_only() { - let all_guards = vec![create_test_guard("pipeline-guard", GuardMode::PreCall)]; - let resolved = resolve_guards_by_name(&all_guards, &["pipeline-guard"], &[]); - assert_eq!(resolved.len(), 1); - assert_eq!(resolved[0].name, "pipeline-guard"); -} - -#[test] -fn test_cannot_remove_pipeline_guardrails_via_api() { - let all_guards = vec![ - create_test_guard("pipeline-guard", GuardMode::PreCall), - create_test_guard("extra", GuardMode::PreCall), - ]; - // Even if header/payload only mention "extra", pipeline guard still included - let resolved = resolve_guards_by_name(&all_guards, &["pipeline-guard"], &["extra"]); - assert!(resolved.iter().any(|g| g.name == "pipeline-guard")); -} - -#[test] -fn test_guards_split_into_pre_and_post_call() { - let guards = vec![ - create_test_guard("pre-1", GuardMode::PreCall), - create_test_guard("post-1", GuardMode::PostCall), - create_test_guard("pre-2", GuardMode::PreCall), - create_test_guard("post-2", GuardMode::PostCall), - ]; - let (pre_call, post_call) = split_guards_by_mode(&guards); - assert_eq!(pre_call.len(), 2); - assert_eq!(post_call.len(), 2); - assert!(pre_call.iter().all(|g| g.mode == GuardMode::PreCall)); - assert!(post_call.iter().all(|g| g.mode == GuardMode::PostCall)); -} - -#[test] -fn test_complete_resolution_merged() { - let all_guards = vec![ - create_test_guard("pipeline-pre", GuardMode::PreCall), - create_test_guard("pipeline-post", GuardMode::PostCall), - create_test_guard("header-pre", GuardMode::PreCall), - ]; - let resolved = resolve_guards_by_name( - &all_guards, - &["pipeline-pre", "pipeline-post"], - &["header-pre"], - ); - assert_eq!(resolved.len(), 3); - let (pre, post) = split_guards_by_mode(&resolved); - assert_eq!(pre.len(), 2); // pipeline-pre + header-pre - assert_eq!(post.len(), 1); // pipeline-post -} diff --git a/tests/guardrails/test_e2e.rs b/tests/guardrails/test_e2e.rs index e010c037..2875d657 100644 --- a/tests/guardrails/test_e2e.rs +++ b/tests/guardrails/test_e2e.rs @@ -1,10 +1,10 @@ -use hub_lib::guardrails::runner::execute_guards; -use hub_lib::guardrails::parsing::{ - extract_post_call_input_from_completion, extract_prompt, -}; +use hub_lib::guardrails::parsing::{CompletionExtractor, PromptExtractor}; use hub_lib::guardrails::providers::traceloop::TraceloopClient; +use hub_lib::guardrails::runner::{blocked_response, execute_guards, warning_header_value}; +use hub_lib::guardrails::setup::{build_guardrail_resources, build_pipeline_guardrails}; use hub_lib::guardrails::types::*; -use hub_lib::pipelines::pipeline::{build_guardrail_resources, build_pipeline_guardrails}; + +use axum::body::to_bytes; use serde_json::json; use wiremock::matchers; @@ -65,7 +65,7 @@ async fn test_e2e_pre_call_block_flow() { ); let request = create_test_chat_request("Bad input"); - let input = extract_prompt(&request); + let input = request.extract_pompt(); let client = TraceloopClient::new(); let outcome = execute_guards(&[guard], &input, &client).await; @@ -87,7 +87,7 @@ async fn test_e2e_pre_call_pass_flow() { ); let request = create_test_chat_request("Safe input"); - let input = extract_prompt(&request); + let input = request.extract_pompt(); let client = TraceloopClient::new(); let outcome = execute_guards(&[guard], &input, &client).await; @@ -111,7 +111,7 @@ async fn test_e2e_post_call_block_flow() { // Simulate LLM response let completion = create_test_chat_completion("Here is the SSN: 123-45-6789"); - let response_text = extract_post_call_input_from_completion(&completion); + let response_text = completion.extract_completion(); let client = TraceloopClient::new(); let outcome = execute_guards(&[guard], &response_text, &client).await; @@ -133,7 +133,7 @@ async fn test_e2e_post_call_warn_flow() { ); let completion = create_test_chat_completion("Mildly concerning response"); - let response_text = extract_post_call_input_from_completion(&completion); + let response_text = completion.extract_completion(); let client = TraceloopClient::new(); let outcome = execute_guards(&[guard], &response_text, &client).await; @@ -168,13 +168,13 @@ async fn test_e2e_pre_and_post_both_pass() { // Pre-call let request = create_test_chat_request("Hello"); - let input = extract_prompt(&request); + let input = request.extract_pompt(); let pre_outcome = execute_guards(&[pre_guard], &input, &client).await; assert!(!pre_outcome.blocked); // Post-call let completion = create_test_chat_completion("Hi there!"); - let response_text = extract_post_call_input_from_completion(&completion); + let response_text = completion.extract_completion(); let post_outcome = execute_guards(&[post_guard], &response_text, &client).await; assert!(!post_outcome.blocked); assert!(post_outcome.warnings.is_empty()); @@ -210,7 +210,7 @@ async fn test_e2e_pre_blocks_post_never_runs() { let client = TraceloopClient::new(); let request = create_test_chat_request("Bad input"); - let input = extract_prompt(&request); + let input = request.extract_pompt(); let pre_outcome = execute_guards(&[pre_guard], &input, &client).await; assert!(pre_outcome.blocked); @@ -532,3 +532,104 @@ pipelines: .and_then(build_guardrail_resources); assert!(shared.is_none()); } + +// --------------------------------------------------------------------------- +// Pipeline Integration (4 tests) +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_pre_call_guardrails_warn_and_continue() { + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(json!({"result": {"reason": "borderline"}, "pass": false})), + ) + .expect(1) + .mount(&eval_server) + .await; + + let guard = guard_with_server( + "tone-check", + GuardMode::PreCall, + OnFailure::Warn, + &eval_server.uri(), + "tone", + ); + + let client = TraceloopClient::new(); + let outcome = execute_guards(&[guard], "borderline input", &client).await; + + assert!(!outcome.blocked); + assert_eq!(outcome.warnings.len(), 1); + assert_eq!(outcome.warnings[0].guard_name, "tone-check"); +} + +#[tokio::test] +async fn test_post_call_guardrails_warn_and_add_header() { + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(json!({"result": {"reason": "mildly concerning"}, "pass": false})), + ) + .expect(1) + .mount(&eval_server) + .await; + + let guard = guard_with_server( + "safety-check", + GuardMode::PostCall, + OnFailure::Warn, + &eval_server.uri(), + "safety", + ); + + let client = TraceloopClient::new(); + let outcome = execute_guards(&[guard], "Some LLM response", &client).await; + + assert!(!outcome.blocked); + assert!(!outcome.warnings.is_empty()); + + // Verify warning header would be generated correctly + let header = warning_header_value(&outcome.warnings); + assert!(header.contains("guardrail_name=")); + assert!(header.contains("safety-check")); +} + +#[tokio::test] +async fn test_warning_header_format() { + let warnings = vec![GuardWarning { + guard_name: "my-guard".to_string(), + reason: "failed".to_string(), + }]; + let header = warning_header_value(&warnings); + assert_eq!(header, "guardrail_name=\"my-guard\", reason=\"failed\""); +} + +#[tokio::test] +async fn test_blocked_response_403_format() { + let outcome = GuardrailsOutcome { + results: vec![GuardResult::Failed { + name: "toxicity-check".to_string(), + result: json!({"reason": "toxic content"}), + on_failure: OnFailure::Block, + }], + blocked: true, + blocking_guard: Some("toxicity-check".to_string()), + warnings: vec![], + }; + let response = blocked_response(&outcome); + assert_eq!(response.status(), 403); + + let body = to_bytes(response.into_body(), 1024 * 1024).await.unwrap(); + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + assert_eq!(json["error"]["type"], "guardrail_blocked"); + assert_eq!(json["error"]["guardrail"], "toxicity-check"); + assert!( + json["error"]["message"] + .as_str() + .unwrap() + .contains("toxicity-check") + ); +} diff --git a/tests/guardrails/test_input_extractor.rs b/tests/guardrails/test_input_extractor.rs deleted file mode 100644 index cc88d62d..00000000 --- a/tests/guardrails/test_input_extractor.rs +++ /dev/null @@ -1,76 +0,0 @@ -use hub_lib::guardrails::parsing::*; -use hub_lib::models::content::{ChatCompletionMessage, ChatMessageContent, ChatMessageContentPart}; - -use super::helpers::*; - -// --------------------------------------------------------------------------- -// Phase 2: Input Extractor (5 tests) -// --------------------------------------------------------------------------- - -#[test] -fn test_extract_text_single_user_message() { - let request = create_test_chat_request("Hello world"); - let text = extract_prompt(&request); - assert_eq!(text, "Hello world"); -} - -#[test] -fn test_extract_text_multi_turn_conversation() { - let mut request = default_request(); - request.messages = vec![ - ChatCompletionMessage { - role: "system".to_string(), - content: Some(ChatMessageContent::String("You are helpful".to_string())), - ..default_message() - }, - ChatCompletionMessage { - role: "user".to_string(), - content: Some(ChatMessageContent::String("First question".to_string())), - ..default_message() - }, - ChatCompletionMessage { - role: "assistant".to_string(), - content: Some(ChatMessageContent::String("First answer".to_string())), - ..default_message() - }, - ChatCompletionMessage { - role: "user".to_string(), - content: Some(ChatMessageContent::String("Follow-up question".to_string())), - ..default_message() - }, - ]; - let text = extract_prompt(&request); - assert_eq!(text, "Follow-up question"); -} - -#[test] -fn test_extract_text_from_array_content_parts() { - let mut request = create_test_chat_request(""); - request.messages[0].content = Some(ChatMessageContent::Array(vec![ - ChatMessageContentPart { - r#type: "text".to_string(), - text: "Part 1".to_string(), - }, - ChatMessageContentPart { - r#type: "text".to_string(), - text: "Part 2".to_string(), - }, - ])); - let text = extract_prompt(&request); - assert_eq!(text, "Part 1 Part 2"); -} - -#[test] -fn test_extract_response_from_chat_completion() { - let completion = create_test_chat_completion("Here is my response"); - let text = extract_post_call_input_from_completion(&completion); - assert_eq!(text, "Here is my response"); -} - -#[test] -fn test_extract_handles_empty_content() { - let mut request = create_test_chat_request(""); - request.messages[0].content = None; - let text = extract_prompt(&request); - assert_eq!(text, ""); -} diff --git a/tests/guardrails/test_parsing.rs b/tests/guardrails/test_parsing.rs new file mode 100644 index 00000000..95ad82f3 --- /dev/null +++ b/tests/guardrails/test_parsing.rs @@ -0,0 +1,147 @@ +use hub_lib::guardrails::parsing::{ + parse_evaluator_response, parse_evaluator_http_response, + CompletionExtractor, PromptExtractor, +}; +use hub_lib::guardrails::types::GuardrailError; +use hub_lib::models::content::{ChatCompletionMessage, ChatMessageContent, ChatMessageContentPart}; + +use super::helpers::*; + +// --------------------------------------------------------------------------- +// Input Extraction (5 tests) +// --------------------------------------------------------------------------- + +#[test] +fn test_extract_text_single_user_message() { + let request = create_test_chat_request("Hello world"); + let text = request.extract_pompt(); + assert_eq!(text, "Hello world"); +} + +#[test] +fn test_extract_text_multi_turn_conversation() { + let mut request = default_request(); + request.messages = vec![ + ChatCompletionMessage { + role: "system".to_string(), + content: Some(ChatMessageContent::String("You are helpful".to_string())), + ..default_message() + }, + ChatCompletionMessage { + role: "user".to_string(), + content: Some(ChatMessageContent::String("First question".to_string())), + ..default_message() + }, + ChatCompletionMessage { + role: "assistant".to_string(), + content: Some(ChatMessageContent::String("First answer".to_string())), + ..default_message() + }, + ChatCompletionMessage { + role: "user".to_string(), + content: Some(ChatMessageContent::String("Follow-up question".to_string())), + ..default_message() + }, + ]; + let text = request.extract_pompt(); + assert_eq!(text, "Follow-up question"); +} + +#[test] +fn test_extract_text_from_array_content_parts() { + let mut request = create_test_chat_request(""); + request.messages[0].content = Some(ChatMessageContent::Array(vec![ + ChatMessageContentPart { + r#type: "text".to_string(), + text: "Part 1".to_string(), + }, + ChatMessageContentPart { + r#type: "text".to_string(), + text: "Part 2".to_string(), + }, + ])); + let text = request.extract_pompt(); + assert_eq!(text, "Part 1 Part 2"); +} + +#[test] +fn test_extract_response_from_chat_completion() { + let completion = create_test_chat_completion("Here is my response"); + let text = completion.extract_completion(); + assert_eq!(text, "Here is my response"); +} + +#[test] +fn test_extract_handles_empty_content() { + let mut request = create_test_chat_request(""); + request.messages[0].content = None; + let text = request.extract_pompt(); + assert_eq!(text, ""); +} + +// --------------------------------------------------------------------------- +// Response Parsing (8 tests) +// --------------------------------------------------------------------------- + +#[test] +fn test_parse_successful_pass_response() { + let body = r#"{"result": {"score": 0.95, "label": "safe"}, "pass": true}"#; + let response = parse_evaluator_response(body).unwrap(); + assert!(response.pass); + assert_eq!(response.result["score"], 0.95); +} + +#[test] +fn test_parse_failed_response() { + let body = r#"{"result": {"score": 0.2, "reason": "Toxic content"}, "pass": false}"#; + let response = parse_evaluator_response(body).unwrap(); + assert!(!response.pass); + assert_eq!(response.result["reason"], "Toxic content"); +} + +#[test] +fn test_parse_with_result_details() { + let body = r#"{"result": {"score": 0.75, "label": "borderline", "categories": ["violence", "profanity"]}, "pass": true}"#; + let response = parse_evaluator_response(body).unwrap(); + assert!(response.pass); + assert_eq!(response.result["label"], "borderline"); + assert_eq!(response.result["categories"][0], "violence"); +} + +#[test] +fn test_parse_missing_pass_field() { + let body = r#"{"result": {"score": 0.5}}"#; + let result = parse_evaluator_response(body); + assert!(result.is_err()); +} + +#[test] +fn test_parse_malformed_json() { + let body = "not json {at all"; + let result = parse_evaluator_response(body); + assert!(result.is_err()); +} + +#[test] +fn test_parse_non_json_response() { + let body = "Internal Server Error"; + let result = parse_evaluator_response(body); + assert!(result.is_err()); +} + +#[test] +fn test_parse_empty_response_body() { + let body = ""; + let result = parse_evaluator_response(body); + assert!(result.is_err()); +} + +#[test] +fn test_parse_http_error_status() { + let result = parse_evaluator_http_response(500, "Internal Server Error"); + assert!(result.is_err()); + match result.unwrap_err() { + GuardrailError::HttpError { status, .. } => assert_eq!(status, 500), + other => panic!("Expected HttpError, got {other:?}"), + } +} diff --git a/tests/guardrails/test_response_parser.rs b/tests/guardrails/test_response_parser.rs deleted file mode 100644 index 7ed9813a..00000000 --- a/tests/guardrails/test_response_parser.rs +++ /dev/null @@ -1,69 +0,0 @@ -use hub_lib::guardrails::parsing::*; -use hub_lib::guardrails::types::GuardrailError; - -// --------------------------------------------------------------------------- -// Phase 3: Response Parser (8 tests) -// --------------------------------------------------------------------------- - -#[test] -fn test_parse_successful_pass_response() { - let body = r#"{"result": {"score": 0.95, "label": "safe"}, "pass": true}"#; - let response = parse_evaluator_response(body).unwrap(); - assert!(response.pass); - assert_eq!(response.result["score"], 0.95); -} - -#[test] -fn test_parse_failed_response() { - let body = r#"{"result": {"score": 0.2, "reason": "Toxic content"}, "pass": false}"#; - let response = parse_evaluator_response(body).unwrap(); - assert!(!response.pass); - assert_eq!(response.result["reason"], "Toxic content"); -} - -#[test] -fn test_parse_with_result_details() { - let body = r#"{"result": {"score": 0.75, "label": "borderline", "categories": ["violence", "profanity"]}, "pass": true}"#; - let response = parse_evaluator_response(body).unwrap(); - assert!(response.pass); - assert_eq!(response.result["label"], "borderline"); - assert_eq!(response.result["categories"][0], "violence"); -} - -#[test] -fn test_parse_missing_pass_field() { - let body = r#"{"result": {"score": 0.5}}"#; - let result = parse_evaluator_response(body); - assert!(result.is_err()); -} - -#[test] -fn test_parse_malformed_json() { - let body = "not json {at all"; - let result = parse_evaluator_response(body); - assert!(result.is_err()); -} - -#[test] -fn test_parse_non_json_response() { - let body = "Internal Server Error"; - let result = parse_evaluator_response(body); - assert!(result.is_err()); -} - -#[test] -fn test_parse_empty_response_body() { - let body = ""; - let result = parse_evaluator_response(body); - assert!(result.is_err()); -} - -#[test] -fn test_parse_http_error_status() { - let result = parse_evaluator_http_response(500, "Internal Server Error"); - assert!(result.is_err()); - match result.unwrap_err() { - GuardrailError::HttpError { status, .. } => assert_eq!(status, 500), - other => panic!("Expected HttpError, got {other:?}"), - } -} diff --git a/tests/guardrails/test_executor.rs b/tests/guardrails/test_runner.rs similarity index 98% rename from tests/guardrails/test_executor.rs rename to tests/guardrails/test_runner.rs index 0ae292f0..fbf5574d 100644 --- a/tests/guardrails/test_executor.rs +++ b/tests/guardrails/test_runner.rs @@ -1,11 +1,11 @@ +use hub_lib::guardrails::parsing::CompletionExtractor; use hub_lib::guardrails::runner::*; -use hub_lib::guardrails::parsing::*; use hub_lib::guardrails::types::*; use super::helpers::*; // --------------------------------------------------------------------------- -// Phase 5: Executor (12 tests) +// Guard Execution (12 tests) // --------------------------------------------------------------------------- #[tokio::test] @@ -130,7 +130,7 @@ async fn test_execute_post_call_guards_non_streaming() { let guard = create_test_guard("response-check", GuardMode::PostCall); let mock_client = MockGuardrailClient::with_response("response-check", Ok(passing_response())); let completion = create_test_chat_completion("Safe response text"); - let response_text = extract_post_call_input_from_completion(&completion); + let response_text = completion.extract_completion(); let outcome = execute_guards(&[guard], &response_text, &mock_client).await; assert!(!outcome.blocked); } diff --git a/tests/guardrails/test_pipeline.rs b/tests/guardrails/test_setup.rs similarity index 53% rename from tests/guardrails/test_pipeline.rs rename to tests/guardrails/test_setup.rs index fec01a98..04707ad1 100644 --- a/tests/guardrails/test_pipeline.rs +++ b/tests/guardrails/test_setup.rs @@ -1,212 +1,124 @@ use std::collections::HashMap; -use hub_lib::guardrails::setup::{resolve_guards_by_name, split_guards_by_mode}; -use hub_lib::guardrails::runner::execute_guards; -use hub_lib::guardrails::providers::traceloop::TraceloopClient; -use hub_lib::guardrails::types::*; -use hub_lib::pipelines::pipeline::{ - blocked_response, build_guardrail_resources, build_pipeline_guardrails, resolve_guard_defaults, - warning_header_value, -}; -use axum::body::to_bytes; -use serde_json::json; -use wiremock::matchers; -use wiremock::{Mock, MockServer, ResponseTemplate}; +use hub_lib::guardrails::setup::*; +use hub_lib::guardrails::types::*; use super::helpers::*; // --------------------------------------------------------------------------- -// Phase 6: Pipeline Integration (7 tests) -// -// These tests verify that guardrails are properly wired into the pipeline -// request handling flow. They use wiremock for the evaluator service. +// Header Parsing & Guard Resolution (15 tests) // --------------------------------------------------------------------------- -#[tokio::test] -async fn test_pre_call_guardrails_block_before_llm() { - // Set up evaluator mock that rejects the input - let eval_server = MockServer::start().await; - Mock::given(matchers::method("POST")) - .respond_with( - ResponseTemplate::new(200) - .set_body_json(json!({"result": {"reason": "toxic"}, "pass": false})), - ) - .expect(1) - .mount(&eval_server) - .await; - - let guard = Guard { - name: "toxicity-check".to_string(), - provider: "traceloop".to_string(), - evaluator_slug: "toxicity".to_string(), - params: Default::default(), - mode: GuardMode::PreCall, - on_failure: OnFailure::Block, - required: true, - api_base: Some(eval_server.uri()), - api_key: Some("test-key".to_string()), - }; - - let client = TraceloopClient::new(); - let outcome = execute_guards(&[guard], "toxic input", &client).await; - - assert!(outcome.blocked); - assert_eq!(outcome.blocking_guard.as_deref(), Some("toxicity-check")); +#[test] +fn test_parse_guardrails_header_single() { + let names = parse_guardrails_header("pii-check"); + assert_eq!(names, vec!["pii-check"]); } -#[tokio::test] -async fn test_pre_call_guardrails_warn_and_continue() { - let eval_server = MockServer::start().await; - Mock::given(matchers::method("POST")) - .respond_with( - ResponseTemplate::new(200) - .set_body_json(json!({"result": {"reason": "borderline"}, "pass": false})), - ) - .expect(1) - .mount(&eval_server) - .await; - - let guard = Guard { - name: "tone-check".to_string(), - provider: "traceloop".to_string(), - evaluator_slug: "tone".to_string(), - params: Default::default(), - mode: GuardMode::PreCall, - on_failure: OnFailure::Warn, - required: true, - api_base: Some(eval_server.uri()), - api_key: Some("test-key".to_string()), - }; - - let client = TraceloopClient::new(); - let outcome = execute_guards(&[guard], "borderline input", &client).await; - - assert!(!outcome.blocked); - assert_eq!(outcome.warnings.len(), 1); - assert_eq!(outcome.warnings[0].guard_name, "tone-check"); +#[test] +fn test_parse_guardrails_header_multiple() { + let names = parse_guardrails_header("toxicity-check,relevance-check,pii-check"); + assert_eq!( + names, + vec!["toxicity-check", "relevance-check", "pii-check"] + ); } -#[tokio::test] -async fn test_post_call_guardrails_block_response() { - let eval_server = MockServer::start().await; - Mock::given(matchers::method("POST")) - .respond_with( - ResponseTemplate::new(200) - .set_body_json(json!({"result": {"reason": "pii detected"}, "pass": false})), - ) - .expect(1) - .mount(&eval_server) - .await; - - let guard = Guard { - name: "pii-check".to_string(), - provider: "traceloop".to_string(), - evaluator_slug: "pii".to_string(), - params: Default::default(), - mode: GuardMode::PostCall, - on_failure: OnFailure::Block, - required: true, - api_base: Some(eval_server.uri()), - api_key: Some("test-key".to_string()), - }; - - let client = TraceloopClient::new(); - // Simulate post-call: evaluate the LLM response text - let outcome = execute_guards(&[guard], "Here is John's SSN: 123-45-6789", &client).await; - - assert!(outcome.blocked); - assert_eq!(outcome.blocking_guard.as_deref(), Some("pii-check")); +#[test] +fn test_pipeline_guardrails_always_included() { + let pipeline_guards = vec![create_test_guard("pipeline-guard", GuardMode::PreCall)]; + let resolved = resolve_guards_by_name(&pipeline_guards, &["pipeline-guard"], &[]); + assert_eq!(resolved.len(), 1); + assert_eq!(resolved[0].name, "pipeline-guard"); } -#[tokio::test] -async fn test_post_call_guardrails_warn_and_add_header() { - let eval_server = MockServer::start().await; - Mock::given(matchers::method("POST")) - .respond_with( - ResponseTemplate::new(200) - .set_body_json(json!({"result": {"reason": "mildly concerning"}, "pass": false})), - ) - .expect(1) - .mount(&eval_server) - .await; - - let guard = Guard { - name: "safety-check".to_string(), - provider: "traceloop".to_string(), - evaluator_slug: "safety".to_string(), - params: Default::default(), - mode: GuardMode::PostCall, - on_failure: OnFailure::Warn, - required: true, - api_base: Some(eval_server.uri()), - api_key: Some("test-key".to_string()), - }; - - let client = TraceloopClient::new(); - let outcome = execute_guards(&[guard], "Some LLM response", &client).await; +#[test] +fn test_header_guardrails_additive_to_pipeline() { + let all_guards = vec![ + create_test_guard("pipeline-guard", GuardMode::PreCall), + create_test_guard("header-guard", GuardMode::PreCall), + ]; + let resolved = resolve_guards_by_name(&all_guards, &["pipeline-guard"], &["header-guard"]); + assert_eq!(resolved.len(), 2); +} - assert!(!outcome.blocked); - assert!(!outcome.warnings.is_empty()); +#[test] +fn test_deduplication_by_name() { + let all_guards = vec![create_test_guard("shared-guard", GuardMode::PreCall)]; + let resolved = resolve_guards_by_name( + &all_guards, + &["shared-guard"], + &["shared-guard"], // duplicate + ); + assert_eq!(resolved.len(), 1); +} - // Verify warning header would be generated correctly - let header = warning_header_value(&outcome.warnings); - assert!(header.contains("guardrail_name=")); - assert!(header.contains("safety-check")); +#[test] +fn test_unknown_guard_name_in_header_ignored() { + let all_guards = vec![create_test_guard("known-guard", GuardMode::PreCall)]; + let resolved = resolve_guards_by_name( + &all_guards, + &["known-guard"], + &["nonexistent-guard"], // unknown + ); + assert_eq!(resolved.len(), 1); + assert_eq!(resolved[0].name, "known-guard"); } -#[tokio::test] -async fn test_warning_header_format() { - let warnings = vec![GuardWarning { - guard_name: "my-guard".to_string(), - reason: "failed".to_string(), - }]; - let header = warning_header_value(&warnings); - assert_eq!(header, "guardrail_name=\"my-guard\", reason=\"failed\""); +#[test] +fn test_empty_header_pipeline_guards_only() { + let all_guards = vec![create_test_guard("pipeline-guard", GuardMode::PreCall)]; + let resolved = resolve_guards_by_name(&all_guards, &["pipeline-guard"], &[]); + assert_eq!(resolved.len(), 1); + assert_eq!(resolved[0].name, "pipeline-guard"); } -#[tokio::test] -async fn test_blocked_response_403_format() { - let blocking_guard = Some("toxicity-check".to_string()); - let response = blocked_response(&blocking_guard); - assert_eq!(response.status(), 403); - - let body = to_bytes(response.into_body(), 1024 * 1024).await.unwrap(); - let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); - assert_eq!(json["error"]["type"], "guardrail_blocked"); - assert_eq!(json["error"]["guardrail"], "toxicity-check"); - assert!( - json["error"]["message"] - .as_str() - .unwrap() - .contains("toxicity-check") - ); +#[test] +fn test_cannot_remove_pipeline_guardrails_via_api() { + let all_guards = vec![ + create_test_guard("pipeline-guard", GuardMode::PreCall), + create_test_guard("extra", GuardMode::PreCall), + ]; + // Even if header/payload only mention "extra", pipeline guard still included + let resolved = resolve_guards_by_name(&all_guards, &["pipeline-guard"], &["extra"]); + assert!(resolved.iter().any(|g| g.name == "pipeline-guard")); } -#[tokio::test] -async fn test_no_guardrails_passthrough() { - // Empty guardrails config -> build_guardrail_resources returns None - let config = GuardrailsConfig { - providers: Default::default(), - guards: vec![], - }; - let result = build_guardrail_resources(&config); - assert!(result.is_none()); +#[test] +fn test_guards_split_into_pre_and_post_call() { + let guards = vec![ + create_test_guard("pre-1", GuardMode::PreCall), + create_test_guard("post-1", GuardMode::PostCall), + create_test_guard("pre-2", GuardMode::PreCall), + create_test_guard("post-2", GuardMode::PostCall), + ]; + let (pre_call, post_call) = split_guards_by_mode(&guards); + assert_eq!(pre_call.len(), 2); + assert_eq!(post_call.len(), 2); + assert!(pre_call.iter().all(|g| g.mode == GuardMode::PreCall)); + assert!(post_call.iter().all(|g| g.mode == GuardMode::PostCall)); +} - // Config with no guards -> passthrough - let config_with_providers = GuardrailsConfig { - providers: HashMap::from([("traceloop".to_string(), ProviderConfig { - name: "traceloop".to_string(), - api_base: "http://localhost".to_string(), - api_key: "key".to_string(), - })]), - guards: vec![], - }; - let result = build_guardrail_resources(&config_with_providers); - assert!(result.is_none()); +#[test] +fn test_complete_resolution_merged() { + let all_guards = vec![ + create_test_guard("pipeline-pre", GuardMode::PreCall), + create_test_guard("pipeline-post", GuardMode::PostCall), + create_test_guard("header-pre", GuardMode::PreCall), + ]; + let resolved = resolve_guards_by_name( + &all_guards, + &["pipeline-pre", "pipeline-post"], + &["header-pre"], + ); + assert_eq!(resolved.len(), 3); + let (pre, post) = split_guards_by_mode(&resolved); + assert_eq!(pre.len(), 2); // pipeline-pre + header-pre + assert_eq!(post.len(), 1); // pipeline-post } // --------------------------------------------------------------------------- -// Pipeline-specific guard association tests +// Pipeline Guard Building & Provider Defaults // --------------------------------------------------------------------------- fn test_guardrails_config() -> GuardrailsConfig { @@ -265,6 +177,29 @@ fn test_guardrails_config() -> GuardrailsConfig { } } +#[test] +fn test_no_guardrails_passthrough() { + // Empty guardrails config -> build_guardrail_resources returns None + let config = GuardrailsConfig { + providers: Default::default(), + guards: vec![], + }; + let result = build_guardrail_resources(&config); + assert!(result.is_none()); + + // Config with no guards -> passthrough + let config_with_providers = GuardrailsConfig { + providers: HashMap::from([("traceloop".to_string(), ProviderConfig { + name: "traceloop".to_string(), + api_base: "http://localhost".to_string(), + api_key: "key".to_string(), + })]), + guards: vec![], + }; + let result = build_guardrail_resources(&config_with_providers); + assert!(result.is_none()); +} + #[test] fn test_build_pipeline_guardrails_with_specific_guards() { let config = test_guardrails_config(); @@ -284,7 +219,8 @@ fn test_build_pipeline_guardrails_empty_pipeline_guards() { let shared = build_guardrail_resources(&config).unwrap(); // Pipeline with no guards specified - shared resources still exist // (header guards can still be used at request time) - let gr = build_pipeline_guardrails(&shared); + let empty: Vec = vec![]; + let gr = build_pipeline_guardrails(&shared, &empty); assert_eq!(gr.all_guards.len(), 4); assert!(gr.pipeline_guard_names.is_empty()); @@ -427,92 +363,3 @@ fn test_no_pipeline_guards_no_header_no_guards_executed() { assert!(pre_call.is_empty()); assert!(post_call.is_empty()); } - -#[test] -fn test_pipeline_guards_field_in_yaml_config() { - use std::io::Write; - use tempfile::NamedTempFile; - - let config_yaml = r#" -providers: - - key: openai - type: openai - api_key: "sk-test" -models: - - key: gpt-4 - type: gpt-4 - provider: openai -pipelines: - - name: default - type: chat - guards: - - pii-check - - injection-check - plugins: - - model-router: - models: - - gpt-4 - - name: embeddings - type: embeddings - plugins: - - model-router: - models: - - gpt-4 -guardrails: - providers: - - name: traceloop - api_base: "https://api.traceloop.com" - api_key: "test-key" - guards: - - name: pii-check - provider: traceloop - evaluator_slug: pii - mode: pre_call - on_failure: block - - name: injection-check - provider: traceloop - evaluator_slug: injection - mode: pre_call - on_failure: block -"#; - - let mut temp_file = NamedTempFile::new().unwrap(); - temp_file.write_all(config_yaml.as_bytes()).unwrap(); - - let config = hub_lib::config::load_config(temp_file.path().to_str().unwrap()).unwrap(); - - // Default pipeline should have guards - assert_eq!(config.pipelines[0].guards, vec!["pii-check", "injection-check"]); - // Embeddings pipeline should have no guards - assert!(config.pipelines[1].guards.is_empty()); -} - -#[test] -fn test_pipeline_guards_field_absent_defaults_to_empty() { - use std::io::Write; - use tempfile::NamedTempFile; - - let config_yaml = r#" -providers: - - key: openai - type: openai - api_key: "sk-test" -models: - - key: gpt-4 - type: gpt-4 - provider: openai -pipelines: - - name: default - type: chat - plugins: - - model-router: - models: - - gpt-4 -"#; - - let mut temp_file = NamedTempFile::new().unwrap(); - temp_file.write_all(config_yaml.as_bytes()).unwrap(); - - let config = hub_lib::config::load_config(temp_file.path().to_str().unwrap()).unwrap(); - assert!(config.pipelines[0].guards.is_empty()); -} diff --git a/tests/guardrails/test_traceloop_client.rs b/tests/guardrails/test_traceloop_client.rs index afe1977e..e25ac042 100644 --- a/tests/guardrails/test_traceloop_client.rs +++ b/tests/guardrails/test_traceloop_client.rs @@ -1,5 +1,6 @@ use hub_lib::guardrails::providers::traceloop::TraceloopClient; -use hub_lib::guardrails::providers::{GuardrailClient, create_guardrail_client}; +use hub_lib::guardrails::providers::create_guardrail_client; +use hub_lib::guardrails::types::GuardrailClient; use hub_lib::guardrails::types::GuardMode; use serde_json::json; use wiremock::matchers; diff --git a/tests/guardrails/test_types.rs b/tests/guardrails/test_types.rs index 8dcfa847..c81dd712 100644 --- a/tests/guardrails/test_types.rs +++ b/tests/guardrails/test_types.rs @@ -250,3 +250,90 @@ fn test_guard_config_evaluator_slug_not_in_params() { assert!(!guard.params.contains_key("evaluator_slug")); assert_eq!(guard.params.get("threshold").unwrap(), 0.5); } + +// --------------------------------------------------------------------------- +// Pipeline config YAML parsing +// --------------------------------------------------------------------------- + +#[test] +fn test_pipeline_guards_field_in_yaml_config() { + let config_yaml = r#" +providers: + - key: openai + type: openai + api_key: "sk-test" +models: + - key: gpt-4 + type: gpt-4 + provider: openai +pipelines: + - name: default + type: chat + guards: + - pii-check + - injection-check + plugins: + - model-router: + models: + - gpt-4 + - name: embeddings + type: embeddings + plugins: + - model-router: + models: + - gpt-4 +guardrails: + providers: + - name: traceloop + api_base: "https://api.traceloop.com" + api_key: "test-key" + guards: + - name: pii-check + provider: traceloop + evaluator_slug: pii + mode: pre_call + on_failure: block + - name: injection-check + provider: traceloop + evaluator_slug: injection + mode: pre_call + on_failure: block +"#; + + let mut temp_file = NamedTempFile::new().unwrap(); + temp_file.write_all(config_yaml.as_bytes()).unwrap(); + + let config = hub_lib::config::load_config(temp_file.path().to_str().unwrap()).unwrap(); + + // Default pipeline should have guards + assert_eq!(config.pipelines[0].guards, vec!["pii-check", "injection-check"]); + // Embeddings pipeline should have no guards + assert!(config.pipelines[1].guards.is_empty()); +} + +#[test] +fn test_pipeline_guards_field_absent_defaults_to_empty() { + let config_yaml = r#" +providers: + - key: openai + type: openai + api_key: "sk-test" +models: + - key: gpt-4 + type: gpt-4 + provider: openai +pipelines: + - name: default + type: chat + plugins: + - model-router: + models: + - gpt-4 +"#; + + let mut temp_file = NamedTempFile::new().unwrap(); + temp_file.write_all(config_yaml.as_bytes()).unwrap(); + + let config = hub_lib::config::load_config(temp_file.path().to_str().unwrap()).unwrap(); + assert!(config.pipelines[0].guards.is_empty()); +} From 70a0ff14efdd73680bd4fdd9275fc0e2bf56e0ff Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Sun, 15 Feb 2026 10:50:19 +0200 Subject: [PATCH 20/59] format --- Cargo.lock | 2 +- src/config/validation.rs | 106 ++++++++++++++-------- src/guardrails/evaluator_types.rs | 6 +- src/guardrails/providers/traceloop.rs | 39 +++++--- src/guardrails/runner.rs | 28 ++++-- src/guardrails/setup.rs | 13 +-- src/pipelines/pipeline.rs | 28 +++--- src/state.rs | 7 +- tests/guardrails/helpers.rs | 4 +- tests/guardrails/test_e2e.rs | 22 +++-- tests/guardrails/test_parsing.rs | 3 +- tests/guardrails/test_setup.rs | 55 ++++++----- tests/guardrails/test_traceloop_client.rs | 4 +- tests/guardrails/test_types.rs | 10 +- 14 files changed, 198 insertions(+), 129 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7349a059..d730fa83 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2269,7 +2269,7 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" [[package]] name = "hub" -version = "0.7.6" +version = "0.7.7" dependencies = [ "anyhow", "async-stream", diff --git a/src/config/validation.rs b/src/config/validation.rs index 7576d5b7..16da3c4c 100644 --- a/src/config/validation.rs +++ b/src/config/validation.rs @@ -39,7 +39,7 @@ pub fn validate_gateway_config(config: &GatewayConfig) -> Result<(), Vec // Check 3: Guardrails validation if let Some(gr_config) = &config.guardrails { - // Guard provider references must exist in guardrails.providers + // Guard provider references must exist in guardrails.providers for guard in &gr_config.guards { if !gr_config.providers.contains_key(&guard.provider) { errors.push(format!( @@ -65,10 +65,14 @@ pub fn validate_gateway_config(config: &GatewayConfig) -> Result<(), Vec // Guards must have api_base and api_key (either directly or via provider) for guard in &gr_config.guards { let has_api_base = guard.api_base.as_ref().is_some_and(|s| !s.is_empty()) - || gr_config.providers.get(&guard.provider) + || gr_config + .providers + .get(&guard.provider) .is_some_and(|p| !p.api_base.is_empty()); let has_api_key = guard.api_key.as_ref().is_some_and(|s| !s.is_empty()) - || gr_config.providers.get(&guard.provider) + || gr_config + .providers + .get(&guard.provider) .is_some_and(|p| !p.api_key.is_empty()); if !has_api_base { @@ -114,9 +118,11 @@ pub fn validate_gateway_config(config: &GatewayConfig) -> Result<(), Vec #[cfg(test)] mod tests { use super::*; // To import validate_gateway_config - use std::collections::HashMap; - use crate::guardrails::types::{Guard, GuardMode, GuardrailsConfig, OnFailure, ProviderConfig as GrProviderConfig}; - use crate::types::{ModelConfig, Pipeline, PipelineType, PluginConfig, Provider, ProviderType}; // For test data + use crate::guardrails::types::{ + Guard, GuardMode, GuardrailsConfig, OnFailure, ProviderConfig as GrProviderConfig, + }; + use crate::types::{ModelConfig, Pipeline, PipelineType, PluginConfig, Provider, ProviderType}; + use std::collections::HashMap; // For test data #[test] fn test_valid_config() { @@ -210,11 +216,14 @@ mod tests { fn test_guard_references_non_existent_guardrail_provider() { let config = GatewayConfig { guardrails: Some(GuardrailsConfig { - providers: HashMap::from([("gr_p1".to_string(), GrProviderConfig { - name: "gr_p1".to_string(), - api_base: "http://localhost".to_string(), - api_key: "key".to_string(), - })]), + providers: HashMap::from([( + "gr_p1".to_string(), + GrProviderConfig { + name: "gr_p1".to_string(), + api_base: "http://localhost".to_string(), + api_key: "key".to_string(), + }, + )]), guards: vec![Guard { name: "g1".to_string(), provider: "gr_p2_non_existent".to_string(), @@ -235,7 +244,9 @@ mod tests { let result = validate_gateway_config(&config); assert!(result.is_err()); let errors = result.unwrap_err(); - assert!(errors.iter().any(|e| e.contains("references non-existent guardrail provider 'gr_p2_non_existent'"))); + assert!(errors.iter().any(|e| { + e.contains("references non-existent guardrail provider 'gr_p2_non_existent'") + })); assert!(errors.iter().any(|e| e.contains("no api_base configured"))); assert!(errors.iter().any(|e| e.contains("no api_key configured"))); } @@ -244,11 +255,14 @@ mod tests { fn test_pipeline_references_non_existent_guard() { let config = GatewayConfig { guardrails: Some(GuardrailsConfig { - providers: HashMap::from([("gr_p1".to_string(), GrProviderConfig { - name: "gr_p1".to_string(), - api_base: "http://localhost".to_string(), - api_key: "key".to_string(), - })]), + providers: HashMap::from([( + "gr_p1".to_string(), + GrProviderConfig { + name: "gr_p1".to_string(), + api_base: "http://localhost".to_string(), + api_key: "key".to_string(), + }, + )]), guards: vec![Guard { name: "g1".to_string(), provider: "gr_p1".to_string(), @@ -282,11 +296,14 @@ mod tests { fn test_duplicate_guard_names() { let config = GatewayConfig { guardrails: Some(GuardrailsConfig { - providers: HashMap::from([("gr_p1".to_string(), GrProviderConfig { - name: "gr_p1".to_string(), - api_base: "http://localhost".to_string(), - api_key: "key".to_string(), - })]), + providers: HashMap::from([( + "gr_p1".to_string(), + GrProviderConfig { + name: "gr_p1".to_string(), + api_base: "http://localhost".to_string(), + api_key: "key".to_string(), + }, + )]), guards: vec![ Guard { name: "g1".to_string(), @@ -328,11 +345,14 @@ mod tests { fn test_guard_missing_api_base_and_api_key() { let config = GatewayConfig { guardrails: Some(GuardrailsConfig { - providers: HashMap::from([("gr_p1".to_string(), GrProviderConfig { - name: "gr_p1".to_string(), - api_base: "".to_string(), - api_key: "".to_string(), - })]), + providers: HashMap::from([( + "gr_p1".to_string(), + GrProviderConfig { + name: "gr_p1".to_string(), + api_base: "".to_string(), + api_key: "".to_string(), + }, + )]), guards: vec![Guard { name: "g1".to_string(), provider: "gr_p1".to_string(), @@ -362,11 +382,14 @@ mod tests { fn test_guard_inherits_api_base_from_provider() { let config = GatewayConfig { guardrails: Some(GuardrailsConfig { - providers: HashMap::from([("gr_p1".to_string(), GrProviderConfig { - name: "gr_p1".to_string(), - api_base: "http://localhost".to_string(), - api_key: "key".to_string(), - })]), + providers: HashMap::from([( + "gr_p1".to_string(), + GrProviderConfig { + name: "gr_p1".to_string(), + api_base: "http://localhost".to_string(), + api_key: "key".to_string(), + }, + )]), guards: vec![Guard { name: "g1".to_string(), provider: "gr_p1".to_string(), @@ -391,11 +414,14 @@ mod tests { fn test_guard_unknown_evaluator_slug() { let config = GatewayConfig { guardrails: Some(GuardrailsConfig { - providers: HashMap::from([("gr_p1".to_string(), GrProviderConfig { - name: "gr_p1".to_string(), - api_base: "http://localhost".to_string(), - api_key: "key".to_string(), - })]), + providers: HashMap::from([( + "gr_p1".to_string(), + GrProviderConfig { + name: "gr_p1".to_string(), + api_base: "http://localhost".to_string(), + api_key: "key".to_string(), + }, + )]), guards: vec![Guard { name: "g1".to_string(), provider: "gr_p1".to_string(), @@ -416,6 +442,10 @@ mod tests { let result = validate_gateway_config(&config); assert!(result.is_err()); let errors = result.unwrap_err(); - assert!(errors.iter().any(|e| e.contains("unknown evaluator_slug 'made-up-slug'"))); + assert!( + errors + .iter() + .any(|e| e.contains("unknown evaluator_slug 'made-up-slug'")) + ); } } diff --git a/src/guardrails/evaluator_types.rs b/src/guardrails/evaluator_types.rs index 6bef0ffd..3937e8aa 100644 --- a/src/guardrails/evaluator_types.rs +++ b/src/guardrails/evaluator_types.rs @@ -1,4 +1,4 @@ -use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; use serde_json::json; use std::collections::HashMap; @@ -84,8 +84,8 @@ fn attach_config( let params_value: serde_json::Value = params.clone().into_iter().collect(); let config: C = serde_json::from_value(params_value) .map_err(|e| GuardrailError::ParseError(format!("{slug}: invalid config — {e}")))?; - let config_json = serde_json::to_value(config) - .map_err(|e| GuardrailError::ParseError(e.to_string()))?; + let config_json = + serde_json::to_value(config).map_err(|e| GuardrailError::ParseError(e.to_string()))?; if config_json.as_object().is_some_and(|m| !m.is_empty()) { body["config"] = config_json; } diff --git a/src/guardrails/providers/traceloop.rs b/src/guardrails/providers/traceloop.rs index 3795f3e1..9cd3ba8c 100644 --- a/src/guardrails/providers/traceloop.rs +++ b/src/guardrails/providers/traceloop.rs @@ -7,7 +7,6 @@ use crate::guardrails::evaluator_types::get_evaluator; use crate::guardrails::parsing::parse_evaluator_http_response; use crate::guardrails::types::{EvaluatorResponse, Guard, GuardrailError}; - const DEFAULT_TRACELOOP_API: &str = "https://api.traceloop.com"; /// HTTP client for the Traceloop evaluator API service. /// Calls `POST {api_base}/v2/guardrails/{evaluator_slug}`. @@ -45,26 +44,30 @@ impl GuardrailClient for TraceloopClient { guard: &Guard, input: &str, ) -> Result { - let api_base = guard.api_base.as_deref() + let api_base = guard + .api_base + .as_deref() .filter(|s| !s.is_empty()) .unwrap_or(DEFAULT_TRACELOOP_API); let url = format!( "{}/v2/guardrails/execute/{}", - api_base, - guard.evaluator_slug + api_base, guard.evaluator_slug ); let api_key = guard.api_key.as_deref().unwrap_or(""); let evaluator = get_evaluator(&guard.evaluator_slug).ok_or_else(|| { - GuardrailError::Unavailable(format!("Unknown evaluator slug '{}'", guard.evaluator_slug)) + GuardrailError::Unavailable(format!( + "Unknown evaluator slug '{}'", + guard.evaluator_slug + )) })?; let body = evaluator.build_body(input, &guard.params)?; debug!(guard = %guard.name, slug = %guard.evaluator_slug, %url, %body, "NOMI - Calling evaluator API"); let response = self - .http_client + .http_client .post(&url) .header("Authorization", format!("Bearer {api_key}")) .header("Content-Type", "application/json") @@ -89,14 +92,20 @@ mod tests { #[test] fn test_build_body_text_slug() { let params = HashMap::new(); - let body = get_evaluator("pii-detector").unwrap().build_body("hello world", ¶ms).unwrap(); + let body = get_evaluator("pii-detector") + .unwrap() + .build_body("hello world", ¶ms) + .unwrap(); assert_eq!(body, json!({"input": {"text": "hello world"}})); } #[test] fn test_build_body_prompt_slug() { let params = HashMap::new(); - let body = get_evaluator("prompt-injection").unwrap().build_body("hello world", ¶ms).unwrap(); + let body = get_evaluator("prompt-injection") + .unwrap() + .build_body("hello world", ¶ms) + .unwrap(); assert_eq!(body, json!({"input": {"prompt": "hello world"}})); } @@ -104,7 +113,10 @@ mod tests { fn test_build_body_with_config() { let mut params = HashMap::new(); params.insert("threshold".to_string(), json!(0.8)); - let body = get_evaluator("toxicity-detector").unwrap().build_body("test", ¶ms).unwrap(); + let body = get_evaluator("toxicity-detector") + .unwrap() + .build_body("test", ¶ms) + .unwrap(); assert_eq!( body, json!({"input": {"text": "test"}, "config": {"threshold": 0.8}}) @@ -114,7 +126,10 @@ mod tests { #[test] fn test_build_body_no_config_when_params_empty() { let params = HashMap::new(); - let body = get_evaluator("secrets-detector").unwrap().build_body("test", ¶ms).unwrap(); + let body = get_evaluator("secrets-detector") + .unwrap() + .build_body("test", ¶ms) + .unwrap(); assert!(body.get("config").is_none()); } @@ -127,7 +142,9 @@ mod tests { fn test_build_body_rejects_invalid_config_type() { let mut params = HashMap::new(); params.insert("threshold".to_string(), json!("not-a-number")); - let result = get_evaluator("toxicity-detector").unwrap().build_body("test", ¶ms); + let result = get_evaluator("toxicity-detector") + .unwrap() + .build_body("test", ¶ms); assert!(result.is_err()); } } diff --git a/src/guardrails/runner.rs b/src/guardrails/runner.rs index a8d4d557..4de82d3b 100644 --- a/src/guardrails/runner.rs +++ b/src/guardrails/runner.rs @@ -1,16 +1,18 @@ use std::collections::HashSet; +use axum::Json; use axum::http::HeaderMap; -use axum::response::{IntoResponse, Response}; use axum::http::StatusCode; -use axum::Json; +use axum::response::{IntoResponse, Response}; use futures::future::join_all; use serde_json::json; use tracing::{debug, warn}; +use super::parsing::{CompletionExtractor, PromptExtractor}; use super::setup::{parse_guardrails_header, resolve_guards_by_name, split_guards_by_mode}; -use super::parsing::{PromptExtractor, CompletionExtractor}; -use super::types::{Guard, GuardResult, GuardWarning, GuardrailClient, GuardrailsOutcome, Guardrails, OnFailure}; +use super::types::{ + Guard, GuardResult, GuardWarning, GuardrailClient, Guardrails, GuardrailsOutcome, OnFailure, +}; /// Execute a set of guardrails against the given input text. /// Guards are run concurrently. Returns a GuardrailsOutcome with results, blocked status, and warnings. @@ -190,10 +192,9 @@ impl<'a> GuardrailsRunner<'a> { } let header_val = warning_header_value(warnings); let mut response = response; - response.headers_mut().insert( - "X-Traceloop-Guardrail-Warning", - header_val.parse().unwrap(), - ); + response + .headers_mut() + .insert("X-Traceloop-Guardrail-Warning", header_val.parse().unwrap()); response } } @@ -203,7 +204,9 @@ pub fn blocked_response(outcome: &GuardrailsOutcome) -> Response { let guard_name = outcome.blocking_guard.as_deref().unwrap_or("unknown"); // Find the blocking guard result to get details - let details = outcome.results.iter() + let details = outcome + .results + .iter() .find(|r| match r { GuardResult::Failed { name, .. } => name == guard_name, GuardResult::Error { name, .. } => name == guard_name, @@ -242,7 +245,12 @@ pub fn blocked_response(outcome: &GuardrailsOutcome) -> Response { pub fn warning_header_value(warnings: &[GuardWarning]) -> String { warnings .iter() - .map(|w| format!("guardrail_name=\"{}\", reason=\"{}\"", w.guard_name, w.reason)) + .map(|w| { + format!( + "guardrail_name=\"{}\", reason=\"{}\"", + w.guard_name, w.reason + ) + }) .collect::>() .join("; ") } diff --git a/src/guardrails/setup.rs b/src/guardrails/setup.rs index f343b0c4..9ffa3b19 100644 --- a/src/guardrails/setup.rs +++ b/src/guardrails/setup.rs @@ -1,7 +1,9 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; -use super::types::{Guard, GuardMode, GuardrailClient, GuardrailResources, GuardrailsConfig, Guardrails}; +use super::types::{ + Guard, GuardMode, GuardrailClient, GuardrailResources, Guardrails, GuardrailsConfig, +}; /// Parse guard names from the X-Traceloop-Guardrails header value. /// Names are comma-separated and trimmed. @@ -25,10 +27,7 @@ pub fn resolve_guards_by_name( let mut seen = HashSet::new(); let mut resolved = Vec::new(); - let all_names = pipeline_names - .iter() - .chain(header_names.iter()) - .copied(); + let all_names = pipeline_names.iter().chain(header_names.iter()).copied(); for name in all_names { if seen.insert(name) { @@ -77,9 +76,7 @@ pub fn resolve_guard_defaults(config: &GuardrailsConfig) -> Vec { /// Build the shared guardrail resources (resolved guards + client). /// Returns None if the config has no guards. /// Called once per router build; the result is shared across all pipelines. -pub fn build_guardrail_resources( - config: &GuardrailsConfig, -) -> Option { +pub fn build_guardrail_resources(config: &GuardrailsConfig) -> Option { if config.guards.is_empty() { return None; } diff --git a/src/pipelines/pipeline.rs b/src/pipelines/pipeline.rs index da4f75d1..80da8910 100644 --- a/src/pipelines/pipeline.rs +++ b/src/pipelines/pipeline.rs @@ -28,18 +28,18 @@ use reqwest_streams::error::StreamBodyError; use std::sync::Arc; // Re-export builder and orchestrator functions for backward compatibility with tests +pub use crate::guardrails::runner::{blocked_response, warning_header_value}; pub use crate::guardrails::setup::{ build_guardrail_resources, build_pipeline_guardrails, resolve_guard_defaults, }; -pub use crate::guardrails::runner::{blocked_response, warning_header_value}; pub fn create_pipeline( pipeline: &Pipeline, model_registry: &ModelRegistry, guardrail_resources: Option<&GuardrailResources>, ) -> Router { - let guardrails: Option> = guardrail_resources - .map(|shared| build_pipeline_guardrails(shared, &pipeline.guards)); + let guardrails: Option> = + guardrail_resources.map(|shared| build_pipeline_guardrails(shared, &pipeline.guards)); let mut router = Router::new(); let available_models: Vec = pipeline @@ -73,18 +73,16 @@ pub fn create_pipeline( router } PluginConfig::ModelRouter { models } => match pipeline.r#type { - PipelineType::Chat => { - router.route( - "/chat/completions", - post(move |state, headers, payload| chat_completions(state, headers, payload, models, gr)), - ) - } - PipelineType::Completion => { - router.route( - "/completions", - post(move |state, payload| completions(state, payload, models)), - ) - } + PipelineType::Chat => router.route( + "/chat/completions", + post(move |state, headers, payload| { + chat_completions(state, headers, payload, models, gr) + }), + ), + PipelineType::Completion => router.route( + "/completions", + post(move |state, payload| completions(state, payload, models)), + ), PipelineType::Embeddings => router.route( "/embeddings", post(move |state, payload| embeddings(state, payload, models)), diff --git a/src/state.rs b/src/state.rs index d5dcbf4f..a3e17c0b 100644 --- a/src/state.rs +++ b/src/state.rs @@ -186,8 +186,11 @@ impl AppState { "Adding default pipeline '{}' to router at index 0", default_pipeline.name ); - let pipeline_router = - create_pipeline(default_pipeline, model_registry, guardrail_resources.as_ref()); + let pipeline_router = create_pipeline( + default_pipeline, + model_registry, + guardrail_resources.as_ref(), + ); pipeline_routers.push(pipeline_router); pipeline_names.push(default_pipeline.name.clone()); } diff --git a/tests/guardrails/helpers.rs b/tests/guardrails/helpers.rs index 1d69039b..6dd1a8de 100644 --- a/tests/guardrails/helpers.rs +++ b/tests/guardrails/helpers.rs @@ -2,9 +2,7 @@ use std::collections::HashMap; use async_trait::async_trait; use hub_lib::guardrails::types::GuardrailClient; -use hub_lib::guardrails::types::{ - EvaluatorResponse, Guard, GuardMode, GuardrailError, OnFailure, -}; +use hub_lib::guardrails::types::{EvaluatorResponse, Guard, GuardMode, GuardrailError, OnFailure}; use hub_lib::models::chat::{ChatCompletion, ChatCompletionChoice, ChatCompletionRequest}; use hub_lib::models::content::{ChatCompletionMessage, ChatMessageContent}; use hub_lib::models::usage::Usage; diff --git a/tests/guardrails/test_e2e.rs b/tests/guardrails/test_e2e.rs index 2875d657..80767264 100644 --- a/tests/guardrails/test_e2e.rs +++ b/tests/guardrails/test_e2e.rs @@ -369,21 +369,23 @@ guardrails: assert_eq!(pipeline_gr.all_guards.len(), 2); assert_eq!(pipeline_gr.pipeline_guard_names.len(), 2); // Provider api_base should be resolved for guards that don't override - let pre_guard = pipeline_gr.all_guards.iter().find(|g| g.mode == GuardMode::PreCall).unwrap(); - let post_guard = pipeline_gr.all_guards.iter().find(|g| g.mode == GuardMode::PostCall).unwrap(); + let pre_guard = pipeline_gr + .all_guards + .iter() + .find(|g| g.mode == GuardMode::PreCall) + .unwrap(); + let post_guard = pipeline_gr + .all_guards + .iter() + .find(|g| g.mode == GuardMode::PostCall) + .unwrap(); assert_eq!( pre_guard.api_base.as_deref(), Some("https://api.traceloop.com") ); - assert_eq!( - pre_guard.api_key.as_deref(), - Some("resolved-key-123") - ); + assert_eq!(pre_guard.api_key.as_deref(), Some("resolved-key-123")); // Guard with override keeps its own api_key - assert_eq!( - post_guard.api_key.as_deref(), - Some("override-key") - ); + assert_eq!(post_guard.api_key.as_deref(), Some("override-key")); unsafe { std::env::remove_var("E2E_TEST_API_KEY"); diff --git a/tests/guardrails/test_parsing.rs b/tests/guardrails/test_parsing.rs index 95ad82f3..423c5f4c 100644 --- a/tests/guardrails/test_parsing.rs +++ b/tests/guardrails/test_parsing.rs @@ -1,6 +1,5 @@ use hub_lib::guardrails::parsing::{ - parse_evaluator_response, parse_evaluator_http_response, - CompletionExtractor, PromptExtractor, + CompletionExtractor, PromptExtractor, parse_evaluator_http_response, parse_evaluator_response, }; use hub_lib::guardrails::types::GuardrailError; use hub_lib::models::content::{ChatCompletionMessage, ChatMessageContent, ChatMessageContentPart}; diff --git a/tests/guardrails/test_setup.rs b/tests/guardrails/test_setup.rs index 04707ad1..0a15a2a5 100644 --- a/tests/guardrails/test_setup.rs +++ b/tests/guardrails/test_setup.rs @@ -123,11 +123,14 @@ fn test_complete_resolution_merged() { fn test_guardrails_config() -> GuardrailsConfig { GuardrailsConfig { - providers: HashMap::from([("traceloop".to_string(), ProviderConfig { - name: "traceloop".to_string(), - api_base: "https://api.traceloop.com".to_string(), - api_key: "test-key".to_string(), - })]), + providers: HashMap::from([( + "traceloop".to_string(), + ProviderConfig { + name: "traceloop".to_string(), + api_base: "https://api.traceloop.com".to_string(), + api_key: "test-key".to_string(), + }, + )]), guards: vec![ Guard { name: "pii-check".to_string(), @@ -189,11 +192,14 @@ fn test_no_guardrails_passthrough() { // Config with no guards -> passthrough let config_with_providers = GuardrailsConfig { - providers: HashMap::from([("traceloop".to_string(), ProviderConfig { - name: "traceloop".to_string(), - api_base: "http://localhost".to_string(), - api_key: "key".to_string(), - })]), + providers: HashMap::from([( + "traceloop".to_string(), + ProviderConfig { + name: "traceloop".to_string(), + api_base: "http://localhost".to_string(), + api_key: "key".to_string(), + }, + )]), guards: vec![], }; let result = build_guardrail_resources(&config_with_providers); @@ -210,7 +216,10 @@ fn test_build_pipeline_guardrails_with_specific_guards() { // all_guards should contain ALL guards from config, resolved with provider defaults assert_eq!(gr.all_guards.len(), 4); // pipeline_guard_names should only contain the ones specified - assert_eq!(gr.pipeline_guard_names, vec!["pii-check", "toxicity-filter"]); + assert_eq!( + gr.pipeline_guard_names, + vec!["pii-check", "toxicity-filter"] + ); } #[test] @@ -242,11 +251,14 @@ fn test_build_pipeline_guardrails_resolves_provider_defaults() { #[test] fn test_resolve_guard_defaults_preserves_guard_overrides() { let config = GuardrailsConfig { - providers: HashMap::from([("traceloop".to_string(), ProviderConfig { - name: "traceloop".to_string(), - api_base: "https://default.api.com".to_string(), - api_key: "default-key".to_string(), - })]), + providers: HashMap::from([( + "traceloop".to_string(), + ProviderConfig { + name: "traceloop".to_string(), + api_base: "https://default.api.com".to_string(), + api_key: "default-key".to_string(), + }, + )]), guards: vec![Guard { name: "custom-guard".to_string(), provider: "traceloop".to_string(), @@ -261,7 +273,10 @@ fn test_resolve_guard_defaults_preserves_guard_overrides() { }; let resolved = resolve_guard_defaults(&config); - assert_eq!(resolved[0].api_base.as_deref(), Some("https://custom.api.com")); + assert_eq!( + resolved[0].api_base.as_deref(), + Some("https://custom.api.com") + ); assert_eq!(resolved[0].api_key.as_deref(), Some("custom-key")); } @@ -276,11 +291,7 @@ fn test_pipeline_guards_resolved_at_request_time() { // Header adds injection-check let header_names = vec!["injection-check"]; - let resolved = resolve_guards_by_name( - &all_guards, - &pipeline_names, - &header_names, - ); + let resolved = resolve_guards_by_name(&all_guards, &pipeline_names, &header_names); assert_eq!(resolved.len(), 2); assert_eq!(resolved[0].name, "pii-check"); assert_eq!(resolved[1].name, "injection-check"); diff --git a/tests/guardrails/test_traceloop_client.rs b/tests/guardrails/test_traceloop_client.rs index e25ac042..9ebd42e4 100644 --- a/tests/guardrails/test_traceloop_client.rs +++ b/tests/guardrails/test_traceloop_client.rs @@ -1,7 +1,7 @@ -use hub_lib::guardrails::providers::traceloop::TraceloopClient; use hub_lib::guardrails::providers::create_guardrail_client; -use hub_lib::guardrails::types::GuardrailClient; +use hub_lib::guardrails::providers::traceloop::TraceloopClient; use hub_lib::guardrails::types::GuardMode; +use hub_lib::guardrails::types::GuardrailClient; use serde_json::json; use wiremock::matchers; use wiremock::{Mock, MockServer, ResponseTemplate}; diff --git a/tests/guardrails/test_types.rs b/tests/guardrails/test_types.rs index c81dd712..a72a091e 100644 --- a/tests/guardrails/test_types.rs +++ b/tests/guardrails/test_types.rs @@ -210,7 +210,10 @@ guards: let config: GuardrailsConfig = serde_yaml::from_str(yaml).unwrap(); assert_eq!(config.providers.len(), 1); assert_eq!(config.providers["traceloop"].name, "traceloop"); - assert_eq!(config.providers["traceloop"].api_base, "https://api.traceloop.com"); + assert_eq!( + config.providers["traceloop"].api_base, + "https://api.traceloop.com" + ); assert_eq!(config.guards.len(), 2); // First guard has no api_base/api_key (inherits from provider) assert!(config.guards[0].api_base.is_none()); @@ -306,7 +309,10 @@ guardrails: let config = hub_lib::config::load_config(temp_file.path().to_str().unwrap()).unwrap(); // Default pipeline should have guards - assert_eq!(config.pipelines[0].guards, vec!["pii-check", "injection-check"]); + assert_eq!( + config.pipelines[0].guards, + vec!["pii-check", "injection-check"] + ); // Embeddings pipeline should have no guards assert!(config.pipelines[1].guards.is_empty()); } From f0f461c9afed249d816bc7dfddb39ef1974a58a4 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Sun, 15 Feb 2026 11:55:13 +0200 Subject: [PATCH 21/59] add warn --- src/guardrails/runner.rs | 16 ++++++- src/guardrails/types.rs | 2 +- src/pipelines/pipeline.rs | 5 +++ tests/guardrails/helpers.rs | 2 +- tests/guardrails/test_e2e.rs | 77 +++++++++++++++++++++++++++++++++- tests/guardrails/test_setup.rs | 8 ++-- tests/guardrails/test_types.rs | 4 +- 7 files changed, 102 insertions(+), 12 deletions(-) diff --git a/src/guardrails/runner.rs b/src/guardrails/runner.rs index 4de82d3b..936380b8 100644 --- a/src/guardrails/runner.rs +++ b/src/guardrails/runner.rs @@ -170,8 +170,20 @@ impl<'a> GuardrailsRunner<'a> { warnings: Vec::new(), }; } - let input = response.extract_completion(); - let outcome = execute_guards(&self.post_call, &input, self.client).await; + let completion = response.extract_completion(); + + if completion.is_empty() { + warn!("Skipping post-call guardrails: LLM response content is empty"); + return GuardPhaseResult { + blocked_response: None, + warnings: vec![GuardWarning { + guard_name: "all post_call guards".to_string(), + reason: "skipped due to empty response content".to_string(), + }], + }; + } + + let outcome = execute_guards(&self.post_call, &completion, self.client).await; if outcome.blocked { return GuardPhaseResult { blocked_response: Some(blocked_response(&outcome)), diff --git a/src/guardrails/types.rs b/src/guardrails/types.rs index 8ba6b88a..7cc4e2d8 100644 --- a/src/guardrails/types.rs +++ b/src/guardrails/types.rs @@ -14,7 +14,7 @@ fn default_on_failure() -> OnFailure { } fn default_required() -> bool { - true + false } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)] diff --git a/src/pipelines/pipeline.rs b/src/pipelines/pipeline.rs index 80da8910..cadcccf2 100644 --- a/src/pipelines/pipeline.rs +++ b/src/pipelines/pipeline.rs @@ -154,6 +154,11 @@ pub async fn chat_completions( if let ChatCompletionResponse::NonStream(completion) = response { tracer.log_success(&completion); + tracing::debug!( + completion = %serde_json::to_string(&completion).unwrap_or_default(), + "AASA - LLM response before post-call guardrails" + ); + // Post-call guardrails (non-streaming) if let Some(orch) = &orchestrator { let post = orch.run_post_call(&completion).await; diff --git a/tests/guardrails/helpers.rs b/tests/guardrails/helpers.rs index 6dd1a8de..79252ada 100644 --- a/tests/guardrails/helpers.rs +++ b/tests/guardrails/helpers.rs @@ -20,7 +20,7 @@ pub fn create_test_guard(name: &str, mode: GuardMode) -> Guard { params: HashMap::new(), mode, on_failure: OnFailure::Block, - required: true, + required: false, api_base: Some("http://localhost:8080".to_string()), api_key: Some("test-api-key".to_string()), } diff --git a/tests/guardrails/test_e2e.rs b/tests/guardrails/test_e2e.rs index 80767264..77ff0e85 100644 --- a/tests/guardrails/test_e2e.rs +++ b/tests/guardrails/test_e2e.rs @@ -1,12 +1,16 @@ use hub_lib::guardrails::parsing::{CompletionExtractor, PromptExtractor}; use hub_lib::guardrails::providers::traceloop::TraceloopClient; -use hub_lib::guardrails::runner::{blocked_response, execute_guards, warning_header_value}; +use hub_lib::guardrails::runner::{ + blocked_response, execute_guards, warning_header_value, GuardrailsRunner, +}; use hub_lib::guardrails::setup::{build_guardrail_resources, build_pipeline_guardrails}; use hub_lib::guardrails::types::*; use axum::body::to_bytes; +use axum::http::HeaderMap; use serde_json::json; +use std::sync::Arc; use wiremock::matchers; use wiremock::{Mock, MockServer, ResponseTemplate}; @@ -46,7 +50,7 @@ fn guard_with_server( params: Default::default(), mode, on_failure, - required: true, + required: false, api_base: Some(server_uri.to_string()), api_key: Some("test-key".to_string()), } @@ -635,3 +639,72 @@ async fn test_blocked_response_403_format() { .contains("toxicity-check") ); } + +#[tokio::test] +async fn test_post_call_skipped_on_empty_response() { + // When the LLM returns empty content (e.g. max_tokens too low), + // post-call guards should be skipped and a warning returned. + let eval_server = MockServer::start().await; + Mock::given(matchers::any()) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(json!({"result": {}, "pass": true})), + ) + .expect(0) // evaluator should never be called + .mount(&eval_server) + .await; + + let guard = guard_with_server( + "toxicity-filter", + GuardMode::PostCall, + OnFailure::Block, + &eval_server.uri(), + "toxicity-detector", + ); + + let guardrails = Guardrails { + all_guards: Arc::new(vec![guard]), + pipeline_guard_names: vec!["toxicity-filter".to_string()], + client: Arc::new(TraceloopClient::new()), + }; + + let headers = HeaderMap::new(); + let runner = GuardrailsRunner::new(Some(&guardrails), &headers).unwrap(); + + // Completion with content: None (simulates empty LLM response) + let empty_completion = create_test_chat_completion(""); + let result = runner.run_post_call(&empty_completion).await; + + assert!(result.blocked_response.is_none()); + assert_eq!(result.warnings.len(), 1); + assert!(result.warnings[0].reason.contains("empty response content")); + + let header = warning_header_value(&result.warnings); + assert!(header.contains("skipped")); + // wiremock will verify expect(0) — evaluator was never called +} + +#[tokio::test] +async fn test_evaluator_error_not_blocked_by_default() { + // Guards default to required: false (fail-open), so an evaluator HTTP 500 + // should NOT block the request. + let server = MockServer::start().await; + Mock::given(matchers::any()) + .respond_with(ResponseTemplate::new(500).set_body_string("Internal Server Error")) + .mount(&server) + .await; + + let guard = guard_with_server( + "warn-guard", + GuardMode::PreCall, + OnFailure::Warn, + &server.uri(), + "profanity-detector", + ); + // guard.required is false by default + + let client = TraceloopClient::new(); + let outcome = execute_guards(&[guard], "test input", &client).await; + + assert!(!outcome.blocked); +} diff --git a/tests/guardrails/test_setup.rs b/tests/guardrails/test_setup.rs index 0a15a2a5..cb1d3b19 100644 --- a/tests/guardrails/test_setup.rs +++ b/tests/guardrails/test_setup.rs @@ -139,7 +139,7 @@ fn test_guardrails_config() -> GuardrailsConfig { params: Default::default(), mode: GuardMode::PreCall, on_failure: OnFailure::Block, - required: true, + required: false, api_base: None, api_key: None, }, @@ -150,7 +150,7 @@ fn test_guardrails_config() -> GuardrailsConfig { params: Default::default(), mode: GuardMode::PostCall, on_failure: OnFailure::Warn, - required: true, + required: false, api_base: None, api_key: None, }, @@ -161,7 +161,7 @@ fn test_guardrails_config() -> GuardrailsConfig { params: Default::default(), mode: GuardMode::PreCall, on_failure: OnFailure::Block, - required: true, + required: false, api_base: None, api_key: None, }, @@ -172,7 +172,7 @@ fn test_guardrails_config() -> GuardrailsConfig { params: Default::default(), mode: GuardMode::PostCall, on_failure: OnFailure::Block, - required: true, + required: false, api_base: None, api_key: None, }, diff --git a/tests/guardrails/test_types.rs b/tests/guardrails/test_types.rs index a72a091e..47543852 100644 --- a/tests/guardrails/test_types.rs +++ b/tests/guardrails/test_types.rs @@ -32,7 +32,7 @@ fn test_on_failure_defaults_to_warn() { } #[test] -fn test_required_defaults_to_true() { +fn test_required_defaults_to_false() { let json = serde_json::json!({ "name": "test-guard", "provider": "traceloop", @@ -40,7 +40,7 @@ fn test_required_defaults_to_true() { "mode": "pre_call" }); let guard: Guard = serde_json::from_value(json).unwrap(); - assert!(guard.required); + assert!(!guard.required); } #[test] From a00198061399f181edfba3d3574c5e1a4331df52 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Sun, 15 Feb 2026 13:50:10 +0200 Subject: [PATCH 22/59] add spans --- src/guardrails/mod.rs | 1 + src/guardrails/runner.rs | 137 ++++++++++++++++++++++++------ src/guardrails/span_attributes.rs | 10 +++ src/pipelines/otel.rs | 10 ++- src/pipelines/pipeline.rs | 3 +- tests/guardrails/test_e2e.rs | 34 ++++---- tests/guardrails/test_runner.rs | 24 +++--- 7 files changed, 163 insertions(+), 56 deletions(-) create mode 100644 src/guardrails/span_attributes.rs diff --git a/src/guardrails/mod.rs b/src/guardrails/mod.rs index 140ab4c4..d31a5a0b 100644 --- a/src/guardrails/mod.rs +++ b/src/guardrails/mod.rs @@ -3,4 +3,5 @@ pub mod parsing; pub mod providers; pub mod runner; pub mod setup; +pub mod span_attributes; pub mod types; diff --git a/src/guardrails/runner.rs b/src/guardrails/runner.rs index 936380b8..cf11ff5a 100644 --- a/src/guardrails/runner.rs +++ b/src/guardrails/runner.rs @@ -5,46 +5,121 @@ use axum::http::HeaderMap; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; use futures::future::join_all; +use opentelemetry::global::{BoxedSpan, ObjectSafeSpan}; +use opentelemetry::trace::{SpanKind, Status as OtelStatus, Tracer}; +use opentelemetry::{Context, KeyValue, global}; use serde_json::json; use tracing::{debug, warn}; +use crate::config::lib::get_trace_content_enabled; + use super::parsing::{CompletionExtractor, PromptExtractor}; use super::setup::{parse_guardrails_header, resolve_guards_by_name, split_guards_by_mode}; +use super::span_attributes::*; use super::types::{ - Guard, GuardResult, GuardWarning, GuardrailClient, Guardrails, GuardrailsOutcome, OnFailure, + EvaluatorResponse, Guard, GuardResult, GuardWarning, GuardrailClient, GuardrailError, + Guardrails, GuardrailsOutcome, OnFailure, }; +fn error_type_name(err: &GuardrailError) -> &'static str { + match err { + GuardrailError::Unavailable(_) => "Unavailable", + GuardrailError::HttpError { .. } => "HttpError", + GuardrailError::Timeout(_) => "Timeout", + GuardrailError::ParseError(_) => "ParseError", + } +} + +fn record_guard_span( + span: &mut BoxedSpan, + guard: &Guard, + result: &Result, + elapsed: std::time::Duration, + input: &str, + guard_count: usize, +) { + span.set_attribute(KeyValue::new(GEN_AI_GUARDRAIL_NAME, guard.name.clone())); + span.set_attribute(KeyValue::new( + GEN_AI_GUARDRAIL_DURATION, + elapsed.as_millis() as i64, + )); + + if get_trace_content_enabled() { + span.set_attribute(KeyValue::new(GEN_AI_GUARDRAIL_INPUT, input.to_string())); + } + + match result { + Ok(resp) => { + let status = if resp.pass { GUARDRAIL_PASSED } else { GUARDRAIL_FAILED }; + span.set_attribute(KeyValue::new(GEN_AI_GUARDRAIL_STATUS, status)); + } + Err(err) => { + span.set_attribute(KeyValue::new(GEN_AI_GUARDRAIL_STATUS, GUARDRAIL_ERROR)); + span.set_attribute(KeyValue::new( + GEN_AI_GUARDRAIL_ERROR_TYPE, + error_type_name(err), + )); + span.set_attribute(KeyValue::new( + GEN_AI_GUARDRAIL_ERROR_MESSAGE, + err.to_string(), + )); + span.set_status(OtelStatus::error(err.to_string())); + } + } +} + /// Execute a set of guardrails against the given input text. /// Guards are run concurrently. Returns a GuardrailsOutcome with results, blocked status, and warnings. +/// When `parent_cx` is provided, creates a child OTel span per guard evaluation. pub async fn execute_guards( guards: &[Guard], input: &str, client: &dyn GuardrailClient, + parent_cx: Option<&Context>, ) -> GuardrailsOutcome { debug!(guard_count = guards.len(), "Executing guardrails"); + let guard_count = guards.len(); + let parent_cx = parent_cx.cloned(); + let futures: Vec<_> = guards .iter() - .map(|guard| async move { - let start = std::time::Instant::now(); - let result = client.evaluate(guard, input).await; - let elapsed = start.elapsed(); - match &result { - Ok(resp) => debug!( - guard = %guard.name, - pass = resp.pass, - elapsed_ms = elapsed.as_millis(), - "Guard evaluation complete" - ), - Err(err) => warn!( - guard = %guard.name, - error = %err, - required = guard.required, - elapsed_ms = elapsed.as_millis(), - "Guard evaluation failed" - ), + .map(|guard| { + let parent_cx = parent_cx.clone(); + async move { + let start = std::time::Instant::now(); + let result = client.evaluate(guard, input).await; + let elapsed = start.elapsed(); + + // Create child span if tracing context is available + let span = parent_cx.as_ref().map(|cx| { + let tracer = global::tracer("traceloop_hub"); + let mut span = tracer + .span_builder(format!("{}.guard", guard.name)) + .with_kind(SpanKind::Internal) + .start_with_context(&tracer, cx); + + record_guard_span(&mut span, guard, &result, elapsed, input, guard_count); + span + }); + + match &result { + Ok(resp) => debug!( + guard = %guard.name, + pass = resp.pass, + elapsed_ms = elapsed.as_millis(), + "Guard evaluation complete" + ), + Err(err) => warn!( + guard = %guard.name, + error = %err, + required = guard.required, + elapsed_ms = elapsed.as_millis(), + "Guard evaluation failed" + ), + } + (guard, result, span) } - (guard, result) }) .collect(); @@ -54,8 +129,12 @@ pub async fn execute_guards( let mut blocked = false; let mut blocking_guard = None; let mut warnings = Vec::new(); + let mut guard_spans: Vec = Vec::new(); - for (guard, result) in results_raw { + for (guard, result, span) in results_raw { + if let Some(s) = span { + guard_spans.push(s); + } match result { Ok(response) => { if response.pass { @@ -122,12 +201,18 @@ pub struct GuardrailsRunner<'a> { pre_call: Vec, post_call: Vec, client: &'a dyn GuardrailClient, + parent_cx: Option, } impl<'a> GuardrailsRunner<'a> { /// Create a runner by resolving guards from pipeline config + request headers. /// Returns None if no guards are active for this request. - pub fn new(guardrails: Option<&'a Guardrails>, headers: &HeaderMap) -> Option { + /// When `parent_cx` is provided, guardrail evaluations are traced as child spans. + pub fn new( + guardrails: Option<&'a Guardrails>, + headers: &HeaderMap, + parent_cx: Option, + ) -> Option { let gr = guardrails?; let (pre_call, post_call) = resolve_request_guards(gr, headers); if pre_call.is_empty() && post_call.is_empty() { @@ -137,6 +222,7 @@ impl<'a> GuardrailsRunner<'a> { pre_call, post_call, client: gr.client.as_ref(), + parent_cx, }) } @@ -149,7 +235,8 @@ impl<'a> GuardrailsRunner<'a> { }; } let input = request.extract_pompt(); - let outcome = execute_guards(&self.pre_call, &input, self.client).await; + let outcome = + execute_guards(&self.pre_call, &input, self.client, self.parent_cx.as_ref()).await; if outcome.blocked { return GuardPhaseResult { blocked_response: Some(blocked_response(&outcome)), @@ -183,7 +270,9 @@ impl<'a> GuardrailsRunner<'a> { }; } - let outcome = execute_guards(&self.post_call, &completion, self.client).await; + let outcome = + execute_guards(&self.post_call, &completion, self.client, self.parent_cx.as_ref()) + .await; if outcome.blocked { return GuardPhaseResult { blocked_response: Some(blocked_response(&outcome)), diff --git a/src/guardrails/span_attributes.rs b/src/guardrails/span_attributes.rs new file mode 100644 index 00000000..28739acc --- /dev/null +++ b/src/guardrails/span_attributes.rs @@ -0,0 +1,10 @@ +pub const GEN_AI_GUARDRAIL_NAME: &str = "gen_ai.guardrail.name"; +pub const GEN_AI_GUARDRAIL_STATUS: &str = "gen_ai.guardrail.status"; +pub const GEN_AI_GUARDRAIL_DURATION: &str = "gen_ai.guardrail.duration"; +pub const GEN_AI_GUARDRAIL_INPUT: &str = "gen_ai.guardrail.input"; +pub const GEN_AI_GUARDRAIL_ERROR_TYPE: &str = "gen_ai.guardrail.error.type"; +pub const GEN_AI_GUARDRAIL_ERROR_MESSAGE: &str = "gen_ai.guardrail.error.message"; + +pub const GUARDRAIL_PASSED: &str = "PASSED"; +pub const GUARDRAIL_FAILED: &str = "FAILED"; +pub const GUARDRAIL_ERROR: &str = "ERROR"; diff --git a/src/pipelines/otel.rs b/src/pipelines/otel.rs index d71e4411..ea9cd6cb 100644 --- a/src/pipelines/otel.rs +++ b/src/pipelines/otel.rs @@ -6,8 +6,8 @@ use crate::models::embeddings::{EmbeddingsInput, EmbeddingsRequest, EmbeddingsRe use crate::models::streaming::ChatCompletionChunk; use crate::models::usage::{EmbeddingUsage, Usage}; use opentelemetry::global::{BoxedSpan, ObjectSafeSpan}; -use opentelemetry::trace::{SpanKind, Status, Tracer}; -use opentelemetry::{KeyValue, global}; +use opentelemetry::trace::{SpanKind, Status, TraceContextExt, Tracer}; +use opentelemetry::{Context, KeyValue, global}; use opentelemetry_otlp::{SpanExporter, WithExportConfig, WithHttpConfig}; use opentelemetry_sdk::propagation::TraceContextPropagator; use opentelemetry_sdk::trace::TracerProvider; @@ -174,6 +174,12 @@ impl OtelTracer { self.span.set_status(Status::error(description)); } + /// Returns an OTel Context carrying this tracer's root span as parent, + /// suitable for creating child spans. + pub fn parent_context(&self) -> Context { + Context::current().with_remote_span_context(self.span.span_context().clone()) + } + pub fn set_vendor(&mut self, vendor: &str) { self.span .set_attribute(KeyValue::new(GEN_AI_SYSTEM, vendor.to_string())); diff --git a/src/pipelines/pipeline.rs b/src/pipelines/pipeline.rs index cadcccf2..9ab4ca7a 100644 --- a/src/pipelines/pipeline.rs +++ b/src/pipelines/pipeline.rs @@ -126,7 +126,8 @@ pub async fn chat_completions( guardrails: Option>, ) -> Result { let mut tracer = OtelTracer::start("chat", &payload); - let orchestrator = GuardrailsRunner::new(guardrails.as_deref(), &headers); + let parent_cx = tracer.parent_context(); + let orchestrator = GuardrailsRunner::new(guardrails.as_deref(), &headers, Some(parent_cx)); // Pre-call guardrails let mut all_warnings = Vec::new(); diff --git a/tests/guardrails/test_e2e.rs b/tests/guardrails/test_e2e.rs index 77ff0e85..f315ad0d 100644 --- a/tests/guardrails/test_e2e.rs +++ b/tests/guardrails/test_e2e.rs @@ -72,7 +72,7 @@ async fn test_e2e_pre_call_block_flow() { let input = request.extract_pompt(); let client = TraceloopClient::new(); - let outcome = execute_guards(&[guard], &input, &client).await; + let outcome = execute_guards(&[guard], &input, &client, None).await; assert!(outcome.blocked); assert_eq!(outcome.blocking_guard.as_deref(), Some("blocker")); @@ -94,7 +94,7 @@ async fn test_e2e_pre_call_pass_flow() { let input = request.extract_pompt(); let client = TraceloopClient::new(); - let outcome = execute_guards(&[guard], &input, &client).await; + let outcome = execute_guards(&[guard], &input, &client, None).await; assert!(!outcome.blocked); assert!(outcome.warnings.is_empty()); @@ -118,7 +118,7 @@ async fn test_e2e_post_call_block_flow() { let response_text = completion.extract_completion(); let client = TraceloopClient::new(); - let outcome = execute_guards(&[guard], &response_text, &client).await; + let outcome = execute_guards(&[guard], &response_text, &client, None).await; assert!(outcome.blocked); assert_eq!(outcome.blocking_guard.as_deref(), Some("pii-check")); @@ -140,7 +140,7 @@ async fn test_e2e_post_call_warn_flow() { let response_text = completion.extract_completion(); let client = TraceloopClient::new(); - let outcome = execute_guards(&[guard], &response_text, &client).await; + let outcome = execute_guards(&[guard], &response_text, &client, None).await; assert!(!outcome.blocked); assert_eq!(outcome.warnings.len(), 1); @@ -173,13 +173,13 @@ async fn test_e2e_pre_and_post_both_pass() { // Pre-call let request = create_test_chat_request("Hello"); let input = request.extract_pompt(); - let pre_outcome = execute_guards(&[pre_guard], &input, &client).await; + let pre_outcome = execute_guards(&[pre_guard], &input, &client, None).await; assert!(!pre_outcome.blocked); // Post-call let completion = create_test_chat_completion("Hi there!"); let response_text = completion.extract_completion(); - let post_outcome = execute_guards(&[post_guard], &response_text, &client).await; + let post_outcome = execute_guards(&[post_guard], &response_text, &client, None).await; assert!(!post_outcome.blocked); assert!(post_outcome.warnings.is_empty()); } @@ -216,7 +216,7 @@ async fn test_e2e_pre_blocks_post_never_runs() { let request = create_test_chat_request("Bad input"); let input = request.extract_pompt(); - let pre_outcome = execute_guards(&[pre_guard], &input, &client).await; + let pre_outcome = execute_guards(&[pre_guard], &input, &client, None).await; assert!(pre_outcome.blocked); // Since pre blocked, post guards never run - post_eval.verify() will assert 0 calls @@ -256,7 +256,7 @@ async fn test_e2e_mixed_block_and_warn() { ]; let client = TraceloopClient::new(); - let outcome = execute_guards(&guards, "test input", &client).await; + let outcome = execute_guards(&guards, "test input", &client, None).await; assert!(outcome.blocked); assert_eq!(outcome.blocking_guard.as_deref(), Some("blocker")); @@ -278,7 +278,7 @@ async fn test_e2e_streaming_post_call_buffer_pass() { let accumulated = "Hello world!"; let client = TraceloopClient::new(); - let outcome = execute_guards(&[guard], &accumulated, &client).await; + let outcome = execute_guards(&[guard], &accumulated, &client, None).await; assert!(!outcome.blocked); } @@ -298,7 +298,7 @@ async fn test_e2e_streaming_post_call_buffer_block() { let accumulated = "Here is SSN: 123-45-6789"; let client = TraceloopClient::new(); - let outcome = execute_guards(&[guard], &accumulated, &client).await; + let outcome = execute_guards(&[guard], &accumulated, &client, None).await; assert!(outcome.blocked); } @@ -433,7 +433,7 @@ async fn test_e2e_multiple_guards_different_evaluators() { ]; let client = TraceloopClient::new(); - let outcome = execute_guards(&guards, "test input", &client).await; + let outcome = execute_guards(&guards, "test input", &client, None).await; assert!(!outcome.blocked); assert_eq!(outcome.results.len(), 2); @@ -459,7 +459,7 @@ async fn test_e2e_fail_open_evaluator_down() { guard.required = false; // fail-open let client = TraceloopClient::new(); - let outcome = execute_guards(&[guard], "test input", &client).await; + let outcome = execute_guards(&[guard], "test input", &client, None).await; assert!(!outcome.blocked); // Fail-open: not blocked despite error } @@ -483,7 +483,7 @@ async fn test_e2e_fail_closed_evaluator_down() { guard.required = true; // fail-closed let client = TraceloopClient::new(); - let outcome = execute_guards(&[guard], "test input", &client).await; + let outcome = execute_guards(&[guard], "test input", &client, None).await; assert!(outcome.blocked); // Fail-closed: blocked due to error } @@ -564,7 +564,7 @@ async fn test_pre_call_guardrails_warn_and_continue() { ); let client = TraceloopClient::new(); - let outcome = execute_guards(&[guard], "borderline input", &client).await; + let outcome = execute_guards(&[guard], "borderline input", &client, None).await; assert!(!outcome.blocked); assert_eq!(outcome.warnings.len(), 1); @@ -592,7 +592,7 @@ async fn test_post_call_guardrails_warn_and_add_header() { ); let client = TraceloopClient::new(); - let outcome = execute_guards(&[guard], "Some LLM response", &client).await; + let outcome = execute_guards(&[guard], "Some LLM response", &client, None).await; assert!(!outcome.blocked); assert!(!outcome.warnings.is_empty()); @@ -669,7 +669,7 @@ async fn test_post_call_skipped_on_empty_response() { }; let headers = HeaderMap::new(); - let runner = GuardrailsRunner::new(Some(&guardrails), &headers).unwrap(); + let runner = GuardrailsRunner::new(Some(&guardrails), &headers, None).unwrap(); // Completion with content: None (simulates empty LLM response) let empty_completion = create_test_chat_completion(""); @@ -704,7 +704,7 @@ async fn test_evaluator_error_not_blocked_by_default() { // guard.required is false by default let client = TraceloopClient::new(); - let outcome = execute_guards(&[guard], "test input", &client).await; + let outcome = execute_guards(&[guard], "test input", &client, None).await; assert!(!outcome.blocked); } diff --git a/tests/guardrails/test_runner.rs b/tests/guardrails/test_runner.rs index fbf5574d..a1a1ead3 100644 --- a/tests/guardrails/test_runner.rs +++ b/tests/guardrails/test_runner.rs @@ -12,7 +12,7 @@ use super::helpers::*; async fn test_execute_single_pre_call_guard_passes() { let guard = create_test_guard("check", GuardMode::PreCall); let mock_client = MockGuardrailClient::with_response("check", Ok(passing_response())); - let outcome = execute_guards(&[guard], "test input", &mock_client).await; + let outcome = execute_guards(&[guard], "test input", &mock_client, None).await; assert!(!outcome.blocked); assert_eq!(outcome.results.len(), 1); assert!(matches!(&outcome.results[0], GuardResult::Passed { .. })); @@ -24,7 +24,7 @@ async fn test_execute_single_pre_call_guard_fails_block() { let guard = create_test_guard_with_failure_action("check", GuardMode::PreCall, OnFailure::Block); let mock_client = MockGuardrailClient::with_response("check", Ok(failing_response())); - let outcome = execute_guards(&[guard], "toxic input", &mock_client).await; + let outcome = execute_guards(&[guard], "toxic input", &mock_client, None).await; assert!(outcome.blocked); assert_eq!(outcome.blocking_guard, Some("check".to_string())); } @@ -33,7 +33,7 @@ async fn test_execute_single_pre_call_guard_fails_block() { async fn test_execute_single_pre_call_guard_fails_warn() { let guard = create_test_guard_with_failure_action("check", GuardMode::PreCall, OnFailure::Warn); let mock_client = MockGuardrailClient::with_response("check", Ok(failing_response())); - let outcome = execute_guards(&[guard], "borderline input", &mock_client).await; + let outcome = execute_guards(&[guard], "borderline input", &mock_client, None).await; assert!(!outcome.blocked); assert_eq!(outcome.warnings.len(), 1); assert_eq!(outcome.warnings[0].guard_name, "check"); @@ -51,7 +51,7 @@ async fn test_execute_multiple_pre_call_guards_all_pass() { ("guard-2", Ok(passing_response())), ("guard-3", Ok(passing_response())), ]); - let outcome = execute_guards(&guards, "safe input", &mock_client).await; + let outcome = execute_guards(&guards, "safe input", &mock_client, None).await; assert!(!outcome.blocked); assert_eq!(outcome.results.len(), 3); assert!(outcome.warnings.is_empty()); @@ -69,7 +69,7 @@ async fn test_execute_multiple_guards_one_blocks() { ("guard-2", Ok(failing_response())), ("guard-3", Ok(passing_response())), ]); - let outcome = execute_guards(&guards, "input", &mock_client).await; + let outcome = execute_guards(&guards, "input", &mock_client, None).await; assert!(outcome.blocked); assert_eq!(outcome.blocking_guard, Some("guard-2".to_string())); } @@ -86,7 +86,7 @@ async fn test_execute_multiple_guards_one_warns_continue() { ("guard-2", Ok(failing_response())), ("guard-3", Ok(passing_response())), ]); - let outcome = execute_guards(&guards, "input", &mock_client).await; + let outcome = execute_guards(&guards, "input", &mock_client, None).await; assert!(!outcome.blocked); assert_eq!(outcome.results.len(), 3); assert_eq!(outcome.warnings.len(), 1); @@ -101,7 +101,7 @@ async fn test_guard_evaluator_unavailable_required_false() { "connection refused".to_string(), )), ); - let outcome = execute_guards(&[guard], "input", &mock_client).await; + let outcome = execute_guards(&[guard], "input", &mock_client, None).await; assert!(!outcome.blocked); // Fail-open assert!(matches!( &outcome.results[0], @@ -121,7 +121,7 @@ async fn test_guard_evaluator_unavailable_required_true() { "connection refused".to_string(), )), ); - let outcome = execute_guards(&[guard], "input", &mock_client).await; + let outcome = execute_guards(&[guard], "input", &mock_client, None).await; assert!(outcome.blocked); // Fail-closed } @@ -131,7 +131,7 @@ async fn test_execute_post_call_guards_non_streaming() { let mock_client = MockGuardrailClient::with_response("response-check", Ok(passing_response())); let completion = create_test_chat_completion("Safe response text"); let response_text = completion.extract_completion(); - let outcome = execute_guards(&[guard], &response_text, &mock_client).await; + let outcome = execute_guards(&[guard], &response_text, &mock_client, None).await; assert!(!outcome.blocked); } @@ -140,7 +140,7 @@ async fn test_execute_post_call_guards_streaming_accumulated() { let guard = create_test_guard("response-check", GuardMode::PostCall); let mock_client = MockGuardrailClient::with_response("response-check", Ok(passing_response())); let accumulated_text = "Hello world from streaming!"; - let outcome = execute_guards(&[guard], accumulated_text, &mock_client).await; + let outcome = execute_guards(&[guard], accumulated_text, &mock_client, None).await; assert!(!outcome.blocked); } @@ -158,7 +158,7 @@ async fn test_parallel_execution_of_independent_guards() { ("guard-2", Ok(passing_response())), ]); let start = std::time::Instant::now(); - let outcome = execute_guards(&guards, "input", &mock_client).await; + let outcome = execute_guards(&guards, "input", &mock_client, None).await; let _elapsed = start.elapsed(); assert!(!outcome.blocked); assert_eq!(outcome.results.len(), 2); @@ -176,7 +176,7 @@ async fn test_executor_returns_correct_guardrails_outcome() { ("warner", Ok(failing_response())), ("blocker", Ok(failing_response())), ]); - let outcome = execute_guards(&guards, "input", &mock_client).await; + let outcome = execute_guards(&guards, "input", &mock_client, None).await; assert!(outcome.blocked); assert_eq!(outcome.blocking_guard, Some("blocker".to_string())); assert!(outcome.warnings.iter().any(|w| w.guard_name == "warner")); From c123126e1b8a7679483d04222395c0cc808b6867 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Sun, 15 Feb 2026 14:29:57 +0200 Subject: [PATCH 23/59] add spans --- Cargo.lock | 39 +++++++ Cargo.toml | 2 + src/guardrails/runner.rs | 25 +++-- src/pipelines/otel.rs | 57 +++++++--- src/pipelines/pipeline.rs | 10 +- tests/guardrails/test_runner.rs | 185 ++++++++++++++++++++++++++++++++ 6 files changed, 284 insertions(+), 34 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d730fa83..e0bae977 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -204,6 +204,42 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "async-process" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc50921ec0055cdd8a16de48773bfeec5c972598674347252c0399676be7da75" +dependencies = [ + "async-channel 2.5.0", + "async-io", + "async-lock", + "async-signal", + "async-task", + "blocking", + "cfg-if", + "event-listener 5.4.1", + "futures-lite 2.6.1", + "rustix", +] + +[[package]] +name = "async-signal" +version = "0.2.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43c070bbf59cd3570b6b2dd54cd772527c7c3620fce8be898406dd3ed6adc64c" +dependencies = [ + "async-io", + "async-lock", + "atomic-waker", + "cfg-if", + "futures-core", + "futures-io", + "rustix", + "signal-hook-registry", + "slab", + "windows-sys 0.61.2", +] + [[package]] name = "async-std" version = "1.13.2" @@ -214,6 +250,7 @@ dependencies = [ "async-global-executor", "async-io", "async-lock", + "async-process", "crossbeam-utils", "futures-channel", "futures-core", @@ -3284,6 +3321,7 @@ version = "0.27.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "231e9d6ceef9b0b2546ddf52335785ce41252bc7474ee8ba05bfad277be13ab8" dependencies = [ + "async-std", "async-trait", "futures-channel", "futures-executor", @@ -3296,6 +3334,7 @@ dependencies = [ "thiserror 1.0.69", "tokio", "tokio-stream", + "tracing", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 61846189..fd13ebd6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -76,3 +76,5 @@ axum-test = "17" sqlx = { version = "0.8", features = ["runtime-tokio", "postgres", "migrate"] } tokio = { version = "1.45.0", features = ["full"] } reqwest = { version = "0.12", features = ["json"] } +opentelemetry = { version = "0.27" } +opentelemetry_sdk = { version = "0.27", features = ["testing"] } diff --git a/src/guardrails/runner.rs b/src/guardrails/runner.rs index cf11ff5a..da11e2d4 100644 --- a/src/guardrails/runner.rs +++ b/src/guardrails/runner.rs @@ -36,7 +36,6 @@ fn record_guard_span( result: &Result, elapsed: std::time::Duration, input: &str, - guard_count: usize, ) { span.set_attribute(KeyValue::new(GEN_AI_GUARDRAIL_NAME, guard.name.clone())); span.set_attribute(KeyValue::new( @@ -79,7 +78,6 @@ pub async fn execute_guards( ) -> GuardrailsOutcome { debug!(guard_count = guards.len(), "Executing guardrails"); - let guard_count = guards.len(); let parent_cx = parent_cx.cloned(); let futures: Vec<_> = guards @@ -87,22 +85,23 @@ pub async fn execute_guards( .map(|guard| { let parent_cx = parent_cx.clone(); async move { - let start = std::time::Instant::now(); - let result = client.evaluate(guard, input).await; - let elapsed = start.elapsed(); - - // Create child span if tracing context is available - let span = parent_cx.as_ref().map(|cx| { + // Create child span BEFORE evaluation so its start time is correct + let mut span = parent_cx.as_ref().map(|cx| { let tracer = global::tracer("traceloop_hub"); - let mut span = tracer + tracer .span_builder(format!("{}.guard", guard.name)) .with_kind(SpanKind::Internal) - .start_with_context(&tracer, cx); - - record_guard_span(&mut span, guard, &result, elapsed, input, guard_count); - span + .start_with_context(&tracer, cx) }); + let start = std::time::Instant::now(); + let result = client.evaluate(guard, input).await; + let elapsed = start.elapsed(); + + if let Some(s) = &mut span { + record_guard_span(s, guard, &result, elapsed, input); + } + match &result { Ok(resp) => debug!( guard = %guard.name, diff --git a/src/pipelines/otel.rs b/src/pipelines/otel.rs index ea9cd6cb..5cbdd365 100644 --- a/src/pipelines/otel.rs +++ b/src/pipelines/otel.rs @@ -20,7 +20,8 @@ pub trait RecordSpan { } pub struct OtelTracer { - span: BoxedSpan, + root_span: BoxedSpan, + llm_span: Option, accumulated_completion: Option, } @@ -87,21 +88,32 @@ impl OtelTracer { } } - pub fn start(operation: &str, request: &T) -> Self { + pub fn start() -> Self { let tracer = global::tracer("traceloop_hub"); - let mut span = tracer - .span_builder(format!("traceloop_hub.{operation}")) - .with_kind(SpanKind::Client) + let span = tracer + .span_builder("traceloop_hub") + .with_kind(SpanKind::Server) .start(&tracer); - request.record_span(&mut span); - Self { - span, + root_span: span, + llm_span: None, accumulated_completion: None, } } + pub fn start_llm_span(&mut self, operation: &str, request: &T) { + let tracer = global::tracer("traceloop_hub"); + let parent_cx = self.parent_context(); + let mut span = tracer + .span_builder(format!("traceloop_hub.{operation}")) + .with_kind(SpanKind::Client) + .start_with_context(&tracer, &parent_cx); + + request.record_span(&mut span); + self.llm_span = Some(span); + } + pub fn log_chunk(&mut self, chunk: &ChatCompletionChunk) { if self.accumulated_completion.is_none() { self.accumulated_completion = Some(ChatCompletion { @@ -160,29 +172,39 @@ impl OtelTracer { pub fn streaming_end(&mut self) { if let Some(completion) = self.accumulated_completion.take() { - completion.record_span(&mut self.span); - self.span.set_status(Status::Ok); + if let Some(span) = &mut self.llm_span { + completion.record_span(span); + span.set_status(Status::Ok); + } } + self.root_span.set_status(Status::Ok); } pub fn log_success(&mut self, response: &T) { - response.record_span(&mut self.span); - self.span.set_status(Status::Ok); + if let Some(span) = &mut self.llm_span { + response.record_span(span); + span.set_status(Status::Ok); + } + self.root_span.set_status(Status::Ok); } pub fn log_error(&mut self, description: String) { - self.span.set_status(Status::error(description)); + if let Some(span) = &mut self.llm_span { + span.set_status(Status::error(description.clone())); + } + self.root_span.set_status(Status::error(description)); } /// Returns an OTel Context carrying this tracer's root span as parent, /// suitable for creating child spans. pub fn parent_context(&self) -> Context { - Context::current().with_remote_span_context(self.span.span_context().clone()) + Context::current().with_remote_span_context(self.root_span.span_context().clone()) } pub fn set_vendor(&mut self, vendor: &str) { - self.span - .set_attribute(KeyValue::new(GEN_AI_SYSTEM, vendor.to_string())); + if let Some(span) = &mut self.llm_span { + span.set_attribute(KeyValue::new(GEN_AI_SYSTEM, vendor.to_string())); + } } } @@ -419,7 +441,8 @@ mod tests { // Test that set_vendor method compiles and can be called // This ensures the method signature is correct let mut tracer = OtelTracer { - span: opentelemetry::global::tracer("test").start("test"), + root_span: opentelemetry::global::tracer("test").start("test"), + llm_span: Some(opentelemetry::global::tracer("test").start("test_llm")), accumulated_completion: None, }; diff --git a/src/pipelines/pipeline.rs b/src/pipelines/pipeline.rs index 9ab4ca7a..df7691ce 100644 --- a/src/pipelines/pipeline.rs +++ b/src/pipelines/pipeline.rs @@ -125,7 +125,7 @@ pub async fn chat_completions( model_keys: Vec, guardrails: Option>, ) -> Result { - let mut tracer = OtelTracer::start("chat", &payload); + let mut tracer = OtelTracer::start(); let parent_cx = tracer.parent_context(); let orchestrator = GuardrailsRunner::new(guardrails.as_deref(), &headers, Some(parent_cx)); @@ -143,6 +143,7 @@ pub async fn chat_completions( let model = model_registry.get(&model_key).unwrap(); if payload.model == model.model_type { + tracer.start_llm_span("chat", &payload); tracer.set_vendor(&get_vendor_name(&model.provider.r#type())); let response = model @@ -193,12 +194,13 @@ pub async fn completions( Json(payload): Json, model_keys: Vec, ) -> Result { - let mut tracer = OtelTracer::start("completion", &payload); + let mut tracer = OtelTracer::start(); for model_key in model_keys { let model = model_registry.get(&model_key).unwrap(); if payload.model == model.model_type { + tracer.start_llm_span("completion", &payload); tracer.set_vendor(&get_vendor_name(&model.provider.r#type())); let response = model.completions(payload.clone()).await.inspect_err(|e| { @@ -220,13 +222,13 @@ pub async fn embeddings( Json(payload): Json, model_keys: Vec, ) -> impl IntoResponse { - let mut tracer = OtelTracer::start("embeddings", &payload); + let mut tracer = OtelTracer::start(); for model_key in model_keys { let model = model_registry.get(&model_key).unwrap(); if payload.model == model.model_type { - // Set vendor now that we know which model/provider we're using + tracer.start_llm_span("embeddings", &payload); tracer.set_vendor(&get_vendor_name(&model.provider.r#type())); let response = model.embeddings(payload.clone()).await.inspect_err(|e| { diff --git a/tests/guardrails/test_runner.rs b/tests/guardrails/test_runner.rs index a1a1ead3..5fab19fa 100644 --- a/tests/guardrails/test_runner.rs +++ b/tests/guardrails/test_runner.rs @@ -1,6 +1,11 @@ use hub_lib::guardrails::parsing::CompletionExtractor; use hub_lib::guardrails::runner::*; use hub_lib::guardrails::types::*; +use opentelemetry::trace::{Span, SpanKind, TraceContextExt, Tracer}; +use opentelemetry::Context; +use opentelemetry_sdk::export::trace::SpanData; +use opentelemetry_sdk::testing::trace::InMemorySpanExporter; +use opentelemetry_sdk::trace::TracerProvider; use super::helpers::*; @@ -181,3 +186,183 @@ async fn test_executor_returns_correct_guardrails_outcome() { assert_eq!(outcome.blocking_guard, Some("blocker".to_string())); assert!(outcome.warnings.iter().any(|w| w.guard_name == "warner")); } + +// --------------------------------------------------------------------------- +// Guard Span Creation +// --------------------------------------------------------------------------- + +use std::sync::LazyLock; + +/// Shared OTel exporter + provider, initialized once for all span tests. +/// Each test creates a unique parent span with a unique trace_id, then filters +/// exported spans by that trace_id — so tests are isolated despite sharing state. +static TEST_EXPORTER: LazyLock = LazyLock::new(|| { + let exporter = InMemorySpanExporter::default(); + let provider = TracerProvider::builder() + .with_simple_exporter(exporter.clone()) + .build(); + opentelemetry::global::set_tracer_provider(provider); + exporter +}); + +/// Helper: create a parent Context from the global tracer, returning +/// the Context and the parent's SpanContext for later assertions. +fn create_parent_context() -> (Context, opentelemetry::trace::SpanContext) { + let tracer = opentelemetry::global::tracer("traceloop_hub"); + let parent_span = tracer.start("traceloop_hub"); + let span_ctx = parent_span.span_context().clone(); + let cx = Context::current().with_span(parent_span); + (cx, span_ctx) +} + +/// Helper: collect guard spans from the shared exporter, filtering by trace_id. +fn get_guard_spans(trace_id: opentelemetry::trace::TraceId) -> Vec { + TEST_EXPORTER + .get_finished_spans() + .unwrap() + .into_iter() + .filter(|s| s.span_context.trace_id() == trace_id) + .filter(|s| s.name.ends_with(".guard")) + .collect() +} + +#[tokio::test] +async fn test_guard_spans_created_with_parent_context() { + let _ = &*TEST_EXPORTER; // ensure global provider is set + let (parent_cx, parent_span_ctx) = create_parent_context(); + + let guards = vec![ + create_test_guard("pii-check", GuardMode::PreCall), + create_test_guard("secrets-check", GuardMode::PostCall), + ]; + let mock_client = MockGuardrailClient::with_responses(vec![ + ("pii-check", Ok(passing_response())), + ("secrets-check", Ok(failing_response())), + ]); + + let _outcome = execute_guards(&guards, "test input", &mock_client, Some(&parent_cx)).await; + drop(parent_cx); + + let spans = get_guard_spans(parent_span_ctx.trace_id()); + assert_eq!(spans.len(), 2, "Expected 2 guard spans, got {}", spans.len()); + + let span_names: Vec<&str> = spans.iter().map(|s| s.name.as_ref()).collect(); + assert!(span_names.contains(&"pii-check.guard")); + assert!(span_names.contains(&"secrets-check.guard")); + + // All guard spans should be children of the parent + for span in &spans { + assert_eq!( + span.parent_span_id, parent_span_ctx.span_id(), + "Guard span '{}' should be child of the parent span", + span.name + ); + assert_eq!(span.span_context.trace_id(), parent_span_ctx.trace_id()); + assert_eq!(span.span_kind, SpanKind::Internal); + } +} + +#[tokio::test] +async fn test_guard_span_attributes_on_pass() { + let _ = &*TEST_EXPORTER; + let (parent_cx, parent_span_ctx) = create_parent_context(); + + let guard = create_test_guard("pii-check", GuardMode::PreCall); + let mock_client = MockGuardrailClient::with_response("pii-check", Ok(passing_response())); + + let _outcome = execute_guards(&[guard], "hello world", &mock_client, Some(&parent_cx)).await; + drop(parent_cx); + + let spans = get_guard_spans(parent_span_ctx.trace_id()); + assert_eq!(spans.len(), 1); + + let span = &spans[0]; + let attrs: std::collections::HashMap = span + .attributes + .iter() + .map(|kv| (kv.key.to_string(), kv.value.to_string())) + .collect(); + + assert_eq!(attrs.get("gen_ai.guardrail.name").unwrap(), "pii-check"); + assert_eq!(attrs.get("gen_ai.guardrail.status").unwrap(), "PASSED"); + assert!(attrs.contains_key("gen_ai.guardrail.duration")); +} + +#[tokio::test] +async fn test_guard_span_attributes_on_fail() { + let _ = &*TEST_EXPORTER; + let (parent_cx, parent_span_ctx) = create_parent_context(); + + let guard = + create_test_guard_with_failure_action("toxicity", GuardMode::PreCall, OnFailure::Block); + let mock_client = MockGuardrailClient::with_response("toxicity", Ok(failing_response())); + + let _outcome = execute_guards(&[guard], "bad input", &mock_client, Some(&parent_cx)).await; + drop(parent_cx); + + let spans = get_guard_spans(parent_span_ctx.trace_id()); + assert_eq!(spans.len(), 1); + + let span = &spans[0]; + let attrs: std::collections::HashMap = span + .attributes + .iter() + .map(|kv| (kv.key.to_string(), kv.value.to_string())) + .collect(); + + assert_eq!(attrs.get("gen_ai.guardrail.name").unwrap(), "toxicity"); + assert_eq!(attrs.get("gen_ai.guardrail.status").unwrap(), "FAILED"); +} + +#[tokio::test] +async fn test_guard_span_attributes_on_error() { + let _ = &*TEST_EXPORTER; + let (parent_cx, parent_span_ctx) = create_parent_context(); + + let guard = create_test_guard_with_required("failing-guard", GuardMode::PreCall, true); + let mock_client = MockGuardrailClient::with_response( + "failing-guard", + Err(GuardrailError::Timeout("timed out".to_string())), + ); + + let _outcome = + execute_guards(&[guard], "test input", &mock_client, Some(&parent_cx)).await; + drop(parent_cx); + + let spans = get_guard_spans(parent_span_ctx.trace_id()); + assert_eq!(spans.len(), 1); + + let span = &spans[0]; + let attrs: std::collections::HashMap = span + .attributes + .iter() + .map(|kv| (kv.key.to_string(), kv.value.to_string())) + .collect(); + + assert_eq!(attrs.get("gen_ai.guardrail.name").unwrap(), "failing-guard"); + assert_eq!(attrs.get("gen_ai.guardrail.status").unwrap(), "ERROR"); + assert_eq!(attrs.get("gen_ai.guardrail.error.type").unwrap(), "Timeout"); + assert!(attrs.get("gen_ai.guardrail.error.message").unwrap().contains("timed out")); +} + +#[tokio::test] +async fn test_no_guard_spans_without_parent_context() { + let _ = &*TEST_EXPORTER; + + // Create a unique trace to establish a "before" baseline + let (marker_cx, marker_span_ctx) = create_parent_context(); + drop(marker_cx); + + let guard = create_test_guard("pii-check", GuardMode::PreCall); + let mock_client = MockGuardrailClient::with_response("pii-check", Ok(passing_response())); + + // Run with None parent — no guard spans should be created + let _outcome = execute_guards(&[guard], "test input", &mock_client, None).await; + + // No guard spans should share the marker's trace_id (nothing was parented to it) + let guard_spans = get_guard_spans(marker_span_ctx.trace_id()); + assert!( + guard_spans.is_empty(), + "No guard spans should be created when parent_cx is None" + ); +} From 3950409e2f72f00872b17f53b831fa6349abab4b Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Sun, 15 Feb 2026 14:50:19 +0200 Subject: [PATCH 24/59] lint --- src/guardrails/providers/traceloop.rs | 2 +- src/guardrails/runner.rs | 16 ++++++++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/guardrails/providers/traceloop.rs b/src/guardrails/providers/traceloop.rs index 9cd3ba8c..0698e0f0 100644 --- a/src/guardrails/providers/traceloop.rs +++ b/src/guardrails/providers/traceloop.rs @@ -1,5 +1,4 @@ use async_trait::async_trait; -use std::collections::HashMap; use tracing::debug; use super::GuardrailClient; @@ -88,6 +87,7 @@ impl GuardrailClient for TraceloopClient { mod tests { use super::*; use serde_json::json; + use std::collections::HashMap; #[test] fn test_build_body_text_slug() { diff --git a/src/guardrails/runner.rs b/src/guardrails/runner.rs index da11e2d4..ad081c1f 100644 --- a/src/guardrails/runner.rs +++ b/src/guardrails/runner.rs @@ -49,7 +49,11 @@ fn record_guard_span( match result { Ok(resp) => { - let status = if resp.pass { GUARDRAIL_PASSED } else { GUARDRAIL_FAILED }; + let status = if resp.pass { + GUARDRAIL_PASSED + } else { + GUARDRAIL_FAILED + }; span.set_attribute(KeyValue::new(GEN_AI_GUARDRAIL_STATUS, status)); } Err(err) => { @@ -269,9 +273,13 @@ impl<'a> GuardrailsRunner<'a> { }; } - let outcome = - execute_guards(&self.post_call, &completion, self.client, self.parent_cx.as_ref()) - .await; + let outcome = execute_guards( + &self.post_call, + &completion, + self.client, + self.parent_cx.as_ref(), + ) + .await; if outcome.blocked { return GuardPhaseResult { blocked_response: Some(blocked_response(&outcome)), From fd1bed2b679791adcc755a184b774c4649b8d12b Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Sun, 15 Feb 2026 14:55:04 +0200 Subject: [PATCH 25/59] build fix --- tests/guardrails/test_e2e.rs | 11 ++++------- tests/guardrails/test_parsing.rs | 5 ++++- tests/guardrails/test_runner.rs | 22 ++++++++++++++++------ tests/guardrails/test_traceloop_client.rs | 1 + 4 files changed, 25 insertions(+), 14 deletions(-) diff --git a/tests/guardrails/test_e2e.rs b/tests/guardrails/test_e2e.rs index f315ad0d..62eb8ab5 100644 --- a/tests/guardrails/test_e2e.rs +++ b/tests/guardrails/test_e2e.rs @@ -1,7 +1,7 @@ use hub_lib::guardrails::parsing::{CompletionExtractor, PromptExtractor}; use hub_lib::guardrails::providers::traceloop::TraceloopClient; use hub_lib::guardrails::runner::{ - blocked_response, execute_guards, warning_header_value, GuardrailsRunner, + GuardrailsRunner, blocked_response, execute_guards, warning_header_value, }; use hub_lib::guardrails::setup::{build_guardrail_resources, build_pipeline_guardrails}; use hub_lib::guardrails::types::*; @@ -560,7 +560,7 @@ async fn test_pre_call_guardrails_warn_and_continue() { GuardMode::PreCall, OnFailure::Warn, &eval_server.uri(), - "tone", + "tone-detection", ); let client = TraceloopClient::new(); @@ -588,7 +588,7 @@ async fn test_post_call_guardrails_warn_and_add_header() { GuardMode::PostCall, OnFailure::Warn, &eval_server.uri(), - "safety", + "pii-detector", ); let client = TraceloopClient::new(); @@ -646,10 +646,7 @@ async fn test_post_call_skipped_on_empty_response() { // post-call guards should be skipped and a warning returned. let eval_server = MockServer::start().await; Mock::given(matchers::any()) - .respond_with( - ResponseTemplate::new(200) - .set_body_json(json!({"result": {}, "pass": true})), - ) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({"result": {}, "pass": true}))) .expect(0) // evaluator should never be called .mount(&eval_server) .await; diff --git a/tests/guardrails/test_parsing.rs b/tests/guardrails/test_parsing.rs index 423c5f4c..195c6553 100644 --- a/tests/guardrails/test_parsing.rs +++ b/tests/guardrails/test_parsing.rs @@ -43,7 +43,10 @@ fn test_extract_text_multi_turn_conversation() { }, ]; let text = request.extract_pompt(); - assert_eq!(text, "Follow-up question"); + assert_eq!( + text, + "You are helpful\nFirst question\nFirst answer\nFollow-up question" + ); } #[test] diff --git a/tests/guardrails/test_runner.rs b/tests/guardrails/test_runner.rs index 5fab19fa..418fadff 100644 --- a/tests/guardrails/test_runner.rs +++ b/tests/guardrails/test_runner.rs @@ -1,8 +1,8 @@ use hub_lib::guardrails::parsing::CompletionExtractor; use hub_lib::guardrails::runner::*; use hub_lib::guardrails::types::*; -use opentelemetry::trace::{Span, SpanKind, TraceContextExt, Tracer}; use opentelemetry::Context; +use opentelemetry::trace::{Span, SpanKind, TraceContextExt, Tracer}; use opentelemetry_sdk::export::trace::SpanData; use opentelemetry_sdk::testing::trace::InMemorySpanExporter; use opentelemetry_sdk::trace::TracerProvider; @@ -244,7 +244,12 @@ async fn test_guard_spans_created_with_parent_context() { drop(parent_cx); let spans = get_guard_spans(parent_span_ctx.trace_id()); - assert_eq!(spans.len(), 2, "Expected 2 guard spans, got {}", spans.len()); + assert_eq!( + spans.len(), + 2, + "Expected 2 guard spans, got {}", + spans.len() + ); let span_names: Vec<&str> = spans.iter().map(|s| s.name.as_ref()).collect(); assert!(span_names.contains(&"pii-check.guard")); @@ -253,7 +258,8 @@ async fn test_guard_spans_created_with_parent_context() { // All guard spans should be children of the parent for span in &spans { assert_eq!( - span.parent_span_id, parent_span_ctx.span_id(), + span.parent_span_id, + parent_span_ctx.span_id(), "Guard span '{}' should be child of the parent span", span.name ); @@ -325,8 +331,7 @@ async fn test_guard_span_attributes_on_error() { Err(GuardrailError::Timeout("timed out".to_string())), ); - let _outcome = - execute_guards(&[guard], "test input", &mock_client, Some(&parent_cx)).await; + let _outcome = execute_guards(&[guard], "test input", &mock_client, Some(&parent_cx)).await; drop(parent_cx); let spans = get_guard_spans(parent_span_ctx.trace_id()); @@ -342,7 +347,12 @@ async fn test_guard_span_attributes_on_error() { assert_eq!(attrs.get("gen_ai.guardrail.name").unwrap(), "failing-guard"); assert_eq!(attrs.get("gen_ai.guardrail.status").unwrap(), "ERROR"); assert_eq!(attrs.get("gen_ai.guardrail.error.type").unwrap(), "Timeout"); - assert!(attrs.get("gen_ai.guardrail.error.message").unwrap().contains("timed out")); + assert!( + attrs + .get("gen_ai.guardrail.error.message") + .unwrap() + .contains("timed out") + ); } #[tokio::test] diff --git a/tests/guardrails/test_traceloop_client.rs b/tests/guardrails/test_traceloop_client.rs index 9ebd42e4..4b5b9994 100644 --- a/tests/guardrails/test_traceloop_client.rs +++ b/tests/guardrails/test_traceloop_client.rs @@ -61,6 +61,7 @@ async fn test_traceloop_client_sends_correct_body() { .await; let mut guard = create_test_guard_with_api_base("test", GuardMode::PreCall, &mock_server.uri()); + guard.evaluator_slug = "toxicity-detector".to_string(); guard.params.insert("threshold".to_string(), json!(0.5)); let client = TraceloopClient::new(); From 764e64d775a65eb0e974a12397aacab972af3410 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Sun, 15 Feb 2026 15:41:40 +0200 Subject: [PATCH 26/59] test comments --- tests/guardrails/test_e2e.rs | 6 ------ tests/guardrails/test_parsing.rs | 3 --- tests/guardrails/test_runner.rs | 3 --- tests/guardrails/test_setup.rs | 3 --- tests/guardrails/test_traceloop_client.rs | 3 --- tests/guardrails/test_types.rs | 4 ---- 6 files changed, 22 deletions(-) diff --git a/tests/guardrails/test_e2e.rs b/tests/guardrails/test_e2e.rs index 62eb8ab5..7f0bdbcf 100644 --- a/tests/guardrails/test_e2e.rs +++ b/tests/guardrails/test_e2e.rs @@ -16,12 +16,6 @@ use wiremock::{Mock, MockServer, ResponseTemplate}; use super::helpers::*; -// --------------------------------------------------------------------------- -// Phase 8: End-to-End Integration (15 tests) -// -// Full request flow tests using wiremock for evaluator services. -// These validate the complete lifecycle from request to response. -// --------------------------------------------------------------------------- /// Helper: set up a wiremock evaluator that returns pass/fail async fn setup_evaluator(pass: bool) -> MockServer { diff --git a/tests/guardrails/test_parsing.rs b/tests/guardrails/test_parsing.rs index 195c6553..266f0547 100644 --- a/tests/guardrails/test_parsing.rs +++ b/tests/guardrails/test_parsing.rs @@ -6,9 +6,6 @@ use hub_lib::models::content::{ChatCompletionMessage, ChatMessageContent, ChatMe use super::helpers::*; -// --------------------------------------------------------------------------- -// Input Extraction (5 tests) -// --------------------------------------------------------------------------- #[test] fn test_extract_text_single_user_message() { diff --git a/tests/guardrails/test_runner.rs b/tests/guardrails/test_runner.rs index 418fadff..89653e46 100644 --- a/tests/guardrails/test_runner.rs +++ b/tests/guardrails/test_runner.rs @@ -9,9 +9,6 @@ use opentelemetry_sdk::trace::TracerProvider; use super::helpers::*; -// --------------------------------------------------------------------------- -// Guard Execution (12 tests) -// --------------------------------------------------------------------------- #[tokio::test] async fn test_execute_single_pre_call_guard_passes() { diff --git a/tests/guardrails/test_setup.rs b/tests/guardrails/test_setup.rs index cb1d3b19..46353e2c 100644 --- a/tests/guardrails/test_setup.rs +++ b/tests/guardrails/test_setup.rs @@ -5,9 +5,6 @@ use hub_lib::guardrails::types::*; use super::helpers::*; -// --------------------------------------------------------------------------- -// Header Parsing & Guard Resolution (15 tests) -// --------------------------------------------------------------------------- #[test] fn test_parse_guardrails_header_single() { diff --git a/tests/guardrails/test_traceloop_client.rs b/tests/guardrails/test_traceloop_client.rs index 4b5b9994..b0b6368b 100644 --- a/tests/guardrails/test_traceloop_client.rs +++ b/tests/guardrails/test_traceloop_client.rs @@ -8,9 +8,6 @@ use wiremock::{Mock, MockServer, ResponseTemplate}; use super::helpers::*; -// --------------------------------------------------------------------------- -// Phase 4: Provider Client System (7 tests) -// --------------------------------------------------------------------------- #[tokio::test] async fn test_traceloop_client_constructs_correct_url() { diff --git a/tests/guardrails/test_types.rs b/tests/guardrails/test_types.rs index 47543852..d7071c60 100644 --- a/tests/guardrails/test_types.rs +++ b/tests/guardrails/test_types.rs @@ -3,10 +3,6 @@ use hub_lib::types::GatewayConfig; use std::io::Write; use tempfile::NamedTempFile; -// --------------------------------------------------------------------------- -// Phase 1: Core Types & Configuration (9 tests + 4 provider tests) -// --------------------------------------------------------------------------- - #[test] fn test_guard_mode_deserialize_pre_call() { let mode: GuardMode = serde_json::from_str("\"pre_call\"").unwrap(); From 40db3abf6be3ee374b21e24ce909c8afcebffa8f Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Sun, 15 Feb 2026 16:34:53 +0200 Subject: [PATCH 27/59] add evaluator test --- .../guardrails/pii_detector_pass.json | 18 ++ .../guardrails/profanity_detector_fail.json | 18 ++ .../guardrails/prompt_injection_pass.json | 18 ++ .../guardrails/sexism_detector_fail.json | 23 ++ .../guardrails/tone_detection_fail.json | 19 ++ .../guardrails/toxicity_detector_fail.json | 23 ++ .../guardrails/uncertainty_detector_pass.json | 19 ++ tests/guardrails/main.rs | 1 + tests/guardrails/test_run_evaluator.rs | 281 ++++++++++++++++++ 9 files changed, 420 insertions(+) create mode 100644 tests/cassettes/guardrails/pii_detector_pass.json create mode 100644 tests/cassettes/guardrails/profanity_detector_fail.json create mode 100644 tests/cassettes/guardrails/prompt_injection_pass.json create mode 100644 tests/cassettes/guardrails/sexism_detector_fail.json create mode 100644 tests/cassettes/guardrails/tone_detection_fail.json create mode 100644 tests/cassettes/guardrails/toxicity_detector_fail.json create mode 100644 tests/cassettes/guardrails/uncertainty_detector_pass.json create mode 100644 tests/guardrails/test_run_evaluator.rs diff --git a/tests/cassettes/guardrails/pii_detector_pass.json b/tests/cassettes/guardrails/pii_detector_pass.json new file mode 100644 index 00000000..5cf34332 --- /dev/null +++ b/tests/cassettes/guardrails/pii_detector_pass.json @@ -0,0 +1,18 @@ +{ + "evaluator_slug": "pii-detector", + "input_text": "The weather is sunny today and I like programming in Rust.", + "params": {}, + "request_body": { + "input": { + "text": "The weather is sunny today and I like programming in Rust." + } + }, + "response_status": 200, + "response_body": { + "result": { + "has_pii": false + }, + "pass": true + }, + "expected_pass": true +} \ No newline at end of file diff --git a/tests/cassettes/guardrails/profanity_detector_fail.json b/tests/cassettes/guardrails/profanity_detector_fail.json new file mode 100644 index 00000000..e0c2ac4b --- /dev/null +++ b/tests/cassettes/guardrails/profanity_detector_fail.json @@ -0,0 +1,18 @@ +{ + "evaluator_slug": "profanity-detector", + "input_text": "This is damn bullshit and I think it's a total crap product.", + "params": {}, + "request_body": { + "input": { + "text": "This is damn bullshit and I think it's a total crap product." + } + }, + "response_status": 200, + "response_body": { + "result": { + "is_safe": false + }, + "pass": false + }, + "expected_pass": false +} \ No newline at end of file diff --git a/tests/cassettes/guardrails/prompt_injection_pass.json b/tests/cassettes/guardrails/prompt_injection_pass.json new file mode 100644 index 00000000..0824f460 --- /dev/null +++ b/tests/cassettes/guardrails/prompt_injection_pass.json @@ -0,0 +1,18 @@ +{ + "evaluator_slug": "prompt-injection", + "input_text": "What is the capital of France?", + "params": {}, + "request_body": { + "input": { + "prompt": "What is the capital of France?" + } + }, + "response_status": 200, + "response_body": { + "result": { + "has_injection": false + }, + "pass": true + }, + "expected_pass": true +} \ No newline at end of file diff --git a/tests/cassettes/guardrails/sexism_detector_fail.json b/tests/cassettes/guardrails/sexism_detector_fail.json new file mode 100644 index 00000000..a909716e --- /dev/null +++ b/tests/cassettes/guardrails/sexism_detector_fail.json @@ -0,0 +1,23 @@ +{ + "evaluator_slug": "sexism-detector", + "input_text": "Women should not be in leadership positions because they are too emotional.", + "params": { + "threshold": 0.5 + }, + "request_body": { + "input": { + "text": "Women should not be in leadership positions because they are too emotional." + }, + "config": { + "threshold": 0.5 + } + }, + "response_status": 200, + "response_body": { + "result": { + "is_safe": false + }, + "pass": false + }, + "expected_pass": false +} \ No newline at end of file diff --git a/tests/cassettes/guardrails/tone_detection_fail.json b/tests/cassettes/guardrails/tone_detection_fail.json new file mode 100644 index 00000000..5e17a6ec --- /dev/null +++ b/tests/cassettes/guardrails/tone_detection_fail.json @@ -0,0 +1,19 @@ +{ + "evaluator_slug": "tone-detection", + "input_text": "This is ABSOLUTELY UNACCEPTABLE. I DEMAND to speak to someone competent immediately!", + "params": {}, + "request_body": { + "input": { + "text": "This is ABSOLUTELY UNACCEPTABLE. I DEMAND to speak to someone competent immediately!" + } + }, + "response_status": 200, + "response_body": { + "result": { + "tone": "", + "score": 0 + }, + "pass": false + }, + "expected_pass": false +} \ No newline at end of file diff --git a/tests/cassettes/guardrails/toxicity_detector_fail.json b/tests/cassettes/guardrails/toxicity_detector_fail.json new file mode 100644 index 00000000..97001c50 --- /dev/null +++ b/tests/cassettes/guardrails/toxicity_detector_fail.json @@ -0,0 +1,23 @@ +{ + "evaluator_slug": "toxicity-detector", + "input_text": "You are a complete idiot and everyone hates you. You should be ashamed.", + "params": { + "threshold": 0.5 + }, + "request_body": { + "input": { + "text": "You are a complete idiot and everyone hates you. You should be ashamed." + }, + "config": { + "threshold": 0.5 + } + }, + "response_status": 200, + "response_body": { + "result": { + "is_safe": false + }, + "pass": false + }, + "expected_pass": false +} \ No newline at end of file diff --git a/tests/cassettes/guardrails/uncertainty_detector_pass.json b/tests/cassettes/guardrails/uncertainty_detector_pass.json new file mode 100644 index 00000000..629d33fd --- /dev/null +++ b/tests/cassettes/guardrails/uncertainty_detector_pass.json @@ -0,0 +1,19 @@ +{ + "evaluator_slug": "uncertainty-detector", + "input_text": "What is 2 + 2?", + "params": {}, + "request_body": { + "input": { + "prompt": "What is 2 + 2?" + } + }, + "response_status": 200, + "response_body": { + "result": { + "answer": "", + "uncertainty": 0 + }, + "pass": true + }, + "expected_pass": true +} \ No newline at end of file diff --git a/tests/guardrails/main.rs b/tests/guardrails/main.rs index 1e18de92..9ac26a4f 100644 --- a/tests/guardrails/main.rs +++ b/tests/guardrails/main.rs @@ -1,4 +1,5 @@ mod helpers; +mod test_run_evaluator; mod test_e2e; mod test_parsing; mod test_runner; diff --git a/tests/guardrails/test_run_evaluator.rs b/tests/guardrails/test_run_evaluator.rs new file mode 100644 index 00000000..15267d71 --- /dev/null +++ b/tests/guardrails/test_run_evaluator.rs @@ -0,0 +1,281 @@ +use hub_lib::guardrails::evaluator_types::get_evaluator; +use hub_lib::guardrails::providers::traceloop::TraceloopClient; +use hub_lib::guardrails::types::*; +use serde::Deserialize; +use serde_json::{json, Value}; +use std::collections::HashMap; +use std::fs; +use std::path::PathBuf; +use wiremock::matchers; +use wiremock::{Mock, MockServer, ResponseTemplate}; + +use super::helpers::*; + +// --------------------------------------------------------------------------- +// Infrastructure +// --------------------------------------------------------------------------- + +struct EvaluatorTestCase { + slug: &'static str, + cassette_name: &'static str, + input: &'static str, + params: HashMap, + expected_pass: bool, +} + +#[derive(Deserialize)] +struct Cassette { + response_body: Value, +} + +fn load_cassette(name: &str) -> Cassette { + let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("tests/cassettes/guardrails") + .join(format!("{}.json", name)); + let content = fs::read_to_string(&path) + .unwrap_or_else(|_| panic!("Cassette '{}' not found at {:?}.", name, path)); + serde_json::from_str(&content).expect("Failed to parse cassette JSON") +} + +/// Set up a wiremock server, execute the guard, and verify the request was correct. +async fn run_evaluator_test(tc: &EvaluatorTestCase) { + let cassette = load_cassette(tc.cassette_name); + + // Build the expected request body using the evaluator's build_body() + let evaluator = get_evaluator(tc.slug).unwrap(); + let expected_body = evaluator.build_body(tc.input, &tc.params).unwrap(); + + // Set up wiremock with strict matchers that verify the request + let server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .and(matchers::path(format!( + "/v2/guardrails/execute/{}", + tc.slug + ))) + .and(matchers::header("Authorization", "Bearer test-api-key")) + .and(matchers::header("Content-Type", "application/json")) + .and(matchers::body_json(&expected_body)) + .respond_with( + ResponseTemplate::new(200).set_body_json(&cassette.response_body), + ) + .expect(1) + .mount(&server) + .await; + + // Create guard pointing at the mock server and execute + let mut guard = + create_test_guard_with_api_base(tc.slug, GuardMode::PreCall, &server.uri()); + guard.evaluator_slug = tc.slug.to_string(); + guard.params = tc.params.clone(); + + let client = TraceloopClient::new(); + let result = client.evaluate(&guard, tc.input).await.unwrap(); + + // Verify the response was interpreted correctly + assert_eq!( + result.pass, tc.expected_pass, + "{}: expected pass={}, got pass={}", + tc.slug, tc.expected_pass, result.pass + ); + // wiremock .expect(1) verifies the request matched all matchers +} + +// --------------------------------------------------------------------------- +// 1. PII Detector (text body, optional probability_threshold config) +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_cassette_pii_detector() { + run_evaluator_test(&EvaluatorTestCase { + slug: "pii-detector", + cassette_name: "pii_detector_pass", + input: "The weather is sunny today and I like programming in Rust.", + params: HashMap::new(), + expected_pass: true, + }) + .await; +} + +// --------------------------------------------------------------------------- +// 2. Secrets Detector (text body, no config) +// --------------------------------------------------------------------------- + +#[ignore = "secrets-detector returns HTTP 500 on current API"] +#[tokio::test] +async fn test_cassette_secrets_detector() { + run_evaluator_test(&EvaluatorTestCase { + slug: "secrets-detector", + cassette_name: "secrets_detector_pass", + input: "Here is a simple function that adds two numbers together.", + params: HashMap::new(), + expected_pass: true, + }) + .await; +} + +// --------------------------------------------------------------------------- +// 3. Prompt Injection (prompt body, optional threshold config) +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_cassette_prompt_injection() { + run_evaluator_test(&EvaluatorTestCase { + slug: "prompt-injection", + cassette_name: "prompt_injection_pass", + input: "What is the capital of France?", + params: HashMap::new(), + expected_pass: true, + }) + .await; +} + +// --------------------------------------------------------------------------- +// 4. Profanity Detector (text body, no config) +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_cassette_profanity_detector() { + run_evaluator_test(&EvaluatorTestCase { + slug: "profanity-detector", + cassette_name: "profanity_detector_fail", + input: "This is damn bullshit and I think it's a total crap product.", + params: HashMap::new(), + expected_pass: false, + }) + .await; +} + +// --------------------------------------------------------------------------- +// 5. Sexism Detector (text body, threshold config) +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_cassette_sexism_detector() { + run_evaluator_test(&EvaluatorTestCase { + slug: "sexism-detector", + cassette_name: "sexism_detector_fail", + input: "Women should not be in leadership positions because they are too emotional.", + params: HashMap::from([("threshold".to_string(), json!(0.5))]), + expected_pass: false, + }) + .await; +} + +// --------------------------------------------------------------------------- +// 6. Toxicity Detector (text body, threshold config) +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_cassette_toxicity_detector() { + run_evaluator_test(&EvaluatorTestCase { + slug: "toxicity-detector", + cassette_name: "toxicity_detector_fail", + input: "You are a complete idiot and everyone hates you. You should be ashamed.", + params: HashMap::from([("threshold".to_string(), json!(0.5))]), + expected_pass: false, + }) + .await; +} + +// --------------------------------------------------------------------------- +// 7. Regex Validator (text body, regex config) +// --------------------------------------------------------------------------- + +#[ignore = "regex-validator returns HTTP 500 on current API"] +#[tokio::test] +async fn test_cassette_regex_validator() { + run_evaluator_test(&EvaluatorTestCase { + slug: "regex-validator", + cassette_name: "regex_validator_pass", + input: "Order ID: ABC-12345", + params: HashMap::from([ + ("regex".to_string(), json!(r"[A-Z]{3}-\d{5}")), + ("should_match".to_string(), json!(true)), + ]), + expected_pass: true, + }) + .await; +} + +// --------------------------------------------------------------------------- +// 8. JSON Validator (text body, optional schema config) +// --------------------------------------------------------------------------- + +#[ignore = "json-validator returns HTTP 500 on current API"] +#[tokio::test] +async fn test_cassette_json_validator() { + run_evaluator_test(&EvaluatorTestCase { + slug: "json-validator", + cassette_name: "json_validator_pass", + input: r#"{"name": "Alice", "age": 30}"#, + params: HashMap::new(), + expected_pass: true, + }) + .await; +} + +// --------------------------------------------------------------------------- +// 9. SQL Validator (text body, no config) +// --------------------------------------------------------------------------- + +#[ignore = "sql-validator returns HTTP 500 on current API"] +#[tokio::test] +async fn test_cassette_sql_validator() { + run_evaluator_test(&EvaluatorTestCase { + slug: "sql-validator", + cassette_name: "sql_validator_pass", + input: "SELECT id, name FROM users WHERE active = true ORDER BY name", + params: HashMap::new(), + expected_pass: true, + }) + .await; +} + +// --------------------------------------------------------------------------- +// 10. Tone Detection (text body, no config) +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_cassette_tone_detection() { + run_evaluator_test(&EvaluatorTestCase { + slug: "tone-detection", + cassette_name: "tone_detection_fail", + input: "This is ABSOLUTELY UNACCEPTABLE. I DEMAND to speak to someone competent immediately!", + params: HashMap::new(), + expected_pass: false, + }) + .await; +} + +// --------------------------------------------------------------------------- +// 11. Prompt Perplexity (prompt body, no config) +// --------------------------------------------------------------------------- + +#[ignore = "prompt-perplexity returns HTTP 500 on current API"] +#[tokio::test] +async fn test_cassette_prompt_perplexity() { + run_evaluator_test(&EvaluatorTestCase { + slug: "prompt-perplexity", + cassette_name: "prompt_perplexity_pass", + input: "Please explain the concept of photosynthesis in simple terms.", + params: HashMap::new(), + expected_pass: true, + }) + .await; +} + +// --------------------------------------------------------------------------- +// 12. Uncertainty Detector (prompt body, no config) +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_cassette_uncertainty_detector() { + run_evaluator_test(&EvaluatorTestCase { + slug: "uncertainty-detector", + cassette_name: "uncertainty_detector_pass", + input: "What is 2 + 2?", + params: HashMap::new(), + expected_pass: true, + }) + .await; +} From 89f0a49c5e191c5f58018b372febf5eb31f9fe20 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Sun, 15 Feb 2026 17:03:29 +0200 Subject: [PATCH 28/59] update recording --- .../guardrails/json_validator_pass.json | 25 +++++++++++++++ .../guardrails/prompt_injection_pass.json | 7 ++++- .../guardrails/prompt_perplexity_pass.json | 18 +++++++++++ .../guardrails/regex_validator_pass.json | 31 +++++++++++++++++++ .../guardrails/secrets_detector_pass.json | 18 +++++++++++ .../guardrails/sexism_detector_fail.json | 4 +-- .../guardrails/sql_validator_pass.json | 18 +++++++++++ .../guardrails/toxicity_detector_fail.json | 6 ---- tests/guardrails/test_run_evaluator.rs | 5 --- 9 files changed, 118 insertions(+), 14 deletions(-) create mode 100644 tests/cassettes/guardrails/json_validator_pass.json create mode 100644 tests/cassettes/guardrails/prompt_perplexity_pass.json create mode 100644 tests/cassettes/guardrails/regex_validator_pass.json create mode 100644 tests/cassettes/guardrails/secrets_detector_pass.json create mode 100644 tests/cassettes/guardrails/sql_validator_pass.json diff --git a/tests/cassettes/guardrails/json_validator_pass.json b/tests/cassettes/guardrails/json_validator_pass.json new file mode 100644 index 00000000..793c6aa1 --- /dev/null +++ b/tests/cassettes/guardrails/json_validator_pass.json @@ -0,0 +1,25 @@ +{ + "evaluator_slug": "json-validator", + "input_text": "{\"name\": \"Alice\", \"age\": 30}", + "params": { + "enable_schema_validation": true, + "schema_string": "{\"name\": \"string\", \"age\": \"number\"}" + }, + "request_body": { + "input": { + "text": "{\"name\": \"Alice\", \"age\": 30}" + }, + "config": { + "enable_schema_validation": true, + "schema_string": "{\"name\": \"string\", \"age\": \"number\"}" + } + }, + "response_status": 200, + "response_body": { + "result": { + "is_valid_json": true + }, + "pass": true + }, + "expected_pass": true +} \ No newline at end of file diff --git a/tests/cassettes/guardrails/prompt_injection_pass.json b/tests/cassettes/guardrails/prompt_injection_pass.json index 0824f460..df30fb7b 100644 --- a/tests/cassettes/guardrails/prompt_injection_pass.json +++ b/tests/cassettes/guardrails/prompt_injection_pass.json @@ -1,10 +1,15 @@ { "evaluator_slug": "prompt-injection", "input_text": "What is the capital of France?", - "params": {}, + "params": { + "threshold": 0.8 + }, "request_body": { "input": { "prompt": "What is the capital of France?" + }, + "config": { + "threshold": 0.8 } }, "response_status": 200, diff --git a/tests/cassettes/guardrails/prompt_perplexity_pass.json b/tests/cassettes/guardrails/prompt_perplexity_pass.json new file mode 100644 index 00000000..dc8ec0e2 --- /dev/null +++ b/tests/cassettes/guardrails/prompt_perplexity_pass.json @@ -0,0 +1,18 @@ +{ + "evaluator_slug": "prompt-perplexity", + "input_text": "Please explain the concept of photosynthesis in simple terms.", + "params": {}, + "request_body": { + "input": { + "prompt": "Please explain the concept of photosynthesis in simple terms." + } + }, + "response_status": 200, + "response_body": { + "result": { + "perplexity_score": 80.0 + }, + "pass": true + }, + "expected_pass": true +} \ No newline at end of file diff --git a/tests/cassettes/guardrails/regex_validator_pass.json b/tests/cassettes/guardrails/regex_validator_pass.json new file mode 100644 index 00000000..713dc9c6 --- /dev/null +++ b/tests/cassettes/guardrails/regex_validator_pass.json @@ -0,0 +1,31 @@ +{ + "evaluator_slug": "regex-validator", + "input_text": "Order ID: ABC-12345", + "params": { + "regex": "^[A-Z]{3}-\\d{5}$", + "should_match": true, + "case_sensitive": true, + "dot_include_nl": true, + "multi_line": true + }, + "request_body": { + "input": { + "text": "Order ID: ABC-12345" + }, + "config": { + "regex": "^[A-Z]{3}-\\d{5}$", + "should_match": true, + "case_sensitive": true, + "dot_include_nl": true, + "multi_line": true + } + }, + "response_status": 200, + "response_body": { + "result": { + "is_valid_regex": true + }, + "pass": true + }, + "expected_pass": true +} \ No newline at end of file diff --git a/tests/cassettes/guardrails/secrets_detector_pass.json b/tests/cassettes/guardrails/secrets_detector_pass.json new file mode 100644 index 00000000..1044cb82 --- /dev/null +++ b/tests/cassettes/guardrails/secrets_detector_pass.json @@ -0,0 +1,18 @@ +{ + "evaluator_slug": "secrets-detector", + "input_text": "You are a complete idiot and everyone hates you. You should be ashamed.", + "params": {}, + "request_body": { + "input": { + "text": "You are a complete idiot and everyone hates you. You should be ashamed." + } + }, + "response_status": 200, + "response_body": { + "result": { + "has_secret": false + }, + "pass": true + }, + "expected_pass": true +} \ No newline at end of file diff --git a/tests/cassettes/guardrails/sexism_detector_fail.json b/tests/cassettes/guardrails/sexism_detector_fail.json index a909716e..2d682c50 100644 --- a/tests/cassettes/guardrails/sexism_detector_fail.json +++ b/tests/cassettes/guardrails/sexism_detector_fail.json @@ -2,14 +2,14 @@ "evaluator_slug": "sexism-detector", "input_text": "Women should not be in leadership positions because they are too emotional.", "params": { - "threshold": 0.5 + "threshold": 0.7 }, "request_body": { "input": { "text": "Women should not be in leadership positions because they are too emotional." }, "config": { - "threshold": 0.5 + "threshold": 0.7 } }, "response_status": 200, diff --git a/tests/cassettes/guardrails/sql_validator_pass.json b/tests/cassettes/guardrails/sql_validator_pass.json new file mode 100644 index 00000000..944aef11 --- /dev/null +++ b/tests/cassettes/guardrails/sql_validator_pass.json @@ -0,0 +1,18 @@ +{ + "evaluator_slug": "sql-validator", + "input_text": "SELECT id, name FROM users WHERE active = true ORDER BY name", + "params": {}, + "request_body": { + "input": { + "text": "SELECT id, name FROM users WHERE active = true ORDER BY name" + } + }, + "response_status": 200, + "response_body": { + "result": { + "is_valid_sql": true + }, + "pass": true + }, + "expected_pass": true +} \ No newline at end of file diff --git a/tests/cassettes/guardrails/toxicity_detector_fail.json b/tests/cassettes/guardrails/toxicity_detector_fail.json index 97001c50..6bd610e6 100644 --- a/tests/cassettes/guardrails/toxicity_detector_fail.json +++ b/tests/cassettes/guardrails/toxicity_detector_fail.json @@ -1,15 +1,9 @@ { "evaluator_slug": "toxicity-detector", "input_text": "You are a complete idiot and everyone hates you. You should be ashamed.", - "params": { - "threshold": 0.5 - }, "request_body": { "input": { "text": "You are a complete idiot and everyone hates you. You should be ashamed." - }, - "config": { - "threshold": 0.5 } }, "response_status": 200, diff --git a/tests/guardrails/test_run_evaluator.rs b/tests/guardrails/test_run_evaluator.rs index 15267d71..5b3a5b8c 100644 --- a/tests/guardrails/test_run_evaluator.rs +++ b/tests/guardrails/test_run_evaluator.rs @@ -100,7 +100,6 @@ async fn test_cassette_pii_detector() { // 2. Secrets Detector (text body, no config) // --------------------------------------------------------------------------- -#[ignore = "secrets-detector returns HTTP 500 on current API"] #[tokio::test] async fn test_cassette_secrets_detector() { run_evaluator_test(&EvaluatorTestCase { @@ -181,7 +180,6 @@ async fn test_cassette_toxicity_detector() { // 7. Regex Validator (text body, regex config) // --------------------------------------------------------------------------- -#[ignore = "regex-validator returns HTTP 500 on current API"] #[tokio::test] async fn test_cassette_regex_validator() { run_evaluator_test(&EvaluatorTestCase { @@ -201,7 +199,6 @@ async fn test_cassette_regex_validator() { // 8. JSON Validator (text body, optional schema config) // --------------------------------------------------------------------------- -#[ignore = "json-validator returns HTTP 500 on current API"] #[tokio::test] async fn test_cassette_json_validator() { run_evaluator_test(&EvaluatorTestCase { @@ -218,7 +215,6 @@ async fn test_cassette_json_validator() { // 9. SQL Validator (text body, no config) // --------------------------------------------------------------------------- -#[ignore = "sql-validator returns HTTP 500 on current API"] #[tokio::test] async fn test_cassette_sql_validator() { run_evaluator_test(&EvaluatorTestCase { @@ -251,7 +247,6 @@ async fn test_cassette_tone_detection() { // 11. Prompt Perplexity (prompt body, no config) // --------------------------------------------------------------------------- -#[ignore = "prompt-perplexity returns HTTP 500 on current API"] #[tokio::test] async fn test_cassette_prompt_perplexity() { run_evaluator_test(&EvaluatorTestCase { From cd2381d402f8b645efa9a75e353ac5b12a1ef84a Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Sun, 15 Feb 2026 17:06:22 +0200 Subject: [PATCH 29/59] ci tests --- tests/guardrails/main.rs | 2 +- tests/guardrails/test_e2e.rs | 1 - tests/guardrails/test_parsing.rs | 1 - tests/guardrails/test_run_evaluator.rs | 9 +++------ tests/guardrails/test_runner.rs | 1 - tests/guardrails/test_setup.rs | 1 - tests/guardrails/test_traceloop_client.rs | 1 - 7 files changed, 4 insertions(+), 12 deletions(-) diff --git a/tests/guardrails/main.rs b/tests/guardrails/main.rs index 9ac26a4f..bbaf1496 100644 --- a/tests/guardrails/main.rs +++ b/tests/guardrails/main.rs @@ -1,7 +1,7 @@ mod helpers; -mod test_run_evaluator; mod test_e2e; mod test_parsing; +mod test_run_evaluator; mod test_runner; mod test_setup; mod test_traceloop_client; diff --git a/tests/guardrails/test_e2e.rs b/tests/guardrails/test_e2e.rs index 7f0bdbcf..1180f1d9 100644 --- a/tests/guardrails/test_e2e.rs +++ b/tests/guardrails/test_e2e.rs @@ -16,7 +16,6 @@ use wiremock::{Mock, MockServer, ResponseTemplate}; use super::helpers::*; - /// Helper: set up a wiremock evaluator that returns pass/fail async fn setup_evaluator(pass: bool) -> MockServer { let server = MockServer::start().await; diff --git a/tests/guardrails/test_parsing.rs b/tests/guardrails/test_parsing.rs index 266f0547..d041f704 100644 --- a/tests/guardrails/test_parsing.rs +++ b/tests/guardrails/test_parsing.rs @@ -6,7 +6,6 @@ use hub_lib::models::content::{ChatCompletionMessage, ChatMessageContent, ChatMe use super::helpers::*; - #[test] fn test_extract_text_single_user_message() { let request = create_test_chat_request("Hello world"); diff --git a/tests/guardrails/test_run_evaluator.rs b/tests/guardrails/test_run_evaluator.rs index 5b3a5b8c..3678a795 100644 --- a/tests/guardrails/test_run_evaluator.rs +++ b/tests/guardrails/test_run_evaluator.rs @@ -2,7 +2,7 @@ use hub_lib::guardrails::evaluator_types::get_evaluator; use hub_lib::guardrails::providers::traceloop::TraceloopClient; use hub_lib::guardrails::types::*; use serde::Deserialize; -use serde_json::{json, Value}; +use serde_json::{Value, json}; use std::collections::HashMap; use std::fs; use std::path::PathBuf; @@ -55,16 +55,13 @@ async fn run_evaluator_test(tc: &EvaluatorTestCase) { .and(matchers::header("Authorization", "Bearer test-api-key")) .and(matchers::header("Content-Type", "application/json")) .and(matchers::body_json(&expected_body)) - .respond_with( - ResponseTemplate::new(200).set_body_json(&cassette.response_body), - ) + .respond_with(ResponseTemplate::new(200).set_body_json(&cassette.response_body)) .expect(1) .mount(&server) .await; // Create guard pointing at the mock server and execute - let mut guard = - create_test_guard_with_api_base(tc.slug, GuardMode::PreCall, &server.uri()); + let mut guard = create_test_guard_with_api_base(tc.slug, GuardMode::PreCall, &server.uri()); guard.evaluator_slug = tc.slug.to_string(); guard.params = tc.params.clone(); diff --git a/tests/guardrails/test_runner.rs b/tests/guardrails/test_runner.rs index 89653e46..f29708fb 100644 --- a/tests/guardrails/test_runner.rs +++ b/tests/guardrails/test_runner.rs @@ -9,7 +9,6 @@ use opentelemetry_sdk::trace::TracerProvider; use super::helpers::*; - #[tokio::test] async fn test_execute_single_pre_call_guard_passes() { let guard = create_test_guard("check", GuardMode::PreCall); diff --git a/tests/guardrails/test_setup.rs b/tests/guardrails/test_setup.rs index 46353e2c..3652b5f3 100644 --- a/tests/guardrails/test_setup.rs +++ b/tests/guardrails/test_setup.rs @@ -5,7 +5,6 @@ use hub_lib::guardrails::types::*; use super::helpers::*; - #[test] fn test_parse_guardrails_header_single() { let names = parse_guardrails_header("pii-check"); diff --git a/tests/guardrails/test_traceloop_client.rs b/tests/guardrails/test_traceloop_client.rs index b0b6368b..58a3f34c 100644 --- a/tests/guardrails/test_traceloop_client.rs +++ b/tests/guardrails/test_traceloop_client.rs @@ -8,7 +8,6 @@ use wiremock::{Mock, MockServer, ResponseTemplate}; use super::helpers::*; - #[tokio::test] async fn test_traceloop_client_constructs_correct_url() { let mock_server = MockServer::start().await; From 99dff3ddaca839285bd2c94754479f2467f5a261 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Sun, 15 Feb 2026 17:24:45 +0200 Subject: [PATCH 30/59] add md --- src/guardrails/GUARDRAILS.md | 168 +++++++++++++++++++++++++++++++++++ 1 file changed, 168 insertions(+) create mode 100644 src/guardrails/GUARDRAILS.md diff --git a/src/guardrails/GUARDRAILS.md b/src/guardrails/GUARDRAILS.md new file mode 100644 index 00000000..7ce5962d --- /dev/null +++ b/src/guardrails/GUARDRAILS.md @@ -0,0 +1,168 @@ +# Guardrails + +Guardrails intercept LLM gateway traffic to evaluate requests and responses against configurable safety, validation, and quality checks. Each guard calls an external evaluator API and either **blocks** the request (HTTP 403) or **warns** via a response header, depending on configuration. + +> **Full reference documentation:** *(coming soon — link to external docs will be added here)* + +--- + +## How It Works + +``` + ┌──────────────┐ + Request ──────►│ Pre-call │──── Block (403) ──► Client + │ Guards │ + └──────┬───────┘ + │ pass + ▼ + ┌──────────────┐ + │ LLM Call │ + └──────┬───────┘ + │ + ▼ + ┌──────────────┐ + │ Post-call │──── Block (403) ──► Client + │ Guards │ + └──────┬───────┘ + │ pass + ▼ + Response (+ warning headers if any) +``` + +1. **Pre-call guards** run on the user's prompt *before* it reaches the LLM. +2. **Post-call guards** run on the LLM's response *before* it is returned to the client. +3. All guards in a phase execute **concurrently** for minimal latency. + +--- + +## Configuration + +Guards are defined in the gateway YAML config under `guardrails`. Provider-level defaults for `api_base` and `api_key` can be inherited by guards or overridden per-guard. + +```yaml +guardrails: + providers: + - name: traceloop + api_base: https://api.traceloop.com + api_key: ${TRACELOOP_API_KEY} + + guards: + - name: pii-check + provider: traceloop + evaluator_slug: pii-detector + mode: pre_call # pre_call | post_call + on_failure: block # block | warn + required: false # when true, evaluator errors block the request; when false, they warn and continue (default: false) + params: # evaluator-specific parameters + probability_threshold: 0.7 +``` + +Pipelines reference guards by name: + +```yaml +pipelines: + - name: default + guards: [pii-check, injection-check] + plugins: + - model-router: + models: [gpt-4] +``` + +### Runtime Guard Addition + +Guards can be added (never removed) at request time via: +- **Header:** `X-Traceloop-Guardrails: extra-guard-1, extra-guard-2` + +These are **additive** to the pipeline-configured guards, preserving the security baseline. + +--- + +## Supported Evaluators + +| Slug | Category | Configurable Params | +|---|---|---| +| `pii-detector` | Safety | `probability_threshold` | +| `secrets-detector` | Safety | — | +| `prompt-injection` | Safety | `threshold` | +| `profanity-detector` | Safety | — | +| `sexism-detector` | Safety | `threshold` | +| `toxicity-detector` | Safety | `threshold` | +| `regex-validator` | Validation | `regex`, `should_match`, `case_sensitive`, `dot_include_nl`, `multi_line` | +| `json-validator` | Validation | `enable_schema_validation`, `schema_string` | +| `sql-validator` | Validation | — | +| `tone-detection` | Quality | — | +| `prompt-perplexity` | Quality | — | +| `uncertainty-detector` | Quality | — | + +--- + +## Failure Behavior + +| Evaluation Result | `on_failure` | `required` | Action | +|---|---|---|---| +| Pass | — | — | Continue | +| Fail | `block` | — | Return 403 | +| Fail | `warn` | — | Add warning header, continue | +| Evaluator error | — | `true` | Return 403 (fail-closed) | +| Evaluator error | — | `false` | Add warning header, continue (fail-open) | + +**Blocked response (403):** +```json +{ + "error": { + "type": "guardrail_blocked", + "guardrail": "pii-check", + "message": "Request blocked by guardrail 'pii-check'", + "evaluation_result": { ... }, + "reason": "evaluation_failed" + } +} +``` + +**Warning header:** +``` +X-Traceloop-Guardrail-Warning: guardrail_name="toxicity-filter", reason="failed" +``` + +--- + +## Observability + +Each guard evaluation emits an OpenTelemetry child span with these attributes: + +| Attribute | Description | +|---|---| +| `gen_ai.guardrail.name` | Guard name | +| `gen_ai.guardrail.status` | `PASSED`, `FAILED`, or `ERROR` | +| `gen_ai.guardrail.duration` | Evaluation time in ms | +| `gen_ai.guardrail.input` | Input text (when `trace_content_enabled`) | +| `gen_ai.guardrail.error.type` | Error category (`Unavailable`, `HttpError`, `Timeout`, `ParseError`) | +| `gen_ai.guardrail.error.message` | Error details | + +--- + +## Source Layout + +``` +src/guardrails/ +├── mod.rs # Module exports +├── types.rs # Core types: Guard, GuardrailsConfig, GuardResult, GuardrailsOutcome +├── evaluator_types.rs # Evaluator slug registry and request body builders +├── parsing.rs # Input/output extraction from chat requests/responses +├── runner.rs # Execution orchestration (GuardrailsRunner, execute_guards) +├── setup.rs # Config resolution, resource building, guard merging +├── span_attributes.rs # OpenTelemetry attribute constants +└── providers/ + ├── mod.rs # GuardrailClient trait re-export + └── traceloop.rs # Traceloop evaluator API HTTP client +``` + +### Key Types + +- **`Guard`** — a single guardrail definition (name, evaluator slug, mode, failure policy, credentials) +- **`GuardrailsConfig`** — top-level config containing provider defaults and guard list +- **`Guardrails`** — per-pipeline runtime state holding shared guards + client +- **`GuardrailsRunner`** — per-request orchestrator that runs pre/post phases +- **`GuardrailClient`** (trait) — provider implementation for calling evaluator APIs +- **`EvaluatorRequest`** (trait) — evaluator-specific request body builder +- **`GuardrailsOutcome`** — aggregated results from a guard execution phase From d4efb4389978be8075dad2a85ca59d3817f9a572 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Mon, 16 Feb 2026 11:41:40 +0200 Subject: [PATCH 31/59] add validation --- src/config/validation.rs | 45 +++++++++++++++++++++++++++++++++++++++- src/guardrails/runner.rs | 2 +- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/src/config/validation.rs b/src/config/validation.rs index 16da3c4c..60f44c24 100644 --- a/src/config/validation.rs +++ b/src/config/validation.rs @@ -37,7 +37,15 @@ pub fn validate_gateway_config(config: &GatewayConfig) -> Result<(), Vec } } - // Check 3: Guardrails validation + // Check 3: If any pipeline specifies guards, guardrails section must exist + let has_pipeline_guards = config.pipelines.iter().any(|p| !p.guards.is_empty()); + if has_pipeline_guards && config.guardrails.is_none() { + errors.push( + "One or more pipelines specify guards, but the 'guardrails' section is missing.".to_string() + ); + } + + // Check 4: Guardrails validation if let Some(gr_config) = &config.guardrails { // Guard provider references must exist in guardrails.providers for guard in &gr_config.guards { @@ -448,4 +456,39 @@ mod tests { .any(|e| e.contains("unknown evaluator_slug 'made-up-slug'")) ); } + + #[test] + fn test_pipeline_guards_without_guardrails_section() { + // Test the case where a pipeline specifies guards but guardrails section is missing + let config = GatewayConfig { + guardrails: None, // Missing guardrails section + general: None, + providers: vec![Provider { + key: "p1".to_string(), + r#type: ProviderType::OpenAI, + api_key: "key1".to_string(), + params: Default::default(), + }], + models: vec![ModelConfig { + key: "m1".to_string(), + r#type: "gpt-4".to_string(), + provider: "p1".to_string(), + params: Default::default(), + }], + pipelines: vec![Pipeline { + name: "pipe1".to_string(), + r#type: PipelineType::Chat, + plugins: vec![PluginConfig::ModelRouter { + models: vec!["m1".to_string()], + }], + guards: vec!["g1".to_string()], // Pipeline specifies a guard + }], + }; + let result = validate_gateway_config(&config); + assert!(result.is_err()); + let errors = result.unwrap_err(); + assert_eq!(errors.len(), 1); + assert!(errors[0].contains("pipelines specify guards")); + assert!(errors[0].contains("guardrails' section is missing")); + } } diff --git a/src/guardrails/runner.rs b/src/guardrails/runner.rs index ad081c1f..568a0cd7 100644 --- a/src/guardrails/runner.rs +++ b/src/guardrails/runner.rs @@ -302,7 +302,7 @@ impl<'a> GuardrailsRunner<'a> { let mut response = response; response .headers_mut() - .insert("X-Traceloop-Guardrail-Warning", header_val.parse().unwrap()); + .insert("x-traceloop-guardrail-warning", header_val.parse().unwrap()); response } } From 3d4743a8f05c1072eb00563e9cbb1e40e5cd5bbb Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Mon, 16 Feb 2026 11:46:10 +0200 Subject: [PATCH 32/59] added timeout --- src/guardrails/providers/traceloop.rs | 11 +++-------- src/pipelines/pipeline.rs | 5 ----- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/src/guardrails/providers/traceloop.rs b/src/guardrails/providers/traceloop.rs index 0698e0f0..24fdfde5 100644 --- a/src/guardrails/providers/traceloop.rs +++ b/src/guardrails/providers/traceloop.rs @@ -1,5 +1,4 @@ use async_trait::async_trait; -use tracing::debug; use super::GuardrailClient; use crate::guardrails::evaluator_types::get_evaluator; @@ -7,6 +6,8 @@ use crate::guardrails::parsing::parse_evaluator_http_response; use crate::guardrails::types::{EvaluatorResponse, Guard, GuardrailError}; const DEFAULT_TRACELOOP_API: &str = "https://api.traceloop.com"; +const DEFAULT_TIMEOUT_SEC: u64 = 3; + /// HTTP client for the Traceloop evaluator API service. /// Calls `POST {api_base}/v2/guardrails/{evaluator_slug}`. pub struct TraceloopClient { @@ -21,9 +22,7 @@ impl Default for TraceloopClient { impl TraceloopClient { pub fn new() -> Self { - Self { - http_client: reqwest::Client::new(), - } + Self::with_timeout(std::time::Duration::from_secs(DEFAULT_TIMEOUT_SEC)) } pub fn with_timeout(timeout: std::time::Duration) -> Self { @@ -63,8 +62,6 @@ impl GuardrailClient for TraceloopClient { })?; let body = evaluator.build_body(input, &guard.params)?; - debug!(guard = %guard.name, slug = %guard.evaluator_slug, %url, %body, "NOMI - Calling evaluator API"); - let response = self .http_client .post(&url) @@ -77,8 +74,6 @@ impl GuardrailClient for TraceloopClient { let status = response.status().as_u16(); let response_body = response.text().await?; - debug!(guard = %guard.name, %status, %response_body, "RON - Evaluator API response"); - parse_evaluator_http_response(status, &response_body) } } diff --git a/src/pipelines/pipeline.rs b/src/pipelines/pipeline.rs index df7691ce..63a63b75 100644 --- a/src/pipelines/pipeline.rs +++ b/src/pipelines/pipeline.rs @@ -156,11 +156,6 @@ pub async fn chat_completions( if let ChatCompletionResponse::NonStream(completion) = response { tracer.log_success(&completion); - tracing::debug!( - completion = %serde_json::to_string(&completion).unwrap_or_default(), - "AASA - LLM response before post-call guardrails" - ); - // Post-call guardrails (non-streaming) if let Some(orch) = &orchestrator { let post = orch.run_post_call(&completion).await; From 72e6a83f9fe9fa44068f7ed2f6d27a4668bfd3ed Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Mon, 16 Feb 2026 11:49:30 +0200 Subject: [PATCH 33/59] add header parsing --- src/guardrails/runner.rs | 17 ++++++++++--- tests/guardrails/test_runner.rs | 43 +++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 3 deletions(-) diff --git a/src/guardrails/runner.rs b/src/guardrails/runner.rs index 568a0cd7..cc4d9eab 100644 --- a/src/guardrails/runner.rs +++ b/src/guardrails/runner.rs @@ -300,9 +300,20 @@ impl<'a> GuardrailsRunner<'a> { } let header_val = warning_header_value(warnings); let mut response = response; - response - .headers_mut() - .insert("x-traceloop-guardrail-warning", header_val.parse().unwrap()); + match header_val.parse() { + Ok(parsed_header) => { + response + .headers_mut() + .insert("x-traceloop-guardrail-warning", parsed_header); + } + Err(e) => { + warn!( + error = %e, + header_value = %header_val, + "Failed to parse guardrail warning header, skipping header" + ); + } + } response } } diff --git a/tests/guardrails/test_runner.rs b/tests/guardrails/test_runner.rs index f29708fb..5940aa73 100644 --- a/tests/guardrails/test_runner.rs +++ b/tests/guardrails/test_runner.rs @@ -372,3 +372,46 @@ async fn test_no_guard_spans_without_parent_context() { "No guard spans should be created when parent_cx is None" ); } + +// --------------------------------------------------------------------------- +// Response Finalization Tests +// --------------------------------------------------------------------------- + +#[test] +fn test_finalize_response_with_valid_warnings() { + use axum::response::IntoResponse; + let response = (axum::http::StatusCode::OK, "test body").into_response(); + let warnings = vec![GuardWarning { + guard_name: "test-guard".to_string(), + reason: "failed".to_string(), + }]; + let finalized = GuardrailsRunner::finalize_response(response, &warnings); + let headers = finalized.headers(); + assert!(headers.contains_key("x-traceloop-guardrail-warning")); +} + +#[test] +fn test_finalize_response_with_invalid_header_characters() { + use axum::response::IntoResponse; + let response = (axum::http::StatusCode::OK, "test body").into_response(); + // Guard name with newline character (invalid in HTTP headers) + let warnings = vec![GuardWarning { + guard_name: "test-guard\nwith-newline".to_string(), + reason: "failed".to_string(), + }]; + // Should not panic, header should be skipped + let finalized = GuardrailsRunner::finalize_response(response, &warnings); + let headers = finalized.headers(); + // Header should not be present since it couldn't be parsed + assert!(!headers.contains_key("x-traceloop-guardrail-warning")); +} + +#[test] +fn test_finalize_response_with_no_warnings() { + use axum::response::IntoResponse; + let response = (axum::http::StatusCode::OK, "test body").into_response(); + let warnings = vec![]; + let finalized = GuardrailsRunner::finalize_response(response, &warnings); + let headers = finalized.headers(); + assert!(!headers.contains_key("x-traceloop-guardrail-warning")); +} From 60b0a68da62cafa3a6363b0ac98e2ff89ffeb014 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Mon, 16 Feb 2026 12:11:34 +0200 Subject: [PATCH 34/59] extract prompt --- src/guardrails/parsing.rs | 4 ++-- src/guardrails/runner.rs | 2 +- src/pipelines/otel.rs | 6 +++--- tests/guardrails/test_e2e.rs | 8 ++++---- tests/guardrails/test_parsing.rs | 8 ++++---- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/guardrails/parsing.rs b/src/guardrails/parsing.rs index eef38b8a..4bf28bc1 100644 --- a/src/guardrails/parsing.rs +++ b/src/guardrails/parsing.rs @@ -6,7 +6,7 @@ use super::types::{EvaluatorResponse, GuardrailError}; /// Trait for extracting pre-call guardrail input from a request. pub trait PromptExtractor { - fn extract_pompt(&self) -> String; + fn extract_prompt(&self) -> String; } /// Trait for extracting post-call guardrail input from a response. @@ -15,7 +15,7 @@ pub trait CompletionExtractor { } impl PromptExtractor for ChatCompletionRequest { - fn extract_pompt(&self) -> String { + fn extract_prompt(&self) -> String { self.messages .iter() .filter_map(|m| { diff --git a/src/guardrails/runner.rs b/src/guardrails/runner.rs index cc4d9eab..50a3a04a 100644 --- a/src/guardrails/runner.rs +++ b/src/guardrails/runner.rs @@ -237,7 +237,7 @@ impl<'a> GuardrailsRunner<'a> { warnings: Vec::new(), }; } - let input = request.extract_pompt(); + let input = request.extract_prompt(); let outcome = execute_guards(&self.pre_call, &input, self.client, self.parent_cx.as_ref()).await; if outcome.blocked { diff --git a/src/pipelines/otel.rs b/src/pipelines/otel.rs index 5cbdd365..6e18e1c3 100644 --- a/src/pipelines/otel.rs +++ b/src/pipelines/otel.rs @@ -20,8 +20,8 @@ pub trait RecordSpan { } pub struct OtelTracer { - root_span: BoxedSpan, llm_span: Option, + root_span: BoxedSpan, accumulated_completion: Option, } @@ -96,8 +96,8 @@ impl OtelTracer { .start(&tracer); Self { - root_span: span, llm_span: None, + root_span: span, accumulated_completion: None, } } @@ -441,8 +441,8 @@ mod tests { // Test that set_vendor method compiles and can be called // This ensures the method signature is correct let mut tracer = OtelTracer { - root_span: opentelemetry::global::tracer("test").start("test"), llm_span: Some(opentelemetry::global::tracer("test").start("test_llm")), + root_span: opentelemetry::global::tracer("test").start("test"), accumulated_completion: None, }; diff --git a/tests/guardrails/test_e2e.rs b/tests/guardrails/test_e2e.rs index 1180f1d9..29d64c1e 100644 --- a/tests/guardrails/test_e2e.rs +++ b/tests/guardrails/test_e2e.rs @@ -62,7 +62,7 @@ async fn test_e2e_pre_call_block_flow() { ); let request = create_test_chat_request("Bad input"); - let input = request.extract_pompt(); + let input = request.extract_prompt(); let client = TraceloopClient::new(); let outcome = execute_guards(&[guard], &input, &client, None).await; @@ -84,7 +84,7 @@ async fn test_e2e_pre_call_pass_flow() { ); let request = create_test_chat_request("Safe input"); - let input = request.extract_pompt(); + let input = request.extract_prompt(); let client = TraceloopClient::new(); let outcome = execute_guards(&[guard], &input, &client, None).await; @@ -165,7 +165,7 @@ async fn test_e2e_pre_and_post_both_pass() { // Pre-call let request = create_test_chat_request("Hello"); - let input = request.extract_pompt(); + let input = request.extract_prompt(); let pre_outcome = execute_guards(&[pre_guard], &input, &client, None).await; assert!(!pre_outcome.blocked); @@ -207,7 +207,7 @@ async fn test_e2e_pre_blocks_post_never_runs() { let client = TraceloopClient::new(); let request = create_test_chat_request("Bad input"); - let input = request.extract_pompt(); + let input = request.extract_prompt(); let pre_outcome = execute_guards(&[pre_guard], &input, &client, None).await; assert!(pre_outcome.blocked); diff --git a/tests/guardrails/test_parsing.rs b/tests/guardrails/test_parsing.rs index d041f704..5dab2db8 100644 --- a/tests/guardrails/test_parsing.rs +++ b/tests/guardrails/test_parsing.rs @@ -9,7 +9,7 @@ use super::helpers::*; #[test] fn test_extract_text_single_user_message() { let request = create_test_chat_request("Hello world"); - let text = request.extract_pompt(); + let text = request.extract_prompt(); assert_eq!(text, "Hello world"); } @@ -38,7 +38,7 @@ fn test_extract_text_multi_turn_conversation() { ..default_message() }, ]; - let text = request.extract_pompt(); + let text = request.extract_prompt(); assert_eq!( text, "You are helpful\nFirst question\nFirst answer\nFollow-up question" @@ -58,7 +58,7 @@ fn test_extract_text_from_array_content_parts() { text: "Part 2".to_string(), }, ])); - let text = request.extract_pompt(); + let text = request.extract_prompt(); assert_eq!(text, "Part 1 Part 2"); } @@ -73,7 +73,7 @@ fn test_extract_response_from_chat_completion() { fn test_extract_handles_empty_content() { let mut request = create_test_chat_request(""); request.messages[0].content = None; - let text = request.extract_pompt(); + let text = request.extract_prompt(); assert_eq!(text, ""); } From 7edb582d2f1f227df0b14c5fde68194c85470728 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Mon, 16 Feb 2026 12:53:55 +0200 Subject: [PATCH 35/59] records --- tests/cassettes/guardrails/json_validator_pass.json | 2 +- tests/cassettes/guardrails/pii_detector_pass.json | 2 +- tests/cassettes/guardrails/profanity_detector_fail.json | 2 +- tests/cassettes/guardrails/prompt_injection_pass.json | 2 +- tests/cassettes/guardrails/prompt_perplexity_pass.json | 2 +- tests/cassettes/guardrails/regex_validator_pass.json | 2 +- tests/cassettes/guardrails/secrets_detector_pass.json | 2 +- tests/cassettes/guardrails/sexism_detector_fail.json | 2 +- tests/cassettes/guardrails/sql_validator_pass.json | 2 +- tests/cassettes/guardrails/tone_detection_fail.json | 2 +- tests/cassettes/guardrails/toxicity_detector_fail.json | 2 +- tests/cassettes/guardrails/uncertainty_detector_pass.json | 2 +- tests/guardrails/test_run_evaluator.rs | 5 ++++- 13 files changed, 16 insertions(+), 13 deletions(-) diff --git a/tests/cassettes/guardrails/json_validator_pass.json b/tests/cassettes/guardrails/json_validator_pass.json index 793c6aa1..0a049f7c 100644 --- a/tests/cassettes/guardrails/json_validator_pass.json +++ b/tests/cassettes/guardrails/json_validator_pass.json @@ -22,4 +22,4 @@ "pass": true }, "expected_pass": true -} \ No newline at end of file +} diff --git a/tests/cassettes/guardrails/pii_detector_pass.json b/tests/cassettes/guardrails/pii_detector_pass.json index 5cf34332..2bcb62b9 100644 --- a/tests/cassettes/guardrails/pii_detector_pass.json +++ b/tests/cassettes/guardrails/pii_detector_pass.json @@ -15,4 +15,4 @@ "pass": true }, "expected_pass": true -} \ No newline at end of file +} diff --git a/tests/cassettes/guardrails/profanity_detector_fail.json b/tests/cassettes/guardrails/profanity_detector_fail.json index e0c2ac4b..f2fff61e 100644 --- a/tests/cassettes/guardrails/profanity_detector_fail.json +++ b/tests/cassettes/guardrails/profanity_detector_fail.json @@ -15,4 +15,4 @@ "pass": false }, "expected_pass": false -} \ No newline at end of file +} diff --git a/tests/cassettes/guardrails/prompt_injection_pass.json b/tests/cassettes/guardrails/prompt_injection_pass.json index df30fb7b..67b6c61e 100644 --- a/tests/cassettes/guardrails/prompt_injection_pass.json +++ b/tests/cassettes/guardrails/prompt_injection_pass.json @@ -20,4 +20,4 @@ "pass": true }, "expected_pass": true -} \ No newline at end of file +} diff --git a/tests/cassettes/guardrails/prompt_perplexity_pass.json b/tests/cassettes/guardrails/prompt_perplexity_pass.json index dc8ec0e2..5db73ae6 100644 --- a/tests/cassettes/guardrails/prompt_perplexity_pass.json +++ b/tests/cassettes/guardrails/prompt_perplexity_pass.json @@ -15,4 +15,4 @@ "pass": true }, "expected_pass": true -} \ No newline at end of file +} diff --git a/tests/cassettes/guardrails/regex_validator_pass.json b/tests/cassettes/guardrails/regex_validator_pass.json index 713dc9c6..a38772aa 100644 --- a/tests/cassettes/guardrails/regex_validator_pass.json +++ b/tests/cassettes/guardrails/regex_validator_pass.json @@ -28,4 +28,4 @@ "pass": true }, "expected_pass": true -} \ No newline at end of file +} diff --git a/tests/cassettes/guardrails/secrets_detector_pass.json b/tests/cassettes/guardrails/secrets_detector_pass.json index 1044cb82..de879761 100644 --- a/tests/cassettes/guardrails/secrets_detector_pass.json +++ b/tests/cassettes/guardrails/secrets_detector_pass.json @@ -15,4 +15,4 @@ "pass": true }, "expected_pass": true -} \ No newline at end of file +} diff --git a/tests/cassettes/guardrails/sexism_detector_fail.json b/tests/cassettes/guardrails/sexism_detector_fail.json index 2d682c50..9473c761 100644 --- a/tests/cassettes/guardrails/sexism_detector_fail.json +++ b/tests/cassettes/guardrails/sexism_detector_fail.json @@ -20,4 +20,4 @@ "pass": false }, "expected_pass": false -} \ No newline at end of file +} diff --git a/tests/cassettes/guardrails/sql_validator_pass.json b/tests/cassettes/guardrails/sql_validator_pass.json index 944aef11..fdfedfc5 100644 --- a/tests/cassettes/guardrails/sql_validator_pass.json +++ b/tests/cassettes/guardrails/sql_validator_pass.json @@ -15,4 +15,4 @@ "pass": true }, "expected_pass": true -} \ No newline at end of file +} diff --git a/tests/cassettes/guardrails/tone_detection_fail.json b/tests/cassettes/guardrails/tone_detection_fail.json index 5e17a6ec..b37d5475 100644 --- a/tests/cassettes/guardrails/tone_detection_fail.json +++ b/tests/cassettes/guardrails/tone_detection_fail.json @@ -16,4 +16,4 @@ "pass": false }, "expected_pass": false -} \ No newline at end of file +} diff --git a/tests/cassettes/guardrails/toxicity_detector_fail.json b/tests/cassettes/guardrails/toxicity_detector_fail.json index 6bd610e6..2b25d772 100644 --- a/tests/cassettes/guardrails/toxicity_detector_fail.json +++ b/tests/cassettes/guardrails/toxicity_detector_fail.json @@ -14,4 +14,4 @@ "pass": false }, "expected_pass": false -} \ No newline at end of file +} diff --git a/tests/cassettes/guardrails/uncertainty_detector_pass.json b/tests/cassettes/guardrails/uncertainty_detector_pass.json index 629d33fd..8b713854 100644 --- a/tests/cassettes/guardrails/uncertainty_detector_pass.json +++ b/tests/cassettes/guardrails/uncertainty_detector_pass.json @@ -16,4 +16,4 @@ "pass": true }, "expected_pass": true -} \ No newline at end of file +} diff --git a/tests/guardrails/test_run_evaluator.rs b/tests/guardrails/test_run_evaluator.rs index 3678a795..4391e389 100644 --- a/tests/guardrails/test_run_evaluator.rs +++ b/tests/guardrails/test_run_evaluator.rs @@ -184,8 +184,11 @@ async fn test_cassette_regex_validator() { cassette_name: "regex_validator_pass", input: "Order ID: ABC-12345", params: HashMap::from([ - ("regex".to_string(), json!(r"[A-Z]{3}-\d{5}")), + ("regex".to_string(), json!(r"^[A-Z]{3}-\d{5}$")), ("should_match".to_string(), json!(true)), + ("case_sensitive".to_string(), json!(true)), + ("dot_include_nl".to_string(), json!(true)), + ("multi_line".to_string(), json!(true)), ]), expected_pass: true, }) From 254bacd0659900e6b409140c82362f5ba2e450c8 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Mon, 16 Feb 2026 13:22:15 +0200 Subject: [PATCH 36/59] ci --- src/config/validation.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/config/validation.rs b/src/config/validation.rs index 60f44c24..ec5b9cbb 100644 --- a/src/config/validation.rs +++ b/src/config/validation.rs @@ -41,7 +41,8 @@ pub fn validate_gateway_config(config: &GatewayConfig) -> Result<(), Vec let has_pipeline_guards = config.pipelines.iter().any(|p| !p.guards.is_empty()); if has_pipeline_guards && config.guardrails.is_none() { errors.push( - "One or more pipelines specify guards, but the 'guardrails' section is missing.".to_string() + "One or more pipelines specify guards, but the 'guardrails' section is missing." + .to_string(), ); } From 6244dbccea6c4a24e54044492185268a8d5bf1e3 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Mon, 16 Feb 2026 14:51:05 +0200 Subject: [PATCH 37/59] add example --- config-example.yaml | 36 ++++++++++++++++++++++++++++++++++++ src/guardrails/GUARDRAILS.md | 11 ++++++++++- 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/config-example.yaml b/config-example.yaml index e8143fb9..aaca3627 100644 --- a/config-example.yaml +++ b/config-example.yaml @@ -94,6 +94,9 @@ pipelines: # Default pipeline for chat completions - name: default type: chat + guards: # Optional: reference guards by name (defined in guardrails section below) + - pii-check + - injection-check plugins: - logging: level: info # Supported levels: debug, info, warning, error @@ -125,3 +128,36 @@ pipelines: - model-router: models: # List the models you want to use for embeddings - textembedding-gecko + +guardrails: + providers: + - name: traceloop + api_base: ${TRACELOOP_BASE_URL} # or use direct URL + api_key: ${TRACELOOP_API_KEY} # or use "" + guards: + # PII Detection - Pre-call guard + - name: pii-check + provider: traceloop + evaluator_slug: pii-detector + mode: pre_call # Runs before the model call + on_failure: block # Options: block, warn + required: true # If true, request fails if guard is unavailable + + # Prompt Injection Detection - Pre-call guard + - name: injection-check + provider: traceloop + evaluator_slug: prompt-injection + params: + threshold: 0.8 + mode: pre_call + on_failure: block + required: false + + # Toxicity Detection - Post-call guard + - name: toxicity-filter + provider: traceloop + evaluator_slug: toxicity-detector + params: + threshold: 0.8 + mode: post_call # Runs after the model call + on_failure: block diff --git a/src/guardrails/GUARDRAILS.md b/src/guardrails/GUARDRAILS.md index 7ce5962d..cdecae80 100644 --- a/src/guardrails/GUARDRAILS.md +++ b/src/guardrails/GUARDRAILS.md @@ -2,7 +2,16 @@ Guardrails intercept LLM gateway traffic to evaluate requests and responses against configurable safety, validation, and quality checks. Each guard calls an external evaluator API and either **blocks** the request (HTTP 403) or **warns** via a response header, depending on configuration. -> **Full reference documentation:** *(coming soon — link to external docs will be added here)* +> **Available in:** Traceloop Hub v1 +> **Full reference documentation:** [Guardrails Documentation](https://docs.traceloop.com/evaluators/guardrails) + +--- + +Guardrails can be implemented in **Config Mode (Hub v1)** : +Guardrails fully defined in YAML configuration, applied automatically to gateway requests + + +This document focuses on **config mode** available in Traceloop Hub v1. --- From 5ea42d1d10c06050bd2aef0c8627e4d4fa417446 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Wed, 18 Feb 2026 09:08:50 +0200 Subject: [PATCH 38/59] init middle --- src/guardrails/middleware.rs | 157 +++++++++++++++++++++++++++++++++++ src/guardrails/mod.rs | 1 + src/pipelines/pipeline.rs | 45 ++-------- 3 files changed, 166 insertions(+), 37 deletions(-) create mode 100644 src/guardrails/middleware.rs diff --git a/src/guardrails/middleware.rs b/src/guardrails/middleware.rs new file mode 100644 index 00000000..cff0bd91 --- /dev/null +++ b/src/guardrails/middleware.rs @@ -0,0 +1,157 @@ +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use axum::body::Body; +use axum::extract::Request; +use axum::response::{IntoResponse, Response}; +use tower::{Layer, Service}; +use tracing::debug; + +use crate::models::chat::{ChatCompletion, ChatCompletionRequest}; + +use super::runner::GuardrailsRunner; +use super::types::Guardrails; + +/// Maximum request/response body size to buffer (10 MB). +const MAX_BODY_SIZE: usize = 10 * 1024 * 1024; + +/// Tower layer that applies guardrail checks around a service. +/// +/// - **Pre-call guards** run before the inner service, inspecting the request body. +/// - **Post-call guards** run after the inner service, inspecting the response body. +/// - Streaming requests (`"stream": true`) bypass guardrails entirely. +#[derive(Clone)] +pub struct GuardrailsLayer { + guardrails: Option>, +} + +impl GuardrailsLayer { + pub fn new(guardrails: Option>) -> Self { + Self { guardrails } + } +} + +impl Layer for GuardrailsLayer { + type Service = GuardrailsMiddleware; + + fn layer(&self, inner: S) -> Self::Service { + GuardrailsMiddleware { + inner, + guardrails: self.guardrails.clone(), + } + } +} + +/// Tower service that wraps an inner service with guardrail checks. +#[derive(Clone)] +pub struct GuardrailsMiddleware { + inner: S, + guardrails: Option>, +} + +impl Service> for GuardrailsMiddleware +where + S: Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, + S::Error: Into> + Send, +{ + type Response = Response; + type Error = S::Error; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, request: Request) -> Self::Future { + let guardrails = self.guardrails.clone(); + // Clone inner and swap so the clone is used in the future + // (standard Tower pattern to satisfy borrow checker) + let inner = self.inner.clone(); + let mut inner = std::mem::replace(&mut self.inner, inner); + + Box::pin(async move { + // No guardrails configured — pass through without buffering + let guardrails = match guardrails { + Some(gr) => gr, + None => return inner.call(request).await, + }; + + let (parts, body) = request.into_parts(); + + // Buffer request body + let bytes = match axum::body::to_bytes(body, MAX_BODY_SIZE).await { + Ok(b) => b, + Err(_) => { + debug!("Guardrails middleware: failed to buffer request body, passing through"); + return Ok(axum::http::StatusCode::BAD_REQUEST.into_response()); + } + }; + + // Try to parse as ChatCompletionRequest + let chat_request: ChatCompletionRequest = match serde_json::from_slice(&bytes) { + Ok(r) => r, + Err(_) => { + // Not a chat completion request — pass through unchanged + let request = Request::from_parts(parts, Body::from(bytes)); + return inner.call(request).await; + } + }; + + // Skip guardrails for streaming requests + if chat_request.stream.unwrap_or(false) { + debug!("Guardrails middleware: streaming request, skipping guardrails"); + let request = Request::from_parts(parts, Body::from(bytes)); + return inner.call(request).await; + } + + // Resolve guards from pipeline config + request headers + let runner = GuardrailsRunner::new(Some(&guardrails), &parts.headers, None); + + let runner = match runner { + Some(r) => r, + None => { + // No active guards for this request + let request = Request::from_parts(parts, Body::from(bytes)); + return inner.call(request).await; + } + }; + + // --- Pre-call guards --- + let pre_result = runner.run_pre_call(&chat_request).await; + if let Some(blocked) = pre_result.blocked_response { + return Ok(blocked); + } + let mut all_warnings = pre_result.warnings; + + // --- Call inner service --- + let request = Request::from_parts(parts, Body::from(bytes)); + let response = inner.call(request).await?; + + // --- Post-call guards --- + let (resp_parts, resp_body) = response.into_parts(); + let resp_bytes = match axum::body::to_bytes(resp_body, MAX_BODY_SIZE).await { + Ok(b) => b, + Err(_) => { + debug!("Guardrails middleware: failed to buffer response body, skipping post-call"); + let response = Response::from_parts(resp_parts, Body::empty()); + return Ok(GuardrailsRunner::finalize_response(response, &all_warnings)); + } + }; + + if let Ok(completion) = serde_json::from_slice::(&resp_bytes) { + let post_result = runner.run_post_call(&completion).await; + if let Some(blocked) = post_result.blocked_response { + return Ok(blocked); + } + all_warnings.extend(post_result.warnings); + } + + // Reconstruct response with original bytes and attach warning headers + let response = Response::from_parts(resp_parts, Body::from(resp_bytes)); + Ok(GuardrailsRunner::finalize_response(response, &all_warnings)) + }) + } +} diff --git a/src/guardrails/mod.rs b/src/guardrails/mod.rs index d31a5a0b..d496100b 100644 --- a/src/guardrails/mod.rs +++ b/src/guardrails/mod.rs @@ -1,4 +1,5 @@ pub mod evaluator_types; +pub mod middleware; pub mod parsing; pub mod providers; pub mod runner; diff --git a/src/pipelines/pipeline.rs b/src/pipelines/pipeline.rs index 63a63b75..0bda5fcd 100644 --- a/src/pipelines/pipeline.rs +++ b/src/pipelines/pipeline.rs @@ -1,6 +1,6 @@ use crate::config::models::PipelineType; -use crate::guardrails::runner::GuardrailsRunner; -use crate::guardrails::types::{GuardrailResources, Guardrails}; +use crate::guardrails::middleware::GuardrailsLayer; +use crate::guardrails::types::GuardrailResources; use crate::models::chat::ChatCompletionResponse; use crate::models::completion::CompletionRequest; use crate::models::embeddings::EmbeddingsRequest; @@ -13,7 +13,6 @@ use crate::{ models::chat::ChatCompletionRequest, }; use async_stream::stream; -use axum::http::HeaderMap; use axum::response::sse::{Event, KeepAlive}; use axum::response::{IntoResponse, Response, Sse}; use axum::{ @@ -38,7 +37,7 @@ pub fn create_pipeline( model_registry: &ModelRegistry, guardrail_resources: Option<&GuardrailResources>, ) -> Router { - let guardrails: Option> = + let guardrails = guardrail_resources.map(|shared| build_pipeline_guardrails(shared, &pipeline.guards)); let mut router = Router::new(); @@ -65,7 +64,6 @@ pub fn create_pipeline( ); for plugin in pipeline.plugins.clone() { - let gr = guardrails.clone(); router = match plugin { PluginConfig::Tracing { endpoint, api_key } => { tracing::info!("Initializing OtelTracer for pipeline {}", pipeline.name); @@ -75,9 +73,7 @@ pub fn create_pipeline( PluginConfig::ModelRouter { models } => match pipeline.r#type { PipelineType::Chat => router.route( "/chat/completions", - post(move |state, headers, payload| { - chat_completions(state, headers, payload, models, gr) - }), + post(move |state, payload| chat_completions(state, payload, models)), ), PipelineType::Completion => router.route( "/completions", @@ -92,7 +88,9 @@ pub fn create_pipeline( }; } - router.with_state(Arc::new(model_registry.clone())) + router + .with_state(Arc::new(model_registry.clone())) + .layer(GuardrailsLayer::new(guardrails)) } fn trace_and_stream( @@ -120,24 +118,10 @@ fn trace_and_stream( pub async fn chat_completions( State(model_registry): State>, - headers: HeaderMap, Json(payload): Json, model_keys: Vec, - guardrails: Option>, ) -> Result { let mut tracer = OtelTracer::start(); - let parent_cx = tracer.parent_context(); - let orchestrator = GuardrailsRunner::new(guardrails.as_deref(), &headers, Some(parent_cx)); - - // Pre-call guardrails - let mut all_warnings = Vec::new(); - if let Some(orch) = &orchestrator { - let pre = orch.run_pre_call(&payload).await; - if let Some(resp) = pre.blocked_response { - return Ok(resp); - } - all_warnings = pre.warnings; - } for model_key in model_keys { let model = model_registry.get(&model_key).unwrap(); @@ -155,20 +139,7 @@ pub async fn chat_completions( if let ChatCompletionResponse::NonStream(completion) = response { tracer.log_success(&completion); - - // Post-call guardrails (non-streaming) - if let Some(orch) = &orchestrator { - let post = orch.run_post_call(&completion).await; - if let Some(resp) = post.blocked_response { - return Ok(resp); - } - all_warnings.extend(post.warnings); - } - - return Ok(GuardrailsRunner::finalize_response( - Json(completion).into_response(), - &all_warnings, - )); + return Ok(Json(completion).into_response()); } if let ChatCompletionResponse::Stream(stream) = response { From ddfc09cd2f34e7c5f6153adcf4a1583ea8a3bf52 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Wed, 18 Feb 2026 09:43:52 +0200 Subject: [PATCH 39/59] dhange --- src/guardrails/middleware.rs | 176 ++++++++++++++++++++++++++++------- src/guardrails/parsing.rs | 32 +++++++ 2 files changed, 174 insertions(+), 34 deletions(-) diff --git a/src/guardrails/middleware.rs b/src/guardrails/middleware.rs index cff0bd91..db45dcae 100644 --- a/src/guardrails/middleware.rs +++ b/src/guardrails/middleware.rs @@ -10,12 +10,57 @@ use tower::{Layer, Service}; use tracing::debug; use crate::models::chat::{ChatCompletion, ChatCompletionRequest}; +use crate::models::completion::{CompletionRequest, CompletionResponse}; +use crate::models::embeddings::EmbeddingsRequest; use super::runner::GuardrailsRunner; use super::types::Guardrails; -/// Maximum request/response body size to buffer (10 MB). -const MAX_BODY_SIZE: usize = 10 * 1024 * 1024; +/// Enum representing the endpoint type. +#[derive(Debug, Clone, Copy)] +enum EndpointType { + Chat, + Completion, + Embeddings, +} + +impl EndpointType { + /// Determine endpoint type from request path. + fn from_path(path: &str) -> Option { + match path { + p if p.contains("/chat/completions") => Some(Self::Chat), + p if p.contains("/completions") => Some(Self::Completion), + p if p.contains("/embeddings") => Some(Self::Embeddings), + _ => None, + } + } +} + +/// Enum representing the type of request being processed. +enum ParsedRequest { + Chat(ChatCompletionRequest), + Completion(CompletionRequest), + Embeddings(EmbeddingsRequest), +} + +impl ParsedRequest { + /// Returns true if this is a streaming request. + fn is_streaming(&self) -> bool { + match self { + ParsedRequest::Chat(req) => req.stream.unwrap_or(false), + ParsedRequest::Completion(req) => req.stream.unwrap_or(false), + ParsedRequest::Embeddings(_) => false, + } + } + + /// Returns true if this request type supports post-call guards. + fn supports_post_call(&self) -> bool { + match self { + ParsedRequest::Chat(_) | ParsedRequest::Completion(_) => true, + ParsedRequest::Embeddings(_) => false, + } + } +} /// Tower layer that applies guardrail checks around a service. /// @@ -44,10 +89,9 @@ impl Layer for GuardrailsLayer { } } -/// Tower service that wraps an inner service with guardrail checks. #[derive(Clone)] pub struct GuardrailsMiddleware { - inner: S, + inner: S, // pipeline router guardrails: Option>, } @@ -67,8 +111,6 @@ where fn call(&mut self, request: Request) -> Self::Future { let guardrails = self.guardrails.clone(); - // Clone inner and swap so the clone is used in the future - // (standard Tower pattern to satisfy borrow checker) let inner = self.inner.clone(); let mut inner = std::mem::replace(&mut self.inner, inner); @@ -81,8 +123,19 @@ where let (parts, body) = request.into_parts(); + // Determine endpoint type from path (more efficient than parsing JSON) + let endpoint_type = match EndpointType::from_path(parts.uri.path()) { + Some(t) => t, + None => { + // Unsupported endpoint — pass through + debug!("Guardrails middleware: unsupported endpoint {}, passing through", parts.uri.path()); + let request = Request::from_parts(parts, body); + return inner.call(request).await; + } + }; + // Buffer request body - let bytes = match axum::body::to_bytes(body, MAX_BODY_SIZE).await { + let bytes = match axum::body::to_bytes(body, usize::MAX).await { Ok(b) => b, Err(_) => { debug!("Guardrails middleware: failed to buffer request body, passing through"); @@ -90,18 +143,42 @@ where } }; - // Try to parse as ChatCompletionRequest - let chat_request: ChatCompletionRequest = match serde_json::from_slice(&bytes) { - Ok(r) => r, - Err(_) => { - // Not a chat completion request — pass through unchanged - let request = Request::from_parts(parts, Body::from(bytes)); - return inner.call(request).await; + // Parse request based on endpoint type + let parsed_request = match endpoint_type { + EndpointType::Chat => { + match serde_json::from_slice::(&bytes) { + Ok(req) => ParsedRequest::Chat(req), + Err(e) => { + debug!("Guardrails middleware: failed to parse chat request: {}", e); + let request = Request::from_parts(parts, Body::from(bytes)); + return inner.call(request).await; + } + } + } + EndpointType::Completion => { + match serde_json::from_slice::(&bytes) { + Ok(req) => ParsedRequest::Completion(req), + Err(e) => { + debug!("Guardrails middleware: failed to parse completion request: {}", e); + let request = Request::from_parts(parts, Body::from(bytes)); + return inner.call(request).await; + } + } + } + EndpointType::Embeddings => { + match serde_json::from_slice::(&bytes) { + Ok(req) => ParsedRequest::Embeddings(req), + Err(e) => { + debug!("Guardrails middleware: failed to parse embeddings request: {}", e); + let request = Request::from_parts(parts, Body::from(bytes)); + return inner.call(request).await; + } + } } }; // Skip guardrails for streaming requests - if chat_request.stream.unwrap_or(false) { + if parsed_request.is_streaming() { debug!("Guardrails middleware: streaming request, skipping guardrails"); let request = Request::from_parts(parts, Body::from(bytes)); return inner.call(request).await; @@ -120,7 +197,11 @@ where }; // --- Pre-call guards --- - let pre_result = runner.run_pre_call(&chat_request).await; + let pre_result = match &parsed_request { + ParsedRequest::Chat(req) => runner.run_pre_call(req).await, + ParsedRequest::Completion(req) => runner.run_pre_call(req).await, + ParsedRequest::Embeddings(req) => runner.run_pre_call(req).await, + }; if let Some(blocked) = pre_result.blocked_response { return Ok(blocked); } @@ -130,28 +211,55 @@ where let request = Request::from_parts(parts, Body::from(bytes)); let response = inner.call(request).await?; - // --- Post-call guards --- + // --- Post-call guards (only for request types that produce text) --- let (resp_parts, resp_body) = response.into_parts(); - let resp_bytes = match axum::body::to_bytes(resp_body, MAX_BODY_SIZE).await { - Ok(b) => b, - Err(_) => { - debug!("Guardrails middleware: failed to buffer response body, skipping post-call"); - let response = Response::from_parts(resp_parts, Body::empty()); - return Ok(GuardrailsRunner::finalize_response(response, &all_warnings)); - } - }; - if let Ok(completion) = serde_json::from_slice::(&resp_bytes) { - let post_result = runner.run_post_call(&completion).await; - if let Some(blocked) = post_result.blocked_response { - return Ok(blocked); + if parsed_request.supports_post_call() { + let resp_bytes = match axum::body::to_bytes(resp_body, usize::MAX).await { + Ok(b) => b, + Err(_) => { + debug!("Guardrails middleware: failed to buffer response body, skipping post-call"); + let response = Response::from_parts(resp_parts, Body::empty()); + return Ok(GuardrailsRunner::finalize_response(response, &all_warnings)); + } + }; + + let post_result = match &parsed_request { + ParsedRequest::Chat(_) => { + if let Ok(completion) = serde_json::from_slice::(&resp_bytes) { + Some(runner.run_post_call(&completion).await) + } else { + debug!("Guardrails middleware: failed to parse chat completion response"); + None + } + } + ParsedRequest::Completion(_) => { + if let Ok(completion) = serde_json::from_slice::(&resp_bytes) { + Some(runner.run_post_call(&completion).await) + } else { + debug!("Guardrails middleware: failed to parse completion response"); + None + } + } + ParsedRequest::Embeddings(_) => None, + }; + + if let Some(result) = post_result { + if let Some(blocked) = result.blocked_response { + return Ok(blocked); + } + all_warnings.extend(result.warnings); } - all_warnings.extend(post_result.warnings); - } - // Reconstruct response with original bytes and attach warning headers - let response = Response::from_parts(resp_parts, Body::from(resp_bytes)); - Ok(GuardrailsRunner::finalize_response(response, &all_warnings)) + // Reconstruct response with original bytes and attach warning headers + let response = Response::from_parts(resp_parts, Body::from(resp_bytes)); + Ok(GuardrailsRunner::finalize_response(response, &all_warnings)) + } else { + // No post-call guards for this request type (e.g., embeddings) + // Pass through response with pre-call warnings attached + let response = Response::from_parts(resp_parts, resp_body); + Ok(GuardrailsRunner::finalize_response(response, &all_warnings)) + } }) } } diff --git a/src/guardrails/parsing.rs b/src/guardrails/parsing.rs index 4bf28bc1..fd984248 100644 --- a/src/guardrails/parsing.rs +++ b/src/guardrails/parsing.rs @@ -1,5 +1,7 @@ use crate::models::chat::{ChatCompletion, ChatCompletionRequest}; +use crate::models::completion::{CompletionRequest, CompletionResponse}; use crate::models::content::ChatMessageContent; +use crate::models::embeddings::EmbeddingsRequest; use tracing::debug; use super::types::{EvaluatorResponse, GuardrailError}; @@ -52,6 +54,36 @@ impl CompletionExtractor for ChatCompletion { } } +impl PromptExtractor for CompletionRequest { + fn extract_prompt(&self) -> String { + self.prompt.clone() + } +} + +impl CompletionExtractor for CompletionResponse { + fn extract_completion(&self) -> String { + self.choices + .first() + .map(|choice| choice.text.clone()) + .unwrap_or_default() + } +} + +impl PromptExtractor for EmbeddingsRequest { + fn extract_prompt(&self) -> String { + match &self.input { + crate::models::embeddings::EmbeddingsInput::Single(s) => s.clone(), + crate::models::embeddings::EmbeddingsInput::Multiple(v) => v.join("\n"), + crate::models::embeddings::EmbeddingsInput::SingleTokenIds(_) => { + "[token IDs]".to_string() + } + crate::models::embeddings::EmbeddingsInput::MultipleTokenIds(_) => { + "[multiple token IDs]".to_string() + } + } + } +} + /// Parse the evaluator response body (JSON string) into an EvaluatorResponse. pub fn parse_evaluator_response(body: &str) -> Result { let response = serde_json::from_str::(body) From f5ba8fba37ed6d6e88c39b2c7c4fe862929743a8 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Wed, 18 Feb 2026 10:48:17 +0200 Subject: [PATCH 40/59] added middle test --- src/guardrails/middleware.rs | 107 +++--- tests/guardrails/helpers.rs | 140 ++++++- tests/guardrails/main.rs | 1 + tests/guardrails/test_middleware.rs | 562 ++++++++++++++++++++++++++++ 4 files changed, 764 insertions(+), 46 deletions(-) create mode 100644 tests/guardrails/test_middleware.rs diff --git a/src/guardrails/middleware.rs b/src/guardrails/middleware.rs index db45dcae..91738f0e 100644 --- a/src/guardrails/middleware.rs +++ b/src/guardrails/middleware.rs @@ -13,6 +13,7 @@ use crate::models::chat::{ChatCompletion, ChatCompletionRequest}; use crate::models::completion::{CompletionRequest, CompletionResponse}; use crate::models::embeddings::EmbeddingsRequest; +use super::parsing::PromptExtractor; use super::runner::GuardrailsRunner; use super::types::Guardrails; @@ -62,6 +63,64 @@ impl ParsedRequest { } } +impl PromptExtractor for ParsedRequest { + fn extract_prompt(&self) -> String { + match self { + ParsedRequest::Chat(req) => req.extract_prompt(), + ParsedRequest::Completion(req) => req.extract_prompt(), + ParsedRequest::Embeddings(req) => req.extract_prompt(), + } + } +} + +/// Helper function to handle post-call guards for supported request types. +async fn handle_post_call_guards( + parsed_request: &ParsedRequest, + resp_parts: axum::http::response::Parts, + resp_body: Body, + runner: &GuardrailsRunner<'_>, + mut warnings: Vec, +) -> Response { + let resp_bytes = match axum::body::to_bytes(resp_body, usize::MAX).await { + Ok(b) => b, + Err(_) => { + debug!("Guardrails middleware: failed to buffer response body, skipping post-call"); + let response = Response::from_parts(resp_parts, Body::empty()); + return GuardrailsRunner::finalize_response(response, &warnings); + } + }; + + let post_result = match parsed_request { + ParsedRequest::Chat(_) => { + if let Ok(completion) = serde_json::from_slice::(&resp_bytes) { + Some(runner.run_post_call(&completion).await) + } else { + debug!("Guardrails middleware: failed to parse chat completion response"); + None + } + } + ParsedRequest::Completion(_) => { + if let Ok(completion) = serde_json::from_slice::(&resp_bytes) { + Some(runner.run_post_call(&completion).await) + } else { + debug!("Guardrails middleware: failed to parse completion response"); + None + } + } + ParsedRequest::Embeddings(_) => None, + }; + + if let Some(result) = post_result { + if let Some(blocked) = result.blocked_response { + return blocked; + } + warnings.extend(result.warnings); + } + + let response = Response::from_parts(resp_parts, Body::from(resp_bytes)); + GuardrailsRunner::finalize_response(response, &warnings) +} + /// Tower layer that applies guardrail checks around a service. /// /// - **Pre-call guards** run before the inner service, inspecting the request body. @@ -197,15 +256,11 @@ where }; // --- Pre-call guards --- - let pre_result = match &parsed_request { - ParsedRequest::Chat(req) => runner.run_pre_call(req).await, - ParsedRequest::Completion(req) => runner.run_pre_call(req).await, - ParsedRequest::Embeddings(req) => runner.run_pre_call(req).await, - }; + let pre_result = runner.run_pre_call(&parsed_request).await; if let Some(blocked) = pre_result.blocked_response { return Ok(blocked); } - let mut all_warnings = pre_result.warnings; + let all_warnings = pre_result.warnings; // --- Call inner service --- let request = Request::from_parts(parts, Body::from(bytes)); @@ -215,45 +270,7 @@ where let (resp_parts, resp_body) = response.into_parts(); if parsed_request.supports_post_call() { - let resp_bytes = match axum::body::to_bytes(resp_body, usize::MAX).await { - Ok(b) => b, - Err(_) => { - debug!("Guardrails middleware: failed to buffer response body, skipping post-call"); - let response = Response::from_parts(resp_parts, Body::empty()); - return Ok(GuardrailsRunner::finalize_response(response, &all_warnings)); - } - }; - - let post_result = match &parsed_request { - ParsedRequest::Chat(_) => { - if let Ok(completion) = serde_json::from_slice::(&resp_bytes) { - Some(runner.run_post_call(&completion).await) - } else { - debug!("Guardrails middleware: failed to parse chat completion response"); - None - } - } - ParsedRequest::Completion(_) => { - if let Ok(completion) = serde_json::from_slice::(&resp_bytes) { - Some(runner.run_post_call(&completion).await) - } else { - debug!("Guardrails middleware: failed to parse completion response"); - None - } - } - ParsedRequest::Embeddings(_) => None, - }; - - if let Some(result) = post_result { - if let Some(blocked) = result.blocked_response { - return Ok(blocked); - } - all_warnings.extend(result.warnings); - } - - // Reconstruct response with original bytes and attach warning headers - let response = Response::from_parts(resp_parts, Body::from(resp_bytes)); - Ok(GuardrailsRunner::finalize_response(response, &all_warnings)) + Ok(handle_post_call_guards(&parsed_request, resp_parts, resp_body, &runner, all_warnings).await) } else { // No post-call guards for this request type (e.g., embeddings) // Pass through response with pre-call warnings attached diff --git a/tests/guardrails/helpers.rs b/tests/guardrails/helpers.rs index 79252ada..d314640e 100644 --- a/tests/guardrails/helpers.rs +++ b/tests/guardrails/helpers.rs @@ -1,12 +1,24 @@ use std::collections::HashMap; +use std::convert::Infallible; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; use async_trait::async_trait; +use axum::body::Body; +use axum::extract::Request; +use axum::http::StatusCode; +use axum::response::Response; use hub_lib::guardrails::types::GuardrailClient; use hub_lib::guardrails::types::{EvaluatorResponse, Guard, GuardMode, GuardrailError, OnFailure}; use hub_lib::models::chat::{ChatCompletion, ChatCompletionChoice, ChatCompletionRequest}; +use hub_lib::models::completion::{CompletionChoice, CompletionRequest, CompletionResponse}; use hub_lib::models::content::{ChatCompletionMessage, ChatMessageContent}; -use hub_lib::models::usage::Usage; +use hub_lib::models::embeddings::{Embedding, Embeddings, EmbeddingsInput, EmbeddingsRequest, EmbeddingsResponse}; +use hub_lib::models::usage::{EmbeddingUsage, Usage}; +use serde::Serialize; use serde_json::json; +use tower::Service; // --------------------------------------------------------------------------- // Guard config builders @@ -164,6 +176,132 @@ pub fn create_test_chat_completion(response_text: &str) -> ChatCompletion { } } +// --------------------------------------------------------------------------- +// Completion request/response builders +// --------------------------------------------------------------------------- + +pub fn create_test_completion_request(prompt: &str) -> CompletionRequest { + CompletionRequest { + model: "gpt-3.5-turbo-instruct".to_string(), + prompt: prompt.to_string(), + suffix: None, + max_tokens: Some(100), + temperature: None, + top_p: None, + n: None, + stream: None, + logprobs: None, + echo: None, + stop: None, + presence_penalty: None, + frequency_penalty: None, + best_of: None, + logit_bias: None, + user: None, + } +} + +pub fn create_test_completion_response(text: &str) -> CompletionResponse { + CompletionResponse { + id: "cmpl-test".to_string(), + object: "text_completion".to_string(), + created: 1234567890, + model: "gpt-3.5-turbo-instruct".to_string(), + choices: vec![CompletionChoice { + text: text.to_string(), + index: 0, + logprobs: None, + finish_reason: Some("stop".to_string()), + }], + usage: Usage::default(), + } +} + +// --------------------------------------------------------------------------- +// Embeddings request/response builders +// --------------------------------------------------------------------------- + +pub fn create_test_embeddings_request(text: &str) -> EmbeddingsRequest { + EmbeddingsRequest { + model: "text-embedding-ada-002".to_string(), + input: EmbeddingsInput::Single(text.to_string()), + user: None, + encoding_format: None, + } +} + +pub fn create_test_embeddings_response() -> EmbeddingsResponse { + EmbeddingsResponse { + object: "list".to_string(), + data: vec![Embeddings { + object: "embedding".to_string(), + embedding: Embedding::Float(vec![0.1, 0.2, 0.3]), + index: 0, + }], + model: "text-embedding-ada-002".to_string(), + usage: EmbeddingUsage { + prompt_tokens: Some(8), + total_tokens: Some(8), + }, + } +} + +// --------------------------------------------------------------------------- +// Streaming request builders +// --------------------------------------------------------------------------- + +pub fn create_streaming_chat_request(message: &str) -> ChatCompletionRequest { + let mut req = create_test_chat_request(message); + req.stream = Some(true); + req +} + +pub fn create_streaming_completion_request(prompt: &str) -> CompletionRequest { + let mut req = create_test_completion_request(prompt); + req.stream = Some(true); + req +} + +// --------------------------------------------------------------------------- +// Mock Service for middleware testing +// --------------------------------------------------------------------------- + +#[derive(Clone)] +pub struct MockService { + status: StatusCode, + body: Vec, +} + +impl MockService { + pub fn with_json(status: StatusCode, data: &T) -> Self { + let body = serde_json::to_vec(data).unwrap(); + Self { status, body } + } +} + +impl Service> for MockService { + type Response = Response; + type Error = Infallible; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _req: Request) -> Self::Future { + let status = self.status; + let body = self.body.clone(); + Box::pin(async move { + let response = Response::builder() + .status(status) + .header("content-type", "application/json") + .body(Body::from(body)) + .unwrap(); + Ok(response) + }) + } +} + // --------------------------------------------------------------------------- // Mock GuardrailClient // --------------------------------------------------------------------------- diff --git a/tests/guardrails/main.rs b/tests/guardrails/main.rs index bbaf1496..dd32663e 100644 --- a/tests/guardrails/main.rs +++ b/tests/guardrails/main.rs @@ -1,5 +1,6 @@ mod helpers; mod test_e2e; +mod test_middleware; mod test_parsing; mod test_run_evaluator; mod test_runner; diff --git a/tests/guardrails/test_middleware.rs b/tests/guardrails/test_middleware.rs new file mode 100644 index 00000000..c92cae95 --- /dev/null +++ b/tests/guardrails/test_middleware.rs @@ -0,0 +1,562 @@ +use hub_lib::guardrails::middleware::GuardrailsLayer; +use hub_lib::guardrails::providers::traceloop::TraceloopClient; +use hub_lib::guardrails::types::{Guard, GuardMode, Guardrails, OnFailure}; + +use axum::body::{to_bytes, Body}; +use axum::extract::Request; +use axum::http::StatusCode; +use serde_json::json; +use std::sync::Arc; +use tower::{Layer, Service, ServiceExt}; +use wiremock::matchers; +use wiremock::{Mock, MockServer, ResponseTemplate}; + +use super::helpers::*; + +/// Helper to create a guard with a specific wiremock server +fn guard_with_server( + name: &str, + mode: GuardMode, + on_failure: OnFailure, + server_uri: &str, +) -> Guard { + Guard { + name: name.to_string(), + provider: "traceloop".to_string(), + evaluator_slug: "toxicity-detector".to_string(), + params: Default::default(), + mode, + on_failure, + required: false, + api_base: Some(server_uri.to_string()), + api_key: Some("test-key".to_string()), + } +} + +/// Helper to create a complete guardrails configuration +fn create_guardrails(guards: Vec) -> Guardrails { + let guard_names: Vec = guards.iter().map(|g| g.name.clone()).collect(); + Guardrails { + all_guards: Arc::new(guards), + pipeline_guard_names: guard_names, + client: Arc::new(TraceloopClient::new()), + } +} + +// =========================================================================== +// Category 1: Endpoint Type Detection +// =========================================================================== + +#[tokio::test] +async fn test_chat_completions_endpoint_detected() { + // Set up mock evaluator for pre-call guard + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {}, + "pass": true + }))) + .expect(1) + .mount(&eval_server) + .await; + + // Create guard + let guard = guard_with_server( + "detector", + GuardMode::PreCall, + OnFailure::Block, + &eval_server.uri(), + ); + let guardrails = create_guardrails(vec![guard]); + + // Create mock inner service + let completion = create_test_chat_completion("Response"); + let inner_service = MockService::with_json(StatusCode::OK, &completion); + + // Apply middleware + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + // Create chat request + let request = create_test_chat_request("Test input"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + // Call middleware + let response = service.ready().await.unwrap().call(http_request).await.unwrap(); + + // Verify response is 200 OK (guard passed) + assert_eq!(response.status(), StatusCode::OK); + + // Wiremock verifies evaluator was called (expect(1)) +} + +#[tokio::test] +async fn test_completions_endpoint_detected() { + // Set up mock evaluator + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {}, + "pass": true + }))) + .expect(1) + .mount(&eval_server) + .await; + + let guard = guard_with_server( + "detector", + GuardMode::PreCall, + OnFailure::Block, + &eval_server.uri(), + ); + let guardrails = create_guardrails(vec![guard]); + + let completion = create_test_completion_response("Response text"); + let inner_service = MockService::with_json(StatusCode::OK, &completion); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + // Create completion request + let request = create_test_completion_request("Complete this"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/completions") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + let response = service.ready().await.unwrap().call(http_request).await.unwrap(); + + assert_eq!(response.status(), StatusCode::OK); +} + +#[tokio::test] +async fn test_embeddings_endpoint_detected() { + // Set up mock evaluator + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {}, + "pass": true + }))) + .expect(1) + .mount(&eval_server) + .await; + + let guard = guard_with_server( + "detector", + GuardMode::PreCall, + OnFailure::Block, + &eval_server.uri(), + ); + let guardrails = create_guardrails(vec![guard]); + + let embeddings_response = create_test_embeddings_response(); + let inner_service = MockService::with_json(StatusCode::OK, &embeddings_response); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + // Create embeddings request + let request = create_test_embeddings_request("Embed this text"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/embeddings") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + let response = service.ready().await.unwrap().call(http_request).await.unwrap(); + + assert_eq!(response.status(), StatusCode::OK); +} + +// =========================================================================== +// Category 2: Pre-Call Guard Behavior +// =========================================================================== + +#[tokio::test] +async fn test_pre_call_guard_blocks_chat() { + // Set up mock evaluator that fails + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {"reason": "toxic content"}, + "pass": false + }))) + .expect(1) + .mount(&eval_server) + .await; + + let guard = guard_with_server( + "blocker", + GuardMode::PreCall, + OnFailure::Block, + &eval_server.uri(), + ); + let guardrails = create_guardrails(vec![guard]); + + let completion = create_test_chat_completion("This won't be returned"); + let inner_service = MockService::with_json(StatusCode::OK, &completion); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + let request = create_test_chat_request("Bad input"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + let response = service.ready().await.unwrap().call(http_request).await.unwrap(); + + // Verify blocked + assert_eq!(response.status(), StatusCode::FORBIDDEN); + + let body_bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap(); + assert_eq!(response_json["error"]["guardrail"], "blocker"); +} + +#[tokio::test] +async fn test_pre_call_guard_warns_chat() { + // Set up mock evaluator that fails with warn + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {"reason": "borderline"}, + "pass": false + }))) + .expect(1) + .mount(&eval_server) + .await; + + let guard = guard_with_server( + "warner", + GuardMode::PreCall, + OnFailure::Warn, + &eval_server.uri(), + ); + let guardrails = create_guardrails(vec![guard]); + + let completion = create_test_chat_completion("Response text"); + let inner_service = MockService::with_json(StatusCode::OK, &completion); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + let request = create_test_chat_request("Borderline input"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + let response = service.ready().await.unwrap().call(http_request).await.unwrap(); + + // Verify passes with warning + assert_eq!(response.status(), StatusCode::OK); + assert!(response.headers().contains_key("x-traceloop-guardrail-warning")); + + let warning_header = response.headers() + .get("x-traceloop-guardrail-warning") + .unwrap() + .to_str() + .unwrap(); + assert!(warning_header.contains("warner")); +} + +// =========================================================================== +// Category 3: Post-Call Guard Behavior +// =========================================================================== + +#[tokio::test] +async fn test_post_call_guard_blocks_chat() { + // Set up mock evaluator for post-call that fails + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {"reason": "unsafe output"}, + "pass": false + }))) + .expect(1) + .mount(&eval_server) + .await; + + let guard = guard_with_server( + "output-blocker", + GuardMode::PostCall, + OnFailure::Block, + &eval_server.uri(), + ); + let guardrails = create_guardrails(vec![guard]); + + let completion = create_test_chat_completion("Unsafe response"); + let inner_service = MockService::with_json(StatusCode::OK, &completion); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + let request = create_test_chat_request("Safe input"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + let response = service.ready().await.unwrap().call(http_request).await.unwrap(); + + // Verify blocked by post-call guard + assert_eq!(response.status(), StatusCode::FORBIDDEN); + + let body_bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap(); + assert_eq!(response_json["error"]["guardrail"], "output-blocker"); +} + +#[tokio::test] +async fn test_post_call_guard_skipped_for_embeddings() { + // Set up mock evaluator for pre-call (should be called) + let pre_eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {}, + "pass": true + }))) + .expect(1) // Pre-call should run + .mount(&pre_eval_server) + .await; + + // Set up mock evaluator for post-call (should NOT be called) + let post_eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {}, + "pass": false // Would block if called + }))) + .expect(0) // Post-call should NOT run for embeddings + .mount(&post_eval_server) + .await; + + // Create both pre-call and post-call guards + let pre_guard = guard_with_server( + "pre-guard", + GuardMode::PreCall, + OnFailure::Block, + &pre_eval_server.uri(), + ); + let post_guard = guard_with_server( + "post-guard", + GuardMode::PostCall, + OnFailure::Block, + &post_eval_server.uri(), + ); + let guardrails = create_guardrails(vec![pre_guard, post_guard]); + + let embeddings_response = create_test_embeddings_response(); + let inner_service = MockService::with_json(StatusCode::OK, &embeddings_response); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + let request = create_test_embeddings_request("Embed this"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/embeddings") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + let response = service.ready().await.unwrap().call(http_request).await.unwrap(); + + // Verify response is 200 OK (no post-call blocking) + assert_eq!(response.status(), StatusCode::OK); + + // Verify no warning header + assert!(!response.headers().contains_key("x-traceloop-guardrail-warning")); + + // Verify response body contains embeddings + let body_bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap(); + assert_eq!(response_json["object"], "list"); + + // Wiremock verifies: + // - pre-call evaluator called exactly once (expect(1)) + // - post-call evaluator never called (expect(0)) +} + +// =========================================================================== +// Category 4: Streaming Behavior +// =========================================================================== + +#[tokio::test] +async fn test_streaming_chat_bypasses_guards() { + // Set up mock evaluator (should never be called) + let eval_server = MockServer::start().await; + Mock::given(matchers::any()) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {}, + "pass": false // Would block if evaluated + }))) + .expect(0) // Should never be called for streaming + .mount(&eval_server) + .await; + + let guard = guard_with_server( + "blocker", + GuardMode::PreCall, + OnFailure::Block, + &eval_server.uri(), + ); + let guardrails = create_guardrails(vec![guard]); + + let completion = create_test_chat_completion("Streamed response"); + let inner_service = MockService::with_json(StatusCode::OK, &completion); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + // Create STREAMING chat request + let request = create_streaming_chat_request("This would fail guards if checked"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + let response = service.ready().await.unwrap().call(http_request).await.unwrap(); + + // Verify response is 200 OK (streaming bypasses guards) + assert_eq!(response.status(), StatusCode::OK); + + // Verify no warning header + assert!(!response.headers().contains_key("x-traceloop-guardrail-warning")); + + // Wiremock verifies evaluator was never called (expect(0)) +} + +#[tokio::test] +async fn test_streaming_completion_bypasses_guards() { + // Set up mock evaluator (should never be called) + let eval_server = MockServer::start().await; + Mock::given(matchers::any()) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {}, + "pass": false // Would block if evaluated + }))) + .expect(0) // Should never be called for streaming + .mount(&eval_server) + .await; + + let guard = guard_with_server( + "blocker", + GuardMode::PreCall, + OnFailure::Block, + &eval_server.uri(), + ); + let guardrails = create_guardrails(vec![guard]); + + let completion = create_test_completion_response("Streamed completion"); + let inner_service = MockService::with_json(StatusCode::OK, &completion); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + // Create STREAMING completion request + let request = create_streaming_completion_request("This would fail guards if checked"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/completions") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + let response = service.ready().await.unwrap().call(http_request).await.unwrap(); + + // Verify response is 200 OK (streaming bypasses guards) + assert_eq!(response.status(), StatusCode::OK); + + // Verify no warning header + assert!(!response.headers().contains_key("x-traceloop-guardrail-warning")); + + // Wiremock verifies evaluator was never called (expect(0)) +} + +// =========================================================================== +// Category 5: Pass-Through Scenarios +// =========================================================================== + +#[tokio::test] +async fn test_no_guardrails_configured_passes() { + // No guardrails configured + let completion = create_test_chat_completion("Response"); + let inner_service = MockService::with_json(StatusCode::OK, &completion); + + let layer = GuardrailsLayer::new(None); // No guardrails + let mut service = layer.layer(inner_service); + + let request = create_test_chat_request("Any input"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + let response = service.ready().await.unwrap().call(http_request).await.unwrap(); + + // Verify response passes through unchanged + assert_eq!(response.status(), StatusCode::OK); +} + +#[tokio::test] +async fn test_unsupported_endpoint_passes() { + let eval_server = MockServer::start().await; + Mock::given(matchers::any()) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {}, + "pass": false + }))) + .expect(0) // Should never be called + .mount(&eval_server) + .await; + + let guard = guard_with_server( + "blocker", + GuardMode::PreCall, + OnFailure::Block, + &eval_server.uri(), + ); + let guardrails = create_guardrails(vec![guard]); + + let inner_service = MockService::with_json( + StatusCode::OK, + &json!({"data": [{"id": "model-1"}]}), + ); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + // Request to unsupported endpoint + let http_request = Request::builder() + .method("GET") + .uri("/v1/models") // Unsupported endpoint + .body(Body::empty()) + .unwrap(); + + let response = service.ready().await.unwrap().call(http_request).await.unwrap(); + + // Verify passes through + assert_eq!(response.status(), StatusCode::OK); +} From 6e2fc9221dc8d18c12d3fb59090e157e9c27805f Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Wed, 18 Feb 2026 11:00:01 +0200 Subject: [PATCH 41/59] require api key --- src/guardrails/providers/traceloop.rs | 10 +++++- tests/guardrails/test_traceloop_client.rs | 38 +++++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/src/guardrails/providers/traceloop.rs b/src/guardrails/providers/traceloop.rs index 24fdfde5..4ffbf214 100644 --- a/src/guardrails/providers/traceloop.rs +++ b/src/guardrails/providers/traceloop.rs @@ -52,7 +52,15 @@ impl GuardrailClient for TraceloopClient { api_base, guard.evaluator_slug ); - let api_key = guard.api_key.as_deref().unwrap_or(""); + let api_key = guard + .api_key + .as_deref() + .filter(|k| !k.is_empty()) + .ok_or_else(|| { + GuardrailError::Unavailable( + "Traceloop API key is required but not provided".to_string(), + ) + })?; let evaluator = get_evaluator(&guard.evaluator_slug).ok_or_else(|| { GuardrailError::Unavailable(format!( diff --git a/tests/guardrails/test_traceloop_client.rs b/tests/guardrails/test_traceloop_client.rs index 58a3f34c..e15e5a77 100644 --- a/tests/guardrails/test_traceloop_client.rs +++ b/tests/guardrails/test_traceloop_client.rs @@ -115,6 +115,44 @@ async fn test_traceloop_client_handles_timeout() { assert!(result.is_err()); } +#[tokio::test] +async fn test_traceloop_client_rejects_missing_api_key() { + let mock_server = MockServer::start().await; + let mut guard = create_test_guard_with_api_base("test", GuardMode::PreCall, &mock_server.uri()); + guard.api_key = None; + + let client = TraceloopClient::new(); + let result = client.evaluate(&guard, "test input").await; + + assert!(result.is_err()); + if let Err(e) = result { + assert!( + e.to_string().contains("API key is required"), + "Expected error message about missing API key, got: {}", + e + ); + } +} + +#[tokio::test] +async fn test_traceloop_client_rejects_empty_api_key() { + let mock_server = MockServer::start().await; + let mut guard = create_test_guard_with_api_base("test", GuardMode::PreCall, &mock_server.uri()); + guard.api_key = Some("".to_string()); + + let client = TraceloopClient::new(); + let result = client.evaluate(&guard, "test input").await; + + assert!(result.is_err()); + if let Err(e) = result { + assert!( + e.to_string().contains("API key is required"), + "Expected error message about missing API key, got: {}", + e + ); + } +} + #[test] fn test_client_creation_from_guard_config() { let guard = create_test_guard("test", GuardMode::PreCall); From 1d866581306077d0d536bd98ebf1d3e41e71f3cc Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Wed, 18 Feb 2026 11:07:26 +0200 Subject: [PATCH 42/59] more comments --- src/guardrails/providers/traceloop.rs | 2 +- src/pipelines/pipeline.rs | 34 ++++++++++++++++++--------- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/src/guardrails/providers/traceloop.rs b/src/guardrails/providers/traceloop.rs index 4ffbf214..7ca19c1e 100644 --- a/src/guardrails/providers/traceloop.rs +++ b/src/guardrails/providers/traceloop.rs @@ -30,7 +30,7 @@ impl TraceloopClient { http_client: reqwest::Client::builder() .timeout(timeout) .build() - .unwrap_or_default(), + .expect("Failed to build HTTP client for Traceloop"), } } } diff --git a/src/pipelines/pipeline.rs b/src/pipelines/pipeline.rs index 0bda5fcd..5d0ccde8 100644 --- a/src/pipelines/pipeline.rs +++ b/src/pipelines/pipeline.rs @@ -130,12 +130,14 @@ pub async fn chat_completions( tracer.start_llm_span("chat", &payload); tracer.set_vendor(&get_vendor_name(&model.provider.r#type())); - let response = model - .chat_completions(payload.clone()) - .await - .inspect_err(|e| { + let response = match model.chat_completions(payload.clone()).await { + Ok(response) => response, + Err(e) => { eprintln!("Chat completion error for model {model_key}: {e:?}"); - })?; + tracer.log_error(format!("Chat completion failed: {e:?}")); + return Err(e); + } + }; if let ChatCompletionResponse::NonStream(completion) = response { tracer.log_success(&completion); @@ -169,9 +171,14 @@ pub async fn completions( tracer.start_llm_span("completion", &payload); tracer.set_vendor(&get_vendor_name(&model.provider.r#type())); - let response = model.completions(payload.clone()).await.inspect_err(|e| { - eprintln!("Completion error for model {model_key}: {e:?}"); - })?; + let response = match model.completions(payload.clone()).await { + Ok(response) => response, + Err(e) => { + eprintln!("Completion error for model {model_key}: {e:?}"); + tracer.log_error(format!("Completion failed: {e:?}")); + return Err(e); + } + }; tracer.log_success(&response); return Ok(Json(response).into_response()); @@ -197,9 +204,14 @@ pub async fn embeddings( tracer.start_llm_span("embeddings", &payload); tracer.set_vendor(&get_vendor_name(&model.provider.r#type())); - let response = model.embeddings(payload.clone()).await.inspect_err(|e| { - eprintln!("Embeddings error for model {model_key}: {e:?}"); - })?; + let response = match model.embeddings(payload.clone()).await { + Ok(response) => response, + Err(e) => { + eprintln!("Embeddings error for model {model_key}: {e:?}"); + tracer.log_error(format!("Embeddings failed: {e:?}")); + return Err(e); + } + }; tracer.log_success(&response); return Ok(Json(response)); } From 96542957d35b365922cd8cecf7cfd867da83052e Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Wed, 18 Feb 2026 11:28:00 +0200 Subject: [PATCH 43/59] comm --- src/config/validation.rs | 42 ++--- src/guardrails/evaluator_types.rs | 86 +++------- tests/guardrails/helpers.rs | 73 +++++--- tests/guardrails/test_e2e.rs | 272 ++++++++++++------------------ 4 files changed, 195 insertions(+), 278 deletions(-) diff --git a/src/config/validation.rs b/src/config/validation.rs index ec5b9cbb..63e49314 100644 --- a/src/config/validation.rs +++ b/src/config/validation.rs @@ -48,31 +48,18 @@ pub fn validate_gateway_config(config: &GatewayConfig) -> Result<(), Vec // Check 4: Guardrails validation if let Some(gr_config) = &config.guardrails { - // Guard provider references must exist in guardrails.providers + // Validate all guards in a single pass + let mut seen_guard_names = HashSet::new(); for guard in &gr_config.guards { + // Check provider reference exists if !gr_config.providers.contains_key(&guard.provider) { errors.push(format!( "Guard '{}' references non-existent guardrail provider '{}'.", guard.name, guard.provider )); } - } - // Pipeline guard references must exist in guardrails.guards - let guard_names: HashSet<&String> = gr_config.guards.iter().map(|g| &g.name).collect(); - for pipeline in &config.pipelines { - for guard_name in &pipeline.guards { - if !guard_names.contains(guard_name) { - errors.push(format!( - "Pipeline '{}' references non-existent guard '{}'.", - pipeline.name, guard_name - )); - } - } - } - - // Guards must have api_base and api_key (either directly or via provider) - for guard in &gr_config.guards { + // Check api_base and api_key (either directly or via provider) let has_api_base = guard.api_base.as_ref().is_some_and(|s| !s.is_empty()) || gr_config .providers @@ -96,25 +83,32 @@ pub fn validate_gateway_config(config: &GatewayConfig) -> Result<(), Vec guard.name, guard.provider )); } - } - // Evaluator slugs must be recognised - for guard in &gr_config.guards { + // Check evaluator slug is recognized if crate::guardrails::evaluator_types::get_evaluator(&guard.evaluator_slug).is_none() { errors.push(format!( "Guard '{}' has unknown evaluator_slug '{}'.", guard.name, guard.evaluator_slug )); } - } - // Guard names must be unique - let mut seen_guard_names = HashSet::new(); - for guard in &gr_config.guards { + // Check for duplicate guard names if !seen_guard_names.insert(&guard.name) { errors.push(format!("Duplicate guard name: '{}'.", guard.name)); } } + + // Pipeline guard references must exist in guardrails.guards + for pipeline in &config.pipelines { + for guard_name in &pipeline.guards { + if !seen_guard_names.contains(guard_name) { + errors.push(format!( + "Pipeline '{}' references non-existent guard '{}'.", + pipeline.name, guard_name + )); + } + } + } } if errors.is_empty() { diff --git a/src/guardrails/evaluator_types.rs b/src/guardrails/evaluator_types.rs index 3937e8aa..85a9944f 100644 --- a/src/guardrails/evaluator_types.rs +++ b/src/guardrails/evaluator_types.rs @@ -107,6 +107,21 @@ macro_rules! evaluator_with_no_config { }; } +macro_rules! evaluator_with_config { + ($name:ident, $body_fn:ident, $config:ty, $slug:expr) => { + pub struct $name; + impl EvaluatorRequest for $name { + fn build_body( + &self, + input: &str, + params: &HashMap, + ) -> Result { + attach_config::<$config>($body_fn(input), params, $slug) + } + } + }; +} + evaluator_with_no_config!(SecretsDetector, text_body); evaluator_with_no_config!(ProfanityDetector, text_body); evaluator_with_no_config!(SqlValidator, text_body); @@ -156,68 +171,9 @@ pub struct JsonValidatorConfig { // Evaluators with config // --------------------------------------------------------------------------- -pub struct PiiDetector; -impl EvaluatorRequest for PiiDetector { - fn build_body( - &self, - input: &str, - params: &HashMap, - ) -> Result { - attach_config::(text_body(input), params, PII_DETECTOR) - } -} - -pub struct PromptInjection; -impl EvaluatorRequest for PromptInjection { - fn build_body( - &self, - input: &str, - params: &HashMap, - ) -> Result { - attach_config::(prompt_body(input), params, PROMPT_INJECTION) - } -} - -pub struct SexismDetector; -impl EvaluatorRequest for SexismDetector { - fn build_body( - &self, - input: &str, - params: &HashMap, - ) -> Result { - attach_config::(text_body(input), params, SEXISM_DETECTOR) - } -} - -pub struct ToxicityDetector; -impl EvaluatorRequest for ToxicityDetector { - fn build_body( - &self, - input: &str, - params: &HashMap, - ) -> Result { - attach_config::(text_body(input), params, TOXICITY_DETECTOR) - } -} - -pub struct RegexValidator; -impl EvaluatorRequest for RegexValidator { - fn build_body( - &self, - input: &str, - params: &HashMap, - ) -> Result { - attach_config::(text_body(input), params, REGEX_VALIDATOR) - } -} - -pub struct JsonValidator; -impl EvaluatorRequest for JsonValidator { - fn build_body( - &self, - input: &str, - params: &HashMap, - ) -> Result { - attach_config::(text_body(input), params, JSON_VALIDATOR) - } -} +evaluator_with_config!(PiiDetector, text_body, PiiDetectorConfig, PII_DETECTOR); +evaluator_with_config!(PromptInjection, prompt_body, ThresholdConfig, PROMPT_INJECTION); +evaluator_with_config!(SexismDetector, text_body, ThresholdConfig, SEXISM_DETECTOR); +evaluator_with_config!(ToxicityDetector, text_body, ThresholdConfig, TOXICITY_DETECTOR); +evaluator_with_config!(RegexValidator, text_body, RegexValidatorConfig, REGEX_VALIDATOR); +evaluator_with_config!(JsonValidator, text_body, JsonValidatorConfig, JSON_VALIDATOR); diff --git a/tests/guardrails/helpers.rs b/tests/guardrails/helpers.rs index d314640e..502d6905 100644 --- a/tests/guardrails/helpers.rs +++ b/tests/guardrails/helpers.rs @@ -24,18 +24,55 @@ use tower::Service; // Guard config builders // --------------------------------------------------------------------------- -pub fn create_test_guard(name: &str, mode: GuardMode) -> Guard { - Guard { - name: name.to_string(), - provider: "traceloop".to_string(), - evaluator_slug: "pii-detector".to_string(), - params: HashMap::new(), - mode, - on_failure: OnFailure::Block, - required: false, - api_base: Some("http://localhost:8080".to_string()), - api_key: Some("test-api-key".to_string()), +pub struct TestGuardBuilder { + guard: Guard, +} + +impl TestGuardBuilder { + pub fn new(name: &str, mode: GuardMode) -> Self { + Self { + guard: Guard { + name: name.to_string(), + provider: "traceloop".to_string(), + evaluator_slug: "pii-detector".to_string(), + params: HashMap::new(), + mode, + on_failure: OnFailure::Block, + required: false, + api_base: Some("http://localhost:8080".to_string()), + api_key: Some("test-api-key".to_string()), + }, + } + } + + pub fn on_failure(mut self, on_failure: OnFailure) -> Self { + self.guard.on_failure = on_failure; + self + } + + pub fn required(mut self, required: bool) -> Self { + self.guard.required = required; + self } + + pub fn api_base(mut self, api_base: &str) -> Self { + self.guard.api_base = Some(api_base.to_string()); + self + } + + pub fn evaluator_slug(mut self, slug: &str) -> Self { + self.guard.evaluator_slug = slug.to_string(); + self + } + + pub fn build(self) -> Guard { + self.guard + } +} + +// Backward-compatible helper functions +pub fn create_test_guard(name: &str, mode: GuardMode) -> Guard { + TestGuardBuilder::new(name, mode).build() } pub fn create_test_guard_with_failure_action( @@ -43,22 +80,15 @@ pub fn create_test_guard_with_failure_action( mode: GuardMode, on_failure: OnFailure, ) -> Guard { - let mut guard = create_test_guard(name, mode); - guard.on_failure = on_failure; - guard + TestGuardBuilder::new(name, mode).on_failure(on_failure).build() } pub fn create_test_guard_with_required(name: &str, mode: GuardMode, required: bool) -> Guard { - let mut guard = create_test_guard(name, mode); - guard.required = required; - guard + TestGuardBuilder::new(name, mode).required(required).build() } -#[allow(dead_code)] pub fn create_test_guard_with_api_base(name: &str, mode: GuardMode, api_base: &str) -> Guard { - let mut guard = create_test_guard(name, mode); - guard.api_base = Some(api_base.to_string()); - guard + TestGuardBuilder::new(name, mode).api_base(api_base).build() } // --------------------------------------------------------------------------- @@ -94,7 +124,6 @@ pub fn default_message() -> ChatCompletionMessage { } } -#[allow(dead_code)] pub fn default_request() -> ChatCompletionRequest { ChatCompletionRequest { model: "gpt-4".to_string(), diff --git a/tests/guardrails/test_e2e.rs b/tests/guardrails/test_e2e.rs index 29d64c1e..87b5cd8c 100644 --- a/tests/guardrails/test_e2e.rs +++ b/tests/guardrails/test_e2e.rs @@ -29,37 +29,15 @@ async fn setup_evaluator(pass: bool) -> MockServer { server } -fn guard_with_server( - name: &str, - mode: GuardMode, - on_failure: OnFailure, - server_uri: &str, - slug: &str, -) -> Guard { - Guard { - name: name.to_string(), - provider: "traceloop".to_string(), - evaluator_slug: slug.to_string(), - params: Default::default(), - mode, - on_failure, - required: false, - api_base: Some(server_uri.to_string()), - api_key: Some("test-key".to_string()), - } -} - #[tokio::test] async fn test_e2e_pre_call_block_flow() { // Request -> guard fail+block -> 403 let eval = setup_evaluator(false).await; - let guard = guard_with_server( - "blocker", - GuardMode::PreCall, - OnFailure::Block, - &eval.uri(), - "toxicity-detector", - ); + let guard = TestGuardBuilder::new("blocker", GuardMode::PreCall) + .on_failure(OnFailure::Block) + .api_base(&eval.uri()) + .evaluator_slug("toxicity-detector") + .build(); let request = create_test_chat_request("Bad input"); let input = request.extract_prompt(); @@ -75,13 +53,11 @@ async fn test_e2e_pre_call_block_flow() { async fn test_e2e_pre_call_pass_flow() { // Request -> guard pass -> LLM -> 200 let eval = setup_evaluator(true).await; - let guard = guard_with_server( - "checker", - GuardMode::PreCall, - OnFailure::Block, - &eval.uri(), - "toxicity-detector", - ); + let guard = TestGuardBuilder::new("checker", GuardMode::PreCall) + .on_failure(OnFailure::Block) + .api_base(&eval.uri()) + .evaluator_slug("toxicity-detector") + .build(); let request = create_test_chat_request("Safe input"); let input = request.extract_prompt(); @@ -98,13 +74,11 @@ async fn test_e2e_pre_call_pass_flow() { async fn test_e2e_post_call_block_flow() { // Request -> LLM -> guard fail+block -> 403 let eval = setup_evaluator(false).await; - let guard = guard_with_server( - "pii-check", - GuardMode::PostCall, - OnFailure::Block, - &eval.uri(), - "pii-detector", - ); + let guard = TestGuardBuilder::new("pii-check", GuardMode::PostCall) + .on_failure(OnFailure::Block) + .api_base(&eval.uri()) + .evaluator_slug("pii-detector") + .build(); // Simulate LLM response let completion = create_test_chat_completion("Here is the SSN: 123-45-6789"); @@ -121,13 +95,11 @@ async fn test_e2e_post_call_block_flow() { async fn test_e2e_post_call_warn_flow() { // Request -> LLM -> guard fail+warn -> 200 + header let eval = setup_evaluator(false).await; - let guard = guard_with_server( - "tone-check", - GuardMode::PostCall, - OnFailure::Warn, - &eval.uri(), - "tone-detection", - ); + let guard = TestGuardBuilder::new("tone-check", GuardMode::PostCall) + .on_failure(OnFailure::Warn) + .api_base(&eval.uri()) + .evaluator_slug("tone-detection") + .build(); let completion = create_test_chat_completion("Mildly concerning response"); let response_text = completion.extract_completion(); @@ -146,20 +118,16 @@ async fn test_e2e_pre_and_post_both_pass() { let pre_eval = setup_evaluator(true).await; let post_eval = setup_evaluator(true).await; - let pre_guard = guard_with_server( - "pre-check", - GuardMode::PreCall, - OnFailure::Block, - &pre_eval.uri(), - "profanity-detector", - ); - let post_guard = guard_with_server( - "post-check", - GuardMode::PostCall, - OnFailure::Block, - &post_eval.uri(), - "pii-detector", - ); + let pre_guard = TestGuardBuilder::new("pre-check", GuardMode::PreCall) + .on_failure(OnFailure::Block) + .api_base(&pre_eval.uri()) + .evaluator_slug("profanity-detector") + .build(); + let post_guard = TestGuardBuilder::new("post-check", GuardMode::PostCall) + .on_failure(OnFailure::Block) + .api_base(&post_eval.uri()) + .evaluator_slug("pii-detector") + .build(); let client = TraceloopClient::new(); @@ -190,20 +158,16 @@ async fn test_e2e_pre_blocks_post_never_runs() { .mount(&post_eval) .await; - let pre_guard = guard_with_server( - "blocker", - GuardMode::PreCall, - OnFailure::Block, - &pre_eval.uri(), - "toxicity-detector", - ); - let post_guard = guard_with_server( - "post-check", - GuardMode::PostCall, - OnFailure::Block, - &post_eval.uri(), - "pii-detector", - ); + let pre_guard = TestGuardBuilder::new("blocker", GuardMode::PreCall) + .on_failure(OnFailure::Block) + .api_base(&pre_eval.uri()) + .evaluator_slug("toxicity-detector") + .build(); + let post_guard = TestGuardBuilder::new("post-check", GuardMode::PostCall) + .on_failure(OnFailure::Block) + .api_base(&post_eval.uri()) + .evaluator_slug("pii-detector") + .build(); let client = TraceloopClient::new(); let request = create_test_chat_request("Bad input"); @@ -225,27 +189,21 @@ async fn test_e2e_mixed_block_and_warn() { let eval3 = setup_evaluator(false).await; // fails -> block let guards = vec![ - guard_with_server( - "passer", - GuardMode::PreCall, - OnFailure::Block, - &eval1.uri(), - "profanity-detector", - ), - guard_with_server( - "warner", - GuardMode::PreCall, - OnFailure::Warn, - &eval2.uri(), - "tone-detection", - ), - guard_with_server( - "blocker", - GuardMode::PreCall, - OnFailure::Block, - &eval3.uri(), - "toxicity-detector", - ), + TestGuardBuilder::new("passer", GuardMode::PreCall) + .on_failure(OnFailure::Block) + .api_base(&eval1.uri()) + .evaluator_slug("profanity-detector") + .build(), + TestGuardBuilder::new("warner", GuardMode::PreCall) + .on_failure(OnFailure::Warn) + .api_base(&eval2.uri()) + .evaluator_slug("tone-detection") + .build(), + TestGuardBuilder::new("blocker", GuardMode::PreCall) + .on_failure(OnFailure::Block) + .api_base(&eval3.uri()) + .evaluator_slug("toxicity-detector") + .build(), ]; let client = TraceloopClient::new(); @@ -260,13 +218,11 @@ async fn test_e2e_mixed_block_and_warn() { async fn test_e2e_streaming_post_call_buffer_pass() { // Stream buffered, guard passes -> SSE response streamed to client let eval = setup_evaluator(true).await; - let guard = guard_with_server( - "response-check", - GuardMode::PostCall, - OnFailure::Block, - &eval.uri(), - "profanity-detector", - ); + let guard = TestGuardBuilder::new("response-check", GuardMode::PostCall) + .on_failure(OnFailure::Block) + .api_base(&eval.uri()) + .evaluator_slug("profanity-detector") + .build(); let accumulated = "Hello world!"; @@ -280,13 +236,11 @@ async fn test_e2e_streaming_post_call_buffer_pass() { async fn test_e2e_streaming_post_call_buffer_block() { // Stream buffered, guard blocks -> 403 let eval = setup_evaluator(false).await; - let guard = guard_with_server( - "pii-check", - GuardMode::PostCall, - OnFailure::Block, - &eval.uri(), - "pii-detector", - ); + let guard = TestGuardBuilder::new("pii-check", GuardMode::PostCall) + .on_failure(OnFailure::Block) + .api_base(&eval.uri()) + .evaluator_slug("pii-detector") + .build(); let accumulated = "Here is SSN: 123-45-6789"; @@ -409,20 +363,16 @@ async fn test_e2e_multiple_guards_different_evaluators() { .await; let guards = vec![ - guard_with_server( - "tox-guard", - GuardMode::PreCall, - OnFailure::Block, - &server.uri(), - "toxicity-detector", - ), - guard_with_server( - "pii-guard", - GuardMode::PreCall, - OnFailure::Block, - &server.uri(), - "pii-detector", - ), + TestGuardBuilder::new("tox-guard", GuardMode::PreCall) + .on_failure(OnFailure::Block) + .api_base(&server.uri()) + .evaluator_slug("toxicity-detector") + .build(), + TestGuardBuilder::new("pii-guard", GuardMode::PreCall) + .on_failure(OnFailure::Block) + .api_base(&server.uri()) + .evaluator_slug("pii-detector") + .build(), ]; let client = TraceloopClient::new(); @@ -442,13 +392,11 @@ async fn test_e2e_fail_open_evaluator_down() { .mount(&server) .await; - let mut guard = guard_with_server( - "checker", - GuardMode::PreCall, - OnFailure::Block, - &server.uri(), - "profanity-detector", - ); + let mut guard = TestGuardBuilder::new("checker", GuardMode::PreCall) + .on_failure(OnFailure::Block) + .api_base(&server.uri()) + .evaluator_slug("profanity-detector") + .build(); guard.required = false; // fail-open let client = TraceloopClient::new(); @@ -466,13 +414,11 @@ async fn test_e2e_fail_closed_evaluator_down() { .mount(&server) .await; - let mut guard = guard_with_server( - "checker", - GuardMode::PreCall, - OnFailure::Block, - &server.uri(), - "profanity-detector", - ); + let mut guard = TestGuardBuilder::new("checker", GuardMode::PreCall) + .on_failure(OnFailure::Block) + .api_base(&server.uri()) + .evaluator_slug("profanity-detector") + .build(); guard.required = true; // fail-closed let client = TraceloopClient::new(); @@ -548,13 +494,11 @@ async fn test_pre_call_guardrails_warn_and_continue() { .mount(&eval_server) .await; - let guard = guard_with_server( - "tone-check", - GuardMode::PreCall, - OnFailure::Warn, - &eval_server.uri(), - "tone-detection", - ); + let guard = TestGuardBuilder::new("tone-check", GuardMode::PreCall) + .on_failure(OnFailure::Warn) + .api_base(&eval_server.uri()) + .evaluator_slug("tone-detection") + .build(); let client = TraceloopClient::new(); let outcome = execute_guards(&[guard], "borderline input", &client, None).await; @@ -576,13 +520,11 @@ async fn test_post_call_guardrails_warn_and_add_header() { .mount(&eval_server) .await; - let guard = guard_with_server( - "safety-check", - GuardMode::PostCall, - OnFailure::Warn, - &eval_server.uri(), - "pii-detector", - ); + let guard = TestGuardBuilder::new("safety-check", GuardMode::PostCall) + .on_failure(OnFailure::Warn) + .api_base(&eval_server.uri()) + .evaluator_slug("pii-detector") + .build(); let client = TraceloopClient::new(); let outcome = execute_guards(&[guard], "Some LLM response", &client, None).await; @@ -644,13 +586,11 @@ async fn test_post_call_skipped_on_empty_response() { .mount(&eval_server) .await; - let guard = guard_with_server( - "toxicity-filter", - GuardMode::PostCall, - OnFailure::Block, - &eval_server.uri(), - "toxicity-detector", - ); + let guard = TestGuardBuilder::new("toxicity-filter", GuardMode::PostCall) + .on_failure(OnFailure::Block) + .api_base(&eval_server.uri()) + .evaluator_slug("toxicity-detector") + .build(); let guardrails = Guardrails { all_guards: Arc::new(vec![guard]), @@ -684,13 +624,11 @@ async fn test_evaluator_error_not_blocked_by_default() { .mount(&server) .await; - let guard = guard_with_server( - "warn-guard", - GuardMode::PreCall, - OnFailure::Warn, - &server.uri(), - "profanity-detector", - ); + let guard = TestGuardBuilder::new("warn-guard", GuardMode::PreCall) + .on_failure(OnFailure::Warn) + .api_base(&server.uri()) + .evaluator_slug("profanity-detector") + .build(); // guard.required is false by default let client = TraceloopClient::new(); From 865672652830a56a1f9b279b33ab90cefeb50795 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Wed, 18 Feb 2026 11:53:25 +0200 Subject: [PATCH 44/59] doc --- src/guardrails/GUARDRAILS.md | 32 ++++++++------------------------ 1 file changed, 8 insertions(+), 24 deletions(-) diff --git a/src/guardrails/GUARDRAILS.md b/src/guardrails/GUARDRAILS.md index cdecae80..edf75dcc 100644 --- a/src/guardrails/GUARDRAILS.md +++ b/src/guardrails/GUARDRAILS.md @@ -42,6 +42,12 @@ This document focuses on **config mode** available in Traceloop Hub v1. 2. **Post-call guards** run on the LLM's response *before* it is returned to the client. 3. All guards in a phase execute **concurrently** for minimal latency. +**Supported Routes:** +- ✅ Chat Completions (`/v1/chat/completions`) — pre-call and post-call guards +- ✅ Completions (`/v1/completions`) — pre-call and post-call guards +- ✅ Embeddings (`/v1/embeddings`) — **pre-call guards only** +- ❌ Streaming requests are **not supported** — guardrails require the complete request/response for evaluation + --- ## Configuration @@ -150,28 +156,6 @@ Each guard evaluation emits an OpenTelemetry child span with these attributes: --- -## Source Layout - -``` -src/guardrails/ -├── mod.rs # Module exports -├── types.rs # Core types: Guard, GuardrailsConfig, GuardResult, GuardrailsOutcome -├── evaluator_types.rs # Evaluator slug registry and request body builders -├── parsing.rs # Input/output extraction from chat requests/responses -├── runner.rs # Execution orchestration (GuardrailsRunner, execute_guards) -├── setup.rs # Config resolution, resource building, guard merging -├── span_attributes.rs # OpenTelemetry attribute constants -└── providers/ - ├── mod.rs # GuardrailClient trait re-export - └── traceloop.rs # Traceloop evaluator API HTTP client -``` - -### Key Types +## Implementation -- **`Guard`** — a single guardrail definition (name, evaluator slug, mode, failure policy, credentials) -- **`GuardrailsConfig`** — top-level config containing provider defaults and guard list -- **`Guardrails`** — per-pipeline runtime state holding shared guards + client -- **`GuardrailsRunner`** — per-request orchestrator that runs pre/post phases -- **`GuardrailClient`** (trait) — provider implementation for calling evaluator APIs -- **`EvaluatorRequest`** (trait) — evaluator-specific request body builder -- **`GuardrailsOutcome`** — aggregated results from a guard execution phase +See `src/guardrails/mod.rs` for module structure and key type definitions. From f6b83a73916d11f38035e9a6400aa0f1b6160d80 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Wed, 18 Feb 2026 12:17:55 +0200 Subject: [PATCH 45/59] fix tracing --- src/guardrails/middleware.rs | 8 +++- src/pipelines/mod.rs | 3 +- src/pipelines/otel.rs | 15 +++++++ src/pipelines/pipeline.rs | 68 ++++++++++++++++------------- src/pipelines/tracing_middleware.rs | 68 +++++++++++++++++++++++++++++ 5 files changed, 130 insertions(+), 32 deletions(-) create mode 100644 src/pipelines/tracing_middleware.rs diff --git a/src/guardrails/middleware.rs b/src/guardrails/middleware.rs index 91738f0e..1142345a 100644 --- a/src/guardrails/middleware.rs +++ b/src/guardrails/middleware.rs @@ -12,6 +12,7 @@ use tracing::debug; use crate::models::chat::{ChatCompletion, ChatCompletionRequest}; use crate::models::completion::{CompletionRequest, CompletionResponse}; use crate::models::embeddings::EmbeddingsRequest; +use crate::pipelines::otel::SharedTracer; use super::parsing::PromptExtractor; use super::runner::GuardrailsRunner; @@ -244,7 +245,12 @@ where } // Resolve guards from pipeline config + request headers - let runner = GuardrailsRunner::new(Some(&guardrails), &parts.headers, None); + // Extract parent context from the tracer in request extensions + let parent_cx = parts.extensions.get::() + .and_then(|tracer| { + tracer.lock().ok().map(|t| t.parent_context()) + }); + let runner = GuardrailsRunner::new(Some(&guardrails), &parts.headers, parent_cx); let runner = match runner { Some(r) => r, diff --git a/src/pipelines/mod.rs b/src/pipelines/mod.rs index deff9c45..b39f2b96 100644 --- a/src/pipelines/mod.rs +++ b/src/pipelines/mod.rs @@ -1,2 +1,3 @@ -mod otel; +pub mod otel; pub mod pipeline; +mod tracing_middleware; diff --git a/src/pipelines/otel.rs b/src/pipelines/otel.rs index 6e18e1c3..6a41f43c 100644 --- a/src/pipelines/otel.rs +++ b/src/pipelines/otel.rs @@ -14,11 +14,14 @@ use opentelemetry_sdk::trace::TracerProvider; use opentelemetry_semantic_conventions::attribute::GEN_AI_REQUEST_MODEL; use opentelemetry_semantic_conventions::trace::*; use std::collections::HashMap; +use std::sync::{Arc, Mutex}; pub trait RecordSpan { fn record_span(&self, span: &mut BoxedSpan); } +pub type SharedTracer = Arc>; + pub struct OtelTracer { llm_span: Option, root_span: BoxedSpan, @@ -102,6 +105,18 @@ impl OtelTracer { } } + /// Helper to extract SharedTracer from request extensions with fallback. + /// This is primarily used by handlers to get the tracer created by TracingMiddleware. + pub fn from_extensions(extensions: &axum::http::Extensions) -> SharedTracer { + extensions + .get::() + .cloned() + .unwrap_or_else(|| { + // Fallback for backwards compatibility + Arc::new(Mutex::new(OtelTracer::start())) + }) + } + pub fn start_llm_span(&mut self, operation: &str, request: &T) { let tracer = global::tracer("traceloop_hub"); let parent_cx = self.parent_context(); diff --git a/src/pipelines/pipeline.rs b/src/pipelines/pipeline.rs index 5d0ccde8..1153c148 100644 --- a/src/pipelines/pipeline.rs +++ b/src/pipelines/pipeline.rs @@ -5,7 +5,8 @@ use crate::models::chat::ChatCompletionResponse; use crate::models::completion::CompletionRequest; use crate::models::embeddings::EmbeddingsRequest; use crate::models::streaming::ChatCompletionChunk; -use crate::pipelines::otel::OtelTracer; +use crate::pipelines::otel::{OtelTracer, SharedTracer}; +use crate::pipelines::tracing_middleware::TracingLayer; use crate::providers::provider::get_vendor_name; use crate::{ ai_models::registry::ModelRegistry, @@ -17,7 +18,7 @@ use axum::response::sse::{Event, KeepAlive}; use axum::response::{IntoResponse, Response, Sse}; use axum::{ Json, Router, - extract::State, + extract::{Extension, State}, http::StatusCode, routing::{get, post}, }; @@ -73,15 +74,15 @@ pub fn create_pipeline( PluginConfig::ModelRouter { models } => match pipeline.r#type { PipelineType::Chat => router.route( "/chat/completions", - post(move |state, payload| chat_completions(state, payload, models)), + post(move |tracer, state, payload| chat_completions(tracer, state, payload, models)), ), PipelineType::Completion => router.route( "/completions", - post(move |state, payload| completions(state, payload, models)), + post(move |tracer, state, payload| completions(tracer, state, payload, models)), ), PipelineType::Embeddings => router.route( "/embeddings", - post(move |state, payload| embeddings(state, payload, models)), + post(move |tracer, state, payload| embeddings(tracer, state, payload, models)), ), }, _ => router, @@ -91,10 +92,11 @@ pub fn create_pipeline( router .with_state(Arc::new(model_registry.clone())) .layer(GuardrailsLayer::new(guardrails)) + .layer(TracingLayer::new()) } fn trace_and_stream( - mut tracer: OtelTracer, + tracer: SharedTracer, stream: BoxStream<'static, Result>, ) -> impl Stream> { stream! { @@ -102,45 +104,47 @@ fn trace_and_stream( while let Some(result) = stream.next().await { yield match result { Ok(chunk) => { - tracer.log_chunk(&chunk); + tracer.lock().unwrap().log_chunk(&chunk); Event::default().json_data(chunk) } Err(e) => { eprintln!("Error in stream: {e:?}"); - tracer.log_error(e.to_string()); + tracer.lock().unwrap().log_error(e.to_string()); Err(axum::Error::new(e)) } }; } - tracer.streaming_end(); + tracer.lock().unwrap().streaming_end(); } } pub async fn chat_completions( + Extension(tracer): Extension, State(model_registry): State>, Json(payload): Json, model_keys: Vec, ) -> Result { - let mut tracer = OtelTracer::start(); - for model_key in model_keys { let model = model_registry.get(&model_key).unwrap(); if payload.model == model.model_type { - tracer.start_llm_span("chat", &payload); - tracer.set_vendor(&get_vendor_name(&model.provider.r#type())); + { + let mut tracer_guard = tracer.lock().unwrap(); + tracer_guard.start_llm_span("chat", &payload); + tracer_guard.set_vendor(&get_vendor_name(&model.provider.r#type())); + } let response = match model.chat_completions(payload.clone()).await { Ok(response) => response, Err(e) => { eprintln!("Chat completion error for model {model_key}: {e:?}"); - tracer.log_error(format!("Chat completion failed: {e:?}")); + tracer.lock().unwrap().log_error(format!("Chat completion failed: {e:?}")); return Err(e); } }; if let ChatCompletionResponse::NonStream(completion) = response { - tracer.log_success(&completion); + tracer.lock().unwrap().log_success(&completion); return Ok(Json(completion).into_response()); } @@ -152,72 +156,76 @@ pub async fn chat_completions( } } - tracer.log_error("No matching model found".to_string()); + tracer.lock().unwrap().log_error("No matching model found".to_string()); eprintln!("No matching model found for: {}", payload.model); Err(StatusCode::NOT_FOUND) } pub async fn completions( + Extension(tracer): Extension, State(model_registry): State>, Json(payload): Json, model_keys: Vec, ) -> Result { - let mut tracer = OtelTracer::start(); - for model_key in model_keys { let model = model_registry.get(&model_key).unwrap(); if payload.model == model.model_type { - tracer.start_llm_span("completion", &payload); - tracer.set_vendor(&get_vendor_name(&model.provider.r#type())); + { + let mut tracer_guard = tracer.lock().unwrap(); + tracer_guard.start_llm_span("completion", &payload); + tracer_guard.set_vendor(&get_vendor_name(&model.provider.r#type())); + } let response = match model.completions(payload.clone()).await { Ok(response) => response, Err(e) => { eprintln!("Completion error for model {model_key}: {e:?}"); - tracer.log_error(format!("Completion failed: {e:?}")); + tracer.lock().unwrap().log_error(format!("Completion failed: {e:?}")); return Err(e); } }; - tracer.log_success(&response); + tracer.lock().unwrap().log_success(&response); return Ok(Json(response).into_response()); } } - tracer.log_error("No matching model found".to_string()); + tracer.lock().unwrap().log_error("No matching model found".to_string()); eprintln!("No matching model found for: {}", payload.model); Err(StatusCode::NOT_FOUND) } pub async fn embeddings( + Extension(tracer): Extension, State(model_registry): State>, Json(payload): Json, model_keys: Vec, ) -> impl IntoResponse { - let mut tracer = OtelTracer::start(); - for model_key in model_keys { let model = model_registry.get(&model_key).unwrap(); if payload.model == model.model_type { - tracer.start_llm_span("embeddings", &payload); - tracer.set_vendor(&get_vendor_name(&model.provider.r#type())); + { + let mut tracer_guard = tracer.lock().unwrap(); + tracer_guard.start_llm_span("embeddings", &payload); + tracer_guard.set_vendor(&get_vendor_name(&model.provider.r#type())); + } let response = match model.embeddings(payload.clone()).await { Ok(response) => response, Err(e) => { eprintln!("Embeddings error for model {model_key}: {e:?}"); - tracer.log_error(format!("Embeddings failed: {e:?}")); + tracer.lock().unwrap().log_error(format!("Embeddings failed: {e:?}")); return Err(e); } }; - tracer.log_success(&response); + tracer.lock().unwrap().log_success(&response); return Ok(Json(response)); } } - tracer.log_error("No matching model found".to_string()); + tracer.lock().unwrap().log_error("No matching model found".to_string()); eprintln!("No matching model found for: {}", payload.model); Err(StatusCode::NOT_FOUND) } diff --git a/src/pipelines/tracing_middleware.rs b/src/pipelines/tracing_middleware.rs new file mode 100644 index 00000000..11b33850 --- /dev/null +++ b/src/pipelines/tracing_middleware.rs @@ -0,0 +1,68 @@ +use axum::body::Body; +use axum::http::Request; +use axum::response::Response; +use std::sync::{Arc, Mutex}; +use tower::{Layer, Service}; + +use super::otel::OtelTracer; + +pub type SharedTracer = Arc>; + +#[derive(Clone)] +pub struct TracingLayer; + +impl TracingLayer { + pub fn new() -> Self { + Self + } +} + +impl Layer for TracingLayer { + type Service = TracingMiddleware; + + fn layer(&self, inner: S) -> Self::Service { + TracingMiddleware { inner } + } +} + +#[derive(Clone)] +pub struct TracingMiddleware { + inner: S, +} + +impl Service> for TracingMiddleware +where + S: Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = std::pin::Pin< + Box> + Send>, + >; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request) -> Self::Future { + let mut inner = self.inner.clone(); + + Box::pin(async move { + // Create root traceloop_hub span and wrap in Arc> + let tracer = Arc::new(Mutex::new(OtelTracer::start())); + + // Insert into request extensions for downstream middleware and handlers + req.extensions_mut().insert(tracer.clone()); + + // Call inner service + let response = inner.call(req).await?; + + // Tracer will be finalized when the Arc is dropped + Ok(response) + }) + } +} From 74dbe1dedf7e901866ee33a4ac7925e5e433e2c7 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Wed, 18 Feb 2026 12:45:04 +0200 Subject: [PATCH 46/59] test for otel --- src/pipelines/pipeline.rs | 1039 +++++++++++++++++++++++++++++++++++++ src/providers/registry.rs | 5 + 2 files changed, 1044 insertions(+) diff --git a/src/pipelines/pipeline.rs b/src/pipelines/pipeline.rs index 1153c148..99d9d119 100644 --- a/src/pipelines/pipeline.rs +++ b/src/pipelines/pipeline.rs @@ -334,6 +334,18 @@ mod tests { } } + // Helper function to create test pipeline with specific type (for otel tests) + fn create_test_pipeline_with_type(model_keys: Vec<&str>, pipeline_type: PipelineType) -> Pipeline { + Pipeline { + name: "test".to_string(), + r#type: pipeline_type, + plugins: vec![PluginConfig::ModelRouter { + models: model_keys.into_iter().map(|s| s.to_string()).collect(), + }], + guards: vec![], + } + } + // Helper function to make GET request to /models async fn get_models_response(app: Router) -> serde_json::Value { let response = app @@ -642,4 +654,1031 @@ mod tests { assert_eq!(get_vendor_name(&anthropic_provider.r#type()), "Anthropic"); assert_eq!(get_vendor_name(&azure_provider.r#type()), "Azure"); } + + // OpenTelemetry span verification tests + mod otel_span_tests { + use super::*; + use opentelemetry::trace::{SpanKind, Status as OtelStatus}; + use opentelemetry_sdk::export::trace::SpanData; + use opentelemetry_sdk::testing::trace::InMemorySpanExporter; + use opentelemetry_sdk::trace::TracerProvider; + use std::sync::LazyLock; + + /// Shared OTel exporter, initialized once for all span tests + /// Tests are isolated by tracking span count before/after each request + static TEST_EXPORTER: LazyLock = LazyLock::new(|| { + let exporter = InMemorySpanExporter::default(); + let provider = TracerProvider::builder() + .with_simple_exporter(exporter.clone()) + .build(); + opentelemetry::global::set_tracer_provider(provider); + exporter + }); + + // Mock provider that returns realistic responses with Usage data + #[derive(Clone)] + struct MockProviderForSpanTests { + provider_type: ProviderType, + } + + #[async_trait] + impl Provider for MockProviderForSpanTests { + fn new(_config: &ProviderConfig) -> Self { + Self { + provider_type: ProviderType::OpenAI, + } + } + + fn key(&self) -> String { + "test-key".to_string() + } + + fn r#type(&self) -> ProviderType { + self.provider_type.clone() + } + + async fn chat_completions( + &self, + payload: crate::models::chat::ChatCompletionRequest, + _model_config: &ModelConfig, + ) -> Result { + use crate::models::chat::{ChatCompletion, ChatCompletionChoice, ChatCompletionResponse}; + use crate::models::content::{ChatCompletionMessage, ChatMessageContent}; + use crate::models::usage::Usage; + + Ok(ChatCompletionResponse::NonStream(ChatCompletion { + id: "chatcmpl-test123".to_string(), + object: Some("chat.completion".to_string()), + created: Some(1234567890), + model: payload.model.clone(), + choices: vec![ChatCompletionChoice { + index: 0, + message: ChatCompletionMessage { + role: "assistant".to_string(), + content: Some(ChatMessageContent::String("Test response".to_string())), + name: None, + tool_calls: None, + tool_call_id: None, + refusal: None, + }, + finish_reason: Some("stop".to_string()), + logprobs: None, + }], + usage: Usage { + prompt_tokens: 10, + completion_tokens: 15, + total_tokens: 25, + completion_tokens_details: None, + prompt_tokens_details: None, + }, + system_fingerprint: None, + })) + } + + async fn completions( + &self, + payload: crate::models::completion::CompletionRequest, + _model_config: &ModelConfig, + ) -> Result { + use crate::models::completion::{CompletionChoice, CompletionResponse}; + use crate::models::usage::Usage; + + Ok(CompletionResponse { + id: "cmpl-test456".to_string(), + object: "text_completion".to_string(), + created: 1234567890, + model: payload.model.clone(), + choices: vec![CompletionChoice { + text: "Test completion".to_string(), + index: 0, + logprobs: None, + finish_reason: Some("stop".to_string()), + }], + usage: Usage { + prompt_tokens: 5, + completion_tokens: 10, + total_tokens: 15, + completion_tokens_details: None, + prompt_tokens_details: None, + }, + }) + } + + async fn embeddings( + &self, + payload: crate::models::embeddings::EmbeddingsRequest, + _model_config: &ModelConfig, + ) -> Result { + use crate::models::embeddings::{Embedding, Embeddings, EmbeddingsResponse}; + use crate::models::usage::EmbeddingUsage; + + Ok(EmbeddingsResponse { + object: "list".to_string(), + data: vec![Embeddings { + object: "embedding".to_string(), + embedding: Embedding::Float(vec![0.1, 0.2, 0.3]), + index: 0, + }], + model: payload.model.clone(), + usage: EmbeddingUsage { + prompt_tokens: Some(8), + total_tokens: Some(8), + }, + }) + } + } + + // Mock provider that returns errors + #[derive(Clone)] + struct MockProviderError { + provider_type: ProviderType, + } + + #[async_trait] + impl Provider for MockProviderError { + fn new(_config: &ProviderConfig) -> Self { + Self { + provider_type: ProviderType::OpenAI, + } + } + + fn key(&self) -> String { + "test-key".to_string() + } + + fn r#type(&self) -> ProviderType { + self.provider_type.clone() + } + + async fn chat_completions( + &self, + _payload: crate::models::chat::ChatCompletionRequest, + _model_config: &ModelConfig, + ) -> Result { + Err(StatusCode::INTERNAL_SERVER_ERROR) + } + + async fn completions( + &self, + _payload: crate::models::completion::CompletionRequest, + _model_config: &ModelConfig, + ) -> Result { + Err(StatusCode::BAD_GATEWAY) + } + + async fn embeddings( + &self, + _payload: crate::models::embeddings::EmbeddingsRequest, + _model_config: &ModelConfig, + ) -> Result { + Err(StatusCode::SERVICE_UNAVAILABLE) + } + } + + // Helper: Create provider registry with specific provider type + fn create_test_provider_registry_for_spans( + provider_type: ProviderType, + return_errors: bool, + ) -> Arc { + let provider_config = ProviderConfig { + key: "test-provider".to_string(), + r#type: provider_type.clone(), + api_key: String::new(), + params: HashMap::new(), + }; + + let mut registry = ProviderRegistry::new(&[provider_config.clone()]).unwrap(); + + // Replace the provider with our mock + if return_errors { + let mock = MockProviderError { provider_type }; + let providers = registry.providers_mut(); + providers.insert("test-provider".to_string(), Arc::new(mock)); + } else { + let mock = MockProviderForSpanTests { provider_type }; + let providers = registry.providers_mut(); + providers.insert("test-provider".to_string(), Arc::new(mock)); + } + + Arc::new(registry) + } + + // Helper: Collect spans added since before_count + fn get_spans_for_test(before_count: usize) -> Vec { + // Get all spans + let all_spans = TEST_EXPORTER.get_finished_spans().unwrap(); + + // Skip the before_count spans and take only the next few + // We expect exactly 2 spans per test (root + LLM) + let new_spans: Vec = all_spans.into_iter().skip(before_count).collect(); + + // If we have more than 2 spans, try to find the most recent root span + // and its immediate child + if new_spans.len() > 2 { + // Find the last root span (traceloop_hub with Server kind) + if let Some(root_idx) = new_spans.iter().rposition(|s| { + s.name == "traceloop_hub" && s.span_kind == SpanKind::Server + }) { + let root = &new_spans[root_idx]; + let root_trace_id = root.span_context.trace_id(); + + // Collect all spans with the same trace_id + return new_spans + .into_iter() + .filter(|s| s.span_context.trace_id() == root_trace_id) + .collect(); + } + } + + new_spans + } + + // Helper: Find root span (name="traceloop_hub", SpanKind::Server) + fn get_root_span(spans: &[SpanData]) -> Option<&SpanData> { + spans + .iter() + .find(|s| s.name == "traceloop_hub" && s.span_kind == SpanKind::Server) + } + + // Helper: Find LLM span by operation + fn get_llm_span<'a>(spans: &'a [SpanData], operation: &str) -> Option<&'a SpanData> { + let expected_name = format!("traceloop_hub.{}", operation); + spans + .iter() + .find(|s| s.name == expected_name && s.span_kind == SpanKind::Client) + } + + // Helper: Extract attribute value from span + fn get_span_attribute(span: &SpanData, key: &str) -> Option { + span.attributes + .iter() + .find(|kv| kv.key.to_string() == key) + .map(|kv| kv.value.to_string()) + } + + // Helper: Check if span is child of another span + fn is_child_of(child: &SpanData, parent: &SpanData) -> bool { + child.parent_span_id == parent.span_context.span_id() + && child.span_context.trace_id() == parent.span_context.trace_id() + } + + // Helper: Assert span has expected attributes + fn assert_span_attributes(span: &SpanData, expected: &[(&str, &str)]) { + for (key, expected_value) in expected { + let actual = get_span_attribute(span, key); + assert_eq!( + actual.as_deref(), + Some(*expected_value), + "Attribute {} mismatch. Expected: {}, Got: {:?}", + key, + expected_value, + actual + ); + } + } + + #[tokio::test] + async fn test_chat_completions_success_spans() { + // Initialize exporter + let _ = &*TEST_EXPORTER; + let before_count = TEST_EXPORTER.get_finished_spans().unwrap().len(); + + // Create test infrastructure + let provider_registry = + create_test_provider_registry_for_spans(ProviderType::OpenAI, false); + let model_configs = create_model_configs(vec!["test-model"]); + let model_registry = + Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); + let pipeline = create_test_pipeline(vec!["test-model"]); + + // Create router (includes TracingLayer) + let app = create_pipeline(&pipeline, &model_registry, None); + + // Prepare request + use crate::models::chat::ChatCompletionRequest; + use crate::models::content::{ChatCompletionMessage, ChatMessageContent}; + let request_body = ChatCompletionRequest { + model: "test".to_string(), + messages: vec![ChatCompletionMessage { + role: "user".to_string(), + content: Some(ChatMessageContent::String("Hello".to_string())), + name: None, + tool_calls: None, + tool_call_id: None, + refusal: None, + }], + temperature: Some(0.7), + top_p: None, + n: None, + stream: None, + stop: None, + max_tokens: None, + max_completion_tokens: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + user: None, + response_format: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + logprobs: None, + top_logprobs: None, + reasoning: None, + }; + + // Send request + let response = app + .oneshot( + Request::builder() + .uri("/chat/completions") + .method("POST") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request_body).unwrap())) + .unwrap(), + ) + .await + .unwrap(); + + // Verify response success + assert_eq!(response.status(), StatusCode::OK); + + // Collect new spans + let spans = get_spans_for_test(before_count); + assert_eq!(spans.len(), 2, "Expected root + LLM span, got {}", spans.len()); + + // Verify root span + let root = get_root_span(&spans).expect("Root span not found"); + assert_eq!(root.name, "traceloop_hub"); + assert_eq!(root.span_kind, SpanKind::Server); + assert_eq!(root.status, OtelStatus::Ok); + + // Verify LLM span + let llm = get_llm_span(&spans, "chat").expect("LLM span not found"); + assert_eq!(llm.name, "traceloop_hub.chat"); + assert_eq!(llm.span_kind, SpanKind::Client); + assert_eq!(llm.status, OtelStatus::Ok); + + // Verify hierarchy + assert!( + is_child_of(llm, root), + "LLM span should be child of root span" + ); + + // Verify attributes + assert_span_attributes( + llm, + &[ + ("gen_ai.system", "openai"), + ("gen_ai.request.model", "test"), + ("llm.request.type", "chat"), + ], + ); + + // Verify usage attributes exist + assert!( + get_span_attribute(llm, "gen_ai.usage.prompt_tokens").is_some(), + "prompt_tokens attribute missing" + ); + assert!( + get_span_attribute(llm, "gen_ai.usage.completion_tokens").is_some(), + "completion_tokens attribute missing" + ); + assert!( + get_span_attribute(llm, "gen_ai.usage.total_tokens").is_some(), + "total_tokens attribute missing" + ); + } + + #[tokio::test] + async fn test_completions_success_spans() { + let _ = &*TEST_EXPORTER; + let before_count = TEST_EXPORTER.get_finished_spans().unwrap().len(); + + let provider_registry = + create_test_provider_registry_for_spans(ProviderType::Anthropic, false); + let model_configs = create_model_configs(vec!["test-model"]); + let model_registry = + Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); + let pipeline = create_test_pipeline_with_type(vec!["test-model"], PipelineType::Completion); + let app = create_pipeline(&pipeline, &model_registry, None); + + use crate::models::completion::CompletionRequest; + let request_body = CompletionRequest { + model: "test".to_string(), + prompt: "Test prompt".to_string(), + max_tokens: Some(50), + temperature: None, + top_p: None, + n: None, + stream: None, + logprobs: None, + echo: None, + stop: None, + presence_penalty: None, + frequency_penalty: None, + best_of: None, + logit_bias: None, + user: None, + suffix: None, + }; + + let response = app + .oneshot( + Request::builder() + .uri("/completions") + .method("POST") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request_body).unwrap())) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let spans = get_spans_for_test(before_count); + assert_eq!(spans.len(), 2, "Expected root + LLM span"); + + let root = get_root_span(&spans).expect("Root span not found"); + let llm = get_llm_span(&spans, "completion").expect("LLM span not found"); + + assert_eq!(llm.name, "traceloop_hub.completion"); + assert_eq!(llm.span_kind, SpanKind::Client); + assert!(is_child_of(llm, root)); + + assert_span_attributes( + llm, + &[ + ("gen_ai.system", "Anthropic"), + ("gen_ai.request.model", "test"), + ("llm.request.type", "completion"), + ], + ); + } + + #[tokio::test] + async fn test_embeddings_success_spans() { + let _ = &*TEST_EXPORTER; + let before_count = TEST_EXPORTER.get_finished_spans().unwrap().len(); + + let provider_registry = + create_test_provider_registry_for_spans(ProviderType::Azure, false); + let model_configs = create_model_configs(vec!["test-model"]); + let model_registry = + Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); + let pipeline = create_test_pipeline_with_type(vec!["test-model"], PipelineType::Embeddings); + let app = create_pipeline(&pipeline, &model_registry, None); + + use crate::models::embeddings::{EmbeddingsInput, EmbeddingsRequest}; + let request_body = EmbeddingsRequest { + input: EmbeddingsInput::Single("Test text".to_string()), + model: "test".to_string(), + encoding_format: None, + user: None, + }; + + let response = app + .oneshot( + Request::builder() + .uri("/embeddings") + .method("POST") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request_body).unwrap())) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let spans = get_spans_for_test(before_count); + assert_eq!(spans.len(), 2, "Expected root + LLM span"); + + let root = get_root_span(&spans).expect("Root span not found"); + let llm = get_llm_span(&spans, "embeddings").expect("LLM span not found"); + + assert_eq!(llm.name, "traceloop_hub.embeddings"); + assert_eq!(llm.span_kind, SpanKind::Client); + assert!(is_child_of(llm, root)); + + assert_span_attributes( + llm, + &[ + ("gen_ai.system", "Azure"), + ("gen_ai.request.model", "test"), + ("llm.request.type", "embeddings"), + ], + ); + } + + #[tokio::test] + async fn test_chat_completions_error_spans() { + let _ = &*TEST_EXPORTER; + let before_count = TEST_EXPORTER.get_finished_spans().unwrap().len(); + + let provider_registry = + create_test_provider_registry_for_spans(ProviderType::OpenAI, true); + let model_configs = create_model_configs(vec!["test-model"]); + let model_registry = + Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); + let pipeline = create_test_pipeline(vec!["test-model"]); + let app = create_pipeline(&pipeline, &model_registry, None); + + use crate::models::chat::ChatCompletionRequest; + use crate::models::content::{ChatCompletionMessage, ChatMessageContent}; + let request_body = ChatCompletionRequest { + model: "test".to_string(), + messages: vec![ChatCompletionMessage { + role: "user".to_string(), + content: Some(ChatMessageContent::String("Hello".to_string())), + name: None, + tool_calls: None, + tool_call_id: None, + refusal: None, + }], + temperature: None, + top_p: None, + n: None, + stream: None, + stop: None, + max_tokens: None, + max_completion_tokens: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + user: None, + response_format: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + logprobs: None, + top_logprobs: None, + reasoning: None, + }; + + let response = app + .oneshot( + Request::builder() + .uri("/chat/completions") + .method("POST") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request_body).unwrap())) + .unwrap(), + ) + .await + .unwrap(); + + // Verify error response + assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); + + // Collect spans + let spans = get_spans_for_test(before_count); + assert_eq!(spans.len(), 2, "Expected root + LLM span"); + + let root = get_root_span(&spans).expect("Root span not found"); + let llm = get_llm_span(&spans, "chat").expect("LLM span not found"); + + // Verify error status on both spans + assert!( + matches!(root.status, OtelStatus::Error { .. }), + "Root span should have error status" + ); + assert!( + matches!(llm.status, OtelStatus::Error { .. }), + "LLM span should have error status" + ); + + // Hierarchy should still be correct + assert!(is_child_of(llm, root)); + } + + #[tokio::test] + async fn test_completions_error_spans() { + let _ = &*TEST_EXPORTER; + let before_count = TEST_EXPORTER.get_finished_spans().unwrap().len(); + + let provider_registry = + create_test_provider_registry_for_spans(ProviderType::OpenAI, true); + let model_configs = create_model_configs(vec!["test-model"]); + let model_registry = + Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); + let pipeline = create_test_pipeline_with_type(vec!["test-model"], PipelineType::Completion); + let app = create_pipeline(&pipeline, &model_registry, None); + + use crate::models::completion::CompletionRequest; + let request_body = CompletionRequest { + model: "test".to_string(), + prompt: "Test".to_string(), + max_tokens: None, + temperature: None, + top_p: None, + n: None, + stream: None, + logprobs: None, + echo: None, + stop: None, + presence_penalty: None, + frequency_penalty: None, + best_of: None, + logit_bias: None, + user: None, + suffix: None, + }; + + let response = app + .oneshot( + Request::builder() + .uri("/completions") + .method("POST") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request_body).unwrap())) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::BAD_GATEWAY); + + let spans = get_spans_for_test(before_count); + let root = get_root_span(&spans).expect("Root span not found"); + let llm = get_llm_span(&spans, "completion").expect("LLM span not found"); + + assert!(matches!(root.status, OtelStatus::Error { .. })); + assert!(matches!(llm.status, OtelStatus::Error { .. })); + } + + #[tokio::test] + async fn test_embeddings_error_spans() { + let _ = &*TEST_EXPORTER; + let before_count = TEST_EXPORTER.get_finished_spans().unwrap().len(); + + let provider_registry = + create_test_provider_registry_for_spans(ProviderType::OpenAI, true); + let model_configs = create_model_configs(vec!["test-model"]); + let model_registry = + Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); + let pipeline = create_test_pipeline_with_type(vec!["test-model"], PipelineType::Embeddings); + let app = create_pipeline(&pipeline, &model_registry, None); + + use crate::models::embeddings::{EmbeddingsInput, EmbeddingsRequest}; + let request_body = EmbeddingsRequest { + input: EmbeddingsInput::Single("Test".to_string()), + model: "test".to_string(), + encoding_format: None, + user: None, + }; + + let response = app + .oneshot( + Request::builder() + .uri("/embeddings") + .method("POST") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request_body).unwrap())) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE); + + let spans = get_spans_for_test(before_count); + let root = get_root_span(&spans).expect("Root span not found"); + let llm = get_llm_span(&spans, "embeddings").expect("LLM span not found"); + + assert!(matches!(root.status, OtelStatus::Error { .. })); + assert!(matches!(llm.status, OtelStatus::Error { .. })); + } + + #[tokio::test] + async fn test_span_request_attributes_recorded() { + let _ = &*TEST_EXPORTER; + let before_count = TEST_EXPORTER.get_finished_spans().unwrap().len(); + + let provider_registry = + create_test_provider_registry_for_spans(ProviderType::OpenAI, false); + let model_configs = create_model_configs(vec!["test-model"]); + let model_registry = + Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); + let pipeline = create_test_pipeline(vec!["test-model"]); + let app = create_pipeline(&pipeline, &model_registry, None); + + use crate::models::chat::ChatCompletionRequest; + use crate::models::content::{ChatCompletionMessage, ChatMessageContent}; + let request_body = ChatCompletionRequest { + model: "test".to_string(), + messages: vec![ChatCompletionMessage { + role: "user".to_string(), + content: Some(ChatMessageContent::String("Hello".to_string())), + name: None, + tool_calls: None, + tool_call_id: None, + refusal: None, + }], + temperature: Some(0.8), + top_p: Some(0.9), + frequency_penalty: Some(0.5), + presence_penalty: Some(0.3), + n: None, + stream: None, + stop: None, + max_tokens: None, + max_completion_tokens: None, + logit_bias: None, + user: None, + response_format: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + logprobs: None, + top_logprobs: None, + reasoning: None, + }; + + let _response = app + .oneshot( + Request::builder() + .uri("/chat/completions") + .method("POST") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request_body).unwrap())) + .unwrap(), + ) + .await + .unwrap(); + + let spans = get_spans_for_test(before_count); + let llm = get_llm_span(&spans, "chat").expect("LLM span not found"); + + // Verify request parameters are recorded + assert_span_attributes(llm, &[("gen_ai.request.model", "test")]); + + // Verify float attributes exist (but don't check exact values due to float precision) + assert!( + get_span_attribute(llm, "gen_ai.request.temperature").is_some(), + "Temperature attribute should exist" + ); + assert!( + get_span_attribute(llm, "gen_ai.request.top_p").is_some(), + "top_p attribute should exist" + ); + assert!( + get_span_attribute(llm, "gen_ai.request.frequency_penalty").is_some(), + "frequency_penalty attribute should exist" + ); + assert!( + get_span_attribute(llm, "gen_ai.request.presence_penalty").is_some(), + "presence_penalty attribute should exist" + ); + } + + #[tokio::test] + async fn test_span_response_attributes_recorded() { + let _ = &*TEST_EXPORTER; + let before_count = TEST_EXPORTER.get_finished_spans().unwrap().len(); + + let provider_registry = + create_test_provider_registry_for_spans(ProviderType::OpenAI, false); + let model_configs = create_model_configs(vec!["test-model"]); + let model_registry = + Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); + let pipeline = create_test_pipeline(vec!["test-model"]); + let app = create_pipeline(&pipeline, &model_registry, None); + + use crate::models::chat::ChatCompletionRequest; + use crate::models::content::{ChatCompletionMessage, ChatMessageContent}; + let request_body = ChatCompletionRequest { + model: "test".to_string(), + messages: vec![ChatCompletionMessage { + role: "user".to_string(), + content: Some(ChatMessageContent::String("Hello".to_string())), + name: None, + tool_calls: None, + tool_call_id: None, + refusal: None, + }], + temperature: None, + top_p: None, + n: None, + stream: None, + stop: None, + max_tokens: None, + max_completion_tokens: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + user: None, + response_format: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + logprobs: None, + top_logprobs: None, + reasoning: None, + }; + + let _response = app + .oneshot( + Request::builder() + .uri("/chat/completions") + .method("POST") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request_body).unwrap())) + .unwrap(), + ) + .await + .unwrap(); + + let spans = get_spans_for_test(before_count); + let llm = get_llm_span(&spans, "chat").expect("LLM span not found"); + + // Verify response attributes + assert_span_attributes( + llm, + &[ + ("gen_ai.response.model", "test"), + ("gen_ai.response.id", "chatcmpl-test123"), + ], + ); + + // Verify usage tokens + assert_eq!( + get_span_attribute(llm, "gen_ai.usage.prompt_tokens"), + Some("10".to_string()) + ); + assert_eq!( + get_span_attribute(llm, "gen_ai.usage.completion_tokens"), + Some("15".to_string()) + ); + assert_eq!( + get_span_attribute(llm, "gen_ai.usage.total_tokens"), + Some("25".to_string()) + ); + } + + #[tokio::test] + async fn test_vendor_attribute_mapping() { + let _ = &*TEST_EXPORTER; + + // Test each provider type maps to correct vendor name + // Note: Skip VertexAI as it requires project_id and location params + let test_cases = vec![ + (ProviderType::OpenAI, "openai"), + (ProviderType::Anthropic, "Anthropic"), + (ProviderType::Azure, "Azure"), + (ProviderType::Bedrock, "AWS"), + ]; + + for (provider_type, expected_vendor) in test_cases { + let before_count = TEST_EXPORTER.get_finished_spans().unwrap().len(); + + let provider_registry = + create_test_provider_registry_for_spans(provider_type.clone(), false); + let model_configs = create_model_configs(vec!["test-model"]); + let model_registry = + Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); + let pipeline = create_test_pipeline(vec!["test-model"]); + let app = create_pipeline(&pipeline, &model_registry, None); + + use crate::models::chat::ChatCompletionRequest; + use crate::models::content::{ChatCompletionMessage, ChatMessageContent}; + let request_body = ChatCompletionRequest { + model: "test".to_string(), + messages: vec![ChatCompletionMessage { + role: "user".to_string(), + content: Some(ChatMessageContent::String("Test".to_string())), + name: None, + tool_calls: None, + tool_call_id: None, + refusal: None, + }], + temperature: None, + top_p: None, + n: None, + stream: None, + stop: None, + max_tokens: None, + max_completion_tokens: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + user: None, + response_format: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + logprobs: None, + top_logprobs: None, + reasoning: None, + }; + + let _response = app + .oneshot( + Request::builder() + .uri("/chat/completions") + .method("POST") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request_body).unwrap())) + .unwrap(), + ) + .await + .unwrap(); + + let spans = get_spans_for_test(before_count); + let llm = get_llm_span(&spans, "chat") + .unwrap_or_else(|| panic!("LLM span not found for {:?}", provider_type)); + + assert_eq!( + get_span_attribute(llm, "gen_ai.system").as_deref(), + Some(expected_vendor), + "Vendor attribute mismatch for {:?}", + provider_type + ); + } + } + + #[tokio::test] + async fn test_span_parent_child_relationship() { + let _ = &*TEST_EXPORTER; + let before_count = TEST_EXPORTER.get_finished_spans().unwrap().len(); + + let provider_registry = + create_test_provider_registry_for_spans(ProviderType::OpenAI, false); + let model_configs = create_model_configs(vec!["test-model"]); + let model_registry = + Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); + let pipeline = create_test_pipeline(vec!["test-model"]); + let app = create_pipeline(&pipeline, &model_registry, None); + + use crate::models::chat::ChatCompletionRequest; + use crate::models::content::{ChatCompletionMessage, ChatMessageContent}; + let request_body = ChatCompletionRequest { + model: "test".to_string(), + messages: vec![ChatCompletionMessage { + role: "user".to_string(), + content: Some(ChatMessageContent::String("Test".to_string())), + name: None, + tool_calls: None, + tool_call_id: None, + refusal: None, + }], + temperature: None, + top_p: None, + n: None, + stream: None, + stop: None, + max_tokens: None, + max_completion_tokens: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + user: None, + response_format: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + logprobs: None, + top_logprobs: None, + reasoning: None, + }; + + let _response = app + .oneshot( + Request::builder() + .uri("/chat/completions") + .method("POST") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request_body).unwrap())) + .unwrap(), + ) + .await + .unwrap(); + + let spans = get_spans_for_test(before_count); + let root = get_root_span(&spans).expect("Root span not found"); + let llm = get_llm_span(&spans, "chat").expect("LLM span not found"); + + // Verify parent-child relationship + assert_eq!( + llm.parent_span_id, + root.span_context.span_id(), + "LLM span's parent_span_id should equal root span's span_id" + ); + + // Verify trace_id consistency + assert_eq!( + llm.span_context.trace_id(), + root.span_context.trace_id(), + "Both spans should share the same trace_id" + ); + + // Verify root span has no parent (is actually root) + assert_eq!( + root.parent_span_id.to_string(), + "0000000000000000", + "Root span should not have a parent" + ); + } + } } diff --git a/src/providers/registry.rs b/src/providers/registry.rs index 4e25e91e..3a012de9 100644 --- a/src/providers/registry.rs +++ b/src/providers/registry.rs @@ -34,4 +34,9 @@ impl ProviderRegistry { pub fn get(&self, name: &str) -> Option> { self.providers.get(name).cloned() } + + #[cfg(test)] + pub fn providers_mut(&mut self) -> &mut HashMap> { + &mut self.providers + } } From f97082dc31138d0e2e644faa333158c388acc085 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Wed, 18 Feb 2026 13:02:30 +0200 Subject: [PATCH 47/59] ci --- src/guardrails/evaluator_types.rs | 28 +++++- src/guardrails/middleware.rs | 44 ++++++--- src/pipelines/pipeline.rs | 71 ++++++++++---- src/pipelines/tracing_middleware.rs | 2 - tests/guardrails/helpers.rs | 8 +- tests/guardrails/test_e2e.rs | 40 ++++---- tests/guardrails/test_middleware.rs | 137 ++++++++++++++++++++++------ 7 files changed, 242 insertions(+), 88 deletions(-) diff --git a/src/guardrails/evaluator_types.rs b/src/guardrails/evaluator_types.rs index 85a9944f..be601f1a 100644 --- a/src/guardrails/evaluator_types.rs +++ b/src/guardrails/evaluator_types.rs @@ -172,8 +172,28 @@ pub struct JsonValidatorConfig { // --------------------------------------------------------------------------- evaluator_with_config!(PiiDetector, text_body, PiiDetectorConfig, PII_DETECTOR); -evaluator_with_config!(PromptInjection, prompt_body, ThresholdConfig, PROMPT_INJECTION); +evaluator_with_config!( + PromptInjection, + prompt_body, + ThresholdConfig, + PROMPT_INJECTION +); evaluator_with_config!(SexismDetector, text_body, ThresholdConfig, SEXISM_DETECTOR); -evaluator_with_config!(ToxicityDetector, text_body, ThresholdConfig, TOXICITY_DETECTOR); -evaluator_with_config!(RegexValidator, text_body, RegexValidatorConfig, REGEX_VALIDATOR); -evaluator_with_config!(JsonValidator, text_body, JsonValidatorConfig, JSON_VALIDATOR); +evaluator_with_config!( + ToxicityDetector, + text_body, + ThresholdConfig, + TOXICITY_DETECTOR +); +evaluator_with_config!( + RegexValidator, + text_body, + RegexValidatorConfig, + REGEX_VALIDATOR +); +evaluator_with_config!( + JsonValidator, + text_body, + JsonValidatorConfig, + JSON_VALIDATOR +); diff --git a/src/guardrails/middleware.rs b/src/guardrails/middleware.rs index 1142345a..eb3bfbe6 100644 --- a/src/guardrails/middleware.rs +++ b/src/guardrails/middleware.rs @@ -40,9 +40,9 @@ impl EndpointType { /// Enum representing the type of request being processed. enum ParsedRequest { - Chat(ChatCompletionRequest), - Completion(CompletionRequest), - Embeddings(EmbeddingsRequest), + Chat(Box), + Completion(Box), + Embeddings(Box), } impl ParsedRequest { @@ -188,7 +188,10 @@ where Some(t) => t, None => { // Unsupported endpoint — pass through - debug!("Guardrails middleware: unsupported endpoint {}, passing through", parts.uri.path()); + debug!( + "Guardrails middleware: unsupported endpoint {}, passing through", + parts.uri.path() + ); let request = Request::from_parts(parts, body); return inner.call(request).await; } @@ -207,7 +210,7 @@ where let parsed_request = match endpoint_type { EndpointType::Chat => { match serde_json::from_slice::(&bytes) { - Ok(req) => ParsedRequest::Chat(req), + Ok(req) => ParsedRequest::Chat(Box::new(req)), Err(e) => { debug!("Guardrails middleware: failed to parse chat request: {}", e); let request = Request::from_parts(parts, Body::from(bytes)); @@ -217,9 +220,12 @@ where } EndpointType::Completion => { match serde_json::from_slice::(&bytes) { - Ok(req) => ParsedRequest::Completion(req), + Ok(req) => ParsedRequest::Completion(Box::new(req)), Err(e) => { - debug!("Guardrails middleware: failed to parse completion request: {}", e); + debug!( + "Guardrails middleware: failed to parse completion request: {}", + e + ); let request = Request::from_parts(parts, Body::from(bytes)); return inner.call(request).await; } @@ -227,9 +233,12 @@ where } EndpointType::Embeddings => { match serde_json::from_slice::(&bytes) { - Ok(req) => ParsedRequest::Embeddings(req), + Ok(req) => ParsedRequest::Embeddings(Box::new(req)), Err(e) => { - debug!("Guardrails middleware: failed to parse embeddings request: {}", e); + debug!( + "Guardrails middleware: failed to parse embeddings request: {}", + e + ); let request = Request::from_parts(parts, Body::from(bytes)); return inner.call(request).await; } @@ -246,10 +255,10 @@ where // Resolve guards from pipeline config + request headers // Extract parent context from the tracer in request extensions - let parent_cx = parts.extensions.get::() - .and_then(|tracer| { - tracer.lock().ok().map(|t| t.parent_context()) - }); + let parent_cx = parts + .extensions + .get::() + .and_then(|tracer| tracer.lock().ok().map(|t| t.parent_context())); let runner = GuardrailsRunner::new(Some(&guardrails), &parts.headers, parent_cx); let runner = match runner { @@ -276,7 +285,14 @@ where let (resp_parts, resp_body) = response.into_parts(); if parsed_request.supports_post_call() { - Ok(handle_post_call_guards(&parsed_request, resp_parts, resp_body, &runner, all_warnings).await) + Ok(handle_post_call_guards( + &parsed_request, + resp_parts, + resp_body, + &runner, + all_warnings, + ) + .await) } else { // No post-call guards for this request type (e.g., embeddings) // Pass through response with pre-call warnings attached diff --git a/src/pipelines/pipeline.rs b/src/pipelines/pipeline.rs index 99d9d119..feee6f5b 100644 --- a/src/pipelines/pipeline.rs +++ b/src/pipelines/pipeline.rs @@ -74,7 +74,9 @@ pub fn create_pipeline( PluginConfig::ModelRouter { models } => match pipeline.r#type { PipelineType::Chat => router.route( "/chat/completions", - post(move |tracer, state, payload| chat_completions(tracer, state, payload, models)), + post(move |tracer, state, payload| { + chat_completions(tracer, state, payload, models) + }), ), PipelineType::Completion => router.route( "/completions", @@ -138,7 +140,10 @@ pub async fn chat_completions( Ok(response) => response, Err(e) => { eprintln!("Chat completion error for model {model_key}: {e:?}"); - tracer.lock().unwrap().log_error(format!("Chat completion failed: {e:?}")); + tracer + .lock() + .unwrap() + .log_error(format!("Chat completion failed: {e:?}")); return Err(e); } }; @@ -156,7 +161,10 @@ pub async fn chat_completions( } } - tracer.lock().unwrap().log_error("No matching model found".to_string()); + tracer + .lock() + .unwrap() + .log_error("No matching model found".to_string()); eprintln!("No matching model found for: {}", payload.model); Err(StatusCode::NOT_FOUND) } @@ -181,7 +189,10 @@ pub async fn completions( Ok(response) => response, Err(e) => { eprintln!("Completion error for model {model_key}: {e:?}"); - tracer.lock().unwrap().log_error(format!("Completion failed: {e:?}")); + tracer + .lock() + .unwrap() + .log_error(format!("Completion failed: {e:?}")); return Err(e); } }; @@ -191,7 +202,10 @@ pub async fn completions( } } - tracer.lock().unwrap().log_error("No matching model found".to_string()); + tracer + .lock() + .unwrap() + .log_error("No matching model found".to_string()); eprintln!("No matching model found for: {}", payload.model); Err(StatusCode::NOT_FOUND) } @@ -216,7 +230,10 @@ pub async fn embeddings( Ok(response) => response, Err(e) => { eprintln!("Embeddings error for model {model_key}: {e:?}"); - tracer.lock().unwrap().log_error(format!("Embeddings failed: {e:?}")); + tracer + .lock() + .unwrap() + .log_error(format!("Embeddings failed: {e:?}")); return Err(e); } }; @@ -225,7 +242,10 @@ pub async fn embeddings( } } - tracer.lock().unwrap().log_error("No matching model found".to_string()); + tracer + .lock() + .unwrap() + .log_error("No matching model found".to_string()); eprintln!("No matching model found for: {}", payload.model); Err(StatusCode::NOT_FOUND) } @@ -335,7 +355,10 @@ mod tests { } // Helper function to create test pipeline with specific type (for otel tests) - fn create_test_pipeline_with_type(model_keys: Vec<&str>, pipeline_type: PipelineType) -> Pipeline { + fn create_test_pipeline_with_type( + model_keys: Vec<&str>, + pipeline_type: PipelineType, + ) -> Pipeline { Pipeline { name: "test".to_string(), r#type: pipeline_type, @@ -702,7 +725,9 @@ mod tests { payload: crate::models::chat::ChatCompletionRequest, _model_config: &ModelConfig, ) -> Result { - use crate::models::chat::{ChatCompletion, ChatCompletionChoice, ChatCompletionResponse}; + use crate::models::chat::{ + ChatCompletion, ChatCompletionChoice, ChatCompletionResponse, + }; use crate::models::content::{ChatCompletionMessage, ChatMessageContent}; use crate::models::usage::Usage; @@ -876,9 +901,10 @@ mod tests { // and its immediate child if new_spans.len() > 2 { // Find the last root span (traceloop_hub with Server kind) - if let Some(root_idx) = new_spans.iter().rposition(|s| { - s.name == "traceloop_hub" && s.span_kind == SpanKind::Server - }) { + if let Some(root_idx) = new_spans + .iter() + .rposition(|s| s.name == "traceloop_hub" && s.span_kind == SpanKind::Server) + { let root = &new_spans[root_idx]; let root_trace_id = root.span_context.trace_id(); @@ -1005,7 +1031,12 @@ mod tests { // Collect new spans let spans = get_spans_for_test(before_count); - assert_eq!(spans.len(), 2, "Expected root + LLM span, got {}", spans.len()); + assert_eq!( + spans.len(), + 2, + "Expected root + LLM span, got {}", + spans.len() + ); // Verify root span let root = get_root_span(&spans).expect("Root span not found"); @@ -1060,7 +1091,8 @@ mod tests { let model_configs = create_model_configs(vec!["test-model"]); let model_registry = Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); - let pipeline = create_test_pipeline_with_type(vec!["test-model"], PipelineType::Completion); + let pipeline = + create_test_pipeline_with_type(vec!["test-model"], PipelineType::Completion); let app = create_pipeline(&pipeline, &model_registry, None); use crate::models::completion::CompletionRequest; @@ -1127,7 +1159,8 @@ mod tests { let model_configs = create_model_configs(vec!["test-model"]); let model_registry = Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); - let pipeline = create_test_pipeline_with_type(vec!["test-model"], PipelineType::Embeddings); + let pipeline = + create_test_pipeline_with_type(vec!["test-model"], PipelineType::Embeddings); let app = create_pipeline(&pipeline, &model_registry, None); use crate::models::embeddings::{EmbeddingsInput, EmbeddingsRequest}; @@ -1263,7 +1296,8 @@ mod tests { let model_configs = create_model_configs(vec!["test-model"]); let model_registry = Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); - let pipeline = create_test_pipeline_with_type(vec!["test-model"], PipelineType::Completion); + let pipeline = + create_test_pipeline_with_type(vec!["test-model"], PipelineType::Completion); let app = create_pipeline(&pipeline, &model_registry, None); use crate::models::completion::CompletionRequest; @@ -1318,7 +1352,8 @@ mod tests { let model_configs = create_model_configs(vec!["test-model"]); let model_registry = Arc::new(ModelRegistry::new(&model_configs, provider_registry).unwrap()); - let pipeline = create_test_pipeline_with_type(vec!["test-model"], PipelineType::Embeddings); + let pipeline = + create_test_pipeline_with_type(vec!["test-model"], PipelineType::Embeddings); let app = create_pipeline(&pipeline, &model_registry, None); use crate::models::embeddings::{EmbeddingsInput, EmbeddingsRequest}; @@ -1565,7 +1600,7 @@ mod tests { logit_bias: None, user: None, response_format: None, - tools: None, + tools: None, tool_choice: None, parallel_tool_calls: None, logprobs: None, diff --git a/src/pipelines/tracing_middleware.rs b/src/pipelines/tracing_middleware.rs index 11b33850..1153506e 100644 --- a/src/pipelines/tracing_middleware.rs +++ b/src/pipelines/tracing_middleware.rs @@ -6,8 +6,6 @@ use tower::{Layer, Service}; use super::otel::OtelTracer; -pub type SharedTracer = Arc>; - #[derive(Clone)] pub struct TracingLayer; diff --git a/tests/guardrails/helpers.rs b/tests/guardrails/helpers.rs index 502d6905..83f84109 100644 --- a/tests/guardrails/helpers.rs +++ b/tests/guardrails/helpers.rs @@ -14,7 +14,9 @@ use hub_lib::guardrails::types::{EvaluatorResponse, Guard, GuardMode, GuardrailE use hub_lib::models::chat::{ChatCompletion, ChatCompletionChoice, ChatCompletionRequest}; use hub_lib::models::completion::{CompletionChoice, CompletionRequest, CompletionResponse}; use hub_lib::models::content::{ChatCompletionMessage, ChatMessageContent}; -use hub_lib::models::embeddings::{Embedding, Embeddings, EmbeddingsInput, EmbeddingsRequest, EmbeddingsResponse}; +use hub_lib::models::embeddings::{ + Embedding, Embeddings, EmbeddingsInput, EmbeddingsRequest, EmbeddingsResponse, +}; use hub_lib::models::usage::{EmbeddingUsage, Usage}; use serde::Serialize; use serde_json::json; @@ -80,7 +82,9 @@ pub fn create_test_guard_with_failure_action( mode: GuardMode, on_failure: OnFailure, ) -> Guard { - TestGuardBuilder::new(name, mode).on_failure(on_failure).build() + TestGuardBuilder::new(name, mode) + .on_failure(on_failure) + .build() } pub fn create_test_guard_with_required(name: &str, mode: GuardMode, required: bool) -> Guard { diff --git a/tests/guardrails/test_e2e.rs b/tests/guardrails/test_e2e.rs index 87b5cd8c..07fb50c1 100644 --- a/tests/guardrails/test_e2e.rs +++ b/tests/guardrails/test_e2e.rs @@ -190,20 +190,20 @@ async fn test_e2e_mixed_block_and_warn() { let guards = vec![ TestGuardBuilder::new("passer", GuardMode::PreCall) - .on_failure(OnFailure::Block) - .api_base(&eval1.uri()) - .evaluator_slug("profanity-detector") - .build(), + .on_failure(OnFailure::Block) + .api_base(&eval1.uri()) + .evaluator_slug("profanity-detector") + .build(), TestGuardBuilder::new("warner", GuardMode::PreCall) - .on_failure(OnFailure::Warn) - .api_base(&eval2.uri()) - .evaluator_slug("tone-detection") - .build(), + .on_failure(OnFailure::Warn) + .api_base(&eval2.uri()) + .evaluator_slug("tone-detection") + .build(), TestGuardBuilder::new("blocker", GuardMode::PreCall) - .on_failure(OnFailure::Block) - .api_base(&eval3.uri()) - .evaluator_slug("toxicity-detector") - .build(), + .on_failure(OnFailure::Block) + .api_base(&eval3.uri()) + .evaluator_slug("toxicity-detector") + .build(), ]; let client = TraceloopClient::new(); @@ -364,15 +364,15 @@ async fn test_e2e_multiple_guards_different_evaluators() { let guards = vec![ TestGuardBuilder::new("tox-guard", GuardMode::PreCall) - .on_failure(OnFailure::Block) - .api_base(&server.uri()) - .evaluator_slug("toxicity-detector") - .build(), + .on_failure(OnFailure::Block) + .api_base(&server.uri()) + .evaluator_slug("toxicity-detector") + .build(), TestGuardBuilder::new("pii-guard", GuardMode::PreCall) - .on_failure(OnFailure::Block) - .api_base(&server.uri()) - .evaluator_slug("pii-detector") - .build(), + .on_failure(OnFailure::Block) + .api_base(&server.uri()) + .evaluator_slug("pii-detector") + .build(), ]; let client = TraceloopClient::new(); diff --git a/tests/guardrails/test_middleware.rs b/tests/guardrails/test_middleware.rs index c92cae95..b18e1341 100644 --- a/tests/guardrails/test_middleware.rs +++ b/tests/guardrails/test_middleware.rs @@ -2,7 +2,7 @@ use hub_lib::guardrails::middleware::GuardrailsLayer; use hub_lib::guardrails::providers::traceloop::TraceloopClient; use hub_lib::guardrails::types::{Guard, GuardMode, Guardrails, OnFailure}; -use axum::body::{to_bytes, Body}; +use axum::body::{Body, to_bytes}; use axum::extract::Request; use axum::http::StatusCode; use serde_json::json; @@ -87,7 +87,13 @@ async fn test_chat_completions_endpoint_detected() { .unwrap(); // Call middleware - let response = service.ready().await.unwrap().call(http_request).await.unwrap(); + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); // Verify response is 200 OK (guard passed) assert_eq!(response.status(), StatusCode::OK); @@ -131,7 +137,13 @@ async fn test_completions_endpoint_detected() { .body(Body::from(serde_json::to_vec(&request).unwrap())) .unwrap(); - let response = service.ready().await.unwrap().call(http_request).await.unwrap(); + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); assert_eq!(response.status(), StatusCode::OK); } @@ -172,7 +184,13 @@ async fn test_embeddings_endpoint_detected() { .body(Body::from(serde_json::to_vec(&request).unwrap())) .unwrap(); - let response = service.ready().await.unwrap().call(http_request).await.unwrap(); + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); assert_eq!(response.status(), StatusCode::OK); } @@ -216,7 +234,13 @@ async fn test_pre_call_guard_blocks_chat() { .body(Body::from(serde_json::to_vec(&request).unwrap())) .unwrap(); - let response = service.ready().await.unwrap().call(http_request).await.unwrap(); + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); // Verify blocked assert_eq!(response.status(), StatusCode::FORBIDDEN); @@ -261,13 +285,24 @@ async fn test_pre_call_guard_warns_chat() { .body(Body::from(serde_json::to_vec(&request).unwrap())) .unwrap(); - let response = service.ready().await.unwrap().call(http_request).await.unwrap(); + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); // Verify passes with warning assert_eq!(response.status(), StatusCode::OK); - assert!(response.headers().contains_key("x-traceloop-guardrail-warning")); + assert!( + response + .headers() + .contains_key("x-traceloop-guardrail-warning") + ); - let warning_header = response.headers() + let warning_header = response + .headers() .get("x-traceloop-guardrail-warning") .unwrap() .to_str() @@ -314,7 +349,13 @@ async fn test_post_call_guard_blocks_chat() { .body(Body::from(serde_json::to_vec(&request).unwrap())) .unwrap(); - let response = service.ready().await.unwrap().call(http_request).await.unwrap(); + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); // Verify blocked by post-call guard assert_eq!(response.status(), StatusCode::FORBIDDEN); @@ -333,7 +374,7 @@ async fn test_post_call_guard_skipped_for_embeddings() { "result": {}, "pass": true }))) - .expect(1) // Pre-call should run + .expect(1) // Pre-call should run .mount(&pre_eval_server) .await; @@ -344,7 +385,7 @@ async fn test_post_call_guard_skipped_for_embeddings() { "result": {}, "pass": false // Would block if called }))) - .expect(0) // Post-call should NOT run for embeddings + .expect(0) // Post-call should NOT run for embeddings .mount(&post_eval_server) .await; @@ -377,13 +418,23 @@ async fn test_post_call_guard_skipped_for_embeddings() { .body(Body::from(serde_json::to_vec(&request).unwrap())) .unwrap(); - let response = service.ready().await.unwrap().call(http_request).await.unwrap(); + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); // Verify response is 200 OK (no post-call blocking) assert_eq!(response.status(), StatusCode::OK); // Verify no warning header - assert!(!response.headers().contains_key("x-traceloop-guardrail-warning")); + assert!( + !response + .headers() + .contains_key("x-traceloop-guardrail-warning") + ); // Verify response body contains embeddings let body_bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap(); @@ -408,7 +459,7 @@ async fn test_streaming_chat_bypasses_guards() { "result": {}, "pass": false // Would block if evaluated }))) - .expect(0) // Should never be called for streaming + .expect(0) // Should never be called for streaming .mount(&eval_server) .await; @@ -435,13 +486,23 @@ async fn test_streaming_chat_bypasses_guards() { .body(Body::from(serde_json::to_vec(&request).unwrap())) .unwrap(); - let response = service.ready().await.unwrap().call(http_request).await.unwrap(); + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); // Verify response is 200 OK (streaming bypasses guards) assert_eq!(response.status(), StatusCode::OK); // Verify no warning header - assert!(!response.headers().contains_key("x-traceloop-guardrail-warning")); + assert!( + !response + .headers() + .contains_key("x-traceloop-guardrail-warning") + ); // Wiremock verifies evaluator was never called (expect(0)) } @@ -455,7 +516,7 @@ async fn test_streaming_completion_bypasses_guards() { "result": {}, "pass": false // Would block if evaluated }))) - .expect(0) // Should never be called for streaming + .expect(0) // Should never be called for streaming .mount(&eval_server) .await; @@ -482,13 +543,23 @@ async fn test_streaming_completion_bypasses_guards() { .body(Body::from(serde_json::to_vec(&request).unwrap())) .unwrap(); - let response = service.ready().await.unwrap().call(http_request).await.unwrap(); + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); // Verify response is 200 OK (streaming bypasses guards) assert_eq!(response.status(), StatusCode::OK); // Verify no warning header - assert!(!response.headers().contains_key("x-traceloop-guardrail-warning")); + assert!( + !response + .headers() + .contains_key("x-traceloop-guardrail-warning") + ); // Wiremock verifies evaluator was never called (expect(0)) } @@ -503,7 +574,7 @@ async fn test_no_guardrails_configured_passes() { let completion = create_test_chat_completion("Response"); let inner_service = MockService::with_json(StatusCode::OK, &completion); - let layer = GuardrailsLayer::new(None); // No guardrails + let layer = GuardrailsLayer::new(None); // No guardrails let mut service = layer.layer(inner_service); let request = create_test_chat_request("Any input"); @@ -514,7 +585,13 @@ async fn test_no_guardrails_configured_passes() { .body(Body::from(serde_json::to_vec(&request).unwrap())) .unwrap(); - let response = service.ready().await.unwrap().call(http_request).await.unwrap(); + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); // Verify response passes through unchanged assert_eq!(response.status(), StatusCode::OK); @@ -528,7 +605,7 @@ async fn test_unsupported_endpoint_passes() { "result": {}, "pass": false }))) - .expect(0) // Should never be called + .expect(0) // Should never be called .mount(&eval_server) .await; @@ -540,10 +617,8 @@ async fn test_unsupported_endpoint_passes() { ); let guardrails = create_guardrails(vec![guard]); - let inner_service = MockService::with_json( - StatusCode::OK, - &json!({"data": [{"id": "model-1"}]}), - ); + let inner_service = + MockService::with_json(StatusCode::OK, &json!({"data": [{"id": "model-1"}]})); let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); let mut service = layer.layer(inner_service); @@ -551,11 +626,17 @@ async fn test_unsupported_endpoint_passes() { // Request to unsupported endpoint let http_request = Request::builder() .method("GET") - .uri("/v1/models") // Unsupported endpoint + .uri("/v1/models") // Unsupported endpoint .body(Body::empty()) .unwrap(); - let response = service.ready().await.unwrap().call(http_request).await.unwrap(); + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); // Verify passes through assert_eq!(response.status(), StatusCode::OK); From fee49637c48576163b231750d2a4c0dbf7dbde06 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Wed, 18 Feb 2026 14:15:15 +0200 Subject: [PATCH 48/59] code simplify --- src/guardrails/middleware.rs | 64 +++++++++++-------------- src/guardrails/runner.rs | 93 ++++++++++++++++-------------------- src/pipelines/pipeline.rs | 87 ++++++++++++++++----------------- 3 files changed, 113 insertions(+), 131 deletions(-) diff --git a/src/guardrails/middleware.rs b/src/guardrails/middleware.rs index eb3bfbe6..b269b1c3 100644 --- a/src/guardrails/middleware.rs +++ b/src/guardrails/middleware.rs @@ -14,6 +14,8 @@ use crate::models::completion::{CompletionRequest, CompletionResponse}; use crate::models::embeddings::EmbeddingsRequest; use crate::pipelines::otel::SharedTracer; +use serde::de::DeserializeOwned; + use super::parsing::PromptExtractor; use super::runner::GuardrailsRunner; use super::types::Guardrails; @@ -122,6 +124,21 @@ async fn handle_post_call_guards( GuardrailsRunner::finalize_response(response, &warnings) } +/// Try to deserialize bytes into a request type, logging on failure. +/// Returns None if deserialization fails, allowing the caller to pass through. +fn try_parse(bytes: &[u8], label: &str) -> Option { + match serde_json::from_slice::(bytes) { + Ok(req) => Some(req), + Err(e) => { + debug!( + "Guardrails middleware: failed to parse {} request: {}", + label, e + ); + None + } + } +} + /// Tower layer that applies guardrail checks around a service. /// /// - **Pre-call guards** run before the inner service, inspecting the request body. @@ -208,41 +225,18 @@ where // Parse request based on endpoint type let parsed_request = match endpoint_type { - EndpointType::Chat => { - match serde_json::from_slice::(&bytes) { - Ok(req) => ParsedRequest::Chat(Box::new(req)), - Err(e) => { - debug!("Guardrails middleware: failed to parse chat request: {}", e); - let request = Request::from_parts(parts, Body::from(bytes)); - return inner.call(request).await; - } - } - } - EndpointType::Completion => { - match serde_json::from_slice::(&bytes) { - Ok(req) => ParsedRequest::Completion(Box::new(req)), - Err(e) => { - debug!( - "Guardrails middleware: failed to parse completion request: {}", - e - ); - let request = Request::from_parts(parts, Body::from(bytes)); - return inner.call(request).await; - } - } - } - EndpointType::Embeddings => { - match serde_json::from_slice::(&bytes) { - Ok(req) => ParsedRequest::Embeddings(Box::new(req)), - Err(e) => { - debug!( - "Guardrails middleware: failed to parse embeddings request: {}", - e - ); - let request = Request::from_parts(parts, Body::from(bytes)); - return inner.call(request).await; - } - } + EndpointType::Chat => try_parse::(&bytes, "chat") + .map(|req| ParsedRequest::Chat(Box::new(req))), + EndpointType::Completion => try_parse::(&bytes, "completion") + .map(|req| ParsedRequest::Completion(Box::new(req))), + EndpointType::Embeddings => try_parse::(&bytes, "embeddings") + .map(|req| ParsedRequest::Embeddings(Box::new(req))), + }; + let parsed_request = match parsed_request { + Some(pr) => pr, + None => { + let request = Request::from_parts(parts, Body::from(bytes)); + return inner.call(request).await; } }; diff --git a/src/guardrails/runner.rs b/src/guardrails/runner.rs index 50a3a04a..598ab2f9 100644 --- a/src/guardrails/runner.rs +++ b/src/guardrails/runner.rs @@ -207,6 +207,22 @@ pub struct GuardrailsRunner<'a> { parent_cx: Option, } +/// Convert a GuardrailsOutcome into a GuardPhaseResult. +/// If the outcome is blocked, produces a blocked response; otherwise, forwards warnings. +fn outcome_to_phase_result(outcome: GuardrailsOutcome) -> GuardPhaseResult { + if outcome.blocked { + GuardPhaseResult { + blocked_response: Some(blocked_response(&outcome)), + warnings: Vec::new(), + } + } else { + GuardPhaseResult { + blocked_response: None, + warnings: outcome.warnings, + } + } +} + impl<'a> GuardrailsRunner<'a> { /// Create a runner by resolving guards from pipeline config + request headers. /// Returns None if no guards are active for this request. @@ -240,16 +256,7 @@ impl<'a> GuardrailsRunner<'a> { let input = request.extract_prompt(); let outcome = execute_guards(&self.pre_call, &input, self.client, self.parent_cx.as_ref()).await; - if outcome.blocked { - return GuardPhaseResult { - blocked_response: Some(blocked_response(&outcome)), - warnings: Vec::new(), - }; - } - GuardPhaseResult { - blocked_response: None, - warnings: outcome.warnings, - } + outcome_to_phase_result(outcome) } /// Run post-call guards, extracting input from the response only if guards exist. @@ -280,16 +287,7 @@ impl<'a> GuardrailsRunner<'a> { self.parent_cx.as_ref(), ) .await; - if outcome.blocked { - return GuardPhaseResult { - blocked_response: Some(blocked_response(&outcome)), - warnings: Vec::new(), - }; - } - GuardPhaseResult { - blocked_response: None, - warnings: outcome.warnings, - } + outcome_to_phase_result(outcome) } /// Attach warning headers to a response if there are any warnings. @@ -323,39 +321,32 @@ pub fn blocked_response(outcome: &GuardrailsOutcome) -> Response { let guard_name = outcome.blocking_guard.as_deref().unwrap_or("unknown"); // Find the blocking guard result to get details - let details = outcome - .results - .iter() - .find(|r| match r { - GuardResult::Failed { name, .. } => name == guard_name, - GuardResult::Error { name, .. } => name == guard_name, - _ => false, - }) - .and_then(|r| match r { - GuardResult::Failed { result, .. } => Some(json!({ - "evaluation_result": result, - "reason": "evaluation_failed" - })), - GuardResult::Error { error, .. } => Some(json!({ - "error_details": error, - "reason": "evaluator_error" - })), - _ => None, - }); - - let mut error_obj = json!({ - "type": "guardrail_blocked", - "guardrail": guard_name, - "message": format!("Request blocked by guardrail '{guard_name}'"), + let blocking_result = outcome.results.iter().find(|r| match r { + GuardResult::Failed { name, .. } | GuardResult::Error { name, .. } => name == guard_name, + _ => false, }); - if let Some(details) = details { - if let Some(obj) = error_obj.as_object_mut() { - if let Some(details_obj) = details.as_object() { - obj.extend(details_obj.clone()); - } - } - } + let error_obj = match blocking_result { + Some(GuardResult::Failed { result, .. }) => json!({ + "type": "guardrail_blocked", + "guardrail": guard_name, + "message": format!("Request blocked by guardrail '{guard_name}'"), + "evaluation_result": result, + "reason": "evaluation_failed", + }), + Some(GuardResult::Error { error, .. }) => json!({ + "type": "guardrail_blocked", + "guardrail": guard_name, + "message": format!("Request blocked by guardrail '{guard_name}'"), + "error_details": error, + "reason": "evaluator_error", + }), + _ => json!({ + "type": "guardrail_blocked", + "guardrail": guard_name, + "message": format!("Request blocked by guardrail '{guard_name}'"), + }), + }; let body = json!({ "error": error_obj }); (StatusCode::FORBIDDEN, Json(body)).into_response() diff --git a/src/pipelines/pipeline.rs b/src/pipelines/pipeline.rs index feee6f5b..e892d5c8 100644 --- a/src/pipelines/pipeline.rs +++ b/src/pipelines/pipeline.rs @@ -27,6 +27,12 @@ use futures::{Stream, StreamExt}; use reqwest_streams::error::StreamBodyError; use std::sync::Arc; +/// Access the tracer behind the shared mutex, executing a closure with the locked guard. +/// Panics if the mutex is poisoned (same behavior as the previous .lock().unwrap() calls). +fn with_tracer(tracer: &SharedTracer, f: impl FnOnce(&mut OtelTracer) -> R) -> R { + f(&mut tracer.lock().unwrap()) +} + // Re-export builder and orchestrator functions for backward compatibility with tests pub use crate::guardrails::runner::{blocked_response, warning_header_value}; pub use crate::guardrails::setup::{ @@ -106,17 +112,17 @@ fn trace_and_stream( while let Some(result) = stream.next().await { yield match result { Ok(chunk) => { - tracer.lock().unwrap().log_chunk(&chunk); + with_tracer(&tracer, |t| t.log_chunk(&chunk)); Event::default().json_data(chunk) } Err(e) => { eprintln!("Error in stream: {e:?}"); - tracer.lock().unwrap().log_error(e.to_string()); + with_tracer(&tracer, |t| t.log_error(e.to_string())); Err(axum::Error::new(e)) } }; } - tracer.lock().unwrap().streaming_end(); + with_tracer(&tracer, |t| t.streaming_end()); } } @@ -130,26 +136,24 @@ pub async fn chat_completions( let model = model_registry.get(&model_key).unwrap(); if payload.model == model.model_type { - { - let mut tracer_guard = tracer.lock().unwrap(); - tracer_guard.start_llm_span("chat", &payload); - tracer_guard.set_vendor(&get_vendor_name(&model.provider.r#type())); - } + with_tracer(&tracer, |t| { + t.start_llm_span("chat", &payload); + t.set_vendor(&get_vendor_name(&model.provider.r#type())); + }); let response = match model.chat_completions(payload.clone()).await { Ok(response) => response, Err(e) => { eprintln!("Chat completion error for model {model_key}: {e:?}"); - tracer - .lock() - .unwrap() - .log_error(format!("Chat completion failed: {e:?}")); + with_tracer(&tracer, |t| { + t.log_error(format!("Chat completion failed: {e:?}")) + }); return Err(e); } }; if let ChatCompletionResponse::NonStream(completion) = response { - tracer.lock().unwrap().log_success(&completion); + with_tracer(&tracer, |t| t.log_success(&completion)); return Ok(Json(completion).into_response()); } @@ -161,10 +165,9 @@ pub async fn chat_completions( } } - tracer - .lock() - .unwrap() - .log_error("No matching model found".to_string()); + with_tracer(&tracer, |t| { + t.log_error("No matching model found".to_string()) + }); eprintln!("No matching model found for: {}", payload.model); Err(StatusCode::NOT_FOUND) } @@ -179,33 +182,30 @@ pub async fn completions( let model = model_registry.get(&model_key).unwrap(); if payload.model == model.model_type { - { - let mut tracer_guard = tracer.lock().unwrap(); - tracer_guard.start_llm_span("completion", &payload); - tracer_guard.set_vendor(&get_vendor_name(&model.provider.r#type())); - } + with_tracer(&tracer, |t| { + t.start_llm_span("completion", &payload); + t.set_vendor(&get_vendor_name(&model.provider.r#type())); + }); let response = match model.completions(payload.clone()).await { Ok(response) => response, Err(e) => { eprintln!("Completion error for model {model_key}: {e:?}"); - tracer - .lock() - .unwrap() - .log_error(format!("Completion failed: {e:?}")); + with_tracer(&tracer, |t| { + t.log_error(format!("Completion failed: {e:?}")) + }); return Err(e); } }; - tracer.lock().unwrap().log_success(&response); + with_tracer(&tracer, |t| t.log_success(&response)); return Ok(Json(response).into_response()); } } - tracer - .lock() - .unwrap() - .log_error("No matching model found".to_string()); + with_tracer(&tracer, |t| { + t.log_error("No matching model found".to_string()) + }); eprintln!("No matching model found for: {}", payload.model); Err(StatusCode::NOT_FOUND) } @@ -220,32 +220,29 @@ pub async fn embeddings( let model = model_registry.get(&model_key).unwrap(); if payload.model == model.model_type { - { - let mut tracer_guard = tracer.lock().unwrap(); - tracer_guard.start_llm_span("embeddings", &payload); - tracer_guard.set_vendor(&get_vendor_name(&model.provider.r#type())); - } + with_tracer(&tracer, |t| { + t.start_llm_span("embeddings", &payload); + t.set_vendor(&get_vendor_name(&model.provider.r#type())); + }); let response = match model.embeddings(payload.clone()).await { Ok(response) => response, Err(e) => { eprintln!("Embeddings error for model {model_key}: {e:?}"); - tracer - .lock() - .unwrap() - .log_error(format!("Embeddings failed: {e:?}")); + with_tracer(&tracer, |t| { + t.log_error(format!("Embeddings failed: {e:?}")) + }); return Err(e); } }; - tracer.lock().unwrap().log_success(&response); + with_tracer(&tracer, |t| t.log_success(&response)); return Ok(Json(response)); } } - tracer - .lock() - .unwrap() - .log_error("No matching model found".to_string()); + with_tracer(&tracer, |t| { + t.log_error("No matching model found".to_string()) + }); eprintln!("No matching model found for: {}", payload.model); Err(StatusCode::NOT_FOUND) } From b973e9e093cc0cea4fc774ca69050ea95beb755a Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Wed, 18 Feb 2026 14:49:33 +0200 Subject: [PATCH 49/59] fix test --- src/pipelines/pipeline.rs | 80 +++++++++++++++++++++++++++------------ 1 file changed, 55 insertions(+), 25 deletions(-) diff --git a/src/pipelines/pipeline.rs b/src/pipelines/pipeline.rs index e892d5c8..bf333ac6 100644 --- a/src/pipelines/pipeline.rs +++ b/src/pipelines/pipeline.rs @@ -682,7 +682,7 @@ mod tests { use opentelemetry_sdk::export::trace::SpanData; use opentelemetry_sdk::testing::trace::InMemorySpanExporter; use opentelemetry_sdk::trace::TracerProvider; - use std::sync::LazyLock; + use std::sync::{LazyLock, Mutex}; /// Shared OTel exporter, initialized once for all span tests /// Tests are isolated by tracking span count before/after each request @@ -695,6 +695,10 @@ mod tests { exporter }); + /// Mutex to serialize span tests and prevent race conditions + /// with the shared TEST_EXPORTER + static SPAN_TEST_LOCK: LazyLock> = LazyLock::new(|| Mutex::new(())); + // Mock provider that returns realistic responses with Usage data #[derive(Clone)] struct MockProviderForSpanTests { @@ -886,34 +890,49 @@ mod tests { } // Helper: Collect spans added since before_count + // Retries for up to 100ms to handle async span finalization fn get_spans_for_test(before_count: usize) -> Vec { - // Get all spans - let all_spans = TEST_EXPORTER.get_finished_spans().unwrap(); - - // Skip the before_count spans and take only the next few - // We expect exactly 2 spans per test (root + LLM) - let new_spans: Vec = all_spans.into_iter().skip(before_count).collect(); - - // If we have more than 2 spans, try to find the most recent root span - // and its immediate child - if new_spans.len() > 2 { - // Find the last root span (traceloop_hub with Server kind) - if let Some(root_idx) = new_spans - .iter() - .rposition(|s| s.name == "traceloop_hub" && s.span_kind == SpanKind::Server) - { - let root = &new_spans[root_idx]; - let root_trace_id = root.span_context.trace_id(); - - // Collect all spans with the same trace_id - return new_spans - .into_iter() - .filter(|s| s.span_context.trace_id() == root_trace_id) - .collect(); + use std::time::Duration; + + // Retry for up to 100ms to wait for spans to be flushed + for _ in 0..10 { + // Get all spans + let all_spans = TEST_EXPORTER.get_finished_spans().unwrap(); + + // Skip the before_count spans and take only the next few + // We expect exactly 2 spans per test (root + LLM) + let new_spans: Vec = all_spans.into_iter().skip(before_count).collect(); + + // If we have at least 2 spans, we can proceed + if new_spans.len() >= 2 { + // If we have more than 2 spans, try to find the most recent root span + // and its immediate child + if new_spans.len() > 2 { + // Find the last root span (traceloop_hub with Server kind) + if let Some(root_idx) = new_spans.iter().rposition(|s| { + s.name == "traceloop_hub" && s.span_kind == SpanKind::Server + }) { + let root = &new_spans[root_idx]; + let root_trace_id = root.span_context.trace_id(); + + // Collect all spans with the same trace_id + return new_spans + .into_iter() + .filter(|s| s.span_context.trace_id() == root_trace_id) + .collect(); + } + } + + return new_spans; } + + // Wait a bit before retrying + std::thread::sleep(Duration::from_millis(10)); } - new_spans + // Last attempt - return whatever we have + let all_spans = TEST_EXPORTER.get_finished_spans().unwrap(); + all_spans.into_iter().skip(before_count).collect() } // Helper: Find root span (name="traceloop_hub", SpanKind::Server) @@ -962,6 +981,8 @@ mod tests { #[tokio::test] async fn test_chat_completions_success_spans() { + // Serialize span tests to avoid race conditions + let _lock = SPAN_TEST_LOCK.lock().unwrap(); // Initialize exporter let _ = &*TEST_EXPORTER; let before_count = TEST_EXPORTER.get_finished_spans().unwrap().len(); @@ -1080,6 +1101,7 @@ mod tests { #[tokio::test] async fn test_completions_success_spans() { + let _lock = SPAN_TEST_LOCK.lock().unwrap(); let _ = &*TEST_EXPORTER; let before_count = TEST_EXPORTER.get_finished_spans().unwrap().len(); @@ -1148,6 +1170,7 @@ mod tests { #[tokio::test] async fn test_embeddings_success_spans() { + let _lock = SPAN_TEST_LOCK.lock().unwrap(); let _ = &*TEST_EXPORTER; let before_count = TEST_EXPORTER.get_finished_spans().unwrap().len(); @@ -1204,6 +1227,7 @@ mod tests { #[tokio::test] async fn test_chat_completions_error_spans() { + let _lock = SPAN_TEST_LOCK.lock().unwrap(); let _ = &*TEST_EXPORTER; let before_count = TEST_EXPORTER.get_finished_spans().unwrap().len(); @@ -1285,6 +1309,7 @@ mod tests { #[tokio::test] async fn test_completions_error_spans() { + let _lock = SPAN_TEST_LOCK.lock().unwrap(); let _ = &*TEST_EXPORTER; let before_count = TEST_EXPORTER.get_finished_spans().unwrap().len(); @@ -1341,6 +1366,7 @@ mod tests { #[tokio::test] async fn test_embeddings_error_spans() { + let _lock = SPAN_TEST_LOCK.lock().unwrap(); let _ = &*TEST_EXPORTER; let before_count = TEST_EXPORTER.get_finished_spans().unwrap().len(); @@ -1385,6 +1411,7 @@ mod tests { #[tokio::test] async fn test_span_request_attributes_recorded() { + let _lock = SPAN_TEST_LOCK.lock().unwrap(); let _ = &*TEST_EXPORTER; let before_count = TEST_EXPORTER.get_finished_spans().unwrap().len(); @@ -1467,6 +1494,7 @@ mod tests { #[tokio::test] async fn test_span_response_attributes_recorded() { + let _lock = SPAN_TEST_LOCK.lock().unwrap(); let _ = &*TEST_EXPORTER; let before_count = TEST_EXPORTER.get_finished_spans().unwrap().len(); @@ -1551,6 +1579,7 @@ mod tests { #[tokio::test] async fn test_vendor_attribute_mapping() { + let _lock = SPAN_TEST_LOCK.lock().unwrap(); let _ = &*TEST_EXPORTER; // Test each provider type maps to correct vendor name @@ -1632,6 +1661,7 @@ mod tests { #[tokio::test] async fn test_span_parent_child_relationship() { + let _lock = SPAN_TEST_LOCK.lock().unwrap(); let _ = &*TEST_EXPORTER; let before_count = TEST_EXPORTER.get_finished_spans().unwrap().len(); From 0b33b562c2a01c3c2971c14bbbd9827a8fa3e6a5 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Thu, 19 Feb 2026 09:59:10 +0200 Subject: [PATCH 50/59] s1 --- src/guardrails/providers/mod.rs | 13 ------------- src/guardrails/providers/traceloop.rs | 2 +- src/guardrails/setup.rs | 9 ++++++--- src/guardrails/types.rs | 5 ++++- tests/guardrails/test_traceloop_client.rs | 8 -------- 5 files changed, 11 insertions(+), 26 deletions(-) diff --git a/src/guardrails/providers/mod.rs b/src/guardrails/providers/mod.rs index 600e5726..6d88b394 100644 --- a/src/guardrails/providers/mod.rs +++ b/src/guardrails/providers/mod.rs @@ -1,14 +1 @@ pub mod traceloop; - -use self::traceloop::TraceloopClient; -use super::types::{Guard, GuardrailClient}; - -pub const TRACELOOP_PROVIDER: &str = "traceloop"; - -/// Create a guardrail client based on the guard's provider type. -pub fn create_guardrail_client(guard: &Guard) -> Option> { - match guard.provider.as_str() { - TRACELOOP_PROVIDER => Some(Box::new(TraceloopClient::new())), - _ => None, - } -} diff --git a/src/guardrails/providers/traceloop.rs b/src/guardrails/providers/traceloop.rs index 7ca19c1e..8eea4e9d 100644 --- a/src/guardrails/providers/traceloop.rs +++ b/src/guardrails/providers/traceloop.rs @@ -1,6 +1,6 @@ use async_trait::async_trait; -use super::GuardrailClient; +use crate::guardrails::types::GuardrailClient; use crate::guardrails::evaluator_types::get_evaluator; use crate::guardrails::parsing::parse_evaluator_http_response; use crate::guardrails::types::{EvaluatorResponse, Guard, GuardrailError}; diff --git a/src/guardrails/setup.rs b/src/guardrails/setup.rs index 9ffa3b19..b213622e 100644 --- a/src/guardrails/setup.rs +++ b/src/guardrails/setup.rs @@ -83,7 +83,10 @@ pub fn build_guardrail_resources(config: &GuardrailsConfig) -> Option = Arc::new(super::providers::traceloop::TraceloopClient::new()); - Some((all_guards, client)) + Some(GuardrailResources { + guards: all_guards, + client, + }) } /// Build per-pipeline Guardrails from shared resources. @@ -93,8 +96,8 @@ pub fn build_pipeline_guardrails( pipeline_guard_names: &[String], ) -> Arc { Arc::new(Guardrails { - all_guards: shared.0.clone(), + all_guards: shared.guards.clone(), pipeline_guard_names: pipeline_guard_names.to_vec(), - client: shared.1.clone(), + client: shared.client.clone(), }) } diff --git a/src/guardrails/types.rs b/src/guardrails/types.rs index 7cc4e2d8..bdddb2cc 100644 --- a/src/guardrails/types.rs +++ b/src/guardrails/types.rs @@ -7,7 +7,10 @@ use thiserror::Error; /// Shared guardrail resources: resolved guards + client. /// Built once per router build and shared across all pipelines. -pub type GuardrailResources = (Arc>, Arc); +pub struct GuardrailResources { + pub guards: Arc>, + pub client: Arc, +} fn default_on_failure() -> OnFailure { OnFailure::Warn diff --git a/tests/guardrails/test_traceloop_client.rs b/tests/guardrails/test_traceloop_client.rs index e15e5a77..0b22cfa4 100644 --- a/tests/guardrails/test_traceloop_client.rs +++ b/tests/guardrails/test_traceloop_client.rs @@ -1,4 +1,3 @@ -use hub_lib::guardrails::providers::create_guardrail_client; use hub_lib::guardrails::providers::traceloop::TraceloopClient; use hub_lib::guardrails::types::GuardMode; use hub_lib::guardrails::types::GuardrailClient; @@ -152,10 +151,3 @@ async fn test_traceloop_client_rejects_empty_api_key() { ); } } - -#[test] -fn test_client_creation_from_guard_config() { - let guard = create_test_guard("test", GuardMode::PreCall); - let client = create_guardrail_client(&guard); - assert!(client.is_some()); -} From 65e43830bbd9262207a2b3e855b707de5fe0e3df Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Thu, 19 Feb 2026 10:08:00 +0200 Subject: [PATCH 51/59] s2 --- src/guardrails/runner.rs | 31 +++++---- src/guardrails/setup.rs | 11 +-- src/guardrails/types.rs | 8 +-- src/pipelines/pipeline.rs | 138 ++++++-------------------------------- 4 files changed, 40 insertions(+), 148 deletions(-) diff --git a/src/guardrails/runner.rs b/src/guardrails/runner.rs index 598ab2f9..71a1ff7f 100644 --- a/src/guardrails/runner.rs +++ b/src/guardrails/runner.rs @@ -138,47 +138,46 @@ pub async fn execute_guards( if let Some(s) = span { guard_spans.push(s); } + let name = guard.name.clone(); match result { Ok(response) => { if response.pass { - results.push(GuardResult::Passed { - name: guard.name.clone(), - }); + results.push(GuardResult::Passed { name }); } else { - results.push(GuardResult::Failed { - name: guard.name.clone(), - result: response.result, - on_failure: guard.on_failure.clone(), - }); match guard.on_failure { OnFailure::Block => { blocked = true; if blocking_guard.is_none() { - blocking_guard = Some(guard.name.clone()); + blocking_guard = Some(name.clone()); } } OnFailure::Warn => { warnings.push(GuardWarning { - guard_name: guard.name.clone(), + guard_name: name.clone(), reason: "failed".to_string(), }); } } + results.push(GuardResult::Failed { + name, + result: response.result, + on_failure: guard.on_failure, + }); } } Err(err) => { let is_required = guard.required; - results.push(GuardResult::Error { - name: guard.name.clone(), - error: err.to_string(), - required: is_required, - }); if is_required { blocked = true; if blocking_guard.is_none() { - blocking_guard = Some(guard.name.clone()); + blocking_guard = Some(name.clone()); } } + results.push(GuardResult::Error { + name, + error: err.to_string(), + required: is_required, + }); } } } diff --git a/src/guardrails/setup.rs b/src/guardrails/setup.rs index b213622e..0d639002 100644 --- a/src/guardrails/setup.rs +++ b/src/guardrails/setup.rs @@ -42,17 +42,10 @@ pub fn resolve_guards_by_name( /// Split guards into (pre_call, post_call) lists by mode. pub fn split_guards_by_mode(guards: &[Guard]) -> (Vec, Vec) { - let pre_call: Vec = guards - .iter() - .filter(|g| g.mode == GuardMode::PreCall) - .cloned() - .collect(); - let post_call: Vec = guards + guards .iter() - .filter(|g| g.mode == GuardMode::PostCall) .cloned() - .collect(); - (pre_call, post_call) + .partition(|g| g.mode == GuardMode::PreCall) } /// Resolve provider defaults (api_base/api_key) for all guards in the config. diff --git a/src/guardrails/types.rs b/src/guardrails/types.rs index bdddb2cc..a738f92a 100644 --- a/src/guardrails/types.rs +++ b/src/guardrails/types.rs @@ -20,14 +20,14 @@ fn default_required() -> bool { false } -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash)] #[serde(rename_all = "snake_case")] pub enum GuardMode { PreCall, PostCall, } -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash)] #[serde(rename_all = "snake_case")] pub enum OnFailure { Block, @@ -68,7 +68,7 @@ impl Hash for Guard { self.evaluator_slug.hash(state); // Hash params by sorting keys and hashing serialized values let mut params_vec: Vec<_> = self.params.iter().collect(); - params_vec.sort_by_key(|(k, _)| (*k).clone()); + params_vec.sort_by(|a, b| a.0.cmp(b.0)); for (k, v) in params_vec { k.hash(state); v.to_string().hash(state); @@ -117,7 +117,7 @@ where impl Hash for GuardrailsConfig { fn hash(&self, state: &mut H) { let mut entries: Vec<_> = self.providers.iter().collect(); - entries.sort_by_key(|(k, _)| (*k).clone()); + entries.sort_by(|a, b| a.0.cmp(b.0)); for (k, v) in entries { k.hash(state); v.hash(state); diff --git a/src/pipelines/pipeline.rs b/src/pipelines/pipeline.rs index bf333ac6..32f08f97 100644 --- a/src/pipelines/pipeline.rs +++ b/src/pipelines/pipeline.rs @@ -497,24 +497,20 @@ mod tests { assert!(!ids.contains(&"test-model-2")); } - // Test providers with different types for vendor testing + // Parameterized test provider for vendor testing #[derive(Clone)] - struct TestProviderOpenAI; - #[derive(Clone)] - struct TestProviderAnthropic; - #[derive(Clone)] - struct TestProviderAzure; + struct TestProvider(ProviderType); #[async_trait] - impl Provider for TestProviderOpenAI { + impl Provider for TestProvider { fn new(_config: &ProviderConfig) -> Self { - Self + Self(ProviderType::OpenAI) } fn key(&self) -> String { - "openai-key".to_string() + format!("{:?}-key", self.0).to_lowercase() } fn r#type(&self) -> ProviderType { - ProviderType::OpenAI + self.0.clone() } async fn chat_completions( @@ -527,101 +523,7 @@ mod tests { id: "test".to_string(), object: None, created: None, - model: "gpt-4".to_string(), - choices: vec![], - usage: crate::models::usage::Usage::default(), - system_fingerprint: None, - }, - )) - } - - async fn completions( - &self, - _payload: CompletionRequest, - _model_config: &ModelConfig, - ) -> Result { - Err(StatusCode::NOT_IMPLEMENTED) - } - - async fn embeddings( - &self, - _payload: EmbeddingsRequest, - _model_config: &ModelConfig, - ) -> Result { - Err(StatusCode::NOT_IMPLEMENTED) - } - } - - #[async_trait] - impl Provider for TestProviderAnthropic { - fn new(_config: &ProviderConfig) -> Self { - Self - } - fn key(&self) -> String { - "anthropic-key".to_string() - } - fn r#type(&self) -> ProviderType { - ProviderType::Anthropic - } - - async fn chat_completions( - &self, - _payload: crate::models::chat::ChatCompletionRequest, - _model_config: &ModelConfig, - ) -> Result { - Ok(crate::models::chat::ChatCompletionResponse::NonStream( - crate::models::chat::ChatCompletion { - id: "test".to_string(), - object: None, - created: None, - model: "claude-3".to_string(), - choices: vec![], - usage: crate::models::usage::Usage::default(), - system_fingerprint: None, - }, - )) - } - - async fn completions( - &self, - _payload: CompletionRequest, - _model_config: &ModelConfig, - ) -> Result { - Err(StatusCode::NOT_IMPLEMENTED) - } - - async fn embeddings( - &self, - _payload: EmbeddingsRequest, - _model_config: &ModelConfig, - ) -> Result { - Err(StatusCode::NOT_IMPLEMENTED) - } - } - - #[async_trait] - impl Provider for TestProviderAzure { - fn new(_config: &ProviderConfig) -> Self { - Self - } - fn key(&self) -> String { - "azure-key".to_string() - } - fn r#type(&self) -> ProviderType { - ProviderType::Azure - } - - async fn chat_completions( - &self, - _payload: crate::models::chat::ChatCompletionRequest, - _model_config: &ModelConfig, - ) -> Result { - Ok(crate::models::chat::ChatCompletionResponse::NonStream( - crate::models::chat::ChatCompletion { - id: "test".to_string(), - object: None, - created: None, - model: "gpt-4".to_string(), + model: "test".to_string(), choices: vec![], usage: crate::models::usage::Usage::default(), system_fingerprint: None, @@ -659,20 +561,18 @@ mod tests { #[test] fn test_provider_type_methods() { - // Test that our test providers return the correct types - // This ensures the pipeline would call set_vendor with the right values - let openai_provider = TestProviderOpenAI; - let anthropic_provider = TestProviderAnthropic; - let azure_provider = TestProviderAzure; - - assert_eq!(openai_provider.r#type(), ProviderType::OpenAI); - assert_eq!(anthropic_provider.r#type(), ProviderType::Anthropic); - assert_eq!(azure_provider.r#type(), ProviderType::Azure); - - // Test that these map to the correct vendor names - assert_eq!(get_vendor_name(&openai_provider.r#type()), "openai"); - assert_eq!(get_vendor_name(&anthropic_provider.r#type()), "Anthropic"); - assert_eq!(get_vendor_name(&azure_provider.r#type()), "Azure"); + // Test that our test provider returns the correct types + // and maps to the correct vendor names + let cases = [ + (ProviderType::OpenAI, "openai"), + (ProviderType::Anthropic, "Anthropic"), + (ProviderType::Azure, "Azure"), + ]; + for (provider_type, expected_vendor) in cases { + let provider = TestProvider(provider_type.clone()); + assert_eq!(provider.r#type(), provider_type); + assert_eq!(get_vendor_name(&provider.r#type()), expected_vendor); + } } // OpenTelemetry span verification tests From 36b0e1bb2c0061122c60a5afde6d4ab6bae92bc1 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Thu, 19 Feb 2026 10:17:09 +0200 Subject: [PATCH 52/59] s3 --- src/guardrails/middleware.rs | 21 +++++------ src/guardrails/runner.rs | 37 +++++-------------- src/pipelines/otel.rs | 71 ++++++++++++------------------------ src/pipelines/pipeline.rs | 7 +--- src/state.rs | 2 +- 5 files changed, 45 insertions(+), 93 deletions(-) diff --git a/src/guardrails/middleware.rs b/src/guardrails/middleware.rs index b269b1c3..b05e789a 100644 --- a/src/guardrails/middleware.rs +++ b/src/guardrails/middleware.rs @@ -32,9 +32,9 @@ impl EndpointType { /// Determine endpoint type from request path. fn from_path(path: &str) -> Option { match path { - p if p.contains("/chat/completions") => Some(Self::Chat), - p if p.contains("/completions") => Some(Self::Completion), - p if p.contains("/embeddings") => Some(Self::Embeddings), + p if p.ends_with("/chat/completions") => Some(Self::Chat), + p if p.ends_with("/completions") => Some(Self::Completion), + p if p.ends_with("/embeddings") => Some(Self::Embeddings), _ => None, } } @@ -114,10 +114,10 @@ async fn handle_post_call_guards( }; if let Some(result) = post_result { - if let Some(blocked) = result.blocked_response { - return blocked; + match result { + Err(blocked) => return blocked, + Ok(w) => warnings.extend(w), } - warnings.extend(result.warnings); } let response = Response::from_parts(resp_parts, Body::from(resp_bytes)); @@ -265,11 +265,10 @@ where }; // --- Pre-call guards --- - let pre_result = runner.run_pre_call(&parsed_request).await; - if let Some(blocked) = pre_result.blocked_response { - return Ok(blocked); - } - let all_warnings = pre_result.warnings; + let all_warnings = match runner.run_pre_call(&parsed_request).await { + Ok(warnings) => warnings, + Err(blocked) => return Ok(blocked), + }; // --- Call inner service --- let request = Request::from_parts(parts, Body::from(bytes)); diff --git a/src/guardrails/runner.rs b/src/guardrails/runner.rs index 71a1ff7f..180250a8 100644 --- a/src/guardrails/runner.rs +++ b/src/guardrails/runner.rs @@ -194,10 +194,8 @@ pub async fn execute_guards( } } -pub struct GuardPhaseResult { - pub blocked_response: Option, - pub warnings: Vec, -} +/// Result of a guard phase: Ok(warnings) on pass, Err(blocked_response) on block. +pub type GuardPhaseResult = Result, Response>; pub struct GuardrailsRunner<'a> { pre_call: Vec, @@ -210,15 +208,9 @@ pub struct GuardrailsRunner<'a> { /// If the outcome is blocked, produces a blocked response; otherwise, forwards warnings. fn outcome_to_phase_result(outcome: GuardrailsOutcome) -> GuardPhaseResult { if outcome.blocked { - GuardPhaseResult { - blocked_response: Some(blocked_response(&outcome)), - warnings: Vec::new(), - } + Err(blocked_response(&outcome)) } else { - GuardPhaseResult { - blocked_response: None, - warnings: outcome.warnings, - } + Ok(outcome.warnings) } } @@ -247,10 +239,7 @@ impl<'a> GuardrailsRunner<'a> { /// Run pre-call guards, extracting input from the request only if guards exist. pub async fn run_pre_call(&self, request: &impl PromptExtractor) -> GuardPhaseResult { if self.pre_call.is_empty() { - return GuardPhaseResult { - blocked_response: None, - warnings: Vec::new(), - }; + return Ok(Vec::new()); } let input = request.extract_prompt(); let outcome = @@ -261,22 +250,16 @@ impl<'a> GuardrailsRunner<'a> { /// Run post-call guards, extracting input from the response only if guards exist. pub async fn run_post_call(&self, response: &impl CompletionExtractor) -> GuardPhaseResult { if self.post_call.is_empty() { - return GuardPhaseResult { - blocked_response: None, - warnings: Vec::new(), - }; + return Ok(Vec::new()); } let completion = response.extract_completion(); if completion.is_empty() { warn!("Skipping post-call guardrails: LLM response content is empty"); - return GuardPhaseResult { - blocked_response: None, - warnings: vec![GuardWarning { - guard_name: "all post_call guards".to_string(), - reason: "skipped due to empty response content".to_string(), - }], - }; + return Ok(vec![GuardWarning { + guard_name: "all post_call guards".to_string(), + reason: "skipped due to empty response content".to_string(), + }]); } let outcome = execute_guards( diff --git a/src/pipelines/otel.rs b/src/pipelines/otel.rs index 6a41f43c..50fbf2ec 100644 --- a/src/pipelines/otel.rs +++ b/src/pipelines/otel.rs @@ -223,29 +223,28 @@ impl OtelTracer { } } +fn set_optional_f64(span: &mut BoxedSpan, key: &'static str, value: Option) { + if let Some(v) = value { + span.set_attribute(KeyValue::new(key, v as f64)); + } +} + +fn content_to_string(content: &ChatMessageContent) -> String { + match content { + ChatMessageContent::String(s) => s.clone(), + ChatMessageContent::Array(parts) => serde_json::to_string(parts).unwrap_or_default(), + } +} + impl RecordSpan for ChatCompletionRequest { fn record_span(&self, span: &mut BoxedSpan) { span.set_attribute(KeyValue::new("llm.request.type", "chat")); span.set_attribute(KeyValue::new(GEN_AI_REQUEST_MODEL, self.model.clone())); - if let Some(freq_penalty) = self.frequency_penalty { - span.set_attribute(KeyValue::new( - GEN_AI_REQUEST_FREQUENCY_PENALTY, - freq_penalty as f64, - )); - } - if let Some(pres_penalty) = self.presence_penalty { - span.set_attribute(KeyValue::new( - GEN_AI_REQUEST_PRESENCE_PENALTY, - pres_penalty as f64, - )); - } - if let Some(top_p) = self.top_p { - span.set_attribute(KeyValue::new(GEN_AI_REQUEST_TOP_P, top_p as f64)); - } - if let Some(temp) = self.temperature { - span.set_attribute(KeyValue::new(GEN_AI_REQUEST_TEMPERATURE, temp as f64)); - } + set_optional_f64(span, GEN_AI_REQUEST_FREQUENCY_PENALTY, self.frequency_penalty); + set_optional_f64(span, GEN_AI_REQUEST_PRESENCE_PENALTY, self.presence_penalty); + set_optional_f64(span, GEN_AI_REQUEST_TOP_P, self.top_p); + set_optional_f64(span, GEN_AI_REQUEST_TEMPERATURE, self.temperature); if get_trace_content_enabled() { for (i, message) in self.messages.iter().enumerate() { @@ -256,12 +255,7 @@ impl RecordSpan for ChatCompletionRequest { )); span.set_attribute(KeyValue::new( format!("gen_ai.prompt.{i}.content"), - match &content { - ChatMessageContent::String(content) => content.clone(), - ChatMessageContent::Array(content) => { - serde_json::to_string(content).unwrap_or_default() - } - }, + content_to_string(content), )); } } @@ -285,12 +279,7 @@ impl RecordSpan for ChatCompletion { )); span.set_attribute(KeyValue::new( format!("gen_ai.completion.{}.content", choice.index), - match &content { - ChatMessageContent::String(content) => content.clone(), - ChatMessageContent::Array(content) => { - serde_json::to_string(content).unwrap_or_default() - } - }, + content_to_string(content), )); } span.set_attribute(KeyValue::new( @@ -308,24 +297,10 @@ impl RecordSpan for CompletionRequest { span.set_attribute(KeyValue::new(GEN_AI_REQUEST_MODEL, self.model.clone())); span.set_attribute(KeyValue::new("gen_ai.prompt", self.prompt.clone())); - if let Some(freq_penalty) = self.frequency_penalty { - span.set_attribute(KeyValue::new( - GEN_AI_REQUEST_FREQUENCY_PENALTY, - freq_penalty as f64, - )); - } - if let Some(pres_penalty) = self.presence_penalty { - span.set_attribute(KeyValue::new( - GEN_AI_REQUEST_PRESENCE_PENALTY, - pres_penalty as f64, - )); - } - if let Some(top_p) = self.top_p { - span.set_attribute(KeyValue::new(GEN_AI_REQUEST_TOP_P, top_p as f64)); - } - if let Some(temp) = self.temperature { - span.set_attribute(KeyValue::new(GEN_AI_REQUEST_TEMPERATURE, temp as f64)); - } + set_optional_f64(span, GEN_AI_REQUEST_FREQUENCY_PENALTY, self.frequency_penalty); + set_optional_f64(span, GEN_AI_REQUEST_PRESENCE_PENALTY, self.presence_penalty); + set_optional_f64(span, GEN_AI_REQUEST_TOP_P, self.top_p); + set_optional_f64(span, GEN_AI_REQUEST_TEMPERATURE, self.temperature); } } diff --git a/src/pipelines/pipeline.rs b/src/pipelines/pipeline.rs index 32f08f97..43ffd60c 100644 --- a/src/pipelines/pipeline.rs +++ b/src/pipelines/pipeline.rs @@ -1,5 +1,6 @@ use crate::config::models::PipelineType; use crate::guardrails::middleware::GuardrailsLayer; +use crate::guardrails::setup::build_pipeline_guardrails; use crate::guardrails::types::GuardrailResources; use crate::models::chat::ChatCompletionResponse; use crate::models::completion::CompletionRequest; @@ -33,12 +34,6 @@ fn with_tracer(tracer: &SharedTracer, f: impl FnOnce(&mut OtelTracer) -> R) - f(&mut tracer.lock().unwrap()) } -// Re-export builder and orchestrator functions for backward compatibility with tests -pub use crate::guardrails::runner::{blocked_response, warning_header_value}; -pub use crate::guardrails::setup::{ - build_guardrail_resources, build_pipeline_guardrails, resolve_guard_defaults, -}; - pub fn create_pipeline( pipeline: &Pipeline, model_registry: &ModelRegistry, diff --git a/src/state.rs b/src/state.rs index a3e17c0b..2850e6c8 100644 --- a/src/state.rs +++ b/src/state.rs @@ -326,7 +326,7 @@ impl tower::Service> for PipelineSteeringService { }; Box::pin(async move { - let router = Arc::try_unwrap(router).unwrap_or_else(|arc_router| (*arc_router).clone()); + let router: Router = (*router).clone(); match router.oneshot(request).await { Ok(response) => Ok(response), From 4817abc79ed0bb9ddf45feb9b230ef59bf81a762 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Thu, 19 Feb 2026 10:32:17 +0200 Subject: [PATCH 53/59] ci --- src/guardrails/middleware.rs | 4 ++-- src/guardrails/providers/traceloop.rs | 2 +- src/guardrails/runner.rs | 4 ++-- src/pipelines/otel.rs | 12 ++++++++++-- tests/guardrails/test_e2e.rs | 8 ++++---- 5 files changed, 19 insertions(+), 11 deletions(-) diff --git a/src/guardrails/middleware.rs b/src/guardrails/middleware.rs index b05e789a..c5db4e9e 100644 --- a/src/guardrails/middleware.rs +++ b/src/guardrails/middleware.rs @@ -115,7 +115,7 @@ async fn handle_post_call_guards( if let Some(result) = post_result { match result { - Err(blocked) => return blocked, + Err(blocked) => return *blocked, Ok(w) => warnings.extend(w), } } @@ -267,7 +267,7 @@ where // --- Pre-call guards --- let all_warnings = match runner.run_pre_call(&parsed_request).await { Ok(warnings) => warnings, - Err(blocked) => return Ok(blocked), + Err(blocked) => return Ok(*blocked), }; // --- Call inner service --- diff --git a/src/guardrails/providers/traceloop.rs b/src/guardrails/providers/traceloop.rs index 8eea4e9d..01bd72d3 100644 --- a/src/guardrails/providers/traceloop.rs +++ b/src/guardrails/providers/traceloop.rs @@ -1,8 +1,8 @@ use async_trait::async_trait; -use crate::guardrails::types::GuardrailClient; use crate::guardrails::evaluator_types::get_evaluator; use crate::guardrails::parsing::parse_evaluator_http_response; +use crate::guardrails::types::GuardrailClient; use crate::guardrails::types::{EvaluatorResponse, Guard, GuardrailError}; const DEFAULT_TRACELOOP_API: &str = "https://api.traceloop.com"; diff --git a/src/guardrails/runner.rs b/src/guardrails/runner.rs index 180250a8..860f34e7 100644 --- a/src/guardrails/runner.rs +++ b/src/guardrails/runner.rs @@ -195,7 +195,7 @@ pub async fn execute_guards( } /// Result of a guard phase: Ok(warnings) on pass, Err(blocked_response) on block. -pub type GuardPhaseResult = Result, Response>; +pub type GuardPhaseResult = Result, Box>; pub struct GuardrailsRunner<'a> { pre_call: Vec, @@ -208,7 +208,7 @@ pub struct GuardrailsRunner<'a> { /// If the outcome is blocked, produces a blocked response; otherwise, forwards warnings. fn outcome_to_phase_result(outcome: GuardrailsOutcome) -> GuardPhaseResult { if outcome.blocked { - Err(blocked_response(&outcome)) + Err(Box::new(blocked_response(&outcome))) } else { Ok(outcome.warnings) } diff --git a/src/pipelines/otel.rs b/src/pipelines/otel.rs index 50fbf2ec..82bd8de6 100644 --- a/src/pipelines/otel.rs +++ b/src/pipelines/otel.rs @@ -241,7 +241,11 @@ impl RecordSpan for ChatCompletionRequest { span.set_attribute(KeyValue::new("llm.request.type", "chat")); span.set_attribute(KeyValue::new(GEN_AI_REQUEST_MODEL, self.model.clone())); - set_optional_f64(span, GEN_AI_REQUEST_FREQUENCY_PENALTY, self.frequency_penalty); + set_optional_f64( + span, + GEN_AI_REQUEST_FREQUENCY_PENALTY, + self.frequency_penalty, + ); set_optional_f64(span, GEN_AI_REQUEST_PRESENCE_PENALTY, self.presence_penalty); set_optional_f64(span, GEN_AI_REQUEST_TOP_P, self.top_p); set_optional_f64(span, GEN_AI_REQUEST_TEMPERATURE, self.temperature); @@ -297,7 +301,11 @@ impl RecordSpan for CompletionRequest { span.set_attribute(KeyValue::new(GEN_AI_REQUEST_MODEL, self.model.clone())); span.set_attribute(KeyValue::new("gen_ai.prompt", self.prompt.clone())); - set_optional_f64(span, GEN_AI_REQUEST_FREQUENCY_PENALTY, self.frequency_penalty); + set_optional_f64( + span, + GEN_AI_REQUEST_FREQUENCY_PENALTY, + self.frequency_penalty, + ); set_optional_f64(span, GEN_AI_REQUEST_PRESENCE_PENALTY, self.presence_penalty); set_optional_f64(span, GEN_AI_REQUEST_TOP_P, self.top_p); set_optional_f64(span, GEN_AI_REQUEST_TEMPERATURE, self.temperature); diff --git a/tests/guardrails/test_e2e.rs b/tests/guardrails/test_e2e.rs index 07fb50c1..b76923f6 100644 --- a/tests/guardrails/test_e2e.rs +++ b/tests/guardrails/test_e2e.rs @@ -605,11 +605,11 @@ async fn test_post_call_skipped_on_empty_response() { let empty_completion = create_test_chat_completion(""); let result = runner.run_post_call(&empty_completion).await; - assert!(result.blocked_response.is_none()); - assert_eq!(result.warnings.len(), 1); - assert!(result.warnings[0].reason.contains("empty response content")); + let warnings = result.expect("should not be blocked"); + assert_eq!(warnings.len(), 1); + assert!(warnings[0].reason.contains("empty response content")); - let header = warning_header_value(&result.warnings); + let header = warning_header_value(&warnings); assert!(header.contains("skipped")); // wiremock will verify expect(0) — evaluator was never called } From 853c8c9d77ba55c9d918934a6afcac97df5dd1c6 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Thu, 19 Feb 2026 15:31:59 +0200 Subject: [PATCH 54/59] enable pre_call for streaming --- src/guardrails/GUARDRAILS.md | 2 +- src/guardrails/middleware.rs | 14 +- tests/guardrails/test_middleware.rs | 272 +++++++++++++++++++++++++--- 3 files changed, 252 insertions(+), 36 deletions(-) diff --git a/src/guardrails/GUARDRAILS.md b/src/guardrails/GUARDRAILS.md index edf75dcc..f934ad61 100644 --- a/src/guardrails/GUARDRAILS.md +++ b/src/guardrails/GUARDRAILS.md @@ -46,7 +46,7 @@ This document focuses on **config mode** available in Traceloop Hub v1. - ✅ Chat Completions (`/v1/chat/completions`) — pre-call and post-call guards - ✅ Completions (`/v1/completions`) — pre-call and post-call guards - ✅ Embeddings (`/v1/embeddings`) — **pre-call guards only** -- ❌ Streaming requests are **not supported** — guardrails require the complete request/response for evaluation +- ⚠️ Streaming requests (`"stream": true`) — **pre-call guards only** (post-call guards are skipped because the response is sent as incremental chunks) --- diff --git a/src/guardrails/middleware.rs b/src/guardrails/middleware.rs index c5db4e9e..fa5d1941 100644 --- a/src/guardrails/middleware.rs +++ b/src/guardrails/middleware.rs @@ -58,7 +58,12 @@ impl ParsedRequest { } /// Returns true if this request type supports post-call guards. + /// Streaming requests do not support post-call guards because the response + /// is sent as chunks and cannot be buffered for evaluation. fn supports_post_call(&self) -> bool { + if self.is_streaming() { + return false; + } match self { ParsedRequest::Chat(_) | ParsedRequest::Completion(_) => true, ParsedRequest::Embeddings(_) => false, @@ -143,7 +148,7 @@ fn try_parse(bytes: &[u8], label: &str) -> Option { /// /// - **Pre-call guards** run before the inner service, inspecting the request body. /// - **Post-call guards** run after the inner service, inspecting the response body. -/// - Streaming requests (`"stream": true`) bypass guardrails entirely. +/// - Streaming requests (`"stream": true`) run pre-call guards but skip post-call guards. #[derive(Clone)] pub struct GuardrailsLayer { guardrails: Option>, @@ -240,13 +245,6 @@ where } }; - // Skip guardrails for streaming requests - if parsed_request.is_streaming() { - debug!("Guardrails middleware: streaming request, skipping guardrails"); - let request = Request::from_parts(parts, Body::from(bytes)); - return inner.call(request).await; - } - // Resolve guards from pipeline config + request headers // Extract parent context from the tracer in request extensions let parent_cx = parts diff --git a/tests/guardrails/test_middleware.rs b/tests/guardrails/test_middleware.rs index b18e1341..51cebc33 100644 --- a/tests/guardrails/test_middleware.rs +++ b/tests/guardrails/test_middleware.rs @@ -451,20 +451,20 @@ async fn test_post_call_guard_skipped_for_embeddings() { // =========================================================================== #[tokio::test] -async fn test_streaming_chat_bypasses_guards() { - // Set up mock evaluator (should never be called) +async fn test_streaming_chat_runs_pre_call_guards() { + // Set up mock evaluator (should be called for pre-call) let eval_server = MockServer::start().await; - Mock::given(matchers::any()) + Mock::given(matchers::method("POST")) .respond_with(ResponseTemplate::new(200).set_body_json(json!({ "result": {}, - "pass": false // Would block if evaluated + "pass": true }))) - .expect(0) // Should never be called for streaming + .expect(1) // Pre-call guard should run even for streaming .mount(&eval_server) .await; let guard = guard_with_server( - "blocker", + "detector", GuardMode::PreCall, OnFailure::Block, &eval_server.uri(), @@ -478,7 +478,7 @@ async fn test_streaming_chat_bypasses_guards() { let mut service = layer.layer(inner_service); // Create STREAMING chat request - let request = create_streaming_chat_request("This would fail guards if checked"); + let request = create_streaming_chat_request("Safe input"); let http_request = Request::builder() .method("POST") .uri("/v1/chat/completions") @@ -494,34 +494,27 @@ async fn test_streaming_chat_bypasses_guards() { .await .unwrap(); - // Verify response is 200 OK (streaming bypasses guards) + // Verify response is 200 OK (pre-call guard passed) assert_eq!(response.status(), StatusCode::OK); - // Verify no warning header - assert!( - !response - .headers() - .contains_key("x-traceloop-guardrail-warning") - ); - - // Wiremock verifies evaluator was never called (expect(0)) + // Wiremock verifies evaluator was called exactly once (expect(1)) } #[tokio::test] -async fn test_streaming_completion_bypasses_guards() { - // Set up mock evaluator (should never be called) +async fn test_streaming_completion_runs_pre_call_guards() { + // Set up mock evaluator (should be called for pre-call) let eval_server = MockServer::start().await; - Mock::given(matchers::any()) + Mock::given(matchers::method("POST")) .respond_with(ResponseTemplate::new(200).set_body_json(json!({ "result": {}, - "pass": false // Would block if evaluated + "pass": true }))) - .expect(0) // Should never be called for streaming + .expect(1) // Pre-call guard should run even for streaming .mount(&eval_server) .await; let guard = guard_with_server( - "blocker", + "detector", GuardMode::PreCall, OnFailure::Block, &eval_server.uri(), @@ -535,7 +528,7 @@ async fn test_streaming_completion_bypasses_guards() { let mut service = layer.layer(inner_service); // Create STREAMING completion request - let request = create_streaming_completion_request("This would fail guards if checked"); + let request = create_streaming_completion_request("Safe input"); let http_request = Request::builder() .method("POST") .uri("/v1/completions") @@ -551,17 +544,242 @@ async fn test_streaming_completion_bypasses_guards() { .await .unwrap(); - // Verify response is 200 OK (streaming bypasses guards) + // Verify response is 200 OK (pre-call guard passed) assert_eq!(response.status(), StatusCode::OK); - // Verify no warning header + // Wiremock verifies evaluator was called exactly once (expect(1)) +} + +#[tokio::test] +async fn test_streaming_chat_pre_call_blocks() { + // Set up mock evaluator that fails + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {"reason": "toxic content"}, + "pass": false + }))) + .expect(1) + .mount(&eval_server) + .await; + + let guard = guard_with_server( + "blocker", + GuardMode::PreCall, + OnFailure::Block, + &eval_server.uri(), + ); + let guardrails = create_guardrails(vec![guard]); + + let completion = create_test_chat_completion("This won't be returned"); + let inner_service = MockService::with_json(StatusCode::OK, &completion); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + // Create STREAMING chat request with bad input + let request = create_streaming_chat_request("Bad input"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); + + // Verify blocked by pre-call guard even for streaming + assert_eq!(response.status(), StatusCode::FORBIDDEN); + + let body_bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap(); + assert_eq!(response_json["error"]["guardrail"], "blocker"); +} + +#[tokio::test] +async fn test_streaming_chat_pre_call_warns() { + // Set up mock evaluator that fails with warn + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {"reason": "borderline"}, + "pass": false + }))) + .expect(1) + .mount(&eval_server) + .await; + + let guard = guard_with_server( + "warner", + GuardMode::PreCall, + OnFailure::Warn, + &eval_server.uri(), + ); + let guardrails = create_guardrails(vec![guard]); + + let completion = create_test_chat_completion("Response text"); + let inner_service = MockService::with_json(StatusCode::OK, &completion); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + let request = create_streaming_chat_request("Borderline input"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); + + // Verify passes with warning even for streaming + assert_eq!(response.status(), StatusCode::OK); assert!( - !response + response .headers() .contains_key("x-traceloop-guardrail-warning") ); - // Wiremock verifies evaluator was never called (expect(0)) + let warning_header = response + .headers() + .get("x-traceloop-guardrail-warning") + .unwrap() + .to_str() + .unwrap(); + assert!(warning_header.contains("warner")); +} + +#[tokio::test] +async fn test_streaming_chat_post_call_skipped() { + // Set up mock evaluator for pre-call (should pass) + let pre_eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {}, + "pass": true + }))) + .expect(1) // Pre-call should run + .mount(&pre_eval_server) + .await; + + // Set up mock evaluator for post-call (should NOT be called) + let post_eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {}, + "pass": false // Would block if called + }))) + .expect(0) // Post-call should NOT run for streaming + .mount(&post_eval_server) + .await; + + let pre_guard = guard_with_server( + "pre-guard", + GuardMode::PreCall, + OnFailure::Block, + &pre_eval_server.uri(), + ); + let post_guard = guard_with_server( + "post-guard", + GuardMode::PostCall, + OnFailure::Block, + &post_eval_server.uri(), + ); + let guardrails = create_guardrails(vec![pre_guard, post_guard]); + + let completion = create_test_chat_completion("Streamed response"); + let inner_service = MockService::with_json(StatusCode::OK, &completion); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + let request = create_streaming_chat_request("Safe input"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); + + // Verify 200 OK — post-call guard was skipped for streaming + assert_eq!(response.status(), StatusCode::OK); + + // Wiremock verifies: + // - pre-call evaluator called exactly once (expect(1)) + // - post-call evaluator never called (expect(0)) +} + +#[tokio::test] +async fn test_streaming_completion_pre_call_blocks() { + // Set up mock evaluator that fails + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {"reason": "toxic content"}, + "pass": false + }))) + .expect(1) + .mount(&eval_server) + .await; + + let guard = guard_with_server( + "blocker", + GuardMode::PreCall, + OnFailure::Block, + &eval_server.uri(), + ); + let guardrails = create_guardrails(vec![guard]); + + let completion = create_test_completion_response("This won't be returned"); + let inner_service = MockService::with_json(StatusCode::OK, &completion); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + // Create STREAMING completion request with bad input + let request = create_streaming_completion_request("Bad input"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/completions") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); + + // Verify blocked by pre-call guard even for streaming + assert_eq!(response.status(), StatusCode::FORBIDDEN); + + let body_bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap(); + assert_eq!(response_json["error"]["guardrail"], "blocker"); } // =========================================================================== From f259218e6078b0ddfbd009cbb48a268a06d5f70e Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Thu, 19 Feb 2026 15:45:06 +0200 Subject: [PATCH 55/59] fix1 --- src/guardrails/middleware.rs | 8 ++- tests/guardrails/test_middleware.rs | 98 ++++++++++++++++++++++++++++- 2 files changed, 103 insertions(+), 3 deletions(-) diff --git a/src/guardrails/middleware.rs b/src/guardrails/middleware.rs index fa5d1941..23f1f6a9 100644 --- a/src/guardrails/middleware.rs +++ b/src/guardrails/middleware.rs @@ -20,6 +20,10 @@ use super::parsing::PromptExtractor; use super::runner::GuardrailsRunner; use super::types::Guardrails; +/// Maximum allowed body size for request/response buffering (10 MB). +/// Prevents unbounded memory allocation when guardrails need to inspect bodies. +pub const MAX_BODY_SIZE: usize = 10 * 1024 * 1024; + /// Enum representing the endpoint type. #[derive(Debug, Clone, Copy)] enum EndpointType { @@ -89,7 +93,7 @@ async fn handle_post_call_guards( runner: &GuardrailsRunner<'_>, mut warnings: Vec, ) -> Response { - let resp_bytes = match axum::body::to_bytes(resp_body, usize::MAX).await { + let resp_bytes = match axum::body::to_bytes(resp_body, MAX_BODY_SIZE).await { Ok(b) => b, Err(_) => { debug!("Guardrails middleware: failed to buffer response body, skipping post-call"); @@ -220,7 +224,7 @@ where }; // Buffer request body - let bytes = match axum::body::to_bytes(body, usize::MAX).await { + let bytes = match axum::body::to_bytes(body, MAX_BODY_SIZE).await { Ok(b) => b, Err(_) => { debug!("Guardrails middleware: failed to buffer request body, passing through"); diff --git a/tests/guardrails/test_middleware.rs b/tests/guardrails/test_middleware.rs index 51cebc33..d0163293 100644 --- a/tests/guardrails/test_middleware.rs +++ b/tests/guardrails/test_middleware.rs @@ -1,4 +1,4 @@ -use hub_lib::guardrails::middleware::GuardrailsLayer; +use hub_lib::guardrails::middleware::{GuardrailsLayer, MAX_BODY_SIZE}; use hub_lib::guardrails::providers::traceloop::TraceloopClient; use hub_lib::guardrails::types::{Guard, GuardMode, Guardrails, OnFailure}; @@ -859,3 +859,99 @@ async fn test_unsupported_endpoint_passes() { // Verify passes through assert_eq!(response.status(), StatusCode::OK); } + +// =========================================================================== +// Category 6: Body Size Limits +// =========================================================================== + +#[tokio::test] +async fn test_request_body_exceeding_limit_returns_400() { + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {}, + "pass": true + }))) + .expect(0) // Should never reach the evaluator + .mount(&eval_server) + .await; + + let guard = guard_with_server( + "detector", + GuardMode::PreCall, + OnFailure::Block, + &eval_server.uri(), + ); + let guardrails = create_guardrails(vec![guard]); + + let completion = create_test_chat_completion("Response"); + let inner_service = MockService::with_json(StatusCode::OK, &completion); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + // Body larger than MAX_BODY_SIZE (10 MB) + let oversized = vec![b'x'; MAX_BODY_SIZE + 1]; + let http_request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from(oversized)) + .unwrap(); + + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); +} + +#[tokio::test] +async fn test_request_body_within_limit_is_accepted() { + let eval_server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "result": {}, + "pass": true + }))) + .expect(1) + .mount(&eval_server) + .await; + + let guard = guard_with_server( + "detector", + GuardMode::PreCall, + OnFailure::Block, + &eval_server.uri(), + ); + let guardrails = create_guardrails(vec![guard]); + + let completion = create_test_chat_completion("Response"); + let inner_service = MockService::with_json(StatusCode::OK, &completion); + + let layer = GuardrailsLayer::new(Some(Arc::new(guardrails))); + let mut service = layer.layer(inner_service); + + let request = create_test_chat_request("Test input"); + let http_request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&request).unwrap())) + .unwrap(); + + let response = service + .ready() + .await + .unwrap() + .call(http_request) + .await + .unwrap(); + + // Should pass through (200 from inner service) + assert_eq!(response.status(), StatusCode::OK); +} From d3b7063742603387753b0a7b28119480cf417f45 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Thu, 19 Feb 2026 15:53:01 +0200 Subject: [PATCH 56/59] test unsafe fix --- Cargo.lock | 10 ++++++++++ Cargo.toml | 1 + tests/guardrails/test_e2e.rs | 17 +++++++---------- tests/guardrails/test_types.rs | 15 ++++++--------- 4 files changed, 24 insertions(+), 19 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e0bae977..7f109bb3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2336,6 +2336,7 @@ dependencies = [ "sqlx", "surf", "surf-vcr", + "temp-env", "tempfile", "testcontainers", "testcontainers-modules", @@ -5102,6 +5103,15 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" +[[package]] +name = "temp-env" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96374855068f47402c3121c6eed88d29cb1de8f3ab27090e273e420bdabcf050" +dependencies = [ + "parking_lot", +] + [[package]] name = "tempfile" version = "3.24.0" diff --git a/Cargo.toml b/Cargo.toml index fd13ebd6..5d6f3662 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -73,6 +73,7 @@ tempfile = "3.8" testcontainers = "0.20.0" testcontainers-modules = { version = "0.8.0", features = ["postgres"] } axum-test = "17" +temp-env = "0.3" sqlx = { version = "0.8", features = ["runtime-tokio", "postgres", "migrate"] } tokio = { version = "1.45.0", features = ["full"] } reqwest = { version = "0.12", features = ["json"] } diff --git a/tests/guardrails/test_e2e.rs b/tests/guardrails/test_e2e.rs index b76923f6..3035b5e3 100644 --- a/tests/guardrails/test_e2e.rs +++ b/tests/guardrails/test_e2e.rs @@ -256,10 +256,6 @@ async fn test_e2e_config_from_yaml_with_env_vars() { use std::io::Write; use tempfile::NamedTempFile; - unsafe { - std::env::set_var("E2E_TEST_API_KEY", "resolved-key-123"); - } - let config_yaml = r#" providers: - key: openai @@ -297,9 +293,14 @@ guardrails: let mut temp_file = NamedTempFile::new().unwrap(); temp_file.write_all(config_yaml.as_bytes()).unwrap(); + let temp_path = temp_file.path().to_str().unwrap().to_owned(); - let config = hub_lib::config::load_config(temp_file.path().to_str().unwrap()).unwrap(); - let gr = config.guardrails.unwrap(); + let gr = + temp_env::with_var("E2E_TEST_API_KEY", Some("resolved-key-123"), || { + let config = + hub_lib::config::load_config(&temp_path).unwrap(); + config.guardrails.unwrap() + }); assert_eq!(gr.providers.len(), 1); assert_eq!(gr.providers["traceloop"].api_key, "resolved-key-123"); @@ -337,10 +338,6 @@ guardrails: assert_eq!(pre_guard.api_key.as_deref(), Some("resolved-key-123")); // Guard with override keeps its own api_key assert_eq!(post_guard.api_key.as_deref(), Some("override-key")); - - unsafe { - std::env::remove_var("E2E_TEST_API_KEY"); - } } #[tokio::test] diff --git a/tests/guardrails/test_types.rs b/tests/guardrails/test_types.rs index d7071c60..f701358c 100644 --- a/tests/guardrails/test_types.rs +++ b/tests/guardrails/test_types.rs @@ -127,9 +127,6 @@ fn test_gateway_config_without_guardrails_backward_compat() { #[test] fn test_guard_config_env_var_in_api_key() { - unsafe { - std::env::set_var("TEST_GUARD_API_KEY_UNIQUE", "tl-secret-key"); - } let config_content = r#" providers: - key: openai @@ -157,12 +154,12 @@ guardrails: "#; let mut temp_file = NamedTempFile::new().unwrap(); temp_file.write_all(config_content.as_bytes()).unwrap(); - let config = hub_lib::config::load_config(temp_file.path().to_str().unwrap()).unwrap(); - let guards = config.guardrails.unwrap().guards; - assert_eq!(guards[0].api_key.as_deref(), Some("tl-secret-key")); - unsafe { - std::env::remove_var("TEST_GUARD_API_KEY_UNIQUE"); - } + let temp_path = temp_file.path().to_str().unwrap().to_owned(); + temp_env::with_var("TEST_GUARD_API_KEY_UNIQUE", Some("tl-secret-key"), || { + let config = hub_lib::config::load_config(&temp_path).unwrap(); + let guards = config.guardrails.unwrap().guards; + assert_eq!(guards[0].api_key.as_deref(), Some("tl-secret-key")); + }); } // --------------------------------------------------------------------------- From 1a9013860afdc2adb6de8184c650248d83fe5972 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Thu, 19 Feb 2026 16:00:31 +0200 Subject: [PATCH 57/59] fix 3 --- src/config/validation.rs | 51 +++++++++++++++++---------------- src/guardrails/runner.rs | 8 +++++- tests/guardrails/test_runner.rs | 6 ++++ 3 files changed, 39 insertions(+), 26 deletions(-) diff --git a/src/config/validation.rs b/src/config/validation.rs index 63e49314..bd8ab5ac 100644 --- a/src/config/validation.rs +++ b/src/config/validation.rs @@ -51,37 +51,38 @@ pub fn validate_gateway_config(config: &GatewayConfig) -> Result<(), Vec // Validate all guards in a single pass let mut seen_guard_names = HashSet::new(); for guard in &gr_config.guards { - // Check provider reference exists + // Check provider reference exists; skip api_base/api_key checks + // when the provider is missing since those would be redundant. if !gr_config.providers.contains_key(&guard.provider) { errors.push(format!( "Guard '{}' references non-existent guardrail provider '{}'.", guard.name, guard.provider )); - } - - // Check api_base and api_key (either directly or via provider) - let has_api_base = guard.api_base.as_ref().is_some_and(|s| !s.is_empty()) - || gr_config - .providers - .get(&guard.provider) - .is_some_and(|p| !p.api_base.is_empty()); - let has_api_key = guard.api_key.as_ref().is_some_and(|s| !s.is_empty()) - || gr_config - .providers - .get(&guard.provider) - .is_some_and(|p| !p.api_key.is_empty()); + } else { + // Check api_base and api_key (either directly or via provider) + let has_api_base = guard.api_base.as_ref().is_some_and(|s| !s.is_empty()) + || gr_config + .providers + .get(&guard.provider) + .is_some_and(|p| !p.api_base.is_empty()); + let has_api_key = guard.api_key.as_ref().is_some_and(|s| !s.is_empty()) + || gr_config + .providers + .get(&guard.provider) + .is_some_and(|p| !p.api_key.is_empty()); - if !has_api_base { - errors.push(format!( - "Guard '{}' has no api_base configured (neither on the guard nor on provider '{}').", - guard.name, guard.provider - )); - } - if !has_api_key { - errors.push(format!( - "Guard '{}' has no api_key configured (neither on the guard nor on provider '{}').", - guard.name, guard.provider - )); + if !has_api_base { + errors.push(format!( + "Guard '{}' has no api_base configured (neither on the guard nor on provider '{}').", + guard.name, guard.provider + )); + } + if !has_api_key { + errors.push(format!( + "Guard '{}' has no api_key configured (neither on the guard nor on provider '{}').", + guard.name, guard.provider + )); + } } // Check evaluator slug is recognized diff --git a/src/guardrails/runner.rs b/src/guardrails/runner.rs index 860f34e7..52268e17 100644 --- a/src/guardrails/runner.rs +++ b/src/guardrails/runner.rs @@ -167,15 +167,21 @@ pub async fn execute_guards( } Err(err) => { let is_required = guard.required; + let error_msg = err.to_string(); if is_required { blocked = true; if blocking_guard.is_none() { blocking_guard = Some(name.clone()); } + } else { + warnings.push(GuardWarning { + guard_name: name.clone(), + reason: format!("evaluator error: {error_msg}"), + }); } results.push(GuardResult::Error { name, - error: err.to_string(), + error: error_msg, required: is_required, }); } diff --git a/tests/guardrails/test_runner.rs b/tests/guardrails/test_runner.rs index 5940aa73..c6e91322 100644 --- a/tests/guardrails/test_runner.rs +++ b/tests/guardrails/test_runner.rs @@ -111,6 +111,10 @@ async fn test_guard_evaluator_unavailable_required_false() { .. } )); + // Non-required guard error should produce a warning header (fail-open) + assert_eq!(outcome.warnings.len(), 1); + assert_eq!(outcome.warnings[0].guard_name, "check"); + assert!(outcome.warnings[0].reason.contains("evaluator error")); } #[tokio::test] @@ -124,6 +128,8 @@ async fn test_guard_evaluator_unavailable_required_true() { ); let outcome = execute_guards(&[guard], "input", &mock_client, None).await; assert!(outcome.blocked); // Fail-closed + // Required guard error should NOT produce a warning (it blocks instead) + assert!(outcome.warnings.is_empty()); } #[tokio::test] From efaae4495b8e116179d1e5e148d13a8d98421a50 Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Thu, 19 Feb 2026 16:11:52 +0200 Subject: [PATCH 58/59] fix 4 --- src/guardrails/middleware.rs | 16 +++++++++++++--- src/guardrails/providers/traceloop.rs | 3 +-- .../guardrails/secrets_detector_pass.json | 4 ++-- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/guardrails/middleware.rs b/src/guardrails/middleware.rs index 23f1f6a9..acf8a641 100644 --- a/src/guardrails/middleware.rs +++ b/src/guardrails/middleware.rs @@ -7,7 +7,7 @@ use axum::body::Body; use axum::extract::Request; use axum::response::{IntoResponse, Response}; use tower::{Layer, Service}; -use tracing::debug; +use tracing::{debug, warn}; use crate::models::chat::{ChatCompletion, ChatCompletionRequest}; use crate::models::completion::{CompletionRequest, CompletionResponse}; @@ -227,8 +227,18 @@ where let bytes = match axum::body::to_bytes(body, MAX_BODY_SIZE).await { Ok(b) => b, Err(_) => { - debug!("Guardrails middleware: failed to buffer request body, passing through"); - return Ok(axum::http::StatusCode::BAD_REQUEST.into_response()); + warn!("Guardrails middleware: request body too large or unreadable"); + let body = serde_json::json!({ + "error": { + "message": "Request body too large or unreadable", + "type": "invalid_request_error", + } + }); + return Ok(( + axum::http::StatusCode::BAD_REQUEST, + axum::Json(body), + ) + .into_response()); } }; diff --git a/src/guardrails/providers/traceloop.rs b/src/guardrails/providers/traceloop.rs index 01bd72d3..a2c2f6cb 100644 --- a/src/guardrails/providers/traceloop.rs +++ b/src/guardrails/providers/traceloop.rs @@ -9,7 +9,7 @@ const DEFAULT_TRACELOOP_API: &str = "https://api.traceloop.com"; const DEFAULT_TIMEOUT_SEC: u64 = 3; /// HTTP client for the Traceloop evaluator API service. -/// Calls `POST {api_base}/v2/guardrails/{evaluator_slug}`. +/// Calls `POST {api_base}/v2/guardrails/execute/{evaluator_slug}`. pub struct TraceloopClient { http_client: reqwest::Client, } @@ -74,7 +74,6 @@ impl GuardrailClient for TraceloopClient { .http_client .post(&url) .header("Authorization", format!("Bearer {api_key}")) - .header("Content-Type", "application/json") .json(&body) .send() .await?; diff --git a/tests/cassettes/guardrails/secrets_detector_pass.json b/tests/cassettes/guardrails/secrets_detector_pass.json index de879761..6f5615ee 100644 --- a/tests/cassettes/guardrails/secrets_detector_pass.json +++ b/tests/cassettes/guardrails/secrets_detector_pass.json @@ -1,10 +1,10 @@ { "evaluator_slug": "secrets-detector", - "input_text": "You are a complete idiot and everyone hates you. You should be ashamed.", + "input_text": "Here is a simple function that adds two numbers together.", "params": {}, "request_body": { "input": { - "text": "You are a complete idiot and everyone hates you. You should be ashamed." + "text": "Here is a simple function that adds two numbers together." } }, "response_status": 200, From 318f60d3e196955d6ffe415c07f206216ef0cfaf Mon Sep 17 00:00:00 2001 From: nina-kollman <59646487+nina-kollman@users.noreply.github.com> Date: Sun, 22 Feb 2026 09:24:31 +0200 Subject: [PATCH 59/59] fix build --- src/config/validation.rs | 2 -- src/guardrails/middleware.rs | 8 +++----- tests/guardrails/test_e2e.rs | 10 ++++------ 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/src/config/validation.rs b/src/config/validation.rs index bd8ab5ac..0782b4ff 100644 --- a/src/config/validation.rs +++ b/src/config/validation.rs @@ -251,8 +251,6 @@ mod tests { assert!(errors.iter().any(|e| { e.contains("references non-existent guardrail provider 'gr_p2_non_existent'") })); - assert!(errors.iter().any(|e| e.contains("no api_base configured"))); - assert!(errors.iter().any(|e| e.contains("no api_key configured"))); } #[test] diff --git a/src/guardrails/middleware.rs b/src/guardrails/middleware.rs index acf8a641..6067cfd9 100644 --- a/src/guardrails/middleware.rs +++ b/src/guardrails/middleware.rs @@ -234,11 +234,9 @@ where "type": "invalid_request_error", } }); - return Ok(( - axum::http::StatusCode::BAD_REQUEST, - axum::Json(body), - ) - .into_response()); + return Ok( + (axum::http::StatusCode::BAD_REQUEST, axum::Json(body)).into_response() + ); } }; diff --git a/tests/guardrails/test_e2e.rs b/tests/guardrails/test_e2e.rs index 3035b5e3..0d4c3aa3 100644 --- a/tests/guardrails/test_e2e.rs +++ b/tests/guardrails/test_e2e.rs @@ -295,12 +295,10 @@ guardrails: temp_file.write_all(config_yaml.as_bytes()).unwrap(); let temp_path = temp_file.path().to_str().unwrap().to_owned(); - let gr = - temp_env::with_var("E2E_TEST_API_KEY", Some("resolved-key-123"), || { - let config = - hub_lib::config::load_config(&temp_path).unwrap(); - config.guardrails.unwrap() - }); + let gr = temp_env::with_var("E2E_TEST_API_KEY", Some("resolved-key-123"), || { + let config = hub_lib::config::load_config(&temp_path).unwrap(); + config.guardrails.unwrap() + }); assert_eq!(gr.providers.len(), 1); assert_eq!(gr.providers["traceloop"].api_key, "resolved-key-123");