diff --git a/bin/pr-name-check b/bin/pr-name-check new file mode 100755 index 0000000000..0f67e6172a --- /dev/null +++ b/bin/pr-name-check @@ -0,0 +1,69 @@ +#!/usr/bin/env bash + +set -euo pipefail + +branch="$(git rev-parse --abbrev-ref HEAD)" + +# based on: https://github.com/dimensionalOS/wiki/wiki +allowed_types="feat fix chore refactor docs" +allowed_names="stash ivan paul alexl mustafa miguel christie ruthwik jalaj yashas yash matt jing juan jeff unknown" + +if [[ "$branch" != */*/* ]]; then + echo "Invalid branch name: '$branch'" + echo "Expected format: //" + echo "Allowed names: $allowed_names" + echo "Allowed types: $allowed_types" + exit 1 +fi + +branch_name="${branch%%/*}" +rest="${branch#*/}" +branch_type="${rest%%/*}" +branch_description="${branch#*/*/}" + +if [[ -z "$branch_description" || "$branch_description" == "$branch" ]]; then + echo "Invalid branch name: '$branch'" + echo "Expected format: //" + exit 1 +fi + +name_ok=0 +for n in $allowed_names; do + if [[ "$branch_name" == "$n" ]]; then + name_ok=1 + break + fi +done + +type_ok=0 +for t in $allowed_types; do + if [[ "$branch_type" == "$t" ]]; then + type_ok=1 + break + fi +done + +if [[ "$name_ok" -ne 1 || "$type_ok" -ne 1 ]]; then + echo + echo + echo + echo + echo + echo "Invalid branch name: '$branch'" + echo + echo " Expected format: //" + echo " Example: jeff/fix/ci-divergence" + echo " Parsed name: $branch_name" + echo " Allowed names: $allowed_names" + echo " Parsed type: $branch_type" + echo " Allowed types: $allowed_types" + echo + echo "Wait 4 seconds if you want to ignore this error" + sleep 1; echo 4 + sleep 1; echo 3 + sleep 1; echo 2 + sleep 1; echo 1 + exit 1 +else + echo "Branch naming check passed: $branch" +fi diff --git a/dimos/agents_deprecated/memory/image_embedding.py b/dimos/agents_deprecated/memory/image_embedding.py index 27e16f1aa8..d6b0967642 100644 --- a/dimos/agents_deprecated/memory/image_embedding.py +++ b/dimos/agents_deprecated/memory/image_embedding.py @@ -63,7 +63,7 @@ def __init__(self, model_name: str = "clip", dimensions: int = 512) -> None: def _initialize_model(self): # type: ignore[no-untyped-def] """Initialize the specified embedding model.""" try: - import onnxruntime as ort # type: ignore[import-untyped] + import onnxruntime as ort # type: ignore[import-untyped,import-not-found] import torch # noqa: F401 from transformers import ( # type: ignore[import-untyped] AutoFeatureExtractor, diff --git a/dimos/core/docker_build.py b/dimos/core/docker_build.py index 7ee90fc5c3..24fd2b3e44 100644 --- a/dimos/core/docker_build.py +++ b/dimos/core/docker_build.py @@ -20,6 +20,7 @@ from __future__ import annotations +import hashlib import subprocess from typing import TYPE_CHECKING @@ -32,10 +33,11 @@ logger = setup_logger() -# Timeout for quick Docker commands +_BUILD_HASH_LABEL = "dimos.build.hash" + DOCKER_CMD_TIMEOUT = 20 -# Sentinel value to detect already-converted Dockerfiles (UUID ensures uniqueness) +# the way of detecting already-converted Dockerfiles (UUID ensures uniqueness) DIMOS_SENTINEL = "DIMOS-MODULE-CONVERSION-427593ae-c6e8-4cf1-9b2d-ee81a420a5dc" # Footer appended to Dockerfiles for DimOS module conversion @@ -53,28 +55,6 @@ """ -def _run(cmd: list[str], *, timeout: float | None = None) -> subprocess.CompletedProcess[str]: - """Run a command and return the result.""" - return subprocess.run(cmd, capture_output=True, text=True, timeout=timeout, check=False) - - -def _run_streaming(cmd: list[str]) -> int: - """Run command and stream output to terminal. Returns exit code.""" - result = subprocess.run(cmd, text=True) - return result.returncode - - -def _docker_bin(cfg: DockerModuleConfig) -> str: - """Get docker binary path.""" - return cfg.docker_bin or "docker" - - -def _image_exists(docker_bin: str, image_name: str) -> bool: - """Check if a Docker image exists locally.""" - r = _run([docker_bin, "image", "inspect", image_name], timeout=DOCKER_CMD_TIMEOUT) - return r.returncode == 0 - - def _convert_dockerfile(dockerfile: Path) -> Path: """Append DimOS footer to Dockerfile. Returns path to converted file.""" content = dockerfile.read_text() @@ -85,32 +65,82 @@ def _convert_dockerfile(dockerfile: Path) -> Path: logger.info(f"Converting {dockerfile.name} to DimOS format") - converted = dockerfile.parent / f".{dockerfile.name}.dimos" + converted = dockerfile.parent / f".{dockerfile.name}.ignore" converted.write_text(content.rstrip() + "\n" + DIMOS_FOOTER.lstrip("\n")) return converted +def _compute_build_hash(cfg: DockerModuleConfig) -> str: + """Hash Dockerfile contents and build args.""" + assert cfg.docker_file is not None + digest = hashlib.sha256() + digest.update(cfg.docker_file.read_bytes()) + for key, val in sorted(cfg.docker_build_args.items()): + digest.update(f"{key}={val}".encode()) + for arg in cfg.docker_build_extra_args: + digest.update(arg.encode()) + return digest.hexdigest() + + +def _get_image_build_hash(cfg: DockerModuleConfig) -> str | None: + """Read the build hash label from an existing Docker image.""" + r = subprocess.run( + [ + cfg.docker_bin, + "image", + "inspect", + "-f", + '{{index .Config.Labels "' + _BUILD_HASH_LABEL + '"}}', + cfg.docker_image, + ], + capture_output=True, + text=True, + timeout=DOCKER_CMD_TIMEOUT, + check=False, + ) + if r.returncode != 0: + return None + value = r.stdout.strip() + # docker prints "" when the label is missing + return value if value and value != "" else None + + def build_image(cfg: DockerModuleConfig) -> None: """Build Docker image using footer mode conversion.""" if cfg.docker_file is None: raise ValueError("docker_file is required for building Docker images") + + build_hash = _compute_build_hash(cfg) dockerfile = _convert_dockerfile(cfg.docker_file) context = cfg.docker_build_context or cfg.docker_file.parent - cmd = [_docker_bin(cfg), "build", "-t", cfg.docker_image, "-f", str(dockerfile)] + cmd = [cfg.docker_bin, "build", "-t", cfg.docker_image, "-f", str(dockerfile)] + cmd.extend(["--label", f"{_BUILD_HASH_LABEL}={build_hash}"]) for k, v in cfg.docker_build_args.items(): cmd.extend(["--build-arg", f"{k}={v}"]) + cmd.extend(cfg.docker_build_extra_args) cmd.append(str(context)) logger.info(f"Building Docker image: {cfg.docker_image}") - exit_code = _run_streaming(cmd) - if exit_code != 0: - raise RuntimeError(f"Docker build failed with exit code {exit_code}") + # Stream stdout to terminal so the user sees build progress, but capture + # stderr separately so we can include it in the error message on failure. + result = subprocess.run(cmd, text=True, stderr=subprocess.PIPE) + if result.returncode != 0: + raise RuntimeError( + f"Docker build failed with exit code {result.returncode}\nSTDERR:\n{result.stderr}" + ) def image_exists(cfg: DockerModuleConfig) -> bool: """Check if the configured Docker image exists locally.""" - return _image_exists(_docker_bin(cfg), cfg.docker_image) + r = subprocess.run( + [cfg.docker_bin, "image", "inspect", cfg.docker_image], + capture_output=True, + text=True, + timeout=DOCKER_CMD_TIMEOUT, + check=False, + ) + return r.returncode == 0 __all__ = [ diff --git a/dimos/core/docker_runner.py b/dimos/core/docker_runner.py index dcb75fbdee..16727a8dd1 100644 --- a/dimos/core/docker_runner.py +++ b/dimos/core/docker_runner.py @@ -18,16 +18,14 @@ from dataclasses import field import importlib import json -import os import signal import subprocess import threading import time from typing import TYPE_CHECKING, Any -from dimos.core.docker_build import build_image, image_exists -from dimos.core.module import Module, ModuleConfig -from dimos.core.rpc_client import RpcCall +from dimos.core.module import ModuleConfig +from dimos.core.rpc_client import ModuleProxyProtocol, RpcCall from dimos.protocol.rpc.pubsubrpc import LCMRPC from dimos.utils.logging_config import setup_logger from dimos.visualization.rerun.bridge import RERUN_GRPC_PORT, RERUN_WEB_PORT @@ -36,9 +34,12 @@ from collections.abc import Callable from pathlib import Path + from dimos.core.module import Module + logger = setup_logger() DOCKER_RUN_TIMEOUT = 120 # Timeout for `docker run` command execution +DOCKER_PULL_TIMEOUT_DEFAULT = None # No timeout for `docker pull` (images can be large) DOCKER_CMD_TIMEOUT = 20 # Timeout for quick Docker commands (inspect, rm, logs) DOCKER_STATUS_TIMEOUT = 10 # Timeout for container status checks DOCKER_STOP_TIMEOUT = 30 # Timeout for `docker stop` command (graceful shutdown) @@ -52,6 +53,8 @@ class DockerModuleConfig(ModuleConfig): For advanced Docker options not listed here, use docker_extra_args. Example: docker_extra_args=["--cap-add=SYS_ADMIN", "--read-only"] + + NOTE: a DockerModule will rebuild automatically if the Dockerfile or build args change """ # Build / image @@ -59,6 +62,7 @@ class DockerModuleConfig(ModuleConfig): docker_file: Path | None = None # Required on host for building, not needed in container docker_build_context: Path | None = None docker_build_args: dict[str, str] = field(default_factory=dict) + docker_build_extra_args: list[str] = field(default_factory=list) # Extra args for docker build # Identity docker_container_name: str | None = None @@ -72,9 +76,9 @@ class DockerModuleConfig(ModuleConfig): ) # (host, container, proto) # Runtime resources - docker_gpus: str | None = "all" - docker_shm_size: str = "2g" - docker_restart_policy: str = "on-failure:3" + docker_gpus: str | None = None + docker_shm_size: str = "4g" + docker_restart_policy: str = "no" # Env + volumes + devices docker_env_files: list[str] = field(default_factory=list) @@ -93,10 +97,14 @@ class DockerModuleConfig(ModuleConfig): docker_command: list[str] | None = None docker_extra_args: list[str] = field(default_factory=list) - # Startup readiness + # Timeouts + docker_pull_timeout: float | None = DOCKER_PULL_TIMEOUT_DEFAULT docker_startup_timeout: float = 120.0 docker_poll_interval: float = 1.0 + # Reconnect to a running container instead of restarting it + docker_reconnect_container: bool = False + # Advanced docker_bin: str = "docker" @@ -104,7 +112,11 @@ class DockerModuleConfig(ModuleConfig): def is_docker_module(module_class: type) -> bool: """Check if a module class should run in Docker based on its default_config.""" default_config = getattr(module_class, "default_config", None) - return default_config is not None and issubclass(default_config, DockerModuleConfig) + return ( + default_config is not None + and isinstance(default_config, type) + and issubclass(default_config, DockerModuleConfig) + ) # Docker helpers @@ -115,25 +127,20 @@ def _run(cmd: list[str], *, timeout: float | None = None) -> subprocess.Complete return subprocess.run(cmd, capture_output=True, text=True, timeout=timeout, check=False) -def _docker_bin(cfg: DockerModuleConfig) -> str: - """Get docker binary path, defaulting to 'docker' if empty/None.""" - return cfg.docker_bin or "docker" - - def _remove_container(cfg: DockerModuleConfig, name: str) -> None: - _run([_docker_bin(cfg), "rm", "-f", name], timeout=DOCKER_CMD_TIMEOUT) + _run([cfg.docker_bin, "rm", "-f", name], timeout=DOCKER_CMD_TIMEOUT) def _is_container_running(cfg: DockerModuleConfig, name: str) -> bool: r = _run( - [_docker_bin(cfg), "inspect", "-f", "{{.State.Running}}", name], + [cfg.docker_bin, "inspect", "-f", "{{.State.Running}}", name], timeout=DOCKER_STATUS_TIMEOUT, ) return r.returncode == 0 and r.stdout.strip() == "true" def _tail_logs(cfg: DockerModuleConfig, name: str, n: int = LOG_TAIL_LINES) -> str: - r = _run([_docker_bin(cfg), "logs", "--tail", str(n), name], timeout=DOCKER_CMD_TIMEOUT) + r = _run([cfg.docker_bin, "logs", "--tail", str(n), name], timeout=DOCKER_CMD_TIMEOUT) out = (r.stdout or "").rstrip() err = (r.stderr or "").rstrip() return out + ("\n" + err if err else "") @@ -156,107 +163,163 @@ def _extract_module_config(cfg: DockerModuleConfig) -> dict[str, Any]: # Host-side Docker-backed Module handle -class DockerModule: +class DockerModule(ModuleProxyProtocol): """ Host-side handle for a module running inside Docker. Lifecycle: - - start(): launches container, waits for module ready via RPC - - stop(): stops container - - __getattr__: exposes RpcCall for @rpc methods on remote module + - start(): builds the image if needed, launches the container, waits for readiness, calls the remote module's start() RPC (after streams are wired) + - stop(): stops the container and cleans up Communication: All RPC happens via LCM multicast (requires --network=host). """ + config: DockerModuleConfig + def __init__(self, module_class: type[Module], *args: Any, **kwargs: Any) -> None: - # Config + from dimos.core.docker_build import ( + _compute_build_hash, + _get_image_build_hash, + build_image, + image_exists, + ) + config_class = getattr(module_class, "default_config", DockerModuleConfig) + if not issubclass(config_class, DockerModuleConfig): + raise TypeError( + f"{module_class.__name__}.default_config must be a DockerModuleConfig subclass, " + f"got {config_class.__name__}" + ) config = config_class(**kwargs) - # Module info self._module_class = module_class - self._config = config + self.config = config self._args = args self._kwargs = kwargs self._running = False self.remote_name = module_class.__name__ + # Derive container name from image + class name: "my-registry/foo:v2" → "dimos_myclass_foo_v2" + image_ref = config.docker_image.rsplit("/", 1)[-1] self._container_name = ( config.docker_container_name - or f"dimos_{module_class.__name__.lower()}_{os.getpid()}_{int(time.time())}" + or f"dimos_{module_class.__name__.lower()}_{image_ref.replace(':', '_')}" ) - # RPC setup self.rpc = LCMRPC() self.rpcs = set(module_class.rpcs.keys()) # type: ignore[attr-defined] self.rpc_calls: list[str] = getattr(module_class, "rpc_calls", []) self._unsub_fns: list[Callable[[], None]] = [] self._bound_rpc_calls: dict[str, RpcCall] = {} - # Build image if needed (but don't start - caller must call start() explicitly) - if not image_exists(config): - logger.info(f"Building {config.docker_image}") - build_image(config) + # Build or pull image, launch container, wait for RPC server + try: + if config.docker_file is not None: + current_hash = _compute_build_hash(config) + stored_hash = _get_image_build_hash(config) + if current_hash != stored_hash: + logger.info(f"Building {config.docker_image}") + build_image(config) + elif not image_exists(config): + logger.info(f"Pulling {config.docker_image}") + r = subprocess.run( + [config.docker_bin, "pull", config.docker_image], + text=True, + stderr=subprocess.PIPE, + timeout=config.docker_pull_timeout, + ) + if r.returncode != 0: + raise RuntimeError( + f"Failed to pull image '{config.docker_image}'.\nSTDERR:\n{r.stderr}" + ) + + reconnect = False + if _is_container_running(config, self._container_name): + if config.docker_reconnect_container: + logger.info(f"Reconnecting to running container: {self._container_name}") + reconnect = True + else: + logger.info(f"Stopping existing container: {self._container_name}") + _run( + [config.docker_bin, "stop", self._container_name], + timeout=DOCKER_STOP_TIMEOUT, + ) + + if not reconnect: + _remove_container(config, self._container_name) + cmd = self._build_docker_run_command() + logger.info(f"Starting docker container: {self._container_name}") + r = _run(cmd, timeout=DOCKER_RUN_TIMEOUT) + if r.returncode != 0: + raise RuntimeError( + f"Failed to start container.\nSTDOUT:\n{r.stdout}\nSTDERR:\n{r.stderr}" + ) + self.rpc.start() + self._running = True + # docker run -d returns before Module.__init__ finishes in the container, + # so we poll until the RPC server is reachable before returning. + self._wait_for_rpc() + except Exception: + with suppress(Exception): + self._cleanup() + raise + + def get_rpc_method_names(self) -> list[str]: + return self.rpc_calls def set_rpc_method(self, method: str, callable: RpcCall) -> None: callable.set_rpc(self.rpc) self._bound_rpc_calls[method] = callable + # Forward to container — Module.set_rpc_method unpickles the RpcCall + # and wires it with the container's own LCMRPC + self.rpc.call_sync(f"{self.remote_name}/set_rpc_method", ([method, callable], {})) def get_rpc_calls(self, *methods: str) -> RpcCall | tuple[RpcCall, ...]: - # Check all requested methods exist missing = set(methods) - self._bound_rpc_calls.keys() if missing: raise ValueError(f"RPC methods not found: {missing}") - # Return single RpcCall or tuple calls = tuple(self._bound_rpc_calls[m] for m in methods) return calls[0] if len(calls) == 1 else calls def start(self) -> None: - if self._running: - return - - cfg = self._config - - # Prevent accidental kill of running container with same name - if _is_container_running(cfg, self._container_name): - raise RuntimeError( - f"Container '{self._container_name}' already running. " - "Choose a different container_name or stop the existing container." - ) - _remove_container(cfg, self._container_name) - - cmd = self._build_docker_run_command() - logger.info(f"Starting docker container: {self._container_name}") - r = _run(cmd, timeout=DOCKER_RUN_TIMEOUT) - if r.returncode != 0: - raise RuntimeError( - f"Failed to start container.\nSTDOUT:\n{r.stdout}\nSTDERR:\n{r.stderr}" - ) - - self.rpc.start() - self._running = True - self._wait_for_ready() + """Invoke the remote module's start() RPC.""" + try: + self.rpc.call_sync(f"{self.remote_name}/start", ([], {})) + except Exception: + with suppress(Exception): + self.stop() + raise def stop(self) -> None: """Gracefully stop the Docker container and clean up resources.""" - # Signal remote module, stop RPC, unsubscribe handlers (ignore failures) + if not self._running: + return + self._running = False # claim shutdown before any side-effects with suppress(Exception): - if self._running: - self.rpc.call_nowait(f"{self.remote_name}/stop", ([], {})) + self.rpc.call_nowait(f"{self.remote_name}/stop", ([], {})) + self._cleanup() + + def _cleanup(self) -> None: + """Release all resources. Idempotent — safe to call from partial init or after stop().""" with suppress(Exception): self.rpc.stop() - for unsub in self._unsub_fns: + for unsub in getattr(self, "_unsub_fns", []): with suppress(Exception): unsub() - self._unsub_fns.clear() - - # Stop and remove container - _run([_docker_bin(self._config), "stop", self._container_name], timeout=DOCKER_STOP_TIMEOUT) - _remove_container(self._config, self._container_name) + with suppress(Exception): + self._unsub_fns.clear() + if not getattr(getattr(self, "config", None), "docker_reconnect_container", False): + with suppress(Exception): + _run( + [self.config.docker_bin, "stop", self._container_name], + timeout=DOCKER_STOP_TIMEOUT, + ) + with suppress(Exception): + _remove_container(self.config, self._container_name) self._running = False - logger.info(f"Stopped container: {self._container_name}") + logger.info(f"Cleaned up container handle: {self._container_name}") def status(self) -> dict[str, Any]: - cfg = self._config + cfg = self.config return { "module": self.remote_name, "container_name": self._container_name, @@ -265,34 +328,30 @@ def status(self) -> dict[str, Any]: } def tail_logs(self, n: int = 200) -> str: - return _tail_logs(self._config, self._container_name, n=n) + return _tail_logs(self.config, self._container_name, n=n) def set_transport(self, stream_name: str, transport: Any) -> bool: - """Configure stream transport in container. Mirrors Module.set_transport() for autoconnect().""" - topic = getattr(transport, "topic", None) - if topic is None: - return False - if hasattr(topic, "topic"): - topic = topic.topic + """Forward to the container's Module.set_transport RPC.""" result, _ = self.rpc.call_sync( - f"{self.remote_name}/configure_stream", ([stream_name, str(topic)], {}) + f"{self.remote_name}/set_transport", ([stream_name, transport], {}) ) return bool(result) def __getattr__(self, name: str) -> Any: - if name in self.rpcs: + rpcs = self.__dict__.get("rpcs") + if rpcs is not None and name in rpcs: original_method = getattr(self._module_class, name, None) return RpcCall(original_method, self.rpc, name, self.remote_name, self._unsub_fns, None) - raise AttributeError(f"{name} not found on {self._module_class.__name__}") + raise AttributeError(f"{name} not found on {type(self).__name__}") # Docker command building (split into focused helpers for readability) def _build_docker_run_command(self) -> list[str]: """Build the complete `docker run` command.""" - cfg = self._config + cfg = self.config self._validate_config(cfg) - cmd = [_docker_bin(cfg), "run", "-d"] + cmd = [cfg.docker_bin, "run", "-d"] self._add_lifecycle_args(cmd, cfg) self._add_network_args(cmd, cfg) self._add_port_args(cmd, cfg) @@ -399,16 +458,53 @@ def _build_container_command(self, cfg: DockerModuleConfig) -> list[str]: if cfg.docker_command: return list(cfg.docker_command) - module_path = f"{self._module_class.__module__}.{self._module_class.__name__}" + module_name = self._module_class.__module__ + if module_name == "__main__": + # When run as `python script.py`, __module__ is "__main__". + # Resolve to the actual dotted module path so the container can import it. + import __main__ + + spec = getattr(__main__, "__spec__", None) + if spec and spec.name: + module_name = spec.name + else: + # Fallback: derive from file path relative to cwd + main_file = getattr(__main__, "__file__", None) + if main_file: + import pathlib + + try: + rel = pathlib.Path(main_file).resolve().relative_to(pathlib.Path.cwd()) + except ValueError: + raise RuntimeError( + f"Cannot derive module path: '{main_file}' is not under cwd " + f"'{pathlib.Path.cwd()}'. " + "Run with `python -m` or set docker_command explicitly." + ) from None + module_name = str(rel.with_suffix("")).replace("/", ".") + else: + raise RuntimeError( + "Cannot determine module path for __main__. " + "Run with `python -m` or set docker_command explicitly." + ) + module_path = f"{module_name}.{self._module_class.__name__}" # Filter out docker-specific kwargs (paths, etc.) - only pass module config kwargs = {"config": _extract_module_config(cfg)} payload = {"module_path": module_path, "args": list(self._args), "kwargs": kwargs} # DimOS base image entrypoint already runs "dimos.core.docker_runner run" - return ["--payload", json.dumps(payload, separators=(",", ":"))] - - def _wait_for_ready(self) -> None: - """Poll the module's RPC endpoint until ready, crashed, or timeout.""" - cfg = self._config + try: + payload_json = json.dumps(payload, separators=(",", ":")) + except TypeError as e: + raise TypeError( + f"Cannot serialize DockerModule payload to JSON: {e}\n" + f"Ensure all constructor args/kwargs for {self._module_class.__name__} are " + f"JSON-serializable, or use docker_command to bypass automatic payload generation." + ) from e + return ["--payload", payload_json] + + def _wait_for_rpc(self) -> None: + """Poll until the container's RPC server is reachable.""" + cfg = self.config start_time = time.time() logger.info(f"Waiting for {self.remote_name} to be ready...") @@ -420,13 +516,14 @@ def _wait_for_ready(self) -> None: try: self.rpc.call_sync( - f"{self.remote_name}/start", ([], {}), rpc_timeout=RPC_READY_TIMEOUT + f"{self.remote_name}/get_rpc_method_names", + ([], {}), + rpc_timeout=RPC_READY_TIMEOUT, ) elapsed = time.time() - start_time logger.info(f"{self.remote_name} ready ({elapsed:.1f}s)") return except (TimeoutError, ConnectionError, OSError): - # Module not ready yet - retry after poll interval time.sleep(cfg.docker_poll_interval) logs = _tail_logs(cfg, self._container_name) diff --git a/dimos/core/docker_worker_manager.py b/dimos/core/docker_worker_manager.py new file mode 100644 index 0000000000..520468182f --- /dev/null +++ b/dimos/core/docker_worker_manager.py @@ -0,0 +1,52 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from contextlib import suppress +from typing import TYPE_CHECKING, Any + +from dimos.core.module import ModuleSpec +from dimos.utils.safe_thread_map import ExceptionGroup, safe_thread_map + +if TYPE_CHECKING: + from dimos.core.docker_runner import DockerModule + + +class DockerWorkerManager: + """Parallel deployment of Docker-backed modules.""" + + @staticmethod + def deploy_parallel( + specs: list[ModuleSpec], + ) -> list[DockerModule]: + """Deploy multiple DockerModules in parallel. + + If any deployment fails, all successfully-started containers are + stopped before an ExceptionGroup is raised. + """ + from dimos.core.docker_runner import DockerModule + + def _on_errors( + _outcomes: list[Any], successes: list[DockerModule], errors: list[Exception] + ) -> None: + for mod in successes: + with suppress(Exception): + mod.stop() + raise ExceptionGroup("docker deploy_parallel failed", errors) + + return safe_thread_map( + specs, + lambda spec: DockerModule(spec[0], global_config=spec[1], **spec[2]), # type: ignore[arg-type] + _on_errors, + ) diff --git a/dimos/core/module.py b/dimos/core/module.py index 1c5b311883..6b12843a3a 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -456,17 +456,6 @@ def set_transport(self, stream_name: str, transport: Transport) -> bool: # type stream._transport = transport return True - @rpc - def configure_stream(self, stream_name: str, topic: str) -> bool: - """Configure a stream's transport by topic. Called by DockerModule for stream wiring.""" - from dimos.core.transport import pLCMTransport - - stream = getattr(self, stream_name, None) - if not isinstance(stream, (Out, In)): - return False - stream._transport = pLCMTransport(topic) - return True - # called from remote def connect_stream(self, input_name: str, remote_stream: RemoteOut[T]): # type: ignore[no-untyped-def] input_stream = getattr(self, input_name, None) diff --git a/dimos/core/module_coordinator.py b/dimos/core/module_coordinator.py index 10227eae93..43e3e44f0a 100644 --- a/dimos/core/module_coordinator.py +++ b/dimos/core/module_coordinator.py @@ -14,19 +14,20 @@ from __future__ import annotations -from concurrent.futures import ThreadPoolExecutor import threading from typing import TYPE_CHECKING, Any +from dimos.core.docker_worker_manager import DockerWorkerManager from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import ModuleBase, ModuleSpec from dimos.core.resource import Resource from dimos.core.worker_manager import WorkerManager from dimos.utils.logging_config import setup_logger +from dimos.utils.safe_thread_map import ExceptionGroup, safe_thread_map if TYPE_CHECKING: from dimos.core.resource_monitor.monitor import StatsMonitor - from dimos.core.rpc_client import ModuleProxy + from dimos.core.rpc_client import ModuleProxy, ModuleProxyProtocol from dimos.core.worker import Worker logger = setup_logger() @@ -37,7 +38,7 @@ class ModuleCoordinator(Resource): # type: ignore[misc] _global_config: GlobalConfig _n: int | None = None _memory_limit: str = "auto" - _deployed_modules: dict[type[ModuleBase], ModuleProxy] + _deployed_modules: dict[type[ModuleBase], ModuleProxyProtocol] _stats_monitor: StatsMonitor | None = None def __init__( @@ -113,7 +114,8 @@ def stop(self) -> None: logger.error("Error stopping module", module=module_class.__name__, exc_info=True) logger.info("Module stopped.", module=module_class.__name__) - self._client.close_all() # type: ignore[union-attr] + if self._client is not None: + self._client.close_all() def deploy( self, @@ -121,35 +123,90 @@ def deploy( global_config: GlobalConfig = global_config, **kwargs: Any, ) -> ModuleProxy: + # Inline to avoid circular import: module_coordinator → docker_runner → module → blueprints → module_coordinator + from dimos.core.docker_runner import DockerModule, is_docker_module + if not self._client: raise ValueError("Trying to dimos.deploy before the client has started") - module = self._client.deploy(module_class, global_config, kwargs) - self._deployed_modules[module_class] = module # type: ignore[assignment] - return module # type: ignore[return-value] + deployed_module: ModuleProxyProtocol + if is_docker_module(module_class): + deployed_module = DockerModule(module_class, global_config=global_config, **kwargs) # type: ignore[arg-type] + else: + deployed_module = self._client.deploy(module_class, global_config, kwargs) + self._deployed_modules[module_class] = deployed_module # type: ignore[assignment] + return deployed_module # type: ignore[return-value] def deploy_parallel(self, module_specs: list[ModuleSpec]) -> list[ModuleProxy]: + # Inline to avoid circular import: module_coordinator → docker_runner → module → blueprints → module_coordinator + from dimos.core.docker_runner import is_docker_module + if not self._client: raise ValueError("Not started") - modules = self._client.deploy_parallel(module_specs) - for (module_class, _, _), module in zip(module_specs, modules, strict=True): - self._deployed_modules[module_class] = module # type: ignore[assignment] - return modules # type: ignore[return-value] + # Split by type, tracking original indices for reassembly + docker_indices: list[int] = [] + worker_indices: list[int] = [] + docker_specs: list[ModuleSpec] = [] + worker_specs: list[ModuleSpec] = [] + for i, spec in enumerate(module_specs): + if is_docker_module(spec[0]): + docker_indices.append(i) + docker_specs.append(spec) + else: + worker_indices.append(i) + worker_specs.append(spec) + + # Deploy worker and docker modules in parallel. + results: list[Any] = [None] * len(module_specs) + + def _deploy_workers() -> None: + if not worker_specs: + return + assert self._client is not None + for index, module in zip( + worker_indices, self._client.deploy_parallel(worker_specs), strict=False + ): + results[index] = module + + def _deploy_docker() -> None: + if not docker_specs: + return + for index, module in zip( + docker_indices, DockerWorkerManager.deploy_parallel(docker_specs), strict=False + ): + results[index] = module + + def _register() -> None: + for (module_class, _, _), module in zip(module_specs, results, strict=False): + if module is not None: + self._deployed_modules[module_class] = module + + def _on_errors( + _outcomes: list[Any], _successes: list[Any], errors: list[Exception] + ) -> None: + _register() + raise ExceptionGroup("deploy_parallel failed", errors) + + safe_thread_map([_deploy_workers, _deploy_docker], lambda fn: fn(), _on_errors) + _register() + return results def start_all_modules(self) -> None: modules = list(self._deployed_modules.values()) - if isinstance(self._client, WorkerManager): - with ThreadPoolExecutor(max_workers=len(modules)) as executor: - list(executor.map(lambda m: m.start(), modules)) - else: - for module in modules: - module.start() + if not modules: + raise ValueError("No modules deployed. Call deploy() before start_all_modules().") + + def _on_start_errors( + _outcomes: list[Any], _successes: list[Any], errors: list[Exception] + ) -> None: + raise ExceptionGroup("start_all_modules failed", errors) + + safe_thread_map(modules, lambda m: m.start(), _on_start_errors) - module_list = list(self._deployed_modules.values()) for module in modules: if hasattr(module, "on_system_modules"): - module.on_system_modules(module_list) + module.on_system_modules(modules) def get_instance(self, module: type[ModuleBase]) -> ModuleProxy: return self._deployed_modules.get(module) # type: ignore[return-value, no-any-return] diff --git a/dimos/core/rpc_client.py b/dimos/core/rpc_client.py index 84de18d671..13add06a02 100644 --- a/dimos/core/rpc_client.py +++ b/dimos/core/rpc_client.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Callable -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Protocol from dimos.core.stream import RemoteStream from dimos.core.worker import MethodCallProxy @@ -81,6 +81,17 @@ def __setstate__(self, state) -> None: # type: ignore[no-untyped-def] self._stop_rpc_client = None +class ModuleProxyProtocol(Protocol): + """Protocol for host-side handles to remote modules (worker or Docker).""" + + def start(self) -> None: ... + def stop(self) -> None: ... + def set_transport(self, stream_name: str, transport: Any) -> bool: ... + def get_rpc_method_names(self) -> list[str]: ... + def set_rpc_method(self, method: str, callable: RpcCall) -> None: ... + def get_rpc_calls(self, *methods: str) -> RpcCall | tuple[RpcCall, ...]: ... + + class RPCClient: def __init__(self, actor_instance, actor_class) -> None: # type: ignore[no-untyped-def] self.rpc = LCMRPC() diff --git a/dimos/core/run_registry.py b/dimos/core/run_registry.py index 617872011c..a3807194f6 100644 --- a/dimos/core/run_registry.py +++ b/dimos/core/run_registry.py @@ -21,6 +21,7 @@ import os from pathlib import Path import re +import signal import time from dimos.utils.logging_config import setup_logger @@ -143,9 +144,6 @@ def get_most_recent(alive_only: bool = True) -> RunEntry | None: return runs[-1] if runs else None -import signal - - def stop_entry(entry: RunEntry, force: bool = False) -> tuple[str, bool]: """Stop a DimOS instance by registry entry. diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index f9a89829d5..7cd0f89b36 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -77,7 +77,7 @@ def test_classmethods() -> None: # Check that we have the expected RPC methods assert "navigate_to" in class_rpcs, "navigate_to should be in rpcs" assert "start" in class_rpcs, "start should be in rpcs" - assert len(class_rpcs) == 9 + assert len(class_rpcs) == 8 # Check that the values are callable assert callable(class_rpcs["navigate_to"]), "navigate_to should be callable" diff --git a/dimos/core/tests/test_docker_deployment.py b/dimos/core/tests/test_docker_deployment.py new file mode 100644 index 0000000000..a3bb0b716d --- /dev/null +++ b/dimos/core/tests/test_docker_deployment.py @@ -0,0 +1,290 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Smoke tests for Docker module deployment routing. + +These tests verify that the ModuleCoordinator correctly detects and routes +docker modules to DockerModule WITHOUT actually running Docker. +""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from dimos.core.docker_runner import DockerModuleConfig, is_docker_module +from dimos.core.global_config import global_config +from dimos.core.module import Module +from dimos.core.module_coordinator import ModuleCoordinator +from dimos.core.stream import Out + +# -- Fixtures: fake module classes ------------------------------------------- + + +class FakeDockerConfig(DockerModuleConfig): + docker_image: str = "fake:latest" + docker_file: Path | None = None + docker_gpus: str | None = None + docker_rm: bool = True + docker_restart_policy: str = "no" + + +class FakeDockerModule(Module["FakeDockerConfig"]): + default_config = FakeDockerConfig + output: Out[str] + + +class FakeRegularModule(Module): + output: Out[str] + + +# -- Tests ------------------------------------------------------------------- + + +class TestIsDockerModule: + def test_docker_module_detected(self): + assert is_docker_module(FakeDockerModule) is True + + def test_regular_module_not_detected(self): + assert is_docker_module(FakeRegularModule) is False + + def test_plain_class_not_detected(self): + assert is_docker_module(str) is False + + def test_no_default_config(self): + class Bare(Module): + pass + + # Module has default_config = ModuleConfig, which is not DockerModuleConfig + assert is_docker_module(Bare) is False + + +class TestModuleCoordinatorDockerRouting: + @patch("dimos.core.docker_runner.DockerModule") + @patch("dimos.core.module_coordinator.WorkerManager") + def test_deploy_routes_docker_module(self, mock_worker_manager_cls, mock_docker_module_cls): + mock_worker_mgr = MagicMock() + mock_worker_manager_cls.return_value = mock_worker_mgr + + mock_dm = MagicMock() + mock_docker_module_cls.return_value = mock_dm + + coordinator = ModuleCoordinator() + coordinator.start() + + result = coordinator.deploy(FakeDockerModule) + + # Should NOT go through worker manager + mock_worker_mgr.deploy.assert_not_called() + # Should construct a DockerModule (container launch happens inside __init__) + mock_docker_module_cls.assert_called_once_with( + FakeDockerModule, global_config=global_config + ) + # start() is NOT called during deploy — it's called in start_all_modules + mock_dm.start.assert_not_called() + assert result is mock_dm + assert coordinator.get_instance(FakeDockerModule) is mock_dm + + coordinator.stop() + + @patch("dimos.core.docker_runner.DockerModule") + @patch("dimos.core.module_coordinator.WorkerManager") + def test_deploy_docker_propagates_constructor_failure( + self, mock_worker_manager_cls, mock_docker_module_cls + ): + mock_worker_mgr = MagicMock() + mock_worker_manager_cls.return_value = mock_worker_mgr + + # Container launch fails inside __init__; DockerModule handles its own cleanup + mock_docker_module_cls.side_effect = RuntimeError("launch failed") + + coordinator = ModuleCoordinator() + coordinator.start() + + with pytest.raises(RuntimeError, match="launch failed"): + coordinator.deploy(FakeDockerModule) + + coordinator.stop() + + @patch("dimos.core.module_coordinator.WorkerManager") + def test_deploy_routes_regular_module_to_worker_manager(self, mock_worker_manager_cls): + mock_worker_mgr = MagicMock() + mock_worker_manager_cls.return_value = mock_worker_mgr + mock_proxy = MagicMock() + mock_worker_mgr.deploy.return_value = mock_proxy + + coordinator = ModuleCoordinator() + coordinator.start() + + result = coordinator.deploy(FakeRegularModule) + + mock_worker_mgr.deploy.assert_called_once_with(FakeRegularModule, global_config, {}) + assert result is mock_proxy + + coordinator.stop() + + @patch("dimos.core.docker_worker_manager.DockerWorkerManager.deploy_parallel") + @patch("dimos.core.module_coordinator.WorkerManager") + def test_deploy_parallel_separates_docker_and_regular( + self, mock_worker_manager_cls, mock_docker_deploy + ): + mock_worker_mgr = MagicMock() + mock_worker_manager_cls.return_value = mock_worker_mgr + + regular_proxy = MagicMock() + mock_worker_mgr.deploy_parallel.return_value = [regular_proxy] + + mock_dm = MagicMock() + mock_docker_deploy.return_value = [mock_dm] + + coordinator = ModuleCoordinator() + coordinator.start() + + specs = [ + (FakeRegularModule, (), {}), + (FakeDockerModule, (), {}), + ] + results = coordinator.deploy_parallel(specs) + + # Regular module goes through worker manager + mock_worker_mgr.deploy_parallel.assert_called_once_with([(FakeRegularModule, (), {})]) + # Docker specs go through DockerWorkerManager + mock_docker_deploy.assert_called_once_with([(FakeDockerModule, (), {})]) + # start() is NOT called during deploy — it's called in start_all_modules + mock_dm.start.assert_not_called() + + # Results preserve input order + assert results[0] is regular_proxy + assert results[1] is mock_dm + + coordinator.stop() + + @patch("dimos.core.docker_runner.DockerModule") + @patch("dimos.core.module_coordinator.WorkerManager") + def test_stop_cleans_up_docker_modules(self, mock_worker_manager_cls, mock_docker_module_cls): + mock_worker_mgr = MagicMock() + mock_worker_manager_cls.return_value = mock_worker_mgr + + mock_dm = MagicMock() + mock_docker_module_cls.return_value = mock_dm + + coordinator = ModuleCoordinator() + coordinator.start() + coordinator.deploy(FakeDockerModule) + coordinator.stop() + + # stop() called exactly once (no double cleanup) + assert mock_dm.stop.call_count == 1 + # Worker manager also closed + mock_worker_mgr.close_all.assert_called_once() + + +class TestDockerModuleGetattr: + """Tests for DockerModule.__getattr__ avoiding infinite recursion.""" + + def test_getattr_no_recursion_when_rpcs_not_set(self): + """If __init__ fails before self.rpcs is assigned, __getattr__ must not recurse.""" + from dimos.core.docker_runner import DockerModule + + dm = DockerModule.__new__(DockerModule) + # Don't set rpcs, _module_class, or any instance attrs — simulates early __init__ failure + with pytest.raises(AttributeError): + _ = dm.some_method + + def test_getattr_no_recursion_on_cleanup_attrs(self): + """Accessing cleanup-related attrs before they exist must raise, not recurse.""" + from dimos.core.docker_runner import DockerModule + + dm = DockerModule.__new__(DockerModule) + # These are accessed during _cleanup() — if rpcs isn't set, they must not recurse + for attr in ("rpc", "config", "_container_name", "_unsub_fns"): + with pytest.raises(AttributeError): + getattr(dm, attr) + + def test_getattr_delegates_to_rpc_when_rpcs_set(self): + from dimos.core.docker_runner import DockerModule + from dimos.core.rpc_client import RpcCall + + dm = DockerModule.__new__(DockerModule) + dm.rpcs = {"do_thing"} + + # _module_class needs a real method with __name__ for RpcCall + class FakeMod: + def do_thing(self) -> None: ... + + dm._module_class = FakeMod + dm.rpc = MagicMock() + dm.remote_name = "FakeMod" + dm._unsub_fns = [] + + result = dm.do_thing + assert isinstance(result, RpcCall) + + def test_getattr_raises_for_unknown_method(self): + from dimos.core.docker_runner import DockerModule + + dm = DockerModule.__new__(DockerModule) + dm.rpcs = {"do_thing"} + + with pytest.raises(AttributeError, match="not found"): + _ = dm.nonexistent + + +class TestDockerModuleCleanupReconnect: + """Tests for DockerModule._cleanup with docker_reconnect_container.""" + + def test_cleanup_skips_stop_when_reconnect(self): + from dimos.core.docker_runner import DockerModule + + with patch.object(DockerModule, "__init__", lambda self: None): + dm = DockerModule.__new__(DockerModule) + dm._running = True + dm._container_name = "test_container" + dm._unsub_fns = [] + dm.rpc = MagicMock() + dm.remote_name = "TestModule" + + # reconnect mode: should NOT stop/rm the container + dm.config = FakeDockerConfig(docker_reconnect_container=True) + with ( + patch("dimos.core.docker_runner._run") as mock_run, + patch("dimos.core.docker_runner._remove_container") as mock_rm, + ): + dm._cleanup() + mock_run.assert_not_called() + mock_rm.assert_not_called() + + def test_cleanup_stops_container_when_not_reconnect(self): + from dimos.core.docker_runner import DockerModule + + with patch.object(DockerModule, "__init__", lambda self: None): + dm = DockerModule.__new__(DockerModule) + dm._running = True + dm._container_name = "test_container" + dm._unsub_fns = [] + dm.rpc = MagicMock() + dm.remote_name = "TestModule" + + # normal mode: should stop and rm the container + dm.config = FakeDockerConfig(docker_reconnect_container=False) + with ( + patch("dimos.core.docker_runner._run") as mock_run, + patch("dimos.core.docker_runner._remove_container") as mock_rm, + ): + dm._cleanup() + mock_run.assert_called_once() # docker stop + mock_rm.assert_called_once() # docker rm -f diff --git a/dimos/core/tests/test_parallel_deploy_cleanup.py b/dimos/core/tests/test_parallel_deploy_cleanup.py new file mode 100644 index 0000000000..1987fa4be7 --- /dev/null +++ b/dimos/core/tests/test_parallel_deploy_cleanup.py @@ -0,0 +1,219 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests that deploy_parallel cleans up successfully-started modules when a +sibling deployment fails ("middle module throws" scenario). +""" + +from __future__ import annotations + +import threading +from unittest.mock import MagicMock, patch + +import pytest + + +class TestDockerWorkerManagerPartialFailure: + """DockerWorkerManager.deploy_parallel must stop successful containers when one fails.""" + + @patch("dimos.core.docker_runner.DockerModule") + def test_middle_module_fails_stops_siblings(self, mock_docker_module_cls): + """Deploy 3 modules where the middle one fails. The other two must be stopped.""" + from dimos.core.docker_worker_manager import DockerWorkerManager + + mod_a = MagicMock(name="ModuleA") + mod_c = MagicMock(name="ModuleC") + + barrier = threading.Barrier(3, timeout=5) + + def fake_constructor(cls, *args, **kwargs): + label = cls.__name__ + barrier.wait() + if label == "B": + raise RuntimeError("B failed to start") + return mod_a if label == "A" else mod_c + + mock_docker_module_cls.side_effect = fake_constructor + + FakeA = type("A", (), {}) + FakeB = type("B", (), {}) + FakeC = type("C", (), {}) + + with pytest.raises(ExceptionGroup, match="docker deploy_parallel failed") as exc_info: + DockerWorkerManager.deploy_parallel( + [ + (FakeA, (), {}), + (FakeB, (), {}), + (FakeC, (), {}), + ] + ) + + assert len(exc_info.value.exceptions) == 1 + assert "B failed to start" in str(exc_info.value.exceptions[0]) + + # Both successful modules must have been stopped exactly once + mod_a.stop.assert_called_once() + mod_c.stop.assert_called_once() + + @patch("dimos.core.docker_runner.DockerModule") + def test_multiple_failures_raises_exception_group(self, mock_docker_module_cls): + """Deploy 3 modules where two fail. Should raise ExceptionGroup with both errors.""" + from dimos.core.docker_worker_manager import DockerWorkerManager + + mod_a = MagicMock(name="ModuleA") + + barrier = threading.Barrier(3, timeout=5) + + def fake_constructor(cls, *args, **kwargs): + label = cls.__name__ + barrier.wait() + if label == "B": + raise RuntimeError("B failed") + if label == "C": + raise ValueError("C failed") + return mod_a + + mock_docker_module_cls.side_effect = fake_constructor + + FakeA = type("A", (), {}) + FakeB = type("B", (), {}) + FakeC = type("C", (), {}) + + with pytest.raises(ExceptionGroup, match="docker deploy_parallel failed") as exc_info: + DockerWorkerManager.deploy_parallel( + [ + (FakeA, (), {}), + (FakeB, (), {}), + (FakeC, (), {}), + ] + ) + + assert len(exc_info.value.exceptions) == 2 + messages = {str(e) for e in exc_info.value.exceptions} + assert "B failed" in messages + assert "C failed" in messages + + # The one successful module must have been stopped + mod_a.stop.assert_called_once() + + @patch("dimos.core.docker_runner.DockerModule") + def test_all_succeed_no_stops(self, mock_docker_module_cls): + """When all deployments succeed, no modules should be stopped.""" + from dimos.core.docker_worker_manager import DockerWorkerManager + + mocks = [MagicMock(name=f"Mod{i}") for i in range(3)] + + def fake_constructor(cls, *args, **kwargs): + return mocks[["A", "B", "C"].index(cls.__name__)] + + mock_docker_module_cls.side_effect = fake_constructor + + FakeA = type("A", (), {}) + FakeB = type("B", (), {}) + FakeC = type("C", (), {}) + + results = DockerWorkerManager.deploy_parallel( + [ + (FakeA, (), {}), + (FakeB, (), {}), + (FakeC, (), {}), + ] + ) + + assert len(results) == 3 + for m in mocks: + m.stop.assert_not_called() + + @patch("dimos.core.docker_runner.DockerModule") + def test_stop_failure_does_not_mask_deploy_error(self, mock_docker_module_cls): + """If stop() itself raises during cleanup, the original deploy error still propagates.""" + from dimos.core.docker_worker_manager import DockerWorkerManager + + mod_a = MagicMock(name="ModuleA") + mod_a.stop.side_effect = OSError("stop failed") + + barrier = threading.Barrier(2, timeout=5) + + def fake_constructor(cls, *args, **kwargs): + barrier.wait() + if cls.__name__ == "B": + raise RuntimeError("B exploded") + return mod_a + + mock_docker_module_cls.side_effect = fake_constructor + + FakeA = type("A", (), {}) + FakeB = type("B", (), {}) + + with pytest.raises(ExceptionGroup, match="docker deploy_parallel failed"): + DockerWorkerManager.deploy_parallel([(FakeA, (), {}), (FakeB, (), {})]) + + # stop was attempted despite it raising + mod_a.stop.assert_called_once() + + +class TestWorkerManagerPartialFailure: + """WorkerManager.deploy_parallel must clean up successful RPCClients when one fails.""" + + def test_middle_module_fails_cleans_up_siblings(self): + from dimos.core.worker_manager import WorkerManager + + manager = WorkerManager(n_workers=2) + + mock_workers = [MagicMock(name=f"Worker{i}") for i in range(2)] + for w in mock_workers: + w.module_count = 0 + w.reserve_slot = MagicMock( + side_effect=lambda w=w: setattr(w, "module_count", w.module_count + 1) + ) + + manager._workers = mock_workers + manager._started = True + + def fake_deploy_module(module_class, args=(), kwargs=None): + if module_class.__name__ == "B": + raise RuntimeError("B failed to deploy") + return MagicMock(name=f"actor_{module_class.__name__}") + + for w in mock_workers: + w.deploy_module = fake_deploy_module + + FakeA = type("A", (), {}) + FakeB = type("B", (), {}) + FakeC = type("C", (), {}) + + rpc_clients_created: list[MagicMock] = [] + + with patch("dimos.core.worker_manager.RPCClient") as mock_rpc_cls: + + def make_rpc(actor, cls): + client = MagicMock(name=f"rpc_{cls.__name__}") + rpc_clients_created.append(client) + return client + + mock_rpc_cls.side_effect = make_rpc + + with pytest.raises(ExceptionGroup, match="worker deploy_parallel failed"): + manager.deploy_parallel( + [ + (FakeA, (), {}), + (FakeB, (), {}), + (FakeC, (), {}), + ] + ) + + # Every successfully-created RPC client must have been cleaned up exactly once + for client in rpc_clients_created: + client.stop_rpc_client.assert_called_once() diff --git a/dimos/core/worker.py b/dimos/core/worker.py index dca561f16c..93dbec6e2b 100644 --- a/dimos/core/worker.py +++ b/dimos/core/worker.py @@ -214,25 +214,27 @@ def deploy_module( "module_class": module_class, "kwargs": kwargs, } - with self._lock: - self._conn.send(request) - response = self._conn.recv() + try: + with self._lock: + self._conn.send(request) + response = self._conn.recv() - if response.get("error"): - raise RuntimeError(f"Failed to deploy module: {response['error']}") - - actor = Actor(self._conn, module_class, self._worker_id, module_id, self._lock) - actor.set_ref(actor).result() - - self._modules[module_id] = actor - self._reserved = max(0, self._reserved - 1) - logger.info( - "Deployed module.", - module=module_class.__name__, - worker_id=self._worker_id, - module_id=module_id, - ) - return actor + if response.get("error"): + raise RuntimeError(f"Failed to deploy module: {response['error']}") + + actor = Actor(self._conn, module_class, self._worker_id, module_id, self._lock) + actor.set_ref(actor).result() + + self._modules[module_id] = actor + logger.info( + "Deployed module.", + module=module_class.__name__, + worker_id=self._worker_id, + module_id=module_id, + ) + return actor + finally: + self._reserved = max(0, self._reserved - 1) def suppress_console(self) -> None: if self._conn is None: diff --git a/dimos/core/worker_manager.py b/dimos/core/worker_manager.py index 4cd5eec8d7..2b778c433e 100644 --- a/dimos/core/worker_manager.py +++ b/dimos/core/worker_manager.py @@ -15,7 +15,7 @@ from __future__ import annotations from collections.abc import Iterable -from concurrent.futures import ThreadPoolExecutor +from contextlib import suppress from typing import Any from dimos.core.global_config import GlobalConfig @@ -23,6 +23,7 @@ from dimos.core.rpc_client import RPCClient from dimos.core.worker import Worker from dimos.utils.logging_config import setup_logger +from dimos.utils.safe_thread_map import ExceptionGroup, safe_thread_map logger = setup_logger() @@ -65,6 +66,10 @@ def deploy_parallel(self, module_specs: Iterable[ModuleSpec]) -> list[RPCClient] if self._closed: raise RuntimeError("WorkerManager is closed") + module_specs = list(module_specs) + if len(module_specs) == 0: + return [] + # Auto-start for backward compatibility if not self._started: self.start() @@ -78,17 +83,19 @@ def deploy_parallel(self, module_specs: Iterable[ModuleSpec]) -> list[RPCClient] worker.reserve_slot() assignments.append((worker, module_class, global_config, kwargs)) - def _deploy( - item: tuple[Worker, type[ModuleBase], GlobalConfig, dict[str, Any]], - ) -> RPCClient: - worker, module_class, global_config, kwargs = item - actor = worker.deploy_module(module_class, global_config=global_config, kwargs=kwargs) - return RPCClient(actor, module_class) - - with ThreadPoolExecutor(max_workers=len(assignments)) as pool: - results = list(pool.map(_deploy, assignments)) - - return results + def _on_errors( + _outcomes: list[Any], successes: list[RPCClient], errors: list[Exception] + ) -> None: + for rpc_client in successes: + with suppress(Exception): + rpc_client.stop_rpc_client() + raise ExceptionGroup("worker deploy_parallel failed", errors) + + return safe_thread_map( + assignments, + lambda item: RPCClient(item[0].deploy_module(item[1], item[2], item[3]), item[1]), + _on_errors, + ) def suppress_console(self) -> None: """Tell all workers to redirect stdout/stderr to /dev/null.""" diff --git a/dimos/simulation/mujoco/policy.py b/dimos/simulation/mujoco/policy.py index 212c7ac60a..1d0598ce46 100644 --- a/dimos/simulation/mujoco/policy.py +++ b/dimos/simulation/mujoco/policy.py @@ -20,7 +20,7 @@ import mujoco import numpy as np -import onnxruntime as ort # type: ignore[import-untyped] +import onnxruntime as ort # type: ignore[import-untyped,import-not-found] from dimos.simulation.mujoco.input_controller import InputController from dimos.utils.logging_config import setup_logger diff --git a/dimos/utils/safe_thread_map.py b/dimos/utils/safe_thread_map.py new file mode 100644 index 0000000000..f480f2c97d --- /dev/null +++ b/dimos/utils/safe_thread_map.py @@ -0,0 +1,110 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from concurrent.futures import Future, ThreadPoolExecutor, as_completed +import sys +from typing import TYPE_CHECKING, Any, TypeVar + +if sys.version_info < (3, 11): + + class ExceptionGroup(Exception): # type: ignore[no-redef] # noqa: N818 + """Minimal ExceptionGroup polyfill for Python 3.10.""" + + exceptions: tuple[BaseException, ...] + + def __init__(self, message: str, exceptions: Sequence[BaseException]) -> None: + super().__init__(message) + self.exceptions = tuple(exceptions) +else: + import builtins + + ExceptionGroup = builtins.ExceptionGroup # type: ignore[misc] + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + +T = TypeVar("T") +R = TypeVar("R") + + +def safe_thread_map( + items: Sequence[T], + fn: Callable[[T], R], + on_errors: Callable[[list[tuple[T, R | Exception]], list[R], list[Exception]], Any] + | None = None, +) -> list[R]: + """Thread-pool map that waits for all items to finish before raising and a cleanup handler + + - Empty *items* → returns ``[]`` immediately. + - All succeed → returns results in input order. + - Any fail → calls ``on_errors(outcomes, successes, errors)`` where + *outcomes* is a list of ``(input, result_or_exception)`` pairs in input + order, *successes* is the list of successful results, and *errors* is + the list of exceptions. If *on_errors* raises, that exception propagates. + If *on_errors* returns normally, its return value is returned from + ``safe_thread_map``. If *on_errors* is ``None``, raises an + ``ExceptionGroup``. + + Example:: + + def start_service(name: str) -> Connection: + return connect(name) + + def cleanup( + outcomes: list[tuple[str, Connection | Exception]], + successes: list[Connection], + errors: list[Exception], + ) -> None: + for conn in successes: + conn.close() + raise ExceptionGroup("failed to start services", errors) + + connections = safe_thread_map( + ["db", "cache", "queue"], + start_service, + cleanup, # called only if any start_service() raises + ) + """ + if not items: + return [] + + outcomes: dict[int, R | Exception] = {} + + with ThreadPoolExecutor(max_workers=len(items)) as pool: + futures: dict[Future[R], int] = {pool.submit(fn, item): i for i, item in enumerate(items)} + for fut in as_completed(futures): + idx = futures[fut] + try: + outcomes[idx] = fut.result() + except Exception as e: + outcomes[idx] = e + + # Note: successes/errors are in completion order, not input order. + # This is fine — on_errors only needs them for cleanup, not ordering. + successes: list[R] = [] + errors: list[Exception] = [] + for v in outcomes.values(): + if isinstance(v, Exception): + errors.append(v) + else: + successes.append(v) + + if errors: + if on_errors is not None: + zipped = [(items[i], outcomes[i]) for i in range(len(items))] + return on_errors(zipped, successes, errors) # type: ignore[return-value, no-any-return] + raise ExceptionGroup("safe_thread_map failed", errors) + + return [outcomes[i] for i in range(len(items))] # type: ignore[misc] diff --git a/examples/docker_hello_world/Dockerfile b/examples/docker_hello_world/Dockerfile new file mode 100644 index 0000000000..3ceb24b3b4 --- /dev/null +++ b/examples/docker_hello_world/Dockerfile @@ -0,0 +1,15 @@ +FROM python:3.12-slim + +RUN apt-get update && apt-get install -y \ + iproute2 \ + libx11-6 libgl1 libglib2.0-0 \ + libidn2-0 libgfortran5 libgomp1 \ + cowsay \ + && rm -rf /var/lib/apt/lists/* + + +# Copy example module so it's importable inside the container +COPY examples/docker_hello_world/hello_docker.py /dimos/source/examples/docker_hello_world/hello_docker.py +RUN touch /dimos/source/examples/__init__.py /dimos/source/examples/docker_hello_world/__init__.py + +WORKDIR /app diff --git a/examples/docker_hello_world/hello_docker.py b/examples/docker_hello_world/hello_docker.py new file mode 100644 index 0000000000..3b8e96e49b --- /dev/null +++ b/examples/docker_hello_world/hello_docker.py @@ -0,0 +1,144 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Hello World Docker Module +========================== + +Minimal example showing a DimOS module running inside Docker. + +The module receives a string on its ``prompt`` input stream, runs it through +cowsay inside the container, and publishes the ASCII art on its ``greeting`` +output stream. + +NOTE: Requires Linux. Docker Desktop on macOS does not support host networking, +which is needed for LCM multicast between host and container. + +Usage: + python examples/docker_hello_world/hello_docker.py +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +import subprocess +import time + +from dimos.core.core import rpc +from dimos.core.docker_runner import DockerModuleConfig +from dimos.core.module import Module +from dimos.core.stream import In, Out + + +@dataclass(kw_only=True) +class HelloDockerConfig(DockerModuleConfig): + docker_image: str = "dimos-hello-docker:latest" + docker_file: Path | None = Path(__file__).parent / "Dockerfile" + docker_build_context: Path | None = Path(__file__).parents[2] # repo root + docker_gpus: str | None = None # no GPU needed + docker_rm: bool = True + docker_restart_policy: str = "no" + docker_env: dict[str, str] = field(default_factory=lambda: {"CI": "1"}) + + # Custom (non-docker) config field — passed to the container via JSON + greeting_prefix: str = "Hello" + + +class HelloDockerModule(Module["HelloDockerConfig"]): + """A trivial module that runs inside Docker and echoes greetings.""" + + default_config = HelloDockerConfig + + prompt: In[str] + greeting: Out[str] + + @rpc + def start(self) -> None: + super().start() + self.prompt.subscribe(self._on_prompt) + + def _cowsay(self, text: str) -> str: + """Run cowsay inside the container and return the ASCII art.""" + result = subprocess.run( + ["/usr/games/cowsay", text], + capture_output=True, + text=True, + check=True, + ) + return result.stdout + + def _on_prompt(self, text: str) -> None: + art = self._cowsay(text) + print(f"[HelloDockerModule]\n{art}") + self.greeting.publish(art) + + @rpc + def greet(self, name: str) -> str: + """RPC method that can be called directly.""" + prefix = self.config.greeting_prefix + return self._cowsay(f"{prefix}, {name}!") + + @rpc + def get_greeting_prefix(self) -> str: + """Return the config value to verify it was passed to the container.""" + return self.config.greeting_prefix + + +class PromptModule(Module): + """Publishes prompts and listens to greetings.""" + + prompt: Out[str] + greeting: In[str] + + @rpc + def start(self) -> None: + super().start() + self.greeting.subscribe(self._on_greeting) + + @rpc + def send(self, text: str) -> None: + """Publish a prompt message onto the stream.""" + self.prompt.publish(text) + + def _on_greeting(self, text: str) -> None: + print(f"[PromptModule] Received: {text}") + + +if __name__ == "__main__": + from dimos.core.blueprints import autoconnect + + coordinator = autoconnect( + PromptModule.blueprint(), + HelloDockerModule.blueprint(greeting_prefix="Howdy"), + ).build() + + # Get module proxies + prompt_mod = coordinator.get_instance(PromptModule) + docker_mod = coordinator.get_instance(HelloDockerModule) + + # Test that custom config was passed to the container + prefix = docker_mod.get_greeting_prefix() + assert prefix == "Howdy", f"Expected 'Howdy', got {prefix!r}" + print(f"Config passed to container: greeting_prefix={prefix!r}") + + # Test RPC (should use the custom prefix) + print(docker_mod.greet("World")) + + # Test stream + prompt_mod.send("stream test") + time.sleep(2) + + coordinator.stop() + print("Done!") diff --git a/pyproject.toml b/pyproject.toml index 722e3b0485..8cc7964d99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -324,8 +324,12 @@ docker = [ "sortedcontainers", "PyTurboJPEG", "rerun-sdk", + "typing_extensions", "open3d-unofficial-arm; platform_system == 'Linux' and platform_machine == 'aarch64'", "open3d>=0.18.0; platform_system != 'Linux' or platform_machine != 'aarch64'", + # these below should be removed later, right now they are needed even for running `dimos --help` (seperate non-docker issue) + "langchain-core", + "matplotlib", ] base = [ diff --git a/uv.lock b/uv.lock index 5ec39fff59..0994e8866d 100644 --- a/uv.lock +++ b/uv.lock @@ -1850,7 +1850,9 @@ dev = [ ] docker = [ { name = "dimos-lcm" }, + { name = "langchain-core" }, { name = "lcm" }, + { name = "matplotlib" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "open3d", marker = "platform_machine != 'aarch64' or sys_platform != 'linux'" }, @@ -1868,6 +1870,7 @@ docker = [ { name = "sortedcontainers" }, { name = "structlog" }, { name = "typer" }, + { name = "typing-extensions" }, ] drone = [ { name = "pymavlink" }, @@ -2011,6 +2014,7 @@ requires-dist = [ { name = "langchain", marker = "extra == 'agents'", specifier = "==1.2.3" }, { name = "langchain-chroma", marker = "extra == 'agents'", specifier = ">=1,<2" }, { name = "langchain-core", marker = "extra == 'agents'", specifier = "==1.2.3" }, + { name = "langchain-core", marker = "extra == 'docker'" }, { name = "langchain-huggingface", marker = "extra == 'agents'", specifier = ">=1,<2" }, { name = "langchain-ollama", marker = "extra == 'agents'", specifier = ">=1,<2" }, { name = "langchain-openai", marker = "extra == 'agents'", specifier = ">=1,<2" }, @@ -2021,6 +2025,7 @@ requires-dist = [ { name = "lcm", marker = "extra == 'docker'" }, { name = "llvmlite", specifier = ">=0.42.0" }, { name = "lxml-stubs", marker = "extra == 'dev'", specifier = ">=0.5.1,<1" }, + { name = "matplotlib", marker = "extra == 'docker'" }, { name = "matplotlib", marker = "extra == 'manipulation'", specifier = ">=3.7.1" }, { name = "md-babel-py", marker = "extra == 'dev'", specifier = "==1.1.1" }, { name = "moondream", marker = "extra == 'perception'" }, @@ -2129,6 +2134,7 @@ requires-dist = [ { name = "types-tabulate", marker = "extra == 'dev'", specifier = ">=0.9.0.20241207,<1" }, { name = "types-tensorflow", marker = "extra == 'dev'", specifier = ">=2.18.0.20251008,<3" }, { name = "types-tqdm", marker = "extra == 'dev'", specifier = ">=4.67.0.20250809,<5" }, + { name = "typing-extensions", marker = "extra == 'docker'" }, { name = "ultralytics", marker = "extra == 'perception'", specifier = ">=8.3.70" }, { name = "unitree-webrtc-connect-leshy", marker = "extra == 'unitree'", specifier = ">=2.0.7" }, { name = "uvicorn", marker = "extra == 'web'", specifier = ">=0.34.0" },