Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 81 additions & 6 deletions src/adcp/utils/response_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,67 @@ def parse_mcp_content(content: list[dict[str, Any]], response_type: type[T]) ->
)


# Protocol-level fields from ProtocolResponse (core/response.json) and
# ProtocolEnvelope (core/protocol_envelope.json). These are separated from
# task data for schema validation, but preserved at the TaskResult level.
# Note: 'data' and 'payload' are handled separately as wrapper fields.
PROTOCOL_FIELDS = {
"message", # Human-readable summary
"context_id", # Session continuity identifier
"task_id", # Async operation identifier
"status", # Task execution state
"timestamp", # Response timestamp
}


def _extract_task_data(data: dict[str, Any]) -> dict[str, Any]:
"""
Extract task-specific data from a protocol response.

Servers may return responses in ProtocolResponse format:
{"message": "...", "context_id": "...", "data": {...}}

Or ProtocolEnvelope format:
{"message": "...", "status": "...", "payload": {...}}

Or task data directly with protocol fields mixed in:
{"message": "...", "products": [...], ...}

This function separates task-specific data for schema validation.
Protocol fields are preserved at the TaskResult level.

Args:
data: Response data dict

Returns:
Task-specific data suitable for schema validation.
Returns the same dict object if no extraction is needed.
"""
# Check for wrapped payload fields
# (ProtocolResponse uses 'data', ProtocolEnvelope uses 'payload')
if "data" in data and isinstance(data["data"], dict):
return data["data"]
if "payload" in data and isinstance(data["payload"], dict):
return data["payload"]

# Check if any protocol fields are present
if not any(k in PROTOCOL_FIELDS for k in data):
return data # Return same object for identity check

# Separate task data from protocol fields
return {k: v for k, v in data.items() if k not in PROTOCOL_FIELDS}


def parse_json_or_text(data: Any, response_type: type[T]) -> T:
"""
Parse data that might be JSON string, dict, or other format.

Used by A2A adapter for flexible response parsing.

Handles protocol-level wrapping where servers return:
- {"message": "...", "data": {...task_data...}}
- {"message": "...", ...task_fields...}

Args:
data: Response data (string, dict, or other)
response_type: Expected Pydantic model type
Expand All @@ -147,22 +202,42 @@ def parse_json_or_text(data: Any, response_type: type[T]) -> T:
"""
# If already a dict, try direct validation
if isinstance(data, dict):
# Try direct validation first
original_error: Exception | None = None
try:
return _validate_union_type(data, response_type)
except ValidationError as e:
# Get the type name, handling Union types
type_name = getattr(response_type, "__name__", str(response_type))
raise ValueError(f"Response doesn't match expected schema {type_name}: {e}") from e
except (ValidationError, ValueError) as e:
original_error = e

# Try extracting task data (separates protocol fields)
task_data = _extract_task_data(data)
if task_data is not data:
try:
return _validate_union_type(task_data, response_type)
except (ValidationError, ValueError):
pass # Fall through to raise original error

# Report the original validation error
type_name = getattr(response_type, "__name__", str(response_type))
raise ValueError(
f"Response doesn't match expected schema {type_name}: {original_error}"
) from original_error

# If string, try JSON parsing
if isinstance(data, str):
try:
parsed = json.loads(data)
return _validate_union_type(parsed, response_type)
except json.JSONDecodeError as e:
raise ValueError(f"Response is not valid JSON: {e}") from e

# Recursively handle dict parsing (which includes protocol field extraction)
if isinstance(parsed, dict):
return parse_json_or_text(parsed, response_type)

# Non-dict JSON (shouldn't happen for AdCP responses)
try:
return _validate_union_type(parsed, response_type)
except ValidationError as e:
# Get the type name, handling Union types
type_name = getattr(response_type, "__name__", str(response_type))
raise ValueError(f"Response doesn't match expected schema {type_name}: {e}") from e

Expand Down
148 changes: 148 additions & 0 deletions tests/test_response_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,151 @@ def test_json_string_not_matching_schema_raises_error(self):

with pytest.raises(ValueError, match="doesn't match expected schema"):
parse_json_or_text(data, SampleResponse)


class ProductResponse(BaseModel):
"""Response type without protocol fields for testing protocol field stripping."""

products: list[str]
total: int = 0


class TestProtocolFieldExtraction:
"""Tests for protocol field extraction from A2A responses.

A2A servers may include protocol-level fields (message, context_id, data)
that are not part of task-specific response schemas. These are separated
for task data validation, but preserved at the TaskResult level.

See: https://github.com/adcontextprotocol/adcp-client-python/issues/109
"""

def test_response_with_message_field_separated(self):
"""Test that protocol 'message' field is separated before validation."""
# A2A server returns task data with protocol message mixed in
data = {
"message": "No products matched your requirements.",
"products": ["product-1", "product-2"],
"total": 2,
}

result = parse_json_or_text(data, ProductResponse)

assert isinstance(result, ProductResponse)
assert result.products == ["product-1", "product-2"]
assert result.total == 2

def test_response_with_context_id_separated(self):
"""Test that protocol 'context_id' field is separated before validation."""
data = {
"context_id": "session-123",
"products": ["product-1"],
"total": 1,
}

result = parse_json_or_text(data, ProductResponse)

assert isinstance(result, ProductResponse)
assert result.products == ["product-1"]

def test_response_with_multiple_protocol_fields_separated(self):
"""Test that multiple protocol fields are separated."""
data = {
"message": "Found products",
"context_id": "session-456",
"products": ["a", "b", "c"],
"total": 3,
}

result = parse_json_or_text(data, ProductResponse)

assert isinstance(result, ProductResponse)
assert result.products == ["a", "b", "c"]
assert result.total == 3

def test_response_with_data_wrapper_extracted(self):
"""Test that ProtocolResponse 'data' wrapper is extracted."""
# Full ProtocolResponse format: {"message": "...", "data": {...task_data...}}
data = {
"message": "Operation completed",
"context_id": "ctx-789",
"data": {
"products": ["wrapped-product"],
"total": 1,
},
}

result = parse_json_or_text(data, ProductResponse)

assert isinstance(result, ProductResponse)
assert result.products == ["wrapped-product"]
assert result.total == 1

def test_response_with_payload_wrapper_extracted(self):
"""Test that ProtocolEnvelope 'payload' wrapper is extracted."""
# Full ProtocolEnvelope format
data = {
"message": "Operation completed",
"status": "completed",
"task_id": "task-123",
"timestamp": "2025-01-01T00:00:00Z",
"payload": {
"products": ["envelope-product"],
"total": 1,
},
}

result = parse_json_or_text(data, ProductResponse)

assert isinstance(result, ProductResponse)
assert result.products == ["envelope-product"]
assert result.total == 1

def test_exact_match_still_works(self):
"""Test that responses exactly matching schema still work."""
data = {
"products": ["exact-match"],
"total": 1,
}

result = parse_json_or_text(data, ProductResponse)

assert result.products == ["exact-match"]
assert result.total == 1

def test_json_string_with_protocol_fields(self):
"""Test JSON string with protocol fields is parsed correctly."""
data = json.dumps(
{
"message": "Success",
"products": ["from-json-string"],
"total": 1,
}
)

result = parse_json_or_text(data, ProductResponse)

assert result.products == ["from-json-string"]

def test_invalid_data_after_separation_raises_error(self):
"""Test that invalid data still raises error after separation."""
data = {
"message": "Some message",
"wrong_field": "value",
}

with pytest.raises(ValueError, match="doesn't match expected schema"):
parse_json_or_text(data, ProductResponse)

def test_model_with_message_field_validates_directly(self):
"""Test that models containing 'message' field validate without separation."""
# SampleResponse has a 'message' field, so it should validate directly
data = {
"message": "Hello",
"count": 42,
}

result = parse_json_or_text(data, SampleResponse)

assert result.message == "Hello"
assert result.count == 42