Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions dimos/core/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@

from abc import ABC
from collections import defaultdict
from collections.abc import Callable, Mapping
from collections.abc import Callable, Mapping, MutableMapping
from dataclasses import dataclass, field, replace
from functools import cached_property, reduce
import operator
import sys
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Literal, get_args, get_origin, get_type_hints

from pydantic import BaseModel, create_model

if TYPE_CHECKING:
from dimos.protocol.service.system_configurator.base import SystemConfigurator

Expand Down Expand Up @@ -130,6 +132,11 @@ def create(cls, module: type[ModuleBase], **kwargs: Any) -> "Blueprint":
def disabled_modules(self, *modules: type[ModuleBase]) -> "Blueprint":
return replace(self, disabled_modules_tuple=self.disabled_modules_tuple + modules)

def config(self) -> type[BaseModel]:
configs = {b.module.name: (b.module.default_config | None, None) for b in self.blueprints}
configs["g"] = (GlobalConfig | None, None)
return create_model("BlueprintConfig", __config__={"extra": "forbid"}, **configs) # type: ignore[call-overload,no-any-return]

def transports(self, transports: dict[tuple[str, type], Any]) -> "Blueprint":
return replace(self, transport_map=MappingProxyType({**self.transport_map, **transports}))

Expand Down Expand Up @@ -274,13 +281,16 @@ def _verify_no_name_conflicts(self) -> None:
raise ValueError("\n".join(error_lines))

def _deploy_all_modules(
self, module_coordinator: ModuleCoordinator, global_config: GlobalConfig
self,
module_coordinator: ModuleCoordinator,
global_config: GlobalConfig,
blueprint_args: Mapping[str, Mapping[str, Any]],
) -> None:
module_specs: list[ModuleSpec] = []
for blueprint in self._active_blueprints:
module_specs.append((blueprint.module, global_config, blueprint.kwargs))
module_specs.append((blueprint.module, global_config, blueprint.kwargs.copy()))

module_coordinator.deploy_parallel(module_specs)
module_coordinator.deploy_parallel(module_specs, blueprint_args)

def _connect_streams(self, module_coordinator: ModuleCoordinator) -> None:
# dict when given (final/remapped) stream name+type, provides a list of modules + original (non-remapped) stream names
Expand Down Expand Up @@ -472,12 +482,13 @@ def _connect_rpc_methods(self, module_coordinator: ModuleCoordinator) -> None:

def build(
self,
cli_config_overrides: Mapping[str, Any] | None = None,
blueprint_args: MutableMapping[str, Any] | None = None,
) -> ModuleCoordinator:
logger.info("Building the blueprint")
global_config.update(**dict(self.global_config_overrides))
if cli_config_overrides:
global_config.update(**dict(cli_config_overrides))
blueprint_args = blueprint_args or {}
if "g" in blueprint_args:
global_config.update(**blueprint_args.pop("g"))

self._run_configurators()
self._check_requirements()
Expand All @@ -488,7 +499,7 @@ def build(
module_coordinator.start()

# all module constructors are called here (each of them setup their own)
self._deploy_all_modules(module_coordinator, global_config)
self._deploy_all_modules(module_coordinator, global_config, blueprint_args)
self._connect_streams(module_coordinator)
self._connect_rpc_methods(module_coordinator)
self._connect_module_refs(module_coordinator)
Expand Down
5 changes: 5 additions & 0 deletions dimos/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ def __init__(self, config_args: dict[str, Any]):
except ValueError:
...

@classproperty
def name(self) -> str:
"""Name for this module to be used for blueprint configs."""
return self.__name__.lower() # type: ignore[attr-defined,no-any-return]

@property
def frame_id(self) -> str:
base = self.config.frame_id or self.__class__.__name__
Expand Down
7 changes: 5 additions & 2 deletions dimos/core/module_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

from collections.abc import Mapping
from concurrent.futures import ThreadPoolExecutor
import threading
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -128,11 +129,13 @@ def deploy(
self._deployed_modules[module_class] = module # type: ignore[assignment]
return module # type: ignore[return-value]

def deploy_parallel(self, module_specs: list[ModuleSpec]) -> list[ModuleProxy]:
def deploy_parallel(
self, module_specs: list[ModuleSpec], blueprint_args: Mapping[str, Mapping[str, Any]]
) -> list[ModuleProxy]:
if not self._client:
raise ValueError("Not started")

modules = self._client.deploy_parallel(module_specs)
modules = self._client.deploy_parallel(module_specs, blueprint_args)
for (module_class, _, _), module in zip(module_specs, modules, strict=True):
self._deployed_modules[module_class] = module # type: ignore[assignment]
return modules # type: ignore[return-value]
Expand Down
8 changes: 8 additions & 0 deletions dimos/core/test_blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,14 @@ def test_autoconnect() -> None:
)


def test_config() -> None:
blueprint = autoconnect(module_a(), module_b())
config = blueprint.config()
assert config.model_fields.keys() == {"modulea", "moduleb"}
assert config.model_fields["modulea"].annotation == ModuleA.default_config
assert config.model_fields["moduleb"].annotation == ModuleB.default_config


def test_transports() -> None:
custom_transport = LCMTransport("/custom_topic", Data1)
blueprint_set = autoconnect(module_a(), module_b()).transports(
Expand Down
9 changes: 7 additions & 2 deletions dimos/core/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import annotations

from collections.abc import Iterable
from collections.abc import Iterable, Mapping
from concurrent.futures import ThreadPoolExecutor
from typing import Any

Expand Down Expand Up @@ -61,7 +61,11 @@ def deploy(
actor = worker.deploy_module(module_class, global_config, kwargs=kwargs)
return RPCClient(actor, module_class)

def deploy_parallel(self, module_specs: Iterable[ModuleSpec]) -> list[RPCClient]:
def deploy_parallel(
self,
module_specs: Iterable[ModuleSpec],
blueprint_args: Mapping[str, Mapping[str, Any]],
) -> list[RPCClient]:
if self._closed:
raise RuntimeError("WorkerManager is closed")

Expand All @@ -76,6 +80,7 @@ def deploy_parallel(self, module_specs: Iterable[ModuleSpec]) -> list[RPCClient]
for module_class, global_config, kwargs in module_specs:
worker = self._select_worker()
worker.reserve_slot()
kwargs.update(blueprint_args.get(module_class.name, {}))
assignments.append((worker, module_class, global_config, kwargs))

def _deploy(
Expand Down
94 changes: 92 additions & 2 deletions dimos/robot/cli/dimos.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,38 @@

from __future__ import annotations

from collections.abc import Iterable
from datetime import datetime, timezone
import inspect
import json
import os
from pathlib import Path
import sys
import time
import types
from typing import Any, get_args, get_origin

import click
from dotenv import load_dotenv
from pydantic import BaseModel
from pydantic_core import PydanticUndefined
import requests
import typer

from dimos.agents.mcp.mcp_adapter import McpAdapter, McpError
from dimos.core.blueprints import Blueprint, _BlueprintAtom
from dimos.core.global_config import GlobalConfig, global_config
from dimos.core.run_registry import get_most_recent, is_pid_alive, stop_entry
from dimos.utils.logging_config import setup_logger

try:
# Not a dependency, just the best way to get config path if available.
from gi.repository import GLib # type: ignore[import-untyped]
except ImportError:
CONFIG_DIR = Path(os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config"))
else:
CONFIG_DIR = Path(GLib.get_user_config_dir())

logger = setup_logger()

main = typer.Typer(
Expand Down Expand Up @@ -108,12 +122,79 @@ def callback(**kwargs) -> None: # type: ignore[no-untyped-def]
main.callback()(create_dynamic_callback()) # type: ignore[no-untyped-call]


def arghelp(
config: type[BaseModel],
blueprint: Blueprint,
indent: str = " ",
module: str = "",
_atom: _BlueprintAtom | None = None,
) -> str:
output = ""
for k, info in config.model_fields.items():
if k == "g":
continue
t = info.annotation
if isinstance(t, types.GenericAlias):
# Can't be specified on CLI
continue

if t is not None and issubclass(t, BaseModel):
output += f"{indent}{module}{k}:\n"
# Find blueprint atom
bp = next(bp for bp in blueprint.blueprints if bp.module.name == k)
output += arghelp(t, blueprint, indent=indent + " ", module=module + k + ".", _atom=bp)
else:
assert _atom is not None
# Use __name__ to avoid "<class 'int'>" style output on basic types.
display_type = t.__name__ if isinstance(t, type) else t
required = "[Required] " if info.is_required() and k not in _atom.kwargs else ""
d = _atom.kwargs.get(k, info.default)
default = f" (default: {d})" if d is not PydanticUndefined else ""
output += f"{indent}* {required}{module}{k}: {display_type}{default}\n"
return output


def load_config_args(config: type[BaseModel], args: Iterable[str], path: Path) -> dict[str, Any]:
try:
kwargs = json.loads(path.read_text())
except (OSError, json.JSONDecodeError):
kwargs = {}

for k, v in os.environ.items():
parts = k.lower().split("__")
if parts[0] not in config.model_fields:
continue
d = kwargs
for p in parts[:-1]:
d = d.setdefault(p, {})
d[parts[-1]] = v

for arg in args:
k, _, v = arg.partition("=")
parts = k.split(".")
d = kwargs
for p in parts[:-1]:
d = d.setdefault(p, {})
d[parts[-1]] = v

# We don't need this config, but this atleast validates the user input first.
# This will help catch misspellings and similar mistakes.
config(**kwargs)

return kwargs # type: ignore[no-any-return]


@main.command()
def run(
ctx: typer.Context,
robot_types: list[str] = typer.Argument(..., help="Blueprints or modules to run"),
daemon: bool = typer.Option(False, "--daemon", "-d", help="Run in background"),
disable: list[str] = typer.Option([], "--disable", help="Module names to disable"),
blueprint_args: list[str] = typer.Option((), "--option", "-o"),
config_path: Path = typer.Option(
CONFIG_DIR / "dimos", "--config", "-c", help="Path to config file"
),
show_help: bool = typer.Option(False, "--help"),
) -> None:
"""Start a robot blueprint"""
logger.info("Starting DimOS")
Expand All @@ -132,7 +213,6 @@ def run(
setup_exception_handler()

cli_config_overrides: dict[str, Any] = ctx.obj
global_config.update(**cli_config_overrides)

# Clean stale registry entries
stale = cleanup_stale()
Expand Down Expand Up @@ -163,7 +243,17 @@ def run(
disabled_classes = tuple(get_module_by_name(name).blueprints[0].module for name in disable)
blueprint = blueprint.disabled_modules(*disabled_classes)

coordinator = blueprint.build(cli_config_overrides=cli_config_overrides)
if show_help:
print("Blueprint arguments:")
print(arghelp(blueprint.config(), blueprint))
return

blueprint_config = blueprint.config()
kwargs = load_config_args(blueprint_config, blueprint_args, config_path)
if cli_config_overrides:
kwargs["g"] = cli_config_overrides

coordinator = blueprint.build(kwargs)

if daemon:
from dimos.core.daemon import (
Expand Down
87 changes: 87 additions & 0 deletions dimos/robot/cli/test_dimos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright 2026 Dimensional Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dimos.core.blueprints import autoconnect
from dimos.core.module import Module, ModuleConfig

# from dimos.robot.cli.dimos import arghelp
from dimos.robot.unitree.go2.connection import GO2Connection
from dimos.visualization.rerun.bridge import RerunBridgeModule, _default_blueprint


def test_blueprint_arghelp():
blueprint = autoconnect(RerunBridgeModule.blueprint(), GO2Connection.blueprint())
output = arghelp(blueprint.config(), blueprint)
assert output.split("\n") == [
" rerunbridgemodule:",
" * rerunbridgemodule.frame_id_prefix: str | None (default: None)",
" * rerunbridgemodule.frame_id: str | None (default: None)",
" * rerunbridgemodule.entity_prefix: str (default: world)",
" * rerunbridgemodule.topic_to_entity: collections.abc.Callable[[typing.Any], str] | None (default: None)",
" * rerunbridgemodule.viewer_mode: typing.Literal['native', 'web', 'connect', 'none']",
" * rerunbridgemodule.connect_url: str (default: rerun+http://127.0.0.1:9877/proxy)",
" * rerunbridgemodule.memory_limit: str (default: 25%)",
f" * rerunbridgemodule.blueprint: collections.abc.Callable[rerun.blueprint.api.Blueprint] | None (default: {_default_blueprint})",
" go2connection:",
" * go2connection.frame_id_prefix: str | None (default: None)",
" * go2connection.frame_id: str | None (default: None)",
" * go2connection.ip: str",
"",
]


def test_blueprint_arghelp_extra_args():
"""Test defaults passed to .blueprint() override."""

bridge = RerunBridgeModule.blueprint(frame_id_prefix="foo", viewer_mode="web")
blueprint = autoconnect(bridge, GO2Connection.blueprint(ip="1.1.1.1"))
output = arghelp(blueprint.config(), blueprint)
assert output.split("\n") == [
" rerunbridgemodule:",
" * rerunbridgemodule.frame_id_prefix: str | None (default: foo)",
" * rerunbridgemodule.frame_id: str | None (default: None)",
" * rerunbridgemodule.entity_prefix: str (default: world)",
" * rerunbridgemodule.topic_to_entity: collections.abc.Callable[[typing.Any], str] | None (default: None)",
" * rerunbridgemodule.viewer_mode: typing.Literal['native', 'web', 'connect', 'none'] (default: web)",
" * rerunbridgemodule.connect_url: str (default: rerun+http://127.0.0.1:9877/proxy)",
" * rerunbridgemodule.memory_limit: str (default: 25%)",
f" * rerunbridgemodule.blueprint: collections.abc.Callable[rerun.blueprint.api.Blueprint] | None (default: {_default_blueprint})",
" go2connection:",
" * go2connection.frame_id_prefix: str | None (default: None)",
" * go2connection.frame_id: str | None (default: None)",
" * go2connection.ip: str (default: 1.1.1.1)",
"",
]


def test_blueprint_arghelp_required():
"""Test required arguments."""

class Config(ModuleConfig):
foo: int
spam: str = "eggs"

class TestModule(Module[Config]):
default_config = Config

blueprint = TestModule.blueprint()
output = arghelp(blueprint.config(), blueprint)
assert output.split("\n") == [
" testmodule:",
" * testmodule.frame_id_prefix: str | None (default: None)",
" * testmodule.frame_id: str | None (default: None)",
" * [Required] testmodule.foo: int",
" * testmodule.spam: str (default: eggs)",
"",
]
Loading
Loading