diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c7e4940..e8407f6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,6 +34,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: "1.8.5" virtualenvs-create: true - name: Install dependencies diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 5551f08..8690c95 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -47,6 +47,7 @@ jobs: if: steps.version-check.outputs.skip == 'false' uses: snok/install-poetry@v1 with: + version: "1.8.5" virtualenvs-create: true - name: Install dependencies diff --git a/dev_runner.py b/dev_runner.py new file mode 100755 index 0000000..d7279d6 --- /dev/null +++ b/dev_runner.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +""" +Development runner to observe SDK behavior including SSE streaming and watchdog. + +Usage: + REFORGE_SDK_KEY=your-key python dev_runner.py + +Or set a specific config key to watch: + REFORGE_SDK_KEY=your-key python dev_runner.py my.config.key +""" + +import logging +import sys +import time +import os + +from sdk_reforge import ReforgeSDK, Options + + +def setup_logging() -> None: + """Configure logging to show SDK internals.""" + root_logger = logging.getLogger() + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter( + logging.Formatter( + "%(asctime)s [%(levelname)s] %(name)s: %(message)s", + datefmt="%H:%M:%S", + ) + ) + root_logger.addHandler(handler) + + # Set root to DEBUG to see everything + root_logger.setLevel(logging.DEBUG) + + # Quiet down noisy libraries + logging.getLogger("urllib3").setLevel(logging.WARNING) + logging.getLogger("requests").setLevel(logging.WARNING) + + +def main() -> None: + setup_logging() + + sdk_key = os.environ.get("REFORGE_SDK_KEY") + if not sdk_key: + print("Error: REFORGE_SDK_KEY environment variable not set") + print("Usage: REFORGE_SDK_KEY=your-key python dev_runner.py [config.key]") + sys.exit(1) + + # Optional: config key to watch + watch_key = sys.argv[1] if len(sys.argv) > 1 else None + + print(f"Starting SDK with key: {sdk_key[:10]}...") + print(f"Watching config key: {watch_key or '(none)'}") + print("Press Ctrl+C to stop\n") + print("=" * 60) + + options = Options( + sdk_key=sdk_key, + connection_timeout_seconds=10, + ) + + sdk = ReforgeSDK(options) + config_sdk = sdk.config_sdk() + + print("=" * 60) + print("SDK initialized, entering main loop...") + print("=" * 60 + "\n") + + try: + iteration = 0 + while True: + iteration += 1 + + status_parts = [ + f"[{iteration}]", + f"initialized={config_sdk.is_ready()}", + f"hwm={config_sdk.highwater_mark()}", + ] + + if watch_key: + try: + value = config_sdk.get(watch_key, default="") + status_parts.append(f"{watch_key}={value!r}") + except Exception as e: + status_parts.append(f"{watch_key}=") + + print(" | ".join(status_parts)) + time.sleep(5) + + except KeyboardInterrupt: + print("\n\nShutting down...") + + sdk.close() + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index b7ed9b3..a293bed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "sdk-reforge" -version = "1.1.1" +version = "1.2.0" description = "Python sdk for Reforge Feature Flags and Config as a Service: https://www.reforge.com" license = "MIT" authors = ["Michael Berkowitz ", "James Kebinger "] diff --git a/sdk_reforge/VERSION b/sdk_reforge/VERSION index 524cb55..26aaba0 100644 --- a/sdk_reforge/VERSION +++ b/sdk_reforge/VERSION @@ -1 +1 @@ -1.1.1 +1.2.0 diff --git a/sdk_reforge/_sse_connection_manager.py b/sdk_reforge/_sse_connection_manager.py index 8ad64f1..72d93a7 100644 --- a/sdk_reforge/_sse_connection_manager.py +++ b/sdk_reforge/_sse_connection_manager.py @@ -1,15 +1,20 @@ import base64 import time -from typing import Optional, Callable +from typing import Optional, Callable, TYPE_CHECKING import sseclient # type: ignore from requests import Response +from requests.exceptions import HTTPError from sdk_reforge._internal_logging import InternalLogger -from sdk_reforge._requests import ApiClient, UnauthorizedException +from sdk_reforge._requests import ApiClient +from sdk_reforge._sse_watchdog import WatchdogResponseWrapper import prefab_pb2 as Prefab from sdk_reforge.config_sdk_interface import ConfigSDKInterface +if TYPE_CHECKING: + from sdk_reforge._sse_watchdog import SSEWatchdog + SHORT_CONNECTION_THRESHOLD = 2 # seconds CONSECUTIVE_SHORT_CONNECTION_LIMIT = 2 # times MIN_BACKOFF_TIME = 1 # seconds @@ -29,68 +34,98 @@ def __init__( api_client: ApiClient, config_client: ConfigSDKInterface, urls: list[str], + watchdog: Optional["SSEWatchdog"] = None, ): self.api_client = api_client self.config_client = config_client self.sse_client: Optional[sseclient.SSEClient] = None self.timing = Timing() self.urls = urls + self.watchdog = watchdog def streaming_loop(self) -> None: too_short_connection_count = 0 backoff_time = MIN_BACKOFF_TIME - while self.config_client.continue_connection_processing(): - try: - logger.debug("Starting streaming connection") - headers = { - "Last-Event-ID": f"{self.config_client.highwater_mark()}", - "accept": "text/event-stream", - } - response = self.api_client.resilient_request( - "/api/v2/sse/config", - headers=headers, - stream=True, - auth=("authuser", self.config_client.options.api_key), - timeout=(5, 60), - hosts=self.urls, - ) - response.raise_for_status() - if response.ok: - elapsed_time = self.timing.time_execution( - lambda: self.process_response(response) + try: + while self.config_client.continue_connection_processing(): + try: + logger.debug("Starting streaming connection") + headers = { + "Last-Event-ID": f"{self.config_client.highwater_mark()}", + "accept": "text/event-stream", + } + response = self.api_client.resilient_request( + "/api/v2/sse/config", + headers=headers, + stream=True, + auth=("authuser", self.config_client.options.api_key), + timeout=(5, 60), + hosts=self.urls, ) - if elapsed_time < SHORT_CONNECTION_THRESHOLD: - too_short_connection_count += 1 - if ( - too_short_connection_count - >= CONSECUTIVE_SHORT_CONNECTION_LIMIT - ): - raise TooQuickConnectionException() - else: - too_short_connection_count = 0 - backoff_time = MIN_BACKOFF_TIME - time.sleep(backoff_time) - except UnauthorizedException: - self.config_client.handle_unauthorized_response() - except TooQuickConnectionException as e: - logger.debug(f"Connection ended quickly: {str(e)}. Will apply backoff.") - backoff_time = min(backoff_time * 2, MAX_BACKOFF_TIME) - time.sleep(backoff_time) - except Exception as e: - if not self.config_client.is_shutting_down(): - logger.warning( - f"Streaming connection error: {str(e)}, Will retry in {backoff_time} seconds" + response.raise_for_status() + if response.ok: + elapsed_time = self.timing.time_execution( + lambda: self.process_response(response) + ) + if elapsed_time < SHORT_CONNECTION_THRESHOLD: + too_short_connection_count += 1 + if ( + too_short_connection_count + >= CONSECUTIVE_SHORT_CONNECTION_LIMIT + ): + raise TooQuickConnectionException() + else: + too_short_connection_count = 0 + backoff_time = MIN_BACKOFF_TIME + time.sleep(backoff_time) + except TooQuickConnectionException as e: + logger.debug( + f"Connection ended quickly: {str(e)}. Will apply backoff." ) backoff_time = min(backoff_time * 2, MAX_BACKOFF_TIME) time.sleep(backoff_time) - - """ - Hand off a successful response here for processing - """ + except HTTPError as e: + # Check for unauthorized (401/403) responses + if e.response is not None and e.response.status_code in (401, 403): + logger.warning( + f"Received {e.response.status_code} response, stopping SSE" + ) + self.config_client.handle_unauthorized_response() + else: + if not self.config_client.is_shutting_down(): + backoff_time = min(backoff_time * 2, MAX_BACKOFF_TIME) + logger.warning( + f"Streaming connection error ({type(e).__name__}): {str(e)}, " + f"Will retry in {backoff_time} seconds" + ) + time.sleep(backoff_time) + except BaseException as e: + # Re-raise system exceptions that should terminate the thread + if isinstance(e, (KeyboardInterrupt, SystemExit)): + raise + if not self.config_client.is_shutting_down(): + backoff_time = min(backoff_time * 2, MAX_BACKOFF_TIME) + logger.warning( + f"Streaming connection error ({type(e).__name__}): {str(e)}, " + f"Will retry in {backoff_time} seconds" + ) + time.sleep(backoff_time) + finally: + logger.info( + f"Streaming loop exited " + f"(shutdown={self.config_client.is_shutting_down()})" + ) def process_response(self, response: Response) -> None: - self.sse_client = sseclient.SSEClient(response) + """Hand off a successful response here for processing.""" + # Wrap response to track data received for watchdog + if self.watchdog: + wrapped_response = WatchdogResponseWrapper(response, self.watchdog.touch) + self.sse_client = sseclient.SSEClient(wrapped_response) + else: + self.sse_client = sseclient.SSEClient(response) + if self.sse_client is not None: for event in self.sse_client.events(): if self.config_client.is_shutting_down(): diff --git a/sdk_reforge/_sse_watchdog.py b/sdk_reforge/_sse_watchdog.py new file mode 100644 index 0000000..e39eff6 --- /dev/null +++ b/sdk_reforge/_sse_watchdog.py @@ -0,0 +1,122 @@ +import threading +import time +from typing import Any, Callable, Iterator, Optional, TYPE_CHECKING + +from ._internal_logging import InternalLogger + +if TYPE_CHECKING: + from .config_sdk_interface import ConfigSDKInterface + +logger = InternalLogger(__name__) + +DEFAULT_CHECK_INTERVAL: float = 60 # seconds +DEFAULT_MAX_SILENCE: float = 120 # seconds (4 missed 30s keepalives) + + +class WatchdogResponseWrapper: + """Wraps a response to touch the watchdog on any data received. + + This allows the watchdog to track when ANY data is received from the SSE + connection, including keepalive comments that sseclient filters out. + """ + + def __init__(self, response: Any, on_data_received: Callable[[], None]) -> None: + self._response = response + self._on_data_received = on_data_received + + def __iter__(self) -> Iterator[Any]: + for chunk in self._response: + self._on_data_received() + yield chunk + + def close(self) -> None: + self._response.close() + + +class SSEWatchdog: + """Monitors SSE connection health and triggers recovery when stuck. + + The watchdog runs in a separate thread and periodically checks if SSE data + has been received recently. If no data (including keepalives) has been + received for max_silence seconds, it: + 1. Logs a warning + 2. Polls the checkpoint API to get fresh config data + 3. Closes the SSE connection to force reconnection + """ + + def __init__( + self, + config_client: "ConfigSDKInterface", + poll_fallback_fn: Callable[[], None], + get_sse_client_fn: Callable[[], Any], + check_interval: float = DEFAULT_CHECK_INTERVAL, + max_silence: float = DEFAULT_MAX_SILENCE, + ) -> None: + """Initialize the watchdog. + + Args: + config_client: The config client interface for checking shutdown state + poll_fallback_fn: Function to call to poll for fresh config data + get_sse_client_fn: Function that returns the current SSE client (or None) + check_interval: How often to check for silence (seconds) + max_silence: Trigger recovery after this many seconds of no data + """ + self.config_client = config_client + self.poll_fallback_fn = poll_fallback_fn + self.get_sse_client_fn = get_sse_client_fn + self.check_interval = check_interval + self.max_silence = max_silence + self.last_activity = time.time() + self._stop = threading.Event() + self._thread: Optional[threading.Thread] = None + + def touch(self) -> None: + """Called when any SSE data is received (including keepalives).""" + self.last_activity = time.time() + + def start(self) -> None: + """Start the watchdog thread.""" + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + + def stop(self) -> None: + """Stop the watchdog thread.""" + self._stop.set() + if self._thread: + self._thread.join(timeout=5) + + def _run(self) -> None: + """Main watchdog loop.""" + while not self._stop.wait(self.check_interval): + if self.config_client.is_shutting_down(): + break + + silence = time.time() - self.last_activity + if silence > self.max_silence: + self._trigger_recovery(silence) + + def _trigger_recovery(self, silence: float) -> None: + """Trigger recovery actions when SSE appears stuck.""" + logger.warning( + f"SSE connection appears stuck (no activity for {silence:.0f}s), " + "triggering recovery" + ) + + # 1. Poll for fresh data immediately + try: + self.poll_fallback_fn() + logger.info("Fallback poll completed successfully") + except Exception as e: + logger.warning(f"Fallback poll failed: {e}") + + # 2. Force SSE reconnection by closing current connection + try: + sse_client = self.get_sse_client_fn() + if sse_client: + sse_client.close() + logger.debug("Closed SSE client to force reconnection") + except Exception: + pass # Best effort + + # Reset activity timer after recovery attempt + self.last_activity = time.time() diff --git a/sdk_reforge/config_sdk.py b/sdk_reforge/config_sdk.py index 6c85a70..8641c57 100644 --- a/sdk_reforge/config_sdk.py +++ b/sdk_reforge/config_sdk.py @@ -11,6 +11,7 @@ from ._count_down_latch import CountDownLatch from ._requests import ApiClient, UnauthorizedException from ._sse_connection_manager import SSEConnectionManager +from ._sse_watchdog import SSEWatchdog from .config_sdk_interface import ConfigSDKInterface from .config_loader import ConfigLoader from .config_resolver import ConfigResolver @@ -47,7 +48,7 @@ def __init__(self, base_client): self.is_initialized = threading.Event() self.checkpointing_thread = None self.streaming_thread = None - self.sse_client = None + self.watchdog: Optional[SSEWatchdog] = None logger.info("Initializing ConfigClient") self.base_client = base_client self._options = base_client.options @@ -60,8 +61,15 @@ def __init__(self, base_client): self._cache_path = None self.set_cache_path() self.api_client = ApiClient(self.options) + + # Create watchdog for SSE connection health monitoring + self.watchdog = SSEWatchdog( + config_client=self, + poll_fallback_fn=self._watchdog_poll_fallback, + get_sse_client_fn=lambda: self.sse_connection_manager.sse_client, + ) self.sse_connection_manager = SSEConnectionManager( - self.api_client, self, self.options.reforge_stream_urls + self.api_client, self, self.options.reforge_stream_urls, self.watchdog ) if self.options.is_local_only(): @@ -127,6 +135,16 @@ def load_checkpoint(self): logger.warning("No success loading checkpoints") except UnauthorizedException: self.handle_unauthorized_response() + return + except Exception as e: + logger.error(f"Unexpected error loading checkpoint: {e}") + + # If we get here, checkpoint loading failed - start streaming as fallback + # Don't call finish_init() - let SSE load configs and call it, + # or let the timeout in get() kick in as designed + if self.options.is_loading_from_api(): + logger.info("Starting streaming as fallback after checkpoint load failure") + self.start_streaming() def start_checkpointing_thread(self): self.checkpointing_thread = threading.Thread( @@ -139,6 +157,17 @@ def start_streaming(self): target=self.sse_connection_manager.streaming_loop, daemon=True ) self.streaming_thread.start() + # Start watchdog to monitor SSE connection health + if self.watchdog: + self.watchdog.start() + + def _watchdog_poll_fallback(self) -> None: + """Called by watchdog when SSE connection appears stuck. + + Polls the checkpoint API to get fresh config data. + """ + logger.info("Watchdog triggered poll fallback") + self.load_checkpoint_from_api_cdn() def is_shutting_down(self): return self.base_client.shutdown_flag.is_set() @@ -292,5 +321,9 @@ def handle_unauthorized_response(self): self.init_latch.count_down() def close(self) -> None: - if self.sse_client: - self.sse_client.close() + """Clean up resources.""" + if self.watchdog: + self.watchdog.stop() + # Close SSE client if active + if self.sse_connection_manager and self.sse_connection_manager.sse_client: + self.sse_connection_manager.sse_client.close() diff --git a/tests/test_config_sdk.py b/tests/test_config_sdk.py index 5541815..7c40502 100644 --- a/tests/test_config_sdk.py +++ b/tests/test_config_sdk.py @@ -162,3 +162,129 @@ def my_on_ready_callback(): ) on_ready_called.wait(timeout=2) assert on_ready_called.is_set() + + +class TestLoadCheckpointErrorHandling: + """Test that load_checkpoint handles errors gracefully and starts streaming. + + The design is that streaming should start as a fallback even if checkpoint + loading fails, but finish_init() should NOT be called - let SSE load configs + (which will call finish_init), or let the timeout in get() kick in as designed. + """ + + def test_starts_streaming_when_no_checkpoint_found(self): + """When both CDN and cache fail to load, streaming should still start.""" + from unittest.mock import Mock, patch + + mock_base_client = Mock() + mock_base_client.options = Options( + sdk_key="123-test-key", + x_use_local_cache=False, + ) + mock_base_client.shutdown_flag = threading.Event() + + with patch.object(ConfigSDK, "__init__", lambda self, x: None): + config_sdk = ConfigSDK(None) + config_sdk.base_client = mock_base_client + config_sdk._options = mock_base_client.options + config_sdk.config_loader = Mock() + config_sdk.config_loader.highwater_mark = 0 + config_sdk.api_client = Mock() + config_sdk.is_initialized = threading.Event() + config_sdk.init_latch = Mock() + config_sdk.finish_init_mutex = threading.Lock() + config_sdk.watchdog = None + config_sdk.streaming_thread = None + config_sdk.sse_connection_manager = Mock() + + # Mock load methods to return False (no data found) + config_sdk.load_checkpoint_from_api_cdn = Mock(return_value=False) + config_sdk.load_cache = Mock(return_value=False) + config_sdk.start_streaming = Mock() + + config_sdk.load_checkpoint() + + # finish_init should NOT have been called - let SSE or timeout handle it + assert not config_sdk.is_initialized.is_set() + config_sdk.init_latch.count_down.assert_not_called() + # But streaming should start as fallback + config_sdk.start_streaming.assert_called_once() + + def test_starts_streaming_on_unexpected_exception(self): + """When an unexpected exception occurs, streaming should still start.""" + from unittest.mock import Mock, patch + + mock_base_client = Mock() + mock_base_client.options = Options( + sdk_key="123-test-key", + x_use_local_cache=False, + ) + mock_base_client.shutdown_flag = threading.Event() + + with patch.object(ConfigSDK, "__init__", lambda self, x: None): + config_sdk = ConfigSDK(None) + config_sdk.base_client = mock_base_client + config_sdk._options = mock_base_client.options + config_sdk.config_loader = Mock() + config_sdk.config_loader.highwater_mark = 0 + config_sdk.api_client = Mock() + config_sdk.is_initialized = threading.Event() + config_sdk.init_latch = Mock() + config_sdk.finish_init_mutex = threading.Lock() + config_sdk.watchdog = None + config_sdk.streaming_thread = None + config_sdk.sse_connection_manager = Mock() + + # Mock load_checkpoint_from_api_cdn to raise an unexpected exception + config_sdk.load_checkpoint_from_api_cdn = Mock( + side_effect=RuntimeError("Unexpected network error") + ) + config_sdk.start_streaming = Mock() + + config_sdk.load_checkpoint() + + # finish_init should NOT have been called - let SSE or timeout handle it + assert not config_sdk.is_initialized.is_set() + config_sdk.init_latch.count_down.assert_not_called() + # But streaming should start as fallback + config_sdk.start_streaming.assert_called_once() + + def test_does_not_start_streaming_on_unauthorized(self): + """When UnauthorizedException occurs, streaming should NOT start.""" + from unittest.mock import Mock, patch + from sdk_reforge._requests import UnauthorizedException + + mock_base_client = Mock() + mock_base_client.options = Options( + sdk_key="123-test-key", + x_use_local_cache=False, + ) + mock_base_client.shutdown_flag = threading.Event() + + with patch.object(ConfigSDK, "__init__", lambda self, x: None): + config_sdk = ConfigSDK(None) + config_sdk.base_client = mock_base_client + config_sdk._options = mock_base_client.options + config_sdk.config_loader = Mock() + config_sdk.config_loader.highwater_mark = 0 + config_sdk.api_client = Mock() + config_sdk.is_initialized = threading.Event() + config_sdk.init_latch = Mock() + config_sdk.finish_init_mutex = threading.Lock() + config_sdk.unauthorized_event = threading.Event() + config_sdk.watchdog = None + config_sdk.streaming_thread = None + config_sdk.sse_connection_manager = Mock() + + # Mock load_checkpoint_from_api_cdn to raise UnauthorizedException + config_sdk.load_checkpoint_from_api_cdn = Mock( + side_effect=UnauthorizedException("bad-key") + ) + config_sdk.start_streaming = Mock() + + config_sdk.load_checkpoint() + + # Unauthorized should be handled, streaming should NOT start + assert config_sdk.unauthorized_event.is_set() + config_sdk.init_latch.count_down.assert_called_once() + config_sdk.start_streaming.assert_not_called() diff --git a/tests/test_sse_connection_manager.py b/tests/test_sse_connection_manager.py index 1c74a44..10a681e 100644 --- a/tests/test_sse_connection_manager.py +++ b/tests/test_sse_connection_manager.py @@ -8,7 +8,6 @@ SSEConnectionManager, MIN_BACKOFF_TIME, ) -from sdk_reforge._requests import UnauthorizedException class TestSSEConnectionManager(unittest.TestCase): @@ -119,9 +118,39 @@ def test_backoff_on_too_quick_connection(self, mock_sleep, mock_timing): ) @patch("sdk_reforge._sse_connection_manager.time.sleep") - def test_backoff_on_unauthorized_exception(self, mock_sleep): + def test_handles_401_unauthorized_response(self, mock_sleep): + """Verify 401 response triggers handle_unauthorized_response and stops loop""" + mock_response = Mock() + mock_response.status_code = 401 + mock_response.ok = False + + # Create HTTPError with response attached + http_error = HTTPError("401 Client Error: Unauthorized") + http_error.response = mock_response + mock_response.raise_for_status.side_effect = http_error + + self.api_client.resilient_request.return_value = mock_response + self.config_client.continue_connection_processing.side_effect = [True, False] + + self.sse_manager.streaming_loop() + + self.config_client.handle_unauthorized_response.assert_called_once() + mock_sleep.assert_not_called() + + @patch("sdk_reforge._sse_connection_manager.time.sleep") + def test_handles_403_forbidden_response(self, mock_sleep): + """Verify 403 response triggers handle_unauthorized_response and stops loop""" + mock_response = Mock() + mock_response.status_code = 403 + mock_response.ok = False + + # Create HTTPError with response attached + http_error = HTTPError("403 Client Error: Forbidden") + http_error.response = mock_response + mock_response.raise_for_status.side_effect = http_error + + self.api_client.resilient_request.return_value = mock_response self.config_client.continue_connection_processing.side_effect = [True, False] - self.api_client.resilient_request.side_effect = UnauthorizedException("the key") self.sse_manager.streaming_loop() diff --git a/tests/test_sse_watchdog.py b/tests/test_sse_watchdog.py new file mode 100644 index 0000000..2112ec6 --- /dev/null +++ b/tests/test_sse_watchdog.py @@ -0,0 +1,300 @@ +import unittest +import time +from unittest.mock import Mock, patch + +from sdk_reforge._sse_watchdog import ( + SSEWatchdog, + WatchdogResponseWrapper, + DEFAULT_CHECK_INTERVAL, + DEFAULT_MAX_SILENCE, +) + + +class TestWatchdogResponseWrapper(unittest.TestCase): + def test_iterates_through_all_chunks(self) -> None: + """Verify all chunks are yielded unchanged""" + chunks = [b"chunk1", b"chunk2", b"chunk3"] + mock_response = iter(chunks) + on_data_received: Mock = Mock() + + wrapper = WatchdogResponseWrapper(mock_response, on_data_received) + result = list(wrapper) + + self.assertEqual(result, chunks) + + def test_calls_callback_for_each_chunk(self) -> None: + """Verify callback is called for each chunk received""" + chunks = [b"chunk1", b"chunk2", b"chunk3"] + mock_response = iter(chunks) + on_data_received: Mock = Mock() + + wrapper = WatchdogResponseWrapper(mock_response, on_data_received) + list(wrapper) # Consume the iterator + + self.assertEqual(on_data_received.call_count, 3) + + def test_close_delegates_to_response(self) -> None: + """Verify close() is delegated to the wrapped response""" + mock_response: Mock = Mock() + on_data_received: Mock = Mock() + + wrapper = WatchdogResponseWrapper(mock_response, on_data_received) + wrapper.close() + + mock_response.close.assert_called_once() + + +class TestSSEWatchdog(unittest.TestCase): + def setUp(self) -> None: + self.config_client: Mock = Mock() + self.config_client.is_shutting_down.return_value = False + self.poll_fallback_fn: Mock = Mock() + self.get_sse_client_fn: Mock = Mock(return_value=None) + + def test_touch_updates_last_activity(self) -> None: + """Verify touch() updates the last_activity timestamp""" + watchdog = SSEWatchdog( + self.config_client, + self.poll_fallback_fn, + self.get_sse_client_fn, + ) + + initial_time = watchdog.last_activity + time.sleep(0.01) # Small delay to ensure time difference + watchdog.touch() + + self.assertGreater(watchdog.last_activity, initial_time) + + def test_no_recovery_when_active(self) -> None: + """Verify no recovery is triggered when activity is recent""" + watchdog = SSEWatchdog( + self.config_client, + self.poll_fallback_fn, + self.get_sse_client_fn, + check_interval=1, + max_silence=10, + ) + + # Touch to reset activity + watchdog.touch() + + # Manually run the check logic + silence = time.time() - watchdog.last_activity + self.assertLess(silence, watchdog.max_silence) + + # Poll should not have been called + self.poll_fallback_fn.assert_not_called() + + @patch("sdk_reforge._sse_watchdog.time.time") + def test_triggers_recovery_when_silent(self, mock_time: Mock) -> None: + """Verify recovery is triggered after max_silence seconds""" + # Set up time mocking: initial time, then time during check + mock_time.side_effect = [ + 1000, # Initial last_activity in __init__ + 1000, # touch() call + 1200, # time check in _run (200s silence > 120s max) + 1200, # time in _trigger_recovery for logging + 1200, # reset last_activity after recovery + ] + + watchdog = SSEWatchdog( + self.config_client, + self.poll_fallback_fn, + self.get_sse_client_fn, + max_silence=120, + ) + watchdog.touch() + + # Manually trigger recovery check + silence = 1200 - watchdog.last_activity # 200 seconds + if silence > watchdog.max_silence: + watchdog._trigger_recovery(silence) + + self.poll_fallback_fn.assert_called_once() + + def test_recovery_closes_sse_client(self) -> None: + """Verify recovery attempts to close the SSE client""" + mock_sse_client: Mock = Mock() + self.get_sse_client_fn.return_value = mock_sse_client + + watchdog = SSEWatchdog( + self.config_client, + self.poll_fallback_fn, + self.get_sse_client_fn, + ) + + watchdog._trigger_recovery(999) + + mock_sse_client.close.assert_called_once() + + def test_recovery_handles_none_sse_client(self) -> None: + """Verify recovery handles case when SSE client is None""" + self.get_sse_client_fn.return_value = None + + watchdog = SSEWatchdog( + self.config_client, + self.poll_fallback_fn, + self.get_sse_client_fn, + ) + + # Should not raise + watchdog._trigger_recovery(999) + + self.poll_fallback_fn.assert_called_once() + + def test_recovery_handles_poll_exception(self) -> None: + """Verify recovery continues even if poll fails""" + self.poll_fallback_fn.side_effect = Exception("Poll failed") + mock_sse_client: Mock = Mock() + self.get_sse_client_fn.return_value = mock_sse_client + + watchdog = SSEWatchdog( + self.config_client, + self.poll_fallback_fn, + self.get_sse_client_fn, + ) + + # Should not raise + watchdog._trigger_recovery(999) + + # Should still try to close SSE client + mock_sse_client.close.assert_called_once() + + def test_recovery_handles_close_exception(self) -> None: + """Verify recovery continues even if close fails""" + mock_sse_client: Mock = Mock() + mock_sse_client.close.side_effect = Exception("Close failed") + self.get_sse_client_fn.return_value = mock_sse_client + + watchdog = SSEWatchdog( + self.config_client, + self.poll_fallback_fn, + self.get_sse_client_fn, + ) + + # Should not raise + watchdog._trigger_recovery(999) + + self.poll_fallback_fn.assert_called_once() + + def test_recovery_resets_last_activity(self) -> None: + """Verify last_activity is reset after recovery""" + watchdog = SSEWatchdog( + self.config_client, + self.poll_fallback_fn, + self.get_sse_client_fn, + ) + + # Set last_activity to old time + watchdog.last_activity = time.time() - 1000 + + watchdog._trigger_recovery(999) + + # last_activity should be recent now + self.assertLess(time.time() - watchdog.last_activity, 1) + + def test_stop_terminates_thread(self) -> None: + """Verify stop() terminates the watchdog thread""" + watchdog = SSEWatchdog( + self.config_client, + self.poll_fallback_fn, + self.get_sse_client_fn, + check_interval=1, + ) + + watchdog.start() + assert watchdog._thread is not None + self.assertTrue(watchdog._thread.is_alive()) + + watchdog.stop() + self.assertFalse(watchdog._thread.is_alive()) + + def test_stops_when_shutting_down(self) -> None: + """Verify watchdog stops when config_client is shutting down""" + self.config_client.is_shutting_down.return_value = True + + watchdog = SSEWatchdog( + self.config_client, + self.poll_fallback_fn, + self.get_sse_client_fn, + check_interval=0.1, + ) + + watchdog.start() + time.sleep(0.3) # Give it time to check + + # Should have stopped on its own + watchdog.stop() + self.poll_fallback_fn.assert_not_called() + + def test_default_values(self) -> None: + """Verify default configuration values""" + watchdog = SSEWatchdog( + self.config_client, + self.poll_fallback_fn, + self.get_sse_client_fn, + ) + + self.assertEqual(watchdog.check_interval, DEFAULT_CHECK_INTERVAL) + self.assertEqual(watchdog.max_silence, DEFAULT_MAX_SILENCE) + + +class TestSSEWatchdogIntegration(unittest.TestCase): + """Integration tests for the watchdog with realistic timing""" + + def test_watchdog_fires_after_silence(self) -> None: + """Integration test: watchdog fires recovery after silence period""" + config_client: Mock = Mock() + config_client.is_shutting_down.return_value = False + poll_fallback_fn: Mock = Mock() + get_sse_client_fn: Mock = Mock(return_value=None) + + # Use short intervals for testing + watchdog = SSEWatchdog( + config_client, + poll_fallback_fn, + get_sse_client_fn, + check_interval=0.1, # Check every 100ms + max_silence=0.2, # Fire after 200ms of silence + ) + + watchdog.start() + + # Wait for silence period + check interval + time.sleep(0.5) + + watchdog.stop() + + # Recovery should have been triggered + self.assertTrue(poll_fallback_fn.called) + + def test_watchdog_does_not_fire_with_activity(self) -> None: + """Integration test: watchdog does not fire when touched regularly""" + config_client: Mock = Mock() + config_client.is_shutting_down.return_value = False + poll_fallback_fn: Mock = Mock() + get_sse_client_fn: Mock = Mock(return_value=None) + + watchdog = SSEWatchdog( + config_client, + poll_fallback_fn, + get_sse_client_fn, + check_interval=0.1, + max_silence=0.3, + ) + + watchdog.start() + + # Keep touching to simulate activity + for _ in range(5): + watchdog.touch() + time.sleep(0.1) + + watchdog.stop() + + # Recovery should NOT have been triggered + poll_fallback_fn.assert_not_called() + + +if __name__ == "__main__": + unittest.main()