Skip to content

Commit f5367b6

Browse files
committed
Support restart in async client
1 parent 03cb002 commit f5367b6

File tree

2 files changed

+83
-0
lines changed

2 files changed

+83
-0
lines changed

durabletask/client.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,26 @@ async def resume_orchestration(self, instance_id: str) -> None:
504504
self._logger.info(f"Resuming instance '{instance_id}'.")
505505
await self._stub.ResumeInstance(req)
506506

507+
async def restart_orchestration(self, instance_id: str, *,
508+
restart_with_new_instance_id: bool = False) -> str:
509+
"""Restarts an existing orchestration instance.
510+
511+
Args:
512+
instance_id: The ID of the orchestration instance to restart.
513+
restart_with_new_instance_id: If True, the restarted orchestration will use a new instance ID.
514+
If False (default), the restarted orchestration will reuse the same instance ID.
515+
516+
Returns:
517+
The instance ID of the restarted orchestration.
518+
"""
519+
req = pb.RestartInstanceRequest(
520+
instanceId=instance_id,
521+
restartWithNewInstanceId=restart_with_new_instance_id)
522+
523+
self._logger.info(f"Restarting instance '{instance_id}'.")
524+
res: pb.RestartInstanceResponse = await self._stub.RestartInstance(req)
525+
return res.instanceId
526+
507527
async def purge_orchestration(self, instance_id: str, recursive: bool = True) -> PurgeInstancesResult:
508528
req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive)
509529
self._logger.info(f"Purging instance '{instance_id}'.")

tests/durabletask/test_orchestration_async_e2e.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,66 @@ def orchestrator(ctx: task.OrchestrationContext, _):
178178

179179
state = await c.get_orchestration_state(id)
180180
assert state is None
181+
182+
183+
@pytest.mark.asyncio
184+
@pytest.mark.skip(reason="durabletask-go does not yet support RestartInstance")
185+
async def test_async_restart_with_same_instance_id():
186+
def orchestrator(ctx: task.OrchestrationContext, _):
187+
result = yield ctx.call_activity(say_hello, input="World")
188+
return result
189+
190+
def say_hello(ctx: task.ActivityContext, input: str):
191+
return f"Hello, {input}!"
192+
193+
with worker.TaskHubGrpcWorker() as w:
194+
w.add_orchestrator(orchestrator)
195+
w.add_activity(say_hello)
196+
w.start()
197+
198+
c = client.AsyncTaskHubGrpcClient()
199+
id = await c.schedule_new_orchestration(orchestrator)
200+
state = await c.wait_for_orchestration_completion(id, timeout=30)
201+
assert state is not None
202+
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
203+
assert state.serialized_output == json.dumps("Hello, World!")
204+
205+
# Restart the orchestration with the same instance ID
206+
restarted_id = await c.restart_orchestration(id)
207+
assert restarted_id == id
208+
209+
state = await c.wait_for_orchestration_completion(restarted_id, timeout=30)
210+
assert state is not None
211+
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
212+
assert state.serialized_output == json.dumps("Hello, World!")
213+
214+
215+
@pytest.mark.asyncio
216+
@pytest.mark.skip(reason="durabletask-go does not yet support RestartInstance")
217+
async def test_async_restart_with_new_instance_id():
218+
def orchestrator(ctx: task.OrchestrationContext, _):
219+
result = yield ctx.call_activity(say_hello, input="World")
220+
return result
221+
222+
def say_hello(ctx: task.ActivityContext, input: str):
223+
return f"Hello, {input}!"
224+
225+
with worker.TaskHubGrpcWorker() as w:
226+
w.add_orchestrator(orchestrator)
227+
w.add_activity(say_hello)
228+
w.start()
229+
230+
c = client.AsyncTaskHubGrpcClient()
231+
id = await c.schedule_new_orchestration(orchestrator)
232+
state = await c.wait_for_orchestration_completion(id, timeout=30)
233+
assert state is not None
234+
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
235+
236+
# Restart the orchestration with a new instance ID
237+
restarted_id = await c.restart_orchestration(id, restart_with_new_instance_id=True)
238+
assert restarted_id != id
239+
240+
state = await c.wait_for_orchestration_completion(restarted_id, timeout=30)
241+
assert state is not None
242+
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
243+
assert state.serialized_output == json.dumps("Hello, World!")

0 commit comments

Comments
 (0)