Skip to content

Commit 66cbcb2

Browse files
committed
more concurrency stuff
1 parent 194b24e commit 66cbcb2

File tree

2 files changed

+85
-11
lines changed

2 files changed

+85
-11
lines changed

durabletask/worker.py

Lines changed: 83 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,6 @@ def run_loop():
219219
self._is_running = True
220220

221221
async def _async_run_loop(self):
222-
self._async_worker_manager = _AsyncWorkerManager(self._concurrency_options)
223222
worker_task = asyncio.create_task(self._async_worker_manager.run())
224223
# Connection state management for retry fix
225224
current_channel = None
@@ -1245,40 +1244,92 @@ def _is_suspendable(event: pb.HistoryEvent) -> bool:
12451244

12461245
class _AsyncWorkerManager:
12471246
def __init__(self, concurrency_options: ConcurrencyOptions):
1248-
self.activity_semaphore = asyncio.Semaphore(
1249-
concurrency_options.maximum_concurrent_activity_work_items
1250-
)
1251-
self.orchestration_semaphore = asyncio.Semaphore(
1252-
concurrency_options.maximum_concurrent_orchestration_work_items
1253-
)
1247+
self.concurrency_options = concurrency_options
1248+
self.activity_semaphore = None
1249+
self.orchestration_semaphore = None
12541250
self.activity_queue: asyncio.Queue = asyncio.Queue()
12551251
self.orchestration_queue: asyncio.Queue = asyncio.Queue()
1252+
self._queue_event_loop: Optional[asyncio.AbstractEventLoop] = None
1253+
# Try to capture the current event loop when queues are created
1254+
try:
1255+
self._queue_event_loop = asyncio.get_running_loop()
1256+
except RuntimeError:
1257+
# No event loop running when manager was created
1258+
pass
12561259
self.thread_pool = ThreadPoolExecutor(
12571260
max_workers=concurrency_options.maximum_thread_pool_workers,
12581261
thread_name_prefix="DurableTask",
12591262
)
12601263
self._shutdown = False
12611264

1265+
def _ensure_queues_for_current_loop(self):
1266+
"""Ensure queues are bound to the current event loop."""
1267+
try:
1268+
current_loop = asyncio.get_running_loop()
1269+
except RuntimeError:
1270+
# No event loop running, can't create queues
1271+
return
1272+
1273+
if self._queue_event_loop is current_loop and hasattr(self, 'activity_queue') and hasattr(self, 'orchestration_queue'):
1274+
# Queues are already bound to the current loop and exist
1275+
return
1276+
1277+
# Need to recreate queues for the current event loop
1278+
# Create fresh queues - any items from previous event loops are dropped
1279+
self.activity_queue = asyncio.Queue()
1280+
self.orchestration_queue = asyncio.Queue()
1281+
self._queue_event_loop = current_loop
1282+
12621283
async def run(self):
1284+
# Reset shutdown flag in case this manager is being reused
1285+
self._shutdown = False
1286+
1287+
# Ensure queues are properly bound to the current event loop
1288+
self._ensure_queues_for_current_loop()
1289+
1290+
# Create semaphores in the current event loop
1291+
self.activity_semaphore = asyncio.Semaphore(
1292+
self.concurrency_options.maximum_concurrent_activity_work_items
1293+
)
1294+
self.orchestration_semaphore = asyncio.Semaphore(
1295+
self.concurrency_options.maximum_concurrent_orchestration_work_items
1296+
)
1297+
12631298
# Start background consumers for each work type
12641299
await asyncio.gather(
12651300
self._consume_queue(self.activity_queue, self.activity_semaphore),
12661301
self._consume_queue(self.orchestration_queue, self.orchestration_semaphore),
12671302
)
12681303

12691304
async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphore):
1305+
# List to track running tasks
1306+
running_tasks: set[asyncio.Task] = set()
1307+
12701308
while True:
1271-
# Exit if shutdown is set and the queue is empty
1272-
if self._shutdown and queue.empty():
1309+
# Clean up completed tasks
1310+
done_tasks = {task for task in running_tasks if task.done()}
1311+
running_tasks -= done_tasks
1312+
1313+
# Exit if shutdown is set and the queue is empty and no tasks are running
1314+
if self._shutdown and queue.empty() and not running_tasks:
12731315
break
1316+
12741317
try:
12751318
work = await asyncio.wait_for(queue.get(), timeout=1.0)
12761319
except asyncio.TimeoutError:
12771320
continue
1321+
12781322
func, args, kwargs = work
1279-
async with semaphore:
1323+
# Create a concurrent task for processing
1324+
task = asyncio.create_task(self._process_work_item(semaphore, queue, func, args, kwargs))
1325+
running_tasks.add(task)
1326+
1327+
async def _process_work_item(self, semaphore: asyncio.Semaphore, queue: asyncio.Queue, func, args, kwargs):
1328+
async with semaphore:
1329+
try:
12801330
await self._run_func(func, *args, **kwargs)
1281-
queue.task_done()
1331+
finally:
1332+
queue.task_done()
12821333

12831334
async def _run_func(self, func, *args, **kwargs):
12841335
if inspect.iscoroutinefunction(func):
@@ -1291,11 +1342,32 @@ async def _run_func(self, func, *args, **kwargs):
12911342
return await loop.run_in_executor(self.thread_pool, lambda: func(*args, **kwargs))
12921343

12931344
def submit_activity(self, func, *args, **kwargs):
1345+
self._ensure_queues_for_current_loop()
12941346
self.activity_queue.put_nowait((func, args, kwargs))
12951347

12961348
def submit_orchestration(self, func, *args, **kwargs):
1349+
self._ensure_queues_for_current_loop()
12971350
self.orchestration_queue.put_nowait((func, args, kwargs))
12981351

12991352
def shutdown(self):
13001353
self._shutdown = True
13011354
self.thread_pool.shutdown(wait=True)
1355+
1356+
def reset_for_new_run(self):
1357+
"""Reset the manager state for a new run."""
1358+
self._shutdown = False
1359+
# Clear any existing queues - they'll be recreated when needed
1360+
if hasattr(self, 'activity_queue'):
1361+
# Clear existing queue by creating a new one
1362+
# This ensures no items from previous runs remain
1363+
try:
1364+
while not self.activity_queue.empty():
1365+
self.activity_queue.get_nowait()
1366+
except Exception:
1367+
pass
1368+
if hasattr(self, 'orchestration_queue'):
1369+
try:
1370+
while not self.orchestration_queue.empty():
1371+
self.orchestration_queue.get_nowait()
1372+
except Exception:
1373+
pass

tests/durabletask/test_worker_concurrency_loop_async.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ async def dummy_activity(req, stub, completionToken):
6363
activity_requests = [DummyRequest('activity', f'act{i}') for i in range(4)]
6464

6565
async def run_test():
66+
# Clear stub state before each run
67+
stub.completed.clear()
6668
worker_task = asyncio.create_task(worker._async_worker_manager.run())
6769
for req in orchestrator_requests:
6870
worker._async_worker_manager.submit_orchestration(dummy_orchestrator, req, stub, DummyCompletionToken())

0 commit comments

Comments
 (0)