From 7be5453d74e83aa5ad50237edb3344515bacb340 Mon Sep 17 00:00:00 2001 From: Dimitris Kargatzis Date: Sat, 15 Nov 2025 14:03:49 +0200 Subject: [PATCH 1/2] refactor: restructure modules and update configuration - Modularize config into package structure (split config.py) - Reorganize providers and integrations into unified structure - Extract shared utilities into core/utils package - Move rule utilities (codeowners, contributors) to rules/utils - Remove supervisor agent, implement factory pattern for agents - Fix OpenAI structured output schema compatibility - Update all imports and references across codebase - Sync documentation with current structure and env vars - Update tests to match refactored architecture Signed-off-by: Dimitris Kargatzis --- DEVELOPMENT.md | 70 ++-- LOCAL_SETUP.md | 68 ++-- docs/getting-started/configuration.md | 2 + docs/getting-started/quick-start.md | 2 + src/agents/__init__.py | 7 +- src/agents/acknowledgment_agent/agent.py | 2 +- src/agents/base.py | 66 +--- src/agents/engine_agent/models.py | 1 - src/agents/engine_agent/nodes.py | 5 +- src/agents/factory.py | 48 +++ src/agents/feasibility_agent/nodes.py | 2 +- src/agents/supervisor_agent/__init__.py | 16 - src/agents/supervisor_agent/agent.py | 229 ------------- src/agents/supervisor_agent/models.py | 62 ---- src/agents/supervisor_agent/nodes.py | 324 ------------------ src/api/rules.py | 4 +- src/core/config/__init__.py | 12 + src/core/config/cors_config.py | 13 + src/core/config/github_config.py | 17 + src/core/config/langsmith_config.py | 15 + src/core/config/logging_config.py | 14 + src/core/config/provider_config.py | 76 ++++ src/core/config/repo_config.py | 13 + src/core/{config.py => config/settings.py} | 134 +------- src/core/utils/README.md | 175 ++++++++++ src/core/utils/__init__.py | 21 ++ src/core/utils/caching.py | 184 ++++++++++ src/core/utils/logging.py | 126 +++++++ src/core/utils/metrics.py | 145 ++++++++ src/core/utils/retry.py | 135 ++++++++ src/core/utils/timeout.py | 75 ++++ src/event_processors/base.py | 4 +- src/event_processors/check_run.py | 4 +- .../deployment_protection_rule.py | 6 +- src/event_processors/deployment_review.py | 28 +- src/event_processors/pull_request.py | 12 +- src/event_processors/push.py | 4 +- src/event_processors/rule_creation.py | 4 +- .../violation_acknowledgment.py | 9 +- src/integrations/README.md | 71 ++++ src/integrations/__init__.py | 7 +- src/integrations/aws_bedrock.py | 320 ----------------- src/integrations/gcp_garden.py | 198 ----------- src/integrations/github/__init__.py | 12 + .../{github_api.py => github/api.py} | 0 src/{ => integrations}/providers/__init__.py | 6 +- .../providers/base.py} | 12 +- .../providers/bedrock_provider.py | 6 +- src/{ => integrations}/providers/factory.py | 30 +- .../providers/openai_provider.py | 8 +- .../providers/vertex_ai_provider.py | 4 +- src/rules/loaders/__init__.py | 18 + .../github_loader.py} | 19 +- src/rules/utils.py | 2 +- src/rules/utils/__init__.py | 35 ++ .../utils}/codeowners.py | 10 +- .../utils}/contributors.py | 25 +- src/rules/utils/validation.py | 119 +++++++ src/rules/validators.py | 9 +- src/tasks/scheduler/deployment_scheduler.py | 23 +- src/webhooks/handlers/issue_comment.py | 6 +- tests/unit/test_agents.py | 185 ---------- tests/unit/test_rule_engine_agent.py | 2 +- 63 files changed, 1598 insertions(+), 1663 deletions(-) create mode 100644 src/agents/factory.py delete mode 100644 src/agents/supervisor_agent/__init__.py delete mode 100644 src/agents/supervisor_agent/agent.py delete mode 100644 src/agents/supervisor_agent/models.py delete mode 100644 src/agents/supervisor_agent/nodes.py create mode 100644 src/core/config/__init__.py create mode 100644 src/core/config/cors_config.py create mode 100644 src/core/config/github_config.py create mode 100644 src/core/config/langsmith_config.py create mode 100644 src/core/config/logging_config.py create mode 100644 src/core/config/provider_config.py create mode 100644 src/core/config/repo_config.py rename src/core/{config.py => config/settings.py} (60%) create mode 100644 src/core/utils/README.md create mode 100644 src/core/utils/__init__.py create mode 100644 src/core/utils/caching.py create mode 100644 src/core/utils/logging.py create mode 100644 src/core/utils/metrics.py create mode 100644 src/core/utils/retry.py create mode 100644 src/core/utils/timeout.py create mode 100644 src/integrations/README.md delete mode 100644 src/integrations/aws_bedrock.py delete mode 100644 src/integrations/gcp_garden.py create mode 100644 src/integrations/github/__init__.py rename src/integrations/{github_api.py => github/api.py} (100%) rename src/{ => integrations}/providers/__init__.py (56%) rename src/{providers/base_provider.py => integrations/providers/base.py} (80%) rename src/{ => integrations}/providers/bedrock_provider.py (98%) rename src/{ => integrations}/providers/factory.py (81%) rename src/{ => integrations}/providers/openai_provider.py (85%) rename src/{ => integrations}/providers/vertex_ai_provider.py (98%) create mode 100644 src/rules/loaders/__init__.py rename src/rules/{github_provider.py => loaders/github_loader.py} (87%) create mode 100644 src/rules/utils/__init__.py rename src/{integrations => rules/utils}/codeowners.py (95%) rename src/{integrations => rules/utils}/contributors.py (93%) create mode 100644 src/rules/utils/validation.py diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index f547f74..7e66417 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -66,33 +66,57 @@ cp .env.example .env Required environment variables: ```bash -# GitHub App Configuration -APP_NAME_GITHUB=your-app-name -CLIENT_ID_GITHUB=your-app-id -APP_CLIENT_SECRET=your-client-secret -PRIVATE_KEY_BASE64_GITHUB=your-base64-private-key -WEBHOOK_SECRET_GITHUB=your-webhook-secret - -# AI Configuration -OPENAI_API_KEY=your-openai-api-key -AI_MODEL=gpt-4.1-mini +# GitHub Configuration (required) +APP_NAME_GITHUB=your_app_name +APP_CLIENT_ID_GITHUB=your_client_id +APP_CLIENT_SECRET_GITHUB=your_client_secret +PRIVATE_KEY_BASE64_GITHUB=your_private_key_base64 +WEBHOOK_SECRET_GITHUB=your_webhook_secret + +# AI Provider Selection +AI_PROVIDER=openai # Options: openai, bedrock, vertex_ai + +# Common AI Settings (defaults for all agents) AI_MAX_TOKENS=4096 AI_TEMPERATURE=0.1 +# OpenAI Configuration (when AI_PROVIDER=openai) +OPENAI_API_KEY=your_openai_api_key_here +OPENAI_MODEL=gpt-4.1-mini # Optional, defaults to gpt-4.1-mini + +# Engine Agent Configuration +AI_ENGINE_MAX_TOKENS=8000 # Default: 8000 +AI_ENGINE_TEMPERATURE=0.1 + +# Feasibility Agent Configuration +AI_FEASIBILITY_MAX_TOKENS=4096 +AI_FEASIBILITY_TEMPERATURE=0.1 + +# Acknowledgment Agent Configuration +AI_ACKNOWLEDGMENT_MAX_TOKENS=2000 +AI_ACKNOWLEDGMENT_TEMPERATURE=0.1 + # LangSmith Configuration -LANGCHAIN_TRACING_V2=true +LANGCHAIN_TRACING_V2=false LANGCHAIN_ENDPOINT=https://api.smith.langchain.com -LANGCHAIN_API_KEY=your-langsmith-api-key +LANGCHAIN_API_KEY=your_langsmith_api_key LANGCHAIN_PROJECT=watchflow-dev -# Development Settings -DEBUG=true -LOG_LEVEL=DEBUG -ENVIRONMENT=development - # CORS Configuration CORS_HEADERS=["*"] -CORS_ORIGINS='["http://localhost:3000", "http://127.0.0.1:3000"]' +CORS_ORIGINS=["http://localhost:3000", "http://127.0.0.1:3000", "http://localhost:5500", "https://warestack.github.io", "https://watchflow.dev"] + +# Repository Configuration +REPO_CONFIG_BASE_PATH=.watchflow +REPO_CONFIG_RULES_FILE=rules.yaml + +# Logging Configuration +LOG_LEVEL=INFO +LOG_FORMAT=%(asctime)s - %(name)s - %(levelname)s - %(message)s + +# Development Settings +DEBUG=false +ENVIRONMENT=development ``` ### 4. GitHub App Setup @@ -306,22 +330,20 @@ With LangSmith configured, you can: ### Local Rule Testing +**💡 Tip**: Test your natural language rules at [watchflow.dev](https://watchflow.dev) to verify they're supported and get the generated YAML. Copy the output directly into your `rules.yaml` file. + Create a test repository with `.watchflow/rules.yaml`: ```yaml rules: - - id: test-rule - name: Test Rule - description: Test rule for development + - description: Test rule for development enabled: true severity: medium event_types: [pull_request] parameters: test_param: "test_value" - - id: status-check-required - name: Status Check Required - description: All PRs must pass required status checks + - description: All PRs must pass required status checks enabled: true severity: high event_types: [pull_request] diff --git a/LOCAL_SETUP.md b/LOCAL_SETUP.md index 627e879..7b12b78 100644 --- a/LOCAL_SETUP.md +++ b/LOCAL_SETUP.md @@ -162,35 +162,57 @@ cp .env.example .env Edit your `.env` file with the following configuration: ```bash -# GitHub App Configuration +# GitHub Configuration (required) APP_NAME_GITHUB=watchflow-dev -CLIENT_ID_GITHUB=your_app_id_from_github_app_settings -CLIENT_SECRET_GITHUB=your_client_secret_from_github_app_settings +APP_CLIENT_ID_GITHUB=your_app_id_from_github_app_settings +APP_CLIENT_SECRET_GITHUB=your_client_secret_from_github_app_settings PRIVATE_KEY_BASE64_GITHUB=your_base64_encoded_private_key -REDIRECT_URI_GITHUB=http://localhost:3000 - -# GitHub Webhook Configuration WEBHOOK_SECRET_GITHUB=your_webhook_secret_from_step_1 -# OpenAI API Configuration -OPENAI_API_KEY=your-openai-api-key +# AI Provider Selection +AI_PROVIDER=openai # Options: openai, bedrock, vertex_ai + +# Common AI Settings (defaults for all agents) +AI_MAX_TOKENS=4096 +AI_TEMPERATURE=0.1 + +# OpenAI Configuration (when AI_PROVIDER=openai) +OPENAI_API_KEY=your_openai_api_key_here +OPENAI_MODEL=gpt-4.1-mini # Optional, defaults to gpt-4.1-mini + +# Engine Agent Configuration +AI_ENGINE_MAX_TOKENS=8000 # Default: 8000 +AI_ENGINE_TEMPERATURE=0.1 + +# Feasibility Agent Configuration +AI_FEASIBILITY_MAX_TOKENS=4096 +AI_FEASIBILITY_TEMPERATURE=0.1 + +# Acknowledgment Agent Configuration +AI_ACKNOWLEDGMENT_MAX_TOKENS=2000 +AI_ACKNOWLEDGMENT_TEMPERATURE=0.1 -# LangChain Configuration (Optional - for AI debugging) -LANGCHAIN_TRACING_V2=true +# LangSmith Configuration +LANGCHAIN_TRACING_V2=false LANGCHAIN_ENDPOINT=https://api.smith.langchain.com -LANGCHAIN_API_KEY=your-langsmith-api-key +LANGCHAIN_API_KEY=your_langsmith_api_key LANGCHAIN_PROJECT=watchflow-dev -# Application Configuration -ENVIRONMENT=development - # CORS Configuration CORS_HEADERS=["*"] -CORS_ORIGINS='["http://localhost:3000", "http://127.0.0.1:3000"]' +CORS_ORIGINS=["http://localhost:3000", "http://127.0.0.1:3000", "http://localhost:5500", "https://warestack.github.io", "https://watchflow.dev"] + +# Repository Configuration +REPO_CONFIG_BASE_PATH=.watchflow +REPO_CONFIG_RULES_FILE=rules.yaml -# AWS Configuration (if using AWS services) -AWS_ACCESS_KEY_ID=your_aws_access_key_id -AWS_SECRET_ACCESS_KEY=your_aws_secret_access_key +# Logging Configuration +LOG_LEVEL=INFO +LOG_FORMAT=%(asctime)s - %(name)s - %(levelname)s - %(message)s + +# Development Settings +DEBUG=false +ENVIRONMENT=development ``` ### 4.3 Encode Your Private Key @@ -271,22 +293,20 @@ Open your browser and navigate to: ### 7.3 Test Rule Evaluation +**💡 Tip**: You can test your natural language rules at [watchflow.dev](https://watchflow.dev) to see if they're supported and get the generated YAML configuration. Then copy and paste it into your repository's `rules.yaml` file. + Create a test rule in a monitored repository by adding `.watchflow/rules.yaml`: ```yaml rules: - - id: test-rule - name: Test Rule for Local Development - description: Simple rule to test local setup + - description: Simple rule to test local setup enabled: true severity: medium event_types: [pull_request] parameters: test_param: "local_test" - - id: pr-approval-required - name: PR Approval Required - description: All pull requests must have at least 1 approval + - description: All pull requests must have at least 1 approval enabled: true severity: high event_types: [pull_request] diff --git a/docs/getting-started/configuration.md b/docs/getting-started/configuration.md index 87b5cbc..9f51376 100644 --- a/docs/getting-started/configuration.md +++ b/docs/getting-started/configuration.md @@ -5,6 +5,8 @@ Learn how to create effective governance rules that adapt to your team's needs a ## Rule Configuration +**💡 Pro Tip**: Not sure if your rule is supported? Test it at [watchflow.dev](https://watchflow.dev) first! Enter your natural language rule description, and the tool will generate the YAML configuration for you. Simply copy and paste it into your `rules.yaml` file. + ### Basic Rule Structure Rules are defined in YAML format and stored in `.watchflow/rules.yaml` in your repository. Each rule consists of diff --git a/docs/getting-started/quick-start.md b/docs/getting-started/quick-start.md index d07bac2..5a21c50 100644 --- a/docs/getting-started/quick-start.md +++ b/docs/getting-started/quick-start.md @@ -30,6 +30,8 @@ Get Watchflow up and running in minutes to replace static protection rules with ## Step 2: Create Your Rules +**💡 Pro Tip**: Before writing rules manually, test your natural language rules at [watchflow.dev](https://watchflow.dev) to see if they're supported. The tool will generate the YAML configuration for you - just copy and paste it into your `rules.yaml` file! + Create `.watchflow/rules.yaml` in your repository root to define your governance rules: ```yaml diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 129c1af..88f2b92 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -7,17 +7,16 @@ """ from src.agents.acknowledgment_agent import AcknowledgmentAgent -from src.agents.base import AgentResult, BaseAgent, SupervisorAgent +from src.agents.base import AgentResult, BaseAgent from src.agents.engine_agent import RuleEngineAgent +from src.agents.factory import get_agent from src.agents.feasibility_agent import RuleFeasibilityAgent -from src.agents.supervisor_agent import RuleSupervisorAgent __all__ = [ "BaseAgent", - "SupervisorAgent", "AgentResult", "RuleFeasibilityAgent", "RuleEngineAgent", "AcknowledgmentAgent", - "RuleSupervisorAgent", + "get_agent", ] diff --git a/src/agents/acknowledgment_agent/agent.py b/src/agents/acknowledgment_agent/agent.py index ec862aa..0e98c93 100644 --- a/src/agents/acknowledgment_agent/agent.py +++ b/src/agents/acknowledgment_agent/agent.py @@ -11,7 +11,7 @@ from src.agents.acknowledgment_agent.models import AcknowledgmentContext, AcknowledgmentEvaluation from src.agents.acknowledgment_agent.prompts import create_evaluation_prompt, get_system_prompt from src.agents.base import AgentResult, BaseAgent -from src.providers import get_chat_model +from src.integrations.providers import get_chat_model logger = logging.getLogger(__name__) diff --git a/src/agents/base.py b/src/agents/base.py index 3e15735..293146a 100644 --- a/src/agents/base.py +++ b/src/agents/base.py @@ -2,12 +2,12 @@ Base agent classes and utilities for agents. """ -import asyncio import logging from abc import ABC, abstractmethod from typing import Any, TypeVar -from src.providers import get_chat_model +from src.core.utils.timeout import execute_with_timeout +from src.integrations.providers import get_chat_model logger = logging.getLogger(__name__) @@ -70,24 +70,20 @@ async def _retry_structured_output(self, llm, output_model, prompt, **kwargs) -> Raises: Exception: If all retries fail """ - structured_llm = llm.with_structured_output(output_model) + from src.core.utils.retry import retry_async - for attempt in range(self.max_retries): - try: - result = await structured_llm.ainvoke(prompt, **kwargs) - if attempt > 0: - logger.info(f"✅ Structured output succeeded on attempt {attempt + 1}") - return result - except Exception as e: - if attempt == self.max_retries - 1: - logger.error(f"❌ Structured output failed after {self.max_retries} attempts: {e}") - raise Exception(f"Structured output failed after {self.max_retries} attempts: {str(e)}") from e + structured_llm = llm.with_structured_output(output_model) - wait_time = self.retry_delay * (2**attempt) - logger.warning(f"⚠️ Structured output attempt {attempt + 1} failed, retrying in {wait_time}s: {e}") - await asyncio.sleep(wait_time) + async def _invoke_structured() -> T: + """Inner function to invoke structured LLM.""" + return await structured_llm.ainvoke(prompt, **kwargs) - raise Exception(f"Structured output failed after {self.max_retries} attempts") + return await retry_async( + _invoke_structured, + max_retries=self.max_retries, + initial_delay=self.retry_delay, + exceptions=(Exception,), + ) async def _execute_with_timeout(self, coro, timeout: float = 30.0): """ @@ -101,39 +97,15 @@ async def _execute_with_timeout(self, coro, timeout: float = 30.0): The result of the coroutine Raises: - Exception: If the operation times out + TimeoutError: If the operation times out """ - try: - return await asyncio.wait_for(coro, timeout=timeout) - except TimeoutError as err: - raise Exception(f"Operation timed out after {timeout} seconds") from err + return await execute_with_timeout( + coro, + timeout=timeout, + timeout_message=f"Operation timed out after {timeout} seconds", + ) @abstractmethod async def execute(self, **kwargs) -> AgentResult: """Execute the agent with given parameters.""" pass - - -class SupervisorAgent(BaseAgent): - """ - Supervisor agent that coordinates multiple sub-agents. - """ - - def __init__(self, sub_agents: dict[str, BaseAgent] = None, **kwargs): - self.sub_agents = sub_agents or {} - super().__init__(**kwargs) - - async def coordinate_agents(self, task: str, **kwargs) -> AgentResult: - """ - Coordinate multiple agents to complete a complex task. - - Args: - task: Description of the task to coordinate - **kwargs: Additional parameters for the task - - Returns: - AgentResult with the coordinated results - """ - # This is a template for supervisor coordination - # Subclasses should implement specific coordination logic - raise NotImplementedError("Subclasses must implement coordinate_agents") diff --git a/src/agents/engine_agent/models.py b/src/agents/engine_agent/models.py index 910189b..ed0d6e9 100644 --- a/src/agents/engine_agent/models.py +++ b/src/agents/engine_agent/models.py @@ -44,7 +44,6 @@ class LLMEvaluationResponse(BaseModel): details: dict[str, Any] = Field( description="Detailed reasoning and metadata", default_factory=dict, - json_schema_extra={"additionalProperties": False}, ) how_to_fix: str | None = Field(description="Specific instructions on how to fix the violation", default=None) diff --git a/src/agents/engine_agent/nodes.py b/src/agents/engine_agent/nodes.py index 3abf542..16e61f4 100644 --- a/src/agents/engine_agent/nodes.py +++ b/src/agents/engine_agent/nodes.py @@ -24,7 +24,7 @@ create_validation_strategy_prompt, get_llm_evaluation_system_prompt, ) -from src.providers import get_chat_model +from src.integrations.providers import get_chat_model from src.rules.validators import VALIDATOR_REGISTRY logger = logging.getLogger(__name__) @@ -297,7 +297,8 @@ async def _execute_single_llm_evaluation( messages = [SystemMessage(content=get_llm_evaluation_system_prompt()), HumanMessage(content=evaluation_prompt)] # Use structured output for reliable parsing - structured_llm = llm.with_structured_output(LLMEvaluationResponse) + # Use function_calling method for better OpenAI compatibility + structured_llm = llm.with_structured_output(LLMEvaluationResponse, method="function_calling") evaluation_result = await structured_llm.ainvoke(messages) execution_time = (time.time() - start_time) * 1000 diff --git a/src/agents/factory.py b/src/agents/factory.py new file mode 100644 index 0000000..e7fa353 --- /dev/null +++ b/src/agents/factory.py @@ -0,0 +1,48 @@ +""" +Agent factory for creating agent instances by name. + +Provides a simple interface to get agents by their type name, +centralizing agent instantiation for consistency. +""" + +import logging +from typing import Any + +from src.agents.acknowledgment_agent import AcknowledgmentAgent +from src.agents.base import BaseAgent +from src.agents.engine_agent import RuleEngineAgent +from src.agents.feasibility_agent import RuleFeasibilityAgent + +logger = logging.getLogger(__name__) + + +def get_agent(agent_type: str, **kwargs: Any) -> BaseAgent: + """ + Get an agent instance by type name. + + Args: + agent_type: Type of agent ("engine", "feasibility", "acknowledgment") + **kwargs: Additional configuration for the agent + + Returns: + Agent instance + + Raises: + ValueError: If agent_type is not supported + + Examples: + >>> engine_agent = get_agent("engine") + >>> feasibility_agent = get_agent("feasibility") + >>> acknowledgment_agent = get_agent("acknowledgment") + """ + agent_type = agent_type.lower() + + if agent_type == "engine": + return RuleEngineAgent(**kwargs) + elif agent_type == "feasibility": + return RuleFeasibilityAgent(**kwargs) + elif agent_type == "acknowledgment": + return AcknowledgmentAgent(**kwargs) + else: + supported = ", ".join(["engine", "feasibility", "acknowledgment"]) + raise ValueError(f"Unsupported agent type: {agent_type}. Supported: {supported}") diff --git a/src/agents/feasibility_agent/nodes.py b/src/agents/feasibility_agent/nodes.py index 0270cc2..c1d0218 100644 --- a/src/agents/feasibility_agent/nodes.py +++ b/src/agents/feasibility_agent/nodes.py @@ -6,7 +6,7 @@ from src.agents.feasibility_agent.models import FeasibilityAnalysis, FeasibilityState, YamlGeneration from src.agents.feasibility_agent.prompts import RULE_FEASIBILITY_PROMPT, YAML_GENERATION_PROMPT -from src.providers import get_chat_model +from src.integrations.providers import get_chat_model logger = logging.getLogger(__name__) diff --git a/src/agents/supervisor_agent/__init__.py b/src/agents/supervisor_agent/__init__.py deleted file mode 100644 index f5e5b55..0000000 --- a/src/agents/supervisor_agent/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -Supervisor Agent for coordinating multiple specialized agents. -""" - -from src.agents.supervisor_agent.agent import RuleSupervisorAgent -from src.agents.supervisor_agent.models import CoordinationResult, SupervisorState -from src.agents.supervisor_agent.nodes import coordinate_agents, synthesize_final_result, validate_results - -__all__ = [ - "RuleSupervisorAgent", - "SupervisorState", - "CoordinationResult", - "coordinate_agents", - "validate_results", - "synthesize_final_result", -] diff --git a/src/agents/supervisor_agent/agent.py b/src/agents/supervisor_agent/agent.py deleted file mode 100644 index e9ed347..0000000 --- a/src/agents/supervisor_agent/agent.py +++ /dev/null @@ -1,229 +0,0 @@ -""" -Rule Supervisor Agent for coordinating multiple specialized agents. -""" - -import asyncio -import logging -import time -from typing import Any - -from langgraph.graph import END, START, StateGraph - -from src.agents.acknowledgment_agent import AcknowledgmentAgent -from src.agents.base import AgentResult, SupervisorAgent -from src.agents.engine_agent import RuleEngineAgent -from src.agents.feasibility_agent import RuleFeasibilityAgent -from src.agents.supervisor_agent.models import AgentTask, SupervisorAgentResult, SupervisorState -from src.agents.supervisor_agent.nodes import coordinate_agents, synthesize_final_result, validate_results - -logger = logging.getLogger(__name__) - - -class RuleSupervisorAgent(SupervisorAgent): - """ - Supervisor agent that coordinates multiple specialized agents for complex rule evaluation. - - Architecture: - 1. Feasibility Agent: Determines if rules are implementable - 2. Engine Agent: Evaluates rules against events - 3. Acknowledgment Agent: Processes violation acknowledgments - 4. Supervisor: Coordinates and synthesizes results - """ - - def __init__(self, max_concurrent_agents: int = 3, timeout: float = 300.0, **kwargs): # Increased to 5 minutes - super().__init__(**kwargs) - self.max_concurrent_agents = max_concurrent_agents - self.timeout = timeout - - # Initialize sub-agents - self.sub_agents = { - "feasibility": RuleFeasibilityAgent(), - "engine": RuleEngineAgent(), - "acknowledgment": AcknowledgmentAgent(), - } - - logger.info(f"🔧 RuleSupervisorAgent initialized with {len(self.sub_agents)} sub-agents") - logger.info(f"🔧 Max concurrent agents: {max_concurrent_agents}") - logger.info(f"🔧 Timeout: {timeout}s") - - def _build_graph(self) -> StateGraph: - """Build the LangGraph workflow for supervisor coordination.""" - workflow = StateGraph(SupervisorState) - - # Add nodes - workflow.add_node("coordinate_agents", coordinate_agents) - workflow.add_node("validate_results", validate_results) - workflow.add_node("synthesize_final_result", synthesize_final_result) - - # Add edges - workflow.add_edge(START, "coordinate_agents") - workflow.add_edge("coordinate_agents", "validate_results") - workflow.add_edge("validate_results", "synthesize_final_result") - workflow.add_edge("synthesize_final_result", END) - - logger.info("🔧 RuleSupervisorAgent graph built with coordination workflow") - return workflow.compile() - - async def execute( - self, event_type: str, event_data: dict[str, Any], rules: list[dict[str, Any]], **kwargs - ) -> AgentResult: - """ - Execute coordinated rule evaluation using multiple specialized agents. - """ - start_time = time.time() - - try: - logger.info(f"🚀 RuleSupervisorAgent starting coordinated evaluation for {event_type}") - logger.info(f"🚀 Processing {len(rules)} rules with {len(self.sub_agents)} agents") - - # Prepare initial state - initial_state = SupervisorState( - task_description=f"Evaluate {len(rules)} rules for {event_type} event", - event_type=event_type, - event_data=event_data, - rules=rules, - start_time=time.time(), - ) - - # Run the coordination graph with timeout - result = await self._execute_with_timeout(self.graph.ainvoke(initial_state), timeout=self.timeout) - - execution_time = time.time() - start_time - logger.info(f"✅ RuleSupervisorAgent coordination completed in {execution_time:.2f}s") - - # Extract coordination result - coordination_result = result.get("coordination_result") - if not coordination_result: - raise Exception("No coordination result produced") - - return AgentResult( - success=coordination_result.overall_success, - message=coordination_result.summary, - data={ - "coordination_result": coordination_result.dict(), - "agent_results": result.get("agent_results", []), - "conflicts": result.get("conflicts", []), - }, - metadata={ - "execution_time_ms": execution_time * 1000, - "agents_used": len(result.get("agent_results", [])), - "conflicts_detected": len(result.get("conflicts", [])), - "coordination_type": "supervisor", - }, - ) - - except TimeoutError: - execution_time = time.time() - start_time - logger.error(f"⏰ RuleSupervisorAgent coordination timed out after {execution_time:.2f}s") - return AgentResult( - success=False, - message=f"Supervisor coordination timed out after {self.timeout}s", - data={}, - metadata={ - "execution_time_ms": execution_time * 1000, - "timeout_used": self.timeout, - "error_type": "timeout", - }, - ) - - except Exception as e: - execution_time = time.time() - start_time - logger.error(f"❌ RuleSupervisorAgent coordination failed: {e}") - return AgentResult( - success=False, - message=f"Supervisor coordination failed: {str(e)}", - data={}, - metadata={"execution_time_ms": execution_time * 1000, "error_type": type(e).__name__}, - ) - - async def coordinate_agents(self, task: str, **kwargs) -> AgentResult: - """ - Coordinate multiple agents to complete a complex task. - """ - try: - logger.info(f"🔧 Coordinating agents for task: {task}") - - # Create tasks for each sub-agent - tasks = [] - for agent_name, _agent in self.sub_agents.items(): - task_obj = AgentTask(agent_name=agent_name, task_type=task, parameters=kwargs, priority=1) - tasks.append(task_obj) - - # Execute tasks concurrently with rate limiting - results = [] - for i in range(0, len(tasks), self.max_concurrent_agents): - batch = tasks[i : i + self.max_concurrent_agents] - batch_results = await asyncio.gather( - *[self._execute_agent_task(task) for task in batch], return_exceptions=True - ) - results.extend(batch_results) - - # Filter out exceptions and convert to results - agent_results = [] - for result in results: - if isinstance(result, Exception): - logger.error(f"❌ Agent task failed: {result}") - agent_results.append( - SupervisorAgentResult(success=False, message=f"Agent task failed: {str(result)}", data={}) - ) - else: - agent_results.append(result) - - return AgentResult( - success=any(r.success for r in agent_results), - message=f"Coordinated {len(agent_results)} agents", - data={"agent_results": agent_results}, - metadata={"agents_executed": len(agent_results)}, - ) - - except Exception as e: - logger.error(f"❌ Agent coordination failed: {e}") - return AgentResult( - success=False, - message=f"Agent coordination failed: {str(e)}", - data={}, - metadata={"error_type": type(e).__name__}, - ) - - async def _execute_agent_task(self, task: AgentTask) -> SupervisorAgentResult: - """ - Execute a single agent task with timeout and error handling. - """ - try: - agent = self.sub_agents.get(task.agent_name) - if not agent: - raise Exception(f"Unknown agent: {task.agent_name}") - - logger.info(f"🔧 Executing {task.agent_name} agent for {task.task_type}") - - # Execute the agent with timeout - result = await asyncio.wait_for(agent.execute(**task.parameters), timeout=task.timeout) - - return SupervisorAgentResult( - success=result.success, - message=result.message, - data=result.data, - metadata={ - "agent_name": task.agent_name, - "task_type": task.task_type, - "execution_time_ms": result.metadata.get("execution_time_ms", 0), - }, - ) - - except TimeoutError: - logger.error(f"⏰ {task.agent_name} agent timed out after {task.timeout}s") - return SupervisorAgentResult( - success=False, - message=f"{task.agent_name} agent timed out after {task.timeout}s", - data={}, - metadata={"agent_name": task.agent_name, "timeout_used": task.timeout, "error_type": "timeout"}, - ) - - except Exception as e: - logger.error(f"❌ {task.agent_name} agent failed: {e}") - return SupervisorAgentResult( - success=False, - message=f"{task.agent_name} agent failed: {str(e)}", - data={}, - metadata={"agent_name": task.agent_name, "error_type": type(e).__name__}, - ) diff --git a/src/agents/supervisor_agent/models.py b/src/agents/supervisor_agent/models.py deleted file mode 100644 index 742824d..0000000 --- a/src/agents/supervisor_agent/models.py +++ /dev/null @@ -1,62 +0,0 @@ -""" -Data models for the Rule Supervisor Agent. -""" - -from datetime import datetime -from typing import Any - -from pydantic import BaseModel, Field - - -class AgentTask(BaseModel): - """Represents a task assigned to a specific agent.""" - - agent_name: str = Field(description="Name of the agent to execute the task") - task_type: str = Field(description="Type of task (feasibility, evaluation, acknowledgment, etc.)") - parameters: dict[str, Any] = Field(description="Parameters for the task", default_factory=dict) - priority: int = Field(description="Task priority (higher = more important)", default=1) - timeout: float = Field(description="Timeout in seconds", default=30.0) - - -class SupervisorAgentResult(BaseModel): - """Result from an individual agent execution within supervisor context.""" - - success: bool = Field(description="Whether the task was successful") - message: str = Field(description="Result message or error description") - data: dict[str, Any] = Field(description="The actual result data", default_factory=dict) - metadata: dict[str, Any] = Field(description="Additional metadata", default_factory=dict) - - -class CoordinationResult(BaseModel): - """Result from coordinating multiple agents.""" - - overall_success: bool = Field(description="Whether the overall coordination was successful") - summary: str = Field(description="Summary of the coordination result") - agent_results: list[SupervisorAgentResult] = Field(description="Results from all agents", default_factory=list) - conflicts: list[str] = Field(description="Conflicts detected between agents", default_factory=list) - confidence_score: float = Field(description="Confidence in the final decision", ge=0.0, le=1.0, default=0.0) - reasoning: list[str] = Field(description="Step-by-step reasoning for the final decision", default_factory=list) - - -class SupervisorState(BaseModel): - """State for the supervisor coordination workflow.""" - - # Input - task_description: str = Field(description="Description of the overall task") - event_type: str = Field(description="Type of GitHub event being processed") - event_data: dict[str, Any] = Field(description="GitHub event data", default_factory=dict) - rules: list[dict[str, Any]] = Field(description="Rules to evaluate", default_factory=list) - - # Coordination - agent_tasks: list[AgentTask] = Field(description="Tasks to be executed by agents", default_factory=list) - agent_results: list[SupervisorAgentResult] = Field( - description="Results from agent executions", default_factory=list - ) - - # Output - coordination_result: CoordinationResult | None = Field(description="Final coordination result", default=None) - - # Metadata - start_time: datetime | None = Field(description="When coordination started", default=None) - end_time: datetime | None = Field(description="When coordination ended", default=None) - errors: list[str] = Field(description="Any errors that occurred", default_factory=list) diff --git a/src/agents/supervisor_agent/nodes.py b/src/agents/supervisor_agent/nodes.py deleted file mode 100644 index 80ac198..0000000 --- a/src/agents/supervisor_agent/nodes.py +++ /dev/null @@ -1,324 +0,0 @@ -""" -LangGraph nodes for the Rule Supervisor Agent. -""" - -import asyncio -import logging -import time - -from src.agents.supervisor_agent.models import AgentTask, CoordinationResult, SupervisorAgentResult, SupervisorState - -logger = logging.getLogger(__name__) - - -async def coordinate_agents(state: SupervisorState) -> SupervisorState: - """ - Coordinate multiple agents to execute tasks in parallel. - """ - try: - logger.info(f"🔧 Starting agent coordination for {state.event_type}") - - # Create tasks for each agent based on the event type and rules - tasks = _create_agent_tasks(state) - state.agent_tasks = tasks - - logger.info(f"🔧 Created {len(tasks)} agent tasks") - - # Execute tasks concurrently (with limits to avoid rate limits) - results = await _execute_tasks_concurrently(tasks, max_concurrent=3) - state.agent_results = results - - logger.info(f"🔧 Completed {len(results)} agent executions") - - # Log results summary - successful_results = [r for r in results if r.success] - logger.info(f"🔧 Successful executions: {len(successful_results)}/{len(results)}") - - except Exception as e: - logger.error(f"❌ Error in agent coordination: {e}") - state.errors.append(f"Coordination failed: {str(e)}") - - return state - - -async def validate_results(state: SupervisorState) -> SupervisorState: - """ - Validate and cross-check results from multiple agents. - """ - try: - logger.info("🔍 Validating agent results") - - # Check for conflicting results - conflicts = _detect_result_conflicts(state.agent_results) - if conflicts: - logger.warning(f"⚠️ Found {len(conflicts)} result conflicts") - state.errors.extend(conflicts) - - # Validate result quality - quality_issues = _validate_result_quality(state.agent_results) - if quality_issues: - logger.warning(f"⚠️ Found {len(quality_issues)} quality issues") - state.errors.extend(quality_issues) - - logger.info("🔍 Result validation completed") - - except Exception as e: - logger.error(f"❌ Error in result validation: {e}") - state.errors.append(f"Validation failed: {str(e)}") - - return state - - -async def synthesize_final_result(state: SupervisorState) -> SupervisorState: - """ - Synthesize final decision from multiple agent results. - """ - try: - logger.info("🧠 Synthesizing final result from agent outputs") - - # Create coordination result - coordination_result = _synthesize_coordination_result(state) - state.coordination_result = coordination_result - - # Set end time - state.end_time = time.time() - - logger.info(f"✅ Final synthesis completed: success={coordination_result.overall_success}") - logger.info(f"✅ Confidence score: {coordination_result.confidence_score}") - - except Exception as e: - logger.error(f"❌ Error in final synthesis: {e}") - state.errors.append(f"Synthesis failed: {str(e)}") - - # Create fallback result - state.coordination_result = CoordinationResult( - overall_success=False, - summary=f"Synthesis failed: {str(e)}", - confidence_score=0.0, - ) - - return state - - -def _create_agent_tasks(state: SupervisorState) -> list[AgentTask]: - """ - Create tasks for each agent based on the event type and rules. - """ - tasks = [] - - # Feasibility agent task - check if rules are implementable - if state.rules: - tasks.append( - AgentTask( - agent_name="feasibility", - task_type="rule_feasibility_check", - parameters={"rule_description": "\n".join([rule.get("description", "") for rule in state.rules])}, - priority=1, - timeout=30.0, - ) - ) - - # Engine agent task - evaluate rules against the event - if state.rules and state.event_data: - tasks.append( - AgentTask( - agent_name="engine", - task_type="rule_evaluation", - parameters={"event_type": state.event_type, "event_data": state.event_data, "rules": state.rules}, - priority=2, - timeout=45.0, - ) - ) - - # Acknowledgment agent task - if this is an acknowledgment request - if ( - state.event_type == "issue_comment" - and "acknowledgment" in state.event_data.get("comment", {}).get("body", "").lower() - ): - tasks.append( - AgentTask( - agent_name="acknowledgment", - task_type="acknowledgment_evaluation", - parameters={ - "acknowledgment_reason": state.event_data.get("comment", {}).get("body", ""), - "violations": state.event_data.get("violations", []), - "pr_data": state.event_data.get("pull_request", {}), - "commenter": state.event_data.get("comment", {}).get("user", {}).get("login", ""), - "rules": state.rules, - }, - priority=3, - timeout=30.0, - ) - ) - - return tasks - - -async def _execute_tasks_concurrently(tasks: list[AgentTask], max_concurrent: int = 3) -> list[SupervisorAgentResult]: - """ - Execute tasks concurrently with rate limiting. - """ - results = [] - - # Execute tasks in batches to avoid overwhelming the system - for i in range(0, len(tasks), max_concurrent): - batch = tasks[i : i + max_concurrent] - batch_results = await asyncio.gather(*[_execute_single_task(task) for task in batch], return_exceptions=True) - - # Convert exceptions to error results - for result in batch_results: - if isinstance(result, Exception): - results.append( - SupervisorAgentResult( - success=False, - message=f"Task execution failed: {str(result)}", - data={}, - metadata={"error_type": type(result).__name__}, - ) - ) - else: - results.append(result) - - return results - - -async def _execute_single_task(task: AgentTask) -> SupervisorAgentResult: - """ - Execute a single agent task. - """ - # This would be implemented by the supervisor agent - # For now, return a placeholder result - return SupervisorAgentResult( - success=True, - message=f"Task {task.task_type} completed successfully", - data={"task_type": task.task_type, "agent_name": task.agent_name}, - metadata={"execution_time_ms": 1000}, - ) - - -def _detect_result_conflicts(results: list[SupervisorAgentResult]) -> list[str]: - """ - Detect conflicts between agent results. - """ - conflicts = [] - - # Check for contradictory success/failure states - success_results = [r for r in results if r.success] - failure_results = [r for r in results if not r.success] - - if success_results and failure_results: - conflicts.append("Conflicting success/failure states between agents") - - # Check for contradictory recommendations - recommendations = [] - for result in results: - if "recommendation" in result.data: - recommendations.append(result.data["recommendation"]) - - if len(set(recommendations)) > 1: - conflicts.append("Conflicting recommendations between agents") - - return conflicts - - -def _validate_result_quality(results: list[SupervisorAgentResult]) -> list[str]: - """ - Validate the quality of agent results. - """ - issues = [] - - for result in results: - # Check for empty or missing data - if not result.data: - issues.append(f"Agent result has no data: {result.message}") - - # Check for very short messages (might indicate errors) - if len(result.message) < 10: - issues.append(f"Agent result has very short message: {result.message}") - - return issues - - -def _synthesize_coordination_result(state: SupervisorState) -> CoordinationResult: - """ - Synthesize final coordination result from agent outputs. - """ - # Calculate overall success - successful_results = [r for r in state.agent_results if r.success] - overall_success = len(successful_results) > 0 and len(state.errors) == 0 - - # Generate summary - summary = _generate_final_decision(state.agent_results, state.errors) - - # Calculate confidence score - confidence_score = _calculate_confidence_score(state.agent_results) - - # Detect conflicts - conflicts = _detect_result_conflicts(state.agent_results) - - # Generate reasoning - reasoning = _generate_reasoning(state.agent_results, state.errors) - - return CoordinationResult( - overall_success=overall_success, - summary=summary, - agent_results=state.agent_results, - conflicts=conflicts, - confidence_score=confidence_score, - reasoning=reasoning, - ) - - -def _calculate_confidence_score(results: list[SupervisorAgentResult]) -> float: - """ - Calculate confidence score based on agent results. - """ - if not results: - return 0.0 - - # Base confidence on success rate - successful_results = [r for r in results if r.success] - success_rate = len(successful_results) / len(results) - - # Adjust based on result quality - quality_score = 0.0 - for result in results: - if result.success and result.data: - quality_score += 0.2 # Bonus for successful results with data - - return min(1.0, success_rate + quality_score) - - -def _generate_final_decision(results: list[SupervisorAgentResult], errors: list[str]) -> str: - """ - Generate final decision based on agent results. - """ - if errors: - return f"Coordination completed with {len(errors)} errors: {'; '.join(errors[:3])}" - - successful_results = [r for r in results if r.success] - if not successful_results: - return "All agent executions failed" - - return f"Coordination completed successfully with {len(successful_results)}/{len(results)} agents" - - -def _generate_reasoning(results: list[SupervisorAgentResult], errors: list[str]) -> list[str]: - """ - Generate step-by-step reasoning for the final decision. - """ - reasoning = [] - - reasoning.append(f"Coordinated {len(results)} agents") - - successful_results = [r for r in results if r.success] - reasoning.append(f"Successful executions: {len(successful_results)}/{len(results)}") - - if errors: - reasoning.append(f"Errors encountered: {len(errors)}") - - conflicts = _detect_result_conflicts(results) - if conflicts: - reasoning.append(f"Conflicts detected: {len(conflicts)}") - - return reasoning diff --git a/src/api/rules.py b/src/api/rules.py index 653cfba..d284db1 100644 --- a/src/api/rules.py +++ b/src/api/rules.py @@ -1,7 +1,7 @@ from fastapi import APIRouter from pydantic import BaseModel -from src.agents.feasibility_agent.agent import RuleFeasibilityAgent +from src.agents import get_agent router = APIRouter() @@ -14,7 +14,7 @@ class RuleEvaluationRequest(BaseModel): @router.post("/rules/evaluate") async def evaluate_rule(request: RuleEvaluationRequest): # Create agent instance (uses centralized config) - agent = RuleFeasibilityAgent() + agent = get_agent("feasibility") # Use the execute method result = await agent.execute(rule_description=request.rule_text) diff --git a/src/core/config/__init__.py b/src/core/config/__init__.py new file mode 100644 index 0000000..7a53f8a --- /dev/null +++ b/src/core/config/__init__.py @@ -0,0 +1,12 @@ +""" +Configuration package - unified access point. + +This package provides all configuration classes and the global config instance. +""" + +from src.core.config.settings import Config, config + +__all__ = [ + "Config", + "config", +] diff --git a/src/core/config/cors_config.py b/src/core/config/cors_config.py new file mode 100644 index 0000000..723a6de --- /dev/null +++ b/src/core/config/cors_config.py @@ -0,0 +1,13 @@ +""" +CORS configuration. +""" + +from dataclasses import dataclass + + +@dataclass +class CORSConfig: + """CORS configuration.""" + + headers: list[str] + origins: list[str] diff --git a/src/core/config/github_config.py b/src/core/config/github_config.py new file mode 100644 index 0000000..0fab3ee --- /dev/null +++ b/src/core/config/github_config.py @@ -0,0 +1,17 @@ +""" +GitHub configuration. +""" + +from dataclasses import dataclass + + +@dataclass +class GitHubConfig: + """GitHub configuration.""" + + app_name: str + app_id: str + app_client_secret: str + private_key: str + webhook_secret: str + api_base_url: str = "https://api.github.com" diff --git a/src/core/config/langsmith_config.py b/src/core/config/langsmith_config.py new file mode 100644 index 0000000..dd20458 --- /dev/null +++ b/src/core/config/langsmith_config.py @@ -0,0 +1,15 @@ +""" +LangSmith configuration for agent debugging. +""" + +from dataclasses import dataclass + + +@dataclass +class LangSmithConfig: + """LangSmith configuration for agent debugging.""" + + tracing_v2: bool = False + endpoint: str = "https://api.smith.langchain.com" + api_key: str = "" + project: str = "watchflow-dev" diff --git a/src/core/config/logging_config.py b/src/core/config/logging_config.py new file mode 100644 index 0000000..3cf86a4 --- /dev/null +++ b/src/core/config/logging_config.py @@ -0,0 +1,14 @@ +""" +Logging configuration. +""" + +from dataclasses import dataclass + + +@dataclass +class LoggingConfig: + """Logging configuration.""" + + level: str = "INFO" + format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + file_path: str | None = None diff --git a/src/core/config/provider_config.py b/src/core/config/provider_config.py new file mode 100644 index 0000000..26bd6fa --- /dev/null +++ b/src/core/config/provider_config.py @@ -0,0 +1,76 @@ +""" +Provider configuration. +""" + +from dataclasses import dataclass + + +@dataclass +class AgentConfig: + """Per-agent configuration.""" + + max_tokens: int = 4096 + temperature: float = 0.1 + + +@dataclass +class ProviderConfig: + """Provider configuration.""" + + api_key: str + provider: str = "openai" + max_tokens: int = 4096 + temperature: float = 0.1 + # Provider-specific model fields + openai_model: str | None = None + bedrock_model_id: str | None = None + vertex_ai_model: str | None = None + # Optional provider-specific fields + # AWS Bedrock + bedrock_region: str | None = None + aws_access_key_id: str | None = None + aws_secret_access_key: str | None = None + aws_profile: str | None = None + # GCP Model Garden + gcp_project: str | None = None + gcp_location: str | None = None + gcp_service_account_key_base64: str | None = None + # Per-agent configurations + engine_agent: AgentConfig | None = None + feasibility_agent: AgentConfig | None = None + acknowledgment_agent: AgentConfig | None = None + + def get_model_for_provider(self, provider: str) -> str: + """Get the appropriate model for the given provider with fallbacks.""" + provider = provider.lower() + + if provider == "openai": + return self.openai_model or "gpt-4.1-mini" + elif provider == "bedrock": + return self.bedrock_model_id or "anthropic.claude-3-sonnet-20240229-v1:0" + elif provider in ["vertex_ai", "garden", "model_garden", "gcp", "vertex", "vertexai"]: + # Support both Gemini and Claude models in Vertex AI + return self.vertex_ai_model or "gemini-pro" + else: + return "gpt-4.1-mini" # Ultimate fallback + + def get_max_tokens_for_agent(self, agent: str | None = None) -> int: + """Get max tokens for agent with fallback to global config.""" + if agent and hasattr(self, agent): + agent_config = getattr(self, agent) + if agent_config and hasattr(agent_config, "max_tokens"): + return agent_config.max_tokens + return self.max_tokens + + def get_temperature_for_agent(self, agent: str | None = None) -> float: + """Get temperature for agent with fallback to global config.""" + if agent and hasattr(self, agent): + agent_config = getattr(self, agent) + if agent_config and hasattr(agent_config, "temperature"): + return agent_config.temperature + return self.temperature + + +# Backward compatibility aliases +AgentAIConfig = AgentConfig +AIConfig = ProviderConfig diff --git a/src/core/config/repo_config.py b/src/core/config/repo_config.py new file mode 100644 index 0000000..b300bce --- /dev/null +++ b/src/core/config/repo_config.py @@ -0,0 +1,13 @@ +""" +Repository configuration. +""" + +from dataclasses import dataclass + + +@dataclass +class RepoConfig: + """Repository configuration.""" + + base_path: str = ".watchflow" + rules_file: str = "rules.yaml" diff --git a/src/core/config.py b/src/core/config/settings.py similarity index 60% rename from src/core/config.py rename to src/core/config/settings.py index 2d38720..add5f3b 100644 --- a/src/core/config.py +++ b/src/core/config/settings.py @@ -1,124 +1,20 @@ +""" +Main configuration class that composes all configs. +""" + import json import os -from dataclasses import dataclass from dotenv import load_dotenv +from src.core.config.cors_config import CORSConfig +from src.core.config.github_config import GitHubConfig +from src.core.config.langsmith_config import LangSmithConfig +from src.core.config.logging_config import LoggingConfig +from src.core.config.provider_config import AgentConfig, ProviderConfig +from src.core.config.repo_config import RepoConfig -@dataclass -class GitHubConfig: - """GitHub configuration.""" - - app_name: str - app_id: str - app_client_secret: str - private_key: str - webhook_secret: str - api_base_url: str = "https://api.github.com" - - -@dataclass -class AgentAIConfig: - """Per-agent AI configuration.""" - - max_tokens: int = 4096 - temperature: float = 0.1 - - -@dataclass -class AIConfig: - """AI provider configuration.""" - - api_key: str - provider: str = "openai" - max_tokens: int = 4096 - temperature: float = 0.1 - # Provider-specific model fields - openai_model: str | None = None - bedrock_model_id: str | None = None - vertex_ai_model: str | None = None - # Optional provider-specific fields - # AWS Bedrock - bedrock_region: str | None = None - aws_access_key_id: str | None = None - aws_secret_access_key: str | None = None - aws_profile: str | None = None - # GCP Model Garden - gcp_project: str | None = None - gcp_location: str | None = None - gcp_service_account_key_base64: str | None = None - # Per-agent configurations - engine_agent: AgentAIConfig | None = None - feasibility_agent: AgentAIConfig | None = None - acknowledgment_agent: AgentAIConfig | None = None - - def get_model_for_provider(self, provider: str) -> str: - """Get the appropriate model for the given provider with fallbacks.""" - provider = provider.lower() - - if provider == "openai": - return self.openai_model or "gpt-4.1-mini" - elif provider == "bedrock": - return self.bedrock_model_id or "anthropic.claude-3-sonnet-20240229-v1:0" - elif provider in ["vertex_ai", "garden", "model_garden", "gcp", "vertex", "vertexai"]: - # Support both Gemini and Claude models in Vertex AI - return self.vertex_ai_model or "gemini-pro" - else: - return "gpt-4.1-mini" # Ultimate fallback - - def get_max_tokens_for_agent(self, agent: str | None = None) -> int: - """Get max tokens for agent with fallback to global config.""" - if agent and hasattr(self, agent): - agent_config = getattr(self, agent) - if agent_config and hasattr(agent_config, "max_tokens"): - return agent_config.max_tokens - return self.max_tokens - - def get_temperature_for_agent(self, agent: str | None = None) -> float: - """Get temperature for agent with fallback to global config.""" - if agent and hasattr(self, agent): - agent_config = getattr(self, agent) - if agent_config and hasattr(agent_config, "temperature"): - return agent_config.temperature - return self.temperature - - -@dataclass -class LangSmithConfig: - """LangSmith configuration for AI agent debugging.""" - - tracing_v2: bool = False - endpoint: str = "https://api.smith.langchain.com" - api_key: str = "" - project: str = "watchflow-dev" - - -@dataclass -class CORSConfig: - """CORS configuration.""" - - headers: list[str] - origins: list[str] - - -@dataclass -class RepoConfig: - """Repo configuration.""" - - base_path: str = ".watchflow" - rules_file: str = "rules.yaml" - - -@dataclass -class LoggingConfig: - """Logging configuration.""" - - level: str = "INFO" - format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - file_path: str | None = None - - -# Load environment variables from a .env file located in the same directory as this script. +# Load environment variables from a .env file load_dotenv() @@ -134,7 +30,7 @@ def __init__(self): webhook_secret=os.getenv("WEBHOOK_SECRET_GITHUB", ""), ) - self.ai = AIConfig( + self.ai = ProviderConfig( provider=os.getenv("AI_PROVIDER", "openai"), api_key=os.getenv("OPENAI_API_KEY", ""), max_tokens=int(os.getenv("AI_MAX_TOKENS", "4096")), @@ -153,15 +49,15 @@ def __init__(self): gcp_location=os.getenv("GCP_LOCATION"), gcp_service_account_key_base64=os.getenv("GCP_SERVICE_ACCOUNT_KEY_BASE64"), # Per-agent configurations - engine_agent=AgentAIConfig( + engine_agent=AgentConfig( max_tokens=int(os.getenv("AI_ENGINE_MAX_TOKENS", "8000")), temperature=float(os.getenv("AI_ENGINE_TEMPERATURE", "0.1")), ), - feasibility_agent=AgentAIConfig( + feasibility_agent=AgentConfig( max_tokens=int(os.getenv("AI_FEASIBILITY_MAX_TOKENS", "4096")), temperature=float(os.getenv("AI_FEASIBILITY_TEMPERATURE", "0.1")), ), - acknowledgment_agent=AgentAIConfig( + acknowledgment_agent=AgentConfig( max_tokens=int(os.getenv("AI_ACKNOWLEDGMENT_MAX_TOKENS", "2000")), temperature=float(os.getenv("AI_ACKNOWLEDGMENT_TEMPERATURE", "0.1")), ), diff --git a/src/core/utils/README.md b/src/core/utils/README.md new file mode 100644 index 0000000..7ed9348 --- /dev/null +++ b/src/core/utils/README.md @@ -0,0 +1,175 @@ +# Core Utilities Module + +This module provides shared utilities for retry logic, caching, logging, metrics, and timeout handling that can be used across the Watchflow codebase. + +## Modules + +### `retry.py` - Retry Utilities + +Provides decorators and functions for retrying async operations with exponential backoff. + +**Functions:** +- `retry_with_backoff()` - Decorator for retrying async functions +- `retry_async()` - Function for retrying async function calls + +**Example:** +```python +from src.core.utils.retry import retry_with_backoff + +@retry_with_backoff(max_retries=3, initial_delay=1.0) +async def fetch_data(): + return await api_call() +``` + +### `timeout.py` - Timeout Utilities + +Provides functions for executing async operations with timeout handling. + +**Functions:** +- `execute_with_timeout()` - Execute coroutine with timeout +- `timeout_decorator()` - Decorator for adding timeout to async functions + +**Example:** +```python +from src.core.utils.timeout import execute_with_timeout + +result = await execute_with_timeout( + long_operation(), + timeout=60.0 +) +``` + +### `caching.py` - Caching Utilities + +Provides async-friendly caching with TTL support. + +**Classes:** +- `AsyncCache` - Cache with TTL and automatic expiration + +**Functions:** +- `cached_async()` - Decorator for caching async function results + +**Example:** +```python +from src.core.utils.caching import AsyncCache, cached_async + +cache = AsyncCache(maxsize=100, ttl=3600) + +@cached_async(cache=cache, key_func=lambda repo: f"repo:{repo}") +async def fetch_repo_data(repo: str): + return await api_call(repo) +``` + +### `logging.py` - Structured Logging Utilities + +Provides context managers and decorators for structured operation logging. + +**Functions:** +- `log_operation()` - Context manager for structured operation logging +- `log_function_call()` - Decorator for logging function calls + +**Example:** +```python +from src.core.utils.logging import log_operation + +async with log_operation("rule_evaluation", repo=repo, pr=pr_number): + result = await evaluate_rules(...) +``` + +### `metrics.py` - Performance Metrics Utilities + +Provides utilities for tracking and recording performance metrics. + +**Functions:** +- `track_metrics()` - Context manager for tracking operation metrics +- `metrics_decorator()` - Decorator for tracking function call metrics + +**Example:** +```python +from src.core.utils.metrics import track_metrics + +async with track_metrics("rule_evaluation", rule_count=5) as metrics: + result = await evaluate_rules(...) + metrics["violations_found"] = len(result.violations) +``` + +## Usage in Codebase + +### Updated Files + +The following files have been updated to use the new utilities: + +1. **`src/agents/base.py`** + - `_retry_structured_output()` now uses `retry_async()` + - `_execute_with_timeout()` now uses `execute_with_timeout()` + +2. **`src/integrations/contributors.py`** + - Replaced manual cache implementation with `AsyncCache` + +### Migration Guide + +If you have code using the old patterns, here's how to migrate: + +**Old retry pattern:** +```python +for attempt in range(max_retries): + try: + return await func() + except Exception as e: + if attempt == max_retries - 1: + raise + await asyncio.sleep(delay * (2 ** attempt)) +``` + +**New retry pattern:** +```python +from src.core.utils.retry import retry_async + +result = await retry_async( + func, + max_retries=3, + initial_delay=1.0 +) +``` + +**Old timeout pattern:** +```python +try: + return await asyncio.wait_for(coro, timeout=30.0) +except TimeoutError: + raise Exception("Operation timed out") +``` + +**New timeout pattern:** +```python +from src.core.utils.timeout import execute_with_timeout + +result = await execute_with_timeout(coro, timeout=30.0) +``` + +**Old cache pattern:** +```python +cache: dict[str, dict] = {} +if key in cache: + cached_data = cache[key] + if time.time() - cached_data["timestamp"] < ttl: + return cached_data["value"] +``` + +**New cache pattern:** +```python +from src.core.utils.caching import AsyncCache + +cache = AsyncCache(maxsize=100, ttl=3600) +cached_value = cache.get(key) +if cached_value is not None: + return cached_value +``` + +## Benefits + +1. **Code Reusability** - Common patterns extracted into reusable utilities +2. **Consistency** - Same retry/cache/timeout behavior across the codebase +3. **Maintainability** - Single place to update retry/cache logic +4. **Testability** - Utilities can be tested independently +5. **Type Safety** - Full type hints for better IDE support diff --git a/src/core/utils/__init__.py b/src/core/utils/__init__.py new file mode 100644 index 0000000..03af47c --- /dev/null +++ b/src/core/utils/__init__.py @@ -0,0 +1,21 @@ +""" +Shared utilities for retry, caching, logging, metrics, and timeout handling. + +This module provides reusable utilities that can be used across the codebase +to avoid code duplication and ensure consistent behavior. +""" + +from src.core.utils.caching import AsyncCache, cached_async +from src.core.utils.logging import log_operation +from src.core.utils.metrics import track_metrics +from src.core.utils.retry import retry_with_backoff +from src.core.utils.timeout import execute_with_timeout + +__all__ = [ + "AsyncCache", + "cached_async", + "log_operation", + "track_metrics", + "retry_with_backoff", + "execute_with_timeout", +] diff --git a/src/core/utils/caching.py b/src/core/utils/caching.py new file mode 100644 index 0000000..51313a9 --- /dev/null +++ b/src/core/utils/caching.py @@ -0,0 +1,184 @@ +""" +Caching utilities for async operations. + +Provides async-friendly caching with TTL support and decorators +for caching function results. +""" + +import logging +from collections.abc import Callable +from datetime import datetime +from functools import wraps +from typing import Any + +from cachetools import TTLCache + +logger = logging.getLogger(__name__) + + +class AsyncCache: + """ + Async-friendly cache with TTL support. + + This cache stores values with timestamps and automatically + expires entries based on TTL. + + Example: + cache = AsyncCache(maxsize=100, ttl=3600) + cache.set("key", "value") + value = cache.get("key") + """ + + def __init__(self, maxsize: int = 100, ttl: int = 3600): + """ + Initialize async cache. + + Args: + maxsize: Maximum number of entries in cache + ttl: Time to live in seconds + """ + self._cache: dict[str, dict[str, Any]] = {} + self.maxsize = maxsize + self.ttl = ttl + + def get(self, key: str) -> Any | None: + """ + Get cached value if not expired. + + Args: + key: Cache key + + Returns: + Cached value or None if not found or expired + """ + if key not in self._cache: + return None + + cached_data = self._cache[key] + age = datetime.now().timestamp() - cached_data.get("timestamp", 0) + + if age >= self.ttl: + del self._cache[key] + logger.debug(f"Cache entry '{key}' expired (age: {age:.2f}s, ttl: {self.ttl}s)") + return None + + logger.debug(f"Cache hit for '{key}'") + return cached_data.get("value") + + def set(self, key: str, value: Any) -> None: + """ + Set cached value with timestamp. + + Args: + key: Cache key + value: Value to cache + """ + if len(self._cache) >= self.maxsize: + # Remove oldest entry + oldest_key = min( + self._cache.keys(), + key=lambda k: self._cache[k].get("timestamp", 0), + ) + logger.debug(f"Cache full, evicting oldest entry '{oldest_key}'") + del self._cache[oldest_key] + + self._cache[key] = { + "value": value, + "timestamp": datetime.now().timestamp(), + } + logger.debug(f"Cached entry '{key}'") + + def clear(self) -> None: + """Clear all cached values.""" + count = len(self._cache) + self._cache.clear() + logger.debug(f"Cleared {count} cache entries") + + def invalidate(self, key: str) -> None: + """ + Invalidate a specific cache entry. + + Args: + key: Cache key to invalidate + """ + if key in self._cache: + del self._cache[key] + logger.debug(f"Invalidated cache entry '{key}'") + + def size(self) -> int: + """ + Get current cache size. + + Returns: + Number of entries in cache + """ + return len(self._cache) + + +def cached_async( + cache: AsyncCache | TTLCache | None = None, + key_func: Callable[..., str] | None = None, + ttl: int | None = None, + maxsize: int = 100, +): + """ + Decorator for caching async function results. + + Args: + cache: Cache instance to use (creates new AsyncCache if None) + key_func: Function to generate cache key from function arguments + ttl: Time to live in seconds (only used if cache is None) + maxsize: Maximum cache size (only used if cache is None) + + Returns: + Decorated async function with caching + + Example: + @cached_async(ttl=3600, key_func=lambda repo, *args: f"repo:{repo}") + async def fetch_repo_data(repo: str): + return await api_call(repo) + """ + if cache is None: + if ttl: + cache = AsyncCache(maxsize=maxsize, ttl=ttl) + else: + # Use TTLCache as fallback + cache = TTLCache(maxsize=maxsize, ttl=ttl or 3600) + + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + # Generate cache key + if key_func: + cache_key = key_func(*args, **kwargs) + else: + # Default: use function name and arguments + cache_key = f"{func.__name__}:{args}:{kwargs}" + + # Check cache + if isinstance(cache, AsyncCache): + cached_value = cache.get(cache_key) + else: + # TTLCache + cached_value = cache.get(cache_key) + + if cached_value is not None: + logger.debug(f"Cache hit for {func.__name__} with key '{cache_key}'") + return cached_value + + # Cache miss - execute function + logger.debug(f"Cache miss for {func.__name__} with key '{cache_key}'") + result = await func(*args, **kwargs) + + # Store in cache + if isinstance(cache, AsyncCache): + cache.set(cache_key, result) + else: + # TTLCache + cache[cache_key] = result + + return result + + return wrapper + + return decorator diff --git a/src/core/utils/logging.py b/src/core/utils/logging.py new file mode 100644 index 0000000..65281e8 --- /dev/null +++ b/src/core/utils/logging.py @@ -0,0 +1,126 @@ +""" +Structured logging utilities. + +Provides context managers and decorators for structured operation logging +with timing, error tracking, and metadata. +""" + +import logging +import time +from contextlib import asynccontextmanager +from functools import wraps +from typing import Any + +logger = logging.getLogger(__name__) + + +@asynccontextmanager +async def log_operation( + operation: str, + subject_ids: dict[str, str] | None = None, + **context: Any, +): + """ + Context manager for structured operation logging. + + Logs operation start, completion, and errors with timing information. + + Args: + operation: Name of the operation being performed + subject_ids: Dictionary of subject identifiers (e.g., {"repo": "owner/repo", "pr": "123"}) + **context: Additional context to include in logs + + Example: + async with log_operation("rule_evaluation", repo=repo, pr=pr_number): + result = await evaluate_rules(...) + """ + start_time = time.time() + log_context = { + "operation": operation, + **(subject_ids or {}), + **context, + } + + logger.info(f"🚀 Starting {operation}", extra=log_context) + + try: + yield + except Exception as e: + latency_ms = int((time.time() - start_time) * 1000) + logger.error( + f"❌ {operation} failed after {latency_ms}ms", + extra={**log_context, "error": str(e), "latency_ms": latency_ms}, + exc_info=True, + ) + raise + else: + latency_ms = int((time.time() - start_time) * 1000) + logger.info( + f"✅ {operation} completed in {latency_ms}ms", + extra={**log_context, "latency_ms": latency_ms}, + ) + + +def log_function_call(operation: str | None = None): + """ + Decorator for logging function calls with timing. + + Args: + operation: Custom operation name (defaults to function name) + + Returns: + Decorated function with logging + + Example: + @log_function_call(operation="fetch_data") + async def fetch_data(): + return await api_call() + """ + + def decorator(func): + op_name = operation or func.__name__ + + @wraps(func) + async def async_wrapper(*args, **kwargs): + start_time = time.time() + logger.info(f"🚀 Calling {op_name}") + + try: + result = await func(*args, **kwargs) + latency_ms = int((time.time() - start_time) * 1000) + logger.info(f"✅ {op_name} completed in {latency_ms}ms") + return result + except Exception as e: + latency_ms = int((time.time() - start_time) * 1000) + logger.error( + f"❌ {op_name} failed after {latency_ms}ms: {e}", + exc_info=True, + ) + raise + + @wraps(func) + def sync_wrapper(*args, **kwargs): + start_time = time.time() + logger.info(f"🚀 Calling {op_name}") + + try: + result = func(*args, **kwargs) + latency_ms = int((time.time() - start_time) * 1000) + logger.info(f"✅ {op_name} completed in {latency_ms}ms") + return result + except Exception as e: + latency_ms = int((time.time() - start_time) * 1000) + logger.error( + f"❌ {op_name} failed after {latency_ms}ms: {e}", + exc_info=True, + ) + raise + + # Return appropriate wrapper based on whether function is async + import inspect + + if inspect.iscoroutinefunction(func): + return async_wrapper + return sync_wrapper + + return decorator diff --git a/src/core/utils/metrics.py b/src/core/utils/metrics.py new file mode 100644 index 0000000..5214810 --- /dev/null +++ b/src/core/utils/metrics.py @@ -0,0 +1,145 @@ +""" +Performance metrics utilities. + +Provides utilities for tracking and recording performance metrics +for operations, API calls, and agent executions. +""" + +import logging +import time +from contextlib import asynccontextmanager +from functools import wraps +from typing import Any + +logger = logging.getLogger(__name__) + + +@asynccontextmanager +async def track_metrics( + operation: str, + **metadata: Any, +): + """ + Context manager for tracking operation metrics. + + Records timing and metadata for performance analysis. + + Args: + operation: Name of the operation + **metadata: Additional metadata to record + + Yields: + Dictionary with metrics that can be updated during execution + + Example: + async with track_metrics("rule_evaluation", rule_count=5) as metrics: + result = await evaluate_rules(...) + metrics["violations_found"] = len(result.violations) + """ + start_time = time.time() + metrics: dict[str, Any] = { + "operation": operation, + "start_time": start_time, + **metadata, + } + + try: + yield metrics + finally: + end_time = time.time() + latency_ms = int((end_time - start_time) * 1000) + metrics.update( + { + "end_time": end_time, + "latency_ms": latency_ms, + "success": "error" not in metrics, + } + ) + + # Log metrics + logger.info( + f"📊 Metrics for {operation}: {latency_ms}ms", + extra=metrics, + ) + + +def metrics_decorator(operation: str | None = None, **default_metadata: Any): + """ + Decorator for tracking function call metrics. + + Args: + operation: Custom operation name (defaults to function name) + **default_metadata: Default metadata to include in metrics + + Returns: + Decorated function with metrics tracking + + Example: + @metrics_decorator(operation="api_call", endpoint="/rules") + async def fetch_rules(): + return await api_call() + """ + + def decorator(func): + op_name = operation or func.__name__ + + @wraps(func) + async def async_wrapper(*args, **kwargs): + async with track_metrics(op_name, **default_metadata) as metrics: + try: + result = await func(*args, **kwargs) + metrics["success"] = True + return result + except Exception as e: + metrics["error"] = str(e) + metrics["success"] = False + raise + + @wraps(func) + def sync_wrapper(*args, **kwargs): + start_time = time.time() + metrics: dict[str, Any] = { + "operation": op_name, + "start_time": start_time, + **default_metadata, + } + + try: + result = func(*args, **kwargs) + end_time = time.time() + metrics.update( + { + "end_time": end_time, + "latency_ms": int((end_time - start_time) * 1000), + "success": True, + } + ) + logger.info( + f"📊 Metrics for {op_name}: {metrics['latency_ms']}ms", + extra=metrics, + ) + return result + except Exception as e: + end_time = time.time() + metrics.update( + { + "end_time": end_time, + "latency_ms": int((end_time - start_time) * 1000), + "error": str(e), + "success": False, + } + ) + logger.error( + f"📊 Metrics for {op_name}: {metrics['latency_ms']}ms (failed)", + extra=metrics, + ) + raise + + # Return appropriate wrapper based on whether function is async + import inspect + + if inspect.iscoroutinefunction(func): + return async_wrapper + return sync_wrapper + + return decorator diff --git a/src/core/utils/retry.py b/src/core/utils/retry.py new file mode 100644 index 0000000..087ac6b --- /dev/null +++ b/src/core/utils/retry.py @@ -0,0 +1,135 @@ +""" +Retry utilities with exponential backoff. + +Provides decorators and functions for retrying async operations with +configurable exponential backoff strategies. +""" + +import asyncio +import logging +from collections.abc import Callable +from functools import wraps +from typing import Any, TypeVar + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +def retry_with_backoff( + max_retries: int = 3, + initial_delay: float = 1.0, + max_delay: float = 60.0, + exponential_base: float = 2.0, + exceptions: tuple[type[Exception], ...] = (Exception,), +): + """ + Decorator for retrying async functions with exponential backoff. + + Args: + max_retries: Maximum number of retry attempts + initial_delay: Initial delay in seconds before first retry + max_delay: Maximum delay in seconds between retries + exponential_base: Base for exponential backoff calculation + exceptions: Tuple of exception types to catch and retry on + + Returns: + Decorated async function with retry logic + + Example: + @retry_with_backoff(max_retries=3, initial_delay=1.0) + async def fetch_data(): + return await api_call() + """ + + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + delay = initial_delay + last_exception: Exception | None = None + + for attempt in range(max_retries): + try: + result = await func(*args, **kwargs) + if attempt > 0: + logger.info(f"✅ {func.__name__} succeeded on attempt {attempt + 1}/{max_retries}") + return result + except exceptions as e: + last_exception = e + if attempt == max_retries - 1: + logger.error(f"❌ {func.__name__} failed after {max_retries} attempts: {e}") + raise + + wait_time = min(delay, max_delay) + logger.warning( + f"⚠️ {func.__name__} attempt {attempt + 1}/{max_retries} failed, " + f"retrying in {wait_time:.2f}s: {e}" + ) + await asyncio.sleep(wait_time) + delay *= exponential_base + + # This should never be reached, but type checker needs it + if last_exception: + raise last_exception + raise RuntimeError(f"{func.__name__} failed after {max_retries} attempts") + + return wrapper + + return decorator + + +async def retry_async( + func: Callable[..., Any], + *args: Any, + max_retries: int = 3, + initial_delay: float = 1.0, + max_delay: float = 60.0, + exponential_base: float = 2.0, + exceptions: tuple[type[Exception], ...] = (Exception,), + **kwargs: Any, +) -> Any: + """ + Retry an async function call with exponential backoff. + + Args: + func: Async function to retry + *args: Positional arguments for the function + max_retries: Maximum number of retry attempts + initial_delay: Initial delay in seconds before first retry + max_delay: Maximum delay in seconds between retries + exponential_base: Base for exponential backoff calculation + exceptions: Tuple of exception types to catch and retry on + **kwargs: Keyword arguments for the function + + Returns: + Result of the function call + + Raises: + Exception: If all retries fail + + Example: + result = await retry_async(fetch_data, max_retries=3) + """ + delay = initial_delay + last_exception: Exception | None = None + + for attempt in range(max_retries): + try: + return await func(*args, **kwargs) + except exceptions as e: + last_exception = e + if attempt == max_retries - 1: + logger.error(f"❌ {func.__name__} failed after {max_retries} attempts: {e}") + raise + + wait_time = min(delay, max_delay) + logger.warning( + f"⚠️ {func.__name__} attempt {attempt + 1}/{max_retries} failed, retrying in {wait_time:.2f}s: {e}" + ) + await asyncio.sleep(wait_time) + delay *= exponential_base + + # This should never be reached, but type checker needs it + if last_exception: + raise last_exception + raise RuntimeError(f"{func.__name__} failed after {max_retries} attempts") diff --git a/src/core/utils/timeout.py b/src/core/utils/timeout.py new file mode 100644 index 0000000..4831b99 --- /dev/null +++ b/src/core/utils/timeout.py @@ -0,0 +1,75 @@ +""" +Timeout utilities for async operations. + +Provides functions for executing async operations with timeout handling. +""" + +import asyncio +import logging +from collections.abc import Coroutine +from typing import Any + +logger = logging.getLogger(__name__) + + +async def execute_with_timeout( + coro: Coroutine[Any, Any, Any], + timeout: float = 30.0, + timeout_message: str | None = None, +) -> Any: + """ + Execute a coroutine with timeout handling. + + Args: + coro: The coroutine to execute + timeout: Timeout in seconds + timeout_message: Custom message for timeout exception + + Returns: + The result of the coroutine + + Raises: + TimeoutError: If the operation times out + + Example: + result = await execute_with_timeout( + long_running_operation(), + timeout=60.0 + ) + """ + try: + return await asyncio.wait_for(coro, timeout=timeout) + except TimeoutError as err: + msg = timeout_message or f"Operation timed out after {timeout} seconds" + logger.error(f"❌ {msg}") + raise TimeoutError(msg) from err + + +def timeout_decorator(timeout: float = 30.0, timeout_message: str | None = None): + """ + Decorator for adding timeout to async functions. + + Args: + timeout: Timeout in seconds + timeout_message: Custom message for timeout exception + + Returns: + Decorated async function with timeout + + Example: + @timeout_decorator(timeout=60.0) + async def long_operation(): + await asyncio.sleep(100) # Will timeout + """ + + def decorator(func): + async def wrapper(*args, **kwargs): + return await execute_with_timeout( + func(*args, **kwargs), + timeout=timeout, + timeout_message=timeout_message, + ) + + return wrapper + + return decorator diff --git a/src/event_processors/base.py b/src/event_processors/base.py index 3dac7d8..04de24d 100644 --- a/src/event_processors/base.py +++ b/src/event_processors/base.py @@ -5,9 +5,9 @@ from pydantic import BaseModel, Field from src.core.models import WebhookEvent -from src.integrations.github_api import github_client -from src.rules.github_provider import GitHubRuleLoader +from src.integrations.github import github_client from src.rules.interface import RuleLoader +from src.rules.loaders.github_loader import GitHubRuleLoader from src.tasks.task_queue import Task logger = logging.getLogger(__name__) diff --git a/src/event_processors/check_run.py b/src/event_processors/check_run.py index 712e85c..6d80ee6 100644 --- a/src/event_processors/check_run.py +++ b/src/event_processors/check_run.py @@ -2,7 +2,7 @@ import time from typing import Any -from src.agents.engine_agent.agent import RuleEngineAgent +from src.agents import get_agent from src.event_processors.base import BaseEventProcessor, ProcessingResult from src.tasks.task_queue import Task @@ -17,7 +17,7 @@ def __init__(self): super().__init__() # Create instance of hybrid RuleEngineAgent - self.engine_agent = RuleEngineAgent() + self.engine_agent = get_agent("engine") def get_event_type(self) -> str: return "check_run" diff --git a/src/event_processors/deployment_protection_rule.py b/src/event_processors/deployment_protection_rule.py index 7c5a7ee..d97c77b 100644 --- a/src/event_processors/deployment_protection_rule.py +++ b/src/event_processors/deployment_protection_rule.py @@ -2,7 +2,7 @@ import time from typing import Any -from src.agents.engine_agent.agent import RuleEngineAgent +from src.agents import get_agent from src.event_processors.base import BaseEventProcessor, ProcessingResult from src.tasks.scheduler.deployment_scheduler import get_deployment_scheduler from src.tasks.task_queue import Task @@ -18,7 +18,7 @@ def __init__(self): super().__init__() # Create instance of hybrid RuleEngineAgent - self.engine_agent = RuleEngineAgent() + self.engine_agent = get_agent("engine") def get_event_type(self) -> str: return "deployment_protection_rule" @@ -262,6 +262,6 @@ async def prepare_api_data(self, task: Task) -> dict[str, Any]: return {} def _get_rule_provider(self): - from src.rules.github_provider import github_rule_loader + from src.rules.loaders.github_loader import github_rule_loader return github_rule_loader diff --git a/src/event_processors/deployment_review.py b/src/event_processors/deployment_review.py index 2e47167..50be9aa 100644 --- a/src/event_processors/deployment_review.py +++ b/src/event_processors/deployment_review.py @@ -2,7 +2,7 @@ import time from typing import Any -from src.agents.engine_agent.agent import RuleEngineAgent +from src.agents import get_agent from src.event_processors.base import BaseEventProcessor, ProcessingResult from src.tasks.task_queue import Task @@ -17,7 +17,7 @@ def __init__(self): super().__init__() # Create instance of hybrid RuleEngineAgent - self.engine_agent = RuleEngineAgent() + self.engine_agent = get_agent("engine") def get_event_type(self) -> str: return "deployment_review" @@ -86,7 +86,7 @@ async def process(self, task: Task) -> ProcessingResult: logger.info(f"📋 Found {len(deployment_review_rules)} applicable rules for deployment_review") # Convert rules to the new format expected by the agent - formatted_rules = self._convert_rules_to_new_format(deployment_review_rules) + formatted_rules = DeploymentReviewProcessor._convert_rules_to_new_format(deployment_review_rules) # Run agentic analysis using the instance result = await self.engine_agent.execute( @@ -109,7 +109,16 @@ async def process(self, task: Task) -> ProcessingResult: processing_time_ms=int((time.time() - start_time) * 1000), ) - def _convert_rules_to_new_format(self, rules: list[Any]) -> list[dict[str, Any]]: + async def prepare_webhook_data(self, task) -> dict[str, Any]: + """Extract data available in webhook payload.""" + return task.payload + + async def prepare_api_data(self, task) -> dict[str, Any]: + """Fetch data not available in webhook.""" + return {} + + @staticmethod + def _convert_rules_to_new_format(rules: list[Any]) -> list[dict[str, Any]]: """Convert Rule objects to the new flat schema format.""" formatted_rules = [] @@ -132,15 +141,8 @@ def _convert_rules_to_new_format(self, rules: list[Any]) -> list[dict[str, Any]] return formatted_rules - async def prepare_webhook_data(self, task) -> dict[str, Any]: - """Extract data available in webhook payload.""" - return task.payload - - async def prepare_api_data(self, task) -> dict[str, Any]: - """Fetch data not available in webhook.""" - return {} - - def _format_violation_comment(self, violations): + @staticmethod + def _format_violation_comment(violations): lines = [] for v in violations: emoji = "❌" if v.get("severity", "high") in ("critical", "high") else "⚠️" diff --git a/src/event_processors/pull_request.py b/src/event_processors/pull_request.py index 743f02f..8d16271 100644 --- a/src/event_processors/pull_request.py +++ b/src/event_processors/pull_request.py @@ -3,9 +3,9 @@ import time from typing import Any -from src.agents.engine_agent.agent import RuleEngineAgent +from src.agents import get_agent from src.event_processors.base import BaseEventProcessor, ProcessingResult -from src.rules.github_provider import RulesFileNotFoundError +from src.rules.loaders.github_loader import RulesFileNotFoundError from src.tasks.task_queue import Task logger = logging.getLogger(__name__) @@ -19,7 +19,7 @@ def __init__(self): super().__init__() # Create instance of RuleEngineAgent - self.engine_agent = RuleEngineAgent() + self.engine_agent = get_agent("engine") def get_event_type(self) -> str: return "pull_request" @@ -74,7 +74,11 @@ async def process(self, task: Task) -> ProcessingResult: logger.info("📋 Rules applicable to pull_request events:") for rule in formatted_rules: if "pull_request" in rule.get("event_types", []): - logger.info(f" - {rule.get('name', 'Unknown')} ({rule.get('id', 'unknown')})") + description = rule.get("description", "Unknown rule") + severity = rule.get("severity", "medium") + # Truncate long descriptions for cleaner logs + desc_preview = description[:60] + "..." if len(description) > 60 else description + logger.info(f" - {desc_preview} ({severity})") # Check for existing acknowledgments from previous comments first pr_data = task.payload.get("pull_request", {}) diff --git a/src/event_processors/push.py b/src/event_processors/push.py index 92d67f3..0954554 100644 --- a/src/event_processors/push.py +++ b/src/event_processors/push.py @@ -2,7 +2,7 @@ import time from typing import Any -from src.agents.engine_agent.agent import RuleEngineAgent +from src.agents import get_agent from src.event_processors.base import BaseEventProcessor, ProcessingResult from src.tasks.task_queue import Task @@ -17,7 +17,7 @@ def __init__(self): super().__init__() # Create instance of hybrid RuleEngineAgent - self.engine_agent = RuleEngineAgent() + self.engine_agent = get_agent("engine") def get_event_type(self) -> str: return "push" diff --git a/src/event_processors/rule_creation.py b/src/event_processors/rule_creation.py index 12b9549..48adfbd 100644 --- a/src/event_processors/rule_creation.py +++ b/src/event_processors/rule_creation.py @@ -3,7 +3,7 @@ import time from typing import Any -from src.agents.feasibility_agent.agent import RuleFeasibilityAgent +from src.agents import get_agent from src.event_processors.base import BaseEventProcessor, ProcessingResult from src.tasks.task_queue import Task @@ -18,7 +18,7 @@ def __init__(self): super().__init__() # Create instance using new structure - self.feasibility_agent = RuleFeasibilityAgent() + self.feasibility_agent = get_agent("feasibility") def get_event_type(self) -> str: return "rule_creation" diff --git a/src/event_processors/violation_acknowledgment.py b/src/event_processors/violation_acknowledgment.py index 6659170..96bf57e 100644 --- a/src/event_processors/violation_acknowledgment.py +++ b/src/event_processors/violation_acknowledgment.py @@ -3,8 +3,7 @@ import time from typing import Any -from src.agents.acknowledgment_agent.agent import AcknowledgmentAgent -from src.agents.engine_agent.agent import RuleEngineAgent +from src.agents import get_agent from src.core.models import EventType from src.event_processors.base import BaseEventProcessor, ProcessingResult from src.tasks.task_queue import Task @@ -23,9 +22,9 @@ def __init__(self): super().__init__() # Create instance of hybrid RuleEngineAgent for rule evaluation - self.engine_agent = RuleEngineAgent() + self.engine_agent = get_agent("engine") # Create instance of intelligent AcknowledgmentAgent for acknowledgment evaluation - self.acknowledgment_agent = AcknowledgmentAgent() + self.acknowledgment_agent = get_agent("acknowledgment") def get_event_type(self) -> str: return "violation_acknowledgment" @@ -669,6 +668,6 @@ async def prepare_api_data(self, task: Task) -> dict[str, Any]: def _get_rule_provider(self): """Get the rule provider for this processor.""" - from src.rules.github_provider import github_rule_loader + from src.rules.loaders.github_loader import github_rule_loader return github_rule_loader diff --git a/src/integrations/README.md b/src/integrations/README.md new file mode 100644 index 0000000..020e7a1 --- /dev/null +++ b/src/integrations/README.md @@ -0,0 +1,71 @@ +# Integrations Module + +This module contains integrations for external services and APIs. Integrations provide a clean interface between Watchflow and third-party services. + +## Structure + +``` +src/integrations/ +├── providers/ # Provider integrations (OpenAI, Bedrock, Vertex AI) +│ ├── base.py # Base provider interface +│ ├── openai_provider.py +│ ├── bedrock_provider.py +│ ├── vertex_ai_provider.py +│ └── factory.py # Provider factory functions +└── github/ # GitHub API adapter + └── api.py # GitHubClient implementation +``` + +## Providers + +Provider integrations handle integration with model services. They implement a common interface defined in `base.py`. + +### Usage + +```python +from src.integrations.providers import get_provider, get_chat_model + +# Get a provider instance +provider = get_provider(provider="openai", model="gpt-4") + +# Or get a ready-to-use chat model +chat_model = get_chat_model(provider="openai", agent="engine_agent") +``` + +### Supported Providers + +- **OpenAI** - Direct OpenAI API integration +- **AWS Bedrock** - AWS Bedrock with support for inference profiles +- **Vertex AI** - Google Cloud Vertex AI (Model Garden) supporting both Gemini and Claude models + +## GitHub Adapter + +The GitHub adapter provides a client for interacting with the GitHub API, handling authentication, token caching, and API operations. + +### Usage + +```python +from src.integrations.github import github_client + +# Use the global instance +token = await github_client.get_installation_access_token(installation_id) +``` + +## Migration Notes + +### Usage + +All code should use the new import paths: + +```python +# ✅ Use these imports +from src.integrations.providers import get_chat_model +from src.integrations.github import github_client +``` + +## Design Principles + +1. **Separation of Concerns** - Integrations handle external service integration, not business logic +2. **Consistent Interface** - All providers implement the same base interface +3. **Flexible Configuration** - Providers support per-agent configuration +4. **Backward Compatible** - Old import paths continue to work during migration diff --git a/src/integrations/__init__.py b/src/integrations/__init__.py index 38a7268..41554d5 100644 --- a/src/integrations/__init__.py +++ b/src/integrations/__init__.py @@ -1 +1,6 @@ -# Integrations package +""" +Integrations for external services and APIs. + +This package contains integrations for external services like GitHub, +provider services (OpenAI, Bedrock, Vertex AI), and other third-party APIs. +""" diff --git a/src/integrations/aws_bedrock.py b/src/integrations/aws_bedrock.py deleted file mode 100644 index 985a493..0000000 --- a/src/integrations/aws_bedrock.py +++ /dev/null @@ -1,320 +0,0 @@ -""" -AWS Bedrock integration for AI model access. - -This module handles AWS Bedrock API interactions, including both -standard langchain-aws clients and the Anthropic Bedrock client -for inference profile support. -""" - -from __future__ import annotations - -import os -from typing import Any - -from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.messages import BaseMessage -from langchain_core.outputs import ChatGeneration, ChatResult - -from src.core.config import config - - -def get_anthropic_bedrock_client() -> Any: - """ - Get Anthropic Bedrock client for models requiring inference profiles. - - This client handles newer Anthropic models that require inference profiles - instead of direct on-demand access. - Uses AWS profile authentication if explicit credentials are not provided. - - Returns: - AnthropicBedrock client instance - """ - try: - from anthropic import AnthropicBedrock - except ImportError as e: - raise RuntimeError( - "Anthropic Bedrock client requires 'anthropic' package. Install with: pip install anthropic" - ) from e - - # Get AWS credentials from config - aws_access_key = config.ai.aws_access_key_id - aws_secret_key = config.ai.aws_secret_access_key - aws_region = config.ai.bedrock_region or "us-east-1" - - # Set AWS profile if specified in config - aws_profile = config.ai.aws_profile - if aws_profile: - os.environ["AWS_PROFILE"] = aws_profile - - # Prepare client parameters - following the official Anthropic client pattern - client_kwargs = { - "aws_region": aws_region, - "aws_profile": aws_profile, - } - - # Add credentials only if they are provided - if aws_access_key and aws_secret_key: - client_kwargs.update( - { - "aws_access_key": aws_access_key, - "aws_secret_key": aws_secret_key, - } - ) - # If no explicit credentials, boto3 will use AWS profile/default credentials - - return AnthropicBedrock(**client_kwargs) - - -def get_standard_bedrock_client() -> Any: - """ - Get standard langchain-aws Bedrock client for on-demand models. - - This client works with models that have direct on-demand access enabled. - Uses AWS profile authentication if explicit credentials are not provided. - - Returns: - ChatBedrock client instance - """ - try: - from langchain_aws import ChatBedrock - except ImportError as e: - raise RuntimeError( - "Standard Bedrock client requires 'langchain-aws' package. Install with: pip install langchain-aws" - ) from e - - # Get AWS credentials from config - aws_access_key = config.ai.aws_access_key_id - aws_secret_key = config.ai.aws_secret_access_key - aws_region = config.ai.bedrock_region or "us-east-1" - - # Set AWS profile if specified in config - aws_profile = config.ai.aws_profile - if aws_profile: - os.environ["AWS_PROFILE"] = aws_profile - - # Get model ID from config - model_id = config.ai.get_model_for_provider("bedrock") - client_kwargs = { - "model_id": model_id, - "region_name": aws_region, - } - - # If using an ARN or inference profile ID, we need to specify the provider - if model_id.startswith("arn:") or model_id.startswith("us.") or model_id.startswith("global."): - # Extract provider from model ID - if "anthropic" in model_id.lower(): - client_kwargs["provider"] = "anthropic" - elif "amazon" in model_id.lower(): - client_kwargs["provider"] = "amazon" - elif "meta" in model_id.lower(): - client_kwargs["provider"] = "meta" - - # Add credentials only if they are provided - if aws_access_key and aws_secret_key: - client_kwargs.update( - { - "aws_access_key_id": aws_access_key, - "aws_secret_access_key": aws_secret_key, - } - ) - # If no explicit credentials, boto3 will use AWS profile/default credentials - - return ChatBedrock(**client_kwargs) - - -def is_anthropic_model(model_id: str) -> bool: - """Check if a model ID is an Anthropic model.""" - return model_id.startswith("anthropic.") - - -def _find_inference_profile(model_id: str) -> str | None: - """ - Find an inference profile that contains the specified model. - - Args: - model_id: The model identifier to find a profile for - - Returns: - Inference profile ARN if found, None otherwise - """ - try: - import boto3 - - # Get AWS credentials from config - aws_region = config.ai.bedrock_region or "us-east-1" - aws_access_key = config.ai.aws_access_key_id - aws_secret_key = config.ai.aws_secret_access_key - - # Create Bedrock client - client_kwargs = {"region_name": aws_region} - if aws_access_key and aws_secret_key: - client_kwargs.update({"aws_access_key_id": aws_access_key, "aws_secret_access_key": aws_secret_key}) - - bedrock = boto3.client("bedrock", **client_kwargs) - - # List inference profiles - response = bedrock.list_inference_profiles() - profiles = response.get("inferenceProfiles", []) - - # Look for profiles that might contain this model - for profile in profiles: - profile_name = profile.get("name", "") - profile_arn = profile.get("arn", "") - - # Check if this profile likely contains the model - if any(keyword in profile_name.lower() for keyword in ["claude", "anthropic", "general", "default"]): - if "anthropic" in model_id.lower() or "claude" in model_id.lower(): - return profile_arn - elif any(keyword in profile_name.lower() for keyword in ["amazon", "titan", "nova"]): - if "amazon" in model_id.lower() or "titan" in model_id.lower() or "nova" in model_id.lower(): - return profile_arn - elif any(keyword in profile_name.lower() for keyword in ["meta", "llama"]): - if "meta" in model_id.lower() or "llama" in model_id.lower(): - return profile_arn - - return None - - except Exception: - # If we can't find inference profiles, return None - return None - - -def get_bedrock_client() -> Any: - """ - Get the appropriate Bedrock client based on configured model type. - - Returns: - Appropriate Bedrock client (Anthropic or standard) - """ - # Get model ID from config - model_id = config.ai.get_model_for_provider("bedrock") - - # Check if this is already an inference profile ID - if model_id.startswith("us.") or model_id.startswith("global.") or model_id.startswith("arn:"): - # This is already an inference profile ID, use Anthropic client directly - return get_anthropic_inference_profile_client(model_id) - - # First, try to find an inference profile for this model - inference_profile = _find_inference_profile(model_id) - - if inference_profile: - # Use inference profile with Anthropic client - return get_anthropic_inference_profile_client(inference_profile) - - # Fallback to direct model access - if is_anthropic_model(model_id): - # For Anthropic models, try standard client first (supports structured output) - try: - return get_standard_bedrock_client() - except Exception: - # If standard client fails, fall back to Anthropic client - client = get_anthropic_bedrock_client() - return _wrap_anthropic_client(client, model_id) - else: - # Use standard client for other models (requires on-demand access) - return get_standard_bedrock_client() - - -def _wrap_anthropic_client(client: Any, model_id: str) -> Any: - """ - Wrap Anthropic Bedrock client to be langchain-compatible. - - This creates a wrapper that implements the langchain interface - while using the Anthropic client under the hood. - """ - - class AnthropicBedrockWrapper(BaseChatModel): - """Wrapper for Anthropic Bedrock client to be langchain-compatible.""" - - anthropic_client: Any - model_id: str - max_tokens: int - temperature: float - - def __init__(self, anthropic_client: Any, model_id: str): - super().__init__( - anthropic_client=anthropic_client, - model_id=model_id, - max_tokens=config.ai.engine_agent.max_tokens if config.ai.engine_agent else config.ai.max_tokens, - temperature=config.ai.engine_agent.temperature if config.ai.engine_agent else config.ai.temperature, - ) - - @property - def _llm_type(self) -> str: - return "anthropic_bedrock" - - def with_structured_output(self, output_model: Any) -> Any: - """Add structured output support to the Anthropic wrapper.""" - # For now, return self and let the calling code handle structured output - # This is a temporary solution - we'll implement proper structured output later - return self - - def _generate( - self, - messages: list[BaseMessage], - stop: list[str] | None = None, - run_manager: Any | None = None, - ) -> ChatResult: - """Generate a response using the Anthropic client.""" - # Convert langchain messages to Anthropic format - anthropic_messages = [] - for msg in messages: - # Convert LangChain message types to Anthropic format - if msg.type == "human": - role = "user" - elif msg.type == "ai": - role = "assistant" - elif msg.type == "system": - role = "user" # Anthropic doesn't have system role, use user - else: - role = "user" # Default to user - - anthropic_messages.append({"role": role, "content": msg.content}) - - # Call Anthropic API - response = self.anthropic_client.messages.create( - model=self.model_id, - max_tokens=self.max_tokens, - temperature=self.temperature, - messages=anthropic_messages, - ) - - # Convert response back to langchain format - content = response.content[0].text if response.content else "" - message = BaseMessage(content=content, type="assistant") - generation = ChatGeneration(message=message) - - return ChatResult(generations=[generation]) - - async def _agenerate( - self, - messages: list[BaseMessage], - stop: list[str] | None = None, - run_manager: Any | None = None, - ) -> ChatResult: - """Async generate using the Anthropic client.""" - # For now, just call the sync version - # TODO: Implement proper async support - return self._generate(messages, stop, run_manager) - - return AnthropicBedrockWrapper(client, model_id) - - -def get_anthropic_inference_profile_client(inference_profile_id: str) -> Any: - """ - Get Anthropic client configured for inference profile models. - - This is the key function that uses the inference profile ID directly - as the model ID, following the Anthropic client pattern. - - Args: - inference_profile_id: The inference profile ID (e.g., 'us.anthropic.claude-3-5-haiku-20241022-v1:0') - - Returns: - Wrapped Anthropic client that works with LangChain - """ - # Get the base Anthropic client - client = get_anthropic_bedrock_client() - - # Wrap it with the inference profile ID as the model - return _wrap_anthropic_client(client, inference_profile_id) diff --git a/src/integrations/gcp_garden.py b/src/integrations/gcp_garden.py deleted file mode 100644 index c9d5c72..0000000 --- a/src/integrations/gcp_garden.py +++ /dev/null @@ -1,198 +0,0 @@ -""" -GCP Model Garden integration for AI model access. - -This module handles Google Cloud Platform Model Garden API interactions -for AI model access, supporting both Google (Gemini) and third-party (Claude) models. -""" - -from __future__ import annotations - -import os -from typing import Any - -from src.core.config import config - - -def get_garden_client() -> Any: - """ - Get GCP Model Garden client for accessing both Google and third-party models. - - Returns: - Model Garden client instance - """ - # Use Model Garden client for better model selection - return get_model_garden_client() - - -def get_model_garden_client() -> Any: - """ - Get GCP Model Garden client for accessing both Google and third-party models. - - This client provides access to models from various providers through - Google's Model Garden marketplace, including: - - Google models: gemini-1.0-pro, gemini-1.5-pro, gemini-2.0-flash-exp - - Third-party models: Claude, Llama, etc. (when available) - - Returns: - Model Garden client instance - """ - # Get GCP credentials from config - project_id = config.ai.gcp_project - location = config.ai.gcp_location or "us-central1" - service_account_key_base64 = config.ai.gcp_service_account_key_base64 - model = config.ai.get_model_for_provider("garden") - - if not project_id: - raise ValueError("GCP project ID required for Model Garden. Set GCP_PROJECT_ID in config") - - # Handle base64 encoded service account key - if service_account_key_base64: - import base64 - import tempfile - - try: - # Decode the base64 key - key_data = base64.b64decode(service_account_key_base64).decode("utf-8") - - # Create a temporary file with the key - with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: - f.write(key_data) - credentials_path = f.name - - # Set the environment variable for Google Cloud to use - os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = credentials_path - - except Exception as e: - raise ValueError(f"Failed to decode GCP service account key: {e}") from e - - # Check if it's a Claude model - if "claude" in model.lower(): - return get_claude_model_garden_client(project_id, location, model) - else: - return get_gemini_model_garden_client(project_id, location, model) - - -def get_claude_model_garden_client(project_id: str, location: str, model: str) -> Any: - """ - Get Claude model via GCP Model Garden using Anthropic Vertex SDK. - - Note: The AnthropicVertex SDK is used for Claude models in Model Garden, - even though the provider is called "garden" in our configuration. - - Args: - project_id: GCP project ID - location: GCP location/region - model: Model name (e.g., claude-3-opus@20240229) - - Returns: - Claude client instance - """ - try: - from anthropic import AnthropicVertex - except ImportError as e: - raise RuntimeError( - "Claude Model Garden client requires 'anthropic[vertex]' package. " - "Install with: pip install 'anthropic[vertex]'" - ) from e - - # Create Anthropic Vertex client (this is the SDK class name for Model Garden) - client = AnthropicVertex(region=location, project_id=project_id) - - # Wrap it to match LangChain interface - return ClaudeModelGardenWrapper(client, model) - - -def get_gemini_model_garden_client(project_id: str, location: str, model: str) -> Any: - """ - Get Gemini model via GCP Model Garden using LangChain. - - Note: ChatVertexAI is the LangChain class name for Model Garden models, - even though the provider is called "garden" in our configuration. - - Args: - project_id: GCP project ID - location: GCP location/region - model: Model name (e.g., gemini-pro) - - Returns: - Gemini client instance - """ - try: - from langchain_google_vertexai import ChatVertexAI - except ImportError as e: - raise RuntimeError( - "Gemini Model Garden client requires 'langchain-google-vertexai' package. " - "Install with: pip install langchain-google-vertexai" - ) from e - - # Try multiple Gemini model names in order of preference - model_candidates = [model, "gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash"] - - for candidate_model in model_candidates: - try: - return ChatVertexAI( - model=candidate_model, - project=project_id, - location=location, - ) - except Exception as e: - if "not found" in str(e).lower() or "404" in str(e): - continue # Try next model - else: - raise # Re-raise if it's not a model not found error - - # If all models fail, raise an error - raise RuntimeError( - f"None of the Gemini models are available in your GCP project. " - f"Tried: {', '.join(model_candidates)}. " - f"Please check your GCP project configuration and model access." - ) - - -class ClaudeModelGardenWrapper: - """ - Wrapper for Claude Model Garden client to match LangChain interface. - """ - - def __init__(self, client, model: str): - self.client = client - self.model = model - - async def ainvoke(self, messages, **kwargs): - """Async invoke method.""" - # Convert LangChain messages to Anthropic format - anthropic_messages = [] - for msg in messages: - if hasattr(msg, "content"): - content = msg.content - role = "user" if msg.type == "human" else "assistant" - else: - content = str(msg) - role = "user" - - anthropic_messages.append({"role": role, "content": content}) - - # Call Claude API - response = self.client.messages.create( - model=self.model, - messages=anthropic_messages, - max_tokens=kwargs.get("max_tokens", 4096), - temperature=kwargs.get("temperature", 0.1), - ) - - # Convert response to LangChain format - from langchain_core.messages import AIMessage - - return AIMessage(content=response.content[0].text) - - def invoke(self, messages, **kwargs): - """Sync invoke method.""" - import asyncio - - return asyncio.run(self.ainvoke(messages, **kwargs)) - - def with_structured_output(self, schema, **kwargs): - """Structured output method.""" - # For now, return self and handle structured output in ainvoke - self._output_schema = schema - return self diff --git a/src/integrations/github/__init__.py b/src/integrations/github/__init__.py new file mode 100644 index 0000000..0fbfc96 --- /dev/null +++ b/src/integrations/github/__init__.py @@ -0,0 +1,12 @@ +""" +GitHub API adapter. + +This package provides integrations for GitHub API interactions. +""" + +from src.integrations.github.api import GitHubClient, github_client + +__all__ = [ + "GitHubClient", + "github_client", +] diff --git a/src/integrations/github_api.py b/src/integrations/github/api.py similarity index 100% rename from src/integrations/github_api.py rename to src/integrations/github/api.py diff --git a/src/providers/__init__.py b/src/integrations/providers/__init__.py similarity index 56% rename from src/providers/__init__.py rename to src/integrations/providers/__init__.py index 9b7d5cd..97dd89d 100644 --- a/src/providers/__init__.py +++ b/src/integrations/providers/__init__.py @@ -1,7 +1,7 @@ """ -AI Provider package for managing different AI service providers. +Provider integrations for model services. -This package provides a unified interface for accessing various AI providers +This package provides integrations for different provider services including OpenAI, AWS Bedrock, and Google Vertex AI. The main entry point is the factory functions: @@ -9,7 +9,7 @@ - get_chat_model() - Get a ready-to-use chat model """ -from src.providers.factory import get_chat_model, get_provider +from src.integrations.providers.factory import get_chat_model, get_provider __all__ = [ "get_provider", diff --git a/src/providers/base_provider.py b/src/integrations/providers/base.py similarity index 80% rename from src/providers/base_provider.py rename to src/integrations/providers/base.py index 66ed3eb..067ff12 100644 --- a/src/providers/base_provider.py +++ b/src/integrations/providers/base.py @@ -1,15 +1,15 @@ """ -Base AI Provider interface. +Base Provider interface. -This module defines the abstract base class that all AI providers must implement. +This module defines the abstract base class that all providers must implement. """ from abc import ABC, abstractmethod from typing import Any -class BaseAIProvider(ABC): - """Base class for AI providers.""" +class BaseProvider(ABC): + """Base class for providers.""" def __init__(self, model: str, max_tokens: int = 4096, temperature: float = 0.1, **kwargs): self.model = model @@ -40,3 +40,7 @@ def get_model_info(self) -> dict[str, Any]: "max_tokens": self.max_tokens, "temperature": self.temperature, } + + +# Alias for backward compatibility +BaseAIProvider = BaseProvider diff --git a/src/providers/bedrock_provider.py b/src/integrations/providers/bedrock_provider.py similarity index 98% rename from src/providers/bedrock_provider.py rename to src/integrations/providers/bedrock_provider.py index 1a12bf6..f85f6a2 100644 --- a/src/providers/bedrock_provider.py +++ b/src/integrations/providers/bedrock_provider.py @@ -16,11 +16,11 @@ from langchain_core.outputs import ChatGeneration, ChatResult from src.core.config import config -from src.providers.base_provider import BaseAIProvider +from src.integrations.providers.base import BaseProvider -class BedrockProvider(BaseAIProvider): - """AWS Bedrock AI Provider with hybrid client support.""" +class BedrockProvider(BaseProvider): + """AWS Bedrock Provider with hybrid client support.""" def get_chat_model(self) -> Any: """Get Bedrock chat model using appropriate client.""" diff --git a/src/providers/factory.py b/src/integrations/providers/factory.py similarity index 81% rename from src/providers/factory.py rename to src/integrations/providers/factory.py index ebb745e..da077b4 100644 --- a/src/providers/factory.py +++ b/src/integrations/providers/factory.py @@ -1,20 +1,20 @@ """ -AI Provider Factory. +Provider Factory. This module provides factory functions to create the appropriate -AI provider based on configuration using a simple mapping approach. +provider based on configuration using a simple mapping approach. """ from typing import Any from src.core.config import config -from src.providers.base_provider import BaseAIProvider -from src.providers.bedrock_provider import BedrockProvider -from src.providers.openai_provider import OpenAIProvider -from src.providers.vertex_ai_provider import VertexAIProvider +from src.integrations.providers.base import BaseProvider +from src.integrations.providers.bedrock_provider import BedrockProvider +from src.integrations.providers.openai_provider import OpenAIProvider +from src.integrations.providers.vertex_ai_provider import VertexAIProvider # Provider mapping - canonical names to provider classes -PROVIDER_MAP: dict[str, type[BaseAIProvider]] = { +PROVIDER_MAP: dict[str, type[BaseProvider]] = { "openai": OpenAIProvider, "bedrock": BedrockProvider, "vertex_ai": VertexAIProvider, @@ -34,12 +34,12 @@ def get_provider( temperature: float | None = None, agent: str | None = None, **kwargs: Any, -) -> BaseAIProvider: +) -> BaseProvider: """ - Get the appropriate AI provider based on configuration. + Get the appropriate provider based on configuration. Args: - provider: AI provider name (openai, bedrock, vertex_ai) + provider: Provider name (openai, bedrock, vertex_ai) model: Model name/ID max_tokens: Maximum tokens to generate temperature: Sampling temperature @@ -47,7 +47,7 @@ def get_provider( **kwargs: Additional provider-specific parameters Returns: - Configured AI provider instance + Configured provider instance Raises: ValueError: If provider is not supported @@ -62,7 +62,7 @@ def get_provider( supported = ", ".join( sorted(set(PROVIDER_MAP.keys()) - {"garden", "model_garden", "gcp", "vertex", "vertexai"}) ) - raise ValueError(f"Unsupported AI provider: {provider_name}. Supported: {supported}") + raise ValueError(f"Unsupported provider: {provider_name}. Supported: {supported}") # Get model with fallbacks handled by config if not model: @@ -113,7 +113,7 @@ def get_chat_model( This is a convenience function that creates a provider and returns its chat model. Args: - provider: AI provider name (openai, bedrock, vertex_ai) + provider: Provider name (openai, bedrock, vertex_ai) model: Model name/ID max_tokens: Maximum tokens to generate temperature: Sampling temperature @@ -133,3 +133,7 @@ def get_chat_model( ) return provider_instance.get_chat_model() + + +# Backward compatibility aliases +BaseAIProvider = BaseProvider diff --git a/src/providers/openai_provider.py b/src/integrations/providers/openai_provider.py similarity index 85% rename from src/providers/openai_provider.py rename to src/integrations/providers/openai_provider.py index 369d49a..13624bf 100644 --- a/src/providers/openai_provider.py +++ b/src/integrations/providers/openai_provider.py @@ -1,16 +1,16 @@ """ -OpenAI AI Provider implementation. +OpenAI Provider implementation. This provider handles OpenAI API integration directly. """ from typing import Any -from src.providers.base_provider import BaseAIProvider +from src.integrations.providers.base import BaseProvider -class OpenAIProvider(BaseAIProvider): - """OpenAI AI Provider.""" +class OpenAIProvider(BaseProvider): + """OpenAI Provider.""" def get_chat_model(self) -> Any: """Get OpenAI chat model.""" diff --git a/src/providers/vertex_ai_provider.py b/src/integrations/providers/vertex_ai_provider.py similarity index 98% rename from src/providers/vertex_ai_provider.py rename to src/integrations/providers/vertex_ai_provider.py index 77f6fd1..74e14b0 100644 --- a/src/providers/vertex_ai_provider.py +++ b/src/integrations/providers/vertex_ai_provider.py @@ -14,10 +14,10 @@ from typing import Any from src.core.config import config -from src.providers.base_provider import BaseAIProvider +from src.integrations.providers.base import BaseProvider -class VertexAIProvider(BaseAIProvider): +class VertexAIProvider(BaseProvider): """Google Vertex AI Provider (Model Garden).""" def get_chat_model(self) -> Any: diff --git a/src/rules/loaders/__init__.py b/src/rules/loaders/__init__.py new file mode 100644 index 0000000..af44edf --- /dev/null +++ b/src/rules/loaders/__init__.py @@ -0,0 +1,18 @@ +""" +Rule loaders package. + +This package contains implementations of the RuleLoader interface +for loading rules from different sources (GitHub, database, etc.). +""" + +from src.rules.loaders.github_loader import ( + GitHubRuleLoader, + RulesFileNotFoundError, + github_rule_loader, +) + +__all__ = [ + "GitHubRuleLoader", + "RulesFileNotFoundError", + "github_rule_loader", +] diff --git a/src/rules/github_provider.py b/src/rules/loaders/github_loader.py similarity index 87% rename from src/rules/github_provider.py rename to src/rules/loaders/github_loader.py index d9910f7..7782100 100644 --- a/src/rules/github_provider.py +++ b/src/rules/loaders/github_loader.py @@ -1,3 +1,9 @@ +""" +GitHub-based rule loader. + +Loads rules from GitHub repository files, implementing the RuleLoader interface. +""" + import logging from typing import Any @@ -5,7 +11,7 @@ from src.core.config import config from src.core.models import EventType -from src.integrations.github_api import GitHubClient, github_client +from src.integrations.github import GitHubClient, github_client from src.rules.interface import RuleLoader from src.rules.models import Rule, RuleAction, RuleSeverity @@ -21,11 +27,11 @@ class RulesFileNotFoundError(Exception): class GitHubRuleLoader(RuleLoader): """ Loads rules from a GitHub repository's rules yaml file. - This provider does NOT map parameters to condition types; it loads rules as-is. + This loader does NOT map parameters to condition types; it loads rules as-is. """ - def __init__(self, github_client: GitHubClient): - self.github_client = github_client + def __init__(self, client: GitHubClient): + self.github_client = client async def get_rules(self, repository: str, installation_id: int) -> list[Rule]: try: @@ -46,7 +52,7 @@ async def get_rules(self, repository: str, installation_id: int) -> list[Rule]: rules = [] for rule_data in rules_data["rules"]: try: - rule = self._parse_rule(rule_data) + rule = GitHubRuleLoader._parse_rule(rule_data) if rule: rules.append(rule) except Exception as e: @@ -63,7 +69,8 @@ async def get_rules(self, repository: str, installation_id: int) -> list[Rule]: logger.error(f"Error fetching rules for {repository}: {e}") raise - def _parse_rule(self, rule_data: dict[str, Any]) -> Rule: + @staticmethod + def _parse_rule(rule_data: dict[str, Any]) -> Rule: # Validate required fields if "description" not in rule_data: raise ValueError("Rule must have 'description' field") diff --git a/src/rules/utils.py b/src/rules/utils.py index afd1017..3f8dc83 100644 --- a/src/rules/utils.py +++ b/src/rules/utils.py @@ -3,7 +3,7 @@ import yaml -from src.integrations.github_api import github_client +from src.integrations.github import github_client from src.rules.models import Rule logger = logging.getLogger(__name__) diff --git a/src/rules/utils/__init__.py b/src/rules/utils/__init__.py new file mode 100644 index 0000000..ed61d4a --- /dev/null +++ b/src/rules/utils/__init__.py @@ -0,0 +1,35 @@ +""" +Rule evaluation utilities. + +This package contains utilities used by rule validators, including +CODEOWNERS parsing, contributor analysis, and rule validation. +""" + +from src.rules.utils.codeowners import ( + get_file_owners, + is_critical_file, + load_codeowners, +) +from src.rules.utils.contributors import ( + get_contributor_analyzer, + get_past_contributors, + is_new_contributor, +) +from src.rules.utils.validation import ( + _validate_rules_yaml, + validate_rules_yaml_from_repo, +) + +__all__ = [ + "get_file_owners", + "is_critical_file", + "load_codeowners", + "get_contributor_analyzer", + "get_past_contributors", + "is_new_contributor", + "_validate_rules_yaml", + "validate_rules_yaml_from_repo", +] + +# Alias for backward compatibility +get_codeowners = load_codeowners diff --git a/src/integrations/codeowners.py b/src/rules/utils/codeowners.py similarity index 95% rename from src/integrations/codeowners.py rename to src/rules/utils/codeowners.py index 567673a..374c16a 100644 --- a/src/integrations/codeowners.py +++ b/src/rules/utils/codeowners.py @@ -1,5 +1,8 @@ """ -Utilities for parsing and using CODEOWNERS files. +Rule evaluation utilities for parsing and using CODEOWNERS files. + +These utilities are used by rule validators to check code ownership +requirements and determine critical file patterns. """ import logging @@ -90,7 +93,7 @@ def _matches_pattern(self, file_path: str, pattern: str) -> bool: return True # Convert pattern to regex - regex_pattern = self._pattern_to_regex(pattern) + regex_pattern = CodeOwnersParser._pattern_to_regex(pattern) try: return bool(re.match(regex_pattern, file_path)) @@ -98,7 +101,8 @@ def _matches_pattern(self, file_path: str, pattern: str) -> bool: logger.error(f"Invalid regex pattern: {regex_pattern}") return False - def _pattern_to_regex(self, pattern: str) -> str: + @staticmethod + def _pattern_to_regex(pattern: str) -> str: """ Convert a CODEOWNERS pattern to a regex pattern. diff --git a/src/integrations/contributors.py b/src/rules/utils/contributors.py similarity index 93% rename from src/integrations/contributors.py rename to src/rules/utils/contributors.py index e936599..e230b52 100644 --- a/src/integrations/contributors.py +++ b/src/rules/utils/contributors.py @@ -1,10 +1,15 @@ """ -Utilities for analyzing repository contributors and determining contribution history. +Rule evaluation utilities for analyzing repository contributors. + +These utilities are used by rule validators to check contributor history +and determine if users are new or established contributors. """ import logging from datetime import datetime, timedelta +from src.core.utils.caching import AsyncCache + logger = logging.getLogger(__name__) @@ -13,8 +18,8 @@ class ContributorAnalyzer: def __init__(self, github_client): self.github_client = github_client - self._contributors_cache: dict[str, dict] = {} - self._cache_ttl = 3600 # 1 hour cache + # Use AsyncCache for better cache management + self._contributors_cache = AsyncCache(maxsize=100, ttl=3600) # 1 hour cache async def get_past_contributors( self, repo: str, installation_id: int, min_contributions: int = 5, days_back: int = 365 @@ -34,11 +39,10 @@ async def get_past_contributors( cache_key = f"{repo}_{min_contributions}_{days_back}" # Check cache first - if cache_key in self._contributors_cache: - cached_data = self._contributors_cache[cache_key] - if datetime.now().timestamp() - cached_data.get("timestamp", 0) < self._cache_ttl: - logger.debug(f"Using cached past contributors for {repo}") - return set(cached_data.get("contributors", [])) + cached_value = self._contributors_cache.get(cache_key) + if cached_value is not None: + logger.debug(f"Using cached past contributors for {repo}") + return set(cached_value) try: logger.info(f"Fetching past contributors for {repo}") @@ -60,10 +64,7 @@ async def get_past_contributors( past_contributors.add(username) # Cache the results - self._contributors_cache[cache_key] = { - "contributors": list(past_contributors), - "timestamp": datetime.now().timestamp(), - } + self._contributors_cache.set(cache_key, list(past_contributors)) logger.info(f"Found {len(past_contributors)} past contributors for {repo}") return past_contributors diff --git a/src/rules/utils/validation.py b/src/rules/utils/validation.py new file mode 100644 index 0000000..7f28ef1 --- /dev/null +++ b/src/rules/utils/validation.py @@ -0,0 +1,119 @@ +""" +Rule validation utilities. + +Functions for validating rule YAML files and posting validation results. +""" + +import logging +from typing import Any + +import yaml + +from src.integrations.github import github_client +from src.rules.models import Rule + +logger = logging.getLogger(__name__) + +DOCS_URL = "https://github.com/warestack/watchflow/blob/main/docs/getting-started/configuration.md" + + +async def validate_rules_yaml_from_repo(repo_full_name: str, installation_id: int, pr_number: int): + """Validate rules YAML and post results to PR comment.""" + validation_result = await _validate_rules_yaml(repo_full_name, installation_id) + # Only post a comment if the result is not a success + if not validation_result["success"]: + await github_client.create_pull_request_comment( + repo=repo_full_name, + pr_number=pr_number, + comment=validation_result["message"], + installation_id=installation_id, + ) + logger.info(f"Posted validation result to PR #{pr_number} in {repo_full_name}") + + +async def _validate_rules_yaml(repo: str, installation_id: int) -> dict[str, Any]: + """Validate rules YAML file from repository.""" + try: + file_content = await github_client.get_file_content(repo, ".watchflow/rules.yaml", installation_id) + if file_content is None: + return { + "success": False, + "message": ( + "⚙️ **Watchflow rules not configured**\n\n" + "No rules file found in your repository. Watchflow can help enforce governance rules for your team.\n\n" + "**How to set up rules:**\n" + "1. Create a file at `.watchflow/rules.yaml` in your repository root\n" + "2. Add your rules in the following format:\n" + " ```yaml\n rules:\n - description: All pull requests must have at least 2 approvals\n enabled: true\n severity: high\n event_types: [pull_request]\n parameters:\n min_approvals: 2\n ```\n\n" + "**Note:** Rules are currently read from the main branch only.\n\n" + "📖 [Read the documentation for more examples](https://github.com/warestack/watchflow/blob/main/docs/getting-started/configuration.md)\n\n" + "After adding the file, push your changes to re-run validation." + ), + } + try: + rules_data = yaml.safe_load(file_content) + except Exception as e: + return { + "success": False, + "message": ( + "❌ **Failed to parse `.watchflow/rules.yaml`**\n\n" + f"Error details: `{e}`\n\n" + "**How to fix:**\n" + "- Ensure your YAML is valid. You can use an online YAML validator.\n" + "- Check for indentation, missing colons, or invalid syntax.\n\n" + f"[See configuration docs.]({DOCS_URL})" + ), + } + if not isinstance(rules_data, dict) or "rules" not in rules_data: + return { + "success": False, + "message": ( + "❌ **Invalid `.watchflow/rules.yaml`: missing top-level `rules:` key**\n\n" + "Your file must start with a `rules:` key, like:\n" + "```yaml\nrules:\n - description: ...\n```\n" + f"[See configuration docs.]({DOCS_URL})" + ), + } + if not isinstance(rules_data["rules"], list): + return { + "success": False, + "message": ( + "❌ **Invalid `.watchflow/rules.yaml`: `rules` must be a list**\n\n" + "Example:\n" + "```yaml\nrules:\n - description: ...\n```\n" + f"[See configuration docs.]({DOCS_URL})" + ), + } + if not rules_data["rules"]: + return { + "success": True, + "message": ( + "✅ **`.watchflow/rules.yaml` is valid but contains no rules.**\n\n" + "You can add rules at any time. [See documentation for examples.]" + f"({DOCS_URL})" + ), + } + for i, rule_data in enumerate(rules_data["rules"]): + try: + Rule.model_validate(rule_data) + except Exception as e: + return { + "success": False, + "message": ( + f"❌ **Rule #{i + 1} failed validation**\n\n" + f"Error: `{e}`\n\n" + "Please check your rule definition and fix the error above.\n\n" + f"[See rule schema docs.]({DOCS_URL})" + ), + } + return { + "success": True, + "message": f"✅ **`.watchflow/rules.yaml` is valid and contains {len(rules_data['rules'])} rules.**\n\nNo action needed.", + } + except Exception as e: + return { + "success": False, + "message": ( + f"❌ **Error validating `.watchflow/rules.yaml`**\n\nError: `{e}`\n\n[See configuration docs.]({DOCS_URL})" + ), + } diff --git a/src/rules/validators.py b/src/rules/validators.py index b04ac19..33f7ca1 100644 --- a/src/rules/validators.py +++ b/src/rules/validators.py @@ -101,7 +101,7 @@ async def validate(self, parameters: dict[str, Any], event: dict[str, Any]) -> b return False # Convert glob pattern to regex - regex_pattern = self._glob_to_regex(pattern) + regex_pattern = FilePatternCondition._glob_to_regex(pattern) # Check if any files match the pattern matching_files = [file for file in changed_files if re.match(regex_pattern, file)] @@ -128,7 +128,8 @@ def _get_changed_files(self, event: dict[str, Any]) -> list[str]: else: return [] - def _glob_to_regex(self, glob_pattern: str) -> str: + @staticmethod + def _glob_to_regex(glob_pattern: str) -> str: """Converts a glob pattern to a regex pattern.""" # Simple conversion - in production, you'd want a more robust implementation regex = glob_pattern.replace(".", "\\.").replace("*", ".*").replace("?", ".") @@ -635,7 +636,7 @@ async def validate(self, parameters: dict[str, Any], event: dict[str, Any]) -> b return True # Check if any of the changed files require code owner review - from src.integrations.codeowners import is_critical_file + from src.rules.utils.codeowners import is_critical_file # Get critical owners from rule parameters or use default behavior critical_owners = parameters.get("critical_owners") @@ -699,7 +700,7 @@ async def validate(self, parameters: dict[str, Any], event: dict[str, Any]) -> b return False # Check if author is a new contributor using the contributor analyzer - from src.integrations.contributors import is_new_contributor + from src.rules.utils.contributors import is_new_contributor is_author_new = await is_new_contributor(author_login, repo, github_client, installation_id) diff --git a/src/tasks/scheduler/deployment_scheduler.py b/src/tasks/scheduler/deployment_scheduler.py index e0aaefe..1b75f5e 100644 --- a/src/tasks/scheduler/deployment_scheduler.py +++ b/src/tasks/scheduler/deployment_scheduler.py @@ -3,8 +3,8 @@ from datetime import datetime, timedelta from typing import Any -from src.agents.engine_agent.agent import RuleEngineAgent -from src.integrations.github_api import github_client +from src.agents import get_agent +from src.integrations.github import github_client logger = logging.getLogger(__name__) @@ -20,10 +20,10 @@ def __init__(self): self._engine_agent = None @property - def engine_agent(self) -> RuleEngineAgent: + def engine_agent(self): """Lazy-load the engine agent to avoid API key validation at import time.""" if self._engine_agent is None: - self._engine_agent = RuleEngineAgent() + self._engine_agent = get_agent("engine") return self._engine_agent async def start(self): @@ -188,7 +188,7 @@ async def _re_evaluate_deployment(self, deployment: dict[str, Any]) -> bool: return False # Convert rules to the format expected by the analysis agent - formatted_rules = self._convert_rules_to_new_format(deployment["rules"]) + formatted_rules = DeploymentScheduler._convert_rules_to_new_format(deployment["rules"]) # Re-run rule analysis result = await self.engine_agent.execute( @@ -321,7 +321,13 @@ async def start_background_scheduler(self): if not self.running: await self.start() - def _convert_rules_to_new_format(self, rules: list[Any]) -> list[dict[str, Any]]: + async def stop_background_scheduler(self): + """Stop the background scheduler task.""" + if self.running: + await self.stop() + + @staticmethod + def _convert_rules_to_new_format(rules: list[Any]) -> list[dict[str, Any]]: """Convert Rule objects to the new flat schema format.""" formatted_rules = [] @@ -350,11 +356,6 @@ def _convert_rules_to_new_format(self, rules: list[Any]) -> list[dict[str, Any]] return formatted_rules - async def stop_background_scheduler(self): - """Stop the background scheduler task.""" - if self.running: - await self.stop() - # Global instance - lazy loaded to avoid API key validation at import time deployment_scheduler = None diff --git a/src/webhooks/handlers/issue_comment.py b/src/webhooks/handlers/issue_comment.py index fe1f506..5e2b9aa 100644 --- a/src/webhooks/handlers/issue_comment.py +++ b/src/webhooks/handlers/issue_comment.py @@ -2,9 +2,9 @@ import re from typing import Any -from src.agents.feasibility_agent.agent import RuleFeasibilityAgent +from src.agents import get_agent from src.core.models import EventType, WebhookEvent -from src.integrations.github_api import github_client +from src.integrations.github import github_client from src.rules.utils import _validate_rules_yaml from src.tasks.task_queue import task_queue from src.webhooks.handlers.base import EventHandler @@ -84,7 +84,7 @@ async def handle(self, event: WebhookEvent) -> dict[str, Any]: # Check if this is an evaluate command eval_rule = self._extract_evaluate_rule(comment_body) if eval_rule is not None: - agent = RuleFeasibilityAgent() + agent = get_agent("feasibility") result = await agent.execute(rule_description=eval_rule) is_feasible = result.data.get("is_feasible", False) yaml_content = result.data.get("yaml_content", "") diff --git a/tests/unit/test_agents.py b/tests/unit/test_agents.py index c51efe3..9a1a049 100644 --- a/tests/unit/test_agents.py +++ b/tests/unit/test_agents.py @@ -9,8 +9,6 @@ from src.agents.base import AgentResult from src.agents.feasibility_agent import RuleFeasibilityAgent -from src.agents.supervisor_agent import RuleSupervisorAgent -from src.agents.supervisor_agent.models import AgentTask, SupervisorAgentResult, SupervisorState class TestBaseAgent: @@ -221,188 +219,5 @@ async def test_execute_with_retry_failure_then_success(self, mock_init): assert mock_execute.call_count == 2 -class TestSupervisorAgent: - """Test supervisor agent functionality.""" - - @patch("src.agents.base.BaseAgent.__init__") - def test_supervisor_agent_initialization(self, mock_init): - """Test supervisor agent initialization.""" - agent = RuleSupervisorAgent(max_concurrent_agents=5) - # Manually set the attributes since we mocked __init__ - agent.max_concurrent_agents = 5 - assert agent.max_concurrent_agents == 5 - assert len(agent.sub_agents) == 3 # feasibility, engine, acknowledgment - assert "feasibility" in agent.sub_agents - assert "engine" in agent.sub_agents - assert "acknowledgment" in agent.sub_agents - - @pytest.mark.asyncio - @patch("src.agents.base.BaseAgent.__init__") - async def test_supervisor_execute_agent_task_success(self, mock_init): - """Test successful execution of agent task.""" - agent = RuleSupervisorAgent() - # Manually set the attributes since we mocked __init__ - agent.max_concurrent_agents = 3 - - task = AgentTask( - agent_name="feasibility", - task_type="rule_feasibility", - parameters={"rule_description": "Test rule"}, - timeout=30.0, - ) - - # Mock sub-agent execution - with patch.object(agent.sub_agents["feasibility"], "execute") as mock_execute: - mock_execute.return_value = AgentResult( - success=True, message="Success", data={"is_feasible": True}, metadata={"execution_time_ms": 1500} - ) - - result = await agent._execute_agent_task(task) - - assert result.success is True - assert result.message == "Success" - assert result.data["is_feasible"] is True - assert result.metadata["execution_time_ms"] == 1500 - - @pytest.mark.asyncio - @patch("src.agents.base.BaseAgent.__init__") - async def test_supervisor_execute_agent_task_timeout(self, mock_init): - """Test agent task execution with timeout.""" - agent = RuleSupervisorAgent() - # Manually set the attributes since we mocked __init__ - agent.max_concurrent_agents = 3 - - task = AgentTask( - agent_name="feasibility", - task_type="rule_feasibility", - parameters={"rule_description": "Test rule"}, - timeout=0.1, # Very short timeout - ) - - # Mock sub-agent that takes too long - with patch.object(agent.sub_agents["feasibility"], "execute") as mock_execute: - - async def slow_execute(*args, **kwargs): - await asyncio.sleep(1.0) # Longer than timeout - return AgentResult(success=True, message="Success", data={}) - - mock_execute.side_effect = slow_execute - - result = await agent._execute_agent_task(task) - - assert result.success is False - assert "timed out" in result.message - assert result.metadata["error_type"] == "timeout" - - @pytest.mark.asyncio - @patch("src.agents.base.BaseAgent.__init__") - async def test_supervisor_coordinate_agents(self, mock_init): - """Test agent coordination.""" - agent = RuleSupervisorAgent() - # Manually set the attributes since we mocked __init__ - agent.max_concurrent_agents = 3 - - result = await agent.coordinate_agents( - "Evaluate rules", - event_type="pull_request", - event_data={"action": "opened"}, - rules=[{"id": "test", "name": "Test Rule"}], - ) - - # Should return a result (even if mock) - assert isinstance(result, AgentResult) - - def test_supervisor_state_creation(self): - """Test supervisor state creation.""" - state = SupervisorState( - task_description="Test task", event_type="pull_request", event_data={"test": "data"}, rules=[{"id": "test"}] - ) - - assert state.task_description == "Test task" - assert state.event_type == "pull_request" - assert len(state.rules) == 1 - assert state.start_time is None # start_time is optional and defaults to None - - -class TestAgentCoordination: - """Test agent coordination patterns.""" - - @pytest.mark.asyncio - async def test_concurrent_agent_execution(self): - """Test concurrent execution of multiple agents.""" - from src.agents.supervisor_agent.nodes import _execute_tasks_concurrently - - tasks = [ - AgentTask(agent_name="agent1", task_type="test", parameters={}), - AgentTask(agent_name="agent2", task_type="test", parameters={}), - AgentTask(agent_name="agent3", task_type="test", parameters={}), - ] - - results = await _execute_tasks_concurrently(tasks, max_concurrent=2) - - assert len(results) == 3 - # All should be mock results for now - assert all(r.success for r in results) - - def test_result_conflict_detection(self): - """Test detection of conflicting results between agents.""" - from src.agents.supervisor_agent.nodes import _detect_result_conflicts - - results = [ - SupervisorAgentResult( - success=True, - message="Agent1 approved", - data={"recommendation": "approve"}, - metadata={"agent_name": "agent1", "execution_time_ms": 100}, - ), - SupervisorAgentResult( - success=True, - message="Agent2 rejected", - data={"recommendation": "reject"}, - metadata={"agent_name": "agent2", "execution_time_ms": 100}, - ), - ] - - conflicts = _detect_result_conflicts(results) - assert len(conflicts) > 0 - assert "Conflicting recommendations" in conflicts[0] - - def test_confidence_score_calculation(self): - """Test confidence score calculation.""" - from src.agents.supervisor_agent.nodes import _calculate_confidence_score - - # All successful results - results = [ - SupervisorAgentResult( - success=True, - message="Agent1 success", - data={}, - metadata={"agent_name": "agent1", "execution_time_ms": 200}, - ), - SupervisorAgentResult( - success=True, - message="Agent2 success", - data={}, - metadata={"agent_name": "agent2", "execution_time_ms": 300}, - ), - ] - - confidence = _calculate_confidence_score(results) - assert confidence > 0.8 # High confidence for all successful results - - # Mixed results - results.append( - SupervisorAgentResult( - success=False, - message="Agent3 failed", - data={}, - metadata={"agent_name": "agent3", "execution_time_ms": 100}, - ) - ) - - confidence = _calculate_confidence_score(results) - assert confidence < 0.8 # Lower confidence with failures - - if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/unit/test_rule_engine_agent.py b/tests/unit/test_rule_engine_agent.py index d4cd66d..4a1d6cf 100644 --- a/tests/unit/test_rule_engine_agent.py +++ b/tests/unit/test_rule_engine_agent.py @@ -235,7 +235,7 @@ async def test_execute_with_timeout_error(self, mock_init): assert result.success is False assert "timed out" in result.message - assert result.metadata["error_type"] == "Exception" + assert result.metadata["error_type"] == "TimeoutError" @pytest.mark.asyncio @patch("src.agents.base.BaseAgent.__init__") From e0fb6c40d34a2b4a60b3ccca50c5c84dc17be415 Mon Sep 17 00:00:00 2001 From: Dimitris Kargatzis Date: Sat, 15 Nov 2025 14:07:43 +0200 Subject: [PATCH 2/2] fix: update integration tests to use agent factory pattern Signed-off-by: Dimitris Kargatzis --- tests/integration/test_rules_api.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/integration/test_rules_api.py b/tests/integration/test_rules_api.py index 94adacd..456fb1f 100644 --- a/tests/integration/test_rules_api.py +++ b/tests/integration/test_rules_api.py @@ -26,10 +26,10 @@ def test_evaluate_feasible_rule_integration(self, client): """Test successful rule evaluation through the complete stack (mocked OpenAI).""" # Mock OpenAI unless real API testing is explicitly enabled if not os.getenv("INTEGRATION_TEST_REAL_API", "false").lower() == "true": - with patch("src.api.rules.RuleFeasibilityAgent") as mock_agent_class: + with patch("src.api.rules.get_agent") as mock_get_agent: # Mock the agent instance mock_agent = MagicMock() - mock_agent_class.return_value = mock_agent + mock_get_agent.return_value = mock_agent # Mock the execute method as async mock_result = AgentResult( @@ -69,10 +69,10 @@ def test_evaluate_unfeasible_rule_integration(self, client): """Test unfeasible rule evaluation through the complete stack (mocked OpenAI).""" # Mock OpenAI unless real API testing is explicitly enabled if not os.getenv("INTEGRATION_TEST_REAL_API", "false").lower() == "true": - with patch("src.api.rules.RuleFeasibilityAgent") as mock_agent_class: + with patch("src.api.rules.get_agent") as mock_get_agent: # Mock the agent instance mock_agent = MagicMock() - mock_agent_class.return_value = mock_agent + mock_get_agent.return_value = mock_agent # Mock the execute method as async mock_result = AgentResult(