From 26fd6bda8a903b2c0838938f99035df531d1e976 Mon Sep 17 00:00:00 2001 From: Ekin Bozyel Date: Sun, 1 Feb 2026 21:20:14 +0300 Subject: [PATCH 1/6] feat: add MultiProviderClient for multi-LLM provider routing --- .../utils/multi_provider_ai_client.py | 135 ++++++++++++++++++ test.py | 8 ++ 2 files changed, 143 insertions(+) create mode 100644 agentlightning/utils/multi_provider_ai_client.py create mode 100644 test.py diff --git a/agentlightning/utils/multi_provider_ai_client.py b/agentlightning/utils/multi_provider_ai_client.py new file mode 100644 index 000000000..358eb2025 --- /dev/null +++ b/agentlightning/utils/multi_provider_ai_client.py @@ -0,0 +1,135 @@ +""" +Multi-Provider Client +======================== +Async client that routes to different LLM providers based on model name. + +Usage: + - google-gemini-2.0-flash → Google API + - groq-llama-3.3-70b → Groq API + +Model name must be in "provider-model" format. + +## Usage +```python +from multi_provider_client import MultiProviderClient +client = MultiProviderClient() +# Use with APO +algo = agl.APO( + client, + gradient_model="google-gemini-2.0-flash", + apply_edit_model="groq-llama-3.3-70b-versatile", +) + + +""" + +import os +from openai import AsyncOpenAI + + +class MultiProviderClient: + """Async client that routes to different providers based on model name. + Model format: "provider-model_name" + Examples: + - google-gemini-2.0-flash + - groq-meta-llama/llama-4-maverick-17b-128e-instruct + - etc. + """ + + def __init__(self, custom_providers: dict[str, dict] | None = None): + """ + Args: + custom_providers: Additional providers. Format: + { + "provider_name": { + "api_key": "...", # or env var name + "base_url": "https://..." + } + } + """ + + self.clients = {} + + # Only create clients for providers with API keys + if os.getenv("GOOGLE_API_KEY"): + self.clients["google"] = AsyncOpenAI( + api_key=os.getenv("GOOGLE_API_KEY"), + base_url="https://generativelanguage.googleapis.com/v1beta/openai/") + + if os.getenv("GROQ_API_KEY"): + self.clients["groq"] = AsyncOpenAI( + api_key=os.getenv("GROQ_API_KEY"), + base_url="https://api.groq.com/openai/v1") + + if os.getenv("OPENAI_API_KEY"): + self.clients["openai"] = AsyncOpenAI( + api_key=os.getenv("OPENAI_API_KEY")) + + if os.getenv("AZURE_OPENAI_API_KEY") and os.getenv("AZURE_OPENAI_ENDPOINT"): + self.clients["azure"] = AsyncOpenAI( + api_key=os.getenv("AZURE_OPENAI_API_KEY"), + base_url=os.getenv("AZURE_OPENAI_ENDPOINT")) + + if os.getenv("OPENROUTER_API_KEY"): + self.clients["openrouter"] = AsyncOpenAI( + api_key=os.getenv("OPENROUTER_API_KEY"), + base_url="https://openrouter.ai/api/v1") + + # Add custom providers + if custom_providers: + for name, config in custom_providers.items(): + api_key = config.get("api_key") or os.getenv(config.get("api_key_env", "")) + base_url = config.get("base_url") + self.clients[name] = AsyncOpenAI(api_key=api_key, base_url=base_url) + + + def _parse_model(self, model: str) -> tuple[str, str]: + """Parse model name into provider and actual model name. + + Args: + model: String in "provider-model_name" format + + Returns: + (provider, actual_model_name) tuple + """ + if "-" not in model: + raise ValueError(f"Model format must be 'provider-model_name': {model}") + + idx = model.find("-") + provider = model[:idx] + actual_model = model[idx + 1:] + + if provider not in self.clients: + for name in self.clients: + if model.startswith(name + "-"): + provider = name + actual_model = model[len(name) + 1:] + break + else: + raise ValueError(f"Unknown provider: {provider}. Supported: {list(self.clients.keys())}") + + return provider, actual_model + + + @property + def chat(self): + return self._ChatProxy(self) + + class _ChatProxy: + def __init__(self, parent): + self.parent = parent + + @property + def completions(self): + return self.parent._CompletionsProxy(self.parent) + + class _CompletionsProxy: + def __init__(self, parent): + self.parent = parent + + async def create(self, model: str, **kwargs): + provider, actual_model = self.parent._parse_model(model) + client = self.parent.clients[provider] + print("--- Multi Provider Client ---") + print(f"{provider.upper()}: {actual_model}") + return await client.chat.completions.create(model=actual_model, **kwargs) diff --git a/test.py b/test.py new file mode 100644 index 000000000..de63ac8ac --- /dev/null +++ b/test.py @@ -0,0 +1,8 @@ + +model="groq-meta-llama/llama-4-maverick-17b-128e-instruct" + +parts = model.split("-", 1) +provider = parts[0] +actual_model = parts[1] if len(parts) > 1 else model + +print(actual_model) \ No newline at end of file From 5fd14578d0ac8d8dd87c6c5f3d76435ec2e90976 Mon Sep 17 00:00:00 2001 From: Ekin Bozyel Date: Sun, 1 Feb 2026 21:20:54 +0300 Subject: [PATCH 2/6] delete test file --- test.py | 8 -------- 1 file changed, 8 deletions(-) delete mode 100644 test.py diff --git a/test.py b/test.py deleted file mode 100644 index de63ac8ac..000000000 --- a/test.py +++ /dev/null @@ -1,8 +0,0 @@ - -model="groq-meta-llama/llama-4-maverick-17b-128e-instruct" - -parts = model.split("-", 1) -provider = parts[0] -actual_model = parts[1] if len(parts) > 1 else model - -print(actual_model) \ No newline at end of file From 13fcd247f6ffe161486c05331dec3cccad171da4 Mon Sep 17 00:00:00 2001 From: Ekin Bozyel <50263592+john-fante@users.noreply.github.com> Date: Sun, 1 Feb 2026 21:44:31 +0300 Subject: [PATCH 3/6] Update agentlightning/utils/multi_provider_ai_client.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- agentlightning/utils/multi_provider_ai_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agentlightning/utils/multi_provider_ai_client.py b/agentlightning/utils/multi_provider_ai_client.py index 358eb2025..03f6facd6 100644 --- a/agentlightning/utils/multi_provider_ai_client.py +++ b/agentlightning/utils/multi_provider_ai_client.py @@ -11,7 +11,7 @@ ## Usage ```python -from multi_provider_client import MultiProviderClient +from agentlightning.utils.multi_provider_ai_client import MultiProviderClient client = MultiProviderClient() # Use with APO algo = agl.APO( From 13a96df2665c56a6962f6321ae0e464daf3c3c97 Mon Sep 17 00:00:00 2001 From: Ekin Bozyel <50263592+john-fante@users.noreply.github.com> Date: Sun, 1 Feb 2026 21:44:53 +0300 Subject: [PATCH 4/6] Update agentlightning/utils/multi_provider_ai_client.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- agentlightning/utils/multi_provider_ai_client.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/agentlightning/utils/multi_provider_ai_client.py b/agentlightning/utils/multi_provider_ai_client.py index 03f6facd6..7e8ffba24 100644 --- a/agentlightning/utils/multi_provider_ai_client.py +++ b/agentlightning/utils/multi_provider_ai_client.py @@ -42,7 +42,8 @@ def __init__(self, custom_providers: dict[str, dict] | None = None): custom_providers: Additional providers. Format: { "provider_name": { - "api_key": "...", # or env var name + "api_key": "...", # literal API key value (optional) + "api_key_env": "ENV_VAR", # name of env var containing the API key (optional) "base_url": "https://..." } } From 9b16ce4a15f268ef62c684466800b52913251d59 Mon Sep 17 00:00:00 2001 From: Ekin Bozyel Date: Mon, 9 Feb 2026 14:56:41 +0300 Subject: [PATCH 5/6] refactor: rename client file and update to LiteLLM --- .../utils/multi_provider_ai_client.py | 136 ------------------ agentlightning/utils/multi_provider_client.py | 69 +++++++++ 2 files changed, 69 insertions(+), 136 deletions(-) delete mode 100644 agentlightning/utils/multi_provider_ai_client.py create mode 100644 agentlightning/utils/multi_provider_client.py diff --git a/agentlightning/utils/multi_provider_ai_client.py b/agentlightning/utils/multi_provider_ai_client.py deleted file mode 100644 index 7e8ffba24..000000000 --- a/agentlightning/utils/multi_provider_ai_client.py +++ /dev/null @@ -1,136 +0,0 @@ -""" -Multi-Provider Client -======================== -Async client that routes to different LLM providers based on model name. - -Usage: - - google-gemini-2.0-flash → Google API - - groq-llama-3.3-70b → Groq API - -Model name must be in "provider-model" format. - -## Usage -```python -from agentlightning.utils.multi_provider_ai_client import MultiProviderClient -client = MultiProviderClient() -# Use with APO -algo = agl.APO( - client, - gradient_model="google-gemini-2.0-flash", - apply_edit_model="groq-llama-3.3-70b-versatile", -) - - -""" - -import os -from openai import AsyncOpenAI - - -class MultiProviderClient: - """Async client that routes to different providers based on model name. - Model format: "provider-model_name" - Examples: - - google-gemini-2.0-flash - - groq-meta-llama/llama-4-maverick-17b-128e-instruct - - etc. - """ - - def __init__(self, custom_providers: dict[str, dict] | None = None): - """ - Args: - custom_providers: Additional providers. Format: - { - "provider_name": { - "api_key": "...", # literal API key value (optional) - "api_key_env": "ENV_VAR", # name of env var containing the API key (optional) - "base_url": "https://..." - } - } - """ - - self.clients = {} - - # Only create clients for providers with API keys - if os.getenv("GOOGLE_API_KEY"): - self.clients["google"] = AsyncOpenAI( - api_key=os.getenv("GOOGLE_API_KEY"), - base_url="https://generativelanguage.googleapis.com/v1beta/openai/") - - if os.getenv("GROQ_API_KEY"): - self.clients["groq"] = AsyncOpenAI( - api_key=os.getenv("GROQ_API_KEY"), - base_url="https://api.groq.com/openai/v1") - - if os.getenv("OPENAI_API_KEY"): - self.clients["openai"] = AsyncOpenAI( - api_key=os.getenv("OPENAI_API_KEY")) - - if os.getenv("AZURE_OPENAI_API_KEY") and os.getenv("AZURE_OPENAI_ENDPOINT"): - self.clients["azure"] = AsyncOpenAI( - api_key=os.getenv("AZURE_OPENAI_API_KEY"), - base_url=os.getenv("AZURE_OPENAI_ENDPOINT")) - - if os.getenv("OPENROUTER_API_KEY"): - self.clients["openrouter"] = AsyncOpenAI( - api_key=os.getenv("OPENROUTER_API_KEY"), - base_url="https://openrouter.ai/api/v1") - - # Add custom providers - if custom_providers: - for name, config in custom_providers.items(): - api_key = config.get("api_key") or os.getenv(config.get("api_key_env", "")) - base_url = config.get("base_url") - self.clients[name] = AsyncOpenAI(api_key=api_key, base_url=base_url) - - - def _parse_model(self, model: str) -> tuple[str, str]: - """Parse model name into provider and actual model name. - - Args: - model: String in "provider-model_name" format - - Returns: - (provider, actual_model_name) tuple - """ - if "-" not in model: - raise ValueError(f"Model format must be 'provider-model_name': {model}") - - idx = model.find("-") - provider = model[:idx] - actual_model = model[idx + 1:] - - if provider not in self.clients: - for name in self.clients: - if model.startswith(name + "-"): - provider = name - actual_model = model[len(name) + 1:] - break - else: - raise ValueError(f"Unknown provider: {provider}. Supported: {list(self.clients.keys())}") - - return provider, actual_model - - - @property - def chat(self): - return self._ChatProxy(self) - - class _ChatProxy: - def __init__(self, parent): - self.parent = parent - - @property - def completions(self): - return self.parent._CompletionsProxy(self.parent) - - class _CompletionsProxy: - def __init__(self, parent): - self.parent = parent - - async def create(self, model: str, **kwargs): - provider, actual_model = self.parent._parse_model(model) - client = self.parent.clients[provider] - print("--- Multi Provider Client ---") - print(f"{provider.upper()}: {actual_model}") - return await client.chat.completions.create(model=actual_model, **kwargs) diff --git a/agentlightning/utils/multi_provider_client.py b/agentlightning/utils/multi_provider_client.py new file mode 100644 index 000000000..57690d22f --- /dev/null +++ b/agentlightning/utils/multi_provider_client.py @@ -0,0 +1,69 @@ +""" +Multi-Provider Client (LiteLLM Version) +======================================== +Async client that routes to different LLM providers using LiteLLM. + +Usage: + - gemini/gemini-2.0-flash → Google API + - groq/llama-3.3-70b → Groq API + - ollama/llama3 → Local Ollama + - openai/ → OpenAI or Custom Base URL + +Model name should follow the standard LiteLLM "provider/model" format. + +## Usage +```python +from agentlightning.utils.multi_provider_ai_client import MultiProviderClient +client = MultiProviderClient() + +# Use with APO +algo = agl.APO( + client, + gradient_model="gemini/gemini-2.0-flash", + apply_edit_model="groq/llama-3.3-70b-versatile", +) + +""" + +from litellm import acompletion + +class MultiProviderClient: + """Async client that routes to different providers using LiteLLM. + Uses standard LiteLLM 'provider/model' format. + """ + + def __init__(self, **kwargs): + """ + Initializes the client. LiteLLM automatically picks up API keys + from environment variables (e.g., GOOGLE_API_KEY, GROQ_API_KEY). + """ + pass + + @property + def chat(self): + return self._ChatProxy(self) + + class _ChatProxy: + def __init__(self, parent): + self.parent = parent + + @property + def completions(self): + return self.parent._CompletionsProxy(self.parent) + + class _CompletionsProxy: + def __init__(self, parent): + self.parent = parent + + async def create(self, model: str, **kwargs): + """ + Passes the request directly to LiteLLM for routing. + + Args: + model: String in "provider/model_name" format. + **kwargs: Additional arguments for the completion call. + """ + print("--- Multi Provider Client (LiteLLM) ---") + print(f"Routing to: {model}") + + return await acompletion(model=model, **kwargs) \ No newline at end of file From 610651fd1e9ae0b0d4a61788801645e978a6e146 Mon Sep 17 00:00:00 2001 From: Ekin Bozyel Date: Wed, 11 Feb 2026 00:22:14 +0300 Subject: [PATCH 6/6] feat(apo): add MultiProviderClient for hybrid-model optimization - Implement MultiProviderClient to support routing different APO stages to different LLM providers. - Add showcasing Gemini and Groq hybrid workflow. - Update with documentation for the new hybrid sample. - Improve static type safety in APO samples to resolve Pylance warnings. --- agentlightning/utils/multi_provider_client.py | 30 +--- examples/apo/README.md | 7 +- examples/apo/apo_multi_provider.py | 138 ++++++++++++++++++ 3 files changed, 146 insertions(+), 29 deletions(-) create mode 100644 examples/apo/apo_multi_provider.py diff --git a/agentlightning/utils/multi_provider_client.py b/agentlightning/utils/multi_provider_client.py index 57690d22f..48d424925 100644 --- a/agentlightning/utils/multi_provider_client.py +++ b/agentlightning/utils/multi_provider_client.py @@ -13,31 +13,19 @@ ## Usage ```python -from agentlightning.utils.multi_provider_ai_client import MultiProviderClient +from agentlightning.utils.multi_provider_client import MultiProviderClient client = MultiProviderClient() -# Use with APO -algo = agl.APO( - client, - gradient_model="gemini/gemini-2.0-flash", - apply_edit_model="groq/llama-3.3-70b-versatile", -) - +``` """ from litellm import acompletion class MultiProviderClient: - """Async client that routes to different providers using LiteLLM. - Uses standard LiteLLM 'provider/model' format. - """ + """Async client that routes to different providers using LiteLLM.""" def __init__(self, **kwargs): - """ - Initializes the client. LiteLLM automatically picks up API keys - from environment variables (e.g., GOOGLE_API_KEY, GROQ_API_KEY). - """ - pass + print("--- Multi Provider Client (LiteLLM) Initialized ---") @property def chat(self): @@ -56,14 +44,4 @@ def __init__(self, parent): self.parent = parent async def create(self, model: str, **kwargs): - """ - Passes the request directly to LiteLLM for routing. - - Args: - model: String in "provider/model_name" format. - **kwargs: Additional arguments for the completion call. - """ - print("--- Multi Provider Client (LiteLLM) ---") - print(f"Routing to: {model}") - return await acompletion(model=model, **kwargs) \ No newline at end of file diff --git a/examples/apo/README.md b/examples/apo/README.md index a3f647853..4dd99ed08 100644 --- a/examples/apo/README.md +++ b/examples/apo/README.md @@ -19,11 +19,12 @@ Follow the [installation guide](../../docs/tutorials/installation.md) to install | `room_selector.py` | Room booking agent implementation using function calling | | `room_selector_apo.py` | Training script using the built-in APO algorithm to optimize prompts | | `room_tasks.jsonl` | Dataset with room booking scenarios and expected selections | -| `apo_custom_algorithm.py` | Tutorial on creating custom algorithms (runnable as algo or runner) | +| `apo_custom_algorithm.py` | Tutorial on creating custom algorithms | | `apo_custom_algorithm_trainer.py` | Shows how to integrate custom algorithms into the Trainer | | `apo_debug.py` | Tutorial demonstrating various agent debugging techniques | -| `legacy_apo_client.py` | Deprecated APO client implementation compatible with Agent-lightning v0.1.x | -| `legacy_apo_server.py` | Deprecated APO server implementation compatible with Agent-lightning v0.1.x | +| `apo_multi_provider.py` | Hybrid optimization sample using multiple LLM backends | +| `legacy_apo_client.py` | Deprecated APO client implementation compatible with v0.1.x | +| `legacy_apo_server.py` | Deprecated APO server implementation compatible with v0.1.x | ## Sample 1: Using Built-in APO Algorithm diff --git a/examples/apo/apo_multi_provider.py b/examples/apo/apo_multi_provider.py new file mode 100644 index 000000000..4acb2d070 --- /dev/null +++ b/examples/apo/apo_multi_provider.py @@ -0,0 +1,138 @@ +""" +This sample code demonstrates how to use the MultiProviderClient with the APO algorithm +to tune mathematical reasoning prompts using a hybrid model setup. +""" + +import logging +import re +import asyncio +import multiprocessing +from typing import Tuple, cast, Dict, Any, List + +from dotenv import load_dotenv +load_dotenv() +import agentlightning as agl +from agentlightning import Trainer, setup_logging, PromptTemplate +from agentlightning.adapter import TraceToMessages +from agentlightning.algorithm.apo import APO +from agentlightning.types import Dataset +from litellm import completion +from agentlightning.utils.multi_provider_client import MultiProviderClient + + +# --- 1. Dataset Logic --- +def load_math_tasks() -> List[Dict[str, str]]: + """Small mock GSM8k-style dataset.""" + return [ + {"question": "If I have 3 apples and buy 2 more, how many do I have?", "expected": "5"}, + {"question": "A train travels 60 miles in 1 hour. How far in 3 hours?", "expected": "180"}, + {"question": "What is the square root of 144?", "expected": "12"}, + {"question": "If a shirt costs $20 and is 10% off, what is the price?", "expected": "18"}, + ] + +def load_train_val_dataset() -> Tuple[Dataset[Dict[str, str]], Dataset[Dict[str, str]]]: + dataset_full = load_math_tasks() + train_split = len(dataset_full) // 2 + # Use list() and cast to satisfy Pylance's SupportsIndex/slice checks + dataset_train = cast(Dataset[Dict[str, str]], list(dataset_full[:train_split])) + dataset_val = cast(Dataset[Dict[str, str]], list(dataset_full[train_split:])) + return dataset_train, dataset_val + +# --- 2. Agent Logic --- +class MathAgent(agl.LitAgent): + def __init__(self): + super().__init__() + + def rollout(self, task: Any, resources: Dict[str, Any], rollout: Any) -> float: + # Pylance fix: Explicitly cast task to Dict + t = cast(Dict[str, str], task) + prompt_template: PromptTemplate = resources.get("prompt_template") # type: ignore + + # Ensure template access is type-safe + template_str = getattr(prompt_template, "template", str(prompt_template)) + prompt = template_str.format(question=t["question"]) + + # Direct LiteLLM call + response = completion( + model="gemini/gemini-2.0-flash", + messages=[{"role": "user", "content": prompt}] + ) + answer = str(response.choices[0].message.content) + + # Reward: Numerical exact match check + pred_nums = re.findall(r"[-+]?\d*\.\d+|\d+", answer.split("Answer:")[-1]) + reward = 1.0 if pred_nums and pred_nums[-1] == t["expected"] else 0.0 + + agl.emit_reward(reward) + return reward + +# --- 3. Logging & Main --- +def setup_apo_logger(file_path: str = "apo_math.log") -> None: + file_handler = logging.FileHandler(file_path) + file_handler.setLevel(logging.INFO) + formatter = logging.Formatter("%(asctime)s [%(levelname)s] (%(name)s) %(message)s") + file_handler.setFormatter(formatter) + logging.getLogger("agentlightning.algorithm.apo").addHandler(file_handler) + +def main() -> None: + setup_logging() + setup_apo_logger() + + multi_client = MultiProviderClient() + + initial_prompt_str = "Solve: {question}" + + algo = APO[Dict[str, str]]( + multi_client, + gradient_model="gemini/gemini-2.0-flash", + apply_edit_model="groq/llama-3.3-70b-versatile", + val_batch_size=2, + gradient_batch_size=2, + beam_width=1, + branch_factor=1, + beam_rounds=1, + ) + + trainer = Trainer( + algorithm=algo, + n_runners=2, + initial_resources={ + "prompt_template": PromptTemplate(template=initial_prompt_str, engine="f-string") + }, + adapter=TraceToMessages(), + ) + + dataset_train, dataset_val = load_train_val_dataset() + agent = MathAgent() + + print("\n" + "="*60) + print("🚀 HYBRID APO OPTIMIZATION STARTING") + print("-" * 60) + + trainer.fit(agent=agent, train_dataset=dataset_train, val_dataset=dataset_val) + + # Print Final Prompt from the store + print("\n" + "="*60) + print("✅ OPTIMIZATION COMPLETE") + print("-" * 60) + print(f"INITIAL PROMPT:\n{initial_prompt_str}") + + + # Accessing the latest optimized prompt from the trainer store + try: + latest_resources = asyncio.run(trainer.store.query_resources()) + if latest_resources: + final_res = latest_resources[-1].resources.get("prompt_template") + final_prompt = getattr(final_res, "template", str(final_res)) + print(f"FINAL OPTIMIZED PROMPT:\n{final_prompt}") + except Exception as e: + print(f"Optimization finished. Check apo_math.log for detailed iteration results. Error: {e}") + + print("="*60 + "\n") + +if __name__ == "__main__": + try: + multiprocessing.set_start_method("fork", force=True) + except RuntimeError: + pass + main() \ No newline at end of file