Skip to content

Commit 1f11d5f

Browse files
committed
Initial work - not complete
1 parent 48830dc commit 1f11d5f

File tree

3 files changed

+128
-27
lines changed

3 files changed

+128
-27
lines changed

durabletask/worker.py

Lines changed: 108 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def __init__(
346346
else:
347347
self._interceptors = None
348348

349-
self._async_worker_manager = _AsyncWorkerManager(self._concurrency_options)
349+
self._async_worker_manager = _AsyncWorkerManager(self._concurrency_options, self._logger)
350350

351351
@property
352352
def concurrency_options(self) -> ConcurrencyOptions:
@@ -533,27 +533,31 @@ def stream_reader():
533533
if work_item.HasField("orchestratorRequest"):
534534
self._async_worker_manager.submit_orchestration(
535535
self._execute_orchestrator,
536+
self._cancel_orchestrator,
536537
work_item.orchestratorRequest,
537538
stub,
538539
work_item.completionToken,
539540
)
540541
elif work_item.HasField("activityRequest"):
541542
self._async_worker_manager.submit_activity(
542543
self._execute_activity,
544+
self._cancel_activity,
543545
work_item.activityRequest,
544546
stub,
545547
work_item.completionToken,
546548
)
547549
elif work_item.HasField("entityRequest"):
548550
self._async_worker_manager.submit_entity_batch(
549551
self._execute_entity_batch,
552+
self._cancel_entity_batch,
550553
work_item.entityRequest,
551554
stub,
552555
work_item.completionToken,
553556
)
554557
elif work_item.HasField("entityRequestV2"):
555558
self._async_worker_manager.submit_entity_batch(
556559
self._execute_entity_batch,
560+
self._cancel_entity_batch,
557561
work_item.entityRequestV2,
558562
stub,
559563
work_item.completionToken
@@ -670,6 +674,19 @@ def _execute_orchestrator(
670674
f"Failed to deliver orchestrator response for '{req.instanceId}' to sidecar: {ex}"
671675
)
672676

677+
def _cancel_orchestrator(
678+
self,
679+
req: pb.OrchestratorRequest,
680+
stub: stubs.TaskHubSidecarServiceStub,
681+
completionToken,
682+
):
683+
stub.AbandonTaskOrchestratorWorkItem(
684+
pb.AbandonOrchestrationTaskRequest(
685+
completionToken=completionToken
686+
)
687+
)
688+
self._logger.info(f"Cancelled orchestration task for invocation ID: {req.instanceId}")
689+
673690
def _execute_activity(
674691
self,
675692
req: pb.ActivityRequest,
@@ -703,6 +720,19 @@ def _execute_activity(
703720
f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {ex}"
704721
)
705722

723+
def _cancel_activity(
724+
self,
725+
req: pb.ActivityRequest,
726+
stub: stubs.TaskHubSidecarServiceStub,
727+
completionToken,
728+
):
729+
stub.AbandonTaskActivityWorkItem(
730+
pb.AbandonActivityTaskRequest(
731+
completionToken=completionToken
732+
)
733+
)
734+
self._logger.info(f"Cancelled activity task for task ID: {req.taskId} on orchestration ID: {req.orchestrationInstance.instanceId}")
735+
706736
def _execute_entity_batch(
707737
self,
708738
req: Union[pb.EntityBatchRequest, pb.EntityRequest],
@@ -771,6 +801,19 @@ def _execute_entity_batch(
771801

772802
return batch_result
773803

804+
def _cancel_entity_batch(
805+
self,
806+
req: pb.EntityBatchRequest,
807+
stub: stubs.TaskHubSidecarServiceStub,
808+
completionToken,
809+
):
810+
stub.AbandonTaskEntityWorkItem(
811+
pb.AbandonEntityTaskRequest(
812+
completionToken=completionToken
813+
)
814+
)
815+
self._logger.info(f"Cancelled entity batch task for entity instance ID: {req.instanceId}")
816+
774817

775818
class _RuntimeOrchestrationContext(task.OrchestrationContext):
776819
_generator: Optional[Generator[task.Task, Any, Any]]
@@ -1368,7 +1411,7 @@ def process_event(
13681411
timer_id = event.timerFired.timerId
13691412
timer_task = ctx._pending_tasks.pop(timer_id, None)
13701413
if not timer_task:
1371-
# TODO: Should this be an error? When would it ever happen?
1414+
# TODO: Should this be an error? would it ever happen?
13721415
if not ctx._is_replaying:
13731416
self._logger.warning(
13741417
f"{ctx.instance_id}: Ignoring unexpected timerFired event with ID = {timer_id}."
@@ -1920,8 +1963,10 @@ def _is_suspendable(event: pb.HistoryEvent) -> bool:
19201963

19211964

19221965
class _AsyncWorkerManager:
1923-
def __init__(self, concurrency_options: ConcurrencyOptions):
1966+
def __init__(self, concurrency_options: ConcurrencyOptions, logger: logging.Logger):
19241967
self.concurrency_options = concurrency_options
1968+
self._logger = logger
1969+
19251970
self.activity_semaphore = None
19261971
self.orchestration_semaphore = None
19271972
self.entity_semaphore = None
@@ -2031,17 +2076,47 @@ async def run(self):
20312076
)
20322077

20332078
# Start background consumers for each work type
2034-
if self.activity_queue is not None and self.orchestration_queue is not None \
2035-
and self.entity_batch_queue is not None:
2036-
await asyncio.gather(
2037-
self._consume_queue(self.activity_queue, self.activity_semaphore),
2038-
self._consume_queue(
2039-
self.orchestration_queue, self.orchestration_semaphore
2040-
),
2041-
self._consume_queue(
2042-
self.entity_batch_queue, self.entity_semaphore
2079+
try:
2080+
if self.activity_queue is not None and self.orchestration_queue is not None \
2081+
and self.entity_batch_queue is not None:
2082+
await asyncio.gather(
2083+
self._consume_queue(self.activity_queue, self.activity_semaphore),
2084+
self._consume_queue(
2085+
self.orchestration_queue, self.orchestration_semaphore
2086+
),
2087+
self._consume_queue(
2088+
self.entity_batch_queue, self.entity_semaphore
2089+
)
20432090
)
2044-
)
2091+
except Exception as queue_exception:
2092+
self._logger.error(f"Uncaught error in activity manager thread pool: {queue_exception}")
2093+
while self.activity_queue is not None and not self.activity_queue.empty():
2094+
try:
2095+
func, cancellation_func, args, kwargs = self.activity_queue.get_nowait()
2096+
await self._run_func(cancellation_func, *args, **kwargs)
2097+
self._logger.error(f"Activity work item args: {args}, kwargs: {kwargs}")
2098+
except asyncio.QueueEmpty:
2099+
pass
2100+
except Exception as cancellation_exception:
2101+
self._logger.error(f"Uncaught error while cancelling activity work item: {cancellation_exception}")
2102+
while self.orchestration_queue is not None and not self.orchestration_queue.empty():
2103+
try:
2104+
func, cancellation_func, args, kwargs = self.orchestration_queue.get_nowait()
2105+
await self._run_func(cancellation_func, *args, **kwargs)
2106+
self._logger.error(f"Orchestration work item args: {args}, kwargs: {kwargs}")
2107+
except asyncio.QueueEmpty:
2108+
pass
2109+
except Exception as cancellation_exception:
2110+
self._logger.error(f"Uncaught error while cancelling orchestration work item: {cancellation_exception}")
2111+
while self.entity_batch_queue is not None and not self.entity_batch_queue.empty():
2112+
try:
2113+
func, cancellation_func, args, kwargs = self.entity_batch_queue.get_nowait()
2114+
await self._run_func(cancellation_func, *args, **kwargs)
2115+
self._logger.error(f"Entity batch work item args: {args}, kwargs: {kwargs}")
2116+
except asyncio.QueueEmpty:
2117+
pass
2118+
except Exception as cancellation_exception:
2119+
self._logger.error(f"Uncaught error while cancelling entity batch work item: {cancellation_exception}")
20452120

20462121
async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphore):
20472122
# List to track running tasks
@@ -2061,7 +2136,7 @@ async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphor
20612136
except asyncio.TimeoutError:
20622137
continue
20632138

2064-
func, args, kwargs = work
2139+
func, cancellation_func, args, kwargs = work
20652140
# Create a concurrent task for processing
20662141
task = asyncio.create_task(
20672142
self._process_work_item(semaphore, queue, func, args, kwargs)
@@ -2092,26 +2167,26 @@ async def _run_func(self, func, *args, **kwargs):
20922167
self.thread_pool, lambda: func(*args, **kwargs)
20932168
)
20942169

2095-
def submit_activity(self, func, *args, **kwargs):
2096-
work_item = (func, args, kwargs)
2170+
def submit_activity(self, func, cancellation_func, *args, **kwargs):
2171+
work_item = (func, cancellation_func, args, kwargs)
20972172
self._ensure_queues_for_current_loop()
20982173
if self.activity_queue is not None:
20992174
self.activity_queue.put_nowait(work_item)
21002175
else:
21012176
# No event loop running, store in pending list
21022177
self._pending_activity_work.append(work_item)
21032178

2104-
def submit_orchestration(self, func, *args, **kwargs):
2105-
work_item = (func, args, kwargs)
2179+
def submit_orchestration(self, func, cancellation_func, *args, **kwargs):
2180+
work_item = (func, cancellation_func, args, kwargs)
21062181
self._ensure_queues_for_current_loop()
21072182
if self.orchestration_queue is not None:
21082183
self.orchestration_queue.put_nowait(work_item)
21092184
else:
21102185
# No event loop running, store in pending list
21112186
self._pending_orchestration_work.append(work_item)
21122187

2113-
def submit_entity_batch(self, func, *args, **kwargs):
2114-
work_item = (func, args, kwargs)
2188+
def submit_entity_batch(self, func, cancellation_func, *args, **kwargs):
2189+
work_item = (func, cancellation_func, args, kwargs)
21152190
self._ensure_queues_for_current_loop()
21162191
if self.entity_batch_queue is not None:
21172192
self.entity_batch_queue.put_nowait(work_item)
@@ -2123,7 +2198,7 @@ def shutdown(self):
21232198
self._shutdown = True
21242199
self.thread_pool.shutdown(wait=True)
21252200

2126-
def reset_for_new_run(self):
2201+
async def reset_for_new_run(self):
21272202
"""Reset the manager state for a new run."""
21282203
self._shutdown = False
21292204
# Clear any existing queues - they'll be recreated when needed
@@ -2132,18 +2207,28 @@ def reset_for_new_run(self):
21322207
# This ensures no items from previous runs remain
21332208
try:
21342209
while not self.activity_queue.empty():
2135-
self.activity_queue.get_nowait()
2210+
func, cancellation_func, args, kwargs = self.activity_queue.get_nowait()
2211+
await self._run_func(cancellation_func, *args, **kwargs)
21362212
except Exception:
21372213
pass
21382214
if self.orchestration_queue is not None:
21392215
try:
21402216
while not self.orchestration_queue.empty():
2141-
self.orchestration_queue.get_nowait()
2217+
func, cancellation_func, args, kwargs = self.orchestration_queue.get_nowait()
2218+
await self._run_func(cancellation_func, *args, **kwargs)
2219+
except Exception:
2220+
pass
2221+
if self.entity_batch_queue is not None:
2222+
try:
2223+
while not self.entity_batch_queue.empty():
2224+
func, cancellation_func, args, kwargs = self.entity_batch_queue.get_nowait()
2225+
await self._run_func(cancellation_func, *args, **kwargs)
21422226
except Exception:
21432227
pass
21442228
# Clear pending work lists
21452229
self._pending_activity_work.clear()
21462230
self._pending_orchestration_work.clear()
2231+
self._pending_entity_batch_work.clear()
21472232

21482233

21492234
# Export public API

tests/durabletask/test_worker_concurrency_loop.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,21 @@ def dummy_orchestrator(req, stub, completionToken):
5252
time.sleep(0.1)
5353
stub.CompleteOrchestratorTask('ok')
5454

55+
def cancel_dummy_orchestrator(req, stub, completionToken):
56+
pass
57+
5558
def dummy_activity(req, stub, completionToken):
5659
time.sleep(0.1)
5760
stub.CompleteActivityTask('ok')
5861

62+
def cancel_dummy_activity(req, stub, completionToken):
63+
pass
64+
5965
# Patch the worker's _execute_orchestrator and _execute_activity
6066
worker._execute_orchestrator = dummy_orchestrator
67+
worker._cancel_orchestrator = cancel_dummy_orchestrator
6168
worker._execute_activity = dummy_activity
69+
worker._cancel_activity = cancel_dummy_activity
6270

6371
orchestrator_requests = [DummyRequest('orchestrator', f'orch{i}') for i in range(3)]
6472
activity_requests = [DummyRequest('activity', f'act{i}') for i in range(4)]
@@ -67,9 +75,9 @@ async def run_test():
6775
# Start the worker manager's run loop in the background
6876
worker_task = asyncio.create_task(worker._async_worker_manager.run())
6977
for req in orchestrator_requests:
70-
worker._async_worker_manager.submit_orchestration(dummy_orchestrator, req, stub, DummyCompletionToken())
78+
worker._async_worker_manager.submit_orchestration(dummy_orchestrator, cancel_dummy_orchestrator, req, stub, DummyCompletionToken())
7179
for req in activity_requests:
72-
worker._async_worker_manager.submit_activity(dummy_activity, req, stub, DummyCompletionToken())
80+
worker._async_worker_manager.submit_activity(dummy_activity, cancel_dummy_activity, req, stub, DummyCompletionToken())
7381
await asyncio.sleep(1.0)
7482
orchestrator_count = sum(1 for t, _ in stub.completed if t == 'orchestrator')
7583
activity_count = sum(1 for t, _ in stub.completed if t == 'activity')

tests/durabletask/test_worker_concurrency_loop_async.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,21 @@ async def dummy_orchestrator(req, stub, completionToken):
5050
await asyncio.sleep(0.1)
5151
stub.CompleteOrchestratorTask('ok')
5252

53+
async def cancel_dummy_orchestrator(req, stub, completionToken):
54+
pass
55+
5356
async def dummy_activity(req, stub, completionToken):
5457
await asyncio.sleep(0.1)
5558
stub.CompleteActivityTask('ok')
5659

60+
async def cancel_dummy_activity(req, stub, completionToken):
61+
pass
62+
5763
# Patch the worker's _execute_orchestrator and _execute_activity
5864
grpc_worker._execute_orchestrator = dummy_orchestrator
65+
grpc_worker._cancel_orchestrator = cancel_dummy_orchestrator
5966
grpc_worker._execute_activity = dummy_activity
67+
grpc_worker._cancel_activity = cancel_dummy_activity
6068

6169
orchestrator_requests = [DummyRequest('orchestrator', f'orch{i}') for i in range(3)]
6270
activity_requests = [DummyRequest('activity', f'act{i}') for i in range(4)]
@@ -66,9 +74,9 @@ async def run_test():
6674
stub.completed.clear()
6775
worker_task = asyncio.create_task(grpc_worker._async_worker_manager.run())
6876
for req in orchestrator_requests:
69-
grpc_worker._async_worker_manager.submit_orchestration(dummy_orchestrator, req, stub, DummyCompletionToken())
77+
grpc_worker._async_worker_manager.submit_orchestration(dummy_orchestrator, cancel_dummy_orchestrator, req, stub, DummyCompletionToken())
7078
for req in activity_requests:
71-
grpc_worker._async_worker_manager.submit_activity(dummy_activity, req, stub, DummyCompletionToken())
79+
grpc_worker._async_worker_manager.submit_activity(dummy_activity, cancel_dummy_activity, req, stub, DummyCompletionToken())
7280
await asyncio.sleep(1.0)
7381
orchestrator_count = sum(1 for t, _ in stub.completed if t == 'orchestrator')
7482
activity_count = sum(1 for t, _ in stub.completed if t == 'activity')

0 commit comments

Comments
 (0)