Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions src/pooldose/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ async def get_access_point(self) -> Tuple[RequestStatus, Optional[AccessPointDic
json_end = text.rfind("}") + 1
data = None
if json_start != -1 and json_end != -1:
data = await resp.json()
data = json.loads(text[json_start:json_end])
if not data:
_LOGGER.error("No data found for access point info")
return RequestStatus.NO_DATA, None
Expand Down Expand Up @@ -403,10 +403,13 @@ async def set_value(self, device_id: str, path: str, value: Any, value_type: str
_LOGGER.info("Sending payload: %s", self._last_payload)
try:
timeout_obj = aiohttp.ClientTimeout(total=self.timeout)
connector = self._get_ssl_connector()
async with aiohttp.ClientSession(connector=connector) as session:
session, close_session = await self._get_session()
try:
async with session.post(url, json=payload, headers=self._headers, timeout=timeout_obj) as resp:
resp.raise_for_status()
finally:
if close_session:
await session.close()
except aiohttp.ClientError as e:
_LOGGER.warning("Client error setting value: %s", e)
return False
Expand Down Expand Up @@ -473,15 +476,18 @@ async def get_device_language(self, device_id: str | None = None):

try:
timeout_obj = aiohttp.ClientTimeout(total=self.timeout)
connector = self._get_ssl_connector()
async with aiohttp.ClientSession(connector=connector) as session:
session, close_session = await self._get_session()
try:
async with session.post(url, json=payload, headers=self._headers, timeout=timeout_obj) as resp:
resp.raise_for_status()
data = await resp.json()
if not data:
_LOGGER.error("No data found for device language")
return RequestStatus.NO_DATA, None
return RequestStatus.SUCCESS, data
finally:
if close_session:
await session.close()
except (aiohttp.ClientError, asyncio.TimeoutError) as err:
_LOGGER.error("Failed to fetch device language: %s", err)
return RequestStatus.UNKNOWN_ERROR, None
Expand All @@ -496,11 +502,14 @@ async def reboot_device(self):
url = self._build_url("/api/v1/system/reboot")
try:
timeout_obj = aiohttp.ClientTimeout(total=self.timeout)
connector = self._get_ssl_connector()
async with aiohttp.ClientSession(connector=connector) as session:
session, close_session = await self._get_session()
try:
async with session.post(url, headers=self._headers, timeout=timeout_obj) as resp:
resp.raise_for_status()
return RequestStatus.SUCCESS, True
finally:
if close_session:
await session.close()
except (aiohttp.ClientError, asyncio.TimeoutError) as err:
_LOGGER.warning("Error sending reboot command: %s", err)
return RequestStatus.UNKNOWN_ERROR, False
248 changes: 248 additions & 0 deletions tests/test_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,251 @@ async def mock_get_debug_config(self):

# Verify the session was closed despite the exception
mock_session.close.assert_awaited_once()


class TestSessionConsistency:
"""Tests to verify all HTTP methods use _get_session() consistently.

Previously, set_value(), get_device_language() and reboot_device() created
their own aiohttp.ClientSession instead of using _get_session(). This meant
externally provided sessions (e.g. from Home Assistant) were bypassed.
"""

@pytest.mark.asyncio
async def test_set_value_uses_get_session_with_external(self):
"""Test that set_value() uses the external session via _get_session()."""
external_session = MagicMock()
external_session.close = AsyncMock()

mock_response = MagicMock()
mock_response.status = 200
mock_response.raise_for_status = MagicMock()

handler = RequestHandler("192.168.1.1", websession=external_session)

with patch.object(handler, '_get_session', return_value=(external_session, False)) as mock_get_session:
async_cm = AsyncMock()
async_cm.__aenter__.return_value = mock_response
external_session.post = MagicMock(return_value=async_cm)

result = await handler.set_value("DEVICE_1", "w_123", 7.0, "NUMBER")

mock_get_session.assert_called_once()
external_session.post.assert_called_once()
assert result is True
external_session.close.assert_not_awaited()

@pytest.mark.asyncio
async def test_set_value_closes_internal_session(self):
"""Test that set_value() creates and closes an internal session when no external session is provided."""
handler = RequestHandler("192.168.1.1")

mock_session = MagicMock()
mock_session.close = AsyncMock()

mock_response = MagicMock()
mock_response.status = 200
mock_response.raise_for_status = MagicMock()

with patch.object(handler, '_get_session', return_value=(mock_session, True)):
async_cm = AsyncMock()
async_cm.__aenter__.return_value = mock_response
mock_session.post = MagicMock(return_value=async_cm)

result = await handler.set_value("DEVICE_1", "w_123", 7.0, "NUMBER")

assert result is True
mock_session.close.assert_awaited_once()

@pytest.mark.asyncio
async def test_set_value_closes_session_on_error(self):
"""Test that set_value() closes the internal session even on error."""
handler = RequestHandler("192.168.1.1")

mock_session = MagicMock()
mock_session.close = AsyncMock()

with patch.object(handler, '_get_session', return_value=(mock_session, True)):
async_cm = AsyncMock()
async_cm.__aenter__.side_effect = aiohttp.ClientError("Connection refused")
mock_session.post = MagicMock(return_value=async_cm)

result = await handler.set_value("DEVICE_1", "w_123", 7.0, "NUMBER")

assert result is False
mock_session.close.assert_awaited_once()

@pytest.mark.asyncio
async def test_get_device_language_uses_get_session_with_external(self):
"""Test that get_device_language() uses the external session via _get_session()."""
external_session = MagicMock()
external_session.close = AsyncMock()

mock_response = MagicMock()
mock_response.status = 200
mock_response.raise_for_status = MagicMock()
mock_response.json = AsyncMock(return_value={"LABEL_test": "Test Label"})

handler = RequestHandler("192.168.1.1", websession=external_session)

with patch.object(handler, '_get_session', return_value=(external_session, False)) as mock_get_session:
async_cm = AsyncMock()
async_cm.__aenter__.return_value = mock_response
external_session.post = MagicMock(return_value=async_cm)

status, data = await handler.get_device_language("TEST_DEVICE")

mock_get_session.assert_called_once()
external_session.post.assert_called_once()
assert status == RequestStatus.SUCCESS
assert data == {"LABEL_test": "Test Label"}
external_session.close.assert_not_awaited()

@pytest.mark.asyncio
async def test_get_device_language_closes_internal_session(self):
"""Test that get_device_language() closes an internal session after use."""
handler = RequestHandler("192.168.1.1")

mock_session = MagicMock()
mock_session.close = AsyncMock()

mock_response = MagicMock()
mock_response.status = 200
mock_response.raise_for_status = MagicMock()
mock_response.json = AsyncMock(return_value={"LABEL_test": "Test Label"})

with patch.object(handler, '_get_session', return_value=(mock_session, True)):
async_cm = AsyncMock()
async_cm.__aenter__.return_value = mock_response
mock_session.post = MagicMock(return_value=async_cm)

status, data = await handler.get_device_language("TEST_DEVICE")

assert status == RequestStatus.SUCCESS
mock_session.close.assert_awaited_once()

@pytest.mark.asyncio
async def test_reboot_device_uses_get_session_with_external(self):
"""Test that reboot_device() uses the external session via _get_session()."""
external_session = MagicMock()
external_session.close = AsyncMock()

mock_response = MagicMock()
mock_response.status = 200
mock_response.raise_for_status = MagicMock()

handler = RequestHandler("192.168.1.1", websession=external_session)

with patch.object(handler, '_get_session', return_value=(external_session, False)) as mock_get_session:
async_cm = AsyncMock()
async_cm.__aenter__.return_value = mock_response
external_session.post = MagicMock(return_value=async_cm)

status, result = await handler.reboot_device()

mock_get_session.assert_called_once()
external_session.post.assert_called_once()
assert status == RequestStatus.SUCCESS
assert result is True
external_session.close.assert_not_awaited()

@pytest.mark.asyncio
async def test_reboot_device_closes_internal_session(self):
"""Test that reboot_device() closes an internal session after use."""
handler = RequestHandler("192.168.1.1")

mock_session = MagicMock()
mock_session.close = AsyncMock()

mock_response = MagicMock()
mock_response.status = 200
mock_response.raise_for_status = MagicMock()

with patch.object(handler, '_get_session', return_value=(mock_session, True)):
async_cm = AsyncMock()
async_cm.__aenter__.return_value = mock_response
mock_session.post = MagicMock(return_value=async_cm)

status, result = await handler.reboot_device()

assert status == RequestStatus.SUCCESS
assert result is True
mock_session.close.assert_awaited_once()

@pytest.mark.asyncio
async def test_reboot_device_closes_session_on_error(self):
"""Test that reboot_device() closes the internal session even on error."""
handler = RequestHandler("192.168.1.1")

mock_session = MagicMock()
mock_session.close = AsyncMock()

with patch.object(handler, '_get_session', return_value=(mock_session, True)):
async_cm = AsyncMock()
async_cm.__aenter__.side_effect = aiohttp.ClientError("Connection refused")
mock_session.post = MagicMock(return_value=async_cm)

status, result = await handler.reboot_device()

assert status == RequestStatus.UNKNOWN_ERROR
assert result is False
mock_session.close.assert_awaited_once()


class TestAccessPointParsing:
"""Test for the get_access_point() response parsing fix."""

@pytest.mark.asyncio
async def test_get_access_point_parses_text_response(self):
"""Test that get_access_point() correctly parses JSON from text response.

Previously, get_access_point() called both resp.text() and resp.json()
on the same response, which could fail because the body can only be
consumed once. Now it uses json.loads() on the already-read text.
"""
handler = RequestHandler("192.168.1.1")

mock_session = MagicMock()
mock_session.close = AsyncMock()

json_body = '{"SSID": "KOMMSPOT-TEST", "KEY": "secret123"}'
mock_response = MagicMock()
mock_response.status = 200
mock_response.raise_for_status = MagicMock()
mock_response.text = AsyncMock(return_value=json_body)

with patch.object(handler, '_get_session', return_value=(mock_session, True)):
async_cm = AsyncMock()
async_cm.__aenter__.return_value = mock_response
mock_session.post = MagicMock(return_value=async_cm)

status, data = await handler.get_access_point()

assert status == RequestStatus.SUCCESS
assert data == {"SSID": "KOMMSPOT-TEST", "KEY": "secret123"}
# Verify resp.json() was NOT called (only resp.text() + json.loads)
mock_response.json.assert_not_called()
mock_session.close.assert_awaited_once()

@pytest.mark.asyncio
async def test_get_access_point_handles_non_json_response(self):
"""Test that get_access_point() handles responses without valid JSON."""
handler = RequestHandler("192.168.1.1")

mock_session = MagicMock()
mock_session.close = AsyncMock()

mock_response = MagicMock()
mock_response.status = 200
mock_response.raise_for_status = MagicMock()
mock_response.text = AsyncMock(return_value="not json at all")

with patch.object(handler, '_get_session', return_value=(mock_session, True)):
async_cm = AsyncMock()
async_cm.__aenter__.return_value = mock_response
mock_session.post = MagicMock(return_value=async_cm)

status, data = await handler.get_access_point()

assert status == RequestStatus.NO_DATA
assert data is None
Loading