Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions src/cpln/models/workloads.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
import random
from typing import Any, Optional, cast
from typing import Any, Optional

import inflection

Expand Down Expand Up @@ -59,11 +59,7 @@ def get_deployment(self, location: Optional[str] = None) -> Deployment:
deployment_data = self.client.api.get_workload_deployment(
self.config(location=location)
)
return Deployment.parse(
deployment_data,
api_client=cast(Any, self.client.api),
config=self.config(location=location),
)
return deployment_data

def delete(self) -> None:
"""
Expand Down
61 changes: 38 additions & 23 deletions tests/unit/cpln/models/test_workloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,11 @@ def test_exec_success(self) -> None:
mock_replica.exec.return_value = expected_response
mock_deployment.get_replicas.return_value = {container: [mock_replica]}

# Mock Deployment.parse to return the mock deployment directly
with patch(
"cpln.models.workloads.Deployment.parse", return_value=mock_deployment
# Mock API to return the mock deployment directly
with patch.object(
self.client.api,
"get_workload_deployment",
return_value=mock_deployment,
):
result = self.workload.exec(command, location, container=container)

Expand All @@ -150,8 +152,10 @@ def test_exec_error(self) -> None:
# Mock print to avoid output during test
with (
patch("builtins.print"),
patch(
"cpln.models.workloads.Deployment.parse", return_value=mock_deployment
patch.object(
self.client.api,
"get_workload_deployment",
return_value=mock_deployment,
),
self.assertRaises(WebSocketExitCodeError),
):
Expand All @@ -171,8 +175,10 @@ def test_ping_success(self) -> None:
mock_replica.exec.return_value = {"output": "ping"}
mock_deployment.get_replicas.return_value = {container: [mock_replica]}

with patch(
"cpln.models.workloads.Deployment.parse", return_value=mock_deployment
with patch.object(
self.client.api,
"get_workload_deployment",
return_value=mock_deployment,
):
result = self.workload.ping(location, container=container)

Expand All @@ -195,8 +201,10 @@ def test_ping_websocket_error(self) -> None:
mock_replica.exec.side_effect = error
mock_deployment.get_replicas.return_value = {container: [mock_replica]}

with patch(
"cpln.models.workloads.Deployment.parse", return_value=mock_deployment
with patch.object(
self.client.api,
"get_workload_deployment",
return_value=mock_deployment,
):
result = self.workload.ping(location, container=container)

Expand All @@ -215,8 +223,10 @@ def test_ping_general_exception(self) -> None:
mock_replica.exec.side_effect = RuntimeError("General error")
mock_deployment.get_replicas.return_value = {container: [mock_replica]}

with patch(
"cpln.models.workloads.Deployment.parse", return_value=mock_deployment
with patch.object(
self.client.api,
"get_workload_deployment",
return_value=mock_deployment,
):
result = self.workload.ping(location, container=container)

Expand All @@ -235,8 +245,10 @@ def test_exec_no_replicas(self) -> None:
mock_deployment.get_replicas.return_value = {}

with (
patch(
"cpln.models.workloads.Deployment.parse", return_value=mock_deployment
patch.object(
self.client.api,
"get_workload_deployment",
return_value=mock_deployment,
),
self.assertRaises(ValueError) as context,
):
Expand All @@ -255,8 +267,10 @@ def test_exec_container_not_found(self) -> None:
mock_deployment.get_replicas.return_value = {"other-container": []}

with (
patch(
"cpln.models.workloads.Deployment.parse", return_value=mock_deployment
patch.object(
self.client.api,
"get_workload_deployment",
return_value=mock_deployment,
),
self.assertRaises(ValueError) as context,
):
Expand All @@ -276,15 +290,14 @@ def test_get_replicas(self) -> None:
mock_deployment.get_replicas.return_value = expected_replicas

with patch(
"cpln.models.workloads.Deployment.parse", return_value=mock_deployment
):
"cpln.models.workloads.Workload.get_deployment",
return_value=mock_deployment,
) as mock_get_deployment:
result = self.workload.get_replicas(location)

self.assertEqual(result, expected_replicas)
# Verify the API was called with correct config
self.client.api.get_workload_deployment.assert_called_once_with(
self.workload.config(location=location)
)
# Verify get_deployment was called with correct location
mock_get_deployment.assert_called_once_with(location=location)

def test_get_containers(self) -> None:
"""Test get_containers method"""
Expand Down Expand Up @@ -330,8 +343,10 @@ def test_ping_general_exception_original(self) -> None:
mock_replica.exec.side_effect = Exception("Connection failed")
mock_deployment.get_replicas.return_value = {container: [mock_replica]}

with patch(
"cpln.models.workloads.Deployment.parse", return_value=mock_deployment
with patch.object(
self.client.api,
"get_workload_deployment",
return_value=mock_deployment,
):
result = self.workload.ping(location, container=container)

Expand Down