From 8ebd080c04ee5c96dc441842eef371d934de27f5 Mon Sep 17 00:00:00 2001 From: momuno Date: Thu, 26 Feb 2026 13:11:24 -0800 Subject: [PATCH] feat: add input_schema to Tool protocol to match ecosystem convention MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Every official tool module (tool-bash, tool-filesystem, tool-todo, etc.) already implements input_schema as a @property returning a JSON Schema dict, but the kernel's Tool Protocol in interfaces.py did not require it. This created a gap between the formal contract and what all real tools do. Additionally, TOOL_CONTRACT.md referenced a get_schema() method that no real module implements. This commit: - Adds input_schema property to the Tool Protocol in interfaces.py - Updates TOOL_CONTRACT.md to use input_schema property throughout - Adds input_schema validation check in tool.py - Adds behavioral and structural tests for input_schema - Updates MockTool classes in test_validation.py to include input_schema - Adds input_schema to the MyTool example in README.md Backward compatible: all existing tool modules already implement this property. 🤖 Generated with [Amplifier](https://github.com/microsoft/amplifier) Co-Authored-By: Amplifier <240397093+microsoft-amplifier@users.noreply.github.com> --- README.md | 9 +++ amplifier_core/interfaces.py | 32 ++++++-- .../validation/behavioral/test_tool.py | 28 ++++++- .../validation/structural/test_tool.py | 9 +++ amplifier_core/validation/tool.py | 31 ++++++++ docs/contracts/TOOL_CONTRACT.md | 31 ++++++-- tests/test_validation.py | 74 +++++++++++++++---- 7 files changed, 184 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index be707d1..3e7df81 100644 --- a/README.md +++ b/README.md @@ -131,6 +131,15 @@ class MyTool: def description(self) -> str: return "Does something useful" + @property + def input_schema(self) -> dict: + return { + "type": "object", + "properties": { + "param": {"type": "string", "description": "Tool-specific parameter"}, + }, + } + async def execute(self, input: dict) -> ToolResult: """Execute tool with input dict.""" return ToolResult( diff --git a/amplifier_core/interfaces.py b/amplifier_core/interfaces.py index 3e04714..1fa7872 100644 --- a/amplifier_core/interfaces.py +++ b/amplifier_core/interfaces.py @@ -139,6 +139,18 @@ def description(self) -> str: """Human-readable tool description.""" ... + @property + def input_schema(self) -> dict[str, Any]: + """JSON Schema describing the tool's input parameters. + + Orchestrators use this to build ToolSpec objects for the LLM, + so the model knows what parameters to pass when calling the tool. + + Returns: + JSON Schema dict (e.g. {"type": "object", "properties": {...}, "required": [...]}) + """ + ... + async def execute(self, input: dict[str, Any]) -> ToolResult: """ Execute tool with given input. @@ -225,9 +237,15 @@ class ApprovalRequest(BaseModel): tool_name: str = Field(..., description="Name of the tool requesting approval") action: str = Field(..., description="Human-readable description of the action") - details: dict[str, Any] = Field(default_factory=dict, description="Tool-specific context and parameters") - risk_level: str = Field(..., description="Risk level: low, medium, high, or critical") - timeout: float | None = Field(default=None, description="Timeout in seconds (None = wait indefinitely)") + details: dict[str, Any] = Field( + default_factory=dict, description="Tool-specific context and parameters" + ) + risk_level: str = Field( + ..., description="Risk level: low, medium, high, or critical" + ) + timeout: float | None = Field( + default=None, description="Timeout in seconds (None = wait indefinitely)" + ) def model_post_init(self, __context: Any) -> None: """Validate timeout if provided.""" @@ -239,8 +257,12 @@ class ApprovalResponse(BaseModel): """Response to an approval request.""" approved: bool = Field(..., description="Whether the action was approved") - reason: str | None = Field(default=None, description="Explanation for approval/denial") - remember: bool = Field(default=False, description="Cache this decision for future requests") + reason: str | None = Field( + default=None, description="Explanation for approval/denial" + ) + remember: bool = Field( + default=False, description="Cache this decision for future requests" + ) @runtime_checkable diff --git a/amplifier_core/validation/behavioral/test_tool.py b/amplifier_core/validation/behavioral/test_tool.py index 27c0db8..5c6954e 100644 --- a/amplifier_core/validation/behavioral/test_tool.py +++ b/amplifier_core/validation/behavioral/test_tool.py @@ -38,9 +38,25 @@ async def test_tool_has_name(self, tool_module): @pytest.mark.asyncio async def test_tool_has_description(self, tool_module): """Tool must have a description property.""" - assert hasattr(tool_module, "description"), "Tool must have description attribute" + assert hasattr(tool_module, "description"), ( + "Tool must have description attribute" + ) assert tool_module.description, "Tool description must not be empty" - assert isinstance(tool_module.description, str), "Tool description must be string" + assert isinstance(tool_module.description, str), ( + "Tool description must be string" + ) + + @pytest.mark.asyncio + async def test_tool_has_input_schema(self, tool_module): + """Tool must have an input_schema property returning a dict with a 'type' key.""" + assert hasattr(tool_module, "input_schema"), ( + "Tool must have input_schema attribute" + ) + schema = tool_module.input_schema + assert isinstance(schema, dict), "input_schema must return a dict" + assert "type" in schema, ( + "input_schema must have a 'type' key (JSON Schema structure)" + ) @pytest.mark.asyncio async def test_tool_has_execute_method(self, tool_module): @@ -69,7 +85,11 @@ async def test_invalid_input_returns_error_result(self, tool_module): try: result = await tool_module.execute({}) # Should return error result, not raise - assert isinstance(result, ToolResult), "Must return ToolResult even on error" + assert isinstance(result, ToolResult), ( + "Must return ToolResult even on error" + ) except Exception as e: # Only allow expected validation errors, not code bugs - assert not isinstance(e, AttributeError | TypeError | KeyError), f"Tool crashed with code error: {e}" + assert not isinstance(e, AttributeError | TypeError | KeyError), ( + f"Tool crashed with code error: {e}" + ) diff --git a/amplifier_core/validation/structural/test_tool.py b/amplifier_core/validation/structural/test_tool.py index ab17974..59bec42 100644 --- a/amplifier_core/validation/structural/test_tool.py +++ b/amplifier_core/validation/structural/test_tool.py @@ -21,6 +21,15 @@ class ToolStructuralTests: All test methods use fixtures provided by the amplifier-core pytest plugin. """ + @pytest.mark.asyncio + async def test_tool_has_input_schema(self, tool_module): + """Tool must have an input_schema property returning a dict.""" + assert hasattr(tool_module, "input_schema"), ( + "Tool must have input_schema attribute" + ) + schema = tool_module.input_schema + assert isinstance(schema, dict), "input_schema must return a dict" + @pytest.mark.asyncio async def test_structural_validation(self, module_path): """Module must pass all structural validation checks.""" diff --git a/amplifier_core/validation/tool.py b/amplifier_core/validation/tool.py index bb662f8..6f3ab3e 100644 --- a/amplifier_core/validation/tool.py +++ b/amplifier_core/validation/tool.py @@ -384,6 +384,37 @@ def _check_tool_methods(self, result: ValidationResult, tool: Tool) -> None: ) ) + # Check input_schema property + try: + schema = tool.input_schema + if isinstance(schema, dict): + result.add( + ValidationCheck( + name="tool_input_schema", + passed=True, + message="Tool has input_schema", + severity="info", + ) + ) + else: + result.add( + ValidationCheck( + name="tool_input_schema", + passed=False, + message="Tool.input_schema should return a dict", + severity="error", + ) + ) + except Exception as e: + result.add( + ValidationCheck( + name="tool_input_schema", + passed=False, + message=f"Error accessing Tool.input_schema: {e}", + severity="warning", + ) + ) + # Check execute method execute = getattr(tool, "execute", None) if execute is None: diff --git a/docs/contracts/TOOL_CONTRACT.md b/docs/contracts/TOOL_CONTRACT.md index 628eb51..f6bf727 100644 --- a/docs/contracts/TOOL_CONTRACT.md +++ b/docs/contracts/TOOL_CONTRACT.md @@ -52,6 +52,11 @@ class Tool(Protocol): """Human-readable tool description.""" ... + @property + def input_schema(self) -> dict[str, Any]: + """JSON Schema describing the tool's input parameters.""" + ... + async def execute(self, input: dict[str, Any]) -> ToolResult: """ Execute tool with given input. @@ -136,11 +141,22 @@ class MyTool: @property def description(self) -> str: return "Performs specific action with given parameters." + + @property + def input_schema(self) -> dict: + return { + "type": "object", + "properties": { + "param": {"type": "string", "description": "A required parameter"}, + }, + "required": ["param"], + } ``` **Best practices**: - `name`: Short, snake_case, unique across mounted tools - `description`: Clear explanation of what the tool does and expects +- `input_schema`: Valid JSON Schema dict describing the tool's parameters ### execute() Method @@ -172,13 +188,16 @@ async def execute(self, input: dict[str, Any]) -> ToolResult: ) ``` -### Tool Schema (Optional but Recommended) +### input_schema Property -Provide JSON schema for input validation: +Provide a JSON Schema describing the tool's input parameters. Orchestrators use this +to build `ToolSpec` objects for the LLM, so the model knows what parameters to pass +when calling the tool. ```python -def get_schema(self) -> dict: - """Return JSON schema for tool input.""" +@property +def input_schema(self) -> dict: + """Return JSON schema for tool parameters.""" return { "type": "object", "properties": { @@ -254,14 +273,14 @@ Additional examples: ### Required -- [ ] Implements Tool protocol (name, description, execute) +- [ ] Implements Tool protocol (name, description, input_schema, execute) - [ ] `mount()` function with entry point in pyproject.toml - [ ] Returns `ToolResult` from execute() - [ ] Handles errors gracefully (returns success=False, doesn't crash) ### Recommended -- [ ] Provides JSON schema via `get_schema()` +- [ ] Provides meaningful `input_schema` with property descriptions - [ ] Validates input before processing - [ ] Logs operations at appropriate levels - [ ] Registers observability events diff --git a/tests/test_validation.py b/tests/test_validation.py index 3a7078e..ad7bdf7 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -45,7 +45,11 @@ def test_create_failing_check(self): def test_severity_levels(self): from typing import Literal - severities: list[Literal["error", "warning", "info"]] = ["error", "warning", "info"] + severities: list[Literal["error", "warning", "info"]] = [ + "error", + "warning", + "info", + ] for severity in severities: check = ValidationCheck( name="test", @@ -263,7 +267,11 @@ def mount(coordinator, config): validator = ProviderValidator() result = await validator.validate(str(module_file)) - assert any("async" in c.message.lower() for c in result.checks if not c.passed and c.name == "mount_signature") + assert any( + "async" in c.message.lower() + for c in result.checks + if not c.passed and c.name == "mount_signature" + ) @pytest.mark.asyncio async def test_mount_missing_params_fails(self, tmp_path): @@ -279,7 +287,9 @@ async def mount(coordinator): validator = ProviderValidator() result = await validator.validate(str(module_file)) assert any( - "2 parameters" in c.message.lower() for c in result.checks if not c.passed and c.name == "mount_signature" + "2 parameters" in c.message.lower() + for c in result.checks + if not c.passed and c.name == "mount_signature" ) @pytest.mark.asyncio @@ -297,7 +307,9 @@ async def mount(coordinator, config): result = await validator.validate(str(module_file)) # Should pass importable and mount_exists and mount_signature checks - signature_check = next((c for c in result.checks if c.name == "mount_signature"), None) + signature_check = next( + (c for c in result.checks if c.name == "mount_signature"), None + ) assert signature_check is not None assert signature_check.passed is True @@ -333,7 +345,9 @@ async def mount(coordinator, config): result = await validator.validate(str(module_dir)) # Should pass importable check - importable_check = next((c for c in result.checks if c.name == "module_importable"), None) + importable_check = next( + (c for c in result.checks if c.name == "module_importable"), None + ) assert importable_check is not None assert importable_check.passed is True @@ -489,7 +503,9 @@ def test_valid_full_mount_plan(self): "orchestrator": {"module": "orchestrator-default"}, "context": {"module": "context-default"}, }, - "providers": [{"module": "provider-anthropic", "config": {"model": "sonnet"}}], + "providers": [ + {"module": "provider-anthropic", "config": {"model": "sonnet"}} + ], "tools": [{"module": "tool-web-search"}], "hooks": [], } @@ -648,7 +664,10 @@ def test_session_not_dict_fails(self): mount_plan = {"session": "not a dict"} result = MountPlanValidator().validate(mount_plan) assert not result.passed - assert any("session" in e.message.lower() and "dict" in e.message.lower() for e in result.errors) + assert any( + "session" in e.message.lower() and "dict" in e.message.lower() + for e in result.errors + ) def test_agents_section_not_validated_as_module_list(self): """Agents section is not validated as a module list (it's dict of configs).""" @@ -734,15 +753,25 @@ def parse_tool_calls(self, response): result_no_config = await validator.validate(str(module_dir)) # Should fail with error about no provider mounted assert not result_no_config.passed, "Should fail when mount returns None" - assert any("No provider was mounted" in c.message for c in result_no_config.checks if c.severity == "error") + assert any( + "No provider was mounted" in c.message + for c in result_no_config.checks + if c.severity == "error" + ) # Test 2: With config - should mount successfully and pass validation - result_with_config = await validator.validate(str(module_dir), config={"api_key": "test-key"}) + result_with_config = await validator.validate( + str(module_dir), config={"api_key": "test-key"} + ) # Should pass with provider mounted assert result_with_config.passed, ( f"Should pass with config. Errors: {[c.message for c in result_with_config.errors]}" ) - assert any("implements Provider protocol" in c.message for c in result_with_config.checks if c.passed) + assert any( + "implements Provider protocol" in c.message + for c in result_with_config.checks + if c.passed + ) @pytest.mark.asyncio async def test_tool_validation_with_config(self, tmp_path): @@ -764,6 +793,7 @@ async def mount(coordinator: ModuleCoordinator, config): class MockTool: name = "mock-tool" description = "A mock tool" + input_schema = {"type": "object", "properties": {}} async def execute(self, input): return ToolResult(success=True, output={"result": "mock"}) @@ -777,11 +807,25 @@ async def execute(self, input): validator = ToolValidator() # With enabled=False - should FAIL (returns None, no tool mounted) - result_disabled = await validator.validate(str(module_dir), config={"enabled": False}) + result_disabled = await validator.validate( + str(module_dir), config={"enabled": False} + ) assert not result_disabled.passed, "Should fail when mount returns None" - assert any("No tool was mounted" in c.message for c in result_disabled.checks if c.severity == "error") + assert any( + "No tool was mounted" in c.message + for c in result_disabled.checks + if c.severity == "error" + ) # With enabled=True - should mount successfully and pass validation - result_enabled = await validator.validate(str(module_dir), config={"enabled": True}) - assert result_enabled.passed, f"Should pass with config. Errors: {[c.message for c in result_enabled.errors]}" - assert any("implements Tool protocol" in c.message for c in result_enabled.checks if c.passed) + result_enabled = await validator.validate( + str(module_dir), config={"enabled": True} + ) + assert result_enabled.passed, ( + f"Should pass with config. Errors: {[c.message for c in result_enabled.errors]}" + ) + assert any( + "implements Tool protocol" in c.message + for c in result_enabled.checks + if c.passed + )