From fdb79cfd865b10da6471df447221eee7142a4da7 Mon Sep 17 00:00:00 2001 From: Ronald van der Meer Date: Wed, 11 Mar 2026 12:38:13 +0100 Subject: [PATCH] Fix session consistency in set_value, get_device_language, reboot_device and get_access_point - set_value(): Use _get_session() instead of creating a standalone aiohttp.ClientSession, so externally provided sessions (e.g. from Home Assistant) are properly used and internal sessions are closed. - get_device_language(): Same fix - use _get_session() for session consistency with the rest of the request handler. - reboot_device(): Same fix - use _get_session() for session consistency. - get_access_point(): Fix double-read bug where resp.text() consumed the body and resp.json() would fail. Now parses JSON from the already-read text using json.loads(). Added 10 tests covering: - External session passthrough (not closed by handler) - Internal session creation and cleanup - Session cleanup on errors - Access point text/JSON parsing --- src/pooldose/request_handler.py | 23 ++- tests/test_request_handler.py | 248 ++++++++++++++++++++++++++++++++ 2 files changed, 264 insertions(+), 7 deletions(-) diff --git a/src/pooldose/request_handler.py b/src/pooldose/request_handler.py index a8dce2e..6a1fef8 100644 --- a/src/pooldose/request_handler.py +++ b/src/pooldose/request_handler.py @@ -311,7 +311,7 @@ async def get_access_point(self) -> Tuple[RequestStatus, Optional[AccessPointDic json_end = text.rfind("}") + 1 data = None if json_start != -1 and json_end != -1: - data = await resp.json() + data = json.loads(text[json_start:json_end]) if not data: _LOGGER.error("No data found for access point info") return RequestStatus.NO_DATA, None @@ -403,10 +403,13 @@ async def set_value(self, device_id: str, path: str, value: Any, value_type: str _LOGGER.info("Sending payload: %s", self._last_payload) try: timeout_obj = aiohttp.ClientTimeout(total=self.timeout) - connector = self._get_ssl_connector() - async with aiohttp.ClientSession(connector=connector) as session: + session, close_session = await self._get_session() + try: async with session.post(url, json=payload, headers=self._headers, timeout=timeout_obj) as resp: resp.raise_for_status() + finally: + if close_session: + await session.close() except aiohttp.ClientError as e: _LOGGER.warning("Client error setting value: %s", e) return False @@ -473,8 +476,8 @@ async def get_device_language(self, device_id: str | None = None): try: timeout_obj = aiohttp.ClientTimeout(total=self.timeout) - connector = self._get_ssl_connector() - async with aiohttp.ClientSession(connector=connector) as session: + session, close_session = await self._get_session() + try: async with session.post(url, json=payload, headers=self._headers, timeout=timeout_obj) as resp: resp.raise_for_status() data = await resp.json() @@ -482,6 +485,9 @@ async def get_device_language(self, device_id: str | None = None): _LOGGER.error("No data found for device language") return RequestStatus.NO_DATA, None return RequestStatus.SUCCESS, data + finally: + if close_session: + await session.close() except (aiohttp.ClientError, asyncio.TimeoutError) as err: _LOGGER.error("Failed to fetch device language: %s", err) return RequestStatus.UNKNOWN_ERROR, None @@ -496,11 +502,14 @@ async def reboot_device(self): url = self._build_url("/api/v1/system/reboot") try: timeout_obj = aiohttp.ClientTimeout(total=self.timeout) - connector = self._get_ssl_connector() - async with aiohttp.ClientSession(connector=connector) as session: + session, close_session = await self._get_session() + try: async with session.post(url, headers=self._headers, timeout=timeout_obj) as resp: resp.raise_for_status() return RequestStatus.SUCCESS, True + finally: + if close_session: + await session.close() except (aiohttp.ClientError, asyncio.TimeoutError) as err: _LOGGER.warning("Error sending reboot command: %s", err) return RequestStatus.UNKNOWN_ERROR, False diff --git a/tests/test_request_handler.py b/tests/test_request_handler.py index de4c275..04a8748 100644 --- a/tests/test_request_handler.py +++ b/tests/test_request_handler.py @@ -194,3 +194,251 @@ async def mock_get_debug_config(self): # Verify the session was closed despite the exception mock_session.close.assert_awaited_once() + + +class TestSessionConsistency: + """Tests to verify all HTTP methods use _get_session() consistently. + + Previously, set_value(), get_device_language() and reboot_device() created + their own aiohttp.ClientSession instead of using _get_session(). This meant + externally provided sessions (e.g. from Home Assistant) were bypassed. + """ + + @pytest.mark.asyncio + async def test_set_value_uses_get_session_with_external(self): + """Test that set_value() uses the external session via _get_session().""" + external_session = MagicMock() + external_session.close = AsyncMock() + + mock_response = MagicMock() + mock_response.status = 200 + mock_response.raise_for_status = MagicMock() + + handler = RequestHandler("192.168.1.1", websession=external_session) + + with patch.object(handler, '_get_session', return_value=(external_session, False)) as mock_get_session: + async_cm = AsyncMock() + async_cm.__aenter__.return_value = mock_response + external_session.post = MagicMock(return_value=async_cm) + + result = await handler.set_value("DEVICE_1", "w_123", 7.0, "NUMBER") + + mock_get_session.assert_called_once() + external_session.post.assert_called_once() + assert result is True + external_session.close.assert_not_awaited() + + @pytest.mark.asyncio + async def test_set_value_closes_internal_session(self): + """Test that set_value() creates and closes an internal session when no external session is provided.""" + handler = RequestHandler("192.168.1.1") + + mock_session = MagicMock() + mock_session.close = AsyncMock() + + mock_response = MagicMock() + mock_response.status = 200 + mock_response.raise_for_status = MagicMock() + + with patch.object(handler, '_get_session', return_value=(mock_session, True)): + async_cm = AsyncMock() + async_cm.__aenter__.return_value = mock_response + mock_session.post = MagicMock(return_value=async_cm) + + result = await handler.set_value("DEVICE_1", "w_123", 7.0, "NUMBER") + + assert result is True + mock_session.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_set_value_closes_session_on_error(self): + """Test that set_value() closes the internal session even on error.""" + handler = RequestHandler("192.168.1.1") + + mock_session = MagicMock() + mock_session.close = AsyncMock() + + with patch.object(handler, '_get_session', return_value=(mock_session, True)): + async_cm = AsyncMock() + async_cm.__aenter__.side_effect = aiohttp.ClientError("Connection refused") + mock_session.post = MagicMock(return_value=async_cm) + + result = await handler.set_value("DEVICE_1", "w_123", 7.0, "NUMBER") + + assert result is False + mock_session.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_get_device_language_uses_get_session_with_external(self): + """Test that get_device_language() uses the external session via _get_session().""" + external_session = MagicMock() + external_session.close = AsyncMock() + + mock_response = MagicMock() + mock_response.status = 200 + mock_response.raise_for_status = MagicMock() + mock_response.json = AsyncMock(return_value={"LABEL_test": "Test Label"}) + + handler = RequestHandler("192.168.1.1", websession=external_session) + + with patch.object(handler, '_get_session', return_value=(external_session, False)) as mock_get_session: + async_cm = AsyncMock() + async_cm.__aenter__.return_value = mock_response + external_session.post = MagicMock(return_value=async_cm) + + status, data = await handler.get_device_language("TEST_DEVICE") + + mock_get_session.assert_called_once() + external_session.post.assert_called_once() + assert status == RequestStatus.SUCCESS + assert data == {"LABEL_test": "Test Label"} + external_session.close.assert_not_awaited() + + @pytest.mark.asyncio + async def test_get_device_language_closes_internal_session(self): + """Test that get_device_language() closes an internal session after use.""" + handler = RequestHandler("192.168.1.1") + + mock_session = MagicMock() + mock_session.close = AsyncMock() + + mock_response = MagicMock() + mock_response.status = 200 + mock_response.raise_for_status = MagicMock() + mock_response.json = AsyncMock(return_value={"LABEL_test": "Test Label"}) + + with patch.object(handler, '_get_session', return_value=(mock_session, True)): + async_cm = AsyncMock() + async_cm.__aenter__.return_value = mock_response + mock_session.post = MagicMock(return_value=async_cm) + + status, data = await handler.get_device_language("TEST_DEVICE") + + assert status == RequestStatus.SUCCESS + mock_session.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_reboot_device_uses_get_session_with_external(self): + """Test that reboot_device() uses the external session via _get_session().""" + external_session = MagicMock() + external_session.close = AsyncMock() + + mock_response = MagicMock() + mock_response.status = 200 + mock_response.raise_for_status = MagicMock() + + handler = RequestHandler("192.168.1.1", websession=external_session) + + with patch.object(handler, '_get_session', return_value=(external_session, False)) as mock_get_session: + async_cm = AsyncMock() + async_cm.__aenter__.return_value = mock_response + external_session.post = MagicMock(return_value=async_cm) + + status, result = await handler.reboot_device() + + mock_get_session.assert_called_once() + external_session.post.assert_called_once() + assert status == RequestStatus.SUCCESS + assert result is True + external_session.close.assert_not_awaited() + + @pytest.mark.asyncio + async def test_reboot_device_closes_internal_session(self): + """Test that reboot_device() closes an internal session after use.""" + handler = RequestHandler("192.168.1.1") + + mock_session = MagicMock() + mock_session.close = AsyncMock() + + mock_response = MagicMock() + mock_response.status = 200 + mock_response.raise_for_status = MagicMock() + + with patch.object(handler, '_get_session', return_value=(mock_session, True)): + async_cm = AsyncMock() + async_cm.__aenter__.return_value = mock_response + mock_session.post = MagicMock(return_value=async_cm) + + status, result = await handler.reboot_device() + + assert status == RequestStatus.SUCCESS + assert result is True + mock_session.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_reboot_device_closes_session_on_error(self): + """Test that reboot_device() closes the internal session even on error.""" + handler = RequestHandler("192.168.1.1") + + mock_session = MagicMock() + mock_session.close = AsyncMock() + + with patch.object(handler, '_get_session', return_value=(mock_session, True)): + async_cm = AsyncMock() + async_cm.__aenter__.side_effect = aiohttp.ClientError("Connection refused") + mock_session.post = MagicMock(return_value=async_cm) + + status, result = await handler.reboot_device() + + assert status == RequestStatus.UNKNOWN_ERROR + assert result is False + mock_session.close.assert_awaited_once() + + +class TestAccessPointParsing: + """Test for the get_access_point() response parsing fix.""" + + @pytest.mark.asyncio + async def test_get_access_point_parses_text_response(self): + """Test that get_access_point() correctly parses JSON from text response. + + Previously, get_access_point() called both resp.text() and resp.json() + on the same response, which could fail because the body can only be + consumed once. Now it uses json.loads() on the already-read text. + """ + handler = RequestHandler("192.168.1.1") + + mock_session = MagicMock() + mock_session.close = AsyncMock() + + json_body = '{"SSID": "KOMMSPOT-TEST", "KEY": "secret123"}' + mock_response = MagicMock() + mock_response.status = 200 + mock_response.raise_for_status = MagicMock() + mock_response.text = AsyncMock(return_value=json_body) + + with patch.object(handler, '_get_session', return_value=(mock_session, True)): + async_cm = AsyncMock() + async_cm.__aenter__.return_value = mock_response + mock_session.post = MagicMock(return_value=async_cm) + + status, data = await handler.get_access_point() + + assert status == RequestStatus.SUCCESS + assert data == {"SSID": "KOMMSPOT-TEST", "KEY": "secret123"} + # Verify resp.json() was NOT called (only resp.text() + json.loads) + mock_response.json.assert_not_called() + mock_session.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_get_access_point_handles_non_json_response(self): + """Test that get_access_point() handles responses without valid JSON.""" + handler = RequestHandler("192.168.1.1") + + mock_session = MagicMock() + mock_session.close = AsyncMock() + + mock_response = MagicMock() + mock_response.status = 200 + mock_response.raise_for_status = MagicMock() + mock_response.text = AsyncMock(return_value="not json at all") + + with patch.object(handler, '_get_session', return_value=(mock_session, True)): + async_cm = AsyncMock() + async_cm.__aenter__.return_value = mock_response + mock_session.post = MagicMock(return_value=async_cm) + + status, data = await handler.get_access_point() + + assert status == RequestStatus.NO_DATA + assert data is None