Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,15 @@ class MyTool:
def description(self) -> str:
return "Does something useful"

@property
def input_schema(self) -> dict:
return {
"type": "object",
"properties": {
"param": {"type": "string", "description": "Tool-specific parameter"},
},
}

async def execute(self, input: dict) -> ToolResult:
"""Execute tool with input dict."""
return ToolResult(
Expand Down
32 changes: 27 additions & 5 deletions amplifier_core/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,18 @@ def description(self) -> str:
"""Human-readable tool description."""
...

@property
def input_schema(self) -> dict[str, Any]:
"""JSON Schema describing the tool's input parameters.

Orchestrators use this to build ToolSpec objects for the LLM,
so the model knows what parameters to pass when calling the tool.

Returns:
JSON Schema dict (e.g. {"type": "object", "properties": {...}, "required": [...]})
"""
...

async def execute(self, input: dict[str, Any]) -> ToolResult:
"""
Execute tool with given input.
Expand Down Expand Up @@ -225,9 +237,15 @@ class ApprovalRequest(BaseModel):

tool_name: str = Field(..., description="Name of the tool requesting approval")
action: str = Field(..., description="Human-readable description of the action")
details: dict[str, Any] = Field(default_factory=dict, description="Tool-specific context and parameters")
risk_level: str = Field(..., description="Risk level: low, medium, high, or critical")
timeout: float | None = Field(default=None, description="Timeout in seconds (None = wait indefinitely)")
details: dict[str, Any] = Field(
default_factory=dict, description="Tool-specific context and parameters"
)
risk_level: str = Field(
..., description="Risk level: low, medium, high, or critical"
)
timeout: float | None = Field(
default=None, description="Timeout in seconds (None = wait indefinitely)"
)

def model_post_init(self, __context: Any) -> None:
"""Validate timeout if provided."""
Expand All @@ -239,8 +257,12 @@ class ApprovalResponse(BaseModel):
"""Response to an approval request."""

approved: bool = Field(..., description="Whether the action was approved")
reason: str | None = Field(default=None, description="Explanation for approval/denial")
remember: bool = Field(default=False, description="Cache this decision for future requests")
reason: str | None = Field(
default=None, description="Explanation for approval/denial"
)
remember: bool = Field(
default=False, description="Cache this decision for future requests"
)


@runtime_checkable
Expand Down
28 changes: 24 additions & 4 deletions amplifier_core/validation/behavioral/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,25 @@ async def test_tool_has_name(self, tool_module):
@pytest.mark.asyncio
async def test_tool_has_description(self, tool_module):
"""Tool must have a description property."""
assert hasattr(tool_module, "description"), "Tool must have description attribute"
assert hasattr(tool_module, "description"), (
"Tool must have description attribute"
)
assert tool_module.description, "Tool description must not be empty"
assert isinstance(tool_module.description, str), "Tool description must be string"
assert isinstance(tool_module.description, str), (
"Tool description must be string"
)

@pytest.mark.asyncio
async def test_tool_has_input_schema(self, tool_module):
"""Tool must have an input_schema property returning a dict with a 'type' key."""
assert hasattr(tool_module, "input_schema"), (
"Tool must have input_schema attribute"
)
schema = tool_module.input_schema
assert isinstance(schema, dict), "input_schema must return a dict"
assert "type" in schema, (
"input_schema must have a 'type' key (JSON Schema structure)"
)

@pytest.mark.asyncio
async def test_tool_has_execute_method(self, tool_module):
Expand Down Expand Up @@ -69,7 +85,11 @@ async def test_invalid_input_returns_error_result(self, tool_module):
try:
result = await tool_module.execute({})
# Should return error result, not raise
assert isinstance(result, ToolResult), "Must return ToolResult even on error"
assert isinstance(result, ToolResult), (
"Must return ToolResult even on error"
)
except Exception as e:
# Only allow expected validation errors, not code bugs
assert not isinstance(e, AttributeError | TypeError | KeyError), f"Tool crashed with code error: {e}"
assert not isinstance(e, AttributeError | TypeError | KeyError), (
f"Tool crashed with code error: {e}"
)
9 changes: 9 additions & 0 deletions amplifier_core/validation/structural/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ class ToolStructuralTests:
All test methods use fixtures provided by the amplifier-core pytest plugin.
"""

@pytest.mark.asyncio
async def test_tool_has_input_schema(self, tool_module):
"""Tool must have an input_schema property returning a dict."""
assert hasattr(tool_module, "input_schema"), (
"Tool must have input_schema attribute"
)
schema = tool_module.input_schema
assert isinstance(schema, dict), "input_schema must return a dict"

@pytest.mark.asyncio
async def test_structural_validation(self, module_path):
"""Module must pass all structural validation checks."""
Expand Down
31 changes: 31 additions & 0 deletions amplifier_core/validation/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,37 @@ def _check_tool_methods(self, result: ValidationResult, tool: Tool) -> None:
)
)

# Check input_schema property
try:
schema = tool.input_schema
if isinstance(schema, dict):
result.add(
ValidationCheck(
name="tool_input_schema",
passed=True,
message="Tool has input_schema",
severity="info",
)
)
else:
result.add(
ValidationCheck(
name="tool_input_schema",
passed=False,
message="Tool.input_schema should return a dict",
severity="error",
)
)
except Exception as e:
result.add(
ValidationCheck(
name="tool_input_schema",
passed=False,
message=f"Error accessing Tool.input_schema: {e}",
severity="warning",
)
)

# Check execute method
execute = getattr(tool, "execute", None)
if execute is None:
Expand Down
31 changes: 25 additions & 6 deletions docs/contracts/TOOL_CONTRACT.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ class Tool(Protocol):
"""Human-readable tool description."""
...

@property
def input_schema(self) -> dict[str, Any]:
"""JSON Schema describing the tool's input parameters."""
...

async def execute(self, input: dict[str, Any]) -> ToolResult:
"""
Execute tool with given input.
Expand Down Expand Up @@ -136,11 +141,22 @@ class MyTool:
@property
def description(self) -> str:
return "Performs specific action with given parameters."

@property
def input_schema(self) -> dict:
return {
"type": "object",
"properties": {
"param": {"type": "string", "description": "A required parameter"},
},
"required": ["param"],
}
```

**Best practices**:
- `name`: Short, snake_case, unique across mounted tools
- `description`: Clear explanation of what the tool does and expects
- `input_schema`: Valid JSON Schema dict describing the tool's parameters

### execute() Method

Expand Down Expand Up @@ -172,13 +188,16 @@ async def execute(self, input: dict[str, Any]) -> ToolResult:
)
```

### Tool Schema (Optional but Recommended)
### input_schema Property

Provide JSON schema for input validation:
Provide a JSON Schema describing the tool's input parameters. Orchestrators use this
to build `ToolSpec` objects for the LLM, so the model knows what parameters to pass
when calling the tool.

```python
def get_schema(self) -> dict:
"""Return JSON schema for tool input."""
@property
def input_schema(self) -> dict:
"""Return JSON schema for tool parameters."""
return {
"type": "object",
"properties": {
Expand Down Expand Up @@ -254,14 +273,14 @@ Additional examples:

### Required

- [ ] Implements Tool protocol (name, description, execute)
- [ ] Implements Tool protocol (name, description, input_schema, execute)
- [ ] `mount()` function with entry point in pyproject.toml
- [ ] Returns `ToolResult` from execute()
- [ ] Handles errors gracefully (returns success=False, doesn't crash)

### Recommended

- [ ] Provides JSON schema via `get_schema()`
- [ ] Provides meaningful `input_schema` with property descriptions
- [ ] Validates input before processing
- [ ] Logs operations at appropriate levels
- [ ] Registers observability events
Expand Down
74 changes: 59 additions & 15 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@ def test_create_failing_check(self):
def test_severity_levels(self):
from typing import Literal

severities: list[Literal["error", "warning", "info"]] = ["error", "warning", "info"]
severities: list[Literal["error", "warning", "info"]] = [
"error",
"warning",
"info",
]
for severity in severities:
check = ValidationCheck(
name="test",
Expand Down Expand Up @@ -263,7 +267,11 @@ def mount(coordinator, config):

validator = ProviderValidator()
result = await validator.validate(str(module_file))
assert any("async" in c.message.lower() for c in result.checks if not c.passed and c.name == "mount_signature")
assert any(
"async" in c.message.lower()
for c in result.checks
if not c.passed and c.name == "mount_signature"
)

@pytest.mark.asyncio
async def test_mount_missing_params_fails(self, tmp_path):
Expand All @@ -279,7 +287,9 @@ async def mount(coordinator):
validator = ProviderValidator()
result = await validator.validate(str(module_file))
assert any(
"2 parameters" in c.message.lower() for c in result.checks if not c.passed and c.name == "mount_signature"
"2 parameters" in c.message.lower()
for c in result.checks
if not c.passed and c.name == "mount_signature"
)

@pytest.mark.asyncio
Expand All @@ -297,7 +307,9 @@ async def mount(coordinator, config):
result = await validator.validate(str(module_file))

# Should pass importable and mount_exists and mount_signature checks
signature_check = next((c for c in result.checks if c.name == "mount_signature"), None)
signature_check = next(
(c for c in result.checks if c.name == "mount_signature"), None
)
assert signature_check is not None
assert signature_check.passed is True

Expand Down Expand Up @@ -333,7 +345,9 @@ async def mount(coordinator, config):
result = await validator.validate(str(module_dir))

# Should pass importable check
importable_check = next((c for c in result.checks if c.name == "module_importable"), None)
importable_check = next(
(c for c in result.checks if c.name == "module_importable"), None
)
assert importable_check is not None
assert importable_check.passed is True

Expand Down Expand Up @@ -489,7 +503,9 @@ def test_valid_full_mount_plan(self):
"orchestrator": {"module": "orchestrator-default"},
"context": {"module": "context-default"},
},
"providers": [{"module": "provider-anthropic", "config": {"model": "sonnet"}}],
"providers": [
{"module": "provider-anthropic", "config": {"model": "sonnet"}}
],
"tools": [{"module": "tool-web-search"}],
"hooks": [],
}
Expand Down Expand Up @@ -648,7 +664,10 @@ def test_session_not_dict_fails(self):
mount_plan = {"session": "not a dict"}
result = MountPlanValidator().validate(mount_plan)
assert not result.passed
assert any("session" in e.message.lower() and "dict" in e.message.lower() for e in result.errors)
assert any(
"session" in e.message.lower() and "dict" in e.message.lower()
for e in result.errors
)

def test_agents_section_not_validated_as_module_list(self):
"""Agents section is not validated as a module list (it's dict of configs)."""
Expand Down Expand Up @@ -734,15 +753,25 @@ def parse_tool_calls(self, response):
result_no_config = await validator.validate(str(module_dir))
# Should fail with error about no provider mounted
assert not result_no_config.passed, "Should fail when mount returns None"
assert any("No provider was mounted" in c.message for c in result_no_config.checks if c.severity == "error")
assert any(
"No provider was mounted" in c.message
for c in result_no_config.checks
if c.severity == "error"
)

# Test 2: With config - should mount successfully and pass validation
result_with_config = await validator.validate(str(module_dir), config={"api_key": "test-key"})
result_with_config = await validator.validate(
str(module_dir), config={"api_key": "test-key"}
)
# Should pass with provider mounted
assert result_with_config.passed, (
f"Should pass with config. Errors: {[c.message for c in result_with_config.errors]}"
)
assert any("implements Provider protocol" in c.message for c in result_with_config.checks if c.passed)
assert any(
"implements Provider protocol" in c.message
for c in result_with_config.checks
if c.passed
)

@pytest.mark.asyncio
async def test_tool_validation_with_config(self, tmp_path):
Expand All @@ -764,6 +793,7 @@ async def mount(coordinator: ModuleCoordinator, config):
class MockTool:
name = "mock-tool"
description = "A mock tool"
input_schema = {"type": "object", "properties": {}}

async def execute(self, input):
return ToolResult(success=True, output={"result": "mock"})
Expand All @@ -777,11 +807,25 @@ async def execute(self, input):
validator = ToolValidator()

# With enabled=False - should FAIL (returns None, no tool mounted)
result_disabled = await validator.validate(str(module_dir), config={"enabled": False})
result_disabled = await validator.validate(
str(module_dir), config={"enabled": False}
)
assert not result_disabled.passed, "Should fail when mount returns None"
assert any("No tool was mounted" in c.message for c in result_disabled.checks if c.severity == "error")
assert any(
"No tool was mounted" in c.message
for c in result_disabled.checks
if c.severity == "error"
)

# With enabled=True - should mount successfully and pass validation
result_enabled = await validator.validate(str(module_dir), config={"enabled": True})
assert result_enabled.passed, f"Should pass with config. Errors: {[c.message for c in result_enabled.errors]}"
assert any("implements Tool protocol" in c.message for c in result_enabled.checks if c.passed)
result_enabled = await validator.validate(
str(module_dir), config={"enabled": True}
)
assert result_enabled.passed, (
f"Should pass with config. Errors: {[c.message for c in result_enabled.errors]}"
)
assert any(
"implements Tool protocol" in c.message
for c in result_enabled.checks
if c.passed
)