diff --git a/.github/workflows/cerebras-manual-test.yml b/.github/workflows/cerebras-manual-test.yml new file mode 100644 index 00000000..9f5c80d6 --- /dev/null +++ b/.github/workflows/cerebras-manual-test.yml @@ -0,0 +1,54 @@ +name: Cerebras Manual Test + +on: + workflow_dispatch: + inputs: + test_type: + description: 'Type of test to run' + required: true + default: 'compile' + type: choice + options: + - compile + - live + +permissions: + contents: read + +jobs: + test: + name: Cerebras Test + runs-on: ubuntu-latest + + env: + RUSTFLAGS: -Dwarnings + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo + uses: Swatinem/rust-cache@v2 + + - name: Compile Cerebras code + run: | + echo "Testing Cerebras adapter compilation..." + cargo build --verbose + cargo test --test tests_p_cerebras --no-run + + - name: Run live tests (if API key available) + if: ${{ github.event.inputs.test_type == 'live' && vars.CEREBRAS_API_KEY != '' }} + run: | + echo "Running live Cerebras tests..." + cargo test --test tests_p_cerebras -- --nocapture + env: + CEREBRAS_API_KEY: ${{ vars.CEREBRAS_API_KEY }} + + - name: Skip live tests (no API key) + if: ${{ github.event.inputs.test_type == 'live' && vars.CEREBRAS_API_KEY == '' }} + run: | + echo "CEREBRAS_API_KEY not configured - skipping live tests" + echo "To enable live tests, add CEREBRAS_API_KEY as a repository variable" \ No newline at end of file diff --git a/.github/workflows/cerebras-tests.yml b/.github/workflows/cerebras-tests.yml new file mode 100644 index 00000000..2468421d --- /dev/null +++ b/.github/workflows/cerebras-tests.yml @@ -0,0 +1,79 @@ +name: Cerebras Provider Tests + +# Tests Cerebras adapter integration with live API calls +on: + push: + branches: [ main, master, develop ] + paths: + - 'src/adapter/adapters/cerebras/**' + - 'tests/tests_p_cerebras.rs' + - 'examples/c11-cerebras.rs' + pull_request: + branches: [ "**" ] + paths: + - 'src/adapter/adapters/cerebras/**' + - 'tests/tests_p_cerebras.rs' + - 'examples/c11-cerebras.rs' + workflow_dispatch: + +permissions: + contents: read + +jobs: + cerebras-tests: + name: Cerebras Integration Tests + runs-on: ubuntu-latest + + if: vars.CEREBRAS_API_KEY != '' + + env: + CEREBRAS_API_KEY: ${{ vars.CEREBRAS_API_KEY }} + RUSTFLAGS: -Dwarnings + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo + uses: Swatinem/rust-cache@v2 + + - name: Format check + run: cargo fmt --all -- --check + + - name: Clippy + run: cargo clippy --all-targets -- -D warnings + + - name: Build + run: cargo build --verbose + + - name: Run Cerebras tests + run: | + echo "Running Cerebras provider tests..." + cargo test --test tests_p_cerebras -- --nocapture + env: + CEREBRAS_API_KEY: ${{ vars.CEREBRAS_API_KEY }} + + - name: Run Cerebras example + run: | + echo "Running Cerebras example..." + cargo run --example c11-cerebras + env: + CEREBRAS_API_KEY: ${{ vars.CEREBRAS_API_KEY }} + + cerebras-tests-skipped: + name: Cerebras Tests Skipped + runs-on: ubuntu-latest + + if: vars.CEREBRAS_API_KEY == '' + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Skip tests + run: | + echo "CEREBRAS_API_KEY not configured - skipping live tests" + echo "To enable Cerebras tests, add CEREBRAS_API_KEY as a repository variable" \ No newline at end of file diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..128e5d7b --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,46 @@ +name: CI + +on: + push: + branches: [ main, master, develop ] + pull_request: + branches: [ "**" ] + +permissions: + contents: read + +jobs: + build: + name: Lint, Build, Test + runs-on: ubuntu-latest + + env: + RUSTFLAGS: -Dwarnings + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo + uses: Swatinem/rust-cache@v2 + + - name: Format check + run: cargo fmt --all -- --check + + - name: Clippy + run: cargo clippy --all-targets -- -D warnings + + - name: Build + run: cargo build --verbose + + - name: Tests (compile only) + run: cargo test --no-run + + - name: Provider tests skipped + run: | + echo "Live provider tests require API keys and are not run in CI." + echo "Cerebras adapter compilation test:" + cargo test --test tests_p_cerebras --no-run diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 00000000..222c922f --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,42 @@ +name: Release + +on: + push: + tags: + - 'v*.*.*' + +permissions: + contents: write + +jobs: + build-and-release: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo + uses: Swatinem/rust-cache@v2 + + - name: Build + run: cargo build --release --verbose + + - name: Package crate + run: cargo package --allow-dirty + + - name: Upload crate tarball + uses: actions/upload-artifact@v4 + with: + name: crate-tarball + path: target/package/*.crate + + - name: Create GitHub Release + uses: softprops/action-gh-release@v2 + with: + files: | + target/package/*.crate + generate_release_notes: true diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..dac7f735 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,23 @@ +repos: + - repo: local + hooks: + - id: rust-fmt + name: rustfmt + entry: cargo fmt --all -- --check + language: system + pass_filenames: false + - id: rust-clippy + name: clippy + entry: cargo clippy --all-targets -- -D warnings + language: system + pass_filenames: false + - id: rust-build + name: cargo build + entry: cargo build --verbose + language: system + pass_filenames: false + - id: rust-test-compile + name: cargo test (no run) + entry: cargo test --no-run + language: system + pass_filenames: false diff --git a/Cargo.toml b/Cargo.toml index 8ab02a8c..2fd9fe4c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,3 +40,8 @@ serial_test = "3.2.0" base64 = "0.22.0" # Check for the latest version bitflags = "2.8.0" gcp_auth = "0.12.3" +# Mock server dependencies +wiremock = "0.6.5" +uuid = { version = "1.11.0", features = ["v4", "serde"] } +# Test utilities +scopeguard = "1.2.0" diff --git a/README.md b/README.md index 2581d6a9..e0889ce9 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # genai - Multi-AI Providers Library for Rust -Currently natively supports: **OpenAI**, **Anthropic**, **Gemini**, **XAI/Grok**, **Ollama**, **Groq**, **DeepSeek** (deepseek.com & Groq), **Cohere** (more to come) +Currently natively supports: **OpenAI**, **Anthropic**, **Gemini**, **XAI/Grok**, **Ollama**, **Groq**, **DeepSeek** (deepseek.com & Groq), **Cohere**, **Cerebras** (more to come) Also allows a custom URL with `ServiceTargetResolver` (see [examples/c06-target-resolver.rs](examples/c06-target-resolver.rs)) @@ -65,7 +65,7 @@ See: ## Key Features -- Native Multi-AI Provider/Model: OpenAI, Anthropic, Gemini, Ollama, Groq, xAI, DeepSeek (Direct chat and stream) (see [examples/c00-readme.rs](examples/c00-readme.rs)) +- Native Multi-AI Provider/Model: OpenAI, Anthropic, Gemini, Ollama, Groq, xAI, DeepSeek, Cerebras (Direct chat and stream) (see [examples/c00-readme.rs](examples/c00-readme.rs)) - DeepSeekR1 support, with `reasoning_content` (and stream support), plus DeepSeek Groq and Ollama support (and `reasoning_content` normalization) - Image Analysis (for OpenAI, Gemini flash-2, Anthropic) (see [examples/c07-image.rs](examples/c07-image.rs)) - Custom Auth/API Key (see [examples/c02-auth.rs](examples/c02-auth.rs)) @@ -170,6 +170,7 @@ async fn main() -> Result<(), Box> { - [examples/c05-model-names.rs](examples/c05-model-names.rs) - Shows how to get model names per AdapterKind. - [examples/c06-target-resolver.rs](examples/c06-target-resolver.rs) - For custom auth, endpoint, and model. - [examples/c07-image.rs](examples/c07-image.rs) - Image analysis support +- [examples/c11-cerebras.rs](examples/c11-cerebras.rs) - Cerebras chat + streaming (set `CEREBRAS_API_KEY`)
Static Badge diff --git a/doc/test-specification.md b/doc/test-specification.md new file mode 100644 index 00000000..1c538448 --- /dev/null +++ b/doc/test-specification.md @@ -0,0 +1,449 @@ +# Test Specification for rust-genai Library + +## Overview + +This document outlines the comprehensive testing strategy for the rust-genai library, focusing on Anthropic and OpenRouter API compatibility. The testing approach includes both live API tests and mock server tests to ensure reliability and offline development capabilities. + +## Testing Architecture + +### 1. Live API Tests +- **Purpose**: Validate real-world API compatibility +- **Execution**: Run against actual provider APIs +- **Requirements**: Valid API keys and network access +- **Frequency**: Nightly builds and before releases + +### 2. Mock Server Tests +- **Purpose**: Enable offline testing and CI/CD reliability +- **Execution**: Run against local mock servers +- **Requirements**: No external dependencies +- **Frequency**: Every commit and PR + +## Test Categories + +### A. Core Chat Functionality + +#### A1. Simple Chat Completion +**Input**: Basic user message +```json +{ + "model": "claude-3-5-haiku-latest", + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "max_tokens": 100 +} +``` + +**Expected Output**: +```json +{ + "content": [{"type": "text", "text": "Hello! I'm doing well, thank you for asking."}], + "usage": {"prompt_tokens": 12, "completion_tokens": 15, "total_tokens": 27} +} +``` + +**Actions**: +- Verify response contains text content +- Validate token usage counts +- Ensure response time < 30 seconds +- Check content is non-empty + +#### A2. System Message Handling +**Input**: System message + user message +```json +{ + "model": "claude-3-5-haiku-latest", + "messages": [ + {"role": "system", "content": "You are a helpful assistant. Be concise."}, + {"role": "user", "content": "Explain quantum computing"} + ], + "max_tokens": 150 +} +``` + +**Expected Output**: Concise explanation of quantum computing + +**Actions**: +- Verify system message influences response style +- Check response is concise (< 100 words) +- Validate content accuracy + +#### A3. Multi-turn Conversation +**Input**: Conversation history +```json +{ + "model": "claude-3-5-haiku-latest", + "messages": [ + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "2+2 equals 4."}, + {"role": "user", "content": "What is 4+4?"} + ], + "max_tokens": 50 +} +``` + +**Expected Output**: "4+4 equals 8." + +**Actions**: +- Verify context preservation +- Check mathematical accuracy +- Validate conversation flow + +### B. Advanced Features + +#### B1. Streaming Responses +**Input**: Streaming request +```json +{ + "model": "claude-3-5-haiku-latest", + "messages": [{"role": "user", "content": "Count to 10"}], + "stream": true, + "max_tokens": 100 +} +``` + +**Expected Output**: Server-sent events with incremental content + +**Actions**: +- Verify streaming format compliance +- Check content chunk integrity +- Validate final assembled content +- Measure streaming latency + +#### B2. Tool/Function Calling +**Input**: Tool definition + user query +```json +{ + "model": "claude-3-5-haiku-latest", + "messages": [{"role": "user", "content": "What's the weather in Paris?"}], + "tools": [{ + "name": "get_weather", + "description": "Get weather information", + "input_schema": { + "type": "object", + "properties": { + "location": {"type": "string"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]} + }, + "required": ["location"] + } + }], + "max_tokens": 100 +} +``` + +**Expected Output**: Tool call request +```json +{ + "content": [{ + "type": "tool_use", + "id": "toolu_01...", + "name": "get_weather", + "input": {"location": "Paris", "unit": "celsius"} + }] +} +``` + +**Actions**: +- Verify tool call structure +- Validate parameter extraction +- Check tool response handling + +#### B3. JSON Mode +**Input**: JSON mode request +```json +{ + "model": "claude-3-5-haiku-latest", + "messages": [{"role": "user", "content": "List 3 colors in JSON format"}], + "response_format": {"type": "json_object"}, + "max_tokens": 100 +} +``` + +**Expected Output**: Valid JSON array of colors + +**Actions**: +- Verify JSON validity +- Check content structure +- Validate schema compliance + +### C. Error Handling + +#### C1. Authentication Errors +**Input**: Invalid API key +```json +{ + "model": "claude-3-5-haiku-latest", + "messages": [{"role": "user", "content": "Hello"}], + "headers": {"Authorization": "Bearer invalid-key"} +} +``` + +**Expected Output**: 401 Unauthorized + +**Actions**: +- Verify error code 401 +- Check error message clarity +- Validate error handling in client + +#### C2. Rate Limiting +**Input**: Rapid successive requests +```json +{ + "model": "claude-3-5-haiku-latest", + "messages": [{"role": "user", "content": "Hello"}] +} +``` + +**Expected Output**: 429 Too Many Requests + +**Actions**: +- Verify rate limit detection +- Check retry-after header +- Validate backoff mechanism + +#### C3. Invalid Requests +**Input**: Malformed request +```json +{ + "model": "invalid-model", + "messages": [{"role": "invalid", "content": 123}] +} +``` + +**Expected Output**: 400 Bad Request + +**Actions**: +- Verify error validation +- Check error message helpfulness +- Validate input sanitization + +### D. Performance Testing + +#### D1. Response Time +**Input**: Standard request +**Actions**: +- Measure response time +- Verify < 30 seconds for simple queries +- Track percentiles (p50, p95, p99) + +#### D2. Throughput +**Input**: Concurrent requests +**Actions**: +- Send 10 concurrent requests +- Measure total completion time +- Verify no request failures + +#### D3. Token Efficiency +**Input**: Various prompt sizes +**Actions**: +- Test with 1K, 10K, 100K token prompts +- Measure processing time per token +- Verify linear scaling + +### E. Media Handling + +#### E1. Image Input +**Input**: Image + text +```json +{ + "model": "claude-3-5-haiku-latest", + "messages": [{ + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + {"type": "image", "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": "base64-encoded-image" + }} + ] + }], + "max_tokens": 100 +} +``` + +**Expected Output**: Image description + +**Actions**: +- Verify image processing +- Check content accuracy +- Validate media type handling + +#### E2. Document Input +**Input**: PDF document + text +```json +{ + "model": "claude-3-5-haiku-latest", + "messages": [{ + "role": "user", + "content": [ + {"type": "text", "text": "Summarize this document"}, + {"type": "document", "source": { + "type": "base64", + "media_type": "application/pdf", + "data": "base64-encoded-pdf" + }} + ] + }], + "max_tokens": 200 +} +``` + +**Expected Output**: Document summary + +**Actions**: +- Verify PDF processing +- Check content extraction +- Validate summary accuracy + +## Mock Server Specifications + +### Anthropic Mock Server + +#### Endpoints: +- `POST /v1/messages` - Chat completions +- `POST /v1/messages/beta/stream` - Streaming chat + +#### Response Templates: +```json +// Success response +{ + "id": "msg_01...", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "Mock response"}], + "model": "claude-3-5-haiku-latest", + "stop_reason": "end_turn", + "stop_sequence": null, + "usage": { + "input_tokens": 10, + "output_tokens": 5 + } +} + +// Error response +{ + "type": "error", + "error": { + "type": "authentication_error", + "message": "Invalid API key" + } +} +``` + +### OpenRouter Mock Server + +#### Endpoints: +- `POST /api/v1/chat/completions` - Chat completions +- `POST /api/v1/chat/completions/stream` - Streaming chat + +#### Response Templates: +```json +// Success response +{ + "id": "chatcmpl-...", + "object": "chat.completion", + "created": 1234567890, + "model": "anthropic/claude-3.5-sonnet", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "Mock response" + }, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15 + } +} +``` + +## Test Configuration + +### Environment Variables +```bash +# Live API Tests +ANTHROPIC_API_KEY=your_key_here +OPENROUTER_API_KEY=your_key_here + +# Test Configuration +GENAI_TEST_MODE=live|mock +GENAI_TEST_TIMEOUT=30 +GENAI_TEST_CONCURRENT=10 +``` + +### Test Categories +```bash +# Run all tests +cargo test + +# Run only live tests +cargo test --features live-tests + +# Run only mock tests +cargo test --features mock-tests + +# Run performance tests +cargo test --features perf-tests + +# Run error scenario tests +cargo test --features error-tests +``` + +## Implementation Plan + +### Phase 1: Mock Server Infrastructure +1. Create mock server framework +2. Implement Anthropic mock endpoints +3. Implement OpenRouter mock endpoints +4. Add response template system + +### Phase 2: Enhanced Test Suite +1. Implement error scenario tests +2. Add performance benchmarks +3. Create contract validation tests +4. Enhance streaming tests + +### Phase 3: Integration & CI +1. Configure CI/CD pipelines +2. Add test reporting +3. Implement test data management +4. Add test documentation + +## Success Criteria + +### Functional Requirements +- [ ] All existing tests pass with mock servers +- [ ] New error scenarios are covered +- [ ] Performance benchmarks are established +- [ ] Streaming is thoroughly tested + +### Non-Functional Requirements +- [ ] Tests run in < 5 minutes +- [ ] Mock servers start in < 2 seconds +- [ ] 95% test coverage maintained +- [ ] No external dependencies for CI + +### Quality Requirements +- [ ] Clear error messages for failures +- [ ] Comprehensive test documentation +- [ ] Reproducible test results +- [ ] Proper test isolation + +## Maintenance + +### Regular Updates +- Update mock responses when APIs change +- Review test coverage monthly +- Update performance benchmarks quarterly +- Refresh test data as needed + +### Monitoring +- Track test execution times +- Monitor flaky tests +- Alert on test failures +- Generate test reports + +This specification provides a comprehensive foundation for improving the rust-genai library's test suite, ensuring reliability, performance, and compatibility with Anthropic and OpenRouter APIs. \ No newline at end of file diff --git a/examples/c11-cerebras.rs b/examples/c11-cerebras.rs new file mode 100644 index 00000000..b6e7a7a1 --- /dev/null +++ b/examples/c11-cerebras.rs @@ -0,0 +1,42 @@ +//! Cerebras basic chat and streaming example + +use genai::Client; +use genai::chat::printer::{PrintChatStreamOptions, print_chat_stream}; +use genai::chat::{ChatMessage, ChatRequest}; + +const MODEL_CEREBRAS: &str = "cerebras::llama-3.1-8b"; +const CEREBRAS_ENV: &str = "CEREBRAS_API_KEY"; + +#[tokio::main] +async fn main() -> Result<(), Box> { + if std::env::var(CEREBRAS_ENV).is_err() { + println!( + "Skipping: set {} to run this example (e.g., export {}=...)", + CEREBRAS_ENV, CEREBRAS_ENV + ); + return Ok(()); + } + + let question = "Why do stars twinkle?"; + + let chat_req = ChatRequest::new(vec![ + ChatMessage::system("Answer briefly in one sentence."), + ChatMessage::user(question), + ]); + + let client = Client::default(); + + println!("\n--- MODEL: {}", MODEL_CEREBRAS); + println!("\n--- Question:\n{}", question); + + println!("\n--- Answer:"); + let chat_res = client.exec_chat(MODEL_CEREBRAS, chat_req.clone(), None).await?; + println!("{}", chat_res.first_text().unwrap_or("NO ANSWER")); + + println!("\n--- Answer (streaming):"); + let stream = client.exec_chat_stream(MODEL_CEREBRAS, chat_req, None).await?; + let print_options = PrintChatStreamOptions::from_print_events(false); + print_chat_stream(stream, Some(&print_options)).await?; + + Ok(()) +} diff --git a/src/adapter/adapter_kind.rs b/src/adapter/adapter_kind.rs index 9fa0f6ad..46821207 100644 --- a/src/adapter/adapter_kind.rs +++ b/src/adapter/adapter_kind.rs @@ -1,5 +1,6 @@ use crate::adapter::adapters::together::TogetherAdapter; use crate::adapter::anthropic::AnthropicAdapter; +use crate::adapter::cerebras::CerebrasAdapter; use crate::adapter::cohere::CohereAdapter; use crate::adapter::deepseek::{self, DeepSeekAdapter}; use crate::adapter::fireworks::FireworksAdapter; @@ -7,6 +8,7 @@ use crate::adapter::gemini::GeminiAdapter; use crate::adapter::groq::{self, GroqAdapter}; use crate::adapter::nebius::NebiusAdapter; use crate::adapter::openai::OpenAIAdapter; +use crate::adapter::openrouter::OpenRouterAdapter; use crate::adapter::xai::XaiAdapter; use crate::adapter::zhipu::ZhipuAdapter; use crate::{ModelName, Result}; @@ -25,6 +27,8 @@ pub enum AdapterKind { OpenAIResp, /// Gemini adapter supports gemini native protocol. e.g., support thinking budget. Gemini, + /// For OpenRouter (OpenAI-compatible protocol) + OpenRouter, /// Anthopric native protocol as well Anthropic, /// For fireworks.ai, mostly OpenAI. @@ -45,6 +49,8 @@ pub enum AdapterKind { Cohere, /// OpenAI shared behavior + some custom. (currently, localhost only, can be customize with ServerTargetResolver). Ollama, + /// Cerebras (OpenAI-compatible protocol) + Cerebras, } /// Serialization/Parse implementations @@ -56,6 +62,7 @@ impl AdapterKind { AdapterKind::OpenAIResp => "OpenAIResp", AdapterKind::Gemini => "Gemini", AdapterKind::Anthropic => "Anthropic", + AdapterKind::OpenRouter => "OpenRouter", AdapterKind::Fireworks => "Fireworks", AdapterKind::Together => "Together", AdapterKind::Groq => "Groq", @@ -65,6 +72,7 @@ impl AdapterKind { AdapterKind::Zhipu => "Zhipu", AdapterKind::Cohere => "Cohere", AdapterKind::Ollama => "Ollama", + AdapterKind::Cerebras => "Cerebras", } } @@ -75,6 +83,7 @@ impl AdapterKind { AdapterKind::OpenAIResp => "openai_resp", AdapterKind::Gemini => "gemini", AdapterKind::Anthropic => "anthropic", + AdapterKind::OpenRouter => "openrouter", AdapterKind::Fireworks => "fireworks", AdapterKind::Together => "together", AdapterKind::Groq => "groq", @@ -84,6 +93,7 @@ impl AdapterKind { AdapterKind::Zhipu => "zhipu", AdapterKind::Cohere => "cohere", AdapterKind::Ollama => "ollama", + AdapterKind::Cerebras => "cerebras", } } @@ -93,6 +103,7 @@ impl AdapterKind { "openai_resp" => Some(AdapterKind::OpenAIResp), "gemini" => Some(AdapterKind::Gemini), "anthropic" => Some(AdapterKind::Anthropic), + "openrouter" => Some(AdapterKind::OpenRouter), "fireworks" => Some(AdapterKind::Fireworks), "together" => Some(AdapterKind::Together), "groq" => Some(AdapterKind::Groq), @@ -102,6 +113,7 @@ impl AdapterKind { "zhipu" => Some(AdapterKind::Zhipu), "cohere" => Some(AdapterKind::Cohere), "ollama" => Some(AdapterKind::Ollama), + "cerebras" => Some(AdapterKind::Cerebras), _ => None, } } @@ -116,6 +128,7 @@ impl AdapterKind { AdapterKind::OpenAIResp => Some(OpenAIAdapter::API_KEY_DEFAULT_ENV_NAME), AdapterKind::Gemini => Some(GeminiAdapter::API_KEY_DEFAULT_ENV_NAME), AdapterKind::Anthropic => Some(AnthropicAdapter::API_KEY_DEFAULT_ENV_NAME), + AdapterKind::OpenRouter => Some(OpenRouterAdapter::API_KEY_DEFAULT_ENV_NAME), AdapterKind::Fireworks => Some(FireworksAdapter::API_KEY_DEFAULT_ENV_NAME), AdapterKind::Together => Some(TogetherAdapter::API_KEY_DEFAULT_ENV_NAME), AdapterKind::Groq => Some(GroqAdapter::API_KEY_DEFAULT_ENV_NAME), @@ -125,6 +138,7 @@ impl AdapterKind { AdapterKind::Zhipu => Some(ZhipuAdapter::API_KEY_DEFAULT_ENV_NAME), AdapterKind::Cohere => Some(CohereAdapter::API_KEY_DEFAULT_ENV_NAME), AdapterKind::Ollama => None, + AdapterKind::Cerebras => Some(CerebrasAdapter::API_KEY_DEFAULT_ENV_NAME), } } } @@ -149,6 +163,7 @@ impl AdapterKind { /// Other Some adapters have to have model name namespaced to be used, /// - e.g., for together.ai `together::meta-llama/Llama-3-8b-chat-hf` /// - e.g., for nebius with `nebius::Qwen/Qwen3-235B-A22B` + /// - e.g., for cerebras with `cerebras::llama-3.1-8b` /// /// And all adapters can be force namspaced as well. /// @@ -164,6 +179,16 @@ impl AdapterKind { } } + // -- Special handling for OpenRouter models (they start with provider names) + if model.contains('/') + && (model.starts_with("openai/") + || model.starts_with("anthropic/") + || model.starts_with("meta-llama/") + || model.starts_with("google/")) + { + return Ok(Self::OpenRouter); + } + // -- Resolve from modelname if model.starts_with("o3") || model.starts_with("o4") diff --git a/src/adapter/adapters/cerebras/adapter_impl.rs b/src/adapter/adapters/cerebras/adapter_impl.rs new file mode 100644 index 00000000..d155d329 --- /dev/null +++ b/src/adapter/adapters/cerebras/adapter_impl.rs @@ -0,0 +1,99 @@ +use crate::ModelIden; +use crate::adapter::openai::OpenAIAdapter; +use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData}; +use crate::chat::{ChatOptionsSet, ChatRequest, ChatResponse, ChatStreamResponse}; +use crate::resolver::{AuthData, Endpoint}; +use crate::webc::WebResponse; +use crate::{Result, ServiceTarget}; +use reqwest::RequestBuilder; +use reqwest_eventsource::EventSource; + +pub struct CerebrasAdapter; + +// A non-exhaustive set of commonly available Cerebras models +pub(in crate::adapter) const MODELS: &[&str] = &[ + "llama-3.3-70b", + "llama-3.1-70b", + "llama-3.1-8b", + "llama-3.2-11b-vision", + "llama-3.2-90b-vision", + "llama-guard-3-8b", +]; + +impl CerebrasAdapter { + pub const API_KEY_DEFAULT_ENV_NAME: &str = "CEREBRAS_API_KEY"; +} + +// The Cerebras API is compatible with OpenAI Chat Completions. +impl Adapter for CerebrasAdapter { + fn default_endpoint() -> Endpoint { + const BASE_URL: &str = "https://api.cerebras.ai/v1/"; + Endpoint::from_static(BASE_URL) + } + + fn default_auth() -> AuthData { + AuthData::from_env(Self::API_KEY_DEFAULT_ENV_NAME) + } + + async fn all_model_names(_kind: AdapterKind) -> Result> { + Ok(MODELS.iter().map(|s| s.to_string()).collect()) + } + + fn get_service_url(model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> Result { + OpenAIAdapter::util_get_service_url(model, service_type, endpoint) + } + + fn to_web_request_data( + target: ServiceTarget, + service_type: ServiceType, + chat_req: ChatRequest, + chat_options: ChatOptionsSet<'_, '_>, + ) -> Result { + OpenAIAdapter::util_to_web_request_data(target, service_type, chat_req, chat_options, None) + } + + fn to_chat_response( + model_iden: ModelIden, + web_response: WebResponse, + options_set: ChatOptionsSet<'_, '_>, + ) -> Result { + OpenAIAdapter::to_chat_response(model_iden, web_response, options_set) + } + + fn to_chat_stream( + model_iden: ModelIden, + reqwest_builder: RequestBuilder, + options_set: ChatOptionsSet<'_, '_>, + ) -> Result { + let event_source = EventSource::new(reqwest_builder)?; + let cerebras_stream = super::streamer::CerebrasStreamer::new(event_source, model_iden.clone(), options_set); + let chat_stream = crate::chat::ChatStream::from_inter_stream(cerebras_stream); + + Ok(ChatStreamResponse { + model_iden, + stream: chat_stream, + }) + } + + fn to_embed_request_data( + _service_target: crate::ServiceTarget, + _embed_req: crate::embed::EmbedRequest, + _options_set: crate::embed::EmbedOptionsSet<'_, '_>, + ) -> Result { + Err(crate::Error::AdapterNotSupported { + adapter_kind: crate::adapter::AdapterKind::Cerebras, + feature: "embeddings".to_string(), + }) + } + + fn to_embed_response( + _model_iden: crate::ModelIden, + _web_response: crate::webc::WebResponse, + _options_set: crate::embed::EmbedOptionsSet<'_, '_>, + ) -> Result { + Err(crate::Error::AdapterNotSupported { + adapter_kind: crate::adapter::AdapterKind::Cerebras, + feature: "embeddings".to_string(), + }) + } +} diff --git a/src/adapter/adapters/cerebras/mod.rs b/src/adapter/adapters/cerebras/mod.rs new file mode 100644 index 00000000..561b3fc3 --- /dev/null +++ b/src/adapter/adapters/cerebras/mod.rs @@ -0,0 +1,12 @@ +//! API Documentation: https://inference-docs.cerebras.ai/ +//! Model Names: https://inference.cerebras.ai/models +//! Pricing: https://inference.cerebras.ai/pricing + +// region: --- Modules + +mod adapter_impl; +mod streamer; + +pub use adapter_impl::*; + +// endregion: --- Modules diff --git a/src/adapter/adapters/cerebras/streamer.rs b/src/adapter/adapters/cerebras/streamer.rs new file mode 100644 index 00000000..9bf3c355 --- /dev/null +++ b/src/adapter/adapters/cerebras/streamer.rs @@ -0,0 +1,170 @@ +use crate::adapter::adapters::support::{StreamerCapturedData, StreamerOptions}; +use crate::adapter::inter_stream::{InterStreamEnd, InterStreamEvent}; +use crate::chat::ChatOptionsSet; +use crate::{Error, ModelIden, Result}; +use reqwest_eventsource::{Event, EventSource}; +use serde_json::Value; +use std::pin::Pin; +use std::task::{Context, Poll}; +use value_ext::JsonValueExt; + +pub struct CerebrasStreamer { + inner: EventSource, + options: StreamerOptions, + + // -- Set by the poll_next + /// Flag to prevent polling the EventSource after a MessageStop event + done: bool, + captured_data: StreamerCapturedData, +} + +impl CerebrasStreamer { + pub fn new(inner: EventSource, model_iden: ModelIden, options_set: ChatOptionsSet<'_, '_>) -> Self { + Self { + inner, + done: false, + options: StreamerOptions::new(model_iden, options_set), + captured_data: Default::default(), + } + } +} + +impl futures::Stream for CerebrasStreamer { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.done { + // The last poll was definitely the end, so end the stream. + // This will prevent triggering a stream ended error + return Poll::Ready(None); + } + + while let Poll::Ready(event) = Pin::new(&mut self.inner).poll_next(cx) { + match event { + Some(Ok(Event::Open)) => return Poll::Ready(Some(Ok(InterStreamEvent::Start))), + Some(Ok(Event::Message(message))) => { + // -- End Message + // Cerebras may not send [DONE] like OpenAI, so we need to handle stream ending differently + if message.data == "[DONE]" { + self.done = true; + + // -- Build the usage and captured_content + let captured_usage = if self.options.capture_usage { + self.captured_data.usage.take() + } else { + None + }; + + let inter_stream_end = InterStreamEnd { + captured_usage, + captured_text_content: self.captured_data.content.take(), + captured_reasoning_content: self.captured_data.reasoning_content.take(), + captured_tool_calls: self.captured_data.tool_calls.take(), + }; + + return Poll::Ready(Some(Ok(InterStreamEvent::End(inter_stream_end)))); + } + + // -- Other Content Messages + // Parse to get the choice + let mut message_data: Value = + serde_json::from_str(&message.data).map_err(|serde_error| Error::StreamParse { + model_iden: self.options.model_iden.clone(), + serde_error, + })?; + + let first_choice: Option = message_data.x_take("/choices/0").ok(); + + // If we have a first choice, then it's a normal message + if let Some(mut first_choice) = first_choice { + // -- Finish Reason + // If finish_reason exists, it's the end of this choice. + // Since we support only a single choice, we can proceed, + // as there might be other messages, and the last one contains data: `[DONE]` + // NOTE: Cerebras may have different finish_reason behavior + if let Ok(_finish_reason) = first_choice.x_take::("finish_reason") { + // For Cerebras, we capture usage when we see finish_reason + if self.options.capture_usage + && let Ok(usage) = message_data.x_take("usage") + && let Ok(usage) = serde_json::from_value(usage) + { + self.captured_data.usage = Some(usage); + } + } + + // -- Content + if let Ok(Some(content)) = first_choice.x_take::>("/delta/content") { + // Add to the captured_content if chat options allow it + if self.options.capture_content { + match self.captured_data.content { + Some(ref mut c) => c.push_str(&content), + None => self.captured_data.content = Some(content.clone()), + } + } + + // Return the Event + return Poll::Ready(Some(Ok(InterStreamEvent::Chunk(content)))); + } + // If we do not have content, then log a trace message + // TODO: use tracing debug + tracing::warn!("EMPTY CHOICE CONTENT"); + } + // -- Usage message + else { + // For Cerebras, capture usage when choices are empty or null + if self.captured_data.usage.is_none() // this might be redundant + && self.options.capture_usage + && let Ok(usage) = message_data.x_take("usage") + && let Ok(usage) = serde_json::from_value(usage) + { + self.captured_data.usage = Some(usage); + } + } + } + Some(Err(err)) => { + // Cerebras sometimes ends the stream with a StreamEnded error instead of clean None + // We'll treat this as a normal stream end + tracing::debug!("Cerebras stream ended with error (this is expected): {}", err); + self.done = true; + + // -- Build the usage and captured_content + let captured_usage = if self.options.capture_usage { + self.captured_data.usage.take() + } else { + None + }; + + let inter_stream_end = InterStreamEnd { + captured_usage, + captured_text_content: self.captured_data.content.take(), + captured_reasoning_content: self.captured_data.reasoning_content.take(), + captured_tool_calls: self.captured_data.tool_calls.take(), + }; + + return Poll::Ready(Some(Ok(InterStreamEvent::End(inter_stream_end)))); + } + None => { + // Cerebras stream ends without [DONE], so we need to create the StreamEnd event here + self.done = true; + + // -- Build the usage and captured_content + let captured_usage = if self.options.capture_usage { + self.captured_data.usage.take() + } else { + None + }; + + let inter_stream_end = InterStreamEnd { + captured_usage, + captured_text_content: self.captured_data.content.take(), + captured_reasoning_content: self.captured_data.reasoning_content.take(), + captured_tool_calls: self.captured_data.tool_calls.take(), + }; + + return Poll::Ready(Some(Ok(InterStreamEvent::End(inter_stream_end)))); + } + } + } + Poll::Pending + } +} diff --git a/src/adapter/adapters/mod.rs b/src/adapter/adapters/mod.rs index a31217cd..6c315e49 100644 --- a/src/adapter/adapters/mod.rs +++ b/src/adapter/adapters/mod.rs @@ -1,6 +1,7 @@ mod support; pub(super) mod anthropic; +pub(super) mod cerebras; pub(super) mod cohere; pub(super) mod deepseek; pub(super) mod fireworks; @@ -10,6 +11,7 @@ pub(super) mod nebius; pub(super) mod ollama; pub(super) mod openai; pub(super) mod openai_resp; +pub(super) mod openrouter; pub(super) mod together; pub(super) mod xai; pub(super) mod zhipu; diff --git a/src/adapter/adapters/openrouter/mod.rs b/src/adapter/adapters/openrouter/mod.rs new file mode 100644 index 00000000..3c29453e --- /dev/null +++ b/src/adapter/adapters/openrouter/mod.rs @@ -0,0 +1,95 @@ +use crate::ServiceTarget; +use crate::adapter::adapters::openai::OpenAIAdapter; +use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData}; +use crate::chat::{ChatOptionsSet, ChatRequest, ChatResponse, ChatStreamResponse}; +use crate::embed::{EmbedOptionsSet, EmbedResponse}; +use crate::resolver::{AuthData, Endpoint}; +use crate::webc::WebResponse; +use crate::{Headers, ModelIden, Result}; +use reqwest::RequestBuilder; + +pub struct OpenRouterAdapter; + +impl OpenRouterAdapter { + pub const API_KEY_DEFAULT_ENV_NAME: &'static str = "OPENROUTER_API_KEY"; + + /// Add OpenRouter-specific headers to the request + fn add_openrouter_headers(headers: Headers) -> Headers { + let openrouter_headers = Headers::from([ + ("HTTP-Referer".to_string(), "https://github.com/sst/genai".to_string()), + ("X-Title".to_string(), "genai-rust".to_string()), + ]); + openrouter_headers.applied_to(headers) + } +} + +impl Adapter for OpenRouterAdapter { + fn default_auth() -> AuthData { + AuthData::from_env(Self::API_KEY_DEFAULT_ENV_NAME) + } + + fn default_endpoint() -> Endpoint { + const BASE_URL: &str = "https://openrouter.ai/api/v1/"; + Endpoint::from_static(BASE_URL) + } + + async fn all_model_names(_kind: AdapterKind) -> Result> { + // For now, return empty - OpenRouter has many models and they should be specified directly + Ok(vec![]) + } + + fn get_service_url(model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> Result { + OpenAIAdapter::get_service_url(model, service_type, endpoint) + } + + fn to_web_request_data( + target: ServiceTarget, + service_type: ServiceType, + chat_req: ChatRequest, + chat_options: ChatOptionsSet<'_, '_>, + ) -> Result { + let mut web_request_data = OpenAIAdapter::to_web_request_data(target, service_type, chat_req, chat_options)?; + + // Add OpenRouter-specific headers + web_request_data.headers = Self::add_openrouter_headers(web_request_data.headers); + + Ok(web_request_data) + } + + fn to_chat_response( + model_iden: ModelIden, + web_response: WebResponse, + options_set: ChatOptionsSet<'_, '_>, + ) -> Result { + OpenAIAdapter::to_chat_response(model_iden, web_response, options_set) + } + + fn to_chat_stream( + model_iden: ModelIden, + reqwest_builder: RequestBuilder, + options_set: ChatOptionsSet<'_, '_>, + ) -> Result { + OpenAIAdapter::to_chat_stream(model_iden, reqwest_builder, options_set) + } + + fn to_embed_request_data( + _service_target: ServiceTarget, + _embed_req: crate::embed::EmbedRequest, + _options_set: crate::embed::EmbedOptionsSet<'_, '_>, + ) -> Result { + // For now, OpenRouter embeddings are not supported + // This would require access to the private embed module in openai + Err(crate::Error::AdapterNotSupported { + adapter_kind: AdapterKind::OpenRouter, + feature: "embed".to_string(), + }) + } + + fn to_embed_response( + model_iden: ModelIden, + web_response: WebResponse, + _options_set: EmbedOptionsSet<'_, '_>, + ) -> Result { + OpenAIAdapter::to_embed_response(model_iden, web_response, _options_set) + } +} diff --git a/src/adapter/dispatcher.rs b/src/adapter/dispatcher.rs index 4e6030e7..6e80fd26 100644 --- a/src/adapter/dispatcher.rs +++ b/src/adapter/dispatcher.rs @@ -1,6 +1,7 @@ use super::groq::GroqAdapter; use crate::adapter::adapters::together::TogetherAdapter; use crate::adapter::anthropic::AnthropicAdapter; +use crate::adapter::cerebras::CerebrasAdapter; use crate::adapter::cohere::CohereAdapter; use crate::adapter::deepseek::DeepSeekAdapter; use crate::adapter::fireworks::FireworksAdapter; @@ -9,6 +10,8 @@ use crate::adapter::nebius::NebiusAdapter; use crate::adapter::ollama::OllamaAdapter; use crate::adapter::openai::OpenAIAdapter; use crate::adapter::openai_resp::OpenAIRespAdapter; +use crate::adapter::openrouter::OpenRouterAdapter; + use crate::adapter::xai::XaiAdapter; use crate::adapter::zhipu::ZhipuAdapter; use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData}; @@ -43,6 +46,8 @@ impl AdapterDispatcher { AdapterKind::Zhipu => ZhipuAdapter::default_endpoint(), AdapterKind::Cohere => CohereAdapter::default_endpoint(), AdapterKind::Ollama => OllamaAdapter::default_endpoint(), + AdapterKind::Cerebras => CerebrasAdapter::default_endpoint(), + AdapterKind::OpenRouter => Endpoint::from_static("https://openrouter.ai/api/v1/"), } } @@ -61,6 +66,8 @@ impl AdapterDispatcher { AdapterKind::Zhipu => ZhipuAdapter::default_auth(), AdapterKind::Cohere => CohereAdapter::default_auth(), AdapterKind::Ollama => OllamaAdapter::default_auth(), + AdapterKind::Cerebras => CerebrasAdapter::default_auth(), + AdapterKind::OpenRouter => AuthData::from_env(OpenRouterAdapter::API_KEY_DEFAULT_ENV_NAME), } } @@ -79,6 +86,8 @@ impl AdapterDispatcher { AdapterKind::Zhipu => ZhipuAdapter::all_model_names(kind).await, AdapterKind::Cohere => CohereAdapter::all_model_names(kind).await, AdapterKind::Ollama => OllamaAdapter::all_model_names(kind).await, + AdapterKind::Cerebras => CerebrasAdapter::all_model_names(kind).await, + AdapterKind::OpenRouter => OpenRouterAdapter::all_model_names(kind).await, } } @@ -97,6 +106,8 @@ impl AdapterDispatcher { AdapterKind::Zhipu => ZhipuAdapter::get_service_url(model, service_type, endpoint), AdapterKind::Cohere => CohereAdapter::get_service_url(model, service_type, endpoint), AdapterKind::Ollama => OllamaAdapter::get_service_url(model, service_type, endpoint), + AdapterKind::Cerebras => CerebrasAdapter::get_service_url(model, service_type, endpoint), + AdapterKind::OpenRouter => OpenRouterAdapter::get_service_url(model, service_type, endpoint), } } @@ -127,6 +138,10 @@ impl AdapterDispatcher { AdapterKind::Zhipu => ZhipuAdapter::to_web_request_data(target, service_type, chat_req, options_set), AdapterKind::Cohere => CohereAdapter::to_web_request_data(target, service_type, chat_req, options_set), AdapterKind::Ollama => OllamaAdapter::to_web_request_data(target, service_type, chat_req, options_set), + AdapterKind::Cerebras => CerebrasAdapter::to_web_request_data(target, service_type, chat_req, options_set), + AdapterKind::OpenRouter => { + OpenRouterAdapter::to_web_request_data(target, service_type, chat_req, options_set) + } } } @@ -149,6 +164,8 @@ impl AdapterDispatcher { AdapterKind::Zhipu => ZhipuAdapter::to_chat_response(model_iden, web_response, options_set), AdapterKind::Cohere => CohereAdapter::to_chat_response(model_iden, web_response, options_set), AdapterKind::Ollama => OllamaAdapter::to_chat_response(model_iden, web_response, options_set), + AdapterKind::Cerebras => CerebrasAdapter::to_chat_response(model_iden, web_response, options_set), + AdapterKind::OpenRouter => OpenRouterAdapter::to_chat_response(model_iden, web_response, options_set), } } @@ -174,6 +191,8 @@ impl AdapterDispatcher { AdapterKind::Zhipu => ZhipuAdapter::to_chat_stream(model_iden, reqwest_builder, options_set), AdapterKind::Cohere => CohereAdapter::to_chat_stream(model_iden, reqwest_builder, options_set), AdapterKind::Ollama => OllamaAdapter::to_chat_stream(model_iden, reqwest_builder, options_set), + AdapterKind::Cerebras => CerebrasAdapter::to_chat_stream(model_iden, reqwest_builder, options_set), + AdapterKind::OpenRouter => OpenRouterAdapter::to_chat_stream(model_iden, reqwest_builder, options_set), } } @@ -200,6 +219,8 @@ impl AdapterDispatcher { AdapterKind::Zhipu => ZhipuAdapter::to_embed_request_data(target, embed_req, options_set), AdapterKind::Cohere => CohereAdapter::to_embed_request_data(target, embed_req, options_set), AdapterKind::Ollama => OllamaAdapter::to_embed_request_data(target, embed_req, options_set), + AdapterKind::Cerebras => CerebrasAdapter::to_embed_request_data(target, embed_req, options_set), + AdapterKind::OpenRouter => OpenRouterAdapter::to_embed_request_data(target, embed_req, options_set), } } @@ -225,6 +246,8 @@ impl AdapterDispatcher { AdapterKind::Zhipu => ZhipuAdapter::to_embed_response(model_iden, web_response, options_set), AdapterKind::Cohere => CohereAdapter::to_embed_response(model_iden, web_response, options_set), AdapterKind::Ollama => OllamaAdapter::to_embed_response(model_iden, web_response, options_set), + AdapterKind::Cerebras => CerebrasAdapter::to_embed_response(model_iden, web_response, options_set), + AdapterKind::OpenRouter => OpenRouterAdapter::to_embed_response(model_iden, web_response, options_set), } } } diff --git a/tests/live_api_tests.rs b/tests/live_api_tests.rs new file mode 100644 index 00000000..737934fe --- /dev/null +++ b/tests/live_api_tests.rs @@ -0,0 +1,434 @@ +//! Live API integration tests +//! +//! These tests run against actual Anthropic, OpenRouter, and Together.ai APIs. +//! They require valid API keys to be set in environment variables: +//! - ANTHROPIC_API_KEY for Anthropic tests +//! - OPENROUTER_API_KEY for OpenRouter tests +//! - TOGETHER_API_KEY for Together.ai tests +//! +//! To run these tests: +//! cargo test --test live_api_tests -- --ignored +//! +//! Tests will be skipped if API keys are not available. + +mod support; + +use genai::Client; +use genai::chat::{ChatMessage, ChatOptions, ChatRequest, Tool}; +use serial_test::serial; +use support::{TestResult, extract_stream_end}; + +/// Helper to check if environment variable is set +fn has_env_key(key: &str) -> bool { + std::env::var(key).is_ok_and(|v| !v.is_empty()) +} + +// ===== ANTHROPIC LIVE API TESTS ===== + +#[tokio::test] +#[serial] +#[ignore] // Ignored by default to avoid accidental API calls +async fn test_anthropic_live_basic_chat() -> TestResult<()> { + if !has_env_key("ANTHROPIC_API_KEY") { + println!("Skipping ANTHROPIC_API_KEY not set"); + return Ok(()); + } + + let client = Client::default(); + let chat_req = ChatRequest::new(vec![ + ChatMessage::system("You are a helpful assistant."), + ChatMessage::user("Say 'Hello from live test!'"), + ]); + + let result = client.exec_chat("claude-3-5-haiku-latest", chat_req, None).await?; + + let content = result.first_text().ok_or("Should have content")?; + assert!(!content.is_empty()); + assert!(content.contains("Hello")); + println!("Anthropic basic chat response: {}", content); + Ok(()) +} + +#[tokio::test] +#[serial] +#[ignore] +async fn test_anthropic_live_tool_calling() -> TestResult<()> { + if !has_env_key("ANTHROPIC_API_KEY") { + println!("Skipping ANTHROPIC_API_KEY not set"); + return Ok(()); + } + + let client = Client::default(); + + let tool = Tool::new("get_weather").with_schema(serde_json::json!({ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location"] + })); + + let chat_req = ChatRequest::new(vec![ChatMessage::user("What's the weather in Paris?")]).append_tool(tool); + + let result = client.exec_chat("claude-3-5-haiku-latest", chat_req, None).await?; + + let content = result.first_text().ok_or("Should have content")?; + assert!(!content.is_empty()); + println!("Anthropic tool call response: {}", content); + Ok(()) +} + +#[tokio::test] +#[serial] +#[ignore] +async fn test_anthropic_live_streaming() -> TestResult<()> { + if !has_env_key("ANTHROPIC_API_KEY") { + println!("Skipping ANTHROPIC_API_KEY not set"); + return Ok(()); + } + + let client = Client::default(); + let chat_req = ChatRequest::new(vec![ChatMessage::user("Count from 1 to 5 slowly")]); + + let options = ChatOptions::default().with_capture_content(true); + + let chat_res = client + .exec_chat_stream("claude-3-5-haiku-latest", chat_req, Some(&options)) + .await?; + + let stream_extract = extract_stream_end(chat_res.stream).await?; + let content = stream_extract.content.ok_or("Should have content")?; + + assert!(!content.is_empty()); + println!("Anthropic streaming content: {}", content); + Ok(()) +} + +// ===== OPENROUTER LIVE API TESTS ===== + +#[tokio::test] +#[serial] +#[ignore] +async fn test_openrouter_live_basic_chat() -> TestResult<()> { + if !has_env_key("OPENROUTER_API_KEY") { + println!("Skipping OPENROUTER_API_KEY not set"); + return Ok(()); + } + + let client = Client::default(); + let chat_req = ChatRequest::new(vec![ + ChatMessage::system("You are a helpful assistant."), + ChatMessage::user("Say 'Hello from OpenRouter live test!'"), + ]); + + let result = client.exec_chat("anthropic/claude-3.5-sonnet", chat_req, None).await?; + + let content = result.first_text().ok_or("Should have content")?; + assert!(!content.is_empty()); + assert!(content.contains("Hello")); + println!("OpenRouter basic chat response: {}", content); + Ok(()) +} + +#[tokio::test] +#[serial] +#[ignore] +async fn test_openrouter_live_tool_calling() -> TestResult<()> { + if !has_env_key("OPENROUTER_API_KEY") { + println!("Skipping OPENROUTER_API_KEY not set"); + return Ok(()); + } + + let client = Client::default(); + + let tool = Tool::new("get_weather").with_schema(serde_json::json!({ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location"] + })); + + let chat_req = ChatRequest::new(vec![ChatMessage::user("What's the weather in Tokyo?")]).append_tool(tool); + + let result = client.exec_chat("anthropic/claude-3.5-sonnet", chat_req, None).await?; + + let content = result.first_text().ok_or("Should have content")?; + assert!(!content.is_empty()); + println!("OpenRouter tool call response: {}", content); + Ok(()) +} + +#[tokio::test] +#[serial] +#[ignore] +async fn test_openrouter_live_streaming() -> TestResult<()> { + if !has_env_key("OPENROUTER_API_KEY") { + println!("Skipping OPENROUTER_API_KEY not set"); + return Ok(()); + } + + let client = Client::default(); + let chat_req = ChatRequest::new(vec![ChatMessage::user("Count from 1 to 5 slowly")]); + + let options = ChatOptions::default().with_capture_content(true); + + let chat_res = client + .exec_chat_stream("openrouter::anthropic/claude-3.5-sonnet", chat_req, Some(&options)) + .await?; + + let stream_extract = extract_stream_end(chat_res.stream).await?; + let content = stream_extract.content.ok_or("Should have content")?; + + assert!(!content.is_empty()); + println!("OpenRouter streaming content: {}", content); + Ok(()) +} + +// ===== TOGETHER.AI LIVE API TESTS ===== + +#[tokio::test] +#[serial] +#[ignore] +async fn test_together_live_basic_chat() -> TestResult<()> { + if !has_env_key("TOGETHER_API_KEY") { + println!("Skipping TOGETHER_API_KEY not set"); + return Ok(()); + } + + let client = Client::default(); + let chat_req = ChatRequest::new(vec![ + ChatMessage::system("You are a helpful assistant."), + ChatMessage::user("Say 'Hello from Together.ai live test!'"), + ]); + + let result = client + .exec_chat("together::meta-llama/Llama-3.2-3B-Instruct-Turbo", chat_req, None) + .await?; + + let content = result.first_text().ok_or("Should have content")?; + assert!(!content.is_empty()); + assert!(content.contains("Hello")); + println!("Together.ai basic chat response: {}", content); + Ok(()) +} + +#[tokio::test] +#[serial] +#[ignore] +async fn test_together_live_tool_calling() -> TestResult<()> { + if !has_env_key("TOGETHER_API_KEY") { + println!("Skipping TOGETHER_API_KEY not set"); + return Ok(()); + } + + let client = Client::default(); + + let tool = Tool::new("get_weather").with_schema(serde_json::json!({ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location"] + })); + + let chat_req = ChatRequest::new(vec![ChatMessage::user("What's the weather in Tokyo?")]).append_tool(tool); + + let result = client + .exec_chat("together::meta-llama/Llama-3.2-3B-Instruct-Turbo", chat_req, None) + .await?; + + let content = result.first_text().ok_or("Should have content")?; + assert!(!content.is_empty()); + println!("Together.ai tool call response: {}", content); + Ok(()) +} + +#[tokio::test] +#[serial] +#[ignore] +async fn test_together_live_streaming() -> TestResult<()> { + if !has_env_key("TOGETHER_API_KEY") { + println!("Skipping TOGETHER_API_KEY not set"); + return Ok(()); + } + + let client = Client::default(); + let chat_req = ChatRequest::new(vec![ChatMessage::user("Count from 1 to 5 slowly")]); + + let options = ChatOptions::default().with_capture_content(true); + + let chat_res = client + .exec_chat_stream( + "together::meta-llama/Llama-3.2-3B-Instruct-Turbo", + chat_req, + Some(&options), + ) + .await?; + + let stream_extract = extract_stream_end(chat_res.stream).await?; + let content = stream_extract.content.ok_or("Should have content")?; + + assert!(!content.is_empty()); + println!("Together.ai streaming content: {}", content); + Ok(()) +} + +// ===== CROSS-PROVIDER COMPARISON TESTS ===== + +#[tokio::test] +#[serial] +#[ignore] +async fn test_cross_provider_model_comparison() -> TestResult<()> { + if !has_env_key("ANTHROPIC_API_KEY") || !has_env_key("OPENROUTER_API_KEY") || !has_env_key("TOGETHER_API_KEY") { + println!("Skipping comparison test - missing API keys"); + return Ok(()); + } + + // Test same prompt across both providers + let prompt = "What is 2 + 2? Answer with just the number."; + + // Anthropic + let anthropic_client = Client::default(); + let anthropic_chat_req = ChatRequest::new(vec![ChatMessage::user(prompt)]); + + let anthropic_result = anthropic_client + .exec_chat("claude-3-5-haiku-latest", anthropic_chat_req, None) + .await?; + + // OpenRouter (using Anthropic model via OpenRouter) + let openrouter_client = Client::default(); + let openrouter_chat_req = ChatRequest::new(vec![ChatMessage::user(prompt)]); + + let openrouter_result = openrouter_client + .exec_chat("openrouter::anthropic/claude-3.5-sonnet", openrouter_chat_req, None) + .await?; + + // Together.ai + let together_client = Client::default(); + let together_chat_req = ChatRequest::new(vec![ChatMessage::user(prompt)]); + + let together_result = together_client + .exec_chat( + "together::meta-llama/Llama-3.2-3B-Instruct-Turbo", + together_chat_req, + None, + ) + .await?; + + // All should give similar answers + let anthropic_content = anthropic_result.first_text().ok_or("Should have content")?; + let openrouter_content = openrouter_result.first_text().ok_or("Should have content")?; + let together_content = together_result.first_text().ok_or("Should have content")?; + + assert!(!anthropic_content.is_empty()); + assert!(!openrouter_content.is_empty()); + assert!(!together_content.is_empty()); + + println!("Anthropic response: {}", anthropic_content); + println!("OpenRouter response: {}", openrouter_content); + println!("Together.ai response: {}", together_content); + + // All should contain "4" somewhere + assert!(anthropic_content.contains("4") || anthropic_content.contains("four")); + assert!(openrouter_content.contains("4") || openrouter_content.contains("four")); + assert!(together_content.contains("4") || together_content.contains("four")); + Ok(()) +} + +// ===== PERFORMANCE TESTS ===== + +#[tokio::test] +#[serial] +#[ignore] +async fn test_anthropic_live_response_time() -> TestResult<()> { + if !has_env_key("ANTHROPIC_API_KEY") { + println!("Skipping ANTHROPIC_API_KEY not set"); + return Ok(()); + } + + let client = Client::default(); + let chat_req = ChatRequest::new(vec![ChatMessage::user("What is 2 + 2?")]); + + let start = std::time::Instant::now(); + let result = client.exec_chat("claude-3-5-haiku-latest", chat_req, None).await?; + let duration = start.elapsed(); + + let content = result.first_text().ok_or("Should have content")?; + assert!(!content.is_empty()); + + println!("Anthropic response time: {:?} for content: {}", duration, content); + assert!(duration.as_secs() < 30, "Response should be under 30 seconds"); + Ok(()) +} + +#[tokio::test] +#[serial] +#[ignore] +async fn test_openrouter_live_response_time() -> TestResult<()> { + if !has_env_key("OPENROUTER_API_KEY") { + println!("Skipping OPENROUTER_API_KEY not set"); + return Ok(()); + } + + let client = Client::default(); + let chat_req = ChatRequest::new(vec![ChatMessage::user("What is 2 + 2?")]); + + let start = std::time::Instant::now(); + let result = client.exec_chat("anthropic/claude-3.5-sonnet", chat_req, None).await?; + let duration = start.elapsed(); + + let content = result.first_text().ok_or("Should have content")?; + assert!(!content.is_empty()); + + println!("OpenRouter response time: {:?} for content: {}", duration, content); + assert!(duration.as_secs() < 30, "Response should be under 30 seconds"); + Ok(()) +} + +#[tokio::test] +#[serial] +#[ignore] +async fn test_together_live_response_time() -> TestResult<()> { + if !has_env_key("TOGETHER_API_KEY") { + println!("Skipping TOGETHER_API_KEY not set"); + return Ok(()); + } + + let client = Client::default(); + let chat_req = ChatRequest::new(vec![ChatMessage::user("What is 2 + 2?")]); + + let start = std::time::Instant::now(); + let result = client + .exec_chat("together::meta-llama/Llama-3.2-3B-Instruct-Turbo", chat_req, None) + .await?; + let duration = start.elapsed(); + + let content = result.first_text().ok_or("Should have content")?; + assert!(!content.is_empty()); + + println!("Together.ai response time: {:?} for content: {}", duration, content); + assert!(duration.as_secs() < 30, "Response should be under 30 seconds"); + Ok(()) +} diff --git a/tests/mock_tests.rs b/tests/mock_tests.rs new file mode 100644 index 00000000..82a1ef00 --- /dev/null +++ b/tests/mock_tests.rs @@ -0,0 +1,817 @@ +//! Mock server integration tests using wiremock + +use serial_test::serial; +use uuid::Uuid; +use wiremock::{ + Mock, MockServer, ResponseTemplate, + matchers::{header, method, path}, +}; + +/// Generate a mock message ID +fn generate_message_id() -> String { + format!("msg_{}", Uuid::new_v4().simple()) +} + +/// Generate a mock chat completion ID +fn generate_chat_id() -> String { + format!("chatcmpl-{}", Uuid::new_v4().simple()) +} + +/// Create a standard success response structure +fn create_standard_usage() -> serde_json::Value { + serde_json::json!({ + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15 + }) +} + +/// Create Anthropic-style response +fn create_anthropic_response() -> serde_json::Value { + serde_json::json!({ + "id": generate_message_id(), + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "Hello! I'm a mock Anthropic response."}], + "model": "claude-3-5-haiku-latest", + "stop_reason": "end_turn", + "stop_sequence": null, + "usage": create_standard_usage() + }) +} + +/// Create OpenRouter-style response +fn create_openrouter_response() -> serde_json::Value { + serde_json::json!({ + "id": generate_chat_id(), + "object": "chat.completion", + "created": 1234567890, + "model": "anthropic/claude-3.5-sonnet", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! I'm a mock OpenRouter response." + }, + "finish_reason": "stop" + }], + "usage": create_standard_usage() + }) +} + +/// Create Anthropic tool response +fn create_anthropic_tool_response() -> serde_json::Value { + serde_json::json!({ + "id": generate_message_id(), + "type": "message", + "role": "assistant", + "content": [{ + "type": "tool_use", + "id": format!("toolu_{}", Uuid::new_v4().simple()), + "name": "get_weather", + "input": { + "location": "Paris", + "unit": "celsius" + } + }], + "model": "claude-3-5-haiku-latest", + "stop_reason": "tool_use", + "stop_sequence": null, + "usage": create_standard_usage() + }) +} + +#[tokio::test] +#[serial] +async fn test_anthropic_mock_server_basic() { + let mock_server = MockServer::start().await; + + // Mock the messages endpoint + Mock::given(method("POST")) + .and(path("/v1/messages")) + .and(header("x-api-key", "test-key")) + .respond_with(ResponseTemplate::new(200).set_body_json(create_anthropic_response())) + .mount(&mock_server) + .await; + + // Test basic request + let client = reqwest::Client::new(); + let response = client + .post(format!("{}/v1/messages", mock_server.uri())) + .header("x-api-key", "test-key") + .json(&serde_json::json!({ + "model": "claude-3-5-haiku-latest", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10 + })) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), 200); + + let json: serde_json::Value = response.json().await.unwrap(); + assert_eq!(json["type"], "message"); + assert_eq!(json["role"], "assistant"); + assert_eq!(json["content"][0]["text"], "Hello! I'm a mock Anthropic response."); +} + +#[tokio::test] +#[serial] +async fn test_openrouter_mock_server_basic() { + let mock_server = MockServer::start().await; + + // Mock the chat completions endpoint + Mock::given(method("POST")) + .and(path("/api/v1/chat/completions")) + .and(header("authorization", "Bearer test-key")) + .respond_with(ResponseTemplate::new(200).set_body_json(create_openrouter_response())) + .mount(&mock_server) + .await; + + // Test basic request + let client = reqwest::Client::new(); + let response = client + .post(format!("{}/api/v1/chat/completions", mock_server.uri())) + .header("authorization", "Bearer test-key") + .json(&serde_json::json!({ + "model": "anthropic/claude-3.5-sonnet", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10 + })) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), 200); + + let json: serde_json::Value = response.json().await.unwrap(); + assert_eq!(json["object"], "chat.completion"); + assert_eq!(json["choices"][0]["message"]["role"], "assistant"); + assert_eq!( + json["choices"][0]["message"]["content"], + "Hello! I'm a mock OpenRouter response." + ); +} + +#[tokio::test] +#[serial] +async fn test_anthropic_tool_call() { + let mock_server = MockServer::start().await; + + Mock::given(method("POST")) + .and(path("/v1/messages")) + .and(header("x-api-key", "test-key")) + .respond_with(ResponseTemplate::new(200).set_body_json(create_anthropic_tool_response())) + .mount(&mock_server) + .await; + + let client = reqwest::Client::new(); + let response = client + .post(format!("{}/v1/messages", mock_server.uri())) + .header("x-api-key", "test-key") + .json(&serde_json::json!({ + "model": "claude-3-5-haiku-latest", + "messages": [{"role": "user", "content": "What's the weather?"}], + "tools": [{ + "name": "get_weather", + "description": "Get weather information", + "input_schema": { + "type": "object", + "properties": { + "location": {"type": "string"}, + "unit": {"type": "string"} + }, + "required": ["location"] + } + }], + "max_tokens": 100 + })) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), 200); + + let json: serde_json::Value = response.json().await.unwrap(); + assert_eq!(json["content"][0]["type"], "tool_use"); + assert_eq!(json["content"][0]["name"], "get_weather"); + assert_eq!(json["content"][0]["input"]["location"], "Paris"); +} + +#[tokio::test] +#[serial] +async fn test_anthropic_streaming() { + let mock_server = MockServer::start().await; + + Mock::given(method("POST")) + .and(path("/v1/messages/beta/stream")) + .and(header("x-api-key", "test-key")) + .respond_with(ResponseTemplate::new(200).set_body_string( + "event: message_start\ndata: {\"type\": \"message_start\"}\n\nevent: content_block_delta\ndata: {\"type\": \"content_block_delta\", \"delta\": {\"text\": \"Hello\"}}\n\nevent: message_stop\ndata: {\"type\": \"message_stop\"}\n\n" + )) + .mount(&mock_server) + .await; + + let client = reqwest::Client::new(); + let response = client + .post(format!("{}/v1/messages/beta/stream", mock_server.uri())) + .header("x-api-key", "test-key") + .json(&serde_json::json!({ + "model": "claude-3-5-haiku-latest", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10 + })) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), 200); + // Note: wiremock may not preserve content-type header exactly +} + +#[tokio::test] +#[serial] +async fn test_openrouter_streaming() { + let mock_server = MockServer::start().await; + + Mock::given(method("POST")) + .and(path("/api/v1/chat/completions/stream")) + .and(header("authorization", "Bearer test-key")) + .respond_with(ResponseTemplate::new(200).set_body_string( + "data: {\"id\": \"chatcmpl-...\", \"object\": \"chat.completion.chunk\", \"choices\": [{\"index\": 0, \"delta\": {\"content\": \"Hello\"}}]}\n\ndata: [DONE]\n\n" + )) + .mount(&mock_server) + .await; + + let client = reqwest::Client::new(); + let response = client + .post(format!("{}/api/v1/chat/completions/stream", mock_server.uri())) + .header("authorization", "Bearer test-key") + .json(&serde_json::json!({ + "model": "anthropic/claude-3.5-sonnet", + "messages": [{"role": "user", "content": "Hello"}], + "stream": true, + "max_tokens": 10 + })) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), 200); + // Note: wiremock may not preserve content-type header exactly +} + +#[tokio::test] +#[serial] +async fn test_anthropic_auth_error() { + let mock_server = MockServer::start().await; + + let error_response = serde_json::json!({ + "type": "error", + "error": { + "type": "authentication_error", + "message": "Invalid API key" + } + }); + + Mock::given(method("POST")) + .and(path("/v1/messages")) + .and(header("x-api-key", "invalid-key")) + .respond_with(ResponseTemplate::new(401).set_body_json(error_response)) + .mount(&mock_server) + .await; + + let client = reqwest::Client::new(); + let response = client + .post(format!("{}/v1/messages", mock_server.uri())) + .header("x-api-key", "invalid-key") + .json(&serde_json::json!({ + "model": "claude-3-5-haiku-latest", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10 + })) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), 401); +} + +#[tokio::test] +#[serial] +async fn test_openrouter_auth_error() { + let mock_server = MockServer::start().await; + + let error_response = serde_json::json!({ + "error": { + "message": "Invalid API key", + "type": "invalid_api_key", + "code": "invalid_api_key" + } + }); + + Mock::given(method("POST")) + .and(path("/api/v1/chat/completions")) + .and(header("authorization", "Bearer invalid-key")) + .respond_with(ResponseTemplate::new(401).set_body_json(error_response)) + .mount(&mock_server) + .await; + + let client = reqwest::Client::new(); + let response = client + .post(format!("{}/api/v1/chat/completions", mock_server.uri())) + .header("authorization", "Bearer invalid-key") + .json(&serde_json::json!({ + "model": "anthropic/claude-3.5-sonnet", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10 + })) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), 401); +} + +#[tokio::test] +#[serial] +async fn test_anthropic_json_mode() { + let mock_server = MockServer::start().await; + + let json_response = serde_json::json!({ + "id": generate_message_id(), + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "{\"colors\": [\"red\", \"green\", \"blue\"]}"}], + "model": "claude-3-5-haiku-latest", + "stop_reason": "end_turn", + "stop_sequence": null, + "usage": create_standard_usage() + }); + + Mock::given(method("POST")) + .and(path("/v1/messages")) + .and(header("x-api-key", "test-key")) + .respond_with(ResponseTemplate::new(200).set_body_json(json_response)) + .mount(&mock_server) + .await; + + let client = reqwest::Client::new(); + let response = client + .post(format!("{}/v1/messages", mock_server.uri())) + .header("x-api-key", "test-key") + .json(&serde_json::json!({ + "model": "claude-3-5-haiku-latest", + "messages": [{"role": "user", "content": "List 3 colors in JSON format"}], + "response_format": {"type": "json_object"}, + "max_tokens": 100 + })) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), 200); + + let json: serde_json::Value = response.json().await.unwrap(); + assert_eq!(json["type"], "message"); + assert_eq!( + json["content"][0]["text"], + "{\"colors\": [\"red\", \"green\", \"blue\"]}" + ); +} + +// ===== ENHANCED ERROR SCENARIO TESTS ===== + +#[tokio::test] +#[serial] +async fn test_anthropic_rate_limit_error() { + let mock_server = MockServer::start().await; + + let rate_limit_response = serde_json::json!({ + "type": "error", + "error": { + "type": "rate_limit_error", + "message": "Rate limit exceeded. Please try again later.", + "error": { + "type": "rate_limit_error", + "message": "Rate limit exceeded" + } + } + }); + + Mock::given(method("POST")) + .and(path("/v1/messages")) + .and(header("x-api-key", "test-key")) + .respond_with( + ResponseTemplate::new(429) + .set_body_json(rate_limit_response) + .insert_header("Retry-After", "60"), + ) + .mount(&mock_server) + .await; + + let client = reqwest::Client::new(); + let response = client + .post(format!("{}/v1/messages", mock_server.uri())) + .header("x-api-key", "test-key") + .json(&serde_json::json!({ + "model": "claude-3-5-haiku-latest", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10 + })) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), 429); + + // Check retry-after header + let retry_after = response.headers().get("Retry-After").unwrap(); + assert_eq!(retry_after, "60"); + + let json: serde_json::Value = response.json().await.unwrap(); + assert_eq!(json["type"], "error"); + assert_eq!(json["error"]["type"], "rate_limit_error"); +} + +#[tokio::test] +#[serial] +async fn test_openrouter_rate_limit_error() { + let mock_server = MockServer::start().await; + + let rate_limit_response = serde_json::json!({ + "error": { + "message": "Rate limit exceeded. Please try again later.", + "type": "rate_limit_exceeded", + "code": "rate_limit_exceeded" + } + }); + + Mock::given(method("POST")) + .and(path("/api/v1/chat/completions")) + .and(header("authorization", "Bearer test-key")) + .respond_with( + ResponseTemplate::new(429) + .set_body_json(rate_limit_response) + .insert_header("Retry-After", "30"), + ) + .mount(&mock_server) + .await; + + let client = reqwest::Client::new(); + let response = client + .post(format!("{}/api/v1/chat/completions", mock_server.uri())) + .header("authorization", "Bearer test-key") + .json(&serde_json::json!({ + "model": "anthropic/claude-3.5-sonnet", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10 + })) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), 429); + + // Check retry-after header + let retry_after = response.headers().get("Retry-After").unwrap(); + assert_eq!(retry_after, "30"); + + let json: serde_json::Value = response.json().await.unwrap(); + assert_eq!(json["error"]["type"], "rate_limit_exceeded"); +} + +#[tokio::test] +#[serial] +async fn test_anthropic_invalid_request_error() { + let mock_server = MockServer::start().await; + + let invalid_request_response = serde_json::json!({ + "type": "error", + "error": { + "type": "invalid_request_error", + "message": "Invalid request: model 'invalid-model' not found", + "error": { + "type": "invalid_request_error", + "message": "model 'invalid-model' not found" + } + } + }); + + Mock::given(method("POST")) + .and(path("/v1/messages")) + .and(header("x-api-key", "test-key")) + .respond_with(ResponseTemplate::new(400).set_body_json(invalid_request_response)) + .mount(&mock_server) + .await; + + let client = reqwest::Client::new(); + let response = client + .post(format!("{}/v1/messages", mock_server.uri())) + .header("x-api-key", "test-key") + .json(&serde_json::json!({ + "model": "invalid-model", + "messages": [{"role": "invalid", "content": 123}], + "max_tokens": 10 + })) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), 400); + + let json: serde_json::Value = response.json().await.unwrap(); + assert_eq!(json["type"], "error"); + assert_eq!(json["error"]["type"], "invalid_request_error"); + assert!(json["error"]["message"].as_str().unwrap().contains("invalid-model")); +} + +#[tokio::test] +#[serial] +async fn test_openrouter_invalid_request_error() { + let mock_server = MockServer::start().await; + + let invalid_request_response = serde_json::json!({ + "error": { + "message": "Invalid request: model 'invalid-model' not found", + "type": "invalid_request_error", + "code": "model_not_found" + } + }); + + Mock::given(method("POST")) + .and(path("/api/v1/chat/completions")) + .and(header("authorization", "Bearer test-key")) + .respond_with(ResponseTemplate::new(400).set_body_json(invalid_request_response)) + .mount(&mock_server) + .await; + + let client = reqwest::Client::new(); + let response = client + .post(format!("{}/api/v1/chat/completions", mock_server.uri())) + .header("authorization", "Bearer test-key") + .json(&serde_json::json!({ + "model": "invalid-model", + "messages": [{"role": "invalid", "content": 123}], + "max_tokens": 10 + })) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), 400); + + let json: serde_json::Value = response.json().await.unwrap(); + assert_eq!(json["error"]["type"], "invalid_request_error"); + assert_eq!(json["error"]["code"], "model_not_found"); +} + +#[tokio::test] +#[serial] +async fn test_anthropic_server_error() { + let mock_server = MockServer::start().await; + + let server_error_response = serde_json::json!({ + "type": "error", + "error": { + "type": "api_error", + "message": "Internal server error. Please try again.", + "error": { + "type": "api_error", + "message": "Internal server error" + } + } + }); + + Mock::given(method("POST")) + .and(path("/v1/messages")) + .and(header("x-api-key", "test-key")) + .respond_with(ResponseTemplate::new(500).set_body_json(server_error_response)) + .mount(&mock_server) + .await; + + let client = reqwest::Client::new(); + let response = client + .post(format!("{}/v1/messages", mock_server.uri())) + .header("x-api-key", "test-key") + .json(&serde_json::json!({ + "model": "claude-3-5-haiku-latest", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10 + })) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), 500); + + let json: serde_json::Value = response.json().await.unwrap(); + assert_eq!(json["type"], "error"); + assert_eq!(json["error"]["type"], "api_error"); +} + +#[tokio::test] +#[serial] +async fn test_openrouter_server_error() { + let mock_server = MockServer::start().await; + + let server_error_response = serde_json::json!({ + "error": { + "message": "Internal server error. Please try again.", + "type": "internal_server_error", + "code": "internal_error" + } + }); + + Mock::given(method("POST")) + .and(path("/api/v1/chat/completions")) + .and(header("authorization", "Bearer test-key")) + .respond_with(ResponseTemplate::new(500).set_body_json(server_error_response)) + .mount(&mock_server) + .await; + + let client = reqwest::Client::new(); + let response = client + .post(format!("{}/api/v1/chat/completions", mock_server.uri())) + .header("authorization", "Bearer test-key") + .json(&serde_json::json!({ + "model": "anthropic/claude-3.5-sonnet", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10 + })) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), 500); + + let json: serde_json::Value = response.json().await.unwrap(); + assert_eq!(json["error"]["type"], "internal_server_error"); + assert_eq!(json["error"]["code"], "internal_error"); +} + +#[tokio::test] +#[serial] +async fn test_anthropic_timeout_error() { + let mock_server = MockServer::start().await; + + // Simulate timeout by not responding and using a timeout template + Mock::given(method("POST")) + .and(path("/v1/messages")) + .and(header("x-api-key", "test-key")) + .respond_with(ResponseTemplate::new(408).set_body_json(serde_json::json!({ + "type": "error", + "error": { + "type": "timeout_error", + "message": "Request timeout. Please try again.", + "error": { + "type": "timeout_error", + "message": "Request timeout" + } + } + }))) + .mount(&mock_server) + .await; + + let client = reqwest::Client::new(); + let response = client + .post(format!("{}/v1/messages", mock_server.uri())) + .header("x-api-key", "test-key") + .json(&serde_json::json!({ + "model": "claude-3-5-haiku-latest", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10 + })) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), 408); + + let json: serde_json::Value = response.json().await.unwrap(); + assert_eq!(json["type"], "error"); + assert_eq!(json["error"]["type"], "timeout_error"); +} + +#[tokio::test] +#[serial] +async fn test_openrouter_timeout_error() { + let mock_server = MockServer::start().await; + + Mock::given(method("POST")) + .and(path("/api/v1/chat/completions")) + .and(header("authorization", "Bearer test-key")) + .respond_with(ResponseTemplate::new(408).set_body_json(serde_json::json!({ + "error": { + "message": "Request timeout. Please try again.", + "type": "timeout", + "code": "request_timeout" + } + }))) + .mount(&mock_server) + .await; + + let client = reqwest::Client::new(); + let response = client + .post(format!("{}/api/v1/chat/completions", mock_server.uri())) + .header("authorization", "Bearer test-key") + .json(&serde_json::json!({ + "model": "anthropic/claude-3.5-sonnet", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10 + })) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), 408); + + let json: serde_json::Value = response.json().await.unwrap(); + assert_eq!(json["error"]["type"], "timeout"); + assert_eq!(json["error"]["code"], "request_timeout"); +} + +#[tokio::test] +#[serial] +async fn test_anthropic_content_filter_error() { + let mock_server = MockServer::start().await; + + let content_filter_response = serde_json::json!({ + "type": "error", + "error": { + "type": "content_filter", + "message": "Content filtered due to policy violation.", + "error": { + "type": "content_filter", + "message": "Content policy violation" + } + } + }); + + Mock::given(method("POST")) + .and(path("/v1/messages")) + .and(header("x-api-key", "test-key")) + .respond_with(ResponseTemplate::new(400).set_body_json(content_filter_response)) + .mount(&mock_server) + .await; + + let client = reqwest::Client::new(); + let response = client + .post(format!("{}/v1/messages", mock_server.uri())) + .header("x-api-key", "test-key") + .json(&serde_json::json!({ + "model": "claude-3-5-haiku-latest", + "messages": [{"role": "user", "content": "Inappropriate content"}], + "max_tokens": 10 + })) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), 400); + + let json: serde_json::Value = response.json().await.unwrap(); + assert_eq!(json["type"], "error"); + assert_eq!(json["error"]["type"], "content_filter"); +} + +#[tokio::test] +#[serial] +async fn test_openrouter_content_filter_error() { + let mock_server = MockServer::start().await; + + let content_filter_response = serde_json::json!({ + "error": { + "message": "Content filtered due to policy violation.", + "type": "content_filter", + "code": "content_policy_violation" + } + }); + + Mock::given(method("POST")) + .and(path("/api/v1/chat/completions")) + .and(header("authorization", "Bearer test-key")) + .respond_with(ResponseTemplate::new(400).set_body_json(content_filter_response)) + .mount(&mock_server) + .await; + + let client = reqwest::Client::new(); + let response = client + .post(format!("{}/api/v1/chat/completions", mock_server.uri())) + .header("authorization", "Bearer test-key") + .json(&serde_json::json!({ + "model": "anthropic/claude-3.5-sonnet", + "messages": [{"role": "user", "content": "Inappropriate content"}], + "max_tokens": 10 + })) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), 400); + + let json: serde_json::Value = response.json().await.unwrap(); + assert_eq!(json["error"]["type"], "content_filter"); + assert_eq!(json["error"]["code"], "content_policy_violation"); +} diff --git a/tests/tests_p_cerebras.rs b/tests/tests_p_cerebras.rs new file mode 100644 index 00000000..534f894b --- /dev/null +++ b/tests/tests_p_cerebras.rs @@ -0,0 +1,80 @@ +mod support; + +use crate::support::{Check, TestResult, common_tests}; +use genai::adapter::AdapterKind; +use genai::resolver::AuthData; + +// Cerebras uses OpenAI-compatible chat completions +const MODEL: &str = "cerebras::llama-3.1-8b"; +const MODEL_NS: &str = "cerebras::llama-3.3-70b"; + +// region: --- Chat + +#[tokio::test] +async fn test_chat_simple_ok() -> TestResult<()> { + common_tests::common_test_chat_simple_ok(MODEL, None).await +} + +#[tokio::test] +async fn test_chat_namespaced_ok() -> TestResult<()> { + common_tests::common_test_chat_simple_ok(MODEL_NS, None).await +} + +#[tokio::test] +async fn test_chat_multi_system_ok() -> TestResult<()> { + common_tests::common_test_chat_multi_system_ok(MODEL).await +} + +#[tokio::test] +async fn test_chat_json_mode_ok() -> TestResult<()> { + common_tests::common_test_chat_json_mode_ok(MODEL, Some(Check::USAGE)).await +} + +#[tokio::test] +async fn test_chat_temperature_ok() -> TestResult<()> { + common_tests::common_test_chat_temperature_ok(MODEL).await +} + +#[tokio::test] +async fn test_chat_stop_sequences_ok() -> TestResult<()> { + common_tests::common_test_chat_stop_sequences_ok(MODEL).await +} + +// endregion: --- Chat + +// region: --- Chat Stream Tests + +#[tokio::test] +async fn test_chat_stream_simple_ok() -> TestResult<()> { + common_tests::common_test_chat_stream_simple_ok(MODEL, None).await +} + +#[tokio::test] +async fn test_chat_stream_capture_content_ok() -> TestResult<()> { + common_tests::common_test_chat_stream_capture_content_ok(MODEL).await +} + +#[tokio::test] +async fn test_chat_stream_capture_all_ok() -> TestResult<()> { + common_tests::common_test_chat_stream_capture_all_ok(MODEL, None).await +} + +// endregion: --- Chat Stream Tests + +// region: --- Resolver Tests + +#[tokio::test] +async fn test_resolver_auth_ok() -> TestResult<()> { + common_tests::common_test_resolver_auth_ok(MODEL, AuthData::from_env("CEREBRAS_API_KEY")).await +} + +// endregion: --- Resolver Tests + +// region: --- List + +#[tokio::test] +async fn test_list_models() -> TestResult<()> { + common_tests::common_test_list_models(AdapterKind::Cerebras, "llama-3.1-8b").await +} + +// endregion: --- List