Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mosaic/comms/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def __init__(self, uid, address, port, **kwargs):
self._heartbeat_timeout = None
self._heartbeat_attempts = 0
self._heartbeat_max_attempts = 5
self._heartbeat_interval = 5
self._heartbeat_interval = 15

self._shaken = False

Expand Down
1 change: 1 addition & 0 deletions mosaic/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def deregister_runtime(self, uid):
self.init_future.set_exception(
RuntimeDisconnectedError('Remote runtime %s became disconnected' % uid)
)
self.init_future.exception()

def __repr__(self):
NotImplementedError('Unimplemented Base method __repr__')
Expand Down
1 change: 1 addition & 0 deletions mosaic/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,7 @@ def deregister_runtime(self, uid):
self._done_future.set_exception(
RuntimeDisconnectedError('Remote runtime %s became disconnected' % uid)
)
self._done_future.exception()
except asyncio.InvalidStateError:
pass
else:
Expand Down
27 changes: 27 additions & 0 deletions mosaic/runtime/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(self, **kwargs):
self._monitored_nodes = dict()
self._monitored_tessera = dict()
self._monitored_tasks = dict()
self._disconnected_runtimes = set()

self._runtime_tessera = defaultdict(list)
self._runtime_tasks = defaultdict(list)
Expand Down Expand Up @@ -303,6 +304,8 @@ def set_profiler(self):
self._loop.interval(self.append_description, interval=10)

def update_node(self, sender_id, update, sub_resources):
if sender_id in self._disconnected_runtimes:
return
if sender_id not in self._monitored_nodes:
self._monitored_nodes[sender_id] = MonitoredResource(sender_id)

Expand All @@ -311,26 +314,36 @@ def update_node(self, sender_id, update, sub_resources):
self._monitor_strategy.update_node(node)

def add_tessera_event(self, sender_id, msgs):
if sender_id in self._disconnected_runtimes:
return
msgs = [msgs] if not isinstance(msgs, list) else msgs
for msg in msgs:
self._add_tessera_event(sender_id, **msg)

def add_task_event(self, sender_id, msgs):
if sender_id in self._disconnected_runtimes:
return
msgs = [msgs] if not isinstance(msgs, list) else msgs
for msg in msgs:
self._add_task_event(sender_id, **msg)

def add_tessera_profile(self, sender_id, msgs):
if sender_id in self._disconnected_runtimes:
return
msgs = [msgs] if not isinstance(msgs, list) else msgs
for msg in msgs:
self._add_tessera_profile(sender_id, **msg)

def add_task_profile(self, sender_id, msgs):
if sender_id in self._disconnected_runtimes:
return
msgs = [msgs] if not isinstance(msgs, list) else msgs
for msg in msgs:
self._add_task_profile(sender_id, **msg)

def _add_tessera_event(self, sender_id, runtime_id, uid, **kwargs):
if runtime_id in self._disconnected_runtimes:
return
if uid not in self._monitored_tessera:
self._monitored_tessera[uid] = MonitoredObject(runtime_id, uid)
self._runtime_tessera[runtime_id].append(uid)
Expand All @@ -341,6 +354,8 @@ def _add_tessera_event(self, sender_id, runtime_id, uid, **kwargs):
self._dirty_tessera.add(uid)

def _add_task_event(self, sender_id, runtime_id, uid, tessera_id, **kwargs):
if runtime_id in self._disconnected_runtimes:
return
if uid not in self._monitored_tasks:
self._monitored_tasks[uid] = MonitoredObject(runtime_id, uid, tessera_id=tessera_id)
self._runtime_tasks[runtime_id].append(uid)
Expand All @@ -351,6 +366,8 @@ def _add_task_event(self, sender_id, runtime_id, uid, tessera_id, **kwargs):
self._dirty_tasks.add(uid)

def _add_tessera_profile(self, sender_id, runtime_id, uid, profile):
if runtime_id in self._disconnected_runtimes:
return
if uid not in self._monitored_tessera:
self._monitored_tessera[uid] = MonitoredObject(runtime_id, uid)
self._runtime_tessera[runtime_id].append(uid)
Expand All @@ -360,6 +377,8 @@ def _add_tessera_profile(self, sender_id, runtime_id, uid, profile):
self._dirty_tessera.add(uid)

def _add_task_profile(self, sender_id, runtime_id, uid, tessera_id, profile):
if runtime_id in self._disconnected_runtimes:
return
if uid not in self._monitored_tasks:
self._monitored_tasks[uid] = MonitoredObject(runtime_id, uid, tessera_id=tessera_id)
self._runtime_tasks[runtime_id].append(uid)
Expand Down Expand Up @@ -480,6 +499,14 @@ def disconnect(self, sender_id, uid):
"""
super().disconnect(sender_id, uid)

# ensure runtime marked as disconnected
self._disconnected_runtimes.add(uid)

# disconnect associated workers
if uid in self._monitored_nodes:
for worker_id in self._monitored_nodes[uid].sub_resources['workers'].keys():
self.disconnect(sender_id, worker_id)

# remove runtime from monitored nodes
try:
del self._monitored_nodes[uid]
Expand Down
12 changes: 4 additions & 8 deletions stride/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,14 +324,10 @@ async def adjoint(self, grad=None, **kwargs):
runtime = mosaic.runtime()

async def redux(rec_grads, *grads):
if rec_grads is None:
sums = [
_maybe_sum(None, g) for g in grads
]
else:
sums = [
_maybe_sum(r, g) for r, g in zip(rec_grads, grads)
]
rec_grads = (None,)*len(grads) if rec_grads is None else rec_grads
sums = [
_maybe_sum(r, g) for r, g in zip(rec_grads, grads)
]
return await asyncio.gather(*sums)

def dealloc(objs):
Expand Down
Loading