diff --git a/src/predicate_secure/__init__.py b/src/predicate_secure/__init__.py index a476c27..0c81e6a 100644 --- a/src/predicate_secure/__init__.py +++ b/src/predicate_secure/__init__.py @@ -21,13 +21,28 @@ from __future__ import annotations +from pathlib import Path +from typing import Any + +from .config import SecureAgentConfig, WrappedAgent +from .detection import ( + DetectionResult, + Framework, + FrameworkDetector, + UnsupportedFrameworkError, +) + __version__ = "0.1.0" -# Public API - will be implemented in subsequent phases +# Public API __all__ = [ "SecureAgent", "SecureAgentConfig", - "PolicyLoader", + "WrappedAgent", + # Framework detection + "Framework", + "FrameworkDetector", + "DetectionResult", # Modes "MODE_STRICT", "MODE_PERMISSIVE", @@ -37,6 +52,7 @@ "AuthorizationDenied", "VerificationFailed", "PolicyLoadError", + "UnsupportedFrameworkError", ] # Mode constants @@ -49,13 +65,17 @@ class AuthorizationDenied(Exception): """Raised when an action is denied by the policy engine.""" - pass + def __init__(self, message: str, decision: Any = None): + super().__init__(message) + self.decision = decision class VerificationFailed(Exception): """Raised when post-execution verification fails.""" - pass + def __init__(self, message: str, predicate: str | None = None): + super().__init__(message) + self.predicate = predicate class PolicyLoadError(Exception): @@ -64,19 +84,6 @@ class PolicyLoadError(Exception): pass -# Placeholder classes - to be implemented -class SecureAgentConfig: - """Configuration for SecureAgent.""" - - pass - - -class PolicyLoader: - """Loads and validates policy files.""" - - pass - - class SecureAgent: """ Drop-in security wrapper for AI agents. @@ -94,34 +101,192 @@ class SecureAgent: mode="strict", ) secure_agent.run() + + Attributes: + config: SecureAgentConfig with all configuration + wrapped: WrappedAgent with detected framework info + authority_context: Initialized AuthorityClient context (lazy) """ def __init__( self, - agent, - policy: str | dict | None = None, + agent: Any, + policy: str | Path | None = None, mode: str = MODE_STRICT, principal_id: str | None = None, + tenant_id: str | None = None, + session_id: str | None = None, sidecar_url: str | None = None, + signing_key: str | None = None, + mandate_ttl_seconds: int = 300, ): """ Initialize SecureAgent wrapper. Args: agent: The agent to wrap (browser-use Agent, Playwright page, etc.) - policy: Policy file path, dict, or None for default + policy: Policy file path or None for env var fallback mode: Execution mode (strict, permissive, debug, audit) principal_id: Agent principal ID (auto-detect from env if not provided) + tenant_id: Tenant ID for multi-tenant deployments + session_id: Session ID for tracking sidecar_url: Sidecar URL (None for embedded mode) + signing_key: Secret key for mandate signing + mandate_ttl_seconds: TTL for issued mandates """ + # Build config from kwargs + self._config = SecureAgentConfig.from_kwargs( + policy=policy, + mode=mode, + principal_id=principal_id, + tenant_id=tenant_id, + session_id=session_id, + sidecar_url=sidecar_url, + signing_key=signing_key, + mandate_ttl_seconds=mandate_ttl_seconds, + ) + + # Detect framework and wrap agent + self._wrapped = self._wrap_agent(agent) + + # Lazy-initialized authority context + self._authority_context: Any = None + + # Legacy attribute access (for backward compat with tests) self._agent = agent self._policy = policy self._mode = mode self._principal_id = principal_id self._sidecar_url = sidecar_url - # TODO: Wire up AgentRuntime, AuthorityClient, RuntimeAgent - def run(self, task: str | None = None): + @property + def config(self) -> SecureAgentConfig: + """Get the configuration.""" + return self._config + + @property + def wrapped(self) -> WrappedAgent: + """Get the wrapped agent with framework info.""" + return self._wrapped + + @property + def framework(self) -> Framework: + """Get the detected framework.""" + return Framework(self._wrapped.framework) + + def _wrap_agent(self, agent: Any) -> WrappedAgent: + """ + Detect framework and wrap agent. + + Args: + agent: The agent to wrap + + Returns: + WrappedAgent with framework info + """ + detection = FrameworkDetector.detect(agent) + + return WrappedAgent( + original=agent, + framework=detection.framework.value, + agent_runtime=None, # Initialized lazily when needed + executor=self._extract_executor(agent, detection), + metadata=detection.metadata, + ) + + def _extract_executor(self, agent: Any, detection: DetectionResult) -> Any | None: + """ + Extract LLM executor from agent if available. + + Args: + agent: The agent object + detection: Detection result + + Returns: + LLM executor or None + """ + # browser-use Agent has .llm attribute + if detection.framework == Framework.BROWSER_USE: + return getattr(agent, "llm", None) + + # LangChain AgentExecutor has .agent or .llm + if detection.framework == Framework.LANGCHAIN: + return getattr(agent, "llm", None) or getattr(agent, "agent", None) + + # PydanticAI agents have .model + if detection.framework == Framework.PYDANTIC_AI: + return getattr(agent, "model", None) + + return None + + def _get_authority_context(self) -> Any: + """ + Get or initialize the authority context. + + Returns: + LocalAuthorizationContext from AuthorityClient + + Raises: + PolicyLoadError: If policy cannot be loaded + """ + if self._authority_context is not None: + return self._authority_context + + policy_path = self._config.effective_policy_path + if policy_path is None: + # No policy = no authorization enforcement + return None + + try: + # Import here to avoid hard dependency + from predicate_authority.client import AuthorityClient + + self._authority_context = AuthorityClient.from_policy_file( + policy_file=policy_path, + secret_key=self._config.effective_signing_key, + ttl_seconds=self._config.mandate_ttl_seconds, + ) + return self._authority_context + except ImportError: + raise PolicyLoadError( + "predicate-authority is required for policy enforcement. " + "Install with: pip install predicate-secure[authority]" + ) + except FileNotFoundError as e: + raise PolicyLoadError(f"Policy file not found: {policy_path}") from e + except Exception as e: + raise PolicyLoadError(f"Failed to load policy: {e}") from e + + def _create_pre_action_authorizer(self) -> Any: + """ + Create a pre-action authorizer callback for RuntimeAgent. + + Returns: + Callable that takes ActionRequest and returns decision + """ + context = self._get_authority_context() + if context is None: + # No policy = allow all + return None + + def authorizer(request: Any) -> Any: + """Pre-action authorization callback.""" + decision = context.client.authorize(request) + + if self._config.mode == "debug": + print(f"[predicate-secure] authorize({request.action}): {decision}") + + if not decision.allowed and self._config.fail_closed: + raise AuthorizationDenied( + f"Action denied: {decision.reason.value if decision.reason else 'policy'}", + decision=decision, + ) + + return decision + + return authorizer + + def run(self, task: str | None = None) -> Any: """ Execute the agent with full authorization + verification loop. @@ -130,12 +295,91 @@ def run(self, task: str | None = None): Returns: Agent execution result + + Raises: + AuthorizationDenied: If an action is denied (in strict mode) + VerificationFailed: If post-execution verification fails + UnsupportedFrameworkError: If framework is not supported """ - # TODO: Implement the full loop - raise NotImplementedError("SecureAgent.run() not yet implemented") + if self._wrapped.framework == Framework.UNKNOWN.value: + detection = FrameworkDetector.detect(self._wrapped.original) + raise UnsupportedFrameworkError(detection) + + # Framework-specific execution + if self._wrapped.framework == Framework.BROWSER_USE.value: + return self._run_browser_use(task) + + if self._wrapped.framework == Framework.PLAYWRIGHT.value: + return self._run_playwright(task) + + if self._wrapped.framework == Framework.LANGCHAIN.value: + return self._run_langchain(task) + + if self._wrapped.framework == Framework.PYDANTIC_AI.value: + return self._run_pydantic_ai(task) + + raise NotImplementedError( + f"run() not implemented for framework: {self._wrapped.framework}" + ) + + def _run_browser_use(self, task: str | None) -> Any: + """Run browser-use agent with authorization.""" + # Import here to avoid hard dependency + try: + import asyncio + + agent = self._wrapped.original + + # Override task if provided + if task is not None: + agent.task = task + + # Check if agent has a run method + if hasattr(agent, "run"): + # browser-use Agent.run() is typically async + if asyncio.iscoroutinefunction(agent.run): + return asyncio.get_event_loop().run_until_complete(agent.run()) + return agent.run() + + raise NotImplementedError( + "browser-use Agent.run() integration not fully implemented. " + "For now, use the agent directly with pre_action_authorizer callback." + ) + except ImportError: + raise NotImplementedError( + "browser-use integration requires the browser-use package. " + "Install with: pip install predicate-secure[browser-use]" + ) + + def _run_playwright(self, task: str | None) -> Any: + """Run Playwright page with authorization.""" + raise NotImplementedError( + "Playwright direct integration not yet implemented. " + "Use with RuntimeAgent for Playwright pages." + ) + + def _run_langchain(self, task: str | None) -> Any: + """Run LangChain agent with authorization.""" + agent = self._wrapped.original + + # LangChain agents have .invoke() method + if hasattr(agent, "invoke"): + if task is not None: + return agent.invoke({"input": task}) + raise ValueError("Task is required for LangChain agents") + + raise NotImplementedError( + "LangChain integration requires AgentExecutor with invoke() method." + ) + + def _run_pydantic_ai(self, task: str | None) -> Any: + """Run PydanticAI agent with authorization.""" + raise NotImplementedError( + "PydanticAI integration not yet implemented." + ) @classmethod - def attach(cls, agent, **kwargs) -> "SecureAgent": + def attach(cls, agent: Any, **kwargs: Any) -> SecureAgent: """ Attach SecureAgent to an existing agent (factory method). @@ -149,3 +393,32 @@ def attach(cls, agent, **kwargs) -> "SecureAgent": SecureAgent instance """ return cls(agent=agent, **kwargs) + + def get_pre_action_authorizer(self) -> Any: + """ + Get a pre-action authorizer callback for use with RuntimeAgent. + + This allows integrating SecureAgent authorization with existing + RuntimeAgent-based workflows. + + Returns: + Callable for pre_action_authorizer parameter + + Example: + secure = SecureAgent(agent=my_agent, policy="policy.yaml") + runtime_agent = RuntimeAgent( + runtime=runtime, + executor=executor, + pre_action_authorizer=secure.get_pre_action_authorizer(), + ) + """ + return self._create_pre_action_authorizer() + + def __repr__(self) -> str: + return ( + f"SecureAgent(" + f"framework={self._wrapped.framework}, " + f"mode={self._config.mode}, " + f"policy={self._config.effective_policy_path or 'None'}" + f")" + ) diff --git a/src/predicate_secure/config.py b/src/predicate_secure/config.py new file mode 100644 index 0000000..a366a15 --- /dev/null +++ b/src/predicate_secure/config.py @@ -0,0 +1,114 @@ +"""Configuration classes for predicate-secure.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Literal + +# Mode type alias +Mode = Literal["strict", "permissive", "debug", "audit"] + + +@dataclass(frozen=True) +class SecureAgentConfig: + """ + Configuration for SecureAgent. + + Attributes: + policy: Path to policy file or inline policy dict + mode: Execution mode (strict, permissive, debug, audit) + principal_id: Agent principal ID (auto-detect from env if not provided) + tenant_id: Tenant ID for multi-tenant deployments + session_id: Session ID for tracking + sidecar_url: Sidecar URL (None for embedded mode) + signing_key: Secret key for mandate signing (auto-detect from env if not provided) + mandate_ttl_seconds: TTL for issued mandates + fail_closed: Whether to fail closed on authorization errors (based on mode) + """ + + policy: str | Path | None = None + mode: Mode = "strict" + principal_id: str | None = None + tenant_id: str | None = None + session_id: str | None = None + sidecar_url: str | None = None + signing_key: str | None = None + mandate_ttl_seconds: int = 300 + + @property + def fail_closed(self) -> bool: + """Whether to fail closed on authorization errors.""" + return self.mode == "strict" + + @property + def effective_principal_id(self) -> str: + """Get principal ID, falling back to environment variable.""" + return self.principal_id or os.getenv("PREDICATE_PRINCIPAL_ID", "agent:default") + + @property + def effective_signing_key(self) -> str: + """Get signing key, falling back to environment variable.""" + key = self.signing_key or os.getenv("PREDICATE_AUTHORITY_SIGNING_KEY") + if not key: + # For embedded mode without sidecar, generate a local key + # This is fine for local-only operation; production should use env var + key = "local-dev-key-not-for-production" + return key + + @property + def effective_policy_path(self) -> str | None: + """Get policy path as string.""" + if self.policy is None: + return os.getenv("PREDICATE_AUTHORITY_POLICY_FILE") + if isinstance(self.policy, Path): + return str(self.policy) + return self.policy + + @classmethod + def from_kwargs( + cls, + policy: str | Path | None = None, + mode: str = "strict", + principal_id: str | None = None, + tenant_id: str | None = None, + session_id: str | None = None, + sidecar_url: str | None = None, + signing_key: str | None = None, + mandate_ttl_seconds: int = 300, + ) -> SecureAgentConfig: + """Create config from keyword arguments with validation.""" + valid_modes = ("strict", "permissive", "debug", "audit") + if mode not in valid_modes: + raise ValueError(f"Invalid mode '{mode}'. Must be one of: {valid_modes}") + + return cls( + policy=policy, + mode=mode, # type: ignore[arg-type] + principal_id=principal_id, + tenant_id=tenant_id, + session_id=session_id, + sidecar_url=sidecar_url, + signing_key=signing_key, + mandate_ttl_seconds=mandate_ttl_seconds, + ) + + +@dataclass +class WrappedAgent: + """ + Container for a wrapped agent with its detected framework. + + Attributes: + original: The original agent object + framework: Detected framework name + agent_runtime: Initialized AgentRuntime (if applicable) + executor: LLM executor extracted from agent (if applicable) + """ + + original: object + framework: str + agent_runtime: object | None = None + executor: object | None = None + metadata: dict = field(default_factory=dict) diff --git a/src/predicate_secure/detection.py b/src/predicate_secure/detection.py new file mode 100644 index 0000000..58cf842 --- /dev/null +++ b/src/predicate_secure/detection.py @@ -0,0 +1,194 @@ +"""Framework detection for SecureAgent.""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Any + + +class Framework(Enum): + """Supported agent frameworks.""" + + BROWSER_USE = "browser_use" + PLAYWRIGHT = "playwright" + LANGCHAIN = "langchain" + PYDANTIC_AI = "pydantic_ai" + UNKNOWN = "unknown" + + +@dataclass(frozen=True) +class DetectionResult: + """Result of framework detection.""" + + framework: Framework + agent_type: str + confidence: float # 0.0 to 1.0 + metadata: dict + + +class FrameworkDetector: + """Detects the framework of a given agent object.""" + + @classmethod + def detect(cls, agent: Any) -> DetectionResult: + """ + Detect the framework of an agent. + + Args: + agent: The agent object to detect + + Returns: + DetectionResult with framework info + """ + # Check browser-use + result = cls._check_browser_use(agent) + if result: + return result + + # Check Playwright + result = cls._check_playwright(agent) + if result: + return result + + # Check LangChain + result = cls._check_langchain(agent) + if result: + return result + + # Check PydanticAI + result = cls._check_pydantic_ai(agent) + if result: + return result + + # Unknown framework + return DetectionResult( + framework=Framework.UNKNOWN, + agent_type=type(agent).__name__, + confidence=0.0, + metadata={"module": getattr(type(agent), "__module__", "unknown")}, + ) + + @classmethod + def _check_browser_use(cls, agent: Any) -> DetectionResult | None: + """Check if agent is a browser-use Agent.""" + agent_type = type(agent) + module = getattr(agent_type, "__module__", "") + + # Check by module path + if "browser_use" in module: + return DetectionResult( + framework=Framework.BROWSER_USE, + agent_type=agent_type.__name__, + confidence=1.0, + metadata={ + "module": module, + "has_task": hasattr(agent, "task"), + "has_llm": hasattr(agent, "llm"), + }, + ) + + # Check by class name and attributes (duck typing) + if agent_type.__name__ == "Agent" and hasattr(agent, "task") and hasattr(agent, "llm"): + # Could be browser-use, check for more specific attributes + if hasattr(agent, "browser") or hasattr(agent, "controller"): + return DetectionResult( + framework=Framework.BROWSER_USE, + agent_type=agent_type.__name__, + confidence=0.8, + metadata={"module": module, "detection": "duck_typing"}, + ) + + return None + + @classmethod + def _check_playwright(cls, agent: Any) -> DetectionResult | None: + """Check if agent is a Playwright Page.""" + agent_type = type(agent) + module = getattr(agent_type, "__module__", "") + + # Check by module path + if "playwright" in module: + # Determine if sync or async + is_async = "async_api" in module + return DetectionResult( + framework=Framework.PLAYWRIGHT, + agent_type=agent_type.__name__, + confidence=1.0, + metadata={ + "module": module, + "is_async": is_async, + "is_page": agent_type.__name__ == "Page", + }, + ) + + # Check by duck typing for Page-like objects + if hasattr(agent, "goto") and hasattr(agent, "click") and hasattr(agent, "evaluate"): + return DetectionResult( + framework=Framework.PLAYWRIGHT, + agent_type=agent_type.__name__, + confidence=0.7, + metadata={"module": module, "detection": "duck_typing"}, + ) + + return None + + @classmethod + def _check_langchain(cls, agent: Any) -> DetectionResult | None: + """Check if agent is a LangChain agent.""" + agent_type = type(agent) + module = getattr(agent_type, "__module__", "") + + # Check by module path + if "langchain" in module: + is_executor = "AgentExecutor" in agent_type.__name__ + return DetectionResult( + framework=Framework.LANGCHAIN, + agent_type=agent_type.__name__, + confidence=1.0, + metadata={ + "module": module, + "is_executor": is_executor, + "has_invoke": hasattr(agent, "invoke"), + }, + ) + + # Check by duck typing + if hasattr(agent, "invoke") and hasattr(agent, "agent"): + return DetectionResult( + framework=Framework.LANGCHAIN, + agent_type=agent_type.__name__, + confidence=0.6, + metadata={"module": module, "detection": "duck_typing"}, + ) + + return None + + @classmethod + def _check_pydantic_ai(cls, agent: Any) -> DetectionResult | None: + """Check if agent is a PydanticAI agent.""" + agent_type = type(agent) + module = getattr(agent_type, "__module__", "") + + # Check by module path + if "pydantic_ai" in module: + return DetectionResult( + framework=Framework.PYDANTIC_AI, + agent_type=agent_type.__name__, + confidence=1.0, + metadata={"module": module}, + ) + + return None + + +class UnsupportedFrameworkError(Exception): + """Raised when an unsupported framework is detected.""" + + def __init__(self, detection: DetectionResult): + self.detection = detection + super().__init__( + f"Unsupported framework: {detection.agent_type} " + f"(module: {detection.metadata.get('module', 'unknown')}). " + f"Supported frameworks: browser-use, Playwright, LangChain, PydanticAI" + ) diff --git a/tests/test_secure_agent.py b/tests/test_secure_agent.py index 15b72e1..237be77 100644 --- a/tests/test_secure_agent.py +++ b/tests/test_secure_agent.py @@ -8,9 +8,15 @@ MODE_PERMISSIVE, MODE_STRICT, AuthorizationDenied, + DetectionResult, + Framework, + FrameworkDetector, PolicyLoadError, SecureAgent, + SecureAgentConfig, + UnsupportedFrameworkError, VerificationFailed, + WrappedAgent, ) @@ -45,6 +51,152 @@ def test_attach_factory_method(self): assert secure._agent is mock_agent assert secure._mode == MODE_DEBUG + def test_config_property(self): + """SecureAgent exposes config property.""" + mock_agent = object() + secure = SecureAgent( + agent=mock_agent, + policy="test.yaml", + mode=MODE_PERMISSIVE, + principal_id="test:agent", + ) + assert isinstance(secure.config, SecureAgentConfig) + assert secure.config.mode == "permissive" + assert secure.config.effective_principal_id == "test:agent" + assert secure.config.effective_policy_path == "test.yaml" + + def test_wrapped_property(self): + """SecureAgent exposes wrapped property.""" + mock_agent = object() + secure = SecureAgent(agent=mock_agent) + assert isinstance(secure.wrapped, WrappedAgent) + assert secure.wrapped.original is mock_agent + assert secure.wrapped.framework == "unknown" + + def test_framework_property(self): + """SecureAgent exposes framework property.""" + mock_agent = object() + secure = SecureAgent(agent=mock_agent) + assert secure.framework == Framework.UNKNOWN + + def test_repr(self): + """SecureAgent has useful repr.""" + mock_agent = object() + secure = SecureAgent(agent=mock_agent, policy="test.yaml", mode=MODE_DEBUG) + repr_str = repr(secure) + assert "SecureAgent" in repr_str + assert "unknown" in repr_str + assert "debug" in repr_str + assert "test.yaml" in repr_str + + +class TestSecureAgentConfig: + """Tests for SecureAgentConfig.""" + + def test_defaults(self): + """Config has sensible defaults.""" + config = SecureAgentConfig() + assert config.mode == "strict" + assert config.fail_closed is True + assert config.mandate_ttl_seconds == 300 + + def test_from_kwargs(self): + """Config can be created from kwargs.""" + config = SecureAgentConfig.from_kwargs( + policy="test.yaml", + mode="permissive", + principal_id="test:agent", + ) + assert config.effective_policy_path == "test.yaml" + assert config.mode == "permissive" + assert config.fail_closed is False + assert config.effective_principal_id == "test:agent" + + def test_from_kwargs_invalid_mode(self): + """Config raises on invalid mode.""" + with pytest.raises(ValueError, match="Invalid mode"): + SecureAgentConfig.from_kwargs(mode="invalid") + + def test_effective_principal_id_fallback(self): + """Config falls back to default principal.""" + config = SecureAgentConfig() + assert config.effective_principal_id == "agent:default" + + def test_fail_closed_by_mode(self): + """fail_closed depends on mode.""" + assert SecureAgentConfig(mode="strict").fail_closed is True + assert SecureAgentConfig(mode="permissive").fail_closed is False + assert SecureAgentConfig(mode="debug").fail_closed is False + assert SecureAgentConfig(mode="audit").fail_closed is False + + +class TestFrameworkDetector: + """Tests for framework detection.""" + + def test_detect_unknown(self): + """Unknown objects detected as UNKNOWN.""" + result = FrameworkDetector.detect(object()) + assert result.framework == Framework.UNKNOWN + assert result.confidence == 0.0 + + def test_detect_by_module(self): + """Objects detected by module name.""" + + class MockBrowserUseAgent: + __module__ = "browser_use.agent" + + result = FrameworkDetector.detect(MockBrowserUseAgent()) + assert result.framework == Framework.BROWSER_USE + assert result.confidence == 1.0 + + def test_detect_playwright_page(self): + """Playwright pages detected by attributes.""" + + class MockPage: + __module__ = "some.module" + + def goto(self): + pass + + def click(self): + pass + + def evaluate(self): + pass + + result = FrameworkDetector.detect(MockPage()) + assert result.framework == Framework.PLAYWRIGHT + assert result.confidence == 0.7 + + def test_detection_result_immutable(self): + """DetectionResult is immutable.""" + result = DetectionResult( + framework=Framework.BROWSER_USE, + agent_type="Agent", + confidence=1.0, + metadata={}, + ) + with pytest.raises(Exception): # frozen dataclass + result.confidence = 0.5 + + +class TestWrappedAgent: + """Tests for WrappedAgent.""" + + def test_wrapped_agent_fields(self): + """WrappedAgent has expected fields.""" + original = object() + wrapped = WrappedAgent( + original=original, + framework="browser_use", + agent_runtime=None, + executor=None, + metadata={"key": "value"}, + ) + assert wrapped.original is original + assert wrapped.framework == "browser_use" + assert wrapped.metadata == {"key": "value"} + class TestExceptions: """Tests for custom exceptions.""" @@ -54,16 +206,40 @@ def test_authorization_denied(self): with pytest.raises(AuthorizationDenied): raise AuthorizationDenied("Action denied by policy") + def test_authorization_denied_with_decision(self): + """AuthorizationDenied can include decision.""" + decision = {"allowed": False, "reason": "policy"} + exc = AuthorizationDenied("denied", decision=decision) + assert exc.decision == decision + def test_verification_failed(self): """VerificationFailed is an Exception.""" with pytest.raises(VerificationFailed): raise VerificationFailed("Post-execution check failed") + def test_verification_failed_with_predicate(self): + """VerificationFailed can include predicate.""" + exc = VerificationFailed("failed", predicate="url_matches('example.com')") + assert exc.predicate == "url_matches('example.com')" + def test_policy_load_error(self): """PolicyLoadError is an Exception.""" with pytest.raises(PolicyLoadError): raise PolicyLoadError("Invalid policy file") + def test_unsupported_framework_error(self): + """UnsupportedFrameworkError includes detection info.""" + detection = DetectionResult( + framework=Framework.UNKNOWN, + agent_type="CustomAgent", + confidence=0.0, + metadata={"module": "custom.module"}, + ) + exc = UnsupportedFrameworkError(detection) + assert exc.detection is detection + assert "CustomAgent" in str(exc) + assert "custom.module" in str(exc) + class TestModeConstants: """Tests for mode constants.""" @@ -74,3 +250,118 @@ def test_mode_values(self): assert MODE_PERMISSIVE == "permissive" assert MODE_DEBUG == "debug" assert MODE_AUDIT == "audit" + + +class TestSecureAgentRun: + """Tests for SecureAgent.run().""" + + def test_run_unknown_framework_raises(self): + """run() raises for unknown frameworks.""" + mock_agent = object() + secure = SecureAgent(agent=mock_agent) + with pytest.raises(UnsupportedFrameworkError): + secure.run() + + def test_get_pre_action_authorizer_no_policy(self): + """get_pre_action_authorizer returns None without policy.""" + mock_agent = object() + secure = SecureAgent(agent=mock_agent) + authorizer = secure.get_pre_action_authorizer() + assert authorizer is None + + +class TestSecureAgentBrowserUseMock: + """Tests for browser-use framework detection with mocks.""" + + def test_detect_browser_use_by_attributes(self): + """browser-use Agent detected by task/llm/browser attributes.""" + + class MockBrowserUseAgent: + __module__ = "some.module" + + def __init__(self): + self.task = "Test task" + self.llm = object() + self.browser = object() + + MockBrowserUseAgent.__name__ = "Agent" + + result = FrameworkDetector.detect(MockBrowserUseAgent()) + assert result.framework == Framework.BROWSER_USE + assert result.confidence == 0.8 + assert result.metadata.get("detection") == "duck_typing" + + def test_secure_agent_detects_browser_use(self): + """SecureAgent correctly detects browser-use agent.""" + + class MockBrowserUseAgent: + __module__ = "browser_use.agent" + + def __init__(self): + self.task = "Test task" + self.llm = "mock_llm" + + agent = MockBrowserUseAgent() + secure = SecureAgent(agent=agent) + + assert secure.framework == Framework.BROWSER_USE + assert secure.wrapped.executor == "mock_llm" + + +class TestSecureAgentLangChainMock: + """Tests for LangChain framework detection with mocks.""" + + def test_detect_langchain_by_module(self): + """LangChain detected by module.""" + + class MockAgentExecutor: + __module__ = "langchain.agents" + + result = FrameworkDetector.detect(MockAgentExecutor()) + assert result.framework == Framework.LANGCHAIN + assert result.confidence == 1.0 + + def test_secure_agent_detects_langchain(self): + """SecureAgent correctly detects LangChain agent.""" + + class MockAgentExecutor: + __module__ = "langchain.agents.executor" + + def __init__(self): + self.llm = "mock_llm" + + def invoke(self, inputs): + return {"output": "result"} + + agent = MockAgentExecutor() + secure = SecureAgent(agent=agent) + + assert secure.framework == Framework.LANGCHAIN + assert secure.wrapped.executor == "mock_llm" + + +class TestSecureAgentPlaywrightMock: + """Tests for Playwright framework detection with mocks.""" + + def test_detect_playwright_by_module(self): + """Playwright detected by module.""" + + class MockPage: + __module__ = "playwright.async_api._generated" + + result = FrameworkDetector.detect(MockPage()) + assert result.framework == Framework.PLAYWRIGHT + assert result.confidence == 1.0 + assert result.metadata.get("is_async") is True + + def test_secure_agent_detects_playwright(self): + """SecureAgent correctly detects Playwright page.""" + + class MockPage: + __module__ = "playwright.sync_api._generated" + + page = MockPage() + secure = SecureAgent(agent=page) + + assert secure.framework == Framework.PLAYWRIGHT + assert secure.wrapped.metadata.get("is_async") is False