diff --git a/pyproject.toml b/pyproject.toml index 7906f74..bbb2c32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "sailhouse" -version = "0.1.1" +version = "1.5.0" description = "Python SDK for Sailhouse - Event Streaming Platform" readme = "README.md" authors = [ diff --git a/src/sailhouse/__init__.py b/src/sailhouse/__init__.py index b863c90..c9fe5c6 100644 --- a/src/sailhouse/__init__.py +++ b/src/sailhouse/__init__.py @@ -1,11 +1,41 @@ -from .client import SailhouseClient, GetEventsResponse, Event +from .client import SailhouseClient, GetEventsResponse, Event, WaitOptions, WaitGroup from .exceptions import SailhouseError +from .admin import AdminClient, FilterCondition, ComplexFilter, Filter, RegisterResult +from .subscriber import SailhouseSubscriber, SubscriberOptions, SubscriptionHandler +from .push_subscriptions import ( + PushSubscriptionVerifier, + PushSubscriptionVerificationError, + SignatureComponents, + PushSubscriptionHeaders, + PushSubscriptionPayload, + VerificationOptions, + verify_push_subscription_signature, + verify_push_subscription_signature_safe +) -__version__ = "0.1.0" +__version__ = "1.5.0" __all__ = [ "SailhouseClient", "Event", "GetEventsResponse", "SailhouseError", + "AdminClient", + "FilterCondition", + "ComplexFilter", + "Filter", + "RegisterResult", + "WaitOptions", + "WaitGroup", + "SailhouseSubscriber", + "SubscriberOptions", + "SubscriptionHandler", + "PushSubscriptionVerifier", + "PushSubscriptionVerificationError", + "SignatureComponents", + "PushSubscriptionHeaders", + "PushSubscriptionPayload", + "VerificationOptions", + "verify_push_subscription_signature", + "verify_push_subscription_signature_safe", ] diff --git a/src/sailhouse/admin.py b/src/sailhouse/admin.py new file mode 100644 index 0000000..7eb2260 --- /dev/null +++ b/src/sailhouse/admin.py @@ -0,0 +1,89 @@ +from dataclasses import dataclass +from typing import Dict, Any, Optional, List, Union, Literal +from .exceptions import SailhouseError +import requests + + +@dataclass +class FilterCondition: + path: str + condition: str + value: Any + + +@dataclass +class ComplexFilter: + conditions: List[FilterCondition] + operator: Literal["and", "or"] = "and" + + +# Union type for filters - can be boolean, None, or complex filter +Filter = Union[bool, None, ComplexFilter] + + +@dataclass +class RegisterResult: + outcome: Literal["created", "updated", "none"] + + +class AdminClient: + def __init__(self, sailhouse_client: 'SailhouseClient'): + self._client = sailhouse_client + + def register_push_subscription( + self, + topic: str, + subscription: str, + endpoint: str, + **kwargs + ) -> RegisterResult: + """Register a push subscription for webhook delivery""" + url = f"{self._client.base_url}/admin/topics/{topic}/subscriptions/{subscription}/push" + + body: Dict[str, Any] = { + "endpoint": endpoint + } + + # Handle optional parameters + filter_condition = kwargs.get('filter_condition') + rate_limit = kwargs.get('rate_limit') + deduplication = kwargs.get('deduplication') + + # Handle filter_condition - only add to body if explicitly provided + if 'filter_condition' in kwargs: + if isinstance(filter_condition, bool): + body["filter"] = filter_condition + elif filter_condition is None: + body["filter"] = None + elif isinstance(filter_condition, ComplexFilter): + body["filter"] = { + "conditions": [ + { + "path": cond.path, + "condition": cond.condition, + "value": cond.value + } + for cond in filter_condition.conditions + ], + "operator": filter_condition.operator + } + + if rate_limit is not None: + body["rate_limit"] = rate_limit + + if deduplication is not None: + body["deduplication"] = deduplication + + response = self._client.session.post( + url, + json=body, + timeout=self._client.timeout + ) + + if response.status_code not in (200, 201): + raise SailhouseError( + f"Failed to register push subscription: {response.status_code} - {response.text}" + ) + + data = response.json() + return RegisterResult(outcome=data.get("outcome", "none")) \ No newline at end of file diff --git a/src/sailhouse/client.py b/src/sailhouse/client.py index 0102e20..1e0ff30 100644 --- a/src/sailhouse/client.py +++ b/src/sailhouse/client.py @@ -8,10 +8,39 @@ import websockets import asyncio import json +import uuid T = TypeVar('T') +@dataclass +class WaitOptions: + ttl: Optional[int] = None # TTL in seconds + + +@dataclass +class WaitGroup: + instance_id: str + client: 'SailhouseClient' + + async def publish( + self, + topic: str, + data: Dict[str, Any], + *, + scheduled_time: Optional[datetime] = None, + metadata: Optional[Dict[str, Any]] = None + ) -> None: + """Publish an event under this waitgroup""" + await self.client.publish( + topic=topic, + data=data, + scheduled_time=scheduled_time, + metadata=metadata, + wait_group_instance_id=self.instance_id + ) + + @dataclass class Event(Generic[T]): id: str @@ -19,6 +48,7 @@ class Event(Generic[T]): _topic: str _subscription: str _client: 'SailhouseClient' + metadata: Optional[Dict[str, Any]] = None def as_type(self, cls: type[T]) -> T: """Convert event data to specified type""" @@ -57,6 +87,10 @@ def __init__( "Authorization": token, "x-source": "sailhouse-python" }) + + # Import here to avoid circular import + from .admin import AdminClient + self.admin = AdminClient(self) async def pull( self, @@ -64,7 +98,7 @@ async def pull( subscription: str, ) -> Event: """Pull an event from a subscription, locking it for processing""" - url = f"{self.BASE_URL}/topics/{topic}/subscriptions/{subscription}/events/pull" + url = f"{self.base_url}/topics/{topic}/subscriptions/{subscription}/events/pull" response = self.session.get(url, timeout=self.timeout) if response.status_code == 204: @@ -80,7 +114,8 @@ async def pull( data=data['data'], _topic=topic, _subscription=subscription, - _client=self + _client=self, + metadata=data.get('metadata') ) async def get_events( @@ -101,7 +136,7 @@ async def get_events( if time_window is not None: params['time_window'] = time_window - url = f"{self.BASE_URL}/topics/{topic}/subscriptions/{subscription}/events" + url = f"{self.base_url}/topics/{topic}/subscriptions/{subscription}/events" response = self.session.get(url, params=params, timeout=self.timeout) if response.status_code != 200: @@ -115,7 +150,8 @@ async def get_events( data=e['data'], _topic=topic, _subscription=subscription, - _client=self + _client=self, + metadata=e.get('metadata') ) for e in data['events'] ] @@ -126,22 +162,50 @@ async def get_events( limit=data.get('limit', 0) ) + def wait(self, options: Optional[WaitOptions] = None) -> WaitGroup: + """Create a waitgroup for coordinated event publishing""" + instance_id = str(uuid.uuid4()) + + if options and options.ttl: + # Create waitgroup with TTL + url = f"{self.base_url}/waitgroups" + body = { + "instance_id": instance_id, + "ttl": options.ttl + } + + response = self.session.post( + url, + json=body, + timeout=self.timeout + ) + + if response.status_code not in (200, 201): + raise SailhouseError( + f"Failed to create waitgroup: {response.status_code} - {response.text}" + ) + + return WaitGroup(instance_id=instance_id, client=self) + async def publish( self, topic: str, data: Dict[str, Any], *, scheduled_time: Optional[datetime] = None, - metadata: Optional[Dict[str, Any]] = None + metadata: Optional[Dict[str, Any]] = None, + wait_group_instance_id: Optional[str] = None ) -> None: """Publish an event to a topic""" - url = f"{self.BASE_URL}/topics/{topic}/events" + url = f"{self.base_url}/topics/{topic}/events" body = {"data": data} if scheduled_time: body["send_at"] = scheduled_time.isoformat() if metadata: body["metadata"] = metadata + if wait_group_instance_id: + body["wait_group_instance_id"] = wait_group_instance_id response = self.session.post( url, @@ -161,7 +225,7 @@ async def acknowledge_message( event_id: str ) -> None: """Acknowledge a message""" - url = f"{self.BASE_URL}/topics/{topic}/subscriptions/{subscription}/events/{event_id}" + url = f"{self.base_url}/topics/{topic}/subscriptions/{subscription}/events/{event_id}" response = self.session.post(url, timeout=self.timeout) if response.status_code not in (200, 204): @@ -175,7 +239,7 @@ async def nack_message( event_id: str ) -> None: """Nacknowledge a message""" - url = f"{self.BASE_URL}/topics/{topic}/subscriptions/{subscription}/events/{event_id}/nack" + url = f"{self.base_url}/topics/{topic}/subscriptions/{subscription}/events/{event_id}/nack" response = self.session.post(url, timeout=self.timeout) if response.status_code not in (200, 204): @@ -209,3 +273,52 @@ async def subscribe( if exit_on_error: break continue + + def subscriber(self, options=None): + """Create a SailhouseSubscriber for long-running event processing""" + from .subscriber import SailhouseSubscriber, SubscriberOptions + return SailhouseSubscriber(self, options) + + def verify_push_subscription( + self, + webhook_secret: str, + signature_header: str, + raw_body, + tolerance: int = 300 + ) -> bool: + """Verify push subscription webhook signature""" + from .push_subscriptions import ( + PushSubscriptionVerifier, + PushSubscriptionHeaders, + PushSubscriptionPayload, + VerificationOptions + ) + + verifier = PushSubscriptionVerifier(webhook_secret) + headers = PushSubscriptionHeaders(signature=signature_header) + payload = PushSubscriptionPayload(raw_body=raw_body) + options = VerificationOptions(tolerance=tolerance) + + return verifier.verify(headers, payload, options) + + def verify_push_subscription_safe( + self, + webhook_secret: str, + signature_header: str, + raw_body, + tolerance: int = 300 + ) -> bool: + """Safe push subscription verification - returns boolean instead of raising""" + from .push_subscriptions import ( + PushSubscriptionVerifier, + PushSubscriptionHeaders, + PushSubscriptionPayload, + VerificationOptions + ) + + verifier = PushSubscriptionVerifier(webhook_secret) + headers = PushSubscriptionHeaders(signature=signature_header) + payload = PushSubscriptionPayload(raw_body=raw_body) + options = VerificationOptions(tolerance=tolerance) + + return verifier.verify_safe(headers, payload, options) diff --git a/src/sailhouse/push_subscriptions.py b/src/sailhouse/push_subscriptions.py new file mode 100644 index 0000000..4b39b6c --- /dev/null +++ b/src/sailhouse/push_subscriptions.py @@ -0,0 +1,191 @@ +import hmac +import hashlib +import time +from dataclasses import dataclass +from typing import Dict, Optional, Union, Literal +from .exceptions import SailhouseError + + +class PushSubscriptionVerificationError(SailhouseError): + """Custom error for push subscription verification failures""" + + INVALID_SIGNATURE_FORMAT = "invalid_signature_format" + MISSING_TIMESTAMP = "missing_timestamp" + TIMESTAMP_TOO_OLD = "timestamp_too_old" + TIMESTAMP_TOO_NEW = "timestamp_too_new" + SIGNATURE_MISMATCH = "signature_mismatch" + + def __init__(self, message: str, error_code: str): + super().__init__(message) + self.error_code = error_code + + +@dataclass +class SignatureComponents: + timestamp: int + signature: str + + +@dataclass +class PushSubscriptionHeaders: + signature: str + + +@dataclass +class PushSubscriptionPayload: + raw_body: Union[str, bytes] + + +@dataclass +class VerificationOptions: + tolerance: int = 300 # 5 minutes in seconds + + +class PushSubscriptionVerifier: + def __init__(self, webhook_secret: str): + self.webhook_secret = webhook_secret.encode('utf-8') + + def parse_signature_header(self, signature_header: str) -> SignatureComponents: + """Parse the signature header format: 't=,v1='""" + if not signature_header: + raise PushSubscriptionVerificationError( + "Signature header is missing", + PushSubscriptionVerificationError.INVALID_SIGNATURE_FORMAT + ) + + parts = signature_header.split(',') + timestamp = None + signature = None + + for part in parts: + if '=' not in part: + continue + key, value = part.split('=', 1) + if key == 't': + try: + timestamp = int(value) + except ValueError: + raise PushSubscriptionVerificationError( + "Invalid timestamp format", + PushSubscriptionVerificationError.INVALID_SIGNATURE_FORMAT + ) + elif key == 'v1': + signature = value + + if timestamp is None: + raise PushSubscriptionVerificationError( + "Missing timestamp in signature header", + PushSubscriptionVerificationError.MISSING_TIMESTAMP + ) + + if signature is None: + raise PushSubscriptionVerificationError( + "Missing signature in header", + PushSubscriptionVerificationError.INVALID_SIGNATURE_FORMAT + ) + + return SignatureComponents(timestamp=timestamp, signature=signature) + + def validate_timestamp(self, timestamp: int, tolerance: int = 300) -> None: + """Validate timestamp is within acceptable range""" + current_time = int(time.time()) + + if timestamp > current_time + tolerance: + raise PushSubscriptionVerificationError( + f"Timestamp is too new: {timestamp} > {current_time + tolerance}", + PushSubscriptionVerificationError.TIMESTAMP_TOO_NEW + ) + + if timestamp < current_time - tolerance: + raise PushSubscriptionVerificationError( + f"Timestamp is too old: {timestamp} < {current_time - tolerance}", + PushSubscriptionVerificationError.TIMESTAMP_TOO_OLD + ) + + def compute_signature(self, timestamp: int, payload: bytes) -> str: + """Compute HMAC-SHA256 signature""" + signed_payload = f"{timestamp}.".encode('utf-8') + payload + signature = hmac.new( + self.webhook_secret, + signed_payload, + hashlib.sha256 + ).hexdigest() + return signature + + def constant_time_compare(self, a: str, b: str) -> bool: + """Constant-time string comparison to prevent timing attacks""" + if len(a) != len(b): + return False + + result = 0 + for x, y in zip(a, b): + result |= ord(x) ^ ord(y) + return result == 0 + + def verify( + self, + headers: PushSubscriptionHeaders, + payload: PushSubscriptionPayload, + options: Optional[VerificationOptions] = None + ) -> bool: + """Verify push subscription signature - raises exception on failure""" + if options is None: + options = VerificationOptions() + + # Parse signature header + components = self.parse_signature_header(headers.signature) + + # Validate timestamp + self.validate_timestamp(components.timestamp, options.tolerance) + + # Convert payload to bytes if needed + if isinstance(payload.raw_body, str): + payload_bytes = payload.raw_body.encode('utf-8') + else: + payload_bytes = payload.raw_body + + # Compute expected signature + expected_signature = self.compute_signature(components.timestamp, payload_bytes) + + # Compare signatures using constant-time comparison + if not self.constant_time_compare(expected_signature, components.signature): + raise PushSubscriptionVerificationError( + "Signature verification failed", + PushSubscriptionVerificationError.SIGNATURE_MISMATCH + ) + + return True + + def verify_safe( + self, + headers: PushSubscriptionHeaders, + payload: PushSubscriptionPayload, + options: Optional[VerificationOptions] = None + ) -> bool: + """Safe version of verify that returns boolean instead of raising""" + try: + return self.verify(headers, payload, options) + except PushSubscriptionVerificationError: + return False + + +def verify_push_subscription_signature( + webhook_secret: str, + headers: PushSubscriptionHeaders, + payload: PushSubscriptionPayload, + options: Optional[VerificationOptions] = None +) -> bool: + """One-off verification function - raises exception on failure""" + verifier = PushSubscriptionVerifier(webhook_secret) + return verifier.verify(headers, payload, options) + + +def verify_push_subscription_signature_safe( + webhook_secret: str, + headers: PushSubscriptionHeaders, + payload: PushSubscriptionPayload, + options: Optional[VerificationOptions] = None +) -> bool: + """Safe one-off verification function - returns boolean""" + verifier = PushSubscriptionVerifier(webhook_secret) + return verifier.verify_safe(headers, payload, options) \ No newline at end of file diff --git a/src/sailhouse/subscriber.py b/src/sailhouse/subscriber.py new file mode 100644 index 0000000..4fde2de --- /dev/null +++ b/src/sailhouse/subscriber.py @@ -0,0 +1,99 @@ +from dataclasses import dataclass +from typing import Dict, Any, Optional, Callable, Awaitable, List +from .client import Event, SailhouseClient +import asyncio + + +@dataclass +class SubscriberOptions: + per_subscription_processors: int = 1 + + +SubscriptionHandler = Callable[[Event], Awaitable[None]] + + +@dataclass +class Subscription: + topic: str + subscription: str + handler: SubscriptionHandler + + +class SailhouseSubscriber: + def __init__(self, client: SailhouseClient, options: Optional[SubscriberOptions] = None): + self.client = client + self.options = options or SubscriberOptions() + self.subscriptions: List[Subscription] = [] + self._running = False + self._tasks: List[asyncio.Task] = [] + + def subscribe(self, topic: str, subscription: str, handler: SubscriptionHandler) -> None: + """Register a topic/subscription handler""" + self.subscriptions.append(Subscription( + topic=topic, + subscription=subscription, + handler=handler + )) + + async def start(self) -> None: + """Start processing events for all registered subscriptions""" + if self._running: + return + + self._running = True + + # Create processor tasks for each subscription + for subscription in self.subscriptions: + for _ in range(self.options.per_subscription_processors): + task = asyncio.create_task( + self._process_subscription(subscription) + ) + self._tasks.append(task) + + # Wait for all tasks to complete (they run indefinitely until stopped) + if self._tasks: + await asyncio.gather(*self._tasks, return_exceptions=True) + + async def stop(self) -> None: + """Stop processing events""" + self._running = False + + # Cancel all running tasks + for task in self._tasks: + task.cancel() + + # Wait for tasks to be cancelled + if self._tasks: + await asyncio.gather(*self._tasks, return_exceptions=True) + + self._tasks.clear() + + async def _process_subscription(self, subscription: Subscription) -> None: + """Process events for a single subscription""" + while self._running: + try: + # Pull an event + event = await self.client.pull(subscription.topic, subscription.subscription) + + if event: + try: + # Process the event with the handler + await subscription.handler(event) + # Automatically acknowledge the event + await event.ack() + except Exception as e: + # If handler fails, we don't acknowledge the event + # This allows for retry logic at the server level + # Log the error but continue processing + pass + + # Continue immediately to try to fetch another event + continue + + # No events available, wait before polling again + await asyncio.sleep(1.0) + + except Exception as e: + # Error in pulling events, wait before retrying + await asyncio.sleep(1.0) + continue \ No newline at end of file diff --git a/test/admin_test.py b/test/admin_test.py new file mode 100644 index 0000000..7ba0fed --- /dev/null +++ b/test/admin_test.py @@ -0,0 +1,235 @@ +import pytest +from unittest.mock import patch +import json + +from sailhouse import SailhouseClient, SailhouseError +from sailhouse.admin import AdminClient, FilterCondition, ComplexFilter, RegisterResult + + +@pytest.fixture +def client(): + return SailhouseClient(token="test-token") + + +@pytest.fixture +def mock_response(): + class MockResponse: + def __init__(self, status_code, json_data=None): + self.status_code = status_code + self._json_data = json_data + self.text = json.dumps(json_data) if json_data else "" + + def json(self): + if self._json_data is None: + raise ValueError("No JSON data available") + return self._json_data + + return MockResponse + + +def test_admin_client_initialization(client): + assert hasattr(client, 'admin') + assert isinstance(client.admin, AdminClient) + assert client.admin._client is client + + +def test_filter_condition_creation(): + condition = FilterCondition(path="data.type", condition="eq", value="user") + assert condition.path == "data.type" + assert condition.condition == "eq" + assert condition.value == "user" + + +def test_complex_filter_creation(): + condition1 = FilterCondition(path="data.type", condition="eq", value="user") + condition2 = FilterCondition(path="data.status", condition="eq", value="active") + + complex_filter = ComplexFilter(conditions=[condition1, condition2], operator="and") + assert len(complex_filter.conditions) == 2 + assert complex_filter.operator == "and" + + +def test_complex_filter_default_operator(): + condition = FilterCondition(path="data.type", condition="eq", value="user") + complex_filter = ComplexFilter(conditions=[condition]) + assert complex_filter.operator == "and" + + +def test_register_push_subscription_basic(client, mock_response): + response_data = {"outcome": "created"} + + with patch.object(client.session, 'post', return_value=mock_response(201, response_data)): + result = client.admin.register_push_subscription( + topic="test-topic", + subscription="test-sub", + endpoint="https://example.com/webhook" + ) + + assert isinstance(result, RegisterResult) + assert result.outcome == "created" + + +def test_register_push_subscription_with_boolean_filter(client, mock_response): + response_data = {"outcome": "updated"} + + with patch.object(client.session, 'post', return_value=mock_response(200, response_data)) as mock_post: + result = client.admin.register_push_subscription( + topic="test-topic", + subscription="test-sub", + endpoint="https://example.com/webhook", + filter_condition=True + ) + + assert result.outcome == "updated" + + # Verify the request body + call_args = mock_post.call_args + body = call_args[1]['json'] + assert body['filter'] is True + + +def test_register_push_subscription_with_none_filter(client, mock_response): + response_data = {"outcome": "none"} + + with patch.object(client.session, 'post', return_value=mock_response(200, response_data)) as mock_post: + result = client.admin.register_push_subscription( + topic="test-topic", + subscription="test-sub", + endpoint="https://example.com/webhook", + filter_condition=None + ) + + assert result.outcome == "none" + + # Verify the request body + call_args = mock_post.call_args + body = call_args[1]['json'] + assert body['filter'] is None + + +def test_register_push_subscription_with_complex_filter(client, mock_response): + condition1 = FilterCondition(path="data.type", condition="eq", value="user") + condition2 = FilterCondition(path="data.status", condition="eq", value="active") + complex_filter = ComplexFilter(conditions=[condition1, condition2], operator="or") + + response_data = {"outcome": "created"} + + with patch.object(client.session, 'post', return_value=mock_response(201, response_data)) as mock_post: + result = client.admin.register_push_subscription( + topic="test-topic", + subscription="test-sub", + endpoint="https://example.com/webhook", + filter_condition=complex_filter + ) + + assert result.outcome == "created" + + # Verify the request body + call_args = mock_post.call_args + body = call_args[1]['json'] + expected_filter = { + "conditions": [ + {"path": "data.type", "condition": "eq", "value": "user"}, + {"path": "data.status", "condition": "eq", "value": "active"} + ], + "operator": "or" + } + assert body['filter'] == expected_filter + + +def test_register_push_subscription_with_rate_limit_and_deduplication(client, mock_response): + response_data = {"outcome": "updated"} + + with patch.object(client.session, 'post', return_value=mock_response(200, response_data)) as mock_post: + result = client.admin.register_push_subscription( + topic="test-topic", + subscription="test-sub", + endpoint="https://example.com/webhook", + rate_limit=100, + deduplication=True + ) + + assert result.outcome == "updated" + + # Verify the request body + call_args = mock_post.call_args + body = call_args[1]['json'] + assert body['rate_limit'] == 100 + assert body['deduplication'] is True + + +def test_register_push_subscription_with_all_options(client, mock_response): + condition = FilterCondition(path="data.priority", condition="gte", value=5) + complex_filter = ComplexFilter(conditions=[condition]) + + response_data = {"outcome": "created"} + + with patch.object(client.session, 'post', return_value=mock_response(201, response_data)) as mock_post: + result = client.admin.register_push_subscription( + topic="notifications", + subscription="high-priority", + endpoint="https://api.example.com/webhooks/notifications", + filter_condition=complex_filter, + rate_limit=50, + deduplication=False + ) + + assert result.outcome == "created" + + # Verify the request was made to the correct URL + call_args = mock_post.call_args + url = call_args[0][0] + assert url == "https://api.sailhouse.dev/admin/topics/notifications/subscriptions/high-priority/push" + + # Verify the complete request body + body = call_args[1]['json'] + expected_body = { + "endpoint": "https://api.example.com/webhooks/notifications", + "filter": { + "conditions": [ + {"path": "data.priority", "condition": "gte", "value": 5} + ], + "operator": "and" + }, + "rate_limit": 50, + "deduplication": False + } + assert body == expected_body + + +def test_register_push_subscription_failure(client, mock_response): + with patch.object(client.session, 'post', return_value=mock_response(400, {})): + with pytest.raises(SailhouseError) as exc_info: + client.admin.register_push_subscription( + topic="test-topic", + subscription="test-sub", + endpoint="https://example.com/webhook" + ) + + assert "Failed to register push subscription" in str(exc_info.value) + + +def test_register_push_subscription_server_error(client, mock_response): + with patch.object(client.session, 'post', return_value=mock_response(500, {})): + with pytest.raises(SailhouseError) as exc_info: + client.admin.register_push_subscription( + topic="test-topic", + subscription="test-sub", + endpoint="https://example.com/webhook" + ) + + assert "Failed to register push subscription: 500" in str(exc_info.value) + + +def test_register_result_default_outcome(client, mock_response): + # Test when response doesn't include outcome field + response_data = {} + + with patch.object(client.session, 'post', return_value=mock_response(200, response_data)): + result = client.admin.register_push_subscription( + topic="test-topic", + subscription="test-sub", + endpoint="https://example.com/webhook" + ) + + assert result.outcome == "none" \ No newline at end of file diff --git a/test/client_test.py b/test/client_test.py index b0144db..6e3e9c3 100644 --- a/test/client_test.py +++ b/test/client_test.py @@ -43,6 +43,24 @@ def test_event_creation(): assert event.id == "test-id" assert event.data == {"message": "test"} + assert event.metadata is None + + +def test_event_creation_with_metadata(): + client = SailhouseClient(token="test-token") + metadata = {"source": "api", "version": "1.0"} + event = Event( + id="test-id", + data={"message": "test"}, + _topic="test-topic", + _subscription="test-sub", + _client=client, + metadata=metadata + ) + + assert event.id == "test-id" + assert event.data == {"message": "test"} + assert event.metadata == metadata def test_event_as_type(): @@ -75,6 +93,34 @@ def test_client_initialization(): assert client.session.headers["x-source"] == "sailhouse-python" +def test_client_initialization_custom_base_url(): + custom_url = "https://custom.api.example.com" + client = SailhouseClient(token="test-token", base_url=custom_url) + assert client.base_url == custom_url + + +def test_client_initialization_custom_timeout(): + client = SailhouseClient(token="test-token", timeout=10.0) + assert client.timeout == 10.0 + + +@pytest.mark.asyncio +async def test_custom_base_url_used_in_publish(mock_response): + custom_url = "https://custom.api.example.com" + client = SailhouseClient(token="test-token", base_url=custom_url) + + test_data = {"message": "test"} + + with patch.object(client.session, 'post', return_value=mock_response(201, {})) as mock_post: + await client.publish("test-topic", test_data) + + # Verify the custom URL was used + call_args = mock_post.call_args + url = call_args[0][0] + assert url.startswith(custom_url) + assert url == f"{custom_url}/topics/test-topic/events" + + @pytest.mark.asyncio async def test_get_events_success(client, mock_response): test_events = { @@ -95,6 +141,27 @@ async def test_get_events_success(client, mock_response): assert response.limit == 10 assert isinstance(response.events[0], Event) assert response.events[0].data["message"] == "test1" + assert response.events[0].metadata is None + + +@pytest.mark.asyncio +async def test_get_events_with_metadata(client, mock_response): + test_events = { + "events": [ + {"id": "1", "data": {"message": "test1"}, "metadata": {"source": "api"}}, + {"id": "2", "data": {"message": "test2"}, "metadata": {"source": "webhook"}} + ], + "offset": 0, + "limit": 10 + } + + with patch.object(client.session, 'get', return_value=mock_response(200, test_events)): + response = await client.get_events("test-topic", "test-sub") + + assert isinstance(response, GetEventsResponse) + assert len(response.events) == 2 + assert response.events[0].metadata == {"source": "api"} + assert response.events[1].metadata == {"source": "webhook"} @pytest.mark.asyncio @@ -156,6 +223,24 @@ async def test_pull_success(client, mock_response): assert event.data == {"message": "test1"} assert event._topic == "test-topic" assert event._subscription == "test-sub" + assert event.metadata is None + + +@pytest.mark.asyncio +async def test_pull_success_with_metadata(client, mock_response): + test_event = { + "id": "1", + "data": {"message": "test1"}, + "metadata": {"source": "api", "version": "1.0"} + } + + with patch.object(client.session, 'get', return_value=mock_response(200, test_event)): + event = await client.pull("test-topic", "test-sub") + + assert isinstance(event, Event) + assert event.id == "1" + assert event.data == {"message": "test1"} + assert event.metadata == {"source": "api", "version": "1.0"} @pytest.mark.asyncio diff --git a/test/push_subscriptions_test.py b/test/push_subscriptions_test.py new file mode 100644 index 0000000..e3ab4d9 --- /dev/null +++ b/test/push_subscriptions_test.py @@ -0,0 +1,361 @@ +import pytest +import time +import hmac +import hashlib +from unittest.mock import patch + +from sailhouse import ( + SailhouseClient, + PushSubscriptionVerifier, + PushSubscriptionVerificationError, + SignatureComponents, + PushSubscriptionHeaders, + PushSubscriptionPayload, + VerificationOptions, + verify_push_subscription_signature, + verify_push_subscription_signature_safe +) + + +@pytest.fixture +def webhook_secret(): + return "test_webhook_secret_key" + + +@pytest.fixture +def test_payload(): + return '{"event": "user.created", "data": {"id": 123, "name": "John"}}' + + +@pytest.fixture +def verifier(webhook_secret): + return PushSubscriptionVerifier(webhook_secret) + + +def create_valid_signature(webhook_secret: str, timestamp: int, payload: str) -> str: + """Helper to create a valid signature for testing""" + signed_payload = f"{timestamp}.{payload}".encode('utf-8') + signature = hmac.new( + webhook_secret.encode('utf-8'), + signed_payload, + hashlib.sha256 + ).hexdigest() + return f"t={timestamp},v1={signature}" + + +def test_verification_options_creation(): + options = VerificationOptions(tolerance=600) + assert options.tolerance == 600 + + +def test_verification_options_default(): + options = VerificationOptions() + assert options.tolerance == 300 + + +def test_push_subscription_headers_creation(): + headers = PushSubscriptionHeaders(signature="t=123,v1=abc") + assert headers.signature == "t=123,v1=abc" + + +def test_push_subscription_payload_creation(): + payload = PushSubscriptionPayload(raw_body="test body") + assert payload.raw_body == "test body" + + +def test_verifier_initialization(webhook_secret): + verifier = PushSubscriptionVerifier(webhook_secret) + assert verifier.webhook_secret == webhook_secret.encode('utf-8') + + +def test_parse_signature_header_valid(verifier): + signature_header = "t=1640995200,v1=abc123def456" + components = verifier.parse_signature_header(signature_header) + + assert components.timestamp == 1640995200 + assert components.signature == "abc123def456" + + +def test_parse_signature_header_multiple_parts(verifier): + signature_header = "t=1640995200,v1=abc123,v2=def456" + components = verifier.parse_signature_header(signature_header) + + assert components.timestamp == 1640995200 + assert components.signature == "abc123" # Takes first v1 + + +def test_parse_signature_header_missing_timestamp(verifier): + signature_header = "v1=abc123def456" + + with pytest.raises(PushSubscriptionVerificationError) as exc_info: + verifier.parse_signature_header(signature_header) + + assert exc_info.value.error_code == PushSubscriptionVerificationError.MISSING_TIMESTAMP + + +def test_parse_signature_header_missing_signature(verifier): + signature_header = "t=1640995200" + + with pytest.raises(PushSubscriptionVerificationError) as exc_info: + verifier.parse_signature_header(signature_header) + + assert exc_info.value.error_code == PushSubscriptionVerificationError.INVALID_SIGNATURE_FORMAT + + +def test_parse_signature_header_invalid_timestamp(verifier): + signature_header = "t=invalid,v1=abc123" + + with pytest.raises(PushSubscriptionVerificationError) as exc_info: + verifier.parse_signature_header(signature_header) + + assert exc_info.value.error_code == PushSubscriptionVerificationError.INVALID_SIGNATURE_FORMAT + + +def test_parse_signature_header_empty(verifier): + with pytest.raises(PushSubscriptionVerificationError) as exc_info: + verifier.parse_signature_header("") + + assert exc_info.value.error_code == PushSubscriptionVerificationError.INVALID_SIGNATURE_FORMAT + + +def test_validate_timestamp_valid(verifier): + current_time = int(time.time()) + # Should not raise + verifier.validate_timestamp(current_time, tolerance=300) + verifier.validate_timestamp(current_time - 100, tolerance=300) + verifier.validate_timestamp(current_time + 100, tolerance=300) + + +def test_validate_timestamp_too_old(verifier): + current_time = int(time.time()) + old_timestamp = current_time - 400 # 400 seconds ago + + with pytest.raises(PushSubscriptionVerificationError) as exc_info: + verifier.validate_timestamp(old_timestamp, tolerance=300) + + assert exc_info.value.error_code == PushSubscriptionVerificationError.TIMESTAMP_TOO_OLD + + +def test_validate_timestamp_too_new(verifier): + current_time = int(time.time()) + future_timestamp = current_time + 400 # 400 seconds in future + + with pytest.raises(PushSubscriptionVerificationError) as exc_info: + verifier.validate_timestamp(future_timestamp, tolerance=300) + + assert exc_info.value.error_code == PushSubscriptionVerificationError.TIMESTAMP_TOO_NEW + + +def test_compute_signature(verifier, test_payload): + timestamp = 1640995200 + signature = verifier.compute_signature(timestamp, test_payload.encode('utf-8')) + + # Verify it's a valid hex string of correct length (SHA256 = 64 chars) + assert len(signature) == 64 + assert all(c in '0123456789abcdef' for c in signature) + + +def test_constant_time_compare(verifier): + # Same strings + assert verifier.constant_time_compare("abc123", "abc123") is True + + # Different strings same length + assert verifier.constant_time_compare("abc123", "def456") is False + + # Different lengths + assert verifier.constant_time_compare("abc", "abcdef") is False + + # Empty strings + assert verifier.constant_time_compare("", "") is True + + +def test_verify_valid_signature(verifier, webhook_secret, test_payload): + current_time = int(time.time()) + signature_header = create_valid_signature(webhook_secret, current_time, test_payload) + + headers = PushSubscriptionHeaders(signature=signature_header) + payload = PushSubscriptionPayload(raw_body=test_payload) + + # Should not raise and return True + result = verifier.verify(headers, payload) + assert result is True + + +def test_verify_with_bytes_payload(verifier, webhook_secret, test_payload): + current_time = int(time.time()) + signature_header = create_valid_signature(webhook_secret, current_time, test_payload) + + headers = PushSubscriptionHeaders(signature=signature_header) + payload = PushSubscriptionPayload(raw_body=test_payload.encode('utf-8')) + + result = verifier.verify(headers, payload) + assert result is True + + +def test_verify_invalid_signature(verifier, test_payload): + current_time = int(time.time()) + # Create signature with wrong secret + wrong_signature = create_valid_signature("wrong_secret", current_time, test_payload) + + headers = PushSubscriptionHeaders(signature=wrong_signature) + payload = PushSubscriptionPayload(raw_body=test_payload) + + with pytest.raises(PushSubscriptionVerificationError) as exc_info: + verifier.verify(headers, payload) + + assert exc_info.value.error_code == PushSubscriptionVerificationError.SIGNATURE_MISMATCH + + +def test_verify_old_timestamp(verifier, webhook_secret, test_payload): + old_timestamp = int(time.time()) - 400 # 400 seconds ago + signature_header = create_valid_signature(webhook_secret, old_timestamp, test_payload) + + headers = PushSubscriptionHeaders(signature=signature_header) + payload = PushSubscriptionPayload(raw_body=test_payload) + + with pytest.raises(PushSubscriptionVerificationError) as exc_info: + verifier.verify(headers, payload) + + assert exc_info.value.error_code == PushSubscriptionVerificationError.TIMESTAMP_TOO_OLD + + +def test_verify_with_custom_tolerance(verifier, webhook_secret, test_payload): + old_timestamp = int(time.time()) - 400 # 400 seconds ago + signature_header = create_valid_signature(webhook_secret, old_timestamp, test_payload) + + headers = PushSubscriptionHeaders(signature=signature_header) + payload = PushSubscriptionPayload(raw_body=test_payload) + options = VerificationOptions(tolerance=500) # Allow 500 seconds + + # Should not raise with larger tolerance + result = verifier.verify(headers, payload, options) + assert result is True + + +def test_verify_safe_valid_signature(verifier, webhook_secret, test_payload): + current_time = int(time.time()) + signature_header = create_valid_signature(webhook_secret, current_time, test_payload) + + headers = PushSubscriptionHeaders(signature=signature_header) + payload = PushSubscriptionPayload(raw_body=test_payload) + + result = verifier.verify_safe(headers, payload) + assert result is True + + +def test_verify_safe_invalid_signature(verifier, test_payload): + current_time = int(time.time()) + wrong_signature = create_valid_signature("wrong_secret", current_time, test_payload) + + headers = PushSubscriptionHeaders(signature=wrong_signature) + payload = PushSubscriptionPayload(raw_body=test_payload) + + # Should return False instead of raising + result = verifier.verify_safe(headers, payload) + assert result is False + + +def test_one_off_verify_function(webhook_secret, test_payload): + current_time = int(time.time()) + signature_header = create_valid_signature(webhook_secret, current_time, test_payload) + + headers = PushSubscriptionHeaders(signature=signature_header) + payload = PushSubscriptionPayload(raw_body=test_payload) + + result = verify_push_subscription_signature(webhook_secret, headers, payload) + assert result is True + + +def test_one_off_verify_safe_function(webhook_secret, test_payload): + current_time = int(time.time()) + wrong_signature = create_valid_signature("wrong_secret", current_time, test_payload) + + headers = PushSubscriptionHeaders(signature=wrong_signature) + payload = PushSubscriptionPayload(raw_body=test_payload) + + result = verify_push_subscription_signature_safe(webhook_secret, headers, payload) + assert result is False + + +def test_client_verify_push_subscription(): + client = SailhouseClient(token="test-token") + webhook_secret = "test_secret" + test_payload = '{"test": "data"}' + current_time = int(time.time()) + signature_header = create_valid_signature(webhook_secret, current_time, test_payload) + + result = client.verify_push_subscription( + webhook_secret=webhook_secret, + signature_header=signature_header, + raw_body=test_payload + ) + assert result is True + + +def test_client_verify_push_subscription_safe(): + client = SailhouseClient(token="test-token") + webhook_secret = "test_secret" + test_payload = '{"test": "data"}' + current_time = int(time.time()) + wrong_signature = create_valid_signature("wrong_secret", current_time, test_payload) + + result = client.verify_push_subscription_safe( + webhook_secret=webhook_secret, + signature_header=wrong_signature, + raw_body=test_payload + ) + assert result is False + + +def test_client_verify_push_subscription_with_custom_tolerance(): + client = SailhouseClient(token="test-token") + webhook_secret = "test_secret" + test_payload = '{"test": "data"}' + old_timestamp = int(time.time()) - 400 + signature_header = create_valid_signature(webhook_secret, old_timestamp, test_payload) + + result = client.verify_push_subscription( + webhook_secret=webhook_secret, + signature_header=signature_header, + raw_body=test_payload, + tolerance=500 # Allow 500 seconds + ) + assert result is True + + +def test_error_code_constants(): + # Verify error code constants exist and are correct + assert PushSubscriptionVerificationError.INVALID_SIGNATURE_FORMAT == "invalid_signature_format" + assert PushSubscriptionVerificationError.MISSING_TIMESTAMP == "missing_timestamp" + assert PushSubscriptionVerificationError.TIMESTAMP_TOO_OLD == "timestamp_too_old" + assert PushSubscriptionVerificationError.TIMESTAMP_TOO_NEW == "timestamp_too_new" + assert PushSubscriptionVerificationError.SIGNATURE_MISMATCH == "signature_mismatch" + + +def test_real_world_webhook_scenario(webhook_secret): + """Test a realistic webhook verification scenario""" + # Simulate a real webhook payload + webhook_payload = '''{ + "event": "subscription.created", + "data": { + "id": "sub_123", + "status": "active", + "created_at": "2024-01-01T00:00:00Z" + }, + "timestamp": "2024-01-01T00:00:00Z" + }''' + + current_time = int(time.time()) + signature_header = create_valid_signature(webhook_secret, current_time, webhook_payload) + + verifier = PushSubscriptionVerifier(webhook_secret) + headers = PushSubscriptionHeaders(signature=signature_header) + payload = PushSubscriptionPayload(raw_body=webhook_payload) + + # Should verify successfully + result = verifier.verify(headers, payload) + assert result is True + + # Should also work with safe version + safe_result = verifier.verify_safe(headers, payload) + assert safe_result is True \ No newline at end of file diff --git a/test/subscriber_test.py b/test/subscriber_test.py new file mode 100644 index 0000000..acddddb --- /dev/null +++ b/test/subscriber_test.py @@ -0,0 +1,323 @@ +import pytest +import asyncio +from unittest.mock import AsyncMock, Mock, patch + +from sailhouse import SailhouseClient, SailhouseSubscriber, SubscriberOptions, Event + + +@pytest.fixture +def client(): + return SailhouseClient(token="test-token") + + +@pytest.fixture +def mock_event(client): + event = Event( + id="test-event-id", + data={"message": "test"}, + _topic="test-topic", + _subscription="test-sub", + _client=client + ) + event.ack = AsyncMock() + return event + + +def test_subscriber_options_creation(): + options = SubscriberOptions(per_subscription_processors=5) + assert options.per_subscription_processors == 5 + + +def test_subscriber_options_default(): + options = SubscriberOptions() + assert options.per_subscription_processors == 1 + + +def test_subscriber_creation_through_client(client): + subscriber = client.subscriber() + + assert isinstance(subscriber, SailhouseSubscriber) + assert subscriber.client is client + assert subscriber.options.per_subscription_processors == 1 + assert len(subscriber.subscriptions) == 0 + assert not subscriber._running + + +def test_subscriber_creation_with_options(client): + options = SubscriberOptions(per_subscription_processors=3) + subscriber = client.subscriber(options) + + assert isinstance(subscriber, SailhouseSubscriber) + assert subscriber.options.per_subscription_processors == 3 + + +def test_subscriber_direct_creation(client): + options = SubscriberOptions(per_subscription_processors=2) + subscriber = SailhouseSubscriber(client, options) + + assert subscriber.client is client + assert subscriber.options.per_subscription_processors == 2 + + +def test_subscribe_method(client): + subscriber = client.subscriber() + handler = AsyncMock() + + subscriber.subscribe("test-topic", "test-subscription", handler) + + assert len(subscriber.subscriptions) == 1 + subscription = subscriber.subscriptions[0] + assert subscription.topic == "test-topic" + assert subscription.subscription == "test-subscription" + assert subscription.handler is handler + + +def test_subscribe_multiple_subscriptions(client): + subscriber = client.subscriber() + handler1 = AsyncMock() + handler2 = AsyncMock() + + subscriber.subscribe("topic1", "sub1", handler1) + subscriber.subscribe("topic2", "sub2", handler2) + + assert len(subscriber.subscriptions) == 2 + assert subscriber.subscriptions[0].topic == "topic1" + assert subscriber.subscriptions[1].topic == "topic2" + + +@pytest.mark.asyncio +async def test_start_and_stop_subscriber(client): + subscriber = client.subscriber() + handler = AsyncMock() + + subscriber.subscribe("test-topic", "test-sub", handler) + + # Mock pull to return None (no events) + with patch.object(client, 'pull', new_callable=AsyncMock) as mock_pull: + mock_pull.return_value = None + + # Start the subscriber + start_task = asyncio.create_task(subscriber.start()) + + # Give it a moment to start + await asyncio.sleep(0.1) + assert subscriber._running is True + assert len(subscriber._tasks) == 1 # One processor per subscription + + # Stop the subscriber + await subscriber.stop() + assert subscriber._running is False + assert len(subscriber._tasks) == 0 + + # Clean up the start task + start_task.cancel() + try: + await start_task + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_subscriber_processes_events(client, mock_event): + subscriber = client.subscriber() + handler = AsyncMock() + + subscriber.subscribe("test-topic", "test-sub", handler) + + with patch.object(client, 'pull', new_callable=AsyncMock) as mock_pull: + # Return one event, then None + mock_pull.side_effect = [mock_event, None] + + # Start subscriber in background + start_task = asyncio.create_task(subscriber.start()) + + # Wait for event processing + await asyncio.sleep(0.2) + + # Verify handler was called + handler.assert_called_once_with(mock_event) + # Verify event was acknowledged + mock_event.ack.assert_called_once() + + # Stop subscriber + await subscriber.stop() + start_task.cancel() + try: + await start_task + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_subscriber_multiple_processors(client, mock_event): + options = SubscriberOptions(per_subscription_processors=3) + subscriber = client.subscriber(options) + handler = AsyncMock() + + subscriber.subscribe("test-topic", "test-sub", handler) + + with patch.object(client, 'pull', new_callable=AsyncMock) as mock_pull: + mock_pull.return_value = None + + # Start subscriber + start_task = asyncio.create_task(subscriber.start()) + + # Give it a moment to start + await asyncio.sleep(0.1) + + # Should have 3 processor tasks (3 processors per subscription) + assert len(subscriber._tasks) == 3 + + # Stop subscriber + await subscriber.stop() + start_task.cancel() + try: + await start_task + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_subscriber_multiple_subscriptions_multiple_processors(client): + options = SubscriberOptions(per_subscription_processors=2) + subscriber = client.subscriber(options) + handler1 = AsyncMock() + handler2 = AsyncMock() + + subscriber.subscribe("topic1", "sub1", handler1) + subscriber.subscribe("topic2", "sub2", handler2) + + with patch.object(client, 'pull', new_callable=AsyncMock) as mock_pull: + mock_pull.return_value = None + + start_task = asyncio.create_task(subscriber.start()) + await asyncio.sleep(0.1) + + # Should have 4 processor tasks (2 subscriptions × 2 processors each) + assert len(subscriber._tasks) == 4 + + await subscriber.stop() + start_task.cancel() + try: + await start_task + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_subscriber_handler_exception_no_ack(client, mock_event): + subscriber = client.subscriber() + handler = AsyncMock() + handler.side_effect = Exception("Handler failed") + + subscriber.subscribe("test-topic", "test-sub", handler) + + with patch.object(client, 'pull', new_callable=AsyncMock) as mock_pull: + # Return one event, then None + mock_pull.side_effect = [mock_event, None] + + start_task = asyncio.create_task(subscriber.start()) + await asyncio.sleep(0.2) + + # Handler was called but failed + handler.assert_called_once_with(mock_event) + # Event should NOT be acknowledged when handler fails + mock_event.ack.assert_not_called() + + await subscriber.stop() + start_task.cancel() + try: + await start_task + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_subscriber_pull_exception_continues(client): + subscriber = client.subscriber() + handler = AsyncMock() + + subscriber.subscribe("test-topic", "test-sub", handler) + + with patch.object(client, 'pull', new_callable=AsyncMock) as mock_pull: + # Multiple calls: first raises exception, then returns None repeatedly + mock_pull.side_effect = [Exception("Pull failed")] + [None] * 10 + + start_task = asyncio.create_task(subscriber.start()) + await asyncio.sleep(1.5) # Give more time for retries + + # Should have attempted pull multiple times despite first failure + assert mock_pull.call_count >= 2 + # Handler should not have been called + handler.assert_not_called() + + await subscriber.stop() + start_task.cancel() + try: + await start_task + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_subscriber_start_when_already_running(client): + subscriber = client.subscriber() + handler = AsyncMock() + subscriber.subscribe("test-topic", "test-sub", handler) + + with patch.object(client, 'pull', new_callable=AsyncMock) as mock_pull: + mock_pull.return_value = None + + # Start first time + start_task1 = asyncio.create_task(subscriber.start()) + await asyncio.sleep(0.1) + + initial_task_count = len(subscriber._tasks) + + # Start again - should not create duplicate tasks + start_task2 = asyncio.create_task(subscriber.start()) + await asyncio.sleep(0.1) + + assert len(subscriber._tasks) == initial_task_count + + await subscriber.stop() + start_task1.cancel() + start_task2.cancel() + try: + await asyncio.gather(start_task1, start_task2) + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_subscriber_continuous_processing(client): + subscriber = client.subscriber() + handler = AsyncMock() + + subscriber.subscribe("test-topic", "test-sub", handler) + + # Create multiple mock events + events = [ + Event("id1", {"msg": "1"}, "test-topic", "test-sub", client), + Event("id2", {"msg": "2"}, "test-topic", "test-sub", client), + None # No more events + ] + + for event in events[:-1]: # Add ack mock to real events + event.ack = AsyncMock() + + with patch.object(client, 'pull', new_callable=AsyncMock) as mock_pull: + mock_pull.side_effect = events + + start_task = asyncio.create_task(subscriber.start()) + await asyncio.sleep(0.3) # Give time to process events + + # Should have processed both events + assert handler.call_count >= 2 + + await subscriber.stop() + start_task.cancel() + try: + await start_task + except asyncio.CancelledError: + pass \ No newline at end of file diff --git a/test/waitgroups_test.py b/test/waitgroups_test.py new file mode 100644 index 0000000..6134600 --- /dev/null +++ b/test/waitgroups_test.py @@ -0,0 +1,193 @@ +import pytest +from unittest.mock import patch +import json +from datetime import datetime + +from sailhouse import SailhouseClient, SailhouseError, WaitOptions, WaitGroup + + +@pytest.fixture +def client(): + return SailhouseClient(token="test-token") + + +@pytest.fixture +def mock_response(): + class MockResponse: + def __init__(self, status_code, json_data=None): + self.status_code = status_code + self._json_data = json_data + self.text = json.dumps(json_data) if json_data else "" + + def json(self): + if self._json_data is None: + raise ValueError("No JSON data available") + return self._json_data + + return MockResponse + + +def test_wait_options_creation(): + options = WaitOptions(ttl=300) + assert options.ttl == 300 + + +def test_wait_options_default(): + options = WaitOptions() + assert options.ttl is None + + +def test_wait_without_options(client): + waitgroup = client.wait() + + assert isinstance(waitgroup, WaitGroup) + assert waitgroup.instance_id is not None + assert len(waitgroup.instance_id) > 0 + assert waitgroup.client is client + + +def test_wait_with_options_no_ttl(client): + options = WaitOptions() + waitgroup = client.wait(options) + + assert isinstance(waitgroup, WaitGroup) + assert waitgroup.instance_id is not None + assert waitgroup.client is client + + +def test_wait_with_ttl(client, mock_response): + options = WaitOptions(ttl=600) + + with patch.object(client.session, 'post', return_value=mock_response(201, {})) as mock_post: + waitgroup = client.wait(options) + + assert isinstance(waitgroup, WaitGroup) + assert waitgroup.instance_id is not None + assert waitgroup.client is client + + # Verify API call was made + mock_post.assert_called_once() + call_args = mock_post.call_args + url = call_args[0][0] + body = call_args[1]['json'] + + assert url == "https://api.sailhouse.dev/waitgroups" + assert body['ttl'] == 600 + assert 'instance_id' in body + + +def test_wait_with_ttl_failure(client, mock_response): + options = WaitOptions(ttl=600) + + with patch.object(client.session, 'post', return_value=mock_response(400, {})): + with pytest.raises(SailhouseError) as exc_info: + client.wait(options) + + assert "Failed to create waitgroup" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_publish_with_wait_group_instance_id(client, mock_response): + test_data = {"message": "test"} + wait_group_id = "test-wait-group-id" + + with patch.object(client.session, 'post', return_value=mock_response(201, {})) as mock_post: + await client.publish( + "test-topic", + test_data, + wait_group_instance_id=wait_group_id + ) + + # Verify the request body includes wait_group_instance_id + call_args = mock_post.call_args + body = call_args[1]['json'] + assert body['wait_group_instance_id'] == wait_group_id + assert body['data'] == test_data + + +@pytest.mark.asyncio +async def test_publish_without_wait_group_instance_id(client, mock_response): + test_data = {"message": "test"} + + with patch.object(client.session, 'post', return_value=mock_response(201, {})) as mock_post: + await client.publish("test-topic", test_data) + + # Verify the request body does not include wait_group_instance_id + call_args = mock_post.call_args + body = call_args[1]['json'] + assert 'wait_group_instance_id' not in body + assert body['data'] == test_data + + +@pytest.mark.asyncio +async def test_waitgroup_publish(client, mock_response): + # Create a waitgroup + waitgroup = client.wait() + test_data = {"message": "test from waitgroup"} + metadata = {"source": "waitgroup"} + scheduled_time = datetime.now() + + with patch.object(client.session, 'post', return_value=mock_response(201, {})) as mock_post: + await waitgroup.publish( + "test-topic", + test_data, + metadata=metadata, + scheduled_time=scheduled_time + ) + + # Verify the request includes the waitgroup instance_id + call_args = mock_post.call_args + body = call_args[1]['json'] + assert body['wait_group_instance_id'] == waitgroup.instance_id + assert body['data'] == test_data + assert body['metadata'] == metadata + assert 'send_at' in body + + +@pytest.mark.asyncio +async def test_waitgroup_publish_minimal(client, mock_response): + waitgroup = client.wait() + test_data = {"simple": "message"} + + with patch.object(client.session, 'post', return_value=mock_response(201, {})) as mock_post: + await waitgroup.publish("test-topic", test_data) + + call_args = mock_post.call_args + body = call_args[1]['json'] + assert body['wait_group_instance_id'] == waitgroup.instance_id + assert body['data'] == test_data + assert 'metadata' not in body + assert 'send_at' not in body + + +def test_multiple_waitgroups_have_different_ids(client): + waitgroup1 = client.wait() + waitgroup2 = client.wait() + + assert waitgroup1.instance_id != waitgroup2.instance_id + assert isinstance(waitgroup1, WaitGroup) + assert isinstance(waitgroup2, WaitGroup) + + +@pytest.mark.asyncio +async def test_publish_with_all_parameters_including_waitgroup(client, mock_response): + test_data = {"complex": "message"} + metadata = {"key": "value"} + scheduled_time = datetime.now() + wait_group_id = "complex-wait-group" + + with patch.object(client.session, 'post', return_value=mock_response(201, {})) as mock_post: + await client.publish( + "complex-topic", + test_data, + metadata=metadata, + scheduled_time=scheduled_time, + wait_group_instance_id=wait_group_id + ) + + call_args = mock_post.call_args + body = call_args[1]['json'] + assert body['data'] == test_data + assert body['metadata'] == metadata + assert body['wait_group_instance_id'] == wait_group_id + assert 'send_at' in body \ No newline at end of file