Skip to content

Commit e915ce3

Browse files
committed
PR feedback
1 parent 91f7da8 commit e915ce3

File tree

3 files changed

+48
-24
lines changed

3 files changed

+48
-24
lines changed

durabletask/client.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,15 @@ def __init__(self, *,
159159
secure_channel=secure_channel,
160160
interceptors=interceptors
161161
)
162+
self._channel = channel
162163
self._stub = stubs.TaskHubSidecarServiceStub(channel)
163164
self._logger = shared.get_logger("client", log_handler, log_formatter)
164165
self.default_version = default_version
165166

167+
def close(self) -> None:
168+
"""Closes the underlying gRPC channel."""
169+
self._channel.close()
170+
166171
def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *,
167172
input: Optional[TInput] = None,
168173
instance_id: Optional[str] = None,
@@ -239,26 +244,26 @@ def wait_for_orchestration_completion(self, instance_id: str, *,
239244
raise
240245

241246
def raise_orchestration_event(self, instance_id: str, event_name: str, *,
242-
data: Optional[Any] = None):
247+
data: Optional[Any] = None) -> None:
243248
req = build_raise_event_req(instance_id, event_name, data)
244249

245250
self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.")
246251
self._stub.RaiseEvent(req)
247252

248253
def terminate_orchestration(self, instance_id: str, *,
249254
output: Optional[Any] = None,
250-
recursive: bool = True):
255+
recursive: bool = True) -> None:
251256
req = build_terminate_req(instance_id, output, recursive)
252257

253258
self._logger.info(f"Terminating instance '{instance_id}'.")
254259
self._stub.TerminateInstance(req)
255260

256-
def suspend_orchestration(self, instance_id: str):
261+
def suspend_orchestration(self, instance_id: str) -> None:
257262
req = pb.SuspendRequest(instanceId=instance_id)
258263
self._logger.info(f"Suspending instance '{instance_id}'.")
259264
self._stub.SuspendInstance(req)
260265

261-
def resume_orchestration(self, instance_id: str):
266+
def resume_orchestration(self, instance_id: str) -> None:
262267
req = pb.ResumeRequest(instanceId=instance_id)
263268
self._logger.info(f"Resuming instance '{instance_id}'.")
264269
self._stub.ResumeInstance(req)
@@ -370,10 +375,15 @@ def __init__(self, *,
370375
secure_channel=secure_channel,
371376
interceptors=interceptors
372377
)
378+
self._channel = channel
373379
self._stub = stubs.TaskHubSidecarServiceStub(channel)
374380
self._logger = shared.get_logger("client", log_handler, log_formatter)
375381
self.default_version = default_version
376382

383+
async def close(self) -> None:
384+
"""Closes the underlying gRPC channel."""
385+
await self._channel.close()
386+
377387
async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *,
378388
input: Optional[TInput] = None,
379389
instance_id: Optional[str] = None,
@@ -450,26 +460,26 @@ async def wait_for_orchestration_completion(self, instance_id: str, *,
450460
raise
451461

452462
async def raise_orchestration_event(self, instance_id: str, event_name: str, *,
453-
data: Optional[Any] = None):
463+
data: Optional[Any] = None) -> None:
454464
req = build_raise_event_req(instance_id, event_name, data)
455465

456466
self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.")
457467
await self._stub.RaiseEvent(req)
458468

459469
async def terminate_orchestration(self, instance_id: str, *,
460470
output: Optional[Any] = None,
461-
recursive: bool = True):
471+
recursive: bool = True) -> None:
462472
req = build_terminate_req(instance_id, output, recursive)
463473

464474
self._logger.info(f"Terminating instance '{instance_id}'.")
465475
await self._stub.TerminateInstance(req)
466476

467-
async def suspend_orchestration(self, instance_id: str):
477+
async def suspend_orchestration(self, instance_id: str) -> None:
468478
req = pb.SuspendRequest(instanceId=instance_id)
469479
self._logger.info(f"Suspending instance '{instance_id}'.")
470480
await self._stub.SuspendInstance(req)
471481

472-
async def resume_orchestration(self, instance_id: str):
482+
async def resume_orchestration(self, instance_id: str) -> None:
473483
req = pb.ResumeRequest(instanceId=instance_id)
474484
self._logger.info(f"Resuming instance '{instance_id}'.")
475485
await self._stub.ResumeInstance(req)

durabletask/internal/grpc_interceptor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ class _ClientCallDetails(
2222
class _AsyncClientCallDetails(
2323
namedtuple(
2424
'_AsyncClientCallDetails',
25-
['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready']),
25+
['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression']),
2626
grpc.aio.ClientCallDetails):
2727
"""This is an implementation of the aio ClientCallDetails interface needed for async interceptors.
28-
This class takes five named values and inherits the ClientCallDetails from grpc.aio package.
28+
This class takes six named values and inherits the ClientCallDetails from grpc.aio package.
2929
This class encloses the values that describe a RPC to be invoked.
3030
"""
3131
pass

tests/durabletask/test_client.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -61,43 +61,53 @@ def test_grpc_channel_with_host_name_protocol_stripping():
6161

6262
prefix = "grpc://"
6363
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
64-
mock_insecure_channel.assert_called_with(host_name)
64+
mock_insecure_channel.assert_called_once_with(host_name)
65+
mock_insecure_channel.reset_mock()
6566

6667
prefix = "http://"
6768
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
68-
mock_insecure_channel.assert_called_with(host_name)
69+
mock_insecure_channel.assert_called_once_with(host_name)
70+
mock_insecure_channel.reset_mock()
6971

7072
prefix = "HTTP://"
7173
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
72-
mock_insecure_channel.assert_called_with(host_name)
74+
mock_insecure_channel.assert_called_once_with(host_name)
75+
mock_insecure_channel.reset_mock()
7376

7477
prefix = "GRPC://"
7578
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
76-
mock_insecure_channel.assert_called_with(host_name)
79+
mock_insecure_channel.assert_called_once_with(host_name)
80+
mock_insecure_channel.reset_mock()
7781

7882
prefix = ""
7983
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
80-
mock_insecure_channel.assert_called_with(host_name)
84+
mock_insecure_channel.assert_called_once_with(host_name)
85+
mock_insecure_channel.reset_mock()
8186

8287
prefix = "grpcs://"
8388
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
84-
mock_secure_channel.assert_called_with(host_name, ANY)
89+
mock_secure_channel.assert_called_once_with(host_name, ANY)
90+
mock_secure_channel.reset_mock()
8591

8692
prefix = "https://"
8793
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
88-
mock_secure_channel.assert_called_with(host_name, ANY)
94+
mock_secure_channel.assert_called_once_with(host_name, ANY)
95+
mock_secure_channel.reset_mock()
8996

9097
prefix = "HTTPS://"
9198
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
92-
mock_secure_channel.assert_called_with(host_name, ANY)
99+
mock_secure_channel.assert_called_once_with(host_name, ANY)
100+
mock_secure_channel.reset_mock()
93101

94102
prefix = "GRPCS://"
95103
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
96-
mock_secure_channel.assert_called_with(host_name, ANY)
104+
mock_secure_channel.assert_called_once_with(host_name, ANY)
105+
mock_secure_channel.reset_mock()
97106

98107
prefix = ""
99108
get_grpc_channel(prefix + host_name, True, interceptors=INTERCEPTORS)
100-
mock_secure_channel.assert_called_with(host_name, ANY)
109+
mock_secure_channel.assert_called_once_with(host_name, ANY)
110+
mock_secure_channel.reset_mock()
101111

102112

103113
# ==== Async channel tests ====
@@ -136,16 +146,20 @@ def test_async_grpc_channel_protocol_stripping():
136146
host_name = "myserver.com:1234"
137147

138148
get_async_grpc_channel("http://" + host_name)
139-
mock_insecure.assert_called_with(host_name, interceptors=None)
149+
mock_insecure.assert_called_once_with(host_name, interceptors=None)
150+
mock_insecure.reset_mock()
140151

141152
get_async_grpc_channel("grpc://" + host_name)
142-
mock_insecure.assert_called_with(host_name, interceptors=None)
153+
mock_insecure.assert_called_once_with(host_name, interceptors=None)
154+
mock_insecure.reset_mock()
143155

144156
get_async_grpc_channel("https://" + host_name)
145-
mock_secure.assert_called_with(host_name, ANY, interceptors=None)
157+
mock_secure.assert_called_once_with(host_name, ANY, interceptors=None)
158+
mock_secure.reset_mock()
146159

147160
get_async_grpc_channel("grpcs://" + host_name)
148-
mock_secure.assert_called_with(host_name, ANY, interceptors=None)
161+
mock_secure.assert_called_once_with(host_name, ANY, interceptors=None)
162+
mock_secure.reset_mock()
149163

150164

151165
# ==== Async client construction tests ====

0 commit comments

Comments
 (0)