diff --git a/cortex/cli.py b/cortex/cli.py index 267228b0..38be04f6 100644 --- a/cortex/cli.py +++ b/cortex/cli.py @@ -2,6 +2,7 @@ import json import logging import os +import re import select import sys import time @@ -193,6 +194,15 @@ def _get_api_key(self) -> str | None: self._detected_provider = detected_provider return key + # 2b. Fallback: allow multiple OpenAI keys in OPENAI_API_KEYS + openai_keys = os.environ.get("OPENAI_API_KEYS") + if openai_keys: + parsed_keys = [k.strip() for k in re.split(r"[\s,;]+", openai_keys) if k.strip()] + if parsed_keys: + self._debug("Using OpenAI API key from OPENAI_API_KEYS") + self._detected_provider = "openai" + return parsed_keys[0] + # Still no key self._print_error(t("api_key.not_found")) cx_print(t("api_key.configure_prompt"), "info") @@ -216,7 +226,7 @@ def _get_provider(self) -> str: # 3. Check env vars (may have been set by auto-detect) if os.environ.get("ANTHROPIC_API_KEY"): return "claude" - elif os.environ.get("OPENAI_API_KEY"): + elif os.environ.get("OPENAI_API_KEY") or os.environ.get("OPENAI_API_KEYS"): return "openai" # 4. Fallback to Ollama for offline mode diff --git a/cortex/llm_router.py b/cortex/llm_router.py index 38403d0d..210ed6be 100644 --- a/cortex/llm_router.py +++ b/cortex/llm_router.py @@ -15,6 +15,7 @@ import json import logging import os +import re import threading import time from dataclasses import dataclass @@ -47,6 +48,7 @@ class LLMProvider(Enum): CLAUDE = "claude" KIMI_K2 = "kimi_k2" + OPENAI = "openai" OLLAMA = "ollama" @@ -96,6 +98,10 @@ class LLMRouter: "input": 1.0, # Estimated lower cost "output": 5.0, # Estimated lower cost }, + LLMProvider.OPENAI: { + "input": 0.0, # Unknown/varies by model + "output": 0.0, # Unknown/varies by model + }, LLMProvider.OLLAMA: { "input": 0.0, # Free - local inference "output": 0.0, # Free - local inference @@ -118,8 +124,10 @@ def __init__( self, claude_api_key: str | None = None, kimi_api_key: str | None = None, + openai_api_key: str | None = None, ollama_base_url: str | None = None, ollama_model: str | None = None, + openai_model: str | None = None, default_provider: LLMProvider = LLMProvider.CLAUDE, enable_fallback: bool = True, track_costs: bool = True, @@ -130,14 +138,21 @@ def __init__( Args: claude_api_key: Anthropic API key (defaults to ANTHROPIC_API_KEY env) kimi_api_key: Moonshot API key (defaults to MOONSHOT_API_KEY env) + openai_api_key: OpenAI API key or delimited list (defaults to OPENAI_API_KEY/OPENAI_API_KEYS env) + Multiple keys are supported for rotation/fallback. Accepted delimiters for + openai_api_key/OPENAI_API_KEYS are commas, semicolons, or spaces. The + loader prefers OPENAI_API_KEYS when present and falls back to OPENAI_API_KEY. ollama_base_url: Ollama API base URL (defaults to http://localhost:11434) ollama_model: Ollama model to use (defaults to llama3.2) + openai_model: OpenAI model to use (defaults to gpt-4o-mini) default_provider: Fallback provider if routing fails enable_fallback: Try alternate LLM if primary fails track_costs: Track token usage and costs """ self.claude_api_key = claude_api_key or os.getenv("ANTHROPIC_API_KEY") self.kimi_api_key = kimi_api_key or os.getenv("MOONSHOT_API_KEY") + self.openai_api_keys = self._load_openai_keys(openai_api_key) + self.openai_model = openai_model or os.getenv("OPENAI_MODEL", "gpt-4o-mini") self.default_provider = default_provider self.enable_fallback = enable_fallback self.track_costs = track_costs @@ -145,10 +160,12 @@ def __init__( # Initialize clients (sync) self.claude_client = None self.kimi_client = None + self.openai_client = None # Initialize async clients self.claude_client_async = None self.kimi_client_async = None + self.openai_client_async = None if self.claude_api_key: self.claude_client = Anthropic(api_key=self.claude_api_key) @@ -168,6 +185,13 @@ def __init__( else: logger.warning("⚠️ No Kimi K2 API key provided") + if self.openai_api_keys: + self.openai_client = OpenAI(api_key=self.openai_api_keys[0]) + self.openai_client_async = AsyncOpenAI(api_key=self.openai_api_keys[0]) + logger.info("✅ OpenAI API client initialized") + else: + logger.warning("⚠️ No OpenAI API key provided") + # Initialize Ollama client (local inference) self.ollama_base_url = ollama_base_url or os.getenv( "OLLAMA_BASE_URL", "http://localhost:11434" @@ -200,9 +224,32 @@ def __init__( self.provider_stats = { LLMProvider.CLAUDE: {"requests": 0, "tokens": 0, "cost": 0.0}, LLMProvider.KIMI_K2: {"requests": 0, "tokens": 0, "cost": 0.0}, + LLMProvider.OPENAI: {"requests": 0, "tokens": 0, "cost": 0.0}, LLMProvider.OLLAMA: {"requests": 0, "tokens": 0, "cost": 0.0}, } + @staticmethod + def _split_api_keys(value: str | None) -> list[str]: + if not value: + return [] + return [key.strip() for key in re.split(r"[\s,;]+", value) if key.strip()] + + def _load_openai_keys(self, explicit_key: str | None) -> list[str]: + keys: list[str] = [] + if explicit_key: + keys.extend(self._split_api_keys(explicit_key)) + + keys.extend(self._split_api_keys(os.getenv("OPENAI_API_KEY"))) + keys.extend(self._split_api_keys(os.getenv("OPENAI_API_KEYS"))) + + seen = set() + unique_keys = [] + for key in keys: + if key and key not in seen: + unique_keys.append(key) + seen.add(key) + return unique_keys + def route_task( self, task_type: TaskType, force_provider: LLMProvider | None = None ) -> RoutingDecision: @@ -232,6 +279,9 @@ def route_task( if self.kimi_client and self.enable_fallback: logger.warning("Claude unavailable, falling back to Kimi K2") provider = LLMProvider.KIMI_K2 + elif self.openai_api_keys and self.enable_fallback: + logger.warning("Claude unavailable, falling back to OpenAI") + provider = LLMProvider.OPENAI elif self.ollama_client and self.enable_fallback: logger.warning("Claude unavailable, falling back to Ollama") provider = LLMProvider.OLLAMA @@ -242,12 +292,28 @@ def route_task( if self.claude_client and self.enable_fallback: logger.warning("Kimi K2 unavailable, falling back to Claude") provider = LLMProvider.CLAUDE + elif self.openai_api_keys and self.enable_fallback: + logger.warning("Kimi K2 unavailable, falling back to OpenAI") + provider = LLMProvider.OPENAI elif self.ollama_client and self.enable_fallback: logger.warning("Kimi K2 unavailable, falling back to Ollama") provider = LLMProvider.OLLAMA else: raise RuntimeError("Kimi K2 API not configured and no fallback available") + if provider == LLMProvider.OPENAI and not self.openai_api_keys: + if self.claude_client and self.enable_fallback: + logger.warning("OpenAI unavailable, falling back to Claude") + provider = LLMProvider.CLAUDE + elif self.kimi_client and self.enable_fallback: + logger.warning("OpenAI unavailable, falling back to Kimi K2") + provider = LLMProvider.KIMI_K2 + elif self.ollama_client and self.enable_fallback: + logger.warning("OpenAI unavailable, falling back to Ollama") + provider = LLMProvider.OLLAMA + else: + raise RuntimeError("OpenAI API not configured and no fallback available") + if provider == LLMProvider.OLLAMA and not self.ollama_client: if self.claude_client and self.enable_fallback: logger.warning("Ollama unavailable, falling back to Claude") @@ -255,6 +321,9 @@ def route_task( elif self.kimi_client and self.enable_fallback: logger.warning("Ollama unavailable, falling back to Kimi K2") provider = LLMProvider.KIMI_K2 + elif self.openai_api_keys and self.enable_fallback: + logger.warning("Ollama unavailable, falling back to OpenAI") + provider = LLMProvider.OPENAI else: raise RuntimeError("Ollama not available and no fallback configured") @@ -268,13 +337,16 @@ def _get_fallback_provider(self, current: LLMProvider) -> LLMProvider | None: """Find the next available provider that isn't the current one.""" candidates = [] - # Priority order: Claude -> Kimi -> Ollama + # Priority order: Claude -> Kimi -> OpenAI -> Ollama if self.claude_client and current != LLMProvider.CLAUDE: candidates.append(LLMProvider.CLAUDE) if self.kimi_client and current != LLMProvider.KIMI_K2: candidates.append(LLMProvider.KIMI_K2) + if self.openai_api_keys and current != LLMProvider.OPENAI: + candidates.append(LLMProvider.OPENAI) + if self.ollama_client and current != LLMProvider.OLLAMA: candidates.append(LLMProvider.OLLAMA) @@ -314,6 +386,8 @@ def complete( response = self._complete_claude(messages, temperature, max_tokens, tools) elif routing.provider == LLMProvider.KIMI_K2: response = self._complete_kimi(messages, temperature, max_tokens, tools) + elif routing.provider == LLMProvider.OPENAI: + response = self._complete_openai(messages, temperature, max_tokens, tools) else: # OLLAMA response = self._complete_ollama(messages, temperature, max_tokens, tools) @@ -326,7 +400,7 @@ def complete( return response except Exception as e: - logger.error(f"❌ Error with {routing.provider.value}: {e}") + self._log_provider_failure(routing.provider, e) # Try fallback if enabled if self.enable_fallback: @@ -343,11 +417,22 @@ def complete( tools=tools, ) else: - logger.error("❌ No fallback providers available") - raise + raise RuntimeError("No AI providers are available. Check your API keys.") else: raise + @staticmethod + def _log_provider_failure(provider: LLMProvider, error: Exception) -> None: + provider_names = { + LLMProvider.CLAUDE: "Claude", + LLMProvider.KIMI_K2: "Kimi K2", + LLMProvider.OPENAI: "OpenAI", + LLMProvider.OLLAMA: "Ollama", + } + name = provider_names.get(provider, provider.value) + logger.warning(f"{name} request failed") + logger.debug("%s error: %s", name, error) + def _complete_claude( self, messages: list[dict[str, str]], @@ -453,6 +538,75 @@ def _complete_kimi( raw_response=response.model_dump() if hasattr(response, "model_dump") else None, ) + def _complete_openai( + self, + messages: list[dict[str, str]], + temperature: float, + max_tokens: int, + tools: list[dict] | None = None, + ) -> LLMResponse: + """Generate completion using OpenAI API.""" + if not self.openai_api_keys: + raise RuntimeError("OpenAI client not initialized") + + last_error: Exception | None = None + for index, key in enumerate(self.openai_api_keys): + try: + client = ( + self.openai_client if index == 0 and self.openai_client else OpenAI(api_key=key) + ) + kwargs = self._build_openai_kwargs(messages, temperature, max_tokens, tools) + + response = client.chat.completions.create(**kwargs) + return self._openai_response_to_llmresponse(response) + except Exception as e: + last_error = e + if index < len(self.openai_api_keys) - 1: + logger.warning("OpenAI API key failed, trying next key") + logger.debug("OpenAI error: %s", e) + continue + break + + raise RuntimeError( + "OpenAI API request failed. Check your OpenAI API key or billing." + ) from last_error + + def _build_openai_kwargs( + self, + messages: list[dict[str, str]], + temperature: float, + max_tokens: int, + tools: list[dict] | None = None, + ) -> dict[str, Any]: + kwargs: dict[str, Any] = { + "model": self.openai_model, + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, + } + + if tools: + kwargs["tools"] = tools + kwargs["tool_choice"] = "auto" + + return kwargs + + def _openai_response_to_llmresponse(self, response: Any) -> LLMResponse: + content = response.choices[0].message.content or "" + input_tokens = response.usage.prompt_tokens + output_tokens = response.usage.completion_tokens + cost = self._calculate_cost(LLMProvider.OPENAI, input_tokens, output_tokens) + + return LLMResponse( + content=content, + provider=LLMProvider.OPENAI, + model=self.openai_model, + tokens_used=input_tokens + output_tokens, + cost_usd=cost, + latency_seconds=0.0, # Set by caller + raw_response=response.model_dump() if hasattr(response, "model_dump") else None, + ) + def _complete_ollama( self, messages: list[dict[str, str]], @@ -497,7 +651,8 @@ def _complete_ollama( ) except Exception as e: - logger.error(f"Ollama error: {e}") + logger.warning("Ollama request failed") + logger.debug("Ollama error: %s", e) raise RuntimeError( f"Ollama request failed. Is Ollama running? (ollama serve) Error: {e}" ) @@ -544,6 +699,11 @@ def get_stats(self) -> dict[str, Any]: "tokens": self.provider_stats[LLMProvider.KIMI_K2]["tokens"], "cost_usd": round(self.provider_stats[LLMProvider.KIMI_K2]["cost"], 4), }, + "openai": { + "requests": self.provider_stats[LLMProvider.OPENAI]["requests"], + "tokens": self.provider_stats[LLMProvider.OPENAI]["tokens"], + "cost_usd": round(self.provider_stats[LLMProvider.OPENAI]["cost"], 4), + }, "ollama": { "requests": self.provider_stats[LLMProvider.OLLAMA]["requests"], "tokens": self.provider_stats[LLMProvider.OLLAMA]["tokens"], @@ -602,6 +762,8 @@ async def acomplete( response = await self._acomplete_claude(messages, temperature, max_tokens, tools) elif routing.provider == LLMProvider.KIMI_K2: response = await self._acomplete_kimi(messages, temperature, max_tokens, tools) + elif routing.provider == LLMProvider.OPENAI: + response = await self._acomplete_openai(messages, temperature, max_tokens, tools) else: # OLLAMA response = await self._acomplete_ollama(messages, temperature, max_tokens, tools) @@ -614,7 +776,7 @@ async def acomplete( return response except Exception as e: - logger.error(f"❌ Error with {routing.provider.value}: {e}") + self._log_provider_failure(routing.provider, e) # Try fallback if enabled if self.enable_fallback: @@ -631,8 +793,7 @@ async def acomplete( tools=tools, ) else: - logger.error("❌ No fallback providers available") - raise + raise RuntimeError("No AI providers are available. Check your API keys.") else: raise @@ -739,6 +900,41 @@ async def _acomplete_kimi( raw_response=response.model_dump() if hasattr(response, "model_dump") else None, ) + async def _acomplete_openai( + self, + messages: list[dict[str, str]], + temperature: float, + max_tokens: int, + tools: list[dict] | None = None, + ) -> LLMResponse: + """Async: Generate completion using OpenAI API.""" + if not self.openai_api_keys: + raise RuntimeError("OpenAI client not initialized") + + last_error: Exception | None = None + for index, key in enumerate(self.openai_api_keys): + try: + client = ( + self.openai_client_async + if index == 0 and self.openai_client_async + else AsyncOpenAI(api_key=key) + ) + kwargs = self._build_openai_kwargs(messages, temperature, max_tokens, tools) + + response = await client.chat.completions.create(**kwargs) + return self._openai_response_to_llmresponse(response) + except Exception as e: + last_error = e + if index < len(self.openai_api_keys) - 1: + logger.warning("OpenAI API key failed, trying next key") + logger.debug("OpenAI error: %s", e) + continue + break + + raise RuntimeError( + "OpenAI API request failed. Check your OpenAI API key or billing." + ) from last_error + async def _acomplete_ollama( self, messages: list[dict[str, str]], @@ -785,7 +981,8 @@ async def _acomplete_ollama( ) except Exception as e: - logger.error(f"Ollama async error: {e}") + logger.warning("Ollama request failed") + logger.debug("Ollama error: %s", e) raise RuntimeError( f"Ollama request failed. Is Ollama running? (ollama serve) Error: {e}" ) diff --git a/cortex/predictive_prevention.py b/cortex/predictive_prevention.py index ab7171d9..df7aa02f 100644 --- a/cortex/predictive_prevention.py +++ b/cortex/predictive_prevention.py @@ -77,6 +77,7 @@ def __init__(self, api_key: str | None = None, provider: str | None = None): self.router = LLMRouter( claude_api_key=api_key if normalized_provider == "claude" else None, kimi_api_key=api_key if normalized_provider == "kimi_k2" else None, + openai_api_key=api_key if normalized_provider == "openai" else None, default_provider=llm_provider, ) diff --git a/tests/test_llm_router.py b/tests/test_llm_router.py index 799d1940..80979418 100644 --- a/tests/test_llm_router.py +++ b/tests/test_llm_router.py @@ -158,7 +158,9 @@ class TestCostTracking(unittest.TestCase): def setUp(self): """Set up router with tracking enabled.""" self.router = LLMRouter( - claude_api_key="test-claude-key", kimi_api_key="test-kimi-key", track_costs=True + claude_api_key="test-claude-key", + kimi_api_key="test-kimi-key", + track_costs=True, ) def test_cost_calculation_claude(self): @@ -272,7 +274,9 @@ def test_claude_completion(self, mock_anthropic): # Test completion result = router._complete_claude( - messages=[{"role": "user", "content": "Hello"}], temperature=0.7, max_tokens=1024 + messages=[{"role": "user", "content": "Hello"}], + temperature=0.7, + max_tokens=1024, ) self.assertEqual(result.content, "Hello from Claude") @@ -342,7 +346,9 @@ def test_kimi_completion(self, mock_openai): # Test completion result = router._complete_kimi( - messages=[{"role": "user", "content": "Hello"}], temperature=0.7, max_tokens=1024 + messages=[{"role": "user", "content": "Hello"}], + temperature=0.7, + max_tokens=1024, ) self.assertEqual(result.content, "Hello from Kimi K2") @@ -373,7 +379,9 @@ def test_kimi_temperature_mapping(self, mock_openai): # Call with temperature=1.0 router._complete_kimi( - messages=[{"role": "user", "content": "Hello"}], temperature=1.0, max_tokens=1024 + messages=[{"role": "user", "content": "Hello"}], + temperature=1.0, + max_tokens=1024, ) # Verify temperature was scaled to 0.6 @@ -416,6 +424,225 @@ def test_kimi_with_tools(self, mock_openai): self.assertEqual(call_args.kwargs["tool_choice"], "auto") +class TestOpenAIIntegration(unittest.TestCase): + """Test OpenAI API integration.""" + + @patch("cortex.llm_router.AsyncOpenAI") + @patch("cortex.llm_router.OpenAI") + def test_openai_completion(self, mock_openai, mock_async_openai): + """Test OpenAI completion with mocked API.""" + mock_message = Mock() + mock_message.content = "Hello from OpenAI" + + mock_choice = Mock() + mock_choice.message = mock_message + + mock_response = Mock() + mock_response.choices = [mock_choice] + mock_response.usage = Mock(prompt_tokens=120, completion_tokens=30) + mock_response.model_dump = lambda: {"mock": "response"} + + mock_client = Mock() + mock_client.chat.completions.create.return_value = mock_response + mock_openai.return_value = mock_client + mock_async_openai.return_value = AsyncMock() + + router = LLMRouter(openai_api_key="test-openai-key") + router.openai_client = mock_client + router.openai_api_keys = ["test-openai-key"] + + result = router._complete_openai( + messages=[{"role": "user", "content": "Hello"}], + temperature=0.7, + max_tokens=1024, + ) + + self.assertEqual(result.content, "Hello from OpenAI") + self.assertEqual(result.provider, LLMProvider.OPENAI) + self.assertEqual(result.tokens_used, 150) + self.assertEqual(result.cost_usd, 0.0) + self.assertEqual(result.latency_seconds, 0.0) + self.assertEqual(result.raw_response, {"mock": "response"}) + + @patch("cortex.llm_router.AsyncOpenAI") + @patch("cortex.llm_router.OpenAI") + def test_openai_key_rotation(self, mock_openai, mock_async_openai): + """Test OpenAI key rotation on failure.""" + bad_client = Mock() + bad_client.chat.completions.create.side_effect = Exception("bad key") + + mock_message = Mock() + mock_message.content = "Recovered response" + + mock_choice = Mock() + mock_choice.message = mock_message + + mock_response = Mock() + mock_response.choices = [mock_choice] + mock_response.usage = Mock(prompt_tokens=90, completion_tokens=10) + mock_response.model_dump = lambda: {"mock": "response"} + + good_client = Mock() + good_client.chat.completions.create.return_value = mock_response + + mock_openai.return_value = good_client + mock_async_openai.return_value = AsyncMock() + + router = LLMRouter(openai_api_key="bad-key,good-key") + router.openai_client = bad_client + router.openai_api_keys = ["bad-key", "good-key"] + + result = router._complete_openai( + messages=[{"role": "user", "content": "Hello"}], + temperature=0.7, + max_tokens=1024, + ) + + self.assertEqual(result.provider, LLMProvider.OPENAI) + self.assertEqual(result.content, "Recovered response") + self.assertEqual(result.tokens_used, 100) + self.assertEqual(result.raw_response, {"mock": "response"}) + + @patch("cortex.llm_router.AsyncOpenAI") + @patch("cortex.llm_router.OpenAI") + def test_openai_all_keys_fail(self, mock_openai, mock_async_openai): + """Test OpenAI error propagation when all keys fail.""" + bad_client_1 = Mock() + bad_client_1.chat.completions.create.side_effect = Exception("bad key 1") + + bad_client_2 = Mock() + bad_client_2.chat.completions.create.side_effect = Exception("bad key 2") + + mock_openai.return_value = bad_client_2 + mock_async_openai.return_value = AsyncMock() + + router = LLMRouter(openai_api_key="bad-key-1,bad-key-2") + router.openai_client = bad_client_1 + router.openai_api_keys = ["bad-key-1", "bad-key-2"] + + with self.assertRaises(RuntimeError) as context: + router._complete_openai( + messages=[{"role": "user", "content": "Hello"}], + temperature=0.7, + max_tokens=1024, + ) + + self.assertIsNotNone(context.exception.__cause__) + + @patch("cortex.llm_router.AsyncOpenAI") + @patch("cortex.llm_router.OpenAI") + def test_acomplete_openai_success(self, mock_openai, mock_async_openai): + """Test async OpenAI completion with mocked API.""" + mock_message = Mock() + mock_message.content = "Async OpenAI response" + + mock_choice = Mock() + mock_choice.message = mock_message + + mock_response = Mock() + mock_response.choices = [mock_choice] + mock_response.usage = Mock(prompt_tokens=80, completion_tokens=20) + mock_response.model_dump = lambda: {"mock": "response"} + + mock_openai.return_value = Mock() + + mock_async_client = AsyncMock() + mock_async_client.chat.completions.create = AsyncMock(return_value=mock_response) + mock_async_openai.return_value = mock_async_client + + router = LLMRouter(openai_api_key="test-openai-key") + router.openai_client_async = mock_async_client + router.openai_api_keys = ["test-openai-key"] + + async def run_test(): + result = await router._acomplete_openai( + messages=[{"role": "user", "content": "Hello"}], + temperature=0.7, + max_tokens=1024, + ) + + self.assertEqual(result.provider, LLMProvider.OPENAI) + self.assertEqual(result.content, "Async OpenAI response") + self.assertEqual(result.tokens_used, 100) + self.assertEqual(result.cost_usd, 0.0) + self.assertEqual(result.latency_seconds, 0.0) + self.assertEqual(result.raw_response, {"mock": "response"}) + + asyncio.run(run_test()) + + @patch("cortex.llm_router.AsyncOpenAI") + @patch("cortex.llm_router.OpenAI") + def test_acomplete_openai_key_rotation(self, mock_openai, mock_async_openai): + """Test async OpenAI key rotation on failure.""" + mock_openai.return_value = Mock() + + bad_async_client = AsyncMock() + bad_async_client.chat.completions.create = AsyncMock(side_effect=Exception("bad key")) + + mock_message = Mock() + mock_message.content = "Async recovered" + + mock_choice = Mock() + mock_choice.message = mock_message + + mock_response = Mock() + mock_response.choices = [mock_choice] + mock_response.usage = Mock(prompt_tokens=60, completion_tokens=20) + mock_response.model_dump = lambda: {"mock": "response"} + + good_async_client = AsyncMock() + good_async_client.chat.completions.create = AsyncMock(return_value=mock_response) + + mock_async_openai.return_value = good_async_client + + router = LLMRouter(openai_api_key="bad-key good-key") + router.openai_client_async = bad_async_client + router.openai_api_keys = ["bad-key", "good-key"] + + async def run_test(): + result = await router._acomplete_openai( + messages=[{"role": "user", "content": "Hello"}], + temperature=0.7, + max_tokens=1024, + ) + + self.assertEqual(result.provider, LLMProvider.OPENAI) + self.assertEqual(result.content, "Async recovered") + self.assertEqual(result.tokens_used, 80) + + asyncio.run(run_test()) + + @patch("cortex.llm_router.AsyncOpenAI") + @patch("cortex.llm_router.OpenAI") + def test_acomplete_openai_all_keys_fail(self, mock_openai, mock_async_openai): + """Test async OpenAI error propagation when all keys fail.""" + mock_openai.return_value = Mock() + + bad_async_client_1 = AsyncMock() + bad_async_client_1.chat.completions.create = AsyncMock(side_effect=Exception("bad key 1")) + + bad_async_client_2 = AsyncMock() + bad_async_client_2.chat.completions.create = AsyncMock(side_effect=Exception("bad key 2")) + + mock_async_openai.return_value = bad_async_client_2 + + router = LLMRouter(openai_api_key="bad1,bad2") + router.openai_client_async = bad_async_client_1 + router.openai_api_keys = ["bad1", "bad2"] + + async def run_test(): + with self.assertRaises(RuntimeError) as context: + await router._acomplete_openai( + messages=[{"role": "user", "content": "Hello"}], + temperature=0.7, + max_tokens=1024, + ) + + self.assertIsNotNone(context.exception.__cause__) + + asyncio.run(run_test()) + + class TestEndToEnd(unittest.TestCase): """End-to-end integration tests.""" @@ -758,6 +985,7 @@ def run_tests(): suite.addTests(loader.loadTestsFromTestCase(TestCostTracking)) suite.addTests(loader.loadTestsFromTestCase(TestClaudeIntegration)) suite.addTests(loader.loadTestsFromTestCase(TestKimiIntegration)) + suite.addTests(loader.loadTestsFromTestCase(TestOpenAIIntegration)) suite.addTests(loader.loadTestsFromTestCase(TestEndToEnd)) suite.addTests(loader.loadTestsFromTestCase(TestConvenienceFunction)) suite.addTests(loader.loadTestsFromTestCase(TestParallelProcessing))