From 7fba6ef1ce1063354a19a2f1f7b171529494b1a6 Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Thu, 8 Jan 2026 16:40:43 +0000 Subject: [PATCH 1/3] Fix monitor worker disconnection --- mosaic/runtime/monitor.py | 25 +++++++++++++++++++++++++ stride/core.py | 12 ++++-------- 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/mosaic/runtime/monitor.py b/mosaic/runtime/monitor.py index 89a2ad9..e1e4014 100644 --- a/mosaic/runtime/monitor.py +++ b/mosaic/runtime/monitor.py @@ -75,6 +75,7 @@ def __init__(self, **kwargs): self._monitored_nodes = dict() self._monitored_tessera = dict() self._monitored_tasks = dict() + self._disconnected_runtimes = set() self._runtime_tessera = defaultdict(list) self._runtime_tasks = defaultdict(list) @@ -303,6 +304,8 @@ def set_profiler(self): self._loop.interval(self.append_description, interval=10) def update_node(self, sender_id, update, sub_resources): + if sender_id in self._disconnected_runtimes: + return if sender_id not in self._monitored_nodes: self._monitored_nodes[sender_id] = MonitoredResource(sender_id) @@ -311,26 +314,36 @@ def update_node(self, sender_id, update, sub_resources): self._monitor_strategy.update_node(node) def add_tessera_event(self, sender_id, msgs): + if sender_id in self._disconnected_runtimes: + return msgs = [msgs] if not isinstance(msgs, list) else msgs for msg in msgs: self._add_tessera_event(sender_id, **msg) def add_task_event(self, sender_id, msgs): + if sender_id in self._disconnected_runtimes: + return msgs = [msgs] if not isinstance(msgs, list) else msgs for msg in msgs: self._add_task_event(sender_id, **msg) def add_tessera_profile(self, sender_id, msgs): + if sender_id in self._disconnected_runtimes: + return msgs = [msgs] if not isinstance(msgs, list) else msgs for msg in msgs: self._add_tessera_profile(sender_id, **msg) def add_task_profile(self, sender_id, msgs): + if sender_id in self._disconnected_runtimes: + return msgs = [msgs] if not isinstance(msgs, list) else msgs for msg in msgs: self._add_task_profile(sender_id, **msg) def _add_tessera_event(self, sender_id, runtime_id, uid, **kwargs): + if runtime_id in self._disconnected_runtimes: + return if uid not in self._monitored_tessera: self._monitored_tessera[uid] = MonitoredObject(runtime_id, uid) self._runtime_tessera[runtime_id].append(uid) @@ -341,6 +354,8 @@ def _add_tessera_event(self, sender_id, runtime_id, uid, **kwargs): self._dirty_tessera.add(uid) def _add_task_event(self, sender_id, runtime_id, uid, tessera_id, **kwargs): + if runtime_id in self._disconnected_runtimes: + return if uid not in self._monitored_tasks: self._monitored_tasks[uid] = MonitoredObject(runtime_id, uid, tessera_id=tessera_id) self._runtime_tasks[runtime_id].append(uid) @@ -351,6 +366,8 @@ def _add_task_event(self, sender_id, runtime_id, uid, tessera_id, **kwargs): self._dirty_tasks.add(uid) def _add_tessera_profile(self, sender_id, runtime_id, uid, profile): + if runtime_id in self._disconnected_runtimes: + return if uid not in self._monitored_tessera: self._monitored_tessera[uid] = MonitoredObject(runtime_id, uid) self._runtime_tessera[runtime_id].append(uid) @@ -360,6 +377,8 @@ def _add_tessera_profile(self, sender_id, runtime_id, uid, profile): self._dirty_tessera.add(uid) def _add_task_profile(self, sender_id, runtime_id, uid, tessera_id, profile): + if runtime_id in self._disconnected_runtimes: + return if uid not in self._monitored_tasks: self._monitored_tasks[uid] = MonitoredObject(runtime_id, uid, tessera_id=tessera_id) self._runtime_tasks[runtime_id].append(uid) @@ -479,6 +498,7 @@ def disconnect(self, sender_id, uid): """ super().disconnect(sender_id, uid) + self._disconnected_runtimes.add(uid) # remove runtime from monitored nodes try: @@ -502,6 +522,11 @@ def disconnect(self, sender_id, uid): pass del self._runtime_tasks[uid] + # disconnect associated workers + if uid in self._monitored_nodes: + for worker_id in self._monitored_nodes[uid].sub_resources['workers'].keys(): + self.disconnect(sender_id, worker_id) + async def select_worker(self, sender_id): """ Select appropriate worker to allocate a tessera. diff --git a/stride/core.py b/stride/core.py index 392d736..6b5b89f 100644 --- a/stride/core.py +++ b/stride/core.py @@ -324,14 +324,10 @@ async def adjoint(self, grad=None, **kwargs): runtime = mosaic.runtime() async def redux(rec_grads, *grads): - if rec_grads is None: - sums = [ - _maybe_sum(None, g) for g in grads - ] - else: - sums = [ - _maybe_sum(r, g) for r, g in zip(rec_grads, grads) - ] + rec_grads = (None,)*len(grads) if rec_grads is None else rec_grads + sums = [ + _maybe_sum(r, g) for r, g in zip(rec_grads, grads) + ] return await asyncio.gather(*sums) def dealloc(objs): From edab7ddef492dfbf8cf9d675b53f95e191fe34f0 Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Thu, 8 Jan 2026 17:22:14 +0000 Subject: [PATCH 2/3] Minor code modification --- mosaic/runtime/monitor.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/mosaic/runtime/monitor.py b/mosaic/runtime/monitor.py index e1e4014..4a22f81 100644 --- a/mosaic/runtime/monitor.py +++ b/mosaic/runtime/monitor.py @@ -498,8 +498,15 @@ def disconnect(self, sender_id, uid): """ super().disconnect(sender_id, uid) + + # ensure runtime marked as disconnected self._disconnected_runtimes.add(uid) + # disconnect associated workers + if uid in self._monitored_nodes: + for worker_id in self._monitored_nodes[uid].sub_resources['workers'].keys(): + self.disconnect(sender_id, worker_id) + # remove runtime from monitored nodes try: del self._monitored_nodes[uid] @@ -522,11 +529,6 @@ def disconnect(self, sender_id, uid): pass del self._runtime_tasks[uid] - # disconnect associated workers - if uid in self._monitored_nodes: - for worker_id in self._monitored_nodes[uid].sub_resources['workers'].keys(): - self.disconnect(sender_id, worker_id) - async def select_worker(self, sender_id): """ Select appropriate worker to allocate a tessera. From e65f618754635c79bc358fccccf492b216955988 Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Fri, 9 Jan 2026 16:29:23 +0000 Subject: [PATCH 3/3] Retrieve future exception --- mosaic/comms/comms.py | 2 +- mosaic/core/base.py | 1 + mosaic/core/task.py | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/mosaic/comms/comms.py b/mosaic/comms/comms.py index 089fe52..b8ec109 100644 --- a/mosaic/comms/comms.py +++ b/mosaic/comms/comms.py @@ -579,7 +579,7 @@ def __init__(self, uid, address, port, **kwargs): self._heartbeat_timeout = None self._heartbeat_attempts = 0 self._heartbeat_max_attempts = 5 - self._heartbeat_interval = 5 + self._heartbeat_interval = 15 self._shaken = False diff --git a/mosaic/core/base.py b/mosaic/core/base.py index 84d7038..e33e35d 100644 --- a/mosaic/core/base.py +++ b/mosaic/core/base.py @@ -103,6 +103,7 @@ def deregister_runtime(self, uid): self.init_future.set_exception( RuntimeDisconnectedError('Remote runtime %s became disconnected' % uid) ) + self.init_future.exception() def __repr__(self): NotImplementedError('Unimplemented Base method __repr__') diff --git a/mosaic/core/task.py b/mosaic/core/task.py index d51c26a..331deb9 100644 --- a/mosaic/core/task.py +++ b/mosaic/core/task.py @@ -737,6 +737,7 @@ def deregister_runtime(self, uid): self._done_future.set_exception( RuntimeDisconnectedError('Remote runtime %s became disconnected' % uid) ) + self._done_future.exception() except asyncio.InvalidStateError: pass else: