From ec93173360016873f48f72739a6c0d3c3a71de58 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Wed, 4 Mar 2026 22:11:47 +0800 Subject: [PATCH 001/118] memory plans --- plans/memory.md | 112 ++++++ plans/memory1.md | 319 +++++++++++++++++ plans/memory2.md | 898 +++++++++++++++++++++++++++++++++++++++++++++++ plans/memory3.md | 364 +++++++++++++++++++ 4 files changed, 1693 insertions(+) create mode 100644 plans/memory.md create mode 100644 plans/memory1.md create mode 100644 plans/memory2.md create mode 100644 plans/memory3.md diff --git a/plans/memory.md b/plans/memory.md new file mode 100644 index 0000000000..f884e8dfb6 --- /dev/null +++ b/plans/memory.md @@ -0,0 +1,112 @@ + +## Goals +We are building this (seriously - exactly this) https://www.youtube.com/watch?v=Zkj5WSae3Uc + +We are designing a human-centric interface, not an agentic interface for now. +Good human centric interface allows us to test our own tooling before giving it to an agent, + +Reason for this is that if you give something directly to an agent without using it yourself, it might be shit and agent might be rightfully underperforming. + +## all sensor data is stored by default for every run + auto rotation/cleanup, set max_size + +## Data streams + +- These are sensor streams (all sensor data is stored by default for every run) +- But also other data streams created in real time or async. + +### All datapoints that have a temporal index (are `Timestamped`) + +for all temporal datapoints we need to be able to get spatial info for - this is important for multi embodiment, we do this either by having robot_id associated to the stream, + +robots can see the same shoe from different angles, we can deduplicate once temporal or spatial matches are there + +or by directly storing a 4D index or time + 3D index. how we actually store stuff is not important and storage system dependent, how we query is what we care about. + +so I can quickly ask for a video frame, where/when it was captured. +I can detect on top of it, fetch a related LIDAR frame, project + +### Search + +Different datastreams provide different types of search, text or image + +Facial recognition datastream can accept a face search, time or space +Agent narration or Video stream can search by vector embedding, time or space +Sound recording, search by time or space + +Some of these abilities imply other types of search, being able to accept embedding search means you can search by text or by image as well + +## Reprocessing, parallel streams + +different algos can create different new datastreams (like embedding models for example, LLM narration etc) +some of these datastreams are slower than realtime, with ability to catch up (like embeddings aren't generated if robot isn't moving) some of these are to be stored permanently, some are temporary and part of some analysis and will be thrown away. + +if this is designed well, on API level we don't care if we are dealing with a stored stream or a search result, we don't care if stream is stored (and where) or in memory as part of some analysis etc. + +### Example + +speaker clustering model is analyzing audio, gives speaker embedding stream (with temporal/spatial index) +correlating facial recognition embeddings to speech embeddings we can match a face to voice + +## Semantic Costmaps + +overlay semantic similarity onto a costmap rendered in rerun in realtime + +## Object Search + +We have many frames to analyze with VLM, analysis is costly (but cheaper if batched!) +So we need to use traditional search algos, use semantic similarity as a heuristic, find hotspots in time and space to analyze with VLMs (just some standard hill climbing, simulated annealing and such. keep in mind we might not be looking for a global optimum but local hills) we can also use clustering algos etc + +Once best matches are found, project into 3d + +## logs + +system logs, human-agent-tool interaction are also temporal/textual streams that can be stored, embedded, searched over + +### Embedding data streams + +# milestone 1 + +I can query for "a shoe" in a textbox, get a semantic map overlay + +# milestone 2 + +I can query for "a shoe" in a textbox, get PointStamped for shoes detected by VLM + +## example interaction 1: memory search + +search for "a shoe" - independent stored streams offer textual queries + +3 agent narration matches (temporal textual stream2) +1 tool call match (temporal textual stream 2) +temporal-semantic graph returned (image embeddings) + +temporal-spatial-semantic graph analysis - 3 clusters identified, feed each cluster to some description VLM - "a shoe on a kitchen floor", "a shoe on a desk" etc + +return to an agent: + +- narration block, timestamp +- tool call match, timestamp +- return 3 clusters, timestamps, rough locations + +agent calls goto (event cluster 3) + +cluster 3 - find best image, correlate to lidar, project into space, navigate, once there, use VLM and visual nav + + +## example interaction 2: arm + +mustafa is able to ask for an object in proximity to the robot. robot searches memory biasing distance in time and space. if close match is not found, search can be expanded + +"do you remember the red sock" + +"yes I saw it 35 seconds ago" + +"yes I saw it 3 days ago behind me" + +"yes I saw it an hour ago, it was 15 meters away" + + +# Questions + +"where was I, when this log line was added" diff --git a/plans/memory1.md b/plans/memory1.md new file mode 100644 index 0000000000..9e9c7028b1 --- /dev/null +++ b/plans/memory1.md @@ -0,0 +1,319 @@ +# DB → Session → Store: DimOS Memory2 + +## Context + +PR #1080 introduced `TimeSeriesStore[T]` with pluggable backends. Paul's review identified it mixes DB lifecycle, connection, and query concerns. Additionally, `memory.md` describes a system where all sensor data is stored as temporal streams with 4D spatial-temporal indexing, cross-stream correlation is the primary operation, and search (text/embedding) must work across streams. This plan builds a clean 3-layer architecture from scratch in `dimos/memory2/`, SQLite-first, with R\*Tree indexing for spatial-temporal queries. + +## Architecture + +``` +SqliteDB (config + factory + WAL + sqlite-vec + R*Tree) + └─ Session (connection, thread-bound) + ├─ .timeseries(table, type) → TimeSeries[T] (temporal store + optional 4D spatial index) + ├─ .embeddings(table, dim) → EmbeddingStore (KNN search store + optional spatial index) + ├─ .at(t, *stores) → tuple (multi-stream temporal lookup) + ├─ .between(t1, t2, *stores)→ Iterator[tuple] (batch temporal join) + └─ .execute(sql, params) → rows (raw SQL escape hatch) +``` + +Every stream gets an R\*Tree (4D: time + xyz). Spatial info is optional per-row — rows without spatial data are indexed by time only (x/y/z set to NaN sentinels or excluded). This eliminates the need for cross-stream pose joins: each datapoint carries its own spatial context at write time. + +## API Examples + +```python +db = SqliteDB("run_001.db") + +with db.session() as s: + images = s.timeseries("color_images", Image) + poses = s.timeseries("poses", PoseStamped) + lidar = s.timeseries("lidar", PointCloud) + img_emb = s.embeddings("image_embeddings", dim=512) + + # --- Save with optional spatial context --- + images.save(frame) # temporal only + images.save(frame, pose=robot_pose) # temporal + spatial (baked in) + images.save(frame, pose=(pos, quat)) # PoseLike tuple also works + + # --- Temporal queries (chainable) --- + hit = images.at(now).one() # closest to now → Hit | None + hit = images.at(now, tolerance=0.1).one() # within 100ms or None + hit = images.before(now).one() # last item before now + hit = images.last() # most recent (shortcut) + + # Lazy fetch actual data from Hit + image = images.load(hit.ts) # → Image + + # --- Spatial queries (R*Tree, chainable) --- + hits = images.near(Point(1, 2, 3), radius=0.5).fetch() + hits = images.near(robot_pose, radius=2.0).between(t1, t2).fetch() + + # Each hit has pose (full 6DOF) for reconstruction + for hit in hits: + print(f"Seen at {hit.pose}, dist={hit.spatial_distance}m") + + # --- Embedding search (chainable) --- + query_vec = clip.encode_text("a shoe") + + # Embedding only + hits = img_emb.search(query_vec, k=20).fetch() + + # Embedding + spatial + hits = img_emb.search(query_vec, k=10).near(robot_pose, radius=3.0).fetch() + + # Embedding + temporal + hits = img_emb.search(query_vec, k=10).between(t1, t2).fetch() + + # All three: embedding + spatial + temporal + hits = (img_emb.search(query_vec, k=10) + .near(robot_pose, radius=5.0) + .between(now - 3600, now) + .fetch()) + + for hit in hits: + hit.ts # when + hit.pose # where + orientation (6DOF) + hit.embedding_distance # similarity score + hit.spatial_distance # meters from query point + image = images.at(hit.ts).one() # correlate to image stream + vec = img_emb.load_embedding(hit.id) # lazy fetch embedding + + # --- Cross-stream temporal lookup --- + pose_hit = poses.at(hit.ts).one() + + # --- Raw SQL escape hatch --- + rows = s.execute("SELECT ... FROM ... JOIN ...", params) +``` + +## File Structure + +``` +dimos/memory2/ + __init__.py # public exports + _sql.py # _validate_identifier(), shared SQL helpers + db.py # DB ABC + SqliteDB + session.py # Session ABC + SqliteSession + hit.py # Hit hierarchy (7 classes: Hit, Temporal, Spatial, Embedding, combos) + query.py # Query hierarchy (7 classes: matching Hit types, chainable) + timeseries.py # TimeSeries[T] ABC + SqliteTimeSeries + embeddings.py # EmbeddingStore ABC + SqliteEmbeddingStore + test_memory2.py # tests +``` + +## Interfaces + +### DB (`db.py`) + +```python +class DB(Resource, ABC): + def session(self) -> Session: ... + def close(self) -> None: ... # closes all tracked sessions + # Resource protocol + def start(self) -> None: pass # usable after __init__ + def stop(self) -> None: self.close() +``` + +`SqliteDB`: +- Stores file path, creates parent dirs on first connect +- `_connect()`: `sqlite3.connect()`, enables WAL mode, loads sqlite-vec +- Tracks sessions via `WeakSet` for cleanup +- `:memory:` uses `file::memory:?cache=shared` URI so sessions share data + +### Session (`session.py`) + +```python +class Session(ABC): + def timeseries(self, table: str, type: type[T]) -> TimeSeries[T]: ... + def embeddings(self, table: str, dim: int) -> EmbeddingStore: ... + def execute(self, sql: str, params=()) -> list: ... + def close(self) -> None: ... + def __enter__ / __exit__ # context manager +``` + +`SqliteSession`: +- Holds one `sqlite3.Connection` +- `timeseries()` / `embeddings()` validate table name, create store, cache it +- `execute()`: raw SQL passthrough +- Cross-stream correlation done via Query builder (e.g. `poses.at(hit.ts).one()`) + +### TimeSeries (`timeseries.py`) + +```python +from dimos.msgs.geometry_msgs.Pose import Pose, PoseLike +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Point import Point + +# --- Hit type hierarchy (type-state, 7 classes) --- + +@dataclass +class Hit: + """Base result: just ts + optional pose. All data lazy-fetched.""" + ts: float + pose: Pose | None = None + +@dataclass +class TemporalHit(Hit): + temporal_distance: float = 0.0 # |query_time - ts| + +@dataclass +class SpatialHit(Hit): + spatial_distance: float = 0.0 # meters from query point + pose: Pose = field(default=...) # guaranteed present for spatial hits + +@dataclass +class EmbeddingHit(Hit): + embedding_distance: float = 0.0 # cosine/L2 in embedding space + id: str = "" + metadata: dict | None = None + +# Combinations (multiple inheritance) +@dataclass +class TemporalSpatialHit(TemporalHit, SpatialHit): ... + +@dataclass +class TemporalEmbeddingHit(TemporalHit, EmbeddingHit): ... + +@dataclass +class SpatialEmbeddingHit(SpatialHit, EmbeddingHit): ... + +@dataclass +class FullHit(TemporalHit, SpatialHit, EmbeddingHit): ... + +# --- Query type-state hierarchy (7 classes, narrows on chain) --- + +class Query: + """Base query builder. Accumulates filters, executes on .fetch().""" + def fetch(self, limit: int | None = None) -> list[Hit]: ... + def one(self) -> Hit | None: ... + def count(self) -> int: ... + +class TemporalQuery(Query): + def near(self, point: Point | PoseLike | PoseStamped, + radius: float) -> TemporalSpatialQuery: ... + def fetch(self, limit=None) -> list[TemporalHit]: ... + def one(self) -> TemporalHit | None: ... + +class SpatialQuery(Query): + def at(self, t: float, tolerance: float | None = None) -> TemporalSpatialQuery: ... + def before(self, t: float) -> TemporalSpatialQuery: ... + def after(self, t: float) -> TemporalSpatialQuery: ... + def between(self, t1: float, t2: float) -> TemporalSpatialQuery: ... + def fetch(self, limit=None) -> list[SpatialHit]: ... + def one(self) -> SpatialHit | None: ... + +class EmbeddingQuery(Query): + def near(self, ...) -> SpatialEmbeddingQuery: ... + def at(self, ...) -> TemporalEmbeddingQuery: ... + def between(self, ...) -> TemporalEmbeddingQuery: ... + def fetch(self, limit=None) -> list[EmbeddingHit]: ... + +class TemporalSpatialQuery(Query): + def fetch(self, limit=None) -> list[TemporalSpatialHit]: ... + +class TemporalEmbeddingQuery(Query): + def near(self, ...) -> FullQuery: ... + def fetch(self, limit=None) -> list[TemporalEmbeddingHit]: ... + +class SpatialEmbeddingQuery(Query): + def at(self, ...) -> FullQuery: ... + def between(self, ...) -> FullQuery: ... + def fetch(self, limit=None) -> list[SpatialEmbeddingHit]: ... + +class FullQuery(Query): + def fetch(self, limit=None) -> list[FullHit]: ... + +# All query logic (SQL generation) lives in base Query. +# Subclasses only override type signatures — no duplicated logic. + +# --- TimeSeries --- + +class TimeSeries(Generic[T], ABC): + # Write + def save(self, *items: T, pose: PoseLike | PoseStamped | None = None) -> None: ... + + # Start a query chain (returns typed Query) + def at(self, t: float, tolerance: float | None = None) -> TemporalQuery: ... + def before(self, t: float) -> TemporalQuery: ... + def after(self, t: float) -> TemporalQuery: ... + def between(self, t1: float, t2: float) -> TemporalQuery: ... + def near(self, point: Point | PoseLike | PoseStamped, + radius: float) -> SpatialQuery: ... + + # Convenience terminals (no chain needed) + def last(self) -> TemporalHit | None: ... + def first(self) -> TemporalHit | None: ... + + # Lazy data fetch (from Hit.ts) + def load(self, ts: float) -> T | None: ... + + def delete(self, t: float) -> bool: ... + def count(self) -> int: ... +``` + +All spatial parameters accept DimOS types with `.x`, `.y`, `.z` — `Point`, `Pose`, `PoseStamped`, `PoseLike`. Full pose (with orientation) stored per row for post-filter reconstruction. + +`SqliteTimeSeries`: +- Data table: `CREATE TABLE {table} (rowid INTEGER PRIMARY KEY, timestamp REAL NOT NULL, data BLOB NOT NULL)` +- R\*Tree: `CREATE VIRTUAL TABLE {table}_rtree USING rtree(id, min_t, max_t, min_x, max_x, min_y, max_y, min_z, max_z)` +- R\*Tree `id` matches `rowid` in data table +- `save(item, pose=p)`: inserts data row + R\*Tree entry with `(ts, ts, x, x, y, y, z, z)` (point) +- `save(item)` without pose: inserts data row + R\*Tree entry with time only (x/y/z set to ±inf to match any spatial query) +- `at()`: `SELECT data FROM {table} ORDER BY ABS(timestamp - ?) LIMIT 1` +- `between()`: `SELECT data FROM {table} WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp` +- `near()`: `SELECT d.data FROM {table} d JOIN {table}_rtree r ON d.rowid = r.id WHERE r.min_t >= ? AND r.max_t <= ? AND r.min_x >= ? AND r.max_x <= ? ...` +- Lazy table creation on first operation + +### EmbeddingStore (`embeddings.py`) + +```python +class EmbeddingStore(ABC): + def save(self, id: str, vector: np.ndarray, timestamp: float, + pose: PoseLike | PoseStamped | None = None, + metadata: dict | None = None) -> None: ... + + # Start a query chain with embedding search (returns typed Query) + def search(self, query: np.ndarray, k: int = 10) -> EmbeddingQuery: ... + + # Chain: .search(vec, 10).near(p, 3.0).between(t1, t2).fetch() → list[FullHit] + + # Lazy fetch + def load_embedding(self, id: str) -> np.ndarray | None: ... + + def delete(self, id: str) -> bool: ... + def count(self) -> int: ... +``` + +Uses the same `Query` builder and `Hit` result type as TimeSeries. `.search()` returns a Query with embedding filter set; chain `.near()`, `.between()`, etc. to add spatial/temporal constraints. + +`SqliteEmbeddingStore`: +- Three tables: `{table}_vec` (sqlite-vec virtual, `float[dim]`), `{table}_meta` (rowid, id, timestamp, x, y, z, metadata JSON), `{table}_rtree` (R\*Tree for spatial-temporal filtering) +- `search()`: KNN via `{table}_vec MATCH ?`, joined with meta for time/spatial filters +- `near=` param: pre-filters candidates via R\*Tree before KNN +- Each `SearchHit` carries position (x, y, z) directly — no pose join needed + +## SQLite Details + +- **WAL mode**: enabled on first connection per DB file. Allows concurrent readers + one writer across threads. +- **R\*Tree**: built into SQLite (compile-time option, enabled by default). Every stream gets a 4D R\*Tree (time + xyz). No extra extension needed. +- **sqlite-vec**: loaded via `conn.load_extension()`. Required for EmbeddingStore. TimeSeries works without it. +- **Thread safety**: each session = one connection = one thread. No `check_same_thread=False`. +- **Pickle BLOBs**: same serialization as current SqliteTSStore. Works with any `Timestamped` subclass. +- **Spatial data without pose**: rows saved without `pose=` get R\*Tree entry with x/y/z bounds set to ±1e38 (effectively unbounded), so they match any spatial query but don't constrain results. + +## Implementation Order + +1. `_sql.py` — `_validate_identifier()` +2. `hit.py` — `Hit` dataclass (unified result type) +3. `query.py` — `Query` builder (accumulates filters, generates SQL, returns `list[Hit]`) +4. `timeseries.py` — `TimeSeries[T]` ABC + `SqliteTimeSeries` (chain methods return Query) +5. `embeddings.py` — `EmbeddingStore` ABC + `SqliteEmbeddingStore` (.search() returns Query) +6. `session.py` — `Session` ABC + `SqliteSession` +7. `db.py` — `DB` ABC + `SqliteDB` (config, connect, WAL, sqlite-vec, Resource) +8. `__init__.py` — public exports +9. `test_memory2.py` — tests: lifecycle, temporal/spatial/embedding queries, combined chains, lazy fetch +10. `pyproject.toml` — add `sqlite-vec` dependency + +## Verification + +1. `uv run pytest dimos/memory2/test_memory2.py -v` — all new tests pass +2. `uv run mypy dimos/memory2/` — type checks clean +3. Existing `dimos/memory/timeseries/test_base.py` still passes (untouched) diff --git a/plans/memory2.md b/plans/memory2.md new file mode 100644 index 0000000000..816d2cb0af --- /dev/null +++ b/plans/memory2.md @@ -0,0 +1,898 @@ +# DimOS Memory2 Spec v2.1 + +Status: implementation-oriented draft for a coding agent. + +This spec is intentionally code/example focused. It defines the public API shape, core invariants, and the minimum execution model needed to implement a useful local-first multimodal memory system. + +--- + +# 0. Goals + +Memory2 stores and queries multimodal robot observations. + +Primary use cases: + +1. Store raw streams: images, lidar, poses, logs, narration. +2. Generate streams from streams: embeddings from images, captions from images, detections from frames. +3. Narrow data without loading payloads: top-k matches, time windows, spatial subsets. +4. Re-query narrowed results. +5. Correlate across streams. +6. Keep payload loading lazy. + +Non-goal: + +- Do not implement high-level search policies here (hotspot search, VLM orchestration, semantic map UI). + +--- + +# 1. Terminology + +## 1.1 Observation + +A single stored item. + +Examples: + +- one RGB frame +- one lidar scan +- one log line +- one CLIP embedding +- one VLM caption + +## 1.2 Stream + +Appendable collection of observations with a shared payload type and capability set. + +Examples: + +- `rgb_front` +- `lidar_mid360` +- `robot_pose` +- `tool_logs` +- `image_embeddings_clip` + +## 1.3 ObservationSet + +A lazy, read-only, queryable view over observations. + +Important: + +- an `ObservationSet` is **not** a Python set +- it is usually **lazy** +- it usually contains **refs + metadata**, not payloads +- it may represent a subset of one stream or a projection/correlation over multiple streams + +## 1.4 DerivedStream + +A stream generated from upstream streams or observation sets. + +Examples: + +- embeddings generated from images +- captions generated from images +- detections generated from frames + +Rule: + +- same observation identity -> `ObservationSet` +- new observation identity -> `DerivedStream` + +--- + +# 2. Core invariants + +These are hard requirements. + +## 2.1 Stable identity + +Every observation has a stable reference independent of timestamp. + +```python +from dataclasses import dataclass + +@dataclass(frozen=True) +class ObservationRef: + stream: str + id: str +``` + +Never use timestamp as the primary load key. + +Bad: + +```python +images.load(hit.ts) +``` + +Good: + +```python +images.load(hit.ref) +``` + +## 2.2 Payloads are lazy + +Queries and observation sets must not load full payloads unless explicitly requested. + +Examples of payloads that must stay lazy: + +- images +- point clouds +- audio chunks +- voxel blocks + +## 2.3 Metadata may be materialized + +It is acceptable to materialize lightweight metadata for result sets: + +- ref +- timestamp +- pose +- scores +- tags +- lineage pointers + +## 2.4 Query results are re-queryable + +A narrowed result should still support `.query()` and further filtering/ranking. + +## 2.5 Query results are not appendable + +`ObservationSet` is read-only. + +Only `Stream` is appendable. + +## 2.6 Spatially unknown != spatially everywhere + +Unlocalized observations do not match spatial queries by default. + +## 2.7 Derived outputs must carry lineage + +Any derived stream should record parent streams and parent refs or parent query provenance. + +--- + +# 3. Public API + +## 3.1 Top-level objects + +```python +class DB: ... +class Session: ... +class Stream[T]: ... +class ObservationSet[T]: ... +class Query[T]: ... +class Correlator: ... +``` + +## 3.2 Shared read/query protocol + +`Stream` and `ObservationSet` should share the same read/query protocol. + +```python +from typing import Protocol, Iterable, Iterator, Generic, TypeVar, Any + +T = TypeVar("T") + +class QueryableObservationSet(Protocol, Generic[T]): + def query(self) -> "Query[T]": ... + def load(self, ref: ObservationRef) -> T: ... + def load_many(self, refs: list[ObservationRef], *, batch_size: int = 32) -> list[T]: ... + def iter_meta(self, *, page_size: int = 128) -> Iterator[list["ObservationRow"]]: ... + def count(self) -> int: ... + def capabilities(self) -> set[str]: ... +``` + +`Stream` extends this with append/introspection. + +--- + +# 4. Core data structures + +## 4.1 Observation metadata + +```python +from dataclasses import dataclass, field +from typing import Any + +@dataclass +class Pose: + xyz: tuple[float, float, float] + quat_xyzw: tuple[float, float, float, float] | None = None + +@dataclass +class ObservationMeta: + ref: ObservationRef + ts_start: float | None = None + ts_end: float | None = None + robot_id: str | None = None + frame_id: str | None = None + pose: Pose | None = None + pose_source: str | None = None + pose_confidence: float | None = None + transform_version: str | None = None + timestamp_uncertainty: float | None = None + payload_codec: str | None = None + payload_size_bytes: int | None = None + tags: dict[str, Any] = field(default_factory=dict) +``` + +Notes: + +- point observations use `ts_start == ts_end` +- interval observations use `[ts_start, ts_end]` +- `pose` is a denormalized snapshot for fast filtering +- provenance fields allow later reinterpretation after better localization + +## 4.2 Query/ObservationSet row + +An `ObservationSet` should expose rows with lightweight metadata and scores. + +```python +@dataclass +class ObservationRow: + ref: ObservationRef + ts_start: float | None = None + ts_end: float | None = None + pose: Pose | None = None + scores: dict[str, float] = field(default_factory=dict) + metadata: dict[str, Any] = field(default_factory=dict) +``` + +Expected score keys: + +- `embedding_distance` +- `text_rank` +- `spatial_distance` +- `temporal_distance` +- `final_rank` + +## 4.3 Lineage + +```python +@dataclass +class Lineage: + parents: list[str] = field(default_factory=list) + parent_refs: list[ObservationRef] = field(default_factory=list) + query_repr: str | None = None + transform_name: str | None = None + transform_version: str | None = None +``` + +This can be attached to streams, rows, or derived outputs. + +--- + +# 5. Stream API + +## 5.1 Stream creation + +```python +with db.session() as s: + images = s.stream( + name="rgb_front", + payload_type=Image, + capabilities={"temporal", "spatial", "load"}, + retention="run", + ) + + logs = s.stream( + name="tool_logs", + payload_type=str, + capabilities={"temporal", "text", "load"}, + retention="run", + ) + + image_embeddings = s.stream( + name="image_embeddings_clip", + payload_type=Embedding, + capabilities={"temporal", "spatial", "embedding", "load"}, + retention="derived", + config={"dim": 512, "metric": "cosine"}, + ) +``` + +## 5.2 Stream interface + +```python +class Stream(QueryableObservationSet[T], Generic[T]): + def append(self, payload: T, **meta: Any) -> ObservationRef: ... + def append_many(self, payloads: Iterable[T], metas: Iterable[dict[str, Any]]) -> list[ObservationRef]: ... + def meta(self, ref: ObservationRef) -> ObservationMeta: ... + def info(self) -> dict[str, Any]: ... + def stats(self) -> dict[str, Any]: ... + def retention(self) -> str: ... +``` + +## 5.3 Append examples + +```python +frame_ref = images.append( + frame, + ts_start=now, + ts_end=now, + robot_id="go2_01", + frame_id="map", + pose=current_pose, + pose_source="slam_localization", + transform_version="loc_epoch_17", +) + +log_ref = logs.append( + "planner timeout on task 42", + ts_start=now, + ts_end=now, + tags={"level": "warning", "module": "planner"}, +) +``` + +--- + +# 6. ObservationSet API + +## 6.1 Design intent + +`ObservationSet` is the key abstraction for narrowed/re-queryable results. + +It should: + +- be lazy by default +- usually avoid payload loading +- support `.query()` +- support loading payloads one-by-one or in batches +- support projection to related streams +- support materialization when needed + +## 6.2 Interface + +```python +class ObservationSet(QueryableObservationSet[T], Generic[T]): + def refs(self, *, limit: int | None = None) -> list[ObservationRef]: ... + def rows(self, *, limit: int | None = None) -> list[ObservationRow]: ... + def one(self) -> ObservationRow: ... + def fetch_page(self, *, limit: int = 128, offset: int = 0) -> list[ObservationRow]: ... + def project_to(self, stream: "Stream[Any]") -> "ObservationSet[Any]": ... + def materialize(self, *, name: str | None = None, retention: str = "ephemeral") -> "ObservationSet[T]": ... + def derive(self, *, name: str, transform: "Transform[T, Any]", retention: str = "derived", payload_type: type | None = None) -> "Stream[Any]": ... + def lineage(self) -> Lineage: ... +``` + +## 6.3 Example: narrowing data and re-querying + +```python +recent_images = ( + images.query() + .filter_time(now - 600, now) + .fetch_set() +) + +recent_nearby_images = ( + recent_images.query() + .filter_near(current_pose, radius=3.0) + .fetch_set() +) +``` + +## 6.4 Example: embedding search without loading images + +```python +matches = ( + image_embeddings.query() + .search_embedding(query_vec, candidate_k=2000) + .filter_time(now - 3600, now) + .filter_near(current_pose, radius=8.0) + .rank(embedding=1.0, recency=0.2, distance=0.3) + .limit(1000) + .fetch_set() +) +``` + +Important: + +- `matches` should not contain 1000 image payloads in RAM +- it should usually contain refs + lightweight metadata/scores only + +## 6.5 Example: payload access stays explicit + +```python +rows = matches.fetch_page(limit=20, offset=0) +first_payload = image_embeddings.load(rows[0].ref) + +candidate_refs = matches.refs(limit=16) +embeddings_batch = image_embeddings.load_many(candidate_refs, batch_size=16) +``` + +## 6.6 Example: projecting embedding matches to images + +Assume each embedding row records its parent image ref. + +```python +matched_frames = matches.project_to(images) +preview_rows = matched_frames.fetch_page(limit=12) +preview_frames = images.load_many([r.ref for r in preview_rows], batch_size=12) +``` + +## 6.7 Example: deriving a caption stream from a narrowed image set + +```python +captions = matched_frames.derive( + name="vlm_captions_shoe_candidates", + transform=caption_model, + retention="derived", + payload_type=str, +) +``` + +This creates a new stream because it creates new observation identities. + +--- + +# 7. Query API + +## 7.1 Query design + +Query should be composable and capability-based. + +It should support: + +- hard filters +- candidate generation +- soft ranking +- terminal materialization + +## 7.2 Interface + +```python +class Query(Generic[T]): + def filter_time(self, t1: float, t2: float) -> "Query[T]": ... + def filter_before(self, t: float) -> "Query[T]": ... + def filter_after(self, t: float) -> "Query[T]": ... + def filter_near(self, pose: Pose, radius: float, *, include_unlocalized: bool = False) -> "Query[T]": ... + def filter_tags(self, **tags: Any) -> "Query[T]": ... + def filter_refs(self, refs: list[ObservationRef]) -> "Query[T]": ... + + def search_text(self, text: str, *, candidate_k: int | None = None) -> "Query[T]": ... + def search_embedding(self, vector: list[float], *, candidate_k: int) -> "Query[T]": ... + + def rank(self, **weights: float) -> "Query[T]": ... + def limit(self, k: int) -> "Query[T]": ... + + def fetch(self) -> list[ObservationRow]: ... + def fetch_set(self) -> ObservationSet[T]: ... + def count(self) -> int: ... + def one(self) -> ObservationRow: ... +``` + +## 7.3 Hard filters vs ranking + +This distinction must stay explicit. + +Example: + +```python +hits = ( + image_embeddings.query() + .search_embedding(query_vec, candidate_k=1000) + .filter_time(t1, t2) + .filter_near(current_pose, radius=5.0) + .rank(embedding=1.0, recency=0.15, distance=0.35) + .limit(50) + .fetch() +) +``` + +Execution meaning: + +1. embedding search creates candidates +2. time/space filters remove candidates +3. ranking combines scores on remaining rows +4. limit applies at the end + +Do not leave this ambiguous. + +--- + +# 8. Under-the-hood model for ObservationSet + +## 8.1 Default behavior + +`ObservationSet` should be lazy/unresolved until needed. + +It must not eagerly decode payloads. + +## 8.2 Internal backing kinds + +Publicly there is one `ObservationSet` class. Internally it may have multiple backing strategies. + +```python +from dataclasses import dataclass +from typing import Literal + +@dataclass +class PredicateBacking: + source_name: str + query_repr: str + +@dataclass +class RefTableBacking: + table_name: str + source_streams: list[str] + ordered: bool = False + +@dataclass +class CompositeBacking: + op: Literal["union", "intersection", "difference", "project", "join"] + input_ids: list[str] + query_repr: str +``` + +Recommended internal shape: + +```python +class ObservationSet(QueryableObservationSet[T], Generic[T]): + _backing: PredicateBacking | RefTableBacking | CompositeBacking + _capabilities: set[str] + _lineage: Lineage +``` + +## 8.3 Predicate-backed set + +Use when the set is still naturally expressible as a query over the underlying source. + +Examples: + +- time range over one stream +- tag filter over one stream +- spatial filter over one stream +- text search over one stream + +No payloads need to be materialized. + +## 8.4 Ref-table-backed set + +Use when a query creates an explicit candidate pool. + +Examples: + +- top-k embedding matches +- correlation results +- reranked subsets +- cluster representatives + +Important: + +- refs do not need to live in Python memory +- they can live in a SQLite temp table +- metadata rows may be fetched page-wise + +## 8.5 Composite-backed set + +Use for union/intersection/project/join style operations over other observation sets. + +--- + +# 9. Payload loading rules + +## 9.1 Allowed eager data + +Eagerly loaded into Python is acceptable for: + +- small metadata rows +- refs +- scores +- tags + +## 9.2 Disallowed by default + +Do not eagerly load by default: + +- all image payloads +- all point clouds +- all audio blobs +- all voxel blocks + +## 9.3 Required explicit methods + +```python +payload = images.load(ref) +payloads = images.load_many(refs, batch_size=32) + +for page in image_set.iter_meta(page_size=128): + ... +``` + +No API should silently decode a thousand images just because `.fetch_set()` was called. + +--- + +# 10. Stream generation from streams + +This is a central use case. + +## 10.1 Example: embeddings from images + +```python +frames = ( + images.query() + .filter_time(now - 60, now) + .fetch_set() +) + +embeddings = frames.derive( + name="image_embeddings_clip_recent", + transform=clip_embedder, + retention="derived", + payload_type=Embedding, +) +``` + +Implementation expectation: + +- `derive()` iterates source payloads in batches +- output rows record lineage to input refs +- output stream stores its own payloads/metadata/indexes + +## 10.2 Example transform protocol + +```python +U = TypeVar("U") + +class Transform(Protocol, Generic[T, U]): + name: str + version: str + + def map_batch(self, rows: list[ObservationRow], payloads: list[T]) -> list[tuple[U, dict[str, Any]]]: ... +``` + +This allows a coding agent to implement batch transforms cleanly. + +--- + +# 11. Correlation API + +Correlation is first-class. + +## 11.1 Example + +```python +bundle = s.correlate( + anchor=log_ref, + with_streams=[images, lidar, poses], + by={ + "rgb_front": {"mode": "nearest_time", "tolerance": 0.2}, + "lidar_mid360": {"mode": "nearest_time", "tolerance": 0.1}, + "robot_pose": {"mode": "nearest_time", "tolerance": 0.05}, + }, +) +``` + +## 11.2 Correlation result shape + +```python +@dataclass +class CorrelatedItem: + stream: str + row: ObservationRow | None + reason: dict[str, Any] + +@dataclass +class CorrelationBundle: + anchor: ObservationRef + items: list[CorrelatedItem] +``` + +At minimum support: + +- nearest in time +- overlapping interval + +Later support: + +- nearest in space +- same robot\_id +- same frame\_id + +--- + +# 12. Introspection + +These are needed for human tooling and debugging. + +```python +s.list_streams() +images.info() +images.stats() +matches.capabilities() +matches.lineage() +``` + +Recommended fields for `stream.info()`: + +```python +{ + "name": "rgb_front", + "payload_type": "Image", + "row_count": 12345, + "retention": "run", + "capabilities": ["temporal", "spatial", "load"], + "time_bounds": [1700000000.0, 1700003600.0], + "spatial_bounds": [xmin, ymin, zmin, xmax, ymax, zmax], + "payload_codec": "jpeg", +} +``` + +--- + +# 13. Backend implementation target + +SQLite-first, but backend-specific details should stay behind the API. + +## 13.1 Expected SQLite tools + +- normal tables for metadata +- temp tables for candidate refs +- FTS5 for text search +- R-tree for spatial indexing +- vector extension when available + +## 13.2 Suggested mapping per stream + +- metadata table +- payload table or blob column +- optional FTS table +- optional vector index table +- optional spatial index table + +## 13.3 Important backend rule + +Unlocalized rows should not be inserted into the spatial index. + +--- + +# 14. Concrete execution examples + +## 14.1 Time-filtered image subset stays lazy + +```python +recent = ( + images.query() + .filter_time(now - 300, now) + .fetch_set() +) +``` + +Expected implementation: + +- create predicate-backed `ObservationSet` +- do not decode image payloads +- only execute SQL when rows/count/payloads are requested + +## 14.2 Embedding search becomes ref-table-backed + +```python +matches = ( + image_embeddings.query() + .search_embedding(query_vec, candidate_k=5000) + .filter_time(now - 7200, now) + .limit(1000) + .fetch_set() +) +``` + +Expected implementation: + +- run vector search +- write candidate refs + scores to temp table +- return ref-table-backed `ObservationSet` +- allow further `.query()` by restricting to that candidate table + +## 14.3 Re-query candidate set without loading payloads + +```python +nearby_matches = ( + matches.query() + .filter_near(current_pose, radius=6.0) + .limit(100) + .fetch_set() +) +``` + +Expected implementation: + +- join source metadata with candidate ref table +- apply spatial filter in backend +- return new lazy observation set + +## 14.4 Paginated preview + +```python +page = nearby_matches.fetch_page(limit=24, offset=0) +preview_refs = [row.ref for row in page] +preview_embeddings = image_embeddings.load_many(preview_refs, batch_size=24) +``` + +Again: explicit payload loading only. + +--- + +# 15. What the coding agent should implement first + +Implementation priority order: + +1. `ObservationRef`, `ObservationMeta`, `ObservationRow`, `Lineage` +2. `DB`, `Session`, `Stream` +3. `Query` with time filters and `.fetch_set()` +4. lazy `ObservationSet` with predicate backing +5. explicit payload loading methods +6. text search +7. ref-table-backed observation sets +8. embedding search +9. `project_to()` +10. `derive()` +11. correlation +12. introspection/stats + +--- + +# 16. Minimal acceptance examples + +These examples should work. + +## 16.1 Re-query narrowed data + +```python +recent = images.query().filter_time(t1, t2).fetch_set() +recent2 = recent.query().filter_near(pose, radius=2.0).fetch_set() +assert recent2.count() <= recent.count() +``` + +## 16.2 Fetch set does not load payloads + +```python +matches = images.query().filter_time(t1, t2).limit(1000).fetch_set() +# should be cheap even for large image payloads +rows = matches.fetch_page(limit=10) +assert len(rows) == 10 +``` + +## 16.3 Derived stream from narrowed set + +```python +subset = images.query().filter_time(t1, t2).limit(100).fetch_set() +captions = subset.derive( + name="captions_test", + transform=caption_model, + retention="derived", + payload_type=str, +) +assert captions.count() == subset.count() +``` + +## 16.4 Projection from embeddings to images + +```python +emb_matches = image_embeddings.query().search_embedding(qvec, candidate_k=100).fetch_set() +frame_matches = emb_matches.project_to(images) +rows = frame_matches.fetch_page(limit=5) +frames = images.load_many([r.ref for r in rows], batch_size=5) +assert len(frames) == 5 +``` + +--- + +# 17. Summary + +Memory2 should expose: + +- appendable `Stream` +- lazy read-only `ObservationSet` +- composable `Query` +- explicit payload loading +- derived stream generation +- re-queryable narrowed results +- stable observation refs +- backend-backed candidate sets instead of eager payload lists + +The most important implementation rule is this: + +> `fetch_set()` returns a lazy queryable view over observations, not a Python list of decoded payloads. diff --git a/plans/memory3.md b/plans/memory3.md new file mode 100644 index 0000000000..e4a468d456 --- /dev/null +++ b/plans/memory3.md @@ -0,0 +1,364 @@ +# Memory2 Implementation Plan + +Source of truth: `plans/memory_2_spec_v_2.md` + +## Context + +PR #1080 introduced `TimeSeriesStore[T]` with pluggable backends. Paul's review identified it mixes DB lifecycle, connection, and query concerns. `memory.md` describes a system where all sensor data is stored as temporal streams with spatial indexing, cross-stream correlation, and multimodal search. The spec (`memory_2_spec_v_2.md`) defines the full public API. This plan maps the spec to concrete SQLite implementation in `dimos/memory2/`. + +## File Structure + +``` +dimos/memory2/ + __init__.py # public exports + _sql.py # _validate_identifier(), SQL helpers + types.py # ObservationRef, ObservationMeta, ObservationRow, Lineage, Pose (spec's own Pose) + db.py # DB (Resource lifecycle, SqliteDB) + session.py # Session (connection, stream factory, correlate) + stream.py # Stream (append + QueryableObservationSet) + observation_set.py # ObservationSet (lazy, re-queryable, predicate/ref-table backed) + query.py # Query (filter/search/rank/limit → fetch/fetch_set) + test_memory2.py # tests +``` + +## Implementation Priority (per spec §15) + +### Phase 1: Core types + storage + +1. **`types.py`** — Data classes + +```python +@dataclass(frozen=True) +class ObservationRef: + stream: str + id: str + +@dataclass +class Pose: + xyz: tuple[float, float, float] + quat_xyzw: tuple[float, float, float, float] | None = None + +@dataclass +class ObservationMeta: + ref: ObservationRef + ts_start: float | None = None + ts_end: float | None = None + robot_id: str | None = None + frame_id: str | None = None + pose: Pose | None = None + pose_source: str | None = None + pose_confidence: float | None = None + payload_codec: str | None = None + payload_size_bytes: int | None = None + tags: dict[str, Any] = field(default_factory=dict) + +@dataclass +class ObservationRow: + ref: ObservationRef + ts_start: float | None = None + ts_end: float | None = None + pose: Pose | None = None + scores: dict[str, float] = field(default_factory=dict) + metadata: dict[str, Any] = field(default_factory=dict) + +@dataclass +class Lineage: + parents: list[str] = field(default_factory=list) + parent_refs: list[ObservationRef] = field(default_factory=list) + query_repr: str | None = None +``` + +Note: `Pose` here is the spec's lightweight tuple-based pose for storage/filtering. Conversion to/from DimOS `dimos.msgs.geometry_msgs.Pose` via helper: + +```python +def to_storage_pose(p: DimOSPose | DimOSPoseStamped | Pose) -> Pose: ... +def to_dimos_pose(p: Pose) -> DimOSPose: ... +``` + +2. **`_sql.py`** — SQL helpers + +```python +def validate_identifier(name: str) -> str: ... # regex check, length limit +``` + +3. **`db.py`** — DB + SqliteDB + +```python +class DB(Resource, ABC): + def session(self) -> Session: ... + def close(self) -> None: ... + def start(self) -> None: pass + def stop(self) -> None: self.close() +``` + +SqliteDB internals: +- Stores file path, creates parent dirs on connect +- `_connect()`: `sqlite3.connect()`, WAL mode, loads sqlite-vec (optional), loads FTS5 +- Tracks sessions via `WeakSet` for cleanup +- `:memory:` uses `file::memory:?cache=shared` URI +- Thread safety: each session = one connection, no `check_same_thread=False` + +4. **`session.py`** — Session + SqliteSession + +```python +class Session(ABC): + def stream(self, name: str, payload_type: type, + capabilities: set[str], retention: str = "run", + config: dict | None = None) -> Stream: ... + def list_streams(self) -> list[str]: ... + def execute(self, sql: str, params=()) -> list: ... + def close(self) -> None: ... + def __enter__ / __exit__ +``` + +SqliteSession: +- Holds one `sqlite3.Connection` +- `stream()`: creates tables if needed (see schema below), caches Stream instances +- Registers stream metadata in a `_streams` registry table + +### Phase 2: Stream + Query + ObservationSet + +5. **`stream.py`** — Stream (implements `QueryableObservationSet`) + +```python +class Stream(Generic[T]): + # Write + def append(self, payload: T, **meta: Any) -> ObservationRef: ... + def append_many(self, payloads, metas) -> list[ObservationRef]: ... + + # QueryableObservationSet protocol + def query(self) -> Query[T]: ... + def load(self, ref: ObservationRef) -> T: ... + def load_many(self, refs: list[ObservationRef], *, batch_size=32) -> list[T]: ... + def iter_meta(self, *, page_size=128) -> Iterator[list[ObservationRow]]: ... + def count(self) -> int: ... + def capabilities(self) -> set[str]: ... + + # Introspection + def meta(self, ref: ObservationRef) -> ObservationMeta: ... + def info(self) -> dict[str, Any]: ... + def stats(self) -> dict[str, Any]: ... +``` + +`append()` generates a UUID for `ObservationRef.id`, pickles payload into BLOB, inserts metadata row + R*Tree entry (if pose provided) + FTS entry (if text capable) + vector entry (if embedding capable). + +6. **`query.py`** — Query (chainable, capability-aware) + +```python +class Query(Generic[T]): + # Hard filters + def filter_time(self, t1: float, t2: float) -> Query[T]: ... + def filter_before(self, t: float) -> Query[T]: ... + def filter_after(self, t: float) -> Query[T]: ... + def filter_near(self, pose: Pose, radius: float, *, + include_unlocalized: bool = False) -> Query[T]: ... + def filter_tags(self, **tags: Any) -> Query[T]: ... + def filter_refs(self, refs: list[ObservationRef]) -> Query[T]: ... + + # Candidate generation + def search_text(self, text: str, *, candidate_k: int | None = None) -> Query[T]: ... + def search_embedding(self, vector: list[float], *, candidate_k: int) -> Query[T]: ... + + # Ranking + limit + def rank(self, **weights: float) -> Query[T]: ... + def limit(self, k: int) -> Query[T]: ... + + # Terminals + def fetch(self) -> list[ObservationRow]: ... + def fetch_set(self) -> ObservationSet[T]: ... + def count(self) -> int: ... + def one(self) -> ObservationRow: ... +``` + +Query internals: +- Accumulates filter predicates, search ops, rank spec, limit +- `fetch()`: generates SQL, executes, returns rows +- `fetch_set()`: creates an ObservationSet (predicate-backed or ref-table-backed) +- search_embedding → sqlite-vec `MATCH`, writes top-k to temp table → ref-table-backed +- search_text → FTS5 `MATCH` +- filter_near → R*Tree range query +- rank → computes composite score from available score columns + +7. **`observation_set.py`** — ObservationSet (lazy, re-queryable) + +```python +class ObservationSet(Generic[T]): + # Re-query + def query(self) -> Query[T]: ... + + # Read + def load(self, ref: ObservationRef) -> T: ... + def load_many(self, refs, *, batch_size=32) -> list[T]: ... + def refs(self, *, limit=None) -> list[ObservationRef]: ... + def rows(self, *, limit=None) -> list[ObservationRow]: ... + def one(self) -> ObservationRow: ... + def fetch_page(self, *, limit=128, offset=0) -> list[ObservationRow]: ... + def count(self) -> int: ... + def capabilities(self) -> set[str]: ... + def lineage(self) -> Lineage: ... + + # Cross-stream + def project_to(self, stream: Stream) -> ObservationSet: ... +``` + +Internal backing (spec §8): + +```python +@dataclass +class PredicateBacking: + """Lazy: expressible as SQL WHERE over source stream.""" + source_name: str + query_repr: str # serialized query filters for replay + +@dataclass +class RefTableBacking: + """Materialized: temp table of refs + scores.""" + table_name: str # SQLite temp table + source_streams: list[str] + ordered: bool = False +``` + +- `.query()` on predicate-backed → adds more predicates +- `.query()` on ref-table-backed → filters within that temp table +- `project_to()` → joins backing refs via lineage parent_refs to target stream + +### Phase 3: Later (not in first PR) + +- `derive()` with Transform protocol +- `CompositeBacking` (union/intersection/difference) +- `Correlator` / `s.correlate()` +- `retention` enforcement / cleanup +- Full introspection (stats, spatial_bounds) + +## SQLite Schema (per stream) + +### Metadata table: `{name}_meta` + +```sql +CREATE TABLE {name}_meta ( + id TEXT PRIMARY KEY, -- UUID, part of ObservationRef + ts_start REAL, + ts_end REAL, + robot_id TEXT, + frame_id TEXT, + pose_x REAL, pose_y REAL, pose_z REAL, + pose_qx REAL, pose_qy REAL, pose_qz REAL, pose_qw REAL, + pose_source TEXT, + pose_confidence REAL, + payload_codec TEXT, + payload_size_bytes INTEGER, + tags TEXT, -- JSON + parent_stream TEXT, -- lineage: source stream name + parent_id TEXT -- lineage: source observation id +); +CREATE INDEX idx_{name}_meta_ts ON {name}_meta(ts_start); +``` + +### Payload table: `{name}_payload` + +```sql +CREATE TABLE {name}_payload ( + id TEXT PRIMARY KEY, -- matches _meta.id + data BLOB NOT NULL +); +``` + +Separate from meta so queries never touch payload BLOBs. + +### R*Tree (spatial index): `{name}_rtree` + +```sql +CREATE VIRTUAL TABLE {name}_rtree USING rtree( + rowid, -- matches _meta rowid + min_t, max_t, -- ts_start, ts_end + min_x, max_x, + min_y, max_y, + min_z, max_z +); +``` + +Only rows with pose get R*Tree entries (spec §2.6: unlocalized != everywhere). +R*Tree `rowid` linked to meta via a mapping or using meta's rowid. + +### FTS5 (text search): `{name}_fts` + +```sql +CREATE VIRTUAL TABLE {name}_fts USING fts5( + id, + content, + content={name}_meta, + content_rowid=rowid +); +``` + +Only for streams with `"text"` capability. + +### Vector index (embedding search): `{name}_vec` + +```sql +CREATE VIRTUAL TABLE {name}_vec USING vec0( + embedding float[{dim}] +); +``` + +`rowid` matches meta rowid. Only for streams with `"embedding"` capability. + +## Key Design Decisions + +### Pose type bridging + +The spec defines its own lightweight `Pose(xyz, quat_xyzw)` for storage. DimOS has `dimos.msgs.geometry_msgs.Pose` with full algebra. Stream `append()` should accept either: + +```python +# DimOS Pose +images.append(frame, pose=robot_pose) # dimos.msgs.geometry_msgs.Pose + +# Spec Pose (tuples) +images.append(frame, pose=Pose(xyz=(1, 2, 3), quat_xyzw=(0, 0, 0, 1))) +``` + +Internal conversion via `to_storage_pose()` extracts `(x, y, z, qx, qy, qz, qw)` for SQL storage. + +### filter_near accepts DimOS types + +```python +from dimos.msgs.geometry_msgs import Point, Pose as DimOSPose + +q.filter_near(DimOSPose(1, 2, 3), radius=5.0) +q.filter_near(Point(1, 2, 3), radius=5.0) +q.filter_near(Pose(xyz=(1, 2, 3)), radius=5.0) +``` + +### ObservationRef identity + +`id` is a UUID4 string generated on `append()`. Never reuse timestamps as identity. + +### Unlocalized observations + +Rows without pose are NOT inserted into R*Tree. `filter_near()` excludes them by default. `include_unlocalized=True` bypasses R*Tree and scans meta table. + +### Separate payload table + +Payload BLOBs live in `{name}_payload`, separate from `{name}_meta`. This ensures queries (which only touch meta + indexes) never page in multi-MB image blobs. + +## Existing Code to Reuse + +- `dimos/memory/timeseries/sqlite.py:29` — `_validate_identifier()` regex pattern +- `dimos/msgs/geometry_msgs/Pose.py` — DimOS Pose type, `PoseLike` type alias +- `dimos/msgs/geometry_msgs/Point.py` — Point type +- `dimos/core/resource.py` — Resource ABC (start/stop/dispose) + +## Verification + +1. `uv run pytest dimos/memory2/test_memory2.py -v` — all tests pass +2. `uv run mypy dimos/memory2/` — type checks clean +3. `uv run pytest dimos/memory/timeseries/test_base.py -v` — existing tests untouched + +### Test scenarios (map to spec §16 acceptance examples) + +- Re-query narrowed data: `filter_time → fetch_set → query → filter_near → fetch_set` +- fetch_set does not load payloads: verify no BLOB reads until explicit `load()` +- Embedding search: `search_embedding → filter_time → limit → fetch_set` → ref-table backed +- Projection: `emb_matches.project_to(images)` → fetch page → load_many +- Paginated preview: `fetch_page(limit=24, offset=0)` returns ObservationRows +- Unlocalized exclusion: rows without pose excluded from `filter_near` by default From 5c23e89dc68f2ec5cacc968a75aed1e200f76f01 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 5 Mar 2026 11:51:37 +0800 Subject: [PATCH 002/118] spec iteration --- plans/analysis.md | 478 ++++++++++++++++++++ plans/answers.md | 853 ++++++++++++++++++++++++++++++++++++ plans/answers_correlator.md | 285 ++++++++++++ plans/correlator.md | 225 ++++++++++ plans/memory.md | 3 +- plans/memory1.md | 1 - plans/memory3.md | 172 ++++---- plans/questions.md | 54 +++ 8 files changed, 1982 insertions(+), 89 deletions(-) create mode 100644 plans/analysis.md create mode 100644 plans/answers.md create mode 100644 plans/answers_correlator.md create mode 100644 plans/correlator.md create mode 100644 plans/questions.md diff --git a/plans/analysis.md b/plans/analysis.md new file mode 100644 index 0000000000..39f9af4b9c --- /dev/null +++ b/plans/analysis.md @@ -0,0 +1,478 @@ +# Analysis Utilities + +Application-level analysis on Memory2 query results. NOT part of memory2 core — operates on fetched `ObservationRow` lists, no SQLite dependency. + +Location: `dimos/memory2/analysis.py` + +Dependencies: only `dimos/memory2/types.py` (ObservationRow, ObservationRef). No numpy, no sklearn. + +--- + +## 1. `cluster_observations()` + +The most common post-query pattern across Q2, Q4, Q5, Q9, Q11, Q12, Q14. + +```python +@dataclass +class Cluster: + rows: list[ObservationRow] + representative: ObservationRow # best by rank_key + + @property + def t_start(self) -> float: + return self.rows[0].ts_start + + @property + def t_end(self) -> float: + return self.rows[-1].ts_start + + @property + def duration(self) -> float: + return self.t_end - self.t_start + + @property + def center_pose(self) -> PoseLike | None: + """Average position of all localized rows.""" + ... + + +def cluster_observations( + rows: list[ObservationRow], + *, + time_scale: float | None = None, + space_scale: float | None = None, + threshold: float = 1.0, + rank_key: Callable[[ObservationRow], float] | None = None, +) -> list[Cluster]: + """Greedy sequential clustering over time and/or space. + + Distance between consecutive rows (must be sorted by ts_start): + + d = sqrt((dt/time_scale)^2 + (ds/space_scale)^2) + + New cluster starts when d > threshold. + + Args: + rows: ObservationRows, sorted by ts_start. + time_scale: Normalize temporal gap (seconds). None = ignore time. + space_scale: Normalize spatial distance (meters). None = ignore space. + threshold: Combined normalized distance to split clusters. + rank_key: Scoring function for representative selection. + Default: embedding score, then recency. + + Returns: + List of Cluster objects, each with .rows and .representative. + """ +``` + +### Modes + +```python +# Temporal only: split if gap > 10s +clusters = cluster_observations(rows, time_scale=10.0) + +# Spatial only: split if > 3m apart +clusters = cluster_observations(rows, space_scale=3.0) + +# Combined: either 10s gap OR 3m apart triggers split +clusters = cluster_observations(rows, time_scale=10.0, space_scale=3.0) + +# Bias toward spatial (space matters more): +clusters = cluster_observations(rows, time_scale=30.0, space_scale=2.0) +``` + +### Representative selection + +Default `rank_key`: `lambda r: r.scores.get("embedding", 0)` — picks the most relevant frame after a search. Override for quality-based selection: + +```python +# Quality-biased: prefer sharp, well-exposed frames +clusters = cluster_observations(rows, + time_scale=10.0, + rank_key=lambda r: ( + r.scores.get("embedding", 0) * 0.4 + + r.tags.get("sharpness", 0.5) * 0.4 + + r.tags.get("exposure", 0.5) * 0.2 + ), +) + +# Recency-biased: prefer the latest frame in each cluster +clusters = cluster_observations(rows, + time_scale=10.0, + rank_key=lambda r: r.ts_start, +) +``` + +### Which questions use this + +| Question | Mode | Purpose | +|----------|------|---------| +| Q2 — red socks viewing sessions | temporal | Group continuous sightings, VLM one per cluster | +| Q4 — where were red socks | spatial | Group nearby sightings into distinct locations | +| Q5 — door open events | temporal | Group rapid-fire "door open" detections into single events | +| Q9 — cat trail | spatial | Group into distinct locations the cat visited | +| Q11 — cat absence | temporal | (indirect — use `find_gaps` on clusters) | +| Q12 — mailman schedule | temporal | Group same-visit detections into single arrival events | +| Q14 — carrying intervals | temporal | Group "carrying" detections into continuous intervals | + +--- + +## 2. `find_gaps()` + +Find periods where observations are absent. Used in Q11 (cat absence) and Q14 (carrying interval boundaries). + +```python +@dataclass +class Gap: + t_start: float # timestamp of last observation before the gap + t_end: float # timestamp of first observation after the gap + duration: float # t_end - t_start + + +def find_gaps( + rows: list[ObservationRow], + *, + min_gap: float, +) -> list[Gap]: + """Find temporal gaps in a sorted observation list. + + Args: + rows: ObservationRows, sorted by ts_start. + min_gap: Minimum gap duration (seconds) to report. + + Returns: + List of Gap objects, sorted by time. + """ +``` + +Usage: + +```python +# Q11: When was the cat last NOT seen? +cat_seen = detections.query().filter_tags(class_name="cat").order_by("ts_start").fetch() +gaps = find_gaps(cat_seen, min_gap=60.0) +if gaps: + print(f"Last absence: {gaps[-1].t_start} to {gaps[-1].t_end}") +``` + +Works on clusters too — find gaps between cluster end and next cluster start: + +```python +# Gaps between sighting sessions (not between individual frames) +clusters = cluster_observations(cat_seen, time_scale=10.0) +# Synthesize one row per cluster (the representative) for gap analysis +cluster_reps = [c.representative for c in clusters] +session_gaps = find_gaps(cluster_reps, min_gap=300.0) +``` + +--- + +## 3. `compute_path_distance()` + +Sum of Euclidean distances along a pose trail. Used in Q9 (cat trail length) and Q14 (distance while carrying). + +```python +def compute_path_distance( + rows: list[ObservationRow], +) -> float: + """Total Euclidean path distance from consecutive poses. + + Args: + rows: ObservationRows with poses, sorted by ts_start. + Rows without pose are skipped. + + Returns: + Total distance in meters. + """ +``` + +Usage: + +```python +# Q14: How far did I travel while carrying? +for cluster in carrying_clusters: + pose_rows = poses.query().filter_time(cluster.t_start, cluster.t_end).order_by("ts_start").fetch() + dist = compute_path_distance(pose_rows) + print(f"Carried for {cluster.duration:.0f}s, traveled {dist:.1f}m") +``` + +--- + +## 4. `extract_time_pattern()` + +Extract time-of-day statistics from observations spread across multiple days. Used in Q12 (mailman schedule). + +```python +@dataclass +class TimePattern: + mean_hour: float # e.g. 10.5 = 10:30 AM + std_minutes: float # standard deviation in minutes + count: int # number of observations + times: list[float] # individual hours (for histogram) + + def __str__(self) -> str: + h = int(self.mean_hour) + m = int((self.mean_hour % 1) * 60) + return f"{h}:{m:02d} +/- {self.std_minutes:.0f}min (n={self.count})" + + +def extract_time_pattern( + rows: list[ObservationRow], + *, + tz: timezone | None = None, +) -> TimePattern: + """Extract time-of-day pattern from observations across multiple days. + + Best used on cluster representatives (one per event) rather than raw rows, + to avoid dense clusters biasing the average. + + Args: + rows: ObservationRows with ts_start. + tz: Timezone for time-of-day extraction. Default: UTC. + + Returns: + TimePattern with mean, std, and individual times. + """ +``` + +Usage: + +```python +# Q12: When does the mailman usually come? +sightings = faces.query().search_embedding(mailman_emb, candidate_k=100).fetch() +sightings = [r for r in sightings if r.scores.get("embedding", 0) > 0.8] + +# Cluster into individual visits (one per day) +visits = cluster_observations(sightings, time_scale=300.0) +pattern = extract_time_pattern([v.representative for v in visits]) +print(f"Mailman comes at {pattern}") # "10:30 +/- 12min (n=23)" +``` + +--- + +## 5. `match_viewpoints()` + +Match observations from two sets by embedding similarity — find corresponding views across time. Used in Q8 (room diff: today vs yesterday). + +```python +@dataclass +class ViewpointMatch: + current: ObservationRow + reference: ObservationRow + similarity: float + + +def match_viewpoints( + current: list[ObservationRow], + reference: list[ObservationRow], + vectors_current: list[list[float]], + vectors_reference: list[list[float]], + *, + min_similarity: float = 0.85, +) -> list[ViewpointMatch]: + """Match observations by embedding similarity (cosine via dot product). + + Assumes vectors are L2-normalized (as CLIP embeddings are). + + Pure Python — no numpy required (but callers may use numpy for + batch vector retrieval before calling this). + + Args: + current: ObservationRows from the "current" time window. + reference: ObservationRows from the "reference" time window. + vectors_current: Embedding vectors for current rows. + vectors_reference: Embedding vectors for reference rows. + min_similarity: Minimum cosine similarity for a valid match. + + Returns: + List of ViewpointMatch objects, one per matched current row. + Unmatched rows are excluded. + """ +``` + +Usage: + +```python +# Q8: What changed in this room vs yesterday? +current_imgs = images.query().filter_time(now - 300, now).filter_near(pose, radius=5.0).fetch() +yesterday_imgs = images.query().filter_time(yest - 300, yest + 300).filter_near(pose, radius=5.0).fetch() + +current_vecs = [images.vector(r.ref) for r in current_imgs] +yesterday_vecs = [images.vector(r.ref) for r in yesterday_imgs] + +matches = match_viewpoints(current_imgs, yesterday_imgs, current_vecs, yesterday_vecs) +for m in matches: + diff = vlm.ask([images.load(m.current.ref), images.load(m.reference.ref)], + "What changed between these two views?") +``` + +Note: this is O(n*m) dot products. Fine for typical sizes (tens to low hundreds of images per spatial query). For very large sets, callers can use numpy directly. + +--- + +## 6. `diff_observation_sets()` + +Find observations in set A that have no similar match in set B. Used in Q13 (cross-robot diff). + +```python +@dataclass +class UnmatchedObservation: + row: ObservationRow + best_similarity: float # highest similarity to anything in the other set + + +def diff_observation_sets( + source: list[ObservationRow], + reference: list[ObservationRow], + vectors_source: list[list[float]], + vectors_reference: list[list[float]], + *, + similarity_threshold: float = 0.7, +) -> list[UnmatchedObservation]: + """Find observations in source that have no close match in reference. + + Args: + source: Observations to check ("what did robot-2 see?") + reference: Observations to compare against ("what did I see?") + vectors_source: Embeddings for source rows. + vectors_reference: Embeddings for reference rows. + similarity_threshold: Below this = "unmatched" = novel observation. + + Returns: + List of UnmatchedObservation from source with no reference match. + """ +``` + +Usage: + +```python +# Q13: What did robot-2 see that I missed? +r2 = detections.query().filter_tags(robot_id="robot-2").filter_near(warehouse, radius=20).fetch() +me = detections.query().filter_tags(robot_id="robot-1").filter_near(warehouse, radius=20).fetch() +r2_vecs = [detections.vector(r.ref) for r in r2] +me_vecs = [detections.vector(r.ref) for r in me] + +missed = diff_observation_sets(r2, me, r2_vecs, me_vecs) +for m in missed: + print(f"Missed: {m.row.tags.get('class_name')} at {m.row.pose}") +``` + +--- + +## Quality Conventions + +Image quality metrics are stored in tags at ingest time by the pipeline. The analysis utilities don't compute quality — they consume it via `rank_key`. + +### Recommended tag keys + +| Tag | Type | Description | Range | +|-----|------|-------------|-------| +| `sharpness` | float | Laplacian variance of grayscale image | 0.0–1.0 (normalized) | +| `blur` | float | Inverse of sharpness (lower = sharper) | 0.0–1.0 | +| `exposure` | float | How well-exposed (0 = dark/blown out, 1 = good) | 0.0–1.0 | +| `occlusion` | float | Fraction of frame occluded | 0.0–1.0 | + +### Pipeline example + +```python +def compute_quality(frame) -> dict: + gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + lap_var = cv2.Laplacian(gray, cv2.CV_64F).var() + sharpness = min(lap_var / 500.0, 1.0) # normalize + mean_brightness = gray.mean() / 255.0 + exposure = 1.0 - abs(mean_brightness - 0.45) * 2 # penalize too dark/bright + return {"sharpness": round(sharpness, 3), "exposure": round(max(exposure, 0), 3)} + +images.append(frame, + pose=robot_pose, + tags=compute_quality(frame), +) +``` + +Then in analysis: + +```python +clusters = cluster_observations(candidates, + time_scale=10.0, + rank_key=lambda r: ( + r.scores.get("embedding", 0) * 0.5 + + r.tags.get("sharpness", 0.5) * 0.3 + + r.tags.get("exposure", 0.5) * 0.2 + ), +) +``` + +--- + +## The "Embed → Cluster → VLM" Pipeline + +The dominant analysis pattern across Q2, Q4, Q5, Q8. Not a function — a recipe. + +``` +1. Embedding search → candidate_k rows (cheap, fast, noisy) +2. Score filter → discard low-similarity noise +3. cluster_observations → group into distinct events/locations +4. Representative pick → best frame per cluster (by quality + relevance) +5. VLM verify → confirm/describe each representative (expensive, precise) +6. Expand → confirmed representative → entire cluster is valid +``` + +This is NOT worth wrapping in a single function because: +- The VLM prompt varies per question +- The cluster parameters vary per domain +- The expand step varies (sometimes you want all rows, sometimes just the cluster metadata) +- Steps 1-4 compose naturally with existing tools + +But documenting it as a pattern means every new question follows the same structure. + +### Example: complete pipeline for Q2 + +```python +# 1. Embedding search +candidates = images.query().search_embedding(clip_text_encode("red socks"), candidate_k=1000).order_by("ts_start").fetch() + +# 2. Score filter +candidates = [r for r in candidates if r.scores.get("embedding", 0) > 0.7] + +# 3. Cluster (temporal — group continuous viewing) +clusters = cluster_observations(candidates, + time_scale=10.0, + rank_key=lambda r: ( + r.scores.get("embedding", 0) * 0.5 + + r.tags.get("sharpness", 0.5) * 0.5 + ), +) + +# 4. VLM verify representatives only +confirmed = [] +for c in clusters: + img = images.load(c.representative.ref) + if vlm.ask(img, "Are there red socks in this image? yes/no") == "yes": + confirmed.append(c) + +# 5. Use results +print(f"Currently watching for {confirmed[-1].duration:.0f}s") +print(f"Seen {len(confirmed) - 1} time(s) before") +``` + +--- + +## Summary + +| Utility | Pure Python | Used in | Core purpose | +|---------|-------------|---------|-------------| +| `cluster_observations` | yes | Q2,Q4,Q5,Q9,Q11,Q12,Q14 | Group by time/space, pick representative | +| `find_gaps` | yes | Q11 | Detect absence periods | +| `compute_path_distance` | yes | Q9,Q14 | Trajectory length | +| `extract_time_pattern` | yes | Q12 | Time-of-day statistics | +| `match_viewpoints` | yes | Q8 | Cross-temporal view matching | +| `diff_observation_sets` | yes | Q13 | Set difference by embedding similarity | + +All utilities are stateless functions on `list[ObservationRow]`. No DB access, no numpy dependency (callers use numpy for batch vector ops if they want). Quality metrics live in tags, set by the ingest pipeline. + +### Not included (stays in application code) + +- **Identity clustering** (Q3, Q6, Q7): Requires DBSCAN/sklearn + domain-specific parameters. Too varied for a generic utility. +- **State transition detection** (Q5): "door went from closed→open" needs domain knowledge about what states exist. +- **Absence reasoning** (Q11): Distinguishing "cat not here" from "robot not looking" requires cross-referencing robot coverage — application context. +- **VLM prompting**: Every question has different prompts and response parsing. diff --git a/plans/answers.md b/plans/answers.md new file mode 100644 index 0000000000..e5cb509ee0 --- /dev/null +++ b/plans/answers.md @@ -0,0 +1,853 @@ +# Answers + +API reference: `memory3.md` (current) + +--- + +## 1. "Where was I, when this log line was added?" + "Where do motor faults keep happening?" + +**Streams**: `logs` (text-capable), `poses` (robot localization at high frequency) + +**Single log line**: + +```python +s = db.session() +logs = s.stream("logs", LogMsg, text=TextConfig()) +poses = s.stream("poses", PoseStamped) + +# Find the log entry by text +log_hit = logs.query().search_text("motor fault detected").one() + +# Look up pose at that time — .at() finds nearest within tolerance +pose_hit = poses.query().at(log_hit.ts_start, tolerance=0.5).one() +print(pose_hit.pose) # Pose(x=1.2, y=3.4, z=0.5) +``` + +**Multiple log lines → spatial map of faults**: + +```python +fault_logs = logs.query().search_text("motor fault").order_by("ts_start").fetch() + +# Correlate each to a pose +fault_locations = [] +for log_row in fault_logs: + pose_row = poses.query().at(log_row.ts_start, tolerance=0.5).fetch() + if pose_row: + fault_locations.append((log_row, pose_row[0])) + +# Cluster by location — "where do faults keep happening?" +from dimos.memory2.analysis import cluster_observations +location_clusters = cluster_observations( + [pose for _, pose in fault_locations], + space_scale=2.0, # within 2m = same spot +) + +for c in location_clusters: + print(f"{len(c.rows)} faults near {c.center_pose} " + f"({c.t_start} to {c.t_end})") + # → "12 faults near Pose(x=3.1, y=7.2) over the last 3 days" + +# Render on costmap +for c in location_clusters: + costmap.mark(pose=c.center_pose, label=f"motor faults ({len(c.rows)}x)") +``` + +**What works**: `.search_text()` finds all matching logs, `.at()` correlates each to a pose, `cluster_observations(space_scale=)` groups faults by location. The result is a heatmap of where the robot has trouble. + +**Cross-stream join**: The for-loop is the same nested-loop join pattern as Q5/Q7/Q14. `Correlator` (Phase 3) would batch this: +```python +fault_poses = s.correlate(fault_logs_set, poses, time_tolerance=0.5) +``` + +--- + +## 2. "How long have I been observing the red socks in view currently?" + "How many times did I see them before?" + +**Streams**: `images` (camera frames with CLIP embeddings and poses) + +No detection pipeline — we search raw images by embedding similarity, then VLM-verify. + +**Stage 1 — Embedding candidate retrieval**: + +```python +s = db.session() +images = s.stream("images", Image, embedding=EmbeddingConfig(dim=512)) + +socks_embedding = clip_text_encode("red socks") + +# Find all frames that might contain red socks +candidates = (images.query() + .search_embedding(socks_embedding, candidate_k=1000) + .order_by("ts_start") + .fetch()) + +# Post-filter by similarity score to discard weak matches +candidates = [h for h in candidates if h.scores.get("embedding", 0) > 0.7] +``` + +**Stage 2 — Diverse sampling before VLM** (don't waste VLM on 200 frames of the robot staring at socks): + +The embedding top-k will cluster heavily around moments of prolonged viewing. We need to spread VLM budget across time and space to discover all distinct sighting sessions. + +```python +# Cluster candidates into temporal segments (frames within 10s = same cluster) +# Then pick one representative per cluster for VLM +candidates.sort(key=lambda r: r.ts_start) + +clusters = [] # list of lists +for row in candidates: + if not clusters or row.ts_start - clusters[-1][-1].ts_start > 10.0: + clusters.append([row]) + else: + clusters[-1].append(row) + +# Pick the highest-scoring representative from each cluster +representatives = [] +for cluster in clusters: + best = max(cluster, key=lambda r: r.scores.get("embedding", 0)) + representatives.append((best, cluster)) +``` + +Now VLM verifies only the representatives — one call per temporal cluster, not per frame: + +```python +confirmed_segments = [] +for rep, cluster in representatives: + img = images.load(rep.ref) + if vlm.ask(img, "Are there red socks visible in this image? yes/no") == "yes": + # Entire cluster counts as a sighting session + confirmed_segments.append((cluster[0].ts_start, cluster[-1].ts_start)) +``` + +If the robot saw socks 5 different times across the day but stared for minutes each time, this makes ~5 VLM calls instead of 200+. + +**Stage 3 — Answer the question**: + +```python +now = time.time() + +# Current viewing session = last confirmed segment +if confirmed_segments: + current_duration = now - confirmed_segments[-1][0] + print(f"Watching red socks for {current_duration:.1f}s") + print(f"Seen them {len(confirmed_segments) - 1} time(s) before") +``` + +**What works**: Embedding search is the broad net (cheap, fast), temporal clustering deduplicates the "staring" problem, VLM confirms only one frame per cluster. Scales to long sessions without blowing VLM budget. + +**What's application logic**: Cluster gap threshold (10s), VLM prompt, what counts as "same sighting" — all domain-specific. + +**Limitation**: `candidate_k=1000` is a guess. sqlite-vec is KNN-only — no "all vectors above threshold" query. Workaround: use a large candidate_k and post-filter by score. + +**Extension — spatial diversity**: If the robot revisits the same spot repeatedly, add pose-based deduplication within temporal clusters. But temporal clustering alone handles the dominant case (continuous staring). + +--- + +## 3. "How many people did I see during last week?" + +**Pipeline**: +``` +camera frames → face detector → face crops → embedding model → face embeddings + ↓ + faces stream (each row = one detected face with identity embedding) +``` + +Yes — the `faces` stream stores detected face crops. Each append includes the face embedding. Searching over this stream by embedding finds the same face across time. + +**Streams**: `faces` (face crops with identity embeddings) + +```python +s = db.session() +faces = s.stream("faces", FaceCrop, embedding=EmbeddingConfig(dim=512)) + +one_week_ago = time.time() - 7 * 86400 +week_faces = (faces.query() + .filter_after(one_week_ago) + .order_by("ts_start") + .fetch()) + +# Get all embedding vectors for clustering +vectors = [] +for row in week_faces: + vec = faces.vector(row.ref) # retrieve stored embedding + vectors.append(vec) + +# Cluster to find unique identities +import numpy as np +from sklearn.cluster import DBSCAN + +X = np.array(vectors) +clustering = DBSCAN(eps=0.6, min_samples=2, metric="cosine").fit(X) +n_people = len(set(clustering.labels_)) - (1 if -1 in clustering.labels_ else 0) +print(f"Saw {n_people} unique people last week") +``` + +**What works**: `filter_after` for time range, `faces.vector(ref)` to retrieve stored embeddings for clustering. + +**What's application logic**: Identity clustering (DBSCAN, threshold tuning) is domain-specific — different robots may have different accuracy needs. + +**With derive() (Phase 3)**: Could automate the dedup into a persistent `people` stream, then it's just `.count()`. + +--- + +## 4. "Where did you see red socks during last week?" + +**Streams**: `images` (camera frames with CLIP embeddings and poses) + +**Stage 1 — Embedding candidate retrieval**: + +```python +s = db.session() +images = s.stream("images", Image, embedding=EmbeddingConfig(dim=512)) + +one_week_ago = time.time() - 7 * 86400 +socks_embedding = clip_text_encode("red socks") + +candidates = (images.query() + .search_embedding(socks_embedding, candidate_k=200) + .filter_after(one_week_ago) + .limit(50) + .fetch_set()) +``` + +**Stage 2 — VLM verification**: + +```python +verified_refs = [] +for row in candidates.rows(): + img = candidates.load(row.ref) + if vlm.ask(img, "Are there red socks in this image? yes/no") == "yes": + verified_refs.append(row.ref) + +# Wrap verified results back into an ObservationSet +verified = images.query().filter_refs(verified_refs).fetch_set() +``` + +`filter_refs()` gives us an ObservationSet of just the verified images — ephemeral, session-scoped. + +To persist: write to a new stream with lineage back to the originals: + +```python +red_socks = s.stream("red_socks", Image) +for ref in verified_refs: + src = images.meta(ref) + red_socks.append( + images.load(ref), + pose=src.pose, ts_start=src.ts_start, + tags={"query": "red socks"}, + parent_stream="images", parent_id=ref.id, + ) +``` + +**Stage 3 — Costmap**: + +```python +for row in verified.rows(): + costmap.mark(pose=row.pose, label="red socks", time=row.ts_start) +``` + +Every verified observation carries the robot's pose from the original image stream → direct costmap placement. + +--- + +## 5. "Did anyone ever open this door? At what times? Who opened it?" + +**Streams**: `detections` (object detections with tags), `faces` (face crops with identity embeddings) + +**Sub-question 1 & 2 — When was the door open?** + +Depends on the detection pipeline. If the detector tags door state: + +```python +detections = s.stream("detections", Detection, embedding=EmbeddingConfig(dim=512)) + +door_open = (detections.query() + .filter_tags(class_name="door", state="open") + .order_by("ts_start") + .fetch()) + +for row in door_open: + print(f"Door open at {row.ts_start}") +``` + +If the detector doesn't tag state — embedding search + VLM verify (same pattern as Q4): + +```python +open_door_emb = clip_text_encode("open door") +candidates = (images.query() + .search_embedding(open_door_emb, candidate_k=100) + .filter_near(door_location, radius=3.0) # only images near the door + .fetch()) + +# VLM verify each candidate +open_times = [r for r in candidates + if vlm.ask(images.load(r.ref), "Is this door open?") == "yes"] +``` + +**Sub-question 3 — Who opened it?** + +Cross-stream temporal+spatial correlation: for each door-open event, find faces nearby at that time. + +```python +faces = s.stream("faces", FaceCrop, embedding=EmbeddingConfig(dim=512)) + +for event in open_times: + # Find faces near the door around the time it opened + nearby = (faces.query() + .filter_time(event.ts_start - 5.0, event.ts_start + 2.0) + .filter_near(event.pose, radius=3.0) + .fetch()) + + if nearby: + # Identify the person via face embedding + vec = faces.vector(nearby[0].ref) + identity = lookup_identity(vec) # match against known faces DB + print(f"Door opened at {event.ts_start} by {identity}") +``` + +**What works**: `filter_time` + `filter_near` compose naturally for "who was here when this happened". The R*Tree + ts_start index handle this efficiently. + +**What's manual**: The for-loop is a nested-loop join. `Correlator` (Phase 3) would batch this: +```python +# Phase 3: +s.correlate(door_events, faces, time_tolerance=5.0, spatial_radius=3.0) +``` + +**State transition detection** ("door went from closed→open") is application logic. The memory system stores observations, not state machines. You'd either store explicit events in a `door_events` stream, or detect transitions by comparing consecutive detections. + +--- + +## 6. "I have a transcription log (STT) and voice embeddings — how do I figure out who is saying what?" + +**Streams**: `transcripts` (STT output, text-capable), `voice_embeddings` (speaker embeddings per audio segment) + +Two separate streams because they come from different models: STT gives you text, a speaker encoder gives you a voice identity vector. + +```python +s = db.session() +transcripts = s.stream("transcripts", Transcript, text=TextConfig()) +voice_embs = s.stream("voice_segments", VoiceSegment, embedding=EmbeddingConfig(dim=192)) +``` + +**Step 1 — Align transcripts to voice segments by time**: + +Each transcript has `ts_start`/`ts_end` (when the words were spoken). Each voice segment has a speaker embedding for that time window. + +```python +for tx_row in transcripts.query().order_by("ts_start").fetch(): + # Find the voice segment that overlaps this transcript + voice = (voice_embs.query() + .filter_time(tx_row.ts_start, tx_row.ts_end) + .one()) + + # voice.ref → voice_embs.vector(voice.ref) gives us the speaker embedding + speaker_vec = voice_embs.vector(voice.ref) + transcript_text = transcripts.load(tx_row.ref).text + + print(f"[{speaker_vec_to_name(speaker_vec)}]: {transcript_text}") +``` + +**Step 2 — Build speaker identity mapping**: + +Cluster all voice embeddings to find distinct speakers, then label: + +```python +all_voices = voice_embs.query().order_by("ts_start").fetch() +vectors = [voice_embs.vector(r.ref) for r in all_voices] + +# Cluster into distinct speakers +clustering = DBSCAN(eps=0.3, min_samples=3, metric="cosine").fit(np.array(vectors)) +# label_id → speaker name mapping (manual or via face correlation — see Q7) +``` + +**What works**: `filter_time` on voice stream using transcript's time window is the natural join key. `.vector()` retrieves stored embeddings for clustering. + +**Key insight**: The two streams are aligned by time, not by embedding similarity. We don't search by embedding across streams — we use temporal co-occurrence to pair them, then use the voice embedding for speaker identity. + +--- + +## 7. "I have parallel voice and facial recognition streams — how do I correlate voice to people? (I don't see all people speaking at all times)" + +**Streams**: `voices` (speaker embeddings per audio segment), `faces` (face identity embeddings per detection) + +The constraint "I don't see all people speaking at all times" means: +- Sometimes a person is speaking but out of camera view → voice segment exists, no face match +- Sometimes multiple people are visible but only one is speaking +- The correlation must be probabilistic, accumulated over time + +```python +s = db.session() +voices = s.stream("voices", VoiceSegment, embedding=EmbeddingConfig(dim=192)) +faces = s.stream("faces", FaceCrop, embedding=EmbeddingConfig(dim=512)) +``` + +**Step 1 — Collect unambiguous pairings** (only one face visible while voice active): + +```python +pairings = [] # (voice_embedding, face_embedding) pairs + +for v_row in voices.query().order_by("ts_start").fetch(): + # Find faces visible during this voice segment + visible_faces = (faces.query() + .filter_time(v_row.ts_start, v_row.ts_end) + .fetch()) + + if len(visible_faces) == 1: + # Unambiguous: only one person visible → must be the speaker + voice_vec = voices.vector(v_row.ref) + face_vec = faces.vector(visible_faces[0].ref) + pairings.append((voice_vec, face_vec)) +``` + +**Step 2 — Build cross-modal identity mapping**: + +```python +# Cluster voice embeddings → speaker IDs +voice_vecs = np.array([p[0] for p in pairings]) +voice_clusters = DBSCAN(eps=0.3, min_samples=2, metric="cosine").fit(voice_vecs) + +# For each voice cluster, find the most common face cluster +# This gives us: voice_speaker_id → face_identity +speaker_to_face = {} +for cluster_id in set(voice_clusters.labels_): + if cluster_id == -1: + continue + cluster_face_vecs = [p[1] for i, p in enumerate(pairings) + if voice_clusters.labels_[i] == cluster_id] + # Majority vote on face identity + face_identity = identify_majority(cluster_face_vecs) + speaker_to_face[cluster_id] = face_identity +``` + +**Step 3 — Label all voice segments** (including ambiguous ones): + +```python +for v_row in voices.query().order_by("ts_start").fetch(): + voice_vec = voices.vector(v_row.ref) + # Find nearest voice cluster → mapped face identity + speaker_id = predict_cluster(voice_vec, voice_clusters) + person = speaker_to_face.get(speaker_id, "unknown") + print(f"[{person}] spoke at {v_row.ts_start}") +``` + +**What works**: +- `filter_time` on faces using voice segment's time window — natural temporal join +- `.vector()` on both streams for cross-modal clustering +- The API provides the building blocks; the correlation logic (accumulate unambiguous pairings → build mapping → apply to ambiguous cases) is correctly application-level + +**What the constraint exposes**: "I don't see all people speaking at all times" means we can't rely on a single observation to establish identity. We need statistical accumulation — many unambiguous pairings build confidence. This is fundamentally a learning problem, not a query problem. The memory system's job is to make the data accessible; the correlation intelligence lives above. + +**With Correlator (Phase 3)**: The inner loop (for each voice segment, query faces) would become: +```python +pairs = s.correlate(voices, faces, time_tolerance=0.5) +``` +But the clustering/identity-mapping step still lives in application code. + +--- + +## 8. "What's different in this room compared to yesterday?" + +**What we need**: Compare object detections from "now" vs "yesterday" at the same location, find what changed. + +**Streams**: `images` (camera frames with CLIP embeddings and poses) + +We can't rely on a precomputed detection stream — object detection for a fixed set is expensive and not run in realtime. Instead, store raw images and diff at query time using embeddings + VLM. + +```python +s = db.session() +images = s.stream("images", Image, embedding=EmbeddingConfig(dim=512)) + +now = time.time() +yesterday = now - 86400 +robot_pose = get_current_pose() + +# Two queries: images from this room now vs yesterday +current_imgs = (images.query() + .filter_time(now - 300, now) + .filter_near(robot_pose, radius=5.0) + .fetch()) + +yesterday_imgs = (images.query() + .filter_time(yesterday - 300, yesterday + 300) + .filter_near(robot_pose, radius=5.0) + .fetch()) + +# Match viewpoints by embedding similarity (numpy, no extra queries) +current_vecs = np.array([images.vector(r.ref) for r in current_imgs]) +yesterday_vecs = np.array([images.vector(r.ref) for r in yesterday_imgs]) +similarity = current_vecs @ yesterday_vecs.T + +# Pair each current image with its closest yesterday viewpoint +pairs = [] +for i, row in enumerate(current_imgs): + j = similarity[i].argmax() + if similarity[i, j] > 0.85: # same viewpoint + pairs.append((row, yesterday_imgs[j])) + +# VLM diffs only matched viewpoint pairs +for curr, yest in pairs: + diff = vlm.ask( + [images.load(curr.ref), images.load(yest.ref)], + "What changed between these two views?") + if diff != "nothing": + print(f"At {curr.pose}: {diff}") +``` + +**What works**: Two queries retrieve the two temporal snapshots scoped to this room. Embedding similarity in numpy matches viewpoints without extra DB queries. VLM provides open-vocabulary scene comparison — no fixed object set needed. + +**What's application logic**: Viewpoint matching threshold, VLM prompting, what counts as a meaningful change. The memory system provides spatial+temporal retrieval; the VLM provides the intelligence. + +**Cost structure**: 2 DB queries + N `.vector()` reads (small, fast) + numpy matmul + M VLM calls (expensive, but only on matched pairs). + +--- + +## 9. "Show me everywhere the cat went today" + +**What we need**: Retrieve all cat detections from today, extract the pose trail, render as a path on the costmap. + +**Streams**: `detections` (object detections with poses) + +```python +s = db.session() +detections = s.stream("detections", Detection, embedding=EmbeddingConfig(dim=512)) + +today_start = start_of_day() + +# All cat detections today, ordered by time +cat_trail = (detections.query() + .filter_tags(class_name="cat") + .filter_after(today_start) + .order_by("ts_start") + .fetch()) + +# Extract the pose path +path = [(row.ts_start, row.pose) for row in cat_trail if row.pose] + +# Render on costmap +for ts, pose in path: + costmap.add_point(pose=pose, time=ts, label="cat") +costmap.draw_path([pose for _, pose in path]) +``` + +**What works**: `filter_tags(class_name="cat")` + `filter_after()` + `order_by("ts_start")` is clean and direct. Every detection carries the robot's pose → we know where the robot saw the cat, which approximates where the cat was. + +**Subtlety**: `row.pose` is the *robot's* pose when it detected the cat, not the cat's position in world frame. If you need the cat's actual position, you'd need the detection bounding box + depth + robot pose to project into world coordinates. That projection would happen in the detection pipeline before appending to the stream: + +```python +# In the detection pipeline: +cat_world_pose = project_to_world(bbox, depth_frame, robot_pose) +detections.append(detection, pose=cat_world_pose, tags={"class_name": "cat"}) +``` + +If stored this way, `row.pose` *is* the cat's world position, and the path is accurate. + +**Dense vs sparse**: If the detector runs at 5Hz and the cat is visible for an hour, that's 18,000 rows. `order_by("ts_start")` + the ts_start index handles this efficiently. For rendering, you might want to downsample: + +```python +# Fetch pages to avoid loading all 18k rows at once +for page in range(0, cat_trail_count, 100): + rows = (detections.query() + .filter_tags(class_name="cat") + .filter_after(today_start) + .order_by("ts_start") + .limit(100) # TODO: need offset on limit, or use fetch_set + fetch_page + .fetch()) +``` + +**Gap exposed**: `Query.limit(k)` has no `offset`. For pagination, you'd need `fetch_set()` then `fetch_page(limit=100, offset=N)`. This works but means you can't paginate purely at the query level. + +--- + +## 10. "What happened in the 30 seconds before the vase fell?" + +**What we need**: Detect the "vase fell" event, then slice ALL streams in a 30s window before it. + +**Streams**: `events` (detected events with tags), plus any number of other streams: `images`, `audio`, `detections`, `poses`, etc. + +```python +s = db.session() +events = s.stream("events", Event, text=TextConfig()) + +# Find the vase-fall event +vase_event = events.query().search_text("vase fell").one() +t_event = vase_event.ts_start + +# Now query every stream for the 30s window before the event +# list_streams() returns StreamInfo with payload_type, configs, count +timeline = {} +for info in s.list_streams(): + stream = s.stream(info.name, info.payload_type, + embedding=info.embedding, text=info.text) + window = (stream.query() + .filter_time(t_event - 30.0, t_event) + .order_by("ts_start") + .fetch()) + timeline[info.name] = window +``` + +**What works**: `list_streams()` returns `StreamInfo` with everything needed to reconstruct stream handles — no hardcoding payload types. `filter_time(t - 30, t)` on each stream gives the pre-event window. + +--- + +## 11. "When was the last time I did NOT see the cat in the apartment?" + +**What we need**: Find gaps in the cat detection stream — periods where no cat was detected. + +**Streams**: `detections` (object detections with tags) + +```python +s = db.session() +detections = s.stream("detections", Detection, embedding=EmbeddingConfig(dim=512)) + +# Get all cat detections, ordered by time +cat_seen = (detections.query() + .filter_tags(class_name="cat") + .order_by("ts_start") + .fetch()) + +# Find gaps — periods where the cat wasn't detected +# Gap = time between consecutive cat detections longer than some threshold +gap_threshold = 60.0 # 1 minute without seeing the cat = "not seen" + +timestamps = [r.ts_start for r in cat_seen] +gaps = [] +for i in range(1, len(timestamps)): + gap = timestamps[i] - timestamps[i - 1] + if gap > gap_threshold: + gaps.append((timestamps[i - 1], timestamps[i], gap)) + +if gaps: + # Most recent gap = last time the cat wasn't seen + last_gap = gaps[-1] + print(f"Last not seen: {last_gap[0]} to {last_gap[1]} ({last_gap[2]:.0f}s)") +else: + print("Cat has been visible continuously") +``` + +**What works**: `filter_tags` + `order_by` gives us the detection timeline. Gap analysis in Python is straightforward. + +**What the API can't do natively**: Negation queries ("when did X NOT happen") aren't expressible in the query builder. You can only query for what exists, then find gaps in Python. This is fundamentally correct — the memory system stores positive observations, not the absence of observations. Detecting absence requires knowledge of when the sensor *could* have observed (was the robot even in the apartment? was the camera on?) — that's application context. + +**Edge case**: The robot wasn't always in the apartment. A "gap" might be because the robot was in another room, not because the cat wasn't there. You'd need to cross-reference with the robot's own position to distinguish "didn't see cat because cat was absent" from "didn't see cat because robot was elsewhere." + +--- + +## 12. "What time does the mailman usually come?" + +**What we need**: We don't know who the mailman is. We need to discover them first, then find all their appearances, then extract the schedule. + +**Streams**: `images` (camera frames with CLIP embeddings and poses), `faces` (face crops with identity embeddings) + +**Stage 1 — Find the mailman via VLM** (retroactive identification): + +We know the mailman comes to the front door. Use spatial + embedding search to find candidates, VLM to confirm. + +```python +s = db.session() +images = s.stream("images", Image, embedding=EmbeddingConfig(dim=512)) +faces = s.stream("faces", FaceCrop, embedding=EmbeddingConfig(dim=512)) + +mailman_emb = clip_text_encode("person delivering mail at front door") + +# Search images near the front door +candidates = (images.query() + .search_embedding(mailman_emb, candidate_k=200) + .filter_near(front_door_pose, radius=5.0) + .fetch()) + +# Cluster temporally (don't VLM 200 frames of same delivery) +clusters = cluster_observations(candidates, time_scale=60.0) + +# VLM verify representatives +mailman_times = [] +for c in clusters: + img = images.load(c.representative.ref) + if vlm.ask(img, "Is there a person delivering mail or packages? yes/no") == "yes": + mailman_times.append(c) +``` + +**Stage 2 — Extract mailman embedding** (from confirmed sightings): + +Now we know *when* the mailman was there. Find their face embedding from those time windows. + +```python +# For each confirmed mailman visit, find faces near the door at that time +mailman_face_vecs = [] +for c in mailman_times: + nearby_faces = (faces.query() + .filter_time(c.t_start - 5.0, c.t_end + 5.0) + .filter_near(front_door_pose, radius=3.0) + .fetch()) + for f in nearby_faces: + mailman_face_vecs.append(faces.vector(f.ref)) + +# Average the face embeddings → stable mailman identity vector +import numpy as np +mailman_identity = np.mean(mailman_face_vecs, axis=0).tolist() +``` + +**Stage 3 — Search broadly with the discovered embedding**: + +Now we have a face embedding. Search ALL face data, not just near the door — catches sightings we might have missed with the spatial filter. + +```python +all_sightings = (faces.query() + .search_embedding(mailman_identity, candidate_k=200) + .fetch()) +sightings = [r for r in all_sightings if r.scores.get("embedding", 0) > 0.8] + +# Cluster into individual visits +visits = cluster_observations(sightings, time_scale=300.0) +``` + +**Stage 4 — Extract schedule**: + +```python +pattern = extract_time_pattern([v.representative for v in visits]) +print(f"Mailman comes at {pattern}") # "10:30 +/- 12min (n=23)" +``` + +**The general pattern — retroactive identification**: +1. **Describe** → CLIP text embedding + spatial constraint to narrow candidates +2. **VLM confirm** → identify positive examples (expensive, but on clustered representatives only) +3. **Extract identity embedding** → from confirmed examples, average face/object embeddings +4. **Search broadly** → use discovered embedding to find all appearances across time +5. **Analyze** → cluster, extract patterns + +This is the inverse of the usual flow (have embedding → search). Here you don't know what you're looking for until you find it via VLM, then bootstrap an embedding for broader retrieval. + +**Cross-session note**: This only works if the DB persists across days (`retention` != `"run"`). For long-term pattern analysis, use a persistent retention policy. + +--- + +## 13. "What did robot-2 observe in the warehouse that I missed?" + +**What we need**: Compare observations between two robots at the same location, find what robot-2 saw that robot-1 (me) didn't. + +**Streams**: Both robots write to the same DB (or DBs are merged). Observations carry `robot_id` in tags. + +```python +s = db.session() +detections = s.stream("detections", Detection, embedding=EmbeddingConfig(dim=512)) + +# What robot-2 saw in the warehouse +robot2_saw = (detections.query() + .filter_tags(robot_id="robot-2") + .filter_near(warehouse_pose, radius=20.0) # warehouse area + .fetch()) + +# What I saw in the same area +my_saw = (detections.query() + .filter_tags(robot_id="robot-1") + .filter_near(warehouse_pose, radius=20.0) + .fetch()) + +# Diff: find objects robot-2 detected that I didn't +# By embedding — for each of robot-2's detections, check if I have a similar one +my_vecs = [detections.vector(r.ref) for r in my_saw] + +missed = [] +for r2_row in robot2_saw: + r2_vec = detections.vector(r2_row.ref) + # Check if any of my detections are similar + similarities = [cosine_sim(r2_vec, mv) for mv in my_vecs] + if not similarities or max(similarities) < 0.7: + missed.append(r2_row) + +print(f"Robot-2 saw {len(missed)} things you missed:") +for m in missed: + print(f" {m.tags.get('class_name')} at {m.pose}") +``` + +**What works**: `filter_tags(robot_id=...)` scopes to a specific robot — `robot_id` lives in the tags JSON, queried via `filter_tags`. `filter_near` scopes to a location. `.vector()` enables cross-robot embedding comparison. No special `filter_robot()` needed. + +--- + +## 14. "How far did I travel while carrying an object?" + +**What we need**: Compute path distance from the pose stream, but only during time intervals when a parallel detection stream shows "carrying object." + +**Streams**: `poses` (robot poses at high frequency), `detections` (with "carrying" state) + +```python +s = db.session() +poses = s.stream("poses", PoseStamped) +detections = s.stream("detections", Detection, embedding=EmbeddingConfig(dim=512)) + +# Step 1: Find all time intervals where the robot was carrying an object +carrying = (detections.query() + .filter_tags(action="carrying") + .order_by("ts_start") + .fetch()) + +# Segment into continuous carrying intervals +intervals = [] +if carrying: + seg_start = carrying[0].ts_start + prev_t = carrying[0].ts_start + for r in carrying[1:]: + if r.ts_start - prev_t > 2.0: # gap = stopped carrying + intervals.append((seg_start, prev_t)) + seg_start = r.ts_start + prev_t = r.ts_start + intervals.append((seg_start, prev_t)) + +# Step 2: For each carrying interval, get poses and compute distance +import math +total_distance = 0.0 + +for t_start, t_end in intervals: + pose_rows = (poses.query() + .filter_time(t_start, t_end) + .order_by("ts_start") + .fetch()) + + for i in range(1, len(pose_rows)): + p1 = pose_rows[i - 1].pose + p2 = pose_rows[i].pose + dx = p2.position.x - p1.position.x + dy = p2.position.y - p1.position.y + dz = p2.position.z - p1.position.z + total_distance += math.sqrt(dx * dx + dy * dy + dz * dz) + +print(f"Traveled {total_distance:.2f}m while carrying objects") +``` + +**What works**: `filter_tags` identifies "carrying" intervals. `filter_time` + `order_by` retrieves the pose trail for each interval. Distance computation is simple Euclidean accumulation. + +**This is the cross-stream conditional join pattern**: Query stream A (detections) for intervals, then query stream B (poses) within those intervals. Same nested-loop pattern as Q5/Q6/Q7. + +**What would be cleaner with Correlator (Phase 3)**: +```python +# Get pose observations that overlap with carrying detections +carrying_poses = s.correlate(carrying_set, poses, time_tolerance=0.0) +``` + +--- + +## Summary + +| Question | Key API features used | Works well? | +|-------------------------------|-------------------------------------------------------------------------|--------------------------------------| +| Q1 — pose at log time | `.search_text()` + `.at()` | Yes | +| Q2 — continuous observation | `.search_embedding()` + VLM verify + `.order_by()` + segmentation | Yes | +| Q3 — count unique people | `.filter_after()` + `.vector()` + DBSCAN | Yes | +| Q4 — map red socks | `.search_embedding()` + VLM + `.filter_refs()` + costmap | Yes | +| Q5 — door opener | `.filter_tags()` + `.filter_time()` + `.filter_near()` | Yes, cross-stream loop | +| Q6 — STT + voice identity | `.filter_time()` + `.vector()` | Yes | +| Q7 — voice ↔ face | `.filter_time()` + `.vector()` + accumulation | Yes | +| Q8 — room diff | `.filter_time()` + `.filter_near()` + `.vector()` diff | Yes, diffing is app logic | +| Q9 — cat trail | `.filter_tags()` + `.order_by()` + pose path | Yes | +| Q10 — pre-event timeline | `.filter_time()` + `list_streams()` → `StreamInfo` | Yes | +| Q11 — absence detection | `.filter_tags()` + `.order_by()` + gap analysis | Yes, negation is app logic | +| Q12 — mailman schedule | `.search_embedding()` or `.filter_tags()` + time stats | Yes, pattern extraction is app logic | +| Q13 — cross-robot diff | `.filter_tags(robot_id=)` + `.filter_near()` + `.vector()` | Yes | +| Q14 — distance while carrying | `.filter_tags()` + `.filter_time()` + `.order_by()` + pose accumulation | Yes, cross-stream conditional join | + +**API gaps exposed by Q8-Q14**: + +**Remaining gap**: + +| Gap | Affects | Suggestion | +|-----|---------|------------| +| Cross-stream conditional join is always a manual loop | Q5,Q7,Q10,Q14 | Phase 3 `Correlator` — the most motivated feature | diff --git a/plans/answers_correlator.md b/plans/answers_correlator.md new file mode 100644 index 0000000000..a6be19d1c8 --- /dev/null +++ b/plans/answers_correlator.md @@ -0,0 +1,285 @@ +# Answers — with Correlator + +Side-by-side: how the cross-stream questions change with `s.correlate()`. +Only covering questions where correlator applies (Q1, Q5, Q6, Q7, Q10, Q14). + +--- + +## Q1. "Where was I when this log line was added?" + +**Before** (using `.at()`): +```python +log_hit = logs.query().search_text("motor fault detected").one() +pose_hit = poses.query().at(log_hit.ts_start, tolerance=0.5).one() +``` + +**With correlator**: +```python +log_set = logs.query().search_text("motor fault detected").fetch_set() +result = s.correlate(log_set, poses, time_tolerance=0.5) +pose = result.unambiguous()[0].matches[0].pose +``` + +**Verdict**: `.at()` is better here. Correlator adds ceremony for a single-observation lookup. Correlator wins when you have many anchors — e.g., "where was I for ALL error log lines?": + +```python +errors = logs.query().search_text("error").fetch_set() +result = s.correlate(errors, poses, time_tolerance=0.5) +for p in result.with_matches(): + error_text = logs.load(p.anchor.ref).text + print(f"{error_text} → {p.matches[0].pose}") +``` + +That replaces a loop of N `.at()` calls with one batch query. + +--- + +## Q5. "Did anyone open this door? Who?" + +**Before** (manual loop): +```python +door_events = events.query().filter_tags(type="door_open").order_by("ts_start").fetch() + +for event in door_events: + nearby = (faces.query() + .filter_time(event.ts_start - 5.0, event.ts_start + 2.0) + .filter_near(event.pose, radius=3.0) + .fetch()) + if nearby: + vec = faces.vector(nearby[0].ref) + identity = lookup_identity(vec) + print(f"Door opened at {event.ts_start} by {identity}") +``` + +**With correlator**: +```python +door_events = events.query().filter_tags(type="door_open").fetch_set() + +pairs = s.correlate(door_events, faces, + time_before=5.0, time_after=2.0, + spatial_radius=3.0) + +for p in pairs.with_matches(): + vec = faces.vector(p.matches[0].ref) + identity = lookup_identity(vec) + print(f"Door opened at {p.anchor.ts_start} by {identity}") + +# Bonus: which door openings had nobody nearby? +for anchor in pairs.unmatched(): + print(f"Door opened at {anchor.ts_start} — nobody detected") +``` + +**What changed**: +- Loop of N queries → 1 batch query +- `.unmatched()` is free — no extra work to find events with zero matches +- Asymmetric window (`time_before=5.0, time_after=2.0`) expresses "who was there just before and shortly after" naturally + +--- + +## Q6. "STT + voice embeddings — who is saying what?" + +**Before** (manual loop): +```python +for tx_row in transcripts.query().order_by("ts_start").fetch(): + voice = (voice_embs.query() + .filter_time(tx_row.ts_start, tx_row.ts_end) + .one()) + + speaker_vec = voice_embs.vector(voice.ref) + transcript_text = transcripts.load(tx_row.ref).text + print(f"[{speaker_vec_to_name(speaker_vec)}]: {transcript_text}") +``` + +**With correlator**: +```python +pairs = s.correlate(transcripts, voice_embs, time_tolerance=0.0) + +for p in pairs.with_matches(): + speaker_vec = voice_embs.vector(p.matches[0].ref) + transcript_text = transcripts.load(p.anchor.ref).text + print(f"[{speaker_vec_to_name(speaker_vec)}]: {transcript_text}") + +# Transcripts with no matching voice segment (e.g., gap in audio) +for anchor in pairs.unmatched(): + print(f"[unknown]: {transcripts.load(anchor.ref).text}") +``` + +**What changed**: +- `time_tolerance=0.0` means: target's `ts_start` must fall within anchor's `[ts_start, ts_end]` window. Since transcripts have both `ts_start`/`ts_end`, this matches voice segments that overlap with the spoken words. +- `.unmatched()` catches transcripts where audio processing failed or had gaps — previously silently lost in a `.one()` that would throw. + +--- + +## Q7. "Voice ↔ face correlation (partial overlap)" + +**Before** (manual loop + filtering): +```python +pairings = [] +for v_row in voices.query().order_by("ts_start").fetch(): + visible_faces = (faces.query() + .filter_time(v_row.ts_start, v_row.ts_end) + .fetch()) + + if len(visible_faces) == 1: + voice_vec = voices.vector(v_row.ref) + face_vec = faces.vector(visible_faces[0].ref) + pairings.append((voice_vec, face_vec)) +``` + +**With correlator**: +```python +pairs = s.correlate(voices, faces, time_tolerance=0.0) + +# Unambiguous pairings: exactly one face visible during voice segment +pairings = [] +for p in pairs.unambiguous(): + voice_vec = voices.vector(p.anchor.ref) + face_vec = faces.vector(p.matches[0].ref) + pairings.append((voice_vec, face_vec)) + +# Stats for free +total = len(pairs) +matched = len(pairs.with_matches()) +unambiguous = len(pairs.unambiguous()) +unmatched = len(pairs.unmatched()) +print(f"{total} voice segments: {unambiguous} unambiguous, " + f"{matched - unambiguous} ambiguous, {unmatched} no face visible") +``` + +**What changed**: +- `.unambiguous()` replaces the `if len(...) == 1` check +- Statistics about match quality are trivial to compute +- The "I don't see all people speaking at all times" constraint is directly visible in `.unmatched()` count + +--- + +## Q10. "What happened in the 30 seconds before the vase fell?" + +**Before** (loop over streams): +```python +vase_event = events.query().search_text("vase fell").one() +t = vase_event.ts_start + +timeline = {} +for info in s.list_streams(): + stream = s.stream(info.name, info.payload_type, + embedding=info.embedding, text=info.text) + window = (stream.query() + .filter_time(t - 30.0, t) + .order_by("ts_start") + .fetch()) + timeline[info.name] = window +``` + +**With correlator**: +```python +vase_set = events.query().search_text("vase fell").fetch_set() + +timeline = {} +for info in s.list_streams(): + stream = s.stream(info.name, info.payload_type, + embedding=info.embedding, text=info.text) + result = s.correlate(vase_set, stream, time_before=30.0, time_after=0.0) + timeline[info.name] = result +``` + +**What changed**: +- `time_before=30.0, time_after=0.0` — asymmetric window expresses "30s before, nothing after" directly. No manual `t - 30.0, t` arithmetic. +- Still loops over streams (correlator is pairwise). But each iteration is cleaner. +- If `vase_set` had multiple events (vase fell twice), you'd get per-event windows for free. The manual version would need a nested loop. + +**Honest assessment**: Marginal improvement for Q10 since the anchor is typically one event. The correlator shines more when you have many anchors. + +--- + +## Q14. "How far did I travel while carrying an object?" + +**Before** (segment + loop): +```python +# Step 1: Segment carrying detections into intervals +carrying = (detections.query() + .filter_tags(action="carrying") + .order_by("ts_start") + .fetch()) + +intervals = [] +seg_start = carrying[0].ts_start +prev_t = carrying[0].ts_start +for r in carrying[1:]: + if r.ts_start - prev_t > 2.0: + intervals.append((seg_start, prev_t)) + seg_start = r.ts_start + prev_t = r.ts_start +intervals.append((seg_start, prev_t)) + +# Step 2: For each interval, get poses and sum distance +total_distance = 0.0 +for t_start, t_end in intervals: + pose_rows = (poses.query() + .filter_time(t_start, t_end) + .order_by("ts_start") + .fetch()) + for i in range(1, len(pose_rows)): + total_distance += distance(pose_rows[i-1].pose, pose_rows[i].pose) +``` + +**With correlator**: +```python +carrying = detections.query().filter_tags(action="carrying").fetch_set() + +pairs = s.correlate(carrying, poses, time_tolerance=0.1) + +# Each carrying detection gets matched to nearby poses +# Deduplicate: collect all unique matched pose refs, sorted by time +seen_pose_refs = set() +all_poses = [] +for p in pairs: + for m in p.matches: + if m.ref.id not in seen_pose_refs: + seen_pose_refs.add(m.ref.id) + all_poses.append(m) + +all_poses.sort(key=lambda r: r.ts_start) + +total_distance = 0.0 +for i in range(1, len(all_poses)): + total_distance += distance(all_poses[i-1].pose, all_poses[i].pose) +``` + +**Honest assessment**: The correlator version is *not* cleaner here. The problem is that carrying detections are per-frame (one every 0.2s at 5Hz), so you get thousands of overlapping CorrelationPairs that all match the same poses. You need deduplication, which is awkward. + +The original approach is actually better: segment into intervals first (app logic), then do one time-range query per interval. Correlator is designed for "match discrete events to another stream", not "define continuous intervals and query within them." + +**When correlator WOULD help for Q14**: If carrying detections had `ts_start`/`ts_end` representing the full carry interval (not per-frame), then: + +```python +# If each carrying observation spans the full interval +carry_intervals = detections.query().filter_tags(action="carrying").fetch_set() + +pairs = s.correlate(carry_intervals, poses, time_tolerance=0.0) +total_distance = 0.0 +for p in pairs: + sorted_poses = sorted(p.matches, key=lambda r: r.ts_start) + for i in range(1, len(sorted_poses)): + total_distance += distance(sorted_poses[i-1].pose, sorted_poses[i].pose) +``` + +Clean — but requires interval-shaped observations. The segmentation from point detections to intervals is still app logic. + +--- + +## Summary + +| Q | Before | With Correlator | Improvement | +|---|--------|----------------|-------------| +| Q1 | `.at()` — 1 query | overkill for single lookup | None (`.at()` is better) | +| Q1 batch | N `.at()` calls | 1 batch query | Yes — N→1 queries | +| Q5 | N queries in loop | 1 batch + `.with_matches()` / `.unmatched()` | Yes — cleaner + free stats | +| Q6 | N queries in loop | 1 batch + `.unmatched()` catches gaps | Yes — cleaner + error visibility | +| Q7 | N queries + `if len==1` | 1 batch + `.unambiguous()` | Yes — most natural fit | +| Q10 | N queries (1 per stream) | N correlate calls, asymmetric window | Marginal — still loops over streams | +| Q14 | segment + N queries | messy dedup of overlapping pairs | No — manual approach is better for continuous intervals | + +**Key insight**: Correlator is best for **discrete events correlated against another stream** (Q5, Q6, Q7). It's less useful for continuous intervals (Q14) or single-observation lookups (Q1). The sweet spot is "I have 50-5000 anchors and want matches from another stream for each." + +**API validated**: `time_before`/`time_after` asymmetric windows are needed (Q5, Q10). `.unambiguous()` and `.unmatched()` are the most-used convenience methods. diff --git a/plans/correlator.md b/plans/correlator.md new file mode 100644 index 0000000000..4141506f68 --- /dev/null +++ b/plans/correlator.md @@ -0,0 +1,225 @@ +# Correlator + +Cross-stream temporal+spatial join for Memory2. + +## Motivation + +5 of 14 usage questions (Q5, Q6, Q7, Q10, Q14) require the same pattern: + +```python +for anchor in stream_a.query().fetch(): + matches = (stream_b.query() + .filter_time(anchor.ts_start - tol, anchor.ts_end + tol) + .filter_near(anchor.pose, radius=r) + .fetch()) + # do something with (anchor, matches) +``` + +This is a nested-loop join — N queries, one per anchor observation. Correlator replaces it with a single batch operation. + +## API + +Method on Session: + +```python +class Session: + def correlate( + self, + anchors: Stream | ObservationSet, + targets: Stream | ObservationSet, + *, + time_tolerance: float | None = None, # symmetric: sets both before and after + time_before: float | None = None, # asymmetric: window before anchor ts_start + time_after: float | None = None, # asymmetric: window after anchor ts_end + spatial_radius: float | None = None, + ) -> CorrelationResult: ... +``` + +Accepts Stream (correlate everything) or ObservationSet (correlate a filtered subset). + +**Time window per anchor**: `[ts_start - time_before, ts_end + time_after]`. If `ts_end` is None, uses `ts_start` for both. `time_tolerance` is shorthand for `time_before = time_after = time_tolerance`. Explicit `time_before`/`time_after` override `time_tolerance`. + +### CorrelationResult + +```python +@dataclass +class CorrelationPair: + anchor: ObservationRow + matches: list[ObservationRow] + +class CorrelationResult: + def __iter__(self) -> Iterator[CorrelationPair]: ... + def __len__(self) -> int: ... + + # Filter by match cardinality + def unambiguous(self) -> list[CorrelationPair]: + """Pairs where exactly one target matched.""" + ... + + def with_matches(self) -> list[CorrelationPair]: + """Pairs where at least one target matched.""" + ... + + def unmatched(self) -> list[ObservationRow]: + """Anchor observations with zero matches.""" + ... +``` + +### Usage + +**Q5 — Who opened the door?** +```python +door_events = events.query().filter_tags(type="door_open").fetch_set() + +pairs = s.correlate(door_events, faces, time_tolerance=5.0, spatial_radius=3.0) +for p in pairs.with_matches(): + identity = identify_face(faces.vector(p.matches[0].ref)) + print(f"Door opened at {p.anchor.ts_start} — {identity}") +``` + +**Q7 — Voice ↔ face (unambiguous only)** +```python +pairs = s.correlate(voices, faces, time_tolerance=0.5) +for p in pairs.unambiguous(): + voice_vec = voices.vector(p.anchor.ref) + face_vec = faces.vector(p.matches[0].ref) + pairings.append((voice_vec, face_vec)) +``` + +**Q10 — Pre-event timeline (30s before, nothing after)** +```python +vase_event = events.query().search_text("vase fell").fetch_set() + +timeline = {} +for info in s.list_streams(): + stream = s.stream(info.name, info.payload_type, + embedding=info.embedding, text=info.text) + result = s.correlate(vase_event, stream, time_before=30.0, time_after=0.0) + timeline[info.name] = result +``` + +**Q14 — Distance while carrying** +```python +carrying = detections.query().filter_tags(action="carrying").fetch_set() + +pairs = s.correlate(carrying, poses, time_tolerance=0.0) +total_distance = 0.0 +for p in pairs: + sorted_poses = sorted(p.matches, key=lambda r: r.ts_start) + for i in range(1, len(sorted_poses)): + total_distance += distance(sorted_poses[i-1].pose, sorted_poses[i].pose) +``` + +## Implementation + +### SQL batch join (single query instead of N) + +```sql +-- 1. Materialize anchors into temp table +CREATE TEMP TABLE _corr_anchors ( + anchor_id TEXT, + ts_lo REAL, -- ts_start - time_before + ts_hi REAL, -- (ts_end or ts_start) + time_after + pose_x REAL, + pose_y REAL, + pose_z REAL +); + +-- 2. Join to target stream's _meta on time overlap +SELECT a.anchor_id, t.* +FROM _corr_anchors a +JOIN {target}_meta t + ON t.ts_start >= a.ts_lo AND t.ts_start <= a.ts_hi +ORDER BY a.anchor_id, t.ts_start; +``` + +With spatial constraint, add R*Tree join: + +```sql +SELECT a.anchor_id, t.* +FROM _corr_anchors a +JOIN {target}_rtree r + ON r.min_x >= a.pose_x - :radius AND r.max_x <= a.pose_x + :radius + AND r.min_y >= a.pose_y - :radius AND r.max_y <= a.pose_y + :radius + AND r.min_z >= a.pose_z - :radius AND r.max_z <= a.pose_z + :radius + AND r.min_t >= a.ts_lo AND r.max_t <= a.ts_hi +JOIN {target}_meta t ON t.rowid = r.rowid +ORDER BY a.anchor_id, t.ts_start; +``` + +### Grouping + +The SQL returns flat rows sorted by `anchor_id`. Group in Python into `CorrelationPair`s: + +```python +pairs = [] +current_anchor = None +current_matches = [] +for row in cursor: + anchor_id = row["anchor_id"] + if anchor_id != current_anchor: + if current_anchor is not None: + pairs.append(CorrelationPair(anchor=..., matches=current_matches)) + current_anchor = anchor_id + current_matches = [] + current_matches.append(to_observation_row(row)) +``` + +### Performance + +- Temp table insert: O(A) where A = anchor count +- Join: SQLite uses the ts_start index (or R*Tree) on the target → O(A × log(T)) where T = target count +- vs naive loop: O(A × T) without indexes, O(A × log(T)) with indexes but A round-trips + +The batch approach saves round-trip overhead and lets SQLite optimize the join plan. For A=1000 anchors, that's 1 query vs 1000. + +## Types + +```python +# In types.py +@dataclass +class CorrelationPair: + anchor: ObservationRow + matches: list[ObservationRow] +``` + +```python +# In correlation.py +class CorrelationResult: + _pairs: list[CorrelationPair] + + def __iter__(self): return iter(self._pairs) + def __len__(self): return len(self._pairs) + + def unambiguous(self) -> list[CorrelationPair]: + return [p for p in self._pairs if len(p.matches) == 1] + + def with_matches(self) -> list[CorrelationPair]: + return [p for p in self._pairs if p.matches] + + def unmatched(self) -> list[ObservationRow]: + return [p.anchor for p in self._pairs if not p.matches] +``` + +## File structure update + +``` +dimos/memory2/ + ... + correlation.py # CorrelationResult, correlate() implementation +``` + +## Phase + +This can be Phase 2b — after Stream + Query + ObservationSet are working. It depends on: +- ObservationSet (anchors can be a set) +- Stream._meta table schema (join target) +- Session.execute() (raw SQL for the batch join) + +No dependency on derive(), CompositeBacking, or retention. + +## Not in scope + +- **Continuous/streaming correlation** — this is one-shot batch. Live correlation (new anchor arrives → auto-query targets) is a different abstraction. +- **Multi-stream correlation** — correlate(A, [B, C, D]) returning aligned tuples. Call correlate() multiple times instead. +- **Embedding cross-match** — correlation is time+space only. "Find similar embeddings across streams" is a different operation (use search_embedding on each stream). diff --git a/plans/memory.md b/plans/memory.md index f884e8dfb6..447e7d3449 100644 --- a/plans/memory.md +++ b/plans/memory.md @@ -93,7 +93,6 @@ agent calls goto (event cluster 3) cluster 3 - find best image, correlate to lidar, project into space, navigate, once there, use VLM and visual nav - ## example interaction 2: arm mustafa is able to ask for an object in proximity to the robot. robot searches memory biasing distance in time and space. if close match is not found, search can be expanded @@ -110,3 +109,5 @@ mustafa is able to ask for an object in proximity to the robot. robot searches m # Questions "where was I, when this log line was added" + +"how long for have I been observing this object" diff --git a/plans/memory1.md b/plans/memory1.md index 9e9c7028b1..fd2ddb7e83 100644 --- a/plans/memory1.md +++ b/plans/memory1.md @@ -32,7 +32,6 @@ with db.session() as s: # --- Save with optional spatial context --- images.save(frame) # temporal only images.save(frame, pose=robot_pose) # temporal + spatial (baked in) - images.save(frame, pose=(pos, quat)) # PoseLike tuple also works # --- Temporal queries (chainable) --- hit = images.at(now).one() # closest to now → Hit | None diff --git a/plans/memory3.md b/plans/memory3.md index e4a468d456..2f0ba8a2ad 100644 --- a/plans/memory3.md +++ b/plans/memory3.md @@ -12,10 +12,10 @@ PR #1080 introduced `TimeSeriesStore[T]` with pluggable backends. Paul's review dimos/memory2/ __init__.py # public exports _sql.py # _validate_identifier(), SQL helpers - types.py # ObservationRef, ObservationMeta, ObservationRow, Lineage, Pose (spec's own Pose) + types.py # ObservationRef, ObservationMeta, ObservationRow, Lineage, StreamInfo db.py # DB (Resource lifecycle, SqliteDB) session.py # Session (connection, stream factory, correlate) - stream.py # Stream (append + QueryableObservationSet) + stream.py # StreamBase, BlobStream, EmbeddingStream, TextStream observation_set.py # ObservationSet (lazy, re-queryable, predicate/ref-table backed) query.py # Query (filter/search/rank/limit → fetch/fetch_set) test_memory2.py # tests @@ -28,52 +28,26 @@ dimos/memory2/ 1. **`types.py`** — Data classes ```python + @dataclass(frozen=True) class ObservationRef: stream: str - id: str - -@dataclass -class Pose: - xyz: tuple[float, float, float] - quat_xyzw: tuple[float, float, float, float] | None = None - -@dataclass -class ObservationMeta: - ref: ObservationRef - ts_start: float | None = None - ts_end: float | None = None - robot_id: str | None = None - frame_id: str | None = None - pose: Pose | None = None - pose_source: str | None = None - pose_confidence: float | None = None - payload_codec: str | None = None - payload_size_bytes: int | None = None - tags: dict[str, Any] = field(default_factory=dict) + rowid: int @dataclass class ObservationRow: ref: ObservationRef - ts_start: float | None = None - ts_end: float | None = None - pose: Pose | None = None + ts: float | None = None + pose: PoseLike | None = None scores: dict[str, float] = field(default_factory=dict) - metadata: dict[str, Any] = field(default_factory=dict) + tags: dict[str, Any] = field(default_factory=dict) @dataclass class Lineage: - parents: list[str] = field(default_factory=list) - parent_refs: list[ObservationRef] = field(default_factory=list) - query_repr: str | None = None + parent_ref: ObservationRef | None = None # single parent via parent_stream + parent_rowid ``` -Note: `Pose` here is the spec's lightweight tuple-based pose for storage/filtering. Conversion to/from DimOS `dimos.msgs.geometry_msgs.Pose` via helper: - -```python -def to_storage_pose(p: DimOSPose | DimOSPoseStamped | Pose) -> Pose: ... -def to_dimos_pose(p: Pose) -> DimOSPose: ... -``` +Poses use DimOS's existing `PoseLike` type alias (`Pose | PoseStamped | Point | PointStamped`). Internally, `append()` extracts `(x, y, z, qx, qy, qz, qw)` floats for SQL storage; `load` reconstructs a `dimos.msgs.geometry_msgs.Pose` from stored floats. No custom Pose type. 2. **`_sql.py`** — SQL helpers @@ -101,11 +75,21 @@ SqliteDB internals: 4. **`session.py`** — Session + SqliteSession ```python +@dataclass +class StreamInfo: + name: str + payload_type: type + count: int + class Session(ABC): - def stream(self, name: str, payload_type: type, - capabilities: set[str], retention: str = "run", - config: dict | None = None) -> Stream: ... - def list_streams(self) -> list[str]: ... + def stream(self, name: str, payload_type: type, *, + retention: str = "run") -> BlobStream: ... + def embedding_stream(self, name: str, payload_type: type, *, + dim: int, retention: str = "run") -> EmbeddingStream: ... + def text_stream(self, name: str, payload_type: type, *, + tokenizer: str = "unicode61", + retention: str = "run") -> TextStream: ... + def list_streams(self) -> list[StreamInfo]: ... def execute(self, sql: str, params=()) -> list: ... def close(self) -> None: ... def __enter__ / __exit__ @@ -113,34 +97,48 @@ class Session(ABC): SqliteSession: - Holds one `sqlite3.Connection` -- `stream()`: creates tables if needed (see schema below), caches Stream instances +- `stream()` / `embedding_stream()` / `text_stream()`: creates tables if needed (see schema below), caches StreamBase instances - Registers stream metadata in a `_streams` registry table ### Phase 2: Stream + Query + ObservationSet -5. **`stream.py`** — Stream (implements `QueryableObservationSet`) +5. **`stream.py`** — Stream hierarchy (subclassed by data type) ```python -class Stream(Generic[T]): +class StreamBase(ABC, Generic[T]): + """Abstract base: meta + payload + spatial index. No text/vector indexes.""" # Write def append(self, payload: T, **meta: Any) -> ObservationRef: ... def append_many(self, payloads, metas) -> list[ObservationRef]: ... - # QueryableObservationSet protocol + # Read def query(self) -> Query[T]: ... def load(self, ref: ObservationRef) -> T: ... def load_many(self, refs: list[ObservationRef], *, batch_size=32) -> list[T]: ... def iter_meta(self, *, page_size=128) -> Iterator[list[ObservationRow]]: ... def count(self) -> int: ... - def capabilities(self) -> set[str]: ... # Introspection def meta(self, ref: ObservationRef) -> ObservationMeta: ... def info(self) -> dict[str, Any]: ... def stats(self) -> dict[str, Any]: ... + +class BlobStream(StreamBase[T]): + """Concrete stream for arbitrary LCM-serializable payloads. No special indexes.""" + +class EmbeddingStream(StreamBase[T]): + """Stream with a vec0 vector index. append() also inserts into _vec table.""" + def __init__(self, ..., *, dim: int): ... + def vector(self, ref: ObservationRef) -> list[float] | None: ... + # search_embedding() on Query is valid only for EmbeddingStream + +class TextStream(StreamBase[T]): + """Stream with an FTS5 index. append() also inserts into _fts table.""" + def __init__(self, ..., *, tokenizer: str = "unicode61"): ... + # search_text() on Query is valid only for TextStream ``` -`append()` generates a UUID for `ObservationRef.id`, pickles payload into BLOB, inserts metadata row + R*Tree entry (if pose provided) + FTS entry (if text capable) + vector entry (if embedding capable). +`append()` inserts a metadata row (SQLite auto-assigns `rowid`), serializes payload via `lcm_encode()` into `_payload` BLOB, and inserts an R*Tree entry if pose is provided. `EmbeddingStream.append()` also inserts into the `_vec` table; `TextStream.append()` also inserts into the `_fts` table. Returns `ObservationRef(stream, rowid)`. `load()` deserializes via `lcm_decode()` using the stream's `payload_type`. 6. **`query.py`** — Query (chainable, capability-aware) @@ -150,17 +148,19 @@ class Query(Generic[T]): def filter_time(self, t1: float, t2: float) -> Query[T]: ... def filter_before(self, t: float) -> Query[T]: ... def filter_after(self, t: float) -> Query[T]: ... - def filter_near(self, pose: Pose, radius: float, *, + def filter_near(self, pose: PoseLike, radius: float, *, include_unlocalized: bool = False) -> Query[T]: ... def filter_tags(self, **tags: Any) -> Query[T]: ... def filter_refs(self, refs: list[ObservationRef]) -> Query[T]: ... + def at(self, t: float, *, tolerance: float = 1.0) -> Query[T]: ... # Candidate generation def search_text(self, text: str, *, candidate_k: int | None = None) -> Query[T]: ... def search_embedding(self, vector: list[float], *, candidate_k: int) -> Query[T]: ... - # Ranking + limit + # Ranking + ordering + limit def rank(self, **weights: float) -> Query[T]: ... + def order_by(self, field: str, *, desc: bool = False) -> Query[T]: ... def limit(self, k: int) -> Query[T]: ... # Terminals @@ -171,7 +171,9 @@ class Query(Generic[T]): ``` Query internals: -- Accumulates filter predicates, search ops, rank spec, limit +- Accumulates filter predicates, search ops, rank spec, ordering, limit +- `at(t, tolerance)` → sugar for `filter_time(t - tol, t + tol)` + `ORDER BY ABS(ts_start - t) LIMIT 1` +- `order_by(field, desc)` → appends `ORDER BY` clause; valid fields: `ts_start`, `ts_end` - `fetch()`: generates SQL, executes, returns rows - `fetch_set()`: creates an ObservationSet (predicate-backed or ref-table-backed) - search_embedding → sqlite-vec `MATCH`, writes top-k to temp table → ref-table-backed @@ -194,11 +196,16 @@ class ObservationSet(Generic[T]): def one(self) -> ObservationRow: ... def fetch_page(self, *, limit=128, offset=0) -> list[ObservationRow]: ... def count(self) -> int: ... - def capabilities(self) -> set[str]: ... def lineage(self) -> Lineage: ... # Cross-stream - def project_to(self, stream: Stream) -> ObservationSet: ... + def project_to(self, stream: StreamBase) -> ObservationSet: ... + + # Cleanup (ref-table-backed only; no-op for predicate-backed) + def close(self) -> None: ... + def __enter__(self) -> Self: ... + def __exit__(self, *exc) -> None: ... + def __del__(self) -> None: ... # best-effort fallback ``` Internal backing (spec §8): @@ -221,6 +228,9 @@ class RefTableBacking: - `.query()` on predicate-backed → adds more predicates - `.query()` on ref-table-backed → filters within that temp table - `project_to()` → joins backing refs via lineage parent_refs to target stream +- `close()` drops the temp table for ref-table-backed sets; no-op for predicate-backed +- Supports context manager (`with`) for deterministic cleanup; `__del__` as fallback +- SQLite connection close is the final safety net for any leaked temp tables ### Phase 3: Later (not in first PR) @@ -236,29 +246,22 @@ class RefTableBacking: ```sql CREATE TABLE {name}_meta ( - id TEXT PRIMARY KEY, -- UUID, part of ObservationRef - ts_start REAL, - ts_end REAL, - robot_id TEXT, - frame_id TEXT, + rowid INTEGER PRIMARY KEY, -- auto-assigned, used by R*Tree/FTS/vec0 + ts REAL, pose_x REAL, pose_y REAL, pose_z REAL, pose_qx REAL, pose_qy REAL, pose_qz REAL, pose_qw REAL, - pose_source TEXT, - pose_confidence REAL, - payload_codec TEXT, - payload_size_bytes INTEGER, - tags TEXT, -- JSON + tags TEXT, -- JSON (robot_id, frame_id, etc.) parent_stream TEXT, -- lineage: source stream name - parent_id TEXT -- lineage: source observation id + parent_rowid INTEGER -- lineage: source observation rowid ); -CREATE INDEX idx_{name}_meta_ts ON {name}_meta(ts_start); +CREATE INDEX idx_{name}_meta_ts ON {name}_meta(ts); ``` ### Payload table: `{name}_payload` ```sql CREATE TABLE {name}_payload ( - id TEXT PRIMARY KEY, -- matches _meta.id + rowid INTEGER PRIMARY KEY, -- matches _meta.rowid data BLOB NOT NULL ); ``` @@ -270,28 +273,29 @@ Separate from meta so queries never touch payload BLOBs. ```sql CREATE VIRTUAL TABLE {name}_rtree USING rtree( rowid, -- matches _meta rowid - min_t, max_t, -- ts_start, ts_end - min_x, max_x, - min_y, max_y, - min_z, max_z + min_t, max_t, -- both set to ts (point, not range) + min_x, max_x, -- both set to pose_x + min_y, max_y, -- both set to pose_y + min_z, max_z -- both set to pose_z ); ``` -Only rows with pose get R*Tree entries (spec §2.6: unlocalized != everywhere). -R*Tree `rowid` linked to meta via a mapping or using meta's rowid. +Only rows with pose get R*Tree entries (unlocalized != everywhere). +R*Tree `rowid` matches `_meta.rowid` directly — no mapping needed. +Time-only queries use the B-tree index on `_meta.ts` (faster than R*Tree for 1D). +Spatial or spatio-temporal queries use the R*Tree. ### FTS5 (text search): `{name}_fts` ```sql CREATE VIRTUAL TABLE {name}_fts USING fts5( - id, content, content={name}_meta, content_rowid=rowid ); ``` -Only for streams with `"text"` capability. +Created by `TextStream` subclass only. ### Vector index (embedding search): `{name}_vec` @@ -301,33 +305,26 @@ CREATE VIRTUAL TABLE {name}_vec USING vec0( ); ``` -`rowid` matches meta rowid. Only for streams with `"embedding"` capability. +`rowid` matches meta rowid. Created by `EmbeddingStream` subclass only. ## Key Design Decisions -### Pose type bridging +### Pose handling -The spec defines its own lightweight `Pose(xyz, quat_xyzw)` for storage. DimOS has `dimos.msgs.geometry_msgs.Pose` with full algebra. Stream `append()` should accept either: +All pose parameters accept `PoseLike` (`Pose | PoseStamped | Point | PointStamped` from `dimos.msgs.geometry_msgs`). No custom pose type. ```python -# DimOS Pose -images.append(frame, pose=robot_pose) # dimos.msgs.geometry_msgs.Pose +from dimos.msgs.geometry_msgs import Pose, Point -# Spec Pose (tuples) -images.append(frame, pose=Pose(xyz=(1, 2, 3), quat_xyzw=(0, 0, 0, 1))) +images.append(frame, pose=robot_pose) # Pose object +q.filter_near(Point(1, 2, 3), radius=5.0) # Point object ``` -Internal conversion via `to_storage_pose()` extracts `(x, y, z, qx, qy, qz, qw)` for SQL storage. +Internally, `_extract_pose(p: PoseLike) -> tuple[float, ...]` pulls `(x, y, z, qx, qy, qz, qw)` for SQL columns. `ObservationRow.pose` returns a reconstructed `dimos.msgs.geometry_msgs.Pose`. -### filter_near accepts DimOS types +### Payload serialization -```python -from dimos.msgs.geometry_msgs import Point, Pose as DimOSPose - -q.filter_near(DimOSPose(1, 2, 3), radius=5.0) -q.filter_near(Point(1, 2, 3), radius=5.0) -q.filter_near(Pose(xyz=(1, 2, 3)), radius=5.0) -``` +Only LCM message types are storable. `append()` calls `lcm_encode(payload)`, `load()` calls `lcm_decode(blob, payload_type)`. Non-LCM types are rejected at `append()` time with a `TypeError`. ### ObservationRef identity @@ -347,6 +344,7 @@ Payload BLOBs live in `{name}_payload`, separate from `{name}_meta`. This ensure - `dimos/msgs/geometry_msgs/Pose.py` — DimOS Pose type, `PoseLike` type alias - `dimos/msgs/geometry_msgs/Point.py` — Point type - `dimos/core/resource.py` — Resource ABC (start/stop/dispose) +- LCM `lcm_encode()` / `lcm_decode()` — payload serialization ## Verification diff --git a/plans/questions.md b/plans/questions.md new file mode 100644 index 0000000000..4010bd55f4 --- /dev/null +++ b/plans/questions.md @@ -0,0 +1,54 @@ +# Questions + +1. "where was I when this log line was added?" +- pose lookup from a timestamp + +2. "how long have I been observing the red socks currently in view?" +- how many times did I see them before? +- temporal duration tracking + observation frequency + +3. "how many people did I see during last week?" +- assume we are generating a facial recognition db — is this matching a face detection stream, then embeddings? then we are searching over that stream? + +4. "where did you see red socks during last week?" +- we query for red socks embedding similarity, then feed this data into a VLM that further filters for socks +- is this data output into some table? is it like an ObservationSet again? +- then we can create a map (costmap) of red socks? + +5. "did anyone ever open this door? at what times did I see this door open? who opened it?" +- event detection + temporal querying of state changes + +6. "I have a transcription log (STT) and voice embeddings, how do I figure out who is saying what?" +- cross-stream correlation: audio → identity + +7. "I have parallel voice and facial recognition streams, how do I correlate voice to people?" +- I don't see all people speaking at all times +- multi-modal fusion with incomplete overlap + +8. "what's different in this room compared to yesterday?" +- comparing scene snapshots across time, diffing object sets +- requires baseline modeling / temporal comparison + +9. "show me everywhere the cat went today" +- continuous spatial tracking over time, not point queries +- dense pose-stream retrieval + path aggregation + +10. "what happened in the 30 seconds before the vase fell?" +- event-anchored temporal window across all streams +- multi-stream temporal slicing relative to a detected event + +11. "when was the last time I did NOT see the cat in the apartment?" +- negation query — finding gaps in an observation stream +- architecturally different from presence queries + +12. "what time does the mailman usually come?" +- aggregation across days, extracting temporal regularity from sparse events +- cross-session pattern extraction + +13. "what did robot-2 observe in the warehouse that I missed?" +- cross-agent memory diff +- session/robot-scoped queries and set difference across streams + +14. "how far did I travel while carrying an object?" +- filtered pose integration — only accumulate distance when a parallel detection stream has a positive signal +- cross-stream conditional joins From dc0f946f9b10c3055aacbff089e37c645c4e6f1e Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 5 Mar 2026 12:04:06 +0800 Subject: [PATCH 003/118] spec iteration --- plans/memory3.md | 64 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 43 insertions(+), 21 deletions(-) diff --git a/plans/memory3.md b/plans/memory3.md index 2f0ba8a2ad..386b56efcc 100644 --- a/plans/memory3.md +++ b/plans/memory3.md @@ -4,7 +4,7 @@ Source of truth: `plans/memory_2_spec_v_2.md` ## Context -PR #1080 introduced `TimeSeriesStore[T]` with pluggable backends. Paul's review identified it mixes DB lifecycle, connection, and query concerns. `memory.md` describes a system where all sensor data is stored as temporal streams with spatial indexing, cross-stream correlation, and multimodal search. The spec (`memory_2_spec_v_2.md`) defines the full public API. This plan maps the spec to concrete SQLite implementation in `dimos/memory2/`. +Check `questions.md` ## File Structure @@ -12,7 +12,7 @@ PR #1080 introduced `TimeSeriesStore[T]` with pluggable backends. Paul's review dimos/memory2/ __init__.py # public exports _sql.py # _validate_identifier(), SQL helpers - types.py # ObservationRef, ObservationMeta, ObservationRow, Lineage, StreamInfo + types.py # ObservationRef, ObservationRow, Lineage, StreamInfo db.py # DB (Resource lifecycle, SqliteDB) session.py # Session (connection, stream factory, correlate) stream.py # StreamBase, BlobStream, EmbeddingStream, TextStream @@ -39,12 +39,13 @@ class ObservationRow: ref: ObservationRef ts: float | None = None pose: PoseLike | None = None - scores: dict[str, float] = field(default_factory=dict) + scores: dict[str, float] = field(default_factory=dict) # query-time only (from rank/search), not stored tags: dict[str, Any] = field(default_factory=dict) @dataclass class Lineage: - parent_ref: ObservationRef | None = None # single parent via parent_stream + parent_rowid + parent_stream: str | None = None # from _streams registry (stream-level) + parent_rowid: int | None = None # per-row: which row in parent stream ``` Poses use DimOS's existing `PoseLike` type alias (`Pose | PoseStamped | Point | PointStamped`). Internally, `append()` extracts `(x, y, z, qx, qy, qz, qw)` floats for SQL storage; `load` reconstructs a `dimos.msgs.geometry_msgs.Pose` from stored floats. No custom Pose type. @@ -79,16 +80,19 @@ SqliteDB internals: class StreamInfo: name: str payload_type: type + parent_stream: str | None # lineage: all rows derive from this stream count: int class Session(ABC): def stream(self, name: str, payload_type: type, *, retention: str = "run") -> BlobStream: ... def embedding_stream(self, name: str, payload_type: type, *, - dim: int, retention: str = "run") -> EmbeddingStream: ... + dim: int, retention: str = "run", + parent: StreamBase | None = None) -> EmbeddingStream: ... def text_stream(self, name: str, payload_type: type, *, tokenizer: str = "unicode61", - retention: str = "run") -> TextStream: ... + retention: str = "run", + parent: StreamBase | None = None) -> TextStream: ... def list_streams(self) -> list[StreamInfo]: ... def execute(self, sql: str, params=()) -> list: ... def close(self) -> None: ... @@ -98,7 +102,18 @@ class Session(ABC): SqliteSession: - Holds one `sqlite3.Connection` - `stream()` / `embedding_stream()` / `text_stream()`: creates tables if needed (see schema below), caches StreamBase instances -- Registers stream metadata in a `_streams` registry table +- Registers stream metadata in a `_streams` registry table: + +```sql +CREATE TABLE _streams ( + rowid INTEGER PRIMARY KEY, + name TEXT UNIQUE NOT NULL, + type TEXT NOT NULL, -- 'blob', 'embedding', 'text' + payload_type TEXT NOT NULL, + parent_stream_id INTEGER, -- FK to _streams.rowid (lineage) + retention TEXT DEFAULT 'run' +); +``` ### Phase 2: Stream + Query + ObservationSet @@ -108,8 +123,12 @@ SqliteSession: class StreamBase(ABC, Generic[T]): """Abstract base: meta + payload + spatial index. No text/vector indexes.""" # Write - def append(self, payload: T, **meta: Any) -> ObservationRef: ... - def append_many(self, payloads, metas) -> list[ObservationRef]: ... + def append(self, payload: T, *, + ts: float | None = None, # defaults to time.time() + pose: PoseLike | None = None, + tags: dict[str, Any] | None = None, + parent_rowid: int | None = None, + ) -> ObservationRef: ... # Read def query(self) -> Query[T]: ... @@ -119,7 +138,7 @@ class StreamBase(ABC, Generic[T]): def count(self) -> int: ... # Introspection - def meta(self, ref: ObservationRef) -> ObservationMeta: ... + def meta(self, ref: ObservationRef) -> ObservationRow: ... def info(self) -> dict[str, Any]: ... def stats(self) -> dict[str, Any]: ... @@ -127,9 +146,11 @@ class BlobStream(StreamBase[T]): """Concrete stream for arbitrary LCM-serializable payloads. No special indexes.""" class EmbeddingStream(StreamBase[T]): - """Stream with a vec0 vector index. append() also inserts into _vec table.""" + """Stream with a vec0 vector index. No _payload table — the vector in _vec IS the data.""" def __init__(self, ..., *, dim: int): ... def vector(self, ref: ObservationRef) -> list[float] | None: ... + # append() inserts into _meta + _vec only (no _payload) + # load() not supported — use vector() instead # search_embedding() on Query is valid only for EmbeddingStream class TextStream(StreamBase[T]): @@ -154,7 +175,7 @@ class Query(Generic[T]): def filter_refs(self, refs: list[ObservationRef]) -> Query[T]: ... def at(self, t: float, *, tolerance: float = 1.0) -> Query[T]: ... - # Candidate generation + # Candidate generation (raise TypeError if stream lacks the required index) def search_text(self, text: str, *, candidate_k: int | None = None) -> Query[T]: ... def search_embedding(self, vector: list[float], *, candidate_k: int) -> Query[T]: ... @@ -170,10 +191,12 @@ class Query(Generic[T]): def one(self) -> ObservationRow: ... ``` +TODO: we want terminals also that generate some general spatial or temporal summary, maybe as a numpy array even + Query internals: - Accumulates filter predicates, search ops, rank spec, ordering, limit -- `at(t, tolerance)` → sugar for `filter_time(t - tol, t + tol)` + `ORDER BY ABS(ts_start - t) LIMIT 1` -- `order_by(field, desc)` → appends `ORDER BY` clause; valid fields: `ts_start`, `ts_end` +- `at(t, tolerance)` → sugar for `filter_time(t - tol, t + tol)` + `ORDER BY ABS(ts - t) LIMIT 1` +- `order_by(field, desc)` → appends `ORDER BY` clause; valid fields: `ts` - `fetch()`: generates SQL, executes, returns rows - `fetch_set()`: creates an ObservationSet (predicate-backed or ref-table-backed) - search_embedding → sqlite-vec `MATCH`, writes top-k to temp table → ref-table-backed @@ -251,8 +274,7 @@ CREATE TABLE {name}_meta ( pose_x REAL, pose_y REAL, pose_z REAL, pose_qx REAL, pose_qy REAL, pose_qz REAL, pose_qw REAL, tags TEXT, -- JSON (robot_id, frame_id, etc.) - parent_stream TEXT, -- lineage: source stream name - parent_rowid INTEGER -- lineage: source observation rowid + parent_rowid INTEGER -- lineage: rowid in parent stream (parent defined at stream level) ); CREATE INDEX idx_{name}_meta_ts ON {name}_meta(ts); ``` @@ -289,13 +311,13 @@ Spatial or spatio-temporal queries use the R*Tree. ```sql CREATE VIRTUAL TABLE {name}_fts USING fts5( - content, - content={name}_meta, - content_rowid=rowid + content ); ``` -Created by `TextStream` subclass only. +Created by `TextStream` subclass only. Standalone FTS5 table (no `content=` sync). +`TextStream.append()` inserts the text into both `_payload` (as TEXT, not BLOB — `TextStream` overrides payload storage) and into the FTS5 table with the same rowid. +FTS5 `rowid` matches `_meta.rowid`. ### Vector index (embedding search): `{name}_vec` @@ -328,7 +350,7 @@ Only LCM message types are storable. `append()` calls `lcm_encode(payload)`, `lo ### ObservationRef identity -`id` is a UUID4 string generated on `append()`. Never reuse timestamps as identity. +`rowid` is an auto-assigned SQLite integer. Unique within a stream. `ObservationRef(stream, rowid)` is globally unique within a session. ### Unlocalized observations From 69540aa9423e78d7d2b762cc6abce46306c59183 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 5 Mar 2026 13:06:57 +0800 Subject: [PATCH 004/118] query objects spec --- dimos/memory/embedding.py | 19 +- dimos/memory/timeseries/__init__.py | 4 +- dimos/memory/timeseries/sqlite.py | 10 +- dimos/memory/timeseries/test_base.py | 6 +- plans/memory3.md | 340 ++++++++++++--------------- plans/memory3_answers.md | 67 ++++++ plans/query_objects.md | 155 ++++++++++++ plans/questions.md | 4 +- 8 files changed, 399 insertions(+), 206 deletions(-) create mode 100644 plans/memory3_answers.md create mode 100644 plans/query_objects.md diff --git a/dimos/memory/embedding.py b/dimos/memory/embedding.py index 4627ecfc35..a06c239cdf 100644 --- a/dimos/memory/embedding.py +++ b/dimos/memory/embedding.py @@ -29,25 +29,29 @@ from dimos.msgs.nav_msgs import OccupancyGrid from dimos.msgs.sensor_msgs import Image from dimos.msgs.sensor_msgs.Image import Image, sharpness_barrier +from dimos.types.timestamped import Timestamped from dimos.utils.reactive import getter_hot @dataclass -class Config(ModuleConfig): - embedding_model: EmbeddingModel = field(default_factory=CLIPModel) +class SpatialEntry(Timestamped): + pose: PoseStamped -@dataclass -class SpatialEntry: +class SpatialImage(SpatialEntry): image: Image - pose: PoseStamped @dataclass -class SpatialEmbedding(SpatialEntry): +class SpatialEmbedding(SpatialImage): embedding: Embedding +@dataclass +class Config(ModuleConfig): + embedding_model: EmbeddingModel = field(default_factory=CLIPModel) + + class EmbeddingMemory(Module[Config]): default_config = Config config: Config @@ -104,3 +108,6 @@ def query_text(self, query: str) -> list[SpatialEmbedding]: self.config.embedding_model.embed_text(query) results: list[SpatialEmbedding] = [] return results + return results + return results + return results diff --git a/dimos/memory/timeseries/__init__.py b/dimos/memory/timeseries/__init__.py index debc14ab3a..6e77185c43 100644 --- a/dimos/memory/timeseries/__init__.py +++ b/dimos/memory/timeseries/__init__.py @@ -16,7 +16,7 @@ from dimos.memory.timeseries.base import TimeSeriesStore from dimos.memory.timeseries.inmemory import InMemoryStore from dimos.memory.timeseries.pickledir import PickleDirStore -from dimos.memory.timeseries.sqlite import SqliteStore +from dimos.memory.timeseries.sqlite import SqliteTSStore def __getattr__(name: str): # type: ignore[no-untyped-def] @@ -35,7 +35,7 @@ def __getattr__(name: str): # type: ignore[no-untyped-def] "InMemoryStore", "PickleDirStore", "PostgresStore", - "SqliteStore", + "SqliteTSStore", "TimeSeriesStore", "reset_db", ] diff --git a/dimos/memory/timeseries/sqlite.py b/dimos/memory/timeseries/sqlite.py index 6e2ac7a7f5..6f1f0d88e4 100644 --- a/dimos/memory/timeseries/sqlite.py +++ b/dimos/memory/timeseries/sqlite.py @@ -37,24 +37,24 @@ def _validate_identifier(name: str) -> str: return name -class SqliteStore(TimeSeriesStore[T]): +class SqliteTSStore(TimeSeriesStore[T]): """SQLite backend for sensor data. Good for indexed queries and single-file storage. Data is stored as pickled BLOBs with timestamp as indexed column. Usage: # Named store (uses data/ directory, auto-downloads from LFS if needed) - store = SqliteStore("recordings/lidar") # -> data/recordings/lidar.db + store = SqliteTSStore("recordings/lidar") # -> data/recordings/lidar.db store.save(data) # saves using data.ts # Absolute path - store = SqliteStore("/path/to/sensors.db") + store = SqliteTSStore("/path/to/sensors.db") # In-memory (for testing) - store = SqliteStore(":memory:") + store = SqliteTSStore(":memory:") # Multiple tables in one DB - store = SqliteStore("recordings/sensors", table="lidar") + store = SqliteTSStore("recordings/sensors", table="lidar") """ def __init__(self, name: str | Path, table: str = "sensor_data") -> None: diff --git a/dimos/memory/timeseries/test_base.py b/dimos/memory/timeseries/test_base.py index 9491d2c93c..31b811c251 100644 --- a/dimos/memory/timeseries/test_base.py +++ b/dimos/memory/timeseries/test_base.py @@ -24,7 +24,7 @@ from dimos.memory.timeseries.inmemory import InMemoryStore from dimos.memory.timeseries.legacy import LegacyPickleStore from dimos.memory.timeseries.pickledir import PickleDirStore -from dimos.memory.timeseries.sqlite import SqliteStore +from dimos.memory.timeseries.sqlite import SqliteTSStore from dimos.types.timestamped import Timestamped @@ -60,7 +60,7 @@ def make_pickle_dir_store(tmpdir: str) -> TimeSeriesStore[SampleData]: def make_sqlite_store(tmpdir: str) -> TimeSeriesStore[SampleData]: - return SqliteStore[SampleData](Path(tmpdir) / "test.db") + return SqliteTSStore[SampleData](Path(tmpdir) / "test.db") def make_legacy_pickle_store(tmpdir: str) -> TimeSeriesStore[SampleData]: @@ -71,7 +71,7 @@ def make_legacy_pickle_store(tmpdir: str) -> TimeSeriesStore[SampleData]: testdata: list[tuple[object, str]] = [ (lambda _: make_in_memory_store(), "InMemoryStore"), (lambda tmpdir: make_pickle_dir_store(tmpdir), "PickleDirStore"), - (lambda tmpdir: make_sqlite_store(tmpdir), "SqliteStore"), + (lambda tmpdir: make_sqlite_store(tmpdir), "SqliteTSStore"), (lambda tmpdir: make_legacy_pickle_store(tmpdir), "LegacyPickleStore"), ] diff --git a/plans/memory3.md b/plans/memory3.md index 386b56efcc..b81bdb8329 100644 --- a/plans/memory3.md +++ b/plans/memory3.md @@ -1,7 +1,4 @@ # Memory2 Implementation Plan - -Source of truth: `plans/memory_2_spec_v_2.md` - ## Context Check `questions.md` @@ -10,25 +7,32 @@ Check `questions.md` ``` dimos/memory2/ - __init__.py # public exports - _sql.py # _validate_identifier(), SQL helpers + __init__.py # public exports (re-exports from API + default backend) types.py # ObservationRef, ObservationRow, Lineage, StreamInfo - db.py # DB (Resource lifecycle, SqliteDB) - session.py # Session (connection, stream factory, correlate) - stream.py # StreamBase, BlobStream, EmbeddingStream, TextStream - observation_set.py # ObservationSet (lazy, re-queryable, predicate/ref-table backed) - query.py # Query (filter/search/rank/limit → fetch/fetch_set) - test_memory2.py # tests + store.py # Store ABC (Resource lifecycle) + session.py # Session ABC (stream factory) + stream.py # StreamBase, BlobStream, EmbeddingStream, TextStream ABCs + query.py # Query ABC (filter/search/rank/limit → fetch/fetch_set) + observation_set.py # ObservationSet ABC + + impl/ + sqlite/ + __init__.py # exports SqliteStore + store.py # SqliteStore (connection, WAL, extension loading) + session.py # SqliteSession (stream factory, _streams registry) + stream.py # SqliteBlobStream, SqliteEmbeddingStream, SqliteTextStream + query.py # SqliteQuery (SQL generation, execution) + observation_set.py # SqliteObservationSet (predicate/ref-table backing) + _sql.py # SQL helpers, identifier validation, schema DDL + + test_memory2.py # tests (against SqliteStore) ``` -## Implementation Priority (per spec §15) +## API Layer (`dimos/memory2/`) -### Phase 1: Core types + storage - -1. **`types.py`** — Data classes +### `types.py` — Data classes ```python - @dataclass(frozen=True) class ObservationRef: stream: str @@ -46,86 +50,61 @@ class ObservationRow: class Lineage: parent_stream: str | None = None # from _streams registry (stream-level) parent_rowid: int | None = None # per-row: which row in parent stream -``` - -Poses use DimOS's existing `PoseLike` type alias (`Pose | PoseStamped | Point | PointStamped`). Internally, `append()` extracts `(x, y, z, qx, qy, qz, qw)` floats for SQL storage; `load` reconstructs a `dimos.msgs.geometry_msgs.Pose` from stored floats. No custom Pose type. -2. **`_sql.py`** — SQL helpers - -```python -def validate_identifier(name: str) -> str: ... # regex check, length limit +@dataclass +class StreamInfo: + name: str + payload_type: type + parent_stream: str | None # lineage: all rows derive from this stream + count: int ``` -3. **`db.py`** — DB + SqliteDB +Poses use DimOS's existing `PoseLike` type alias (`Pose | PoseStamped | Point | PointStamped`). No custom Pose type. + +### `store.py` — Store ABC ```python -class DB(Resource, ABC): +class Store(Resource, ABC): def session(self) -> Session: ... def close(self) -> None: ... def start(self) -> None: pass def stop(self) -> None: self.close() ``` -SqliteDB internals: -- Stores file path, creates parent dirs on connect -- `_connect()`: `sqlite3.connect()`, WAL mode, loads sqlite-vec (optional), loads FTS5 -- Tracks sessions via `WeakSet` for cleanup -- `:memory:` uses `file::memory:?cache=shared` URI -- Thread safety: each session = one connection, no `check_same_thread=False` - -4. **`session.py`** — Session + SqliteSession +### `session.py` — Session ABC ```python -@dataclass -class StreamInfo: - name: str - payload_type: type - parent_stream: str | None # lineage: all rows derive from this stream - count: int +PoseProvider = Callable[[], PoseLike | None] class Session(ABC): def stream(self, name: str, payload_type: type, *, - retention: str = "run") -> BlobStream: ... + retention: str = "run", + pose_provider: PoseProvider | None = None) -> BlobStream: ... def embedding_stream(self, name: str, payload_type: type, *, dim: int, retention: str = "run", - parent: StreamBase | None = None) -> EmbeddingStream: ... + parent: StreamBase | None = None, + pose_provider: PoseProvider | None = None) -> EmbeddingStream: ... def text_stream(self, name: str, payload_type: type, *, tokenizer: str = "unicode61", retention: str = "run", - parent: StreamBase | None = None) -> TextStream: ... + parent: StreamBase | None = None, + pose_provider: PoseProvider | None = None) -> TextStream: ... def list_streams(self) -> list[StreamInfo]: ... - def execute(self, sql: str, params=()) -> list: ... def close(self) -> None: ... def __enter__ / __exit__ ``` -SqliteSession: -- Holds one `sqlite3.Connection` -- `stream()` / `embedding_stream()` / `text_stream()`: creates tables if needed (see schema below), caches StreamBase instances -- Registers stream metadata in a `_streams` registry table: - -```sql -CREATE TABLE _streams ( - rowid INTEGER PRIMARY KEY, - name TEXT UNIQUE NOT NULL, - type TEXT NOT NULL, -- 'blob', 'embedding', 'text' - payload_type TEXT NOT NULL, - parent_stream_id INTEGER, -- FK to _streams.rowid (lineage) - retention TEXT DEFAULT 'run' -); -``` - -### Phase 2: Stream + Query + ObservationSet - -5. **`stream.py`** — Stream hierarchy (subclassed by data type) +### `stream.py` — Stream hierarchy ABCs ```python class StreamBase(ABC, Generic[T]): - """Abstract base: meta + payload + spatial index. No text/vector indexes.""" + """Abstract base. No text/vector indexes.""" + pose_provider: PoseProvider | None = None # auto-fills pose on append if set + # Write def append(self, payload: T, *, ts: float | None = None, # defaults to time.time() - pose: PoseLike | None = None, + pose: PoseLike | None = None, # explicit pose overrides provider tags: dict[str, Any] | None = None, parent_rowid: int | None = None, ) -> ObservationRef: ... @@ -139,38 +118,28 @@ class StreamBase(ABC, Generic[T]): # Introspection def meta(self, ref: ObservationRef) -> ObservationRow: ... - def info(self) -> dict[str, Any]: ... - def stats(self) -> dict[str, Any]: ... class BlobStream(StreamBase[T]): - """Concrete stream for arbitrary LCM-serializable payloads. No special indexes.""" + """Stream for arbitrary serializable payloads. No special indexes.""" class EmbeddingStream(StreamBase[T]): - """Stream with a vec0 vector index. No _payload table — the vector in _vec IS the data.""" - def __init__(self, ..., *, dim: int): ... + """Stream with vector index. No payload table — the vector IS the data.""" def vector(self, ref: ObservationRef) -> list[float] | None: ... - # append() inserts into _meta + _vec only (no _payload) # load() not supported — use vector() instead - # search_embedding() on Query is valid only for EmbeddingStream class TextStream(StreamBase[T]): - """Stream with an FTS5 index. append() also inserts into _fts table.""" - def __init__(self, ..., *, tokenizer: str = "unicode61"): ... - # search_text() on Query is valid only for TextStream + """Stream with FTS index.""" ``` -`append()` inserts a metadata row (SQLite auto-assigns `rowid`), serializes payload via `lcm_encode()` into `_payload` BLOB, and inserts an R*Tree entry if pose is provided. `EmbeddingStream.append()` also inserts into the `_vec` table; `TextStream.append()` also inserts into the `_fts` table. Returns `ObservationRef(stream, rowid)`. `load()` deserializes via `lcm_decode()` using the stream's `payload_type`. - -6. **`query.py`** — Query (chainable, capability-aware) +### `query.py` — Query ABC ```python -class Query(Generic[T]): +class Query(ABC, Generic[T]): # Hard filters - def filter_time(self, t1: float, t2: float) -> Query[T]: ... - def filter_before(self, t: float) -> Query[T]: ... - def filter_after(self, t: float) -> Query[T]: ... - def filter_near(self, pose: PoseLike, radius: float, *, - include_unlocalized: bool = False) -> Query[T]: ... + def time_range(self, t1: float, t2: float) -> Query[T]: ... + def before(self, t: float) -> Query[T]: ... + def after(self, t: float) -> Query[T]: ... + def filter_tags(self, **tags: Any) -> Query[T]: ... def filter_refs(self, refs: list[ObservationRef]) -> Query[T]: ... def at(self, t: float, *, tolerance: float = 1.0) -> Query[T]: ... @@ -189,25 +158,15 @@ class Query(Generic[T]): def fetch_set(self) -> ObservationSet[T]: ... def count(self) -> int: ... def one(self) -> ObservationRow: ... + def last(self) -> ObservationRow: ... ``` -TODO: we want terminals also that generate some general spatial or temporal summary, maybe as a numpy array even +TODO: terminals that generate spatial or temporal summaries (maybe as numpy arrays). -Query internals: -- Accumulates filter predicates, search ops, rank spec, ordering, limit -- `at(t, tolerance)` → sugar for `filter_time(t - tol, t + tol)` + `ORDER BY ABS(ts - t) LIMIT 1` -- `order_by(field, desc)` → appends `ORDER BY` clause; valid fields: `ts` -- `fetch()`: generates SQL, executes, returns rows -- `fetch_set()`: creates an ObservationSet (predicate-backed or ref-table-backed) -- search_embedding → sqlite-vec `MATCH`, writes top-k to temp table → ref-table-backed -- search_text → FTS5 `MATCH` -- filter_near → R*Tree range query -- rank → computes composite score from available score columns - -7. **`observation_set.py`** — ObservationSet (lazy, re-queryable) +### `observation_set.py` — ObservationSet ABC ```python -class ObservationSet(Generic[T]): +class ObservationSet(ABC, Generic[T]): # Re-query def query(self) -> Query[T]: ... @@ -224,14 +183,61 @@ class ObservationSet(Generic[T]): # Cross-stream def project_to(self, stream: StreamBase) -> ObservationSet: ... - # Cleanup (ref-table-backed only; no-op for predicate-backed) + # Cleanup def close(self) -> None: ... def __enter__(self) -> Self: ... def __exit__(self, *exc) -> None: ... def __del__(self) -> None: ... # best-effort fallback ``` -Internal backing (spec §8): +--- + +## SQLite Implementation (`dimos/memory2/impl/sqlite/`) + +### `store.py` — SqliteStore + +- Stores file path, creates parent dirs on connect +- `_connect()`: `sqlite3.connect()`, WAL mode, loads sqlite-vec (optional), loads FTS5 +- Tracks sessions via `WeakSet` for cleanup +- `:memory:` uses `file::memory:?cache=shared` URI +- Thread safety: each session = one connection, no `check_same_thread=False` + +### `session.py` — SqliteSession + +- Holds one `sqlite3.Connection` +- `stream()` / `embedding_stream()` / `text_stream()`: creates tables if needed, caches stream instances +- Registers stream metadata in a `_streams` registry table: + +```sql +CREATE TABLE _streams ( + rowid INTEGER PRIMARY KEY, + name TEXT UNIQUE NOT NULL, + type TEXT NOT NULL, -- 'blob', 'embedding', 'text' + payload_type TEXT NOT NULL, + parent_stream_id INTEGER, -- FK to _streams.rowid (lineage) + retention TEXT DEFAULT 'run' +); +``` + +### `stream.py` — SqliteBlobStream, SqliteEmbeddingStream, SqliteTextStream + +`append()` inserts a metadata row (SQLite auto-assigns `rowid`), serializes payload into `_payload`, and inserts an R*Tree entry if pose is provided. `EmbeddingStream.append()` inserts into `_vec` only (no `_payload`). `TextStream.append()` inserts into both `_payload` (as TEXT) and `_fts`. Returns `ObservationRef(stream, rowid)`. + +### `query.py` — SqliteQuery + +- Accumulates filter predicates, search ops, rank spec, ordering, limit +- `at(t, tolerance)` → sugar for `filter_time(t - tol, t + tol)` + `ORDER BY ABS(ts - t) LIMIT 1` +- `order_by(field, desc)` → appends `ORDER BY` clause; valid fields: `ts` +- `fetch()`: generates SQL, executes, returns rows +- `fetch_set()`: creates ObservationSet (predicate-backed or ref-table-backed) +- `search_embedding` → sqlite-vec `MATCH`, writes top-k to temp table → ref-table-backed +- `search_text` → FTS5 `MATCH` +- `filter_near` → R*Tree range query +- `rank` → computes composite score from available score columns + +### `observation_set.py` — SqliteObservationSet + +Internal backing: ```python @dataclass @@ -250,135 +256,91 @@ class RefTableBacking: - `.query()` on predicate-backed → adds more predicates - `.query()` on ref-table-backed → filters within that temp table -- `project_to()` → joins backing refs via lineage parent_refs to target stream +- `project_to()` → joins backing refs via lineage parent_rowid to target stream - `close()` drops the temp table for ref-table-backed sets; no-op for predicate-backed - Supports context manager (`with`) for deterministic cleanup; `__del__` as fallback - SQLite connection close is the final safety net for any leaked temp tables -### Phase 3: Later (not in first PR) +### `_sql.py` — SQL helpers -- `derive()` with Transform protocol -- `CompositeBacking` (union/intersection/difference) -- `Correlator` / `s.correlate()` -- `retention` enforcement / cleanup -- Full introspection (stats, spatial_bounds) +```python +def validate_identifier(name: str) -> str: ... # regex check, length limit +``` + +Pose extraction: `_extract_pose(p: PoseLike) -> tuple[float, ...]` pulls `(x, y, z, qx, qy, qz, qw)`. `_reconstruct_pose(row) -> Pose` rebuilds from stored floats. -## SQLite Schema (per stream) +Payload serialization: `lcm_encode(payload)` / `lcm_decode(blob, payload_type)`. Non-LCM types rejected at `append()` with `TypeError`. -### Metadata table: `{name}_meta` +--- +### Schema (per stream) + +**`{name}_meta`** — metadata for all stream types: ```sql CREATE TABLE {name}_meta ( - rowid INTEGER PRIMARY KEY, -- auto-assigned, used by R*Tree/FTS/vec0 + rowid INTEGER PRIMARY KEY, ts REAL, pose_x REAL, pose_y REAL, pose_z REAL, pose_qx REAL, pose_qy REAL, pose_qz REAL, pose_qw REAL, - tags TEXT, -- JSON (robot_id, frame_id, etc.) - parent_rowid INTEGER -- lineage: rowid in parent stream (parent defined at stream level) + tags TEXT, -- JSON + parent_rowid INTEGER -- lineage: rowid in parent stream ); CREATE INDEX idx_{name}_meta_ts ON {name}_meta(ts); ``` -### Payload table: `{name}_payload` - +**`{name}_payload`** — BlobStream and TextStream only (not EmbeddingStream): ```sql CREATE TABLE {name}_payload ( rowid INTEGER PRIMARY KEY, -- matches _meta.rowid - data BLOB NOT NULL + data BLOB NOT NULL -- TextStream stores TEXT here instead of BLOB ); ``` -Separate from meta so queries never touch payload BLOBs. - -### R*Tree (spatial index): `{name}_rtree` - +**`{name}_rtree`** — all stream types (rows with pose only): ```sql CREATE VIRTUAL TABLE {name}_rtree USING rtree( - rowid, -- matches _meta rowid - min_t, max_t, -- both set to ts (point, not range) - min_x, max_x, -- both set to pose_x - min_y, max_y, -- both set to pose_y - min_z, max_z -- both set to pose_z + rowid, + min_t, max_t, -- both set to ts + min_x, max_x, min_y, max_y, min_z, max_z -- both set to pose_xyz ); ``` -Only rows with pose get R*Tree entries (unlocalized != everywhere). -R*Tree `rowid` matches `_meta.rowid` directly — no mapping needed. -Time-only queries use the B-tree index on `_meta.ts` (faster than R*Tree for 1D). -Spatial or spatio-temporal queries use the R*Tree. - -### FTS5 (text search): `{name}_fts` - +**`{name}_fts`** — TextStream only: ```sql -CREATE VIRTUAL TABLE {name}_fts USING fts5( - content -); +CREATE VIRTUAL TABLE {name}_fts USING fts5(content); ``` -Created by `TextStream` subclass only. Standalone FTS5 table (no `content=` sync). -`TextStream.append()` inserts the text into both `_payload` (as TEXT, not BLOB — `TextStream` overrides payload storage) and into the FTS5 table with the same rowid. -FTS5 `rowid` matches `_meta.rowid`. - -### Vector index (embedding search): `{name}_vec` - +**`{name}_vec`** — EmbeddingStream only: ```sql -CREATE VIRTUAL TABLE {name}_vec USING vec0( - embedding float[{dim}] -); -``` - -`rowid` matches meta rowid. Created by `EmbeddingStream` subclass only. - -## Key Design Decisions - -### Pose handling - -All pose parameters accept `PoseLike` (`Pose | PoseStamped | Point | PointStamped` from `dimos.msgs.geometry_msgs`). No custom pose type. - -```python -from dimos.msgs.geometry_msgs import Pose, Point - -images.append(frame, pose=robot_pose) # Pose object -q.filter_near(Point(1, 2, 3), radius=5.0) # Point object +CREATE VIRTUAL TABLE {name}_vec USING vec0(embedding float[{dim}]); ``` -Internally, `_extract_pose(p: PoseLike) -> tuple[float, ...]` pulls `(x, y, z, qx, qy, qz, qw)` for SQL columns. `ObservationRow.pose` returns a reconstructed `dimos.msgs.geometry_msgs.Pose`. - -### Payload serialization - -Only LCM message types are storable. `append()` calls `lcm_encode(payload)`, `load()` calls `lcm_decode(blob, payload_type)`. Non-LCM types are rejected at `append()` time with a `TypeError`. - -### ObservationRef identity - -`rowid` is an auto-assigned SQLite integer. Unique within a stream. `ObservationRef(stream, rowid)` is globally unique within a session. +All virtual table rowids match `_meta.rowid` directly. -### Unlocalized observations +--- -Rows without pose are NOT inserted into R*Tree. `filter_near()` excludes them by default. `include_unlocalized=True` bypasses R*Tree and scans meta table. +## Phase 3: Later (not in first PR) -### Separate payload table - -Payload BLOBs live in `{name}_payload`, separate from `{name}_meta`. This ensures queries (which only touch meta + indexes) never page in multi-MB image blobs. - -## Existing Code to Reuse +- `derive()` with Transform protocol +- `CompositeBacking` (union/intersection/difference) +- `Correlator` / `s.correlate()` +- `retention` enforcement / cleanup +- Full introspection (stats, spatial_bounds) -- `dimos/memory/timeseries/sqlite.py:29` — `_validate_identifier()` regex pattern -- `dimos/msgs/geometry_msgs/Pose.py` — DimOS Pose type, `PoseLike` type alias -- `dimos/msgs/geometry_msgs/Point.py` — Point type -- `dimos/core/resource.py` — Resource ABC (start/stop/dispose) -- LCM `lcm_encode()` / `lcm_decode()` — payload serialization +## Design Decisions -## Verification +### API-level -1. `uv run pytest dimos/memory2/test_memory2.py -v` — all tests pass -2. `uv run mypy dimos/memory2/` — type checks clean -3. `uv run pytest dimos/memory/timeseries/test_base.py -v` — existing tests untouched +- **Poses**: all pose params accept `PoseLike` (`Pose | PoseStamped | Point | PointStamped`). No custom pose type. +- **ObservationRef identity**: `rowid` is auto-assigned integer. `ObservationRef(stream, rowid)` is globally unique within a session. +- **Unlocalized observations**: rows without pose excluded from `filter_near()` by default. `include_unlocalized=True` to include them. +- **Stream hierarchy**: `StreamBase` (ABC) → `BlobStream`, `EmbeddingStream`, `TextStream`. Indexing is determined by stream type, not config. +- **Lineage**: parent stream defined at stream level (in `_streams` registry). Per-row `parent_rowid` links to specific row in parent. -### Test scenarios (map to spec §16 acceptance examples) +### SQLite-specific -- Re-query narrowed data: `filter_time → fetch_set → query → filter_near → fetch_set` -- fetch_set does not load payloads: verify no BLOB reads until explicit `load()` -- Embedding search: `search_embedding → filter_time → limit → fetch_set` → ref-table backed -- Projection: `emb_matches.project_to(images)` → fetch page → load_many -- Paginated preview: `fetch_page(limit=24, offset=0)` returns ObservationRows -- Unlocalized exclusion: rows without pose excluded from `filter_near` by default +- **Separate payload table**: `_payload` separate from `_meta` so queries never page in multi-MB blobs. +- **EmbeddingStream has no payload table**: the vector in `_vec` IS the data. +- **R*Tree for spatio-temporal**: time-only queries use B-tree index on `_meta.ts` (faster for 1D). Spatial/spatio-temporal queries use R*Tree. +- **Payload serialization**: `lcm_encode()` / `lcm_decode()`. Non-LCM types rejected at `append()` with `TypeError`. +- **ObservationSet cleanup**: ref-table-backed sets use SQLite temp tables. Cleaned via context manager, `__del__` fallback, or connection close. diff --git a/plans/memory3_answers.md b/plans/memory3_answers.md new file mode 100644 index 0000000000..24331dc002 --- /dev/null +++ b/plans/memory3_answers.md @@ -0,0 +1,67 @@ +# Memory2 API Answers + +Worked examples against the API defined in `memory3.md`. + +## Q1: "Where was I when this log line was added?" + +> Pose lookup, correlating to log lines found. Assume log lines have poses associated. Assume there are multiple log lines matching a search. + +### Setup + +```python +store = SqliteStore("/data/robot.db") +session = store.session() + +# TextStream for robot logs — pose auto-filled from TF tree +logs = session.text_stream("logs", payload_type=str, + pose_provider=lambda: tf.get_pose("world", "base_link")) + +# At runtime, just append text — pose is filled automatically +logs.append("Motor fault on joint 3") +logs.append("Obstacle detected ahead") +logs.append("Motor fault on joint 3") +``` + +### Single log line lookup + +```python +row = logs.query().search_text("motor fault on joint 3").one() +print(f"Robot was at {row.pose} when this log was added (t={row.ts})") +``` + +`search_text()` uses FTS5 keyword matching. `one()` returns the best match. The pose comes straight from `_meta` — no joins or extra queries needed. + +### Multiple matches + +```python +rows = logs.query().search_text("motor fault").order_by("ts").fetch() + +for row in rows: + text = logs.load(row.ref) # load actual log text from _payload + print(f"t={row.ts} pose={row.pose}: {text}") +``` + +### Spatial aggregation — "where do motor faults cluster?" + +```python +rows = logs.query().search_text("motor fault").fetch() + +# Group by proximity (application-level, not part of core API) +from collections import defaultdict +clusters = defaultdict(list) +for row in rows: + # bucket by 2m grid + key = (round(row.pose.x / 2) * 2, round(row.pose.y / 2) * 2) + clusters[key].append(row) + +for loc, group in clusters.items(): + print(f" {len(group)} motor faults near {loc}") +``` + +### What's exercised + +- `TextStream` with FTS index for keyword search +- `search_text()` → FTS5 `MATCH` +- Pose stored at append time, returned in `ObservationRow.pose` +- `load()` to retrieve actual text payload separately from metadata +- `order_by("ts")` for chronological ordering diff --git a/plans/query_objects.md b/plans/query_objects.md new file mode 100644 index 0000000000..bf86d39675 --- /dev/null +++ b/plans/query_objects.md @@ -0,0 +1,155 @@ +# Query Objects — 4D Region + Soft Scoring System + +## Problem + +We need to query observations across 4 dimensions (x, y, z, t) plus embedding space. Current API has flat `filter_*` methods — works for simple cases but doesn't compose. We need: + +1. **Regions** — composable hard boundaries (include/exclude) +2. **Fields** — soft scoring that biases toward a point/time/embedding without hard cutoffs +3. A way to combine both in a single query + +## Key Insight + +Hard filters and soft biases are the same thing at different extremes: +- Hard filter = step function (1 inside, 0 outside) +- Soft bias = smooth decay (gaussian, linear, etc.) + +A unified **Criterion** type handles both. Each criterion maps an observation to a score in `[0, 1]`. Hard filters are just criteria with score `{0, 1}`. + +## Primitives + +### Temporal + +```python +# Hard boundaries +TimeRange(t1, t2) # 1 inside, 0 outside +Before(t) # sugar for TimeRange(-inf, t) +After(t) # sugar for TimeRange(t, inf) + +# Soft — score decays with distance from target +TimeProximity(target, sigma=60.0) # gaussian: exp(-dt²/2σ²) +``` + +### Spatial + +```python +# Hard boundaries +Sphere(center: PoseLike, radius: float) # 1 inside, 0 outside +Box(min: PoseLike, max: PoseLike) # axis-aligned bounding box +HeightRange(z_min, z_max) # horizontal slice + +# Soft +SpatialProximity(point: PoseLike, sigma=5.0) # gaussian in 3D +``` + +### Embedding + +```python +# Soft only (no hard boundary in embedding space makes sense) +EmbeddingSimilarity(vector, candidate_k=100) # cosine similarity, top-k pre-filter +``` + +### Tags + +```python +TagMatch(robot_id="robot1") # hard: exact match on tag values +``` + +## Composition + +Criteria compose via set operators: + +```python +# Intersection — all criteria must score > 0 +region = TimeRange(t1, t2) & Sphere(point, 5.0) + +# Union — any criterion scoring > 0 passes +region = Sphere(p1, 3.0) | Sphere(p2, 3.0) + +# Complement +region = ~TimeRange(t1, t2) # everything outside this window +``` + +For soft criteria, composition combines scores: +- `a & b` → `min(a.score, b.score)` (conservative) +- `a | b` → `max(a.score, b.score)` (permissive) + +## Weighted Scoring + +The interesting problem: "I care about embedding similarity, temporal proximity, AND spatial proximity" — but as soft preferences, not hard cutoffs. + +```python +Score( + time=TimeProximity(target_t, sigma=60), + space=SpatialProximity(point, sigma=5.0), + embedding=EmbeddingSimilarity(vector, candidate_k=200), + weights={"time": 0.3, "space": 0.3, "embedding": 0.4} +) +``` + +Each dimension produces a `[0, 1]` score. Final score = weighted sum. This replaces the vague `rank(**weights)` in the current API. + +## Integration with Query + +```python +# Current flat API (still works, sugar for simple cases) +q.after(t).near(point, 5.0).search_embedding(vec, candidate_k=100) + +# Region object approach +region = After(t) & Sphere(point, 5.0) +q.where(region).search_embedding(vec, candidate_k=100) + +# Full soft scoring — no hard boundaries, just preferences +q.score( + time=TimeProximity(target_t, sigma=120), + space=SpatialProximity(point, sigma=10.0), + embedding=EmbeddingSimilarity(vec, candidate_k=500), +).limit(20) + +# Mixed — hard boundary + soft ranking within +q.where(TimeRange(t1, t2)).score( + space=SpatialProximity(point, sigma=5.0), + embedding=EmbeddingSimilarity(vec, candidate_k=200), +).limit(10) +``` + +## SQL Mapping (SQLite impl) + +How each primitive maps to SQL: + +| Criterion | SQL Strategy | +|--------------------------|-------------------------------------------------------| +| `TimeRange(t1, t2)` | `WHERE ts BETWEEN ? AND ?` (B-tree) | +| `Before(t)` / `After(t)` | `WHERE ts < ?` / `WHERE ts > ?` | +| `Sphere(p, r)` | R*Tree range query on `_rtree` | +| `HeightRange(z1, z2)` | `WHERE pose_z BETWEEN ? AND ?` | +| `Box(min, max)` | R*Tree range query | +| `TimeProximity(t, σ)` | `ORDER BY ABS(ts - ?) ASC` or compute score in SELECT | +| `SpatialProximity(p, σ)` | R*Tree range (pre-filter at ~3σ) + score in SELECT | +| `EmbeddingSimilarity` | sqlite-vec `MATCH` → temp table | +| `TagMatch` | `WHERE json_extract(tags, ?) = ?` | + +Soft scoring strategy: **generous hard pre-filter in SQL, then score in Python**. +- Each soft criterion auto-generates a hard pre-filter at ~3σ (captures 99.7% of relevant results) +- `TimeProximity(t, σ=60)` → SQL: `WHERE ts BETWEEN t-180 AND t+180` (B-tree) +- `SpatialProximity(p, σ=5)` → SQL: R*Tree range query with 15m box +- `EmbeddingSimilarity` → sqlite-vec `MATCH` top-k (already a pre-filter) +- Python computes `[0, 1]` scores on the pre-filtered set, applies weights, sorts + +This keeps SQL simple (range queries on indexes) and Python handles the math. + +## Open Questions + +2. **How does `Score` interact with `search_embedding`?** Embedding search already returns ranked results from vec0. Should `Score.embedding` just re-weight those scores, or does it need a separate search pass? + +3. **Region objects as first-class types?** Do we store/serialize regions (e.g., "the kitchen region" as a reusable spatial boundary)? Or are they always constructed in code? + +4. **Do we need `NOT` regions for exclusion zones?** E.g., "everywhere except within 2m of the charging station." `~Sphere(charger, 2.0)` — complement on spatial regions requires scanning all of `_meta`, can't use R*Tree efficiently. + +5. **Gradient fields?** "Prefer observations taken at higher elevation" — not proximity to a point but a directional preference. `HeightGradient(ascending=True)` as a scorer? + +## Priority + +- **Phase 1**: Keep the flat `filter_*` / `rank()` API. Implement primitives internally. +- **Phase 2**: Expose `Criterion` objects + `where()` + `score()` as the composable API. +- **Phase 3**: Region persistence, named regions, gradient fields. diff --git a/plans/questions.md b/plans/questions.md index 4010bd55f4..bc91b9f306 100644 --- a/plans/questions.md +++ b/plans/questions.md @@ -1,7 +1,9 @@ # Questions 1. "where was I when this log line was added?" -- pose lookup from a timestamp +- pose lookup, corelating to log lines found +- assume log line has a pose associated +- assume there are multiple log lines matching a search 2. "how long have I been observing the red socks currently in view?" - how many times did I see them before? From cc5939b29acfd070898d2b2182d9d2da03166051 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 5 Mar 2026 13:27:59 +0800 Subject: [PATCH 005/118] mem3 iteration --- plans/memory3.md | 77 ++++++++++++++++++++++++------------------------ 1 file changed, 39 insertions(+), 38 deletions(-) diff --git a/plans/memory3.md b/plans/memory3.md index b81bdb8329..2534aa5a15 100644 --- a/plans/memory3.md +++ b/plans/memory3.md @@ -8,7 +8,7 @@ Check `questions.md` ``` dimos/memory2/ __init__.py # public exports (re-exports from API + default backend) - types.py # ObservationRef, ObservationRow, Lineage, StreamInfo + types.py # ObservationRow, Lineage, StreamInfo store.py # Store ABC (Resource lifecycle) session.py # Session ABC (stream factory) stream.py # StreamBase, BlobStream, EmbeddingStream, TextStream ABCs @@ -33,23 +33,26 @@ dimos/memory2/ ### `types.py` — Data classes ```python -@dataclass(frozen=True) -class ObservationRef: - stream: str - rowid: int - @dataclass class ObservationRow: - ref: ObservationRef + id: int ts: float | None = None - pose: PoseLike | None = None - scores: dict[str, float] = field(default_factory=dict) # query-time only (from rank/search), not stored + pose: PoseStamped | None = None tags: dict[str, Any] = field(default_factory=dict) + _data: Any = field(default=None, repr=False) + _load: Callable[[], Any] | None = field(default=None, repr=False) + + @property + def data(self) -> Any: + """Lazy payload access. Pre-populated for appended events, fetched on demand for query results.""" + if self._data is None and self._load is not None: + self._data = self._load() + return self._data @dataclass class Lineage: parent_stream: str | None = None # from _streams registry (stream-level) - parent_rowid: int | None = None # per-row: which row in parent stream + parent_id: int | None = None # per-row: which row in parent stream @dataclass class StreamInfo: @@ -78,16 +81,11 @@ PoseProvider = Callable[[], PoseLike | None] class Session(ABC): def stream(self, name: str, payload_type: type, *, - retention: str = "run", pose_provider: PoseProvider | None = None) -> BlobStream: ... - def embedding_stream(self, name: str, payload_type: type, *, - dim: int, retention: str = "run", - parent: StreamBase | None = None, - pose_provider: PoseProvider | None = None) -> EmbeddingStream: ... + def embedding_stream(self, name: str, *, + model: EmbeddingModel) -> EmbeddingStream: ... def text_stream(self, name: str, payload_type: type, *, tokenizer: str = "unicode61", - retention: str = "run", - parent: StreamBase | None = None, pose_provider: PoseProvider | None = None) -> TextStream: ... def list_streams(self) -> list[StreamInfo]: ... def close(self) -> None: ... @@ -106,26 +104,34 @@ class StreamBase(ABC, Generic[T]): ts: float | None = None, # defaults to time.time() pose: PoseLike | None = None, # explicit pose overrides provider tags: dict[str, Any] | None = None, - parent_rowid: int | None = None, - ) -> ObservationRef: ... + parent_id: int | None = None, + ) -> ObservationRow: ... # returned row has .data pre-populated + + # Reactive + @property + def appended(self) -> Observable[ObservationRow]: ... # .data pre-populated # Read def query(self) -> Query[T]: ... - def load(self, ref: ObservationRef) -> T: ... - def load_many(self, refs: list[ObservationRef], *, batch_size=32) -> list[T]: ... - def iter_meta(self, *, page_size=128) -> Iterator[list[ObservationRow]]: ... + def load(self, row: ObservationRow) -> T: ... + def load_many(self, rows: list[ObservationRow], *, batch_size=32) -> list[T]: ... def count(self) -> int: ... - # Introspection - def meta(self, ref: ObservationRef) -> ObservationRow: ... - class BlobStream(StreamBase[T]): """Stream for arbitrary serializable payloads. No special indexes.""" class EmbeddingStream(StreamBase[T]): """Stream with vector index. No payload table — the vector IS the data.""" - def vector(self, ref: ObservationRef) -> list[float] | None: ... - # load() not supported — use vector() instead + model: EmbeddingModel + + def attach(self, parent: StreamBase) -> Self: + """Sets lineage parent + subscribes to parent.appended to auto-embed.""" + # parent.appended.pipe( + # ops.map(lambda row: self._embed_and_store(row)), + # ).subscribe() + ... + + def vector(self, row: ObservationRow) -> list[float] | None: ... class TextStream(StreamBase[T]): """Stream with FTS index.""" @@ -141,15 +147,12 @@ class Query(ABC, Generic[T]): def after(self, t: float) -> Query[T]: ... def filter_tags(self, **tags: Any) -> Query[T]: ... - def filter_refs(self, refs: list[ObservationRef]) -> Query[T]: ... def at(self, t: float, *, tolerance: float = 1.0) -> Query[T]: ... # Candidate generation (raise TypeError if stream lacks the required index) def search_text(self, text: str, *, candidate_k: int | None = None) -> Query[T]: ... def search_embedding(self, vector: list[float], *, candidate_k: int) -> Query[T]: ... - # Ranking + ordering + limit - def rank(self, **weights: float) -> Query[T]: ... def order_by(self, field: str, *, desc: bool = False) -> Query[T]: ... def limit(self, k: int) -> Query[T]: ... @@ -171,9 +174,8 @@ class ObservationSet(ABC, Generic[T]): def query(self) -> Query[T]: ... # Read - def load(self, ref: ObservationRef) -> T: ... - def load_many(self, refs, *, batch_size=32) -> list[T]: ... - def refs(self, *, limit=None) -> list[ObservationRef]: ... + def load(self, row: ObservationRow) -> T: ... + def load_many(self, rows, *, batch_size=32) -> list[T]: ... def rows(self, *, limit=None) -> list[ObservationRow]: ... def one(self) -> ObservationRow: ... def fetch_page(self, *, limit=128, offset=0) -> list[ObservationRow]: ... @@ -214,14 +216,13 @@ CREATE TABLE _streams ( name TEXT UNIQUE NOT NULL, type TEXT NOT NULL, -- 'blob', 'embedding', 'text' payload_type TEXT NOT NULL, - parent_stream_id INTEGER, -- FK to _streams.rowid (lineage) - retention TEXT DEFAULT 'run' + parent_stream_id INTEGER -- FK to _streams.rowid (lineage) ); ``` ### `stream.py` — SqliteBlobStream, SqliteEmbeddingStream, SqliteTextStream -`append()` inserts a metadata row (SQLite auto-assigns `rowid`), serializes payload into `_payload`, and inserts an R*Tree entry if pose is provided. `EmbeddingStream.append()` inserts into `_vec` only (no `_payload`). `TextStream.append()` inserts into both `_payload` (as TEXT) and `_fts`. Returns `ObservationRef(stream, rowid)`. +`append()` inserts a metadata row (SQLite auto-assigns `rowid`), serializes payload into `_payload`, and inserts an R*Tree entry if pose is provided. `EmbeddingStream.append()` inserts into `_vec` only (no `_payload`). `TextStream.append()` inserts into both `_payload` (as TEXT) and `_fts`. Returns `ObservationRow` with `id`, `ts`, `pose`, `tags` populated. ### `query.py` — SqliteQuery @@ -332,10 +333,10 @@ All virtual table rowids match `_meta.rowid` directly. ### API-level - **Poses**: all pose params accept `PoseLike` (`Pose | PoseStamped | Point | PointStamped`). No custom pose type. -- **ObservationRef identity**: `rowid` is auto-assigned integer. `ObservationRef(stream, rowid)` is globally unique within a session. +- **Row identity**: `id` is auto-assigned integer per stream. Unique within a stream. Impl layer maps to SQLite `rowid`. - **Unlocalized observations**: rows without pose excluded from `filter_near()` by default. `include_unlocalized=True` to include them. - **Stream hierarchy**: `StreamBase` (ABC) → `BlobStream`, `EmbeddingStream`, `TextStream`. Indexing is determined by stream type, not config. -- **Lineage**: parent stream defined at stream level (in `_streams` registry). Per-row `parent_rowid` links to specific row in parent. +- **Lineage**: parent stream defined at stream level (in `_streams` registry). Per-row `parent_id` links to specific row in parent. ### SQLite-specific From a5d3f3cfb8a8fa46715d73bcd2fa4a4ee47211c9 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 5 Mar 2026 13:33:32 +0800 Subject: [PATCH 006/118] live/passive transforms --- plans/transform.md | 180 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 180 insertions(+) create mode 100644 plans/transform.md diff --git a/plans/transform.md b/plans/transform.md new file mode 100644 index 0000000000..0d82481304 --- /dev/null +++ b/plans/transform.md @@ -0,0 +1,180 @@ +# Transform — Unified Derived Stream API + +## Concept + +`.transform()` is a single method on `StreamBase` that handles both historical (batch) and live (reactive) processing. It takes data from a source, applies a function, and stores results into the target stream with lineage. + +## API + +```python +class StreamBase(ABC, Generic[T]): + def transform(self, + source: StreamBase | ObservationSet, + fn: Callable[[Any], T | list[T] | None] | None = None, + *, + live: bool = False, + ) -> Self: + """ + Process source data, store results in this stream. + + Args: + source: where to read from + fn: transform function. Returns T, list[T], or None (skip). + None allowed for EmbeddingStream (uses model.embed implicitly). + live: if True, only subscribe to new appends (no backfill) + + Behavior by source type: + StreamBase → backfill existing + subscribe to live (default) + live=True → skip backfill, only subscribe + ObservationSet → batch process snapshot (live ignored) + + Returns self for chaining. + """ +``` + +## Source type determines mode + +| Source | `live=False` (default) | `live=True` | +|--------|----------------------|-------------| +| `StreamBase` | backfill all existing + subscribe to `.appended` | subscribe to `.appended` only | +| `ObservationSet` | batch process the set | N/A (ignored) | + +## Transform function contract + +```python +fn: Callable[[Any], T | list[T] | None] +``` + +- Returns `T` → single result stored +- Returns `list[T]` → multiple results stored (e.g., multiple detections per frame) +- Returns `None` or `[]` → nothing stored for this input (e.g., no detections) +- `parent_id` set automatically from source row + +## Examples + +### VLM detections on images + +```python +images = session.stream("images", Image, + pose_provider=lambda: tf.get_pose("world", "base_link")) + +detections = session.stream("cigarette_detections", VLMDetection) + +# Backfill + live +detections.transform(images, fn=lambda img: vlm.detect(img, "people with cigarettes")) + +# After this, every new image.append() triggers detection automatically +# All results are queryable +rows = detections.query().filter_after(one_hour_ago).fetch() +``` + +### Live-only (skip backfill) + +```python +detections.transform(images, fn=detect_fn, live=True) +# Only processes images appended from now on +``` + +### Historical batch on query results + +```python +# Only process images from the kitchen in the last hour +kitchen_images = images.query().filter_near(kitchen_pose, 5.0).filter_after(one_hour_ago).fetch_set() + +detections.transform(kitchen_images, fn=lambda img: vlm.detect(img, "cigarettes")) +# Batch processes the set, no live subscription +``` + +### Embedding stream (specialized) + +```python +img_emb = session.embedding_stream("img_emb", model=CLIPModel()) + +# fn is implicit — uses model.embed() +img_emb.transform(images, live=True) + +# Equivalent to: +img_emb.transform(images, fn=lambda img: clip.embed(img), live=True) +``` + +### Chaining transforms + +```python +images = session.stream("images", Image, pose_provider=pose_fn) + +# Embeddings from images +img_emb = session.embedding_stream("img_emb", model=CLIPModel()) +img_emb.transform(images, live=True) + +# Detections from images +detections = session.stream("detections", VLMDetection) +detections.transform(images, fn=detect_fn, live=True) + +# Text descriptions from detections (second-level derived) +descriptions = session.text_stream("descriptions", str) +descriptions.transform(detections, fn=lambda det: det.describe(), live=True) +``` + +## Internals + +### Backfill (batch) + +```python +for page in source.iter_meta(page_size=128): + for row in page: + payload = source.load(row) # or row.data + results = fn(payload) + if results is None: + continue + if not isinstance(results, list): + results = [results] + for r in results: + self.append(r, ts=row.ts, pose=row.pose, parent_id=row.id) +``` + +### Live (reactive) + +```python +source.appended.pipe( + ops.map(lambda row: (row, fn(row.data))), + ops.filter(lambda pair: pair[1] is not None), + ops.flat_map(lambda pair: [ + (pair[0], r) for r in (pair[1] if isinstance(pair[1], list) else [pair[1]]) + ]), +).subscribe(lambda pair: self.append(pair[1], ts=pair[0].ts, pose=pair[0].pose, + parent_id=pair[0].id)) +``` + +### EmbeddingStream override + +```python +class EmbeddingStream(StreamBase[T]): + model: EmbeddingModel + + def transform(self, source, fn=None, *, live=False): + if fn is None: + fn = self.model.embed + return super().transform(source, fn, live=live) +``` + +## Lineage + +`transform()` sets `parent_id` on every appended row, linking back to the source row. This enables `project_to()`: + +```python +# Find source images for cigarette detections +with detections.query().fetch_set() as det_set: + source_images = det_set.project_to(images) + for row in source_images.rows(limit=5): + img = images.load(row) +``` + +## Open questions + +1. **Async transforms?** VLM inference is slow. Should `fn` support async/await or rx scheduling (e.g., `observe_on(io_scheduler)`)? + +2. **Error handling?** If `fn` raises on one row, skip it? Log and continue? Configurable? + +3. **Backfill progress?** For large backfills, should `transform()` return a progress observable or run in background? + +4. **Multiple parents?** Current design is single-parent lineage. If a stream derives from two streams (e.g., fusing image + audio), we'd need multi-parent support. Phase 3. From c87d955931d39026a44f914720ff39502ec3ce85 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 5 Mar 2026 14:54:53 +0800 Subject: [PATCH 007/118] initial pass on memory --- dimos/memory/__init__.py | 26 + dimos/memory/impl/__init__.py | 0 dimos/memory/impl/sqlite.py | 629 ++++++++++++++ dimos/memory/store.py | 75 ++ dimos/memory/stream.py | 388 +++++++++ dimos/memory/tests/__init__.py | 0 dimos/memory/tests/test_sqlite.py | 302 +++++++ dimos/memory/transformer.py | 101 +++ dimos/memory/types.py | 157 ++++ dimos/{memory => memory_old}/embedding.py | 0 dimos/memory_old/impl/sqlite.py | 14 + .../{memory => memory_old}/test_embedding.py | 0 .../timeseries/__init__.py | 0 .../{memory => memory_old}/timeseries/base.py | 0 .../timeseries/inmemory.py | 0 .../timeseries/legacy.py | 0 .../timeseries/pickledir.py | 0 .../timeseries/postgres.py | 0 .../timeseries/sqlite.py | 0 .../timeseries/test_base.py | 0 .../timeseries/test_legacy.py | 0 plans/memory/api.md | 477 +++++++++++ plans/{ => memory}/query_objects.md | 0 plans/{ => memory}/questions.md | 0 plans/memory/sqlite.md | 780 ++++++++++++++++++ plans/{ => memory}/transform.md | 8 +- plans/{ => old}/analysis.md | 0 plans/{ => old}/answers.md | 0 plans/{ => old}/answers_correlator.md | 0 plans/{ => old}/correlator.md | 0 plans/{ => old}/memory.md | 0 plans/{ => old}/memory1.md | 0 plans/{ => old}/memory2.md | 0 plans/{ => old}/memory3.md | 22 +- plans/{ => old}/memory3_answers.md | 0 plans/old/memory4.md | 466 +++++++++++ plans/old/transforms.md | 21 + 37 files changed, 3456 insertions(+), 10 deletions(-) create mode 100644 dimos/memory/__init__.py create mode 100644 dimos/memory/impl/__init__.py create mode 100644 dimos/memory/impl/sqlite.py create mode 100644 dimos/memory/store.py create mode 100644 dimos/memory/stream.py create mode 100644 dimos/memory/tests/__init__.py create mode 100644 dimos/memory/tests/test_sqlite.py create mode 100644 dimos/memory/transformer.py create mode 100644 dimos/memory/types.py rename dimos/{memory => memory_old}/embedding.py (100%) create mode 100644 dimos/memory_old/impl/sqlite.py rename dimos/{memory => memory_old}/test_embedding.py (100%) rename dimos/{memory => memory_old}/timeseries/__init__.py (100%) rename dimos/{memory => memory_old}/timeseries/base.py (100%) rename dimos/{memory => memory_old}/timeseries/inmemory.py (100%) rename dimos/{memory => memory_old}/timeseries/legacy.py (100%) rename dimos/{memory => memory_old}/timeseries/pickledir.py (100%) rename dimos/{memory => memory_old}/timeseries/postgres.py (100%) rename dimos/{memory => memory_old}/timeseries/sqlite.py (100%) rename dimos/{memory => memory_old}/timeseries/test_base.py (100%) rename dimos/{memory => memory_old}/timeseries/test_legacy.py (100%) create mode 100644 plans/memory/api.md rename plans/{ => memory}/query_objects.md (100%) rename plans/{ => memory}/questions.md (100%) create mode 100644 plans/memory/sqlite.md rename plans/{ => memory}/transform.md (92%) rename plans/{ => old}/analysis.md (100%) rename plans/{ => old}/answers.md (100%) rename plans/{ => old}/answers_correlator.md (100%) rename plans/{ => old}/correlator.md (100%) rename plans/{ => old}/memory.md (100%) rename plans/{ => old}/memory1.md (100%) rename plans/{ => old}/memory2.md (100%) rename plans/{ => old}/memory3.md (93%) rename plans/{ => old}/memory3_answers.md (100%) create mode 100644 plans/old/memory4.md create mode 100644 plans/old/transforms.md diff --git a/dimos/memory/__init__.py b/dimos/memory/__init__.py new file mode 100644 index 0000000000..ba76a2f5ed --- /dev/null +++ b/dimos/memory/__init__.py @@ -0,0 +1,26 @@ +from dimos.memory.store import Session, Store +from dimos.memory.stream import EmbeddingStream, Stream, TextStream +from dimos.memory.transformer import ( + EmbeddingTransformer, + PerItemTransformer, + Transformer, +) +from dimos.memory.types import ( + EmbeddingObservation, + Observation, + StreamInfo, +) + +__all__ = [ + "EmbeddingObservation", + "EmbeddingStream", + "EmbeddingTransformer", + "Observation", + "PerItemTransformer", + "Session", + "Store", + "Stream", + "StreamInfo", + "TextStream", + "Transformer", +] diff --git a/dimos/memory/impl/__init__.py b/dimos/memory/impl/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py new file mode 100644 index 0000000000..d3f8e3b989 --- /dev/null +++ b/dimos/memory/impl/sqlite.py @@ -0,0 +1,629 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SQLite-backed memory store implementation. + +Each stream maps to a table: + {name} — id INTEGER PK, ts REAL, pose BLOB, tags TEXT (JSON), payload BLOB + {name}_fts — FTS5 virtual table (TextStream only) + {name}_vec — vec0 virtual table (EmbeddingStream only) + +Payloads are pickled. Poses are pickled PoseStamped. Tags are JSON. +""" + +from __future__ import annotations + +import json +import pickle +import sqlite3 +import time +from typing import TYPE_CHECKING, Any + +from reactivex.subject import Subject + +from dimos.memory.store import Session, Store +from dimos.memory.stream import EmbeddingStream, Stream, TextStream +from dimos.memory.types import ( + AfterFilter, + AtFilter, + BeforeFilter, + EmbeddingObservation, + EmbeddingSearchFilter, + Filter, + NearFilter, + Observation, + StreamInfo, + StreamQuery, + TagsFilter, + TextSearchFilter, + TimeRangeFilter, +) + +if TYPE_CHECKING: + from dimos.memory.types import PoseProvider + + +# ── Serialization helpers ───────────────────────────────────────────── + + +def _serialize_payload(payload: Any) -> bytes: + return pickle.dumps(payload, protocol=pickle.HIGHEST_PROTOCOL) + + +def _deserialize_payload(blob: bytes) -> Any: + return pickle.loads(blob) + + +def _serialize_pose(pose: Any) -> bytes | None: + if pose is None: + return None + return pickle.dumps(pose, protocol=pickle.HIGHEST_PROTOCOL) + + +def _deserialize_pose(blob: bytes | None) -> Any: + if blob is None: + return None + return pickle.loads(blob) + + +def _serialize_tags(tags: dict[str, Any] | None) -> str: + if not tags: + return "{}" + return json.dumps(tags, separators=(",", ":")) + + +def _deserialize_tags(text: str) -> dict[str, Any]: + if not text: + return {} + return json.loads(text) # type: ignore[no-any-return] + + +# ── SQL building ────────────────────────────────────────────────────── + + +def _compile_filter(f: Filter, table: str) -> tuple[str, list[Any]]: + """Compile a single filter to (SQL fragment, params).""" + if isinstance(f, AfterFilter): + return "ts > ?", [f.t] + if isinstance(f, BeforeFilter): + return "ts < ?", [f.t] + if isinstance(f, TimeRangeFilter): + return "ts >= ? AND ts <= ?", [f.t1, f.t2] + if isinstance(f, AtFilter): + return "ABS(ts - ?) <= ?", [f.t, f.tolerance] + if isinstance(f, TagsFilter): + clauses: list[str] = [] + params: list[Any] = [] + for key, val in f.tags.items(): + clauses.append(f"json_extract(tags, '$.{key}') = ?") + params.append(val) + return " AND ".join(clauses), params + if isinstance(f, NearFilter): + # Spatial filtering requires pose deserialization — done post-query + # Return a no-op SQL clause; filtering happens in Python + return "1=1", [] + if isinstance(f, EmbeddingSearchFilter): + # Handled specially by EmbeddingStream backend + return "1=1", [] + if isinstance(f, TextSearchFilter): + # Handled specially by TextStream backend + return "1=1", [] + raise TypeError(f"Unknown filter type: {type(f)}") + + +def _compile_query(query: StreamQuery, table: str) -> tuple[str, list[Any]]: + """Compile a StreamQuery to (SQL, params) for a SELECT.""" + where_parts: list[str] = [] + params: list[Any] = [] + + for f in query.filters: + sql, p = _compile_filter(f, table) + where_parts.append(sql) + params.extend(p) + + where = " AND ".join(where_parts) if where_parts else "1=1" + order = f"ORDER BY {query.order_field}" + if query.order_field: + if query.order_desc: + order += " DESC" + else: + order = "ORDER BY id" + + sql = f"SELECT id, ts, pose, tags, payload FROM {table} WHERE {where} {order}" + if query.limit_val is not None: + sql += f" LIMIT {query.limit_val}" + if query.offset_val is not None: + sql += f" OFFSET {query.offset_val}" + return sql, params + + +def _compile_count(query: StreamQuery, table: str) -> tuple[str, list[Any]]: + where_parts: list[str] = [] + params: list[Any] = [] + for f in query.filters: + sql, p = _compile_filter(f, table) + where_parts.append(sql) + params.extend(p) + where = " AND ".join(where_parts) if where_parts else "1=1" + return f"SELECT COUNT(*) FROM {table} WHERE {where}", params + + +# ── Near-filter post-processing ─────────────────────────────────────── + + +def _has_near_filter(query: StreamQuery) -> NearFilter | None: + for f in query.filters: + if isinstance(f, NearFilter): + return f + return None + + +def _apply_near_filter(rows: list[Observation], near: NearFilter) -> list[Observation]: + """Post-filter observations by spatial distance.""" + from dimos.msgs.geometry_msgs.Pose import to_pose + + target = to_pose(near.pose) + result: list[Observation] = [] + for obs in rows: + if obs.pose is None: + continue + obs_pose = to_pose(obs.pose) + dist = (target - obs_pose).position.norm() + if dist <= near.radius: + result.append(obs) + return result + + +# ── Backend ─────────────────────────────────────────────────────────── + + +class SqliteStreamBackend: + """StreamBackend implementation for a single SQLite-backed stream.""" + + def __init__( + self, + conn: sqlite3.Connection, + table: str, + *, + pose_provider: PoseProvider | None = None, + ) -> None: + self._conn = conn + self._table = table + self._pose_provider = pose_provider + self._subject: Subject[Observation] = Subject() # type: ignore[type-arg] + + @property + def appended_subject(self) -> Subject[Observation]: # type: ignore[type-arg] + return self._subject + + @property + def stream_name(self) -> str: + return self._table + + def do_append( + self, + payload: Any, + ts: float | None, + pose: Any | None, + tags: dict[str, Any] | None, + ) -> Observation: + if ts is None: + ts = time.time() + if pose is None and self._pose_provider is not None: + pose = self._pose_provider() + + payload_blob = _serialize_payload(payload) + pose_blob = _serialize_pose(pose) + tags_json = _serialize_tags(tags) + + cur = self._conn.execute( + f"INSERT INTO {self._table} (ts, pose, tags, payload) VALUES (?, ?, ?, ?)", + (ts, pose_blob, tags_json, payload_blob), + ) + self._conn.commit() + row_id = cur.lastrowid + assert row_id is not None + + obs = Observation( + id=row_id, + ts=ts, + pose=pose, + tags=tags or {}, + _data=payload, + ) + self._subject.on_next(obs) + return obs + + def execute_fetch(self, query: StreamQuery) -> list[Observation]: + sql, params = _compile_query(query, self._table) + rows = self._conn.execute(sql, params).fetchall() + + observations = [self._row_to_obs(r) for r in rows] + + near = _has_near_filter(query) + if near is not None: + observations = _apply_near_filter(observations, near) + + return observations + + def execute_count(self, query: StreamQuery) -> int: + sql, params = _compile_count(query, self._table) + result = self._conn.execute(sql, params).fetchone() + return result[0] if result else 0 # type: ignore[no-any-return] + + def _row_to_obs(self, row: Any) -> Observation: + row_id, ts, pose_blob, tags_json, payload_blob = row + return Observation( + id=row_id, + ts=ts, + pose=_deserialize_pose(pose_blob), + tags=_deserialize_tags(tags_json), + _data=_deserialize_payload(payload_blob), + ) + + +class SqliteEmbeddingBackend(SqliteStreamBackend): + """Backend for EmbeddingStream — stores vectors in a vec0 virtual table.""" + + def __init__( + self, + conn: sqlite3.Connection, + table: str, + *, + vec_dimensions: int | None = None, + pose_provider: PoseProvider | None = None, + parent_table: str | None = None, + ) -> None: + super().__init__(conn, table, pose_provider=pose_provider) + self._vec_dimensions = vec_dimensions + self._parent_table = parent_table + + def do_append( + self, + payload: Any, + ts: float | None, + pose: Any | None, + tags: dict[str, Any] | None, + ) -> Observation: + from dimos.models.embedding.base import Embedding + + obs = super().do_append(payload, ts, pose, tags) + + # Also insert into vec0 table if payload is an Embedding + if isinstance(payload, Embedding): + vec = payload.to_numpy().tolist() + if self._vec_dimensions is None: + self._vec_dimensions = len(vec) + self._ensure_vec_table() + self._conn.execute( + f"INSERT INTO {self._table}_vec (rowid, embedding) VALUES (?, ?)", + (obs.id, json.dumps(vec)), + ) + self._conn.commit() + + return obs + + def _ensure_vec_table(self) -> None: + if self._vec_dimensions is None: + return + self._conn.execute( + f"CREATE VIRTUAL TABLE IF NOT EXISTS {self._table}_vec " + f"USING vec0(embedding float[{self._vec_dimensions}])" + ) + self._conn.commit() + + def execute_fetch(self, query: StreamQuery) -> list[Observation]: + # Check for embedding search filter + emb_filter = None + for f in query.filters: + if isinstance(f, EmbeddingSearchFilter): + emb_filter = f + break + + if emb_filter is not None: + return self._fetch_by_vector(query, emb_filter) + + return super().execute_fetch(query) + + def _fetch_by_vector( + self, query: StreamQuery, emb_filter: EmbeddingSearchFilter + ) -> list[Observation]: + """Fetch using vec0 similarity search, then apply remaining filters.""" + # First, get candidate rowids from vec0 + vec_sql = ( + f"SELECT rowid, distance FROM {self._table}_vec " + f"WHERE embedding MATCH ? ORDER BY distance LIMIT ?" + ) + vec_rows = self._conn.execute( + vec_sql, (json.dumps(emb_filter.query), emb_filter.k) + ).fetchall() + + if not vec_rows: + return [] + + rowids = [r[0] for r in vec_rows] + placeholders = ",".join("?" * len(rowids)) + + # Build remaining WHERE clauses (skip the embedding filter) + where_parts: list[str] = [f"id IN ({placeholders})"] + params: list[Any] = list(rowids) + + for f in query.filters: + if isinstance(f, EmbeddingSearchFilter): + continue + sql_frag, p = _compile_filter(f, self._table) + where_parts.append(sql_frag) + params.extend(p) + + where = " AND ".join(where_parts) + sql = f"SELECT id, ts, pose, tags, payload FROM {self._table} WHERE {where}" + rows = self._conn.execute(sql, params).fetchall() + + observations = [self._row_to_obs(r) for r in rows] + + near = _has_near_filter(query) + if near is not None: + observations = _apply_near_filter(observations, near) + + return observations + + def _row_to_obs(self, row: Any) -> Observation: + row_id, ts, pose_blob, tags_json, payload_blob = row + return EmbeddingObservation( + id=row_id, + ts=ts, + pose=_deserialize_pose(pose_blob), + tags=_deserialize_tags(tags_json), + _data=_deserialize_payload(payload_blob), + ) + + +class SqliteTextBackend(SqliteStreamBackend): + """Backend for TextStream — maintains an FTS5 index.""" + + def __init__( + self, + conn: sqlite3.Connection, + table: str, + *, + tokenizer: str = "unicode61", + pose_provider: PoseProvider | None = None, + ) -> None: + super().__init__(conn, table, pose_provider=pose_provider) + self._tokenizer = tokenizer + + def do_append( + self, + payload: Any, + ts: float | None, + pose: Any | None, + tags: dict[str, Any] | None, + ) -> Observation: + obs = super().do_append(payload, ts, pose, tags) + + # Insert into FTS table + text = str(payload) if payload is not None else "" + self._conn.execute( + f"INSERT INTO {self._table}_fts (rowid, content) VALUES (?, ?)", + (obs.id, text), + ) + self._conn.commit() + return obs + + def execute_fetch(self, query: StreamQuery) -> list[Observation]: + text_filter = None + for f in query.filters: + if isinstance(f, TextSearchFilter): + text_filter = f + break + + if text_filter is not None: + return self._fetch_by_text(query, text_filter) + + return super().execute_fetch(query) + + def _fetch_by_text( + self, query: StreamQuery, text_filter: TextSearchFilter + ) -> list[Observation]: + # Get matching rowids from FTS + fts_sql = f"SELECT rowid, rank FROM {self._table}_fts WHERE content MATCH ? ORDER BY rank" + fts_params: list[Any] = [text_filter.text] + if text_filter.k is not None: + fts_sql += " LIMIT ?" + fts_params.append(text_filter.k) + + fts_rows = self._conn.execute(fts_sql, fts_params).fetchall() + if not fts_rows: + return [] + + rowids = [r[0] for r in fts_rows] + placeholders = ",".join("?" * len(rowids)) + + where_parts: list[str] = [f"id IN ({placeholders})"] + params: list[Any] = list(rowids) + + for f in query.filters: + if isinstance(f, TextSearchFilter): + continue + sql_frag, p = _compile_filter(f, self._table) + where_parts.append(sql_frag) + params.extend(p) + + where = " AND ".join(where_parts) + sql = f"SELECT id, ts, pose, tags, payload FROM {self._table} WHERE {where}" + rows = self._conn.execute(sql, params).fetchall() + + observations = [self._row_to_obs(r) for r in rows] + + near = _has_near_filter(query) + if near is not None: + observations = _apply_near_filter(observations, near) + + return observations + + +# ── Session ─────────────────────────────────────────────────────────── + + +class SqliteSession(Session): + """Session against a SQLite database.""" + + def __init__(self, conn: sqlite3.Connection) -> None: + self._conn = conn + self._streams: dict[str, Stream[Any]] = {} + self._ensure_meta_table() + + def _ensure_meta_table(self) -> None: + self._conn.execute( + "CREATE TABLE IF NOT EXISTS _streams (" + " name TEXT PRIMARY KEY," + " payload_type TEXT," + " stream_kind TEXT DEFAULT 'stream'" + ")" + ) + self._conn.commit() + + def stream( + self, + name: str, + payload_type: type | None = None, + *, + pose_provider: PoseProvider | None = None, + ) -> Stream[Any]: + if name in self._streams: + return self._streams[name] + + self._ensure_stream_table(name) + self._register_stream(name, payload_type, "stream") + + backend = SqliteStreamBackend(self._conn, name, pose_provider=pose_provider) + s: Stream[Any] = Stream(backend=backend) + self._streams[name] = s + return s + + def text_stream( + self, + name: str, + payload_type: type | None = None, + *, + tokenizer: str = "unicode61", + pose_provider: PoseProvider | None = None, + ) -> TextStream[Any]: + if name in self._streams: + return self._streams[name] # type: ignore[return-value] + + self._ensure_stream_table(name) + self._ensure_fts_table(name, tokenizer) + self._register_stream(name, payload_type, "text") + + backend = SqliteTextBackend( + self._conn, name, tokenizer=tokenizer, pose_provider=pose_provider + ) + ts: TextStream[Any] = TextStream(backend=backend) + self._streams[name] = ts + return ts + + def embedding_stream( + self, + name: str, + payload_type: type | None = None, + *, + vec_dimensions: int | None = None, + pose_provider: PoseProvider | None = None, + parent_table: str | None = None, + ) -> EmbeddingStream[Any]: + if name in self._streams: + return self._streams[name] # type: ignore[return-value] + + self._ensure_stream_table(name) + self._register_stream(name, payload_type, "embedding") + + backend = SqliteEmbeddingBackend( + self._conn, + name, + vec_dimensions=vec_dimensions, + pose_provider=pose_provider, + parent_table=parent_table, + ) + if vec_dimensions is not None: + backend._ensure_vec_table() + + es: EmbeddingStream[Any] = EmbeddingStream(backend=backend) + self._streams[name] = es + return es + + def list_streams(self) -> list[StreamInfo]: + rows = self._conn.execute("SELECT name, payload_type FROM _streams").fetchall() + result: list[StreamInfo] = [] + for name, ptype in rows: + count_row = self._conn.execute(f"SELECT COUNT(*) FROM {name}").fetchone() + count = count_row[0] if count_row else 0 + result.append(StreamInfo(name=name, payload_type=ptype, count=count)) + return result + + def close(self) -> None: + for s in self._streams.values(): + if s._backend is not None: + s._backend.appended_subject.on_completed() + self._streams.clear() + + # ── Internal helpers ────────────────────────────────────────────── + + def _ensure_stream_table(self, name: str) -> None: + self._conn.execute( + f"CREATE TABLE IF NOT EXISTS {name} (" + " id INTEGER PRIMARY KEY AUTOINCREMENT," + " ts REAL," + " pose BLOB," + " tags TEXT DEFAULT '{}'," + " payload BLOB," + " parent_id INTEGER" + ")" + ) + self._conn.execute(f"CREATE INDEX IF NOT EXISTS idx_{name}_ts ON {name}(ts)") + self._conn.commit() + + def _ensure_fts_table(self, name: str, tokenizer: str) -> None: + self._conn.execute( + f"CREATE VIRTUAL TABLE IF NOT EXISTS {name}_fts " + f"USING fts5(content, tokenize='{tokenizer}')" + ) + self._conn.commit() + + def _register_stream(self, name: str, payload_type: type | None, kind: str) -> None: + type_name = payload_type.__qualname__ if payload_type else None + self._conn.execute( + "INSERT OR IGNORE INTO _streams (name, payload_type, stream_kind) VALUES (?, ?, ?)", + (name, type_name, kind), + ) + self._conn.commit() + + +# ── Store ───────────────────────────────────────────────────────────── + + +class SqliteStore(Store): + """SQLite-backed memory store.""" + + def __init__(self, path: str) -> None: + self._path = path + self._conn = sqlite3.connect(path) + self._conn.execute("PRAGMA journal_mode=WAL") + self._conn.execute("PRAGMA synchronous=NORMAL") + + def session(self) -> SqliteSession: + return SqliteSession(self._conn) + + def close(self) -> None: + self._conn.close() diff --git a/dimos/memory/store.py b/dimos/memory/store.py new file mode 100644 index 0000000000..eb05ee9a41 --- /dev/null +++ b/dimos/memory/store.py @@ -0,0 +1,75 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from .stream import Stream, TextStream + from .types import PoseProvider, StreamInfo + + +class Session(ABC): + """A session against a memory store. Creates and manages streams.""" + + @abstractmethod + def stream( + self, + name: str, + payload_type: type | None = None, + *, + pose_provider: PoseProvider | None = None, + ) -> Stream[Any]: + """Get or create a stored stream backed by the database.""" + + @abstractmethod + def text_stream( + self, + name: str, + payload_type: type | None = None, + *, + tokenizer: str = "unicode61", + pose_provider: PoseProvider | None = None, + ) -> TextStream[Any]: + """Get or create a text stream with FTS index.""" + + @abstractmethod + def list_streams(self) -> list[StreamInfo]: ... + + @abstractmethod + def close(self) -> None: ... + + def __enter__(self) -> Session: + return self + + def __exit__(self, *args: object) -> None: + self.close() + + +class Store(ABC): + """Top-level entry point — wraps a database file.""" + + @abstractmethod + def session(self) -> Session: ... + + @abstractmethod + def close(self) -> None: ... + + def __enter__(self) -> Store: + return self + + def __exit__(self, *args: object) -> None: + self.close() diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py new file mode 100644 index 0000000000..e2ccc50bd8 --- /dev/null +++ b/dimos/memory/stream.py @@ -0,0 +1,388 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Protocol, + TypeVar, + overload, +) + +from .types import ( + AfterFilter, + AtFilter, + BeforeFilter, + EmbeddingObservation, + EmbeddingSearchFilter, + Filter, + NearFilter, + Observation, + StreamQuery, + TagsFilter, + TextSearchFilter, + TimeRangeFilter, +) + +if TYPE_CHECKING: + from collections.abc import Callable, Iterator + + from reactivex import Observable + from reactivex.subject import Subject + + from dimos.models.embedding.base import Embedding + from dimos.msgs.geometry_msgs.Pose import PoseLike + + from .transformer import Transformer + +T = TypeVar("T") +R = TypeVar("R") + + +class StreamBackend(Protocol): + """Backend protocol — implemented by SqliteStreamBackend etc.""" + + def execute_fetch(self, query: StreamQuery) -> list[Observation]: ... + def execute_count(self, query: StreamQuery) -> int: ... + def do_append( + self, + payload: Any, + ts: float | None, + pose: Any | None, + tags: dict[str, Any] | None, + ) -> Observation: ... + @property + def appended_subject(self) -> Subject[Observation]: ... # type: ignore[type-arg] + @property + def stream_name(self) -> str: ... + + +class Stream(Generic[T]): + """Lazy, chainable stream over stored observations. + + Created by Session.stream(). Filter methods return new Stream instances. + Terminals (.fetch(), .count(), etc.) execute the query. + """ + + def __init__( + self, + backend: StreamBackend | None = None, + *, + query: StreamQuery | None = None, + ) -> None: + self._backend = backend + self._query = query or StreamQuery() + + def _clone(self, **overrides: Any) -> Stream[T]: + """Return a new Stream with updated query fields.""" + q = self._query + new_query = StreamQuery( + filters=overrides.get("filters", q.filters), + order_field=overrides.get("order_field", q.order_field), + order_desc=overrides.get("order_desc", q.order_desc), + limit_val=overrides.get("limit_val", q.limit_val), + offset_val=overrides.get("offset_val", q.offset_val), + ) + clone: Stream[T] = self.__class__.__new__(self.__class__) + clone._backend = self._backend + clone._query = new_query + return clone + + def _with_filter(self, f: Filter) -> Stream[T]: + return self._clone(filters=(*self._query.filters, f)) + + def _require_backend(self) -> StreamBackend: + if self._backend is None: + raise TypeError( + "Operation requires a stored stream. Call .store() first or use session.stream()." + ) + return self._backend + + # ── Write ───────────────────────────────────────────────────────── + + def append( + self, + payload: T, + *, + ts: float | None = None, + pose: PoseLike | None = None, + tags: dict[str, Any] | None = None, + ) -> Observation: + backend = self._require_backend() + return backend.do_append(payload, ts, pose, tags) + + # ── Temporal filters ────────────────────────────────────────────── + + def after(self, t: float) -> Stream[T]: + return self._with_filter(AfterFilter(t)) + + def before(self, t: float) -> Stream[T]: + return self._with_filter(BeforeFilter(t)) + + def time_range(self, t1: float, t2: float) -> Stream[T]: + return self._with_filter(TimeRangeFilter(t1, t2)) + + def at(self, t: float, *, tolerance: float = 1.0) -> Stream[T]: + return self._with_filter(AtFilter(t, tolerance)) + + # ── Spatial filter ──────────────────────────────────────────────── + + def near(self, pose: PoseLike, radius: float) -> Stream[T]: + return self._with_filter(NearFilter(pose, radius)) + + # ── Tag filter ──────────────────────────────────────────────────── + + def filter_tags(self, **tags: Any) -> Stream[T]: + return self._with_filter(TagsFilter(tags)) + + # ── Ordering / pagination ───────────────────────────────────────── + + def order_by(self, field: str, *, desc: bool = False) -> Stream[T]: + return self._clone(order_field=field, order_desc=desc) + + def limit(self, k: int) -> Stream[T]: + return self._clone(limit_val=k) + + def offset(self, n: int) -> Stream[T]: + return self._clone(offset_val=n) + + # ── Transform ───────────────────────────────────────────────────── + + @overload + def transform( + self, + xf: Transformer[T, R], + *, + live: bool = ..., + backfill_only: bool = ..., + ) -> Stream[R]: ... + + @overload + def transform( + self, + xf: Callable[[T], Any], + *, + live: bool = ..., + backfill_only: bool = ..., + ) -> Stream[Any]: ... + + def transform( + self, + xf: Transformer[Any, Any] | Callable[..., Any], + *, + live: bool = False, + backfill_only: bool = False, + ) -> Stream[Any]: + from .transformer import PerItemTransformer, Transformer as TransformerABC + + transformer: TransformerABC[Any, Any] + if not isinstance(xf, TransformerABC): + transformer = PerItemTransformer(xf) + else: + transformer = xf + + return TransformStream( + source=self, + transformer=transformer, + live=live, + backfill_only=backfill_only, + ) + + # ── Materialize ─────────────────────────────────────────────────── + + def store(self, name: str | None = None) -> Stream[T]: + raise TypeError( + "store() requires a session context. This stream is not associated with a session." + ) + + # ── Cross-stream lineage ────────────────────────────────────────── + + def project_to(self, target: Stream[Any]) -> Stream[Any]: + raise NotImplementedError("project_to requires a stored stream with lineage") + + # ── Terminals ───────────────────────────────────────────────────── + + def fetch(self) -> list[Observation]: + backend = self._require_backend() + return backend.execute_fetch(self._query) + + def fetch_pages(self, batch_size: int = 128) -> Iterator[list[Observation]]: + offset = self._query.offset_val or 0 + while True: + q = StreamQuery( + filters=self._query.filters, + order_field=self._query.order_field or "id", + order_desc=self._query.order_desc, + limit_val=batch_size, + offset_val=offset, + ) + backend = self._require_backend() + page = backend.execute_fetch(q) + if not page: + break + yield page + if len(page) < batch_size: + break + offset += batch_size + + def one(self) -> Observation: + results = self.limit(1).fetch() + if not results: + raise LookupError("No matching observation") + return results[0] + + def last(self) -> Observation: + results = self.order_by("ts", desc=True).limit(1).fetch() + if not results: + raise LookupError("No matching observation") + return results[0] + + def count(self) -> int: + backend = self._require_backend() + return backend.execute_count(self._query) + + # ── Reactive ────────────────────────────────────────────────────── + + @property + def appended(self) -> Observable[Observation]: # type: ignore[type-arg] + backend = self._require_backend() + return backend.appended_subject # type: ignore[return-value] + + +class EmbeddingStream(Stream[T]): + """Stream with a vector index. Adds search_embedding().""" + + def search_embedding( + self, + query: Embedding | list[float], + *, + k: int, + ) -> EmbeddingStream[T]: + from dimos.models.embedding.base import Embedding as EmbeddingCls + + if isinstance(query, EmbeddingCls): + vec = query.to_numpy().tolist() + else: + vec = list(query) + clone = self._with_filter(EmbeddingSearchFilter(vec, k)) + # Preserve EmbeddingStream type + es: EmbeddingStream[T] = EmbeddingStream(backend=clone._backend, query=clone._query) + return es + + def fetch(self) -> list[EmbeddingObservation]: # type: ignore[override] + backend = self._require_backend() + return backend.execute_fetch(self._query) # type: ignore[return-value] + + def one(self) -> EmbeddingObservation: # type: ignore[override] + q = StreamQuery( + filters=self._query.filters, + order_field=self._query.order_field, + order_desc=self._query.order_desc, + limit_val=1, + offset_val=self._query.offset_val, + ) + backend = self._require_backend() + results = backend.execute_fetch(q) + if not results: + raise LookupError("No matching observation") + return results[0] # type: ignore[return-value] + + def last(self) -> EmbeddingObservation: # type: ignore[override] + q = StreamQuery( + filters=self._query.filters, + order_field="ts", + order_desc=True, + limit_val=1, + offset_val=self._query.offset_val, + ) + backend = self._require_backend() + results = backend.execute_fetch(q) + if not results: + raise LookupError("No matching observation") + return results[0] # type: ignore[return-value] + + +class TextStream(Stream[T]): + """Stream with an FTS5 index. Adds search_text().""" + + def search_text(self, text: str, *, k: int | None = None) -> TextStream[T]: + clone = self._with_filter(TextSearchFilter(text, k)) + ts: TextStream[T] = TextStream(backend=clone._backend, query=clone._query) + return ts + + +class TransformStream(Stream[R]): + """In-memory stream produced by .transform(). Not yet stored.""" + + def __init__( + self, + source: Stream[Any], + transformer: Transformer[Any, R], + *, + live: bool = False, + backfill_only: bool = False, + ) -> None: + super().__init__(backend=None) + self._source = source + self._transformer = transformer + self._live = live + self._backfill_only = backfill_only + + def fetch(self) -> list[Observation]: + """Execute transform in memory, collecting results.""" + collector = _CollectorStream[R]() + if self._transformer.supports_backfill and not self._live: + self._transformer.process(self._source, collector) + return collector.results + + def store(self, name: str | None = None) -> Stream[R]: + # Delegated to session — TransformStream.store() is overridden + # by the session when the source stream has a backend + source_backend = self._source._backend + if source_backend is None: + raise TypeError("Cannot store a transform whose source has no backend session") + # The backend's session handles materialization + raise NotImplementedError( + "store() on TransformStream must be handled by the session/backend" + ) + + +class _CollectorStream(Stream[R]): + """Ephemeral stream that collects appended observations in a list.""" + + def __init__(self) -> None: + super().__init__(backend=None) + self.results: list[Observation] = [] + self._next_id = 0 + + def append( + self, + payload: R, + *, + ts: float | None = None, + pose: PoseLike | None = None, + tags: dict[str, Any] | None = None, + ) -> Observation: + obs = Observation( + id=self._next_id, + ts=ts, + tags=tags or {}, + _data=payload, + ) + self._next_id += 1 + self.results.append(obs) + return obs diff --git a/dimos/memory/tests/__init__.py b/dimos/memory/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dimos/memory/tests/test_sqlite.py b/dimos/memory/tests/test_sqlite.py new file mode 100644 index 0000000000..ec813f2a71 --- /dev/null +++ b/dimos/memory/tests/test_sqlite.py @@ -0,0 +1,302 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for SQLite-backed memory store.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from dimos.memory.impl.sqlite import SqliteSession, SqliteStore + +if TYPE_CHECKING: + from dimos.memory.types import Observation + + +@pytest.fixture +def store(tmp_path: object) -> SqliteStore: + # tmp_path is a pathlib.Path + from pathlib import Path + + assert isinstance(tmp_path, Path) + return SqliteStore(str(tmp_path / "test.db")) + + +@pytest.fixture +def session(store: SqliteStore) -> SqliteSession: + return store.session() + + +class TestStreamBasics: + def test_create_stream(self, session: SqliteSession) -> None: + s = session.stream("images", bytes) + assert s is not None + + def test_append_and_fetch(self, session: SqliteSession) -> None: + s = session.stream("images", bytes) + obs = s.append(b"frame1") + assert obs.id == 1 + assert obs.data == b"frame1" + assert obs.ts is not None + + rows = s.fetch() + assert len(rows) == 1 + assert rows[0].data == b"frame1" + assert rows[0].id == 1 + + def test_append_multiple(self, session: SqliteSession) -> None: + s = session.stream("images", bytes) + s.append(b"frame1") + s.append(b"frame2") + s.append(b"frame3") + + assert s.count() == 3 + rows = s.fetch() + assert [r.data for r in rows] == [b"frame1", b"frame2", b"frame3"] + + def test_append_with_tags(self, session: SqliteSession) -> None: + s = session.stream("images", bytes) + s.append(b"frame1", tags={"cam": "front", "quality": "high"}) + + rows = s.fetch() + assert rows[0].tags == {"cam": "front", "quality": "high"} + + def test_last(self, session: SqliteSession) -> None: + s = session.stream("images", bytes) + s.append(b"frame1", ts=1.0) + s.append(b"frame2", ts=2.0) + s.append(b"frame3", ts=3.0) + + obs = s.last() + assert obs.data == b"frame3" + assert obs.ts == 3.0 + + def test_one(self, session: SqliteSession) -> None: + s = session.stream("images", bytes) + s.append(b"only") + + obs = s.one() + assert obs.data == b"only" + + def test_one_empty_raises(self, session: SqliteSession) -> None: + s = session.stream("images", bytes) + with pytest.raises(LookupError): + s.one() + + +class TestFilters: + def test_after(self, session: SqliteSession) -> None: + s = session.stream("data", str) + s.append("old", ts=1.0) + s.append("new", ts=10.0) + + rows = s.after(5.0).fetch() + assert len(rows) == 1 + assert rows[0].data == "new" + + def test_before(self, session: SqliteSession) -> None: + s = session.stream("data", str) + s.append("old", ts=1.0) + s.append("new", ts=10.0) + + rows = s.before(5.0).fetch() + assert len(rows) == 1 + assert rows[0].data == "old" + + def test_time_range(self, session: SqliteSession) -> None: + s = session.stream("data", str) + s.append("a", ts=1.0) + s.append("b", ts=5.0) + s.append("c", ts=10.0) + + rows = s.time_range(3.0, 7.0).fetch() + assert len(rows) == 1 + assert rows[0].data == "b" + + def test_at(self, session: SqliteSession) -> None: + s = session.stream("data", str) + s.append("a", ts=1.0) + s.append("b", ts=5.0) + s.append("c", ts=10.0) + + rows = s.at(5.5, tolerance=1.0).fetch() + assert len(rows) == 1 + assert rows[0].data == "b" + + def test_filter_tags(self, session: SqliteSession) -> None: + s = session.stream("data", str) + s.append("front", tags={"cam": "front"}) + s.append("rear", tags={"cam": "rear"}) + + rows = s.filter_tags(cam="front").fetch() + assert len(rows) == 1 + assert rows[0].data == "front" + + def test_chained_filters(self, session: SqliteSession) -> None: + s = session.stream("data", str) + s.append("a", ts=1.0, tags={"cam": "front"}) + s.append("b", ts=5.0, tags={"cam": "front"}) + s.append("c", ts=5.0, tags={"cam": "rear"}) + + rows = s.after(3.0).filter_tags(cam="front").fetch() + assert len(rows) == 1 + assert rows[0].data == "b" + + +class TestOrdering: + def test_order_by_ts(self, session: SqliteSession) -> None: + s = session.stream("data", str) + s.append("b", ts=2.0) + s.append("a", ts=1.0) + s.append("c", ts=3.0) + + rows = s.order_by("ts").fetch() + assert [r.data for r in rows] == ["a", "b", "c"] + + def test_order_by_desc(self, session: SqliteSession) -> None: + s = session.stream("data", str) + s.append("a", ts=1.0) + s.append("b", ts=2.0) + s.append("c", ts=3.0) + + rows = s.order_by("ts", desc=True).fetch() + assert [r.data for r in rows] == ["c", "b", "a"] + + def test_limit_offset(self, session: SqliteSession) -> None: + s = session.stream("data", str) + for i in range(10): + s.append(f"item{i}", ts=float(i)) + + rows = s.order_by("ts").limit(3).offset(2).fetch() + assert [r.data for r in rows] == ["item2", "item3", "item4"] + + +class TestFetchPages: + def test_basic_pagination(self, session: SqliteSession) -> None: + s = session.stream("data", str) + for i in range(10): + s.append(f"item{i}", ts=float(i)) + + pages = list(s.fetch_pages(batch_size=3)) + assert len(pages) == 4 # 3+3+3+1 + assert len(pages[0]) == 3 + assert len(pages[-1]) == 1 + + all_items = [obs.data for page in pages for obs in page] + assert all_items == [f"item{i}" for i in range(10)] + + +class TestTextStream: + def test_create_and_append(self, session: SqliteSession) -> None: + s = session.text_stream("logs", str) + s.append("Motor fault on joint 3") + s.append("Battery low warning") + + assert s.count() == 2 + + def test_text_search(self, session: SqliteSession) -> None: + s = session.text_stream("logs", str) + s.append("Motor fault on joint 3") + s.append("Battery low warning") + s.append("Motor overheating on joint 5") + + rows = s.search_text("motor", k=10).fetch() + assert len(rows) == 2 + assert all("Motor" in r.data for r in rows) + + +class TestListStreams: + def test_list_empty(self, session: SqliteSession) -> None: + assert session.list_streams() == [] + + def test_list_after_create(self, session: SqliteSession) -> None: + session.stream("images", bytes) + session.text_stream("logs", str) + + infos = session.list_streams() + names = {i.name for i in infos} + assert names == {"images", "logs"} + + +class TestReactive: + def test_appended_observable(self, session: SqliteSession) -> None: + s = session.stream("images", bytes) + received: list[Observation] = [] + s.appended.subscribe(on_next=received.append) + + s.append(b"frame1") + s.append(b"frame2") + + assert len(received) == 2 + assert received[0].data == b"frame1" + assert received[1].data == b"frame2" + + +class TestTransformInMemory: + def test_lambda_transform(self, session: SqliteSession) -> None: + s = session.stream("data", str) + s.append("hello", ts=1.0) + s.append("world", ts=2.0) + + upper = s.transform(lambda x: x.upper()) + results = upper.fetch() + assert len(results) == 2 + assert results[0].data == "HELLO" + assert results[1].data == "WORLD" + + def test_lambda_filter_none(self, session: SqliteSession) -> None: + s = session.stream("data", int) + s.append(1, ts=1.0) + s.append(2, ts=2.0) + s.append(3, ts=3.0) + + evens = s.transform(lambda x: x * 2 if x % 2 == 0 else None) + results = evens.fetch() + assert len(results) == 1 + assert results[0].data == 4 + + def test_lambda_expand_list(self, session: SqliteSession) -> None: + s = session.stream("data", str) + s.append("a,b,c", ts=1.0) + + split = s.transform(lambda x: x.split(",")) + results = split.fetch() + assert len(results) == 3 + assert [r.data for r in results] == ["a", "b", "c"] + + +class TestStoreReopen: + def test_data_persists(self, tmp_path: object) -> None: + from pathlib import Path + + assert isinstance(tmp_path, Path) + db_path = str(tmp_path / "persist.db") + + # Write + store1 = SqliteStore(db_path) + s1 = store1.session() + s1.stream("data", str).append("hello", ts=1.0) + s1.close() + store1.close() + + # Re-open and read + store2 = SqliteStore(db_path) + s2 = store2.session() + rows = s2.stream("data", str).fetch() + assert len(rows) == 1 + assert rows[0].data == "hello" + s2.close() + store2.close() diff --git a/dimos/memory/transformer.py b/dimos/memory/transformer.py new file mode 100644 index 0000000000..d347eb5160 --- /dev/null +++ b/dimos/memory/transformer.py @@ -0,0 +1,101 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +if TYPE_CHECKING: + from collections.abc import Callable + + from dimos.models.embedding.base import Embedding, EmbeddingModel + + from .stream import Stream + from .types import Observation + +T = TypeVar("T") +R = TypeVar("R") + + +class Transformer(ABC, Generic[T, R]): + """Transforms a source stream into results on a target stream.""" + + supports_backfill: bool = True + supports_live: bool = True + + @abstractmethod + def process(self, source: Stream[T], target: Stream[R]) -> None: + """Batch/historical processing. + + Has full access to the source stream — can query, filter, batch, skip, etc. + """ + + def on_append(self, obs: Observation, target: Stream[R]) -> None: + """Reactive per-item processing. Called for each new item.""" + + +class PerItemTransformer(Transformer[T, R]): + """Wraps a simple callable as a per-item Transformer.""" + + def __init__(self, fn: Callable[[T], R | list[R] | None]) -> None: + self._fn = fn + + def process(self, source: Stream[T], target: Stream[R]) -> None: + for page in source.fetch_pages(): + for obs in page: + self._apply(obs, target) + + def on_append(self, obs: Observation, target: Stream[R]) -> None: + self._apply(obs, target) + + def _apply(self, obs: Observation, target: Stream[R]) -> None: + result = self._fn(obs.data) + if result is None: + return + if isinstance(result, list): + for item in result: + target.append(item, ts=obs.ts, pose=obs.pose, tags=obs.tags) + else: + target.append(result, ts=obs.ts, pose=obs.pose, tags=obs.tags) + + +class EmbeddingTransformer(Transformer[Any, "Embedding"]): + """Wraps an EmbeddingModel as a Transformer that produces Embedding output. + + When stored, the output stream becomes an EmbeddingStream with vector index. + """ + + supports_backfill: bool = True + supports_live: bool = True + + def __init__(self, model: EmbeddingModel) -> None: + self.model = model + + def process(self, source: Stream[Any], target: Stream[Embedding]) -> None: + for page in source.fetch_pages(): + images = [obs.data for obs in page] + if not images: + continue + embeddings = self.model.embed(*images) + if not isinstance(embeddings, list): + embeddings = [embeddings] + for obs, emb in zip(page, embeddings, strict=True): + target.append(emb, ts=obs.ts, pose=obs.pose, tags=obs.tags) + + def on_append(self, obs: Observation, target: Stream[Embedding]) -> None: + emb = self.model.embed(obs.data) + if isinstance(emb, list): + emb = emb[0] + target.append(emb, ts=obs.ts, pose=obs.pose, tags=obs.tags) diff --git a/dimos/memory/types.py b/dimos/memory/types.py new file mode 100644 index 0000000000..bc94daad4e --- /dev/null +++ b/dimos/memory/types.py @@ -0,0 +1,157 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, TypeAlias + +if TYPE_CHECKING: + from dimos.models.embedding.base import Embedding + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + +PoseProvider: TypeAlias = Callable[[], Any] # () -> PoseLike | None + +_UNSET: Any = object() + + +@dataclass +class Observation: + id: int + ts: float | None = None + pose: PoseStamped | None = None + tags: dict[str, Any] = field(default_factory=dict) + _data: Any = field(default=_UNSET, repr=False) + _data_loader: Callable[[], Any] | None = field(default=None, repr=False, compare=False) + + @property + def data(self) -> Any: + if self._data is not _UNSET: + return self._data + if self._data_loader is not None: + self._data = self._data_loader() + return self._data + raise LookupError("No data available; observation was not fetched with payload") + + +@dataclass +class EmbeddingObservation(Observation): + """Returned by EmbeddingStream terminals. + + .data auto-projects to the source stream's payload type. + .embedding gives the Embedding vector. + """ + + _embedding: Embedding | None = field(default=None, repr=False) + _embedding_loader: Callable[[], Embedding] | None = field( + default=None, repr=False, compare=False + ) + _source_data_loader: Callable[[], Any] | None = field(default=None, repr=False, compare=False) + + @property + def data(self) -> Any: + if self._data is not _UNSET: + return self._data + if self._source_data_loader is not None: + self._data = self._source_data_loader() + return self._data + return super().data + + @property + def embedding(self) -> Embedding: + if self._embedding is not None: + return self._embedding + if self._embedding_loader is not None: + self._embedding = self._embedding_loader() + return self._embedding + raise LookupError("No embedding available") + + +@dataclass +class StreamInfo: + name: str + payload_type: str | None = None + count: int = 0 + + +# ── Filter types ────────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class AfterFilter: + t: float + + +@dataclass(frozen=True) +class BeforeFilter: + t: float + + +@dataclass(frozen=True) +class TimeRangeFilter: + t1: float + t2: float + + +@dataclass(frozen=True) +class AtFilter: + t: float + tolerance: float + + +@dataclass(frozen=True) +class NearFilter: + pose: Any # PoseLike + radius: float + + +@dataclass(frozen=True) +class TagsFilter: + tags: dict[str, Any] + + +@dataclass(frozen=True) +class EmbeddingSearchFilter: + query: list[float] + k: int + + +@dataclass(frozen=True) +class TextSearchFilter: + text: str + k: int | None + + +Filter: TypeAlias = ( + AfterFilter + | BeforeFilter + | TimeRangeFilter + | AtFilter + | NearFilter + | TagsFilter + | EmbeddingSearchFilter + | TextSearchFilter +) + + +@dataclass(frozen=True) +class StreamQuery: + """Immutable bundle of query parameters passed to backends.""" + + filters: tuple[Filter, ...] = () + order_field: str | None = None + order_desc: bool = False + limit_val: int | None = None + offset_val: int | None = None diff --git a/dimos/memory/embedding.py b/dimos/memory_old/embedding.py similarity index 100% rename from dimos/memory/embedding.py rename to dimos/memory_old/embedding.py diff --git a/dimos/memory_old/impl/sqlite.py b/dimos/memory_old/impl/sqlite.py new file mode 100644 index 0000000000..20caceb8a7 --- /dev/null +++ b/dimos/memory_old/impl/sqlite.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python3 +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/dimos/memory/test_embedding.py b/dimos/memory_old/test_embedding.py similarity index 100% rename from dimos/memory/test_embedding.py rename to dimos/memory_old/test_embedding.py diff --git a/dimos/memory/timeseries/__init__.py b/dimos/memory_old/timeseries/__init__.py similarity index 100% rename from dimos/memory/timeseries/__init__.py rename to dimos/memory_old/timeseries/__init__.py diff --git a/dimos/memory/timeseries/base.py b/dimos/memory_old/timeseries/base.py similarity index 100% rename from dimos/memory/timeseries/base.py rename to dimos/memory_old/timeseries/base.py diff --git a/dimos/memory/timeseries/inmemory.py b/dimos/memory_old/timeseries/inmemory.py similarity index 100% rename from dimos/memory/timeseries/inmemory.py rename to dimos/memory_old/timeseries/inmemory.py diff --git a/dimos/memory/timeseries/legacy.py b/dimos/memory_old/timeseries/legacy.py similarity index 100% rename from dimos/memory/timeseries/legacy.py rename to dimos/memory_old/timeseries/legacy.py diff --git a/dimos/memory/timeseries/pickledir.py b/dimos/memory_old/timeseries/pickledir.py similarity index 100% rename from dimos/memory/timeseries/pickledir.py rename to dimos/memory_old/timeseries/pickledir.py diff --git a/dimos/memory/timeseries/postgres.py b/dimos/memory_old/timeseries/postgres.py similarity index 100% rename from dimos/memory/timeseries/postgres.py rename to dimos/memory_old/timeseries/postgres.py diff --git a/dimos/memory/timeseries/sqlite.py b/dimos/memory_old/timeseries/sqlite.py similarity index 100% rename from dimos/memory/timeseries/sqlite.py rename to dimos/memory_old/timeseries/sqlite.py diff --git a/dimos/memory/timeseries/test_base.py b/dimos/memory_old/timeseries/test_base.py similarity index 100% rename from dimos/memory/timeseries/test_base.py rename to dimos/memory_old/timeseries/test_base.py diff --git a/dimos/memory/timeseries/test_legacy.py b/dimos/memory_old/timeseries/test_legacy.py similarity index 100% rename from dimos/memory/timeseries/test_legacy.py rename to dimos/memory_old/timeseries/test_legacy.py diff --git a/plans/memory/api.md b/plans/memory/api.md new file mode 100644 index 0000000000..3b00a8b514 --- /dev/null +++ b/plans/memory/api.md @@ -0,0 +1,477 @@ +# Memory2 API — Unified Stream + +## Core Idea + +One type: `Stream[T]`. Everything is a stream — stored, filtered, transformed. The user never thinks about Query vs ObservationSet vs Stream. They just chain operations. + +## Creating Streams + +```python +store = SqliteStore("/data/robot.db") +session = store.session() + +# Root stored stream — backed by DB +images = session.stream("images", Image, + pose_provider=lambda: tf.get_pose("world", "base_link")) + +logs = session.text_stream("logs", str, + pose_provider=lambda: tf.get_pose("world", "base_link")) +``` + +## Writing + +```python +images.append(frame) # ts + pose auto-filled +logs.append("Motor fault on joint 3") # ts + pose auto-filled +images.append(frame, pose=explicit_pose, tags={"cam": "front"}) +``` + +Only meaningful on stored (DB-backed) streams. + +## Filtering + +Every filter returns a new `Stream[T]`. Lazy — nothing executes until a terminal. + +```python +recent = images.after(one_hour_ago) +kitchen = recent.near(kitchen_pose, 5.0) +tagged = kitchen.filter_tags(cam="front") + +# Or chained +images.after(one_hour_ago).near(kitchen_pose, 5.0).filter_tags(cam="front") +``` + +### Filter methods + +```python +class Stream(Generic[T]): + # Temporal + def after(self, t: float) -> Stream[T]: ... + def before(self, t: float) -> Stream[T]: ... + def time_range(self, t1: float, t2: float) -> Stream[T]: ... + def at(self, t: float, *, tolerance: float = 1.0) -> Stream[T]: ... + + # Spatial + def near(self, pose: PoseLike, radius: float) -> Stream[T]: ... + + # Tags + def filter_tags(self, **tags: Any) -> Stream[T]: ... + +class EmbeddingStream(Stream[T]): + def search_embedding(self, query: Embedding | list[float], *, k: int) -> EmbeddingStream[T]: ... + +class TextStream(Stream[T]): + def search_text(self, text: str, *, k: int | None = None) -> TextStream[T]: ... +``` + +## Terminals & Iteration + +`Stream` is directly iterable — pages internally, never loads everything at once. + +```python +# Direct iteration (lazy, memory-efficient — uses fetch_pages internally) +for row in images.after(t).near(kitchen_pose, 5.0): + print(row.data) + +# Explicit fetch when you want the full list in memory +all_rows = images.after(t).fetch() + +# Other terminals +row = images.after(t).one() # single best match +row = images.last() # most recent +n = images.after(t).count() # count without fetching + +# Pagination +page = images.order_by("ts").limit(50).offset(100).fetch() +``` + +### Terminal methods + +```python +class Stream(Generic[T]): + def __iter__(self) -> Iterator[Observation]: ... # lazy, pages internally + def fetch(self) -> list[Observation]: ... # all results in memory + def fetch_pages(self, batch_size: int = 128) -> Iterator[list[Observation]]: ... + def one(self) -> Observation: ... + def last(self) -> Observation: ... + def count(self) -> int: ... + def order_by(self, field: str, *, desc: bool = False) -> Stream[T]: ... + def limit(self, k: int) -> Stream[T]: ... + def offset(self, n: int) -> Stream[T]: ... +``` + +## Observation + +```python +from dimos.models.embedding.base import Embedding, EmbeddingModel + +@dataclass +class Observation: + id: int + ts: float | None = None + pose: PoseStamped | None = None + tags: dict[str, Any] = field(default_factory=dict) + + @property + def data(self) -> Any: + """Lazy payload. Pre-populated from append/transform, fetched on demand from query.""" + ... +``` + +## Transformer + +A `Transformer` receives the full source stream and decides what to do — which items to process, how to batch, whether to use embeddings as a cheap proxy, etc. + +```python +class Transformer(ABC, Generic[T, R]): + """Transforms a source stream into results on a target stream.""" + + def process(self, source: Stream[T], target: Stream[R]) -> None: + """Batch/historical processing. Has full access to source — can query, + filter, use embeddings, batch, skip frames, etc.""" + ... + + def on_append(self, obs: Observation, target: Stream[R]) -> None: + """Reactive processing. Called per new item. Default: process([obs]).""" + ... + + supports_backfill: bool = True + supports_live: bool = True +``` + +### Simple lambdas (sugar) + +`Callable[[T], R | list[R] | None]` is auto-wrapped into a naive per-item Transformer: + +```python +# These are equivalent: +images.transform(lambda img: vlm.detect(img, "cigarettes")) +images.transform(PerItemTransformer(lambda img: vlm.detect(img, "cigarettes"))) +``` + +- `R` → single result +- `list[R]` → multiple results (e.g., multiple detections per frame) +- `None` → skip (no result for this input) + +### EmbeddingTransformer + +`EmbeddingTransformer` wraps an `EmbeddingModel` as a `Transformer[T, Embedding]`. When the output type is `Embedding`, `.store()` creates an `EmbeddingStream` (vec0 index, `search_embedding`, `EmbeddingObservation`). + +```python +# EmbeddingTransformer wraps the model +img_emb = images.transform(EmbeddingTransformer(CLIPModel())).store("img_emb") + +# Now img_emb is an EmbeddingStream +results = img_emb.search_embedding(query_emb, k=20).fetch() +# results[0].data → Image (auto-projected from source) +# results[0].embedding → Embedding (supports @ for cosine similarity) +``` + +### Smart Transformer example + +Chains after an embedding transform — receives `EmbeddingObservation` with `.data` (Image) and `.embedding` (vector), so it can use similarity to skip irrelevant frames: + +```python +class CigaretteDetector(Transformer[EmbeddingObservation, Detection]): + def __init__(self, vlm, clip): + self.vlm = vlm + self.clip = clip + + def process(self, source: Stream[EmbeddingObservation], target: Stream[Detection]): + query = self.clip.embed_text("person smoking cigarette") + for page in source.fetch_pages(batch_size=16): + # Use embedding similarity as cheap proxy — skip distant frames + promising = [obs for obs in page if obs.embedding @ query > 0.3] + if not promising: + continue + detections = self.vlm.detect_batch( + [obs.data for obs in promising], "cigarettes" + ) + for obs, dets in zip(promising, detections): + for det in dets: + target.append(det, ts=obs.ts, pose=obs.pose) + + def on_append(self, obs: EmbeddingObservation, target: Stream[Detection]): + dets = self.vlm.detect(obs.data, "cigarettes") + for det in dets: + target.append(det, ts=obs.ts, pose=obs.pose) +``` + +### Chaining transforms + +```python +# Filter → transform → store +images.after(one_hour_ago) \ + .near(kitchen_pose, 5.0) \ + .transform(EmbeddingTransformer(CLIPModel())) \ + .store("kitchen_embeddings") + +# Filter → transform → fetch (in-memory, not persisted) +results = images.after(one_hour_ago) \ + .near(kitchen_pose, 5.0) \ + .transform(EmbeddingTransformer(CLIPModel())) \ + .fetch() + +# Filter → embed → detect → store (chained: detector gets EmbeddingObservation) +images.near(kitchen_pose, 5.0) \ + .transform(EmbeddingTransformer(CLIPModel())) \ + .transform(CigaretteDetector(vlm, clip)) \ + .store("kitchen_cigarette_detections") +``` + +### Backfill / Live modes + +```python +# Both (default): backfill existing + subscribe to new +images.transform(detector).store("detections") + +# Live only: skip backfill, only process new items +images.transform(detector, live=True).store("detections") + +# Backfill only: process existing, don't subscribe +images.transform(detector, backfill=True).store("detections") + +# Backfill only: process existing, and subscribe +images.transform(detector, backfill=True, live=True).store("detections") + +# Incremental: re-running a stored transform resumes from last processed item +# (uses lineage parent_id to skip already-processed source rows) +``` + +## Storing + +`.store(name)` materializes a stream to DB. After storing, results are queryable and persistent. + +```python +# In-memory transform result — not persisted +detections = images.transform(detect_fn) + +# Persist it +detections.store("detections") + +# Now it's a DB-backed stream, queryable +stored = session.stream("detections") +rows = stored.after(t).fetch() +``` + +`.store()` also sets up lineage — every stored row gets `parent_id` pointing back to its source. + +Stream type is determined by what the Transformer produces: +- `Embedding` output → `EmbeddingStream` (vec0 index) +- Everything else → `Stream` (blob) +- `TextStream` is created explicitly via `session.text_stream()` (not auto-detected) + +## Reactive + +```python +# .appended emits Observation with .data pre-populated +images.appended.subscribe(lambda row: print(f"New image at {row.pose}")) + +# Stored transforms propagate reactively by default +detections = images.transform(detect_fn).store("detections") +# Now every images.append(frame) → detect_fn runs → result stored in "detections" + +# Filtered appended — only kitchen images +images.near(kitchen_pose, 5.0).appended.subscribe(...) +``` + +## Join (cross-stream lineage) + +```python +# Join detections with their source images — returns tuples +for det, img in detections.after(t).join(images): + print(f"Detected {det.data} in image at {img.pose}") +``` + +## Full Example: Cigarette Detection Pipeline + +```python +session = SqliteStore("/data/robot.db").session() + +# Root stream +images = session.stream("images", Image, + pose_provider=lambda: tf.get_pose("world", "base_link")) + +# Embedding index — EmbeddingModel is a Transformer +img_emb = images.transform(EmbeddingTransformer(CLIPModel())).store("img_emb") + +# VLM detection pipeline (live-only, no backfill) +images.transform( + lambda img: vlm.detect(img, "people with cigarettes"), + live=True, +).store("cigarette_detections") + +# Smart detection — reuse existing embeddings, detector gets EmbeddingObservation +img_emb.near(kitchen_pose, 10.0) \ + .transform(CigaretteDetector(vlm, clip)) \ + .store("kitchen_cigarette_detections") + +# # Worse: re-embeds from scratch (redundant if img_emb already exists) +# images.near(kitchen_pose, 10.0) \ +# .transform(EmbeddingTransformer(CLIPModel())) \ +# .transform(CigaretteDetector(vlm, clip)) \ +# .store("kitchen_cigarette_detections") + +# --- Later, querying --- + +# "Where did we see people with cigarettes in the kitchen?" +for row in session.stream("cigarette_detections") \ + .after(one_hour_ago).near(kitchen_pose, 10.0): + print(f"t={row.ts} pose={row.pose}: {row.data}") + +# "Show me the source images alongside detections" +for det, img in session.stream("cigarette_detections") \ + .after(one_hour_ago).join(images): + print(f"Detection: {det.data}, Source image at {img.pose}") + +# "Find images similar to 'red shoes'" +query_emb = clip.embed_text("red shoes") +similar = img_emb.search_embedding(query_emb, k=20).fetch() +# similar[0].data → Image (auto-projected from source) +# similar[0].embedding → Embedding (supports @ for cosine similarity) +``` + +## Full API + +```python +from dimos.models.embedding.base import Embedding, EmbeddingModel + +# --- Data types --- + +@dataclass +class Observation: + id: int + ts: float | None = None + pose: PoseStamped | None = None + tags: dict[str, Any] = field(default_factory=dict) + + @property + def data(self) -> Any: + """Lazy payload. Pre-populated from append, fetched on demand from query.""" + ... + +@dataclass +class EmbeddedObservation(Observation): + """Returned by EmbeddingStream terminals. Auto-projects .data to source stream.""" + + @property + def data(self) -> Any: + """Lazily loads from the source stream (e.g., Image), not the embedding.""" + ... + + @property + def embedding(self) -> Embedding: + """The Embedding object (has .vector, supports @ for cosine similarity).""" + ... + +# --- Transformer --- + +class Transformer(ABC, Generic[T, R]): + """Transforms a source stream into results on a target stream.""" + + def process(self, source: Stream[T], target: Stream[R]) -> None: + """Batch/historical processing. Full access to source stream.""" + ... + + def on_append(self, obs: Observation, target: Stream[R]) -> None: + """Reactive processing. Called per new item.""" + ... + + supports_backfill: bool = True + supports_live: bool = True + +# --- Streams --- + +class Stream(Generic[T]): + # Write (DB-backed only) + def append(self, payload: T, *, + ts: float | None = None, + pose: PoseLike | None = None, + tags: dict[str, Any] | None = None, + ) -> Observation: ... + + # Filter (returns new Stream, lazy) + def after(self, t: float) -> Stream[T]: ... + def before(self, t: float) -> Stream[T]: ... + def time_range(self, t1: float, t2: float) -> Stream[T]: ... + def at(self, t: float, *, tolerance: float = 1.0) -> Stream[T]: ... + def near(self, pose: PoseLike, radius: float) -> Stream[T]: ... + def filter_tags(self, **tags: Any) -> Stream[T]: ... + + # Order / paginate + def order_by(self, field: str, *, desc: bool = False) -> Stream[T]: ... + def limit(self, k: int) -> Stream[T]: ... + def offset(self, n: int) -> Stream[T]: ... + + # Transform + def transform(self, + xf: Transformer[T, R] | Callable[[T], R | list[R] | None], + *, live: bool = False, + backfill_only: bool = False, + ) -> Stream[R]: ... + + # Materialize + def store(self, name: str | None = None) -> Stream[T]: ... + + # Cross-stream (lineage join — returns tuples of (self_obs, target_obs)) + def join(self, target: Stream) -> Stream[tuple[Observation, Observation]]: ... + + # Iteration & Terminals + def __iter__(self) -> Iterator[Observation]: ... # lazy, pages internally + def fetch(self) -> list[Observation]: ... # all results in memory + def fetch_pages(self, batch_size: int = 128) -> Iterator[list[Observation]]: ... + def one(self) -> Observation: ... + def last(self) -> Observation: ... + def count(self) -> int: ... + + # Reactive + @property + def appended(self) -> Observable[Observation]: ... + +class EmbeddingStream(Stream[T]): + """Created automatically when a Transformer produces Embedding output. + Terminals return EmbeddedObservation (auto-projects .data to source stream).""" + def search_embedding(self, query: Embedding | list[float], *, k: int) -> EmbeddingStream[T]: ... + def fetch(self) -> list[EmbeddedObservation]: ... + def one(self) -> EmbeddedObservation: ... + def last(self) -> EmbeddedObservation: ... + +class TextStream(Stream[T]): + """Stream with FTS index.""" + def search_text(self, text: str, *, k: int | None = None) -> TextStream[T]: ... + +# --- Session / Store --- + +PoseProvider = Callable[[], PoseLike | None] + +class Session: + def stream(self, name: str, payload_type: type | None = None, *, + pose_provider: PoseProvider | None = None) -> Stream: ... + def text_stream(self, name: str, payload_type: type | None = None, *, + tokenizer: str = "unicode61", + pose_provider: PoseProvider | None = None) -> TextStream: ... + def list_streams(self) -> list[StreamInfo]: ... + def close(self) -> None: ... + +class Store: + def session(self) -> Session: ... + def close(self) -> None: ... +``` + +## Internal Backing (impl detail) + +A `Stream` can be backed by different things — the user never sees this: + +- **DB table** — from `session.stream()`. Has `_meta`, `_payload`, indexes. +- **Predicate** — from `.after()`, `.near()`, etc. Lazy SQL WHERE. +- **Transform** — from `.transform(t)`. Source stream + Transformer. + +The impl decides how to execute based on the backing chain. + +## Open Questions + +1. **`.append()` on non-stored streams?** Runtime error, or silently ignore? Probably `TypeError`. + +2. **Multiple `.store()` calls?** Should be idempotent — second call is a no-op if already stored under the same name. + +3. ~~**Memory pressure from in-memory transforms?**~~ Solved — `Stream` is iterable, pages internally via `fetch_pages`. diff --git a/plans/query_objects.md b/plans/memory/query_objects.md similarity index 100% rename from plans/query_objects.md rename to plans/memory/query_objects.md diff --git a/plans/questions.md b/plans/memory/questions.md similarity index 100% rename from plans/questions.md rename to plans/memory/questions.md diff --git a/plans/memory/sqlite.md b/plans/memory/sqlite.md new file mode 100644 index 0000000000..0439b7defa --- /dev/null +++ b/plans/memory/sqlite.md @@ -0,0 +1,780 @@ +# SQLite Implementation + +Implementation spec for `dimos/memory/impl/sqlite/`. A coding agent should be able to implement the full SQLite backend from this document + `api.md`. + +## File Structure + +``` +dimos/memory/ + __init__.py # public exports: Observation, EmbeddingObservation, + # Stream, EmbeddingStream, TextStream, Transformer, + # EmbeddingTransformer, PerItemTransformer, Session, Store + types.py # Observation, EmbeddingObservation, StreamInfo + stream.py # Stream, EmbeddingStream, TextStream (base classes) + transformer.py # Transformer ABC, EmbeddingTransformer, PerItemTransformer + store.py # Store ABC + session.py # Session ABC + + impl/ + sqlite/ + __init__.py # exports SqliteStore + store.py # SqliteStore + session.py # SqliteSession + stream.py # SqliteStream, SqliteEmbeddingStream, SqliteTextStream + query.py # FilterChain — accumulates predicates, generates SQL + _sql.py # SQL helpers, identifier validation, pose helpers, serialization +``` + +## Dependencies + +- `sqlite3` (stdlib) +- `sqlite-vec` — vector similarity search via vec0 virtual table. Optional — `search_embedding` raises if unavailable. +- FTS5 — built into SQLite by default on most platforms. +- R*Tree — built into SQLite by default. +- `reactivex` — for `.appended` observable (already a DimOS dependency). + +## Connection Management + +### SqliteStore + +```python +class SqliteStore(Store): + def __init__(self, path: str): + self.path = path # or ":memory:" + + def session(self) -> SqliteSession: + conn = self._connect() + return SqliteSession(conn) + + def _connect(self) -> sqlite3.Connection: + if self.path == ":memory:": + uri = "file::memory:?cache=shared" + conn = sqlite3.connect(uri, uri=True) + else: + Path(self.path).parent.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(self.path) + + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA synchronous=NORMAL") + conn.execute("PRAGMA foreign_keys=ON") + + # Try loading sqlite-vec + try: + conn.enable_load_extension(True) + conn.load_extension("vec0") # or find via sqlite_vec.loadable_path() + conn.enable_load_extension(False) + except Exception: + pass # vec0 unavailable — search_embedding will raise + + return conn + + def close(self) -> None: ... +``` + +### SqliteSession + +```python +class SqliteSession(Session): + def __init__(self, conn: sqlite3.Connection): + self._conn = conn + self._streams: dict[str, SqliteStream] = {} # cache by name + self._ensure_registry() + + def _ensure_registry(self): + """Create _streams table if not exists.""" + self._conn.execute(""" + CREATE TABLE IF NOT EXISTS _streams ( + rowid INTEGER PRIMARY KEY, + name TEXT UNIQUE NOT NULL, + type TEXT NOT NULL, + payload_type TEXT, + parent_stream TEXT, + embedding_dim INTEGER + ) + """) + + def stream(self, name, payload_type=None, *, pose_provider=None) -> SqliteStream: + if name in self._streams: + return self._streams[name] + self._register_stream(name, "blob", payload_type) + self._create_stream_tables(name, stream_type="blob") + s = SqliteStream(name, self._conn, payload_type, pose_provider) + self._streams[name] = s + return s + + def text_stream(self, name, payload_type=None, *, tokenizer="unicode61", + pose_provider=None) -> SqliteTextStream: + # Similar — creates FTS tables too + ... + + def list_streams(self) -> list[StreamInfo]: ... + def close(self) -> None: self._conn.close() +``` + +## Schema + +All table names are prefixed with the stream name. Stream names are validated: `[a-zA-Z_][a-zA-Z0-9_]*`, max 64 chars. + +### `_streams` — Global registry + +```sql +CREATE TABLE _streams ( + rowid INTEGER PRIMARY KEY, + name TEXT UNIQUE NOT NULL, + type TEXT NOT NULL, -- 'blob', 'embedding', 'text' + payload_type TEXT, -- e.g. 'dimos.msgs.sensor_msgs.Image' + parent_stream TEXT, -- FK name of parent stream (lineage) + embedding_dim INTEGER -- only for type='embedding' +); +``` + +### `{name}_meta` — Observation metadata (all stream types) + +```sql +CREATE TABLE {name}_meta ( + rowid INTEGER PRIMARY KEY, -- = Observation.id + ts REAL, + pose_x REAL, pose_y REAL, pose_z REAL, + pose_qx REAL, pose_qy REAL, pose_qz REAL, pose_qw REAL, + tags TEXT, -- JSON dict, NULL if empty + parent_rowid INTEGER -- lineage: rowid in parent stream's _meta +); +CREATE INDEX idx_{name}_ts ON {name}_meta(ts); +``` + +### `{name}_payload` — Blob/Text payload (not EmbeddingStream) + +```sql +CREATE TABLE {name}_payload ( + rowid INTEGER PRIMARY KEY, -- matches _meta.rowid + data BLOB NOT NULL -- TextStream: TEXT instead of BLOB +); +``` + +Separated from `_meta` so metadata queries never page in multi-MB blobs. + +### `{name}_rtree` — Spatial index (all stream types) + +```sql +CREATE VIRTUAL TABLE {name}_rtree USING rtree( + rowid, -- matches _meta.rowid + min_x, max_x, + min_y, max_y, + min_z, max_z +); +``` + +Only rows with pose are inserted into R*Tree. Rows without pose are excluded from `.near()` results. + +### `{name}_fts` — Full-text search (TextStream only) + +```sql +CREATE VIRTUAL TABLE {name}_fts USING fts5( + content, + tokenize='{tokenizer}' +); +``` + +Standalone FTS table (not content-synced). Rowids match `_meta.rowid`. + +### `{name}_vec` — Vector index (EmbeddingStream only) + +```sql +CREATE VIRTUAL TABLE {name}_vec USING vec0( + embedding float[{dim}] +); +``` + +Rowids match `_meta.rowid`. Dimension inferred from first embedding inserted, or from `EmbeddingModel.embed()` output. + +## Stream Implementation + +### SqliteStream (implements Stream[T]) + +Internally, a stream object can be in different modes: + +```python +@dataclass +class StoredBacking: + """Root DB-backed stream. Created by session.stream().""" + name: str + +@dataclass +class FilteredBacking: + """Lazy predicate chain. Created by .after(), .near(), etc.""" + parent: StreamBacking # recursive — can chain filters + predicates: list[Predicate] + ordering: list[OrderClause] + limit_val: int | None + offset_val: int | None + +@dataclass +class TransformBacking: + """Unevaluated transform. Created by .transform().""" + source: StreamBacking + transformer: Transformer + live: bool + backfill_only: bool + +Backing = StoredBacking | FilteredBacking | TransformBacking +``` + +The stream carries its backing and resolves it at terminal time. + +### append() + +Only valid on `StoredBacking`. Otherwise raises `TypeError`. + +```python +def append(self, payload, *, ts=None, pose=None, tags=None): + if not isinstance(self._backing, StoredBacking): + raise TypeError("append() only valid on stored streams") + + ts = ts or time.time() + pose = pose or (self._pose_provider() if self._pose_provider else None) + + # 1. Insert into _meta + meta_rowid = self._insert_meta(ts, pose, tags, parent_rowid=None) + + # 2. Insert into _payload + blob = serialize(payload) # see Serialization section + self._conn.execute( + f"INSERT INTO {name}_payload(rowid, data) VALUES (?, ?)", + (meta_rowid, blob) + ) + + # 3. Insert into _rtree (if pose) + if pose: + x, y, z = extract_position(pose) + self._conn.execute( + f"INSERT INTO {name}_rtree(rowid, min_x, max_x, min_y, max_y, min_z, max_z) " + "VALUES (?, ?, ?, ?, ?, ?, ?)", + (meta_rowid, x, x, y, y, z, z) + ) + + self._conn.commit() + + # 4. Build Observation and emit + obs = Observation(id=meta_rowid, ts=ts, pose=pose, tags=tags or {}) + obs._data = payload # pre-populated + self._appended_subject.on_next(obs) + return obs +``` + +### EmbeddingStream.append() + +Same as above but inserts into `_vec` instead of `_payload`: + +```python +# Insert embedding vector +vec_data = embedding.to_numpy().tobytes() +self._conn.execute( + f"INSERT INTO {name}_vec(rowid, embedding) VALUES (?, ?)", + (meta_rowid, vec_data) +) +``` + +### TextStream.append() + +Inserts into both `_payload` (TEXT) and `_fts`: + +```python +self._conn.execute( + f"INSERT INTO {name}_payload(rowid, data) VALUES (?, ?)", + (meta_rowid, text_content) +) +self._conn.execute( + f"INSERT INTO {name}_fts(rowid, content) VALUES (?, ?)", + (meta_rowid, text_content) +) +``` + +## Filter → SQL Generation + +Each filter method returns a new stream with a `FilteredBacking` wrapping the current backing. At terminal time, the filter chain is compiled to SQL. + +### Predicate types + +```python +@dataclass +class AfterPred: + t: float + # → WHERE ts > ? + +@dataclass +class BeforePred: + t: float + # → WHERE ts < ? + +@dataclass +class TimeRangePred: + t1: float + t2: float + # → WHERE ts BETWEEN ? AND ? + +@dataclass +class AtPred: + t: float + tolerance: float + # → WHERE ts BETWEEN ? AND ? ORDER BY ABS(ts - ?) LIMIT 1 + +@dataclass +class NearPred: + x: float + y: float + z: float + radius: float + # → JOIN with _rtree bounding box query + +@dataclass +class TagsPred: + tags: dict[str, Any] + # → WHERE json_extract(tags, '$.key') = ? + +@dataclass +class TextSearchPred: + text: str + k: int | None + # → JOIN with _fts MATCH + +@dataclass +class EmbeddingSearchPred: + vector: list[float] + k: int + # → query _vec for top-k, then filter +``` + +### SQL compilation + +Walk the backing chain to the root `StoredBacking`, collect all predicates, then generate SQL: + +```python +def _compile(self) -> tuple[str, list[Any]]: + """Walk backing chain, return (sql, params).""" + root_name = self._find_root_name() + predicates = self._collect_predicates() + ordering = self._collect_ordering() + limit = self._collect_limit() + offset = self._collect_offset() + + # Start with base SELECT + sql = f"SELECT rowid, ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, tags FROM {root_name}_meta" + params = [] + joins = [] + wheres = [] + + for pred in predicates: + if isinstance(pred, AfterPred): + wheres.append("ts > ?") + params.append(pred.t) + elif isinstance(pred, NearPred): + joins.append( + f"JOIN {root_name}_rtree r ON r.rowid = {root_name}_meta.rowid" + ) + wheres.append( + "r.min_x >= ? AND r.max_x <= ? AND " + "r.min_y >= ? AND r.max_y <= ? AND " + "r.min_z >= ? AND r.max_z <= ?" + ) + params.extend([ + pred.x - pred.radius, pred.x + pred.radius, + pred.y - pred.radius, pred.y + pred.radius, + pred.z - pred.radius, pred.z + pred.radius, + ]) + elif isinstance(pred, TagsPred): + for key, val in pred.tags.items(): + wheres.append(f"json_extract(tags, '$.{key}') = ?") + params.append(val) + # ... etc + + sql += " " + " ".join(joins) + if wheres: + sql += " WHERE " + " AND ".join(wheres) + if ordering: + sql += " ORDER BY " + ", ".join(ordering) + if limit is not None: + sql += " LIMIT ?" + params.append(limit) + if offset is not None: + sql += " OFFSET ?" + params.append(offset) + + return sql, params +``` + +### search_embedding (vec0) + +```sql +-- Top-k vector search +SELECT rowid, distance +FROM {name}_vec +WHERE embedding MATCH ? + AND k = ? +ORDER BY distance +``` + +Returns rowids, which are then used to filter `_meta`. This is a two-step process: +1. Get top-k rowids from vec0 +2. Fetch metadata for those rowids + +### search_text (FTS5) + +```sql +SELECT rowid, rank +FROM {name}_fts +WHERE {name}_fts MATCH ? +ORDER BY rank +``` + +Same two-step: get rowids from FTS5, then fetch metadata. + +## Terminal Execution + +### __iter__() — lazy iteration + +`Stream` is directly iterable. Pages internally via `fetch_pages`, yielding one `Observation` at a time: + +```python +def __iter__(self) -> Iterator[Observation]: + for page in self.fetch_pages(): + yield from page +``` + +### fetch() + +```python +def fetch(self) -> list[Observation]: + sql, params = self._compile() + rows = self._conn.execute(sql, params).fetchall() + return [self._row_to_observation(row) for row in rows] +``` + +### fetch_pages() + +```python +def fetch_pages(self, batch_size=128) -> Iterator[list[Observation]]: + sql, params = self._compile() + # Add LIMIT/OFFSET pagination + offset = 0 + while True: + page_sql = sql + f" LIMIT {batch_size} OFFSET {offset}" + rows = self._conn.execute(page_sql, params).fetchall() + if not rows: + break + yield [self._row_to_observation(row) for row in rows] + offset += batch_size +``` + +### count() + +```python +def count(self) -> int: + sql, params = self._compile() + count_sql = f"SELECT COUNT(*) FROM ({sql})" + return self._conn.execute(count_sql, params).fetchone()[0] +``` + +### one() / last() + +- `one()` → adds `LIMIT 1` to the query +- `last()` → adds `ORDER BY ts DESC LIMIT 1` + +## Lazy Data Loading + +`Observation.data` uses lazy loading. The implementation: + +```python +@dataclass +class Observation: + id: int + ts: float | None = None + pose: PoseStamped | None = None + tags: dict[str, Any] = field(default_factory=dict) + _data: Any = field(default=_SENTINEL, repr=False) + _load: Callable[[], Any] | None = field(default=None, repr=False) + + @property + def data(self) -> Any: + if self._data is _SENTINEL and self._load is not None: + self._data = self._load() + return self._data +``` + +When building observations from query results: + +```python +def _row_to_observation(self, row) -> Observation: + rowid = row[0] + obs = Observation( + id=rowid, + ts=row[1], + pose=reconstruct_pose(row[2:9]), + tags=json.loads(row[9]) if row[9] else {}, + ) + name = self._root_name() + conn = self._conn + obs._load = lambda: deserialize( + conn.execute(f"SELECT data FROM {name}_payload WHERE rowid = ?", (rowid,)).fetchone()[0] + ) + return obs +``` + +### EmbeddingObservation + +For `EmbeddingStream`, terminals return `EmbeddingObservation` which auto-projects `.data` to the source stream: + +```python +def _row_to_embedding_observation(self, row) -> EmbeddingObservation: + rowid = row[0] + parent_stream = self._get_parent_stream_name() + obs = EmbeddingObservation(id=rowid, ts=row[1], ...) + + # .data loads from PARENT stream (auto-projection) + obs._load = lambda: deserialize( + conn.execute( + f"SELECT data FROM {parent_stream}_payload WHERE rowid = ?", + (conn.execute( + f"SELECT parent_rowid FROM {self._name}_meta WHERE rowid = ?", + (rowid,) + ).fetchone()[0],) + ).fetchone()[0] + ) + + # .embedding loads from _vec + obs._embedding_load = lambda: Embedding( + np.frombuffer( + conn.execute( + f"SELECT embedding FROM {self._name}_vec WHERE rowid = ?", + (rowid,) + ).fetchone()[0], + dtype=np.float32 + ) + ) + return obs +``` + +## Lineage & join + +### Storing lineage + +When a Transformer appends to a target stream, `parent_rowid` links back to the source: + +```python +# Inside Transformer execution +target.append(result, ts=source_obs.ts, pose=source_obs.pose, + _parent_rowid=source_obs.id) # internal param +``` + +The `_streams` registry tracks stream-level lineage: +```python +# When .store() creates from a transform +INSERT INTO _streams (name, type, payload_type, parent_stream) +VALUES ('detections', 'blob', '...', 'images') +``` + +### join() + +Returns tuples of `(self_obs, target_obs)` linked by lineage: + +```sql +-- Join self with target via parent_rowid +SELECT + c.rowid, c.ts, c.pose_x, ..., -- self (e.g., detections) + p.rowid, p.ts, p.pose_x, ... -- target (e.g., images) +FROM {self}_meta c +JOIN {target}_meta p ON c.parent_rowid = p.rowid +WHERE c.rowid IN (/* current filtered set */) +``` + +Iteration yields `tuple[Observation, Observation]` — both sides have lazy `.data`. + +## Transform Execution + +### .transform() — returns lazy stream + +`.transform(xf)` doesn't execute immediately. It returns a new stream with `TransformBacking`. Execution happens at terminal time or `.store()`. + +### .store() — materializes + +When `.store(name)` is called on a transform-backed stream: + +1. Register target stream in `_streams` (with `parent_stream` set) +2. Create target tables (`_meta`, `_payload`, etc.) +3. If not `live` mode: run `xf.process(source_stream, target_stream)` (backfill) +4. If not `backfill_only`: subscribe to source's `.appended` observable, call `xf.on_append()` for each new item +5. Return the stored stream (now `StoredBacking`) + +```python +def store(self, name): + if not isinstance(self._backing, TransformBacking): + # Already stored or predicate-backed — different path + ... + + tb = self._backing + # Create target stream + target = self._session._create_stream(name, ...) + + # Register lineage + self._session._register_lineage(name, parent_stream=source_name) + + # Backfill + if not tb.live and tb.transformer.supports_backfill: + source_stream = self._resolve_source() + tb.transformer.process(source_stream, target) + + # Live subscription + if not tb.backfill_only and tb.transformer.supports_live: + source_stream = self._resolve_source() + source_stream.appended.subscribe( + lambda obs: tb.transformer.on_append(obs, target) + ) + + return target +``` + +### Incremental backfill + +When re-opening a previously stored transform, check what's already been processed: + +```python +# Find max parent_rowid already processed +max_parent = conn.execute( + f"SELECT MAX(parent_rowid) FROM {target_name}_meta" +).fetchone()[0] + +# Only process source rows after that +if max_parent is not None: + source = source.after_id(max_parent) # internal method +``` + +### .fetch() on transform-backed stream (no .store()) + +If `.fetch()` is called on a transform-backed stream without `.store()`, execute the transform in-memory: + +1. Fetch source observations +2. Apply transformer's `process()` with an in-memory target +3. Return results without persisting + +This is useful for one-off transforms but can cause memory pressure with large datasets. + +## Reactive (.appended) + +Each stored stream has a `ReplaySubject` (or `Subject`) from reactivex: + +```python +class SqliteStream: + def __init__(self, ...): + self._appended_subject = Subject() + + @property + def appended(self) -> Observable[Observation]: + return self._appended_subject.pipe(...) +``` + +`append()` emits to the subject after the DB write succeeds. + +For filtered streams (`.after(t).near(pose, 5.0).appended`), the observable filters events through the predicate chain in Python: + +```python +@property +def appended(self): + root = self._find_root_stream() + predicates = self._collect_predicates() + return root.appended.pipe( + ops.filter(lambda obs: all(p.matches(obs) for p in predicates)) + ) +``` + +Each predicate type implements `matches(obs) -> bool` for Python-side filtering. + +## Serialization + +### Payload serialization + +Use Python `pickle` for general types, with an optimization path for known DimOS types (LCM-encoded messages): + +```python +def serialize(payload: Any) -> bytes: + # LCM types: use lcm_encode for compact binary + if hasattr(payload, '_get_packed_fingerprint'): + return lcm_encode(payload) + # Fallback: pickle + return pickle.dumps(payload) + +def deserialize(blob: bytes, payload_type: type | None = None) -> Any: + if payload_type and hasattr(payload_type, '_get_packed_fingerprint'): + return lcm_decode(blob, payload_type) + return pickle.loads(blob) +``` + +### Pose helpers + +```python +def extract_position(pose: PoseLike) -> tuple[float, float, float]: + """Extract (x, y, z) from any PoseLike.""" + if isinstance(pose, PoseStamped): + p = pose.pose.position + return (p.x, p.y, p.z) + # ... handle Pose, Point, PointStamped + +def extract_orientation(pose: PoseLike) -> tuple[float, float, float, float] | None: + """Extract (qx, qy, qz, qw) if available.""" + ... + +def reconstruct_pose(row_slice) -> PoseStamped | None: + """Rebuild PoseStamped from (x, y, z, qx, qy, qz, qw) columns.""" + x, y, z, qx, qy, qz, qw = row_slice + if x is None: + return None + ... +``` + +### Tag serialization + +Tags are stored as JSON text. `None`/empty dict → `NULL` in the column. + +```python +tags_json = json.dumps(tags) if tags else None +``` + +## SQL Safety + +- **Identifier validation**: stream names must match `^[a-zA-Z_][a-zA-Z0-9_]{0,63}$`. Reject anything else with `ValueError`. +- **Parameterized queries**: all user values go through `?` params, never string interpolation. +- **Table names**: constructed from validated stream names, so they're safe for SQL interpolation (e.g., `f"{name}_meta"`). + +## Thread Safety + +- Each `Session` owns one `sqlite3.Connection` — not shared across threads. +- Multiple sessions can exist on the same file (WAL mode allows concurrent reads + one writer). +- The `appended` subject emits on the thread that called `append()`. + +## Error Handling + +- `append()` on non-stored stream → `TypeError` +- `search_embedding()` on non-embedding stream → `TypeError` +- `search_text()` on non-text stream → `TypeError` +- `search_embedding()` when sqlite-vec not loaded → `RuntimeError` +- Invalid stream name → `ValueError` +- `one()` with no results → `LookupError` + +## Testing + +Tests go in `dimos/memory/tests/test_sqlite.py`. Use `:memory:` store for speed. + +Key test scenarios: +1. Create stream, append, fetch — verify data round-trips +2. Temporal filters (after, before, time_range, at) +3. Spatial filter (near) — with and without pose +4. Tag filtering +5. EmbeddingStream — store embeddings, search_embedding, verify EmbeddingObservation auto-projects .data +6. TextStream — store text, search_text +7. Transform with lambda — verify lineage +8. Transform with Transformer class — verify process() called +9. Chained filters — verify SQL composition +10. join — verify cross-stream lineage returns tuples +11. fetch_pages — verify pagination +12. Lazy data loading — verify .data only hits DB on access +13. .appended observable — verify reactive emission +14. Incremental backfill — verify resume from last processed +15. Multiple sessions on same file diff --git a/plans/transform.md b/plans/memory/transform.md similarity index 92% rename from plans/transform.md rename to plans/memory/transform.md index 0d82481304..409fd8fc6b 100644 --- a/plans/transform.md +++ b/plans/memory/transform.md @@ -34,10 +34,10 @@ class StreamBase(ABC, Generic[T]): ## Source type determines mode -| Source | `live=False` (default) | `live=True` | -|--------|----------------------|-------------| -| `StreamBase` | backfill all existing + subscribe to `.appended` | subscribe to `.appended` only | -| `ObservationSet` | batch process the set | N/A (ignored) | +| Source | `live=False` (default) | `live=True` | +|------------------|--------------------------------------------------|-------------------------------| +| `StreamBase` | backfill all existing + subscribe to `.appended` | subscribe to `.appended` only | +| `ObservationSet` | batch process the set | N/A (ignored) | ## Transform function contract diff --git a/plans/analysis.md b/plans/old/analysis.md similarity index 100% rename from plans/analysis.md rename to plans/old/analysis.md diff --git a/plans/answers.md b/plans/old/answers.md similarity index 100% rename from plans/answers.md rename to plans/old/answers.md diff --git a/plans/answers_correlator.md b/plans/old/answers_correlator.md similarity index 100% rename from plans/answers_correlator.md rename to plans/old/answers_correlator.md diff --git a/plans/correlator.md b/plans/old/correlator.md similarity index 100% rename from plans/correlator.md rename to plans/old/correlator.md diff --git a/plans/memory.md b/plans/old/memory.md similarity index 100% rename from plans/memory.md rename to plans/old/memory.md diff --git a/plans/memory1.md b/plans/old/memory1.md similarity index 100% rename from plans/memory1.md rename to plans/old/memory1.md diff --git a/plans/memory2.md b/plans/old/memory2.md similarity index 100% rename from plans/memory2.md rename to plans/old/memory2.md diff --git a/plans/memory3.md b/plans/old/memory3.md similarity index 93% rename from plans/memory3.md rename to plans/old/memory3.md index 2534aa5a15..d8f6e1daa0 100644 --- a/plans/memory3.md +++ b/plans/old/memory3.md @@ -111,10 +111,22 @@ class StreamBase(ABC, Generic[T]): @property def appended(self) -> Observable[ObservationRow]: ... # .data pre-populated + # Transform (see transform.md for details) + def transform(self, + source: StreamBase | ObservationSet, + fn: Callable[[Any], T | list[T] | None], + *, + live: bool = False) -> Self: + """Process source data, store results with lineage. + StreamBase source: backfill + subscribe (live=True skips backfill). + ObservationSet source: batch only.""" + ... + # Read def query(self) -> Query[T]: ... def load(self, row: ObservationRow) -> T: ... def load_many(self, rows: list[ObservationRow], *, batch_size=32) -> list[T]: ... + def iter_meta(self, *, page_size=128) -> Iterator[list[ObservationRow]]: ... def count(self) -> int: ... class BlobStream(StreamBase[T]): @@ -124,11 +136,8 @@ class EmbeddingStream(StreamBase[T]): """Stream with vector index. No payload table — the vector IS the data.""" model: EmbeddingModel - def attach(self, parent: StreamBase) -> Self: - """Sets lineage parent + subscribes to parent.appended to auto-embed.""" - # parent.appended.pipe( - # ops.map(lambda row: self._embed_and_store(row)), - # ).subscribe() + def transform(self, source, fn=None, *, live=False) -> Self: + """If fn is None, uses model.embed implicitly.""" ... def vector(self, row: ObservationRow) -> list[float] | None: ... @@ -322,11 +331,11 @@ All virtual table rowids match `_meta.rowid` directly. ## Phase 3: Later (not in first PR) -- `derive()` with Transform protocol - `CompositeBacking` (union/intersection/difference) - `Correlator` / `s.correlate()` - `retention` enforcement / cleanup - Full introspection (stats, spatial_bounds) +- Query objects (`query_objects.md`) — composable criteria + soft scoring ## Design Decisions @@ -337,6 +346,7 @@ All virtual table rowids match `_meta.rowid` directly. - **Unlocalized observations**: rows without pose excluded from `filter_near()` by default. `include_unlocalized=True` to include them. - **Stream hierarchy**: `StreamBase` (ABC) → `BlobStream`, `EmbeddingStream`, `TextStream`. Indexing is determined by stream type, not config. - **Lineage**: parent stream defined at stream level (in `_streams` registry). Per-row `parent_id` links to specific row in parent. +- **Transform**: `.transform(source, fn)` on any stream — unified API for batch (ObservationSet) and live (StreamBase) derived streams. Uses `appended` observable for reactive pipeline. See `transform.md`. ### SQLite-specific diff --git a/plans/memory3_answers.md b/plans/old/memory3_answers.md similarity index 100% rename from plans/memory3_answers.md rename to plans/old/memory3_answers.md diff --git a/plans/old/memory4.md b/plans/old/memory4.md new file mode 100644 index 0000000000..08e42265c4 --- /dev/null +++ b/plans/old/memory4.md @@ -0,0 +1,466 @@ +# Memory2 API — Unified Stream + +## Core Idea + +One type: `Stream[T]`. Everything is a stream — stored, filtered, transformed. The user never thinks about Query vs ObservationSet vs Stream. They just chain operations. + +## Creating Streams + +```python +store = SqliteStore("/data/robot.db") +session = store.session() + +# Root stored stream — backed by DB +images = session.stream("images", Image, + pose_provider=lambda: tf.get_pose("world", "base_link")) + +logs = session.text_stream("logs", str, + pose_provider=lambda: tf.get_pose("world", "base_link")) +``` + +## Writing + +```python +images.append(frame) # ts + pose auto-filled +logs.append("Motor fault on joint 3") # ts + pose auto-filled +images.append(frame, pose=explicit_pose, tags={"cam": "front"}) +``` + +Only meaningful on stored (DB-backed) streams. + +## Filtering + +Every filter returns a new `Stream[T]`. Lazy — nothing executes until a terminal. + +```python +recent = images.after(one_hour_ago) +kitchen = recent.near(kitchen_pose, 5.0) +tagged = kitchen.filter_tags(cam="front") + +# Or chained +images.after(one_hour_ago).near(kitchen_pose, 5.0).filter_tags(cam="front") +``` + +### Filter methods + +```python +class Stream(Generic[T]): + # Temporal + def after(self, t: float) -> Stream[T]: ... + def before(self, t: float) -> Stream[T]: ... + def time_range(self, t1: float, t2: float) -> Stream[T]: ... + def at(self, t: float, *, tolerance: float = 1.0) -> Stream[T]: ... + + # Spatial + def near(self, pose: PoseLike, radius: float) -> Stream[T]: ... + + # Tags + def filter_tags(self, **tags: Any) -> Stream[T]: ... + +class EmbeddingStream(Stream[T]): + def search_embedding(self, query: Embedding | list[float], *, k: int) -> EmbeddingStream[T]: ... + +class TextStream(Stream[T]): + def search_text(self, text: str, *, k: int | None = None) -> TextStream[T]: ... +``` + +## Terminals + +```python +rows = images.after(t).fetch() # list[Observation] +row = images.after(t).one() # single best match +row = images.last() # most recent +n = images.after(t).count() # count without fetching + +# Pagination +page = images.order_by("ts").limit(50).offset(100).fetch() +``` + +### Terminal methods + +```python +class Stream(Generic[T]): + def fetch(self) -> list[Observation]: ... + def fetch_pages(self, batch_size: int = 128) -> Iterator[list[Observation]]: ... + def one(self) -> Observation: ... + def last(self) -> Observation: ... + def count(self) -> int: ... + def order_by(self, field: str, *, desc: bool = False) -> Stream[T]: ... + def limit(self, k: int) -> Stream[T]: ... + def offset(self, n: int) -> Stream[T]: ... +``` + +## Observation + +```python +from dimos.models.embedding.base import Embedding, EmbeddingModel + +@dataclass +class Observation: + id: int + ts: float | None = None + pose: PoseStamped | None = None + tags: dict[str, Any] = field(default_factory=dict) + + @property + def data(self) -> Any: + """Lazy payload. Pre-populated from append/transform, fetched on demand from query.""" + ... +``` + +## Transformer + +A `Transformer` receives the full source stream and decides what to do — which items to process, how to batch, whether to use embeddings as a cheap proxy, etc. + +```python +class Transformer(ABC, Generic[T, R]): + """Transforms a source stream into results on a target stream.""" + + def process(self, source: Stream[T], target: Stream[R]) -> None: + """Batch/historical processing. Has full access to source — can query, + filter, use embeddings, batch, skip frames, etc.""" + ... + + def on_append(self, obs: Observation, target: Stream[R]) -> None: + """Reactive processing. Called per new item. Default: process([obs]).""" + ... + + supports_backfill: bool = True + supports_live: bool = True +``` + +### Simple lambdas (sugar) + +`Callable[[T], R | list[R] | None]` is auto-wrapped into a naive per-item Transformer: + +```python +# These are equivalent: +images.transform(lambda img: vlm.detect(img, "cigarettes")) +images.transform(PerItemTransformer(lambda img: vlm.detect(img, "cigarettes"))) +``` + +- `R` → single result +- `list[R]` → multiple results (e.g., multiple detections per frame) +- `None` → skip (no result for this input) + +### EmbeddingTransformer + +`EmbeddingTransformer` wraps an `EmbeddingModel` as a `Transformer[T, Embedding]`. When the output type is `Embedding`, `.store()` creates an `EmbeddingStream` (vec0 index, `search_embedding`, `EmbeddingObservation`). + +```python +# EmbeddingTransformer wraps the model +img_emb = images.transform(EmbeddingTransformer(CLIPModel())).store("img_emb") + +# Now img_emb is an EmbeddingStream +results = img_emb.search_embedding(query_emb, k=20).fetch() +# results[0].data → Image (auto-projected from source) +# results[0].embedding → Embedding (supports @ for cosine similarity) +``` + +### Smart Transformer example + +Chains after an embedding transform — receives `EmbeddingObservation` with `.data` (Image) and `.embedding` (vector), so it can use similarity to skip irrelevant frames: + +```python +class CigaretteDetector(Transformer[EmbeddingObservation, Detection]): + def __init__(self, vlm, clip): + self.vlm = vlm + self.clip = clip + + def process(self, source: Stream[EmbeddingObservation], target: Stream[Detection]): + query = self.clip.embed_text("person smoking cigarette") + for page in source.fetch_pages(batch_size=16): + # Use embedding similarity as cheap proxy — skip distant frames + promising = [obs for obs in page if obs.embedding @ query > 0.3] + if not promising: + continue + detections = self.vlm.detect_batch( + [obs.data for obs in promising], "cigarettes" + ) + for obs, dets in zip(promising, detections): + for det in dets: + target.append(det, ts=obs.ts, pose=obs.pose) + + def on_append(self, obs: EmbeddingObservation, target: Stream[Detection]): + dets = self.vlm.detect(obs.data, "cigarettes") + for det in dets: + target.append(det, ts=obs.ts, pose=obs.pose) +``` + +### Chaining transforms + +```python +# Filter → transform → store +images.after(one_hour_ago) \ + .near(kitchen_pose, 5.0) \ + .transform(EmbeddingTransformer(CLIPModel())) \ + .store("kitchen_embeddings") + +# Filter → transform → fetch (in-memory, not persisted) +results = images.after(one_hour_ago) \ + .near(kitchen_pose, 5.0) \ + .transform(EmbeddingTransformer(CLIPModel())) \ + .fetch() + +# Filter → embed → detect → store (chained: detector gets EmbeddingObservation) +images.near(kitchen_pose, 5.0) \ + .transform(EmbeddingTransformer(CLIPModel())) \ + .transform(CigaretteDetector(vlm, clip)) \ + .store("kitchen_cigarette_detections") +``` + +### Backfill / Live modes + +```python +# Both (default): backfill existing + subscribe to new +images.transform(detector).store("detections") + +# Live only: skip backfill, only process new items +images.transform(detector, live=True).store("detections") + +# Backfill only: process existing, don't subscribe +images.transform(detector, backfill_only=True).store("detections") + +# Incremental: re-running a stored transform resumes from last processed item +# (uses lineage parent_id to skip already-processed source rows) +``` + +## Storing + +`.store(name)` materializes a stream to DB. After storing, results are queryable and persistent. + +```python +# In-memory transform result — not persisted +detections = images.transform(detect_fn) + +# Persist it +detections.store("detections") + +# Now it's a DB-backed stream, queryable +stored = session.stream("detections") +rows = stored.after(t).fetch() +``` + +`.store()` also sets up lineage — every stored row gets `parent_id` pointing back to its source. + +Stream type is determined by what the Transformer produces: +- `Embedding` output → `EmbeddingStream` (vec0 index) +- Everything else → `Stream` (blob) +- `TextStream` is created explicitly via `session.text_stream()` (not auto-detected) + +## Reactive + +```python +# .appended emits Observation with .data pre-populated +images.appended.subscribe(lambda row: print(f"New image at {row.pose}")) + +# Stored transforms propagate reactively by default +detections = images.transform(detect_fn).store("detections") +# Now every images.append(frame) → detect_fn runs → result stored in "detections" + +# Filtered appended — only kitchen images +images.near(kitchen_pose, 5.0).appended.subscribe(...) +``` + +## Project (cross-stream lineage) + +```python +# Find source images for detections +source_images = detections.after(t).project_to(images) + +# project_to returns a Stream over the parent, filtered by lineage +for row in source_images.fetch(): + img = row.data # lazy-loads from images stream +``` + +## Full Example: Cigarette Detection Pipeline + +```python +session = SqliteStore("/data/robot.db").session() + +# Root stream +images = session.stream("images", Image, + pose_provider=lambda: tf.get_pose("world", "base_link")) + +# Embedding index — EmbeddingModel is a Transformer +img_emb = images.transform(EmbeddingTransformer(CLIPModel())).store("img_emb") + +# VLM detection pipeline (live-only, no backfill) +images.transform( + lambda img: vlm.detect(img, "people with cigarettes"), + live=True, +).store("cigarette_detections") + +# Smart detection — chain embed → detect (detector uses embedding similarity to skip frames) +images.near(kitchen_pose, 10.0) \ + .transform(EmbeddingTransformer(CLIPModel())) \ + .transform(CigaretteDetector(vlm, clip)) \ + .store("kitchen_cigarette_detections") + +# --- Later, querying --- + +# "Where did we see people with cigarettes in the kitchen?" +rows = session.stream("cigarette_detections") \ + .after(one_hour_ago) \ + .near(kitchen_pose, 10.0) \ + .fetch() + +for row in rows: + print(f"t={row.ts} pose={row.pose}: {row.data}") + +# "Show me the source images" +source_imgs = session.stream("cigarette_detections") \ + .after(one_hour_ago) \ + .project_to(images) \ + .fetch() + +# "Find images similar to 'red shoes'" +query_emb = clip.embed_text("red shoes") +similar = img_emb.search_embedding(query_emb, k=20).fetch() +# similar[0].data → Image (auto-projected from source) +# similar[0].embedding → Embedding (supports @ for cosine similarity) +``` + +## Full API + +```python +from dimos.models.embedding.base import Embedding, EmbeddingModel + +# --- Data types --- + +@dataclass +class Observation: + id: int + ts: float | None = None + pose: PoseStamped | None = None + tags: dict[str, Any] = field(default_factory=dict) + + @property + def data(self) -> Any: + """Lazy payload. Pre-populated from append, fetched on demand from query.""" + ... + +@dataclass +class EmbeddingObservation(Observation): + """Returned by EmbeddingStream terminals. Auto-projects .data to source stream.""" + + @property + def data(self) -> Any: + """Lazily loads from the source stream (e.g., Image), not the embedding.""" + ... + + @property + def embedding(self) -> Embedding: + """The Embedding object (has .vector, supports @ for cosine similarity).""" + ... + +# --- Transformer --- + +class Transformer(ABC, Generic[T, R]): + """Transforms a source stream into results on a target stream.""" + + def process(self, source: Stream[T], target: Stream[R]) -> None: + """Batch/historical processing. Full access to source stream.""" + ... + + def on_append(self, obs: Observation, target: Stream[R]) -> None: + """Reactive processing. Called per new item.""" + ... + + supports_backfill: bool = True + supports_live: bool = True + +# --- Streams --- + +class Stream(Generic[T]): + # Write (DB-backed only) + def append(self, payload: T, *, + ts: float | None = None, + pose: PoseLike | None = None, + tags: dict[str, Any] | None = None, + ) -> Observation: ... + + # Filter (returns new Stream, lazy) + def after(self, t: float) -> Stream[T]: ... + def before(self, t: float) -> Stream[T]: ... + def time_range(self, t1: float, t2: float) -> Stream[T]: ... + def at(self, t: float, *, tolerance: float = 1.0) -> Stream[T]: ... + def near(self, pose: PoseLike, radius: float) -> Stream[T]: ... + def filter_tags(self, **tags: Any) -> Stream[T]: ... + + # Order / paginate + def order_by(self, field: str, *, desc: bool = False) -> Stream[T]: ... + def limit(self, k: int) -> Stream[T]: ... + def offset(self, n: int) -> Stream[T]: ... + + # Transform + def transform(self, + xf: Transformer[T, R] | Callable[[T], R | list[R] | None], + *, live: bool = False, + backfill_only: bool = False, + ) -> Stream[R]: ... + + # Materialize + def store(self, name: str | None = None) -> Stream[T]: ... + + # Cross-stream + def project_to(self, target: Stream) -> Stream: ... + + # Terminals + def fetch(self) -> list[Observation]: ... + def fetch_pages(self, batch_size: int = 128) -> Iterator[list[Observation]]: ... + def one(self) -> Observation: ... + def last(self) -> Observation: ... + def count(self) -> int: ... + + # Reactive + @property + def appended(self) -> Observable[Observation]: ... + +class EmbeddingStream(Stream[T]): + """Created automatically when a Transformer produces Embedding output. + Terminals return EmbeddingObservation (auto-projects .data to source stream).""" + def search_embedding(self, query: Embedding | list[float], *, k: int) -> EmbeddingStream[T]: ... + def fetch(self) -> list[EmbeddingObservation]: ... + def one(self) -> EmbeddingObservation: ... + def last(self) -> EmbeddingObservation: ... + +class TextStream(Stream[T]): + """Stream with FTS index.""" + def search_text(self, text: str, *, k: int | None = None) -> TextStream[T]: ... + +# --- Session / Store --- + +PoseProvider = Callable[[], PoseLike | None] + +class Session: + def stream(self, name: str, payload_type: type | None = None, *, + pose_provider: PoseProvider | None = None) -> Stream: ... + def text_stream(self, name: str, payload_type: type | None = None, *, + tokenizer: str = "unicode61", + pose_provider: PoseProvider | None = None) -> TextStream: ... + def list_streams(self) -> list[StreamInfo]: ... + def close(self) -> None: ... + +class Store: + def session(self) -> Session: ... + def close(self) -> None: ... +``` + +## Internal Backing (impl detail) + +A `Stream` can be backed by different things — the user never sees this: + +- **DB table** — from `session.stream()`. Has `_meta`, `_payload`, indexes. +- **Predicate** — from `.after()`, `.near()`, etc. Lazy SQL WHERE. +- **Transform** — from `.transform(t)`. Source stream + Transformer. + +The impl decides how to execute based on the backing chain. + +## Open Questions + +1. **`.append()` on non-stored streams?** Runtime error, or silently ignore? Probably `TypeError`. + +2. **Multiple `.store()` calls?** Should be idempotent — second call is a no-op if already stored under the same name. + +3. **Memory pressure from in-memory transforms?** Large `.transform().fetch()` without `.store()` loads everything into memory. Should we support streaming iteration? diff --git a/plans/old/transforms.md b/plans/old/transforms.md new file mode 100644 index 0000000000..edc5940512 --- /dev/null +++ b/plans/old/transforms.md @@ -0,0 +1,21 @@ +```python +# Filter → transform → store + + images.after(one_hour_ago) \ + .near(kitchen_pose, 5.0) \ + .transform(CLIPModel()) \ + .store("kitchen_embeddings") + + # Filter → transform → fetch (in-memory, not persisted) + results = images.after(one_hour_ago) \ + .near(kitchen_pose, 5.0) \ + .transform(CLIPModel()) \ + .fetch() + + # Filter → transform → transform → store + images.near(kitchen_pose, 5.0) \ + .transform(CLIPModel()) \ + .transform(CigaretteDetector(vlm)) \ + .store("kitchen_cigarette_detections") + +``` From 76fbea1c2186de437eeac240ecc5f4f0281b6303 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 5 Mar 2026 15:00:12 +0800 Subject: [PATCH 008/118] transform materialize --- dimos/memory/impl/sqlite.py | 33 ++++++++++++++++++++++--- dimos/memory/store.py | 13 ++++++++++ dimos/memory/stream.py | 32 +++++++++++++++++------- dimos/memory/tests/test_sqlite.py | 41 +++++++++++++++++++++++++++++++ plans/memory/api.md | 27 +++++++++++++++----- 5 files changed, 128 insertions(+), 18 deletions(-) diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index d3f8e3b989..136999f103 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -34,6 +34,7 @@ from dimos.memory.store import Session, Store from dimos.memory.stream import EmbeddingStream, Stream, TextStream +from dimos.memory.transformer import EmbeddingTransformer, Transformer from dimos.memory.types import ( AfterFilter, AtFilter, @@ -508,7 +509,7 @@ def stream( self._register_stream(name, payload_type, "stream") backend = SqliteStreamBackend(self._conn, name, pose_provider=pose_provider) - s: Stream[Any] = Stream(backend=backend) + s: Stream[Any] = Stream(backend=backend, session=self) self._streams[name] = s return s @@ -530,7 +531,7 @@ def text_stream( backend = SqliteTextBackend( self._conn, name, tokenizer=tokenizer, pose_provider=pose_provider ) - ts: TextStream[Any] = TextStream(backend=backend) + ts: TextStream[Any] = TextStream(backend=backend, session=self) self._streams[name] = ts return ts @@ -559,7 +560,7 @@ def embedding_stream( if vec_dimensions is not None: backend._ensure_vec_table() - es: EmbeddingStream[Any] = EmbeddingStream(backend=backend) + es: EmbeddingStream[Any] = EmbeddingStream(backend=backend, session=self) self._streams[name] = es return es @@ -572,6 +573,32 @@ def list_streams(self) -> list[StreamInfo]: result.append(StreamInfo(name=name, payload_type=ptype, count=count)) return result + def materialize_transform( + self, + name: str, + source: Stream[Any], + transformer: Transformer[Any, Any], + *, + live: bool = False, + backfill_only: bool = False, + ) -> Stream[Any]: + # Determine stream type from transformer + target: Stream[Any] + if isinstance(transformer, EmbeddingTransformer): + target = self.embedding_stream(name) + else: + target = self.stream(name) + + # Backfill existing data + if transformer.supports_backfill and not live: + transformer.process(source, target) + + # Subscribe to live updates + if transformer.supports_live and not backfill_only: + source.appended.subscribe(on_next=lambda obs: transformer.on_append(obs, target)) + + return target + def close(self) -> None: for s in self._streams.values(): if s._backend is not None: diff --git a/dimos/memory/store.py b/dimos/memory/store.py index eb05ee9a41..43d3abb832 100644 --- a/dimos/memory/store.py +++ b/dimos/memory/store.py @@ -19,6 +19,7 @@ if TYPE_CHECKING: from .stream import Stream, TextStream + from .transformer import Transformer from .types import PoseProvider, StreamInfo @@ -49,6 +50,18 @@ def text_stream( @abstractmethod def list_streams(self) -> list[StreamInfo]: ... + @abstractmethod + def materialize_transform( + self, + name: str, + source: Stream[Any], + transformer: Transformer[Any, Any], + *, + live: bool = False, + backfill_only: bool = False, + ) -> Stream[Any]: + """Create a stored stream from a transform pipeline.""" + @abstractmethod def close(self) -> None: ... diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index e2ccc50bd8..9170d4feb4 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -47,6 +47,7 @@ from dimos.models.embedding.base import Embedding from dimos.msgs.geometry_msgs.Pose import PoseLike + from .store import Session from .transformer import Transformer T = TypeVar("T") @@ -83,9 +84,11 @@ def __init__( backend: StreamBackend | None = None, *, query: StreamQuery | None = None, + session: Session | None = None, ) -> None: self._backend = backend self._query = query or StreamQuery() + self._session: Session | None = session def _clone(self, **overrides: Any) -> Stream[T]: """Return a new Stream with updated query fields.""" @@ -100,6 +103,7 @@ def _clone(self, **overrides: Any) -> Stream[T]: clone: Stream[T] = self.__class__.__new__(self.__class__) clone._backend = self._backend clone._query = new_query + clone._session = self._session return clone def _with_filter(self, f: Filter) -> Stream[T]: @@ -205,6 +209,9 @@ def transform( # ── Materialize ─────────────────────────────────────────────────── def store(self, name: str | None = None) -> Stream[T]: + # Already stored streams are a no-op + if self._backend is not None and name is None: + return self raise TypeError( "store() requires a session context. This stream is not associated with a session." ) @@ -349,15 +356,22 @@ def fetch(self) -> list[Observation]: self._transformer.process(self._source, collector) return collector.results - def store(self, name: str | None = None) -> Stream[R]: - # Delegated to session — TransformStream.store() is overridden - # by the session when the source stream has a backend - source_backend = self._source._backend - if source_backend is None: - raise TypeError("Cannot store a transform whose source has no backend session") - # The backend's session handles materialization - raise NotImplementedError( - "store() on TransformStream must be handled by the session/backend" + def store(self, name: str | None = None, session: Session | None = None) -> Stream[R]: + resolved = session or self._source._session + if resolved is None: + raise TypeError( + "Cannot store: no session available. " + "Either use session.stream() to create the source, " + "or pass session= to store()." + ) + if name is None: + raise TypeError("store() requires a name for transform outputs") + return resolved.materialize_transform( + name=name, + source=self._source, + transformer=self._transformer, + live=self._live, + backfill_only=self._backfill_only, ) diff --git a/dimos/memory/tests/test_sqlite.py b/dimos/memory/tests/test_sqlite.py index ec813f2a71..ec71135b4d 100644 --- a/dimos/memory/tests/test_sqlite.py +++ b/dimos/memory/tests/test_sqlite.py @@ -278,6 +278,47 @@ def test_lambda_expand_list(self, session: SqliteSession) -> None: assert [r.data for r in results] == ["a", "b", "c"] +class TestTransformStore: + def test_transform_store_backfill(self, session: SqliteSession) -> None: + s = session.stream("data", str) + s.append("hello", ts=1.0) + s.append("world", ts=2.0) + + stored = s.transform(lambda x: x.upper()).store("upper_data") + rows = stored.fetch() + assert len(rows) == 2 + assert rows[0].data == "HELLO" + assert rows[1].data == "WORLD" + + # Also queryable by name + reloaded = session.stream("upper_data") + assert reloaded.count() == 2 + + def test_transform_store_live(self, session: SqliteSession) -> None: + s = session.stream("data", str) + s.append("existing", ts=1.0) + + # live=True skips backfill, only processes new items + stored = s.transform(lambda x: x.upper(), live=True).store("live_upper") + assert stored.count() == 0 # no backfill + + s.append("new", ts=2.0) + assert stored.count() == 1 + assert stored.last().data == "NEW" + + def test_transform_store_backfill_only(self, session: SqliteSession) -> None: + s = session.stream("data", str) + s.append("existing", ts=1.0) + + stored = s.transform(lambda x: x.upper(), backfill_only=True).store("backfill_upper") + assert stored.count() == 1 + assert stored.one().data == "EXISTING" + + # New appends should NOT propagate + s.append("new", ts=2.0) + assert stored.count() == 1 # still 1 + + class TestStoreReopen: def test_data_persists(self, tmp_path: object) -> None: from pathlib import Path diff --git a/plans/memory/api.md b/plans/memory/api.md index 3b00a8b514..bb62d52ca4 100644 --- a/plans/memory/api.md +++ b/plans/memory/api.md @@ -410,8 +410,8 @@ class Stream(Generic[T]): backfill_only: bool = False, ) -> Stream[R]: ... - # Materialize - def store(self, name: str | None = None) -> Stream[T]: ... + # Materialize (on TransformStream, accepts optional session= fallback) + def store(self, name: str | None = None, session: Session | None = None) -> Stream[T]: ... # Cross-stream (lineage join — returns tuples of (self_obs, target_obs)) def join(self, target: Stream) -> Stream[tuple[Observation, Observation]]: ... @@ -450,6 +450,10 @@ class Session: def text_stream(self, name: str, payload_type: type | None = None, *, tokenizer: str = "unicode61", pose_provider: PoseProvider | None = None) -> TextStream: ... + def materialize_transform(self, name: str, source: Stream, + transformer: Transformer, + *, live: bool = False, + backfill_only: bool = False) -> Stream: ... def list_streams(self) -> list[StreamInfo]: ... def close(self) -> None: ... @@ -468,10 +472,21 @@ A `Stream` can be backed by different things — the user never sees this: The impl decides how to execute based on the backing chain. -## Open Questions +## Implementation Notes + +- **No ORM** — raw `sqlite3` with direct SQL. The `Stream` filter chain *is* the query builder. +- **Session threading** — streams created by `session.stream()` get `_session` set. `TransformStream` inherits it from its source. `store()` also accepts an explicit `session=` fallback. +- **Serialization** — payloads are `pickle`, poses are `pickle`, tags are JSON. +- **Near filter** — compiled as no-op SQL (`1=1`), filtered post-query in Python via pose distance. -1. **`.append()` on non-stored streams?** Runtime error, or silently ignore? Probably `TypeError`. +## Resolved Questions -2. **Multiple `.store()` calls?** Should be idempotent — second call is a no-op if already stored under the same name. +1. **`.append()` on non-stored streams?** → `TypeError` (requires backend). +2. **Multiple `.store()` calls?** → Idempotent — returns existing stream if already stored. +3. ~~**Memory pressure from in-memory transforms?**~~ → Solved via `fetch_pages`. + +## Open Questions -3. ~~**Memory pressure from in-memory transforms?**~~ Solved — `Stream` is iterable, pages internally via `fetch_pages`. +1. **`project_to` / lineage** — `parent_id` column exists but not yet wired. +2. **Incremental transforms** — re-running a stored transform should resume from last processed item. +3. **`__iter__`** — spec shows `Stream.__iter__` but not yet implemented. From ef7fe1d83bb08846724b254984cac2f7b4ff35e9 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 5 Mar 2026 15:53:25 +0800 Subject: [PATCH 009/118] sqlite schema: decomposed pose columns, separate payload table, R*Tree spatial index, lazy data loading MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Pose stored as 7 real columns (x/y/z + quaternion) instead of blob, enabling R*Tree spatial indexing - Payload moved to separate {name}_payload table with lazy loading via _data_loader closure - R*Tree virtual table created per stream for .near() bounding-box queries - Added __iter__ to Stream for lazy iteration via fetch_pages - Added embedding_stream() to Session ABC - Updated _streams metadata with parent_stream and embedding_dim columns - Codec module extracted (LcmCodec, PickleCodec, codec_for_type) - Fixed broken memory_old.timeseries imports (memory.timeseries → memory_old.timeseries) - Tests now use real Image data from TimedSensorReplay("unitree_go2_bigoffice/video") - 32/32 tests passing, mypy clean --- dimos/memory/__init__.py | 5 + dimos/memory/codec.py | 95 +++++ dimos/memory/impl/sqlite.py | 396 +++++++++++++++------ dimos/memory/store.py | 14 +- dimos/memory/stream.py | 6 + dimos/memory/tests/test_sqlite.py | 350 ++++++++++-------- dimos/memory_old/timeseries/__init__.py | 12 +- dimos/memory_old/timeseries/inmemory.py | 2 +- dimos/memory_old/timeseries/legacy.py | 2 +- dimos/memory_old/timeseries/pickledir.py | 2 +- dimos/memory_old/timeseries/postgres.py | 2 +- dimos/memory_old/timeseries/sqlite.py | 2 +- dimos/memory_old/timeseries/test_base.py | 12 +- dimos/memory_old/timeseries/test_legacy.py | 2 +- dimos/protocol/tf/tf.py | 2 +- dimos/types/test_timestamped.py | 2 +- dimos/types/timestamped.py | 2 +- dimos/utils/testing/replay.py | 2 +- plans/memory/api.md | 131 ++++++- 19 files changed, 760 insertions(+), 281 deletions(-) create mode 100644 dimos/memory/codec.py diff --git a/dimos/memory/__init__.py b/dimos/memory/__init__.py index ba76a2f5ed..14b9c87ba3 100644 --- a/dimos/memory/__init__.py +++ b/dimos/memory/__init__.py @@ -1,3 +1,4 @@ +from dimos.memory.codec import Codec, LcmCodec, PickleCodec, codec_for_type from dimos.memory.store import Session, Store from dimos.memory.stream import EmbeddingStream, Stream, TextStream from dimos.memory.transformer import ( @@ -12,15 +13,19 @@ ) __all__ = [ + "Codec", "EmbeddingObservation", "EmbeddingStream", "EmbeddingTransformer", + "LcmCodec", "Observation", "PerItemTransformer", + "PickleCodec", "Session", "Store", "Stream", "StreamInfo", "TextStream", "Transformer", + "codec_for_type", ] diff --git a/dimos/memory/codec.py b/dimos/memory/codec.py new file mode 100644 index 0000000000..3a18fe21df --- /dev/null +++ b/dimos/memory/codec.py @@ -0,0 +1,95 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import importlib +import pickle +from typing import TYPE_CHECKING, Any, Protocol, TypeVar + +if TYPE_CHECKING: + from dimos.msgs.protocol import DimosMsg + +T = TypeVar("T") + + +class Codec(Protocol[T]): + """Encodes/decodes payloads for storage.""" + + def encode(self, value: T) -> bytes: ... + def decode(self, data: bytes) -> T: ... + + +class LcmCodec: + """Codec for DimosMsg types — uses lcm_encode/lcm_decode.""" + + def __init__(self, msg_type: type[DimosMsg]) -> None: + self._msg_type = msg_type + + def encode(self, value: DimosMsg) -> bytes: + return value.lcm_encode() + + def decode(self, data: bytes) -> DimosMsg: + return self._msg_type.lcm_decode(data) + + +class PickleCodec: + """Fallback codec for arbitrary Python objects.""" + + def encode(self, value: Any) -> bytes: + return pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL) + + def decode(self, data: bytes) -> Any: + return pickle.loads(data) + + +_POSE_CODEC: LcmCodec | None = None + + +def _pose_codec() -> LcmCodec: + global _POSE_CODEC + if _POSE_CODEC is None: + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + + _POSE_CODEC = LcmCodec(PoseStamped) + return _POSE_CODEC + + +def codec_for_type(payload_type: type | None) -> LcmCodec | PickleCodec: + """Auto-select codec based on payload type.""" + if ( + payload_type is not None + and hasattr(payload_type, "lcm_encode") + and hasattr(payload_type, "lcm_decode") + ): + return LcmCodec(payload_type) # type: ignore[arg-type] + return PickleCodec() + + +def type_to_module_path(t: type) -> str: + """Return fully qualified module path for a type, e.g. 'dimos.msgs.sensor_msgs.Image.Image'.""" + return f"{t.__module__}.{t.__qualname__}" + + +def module_path_to_type(path: str) -> type | None: + """Resolve a fully qualified module path back to a type. Returns None on failure.""" + parts = path.rsplit(".", 1) + if len(parts) != 2: + return None + module_path, class_name = parts + try: + mod = importlib.import_module(module_path) + return getattr(mod, class_name, None) # type: ignore[no-any-return] + except (ImportError, AttributeError): + return None diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index 136999f103..8a549803be 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -14,24 +14,34 @@ """SQLite-backed memory store implementation. -Each stream maps to a table: - {name} — id INTEGER PK, ts REAL, pose BLOB, tags TEXT (JSON), payload BLOB - {name}_fts — FTS5 virtual table (TextStream only) - {name}_vec — vec0 virtual table (EmbeddingStream only) +Schema per stream ``{name}``: -Payloads are pickled. Poses are pickled PoseStamped. Tags are JSON. + {name} — id, ts, pose columns (x/y/z + quaternion), tags, parent_id + {name}_payload — id, data BLOB (loaded lazily) + {name}_rtree — R*Tree spatial index on position + {name}_fts — FTS5 full-text index (TextStream only) + {name}_vec — vec0 vector index (EmbeddingStream only) + +Payloads use Codec (LCM for DimosMsg types, pickle otherwise). +Poses are decomposed into columns. Tags are JSON. """ from __future__ import annotations import json -import pickle import sqlite3 import time from typing import TYPE_CHECKING, Any from reactivex.subject import Subject +from dimos.memory.codec import ( + LcmCodec, + PickleCodec, + codec_for_type, + module_path_to_type, + type_to_module_path, +) from dimos.memory.store import Session, Store from dimos.memory.stream import EmbeddingStream, Stream, TextStream from dimos.memory.transformer import EmbeddingTransformer, Transformer @@ -55,27 +65,44 @@ from dimos.memory.types import PoseProvider -# ── Serialization helpers ───────────────────────────────────────────── - - -def _serialize_payload(payload: Any) -> bytes: - return pickle.dumps(payload, protocol=pickle.HIGHEST_PROTOCOL) - +# ── Pose helpers (column-based) ────────────────────────────────────── -def _deserialize_payload(blob: bytes) -> Any: - return pickle.loads(blob) - -def _serialize_pose(pose: Any) -> bytes | None: +def _decompose_pose(pose: Any) -> tuple[float, float, float, float, float, float, float] | None: + """Extract (x, y, z, qx, qy, qz, qw) from a PoseStamped or similar.""" if pose is None: return None - return pickle.dumps(pose, protocol=pickle.HIGHEST_PROTOCOL) - - -def _deserialize_pose(blob: bytes | None) -> Any: - if blob is None: + # PoseStamped has .pose.position and .pose.orientation + p = pose.pose.position + q = pose.pose.orientation + return (p.x, p.y, p.z, q.x, q.y, q.z, q.w) + + +def _reconstruct_pose( + x: float | None, + y: float | None, + z: float | None, + qx: float | None, + qy: float | None, + qz: float | None, + qw: float | None, +) -> Any | None: + """Rebuild a PoseStamped from column values.""" + if x is None: return None - return pickle.loads(blob) + from dimos.msgs.geometry_msgs.Point import Point + from dimos.msgs.geometry_msgs.Pose import Pose + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + from dimos.msgs.geometry_msgs.Quaternion import Quaternion + from dimos.msgs.std_msgs.Header import Header + + return PoseStamped( + header=Header(), + pose=Pose( + position=Point(x=x, y=y or 0.0, z=z or 0.0), + orientation=Quaternion(x=qx or 0.0, y=qy or 0.0, z=qz or 0.0, w=qw or 1.0), + ), + ) def _serialize_tags(tags: dict[str, Any] | None) -> str: @@ -92,56 +119,95 @@ def _deserialize_tags(text: str) -> dict[str, Any]: # ── SQL building ────────────────────────────────────────────────────── +# Columns selected from the meta table (no payload). +_META_COLS = "id, ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, tags" + def _compile_filter(f: Filter, table: str) -> tuple[str, list[Any]]: """Compile a single filter to (SQL fragment, params).""" if isinstance(f, AfterFilter): - return "ts > ?", [f.t] + return f"{table}.ts > ?", [f.t] if isinstance(f, BeforeFilter): - return "ts < ?", [f.t] + return f"{table}.ts < ?", [f.t] if isinstance(f, TimeRangeFilter): - return "ts >= ? AND ts <= ?", [f.t1, f.t2] + return f"{table}.ts >= ? AND {table}.ts <= ?", [f.t1, f.t2] if isinstance(f, AtFilter): - return "ABS(ts - ?) <= ?", [f.t, f.tolerance] + return f"ABS({table}.ts - ?) <= ?", [f.t, f.tolerance] if isinstance(f, TagsFilter): clauses: list[str] = [] params: list[Any] = [] for key, val in f.tags.items(): - clauses.append(f"json_extract(tags, '$.{key}') = ?") + clauses.append(f"json_extract({table}.tags, '$.{key}') = ?") params.append(val) return " AND ".join(clauses), params if isinstance(f, NearFilter): - # Spatial filtering requires pose deserialization — done post-query - # Return a no-op SQL clause; filtering happens in Python + # Handled via R*Tree JOIN — see _compile_query return "1=1", [] if isinstance(f, EmbeddingSearchFilter): - # Handled specially by EmbeddingStream backend return "1=1", [] if isinstance(f, TextSearchFilter): - # Handled specially by TextStream backend return "1=1", [] raise TypeError(f"Unknown filter type: {type(f)}") +def _has_near_filter(query: StreamQuery) -> NearFilter | None: + for f in query.filters: + if isinstance(f, NearFilter): + return f + return None + + def _compile_query(query: StreamQuery, table: str) -> tuple[str, list[Any]]: - """Compile a StreamQuery to (SQL, params) for a SELECT.""" + """Compile a StreamQuery to (SQL, params) for a metadata SELECT.""" where_parts: list[str] = [] params: list[Any] = [] + joins: list[str] = [] + + _has_near_filter(query) for f in query.filters: - sql, p = _compile_filter(f, table) - where_parts.append(sql) - params.extend(p) + if isinstance(f, NearFilter): + # R*Tree bounding-box join + joins.append(f"JOIN {table}_rtree AS r ON r.id = {table}.id") + where_parts.append( + "r.min_x >= ? AND r.max_x <= ? AND " + "r.min_y >= ? AND r.max_y <= ? AND " + "r.min_z >= ? AND r.max_z <= ?" + ) + pose_parts = _decompose_pose(f.pose) + if pose_parts is not None: + x, y, z = pose_parts[0], pose_parts[1], pose_parts[2] + else: + x, y, z = 0.0, 0.0, 0.0 + params.extend( + [ + x - f.radius, + x + f.radius, + y - f.radius, + y + f.radius, + z - f.radius, + z + f.radius, + ] + ) + else: + sql_frag, p = _compile_filter(f, table) + where_parts.append(sql_frag) + params.extend(p) where = " AND ".join(where_parts) if where_parts else "1=1" + join_clause = " ".join(joins) + order = f"ORDER BY {query.order_field}" if query.order_field: if query.order_desc: order += " DESC" else: - order = "ORDER BY id" + order = f"ORDER BY {table}.id" - sql = f"SELECT id, ts, pose, tags, payload FROM {table} WHERE {where} {order}" + sql = f"SELECT {table}.{_META_COLS.replace(', ', f', {table}.')} FROM {table}" + if join_clause: + sql += f" {join_clause}" + sql += f" WHERE {where} {order}" if query.limit_val is not None: sql += f" LIMIT {query.limit_val}" if query.offset_val is not None: @@ -152,35 +218,60 @@ def _compile_query(query: StreamQuery, table: str) -> tuple[str, list[Any]]: def _compile_count(query: StreamQuery, table: str) -> tuple[str, list[Any]]: where_parts: list[str] = [] params: list[Any] = [] - for f in query.filters: - sql, p = _compile_filter(f, table) - where_parts.append(sql) - params.extend(p) - where = " AND ".join(where_parts) if where_parts else "1=1" - return f"SELECT COUNT(*) FROM {table} WHERE {where}", params - - -# ── Near-filter post-processing ─────────────────────────────────────── + joins: list[str] = [] - -def _has_near_filter(query: StreamQuery) -> NearFilter | None: for f in query.filters: if isinstance(f, NearFilter): - return f - return None + joins.append(f"JOIN {table}_rtree AS r ON r.id = {table}.id") + pose_parts = _decompose_pose(f.pose) + if pose_parts is not None: + x, y, z = pose_parts[0], pose_parts[1], pose_parts[2] + else: + x, y, z = 0.0, 0.0, 0.0 + where_parts.append( + "r.min_x >= ? AND r.max_x <= ? AND " + "r.min_y >= ? AND r.max_y <= ? AND " + "r.min_z >= ? AND r.max_z <= ?" + ) + params.extend( + [ + x - f.radius, + x + f.radius, + y - f.radius, + y + f.radius, + z - f.radius, + z + f.radius, + ] + ) + else: + sql_frag, p = _compile_filter(f, table) + where_parts.append(sql_frag) + params.extend(p) + + where = " AND ".join(where_parts) if where_parts else "1=1" + join_clause = " ".join(joins) + sql = f"SELECT COUNT(*) FROM {table}" + if join_clause: + sql += f" {join_clause}" + sql += f" WHERE {where}" + return sql, params -def _apply_near_filter(rows: list[Observation], near: NearFilter) -> list[Observation]: - """Post-filter observations by spatial distance.""" - from dimos.msgs.geometry_msgs.Pose import to_pose +# ── Near-filter post-processing (exact distance after R*Tree bbox) ─── - target = to_pose(near.pose) + +def _apply_near_post_filter(rows: list[Observation], near: NearFilter) -> list[Observation]: + """Post-filter R*Tree candidates by exact Euclidean distance.""" + pose_parts = _decompose_pose(near.pose) + if pose_parts is None: + return [] + tx, ty, tz = pose_parts[0], pose_parts[1], pose_parts[2] result: list[Observation] = [] for obs in rows: if obs.pose is None: continue - obs_pose = to_pose(obs.pose) - dist = (target - obs_pose).position.norm() + op = obs.pose.pose.position + dist = ((op.x - tx) ** 2 + (op.y - ty) ** 2 + (op.z - tz) ** 2) ** 0.5 if dist <= near.radius: result.append(obs) return result @@ -198,10 +289,12 @@ def __init__( table: str, *, pose_provider: PoseProvider | None = None, + codec: LcmCodec | PickleCodec | None = None, ) -> None: self._conn = conn self._table = table self._pose_provider = pose_provider + self._codec = codec or PickleCodec() self._subject: Subject[Observation] = Subject() # type: ignore[type-arg] @property @@ -224,18 +317,43 @@ def do_append( if pose is None and self._pose_provider is not None: pose = self._pose_provider() - payload_blob = _serialize_payload(payload) - pose_blob = _serialize_pose(pose) + pose_cols = _decompose_pose(pose) tags_json = _serialize_tags(tags) - cur = self._conn.execute( - f"INSERT INTO {self._table} (ts, pose, tags, payload) VALUES (?, ?, ?, ?)", - (ts, pose_blob, tags_json, payload_blob), - ) - self._conn.commit() + # 1. Insert into meta table + if pose_cols is not None: + cur = self._conn.execute( + f"INSERT INTO {self._table} " + "(ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, tags) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + (ts, *pose_cols, tags_json), + ) + else: + cur = self._conn.execute( + f"INSERT INTO {self._table} (ts, tags) VALUES (?, ?)", + (ts, tags_json), + ) row_id = cur.lastrowid assert row_id is not None + # 2. Insert into payload table + payload_blob = self._codec.encode(payload) + self._conn.execute( + f"INSERT INTO {self._table}_payload (id, data) VALUES (?, ?)", + (row_id, payload_blob), + ) + + # 3. Insert into R*Tree (if pose) + if pose_cols is not None: + x, y, z = pose_cols[0], pose_cols[1], pose_cols[2] + self._conn.execute( + f"INSERT INTO {self._table}_rtree (id, min_x, max_x, min_y, max_y, min_z, max_z) " + "VALUES (?, ?, ?, ?, ?, ?, ?)", + (row_id, x, x, y, y, z, z), + ) + + self._conn.commit() + obs = Observation( id=row_id, ts=ts, @@ -249,12 +367,11 @@ def do_append( def execute_fetch(self, query: StreamQuery) -> list[Observation]: sql, params = _compile_query(query, self._table) rows = self._conn.execute(sql, params).fetchall() - observations = [self._row_to_obs(r) for r in rows] near = _has_near_filter(query) if near is not None: - observations = _apply_near_filter(observations, near) + observations = _apply_near_post_filter(observations, near) return observations @@ -264,13 +381,24 @@ def execute_count(self, query: StreamQuery) -> int: return result[0] if result else 0 # type: ignore[no-any-return] def _row_to_obs(self, row: Any) -> Observation: - row_id, ts, pose_blob, tags_json, payload_blob = row + row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json = row + pose = _reconstruct_pose(px, py, pz, qx, qy, qz, qw) + conn = self._conn + table = self._table + codec = self._codec + + def loader() -> Any: + r = conn.execute(f"SELECT data FROM {table}_payload WHERE id = ?", (row_id,)).fetchone() + if r is None: + raise LookupError(f"No payload for id={row_id}") + return codec.decode(r[0]) + return Observation( id=row_id, ts=ts, - pose=_deserialize_pose(pose_blob), + pose=pose, tags=_deserialize_tags(tags_json), - _data=_deserialize_payload(payload_blob), + _data_loader=loader, ) @@ -285,8 +413,9 @@ def __init__( vec_dimensions: int | None = None, pose_provider: PoseProvider | None = None, parent_table: str | None = None, + codec: LcmCodec | PickleCodec | None = None, ) -> None: - super().__init__(conn, table, pose_provider=pose_provider) + super().__init__(conn, table, pose_provider=pose_provider, codec=codec) self._vec_dimensions = vec_dimensions self._parent_table = parent_table @@ -325,7 +454,6 @@ def _ensure_vec_table(self) -> None: self._conn.commit() def execute_fetch(self, query: StreamQuery) -> list[Observation]: - # Check for embedding search filter emb_filter = None for f in query.filters: if isinstance(f, EmbeddingSearchFilter): @@ -341,7 +469,6 @@ def _fetch_by_vector( self, query: StreamQuery, emb_filter: EmbeddingSearchFilter ) -> list[Observation]: """Fetch using vec0 similarity search, then apply remaining filters.""" - # First, get candidate rowids from vec0 vec_sql = ( f"SELECT rowid, distance FROM {self._table}_vec " f"WHERE embedding MATCH ? ORDER BY distance LIMIT ?" @@ -356,8 +483,7 @@ def _fetch_by_vector( rowids = [r[0] for r in vec_rows] placeholders = ",".join("?" * len(rowids)) - # Build remaining WHERE clauses (skip the embedding filter) - where_parts: list[str] = [f"id IN ({placeholders})"] + where_parts: list[str] = [f"{self._table}.id IN ({placeholders})"] params: list[Any] = list(rowids) for f in query.filters: @@ -368,25 +494,39 @@ def _fetch_by_vector( params.extend(p) where = " AND ".join(where_parts) - sql = f"SELECT id, ts, pose, tags, payload FROM {self._table} WHERE {where}" + sql = ( + f"SELECT {self._table}.{_META_COLS.replace(', ', f', {self._table}.')} " + f"FROM {self._table} WHERE {where}" + ) rows = self._conn.execute(sql, params).fetchall() observations = [self._row_to_obs(r) for r in rows] near = _has_near_filter(query) if near is not None: - observations = _apply_near_filter(observations, near) + observations = _apply_near_post_filter(observations, near) return observations def _row_to_obs(self, row: Any) -> Observation: - row_id, ts, pose_blob, tags_json, payload_blob = row + row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json = row + pose = _reconstruct_pose(px, py, pz, qx, qy, qz, qw) + conn = self._conn + table = self._table + codec = self._codec + + def loader() -> Any: + r = conn.execute(f"SELECT data FROM {table}_payload WHERE id = ?", (row_id,)).fetchone() + if r is None: + raise LookupError(f"No payload for id={row_id}") + return codec.decode(r[0]) + return EmbeddingObservation( id=row_id, ts=ts, - pose=_deserialize_pose(pose_blob), + pose=pose, tags=_deserialize_tags(tags_json), - _data=_deserialize_payload(payload_blob), + _data_loader=loader, ) @@ -400,8 +540,9 @@ def __init__( *, tokenizer: str = "unicode61", pose_provider: PoseProvider | None = None, + codec: LcmCodec | PickleCodec | None = None, ) -> None: - super().__init__(conn, table, pose_provider=pose_provider) + super().__init__(conn, table, pose_provider=pose_provider, codec=codec) self._tokenizer = tokenizer def do_append( @@ -413,7 +554,6 @@ def do_append( ) -> Observation: obs = super().do_append(payload, ts, pose, tags) - # Insert into FTS table text = str(payload) if payload is not None else "" self._conn.execute( f"INSERT INTO {self._table}_fts (rowid, content) VALUES (?, ?)", @@ -437,7 +577,6 @@ def execute_fetch(self, query: StreamQuery) -> list[Observation]: def _fetch_by_text( self, query: StreamQuery, text_filter: TextSearchFilter ) -> list[Observation]: - # Get matching rowids from FTS fts_sql = f"SELECT rowid, rank FROM {self._table}_fts WHERE content MATCH ? ORDER BY rank" fts_params: list[Any] = [text_filter.text] if text_filter.k is not None: @@ -451,7 +590,7 @@ def _fetch_by_text( rowids = [r[0] for r in fts_rows] placeholders = ",".join("?" * len(rowids)) - where_parts: list[str] = [f"id IN ({placeholders})"] + where_parts: list[str] = [f"{self._table}.id IN ({placeholders})"] params: list[Any] = list(rowids) for f in query.filters: @@ -462,14 +601,17 @@ def _fetch_by_text( params.extend(p) where = " AND ".join(where_parts) - sql = f"SELECT id, ts, pose, tags, payload FROM {self._table} WHERE {where}" + sql = ( + f"SELECT {self._table}.{_META_COLS.replace(', ', f', {self._table}.')} " + f"FROM {self._table} WHERE {where}" + ) rows = self._conn.execute(sql, params).fetchall() observations = [self._row_to_obs(r) for r in rows] near = _has_near_filter(query) if near is not None: - observations = _apply_near_filter(observations, near) + observations = _apply_near_post_filter(observations, near) return observations @@ -489,8 +631,10 @@ def _ensure_meta_table(self) -> None: self._conn.execute( "CREATE TABLE IF NOT EXISTS _streams (" " name TEXT PRIMARY KEY," - " payload_type TEXT," - " stream_kind TEXT DEFAULT 'stream'" + " payload_module TEXT," + " stream_kind TEXT DEFAULT 'stream'," + " parent_stream TEXT," + " embedding_dim INTEGER" ")" ) self._conn.commit() @@ -505,10 +649,14 @@ def stream( if name in self._streams: return self._streams[name] - self._ensure_stream_table(name) + if payload_type is None: + payload_type = self._resolve_payload_type(name) + + self._ensure_stream_tables(name) self._register_stream(name, payload_type, "stream") - backend = SqliteStreamBackend(self._conn, name, pose_provider=pose_provider) + codec = codec_for_type(payload_type) + backend = SqliteStreamBackend(self._conn, name, pose_provider=pose_provider, codec=codec) s: Stream[Any] = Stream(backend=backend, session=self) self._streams[name] = s return s @@ -524,12 +672,16 @@ def text_stream( if name in self._streams: return self._streams[name] # type: ignore[return-value] - self._ensure_stream_table(name) + if payload_type is None: + payload_type = self._resolve_payload_type(name) + + self._ensure_stream_tables(name) self._ensure_fts_table(name, tokenizer) self._register_stream(name, payload_type, "text") + codec = codec_for_type(payload_type) backend = SqliteTextBackend( - self._conn, name, tokenizer=tokenizer, pose_provider=pose_provider + self._conn, name, tokenizer=tokenizer, pose_provider=pose_provider, codec=codec ) ts: TextStream[Any] = TextStream(backend=backend, session=self) self._streams[name] = ts @@ -547,15 +699,20 @@ def embedding_stream( if name in self._streams: return self._streams[name] # type: ignore[return-value] - self._ensure_stream_table(name) - self._register_stream(name, payload_type, "embedding") + if payload_type is None: + payload_type = self._resolve_payload_type(name) + self._ensure_stream_tables(name) + self._register_stream(name, payload_type, "embedding", embedding_dim=vec_dimensions) + + codec = codec_for_type(payload_type) backend = SqliteEmbeddingBackend( self._conn, name, vec_dimensions=vec_dimensions, pose_provider=pose_provider, parent_table=parent_table, + codec=codec, ) if vec_dimensions is not None: backend._ensure_vec_table() @@ -565,12 +722,12 @@ def embedding_stream( return es def list_streams(self) -> list[StreamInfo]: - rows = self._conn.execute("SELECT name, payload_type FROM _streams").fetchall() + rows = self._conn.execute("SELECT name, payload_module FROM _streams").fetchall() result: list[StreamInfo] = [] - for name, ptype in rows: + for name, pmodule in rows: count_row = self._conn.execute(f"SELECT COUNT(*) FROM {name}").fetchone() count = count_row[0] if count_row else 0 - result.append(StreamInfo(name=name, payload_type=ptype, count=count)) + result.append(StreamInfo(name=name, payload_type=pmodule, count=count)) return result def materialize_transform( @@ -582,7 +739,6 @@ def materialize_transform( live: bool = False, backfill_only: bool = False, ) -> Stream[Any]: - # Determine stream type from transformer target: Stream[Any] if isinstance(transformer, EmbeddingTransformer): target = self.embedding_stream(name) @@ -607,18 +763,35 @@ def close(self) -> None: # ── Internal helpers ────────────────────────────────────────────── - def _ensure_stream_table(self, name: str) -> None: + def _ensure_stream_tables(self, name: str) -> None: + """Create the meta table, payload table, and R*Tree for a stream.""" self._conn.execute( f"CREATE TABLE IF NOT EXISTS {name} (" " id INTEGER PRIMARY KEY AUTOINCREMENT," " ts REAL," - " pose BLOB," + " pose_x REAL," + " pose_y REAL," + " pose_z REAL," + " pose_qx REAL," + " pose_qy REAL," + " pose_qz REAL," + " pose_qw REAL," " tags TEXT DEFAULT '{}'," - " payload BLOB," " parent_id INTEGER" ")" ) self._conn.execute(f"CREATE INDEX IF NOT EXISTS idx_{name}_ts ON {name}(ts)") + self._conn.execute( + f"CREATE TABLE IF NOT EXISTS {name}_payload ( id INTEGER PRIMARY KEY, data BLOB)" + ) + self._conn.execute( + f"CREATE VIRTUAL TABLE IF NOT EXISTS {name}_rtree USING rtree(" + " id," + " min_x, max_x," + " min_y, max_y," + " min_z, max_z" + ")" + ) self._conn.commit() def _ensure_fts_table(self, name: str, tokenizer: str) -> None: @@ -628,14 +801,31 @@ def _ensure_fts_table(self, name: str, tokenizer: str) -> None: ) self._conn.commit() - def _register_stream(self, name: str, payload_type: type | None, kind: str) -> None: - type_name = payload_type.__qualname__ if payload_type else None + def _register_stream( + self, + name: str, + payload_type: type | None, + kind: str, + *, + embedding_dim: int | None = None, + ) -> None: + module_path = type_to_module_path(payload_type) if payload_type else None self._conn.execute( - "INSERT OR IGNORE INTO _streams (name, payload_type, stream_kind) VALUES (?, ?, ?)", - (name, type_name, kind), + "INSERT OR IGNORE INTO _streams (name, payload_module, stream_kind, embedding_dim) " + "VALUES (?, ?, ?, ?)", + (name, module_path, kind, embedding_dim), ) self._conn.commit() + def _resolve_payload_type(self, name: str) -> type | None: + """Look up payload type from _streams metadata (for restart case).""" + row = self._conn.execute( + "SELECT payload_module FROM _streams WHERE name = ?", (name,) + ).fetchone() + if row is None or row[0] is None: + return None + return module_path_to_type(row[0]) + # ── Store ───────────────────────────────────────────────────────────── diff --git a/dimos/memory/store.py b/dimos/memory/store.py index 43d3abb832..8662d0e895 100644 --- a/dimos/memory/store.py +++ b/dimos/memory/store.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from .stream import Stream, TextStream + from .stream import EmbeddingStream, Stream, TextStream from .transformer import Transformer from .types import PoseProvider, StreamInfo @@ -47,6 +47,18 @@ def text_stream( ) -> TextStream[Any]: """Get or create a text stream with FTS index.""" + @abstractmethod + def embedding_stream( + self, + name: str, + payload_type: type | None = None, + *, + vec_dimensions: int | None = None, + pose_provider: PoseProvider | None = None, + parent_table: str | None = None, + ) -> EmbeddingStream[Any]: + """Get or create an embedding stream with vec0 index.""" + @abstractmethod def list_streams(self) -> list[StreamInfo]: ... diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index 9170d4feb4..dd0543bae4 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -221,6 +221,12 @@ def store(self, name: str | None = None) -> Stream[T]: def project_to(self, target: Stream[Any]) -> Stream[Any]: raise NotImplementedError("project_to requires a stored stream with lineage") + # ── Iteration ───────────────────────────────────────────────────── + + def __iter__(self) -> Iterator[Observation]: + for page in self.fetch_pages(): + yield from page + # ── Terminals ───────────────────────────────────────────────────── def fetch(self) -> list[Observation]: diff --git a/dimos/memory/tests/test_sqlite.py b/dimos/memory/tests/test_sqlite.py index ec71135b4d..f45ed44c36 100644 --- a/dimos/memory/tests/test_sqlite.py +++ b/dimos/memory/tests/test_sqlite.py @@ -21,14 +21,28 @@ import pytest from dimos.memory.impl.sqlite import SqliteSession, SqliteStore +from dimos.msgs.sensor_msgs.Image import Image +from dimos.utils.testing import TimedSensorReplay if TYPE_CHECKING: from dimos.memory.types import Observation +@pytest.fixture(scope="module") +def replay() -> TimedSensorReplay: # type: ignore[type-arg] + return TimedSensorReplay("unitree_go2_bigoffice/video") + + +@pytest.fixture(scope="module") +def images(replay: TimedSensorReplay) -> list[Image]: # type: ignore[type-arg] + """Load 5 images from replay at 1s intervals.""" + imgs = [replay.find_closest_seek(float(i)) for i in range(1, 6)] + assert all(isinstance(im, Image) for im in imgs) + return imgs # type: ignore[return-value] + + @pytest.fixture def store(tmp_path: object) -> SqliteStore: - # tmp_path is a pathlib.Path from pathlib import Path assert isinstance(tmp_path, Path) @@ -42,161 +56,161 @@ def session(store: SqliteStore) -> SqliteSession: class TestStreamBasics: def test_create_stream(self, session: SqliteSession) -> None: - s = session.stream("images", bytes) + s = session.stream("images", Image) assert s is not None - def test_append_and_fetch(self, session: SqliteSession) -> None: - s = session.stream("images", bytes) - obs = s.append(b"frame1") + def test_append_and_fetch(self, session: SqliteSession, images: list[Image]) -> None: + s = session.stream("images", Image) + obs = s.append(images[0]) assert obs.id == 1 - assert obs.data == b"frame1" + assert obs.data == images[0] assert obs.ts is not None rows = s.fetch() assert len(rows) == 1 - assert rows[0].data == b"frame1" + assert rows[0].data == images[0] assert rows[0].id == 1 - def test_append_multiple(self, session: SqliteSession) -> None: - s = session.stream("images", bytes) - s.append(b"frame1") - s.append(b"frame2") - s.append(b"frame3") + def test_append_multiple(self, session: SqliteSession, images: list[Image]) -> None: + s = session.stream("images", Image) + for img in images[:3]: + s.append(img) assert s.count() == 3 rows = s.fetch() - assert [r.data for r in rows] == [b"frame1", b"frame2", b"frame3"] + assert [r.data for r in rows] == images[:3] - def test_append_with_tags(self, session: SqliteSession) -> None: - s = session.stream("images", bytes) - s.append(b"frame1", tags={"cam": "front", "quality": "high"}) + def test_append_with_tags(self, session: SqliteSession, images: list[Image]) -> None: + s = session.stream("images", Image) + s.append(images[0], tags={"cam": "front", "quality": "high"}) rows = s.fetch() assert rows[0].tags == {"cam": "front", "quality": "high"} - def test_last(self, session: SqliteSession) -> None: - s = session.stream("images", bytes) - s.append(b"frame1", ts=1.0) - s.append(b"frame2", ts=2.0) - s.append(b"frame3", ts=3.0) + def test_last(self, session: SqliteSession, images: list[Image]) -> None: + s = session.stream("images", Image) + s.append(images[0], ts=1.0) + s.append(images[1], ts=2.0) + s.append(images[2], ts=3.0) obs = s.last() - assert obs.data == b"frame3" + assert obs.data == images[2] assert obs.ts == 3.0 - def test_one(self, session: SqliteSession) -> None: - s = session.stream("images", bytes) - s.append(b"only") + def test_one(self, session: SqliteSession, images: list[Image]) -> None: + s = session.stream("images", Image) + s.append(images[0]) obs = s.one() - assert obs.data == b"only" + assert obs.data == images[0] def test_one_empty_raises(self, session: SqliteSession) -> None: - s = session.stream("images", bytes) + s = session.stream("images", Image) with pytest.raises(LookupError): s.one() class TestFilters: - def test_after(self, session: SqliteSession) -> None: - s = session.stream("data", str) - s.append("old", ts=1.0) - s.append("new", ts=10.0) + def test_after(self, session: SqliteSession, images: list[Image]) -> None: + s = session.stream("data", Image) + s.append(images[0], ts=1.0) + s.append(images[1], ts=10.0) rows = s.after(5.0).fetch() assert len(rows) == 1 - assert rows[0].data == "new" + assert rows[0].data == images[1] - def test_before(self, session: SqliteSession) -> None: - s = session.stream("data", str) - s.append("old", ts=1.0) - s.append("new", ts=10.0) + def test_before(self, session: SqliteSession, images: list[Image]) -> None: + s = session.stream("data", Image) + s.append(images[0], ts=1.0) + s.append(images[1], ts=10.0) rows = s.before(5.0).fetch() assert len(rows) == 1 - assert rows[0].data == "old" + assert rows[0].data == images[0] - def test_time_range(self, session: SqliteSession) -> None: - s = session.stream("data", str) - s.append("a", ts=1.0) - s.append("b", ts=5.0) - s.append("c", ts=10.0) + def test_time_range(self, session: SqliteSession, images: list[Image]) -> None: + s = session.stream("data", Image) + s.append(images[0], ts=1.0) + s.append(images[1], ts=5.0) + s.append(images[2], ts=10.0) rows = s.time_range(3.0, 7.0).fetch() assert len(rows) == 1 - assert rows[0].data == "b" + assert rows[0].data == images[1] - def test_at(self, session: SqliteSession) -> None: - s = session.stream("data", str) - s.append("a", ts=1.0) - s.append("b", ts=5.0) - s.append("c", ts=10.0) + def test_at(self, session: SqliteSession, images: list[Image]) -> None: + s = session.stream("data", Image) + s.append(images[0], ts=1.0) + s.append(images[1], ts=5.0) + s.append(images[2], ts=10.0) rows = s.at(5.5, tolerance=1.0).fetch() assert len(rows) == 1 - assert rows[0].data == "b" + assert rows[0].data == images[1] - def test_filter_tags(self, session: SqliteSession) -> None: - s = session.stream("data", str) - s.append("front", tags={"cam": "front"}) - s.append("rear", tags={"cam": "rear"}) + def test_filter_tags(self, session: SqliteSession, images: list[Image]) -> None: + s = session.stream("data", Image) + s.append(images[0], tags={"cam": "front"}) + s.append(images[1], tags={"cam": "rear"}) rows = s.filter_tags(cam="front").fetch() assert len(rows) == 1 - assert rows[0].data == "front" + assert rows[0].data == images[0] - def test_chained_filters(self, session: SqliteSession) -> None: - s = session.stream("data", str) - s.append("a", ts=1.0, tags={"cam": "front"}) - s.append("b", ts=5.0, tags={"cam": "front"}) - s.append("c", ts=5.0, tags={"cam": "rear"}) + def test_chained_filters(self, session: SqliteSession, images: list[Image]) -> None: + s = session.stream("data", Image) + s.append(images[0], ts=1.0, tags={"cam": "front"}) + s.append(images[1], ts=5.0, tags={"cam": "front"}) + s.append(images[2], ts=5.0, tags={"cam": "rear"}) rows = s.after(3.0).filter_tags(cam="front").fetch() assert len(rows) == 1 - assert rows[0].data == "b" + assert rows[0].data == images[1] class TestOrdering: - def test_order_by_ts(self, session: SqliteSession) -> None: - s = session.stream("data", str) - s.append("b", ts=2.0) - s.append("a", ts=1.0) - s.append("c", ts=3.0) + def test_order_by_ts(self, session: SqliteSession, images: list[Image]) -> None: + s = session.stream("data", Image) + s.append(images[1], ts=2.0) + s.append(images[0], ts=1.0) + s.append(images[2], ts=3.0) rows = s.order_by("ts").fetch() - assert [r.data for r in rows] == ["a", "b", "c"] + assert [r.data for r in rows] == [images[0], images[1], images[2]] - def test_order_by_desc(self, session: SqliteSession) -> None: - s = session.stream("data", str) - s.append("a", ts=1.0) - s.append("b", ts=2.0) - s.append("c", ts=3.0) + def test_order_by_desc(self, session: SqliteSession, images: list[Image]) -> None: + s = session.stream("data", Image) + s.append(images[0], ts=1.0) + s.append(images[1], ts=2.0) + s.append(images[2], ts=3.0) rows = s.order_by("ts", desc=True).fetch() - assert [r.data for r in rows] == ["c", "b", "a"] + assert [r.data for r in rows] == [images[2], images[1], images[0]] - def test_limit_offset(self, session: SqliteSession) -> None: - s = session.stream("data", str) - for i in range(10): - s.append(f"item{i}", ts=float(i)) + def test_limit_offset(self, session: SqliteSession, images: list[Image]) -> None: + s = session.stream("data", Image) + for i, img in enumerate(images): + s.append(img, ts=float(i)) - rows = s.order_by("ts").limit(3).offset(2).fetch() - assert [r.data for r in rows] == ["item2", "item3", "item4"] + rows = s.order_by("ts").limit(2).offset(1).fetch() + assert len(rows) == 2 + assert [r.data for r in rows] == [images[1], images[2]] class TestFetchPages: - def test_basic_pagination(self, session: SqliteSession) -> None: - s = session.stream("data", str) - for i in range(10): - s.append(f"item{i}", ts=float(i)) - - pages = list(s.fetch_pages(batch_size=3)) - assert len(pages) == 4 # 3+3+3+1 - assert len(pages[0]) == 3 + def test_basic_pagination(self, session: SqliteSession, images: list[Image]) -> None: + s = session.stream("data", Image) + for i, img in enumerate(images): + s.append(img, ts=float(i)) + + pages = list(s.fetch_pages(batch_size=2)) + assert len(pages) == 3 # 2+2+1 + assert len(pages[0]) == 2 assert len(pages[-1]) == 1 all_items = [obs.data for page in pages for obs in page] - assert all_items == [f"item{i}" for i in range(10)] + assert all_items == images class TestTextStream: @@ -223,7 +237,7 @@ def test_list_empty(self, session: SqliteSession) -> None: assert session.list_streams() == [] def test_list_after_create(self, session: SqliteSession) -> None: - session.stream("images", bytes) + session.stream("images", Image) session.text_stream("logs", str) infos = session.list_streams() @@ -232,112 +246,146 @@ def test_list_after_create(self, session: SqliteSession) -> None: class TestReactive: - def test_appended_observable(self, session: SqliteSession) -> None: - s = session.stream("images", bytes) + def test_appended_observable(self, session: SqliteSession, images: list[Image]) -> None: + s = session.stream("images", Image) received: list[Observation] = [] s.appended.subscribe(on_next=received.append) - s.append(b"frame1") - s.append(b"frame2") + s.append(images[0]) + s.append(images[1]) assert len(received) == 2 - assert received[0].data == b"frame1" - assert received[1].data == b"frame2" + assert received[0].data == images[0] + assert received[1].data == images[1] class TestTransformInMemory: - def test_lambda_transform(self, session: SqliteSession) -> None: - s = session.stream("data", str) - s.append("hello", ts=1.0) - s.append("world", ts=2.0) + def test_lambda_transform(self, session: SqliteSession, images: list[Image]) -> None: + s = session.stream("data", Image) + s.append(images[0], ts=1.0) + s.append(images[1], ts=2.0) - upper = s.transform(lambda x: x.upper()) - results = upper.fetch() + shapes = s.transform(lambda im: f"{im.width}x{im.height}") + results = shapes.fetch() assert len(results) == 2 - assert results[0].data == "HELLO" - assert results[1].data == "WORLD" - - def test_lambda_filter_none(self, session: SqliteSession) -> None: - s = session.stream("data", int) - s.append(1, ts=1.0) - s.append(2, ts=2.0) - s.append(3, ts=3.0) - - evens = s.transform(lambda x: x * 2 if x % 2 == 0 else None) - results = evens.fetch() - assert len(results) == 1 - assert results[0].data == 4 - - def test_lambda_expand_list(self, session: SqliteSession) -> None: - s = session.stream("data", str) - s.append("a,b,c", ts=1.0) - - split = s.transform(lambda x: x.split(",")) - results = split.fetch() - assert len(results) == 3 - assert [r.data for r in results] == ["a", "b", "c"] + assert results[0].data == f"{images[0].width}x{images[0].height}" + + def test_lambda_filter_none(self, session: SqliteSession, images: list[Image]) -> None: + s = session.stream("data", Image) + s.append(images[0], ts=1.0) + s.append(images[1], ts=2.0) + s.append(images[2], ts=3.0) + + # Only keep images wider than 0 (all pass), filter second by index trick + idx = iter(range(3)) + big = s.transform(lambda im: im if next(idx) % 2 == 0 else None) + results = big.fetch() + assert len(results) == 2 # indices 0 and 2 + + def test_lambda_expand_list(self, session: SqliteSession, images: list[Image]) -> None: + s = session.stream("data", Image) + s.append(images[0], ts=1.0) + + # Extract format and frame_id as two separate results + results = s.transform(lambda im: [im.format.value, im.frame_id]).fetch() + assert len(results) == 2 + assert results[0].data == images[0].format.value class TestTransformStore: - def test_transform_store_backfill(self, session: SqliteSession) -> None: - s = session.stream("data", str) - s.append("hello", ts=1.0) - s.append("world", ts=2.0) + def test_transform_store_backfill(self, session: SqliteSession, images: list[Image]) -> None: + s = session.stream("data", Image) + s.append(images[0], ts=1.0) + s.append(images[1], ts=2.0) - stored = s.transform(lambda x: x.upper()).store("upper_data") + stored = s.transform(lambda im: f"{im.width}x{im.height}").store("shapes") rows = stored.fetch() assert len(rows) == 2 - assert rows[0].data == "HELLO" - assert rows[1].data == "WORLD" + expected = f"{images[0].width}x{images[0].height}" + assert rows[0].data == expected - # Also queryable by name - reloaded = session.stream("upper_data") + reloaded = session.stream("shapes") assert reloaded.count() == 2 - def test_transform_store_live(self, session: SqliteSession) -> None: - s = session.stream("data", str) - s.append("existing", ts=1.0) + def test_transform_store_live(self, session: SqliteSession, images: list[Image]) -> None: + s = session.stream("data", Image) + s.append(images[0], ts=1.0) - # live=True skips backfill, only processes new items - stored = s.transform(lambda x: x.upper(), live=True).store("live_upper") + stored = s.transform(lambda im: im.height, live=True).store("heights") assert stored.count() == 0 # no backfill - s.append("new", ts=2.0) + s.append(images[1], ts=2.0) assert stored.count() == 1 - assert stored.last().data == "NEW" + assert stored.last().data == images[1].height - def test_transform_store_backfill_only(self, session: SqliteSession) -> None: - s = session.stream("data", str) - s.append("existing", ts=1.0) + def test_transform_store_backfill_only( + self, session: SqliteSession, images: list[Image] + ) -> None: + s = session.stream("data", Image) + s.append(images[0], ts=1.0) - stored = s.transform(lambda x: x.upper(), backfill_only=True).store("backfill_upper") + stored = s.transform(lambda im: im.height, backfill_only=True).store("heights_bo") assert stored.count() == 1 - assert stored.one().data == "EXISTING" + assert stored.one().data == images[0].height - # New appends should NOT propagate - s.append("new", ts=2.0) + s.append(images[1], ts=2.0) assert stored.count() == 1 # still 1 +class TestLazyData: + def test_data_lazy_loaded(self, session: SqliteSession, images: list[Image]) -> None: + """Fetched observations should not eagerly load payload.""" + s = session.stream("data", Image) + s.append(images[0], ts=1.0) + + rows = s.fetch() + obs = rows[0] + from dimos.memory.types import _UNSET + + assert obs._data is _UNSET + assert obs._data_loader is not None + assert obs.data == images[0] + assert obs._data == images[0] + + def test_metadata_without_payload(self, session: SqliteSession, images: list[Image]) -> None: + """Metadata (ts, tags) should be available without loading payload.""" + s = session.stream("data", Image) + s.append(images[0], ts=1.0, tags={"key": "val"}) + + rows = s.fetch() + obs = rows[0] + assert obs.ts == 1.0 + assert obs.tags == {"key": "val"} + assert obs.id == 1 + + +class TestIteration: + def test_iter(self, session: SqliteSession, images: list[Image]) -> None: + s = session.stream("data", Image) + for i, img in enumerate(images[:3]): + s.append(img, ts=float(i)) + + items = [obs.data for obs in s] + assert items == images[:3] + + class TestStoreReopen: - def test_data_persists(self, tmp_path: object) -> None: + def test_data_persists(self, tmp_path: object, images: list[Image]) -> None: from pathlib import Path assert isinstance(tmp_path, Path) db_path = str(tmp_path / "persist.db") - # Write store1 = SqliteStore(db_path) s1 = store1.session() - s1.stream("data", str).append("hello", ts=1.0) + s1.stream("data", Image).append(images[0], ts=1.0) s1.close() store1.close() - # Re-open and read store2 = SqliteStore(db_path) s2 = store2.session() - rows = s2.stream("data", str).fetch() + rows = s2.stream("data", Image).fetch() assert len(rows) == 1 - assert rows[0].data == "hello" + assert rows[0].data == images[0] s2.close() store2.close() diff --git a/dimos/memory_old/timeseries/__init__.py b/dimos/memory_old/timeseries/__init__.py index 6e77185c43..51130005b3 100644 --- a/dimos/memory_old/timeseries/__init__.py +++ b/dimos/memory_old/timeseries/__init__.py @@ -13,19 +13,19 @@ # limitations under the License. """Time series storage and replay.""" -from dimos.memory.timeseries.base import TimeSeriesStore -from dimos.memory.timeseries.inmemory import InMemoryStore -from dimos.memory.timeseries.pickledir import PickleDirStore -from dimos.memory.timeseries.sqlite import SqliteTSStore +from dimos.memory_old.timeseries.base import TimeSeriesStore +from dimos.memory_old.timeseries.inmemory import InMemoryStore +from dimos.memory_old.timeseries.pickledir import PickleDirStore +from dimos.memory_old.timeseries.sqlite import SqliteTSStore def __getattr__(name: str): # type: ignore[no-untyped-def] if name == "PostgresStore": - from dimos.memory.timeseries.postgres import PostgresStore + from dimos.memory_old.timeseries.postgres import PostgresStore return PostgresStore if name == "reset_db": - from dimos.memory.timeseries.postgres import reset_db + from dimos.memory_old.timeseries.postgres import reset_db return reset_db raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/dimos/memory_old/timeseries/inmemory.py b/dimos/memory_old/timeseries/inmemory.py index b67faca644..608235c11d 100644 --- a/dimos/memory_old/timeseries/inmemory.py +++ b/dimos/memory_old/timeseries/inmemory.py @@ -17,7 +17,7 @@ from sortedcontainers import SortedKeyList # type: ignore[import-untyped] -from dimos.memory.timeseries.base import T, TimeSeriesStore +from dimos.memory_old.timeseries.base import T, TimeSeriesStore class InMemoryStore(TimeSeriesStore[T]): diff --git a/dimos/memory_old/timeseries/legacy.py b/dimos/memory_old/timeseries/legacy.py index 821d306d2d..6abcf8deef 100644 --- a/dimos/memory_old/timeseries/legacy.py +++ b/dimos/memory_old/timeseries/legacy.py @@ -30,7 +30,7 @@ from reactivex.observable import Observable from reactivex.scheduler import TimeoutScheduler -from dimos.memory.timeseries.base import T, TimeSeriesStore +from dimos.memory_old.timeseries.base import T, TimeSeriesStore from dimos.utils.data import get_data, get_data_dir diff --git a/dimos/memory_old/timeseries/pickledir.py b/dimos/memory_old/timeseries/pickledir.py index 9e8cd5a249..719c9f8a94 100644 --- a/dimos/memory_old/timeseries/pickledir.py +++ b/dimos/memory_old/timeseries/pickledir.py @@ -20,7 +20,7 @@ from pathlib import Path import pickle -from dimos.memory.timeseries.base import T, TimeSeriesStore +from dimos.memory_old.timeseries.base import T, TimeSeriesStore from dimos.utils.data import get_data, get_data_dir diff --git a/dimos/memory_old/timeseries/postgres.py b/dimos/memory_old/timeseries/postgres.py index 0daae44adb..c6774d3920 100644 --- a/dimos/memory_old/timeseries/postgres.py +++ b/dimos/memory_old/timeseries/postgres.py @@ -21,7 +21,7 @@ import psycopg2.extensions # type: ignore[import-untyped] from dimos.core.resource import Resource -from dimos.memory.timeseries.base import T, TimeSeriesStore +from dimos.memory_old.timeseries.base import T, TimeSeriesStore # Valid SQL identifier: alphanumeric and underscores, not starting with digit _VALID_IDENTIFIER = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") diff --git a/dimos/memory_old/timeseries/sqlite.py b/dimos/memory_old/timeseries/sqlite.py index 6f1f0d88e4..a7d3fcbb35 100644 --- a/dimos/memory_old/timeseries/sqlite.py +++ b/dimos/memory_old/timeseries/sqlite.py @@ -19,7 +19,7 @@ import re import sqlite3 -from dimos.memory.timeseries.base import T, TimeSeriesStore +from dimos.memory_old.timeseries.base import T, TimeSeriesStore from dimos.utils.data import get_data, get_data_dir # Valid SQL identifier: alphanumeric and underscores, not starting with digit diff --git a/dimos/memory_old/timeseries/test_base.py b/dimos/memory_old/timeseries/test_base.py index 31b811c251..491f0ed534 100644 --- a/dimos/memory_old/timeseries/test_base.py +++ b/dimos/memory_old/timeseries/test_base.py @@ -20,11 +20,11 @@ import pytest -from dimos.memory.timeseries.base import TimeSeriesStore -from dimos.memory.timeseries.inmemory import InMemoryStore -from dimos.memory.timeseries.legacy import LegacyPickleStore -from dimos.memory.timeseries.pickledir import PickleDirStore -from dimos.memory.timeseries.sqlite import SqliteTSStore +from dimos.memory_old.timeseries.base import TimeSeriesStore +from dimos.memory_old.timeseries.inmemory import InMemoryStore +from dimos.memory_old.timeseries.legacy import LegacyPickleStore +from dimos.memory_old.timeseries.pickledir import PickleDirStore +from dimos.memory_old.timeseries.sqlite import SqliteTSStore from dimos.types.timestamped import Timestamped @@ -81,7 +81,7 @@ def make_legacy_pickle_store(tmpdir: str) -> TimeSeriesStore[SampleData]: try: import psycopg2 - from dimos.memory.timeseries.postgres import PostgresStore + from dimos.memory_old.timeseries.postgres import PostgresStore # Test connection _test_conn = psycopg2.connect(dbname="dimensional") diff --git a/dimos/memory_old/timeseries/test_legacy.py b/dimos/memory_old/timeseries/test_legacy.py index c77ec64a76..145af0d1f4 100644 --- a/dimos/memory_old/timeseries/test_legacy.py +++ b/dimos/memory_old/timeseries/test_legacy.py @@ -15,7 +15,7 @@ import pytest -from dimos.memory.timeseries.legacy import LegacyPickleStore +from dimos.memory_old.timeseries.legacy import LegacyPickleStore class TestLegacyPickleStoreRealData: diff --git a/dimos/protocol/tf/tf.py b/dimos/protocol/tf/tf.py index 825e89fc8c..6e75813d8b 100644 --- a/dimos/protocol/tf/tf.py +++ b/dimos/protocol/tf/tf.py @@ -20,7 +20,7 @@ from functools import reduce from typing import TypeVar -from dimos.memory.timeseries.inmemory import InMemoryStore +from dimos.memory_old.timeseries.inmemory import InMemoryStore from dimos.msgs.geometry_msgs import PoseStamped, Transform from dimos.msgs.tf2_msgs import TFMessage from dimos.protocol.pubsub.impl.lcmpubsub import LCM, Topic diff --git a/dimos/types/test_timestamped.py b/dimos/types/test_timestamped.py index 7de82e8f9a..eaad794384 100644 --- a/dimos/types/test_timestamped.py +++ b/dimos/types/test_timestamped.py @@ -19,7 +19,7 @@ from reactivex import operators as ops from reactivex.scheduler import ThreadPoolScheduler -from dimos.memory.timeseries.inmemory import InMemoryStore +from dimos.memory_old.timeseries.inmemory import InMemoryStore from dimos.msgs.sensor_msgs import Image from dimos.types.timestamped import ( Timestamped, diff --git a/dimos/types/timestamped.py b/dimos/types/timestamped.py index b229a2478e..a02cd392e1 100644 --- a/dimos/types/timestamped.py +++ b/dimos/types/timestamped.py @@ -22,7 +22,7 @@ # from dimos_lcm.std_msgs import Time as ROSTime from reactivex.observable import Observable -from dimos.memory.timeseries.inmemory import InMemoryStore +from dimos.memory_old.timeseries.inmemory import InMemoryStore from dimos.types.weaklist import WeakList from dimos.utils.logging_config import setup_logger diff --git a/dimos/utils/testing/replay.py b/dimos/utils/testing/replay.py index 588b63e099..68d3ca8fe8 100644 --- a/dimos/utils/testing/replay.py +++ b/dimos/utils/testing/replay.py @@ -14,7 +14,7 @@ """Shim for TimedSensorReplay/TimedSensorStorage.""" -from dimos.memory.timeseries.legacy import LegacyPickleStore +from dimos.memory_old.timeseries.legacy import LegacyPickleStore SensorReplay = LegacyPickleStore SensorStorage = LegacyPickleStore diff --git a/plans/memory/api.md b/plans/memory/api.md index bb62d52ca4..39cd7f738c 100644 --- a/plans/memory/api.md +++ b/plans/memory/api.md @@ -450,6 +450,10 @@ class Session: def text_stream(self, name: str, payload_type: type | None = None, *, tokenizer: str = "unicode61", pose_provider: PoseProvider | None = None) -> TextStream: ... + def embedding_stream(self, name: str, payload_type: type | None = None, *, + vec_dimensions: int | None = None, + pose_provider: PoseProvider | None = None, + parent_table: str | None = None) -> EmbeddingStream: ... def materialize_transform(self, name: str, source: Stream, transformer: Transformer, *, live: bool = False, @@ -466,27 +470,146 @@ class Store: A `Stream` can be backed by different things — the user never sees this: -- **DB table** — from `session.stream()`. Has `_meta`, `_payload`, indexes. +- **DB tables** — from `session.stream()`. Metadata + payload + indexes. - **Predicate** — from `.after()`, `.near()`, etc. Lazy SQL WHERE. - **Transform** — from `.transform(t)`. Source stream + Transformer. The impl decides how to execute based on the backing chain. +## SQLite Schema + +Each stream `{name}` creates these tables: + +```sql +-- Metadata table (compact rows, fast scans) +CREATE TABLE {name} ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + ts REAL, + pose_x REAL, -- position + pose_y REAL, + pose_z REAL, + pose_qx REAL, -- orientation quaternion (stored, not indexed) + pose_qy REAL, + pose_qz REAL, + pose_qw REAL, + tags TEXT DEFAULT '{}', + parent_id INTEGER -- lineage: source observation id +); +CREATE INDEX idx_{name}_ts ON {name}(ts); + +-- Payload table (blobs, loaded on demand) +CREATE TABLE {name}_payload ( + id INTEGER PRIMARY KEY, + data BLOB +); + +-- R*Tree spatial index (position only) +CREATE VIRTUAL TABLE {name}_rtree USING rtree( + id, + min_x, max_x, + min_y, max_y, + min_z, max_z +); +``` + +**Optional per stream kind:** + +```sql +-- TextStream: FTS5 full-text index +CREATE VIRTUAL TABLE {name}_fts USING fts5(content, tokenize='unicode61'); + +-- EmbeddingStream: vec0 vector index +CREATE VIRTUAL TABLE {name}_vec USING vec0(embedding float[{dim}]); +``` + +### Key design decisions + +- **Separate payload table** — metadata queries (`fetch`, `count`, `near`, filters) never touch blob data. Payload is loaded lazily via `obs.data`. +- **Decomposed pose columns** — enables R*Tree spatial index for `.near()` queries. Orientation stored for reconstruction but not spatially indexed. +- **R*Tree for spatial queries** — `.near(pose, radius)` compiles to an R*Tree range query (bounding box at ±radius), not post-query Python filtering. + +### Lazy payload loading + +`fetch()` returns `Observation` with lazy `.data`: +- Metadata query: `SELECT id, ts, pose_x, ..., tags FROM {name} WHERE ...` +- `_data` stays `_UNSET`, `_data_loader` is set to: `SELECT data FROM {name}_payload WHERE id = ?` +- Only `obs.data` access triggers the blob read + codec decode + +This means iterating metadata (`obs.ts`, `obs.pose`, `obs.tags`) is cheap. + +### NearFilter SQL compilation + +```python +# .near(pose, 5.0) compiles to: +# JOIN {name}_rtree AS r ON r.id = {name}.id +# WHERE r.min_x >= pose.x - 5.0 AND r.max_x <= pose.x + 5.0 +# AND r.min_y >= pose.y - 5.0 AND r.max_y <= pose.y + 5.0 +# AND r.min_z >= pose.z - 5.0 AND r.max_z <= pose.z + 5.0 +``` + +For exact distance (not just bounding box), a post-filter computes Euclidean distance on the R*Tree candidates. + +## Serialization (Codec) + +Each stream has a `Codec[T]` that handles payload encode/decode. Auto-selected from `payload_type`. + +```python +class Codec(Protocol[T]): + def encode(self, value: T) -> bytes: ... + def decode(self, data: bytes) -> T: ... + +class LcmCodec(Codec[DimosMsg]): + """For DimosMsg types — uses lcm_encode/lcm_decode.""" + def __init__(self, msg_type: type[DimosMsg]) -> None: ... + +class PickleCodec(Codec[Any]): + """Fallback for arbitrary Python objects.""" + +def codec_for_type(payload_type: type[T] | None) -> Codec[T]: + """Auto-select codec based on payload type.""" + if payload_type is not None and issubclass(payload_type, DimosMsg): + return LcmCodec(payload_type) + return PickleCodec() +``` + +Lives in `dimos.memory.codec`. Detection uses `dimos.msgs.protocol.DimosMsg` (`runtime_checkable`). + +Transparent to the user — just pass `payload_type` to `session.stream()`: +```python +images = session.stream("images", Image) # auto LCM codec +numbers = session.stream("numbers", int) # auto pickle codec +``` + +Tags are JSON. Poses are decomposed into columns (not serialized). + +### Stream metadata (`_streams` table) + +``` +name TEXT PRIMARY KEY +payload_module TEXT -- fully qualified, e.g. "dimos.msgs.sensor_msgs.Image.Image" +stream_kind TEXT -- "stream" | "text" | "embedding" +parent_stream TEXT -- parent stream name (lineage for join()) +embedding_dim INTEGER -- vec0 dimension (embedding streams only) +``` + +On restart, `session.stream("images")` (no `payload_type`) resolves the class from `payload_module` via `importlib`, then selects the codec automatically. `embedding_dim` allows recreating the vec0 table without needing to see the first embedding again. + ## Implementation Notes - **No ORM** — raw `sqlite3` with direct SQL. The `Stream` filter chain *is* the query builder. - **Session threading** — streams created by `session.stream()` get `_session` set. `TransformStream` inherits it from its source. `store()` also accepts an explicit `session=` fallback. -- **Serialization** — payloads are `pickle`, poses are `pickle`, tags are JSON. -- **Near filter** — compiled as no-op SQL (`1=1`), filtered post-query in Python via pose distance. ## Resolved Questions 1. **`.append()` on non-stored streams?** → `TypeError` (requires backend). 2. **Multiple `.store()` calls?** → Idempotent — returns existing stream if already stored. 3. ~~**Memory pressure from in-memory transforms?**~~ → Solved via `fetch_pages`. +4. **Pose storage** → Decomposed columns + R*Tree index (not binary blob). +5. **Payload loading** → Lazy via separate `{name}_payload` table. +6. **`__iter__`** → `for page in self.fetch_pages(): yield from page` — lazy, memory-efficient iteration. ## Open Questions 1. **`project_to` / lineage** — `parent_id` column exists but not yet wired. 2. **Incremental transforms** — re-running a stored transform should resume from last processed item. -3. **`__iter__`** — spec shows `Stream.__iter__` but not yet implemented. +3. **4D indexing** — should R*Tree include time as a 4th dimension? See `query_objects.md` for the Criterion/Score direction. From 936e2ce64c9c51bf38536a51037e7ade2227fb3d Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 5 Mar 2026 16:36:35 +0800 Subject: [PATCH 010/118] JpegCodec for Image storage (43x smaller), ingest helpers, QualityWindowTransformer, E2E test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add JpegCodec as default codec for Image types (2.76MB → 64KB per frame) - Preserve frame_id in JPEG header; ts stored in meta table - Add ingest() helper for bulk-loading (ts, payload) iterables into streams - Add QualityWindowTransformer: best-frame-per-window (supports backfill + live) - EmbeddingTransformer sets output_type=Embedding automatically - Require payload_type when creating new streams (no silent PickleCodec fallback) - TransformStream.store() accepts payload_type, propagated through materialize_transform - E2E test: 5min video → sharpness filter → CLIP embed → text search - Move test_sqlite.py next to sqlite.py, update Image comparisons for lossy codec - Add sqlite-vec dependency --- dimos/memory/__init__.py | 3 +- dimos/memory/codec.py | 56 +++++++- dimos/memory/impl/run_e2e_export.py | 99 +++++++++++++ dimos/memory/impl/sqlite.py | 29 +++- dimos/memory/{tests => impl}/test_sqlite.py | 150 ++++++++++++++++---- dimos/memory/impl/test_sqlite_e2e.py | 113 +++++++++++++++ dimos/memory/ingest.py | 42 ++++++ dimos/memory/store.py | 1 + dimos/memory/stream.py | 9 +- dimos/memory/transformer.py | 74 ++++++++++ pyproject.toml | 2 +- uv.lock | 14 ++ 12 files changed, 551 insertions(+), 41 deletions(-) create mode 100644 dimos/memory/impl/run_e2e_export.py rename dimos/memory/{tests => impl}/test_sqlite.py (69%) create mode 100644 dimos/memory/impl/test_sqlite_e2e.py create mode 100644 dimos/memory/ingest.py diff --git a/dimos/memory/__init__.py b/dimos/memory/__init__.py index 14b9c87ba3..c8f0b8a336 100644 --- a/dimos/memory/__init__.py +++ b/dimos/memory/__init__.py @@ -1,4 +1,4 @@ -from dimos.memory.codec import Codec, LcmCodec, PickleCodec, codec_for_type +from dimos.memory.codec import Codec, JpegCodec, LcmCodec, PickleCodec, codec_for_type from dimos.memory.store import Session, Store from dimos.memory.stream import EmbeddingStream, Stream, TextStream from dimos.memory.transformer import ( @@ -17,6 +17,7 @@ "EmbeddingObservation", "EmbeddingStream", "EmbeddingTransformer", + "JpegCodec", "LcmCodec", "Observation", "PerItemTransformer", diff --git a/dimos/memory/codec.py b/dimos/memory/codec.py index 3a18fe21df..3fdc2f7592 100644 --- a/dimos/memory/codec.py +++ b/dimos/memory/codec.py @@ -44,6 +44,46 @@ def decode(self, data: bytes) -> DimosMsg: return self._msg_type.lcm_decode(data) +class JpegCodec: + """Codec for Image types — stores as JPEG bytes (lossy, ~10-20x smaller). + + Preserves ``frame_id`` as a short header: ````. + Pixel data is lossy-compressed; ``ts`` is NOT preserved (stored in the meta table). + """ + + def __init__(self, quality: int = 90) -> None: + self._quality = quality + + def encode(self, value: Any) -> bytes: + import struct + + import cv2 + + bgr = value.to_opencv() + ok, buf = cv2.imencode(".jpg", bgr, [cv2.IMWRITE_JPEG_QUALITY, self._quality]) + if not ok: + raise ValueError("JPEG encoding failed") + frame_id = (value.frame_id or "").encode("utf-8") + header = struct.pack(" Any: + import struct + + import cv2 + import numpy as np + + from dimos.msgs.sensor_msgs.Image import Image, ImageFormat + + fid_len = struct.unpack(" LcmCodec: return _POSE_CODEC -def codec_for_type(payload_type: type | None) -> LcmCodec | PickleCodec: +def codec_for_type(payload_type: type | None) -> LcmCodec | JpegCodec | PickleCodec: """Auto-select codec based on payload type.""" - if ( - payload_type is not None - and hasattr(payload_type, "lcm_encode") - and hasattr(payload_type, "lcm_decode") - ): - return LcmCodec(payload_type) # type: ignore[arg-type] + if payload_type is not None: + # Image → JPEG by default (much smaller than LCM raw pixels) + from dimos.msgs.sensor_msgs.Image import Image + + if issubclass(payload_type, Image): + return JpegCodec() + if hasattr(payload_type, "lcm_encode") and hasattr(payload_type, "lcm_decode"): + return LcmCodec(payload_type) # type: ignore[arg-type] return PickleCodec() diff --git a/dimos/memory/impl/run_e2e_export.py b/dimos/memory/impl/run_e2e_export.py new file mode 100644 index 0000000000..aea96d497c --- /dev/null +++ b/dimos/memory/impl/run_e2e_export.py @@ -0,0 +1,99 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Ingest 5min robot video → sharpness filter → CLIP embed → export top matches. + +Caches the DB — re-run to just search without re-ingesting/embedding. +""" + +from __future__ import annotations + +from pathlib import Path + +from dimos.memory.impl.sqlite import SqliteStore +from dimos.memory.ingest import ingest +from dimos.memory.transformer import EmbeddingTransformer, QualityWindowTransformer +from dimos.models.embedding.clip import CLIPModel +from dimos.msgs.sensor_msgs.Image import Image +from dimos.utils.testing import TimedSensorReplay + +OUT_DIR = Path(__file__).parent / "e2e_matches" +OUT_DIR.mkdir(exist_ok=True) + +db_path = OUT_DIR / "e2e.db" +store = SqliteStore(str(db_path)) +session = store.session() + +# Check if we already have data +existing = {s.name for s in session.list_streams()} +need_build = "clip_embeddings" not in existing + +if need_build: + replay = TimedSensorReplay("unitree_go2_bigoffice/video") + + print("Loading CLIP...") + clip = CLIPModel() + clip.start() + + # 1. Ingest 5 minutes + print("Ingesting 5 min of video...") + raw = session.stream("raw_video", Image) + n = ingest(raw, replay.iterate_ts(seek=5.0, duration=300.0)) + print(f" {n} frames ingested") + + # 2. Sharpness filter + print("Filtering by sharpness (0.5s windows)...") + sharp = raw.transform(QualityWindowTransformer(lambda img: img.sharpness, window=0.5)).store( + "sharp_frames", Image + ) + n_sharp = sharp.count() + print(f" {n_sharp} sharp frames (from {n}, {n_sharp / n:.0%} kept)") + + # 3. Embed + print("Embedding with CLIP...") + embeddings = sharp.transform(EmbeddingTransformer(clip)).store("clip_embeddings") + print(f" {embeddings.count()} embeddings stored") +else: + print(f"Using cached DB ({db_path})") + clip = CLIPModel() + clip.start() + sharp = session.stream("sharp_frames") + embeddings = session.embedding_stream("clip_embeddings") + print(f" {sharp.count()} sharp frames, {embeddings.count()} embeddings") + +# 4. Search and export +queries = [ + "a hallway in an office", + "a person standing", + "a door", + "a desk", + "supermarket", + "large room", +] + +for query_text in queries: + print(f"\nQuery: '{query_text}'") + query_emb = clip.embed_text(query_text) + results = embeddings.search_embedding(query_emb, k=5).fetch() + + slug = query_text.replace(" ", "_")[:30] + for i, r in enumerate(results): + img_obs = sharp.at(r.ts, tolerance=0.01).one() + fname = OUT_DIR / f"{slug}_{i + 1}_id{r.id}_ts{r.ts:.0f}.jpg" + img_obs.data.save(str(fname)) + print(f" [{i + 1}] id={r.id} ts={r.ts:.2f} → {fname.name}") + +session.close() +store.close() +print(f"\nDone. Results in {OUT_DIR}/") diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index 8a549803be..c333268fe3 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -36,6 +36,7 @@ from reactivex.subject import Subject from dimos.memory.codec import ( + JpegCodec, LcmCodec, PickleCodec, codec_for_type, @@ -289,7 +290,7 @@ def __init__( table: str, *, pose_provider: PoseProvider | None = None, - codec: LcmCodec | PickleCodec | None = None, + codec: LcmCodec | JpegCodec | PickleCodec | None = None, ) -> None: self._conn = conn self._table = table @@ -413,7 +414,7 @@ def __init__( vec_dimensions: int | None = None, pose_provider: PoseProvider | None = None, parent_table: str | None = None, - codec: LcmCodec | PickleCodec | None = None, + codec: LcmCodec | JpegCodec | PickleCodec | None = None, ) -> None: super().__init__(conn, table, pose_provider=pose_provider, codec=codec) self._vec_dimensions = vec_dimensions @@ -540,7 +541,7 @@ def __init__( *, tokenizer: str = "unicode61", pose_provider: PoseProvider | None = None, - codec: LcmCodec | PickleCodec | None = None, + codec: LcmCodec | JpegCodec | PickleCodec | None = None, ) -> None: super().__init__(conn, table, pose_provider=pose_provider, codec=codec) self._tokenizer = tokenizer @@ -652,6 +653,12 @@ def stream( if payload_type is None: payload_type = self._resolve_payload_type(name) + if payload_type is None: + raise TypeError( + f"stream({name!r}): payload_type is required when creating a new stream. " + "Pass the type explicitly, e.g. session.stream('images', Image)." + ) + self._ensure_stream_tables(name) self._register_stream(name, payload_type, "stream") @@ -736,14 +743,15 @@ def materialize_transform( source: Stream[Any], transformer: Transformer[Any, Any], *, + payload_type: type | None = None, live: bool = False, backfill_only: bool = False, ) -> Stream[Any]: target: Stream[Any] if isinstance(transformer, EmbeddingTransformer): - target = self.embedding_stream(name) + target = self.embedding_stream(name, payload_type) else: - target = self.stream(name) + target = self.stream(name, payload_type) # Backfill existing data if transformer.supports_backfill and not live: @@ -838,9 +846,20 @@ def __init__(self, path: str) -> None: self._conn = sqlite3.connect(path) self._conn.execute("PRAGMA journal_mode=WAL") self._conn.execute("PRAGMA synchronous=NORMAL") + self._load_extensions() def session(self) -> SqliteSession: return SqliteSession(self._conn) + def _load_extensions(self) -> None: + try: + import sqlite_vec + + self._conn.enable_load_extension(True) + sqlite_vec.load(self._conn) + self._conn.enable_load_extension(False) + except ImportError: + pass + def close(self) -> None: self._conn.close() diff --git a/dimos/memory/tests/test_sqlite.py b/dimos/memory/impl/test_sqlite.py similarity index 69% rename from dimos/memory/tests/test_sqlite.py rename to dimos/memory/impl/test_sqlite.py index f45ed44c36..267736e83e 100644 --- a/dimos/memory/tests/test_sqlite.py +++ b/dimos/memory/impl/test_sqlite.py @@ -18,9 +18,13 @@ from typing import TYPE_CHECKING +import numpy as np import pytest from dimos.memory.impl.sqlite import SqliteSession, SqliteStore +from dimos.memory.transformer import EmbeddingTransformer +from dimos.memory.types import _UNSET, EmbeddingObservation +from dimos.models.embedding.base import Embedding, EmbeddingModel from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.testing import TimedSensorReplay @@ -28,6 +32,15 @@ from dimos.memory.types import Observation +def _img_close(a: Image, b: Image, max_diff: float = 5.0) -> bool: + """Approximate Image equality (JPEG is lossy).""" + if a.data.shape != b.data.shape: + return False + if a.frame_id != b.frame_id: + return False + return float(np.abs(a.data.astype(np.float32) - b.data.astype(np.float32)).mean()) < max_diff + + @pytest.fixture(scope="module") def replay() -> TimedSensorReplay: # type: ignore[type-arg] return TimedSensorReplay("unitree_go2_bigoffice/video") @@ -63,12 +76,12 @@ def test_append_and_fetch(self, session: SqliteSession, images: list[Image]) -> s = session.stream("images", Image) obs = s.append(images[0]) assert obs.id == 1 - assert obs.data == images[0] + assert obs.data == images[0] # append returns original, not decoded assert obs.ts is not None rows = s.fetch() assert len(rows) == 1 - assert rows[0].data == images[0] + assert _img_close(rows[0].data, images[0]) assert rows[0].id == 1 def test_append_multiple(self, session: SqliteSession, images: list[Image]) -> None: @@ -78,7 +91,7 @@ def test_append_multiple(self, session: SqliteSession, images: list[Image]) -> N assert s.count() == 3 rows = s.fetch() - assert [r.data for r in rows] == images[:3] + assert all(_img_close(r.data, img) for r, img in zip(rows, images[:3], strict=True)) def test_append_with_tags(self, session: SqliteSession, images: list[Image]) -> None: s = session.stream("images", Image) @@ -94,7 +107,7 @@ def test_last(self, session: SqliteSession, images: list[Image]) -> None: s.append(images[2], ts=3.0) obs = s.last() - assert obs.data == images[2] + assert _img_close(obs.data, images[2]) assert obs.ts == 3.0 def test_one(self, session: SqliteSession, images: list[Image]) -> None: @@ -102,7 +115,7 @@ def test_one(self, session: SqliteSession, images: list[Image]) -> None: s.append(images[0]) obs = s.one() - assert obs.data == images[0] + assert _img_close(obs.data, images[0]) def test_one_empty_raises(self, session: SqliteSession) -> None: s = session.stream("images", Image) @@ -118,7 +131,7 @@ def test_after(self, session: SqliteSession, images: list[Image]) -> None: rows = s.after(5.0).fetch() assert len(rows) == 1 - assert rows[0].data == images[1] + assert _img_close(rows[0].data, images[1]) def test_before(self, session: SqliteSession, images: list[Image]) -> None: s = session.stream("data", Image) @@ -127,7 +140,7 @@ def test_before(self, session: SqliteSession, images: list[Image]) -> None: rows = s.before(5.0).fetch() assert len(rows) == 1 - assert rows[0].data == images[0] + assert _img_close(rows[0].data, images[0]) def test_time_range(self, session: SqliteSession, images: list[Image]) -> None: s = session.stream("data", Image) @@ -137,7 +150,7 @@ def test_time_range(self, session: SqliteSession, images: list[Image]) -> None: rows = s.time_range(3.0, 7.0).fetch() assert len(rows) == 1 - assert rows[0].data == images[1] + assert _img_close(rows[0].data, images[1]) def test_at(self, session: SqliteSession, images: list[Image]) -> None: s = session.stream("data", Image) @@ -147,7 +160,7 @@ def test_at(self, session: SqliteSession, images: list[Image]) -> None: rows = s.at(5.5, tolerance=1.0).fetch() assert len(rows) == 1 - assert rows[0].data == images[1] + assert _img_close(rows[0].data, images[1]) def test_filter_tags(self, session: SqliteSession, images: list[Image]) -> None: s = session.stream("data", Image) @@ -156,7 +169,7 @@ def test_filter_tags(self, session: SqliteSession, images: list[Image]) -> None: rows = s.filter_tags(cam="front").fetch() assert len(rows) == 1 - assert rows[0].data == images[0] + assert _img_close(rows[0].data, images[0]) def test_chained_filters(self, session: SqliteSession, images: list[Image]) -> None: s = session.stream("data", Image) @@ -166,7 +179,7 @@ def test_chained_filters(self, session: SqliteSession, images: list[Image]) -> N rows = s.after(3.0).filter_tags(cam="front").fetch() assert len(rows) == 1 - assert rows[0].data == images[1] + assert _img_close(rows[0].data, images[1]) class TestOrdering: @@ -177,7 +190,10 @@ def test_order_by_ts(self, session: SqliteSession, images: list[Image]) -> None: s.append(images[2], ts=3.0) rows = s.order_by("ts").fetch() - assert [r.data for r in rows] == [images[0], images[1], images[2]] + assert all( + _img_close(r.data, img) + for r, img in zip(rows, [images[0], images[1], images[2]], strict=True) + ) def test_order_by_desc(self, session: SqliteSession, images: list[Image]) -> None: s = session.stream("data", Image) @@ -186,7 +202,10 @@ def test_order_by_desc(self, session: SqliteSession, images: list[Image]) -> Non s.append(images[2], ts=3.0) rows = s.order_by("ts", desc=True).fetch() - assert [r.data for r in rows] == [images[2], images[1], images[0]] + assert all( + _img_close(r.data, img) + for r, img in zip(rows, [images[2], images[1], images[0]], strict=True) + ) def test_limit_offset(self, session: SqliteSession, images: list[Image]) -> None: s = session.stream("data", Image) @@ -195,7 +214,9 @@ def test_limit_offset(self, session: SqliteSession, images: list[Image]) -> None rows = s.order_by("ts").limit(2).offset(1).fetch() assert len(rows) == 2 - assert [r.data for r in rows] == [images[1], images[2]] + assert all( + _img_close(r.data, img) for r, img in zip(rows, [images[1], images[2]], strict=True) + ) class TestFetchPages: @@ -210,7 +231,7 @@ def test_basic_pagination(self, session: SqliteSession, images: list[Image]) -> assert len(pages[-1]) == 1 all_items = [obs.data for page in pages for obs in page] - assert all_items == images + assert all(_img_close(a, b) for a, b in zip(all_items, images, strict=True)) class TestTextStream: @@ -232,6 +253,84 @@ def test_text_search(self, session: SqliteSession) -> None: assert all("Motor" in r.data for r in rows) +class TestEmbeddingStream: + def test_create_and_append(self, session: SqliteSession) -> None: + es = session.embedding_stream("emb", vec_dimensions=4) + e1 = Embedding(np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)) + e2 = Embedding(np.array([0.0, 1.0, 0.0, 0.0], dtype=np.float32)) + + es.append(e1, ts=1.0) + es.append(e2, ts=2.0) + + assert es.count() == 2 + + def test_search_embedding(self, session: SqliteSession) -> None: + es = session.embedding_stream("emb_search", vec_dimensions=4) + vecs = [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.9, 0.1, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + ] + for i, v in enumerate(vecs): + es.append(Embedding(np.array(v, dtype=np.float32)), ts=float(i)) + + # Search for vector closest to [1, 0, 0, 0] — should get id=1 and id=3 + results = es.search_embedding([1.0, 0.0, 0.0, 0.0], k=2).fetch() + assert len(results) == 2 + result_ids = {r.id for r in results} + assert 1 in result_ids # exact match + assert 3 in result_ids # [0.9, 0.1, 0, 0] is close + + def test_search_returns_embedding_observation(self, session: SqliteSession) -> None: + es = session.embedding_stream("emb_obs", vec_dimensions=3) + es.append(Embedding(np.array([1.0, 0.0, 0.0], dtype=np.float32)), ts=1.0) + + results = es.search_embedding([1.0, 0.0, 0.0], k=1).fetch() + assert len(results) == 1 + assert isinstance(results[0], EmbeddingObservation) + + def test_search_with_time_filter(self, session: SqliteSession) -> None: + es = session.embedding_stream("emb_time", vec_dimensions=3) + es.append(Embedding(np.array([1.0, 0.0, 0.0], dtype=np.float32)), ts=1.0) + es.append(Embedding(np.array([1.0, 0.0, 0.0], dtype=np.float32)), ts=10.0) + + # Both match the vector, but only one is after t=5 + results = es.search_embedding([1.0, 0.0, 0.0], k=10).after(5.0).fetch() + assert len(results) == 1 + assert results[0].ts == 10.0 + + def test_embedding_transformer_store(self, session: SqliteSession, images: list[Image]) -> None: + """Test the full pipeline: images → EmbeddingTransformer → EmbeddingStream.""" + + class FakeEmbedder(EmbeddingModel): + device = "cpu" + + def embed(self, *imgs: Image) -> Embedding | list[Embedding]: # type: ignore[override] + # Return a fake embedding based on mean pixel value + results = [] + for img in imgs: + val = float(img.data.mean()) / 255.0 + results.append( + Embedding(np.array([val, 1.0 - val, 0.0, 0.0], dtype=np.float32)) + ) + return results if len(results) > 1 else results[0] + + def embed_text(self, *texts: str) -> Embedding | list[Embedding]: + raise NotImplementedError + + s = session.stream("cam_emb", Image) + s.append(images[0], ts=1.0) + s.append(images[1], ts=2.0) + + emb_stream = s.transform(EmbeddingTransformer(FakeEmbedder())).store("cam_embeddings") + assert emb_stream.count() == 2 + + # Search by vector + results = emb_stream.search_embedding([0.5, 0.5, 0.0, 0.0], k=1).fetch() + assert len(results) == 1 + + class TestListStreams: def test_list_empty(self, session: SqliteSession) -> None: assert session.list_streams() == [] @@ -255,8 +354,8 @@ def test_appended_observable(self, session: SqliteSession, images: list[Image]) s.append(images[1]) assert len(received) == 2 - assert received[0].data == images[0] - assert received[1].data == images[1] + assert received[0].data is images[0] # appended obs holds original + assert received[1].data is images[1] class TestTransformInMemory: @@ -298,7 +397,7 @@ def test_transform_store_backfill(self, session: SqliteSession, images: list[Ima s.append(images[0], ts=1.0) s.append(images[1], ts=2.0) - stored = s.transform(lambda im: f"{im.width}x{im.height}").store("shapes") + stored = s.transform(lambda im: f"{im.width}x{im.height}").store("shapes", str) rows = stored.fetch() assert len(rows) == 2 expected = f"{images[0].width}x{images[0].height}" @@ -311,7 +410,7 @@ def test_transform_store_live(self, session: SqliteSession, images: list[Image]) s = session.stream("data", Image) s.append(images[0], ts=1.0) - stored = s.transform(lambda im: im.height, live=True).store("heights") + stored = s.transform(lambda im: im.height, live=True).store("heights", int) assert stored.count() == 0 # no backfill s.append(images[1], ts=2.0) @@ -324,7 +423,7 @@ def test_transform_store_backfill_only( s = session.stream("data", Image) s.append(images[0], ts=1.0) - stored = s.transform(lambda im: im.height, backfill_only=True).store("heights_bo") + stored = s.transform(lambda im: im.height, backfill_only=True).store("heights_bo", int) assert stored.count() == 1 assert stored.one().data == images[0].height @@ -340,12 +439,11 @@ def test_data_lazy_loaded(self, session: SqliteSession, images: list[Image]) -> rows = s.fetch() obs = rows[0] - from dimos.memory.types import _UNSET - assert obs._data is _UNSET assert obs._data_loader is not None - assert obs.data == images[0] - assert obs._data == images[0] + loaded = obs.data + assert _img_close(loaded, images[0]) + assert obs._data is loaded # cached after first access def test_metadata_without_payload(self, session: SqliteSession, images: list[Image]) -> None: """Metadata (ts, tags) should be available without loading payload.""" @@ -366,7 +464,7 @@ def test_iter(self, session: SqliteSession, images: list[Image]) -> None: s.append(img, ts=float(i)) items = [obs.data for obs in s] - assert items == images[:3] + assert all(_img_close(a, b) for a, b in zip(items, images[:3], strict=True)) class TestStoreReopen: @@ -386,6 +484,6 @@ def test_data_persists(self, tmp_path: object, images: list[Image]) -> None: s2 = store2.session() rows = s2.stream("data", Image).fetch() assert len(rows) == 1 - assert rows[0].data == images[0] + assert _img_close(rows[0].data, images[0]) s2.close() store2.close() diff --git a/dimos/memory/impl/test_sqlite_e2e.py b/dimos/memory/impl/test_sqlite_e2e.py new file mode 100644 index 0000000000..fd4a049b2a --- /dev/null +++ b/dimos/memory/impl/test_sqlite_e2e.py @@ -0,0 +1,113 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""E2E test: ingest robot video → sharpness filter → CLIP embed → vector search.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +if TYPE_CHECKING: + from pathlib import Path + +from dimos.memory.impl.sqlite import SqliteStore +from dimos.memory.ingest import ingest +from dimos.memory.transformer import EmbeddingTransformer, QualityWindowTransformer +from dimos.models.embedding.clip import CLIPModel +from dimos.msgs.sensor_msgs.Image import Image +from dimos.utils.testing import TimedSensorReplay + + +@pytest.fixture(scope="module") +def replay() -> TimedSensorReplay: # type: ignore[type-arg] + return TimedSensorReplay("unitree_go2_bigoffice/video") + + +@pytest.fixture(scope="module") +def clip() -> CLIPModel: + model = CLIPModel() + model.start() + return model + + +@pytest.mark.slow +@pytest.mark.skipif_in_ci +class TestE2EPipeline: + """Ingest 60s of robot video, filter by sharpness, embed with CLIP, search.""" + + def test_ingest_filter_embed_search( + self, + tmp_path: Path, + replay: TimedSensorReplay, # type: ignore[type-arg] + clip: CLIPModel, + ) -> None: + store = SqliteStore(str(tmp_path / "e2e.db")) + session = store.session() + + # 1. Ingest 60s of video + raw = session.stream("raw_video", Image) + n_ingested = ingest(raw, replay.iterate_ts(seek=5.0, duration=60.0)) + assert n_ingested > 0 + print(f"\nIngested {n_ingested} frames") + + # 2. Sharpness filter: keep best frame per 0.5s window + sharp = raw.transform( + QualityWindowTransformer(lambda img: img.sharpness, window=0.5) + ).store("sharp_frames", Image) + n_sharp = sharp.count() + assert n_sharp > 0 + assert n_sharp < n_ingested # should reduce count + print(f"Sharp frames: {n_sharp} (from {n_ingested}, {n_sharp / n_ingested:.0%} kept)") + + # 3. Embed with real CLIP model + embeddings = sharp.transform(EmbeddingTransformer(clip)).store("clip_embeddings") + n_emb = embeddings.count() + assert n_emb == n_sharp + print(f"Embeddings stored: {n_emb}") + + # 4. Text-to-image search + query_emb = clip.embed_text("a hallway in an office") + results = embeddings.search_embedding(query_emb, k=5).fetch() + assert len(results) > 0 + assert len(results) <= 5 + print(f"Search returned {len(results)} results") + + for r in results: + assert r.ts is not None + assert r.data is not None + print(f" id={r.id} ts={r.ts:.2f}") + + # 5. Search with time filter + mid_ts = (results[0].ts + results[-1].ts) / 2 if len(results) > 1 else results[0].ts + filtered = embeddings.search_embedding(query_emb, k=10).after(mid_ts).fetch() + assert all(r.ts > mid_ts for r in filtered) + print(f"Time-filtered search: {len(filtered)} results after ts={mid_ts:.2f}") + + # 6. Verify persistence — reopen and search again + session.close() + store.close() + + store2 = SqliteStore(str(tmp_path / "e2e.db")) + session2 = store2.session() + reloaded = session2.embedding_stream("clip_embeddings") + assert reloaded.count() == n_emb + + results2 = reloaded.search_embedding(query_emb, k=3).fetch() + assert len(results2) > 0 + print(f"After reopen: {len(results2)} results") + + session2.close() + store2.close() diff --git a/dimos/memory/ingest.py b/dimos/memory/ingest.py new file mode 100644 index 0000000000..5d96fc22b3 --- /dev/null +++ b/dimos/memory/ingest.py @@ -0,0 +1,42 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for ingesting timestamped data into memory streams.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Iterable + + from dimos.memory.stream import Stream + + +def ingest( + stream: Stream[Any], + source: Iterable[tuple[float, Any]], +) -> int: + """Ingest (timestamp, payload) pairs into a stream. + + Accepts any iterable of ``(ts, data)`` — e.g. ``replay.iterate_ts(seek=5, duration=60)``. + + Returns: + Number of items ingested. + """ + count = 0 + for ts, payload in source: + stream.append(payload, ts=ts) + count += 1 + return count diff --git a/dimos/memory/store.py b/dimos/memory/store.py index 8662d0e895..a5037caa04 100644 --- a/dimos/memory/store.py +++ b/dimos/memory/store.py @@ -69,6 +69,7 @@ def materialize_transform( source: Stream[Any], transformer: Transformer[Any, Any], *, + payload_type: type | None = None, live: bool = False, backfill_only: bool = False, ) -> Stream[Any]: diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index dd0543bae4..a9adb04981 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -362,7 +362,12 @@ def fetch(self) -> list[Observation]: self._transformer.process(self._source, collector) return collector.results - def store(self, name: str | None = None, session: Session | None = None) -> Stream[R]: + def store( + self, + name: str | None = None, + payload_type: type | None = None, + session: Session | None = None, + ) -> Stream[R]: resolved = session or self._source._session if resolved is None: raise TypeError( @@ -372,10 +377,12 @@ def store(self, name: str | None = None, session: Session | None = None) -> Stre ) if name is None: raise TypeError("store() requires a name for transform outputs") + resolved_type = payload_type or self._transformer.output_type return resolved.materialize_transform( name=name, source=self._source, transformer=self._transformer, + payload_type=resolved_type, live=self._live, backfill_only=self._backfill_only, ) diff --git a/dimos/memory/transformer.py b/dimos/memory/transformer.py index d347eb5160..964a016181 100644 --- a/dimos/memory/transformer.py +++ b/dimos/memory/transformer.py @@ -34,6 +34,7 @@ class Transformer(ABC, Generic[T, R]): supports_backfill: bool = True supports_live: bool = True + output_type: type | None = None @abstractmethod def process(self, source: Stream[T], target: Stream[R]) -> None: @@ -71,6 +72,76 @@ def _apply(self, obs: Observation, target: Stream[R]) -> None: target.append(result, ts=obs.ts, pose=obs.pose, tags=obs.tags) +class QualityWindowTransformer(Transformer[T, T]): + """Keeps the highest-quality item per time window. + + Like ``sharpness_barrier`` but operates on stored data (no wall-clock dependency). + In live mode, buffers the current window and emits the best item when a new + observation falls outside the window. + """ + + supports_backfill: bool = True + supports_live: bool = True + + def __init__(self, quality_fn: Callable[[T], float], window: float = 0.5) -> None: + self._quality_fn = quality_fn + self._window = window + # Live state + self._window_start: float | None = None + self._best_obs: Observation | None = None + self._best_score: float = -1.0 + + def process(self, source: Stream[T], target: Stream[T]) -> None: + window_start: float | None = None + best_obs: Observation | None = None + best_score: float = -1.0 + + for obs in source: + ts = obs.ts or 0.0 + if window_start is None: + window_start = ts + + if (ts - window_start) >= self._window: + if best_obs is not None: + target.append( + best_obs.data, ts=best_obs.ts, pose=best_obs.pose, tags=best_obs.tags + ) + window_start = ts + best_score = -1.0 + best_obs = None + + score = self._quality_fn(obs.data) + if score > best_score: + best_score = score + best_obs = obs + + if best_obs is not None: + target.append(best_obs.data, ts=best_obs.ts, pose=best_obs.pose, tags=best_obs.tags) + + def on_append(self, obs: Observation, target: Stream[T]) -> None: + ts = obs.ts or 0.0 + + if self._window_start is None: + self._window_start = ts + + if (ts - self._window_start) >= self._window: + if self._best_obs is not None: + target.append( + self._best_obs.data, + ts=self._best_obs.ts, + pose=self._best_obs.pose, + tags=self._best_obs.tags, + ) + self._window_start = ts + self._best_score = -1.0 + self._best_obs = None + + score = self._quality_fn(obs.data) + if score > self._best_score: + self._best_score = score + self._best_obs = obs + + class EmbeddingTransformer(Transformer[Any, "Embedding"]): """Wraps an EmbeddingModel as a Transformer that produces Embedding output. @@ -81,7 +152,10 @@ class EmbeddingTransformer(Transformer[Any, "Embedding"]): supports_live: bool = True def __init__(self, model: EmbeddingModel) -> None: + from dimos.models.embedding.base import Embedding as EmbeddingCls + self.model = model + self.output_type: type | None = EmbeddingCls def process(self, source: Stream[Any], target: Stream[Embedding]) -> None: for page in source.fetch_pages(): diff --git a/pyproject.toml b/pyproject.toml index cb4607ced5..a448b90edd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,6 @@ dependencies = [ "annotation-protocol>=1.4.0", "lazy_loader", "plum-dispatch==2.5.7", - # Logging "structlog>=25.5.0,<26", "colorlog==6.9.0", @@ -66,6 +65,7 @@ dependencies = [ "toolz>=1.1.0", "protobuf>=6.33.5,<7", "psutil>=7.0.0", + "sqlite-vec>=0.1.6", ] diff --git a/uv.lock b/uv.lock index 2f53ef0e6f..46f9b1931d 100644 --- a/uv.lock +++ b/uv.lock @@ -1705,6 +1705,7 @@ dependencies = [ { name = "scipy", version = "1.15.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "scipy", version = "1.17.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "sortedcontainers" }, + { name = "sqlite-vec" }, { name = "structlog" }, { name = "terminaltexteffects" }, { name = "textual" }, @@ -2085,6 +2086,7 @@ requires-dist = [ { name = "sortedcontainers", marker = "extra == 'docker'" }, { name = "sounddevice", marker = "extra == 'agents'" }, { name = "soundfile", marker = "extra == 'web'" }, + { name = "sqlite-vec", specifier = ">=0.1.6" }, { name = "sse-starlette", marker = "extra == 'web'", specifier = ">=2.2.1" }, { name = "structlog", specifier = ">=25.5.0,<26" }, { name = "structlog", marker = "extra == 'docker'", specifier = ">=25.5.0,<26" }, @@ -9063,6 +9065,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/46/2c/1462b1d0a634697ae9e55b3cecdcb64788e8b7d63f54d923fcd0bb140aed/soupsieve-2.8.3-py3-none-any.whl", hash = "sha256:ed64f2ba4eebeab06cc4962affce381647455978ffc1e36bb79a545b91f45a95", size = 37016, upload-time = "2026-01-20T04:27:01.012Z" }, ] +[[package]] +name = "sqlite-vec" +version = "0.1.6" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/ed/aabc328f29ee6814033d008ec43e44f2c595447d9cccd5f2aabe60df2933/sqlite_vec-0.1.6-py3-none-macosx_10_6_x86_64.whl", hash = "sha256:77491bcaa6d496f2acb5cc0d0ff0b8964434f141523c121e313f9a7d8088dee3", size = 164075, upload-time = "2024-11-20T16:40:29.847Z" }, + { url = "https://files.pythonhosted.org/packages/a7/57/05604e509a129b22e303758bfa062c19afb020557d5e19b008c64016704e/sqlite_vec-0.1.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fdca35f7ee3243668a055255d4dee4dea7eed5a06da8cad409f89facf4595361", size = 165242, upload-time = "2024-11-20T16:40:31.206Z" }, + { url = "https://files.pythonhosted.org/packages/f2/48/dbb2cc4e5bad88c89c7bb296e2d0a8df58aab9edc75853728c361eefc24f/sqlite_vec-0.1.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b0519d9cd96164cd2e08e8eed225197f9cd2f0be82cb04567692a0a4be02da3", size = 103704, upload-time = "2024-11-20T16:40:33.729Z" }, + { url = "https://files.pythonhosted.org/packages/80/76/97f33b1a2446f6ae55e59b33869bed4eafaf59b7f4c662c8d9491b6a714a/sqlite_vec-0.1.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux1_x86_64.whl", hash = "sha256:823b0493add80d7fe82ab0fe25df7c0703f4752941aee1c7b2b02cec9656cb24", size = 151556, upload-time = "2024-11-20T16:40:35.387Z" }, + { url = "https://files.pythonhosted.org/packages/6a/98/e8bc58b178266eae2fcf4c9c7a8303a8d41164d781b32d71097924a6bebe/sqlite_vec-0.1.6-py3-none-win_amd64.whl", hash = "sha256:c65bcfd90fa2f41f9000052bcb8bb75d38240b2dae49225389eca6c3136d3f0c", size = 281540, upload-time = "2024-11-20T16:40:37.296Z" }, +] + [[package]] name = "sse-starlette" version = "3.2.0" From 852a7e9bea55b7e6b542d391400ea4f3795aa4ea Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 5 Mar 2026 16:45:33 +0800 Subject: [PATCH 011/118] Wire parent_id lineage through transforms for automatic source data projection - Add parent_id to Observation, append(), do_append(), and _META_COLS - All transformers (PerItem, QualityWindow, Embedding) pass obs.id as parent_id - SqliteEmbeddingBackend._row_to_obs() wires _source_data_loader via parent_id - EmbeddingObservation.data now auto-projects to parent stream's payload (e.g. Image) - No more timestamp-matching hacks to find source data from embedding results --- dimos/memory/impl/run_e2e_export.py | 10 ++--- dimos/memory/impl/sqlite.py | 57 +++++++++++++++++++++++------ dimos/memory/stream.py | 6 ++- dimos/memory/transformer.py | 23 +++++++++--- dimos/memory/types.py | 1 + 5 files changed, 74 insertions(+), 23 deletions(-) diff --git a/dimos/memory/impl/run_e2e_export.py b/dimos/memory/impl/run_e2e_export.py index aea96d497c..8191d8df56 100644 --- a/dimos/memory/impl/run_e2e_export.py +++ b/dimos/memory/impl/run_e2e_export.py @@ -88,11 +88,11 @@ results = embeddings.search_embedding(query_emb, k=5).fetch() slug = query_text.replace(" ", "_")[:30] - for i, r in enumerate(results): - img_obs = sharp.at(r.ts, tolerance=0.01).one() - fname = OUT_DIR / f"{slug}_{i + 1}_id{r.id}_ts{r.ts:.0f}.jpg" - img_obs.data.save(str(fname)) - print(f" [{i + 1}] id={r.id} ts={r.ts:.2f} → {fname.name}") + for rank, result in enumerate(results): + # .data auto-projects to parent image via parent_id lineage + fname = OUT_DIR / f"{slug}_{rank + 1}_id{result.id}_ts{result.ts:.0f}.jpg" + result.data.save(str(fname)) + print(f" [{rank + 1}] id={result.id} ts={result.ts:.2f} → {fname.name}") session.close() store.close() diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index c333268fe3..6d1710f6fe 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -121,7 +121,7 @@ def _deserialize_tags(text: str) -> dict[str, Any]: # ── SQL building ────────────────────────────────────────────────────── # Columns selected from the meta table (no payload). -_META_COLS = "id, ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, tags" +_META_COLS = "id, ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, tags, parent_id" def _compile_filter(f: Filter, table: str) -> tuple[str, list[Any]]: @@ -312,6 +312,7 @@ def do_append( ts: float | None, pose: Any | None, tags: dict[str, Any] | None, + parent_id: int | None = None, ) -> Observation: if ts is None: ts = time.time() @@ -325,14 +326,14 @@ def do_append( if pose_cols is not None: cur = self._conn.execute( f"INSERT INTO {self._table} " - "(ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, tags) " - "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", - (ts, *pose_cols, tags_json), + "(ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, tags, parent_id) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + (ts, *pose_cols, tags_json, parent_id), ) else: cur = self._conn.execute( - f"INSERT INTO {self._table} (ts, tags) VALUES (?, ?)", - (ts, tags_json), + f"INSERT INTO {self._table} (ts, tags, parent_id) VALUES (?, ?, ?)", + (ts, tags_json, parent_id), ) row_id = cur.lastrowid assert row_id is not None @@ -360,6 +361,7 @@ def do_append( ts=ts, pose=pose, tags=tags or {}, + parent_id=parent_id, _data=payload, ) self._subject.on_next(obs) @@ -382,7 +384,7 @@ def execute_count(self, query: StreamQuery) -> int: return result[0] if result else 0 # type: ignore[no-any-return] def _row_to_obs(self, row: Any) -> Observation: - row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json = row + row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json, pid = row pose = _reconstruct_pose(px, py, pz, qx, qy, qz, qw) conn = self._conn table = self._table @@ -399,6 +401,7 @@ def loader() -> Any: ts=ts, pose=pose, tags=_deserialize_tags(tags_json), + parent_id=pid, _data_loader=loader, ) @@ -426,10 +429,11 @@ def do_append( ts: float | None, pose: Any | None, tags: dict[str, Any] | None, + parent_id: int | None = None, ) -> Observation: from dimos.models.embedding.base import Embedding - obs = super().do_append(payload, ts, pose, tags) + obs = super().do_append(payload, ts, pose, tags, parent_id) # Also insert into vec0 table if payload is an Embedding if isinstance(payload, Embedding): @@ -510,11 +514,12 @@ def _fetch_by_vector( return observations def _row_to_obs(self, row: Any) -> Observation: - row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json = row + row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json, pid = row pose = _reconstruct_pose(px, py, pz, qx, qy, qz, qw) conn = self._conn table = self._table codec = self._codec + parent_table = self._parent_table def loader() -> Any: r = conn.execute(f"SELECT data FROM {table}_payload WHERE id = ?", (row_id,)).fetchone() @@ -522,12 +527,36 @@ def loader() -> Any: raise LookupError(f"No payload for id={row_id}") return codec.decode(r[0]) + source_loader = None + if pid is not None and parent_table is not None: + + def _source_loader(parent_tbl: str = parent_table, parent_row_id: int = pid) -> Any: + r = conn.execute( + f"SELECT data FROM {parent_tbl}_payload WHERE id = ?", (parent_row_id,) + ).fetchone() + if r is None: + raise LookupError(f"No parent payload for id={parent_row_id}") + # Resolve parent codec from _streams metadata + meta = conn.execute( + "SELECT payload_module FROM _streams WHERE name = ?", (parent_tbl,) + ).fetchone() + if meta and meta[0]: + parent_type = module_path_to_type(meta[0]) + parent_codec = codec_for_type(parent_type) + else: + parent_codec = codec + return parent_codec.decode(r[0]) + + source_loader = _source_loader + return EmbeddingObservation( id=row_id, ts=ts, pose=pose, tags=_deserialize_tags(tags_json), + parent_id=pid, _data_loader=loader, + _source_data_loader=source_loader, ) @@ -552,8 +581,9 @@ def do_append( ts: float | None, pose: Any | None, tags: dict[str, Any] | None, + parent_id: int | None = None, ) -> Observation: - obs = super().do_append(payload, ts, pose, tags) + obs = super().do_append(payload, ts, pose, tags, parent_id) text = str(payload) if payload is not None else "" self._conn.execute( @@ -747,9 +777,14 @@ def materialize_transform( live: bool = False, backfill_only: bool = False, ) -> Stream[Any]: + # Resolve source table name for parent lineage + source_table = None + if source._backend is not None: + source_table = source._backend.stream_name + target: Stream[Any] if isinstance(transformer, EmbeddingTransformer): - target = self.embedding_stream(name, payload_type) + target = self.embedding_stream(name, payload_type, parent_table=source_table) else: target = self.stream(name, payload_type) diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index a9adb04981..f8ed0cae1f 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -65,6 +65,7 @@ def do_append( ts: float | None, pose: Any | None, tags: dict[str, Any] | None, + parent_id: int | None = None, ) -> Observation: ... @property def appended_subject(self) -> Subject[Observation]: ... # type: ignore[type-arg] @@ -125,9 +126,10 @@ def append( ts: float | None = None, pose: PoseLike | None = None, tags: dict[str, Any] | None = None, + parent_id: int | None = None, ) -> Observation: backend = self._require_backend() - return backend.do_append(payload, ts, pose, tags) + return backend.do_append(payload, ts, pose, tags, parent_id) # ── Temporal filters ────────────────────────────────────────────── @@ -403,11 +405,13 @@ def append( ts: float | None = None, pose: PoseLike | None = None, tags: dict[str, Any] | None = None, + parent_id: int | None = None, ) -> Observation: obs = Observation( id=self._next_id, ts=ts, tags=tags or {}, + parent_id=parent_id, _data=payload, ) self._next_id += 1 diff --git a/dimos/memory/transformer.py b/dimos/memory/transformer.py index 964a016181..d4559a7265 100644 --- a/dimos/memory/transformer.py +++ b/dimos/memory/transformer.py @@ -67,9 +67,9 @@ def _apply(self, obs: Observation, target: Stream[R]) -> None: return if isinstance(result, list): for item in result: - target.append(item, ts=obs.ts, pose=obs.pose, tags=obs.tags) + target.append(item, ts=obs.ts, pose=obs.pose, tags=obs.tags, parent_id=obs.id) else: - target.append(result, ts=obs.ts, pose=obs.pose, tags=obs.tags) + target.append(result, ts=obs.ts, pose=obs.pose, tags=obs.tags, parent_id=obs.id) class QualityWindowTransformer(Transformer[T, T]): @@ -104,7 +104,11 @@ def process(self, source: Stream[T], target: Stream[T]) -> None: if (ts - window_start) >= self._window: if best_obs is not None: target.append( - best_obs.data, ts=best_obs.ts, pose=best_obs.pose, tags=best_obs.tags + best_obs.data, + ts=best_obs.ts, + pose=best_obs.pose, + tags=best_obs.tags, + parent_id=best_obs.id, ) window_start = ts best_score = -1.0 @@ -116,7 +120,13 @@ def process(self, source: Stream[T], target: Stream[T]) -> None: best_obs = obs if best_obs is not None: - target.append(best_obs.data, ts=best_obs.ts, pose=best_obs.pose, tags=best_obs.tags) + target.append( + best_obs.data, + ts=best_obs.ts, + pose=best_obs.pose, + tags=best_obs.tags, + parent_id=best_obs.id, + ) def on_append(self, obs: Observation, target: Stream[T]) -> None: ts = obs.ts or 0.0 @@ -131,6 +141,7 @@ def on_append(self, obs: Observation, target: Stream[T]) -> None: ts=self._best_obs.ts, pose=self._best_obs.pose, tags=self._best_obs.tags, + parent_id=self._best_obs.id, ) self._window_start = ts self._best_score = -1.0 @@ -166,10 +177,10 @@ def process(self, source: Stream[Any], target: Stream[Embedding]) -> None: if not isinstance(embeddings, list): embeddings = [embeddings] for obs, emb in zip(page, embeddings, strict=True): - target.append(emb, ts=obs.ts, pose=obs.pose, tags=obs.tags) + target.append(emb, ts=obs.ts, pose=obs.pose, tags=obs.tags, parent_id=obs.id) def on_append(self, obs: Observation, target: Stream[Embedding]) -> None: emb = self.model.embed(obs.data) if isinstance(emb, list): emb = emb[0] - target.append(emb, ts=obs.ts, pose=obs.pose, tags=obs.tags) + target.append(emb, ts=obs.ts, pose=obs.pose, tags=obs.tags, parent_id=obs.id) diff --git a/dimos/memory/types.py b/dimos/memory/types.py index bc94daad4e..fb031314bf 100644 --- a/dimos/memory/types.py +++ b/dimos/memory/types.py @@ -33,6 +33,7 @@ class Observation: ts: float | None = None pose: PoseStamped | None = None tags: dict[str, Any] = field(default_factory=dict) + parent_id: int | None = field(default=None, repr=False) _data: Any = field(default=_UNSET, repr=False) _data_loader: Callable[[], Any] | None = field(default=None, repr=False, compare=False) From d44aaaf25206cc760b1781034f4803544567f5c9 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 5 Mar 2026 16:53:15 +0800 Subject: [PATCH 012/118] Wire parent_stream into _streams registry, add tasks.md gap analysis - materialize_transform() now UPDATEs _streams.parent_stream so stream-level lineage is discoverable (prerequisite for .join()) - Fix mypy: narrow parent_table type in _source_loader closure - Add plans/memory/tasks.md documenting all spec-vs-impl gaps --- dimos/memory/impl/sqlite.py | 11 ++- plans/memory/tasks.md | 135 ++++++++++++++++++++++++++++++++++++ 2 files changed, 145 insertions(+), 1 deletion(-) create mode 100644 plans/memory/tasks.md diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index 6d1710f6fe..5aa582d78c 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -529,8 +529,9 @@ def loader() -> Any: source_loader = None if pid is not None and parent_table is not None: + _pt: str = parent_table # narrowed from str | None by the guard above - def _source_loader(parent_tbl: str = parent_table, parent_row_id: int = pid) -> Any: + def _source_loader(parent_tbl: str = _pt, parent_row_id: int = pid) -> Any: r = conn.execute( f"SELECT data FROM {parent_tbl}_payload WHERE id = ?", (parent_row_id,) ).fetchone() @@ -788,6 +789,14 @@ def materialize_transform( else: target = self.stream(name, payload_type) + # Record parent lineage in _streams registry + if source_table is not None: + self._conn.execute( + "UPDATE _streams SET parent_stream = ? WHERE name = ?", + (source_table, name), + ) + self._conn.commit() + # Backfill existing data if transformer.supports_backfill and not live: transformer.process(source, target) diff --git a/plans/memory/tasks.md b/plans/memory/tasks.md new file mode 100644 index 0000000000..0197924b78 --- /dev/null +++ b/plans/memory/tasks.md @@ -0,0 +1,135 @@ +# Memory2 — Remaining Tasks + +Gap analysis between `plans/memory/` specs and `dimos/memory/` implementation. + +## P0 — Security / Correctness + +### 1. Stream name validation + +Stream names are interpolated directly into SQL via f-strings. No validation exists — arbitrary input is a SQL injection vector. + +**Spec** (`sqlite.md`): `^[a-zA-Z_][a-zA-Z0-9_]{0,63}$`, reject with `ValueError`. + +**Where**: Add a `_validate_stream_name(name)` check at the top of `SqliteSession.stream()`, `.text_stream()`, `.embedding_stream()`. + +### 2. `_clone()` type annotation vs runtime + +`Stream._clone()` (`stream.py:94-108`) is annotated `-> Stream[T]`, but at runtime it uses `self.__class__.__new__(self.__class__)` which correctly preserves the subclass. So `EmbeddingStream.after(t)` returns an `EmbeddingStream` at runtime — no bug. + +The annotation is wrong for mypy though. Consider `-> Self` (from `typing_extensions`) if we want strict typing. Low priority — runtime works. + +## P1 — Core API Gaps + +### 3. Wire `parent_stream` into `_streams` registry + +`_register_stream()` (`sqlite.py:847-861`) never writes the `parent_stream` column. The column exists in the schema but is always NULL. + +**Where**: `materialize_transform()` (`sqlite.py:770-799`) knows both `source_table` and `name`. Pass `parent_stream=source_table` to `_register_stream()`, and update `_register_stream` to accept and INSERT it. + +This is a prerequisite for `.join()` and stream-level lineage discovery. + +### 4. Implement `.join()` — cross-stream lineage + +`api.md` specifies: +```python +for det, img in detections.after(t).join(images): + print(f"Detected {det.data} in image at {img.pose}") +``` + +Currently only a `project_to()` stub exists that raises `NotImplementedError`. + +**Design** (from `sqlite.md`): +```sql +SELECT c.*, p.* +FROM {self}_meta c +JOIN {target}_meta p ON c.parent_id = p.id +WHERE c.id IN (/* current filtered set */) +``` + +Both sides return `Observation` with lazy `.data`. Yields `tuple[Observation, Observation]`. + +**Depends on**: Task 3 (parent_stream in registry) for discovering which stream to join against. Alternatively, `.join(target)` takes the target explicitly, so parent_stream metadata is nice-to-have but not strictly required — `parent_id` column is sufficient. + +### 5. Filtered `.appended` — predicate-filtered reactive subscriptions + +`api.md` specifies: +```python +images.near(kitchen_pose, 5.0).appended.subscribe(...) +``` + +Current impl (`stream.py:276-278`) returns the raw Subject regardless of filters. + +**Fix** (from `sqlite.md`): When `self._query.filters` is non-empty, pipe the root subject through `ops.filter()` that evaluates each predicate in Python: + +```python +@property +def appended(self): + backend = self._require_backend() + obs = backend.appended_subject + if not self._query.filters: + return obs + return obs.pipe(ops.filter(lambda o: self._matches_filters(o))) +``` + +Each filter type needs a `matches(obs) -> bool` method for Python-side evaluation: +- `AfterFilter`: `obs.ts > self.t` +- `NearFilter`: Euclidean distance check +- `TagsFilter`: dict subset check +- etc. + +### 6. Incremental backfill + +`sqlite.md` specifies that re-running a stored transform resumes from the last processed item: + +```python +max_parent = conn.execute( + f"SELECT MAX(parent_id) FROM {target_name}" +).fetchone()[0] + +if max_parent is not None: + source = source.after_id(max_parent) # internal: WHERE id > ? +``` + +**Where**: `materialize_transform()` (`sqlite.py:791-793`). Before calling `transformer.process()`, check if target already has rows and filter source accordingly. + +**Needs**: An internal `_after_id(row_id)` filter (not exposed in public API) that adds `WHERE id > ?`. + +## P2 — Robustness + +### 7. Separate connections per session + +`SqliteStore.session()` (`sqlite.py:886-887`) shares `self._conn` across all sessions. The spec says each session should own its own connection. + +**Fix**: `session()` should call `sqlite3.connect(self._path)` + WAL pragma + extension loading each time, not reuse `self._conn`. Store keeps the path, sessions get independent connections. + +This is required for multi-threaded use (e.g., one session writing in a background thread, another querying in the main thread). + +### 8. `_CollectorStream` doesn't set pose on observations + +`_CollectorStream.append()` (`stream.py:401-419`) accepts `pose` but doesn't store it on the `Observation`: + +```python +obs = Observation(id=self._next_id, ts=ts, tags=tags or {}, parent_id=parent_id, _data=payload) +# pose is silently dropped +``` + +**Fix**: Add `pose=pose` to the Observation constructor call. + +## P3 — Future (not blocking) + +### 9. Query objects — composable 4D regions + soft scoring + +`query_objects.md` proposes `Criterion` types (`TimeRange`, `Sphere`, `TimeProximity`, `SpatialProximity`, `EmbeddingSimilarity`) with `&`/`|`/`~` composition and weighted `Score()`. + +Explicitly Phase 2. Current flat filter API covers all simple cases. Implement when real usage demands soft scoring or region composition. + +### 10. `questions.md` hard cases + +Unresolved query patterns from the product requirements: +- Negation queries ("when did I NOT see the cat") +- Temporal regularity ("what time does the mailman come") +- Cross-agent memory diff +- Conditional pose integration +- Event-anchored multi-stream slicing + +These require extensions beyond the current Stream API — likely built on top of the composable query layer (task 9). From fa31471e608bf5618349d7ad5859ccb04947e85b Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 5 Mar 2026 17:12:28 +0800 Subject: [PATCH 013/118] Implement project_to() for cross-stream lineage projection Adds LineageFilter that compiles to nested SQL subqueries walking the parent_id chain. project_to(target) returns a chainable target Stream using the same _with_filter mechanism as .after(), .near(), etc. Also fixes _session propagation in search_embedding/search_text. --- dimos/memory/impl/sqlite.py | 101 +++++++++++++++++++++++++ dimos/memory/impl/test_sqlite.py | 126 +++++++++++++++++++++++++++++++ dimos/memory/store.py | 9 +++ dimos/memory/stream.py | 35 ++++++++- dimos/memory/types.py | 14 ++++ plans/memory/tasks.md | 20 ++--- 6 files changed, 288 insertions(+), 17 deletions(-) diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index 5aa582d78c..3046ecc4bf 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -53,6 +53,7 @@ EmbeddingObservation, EmbeddingSearchFilter, Filter, + LineageFilter, NearFilter, Observation, StreamInfo, @@ -148,9 +149,82 @@ def _compile_filter(f: Filter, table: str) -> tuple[str, list[Any]]: return "1=1", [] if isinstance(f, TextSearchFilter): return "1=1", [] + if isinstance(f, LineageFilter): + inner_sql, params = _compile_ids(f.source_query, f.source_table, select_col="parent_id") + for hop in f.hops: + inner_sql = f"SELECT parent_id FROM {hop} WHERE id IN ({inner_sql})" + return f"{table}.id IN ({inner_sql})", params raise TypeError(f"Unknown filter type: {type(f)}") +def _compile_ids( + query: StreamQuery, table: str, *, select_col: str = "id" +) -> tuple[str, list[Any]]: + """Compile a StreamQuery to ``SELECT {col} FROM {table} WHERE ...``. + + Unlike ``_compile_query``, this handles *all* filter types as SQL — including + EmbeddingSearchFilter and TextSearchFilter as inline subqueries — so that the + result can be nested inside another query (used by LineageFilter). + """ + where_parts: list[str] = [] + params: list[Any] = [] + joins: list[str] = [] + + for f in query.filters: + if isinstance(f, EmbeddingSearchFilter): + where_parts.append( + f"{table}.id IN (SELECT rowid FROM {table}_vec WHERE embedding MATCH ? AND k = ?)" + ) + params.extend([json.dumps(f.query), f.k]) + elif isinstance(f, TextSearchFilter): + fts_sub = f"SELECT rowid FROM {table}_fts WHERE content MATCH ?" + fts_params: list[Any] = [f.text] + if f.k is not None: + fts_sub += " LIMIT ?" + fts_params.append(f.k) + where_parts.append(f"{table}.id IN ({fts_sub})") + params.extend(fts_params) + elif isinstance(f, NearFilter): + joins.append(f"JOIN {table}_rtree AS r ON r.id = {table}.id") + pose_parts = _decompose_pose(f.pose) + if pose_parts is not None: + x, y, z = pose_parts[0], pose_parts[1], pose_parts[2] + else: + x, y, z = 0.0, 0.0, 0.0 + where_parts.append( + "r.min_x >= ? AND r.max_x <= ? AND " + "r.min_y >= ? AND r.max_y <= ? AND " + "r.min_z >= ? AND r.max_z <= ?" + ) + params.extend( + [x - f.radius, x + f.radius, y - f.radius, y + f.radius, z - f.radius, z + f.radius] + ) + else: + # Simple filters + LineageFilter → delegate to _compile_filter + sql_frag, p = _compile_filter(f, table) + where_parts.append(sql_frag) + params.extend(p) + + where = " AND ".join(where_parts) if where_parts else "1=1" + join_clause = " ".join(joins) + + sql = f"SELECT {table}.{select_col} FROM {table}" + if join_clause: + sql += f" {join_clause}" + sql += f" WHERE {where}" + + if query.order_field: + sql += f" ORDER BY {query.order_field}" + if query.order_desc: + sql += " DESC" + if query.limit_val is not None: + sql += f" LIMIT {query.limit_val}" + if query.offset_val is not None: + sql += f" OFFSET {query.offset_val}" + + return sql, params + + def _has_near_filter(query: StreamQuery) -> NearFilter | None: for f in query.filters: if isinstance(f, NearFilter): @@ -659,6 +733,33 @@ def __init__(self, conn: sqlite3.Connection) -> None: self._streams: dict[str, Stream[Any]] = {} self._ensure_meta_table() + def resolve_lineage_chain(self, source: str, target: str) -> tuple[str, ...]: + """Walk ``_streams.parent_stream`` from *source* toward *target*. + + Returns intermediate table names (empty tuple for direct parent). + """ + current = source + intermediates: list[str] = [] + visited = {source} + + while True: + row = self._conn.execute( + "SELECT parent_stream FROM _streams WHERE name = ?", (current,) + ).fetchone() + if not row or not row[0]: + raise ValueError(f"No lineage path from {source!r} to {target!r}") + + parent_name: str = row[0] + if parent_name == target: + return tuple(intermediates) + + if parent_name in visited: + raise ValueError(f"Cycle detected in lineage chain at {parent_name!r}") + + visited.add(parent_name) + intermediates.append(parent_name) + current = parent_name + def _ensure_meta_table(self) -> None: self._conn.execute( "CREATE TABLE IF NOT EXISTS _streams (" diff --git a/dimos/memory/impl/test_sqlite.py b/dimos/memory/impl/test_sqlite.py index 267736e83e..77a5949684 100644 --- a/dimos/memory/impl/test_sqlite.py +++ b/dimos/memory/impl/test_sqlite.py @@ -467,6 +467,132 @@ def test_iter(self, session: SqliteSession, images: list[Image]) -> None: assert all(_img_close(a, b) for a, b in zip(items, images[:3], strict=True)) +class TestProjectTo: + def test_single_hop(self, session: SqliteSession, images: list[Image]) -> None: + """project_to follows parent_id one hop: embeddings → images.""" + + class FakeEmbedder(EmbeddingModel): + device = "cpu" + + def embed(self, *imgs: Image) -> Embedding | list[Embedding]: # type: ignore[override] + results = [] + for img in imgs: + val = float(img.data.mean()) / 255.0 + results.append(Embedding(np.array([val, 1.0 - val, 0.0], dtype=np.float32))) + return results if len(results) > 1 else results[0] + + def embed_text(self, *texts: str) -> Embedding | list[Embedding]: + raise NotImplementedError + + imgs = session.stream("pt_images", Image) + imgs.append(images[0], ts=1.0) + imgs.append(images[1], ts=2.0) + imgs.append(images[2], ts=3.0) + + embs = imgs.transform(EmbeddingTransformer(FakeEmbedder())).store("pt_embs") + assert embs.count() == 3 + + # Search for top-2, then project back to images + results = embs.search_embedding([0.5, 0.5, 0.0], k=2).project_to(imgs) + projected = results.fetch() + assert len(projected) == 2 + # Projected observations have image data, not embeddings + for obs in projected: + assert ( + _img_close(obs.data, images[0]) + or _img_close(obs.data, images[1]) + or _img_close(obs.data, images[2]) + ) + + def test_project_to_with_chained_filter( + self, session: SqliteSession, images: list[Image] + ) -> None: + """project_to result is a chainable Stream — further filters work.""" + + class FakeEmbedder(EmbeddingModel): + device = "cpu" + + def embed(self, *imgs: Image) -> Embedding | list[Embedding]: # type: ignore[override] + results = [] + for img in imgs: + val = float(img.data.mean()) / 255.0 + results.append(Embedding(np.array([val, 1.0 - val, 0.0], dtype=np.float32))) + return results if len(results) > 1 else results[0] + + def embed_text(self, *texts: str) -> Embedding | list[Embedding]: + raise NotImplementedError + + imgs = session.stream("ptc_images", Image) + imgs.append(images[0], ts=1.0) + imgs.append(images[1], ts=5.0) + imgs.append(images[2], ts=10.0) + + embs = imgs.transform(EmbeddingTransformer(FakeEmbedder())).store("ptc_embs") + + # Project all embeddings to images, then filter by time + projected = embs.project_to(imgs).after(3.0) + results = projected.fetch() + # Only images with ts > 3.0 should remain + assert all(r.ts is not None and r.ts > 3.0 for r in results) + + def test_two_hop(self, session: SqliteSession, images: list[Image]) -> None: + """project_to walks multi-hop lineage: embeddings → filtered → raw.""" + + class FakeEmbedder(EmbeddingModel): + device = "cpu" + + def embed(self, *imgs: Image) -> Embedding | list[Embedding]: # type: ignore[override] + results = [] + for img in imgs: + val = float(img.data.mean()) / 255.0 + results.append(Embedding(np.array([val, 1.0 - val, 0.0], dtype=np.float32))) + return results if len(results) > 1 else results[0] + + def embed_text(self, *texts: str) -> Embedding | list[Embedding]: + raise NotImplementedError + + raw = session.stream("th_raw", Image) + raw.append(images[0], ts=1.0) + raw.append(images[1], ts=2.0) + raw.append(images[2], ts=3.0) + + # Passthrough transform to create an intermediate stream + mid = raw.transform(lambda img: img).store("th_mid", Image) + assert mid.count() == 3 + + embs = mid.transform(EmbeddingTransformer(FakeEmbedder())).store("th_embs") + assert embs.count() == 3 + + # Two-hop: embeddings → mid → raw + projected = embs.search_embedding([0.5, 0.5, 0.0], k=2).project_to(raw) + results = projected.fetch() + assert len(results) == 2 + + def test_count_on_projected(self, session: SqliteSession, images: list[Image]) -> None: + """count() works on projected streams.""" + + class FakeEmbedder(EmbeddingModel): + device = "cpu" + + def embed(self, *imgs: Image) -> Embedding | list[Embedding]: # type: ignore[override] + results = [] + for img in imgs: + val = float(img.data.mean()) / 255.0 + results.append(Embedding(np.array([val, 1.0 - val, 0.0], dtype=np.float32))) + return results if len(results) > 1 else results[0] + + def embed_text(self, *texts: str) -> Embedding | list[Embedding]: + raise NotImplementedError + + imgs = session.stream("ptcnt_images", Image) + imgs.append(images[0], ts=1.0) + imgs.append(images[1], ts=2.0) + + embs = imgs.transform(EmbeddingTransformer(FakeEmbedder())).store("ptcnt_embs") + projected = embs.search_embedding([0.5, 0.5, 0.0], k=1).project_to(imgs) + assert projected.count() == 1 + + class TestStoreReopen: def test_data_persists(self, tmp_path: object, images: list[Image]) -> None: from pathlib import Path diff --git a/dimos/memory/store.py b/dimos/memory/store.py index a5037caa04..83505b27fe 100644 --- a/dimos/memory/store.py +++ b/dimos/memory/store.py @@ -75,6 +75,15 @@ def materialize_transform( ) -> Stream[Any]: """Create a stored stream from a transform pipeline.""" + @abstractmethod + def resolve_lineage_chain(self, source: str, target: str) -> tuple[str, ...]: + """Return intermediate tables in the parent_id chain from source to target. + + Single hop (source directly parents target) returns ``()``. + Two hops (source → mid → target) returns ``("mid",)``. + Raises ``ValueError`` if no lineage path exists. + """ + @abstractmethod def close(self) -> None: ... diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index f8ed0cae1f..75ce1a61ef 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -30,6 +30,7 @@ EmbeddingObservation, EmbeddingSearchFilter, Filter, + LineageFilter, NearFilter, Observation, StreamQuery, @@ -220,8 +221,30 @@ def store(self, name: str | None = None) -> Stream[T]: # ── Cross-stream lineage ────────────────────────────────────────── - def project_to(self, target: Stream[Any]) -> Stream[Any]: - raise NotImplementedError("project_to requires a stored stream with lineage") + def project_to(self, target: Stream[R]) -> Stream[R]: + """Follow parent_id lineage to project observations onto the target stream. + + Returns a filtered *target* Stream containing only observations that are + ancestors of the current (source) query results. The result is a normal + Stream — all chaining, pagination, and lazy loading work as usual. + """ + backend = self._require_backend() + target_backend = target._require_backend() + session = self._session + if session is None: + raise TypeError("project_to requires a session-backed stream") + + source_table = backend.stream_name + target_table = target_backend.stream_name + hops = session.resolve_lineage_chain(source_table, target_table) + + return target._with_filter( + LineageFilter( + source_table=source_table, + source_query=self._query, + hops=hops, + ) + ) # ── Iteration ───────────────────────────────────────────────────── @@ -295,7 +318,9 @@ def search_embedding( vec = list(query) clone = self._with_filter(EmbeddingSearchFilter(vec, k)) # Preserve EmbeddingStream type - es: EmbeddingStream[T] = EmbeddingStream(backend=clone._backend, query=clone._query) + es: EmbeddingStream[T] = EmbeddingStream( + backend=clone._backend, query=clone._query, session=clone._session + ) return es def fetch(self) -> list[EmbeddingObservation]: # type: ignore[override] @@ -336,7 +361,9 @@ class TextStream(Stream[T]): def search_text(self, text: str, *, k: int | None = None) -> TextStream[T]: clone = self._with_filter(TextSearchFilter(text, k)) - ts: TextStream[T] = TextStream(backend=clone._backend, query=clone._query) + ts: TextStream[T] = TextStream( + backend=clone._backend, query=clone._query, session=clone._session + ) return ts diff --git a/dimos/memory/types.py b/dimos/memory/types.py index fb031314bf..0cca4917c7 100644 --- a/dimos/memory/types.py +++ b/dimos/memory/types.py @@ -135,6 +135,19 @@ class TextSearchFilter: k: int | None +@dataclass(frozen=True) +class LineageFilter: + """Filter to rows that are ancestors of observations in another stream. + + Used by ``project_to`` — compiles to a nested SQL subquery that walks the + ``parent_id`` chain from *source_table* through *hops* to the target. + """ + + source_table: str + source_query: StreamQuery + hops: tuple[str, ...] # intermediate tables between source and target + + Filter: TypeAlias = ( AfterFilter | BeforeFilter @@ -144,6 +157,7 @@ class TextSearchFilter: | TagsFilter | EmbeddingSearchFilter | TextSearchFilter + | LineageFilter ) diff --git a/plans/memory/tasks.md b/plans/memory/tasks.md index 0197924b78..82d2a4e964 100644 --- a/plans/memory/tasks.md +++ b/plans/memory/tasks.md @@ -28,7 +28,11 @@ The annotation is wrong for mypy though. Consider `-> Self` (from `typing_extens This is a prerequisite for `.join()` and stream-level lineage discovery. -### 4. Implement `.join()` — cross-stream lineage +### 4. ~~Implement `.project_to()` — cross-stream lineage~~ ✅ + +Implemented. `project_to(target)` adds a `LineageFilter` to the target stream (same `_with_filter` mechanism as `.after()`, `.near()`, etc.). The filter compiles to a SQL subquery walking the `parent_id` chain. Multi-hop lineage is resolved via `_streams.parent_stream` registry. Result is a fully chainable `Stream`. + +### 4b. Implement `.join()` — cross-stream lineage returning pairs `api.md` specifies: ```python @@ -36,19 +40,9 @@ for det, img in detections.after(t).join(images): print(f"Detected {det.data} in image at {img.pose}") ``` -Currently only a `project_to()` stub exists that raises `NotImplementedError`. - -**Design** (from `sqlite.md`): -```sql -SELECT c.*, p.* -FROM {self}_meta c -JOIN {target}_meta p ON c.parent_id = p.id -WHERE c.id IN (/* current filtered set */) -``` - -Both sides return `Observation` with lazy `.data`. Yields `tuple[Observation, Observation]`. +Unlike `project_to()` which returns a `Stream`, `join()` yields `tuple[Observation, Observation]` pairs. This is a terminal operation (not chainable) since the return type is pairs, not observations. -**Depends on**: Task 3 (parent_stream in registry) for discovering which stream to join against. Alternatively, `.join(target)` takes the target explicitly, so parent_stream metadata is nice-to-have but not strictly required — `parent_id` column is sufficient. +**Depends on**: ~~Task 3~~ Done — `parent_stream` is now written by `materialize_transform()` and read by `resolve_lineage_chain()`. ### 5. Filtered `.appended` — predicate-filtered reactive subscriptions From 41a06f0ded9728bc6cbef3f52d85888d46cfbc13 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 5 Mar 2026 17:21:19 +0800 Subject: [PATCH 014/118] Make search_embedding auto-project to source stream MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit EmbeddingStream is a semantic index — search results should be source observations (Images), not Embedding objects. search_embedding now auto-projects via project_to when lineage exists, falling back to EmbeddingStream for standalone streams without parent lineage. --- dimos/memory/impl/run_e2e_export.py | 2 +- dimos/memory/impl/sqlite.py | 6 ++ dimos/memory/impl/test_sqlite.py | 86 +++++++++++++++++++++++------ dimos/memory/store.py | 4 ++ dimos/memory/stream.py | 26 +++++++-- 5 files changed, 101 insertions(+), 23 deletions(-) diff --git a/dimos/memory/impl/run_e2e_export.py b/dimos/memory/impl/run_e2e_export.py index 8191d8df56..9f1167980a 100644 --- a/dimos/memory/impl/run_e2e_export.py +++ b/dimos/memory/impl/run_e2e_export.py @@ -89,7 +89,7 @@ slug = query_text.replace(" ", "_")[:30] for rank, result in enumerate(results): - # .data auto-projects to parent image via parent_id lineage + # search_embedding auto-projects to source images fname = OUT_DIR / f"{slug}_{rank + 1}_id{result.id}_ts{result.ts:.0f}.jpg" result.data.save(str(fname)) print(f" [{rank + 1}] id={result.id} ts={result.ts:.2f} → {fname.name}") diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index 3046ecc4bf..d4a2e79e0a 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -733,6 +733,12 @@ def __init__(self, conn: sqlite3.Connection) -> None: self._streams: dict[str, Stream[Any]] = {} self._ensure_meta_table() + def resolve_parent_stream(self, name: str) -> str | None: + row = self._conn.execute( + "SELECT parent_stream FROM _streams WHERE name = ?", (name,) + ).fetchone() + return row[0] if row and row[0] else None + def resolve_lineage_chain(self, source: str, target: str) -> tuple[str, ...]: """Walk ``_streams.parent_stream`` from *source* toward *target*. diff --git a/dimos/memory/impl/test_sqlite.py b/dimos/memory/impl/test_sqlite.py index 77a5949684..b5ae14606c 100644 --- a/dimos/memory/impl/test_sqlite.py +++ b/dimos/memory/impl/test_sqlite.py @@ -307,7 +307,6 @@ class FakeEmbedder(EmbeddingModel): device = "cpu" def embed(self, *imgs: Image) -> Embedding | list[Embedding]: # type: ignore[override] - # Return a fake embedding based on mean pixel value results = [] for img in imgs: val = float(img.data.mean()) / 255.0 @@ -326,9 +325,10 @@ def embed_text(self, *texts: str) -> Embedding | list[Embedding]: emb_stream = s.transform(EmbeddingTransformer(FakeEmbedder())).store("cam_embeddings") assert emb_stream.count() == 2 - # Search by vector + # Search auto-projects to source images results = emb_stream.search_embedding([0.5, 0.5, 0.0, 0.0], k=1).fetch() assert len(results) == 1 + assert _img_close(results[0].data, images[0]) or _img_close(results[0].data, images[1]) class TestListStreams: @@ -468,8 +468,8 @@ def test_iter(self, session: SqliteSession, images: list[Image]) -> None: class TestProjectTo: - def test_single_hop(self, session: SqliteSession, images: list[Image]) -> None: - """project_to follows parent_id one hop: embeddings → images.""" + def test_search_auto_projects(self, session: SqliteSession, images: list[Image]) -> None: + """search_embedding auto-projects to source stream.""" class FakeEmbedder(EmbeddingModel): device = "cpu" @@ -492,22 +492,21 @@ def embed_text(self, *texts: str) -> Embedding | list[Embedding]: embs = imgs.transform(EmbeddingTransformer(FakeEmbedder())).store("pt_embs") assert embs.count() == 3 - # Search for top-2, then project back to images - results = embs.search_embedding([0.5, 0.5, 0.0], k=2).project_to(imgs) - projected = results.fetch() + # search_embedding auto-projects — results are Images, not Embeddings + projected = embs.search_embedding([0.5, 0.5, 0.0], k=2).fetch() assert len(projected) == 2 - # Projected observations have image data, not embeddings for obs in projected: + assert not isinstance(obs, EmbeddingObservation) assert ( _img_close(obs.data, images[0]) or _img_close(obs.data, images[1]) or _img_close(obs.data, images[2]) ) - def test_project_to_with_chained_filter( + def test_search_auto_projects_chainable( self, session: SqliteSession, images: list[Image] ) -> None: - """project_to result is a chainable Stream — further filters work.""" + """Auto-projected search results support further chaining.""" class FakeEmbedder(EmbeddingModel): device = "cpu" @@ -529,14 +528,40 @@ def embed_text(self, *texts: str) -> Embedding | list[Embedding]: embs = imgs.transform(EmbeddingTransformer(FakeEmbedder())).store("ptc_embs") - # Project all embeddings to images, then filter by time + # Chain time filter after auto-projected search + results = embs.search_embedding([0.5, 0.5, 0.0], k=10).after(3.0).fetch() + assert all(r.ts is not None and r.ts > 3.0 for r in results) + + def test_explicit_project_to(self, session: SqliteSession, images: list[Image]) -> None: + """Explicit project_to works for non-search cases.""" + + class FakeEmbedder(EmbeddingModel): + device = "cpu" + + def embed(self, *imgs: Image) -> Embedding | list[Embedding]: # type: ignore[override] + results = [] + for img in imgs: + val = float(img.data.mean()) / 255.0 + results.append(Embedding(np.array([val, 1.0 - val, 0.0], dtype=np.float32))) + return results if len(results) > 1 else results[0] + + def embed_text(self, *texts: str) -> Embedding | list[Embedding]: + raise NotImplementedError + + imgs = session.stream("pte_images", Image) + imgs.append(images[0], ts=1.0) + imgs.append(images[1], ts=5.0) + imgs.append(images[2], ts=10.0) + + embs = imgs.transform(EmbeddingTransformer(FakeEmbedder())).store("pte_embs") + + # Explicit project_to without search — project all embeddings to images projected = embs.project_to(imgs).after(3.0) results = projected.fetch() - # Only images with ts > 3.0 should remain assert all(r.ts is not None and r.ts > 3.0 for r in results) def test_two_hop(self, session: SqliteSession, images: list[Image]) -> None: - """project_to walks multi-hop lineage: embeddings → filtered → raw.""" + """search_embedding auto-projects to direct parent, then project_to for second hop.""" class FakeEmbedder(EmbeddingModel): device = "cpu" @@ -556,20 +581,19 @@ def embed_text(self, *texts: str) -> Embedding | list[Embedding]: raw.append(images[1], ts=2.0) raw.append(images[2], ts=3.0) - # Passthrough transform to create an intermediate stream mid = raw.transform(lambda img: img).store("th_mid", Image) assert mid.count() == 3 embs = mid.transform(EmbeddingTransformer(FakeEmbedder())).store("th_embs") assert embs.count() == 3 - # Two-hop: embeddings → mid → raw + # search auto-projects to mid (direct parent), then project_to(raw) for second hop projected = embs.search_embedding([0.5, 0.5, 0.0], k=2).project_to(raw) results = projected.fetch() assert len(results) == 2 def test_count_on_projected(self, session: SqliteSession, images: list[Image]) -> None: - """count() works on projected streams.""" + """count() works on auto-projected search results.""" class FakeEmbedder(EmbeddingModel): device = "cpu" @@ -589,8 +613,34 @@ def embed_text(self, *texts: str) -> Embedding | list[Embedding]: imgs.append(images[1], ts=2.0) embs = imgs.transform(EmbeddingTransformer(FakeEmbedder())).store("ptcnt_embs") - projected = embs.search_embedding([0.5, 0.5, 0.0], k=1).project_to(imgs) - assert projected.count() == 1 + assert embs.search_embedding([0.5, 0.5, 0.0], k=1).count() == 1 + + def test_project_to_plain_transform(self, session: SqliteSession, images: list[Image]) -> None: + """project_to on a non-embedding derived stream (e.g., detections → images).""" + imgs = session.stream("ptplain_images", Image) + imgs.append(images[0], ts=1.0) + imgs.append(images[1], ts=2.0) + imgs.append(images[2], ts=3.0) + + # Simulate a detection transform — extracts height as an "int" stream + heights = imgs.transform(lambda im: im.height).store("ptplain_heights", int) + assert heights.count() == 3 + + # Project heights back to source images + projected = heights.after(1.5).project_to(imgs) + results = projected.fetch() + assert len(results) == 2 # ts=2.0 and ts=3.0 + for obs in results: + assert _img_close(obs.data, images[1]) or _img_close(obs.data, images[2]) + + def test_no_lineage_fallback(self, session: SqliteSession) -> None: + """search_embedding without lineage returns EmbeddingStream (no projection).""" + es = session.embedding_stream("pt_standalone", vec_dimensions=3) + es.append(Embedding(np.array([1.0, 0.0, 0.0], dtype=np.float32)), ts=1.0) + + results = es.search_embedding([1.0, 0.0, 0.0], k=1).fetch() + assert len(results) == 1 + assert isinstance(results[0], EmbeddingObservation) class TestStoreReopen: diff --git a/dimos/memory/store.py b/dimos/memory/store.py index 83505b27fe..c86b344f06 100644 --- a/dimos/memory/store.py +++ b/dimos/memory/store.py @@ -75,6 +75,10 @@ def materialize_transform( ) -> Stream[Any]: """Create a stored stream from a transform pipeline.""" + @abstractmethod + def resolve_parent_stream(self, name: str) -> str | None: + """Return the direct parent stream name, or None if no lineage exists.""" + @abstractmethod def resolve_lineage_chain(self, source: str, target: str) -> tuple[str, ...]: """Return intermediate tables in the parent_id chain from source to target. diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index 75ce1a61ef..7ae93b56b1 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -236,6 +236,10 @@ def project_to(self, target: Stream[R]) -> Stream[R]: source_table = backend.stream_name target_table = target_backend.stream_name + + if source_table == target_table: + return self # type: ignore[return-value] + hops = session.resolve_lineage_chain(source_table, target_table) return target._with_filter( @@ -309,7 +313,12 @@ def search_embedding( query: Embedding | list[float], *, k: int, - ) -> EmbeddingStream[T]: + ) -> Stream[Any]: + """Search by vector similarity. + + Auto-projects to the source stream when lineage exists, so results + contain the source data (e.g. Images) rather than Embedding objects. + """ from dimos.models.embedding.base import Embedding as EmbeddingCls if isinstance(query, EmbeddingCls): @@ -317,11 +326,20 @@ def search_embedding( else: vec = list(query) clone = self._with_filter(EmbeddingSearchFilter(vec, k)) - # Preserve EmbeddingStream type - es: EmbeddingStream[T] = EmbeddingStream( + filtered: EmbeddingStream[T] = EmbeddingStream( backend=clone._backend, query=clone._query, session=clone._session ) - return es + + # Auto-project to source stream when lineage exists + session = filtered._session + backend = filtered._backend + if session is not None and backend is not None: + parent_name = session.resolve_parent_stream(backend.stream_name) + if parent_name is not None: + source = session.stream(parent_name) + return filtered.project_to(source) + + return filtered def fetch(self) -> list[EmbeddingObservation]: # type: ignore[override] backend = self._require_backend() From bce75869934b49abef229d7c6bf8f0c36eff5f60 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 5 Mar 2026 17:32:04 +0800 Subject: [PATCH 015/118] CaptionTransformer + Florence2 batch fix - Add CaptionTransformer: wraps Captioner/VlModel, uses caption_batch() for backfill efficiency, auto-creates TextStream with FTS on .store() - Fix Florence2 caption_batch() emitting tokens (skip_special_tokens) - E2E script now uses transform pipeline for captioning search results --- dimos/memory/__init__.py | 2 ++ dimos/memory/impl/run_e2e_export.py | 27 ++++++++++++++++++++------- dimos/memory/impl/sqlite.py | 4 +++- dimos/memory/transformer.py | 29 +++++++++++++++++++++++++++++ dimos/models/vl/florence.py | 2 +- 5 files changed, 55 insertions(+), 9 deletions(-) diff --git a/dimos/memory/__init__.py b/dimos/memory/__init__.py index c8f0b8a336..0bd7bc4131 100644 --- a/dimos/memory/__init__.py +++ b/dimos/memory/__init__.py @@ -2,6 +2,7 @@ from dimos.memory.store import Session, Store from dimos.memory.stream import EmbeddingStream, Stream, TextStream from dimos.memory.transformer import ( + CaptionTransformer, EmbeddingTransformer, PerItemTransformer, Transformer, @@ -13,6 +14,7 @@ ) __all__ = [ + "CaptionTransformer", "Codec", "EmbeddingObservation", "EmbeddingStream", diff --git a/dimos/memory/impl/run_e2e_export.py b/dimos/memory/impl/run_e2e_export.py index 9f1167980a..246f4f3dea 100644 --- a/dimos/memory/impl/run_e2e_export.py +++ b/dimos/memory/impl/run_e2e_export.py @@ -23,8 +23,13 @@ from dimos.memory.impl.sqlite import SqliteStore from dimos.memory.ingest import ingest -from dimos.memory.transformer import EmbeddingTransformer, QualityWindowTransformer +from dimos.memory.transformer import ( + CaptionTransformer, + EmbeddingTransformer, + QualityWindowTransformer, +) from dimos.models.embedding.clip import CLIPModel +from dimos.models.vl.florence import Florence2Model from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.testing import TimedSensorReplay @@ -82,17 +87,25 @@ "large room", ] +print("\nLoading Florence2 for captioning...") +captioner = Florence2Model() +captioner.start() + +caption_xf = CaptionTransformer(captioner) + for query_text in queries: print(f"\nQuery: '{query_text}'") query_emb = clip.embed_text(query_text) - results = embeddings.search_embedding(query_emb, k=5).fetch() + search = embeddings.search_embedding(query_emb, k=5) + + captions = search.transform(caption_xf).fetch() + images = search.fetch() slug = query_text.replace(" ", "_")[:30] - for rank, result in enumerate(results): - # search_embedding auto-projects to source images - fname = OUT_DIR / f"{slug}_{rank + 1}_id{result.id}_ts{result.ts:.0f}.jpg" - result.data.save(str(fname)) - print(f" [{rank + 1}] id={result.id} ts={result.ts:.2f} → {fname.name}") + for rank, (cap, img) in enumerate(zip(captions, images, strict=False)): + fname = OUT_DIR / f"{slug}_{rank + 1}_id{img.id}_ts{img.ts:.0f}.jpg" + img.data.save(str(fname)) + print(f" [{rank + 1}] id={img.id} ts={img.ts:.2f} — {cap.data}") session.close() store.close() diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index d4a2e79e0a..8699f294dc 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -45,7 +45,7 @@ ) from dimos.memory.store import Session, Store from dimos.memory.stream import EmbeddingStream, Stream, TextStream -from dimos.memory.transformer import EmbeddingTransformer, Transformer +from dimos.memory.transformer import CaptionTransformer, EmbeddingTransformer, Transformer from dimos.memory.types import ( AfterFilter, AtFilter, @@ -893,6 +893,8 @@ def materialize_transform( target: Stream[Any] if isinstance(transformer, EmbeddingTransformer): target = self.embedding_stream(name, payload_type, parent_table=source_table) + elif isinstance(transformer, CaptionTransformer): + target = self.text_stream(name, payload_type) else: target = self.stream(name, payload_type) diff --git a/dimos/memory/transformer.py b/dimos/memory/transformer.py index d4559a7265..bb7489d608 100644 --- a/dimos/memory/transformer.py +++ b/dimos/memory/transformer.py @@ -21,6 +21,7 @@ from collections.abc import Callable from dimos.models.embedding.base import Embedding, EmbeddingModel + from dimos.models.vl.base import Captioner from .stream import Stream from .types import Observation @@ -153,6 +154,34 @@ def on_append(self, obs: Observation, target: Stream[T]) -> None: self._best_obs = obs +class CaptionTransformer(Transformer[Any, str]): + """Wraps a Captioner (or VlModel) to produce text captions from images. + + When stored, the output stream becomes a TextStream with FTS index. + Uses caption_batch() during backfill for efficiency. + """ + + supports_backfill: bool = True + supports_live: bool = True + + def __init__(self, model: Captioner) -> None: + self.model = model + self.output_type: type | None = str + + def process(self, source: Stream[Any], target: Stream[str]) -> None: + for page in source.fetch_pages(): + images = [obs.data for obs in page] + if not images: + continue + captions = self.model.caption_batch(*images) + for obs, cap in zip(page, captions, strict=True): + target.append(cap, ts=obs.ts, pose=obs.pose, tags=obs.tags, parent_id=obs.id) + + def on_append(self, obs: Observation, target: Stream[str]) -> None: + caption = self.model.caption(obs.data) + target.append(caption, ts=obs.ts, pose=obs.pose, tags=obs.tags, parent_id=obs.id) + + class EmbeddingTransformer(Transformer[Any, "Embedding"]): """Wraps an EmbeddingModel as a Transformer that produces Embedding output. diff --git a/dimos/models/vl/florence.py b/dimos/models/vl/florence.py index 2e6cf822a8..a44267d620 100644 --- a/dimos/models/vl/florence.py +++ b/dimos/models/vl/florence.py @@ -136,7 +136,7 @@ def caption_batch(self, *images: Image) -> list[str]: ) # Decode all - generated_texts = self._processor.batch_decode(generated_ids, skip_special_tokens=False) + generated_texts = self._processor.batch_decode(generated_ids, skip_special_tokens=True) # Parse outputs captions = [] From 8ad469eabc5f27f3dbdcbcd1b6d87056c0423367 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 5 Mar 2026 17:53:26 +0800 Subject: [PATCH 016/118] ObservationSet: fetch() returns list-like + stream-like result set fetch() now returns ObservationSet instead of plain list, keeping you in the Stream API. This enables fork-and-zip (one DB query, two uses) and in-memory re-filtering without re-querying the database. - Add matches(obs) to all filter dataclasses for in-Python evaluation - Add ListBackend (in-memory StreamBackend) and ObservationSet class - Filtered .appended reactive subscription via matches() infrastructure - Update e2e export script to use fork-and-zip pattern - 20 new tests (64 total, all passing) --- dimos/memory/__init__.py | 3 +- dimos/memory/impl/run_e2e_export.py | 8 +- dimos/memory/impl/test_sqlite.py | 203 +++++++++++++++++++++++++++- dimos/memory/stream.py | 195 ++++++++++++++++++++++---- dimos/memory/types.py | 33 +++++ 5 files changed, 406 insertions(+), 36 deletions(-) diff --git a/dimos/memory/__init__.py b/dimos/memory/__init__.py index 0bd7bc4131..132f23832b 100644 --- a/dimos/memory/__init__.py +++ b/dimos/memory/__init__.py @@ -1,6 +1,6 @@ from dimos.memory.codec import Codec, JpegCodec, LcmCodec, PickleCodec, codec_for_type from dimos.memory.store import Session, Store -from dimos.memory.stream import EmbeddingStream, Stream, TextStream +from dimos.memory.stream import EmbeddingStream, ObservationSet, Stream, TextStream from dimos.memory.transformer import ( CaptionTransformer, EmbeddingTransformer, @@ -22,6 +22,7 @@ "JpegCodec", "LcmCodec", "Observation", + "ObservationSet", "PerItemTransformer", "PickleCodec", "Session", diff --git a/dimos/memory/impl/run_e2e_export.py b/dimos/memory/impl/run_e2e_export.py index 246f4f3dea..fb35cd944a 100644 --- a/dimos/memory/impl/run_e2e_export.py +++ b/dimos/memory/impl/run_e2e_export.py @@ -96,13 +96,13 @@ for query_text in queries: print(f"\nQuery: '{query_text}'") query_emb = clip.embed_text(query_text) - search = embeddings.search_embedding(query_emb, k=5) - captions = search.transform(caption_xf).fetch() - images = search.fetch() + # Fork-and-zip: one DB query, two uses via ObservationSet + results = embeddings.search_embedding(query_emb, k=5).fetch() + captions = results.transform(caption_xf).fetch() slug = query_text.replace(" ", "_")[:30] - for rank, (cap, img) in enumerate(zip(captions, images, strict=False)): + for rank, (cap, img) in enumerate(zip(captions, results, strict=False)): fname = OUT_DIR / f"{slug}_{rank + 1}_id{img.id}_ts{img.ts:.0f}.jpg" img.data.save(str(fname)) print(f" [{rank + 1}] id={img.id} ts={img.ts:.2f} — {cap.data}") diff --git a/dimos/memory/impl/test_sqlite.py b/dimos/memory/impl/test_sqlite.py index b5ae14606c..1b337834d3 100644 --- a/dimos/memory/impl/test_sqlite.py +++ b/dimos/memory/impl/test_sqlite.py @@ -16,21 +16,16 @@ from __future__ import annotations -from typing import TYPE_CHECKING - import numpy as np import pytest from dimos.memory.impl.sqlite import SqliteSession, SqliteStore from dimos.memory.transformer import EmbeddingTransformer -from dimos.memory.types import _UNSET, EmbeddingObservation +from dimos.memory.types import _UNSET, EmbeddingObservation, Observation from dimos.models.embedding.base import Embedding, EmbeddingModel from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.testing import TimedSensorReplay -if TYPE_CHECKING: - from dimos.memory.types import Observation - def _img_close(a: Image, b: Image, max_diff: float = 5.0) -> bool: """Approximate Image equality (JPEG is lossy).""" @@ -643,6 +638,202 @@ def test_no_lineage_fallback(self, session: SqliteSession) -> None: assert isinstance(results[0], EmbeddingObservation) +class TestObservationSet: + def test_fetch_returns_observation_set( + self, session: SqliteSession, images: list[Image] + ) -> None: + from dimos.memory.stream import ObservationSet + + s = session.stream("obs_set", Image) + s.append(images[0], ts=1.0) + s.append(images[1], ts=2.0) + + result = s.fetch() + assert isinstance(result, ObservationSet) + assert len(result) == 2 + + def test_list_like_access(self, session: SqliteSession, images: list[Image]) -> None: + s = session.stream("obs_list", Image) + s.append(images[0], ts=1.0) + s.append(images[1], ts=2.0) + s.append(images[2], ts=3.0) + + result = s.fetch() + assert result[0].ts == 1.0 + assert result[-1].ts == 3.0 + assert len(result[1:]) == 2 + assert bool(result) is True + + def test_empty_observation_set(self, session: SqliteSession) -> None: + s = session.stream("obs_empty", Image) + result = s.fetch() + assert len(result) == 0 + assert bool(result) is False + + def test_iter(self, session: SqliteSession, images: list[Image]) -> None: + s = session.stream("obs_iter", Image) + for i, img in enumerate(images[:3]): + s.append(img, ts=float(i)) + + result = s.fetch() + timestamps = [obs.ts for obs in result] + assert timestamps == [0.0, 1.0, 2.0] + + def test_refilter_in_memory(self, session: SqliteSession, images: list[Image]) -> None: + """ObservationSet supports chaining filters that re-evaluate in memory.""" + s = session.stream("obs_refilter", Image) + s.append(images[0], ts=1.0) + s.append(images[1], ts=5.0) + s.append(images[2], ts=10.0) + + result = s.fetch() + assert len(result) == 3 + + # Re-filter in memory + recent = result.after(3.0).fetch() + assert len(recent) == 2 + assert all(r.ts is not None and r.ts > 3.0 for r in recent) + + def test_transform_on_observation_set( + self, session: SqliteSession, images: list[Image] + ) -> None: + """ObservationSet supports .transform() for fork-and-zip.""" + s = session.stream("obs_xf", Image) + s.append(images[0], ts=1.0) + s.append(images[1], ts=2.0) + + result = s.fetch() + shapes = result.transform(lambda im: f"{im.width}x{im.height}").fetch() + assert len(shapes) == 2 + assert shapes[0].data == f"{images[0].width}x{images[0].height}" + + def test_read_only(self, session: SqliteSession, images: list[Image]) -> None: + from dimos.memory.stream import ObservationSet + + result = ObservationSet([], session=session) + with pytest.raises(TypeError, match="read-only"): + result.append(images[0]) + + def test_ordering_in_memory(self, session: SqliteSession, images: list[Image]) -> None: + s = session.stream("obs_order", Image) + s.append(images[0], ts=3.0) + s.append(images[1], ts=1.0) + s.append(images[2], ts=2.0) + + result = s.fetch() + ordered = result.order_by("ts").fetch() + assert [o.ts for o in ordered] == [1.0, 2.0, 3.0] + + desc = result.order_by("ts", desc=True).fetch() + assert [o.ts for o in desc] == [3.0, 2.0, 1.0] + + def test_limit_offset_in_memory(self, session: SqliteSession, images: list[Image]) -> None: + s = session.stream("obs_lim", Image) + for i, img in enumerate(images): + s.append(img, ts=float(i)) + + result = s.fetch() + page = result.order_by("ts").limit(2).offset(1).fetch() + assert len(page) == 2 + assert [o.ts for o in page] == [1.0, 2.0] + + +class TestMatchesFilters: + def test_after_filter(self) -> None: + from dimos.memory.types import AfterFilter + + f = AfterFilter(5.0) + assert f.matches(Observation(id=1, ts=6.0)) is True + assert f.matches(Observation(id=2, ts=5.0)) is False + assert f.matches(Observation(id=3, ts=4.0)) is False + assert f.matches(Observation(id=4, ts=None)) is False + + def test_before_filter(self) -> None: + from dimos.memory.types import BeforeFilter + + f = BeforeFilter(5.0) + assert f.matches(Observation(id=1, ts=4.0)) is True + assert f.matches(Observation(id=2, ts=5.0)) is False + assert f.matches(Observation(id=3, ts=6.0)) is False + + def test_time_range_filter(self) -> None: + from dimos.memory.types import TimeRangeFilter + + f = TimeRangeFilter(2.0, 8.0) + assert f.matches(Observation(id=1, ts=5.0)) is True + assert f.matches(Observation(id=2, ts=2.0)) is True + assert f.matches(Observation(id=3, ts=8.0)) is True + assert f.matches(Observation(id=4, ts=1.0)) is False + assert f.matches(Observation(id=5, ts=9.0)) is False + + def test_at_filter(self) -> None: + from dimos.memory.types import AtFilter + + f = AtFilter(5.0, tolerance=1.0) + assert f.matches(Observation(id=1, ts=5.0)) is True + assert f.matches(Observation(id=2, ts=5.5)) is True + assert f.matches(Observation(id=3, ts=6.0)) is True + assert f.matches(Observation(id=4, ts=6.5)) is False + + def test_tags_filter(self) -> None: + from dimos.memory.types import TagsFilter + + f = TagsFilter({"cam": "front"}) + assert f.matches(Observation(id=1, tags={"cam": "front", "quality": "high"})) is True + assert f.matches(Observation(id=2, tags={"cam": "rear"})) is False + assert f.matches(Observation(id=3, tags={})) is False + + def test_text_search_filter(self) -> None: + from dimos.memory.types import TextSearchFilter + + f = TextSearchFilter("motor", k=None) + assert f.matches(Observation(id=1, _data="Motor fault on joint 3")) is True + assert f.matches(Observation(id=2, _data="Battery low")) is False + + def test_embedding_search_filter_always_true(self) -> None: + from dimos.memory.types import EmbeddingSearchFilter + + f = EmbeddingSearchFilter([1.0, 0.0], k=5) + assert f.matches(Observation(id=1)) is True + + def test_lineage_filter_raises(self) -> None: + from dimos.memory.types import LineageFilter, StreamQuery + + f = LineageFilter("src", StreamQuery(), ()) + with pytest.raises(NotImplementedError): + f.matches(Observation(id=1)) + + +class TestFilteredAppended: + def test_unfiltered_appended(self, session: SqliteSession, images: list[Image]) -> None: + s = session.stream("fa_unfilt", Image) + received: list[Observation] = [] + s.appended.subscribe(on_next=received.append) + + s.append(images[0], ts=1.0) + s.append(images[1], ts=5.0) + assert len(received) == 2 + + def test_filtered_appended(self, session: SqliteSession, images: list[Image]) -> None: + s = session.stream("fa_filt", Image) + received: list[Observation] = [] + s.after(3.0).appended.subscribe(on_next=received.append) + + s.append(images[0], ts=1.0) # filtered out + s.append(images[1], ts=5.0) # passes + assert len(received) == 1 + assert received[0].ts == 5.0 + + def test_tag_filtered_appended(self, session: SqliteSession, images: list[Image]) -> None: + s = session.stream("fa_tag", Image) + received: list[Observation] = [] + s.filter_tags(cam="front").appended.subscribe(on_next=received.append) + + s.append(images[0], tags={"cam": "front"}) + s.append(images[1], tags={"cam": "rear"}) + assert len(received) == 1 + + class TestStoreReopen: def test_data_persists(self, tmp_path: object, images: list[Image]) -> None: from pathlib import Path diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index 7ae93b56b1..49c938cfd8 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -23,6 +23,9 @@ overload, ) +import numpy as np +import reactivex.operators as ops + from .types import ( AfterFilter, AtFilter, @@ -258,9 +261,10 @@ def __iter__(self) -> Iterator[Observation]: # ── Terminals ───────────────────────────────────────────────────── - def fetch(self) -> list[Observation]: + def fetch(self) -> ObservationSet[T]: backend = self._require_backend() - return backend.execute_fetch(self._query) + results = backend.execute_fetch(self._query) + return ObservationSet(results, session=self._session) def fetch_pages(self, batch_size: int = 128) -> Iterator[list[Observation]]: offset = self._query.offset_val or 0 @@ -302,7 +306,19 @@ def count(self) -> int: @property def appended(self) -> Observable[Observation]: # type: ignore[type-arg] backend = self._require_backend() - return backend.appended_subject # type: ignore[return-value] + raw: Observable[Observation] = backend.appended_subject # type: ignore[assignment] + if not self._query.filters: + return raw + active = [ + f + for f in self._query.filters + if not isinstance(f, (EmbeddingSearchFilter, LineageFilter)) + ] + + def _check(o: Observation) -> bool: + return all(f.matches(o) for f in active) + + return raw.pipe(ops.filter(_check)) class EmbeddingStream(Stream[T]): @@ -341,34 +357,19 @@ def search_embedding( return filtered - def fetch(self) -> list[EmbeddingObservation]: # type: ignore[override] + def fetch(self) -> ObservationSet[T]: # type: ignore[override] backend = self._require_backend() - return backend.execute_fetch(self._query) # type: ignore[return-value] + results = backend.execute_fetch(self._query) + return ObservationSet(results, session=self._session) def one(self) -> EmbeddingObservation: # type: ignore[override] - q = StreamQuery( - filters=self._query.filters, - order_field=self._query.order_field, - order_desc=self._query.order_desc, - limit_val=1, - offset_val=self._query.offset_val, - ) - backend = self._require_backend() - results = backend.execute_fetch(q) + results = self.limit(1).fetch() if not results: raise LookupError("No matching observation") return results[0] # type: ignore[return-value] def last(self) -> EmbeddingObservation: # type: ignore[override] - q = StreamQuery( - filters=self._query.filters, - order_field="ts", - order_desc=True, - limit_val=1, - offset_val=self._query.offset_val, - ) - backend = self._require_backend() - results = backend.execute_fetch(q) + results = self.order_by("ts", desc=True).limit(1).fetch() if not results: raise LookupError("No matching observation") return results[0] # type: ignore[return-value] @@ -402,12 +403,12 @@ def __init__( self._live = live self._backfill_only = backfill_only - def fetch(self) -> list[Observation]: + def fetch(self) -> ObservationSet[R]: """Execute transform in memory, collecting results.""" collector = _CollectorStream[R]() if self._transformer.supports_backfill and not self._live: self._transformer.process(self._source, collector) - return collector.results + return ObservationSet(collector.results, session=self._source._session) def store( self, @@ -462,3 +463,147 @@ def append( self._next_id += 1 self.results.append(obs) return obs + + +class ListBackend: + """In-memory backend that evaluates StreamQuery filters in Python.""" + + def __init__(self, observations: list[Observation], name: str = "") -> None: + self._observations = observations + self._name = name + from reactivex.subject import Subject + + self._subject: Subject[Observation] = Subject() # type: ignore[type-arg] + + def execute_fetch(self, query: StreamQuery) -> list[Observation]: + results = list(self._observations) + + # Apply non-embedding filters + for f in query.filters: + if isinstance(f, (EmbeddingSearchFilter, LineageFilter)): + continue + results = [obs for obs in results if f.matches(obs)] + + # Embedding top-k pass (cosine similarity) + emb_filters = [f for f in query.filters if isinstance(f, EmbeddingSearchFilter)] + if emb_filters: + ef = emb_filters[0] + query_vec = np.array(ef.query, dtype=np.float32) + query_norm = np.linalg.norm(query_vec) + if query_norm > 0: + scored = [] + for obs in results: + if isinstance(obs, EmbeddingObservation): + obs_vec = obs.embedding.to_numpy() + else: + continue + obs_norm = np.linalg.norm(obs_vec) + if obs_norm > 0: + sim = float(np.dot(query_vec, obs_vec) / (query_norm * obs_norm)) + else: + sim = 0.0 + scored.append((sim, obs)) + scored.sort(key=lambda x: x[0], reverse=True) + results = [obs for _, obs in scored[: ef.k]] + + # Ordering + if query.order_field: + key = query.order_field + results.sort( + key=lambda obs: getattr(obs, key) if getattr(obs, key, None) is not None else 0, + reverse=query.order_desc, + ) + + # Offset / limit + if query.offset_val: + results = results[query.offset_val :] + if query.limit_val is not None: + results = results[: query.limit_val] + + return results + + def execute_count(self, query: StreamQuery) -> int: + return len(self.execute_fetch(query)) + + def do_append( + self, + payload: Any, + ts: float | None, + pose: Any | None, + tags: dict[str, Any] | None, + parent_id: int | None = None, + ) -> Observation: + raise TypeError("ObservationSet is read-only") + + @property + def appended_subject(self) -> Subject[Observation]: # type: ignore[type-arg] + return self._subject + + @property + def stream_name(self) -> str: + return self._name + + +class ObservationSet(Stream[T]): + """Materialized result set — list-like + stream-like. + + Holds Observation objects with lazy _data_loader closures. + Metadata is in memory, payload BLOBs stay in DB until .data access. + """ + + def __init__( + self, + observations: list[Observation], + *, + session: Session | None = None, + ) -> None: + self._observations = observations + backend = ListBackend(observations) + super().__init__(backend=backend, session=session) + + def _clone(self, **overrides: Any) -> Stream[T]: + """Return a plain Stream backed by same ListBackend (preserves lazy filter chaining).""" + q = self._query + new_query = StreamQuery( + filters=overrides.get("filters", q.filters), + order_field=overrides.get("order_field", q.order_field), + order_desc=overrides.get("order_desc", q.order_desc), + limit_val=overrides.get("limit_val", q.limit_val), + offset_val=overrides.get("offset_val", q.offset_val), + ) + clone: Stream[T] = Stream.__new__(Stream) + clone._backend = self._backend + clone._query = new_query + clone._session = self._session + return clone + + def append( + self, + payload: T, + *, + ts: float | None = None, + pose: PoseLike | None = None, + tags: dict[str, Any] | None = None, + parent_id: int | None = None, + ) -> Observation: + raise TypeError("ObservationSet is read-only") + + # ── List-like interface ────────────────────────────────────────── + + def __len__(self) -> int: + return len(self._observations) + + @overload + def __getitem__(self, index: int) -> Observation: ... + + @overload + def __getitem__(self, index: slice) -> list[Observation]: ... + + def __getitem__(self, index: int | slice) -> Observation | list[Observation]: + return self._observations[index] + + def __iter__(self) -> Iterator[Observation]: + return iter(self._observations) + + def __bool__(self) -> bool: + return len(self._observations) > 0 diff --git a/dimos/memory/types.py b/dimos/memory/types.py index 0cca4917c7..f1efbbd58f 100644 --- a/dimos/memory/types.py +++ b/dimos/memory/types.py @@ -16,6 +16,7 @@ from collections.abc import Callable from dataclasses import dataclass, field +import math from typing import TYPE_CHECKING, Any, TypeAlias if TYPE_CHECKING: @@ -94,46 +95,75 @@ class StreamInfo: class AfterFilter: t: float + def matches(self, obs: Observation) -> bool: + return obs.ts is not None and obs.ts > self.t + @dataclass(frozen=True) class BeforeFilter: t: float + def matches(self, obs: Observation) -> bool: + return obs.ts is not None and obs.ts < self.t + @dataclass(frozen=True) class TimeRangeFilter: t1: float t2: float + def matches(self, obs: Observation) -> bool: + return obs.ts is not None and self.t1 <= obs.ts <= self.t2 + @dataclass(frozen=True) class AtFilter: t: float tolerance: float + def matches(self, obs: Observation) -> bool: + return obs.ts is not None and abs(obs.ts - self.t) <= self.tolerance + @dataclass(frozen=True) class NearFilter: pose: Any # PoseLike radius: float + def matches(self, obs: Observation) -> bool: + if obs.pose is None: + return False + p1 = obs.pose.pose.position + p2 = self.pose.pose.position + dist = math.sqrt((p1.x - p2.x) ** 2 + (p1.y - p2.y) ** 2 + (p1.z - p2.z) ** 2) + return dist <= self.radius + @dataclass(frozen=True) class TagsFilter: tags: dict[str, Any] + def matches(self, obs: Observation) -> bool: + return all(obs.tags.get(k) == v for k, v in self.tags.items()) + @dataclass(frozen=True) class EmbeddingSearchFilter: query: list[float] k: int + def matches(self, obs: Observation) -> bool: + return True # top-k handled as special pass in ListBackend + @dataclass(frozen=True) class TextSearchFilter: text: str k: int | None + def matches(self, obs: Observation) -> bool: + return self.text.lower() in str(obs.data).lower() + @dataclass(frozen=True) class LineageFilter: @@ -147,6 +177,9 @@ class LineageFilter: source_query: StreamQuery hops: tuple[str, ...] # intermediate tables between source and target + def matches(self, obs: Observation) -> bool: + raise NotImplementedError("LineageFilter requires a database backend") + Filter: TypeAlias = ( AfterFilter From 216317991253f12e6826af873910fa4babdce81d Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 5 Mar 2026 17:58:27 +0800 Subject: [PATCH 017/118] search_embedding accepts str/image with auto-embedding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit EmbeddingStream now holds an optional model reference, so search_embedding auto-dispatches: str → embed_text(), image → embed(), Embedding/list[float] → use directly. The model is wired through materialize_transform and also accepted via embedding_stream(). --- dimos/memory/impl/run_e2e_export.py | 7 ++-- dimos/memory/impl/sqlite.py | 12 +++++- dimos/memory/impl/test_sqlite.py | 63 +++++++++++++++++++++++++++++ dimos/memory/store.py | 3 ++ dimos/memory/stream.py | 56 +++++++++++++++++++++++-- 5 files changed, 131 insertions(+), 10 deletions(-) diff --git a/dimos/memory/impl/run_e2e_export.py b/dimos/memory/impl/run_e2e_export.py index fb35cd944a..8efa5c2a6e 100644 --- a/dimos/memory/impl/run_e2e_export.py +++ b/dimos/memory/impl/run_e2e_export.py @@ -74,7 +74,7 @@ clip = CLIPModel() clip.start() sharp = session.stream("sharp_frames") - embeddings = session.embedding_stream("clip_embeddings") + embeddings = session.embedding_stream("clip_embeddings", embedding_model=clip) print(f" {sharp.count()} sharp frames, {embeddings.count()} embeddings") # 4. Search and export @@ -95,10 +95,9 @@ for query_text in queries: print(f"\nQuery: '{query_text}'") - query_emb = clip.embed_text(query_text) - # Fork-and-zip: one DB query, two uses via ObservationSet - results = embeddings.search_embedding(query_emb, k=5).fetch() + # search_embedding auto-embeds text; ObservationSet enables fork-and-zip + results = embeddings.search_embedding(query_text, k=5).fetch() captions = results.transform(caption_xf).fetch() slug = query_text.replace(" ", "_")[:30] diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index 8699f294dc..152c0f09ee 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -65,6 +65,7 @@ if TYPE_CHECKING: from dimos.memory.types import PoseProvider + from dimos.models.embedding.base import EmbeddingModel # ── Pose helpers (column-based) ────────────────────────────────────── @@ -840,9 +841,13 @@ def embedding_stream( vec_dimensions: int | None = None, pose_provider: PoseProvider | None = None, parent_table: str | None = None, + embedding_model: EmbeddingModel | None = None, ) -> EmbeddingStream[Any]: if name in self._streams: - return self._streams[name] # type: ignore[return-value] + existing = self._streams[name] + if embedding_model is not None and isinstance(existing, EmbeddingStream): + existing._embedding_model = embedding_model + return existing # type: ignore[return-value] if payload_type is None: payload_type = self._resolve_payload_type(name) @@ -862,7 +867,9 @@ def embedding_stream( if vec_dimensions is not None: backend._ensure_vec_table() - es: EmbeddingStream[Any] = EmbeddingStream(backend=backend, session=self) + es: EmbeddingStream[Any] = EmbeddingStream( + backend=backend, session=self, embedding_model=embedding_model + ) self._streams[name] = es return es @@ -893,6 +900,7 @@ def materialize_transform( target: Stream[Any] if isinstance(transformer, EmbeddingTransformer): target = self.embedding_stream(name, payload_type, parent_table=source_table) + target._embedding_model = transformer.model elif isinstance(transformer, CaptionTransformer): target = self.text_stream(name, payload_type) else: diff --git a/dimos/memory/impl/test_sqlite.py b/dimos/memory/impl/test_sqlite.py index 1b337834d3..f9c8f65371 100644 --- a/dimos/memory/impl/test_sqlite.py +++ b/dimos/memory/impl/test_sqlite.py @@ -628,6 +628,69 @@ def test_project_to_plain_transform(self, session: SqliteSession, images: list[I for obs in results: assert _img_close(obs.data, images[1]) or _img_close(obs.data, images[2]) + def test_search_by_text(self, session: SqliteSession, images: list[Image]) -> None: + """search_embedding accepts a string and auto-embeds via model.""" + + class FakeEmbedder(EmbeddingModel): + device = "cpu" + + def embed(self, *imgs: Image) -> Embedding | list[Embedding]: # type: ignore[override] + results = [] + for img in imgs: + val = float(img.data.mean()) / 255.0 + results.append(Embedding(np.array([val, 1.0 - val, 0.0], dtype=np.float32))) + return results if len(results) > 1 else results[0] + + def embed_text(self, *texts: str) -> Embedding | list[Embedding]: + results = [] + for _text in texts: + results.append(Embedding(np.array([0.5, 0.5, 0.0], dtype=np.float32))) + return results if len(results) > 1 else results[0] + + imgs = session.stream("pttxt_images", Image) + imgs.append(images[0], ts=1.0) + imgs.append(images[1], ts=2.0) + + embs = imgs.transform(EmbeddingTransformer(FakeEmbedder())).store("pttxt_embs") + + # Search with text string — auto-embeds via embed_text() + results = embs.search_embedding("a hallway", k=2).fetch() + assert len(results) == 2 + + def test_search_by_image(self, session: SqliteSession, images: list[Image]) -> None: + """search_embedding accepts an image and auto-embeds via model.""" + + class FakeEmbedder(EmbeddingModel): + device = "cpu" + + def embed(self, *imgs: Image) -> Embedding | list[Embedding]: # type: ignore[override] + results = [] + for img in imgs: + val = float(img.data.mean()) / 255.0 + results.append(Embedding(np.array([val, 1.0 - val, 0.0], dtype=np.float32))) + return results if len(results) > 1 else results[0] + + def embed_text(self, *texts: str) -> Embedding | list[Embedding]: + raise NotImplementedError + + imgs = session.stream("ptimg_images", Image) + imgs.append(images[0], ts=1.0) + imgs.append(images[1], ts=2.0) + + embs = imgs.transform(EmbeddingTransformer(FakeEmbedder())).store("ptimg_embs") + + # Search with image — auto-embeds via embed() + results = embs.search_embedding(images[0], k=1).fetch() + assert len(results) == 1 + + def test_search_no_model_raises(self, session: SqliteSession) -> None: + """search_embedding with str raises when no model is available.""" + es = session.embedding_stream("pt_nomodel", vec_dimensions=3) + es.append(Embedding(np.array([1.0, 0.0, 0.0], dtype=np.float32)), ts=1.0) + + with pytest.raises(TypeError, match="no model reference"): + es.search_embedding("hello", k=1) + def test_no_lineage_fallback(self, session: SqliteSession) -> None: """search_embedding without lineage returns EmbeddingStream (no projection).""" es = session.embedding_stream("pt_standalone", vec_dimensions=3) diff --git a/dimos/memory/store.py b/dimos/memory/store.py index c86b344f06..327f5caa02 100644 --- a/dimos/memory/store.py +++ b/dimos/memory/store.py @@ -18,6 +18,8 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + from dimos.models.embedding.base import EmbeddingModel + from .stream import EmbeddingStream, Stream, TextStream from .transformer import Transformer from .types import PoseProvider, StreamInfo @@ -56,6 +58,7 @@ def embedding_stream( vec_dimensions: int | None = None, pose_provider: PoseProvider | None = None, parent_table: str | None = None, + embedding_model: EmbeddingModel | None = None, ) -> EmbeddingStream[Any]: """Get or create an embedding stream with vec0 index.""" diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index 49c938cfd8..c63cfd9263 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -48,7 +48,7 @@ from reactivex import Observable from reactivex.subject import Subject - from dimos.models.embedding.base import Embedding + from dimos.models.embedding.base import Embedding, EmbeddingModel from dimos.msgs.geometry_msgs.Pose import PoseLike from .store import Session @@ -324,26 +324,74 @@ def _check(o: Observation) -> bool: class EmbeddingStream(Stream[T]): """Stream with a vector index. Adds search_embedding().""" + _embedding_model: EmbeddingModel | None + + def __init__( + self, + backend: StreamBackend | None = None, + *, + query: StreamQuery | None = None, + session: Session | None = None, + embedding_model: EmbeddingModel | None = None, + ) -> None: + super().__init__(backend=backend, query=query, session=session) + self._embedding_model = embedding_model + + def _require_model(self) -> EmbeddingModel: + if self._embedding_model is None: + raise TypeError( + "This embedding stream has no model reference. " + "Pass a str/image only on streams created via EmbeddingTransformer, " + "or search with a pre-computed Embedding / list[float]." + ) + return self._embedding_model + + def _clone(self, **overrides: Any) -> Stream[T]: + clone = super()._clone(**overrides) + if isinstance(clone, EmbeddingStream): + clone._embedding_model = self._embedding_model + return clone + def search_embedding( self, - query: Embedding | list[float], + query: Embedding | list[float] | str | Any, *, k: int, ) -> Stream[Any]: """Search by vector similarity. + Accepts pre-computed embeddings, raw float lists, text strings, or + images/other objects. Text and non-vector inputs are auto-embedded + using the model that created this stream. + Auto-projects to the source stream when lineage exists, so results contain the source data (e.g. Images) rather than Embedding objects. """ from dimos.models.embedding.base import Embedding as EmbeddingCls + if isinstance(query, str): + emb = self._require_model().embed_text(query) + if isinstance(emb, list): + emb = emb[0] + return self.search_embedding(emb, k=k) + if isinstance(query, EmbeddingCls): vec = query.to_numpy().tolist() - else: + elif isinstance(query, list): vec = list(query) + else: + # Assume embeddable object (Image, etc.) + emb = self._require_model().embed(query) + if isinstance(emb, list): + emb = emb[0] + return self.search_embedding(emb, k=k) + clone = self._with_filter(EmbeddingSearchFilter(vec, k)) filtered: EmbeddingStream[T] = EmbeddingStream( - backend=clone._backend, query=clone._query, session=clone._session + backend=clone._backend, + query=clone._query, + session=clone._session, + embedding_model=self._embedding_model, ) # Auto-project to source stream when lineage exists From 1bdd496c3be2cb9ed13bc1e465f8e926b9da2416 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 5 Mar 2026 18:00:49 +0800 Subject: [PATCH 018/118] Add sqlite_vec to mypy ignore list (no type stubs available) --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index a448b90edd..52aeecfbae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -372,6 +372,7 @@ module = [ "rclpy.*", "sam2.*", "sensor_msgs.*", + "sqlite_vec", "std_msgs.*", "tf2_msgs.*", "torchreid", From 5130511a8aa92dac32e8f27d63bff8346e40e96d Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 5 Mar 2026 18:38:47 +0800 Subject: [PATCH 019/118] Fix mypy + pytest errors across memory and memory_old modules - Fix SpatialImage/SpatialEntry dataclass hierarchy in memory_old - Fix import path in memory_old/test_embedding.py - Add None guard for obs.ts in run_viz_demo.py - Add payload_type/session kwargs to base Stream.store() signature - Type-annotate embeddings as EmbeddingStream in run_e2e_export.py - Add similarity scores, raw search mode, pose ingest, viz pipeline --- dimos/memory/impl/run_e2e_export.py | 42 ++++++-- dimos/memory/impl/run_viz_demo.py | 92 ++++++++++++++++++ dimos/memory/impl/sqlite.py | 58 +++++------ dimos/memory/impl/test_sqlite.py | 65 +++++++++++++ dimos/memory/ingest.py | 11 ++- dimos/memory/stream.py | 22 ++++- dimos/memory/types.py | 2 + dimos/memory/viz.py | 143 ++++++++++++++++++++++++++++ dimos/memory_old/embedding.py | 7 +- dimos/memory_old/test_embedding.py | 2 +- 10 files changed, 390 insertions(+), 54 deletions(-) create mode 100644 dimos/memory/impl/run_viz_demo.py create mode 100644 dimos/memory/viz.py diff --git a/dimos/memory/impl/run_e2e_export.py b/dimos/memory/impl/run_e2e_export.py index 8efa5c2a6e..3747f8f14e 100644 --- a/dimos/memory/impl/run_e2e_export.py +++ b/dimos/memory/impl/run_e2e_export.py @@ -12,14 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Ingest 5min robot video → sharpness filter → CLIP embed → export top matches. +"""Ingest 5min robot video → sharpness filter → CLIP embed → search & visualize. Caches the DB — re-run to just search without re-ingesting/embedding. +Outputs heatmaps and timelines to Rerun, images to disk. """ from __future__ import annotations from pathlib import Path +from typing import TYPE_CHECKING, Any + +import rerun as rr from dimos.memory.impl.sqlite import SqliteStore from dimos.memory.ingest import ingest @@ -28,14 +32,20 @@ EmbeddingTransformer, QualityWindowTransformer, ) +from dimos.memory.viz import log_similarity_timeline, similarity_heatmap from dimos.models.embedding.clip import CLIPModel from dimos.models.vl.florence import Florence2Model from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.testing import TimedSensorReplay +if TYPE_CHECKING: + from dimos.memory.stream import EmbeddingStream + OUT_DIR = Path(__file__).parent / "e2e_matches" OUT_DIR.mkdir(exist_ok=True) +rr.init("memory_e2e", spawn=True) + db_path = OUT_DIR / "e2e.db" store = SqliteStore(str(db_path)) session = store.session() @@ -46,15 +56,16 @@ if need_build: replay = TimedSensorReplay("unitree_go2_bigoffice/video") + odom = TimedSensorReplay("unitree_go2_bigoffice/odom") print("Loading CLIP...") clip = CLIPModel() clip.start() - # 1. Ingest 5 minutes - print("Ingesting 5 min of video...") + # 1. Ingest 5 minutes with odom poses + print("Ingesting 5 min of video with odom poses...") raw = session.stream("raw_video", Image) - n = ingest(raw, replay.iterate_ts(seek=5.0, duration=300.0)) + n = ingest(raw, replay.iterate_ts(seek=5.0, duration=300.0), pose_source=odom) print(f" {n} frames ingested") # 2. Sharpness filter @@ -67,17 +78,21 @@ # 3. Embed print("Embedding with CLIP...") - embeddings = sharp.transform(EmbeddingTransformer(clip)).store("clip_embeddings") + embeddings: EmbeddingStream[Any] = sharp.transform(EmbeddingTransformer(clip)).store( + "clip_embeddings" + ) # type: ignore[assignment] print(f" {embeddings.count()} embeddings stored") else: print(f"Using cached DB ({db_path})") + print("loading Clip") clip = CLIPModel() clip.start() + print("done") sharp = session.stream("sharp_frames") embeddings = session.embedding_stream("clip_embeddings", embedding_model=clip) print(f" {sharp.count()} sharp frames, {embeddings.count()} embeddings") -# 4. Search and export +# 4. Search, visualize, export queries = [ "a hallway in an office", "a person standing", @@ -95,12 +110,23 @@ for query_text in queries: print(f"\nQuery: '{query_text}'") + slug = query_text.replace(" ", "_")[:30] - # search_embedding auto-embeds text; ObservationSet enables fork-and-zip + # raw=True: get EmbeddingObservation with .similarity and .pose + raw_results = embeddings.search_embedding(query_text, k=50, raw=True).fetch() + + # Spatial heatmap → Rerun + grid = similarity_heatmap(raw_results, resolution=0.5) + print(f" Heatmap: {grid}") + rr.log(f"world/{slug}/heatmap", grid.to_rerun(colormap="inferno")) + + # Temporal timeline → Rerun + log_similarity_timeline(raw_results, entity_path=f"plots/{slug}") + + # Caption top 5 (auto-projected results for image access) results = embeddings.search_embedding(query_text, k=5).fetch() captions = results.transform(caption_xf).fetch() - slug = query_text.replace(" ", "_")[:30] for rank, (cap, img) in enumerate(zip(captions, results, strict=False)): fname = OUT_DIR / f"{slug}_{rank + 1}_id{img.id}_ts{img.ts:.0f}.jpg" img.data.save(str(fname)) diff --git a/dimos/memory/impl/run_viz_demo.py b/dimos/memory/impl/run_viz_demo.py new file mode 100644 index 0000000000..1bd5d95cba --- /dev/null +++ b/dimos/memory/impl/run_viz_demo.py @@ -0,0 +1,92 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Visual demo: similarity heatmap + timeline in Rerun. + +Run with: python -m dimos.memory.impl.run_viz_demo +Then open Rerun viewer to see the output. +""" + +from __future__ import annotations + +import numpy as np +import rerun as rr + +from dimos.memory.impl.sqlite import SqliteStore +from dimos.memory.types import EmbeddingObservation +from dimos.memory.viz import log_similarity_timeline, similarity_heatmap +from dimos.models.embedding.base import Embedding +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + +# ── Rerun setup ─────────────────────────────────────────────────────── +rr.init("memory_viz_demo", spawn=True) + +# ── Build a small DB with posed embeddings ──────────────────────────── +store = SqliteStore(":memory:") +session = store.session() + +es = session.embedding_stream("demo_emb", vec_dimensions=4) + +# Simulate a robot path with embeddings at various positions +np.random.seed(42) +n_obs = 60 +for i in range(n_obs): + angle = 2 * np.pi * i / n_obs + radius = 3.0 + 0.5 * np.sin(3 * angle) + x = radius * np.cos(angle) + y = radius * np.sin(angle) + + # Embedding: mix of two basis vectors depending on position + mix = (np.sin(angle) + 1) / 2 # 0..1 + vec = np.array([mix, 1.0 - mix, 0.1 * np.cos(angle), 0.0], dtype=np.float32) + vec /= np.linalg.norm(vec) + + pose = PoseStamped( + ts=float(i), + frame_id="world", + position=[x, y, 0.0], + orientation=[0.0, 0.0, 0.0, 1.0], + ) + es.append(Embedding(vec), ts=float(i), pose=pose) + +print(f"Created {es.count()} observations on a circular path") + +# ── Search and visualize ────────────────────────────────────────────── +query = [1.0, 0.0, 0.0, 0.0] +results = es.search_embedding(query, k=n_obs).fetch() + +print(f"Search returned {len(results)} results") +for obs in results[:5]: + assert isinstance(obs, EmbeddingObservation) + print(f" id={obs.id} ts={obs.ts:.0f} similarity={obs.similarity:.3f}") + +# 1. Similarity heatmap → OccupancyGrid → Rerun mesh +grid = similarity_heatmap(results, resolution=0.2, padding=2.0) +print(f"\nHeatmap: {grid}") +rr.log("world/heatmap", grid.to_rerun(colormap="inferno")) + +# 2. Similarity timeline → Rerun scalar plot +log_similarity_timeline(results, entity_path="plots/similarity") +print("Logged similarity timeline") + +# 3. Also log poses as arrows for spatial context +for obs in results: + if obs.pose is not None and obs.ts is not None: + rr.set_time("memory_time", timestamp=obs.ts) + rr.log("world/poses", obs.pose.to_rerun_arrow(length=0.3)) + +print("\nDone — check Rerun viewer") + +session.close() +store.close() diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index 152c0f09ee..c2e6e75c71 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -72,12 +72,11 @@ def _decompose_pose(pose: Any) -> tuple[float, float, float, float, float, float, float] | None: - """Extract (x, y, z, qx, qy, qz, qw) from a PoseStamped or similar.""" + """Extract (x, y, z, qx, qy, qz, qw) from a PoseStamped.""" if pose is None: return None - # PoseStamped has .pose.position and .pose.orientation - p = pose.pose.position - q = pose.pose.orientation + p = pose.position + q = pose.orientation return (p.x, p.y, p.z, q.x, q.y, q.z, q.w) @@ -93,18 +92,11 @@ def _reconstruct_pose( """Rebuild a PoseStamped from column values.""" if x is None: return None - from dimos.msgs.geometry_msgs.Point import Point - from dimos.msgs.geometry_msgs.Pose import Pose from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped - from dimos.msgs.geometry_msgs.Quaternion import Quaternion - from dimos.msgs.std_msgs.Header import Header return PoseStamped( - header=Header(), - pose=Pose( - position=Point(x=x, y=y or 0.0, z=z or 0.0), - orientation=Quaternion(x=qx or 0.0, y=qy or 0.0, z=qz or 0.0, w=qw or 1.0), - ), + position=[x, y or 0.0, z or 0.0], + orientation=[qx or 0.0, qy or 0.0, qz or 0.0, qw or 1.0], ) @@ -187,11 +179,8 @@ def _compile_ids( params.extend(fts_params) elif isinstance(f, NearFilter): joins.append(f"JOIN {table}_rtree AS r ON r.id = {table}.id") - pose_parts = _decompose_pose(f.pose) - if pose_parts is not None: - x, y, z = pose_parts[0], pose_parts[1], pose_parts[2] - else: - x, y, z = 0.0, 0.0, 0.0 + p = f.pose.position + x, y, z = p.x, p.y, p.z where_parts.append( "r.min_x >= ? AND r.max_x <= ? AND " "r.min_y >= ? AND r.max_y <= ? AND " @@ -250,11 +239,8 @@ def _compile_query(query: StreamQuery, table: str) -> tuple[str, list[Any]]: "r.min_y >= ? AND r.max_y <= ? AND " "r.min_z >= ? AND r.max_z <= ?" ) - pose_parts = _decompose_pose(f.pose) - if pose_parts is not None: - x, y, z = pose_parts[0], pose_parts[1], pose_parts[2] - else: - x, y, z = 0.0, 0.0, 0.0 + p = f.pose.position + x, y, z = p.x, p.y, p.z params.extend( [ x - f.radius, @@ -299,11 +285,8 @@ def _compile_count(query: StreamQuery, table: str) -> tuple[str, list[Any]]: for f in query.filters: if isinstance(f, NearFilter): joins.append(f"JOIN {table}_rtree AS r ON r.id = {table}.id") - pose_parts = _decompose_pose(f.pose) - if pose_parts is not None: - x, y, z = pose_parts[0], pose_parts[1], pose_parts[2] - else: - x, y, z = 0.0, 0.0, 0.0 + p = f.pose.position + x, y, z = p.x, p.y, p.z where_parts.append( "r.min_x >= ? AND r.max_x <= ? AND " "r.min_y >= ? AND r.max_y <= ? AND " @@ -338,16 +321,13 @@ def _compile_count(query: StreamQuery, table: str) -> tuple[str, list[Any]]: def _apply_near_post_filter(rows: list[Observation], near: NearFilter) -> list[Observation]: """Post-filter R*Tree candidates by exact Euclidean distance.""" - pose_parts = _decompose_pose(near.pose) - if pose_parts is None: - return [] - tx, ty, tz = pose_parts[0], pose_parts[1], pose_parts[2] + tp = near.pose.position result: list[Observation] = [] for obs in rows: if obs.pose is None: continue - op = obs.pose.pose.position - dist = ((op.x - tx) ** 2 + (op.y - ty) ** 2 + (op.z - tz) ** 2) ** 0.5 + op = obs.pose.position + dist = ((op.x - tp.x) ** 2 + (op.y - tp.y) ** 2 + (op.z - tp.z) ** 2) ** 0.5 if dist <= near.radius: result.append(obs) return result @@ -529,7 +509,7 @@ def _ensure_vec_table(self) -> None: return self._conn.execute( f"CREATE VIRTUAL TABLE IF NOT EXISTS {self._table}_vec " - f"USING vec0(embedding float[{self._vec_dimensions}])" + f"USING vec0(embedding float[{self._vec_dimensions}] distance_metric=cosine)" ) self._conn.commit() @@ -560,7 +540,8 @@ def _fetch_by_vector( if not vec_rows: return [] - rowids = [r[0] for r in vec_rows] + dist_map = {r[0]: r[1] for r in vec_rows} + rowids = list(dist_map.keys()) placeholders = ",".join("?" * len(rowids)) where_parts: list[str] = [f"{self._table}.id IN ({placeholders})"] @@ -582,6 +563,11 @@ def _fetch_by_vector( observations = [self._row_to_obs(r) for r in rows] + # Populate similarity scores from vec0 cosine distance (0=identical, 2=opposite) + for obs in observations: + if isinstance(obs, EmbeddingObservation): + obs.similarity = max(0.0, min(1.0, 1.0 - dist_map.get(obs.id, 0.0))) + near = _has_near_filter(query) if near is not None: observations = _apply_near_post_filter(observations, near) diff --git a/dimos/memory/impl/test_sqlite.py b/dimos/memory/impl/test_sqlite.py index f9c8f65371..42b94803e7 100644 --- a/dimos/memory/impl/test_sqlite.py +++ b/dimos/memory/impl/test_sqlite.py @@ -701,6 +701,71 @@ def test_no_lineage_fallback(self, session: SqliteSession) -> None: assert isinstance(results[0], EmbeddingObservation) +class TestSimilarityScores: + def test_search_populates_similarity(self, session: SqliteSession) -> None: + """search_embedding should populate .similarity on EmbeddingObservation.""" + es = session.embedding_stream("sim_test", vec_dimensions=4) + vecs = [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.9, 0.1, 0.0, 0.0], + ] + for i, v in enumerate(vecs): + es.append(Embedding(np.array(v, dtype=np.float32)), ts=float(i)) + + results = es.search_embedding([1.0, 0.0, 0.0, 0.0], k=3).fetch() + assert len(results) == 3 + for obs in results: + assert isinstance(obs, EmbeddingObservation) + assert obs.similarity is not None + assert 0.0 <= obs.similarity <= 1.0 + + # Exact match should have highest similarity + by_sim = sorted(results, key=lambda o: o.similarity, reverse=True) + assert by_sim[0].id == 1 # [1,0,0,0] is exact match + + def test_similarity_none_without_search(self, session: SqliteSession) -> None: + """Plain fetch() should leave similarity as None.""" + es = session.embedding_stream("sim_none", vec_dimensions=3) + es.append(Embedding(np.array([1.0, 0.0, 0.0], dtype=np.float32)), ts=1.0) + + results = es.fetch() + assert len(results) == 1 + assert isinstance(results[0], EmbeddingObservation) + assert results[0].similarity is None + + def test_raw_returns_embedding_obs(self, session: SqliteSession, images: list[Image]) -> None: + """search_embedding(raw=True) returns EmbeddingObservation with similarity.""" + + class FakeEmbedder(EmbeddingModel): + device = "cpu" + + def embed(self, *imgs: Image) -> Embedding | list[Embedding]: # type: ignore[override] + results = [] + for img in imgs: + val = float(img.data.mean()) / 255.0 + results.append(Embedding(np.array([val, 1.0 - val, 0.0], dtype=np.float32))) + return results if len(results) > 1 else results[0] + + def embed_text(self, *texts: str) -> Embedding | list[Embedding]: + raise NotImplementedError + + imgs = session.stream("sim_proj_imgs", Image) + imgs.append(images[0], ts=1.0) + imgs.append(images[1], ts=2.0) + + embs = imgs.transform(EmbeddingTransformer(FakeEmbedder())).store("sim_proj_embs") + + # raw=True: get raw EmbeddingObservation with similarity + results = embs.search_embedding([0.5, 0.5, 0.0], k=2, raw=True).fetch() + assert len(results) == 2 + for obs in results: + assert isinstance(obs, EmbeddingObservation) + assert obs.similarity is not None + # .data auto-projects to source Image via _source_data_loader + assert isinstance(obs.data, Image) + + class TestObservationSet: def test_fetch_returns_observation_set( self, session: SqliteSession, images: list[Image] diff --git a/dimos/memory/ingest.py b/dimos/memory/ingest.py index 5d96fc22b3..f0fd04263b 100644 --- a/dimos/memory/ingest.py +++ b/dimos/memory/ingest.py @@ -27,16 +27,25 @@ def ingest( stream: Stream[Any], source: Iterable[tuple[float, Any]], + *, + pose_source: Any | None = None, ) -> int: """Ingest (timestamp, payload) pairs into a stream. Accepts any iterable of ``(ts, data)`` — e.g. ``replay.iterate_ts(seek=5, duration=60)``. + Args: + pose_source: Optional replay with ``find_closest(ts)`` returning a pose + to attach to each frame (e.g. odom replay). + Returns: Number of items ingested. """ count = 0 for ts, payload in source: - stream.append(payload, ts=ts) + pose = None + if pose_source is not None: + pose = pose_source.find_closest(ts) + stream.append(payload, ts=ts, pose=pose) count += 1 return count diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index c63cfd9263..0ada16f81d 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -214,7 +214,12 @@ def transform( # ── Materialize ─────────────────────────────────────────────────── - def store(self, name: str | None = None) -> Stream[T]: + def store( + self, + name: str | None = None, + payload_type: type | None = None, + session: Session | None = None, + ) -> Stream[T]: # Already stored streams are a no-op if self._backend is not None and name is None: return self @@ -357,6 +362,7 @@ def search_embedding( query: Embedding | list[float] | str | Any, *, k: int, + raw: bool = False, ) -> Stream[Any]: """Search by vector similarity. @@ -364,8 +370,11 @@ def search_embedding( images/other objects. Text and non-vector inputs are auto-embedded using the model that created this stream. - Auto-projects to the source stream when lineage exists, so results - contain the source data (e.g. Images) rather than Embedding objects. + By default, auto-projects to the source stream so results contain the + source data (e.g. Images) rather than Embedding objects. Set + ``raw=True`` to skip auto-projection and get ``EmbeddingObservation`` + results with ``.similarity``, ``.pose``, ``.ts``, and ``.data`` + (auto-projected to parent via ``_source_data_loader``). """ from dimos.models.embedding.base import Embedding as EmbeddingCls @@ -373,7 +382,7 @@ def search_embedding( emb = self._require_model().embed_text(query) if isinstance(emb, list): emb = emb[0] - return self.search_embedding(emb, k=k) + return self.search_embedding(emb, k=k, raw=raw) if isinstance(query, EmbeddingCls): vec = query.to_numpy().tolist() @@ -384,7 +393,7 @@ def search_embedding( emb = self._require_model().embed(query) if isinstance(emb, list): emb = emb[0] - return self.search_embedding(emb, k=k) + return self.search_embedding(emb, k=k, raw=raw) clone = self._with_filter(EmbeddingSearchFilter(vec, k)) filtered: EmbeddingStream[T] = EmbeddingStream( @@ -394,6 +403,9 @@ def search_embedding( embedding_model=self._embedding_model, ) + if raw: + return filtered + # Auto-project to source stream when lineage exists session = filtered._session backend = filtered._backend diff --git a/dimos/memory/types.py b/dimos/memory/types.py index f1efbbd58f..1bab7b4af8 100644 --- a/dimos/memory/types.py +++ b/dimos/memory/types.py @@ -54,8 +54,10 @@ class EmbeddingObservation(Observation): .data auto-projects to the source stream's payload type. .embedding gives the Embedding vector. + .similarity is populated (0..1) when fetched via search_embedding (vec0 cosine). """ + similarity: float | None = field(default=None, repr=True) _embedding: Embedding | None = field(default=None, repr=False) _embedding_loader: Callable[[], Embedding] | None = field( default=None, repr=False, compare=False diff --git a/dimos/memory/viz.py b/dimos/memory/viz.py new file mode 100644 index 0000000000..599e3d8b8e --- /dev/null +++ b/dimos/memory/viz.py @@ -0,0 +1,143 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Visualization helpers for Memory2 search results. + +Produces LCM-publishable messages (OccupancyGrid, PoseStamped) and +Rerun time-series plots from embedding search observations. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import numpy as np + +if TYPE_CHECKING: + from dimos.memory.types import Observation + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid + + +def similarity_heatmap( + observations: list[Observation] | Any, + *, + resolution: float = 0.1, + padding: float = 1.0, + frame_id: str = "world", +) -> OccupancyGrid: + """Build an OccupancyGrid heatmap from observations with similarity scores. + + Each observation's pose maps to a grid cell; the cell value is + ``int(similarity * 100)`` (0-100 scale). Unknown cells stay at -1. + + Args: + observations: Iterable of Observation (must have .pose and .similarity). + resolution: Grid resolution in metres/cell. + padding: Extra metres around the bounding box. + frame_id: Coordinate frame for the grid. + + Returns: + OccupancyGrid publishable via LCMTransport. + """ + from dimos.memory.types import EmbeddingObservation + from dimos.msgs.geometry_msgs.Pose import Pose + from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid as OG + + posed: list[tuple[float, float, float]] = [] + for obs in observations: + if obs.pose is None: + continue + sim = ( + obs.similarity + if isinstance(obs, EmbeddingObservation) and obs.similarity is not None + else 0.0 + ) + p = obs.pose.position + posed.append((p.x, p.y, sim)) + + if not posed: + return OG(width=1, height=1, resolution=resolution, frame_id=frame_id) + + xs = [p[0] for p in posed] + ys = [p[1] for p in posed] + + min_x = min(xs) - padding + min_y = min(ys) - padding + max_x = max(xs) + padding + max_y = max(ys) + padding + + width = max(1, int((max_x - min_x) / resolution) + 1) + height = max(1, int((max_y - min_y) / resolution) + 1) + + grid = np.full((height, width), -1, dtype=np.int8) + + for px, py, sim in posed: + gx = int((px - min_x) / resolution) + gy = int((py - min_y) / resolution) + gx = min(gx, width - 1) + gy = min(gy, height - 1) + val = int(sim * 100) + # Keep max similarity per cell + if grid[gy, gx] < val: + grid[gy, gx] = np.int8(val) + + origin = Pose( + position=[min_x, min_y, 0.0], + orientation=[0.0, 0.0, 0.0, 1.0], + ) + + return OG(grid=grid, resolution=resolution, origin=origin, frame_id=frame_id) + + +def similarity_poses(observations: list[Observation] | Any) -> list[PoseStamped]: + """Extract PoseStamped from observations for spatial arrow rendering. + + Args: + observations: Iterable of Observation with .pose. + + Returns: + List of PoseStamped suitable for LCMTransport publishing. + """ + result: list[PoseStamped] = [] + for obs in observations: + if obs.pose is not None: + result.append(obs.pose) + return result + + +def log_similarity_timeline( + observations: list[Observation] | Any, + entity_path: str = "memory/similarity", +) -> None: + """Log similarity scores as a Rerun time-series plot. + + Each observation is logged at its timestamp with its similarity score. + Rerun auto-generates an interactive time-series graph in the timeline panel. + + Args: + observations: Iterable of EmbeddingObservation with .similarity and .ts. + entity_path: Rerun entity path for the scalar series. + """ + import rerun as rr + + from dimos.memory.types import EmbeddingObservation + + for obs in observations: + if not isinstance(obs, EmbeddingObservation): + continue + if obs.similarity is None or obs.ts is None: + continue + rr.set_time("memory_time", timestamp=obs.ts) + rr.log(entity_path, rr.Scalars(obs.similarity)) diff --git a/dimos/memory_old/embedding.py b/dimos/memory_old/embedding.py index a06c239cdf..758634eecc 100644 --- a/dimos/memory_old/embedding.py +++ b/dimos/memory_old/embedding.py @@ -38,6 +38,7 @@ class SpatialEntry(Timestamped): pose: PoseStamped +@dataclass class SpatialImage(SpatialEntry): image: Image @@ -87,13 +88,13 @@ def start(self) -> None: ops.map(self._store_spatial_entry), ).subscribe(print) - def _try_create_spatial_entry(self, img: Image) -> Observable[SpatialEntry]: + def _try_create_spatial_entry(self, img: Image) -> Observable[SpatialImage]: pose = self.tf.get_pose("world", "base_link") if not pose: return rx.empty() - return rx.of(SpatialEntry(image=img, pose=pose)) + return rx.of(SpatialImage(image=img, pose=pose)) - def _embed_spatial_entry(self, spatial_entry: SpatialEntry) -> SpatialEmbedding: + def _embed_spatial_entry(self, spatial_entry: SpatialImage) -> SpatialEmbedding: embedding = cast("Embedding", self.config.embedding_model.embed(spatial_entry.image)) return SpatialEmbedding( image=spatial_entry.image, diff --git a/dimos/memory_old/test_embedding.py b/dimos/memory_old/test_embedding.py index b7e7fbb294..5e8de6b3bf 100644 --- a/dimos/memory_old/test_embedding.py +++ b/dimos/memory_old/test_embedding.py @@ -14,7 +14,7 @@ import pytest -from dimos.memory.embedding import EmbeddingMemory, SpatialEntry +from dimos.memory_old.embedding import EmbeddingMemory, SpatialEntry from dimos.msgs.geometry_msgs import PoseStamped from dimos.utils.data import get_data from dimos.utils.testing import TimedSensorReplay From 219341f50a39e61dabfb245fda88aa048e6b1cc4 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 5 Mar 2026 18:44:51 +0800 Subject: [PATCH 020/118] Improve similarity heatmap with normalized values and distance spread - Normalize similarity scores relative to min/max (CLIP clusters in narrow band) - Add distance_transform_edt spread so dots radiate outward, fading to 0 - Bump default search k to 200 for denser heatmaps --- dimos/memory/impl/run_e2e_export.py | 2 +- dimos/memory/viz.py | 58 +++++++++++++++++++++++------ 2 files changed, 47 insertions(+), 13 deletions(-) diff --git a/dimos/memory/impl/run_e2e_export.py b/dimos/memory/impl/run_e2e_export.py index 3747f8f14e..4c70531a16 100644 --- a/dimos/memory/impl/run_e2e_export.py +++ b/dimos/memory/impl/run_e2e_export.py @@ -113,7 +113,7 @@ slug = query_text.replace(" ", "_")[:30] # raw=True: get EmbeddingObservation with .similarity and .pose - raw_results = embeddings.search_embedding(query_text, k=50, raw=True).fetch() + raw_results = embeddings.search_embedding(query_text, k=200, raw=True).fetch() # Spatial heatmap → Rerun grid = similarity_heatmap(raw_results, resolution=0.5) diff --git a/dimos/memory/viz.py b/dimos/memory/viz.py index 599e3d8b8e..70efe906df 100644 --- a/dimos/memory/viz.py +++ b/dimos/memory/viz.py @@ -35,22 +35,30 @@ def similarity_heatmap( *, resolution: float = 0.1, padding: float = 1.0, + spread: float = 2.0, frame_id: str = "world", ) -> OccupancyGrid: """Build an OccupancyGrid heatmap from observations with similarity scores. - Each observation's pose maps to a grid cell; the cell value is - ``int(similarity * 100)`` (0-100 scale). Unknown cells stay at -1. + Similarity values are normalized relative to the result set's min/max + (so the full 0-100 color range is used even when CLIP similarities + cluster in a narrow band). Each dot's value spreads outward using + ``distance_transform_edt`` — the same technique as + :func:`dimos.mapping.occupancy.gradient.gradient` — fading to 0 at + *spread* metres. Args: observations: Iterable of Observation (must have .pose and .similarity). resolution: Grid resolution in metres/cell. padding: Extra metres around the bounding box. + spread: How far each dot's similarity radiates (metres). frame_id: Coordinate frame for the grid. Returns: OccupancyGrid publishable via LCMTransport. """ + from scipy.ndimage import distance_transform_edt + from dimos.memory.types import EmbeddingObservation from dimos.msgs.geometry_msgs.Pose import Pose from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid as OG @@ -81,17 +89,43 @@ def similarity_heatmap( width = max(1, int((max_x - min_x) / resolution) + 1) height = max(1, int((max_y - min_y) / resolution) + 1) - grid = np.full((height, width), -1, dtype=np.int8) + # Normalize similarities to 0-1 (CLIP similarities cluster in a narrow band) + sims = np.array([s for _, _, s in posed]) + sim_min, sim_max = float(sims.min()), float(sims.max()) + sim_range = sim_max - sim_min + sims_norm = (sims - sim_min) / sim_range if sim_range > 0 else np.full_like(sims, 0.5) + + # Stamp normalized values onto a float grid (0 = no observation) + value_grid = np.zeros((height, width), dtype=np.float32) + has_obs = np.zeros((height, width), dtype=bool) + + for (px, py, _), snorm in zip(posed, sims_norm, strict=False): + gx = min(int((px - min_x) / resolution), width - 1) + gy = min(int((py - min_y) / resolution), height - 1) + if snorm > value_grid[gy, gx]: + value_grid[gy, gx] = snorm + has_obs[gy, gx] = True + + # Distance transform: distance (in cells) from each empty cell to nearest dot + dist_cells = distance_transform_edt(~has_obs) + dist_metres = dist_cells * resolution - for px, py, sim in posed: - gx = int((px - min_x) / resolution) - gy = int((py - min_y) / resolution) - gx = min(gx, width - 1) - gy = min(gy, height - 1) - val = int(sim * 100) - # Keep max similarity per cell - if grid[gy, gx] < val: - grid[gy, gx] = np.int8(val) + # Fade factor: 1.0 at the dot, 0.0 at `spread` metres away + fade = np.clip(1.0 - dist_metres / spread, 0.0, 1.0) + + # For each cell, find the value of its nearest dot (via index output) + _, nearest_idx = distance_transform_edt(~has_obs, return_indices=True) + nearest_value = value_grid[nearest_idx[0], nearest_idx[1]] + + # Final heatmap = nearest dot's value * distance fade + heatmap = nearest_value * fade + + # Convert to int8 grid: observed region is 0-100, rest is -1 + grid = np.full((height, width), -1, dtype=np.int8) + active = heatmap > 0 + grid[active] = (heatmap[active] * 100).clip(0, 100).astype(np.int8) + # Ensure dot cells themselves are always visible + grid[has_obs] = (value_grid[has_obs] * 100).clip(1, 100).astype(np.int8) origin = Pose( position=[min_x, min_y, 0.0], From e7f3fcd9444d9fe918346b43a05f4b43e25d59ad Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 5 Mar 2026 18:50:35 +0800 Subject: [PATCH 021/118] Remove plans/ from tracking (kept locally) --- plans/memory/api.md | 615 ---------------------- plans/memory/query_objects.md | 155 ------ plans/memory/questions.md | 56 -- plans/memory/sqlite.md | 780 --------------------------- plans/memory/tasks.md | 129 ----- plans/memory/transform.md | 180 ------- plans/old/analysis.md | 478 ----------------- plans/old/answers.md | 853 ------------------------------ plans/old/answers_correlator.md | 285 ---------- plans/old/correlator.md | 225 -------- plans/old/memory.md | 113 ---- plans/old/memory1.md | 318 ----------- plans/old/memory2.md | 898 -------------------------------- plans/old/memory3.md | 357 ------------- plans/old/memory3_answers.md | 67 --- plans/old/memory4.md | 466 ----------------- plans/old/transforms.md | 21 - 17 files changed, 5996 deletions(-) delete mode 100644 plans/memory/api.md delete mode 100644 plans/memory/query_objects.md delete mode 100644 plans/memory/questions.md delete mode 100644 plans/memory/sqlite.md delete mode 100644 plans/memory/tasks.md delete mode 100644 plans/memory/transform.md delete mode 100644 plans/old/analysis.md delete mode 100644 plans/old/answers.md delete mode 100644 plans/old/answers_correlator.md delete mode 100644 plans/old/correlator.md delete mode 100644 plans/old/memory.md delete mode 100644 plans/old/memory1.md delete mode 100644 plans/old/memory2.md delete mode 100644 plans/old/memory3.md delete mode 100644 plans/old/memory3_answers.md delete mode 100644 plans/old/memory4.md delete mode 100644 plans/old/transforms.md diff --git a/plans/memory/api.md b/plans/memory/api.md deleted file mode 100644 index 39cd7f738c..0000000000 --- a/plans/memory/api.md +++ /dev/null @@ -1,615 +0,0 @@ -# Memory2 API — Unified Stream - -## Core Idea - -One type: `Stream[T]`. Everything is a stream — stored, filtered, transformed. The user never thinks about Query vs ObservationSet vs Stream. They just chain operations. - -## Creating Streams - -```python -store = SqliteStore("/data/robot.db") -session = store.session() - -# Root stored stream — backed by DB -images = session.stream("images", Image, - pose_provider=lambda: tf.get_pose("world", "base_link")) - -logs = session.text_stream("logs", str, - pose_provider=lambda: tf.get_pose("world", "base_link")) -``` - -## Writing - -```python -images.append(frame) # ts + pose auto-filled -logs.append("Motor fault on joint 3") # ts + pose auto-filled -images.append(frame, pose=explicit_pose, tags={"cam": "front"}) -``` - -Only meaningful on stored (DB-backed) streams. - -## Filtering - -Every filter returns a new `Stream[T]`. Lazy — nothing executes until a terminal. - -```python -recent = images.after(one_hour_ago) -kitchen = recent.near(kitchen_pose, 5.0) -tagged = kitchen.filter_tags(cam="front") - -# Or chained -images.after(one_hour_ago).near(kitchen_pose, 5.0).filter_tags(cam="front") -``` - -### Filter methods - -```python -class Stream(Generic[T]): - # Temporal - def after(self, t: float) -> Stream[T]: ... - def before(self, t: float) -> Stream[T]: ... - def time_range(self, t1: float, t2: float) -> Stream[T]: ... - def at(self, t: float, *, tolerance: float = 1.0) -> Stream[T]: ... - - # Spatial - def near(self, pose: PoseLike, radius: float) -> Stream[T]: ... - - # Tags - def filter_tags(self, **tags: Any) -> Stream[T]: ... - -class EmbeddingStream(Stream[T]): - def search_embedding(self, query: Embedding | list[float], *, k: int) -> EmbeddingStream[T]: ... - -class TextStream(Stream[T]): - def search_text(self, text: str, *, k: int | None = None) -> TextStream[T]: ... -``` - -## Terminals & Iteration - -`Stream` is directly iterable — pages internally, never loads everything at once. - -```python -# Direct iteration (lazy, memory-efficient — uses fetch_pages internally) -for row in images.after(t).near(kitchen_pose, 5.0): - print(row.data) - -# Explicit fetch when you want the full list in memory -all_rows = images.after(t).fetch() - -# Other terminals -row = images.after(t).one() # single best match -row = images.last() # most recent -n = images.after(t).count() # count without fetching - -# Pagination -page = images.order_by("ts").limit(50).offset(100).fetch() -``` - -### Terminal methods - -```python -class Stream(Generic[T]): - def __iter__(self) -> Iterator[Observation]: ... # lazy, pages internally - def fetch(self) -> list[Observation]: ... # all results in memory - def fetch_pages(self, batch_size: int = 128) -> Iterator[list[Observation]]: ... - def one(self) -> Observation: ... - def last(self) -> Observation: ... - def count(self) -> int: ... - def order_by(self, field: str, *, desc: bool = False) -> Stream[T]: ... - def limit(self, k: int) -> Stream[T]: ... - def offset(self, n: int) -> Stream[T]: ... -``` - -## Observation - -```python -from dimos.models.embedding.base import Embedding, EmbeddingModel - -@dataclass -class Observation: - id: int - ts: float | None = None - pose: PoseStamped | None = None - tags: dict[str, Any] = field(default_factory=dict) - - @property - def data(self) -> Any: - """Lazy payload. Pre-populated from append/transform, fetched on demand from query.""" - ... -``` - -## Transformer - -A `Transformer` receives the full source stream and decides what to do — which items to process, how to batch, whether to use embeddings as a cheap proxy, etc. - -```python -class Transformer(ABC, Generic[T, R]): - """Transforms a source stream into results on a target stream.""" - - def process(self, source: Stream[T], target: Stream[R]) -> None: - """Batch/historical processing. Has full access to source — can query, - filter, use embeddings, batch, skip frames, etc.""" - ... - - def on_append(self, obs: Observation, target: Stream[R]) -> None: - """Reactive processing. Called per new item. Default: process([obs]).""" - ... - - supports_backfill: bool = True - supports_live: bool = True -``` - -### Simple lambdas (sugar) - -`Callable[[T], R | list[R] | None]` is auto-wrapped into a naive per-item Transformer: - -```python -# These are equivalent: -images.transform(lambda img: vlm.detect(img, "cigarettes")) -images.transform(PerItemTransformer(lambda img: vlm.detect(img, "cigarettes"))) -``` - -- `R` → single result -- `list[R]` → multiple results (e.g., multiple detections per frame) -- `None` → skip (no result for this input) - -### EmbeddingTransformer - -`EmbeddingTransformer` wraps an `EmbeddingModel` as a `Transformer[T, Embedding]`. When the output type is `Embedding`, `.store()` creates an `EmbeddingStream` (vec0 index, `search_embedding`, `EmbeddingObservation`). - -```python -# EmbeddingTransformer wraps the model -img_emb = images.transform(EmbeddingTransformer(CLIPModel())).store("img_emb") - -# Now img_emb is an EmbeddingStream -results = img_emb.search_embedding(query_emb, k=20).fetch() -# results[0].data → Image (auto-projected from source) -# results[0].embedding → Embedding (supports @ for cosine similarity) -``` - -### Smart Transformer example - -Chains after an embedding transform — receives `EmbeddingObservation` with `.data` (Image) and `.embedding` (vector), so it can use similarity to skip irrelevant frames: - -```python -class CigaretteDetector(Transformer[EmbeddingObservation, Detection]): - def __init__(self, vlm, clip): - self.vlm = vlm - self.clip = clip - - def process(self, source: Stream[EmbeddingObservation], target: Stream[Detection]): - query = self.clip.embed_text("person smoking cigarette") - for page in source.fetch_pages(batch_size=16): - # Use embedding similarity as cheap proxy — skip distant frames - promising = [obs for obs in page if obs.embedding @ query > 0.3] - if not promising: - continue - detections = self.vlm.detect_batch( - [obs.data for obs in promising], "cigarettes" - ) - for obs, dets in zip(promising, detections): - for det in dets: - target.append(det, ts=obs.ts, pose=obs.pose) - - def on_append(self, obs: EmbeddingObservation, target: Stream[Detection]): - dets = self.vlm.detect(obs.data, "cigarettes") - for det in dets: - target.append(det, ts=obs.ts, pose=obs.pose) -``` - -### Chaining transforms - -```python -# Filter → transform → store -images.after(one_hour_ago) \ - .near(kitchen_pose, 5.0) \ - .transform(EmbeddingTransformer(CLIPModel())) \ - .store("kitchen_embeddings") - -# Filter → transform → fetch (in-memory, not persisted) -results = images.after(one_hour_ago) \ - .near(kitchen_pose, 5.0) \ - .transform(EmbeddingTransformer(CLIPModel())) \ - .fetch() - -# Filter → embed → detect → store (chained: detector gets EmbeddingObservation) -images.near(kitchen_pose, 5.0) \ - .transform(EmbeddingTransformer(CLIPModel())) \ - .transform(CigaretteDetector(vlm, clip)) \ - .store("kitchen_cigarette_detections") -``` - -### Backfill / Live modes - -```python -# Both (default): backfill existing + subscribe to new -images.transform(detector).store("detections") - -# Live only: skip backfill, only process new items -images.transform(detector, live=True).store("detections") - -# Backfill only: process existing, don't subscribe -images.transform(detector, backfill=True).store("detections") - -# Backfill only: process existing, and subscribe -images.transform(detector, backfill=True, live=True).store("detections") - -# Incremental: re-running a stored transform resumes from last processed item -# (uses lineage parent_id to skip already-processed source rows) -``` - -## Storing - -`.store(name)` materializes a stream to DB. After storing, results are queryable and persistent. - -```python -# In-memory transform result — not persisted -detections = images.transform(detect_fn) - -# Persist it -detections.store("detections") - -# Now it's a DB-backed stream, queryable -stored = session.stream("detections") -rows = stored.after(t).fetch() -``` - -`.store()` also sets up lineage — every stored row gets `parent_id` pointing back to its source. - -Stream type is determined by what the Transformer produces: -- `Embedding` output → `EmbeddingStream` (vec0 index) -- Everything else → `Stream` (blob) -- `TextStream` is created explicitly via `session.text_stream()` (not auto-detected) - -## Reactive - -```python -# .appended emits Observation with .data pre-populated -images.appended.subscribe(lambda row: print(f"New image at {row.pose}")) - -# Stored transforms propagate reactively by default -detections = images.transform(detect_fn).store("detections") -# Now every images.append(frame) → detect_fn runs → result stored in "detections" - -# Filtered appended — only kitchen images -images.near(kitchen_pose, 5.0).appended.subscribe(...) -``` - -## Join (cross-stream lineage) - -```python -# Join detections with their source images — returns tuples -for det, img in detections.after(t).join(images): - print(f"Detected {det.data} in image at {img.pose}") -``` - -## Full Example: Cigarette Detection Pipeline - -```python -session = SqliteStore("/data/robot.db").session() - -# Root stream -images = session.stream("images", Image, - pose_provider=lambda: tf.get_pose("world", "base_link")) - -# Embedding index — EmbeddingModel is a Transformer -img_emb = images.transform(EmbeddingTransformer(CLIPModel())).store("img_emb") - -# VLM detection pipeline (live-only, no backfill) -images.transform( - lambda img: vlm.detect(img, "people with cigarettes"), - live=True, -).store("cigarette_detections") - -# Smart detection — reuse existing embeddings, detector gets EmbeddingObservation -img_emb.near(kitchen_pose, 10.0) \ - .transform(CigaretteDetector(vlm, clip)) \ - .store("kitchen_cigarette_detections") - -# # Worse: re-embeds from scratch (redundant if img_emb already exists) -# images.near(kitchen_pose, 10.0) \ -# .transform(EmbeddingTransformer(CLIPModel())) \ -# .transform(CigaretteDetector(vlm, clip)) \ -# .store("kitchen_cigarette_detections") - -# --- Later, querying --- - -# "Where did we see people with cigarettes in the kitchen?" -for row in session.stream("cigarette_detections") \ - .after(one_hour_ago).near(kitchen_pose, 10.0): - print(f"t={row.ts} pose={row.pose}: {row.data}") - -# "Show me the source images alongside detections" -for det, img in session.stream("cigarette_detections") \ - .after(one_hour_ago).join(images): - print(f"Detection: {det.data}, Source image at {img.pose}") - -# "Find images similar to 'red shoes'" -query_emb = clip.embed_text("red shoes") -similar = img_emb.search_embedding(query_emb, k=20).fetch() -# similar[0].data → Image (auto-projected from source) -# similar[0].embedding → Embedding (supports @ for cosine similarity) -``` - -## Full API - -```python -from dimos.models.embedding.base import Embedding, EmbeddingModel - -# --- Data types --- - -@dataclass -class Observation: - id: int - ts: float | None = None - pose: PoseStamped | None = None - tags: dict[str, Any] = field(default_factory=dict) - - @property - def data(self) -> Any: - """Lazy payload. Pre-populated from append, fetched on demand from query.""" - ... - -@dataclass -class EmbeddedObservation(Observation): - """Returned by EmbeddingStream terminals. Auto-projects .data to source stream.""" - - @property - def data(self) -> Any: - """Lazily loads from the source stream (e.g., Image), not the embedding.""" - ... - - @property - def embedding(self) -> Embedding: - """The Embedding object (has .vector, supports @ for cosine similarity).""" - ... - -# --- Transformer --- - -class Transformer(ABC, Generic[T, R]): - """Transforms a source stream into results on a target stream.""" - - def process(self, source: Stream[T], target: Stream[R]) -> None: - """Batch/historical processing. Full access to source stream.""" - ... - - def on_append(self, obs: Observation, target: Stream[R]) -> None: - """Reactive processing. Called per new item.""" - ... - - supports_backfill: bool = True - supports_live: bool = True - -# --- Streams --- - -class Stream(Generic[T]): - # Write (DB-backed only) - def append(self, payload: T, *, - ts: float | None = None, - pose: PoseLike | None = None, - tags: dict[str, Any] | None = None, - ) -> Observation: ... - - # Filter (returns new Stream, lazy) - def after(self, t: float) -> Stream[T]: ... - def before(self, t: float) -> Stream[T]: ... - def time_range(self, t1: float, t2: float) -> Stream[T]: ... - def at(self, t: float, *, tolerance: float = 1.0) -> Stream[T]: ... - def near(self, pose: PoseLike, radius: float) -> Stream[T]: ... - def filter_tags(self, **tags: Any) -> Stream[T]: ... - - # Order / paginate - def order_by(self, field: str, *, desc: bool = False) -> Stream[T]: ... - def limit(self, k: int) -> Stream[T]: ... - def offset(self, n: int) -> Stream[T]: ... - - # Transform - def transform(self, - xf: Transformer[T, R] | Callable[[T], R | list[R] | None], - *, live: bool = False, - backfill_only: bool = False, - ) -> Stream[R]: ... - - # Materialize (on TransformStream, accepts optional session= fallback) - def store(self, name: str | None = None, session: Session | None = None) -> Stream[T]: ... - - # Cross-stream (lineage join — returns tuples of (self_obs, target_obs)) - def join(self, target: Stream) -> Stream[tuple[Observation, Observation]]: ... - - # Iteration & Terminals - def __iter__(self) -> Iterator[Observation]: ... # lazy, pages internally - def fetch(self) -> list[Observation]: ... # all results in memory - def fetch_pages(self, batch_size: int = 128) -> Iterator[list[Observation]]: ... - def one(self) -> Observation: ... - def last(self) -> Observation: ... - def count(self) -> int: ... - - # Reactive - @property - def appended(self) -> Observable[Observation]: ... - -class EmbeddingStream(Stream[T]): - """Created automatically when a Transformer produces Embedding output. - Terminals return EmbeddedObservation (auto-projects .data to source stream).""" - def search_embedding(self, query: Embedding | list[float], *, k: int) -> EmbeddingStream[T]: ... - def fetch(self) -> list[EmbeddedObservation]: ... - def one(self) -> EmbeddedObservation: ... - def last(self) -> EmbeddedObservation: ... - -class TextStream(Stream[T]): - """Stream with FTS index.""" - def search_text(self, text: str, *, k: int | None = None) -> TextStream[T]: ... - -# --- Session / Store --- - -PoseProvider = Callable[[], PoseLike | None] - -class Session: - def stream(self, name: str, payload_type: type | None = None, *, - pose_provider: PoseProvider | None = None) -> Stream: ... - def text_stream(self, name: str, payload_type: type | None = None, *, - tokenizer: str = "unicode61", - pose_provider: PoseProvider | None = None) -> TextStream: ... - def embedding_stream(self, name: str, payload_type: type | None = None, *, - vec_dimensions: int | None = None, - pose_provider: PoseProvider | None = None, - parent_table: str | None = None) -> EmbeddingStream: ... - def materialize_transform(self, name: str, source: Stream, - transformer: Transformer, - *, live: bool = False, - backfill_only: bool = False) -> Stream: ... - def list_streams(self) -> list[StreamInfo]: ... - def close(self) -> None: ... - -class Store: - def session(self) -> Session: ... - def close(self) -> None: ... -``` - -## Internal Backing (impl detail) - -A `Stream` can be backed by different things — the user never sees this: - -- **DB tables** — from `session.stream()`. Metadata + payload + indexes. -- **Predicate** — from `.after()`, `.near()`, etc. Lazy SQL WHERE. -- **Transform** — from `.transform(t)`. Source stream + Transformer. - -The impl decides how to execute based on the backing chain. - -## SQLite Schema - -Each stream `{name}` creates these tables: - -```sql --- Metadata table (compact rows, fast scans) -CREATE TABLE {name} ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - ts REAL, - pose_x REAL, -- position - pose_y REAL, - pose_z REAL, - pose_qx REAL, -- orientation quaternion (stored, not indexed) - pose_qy REAL, - pose_qz REAL, - pose_qw REAL, - tags TEXT DEFAULT '{}', - parent_id INTEGER -- lineage: source observation id -); -CREATE INDEX idx_{name}_ts ON {name}(ts); - --- Payload table (blobs, loaded on demand) -CREATE TABLE {name}_payload ( - id INTEGER PRIMARY KEY, - data BLOB -); - --- R*Tree spatial index (position only) -CREATE VIRTUAL TABLE {name}_rtree USING rtree( - id, - min_x, max_x, - min_y, max_y, - min_z, max_z -); -``` - -**Optional per stream kind:** - -```sql --- TextStream: FTS5 full-text index -CREATE VIRTUAL TABLE {name}_fts USING fts5(content, tokenize='unicode61'); - --- EmbeddingStream: vec0 vector index -CREATE VIRTUAL TABLE {name}_vec USING vec0(embedding float[{dim}]); -``` - -### Key design decisions - -- **Separate payload table** — metadata queries (`fetch`, `count`, `near`, filters) never touch blob data. Payload is loaded lazily via `obs.data`. -- **Decomposed pose columns** — enables R*Tree spatial index for `.near()` queries. Orientation stored for reconstruction but not spatially indexed. -- **R*Tree for spatial queries** — `.near(pose, radius)` compiles to an R*Tree range query (bounding box at ±radius), not post-query Python filtering. - -### Lazy payload loading - -`fetch()` returns `Observation` with lazy `.data`: -- Metadata query: `SELECT id, ts, pose_x, ..., tags FROM {name} WHERE ...` -- `_data` stays `_UNSET`, `_data_loader` is set to: `SELECT data FROM {name}_payload WHERE id = ?` -- Only `obs.data` access triggers the blob read + codec decode - -This means iterating metadata (`obs.ts`, `obs.pose`, `obs.tags`) is cheap. - -### NearFilter SQL compilation - -```python -# .near(pose, 5.0) compiles to: -# JOIN {name}_rtree AS r ON r.id = {name}.id -# WHERE r.min_x >= pose.x - 5.0 AND r.max_x <= pose.x + 5.0 -# AND r.min_y >= pose.y - 5.0 AND r.max_y <= pose.y + 5.0 -# AND r.min_z >= pose.z - 5.0 AND r.max_z <= pose.z + 5.0 -``` - -For exact distance (not just bounding box), a post-filter computes Euclidean distance on the R*Tree candidates. - -## Serialization (Codec) - -Each stream has a `Codec[T]` that handles payload encode/decode. Auto-selected from `payload_type`. - -```python -class Codec(Protocol[T]): - def encode(self, value: T) -> bytes: ... - def decode(self, data: bytes) -> T: ... - -class LcmCodec(Codec[DimosMsg]): - """For DimosMsg types — uses lcm_encode/lcm_decode.""" - def __init__(self, msg_type: type[DimosMsg]) -> None: ... - -class PickleCodec(Codec[Any]): - """Fallback for arbitrary Python objects.""" - -def codec_for_type(payload_type: type[T] | None) -> Codec[T]: - """Auto-select codec based on payload type.""" - if payload_type is not None and issubclass(payload_type, DimosMsg): - return LcmCodec(payload_type) - return PickleCodec() -``` - -Lives in `dimos.memory.codec`. Detection uses `dimos.msgs.protocol.DimosMsg` (`runtime_checkable`). - -Transparent to the user — just pass `payload_type` to `session.stream()`: -```python -images = session.stream("images", Image) # auto LCM codec -numbers = session.stream("numbers", int) # auto pickle codec -``` - -Tags are JSON. Poses are decomposed into columns (not serialized). - -### Stream metadata (`_streams` table) - -``` -name TEXT PRIMARY KEY -payload_module TEXT -- fully qualified, e.g. "dimos.msgs.sensor_msgs.Image.Image" -stream_kind TEXT -- "stream" | "text" | "embedding" -parent_stream TEXT -- parent stream name (lineage for join()) -embedding_dim INTEGER -- vec0 dimension (embedding streams only) -``` - -On restart, `session.stream("images")` (no `payload_type`) resolves the class from `payload_module` via `importlib`, then selects the codec automatically. `embedding_dim` allows recreating the vec0 table without needing to see the first embedding again. - -## Implementation Notes - -- **No ORM** — raw `sqlite3` with direct SQL. The `Stream` filter chain *is* the query builder. -- **Session threading** — streams created by `session.stream()` get `_session` set. `TransformStream` inherits it from its source. `store()` also accepts an explicit `session=` fallback. - -## Resolved Questions - -1. **`.append()` on non-stored streams?** → `TypeError` (requires backend). -2. **Multiple `.store()` calls?** → Idempotent — returns existing stream if already stored. -3. ~~**Memory pressure from in-memory transforms?**~~ → Solved via `fetch_pages`. -4. **Pose storage** → Decomposed columns + R*Tree index (not binary blob). -5. **Payload loading** → Lazy via separate `{name}_payload` table. -6. **`__iter__`** → `for page in self.fetch_pages(): yield from page` — lazy, memory-efficient iteration. - -## Open Questions - -1. **`project_to` / lineage** — `parent_id` column exists but not yet wired. -2. **Incremental transforms** — re-running a stored transform should resume from last processed item. -3. **4D indexing** — should R*Tree include time as a 4th dimension? See `query_objects.md` for the Criterion/Score direction. diff --git a/plans/memory/query_objects.md b/plans/memory/query_objects.md deleted file mode 100644 index bf86d39675..0000000000 --- a/plans/memory/query_objects.md +++ /dev/null @@ -1,155 +0,0 @@ -# Query Objects — 4D Region + Soft Scoring System - -## Problem - -We need to query observations across 4 dimensions (x, y, z, t) plus embedding space. Current API has flat `filter_*` methods — works for simple cases but doesn't compose. We need: - -1. **Regions** — composable hard boundaries (include/exclude) -2. **Fields** — soft scoring that biases toward a point/time/embedding without hard cutoffs -3. A way to combine both in a single query - -## Key Insight - -Hard filters and soft biases are the same thing at different extremes: -- Hard filter = step function (1 inside, 0 outside) -- Soft bias = smooth decay (gaussian, linear, etc.) - -A unified **Criterion** type handles both. Each criterion maps an observation to a score in `[0, 1]`. Hard filters are just criteria with score `{0, 1}`. - -## Primitives - -### Temporal - -```python -# Hard boundaries -TimeRange(t1, t2) # 1 inside, 0 outside -Before(t) # sugar for TimeRange(-inf, t) -After(t) # sugar for TimeRange(t, inf) - -# Soft — score decays with distance from target -TimeProximity(target, sigma=60.0) # gaussian: exp(-dt²/2σ²) -``` - -### Spatial - -```python -# Hard boundaries -Sphere(center: PoseLike, radius: float) # 1 inside, 0 outside -Box(min: PoseLike, max: PoseLike) # axis-aligned bounding box -HeightRange(z_min, z_max) # horizontal slice - -# Soft -SpatialProximity(point: PoseLike, sigma=5.0) # gaussian in 3D -``` - -### Embedding - -```python -# Soft only (no hard boundary in embedding space makes sense) -EmbeddingSimilarity(vector, candidate_k=100) # cosine similarity, top-k pre-filter -``` - -### Tags - -```python -TagMatch(robot_id="robot1") # hard: exact match on tag values -``` - -## Composition - -Criteria compose via set operators: - -```python -# Intersection — all criteria must score > 0 -region = TimeRange(t1, t2) & Sphere(point, 5.0) - -# Union — any criterion scoring > 0 passes -region = Sphere(p1, 3.0) | Sphere(p2, 3.0) - -# Complement -region = ~TimeRange(t1, t2) # everything outside this window -``` - -For soft criteria, composition combines scores: -- `a & b` → `min(a.score, b.score)` (conservative) -- `a | b` → `max(a.score, b.score)` (permissive) - -## Weighted Scoring - -The interesting problem: "I care about embedding similarity, temporal proximity, AND spatial proximity" — but as soft preferences, not hard cutoffs. - -```python -Score( - time=TimeProximity(target_t, sigma=60), - space=SpatialProximity(point, sigma=5.0), - embedding=EmbeddingSimilarity(vector, candidate_k=200), - weights={"time": 0.3, "space": 0.3, "embedding": 0.4} -) -``` - -Each dimension produces a `[0, 1]` score. Final score = weighted sum. This replaces the vague `rank(**weights)` in the current API. - -## Integration with Query - -```python -# Current flat API (still works, sugar for simple cases) -q.after(t).near(point, 5.0).search_embedding(vec, candidate_k=100) - -# Region object approach -region = After(t) & Sphere(point, 5.0) -q.where(region).search_embedding(vec, candidate_k=100) - -# Full soft scoring — no hard boundaries, just preferences -q.score( - time=TimeProximity(target_t, sigma=120), - space=SpatialProximity(point, sigma=10.0), - embedding=EmbeddingSimilarity(vec, candidate_k=500), -).limit(20) - -# Mixed — hard boundary + soft ranking within -q.where(TimeRange(t1, t2)).score( - space=SpatialProximity(point, sigma=5.0), - embedding=EmbeddingSimilarity(vec, candidate_k=200), -).limit(10) -``` - -## SQL Mapping (SQLite impl) - -How each primitive maps to SQL: - -| Criterion | SQL Strategy | -|--------------------------|-------------------------------------------------------| -| `TimeRange(t1, t2)` | `WHERE ts BETWEEN ? AND ?` (B-tree) | -| `Before(t)` / `After(t)` | `WHERE ts < ?` / `WHERE ts > ?` | -| `Sphere(p, r)` | R*Tree range query on `_rtree` | -| `HeightRange(z1, z2)` | `WHERE pose_z BETWEEN ? AND ?` | -| `Box(min, max)` | R*Tree range query | -| `TimeProximity(t, σ)` | `ORDER BY ABS(ts - ?) ASC` or compute score in SELECT | -| `SpatialProximity(p, σ)` | R*Tree range (pre-filter at ~3σ) + score in SELECT | -| `EmbeddingSimilarity` | sqlite-vec `MATCH` → temp table | -| `TagMatch` | `WHERE json_extract(tags, ?) = ?` | - -Soft scoring strategy: **generous hard pre-filter in SQL, then score in Python**. -- Each soft criterion auto-generates a hard pre-filter at ~3σ (captures 99.7% of relevant results) -- `TimeProximity(t, σ=60)` → SQL: `WHERE ts BETWEEN t-180 AND t+180` (B-tree) -- `SpatialProximity(p, σ=5)` → SQL: R*Tree range query with 15m box -- `EmbeddingSimilarity` → sqlite-vec `MATCH` top-k (already a pre-filter) -- Python computes `[0, 1]` scores on the pre-filtered set, applies weights, sorts - -This keeps SQL simple (range queries on indexes) and Python handles the math. - -## Open Questions - -2. **How does `Score` interact with `search_embedding`?** Embedding search already returns ranked results from vec0. Should `Score.embedding` just re-weight those scores, or does it need a separate search pass? - -3. **Region objects as first-class types?** Do we store/serialize regions (e.g., "the kitchen region" as a reusable spatial boundary)? Or are they always constructed in code? - -4. **Do we need `NOT` regions for exclusion zones?** E.g., "everywhere except within 2m of the charging station." `~Sphere(charger, 2.0)` — complement on spatial regions requires scanning all of `_meta`, can't use R*Tree efficiently. - -5. **Gradient fields?** "Prefer observations taken at higher elevation" — not proximity to a point but a directional preference. `HeightGradient(ascending=True)` as a scorer? - -## Priority - -- **Phase 1**: Keep the flat `filter_*` / `rank()` API. Implement primitives internally. -- **Phase 2**: Expose `Criterion` objects + `where()` + `score()` as the composable API. -- **Phase 3**: Region persistence, named regions, gradient fields. diff --git a/plans/memory/questions.md b/plans/memory/questions.md deleted file mode 100644 index bc91b9f306..0000000000 --- a/plans/memory/questions.md +++ /dev/null @@ -1,56 +0,0 @@ -# Questions - -1. "where was I when this log line was added?" -- pose lookup, corelating to log lines found -- assume log line has a pose associated -- assume there are multiple log lines matching a search - -2. "how long have I been observing the red socks currently in view?" -- how many times did I see them before? -- temporal duration tracking + observation frequency - -3. "how many people did I see during last week?" -- assume we are generating a facial recognition db — is this matching a face detection stream, then embeddings? then we are searching over that stream? - -4. "where did you see red socks during last week?" -- we query for red socks embedding similarity, then feed this data into a VLM that further filters for socks -- is this data output into some table? is it like an ObservationSet again? -- then we can create a map (costmap) of red socks? - -5. "did anyone ever open this door? at what times did I see this door open? who opened it?" -- event detection + temporal querying of state changes - -6. "I have a transcription log (STT) and voice embeddings, how do I figure out who is saying what?" -- cross-stream correlation: audio → identity - -7. "I have parallel voice and facial recognition streams, how do I correlate voice to people?" -- I don't see all people speaking at all times -- multi-modal fusion with incomplete overlap - -8. "what's different in this room compared to yesterday?" -- comparing scene snapshots across time, diffing object sets -- requires baseline modeling / temporal comparison - -9. "show me everywhere the cat went today" -- continuous spatial tracking over time, not point queries -- dense pose-stream retrieval + path aggregation - -10. "what happened in the 30 seconds before the vase fell?" -- event-anchored temporal window across all streams -- multi-stream temporal slicing relative to a detected event - -11. "when was the last time I did NOT see the cat in the apartment?" -- negation query — finding gaps in an observation stream -- architecturally different from presence queries - -12. "what time does the mailman usually come?" -- aggregation across days, extracting temporal regularity from sparse events -- cross-session pattern extraction - -13. "what did robot-2 observe in the warehouse that I missed?" -- cross-agent memory diff -- session/robot-scoped queries and set difference across streams - -14. "how far did I travel while carrying an object?" -- filtered pose integration — only accumulate distance when a parallel detection stream has a positive signal -- cross-stream conditional joins diff --git a/plans/memory/sqlite.md b/plans/memory/sqlite.md deleted file mode 100644 index 0439b7defa..0000000000 --- a/plans/memory/sqlite.md +++ /dev/null @@ -1,780 +0,0 @@ -# SQLite Implementation - -Implementation spec for `dimos/memory/impl/sqlite/`. A coding agent should be able to implement the full SQLite backend from this document + `api.md`. - -## File Structure - -``` -dimos/memory/ - __init__.py # public exports: Observation, EmbeddingObservation, - # Stream, EmbeddingStream, TextStream, Transformer, - # EmbeddingTransformer, PerItemTransformer, Session, Store - types.py # Observation, EmbeddingObservation, StreamInfo - stream.py # Stream, EmbeddingStream, TextStream (base classes) - transformer.py # Transformer ABC, EmbeddingTransformer, PerItemTransformer - store.py # Store ABC - session.py # Session ABC - - impl/ - sqlite/ - __init__.py # exports SqliteStore - store.py # SqliteStore - session.py # SqliteSession - stream.py # SqliteStream, SqliteEmbeddingStream, SqliteTextStream - query.py # FilterChain — accumulates predicates, generates SQL - _sql.py # SQL helpers, identifier validation, pose helpers, serialization -``` - -## Dependencies - -- `sqlite3` (stdlib) -- `sqlite-vec` — vector similarity search via vec0 virtual table. Optional — `search_embedding` raises if unavailable. -- FTS5 — built into SQLite by default on most platforms. -- R*Tree — built into SQLite by default. -- `reactivex` — for `.appended` observable (already a DimOS dependency). - -## Connection Management - -### SqliteStore - -```python -class SqliteStore(Store): - def __init__(self, path: str): - self.path = path # or ":memory:" - - def session(self) -> SqliteSession: - conn = self._connect() - return SqliteSession(conn) - - def _connect(self) -> sqlite3.Connection: - if self.path == ":memory:": - uri = "file::memory:?cache=shared" - conn = sqlite3.connect(uri, uri=True) - else: - Path(self.path).parent.mkdir(parents=True, exist_ok=True) - conn = sqlite3.connect(self.path) - - conn.execute("PRAGMA journal_mode=WAL") - conn.execute("PRAGMA synchronous=NORMAL") - conn.execute("PRAGMA foreign_keys=ON") - - # Try loading sqlite-vec - try: - conn.enable_load_extension(True) - conn.load_extension("vec0") # or find via sqlite_vec.loadable_path() - conn.enable_load_extension(False) - except Exception: - pass # vec0 unavailable — search_embedding will raise - - return conn - - def close(self) -> None: ... -``` - -### SqliteSession - -```python -class SqliteSession(Session): - def __init__(self, conn: sqlite3.Connection): - self._conn = conn - self._streams: dict[str, SqliteStream] = {} # cache by name - self._ensure_registry() - - def _ensure_registry(self): - """Create _streams table if not exists.""" - self._conn.execute(""" - CREATE TABLE IF NOT EXISTS _streams ( - rowid INTEGER PRIMARY KEY, - name TEXT UNIQUE NOT NULL, - type TEXT NOT NULL, - payload_type TEXT, - parent_stream TEXT, - embedding_dim INTEGER - ) - """) - - def stream(self, name, payload_type=None, *, pose_provider=None) -> SqliteStream: - if name in self._streams: - return self._streams[name] - self._register_stream(name, "blob", payload_type) - self._create_stream_tables(name, stream_type="blob") - s = SqliteStream(name, self._conn, payload_type, pose_provider) - self._streams[name] = s - return s - - def text_stream(self, name, payload_type=None, *, tokenizer="unicode61", - pose_provider=None) -> SqliteTextStream: - # Similar — creates FTS tables too - ... - - def list_streams(self) -> list[StreamInfo]: ... - def close(self) -> None: self._conn.close() -``` - -## Schema - -All table names are prefixed with the stream name. Stream names are validated: `[a-zA-Z_][a-zA-Z0-9_]*`, max 64 chars. - -### `_streams` — Global registry - -```sql -CREATE TABLE _streams ( - rowid INTEGER PRIMARY KEY, - name TEXT UNIQUE NOT NULL, - type TEXT NOT NULL, -- 'blob', 'embedding', 'text' - payload_type TEXT, -- e.g. 'dimos.msgs.sensor_msgs.Image' - parent_stream TEXT, -- FK name of parent stream (lineage) - embedding_dim INTEGER -- only for type='embedding' -); -``` - -### `{name}_meta` — Observation metadata (all stream types) - -```sql -CREATE TABLE {name}_meta ( - rowid INTEGER PRIMARY KEY, -- = Observation.id - ts REAL, - pose_x REAL, pose_y REAL, pose_z REAL, - pose_qx REAL, pose_qy REAL, pose_qz REAL, pose_qw REAL, - tags TEXT, -- JSON dict, NULL if empty - parent_rowid INTEGER -- lineage: rowid in parent stream's _meta -); -CREATE INDEX idx_{name}_ts ON {name}_meta(ts); -``` - -### `{name}_payload` — Blob/Text payload (not EmbeddingStream) - -```sql -CREATE TABLE {name}_payload ( - rowid INTEGER PRIMARY KEY, -- matches _meta.rowid - data BLOB NOT NULL -- TextStream: TEXT instead of BLOB -); -``` - -Separated from `_meta` so metadata queries never page in multi-MB blobs. - -### `{name}_rtree` — Spatial index (all stream types) - -```sql -CREATE VIRTUAL TABLE {name}_rtree USING rtree( - rowid, -- matches _meta.rowid - min_x, max_x, - min_y, max_y, - min_z, max_z -); -``` - -Only rows with pose are inserted into R*Tree. Rows without pose are excluded from `.near()` results. - -### `{name}_fts` — Full-text search (TextStream only) - -```sql -CREATE VIRTUAL TABLE {name}_fts USING fts5( - content, - tokenize='{tokenizer}' -); -``` - -Standalone FTS table (not content-synced). Rowids match `_meta.rowid`. - -### `{name}_vec` — Vector index (EmbeddingStream only) - -```sql -CREATE VIRTUAL TABLE {name}_vec USING vec0( - embedding float[{dim}] -); -``` - -Rowids match `_meta.rowid`. Dimension inferred from first embedding inserted, or from `EmbeddingModel.embed()` output. - -## Stream Implementation - -### SqliteStream (implements Stream[T]) - -Internally, a stream object can be in different modes: - -```python -@dataclass -class StoredBacking: - """Root DB-backed stream. Created by session.stream().""" - name: str - -@dataclass -class FilteredBacking: - """Lazy predicate chain. Created by .after(), .near(), etc.""" - parent: StreamBacking # recursive — can chain filters - predicates: list[Predicate] - ordering: list[OrderClause] - limit_val: int | None - offset_val: int | None - -@dataclass -class TransformBacking: - """Unevaluated transform. Created by .transform().""" - source: StreamBacking - transformer: Transformer - live: bool - backfill_only: bool - -Backing = StoredBacking | FilteredBacking | TransformBacking -``` - -The stream carries its backing and resolves it at terminal time. - -### append() - -Only valid on `StoredBacking`. Otherwise raises `TypeError`. - -```python -def append(self, payload, *, ts=None, pose=None, tags=None): - if not isinstance(self._backing, StoredBacking): - raise TypeError("append() only valid on stored streams") - - ts = ts or time.time() - pose = pose or (self._pose_provider() if self._pose_provider else None) - - # 1. Insert into _meta - meta_rowid = self._insert_meta(ts, pose, tags, parent_rowid=None) - - # 2. Insert into _payload - blob = serialize(payload) # see Serialization section - self._conn.execute( - f"INSERT INTO {name}_payload(rowid, data) VALUES (?, ?)", - (meta_rowid, blob) - ) - - # 3. Insert into _rtree (if pose) - if pose: - x, y, z = extract_position(pose) - self._conn.execute( - f"INSERT INTO {name}_rtree(rowid, min_x, max_x, min_y, max_y, min_z, max_z) " - "VALUES (?, ?, ?, ?, ?, ?, ?)", - (meta_rowid, x, x, y, y, z, z) - ) - - self._conn.commit() - - # 4. Build Observation and emit - obs = Observation(id=meta_rowid, ts=ts, pose=pose, tags=tags or {}) - obs._data = payload # pre-populated - self._appended_subject.on_next(obs) - return obs -``` - -### EmbeddingStream.append() - -Same as above but inserts into `_vec` instead of `_payload`: - -```python -# Insert embedding vector -vec_data = embedding.to_numpy().tobytes() -self._conn.execute( - f"INSERT INTO {name}_vec(rowid, embedding) VALUES (?, ?)", - (meta_rowid, vec_data) -) -``` - -### TextStream.append() - -Inserts into both `_payload` (TEXT) and `_fts`: - -```python -self._conn.execute( - f"INSERT INTO {name}_payload(rowid, data) VALUES (?, ?)", - (meta_rowid, text_content) -) -self._conn.execute( - f"INSERT INTO {name}_fts(rowid, content) VALUES (?, ?)", - (meta_rowid, text_content) -) -``` - -## Filter → SQL Generation - -Each filter method returns a new stream with a `FilteredBacking` wrapping the current backing. At terminal time, the filter chain is compiled to SQL. - -### Predicate types - -```python -@dataclass -class AfterPred: - t: float - # → WHERE ts > ? - -@dataclass -class BeforePred: - t: float - # → WHERE ts < ? - -@dataclass -class TimeRangePred: - t1: float - t2: float - # → WHERE ts BETWEEN ? AND ? - -@dataclass -class AtPred: - t: float - tolerance: float - # → WHERE ts BETWEEN ? AND ? ORDER BY ABS(ts - ?) LIMIT 1 - -@dataclass -class NearPred: - x: float - y: float - z: float - radius: float - # → JOIN with _rtree bounding box query - -@dataclass -class TagsPred: - tags: dict[str, Any] - # → WHERE json_extract(tags, '$.key') = ? - -@dataclass -class TextSearchPred: - text: str - k: int | None - # → JOIN with _fts MATCH - -@dataclass -class EmbeddingSearchPred: - vector: list[float] - k: int - # → query _vec for top-k, then filter -``` - -### SQL compilation - -Walk the backing chain to the root `StoredBacking`, collect all predicates, then generate SQL: - -```python -def _compile(self) -> tuple[str, list[Any]]: - """Walk backing chain, return (sql, params).""" - root_name = self._find_root_name() - predicates = self._collect_predicates() - ordering = self._collect_ordering() - limit = self._collect_limit() - offset = self._collect_offset() - - # Start with base SELECT - sql = f"SELECT rowid, ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, tags FROM {root_name}_meta" - params = [] - joins = [] - wheres = [] - - for pred in predicates: - if isinstance(pred, AfterPred): - wheres.append("ts > ?") - params.append(pred.t) - elif isinstance(pred, NearPred): - joins.append( - f"JOIN {root_name}_rtree r ON r.rowid = {root_name}_meta.rowid" - ) - wheres.append( - "r.min_x >= ? AND r.max_x <= ? AND " - "r.min_y >= ? AND r.max_y <= ? AND " - "r.min_z >= ? AND r.max_z <= ?" - ) - params.extend([ - pred.x - pred.radius, pred.x + pred.radius, - pred.y - pred.radius, pred.y + pred.radius, - pred.z - pred.radius, pred.z + pred.radius, - ]) - elif isinstance(pred, TagsPred): - for key, val in pred.tags.items(): - wheres.append(f"json_extract(tags, '$.{key}') = ?") - params.append(val) - # ... etc - - sql += " " + " ".join(joins) - if wheres: - sql += " WHERE " + " AND ".join(wheres) - if ordering: - sql += " ORDER BY " + ", ".join(ordering) - if limit is not None: - sql += " LIMIT ?" - params.append(limit) - if offset is not None: - sql += " OFFSET ?" - params.append(offset) - - return sql, params -``` - -### search_embedding (vec0) - -```sql --- Top-k vector search -SELECT rowid, distance -FROM {name}_vec -WHERE embedding MATCH ? - AND k = ? -ORDER BY distance -``` - -Returns rowids, which are then used to filter `_meta`. This is a two-step process: -1. Get top-k rowids from vec0 -2. Fetch metadata for those rowids - -### search_text (FTS5) - -```sql -SELECT rowid, rank -FROM {name}_fts -WHERE {name}_fts MATCH ? -ORDER BY rank -``` - -Same two-step: get rowids from FTS5, then fetch metadata. - -## Terminal Execution - -### __iter__() — lazy iteration - -`Stream` is directly iterable. Pages internally via `fetch_pages`, yielding one `Observation` at a time: - -```python -def __iter__(self) -> Iterator[Observation]: - for page in self.fetch_pages(): - yield from page -``` - -### fetch() - -```python -def fetch(self) -> list[Observation]: - sql, params = self._compile() - rows = self._conn.execute(sql, params).fetchall() - return [self._row_to_observation(row) for row in rows] -``` - -### fetch_pages() - -```python -def fetch_pages(self, batch_size=128) -> Iterator[list[Observation]]: - sql, params = self._compile() - # Add LIMIT/OFFSET pagination - offset = 0 - while True: - page_sql = sql + f" LIMIT {batch_size} OFFSET {offset}" - rows = self._conn.execute(page_sql, params).fetchall() - if not rows: - break - yield [self._row_to_observation(row) for row in rows] - offset += batch_size -``` - -### count() - -```python -def count(self) -> int: - sql, params = self._compile() - count_sql = f"SELECT COUNT(*) FROM ({sql})" - return self._conn.execute(count_sql, params).fetchone()[0] -``` - -### one() / last() - -- `one()` → adds `LIMIT 1` to the query -- `last()` → adds `ORDER BY ts DESC LIMIT 1` - -## Lazy Data Loading - -`Observation.data` uses lazy loading. The implementation: - -```python -@dataclass -class Observation: - id: int - ts: float | None = None - pose: PoseStamped | None = None - tags: dict[str, Any] = field(default_factory=dict) - _data: Any = field(default=_SENTINEL, repr=False) - _load: Callable[[], Any] | None = field(default=None, repr=False) - - @property - def data(self) -> Any: - if self._data is _SENTINEL and self._load is not None: - self._data = self._load() - return self._data -``` - -When building observations from query results: - -```python -def _row_to_observation(self, row) -> Observation: - rowid = row[0] - obs = Observation( - id=rowid, - ts=row[1], - pose=reconstruct_pose(row[2:9]), - tags=json.loads(row[9]) if row[9] else {}, - ) - name = self._root_name() - conn = self._conn - obs._load = lambda: deserialize( - conn.execute(f"SELECT data FROM {name}_payload WHERE rowid = ?", (rowid,)).fetchone()[0] - ) - return obs -``` - -### EmbeddingObservation - -For `EmbeddingStream`, terminals return `EmbeddingObservation` which auto-projects `.data` to the source stream: - -```python -def _row_to_embedding_observation(self, row) -> EmbeddingObservation: - rowid = row[0] - parent_stream = self._get_parent_stream_name() - obs = EmbeddingObservation(id=rowid, ts=row[1], ...) - - # .data loads from PARENT stream (auto-projection) - obs._load = lambda: deserialize( - conn.execute( - f"SELECT data FROM {parent_stream}_payload WHERE rowid = ?", - (conn.execute( - f"SELECT parent_rowid FROM {self._name}_meta WHERE rowid = ?", - (rowid,) - ).fetchone()[0],) - ).fetchone()[0] - ) - - # .embedding loads from _vec - obs._embedding_load = lambda: Embedding( - np.frombuffer( - conn.execute( - f"SELECT embedding FROM {self._name}_vec WHERE rowid = ?", - (rowid,) - ).fetchone()[0], - dtype=np.float32 - ) - ) - return obs -``` - -## Lineage & join - -### Storing lineage - -When a Transformer appends to a target stream, `parent_rowid` links back to the source: - -```python -# Inside Transformer execution -target.append(result, ts=source_obs.ts, pose=source_obs.pose, - _parent_rowid=source_obs.id) # internal param -``` - -The `_streams` registry tracks stream-level lineage: -```python -# When .store() creates from a transform -INSERT INTO _streams (name, type, payload_type, parent_stream) -VALUES ('detections', 'blob', '...', 'images') -``` - -### join() - -Returns tuples of `(self_obs, target_obs)` linked by lineage: - -```sql --- Join self with target via parent_rowid -SELECT - c.rowid, c.ts, c.pose_x, ..., -- self (e.g., detections) - p.rowid, p.ts, p.pose_x, ... -- target (e.g., images) -FROM {self}_meta c -JOIN {target}_meta p ON c.parent_rowid = p.rowid -WHERE c.rowid IN (/* current filtered set */) -``` - -Iteration yields `tuple[Observation, Observation]` — both sides have lazy `.data`. - -## Transform Execution - -### .transform() — returns lazy stream - -`.transform(xf)` doesn't execute immediately. It returns a new stream with `TransformBacking`. Execution happens at terminal time or `.store()`. - -### .store() — materializes - -When `.store(name)` is called on a transform-backed stream: - -1. Register target stream in `_streams` (with `parent_stream` set) -2. Create target tables (`_meta`, `_payload`, etc.) -3. If not `live` mode: run `xf.process(source_stream, target_stream)` (backfill) -4. If not `backfill_only`: subscribe to source's `.appended` observable, call `xf.on_append()` for each new item -5. Return the stored stream (now `StoredBacking`) - -```python -def store(self, name): - if not isinstance(self._backing, TransformBacking): - # Already stored or predicate-backed — different path - ... - - tb = self._backing - # Create target stream - target = self._session._create_stream(name, ...) - - # Register lineage - self._session._register_lineage(name, parent_stream=source_name) - - # Backfill - if not tb.live and tb.transformer.supports_backfill: - source_stream = self._resolve_source() - tb.transformer.process(source_stream, target) - - # Live subscription - if not tb.backfill_only and tb.transformer.supports_live: - source_stream = self._resolve_source() - source_stream.appended.subscribe( - lambda obs: tb.transformer.on_append(obs, target) - ) - - return target -``` - -### Incremental backfill - -When re-opening a previously stored transform, check what's already been processed: - -```python -# Find max parent_rowid already processed -max_parent = conn.execute( - f"SELECT MAX(parent_rowid) FROM {target_name}_meta" -).fetchone()[0] - -# Only process source rows after that -if max_parent is not None: - source = source.after_id(max_parent) # internal method -``` - -### .fetch() on transform-backed stream (no .store()) - -If `.fetch()` is called on a transform-backed stream without `.store()`, execute the transform in-memory: - -1. Fetch source observations -2. Apply transformer's `process()` with an in-memory target -3. Return results without persisting - -This is useful for one-off transforms but can cause memory pressure with large datasets. - -## Reactive (.appended) - -Each stored stream has a `ReplaySubject` (or `Subject`) from reactivex: - -```python -class SqliteStream: - def __init__(self, ...): - self._appended_subject = Subject() - - @property - def appended(self) -> Observable[Observation]: - return self._appended_subject.pipe(...) -``` - -`append()` emits to the subject after the DB write succeeds. - -For filtered streams (`.after(t).near(pose, 5.0).appended`), the observable filters events through the predicate chain in Python: - -```python -@property -def appended(self): - root = self._find_root_stream() - predicates = self._collect_predicates() - return root.appended.pipe( - ops.filter(lambda obs: all(p.matches(obs) for p in predicates)) - ) -``` - -Each predicate type implements `matches(obs) -> bool` for Python-side filtering. - -## Serialization - -### Payload serialization - -Use Python `pickle` for general types, with an optimization path for known DimOS types (LCM-encoded messages): - -```python -def serialize(payload: Any) -> bytes: - # LCM types: use lcm_encode for compact binary - if hasattr(payload, '_get_packed_fingerprint'): - return lcm_encode(payload) - # Fallback: pickle - return pickle.dumps(payload) - -def deserialize(blob: bytes, payload_type: type | None = None) -> Any: - if payload_type and hasattr(payload_type, '_get_packed_fingerprint'): - return lcm_decode(blob, payload_type) - return pickle.loads(blob) -``` - -### Pose helpers - -```python -def extract_position(pose: PoseLike) -> tuple[float, float, float]: - """Extract (x, y, z) from any PoseLike.""" - if isinstance(pose, PoseStamped): - p = pose.pose.position - return (p.x, p.y, p.z) - # ... handle Pose, Point, PointStamped - -def extract_orientation(pose: PoseLike) -> tuple[float, float, float, float] | None: - """Extract (qx, qy, qz, qw) if available.""" - ... - -def reconstruct_pose(row_slice) -> PoseStamped | None: - """Rebuild PoseStamped from (x, y, z, qx, qy, qz, qw) columns.""" - x, y, z, qx, qy, qz, qw = row_slice - if x is None: - return None - ... -``` - -### Tag serialization - -Tags are stored as JSON text. `None`/empty dict → `NULL` in the column. - -```python -tags_json = json.dumps(tags) if tags else None -``` - -## SQL Safety - -- **Identifier validation**: stream names must match `^[a-zA-Z_][a-zA-Z0-9_]{0,63}$`. Reject anything else with `ValueError`. -- **Parameterized queries**: all user values go through `?` params, never string interpolation. -- **Table names**: constructed from validated stream names, so they're safe for SQL interpolation (e.g., `f"{name}_meta"`). - -## Thread Safety - -- Each `Session` owns one `sqlite3.Connection` — not shared across threads. -- Multiple sessions can exist on the same file (WAL mode allows concurrent reads + one writer). -- The `appended` subject emits on the thread that called `append()`. - -## Error Handling - -- `append()` on non-stored stream → `TypeError` -- `search_embedding()` on non-embedding stream → `TypeError` -- `search_text()` on non-text stream → `TypeError` -- `search_embedding()` when sqlite-vec not loaded → `RuntimeError` -- Invalid stream name → `ValueError` -- `one()` with no results → `LookupError` - -## Testing - -Tests go in `dimos/memory/tests/test_sqlite.py`. Use `:memory:` store for speed. - -Key test scenarios: -1. Create stream, append, fetch — verify data round-trips -2. Temporal filters (after, before, time_range, at) -3. Spatial filter (near) — with and without pose -4. Tag filtering -5. EmbeddingStream — store embeddings, search_embedding, verify EmbeddingObservation auto-projects .data -6. TextStream — store text, search_text -7. Transform with lambda — verify lineage -8. Transform with Transformer class — verify process() called -9. Chained filters — verify SQL composition -10. join — verify cross-stream lineage returns tuples -11. fetch_pages — verify pagination -12. Lazy data loading — verify .data only hits DB on access -13. .appended observable — verify reactive emission -14. Incremental backfill — verify resume from last processed -15. Multiple sessions on same file diff --git a/plans/memory/tasks.md b/plans/memory/tasks.md deleted file mode 100644 index 82d2a4e964..0000000000 --- a/plans/memory/tasks.md +++ /dev/null @@ -1,129 +0,0 @@ -# Memory2 — Remaining Tasks - -Gap analysis between `plans/memory/` specs and `dimos/memory/` implementation. - -## P0 — Security / Correctness - -### 1. Stream name validation - -Stream names are interpolated directly into SQL via f-strings. No validation exists — arbitrary input is a SQL injection vector. - -**Spec** (`sqlite.md`): `^[a-zA-Z_][a-zA-Z0-9_]{0,63}$`, reject with `ValueError`. - -**Where**: Add a `_validate_stream_name(name)` check at the top of `SqliteSession.stream()`, `.text_stream()`, `.embedding_stream()`. - -### 2. `_clone()` type annotation vs runtime - -`Stream._clone()` (`stream.py:94-108`) is annotated `-> Stream[T]`, but at runtime it uses `self.__class__.__new__(self.__class__)` which correctly preserves the subclass. So `EmbeddingStream.after(t)` returns an `EmbeddingStream` at runtime — no bug. - -The annotation is wrong for mypy though. Consider `-> Self` (from `typing_extensions`) if we want strict typing. Low priority — runtime works. - -## P1 — Core API Gaps - -### 3. Wire `parent_stream` into `_streams` registry - -`_register_stream()` (`sqlite.py:847-861`) never writes the `parent_stream` column. The column exists in the schema but is always NULL. - -**Where**: `materialize_transform()` (`sqlite.py:770-799`) knows both `source_table` and `name`. Pass `parent_stream=source_table` to `_register_stream()`, and update `_register_stream` to accept and INSERT it. - -This is a prerequisite for `.join()` and stream-level lineage discovery. - -### 4. ~~Implement `.project_to()` — cross-stream lineage~~ ✅ - -Implemented. `project_to(target)` adds a `LineageFilter` to the target stream (same `_with_filter` mechanism as `.after()`, `.near()`, etc.). The filter compiles to a SQL subquery walking the `parent_id` chain. Multi-hop lineage is resolved via `_streams.parent_stream` registry. Result is a fully chainable `Stream`. - -### 4b. Implement `.join()` — cross-stream lineage returning pairs - -`api.md` specifies: -```python -for det, img in detections.after(t).join(images): - print(f"Detected {det.data} in image at {img.pose}") -``` - -Unlike `project_to()` which returns a `Stream`, `join()` yields `tuple[Observation, Observation]` pairs. This is a terminal operation (not chainable) since the return type is pairs, not observations. - -**Depends on**: ~~Task 3~~ Done — `parent_stream` is now written by `materialize_transform()` and read by `resolve_lineage_chain()`. - -### 5. Filtered `.appended` — predicate-filtered reactive subscriptions - -`api.md` specifies: -```python -images.near(kitchen_pose, 5.0).appended.subscribe(...) -``` - -Current impl (`stream.py:276-278`) returns the raw Subject regardless of filters. - -**Fix** (from `sqlite.md`): When `self._query.filters` is non-empty, pipe the root subject through `ops.filter()` that evaluates each predicate in Python: - -```python -@property -def appended(self): - backend = self._require_backend() - obs = backend.appended_subject - if not self._query.filters: - return obs - return obs.pipe(ops.filter(lambda o: self._matches_filters(o))) -``` - -Each filter type needs a `matches(obs) -> bool` method for Python-side evaluation: -- `AfterFilter`: `obs.ts > self.t` -- `NearFilter`: Euclidean distance check -- `TagsFilter`: dict subset check -- etc. - -### 6. Incremental backfill - -`sqlite.md` specifies that re-running a stored transform resumes from the last processed item: - -```python -max_parent = conn.execute( - f"SELECT MAX(parent_id) FROM {target_name}" -).fetchone()[0] - -if max_parent is not None: - source = source.after_id(max_parent) # internal: WHERE id > ? -``` - -**Where**: `materialize_transform()` (`sqlite.py:791-793`). Before calling `transformer.process()`, check if target already has rows and filter source accordingly. - -**Needs**: An internal `_after_id(row_id)` filter (not exposed in public API) that adds `WHERE id > ?`. - -## P2 — Robustness - -### 7. Separate connections per session - -`SqliteStore.session()` (`sqlite.py:886-887`) shares `self._conn` across all sessions. The spec says each session should own its own connection. - -**Fix**: `session()` should call `sqlite3.connect(self._path)` + WAL pragma + extension loading each time, not reuse `self._conn`. Store keeps the path, sessions get independent connections. - -This is required for multi-threaded use (e.g., one session writing in a background thread, another querying in the main thread). - -### 8. `_CollectorStream` doesn't set pose on observations - -`_CollectorStream.append()` (`stream.py:401-419`) accepts `pose` but doesn't store it on the `Observation`: - -```python -obs = Observation(id=self._next_id, ts=ts, tags=tags or {}, parent_id=parent_id, _data=payload) -# pose is silently dropped -``` - -**Fix**: Add `pose=pose` to the Observation constructor call. - -## P3 — Future (not blocking) - -### 9. Query objects — composable 4D regions + soft scoring - -`query_objects.md` proposes `Criterion` types (`TimeRange`, `Sphere`, `TimeProximity`, `SpatialProximity`, `EmbeddingSimilarity`) with `&`/`|`/`~` composition and weighted `Score()`. - -Explicitly Phase 2. Current flat filter API covers all simple cases. Implement when real usage demands soft scoring or region composition. - -### 10. `questions.md` hard cases - -Unresolved query patterns from the product requirements: -- Negation queries ("when did I NOT see the cat") -- Temporal regularity ("what time does the mailman come") -- Cross-agent memory diff -- Conditional pose integration -- Event-anchored multi-stream slicing - -These require extensions beyond the current Stream API — likely built on top of the composable query layer (task 9). diff --git a/plans/memory/transform.md b/plans/memory/transform.md deleted file mode 100644 index 409fd8fc6b..0000000000 --- a/plans/memory/transform.md +++ /dev/null @@ -1,180 +0,0 @@ -# Transform — Unified Derived Stream API - -## Concept - -`.transform()` is a single method on `StreamBase` that handles both historical (batch) and live (reactive) processing. It takes data from a source, applies a function, and stores results into the target stream with lineage. - -## API - -```python -class StreamBase(ABC, Generic[T]): - def transform(self, - source: StreamBase | ObservationSet, - fn: Callable[[Any], T | list[T] | None] | None = None, - *, - live: bool = False, - ) -> Self: - """ - Process source data, store results in this stream. - - Args: - source: where to read from - fn: transform function. Returns T, list[T], or None (skip). - None allowed for EmbeddingStream (uses model.embed implicitly). - live: if True, only subscribe to new appends (no backfill) - - Behavior by source type: - StreamBase → backfill existing + subscribe to live (default) - live=True → skip backfill, only subscribe - ObservationSet → batch process snapshot (live ignored) - - Returns self for chaining. - """ -``` - -## Source type determines mode - -| Source | `live=False` (default) | `live=True` | -|------------------|--------------------------------------------------|-------------------------------| -| `StreamBase` | backfill all existing + subscribe to `.appended` | subscribe to `.appended` only | -| `ObservationSet` | batch process the set | N/A (ignored) | - -## Transform function contract - -```python -fn: Callable[[Any], T | list[T] | None] -``` - -- Returns `T` → single result stored -- Returns `list[T]` → multiple results stored (e.g., multiple detections per frame) -- Returns `None` or `[]` → nothing stored for this input (e.g., no detections) -- `parent_id` set automatically from source row - -## Examples - -### VLM detections on images - -```python -images = session.stream("images", Image, - pose_provider=lambda: tf.get_pose("world", "base_link")) - -detections = session.stream("cigarette_detections", VLMDetection) - -# Backfill + live -detections.transform(images, fn=lambda img: vlm.detect(img, "people with cigarettes")) - -# After this, every new image.append() triggers detection automatically -# All results are queryable -rows = detections.query().filter_after(one_hour_ago).fetch() -``` - -### Live-only (skip backfill) - -```python -detections.transform(images, fn=detect_fn, live=True) -# Only processes images appended from now on -``` - -### Historical batch on query results - -```python -# Only process images from the kitchen in the last hour -kitchen_images = images.query().filter_near(kitchen_pose, 5.0).filter_after(one_hour_ago).fetch_set() - -detections.transform(kitchen_images, fn=lambda img: vlm.detect(img, "cigarettes")) -# Batch processes the set, no live subscription -``` - -### Embedding stream (specialized) - -```python -img_emb = session.embedding_stream("img_emb", model=CLIPModel()) - -# fn is implicit — uses model.embed() -img_emb.transform(images, live=True) - -# Equivalent to: -img_emb.transform(images, fn=lambda img: clip.embed(img), live=True) -``` - -### Chaining transforms - -```python -images = session.stream("images", Image, pose_provider=pose_fn) - -# Embeddings from images -img_emb = session.embedding_stream("img_emb", model=CLIPModel()) -img_emb.transform(images, live=True) - -# Detections from images -detections = session.stream("detections", VLMDetection) -detections.transform(images, fn=detect_fn, live=True) - -# Text descriptions from detections (second-level derived) -descriptions = session.text_stream("descriptions", str) -descriptions.transform(detections, fn=lambda det: det.describe(), live=True) -``` - -## Internals - -### Backfill (batch) - -```python -for page in source.iter_meta(page_size=128): - for row in page: - payload = source.load(row) # or row.data - results = fn(payload) - if results is None: - continue - if not isinstance(results, list): - results = [results] - for r in results: - self.append(r, ts=row.ts, pose=row.pose, parent_id=row.id) -``` - -### Live (reactive) - -```python -source.appended.pipe( - ops.map(lambda row: (row, fn(row.data))), - ops.filter(lambda pair: pair[1] is not None), - ops.flat_map(lambda pair: [ - (pair[0], r) for r in (pair[1] if isinstance(pair[1], list) else [pair[1]]) - ]), -).subscribe(lambda pair: self.append(pair[1], ts=pair[0].ts, pose=pair[0].pose, - parent_id=pair[0].id)) -``` - -### EmbeddingStream override - -```python -class EmbeddingStream(StreamBase[T]): - model: EmbeddingModel - - def transform(self, source, fn=None, *, live=False): - if fn is None: - fn = self.model.embed - return super().transform(source, fn, live=live) -``` - -## Lineage - -`transform()` sets `parent_id` on every appended row, linking back to the source row. This enables `project_to()`: - -```python -# Find source images for cigarette detections -with detections.query().fetch_set() as det_set: - source_images = det_set.project_to(images) - for row in source_images.rows(limit=5): - img = images.load(row) -``` - -## Open questions - -1. **Async transforms?** VLM inference is slow. Should `fn` support async/await or rx scheduling (e.g., `observe_on(io_scheduler)`)? - -2. **Error handling?** If `fn` raises on one row, skip it? Log and continue? Configurable? - -3. **Backfill progress?** For large backfills, should `transform()` return a progress observable or run in background? - -4. **Multiple parents?** Current design is single-parent lineage. If a stream derives from two streams (e.g., fusing image + audio), we'd need multi-parent support. Phase 3. diff --git a/plans/old/analysis.md b/plans/old/analysis.md deleted file mode 100644 index 39f9af4b9c..0000000000 --- a/plans/old/analysis.md +++ /dev/null @@ -1,478 +0,0 @@ -# Analysis Utilities - -Application-level analysis on Memory2 query results. NOT part of memory2 core — operates on fetched `ObservationRow` lists, no SQLite dependency. - -Location: `dimos/memory2/analysis.py` - -Dependencies: only `dimos/memory2/types.py` (ObservationRow, ObservationRef). No numpy, no sklearn. - ---- - -## 1. `cluster_observations()` - -The most common post-query pattern across Q2, Q4, Q5, Q9, Q11, Q12, Q14. - -```python -@dataclass -class Cluster: - rows: list[ObservationRow] - representative: ObservationRow # best by rank_key - - @property - def t_start(self) -> float: - return self.rows[0].ts_start - - @property - def t_end(self) -> float: - return self.rows[-1].ts_start - - @property - def duration(self) -> float: - return self.t_end - self.t_start - - @property - def center_pose(self) -> PoseLike | None: - """Average position of all localized rows.""" - ... - - -def cluster_observations( - rows: list[ObservationRow], - *, - time_scale: float | None = None, - space_scale: float | None = None, - threshold: float = 1.0, - rank_key: Callable[[ObservationRow], float] | None = None, -) -> list[Cluster]: - """Greedy sequential clustering over time and/or space. - - Distance between consecutive rows (must be sorted by ts_start): - - d = sqrt((dt/time_scale)^2 + (ds/space_scale)^2) - - New cluster starts when d > threshold. - - Args: - rows: ObservationRows, sorted by ts_start. - time_scale: Normalize temporal gap (seconds). None = ignore time. - space_scale: Normalize spatial distance (meters). None = ignore space. - threshold: Combined normalized distance to split clusters. - rank_key: Scoring function for representative selection. - Default: embedding score, then recency. - - Returns: - List of Cluster objects, each with .rows and .representative. - """ -``` - -### Modes - -```python -# Temporal only: split if gap > 10s -clusters = cluster_observations(rows, time_scale=10.0) - -# Spatial only: split if > 3m apart -clusters = cluster_observations(rows, space_scale=3.0) - -# Combined: either 10s gap OR 3m apart triggers split -clusters = cluster_observations(rows, time_scale=10.0, space_scale=3.0) - -# Bias toward spatial (space matters more): -clusters = cluster_observations(rows, time_scale=30.0, space_scale=2.0) -``` - -### Representative selection - -Default `rank_key`: `lambda r: r.scores.get("embedding", 0)` — picks the most relevant frame after a search. Override for quality-based selection: - -```python -# Quality-biased: prefer sharp, well-exposed frames -clusters = cluster_observations(rows, - time_scale=10.0, - rank_key=lambda r: ( - r.scores.get("embedding", 0) * 0.4 + - r.tags.get("sharpness", 0.5) * 0.4 + - r.tags.get("exposure", 0.5) * 0.2 - ), -) - -# Recency-biased: prefer the latest frame in each cluster -clusters = cluster_observations(rows, - time_scale=10.0, - rank_key=lambda r: r.ts_start, -) -``` - -### Which questions use this - -| Question | Mode | Purpose | -|----------|------|---------| -| Q2 — red socks viewing sessions | temporal | Group continuous sightings, VLM one per cluster | -| Q4 — where were red socks | spatial | Group nearby sightings into distinct locations | -| Q5 — door open events | temporal | Group rapid-fire "door open" detections into single events | -| Q9 — cat trail | spatial | Group into distinct locations the cat visited | -| Q11 — cat absence | temporal | (indirect — use `find_gaps` on clusters) | -| Q12 — mailman schedule | temporal | Group same-visit detections into single arrival events | -| Q14 — carrying intervals | temporal | Group "carrying" detections into continuous intervals | - ---- - -## 2. `find_gaps()` - -Find periods where observations are absent. Used in Q11 (cat absence) and Q14 (carrying interval boundaries). - -```python -@dataclass -class Gap: - t_start: float # timestamp of last observation before the gap - t_end: float # timestamp of first observation after the gap - duration: float # t_end - t_start - - -def find_gaps( - rows: list[ObservationRow], - *, - min_gap: float, -) -> list[Gap]: - """Find temporal gaps in a sorted observation list. - - Args: - rows: ObservationRows, sorted by ts_start. - min_gap: Minimum gap duration (seconds) to report. - - Returns: - List of Gap objects, sorted by time. - """ -``` - -Usage: - -```python -# Q11: When was the cat last NOT seen? -cat_seen = detections.query().filter_tags(class_name="cat").order_by("ts_start").fetch() -gaps = find_gaps(cat_seen, min_gap=60.0) -if gaps: - print(f"Last absence: {gaps[-1].t_start} to {gaps[-1].t_end}") -``` - -Works on clusters too — find gaps between cluster end and next cluster start: - -```python -# Gaps between sighting sessions (not between individual frames) -clusters = cluster_observations(cat_seen, time_scale=10.0) -# Synthesize one row per cluster (the representative) for gap analysis -cluster_reps = [c.representative for c in clusters] -session_gaps = find_gaps(cluster_reps, min_gap=300.0) -``` - ---- - -## 3. `compute_path_distance()` - -Sum of Euclidean distances along a pose trail. Used in Q9 (cat trail length) and Q14 (distance while carrying). - -```python -def compute_path_distance( - rows: list[ObservationRow], -) -> float: - """Total Euclidean path distance from consecutive poses. - - Args: - rows: ObservationRows with poses, sorted by ts_start. - Rows without pose are skipped. - - Returns: - Total distance in meters. - """ -``` - -Usage: - -```python -# Q14: How far did I travel while carrying? -for cluster in carrying_clusters: - pose_rows = poses.query().filter_time(cluster.t_start, cluster.t_end).order_by("ts_start").fetch() - dist = compute_path_distance(pose_rows) - print(f"Carried for {cluster.duration:.0f}s, traveled {dist:.1f}m") -``` - ---- - -## 4. `extract_time_pattern()` - -Extract time-of-day statistics from observations spread across multiple days. Used in Q12 (mailman schedule). - -```python -@dataclass -class TimePattern: - mean_hour: float # e.g. 10.5 = 10:30 AM - std_minutes: float # standard deviation in minutes - count: int # number of observations - times: list[float] # individual hours (for histogram) - - def __str__(self) -> str: - h = int(self.mean_hour) - m = int((self.mean_hour % 1) * 60) - return f"{h}:{m:02d} +/- {self.std_minutes:.0f}min (n={self.count})" - - -def extract_time_pattern( - rows: list[ObservationRow], - *, - tz: timezone | None = None, -) -> TimePattern: - """Extract time-of-day pattern from observations across multiple days. - - Best used on cluster representatives (one per event) rather than raw rows, - to avoid dense clusters biasing the average. - - Args: - rows: ObservationRows with ts_start. - tz: Timezone for time-of-day extraction. Default: UTC. - - Returns: - TimePattern with mean, std, and individual times. - """ -``` - -Usage: - -```python -# Q12: When does the mailman usually come? -sightings = faces.query().search_embedding(mailman_emb, candidate_k=100).fetch() -sightings = [r for r in sightings if r.scores.get("embedding", 0) > 0.8] - -# Cluster into individual visits (one per day) -visits = cluster_observations(sightings, time_scale=300.0) -pattern = extract_time_pattern([v.representative for v in visits]) -print(f"Mailman comes at {pattern}") # "10:30 +/- 12min (n=23)" -``` - ---- - -## 5. `match_viewpoints()` - -Match observations from two sets by embedding similarity — find corresponding views across time. Used in Q8 (room diff: today vs yesterday). - -```python -@dataclass -class ViewpointMatch: - current: ObservationRow - reference: ObservationRow - similarity: float - - -def match_viewpoints( - current: list[ObservationRow], - reference: list[ObservationRow], - vectors_current: list[list[float]], - vectors_reference: list[list[float]], - *, - min_similarity: float = 0.85, -) -> list[ViewpointMatch]: - """Match observations by embedding similarity (cosine via dot product). - - Assumes vectors are L2-normalized (as CLIP embeddings are). - - Pure Python — no numpy required (but callers may use numpy for - batch vector retrieval before calling this). - - Args: - current: ObservationRows from the "current" time window. - reference: ObservationRows from the "reference" time window. - vectors_current: Embedding vectors for current rows. - vectors_reference: Embedding vectors for reference rows. - min_similarity: Minimum cosine similarity for a valid match. - - Returns: - List of ViewpointMatch objects, one per matched current row. - Unmatched rows are excluded. - """ -``` - -Usage: - -```python -# Q8: What changed in this room vs yesterday? -current_imgs = images.query().filter_time(now - 300, now).filter_near(pose, radius=5.0).fetch() -yesterday_imgs = images.query().filter_time(yest - 300, yest + 300).filter_near(pose, radius=5.0).fetch() - -current_vecs = [images.vector(r.ref) for r in current_imgs] -yesterday_vecs = [images.vector(r.ref) for r in yesterday_imgs] - -matches = match_viewpoints(current_imgs, yesterday_imgs, current_vecs, yesterday_vecs) -for m in matches: - diff = vlm.ask([images.load(m.current.ref), images.load(m.reference.ref)], - "What changed between these two views?") -``` - -Note: this is O(n*m) dot products. Fine for typical sizes (tens to low hundreds of images per spatial query). For very large sets, callers can use numpy directly. - ---- - -## 6. `diff_observation_sets()` - -Find observations in set A that have no similar match in set B. Used in Q13 (cross-robot diff). - -```python -@dataclass -class UnmatchedObservation: - row: ObservationRow - best_similarity: float # highest similarity to anything in the other set - - -def diff_observation_sets( - source: list[ObservationRow], - reference: list[ObservationRow], - vectors_source: list[list[float]], - vectors_reference: list[list[float]], - *, - similarity_threshold: float = 0.7, -) -> list[UnmatchedObservation]: - """Find observations in source that have no close match in reference. - - Args: - source: Observations to check ("what did robot-2 see?") - reference: Observations to compare against ("what did I see?") - vectors_source: Embeddings for source rows. - vectors_reference: Embeddings for reference rows. - similarity_threshold: Below this = "unmatched" = novel observation. - - Returns: - List of UnmatchedObservation from source with no reference match. - """ -``` - -Usage: - -```python -# Q13: What did robot-2 see that I missed? -r2 = detections.query().filter_tags(robot_id="robot-2").filter_near(warehouse, radius=20).fetch() -me = detections.query().filter_tags(robot_id="robot-1").filter_near(warehouse, radius=20).fetch() -r2_vecs = [detections.vector(r.ref) for r in r2] -me_vecs = [detections.vector(r.ref) for r in me] - -missed = diff_observation_sets(r2, me, r2_vecs, me_vecs) -for m in missed: - print(f"Missed: {m.row.tags.get('class_name')} at {m.row.pose}") -``` - ---- - -## Quality Conventions - -Image quality metrics are stored in tags at ingest time by the pipeline. The analysis utilities don't compute quality — they consume it via `rank_key`. - -### Recommended tag keys - -| Tag | Type | Description | Range | -|-----|------|-------------|-------| -| `sharpness` | float | Laplacian variance of grayscale image | 0.0–1.0 (normalized) | -| `blur` | float | Inverse of sharpness (lower = sharper) | 0.0–1.0 | -| `exposure` | float | How well-exposed (0 = dark/blown out, 1 = good) | 0.0–1.0 | -| `occlusion` | float | Fraction of frame occluded | 0.0–1.0 | - -### Pipeline example - -```python -def compute_quality(frame) -> dict: - gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) - lap_var = cv2.Laplacian(gray, cv2.CV_64F).var() - sharpness = min(lap_var / 500.0, 1.0) # normalize - mean_brightness = gray.mean() / 255.0 - exposure = 1.0 - abs(mean_brightness - 0.45) * 2 # penalize too dark/bright - return {"sharpness": round(sharpness, 3), "exposure": round(max(exposure, 0), 3)} - -images.append(frame, - pose=robot_pose, - tags=compute_quality(frame), -) -``` - -Then in analysis: - -```python -clusters = cluster_observations(candidates, - time_scale=10.0, - rank_key=lambda r: ( - r.scores.get("embedding", 0) * 0.5 + - r.tags.get("sharpness", 0.5) * 0.3 + - r.tags.get("exposure", 0.5) * 0.2 - ), -) -``` - ---- - -## The "Embed → Cluster → VLM" Pipeline - -The dominant analysis pattern across Q2, Q4, Q5, Q8. Not a function — a recipe. - -``` -1. Embedding search → candidate_k rows (cheap, fast, noisy) -2. Score filter → discard low-similarity noise -3. cluster_observations → group into distinct events/locations -4. Representative pick → best frame per cluster (by quality + relevance) -5. VLM verify → confirm/describe each representative (expensive, precise) -6. Expand → confirmed representative → entire cluster is valid -``` - -This is NOT worth wrapping in a single function because: -- The VLM prompt varies per question -- The cluster parameters vary per domain -- The expand step varies (sometimes you want all rows, sometimes just the cluster metadata) -- Steps 1-4 compose naturally with existing tools - -But documenting it as a pattern means every new question follows the same structure. - -### Example: complete pipeline for Q2 - -```python -# 1. Embedding search -candidates = images.query().search_embedding(clip_text_encode("red socks"), candidate_k=1000).order_by("ts_start").fetch() - -# 2. Score filter -candidates = [r for r in candidates if r.scores.get("embedding", 0) > 0.7] - -# 3. Cluster (temporal — group continuous viewing) -clusters = cluster_observations(candidates, - time_scale=10.0, - rank_key=lambda r: ( - r.scores.get("embedding", 0) * 0.5 + - r.tags.get("sharpness", 0.5) * 0.5 - ), -) - -# 4. VLM verify representatives only -confirmed = [] -for c in clusters: - img = images.load(c.representative.ref) - if vlm.ask(img, "Are there red socks in this image? yes/no") == "yes": - confirmed.append(c) - -# 5. Use results -print(f"Currently watching for {confirmed[-1].duration:.0f}s") -print(f"Seen {len(confirmed) - 1} time(s) before") -``` - ---- - -## Summary - -| Utility | Pure Python | Used in | Core purpose | -|---------|-------------|---------|-------------| -| `cluster_observations` | yes | Q2,Q4,Q5,Q9,Q11,Q12,Q14 | Group by time/space, pick representative | -| `find_gaps` | yes | Q11 | Detect absence periods | -| `compute_path_distance` | yes | Q9,Q14 | Trajectory length | -| `extract_time_pattern` | yes | Q12 | Time-of-day statistics | -| `match_viewpoints` | yes | Q8 | Cross-temporal view matching | -| `diff_observation_sets` | yes | Q13 | Set difference by embedding similarity | - -All utilities are stateless functions on `list[ObservationRow]`. No DB access, no numpy dependency (callers use numpy for batch vector ops if they want). Quality metrics live in tags, set by the ingest pipeline. - -### Not included (stays in application code) - -- **Identity clustering** (Q3, Q6, Q7): Requires DBSCAN/sklearn + domain-specific parameters. Too varied for a generic utility. -- **State transition detection** (Q5): "door went from closed→open" needs domain knowledge about what states exist. -- **Absence reasoning** (Q11): Distinguishing "cat not here" from "robot not looking" requires cross-referencing robot coverage — application context. -- **VLM prompting**: Every question has different prompts and response parsing. diff --git a/plans/old/answers.md b/plans/old/answers.md deleted file mode 100644 index e5cb509ee0..0000000000 --- a/plans/old/answers.md +++ /dev/null @@ -1,853 +0,0 @@ -# Answers - -API reference: `memory3.md` (current) - ---- - -## 1. "Where was I, when this log line was added?" + "Where do motor faults keep happening?" - -**Streams**: `logs` (text-capable), `poses` (robot localization at high frequency) - -**Single log line**: - -```python -s = db.session() -logs = s.stream("logs", LogMsg, text=TextConfig()) -poses = s.stream("poses", PoseStamped) - -# Find the log entry by text -log_hit = logs.query().search_text("motor fault detected").one() - -# Look up pose at that time — .at() finds nearest within tolerance -pose_hit = poses.query().at(log_hit.ts_start, tolerance=0.5).one() -print(pose_hit.pose) # Pose(x=1.2, y=3.4, z=0.5) -``` - -**Multiple log lines → spatial map of faults**: - -```python -fault_logs = logs.query().search_text("motor fault").order_by("ts_start").fetch() - -# Correlate each to a pose -fault_locations = [] -for log_row in fault_logs: - pose_row = poses.query().at(log_row.ts_start, tolerance=0.5).fetch() - if pose_row: - fault_locations.append((log_row, pose_row[0])) - -# Cluster by location — "where do faults keep happening?" -from dimos.memory2.analysis import cluster_observations -location_clusters = cluster_observations( - [pose for _, pose in fault_locations], - space_scale=2.0, # within 2m = same spot -) - -for c in location_clusters: - print(f"{len(c.rows)} faults near {c.center_pose} " - f"({c.t_start} to {c.t_end})") - # → "12 faults near Pose(x=3.1, y=7.2) over the last 3 days" - -# Render on costmap -for c in location_clusters: - costmap.mark(pose=c.center_pose, label=f"motor faults ({len(c.rows)}x)") -``` - -**What works**: `.search_text()` finds all matching logs, `.at()` correlates each to a pose, `cluster_observations(space_scale=)` groups faults by location. The result is a heatmap of where the robot has trouble. - -**Cross-stream join**: The for-loop is the same nested-loop join pattern as Q5/Q7/Q14. `Correlator` (Phase 3) would batch this: -```python -fault_poses = s.correlate(fault_logs_set, poses, time_tolerance=0.5) -``` - ---- - -## 2. "How long have I been observing the red socks in view currently?" + "How many times did I see them before?" - -**Streams**: `images` (camera frames with CLIP embeddings and poses) - -No detection pipeline — we search raw images by embedding similarity, then VLM-verify. - -**Stage 1 — Embedding candidate retrieval**: - -```python -s = db.session() -images = s.stream("images", Image, embedding=EmbeddingConfig(dim=512)) - -socks_embedding = clip_text_encode("red socks") - -# Find all frames that might contain red socks -candidates = (images.query() - .search_embedding(socks_embedding, candidate_k=1000) - .order_by("ts_start") - .fetch()) - -# Post-filter by similarity score to discard weak matches -candidates = [h for h in candidates if h.scores.get("embedding", 0) > 0.7] -``` - -**Stage 2 — Diverse sampling before VLM** (don't waste VLM on 200 frames of the robot staring at socks): - -The embedding top-k will cluster heavily around moments of prolonged viewing. We need to spread VLM budget across time and space to discover all distinct sighting sessions. - -```python -# Cluster candidates into temporal segments (frames within 10s = same cluster) -# Then pick one representative per cluster for VLM -candidates.sort(key=lambda r: r.ts_start) - -clusters = [] # list of lists -for row in candidates: - if not clusters or row.ts_start - clusters[-1][-1].ts_start > 10.0: - clusters.append([row]) - else: - clusters[-1].append(row) - -# Pick the highest-scoring representative from each cluster -representatives = [] -for cluster in clusters: - best = max(cluster, key=lambda r: r.scores.get("embedding", 0)) - representatives.append((best, cluster)) -``` - -Now VLM verifies only the representatives — one call per temporal cluster, not per frame: - -```python -confirmed_segments = [] -for rep, cluster in representatives: - img = images.load(rep.ref) - if vlm.ask(img, "Are there red socks visible in this image? yes/no") == "yes": - # Entire cluster counts as a sighting session - confirmed_segments.append((cluster[0].ts_start, cluster[-1].ts_start)) -``` - -If the robot saw socks 5 different times across the day but stared for minutes each time, this makes ~5 VLM calls instead of 200+. - -**Stage 3 — Answer the question**: - -```python -now = time.time() - -# Current viewing session = last confirmed segment -if confirmed_segments: - current_duration = now - confirmed_segments[-1][0] - print(f"Watching red socks for {current_duration:.1f}s") - print(f"Seen them {len(confirmed_segments) - 1} time(s) before") -``` - -**What works**: Embedding search is the broad net (cheap, fast), temporal clustering deduplicates the "staring" problem, VLM confirms only one frame per cluster. Scales to long sessions without blowing VLM budget. - -**What's application logic**: Cluster gap threshold (10s), VLM prompt, what counts as "same sighting" — all domain-specific. - -**Limitation**: `candidate_k=1000` is a guess. sqlite-vec is KNN-only — no "all vectors above threshold" query. Workaround: use a large candidate_k and post-filter by score. - -**Extension — spatial diversity**: If the robot revisits the same spot repeatedly, add pose-based deduplication within temporal clusters. But temporal clustering alone handles the dominant case (continuous staring). - ---- - -## 3. "How many people did I see during last week?" - -**Pipeline**: -``` -camera frames → face detector → face crops → embedding model → face embeddings - ↓ - faces stream (each row = one detected face with identity embedding) -``` - -Yes — the `faces` stream stores detected face crops. Each append includes the face embedding. Searching over this stream by embedding finds the same face across time. - -**Streams**: `faces` (face crops with identity embeddings) - -```python -s = db.session() -faces = s.stream("faces", FaceCrop, embedding=EmbeddingConfig(dim=512)) - -one_week_ago = time.time() - 7 * 86400 -week_faces = (faces.query() - .filter_after(one_week_ago) - .order_by("ts_start") - .fetch()) - -# Get all embedding vectors for clustering -vectors = [] -for row in week_faces: - vec = faces.vector(row.ref) # retrieve stored embedding - vectors.append(vec) - -# Cluster to find unique identities -import numpy as np -from sklearn.cluster import DBSCAN - -X = np.array(vectors) -clustering = DBSCAN(eps=0.6, min_samples=2, metric="cosine").fit(X) -n_people = len(set(clustering.labels_)) - (1 if -1 in clustering.labels_ else 0) -print(f"Saw {n_people} unique people last week") -``` - -**What works**: `filter_after` for time range, `faces.vector(ref)` to retrieve stored embeddings for clustering. - -**What's application logic**: Identity clustering (DBSCAN, threshold tuning) is domain-specific — different robots may have different accuracy needs. - -**With derive() (Phase 3)**: Could automate the dedup into a persistent `people` stream, then it's just `.count()`. - ---- - -## 4. "Where did you see red socks during last week?" - -**Streams**: `images` (camera frames with CLIP embeddings and poses) - -**Stage 1 — Embedding candidate retrieval**: - -```python -s = db.session() -images = s.stream("images", Image, embedding=EmbeddingConfig(dim=512)) - -one_week_ago = time.time() - 7 * 86400 -socks_embedding = clip_text_encode("red socks") - -candidates = (images.query() - .search_embedding(socks_embedding, candidate_k=200) - .filter_after(one_week_ago) - .limit(50) - .fetch_set()) -``` - -**Stage 2 — VLM verification**: - -```python -verified_refs = [] -for row in candidates.rows(): - img = candidates.load(row.ref) - if vlm.ask(img, "Are there red socks in this image? yes/no") == "yes": - verified_refs.append(row.ref) - -# Wrap verified results back into an ObservationSet -verified = images.query().filter_refs(verified_refs).fetch_set() -``` - -`filter_refs()` gives us an ObservationSet of just the verified images — ephemeral, session-scoped. - -To persist: write to a new stream with lineage back to the originals: - -```python -red_socks = s.stream("red_socks", Image) -for ref in verified_refs: - src = images.meta(ref) - red_socks.append( - images.load(ref), - pose=src.pose, ts_start=src.ts_start, - tags={"query": "red socks"}, - parent_stream="images", parent_id=ref.id, - ) -``` - -**Stage 3 — Costmap**: - -```python -for row in verified.rows(): - costmap.mark(pose=row.pose, label="red socks", time=row.ts_start) -``` - -Every verified observation carries the robot's pose from the original image stream → direct costmap placement. - ---- - -## 5. "Did anyone ever open this door? At what times? Who opened it?" - -**Streams**: `detections` (object detections with tags), `faces` (face crops with identity embeddings) - -**Sub-question 1 & 2 — When was the door open?** - -Depends on the detection pipeline. If the detector tags door state: - -```python -detections = s.stream("detections", Detection, embedding=EmbeddingConfig(dim=512)) - -door_open = (detections.query() - .filter_tags(class_name="door", state="open") - .order_by("ts_start") - .fetch()) - -for row in door_open: - print(f"Door open at {row.ts_start}") -``` - -If the detector doesn't tag state — embedding search + VLM verify (same pattern as Q4): - -```python -open_door_emb = clip_text_encode("open door") -candidates = (images.query() - .search_embedding(open_door_emb, candidate_k=100) - .filter_near(door_location, radius=3.0) # only images near the door - .fetch()) - -# VLM verify each candidate -open_times = [r for r in candidates - if vlm.ask(images.load(r.ref), "Is this door open?") == "yes"] -``` - -**Sub-question 3 — Who opened it?** - -Cross-stream temporal+spatial correlation: for each door-open event, find faces nearby at that time. - -```python -faces = s.stream("faces", FaceCrop, embedding=EmbeddingConfig(dim=512)) - -for event in open_times: - # Find faces near the door around the time it opened - nearby = (faces.query() - .filter_time(event.ts_start - 5.0, event.ts_start + 2.0) - .filter_near(event.pose, radius=3.0) - .fetch()) - - if nearby: - # Identify the person via face embedding - vec = faces.vector(nearby[0].ref) - identity = lookup_identity(vec) # match against known faces DB - print(f"Door opened at {event.ts_start} by {identity}") -``` - -**What works**: `filter_time` + `filter_near` compose naturally for "who was here when this happened". The R*Tree + ts_start index handle this efficiently. - -**What's manual**: The for-loop is a nested-loop join. `Correlator` (Phase 3) would batch this: -```python -# Phase 3: -s.correlate(door_events, faces, time_tolerance=5.0, spatial_radius=3.0) -``` - -**State transition detection** ("door went from closed→open") is application logic. The memory system stores observations, not state machines. You'd either store explicit events in a `door_events` stream, or detect transitions by comparing consecutive detections. - ---- - -## 6. "I have a transcription log (STT) and voice embeddings — how do I figure out who is saying what?" - -**Streams**: `transcripts` (STT output, text-capable), `voice_embeddings` (speaker embeddings per audio segment) - -Two separate streams because they come from different models: STT gives you text, a speaker encoder gives you a voice identity vector. - -```python -s = db.session() -transcripts = s.stream("transcripts", Transcript, text=TextConfig()) -voice_embs = s.stream("voice_segments", VoiceSegment, embedding=EmbeddingConfig(dim=192)) -``` - -**Step 1 — Align transcripts to voice segments by time**: - -Each transcript has `ts_start`/`ts_end` (when the words were spoken). Each voice segment has a speaker embedding for that time window. - -```python -for tx_row in transcripts.query().order_by("ts_start").fetch(): - # Find the voice segment that overlaps this transcript - voice = (voice_embs.query() - .filter_time(tx_row.ts_start, tx_row.ts_end) - .one()) - - # voice.ref → voice_embs.vector(voice.ref) gives us the speaker embedding - speaker_vec = voice_embs.vector(voice.ref) - transcript_text = transcripts.load(tx_row.ref).text - - print(f"[{speaker_vec_to_name(speaker_vec)}]: {transcript_text}") -``` - -**Step 2 — Build speaker identity mapping**: - -Cluster all voice embeddings to find distinct speakers, then label: - -```python -all_voices = voice_embs.query().order_by("ts_start").fetch() -vectors = [voice_embs.vector(r.ref) for r in all_voices] - -# Cluster into distinct speakers -clustering = DBSCAN(eps=0.3, min_samples=3, metric="cosine").fit(np.array(vectors)) -# label_id → speaker name mapping (manual or via face correlation — see Q7) -``` - -**What works**: `filter_time` on voice stream using transcript's time window is the natural join key. `.vector()` retrieves stored embeddings for clustering. - -**Key insight**: The two streams are aligned by time, not by embedding similarity. We don't search by embedding across streams — we use temporal co-occurrence to pair them, then use the voice embedding for speaker identity. - ---- - -## 7. "I have parallel voice and facial recognition streams — how do I correlate voice to people? (I don't see all people speaking at all times)" - -**Streams**: `voices` (speaker embeddings per audio segment), `faces` (face identity embeddings per detection) - -The constraint "I don't see all people speaking at all times" means: -- Sometimes a person is speaking but out of camera view → voice segment exists, no face match -- Sometimes multiple people are visible but only one is speaking -- The correlation must be probabilistic, accumulated over time - -```python -s = db.session() -voices = s.stream("voices", VoiceSegment, embedding=EmbeddingConfig(dim=192)) -faces = s.stream("faces", FaceCrop, embedding=EmbeddingConfig(dim=512)) -``` - -**Step 1 — Collect unambiguous pairings** (only one face visible while voice active): - -```python -pairings = [] # (voice_embedding, face_embedding) pairs - -for v_row in voices.query().order_by("ts_start").fetch(): - # Find faces visible during this voice segment - visible_faces = (faces.query() - .filter_time(v_row.ts_start, v_row.ts_end) - .fetch()) - - if len(visible_faces) == 1: - # Unambiguous: only one person visible → must be the speaker - voice_vec = voices.vector(v_row.ref) - face_vec = faces.vector(visible_faces[0].ref) - pairings.append((voice_vec, face_vec)) -``` - -**Step 2 — Build cross-modal identity mapping**: - -```python -# Cluster voice embeddings → speaker IDs -voice_vecs = np.array([p[0] for p in pairings]) -voice_clusters = DBSCAN(eps=0.3, min_samples=2, metric="cosine").fit(voice_vecs) - -# For each voice cluster, find the most common face cluster -# This gives us: voice_speaker_id → face_identity -speaker_to_face = {} -for cluster_id in set(voice_clusters.labels_): - if cluster_id == -1: - continue - cluster_face_vecs = [p[1] for i, p in enumerate(pairings) - if voice_clusters.labels_[i] == cluster_id] - # Majority vote on face identity - face_identity = identify_majority(cluster_face_vecs) - speaker_to_face[cluster_id] = face_identity -``` - -**Step 3 — Label all voice segments** (including ambiguous ones): - -```python -for v_row in voices.query().order_by("ts_start").fetch(): - voice_vec = voices.vector(v_row.ref) - # Find nearest voice cluster → mapped face identity - speaker_id = predict_cluster(voice_vec, voice_clusters) - person = speaker_to_face.get(speaker_id, "unknown") - print(f"[{person}] spoke at {v_row.ts_start}") -``` - -**What works**: -- `filter_time` on faces using voice segment's time window — natural temporal join -- `.vector()` on both streams for cross-modal clustering -- The API provides the building blocks; the correlation logic (accumulate unambiguous pairings → build mapping → apply to ambiguous cases) is correctly application-level - -**What the constraint exposes**: "I don't see all people speaking at all times" means we can't rely on a single observation to establish identity. We need statistical accumulation — many unambiguous pairings build confidence. This is fundamentally a learning problem, not a query problem. The memory system's job is to make the data accessible; the correlation intelligence lives above. - -**With Correlator (Phase 3)**: The inner loop (for each voice segment, query faces) would become: -```python -pairs = s.correlate(voices, faces, time_tolerance=0.5) -``` -But the clustering/identity-mapping step still lives in application code. - ---- - -## 8. "What's different in this room compared to yesterday?" - -**What we need**: Compare object detections from "now" vs "yesterday" at the same location, find what changed. - -**Streams**: `images` (camera frames with CLIP embeddings and poses) - -We can't rely on a precomputed detection stream — object detection for a fixed set is expensive and not run in realtime. Instead, store raw images and diff at query time using embeddings + VLM. - -```python -s = db.session() -images = s.stream("images", Image, embedding=EmbeddingConfig(dim=512)) - -now = time.time() -yesterday = now - 86400 -robot_pose = get_current_pose() - -# Two queries: images from this room now vs yesterday -current_imgs = (images.query() - .filter_time(now - 300, now) - .filter_near(robot_pose, radius=5.0) - .fetch()) - -yesterday_imgs = (images.query() - .filter_time(yesterday - 300, yesterday + 300) - .filter_near(robot_pose, radius=5.0) - .fetch()) - -# Match viewpoints by embedding similarity (numpy, no extra queries) -current_vecs = np.array([images.vector(r.ref) for r in current_imgs]) -yesterday_vecs = np.array([images.vector(r.ref) for r in yesterday_imgs]) -similarity = current_vecs @ yesterday_vecs.T - -# Pair each current image with its closest yesterday viewpoint -pairs = [] -for i, row in enumerate(current_imgs): - j = similarity[i].argmax() - if similarity[i, j] > 0.85: # same viewpoint - pairs.append((row, yesterday_imgs[j])) - -# VLM diffs only matched viewpoint pairs -for curr, yest in pairs: - diff = vlm.ask( - [images.load(curr.ref), images.load(yest.ref)], - "What changed between these two views?") - if diff != "nothing": - print(f"At {curr.pose}: {diff}") -``` - -**What works**: Two queries retrieve the two temporal snapshots scoped to this room. Embedding similarity in numpy matches viewpoints without extra DB queries. VLM provides open-vocabulary scene comparison — no fixed object set needed. - -**What's application logic**: Viewpoint matching threshold, VLM prompting, what counts as a meaningful change. The memory system provides spatial+temporal retrieval; the VLM provides the intelligence. - -**Cost structure**: 2 DB queries + N `.vector()` reads (small, fast) + numpy matmul + M VLM calls (expensive, but only on matched pairs). - ---- - -## 9. "Show me everywhere the cat went today" - -**What we need**: Retrieve all cat detections from today, extract the pose trail, render as a path on the costmap. - -**Streams**: `detections` (object detections with poses) - -```python -s = db.session() -detections = s.stream("detections", Detection, embedding=EmbeddingConfig(dim=512)) - -today_start = start_of_day() - -# All cat detections today, ordered by time -cat_trail = (detections.query() - .filter_tags(class_name="cat") - .filter_after(today_start) - .order_by("ts_start") - .fetch()) - -# Extract the pose path -path = [(row.ts_start, row.pose) for row in cat_trail if row.pose] - -# Render on costmap -for ts, pose in path: - costmap.add_point(pose=pose, time=ts, label="cat") -costmap.draw_path([pose for _, pose in path]) -``` - -**What works**: `filter_tags(class_name="cat")` + `filter_after()` + `order_by("ts_start")` is clean and direct. Every detection carries the robot's pose → we know where the robot saw the cat, which approximates where the cat was. - -**Subtlety**: `row.pose` is the *robot's* pose when it detected the cat, not the cat's position in world frame. If you need the cat's actual position, you'd need the detection bounding box + depth + robot pose to project into world coordinates. That projection would happen in the detection pipeline before appending to the stream: - -```python -# In the detection pipeline: -cat_world_pose = project_to_world(bbox, depth_frame, robot_pose) -detections.append(detection, pose=cat_world_pose, tags={"class_name": "cat"}) -``` - -If stored this way, `row.pose` *is* the cat's world position, and the path is accurate. - -**Dense vs sparse**: If the detector runs at 5Hz and the cat is visible for an hour, that's 18,000 rows. `order_by("ts_start")` + the ts_start index handles this efficiently. For rendering, you might want to downsample: - -```python -# Fetch pages to avoid loading all 18k rows at once -for page in range(0, cat_trail_count, 100): - rows = (detections.query() - .filter_tags(class_name="cat") - .filter_after(today_start) - .order_by("ts_start") - .limit(100) # TODO: need offset on limit, or use fetch_set + fetch_page - .fetch()) -``` - -**Gap exposed**: `Query.limit(k)` has no `offset`. For pagination, you'd need `fetch_set()` then `fetch_page(limit=100, offset=N)`. This works but means you can't paginate purely at the query level. - ---- - -## 10. "What happened in the 30 seconds before the vase fell?" - -**What we need**: Detect the "vase fell" event, then slice ALL streams in a 30s window before it. - -**Streams**: `events` (detected events with tags), plus any number of other streams: `images`, `audio`, `detections`, `poses`, etc. - -```python -s = db.session() -events = s.stream("events", Event, text=TextConfig()) - -# Find the vase-fall event -vase_event = events.query().search_text("vase fell").one() -t_event = vase_event.ts_start - -# Now query every stream for the 30s window before the event -# list_streams() returns StreamInfo with payload_type, configs, count -timeline = {} -for info in s.list_streams(): - stream = s.stream(info.name, info.payload_type, - embedding=info.embedding, text=info.text) - window = (stream.query() - .filter_time(t_event - 30.0, t_event) - .order_by("ts_start") - .fetch()) - timeline[info.name] = window -``` - -**What works**: `list_streams()` returns `StreamInfo` with everything needed to reconstruct stream handles — no hardcoding payload types. `filter_time(t - 30, t)` on each stream gives the pre-event window. - ---- - -## 11. "When was the last time I did NOT see the cat in the apartment?" - -**What we need**: Find gaps in the cat detection stream — periods where no cat was detected. - -**Streams**: `detections` (object detections with tags) - -```python -s = db.session() -detections = s.stream("detections", Detection, embedding=EmbeddingConfig(dim=512)) - -# Get all cat detections, ordered by time -cat_seen = (detections.query() - .filter_tags(class_name="cat") - .order_by("ts_start") - .fetch()) - -# Find gaps — periods where the cat wasn't detected -# Gap = time between consecutive cat detections longer than some threshold -gap_threshold = 60.0 # 1 minute without seeing the cat = "not seen" - -timestamps = [r.ts_start for r in cat_seen] -gaps = [] -for i in range(1, len(timestamps)): - gap = timestamps[i] - timestamps[i - 1] - if gap > gap_threshold: - gaps.append((timestamps[i - 1], timestamps[i], gap)) - -if gaps: - # Most recent gap = last time the cat wasn't seen - last_gap = gaps[-1] - print(f"Last not seen: {last_gap[0]} to {last_gap[1]} ({last_gap[2]:.0f}s)") -else: - print("Cat has been visible continuously") -``` - -**What works**: `filter_tags` + `order_by` gives us the detection timeline. Gap analysis in Python is straightforward. - -**What the API can't do natively**: Negation queries ("when did X NOT happen") aren't expressible in the query builder. You can only query for what exists, then find gaps in Python. This is fundamentally correct — the memory system stores positive observations, not the absence of observations. Detecting absence requires knowledge of when the sensor *could* have observed (was the robot even in the apartment? was the camera on?) — that's application context. - -**Edge case**: The robot wasn't always in the apartment. A "gap" might be because the robot was in another room, not because the cat wasn't there. You'd need to cross-reference with the robot's own position to distinguish "didn't see cat because cat was absent" from "didn't see cat because robot was elsewhere." - ---- - -## 12. "What time does the mailman usually come?" - -**What we need**: We don't know who the mailman is. We need to discover them first, then find all their appearances, then extract the schedule. - -**Streams**: `images` (camera frames with CLIP embeddings and poses), `faces` (face crops with identity embeddings) - -**Stage 1 — Find the mailman via VLM** (retroactive identification): - -We know the mailman comes to the front door. Use spatial + embedding search to find candidates, VLM to confirm. - -```python -s = db.session() -images = s.stream("images", Image, embedding=EmbeddingConfig(dim=512)) -faces = s.stream("faces", FaceCrop, embedding=EmbeddingConfig(dim=512)) - -mailman_emb = clip_text_encode("person delivering mail at front door") - -# Search images near the front door -candidates = (images.query() - .search_embedding(mailman_emb, candidate_k=200) - .filter_near(front_door_pose, radius=5.0) - .fetch()) - -# Cluster temporally (don't VLM 200 frames of same delivery) -clusters = cluster_observations(candidates, time_scale=60.0) - -# VLM verify representatives -mailman_times = [] -for c in clusters: - img = images.load(c.representative.ref) - if vlm.ask(img, "Is there a person delivering mail or packages? yes/no") == "yes": - mailman_times.append(c) -``` - -**Stage 2 — Extract mailman embedding** (from confirmed sightings): - -Now we know *when* the mailman was there. Find their face embedding from those time windows. - -```python -# For each confirmed mailman visit, find faces near the door at that time -mailman_face_vecs = [] -for c in mailman_times: - nearby_faces = (faces.query() - .filter_time(c.t_start - 5.0, c.t_end + 5.0) - .filter_near(front_door_pose, radius=3.0) - .fetch()) - for f in nearby_faces: - mailman_face_vecs.append(faces.vector(f.ref)) - -# Average the face embeddings → stable mailman identity vector -import numpy as np -mailman_identity = np.mean(mailman_face_vecs, axis=0).tolist() -``` - -**Stage 3 — Search broadly with the discovered embedding**: - -Now we have a face embedding. Search ALL face data, not just near the door — catches sightings we might have missed with the spatial filter. - -```python -all_sightings = (faces.query() - .search_embedding(mailman_identity, candidate_k=200) - .fetch()) -sightings = [r for r in all_sightings if r.scores.get("embedding", 0) > 0.8] - -# Cluster into individual visits -visits = cluster_observations(sightings, time_scale=300.0) -``` - -**Stage 4 — Extract schedule**: - -```python -pattern = extract_time_pattern([v.representative for v in visits]) -print(f"Mailman comes at {pattern}") # "10:30 +/- 12min (n=23)" -``` - -**The general pattern — retroactive identification**: -1. **Describe** → CLIP text embedding + spatial constraint to narrow candidates -2. **VLM confirm** → identify positive examples (expensive, but on clustered representatives only) -3. **Extract identity embedding** → from confirmed examples, average face/object embeddings -4. **Search broadly** → use discovered embedding to find all appearances across time -5. **Analyze** → cluster, extract patterns - -This is the inverse of the usual flow (have embedding → search). Here you don't know what you're looking for until you find it via VLM, then bootstrap an embedding for broader retrieval. - -**Cross-session note**: This only works if the DB persists across days (`retention` != `"run"`). For long-term pattern analysis, use a persistent retention policy. - ---- - -## 13. "What did robot-2 observe in the warehouse that I missed?" - -**What we need**: Compare observations between two robots at the same location, find what robot-2 saw that robot-1 (me) didn't. - -**Streams**: Both robots write to the same DB (or DBs are merged). Observations carry `robot_id` in tags. - -```python -s = db.session() -detections = s.stream("detections", Detection, embedding=EmbeddingConfig(dim=512)) - -# What robot-2 saw in the warehouse -robot2_saw = (detections.query() - .filter_tags(robot_id="robot-2") - .filter_near(warehouse_pose, radius=20.0) # warehouse area - .fetch()) - -# What I saw in the same area -my_saw = (detections.query() - .filter_tags(robot_id="robot-1") - .filter_near(warehouse_pose, radius=20.0) - .fetch()) - -# Diff: find objects robot-2 detected that I didn't -# By embedding — for each of robot-2's detections, check if I have a similar one -my_vecs = [detections.vector(r.ref) for r in my_saw] - -missed = [] -for r2_row in robot2_saw: - r2_vec = detections.vector(r2_row.ref) - # Check if any of my detections are similar - similarities = [cosine_sim(r2_vec, mv) for mv in my_vecs] - if not similarities or max(similarities) < 0.7: - missed.append(r2_row) - -print(f"Robot-2 saw {len(missed)} things you missed:") -for m in missed: - print(f" {m.tags.get('class_name')} at {m.pose}") -``` - -**What works**: `filter_tags(robot_id=...)` scopes to a specific robot — `robot_id` lives in the tags JSON, queried via `filter_tags`. `filter_near` scopes to a location. `.vector()` enables cross-robot embedding comparison. No special `filter_robot()` needed. - ---- - -## 14. "How far did I travel while carrying an object?" - -**What we need**: Compute path distance from the pose stream, but only during time intervals when a parallel detection stream shows "carrying object." - -**Streams**: `poses` (robot poses at high frequency), `detections` (with "carrying" state) - -```python -s = db.session() -poses = s.stream("poses", PoseStamped) -detections = s.stream("detections", Detection, embedding=EmbeddingConfig(dim=512)) - -# Step 1: Find all time intervals where the robot was carrying an object -carrying = (detections.query() - .filter_tags(action="carrying") - .order_by("ts_start") - .fetch()) - -# Segment into continuous carrying intervals -intervals = [] -if carrying: - seg_start = carrying[0].ts_start - prev_t = carrying[0].ts_start - for r in carrying[1:]: - if r.ts_start - prev_t > 2.0: # gap = stopped carrying - intervals.append((seg_start, prev_t)) - seg_start = r.ts_start - prev_t = r.ts_start - intervals.append((seg_start, prev_t)) - -# Step 2: For each carrying interval, get poses and compute distance -import math -total_distance = 0.0 - -for t_start, t_end in intervals: - pose_rows = (poses.query() - .filter_time(t_start, t_end) - .order_by("ts_start") - .fetch()) - - for i in range(1, len(pose_rows)): - p1 = pose_rows[i - 1].pose - p2 = pose_rows[i].pose - dx = p2.position.x - p1.position.x - dy = p2.position.y - p1.position.y - dz = p2.position.z - p1.position.z - total_distance += math.sqrt(dx * dx + dy * dy + dz * dz) - -print(f"Traveled {total_distance:.2f}m while carrying objects") -``` - -**What works**: `filter_tags` identifies "carrying" intervals. `filter_time` + `order_by` retrieves the pose trail for each interval. Distance computation is simple Euclidean accumulation. - -**This is the cross-stream conditional join pattern**: Query stream A (detections) for intervals, then query stream B (poses) within those intervals. Same nested-loop pattern as Q5/Q6/Q7. - -**What would be cleaner with Correlator (Phase 3)**: -```python -# Get pose observations that overlap with carrying detections -carrying_poses = s.correlate(carrying_set, poses, time_tolerance=0.0) -``` - ---- - -## Summary - -| Question | Key API features used | Works well? | -|-------------------------------|-------------------------------------------------------------------------|--------------------------------------| -| Q1 — pose at log time | `.search_text()` + `.at()` | Yes | -| Q2 — continuous observation | `.search_embedding()` + VLM verify + `.order_by()` + segmentation | Yes | -| Q3 — count unique people | `.filter_after()` + `.vector()` + DBSCAN | Yes | -| Q4 — map red socks | `.search_embedding()` + VLM + `.filter_refs()` + costmap | Yes | -| Q5 — door opener | `.filter_tags()` + `.filter_time()` + `.filter_near()` | Yes, cross-stream loop | -| Q6 — STT + voice identity | `.filter_time()` + `.vector()` | Yes | -| Q7 — voice ↔ face | `.filter_time()` + `.vector()` + accumulation | Yes | -| Q8 — room diff | `.filter_time()` + `.filter_near()` + `.vector()` diff | Yes, diffing is app logic | -| Q9 — cat trail | `.filter_tags()` + `.order_by()` + pose path | Yes | -| Q10 — pre-event timeline | `.filter_time()` + `list_streams()` → `StreamInfo` | Yes | -| Q11 — absence detection | `.filter_tags()` + `.order_by()` + gap analysis | Yes, negation is app logic | -| Q12 — mailman schedule | `.search_embedding()` or `.filter_tags()` + time stats | Yes, pattern extraction is app logic | -| Q13 — cross-robot diff | `.filter_tags(robot_id=)` + `.filter_near()` + `.vector()` | Yes | -| Q14 — distance while carrying | `.filter_tags()` + `.filter_time()` + `.order_by()` + pose accumulation | Yes, cross-stream conditional join | - -**API gaps exposed by Q8-Q14**: - -**Remaining gap**: - -| Gap | Affects | Suggestion | -|-----|---------|------------| -| Cross-stream conditional join is always a manual loop | Q5,Q7,Q10,Q14 | Phase 3 `Correlator` — the most motivated feature | diff --git a/plans/old/answers_correlator.md b/plans/old/answers_correlator.md deleted file mode 100644 index a6be19d1c8..0000000000 --- a/plans/old/answers_correlator.md +++ /dev/null @@ -1,285 +0,0 @@ -# Answers — with Correlator - -Side-by-side: how the cross-stream questions change with `s.correlate()`. -Only covering questions where correlator applies (Q1, Q5, Q6, Q7, Q10, Q14). - ---- - -## Q1. "Where was I when this log line was added?" - -**Before** (using `.at()`): -```python -log_hit = logs.query().search_text("motor fault detected").one() -pose_hit = poses.query().at(log_hit.ts_start, tolerance=0.5).one() -``` - -**With correlator**: -```python -log_set = logs.query().search_text("motor fault detected").fetch_set() -result = s.correlate(log_set, poses, time_tolerance=0.5) -pose = result.unambiguous()[0].matches[0].pose -``` - -**Verdict**: `.at()` is better here. Correlator adds ceremony for a single-observation lookup. Correlator wins when you have many anchors — e.g., "where was I for ALL error log lines?": - -```python -errors = logs.query().search_text("error").fetch_set() -result = s.correlate(errors, poses, time_tolerance=0.5) -for p in result.with_matches(): - error_text = logs.load(p.anchor.ref).text - print(f"{error_text} → {p.matches[0].pose}") -``` - -That replaces a loop of N `.at()` calls with one batch query. - ---- - -## Q5. "Did anyone open this door? Who?" - -**Before** (manual loop): -```python -door_events = events.query().filter_tags(type="door_open").order_by("ts_start").fetch() - -for event in door_events: - nearby = (faces.query() - .filter_time(event.ts_start - 5.0, event.ts_start + 2.0) - .filter_near(event.pose, radius=3.0) - .fetch()) - if nearby: - vec = faces.vector(nearby[0].ref) - identity = lookup_identity(vec) - print(f"Door opened at {event.ts_start} by {identity}") -``` - -**With correlator**: -```python -door_events = events.query().filter_tags(type="door_open").fetch_set() - -pairs = s.correlate(door_events, faces, - time_before=5.0, time_after=2.0, - spatial_radius=3.0) - -for p in pairs.with_matches(): - vec = faces.vector(p.matches[0].ref) - identity = lookup_identity(vec) - print(f"Door opened at {p.anchor.ts_start} by {identity}") - -# Bonus: which door openings had nobody nearby? -for anchor in pairs.unmatched(): - print(f"Door opened at {anchor.ts_start} — nobody detected") -``` - -**What changed**: -- Loop of N queries → 1 batch query -- `.unmatched()` is free — no extra work to find events with zero matches -- Asymmetric window (`time_before=5.0, time_after=2.0`) expresses "who was there just before and shortly after" naturally - ---- - -## Q6. "STT + voice embeddings — who is saying what?" - -**Before** (manual loop): -```python -for tx_row in transcripts.query().order_by("ts_start").fetch(): - voice = (voice_embs.query() - .filter_time(tx_row.ts_start, tx_row.ts_end) - .one()) - - speaker_vec = voice_embs.vector(voice.ref) - transcript_text = transcripts.load(tx_row.ref).text - print(f"[{speaker_vec_to_name(speaker_vec)}]: {transcript_text}") -``` - -**With correlator**: -```python -pairs = s.correlate(transcripts, voice_embs, time_tolerance=0.0) - -for p in pairs.with_matches(): - speaker_vec = voice_embs.vector(p.matches[0].ref) - transcript_text = transcripts.load(p.anchor.ref).text - print(f"[{speaker_vec_to_name(speaker_vec)}]: {transcript_text}") - -# Transcripts with no matching voice segment (e.g., gap in audio) -for anchor in pairs.unmatched(): - print(f"[unknown]: {transcripts.load(anchor.ref).text}") -``` - -**What changed**: -- `time_tolerance=0.0` means: target's `ts_start` must fall within anchor's `[ts_start, ts_end]` window. Since transcripts have both `ts_start`/`ts_end`, this matches voice segments that overlap with the spoken words. -- `.unmatched()` catches transcripts where audio processing failed or had gaps — previously silently lost in a `.one()` that would throw. - ---- - -## Q7. "Voice ↔ face correlation (partial overlap)" - -**Before** (manual loop + filtering): -```python -pairings = [] -for v_row in voices.query().order_by("ts_start").fetch(): - visible_faces = (faces.query() - .filter_time(v_row.ts_start, v_row.ts_end) - .fetch()) - - if len(visible_faces) == 1: - voice_vec = voices.vector(v_row.ref) - face_vec = faces.vector(visible_faces[0].ref) - pairings.append((voice_vec, face_vec)) -``` - -**With correlator**: -```python -pairs = s.correlate(voices, faces, time_tolerance=0.0) - -# Unambiguous pairings: exactly one face visible during voice segment -pairings = [] -for p in pairs.unambiguous(): - voice_vec = voices.vector(p.anchor.ref) - face_vec = faces.vector(p.matches[0].ref) - pairings.append((voice_vec, face_vec)) - -# Stats for free -total = len(pairs) -matched = len(pairs.with_matches()) -unambiguous = len(pairs.unambiguous()) -unmatched = len(pairs.unmatched()) -print(f"{total} voice segments: {unambiguous} unambiguous, " - f"{matched - unambiguous} ambiguous, {unmatched} no face visible") -``` - -**What changed**: -- `.unambiguous()` replaces the `if len(...) == 1` check -- Statistics about match quality are trivial to compute -- The "I don't see all people speaking at all times" constraint is directly visible in `.unmatched()` count - ---- - -## Q10. "What happened in the 30 seconds before the vase fell?" - -**Before** (loop over streams): -```python -vase_event = events.query().search_text("vase fell").one() -t = vase_event.ts_start - -timeline = {} -for info in s.list_streams(): - stream = s.stream(info.name, info.payload_type, - embedding=info.embedding, text=info.text) - window = (stream.query() - .filter_time(t - 30.0, t) - .order_by("ts_start") - .fetch()) - timeline[info.name] = window -``` - -**With correlator**: -```python -vase_set = events.query().search_text("vase fell").fetch_set() - -timeline = {} -for info in s.list_streams(): - stream = s.stream(info.name, info.payload_type, - embedding=info.embedding, text=info.text) - result = s.correlate(vase_set, stream, time_before=30.0, time_after=0.0) - timeline[info.name] = result -``` - -**What changed**: -- `time_before=30.0, time_after=0.0` — asymmetric window expresses "30s before, nothing after" directly. No manual `t - 30.0, t` arithmetic. -- Still loops over streams (correlator is pairwise). But each iteration is cleaner. -- If `vase_set` had multiple events (vase fell twice), you'd get per-event windows for free. The manual version would need a nested loop. - -**Honest assessment**: Marginal improvement for Q10 since the anchor is typically one event. The correlator shines more when you have many anchors. - ---- - -## Q14. "How far did I travel while carrying an object?" - -**Before** (segment + loop): -```python -# Step 1: Segment carrying detections into intervals -carrying = (detections.query() - .filter_tags(action="carrying") - .order_by("ts_start") - .fetch()) - -intervals = [] -seg_start = carrying[0].ts_start -prev_t = carrying[0].ts_start -for r in carrying[1:]: - if r.ts_start - prev_t > 2.0: - intervals.append((seg_start, prev_t)) - seg_start = r.ts_start - prev_t = r.ts_start -intervals.append((seg_start, prev_t)) - -# Step 2: For each interval, get poses and sum distance -total_distance = 0.0 -for t_start, t_end in intervals: - pose_rows = (poses.query() - .filter_time(t_start, t_end) - .order_by("ts_start") - .fetch()) - for i in range(1, len(pose_rows)): - total_distance += distance(pose_rows[i-1].pose, pose_rows[i].pose) -``` - -**With correlator**: -```python -carrying = detections.query().filter_tags(action="carrying").fetch_set() - -pairs = s.correlate(carrying, poses, time_tolerance=0.1) - -# Each carrying detection gets matched to nearby poses -# Deduplicate: collect all unique matched pose refs, sorted by time -seen_pose_refs = set() -all_poses = [] -for p in pairs: - for m in p.matches: - if m.ref.id not in seen_pose_refs: - seen_pose_refs.add(m.ref.id) - all_poses.append(m) - -all_poses.sort(key=lambda r: r.ts_start) - -total_distance = 0.0 -for i in range(1, len(all_poses)): - total_distance += distance(all_poses[i-1].pose, all_poses[i].pose) -``` - -**Honest assessment**: The correlator version is *not* cleaner here. The problem is that carrying detections are per-frame (one every 0.2s at 5Hz), so you get thousands of overlapping CorrelationPairs that all match the same poses. You need deduplication, which is awkward. - -The original approach is actually better: segment into intervals first (app logic), then do one time-range query per interval. Correlator is designed for "match discrete events to another stream", not "define continuous intervals and query within them." - -**When correlator WOULD help for Q14**: If carrying detections had `ts_start`/`ts_end` representing the full carry interval (not per-frame), then: - -```python -# If each carrying observation spans the full interval -carry_intervals = detections.query().filter_tags(action="carrying").fetch_set() - -pairs = s.correlate(carry_intervals, poses, time_tolerance=0.0) -total_distance = 0.0 -for p in pairs: - sorted_poses = sorted(p.matches, key=lambda r: r.ts_start) - for i in range(1, len(sorted_poses)): - total_distance += distance(sorted_poses[i-1].pose, sorted_poses[i].pose) -``` - -Clean — but requires interval-shaped observations. The segmentation from point detections to intervals is still app logic. - ---- - -## Summary - -| Q | Before | With Correlator | Improvement | -|---|--------|----------------|-------------| -| Q1 | `.at()` — 1 query | overkill for single lookup | None (`.at()` is better) | -| Q1 batch | N `.at()` calls | 1 batch query | Yes — N→1 queries | -| Q5 | N queries in loop | 1 batch + `.with_matches()` / `.unmatched()` | Yes — cleaner + free stats | -| Q6 | N queries in loop | 1 batch + `.unmatched()` catches gaps | Yes — cleaner + error visibility | -| Q7 | N queries + `if len==1` | 1 batch + `.unambiguous()` | Yes — most natural fit | -| Q10 | N queries (1 per stream) | N correlate calls, asymmetric window | Marginal — still loops over streams | -| Q14 | segment + N queries | messy dedup of overlapping pairs | No — manual approach is better for continuous intervals | - -**Key insight**: Correlator is best for **discrete events correlated against another stream** (Q5, Q6, Q7). It's less useful for continuous intervals (Q14) or single-observation lookups (Q1). The sweet spot is "I have 50-5000 anchors and want matches from another stream for each." - -**API validated**: `time_before`/`time_after` asymmetric windows are needed (Q5, Q10). `.unambiguous()` and `.unmatched()` are the most-used convenience methods. diff --git a/plans/old/correlator.md b/plans/old/correlator.md deleted file mode 100644 index 4141506f68..0000000000 --- a/plans/old/correlator.md +++ /dev/null @@ -1,225 +0,0 @@ -# Correlator - -Cross-stream temporal+spatial join for Memory2. - -## Motivation - -5 of 14 usage questions (Q5, Q6, Q7, Q10, Q14) require the same pattern: - -```python -for anchor in stream_a.query().fetch(): - matches = (stream_b.query() - .filter_time(anchor.ts_start - tol, anchor.ts_end + tol) - .filter_near(anchor.pose, radius=r) - .fetch()) - # do something with (anchor, matches) -``` - -This is a nested-loop join — N queries, one per anchor observation. Correlator replaces it with a single batch operation. - -## API - -Method on Session: - -```python -class Session: - def correlate( - self, - anchors: Stream | ObservationSet, - targets: Stream | ObservationSet, - *, - time_tolerance: float | None = None, # symmetric: sets both before and after - time_before: float | None = None, # asymmetric: window before anchor ts_start - time_after: float | None = None, # asymmetric: window after anchor ts_end - spatial_radius: float | None = None, - ) -> CorrelationResult: ... -``` - -Accepts Stream (correlate everything) or ObservationSet (correlate a filtered subset). - -**Time window per anchor**: `[ts_start - time_before, ts_end + time_after]`. If `ts_end` is None, uses `ts_start` for both. `time_tolerance` is shorthand for `time_before = time_after = time_tolerance`. Explicit `time_before`/`time_after` override `time_tolerance`. - -### CorrelationResult - -```python -@dataclass -class CorrelationPair: - anchor: ObservationRow - matches: list[ObservationRow] - -class CorrelationResult: - def __iter__(self) -> Iterator[CorrelationPair]: ... - def __len__(self) -> int: ... - - # Filter by match cardinality - def unambiguous(self) -> list[CorrelationPair]: - """Pairs where exactly one target matched.""" - ... - - def with_matches(self) -> list[CorrelationPair]: - """Pairs where at least one target matched.""" - ... - - def unmatched(self) -> list[ObservationRow]: - """Anchor observations with zero matches.""" - ... -``` - -### Usage - -**Q5 — Who opened the door?** -```python -door_events = events.query().filter_tags(type="door_open").fetch_set() - -pairs = s.correlate(door_events, faces, time_tolerance=5.0, spatial_radius=3.0) -for p in pairs.with_matches(): - identity = identify_face(faces.vector(p.matches[0].ref)) - print(f"Door opened at {p.anchor.ts_start} — {identity}") -``` - -**Q7 — Voice ↔ face (unambiguous only)** -```python -pairs = s.correlate(voices, faces, time_tolerance=0.5) -for p in pairs.unambiguous(): - voice_vec = voices.vector(p.anchor.ref) - face_vec = faces.vector(p.matches[0].ref) - pairings.append((voice_vec, face_vec)) -``` - -**Q10 — Pre-event timeline (30s before, nothing after)** -```python -vase_event = events.query().search_text("vase fell").fetch_set() - -timeline = {} -for info in s.list_streams(): - stream = s.stream(info.name, info.payload_type, - embedding=info.embedding, text=info.text) - result = s.correlate(vase_event, stream, time_before=30.0, time_after=0.0) - timeline[info.name] = result -``` - -**Q14 — Distance while carrying** -```python -carrying = detections.query().filter_tags(action="carrying").fetch_set() - -pairs = s.correlate(carrying, poses, time_tolerance=0.0) -total_distance = 0.0 -for p in pairs: - sorted_poses = sorted(p.matches, key=lambda r: r.ts_start) - for i in range(1, len(sorted_poses)): - total_distance += distance(sorted_poses[i-1].pose, sorted_poses[i].pose) -``` - -## Implementation - -### SQL batch join (single query instead of N) - -```sql --- 1. Materialize anchors into temp table -CREATE TEMP TABLE _corr_anchors ( - anchor_id TEXT, - ts_lo REAL, -- ts_start - time_before - ts_hi REAL, -- (ts_end or ts_start) + time_after - pose_x REAL, - pose_y REAL, - pose_z REAL -); - --- 2. Join to target stream's _meta on time overlap -SELECT a.anchor_id, t.* -FROM _corr_anchors a -JOIN {target}_meta t - ON t.ts_start >= a.ts_lo AND t.ts_start <= a.ts_hi -ORDER BY a.anchor_id, t.ts_start; -``` - -With spatial constraint, add R*Tree join: - -```sql -SELECT a.anchor_id, t.* -FROM _corr_anchors a -JOIN {target}_rtree r - ON r.min_x >= a.pose_x - :radius AND r.max_x <= a.pose_x + :radius - AND r.min_y >= a.pose_y - :radius AND r.max_y <= a.pose_y + :radius - AND r.min_z >= a.pose_z - :radius AND r.max_z <= a.pose_z + :radius - AND r.min_t >= a.ts_lo AND r.max_t <= a.ts_hi -JOIN {target}_meta t ON t.rowid = r.rowid -ORDER BY a.anchor_id, t.ts_start; -``` - -### Grouping - -The SQL returns flat rows sorted by `anchor_id`. Group in Python into `CorrelationPair`s: - -```python -pairs = [] -current_anchor = None -current_matches = [] -for row in cursor: - anchor_id = row["anchor_id"] - if anchor_id != current_anchor: - if current_anchor is not None: - pairs.append(CorrelationPair(anchor=..., matches=current_matches)) - current_anchor = anchor_id - current_matches = [] - current_matches.append(to_observation_row(row)) -``` - -### Performance - -- Temp table insert: O(A) where A = anchor count -- Join: SQLite uses the ts_start index (or R*Tree) on the target → O(A × log(T)) where T = target count -- vs naive loop: O(A × T) without indexes, O(A × log(T)) with indexes but A round-trips - -The batch approach saves round-trip overhead and lets SQLite optimize the join plan. For A=1000 anchors, that's 1 query vs 1000. - -## Types - -```python -# In types.py -@dataclass -class CorrelationPair: - anchor: ObservationRow - matches: list[ObservationRow] -``` - -```python -# In correlation.py -class CorrelationResult: - _pairs: list[CorrelationPair] - - def __iter__(self): return iter(self._pairs) - def __len__(self): return len(self._pairs) - - def unambiguous(self) -> list[CorrelationPair]: - return [p for p in self._pairs if len(p.matches) == 1] - - def with_matches(self) -> list[CorrelationPair]: - return [p for p in self._pairs if p.matches] - - def unmatched(self) -> list[ObservationRow]: - return [p.anchor for p in self._pairs if not p.matches] -``` - -## File structure update - -``` -dimos/memory2/ - ... - correlation.py # CorrelationResult, correlate() implementation -``` - -## Phase - -This can be Phase 2b — after Stream + Query + ObservationSet are working. It depends on: -- ObservationSet (anchors can be a set) -- Stream._meta table schema (join target) -- Session.execute() (raw SQL for the batch join) - -No dependency on derive(), CompositeBacking, or retention. - -## Not in scope - -- **Continuous/streaming correlation** — this is one-shot batch. Live correlation (new anchor arrives → auto-query targets) is a different abstraction. -- **Multi-stream correlation** — correlate(A, [B, C, D]) returning aligned tuples. Call correlate() multiple times instead. -- **Embedding cross-match** — correlation is time+space only. "Find similar embeddings across streams" is a different operation (use search_embedding on each stream). diff --git a/plans/old/memory.md b/plans/old/memory.md deleted file mode 100644 index 447e7d3449..0000000000 --- a/plans/old/memory.md +++ /dev/null @@ -1,113 +0,0 @@ - -## Goals -We are building this (seriously - exactly this) https://www.youtube.com/watch?v=Zkj5WSae3Uc - -We are designing a human-centric interface, not an agentic interface for now. -Good human centric interface allows us to test our own tooling before giving it to an agent, - -Reason for this is that if you give something directly to an agent without using it yourself, it might be shit and agent might be rightfully underperforming. - -## all sensor data is stored by default for every run - auto rotation/cleanup, set max_size - -## Data streams - -- These are sensor streams (all sensor data is stored by default for every run) -- But also other data streams created in real time or async. - -### All datapoints that have a temporal index (are `Timestamped`) - -for all temporal datapoints we need to be able to get spatial info for - this is important for multi embodiment, we do this either by having robot_id associated to the stream, - -robots can see the same shoe from different angles, we can deduplicate once temporal or spatial matches are there - -or by directly storing a 4D index or time + 3D index. how we actually store stuff is not important and storage system dependent, how we query is what we care about. - -so I can quickly ask for a video frame, where/when it was captured. -I can detect on top of it, fetch a related LIDAR frame, project - -### Search - -Different datastreams provide different types of search, text or image - -Facial recognition datastream can accept a face search, time or space -Agent narration or Video stream can search by vector embedding, time or space -Sound recording, search by time or space - -Some of these abilities imply other types of search, being able to accept embedding search means you can search by text or by image as well - -## Reprocessing, parallel streams - -different algos can create different new datastreams (like embedding models for example, LLM narration etc) -some of these datastreams are slower than realtime, with ability to catch up (like embeddings aren't generated if robot isn't moving) some of these are to be stored permanently, some are temporary and part of some analysis and will be thrown away. - -if this is designed well, on API level we don't care if we are dealing with a stored stream or a search result, we don't care if stream is stored (and where) or in memory as part of some analysis etc. - -### Example - -speaker clustering model is analyzing audio, gives speaker embedding stream (with temporal/spatial index) -correlating facial recognition embeddings to speech embeddings we can match a face to voice - -## Semantic Costmaps - -overlay semantic similarity onto a costmap rendered in rerun in realtime - -## Object Search - -We have many frames to analyze with VLM, analysis is costly (but cheaper if batched!) -So we need to use traditional search algos, use semantic similarity as a heuristic, find hotspots in time and space to analyze with VLMs (just some standard hill climbing, simulated annealing and such. keep in mind we might not be looking for a global optimum but local hills) we can also use clustering algos etc - -Once best matches are found, project into 3d - -## logs - -system logs, human-agent-tool interaction are also temporal/textual streams that can be stored, embedded, searched over - -### Embedding data streams - -# milestone 1 - -I can query for "a shoe" in a textbox, get a semantic map overlay - -# milestone 2 - -I can query for "a shoe" in a textbox, get PointStamped for shoes detected by VLM - -## example interaction 1: memory search - -search for "a shoe" - independent stored streams offer textual queries - -3 agent narration matches (temporal textual stream2) -1 tool call match (temporal textual stream 2) -temporal-semantic graph returned (image embeddings) - -temporal-spatial-semantic graph analysis - 3 clusters identified, feed each cluster to some description VLM - "a shoe on a kitchen floor", "a shoe on a desk" etc - -return to an agent: - -- narration block, timestamp -- tool call match, timestamp -- return 3 clusters, timestamps, rough locations - -agent calls goto (event cluster 3) - -cluster 3 - find best image, correlate to lidar, project into space, navigate, once there, use VLM and visual nav - -## example interaction 2: arm - -mustafa is able to ask for an object in proximity to the robot. robot searches memory biasing distance in time and space. if close match is not found, search can be expanded - -"do you remember the red sock" - -"yes I saw it 35 seconds ago" - -"yes I saw it 3 days ago behind me" - -"yes I saw it an hour ago, it was 15 meters away" - - -# Questions - -"where was I, when this log line was added" - -"how long for have I been observing this object" diff --git a/plans/old/memory1.md b/plans/old/memory1.md deleted file mode 100644 index fd2ddb7e83..0000000000 --- a/plans/old/memory1.md +++ /dev/null @@ -1,318 +0,0 @@ -# DB → Session → Store: DimOS Memory2 - -## Context - -PR #1080 introduced `TimeSeriesStore[T]` with pluggable backends. Paul's review identified it mixes DB lifecycle, connection, and query concerns. Additionally, `memory.md` describes a system where all sensor data is stored as temporal streams with 4D spatial-temporal indexing, cross-stream correlation is the primary operation, and search (text/embedding) must work across streams. This plan builds a clean 3-layer architecture from scratch in `dimos/memory2/`, SQLite-first, with R\*Tree indexing for spatial-temporal queries. - -## Architecture - -``` -SqliteDB (config + factory + WAL + sqlite-vec + R*Tree) - └─ Session (connection, thread-bound) - ├─ .timeseries(table, type) → TimeSeries[T] (temporal store + optional 4D spatial index) - ├─ .embeddings(table, dim) → EmbeddingStore (KNN search store + optional spatial index) - ├─ .at(t, *stores) → tuple (multi-stream temporal lookup) - ├─ .between(t1, t2, *stores)→ Iterator[tuple] (batch temporal join) - └─ .execute(sql, params) → rows (raw SQL escape hatch) -``` - -Every stream gets an R\*Tree (4D: time + xyz). Spatial info is optional per-row — rows without spatial data are indexed by time only (x/y/z set to NaN sentinels or excluded). This eliminates the need for cross-stream pose joins: each datapoint carries its own spatial context at write time. - -## API Examples - -```python -db = SqliteDB("run_001.db") - -with db.session() as s: - images = s.timeseries("color_images", Image) - poses = s.timeseries("poses", PoseStamped) - lidar = s.timeseries("lidar", PointCloud) - img_emb = s.embeddings("image_embeddings", dim=512) - - # --- Save with optional spatial context --- - images.save(frame) # temporal only - images.save(frame, pose=robot_pose) # temporal + spatial (baked in) - - # --- Temporal queries (chainable) --- - hit = images.at(now).one() # closest to now → Hit | None - hit = images.at(now, tolerance=0.1).one() # within 100ms or None - hit = images.before(now).one() # last item before now - hit = images.last() # most recent (shortcut) - - # Lazy fetch actual data from Hit - image = images.load(hit.ts) # → Image - - # --- Spatial queries (R*Tree, chainable) --- - hits = images.near(Point(1, 2, 3), radius=0.5).fetch() - hits = images.near(robot_pose, radius=2.0).between(t1, t2).fetch() - - # Each hit has pose (full 6DOF) for reconstruction - for hit in hits: - print(f"Seen at {hit.pose}, dist={hit.spatial_distance}m") - - # --- Embedding search (chainable) --- - query_vec = clip.encode_text("a shoe") - - # Embedding only - hits = img_emb.search(query_vec, k=20).fetch() - - # Embedding + spatial - hits = img_emb.search(query_vec, k=10).near(robot_pose, radius=3.0).fetch() - - # Embedding + temporal - hits = img_emb.search(query_vec, k=10).between(t1, t2).fetch() - - # All three: embedding + spatial + temporal - hits = (img_emb.search(query_vec, k=10) - .near(robot_pose, radius=5.0) - .between(now - 3600, now) - .fetch()) - - for hit in hits: - hit.ts # when - hit.pose # where + orientation (6DOF) - hit.embedding_distance # similarity score - hit.spatial_distance # meters from query point - image = images.at(hit.ts).one() # correlate to image stream - vec = img_emb.load_embedding(hit.id) # lazy fetch embedding - - # --- Cross-stream temporal lookup --- - pose_hit = poses.at(hit.ts).one() - - # --- Raw SQL escape hatch --- - rows = s.execute("SELECT ... FROM ... JOIN ...", params) -``` - -## File Structure - -``` -dimos/memory2/ - __init__.py # public exports - _sql.py # _validate_identifier(), shared SQL helpers - db.py # DB ABC + SqliteDB - session.py # Session ABC + SqliteSession - hit.py # Hit hierarchy (7 classes: Hit, Temporal, Spatial, Embedding, combos) - query.py # Query hierarchy (7 classes: matching Hit types, chainable) - timeseries.py # TimeSeries[T] ABC + SqliteTimeSeries - embeddings.py # EmbeddingStore ABC + SqliteEmbeddingStore - test_memory2.py # tests -``` - -## Interfaces - -### DB (`db.py`) - -```python -class DB(Resource, ABC): - def session(self) -> Session: ... - def close(self) -> None: ... # closes all tracked sessions - # Resource protocol - def start(self) -> None: pass # usable after __init__ - def stop(self) -> None: self.close() -``` - -`SqliteDB`: -- Stores file path, creates parent dirs on first connect -- `_connect()`: `sqlite3.connect()`, enables WAL mode, loads sqlite-vec -- Tracks sessions via `WeakSet` for cleanup -- `:memory:` uses `file::memory:?cache=shared` URI so sessions share data - -### Session (`session.py`) - -```python -class Session(ABC): - def timeseries(self, table: str, type: type[T]) -> TimeSeries[T]: ... - def embeddings(self, table: str, dim: int) -> EmbeddingStore: ... - def execute(self, sql: str, params=()) -> list: ... - def close(self) -> None: ... - def __enter__ / __exit__ # context manager -``` - -`SqliteSession`: -- Holds one `sqlite3.Connection` -- `timeseries()` / `embeddings()` validate table name, create store, cache it -- `execute()`: raw SQL passthrough -- Cross-stream correlation done via Query builder (e.g. `poses.at(hit.ts).one()`) - -### TimeSeries (`timeseries.py`) - -```python -from dimos.msgs.geometry_msgs.Pose import Pose, PoseLike -from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped -from dimos.msgs.geometry_msgs.Point import Point - -# --- Hit type hierarchy (type-state, 7 classes) --- - -@dataclass -class Hit: - """Base result: just ts + optional pose. All data lazy-fetched.""" - ts: float - pose: Pose | None = None - -@dataclass -class TemporalHit(Hit): - temporal_distance: float = 0.0 # |query_time - ts| - -@dataclass -class SpatialHit(Hit): - spatial_distance: float = 0.0 # meters from query point - pose: Pose = field(default=...) # guaranteed present for spatial hits - -@dataclass -class EmbeddingHit(Hit): - embedding_distance: float = 0.0 # cosine/L2 in embedding space - id: str = "" - metadata: dict | None = None - -# Combinations (multiple inheritance) -@dataclass -class TemporalSpatialHit(TemporalHit, SpatialHit): ... - -@dataclass -class TemporalEmbeddingHit(TemporalHit, EmbeddingHit): ... - -@dataclass -class SpatialEmbeddingHit(SpatialHit, EmbeddingHit): ... - -@dataclass -class FullHit(TemporalHit, SpatialHit, EmbeddingHit): ... - -# --- Query type-state hierarchy (7 classes, narrows on chain) --- - -class Query: - """Base query builder. Accumulates filters, executes on .fetch().""" - def fetch(self, limit: int | None = None) -> list[Hit]: ... - def one(self) -> Hit | None: ... - def count(self) -> int: ... - -class TemporalQuery(Query): - def near(self, point: Point | PoseLike | PoseStamped, - radius: float) -> TemporalSpatialQuery: ... - def fetch(self, limit=None) -> list[TemporalHit]: ... - def one(self) -> TemporalHit | None: ... - -class SpatialQuery(Query): - def at(self, t: float, tolerance: float | None = None) -> TemporalSpatialQuery: ... - def before(self, t: float) -> TemporalSpatialQuery: ... - def after(self, t: float) -> TemporalSpatialQuery: ... - def between(self, t1: float, t2: float) -> TemporalSpatialQuery: ... - def fetch(self, limit=None) -> list[SpatialHit]: ... - def one(self) -> SpatialHit | None: ... - -class EmbeddingQuery(Query): - def near(self, ...) -> SpatialEmbeddingQuery: ... - def at(self, ...) -> TemporalEmbeddingQuery: ... - def between(self, ...) -> TemporalEmbeddingQuery: ... - def fetch(self, limit=None) -> list[EmbeddingHit]: ... - -class TemporalSpatialQuery(Query): - def fetch(self, limit=None) -> list[TemporalSpatialHit]: ... - -class TemporalEmbeddingQuery(Query): - def near(self, ...) -> FullQuery: ... - def fetch(self, limit=None) -> list[TemporalEmbeddingHit]: ... - -class SpatialEmbeddingQuery(Query): - def at(self, ...) -> FullQuery: ... - def between(self, ...) -> FullQuery: ... - def fetch(self, limit=None) -> list[SpatialEmbeddingHit]: ... - -class FullQuery(Query): - def fetch(self, limit=None) -> list[FullHit]: ... - -# All query logic (SQL generation) lives in base Query. -# Subclasses only override type signatures — no duplicated logic. - -# --- TimeSeries --- - -class TimeSeries(Generic[T], ABC): - # Write - def save(self, *items: T, pose: PoseLike | PoseStamped | None = None) -> None: ... - - # Start a query chain (returns typed Query) - def at(self, t: float, tolerance: float | None = None) -> TemporalQuery: ... - def before(self, t: float) -> TemporalQuery: ... - def after(self, t: float) -> TemporalQuery: ... - def between(self, t1: float, t2: float) -> TemporalQuery: ... - def near(self, point: Point | PoseLike | PoseStamped, - radius: float) -> SpatialQuery: ... - - # Convenience terminals (no chain needed) - def last(self) -> TemporalHit | None: ... - def first(self) -> TemporalHit | None: ... - - # Lazy data fetch (from Hit.ts) - def load(self, ts: float) -> T | None: ... - - def delete(self, t: float) -> bool: ... - def count(self) -> int: ... -``` - -All spatial parameters accept DimOS types with `.x`, `.y`, `.z` — `Point`, `Pose`, `PoseStamped`, `PoseLike`. Full pose (with orientation) stored per row for post-filter reconstruction. - -`SqliteTimeSeries`: -- Data table: `CREATE TABLE {table} (rowid INTEGER PRIMARY KEY, timestamp REAL NOT NULL, data BLOB NOT NULL)` -- R\*Tree: `CREATE VIRTUAL TABLE {table}_rtree USING rtree(id, min_t, max_t, min_x, max_x, min_y, max_y, min_z, max_z)` -- R\*Tree `id` matches `rowid` in data table -- `save(item, pose=p)`: inserts data row + R\*Tree entry with `(ts, ts, x, x, y, y, z, z)` (point) -- `save(item)` without pose: inserts data row + R\*Tree entry with time only (x/y/z set to ±inf to match any spatial query) -- `at()`: `SELECT data FROM {table} ORDER BY ABS(timestamp - ?) LIMIT 1` -- `between()`: `SELECT data FROM {table} WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp` -- `near()`: `SELECT d.data FROM {table} d JOIN {table}_rtree r ON d.rowid = r.id WHERE r.min_t >= ? AND r.max_t <= ? AND r.min_x >= ? AND r.max_x <= ? ...` -- Lazy table creation on first operation - -### EmbeddingStore (`embeddings.py`) - -```python -class EmbeddingStore(ABC): - def save(self, id: str, vector: np.ndarray, timestamp: float, - pose: PoseLike | PoseStamped | None = None, - metadata: dict | None = None) -> None: ... - - # Start a query chain with embedding search (returns typed Query) - def search(self, query: np.ndarray, k: int = 10) -> EmbeddingQuery: ... - - # Chain: .search(vec, 10).near(p, 3.0).between(t1, t2).fetch() → list[FullHit] - - # Lazy fetch - def load_embedding(self, id: str) -> np.ndarray | None: ... - - def delete(self, id: str) -> bool: ... - def count(self) -> int: ... -``` - -Uses the same `Query` builder and `Hit` result type as TimeSeries. `.search()` returns a Query with embedding filter set; chain `.near()`, `.between()`, etc. to add spatial/temporal constraints. - -`SqliteEmbeddingStore`: -- Three tables: `{table}_vec` (sqlite-vec virtual, `float[dim]`), `{table}_meta` (rowid, id, timestamp, x, y, z, metadata JSON), `{table}_rtree` (R\*Tree for spatial-temporal filtering) -- `search()`: KNN via `{table}_vec MATCH ?`, joined with meta for time/spatial filters -- `near=` param: pre-filters candidates via R\*Tree before KNN -- Each `SearchHit` carries position (x, y, z) directly — no pose join needed - -## SQLite Details - -- **WAL mode**: enabled on first connection per DB file. Allows concurrent readers + one writer across threads. -- **R\*Tree**: built into SQLite (compile-time option, enabled by default). Every stream gets a 4D R\*Tree (time + xyz). No extra extension needed. -- **sqlite-vec**: loaded via `conn.load_extension()`. Required for EmbeddingStore. TimeSeries works without it. -- **Thread safety**: each session = one connection = one thread. No `check_same_thread=False`. -- **Pickle BLOBs**: same serialization as current SqliteTSStore. Works with any `Timestamped` subclass. -- **Spatial data without pose**: rows saved without `pose=` get R\*Tree entry with x/y/z bounds set to ±1e38 (effectively unbounded), so they match any spatial query but don't constrain results. - -## Implementation Order - -1. `_sql.py` — `_validate_identifier()` -2. `hit.py` — `Hit` dataclass (unified result type) -3. `query.py` — `Query` builder (accumulates filters, generates SQL, returns `list[Hit]`) -4. `timeseries.py` — `TimeSeries[T]` ABC + `SqliteTimeSeries` (chain methods return Query) -5. `embeddings.py` — `EmbeddingStore` ABC + `SqliteEmbeddingStore` (.search() returns Query) -6. `session.py` — `Session` ABC + `SqliteSession` -7. `db.py` — `DB` ABC + `SqliteDB` (config, connect, WAL, sqlite-vec, Resource) -8. `__init__.py` — public exports -9. `test_memory2.py` — tests: lifecycle, temporal/spatial/embedding queries, combined chains, lazy fetch -10. `pyproject.toml` — add `sqlite-vec` dependency - -## Verification - -1. `uv run pytest dimos/memory2/test_memory2.py -v` — all new tests pass -2. `uv run mypy dimos/memory2/` — type checks clean -3. Existing `dimos/memory/timeseries/test_base.py` still passes (untouched) diff --git a/plans/old/memory2.md b/plans/old/memory2.md deleted file mode 100644 index 816d2cb0af..0000000000 --- a/plans/old/memory2.md +++ /dev/null @@ -1,898 +0,0 @@ -# DimOS Memory2 Spec v2.1 - -Status: implementation-oriented draft for a coding agent. - -This spec is intentionally code/example focused. It defines the public API shape, core invariants, and the minimum execution model needed to implement a useful local-first multimodal memory system. - ---- - -# 0. Goals - -Memory2 stores and queries multimodal robot observations. - -Primary use cases: - -1. Store raw streams: images, lidar, poses, logs, narration. -2. Generate streams from streams: embeddings from images, captions from images, detections from frames. -3. Narrow data without loading payloads: top-k matches, time windows, spatial subsets. -4. Re-query narrowed results. -5. Correlate across streams. -6. Keep payload loading lazy. - -Non-goal: - -- Do not implement high-level search policies here (hotspot search, VLM orchestration, semantic map UI). - ---- - -# 1. Terminology - -## 1.1 Observation - -A single stored item. - -Examples: - -- one RGB frame -- one lidar scan -- one log line -- one CLIP embedding -- one VLM caption - -## 1.2 Stream - -Appendable collection of observations with a shared payload type and capability set. - -Examples: - -- `rgb_front` -- `lidar_mid360` -- `robot_pose` -- `tool_logs` -- `image_embeddings_clip` - -## 1.3 ObservationSet - -A lazy, read-only, queryable view over observations. - -Important: - -- an `ObservationSet` is **not** a Python set -- it is usually **lazy** -- it usually contains **refs + metadata**, not payloads -- it may represent a subset of one stream or a projection/correlation over multiple streams - -## 1.4 DerivedStream - -A stream generated from upstream streams or observation sets. - -Examples: - -- embeddings generated from images -- captions generated from images -- detections generated from frames - -Rule: - -- same observation identity -> `ObservationSet` -- new observation identity -> `DerivedStream` - ---- - -# 2. Core invariants - -These are hard requirements. - -## 2.1 Stable identity - -Every observation has a stable reference independent of timestamp. - -```python -from dataclasses import dataclass - -@dataclass(frozen=True) -class ObservationRef: - stream: str - id: str -``` - -Never use timestamp as the primary load key. - -Bad: - -```python -images.load(hit.ts) -``` - -Good: - -```python -images.load(hit.ref) -``` - -## 2.2 Payloads are lazy - -Queries and observation sets must not load full payloads unless explicitly requested. - -Examples of payloads that must stay lazy: - -- images -- point clouds -- audio chunks -- voxel blocks - -## 2.3 Metadata may be materialized - -It is acceptable to materialize lightweight metadata for result sets: - -- ref -- timestamp -- pose -- scores -- tags -- lineage pointers - -## 2.4 Query results are re-queryable - -A narrowed result should still support `.query()` and further filtering/ranking. - -## 2.5 Query results are not appendable - -`ObservationSet` is read-only. - -Only `Stream` is appendable. - -## 2.6 Spatially unknown != spatially everywhere - -Unlocalized observations do not match spatial queries by default. - -## 2.7 Derived outputs must carry lineage - -Any derived stream should record parent streams and parent refs or parent query provenance. - ---- - -# 3. Public API - -## 3.1 Top-level objects - -```python -class DB: ... -class Session: ... -class Stream[T]: ... -class ObservationSet[T]: ... -class Query[T]: ... -class Correlator: ... -``` - -## 3.2 Shared read/query protocol - -`Stream` and `ObservationSet` should share the same read/query protocol. - -```python -from typing import Protocol, Iterable, Iterator, Generic, TypeVar, Any - -T = TypeVar("T") - -class QueryableObservationSet(Protocol, Generic[T]): - def query(self) -> "Query[T]": ... - def load(self, ref: ObservationRef) -> T: ... - def load_many(self, refs: list[ObservationRef], *, batch_size: int = 32) -> list[T]: ... - def iter_meta(self, *, page_size: int = 128) -> Iterator[list["ObservationRow"]]: ... - def count(self) -> int: ... - def capabilities(self) -> set[str]: ... -``` - -`Stream` extends this with append/introspection. - ---- - -# 4. Core data structures - -## 4.1 Observation metadata - -```python -from dataclasses import dataclass, field -from typing import Any - -@dataclass -class Pose: - xyz: tuple[float, float, float] - quat_xyzw: tuple[float, float, float, float] | None = None - -@dataclass -class ObservationMeta: - ref: ObservationRef - ts_start: float | None = None - ts_end: float | None = None - robot_id: str | None = None - frame_id: str | None = None - pose: Pose | None = None - pose_source: str | None = None - pose_confidence: float | None = None - transform_version: str | None = None - timestamp_uncertainty: float | None = None - payload_codec: str | None = None - payload_size_bytes: int | None = None - tags: dict[str, Any] = field(default_factory=dict) -``` - -Notes: - -- point observations use `ts_start == ts_end` -- interval observations use `[ts_start, ts_end]` -- `pose` is a denormalized snapshot for fast filtering -- provenance fields allow later reinterpretation after better localization - -## 4.2 Query/ObservationSet row - -An `ObservationSet` should expose rows with lightweight metadata and scores. - -```python -@dataclass -class ObservationRow: - ref: ObservationRef - ts_start: float | None = None - ts_end: float | None = None - pose: Pose | None = None - scores: dict[str, float] = field(default_factory=dict) - metadata: dict[str, Any] = field(default_factory=dict) -``` - -Expected score keys: - -- `embedding_distance` -- `text_rank` -- `spatial_distance` -- `temporal_distance` -- `final_rank` - -## 4.3 Lineage - -```python -@dataclass -class Lineage: - parents: list[str] = field(default_factory=list) - parent_refs: list[ObservationRef] = field(default_factory=list) - query_repr: str | None = None - transform_name: str | None = None - transform_version: str | None = None -``` - -This can be attached to streams, rows, or derived outputs. - ---- - -# 5. Stream API - -## 5.1 Stream creation - -```python -with db.session() as s: - images = s.stream( - name="rgb_front", - payload_type=Image, - capabilities={"temporal", "spatial", "load"}, - retention="run", - ) - - logs = s.stream( - name="tool_logs", - payload_type=str, - capabilities={"temporal", "text", "load"}, - retention="run", - ) - - image_embeddings = s.stream( - name="image_embeddings_clip", - payload_type=Embedding, - capabilities={"temporal", "spatial", "embedding", "load"}, - retention="derived", - config={"dim": 512, "metric": "cosine"}, - ) -``` - -## 5.2 Stream interface - -```python -class Stream(QueryableObservationSet[T], Generic[T]): - def append(self, payload: T, **meta: Any) -> ObservationRef: ... - def append_many(self, payloads: Iterable[T], metas: Iterable[dict[str, Any]]) -> list[ObservationRef]: ... - def meta(self, ref: ObservationRef) -> ObservationMeta: ... - def info(self) -> dict[str, Any]: ... - def stats(self) -> dict[str, Any]: ... - def retention(self) -> str: ... -``` - -## 5.3 Append examples - -```python -frame_ref = images.append( - frame, - ts_start=now, - ts_end=now, - robot_id="go2_01", - frame_id="map", - pose=current_pose, - pose_source="slam_localization", - transform_version="loc_epoch_17", -) - -log_ref = logs.append( - "planner timeout on task 42", - ts_start=now, - ts_end=now, - tags={"level": "warning", "module": "planner"}, -) -``` - ---- - -# 6. ObservationSet API - -## 6.1 Design intent - -`ObservationSet` is the key abstraction for narrowed/re-queryable results. - -It should: - -- be lazy by default -- usually avoid payload loading -- support `.query()` -- support loading payloads one-by-one or in batches -- support projection to related streams -- support materialization when needed - -## 6.2 Interface - -```python -class ObservationSet(QueryableObservationSet[T], Generic[T]): - def refs(self, *, limit: int | None = None) -> list[ObservationRef]: ... - def rows(self, *, limit: int | None = None) -> list[ObservationRow]: ... - def one(self) -> ObservationRow: ... - def fetch_page(self, *, limit: int = 128, offset: int = 0) -> list[ObservationRow]: ... - def project_to(self, stream: "Stream[Any]") -> "ObservationSet[Any]": ... - def materialize(self, *, name: str | None = None, retention: str = "ephemeral") -> "ObservationSet[T]": ... - def derive(self, *, name: str, transform: "Transform[T, Any]", retention: str = "derived", payload_type: type | None = None) -> "Stream[Any]": ... - def lineage(self) -> Lineage: ... -``` - -## 6.3 Example: narrowing data and re-querying - -```python -recent_images = ( - images.query() - .filter_time(now - 600, now) - .fetch_set() -) - -recent_nearby_images = ( - recent_images.query() - .filter_near(current_pose, radius=3.0) - .fetch_set() -) -``` - -## 6.4 Example: embedding search without loading images - -```python -matches = ( - image_embeddings.query() - .search_embedding(query_vec, candidate_k=2000) - .filter_time(now - 3600, now) - .filter_near(current_pose, radius=8.0) - .rank(embedding=1.0, recency=0.2, distance=0.3) - .limit(1000) - .fetch_set() -) -``` - -Important: - -- `matches` should not contain 1000 image payloads in RAM -- it should usually contain refs + lightweight metadata/scores only - -## 6.5 Example: payload access stays explicit - -```python -rows = matches.fetch_page(limit=20, offset=0) -first_payload = image_embeddings.load(rows[0].ref) - -candidate_refs = matches.refs(limit=16) -embeddings_batch = image_embeddings.load_many(candidate_refs, batch_size=16) -``` - -## 6.6 Example: projecting embedding matches to images - -Assume each embedding row records its parent image ref. - -```python -matched_frames = matches.project_to(images) -preview_rows = matched_frames.fetch_page(limit=12) -preview_frames = images.load_many([r.ref for r in preview_rows], batch_size=12) -``` - -## 6.7 Example: deriving a caption stream from a narrowed image set - -```python -captions = matched_frames.derive( - name="vlm_captions_shoe_candidates", - transform=caption_model, - retention="derived", - payload_type=str, -) -``` - -This creates a new stream because it creates new observation identities. - ---- - -# 7. Query API - -## 7.1 Query design - -Query should be composable and capability-based. - -It should support: - -- hard filters -- candidate generation -- soft ranking -- terminal materialization - -## 7.2 Interface - -```python -class Query(Generic[T]): - def filter_time(self, t1: float, t2: float) -> "Query[T]": ... - def filter_before(self, t: float) -> "Query[T]": ... - def filter_after(self, t: float) -> "Query[T]": ... - def filter_near(self, pose: Pose, radius: float, *, include_unlocalized: bool = False) -> "Query[T]": ... - def filter_tags(self, **tags: Any) -> "Query[T]": ... - def filter_refs(self, refs: list[ObservationRef]) -> "Query[T]": ... - - def search_text(self, text: str, *, candidate_k: int | None = None) -> "Query[T]": ... - def search_embedding(self, vector: list[float], *, candidate_k: int) -> "Query[T]": ... - - def rank(self, **weights: float) -> "Query[T]": ... - def limit(self, k: int) -> "Query[T]": ... - - def fetch(self) -> list[ObservationRow]: ... - def fetch_set(self) -> ObservationSet[T]: ... - def count(self) -> int: ... - def one(self) -> ObservationRow: ... -``` - -## 7.3 Hard filters vs ranking - -This distinction must stay explicit. - -Example: - -```python -hits = ( - image_embeddings.query() - .search_embedding(query_vec, candidate_k=1000) - .filter_time(t1, t2) - .filter_near(current_pose, radius=5.0) - .rank(embedding=1.0, recency=0.15, distance=0.35) - .limit(50) - .fetch() -) -``` - -Execution meaning: - -1. embedding search creates candidates -2. time/space filters remove candidates -3. ranking combines scores on remaining rows -4. limit applies at the end - -Do not leave this ambiguous. - ---- - -# 8. Under-the-hood model for ObservationSet - -## 8.1 Default behavior - -`ObservationSet` should be lazy/unresolved until needed. - -It must not eagerly decode payloads. - -## 8.2 Internal backing kinds - -Publicly there is one `ObservationSet` class. Internally it may have multiple backing strategies. - -```python -from dataclasses import dataclass -from typing import Literal - -@dataclass -class PredicateBacking: - source_name: str - query_repr: str - -@dataclass -class RefTableBacking: - table_name: str - source_streams: list[str] - ordered: bool = False - -@dataclass -class CompositeBacking: - op: Literal["union", "intersection", "difference", "project", "join"] - input_ids: list[str] - query_repr: str -``` - -Recommended internal shape: - -```python -class ObservationSet(QueryableObservationSet[T], Generic[T]): - _backing: PredicateBacking | RefTableBacking | CompositeBacking - _capabilities: set[str] - _lineage: Lineage -``` - -## 8.3 Predicate-backed set - -Use when the set is still naturally expressible as a query over the underlying source. - -Examples: - -- time range over one stream -- tag filter over one stream -- spatial filter over one stream -- text search over one stream - -No payloads need to be materialized. - -## 8.4 Ref-table-backed set - -Use when a query creates an explicit candidate pool. - -Examples: - -- top-k embedding matches -- correlation results -- reranked subsets -- cluster representatives - -Important: - -- refs do not need to live in Python memory -- they can live in a SQLite temp table -- metadata rows may be fetched page-wise - -## 8.5 Composite-backed set - -Use for union/intersection/project/join style operations over other observation sets. - ---- - -# 9. Payload loading rules - -## 9.1 Allowed eager data - -Eagerly loaded into Python is acceptable for: - -- small metadata rows -- refs -- scores -- tags - -## 9.2 Disallowed by default - -Do not eagerly load by default: - -- all image payloads -- all point clouds -- all audio blobs -- all voxel blocks - -## 9.3 Required explicit methods - -```python -payload = images.load(ref) -payloads = images.load_many(refs, batch_size=32) - -for page in image_set.iter_meta(page_size=128): - ... -``` - -No API should silently decode a thousand images just because `.fetch_set()` was called. - ---- - -# 10. Stream generation from streams - -This is a central use case. - -## 10.1 Example: embeddings from images - -```python -frames = ( - images.query() - .filter_time(now - 60, now) - .fetch_set() -) - -embeddings = frames.derive( - name="image_embeddings_clip_recent", - transform=clip_embedder, - retention="derived", - payload_type=Embedding, -) -``` - -Implementation expectation: - -- `derive()` iterates source payloads in batches -- output rows record lineage to input refs -- output stream stores its own payloads/metadata/indexes - -## 10.2 Example transform protocol - -```python -U = TypeVar("U") - -class Transform(Protocol, Generic[T, U]): - name: str - version: str - - def map_batch(self, rows: list[ObservationRow], payloads: list[T]) -> list[tuple[U, dict[str, Any]]]: ... -``` - -This allows a coding agent to implement batch transforms cleanly. - ---- - -# 11. Correlation API - -Correlation is first-class. - -## 11.1 Example - -```python -bundle = s.correlate( - anchor=log_ref, - with_streams=[images, lidar, poses], - by={ - "rgb_front": {"mode": "nearest_time", "tolerance": 0.2}, - "lidar_mid360": {"mode": "nearest_time", "tolerance": 0.1}, - "robot_pose": {"mode": "nearest_time", "tolerance": 0.05}, - }, -) -``` - -## 11.2 Correlation result shape - -```python -@dataclass -class CorrelatedItem: - stream: str - row: ObservationRow | None - reason: dict[str, Any] - -@dataclass -class CorrelationBundle: - anchor: ObservationRef - items: list[CorrelatedItem] -``` - -At minimum support: - -- nearest in time -- overlapping interval - -Later support: - -- nearest in space -- same robot\_id -- same frame\_id - ---- - -# 12. Introspection - -These are needed for human tooling and debugging. - -```python -s.list_streams() -images.info() -images.stats() -matches.capabilities() -matches.lineage() -``` - -Recommended fields for `stream.info()`: - -```python -{ - "name": "rgb_front", - "payload_type": "Image", - "row_count": 12345, - "retention": "run", - "capabilities": ["temporal", "spatial", "load"], - "time_bounds": [1700000000.0, 1700003600.0], - "spatial_bounds": [xmin, ymin, zmin, xmax, ymax, zmax], - "payload_codec": "jpeg", -} -``` - ---- - -# 13. Backend implementation target - -SQLite-first, but backend-specific details should stay behind the API. - -## 13.1 Expected SQLite tools - -- normal tables for metadata -- temp tables for candidate refs -- FTS5 for text search -- R-tree for spatial indexing -- vector extension when available - -## 13.2 Suggested mapping per stream - -- metadata table -- payload table or blob column -- optional FTS table -- optional vector index table -- optional spatial index table - -## 13.3 Important backend rule - -Unlocalized rows should not be inserted into the spatial index. - ---- - -# 14. Concrete execution examples - -## 14.1 Time-filtered image subset stays lazy - -```python -recent = ( - images.query() - .filter_time(now - 300, now) - .fetch_set() -) -``` - -Expected implementation: - -- create predicate-backed `ObservationSet` -- do not decode image payloads -- only execute SQL when rows/count/payloads are requested - -## 14.2 Embedding search becomes ref-table-backed - -```python -matches = ( - image_embeddings.query() - .search_embedding(query_vec, candidate_k=5000) - .filter_time(now - 7200, now) - .limit(1000) - .fetch_set() -) -``` - -Expected implementation: - -- run vector search -- write candidate refs + scores to temp table -- return ref-table-backed `ObservationSet` -- allow further `.query()` by restricting to that candidate table - -## 14.3 Re-query candidate set without loading payloads - -```python -nearby_matches = ( - matches.query() - .filter_near(current_pose, radius=6.0) - .limit(100) - .fetch_set() -) -``` - -Expected implementation: - -- join source metadata with candidate ref table -- apply spatial filter in backend -- return new lazy observation set - -## 14.4 Paginated preview - -```python -page = nearby_matches.fetch_page(limit=24, offset=0) -preview_refs = [row.ref for row in page] -preview_embeddings = image_embeddings.load_many(preview_refs, batch_size=24) -``` - -Again: explicit payload loading only. - ---- - -# 15. What the coding agent should implement first - -Implementation priority order: - -1. `ObservationRef`, `ObservationMeta`, `ObservationRow`, `Lineage` -2. `DB`, `Session`, `Stream` -3. `Query` with time filters and `.fetch_set()` -4. lazy `ObservationSet` with predicate backing -5. explicit payload loading methods -6. text search -7. ref-table-backed observation sets -8. embedding search -9. `project_to()` -10. `derive()` -11. correlation -12. introspection/stats - ---- - -# 16. Minimal acceptance examples - -These examples should work. - -## 16.1 Re-query narrowed data - -```python -recent = images.query().filter_time(t1, t2).fetch_set() -recent2 = recent.query().filter_near(pose, radius=2.0).fetch_set() -assert recent2.count() <= recent.count() -``` - -## 16.2 Fetch set does not load payloads - -```python -matches = images.query().filter_time(t1, t2).limit(1000).fetch_set() -# should be cheap even for large image payloads -rows = matches.fetch_page(limit=10) -assert len(rows) == 10 -``` - -## 16.3 Derived stream from narrowed set - -```python -subset = images.query().filter_time(t1, t2).limit(100).fetch_set() -captions = subset.derive( - name="captions_test", - transform=caption_model, - retention="derived", - payload_type=str, -) -assert captions.count() == subset.count() -``` - -## 16.4 Projection from embeddings to images - -```python -emb_matches = image_embeddings.query().search_embedding(qvec, candidate_k=100).fetch_set() -frame_matches = emb_matches.project_to(images) -rows = frame_matches.fetch_page(limit=5) -frames = images.load_many([r.ref for r in rows], batch_size=5) -assert len(frames) == 5 -``` - ---- - -# 17. Summary - -Memory2 should expose: - -- appendable `Stream` -- lazy read-only `ObservationSet` -- composable `Query` -- explicit payload loading -- derived stream generation -- re-queryable narrowed results -- stable observation refs -- backend-backed candidate sets instead of eager payload lists - -The most important implementation rule is this: - -> `fetch_set()` returns a lazy queryable view over observations, not a Python list of decoded payloads. diff --git a/plans/old/memory3.md b/plans/old/memory3.md deleted file mode 100644 index d8f6e1daa0..0000000000 --- a/plans/old/memory3.md +++ /dev/null @@ -1,357 +0,0 @@ -# Memory2 Implementation Plan -## Context - -Check `questions.md` - -## File Structure - -``` -dimos/memory2/ - __init__.py # public exports (re-exports from API + default backend) - types.py # ObservationRow, Lineage, StreamInfo - store.py # Store ABC (Resource lifecycle) - session.py # Session ABC (stream factory) - stream.py # StreamBase, BlobStream, EmbeddingStream, TextStream ABCs - query.py # Query ABC (filter/search/rank/limit → fetch/fetch_set) - observation_set.py # ObservationSet ABC - - impl/ - sqlite/ - __init__.py # exports SqliteStore - store.py # SqliteStore (connection, WAL, extension loading) - session.py # SqliteSession (stream factory, _streams registry) - stream.py # SqliteBlobStream, SqliteEmbeddingStream, SqliteTextStream - query.py # SqliteQuery (SQL generation, execution) - observation_set.py # SqliteObservationSet (predicate/ref-table backing) - _sql.py # SQL helpers, identifier validation, schema DDL - - test_memory2.py # tests (against SqliteStore) -``` - -## API Layer (`dimos/memory2/`) - -### `types.py` — Data classes - -```python -@dataclass -class ObservationRow: - id: int - ts: float | None = None - pose: PoseStamped | None = None - tags: dict[str, Any] = field(default_factory=dict) - _data: Any = field(default=None, repr=False) - _load: Callable[[], Any] | None = field(default=None, repr=False) - - @property - def data(self) -> Any: - """Lazy payload access. Pre-populated for appended events, fetched on demand for query results.""" - if self._data is None and self._load is not None: - self._data = self._load() - return self._data - -@dataclass -class Lineage: - parent_stream: str | None = None # from _streams registry (stream-level) - parent_id: int | None = None # per-row: which row in parent stream - -@dataclass -class StreamInfo: - name: str - payload_type: type - parent_stream: str | None # lineage: all rows derive from this stream - count: int -``` - -Poses use DimOS's existing `PoseLike` type alias (`Pose | PoseStamped | Point | PointStamped`). No custom Pose type. - -### `store.py` — Store ABC - -```python -class Store(Resource, ABC): - def session(self) -> Session: ... - def close(self) -> None: ... - def start(self) -> None: pass - def stop(self) -> None: self.close() -``` - -### `session.py` — Session ABC - -```python -PoseProvider = Callable[[], PoseLike | None] - -class Session(ABC): - def stream(self, name: str, payload_type: type, *, - pose_provider: PoseProvider | None = None) -> BlobStream: ... - def embedding_stream(self, name: str, *, - model: EmbeddingModel) -> EmbeddingStream: ... - def text_stream(self, name: str, payload_type: type, *, - tokenizer: str = "unicode61", - pose_provider: PoseProvider | None = None) -> TextStream: ... - def list_streams(self) -> list[StreamInfo]: ... - def close(self) -> None: ... - def __enter__ / __exit__ -``` - -### `stream.py` — Stream hierarchy ABCs - -```python -class StreamBase(ABC, Generic[T]): - """Abstract base. No text/vector indexes.""" - pose_provider: PoseProvider | None = None # auto-fills pose on append if set - - # Write - def append(self, payload: T, *, - ts: float | None = None, # defaults to time.time() - pose: PoseLike | None = None, # explicit pose overrides provider - tags: dict[str, Any] | None = None, - parent_id: int | None = None, - ) -> ObservationRow: ... # returned row has .data pre-populated - - # Reactive - @property - def appended(self) -> Observable[ObservationRow]: ... # .data pre-populated - - # Transform (see transform.md for details) - def transform(self, - source: StreamBase | ObservationSet, - fn: Callable[[Any], T | list[T] | None], - *, - live: bool = False) -> Self: - """Process source data, store results with lineage. - StreamBase source: backfill + subscribe (live=True skips backfill). - ObservationSet source: batch only.""" - ... - - # Read - def query(self) -> Query[T]: ... - def load(self, row: ObservationRow) -> T: ... - def load_many(self, rows: list[ObservationRow], *, batch_size=32) -> list[T]: ... - def iter_meta(self, *, page_size=128) -> Iterator[list[ObservationRow]]: ... - def count(self) -> int: ... - -class BlobStream(StreamBase[T]): - """Stream for arbitrary serializable payloads. No special indexes.""" - -class EmbeddingStream(StreamBase[T]): - """Stream with vector index. No payload table — the vector IS the data.""" - model: EmbeddingModel - - def transform(self, source, fn=None, *, live=False) -> Self: - """If fn is None, uses model.embed implicitly.""" - ... - - def vector(self, row: ObservationRow) -> list[float] | None: ... - -class TextStream(StreamBase[T]): - """Stream with FTS index.""" -``` - -### `query.py` — Query ABC - -```python -class Query(ABC, Generic[T]): - # Hard filters - def time_range(self, t1: float, t2: float) -> Query[T]: ... - def before(self, t: float) -> Query[T]: ... - def after(self, t: float) -> Query[T]: ... - - def filter_tags(self, **tags: Any) -> Query[T]: ... - def at(self, t: float, *, tolerance: float = 1.0) -> Query[T]: ... - - # Candidate generation (raise TypeError if stream lacks the required index) - def search_text(self, text: str, *, candidate_k: int | None = None) -> Query[T]: ... - def search_embedding(self, vector: list[float], *, candidate_k: int) -> Query[T]: ... - - def order_by(self, field: str, *, desc: bool = False) -> Query[T]: ... - def limit(self, k: int) -> Query[T]: ... - - # Terminals - def fetch(self) -> list[ObservationRow]: ... - def fetch_set(self) -> ObservationSet[T]: ... - def count(self) -> int: ... - def one(self) -> ObservationRow: ... - def last(self) -> ObservationRow: ... -``` - -TODO: terminals that generate spatial or temporal summaries (maybe as numpy arrays). - -### `observation_set.py` — ObservationSet ABC - -```python -class ObservationSet(ABC, Generic[T]): - # Re-query - def query(self) -> Query[T]: ... - - # Read - def load(self, row: ObservationRow) -> T: ... - def load_many(self, rows, *, batch_size=32) -> list[T]: ... - def rows(self, *, limit=None) -> list[ObservationRow]: ... - def one(self) -> ObservationRow: ... - def fetch_page(self, *, limit=128, offset=0) -> list[ObservationRow]: ... - def count(self) -> int: ... - def lineage(self) -> Lineage: ... - - # Cross-stream - def project_to(self, stream: StreamBase) -> ObservationSet: ... - - # Cleanup - def close(self) -> None: ... - def __enter__(self) -> Self: ... - def __exit__(self, *exc) -> None: ... - def __del__(self) -> None: ... # best-effort fallback -``` - ---- - -## SQLite Implementation (`dimos/memory2/impl/sqlite/`) - -### `store.py` — SqliteStore - -- Stores file path, creates parent dirs on connect -- `_connect()`: `sqlite3.connect()`, WAL mode, loads sqlite-vec (optional), loads FTS5 -- Tracks sessions via `WeakSet` for cleanup -- `:memory:` uses `file::memory:?cache=shared` URI -- Thread safety: each session = one connection, no `check_same_thread=False` - -### `session.py` — SqliteSession - -- Holds one `sqlite3.Connection` -- `stream()` / `embedding_stream()` / `text_stream()`: creates tables if needed, caches stream instances -- Registers stream metadata in a `_streams` registry table: - -```sql -CREATE TABLE _streams ( - rowid INTEGER PRIMARY KEY, - name TEXT UNIQUE NOT NULL, - type TEXT NOT NULL, -- 'blob', 'embedding', 'text' - payload_type TEXT NOT NULL, - parent_stream_id INTEGER -- FK to _streams.rowid (lineage) -); -``` - -### `stream.py` — SqliteBlobStream, SqliteEmbeddingStream, SqliteTextStream - -`append()` inserts a metadata row (SQLite auto-assigns `rowid`), serializes payload into `_payload`, and inserts an R*Tree entry if pose is provided. `EmbeddingStream.append()` inserts into `_vec` only (no `_payload`). `TextStream.append()` inserts into both `_payload` (as TEXT) and `_fts`. Returns `ObservationRow` with `id`, `ts`, `pose`, `tags` populated. - -### `query.py` — SqliteQuery - -- Accumulates filter predicates, search ops, rank spec, ordering, limit -- `at(t, tolerance)` → sugar for `filter_time(t - tol, t + tol)` + `ORDER BY ABS(ts - t) LIMIT 1` -- `order_by(field, desc)` → appends `ORDER BY` clause; valid fields: `ts` -- `fetch()`: generates SQL, executes, returns rows -- `fetch_set()`: creates ObservationSet (predicate-backed or ref-table-backed) -- `search_embedding` → sqlite-vec `MATCH`, writes top-k to temp table → ref-table-backed -- `search_text` → FTS5 `MATCH` -- `filter_near` → R*Tree range query -- `rank` → computes composite score from available score columns - -### `observation_set.py` — SqliteObservationSet - -Internal backing: - -```python -@dataclass -class PredicateBacking: - """Lazy: expressible as SQL WHERE over source stream.""" - source_name: str - query_repr: str # serialized query filters for replay - -@dataclass -class RefTableBacking: - """Materialized: temp table of refs + scores.""" - table_name: str # SQLite temp table - source_streams: list[str] - ordered: bool = False -``` - -- `.query()` on predicate-backed → adds more predicates -- `.query()` on ref-table-backed → filters within that temp table -- `project_to()` → joins backing refs via lineage parent_rowid to target stream -- `close()` drops the temp table for ref-table-backed sets; no-op for predicate-backed -- Supports context manager (`with`) for deterministic cleanup; `__del__` as fallback -- SQLite connection close is the final safety net for any leaked temp tables - -### `_sql.py` — SQL helpers - -```python -def validate_identifier(name: str) -> str: ... # regex check, length limit -``` - -Pose extraction: `_extract_pose(p: PoseLike) -> tuple[float, ...]` pulls `(x, y, z, qx, qy, qz, qw)`. `_reconstruct_pose(row) -> Pose` rebuilds from stored floats. - -Payload serialization: `lcm_encode(payload)` / `lcm_decode(blob, payload_type)`. Non-LCM types rejected at `append()` with `TypeError`. - ---- - -### Schema (per stream) - -**`{name}_meta`** — metadata for all stream types: -```sql -CREATE TABLE {name}_meta ( - rowid INTEGER PRIMARY KEY, - ts REAL, - pose_x REAL, pose_y REAL, pose_z REAL, - pose_qx REAL, pose_qy REAL, pose_qz REAL, pose_qw REAL, - tags TEXT, -- JSON - parent_rowid INTEGER -- lineage: rowid in parent stream -); -CREATE INDEX idx_{name}_meta_ts ON {name}_meta(ts); -``` - -**`{name}_payload`** — BlobStream and TextStream only (not EmbeddingStream): -```sql -CREATE TABLE {name}_payload ( - rowid INTEGER PRIMARY KEY, -- matches _meta.rowid - data BLOB NOT NULL -- TextStream stores TEXT here instead of BLOB -); -``` - -**`{name}_rtree`** — all stream types (rows with pose only): -```sql -CREATE VIRTUAL TABLE {name}_rtree USING rtree( - rowid, - min_t, max_t, -- both set to ts - min_x, max_x, min_y, max_y, min_z, max_z -- both set to pose_xyz -); -``` - -**`{name}_fts`** — TextStream only: -```sql -CREATE VIRTUAL TABLE {name}_fts USING fts5(content); -``` - -**`{name}_vec`** — EmbeddingStream only: -```sql -CREATE VIRTUAL TABLE {name}_vec USING vec0(embedding float[{dim}]); -``` - -All virtual table rowids match `_meta.rowid` directly. - ---- - -## Phase 3: Later (not in first PR) - -- `CompositeBacking` (union/intersection/difference) -- `Correlator` / `s.correlate()` -- `retention` enforcement / cleanup -- Full introspection (stats, spatial_bounds) -- Query objects (`query_objects.md`) — composable criteria + soft scoring - -## Design Decisions - -### API-level - -- **Poses**: all pose params accept `PoseLike` (`Pose | PoseStamped | Point | PointStamped`). No custom pose type. -- **Row identity**: `id` is auto-assigned integer per stream. Unique within a stream. Impl layer maps to SQLite `rowid`. -- **Unlocalized observations**: rows without pose excluded from `filter_near()` by default. `include_unlocalized=True` to include them. -- **Stream hierarchy**: `StreamBase` (ABC) → `BlobStream`, `EmbeddingStream`, `TextStream`. Indexing is determined by stream type, not config. -- **Lineage**: parent stream defined at stream level (in `_streams` registry). Per-row `parent_id` links to specific row in parent. -- **Transform**: `.transform(source, fn)` on any stream — unified API for batch (ObservationSet) and live (StreamBase) derived streams. Uses `appended` observable for reactive pipeline. See `transform.md`. - -### SQLite-specific - -- **Separate payload table**: `_payload` separate from `_meta` so queries never page in multi-MB blobs. -- **EmbeddingStream has no payload table**: the vector in `_vec` IS the data. -- **R*Tree for spatio-temporal**: time-only queries use B-tree index on `_meta.ts` (faster for 1D). Spatial/spatio-temporal queries use R*Tree. -- **Payload serialization**: `lcm_encode()` / `lcm_decode()`. Non-LCM types rejected at `append()` with `TypeError`. -- **ObservationSet cleanup**: ref-table-backed sets use SQLite temp tables. Cleaned via context manager, `__del__` fallback, or connection close. diff --git a/plans/old/memory3_answers.md b/plans/old/memory3_answers.md deleted file mode 100644 index 24331dc002..0000000000 --- a/plans/old/memory3_answers.md +++ /dev/null @@ -1,67 +0,0 @@ -# Memory2 API Answers - -Worked examples against the API defined in `memory3.md`. - -## Q1: "Where was I when this log line was added?" - -> Pose lookup, correlating to log lines found. Assume log lines have poses associated. Assume there are multiple log lines matching a search. - -### Setup - -```python -store = SqliteStore("/data/robot.db") -session = store.session() - -# TextStream for robot logs — pose auto-filled from TF tree -logs = session.text_stream("logs", payload_type=str, - pose_provider=lambda: tf.get_pose("world", "base_link")) - -# At runtime, just append text — pose is filled automatically -logs.append("Motor fault on joint 3") -logs.append("Obstacle detected ahead") -logs.append("Motor fault on joint 3") -``` - -### Single log line lookup - -```python -row = logs.query().search_text("motor fault on joint 3").one() -print(f"Robot was at {row.pose} when this log was added (t={row.ts})") -``` - -`search_text()` uses FTS5 keyword matching. `one()` returns the best match. The pose comes straight from `_meta` — no joins or extra queries needed. - -### Multiple matches - -```python -rows = logs.query().search_text("motor fault").order_by("ts").fetch() - -for row in rows: - text = logs.load(row.ref) # load actual log text from _payload - print(f"t={row.ts} pose={row.pose}: {text}") -``` - -### Spatial aggregation — "where do motor faults cluster?" - -```python -rows = logs.query().search_text("motor fault").fetch() - -# Group by proximity (application-level, not part of core API) -from collections import defaultdict -clusters = defaultdict(list) -for row in rows: - # bucket by 2m grid - key = (round(row.pose.x / 2) * 2, round(row.pose.y / 2) * 2) - clusters[key].append(row) - -for loc, group in clusters.items(): - print(f" {len(group)} motor faults near {loc}") -``` - -### What's exercised - -- `TextStream` with FTS index for keyword search -- `search_text()` → FTS5 `MATCH` -- Pose stored at append time, returned in `ObservationRow.pose` -- `load()` to retrieve actual text payload separately from metadata -- `order_by("ts")` for chronological ordering diff --git a/plans/old/memory4.md b/plans/old/memory4.md deleted file mode 100644 index 08e42265c4..0000000000 --- a/plans/old/memory4.md +++ /dev/null @@ -1,466 +0,0 @@ -# Memory2 API — Unified Stream - -## Core Idea - -One type: `Stream[T]`. Everything is a stream — stored, filtered, transformed. The user never thinks about Query vs ObservationSet vs Stream. They just chain operations. - -## Creating Streams - -```python -store = SqliteStore("/data/robot.db") -session = store.session() - -# Root stored stream — backed by DB -images = session.stream("images", Image, - pose_provider=lambda: tf.get_pose("world", "base_link")) - -logs = session.text_stream("logs", str, - pose_provider=lambda: tf.get_pose("world", "base_link")) -``` - -## Writing - -```python -images.append(frame) # ts + pose auto-filled -logs.append("Motor fault on joint 3") # ts + pose auto-filled -images.append(frame, pose=explicit_pose, tags={"cam": "front"}) -``` - -Only meaningful on stored (DB-backed) streams. - -## Filtering - -Every filter returns a new `Stream[T]`. Lazy — nothing executes until a terminal. - -```python -recent = images.after(one_hour_ago) -kitchen = recent.near(kitchen_pose, 5.0) -tagged = kitchen.filter_tags(cam="front") - -# Or chained -images.after(one_hour_ago).near(kitchen_pose, 5.0).filter_tags(cam="front") -``` - -### Filter methods - -```python -class Stream(Generic[T]): - # Temporal - def after(self, t: float) -> Stream[T]: ... - def before(self, t: float) -> Stream[T]: ... - def time_range(self, t1: float, t2: float) -> Stream[T]: ... - def at(self, t: float, *, tolerance: float = 1.0) -> Stream[T]: ... - - # Spatial - def near(self, pose: PoseLike, radius: float) -> Stream[T]: ... - - # Tags - def filter_tags(self, **tags: Any) -> Stream[T]: ... - -class EmbeddingStream(Stream[T]): - def search_embedding(self, query: Embedding | list[float], *, k: int) -> EmbeddingStream[T]: ... - -class TextStream(Stream[T]): - def search_text(self, text: str, *, k: int | None = None) -> TextStream[T]: ... -``` - -## Terminals - -```python -rows = images.after(t).fetch() # list[Observation] -row = images.after(t).one() # single best match -row = images.last() # most recent -n = images.after(t).count() # count without fetching - -# Pagination -page = images.order_by("ts").limit(50).offset(100).fetch() -``` - -### Terminal methods - -```python -class Stream(Generic[T]): - def fetch(self) -> list[Observation]: ... - def fetch_pages(self, batch_size: int = 128) -> Iterator[list[Observation]]: ... - def one(self) -> Observation: ... - def last(self) -> Observation: ... - def count(self) -> int: ... - def order_by(self, field: str, *, desc: bool = False) -> Stream[T]: ... - def limit(self, k: int) -> Stream[T]: ... - def offset(self, n: int) -> Stream[T]: ... -``` - -## Observation - -```python -from dimos.models.embedding.base import Embedding, EmbeddingModel - -@dataclass -class Observation: - id: int - ts: float | None = None - pose: PoseStamped | None = None - tags: dict[str, Any] = field(default_factory=dict) - - @property - def data(self) -> Any: - """Lazy payload. Pre-populated from append/transform, fetched on demand from query.""" - ... -``` - -## Transformer - -A `Transformer` receives the full source stream and decides what to do — which items to process, how to batch, whether to use embeddings as a cheap proxy, etc. - -```python -class Transformer(ABC, Generic[T, R]): - """Transforms a source stream into results on a target stream.""" - - def process(self, source: Stream[T], target: Stream[R]) -> None: - """Batch/historical processing. Has full access to source — can query, - filter, use embeddings, batch, skip frames, etc.""" - ... - - def on_append(self, obs: Observation, target: Stream[R]) -> None: - """Reactive processing. Called per new item. Default: process([obs]).""" - ... - - supports_backfill: bool = True - supports_live: bool = True -``` - -### Simple lambdas (sugar) - -`Callable[[T], R | list[R] | None]` is auto-wrapped into a naive per-item Transformer: - -```python -# These are equivalent: -images.transform(lambda img: vlm.detect(img, "cigarettes")) -images.transform(PerItemTransformer(lambda img: vlm.detect(img, "cigarettes"))) -``` - -- `R` → single result -- `list[R]` → multiple results (e.g., multiple detections per frame) -- `None` → skip (no result for this input) - -### EmbeddingTransformer - -`EmbeddingTransformer` wraps an `EmbeddingModel` as a `Transformer[T, Embedding]`. When the output type is `Embedding`, `.store()` creates an `EmbeddingStream` (vec0 index, `search_embedding`, `EmbeddingObservation`). - -```python -# EmbeddingTransformer wraps the model -img_emb = images.transform(EmbeddingTransformer(CLIPModel())).store("img_emb") - -# Now img_emb is an EmbeddingStream -results = img_emb.search_embedding(query_emb, k=20).fetch() -# results[0].data → Image (auto-projected from source) -# results[0].embedding → Embedding (supports @ for cosine similarity) -``` - -### Smart Transformer example - -Chains after an embedding transform — receives `EmbeddingObservation` with `.data` (Image) and `.embedding` (vector), so it can use similarity to skip irrelevant frames: - -```python -class CigaretteDetector(Transformer[EmbeddingObservation, Detection]): - def __init__(self, vlm, clip): - self.vlm = vlm - self.clip = clip - - def process(self, source: Stream[EmbeddingObservation], target: Stream[Detection]): - query = self.clip.embed_text("person smoking cigarette") - for page in source.fetch_pages(batch_size=16): - # Use embedding similarity as cheap proxy — skip distant frames - promising = [obs for obs in page if obs.embedding @ query > 0.3] - if not promising: - continue - detections = self.vlm.detect_batch( - [obs.data for obs in promising], "cigarettes" - ) - for obs, dets in zip(promising, detections): - for det in dets: - target.append(det, ts=obs.ts, pose=obs.pose) - - def on_append(self, obs: EmbeddingObservation, target: Stream[Detection]): - dets = self.vlm.detect(obs.data, "cigarettes") - for det in dets: - target.append(det, ts=obs.ts, pose=obs.pose) -``` - -### Chaining transforms - -```python -# Filter → transform → store -images.after(one_hour_ago) \ - .near(kitchen_pose, 5.0) \ - .transform(EmbeddingTransformer(CLIPModel())) \ - .store("kitchen_embeddings") - -# Filter → transform → fetch (in-memory, not persisted) -results = images.after(one_hour_ago) \ - .near(kitchen_pose, 5.0) \ - .transform(EmbeddingTransformer(CLIPModel())) \ - .fetch() - -# Filter → embed → detect → store (chained: detector gets EmbeddingObservation) -images.near(kitchen_pose, 5.0) \ - .transform(EmbeddingTransformer(CLIPModel())) \ - .transform(CigaretteDetector(vlm, clip)) \ - .store("kitchen_cigarette_detections") -``` - -### Backfill / Live modes - -```python -# Both (default): backfill existing + subscribe to new -images.transform(detector).store("detections") - -# Live only: skip backfill, only process new items -images.transform(detector, live=True).store("detections") - -# Backfill only: process existing, don't subscribe -images.transform(detector, backfill_only=True).store("detections") - -# Incremental: re-running a stored transform resumes from last processed item -# (uses lineage parent_id to skip already-processed source rows) -``` - -## Storing - -`.store(name)` materializes a stream to DB. After storing, results are queryable and persistent. - -```python -# In-memory transform result — not persisted -detections = images.transform(detect_fn) - -# Persist it -detections.store("detections") - -# Now it's a DB-backed stream, queryable -stored = session.stream("detections") -rows = stored.after(t).fetch() -``` - -`.store()` also sets up lineage — every stored row gets `parent_id` pointing back to its source. - -Stream type is determined by what the Transformer produces: -- `Embedding` output → `EmbeddingStream` (vec0 index) -- Everything else → `Stream` (blob) -- `TextStream` is created explicitly via `session.text_stream()` (not auto-detected) - -## Reactive - -```python -# .appended emits Observation with .data pre-populated -images.appended.subscribe(lambda row: print(f"New image at {row.pose}")) - -# Stored transforms propagate reactively by default -detections = images.transform(detect_fn).store("detections") -# Now every images.append(frame) → detect_fn runs → result stored in "detections" - -# Filtered appended — only kitchen images -images.near(kitchen_pose, 5.0).appended.subscribe(...) -``` - -## Project (cross-stream lineage) - -```python -# Find source images for detections -source_images = detections.after(t).project_to(images) - -# project_to returns a Stream over the parent, filtered by lineage -for row in source_images.fetch(): - img = row.data # lazy-loads from images stream -``` - -## Full Example: Cigarette Detection Pipeline - -```python -session = SqliteStore("/data/robot.db").session() - -# Root stream -images = session.stream("images", Image, - pose_provider=lambda: tf.get_pose("world", "base_link")) - -# Embedding index — EmbeddingModel is a Transformer -img_emb = images.transform(EmbeddingTransformer(CLIPModel())).store("img_emb") - -# VLM detection pipeline (live-only, no backfill) -images.transform( - lambda img: vlm.detect(img, "people with cigarettes"), - live=True, -).store("cigarette_detections") - -# Smart detection — chain embed → detect (detector uses embedding similarity to skip frames) -images.near(kitchen_pose, 10.0) \ - .transform(EmbeddingTransformer(CLIPModel())) \ - .transform(CigaretteDetector(vlm, clip)) \ - .store("kitchen_cigarette_detections") - -# --- Later, querying --- - -# "Where did we see people with cigarettes in the kitchen?" -rows = session.stream("cigarette_detections") \ - .after(one_hour_ago) \ - .near(kitchen_pose, 10.0) \ - .fetch() - -for row in rows: - print(f"t={row.ts} pose={row.pose}: {row.data}") - -# "Show me the source images" -source_imgs = session.stream("cigarette_detections") \ - .after(one_hour_ago) \ - .project_to(images) \ - .fetch() - -# "Find images similar to 'red shoes'" -query_emb = clip.embed_text("red shoes") -similar = img_emb.search_embedding(query_emb, k=20).fetch() -# similar[0].data → Image (auto-projected from source) -# similar[0].embedding → Embedding (supports @ for cosine similarity) -``` - -## Full API - -```python -from dimos.models.embedding.base import Embedding, EmbeddingModel - -# --- Data types --- - -@dataclass -class Observation: - id: int - ts: float | None = None - pose: PoseStamped | None = None - tags: dict[str, Any] = field(default_factory=dict) - - @property - def data(self) -> Any: - """Lazy payload. Pre-populated from append, fetched on demand from query.""" - ... - -@dataclass -class EmbeddingObservation(Observation): - """Returned by EmbeddingStream terminals. Auto-projects .data to source stream.""" - - @property - def data(self) -> Any: - """Lazily loads from the source stream (e.g., Image), not the embedding.""" - ... - - @property - def embedding(self) -> Embedding: - """The Embedding object (has .vector, supports @ for cosine similarity).""" - ... - -# --- Transformer --- - -class Transformer(ABC, Generic[T, R]): - """Transforms a source stream into results on a target stream.""" - - def process(self, source: Stream[T], target: Stream[R]) -> None: - """Batch/historical processing. Full access to source stream.""" - ... - - def on_append(self, obs: Observation, target: Stream[R]) -> None: - """Reactive processing. Called per new item.""" - ... - - supports_backfill: bool = True - supports_live: bool = True - -# --- Streams --- - -class Stream(Generic[T]): - # Write (DB-backed only) - def append(self, payload: T, *, - ts: float | None = None, - pose: PoseLike | None = None, - tags: dict[str, Any] | None = None, - ) -> Observation: ... - - # Filter (returns new Stream, lazy) - def after(self, t: float) -> Stream[T]: ... - def before(self, t: float) -> Stream[T]: ... - def time_range(self, t1: float, t2: float) -> Stream[T]: ... - def at(self, t: float, *, tolerance: float = 1.0) -> Stream[T]: ... - def near(self, pose: PoseLike, radius: float) -> Stream[T]: ... - def filter_tags(self, **tags: Any) -> Stream[T]: ... - - # Order / paginate - def order_by(self, field: str, *, desc: bool = False) -> Stream[T]: ... - def limit(self, k: int) -> Stream[T]: ... - def offset(self, n: int) -> Stream[T]: ... - - # Transform - def transform(self, - xf: Transformer[T, R] | Callable[[T], R | list[R] | None], - *, live: bool = False, - backfill_only: bool = False, - ) -> Stream[R]: ... - - # Materialize - def store(self, name: str | None = None) -> Stream[T]: ... - - # Cross-stream - def project_to(self, target: Stream) -> Stream: ... - - # Terminals - def fetch(self) -> list[Observation]: ... - def fetch_pages(self, batch_size: int = 128) -> Iterator[list[Observation]]: ... - def one(self) -> Observation: ... - def last(self) -> Observation: ... - def count(self) -> int: ... - - # Reactive - @property - def appended(self) -> Observable[Observation]: ... - -class EmbeddingStream(Stream[T]): - """Created automatically when a Transformer produces Embedding output. - Terminals return EmbeddingObservation (auto-projects .data to source stream).""" - def search_embedding(self, query: Embedding | list[float], *, k: int) -> EmbeddingStream[T]: ... - def fetch(self) -> list[EmbeddingObservation]: ... - def one(self) -> EmbeddingObservation: ... - def last(self) -> EmbeddingObservation: ... - -class TextStream(Stream[T]): - """Stream with FTS index.""" - def search_text(self, text: str, *, k: int | None = None) -> TextStream[T]: ... - -# --- Session / Store --- - -PoseProvider = Callable[[], PoseLike | None] - -class Session: - def stream(self, name: str, payload_type: type | None = None, *, - pose_provider: PoseProvider | None = None) -> Stream: ... - def text_stream(self, name: str, payload_type: type | None = None, *, - tokenizer: str = "unicode61", - pose_provider: PoseProvider | None = None) -> TextStream: ... - def list_streams(self) -> list[StreamInfo]: ... - def close(self) -> None: ... - -class Store: - def session(self) -> Session: ... - def close(self) -> None: ... -``` - -## Internal Backing (impl detail) - -A `Stream` can be backed by different things — the user never sees this: - -- **DB table** — from `session.stream()`. Has `_meta`, `_payload`, indexes. -- **Predicate** — from `.after()`, `.near()`, etc. Lazy SQL WHERE. -- **Transform** — from `.transform(t)`. Source stream + Transformer. - -The impl decides how to execute based on the backing chain. - -## Open Questions - -1. **`.append()` on non-stored streams?** Runtime error, or silently ignore? Probably `TypeError`. - -2. **Multiple `.store()` calls?** Should be idempotent — second call is a no-op if already stored under the same name. - -3. **Memory pressure from in-memory transforms?** Large `.transform().fetch()` without `.store()` loads everything into memory. Should we support streaming iteration? diff --git a/plans/old/transforms.md b/plans/old/transforms.md deleted file mode 100644 index edc5940512..0000000000 --- a/plans/old/transforms.md +++ /dev/null @@ -1,21 +0,0 @@ -```python -# Filter → transform → store - - images.after(one_hour_ago) \ - .near(kitchen_pose, 5.0) \ - .transform(CLIPModel()) \ - .store("kitchen_embeddings") - - # Filter → transform → fetch (in-memory, not persisted) - results = images.after(one_hour_ago) \ - .near(kitchen_pose, 5.0) \ - .transform(CLIPModel()) \ - .fetch() - - # Filter → transform → transform → store - images.near(kitchen_pose, 5.0) \ - .transform(CLIPModel()) \ - .transform(CigaretteDetector(vlm)) \ - .store("kitchen_cigarette_detections") - -``` From 92ee38068fc8508fa537b8acfd25b63a16931fbb Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 5 Mar 2026 19:02:41 +0800 Subject: [PATCH 022/118] Address Greptile review: SQL injection guards, distance ordering, stubs - Validate stream names and tag keys as SQL identifiers - Allowlist order_by fields to {id, ts} - Re-sort vector search results by distance rank after IN-clause fetch - Make TagsFilter hashable (tuple of pairs instead of dict) - Remove dead code in memory_old/embedding.py - Add scipy-stubs, fix distance_transform_edt type annotations --- dimos/mapping/occupancy/gradient.py | 2 +- dimos/memory/impl/sqlite.py | 23 +++++- dimos/memory/impl/test_sqlite.py | 2 +- dimos/memory/stream.py | 2 +- dimos/memory/types.py | 4 +- dimos/memory/viz.py | 6 +- dimos/memory_old/embedding.py | 3 - pyproject.toml | 1 + uv.lock | 119 ++++++++++++++++++++++++++++ 9 files changed, 149 insertions(+), 13 deletions(-) diff --git a/dimos/mapping/occupancy/gradient.py b/dimos/mapping/occupancy/gradient.py index 880f2692da..51a0c013ad 100644 --- a/dimos/mapping/occupancy/gradient.py +++ b/dimos/mapping/occupancy/gradient.py @@ -50,7 +50,7 @@ def gradient( # Compute distance transform (distance to nearest obstacle in cells) # Unknown cells are treated as if they don't exist for distance calculation - distance_cells = ndimage.distance_transform_edt(1 - obstacle_map) + distance_cells: np.ndarray = ndimage.distance_transform_edt(1 - obstacle_map) # type: ignore[assignment] # Convert to meters and clip to max distance distance_meters = np.clip(distance_cells * occupancy_grid.resolution, 0, max_distance) diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index c2e6e75c71..c531591bfe 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -29,6 +29,7 @@ from __future__ import annotations import json +import re import sqlite3 import time from typing import TYPE_CHECKING, Any @@ -67,6 +68,16 @@ from dimos.memory.types import PoseProvider from dimos.models.embedding.base import EmbeddingModel +_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") +_ALLOWED_ORDER_FIELDS = frozenset({"id", "ts"}) + + +def _validate_identifier(name: str) -> str: + """Validate that *name* is a safe SQL identifier (alphanumeric + underscore).""" + if not _IDENTIFIER_RE.match(name): + raise ValueError(f"Invalid identifier: {name!r}") + return name + # ── Pose helpers (column-based) ────────────────────────────────────── @@ -131,7 +142,8 @@ def _compile_filter(f: Filter, table: str) -> tuple[str, list[Any]]: if isinstance(f, TagsFilter): clauses: list[str] = [] params: list[Any] = [] - for key, val in f.tags.items(): + for key, val in f.tags: + _validate_identifier(key) clauses.append(f"json_extract({table}.tags, '$.{key}') = ?") params.append(val) return " AND ".join(clauses), params @@ -259,8 +271,10 @@ def _compile_query(query: StreamQuery, table: str) -> tuple[str, list[Any]]: where = " AND ".join(where_parts) if where_parts else "1=1" join_clause = " ".join(joins) - order = f"ORDER BY {query.order_field}" if query.order_field: + if query.order_field not in _ALLOWED_ORDER_FIELDS: + raise ValueError(f"Invalid order field: {query.order_field!r}") + order = f"ORDER BY {table}.{query.order_field}" if query.order_desc: order += " DESC" else: @@ -347,6 +361,7 @@ def __init__( pose_provider: PoseProvider | None = None, codec: LcmCodec | JpegCodec | PickleCodec | None = None, ) -> None: + _validate_identifier(table) self._conn = conn self._table = table self._pose_provider = pose_provider @@ -568,6 +583,10 @@ def _fetch_by_vector( if isinstance(obs, EmbeddingObservation): obs.similarity = max(0.0, min(1.0, 1.0 - dist_map.get(obs.id, 0.0))) + # Re-sort by distance rank (IN clause doesn't preserve vec0 ordering) + rank = {rid: i for i, rid in enumerate(rowids)} + observations.sort(key=lambda o: rank.get(o.id, len(rank))) + near = _has_near_filter(query) if near is not None: observations = _apply_near_post_filter(observations, near) diff --git a/dimos/memory/impl/test_sqlite.py b/dimos/memory/impl/test_sqlite.py index 42b94803e7..2c0bdb58e9 100644 --- a/dimos/memory/impl/test_sqlite.py +++ b/dimos/memory/impl/test_sqlite.py @@ -906,7 +906,7 @@ def test_at_filter(self) -> None: def test_tags_filter(self) -> None: from dimos.memory.types import TagsFilter - f = TagsFilter({"cam": "front"}) + f = TagsFilter((("cam", "front"),)) assert f.matches(Observation(id=1, tags={"cam": "front", "quality": "high"})) is True assert f.matches(Observation(id=2, tags={"cam": "rear"})) is False assert f.matches(Observation(id=3, tags={})) is False diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index 0ada16f81d..7b353cc0f1 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -157,7 +157,7 @@ def near(self, pose: PoseLike, radius: float) -> Stream[T]: # ── Tag filter ──────────────────────────────────────────────────── def filter_tags(self, **tags: Any) -> Stream[T]: - return self._with_filter(TagsFilter(tags)) + return self._with_filter(TagsFilter(tuple(tags.items()))) # ── Ordering / pagination ───────────────────────────────────────── diff --git a/dimos/memory/types.py b/dimos/memory/types.py index 1bab7b4af8..ea7b1194ef 100644 --- a/dimos/memory/types.py +++ b/dimos/memory/types.py @@ -143,10 +143,10 @@ def matches(self, obs: Observation) -> bool: @dataclass(frozen=True) class TagsFilter: - tags: dict[str, Any] + tags: tuple[tuple[str, Any], ...] def matches(self, obs: Observation) -> bool: - return all(obs.tags.get(k) == v for k, v in self.tags.items()) + return all(obs.tags.get(k) == v for k, v in self.tags) @dataclass(frozen=True) diff --git a/dimos/memory/viz.py b/dimos/memory/viz.py index 70efe906df..3870f18600 100644 --- a/dimos/memory/viz.py +++ b/dimos/memory/viz.py @@ -35,7 +35,7 @@ def similarity_heatmap( *, resolution: float = 0.1, padding: float = 1.0, - spread: float = 2.0, + spread: float = 0.2, frame_id: str = "world", ) -> OccupancyGrid: """Build an OccupancyGrid heatmap from observations with similarity scores. @@ -107,14 +107,14 @@ def similarity_heatmap( has_obs[gy, gx] = True # Distance transform: distance (in cells) from each empty cell to nearest dot - dist_cells = distance_transform_edt(~has_obs) + dist_cells: np.ndarray[Any, Any] = distance_transform_edt(~has_obs) # type: ignore[assignment] dist_metres = dist_cells * resolution # Fade factor: 1.0 at the dot, 0.0 at `spread` metres away fade = np.clip(1.0 - dist_metres / spread, 0.0, 1.0) # For each cell, find the value of its nearest dot (via index output) - _, nearest_idx = distance_transform_edt(~has_obs, return_indices=True) + _, nearest_idx = distance_transform_edt(~has_obs, return_indices=True) # type: ignore[misc] nearest_value = value_grid[nearest_idx[0], nearest_idx[1]] # Final heatmap = nearest dot's value * distance fade diff --git a/dimos/memory_old/embedding.py b/dimos/memory_old/embedding.py index 758634eecc..6fa3445208 100644 --- a/dimos/memory_old/embedding.py +++ b/dimos/memory_old/embedding.py @@ -109,6 +109,3 @@ def query_text(self, query: str) -> list[SpatialEmbedding]: self.config.embedding_model.embed_text(query) results: list[SpatialEmbedding] = [] return results - return results - return results - return results diff --git a/pyproject.toml b/pyproject.toml index 52aeecfbae..9e993584c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -243,6 +243,7 @@ dev = [ "types-tensorflow>=2.18.0.20251008,<3", "types-tqdm>=4.67.0.20250809,<5", "types-psycopg2>=2.9.21.20251012", + "scipy-stubs>=1.15.0", # Tools "py-spy", diff --git a/uv.lock b/uv.lock index 46f9b1931d..dc8012cc2a 100644 --- a/uv.lock +++ b/uv.lock @@ -1790,6 +1790,8 @@ dds = [ { name = "python-lsp-server", extra = ["all"] }, { name = "requests-mock" }, { name = "ruff" }, + { name = "scipy-stubs", version = "1.15.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "scipy-stubs", version = "1.17.1.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "terminaltexteffects" }, { name = "types-colorama" }, { name = "types-defusedxml" }, @@ -1827,6 +1829,8 @@ dev = [ { name = "python-lsp-server", extra = ["all"] }, { name = "requests-mock" }, { name = "ruff" }, + { name = "scipy-stubs", version = "1.15.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "scipy-stubs", version = "1.17.1.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "terminaltexteffects" }, { name = "types-colorama" }, { name = "types-defusedxml" }, @@ -2081,6 +2085,7 @@ requires-dist = [ { name = "scikit-learn", marker = "extra == 'misc'" }, { name = "scipy", specifier = ">=1.15.1" }, { name = "scipy", marker = "extra == 'docker'", specifier = ">=1.15.1" }, + { name = "scipy-stubs", marker = "extra == 'dev'", specifier = ">=1.15.0" }, { name = "sentence-transformers", marker = "extra == 'misc'" }, { name = "sortedcontainers", specifier = "==2.4.0" }, { name = "sortedcontainers", marker = "extra == 'docker'" }, @@ -5523,6 +5528,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2d/ee/346fa473e666fe14c52fcdd19ec2424157290a032d4c41f98127bfb31ac7/numpy-2.3.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:f16417ec91f12f814b10bafe79ef77e70113a2f5f7018640e7425ff979253425", size = 12967213, upload-time = "2025-11-16T22:52:39.38Z" }, ] +[[package]] +name = "numpy-typing-compat" +version = "20251206.2.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/77/83/dd90774d6685664cbe5525645a50c4e6c7454207aee552918790e879137f/numpy_typing_compat-20251206.2.3.tar.gz", hash = "sha256:18e00e0f4f2040fe98574890248848c7c6831a975562794da186cf4f3c90b935", size = 5009, upload-time = "2025-12-06T20:02:04.177Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/48/6f/dde8e2a79a3b6cbc31bc1037c1a1dbc07c90d52d946851bd7cba67e730a8/numpy_typing_compat-20251206.2.3-py3-none-any.whl", hash = "sha256:bfa2e4c4945413e84552cbd34a6d368c88a06a54a896e77ced760521b08f0f61", size = 6300, upload-time = "2025-12-06T20:01:56.664Z" }, +] + [[package]] name = "nvidia-cublas-cu12" version = "12.8.4.1" @@ -6186,6 +6203,60 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b8/ec/19c6cc6064c7fc8f0cd6d5b37c4747849e66040c6ca98f86565efc2c227c/optax-0.2.6-py3-none-any.whl", hash = "sha256:f875251a5ab20f179d4be57478354e8e21963373b10f9c3b762b94dcb8c36d91", size = 367782, upload-time = "2025-09-15T22:41:22.825Z" }, ] +[[package]] +name = "optype" +version = "0.9.3" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.11' and sys_platform == 'darwin'", + "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.11' and sys_platform == 'win32'", + "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", +] +dependencies = [ + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/88/3c/9d59b0167458b839273ad0c4fc5f62f787058d8f5aed7f71294963a99471/optype-0.9.3.tar.gz", hash = "sha256:5f09d74127d316053b26971ce441a4df01f3a01943601d3712dd6f34cdfbaf48", size = 96143, upload-time = "2025-03-31T17:00:08.392Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/73/d8/ac50e2982bdc2d3595dc2bfe3c7e5a0574b5e407ad82d70b5f3707009671/optype-0.9.3-py3-none-any.whl", hash = "sha256:2935c033265938d66cc4198b0aca865572e635094e60e6e79522852f029d9e8d", size = 84357, upload-time = "2025-03-31T17:00:06.464Z" }, +] + +[[package]] +name = "optype" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.14' and sys_platform == 'darwin'", + "python_full_version == '3.13.*' and sys_platform == 'darwin'", + "python_full_version == '3.12.*' and sys_platform == 'darwin'", + "python_full_version >= '3.14' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version >= '3.14' and sys_platform == 'win32'", + "python_full_version == '3.13.*' and sys_platform == 'win32'", + "(python_full_version >= '3.14' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.14' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "(python_full_version == '3.13.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.13.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version == '3.12.*' and sys_platform == 'win32'", + "(python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version == '3.11.*' and sys_platform == 'darwin'", + "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.11.*' and sys_platform == 'win32'", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", +] +dependencies = [ + { name = "typing-extensions", marker = "python_full_version >= '3.11' and python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/66/d3/c88bb4bd90867356275ca839499313851af4b36fce6919ebc5e1de26e7ca/optype-0.16.0.tar.gz", hash = "sha256:fa682fd629ef6b70ba656ebc9fdd6614ba06ce13f52e0416dd8014c7e691a2d1", size = 53498, upload-time = "2026-02-19T23:37:09.495Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/91/a8/fe26515203cff140f1afc31236fb7f703d4bb4bd5679d28afcb3661c8d9f/optype-0.16.0-py3-none-any.whl", hash = "sha256:c28905713f55630b4bb8948f38e027ad13a541499ebcf957501f486da54b74d2", size = 65893, upload-time = "2026-02-19T23:37:08.217Z" }, +] + +[package.optional-dependencies] +numpy = [ + { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "numpy-typing-compat", marker = "python_full_version >= '3.11'" }, +] + [[package]] name = "orbax-checkpoint" version = "0.11.32" @@ -8871,6 +8942,54 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/56/a5/df8f46ef7da168f1bc52cd86e09a9de5c6f19cc1da04454d51b7d4f43408/scipy-1.17.0-cp314-cp314t-win_arm64.whl", hash = "sha256:031121914e295d9791319a1875444d55079885bbae5bdc9c5e0f2ee5f09d34ff", size = 25246266, upload-time = "2026-01-10T21:30:45.923Z" }, ] +[[package]] +name = "scipy-stubs" +version = "1.15.3.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.11' and sys_platform == 'darwin'", + "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.11' and sys_platform == 'win32'", + "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", +] +dependencies = [ + { name = "optype", version = "0.9.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0b/5f/35c43bd7d412add4adcd68475702571b2489b50c40b6564f808b2355e452/scipy_stubs-1.15.3.0.tar.gz", hash = "sha256:e8f76c9887461cf9424c1e2ad78ea5dac71dd4cbb383dc85f91adfe8f74d1e17", size = 275699, upload-time = "2025-05-08T16:58:35.139Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6c/42/cd8dc81f8060de1f14960885ad5b2d2651f41de8b93d09f3f919d6567a5a/scipy_stubs-1.15.3.0-py3-none-any.whl", hash = "sha256:a251254cf4fd6e7fb87c55c1feee92d32ddbc1f542ecdf6a0159cdb81c2fb62d", size = 459062, upload-time = "2025-05-08T16:58:33.356Z" }, +] + +[[package]] +name = "scipy-stubs" +version = "1.17.1.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.14' and sys_platform == 'darwin'", + "python_full_version == '3.13.*' and sys_platform == 'darwin'", + "python_full_version == '3.12.*' and sys_platform == 'darwin'", + "python_full_version >= '3.14' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version >= '3.14' and sys_platform == 'win32'", + "python_full_version == '3.13.*' and sys_platform == 'win32'", + "(python_full_version >= '3.14' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.14' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "(python_full_version == '3.13.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.13.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version == '3.12.*' and sys_platform == 'win32'", + "(python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version == '3.11.*' and sys_platform == 'darwin'", + "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.11.*' and sys_platform == 'win32'", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", +] +dependencies = [ + { name = "optype", version = "0.16.0", source = { registry = "https://pypi.org/simple" }, extra = ["numpy"], marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d9/ad/413b0d18efca7bb48574d28e91253409d91ee6121e7937022d0d380dfc6a/scipy_stubs-1.17.1.0.tar.gz", hash = "sha256:5dc51c21765b145c2d132b96b63ff4f835dd5fb768006876d1554e7a59c61571", size = 381420, upload-time = "2026-02-23T10:33:04.742Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6c/ee/c6811e04ff9d5dd1d92236e8df7ebc4db6aa65c70b9938cec293348b8ec4/scipy_stubs-1.17.1.0-py3-none-any.whl", hash = "sha256:5c9c84993d36b104acb2d187b05985eb79f73491c60d83292dd738093d53d96a", size = 587059, upload-time = "2026-02-23T10:33:02.845Z" }, +] + [[package]] name = "sentence-transformers" version = "5.2.2" From de0efb987e0ac4ccb957948474186dc95f2c7ced Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 5 Mar 2026 19:49:17 +0800 Subject: [PATCH 023/118] Add memory Rerun visualization, fix stream iteration, update docs - Add dimos/memory/rerun.py: to_rerun() sends stream data to Rerun with auto-derived entity paths and no wall-clock timeline contamination - Fix Stream.fetch_pages() to respect limit_val (was always overridden by batch_size, making .limit() ineffective during iteration) - Update viz.py: normalize similarities with 20% floor cutoff, sort timeline by timestamp, add log_top_images() - Convert run_e2e_export.py to pytest with cached DB fixture - Update plans/memory docs to match current implementation --- dimos/memory/impl/run_e2e_export.py | 225 +++++---- dimos/memory/rerun.py | 78 ++++ dimos/memory/stream.py | 15 +- dimos/memory/viz.py | 105 ++++- plans/memory/api.md | 684 ++++++++++++++++++++++++++++ plans/memory/sqlite.md | 621 +++++++++++++++++++++++++ 6 files changed, 1611 insertions(+), 117 deletions(-) create mode 100644 dimos/memory/rerun.py create mode 100644 plans/memory/api.md create mode 100644 plans/memory/sqlite.md diff --git a/dimos/memory/impl/run_e2e_export.py b/dimos/memory/impl/run_e2e_export.py index 4c70531a16..9a56e3b283 100644 --- a/dimos/memory/impl/run_e2e_export.py +++ b/dimos/memory/impl/run_e2e_export.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Ingest 5min robot video → sharpness filter → CLIP embed → search & visualize. +"""E2E tests: ingest 5min robot video → sharpness filter → CLIP embed → search. -Caches the DB — re-run to just search without re-ingesting/embedding. -Outputs heatmaps and timelines to Rerun, images to disk. +The DB is built once and cached on disk so subsequent runs skip ingestion. +Run with: pytest dimos/memory/impl/run_e2e_export.py -s """ from __future__ import annotations @@ -23,7 +23,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any -import rerun as rr +import pytest from dimos.memory.impl.sqlite import SqliteStore from dimos.memory.ingest import ingest @@ -32,106 +32,131 @@ EmbeddingTransformer, QualityWindowTransformer, ) -from dimos.memory.viz import log_similarity_timeline, similarity_heatmap from dimos.models.embedding.clip import CLIPModel -from dimos.models.vl.florence import Florence2Model from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.testing import TimedSensorReplay if TYPE_CHECKING: from dimos.memory.stream import EmbeddingStream -OUT_DIR = Path(__file__).parent / "e2e_matches" -OUT_DIR.mkdir(exist_ok=True) - -rr.init("memory_e2e", spawn=True) - -db_path = OUT_DIR / "e2e.db" -store = SqliteStore(str(db_path)) -session = store.session() - -# Check if we already have data -existing = {s.name for s in session.list_streams()} -need_build = "clip_embeddings" not in existing - -if need_build: - replay = TimedSensorReplay("unitree_go2_bigoffice/video") - odom = TimedSensorReplay("unitree_go2_bigoffice/odom") - - print("Loading CLIP...") - clip = CLIPModel() - clip.start() - - # 1. Ingest 5 minutes with odom poses - print("Ingesting 5 min of video with odom poses...") - raw = session.stream("raw_video", Image) - n = ingest(raw, replay.iterate_ts(seek=5.0, duration=300.0), pose_source=odom) - print(f" {n} frames ingested") - - # 2. Sharpness filter - print("Filtering by sharpness (0.5s windows)...") - sharp = raw.transform(QualityWindowTransformer(lambda img: img.sharpness, window=0.5)).store( - "sharp_frames", Image - ) - n_sharp = sharp.count() - print(f" {n_sharp} sharp frames (from {n}, {n_sharp / n:.0%} kept)") - - # 3. Embed - print("Embedding with CLIP...") - embeddings: EmbeddingStream[Any] = sharp.transform(EmbeddingTransformer(clip)).store( - "clip_embeddings" - ) # type: ignore[assignment] - print(f" {embeddings.count()} embeddings stored") -else: - print(f"Using cached DB ({db_path})") - print("loading Clip") - clip = CLIPModel() - clip.start() - print("done") - sharp = session.stream("sharp_frames") - embeddings = session.embedding_stream("clip_embeddings", embedding_model=clip) - print(f" {sharp.count()} sharp frames, {embeddings.count()} embeddings") - -# 4. Search, visualize, export -queries = [ - "a hallway in an office", - "a person standing", - "a door", - "a desk", - "supermarket", - "large room", -] - -print("\nLoading Florence2 for captioning...") -captioner = Florence2Model() -captioner.start() - -caption_xf = CaptionTransformer(captioner) - -for query_text in queries: - print(f"\nQuery: '{query_text}'") - slug = query_text.replace(" ", "_")[:30] - - # raw=True: get EmbeddingObservation with .similarity and .pose - raw_results = embeddings.search_embedding(query_text, k=200, raw=True).fetch() - - # Spatial heatmap → Rerun - grid = similarity_heatmap(raw_results, resolution=0.5) - print(f" Heatmap: {grid}") - rr.log(f"world/{slug}/heatmap", grid.to_rerun(colormap="inferno")) - - # Temporal timeline → Rerun - log_similarity_timeline(raw_results, entity_path=f"plots/{slug}") - - # Caption top 5 (auto-projected results for image access) - results = embeddings.search_embedding(query_text, k=5).fetch() - captions = results.transform(caption_xf).fetch() - - for rank, (cap, img) in enumerate(zip(captions, results, strict=False)): - fname = OUT_DIR / f"{slug}_{rank + 1}_id{img.id}_ts{img.ts:.0f}.jpg" - img.data.save(str(fname)) - print(f" [{rank + 1}] id={img.id} ts={img.ts:.2f} — {cap.data}") - -session.close() -store.close() -print(f"\nDone. Results in {OUT_DIR}/") +DB_DIR = Path(__file__).parent / "e2e_matches" +DB_DIR.mkdir(exist_ok=True) +DB_PATH = DB_DIR / "e2e.db" + + +@pytest.fixture(scope="module") +def clip() -> CLIPModel: + model = CLIPModel() + model.start() + return model + + +@pytest.fixture(scope="module") +def e2e_db(clip: CLIPModel) -> tuple[SqliteStore, Any]: + """Build (or reuse cached) e2e DB with video → sharpness → CLIP embeddings.""" + store = SqliteStore(str(DB_PATH)) + session = store.session() + + existing = {s.name for s in session.list_streams()} + if "clip_embeddings" not in existing: + replay = TimedSensorReplay("unitree_go2_bigoffice/video") + odom = TimedSensorReplay("unitree_go2_bigoffice/odom") + + raw = session.stream("raw_video", Image) + n = ingest(raw, replay.iterate_ts(seek=5.0, duration=300.0), pose_source=odom) + print(f" {n} frames ingested") + + sharp = raw.transform( + QualityWindowTransformer(lambda img: img.sharpness, window=0.5) + ).store("sharp_frames", Image) + print(f" {sharp.count()} sharp frames (from {n}, {sharp.count() / n:.0%} kept)") + + embeddings: EmbeddingStream[Any] = sharp.transform(EmbeddingTransformer(clip)).store( + "clip_embeddings" + ) # type: ignore[assignment] + print(f" {embeddings.count()} embeddings stored") + else: + print(f"Using cached DB ({DB_PATH})") + + yield store, session # type: ignore[misc] + session.close() + store.close() + + +@pytest.fixture(scope="module") +def embeddings(e2e_db: tuple[SqliteStore, Any], clip: CLIPModel) -> EmbeddingStream[Any]: + _, session = e2e_db + return session.embedding_stream("clip_embeddings", embedding_model=clip) + + +class TestEmbeddingSearch: + """Search the cached CLIP embedding DB and export top matches.""" + + QUERIES = [ + "a hallway in an office", + "a person standing", + "a door", + "a desk", + "supermarket", + "large room", + ] + + @pytest.mark.parametrize("query", QUERIES) + def test_search_returns_results(self, embeddings: EmbeddingStream[Any], query: str) -> None: + results = embeddings.search_embedding(query, k=5).fetch() + assert len(results) > 0 + for obs in results: + assert obs.ts is not None + assert isinstance(obs.data, Image) + + @pytest.mark.parametrize("query", QUERIES) + def test_search_exports_images(self, embeddings: EmbeddingStream[Any], query: str) -> None: + slug = query.replace(" ", "_")[:30] + results = embeddings.search_embedding(query, k=5).fetch() + + for rank, img in enumerate(results): + fname = DB_DIR / f"{slug}_{rank + 1}_id{img.id}_ts{img.ts:.0f}.jpg" + img.data.save(str(fname)) + print(f" [{rank + 1}] id={img.id} ts={img.ts:.2f}") + + def test_raw_search_has_similarity(self, embeddings: EmbeddingStream[Any]) -> None: + from dimos.memory.types import EmbeddingObservation + + raw = embeddings.search_embedding("a hallway", k=10, raw=True).fetch() + assert len(raw) > 0 + for obs in raw: + assert isinstance(obs, EmbeddingObservation) + assert obs.similarity is not None + assert 0.0 <= obs.similarity <= 1.0 + + def test_caption_search_results(self, embeddings: EmbeddingStream[Any]) -> None: + from dimos.models.vl.florence import Florence2Model + + captioner = Florence2Model() + captioner.start() + caption_xf = CaptionTransformer(captioner) + + results = embeddings.search_embedding("a door", k=3).fetch() + captions = results.transform(caption_xf).fetch() + + assert len(captions) == len(results) + for cap in captions: + assert isinstance(cap.data, str) + assert len(cap.data) > 0 + print(f" Caption: {cap.data}") + + +class TestRerunStream: + """Send a full image stream to Rerun.""" + + def test_stream_to_rerun(self, e2e_db: tuple[SqliteStore, Any]) -> None: + import rerun as rr + + from dimos.memory.rerun import to_rerun + + rr.init("memory_e2e_test", spawn=True) + + _, session = e2e_db + n = to_rerun(session.stream("sharp_frames")) + assert n > 0 + print(f" Logged {n} images to Rerun") diff --git a/dimos/memory/rerun.py b/dimos/memory/rerun.py new file mode 100644 index 0000000000..e5c5bb371c --- /dev/null +++ b/dimos/memory/rerun.py @@ -0,0 +1,78 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Send memory stream contents to Rerun. + +Iterates a Stream, calls ``.to_rerun()`` on each observation's data +payload, and logs it at the observation's timestamp. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from dimos.memory.stream import Stream + + +def _infer_entity_path(stream: Any) -> str: + """Derive an entity path from the stream's backend name.""" + backend = getattr(stream, "_backend", None) + if backend is not None: + name = getattr(backend, "stream_name", None) + if name and name != "": + return f"memory/{name}" + raise ValueError( + "Cannot infer entity_path — stream has no named backend " + "(e.g. ObservationSet from .fetch()). Pass entity_path explicitly." + ) + + +def to_rerun( + stream: Stream[Any] | Any, + entity_path: str | None = None, +) -> int: + """Log stream observations to Rerun. + + For each observation whose ``.data`` has a ``to_rerun()`` method, + logs the result at the observation's timestamp on a custom "time" + timeline (no wall-clock contamination). + + Args: + stream: Any Stream or iterable of Observations. + entity_path: Rerun entity path. Auto-derived from stream name if None. + + Returns: + Number of items logged. + """ + import rerun as rr + + if entity_path is None: + entity_path = _infer_entity_path(stream) + + rr.disable_timeline("log_time") + rr.disable_timeline("log_tick") + + count = 0 + for obs in stream: + if obs.ts is not None: + rr.set_time("time", duration=obs.ts) + + data = obs.data + if hasattr(data, "to_rerun"): + rr.log(entity_path, data.to_rerun()) + count += 1 + + rr.reset_time() + return count diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index 7b353cc0f1..dd9b1a7ed3 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -273,12 +273,20 @@ def fetch(self) -> ObservationSet[T]: def fetch_pages(self, batch_size: int = 128) -> Iterator[list[Observation]]: offset = self._query.offset_val or 0 + total_limit = self._query.limit_val + emitted = 0 while True: + page_size = batch_size + if total_limit is not None: + remaining = total_limit - emitted + if remaining <= 0: + break + page_size = min(batch_size, remaining) q = StreamQuery( filters=self._query.filters, order_field=self._query.order_field or "id", order_desc=self._query.order_desc, - limit_val=batch_size, + limit_val=page_size, offset_val=offset, ) backend = self._require_backend() @@ -286,9 +294,10 @@ def fetch_pages(self, batch_size: int = 128) -> Iterator[list[Observation]]: if not page: break yield page - if len(page) < batch_size: + emitted += len(page) + if len(page) < page_size: break - offset += batch_size + offset += len(page) def one(self) -> Observation: results = self.limit(1).fetch() diff --git a/dimos/memory/viz.py b/dimos/memory/viz.py index 3870f18600..feba7446ed 100644 --- a/dimos/memory/viz.py +++ b/dimos/memory/viz.py @@ -30,6 +30,20 @@ from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +def _normalize_similarities(values: np.ndarray, *, floor_percentile: float = 20.0) -> np.ndarray: + """Min-max normalize, then cut bottom percentile to 0 and re-stretch to 0-1.""" + if len(values) == 0: + return values + vmin, vmax = float(values.min()), float(values.max()) + vrange = vmax - vmin + normed = (values - vmin) / vrange if vrange > 0 else np.full_like(values, 0.5) + floor = float(np.percentile(normed, floor_percentile)) + denom = 1.0 - floor + if denom > 0: + normed = np.clip((normed - floor) / denom, 0.0, 1.0) + return normed + + def similarity_heatmap( observations: list[Observation] | Any, *, @@ -89,11 +103,9 @@ def similarity_heatmap( width = max(1, int((max_x - min_x) / resolution) + 1) height = max(1, int((max_y - min_y) / resolution) + 1) - # Normalize similarities to 0-1 (CLIP similarities cluster in a narrow band) + # Normalize: min-max, cut bottom 20%, re-stretch to 0-1 sims = np.array([s for _, _, s in posed]) - sim_min, sim_max = float(sims.min()), float(sims.max()) - sim_range = sim_max - sim_min - sims_norm = (sims - sim_min) / sim_range if sim_range > 0 else np.full_like(sims, 0.5) + sims_norm = _normalize_similarities(sims) # Stamp normalized values onto a float grid (0 = no observation) value_grid = np.zeros((height, width), dtype=np.float32) @@ -124,8 +136,8 @@ def similarity_heatmap( grid = np.full((height, width), -1, dtype=np.int8) active = heatmap > 0 grid[active] = (heatmap[active] * 100).clip(0, 100).astype(np.int8) - # Ensure dot cells themselves are always visible - grid[has_obs] = (value_grid[has_obs] * 100).clip(1, 100).astype(np.int8) + # Ensure dot cells are present (0 = black for bottom percentile) + grid[has_obs] = (value_grid[has_obs] * 100).clip(50, 100).astype(np.int8) origin = Pose( position=[min_x, min_y, 0.0], @@ -157,8 +169,8 @@ def log_similarity_timeline( ) -> None: """Log similarity scores as a Rerun time-series plot. - Each observation is logged at its timestamp with its similarity score. - Rerun auto-generates an interactive time-series graph in the timeline panel. + Observations are sorted by timestamp so the plot shows temporal similarity + bumps rather than a descending curve (search results are ranked by similarity). Args: observations: Iterable of EmbeddingObservation with .similarity and .ts. @@ -168,10 +180,75 @@ def log_similarity_timeline( from dimos.memory.types import EmbeddingObservation - for obs in observations: - if not isinstance(obs, EmbeddingObservation): - continue - if obs.similarity is None or obs.ts is None: + sorted_obs = sorted( + ( + obs + for obs in observations + if isinstance(obs, EmbeddingObservation) + and obs.similarity is not None + and obs.ts is not None + ), + key=lambda o: o.ts, # type: ignore[arg-type] + ) + if not sorted_obs: + return + + # Normalize: cut bottom 20%, re-stretch to 0-1 + raw_sims = np.array([obs.similarity for obs in sorted_obs]) + normed = _normalize_similarities(raw_sims) + + # Disable wall-clock timelines so Rerun defaults to our custom "time" axis + rr.disable_timeline("log_time") + rr.disable_timeline("log_tick") + + # Lock Y-axis to 0-1 + from rerun.blueprint import ScalarAxis + + rr.set_time("time", duration=0.0) + rr.log(entity_path, ScalarAxis(range=(0.0, 1.0), zoom_lock=True), static=True) + + for obs, sim in zip(sorted_obs, normed, strict=True): + rr.set_time("time", duration=obs.ts) # type: ignore[arg-type] + rr.log(entity_path, rr.Scalars(float(sim))) + rr.reset_time() + + +def log_top_images( + observations: list[Observation] | Any, + entity_path: str = "memory/top_matches", + *, + n: int = 6, +) -> None: + """Log the top-N matching images to Rerun as a grid. + + Observations must have ``.data`` that is a dimos Image (with ``.to_rerun()``). + Sorted by similarity (highest first), limited to *n*. + + Args: + observations: Iterable of EmbeddingObservation with .similarity and .data (Image). + entity_path: Rerun entity path prefix. Images logged as ``{entity_path}/{rank}``. + n: Number of top images to log. + """ + import rerun as rr + + from dimos.memory.types import EmbeddingObservation + + ranked = [ + obs + for obs in observations + if isinstance(obs, EmbeddingObservation) and obs.similarity is not None + ] + ranked.sort(key=lambda o: o.similarity or 0.0, reverse=True) # type: ignore[union-attr] + + for i, obs in enumerate(ranked[:n]): + try: + img = obs.data + if hasattr(img, "to_rerun"): + rr.log(f"{entity_path}/{i + 1}", img.to_rerun()) + else: + import numpy as np_inner + + arr = np_inner.asarray(img) + rr.log(f"{entity_path}/{i + 1}", rr.Image(arr)) + except Exception: continue - rr.set_time("memory_time", timestamp=obs.ts) - rr.log(entity_path, rr.Scalars(obs.similarity)) diff --git a/plans/memory/api.md b/plans/memory/api.md new file mode 100644 index 0000000000..fe6897a5d9 --- /dev/null +++ b/plans/memory/api.md @@ -0,0 +1,684 @@ +# Memory2 API — Unified Stream + +## Core Idea + +One type: `Stream[T]`. Everything is a stream — stored, filtered, transformed. The user never thinks about Query vs ObservationSet vs Stream. They just chain operations. + +## Creating Streams + +```python +store = SqliteStore("/data/robot.db") +session = store.session() + +# Root stored stream — backed by DB +images = session.stream("images", Image, + pose_provider=lambda: tf.get_pose("world", "base_link")) + +logs = session.text_stream("logs", str, + pose_provider=lambda: tf.get_pose("world", "base_link")) +``` + +## Writing + +```python +images.append(frame) # ts + pose auto-filled +logs.append("Motor fault on joint 3") # ts + pose auto-filled +images.append(frame, pose=explicit_pose, tags={"cam": "front"}) +``` + +Only meaningful on stored (DB-backed) streams. + +### Batch ingest + +The `ingest()` helper accepts any iterable of `(ts, payload)` — e.g. from a replay: + +```python +from dimos.memory.ingest import ingest + +replay = TimedSensorReplay("unitree_go2_bigoffice/video") +odom = TimedSensorReplay("unitree_go2_bigoffice/odom") + +raw = session.stream("raw_video", Image) +n = ingest(raw, replay.iterate_ts(seek=5.0, duration=300.0), pose_source=odom) +# pose_source.find_closest(ts) is called per frame to attach odom poses +``` + +## Filtering + +Every filter returns a new `Stream[T]`. Lazy — nothing executes until a terminal. + +```python +recent = images.after(one_hour_ago) +kitchen = recent.near(kitchen_pose, 5.0) +tagged = kitchen.filter_tags(cam="front") + +# Or chained +images.after(one_hour_ago).near(kitchen_pose, 5.0).filter_tags(cam="front") +``` + +### Filter methods + +```python +class Stream(Generic[T]): + # Temporal + def after(self, t: float) -> Stream[T]: ... + def before(self, t: float) -> Stream[T]: ... + def time_range(self, t1: float, t2: float) -> Stream[T]: ... + def at(self, t: float, *, tolerance: float = 1.0) -> Stream[T]: ... + + # Spatial + def near(self, pose: PoseLike, radius: float) -> Stream[T]: ... + + # Tags + def filter_tags(self, **tags: Any) -> Stream[T]: ... + +class EmbeddingStream(Stream[T]): + def search_embedding(self, query: Embedding | list[float] | str | Any, + *, k: int, raw: bool = False) -> Stream[Any]: ... + +class TextStream(Stream[T]): + def search_text(self, text: str, *, k: int | None = None) -> TextStream[T]: ... +``` + +## Terminals & Iteration + +`Stream` is directly iterable — pages internally, never loads everything at once. + +```python +# Direct iteration (lazy, memory-efficient — uses fetch_pages internally) +for row in images.after(t).near(kitchen_pose, 5.0): + print(row.data) + +# Explicit fetch when you want the full list in memory +all_rows = images.after(t).fetch() # returns ObservationSet + +# Other terminals +row = images.after(t).one() # single best match +row = images.last() # most recent +n = images.after(t).count() # count without fetching + +# Pagination +page = images.order_by("ts").limit(50).offset(100).fetch() +``` + +### Terminal methods + +```python +class Stream(Generic[T]): + def __iter__(self) -> Iterator[Observation]: ... # lazy, pages internally + def fetch(self) -> ObservationSet[T]: ... # all results, list-like + stream-like + def fetch_pages(self, batch_size: int = 128) -> Iterator[list[Observation]]: ... + def one(self) -> Observation: ... + def last(self) -> Observation: ... + def count(self) -> int: ... + def order_by(self, field: str, *, desc: bool = False) -> Stream[T]: ... + def limit(self, k: int) -> Stream[T]: ... + def offset(self, n: int) -> Stream[T]: ... +``` + +### ObservationSet + +`fetch()` returns an `ObservationSet` — a list-like object that also supports stream chaining: + +```python +results = embeddings.search_embedding("a hallway", k=50).fetch() + +len(results) # list-like +results[0] # indexing +for r in results: # iteration + print(r.data) + +# Stream-like — further filter/transform the materialized results +results.after(t).fetch() +results.transform(caption_xf).fetch() +``` + +## Observation + +```python +@dataclass +class Observation: + id: int + ts: float | None = None + pose: PoseStamped | None = None + tags: dict[str, Any] = field(default_factory=dict) + parent_id: int | None = None # lineage: source observation id + + @property + def data(self) -> Any: + """Lazy payload. Pre-populated from append/transform, fetched on demand from query.""" + ... + +@dataclass +class EmbeddingObservation(Observation): + """Returned by EmbeddingStream terminals. Auto-projects .data to source stream.""" + + similarity: float | None = None # 0..1, populated by search_embedding (vec0 cosine) + + @property + def data(self) -> Any: + """Lazily loads from the source stream (e.g., Image), not the embedding.""" + ... + + @property + def embedding(self) -> Embedding: + """The Embedding object (has .vector, supports @ for cosine similarity).""" + ... +``` + +## Transformer + +A `Transformer` receives the full source stream and decides what to do — which items to process, how to batch, whether to use embeddings as a cheap proxy, etc. + +```python +class Transformer(ABC, Generic[T, R]): + """Transforms a source stream into results on a target stream.""" + + def process(self, source: Stream[T], target: Stream[R]) -> None: + """Batch/historical processing. Has full access to source — can query, + filter, use embeddings, batch, skip frames, etc.""" + ... + + def on_append(self, obs: Observation, target: Stream[R]) -> None: + """Reactive processing. Called per new item. Default: process([obs]).""" + ... + + supports_backfill: bool = True + supports_live: bool = True + output_type: type | None = None # determines target stream kind +``` + +### Simple lambdas (sugar) + +`Callable[[T], R | list[R] | None]` is auto-wrapped into a naive per-item Transformer: + +```python +# These are equivalent: +images.transform(lambda img: vlm.detect(img, "cigarettes")) +images.transform(PerItemTransformer(lambda img: vlm.detect(img, "cigarettes"))) +``` + +- `R` → single result +- `list[R]` → multiple results (e.g., multiple detections per frame) +- `None` → skip (no result for this input) + +### EmbeddingTransformer + +`EmbeddingTransformer` wraps an `EmbeddingModel` as a `Transformer[T, Embedding]`. When the output type is `Embedding`, `.store()` creates an `EmbeddingStream` (vec0 index, `search_embedding`, `EmbeddingObservation`). + +```python +# EmbeddingTransformer wraps the model +img_emb = images.transform(EmbeddingTransformer(CLIPModel())).store("img_emb") + +# Now img_emb is an EmbeddingStream +results = img_emb.search_embedding(query_emb, k=20).fetch() +# results[0].data → Image (auto-projected from source) +# results[0].embedding → Embedding (supports @ for cosine similarity) +``` + +### Chaining transforms + +```python +# Filter → transform → store +images.after(one_hour_ago) \ + .near(kitchen_pose, 5.0) \ + .transform(EmbeddingTransformer(CLIPModel())) \ + .store("kitchen_embeddings") + +# Filter → transform → fetch (in-memory, not persisted) +results = images.after(one_hour_ago) \ + .near(kitchen_pose, 5.0) \ + .transform(EmbeddingTransformer(CLIPModel())) \ + .fetch() + +# Filter → embed → detect → store (chained: detector gets EmbeddingObservation) +images.near(kitchen_pose, 5.0) \ + .transform(EmbeddingTransformer(CLIPModel())) \ + .transform(CigaretteDetector(vlm, clip)) \ + .store("kitchen_cigarette_detections") +``` + +### Backfill / Live modes + +```python +# Both (default): backfill existing + subscribe to new +images.transform(detector).store("detections") + +# Live only: skip backfill, only process new items +images.transform(detector, live=True).store("detections") + +# Backfill only: process existing, don't subscribe +images.transform(detector, backfill=True).store("detections") + +# Backfill only: process existing, and subscribe +images.transform(detector, backfill=True, live=True).store("detections") + +# Incremental: re-running a stored transform resumes from last processed item +# (uses lineage parent_id to skip already-processed source rows) +``` + +## Storing + +`.store(name)` materializes a stream to DB. After storing, results are queryable and persistent. + +```python +# In-memory transform result — not persisted +detections = images.transform(detect_fn) + +# Persist it +detections.store("detections") + +# Now it's a DB-backed stream, queryable +stored = session.stream("detections") +rows = stored.after(t).fetch() +``` + +`.store()` also sets up lineage — every stored row gets `parent_id` pointing back to its source. + +Stream type is determined by what the Transformer produces: +- `Embedding` output → `EmbeddingStream` (vec0 index) +- `str` output from `CaptionTransformer` → `TextStream` (FTS index) +- Everything else → `Stream` (blob) + +## Reactive + +```python +# .appended emits Observation with .data pre-populated +images.appended.subscribe(lambda row: print(f"New image at {row.pose}")) + +# Stored transforms propagate reactively by default +detections = images.transform(detect_fn).store("detections") +# Now every images.append(frame) → detect_fn runs → result stored in "detections" + +# Filtered appended — only kitchen images +images.near(kitchen_pose, 5.0).appended.subscribe(...) +``` + +## Cross-stream lineage (project_to) + +`project_to()` follows `parent_id` chains to project observations onto another stream: + +```python +# Get embeddings matching a query, then project to source images +emb_results = img_emb.search_embedding("red shoes", k=20, raw=True).fetch() +# emb_results are EmbeddingObservations with .similarity, .pose, .ts + +# Or project to get the source images directly +image_results = img_emb.search_embedding("red shoes", k=20, raw=True) \ + .project_to(images).fetch() +``` + +`search_embedding` auto-projects by default — `raw=True` skips this to get +`EmbeddingObservation` results with `.similarity` scores. + +Multi-hop lineage works too: +```python +# images → sharp_frames → clip_embeddings (2 hops) +# search_embedding auto-resolves the chain +results = clip_embeddings.search_embedding("a door", k=10).fetch() +# results[0].data → Image (from raw_video, traversing through sharp_frames) +``` + +## Visualization + +`dimos.memory.viz` provides helpers for visualizing search results: + +```python +from dimos.memory.viz import similarity_heatmap, log_similarity_timeline + +# Get raw results (with similarity scores and poses) +results = embeddings.search_embedding("a hallway", k=200, raw=True).fetch() + +# Spatial heatmap → OccupancyGrid (publishable via LCM, renderable in Rerun) +grid = similarity_heatmap(results, resolution=0.5, spread=2.0) +rr.log("world/heatmap", grid.to_rerun(colormap="inferno")) + +# Temporal timeline → Rerun scalar plot +log_similarity_timeline(results, entity_path="plots/similarity") +``` + +## Full Example: Cigarette Detection Pipeline + +```python +session = SqliteStore("/data/robot.db").session() + +# Root stream +images = session.stream("images", Image, + pose_provider=lambda: tf.get_pose("world", "base_link")) + +# Embedding index — EmbeddingModel is a Transformer +img_emb = images.transform(EmbeddingTransformer(CLIPModel())).store("img_emb") + +# VLM detection pipeline (live-only, no backfill) +images.transform( + lambda img: vlm.detect(img, "people with cigarettes"), + live=True, +).store("cigarette_detections") + +# Smart detection — reuse existing embeddings, detector gets EmbeddingObservation +img_emb.near(kitchen_pose, 10.0) \ + .transform(CigaretteDetector(vlm, clip)) \ + .store("kitchen_cigarette_detections") + +# --- Later, querying --- + +# "Where did we see people with cigarettes in the kitchen?" +for row in session.stream("cigarette_detections") \ + .after(one_hour_ago).near(kitchen_pose, 10.0): + print(f"t={row.ts} pose={row.pose}: {row.data}") + +# "Show me the source images alongside detections" +for det, img in session.stream("cigarette_detections") \ + .after(one_hour_ago).join(images): + print(f"Detection: {det.data}, Source image at {img.pose}") + +# "Find images similar to 'red shoes'" +similar = img_emb.search_embedding("red shoes", k=20).fetch() +# similar[0].data → Image (auto-projected from source) +# similar[0].embedding → Embedding (supports @ for cosine similarity) +``` + +## Full API + +```python +from dimos.models.embedding.base import Embedding, EmbeddingModel + +# --- Data types --- + +@dataclass +class Observation: + id: int + ts: float | None = None + pose: PoseStamped | None = None + tags: dict[str, Any] = field(default_factory=dict) + parent_id: int | None = None + + @property + def data(self) -> Any: + """Lazy payload. Pre-populated from append, fetched on demand from query.""" + ... + +@dataclass +class EmbeddingObservation(Observation): + """Returned by EmbeddingStream terminals. Auto-projects .data to source stream.""" + + similarity: float | None = None # 0..1, populated by search_embedding + + @property + def data(self) -> Any: + """Lazily loads from the source stream (e.g., Image), not the embedding.""" + ... + + @property + def embedding(self) -> Embedding: + """The Embedding object (has .vector, supports @ for cosine similarity).""" + ... + +# --- Transformer --- + +class Transformer(ABC, Generic[T, R]): + """Transforms a source stream into results on a target stream.""" + + def process(self, source: Stream[T], target: Stream[R]) -> None: + """Batch/historical processing. Full access to source stream.""" + ... + + def on_append(self, obs: Observation, target: Stream[R]) -> None: + """Reactive processing. Called per new item.""" + ... + + supports_backfill: bool = True + supports_live: bool = True + output_type: type | None = None + +# --- Streams --- + +class Stream(Generic[T]): + # Write (DB-backed only) + def append(self, payload: T, *, + ts: float | None = None, + pose: PoseLike | None = None, + tags: dict[str, Any] | None = None, + parent_id: int | None = None, + ) -> Observation: ... + + # Filter (returns new Stream, lazy) + def after(self, t: float) -> Stream[T]: ... + def before(self, t: float) -> Stream[T]: ... + def time_range(self, t1: float, t2: float) -> Stream[T]: ... + def at(self, t: float, *, tolerance: float = 1.0) -> Stream[T]: ... + def near(self, pose: PoseLike, radius: float) -> Stream[T]: ... + def filter_tags(self, **tags: Any) -> Stream[T]: ... + + # Order / paginate + def order_by(self, field: str, *, desc: bool = False) -> Stream[T]: ... + def limit(self, k: int) -> Stream[T]: ... + def offset(self, n: int) -> Stream[T]: ... + + # Transform + def transform(self, + xf: Transformer[T, R] | Callable[[T], R | list[R] | None], + *, live: bool = False, + backfill_only: bool = False, + ) -> Stream[R]: ... + + # Materialize (on TransformStream, accepts optional session= fallback) + def store(self, name: str | None = None, session: Session | None = None) -> Stream[T]: ... + + # Cross-stream lineage + def project_to(self, target: Stream[R]) -> Stream[R]: ... + + # Iteration & Terminals + def __iter__(self) -> Iterator[Observation]: ... # lazy, pages internally + def fetch(self) -> ObservationSet[T]: ... # list-like + stream-like result set + def fetch_pages(self, batch_size: int = 128) -> Iterator[list[Observation]]: ... + def one(self) -> Observation: ... + def last(self) -> Observation: ... + def count(self) -> int: ... + + # Reactive + @property + def appended(self) -> Observable[Observation]: ... + +class EmbeddingStream(Stream[T]): + """Created automatically when a Transformer produces Embedding output. + Terminals return EmbeddingObservation (auto-projects .data to source stream).""" + def search_embedding(self, query: Embedding | list[float] | str | Any, + *, k: int, raw: bool = False) -> Stream[Any]: ... + +class TextStream(Stream[T]): + """Stream with FTS index.""" + def search_text(self, text: str, *, k: int | None = None) -> TextStream[T]: ... + +class ObservationSet(Stream[T]): + """Materialized result set from fetch(). List-like + stream-like.""" + def __len__(self) -> int: ... + def __getitem__(self, index: int) -> Observation: ... + def __iter__(self) -> Iterator[Observation]: ... + def __bool__(self) -> bool: ... + +# --- Helpers --- + +def ingest(stream: Stream, source: Iterable[tuple[float, Any]], *, + pose_source: Any | None = None) -> int: + """Ingest (ts, payload) pairs into a stream. Returns count.""" + ... + +# --- Session / Store --- + +PoseProvider = Callable[[], PoseLike | None] + +class Session: + def stream(self, name: str, payload_type: type | None = None, *, + pose_provider: PoseProvider | None = None) -> Stream: ... + def text_stream(self, name: str, payload_type: type | None = None, *, + tokenizer: str = "unicode61", + pose_provider: PoseProvider | None = None) -> TextStream: ... + def embedding_stream(self, name: str, payload_type: type | None = None, *, + vec_dimensions: int | None = None, + pose_provider: PoseProvider | None = None, + parent_table: str | None = None, + embedding_model: EmbeddingModel | None = None) -> EmbeddingStream: ... + def materialize_transform(self, name: str, source: Stream, + transformer: Transformer, + *, payload_type: type | None = None, + live: bool = False, + backfill_only: bool = False) -> Stream: ... + def list_streams(self) -> list[StreamInfo]: ... + def resolve_parent_stream(self, name: str) -> str | None: ... + def resolve_lineage_chain(self, source: str, target: str) -> tuple[str, ...]: ... + def close(self) -> None: ... + +class Store: + def session(self) -> Session: ... + def close(self) -> None: ... +``` + +## Internal Backing (impl detail) + +A `Stream` can be backed by different things — the user never sees this: + +- **DB tables** — from `session.stream()`. Metadata + payload + indexes. +- **Predicate** — from `.after()`, `.near()`, etc. Lazy SQL WHERE. +- **Transform** — from `.transform(t)`. Source stream + Transformer. +- **ListBackend** — from `ObservationSet`. In-memory Python-side filtering. + +The impl decides how to execute based on the backing chain. + +## SQLite Schema + +Each stream `{name}` creates these tables: + +```sql +-- Metadata table (compact rows, fast scans) +CREATE TABLE {name} ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + ts REAL, + pose_x REAL, -- position + pose_y REAL, + pose_z REAL, + pose_qx REAL, -- orientation quaternion (stored, not indexed) + pose_qy REAL, + pose_qz REAL, + pose_qw REAL, + tags TEXT DEFAULT '{}', + parent_id INTEGER -- lineage: source observation id +); +CREATE INDEX idx_{name}_ts ON {name}(ts); + +-- Payload table (blobs, loaded on demand) +CREATE TABLE {name}_payload ( + id INTEGER PRIMARY KEY, + data BLOB +); + +-- R*Tree spatial index (position only) +CREATE VIRTUAL TABLE {name}_rtree USING rtree( + id, + min_x, max_x, + min_y, max_y, + min_z, max_z +); +``` + +**Optional per stream kind:** + +```sql +-- TextStream: FTS5 full-text index +CREATE VIRTUAL TABLE {name}_fts USING fts5(content, tokenize='unicode61'); + +-- EmbeddingStream: vec0 vector index (cosine distance) +CREATE VIRTUAL TABLE {name}_vec USING vec0( + embedding float[{dim}] distance_metric=cosine +); +``` + +### Key design decisions + +- **Separate payload table** — metadata queries (`fetch`, `count`, `near`, filters) never touch blob data. Payload is loaded lazily via `obs.data`. +- **Decomposed pose columns** — enables R*Tree spatial index for `.near()` queries. Orientation stored for reconstruction but not spatially indexed. +- **R*Tree for spatial queries** — `.near(pose, radius)` compiles to an R*Tree range query (bounding box at +/-radius), with post-filter for exact Euclidean distance. +- **Cosine distance metric** — vec0 uses `distance_metric=cosine` (0=identical, 2=opposite). Similarity = `1.0 - distance`, clamped to [0, 1]. + +### Lazy payload loading + +`fetch()` returns `Observation` with lazy `.data`: +- Metadata query: `SELECT id, ts, pose_x, ..., tags, parent_id FROM {name} WHERE ...` +- `_data` stays `_UNSET`, `_data_loader` is set to: `SELECT data FROM {name}_payload WHERE id = ?` +- Only `obs.data` access triggers the blob read + codec decode + +This means iterating metadata (`obs.ts`, `obs.pose`, `obs.tags`) is cheap. + +### NearFilter SQL compilation + +```python +# .near(pose, 5.0) compiles to: +# JOIN {name}_rtree AS r ON r.id = {name}.id +# WHERE r.min_x >= pose.position.x - 5.0 AND r.max_x <= pose.position.x + 5.0 +# AND r.min_y >= pose.position.y - 5.0 AND r.max_y <= pose.position.y + 5.0 +# AND r.min_z >= pose.position.z - 5.0 AND r.max_z <= pose.position.z + 5.0 +``` + +For exact distance (not just bounding box), a post-filter computes Euclidean distance on the R*Tree candidates. + +## Serialization (Codec) + +Each stream has a `Codec[T]` that handles payload encode/decode. Auto-selected from `payload_type`. + +```python +class Codec(Protocol[T]): + def encode(self, value: T) -> bytes: ... + def decode(self, data: bytes) -> T: ... + +class LcmCodec(Codec[DimosMsg]): + """For DimosMsg types — uses lcm_encode/lcm_decode.""" + +class JpegCodec(Codec[Image]): + """For Image types — uses JPEG compression.""" + +class PickleCodec(Codec[Any]): + """Fallback for arbitrary Python objects.""" + +def codec_for_type(payload_type: type[T] | None) -> Codec[T]: + """Auto-select codec based on payload type.""" + ... +``` + +Lives in `dimos.memory.codec`. + +Transparent to the user — just pass `payload_type` to `session.stream()`: +```python +images = session.stream("images", Image) # auto LCM codec +numbers = session.stream("numbers", int) # auto pickle codec +``` + +Tags are JSON. Poses are decomposed into columns (not serialized). + +### Stream metadata (`_streams` table) + +``` +name TEXT PRIMARY KEY +payload_module TEXT -- fully qualified, e.g. "dimos.msgs.sensor_msgs.Image.Image" +stream_kind TEXT -- "stream" | "text" | "embedding" +parent_stream TEXT -- parent stream name (lineage for project_to/join) +embedding_dim INTEGER -- vec0 dimension (embedding streams only) +``` + +On restart, `session.stream("images")` (no `payload_type`) resolves the class from `payload_module` via `importlib`, then selects the codec automatically. `embedding_dim` allows recreating the vec0 table without needing to see the first embedding again. + +## Resolved Questions + +1. **`.append()` on non-stored streams?** → `TypeError` (requires backend). +2. **Multiple `.store()` calls?** → Idempotent — returns existing stream if already stored. +3. ~~**Memory pressure from in-memory transforms?**~~ → Solved via `fetch_pages`. +4. **Pose storage** → Decomposed columns + R*Tree index (not binary blob). +5. **Payload loading** → Lazy via separate `{name}_payload` table. +6. **`__iter__`** → `for page in self.fetch_pages(): yield from page` — lazy, memory-efficient iteration. +7. **`project_to` / lineage** → Implemented via `parent_id` column + `_streams.parent_stream`. Multi-hop chains supported. +8. **`fetch()` return type** → `ObservationSet` (list-like + stream-like). +9. **Similarity scores** → `EmbeddingObservation.similarity` populated from vec0 cosine distance. + +## Open Questions + +1. **Incremental transforms** — re-running a stored transform should resume from last processed item. +2. **4D indexing** — should R*Tree include time as a 4th dimension? diff --git a/plans/memory/sqlite.md b/plans/memory/sqlite.md new file mode 100644 index 0000000000..173bedb5b6 --- /dev/null +++ b/plans/memory/sqlite.md @@ -0,0 +1,621 @@ +# SQLite Implementation + +Implementation spec for the SQLite backend. A coding agent should be able to implement the full backend from this document + `api.md`. + +## File Structure + +``` +dimos/memory/ + __init__.py # public exports + types.py # Observation, EmbeddingObservation, StreamInfo, Filter types + stream.py # Stream, EmbeddingStream, TextStream, ObservationSet, ListBackend + transformer.py # Transformer ABC, PerItemTransformer, EmbeddingTransformer, etc. + store.py # Session ABC, Store ABC + codec.py # LcmCodec, JpegCodec, PickleCodec, codec_for_type() + ingest.py # ingest() helper for batch ingestion + viz.py # similarity_heatmap(), similarity_poses(), log_similarity_timeline() + + impl/ + sqlite.py # SqliteStore, SqliteSession, Sqlite*Backend (single file) + test_sqlite.py # tests +``` + +## Dependencies + +- `sqlite3` (stdlib) +- `sqlite-vec` — vector similarity search via vec0 virtual table. Loaded via `sqlite_vec.load(conn)`. +- FTS5 — built into SQLite by default on most platforms. +- R*Tree — built into SQLite by default. +- `reactivex` — for `.appended` observable (already a DimOS dependency). + +## Connection Management + +### SqliteStore + +```python +class SqliteStore(Store): + def __init__(self, path: str): + self._path = path + self._conn = sqlite3.connect(path) + self._conn.execute("PRAGMA journal_mode=WAL") + self._conn.execute("PRAGMA synchronous=NORMAL") + self._load_extensions() + + def session(self) -> SqliteSession: + return SqliteSession(self._conn) + + def _load_extensions(self) -> None: + try: + import sqlite_vec + self._conn.enable_load_extension(True) + sqlite_vec.load(self._conn) + self._conn.enable_load_extension(False) + except ImportError: + pass # vec0 unavailable — search_embedding will raise + + def close(self) -> None: + self._conn.close() +``` + +### SqliteSession + +```python +class SqliteSession(Session): + def __init__(self, conn: sqlite3.Connection): + self._conn = conn + self._streams: dict[str, Stream] = {} # cache by name + self._ensure_meta_table() + + def _ensure_meta_table(self): + """Create _streams registry table if not exists.""" + self._conn.execute(""" + CREATE TABLE IF NOT EXISTS _streams ( + name TEXT PRIMARY KEY, + payload_module TEXT, + stream_kind TEXT DEFAULT 'stream', + parent_stream TEXT, + embedding_dim INTEGER + ) + """) + + def stream(self, name, payload_type=None, *, pose_provider=None) -> Stream: + # Returns cached or creates new. payload_type required for new streams. + ... + + def text_stream(self, name, payload_type=None, *, tokenizer="unicode61", + pose_provider=None) -> TextStream: + ... + + def embedding_stream(self, name, payload_type=None, *, vec_dimensions=None, + pose_provider=None, parent_table=None, + embedding_model=None) -> EmbeddingStream: + ... + + def list_streams(self) -> list[StreamInfo]: ... + def resolve_parent_stream(self, name: str) -> str | None: ... + def resolve_lineage_chain(self, source: str, target: str) -> tuple[str, ...]: ... + def close(self) -> None: ... +``` + +## Schema + +All table names are prefixed with the stream name. Stream names are validated: `[a-zA-Z_][a-zA-Z0-9_]*`. + +### `_streams` — Global registry + +```sql +CREATE TABLE _streams ( + name TEXT PRIMARY KEY, + payload_module TEXT, -- e.g. 'dimos.msgs.sensor_msgs.Image.Image' + stream_kind TEXT DEFAULT 'stream', -- 'stream', 'embedding', 'text' + parent_stream TEXT, -- parent stream name (lineage) + embedding_dim INTEGER -- only for kind='embedding' +); +``` + +### `{name}` — Observation metadata (all stream types) + +```sql +CREATE TABLE {name} ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + ts REAL, + pose_x REAL, pose_y REAL, pose_z REAL, + pose_qx REAL, pose_qy REAL, pose_qz REAL, pose_qw REAL, + tags TEXT DEFAULT '{}', -- JSON dict + parent_id INTEGER -- lineage: id in parent stream +); +CREATE INDEX idx_{name}_ts ON {name}(ts); +``` + +### `{name}_payload` — Blob/Text payload + +```sql +CREATE TABLE {name}_payload ( + id INTEGER PRIMARY KEY, -- matches {name}.id + data BLOB NOT NULL +); +``` + +Separated from metadata so metadata queries never page in multi-MB blobs. + +### `{name}_rtree` — Spatial index (all stream types) + +```sql +CREATE VIRTUAL TABLE {name}_rtree USING rtree( + id, -- matches {name}.id + min_x, max_x, + min_y, max_y, + min_z, max_z +); +``` + +Only rows with pose are inserted into R*Tree. Rows without pose are excluded from `.near()` results. + +### `{name}_fts` — Full-text search (TextStream only) + +```sql +CREATE VIRTUAL TABLE {name}_fts USING fts5( + content, + tokenize='{tokenizer}' +); +``` + +Standalone FTS table (not content-synced). Rowids match `{name}.id`. + +### `{name}_vec` — Vector index (EmbeddingStream only) + +```sql +CREATE VIRTUAL TABLE {name}_vec USING vec0( + embedding float[{dim}] distance_metric=cosine +); +``` + +Cosine distance: 0 = identical, 2 = opposite. Similarity = `max(0, min(1, 1.0 - distance))`. + +Rowids match `{name}.id`. Dimension inferred from first embedding inserted, or from `vec_dimensions` parameter. + +## Stream Implementation + +### Backend classes + +The stream/backend split separates query logic from stream API: + +```python +class SqliteStreamBackend: + """Base backend for blob streams.""" + def do_append(self, payload, ts, pose, tags, parent_id=None) -> Observation: ... + def execute_fetch(self, query: StreamQuery) -> list[Observation]: ... + def execute_count(self, query: StreamQuery) -> int: ... + +class SqliteEmbeddingBackend(SqliteStreamBackend): + """Adds vec0 index. Overrides execute_fetch for vector search.""" + ... + +class SqliteTextBackend(SqliteStreamBackend): + """Adds FTS5 index. Overrides execute_fetch for text search.""" + ... +``` + +### append() + +```python +def do_append(self, payload, ts, pose, tags, parent_id=None): + ts = ts or time.time() + if pose is None and self._pose_provider: + pose = self._pose_provider() + + pose_cols = _decompose_pose(pose) + tags_json = _serialize_tags(tags) + + # 1. Insert into meta table + cur = self._conn.execute( + f"INSERT INTO {name} " + "(ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, tags, parent_id) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + (ts, *pose_cols, tags_json, parent_id), + ) + row_id = cur.lastrowid + + # 2. Insert into _payload + blob = self._codec.encode(payload) + self._conn.execute( + f"INSERT INTO {name}_payload(id, data) VALUES (?, ?)", + (row_id, blob) + ) + + # 3. Insert into _rtree (if pose) + if pose_cols: + x, y, z = pose_cols[0], pose_cols[1], pose_cols[2] + self._conn.execute( + f"INSERT INTO {name}_rtree(id, min_x, max_x, min_y, max_y, min_z, max_z) " + "VALUES (?, ?, ?, ?, ?, ?, ?)", + (row_id, x, x, y, y, z, z) + ) + + self._conn.commit() + + # 4. Build Observation and emit + obs = Observation(id=row_id, ts=ts, pose=pose, tags=tags or {}, _data=payload) + self._subject.on_next(obs) + return obs +``` + +### EmbeddingBackend.append() + +Same as above, plus inserts into `_vec`: + +```python +if isinstance(payload, Embedding): + vec = payload.to_numpy().tolist() + self._conn.execute( + f"INSERT INTO {name}_vec(rowid, embedding) VALUES (?, ?)", + (row_id, json.dumps(vec)) + ) +``` + +### TextBackend.append() + +Same as base, plus inserts into `_fts`: + +```python +text = str(payload) +self._conn.execute( + f"INSERT INTO {name}_fts(rowid, content) VALUES (?, ?)", + (row_id, text) +) +``` + +## Filter → SQL Generation + +Each filter method returns a new stream with an added filter. At terminal time, the filter chain is compiled to SQL. + +### Filter types + +```python +AfterFilter(t) # → WHERE ts > ? +BeforeFilter(t) # → WHERE ts < ? +TimeRangeFilter(t1, t2) # → WHERE ts >= ? AND ts <= ? +AtFilter(t, tolerance) # → WHERE ABS(ts - ?) <= ? +NearFilter(pose, radius) # → JOIN _rtree bounding box query +TagsFilter(tags) # → WHERE json_extract(tags, '$.key') = ? +EmbeddingSearchFilter(vec, k) # → query _vec, then filter by rowids +TextSearchFilter(text, k) # → query _fts MATCH, then filter by rowids +LineageFilter(source_table, source_query, hops) # → nested IN subquery +``` + +### SQL compilation + +Walk the filter list, generate SQL: + +```python +def _compile_query(query, table) -> tuple[str, list[Any]]: + # Base SELECT + sql = f"SELECT {table}.id, {table}.ts, ... FROM {table}" + + # NearFilter → JOIN _rtree + # Other filters → WHERE clauses + # EmbeddingSearch/TextSearch → handled separately (two-step query) + # LineageFilter → nested IN subquery via _compile_ids() + + return sql, params +``` + +### search_embedding (vec0) + +Two-step process: + +```sql +-- 1. Top-k vector search (cosine distance) +SELECT rowid, distance +FROM {name}_vec +WHERE embedding MATCH ? +ORDER BY distance +LIMIT ? +``` + +```python +# 2. Build dist_map, fetch metadata for those rowids, populate similarity +dist_map = {rowid: distance for rowid, distance in vec_rows} +# ... fetch metadata WHERE id IN (rowids) ... +for obs in observations: + obs.similarity = max(0.0, min(1.0, 1.0 - dist_map[obs.id])) +# Re-sort by distance rank (IN clause doesn't preserve vec0 ordering) +``` + +### search_text (FTS5) + +```sql +SELECT rowid, rank +FROM {name}_fts +WHERE content MATCH ? +ORDER BY rank +``` + +Same two-step: get rowids from FTS5, then fetch metadata. + +### LineageFilter compilation + +LineageFilter compiles to a nested SQL subquery walking the `parent_id` chain: + +```python +# Single hop: embeddings → images +f"SELECT parent_id FROM {source_table} WHERE id IN ({source_ids_sql})" + +# Multi-hop: embeddings → sharp_frames → images +# Wraps each hop as a nested IN subquery +``` + +## Terminal Execution + +### __iter__() — lazy iteration + +`Stream` is directly iterable via `fetch_pages`: + +```python +def __iter__(self): + for page in self.fetch_pages(): + yield from page +``` + +### fetch() + +Returns `ObservationSet` (list-like + stream-like): + +```python +def fetch(self) -> ObservationSet: + results = self._backend.execute_fetch(self._query) + return ObservationSet(results, session=self._session) +``` + +### count() + +```python +def count(self) -> int: + sql, params = _compile_count(query, table) + # → SELECT COUNT(*) FROM {table} WHERE ... + return self._conn.execute(sql, params).fetchone()[0] +``` + +### one() / last() + +- `one()` → `self.limit(1).fetch()[0]` +- `last()` → `self.order_by("ts", desc=True).limit(1).fetch()[0]` + +## Lazy Data Loading + +`Observation.data` uses lazy loading: + +```python +@dataclass +class Observation: + _data: Any = field(default=_UNSET, repr=False) + _data_loader: Callable[[], Any] | None = field(default=None, repr=False) + + @property + def data(self) -> Any: + if self._data is not _UNSET: + return self._data + if self._data_loader is not None: + self._data = self._data_loader() + return self._data + raise LookupError("No data available") +``` + +When building observations from query results: + +```python +def _row_to_obs(self, row) -> Observation: + row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json, pid = row + pose = _reconstruct_pose(px, py, pz, qx, qy, qz, qw) + + def loader(): + r = conn.execute(f"SELECT data FROM {table}_payload WHERE id = ?", (row_id,)).fetchone() + return codec.decode(r[0]) + + return Observation(id=row_id, ts=ts, pose=pose, tags=..., _data_loader=loader) +``` + +### EmbeddingObservation + +For `EmbeddingBackend`, `_row_to_obs` returns `EmbeddingObservation` with two lazy loaders: + +```python +def _row_to_obs(self, row) -> EmbeddingObservation: + # ... same metadata extraction ... + + # _data_loader: loads raw embedding payload + # _source_data_loader: loads from PARENT stream (auto-projection) + # - Resolves parent codec from _streams.payload_module + # - Uses parent_id to look up the source payload + + return EmbeddingObservation( + id=row_id, ts=ts, pose=pose, tags=..., + parent_id=pid, + _data_loader=loader, + _source_data_loader=source_loader, # None if no parent + ) +``` + +## Lineage + +### Storing lineage + +When a Transformer appends to a target stream, `parent_id` links back to the source: + +```python +target.append(result, ts=source_obs.ts, pose=source_obs.pose, + parent_id=source_obs.id) +``` + +The `_streams` registry tracks stream-level lineage: +```python +# After materialize_transform creates the target +UPDATE _streams SET parent_stream = ? WHERE name = ? +``` + +### resolve_lineage_chain() + +Walks `_streams.parent_stream` from source toward target: + +```python +def resolve_lineage_chain(self, source: str, target: str) -> tuple[str, ...]: + # Single hop (source → target): returns () + # Two hops (source → mid → target): returns ("mid",) + # Raises ValueError if no path exists +``` + +### project_to() + +Uses `LineageFilter` to compile a nested SQL subquery: + +```python +def project_to(self, target: Stream) -> Stream: + hops = session.resolve_lineage_chain(source_table, target_table) + return target._with_filter(LineageFilter(source_table, self._query, hops)) +``` + +## Pose Helpers + +PoseStamped in dimos extends Pose directly (no wrapper). Access position/orientation directly: + +```python +def _decompose_pose(pose) -> tuple[float, ...] | None: + if pose is None: + return None + p = pose.position # NOT pose.pose.position + q = pose.orientation + return (p.x, p.y, p.z, q.x, q.y, q.z, q.w) + +def _reconstruct_pose(x, y, z, qx, qy, qz, qw) -> PoseStamped | None: + if x is None: + return None + return PoseStamped( + position=[x, y or 0.0, z or 0.0], # list args (plum dispatch) + orientation=[qx or 0.0, qy or 0.0, qz or 0.0, qw or 1.0], + ) +``` + +NearFilter SQL compilation also accesses `f.pose.position` directly. + +## Transform Execution + +### .transform() — returns lazy stream + +`.transform(xf)` doesn't execute immediately. It returns a `TransformStream`. Execution happens at terminal time or `.store()`. + +### .store() — materializes + +When `.store(name)` is called on a `TransformStream`: + +1. Register target stream in `_streams` (with `parent_stream` set) +2. Create target tables +3. Auto-detect target stream type from transformer: + - `EmbeddingTransformer` → `EmbeddingStream` (with parent_table) + - `CaptionTransformer` → `TextStream` (FTS) + - Other → `Stream` (blob) +4. If not `live` mode: run `xf.process(source, target)` (backfill) +5. If not `backfill_only`: subscribe to source's `.appended`, call `xf.on_append()` +6. Return the stored stream + +### .fetch() on TransformStream (no .store()) + +Executes the transform in-memory using `_CollectorStream`: + +```python +def fetch(self) -> ObservationSet: + collector = _CollectorStream() + self._transformer.process(self._source, collector) + return ObservationSet(collector.results) +``` + +## Reactive (.appended) + +Each stored stream backend has a `Subject` from reactivex: + +```python +class SqliteStreamBackend: + def __init__(self, ...): + self._subject: Subject[Observation] = Subject() + + @property + def appended_subject(self): + return self._subject +``` + +`do_append()` emits to the subject after the DB write succeeds. + +For filtered streams, the observable filters events through the filter chain in Python: + +```python +@property +def appended(self): + raw = self._backend.appended_subject + active = [f for f in self._query.filters + if not isinstance(f, (EmbeddingSearchFilter, LineageFilter))] + return raw.pipe(ops.filter(lambda obs: all(f.matches(obs) for f in active))) +``` + +## Serialization + +### Codec system + +```python +class LcmCodec: # for DimosMsg types (lcm_encode/lcm_decode) +class JpegCodec: # for Image types (JPEG compression) +class PickleCodec: # fallback for arbitrary Python objects + +def codec_for_type(payload_type: type | None) -> Codec: + """Auto-select codec based on payload type.""" + ... +``` + +Lives in `dimos.memory.codec`. + +### Tag serialization + +Tags are stored as JSON text. Empty dict → `"{}"`. + +## SQL Safety + +- **Identifier validation**: stream names must match `^[a-zA-Z_][a-zA-Z0-9_]*$`. +- **Parameterized queries**: all user values go through `?` params, never string interpolation. +- **Table names**: constructed from validated stream names, safe for SQL interpolation. +- **Order fields**: validated against allowlist `{"id", "ts"}`. + +## Thread Safety + +- Each `Session` owns one `sqlite3.Connection` — not shared across threads. +- Multiple sessions can exist on the same file (WAL mode allows concurrent reads + one writer). +- The `appended` subject emits on the thread that called `append()`. + +## Error Handling + +- `append()` on non-stored stream → `TypeError` +- `search_embedding()` on non-embedding stream → `TypeError` +- `search_text()` on non-text stream → `TypeError` +- `search_embedding()` when sqlite-vec not loaded → `RuntimeError` +- Invalid stream name → `ValueError` +- `one()` with no results → `LookupError` +- `stream()` without `payload_type` on new stream → `TypeError` + +## Testing + +Tests in `dimos/memory/impl/test_sqlite.py`. Use `:memory:` store for speed. + +Key test scenarios: +1. Create stream, append, fetch — verify data round-trips +2. Temporal filters (after, before, time_range, at) +3. Spatial filter (near) — with and without pose +4. Tag filtering +5. EmbeddingStream — store embeddings, search_embedding, verify auto-projection +6. TextStream — store text, search_text +7. Transform with lambda — verify lineage +8. Transform with Transformer class — verify process() called +9. Chained filters — verify SQL composition +10. project_to — verify cross-stream lineage (single and multi-hop) +11. fetch_pages — verify pagination +12. Lazy data loading — verify .data only hits DB on access +13. .appended observable — verify reactive emission +14. Similarity scores — verify EmbeddingObservation.similarity populated after search +15. raw=True — verify EmbeddingObservation with similarity + auto-projected data +16. ObservationSet — verify list-like + stream-like behavior From 314b4d30dbbe0d95621bcd875e4362440c01a479 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 5 Mar 2026 20:00:08 +0800 Subject: [PATCH 024/118] =?UTF-8?q?Rename=20run=5Fe2e=5Fexport=20=E2=86=92?= =?UTF-8?q?=20test=5Fe2e=5Fexport,=20delete=20viz.py=20+=20run=5Fviz=5Fdem?= =?UTF-8?q?o,=20fix=20mypy?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename to test_e2e_export.py (it's a pytest file, not a standalone script) - Fix Generator return type and type: ignore for mypy - Delete viz.py (replaced by rerun.py) and run_viz_demo.py - Update docs/api.md to reference rerun.py instead of viz.py --- {plans/memory => dimos/memory/docs}/api.md | 17 +- dimos/memory/impl/run_viz_demo.py | 92 --- .../{run_e2e_export.py => test_e2e_export.py} | 7 +- dimos/memory/rerun.py | 3 + dimos/memory/viz.py | 254 ------- plans/memory/sqlite.md | 621 ------------------ 6 files changed, 14 insertions(+), 980 deletions(-) rename {plans/memory => dimos/memory/docs}/api.md (97%) delete mode 100644 dimos/memory/impl/run_viz_demo.py rename dimos/memory/impl/{run_e2e_export.py => test_e2e_export.py} (95%) delete mode 100644 dimos/memory/viz.py delete mode 100644 plans/memory/sqlite.md diff --git a/plans/memory/api.md b/dimos/memory/docs/api.md similarity index 97% rename from plans/memory/api.md rename to dimos/memory/docs/api.md index fe6897a5d9..cde3a818ff 100644 --- a/plans/memory/api.md +++ b/dimos/memory/docs/api.md @@ -321,20 +321,15 @@ results = clip_embeddings.search_embedding("a door", k=10).fetch() ## Visualization -`dimos.memory.viz` provides helpers for visualizing search results: +`dimos.memory.rerun` sends stream contents to Rerun: ```python -from dimos.memory.viz import similarity_heatmap, log_similarity_timeline +from dimos.memory.rerun import to_rerun -# Get raw results (with similarity scores and poses) -results = embeddings.search_embedding("a hallway", k=200, raw=True).fetch() - -# Spatial heatmap → OccupancyGrid (publishable via LCM, renderable in Rerun) -grid = similarity_heatmap(results, resolution=0.5, spread=2.0) -rr.log("world/heatmap", grid.to_rerun(colormap="inferno")) - -# Temporal timeline → Rerun scalar plot -log_similarity_timeline(results, entity_path="plots/similarity") +# Send any stream to Rerun — auto-derives entity path from stream name, +# logs .data via to_rerun() and poses as arrows +to_rerun(images) +to_rerun(embeddings.search_embedding("a hallway", k=50)) ``` ## Full Example: Cigarette Detection Pipeline diff --git a/dimos/memory/impl/run_viz_demo.py b/dimos/memory/impl/run_viz_demo.py deleted file mode 100644 index 1bd5d95cba..0000000000 --- a/dimos/memory/impl/run_viz_demo.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Visual demo: similarity heatmap + timeline in Rerun. - -Run with: python -m dimos.memory.impl.run_viz_demo -Then open Rerun viewer to see the output. -""" - -from __future__ import annotations - -import numpy as np -import rerun as rr - -from dimos.memory.impl.sqlite import SqliteStore -from dimos.memory.types import EmbeddingObservation -from dimos.memory.viz import log_similarity_timeline, similarity_heatmap -from dimos.models.embedding.base import Embedding -from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped - -# ── Rerun setup ─────────────────────────────────────────────────────── -rr.init("memory_viz_demo", spawn=True) - -# ── Build a small DB with posed embeddings ──────────────────────────── -store = SqliteStore(":memory:") -session = store.session() - -es = session.embedding_stream("demo_emb", vec_dimensions=4) - -# Simulate a robot path with embeddings at various positions -np.random.seed(42) -n_obs = 60 -for i in range(n_obs): - angle = 2 * np.pi * i / n_obs - radius = 3.0 + 0.5 * np.sin(3 * angle) - x = radius * np.cos(angle) - y = radius * np.sin(angle) - - # Embedding: mix of two basis vectors depending on position - mix = (np.sin(angle) + 1) / 2 # 0..1 - vec = np.array([mix, 1.0 - mix, 0.1 * np.cos(angle), 0.0], dtype=np.float32) - vec /= np.linalg.norm(vec) - - pose = PoseStamped( - ts=float(i), - frame_id="world", - position=[x, y, 0.0], - orientation=[0.0, 0.0, 0.0, 1.0], - ) - es.append(Embedding(vec), ts=float(i), pose=pose) - -print(f"Created {es.count()} observations on a circular path") - -# ── Search and visualize ────────────────────────────────────────────── -query = [1.0, 0.0, 0.0, 0.0] -results = es.search_embedding(query, k=n_obs).fetch() - -print(f"Search returned {len(results)} results") -for obs in results[:5]: - assert isinstance(obs, EmbeddingObservation) - print(f" id={obs.id} ts={obs.ts:.0f} similarity={obs.similarity:.3f}") - -# 1. Similarity heatmap → OccupancyGrid → Rerun mesh -grid = similarity_heatmap(results, resolution=0.2, padding=2.0) -print(f"\nHeatmap: {grid}") -rr.log("world/heatmap", grid.to_rerun(colormap="inferno")) - -# 2. Similarity timeline → Rerun scalar plot -log_similarity_timeline(results, entity_path="plots/similarity") -print("Logged similarity timeline") - -# 3. Also log poses as arrows for spatial context -for obs in results: - if obs.pose is not None and obs.ts is not None: - rr.set_time("memory_time", timestamp=obs.ts) - rr.log("world/poses", obs.pose.to_rerun_arrow(length=0.3)) - -print("\nDone — check Rerun viewer") - -session.close() -store.close() diff --git a/dimos/memory/impl/run_e2e_export.py b/dimos/memory/impl/test_e2e_export.py similarity index 95% rename from dimos/memory/impl/run_e2e_export.py rename to dimos/memory/impl/test_e2e_export.py index 9a56e3b283..a175560d9b 100644 --- a/dimos/memory/impl/run_e2e_export.py +++ b/dimos/memory/impl/test_e2e_export.py @@ -37,6 +37,8 @@ from dimos.utils.testing import TimedSensorReplay if TYPE_CHECKING: + from collections.abc import Generator + from dimos.memory.stream import EmbeddingStream DB_DIR = Path(__file__).parent / "e2e_matches" @@ -52,7 +54,7 @@ def clip() -> CLIPModel: @pytest.fixture(scope="module") -def e2e_db(clip: CLIPModel) -> tuple[SqliteStore, Any]: +def e2e_db(clip: CLIPModel) -> Generator[tuple[SqliteStore, Any], None, None]: """Build (or reuse cached) e2e DB with video → sharpness → CLIP embeddings.""" store = SqliteStore(str(DB_PATH)) session = store.session() @@ -86,7 +88,8 @@ def e2e_db(clip: CLIPModel) -> tuple[SqliteStore, Any]: @pytest.fixture(scope="module") def embeddings(e2e_db: tuple[SqliteStore, Any], clip: CLIPModel) -> EmbeddingStream[Any]: _, session = e2e_db - return session.embedding_stream("clip_embeddings", embedding_model=clip) + stream: EmbeddingStream[Any] = session.embedding_stream("clip_embeddings", embedding_model=clip) # type: ignore[assignment] + return stream class TestEmbeddingSearch: diff --git a/dimos/memory/rerun.py b/dimos/memory/rerun.py index e5c5bb371c..629fa5b4dd 100644 --- a/dimos/memory/rerun.py +++ b/dimos/memory/rerun.py @@ -74,5 +74,8 @@ def to_rerun( rr.log(entity_path, data.to_rerun()) count += 1 + if obs.pose is not None and hasattr(obs.pose, "to_rerun_arrow"): + rr.log(f"{entity_path}/pose", obs.pose.to_rerun_arrow()) + rr.reset_time() return count diff --git a/dimos/memory/viz.py b/dimos/memory/viz.py deleted file mode 100644 index feba7446ed..0000000000 --- a/dimos/memory/viz.py +++ /dev/null @@ -1,254 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Visualization helpers for Memory2 search results. - -Produces LCM-publishable messages (OccupancyGrid, PoseStamped) and -Rerun time-series plots from embedding search observations. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Any - -import numpy as np - -if TYPE_CHECKING: - from dimos.memory.types import Observation - from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped - from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid - - -def _normalize_similarities(values: np.ndarray, *, floor_percentile: float = 20.0) -> np.ndarray: - """Min-max normalize, then cut bottom percentile to 0 and re-stretch to 0-1.""" - if len(values) == 0: - return values - vmin, vmax = float(values.min()), float(values.max()) - vrange = vmax - vmin - normed = (values - vmin) / vrange if vrange > 0 else np.full_like(values, 0.5) - floor = float(np.percentile(normed, floor_percentile)) - denom = 1.0 - floor - if denom > 0: - normed = np.clip((normed - floor) / denom, 0.0, 1.0) - return normed - - -def similarity_heatmap( - observations: list[Observation] | Any, - *, - resolution: float = 0.1, - padding: float = 1.0, - spread: float = 0.2, - frame_id: str = "world", -) -> OccupancyGrid: - """Build an OccupancyGrid heatmap from observations with similarity scores. - - Similarity values are normalized relative to the result set's min/max - (so the full 0-100 color range is used even when CLIP similarities - cluster in a narrow band). Each dot's value spreads outward using - ``distance_transform_edt`` — the same technique as - :func:`dimos.mapping.occupancy.gradient.gradient` — fading to 0 at - *spread* metres. - - Args: - observations: Iterable of Observation (must have .pose and .similarity). - resolution: Grid resolution in metres/cell. - padding: Extra metres around the bounding box. - spread: How far each dot's similarity radiates (metres). - frame_id: Coordinate frame for the grid. - - Returns: - OccupancyGrid publishable via LCMTransport. - """ - from scipy.ndimage import distance_transform_edt - - from dimos.memory.types import EmbeddingObservation - from dimos.msgs.geometry_msgs.Pose import Pose - from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid as OG - - posed: list[tuple[float, float, float]] = [] - for obs in observations: - if obs.pose is None: - continue - sim = ( - obs.similarity - if isinstance(obs, EmbeddingObservation) and obs.similarity is not None - else 0.0 - ) - p = obs.pose.position - posed.append((p.x, p.y, sim)) - - if not posed: - return OG(width=1, height=1, resolution=resolution, frame_id=frame_id) - - xs = [p[0] for p in posed] - ys = [p[1] for p in posed] - - min_x = min(xs) - padding - min_y = min(ys) - padding - max_x = max(xs) + padding - max_y = max(ys) + padding - - width = max(1, int((max_x - min_x) / resolution) + 1) - height = max(1, int((max_y - min_y) / resolution) + 1) - - # Normalize: min-max, cut bottom 20%, re-stretch to 0-1 - sims = np.array([s for _, _, s in posed]) - sims_norm = _normalize_similarities(sims) - - # Stamp normalized values onto a float grid (0 = no observation) - value_grid = np.zeros((height, width), dtype=np.float32) - has_obs = np.zeros((height, width), dtype=bool) - - for (px, py, _), snorm in zip(posed, sims_norm, strict=False): - gx = min(int((px - min_x) / resolution), width - 1) - gy = min(int((py - min_y) / resolution), height - 1) - if snorm > value_grid[gy, gx]: - value_grid[gy, gx] = snorm - has_obs[gy, gx] = True - - # Distance transform: distance (in cells) from each empty cell to nearest dot - dist_cells: np.ndarray[Any, Any] = distance_transform_edt(~has_obs) # type: ignore[assignment] - dist_metres = dist_cells * resolution - - # Fade factor: 1.0 at the dot, 0.0 at `spread` metres away - fade = np.clip(1.0 - dist_metres / spread, 0.0, 1.0) - - # For each cell, find the value of its nearest dot (via index output) - _, nearest_idx = distance_transform_edt(~has_obs, return_indices=True) # type: ignore[misc] - nearest_value = value_grid[nearest_idx[0], nearest_idx[1]] - - # Final heatmap = nearest dot's value * distance fade - heatmap = nearest_value * fade - - # Convert to int8 grid: observed region is 0-100, rest is -1 - grid = np.full((height, width), -1, dtype=np.int8) - active = heatmap > 0 - grid[active] = (heatmap[active] * 100).clip(0, 100).astype(np.int8) - # Ensure dot cells are present (0 = black for bottom percentile) - grid[has_obs] = (value_grid[has_obs] * 100).clip(50, 100).astype(np.int8) - - origin = Pose( - position=[min_x, min_y, 0.0], - orientation=[0.0, 0.0, 0.0, 1.0], - ) - - return OG(grid=grid, resolution=resolution, origin=origin, frame_id=frame_id) - - -def similarity_poses(observations: list[Observation] | Any) -> list[PoseStamped]: - """Extract PoseStamped from observations for spatial arrow rendering. - - Args: - observations: Iterable of Observation with .pose. - - Returns: - List of PoseStamped suitable for LCMTransport publishing. - """ - result: list[PoseStamped] = [] - for obs in observations: - if obs.pose is not None: - result.append(obs.pose) - return result - - -def log_similarity_timeline( - observations: list[Observation] | Any, - entity_path: str = "memory/similarity", -) -> None: - """Log similarity scores as a Rerun time-series plot. - - Observations are sorted by timestamp so the plot shows temporal similarity - bumps rather than a descending curve (search results are ranked by similarity). - - Args: - observations: Iterable of EmbeddingObservation with .similarity and .ts. - entity_path: Rerun entity path for the scalar series. - """ - import rerun as rr - - from dimos.memory.types import EmbeddingObservation - - sorted_obs = sorted( - ( - obs - for obs in observations - if isinstance(obs, EmbeddingObservation) - and obs.similarity is not None - and obs.ts is not None - ), - key=lambda o: o.ts, # type: ignore[arg-type] - ) - if not sorted_obs: - return - - # Normalize: cut bottom 20%, re-stretch to 0-1 - raw_sims = np.array([obs.similarity for obs in sorted_obs]) - normed = _normalize_similarities(raw_sims) - - # Disable wall-clock timelines so Rerun defaults to our custom "time" axis - rr.disable_timeline("log_time") - rr.disable_timeline("log_tick") - - # Lock Y-axis to 0-1 - from rerun.blueprint import ScalarAxis - - rr.set_time("time", duration=0.0) - rr.log(entity_path, ScalarAxis(range=(0.0, 1.0), zoom_lock=True), static=True) - - for obs, sim in zip(sorted_obs, normed, strict=True): - rr.set_time("time", duration=obs.ts) # type: ignore[arg-type] - rr.log(entity_path, rr.Scalars(float(sim))) - rr.reset_time() - - -def log_top_images( - observations: list[Observation] | Any, - entity_path: str = "memory/top_matches", - *, - n: int = 6, -) -> None: - """Log the top-N matching images to Rerun as a grid. - - Observations must have ``.data`` that is a dimos Image (with ``.to_rerun()``). - Sorted by similarity (highest first), limited to *n*. - - Args: - observations: Iterable of EmbeddingObservation with .similarity and .data (Image). - entity_path: Rerun entity path prefix. Images logged as ``{entity_path}/{rank}``. - n: Number of top images to log. - """ - import rerun as rr - - from dimos.memory.types import EmbeddingObservation - - ranked = [ - obs - for obs in observations - if isinstance(obs, EmbeddingObservation) and obs.similarity is not None - ] - ranked.sort(key=lambda o: o.similarity or 0.0, reverse=True) # type: ignore[union-attr] - - for i, obs in enumerate(ranked[:n]): - try: - img = obs.data - if hasattr(img, "to_rerun"): - rr.log(f"{entity_path}/{i + 1}", img.to_rerun()) - else: - import numpy as np_inner - - arr = np_inner.asarray(img) - rr.log(f"{entity_path}/{i + 1}", rr.Image(arr)) - except Exception: - continue diff --git a/plans/memory/sqlite.md b/plans/memory/sqlite.md deleted file mode 100644 index 173bedb5b6..0000000000 --- a/plans/memory/sqlite.md +++ /dev/null @@ -1,621 +0,0 @@ -# SQLite Implementation - -Implementation spec for the SQLite backend. A coding agent should be able to implement the full backend from this document + `api.md`. - -## File Structure - -``` -dimos/memory/ - __init__.py # public exports - types.py # Observation, EmbeddingObservation, StreamInfo, Filter types - stream.py # Stream, EmbeddingStream, TextStream, ObservationSet, ListBackend - transformer.py # Transformer ABC, PerItemTransformer, EmbeddingTransformer, etc. - store.py # Session ABC, Store ABC - codec.py # LcmCodec, JpegCodec, PickleCodec, codec_for_type() - ingest.py # ingest() helper for batch ingestion - viz.py # similarity_heatmap(), similarity_poses(), log_similarity_timeline() - - impl/ - sqlite.py # SqliteStore, SqliteSession, Sqlite*Backend (single file) - test_sqlite.py # tests -``` - -## Dependencies - -- `sqlite3` (stdlib) -- `sqlite-vec` — vector similarity search via vec0 virtual table. Loaded via `sqlite_vec.load(conn)`. -- FTS5 — built into SQLite by default on most platforms. -- R*Tree — built into SQLite by default. -- `reactivex` — for `.appended` observable (already a DimOS dependency). - -## Connection Management - -### SqliteStore - -```python -class SqliteStore(Store): - def __init__(self, path: str): - self._path = path - self._conn = sqlite3.connect(path) - self._conn.execute("PRAGMA journal_mode=WAL") - self._conn.execute("PRAGMA synchronous=NORMAL") - self._load_extensions() - - def session(self) -> SqliteSession: - return SqliteSession(self._conn) - - def _load_extensions(self) -> None: - try: - import sqlite_vec - self._conn.enable_load_extension(True) - sqlite_vec.load(self._conn) - self._conn.enable_load_extension(False) - except ImportError: - pass # vec0 unavailable — search_embedding will raise - - def close(self) -> None: - self._conn.close() -``` - -### SqliteSession - -```python -class SqliteSession(Session): - def __init__(self, conn: sqlite3.Connection): - self._conn = conn - self._streams: dict[str, Stream] = {} # cache by name - self._ensure_meta_table() - - def _ensure_meta_table(self): - """Create _streams registry table if not exists.""" - self._conn.execute(""" - CREATE TABLE IF NOT EXISTS _streams ( - name TEXT PRIMARY KEY, - payload_module TEXT, - stream_kind TEXT DEFAULT 'stream', - parent_stream TEXT, - embedding_dim INTEGER - ) - """) - - def stream(self, name, payload_type=None, *, pose_provider=None) -> Stream: - # Returns cached or creates new. payload_type required for new streams. - ... - - def text_stream(self, name, payload_type=None, *, tokenizer="unicode61", - pose_provider=None) -> TextStream: - ... - - def embedding_stream(self, name, payload_type=None, *, vec_dimensions=None, - pose_provider=None, parent_table=None, - embedding_model=None) -> EmbeddingStream: - ... - - def list_streams(self) -> list[StreamInfo]: ... - def resolve_parent_stream(self, name: str) -> str | None: ... - def resolve_lineage_chain(self, source: str, target: str) -> tuple[str, ...]: ... - def close(self) -> None: ... -``` - -## Schema - -All table names are prefixed with the stream name. Stream names are validated: `[a-zA-Z_][a-zA-Z0-9_]*`. - -### `_streams` — Global registry - -```sql -CREATE TABLE _streams ( - name TEXT PRIMARY KEY, - payload_module TEXT, -- e.g. 'dimos.msgs.sensor_msgs.Image.Image' - stream_kind TEXT DEFAULT 'stream', -- 'stream', 'embedding', 'text' - parent_stream TEXT, -- parent stream name (lineage) - embedding_dim INTEGER -- only for kind='embedding' -); -``` - -### `{name}` — Observation metadata (all stream types) - -```sql -CREATE TABLE {name} ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - ts REAL, - pose_x REAL, pose_y REAL, pose_z REAL, - pose_qx REAL, pose_qy REAL, pose_qz REAL, pose_qw REAL, - tags TEXT DEFAULT '{}', -- JSON dict - parent_id INTEGER -- lineage: id in parent stream -); -CREATE INDEX idx_{name}_ts ON {name}(ts); -``` - -### `{name}_payload` — Blob/Text payload - -```sql -CREATE TABLE {name}_payload ( - id INTEGER PRIMARY KEY, -- matches {name}.id - data BLOB NOT NULL -); -``` - -Separated from metadata so metadata queries never page in multi-MB blobs. - -### `{name}_rtree` — Spatial index (all stream types) - -```sql -CREATE VIRTUAL TABLE {name}_rtree USING rtree( - id, -- matches {name}.id - min_x, max_x, - min_y, max_y, - min_z, max_z -); -``` - -Only rows with pose are inserted into R*Tree. Rows without pose are excluded from `.near()` results. - -### `{name}_fts` — Full-text search (TextStream only) - -```sql -CREATE VIRTUAL TABLE {name}_fts USING fts5( - content, - tokenize='{tokenizer}' -); -``` - -Standalone FTS table (not content-synced). Rowids match `{name}.id`. - -### `{name}_vec` — Vector index (EmbeddingStream only) - -```sql -CREATE VIRTUAL TABLE {name}_vec USING vec0( - embedding float[{dim}] distance_metric=cosine -); -``` - -Cosine distance: 0 = identical, 2 = opposite. Similarity = `max(0, min(1, 1.0 - distance))`. - -Rowids match `{name}.id`. Dimension inferred from first embedding inserted, or from `vec_dimensions` parameter. - -## Stream Implementation - -### Backend classes - -The stream/backend split separates query logic from stream API: - -```python -class SqliteStreamBackend: - """Base backend for blob streams.""" - def do_append(self, payload, ts, pose, tags, parent_id=None) -> Observation: ... - def execute_fetch(self, query: StreamQuery) -> list[Observation]: ... - def execute_count(self, query: StreamQuery) -> int: ... - -class SqliteEmbeddingBackend(SqliteStreamBackend): - """Adds vec0 index. Overrides execute_fetch for vector search.""" - ... - -class SqliteTextBackend(SqliteStreamBackend): - """Adds FTS5 index. Overrides execute_fetch for text search.""" - ... -``` - -### append() - -```python -def do_append(self, payload, ts, pose, tags, parent_id=None): - ts = ts or time.time() - if pose is None and self._pose_provider: - pose = self._pose_provider() - - pose_cols = _decompose_pose(pose) - tags_json = _serialize_tags(tags) - - # 1. Insert into meta table - cur = self._conn.execute( - f"INSERT INTO {name} " - "(ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, tags, parent_id) " - "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - (ts, *pose_cols, tags_json, parent_id), - ) - row_id = cur.lastrowid - - # 2. Insert into _payload - blob = self._codec.encode(payload) - self._conn.execute( - f"INSERT INTO {name}_payload(id, data) VALUES (?, ?)", - (row_id, blob) - ) - - # 3. Insert into _rtree (if pose) - if pose_cols: - x, y, z = pose_cols[0], pose_cols[1], pose_cols[2] - self._conn.execute( - f"INSERT INTO {name}_rtree(id, min_x, max_x, min_y, max_y, min_z, max_z) " - "VALUES (?, ?, ?, ?, ?, ?, ?)", - (row_id, x, x, y, y, z, z) - ) - - self._conn.commit() - - # 4. Build Observation and emit - obs = Observation(id=row_id, ts=ts, pose=pose, tags=tags or {}, _data=payload) - self._subject.on_next(obs) - return obs -``` - -### EmbeddingBackend.append() - -Same as above, plus inserts into `_vec`: - -```python -if isinstance(payload, Embedding): - vec = payload.to_numpy().tolist() - self._conn.execute( - f"INSERT INTO {name}_vec(rowid, embedding) VALUES (?, ?)", - (row_id, json.dumps(vec)) - ) -``` - -### TextBackend.append() - -Same as base, plus inserts into `_fts`: - -```python -text = str(payload) -self._conn.execute( - f"INSERT INTO {name}_fts(rowid, content) VALUES (?, ?)", - (row_id, text) -) -``` - -## Filter → SQL Generation - -Each filter method returns a new stream with an added filter. At terminal time, the filter chain is compiled to SQL. - -### Filter types - -```python -AfterFilter(t) # → WHERE ts > ? -BeforeFilter(t) # → WHERE ts < ? -TimeRangeFilter(t1, t2) # → WHERE ts >= ? AND ts <= ? -AtFilter(t, tolerance) # → WHERE ABS(ts - ?) <= ? -NearFilter(pose, radius) # → JOIN _rtree bounding box query -TagsFilter(tags) # → WHERE json_extract(tags, '$.key') = ? -EmbeddingSearchFilter(vec, k) # → query _vec, then filter by rowids -TextSearchFilter(text, k) # → query _fts MATCH, then filter by rowids -LineageFilter(source_table, source_query, hops) # → nested IN subquery -``` - -### SQL compilation - -Walk the filter list, generate SQL: - -```python -def _compile_query(query, table) -> tuple[str, list[Any]]: - # Base SELECT - sql = f"SELECT {table}.id, {table}.ts, ... FROM {table}" - - # NearFilter → JOIN _rtree - # Other filters → WHERE clauses - # EmbeddingSearch/TextSearch → handled separately (two-step query) - # LineageFilter → nested IN subquery via _compile_ids() - - return sql, params -``` - -### search_embedding (vec0) - -Two-step process: - -```sql --- 1. Top-k vector search (cosine distance) -SELECT rowid, distance -FROM {name}_vec -WHERE embedding MATCH ? -ORDER BY distance -LIMIT ? -``` - -```python -# 2. Build dist_map, fetch metadata for those rowids, populate similarity -dist_map = {rowid: distance for rowid, distance in vec_rows} -# ... fetch metadata WHERE id IN (rowids) ... -for obs in observations: - obs.similarity = max(0.0, min(1.0, 1.0 - dist_map[obs.id])) -# Re-sort by distance rank (IN clause doesn't preserve vec0 ordering) -``` - -### search_text (FTS5) - -```sql -SELECT rowid, rank -FROM {name}_fts -WHERE content MATCH ? -ORDER BY rank -``` - -Same two-step: get rowids from FTS5, then fetch metadata. - -### LineageFilter compilation - -LineageFilter compiles to a nested SQL subquery walking the `parent_id` chain: - -```python -# Single hop: embeddings → images -f"SELECT parent_id FROM {source_table} WHERE id IN ({source_ids_sql})" - -# Multi-hop: embeddings → sharp_frames → images -# Wraps each hop as a nested IN subquery -``` - -## Terminal Execution - -### __iter__() — lazy iteration - -`Stream` is directly iterable via `fetch_pages`: - -```python -def __iter__(self): - for page in self.fetch_pages(): - yield from page -``` - -### fetch() - -Returns `ObservationSet` (list-like + stream-like): - -```python -def fetch(self) -> ObservationSet: - results = self._backend.execute_fetch(self._query) - return ObservationSet(results, session=self._session) -``` - -### count() - -```python -def count(self) -> int: - sql, params = _compile_count(query, table) - # → SELECT COUNT(*) FROM {table} WHERE ... - return self._conn.execute(sql, params).fetchone()[0] -``` - -### one() / last() - -- `one()` → `self.limit(1).fetch()[0]` -- `last()` → `self.order_by("ts", desc=True).limit(1).fetch()[0]` - -## Lazy Data Loading - -`Observation.data` uses lazy loading: - -```python -@dataclass -class Observation: - _data: Any = field(default=_UNSET, repr=False) - _data_loader: Callable[[], Any] | None = field(default=None, repr=False) - - @property - def data(self) -> Any: - if self._data is not _UNSET: - return self._data - if self._data_loader is not None: - self._data = self._data_loader() - return self._data - raise LookupError("No data available") -``` - -When building observations from query results: - -```python -def _row_to_obs(self, row) -> Observation: - row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json, pid = row - pose = _reconstruct_pose(px, py, pz, qx, qy, qz, qw) - - def loader(): - r = conn.execute(f"SELECT data FROM {table}_payload WHERE id = ?", (row_id,)).fetchone() - return codec.decode(r[0]) - - return Observation(id=row_id, ts=ts, pose=pose, tags=..., _data_loader=loader) -``` - -### EmbeddingObservation - -For `EmbeddingBackend`, `_row_to_obs` returns `EmbeddingObservation` with two lazy loaders: - -```python -def _row_to_obs(self, row) -> EmbeddingObservation: - # ... same metadata extraction ... - - # _data_loader: loads raw embedding payload - # _source_data_loader: loads from PARENT stream (auto-projection) - # - Resolves parent codec from _streams.payload_module - # - Uses parent_id to look up the source payload - - return EmbeddingObservation( - id=row_id, ts=ts, pose=pose, tags=..., - parent_id=pid, - _data_loader=loader, - _source_data_loader=source_loader, # None if no parent - ) -``` - -## Lineage - -### Storing lineage - -When a Transformer appends to a target stream, `parent_id` links back to the source: - -```python -target.append(result, ts=source_obs.ts, pose=source_obs.pose, - parent_id=source_obs.id) -``` - -The `_streams` registry tracks stream-level lineage: -```python -# After materialize_transform creates the target -UPDATE _streams SET parent_stream = ? WHERE name = ? -``` - -### resolve_lineage_chain() - -Walks `_streams.parent_stream` from source toward target: - -```python -def resolve_lineage_chain(self, source: str, target: str) -> tuple[str, ...]: - # Single hop (source → target): returns () - # Two hops (source → mid → target): returns ("mid",) - # Raises ValueError if no path exists -``` - -### project_to() - -Uses `LineageFilter` to compile a nested SQL subquery: - -```python -def project_to(self, target: Stream) -> Stream: - hops = session.resolve_lineage_chain(source_table, target_table) - return target._with_filter(LineageFilter(source_table, self._query, hops)) -``` - -## Pose Helpers - -PoseStamped in dimos extends Pose directly (no wrapper). Access position/orientation directly: - -```python -def _decompose_pose(pose) -> tuple[float, ...] | None: - if pose is None: - return None - p = pose.position # NOT pose.pose.position - q = pose.orientation - return (p.x, p.y, p.z, q.x, q.y, q.z, q.w) - -def _reconstruct_pose(x, y, z, qx, qy, qz, qw) -> PoseStamped | None: - if x is None: - return None - return PoseStamped( - position=[x, y or 0.0, z or 0.0], # list args (plum dispatch) - orientation=[qx or 0.0, qy or 0.0, qz or 0.0, qw or 1.0], - ) -``` - -NearFilter SQL compilation also accesses `f.pose.position` directly. - -## Transform Execution - -### .transform() — returns lazy stream - -`.transform(xf)` doesn't execute immediately. It returns a `TransformStream`. Execution happens at terminal time or `.store()`. - -### .store() — materializes - -When `.store(name)` is called on a `TransformStream`: - -1. Register target stream in `_streams` (with `parent_stream` set) -2. Create target tables -3. Auto-detect target stream type from transformer: - - `EmbeddingTransformer` → `EmbeddingStream` (with parent_table) - - `CaptionTransformer` → `TextStream` (FTS) - - Other → `Stream` (blob) -4. If not `live` mode: run `xf.process(source, target)` (backfill) -5. If not `backfill_only`: subscribe to source's `.appended`, call `xf.on_append()` -6. Return the stored stream - -### .fetch() on TransformStream (no .store()) - -Executes the transform in-memory using `_CollectorStream`: - -```python -def fetch(self) -> ObservationSet: - collector = _CollectorStream() - self._transformer.process(self._source, collector) - return ObservationSet(collector.results) -``` - -## Reactive (.appended) - -Each stored stream backend has a `Subject` from reactivex: - -```python -class SqliteStreamBackend: - def __init__(self, ...): - self._subject: Subject[Observation] = Subject() - - @property - def appended_subject(self): - return self._subject -``` - -`do_append()` emits to the subject after the DB write succeeds. - -For filtered streams, the observable filters events through the filter chain in Python: - -```python -@property -def appended(self): - raw = self._backend.appended_subject - active = [f for f in self._query.filters - if not isinstance(f, (EmbeddingSearchFilter, LineageFilter))] - return raw.pipe(ops.filter(lambda obs: all(f.matches(obs) for f in active))) -``` - -## Serialization - -### Codec system - -```python -class LcmCodec: # for DimosMsg types (lcm_encode/lcm_decode) -class JpegCodec: # for Image types (JPEG compression) -class PickleCodec: # fallback for arbitrary Python objects - -def codec_for_type(payload_type: type | None) -> Codec: - """Auto-select codec based on payload type.""" - ... -``` - -Lives in `dimos.memory.codec`. - -### Tag serialization - -Tags are stored as JSON text. Empty dict → `"{}"`. - -## SQL Safety - -- **Identifier validation**: stream names must match `^[a-zA-Z_][a-zA-Z0-9_]*$`. -- **Parameterized queries**: all user values go through `?` params, never string interpolation. -- **Table names**: constructed from validated stream names, safe for SQL interpolation. -- **Order fields**: validated against allowlist `{"id", "ts"}`. - -## Thread Safety - -- Each `Session` owns one `sqlite3.Connection` — not shared across threads. -- Multiple sessions can exist on the same file (WAL mode allows concurrent reads + one writer). -- The `appended` subject emits on the thread that called `append()`. - -## Error Handling - -- `append()` on non-stored stream → `TypeError` -- `search_embedding()` on non-embedding stream → `TypeError` -- `search_text()` on non-text stream → `TypeError` -- `search_embedding()` when sqlite-vec not loaded → `RuntimeError` -- Invalid stream name → `ValueError` -- `one()` with no results → `LookupError` -- `stream()` without `payload_type` on new stream → `TypeError` - -## Testing - -Tests in `dimos/memory/impl/test_sqlite.py`. Use `:memory:` store for speed. - -Key test scenarios: -1. Create stream, append, fetch — verify data round-trips -2. Temporal filters (after, before, time_range, at) -3. Spatial filter (near) — with and without pose -4. Tag filtering -5. EmbeddingStream — store embeddings, search_embedding, verify auto-projection -6. TextStream — store text, search_text -7. Transform with lambda — verify lineage -8. Transform with Transformer class — verify process() called -9. Chained filters — verify SQL composition -10. project_to — verify cross-stream lineage (single and multi-hop) -11. fetch_pages — verify pagination -12. Lazy data loading — verify .data only hits DB on access -13. .appended observable — verify reactive emission -14. Similarity scores — verify EmbeddingObservation.similarity populated after search -15. raw=True — verify EmbeddingObservation with similarity + auto-projected data -16. ObservationSet — verify list-like + stream-like behavior From 0b2983cd2ebca6e67f14fc05cbc50a3e7ad9265d Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 5 Mar 2026 20:00:55 +0800 Subject: [PATCH 025/118] added docs --- dimos/memory/docs/query_objects.md | 155 +++++++ dimos/memory/docs/questions.md | 56 +++ dimos/memory/docs/sqlite.md | 621 +++++++++++++++++++++++++++++ dimos/memory/docs/tasks.md | 129 ++++++ dimos/memory/docs/transform.md | 180 +++++++++ 5 files changed, 1141 insertions(+) create mode 100644 dimos/memory/docs/query_objects.md create mode 100644 dimos/memory/docs/questions.md create mode 100644 dimos/memory/docs/sqlite.md create mode 100644 dimos/memory/docs/tasks.md create mode 100644 dimos/memory/docs/transform.md diff --git a/dimos/memory/docs/query_objects.md b/dimos/memory/docs/query_objects.md new file mode 100644 index 0000000000..bf86d39675 --- /dev/null +++ b/dimos/memory/docs/query_objects.md @@ -0,0 +1,155 @@ +# Query Objects — 4D Region + Soft Scoring System + +## Problem + +We need to query observations across 4 dimensions (x, y, z, t) plus embedding space. Current API has flat `filter_*` methods — works for simple cases but doesn't compose. We need: + +1. **Regions** — composable hard boundaries (include/exclude) +2. **Fields** — soft scoring that biases toward a point/time/embedding without hard cutoffs +3. A way to combine both in a single query + +## Key Insight + +Hard filters and soft biases are the same thing at different extremes: +- Hard filter = step function (1 inside, 0 outside) +- Soft bias = smooth decay (gaussian, linear, etc.) + +A unified **Criterion** type handles both. Each criterion maps an observation to a score in `[0, 1]`. Hard filters are just criteria with score `{0, 1}`. + +## Primitives + +### Temporal + +```python +# Hard boundaries +TimeRange(t1, t2) # 1 inside, 0 outside +Before(t) # sugar for TimeRange(-inf, t) +After(t) # sugar for TimeRange(t, inf) + +# Soft — score decays with distance from target +TimeProximity(target, sigma=60.0) # gaussian: exp(-dt²/2σ²) +``` + +### Spatial + +```python +# Hard boundaries +Sphere(center: PoseLike, radius: float) # 1 inside, 0 outside +Box(min: PoseLike, max: PoseLike) # axis-aligned bounding box +HeightRange(z_min, z_max) # horizontal slice + +# Soft +SpatialProximity(point: PoseLike, sigma=5.0) # gaussian in 3D +``` + +### Embedding + +```python +# Soft only (no hard boundary in embedding space makes sense) +EmbeddingSimilarity(vector, candidate_k=100) # cosine similarity, top-k pre-filter +``` + +### Tags + +```python +TagMatch(robot_id="robot1") # hard: exact match on tag values +``` + +## Composition + +Criteria compose via set operators: + +```python +# Intersection — all criteria must score > 0 +region = TimeRange(t1, t2) & Sphere(point, 5.0) + +# Union — any criterion scoring > 0 passes +region = Sphere(p1, 3.0) | Sphere(p2, 3.0) + +# Complement +region = ~TimeRange(t1, t2) # everything outside this window +``` + +For soft criteria, composition combines scores: +- `a & b` → `min(a.score, b.score)` (conservative) +- `a | b` → `max(a.score, b.score)` (permissive) + +## Weighted Scoring + +The interesting problem: "I care about embedding similarity, temporal proximity, AND spatial proximity" — but as soft preferences, not hard cutoffs. + +```python +Score( + time=TimeProximity(target_t, sigma=60), + space=SpatialProximity(point, sigma=5.0), + embedding=EmbeddingSimilarity(vector, candidate_k=200), + weights={"time": 0.3, "space": 0.3, "embedding": 0.4} +) +``` + +Each dimension produces a `[0, 1]` score. Final score = weighted sum. This replaces the vague `rank(**weights)` in the current API. + +## Integration with Query + +```python +# Current flat API (still works, sugar for simple cases) +q.after(t).near(point, 5.0).search_embedding(vec, candidate_k=100) + +# Region object approach +region = After(t) & Sphere(point, 5.0) +q.where(region).search_embedding(vec, candidate_k=100) + +# Full soft scoring — no hard boundaries, just preferences +q.score( + time=TimeProximity(target_t, sigma=120), + space=SpatialProximity(point, sigma=10.0), + embedding=EmbeddingSimilarity(vec, candidate_k=500), +).limit(20) + +# Mixed — hard boundary + soft ranking within +q.where(TimeRange(t1, t2)).score( + space=SpatialProximity(point, sigma=5.0), + embedding=EmbeddingSimilarity(vec, candidate_k=200), +).limit(10) +``` + +## SQL Mapping (SQLite impl) + +How each primitive maps to SQL: + +| Criterion | SQL Strategy | +|--------------------------|-------------------------------------------------------| +| `TimeRange(t1, t2)` | `WHERE ts BETWEEN ? AND ?` (B-tree) | +| `Before(t)` / `After(t)` | `WHERE ts < ?` / `WHERE ts > ?` | +| `Sphere(p, r)` | R*Tree range query on `_rtree` | +| `HeightRange(z1, z2)` | `WHERE pose_z BETWEEN ? AND ?` | +| `Box(min, max)` | R*Tree range query | +| `TimeProximity(t, σ)` | `ORDER BY ABS(ts - ?) ASC` or compute score in SELECT | +| `SpatialProximity(p, σ)` | R*Tree range (pre-filter at ~3σ) + score in SELECT | +| `EmbeddingSimilarity` | sqlite-vec `MATCH` → temp table | +| `TagMatch` | `WHERE json_extract(tags, ?) = ?` | + +Soft scoring strategy: **generous hard pre-filter in SQL, then score in Python**. +- Each soft criterion auto-generates a hard pre-filter at ~3σ (captures 99.7% of relevant results) +- `TimeProximity(t, σ=60)` → SQL: `WHERE ts BETWEEN t-180 AND t+180` (B-tree) +- `SpatialProximity(p, σ=5)` → SQL: R*Tree range query with 15m box +- `EmbeddingSimilarity` → sqlite-vec `MATCH` top-k (already a pre-filter) +- Python computes `[0, 1]` scores on the pre-filtered set, applies weights, sorts + +This keeps SQL simple (range queries on indexes) and Python handles the math. + +## Open Questions + +2. **How does `Score` interact with `search_embedding`?** Embedding search already returns ranked results from vec0. Should `Score.embedding` just re-weight those scores, or does it need a separate search pass? + +3. **Region objects as first-class types?** Do we store/serialize regions (e.g., "the kitchen region" as a reusable spatial boundary)? Or are they always constructed in code? + +4. **Do we need `NOT` regions for exclusion zones?** E.g., "everywhere except within 2m of the charging station." `~Sphere(charger, 2.0)` — complement on spatial regions requires scanning all of `_meta`, can't use R*Tree efficiently. + +5. **Gradient fields?** "Prefer observations taken at higher elevation" — not proximity to a point but a directional preference. `HeightGradient(ascending=True)` as a scorer? + +## Priority + +- **Phase 1**: Keep the flat `filter_*` / `rank()` API. Implement primitives internally. +- **Phase 2**: Expose `Criterion` objects + `where()` + `score()` as the composable API. +- **Phase 3**: Region persistence, named regions, gradient fields. diff --git a/dimos/memory/docs/questions.md b/dimos/memory/docs/questions.md new file mode 100644 index 0000000000..bc91b9f306 --- /dev/null +++ b/dimos/memory/docs/questions.md @@ -0,0 +1,56 @@ +# Questions + +1. "where was I when this log line was added?" +- pose lookup, corelating to log lines found +- assume log line has a pose associated +- assume there are multiple log lines matching a search + +2. "how long have I been observing the red socks currently in view?" +- how many times did I see them before? +- temporal duration tracking + observation frequency + +3. "how many people did I see during last week?" +- assume we are generating a facial recognition db — is this matching a face detection stream, then embeddings? then we are searching over that stream? + +4. "where did you see red socks during last week?" +- we query for red socks embedding similarity, then feed this data into a VLM that further filters for socks +- is this data output into some table? is it like an ObservationSet again? +- then we can create a map (costmap) of red socks? + +5. "did anyone ever open this door? at what times did I see this door open? who opened it?" +- event detection + temporal querying of state changes + +6. "I have a transcription log (STT) and voice embeddings, how do I figure out who is saying what?" +- cross-stream correlation: audio → identity + +7. "I have parallel voice and facial recognition streams, how do I correlate voice to people?" +- I don't see all people speaking at all times +- multi-modal fusion with incomplete overlap + +8. "what's different in this room compared to yesterday?" +- comparing scene snapshots across time, diffing object sets +- requires baseline modeling / temporal comparison + +9. "show me everywhere the cat went today" +- continuous spatial tracking over time, not point queries +- dense pose-stream retrieval + path aggregation + +10. "what happened in the 30 seconds before the vase fell?" +- event-anchored temporal window across all streams +- multi-stream temporal slicing relative to a detected event + +11. "when was the last time I did NOT see the cat in the apartment?" +- negation query — finding gaps in an observation stream +- architecturally different from presence queries + +12. "what time does the mailman usually come?" +- aggregation across days, extracting temporal regularity from sparse events +- cross-session pattern extraction + +13. "what did robot-2 observe in the warehouse that I missed?" +- cross-agent memory diff +- session/robot-scoped queries and set difference across streams + +14. "how far did I travel while carrying an object?" +- filtered pose integration — only accumulate distance when a parallel detection stream has a positive signal +- cross-stream conditional joins diff --git a/dimos/memory/docs/sqlite.md b/dimos/memory/docs/sqlite.md new file mode 100644 index 0000000000..173bedb5b6 --- /dev/null +++ b/dimos/memory/docs/sqlite.md @@ -0,0 +1,621 @@ +# SQLite Implementation + +Implementation spec for the SQLite backend. A coding agent should be able to implement the full backend from this document + `api.md`. + +## File Structure + +``` +dimos/memory/ + __init__.py # public exports + types.py # Observation, EmbeddingObservation, StreamInfo, Filter types + stream.py # Stream, EmbeddingStream, TextStream, ObservationSet, ListBackend + transformer.py # Transformer ABC, PerItemTransformer, EmbeddingTransformer, etc. + store.py # Session ABC, Store ABC + codec.py # LcmCodec, JpegCodec, PickleCodec, codec_for_type() + ingest.py # ingest() helper for batch ingestion + viz.py # similarity_heatmap(), similarity_poses(), log_similarity_timeline() + + impl/ + sqlite.py # SqliteStore, SqliteSession, Sqlite*Backend (single file) + test_sqlite.py # tests +``` + +## Dependencies + +- `sqlite3` (stdlib) +- `sqlite-vec` — vector similarity search via vec0 virtual table. Loaded via `sqlite_vec.load(conn)`. +- FTS5 — built into SQLite by default on most platforms. +- R*Tree — built into SQLite by default. +- `reactivex` — for `.appended` observable (already a DimOS dependency). + +## Connection Management + +### SqliteStore + +```python +class SqliteStore(Store): + def __init__(self, path: str): + self._path = path + self._conn = sqlite3.connect(path) + self._conn.execute("PRAGMA journal_mode=WAL") + self._conn.execute("PRAGMA synchronous=NORMAL") + self._load_extensions() + + def session(self) -> SqliteSession: + return SqliteSession(self._conn) + + def _load_extensions(self) -> None: + try: + import sqlite_vec + self._conn.enable_load_extension(True) + sqlite_vec.load(self._conn) + self._conn.enable_load_extension(False) + except ImportError: + pass # vec0 unavailable — search_embedding will raise + + def close(self) -> None: + self._conn.close() +``` + +### SqliteSession + +```python +class SqliteSession(Session): + def __init__(self, conn: sqlite3.Connection): + self._conn = conn + self._streams: dict[str, Stream] = {} # cache by name + self._ensure_meta_table() + + def _ensure_meta_table(self): + """Create _streams registry table if not exists.""" + self._conn.execute(""" + CREATE TABLE IF NOT EXISTS _streams ( + name TEXT PRIMARY KEY, + payload_module TEXT, + stream_kind TEXT DEFAULT 'stream', + parent_stream TEXT, + embedding_dim INTEGER + ) + """) + + def stream(self, name, payload_type=None, *, pose_provider=None) -> Stream: + # Returns cached or creates new. payload_type required for new streams. + ... + + def text_stream(self, name, payload_type=None, *, tokenizer="unicode61", + pose_provider=None) -> TextStream: + ... + + def embedding_stream(self, name, payload_type=None, *, vec_dimensions=None, + pose_provider=None, parent_table=None, + embedding_model=None) -> EmbeddingStream: + ... + + def list_streams(self) -> list[StreamInfo]: ... + def resolve_parent_stream(self, name: str) -> str | None: ... + def resolve_lineage_chain(self, source: str, target: str) -> tuple[str, ...]: ... + def close(self) -> None: ... +``` + +## Schema + +All table names are prefixed with the stream name. Stream names are validated: `[a-zA-Z_][a-zA-Z0-9_]*`. + +### `_streams` — Global registry + +```sql +CREATE TABLE _streams ( + name TEXT PRIMARY KEY, + payload_module TEXT, -- e.g. 'dimos.msgs.sensor_msgs.Image.Image' + stream_kind TEXT DEFAULT 'stream', -- 'stream', 'embedding', 'text' + parent_stream TEXT, -- parent stream name (lineage) + embedding_dim INTEGER -- only for kind='embedding' +); +``` + +### `{name}` — Observation metadata (all stream types) + +```sql +CREATE TABLE {name} ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + ts REAL, + pose_x REAL, pose_y REAL, pose_z REAL, + pose_qx REAL, pose_qy REAL, pose_qz REAL, pose_qw REAL, + tags TEXT DEFAULT '{}', -- JSON dict + parent_id INTEGER -- lineage: id in parent stream +); +CREATE INDEX idx_{name}_ts ON {name}(ts); +``` + +### `{name}_payload` — Blob/Text payload + +```sql +CREATE TABLE {name}_payload ( + id INTEGER PRIMARY KEY, -- matches {name}.id + data BLOB NOT NULL +); +``` + +Separated from metadata so metadata queries never page in multi-MB blobs. + +### `{name}_rtree` — Spatial index (all stream types) + +```sql +CREATE VIRTUAL TABLE {name}_rtree USING rtree( + id, -- matches {name}.id + min_x, max_x, + min_y, max_y, + min_z, max_z +); +``` + +Only rows with pose are inserted into R*Tree. Rows without pose are excluded from `.near()` results. + +### `{name}_fts` — Full-text search (TextStream only) + +```sql +CREATE VIRTUAL TABLE {name}_fts USING fts5( + content, + tokenize='{tokenizer}' +); +``` + +Standalone FTS table (not content-synced). Rowids match `{name}.id`. + +### `{name}_vec` — Vector index (EmbeddingStream only) + +```sql +CREATE VIRTUAL TABLE {name}_vec USING vec0( + embedding float[{dim}] distance_metric=cosine +); +``` + +Cosine distance: 0 = identical, 2 = opposite. Similarity = `max(0, min(1, 1.0 - distance))`. + +Rowids match `{name}.id`. Dimension inferred from first embedding inserted, or from `vec_dimensions` parameter. + +## Stream Implementation + +### Backend classes + +The stream/backend split separates query logic from stream API: + +```python +class SqliteStreamBackend: + """Base backend for blob streams.""" + def do_append(self, payload, ts, pose, tags, parent_id=None) -> Observation: ... + def execute_fetch(self, query: StreamQuery) -> list[Observation]: ... + def execute_count(self, query: StreamQuery) -> int: ... + +class SqliteEmbeddingBackend(SqliteStreamBackend): + """Adds vec0 index. Overrides execute_fetch for vector search.""" + ... + +class SqliteTextBackend(SqliteStreamBackend): + """Adds FTS5 index. Overrides execute_fetch for text search.""" + ... +``` + +### append() + +```python +def do_append(self, payload, ts, pose, tags, parent_id=None): + ts = ts or time.time() + if pose is None and self._pose_provider: + pose = self._pose_provider() + + pose_cols = _decompose_pose(pose) + tags_json = _serialize_tags(tags) + + # 1. Insert into meta table + cur = self._conn.execute( + f"INSERT INTO {name} " + "(ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, tags, parent_id) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + (ts, *pose_cols, tags_json, parent_id), + ) + row_id = cur.lastrowid + + # 2. Insert into _payload + blob = self._codec.encode(payload) + self._conn.execute( + f"INSERT INTO {name}_payload(id, data) VALUES (?, ?)", + (row_id, blob) + ) + + # 3. Insert into _rtree (if pose) + if pose_cols: + x, y, z = pose_cols[0], pose_cols[1], pose_cols[2] + self._conn.execute( + f"INSERT INTO {name}_rtree(id, min_x, max_x, min_y, max_y, min_z, max_z) " + "VALUES (?, ?, ?, ?, ?, ?, ?)", + (row_id, x, x, y, y, z, z) + ) + + self._conn.commit() + + # 4. Build Observation and emit + obs = Observation(id=row_id, ts=ts, pose=pose, tags=tags or {}, _data=payload) + self._subject.on_next(obs) + return obs +``` + +### EmbeddingBackend.append() + +Same as above, plus inserts into `_vec`: + +```python +if isinstance(payload, Embedding): + vec = payload.to_numpy().tolist() + self._conn.execute( + f"INSERT INTO {name}_vec(rowid, embedding) VALUES (?, ?)", + (row_id, json.dumps(vec)) + ) +``` + +### TextBackend.append() + +Same as base, plus inserts into `_fts`: + +```python +text = str(payload) +self._conn.execute( + f"INSERT INTO {name}_fts(rowid, content) VALUES (?, ?)", + (row_id, text) +) +``` + +## Filter → SQL Generation + +Each filter method returns a new stream with an added filter. At terminal time, the filter chain is compiled to SQL. + +### Filter types + +```python +AfterFilter(t) # → WHERE ts > ? +BeforeFilter(t) # → WHERE ts < ? +TimeRangeFilter(t1, t2) # → WHERE ts >= ? AND ts <= ? +AtFilter(t, tolerance) # → WHERE ABS(ts - ?) <= ? +NearFilter(pose, radius) # → JOIN _rtree bounding box query +TagsFilter(tags) # → WHERE json_extract(tags, '$.key') = ? +EmbeddingSearchFilter(vec, k) # → query _vec, then filter by rowids +TextSearchFilter(text, k) # → query _fts MATCH, then filter by rowids +LineageFilter(source_table, source_query, hops) # → nested IN subquery +``` + +### SQL compilation + +Walk the filter list, generate SQL: + +```python +def _compile_query(query, table) -> tuple[str, list[Any]]: + # Base SELECT + sql = f"SELECT {table}.id, {table}.ts, ... FROM {table}" + + # NearFilter → JOIN _rtree + # Other filters → WHERE clauses + # EmbeddingSearch/TextSearch → handled separately (two-step query) + # LineageFilter → nested IN subquery via _compile_ids() + + return sql, params +``` + +### search_embedding (vec0) + +Two-step process: + +```sql +-- 1. Top-k vector search (cosine distance) +SELECT rowid, distance +FROM {name}_vec +WHERE embedding MATCH ? +ORDER BY distance +LIMIT ? +``` + +```python +# 2. Build dist_map, fetch metadata for those rowids, populate similarity +dist_map = {rowid: distance for rowid, distance in vec_rows} +# ... fetch metadata WHERE id IN (rowids) ... +for obs in observations: + obs.similarity = max(0.0, min(1.0, 1.0 - dist_map[obs.id])) +# Re-sort by distance rank (IN clause doesn't preserve vec0 ordering) +``` + +### search_text (FTS5) + +```sql +SELECT rowid, rank +FROM {name}_fts +WHERE content MATCH ? +ORDER BY rank +``` + +Same two-step: get rowids from FTS5, then fetch metadata. + +### LineageFilter compilation + +LineageFilter compiles to a nested SQL subquery walking the `parent_id` chain: + +```python +# Single hop: embeddings → images +f"SELECT parent_id FROM {source_table} WHERE id IN ({source_ids_sql})" + +# Multi-hop: embeddings → sharp_frames → images +# Wraps each hop as a nested IN subquery +``` + +## Terminal Execution + +### __iter__() — lazy iteration + +`Stream` is directly iterable via `fetch_pages`: + +```python +def __iter__(self): + for page in self.fetch_pages(): + yield from page +``` + +### fetch() + +Returns `ObservationSet` (list-like + stream-like): + +```python +def fetch(self) -> ObservationSet: + results = self._backend.execute_fetch(self._query) + return ObservationSet(results, session=self._session) +``` + +### count() + +```python +def count(self) -> int: + sql, params = _compile_count(query, table) + # → SELECT COUNT(*) FROM {table} WHERE ... + return self._conn.execute(sql, params).fetchone()[0] +``` + +### one() / last() + +- `one()` → `self.limit(1).fetch()[0]` +- `last()` → `self.order_by("ts", desc=True).limit(1).fetch()[0]` + +## Lazy Data Loading + +`Observation.data` uses lazy loading: + +```python +@dataclass +class Observation: + _data: Any = field(default=_UNSET, repr=False) + _data_loader: Callable[[], Any] | None = field(default=None, repr=False) + + @property + def data(self) -> Any: + if self._data is not _UNSET: + return self._data + if self._data_loader is not None: + self._data = self._data_loader() + return self._data + raise LookupError("No data available") +``` + +When building observations from query results: + +```python +def _row_to_obs(self, row) -> Observation: + row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json, pid = row + pose = _reconstruct_pose(px, py, pz, qx, qy, qz, qw) + + def loader(): + r = conn.execute(f"SELECT data FROM {table}_payload WHERE id = ?", (row_id,)).fetchone() + return codec.decode(r[0]) + + return Observation(id=row_id, ts=ts, pose=pose, tags=..., _data_loader=loader) +``` + +### EmbeddingObservation + +For `EmbeddingBackend`, `_row_to_obs` returns `EmbeddingObservation` with two lazy loaders: + +```python +def _row_to_obs(self, row) -> EmbeddingObservation: + # ... same metadata extraction ... + + # _data_loader: loads raw embedding payload + # _source_data_loader: loads from PARENT stream (auto-projection) + # - Resolves parent codec from _streams.payload_module + # - Uses parent_id to look up the source payload + + return EmbeddingObservation( + id=row_id, ts=ts, pose=pose, tags=..., + parent_id=pid, + _data_loader=loader, + _source_data_loader=source_loader, # None if no parent + ) +``` + +## Lineage + +### Storing lineage + +When a Transformer appends to a target stream, `parent_id` links back to the source: + +```python +target.append(result, ts=source_obs.ts, pose=source_obs.pose, + parent_id=source_obs.id) +``` + +The `_streams` registry tracks stream-level lineage: +```python +# After materialize_transform creates the target +UPDATE _streams SET parent_stream = ? WHERE name = ? +``` + +### resolve_lineage_chain() + +Walks `_streams.parent_stream` from source toward target: + +```python +def resolve_lineage_chain(self, source: str, target: str) -> tuple[str, ...]: + # Single hop (source → target): returns () + # Two hops (source → mid → target): returns ("mid",) + # Raises ValueError if no path exists +``` + +### project_to() + +Uses `LineageFilter` to compile a nested SQL subquery: + +```python +def project_to(self, target: Stream) -> Stream: + hops = session.resolve_lineage_chain(source_table, target_table) + return target._with_filter(LineageFilter(source_table, self._query, hops)) +``` + +## Pose Helpers + +PoseStamped in dimos extends Pose directly (no wrapper). Access position/orientation directly: + +```python +def _decompose_pose(pose) -> tuple[float, ...] | None: + if pose is None: + return None + p = pose.position # NOT pose.pose.position + q = pose.orientation + return (p.x, p.y, p.z, q.x, q.y, q.z, q.w) + +def _reconstruct_pose(x, y, z, qx, qy, qz, qw) -> PoseStamped | None: + if x is None: + return None + return PoseStamped( + position=[x, y or 0.0, z or 0.0], # list args (plum dispatch) + orientation=[qx or 0.0, qy or 0.0, qz or 0.0, qw or 1.0], + ) +``` + +NearFilter SQL compilation also accesses `f.pose.position` directly. + +## Transform Execution + +### .transform() — returns lazy stream + +`.transform(xf)` doesn't execute immediately. It returns a `TransformStream`. Execution happens at terminal time or `.store()`. + +### .store() — materializes + +When `.store(name)` is called on a `TransformStream`: + +1. Register target stream in `_streams` (with `parent_stream` set) +2. Create target tables +3. Auto-detect target stream type from transformer: + - `EmbeddingTransformer` → `EmbeddingStream` (with parent_table) + - `CaptionTransformer` → `TextStream` (FTS) + - Other → `Stream` (blob) +4. If not `live` mode: run `xf.process(source, target)` (backfill) +5. If not `backfill_only`: subscribe to source's `.appended`, call `xf.on_append()` +6. Return the stored stream + +### .fetch() on TransformStream (no .store()) + +Executes the transform in-memory using `_CollectorStream`: + +```python +def fetch(self) -> ObservationSet: + collector = _CollectorStream() + self._transformer.process(self._source, collector) + return ObservationSet(collector.results) +``` + +## Reactive (.appended) + +Each stored stream backend has a `Subject` from reactivex: + +```python +class SqliteStreamBackend: + def __init__(self, ...): + self._subject: Subject[Observation] = Subject() + + @property + def appended_subject(self): + return self._subject +``` + +`do_append()` emits to the subject after the DB write succeeds. + +For filtered streams, the observable filters events through the filter chain in Python: + +```python +@property +def appended(self): + raw = self._backend.appended_subject + active = [f for f in self._query.filters + if not isinstance(f, (EmbeddingSearchFilter, LineageFilter))] + return raw.pipe(ops.filter(lambda obs: all(f.matches(obs) for f in active))) +``` + +## Serialization + +### Codec system + +```python +class LcmCodec: # for DimosMsg types (lcm_encode/lcm_decode) +class JpegCodec: # for Image types (JPEG compression) +class PickleCodec: # fallback for arbitrary Python objects + +def codec_for_type(payload_type: type | None) -> Codec: + """Auto-select codec based on payload type.""" + ... +``` + +Lives in `dimos.memory.codec`. + +### Tag serialization + +Tags are stored as JSON text. Empty dict → `"{}"`. + +## SQL Safety + +- **Identifier validation**: stream names must match `^[a-zA-Z_][a-zA-Z0-9_]*$`. +- **Parameterized queries**: all user values go through `?` params, never string interpolation. +- **Table names**: constructed from validated stream names, safe for SQL interpolation. +- **Order fields**: validated against allowlist `{"id", "ts"}`. + +## Thread Safety + +- Each `Session` owns one `sqlite3.Connection` — not shared across threads. +- Multiple sessions can exist on the same file (WAL mode allows concurrent reads + one writer). +- The `appended` subject emits on the thread that called `append()`. + +## Error Handling + +- `append()` on non-stored stream → `TypeError` +- `search_embedding()` on non-embedding stream → `TypeError` +- `search_text()` on non-text stream → `TypeError` +- `search_embedding()` when sqlite-vec not loaded → `RuntimeError` +- Invalid stream name → `ValueError` +- `one()` with no results → `LookupError` +- `stream()` without `payload_type` on new stream → `TypeError` + +## Testing + +Tests in `dimos/memory/impl/test_sqlite.py`. Use `:memory:` store for speed. + +Key test scenarios: +1. Create stream, append, fetch — verify data round-trips +2. Temporal filters (after, before, time_range, at) +3. Spatial filter (near) — with and without pose +4. Tag filtering +5. EmbeddingStream — store embeddings, search_embedding, verify auto-projection +6. TextStream — store text, search_text +7. Transform with lambda — verify lineage +8. Transform with Transformer class — verify process() called +9. Chained filters — verify SQL composition +10. project_to — verify cross-stream lineage (single and multi-hop) +11. fetch_pages — verify pagination +12. Lazy data loading — verify .data only hits DB on access +13. .appended observable — verify reactive emission +14. Similarity scores — verify EmbeddingObservation.similarity populated after search +15. raw=True — verify EmbeddingObservation with similarity + auto-projected data +16. ObservationSet — verify list-like + stream-like behavior diff --git a/dimos/memory/docs/tasks.md b/dimos/memory/docs/tasks.md new file mode 100644 index 0000000000..82d2a4e964 --- /dev/null +++ b/dimos/memory/docs/tasks.md @@ -0,0 +1,129 @@ +# Memory2 — Remaining Tasks + +Gap analysis between `plans/memory/` specs and `dimos/memory/` implementation. + +## P0 — Security / Correctness + +### 1. Stream name validation + +Stream names are interpolated directly into SQL via f-strings. No validation exists — arbitrary input is a SQL injection vector. + +**Spec** (`sqlite.md`): `^[a-zA-Z_][a-zA-Z0-9_]{0,63}$`, reject with `ValueError`. + +**Where**: Add a `_validate_stream_name(name)` check at the top of `SqliteSession.stream()`, `.text_stream()`, `.embedding_stream()`. + +### 2. `_clone()` type annotation vs runtime + +`Stream._clone()` (`stream.py:94-108`) is annotated `-> Stream[T]`, but at runtime it uses `self.__class__.__new__(self.__class__)` which correctly preserves the subclass. So `EmbeddingStream.after(t)` returns an `EmbeddingStream` at runtime — no bug. + +The annotation is wrong for mypy though. Consider `-> Self` (from `typing_extensions`) if we want strict typing. Low priority — runtime works. + +## P1 — Core API Gaps + +### 3. Wire `parent_stream` into `_streams` registry + +`_register_stream()` (`sqlite.py:847-861`) never writes the `parent_stream` column. The column exists in the schema but is always NULL. + +**Where**: `materialize_transform()` (`sqlite.py:770-799`) knows both `source_table` and `name`. Pass `parent_stream=source_table` to `_register_stream()`, and update `_register_stream` to accept and INSERT it. + +This is a prerequisite for `.join()` and stream-level lineage discovery. + +### 4. ~~Implement `.project_to()` — cross-stream lineage~~ ✅ + +Implemented. `project_to(target)` adds a `LineageFilter` to the target stream (same `_with_filter` mechanism as `.after()`, `.near()`, etc.). The filter compiles to a SQL subquery walking the `parent_id` chain. Multi-hop lineage is resolved via `_streams.parent_stream` registry. Result is a fully chainable `Stream`. + +### 4b. Implement `.join()` — cross-stream lineage returning pairs + +`api.md` specifies: +```python +for det, img in detections.after(t).join(images): + print(f"Detected {det.data} in image at {img.pose}") +``` + +Unlike `project_to()` which returns a `Stream`, `join()` yields `tuple[Observation, Observation]` pairs. This is a terminal operation (not chainable) since the return type is pairs, not observations. + +**Depends on**: ~~Task 3~~ Done — `parent_stream` is now written by `materialize_transform()` and read by `resolve_lineage_chain()`. + +### 5. Filtered `.appended` — predicate-filtered reactive subscriptions + +`api.md` specifies: +```python +images.near(kitchen_pose, 5.0).appended.subscribe(...) +``` + +Current impl (`stream.py:276-278`) returns the raw Subject regardless of filters. + +**Fix** (from `sqlite.md`): When `self._query.filters` is non-empty, pipe the root subject through `ops.filter()` that evaluates each predicate in Python: + +```python +@property +def appended(self): + backend = self._require_backend() + obs = backend.appended_subject + if not self._query.filters: + return obs + return obs.pipe(ops.filter(lambda o: self._matches_filters(o))) +``` + +Each filter type needs a `matches(obs) -> bool` method for Python-side evaluation: +- `AfterFilter`: `obs.ts > self.t` +- `NearFilter`: Euclidean distance check +- `TagsFilter`: dict subset check +- etc. + +### 6. Incremental backfill + +`sqlite.md` specifies that re-running a stored transform resumes from the last processed item: + +```python +max_parent = conn.execute( + f"SELECT MAX(parent_id) FROM {target_name}" +).fetchone()[0] + +if max_parent is not None: + source = source.after_id(max_parent) # internal: WHERE id > ? +``` + +**Where**: `materialize_transform()` (`sqlite.py:791-793`). Before calling `transformer.process()`, check if target already has rows and filter source accordingly. + +**Needs**: An internal `_after_id(row_id)` filter (not exposed in public API) that adds `WHERE id > ?`. + +## P2 — Robustness + +### 7. Separate connections per session + +`SqliteStore.session()` (`sqlite.py:886-887`) shares `self._conn` across all sessions. The spec says each session should own its own connection. + +**Fix**: `session()` should call `sqlite3.connect(self._path)` + WAL pragma + extension loading each time, not reuse `self._conn`. Store keeps the path, sessions get independent connections. + +This is required for multi-threaded use (e.g., one session writing in a background thread, another querying in the main thread). + +### 8. `_CollectorStream` doesn't set pose on observations + +`_CollectorStream.append()` (`stream.py:401-419`) accepts `pose` but doesn't store it on the `Observation`: + +```python +obs = Observation(id=self._next_id, ts=ts, tags=tags or {}, parent_id=parent_id, _data=payload) +# pose is silently dropped +``` + +**Fix**: Add `pose=pose` to the Observation constructor call. + +## P3 — Future (not blocking) + +### 9. Query objects — composable 4D regions + soft scoring + +`query_objects.md` proposes `Criterion` types (`TimeRange`, `Sphere`, `TimeProximity`, `SpatialProximity`, `EmbeddingSimilarity`) with `&`/`|`/`~` composition and weighted `Score()`. + +Explicitly Phase 2. Current flat filter API covers all simple cases. Implement when real usage demands soft scoring or region composition. + +### 10. `questions.md` hard cases + +Unresolved query patterns from the product requirements: +- Negation queries ("when did I NOT see the cat") +- Temporal regularity ("what time does the mailman come") +- Cross-agent memory diff +- Conditional pose integration +- Event-anchored multi-stream slicing + +These require extensions beyond the current Stream API — likely built on top of the composable query layer (task 9). diff --git a/dimos/memory/docs/transform.md b/dimos/memory/docs/transform.md new file mode 100644 index 0000000000..409fd8fc6b --- /dev/null +++ b/dimos/memory/docs/transform.md @@ -0,0 +1,180 @@ +# Transform — Unified Derived Stream API + +## Concept + +`.transform()` is a single method on `StreamBase` that handles both historical (batch) and live (reactive) processing. It takes data from a source, applies a function, and stores results into the target stream with lineage. + +## API + +```python +class StreamBase(ABC, Generic[T]): + def transform(self, + source: StreamBase | ObservationSet, + fn: Callable[[Any], T | list[T] | None] | None = None, + *, + live: bool = False, + ) -> Self: + """ + Process source data, store results in this stream. + + Args: + source: where to read from + fn: transform function. Returns T, list[T], or None (skip). + None allowed for EmbeddingStream (uses model.embed implicitly). + live: if True, only subscribe to new appends (no backfill) + + Behavior by source type: + StreamBase → backfill existing + subscribe to live (default) + live=True → skip backfill, only subscribe + ObservationSet → batch process snapshot (live ignored) + + Returns self for chaining. + """ +``` + +## Source type determines mode + +| Source | `live=False` (default) | `live=True` | +|------------------|--------------------------------------------------|-------------------------------| +| `StreamBase` | backfill all existing + subscribe to `.appended` | subscribe to `.appended` only | +| `ObservationSet` | batch process the set | N/A (ignored) | + +## Transform function contract + +```python +fn: Callable[[Any], T | list[T] | None] +``` + +- Returns `T` → single result stored +- Returns `list[T]` → multiple results stored (e.g., multiple detections per frame) +- Returns `None` or `[]` → nothing stored for this input (e.g., no detections) +- `parent_id` set automatically from source row + +## Examples + +### VLM detections on images + +```python +images = session.stream("images", Image, + pose_provider=lambda: tf.get_pose("world", "base_link")) + +detections = session.stream("cigarette_detections", VLMDetection) + +# Backfill + live +detections.transform(images, fn=lambda img: vlm.detect(img, "people with cigarettes")) + +# After this, every new image.append() triggers detection automatically +# All results are queryable +rows = detections.query().filter_after(one_hour_ago).fetch() +``` + +### Live-only (skip backfill) + +```python +detections.transform(images, fn=detect_fn, live=True) +# Only processes images appended from now on +``` + +### Historical batch on query results + +```python +# Only process images from the kitchen in the last hour +kitchen_images = images.query().filter_near(kitchen_pose, 5.0).filter_after(one_hour_ago).fetch_set() + +detections.transform(kitchen_images, fn=lambda img: vlm.detect(img, "cigarettes")) +# Batch processes the set, no live subscription +``` + +### Embedding stream (specialized) + +```python +img_emb = session.embedding_stream("img_emb", model=CLIPModel()) + +# fn is implicit — uses model.embed() +img_emb.transform(images, live=True) + +# Equivalent to: +img_emb.transform(images, fn=lambda img: clip.embed(img), live=True) +``` + +### Chaining transforms + +```python +images = session.stream("images", Image, pose_provider=pose_fn) + +# Embeddings from images +img_emb = session.embedding_stream("img_emb", model=CLIPModel()) +img_emb.transform(images, live=True) + +# Detections from images +detections = session.stream("detections", VLMDetection) +detections.transform(images, fn=detect_fn, live=True) + +# Text descriptions from detections (second-level derived) +descriptions = session.text_stream("descriptions", str) +descriptions.transform(detections, fn=lambda det: det.describe(), live=True) +``` + +## Internals + +### Backfill (batch) + +```python +for page in source.iter_meta(page_size=128): + for row in page: + payload = source.load(row) # or row.data + results = fn(payload) + if results is None: + continue + if not isinstance(results, list): + results = [results] + for r in results: + self.append(r, ts=row.ts, pose=row.pose, parent_id=row.id) +``` + +### Live (reactive) + +```python +source.appended.pipe( + ops.map(lambda row: (row, fn(row.data))), + ops.filter(lambda pair: pair[1] is not None), + ops.flat_map(lambda pair: [ + (pair[0], r) for r in (pair[1] if isinstance(pair[1], list) else [pair[1]]) + ]), +).subscribe(lambda pair: self.append(pair[1], ts=pair[0].ts, pose=pair[0].pose, + parent_id=pair[0].id)) +``` + +### EmbeddingStream override + +```python +class EmbeddingStream(StreamBase[T]): + model: EmbeddingModel + + def transform(self, source, fn=None, *, live=False): + if fn is None: + fn = self.model.embed + return super().transform(source, fn, live=live) +``` + +## Lineage + +`transform()` sets `parent_id` on every appended row, linking back to the source row. This enables `project_to()`: + +```python +# Find source images for cigarette detections +with detections.query().fetch_set() as det_set: + source_images = det_set.project_to(images) + for row in source_images.rows(limit=5): + img = images.load(row) +``` + +## Open questions + +1. **Async transforms?** VLM inference is slow. Should `fn` support async/await or rx scheduling (e.g., `observe_on(io_scheduler)`)? + +2. **Error handling?** If `fn` raises on one row, skip it? Log and continue? Configurable? + +3. **Backfill progress?** For large backfills, should `transform()` return a progress observable or run in background? + +4. **Multiple parents?** Current design is single-parent lineage. If a stream derives from two streams (e.g., fusing image + audio), we'd need multi-parent support. Phase 3. From ec499cd65c85d9ffff7a8f40c3188e4129779b43 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 5 Mar 2026 20:02:43 +0800 Subject: [PATCH 026/118] removed tasks.md --- dimos/memory/docs/tasks.md | 129 ------------------------------------- 1 file changed, 129 deletions(-) delete mode 100644 dimos/memory/docs/tasks.md diff --git a/dimos/memory/docs/tasks.md b/dimos/memory/docs/tasks.md deleted file mode 100644 index 82d2a4e964..0000000000 --- a/dimos/memory/docs/tasks.md +++ /dev/null @@ -1,129 +0,0 @@ -# Memory2 — Remaining Tasks - -Gap analysis between `plans/memory/` specs and `dimos/memory/` implementation. - -## P0 — Security / Correctness - -### 1. Stream name validation - -Stream names are interpolated directly into SQL via f-strings. No validation exists — arbitrary input is a SQL injection vector. - -**Spec** (`sqlite.md`): `^[a-zA-Z_][a-zA-Z0-9_]{0,63}$`, reject with `ValueError`. - -**Where**: Add a `_validate_stream_name(name)` check at the top of `SqliteSession.stream()`, `.text_stream()`, `.embedding_stream()`. - -### 2. `_clone()` type annotation vs runtime - -`Stream._clone()` (`stream.py:94-108`) is annotated `-> Stream[T]`, but at runtime it uses `self.__class__.__new__(self.__class__)` which correctly preserves the subclass. So `EmbeddingStream.after(t)` returns an `EmbeddingStream` at runtime — no bug. - -The annotation is wrong for mypy though. Consider `-> Self` (from `typing_extensions`) if we want strict typing. Low priority — runtime works. - -## P1 — Core API Gaps - -### 3. Wire `parent_stream` into `_streams` registry - -`_register_stream()` (`sqlite.py:847-861`) never writes the `parent_stream` column. The column exists in the schema but is always NULL. - -**Where**: `materialize_transform()` (`sqlite.py:770-799`) knows both `source_table` and `name`. Pass `parent_stream=source_table` to `_register_stream()`, and update `_register_stream` to accept and INSERT it. - -This is a prerequisite for `.join()` and stream-level lineage discovery. - -### 4. ~~Implement `.project_to()` — cross-stream lineage~~ ✅ - -Implemented. `project_to(target)` adds a `LineageFilter` to the target stream (same `_with_filter` mechanism as `.after()`, `.near()`, etc.). The filter compiles to a SQL subquery walking the `parent_id` chain. Multi-hop lineage is resolved via `_streams.parent_stream` registry. Result is a fully chainable `Stream`. - -### 4b. Implement `.join()` — cross-stream lineage returning pairs - -`api.md` specifies: -```python -for det, img in detections.after(t).join(images): - print(f"Detected {det.data} in image at {img.pose}") -``` - -Unlike `project_to()` which returns a `Stream`, `join()` yields `tuple[Observation, Observation]` pairs. This is a terminal operation (not chainable) since the return type is pairs, not observations. - -**Depends on**: ~~Task 3~~ Done — `parent_stream` is now written by `materialize_transform()` and read by `resolve_lineage_chain()`. - -### 5. Filtered `.appended` — predicate-filtered reactive subscriptions - -`api.md` specifies: -```python -images.near(kitchen_pose, 5.0).appended.subscribe(...) -``` - -Current impl (`stream.py:276-278`) returns the raw Subject regardless of filters. - -**Fix** (from `sqlite.md`): When `self._query.filters` is non-empty, pipe the root subject through `ops.filter()` that evaluates each predicate in Python: - -```python -@property -def appended(self): - backend = self._require_backend() - obs = backend.appended_subject - if not self._query.filters: - return obs - return obs.pipe(ops.filter(lambda o: self._matches_filters(o))) -``` - -Each filter type needs a `matches(obs) -> bool` method for Python-side evaluation: -- `AfterFilter`: `obs.ts > self.t` -- `NearFilter`: Euclidean distance check -- `TagsFilter`: dict subset check -- etc. - -### 6. Incremental backfill - -`sqlite.md` specifies that re-running a stored transform resumes from the last processed item: - -```python -max_parent = conn.execute( - f"SELECT MAX(parent_id) FROM {target_name}" -).fetchone()[0] - -if max_parent is not None: - source = source.after_id(max_parent) # internal: WHERE id > ? -``` - -**Where**: `materialize_transform()` (`sqlite.py:791-793`). Before calling `transformer.process()`, check if target already has rows and filter source accordingly. - -**Needs**: An internal `_after_id(row_id)` filter (not exposed in public API) that adds `WHERE id > ?`. - -## P2 — Robustness - -### 7. Separate connections per session - -`SqliteStore.session()` (`sqlite.py:886-887`) shares `self._conn` across all sessions. The spec says each session should own its own connection. - -**Fix**: `session()` should call `sqlite3.connect(self._path)` + WAL pragma + extension loading each time, not reuse `self._conn`. Store keeps the path, sessions get independent connections. - -This is required for multi-threaded use (e.g., one session writing in a background thread, another querying in the main thread). - -### 8. `_CollectorStream` doesn't set pose on observations - -`_CollectorStream.append()` (`stream.py:401-419`) accepts `pose` but doesn't store it on the `Observation`: - -```python -obs = Observation(id=self._next_id, ts=ts, tags=tags or {}, parent_id=parent_id, _data=payload) -# pose is silently dropped -``` - -**Fix**: Add `pose=pose` to the Observation constructor call. - -## P3 — Future (not blocking) - -### 9. Query objects — composable 4D regions + soft scoring - -`query_objects.md` proposes `Criterion` types (`TimeRange`, `Sphere`, `TimeProximity`, `SpatialProximity`, `EmbeddingSimilarity`) with `&`/`|`/`~` composition and weighted `Score()`. - -Explicitly Phase 2. Current flat filter API covers all simple cases. Implement when real usage demands soft scoring or region composition. - -### 10. `questions.md` hard cases - -Unresolved query patterns from the product requirements: -- Negation queries ("when did I NOT see the cat") -- Temporal regularity ("what time does the mailman come") -- Cross-agent memory diff -- Conditional pose integration -- Event-anchored multi-stream slicing - -These require extensions beyond the current Stream API — likely built on top of the composable query layer (task 9). From e344498da2c962f48214387bfae2082ab8752700 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 5 Mar 2026 23:36:03 +0800 Subject: [PATCH 027/118] Optimize memory pipeline: TurboJPEG codec, sharpness downsample, thread reduction - Switch JpegCodec from cv2.imencode to TurboJPEG (2-5x faster encode/decode) - Lower default JPEG quality from 90 to 50 for smaller storage footprint - Downscale sharpness computation to 160px Laplacian variance (10-20x cheaper) - Add MemoryModule with plain-Python sharpness windowing (no rx timer overhead) - Limit OpenCV threads: 2 globally in worker entrypoint, 1 in MemoryModule - Cap global rx ThreadPoolScheduler at 8 workers (was unbounded cpu_count) - Refactor SqliteEmbeddingBackend/SqliteTextBackend to use _post_insert hook - Encode payload before meta insert to prevent orphaned rows on codec error - Add `dimos ps` CLI command and `dps` entrypoint for non-interactive process listing - Add unitree-go2-memory blueprint --- dimos/core/worker.py | 9 + dimos/memory/codec.py | 31 ++- dimos/memory/impl/sqlite.py | 45 ++-- dimos/memory/module.py | 210 ++++++++++++++++++ dimos/msgs/sensor_msgs/Image.py | 22 +- dimos/robot/all_blueprints.py | 2 + dimos/robot/cli/dimos.py | 9 + .../blueprints/smart/unitree_go2_memory.py | 24 ++ dimos/utils/cli/dps.py | 139 ++++++++++++ dimos/utils/threadpool.py | 2 +- pyproject.toml | 1 + 11 files changed, 444 insertions(+), 50 deletions(-) create mode 100644 dimos/memory/module.py create mode 100644 dimos/robot/unitree/go2/blueprints/smart/unitree_go2_memory.py create mode 100644 dimos/utils/cli/dps.py diff --git a/dimos/core/worker.py b/dimos/core/worker.py index b0dd802841..438a23e0bc 100644 --- a/dimos/core/worker.py +++ b/dimos/core/worker.py @@ -260,6 +260,15 @@ def _worker_entrypoint( conn: Connection, worker_id: int, ) -> None: + # Limit OpenCV internal threads to avoid idle thread contention. + # Modules that need parallel cv2 ops can call cv2.setNumThreads() in start(). + try: + import cv2 + + cv2.setNumThreads(2) + except ImportError: + pass + instances: dict[int, Any] = {} try: diff --git a/dimos/memory/codec.py b/dimos/memory/codec.py index 3fdc2f7592..a426339847 100644 --- a/dimos/memory/codec.py +++ b/dimos/memory/codec.py @@ -47,38 +47,47 @@ def decode(self, data: bytes) -> DimosMsg: class JpegCodec: """Codec for Image types — stores as JPEG bytes (lossy, ~10-20x smaller). + Uses TurboJPEG (libjpeg-turbo) for 2-5x faster encode/decode vs OpenCV. Preserves ``frame_id`` as a short header: ````. Pixel data is lossy-compressed; ``ts`` is NOT preserved (stored in the meta table). """ - def __init__(self, quality: int = 90) -> None: + def __init__(self, quality: int = 50) -> None: self._quality = quality + from turbojpeg import TurboJPEG # type: ignore[import-untyped] + + self._tj = TurboJPEG() + + _TJPF_MAP: dict[str, int] | None = None + + @staticmethod + def _get_tjpf_map() -> dict[str, int]: + if JpegCodec._TJPF_MAP is None: + from turbojpeg import TJPF_BGR, TJPF_GRAY, TJPF_RGB # type: ignore[import-untyped] + + JpegCodec._TJPF_MAP = {"BGR": TJPF_BGR, "RGB": TJPF_RGB, "GRAY": TJPF_GRAY} + return JpegCodec._TJPF_MAP def encode(self, value: Any) -> bytes: import struct - import cv2 + from turbojpeg import TJPF_BGR # type: ignore[import-untyped] - bgr = value.to_opencv() - ok, buf = cv2.imencode(".jpg", bgr, [cv2.IMWRITE_JPEG_QUALITY, self._quality]) - if not ok: - raise ValueError("JPEG encoding failed") + pf = self._get_tjpf_map().get(value.format.value, TJPF_BGR) + jpeg_data = self._tj.encode(value.data, quality=self._quality, pixel_format=pf) frame_id = (value.frame_id or "").encode("utf-8") header = struct.pack(" Any: import struct - import cv2 - import numpy as np - from dimos.msgs.sensor_msgs.Image import Image, ImageFormat fid_len = struct.unpack(" Subject[Observation]: # type: ignore[type-arg] def stream_name(self) -> str: return self._table + def _post_insert(self, row_id: int, payload: Any) -> None: + """Hook for subclasses to add extra inserts inside the transaction.""" + def do_append( self, payload: Any, @@ -392,6 +395,10 @@ def do_append( pose_cols = _decompose_pose(pose) tags_json = _serialize_tags(tags) + # Encode payload before touching the DB so a codec error can't leave + # a metadata row without a matching payload row. + payload_blob = self._codec.encode(payload) + # 1. Insert into meta table if pose_cols is not None: cur = self._conn.execute( @@ -409,7 +416,6 @@ def do_append( assert row_id is not None # 2. Insert into payload table - payload_blob = self._codec.encode(payload) self._conn.execute( f"INSERT INTO {self._table}_payload (id, data) VALUES (?, ?)", (row_id, payload_blob), @@ -424,6 +430,9 @@ def do_append( (row_id, x, x, y, y, z, z), ) + # 4. Subclass hook (vec0, FTS, etc.) + self._post_insert(row_id, payload) + self._conn.commit() obs = Observation( @@ -493,19 +502,9 @@ def __init__( self._vec_dimensions = vec_dimensions self._parent_table = parent_table - def do_append( - self, - payload: Any, - ts: float | None, - pose: Any | None, - tags: dict[str, Any] | None, - parent_id: int | None = None, - ) -> Observation: + def _post_insert(self, row_id: int, payload: Any) -> None: from dimos.models.embedding.base import Embedding - obs = super().do_append(payload, ts, pose, tags, parent_id) - - # Also insert into vec0 table if payload is an Embedding if isinstance(payload, Embedding): vec = payload.to_numpy().tolist() if self._vec_dimensions is None: @@ -513,11 +512,8 @@ def do_append( self._ensure_vec_table() self._conn.execute( f"INSERT INTO {self._table}_vec (rowid, embedding) VALUES (?, ?)", - (obs.id, json.dumps(vec)), + (row_id, json.dumps(vec)), ) - self._conn.commit() - - return obs def _ensure_vec_table(self) -> None: if self._vec_dimensions is None: @@ -656,23 +652,12 @@ def __init__( super().__init__(conn, table, pose_provider=pose_provider, codec=codec) self._tokenizer = tokenizer - def do_append( - self, - payload: Any, - ts: float | None, - pose: Any | None, - tags: dict[str, Any] | None, - parent_id: int | None = None, - ) -> Observation: - obs = super().do_append(payload, ts, pose, tags, parent_id) - + def _post_insert(self, row_id: int, payload: Any) -> None: text = str(payload) if payload is not None else "" self._conn.execute( f"INSERT INTO {self._table}_fts (rowid, content) VALUES (?, ?)", - (obs.id, text), + (row_id, text), ) - self._conn.commit() - return obs def execute_fetch(self, query: StreamQuery) -> list[Observation]: text_filter = None @@ -1009,7 +994,7 @@ class SqliteStore(Store): def __init__(self, path: str) -> None: self._path = path - self._conn = sqlite3.connect(path) + self._conn = sqlite3.connect(path, check_same_thread=False) self._conn.execute("PRAGMA journal_mode=WAL") self._conn.execute("PRAGMA synchronous=NORMAL") self._load_extensions() diff --git a/dimos/memory/module.py b/dimos/memory/module.py new file mode 100644 index 0000000000..aa8162e1da --- /dev/null +++ b/dimos/memory/module.py @@ -0,0 +1,210 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Memory module — ingests Image, PointCloud2, and pose into dimos.memory streams.""" + +from __future__ import annotations + +from dataclasses import dataclass +import time +from typing import TYPE_CHECKING, Any + +from dimos.core.core import rpc +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import In +from dimos.memory.impl.sqlite import SqliteStore +from dimos.msgs.sensor_msgs import Image, PointCloud2 +from dimos.utils.logging_config import setup_logger + +if TYPE_CHECKING: + from dimos.memory.store import Session + from dimos.memory.stream import EmbeddingStream, Stream + +logger = setup_logger() + + +@dataclass +class MemoryModuleConfig(ModuleConfig): + db_path: str = "memory.db" + world_frame: str = "world" + robot_frame: str = "base_link" + image_fps: float = 5.0 + # CLIP embedding pipeline + enable_clip: bool = False + sharpness_window: float = 0.5 + + +class MemoryModule(Module[MemoryModuleConfig]): + """Ingests images and point clouds into persistent memory streams. + + Pose is obtained implicitly from the TF system (world -> base_link). + Optionally builds a CLIP embedding index with sharpness-based quality filtering. + + Usage:: + + memory = dimos.deploy(MemoryModule, db_path="/data/robot.db") + memory.color_image.connect(camera.color_image) + memory.pointcloud.connect(lidar.pointcloud) + memory.start() + + # Query via session + session = memory.session + results = session.stream("images").after(t).near(pose, 5.0).fetch() + """ + + color_image: In[Image] + lidar: In[PointCloud2] + + default_config: type[MemoryModuleConfig] = MemoryModuleConfig + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._stores: list[SqliteStore] = [] + self._session: Session | None = None + self._images: Stream[Image] | None = None + self._pointclouds: Stream[PointCloud2] | None = None + self._embeddings: EmbeddingStream[Any] | None = None + + # ── Lifecycle ───────────────────────────────────────────────────── + + def _open_session(self) -> Session: + """Open a new store+session (own connection) for the same DB file.""" + store = SqliteStore(self.config.db_path) + self._stores.append(store) + return store.session() + + @rpc + def start(self) -> None: + super().start() + + import cv2 + + cv2.setNumThreads(1) + + pose_fn = lambda: self.tf.get_pose( # noqa: E731 + self.config.world_frame, self.config.robot_frame + ) + + # Each stream gets its own connection so rx callback threads + # don't share a single sqlite3.Connection (which can't serialize + # concurrent transactions internally). + + # Image stream (best-sharpness per window, no rx windowing overhead) + img_session = self._open_session() + self._images = img_session.stream("images", Image, pose_provider=pose_fn) + self._img_window = 1.0 / self.config.image_fps + self._img_best: Image | None = None + self._img_best_score: float = -1.0 + self._img_window_start: float = 0.0 + self._disposables.add(self.color_image.observable().subscribe(on_next=self._on_image)) + + # Pointcloud stream (only if transport is connected) + if self.lidar._transport is not None: + pc_session = self._open_session() + self._pointclouds = pc_session.stream("pointclouds", PointCloud2, pose_provider=pose_fn) + self._disposables.add(self.lidar.observable().subscribe(on_next=self._on_pointcloud)) + + # Read session (for queries / list_streams) + self._session = self._open_session() + + # Optional CLIP embedding pipeline + if self.config.enable_clip: + self._setup_clip_pipeline() + + logger.info("MemoryModule started (db=%s)", self.config.db_path) + + def _setup_clip_pipeline(self) -> None: + from dimos.memory.transformer import EmbeddingTransformer, QualityWindowTransformer + from dimos.models.embedding.clip import CLIPModel + + assert self._images is not None + + clip = CLIPModel() + clip.start() + + sharp = self._images.transform( + QualityWindowTransformer( + lambda img: img.sharpness, window=self.config.sharpness_window + ), + live=True, + ).store("sharp_frames", Image) + + self._embeddings = sharp.transform( # type: ignore[assignment] + EmbeddingTransformer(clip), live=True + ).store("clip_embeddings") + + logger.info("CLIP embedding pipeline active") + + @rpc + def stop(self) -> None: + self._session = None + for store in self._stores: + store.close() + self._stores.clear() + super().stop() + + # ── Callbacks ───────────────────────────────────────────────────── + + def _on_image(self, img: Image) -> None: + if self._images is None: + return + now = time.monotonic() + score = img.sharpness + if now - self._img_window_start >= self._img_window: + # Window elapsed — flush best from previous window, start new one + if self._img_best is not None: + self._images.append(self._img_best, ts=self._img_best.ts) + self._img_best = img + self._img_best_score = score + self._img_window_start = now + elif score > self._img_best_score: + self._img_best = img + self._img_best_score = score + + def _on_pointcloud(self, pc: PointCloud2) -> None: + if self._pointclouds is not None: + self._pointclouds.append(pc, ts=pc.ts) + + # ── Public API ──────────────────────────────────────────────────── + + @property + def session(self) -> Session: + if self._session is None: + raise RuntimeError("MemoryModule not started") + return self._session + + @property + def images(self) -> Stream[Image]: + if self._images is None: + raise RuntimeError("MemoryModule not started") + return self._images + + @property + def pointclouds(self) -> Stream[PointCloud2]: + if self._pointclouds is None: + raise RuntimeError("MemoryModule not started or no pointcloud connected") + return self._pointclouds + + @property + def embeddings(self) -> EmbeddingStream[Any] | None: + return self._embeddings + + @rpc + def get_stats(self) -> dict[str, int]: + if self._session is None: + return {} + return {s.name: s.count for s in self._session.list_streams()} + + +memory_module = MemoryModule.blueprint diff --git a/dimos/msgs/sensor_msgs/Image.py b/dimos/msgs/sensor_msgs/Image.py index 66c2876b62..3f2e049920 100644 --- a/dimos/msgs/sensor_msgs/Image.py +++ b/dimos/msgs/sensor_msgs/Image.py @@ -377,15 +377,21 @@ def crop(self, x: int, y: int, width: int, height: int) -> Image: @property def sharpness(self) -> float: - """Return sharpness score.""" - gray = self.to_grayscale() - sx = cv2.Sobel(gray.data, cv2.CV_32F, 1, 0, ksize=5) - sy = cv2.Sobel(gray.data, cv2.CV_32F, 0, 1, ksize=5) - magnitude = cv2.magnitude(sx, sy) - mean_mag = float(magnitude.mean()) - if mean_mag <= 0: + """Return sharpness score. + + Downsamples to ~160px wide before computing Laplacian variance + for fast evaluation (~10-20x cheaper than full-res Sobel). + """ + gray = self.to_grayscale().data + # Downsample to ~160px wide for cheap evaluation + h, w = gray.shape[:2] + if w > 160: + scale = 160.0 / w + gray = cv2.resize(gray, (160, int(h * scale)), interpolation=cv2.INTER_AREA) + lap_var = cv2.Laplacian(gray, cv2.CV_32F).var() + if lap_var <= 0: return 0.0 - return float(np.clip((np.log10(mean_mag + 1) - 1.7) / 2.0, 0.0, 1.0)) + return float(np.clip((np.log10(lap_var + 1) - 1.0) / 3.0, 0.0, 1.0)) def save(self, filepath: str) -> bool: arr = self.to_opencv() diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 6026572388..de388b69e5 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -77,6 +77,7 @@ "unitree-go2-agentic-ollama": "dimos.robot.unitree.go2.blueprints.agentic.unitree_go2_agentic_ollama:unitree_go2_agentic_ollama", "unitree-go2-basic": "dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic:unitree_go2_basic", "unitree-go2-detection": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_detection:unitree_go2_detection", + "unitree-go2-memory": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_memory:unitree_go2_memory", "unitree-go2-ros": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_ros:unitree_go2_ros", "unitree-go2-spatial": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_spatial:unitree_go2_spatial", "unitree-go2-temporal-memory": "dimos.robot.unitree.go2.blueprints.agentic.unitree_go2_temporal_memory:unitree_go2_temporal_memory", @@ -116,6 +117,7 @@ "manipulation-module": "dimos.manipulation.manipulation_module", "mapper": "dimos.robot.unitree.type.map", "mcp-client": "dimos.agents.mcp.mcp_client", + "memory-module": "dimos.memory.module", "mid360-module": "dimos.hardware.sensors.lidar.livox.module", "navigation-skill": "dimos.agents.skills.navigation", "object-scene-registration-module": "dimos.perception.object_scene_registration", diff --git a/dimos/robot/cli/dimos.py b/dimos/robot/cli/dimos.py index 47a1e777e8..3171ec246e 100644 --- a/dimos/robot/cli/dimos.py +++ b/dimos/robot/cli/dimos.py @@ -177,6 +177,15 @@ def top(ctx: typer.Context) -> None: dtop_main() +@main.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) +def ps(ctx: typer.Context) -> None: + """List running worker processes (non-interactive).""" + from dimos.utils.cli.dps import main as dps_main + + sys.argv = ["dps", *ctx.args] + dps_main() + + topic_app = typer.Typer(help="Topic commands for pub/sub") main.add_typer(topic_app, name="topic") diff --git a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_memory.py b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_memory.py new file mode 100644 index 0000000000..4f76851e6e --- /dev/null +++ b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_memory.py @@ -0,0 +1,24 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.core.blueprints import autoconnect +from dimos.memory.module import memory_module +from dimos.robot.unitree.go2.blueprints.smart.unitree_go2 import unitree_go2 + +unitree_go2_memory = autoconnect( + unitree_go2, + memory_module(), +).global_config(n_workers=8) + +__all__ = ["unitree_go2_memory"] diff --git a/dimos/utils/cli/dps.py b/dimos/utils/cli/dps.py new file mode 100644 index 0000000000..0ab36d5a71 --- /dev/null +++ b/dimos/utils/cli/dps.py @@ -0,0 +1,139 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""dps — Non-interactive process list over LCM (like `docker ps`). + +Waits for one dtop resource_stats message and prints a table. + +Usage: + dps [--topic /dimos/resource_stats] [--timeout 5] +""" + +from __future__ import annotations + +import sys +import threading +from typing import Any + +from rich.console import Console +from rich.table import Table + +from dimos.protocol.pubsub.impl.lcmpubsub import PickleLCM, Topic + + +def _fmt_pct(v: float) -> str: + return f"{v:.0f}%" + + +def _fmt_mem(v: float) -> str: + mb = v / 1048576 + if mb >= 1024: + return f"{mb / 1024:.1f}G" + return f"{mb:.0f}M" + + +def _fmt_secs(v: float) -> str: + if v >= 3600: + return f"{v / 3600:.1f}h" + if v >= 60: + return f"{v / 60:.1f}m" + return f"{v:.1f}s" + + +def ps(topic: str = "/dimos/resource_stats", timeout: float = 5.0) -> None: + """Wait for one LCM message and print a process table.""" + lcm = PickleLCM(autoconf=True) + result: dict[str, Any] = {} + event = threading.Event() + + def on_msg(msg: dict[str, Any], _topic: str) -> None: + nonlocal result + result = msg + event.set() + + lcm.subscribe(Topic(topic), on_msg) + lcm.start() + + if not event.wait(timeout): + lcm.stop() + Console(stderr=True).print( + f"[red]No dtop message within {timeout:.0f}s. Is --dtop enabled?[/red]" + ) + sys.exit(1) + + lcm.stop() + + table = Table(show_header=True, header_style="bold", padding=(0, 1)) + table.add_column("PID", style="dim") + table.add_column("Role") + table.add_column("Modules") + table.add_column("CPU", justify="right") + table.add_column("Mem", justify="right") + table.add_column("Threads", justify="right") + table.add_column("FDs", justify="right") + table.add_column("User", justify="right") + table.add_column("Sys", justify="right") + + coord = result.get("coordinator", {}) + table.add_row( + str(coord.get("pid", "")), + "[cyan]coordinator[/cyan]", + "", + _fmt_pct(coord.get("cpu_percent", 0)), + _fmt_mem(coord.get("pss", 0)), + str(int(coord.get("num_threads", 0))), + str(int(coord.get("num_fds", 0))), + _fmt_secs(coord.get("cpu_time_user", 0)), + _fmt_secs(coord.get("cpu_time_system", 0)), + ) + + for w in result.get("workers", []): + alive = w.get("alive", False) + wid = w.get("worker_id", "?") + role_style = "green" if alive else "red" + modules = ", ".join(w.get("modules", [])) + table.add_row( + str(w.get("pid", "")), + f"[{role_style}]worker {wid}[/{role_style}]", + modules, + _fmt_pct(w.get("cpu_percent", 0)), + _fmt_mem(w.get("pss", 0)), + str(int(w.get("num_threads", 0))), + str(int(w.get("num_fds", 0))), + _fmt_secs(w.get("cpu_time_user", 0)), + _fmt_secs(w.get("cpu_time_system", 0)), + ) + + Console().print(table) + + +def main() -> None: + topic = "/dimos/resource_stats" + timeout = 5.0 + args = sys.argv[1:] + i = 0 + while i < len(args): + if args[i] == "--topic" and i + 1 < len(args): + topic = args[i + 1] + i += 2 + elif args[i] == "--timeout" and i + 1 < len(args): + timeout = float(args[i + 1]) + i += 2 + else: + i += 1 + ps(topic=topic, timeout=timeout) + + +if __name__ == "__main__": + main() diff --git a/dimos/utils/threadpool.py b/dimos/utils/threadpool.py index a2adc90725..f2fd577d40 100644 --- a/dimos/utils/threadpool.py +++ b/dimos/utils/threadpool.py @@ -36,7 +36,7 @@ def get_max_workers() -> int: environment variable, defaulting to 4 times the CPU count. """ env_value = os.getenv("DIMOS_MAX_WORKERS", "") - return int(env_value) if env_value.strip() else multiprocessing.cpu_count() + return int(env_value) if env_value.strip() else min(8, multiprocessing.cpu_count()) # Create a ThreadPoolScheduler with a configurable number of workers. diff --git a/pyproject.toml b/pyproject.toml index 9e993584c0..9481f2499c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,7 @@ dimos = "dimos.robot.cli.dimos:main" rerun-bridge = "dimos.visualization.rerun.bridge:app" doclinks = "dimos.utils.docs.doclinks:main" dtop = "dimos.utils.cli.dtop:main" +dps = "dimos.utils.cli.dps:main" [project.optional-dependencies] misc = [ From bf0b79a879d5bbc467e24b1cbd950b2d9d7094db Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 6 Mar 2026 13:25:10 +0800 Subject: [PATCH 028/118] text embedding transformer --- dimos/memory/__init__.py | 2 + dimos/memory/impl/sqlite.py | 9 +- dimos/memory/impl/test_sqlite.py | 146 ++++++++++++++++++++++++++++++- dimos/memory/transformer.py | 34 +++++++ 4 files changed, 188 insertions(+), 3 deletions(-) diff --git a/dimos/memory/__init__.py b/dimos/memory/__init__.py index 132f23832b..0104d65e5d 100644 --- a/dimos/memory/__init__.py +++ b/dimos/memory/__init__.py @@ -5,6 +5,7 @@ CaptionTransformer, EmbeddingTransformer, PerItemTransformer, + TextEmbeddingTransformer, Transformer, ) from dimos.memory.types import ( @@ -29,6 +30,7 @@ "Store", "Stream", "StreamInfo", + "TextEmbeddingTransformer", "TextStream", "Transformer", "codec_for_type", diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index 6b4b859322..5923b0bdf6 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -46,7 +46,12 @@ ) from dimos.memory.store import Session, Store from dimos.memory.stream import EmbeddingStream, Stream, TextStream -from dimos.memory.transformer import CaptionTransformer, EmbeddingTransformer, Transformer +from dimos.memory.transformer import ( + CaptionTransformer, + EmbeddingTransformer, + TextEmbeddingTransformer, + Transformer, +) from dimos.memory.types import ( AfterFilter, AtFilter, @@ -888,7 +893,7 @@ def materialize_transform( source_table = source._backend.stream_name target: Stream[Any] - if isinstance(transformer, EmbeddingTransformer): + if isinstance(transformer, (EmbeddingTransformer, TextEmbeddingTransformer)): target = self.embedding_stream(name, payload_type, parent_table=source_table) target._embedding_model = transformer.model elif isinstance(transformer, CaptionTransformer): diff --git a/dimos/memory/impl/test_sqlite.py b/dimos/memory/impl/test_sqlite.py index 2c0bdb58e9..815c5a96ee 100644 --- a/dimos/memory/impl/test_sqlite.py +++ b/dimos/memory/impl/test_sqlite.py @@ -20,7 +20,7 @@ import pytest from dimos.memory.impl.sqlite import SqliteSession, SqliteStore -from dimos.memory.transformer import EmbeddingTransformer +from dimos.memory.transformer import EmbeddingTransformer, TextEmbeddingTransformer from dimos.memory.types import _UNSET, EmbeddingObservation, Observation from dimos.models.embedding.base import Embedding, EmbeddingModel from dimos.msgs.sensor_msgs.Image import Image @@ -248,6 +248,150 @@ def test_text_search(self, session: SqliteSession) -> None: assert all("Motor" in r.data for r in rows) +class TestTextStorage: + """Test storing plain text (str) in streams.""" + + def test_store_and_fetch_str(self, session: SqliteSession) -> None: + s = session.stream("raw_logs", str) + s.append("Robot started navigation to kitchen", ts=1.0) + s.append("Obstacle detected at waypoint 3", ts=2.0) + s.append("Navigation complete", ts=3.0) + + assert s.count() == 3 + rows = s.fetch() + assert rows[0].data == "Robot started navigation to kitchen" + assert rows[2].data == "Navigation complete" + + def test_str_with_tags_and_filters(self, session: SqliteSession) -> None: + s = session.stream("tagged_logs", str) + s.append("Motor fault on joint 3", ts=1.0, tags={"level": "error"}) + s.append("Battery at 80%", ts=2.0, tags={"level": "info"}) + s.append("Motor overheating", ts=3.0, tags={"level": "error"}) + + errors = s.filter_tags(level="error").fetch() + assert len(errors) == 2 + assert all("Motor" in e.data for e in errors) + + def test_str_persists_reopen(self, tmp_path: object) -> None: + from pathlib import Path + + assert isinstance(tmp_path, Path) + db_path = str(tmp_path / "logs.db") + + store1 = SqliteStore(db_path) + s1 = store1.session() + s1.stream("logs", str).append("hello world", ts=1.0) + s1.close() + store1.close() + + store2 = SqliteStore(db_path) + s2 = store2.session() + rows = s2.stream("logs", str).fetch() + assert len(rows) == 1 + assert rows[0].data == "hello world" + s2.close() + store2.close() + + +class TestTextEmbeddingTransformer: + """Test text → embedding → semantic search pipeline.""" + + def test_text_to_embedding_backfill(self, session: SqliteSession) -> None: + """Backfill: store text, transform to embeddings, search by text.""" + + class FakeTextEmbedder(EmbeddingModel): + device = "cpu" + + def embed(self, *imgs: Image) -> Embedding | list[Embedding]: # type: ignore[override] + raise NotImplementedError + + def embed_text(self, *texts: str) -> Embedding | list[Embedding]: + results = [] + for text in texts: + # Simple fake: hash text to a stable vector + h = hash(text) % 1000 / 1000.0 + results.append(Embedding(np.array([h, 1.0 - h, 0.0, 0.0], dtype=np.float32))) + return results if len(results) > 1 else results[0] + + logs = session.stream("te_logs", str) + logs.append("Robot navigated to kitchen", ts=1.0) + logs.append("Battery low warning", ts=2.0) + logs.append("Robot navigated to bedroom", ts=3.0) + + embedder = FakeTextEmbedder() + emb_stream = logs.transform(TextEmbeddingTransformer(embedder)).store("te_log_embeddings") + + assert emb_stream.count() == 3 + + # Search — the model embeds the query text into the same space + results = emb_stream.search_embedding("Robot navigated to kitchen", k=1).fetch() + assert len(results) == 1 + # Auto-projects to source — data should be original text + assert isinstance(results[0].data, str) + + def test_text_embedding_live(self, session: SqliteSession) -> None: + """Live mode: new text is embedded automatically.""" + + class FakeTextEmbedder(EmbeddingModel): + device = "cpu" + + def embed(self, *imgs: Image) -> Embedding | list[Embedding]: # type: ignore[override] + raise NotImplementedError + + def embed_text(self, *texts: str) -> Embedding | list[Embedding]: + results = [] + for text in texts: + h = hash(text) % 1000 / 1000.0 + results.append(Embedding(np.array([h, 1.0 - h, 0.0], dtype=np.float32))) + return results if len(results) > 1 else results[0] + + logs = session.stream("te_live_logs", str) + emb_stream = logs.transform(TextEmbeddingTransformer(FakeTextEmbedder()), live=True).store( + "te_live_embs" + ) + + assert emb_stream.count() == 0 # no backfill + + logs.append("New log entry", ts=1.0) + assert emb_stream.count() == 1 + + logs.append("Another log entry", ts=2.0) + assert emb_stream.count() == 2 + + def test_text_embedding_search_projects_to_source(self, session: SqliteSession) -> None: + """search_embedding auto-projects back to source text stream.""" + + class FakeTextEmbedder(EmbeddingModel): + device = "cpu" + + def embed(self, *imgs: Image) -> Embedding | list[Embedding]: # type: ignore[override] + raise NotImplementedError + + def embed_text(self, *texts: str) -> Embedding | list[Embedding]: + results = [] + for text in texts: + # "kitchen" texts get similar vectors + if "kitchen" in text.lower(): + results.append(Embedding(np.array([1.0, 0.0, 0.0], dtype=np.float32))) + else: + results.append(Embedding(np.array([0.0, 1.0, 0.0], dtype=np.float32))) + return results if len(results) > 1 else results[0] + + logs = session.stream("te_proj_logs", str) + logs.append("Robot entered kitchen", ts=1.0) + logs.append("Battery warning", ts=2.0) + logs.append("Cleaning kitchen floor", ts=3.0) + + emb_stream = logs.transform(TextEmbeddingTransformer(FakeTextEmbedder())).store( + "te_proj_embs" + ) + + # Search for kitchen-related logs + results = emb_stream.search_embedding("kitchen", k=2).fetch() + assert len(results) == 2 + assert all("kitchen" in r.data.lower() for r in results) + + class TestEmbeddingStream: def test_create_and_append(self, session: SqliteSession) -> None: es = session.embedding_stream("emb", vec_dimensions=4) diff --git a/dimos/memory/transformer.py b/dimos/memory/transformer.py index bb7489d608..629b5fec83 100644 --- a/dimos/memory/transformer.py +++ b/dimos/memory/transformer.py @@ -182,6 +182,40 @@ def on_append(self, obs: Observation, target: Stream[str]) -> None: target.append(caption, ts=obs.ts, pose=obs.pose, tags=obs.tags, parent_id=obs.id) +class TextEmbeddingTransformer(Transformer[Any, "Embedding"]): + """Wraps an EmbeddingModel to embed text payloads (strings) into vectors. + + Use this for semantic search over logs, captions, or any text data. + When stored, the output stream becomes an EmbeddingStream with vector index. + """ + + supports_backfill: bool = True + supports_live: bool = True + + def __init__(self, model: EmbeddingModel) -> None: + from dimos.models.embedding.base import Embedding as EmbeddingCls + + self.model = model + self.output_type: type | None = EmbeddingCls + + def process(self, source: Stream[Any], target: Stream[Embedding]) -> None: + for page in source.fetch_pages(): + texts = [str(obs.data) for obs in page] + if not texts: + continue + embeddings = self.model.embed_text(*texts) + if not isinstance(embeddings, list): + embeddings = [embeddings] + for obs, emb in zip(page, embeddings, strict=True): + target.append(emb, ts=obs.ts, pose=obs.pose, tags=obs.tags, parent_id=obs.id) + + def on_append(self, obs: Observation, target: Stream[Embedding]) -> None: + emb = self.model.embed_text(str(obs.data)) + if isinstance(emb, list): + emb = emb[0] + target.append(emb, ts=obs.ts, pose=obs.pose, tags=obs.tags, parent_id=obs.id) + + class EmbeddingTransformer(Transformer[Any, "Embedding"]): """Wraps an EmbeddingModel as a Transformer that produces Embedding output. From b4f9f96561e1695bca5dfe45b0336ca1ebbce8fd Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 6 Mar 2026 20:41:58 +0800 Subject: [PATCH 029/118] cleanup --- dimos/memory/impl/sqlite.py | 32 +++++++++++++++++++++++++++----- dimos/memory/module.py | 4 ++++ dimos/memory/types.py | 4 ++-- 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index 5923b0bdf6..3bf20162e0 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -111,8 +111,13 @@ def _reconstruct_pose( from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped return PoseStamped( - position=[x, y or 0.0, z or 0.0], - orientation=[qx or 0.0, qy or 0.0, qz or 0.0, qw or 1.0], + position=[x, y if y is not None else 0.0, z if z is not None else 0.0], + orientation=[ + qx if qx is not None else 0.0, + qy if qy is not None else 0.0, + qz if qz is not None else 0.0, + qw if qw is not None else 1.0, + ], ) @@ -221,6 +226,8 @@ def _compile_ids( sql += f" WHERE {where}" if query.order_field: + if query.order_field not in _ALLOWED_ORDER_FIELDS: + raise ValueError(f"Invalid order field: {query.order_field!r}") sql += f" ORDER BY {query.order_field}" if query.order_desc: sql += " DESC" @@ -245,8 +252,6 @@ def _compile_query(query: StreamQuery, table: str) -> tuple[str, list[Any]]: params: list[Any] = [] joins: list[str] = [] - _has_near_filter(query) - for f in query.filters: if isinstance(f, NearFilter): # R*Tree bounding-box join @@ -711,6 +716,10 @@ def _fetch_by_text( observations = [self._row_to_obs(r) for r in rows] + # Re-sort by FTS rank (IN clause doesn't preserve FTS5 ordering) + rank = {rid: i for i, rid in enumerate(rowids)} + observations.sort(key=lambda o: rank.get(o.id, len(rank))) + near = _has_near_filter(query) if near is not None: observations = _apply_near_post_filter(observations, near) @@ -781,6 +790,7 @@ def stream( *, pose_provider: PoseProvider | None = None, ) -> Stream[Any]: + _validate_identifier(name) if name in self._streams: return self._streams[name] @@ -810,11 +820,14 @@ def text_stream( tokenizer: str = "unicode61", pose_provider: PoseProvider | None = None, ) -> TextStream[Any]: + _validate_identifier(name) if name in self._streams: return self._streams[name] # type: ignore[return-value] if payload_type is None: payload_type = self._resolve_payload_type(name) + if payload_type is None: + payload_type = str self._ensure_stream_tables(name) self._ensure_fts_table(name, tokenizer) @@ -838,6 +851,7 @@ def embedding_stream( parent_table: str | None = None, embedding_model: EmbeddingModel | None = None, ) -> EmbeddingStream[Any]: + _validate_identifier(name) if name in self._streams: existing = self._streams[name] if embedding_model is not None and isinstance(existing, EmbeddingStream): @@ -872,6 +886,7 @@ def list_streams(self) -> list[StreamInfo]: rows = self._conn.execute("SELECT name, payload_module FROM _streams").fetchall() result: list[StreamInfo] = [] for name, pmodule in rows: + _validate_identifier(name) count_row = self._conn.execute(f"SELECT COUNT(*) FROM {name}").fetchone() count = count_row[0] if count_row else 0 result.append(StreamInfo(name=name, payload_type=pmodule, count=count)) @@ -995,7 +1010,14 @@ def _resolve_payload_type(self, name: str) -> type | None: class SqliteStore(Store): - """SQLite-backed memory store.""" + """SQLite-backed memory store. + + Note: all sessions returned by :meth:`session` share the same underlying + ``sqlite3.Connection``. For concurrent write access from multiple threads, + open separate ``SqliteStore`` instances (one per thread) against the same + DB path — WAL mode allows this safely. See ``MemoryModule._open_session`` + for an example. + """ def __init__(self, path: str) -> None: self._path = path diff --git a/dimos/memory/module.py b/dimos/memory/module.py index aa8162e1da..638001f975 100644 --- a/dimos/memory/module.py +++ b/dimos/memory/module.py @@ -148,6 +148,10 @@ def _setup_clip_pipeline(self) -> None: @rpc def stop(self) -> None: + # Flush the last sharpness window so the final image isn't lost + if self._img_best is not None and self._images is not None: + self._images.append(self._img_best, ts=self._img_best.ts) + self._img_best = None self._session = None for store in self._stores: store.close() diff --git a/dimos/memory/types.py b/dimos/memory/types.py index ea7b1194ef..511f8720ff 100644 --- a/dimos/memory/types.py +++ b/dimos/memory/types.py @@ -135,8 +135,8 @@ class NearFilter: def matches(self, obs: Observation) -> bool: if obs.pose is None: return False - p1 = obs.pose.pose.position - p2 = self.pose.pose.position + p1 = obs.pose.position + p2 = self.pose.position dist = math.sqrt((p1.x - p2.x) ** 2 + (p1.y - p2.y) ** 2 + (p1.z - p2.z) ** 2) return dist <= self.radius From 19a8db34e56e603c7a03f7cf4a5bdc1847b6a59b Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 6 Mar 2026 20:51:04 +0800 Subject: [PATCH 030/118] Use Codec protocol type instead of concrete union, remove dead _pose_codec --- dimos/memory/codec.py | 17 +++-------------- dimos/memory/impl/sqlite.py | 9 ++++----- 2 files changed, 7 insertions(+), 19 deletions(-) diff --git a/dimos/memory/codec.py b/dimos/memory/codec.py index a426339847..db0bf6757f 100644 --- a/dimos/memory/codec.py +++ b/dimos/memory/codec.py @@ -18,6 +18,8 @@ import pickle from typing import TYPE_CHECKING, Any, Protocol, TypeVar +from dimos.msgs.sensor_msgs.Image import Image + if TYPE_CHECKING: from dimos.msgs.protocol import DimosMsg @@ -103,23 +105,10 @@ def decode(self, data: bytes) -> Any: return pickle.loads(data) -_POSE_CODEC: LcmCodec | None = None - - -def _pose_codec() -> LcmCodec: - global _POSE_CODEC - if _POSE_CODEC is None: - from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped - - _POSE_CODEC = LcmCodec(PoseStamped) - return _POSE_CODEC - - -def codec_for_type(payload_type: type | None) -> LcmCodec | JpegCodec | PickleCodec: +def codec_for_type(payload_type: type | None) -> Codec[Any]: """Auto-select codec based on payload type.""" if payload_type is not None: # Image → JPEG by default (much smaller than LCM raw pixels) - from dimos.msgs.sensor_msgs.Image import Image if issubclass(payload_type, Image): return JpegCodec() diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index 3bf20162e0..b36b9f681d 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -37,8 +37,7 @@ from reactivex.subject import Subject from dimos.memory.codec import ( - JpegCodec, - LcmCodec, + Codec, PickleCodec, codec_for_type, module_path_to_type, @@ -369,7 +368,7 @@ def __init__( table: str, *, pose_provider: PoseProvider | None = None, - codec: LcmCodec | JpegCodec | PickleCodec | None = None, + codec: Codec[Any] | None = None, ) -> None: _validate_identifier(table) self._conn = conn @@ -506,7 +505,7 @@ def __init__( vec_dimensions: int | None = None, pose_provider: PoseProvider | None = None, parent_table: str | None = None, - codec: LcmCodec | JpegCodec | PickleCodec | None = None, + codec: Codec[Any] | None = None, ) -> None: super().__init__(conn, table, pose_provider=pose_provider, codec=codec) self._vec_dimensions = vec_dimensions @@ -657,7 +656,7 @@ def __init__( *, tokenizer: str = "unicode61", pose_provider: PoseProvider | None = None, - codec: LcmCodec | JpegCodec | PickleCodec | None = None, + codec: Codec[Any] | None = None, ) -> None: super().__init__(conn, table, pose_provider=pose_provider, codec=codec) self._tokenizer = tokenizer From 2cbf162272fdea846f4540cefa9702a9fdca4632 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 6 Mar 2026 20:59:14 +0800 Subject: [PATCH 031/118] correct db sessions --- dimos/memory/impl/sqlite.py | 36 ++++++++++++++++++++---------------- dimos/memory/module.py | 27 ++++++++++++++++----------- 2 files changed, 36 insertions(+), 27 deletions(-) diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index b36b9f681d..fce0df4e87 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -938,6 +938,7 @@ def close(self) -> None: if s._backend is not None: s._backend.appended_subject.on_completed() self._streams.clear() + self._conn.close() # ── Internal helpers ────────────────────────────────────────────── @@ -1009,34 +1010,37 @@ def _resolve_payload_type(self, name: str) -> type | None: class SqliteStore(Store): - """SQLite-backed memory store. + """SQLite-backed memory store (lightweight factory). - Note: all sessions returned by :meth:`session` share the same underlying - ``sqlite3.Connection``. For concurrent write access from multiple threads, - open separate ``SqliteStore`` instances (one per thread) against the same - DB path — WAL mode allows this safely. See ``MemoryModule._open_session`` - for an example. + Each :meth:`session` call opens a new ``sqlite3.Connection`` with WAL mode + and extensions loaded. Sessions are safe to use from different threads. """ def __init__(self, path: str) -> None: self._path = path - self._conn = sqlite3.connect(path, check_same_thread=False) - self._conn.execute("PRAGMA journal_mode=WAL") - self._conn.execute("PRAGMA synchronous=NORMAL") - self._load_extensions() + self._closed = False + + def _connect(self) -> sqlite3.Connection: + conn = sqlite3.connect(self._path, check_same_thread=False) + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA synchronous=NORMAL") + self._load_extensions(conn) + return conn def session(self) -> SqliteSession: - return SqliteSession(self._conn) + if self._closed: + raise RuntimeError("Store is closed") + return SqliteSession(self._connect()) - def _load_extensions(self) -> None: + def _load_extensions(self, conn: sqlite3.Connection) -> None: try: import sqlite_vec - self._conn.enable_load_extension(True) - sqlite_vec.load(self._conn) - self._conn.enable_load_extension(False) + conn.enable_load_extension(True) + sqlite_vec.load(conn) + conn.enable_load_extension(False) except ImportError: pass def close(self) -> None: - self._conn.close() + self._closed = True diff --git a/dimos/memory/module.py b/dimos/memory/module.py index 638001f975..4d5657207e 100644 --- a/dimos/memory/module.py +++ b/dimos/memory/module.py @@ -70,7 +70,8 @@ class MemoryModule(Module[MemoryModuleConfig]): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - self._stores: list[SqliteStore] = [] + self._store: SqliteStore | None = None + self._sessions: list[Session] = [] self._session: Session | None = None self._images: Stream[Image] | None = None self._pointclouds: Stream[PointCloud2] | None = None @@ -79,10 +80,11 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # ── Lifecycle ───────────────────────────────────────────────────── def _open_session(self) -> Session: - """Open a new store+session (own connection) for the same DB file.""" - store = SqliteStore(self.config.db_path) - self._stores.append(store) - return store.session() + """Open a new session (own connection) from the shared store.""" + assert self._store is not None + session = self._store.session() + self._sessions.append(session) + return session @rpc def start(self) -> None: @@ -92,13 +94,13 @@ def start(self) -> None: cv2.setNumThreads(1) + self._store = SqliteStore(self.config.db_path) + pose_fn = lambda: self.tf.get_pose( # noqa: E731 self.config.world_frame, self.config.robot_frame ) - # Each stream gets its own connection so rx callback threads - # don't share a single sqlite3.Connection (which can't serialize - # concurrent transactions internally). + # Each session opens its own connection (WAL mode allows concurrent writes). # Image stream (best-sharpness per window, no rx windowing overhead) img_session = self._open_session() @@ -153,9 +155,12 @@ def stop(self) -> None: self._images.append(self._img_best, ts=self._img_best.ts) self._img_best = None self._session = None - for store in self._stores: - store.close() - self._stores.clear() + for session in self._sessions: + session.close() + self._sessions.clear() + if self._store is not None: + self._store.close() + self._store = None super().stop() # ── Callbacks ───────────────────────────────────────────────────── From 74df2de4f20c7d10ac05625534ac74837a1a20b5 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 6 Mar 2026 22:08:11 +0800 Subject: [PATCH 032/118] record module cleanup --- dimos/memory/codec.py | 2 +- dimos/memory/impl/sqlite.py | 12 +- dimos/memory/module.py | 195 ++++++------------ dimos/memory/store.py | 15 +- dimos/memory/types.py | 1 + .../blueprints/smart/unitree_go2_memory.py | 52 ++++- 6 files changed, 129 insertions(+), 148 deletions(-) diff --git a/dimos/memory/codec.py b/dimos/memory/codec.py index db0bf6757f..9351b3bf84 100644 --- a/dimos/memory/codec.py +++ b/dimos/memory/codec.py @@ -76,7 +76,7 @@ def encode(self, value: Any) -> bytes: from turbojpeg import TJPF_BGR # type: ignore[import-untyped] pf = self._get_tjpf_map().get(value.format.value, TJPF_BGR) - jpeg_data = self._tj.encode(value.data, quality=self._quality, pixel_format=pf) + jpeg_data: bytes = self._tj.encode(value.data, quality=self._quality, pixel_format=pf) frame_id = (value.frame_id or "").encode("utf-8") header = struct.pack(" list[StreamInfo]: - rows = self._conn.execute("SELECT name, payload_module FROM _streams").fetchall() + rows = self._conn.execute( + "SELECT name, payload_module, stream_kind FROM _streams" + ).fetchall() result: list[StreamInfo] = [] - for name, pmodule in rows: + for name, pmodule, kind in rows: _validate_identifier(name) count_row = self._conn.execute(f"SELECT COUNT(*) FROM {name}").fetchone() count = count_row[0] if count_row else 0 - result.append(StreamInfo(name=name, payload_type=pmodule, count=count)) + result.append( + StreamInfo( + name=name, payload_type=pmodule, count=count, stream_kind=kind or "stream" + ) + ) return result def materialize_transform( diff --git a/dimos/memory/module.py b/dimos/memory/module.py index 4d5657207e..796c63bf69 100644 --- a/dimos/memory/module.py +++ b/dimos/memory/module.py @@ -12,178 +12,115 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Memory module — ingests Image, PointCloud2, and pose into dimos.memory streams.""" +"""Memory module — record input streams into persistent memory.""" from __future__ import annotations -from dataclasses import dataclass -import time +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any +import cv2 + from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig -from dimos.core.stream import In from dimos.memory.impl.sqlite import SqliteStore -from dimos.msgs.sensor_msgs import Image, PointCloud2 +from dimos.msgs.sensor_msgs.Image import sharpness_barrier from dimos.utils.logging_config import setup_logger +cv2.setNumThreads(1) + if TYPE_CHECKING: + from reactivex.observable import Observable + + from dimos.core.stream import In from dimos.memory.store import Session - from dimos.memory.stream import EmbeddingStream, Stream + from dimos.memory.stream import Stream + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped logger = setup_logger() +@dataclass +class RecordSpec: + """Declares an input stream to record.""" + + input_name: str + stream_name: str + payload_type: type | None = None + fps: float = 0 + """Target FPS. If >0, uses sharpness_barrier to select best frame per window.""" + + @dataclass class MemoryModuleConfig(ModuleConfig): db_path: str = "memory.db" world_frame: str = "world" robot_frame: str = "base_link" - image_fps: float = 5.0 - # CLIP embedding pipeline - enable_clip: bool = False - sharpness_window: float = 0.5 + records: list[RecordSpec] = field(default_factory=list) class MemoryModule(Module[MemoryModuleConfig]): - """Ingests images and point clouds into persistent memory streams. - - Pose is obtained implicitly from the TF system (world -> base_link). - Optionally builds a CLIP embedding index with sharpness-based quality filtering. - - Usage:: - - memory = dimos.deploy(MemoryModule, db_path="/data/robot.db") - memory.color_image.connect(camera.color_image) - memory.pointcloud.connect(lidar.pointcloud) - memory.start() - - # Query via session - session = memory.session - results = session.stream("images").after(t).near(pose, 5.0).fetch() - """ - - color_image: In[Image] - lidar: In[PointCloud2] - default_config: type[MemoryModuleConfig] = MemoryModuleConfig def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._store: SqliteStore | None = None - self._sessions: list[Session] = [] self._session: Session | None = None - self._images: Stream[Image] | None = None - self._pointclouds: Stream[PointCloud2] | None = None - self._embeddings: EmbeddingStream[Any] | None = None # ── Lifecycle ───────────────────────────────────────────────────── - def _open_session(self) -> Session: - """Open a new session (own connection) from the shared store.""" - assert self._store is not None - session = self._store.session() - self._sessions.append(session) - return session + def pose(self) -> PoseStamped | None: + return self.tf.get_pose(self.config.world_frame, self.config.robot_frame) # type: ignore[no-any-return] @rpc def start(self) -> None: super().start() - - import cv2 - - cv2.setNumThreads(1) - self._store = SqliteStore(self.config.db_path) - - pose_fn = lambda: self.tf.get_pose( # noqa: E731 - self.config.world_frame, self.config.robot_frame - ) - - # Each session opens its own connection (WAL mode allows concurrent writes). - - # Image stream (best-sharpness per window, no rx windowing overhead) - img_session = self._open_session() - self._images = img_session.stream("images", Image, pose_provider=pose_fn) - self._img_window = 1.0 / self.config.image_fps - self._img_best: Image | None = None - self._img_best_score: float = -1.0 - self._img_window_start: float = 0.0 - self._disposables.add(self.color_image.observable().subscribe(on_next=self._on_image)) - - # Pointcloud stream (only if transport is connected) - if self.lidar._transport is not None: - pc_session = self._open_session() - self._pointclouds = pc_session.stream("pointclouds", PointCloud2, pose_provider=pose_fn) - self._disposables.add(self.lidar.observable().subscribe(on_next=self._on_pointcloud)) - - # Read session (for queries / list_streams) - self._session = self._open_session() - - # Optional CLIP embedding pipeline - if self.config.enable_clip: - self._setup_clip_pipeline() + self._session = self._store.session() + self._disposables.add(self._session) + + # Auto-record streams declared in config + for spec in self.config.records: + input_stream: In[Any] = getattr(self, spec.input_name) + self.record( + input_stream, + spec.stream_name, + spec.payload_type, + fps=spec.fps, + ) logger.info("MemoryModule started (db=%s)", self.config.db_path) - def _setup_clip_pipeline(self) -> None: - from dimos.memory.transformer import EmbeddingTransformer, QualityWindowTransformer - from dimos.models.embedding.clip import CLIPModel - - assert self._images is not None + def record( + self, + input: In[Any], + name: str, + payload_type: type | None = None, + fps: float = 0, + ) -> Stream[Any]: + assert self._store is not None, "record() called before start()" + session = self._store.session() + self._disposables.add(session) + stream = session.stream(name, payload_type, pose_provider=self.pose) - clip = CLIPModel() - clip.start() + obs: Observable[Any] = input.observable() + if fps > 0: + obs = obs.pipe(sharpness_barrier(fps)) - sharp = self._images.transform( - QualityWindowTransformer( - lambda img: img.sharpness, window=self.config.sharpness_window - ), - live=True, - ).store("sharp_frames", Image) + def _on_item(item: Any) -> None: + stream.append(item, ts=getattr(item, "ts", None)) - self._embeddings = sharp.transform( # type: ignore[assignment] - EmbeddingTransformer(clip), live=True - ).store("clip_embeddings") + self._disposables.add(obs.subscribe(on_next=_on_item)) - logger.info("CLIP embedding pipeline active") + return stream @rpc def stop(self) -> None: - # Flush the last sharpness window so the final image isn't lost - if self._img_best is not None and self._images is not None: - self._images.append(self._img_best, ts=self._img_best.ts) - self._img_best = None self._session = None - for session in self._sessions: - session.close() - self._sessions.clear() + super().stop() # disposes all sessions via CompositeDisposable if self._store is not None: self._store.close() self._store = None - super().stop() - - # ── Callbacks ───────────────────────────────────────────────────── - - def _on_image(self, img: Image) -> None: - if self._images is None: - return - now = time.monotonic() - score = img.sharpness - if now - self._img_window_start >= self._img_window: - # Window elapsed — flush best from previous window, start new one - if self._img_best is not None: - self._images.append(self._img_best, ts=self._img_best.ts) - self._img_best = img - self._img_best_score = score - self._img_window_start = now - elif score > self._img_best_score: - self._img_best = img - self._img_best_score = score - - def _on_pointcloud(self, pc: PointCloud2) -> None: - if self._pointclouds is not None: - self._pointclouds.append(pc, ts=pc.ts) # ── Public API ──────────────────────────────────────────────────── @@ -193,22 +130,6 @@ def session(self) -> Session: raise RuntimeError("MemoryModule not started") return self._session - @property - def images(self) -> Stream[Image]: - if self._images is None: - raise RuntimeError("MemoryModule not started") - return self._images - - @property - def pointclouds(self) -> Stream[PointCloud2]: - if self._pointclouds is None: - raise RuntimeError("MemoryModule not started or no pointcloud connected") - return self._pointclouds - - @property - def embeddings(self) -> EmbeddingStream[Any] | None: - return self._embeddings - @rpc def get_stats(self) -> dict[str, int]: if self._session is None: @@ -217,3 +138,5 @@ def get_stats(self) -> dict[str, int]: memory_module = MemoryModule.blueprint +memory_module = MemoryModule.blueprint +memory_module = MemoryModule.blueprint diff --git a/dimos/memory/store.py b/dimos/memory/store.py index 327f5caa02..5d0dfa9469 100644 --- a/dimos/memory/store.py +++ b/dimos/memory/store.py @@ -17,6 +17,8 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any +from reactivex.abc import DisposableBase + if TYPE_CHECKING: from dimos.models.embedding.base import EmbeddingModel @@ -25,8 +27,14 @@ from .types import PoseProvider, StreamInfo -class Session(ABC): - """A session against a memory store. Creates and manages streams.""" +class Session(DisposableBase, ABC): + """A session against a memory store. Creates and manages streams. + + Inherits DisposableBase so sessions can be added to CompositeDisposable. + """ + + def dispose(self) -> None: + self.close() @abstractmethod def stream( @@ -97,9 +105,6 @@ def close(self) -> None: ... def __enter__(self) -> Session: return self - def __exit__(self, *args: object) -> None: - self.close() - class Store(ABC): """Top-level entry point — wraps a database file.""" diff --git a/dimos/memory/types.py b/dimos/memory/types.py index 511f8720ff..c14123e67f 100644 --- a/dimos/memory/types.py +++ b/dimos/memory/types.py @@ -88,6 +88,7 @@ class StreamInfo: name: str payload_type: str | None = None count: int = 0 + stream_kind: str = "stream" # ── Filter types ────────────────────────────────────────────────────── diff --git a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_memory.py b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_memory.py index 4f76851e6e..b12c9bf4ba 100644 --- a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_memory.py +++ b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_memory.py @@ -12,13 +12,59 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + from dimos.core.blueprints import autoconnect -from dimos.memory.module import memory_module +from dimos.core.core import rpc +from dimos.memory.module import MemoryModule, MemoryModuleConfig +from dimos.msgs.sensor_msgs import Image, PointCloud2 from dimos.robot.unitree.go2.blueprints.smart.unitree_go2 import unitree_go2 +if TYPE_CHECKING: + from dimos.core.stream import In + + +@dataclass +class UnitreeGo2MemoryConfig(MemoryModuleConfig): + image_fps: float = 5.0 + enable_clip: bool = False + + +class UnitreeGo2Memory(MemoryModule): + color_image: In[Image] + lidar: In[PointCloud2] + + config: UnitreeGo2MemoryConfig # type: ignore[assignment] + default_config: type[UnitreeGo2MemoryConfig] = UnitreeGo2MemoryConfig + + @rpc + def start(self) -> None: + super().start() + self._images = self.record(self.color_image, "images", Image, fps=self.config.image_fps) + if self.lidar._transport is not None: + self._pointclouds = self.record(self.lidar, "pointclouds", PointCloud2) + + if self.config.enable_clip: + self._setup_clip_pipeline() + + def _setup_clip_pipeline(self) -> None: + from dimos.memory.transformer import EmbeddingTransformer + from dimos.models.embedding.clip import CLIPModel + + clip = CLIPModel() + clip.start() + + self._embeddings: Any = self._images.transform(EmbeddingTransformer(clip), live=True).store( + "clip_embeddings" + ) + + unitree_go2_memory = autoconnect( unitree_go2, - memory_module(), + UnitreeGo2Memory.blueprint(), ).global_config(n_workers=8) -__all__ = ["unitree_go2_memory"] +__all__ = ["UnitreeGo2Memory", "unitree_go2_memory"] From e039f904ed15ff577c5f09e73cb607c45f0db266 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Sat, 7 Mar 2026 15:56:28 +0800 Subject: [PATCH 033/118] memory elements are now Resource, simplification of memory Module --- dimos/core/resource.py | 6 +- dimos/memory/impl/sqlite.py | 4 +- dimos/memory/impl/test_e2e_export.py | 4 +- dimos/memory/impl/test_sqlite.py | 16 ++-- dimos/memory/impl/test_sqlite_e2e.py | 8 +- dimos/memory/module.py | 84 +++++-------------- dimos/memory/store.py | 21 +++-- dimos/memory/stream.py | 4 + .../blueprints/smart/unitree_go2_memory.py | 31 ++++--- 9 files changed, 71 insertions(+), 107 deletions(-) diff --git a/dimos/core/resource.py b/dimos/core/resource.py index ce3f735329..df1ca568bc 100644 --- a/dimos/core/resource.py +++ b/dimos/core/resource.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC, abstractmethod +from abc import abstractmethod +from reactivex.abc import DisposableBase -class Resource(ABC): + +class Resource(DisposableBase): @abstractmethod def start(self) -> None: ... diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index be952028c9..6ed1c46bb7 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -939,7 +939,7 @@ def materialize_transform( return target - def close(self) -> None: + def stop(self) -> None: for s in self._streams.values(): if s._backend is not None: s._backend.appended_subject.on_completed() @@ -1048,5 +1048,5 @@ def _load_extensions(self, conn: sqlite3.Connection) -> None: except ImportError: pass - def close(self) -> None: + def stop(self) -> None: self._closed = True diff --git a/dimos/memory/impl/test_e2e_export.py b/dimos/memory/impl/test_e2e_export.py index a175560d9b..defce69401 100644 --- a/dimos/memory/impl/test_e2e_export.py +++ b/dimos/memory/impl/test_e2e_export.py @@ -81,8 +81,8 @@ def e2e_db(clip: CLIPModel) -> Generator[tuple[SqliteStore, Any], None, None]: print(f"Using cached DB ({DB_PATH})") yield store, session # type: ignore[misc] - session.close() - store.close() + session.stop() + store.stop() @pytest.fixture(scope="module") diff --git a/dimos/memory/impl/test_sqlite.py b/dimos/memory/impl/test_sqlite.py index 815c5a96ee..ed9b5b6862 100644 --- a/dimos/memory/impl/test_sqlite.py +++ b/dimos/memory/impl/test_sqlite.py @@ -281,16 +281,16 @@ def test_str_persists_reopen(self, tmp_path: object) -> None: store1 = SqliteStore(db_path) s1 = store1.session() s1.stream("logs", str).append("hello world", ts=1.0) - s1.close() - store1.close() + s1.stop() + store1.stop() store2 = SqliteStore(db_path) s2 = store2.session() rows = s2.stream("logs", str).fetch() assert len(rows) == 1 assert rows[0].data == "hello world" - s2.close() - store2.close() + s2.stop() + store2.stop() class TestTextEmbeddingTransformer: @@ -1116,13 +1116,13 @@ def test_data_persists(self, tmp_path: object, images: list[Image]) -> None: store1 = SqliteStore(db_path) s1 = store1.session() s1.stream("data", Image).append(images[0], ts=1.0) - s1.close() - store1.close() + s1.stop() + store1.stop() store2 = SqliteStore(db_path) s2 = store2.session() rows = s2.stream("data", Image).fetch() assert len(rows) == 1 assert _img_close(rows[0].data, images[0]) - s2.close() - store2.close() + s2.stop() + store2.stop() diff --git a/dimos/memory/impl/test_sqlite_e2e.py b/dimos/memory/impl/test_sqlite_e2e.py index fd4a049b2a..368e145b51 100644 --- a/dimos/memory/impl/test_sqlite_e2e.py +++ b/dimos/memory/impl/test_sqlite_e2e.py @@ -97,8 +97,8 @@ def test_ingest_filter_embed_search( print(f"Time-filtered search: {len(filtered)} results after ts={mid_ts:.2f}") # 6. Verify persistence — reopen and search again - session.close() - store.close() + session.stop() + store.stop() store2 = SqliteStore(str(tmp_path / "e2e.db")) session2 = store2.session() @@ -109,5 +109,5 @@ def test_ingest_filter_embed_search( assert len(results2) > 0 print(f"After reopen: {len(results2)} results") - session2.close() - store2.close() + session2.stop() + store2.stop() diff --git a/dimos/memory/module.py b/dimos/memory/module.py index 796c63bf69..c92f1b5c60 100644 --- a/dimos/memory/module.py +++ b/dimos/memory/module.py @@ -16,8 +16,8 @@ from __future__ import annotations -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, TypeVar import cv2 @@ -33,22 +33,12 @@ from reactivex.observable import Observable from dimos.core.stream import In - from dimos.memory.store import Session from dimos.memory.stream import Stream from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped -logger = setup_logger() - +T = TypeVar("T") -@dataclass -class RecordSpec: - """Declares an input stream to record.""" - - input_name: str - stream_name: str - payload_type: type | None = None - fps: float = 0 - """Target FPS. If >0, uses sharpness_barrier to select best frame per window.""" +logger = setup_logger() @dataclass @@ -56,7 +46,6 @@ class MemoryModuleConfig(ModuleConfig): db_path: str = "memory.db" world_frame: str = "world" robot_frame: str = "base_link" - records: list[RecordSpec] = field(default_factory=list) class MemoryModule(Module[MemoryModuleConfig]): @@ -65,9 +54,6 @@ class MemoryModule(Module[MemoryModuleConfig]): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._store: SqliteStore | None = None - self._session: Session | None = None - - # ── Lifecycle ───────────────────────────────────────────────────── def pose(self) -> PoseStamped | None: return self.tf.get_pose(self.config.world_frame, self.config.robot_frame) # type: ignore[no-any-return] @@ -76,67 +62,39 @@ def pose(self) -> PoseStamped | None: def start(self) -> None: super().start() self._store = SqliteStore(self.config.db_path) - self._session = self._store.session() - self._disposables.add(self._session) - - # Auto-record streams declared in config - for spec in self.config.records: - input_stream: In[Any] = getattr(self, spec.input_name) - self.record( - input_stream, - spec.stream_name, - spec.payload_type, - fps=spec.fps, - ) - + self._disposables.add(self._store) logger.info("MemoryModule started (db=%s)", self.config.db_path) - def record( + def memory( self, - input: In[Any], - name: str, - payload_type: type | None = None, + input: In[T], + name: str | None = None, # can be infered from input + payload_type: type | None = None, # can be infered from input fps: float = 0, - ) -> Stream[Any]: + ) -> Stream[T]: assert self._store is not None, "record() called before start()" + + if name is None: + name = input.name + if payload_type is None: + payload_type = input.type + session = self._store.session() self._disposables.add(session) - stream = session.stream(name, payload_type, pose_provider=self.pose) + + memory_stream = session.stream(name, payload_type, pose_provider=self.pose) obs: Observable[Any] = input.observable() if fps > 0: obs = obs.pipe(sharpness_barrier(fps)) - def _on_item(item: Any) -> None: - stream.append(item, ts=getattr(item, "ts", None)) - - self._disposables.add(obs.subscribe(on_next=_on_item)) + self._disposables.add(obs.subscribe(on_next=memory_stream.append)) - return stream + return memory_stream @rpc def stop(self) -> None: - self._session = None - super().stop() # disposes all sessions via CompositeDisposable - if self._store is not None: - self._store.close() - self._store = None - - # ── Public API ──────────────────────────────────────────────────── - - @property - def session(self) -> Session: - if self._session is None: - raise RuntimeError("MemoryModule not started") - return self._session - - @rpc - def get_stats(self) -> dict[str, int]: - if self._session is None: - return {} - return {s.name: s.count for s in self._session.list_streams()} + super().stop() memory_module = MemoryModule.blueprint -memory_module = MemoryModule.blueprint -memory_module = MemoryModule.blueprint diff --git a/dimos/memory/store.py b/dimos/memory/store.py index 5d0dfa9469..2c62b6397b 100644 --- a/dimos/memory/store.py +++ b/dimos/memory/store.py @@ -14,10 +14,10 @@ from __future__ import annotations -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import TYPE_CHECKING, Any -from reactivex.abc import DisposableBase +from dimos.core.resource import Resource if TYPE_CHECKING: from dimos.models.embedding.base import EmbeddingModel @@ -27,14 +27,14 @@ from .types import PoseProvider, StreamInfo -class Session(DisposableBase, ABC): +class Session(Resource): """A session against a memory store. Creates and manages streams. Inherits DisposableBase so sessions can be added to CompositeDisposable. """ - def dispose(self) -> None: - self.close() + def start(self) -> None: + pass @abstractmethod def stream( @@ -100,23 +100,26 @@ def resolve_lineage_chain(self, source: str, target: str) -> tuple[str, ...]: """ @abstractmethod - def close(self) -> None: ... + def stop(self) -> None: ... def __enter__(self) -> Session: return self -class Store(ABC): +class Store(Resource): """Top-level entry point — wraps a database file.""" @abstractmethod def session(self) -> Session: ... + def start(self) -> None: + pass + @abstractmethod - def close(self) -> None: ... + def stop(self) -> None: ... def __enter__(self) -> Store: return self def __exit__(self, *args: object) -> None: - self.close() + self.stop() diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index dd9b1a7ed3..86933d4836 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -26,6 +26,8 @@ import numpy as np import reactivex.operators as ops +from dimos.types.timestamped import Timestamped + from .types import ( AfterFilter, AtFilter, @@ -132,6 +134,8 @@ def append( tags: dict[str, Any] | None = None, parent_id: int | None = None, ) -> Observation: + if ts is None and isinstance(payload, Timestamped): + ts = payload.ts backend = self._require_backend() return backend.do_append(payload, ts, pose, tags, parent_id) diff --git a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_memory.py b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_memory.py index b12c9bf4ba..4891d307e2 100644 --- a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_memory.py +++ b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_memory.py @@ -15,56 +15,53 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from dimos.core.blueprints import autoconnect from dimos.core.core import rpc from dimos.memory.module import MemoryModule, MemoryModuleConfig -from dimos.msgs.sensor_msgs import Image, PointCloud2 +from dimos.memory.transformer import EmbeddingTransformer +from dimos.models.embedding.clip import CLIPModel from dimos.robot.unitree.go2.blueprints.smart.unitree_go2 import unitree_go2 if TYPE_CHECKING: from dimos.core.stream import In + from dimos.memory.stream import Stream + from dimos.models.embedding.base import Embedding + from dimos.msgs.sensor_msgs import Image, PointCloud2 @dataclass class UnitreeGo2MemoryConfig(MemoryModuleConfig): image_fps: float = 5.0 - enable_clip: bool = False class UnitreeGo2Memory(MemoryModule): color_image: In[Image] lidar: In[PointCloud2] - config: UnitreeGo2MemoryConfig # type: ignore[assignment] default_config: type[UnitreeGo2MemoryConfig] = UnitreeGo2MemoryConfig @rpc def start(self) -> None: super().start() - self._images = self.record(self.color_image, "images", Image, fps=self.config.image_fps) - if self.lidar._transport is not None: - self._pointclouds = self.record(self.lidar, "pointclouds", PointCloud2) - if self.config.enable_clip: - self._setup_clip_pipeline() + self.image_memory: Stream[Image] = self.memory( + self.color_image, + ) - def _setup_clip_pipeline(self) -> None: - from dimos.memory.transformer import EmbeddingTransformer - from dimos.models.embedding.clip import CLIPModel + self.pointcloud_memory: Stream[PointCloud2] = self.memory(self.lidar) clip = CLIPModel() clip.start() + self._disposables.add(clip) - self._embeddings: Any = self._images.transform(EmbeddingTransformer(clip), live=True).store( - "clip_embeddings" - ) + self.image_embeddings: Stream[Embedding] = self.image_memory.transform( + EmbeddingTransformer(clip), live=True + ).store("clip_embeddings") unitree_go2_memory = autoconnect( unitree_go2, UnitreeGo2Memory.blueprint(), ).global_config(n_workers=8) - -__all__ = ["UnitreeGo2Memory", "unitree_go2_memory"] From bd3a572e5162779a5470d4a84dc67603bc695be8 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Sat, 7 Mar 2026 16:03:01 +0800 Subject: [PATCH 034/118] Rename stream.appended to stream.observable()/subscribe() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mirror the core In stream API — memory streams now expose .observable() and .subscribe() instead of the .appended property. --- dimos/memory/impl/sqlite.py | 2 +- dimos/memory/impl/test_sqlite.py | 109 ++------------------------ dimos/memory/stream.py | 7 +- dimos/memory/test_transformer.py | 126 +++++++++++++++++++++++++++++++ 4 files changed, 137 insertions(+), 107 deletions(-) create mode 100644 dimos/memory/test_transformer.py diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index 6ed1c46bb7..167e8c16e9 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -935,7 +935,7 @@ def materialize_transform( # Subscribe to live updates if transformer.supports_live and not backfill_only: - source.appended.subscribe(on_next=lambda obs: transformer.on_append(obs, target)) + source.subscribe(lambda obs: transformer.on_append(obs, target)) return target diff --git a/dimos/memory/impl/test_sqlite.py b/dimos/memory/impl/test_sqlite.py index ed9b5b6862..bfabfb576d 100644 --- a/dimos/memory/impl/test_sqlite.py +++ b/dimos/memory/impl/test_sqlite.py @@ -20,7 +20,7 @@ import pytest from dimos.memory.impl.sqlite import SqliteSession, SqliteStore -from dimos.memory.transformer import EmbeddingTransformer, TextEmbeddingTransformer +from dimos.memory.transformer import EmbeddingTransformer from dimos.memory.types import _UNSET, EmbeddingObservation, Observation from dimos.models.embedding.base import Embedding, EmbeddingModel from dimos.msgs.sensor_msgs.Image import Image @@ -293,105 +293,6 @@ def test_str_persists_reopen(self, tmp_path: object) -> None: store2.stop() -class TestTextEmbeddingTransformer: - """Test text → embedding → semantic search pipeline.""" - - def test_text_to_embedding_backfill(self, session: SqliteSession) -> None: - """Backfill: store text, transform to embeddings, search by text.""" - - class FakeTextEmbedder(EmbeddingModel): - device = "cpu" - - def embed(self, *imgs: Image) -> Embedding | list[Embedding]: # type: ignore[override] - raise NotImplementedError - - def embed_text(self, *texts: str) -> Embedding | list[Embedding]: - results = [] - for text in texts: - # Simple fake: hash text to a stable vector - h = hash(text) % 1000 / 1000.0 - results.append(Embedding(np.array([h, 1.0 - h, 0.0, 0.0], dtype=np.float32))) - return results if len(results) > 1 else results[0] - - logs = session.stream("te_logs", str) - logs.append("Robot navigated to kitchen", ts=1.0) - logs.append("Battery low warning", ts=2.0) - logs.append("Robot navigated to bedroom", ts=3.0) - - embedder = FakeTextEmbedder() - emb_stream = logs.transform(TextEmbeddingTransformer(embedder)).store("te_log_embeddings") - - assert emb_stream.count() == 3 - - # Search — the model embeds the query text into the same space - results = emb_stream.search_embedding("Robot navigated to kitchen", k=1).fetch() - assert len(results) == 1 - # Auto-projects to source — data should be original text - assert isinstance(results[0].data, str) - - def test_text_embedding_live(self, session: SqliteSession) -> None: - """Live mode: new text is embedded automatically.""" - - class FakeTextEmbedder(EmbeddingModel): - device = "cpu" - - def embed(self, *imgs: Image) -> Embedding | list[Embedding]: # type: ignore[override] - raise NotImplementedError - - def embed_text(self, *texts: str) -> Embedding | list[Embedding]: - results = [] - for text in texts: - h = hash(text) % 1000 / 1000.0 - results.append(Embedding(np.array([h, 1.0 - h, 0.0], dtype=np.float32))) - return results if len(results) > 1 else results[0] - - logs = session.stream("te_live_logs", str) - emb_stream = logs.transform(TextEmbeddingTransformer(FakeTextEmbedder()), live=True).store( - "te_live_embs" - ) - - assert emb_stream.count() == 0 # no backfill - - logs.append("New log entry", ts=1.0) - assert emb_stream.count() == 1 - - logs.append("Another log entry", ts=2.0) - assert emb_stream.count() == 2 - - def test_text_embedding_search_projects_to_source(self, session: SqliteSession) -> None: - """search_embedding auto-projects back to source text stream.""" - - class FakeTextEmbedder(EmbeddingModel): - device = "cpu" - - def embed(self, *imgs: Image) -> Embedding | list[Embedding]: # type: ignore[override] - raise NotImplementedError - - def embed_text(self, *texts: str) -> Embedding | list[Embedding]: - results = [] - for text in texts: - # "kitchen" texts get similar vectors - if "kitchen" in text.lower(): - results.append(Embedding(np.array([1.0, 0.0, 0.0], dtype=np.float32))) - else: - results.append(Embedding(np.array([0.0, 1.0, 0.0], dtype=np.float32))) - return results if len(results) > 1 else results[0] - - logs = session.stream("te_proj_logs", str) - logs.append("Robot entered kitchen", ts=1.0) - logs.append("Battery warning", ts=2.0) - logs.append("Cleaning kitchen floor", ts=3.0) - - emb_stream = logs.transform(TextEmbeddingTransformer(FakeTextEmbedder())).store( - "te_proj_embs" - ) - - # Search for kitchen-related logs - results = emb_stream.search_embedding("kitchen", k=2).fetch() - assert len(results) == 2 - assert all("kitchen" in r.data.lower() for r in results) - - class TestEmbeddingStream: def test_create_and_append(self, session: SqliteSession) -> None: es = session.embedding_stream("emb", vec_dimensions=4) @@ -487,7 +388,7 @@ class TestReactive: def test_appended_observable(self, session: SqliteSession, images: list[Image]) -> None: s = session.stream("images", Image) received: list[Observation] = [] - s.appended.subscribe(on_next=received.append) + s.subscribe(received.append) s.append(images[0]) s.append(images[1]) @@ -1080,7 +981,7 @@ class TestFilteredAppended: def test_unfiltered_appended(self, session: SqliteSession, images: list[Image]) -> None: s = session.stream("fa_unfilt", Image) received: list[Observation] = [] - s.appended.subscribe(on_next=received.append) + s.subscribe(received.append) s.append(images[0], ts=1.0) s.append(images[1], ts=5.0) @@ -1089,7 +990,7 @@ def test_unfiltered_appended(self, session: SqliteSession, images: list[Image]) def test_filtered_appended(self, session: SqliteSession, images: list[Image]) -> None: s = session.stream("fa_filt", Image) received: list[Observation] = [] - s.after(3.0).appended.subscribe(on_next=received.append) + s.after(3.0).subscribe(received.append) s.append(images[0], ts=1.0) # filtered out s.append(images[1], ts=5.0) # passes @@ -1099,7 +1000,7 @@ def test_filtered_appended(self, session: SqliteSession, images: list[Image]) -> def test_tag_filtered_appended(self, session: SqliteSession, images: list[Image]) -> None: s = session.stream("fa_tag", Image) received: list[Observation] = [] - s.filter_tags(cam="front").appended.subscribe(on_next=received.append) + s.filter_tags(cam="front").subscribe(received.append) s.append(images[0], tags={"cam": "front"}) s.append(images[1], tags={"cam": "rear"}) diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index 86933d4836..6a3eee0fad 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -48,6 +48,7 @@ from collections.abc import Callable, Iterator from reactivex import Observable + from reactivex.abc import DisposableBase as Disposable from reactivex.subject import Subject from dimos.models.embedding.base import Embedding, EmbeddingModel @@ -321,8 +322,7 @@ def count(self) -> int: # ── Reactive ────────────────────────────────────────────────────── - @property - def appended(self) -> Observable[Observation]: # type: ignore[type-arg] + def observable(self) -> Observable[Observation]: # type: ignore[type-arg] backend = self._require_backend() raw: Observable[Observation] = backend.appended_subject # type: ignore[assignment] if not self._query.filters: @@ -338,6 +338,9 @@ def _check(o: Observation) -> bool: return raw.pipe(ops.filter(_check)) + def subscribe(self, on_next: Callable[[Observation], None]) -> Disposable: + return self.observable().subscribe(on_next=on_next) + class EmbeddingStream(Stream[T]): """Stream with a vector index. Adds search_embedding().""" diff --git a/dimos/memory/test_transformer.py b/dimos/memory/test_transformer.py new file mode 100644 index 0000000000..41d58effab --- /dev/null +++ b/dimos/memory/test_transformer.py @@ -0,0 +1,126 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for memory transformers.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import pytest + +from dimos.memory.impl.sqlite import SqliteSession, SqliteStore +from dimos.memory.transformer import TextEmbeddingTransformer +from dimos.models.embedding.base import Embedding, EmbeddingModel + +if TYPE_CHECKING: + from collections.abc import Iterator + + from dimos.msgs.sensor_msgs.Image import Image + + +class FakeTextEmbedder(EmbeddingModel): + device = "cpu" + + def embed(self, *imgs: Image) -> Embedding | list[Embedding]: # type: ignore[override] + raise NotImplementedError + + def embed_text(self, *texts: str) -> Embedding | list[Embedding]: + results = [] + for text in texts: + h = hash(text) % 1000 / 1000.0 + results.append(Embedding(np.array([h, 1.0 - h, 0.0, 0.0], dtype=np.float32))) + return results if len(results) > 1 else results[0] + + +class SemanticFakeEmbedder(EmbeddingModel): + """Embeds 'kitchen' texts to one region, everything else to another.""" + + device = "cpu" + + def embed(self, *imgs: Image) -> Embedding | list[Embedding]: # type: ignore[override] + raise NotImplementedError + + def embed_text(self, *texts: str) -> Embedding | list[Embedding]: + results = [] + for text in texts: + if "kitchen" in text.lower(): + results.append(Embedding(np.array([1.0, 0.0, 0.0], dtype=np.float32))) + else: + results.append(Embedding(np.array([0.0, 1.0, 0.0], dtype=np.float32))) + return results if len(results) > 1 else results[0] + + +@pytest.fixture +def session(tmp_path: object) -> Iterator[SqliteSession]: + from pathlib import Path + + assert isinstance(tmp_path, Path) + store = SqliteStore(str(tmp_path / "test.db")) + sess = store.session() + yield sess + sess.stop() + store.stop() + + +class TestTextEmbeddingTransformer: + """Test text -> embedding -> semantic search pipeline.""" + + def test_text_to_embedding_backfill(self, session: SqliteSession) -> None: + """Backfill: store text, transform to embeddings, search by text.""" + logs = session.stream("te_logs", str) + logs.append("Robot navigated to kitchen", ts=1.0) + logs.append("Battery low warning", ts=2.0) + logs.append("Robot navigated to bedroom", ts=3.0) + + emb_stream = logs.transform(TextEmbeddingTransformer(FakeTextEmbedder())).store( + "te_log_embeddings" + ) + + assert emb_stream.count() == 3 + + results = emb_stream.search_embedding("Robot navigated to kitchen", k=1).fetch() + assert len(results) == 1 + assert isinstance(results[0].data, str) + + def test_text_embedding_live(self, session: SqliteSession) -> None: + """Live mode: new text is embedded automatically.""" + logs = session.stream("te_live_logs", str) + emb_stream = logs.transform(TextEmbeddingTransformer(FakeTextEmbedder()), live=True).store( + "te_live_embs" + ) + + assert emb_stream.count() == 0 # no backfill + + logs.append("New log entry", ts=1.0) + assert emb_stream.count() == 1 + + logs.append("Another log entry", ts=2.0) + assert emb_stream.count() == 2 + + def test_text_embedding_search_projects_to_source(self, session: SqliteSession) -> None: + """search_embedding auto-projects back to source text stream.""" + logs = session.stream("te_proj_logs", str) + logs.append("Robot entered kitchen", ts=1.0) + logs.append("Battery warning", ts=2.0) + logs.append("Cleaning kitchen floor", ts=3.0) + + emb_stream = logs.transform(TextEmbeddingTransformer(SemanticFakeEmbedder())).store( + "te_proj_embs" + ) + + results = emb_stream.search_embedding("kitchen", k=2).fetch() + assert len(results) == 2 + assert all("kitchen" in r.data.lower() for r in results) From 24a13de8f0379e48a6e672e87c4fee61712cd016 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Sat, 7 Mar 2026 17:41:46 +0800 Subject: [PATCH 035/118] repr, embedding fetch simplification --- dimos/memory/impl/sqlite.py | 76 ++++------- dimos/memory/impl/test_e2e_export.py | 34 +++-- dimos/memory/impl/test_sqlite.py | 54 ++++---- dimos/memory/impl/test_sqlite_e2e.py | 2 +- dimos/memory/store.py | 16 +-- dimos/memory/stream.py | 84 ++++++++----- dimos/memory/test_stream_repr.py | 181 +++++++++++++++++++++++++++ dimos/memory/test_transformer.py | 15 ++- dimos/memory/types.py | 69 ++++++---- 9 files changed, 377 insertions(+), 154 deletions(-) create mode 100644 dimos/memory/test_stream_repr.py diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index 167e8c16e9..3c585b2ec0 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -545,6 +545,11 @@ def execute_fetch(self, query: StreamQuery) -> list[Observation]: return super().execute_fetch(query) + def execute_count(self, query: StreamQuery) -> int: + if any(isinstance(f, EmbeddingSearchFilter) for f in query.filters): + return len(self.execute_fetch(query)) + return super().execute_count(query) + def _fetch_by_vector( self, query: StreamQuery, emb_filter: EmbeddingSearchFilter ) -> list[Observation]: @@ -604,7 +609,6 @@ def _row_to_obs(self, row: Any) -> Observation: conn = self._conn table = self._table codec = self._codec - parent_table = self._parent_table def loader() -> Any: r = conn.execute(f"SELECT data FROM {table}_payload WHERE id = ?", (row_id,)).fetchone() @@ -612,29 +616,6 @@ def loader() -> Any: raise LookupError(f"No payload for id={row_id}") return codec.decode(r[0]) - source_loader = None - if pid is not None and parent_table is not None: - _pt: str = parent_table # narrowed from str | None by the guard above - - def _source_loader(parent_tbl: str = _pt, parent_row_id: int = pid) -> Any: - r = conn.execute( - f"SELECT data FROM {parent_tbl}_payload WHERE id = ?", (parent_row_id,) - ).fetchone() - if r is None: - raise LookupError(f"No parent payload for id={parent_row_id}") - # Resolve parent codec from _streams metadata - meta = conn.execute( - "SELECT payload_module FROM _streams WHERE name = ?", (parent_tbl,) - ).fetchone() - if meta and meta[0]: - parent_type = module_path_to_type(meta[0]) - parent_codec = codec_for_type(parent_type) - else: - parent_codec = codec - return parent_codec.decode(r[0]) - - source_loader = _source_loader - return EmbeddingObservation( id=row_id, ts=ts, @@ -642,7 +623,6 @@ def _source_loader(parent_tbl: str = _pt, parent_row_id: int = pid) -> Any: tags=_deserialize_tags(tags_json), parent_id=pid, _data_loader=loader, - _source_data_loader=source_loader, ) @@ -785,7 +765,7 @@ def _ensure_meta_table(self) -> None: def stream( self, name: str, - payload_type: type | None = None, + payload_type: type, *, pose_provider: PoseProvider | None = None, ) -> Stream[Any]: @@ -793,28 +773,18 @@ def stream( if name in self._streams: return self._streams[name] - if payload_type is None: - payload_type = self._resolve_payload_type(name) - - if payload_type is None: - raise TypeError( - f"stream({name!r}): payload_type is required when creating a new stream. " - "Pass the type explicitly, e.g. session.stream('images', Image)." - ) - self._ensure_stream_tables(name) self._register_stream(name, payload_type, "stream") codec = codec_for_type(payload_type) backend = SqliteStreamBackend(self._conn, name, pose_provider=pose_provider, codec=codec) - s: Stream[Any] = Stream(backend=backend, session=self) + s: Stream[Any] = Stream(backend=backend, session=self, payload_type=payload_type) self._streams[name] = s return s def text_stream( self, name: str, - payload_type: type | None = None, *, tokenizer: str = "unicode61", pose_provider: PoseProvider | None = None, @@ -823,33 +793,29 @@ def text_stream( if name in self._streams: return self._streams[name] # type: ignore[return-value] - if payload_type is None: - payload_type = self._resolve_payload_type(name) - if payload_type is None: - payload_type = str - self._ensure_stream_tables(name) self._ensure_fts_table(name, tokenizer) - self._register_stream(name, payload_type, "text") + self._register_stream(name, str, "text") - codec = codec_for_type(payload_type) + codec = codec_for_type(str) backend = SqliteTextBackend( self._conn, name, tokenizer=tokenizer, pose_provider=pose_provider, codec=codec ) - ts: TextStream[Any] = TextStream(backend=backend, session=self) + ts: TextStream[Any] = TextStream(backend=backend, session=self, payload_type=str) self._streams[name] = ts return ts def embedding_stream( self, name: str, - payload_type: type | None = None, *, vec_dimensions: int | None = None, pose_provider: PoseProvider | None = None, parent_table: str | None = None, embedding_model: EmbeddingModel | None = None, ) -> EmbeddingStream[Any]: + from dimos.models.embedding.base import Embedding + _validate_identifier(name) if name in self._streams: existing = self._streams[name] @@ -857,13 +823,10 @@ def embedding_stream( existing._embedding_model = embedding_model return existing # type: ignore[return-value] - if payload_type is None: - payload_type = self._resolve_payload_type(name) - self._ensure_stream_tables(name) - self._register_stream(name, payload_type, "embedding", embedding_dim=vec_dimensions) + self._register_stream(name, Embedding, "embedding", embedding_dim=vec_dimensions) - codec = codec_for_type(payload_type) + codec = codec_for_type(Embedding) backend = SqliteEmbeddingBackend( self._conn, name, @@ -876,7 +839,10 @@ def embedding_stream( backend._ensure_vec_table() es: EmbeddingStream[Any] = EmbeddingStream( - backend=backend, session=self, embedding_model=embedding_model + backend=backend, + session=self, + embedding_model=embedding_model, + payload_type=Embedding, ) self._streams[name] = es return es @@ -914,11 +880,13 @@ def materialize_transform( target: Stream[Any] if isinstance(transformer, (EmbeddingTransformer, TextEmbeddingTransformer)): - target = self.embedding_stream(name, payload_type, parent_table=source_table) + target = self.embedding_stream(name, parent_table=source_table) target._embedding_model = transformer.model elif isinstance(transformer, CaptionTransformer): - target = self.text_stream(name, payload_type) + target = self.text_stream(name) else: + if payload_type is None: + raise TypeError("materialize_transform requires payload_type for plain streams") target = self.stream(name, payload_type) # Record parent lineage in _streams registry diff --git a/dimos/memory/impl/test_e2e_export.py b/dimos/memory/impl/test_e2e_export.py index defce69401..e5d76692f0 100644 --- a/dimos/memory/impl/test_e2e_export.py +++ b/dimos/memory/impl/test_e2e_export.py @@ -88,10 +88,16 @@ def e2e_db(clip: CLIPModel) -> Generator[tuple[SqliteStore, Any], None, None]: @pytest.fixture(scope="module") def embeddings(e2e_db: tuple[SqliteStore, Any], clip: CLIPModel) -> EmbeddingStream[Any]: _, session = e2e_db - stream: EmbeddingStream[Any] = session.embedding_stream("clip_embeddings", embedding_model=clip) # type: ignore[assignment] + stream: EmbeddingStream[Any] = session.embedding_stream("clip_embeddings", embedding_model=clip) # type: ignore[return-value] return stream +@pytest.fixture(scope="module") +def sharp_frames(e2e_db: tuple[SqliteStore, Any]) -> Any: + _, session = e2e_db + return session.stream("sharp_frames", Image) + + class TestEmbeddingSearch: """Search the cached CLIP embedding DB and export top matches.""" @@ -106,40 +112,46 @@ class TestEmbeddingSearch: @pytest.mark.parametrize("query", QUERIES) def test_search_returns_results(self, embeddings: EmbeddingStream[Any], query: str) -> None: + from dimos.memory.types import EmbeddingObservation + results = embeddings.search_embedding(query, k=5).fetch() assert len(results) > 0 for obs in results: assert obs.ts is not None - assert isinstance(obs.data, Image) + assert isinstance(obs, EmbeddingObservation) @pytest.mark.parametrize("query", QUERIES) - def test_search_exports_images(self, embeddings: EmbeddingStream[Any], query: str) -> None: + def test_search_exports_images( + self, embeddings: EmbeddingStream[Any], sharp_frames: Any, query: str + ) -> None: slug = query.replace(" ", "_")[:30] - results = embeddings.search_embedding(query, k=5).fetch() + results = embeddings.search_embedding(query, k=5).project_to(sharp_frames).fetch() for rank, img in enumerate(results): fname = DB_DIR / f"{slug}_{rank + 1}_id{img.id}_ts{img.ts:.0f}.jpg" img.data.save(str(fname)) print(f" [{rank + 1}] id={img.id} ts={img.ts:.2f}") - def test_raw_search_has_similarity(self, embeddings: EmbeddingStream[Any]) -> None: + def test_search_has_similarity(self, embeddings: EmbeddingStream[Any]) -> None: from dimos.memory.types import EmbeddingObservation - raw = embeddings.search_embedding("a hallway", k=10, raw=True).fetch() - assert len(raw) > 0 - for obs in raw: + results = embeddings.search_embedding("a hallway", k=10).fetch() + assert len(results) > 0 + for obs in results: assert isinstance(obs, EmbeddingObservation) assert obs.similarity is not None assert 0.0 <= obs.similarity <= 1.0 - def test_caption_search_results(self, embeddings: EmbeddingStream[Any]) -> None: + def test_caption_search_results( + self, embeddings: EmbeddingStream[Any], sharp_frames: Any + ) -> None: from dimos.models.vl.florence import Florence2Model captioner = Florence2Model() captioner.start() caption_xf = CaptionTransformer(captioner) - results = embeddings.search_embedding("a door", k=3).fetch() + results = embeddings.search_embedding("a door", k=3).project_to(sharp_frames).fetch() captions = results.transform(caption_xf).fetch() assert len(captions) == len(results) @@ -160,6 +172,6 @@ def test_stream_to_rerun(self, e2e_db: tuple[SqliteStore, Any]) -> None: rr.init("memory_e2e_test", spawn=True) _, session = e2e_db - n = to_rerun(session.stream("sharp_frames")) + n = to_rerun(session.stream("sharp_frames", Image)) assert n > 0 print(f" Logged {n} images to Rerun") diff --git a/dimos/memory/impl/test_sqlite.py b/dimos/memory/impl/test_sqlite.py index bfabfb576d..ff98a9eb04 100644 --- a/dimos/memory/impl/test_sqlite.py +++ b/dimos/memory/impl/test_sqlite.py @@ -231,14 +231,14 @@ def test_basic_pagination(self, session: SqliteSession, images: list[Image]) -> class TestTextStream: def test_create_and_append(self, session: SqliteSession) -> None: - s = session.text_stream("logs", str) + s = session.text_stream("logs") s.append("Motor fault on joint 3") s.append("Battery low warning") assert s.count() == 2 def test_text_search(self, session: SqliteSession) -> None: - s = session.text_stream("logs", str) + s = session.text_stream("logs") s.append("Motor fault on joint 3") s.append("Battery low warning") s.append("Motor overheating on joint 5") @@ -365,8 +365,8 @@ def embed_text(self, *texts: str) -> Embedding | list[Embedding]: emb_stream = s.transform(EmbeddingTransformer(FakeEmbedder())).store("cam_embeddings") assert emb_stream.count() == 2 - # Search auto-projects to source images - results = emb_stream.search_embedding([0.5, 0.5, 0.0, 0.0], k=1).fetch() + # Search returns EmbeddingObservation; project_to to get source images + results = emb_stream.search_embedding([0.5, 0.5, 0.0, 0.0], k=1).project_to(s).fetch() assert len(results) == 1 assert _img_close(results[0].data, images[0]) or _img_close(results[0].data, images[1]) @@ -377,7 +377,7 @@ def test_list_empty(self, session: SqliteSession) -> None: def test_list_after_create(self, session: SqliteSession) -> None: session.stream("images", Image) - session.text_stream("logs", str) + session.text_stream("logs") infos = session.list_streams() names = {i.name for i in infos} @@ -443,7 +443,7 @@ def test_transform_store_backfill(self, session: SqliteSession, images: list[Ima expected = f"{images[0].width}x{images[0].height}" assert rows[0].data == expected - reloaded = session.stream("shapes") + reloaded = session.stream("shapes", str) assert reloaded.count() == 2 def test_transform_store_live(self, session: SqliteSession, images: list[Image]) -> None: @@ -508,8 +508,10 @@ def test_iter(self, session: SqliteSession, images: list[Image]) -> None: class TestProjectTo: - def test_search_auto_projects(self, session: SqliteSession, images: list[Image]) -> None: - """search_embedding auto-projects to source stream.""" + def test_search_returns_embedding_obs( + self, session: SqliteSession, images: list[Image] + ) -> None: + """search_embedding returns EmbeddingObservation; .data provides source data via lineage.""" class FakeEmbedder(EmbeddingModel): device = "cpu" @@ -532,21 +534,25 @@ def embed_text(self, *texts: str) -> Embedding | list[Embedding]: embs = imgs.transform(EmbeddingTransformer(FakeEmbedder())).store("pt_embs") assert embs.count() == 3 - # search_embedding auto-projects — results are Images, not Embeddings - projected = embs.search_embedding([0.5, 0.5, 0.0], k=2).fetch() + # search_embedding returns EmbeddingObservation with Embedding data + results = embs.search_embedding([0.5, 0.5, 0.0], k=2).fetch() + assert len(results) == 2 + for obs in results: + assert isinstance(obs, EmbeddingObservation) + assert isinstance(obs.data, Embedding) + + # project_to to get source images + projected = embs.search_embedding([0.5, 0.5, 0.0], k=2).project_to(imgs).fetch() assert len(projected) == 2 for obs in projected: - assert not isinstance(obs, EmbeddingObservation) assert ( _img_close(obs.data, images[0]) or _img_close(obs.data, images[1]) or _img_close(obs.data, images[2]) ) - def test_search_auto_projects_chainable( - self, session: SqliteSession, images: list[Image] - ) -> None: - """Auto-projected search results support further chaining.""" + def test_search_chainable(self, session: SqliteSession, images: list[Image]) -> None: + """Search results support further filter chaining.""" class FakeEmbedder(EmbeddingModel): device = "cpu" @@ -568,7 +574,7 @@ def embed_text(self, *texts: str) -> Embedding | list[Embedding]: embs = imgs.transform(EmbeddingTransformer(FakeEmbedder())).store("ptc_embs") - # Chain time filter after auto-projected search + # Chain time filter after search results = embs.search_embedding([0.5, 0.5, 0.0], k=10).after(3.0).fetch() assert all(r.ts is not None and r.ts > 3.0 for r in results) @@ -601,7 +607,7 @@ def embed_text(self, *texts: str) -> Embedding | list[Embedding]: assert all(r.ts is not None and r.ts > 3.0 for r in results) def test_two_hop(self, session: SqliteSession, images: list[Image]) -> None: - """search_embedding auto-projects to direct parent, then project_to for second hop.""" + """project_to handles multi-hop lineage (embs → mid → raw).""" class FakeEmbedder(EmbeddingModel): device = "cpu" @@ -627,7 +633,7 @@ def embed_text(self, *texts: str) -> Embedding | list[Embedding]: embs = mid.transform(EmbeddingTransformer(FakeEmbedder())).store("th_embs") assert embs.count() == 3 - # search auto-projects to mid (direct parent), then project_to(raw) for second hop + # project_to(raw) walks the full chain: th_embs → th_mid → th_raw projected = embs.search_embedding([0.5, 0.5, 0.0], k=2).project_to(raw) results = projected.fetch() assert len(results) == 2 @@ -779,8 +785,10 @@ def test_similarity_none_without_search(self, session: SqliteSession) -> None: assert isinstance(results[0], EmbeddingObservation) assert results[0].similarity is None - def test_raw_returns_embedding_obs(self, session: SqliteSession, images: list[Image]) -> None: - """search_embedding(raw=True) returns EmbeddingObservation with similarity.""" + def test_search_embedding_obs_with_similarity( + self, session: SqliteSession, images: list[Image] + ) -> None: + """search_embedding returns EmbeddingObservation with similarity scores.""" class FakeEmbedder(EmbeddingModel): device = "cpu" @@ -801,14 +809,12 @@ def embed_text(self, *texts: str) -> Embedding | list[Embedding]: embs = imgs.transform(EmbeddingTransformer(FakeEmbedder())).store("sim_proj_embs") - # raw=True: get raw EmbeddingObservation with similarity - results = embs.search_embedding([0.5, 0.5, 0.0], k=2, raw=True).fetch() + results = embs.search_embedding([0.5, 0.5, 0.0], k=2).fetch() assert len(results) == 2 for obs in results: assert isinstance(obs, EmbeddingObservation) assert obs.similarity is not None - # .data auto-projects to source Image via _source_data_loader - assert isinstance(obs.data, Image) + assert isinstance(obs.data, Embedding) class TestObservationSet: diff --git a/dimos/memory/impl/test_sqlite_e2e.py b/dimos/memory/impl/test_sqlite_e2e.py index 368e145b51..d75278c384 100644 --- a/dimos/memory/impl/test_sqlite_e2e.py +++ b/dimos/memory/impl/test_sqlite_e2e.py @@ -102,7 +102,7 @@ def test_ingest_filter_embed_search( store2 = SqliteStore(str(tmp_path / "e2e.db")) session2 = store2.session() - reloaded = session2.embedding_stream("clip_embeddings") + reloaded = session2.embedding_stream("clip_embeddings", vec_dimensions=512) assert reloaded.count() == n_emb results2 = reloaded.search_embedding(query_emb, k=3).fetch() diff --git a/dimos/memory/store.py b/dimos/memory/store.py index 2c62b6397b..35b3ac7237 100644 --- a/dimos/memory/store.py +++ b/dimos/memory/store.py @@ -15,17 +15,19 @@ from __future__ import annotations from abc import abstractmethod -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeVar from dimos.core.resource import Resource if TYPE_CHECKING: - from dimos.models.embedding.base import EmbeddingModel + from dimos.models.embedding.base import Embedding, EmbeddingModel from .stream import EmbeddingStream, Stream, TextStream from .transformer import Transformer from .types import PoseProvider, StreamInfo +T = TypeVar("T") + class Session(Resource): """A session against a memory store. Creates and manages streams. @@ -40,34 +42,32 @@ def start(self) -> None: def stream( self, name: str, - payload_type: type | None = None, + payload_type: type[T], *, pose_provider: PoseProvider | None = None, - ) -> Stream[Any]: + ) -> Stream[T]: """Get or create a stored stream backed by the database.""" @abstractmethod def text_stream( self, name: str, - payload_type: type | None = None, *, tokenizer: str = "unicode61", pose_provider: PoseProvider | None = None, - ) -> TextStream[Any]: + ) -> TextStream[str]: """Get or create a text stream with FTS index.""" @abstractmethod def embedding_stream( self, name: str, - payload_type: type | None = None, *, vec_dimensions: int | None = None, pose_provider: PoseProvider | None = None, parent_table: str | None = None, embedding_model: EmbeddingModel | None = None, - ) -> EmbeddingStream[Any]: + ) -> EmbeddingStream[Embedding]: """Get or create an embedding stream with vec0 index.""" @abstractmethod diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index 6a3eee0fad..79d2d657c3 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -93,10 +93,12 @@ def __init__( *, query: StreamQuery | None = None, session: Session | None = None, + payload_type: type | None = None, ) -> None: self._backend = backend self._query = query or StreamQuery() self._session: Session | None = session + self._payload_type: type | None = payload_type def _clone(self, **overrides: Any) -> Stream[T]: """Return a new Stream with updated query fields.""" @@ -112,8 +114,17 @@ def _clone(self, **overrides: Any) -> Stream[T]: clone._backend = self._backend clone._query = new_query clone._session = self._session + clone._payload_type = self._payload_type return clone + def __repr__(self) -> str: + cls = type(self).__name__ + type_name = self._payload_type.__name__ if self._payload_type else "?" + name = self._backend.stream_name if self._backend else "unbound" + head = f'{cls}[{type_name}]("{name}")' + query_str = str(self._query) + return f"{head} | {query_str}" if query_str else head + def _with_filter(self, f: Filter) -> Stream[T]: return self._clone(filters=(*self._query.filters, f)) @@ -354,8 +365,9 @@ def __init__( query: StreamQuery | None = None, session: Session | None = None, embedding_model: EmbeddingModel | None = None, + payload_type: type | None = None, ) -> None: - super().__init__(backend=backend, query=query, session=session) + super().__init__(backend=backend, query=query, session=session, payload_type=payload_type) self._embedding_model = embedding_model def _require_model(self) -> EmbeddingModel: @@ -378,19 +390,16 @@ def search_embedding( query: Embedding | list[float] | str | Any, *, k: int, - raw: bool = False, - ) -> Stream[Any]: + ) -> EmbeddingStream[T]: """Search by vector similarity. Accepts pre-computed embeddings, raw float lists, text strings, or images/other objects. Text and non-vector inputs are auto-embedded using the model that created this stream. - By default, auto-projects to the source stream so results contain the - source data (e.g. Images) rather than Embedding objects. Set - ``raw=True`` to skip auto-projection and get ``EmbeddingObservation`` - results with ``.similarity``, ``.pose``, ``.ts``, and ``.data`` - (auto-projected to parent via ``_source_data_loader``). + Returns an EmbeddingStream — use ``.project_to(source)`` to get + results in the source stream's type, or ``.fetch()`` for + ``EmbeddingObservation`` with ``.similarity`` scores. """ from dimos.models.embedding.base import Embedding as EmbeddingCls @@ -398,7 +407,7 @@ def search_embedding( emb = self._require_model().embed_text(query) if isinstance(emb, list): emb = emb[0] - return self.search_embedding(emb, k=k, raw=raw) + return self.search_embedding(emb, k=k) if isinstance(query, EmbeddingCls): vec = query.to_numpy().tolist() @@ -409,30 +418,17 @@ def search_embedding( emb = self._require_model().embed(query) if isinstance(emb, list): emb = emb[0] - return self.search_embedding(emb, k=k, raw=raw) + return self.search_embedding(emb, k=k) clone = self._with_filter(EmbeddingSearchFilter(vec, k)) - filtered: EmbeddingStream[T] = EmbeddingStream( + return EmbeddingStream( backend=clone._backend, query=clone._query, session=clone._session, embedding_model=self._embedding_model, + payload_type=clone._payload_type, ) - if raw: - return filtered - - # Auto-project to source stream when lineage exists - session = filtered._session - backend = filtered._backend - if session is not None and backend is not None: - parent_name = session.resolve_parent_stream(backend.stream_name) - if parent_name is not None: - source = session.stream(parent_name) - return filtered.project_to(source) - - return filtered - def fetch(self) -> ObservationSet[T]: # type: ignore[override] backend = self._require_backend() results = backend.execute_fetch(self._query) @@ -457,7 +453,10 @@ class TextStream(Stream[T]): def search_text(self, text: str, *, k: int | None = None) -> TextStream[T]: clone = self._with_filter(TextSearchFilter(text, k)) ts: TextStream[T] = TextStream( - backend=clone._backend, query=clone._query, session=clone._session + backend=clone._backend, + query=clone._query, + session=clone._session, + payload_type=clone._payload_type, ) return ts @@ -479,6 +478,28 @@ def __init__( self._live = live self._backfill_only = backfill_only + def _clone(self, **overrides: Any) -> Stream[R]: + clone = super()._clone(**overrides) + if isinstance(clone, TransformStream): + clone._source = self._source + clone._transformer = self._transformer + clone._live = self._live + clone._backfill_only = self._backfill_only + return clone + + def __repr__(self) -> str: + type_name = self._transformer.output_type.__name__ if self._transformer.output_type else "?" + xf_name = type(self._transformer).__name__ + flags: list[str] = [] + if self._live: + flags.append("live=True") + if self._backfill_only: + flags.append("backfill_only=True") + flag_str = ", " + ", ".join(flags) if flags else "" + head = f"TransformStream[{type_name}]({self._source!r} -> {xf_name}{flag_str})" + query_str = str(self._query) + return f"{head} | {query_str}" if query_str else head + def fetch(self) -> ObservationSet[R]: """Execute transform in memory, collecting results.""" collector = _CollectorStream[R]() @@ -647,11 +668,12 @@ def _clone(self, **overrides: Any) -> Stream[T]: limit_val=overrides.get("limit_val", q.limit_val), offset_val=overrides.get("offset_val", q.offset_val), ) - clone: Stream[T] = Stream.__new__(Stream) - clone._backend = self._backend - clone._query = new_query - clone._session = self._session - return clone + return Stream( + backend=self._backend, + query=new_query, + session=self._session, + payload_type=self._payload_type, + ) def append( self, diff --git a/dimos/memory/test_stream_repr.py b/dimos/memory/test_stream_repr.py new file mode 100644 index 0000000000..d91a4e6d77 --- /dev/null +++ b/dimos/memory/test_stream_repr.py @@ -0,0 +1,181 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Stream.__repr__ and Filter.__str__.""" + +from __future__ import annotations + +import pytest + +from dimos.memory.impl.sqlite import SqliteStore +from dimos.memory.stream import Stream +from dimos.memory.transformer import PerItemTransformer +from dimos.memory.types import ( + AfterFilter, + AtFilter, + BeforeFilter, + EmbeddingSearchFilter, + LineageFilter, + NearFilter, + StreamQuery, + TagsFilter, + TextSearchFilter, + TimeRangeFilter, +) + +# ── Filter __str__ ──────────────────────────────────────────────────── + + +class TestFilterStr: + def test_after(self) -> None: + assert str(AfterFilter(3.0)) == "after(t=3.0)" + + def test_before(self) -> None: + assert str(BeforeFilter(10.5)) == "before(t=10.5)" + + def test_time_range(self) -> None: + assert str(TimeRangeFilter(3.0, 10.0)) == "time_range(3.0, 10.0)" + + def test_at(self) -> None: + assert str(AtFilter(5.0, 1.0)) == "at(t=5.0, tol=1.0)" + + def test_near(self) -> None: + assert str(NearFilter(pose=None, radius=5.0)) == "near(radius=5.0)" + + def test_tags_single(self) -> None: + assert str(TagsFilter((("cam", "front"),))) == "tags(cam='front')" + + def test_tags_multiple(self) -> None: + f = TagsFilter((("cam", "front"), ("quality", 1))) + assert str(f) == "tags(cam='front', quality=1)" + + def test_embedding_search(self) -> None: + assert str(EmbeddingSearchFilter([0.1, 0.2], k=5)) == "search(k=5)" + + def test_text_search(self) -> None: + assert str(TextSearchFilter("error", k=None)) == "text('error')" + + def test_lineage(self) -> None: + f = LineageFilter("embeddings", StreamQuery(), hops=("filtered",)) + assert str(f) == "lineage(embeddings -> filtered)" + + def test_lineage_direct(self) -> None: + f = LineageFilter("embeddings", StreamQuery(), hops=()) + assert str(f) == "lineage(embeddings -> direct)" + + +# ── Stream __repr__ ─────────────────────────────────────────────────── + + +@pytest.fixture() +def session(): + store = SqliteStore(":memory:") + store.start() + s = store.session() + yield s + s.stop() + store.stop() + + +class TestStreamRepr: + def test_basic_stream(self, session) -> None: + s = session.stream("images", int) + assert repr(s) == 'Stream[int]("images")' + + def test_chain(self, session) -> None: + s = session.stream("images", int) + r = repr(s.after(3.0).filter_tags(cam="front").limit(10)) + assert r == "Stream[int](\"images\") | after(t=3.0) | tags(cam='front') | limit(10)" + + def test_order_and_offset(self, session) -> None: + s = session.stream("images", int) + r = repr(s.order_by("ts", desc=True).offset(5).limit(10)) + assert r == 'Stream[int]("images") | order(ts, desc) | limit(10) | offset(5)' + + def test_text_stream(self, session) -> None: + ts = session.text_stream("logs") + assert repr(ts) == 'TextStream[str]("logs")' + + def test_text_search(self, session) -> None: + ts = session.text_stream("logs") + r = repr(ts.search_text("error")) + assert r == "TextStream[str](\"logs\") | text('error')" + + def test_embedding_stream(self, session) -> None: + es = session.embedding_stream("clip", vec_dimensions=512) + assert repr(es) == 'EmbeddingStream[Embedding]("clip")' + + def test_transform_stream(self, session) -> None: + s = session.stream("images", int) + xf = PerItemTransformer(lambda x: x) + r = repr(s.transform(xf, live=True)) + assert r == 'TransformStream[?](Stream[int]("images") -> PerItemTransformer, live=True)' + + def test_transform_backfill_only(self, session) -> None: + s = session.stream("images", int) + xf = PerItemTransformer(lambda x: x) + r = repr(s.transform(xf, backfill_only=True)) + assert ( + r + == 'TransformStream[?](Stream[int]("images") -> PerItemTransformer, backfill_only=True)' + ) + + def test_unbound_stream(self) -> None: + s = Stream(payload_type=int) + assert repr(s) == 'Stream[int]("unbound")' + + def test_no_payload_type(self) -> None: + s = Stream() + assert repr(s) == 'Stream[?]("unbound")' + + def test_materialized_transform(self, session) -> None: + s = session.stream("images", int) + s.append(1, ts=1.0) + xf = PerItemTransformer(lambda x: x * 2) + derived = s.transform(xf).store("doubled", int) + assert repr(derived) == 'Stream[int]("doubled")' + + def test_transform_with_typed_transformer(self, session) -> None: + from unittest.mock import MagicMock + + from dimos.memory.transformer import EmbeddingTransformer + + s = session.stream("images", int) + model = MagicMock() + xf = EmbeddingTransformer(model) + r = repr(s.transform(xf, live=True)) + assert ( + r + == 'TransformStream[Embedding](Stream[int]("images") -> EmbeddingTransformer, live=True)' + ) + + def test_embedding_stream_from_source(self, session) -> None: + session.stream("images", int) + es = session.embedding_stream("clip", vec_dimensions=512, parent_table="images") + assert ( + repr(es.after(5.0).limit(3)) + == 'EmbeddingStream[Embedding]("clip") | after(t=5.0) | limit(3)' + ) + + def test_ivan(self, session) -> None: + from unittest.mock import MagicMock + + from dimos.memory.transformer import EmbeddingTransformer + from dimos.msgs.sensor_msgs.Image import Image + + s = session.stream("images", Image).after(5.0).limit(3) + print("\n") + print(s) + model = MagicMock() + print(s.transform(EmbeddingTransformer(model)).limit(3)) diff --git a/dimos/memory/test_transformer.py b/dimos/memory/test_transformer.py index 41d58effab..1e294d8cd9 100644 --- a/dimos/memory/test_transformer.py +++ b/dimos/memory/test_transformer.py @@ -93,7 +93,14 @@ def test_text_to_embedding_backfill(self, session: SqliteSession) -> None: results = emb_stream.search_embedding("Robot navigated to kitchen", k=1).fetch() assert len(results) == 1 - assert isinstance(results[0].data, str) + assert isinstance(results[0].data, Embedding) + + # project_to to get source text + projected = ( + emb_stream.search_embedding("Robot navigated to kitchen", k=1).project_to(logs).fetch() + ) + assert len(projected) == 1 + assert isinstance(projected[0].data, str) def test_text_embedding_live(self, session: SqliteSession) -> None: """Live mode: new text is embedded automatically.""" @@ -110,8 +117,8 @@ def test_text_embedding_live(self, session: SqliteSession) -> None: logs.append("Another log entry", ts=2.0) assert emb_stream.count() == 2 - def test_text_embedding_search_projects_to_source(self, session: SqliteSession) -> None: - """search_embedding auto-projects back to source text stream.""" + def test_text_embedding_search_and_project(self, session: SqliteSession) -> None: + """search_embedding + project_to retrieves source text.""" logs = session.stream("te_proj_logs", str) logs.append("Robot entered kitchen", ts=1.0) logs.append("Battery warning", ts=2.0) @@ -121,6 +128,6 @@ def test_text_embedding_search_projects_to_source(self, session: SqliteSession) "te_proj_embs" ) - results = emb_stream.search_embedding("kitchen", k=2).fetch() + results = emb_stream.search_embedding("kitchen", k=2).project_to(logs).fetch() assert len(results) == 2 assert all("kitchen" in r.data.lower() for r in results) diff --git a/dimos/memory/types.py b/dimos/memory/types.py index c14123e67f..e129f70295 100644 --- a/dimos/memory/types.py +++ b/dimos/memory/types.py @@ -17,7 +17,7 @@ from collections.abc import Callable from dataclasses import dataclass, field import math -from typing import TYPE_CHECKING, Any, TypeAlias +from typing import TYPE_CHECKING, Any, TypeAlias, cast if TYPE_CHECKING: from dimos.models.embedding.base import Embedding @@ -52,35 +52,22 @@ def data(self) -> Any: class EmbeddingObservation(Observation): """Returned by EmbeddingStream terminals. - .data auto-projects to the source stream's payload type. - .embedding gives the Embedding vector. + .data returns the Embedding stored in this stream. + .embedding is a convenience alias for .data (typed as Embedding). .similarity is populated (0..1) when fetched via search_embedding (vec0 cosine). + + To get source data (e.g. the original Image), use .project_to(source_stream). """ similarity: float | None = field(default=None, repr=True) - _embedding: Embedding | None = field(default=None, repr=False) - _embedding_loader: Callable[[], Embedding] | None = field( - default=None, repr=False, compare=False - ) - _source_data_loader: Callable[[], Any] | None = field(default=None, repr=False, compare=False) @property - def data(self) -> Any: - if self._data is not _UNSET: - return self._data - if self._source_data_loader is not None: - self._data = self._source_data_loader() - return self._data - return super().data + def data(self) -> Embedding: + return cast("Embedding", super().data) @property def embedding(self) -> Embedding: - if self._embedding is not None: - return self._embedding - if self._embedding_loader is not None: - self._embedding = self._embedding_loader() - return self._embedding - raise LookupError("No embedding available") + return self.data @dataclass @@ -101,6 +88,9 @@ class AfterFilter: def matches(self, obs: Observation) -> bool: return obs.ts is not None and obs.ts > self.t + def __str__(self) -> str: + return f"after(t={self.t})" + @dataclass(frozen=True) class BeforeFilter: @@ -109,6 +99,9 @@ class BeforeFilter: def matches(self, obs: Observation) -> bool: return obs.ts is not None and obs.ts < self.t + def __str__(self) -> str: + return f"before(t={self.t})" + @dataclass(frozen=True) class TimeRangeFilter: @@ -118,6 +111,9 @@ class TimeRangeFilter: def matches(self, obs: Observation) -> bool: return obs.ts is not None and self.t1 <= obs.ts <= self.t2 + def __str__(self) -> str: + return f"time_range({self.t1}, {self.t2})" + @dataclass(frozen=True) class AtFilter: @@ -127,6 +123,9 @@ class AtFilter: def matches(self, obs: Observation) -> bool: return obs.ts is not None and abs(obs.ts - self.t) <= self.tolerance + def __str__(self) -> str: + return f"at(t={self.t}, tol={self.tolerance})" + @dataclass(frozen=True) class NearFilter: @@ -141,6 +140,9 @@ def matches(self, obs: Observation) -> bool: dist = math.sqrt((p1.x - p2.x) ** 2 + (p1.y - p2.y) ** 2 + (p1.z - p2.z) ** 2) return dist <= self.radius + def __str__(self) -> str: + return f"near(radius={self.radius})" + @dataclass(frozen=True) class TagsFilter: @@ -149,6 +151,10 @@ class TagsFilter: def matches(self, obs: Observation) -> bool: return all(obs.tags.get(k) == v for k, v in self.tags) + def __str__(self) -> str: + pairs = ", ".join(f"{k}={v!r}" for k, v in self.tags) + return f"tags({pairs})" + @dataclass(frozen=True) class EmbeddingSearchFilter: @@ -158,6 +164,9 @@ class EmbeddingSearchFilter: def matches(self, obs: Observation) -> bool: return True # top-k handled as special pass in ListBackend + def __str__(self) -> str: + return f"search(k={self.k})" + @dataclass(frozen=True) class TextSearchFilter: @@ -167,6 +176,9 @@ class TextSearchFilter: def matches(self, obs: Observation) -> bool: return self.text.lower() in str(obs.data).lower() + def __str__(self) -> str: + return f"text({self.text!r})" + @dataclass(frozen=True) class LineageFilter: @@ -183,6 +195,10 @@ class LineageFilter: def matches(self, obs: Observation) -> bool: raise NotImplementedError("LineageFilter requires a database backend") + def __str__(self) -> str: + hops = " -> ".join(self.hops) if self.hops else "direct" + return f"lineage({self.source_table} -> {hops})" + Filter: TypeAlias = ( AfterFilter @@ -206,3 +222,14 @@ class StreamQuery: order_desc: bool = False limit_val: int | None = None offset_val: int | None = None + + def __str__(self) -> str: + parts: list[str] = [str(f) for f in self.filters] + if self.order_field: + direction = "desc" if self.order_desc else "asc" + parts.append(f"order({self.order_field}, {direction})") + if self.limit_val is not None: + parts.append(f"limit({self.limit_val})") + if self.offset_val is not None: + parts.append(f"offset({self.offset_val})") + return " | ".join(parts) From d5db010b0cba4be94a3d37f28e8952d3795f5903 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Sat, 7 Mar 2026 18:01:53 +0800 Subject: [PATCH 036/118] Make Observation generic: Observation[T] with full type safety --- dimos/memory/impl/sqlite.py | 28 +++++++------- dimos/memory/impl/test_sqlite.py | 4 +- dimos/memory/stream.py | 63 ++++++++++++++++---------------- dimos/memory/transformer.py | 18 ++++----- dimos/memory/types.py | 55 ++++++++++++++++------------ 5 files changed, 89 insertions(+), 79 deletions(-) diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index 3c585b2ec0..d2a117fe87 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -342,10 +342,12 @@ def _compile_count(query: StreamQuery, table: str) -> tuple[str, list[Any]]: # ── Near-filter post-processing (exact distance after R*Tree bbox) ─── -def _apply_near_post_filter(rows: list[Observation], near: NearFilter) -> list[Observation]: +def _apply_near_post_filter( + rows: list[Observation[Any]], near: NearFilter +) -> list[Observation[Any]]: """Post-filter R*Tree candidates by exact Euclidean distance.""" tp = near.pose.position - result: list[Observation] = [] + result: list[Observation[Any]] = [] for obs in rows: if obs.pose is None: continue @@ -375,10 +377,10 @@ def __init__( self._table = table self._pose_provider = pose_provider self._codec = codec or PickleCodec() - self._subject: Subject[Observation] = Subject() # type: ignore[type-arg] + self._subject: Subject[Observation[Any]] = Subject() # type: ignore[type-arg] @property - def appended_subject(self) -> Subject[Observation]: # type: ignore[type-arg] + def appended_subject(self) -> Subject[Observation[Any]]: # type: ignore[type-arg] return self._subject @property @@ -395,7 +397,7 @@ def do_append( pose: Any | None, tags: dict[str, Any] | None, parent_id: int | None = None, - ) -> Observation: + ) -> Observation[Any]: if ts is None: ts = time.time() if pose is None and self._pose_provider is not None: @@ -455,7 +457,7 @@ def do_append( self._subject.on_next(obs) return obs - def execute_fetch(self, query: StreamQuery) -> list[Observation]: + def execute_fetch(self, query: StreamQuery) -> list[Observation[Any]]: sql, params = _compile_query(query, self._table) rows = self._conn.execute(sql, params).fetchall() observations = [self._row_to_obs(r) for r in rows] @@ -471,7 +473,7 @@ def execute_count(self, query: StreamQuery) -> int: result = self._conn.execute(sql, params).fetchone() return result[0] if result else 0 # type: ignore[no-any-return] - def _row_to_obs(self, row: Any) -> Observation: + def _row_to_obs(self, row: Any) -> Observation[Any]: row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json, pid = row pose = _reconstruct_pose(px, py, pz, qx, qy, qz, qw) conn = self._conn @@ -533,7 +535,7 @@ def _ensure_vec_table(self) -> None: ) self._conn.commit() - def execute_fetch(self, query: StreamQuery) -> list[Observation]: + def execute_fetch(self, query: StreamQuery) -> list[Observation[Any]]: emb_filter = None for f in query.filters: if isinstance(f, EmbeddingSearchFilter): @@ -552,7 +554,7 @@ def execute_count(self, query: StreamQuery) -> int: def _fetch_by_vector( self, query: StreamQuery, emb_filter: EmbeddingSearchFilter - ) -> list[Observation]: + ) -> list[Observation[Any]]: """Fetch using vec0 similarity search, then apply remaining filters.""" vec_sql = ( f"SELECT rowid, distance FROM {self._table}_vec " @@ -586,7 +588,7 @@ def _fetch_by_vector( ) rows = self._conn.execute(sql, params).fetchall() - observations = [self._row_to_obs(r) for r in rows] + observations: list[Observation[Any]] = [self._row_to_obs(r) for r in rows] # Populate similarity scores from vec0 cosine distance (0=identical, 2=opposite) for obs in observations: @@ -603,7 +605,7 @@ def _fetch_by_vector( return observations - def _row_to_obs(self, row: Any) -> Observation: + def _row_to_obs(self, row: Any) -> EmbeddingObservation: row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json, pid = row pose = _reconstruct_pose(px, py, pz, qx, qy, qz, qw) conn = self._conn @@ -648,7 +650,7 @@ def _post_insert(self, row_id: int, payload: Any) -> None: (row_id, text), ) - def execute_fetch(self, query: StreamQuery) -> list[Observation]: + def execute_fetch(self, query: StreamQuery) -> list[Observation[Any]]: text_filter = None for f in query.filters: if isinstance(f, TextSearchFilter): @@ -662,7 +664,7 @@ def execute_fetch(self, query: StreamQuery) -> list[Observation]: def _fetch_by_text( self, query: StreamQuery, text_filter: TextSearchFilter - ) -> list[Observation]: + ) -> list[Observation[Any]]: fts_sql = f"SELECT rowid, rank FROM {self._table}_fts WHERE content MATCH ? ORDER BY rank" fts_params: list[Any] = [text_filter.text] if text_filter.k is not None: diff --git a/dimos/memory/impl/test_sqlite.py b/dimos/memory/impl/test_sqlite.py index ff98a9eb04..6d334d87d2 100644 --- a/dimos/memory/impl/test_sqlite.py +++ b/dimos/memory/impl/test_sqlite.py @@ -21,7 +21,7 @@ from dimos.memory.impl.sqlite import SqliteSession, SqliteStore from dimos.memory.transformer import EmbeddingTransformer -from dimos.memory.types import _UNSET, EmbeddingObservation, Observation +from dimos.memory.types import EmbeddingObservation, Observation, _Unset from dimos.models.embedding.base import Embedding, EmbeddingModel from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.testing import TimedSensorReplay @@ -479,7 +479,7 @@ def test_data_lazy_loaded(self, session: SqliteSession, images: list[Image]) -> rows = s.fetch() obs = rows[0] - assert obs._data is _UNSET + assert isinstance(obs._data, _Unset) assert obs._data_loader is not None loaded = obs.data assert _img_close(loaded, images[0]) diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index 79d2d657c3..fc744aea42 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -20,6 +20,7 @@ Generic, Protocol, TypeVar, + cast, overload, ) @@ -64,7 +65,7 @@ class StreamBackend(Protocol): """Backend protocol — implemented by SqliteStreamBackend etc.""" - def execute_fetch(self, query: StreamQuery) -> list[Observation]: ... + def execute_fetch(self, query: StreamQuery) -> list[Observation[Any]]: ... def execute_count(self, query: StreamQuery) -> int: ... def do_append( self, @@ -73,9 +74,9 @@ def do_append( pose: Any | None, tags: dict[str, Any] | None, parent_id: int | None = None, - ) -> Observation: ... + ) -> Observation[Any]: ... @property - def appended_subject(self) -> Subject[Observation]: ... # type: ignore[type-arg] + def appended_subject(self) -> Subject[Observation[Any]]: ... # type: ignore[type-arg] @property def stream_name(self) -> str: ... @@ -145,11 +146,11 @@ def append( pose: PoseLike | None = None, tags: dict[str, Any] | None = None, parent_id: int | None = None, - ) -> Observation: + ) -> Observation[T]: if ts is None and isinstance(payload, Timestamped): ts = payload.ts backend = self._require_backend() - return backend.do_append(payload, ts, pose, tags, parent_id) + return cast("Observation[T]", backend.do_append(payload, ts, pose, tags, parent_id)) # ── Temporal filters ────────────────────────────────────────────── @@ -276,7 +277,7 @@ def project_to(self, target: Stream[R]) -> Stream[R]: # ── Iteration ───────────────────────────────────────────────────── - def __iter__(self) -> Iterator[Observation]: + def __iter__(self) -> Iterator[Observation[T]]: for page in self.fetch_pages(): yield from page @@ -285,9 +286,9 @@ def __iter__(self) -> Iterator[Observation]: def fetch(self) -> ObservationSet[T]: backend = self._require_backend() results = backend.execute_fetch(self._query) - return ObservationSet(results, session=self._session) + return ObservationSet(cast("list[Observation[T]]", results), session=self._session) - def fetch_pages(self, batch_size: int = 128) -> Iterator[list[Observation]]: + def fetch_pages(self, batch_size: int = 128) -> Iterator[list[Observation[T]]]: offset = self._query.offset_val or 0 total_limit = self._query.limit_val emitted = 0 @@ -309,19 +310,19 @@ def fetch_pages(self, batch_size: int = 128) -> Iterator[list[Observation]]: page = backend.execute_fetch(q) if not page: break - yield page + yield cast("list[Observation[T]]", page) emitted += len(page) if len(page) < page_size: break offset += len(page) - def one(self) -> Observation: + def one(self) -> Observation[T]: results = self.limit(1).fetch() if not results: raise LookupError("No matching observation") return results[0] - def last(self) -> Observation: + def last(self) -> Observation[T]: results = self.order_by("ts", desc=True).limit(1).fetch() if not results: raise LookupError("No matching observation") @@ -333,9 +334,9 @@ def count(self) -> int: # ── Reactive ────────────────────────────────────────────────────── - def observable(self) -> Observable[Observation]: # type: ignore[type-arg] + def observable(self) -> Observable[Observation[T]]: # type: ignore[type-arg] backend = self._require_backend() - raw: Observable[Observation] = backend.appended_subject # type: ignore[assignment] + raw: Observable[Observation[T]] = backend.appended_subject # type: ignore[assignment] if not self._query.filters: return raw active = [ @@ -344,12 +345,12 @@ def observable(self) -> Observable[Observation]: # type: ignore[type-arg] if not isinstance(f, (EmbeddingSearchFilter, LineageFilter)) ] - def _check(o: Observation) -> bool: + def _check(o: Observation[T]) -> bool: return all(f.matches(o) for f in active) return raw.pipe(ops.filter(_check)) - def subscribe(self, on_next: Callable[[Observation], None]) -> Disposable: + def subscribe(self, on_next: Callable[[Observation[T]], None]) -> Disposable: return self.observable().subscribe(on_next=on_next) @@ -432,7 +433,7 @@ def search_embedding( def fetch(self) -> ObservationSet[T]: # type: ignore[override] backend = self._require_backend() results = backend.execute_fetch(self._query) - return ObservationSet(results, session=self._session) + return ObservationSet(cast("list[Observation[T]]", results), session=self._session) def one(self) -> EmbeddingObservation: # type: ignore[override] results = self.limit(1).fetch() @@ -538,7 +539,7 @@ class _CollectorStream(Stream[R]): def __init__(self) -> None: super().__init__(backend=None) - self.results: list[Observation] = [] + self.results: list[Observation[R]] = [] self._next_id = 0 def append( @@ -549,8 +550,8 @@ def append( pose: PoseLike | None = None, tags: dict[str, Any] | None = None, parent_id: int | None = None, - ) -> Observation: - obs = Observation( + ) -> Observation[R]: + obs: Observation[R] = Observation( id=self._next_id, ts=ts, tags=tags or {}, @@ -565,14 +566,14 @@ def append( class ListBackend: """In-memory backend that evaluates StreamQuery filters in Python.""" - def __init__(self, observations: list[Observation], name: str = "") -> None: + def __init__(self, observations: list[Observation[Any]], name: str = "") -> None: self._observations = observations self._name = name from reactivex.subject import Subject - self._subject: Subject[Observation] = Subject() # type: ignore[type-arg] + self._subject: Subject[Observation[Any]] = Subject() # type: ignore[type-arg] - def execute_fetch(self, query: StreamQuery) -> list[Observation]: + def execute_fetch(self, query: StreamQuery) -> list[Observation[Any]]: results = list(self._observations) # Apply non-embedding filters @@ -629,11 +630,11 @@ def do_append( pose: Any | None, tags: dict[str, Any] | None, parent_id: int | None = None, - ) -> Observation: + ) -> Observation[Any]: raise TypeError("ObservationSet is read-only") @property - def appended_subject(self) -> Subject[Observation]: # type: ignore[type-arg] + def appended_subject(self) -> Subject[Observation[Any]]: # type: ignore[type-arg] return self._subject @property @@ -650,12 +651,12 @@ class ObservationSet(Stream[T]): def __init__( self, - observations: list[Observation], + observations: list[Observation[T]], *, session: Session | None = None, ) -> None: self._observations = observations - backend = ListBackend(observations) + backend = ListBackend(cast("list[Observation[Any]]", observations)) super().__init__(backend=backend, session=session) def _clone(self, **overrides: Any) -> Stream[T]: @@ -683,7 +684,7 @@ def append( pose: PoseLike | None = None, tags: dict[str, Any] | None = None, parent_id: int | None = None, - ) -> Observation: + ) -> Observation[T]: raise TypeError("ObservationSet is read-only") # ── List-like interface ────────────────────────────────────────── @@ -692,15 +693,15 @@ def __len__(self) -> int: return len(self._observations) @overload - def __getitem__(self, index: int) -> Observation: ... + def __getitem__(self, index: int) -> Observation[T]: ... @overload - def __getitem__(self, index: slice) -> list[Observation]: ... + def __getitem__(self, index: slice) -> list[Observation[T]]: ... - def __getitem__(self, index: int | slice) -> Observation | list[Observation]: + def __getitem__(self, index: int | slice) -> Observation[T] | list[Observation[T]]: return self._observations[index] - def __iter__(self) -> Iterator[Observation]: + def __iter__(self) -> Iterator[Observation[T]]: return iter(self._observations) def __bool__(self) -> bool: diff --git a/dimos/memory/transformer.py b/dimos/memory/transformer.py index 629b5fec83..dc7afc310f 100644 --- a/dimos/memory/transformer.py +++ b/dimos/memory/transformer.py @@ -44,7 +44,7 @@ def process(self, source: Stream[T], target: Stream[R]) -> None: Has full access to the source stream — can query, filter, batch, skip, etc. """ - def on_append(self, obs: Observation, target: Stream[R]) -> None: + def on_append(self, obs: Observation[Any], target: Stream[R]) -> None: """Reactive per-item processing. Called for each new item.""" @@ -59,10 +59,10 @@ def process(self, source: Stream[T], target: Stream[R]) -> None: for obs in page: self._apply(obs, target) - def on_append(self, obs: Observation, target: Stream[R]) -> None: + def on_append(self, obs: Observation[Any], target: Stream[R]) -> None: self._apply(obs, target) - def _apply(self, obs: Observation, target: Stream[R]) -> None: + def _apply(self, obs: Observation[Any], target: Stream[R]) -> None: result = self._fn(obs.data) if result is None: return @@ -89,12 +89,12 @@ def __init__(self, quality_fn: Callable[[T], float], window: float = 0.5) -> Non self._window = window # Live state self._window_start: float | None = None - self._best_obs: Observation | None = None + self._best_obs: Observation[T] | None = None self._best_score: float = -1.0 def process(self, source: Stream[T], target: Stream[T]) -> None: window_start: float | None = None - best_obs: Observation | None = None + best_obs: Observation[T] | None = None best_score: float = -1.0 for obs in source: @@ -129,7 +129,7 @@ def process(self, source: Stream[T], target: Stream[T]) -> None: parent_id=best_obs.id, ) - def on_append(self, obs: Observation, target: Stream[T]) -> None: + def on_append(self, obs: Observation[T], target: Stream[T]) -> None: # type: ignore[override] ts = obs.ts or 0.0 if self._window_start is None: @@ -177,7 +177,7 @@ def process(self, source: Stream[Any], target: Stream[str]) -> None: for obs, cap in zip(page, captions, strict=True): target.append(cap, ts=obs.ts, pose=obs.pose, tags=obs.tags, parent_id=obs.id) - def on_append(self, obs: Observation, target: Stream[str]) -> None: + def on_append(self, obs: Observation[Any], target: Stream[str]) -> None: caption = self.model.caption(obs.data) target.append(caption, ts=obs.ts, pose=obs.pose, tags=obs.tags, parent_id=obs.id) @@ -209,7 +209,7 @@ def process(self, source: Stream[Any], target: Stream[Embedding]) -> None: for obs, emb in zip(page, embeddings, strict=True): target.append(emb, ts=obs.ts, pose=obs.pose, tags=obs.tags, parent_id=obs.id) - def on_append(self, obs: Observation, target: Stream[Embedding]) -> None: + def on_append(self, obs: Observation[Any], target: Stream[Embedding]) -> None: emb = self.model.embed_text(str(obs.data)) if isinstance(emb, list): emb = emb[0] @@ -242,7 +242,7 @@ def process(self, source: Stream[Any], target: Stream[Embedding]) -> None: for obs, emb in zip(page, embeddings, strict=True): target.append(emb, ts=obs.ts, pose=obs.pose, tags=obs.tags, parent_id=obs.id) - def on_append(self, obs: Observation, target: Stream[Embedding]) -> None: + def on_append(self, obs: Observation[Any], target: Stream[Embedding]) -> None: emb = self.model.embed(obs.data) if isinstance(emb, list): emb = emb[0] diff --git a/dimos/memory/types.py b/dimos/memory/types.py index e129f70295..e1036db1eb 100644 --- a/dimos/memory/types.py +++ b/dimos/memory/types.py @@ -17,39 +17,50 @@ from collections.abc import Callable from dataclasses import dataclass, field import math -from typing import TYPE_CHECKING, Any, TypeAlias, cast +from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeVar + +from dimos.models.embedding.base import Embedding if TYPE_CHECKING: - from dimos.models.embedding.base import Embedding from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped PoseProvider: TypeAlias = Callable[[], Any] # () -> PoseLike | None -_UNSET: Any = object() +T = TypeVar("T") + + +class _Unset: + """Sentinel indicating no data has been loaded yet.""" + + __slots__ = () + + +_UNSET = _Unset() @dataclass -class Observation: +class Observation(Generic[T]): id: int ts: float | None = None pose: PoseStamped | None = None tags: dict[str, Any] = field(default_factory=dict) parent_id: int | None = field(default=None, repr=False) - _data: Any = field(default=_UNSET, repr=False) - _data_loader: Callable[[], Any] | None = field(default=None, repr=False, compare=False) + _data: T | _Unset = field(default_factory=lambda: _UNSET, repr=False) + _data_loader: Callable[[], T] | None = field(default=None, repr=False, compare=False) @property - def data(self) -> Any: - if self._data is not _UNSET: + def data(self) -> T: + if not isinstance(self._data, _Unset): return self._data if self._data_loader is not None: - self._data = self._data_loader() - return self._data + loaded = self._data_loader() + self._data = loaded + return loaded raise LookupError("No data available; observation was not fetched with payload") @dataclass -class EmbeddingObservation(Observation): +class EmbeddingObservation(Observation[Embedding]): """Returned by EmbeddingStream terminals. .data returns the Embedding stored in this stream. @@ -61,10 +72,6 @@ class EmbeddingObservation(Observation): similarity: float | None = field(default=None, repr=True) - @property - def data(self) -> Embedding: - return cast("Embedding", super().data) - @property def embedding(self) -> Embedding: return self.data @@ -85,7 +92,7 @@ class StreamInfo: class AfterFilter: t: float - def matches(self, obs: Observation) -> bool: + def matches(self, obs: Observation[Any]) -> bool: return obs.ts is not None and obs.ts > self.t def __str__(self) -> str: @@ -96,7 +103,7 @@ def __str__(self) -> str: class BeforeFilter: t: float - def matches(self, obs: Observation) -> bool: + def matches(self, obs: Observation[Any]) -> bool: return obs.ts is not None and obs.ts < self.t def __str__(self) -> str: @@ -108,7 +115,7 @@ class TimeRangeFilter: t1: float t2: float - def matches(self, obs: Observation) -> bool: + def matches(self, obs: Observation[Any]) -> bool: return obs.ts is not None and self.t1 <= obs.ts <= self.t2 def __str__(self) -> str: @@ -120,7 +127,7 @@ class AtFilter: t: float tolerance: float - def matches(self, obs: Observation) -> bool: + def matches(self, obs: Observation[Any]) -> bool: return obs.ts is not None and abs(obs.ts - self.t) <= self.tolerance def __str__(self) -> str: @@ -132,7 +139,7 @@ class NearFilter: pose: Any # PoseLike radius: float - def matches(self, obs: Observation) -> bool: + def matches(self, obs: Observation[Any]) -> bool: if obs.pose is None: return False p1 = obs.pose.position @@ -148,7 +155,7 @@ def __str__(self) -> str: class TagsFilter: tags: tuple[tuple[str, Any], ...] - def matches(self, obs: Observation) -> bool: + def matches(self, obs: Observation[Any]) -> bool: return all(obs.tags.get(k) == v for k, v in self.tags) def __str__(self) -> str: @@ -161,7 +168,7 @@ class EmbeddingSearchFilter: query: list[float] k: int - def matches(self, obs: Observation) -> bool: + def matches(self, obs: Observation[Any]) -> bool: return True # top-k handled as special pass in ListBackend def __str__(self) -> str: @@ -173,7 +180,7 @@ class TextSearchFilter: text: str k: int | None - def matches(self, obs: Observation) -> bool: + def matches(self, obs: Observation[Any]) -> bool: return self.text.lower() in str(obs.data).lower() def __str__(self) -> str: @@ -192,7 +199,7 @@ class LineageFilter: source_query: StreamQuery hops: tuple[str, ...] # intermediate tables between source and target - def matches(self, obs: Observation) -> bool: + def matches(self, obs: Observation[Any]) -> bool: raise NotImplementedError("LineageFilter requires a database backend") def __str__(self) -> str: From 2d0bedcf724c00da76515b9496fca6fd3e767f33 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Sat, 7 Mar 2026 19:20:59 +0800 Subject: [PATCH 037/118] Simplify Stream._clone with copy.copy, remove subclass overrides --- dimos/memory/stream.py | 70 +++++++++--------------------------------- 1 file changed, 15 insertions(+), 55 deletions(-) diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index fc744aea42..8e599fad90 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -14,11 +14,13 @@ from __future__ import annotations +import copy from typing import ( TYPE_CHECKING, Any, Generic, Protocol, + Self, TypeVar, cast, overload, @@ -101,21 +103,17 @@ def __init__( self._session: Session | None = session self._payload_type: type | None = payload_type - def _clone(self, **overrides: Any) -> Stream[T]: - """Return a new Stream with updated query fields.""" + def _clone(self, **overrides: Any) -> Self: + """Return a shallow copy with updated query fields.""" q = self._query - new_query = StreamQuery( + clone = copy.copy(self) + clone._query = StreamQuery( filters=overrides.get("filters", q.filters), order_field=overrides.get("order_field", q.order_field), order_desc=overrides.get("order_desc", q.order_desc), limit_val=overrides.get("limit_val", q.limit_val), offset_val=overrides.get("offset_val", q.offset_val), ) - clone: Stream[T] = self.__class__.__new__(self.__class__) - clone._backend = self._backend - clone._query = new_query - clone._session = self._session - clone._payload_type = self._payload_type return clone def __repr__(self) -> str: @@ -126,7 +124,7 @@ def __repr__(self) -> str: query_str = str(self._query) return f"{head} | {query_str}" if query_str else head - def _with_filter(self, f: Filter) -> Stream[T]: + def _with_filter(self, f: Filter) -> Self: return self._clone(filters=(*self._query.filters, f)) def _require_backend(self) -> StreamBackend: @@ -380,12 +378,6 @@ def _require_model(self) -> EmbeddingModel: ) return self._embedding_model - def _clone(self, **overrides: Any) -> Stream[T]: - clone = super()._clone(**overrides) - if isinstance(clone, EmbeddingStream): - clone._embedding_model = self._embedding_model - return clone - def search_embedding( self, query: Embedding | list[float] | str | Any, @@ -421,14 +413,7 @@ def search_embedding( emb = emb[0] return self.search_embedding(emb, k=k) - clone = self._with_filter(EmbeddingSearchFilter(vec, k)) - return EmbeddingStream( - backend=clone._backend, - query=clone._query, - session=clone._session, - embedding_model=self._embedding_model, - payload_type=clone._payload_type, - ) + return self._with_filter(EmbeddingSearchFilter(vec, k)) def fetch(self) -> ObservationSet[T]: # type: ignore[override] backend = self._require_backend() @@ -452,14 +437,7 @@ class TextStream(Stream[T]): """Stream with an FTS5 index. Adds search_text().""" def search_text(self, text: str, *, k: int | None = None) -> TextStream[T]: - clone = self._with_filter(TextSearchFilter(text, k)) - ts: TextStream[T] = TextStream( - backend=clone._backend, - query=clone._query, - session=clone._session, - payload_type=clone._payload_type, - ) - return ts + return self._with_filter(TextSearchFilter(text, k)) class TransformStream(Stream[R]): @@ -479,15 +457,6 @@ def __init__( self._live = live self._backfill_only = backfill_only - def _clone(self, **overrides: Any) -> Stream[R]: - clone = super()._clone(**overrides) - if isinstance(clone, TransformStream): - clone._source = self._source - clone._transformer = self._transformer - clone._live = self._live - clone._backfill_only = self._backfill_only - return clone - def __repr__(self) -> str: type_name = self._transformer.output_type.__name__ if self._transformer.output_type else "?" xf_name = type(self._transformer).__name__ @@ -659,22 +628,13 @@ def __init__( backend = ListBackend(cast("list[Observation[Any]]", observations)) super().__init__(backend=backend, session=session) - def _clone(self, **overrides: Any) -> Stream[T]: - """Return a plain Stream backed by same ListBackend (preserves lazy filter chaining).""" - q = self._query - new_query = StreamQuery( - filters=overrides.get("filters", q.filters), - order_field=overrides.get("order_field", q.order_field), - order_desc=overrides.get("order_desc", q.order_desc), - limit_val=overrides.get("limit_val", q.limit_val), - offset_val=overrides.get("offset_val", q.offset_val), - ) - return Stream( - backend=self._backend, - query=new_query, - session=self._session, - payload_type=self._payload_type, + def _clone(self, **overrides: Any) -> Stream[T]: # type: ignore[override] + """Downgrade to plain Stream — don't carry _observations through chaining.""" + base: Stream[T] = Stream( + backend=self._backend, session=self._session, payload_type=self._payload_type ) + base._query = self._query + return base._clone(**overrides) def append( self, From ce9f5e8ddf6fa2d3c22f5b0c25369caa3c18550a Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Sat, 7 Mar 2026 19:32:12 +0800 Subject: [PATCH 038/118] loader refactor --- dimos/memory/impl/sqlite.py | 37 +++++++++++++++++++------------------ dimos/memory/types.py | 5 +++++ 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index d2a117fe87..b106129c0c 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -31,6 +31,7 @@ import json import re import sqlite3 +import threading import time from typing import TYPE_CHECKING, Any @@ -69,6 +70,8 @@ ) if TYPE_CHECKING: + from collections.abc import Callable + from dimos.memory.types import PoseProvider from dimos.models.embedding.base import EmbeddingModel @@ -473,26 +476,35 @@ def execute_count(self, query: StreamQuery) -> int: result = self._conn.execute(sql, params).fetchone() return result[0] if result else 0 # type: ignore[no-any-return] - def _row_to_obs(self, row: Any) -> Observation[Any]: - row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json, pid = row - pose = _reconstruct_pose(px, py, pz, qx, qy, qz, qw) + def _make_loader(self, row_id: int) -> Callable[[], Any]: conn = self._conn table = self._table codec = self._codec + owner_tid = threading.get_ident() def loader() -> Any: + if threading.get_ident() != owner_tid: + raise RuntimeError( + "Observation.data accessed from a different thread than the one that " + "fetched it. Access .data on the original thread first to cache it, " + "or use obs.load() before passing across threads." + ) r = conn.execute(f"SELECT data FROM {table}_payload WHERE id = ?", (row_id,)).fetchone() if r is None: raise LookupError(f"No payload for id={row_id}") return codec.decode(r[0]) + return loader + + def _row_to_obs(self, row: Any) -> Observation[Any]: + row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json, pid = row return Observation( id=row_id, ts=ts, - pose=pose, + pose=_reconstruct_pose(px, py, pz, qx, qy, qz, qw), tags=_deserialize_tags(tags_json), parent_id=pid, - _data_loader=loader, + _data_loader=self._make_loader(row_id), ) @@ -607,24 +619,13 @@ def _fetch_by_vector( def _row_to_obs(self, row: Any) -> EmbeddingObservation: row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json, pid = row - pose = _reconstruct_pose(px, py, pz, qx, qy, qz, qw) - conn = self._conn - table = self._table - codec = self._codec - - def loader() -> Any: - r = conn.execute(f"SELECT data FROM {table}_payload WHERE id = ?", (row_id,)).fetchone() - if r is None: - raise LookupError(f"No payload for id={row_id}") - return codec.decode(r[0]) - return EmbeddingObservation( id=row_id, ts=ts, - pose=pose, + pose=_reconstruct_pose(px, py, pz, qx, qy, qz, qw), tags=_deserialize_tags(tags_json), parent_id=pid, - _data_loader=loader, + _data_loader=self._make_loader(row_id), ) diff --git a/dimos/memory/types.py b/dimos/memory/types.py index e1036db1eb..7f6bc10174 100644 --- a/dimos/memory/types.py +++ b/dimos/memory/types.py @@ -58,6 +58,11 @@ def data(self) -> T: return loaded raise LookupError("No data available; observation was not fetched with payload") + def load(self) -> Observation[T]: + """Force-load .data and return self. Safe to pass across threads after this.""" + self.data # noqa: B018 + return self + @dataclass class EmbeddingObservation(Observation[Embedding]): From 33ad5e1eec0a4cbf00585609111f818b884450ff Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Sat, 7 Mar 2026 19:46:35 +0800 Subject: [PATCH 039/118] Extract backend.load_data(), add stream.load_data(obs) public API SQL now lives on the backend, closures are thin thread-guarded wrappers. --- dimos/memory/impl/sqlite.py | 17 ++++++++++------- dimos/memory/stream.py | 14 ++++++++++++++ 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index b106129c0c..52af8fe974 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -476,10 +476,16 @@ def execute_count(self, query: StreamQuery) -> int: result = self._conn.execute(sql, params).fetchone() return result[0] if result else 0 # type: ignore[no-any-return] + def load_data(self, row_id: int) -> Any: + """Load payload by row ID from the database.""" + r = self._conn.execute( + f"SELECT data FROM {self._table}_payload WHERE id = ?", (row_id,) + ).fetchone() + if r is None: + raise LookupError(f"No payload for id={row_id}") + return self._codec.decode(r[0]) + def _make_loader(self, row_id: int) -> Callable[[], Any]: - conn = self._conn - table = self._table - codec = self._codec owner_tid = threading.get_ident() def loader() -> Any: @@ -489,10 +495,7 @@ def loader() -> Any: "fetched it. Access .data on the original thread first to cache it, " "or use obs.load() before passing across threads." ) - r = conn.execute(f"SELECT data FROM {table}_payload WHERE id = ?", (row_id,)).fetchone() - if r is None: - raise LookupError(f"No payload for id={row_id}") - return codec.decode(r[0]) + return self.load_data(row_id) return loader diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index 8e599fad90..ec85b91ac7 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -77,6 +77,7 @@ def do_append( tags: dict[str, Any] | None, parent_id: int | None = None, ) -> Observation[Any]: ... + def load_data(self, row_id: int) -> Any: ... @property def appended_subject(self) -> Subject[Observation[Any]]: ... # type: ignore[type-arg] @property @@ -134,6 +135,13 @@ def _require_backend(self) -> StreamBackend: ) return self._backend + # ── Data loading ────────────────────────────────────────────────── + + def load_data(self, obs: Observation[T]) -> T: + """Load payload for an observation. Thread-safe alternative to obs.data.""" + backend = self._require_backend() + return cast("T", backend.load_data(obs.id)) + # ── Write ───────────────────────────────────────────────────────── def append( @@ -592,6 +600,12 @@ def execute_fetch(self, query: StreamQuery) -> list[Observation[Any]]: def execute_count(self, query: StreamQuery) -> int: return len(self.execute_fetch(query)) + def load_data(self, row_id: int) -> Any: + for obs in self._observations: + if obs.id == row_id: + return obs.data + raise LookupError(f"No observation with id={row_id}") + def do_append( self, payload: Any, From 1b679596003b06e3347244dddda4eb5980f56b8f Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Sat, 7 Mar 2026 20:01:38 +0800 Subject: [PATCH 040/118] Add rich colored __str__ to Stream and Filter types print() now shows colored output (class=cyan, type=yellow, name=green, filters=cyan, pipes=dim). __repr__ stays plain for logs. --- dimos/memory/stream.py | 59 +++++++++++++++++++++++++++++ dimos/memory/types.py | 84 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 143 insertions(+) diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index ec85b91ac7..ca2ddfa8ba 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -28,6 +28,8 @@ import numpy as np import reactivex.operators as ops +from rich.console import Console +from rich.text import Text from dimos.types.timestamped import Timestamped @@ -63,6 +65,14 @@ T = TypeVar("T") R = TypeVar("R") +_console = Console(force_terminal=True, highlight=False) + + +def _render_text(text: Text) -> str: + with _console.capture() as cap: + _console.print(text, end="") + return cap.get() + class StreamBackend(Protocol): """Backend protocol — implemented by SqliteStreamBackend etc.""" @@ -125,6 +135,27 @@ def __repr__(self) -> str: query_str = str(self._query) return f"{head} | {query_str}" if query_str else head + def _rich_text(self) -> Text: + t = Text() + cls = type(self).__name__ + type_name = self._payload_type.__name__ if self._payload_type else "?" + name = self._backend.stream_name if self._backend else "unbound" + t.append(cls, style="bold cyan") + t.append("[", style="dim") + t.append(type_name, style="yellow") + t.append("]", style="dim") + t.append("(", style="dim") + t.append(f'"{name}"', style="green") + t.append(")", style="dim") + query_text = self._query._rich_text() + if query_text.plain: + t.append(" | ", style="dim") + t.append_text(query_text) + return t + + def __str__(self) -> str: + return _render_text(self._rich_text()) + def _with_filter(self, f: Filter) -> Self: return self._clone(filters=(*self._query.filters, f)) @@ -478,6 +509,34 @@ def __repr__(self) -> str: query_str = str(self._query) return f"{head} | {query_str}" if query_str else head + def _rich_text(self) -> Text: + t = Text() + type_name = self._transformer.output_type.__name__ if self._transformer.output_type else "?" + xf_name = type(self._transformer).__name__ + t.append("TransformStream", style="bold cyan") + t.append("[", style="dim") + t.append(type_name, style="yellow") + t.append("]", style="dim") + t.append("(", style="dim") + t.append_text(self._source._rich_text()) + t.append(" -> ", style="dim") + t.append(xf_name, style="magenta") + if self._live: + t.append(", ", style="dim") + t.append("live=True", style="yellow") + if self._backfill_only: + t.append(", ", style="dim") + t.append("backfill_only=True", style="yellow") + t.append(")", style="dim") + query_text = self._query._rich_text() + if query_text.plain: + t.append(" | ", style="dim") + t.append_text(query_text) + return t + + def __str__(self) -> str: + return _render_text(self._rich_text()) + def fetch(self) -> ObservationSet[R]: """Execute transform in memory, collecting results.""" collector = _CollectorStream[R]() diff --git a/dimos/memory/types.py b/dimos/memory/types.py index 7f6bc10174..65170104fd 100644 --- a/dimos/memory/types.py +++ b/dimos/memory/types.py @@ -19,6 +19,8 @@ import math from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeVar +from rich.text import Text + from dimos.models.embedding.base import Embedding if TYPE_CHECKING: @@ -103,6 +105,12 @@ def matches(self, obs: Observation[Any]) -> bool: def __str__(self) -> str: return f"after(t={self.t})" + def _rich_text(self) -> Text: + t = Text() + t.append("after", style="cyan") + t.append(f"(t={self.t})") + return t + @dataclass(frozen=True) class BeforeFilter: @@ -114,6 +122,12 @@ def matches(self, obs: Observation[Any]) -> bool: def __str__(self) -> str: return f"before(t={self.t})" + def _rich_text(self) -> Text: + t = Text() + t.append("before", style="cyan") + t.append(f"(t={self.t})") + return t + @dataclass(frozen=True) class TimeRangeFilter: @@ -126,6 +140,12 @@ def matches(self, obs: Observation[Any]) -> bool: def __str__(self) -> str: return f"time_range({self.t1}, {self.t2})" + def _rich_text(self) -> Text: + t = Text() + t.append("time_range", style="cyan") + t.append(f"({self.t1}, {self.t2})") + return t + @dataclass(frozen=True) class AtFilter: @@ -138,6 +158,12 @@ def matches(self, obs: Observation[Any]) -> bool: def __str__(self) -> str: return f"at(t={self.t}, tol={self.tolerance})" + def _rich_text(self) -> Text: + t = Text() + t.append("at", style="cyan") + t.append(f"(t={self.t}, tol={self.tolerance})") + return t + @dataclass(frozen=True) class NearFilter: @@ -155,6 +181,12 @@ def matches(self, obs: Observation[Any]) -> bool: def __str__(self) -> str: return f"near(radius={self.radius})" + def _rich_text(self) -> Text: + t = Text() + t.append("near", style="cyan") + t.append(f"(radius={self.radius})") + return t + @dataclass(frozen=True) class TagsFilter: @@ -167,6 +199,13 @@ def __str__(self) -> str: pairs = ", ".join(f"{k}={v!r}" for k, v in self.tags) return f"tags({pairs})" + def _rich_text(self) -> Text: + t = Text() + t.append("tags", style="cyan") + pairs = ", ".join(f"{k}={v!r}" for k, v in self.tags) + t.append(f"({pairs})") + return t + @dataclass(frozen=True) class EmbeddingSearchFilter: @@ -179,6 +218,12 @@ def matches(self, obs: Observation[Any]) -> bool: def __str__(self) -> str: return f"search(k={self.k})" + def _rich_text(self) -> Text: + t = Text() + t.append("search", style="cyan") + t.append(f"(k={self.k})") + return t + @dataclass(frozen=True) class TextSearchFilter: @@ -191,6 +236,12 @@ def matches(self, obs: Observation[Any]) -> bool: def __str__(self) -> str: return f"text({self.text!r})" + def _rich_text(self) -> Text: + t = Text() + t.append("text", style="cyan") + t.append(f"({self.text!r})") + return t + @dataclass(frozen=True) class LineageFilter: @@ -211,6 +262,13 @@ def __str__(self) -> str: hops = " -> ".join(self.hops) if self.hops else "direct" return f"lineage({self.source_table} -> {hops})" + def _rich_text(self) -> Text: + t = Text() + t.append("lineage", style="cyan") + hops = " -> ".join(self.hops) if self.hops else "direct" + t.append(f"({self.source_table} -> {hops})") + return t + Filter: TypeAlias = ( AfterFilter @@ -245,3 +303,29 @@ def __str__(self) -> str: if self.offset_val is not None: parts.append(f"offset({self.offset_val})") return " | ".join(parts) + + def _rich_text(self) -> Text: + t = Text() + pipe = Text(" | ", style="dim") + parts: list[Text] = [f._rich_text() for f in self.filters] + if self.order_field: + p = Text() + p.append("order", style="cyan") + direction = "desc" if self.order_desc else "asc" + p.append(f"({self.order_field}, {direction})") + parts.append(p) + if self.limit_val is not None: + p = Text() + p.append("limit", style="cyan") + p.append(f"({self.limit_val})") + parts.append(p) + if self.offset_val is not None: + p = Text() + p.append("offset", style="cyan") + p.append(f"({self.offset_val})") + parts.append(p) + for i, part in enumerate(parts): + if i > 0: + t.append_text(pipe) + t.append_text(part) + return t From 2f66bc0c5ae0d69b0f93ef84fdb2b593034d5dd9 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Sat, 7 Mar 2026 20:14:56 +0800 Subject: [PATCH 041/118] Unify __repr__ and __str__ via _rich_text().plain, remove duplicate rendering --- dimos/memory/stream.py | 29 ++++------------- dimos/memory/test_stream_repr.py | 55 ++++++++++++++++++++------------ 2 files changed, 41 insertions(+), 43 deletions(-) diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index ca2ddfa8ba..e7d49614f8 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -70,7 +70,7 @@ def _render_text(text: Text) -> str: with _console.capture() as cap: - _console.print(text, end="") + _console.print(text, end="", soft_wrap=True) return cap.get() @@ -127,14 +127,6 @@ def _clone(self, **overrides: Any) -> Self: ) return clone - def __repr__(self) -> str: - cls = type(self).__name__ - type_name = self._payload_type.__name__ if self._payload_type else "?" - name = self._backend.stream_name if self._backend else "unbound" - head = f'{cls}[{type_name}]("{name}")' - query_str = str(self._query) - return f"{head} | {query_str}" if query_str else head - def _rich_text(self) -> Text: t = Text() cls = type(self).__name__ @@ -153,6 +145,9 @@ def _rich_text(self) -> Text: t.append_text(query_text) return t + def __repr__(self) -> str: + return self._rich_text().plain + def __str__(self) -> str: return _render_text(self._rich_text()) @@ -496,19 +491,6 @@ def __init__( self._live = live self._backfill_only = backfill_only - def __repr__(self) -> str: - type_name = self._transformer.output_type.__name__ if self._transformer.output_type else "?" - xf_name = type(self._transformer).__name__ - flags: list[str] = [] - if self._live: - flags.append("live=True") - if self._backfill_only: - flags.append("backfill_only=True") - flag_str = ", " + ", ".join(flags) if flags else "" - head = f"TransformStream[{type_name}]({self._source!r} -> {xf_name}{flag_str})" - query_str = str(self._query) - return f"{head} | {query_str}" if query_str else head - def _rich_text(self) -> Text: t = Text() type_name = self._transformer.output_type.__name__ if self._transformer.output_type else "?" @@ -534,6 +516,9 @@ def _rich_text(self) -> Text: t.append_text(query_text) return t + def __repr__(self) -> str: + return self._rich_text().plain + def __str__(self) -> str: return _render_text(self._rich_text()) diff --git a/dimos/memory/test_stream_repr.py b/dimos/memory/test_stream_repr.py index d91a4e6d77..f65a13276c 100644 --- a/dimos/memory/test_stream_repr.py +++ b/dimos/memory/test_stream_repr.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for Stream.__repr__ and Filter.__str__.""" +"""Tests for Stream repr/str and Filter.__str__.""" from __future__ import annotations @@ -75,7 +75,7 @@ def test_lineage_direct(self) -> None: assert str(f) == "lineage(embeddings -> direct)" -# ── Stream __repr__ ─────────────────────────────────────────────────── +# ── Stream __str__ ──────────────────────────────────────────────────── @pytest.fixture() @@ -91,52 +91,61 @@ def session(): class TestStreamRepr: def test_basic_stream(self, session) -> None: s = session.stream("images", int) + print(s) assert repr(s) == 'Stream[int]("images")' def test_chain(self, session) -> None: - s = session.stream("images", int) - r = repr(s.after(3.0).filter_tags(cam="front").limit(10)) - assert r == "Stream[int](\"images\") | after(t=3.0) | tags(cam='front') | limit(10)" + s = session.stream("images", int).after(3.0).filter_tags(cam="front").limit(10) + print(s) + assert repr(s) == "Stream[int](\"images\") | after(t=3.0) | tags(cam='front') | limit(10)" def test_order_and_offset(self, session) -> None: - s = session.stream("images", int) - r = repr(s.order_by("ts", desc=True).offset(5).limit(10)) - assert r == 'Stream[int]("images") | order(ts, desc) | limit(10) | offset(5)' + s = session.stream("images", int).order_by("ts", desc=True).offset(5).limit(10) + print(s) + assert repr(s) == 'Stream[int]("images") | order(ts, desc) | limit(10) | offset(5)' def test_text_stream(self, session) -> None: ts = session.text_stream("logs") + print(ts) assert repr(ts) == 'TextStream[str]("logs")' def test_text_search(self, session) -> None: - ts = session.text_stream("logs") - r = repr(ts.search_text("error")) - assert r == "TextStream[str](\"logs\") | text('error')" + ts = session.text_stream("logs").search_text("error") + print(ts) + assert repr(ts) == "TextStream[str](\"logs\") | text('error')" def test_embedding_stream(self, session) -> None: es = session.embedding_stream("clip", vec_dimensions=512) + print(es) assert repr(es) == 'EmbeddingStream[Embedding]("clip")' def test_transform_stream(self, session) -> None: s = session.stream("images", int) xf = PerItemTransformer(lambda x: x) - r = repr(s.transform(xf, live=True)) - assert r == 'TransformStream[?](Stream[int]("images") -> PerItemTransformer, live=True)' + t = s.transform(xf, live=True) + print(t) + assert ( + repr(t) == 'TransformStream[?](Stream[int]("images") -> PerItemTransformer, live=True)' + ) def test_transform_backfill_only(self, session) -> None: s = session.stream("images", int) xf = PerItemTransformer(lambda x: x) - r = repr(s.transform(xf, backfill_only=True)) + t = s.transform(xf, backfill_only=True) + print(t) assert ( - r + repr(t) == 'TransformStream[?](Stream[int]("images") -> PerItemTransformer, backfill_only=True)' ) def test_unbound_stream(self) -> None: s = Stream(payload_type=int) + print(s) assert repr(s) == 'Stream[int]("unbound")' def test_no_payload_type(self) -> None: s = Stream() + print(s) assert repr(s) == 'Stream[?]("unbound")' def test_materialized_transform(self, session) -> None: @@ -144,6 +153,7 @@ def test_materialized_transform(self, session) -> None: s.append(1, ts=1.0) xf = PerItemTransformer(lambda x: x * 2) derived = s.transform(xf).store("doubled", int) + print(derived) assert repr(derived) == 'Stream[int]("doubled")' def test_transform_with_typed_transformer(self, session) -> None: @@ -154,19 +164,22 @@ def test_transform_with_typed_transformer(self, session) -> None: s = session.stream("images", int) model = MagicMock() xf = EmbeddingTransformer(model) - r = repr(s.transform(xf, live=True)) + t = s.transform(xf, live=True) + print(t) assert ( - r + repr(t) == 'TransformStream[Embedding](Stream[int]("images") -> EmbeddingTransformer, live=True)' ) def test_embedding_stream_from_source(self, session) -> None: session.stream("images", int) - es = session.embedding_stream("clip", vec_dimensions=512, parent_table="images") - assert ( - repr(es.after(5.0).limit(3)) - == 'EmbeddingStream[Embedding]("clip") | after(t=5.0) | limit(3)' + es = ( + session.embedding_stream("clip", vec_dimensions=512, parent_table="images") + .after(5.0) + .limit(3) ) + print(es) + assert repr(es) == 'EmbeddingStream[Embedding]("clip") | after(t=5.0) | limit(3)' def test_ivan(self, session) -> None: from unittest.mock import MagicMock From e9078a9310019bf025137d7a7122b3fd9b167dc3 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Sat, 7 Mar 2026 20:22:33 +0800 Subject: [PATCH 042/118] renamed types to type --- dimos/memory/__init__.py | 2 +- dimos/memory/impl/sqlite.py | 4 ++-- dimos/memory/impl/test_e2e_export.py | 4 ++-- dimos/memory/impl/test_sqlite.py | 18 +++++++++--------- dimos/memory/store.py | 2 +- dimos/memory/stream.py | 2 +- dimos/memory/test_stream_repr.py | 2 +- dimos/memory/transformer.py | 2 +- dimos/memory/{types.py => type.py} | 0 docs/agents/docs/codeblocks.md | 14 +++++++------- 10 files changed, 25 insertions(+), 25 deletions(-) rename dimos/memory/{types.py => type.py} (100%) diff --git a/dimos/memory/__init__.py b/dimos/memory/__init__.py index 0104d65e5d..fe65bded3c 100644 --- a/dimos/memory/__init__.py +++ b/dimos/memory/__init__.py @@ -8,7 +8,7 @@ TextEmbeddingTransformer, Transformer, ) -from dimos.memory.types import ( +from dimos.memory.type import ( EmbeddingObservation, Observation, StreamInfo, diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index 52af8fe974..6e91ea44ba 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -52,7 +52,7 @@ TextEmbeddingTransformer, Transformer, ) -from dimos.memory.types import ( +from dimos.memory.type import ( AfterFilter, AtFilter, BeforeFilter, @@ -72,7 +72,7 @@ if TYPE_CHECKING: from collections.abc import Callable - from dimos.memory.types import PoseProvider + from dimos.memory.type import PoseProvider from dimos.models.embedding.base import EmbeddingModel _IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") diff --git a/dimos/memory/impl/test_e2e_export.py b/dimos/memory/impl/test_e2e_export.py index e5d76692f0..3804d3d127 100644 --- a/dimos/memory/impl/test_e2e_export.py +++ b/dimos/memory/impl/test_e2e_export.py @@ -112,7 +112,7 @@ class TestEmbeddingSearch: @pytest.mark.parametrize("query", QUERIES) def test_search_returns_results(self, embeddings: EmbeddingStream[Any], query: str) -> None: - from dimos.memory.types import EmbeddingObservation + from dimos.memory.type import EmbeddingObservation results = embeddings.search_embedding(query, k=5).fetch() assert len(results) > 0 @@ -133,7 +133,7 @@ def test_search_exports_images( print(f" [{rank + 1}] id={img.id} ts={img.ts:.2f}") def test_search_has_similarity(self, embeddings: EmbeddingStream[Any]) -> None: - from dimos.memory.types import EmbeddingObservation + from dimos.memory.type import EmbeddingObservation results = embeddings.search_embedding("a hallway", k=10).fetch() assert len(results) > 0 diff --git a/dimos/memory/impl/test_sqlite.py b/dimos/memory/impl/test_sqlite.py index 6d334d87d2..009337476b 100644 --- a/dimos/memory/impl/test_sqlite.py +++ b/dimos/memory/impl/test_sqlite.py @@ -21,7 +21,7 @@ from dimos.memory.impl.sqlite import SqliteSession, SqliteStore from dimos.memory.transformer import EmbeddingTransformer -from dimos.memory.types import EmbeddingObservation, Observation, _Unset +from dimos.memory.type import EmbeddingObservation, Observation, _Unset from dimos.models.embedding.base import Embedding, EmbeddingModel from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.testing import TimedSensorReplay @@ -919,7 +919,7 @@ def test_limit_offset_in_memory(self, session: SqliteSession, images: list[Image class TestMatchesFilters: def test_after_filter(self) -> None: - from dimos.memory.types import AfterFilter + from dimos.memory.type import AfterFilter f = AfterFilter(5.0) assert f.matches(Observation(id=1, ts=6.0)) is True @@ -928,7 +928,7 @@ def test_after_filter(self) -> None: assert f.matches(Observation(id=4, ts=None)) is False def test_before_filter(self) -> None: - from dimos.memory.types import BeforeFilter + from dimos.memory.type import BeforeFilter f = BeforeFilter(5.0) assert f.matches(Observation(id=1, ts=4.0)) is True @@ -936,7 +936,7 @@ def test_before_filter(self) -> None: assert f.matches(Observation(id=3, ts=6.0)) is False def test_time_range_filter(self) -> None: - from dimos.memory.types import TimeRangeFilter + from dimos.memory.type import TimeRangeFilter f = TimeRangeFilter(2.0, 8.0) assert f.matches(Observation(id=1, ts=5.0)) is True @@ -946,7 +946,7 @@ def test_time_range_filter(self) -> None: assert f.matches(Observation(id=5, ts=9.0)) is False def test_at_filter(self) -> None: - from dimos.memory.types import AtFilter + from dimos.memory.type import AtFilter f = AtFilter(5.0, tolerance=1.0) assert f.matches(Observation(id=1, ts=5.0)) is True @@ -955,7 +955,7 @@ def test_at_filter(self) -> None: assert f.matches(Observation(id=4, ts=6.5)) is False def test_tags_filter(self) -> None: - from dimos.memory.types import TagsFilter + from dimos.memory.type import TagsFilter f = TagsFilter((("cam", "front"),)) assert f.matches(Observation(id=1, tags={"cam": "front", "quality": "high"})) is True @@ -963,20 +963,20 @@ def test_tags_filter(self) -> None: assert f.matches(Observation(id=3, tags={})) is False def test_text_search_filter(self) -> None: - from dimos.memory.types import TextSearchFilter + from dimos.memory.type import TextSearchFilter f = TextSearchFilter("motor", k=None) assert f.matches(Observation(id=1, _data="Motor fault on joint 3")) is True assert f.matches(Observation(id=2, _data="Battery low")) is False def test_embedding_search_filter_always_true(self) -> None: - from dimos.memory.types import EmbeddingSearchFilter + from dimos.memory.type import EmbeddingSearchFilter f = EmbeddingSearchFilter([1.0, 0.0], k=5) assert f.matches(Observation(id=1)) is True def test_lineage_filter_raises(self) -> None: - from dimos.memory.types import LineageFilter, StreamQuery + from dimos.memory.type import LineageFilter, StreamQuery f = LineageFilter("src", StreamQuery(), ()) with pytest.raises(NotImplementedError): diff --git a/dimos/memory/store.py b/dimos/memory/store.py index 35b3ac7237..e37f22b057 100644 --- a/dimos/memory/store.py +++ b/dimos/memory/store.py @@ -24,7 +24,7 @@ from .stream import EmbeddingStream, Stream, TextStream from .transformer import Transformer - from .types import PoseProvider, StreamInfo + from .type import PoseProvider, StreamInfo T = TypeVar("T") diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index e7d49614f8..7d3324a5c2 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -33,7 +33,7 @@ from dimos.types.timestamped import Timestamped -from .types import ( +from .type import ( AfterFilter, AtFilter, BeforeFilter, diff --git a/dimos/memory/test_stream_repr.py b/dimos/memory/test_stream_repr.py index f65a13276c..91a4009bd4 100644 --- a/dimos/memory/test_stream_repr.py +++ b/dimos/memory/test_stream_repr.py @@ -21,7 +21,7 @@ from dimos.memory.impl.sqlite import SqliteStore from dimos.memory.stream import Stream from dimos.memory.transformer import PerItemTransformer -from dimos.memory.types import ( +from dimos.memory.type import ( AfterFilter, AtFilter, BeforeFilter, diff --git a/dimos/memory/transformer.py b/dimos/memory/transformer.py index dc7afc310f..50b4d5954a 100644 --- a/dimos/memory/transformer.py +++ b/dimos/memory/transformer.py @@ -24,7 +24,7 @@ from dimos.models.vl.base import Captioner from .stream import Stream - from .types import Observation + from .type import Observation T = TypeVar("T") R = TypeVar("R") diff --git a/dimos/memory/types.py b/dimos/memory/type.py similarity index 100% rename from dimos/memory/types.py rename to dimos/memory/type.py diff --git a/docs/agents/docs/codeblocks.md b/docs/agents/docs/codeblocks.md index 323f1c0c50..d56ee97015 100644 --- a/docs/agents/docs/codeblocks.md +++ b/docs/agents/docs/codeblocks.md @@ -22,13 +22,13 @@ Python, Shell (sh), Node.js, plus visualization: Matplotlib, Graphviz, Pikchr, A Add flags after the language identifier: -| Flag | Effect | -|------|--------| -| `session=NAME` | Share state between blocks with same session name | -| `output=path.png` | Write output to file instead of inline | -| `no-result` | Execute but don't insert result | -| `skip` | Don't execute this block | -| `expected-error` | Block is expected to fail | +| Flag | Effect | +|-------------------|---------------------------------------------------| +| `session=NAME` | Share state between blocks with same session name | +| `output=path.png` | Write output to file instead of inline | +| `no-result` | Execute but don't insert result | +| `skip` | Don't execute this block | +| `expected-error` | Block is expected to fail | ## Examples From 2b799191cf1d04e4635c83f0b59441c146efade7 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Sat, 7 Mar 2026 20:52:46 +0800 Subject: [PATCH 043/118] one -> first, time range --- dimos/memory/impl/test_sqlite.py | 6 ++--- dimos/memory/stream.py | 44 +++++++++++++++++++++++++++++--- 2 files changed, 43 insertions(+), 7 deletions(-) diff --git a/dimos/memory/impl/test_sqlite.py b/dimos/memory/impl/test_sqlite.py index 009337476b..de76df8d3d 100644 --- a/dimos/memory/impl/test_sqlite.py +++ b/dimos/memory/impl/test_sqlite.py @@ -109,13 +109,13 @@ def test_one(self, session: SqliteSession, images: list[Image]) -> None: s = session.stream("images", Image) s.append(images[0]) - obs = s.one() + obs = s.first() assert _img_close(obs.data, images[0]) def test_one_empty_raises(self, session: SqliteSession) -> None: s = session.stream("images", Image) with pytest.raises(LookupError): - s.one() + s.first() class TestFilters: @@ -465,7 +465,7 @@ def test_transform_store_backfill_only( stored = s.transform(lambda im: im.height, backfill_only=True).store("heights_bo", int) assert stored.count() == 1 - assert stored.one().data == images[0].height + assert stored.first().data == images[0].height s.append(images[1], ts=2.0) assert stored.count() == 1 # still 1 diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index 7d3324a5c2..5c0ac8ebae 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -318,7 +318,11 @@ def __iter__(self) -> Iterator[Observation[T]]: def fetch(self) -> ObservationSet[T]: backend = self._require_backend() results = backend.execute_fetch(self._query) - return ObservationSet(cast("list[Observation[T]]", results), session=self._session) + return ObservationSet( + cast("list[Observation[T]]", results), + session=self._session, + payload_type=self._payload_type, + ) def fetch_pages(self, batch_size: int = 128) -> Iterator[list[Observation[T]]]: offset = self._query.offset_val or 0 @@ -348,7 +352,7 @@ def fetch_pages(self, batch_size: int = 128) -> Iterator[list[Observation[T]]]: break offset += len(page) - def one(self) -> Observation[T]: + def first(self) -> Observation[T]: results = self.limit(1).fetch() if not results: raise LookupError("No matching observation") @@ -364,6 +368,25 @@ def count(self) -> int: backend = self._require_backend() return backend.execute_count(self._query) + def exists(self) -> bool: + return self.count() > 0 + + def get_time_range(self) -> tuple[float, float]: + return (self.first().ts, self.last().ts) + + def summary(self) -> str: + from datetime import datetime, timezone + + n = self.count() + if n == 0: + return f"{self!r}: empty" + t0, t1 = self.get_time_range() + fmt = "%Y-%m-%d %H:%M:%S" + dt0 = datetime.fromtimestamp(t0, tz=timezone.utc).strftime(fmt) + dt1 = datetime.fromtimestamp(t1, tz=timezone.utc).strftime(fmt) + dur = t1 - t0 + return f"{self!r}: {n} items, {dt0} — {dt1} ({dur:.1f}s)" + # ── Reactive ────────────────────────────────────────────────────── def observable(self) -> Observable[Observation[T]]: # type: ignore[type-arg] @@ -454,7 +477,7 @@ def fetch(self) -> ObservationSet[T]: # type: ignore[override] results = backend.execute_fetch(self._query) return ObservationSet(cast("list[Observation[T]]", results), session=self._session) - def one(self) -> EmbeddingObservation: # type: ignore[override] + def first(self) -> EmbeddingObservation: # type: ignore[override] results = self.limit(1).fetch() if not results: raise LookupError("No matching observation") @@ -681,10 +704,23 @@ def __init__( observations: list[Observation[T]], *, session: Session | None = None, + payload_type: type | None = None, ) -> None: self._observations = observations backend = ListBackend(cast("list[Observation[Any]]", observations)) - super().__init__(backend=backend, session=session) + super().__init__(backend=backend, session=session, payload_type=payload_type) + + def _rich_text(self) -> Text: + t = Text() + type_name = self._payload_type.__name__ if self._payload_type else "?" + t.append("ObservationSet", style="bold cyan") + t.append("[", style="dim") + t.append(type_name, style="yellow") + t.append("]", style="dim") + t.append("(", style="dim") + t.append(f"{len(self._observations)} items", style="green") + t.append(")", style="dim") + return t def _clone(self, **overrides: Any) -> Stream[T]: # type: ignore[override] """Downgrade to plain Stream — don't carry _observations through chaining.""" From dfd06c498e96c1a65875a2cdbc87e71c1cc9c59c Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Sun, 8 Mar 2026 12:23:44 +0800 Subject: [PATCH 044/118] getitem for streams --- dimos/memory/stream.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index 5c0ac8ebae..b063413cd5 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -307,12 +307,44 @@ def project_to(self, target: Stream[R]) -> Stream[R]: ) ) - # ── Iteration ───────────────────────────────────────────────────── + # ── List-like interface ──────────────────────────────────────────── def __iter__(self) -> Iterator[Observation[T]]: for page in self.fetch_pages(): yield from page + def __len__(self) -> int: + return self.count() + + @overload + def __getitem__(self, index: int) -> Observation[T]: ... + + @overload + def __getitem__(self, index: slice) -> list[Observation[T]]: ... + + def __getitem__(self, index: int | slice) -> Observation[T] | list[Observation[T]]: + if isinstance(index, int): + if index < 0: + # Negative index: need count to resolve + n = self.count() + index = n + index + if index < 0: + raise IndexError("stream index out of range") + results = self.offset(index).limit(1).fetch() + if not results: + raise IndexError("stream index out of range") + return results[0] + # Slice + start, stop, step = index.indices(self.count()) + s = self.offset(start).limit(stop - start) + results = s.fetch() + if step != 1: + return list(results)[::step] + return list(results) + + def __bool__(self) -> bool: + return self.exists() + # ── Terminals ───────────────────────────────────────────────────── def fetch(self) -> ObservationSet[T]: From 4734acf35fe4174d390c76cd420ca53eb6e33907 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Sun, 8 Mar 2026 14:11:05 +0800 Subject: [PATCH 045/118] readme sketch --- dimos/memory/readme.md | 455 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 455 insertions(+) create mode 100644 dimos/memory/readme.md diff --git a/dimos/memory/readme.md b/dimos/memory/readme.md new file mode 100644 index 0000000000..3b5e6aca41 --- /dev/null +++ b/dimos/memory/readme.md @@ -0,0 +1,455 @@ +# Memory + +Lazy, chainable query system for persistent robot data. Stores timestamped observations in SQLite with vector similarity (sqlite-vec), full-text (FTS5), spatial (R*Tree), and temporal indexes. + +```sh no-result +rm -f /tmp/memory_readme.db +``` + +## Quick start + +```python session=memory ansi=false no-result +from dimos.memory.impl.sqlite import SqliteStore + +store = SqliteStore("/tmp/memory_readme.db") +session = store.session() +``` + +Open a store, get a session, create a stream: + +```python session=memory ansi=false +logs = session.stream("logs", str) +print(logs) +``` + + +``` +Stream[str]("logs") +``` + +Append observations and query them: + +```python session=memory ansi=false +logs.append("Motor started", ts=1.0, tags={"level": "info"}) +logs.append("Joint 3 fault", ts=2.0, tags={"level": "error"}) +logs.append("Motor stopped", ts=3.0, tags={"level": "info"}) + +print(logs.summary()) +``` + + +``` +Stream[str]("logs"): 3 items, 1970-01-01 00:00:01 — 1970-01-01 00:00:03 (2.0s) +``` + +## Observations + +Each observation wraps a payload with metadata: + +```python session=memory ansi=false +obs = logs.first() +print(obs) +print(f"id={obs.id}, ts={obs.ts}, tags={obs.tags}") +``` + + +``` +Observation(id=1, ts=1.0, pose=None, tags={'level': 'info'}) +id=1, ts=1.0, tags={'level': 'info'} +``` + +- `id` — auto-assigned integer +- `ts` — timestamp (float, seconds) +- `pose` — optional 3D position + orientation +- `tags` — key-value metadata dict +- `parent_id` — lineage tracking (set by transforms) +- `data` — the payload (lazily loaded from DB) + +### Lazy payload loading + +Metadata is always in memory. The `.data` property triggers a single-row `SELECT` on first access, then caches: + +```python skip +obs = logs.first() +obs.ts # already in memory +obs.tags # already in memory +obs.data # NOW loads the blob and decodes it +obs.data # cached — second access is free +``` + +Payloads are decoded by the thread that fetched them. To pass observations across threads, call `.load()` first: + +```python skip +obs = logs.first().load() # force-loads .data, safe to pass to another thread +``` + +## Streams are lazy queries + +Filter methods return new `Stream` instances — nothing executes until a terminal is called: + +```python session=memory ansi=false +query = logs.after(1.0).filter_tags(level="info").order_by("ts", desc=True).limit(5) +print(query) +# nothing has hit the database yet +``` + + +``` +Stream[str]("logs") | after(t=1.0) | tags(level='info') | order(ts, desc) | limit(5) +``` + +Each call clones the stream with updated query parameters. The underlying `StreamQuery` compiles to SQL only at terminal time. + +### Filters + +| Method | Description | +|--------|-------------| +| `.after(t)` | `ts > t` | +| `.before(t)` | `ts < t` | +| `.time_range(t1, t2)` | `t1 <= ts <= t2` | +| `.at(t, tolerance=1.0)` | `\|ts - t\| <= tolerance` | +| `.near(pose, radius)` | R*Tree bounding box + exact distance post-filter | +| `.filter_tags(**kv)` | JSON tag field matching | + +### Ordering and pagination + +| Method | Description | +|--------|-------------| +| `.order_by(field, desc=False)` | Sort by `"ts"` or `"id"` | +| `.limit(k)` | Cap results | +| `.offset(n)` | Skip first n results | + +## Terminals execute the query + +| Terminal | Returns | Description | +|----------|---------|-------------| +| `.fetch()` | `ObservationSet` | All matching rows (lazy payloads) | +| `.fetch_pages(batch_size=128)` | `Iterator[list[Observation]]` | Paginated iteration | +| `.count()` | `int` | `SELECT COUNT(*)`, no payload loading | +| `.first()` | `Observation` | First by current ordering; raises `LookupError` if empty | +| `.last()` | `Observation` | Most recent by `ts` | +| `.exists()` | `bool` | `count() > 0` | +| `.summary()` | `str` | Count, time range, duration | +| `.get_time_range()` | `(float, float)` | `(first.ts, last.ts)` | + +List-like access also works: + +```python session=memory ansi=false +print(f"len={len(logs)}, bool={bool(logs)}") +print(f"logs[0] = {logs[0]}") +print(f"logs[-1] = {logs[-1]}") +``` + + +``` +len=3, bool=True +logs[0] = Observation(id=1, ts=1.0, pose=None, tags={'level': 'info'}) +logs[-1] = Observation(id=3, ts=3.0, pose=None, tags={'level': 'info'}) +``` + +Iteration uses paginated fetching under the hood: + +```python session=memory ansi=false +for obs in logs.after(1.5): + print(obs) +``` + + +``` +Observation(id=2, ts=2.0, pose=None, tags={'level': 'error'}) +Observation(id=3, ts=3.0, pose=None, tags={'level': 'info'}) +``` + +## ObservationSet + +`.fetch()` returns an `ObservationSet` — an in-memory result set that is itself a `Stream`. All filters and terminals work on it, re-evaluating in memory without hitting the database: + +```python session=memory ansi=false +results = logs.fetch() +print(results) +print(f"len={len(results)}") + +# re-filter in memory — no DB hit +errors = results.filter_tags(level="error").fetch() +print(errors) +print(errors[0]) +``` + + +``` +ObservationSet[str](3 items) +len=3 +ObservationSet[str](1 items) +Observation(id=2, ts=2.0, pose=None, tags={'level': 'error'}) +``` + +ObservationSet is read-only — `.append()` raises `TypeError`. + +When you chain filters on an ObservationSet, it downgrades to a plain Stream backed by the in-memory list, so it doesn't carry the full result set through the chain. + +## Transforms + +`.transform()` applies a function to each observation's payload. Without `.store()`, it runs entirely in memory: + +```python session=memory ansi=false +upper = logs.transform(lambda s: s.upper()) +print(upper) +print(upper.fetch()) +for obs in upper.fetch(): + print(obs.data) +``` + + +``` +TransformStream[?](Stream[str]("logs") -> PerItemTransformer) +ObservationSet[?](3 items) +MOTOR STARTED +JOINT 3 FAULT +MOTOR STOPPED +``` + +Return `None` to skip an item, return a `list` to fan-out: + +```python skip +# Filter: skip short messages +long = logs.transform(lambda s: s if len(s) > 10 else None) + +# Fan-out: split into words +words = logs.transform(lambda s: s.split()) +``` + +### Storing transforms + +`.store(name)` materializes a transform into a new named stream in the database: + +```python skip +# Default: backfill existing + subscribe to new appends +embeddings = images.transform(EmbeddingTransformer(clip)).store("clip_embeddings") + +# Live only: skip backfill, only process new appends +embeddings = images.transform(EmbeddingTransformer(clip), live=True).store("clip_embeddings") + +# Backfill only: process existing data, don't subscribe +embeddings = images.transform(EmbeddingTransformer(clip), backfill_only=True).store("clip_embeddings") +``` + +| Mode | Processes existing data | Subscribes to new appends | +|------|-------------------------|---------------------------| +| default | yes | yes | +| `live=True` | no | yes | +| `backfill_only=True` | yes | no | + +The output stream kind is auto-detected from the transformer: `EmbeddingTransformer` and `TextEmbeddingTransformer` create an `EmbeddingStream` with vec0 index, `CaptionTransformer` creates a `TextStream` with FTS index. + +Storing also records **parent lineage** — which source stream produced the derived stream. This powers `project_to`. + +### Built-in transformers + +| Transformer | Input | Output | Stored as | +|---|---|---|---| +| `PerItemTransformer(fn)` | any | any | `Stream` | +| `QualityWindowTransformer(quality_fn, window)` | any | same type | `Stream` — keeps best-quality item per time window | +| `CaptionTransformer(model)` | Image | `str` | `TextStream` with FTS index | +| `EmbeddingTransformer(model)` | Image/any | `Embedding` | `EmbeddingStream` with vec0 index | +| `TextEmbeddingTransformer(model)` | `str` | `Embedding` | `EmbeddingStream` with vec0 index | + +`QualityWindowTransformer` buffers observations within a time window and emits only the one with the highest quality score. Useful for sharpness filtering on camera frames: + +```python skip +from dimos.memory.transformer import QualityWindowTransformer + +sharp = images.transform( + QualityWindowTransformer(quality_fn=lambda img: img.sharpness, window=0.5) +).store("sharp_frames") +``` + +## Specialized streams + +### TextStream — full-text search + +```python session=memory ansi=false +text = session.text_stream("events") +text.append("Motor fault on joint 3", ts=1.0) +text.append("Battery low warning", ts=2.0) +text.append("Motor recovered", ts=3.0) + +results = text.search_text("motor").fetch() +print(results) +for obs in results: + print(f" {obs.data}") +``` + + +``` +ObservationSet[str](2 items) + Motor recovered + Motor fault on joint 3 +``` + +Uses SQLite FTS5. Results are ranked by relevance. Optional `k` parameter limits results. + +### EmbeddingStream — vector similarity search + +```python skip +embs = session.embedding_stream("clip_embs", vec_dimensions=512) +embs.append(embedding_vector, ts=1.0) + +results = embs.search_embedding([0.5, 0.3, ...], k=5).fetch() +``` + +Uses sqlite-vec (vec0) for cosine similarity. `search_embedding` accepts: +- Pre-computed `Embedding` or `list[float]` +- A `str` — auto-embedded via the stream's model (`embed_text`) +- An `Image` or other object — auto-embedded via the stream's model (`embed`) + +Results are `EmbeddingObservation` with `.similarity` (0–1 cosine) and `.embedding` (convenience alias for `.data`). + +## Lineage and project_to + +When you store a transform, each derived observation tracks its `parent_id`. Use `.project_to()` to follow the lineage chain back to a source stream: + +```python skip +images = session.stream("images", Image) +embeddings = images.transform(EmbeddingTransformer(clip)).store("clip_embeddings") + +# Search returns EmbeddingObservation — .data is the Embedding, not the source Image +results = embeddings.search_embedding("a hallway", k=5).fetch() +results[0].similarity # cosine similarity (0–1) +results[0].embedding # the Embedding vector +results[0].data # also the Embedding (same as .embedding) + +# To get source images, project back through the lineage chain +image_results = embeddings.search_embedding("office", k=5).project_to(images).fetch() + +# Multi-hop works too: embeddings → sharp_frames → raw_images +image_results = embeddings.search_embedding("office", k=5).project_to(raw_images).fetch() +``` + +## Reactive subscriptions + +Streams emit observations as they're appended: + +```python skip +images.subscribe(lambda obs: print(f"new frame at {obs.ts}")) + +# Filters work on subscriptions too: +images.after(10.0).filter_tags(cam="front").subscribe(handle_front_cam) + +# Or get the raw RxPY Observable: +observable = images.observable() +``` + +Under the hood this is an RxPY Observable on the backend's `Subject`. Embedding and lineage filters are skipped for live filtering (they need DB context); temporal, spatial, and tag filters work. + +## Codecs + +Payloads are BLOB-encoded via auto-selected codecs: + +| Codec | Used for | Strategy | +|-------|----------|----------| +| `JpegCodec` | `Image` | TurboJPEG lossy compression (~10-20x smaller), preserves `frame_id` | +| `LcmCodec` | `DimosMsg` types | LCM binary encoding (lossless) | +| `PickleCodec` | everything else | Python pickle (fallback) | + +`codec_for_type(payload_type)` auto-selects the best codec. This is transparent — you never need to specify a codec manually. + +## Session management + +### Listing streams + +```python session=memory ansi=false +for info in session.list_streams(): + print(f"{info.name}: {info.stream_kind}, {info.count} items") +``` + + +``` +logs: stream, 3 items +events: text, 3 items +``` + +### Context managers + +```python skip +with SqliteStore("memory.db") as store: + session = store.session() + # ... use session ... +# store.stop() called automatically +``` + +### Pose provider + +Auto-attach pose to every appended observation: + +```python skip +images = session.stream("images", Image, pose_provider=robot.get_pose) +images.append(frame) # pose is auto-filled from pose_provider() +``` + +### Persistence + +Data persists across restarts. Reopen the same database and streams pick up where they left off: + +```python skip +store = SqliteStore("memory.db") +session = store.session() +images = session.stream("images", Image) +results = images.after(100.0).fetch() # picks up old data +``` + +The `_streams` meta-table tracks stream names, payload types (as module paths), stream kind, parent lineage, and embedding dimensions. + +## MemoryModule — blueprint integration + +In a robot blueprint, `MemoryModule` wires input streams to memory: + +```python skip +class MyMemory(MemoryModule): + camera: In[Image] + + def start(self): + super().start() + + # Record camera input to a named stream (name/type inferred from input) + self.image_memory = self.memory(self.camera) + + # With quality filtering at 2 fps (keeps sharpest frame per window) + self.image_memory = self.memory(self.camera, fps=2) + + # Build derived streams + self.embeddings = self.image_memory.transform( + EmbeddingTransformer(CLIPModel()), live=True + ).store("clip_embeddings") +``` + +## Utilities + +### Bulk import + +```python skip +from dimos.memory.ingest import ingest + +# Import from any iterable of (timestamp, payload) tuples +count = ingest(images, replay.iterate_ts()) + +# With pose interpolation from an odometry source +count = ingest(images, replay.iterate_ts(), pose_source=odom_replay) +``` + +### Rerun export + +```python skip +from dimos.memory.rerun import to_rerun + +# Log a stream's observations to Rerun timeline +count = to_rerun(images) +count = to_rerun(images, entity_path="memory/camera") +``` + +```python session=memory no-result +store.stop() +``` + +```sh no-result +rm -f /tmp/memory_readme.db +``` From d6e5efca090844fa019c5a8d23ed911be258ce52 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Sun, 8 Mar 2026 16:26:47 +0800 Subject: [PATCH 046/118] bigoffice db in lfs, sqlite accepts Path --- data/.lfs/go2_bigoffice.db.tar.gz | 3 +++ dimos/memory/impl/sqlite.py | 7 ++++--- 2 files changed, 7 insertions(+), 3 deletions(-) create mode 100644 data/.lfs/go2_bigoffice.db.tar.gz diff --git a/data/.lfs/go2_bigoffice.db.tar.gz b/data/.lfs/go2_bigoffice.db.tar.gz new file mode 100644 index 0000000000..843a97b9b1 --- /dev/null +++ b/data/.lfs/go2_bigoffice.db.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6a4ce6670e3a48fdf378188ababe8dc607ed83b5160802ff7f309aa43f8e72ce +size 406735715 diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index 6e91ea44ba..51524c5b0c 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -71,6 +71,7 @@ if TYPE_CHECKING: from collections.abc import Callable + import os from dimos.memory.type import PoseProvider from dimos.models.embedding.base import EmbeddingModel @@ -404,7 +405,7 @@ def do_append( if ts is None: ts = time.time() if pose is None and self._pose_provider is not None: - pose = self._pose_provider() + pose = self._pose_provider(ts) pose_cols = _decompose_pose(pose) tags_json = _serialize_tags(tags) @@ -996,8 +997,8 @@ class SqliteStore(Store): and extensions loaded. Sessions are safe to use from different threads. """ - def __init__(self, path: str) -> None: - self._path = path + def __init__(self, path: str | os.PathLike[str]) -> None: + self._path = str(path) self._closed = False def _connect(self) -> sqlite3.Connection: From 31cf8a859bf2ba71c6c1ff46d6b6ea4ae66d4e91 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Sun, 8 Mar 2026 16:45:07 +0800 Subject: [PATCH 047/118] projection transformers --- dimos/memory/impl/sqlite.py | 1 + dimos/memory/impl/test_sqlite.py | 16 ++-- dimos/memory/module.py | 2 +- dimos/memory/stream.py | 5 +- dimos/memory/test_transformer.py | 135 ++++++++++++++++++++++++++++++- dimos/memory/transformer.py | 52 +++++++++++- dimos/memory/type.py | 4 +- 7 files changed, 199 insertions(+), 16 deletions(-) diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index 51524c5b0c..b0d4325206 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -410,6 +410,7 @@ def do_append( pose_cols = _decompose_pose(pose) tags_json = _serialize_tags(tags) + print("APPEND", payload) # Encode payload before touching the DB so a codec error can't leave # a metadata row without a matching payload row. payload_blob = self._codec.encode(payload) diff --git a/dimos/memory/impl/test_sqlite.py b/dimos/memory/impl/test_sqlite.py index de76df8d3d..2b1b71d021 100644 --- a/dimos/memory/impl/test_sqlite.py +++ b/dimos/memory/impl/test_sqlite.py @@ -958,29 +958,31 @@ def test_tags_filter(self) -> None: from dimos.memory.type import TagsFilter f = TagsFilter((("cam", "front"),)) - assert f.matches(Observation(id=1, tags={"cam": "front", "quality": "high"})) is True - assert f.matches(Observation(id=2, tags={"cam": "rear"})) is False - assert f.matches(Observation(id=3, tags={})) is False + assert ( + f.matches(Observation(id=1, ts=0.0, tags={"cam": "front", "quality": "high"})) is True + ) + assert f.matches(Observation(id=2, ts=0.0, tags={"cam": "rear"})) is False + assert f.matches(Observation(id=3, ts=0.0, tags={})) is False def test_text_search_filter(self) -> None: from dimos.memory.type import TextSearchFilter f = TextSearchFilter("motor", k=None) - assert f.matches(Observation(id=1, _data="Motor fault on joint 3")) is True - assert f.matches(Observation(id=2, _data="Battery low")) is False + assert f.matches(Observation(id=1, ts=0.0, _data="Motor fault on joint 3")) is True + assert f.matches(Observation(id=2, ts=0.0, _data="Battery low")) is False def test_embedding_search_filter_always_true(self) -> None: from dimos.memory.type import EmbeddingSearchFilter f = EmbeddingSearchFilter([1.0, 0.0], k=5) - assert f.matches(Observation(id=1)) is True + assert f.matches(Observation(id=1, ts=0.0)) is True def test_lineage_filter_raises(self) -> None: from dimos.memory.type import LineageFilter, StreamQuery f = LineageFilter("src", StreamQuery(), ()) with pytest.raises(NotImplementedError): - f.matches(Observation(id=1)) + f.matches(Observation(id=1, ts=0.0)) class TestFilteredAppended: diff --git a/dimos/memory/module.py b/dimos/memory/module.py index c92f1b5c60..b651d037be 100644 --- a/dimos/memory/module.py +++ b/dimos/memory/module.py @@ -55,7 +55,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._store: SqliteStore | None = None - def pose(self) -> PoseStamped | None: + def pose(self, ts: float) -> PoseStamped | None: return self.tf.get_pose(self.config.world_frame, self.config.robot_frame) # type: ignore[no-any-return] @rpc diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index b063413cd5..4941c63337 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -15,6 +15,7 @@ from __future__ import annotations import copy +import time from typing import ( TYPE_CHECKING, Any, @@ -540,7 +541,7 @@ def __init__( live: bool = False, backfill_only: bool = False, ) -> None: - super().__init__(backend=None) + super().__init__(backend=None, session=source._session) self._source = source self._transformer = transformer self._live = live @@ -629,7 +630,7 @@ def append( ) -> Observation[R]: obs: Observation[R] = Observation( id=self._next_id, - ts=ts, + ts=ts if ts is not None else time.time(), tags=tags or {}, parent_id=parent_id, _data=payload, diff --git a/dimos/memory/test_transformer.py b/dimos/memory/test_transformer.py index 1e294d8cd9..87d6da134f 100644 --- a/dimos/memory/test_transformer.py +++ b/dimos/memory/test_transformer.py @@ -22,14 +22,14 @@ import pytest from dimos.memory.impl.sqlite import SqliteSession, SqliteStore -from dimos.memory.transformer import TextEmbeddingTransformer +from dimos.memory.transformer import TextEmbeddingTransformer, VLMDetectionTransformer from dimos.models.embedding.base import Embedding, EmbeddingModel +from dimos.msgs.sensor_msgs.Image import Image +from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox if TYPE_CHECKING: from collections.abc import Iterator - from dimos.msgs.sensor_msgs.Image import Image - class FakeTextEmbedder(EmbeddingModel): device = "cpu" @@ -131,3 +131,132 @@ def test_text_embedding_search_and_project(self, session: SqliteSession) -> None results = emb_stream.search_embedding("kitchen", k=2).project_to(logs).fetch() assert len(results) == 2 assert all("kitchen" in r.data.lower() for r in results) + + +# ── Fake VLM for detection tests ───────────────────────────────────── + + +def _make_image(width: int = 640, height: int = 480) -> Image: + """Create a simple test image.""" + return Image(np.zeros((height, width, 3), dtype=np.uint8)) + + +class _FakeImageDetections2D: + """Mimics ImageDetections2D with a .detections list.""" + + def __init__(self, detections: list[Detection2DBBox]) -> None: + self.detections = detections + + +class FakeVlModel: + """Fake VlModel that returns canned detections per call. + + Pass a list of detection lists — one per call to query_detections(). + Each call pops the next entry from the list. + """ + + def __init__( + self, detections_per_call: list[list[tuple[str, float, float, float, float]]] + ) -> None: + self._queue = list(detections_per_call) + + def query_detections(self, image: Image, query: str) -> _FakeImageDetections2D: + raw = self._queue.pop(0) if self._queue else [] + ts = image.ts + dets = [] + for i, (name, x1, y1, x2, y2) in enumerate(raw): + dets.append( + Detection2DBBox( + bbox=(x1, y1, x2, y2), + track_id=i, + class_id=-1, + confidence=0.9, + name=name, + ts=ts, + image=image, + ) + ) + return _FakeImageDetections2D(dets) + + +class TestVLMDetectionTransformer: + """Test VLM detection transformer.""" + + def test_vlm_detection_backfill(self, session: SqliteSession) -> None: + """3 images, VLM finds 1 detection per image → 3 detections with parent_id.""" + frames = session.stream("vlm_frames", Image) + frames.append(_make_image(), ts=1.0) + frames.append(_make_image(), ts=2.0) + frames.append(_make_image(), ts=3.0) + + vlm = FakeVlModel( + [ + [("bottle", 10, 20, 100, 200)], + [("bottle", 50, 60, 150, 250)], + [("bottle", 30, 40, 130, 230)], + ] + ) + + det_stream = frames.transform( + VLMDetectionTransformer(vlm, query="bottle") # type: ignore[arg-type] + ).store("vlm_detections", Detection2DBBox) + + results = det_stream.fetch() + assert len(results) == 3 + + # All detections have parent_id linking back to source frames + frame_ids = {obs.id for obs in frames.fetch()} + for det_obs in results: + assert det_obs.parent_id in frame_ids + assert det_obs.data.image is None # stored without image + assert det_obs.data.name == "bottle" + assert det_obs.tags["query"] == "bottle" + + def test_vlm_detection_no_matches(self, session: SqliteSession) -> None: + """VLM returns empty detections → stream stays empty.""" + frames = session.stream("vlm_empty_frames", Image) + frames.append(_make_image(), ts=1.0) + frames.append(_make_image(), ts=2.0) + + vlm = FakeVlModel([[], []]) # no detections for either call + + det_stream = frames.transform( + VLMDetectionTransformer(vlm, query="cat") # type: ignore[arg-type] + ).store("vlm_empty_dets", Detection2DBBox) + + assert det_stream.count() == 0 + + def test_vlm_detection_multiple_per_frame(self, session: SqliteSession) -> None: + """1 image → 3 detections, all share same parent_id.""" + frames = session.stream("vlm_multi_frames", Image) + frames.append(_make_image(), ts=1.0) + + vlm = FakeVlModel( + [ + [ + ("bottle", 10, 20, 100, 200), + ("cup", 200, 100, 300, 250), + ("plate", 400, 300, 500, 400), + ], + ] + ) + + det_stream = frames.transform( + VLMDetectionTransformer(vlm, query="objects") # type: ignore[arg-type] + ).store("vlm_multi_dets", Detection2DBBox) + + results = det_stream.fetch() + assert len(results) == 3 + + # All share the same parent_id (the single source frame) + parent_ids = {obs.parent_id for obs in results} + assert len(parent_ids) == 1 + + names = {obs.data.name for obs in results} + assert names == {"bottle", "cup", "plate"} + + parent_ids = {obs.parent_id for obs in results} + assert len(parent_ids) == 1 + + names = {obs.data.name for obs in results} + assert names == {"bottle", "cup", "plate"} diff --git a/dimos/memory/transformer.py b/dimos/memory/transformer.py index 50b4d5954a..9519322fb8 100644 --- a/dimos/memory/transformer.py +++ b/dimos/memory/transformer.py @@ -21,7 +21,7 @@ from collections.abc import Callable from dimos.models.embedding.base import Embedding, EmbeddingModel - from dimos.models.vl.base import Captioner + from dimos.models.vl.base import Captioner, VlModel from .stream import Stream from .type import Observation @@ -247,3 +247,53 @@ def on_append(self, obs: Observation[Any], target: Stream[Embedding]) -> None: if isinstance(emb, list): emb = emb[0] target.append(emb, ts=obs.ts, pose=obs.pose, tags=obs.tags, parent_id=obs.id) + + +class VLMDetectionTransformer(Transformer[Any, "Detection2DBBox"]): + """Wraps a VlModel to produce Detection2DBBox from images. + + Calls query_detections() per image, emitting one Detection2DBBox(image=None) + per bounding box with parent_id linking to the source image observation. + """ + + supports_backfill: bool = True + supports_live: bool = True + + def __init__(self, model: VlModel, query: str) -> None: + from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox + + self.model = model + self._query = query + self.output_type: type | None = Detection2DBBox + + def process(self, source: Stream[Any], target: Stream[Detection2DBBox]) -> None: + for page in source.fetch_pages(): + for obs in page: + self._detect(obs, target) + + def on_append(self, obs: Observation[Any], target: Stream[Detection2DBBox]) -> None: + self._detect(obs, target) + + def _detect(self, obs: Observation[Any], target: Stream[Detection2DBBox]) -> None: + from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox + + try: + result = self.model.query_detections(obs.data, self._query) + except Exception: + return + for det in result.detections: + target.append( + Detection2DBBox( + bbox=det.bbox, + track_id=det.track_id, + class_id=det.class_id, + confidence=det.confidence, + name=det.name, + ts=obs.ts or det.ts, + image=None, + ), + ts=obs.ts, + pose=obs.pose, + tags={**(obs.tags or {}), "query": self._query}, + parent_id=obs.id, + ) diff --git a/dimos/memory/type.py b/dimos/memory/type.py index 65170104fd..b9c0bcc376 100644 --- a/dimos/memory/type.py +++ b/dimos/memory/type.py @@ -26,7 +26,7 @@ if TYPE_CHECKING: from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped -PoseProvider: TypeAlias = Callable[[], Any] # () -> PoseLike | None +PoseProvider: TypeAlias = Callable[[float], Any] # (ts) -> PoseLike | None T = TypeVar("T") @@ -43,7 +43,7 @@ class _Unset: @dataclass class Observation(Generic[T]): id: int - ts: float | None = None + ts: float pose: PoseStamped | None = None tags: dict[str, Any] = field(default_factory=dict) parent_id: int | None = field(default=None, repr=False) From 04337db4b5902514d87a93fe36b108441cd79240 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Sun, 8 Mar 2026 16:54:08 +0800 Subject: [PATCH 048/118] stream info removed, stream accessor helper, TS unique per stream --- dimos/memory/__init__.py | 5 ++-- dimos/memory/impl/sqlite.py | 24 ++++++++-------- dimos/memory/store.py | 56 +++++++++++++++++++++++++++++++++++-- dimos/memory/stream.py | 14 ++++++++-- dimos/memory/type.py | 8 ------ 5 files changed, 80 insertions(+), 27 deletions(-) diff --git a/dimos/memory/__init__.py b/dimos/memory/__init__.py index fe65bded3c..489e3f568d 100644 --- a/dimos/memory/__init__.py +++ b/dimos/memory/__init__.py @@ -1,5 +1,5 @@ from dimos.memory.codec import Codec, JpegCodec, LcmCodec, PickleCodec, codec_for_type -from dimos.memory.store import Session, Store +from dimos.memory.store import Session, Store, StreamNamespace from dimos.memory.stream import EmbeddingStream, ObservationSet, Stream, TextStream from dimos.memory.transformer import ( CaptionTransformer, @@ -11,7 +11,6 @@ from dimos.memory.type import ( EmbeddingObservation, Observation, - StreamInfo, ) __all__ = [ @@ -29,7 +28,7 @@ "Session", "Store", "Stream", - "StreamInfo", + "StreamNamespace", "TextEmbeddingTransformer", "TextStream", "Transformer", diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index b0d4325206..434a71c975 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -62,7 +62,6 @@ LineageFilter, NearFilter, Observation, - StreamInfo, StreamQuery, TagsFilter, TextSearchFilter, @@ -855,20 +854,21 @@ def embedding_stream( self._streams[name] = es return es - def list_streams(self) -> list[StreamInfo]: + def list_streams(self) -> list[Stream[Any]]: rows = self._conn.execute( "SELECT name, payload_module, stream_kind FROM _streams" ).fetchall() - result: list[StreamInfo] = [] + result: list[Stream[Any]] = [] for name, pmodule, kind in rows: _validate_identifier(name) - count_row = self._conn.execute(f"SELECT COUNT(*) FROM {name}").fetchone() - count = count_row[0] if count_row else 0 - result.append( - StreamInfo( - name=name, payload_type=pmodule, count=count, stream_kind=kind or "stream" - ) - ) + payload_type = module_path_to_type(pmodule) if pmodule else None + kind = kind or "stream" + if kind == "embedding": + result.append(self.embedding_stream(name)) + elif kind == "text": + result.append(self.text_stream(name)) + else: + result.append(self.stream(name, payload_type or object)) return result def materialize_transform( @@ -929,7 +929,7 @@ def _ensure_stream_tables(self, name: str) -> None: self._conn.execute( f"CREATE TABLE IF NOT EXISTS {name} (" " id INTEGER PRIMARY KEY AUTOINCREMENT," - " ts REAL," + " ts REAL UNIQUE NOT NULL," " pose_x REAL," " pose_y REAL," " pose_z REAL," @@ -941,7 +941,7 @@ def _ensure_stream_tables(self, name: str) -> None: " parent_id INTEGER" ")" ) - self._conn.execute(f"CREATE INDEX IF NOT EXISTS idx_{name}_ts ON {name}(ts)") + self._conn.execute( f"CREATE TABLE IF NOT EXISTS {name}_payload ( id INTEGER PRIMARY KEY, data BLOB)" ) diff --git a/dimos/memory/store.py b/dimos/memory/store.py index e37f22b057..2484cb42bf 100644 --- a/dimos/memory/store.py +++ b/dimos/memory/store.py @@ -24,17 +24,69 @@ from .stream import EmbeddingStream, Stream, TextStream from .transformer import Transformer - from .type import PoseProvider, StreamInfo + from .type import PoseProvider T = TypeVar("T") +class StreamNamespace: + """Attribute-access proxy for session streams. + + Usage:: + + session.streams.image_stream # same as looking up "image_stream" from list_streams() + session.streams["image_stream"] + list(session.streams) # iterate all streams + len(session.streams) + """ + + def __init__(self, session: Session) -> None: + self._session = session + + def _load(self) -> dict[str, Stream[Any]]: + return {s._backend.stream_name: s for s in self._session.list_streams() if s._backend} + + def __getattr__(self, name: str) -> Stream[Any]: + if name.startswith("_"): + raise AttributeError(name) + streams = self._load() + if name in streams: + return streams[name] + raise AttributeError( + f"No stream named {name!r}. Available: {', '.join(streams) or '(none)'}" + ) + + def __getitem__(self, name: str) -> Stream[Any]: + streams = self._load() + if name in streams: + return streams[name] + raise KeyError(name) + + def __iter__(self): + return iter(self._load().values()) + + def __len__(self) -> int: + return len(self._load()) + + def __contains__(self, name: str) -> bool: + return name in self._load() + + def __repr__(self) -> str: + names = list(self._load().keys()) + return f"StreamNamespace({names})" + + class Session(Resource): """A session against a memory store. Creates and manages streams. Inherits DisposableBase so sessions can be added to CompositeDisposable. """ + @property + def streams(self) -> StreamNamespace: + """Attribute-access namespace for all streams in this session.""" + return StreamNamespace(self) + def start(self) -> None: pass @@ -71,7 +123,7 @@ def embedding_stream( """Get or create an embedding stream with vec0 index.""" @abstractmethod - def list_streams(self) -> list[StreamInfo]: ... + def list_streams(self) -> list[Stream[Any]]: ... @abstractmethod def materialize_transform( diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index 4941c63337..27bf8036f7 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -410,15 +410,25 @@ def get_time_range(self) -> tuple[float, float]: def summary(self) -> str: from datetime import datetime, timezone + t = self._rich_text() n = self.count() if n == 0: - return f"{self!r}: empty" + t.append(": ", style="dim") + t.append("empty", style="italic dim") + return _render_text(t) t0, t1 = self.get_time_range() fmt = "%Y-%m-%d %H:%M:%S" dt0 = datetime.fromtimestamp(t0, tz=timezone.utc).strftime(fmt) dt1 = datetime.fromtimestamp(t1, tz=timezone.utc).strftime(fmt) dur = t1 - t0 - return f"{self!r}: {n} items, {dt0} — {dt1} ({dur:.1f}s)" + t.append(": ", style="dim") + t.append(f"{n}", style="bold white") + t.append(" items, ", style="dim") + t.append(dt0, style="bright_blue") + t.append(" — ", style="dim") + t.append(dt1, style="bright_blue") + t.append(f" ({dur:.1f}s)", style="dim yellow") + return _render_text(t) # ── Reactive ────────────────────────────────────────────────────── diff --git a/dimos/memory/type.py b/dimos/memory/type.py index b9c0bcc376..23fd8dcba6 100644 --- a/dimos/memory/type.py +++ b/dimos/memory/type.py @@ -84,14 +84,6 @@ def embedding(self) -> Embedding: return self.data -@dataclass -class StreamInfo: - name: str - payload_type: str | None = None - count: int = 0 - stream_kind: str = "stream" - - # ── Filter types ────────────────────────────────────────────────────── From 6fc6e8dbddb903ea60f482fb85ca863bcf062cd8 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Sun, 8 Mar 2026 16:57:26 +0800 Subject: [PATCH 049/118] Add colored summary() output and model= param to search_embedding summary() now renders the rich-text stream header with colored type info, count, timestamps, and duration. search_embedding() accepts an optional model= override so callers don't need to attach a model to the stream. --- dimos/memory/stream.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index 27bf8036f7..ef5d21992c 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -483,12 +483,13 @@ def search_embedding( query: Embedding | list[float] | str | Any, *, k: int, + model: EmbeddingModel | None = None, ) -> EmbeddingStream[T]: """Search by vector similarity. Accepts pre-computed embeddings, raw float lists, text strings, or images/other objects. Text and non-vector inputs are auto-embedded - using the model that created this stream. + using the model that created this stream (or the ``model`` override). Returns an EmbeddingStream — use ``.project_to(source)`` to get results in the source stream's type, or ``.fetch()`` for @@ -496,8 +497,14 @@ def search_embedding( """ from dimos.models.embedding.base import Embedding as EmbeddingCls + resolve = model or self._embedding_model if isinstance(query, str): - emb = self._require_model().embed_text(query) + if resolve is None: + raise TypeError( + "No embedding model available. Pass model= or use a " + "pre-computed Embedding / list[float]." + ) + emb = resolve.embed_text(query) if isinstance(emb, list): emb = emb[0] return self.search_embedding(emb, k=k) @@ -508,7 +515,12 @@ def search_embedding( vec = list(query) else: # Assume embeddable object (Image, etc.) - emb = self._require_model().embed(query) + if resolve is None: + raise TypeError( + "No embedding model available. Pass model= or use a " + "pre-computed Embedding / list[float]." + ) + emb = resolve.embed(query) if isinstance(emb, list): emb = emb[0] return self.search_embedding(emb, k=k) From a6a06e179b0e12e0018c56c9320b6a8a8976200e Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Sun, 8 Mar 2026 17:33:55 +0800 Subject: [PATCH 050/118] stream delete --- dimos/memory/impl/sqlite.py | 11 +++++++++++ dimos/memory/store.py | 4 ++++ dimos/memory/stream.py | 23 +++++++++++++++++++---- dimos/memory/type.py | 15 ++++++++++++--- 4 files changed, 46 insertions(+), 7 deletions(-) diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index 434a71c975..3241b26168 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -871,6 +871,17 @@ def list_streams(self) -> list[Stream[Any]]: result.append(self.stream(name, payload_type or object)) return result + def delete_stream(self, name: str) -> None: + _validate_identifier(name) + for suffix in ("_vec", "_fts", "_rtree", "_payload", ""): + table = f"{name}{suffix}" + # Virtual tables (rtree, fts, vec) need DROP TABLE, not DROP TABLE IF EXISTS + # on some builds, but IF EXISTS is safe for all. + self._conn.execute(f"DROP TABLE IF EXISTS {table}") + self._conn.execute("DELETE FROM _streams WHERE name = ?", (name,)) + self._conn.commit() + self._streams.pop(name, None) + def materialize_transform( self, name: str, diff --git a/dimos/memory/store.py b/dimos/memory/store.py index 2484cb42bf..b2d3646f92 100644 --- a/dimos/memory/store.py +++ b/dimos/memory/store.py @@ -125,6 +125,10 @@ def embedding_stream( @abstractmethod def list_streams(self) -> list[Stream[Any]]: ... + @abstractmethod + def delete_stream(self, name: str) -> None: + """Drop a stream and all its associated tables (payload, rtree, etc.).""" + @abstractmethod def materialize_transform( self, diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index ef5d21992c..9d863341a5 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -404,6 +404,13 @@ def count(self) -> int: def exists(self) -> bool: return self.count() > 0 + def delete(self) -> None: + """Drop this stream and all associated storage.""" + if self._session is None: + raise TypeError("Cannot delete: no session available.") + backend = self._require_backend() + self._session.delete_stream(backend.stream_name) + def get_time_range(self) -> tuple[float, float]: return (self.first().ts, self.last().ts) @@ -498,16 +505,18 @@ def search_embedding( from dimos.models.embedding.base import Embedding as EmbeddingCls resolve = model or self._embedding_model + label: str | None = None if isinstance(query, str): if resolve is None: raise TypeError( "No embedding model available. Pass model= or use a " "pre-computed Embedding / list[float]." ) + label = query emb = resolve.embed_text(query) if isinstance(emb, list): emb = emb[0] - return self.search_embedding(emb, k=k) + query = emb if isinstance(query, EmbeddingCls): vec = query.to_numpy().tolist() @@ -520,17 +529,23 @@ def search_embedding( "No embedding model available. Pass model= or use a " "pre-computed Embedding / list[float]." ) + label = type(query).__name__ emb = resolve.embed(query) if isinstance(emb, list): emb = emb[0] - return self.search_embedding(emb, k=k) + query = emb + vec = query.to_numpy().tolist() - return self._with_filter(EmbeddingSearchFilter(vec, k)) + return self._with_filter(EmbeddingSearchFilter(vec, k, label=label)) def fetch(self) -> ObservationSet[T]: # type: ignore[override] backend = self._require_backend() results = backend.execute_fetch(self._query) - return ObservationSet(cast("list[Observation[T]]", results), session=self._session) + return ObservationSet( + cast("list[Observation[T]]", results), + session=self._session, + payload_type=self._payload_type, + ) def first(self) -> EmbeddingObservation: # type: ignore[override] results = self.limit(1).fetch() diff --git a/dimos/memory/type.py b/dimos/memory/type.py index 23fd8dcba6..ebd7308776 100644 --- a/dimos/memory/type.py +++ b/dimos/memory/type.py @@ -203,17 +203,26 @@ def _rich_text(self) -> Text: class EmbeddingSearchFilter: query: list[float] k: int + label: str | None = None def matches(self, obs: Observation[Any]) -> bool: return True # top-k handled as special pass in ListBackend def __str__(self) -> str: - return f"search(k={self.k})" + parts = [f"k={self.k}"] + if self.label: + parts.insert(0, repr(self.label)) + return f"search_embedding({', '.join(parts)})" def _rich_text(self) -> Text: t = Text() - t.append("search", style="cyan") - t.append(f"(k={self.k})") + t.append("search_embedding", style="cyan") + t.append("(") + if self.label: + t.append(repr(self.label), style="green") + t.append(", ") + t.append(f"k={self.k}") + t.append(")") return t From b9af997e7fb37d53e61e9f29a63256fa5717685e Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Sun, 8 Mar 2026 17:47:04 +0800 Subject: [PATCH 051/118] florence model detail settings and prefix filter --- dimos/models/vl/florence.py | 49 ++++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/dimos/models/vl/florence.py b/dimos/models/vl/florence.py index a44267d620..993c573e56 100644 --- a/dimos/models/vl/florence.py +++ b/dimos/models/vl/florence.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum from functools import cached_property from PIL import Image as PILImage @@ -23,6 +24,15 @@ from dimos.msgs.sensor_msgs import Image +class CaptionDetail(Enum): + """Florence-2 caption detail level.""" + + BRIEF = "" + NORMAL = "" + DETAILED = "" + MORE_DETAILED = "" + + class Florence2Model(HuggingFaceModel, Captioner): """Florence-2 captioning model from Microsoft. @@ -35,6 +45,7 @@ class Florence2Model(HuggingFaceModel, Captioner): def __init__( self, model_name: str = "microsoft/Florence-2-base", + detail: CaptionDetail = CaptionDetail.NORMAL, **kwargs: object, ) -> None: """Initialize Florence-2 model. @@ -43,9 +54,11 @@ def __init__( model_name: HuggingFace model name. Options: - "microsoft/Florence-2-base" (~0.2B, fastest) - "microsoft/Florence-2-large" (~0.8B, better quality) + detail: Caption detail level **kwargs: Additional config options (device, dtype, warmup, etc.) """ super().__init__(model_name=model_name, **kwargs) + self._task_prompt = detail.value @cached_property def _processor(self) -> AutoProcessor: @@ -53,27 +66,22 @@ def _processor(self) -> AutoProcessor: self.config.model_name, trust_remote_code=self.config.trust_remote_code ) - def caption(self, image: Image, detail: str = "normal") -> str: - """Generate a caption for the image. + _STRIP_PREFIXES = ("The image shows ", "The image is a ") - Args: - image: Input image to caption - detail: Level of detail for caption: - - "brief": Short, concise caption - - "normal": Standard caption (default) - - "detailed": More detailed description + @staticmethod + def _clean_caption(text: str) -> str: + for prefix in Florence2Model._STRIP_PREFIXES: + if text.startswith(prefix): + return text[len(prefix):] + return text + + def caption(self, image: Image) -> str: + """Generate a caption for the image. Returns: Text description of the image """ - # Map detail level to Florence-2 task prompts - task_prompts = { - "brief": "", - "normal": "", - "detailed": "", - "more_detailed": "", - } - task_prompt = task_prompts.get(detail, "") + task_prompt = self._task_prompt # Convert to PIL pil_image = PILImage.fromarray(image.to_rgb().data) @@ -101,21 +109,18 @@ def caption(self, image: Image, detail: str = "normal") -> str: # Extract caption from parsed output caption: str = parsed.get(task_prompt, generated_text) - return caption.strip() + return self._clean_caption(caption.strip()) def caption_batch(self, *images: Image) -> list[str]: """Generate captions for multiple images efficiently. - Args: - images: Input images to caption - Returns: List of text descriptions """ if not images: return [] - task_prompt = "" + task_prompt = self._task_prompt # Convert all to PIL pil_images = [PILImage.fromarray(img.to_rgb().data) for img in images] @@ -144,7 +149,7 @@ def caption_batch(self, *images: Image) -> list[str]: parsed = self._processor.post_process_generation( text, task=task_prompt, image_size=pil_img.size ) - captions.append(parsed.get(task_prompt, text).strip()) + captions.append(self._clean_caption(parsed.get(task_prompt, text).strip())) return captions From 1e42408f8dcf01bdf518daa26f13d9c5db7876b8 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Sun, 8 Mar 2026 18:18:38 +0800 Subject: [PATCH 052/118] extracted formatting to a separate file --- dimos/memory/formatting.py | 246 ++++++++++++++++++++++++++++++++++++ dimos/memory/stream.py | 80 ++---------- dimos/memory/transformer.py | 57 +-------- dimos/memory/type.py | 89 ------------- 4 files changed, 258 insertions(+), 214 deletions(-) create mode 100644 dimos/memory/formatting.py diff --git a/dimos/memory/formatting.py b/dimos/memory/formatting.py new file mode 100644 index 0000000000..1f3c785567 --- /dev/null +++ b/dimos/memory/formatting.py @@ -0,0 +1,246 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Rich text rendering for memory types and streams.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from rich.console import Console +from rich.text import Text + +if TYPE_CHECKING: + from collections.abc import Callable + +_console = Console(force_terminal=True, highlight=False) + + +def render_text(text: Text) -> str: + """Render rich Text to a terminal string with ANSI codes.""" + with _console.capture() as cap: + _console.print(text, end="", soft_wrap=True) + return cap.get() + + +# ── Filter rendering ──────────────────────────────────────────────── + + +def _after_rich(f: Any) -> Text: + t = Text() + t.append("after", style="cyan") + t.append(f"(t={f.t})") + return t + + +def _before_rich(f: Any) -> Text: + t = Text() + t.append("before", style="cyan") + t.append(f"(t={f.t})") + return t + + +def _time_range_rich(f: Any) -> Text: + t = Text() + t.append("time_range", style="cyan") + t.append(f"({f.t1}, {f.t2})") + return t + + +def _at_rich(f: Any) -> Text: + t = Text() + t.append("at", style="cyan") + t.append(f"(t={f.t}, tol={f.tolerance})") + return t + + +def _near_rich(f: Any) -> Text: + t = Text() + t.append("near", style="cyan") + t.append(f"(radius={f.radius})") + return t + + +def _tags_rich(f: Any) -> Text: + t = Text() + t.append("tags", style="cyan") + pairs = ", ".join(f"{k}={v!r}" for k, v in f.tags) + t.append(f"({pairs})") + return t + + +def _embedding_search_rich(f: Any) -> Text: + t = Text() + t.append("search_embedding", style="cyan") + t.append("(") + if f.label: + t.append(repr(f.label), style="green") + t.append(", ") + t.append(f"k={f.k}") + t.append(")") + return t + + +def _text_search_rich(f: Any) -> Text: + t = Text() + t.append("text", style="cyan") + t.append(f"({f.text!r})") + return t + + +def _lineage_rich(f: Any) -> Text: + t = Text() + t.append("lineage", style="cyan") + hops = " -> ".join(f.hops) if f.hops else "direct" + t.append(f"({f.source_table} -> {hops})") + return t + + +_FILTER_DISPATCH: dict[type, Callable[..., Text]] | None = None + + +def _get_dispatch() -> dict[type, Callable[..., Text]]: + global _FILTER_DISPATCH + if _FILTER_DISPATCH is not None: + return _FILTER_DISPATCH + from .type import ( + AfterFilter, + AtFilter, + BeforeFilter, + EmbeddingSearchFilter, + LineageFilter, + NearFilter, + TagsFilter, + TextSearchFilter, + TimeRangeFilter, + ) + + _FILTER_DISPATCH = { + AfterFilter: _after_rich, + BeforeFilter: _before_rich, + TimeRangeFilter: _time_range_rich, + AtFilter: _at_rich, + NearFilter: _near_rich, + TagsFilter: _tags_rich, + EmbeddingSearchFilter: _embedding_search_rich, + TextSearchFilter: _text_search_rich, + LineageFilter: _lineage_rich, + } + return _FILTER_DISPATCH + + +def filter_rich(f: Any) -> Text: + """Render a Filter to rich Text.""" + dispatch = _get_dispatch() + renderer = dispatch.get(type(f)) + if renderer is None: + return Text(str(f)) + return renderer(f) + + +def query_rich(q: Any) -> Text: + """Render a StreamQuery to rich Text.""" + t = Text() + pipe = Text(" | ", style="dim") + parts: list[Text] = [filter_rich(f) for f in q.filters] + if q.order_field: + p = Text() + p.append("order", style="cyan") + direction = "desc" if q.order_desc else "asc" + p.append(f"({q.order_field}, {direction})") + parts.append(p) + if q.limit_val is not None: + p = Text() + p.append("limit", style="cyan") + p.append(f"({q.limit_val})") + parts.append(p) + if q.offset_val is not None: + p = Text() + p.append("offset", style="cyan") + p.append(f"({q.offset_val})") + parts.append(p) + for i, part in enumerate(parts): + if i > 0: + t.append_text(pipe) + t.append_text(part) + return t + + +# ── Stream rendering ──────────────────────────────────────────────── + + +def rich_text(obj: Any) -> Text: + """Render a Stream, TransformStream, ObservationSet, or StreamQuery to rich Text. + + Uses duck-typing on attributes — no dispatch table needed. + """ + # TransformStream: has _source and _transformer + if hasattr(obj, "_transformer"): + xf = obj._transformer + t = Text() + t.append("TransformStream", style="bold cyan") + t.append("[", style="dim") + t.append(xf.output_type.__name__ if xf.output_type else "?", style="yellow") + t.append("]", style="dim") + t.append("(", style="dim") + t.append_text(rich_text(obj._source)) + t.append(" -> ", style="dim") + t.append(type(xf).__name__, style="magenta") + if obj._live: + t.append(", ", style="dim") + t.append("live=True", style="yellow") + if obj._backfill_only: + t.append(", ", style="dim") + t.append("backfill_only=True", style="yellow") + t.append(")", style="dim") + qt = query_rich(obj._query) + if qt.plain: + t.append(" | ", style="dim") + t.append_text(qt) + return t + + # ObservationSet: has _observations list + if hasattr(obj, "_observations"): + type_name = obj._payload_type.__name__ if obj._payload_type else "?" + t = Text() + t.append("ObservationSet", style="bold cyan") + t.append("[", style="dim") + t.append(type_name, style="yellow") + t.append("]", style="dim") + t.append("(", style="dim") + t.append(f"{len(obj._observations)} items", style="green") + t.append(")", style="dim") + return t + + # StreamQuery + if hasattr(obj, "filters"): + return query_rich(obj) + + # Stream (and subclasses like EmbeddingStream, TextStream) + cls_name = type(obj).__name__ + type_name = obj._payload_type.__name__ if obj._payload_type else "?" + name = obj._backend.stream_name if obj._backend else "unbound" + t = Text() + t.append(cls_name, style="bold cyan") + t.append("[", style="dim") + t.append(type_name, style="yellow") + t.append("]", style="dim") + t.append("(", style="dim") + t.append(f'"{name}"', style="green") + t.append(")", style="dim") + qt = query_rich(obj._query) + if qt.plain: + t.append(" | ", style="dim") + t.append_text(qt) + return t diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index 9d863341a5..719d70cc43 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -29,11 +29,10 @@ import numpy as np import reactivex.operators as ops -from rich.console import Console -from rich.text import Text from dimos.types.timestamped import Timestamped +from .formatting import render_text, rich_text from .type import ( AfterFilter, AtFilter, @@ -66,14 +65,6 @@ T = TypeVar("T") R = TypeVar("R") -_console = Console(force_terminal=True, highlight=False) - - -def _render_text(text: Text) -> str: - with _console.capture() as cap: - _console.print(text, end="", soft_wrap=True) - return cap.get() - class StreamBackend(Protocol): """Backend protocol — implemented by SqliteStreamBackend etc.""" @@ -128,29 +119,11 @@ def _clone(self, **overrides: Any) -> Self: ) return clone - def _rich_text(self) -> Text: - t = Text() - cls = type(self).__name__ - type_name = self._payload_type.__name__ if self._payload_type else "?" - name = self._backend.stream_name if self._backend else "unbound" - t.append(cls, style="bold cyan") - t.append("[", style="dim") - t.append(type_name, style="yellow") - t.append("]", style="dim") - t.append("(", style="dim") - t.append(f'"{name}"', style="green") - t.append(")", style="dim") - query_text = self._query._rich_text() - if query_text.plain: - t.append(" | ", style="dim") - t.append_text(query_text) - return t - def __repr__(self) -> str: - return self._rich_text().plain + return rich_text(self).plain def __str__(self) -> str: - return _render_text(self._rich_text()) + return render_text(rich_text(self)) def _with_filter(self, f: Filter) -> Self: return self._clone(filters=(*self._query.filters, f)) @@ -417,12 +390,12 @@ def get_time_range(self) -> tuple[float, float]: def summary(self) -> str: from datetime import datetime, timezone - t = self._rich_text() + t = rich_text(self) n = self.count() if n == 0: t.append(": ", style="dim") t.append("empty", style="italic dim") - return _render_text(t) + return render_text(t) t0, t1 = self.get_time_range() fmt = "%Y-%m-%d %H:%M:%S" dt0 = datetime.fromtimestamp(t0, tz=timezone.utc).strftime(fmt) @@ -435,7 +408,7 @@ def summary(self) -> str: t.append(" — ", style="dim") t.append(dt1, style="bright_blue") t.append(f" ({dur:.1f}s)", style="dim yellow") - return _render_text(t) + return render_text(t) # ── Reactive ────────────────────────────────────────────────────── @@ -584,36 +557,11 @@ def __init__( self._live = live self._backfill_only = backfill_only - def _rich_text(self) -> Text: - t = Text() - type_name = self._transformer.output_type.__name__ if self._transformer.output_type else "?" - xf_name = type(self._transformer).__name__ - t.append("TransformStream", style="bold cyan") - t.append("[", style="dim") - t.append(type_name, style="yellow") - t.append("]", style="dim") - t.append("(", style="dim") - t.append_text(self._source._rich_text()) - t.append(" -> ", style="dim") - t.append(xf_name, style="magenta") - if self._live: - t.append(", ", style="dim") - t.append("live=True", style="yellow") - if self._backfill_only: - t.append(", ", style="dim") - t.append("backfill_only=True", style="yellow") - t.append(")", style="dim") - query_text = self._query._rich_text() - if query_text.plain: - t.append(" | ", style="dim") - t.append_text(query_text) - return t - def __repr__(self) -> str: - return self._rich_text().plain + return rich_text(self).plain def __str__(self) -> str: - return _render_text(self._rich_text()) + return render_text(rich_text(self)) def fetch(self) -> ObservationSet[R]: """Execute transform in memory, collecting results.""" @@ -780,18 +728,6 @@ def __init__( backend = ListBackend(cast("list[Observation[Any]]", observations)) super().__init__(backend=backend, session=session, payload_type=payload_type) - def _rich_text(self) -> Text: - t = Text() - type_name = self._payload_type.__name__ if self._payload_type else "?" - t.append("ObservationSet", style="bold cyan") - t.append("[", style="dim") - t.append(type_name, style="yellow") - t.append("]", style="dim") - t.append("(", style="dim") - t.append(f"{len(self._observations)} items", style="green") - t.append(")", style="dim") - return t - def _clone(self, **overrides: Any) -> Stream[T]: # type: ignore[override] """Downgrade to plain Stream — don't carry _observations through chaining.""" base: Stream[T] = Stream( diff --git a/dimos/memory/transformer.py b/dimos/memory/transformer.py index 9519322fb8..bdf39c410a 100644 --- a/dimos/memory/transformer.py +++ b/dimos/memory/transformer.py @@ -21,7 +21,7 @@ from collections.abc import Callable from dimos.models.embedding.base import Embedding, EmbeddingModel - from dimos.models.vl.base import Captioner, VlModel + from dimos.models.vl.base import Captioner from .stream import Stream from .type import Observation @@ -164,12 +164,13 @@ class CaptionTransformer(Transformer[Any, str]): supports_backfill: bool = True supports_live: bool = True - def __init__(self, model: Captioner) -> None: + def __init__(self, model: Captioner, *, batch_size: int = 8) -> None: self.model = model + self.batch_size = batch_size self.output_type: type | None = str def process(self, source: Stream[Any], target: Stream[str]) -> None: - for page in source.fetch_pages(): + for page in source.fetch_pages(batch_size=self.batch_size): images = [obs.data for obs in page] if not images: continue @@ -247,53 +248,3 @@ def on_append(self, obs: Observation[Any], target: Stream[Embedding]) -> None: if isinstance(emb, list): emb = emb[0] target.append(emb, ts=obs.ts, pose=obs.pose, tags=obs.tags, parent_id=obs.id) - - -class VLMDetectionTransformer(Transformer[Any, "Detection2DBBox"]): - """Wraps a VlModel to produce Detection2DBBox from images. - - Calls query_detections() per image, emitting one Detection2DBBox(image=None) - per bounding box with parent_id linking to the source image observation. - """ - - supports_backfill: bool = True - supports_live: bool = True - - def __init__(self, model: VlModel, query: str) -> None: - from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox - - self.model = model - self._query = query - self.output_type: type | None = Detection2DBBox - - def process(self, source: Stream[Any], target: Stream[Detection2DBBox]) -> None: - for page in source.fetch_pages(): - for obs in page: - self._detect(obs, target) - - def on_append(self, obs: Observation[Any], target: Stream[Detection2DBBox]) -> None: - self._detect(obs, target) - - def _detect(self, obs: Observation[Any], target: Stream[Detection2DBBox]) -> None: - from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox - - try: - result = self.model.query_detections(obs.data, self._query) - except Exception: - return - for det in result.detections: - target.append( - Detection2DBBox( - bbox=det.bbox, - track_id=det.track_id, - class_id=det.class_id, - confidence=det.confidence, - name=det.name, - ts=obs.ts or det.ts, - image=None, - ), - ts=obs.ts, - pose=obs.pose, - tags={**(obs.tags or {}), "query": self._query}, - parent_id=obs.id, - ) diff --git a/dimos/memory/type.py b/dimos/memory/type.py index ebd7308776..35982818f6 100644 --- a/dimos/memory/type.py +++ b/dimos/memory/type.py @@ -19,8 +19,6 @@ import math from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeVar -from rich.text import Text - from dimos.models.embedding.base import Embedding if TYPE_CHECKING: @@ -97,12 +95,6 @@ def matches(self, obs: Observation[Any]) -> bool: def __str__(self) -> str: return f"after(t={self.t})" - def _rich_text(self) -> Text: - t = Text() - t.append("after", style="cyan") - t.append(f"(t={self.t})") - return t - @dataclass(frozen=True) class BeforeFilter: @@ -114,12 +106,6 @@ def matches(self, obs: Observation[Any]) -> bool: def __str__(self) -> str: return f"before(t={self.t})" - def _rich_text(self) -> Text: - t = Text() - t.append("before", style="cyan") - t.append(f"(t={self.t})") - return t - @dataclass(frozen=True) class TimeRangeFilter: @@ -132,12 +118,6 @@ def matches(self, obs: Observation[Any]) -> bool: def __str__(self) -> str: return f"time_range({self.t1}, {self.t2})" - def _rich_text(self) -> Text: - t = Text() - t.append("time_range", style="cyan") - t.append(f"({self.t1}, {self.t2})") - return t - @dataclass(frozen=True) class AtFilter: @@ -150,12 +130,6 @@ def matches(self, obs: Observation[Any]) -> bool: def __str__(self) -> str: return f"at(t={self.t}, tol={self.tolerance})" - def _rich_text(self) -> Text: - t = Text() - t.append("at", style="cyan") - t.append(f"(t={self.t}, tol={self.tolerance})") - return t - @dataclass(frozen=True) class NearFilter: @@ -173,12 +147,6 @@ def matches(self, obs: Observation[Any]) -> bool: def __str__(self) -> str: return f"near(radius={self.radius})" - def _rich_text(self) -> Text: - t = Text() - t.append("near", style="cyan") - t.append(f"(radius={self.radius})") - return t - @dataclass(frozen=True) class TagsFilter: @@ -191,13 +159,6 @@ def __str__(self) -> str: pairs = ", ".join(f"{k}={v!r}" for k, v in self.tags) return f"tags({pairs})" - def _rich_text(self) -> Text: - t = Text() - t.append("tags", style="cyan") - pairs = ", ".join(f"{k}={v!r}" for k, v in self.tags) - t.append(f"({pairs})") - return t - @dataclass(frozen=True) class EmbeddingSearchFilter: @@ -214,17 +175,6 @@ def __str__(self) -> str: parts.insert(0, repr(self.label)) return f"search_embedding({', '.join(parts)})" - def _rich_text(self) -> Text: - t = Text() - t.append("search_embedding", style="cyan") - t.append("(") - if self.label: - t.append(repr(self.label), style="green") - t.append(", ") - t.append(f"k={self.k}") - t.append(")") - return t - @dataclass(frozen=True) class TextSearchFilter: @@ -237,12 +187,6 @@ def matches(self, obs: Observation[Any]) -> bool: def __str__(self) -> str: return f"text({self.text!r})" - def _rich_text(self) -> Text: - t = Text() - t.append("text", style="cyan") - t.append(f"({self.text!r})") - return t - @dataclass(frozen=True) class LineageFilter: @@ -263,13 +207,6 @@ def __str__(self) -> str: hops = " -> ".join(self.hops) if self.hops else "direct" return f"lineage({self.source_table} -> {hops})" - def _rich_text(self) -> Text: - t = Text() - t.append("lineage", style="cyan") - hops = " -> ".join(self.hops) if self.hops else "direct" - t.append(f"({self.source_table} -> {hops})") - return t - Filter: TypeAlias = ( AfterFilter @@ -304,29 +241,3 @@ def __str__(self) -> str: if self.offset_val is not None: parts.append(f"offset({self.offset_val})") return " | ".join(parts) - - def _rich_text(self) -> Text: - t = Text() - pipe = Text(" | ", style="dim") - parts: list[Text] = [f._rich_text() for f in self.filters] - if self.order_field: - p = Text() - p.append("order", style="cyan") - direction = "desc" if self.order_desc else "asc" - p.append(f"({self.order_field}, {direction})") - parts.append(p) - if self.limit_val is not None: - p = Text() - p.append("limit", style="cyan") - p.append(f"({self.limit_val})") - parts.append(p) - if self.offset_val is not None: - p = Text() - p.append("offset", style="cyan") - p.append(f"({self.offset_val})") - parts.append(p) - for i, part in enumerate(parts): - if i > 0: - t.append_text(pipe) - t.append_text(part) - return t From 0c09d49bd64d6f27382ec2fe107aa25f923de41c Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Sun, 8 Mar 2026 18:26:46 +0800 Subject: [PATCH 053/118] extract rich text rendering to formatting.py, add Stream.name, fix stale tests Move all _rich_text methods from type.py and stream.py into a central formatting.py module with a single rich_text() dispatch function. Replace relative imports with absolute imports across memory/. Add Stream.name property, remove VLMDetectionTransformer tests, fix stale test assertions. --- dimos/memory/formatting.py | 2 +- dimos/memory/impl/test_sqlite.py | 2 +- dimos/memory/store.py | 7 +- dimos/memory/stream.py | 21 +++-- dimos/memory/test_stream_repr.py | 2 +- dimos/memory/test_transformer.py | 133 +------------------------------ dimos/memory/transformer.py | 5 +- 7 files changed, 22 insertions(+), 150 deletions(-) diff --git a/dimos/memory/formatting.py b/dimos/memory/formatting.py index 1f3c785567..9cbf44779c 100644 --- a/dimos/memory/formatting.py +++ b/dimos/memory/formatting.py @@ -114,7 +114,7 @@ def _get_dispatch() -> dict[type, Callable[..., Text]]: global _FILTER_DISPATCH if _FILTER_DISPATCH is not None: return _FILTER_DISPATCH - from .type import ( + from dimos.memory.type import ( AfterFilter, AtFilter, BeforeFilter, diff --git a/dimos/memory/impl/test_sqlite.py b/dimos/memory/impl/test_sqlite.py index 2b1b71d021..3a51775dd0 100644 --- a/dimos/memory/impl/test_sqlite.py +++ b/dimos/memory/impl/test_sqlite.py @@ -739,7 +739,7 @@ def test_search_no_model_raises(self, session: SqliteSession) -> None: es = session.embedding_stream("pt_nomodel", vec_dimensions=3) es.append(Embedding(np.array([1.0, 0.0, 0.0], dtype=np.float32)), ts=1.0) - with pytest.raises(TypeError, match="no model reference"): + with pytest.raises(TypeError, match="No embedding model available"): es.search_embedding("hello", k=1) def test_no_lineage_fallback(self, session: SqliteSession) -> None: diff --git a/dimos/memory/store.py b/dimos/memory/store.py index b2d3646f92..0d2f549110 100644 --- a/dimos/memory/store.py +++ b/dimos/memory/store.py @@ -20,12 +20,11 @@ from dimos.core.resource import Resource if TYPE_CHECKING: + from dimos.memory.stream import EmbeddingStream, Stream, TextStream + from dimos.memory.transformer import Transformer + from dimos.memory.type import PoseProvider from dimos.models.embedding.base import Embedding, EmbeddingModel - from .stream import EmbeddingStream, Stream, TextStream - from .transformer import Transformer - from .type import PoseProvider - T = TypeVar("T") diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index 719d70cc43..eb515617aa 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -30,10 +30,8 @@ import numpy as np import reactivex.operators as ops -from dimos.types.timestamped import Timestamped - -from .formatting import render_text, rich_text -from .type import ( +from dimos.memory.formatting import render_text, rich_text +from dimos.memory.type import ( AfterFilter, AtFilter, BeforeFilter, @@ -48,6 +46,7 @@ TextSearchFilter, TimeRangeFilter, ) +from dimos.types.timestamped import Timestamped if TYPE_CHECKING: from collections.abc import Callable, Iterator @@ -56,12 +55,11 @@ from reactivex.abc import DisposableBase as Disposable from reactivex.subject import Subject + from dimos.memory.store import Session + from dimos.memory.transformer import Transformer from dimos.models.embedding.base import Embedding, EmbeddingModel from dimos.msgs.geometry_msgs.Pose import PoseLike - from .store import Session - from .transformer import Transformer - T = TypeVar("T") R = TypeVar("R") @@ -79,7 +77,9 @@ def do_append( tags: dict[str, Any] | None, parent_id: int | None = None, ) -> Observation[Any]: ... + def load_data(self, row_id: int) -> Any: ... + @property def appended_subject(self) -> Subject[Observation[Any]]: ... # type: ignore[type-arg] @property @@ -106,6 +106,11 @@ def __init__( self._session: Session | None = session self._payload_type: type | None = payload_type + @property + def name(self) -> str: + """The stream name in the backing store.""" + return self._require_backend().stream_name + def _clone(self, **overrides: Any) -> Self: """Return a shallow copy with updated query fields.""" q = self._query @@ -220,7 +225,7 @@ def transform( live: bool = False, backfill_only: bool = False, ) -> Stream[Any]: - from .transformer import PerItemTransformer, Transformer as TransformerABC + from dimos.memory.transformer import PerItemTransformer, Transformer as TransformerABC transformer: TransformerABC[Any, Any] if not isinstance(xf, TransformerABC): diff --git a/dimos/memory/test_stream_repr.py b/dimos/memory/test_stream_repr.py index 91a4009bd4..6d68b77bba 100644 --- a/dimos/memory/test_stream_repr.py +++ b/dimos/memory/test_stream_repr.py @@ -61,7 +61,7 @@ def test_tags_multiple(self) -> None: assert str(f) == "tags(cam='front', quality=1)" def test_embedding_search(self) -> None: - assert str(EmbeddingSearchFilter([0.1, 0.2], k=5)) == "search(k=5)" + assert str(EmbeddingSearchFilter([0.1, 0.2], k=5)) == "search_embedding(k=5)" def test_text_search(self) -> None: assert str(TextSearchFilter("error", k=None)) == "text('error')" diff --git a/dimos/memory/test_transformer.py b/dimos/memory/test_transformer.py index 87d6da134f..e48c5df268 100644 --- a/dimos/memory/test_transformer.py +++ b/dimos/memory/test_transformer.py @@ -22,10 +22,8 @@ import pytest from dimos.memory.impl.sqlite import SqliteSession, SqliteStore -from dimos.memory.transformer import TextEmbeddingTransformer, VLMDetectionTransformer +from dimos.memory.transformer import TextEmbeddingTransformer from dimos.models.embedding.base import Embedding, EmbeddingModel -from dimos.msgs.sensor_msgs.Image import Image -from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox if TYPE_CHECKING: from collections.abc import Iterator @@ -131,132 +129,3 @@ def test_text_embedding_search_and_project(self, session: SqliteSession) -> None results = emb_stream.search_embedding("kitchen", k=2).project_to(logs).fetch() assert len(results) == 2 assert all("kitchen" in r.data.lower() for r in results) - - -# ── Fake VLM for detection tests ───────────────────────────────────── - - -def _make_image(width: int = 640, height: int = 480) -> Image: - """Create a simple test image.""" - return Image(np.zeros((height, width, 3), dtype=np.uint8)) - - -class _FakeImageDetections2D: - """Mimics ImageDetections2D with a .detections list.""" - - def __init__(self, detections: list[Detection2DBBox]) -> None: - self.detections = detections - - -class FakeVlModel: - """Fake VlModel that returns canned detections per call. - - Pass a list of detection lists — one per call to query_detections(). - Each call pops the next entry from the list. - """ - - def __init__( - self, detections_per_call: list[list[tuple[str, float, float, float, float]]] - ) -> None: - self._queue = list(detections_per_call) - - def query_detections(self, image: Image, query: str) -> _FakeImageDetections2D: - raw = self._queue.pop(0) if self._queue else [] - ts = image.ts - dets = [] - for i, (name, x1, y1, x2, y2) in enumerate(raw): - dets.append( - Detection2DBBox( - bbox=(x1, y1, x2, y2), - track_id=i, - class_id=-1, - confidence=0.9, - name=name, - ts=ts, - image=image, - ) - ) - return _FakeImageDetections2D(dets) - - -class TestVLMDetectionTransformer: - """Test VLM detection transformer.""" - - def test_vlm_detection_backfill(self, session: SqliteSession) -> None: - """3 images, VLM finds 1 detection per image → 3 detections with parent_id.""" - frames = session.stream("vlm_frames", Image) - frames.append(_make_image(), ts=1.0) - frames.append(_make_image(), ts=2.0) - frames.append(_make_image(), ts=3.0) - - vlm = FakeVlModel( - [ - [("bottle", 10, 20, 100, 200)], - [("bottle", 50, 60, 150, 250)], - [("bottle", 30, 40, 130, 230)], - ] - ) - - det_stream = frames.transform( - VLMDetectionTransformer(vlm, query="bottle") # type: ignore[arg-type] - ).store("vlm_detections", Detection2DBBox) - - results = det_stream.fetch() - assert len(results) == 3 - - # All detections have parent_id linking back to source frames - frame_ids = {obs.id for obs in frames.fetch()} - for det_obs in results: - assert det_obs.parent_id in frame_ids - assert det_obs.data.image is None # stored without image - assert det_obs.data.name == "bottle" - assert det_obs.tags["query"] == "bottle" - - def test_vlm_detection_no_matches(self, session: SqliteSession) -> None: - """VLM returns empty detections → stream stays empty.""" - frames = session.stream("vlm_empty_frames", Image) - frames.append(_make_image(), ts=1.0) - frames.append(_make_image(), ts=2.0) - - vlm = FakeVlModel([[], []]) # no detections for either call - - det_stream = frames.transform( - VLMDetectionTransformer(vlm, query="cat") # type: ignore[arg-type] - ).store("vlm_empty_dets", Detection2DBBox) - - assert det_stream.count() == 0 - - def test_vlm_detection_multiple_per_frame(self, session: SqliteSession) -> None: - """1 image → 3 detections, all share same parent_id.""" - frames = session.stream("vlm_multi_frames", Image) - frames.append(_make_image(), ts=1.0) - - vlm = FakeVlModel( - [ - [ - ("bottle", 10, 20, 100, 200), - ("cup", 200, 100, 300, 250), - ("plate", 400, 300, 500, 400), - ], - ] - ) - - det_stream = frames.transform( - VLMDetectionTransformer(vlm, query="objects") # type: ignore[arg-type] - ).store("vlm_multi_dets", Detection2DBBox) - - results = det_stream.fetch() - assert len(results) == 3 - - # All share the same parent_id (the single source frame) - parent_ids = {obs.parent_id for obs in results} - assert len(parent_ids) == 1 - - names = {obs.data.name for obs in results} - assert names == {"bottle", "cup", "plate"} - - parent_ids = {obs.parent_id for obs in results} - assert len(parent_ids) == 1 - - names = {obs.data.name for obs in results} - assert names == {"bottle", "cup", "plate"} diff --git a/dimos/memory/transformer.py b/dimos/memory/transformer.py index bdf39c410a..cf76ae1bf6 100644 --- a/dimos/memory/transformer.py +++ b/dimos/memory/transformer.py @@ -20,12 +20,11 @@ if TYPE_CHECKING: from collections.abc import Callable + from dimos.memory.stream import Stream + from dimos.memory.type import Observation from dimos.models.embedding.base import Embedding, EmbeddingModel from dimos.models.vl.base import Captioner - from .stream import Stream - from .type import Observation - T = TypeVar("T") R = TypeVar("R") From a954f792ba0346fdafe63ed407e0b697a2614e13 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Sun, 8 Mar 2026 19:12:21 +0800 Subject: [PATCH 054/118] matching based on streams --- dimos/memory/formatting.py | 9 ++++++++- dimos/memory/stream.py | 34 ++++++++++++++++++++++++++++++++-- dimos/memory/transformer.py | 2 +- dimos/models/embedding/clip.py | 6 +++--- dimos/models/vl/florence.py | 4 ++-- 5 files changed, 46 insertions(+), 9 deletions(-) diff --git a/dimos/memory/formatting.py b/dimos/memory/formatting.py index 9cbf44779c..50a46e8429 100644 --- a/dimos/memory/formatting.py +++ b/dimos/memory/formatting.py @@ -68,7 +68,14 @@ def _at_rich(f: Any) -> Text: def _near_rich(f: Any) -> Text: t = Text() t.append("near", style="cyan") - t.append(f"(radius={f.radius})") + t.append("(") + if f.pose is not None and hasattr(f.pose, "position"): + p = f.pose.position + t.append(f"[{p.x:.1f}, {p.y:.1f}, {p.z:.1f}]", style="green") + t.append(f", radius={f.radius:.2f}") + else: + t.append(f"radius={f.radius}") + t.append(")") return t diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index eb515617aa..e478858744 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -174,12 +174,18 @@ def before(self, t: float) -> Stream[T]: def time_range(self, t1: float, t2: float) -> Stream[T]: return self._with_filter(TimeRangeFilter(t1, t2)) - def at(self, t: float, *, tolerance: float = 1.0) -> Stream[T]: + def at(self, t: float | Stream[Any], *, tolerance: float = 1.0) -> Stream[T]: + if isinstance(t, Stream): + t1, t2 = t.get_time_range() + return self._with_filter(TimeRangeFilter(t1 - tolerance, t2 + tolerance)) return self._with_filter(AtFilter(t, tolerance)) # ── Spatial filter ──────────────────────────────────────────────── - def near(self, pose: PoseLike, radius: float) -> Stream[T]: + def near(self, pose: PoseLike | Stream[Any], radius: float) -> Stream[T]: + if isinstance(pose, Stream): + center, max_dist = pose.bounding_sphere() + return self._with_filter(NearFilter(center, max_dist + radius)) return self._with_filter(NearFilter(pose, radius)) # ── Tag filter ──────────────────────────────────────────────────── @@ -392,6 +398,30 @@ def delete(self) -> None: def get_time_range(self) -> tuple[float, float]: return (self.first().ts, self.last().ts) + def bounding_sphere(self) -> tuple[Any, float]: + """Return (centroid_pose, max_distance_from_centroid) for all poses.""" + xs: list[float] = [] + ys: list[float] = [] + zs: list[float] = [] + for obs in self: + if obs.pose is None: + continue + p = obs.pose.position + xs.append(p.x) + ys.append(p.y) + zs.append(p.z) + if not xs: + raise ValueError("No observations with poses in this stream") + cx, cy, cz = sum(xs) / len(xs), sum(ys) / len(ys), sum(zs) / len(zs) + max_dist = max( + ((x - cx) ** 2 + (y - cy) ** 2 + (z - cz) ** 2) ** 0.5 + for x, y, z in zip(xs, ys, zs, strict=True) + ) + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + + center = PoseStamped(position=[cx, cy, cz]) + return center, max_dist + def summary(self) -> str: from datetime import datetime, timezone diff --git a/dimos/memory/transformer.py b/dimos/memory/transformer.py index cf76ae1bf6..6e090a41ed 100644 --- a/dimos/memory/transformer.py +++ b/dimos/memory/transformer.py @@ -163,7 +163,7 @@ class CaptionTransformer(Transformer[Any, str]): supports_backfill: bool = True supports_live: bool = True - def __init__(self, model: Captioner, *, batch_size: int = 8) -> None: + def __init__(self, model: Captioner, *, batch_size: int = 16) -> None: self.model = model self.batch_size = batch_size self.output_type: type | None = str diff --git a/dimos/models/embedding/clip.py b/dimos/models/embedding/clip.py index 1b8d3e68bb..3337bfcaf6 100644 --- a/dimos/models/embedding/clip.py +++ b/dimos/models/embedding/clip.py @@ -77,9 +77,9 @@ def embed_text(self, *texts: str) -> Embedding | list[Embedding]: Returns embeddings as torch.Tensor on device for efficient GPU comparisons. """ with torch.inference_mode(): - inputs = self._processor(text=list(texts), return_tensors="pt", padding=True).to( - self.config.device - ) + inputs = self._processor( + text=list(texts), return_tensors="pt", padding=True, truncation=True + ).to(self.config.device) text_features = self._model.get_text_features(**inputs) if self.config.normalize: diff --git a/dimos/models/vl/florence.py b/dimos/models/vl/florence.py index 993c573e56..19d99ccdfd 100644 --- a/dimos/models/vl/florence.py +++ b/dimos/models/vl/florence.py @@ -66,13 +66,13 @@ def _processor(self) -> AutoProcessor: self.config.model_name, trust_remote_code=self.config.trust_remote_code ) - _STRIP_PREFIXES = ("The image shows ", "The image is a ") + _STRIP_PREFIXES = ("The image shows ", "The image is a ", "A ") @staticmethod def _clean_caption(text: str) -> str: for prefix in Florence2Model._STRIP_PREFIXES: if text.startswith(prefix): - return text[len(prefix):] + return text[len(prefix) :] return text def caption(self, image: Image) -> str: From ab481713db2e0a41fb4a186aba6ba69be0d200ed Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Sun, 8 Mar 2026 19:12:36 +0800 Subject: [PATCH 055/118] projection experiments --- dimos/memory/test_projection.py | 163 ++++++++++++++++++++++++++++++++ 1 file changed, 163 insertions(+) create mode 100644 dimos/memory/test_projection.py diff --git a/dimos/memory/test_projection.py b/dimos/memory/test_projection.py new file mode 100644 index 0000000000..ae84b9184f --- /dev/null +++ b/dimos/memory/test_projection.py @@ -0,0 +1,163 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Generator + +import pytest + +from dimos.memory.impl.sqlite import SqliteSession, SqliteStore +from dimos.memory.transformer import ( + CaptionTransformer, + EmbeddingTransformer, + QualityWindowTransformer, + TextEmbeddingTransformer, +) +from dimos.models.embedding.base import Embedding +from dimos.models.embedding.clip import CLIPModel +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.utils.data import get_data + + +@pytest.fixture(scope="module") +def store() -> Generator[SqliteStore, None, None]: + with SqliteStore(get_data("go2_bigoffice.db")) as store: + yield store + + +@pytest.fixture(scope="module") +def session(store: SqliteStore) -> Generator[SqliteSession, None, None]: + with store.session() as session: + yield session + + +@pytest.fixture(scope="module") +def image_stream(session): + return session.stream("color_image", Image) + + +@pytest.fixture(scope="module") +def lidar_stream(session): + return session.stream("lidar", PointCloud2) + + +@pytest.fixture(scope="module") +def clip() -> CLIPModel: + model = CLIPModel() + model.start() + return model + + +def test_list_streams(session): + print("") + for stream in session.list_streams(): + print(stream.summary()) + + +@pytest.mark.tool +def test_make_embedding(session, lidar_stream, image_stream, clip): + embeddings = ( + image_stream.transform( + QualityWindowTransformer(lambda img: img.sharpness, window=1.0), + live=False, + backfill_only=True, + ) + .store("sharp_images", Image) + .transform(EmbeddingTransformer(clip), live=False, backfill_only=True) + .store("clip_embeddings", Embedding) + ) + print(embeddings) + print(f"Stored {embeddings.count()} embeddings") + + +@pytest.mark.tool +def test_make_caption(session, clip): + from dimos.models.vl.florence import CaptionDetail, Florence2Model + + print("") + + session.streams.captions.delete() + session.streams.super_sharp_images.delete() + session.streams.caption_embeddings.delete() + + florence = Florence2Model(detail=CaptionDetail.NORMAL) + florence.start() + + super_sharp_images = session.streams.sharp_images.transform( + QualityWindowTransformer(lambda img: img.sharpness, window=3.0), + backfill_only=True, + ).store("super_sharp_images", Image) + + print(super_sharp_images.summary()) + + captions = super_sharp_images.transform(CaptionTransformer(florence), backfill_only=True).store( + "captions", str + ) + + print(captions.summary()) + + florence.stop() + + caption_embeddings = captions.transform( + TextEmbeddingTransformer(clip), backfill_only=True + ).store("caption_embeddings", Embedding) + + print(caption_embeddings.summary()) + print(f"Stored {caption_embeddings.count()} caption embeddings") + + +@pytest.mark.tool +def test_query_embeddings(session, clip): + embeddings = session.streams.clip_embeddings.search_embedding("supermarket", k=5, model=clip) + + caption_search = session.streams.captions.near(embeddings, radius=1.0) + print(caption_search) + + captions = caption_search.fetch() + + print(captions.summary()) + for obs in captions: + print(obs.data) + + images = session.streams.color_image.near(embeddings, radius=1.0).fetch() + print(images) + + +@pytest.mark.tool +def test_print_captions(session, clip): + for caption in session.streams.captions: + print(caption.data) + + +def test_search_embeddings(session, clip): + print("") + embedding_stream = session.embedding_stream("clip_embeddings", embedding_model=clip) + + search = embedding_stream.search_embedding("supermarket", k=5) + print(search) + + project = search.project_to(session.streams.color_image) + print(project) + + results = project.fetch() + print(results) + results = project.fetch() + print(results) + print(results) + print(results) + print(results) + print(results) + print(results) + print(results) + print(results) From a80bbb991bd61d570bc08b93def9db5b793a2aba Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Sun, 8 Mar 2026 19:40:00 +0800 Subject: [PATCH 056/118] projection bugfix --- dimos/memory/impl/sqlite.py | 30 +++++++---------- dimos/memory/impl/test_sqlite.py | 4 +-- dimos/memory/test_projection.py | 58 ++++++++++++++++++++++++++++---- 3 files changed, 65 insertions(+), 27 deletions(-) diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index 3241b26168..2f5692fd27 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -256,25 +256,25 @@ def _compile_query(query: StreamQuery, table: str) -> tuple[str, list[Any]]: for f in query.filters: if isinstance(f, NearFilter): - # R*Tree bounding-box join + # R*Tree bounding-box pre-filter + exact Euclidean distance in SQL joins.append(f"JOIN {table}_rtree AS r ON r.id = {table}.id") + p = f.pose.position + x, y, z = p.x, p.y, p.z where_parts.append( "r.min_x >= ? AND r.max_x <= ? AND " "r.min_y >= ? AND r.max_y <= ? AND " "r.min_z >= ? AND r.max_z <= ?" ) - p = f.pose.position - x, y, z = p.x, p.y, p.z params.extend( - [ - x - f.radius, - x + f.radius, - y - f.radius, - y + f.radius, - z - f.radius, - z + f.radius, - ] + [x - f.radius, x + f.radius, y - f.radius, y + f.radius, z - f.radius, z + f.radius] + ) + # Exact spherical check so LIMIT/OFFSET work correctly + where_parts.append( + f"({table}.pose_x - ?) * ({table}.pose_x - ?) + " + f"({table}.pose_y - ?) * ({table}.pose_y - ?) + " + f"({table}.pose_z - ?) * ({table}.pose_z - ?) <= ? * ?" ) + params.extend([x, x, y, y, z, z, f.radius, f.radius]) else: sql_frag, p = _compile_filter(f, table) where_parts.append(sql_frag) @@ -464,13 +464,7 @@ def do_append( def execute_fetch(self, query: StreamQuery) -> list[Observation[Any]]: sql, params = _compile_query(query, self._table) rows = self._conn.execute(sql, params).fetchall() - observations = [self._row_to_obs(r) for r in rows] - - near = _has_near_filter(query) - if near is not None: - observations = _apply_near_post_filter(observations, near) - - return observations + return [self._row_to_obs(r) for r in rows] def execute_count(self, query: StreamQuery) -> int: sql, params = _compile_count(query, self._table) diff --git a/dimos/memory/impl/test_sqlite.py b/dimos/memory/impl/test_sqlite.py index 3a51775dd0..0886eabc44 100644 --- a/dimos/memory/impl/test_sqlite.py +++ b/dimos/memory/impl/test_sqlite.py @@ -167,10 +167,10 @@ def test_filter_tags(self, session: SqliteSession, images: list[Image]) -> None: assert _img_close(rows[0].data, images[0]) def test_chained_filters(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("data", Image) + s = session.stream("chained_filter_data", Image) s.append(images[0], ts=1.0, tags={"cam": "front"}) s.append(images[1], ts=5.0, tags={"cam": "front"}) - s.append(images[2], ts=5.0, tags={"cam": "rear"}) + s.append(images[2], ts=6.0, tags={"cam": "rear"}) rows = s.after(3.0).filter_tags(cam="front").fetch() assert len(rows) == 1 diff --git a/dimos/memory/test_projection.py b/dimos/memory/test_projection.py index ae84b9184f..ac333adec4 100644 --- a/dimos/memory/test_projection.py +++ b/dimos/memory/test_projection.py @@ -25,6 +25,7 @@ ) from dimos.models.embedding.base import Embedding from dimos.models.embedding.clip import CLIPModel +from dimos.models.vl.florence import CaptionDetail, Florence2Model from dimos.msgs.sensor_msgs.Image import Image from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.utils.data import get_data @@ -119,21 +120,67 @@ def test_make_caption(session, clip): @pytest.mark.tool def test_query_embeddings(session, clip): + print("\n") + embeddings = session.streams.clip_embeddings.search_embedding("supermarket", k=5, model=clip) - caption_search = session.streams.captions.near(embeddings, radius=1.0) + print(embeddings) + florence = Florence2Model(detail=CaptionDetail.NORMAL) + florence.start() + + # ~600 results: + # images = session.streams.sharp_images.near(embeddings, radius=1.0).fetch() + # caption_search = images.transform( + # CaptionTransformer(florence) + # ) + + # 3 results + caption_search = session.streams.sharp_images.near(embeddings).transform( + CaptionTransformer(florence) + ) + print(caption_search) captions = caption_search.fetch() print(captions.summary()) + florence.stop() + for obs in captions: - print(obs.data) + print(obs.id, obs.data) images = session.streams.color_image.near(embeddings, radius=1.0).fetch() + print(images) +def test_count_comparison(session, clip): + """Compare fetch-then-transform vs transform-then-fetch counts.""" + print("\n") + embeddings = session.streams.clip_embeddings.search_embedding("supermarket", k=5, model=clip) + + # Count from near() directly + near_stream = session.streams.color_image.near(embeddings, radius=1.0) + fetched = near_stream.fetch() + print(f"near().fetch() count: {len(fetched)}") + + # Approach 1: fetch first, then transform with identity lambda + result1 = fetched.transform(lambda x: x).fetch() + print(f"fetch().transform(id).fetch() count: {len(result1)}") + + # Approach 2: transform on lazy stream, then fetch + near_stream2 = session.streams.color_image.near(embeddings, radius=1.0) + result2 = near_stream2.transform(lambda x: x).fetch() + print(f"near().transform(id).fetch() count: {len(result2)}") + + assert len(fetched) == len(result1), ( + f"fetch-then-transform mismatch: {len(fetched)} vs {len(result1)}" + ) + assert len(fetched) == len(result2), ( + f"transform-then-fetch mismatch: {len(fetched)} vs {len(result2)}" + ) + + @pytest.mark.tool def test_print_captions(session, clip): for caption in session.streams.captions: @@ -154,10 +201,7 @@ def test_search_embeddings(session, clip): print(results) results = project.fetch() print(results) + results = project.fetch() print(results) - print(results) - print(results) - print(results) - print(results) - print(results) + results = project.fetch() print(results) From c7522d3cb9616aceec6a381ca9d75b1dcb787a21 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Sun, 8 Mar 2026 19:46:08 +0800 Subject: [PATCH 057/118] observationset typing fix --- dimos/memory/stream.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index e478858744..52eed600c1 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -182,7 +182,7 @@ def at(self, t: float | Stream[Any], *, tolerance: float = 1.0) -> Stream[T]: # ── Spatial filter ──────────────────────────────────────────────── - def near(self, pose: PoseLike | Stream[Any], radius: float) -> Stream[T]: + def near(self, pose: PoseLike | Stream[Any], radius: float = 0.0) -> Stream[T]: if isinstance(pose, Stream): center, max_dist = pose.bounding_sphere() return self._with_filter(NearFilter(center, max_dist + radius)) @@ -603,7 +603,11 @@ def fetch(self) -> ObservationSet[R]: collector = _CollectorStream[R]() if self._transformer.supports_backfill and not self._live: self._transformer.process(self._source, collector) - return ObservationSet(collector.results, session=self._source._session) + return ObservationSet( + collector.results, + session=self._source._session, + payload_type=self._transformer.output_type, + ) def store( self, From decd090ab4575354223a5932625a6c6679b33845 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Sun, 8 Mar 2026 19:59:00 +0800 Subject: [PATCH 058/118] detections, cleanup --- dimos/memory/formatting.py | 2 +- dimos/memory/impl/sqlite.py | 1 - dimos/memory/stream.py | 1 + dimos/memory/test_projection.py | 52 +++++--- dimos/memory/test_stream_repr.py | 2 +- dimos/memory/test_transformer.py | 123 +++++++++++++++++- dimos/memory/transformer.py | 80 +++++++++++- .../detection/type/detection2d/bbox.py | 2 +- .../detection/type/imageDetections.py | 9 +- dimos/perception/detection/type/utils.py | 8 +- 10 files changed, 251 insertions(+), 29 deletions(-) diff --git a/dimos/memory/formatting.py b/dimos/memory/formatting.py index 50a46e8429..81145a07ba 100644 --- a/dimos/memory/formatting.py +++ b/dimos/memory/formatting.py @@ -203,7 +203,7 @@ def rich_text(obj: Any) -> Text: t.append("(", style="dim") t.append_text(rich_text(obj._source)) t.append(" -> ", style="dim") - t.append(type(xf).__name__, style="magenta") + t.append(repr(xf), style="magenta") if obj._live: t.append(", ", style="dim") t.append("live=True", style="yellow") diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index 2f5692fd27..9ee8219627 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -409,7 +409,6 @@ def do_append( pose_cols = _decompose_pose(pose) tags_json = _serialize_tags(tags) - print("APPEND", payload) # Encode payload before touching the DB so a codec error can't leave # a metadata row without a matching payload row. payload_blob = self._codec.encode(payload) diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index 52eed600c1..3426cae66d 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -655,6 +655,7 @@ def append( obs: Observation[R] = Observation( id=self._next_id, ts=ts if ts is not None else time.time(), + pose=pose, tags=tags or {}, parent_id=parent_id, _data=payload, diff --git a/dimos/memory/test_projection.py b/dimos/memory/test_projection.py index ac333adec4..9e5bf07902 100644 --- a/dimos/memory/test_projection.py +++ b/dimos/memory/test_projection.py @@ -19,6 +19,7 @@ from dimos.memory.impl.sqlite import SqliteSession, SqliteStore from dimos.memory.transformer import ( CaptionTransformer, + DetectionTransformer, EmbeddingTransformer, QualityWindowTransformer, TextEmbeddingTransformer, @@ -26,6 +27,7 @@ from dimos.models.embedding.base import Embedding from dimos.models.embedding.clip import CLIPModel from dimos.models.vl.florence import CaptionDetail, Florence2Model +from dimos.models.vl.moondream import MoondreamVlModel from dimos.msgs.sensor_msgs.Image import Image from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.utils.data import get_data @@ -84,8 +86,6 @@ def test_make_embedding(session, lidar_stream, image_stream, clip): @pytest.mark.tool def test_make_caption(session, clip): - from dimos.models.vl.florence import CaptionDetail, Florence2Model - print("") session.streams.captions.delete() @@ -125,34 +125,49 @@ def test_query_embeddings(session, clip): embeddings = session.streams.clip_embeddings.search_embedding("supermarket", k=5, model=clip) print(embeddings) + + # we can create captions on demand florence = Florence2Model(detail=CaptionDetail.NORMAL) florence.start() - # ~600 results: - # images = session.streams.sharp_images.near(embeddings, radius=1.0).fetch() - # caption_search = images.transform( - # CaptionTransformer(florence) - # ) - - # 3 results - caption_search = session.streams.sharp_images.near(embeddings).transform( - CaptionTransformer(florence) + caption_query = ( + session.streams.sharp_images.near(embeddings) + .limit(5) + .transform(CaptionTransformer(florence)) ) + florence.stop() + + # we could have also searched in the db (if precomputed) + # caption_query = session.streams.captions.near(embeddings) - print(caption_search) + print(caption_query) - captions = caption_search.fetch() + captions = caption_query.fetch() print(captions.summary()) - florence.stop() for obs in captions: print(obs.id, obs.data) - images = session.streams.color_image.near(embeddings, radius=1.0).fetch() + # we can also find all images ever captured near these embeddings (600+ frames) + images = session.streams.sharp_images.near(embeddings).fetch() print(images) + moondream = MoondreamVlModel() + moondream.start() + + bottles = session.streams.sharp_images.near(embeddings, radius=1.0).transform( + DetectionTransformer(moondream, query="bottle") + ) + + print(bottles) + + for bottle in bottles.fetch(): + print(bottle.data) + + moondream.stop() + def test_count_comparison(session, clip): """Compare fetch-then-transform vs transform-then-fetch counts.""" @@ -205,3 +220,10 @@ def test_search_embeddings(session, clip): print(results) results = project.fetch() print(results) + results = project.fetch() + print(results) + results = project.fetch() + print(results) + print(results) + print(results) + print(results) diff --git a/dimos/memory/test_stream_repr.py b/dimos/memory/test_stream_repr.py index 6d68b77bba..1fa7239578 100644 --- a/dimos/memory/test_stream_repr.py +++ b/dimos/memory/test_stream_repr.py @@ -168,7 +168,7 @@ def test_transform_with_typed_transformer(self, session) -> None: print(t) assert ( repr(t) - == 'TransformStream[Embedding](Stream[int]("images") -> EmbeddingTransformer, live=True)' + == 'TransformStream[Embedding](Stream[int]("images") -> EmbeddingTransformer(MagicMock), live=True)' ) def test_embedding_stream_from_source(self, session) -> None: diff --git a/dimos/memory/test_transformer.py b/dimos/memory/test_transformer.py index e48c5df268..bb526a302a 100644 --- a/dimos/memory/test_transformer.py +++ b/dimos/memory/test_transformer.py @@ -22,8 +22,10 @@ import pytest from dimos.memory.impl.sqlite import SqliteSession, SqliteStore -from dimos.memory.transformer import TextEmbeddingTransformer +from dimos.memory.transformer import DetectionTransformer, TextEmbeddingTransformer from dimos.models.embedding.base import Embedding, EmbeddingModel +from dimos.msgs.sensor_msgs import Image +from dimos.perception.detection.type import Detection2DBBox, ImageDetections2D if TYPE_CHECKING: from collections.abc import Iterator @@ -129,3 +131,122 @@ def test_text_embedding_search_and_project(self, session: SqliteSession) -> None results = emb_stream.search_embedding("kitchen", k=2).project_to(logs).fetch() assert len(results) == 2 assert all("kitchen" in r.data.lower() for r in results) + + +def _make_image(ts: float) -> Image: + return Image(data=np.zeros((64, 64, 3), dtype=np.uint8), ts=ts) + + +class FakeVlModel: + """Minimal VlModel stub for detection tests.""" + + def __init__( + self, + detections_per_image: int = 2, + *, + raise_on_call: bool = False, + ) -> None: + self.detections_per_image = detections_per_image + self.raise_on_call = raise_on_call + + def query_detections(self, image: Image, query: str, **kwargs: object) -> ImageDetections2D: + if self.raise_on_call: + raise RuntimeError("model error") + dets = [ + Detection2DBBox( + bbox=(10.0 * i, 10.0 * i, 20.0 * i + 20, 20.0 * i + 20), + track_id=i, + class_id=-1, + confidence=0.9, + name=query, + ts=image.ts, + image=image, + ) + for i in range(self.detections_per_image) + ] + return ImageDetections2D(image=image, detections=dets) + + +class TestDetectionTransformer: + """Test VLM detection transformer.""" + + def test_detection_backfill(self, session: SqliteSession) -> None: + """Backfill: 3 images → transform → 3 detection observations.""" + imgs = session.stream("det_imgs", Image) + for i in range(3): + imgs.append(_make_image(float(i + 1)), ts=float(i + 1)) + + det_stream = imgs.transform(DetectionTransformer(FakeVlModel(2), "cup")).store("det_cups") + + assert det_stream.count() == 3 + results = det_stream.fetch() + for obs in results: + assert obs.data.image is None, "image should be stripped" + for det in obs.data.detections: + assert det.image is None, "detection image should be stripped" + assert obs.tags["query"] == "cup" + assert obs.tags["count"] == 2 + + def test_detection_skip_empty(self, session: SqliteSession) -> None: + """skip_empty=True (default): 0 detections → observation skipped.""" + imgs = session.stream("det_skip_imgs", Image) + imgs.append(_make_image(1.0), ts=1.0) + + det_stream = imgs.transform(DetectionTransformer(FakeVlModel(0), "nothing")).store( + "det_skip" + ) + + assert det_stream.count() == 0 + + def test_detection_keep_empty(self, session: SqliteSession) -> None: + """skip_empty=False: 0 detections → observation stored with count=0.""" + imgs = session.stream("det_keep_imgs", Image) + imgs.append(_make_image(1.0), ts=1.0) + + det_stream = imgs.transform( + DetectionTransformer(FakeVlModel(0), "nothing", skip_empty=False) + ).store("det_keep") + + assert det_stream.count() == 1 + obs = det_stream.fetch()[0] + assert obs.tags["count"] == 0 + assert len(obs.data.detections) == 0 + + def test_detection_model_error(self, session: SqliteSession) -> None: + """Model raises → observation skipped, no crash.""" + imgs = session.stream("det_err_imgs", Image) + imgs.append(_make_image(1.0), ts=1.0) + + det_stream = imgs.transform( + DetectionTransformer(FakeVlModel(raise_on_call=True), "cup") + ).store("det_err") + + assert det_stream.count() == 0 + + def test_detection_lineage(self, session: SqliteSession) -> None: + """project_to(image_stream) recovers source images.""" + imgs = session.stream("det_lin_imgs", Image) + imgs.append(_make_image(1.0), ts=1.0) + imgs.append(_make_image(2.0), ts=2.0) + + det_stream = imgs.transform(DetectionTransformer(FakeVlModel(1), "obj")).store("det_lin") + + projected = det_stream.project_to(imgs).fetch() + assert len(projected) == 2 + for obs in projected: + assert isinstance(obs.data, Image) + + def test_detection_live(self, session: SqliteSession) -> None: + """Live mode: append images after transform, verify reactive detection.""" + imgs = session.stream("det_live_imgs", Image) + det_stream = imgs.transform(DetectionTransformer(FakeVlModel(1), "cup"), live=True).store( + "det_live" + ) + + assert det_stream.count() == 0 + + imgs.append(_make_image(1.0), ts=1.0) + assert det_stream.count() == 1 + + imgs.append(_make_image(2.0), ts=2.0) + assert det_stream.count() == 2 diff --git a/dimos/memory/transformer.py b/dimos/memory/transformer.py index 6e090a41ed..3d0e73f55f 100644 --- a/dimos/memory/transformer.py +++ b/dimos/memory/transformer.py @@ -15,6 +15,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +import logging from typing import TYPE_CHECKING, Any, Generic, TypeVar if TYPE_CHECKING: @@ -23,7 +24,10 @@ from dimos.memory.stream import Stream from dimos.memory.type import Observation from dimos.models.embedding.base import Embedding, EmbeddingModel - from dimos.models.vl.base import Captioner + from dimos.models.vl.base import Captioner, VlModel + from dimos.perception.detection.type import ImageDetections2D + +logger = logging.getLogger(__name__) T = TypeVar("T") R = TypeVar("R") @@ -36,6 +40,9 @@ class Transformer(ABC, Generic[T, R]): supports_live: bool = True output_type: type | None = None + def __repr__(self) -> str: + return type(self).__name__ + @abstractmethod def process(self, source: Stream[T], target: Stream[R]) -> None: """Batch/historical processing. @@ -86,6 +93,10 @@ class QualityWindowTransformer(Transformer[T, T]): def __init__(self, quality_fn: Callable[[T], float], window: float = 0.5) -> None: self._quality_fn = quality_fn self._window = window + + def __repr__(self) -> str: + fn_name = getattr(self._quality_fn, "__name__", None) or repr(self._quality_fn) + return f"QualityWindowTransformer({fn_name}, window={self._window})" # Live state self._window_start: float | None = None self._best_obs: Observation[T] | None = None @@ -168,6 +179,13 @@ def __init__(self, model: Captioner, *, batch_size: int = 16) -> None: self.batch_size = batch_size self.output_type: type | None = str + def __repr__(self) -> str: + model_name = type(self.model).__name__ + parts = [model_name] + if self.batch_size != 16: + parts.append(f"batch_size={self.batch_size}") + return f"CaptionTransformer({', '.join(parts)})" + def process(self, source: Stream[Any], target: Stream[str]) -> None: for page in source.fetch_pages(batch_size=self.batch_size): images = [obs.data for obs in page] @@ -198,6 +216,9 @@ def __init__(self, model: EmbeddingModel) -> None: self.model = model self.output_type: type | None = EmbeddingCls + def __repr__(self) -> str: + return f"TextEmbeddingTransformer({type(self.model).__name__})" + def process(self, source: Stream[Any], target: Stream[Embedding]) -> None: for page in source.fetch_pages(): texts = [str(obs.data) for obs in page] @@ -231,6 +252,9 @@ def __init__(self, model: EmbeddingModel) -> None: self.model = model self.output_type: type | None = EmbeddingCls + def __repr__(self) -> str: + return f"EmbeddingTransformer({type(self.model).__name__})" + def process(self, source: Stream[Any], target: Stream[Embedding]) -> None: for page in source.fetch_pages(): images = [obs.data for obs in page] @@ -247,3 +271,57 @@ def on_append(self, obs: Observation[Any], target: Stream[Embedding]) -> None: if isinstance(emb, list): emb = emb[0] target.append(emb, ts=obs.ts, pose=obs.pose, tags=obs.tags, parent_id=obs.id) + + +class DetectionTransformer(Transformer[Any, "ImageDetections2D"]): + """Runs VLM object detection on images, producing ImageDetections2D. + + Strips image references from detections before storage to avoid + duplicating image data. Use project_to(image_stream) to recover + source images via lineage. + """ + + supports_backfill = True + supports_live = True + + def __init__(self, model: VlModel, query: str, *, skip_empty: bool = True) -> None: + from dimos.perception.detection.type import ImageDetections2D as IDet2D + + self.model = model + self.query = query + self.skip_empty = skip_empty + self.output_type: type | None = IDet2D + + def __repr__(self) -> str: + model_name = type(self.model).__name__ + parts = [f"{model_name}, {self.query!r}"] + if not self.skip_empty: + parts.append("skip_empty=False") + return f"DetectionTransformer({', '.join(parts)})" + + def process(self, source: Stream[Any], target: Stream[ImageDetections2D]) -> None: + for page in source.fetch_pages(): + for obs in page: + self._detect_and_append(obs, target) + + def on_append(self, obs: Observation[Any], target: Stream[ImageDetections2D]) -> None: + self._detect_and_append(obs, target) + + def _detect_and_append(self, obs: Observation[Any], target: Stream[ImageDetections2D]) -> None: + try: + detections = self.model.query_detections(obs.data, self.query) + except Exception: + logger.warning("Detection failed for obs %s, skipping", obs.id, exc_info=True) + return + + count = len(detections) + if count == 0 and self.skip_empty: + return + + # Strip image refs to avoid duplicating image data in storage + detections.image = None + for det in detections.detections: + det.image = None + + tags = {**(obs.tags or {}), "query": self.query, "count": count} + target.append(detections, ts=obs.ts, pose=obs.pose, tags=tags, parent_id=obs.id) diff --git a/dimos/perception/detection/type/detection2d/bbox.py b/dimos/perception/detection/type/detection2d/bbox.py index 45dc848e9d..6022e010cb 100644 --- a/dimos/perception/detection/type/detection2d/bbox.py +++ b/dimos/perception/detection/type/detection2d/bbox.py @@ -85,7 +85,7 @@ class Detection2DBBox(Detection2D): confidence: float name: str ts: float - image: Image + image: Image | None def to_repr_dict(self) -> dict[str, Any]: """Return a dictionary representation of the detection for display purposes.""" diff --git a/dimos/perception/detection/type/imageDetections.py b/dimos/perception/detection/type/imageDetections.py index 12a1f4efb9..a3d8acebd1 100644 --- a/dimos/perception/detection/type/imageDetections.py +++ b/dimos/perception/detection/type/imageDetections.py @@ -45,12 +45,13 @@ class ImageDetections(Generic[T], TableStr): def ts(self) -> float: return self.image.ts - def __init__(self, image: Image, detections: list[T] | None = None) -> None: + def __init__(self, image: Image | None = None, detections: list[T] | None = None) -> None: self.image = image self.detections = detections or [] - for det in self.detections: - if not det.ts: - det.ts = image.ts + if image is not None: + for det in self.detections: + if not det.ts: + det.ts = image.ts def __len__(self) -> int: return len(self.detections) diff --git a/dimos/perception/detection/type/utils.py b/dimos/perception/detection/type/utils.py index eb924cbd1a..35c3909698 100644 --- a/dimos/perception/detection/type/utils.py +++ b/dimos/perception/detection/type/utils.py @@ -53,18 +53,18 @@ class TableStr: def __str__(self) -> str: console = Console(force_terminal=True, legacy_windows=False) + ts_str = f"{to_timestamp(self.image.ts):.3f}" if self.image is not None else "?" # type: ignore[attr-defined] + # Create a table for detections table = Table( - title=f"{self.__class__.__name__} [{len(self.detections)} detections @ {to_timestamp(self.image.ts):.3f}]", # type: ignore[attr-defined] + title=f"{self.__class__.__name__} [{len(self.detections)} detections @ {ts_str}]", # type: ignore[attr-defined] show_header=True, show_edge=True, ) # Dynamically build columns based on the first detection's dict keys if not self.detections: # type: ignore[attr-defined] - return ( - f" {self.__class__.__name__} [0 detections @ {to_timestamp(self.image.ts):.3f}]" # type: ignore[attr-defined] - ) + return f" {self.__class__.__name__} [0 detections @ {ts_str}]" # type: ignore[attr-defined] # Cache all repr_dicts to avoid double computation detection_dicts = [det.to_repr_dict() for det in self] # type: ignore[attr-defined] From f51923d36bfbbace909adf4acedd7e6492511cdc Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Sun, 8 Mar 2026 20:05:54 +0800 Subject: [PATCH 059/118] mini adjustments --- dimos/memory/test_projection.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/dimos/memory/test_projection.py b/dimos/memory/test_projection.py index 9e5bf07902..a9c55f26aa 100644 --- a/dimos/memory/test_projection.py +++ b/dimos/memory/test_projection.py @@ -127,12 +127,12 @@ def test_query_embeddings(session, clip): print(embeddings) # we can create captions on demand - florence = Florence2Model(detail=CaptionDetail.NORMAL) + florence = Florence2Model(detail=CaptionDetail.MORE_DETAILED) florence.start() caption_query = ( session.streams.sharp_images.near(embeddings) - .limit(5) + .limit(2) .transform(CaptionTransformer(florence)) ) florence.stop() @@ -150,7 +150,7 @@ def test_query_embeddings(session, clip): print(obs.id, obs.data) # we can also find all images ever captured near these embeddings (600+ frames) - images = session.streams.sharp_images.near(embeddings).fetch() + images = session.streams.color_image.near(embeddings).fetch() print(images) @@ -214,16 +214,3 @@ def test_search_embeddings(session, clip): results = project.fetch() print(results) - results = project.fetch() - print(results) - results = project.fetch() - print(results) - results = project.fetch() - print(results) - results = project.fetch() - print(results) - results = project.fetch() - print(results) - print(results) - print(results) - print(results) From 9edcbef06135b53acd559c1e8be2a3414f3403a8 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Mon, 9 Mar 2026 17:21:34 +0800 Subject: [PATCH 060/118] transform chaining --- dimos/memory/impl/test_sqlite.py | 8 +-- dimos/memory/stream.py | 83 +++++++++++++------------------- dimos/memory/test_memory.py | 72 +++++++++++++++++++++++++++ dimos/memory/test_projection.py | 37 +++++--------- 4 files changed, 122 insertions(+), 78 deletions(-) create mode 100644 dimos/memory/test_memory.py diff --git a/dimos/memory/impl/test_sqlite.py b/dimos/memory/impl/test_sqlite.py index 0886eabc44..1df2d55d7f 100644 --- a/dimos/memory/impl/test_sqlite.py +++ b/dimos/memory/impl/test_sqlite.py @@ -886,12 +886,14 @@ def test_transform_on_observation_set( assert len(shapes) == 2 assert shapes[0].data == f"{images[0].width}x{images[0].height}" - def test_read_only(self, session: SqliteSession, images: list[Image]) -> None: + def test_append(self, session: SqliteSession, images: list[Image]) -> None: from dimos.memory.stream import ObservationSet result = ObservationSet([], session=session) - with pytest.raises(TypeError, match="read-only"): - result.append(images[0]) + obs = result.append(images[0], ts=1.0) + assert obs.id == 0 + assert obs.ts == 1.0 + assert len(result) == 1 def test_ordering_in_memory(self, session: SqliteSession, images: list[Image]) -> None: s = session.stream("obs_order", Image) diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index 3426cae66d..4b044028ed 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -576,7 +576,7 @@ def search_text(self, text: str, *, k: int | None = None) -> TextStream[T]: class TransformStream(Stream[R]): - """In-memory stream produced by .transform(). Not yet stored.""" + """In-memory stream produced by .transform(). Backed by ListBackend.""" def __init__( self, @@ -586,11 +586,25 @@ def __init__( live: bool = False, backfill_only: bool = False, ) -> None: - super().__init__(backend=None, session=source._session) + backend = ListBackend([], name="") + super().__init__(backend=backend, session=source._session) self._source = source self._transformer = transformer self._live = live self._backfill_only = backfill_only + self._materialized = False + + def _materialize(self) -> None: + """Run backfill if not yet done.""" + if self._materialized: + return + self._materialized = True + if self._transformer.supports_backfill and not self._live: + self._transformer.process(self._source, self) + + def _require_backend(self) -> StreamBackend: + self._materialize() + return super()._require_backend() def __repr__(self) -> str: return rich_text(self).plain @@ -599,13 +613,11 @@ def __str__(self) -> str: return render_text(rich_text(self)) def fetch(self) -> ObservationSet[R]: - """Execute transform in memory, collecting results.""" - collector = _CollectorStream[R]() - if self._transformer.supports_backfill and not self._live: - self._transformer.process(self._source, collector) + self._materialize() + backend = cast("ListBackend", self._backend) return ObservationSet( - collector.results, - session=self._source._session, + cast("list[Observation[R]]", list(backend._observations)), + session=self._session, payload_type=self._transformer.output_type, ) @@ -635,42 +647,13 @@ def store( ) -class _CollectorStream(Stream[R]): - """Ephemeral stream that collects appended observations in a list.""" - - def __init__(self) -> None: - super().__init__(backend=None) - self.results: list[Observation[R]] = [] - self._next_id = 0 - - def append( - self, - payload: R, - *, - ts: float | None = None, - pose: PoseLike | None = None, - tags: dict[str, Any] | None = None, - parent_id: int | None = None, - ) -> Observation[R]: - obs: Observation[R] = Observation( - id=self._next_id, - ts=ts if ts is not None else time.time(), - pose=pose, - tags=tags or {}, - parent_id=parent_id, - _data=payload, - ) - self._next_id += 1 - self.results.append(obs) - return obs - - class ListBackend: """In-memory backend that evaluates StreamQuery filters in Python.""" def __init__(self, observations: list[Observation[Any]], name: str = "") -> None: self._observations = observations self._name = name + self._next_id = max((o.id for o in observations), default=-1) + 1 from reactivex.subject import Subject self._subject: Subject[Observation[Any]] = Subject() # type: ignore[type-arg] @@ -739,7 +722,18 @@ def do_append( tags: dict[str, Any] | None, parent_id: int | None = None, ) -> Observation[Any]: - raise TypeError("ObservationSet is read-only") + obs: Observation[Any] = Observation( + id=self._next_id, + ts=ts if ts is not None else time.time(), + pose=pose, + tags=tags or {}, + parent_id=parent_id, + _data=payload, + ) + self._next_id += 1 + self._observations.append(obs) + self._subject.on_next(obs) + return obs @property def appended_subject(self) -> Subject[Observation[Any]]: # type: ignore[type-arg] @@ -776,17 +770,6 @@ def _clone(self, **overrides: Any) -> Stream[T]: # type: ignore[override] base._query = self._query return base._clone(**overrides) - def append( - self, - payload: T, - *, - ts: float | None = None, - pose: PoseLike | None = None, - tags: dict[str, Any] | None = None, - parent_id: int | None = None, - ) -> Observation[T]: - raise TypeError("ObservationSet is read-only") - # ── List-like interface ────────────────────────────────────────── def __len__(self) -> int: diff --git a/dimos/memory/test_memory.py b/dimos/memory/test_memory.py new file mode 100644 index 0000000000..cfcb5230ce --- /dev/null +++ b/dimos/memory/test_memory.py @@ -0,0 +1,72 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Generator + +import pytest + +from dimos.memory.impl.sqlite import SqliteSession, SqliteStore +from dimos.memory.transformer import ( + CaptionTransformer, + QualityWindowTransformer, + TextEmbeddingTransformer, +) +from dimos.models.embedding.clip import CLIPModel +from dimos.models.vl.florence import CaptionDetail, Florence2Model +from dimos.msgs.sensor_msgs.Image import Image +from dimos.utils.data import get_data + + +@pytest.fixture(scope="module") +def store() -> Generator[SqliteStore, None, None]: + with SqliteStore(get_data("go2_bigoffice.db")) as store: + yield store + + +@pytest.fixture(scope="module") +def session(store: SqliteStore) -> Generator[SqliteSession, None, None]: + with store.session() as session: + yield session + + +@pytest.fixture(scope="module") +def image_stream(session): + return session.stream("color_image", Image) + + +@pytest.fixture(scope="module") +def clip() -> CLIPModel: + model = CLIPModel() + model.start() + return model + + +def test_make_caption(session, clip): + print("") + + florence = Florence2Model(detail=CaptionDetail.NORMAL) + florence.start() + + caption_embeddings = ( + session.streams.sharp_images.transform( + QualityWindowTransformer(lambda img: img.sharpness, window=3.0), + ) + .transform(CaptionTransformer(florence)) + .transform(TextEmbeddingTransformer(clip)) + ) + + florence.stop() + + print(caption_embeddings) + print(caption_embeddings.fetch().summary()) diff --git a/dimos/memory/test_projection.py b/dimos/memory/test_projection.py index a9c55f26aa..88989a376d 100644 --- a/dimos/memory/test_projection.py +++ b/dimos/memory/test_projection.py @@ -27,7 +27,6 @@ from dimos.models.embedding.base import Embedding from dimos.models.embedding.clip import CLIPModel from dimos.models.vl.florence import CaptionDetail, Florence2Model -from dimos.models.vl.moondream import MoondreamVlModel from dimos.msgs.sensor_msgs.Image import Image from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.utils.data import get_data @@ -124,43 +123,31 @@ def test_query_embeddings(session, clip): embeddings = session.streams.clip_embeddings.search_embedding("supermarket", k=5, model=clip) - print(embeddings) - - # we can create captions on demand - florence = Florence2Model(detail=CaptionDetail.MORE_DETAILED) - florence.start() - - caption_query = ( - session.streams.sharp_images.near(embeddings) - .limit(2) + # we are precomputing and throwing away this stream + captions = ( + session.streams.sharp_images.near(embeddings) # spatially near the embedding matches + .limit(5) .transform(CaptionTransformer(florence)) + # adding live=True here makes it run the caption transformer live on each new matching embedding ) - florence.stop() - - # we could have also searched in the db (if precomputed) - # caption_query = session.streams.captions.near(embeddings) - print(caption_query) - - captions = caption_query.fetch() - - print(captions.summary()) - - for obs in captions: + for obs in captions.fetch(): print(obs.id, obs.data) - # we can also find all images ever captured near these embeddings (600+ frames) + # we can also find all images ever captured spatially near these embeddings (600+ frames) images = session.streams.color_image.near(embeddings).fetch() print(images) - moondream = MoondreamVlModel() - moondream.start() - + # we can also find all sharp images near these embeddings, then transform to detect bottles + # sharp images can be loaded from db or computed on demand, here we load from db bottles = session.streams.sharp_images.near(embeddings, radius=1.0).transform( DetectionTransformer(moondream, query="bottle") ) + # if we want to save this we'd do + # bottles.save("bottle_detections", Detection2D) + print(bottles) for bottle in bottles.fetch(): From 24c708d31cc42585bec87233e66c8c357dd77b9e Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Mon, 9 Mar 2026 18:32:51 +0800 Subject: [PATCH 061/118] memory2: lazy pull-based stream system MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Greenfield rewrite of the memory module using sync generators. Every .filter(), .transform(), .map() returns a new Stream — no computation until iteration. Backends handle query application; transforms are Iterator[Obs] → Iterator[Obs]. Live mode with backpressure buffers bridges push sources to pull consumers. --- dimos/memory2/__init__.py | 54 +++ dimos/memory2/backend.py | 151 +++++++++ dimos/memory2/buffer.py | 224 +++++++++++++ dimos/memory2/store.py | 131 ++++++++ dimos/memory2/stream.py | 274 ++++++++++++++++ dimos/memory2/test_stream.py | 613 +++++++++++++++++++++++++++++++++++ dimos/memory2/transform.py | 89 +++++ dimos/memory2/type.py | 178 ++++++++++ 8 files changed, 1714 insertions(+) create mode 100644 dimos/memory2/__init__.py create mode 100644 dimos/memory2/backend.py create mode 100644 dimos/memory2/buffer.py create mode 100644 dimos/memory2/store.py create mode 100644 dimos/memory2/stream.py create mode 100644 dimos/memory2/test_stream.py create mode 100644 dimos/memory2/transform.py create mode 100644 dimos/memory2/type.py diff --git a/dimos/memory2/__init__.py b/dimos/memory2/__init__.py new file mode 100644 index 0000000000..eeda6fd788 --- /dev/null +++ b/dimos/memory2/__init__.py @@ -0,0 +1,54 @@ +from dimos.memory2.backend import Backend, Disposable, ListBackend +from dimos.memory2.buffer import ( + BackpressureBuffer, + Bounded, + ClosedError, + DropNew, + KeepLast, + Unbounded, +) +from dimos.memory2.store import ListStore, Session, Store, StreamNamespace +from dimos.memory2.stream import Stream +from dimos.memory2.transform import FnTransformer, QualityWindow, Transformer +from dimos.memory2.type import ( + AfterFilter, + AtFilter, + BeforeFilter, + Filter, + NearFilter, + Observation, + PredicateFilter, + StreamQuery, + TagsFilter, + TimeRangeFilter, +) + +__all__ = [ + "AfterFilter", + "AtFilter", + "Backend", + "BackpressureBuffer", + "BeforeFilter", + "Bounded", + "ClosedError", + "Disposable", + "DropNew", + "Filter", + "FnTransformer", + "KeepLast", + "ListBackend", + "ListStore", + "NearFilter", + "Observation", + "PredicateFilter", + "QualityWindow", + "Session", + "Store", + "Stream", + "StreamNamespace", + "StreamQuery", + "TagsFilter", + "TimeRangeFilter", + "Transformer", + "Unbounded", +] diff --git a/dimos/memory2/backend.py b/dimos/memory2/backend.py new file mode 100644 index 0000000000..1438bd259c --- /dev/null +++ b/dimos/memory2/backend.py @@ -0,0 +1,151 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import threading +import time +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, runtime_checkable + +from dimos.memory2.type import Observation, StreamQuery + +if TYPE_CHECKING: + from collections.abc import Iterator + + from dimos.memory2.buffer import BackpressureBuffer + +T = TypeVar("T") + + +class Disposable: + """Simple disposable that calls a function on dispose().""" + + def __init__(self, fn: Any) -> None: + self._fn = fn + + def dispose(self) -> None: + if self._fn is not None: + self._fn() + self._fn = None + + +@runtime_checkable +class Backend(Protocol[T]): + """Data source protocol for stored observations. + + The backend is fully responsible for applying query filters. + How it does so (SQL, R-tree, Python predicates) is its business. + """ + + @property + def name(self) -> str: ... + + def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: ... + + def append( + self, + payload: T, + *, + ts: float | None = None, + pose: Any | None = None, + tags: dict[str, Any] | None = None, + ) -> Observation[T]: ... + + def count(self, query: StreamQuery) -> int: ... + + def subscribe(self, buf: BackpressureBuffer[Observation[T]]) -> Disposable: ... + + +class ListBackend(Generic[T]): + """In-memory backend for experimentation. Thread-safe.""" + + def __init__(self, name: str = "") -> None: + self._name = name + self._observations: list[Observation[T]] = [] + self._next_id = 0 + self._lock = threading.Lock() + self._subscribers: list[BackpressureBuffer[Observation[T]]] = [] + + @property + def name(self) -> str: + return self._name + + def append( + self, + payload: T, + *, + ts: float | None = None, + pose: Any | None = None, + tags: dict[str, Any] | None = None, + ) -> Observation[T]: + with self._lock: + obs: Observation[T] = Observation( + id=self._next_id, + ts=ts if ts is not None else time.time(), + pose=pose, + tags=tags or {}, + _data=payload, + ) + self._next_id += 1 + self._observations.append(obs) + subs = list(self._subscribers) + + # Notify outside lock to avoid deadlocks + for buf in subs: + buf.put(obs) + + return obs + + def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: + """Snapshot + apply all filters/ordering/offset/limit in Python.""" + with self._lock: + snapshot = list(self._observations) + + # Apply filters + for f in query.filters: + snapshot = [obs for obs in snapshot if f.matches(obs)] + + # Ordering + if query.order_field: + key = query.order_field + snapshot.sort( + key=lambda obs: getattr(obs, key) if getattr(obs, key, None) is not None else 0, + reverse=query.order_desc, + ) + + # Offset + if query.offset_val: + snapshot = snapshot[query.offset_val :] + + # Limit + if query.limit_val is not None: + snapshot = snapshot[: query.limit_val] + + yield from snapshot + + def count(self, query: StreamQuery) -> int: + return sum(1 for _ in self.iterate(query)) + + def subscribe(self, buf: BackpressureBuffer[Observation[T]]) -> Disposable: + with self._lock: + self._subscribers.append(buf) + + def _unsubscribe() -> None: + with self._lock: + try: + self._subscribers.remove(buf) + except ValueError: + pass + + return Disposable(_unsubscribe) diff --git a/dimos/memory2/buffer.py b/dimos/memory2/buffer.py new file mode 100644 index 0000000000..2669bef616 --- /dev/null +++ b/dimos/memory2/buffer.py @@ -0,0 +1,224 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections import deque +import threading +from typing import Generic, TypeVar + +T = TypeVar("T") + + +class ClosedError(Exception): + """Raised when take() is called on a closed buffer.""" + + +class BackpressureBuffer(ABC, Generic[T]): + """Thread-safe buffer between push producers and pull consumers.""" + + @abstractmethod + def put(self, item: T) -> bool: + """Push an item. Returns False if the item was dropped.""" + + @abstractmethod + def take(self, timeout: float | None = None) -> T: + """Block until an item is available. Raises ClosedError if the buffer is closed.""" + + @abstractmethod + def try_take(self) -> T | None: + """Non-blocking take. Returns None if empty.""" + + @abstractmethod + def close(self) -> None: + """Signal no more items. Subsequent take() raises ClosedError.""" + + @abstractmethod + def __len__(self) -> int: ... + + def __iter__(self): + """Yield items until the buffer is closed.""" + while True: + try: + yield self.take() + except ClosedError: + return + + +class KeepLast(BackpressureBuffer[T]): + """Single-slot buffer. put() always overwrites. Default for live mode.""" + + def __init__(self) -> None: + self._item: T | None = None + self._has_item = False + self._closed = False + self._cond = threading.Condition() + + def put(self, item: T) -> bool: + with self._cond: + if self._closed: + return False + self._item = item + self._has_item = True + self._cond.notify() + return True + + def take(self, timeout: float | None = None) -> T: + with self._cond: + while not self._has_item: + if self._closed: + raise ClosedError("Buffer is closed") + if not self._cond.wait(timeout): + raise TimeoutError("take() timed out") + item = self._item + self._item = None + self._has_item = False + return item # type: ignore[return-value] + + def try_take(self) -> T | None: + with self._cond: + if not self._has_item: + return None + item = self._item + self._item = None + self._has_item = False + return item + + def close(self) -> None: + with self._cond: + self._closed = True + self._cond.notify_all() + + def __len__(self) -> int: + with self._cond: + return 1 if self._has_item else 0 + + +class Bounded(BackpressureBuffer[T]): + """FIFO queue with max size. Drops oldest when full.""" + + def __init__(self, maxlen: int) -> None: + self._buf: deque[T] = deque(maxlen=maxlen) + self._closed = False + self._cond = threading.Condition() + + def put(self, item: T) -> bool: + with self._cond: + if self._closed: + return False + self._buf.append(item) # deque(maxlen) drops oldest automatically + self._cond.notify() + return True + + def take(self, timeout: float | None = None) -> T: + with self._cond: + while not self._buf: + if self._closed: + raise ClosedError("Buffer is closed") + if not self._cond.wait(timeout): + raise TimeoutError("take() timed out") + return self._buf.popleft() + + def try_take(self) -> T | None: + with self._cond: + return self._buf.popleft() if self._buf else None + + def close(self) -> None: + with self._cond: + self._closed = True + self._cond.notify_all() + + def __len__(self) -> int: + with self._cond: + return len(self._buf) + + +class DropNew(BackpressureBuffer[T]): + """FIFO queue. Rejects new items when full (put returns False).""" + + def __init__(self, maxlen: int) -> None: + self._buf: deque[T] = deque() + self._maxlen = maxlen + self._closed = False + self._cond = threading.Condition() + + def put(self, item: T) -> bool: + with self._cond: + if self._closed or len(self._buf) >= self._maxlen: + return False + self._buf.append(item) + self._cond.notify() + return True + + def take(self, timeout: float | None = None) -> T: + with self._cond: + while not self._buf: + if self._closed: + raise ClosedError("Buffer is closed") + if not self._cond.wait(timeout): + raise TimeoutError("take() timed out") + return self._buf.popleft() + + def try_take(self) -> T | None: + with self._cond: + return self._buf.popleft() if self._buf else None + + def close(self) -> None: + with self._cond: + self._closed = True + self._cond.notify_all() + + def __len__(self) -> int: + with self._cond: + return len(self._buf) + + +class Unbounded(BackpressureBuffer[T]): + """Unbounded FIFO queue. Use carefully — can grow without limit.""" + + def __init__(self) -> None: + self._buf: deque[T] = deque() + self._closed = False + self._cond = threading.Condition() + + def put(self, item: T) -> bool: + with self._cond: + if self._closed: + return False + self._buf.append(item) + self._cond.notify() + return True + + def take(self, timeout: float | None = None) -> T: + with self._cond: + while not self._buf: + if self._closed: + raise ClosedError("Buffer is closed") + if not self._cond.wait(timeout): + raise TimeoutError("take() timed out") + return self._buf.popleft() + + def try_take(self) -> T | None: + with self._cond: + return self._buf.popleft() if self._buf else None + + def close(self) -> None: + with self._cond: + self._closed = True + self._cond.notify_all() + + def __len__(self) -> int: + with self._cond: + return len(self._buf) diff --git a/dimos/memory2/store.py b/dimos/memory2/store.py new file mode 100644 index 0000000000..b36e845143 --- /dev/null +++ b/dimos/memory2/store.py @@ -0,0 +1,131 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any, TypeVar + +from dimos.core.resource import Resource +from dimos.memory2.backend import Backend, ListBackend +from dimos.memory2.stream import Stream + +T = TypeVar("T") + + +class StreamNamespace: + """Attribute-access proxy for session streams. + + Usage:: + + session.streams.image_stream + session.streams["image_stream"] + list(session.streams) + len(session.streams) + """ + + def __init__(self, session: Session) -> None: + self._session = session + + def __getattr__(self, name: str) -> Stream[Any]: + if name.startswith("_"): + raise AttributeError(name) + try: + return self._session._streams[name] + except KeyError: + available = ", ".join(self._session._streams) or "(none)" + raise AttributeError(f"No stream named {name!r}. Available: {available}") from None + + def __getitem__(self, name: str) -> Stream[Any]: + try: + return self._session._streams[name] + except KeyError: + raise KeyError(name) from None + + def __iter__(self): + return iter(self._session._streams.values()) + + def __len__(self) -> int: + return len(self._session._streams) + + def __contains__(self, name: str) -> bool: + return name in self._session._streams + + def __repr__(self) -> str: + return f"StreamNamespace({list(self._session._streams.keys())})" + + +class Session(Resource): + """A session against a store. Creates and manages named streams.""" + + def __init__(self, backend_factory: Any) -> None: # Callable[[str], Backend] + self._backend_factory = backend_factory + self._streams: dict[str, Stream[Any]] = {} + self._backends: dict[str, Backend[Any]] = {} + + def stream(self, name: str, payload_type: type[T] | None = None) -> Stream[T]: + """Get or create a named stream. Returns the same Stream on repeated calls.""" + if name not in self._streams: + backend = self._backend_factory(name) + self._backends[name] = backend + self._streams[name] = Stream(source=backend) + return self._streams[name] # type: ignore[return-value] + + def list_streams(self) -> list[Stream[Any]]: + return list(self._streams.values()) + + def delete_stream(self, name: str) -> None: + self._streams.pop(name, None) + self._backends.pop(name, None) + + @property + def streams(self) -> StreamNamespace: + return StreamNamespace(self) + + def start(self) -> None: + pass + + def stop(self) -> None: + pass + + def __enter__(self) -> Session: + return self + + def __exit__(self, *args: object) -> None: + self.stop() + + +class Store(Resource): + """Top-level entry point — wraps a storage location.""" + + def session(self) -> Session: + raise NotImplementedError + + def start(self) -> None: + pass + + def stop(self) -> None: + pass + + def __enter__(self) -> Store: + return self + + def __exit__(self, *args: object) -> None: + self.stop() + + +class ListStore(Store): + """In-memory store for experimentation.""" + + def session(self) -> Session: + return Session(backend_factory=lambda name: ListBackend(name)) diff --git a/dimos/memory2/stream.py b/dimos/memory2/stream.py new file mode 100644 index 0000000000..e602df6013 --- /dev/null +++ b/dimos/memory2/stream.py @@ -0,0 +1,274 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from itertools import islice +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +from dimos.memory2.backend import Backend +from dimos.memory2.buffer import BackpressureBuffer, ClosedError, KeepLast +from dimos.memory2.transform import FnTransformer, Transformer +from dimos.memory2.type import ( + AfterFilter, + AtFilter, + BeforeFilter, + Filter, + NearFilter, + Observation, + PredicateFilter, + StreamQuery, + TagsFilter, + TimeRangeFilter, +) + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable, Iterator + +T = TypeVar("T") +R = TypeVar("R") + +# Source is either a Backend or a (upstream_stream, transformer) pair. +_Source = Backend[Any] | tuple["Stream[Any]", Transformer[Any, Any]] + + +class Stream(Generic[T]): + """Lazy, pull-based stream over observations. + + Every filter/transform method returns a new Stream — no computation + happens until iteration. Backends handle query application for stored + data; transform sources apply filters as Python predicates. + """ + + def __init__( + self, + source: _Source, + *, + query: StreamQuery = StreamQuery(), + _live_buf: BackpressureBuffer[Observation[Any]] | None = None, + _live_sub: Any | None = None, + ) -> None: + self._source = source + self._query = query + self._live_buf = _live_buf + self._live_sub = _live_sub # Disposable, kept alive for lifetime of stream + + # ── Iteration ─────────────────────────────────────────────────── + + def __iter__(self) -> Iterator[Observation[T]]: + return self._build_iter() + + def _build_iter(self) -> Iterator[Observation[T]]: + if isinstance(self._source, tuple): + it = self._iter_transform() + else: + # Backend handles all query application + it = self._source.iterate(self._query) + + # Live tail: after backfill exhausts, yield from live buffer + if self._live_buf is not None: + it = self._iter_with_live(it) + + return it + + def _iter_transform(self) -> Iterator[Observation[T]]: + """Iterate a transform source, applying query filters in Python.""" + upstream_stream, xf = self._source # type: ignore[misc] + it: Iterator[Observation[Any]] = xf(iter(upstream_stream)) + + # Apply filters as Python predicates + filters = self._query.filters + if filters: + it = (obs for obs in it if all(f.matches(obs) for f in filters)) + + # Sort if needed (materializes — only for finite streams) + if self._query.order_field: + key = self._query.order_field + desc = self._query.order_desc + items = sorted( + list(it), + key=lambda obs: getattr(obs, key) if getattr(obs, key, None) is not None else 0, + reverse=desc, + ) + it = iter(items) + + # Offset + limit + if self._query.offset_val: + it = islice(it, self._query.offset_val, None) + if self._query.limit_val is not None: + it = islice(it, self._query.limit_val) + + return it # type: ignore[return-value] + + def _iter_with_live(self, backfill: Iterator[Observation[T]]) -> Iterator[Observation[T]]: + """Yield backfill, then switch to live tail.""" + last_id = -1 + for obs in backfill: + last_id = max(last_id, obs.id) + yield obs + + # Live phase + buf = self._live_buf + assert buf is not None + filters = self._query.filters + try: + while True: + obs = buf.take() + if obs.id <= last_id: + continue + last_id = obs.id + if filters and not all(f.matches(obs) for f in filters): + continue + yield obs # type: ignore[misc] + except (ClosedError, StopIteration): + return + + # ── Query builders ────────────────────────────────────────────── + + def _replace_query(self, **overrides: Any) -> Stream[T]: + q = self._query + new_q = StreamQuery( + filters=overrides.get("filters", q.filters), + order_field=overrides.get("order_field", q.order_field), + order_desc=overrides.get("order_desc", q.order_desc), + limit_val=overrides.get("limit_val", q.limit_val), + offset_val=overrides.get("offset_val", q.offset_val), + ) + return Stream(self._source, query=new_q, _live_buf=self._live_buf, _live_sub=self._live_sub) + + def _with_filter(self, f: Filter) -> Stream[T]: + return self._replace_query(filters=(*self._query.filters, f)) + + def after(self, t: float) -> Stream[T]: + return self._with_filter(AfterFilter(t)) + + def before(self, t: float) -> Stream[T]: + return self._with_filter(BeforeFilter(t)) + + def time_range(self, t1: float, t2: float) -> Stream[T]: + return self._with_filter(TimeRangeFilter(t1, t2)) + + def at(self, t: float, tolerance: float = 1.0) -> Stream[T]: + return self._with_filter(AtFilter(t, tolerance)) + + def near(self, pose: Any, radius: float) -> Stream[T]: + return self._with_filter(NearFilter(pose, radius)) + + def filter_tags(self, **tags: Any) -> Stream[T]: + return self._with_filter(TagsFilter(tags)) + + def order_by(self, field: str, desc: bool = False) -> Stream[T]: + return self._replace_query(order_field=field, order_desc=desc) + + def limit(self, k: int) -> Stream[T]: + return self._replace_query(limit_val=k) + + def offset(self, n: int) -> Stream[T]: + return self._replace_query(offset_val=n) + + # ── Functional API ────────────────────────────────────────────── + + def filter(self, pred: Callable[[Observation[T]], bool]) -> Stream[T]: + """Filter by arbitrary predicate on the full Observation.""" + return self._with_filter(PredicateFilter(pred)) + + def map(self, fn: Callable[[Observation[T]], Any]) -> Stream[Any]: + """Transform each observation's data via callable.""" + return self.transform(FnTransformer(lambda obs: obs.derive(data=fn(obs)))) + + def flat_map(self, fn: Callable[[Observation[T]], Iterable[Any]]) -> Stream[Any]: + """Map that fans out — fn returns iterable of data values per observation.""" + + class _FlatMapXf(Transformer[T, Any]): + def __call__(self, upstream: Iterator[Observation[T]]) -> Iterator[Observation[Any]]: + for obs in upstream: + for item in fn(obs): + yield obs.derive(data=item) + + return self.transform(_FlatMapXf()) + + # ── Transform ─────────────────────────────────────────────────── + + def transform(self, xf: Transformer[Any, Any]) -> Stream[Any]: + """Wrap this stream with a transformer. Returns a new lazy Stream. + + When iterated, calls xf(iter(self)) — pulls lazily through the chain. + """ + return Stream(source=(self, xf), query=StreamQuery()) + + # ── Live mode ─────────────────────────────────────────────────── + + def live(self, buffer: BackpressureBuffer[Observation[Any]] | None = None) -> Stream[T]: + """Return a stream that yields backfill then live data (infinite iterator). + + Default buffer: KeepLast(). Subscribes to the root backend BEFORE + backfill starts and deduplicates by observation id. + """ + buf = buffer if buffer is not None else KeepLast() + backend = self._find_backend() + sub = backend.subscribe(buf) + return Stream(self._source, query=self._query, _live_buf=buf, _live_sub=sub) + + def _find_backend(self) -> Backend[Any]: + """Walk up the source chain to find the root Backend.""" + source = self._source + while isinstance(source, tuple): + upstream_stream, _ = source + source = upstream_stream._source + if not isinstance(source, Backend): + raise TypeError("Cannot find a backend in this stream chain") + return source + + # ── Terminals ─────────────────────────────────────────────────── + + def fetch(self) -> list[Observation[T]]: + """Materialize all observations into a list.""" + return list(self) + + def first(self) -> Observation[T]: + """Return the first matching observation.""" + it = iter(self.limit(1)) + try: + return next(it) + except StopIteration: + raise LookupError("No matching observation") from None + + def last(self) -> Observation[T]: + """Return the last matching observation (by timestamp).""" + return self.order_by("ts", desc=True).first() + + def count(self) -> int: + """Count matching observations.""" + if isinstance(self._source, Backend) and not isinstance(self._source, tuple): + return self._source.count(self._query) + return sum(1 for _ in self) + + def exists(self) -> bool: + """Check if any matching observation exists.""" + return next(iter(self.limit(1)), None) is not None + + # ── Write ─────────────────────────────────────────────────────── + + def append( + self, + payload: T, + *, + ts: float | None = None, + pose: Any | None = None, + tags: dict[str, Any] | None = None, + ) -> Observation[T]: + """Append to the backing store. Only works if source is a Backend.""" + if isinstance(self._source, tuple): + raise TypeError("Cannot append to a transform stream. Append to the source stream.") + return self._source.append(payload, ts=ts, pose=pose, tags=tags) diff --git a/dimos/memory2/test_stream.py b/dimos/memory2/test_stream.py new file mode 100644 index 0000000000..db3c214394 --- /dev/null +++ b/dimos/memory2/test_stream.py @@ -0,0 +1,613 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""memory2 stream tests — serves as living documentation of the lazy stream API. + +Each test demonstrates a specific capability with clear setup, action, and assertion. +""" + +from __future__ import annotations + +import threading +import time + +import pytest + +from dimos.memory2.backend import ListBackend +from dimos.memory2.buffer import Bounded, ClosedError, DropNew, KeepLast, Unbounded +from dimos.memory2.store import ListStore +from dimos.memory2.stream import Stream +from dimos.memory2.transform import FnTransformer, QualityWindow, Transformer +from dimos.memory2.type import Observation + +# ── Helpers ────────────────────────────────────────────────────────── + + +def make_stream(n: int = 5, start_ts: float = 0.0) -> Stream[int]: + """Create a ListBackend stream with n integer observations at 1-second intervals.""" + backend = ListBackend[int]("test") + for i in range(n): + backend.append(i * 10, ts=start_ts + i) + return Stream(source=backend) + + +# ═══════════════════════════════════════════════════════════════════ +# 1. Basic iteration +# ═══════════════════════════════════════════════════════════════════ + + +class TestBasicIteration: + """Streams are lazy iterables — nothing runs until you iterate.""" + + def test_iterate_yields_all_observations(self): + stream = make_stream(5) + obs = list(stream) + assert len(obs) == 5 + assert [o.data for o in obs] == [0, 10, 20, 30, 40] + + def test_iterate_preserves_timestamps(self): + stream = make_stream(3, start_ts=100.0) + assert [o.ts for o in stream] == [100.0, 101.0, 102.0] + + def test_empty_stream(self): + stream = make_stream(0) + assert list(stream) == [] + + def test_fetch_materializes_to_list(self): + result = make_stream(3).fetch() + assert isinstance(result, list) + assert len(result) == 3 + + def test_stream_is_reiterable(self): + """Same stream can be iterated multiple times — each time re-queries.""" + stream = make_stream(3) + first = [o.data for o in stream] + second = [o.data for o in stream] + assert first == second == [0, 10, 20] + + +# ═══════════════════════════════════════════════════════════════════ +# 2. Temporal filters +# ═══════════════════════════════════════════════════════════════════ + + +class TestTemporalFilters: + """Temporal filters constrain observations by timestamp.""" + + def test_after(self): + """.after(t) keeps observations with ts > t.""" + result = make_stream(5).after(2.0).fetch() + assert [o.ts for o in result] == [3.0, 4.0] + + def test_before(self): + """.before(t) keeps observations with ts < t.""" + result = make_stream(5).before(2.0).fetch() + assert [o.ts for o in result] == [0.0, 1.0] + + def test_time_range(self): + """.time_range(t1, t2) keeps t1 <= ts <= t2.""" + result = make_stream(5).time_range(1.0, 3.0).fetch() + assert [o.ts for o in result] == [1.0, 2.0, 3.0] + + def test_at_with_tolerance(self): + """.at(t, tolerance) keeps observations within tolerance of t.""" + result = make_stream(5).at(2.0, tolerance=0.5).fetch() + assert [o.ts for o in result] == [2.0] + + def test_chained_temporal_filters(self): + """Filters compose — each narrows the result.""" + result = make_stream(10).after(2.0).before(7.0).fetch() + assert [o.ts for o in result] == [3.0, 4.0, 5.0, 6.0] + + +# ═══════════════════════════════════════════════════════════════════ +# 3. Spatial filter +# ═══════════════════════════════════════════════════════════════════ + + +class TestSpatialFilter: + """.near(pose, radius) filters by Euclidean distance.""" + + def test_near_with_tuples(self): + backend = ListBackend[str]("spatial") + backend.append("origin", ts=0.0, pose=(0, 0, 0)) + backend.append("close", ts=1.0, pose=(1, 1, 0)) + backend.append("far", ts=2.0, pose=(10, 10, 10)) + stream = Stream(source=backend) + + result = stream.near((0, 0, 0), radius=2.0).fetch() + assert [o.data for o in result] == ["origin", "close"] + + def test_near_excludes_no_pose(self): + backend = ListBackend[str]("spatial") + backend.append("no_pose", ts=0.0) + backend.append("has_pose", ts=1.0, pose=(0, 0, 0)) + stream = Stream(source=backend) + + result = stream.near((0, 0, 0), radius=10.0).fetch() + assert [o.data for o in result] == ["has_pose"] + + +# ═══════════════════════════════════════════════════════════════════ +# 4. Tags filter +# ═══════════════════════════════════════════════════════════════════ + + +class TestTagsFilter: + """.filter_tags() matches on observation metadata.""" + + def test_filter_by_tag(self): + backend = ListBackend[str]("tagged") + backend.append("cat", ts=0.0, tags={"type": "animal", "legs": 4}) + backend.append("car", ts=1.0, tags={"type": "vehicle", "wheels": 4}) + backend.append("dog", ts=2.0, tags={"type": "animal", "legs": 4}) + stream = Stream(source=backend) + + result = stream.filter_tags(type="animal").fetch() + assert [o.data for o in result] == ["cat", "dog"] + + def test_filter_multiple_tags(self): + backend = ListBackend[str]("tagged") + backend.append("a", ts=0.0, tags={"x": 1, "y": 2}) + backend.append("b", ts=1.0, tags={"x": 1, "y": 3}) + stream = Stream(source=backend) + + result = stream.filter_tags(x=1, y=2).fetch() + assert [o.data for o in result] == ["a"] + + +# ═══════════════════════════════════════════════════════════════════ +# 5. Ordering, limit, offset +# ═══════════════════════════════════════════════════════════════════ + + +class TestOrderLimitOffset: + def test_limit(self): + result = make_stream(10).limit(3).fetch() + assert len(result) == 3 + + def test_offset(self): + result = make_stream(5).offset(2).fetch() + assert [o.data for o in result] == [20, 30, 40] + + def test_limit_and_offset(self): + result = make_stream(10).offset(2).limit(3).fetch() + assert [o.data for o in result] == [20, 30, 40] + + def test_order_by_ts_desc(self): + result = make_stream(5).order_by("ts", desc=True).fetch() + assert [o.ts for o in result] == [4.0, 3.0, 2.0, 1.0, 0.0] + + def test_first(self): + obs = make_stream(5).first() + assert obs.data == 0 + + def test_last(self): + obs = make_stream(5).last() + assert obs.data == 40 + + def test_first_empty_raises(self): + with pytest.raises(LookupError): + make_stream(0).first() + + def test_count(self): + assert make_stream(5).count() == 5 + assert make_stream(5).after(2.0).count() == 2 + + def test_exists(self): + assert make_stream(5).exists() + assert not make_stream(0).exists() + assert not make_stream(5).after(100.0).exists() + + +# ═══════════════════════════════════════════════════════════════════ +# 6. Functional API: .filter(), .map() +# ═══════════════════════════════════════════════════════════════════ + + +class TestFunctionalAPI: + """Functional combinators receive the full Observation.""" + + def test_filter_with_predicate(self): + """.filter() takes a predicate on the full Observation.""" + result = make_stream(5).filter(lambda obs: obs.data > 20).fetch() + assert [o.data for o in result] == [30, 40] + + def test_filter_on_metadata(self): + """Predicates can access ts, tags, pose — not just data.""" + result = make_stream(5).filter(lambda obs: obs.ts % 2 == 0).fetch() + assert [o.ts for o in result] == [0.0, 2.0, 4.0] + + def test_map(self): + """.map() transforms each observation's data.""" + result = make_stream(3).map(lambda obs: obs.data * 2).fetch() + assert [o.data for o in result] == [0, 20, 40] + + def test_map_preserves_ts(self): + result = make_stream(3).map(lambda obs: str(obs.data)).fetch() + assert [o.ts for o in result] == [0.0, 1.0, 2.0] + assert [o.data for o in result] == ["0", "10", "20"] + + def test_flat_map(self): + """.flat_map() fans out — fn returns iterable of values per obs.""" + result = make_stream(3).flat_map(lambda obs: [obs.data, obs.data + 1]).fetch() + assert [o.data for o in result] == [0, 1, 10, 11, 20, 21] + + +# ═══════════════════════════════════════════════════════════════════ +# 7. Transform chaining +# ═══════════════════════════════════════════════════════════════════ + + +class TestTransformChaining: + """Transforms chain lazily — each obs flows through the full pipeline.""" + + def test_single_transform(self): + xf = FnTransformer(lambda obs: obs.derive(data=obs.data + 1)) + result = make_stream(3).transform(xf).fetch() + assert [o.data for o in result] == [1, 11, 21] + + def test_chained_transforms(self): + """stream.transform(A).transform(B) — B pulls from A which pulls from source.""" + add_one = FnTransformer(lambda obs: obs.derive(data=obs.data + 1)) + double = FnTransformer(lambda obs: obs.derive(data=obs.data * 2)) + + result = make_stream(3).transform(add_one).transform(double).fetch() + # (0+1)*2=2, (10+1)*2=22, (20+1)*2=42 + assert [o.data for o in result] == [2, 22, 42] + + def test_transform_can_skip(self): + """Returning None from a transformer skips that observation.""" + keep_even = FnTransformer(lambda obs: obs if obs.data % 20 == 0 else None) + result = make_stream(5).transform(keep_even).fetch() + assert [o.data for o in result] == [0, 20, 40] + + def test_transform_filter_transform(self): + """stream.transform(A).near(pose).transform(B) — filter between transforms.""" + backend = ListBackend[int]("tfft") + backend.append(1, ts=0.0, pose=(0, 0, 0)) + backend.append(2, ts=1.0, pose=(100, 100, 100)) + backend.append(3, ts=2.0, pose=(1, 0, 0)) + stream = Stream(source=backend) + + add_ten = FnTransformer(lambda obs: obs.derive(data=obs.data + 10)) + double = FnTransformer(lambda obs: obs.derive(data=obs.data * 2)) + + result = ( + stream.transform(add_ten) # 11, 12, 13 + .near((0, 0, 0), 5.0) # keeps pose at (0,0,0) and (1,0,0) + .transform(double) # 22, 26 + .fetch() + ) + assert [o.data for o in result] == [22, 26] + + def test_quality_window(self): + """QualityWindow keeps the best item per time window.""" + backend = ListBackend[float]("qw") + # Window 1: ts 0.0-0.9 → best quality + backend.append(0.3, ts=0.0) + backend.append(0.9, ts=0.3) # best in window + backend.append(0.1, ts=0.7) + # Window 2: ts 1.0-1.9 + backend.append(0.5, ts=1.0) + backend.append(0.8, ts=1.5) # best in window + # Window 3: ts 2.0+ (emitted at end via flush) + backend.append(0.6, ts=2.2) + stream = Stream(source=backend) + + xf = QualityWindow(quality_fn=lambda v: v, window=1.0) + result = stream.transform(xf).fetch() + assert [o.data for o in result] == [0.9, 0.8, 0.6] + + def test_streaming_not_buffering(self): + """Transforms process lazily — early limit stops pulling from source.""" + calls = [] + + class CountingXf(Transformer[int, int]): + def __call__(self, upstream): + for obs in upstream: + calls.append(obs.data) + yield obs + + result = make_stream(100).transform(CountingXf()).limit(3).fetch() + assert len(result) == 3 + # The transformer should have processed at most a few more than 3 + # (not all 100) due to lazy evaluation + assert len(calls) == 3 + + +# ═══════════════════════════════════════════════════════════════════ +# 8. Store & Session +# ═══════════════════════════════════════════════════════════════════ + + +class TestStoreSession: + """Store -> Session -> Stream hierarchy for named streams.""" + + def test_basic_session(self): + store = ListStore() + with store.session() as session: + images = session.stream("images") + images.append("frame1", ts=0.0) + images.append("frame2", ts=1.0) + assert images.count() == 2 + + def test_same_stream_on_repeated_calls(self): + store = ListStore() + with store.session() as session: + s1 = session.stream("images") + s2 = session.stream("images") + assert s1 is s2 + + def test_stream_namespace(self): + store = ListStore() + with store.session() as session: + session.stream("images") + session.stream("lidar") + assert "images" in session.streams + assert len(session.streams) == 2 + assert session.streams.images is session.stream("images") + assert session.streams["lidar"] is session.stream("lidar") + + def test_namespace_missing_raises(self): + store = ListStore() + with store.session() as session: + with pytest.raises(AttributeError, match="No stream named"): + _ = session.streams.nonexistent + + def test_delete_stream(self): + store = ListStore() + with store.session() as session: + session.stream("temp") + session.delete_stream("temp") + assert "temp" not in session.streams + + +# ═══════════════════════════════════════════════════════════════════ +# 9. Lazy data loading +# ═══════════════════════════════════════════════════════════════════ + + +class TestLazyData: + """Observation.data supports lazy loading with cleanup.""" + + def test_eager_data(self): + """In-memory observations have data set directly — zero-cost access.""" + obs = Observation(id=0, ts=0.0, _data="hello") + assert obs.data == "hello" + + def test_lazy_loading(self): + """Data loaded on first access, loader released after.""" + load_count = 0 + + def loader(): + nonlocal load_count + load_count += 1 + return "loaded" + + obs = Observation(id=0, ts=0.0, _loader=loader) + assert load_count == 0 + assert obs.data == "loaded" + assert load_count == 1 + assert obs._loader is None # released + assert obs.data == "loaded" # cached, no second load + assert load_count == 1 + + def test_no_data_no_loader_raises(self): + obs = Observation(id=0, ts=0.0) + with pytest.raises(LookupError): + _ = obs.data + + def test_derive_preserves_metadata(self): + obs = Observation(id=42, ts=1.5, pose=(1, 2, 3), tags={"k": "v"}, _data="original") + derived = obs.derive(data="transformed") + assert derived.id == 42 + assert derived.ts == 1.5 + assert derived.pose == (1, 2, 3) + assert derived.tags == {"k": "v"} + assert derived.data == "transformed" + + +# ═══════════════════════════════════════════════════════════════════ +# 10. Backpressure buffers +# ═══════════════════════════════════════════════════════════════════ + + +class TestBackpressureBuffers: + """Thread-safe buffers bridging push sources to pull consumers.""" + + def test_keep_last_overwrites(self): + buf = KeepLast[int]() + buf.put(1) + buf.put(2) + buf.put(3) + assert buf.take() == 3 + assert len(buf) == 0 + + def test_bounded_drops_oldest(self): + buf = Bounded[int](maxlen=2) + buf.put(1) + buf.put(2) + buf.put(3) # drops 1 + assert buf.take() == 2 + assert buf.take() == 3 + + def test_drop_new_rejects(self): + buf = DropNew[int](maxlen=2) + assert buf.put(1) is True + assert buf.put(2) is True + assert buf.put(3) is False # rejected + assert buf.take() == 1 + assert buf.take() == 2 + + def test_unbounded_keeps_all(self): + buf = Unbounded[int]() + for i in range(100): + buf.put(i) + assert len(buf) == 100 + + def test_close_signals_end(self): + buf = KeepLast[int]() + buf.close() + with pytest.raises(ClosedError): + buf.take() + + def test_buffer_is_iterable(self): + """Iterating a buffer yields items until closed.""" + buf = Unbounded[int]() + buf.put(1) + buf.put(2) + buf.close() + assert list(buf) == [1, 2] + + def test_take_blocks_until_put(self): + buf = KeepLast[int]() + result = [] + + def producer(): + time.sleep(0.05) + buf.put(42) + + t = threading.Thread(target=producer) + t.start() + result.append(buf.take(timeout=2.0)) + t.join() + assert result == [42] + + +# ═══════════════════════════════════════════════════════════════════ +# 11. Live mode +# ═══════════════════════════════════════════════════════════════════ + + +class TestLiveMode: + """Live streams yield backfill then block for new observations.""" + + def test_live_sees_backfill_then_new(self): + """Backfill first, then live appends come through.""" + backend = ListBackend[str]("live") + backend.append("old", ts=0.0) + stream = Stream(source=backend) + live = stream.live(buffer=Unbounded()) + + # Start consuming in a thread + results: list[str] = [] + consumed = threading.Event() + + def consumer(): + for obs in live: + results.append(obs.data) + if len(results) >= 3: + consumed.set() + return + + t = threading.Thread(target=consumer) + t.start() + + time.sleep(0.05) + backend.append("new1", ts=1.0) + backend.append("new2", ts=2.0) + + consumed.wait(timeout=2.0) + t.join(timeout=2.0) + assert results == ["old", "new1", "new2"] + + def test_live_with_filter(self): + """Filters apply to live data — non-matching obs are dropped silently.""" + backend = ListBackend[int]("live_filter") + stream = Stream(source=backend) + live = stream.after(5.0).live(buffer=Unbounded()) + + results: list[int] = [] + consumed = threading.Event() + + def consumer(): + for obs in live: + results.append(obs.data) + if len(results) >= 2: + consumed.set() + return + + t = threading.Thread(target=consumer) + t.start() + + time.sleep(0.05) + backend.append(1, ts=1.0) # filtered out (ts <= 5.0) + backend.append(2, ts=6.0) # passes + backend.append(3, ts=3.0) # filtered out + backend.append(4, ts=10.0) # passes + + consumed.wait(timeout=2.0) + t.join(timeout=2.0) + assert results == [2, 4] + + def test_live_deduplicates_backfill_overlap(self): + """Observations seen in backfill are not re-yielded from the live buffer.""" + backend = ListBackend[str]("dedup") + backend.append("backfill", ts=0.0) + stream = Stream(source=backend) + live = stream.live(buffer=Unbounded()) + + results: list[str] = [] + consumed = threading.Event() + + def consumer(): + for obs in live: + results.append(obs.data) + if len(results) >= 2: + consumed.set() + return + + t = threading.Thread(target=consumer) + t.start() + + time.sleep(0.05) + backend.append("live1", ts=1.0) + + consumed.wait(timeout=2.0) + t.join(timeout=2.0) + assert results == ["backfill", "live1"] + + def test_live_with_keep_last_backpressure(self): + """KeepLast drops intermediate values when consumer is slow.""" + backend = ListBackend[int]("bp") + stream = Stream(source=backend) + live = stream.live(buffer=KeepLast()) + + results: list[int] = [] + consumed = threading.Event() + + def consumer(): + for obs in live: + results.append(obs.data) + if obs.data >= 90: + consumed.set() + return + time.sleep(0.1) # slow consumer + + t = threading.Thread(target=consumer) + t.start() + + time.sleep(0.05) + # Rapid producer — KeepLast will drop most of these + for i in range(100): + backend.append(i, ts=float(i)) + time.sleep(0.001) + + consumed.wait(timeout=5.0) + t.join(timeout=2.0) + # Should have far fewer than 100 results due to KeepLast + assert len(results) < 50 + # Last result should be near the end + assert results[-1] >= 90 diff --git a/dimos/memory2/transform.py b/dimos/memory2/transform.py new file mode 100644 index 0000000000..7fca2ab3e2 --- /dev/null +++ b/dimos/memory2/transform.py @@ -0,0 +1,89 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +if TYPE_CHECKING: + from collections.abc import Callable, Iterator + + from dimos.memory2.type import Observation + +T = TypeVar("T") +R = TypeVar("R") + + +class Transformer(ABC, Generic[T, R]): + """Transforms a stream of observations lazily via iterator -> iterator. + + Pull from upstream, yield transformed observations. Naturally supports + batching, windowing, fan-out. No flush() needed — the generator cleans + up when upstream exhausts. + """ + + @abstractmethod + def __call__(self, upstream: Iterator[Observation[T]]) -> Iterator[Observation[R]]: ... + + +class FnTransformer(Transformer[T, R]): + """Wraps a callable that receives an Observation and returns a new one (or None to skip).""" + + def __init__(self, fn: Callable[[Observation[T]], Observation[R] | None]) -> None: + self._fn = fn + + def __call__(self, upstream: Iterator[Observation[T]]) -> Iterator[Observation[R]]: + fn = self._fn + for obs in upstream: + result = fn(obs) + if result is not None: + yield result + + +class QualityWindow(Transformer[T, T]): + """Keeps the highest-quality item per time window. + + Emits the best observation when the window advances. The last window + is emitted when the upstream iterator exhausts — no flush needed. + """ + + def __init__(self, quality_fn: Callable[[Any], float], window: float) -> None: + self._quality_fn = quality_fn + self._window = window + + def __call__(self, upstream: Iterator[Observation[T]]) -> Iterator[Observation[T]]: + quality_fn = self._quality_fn + window = self._window + best: Observation[T] | None = None + best_score: float = -1.0 + window_start: float | None = None + + for obs in upstream: + if window_start is not None and (obs.ts - window_start) >= window: + if best is not None: + yield best + best = None + best_score = -1.0 + window_start = obs.ts + + score = quality_fn(obs.data) + if score > best_score: + best = obs + best_score = score + if window_start is None: + window_start = obs.ts + + if best is not None: + yield best diff --git a/dimos/memory2/type.py b/dimos/memory2/type.py new file mode 100644 index 0000000000..14eb26ba34 --- /dev/null +++ b/dimos/memory2/type.py @@ -0,0 +1,178 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Generic, Protocol, TypeVar, runtime_checkable + +T = TypeVar("T") +R = TypeVar("R") + + +# ── Filter protocol ───────────────────────────────────────────────── + + +@runtime_checkable +class Filter(Protocol): + """Any object with a .matches(obs) -> bool method can be a filter.""" + + def matches(self, obs: Observation[Any]) -> bool: ... + + +# ── Lazy data sentinel ────────────────────────────────────────────── + + +class _Unloaded: + """Sentinel indicating data has not been loaded yet.""" + + __slots__ = () + + def __repr__(self) -> str: + return "" + + +_UNLOADED = _Unloaded() + + +# ── Observation ───────────────────────────────────────────────────── + + +@dataclass +class Observation(Generic[T]): + """A single timestamped observation with optional spatial pose and metadata.""" + + id: int + ts: float + pose: Any | None = None + tags: dict[str, Any] = field(default_factory=dict) + _data: T | _Unloaded = field(default=_UNLOADED, repr=False) + _loader: Any | None = field(default=None, repr=False) # Callable[[], T] + + @property + def data(self) -> T: + if isinstance(self._data, _Unloaded): + if self._loader is None: + raise LookupError("No data and no loader set on this observation") + self._data = self._loader() + self._loader = None # release closure + return self._data # type: ignore[return-value] + + def derive(self, *, data: Any, **overrides: Any) -> Observation[Any]: + """Create a new observation preserving ts/pose/tags, replacing data.""" + return Observation( + id=self.id, + ts=overrides.get("ts", self.ts), + pose=overrides.get("pose", self.pose), + tags=overrides.get("tags", self.tags), + _data=data, + ) + + +# ── Filters ───────────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class AfterFilter: + t: float + + def matches(self, obs: Observation[Any]) -> bool: + return obs.ts > self.t + + +@dataclass(frozen=True) +class BeforeFilter: + t: float + + def matches(self, obs: Observation[Any]) -> bool: + return obs.ts < self.t + + +@dataclass(frozen=True) +class TimeRangeFilter: + t1: float + t2: float + + def matches(self, obs: Observation[Any]) -> bool: + return self.t1 <= obs.ts <= self.t2 + + +@dataclass(frozen=True) +class AtFilter: + t: float + tolerance: float = 1.0 + + def matches(self, obs: Observation[Any]) -> bool: + return abs(obs.ts - self.t) <= self.tolerance + + +@dataclass(frozen=True) +class NearFilter: + pose: Any + radius: float + + def matches(self, obs: Observation[Any]) -> bool: + if obs.pose is None or self.pose is None: + return False + p1 = self.pose + p2 = obs.pose + # Support both raw (x,y,z) tuples and PoseStamped objects + if hasattr(p1, "position"): + p1 = p1.position + if hasattr(p2, "position"): + p2 = p2.position + x1, y1, z1 = _xyz(p1) + x2, y2, z2 = _xyz(p2) + dist_sq = (x1 - x2) ** 2 + (y1 - y2) ** 2 + (z1 - z2) ** 2 + return dist_sq <= self.radius**2 + + +def _xyz(p: Any) -> tuple[float, float, float]: + """Extract (x, y, z) from various pose representations.""" + if isinstance(p, (list, tuple)): + return (float(p[0]), float(p[1]), float(p[2]) if len(p) > 2 else 0.0) + return (float(p.x), float(p.y), float(getattr(p, "z", 0.0))) + + +@dataclass(frozen=True) +class TagsFilter: + tags: dict[str, Any] + + def matches(self, obs: Observation[Any]) -> bool: + for k, v in self.tags.items(): + if obs.tags.get(k) != v: + return False + return True + + +@dataclass(frozen=True) +class PredicateFilter: + """Wraps an arbitrary predicate function for use with .filter().""" + + fn: Any # Callable[[Observation], bool] — Any to keep frozen hashable + + def matches(self, obs: Observation[Any]) -> bool: + return self.fn(obs) + + +# ── StreamQuery ───────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class StreamQuery: + filters: tuple[Filter, ...] = () + order_field: str | None = None + order_desc: bool = False + limit_val: int | None = None + offset_val: int | None = None From 363f094ae7bb7d899e1a3ea04f7521d3ddd162f4 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Mon, 9 Mar 2026 18:39:32 +0800 Subject: [PATCH 062/118] =?UTF-8?q?memory2:=20fix=20typing=20=E2=80=94=20z?= =?UTF-8?q?ero=20type:ignore,=20proper=20generics?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Closed → ClosedError (N818) - Callable types for _loader, Disposable.fn, backend_factory, PredicateFilter.fn - Disposable typed in stream._live_sub - assert+narrowing instead of type:ignore in KeepLast.take, _iter_transform - cast only in Session.stream (unavoidable generic cache lookup) --- dimos/memory2/backend.py | 6 +++--- dimos/memory2/buffer.py | 10 +++++++--- dimos/memory2/store.py | 11 +++++++---- dimos/memory2/stream.py | 13 +++++++------ dimos/memory2/type.py | 20 +++++++++++++------- 5 files changed, 37 insertions(+), 23 deletions(-) diff --git a/dimos/memory2/backend.py b/dimos/memory2/backend.py index 1438bd259c..f95cb0c43e 100644 --- a/dimos/memory2/backend.py +++ b/dimos/memory2/backend.py @@ -21,7 +21,7 @@ from dimos.memory2.type import Observation, StreamQuery if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import Callable, Iterator from dimos.memory2.buffer import BackpressureBuffer @@ -31,8 +31,8 @@ class Disposable: """Simple disposable that calls a function on dispose().""" - def __init__(self, fn: Any) -> None: - self._fn = fn + def __init__(self, fn: Callable[[], None]) -> None: + self._fn: Callable[[], None] | None = fn def dispose(self) -> None: if self._fn is not None: diff --git a/dimos/memory2/buffer.py b/dimos/memory2/buffer.py index 2669bef616..de122f3d68 100644 --- a/dimos/memory2/buffer.py +++ b/dimos/memory2/buffer.py @@ -17,7 +17,10 @@ from abc import ABC, abstractmethod from collections import deque import threading -from typing import Generic, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar + +if TYPE_CHECKING: + from collections.abc import Iterator T = TypeVar("T") @@ -48,7 +51,7 @@ def close(self) -> None: @abstractmethod def __len__(self) -> int: ... - def __iter__(self): + def __iter__(self) -> Iterator[T]: """Yield items until the buffer is closed.""" while True: try: @@ -83,9 +86,10 @@ def take(self, timeout: float | None = None) -> T: if not self._cond.wait(timeout): raise TimeoutError("take() timed out") item = self._item + assert item is not None self._item = None self._has_item = False - return item # type: ignore[return-value] + return item def try_take(self) -> T | None: with self._cond: diff --git a/dimos/memory2/store.py b/dimos/memory2/store.py index b36e845143..e8e5340a14 100644 --- a/dimos/memory2/store.py +++ b/dimos/memory2/store.py @@ -14,12 +14,15 @@ from __future__ import annotations -from typing import Any, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar, cast from dimos.core.resource import Resource from dimos.memory2.backend import Backend, ListBackend from dimos.memory2.stream import Stream +if TYPE_CHECKING: + from collections.abc import Callable, Iterator + T = TypeVar("T") @@ -52,7 +55,7 @@ def __getitem__(self, name: str) -> Stream[Any]: except KeyError: raise KeyError(name) from None - def __iter__(self): + def __iter__(self) -> Iterator[Stream[Any]]: return iter(self._session._streams.values()) def __len__(self) -> int: @@ -68,7 +71,7 @@ def __repr__(self) -> str: class Session(Resource): """A session against a store. Creates and manages named streams.""" - def __init__(self, backend_factory: Any) -> None: # Callable[[str], Backend] + def __init__(self, backend_factory: Callable[[str], Backend[Any]]) -> None: self._backend_factory = backend_factory self._streams: dict[str, Stream[Any]] = {} self._backends: dict[str, Backend[Any]] = {} @@ -79,7 +82,7 @@ def stream(self, name: str, payload_type: type[T] | None = None) -> Stream[T]: backend = self._backend_factory(name) self._backends[name] = backend self._streams[name] = Stream(source=backend) - return self._streams[name] # type: ignore[return-value] + return cast("Stream[T]", self._streams[name]) def list_streams(self) -> list[Stream[Any]]: return list(self._streams.values()) diff --git a/dimos/memory2/stream.py b/dimos/memory2/stream.py index e602df6013..278197e4b9 100644 --- a/dimos/memory2/stream.py +++ b/dimos/memory2/stream.py @@ -17,7 +17,7 @@ from itertools import islice from typing import TYPE_CHECKING, Any, Generic, TypeVar -from dimos.memory2.backend import Backend +from dimos.memory2.backend import Backend, Disposable from dimos.memory2.buffer import BackpressureBuffer, ClosedError, KeepLast from dimos.memory2.transform import FnTransformer, Transformer from dimos.memory2.type import ( @@ -57,7 +57,7 @@ def __init__( *, query: StreamQuery = StreamQuery(), _live_buf: BackpressureBuffer[Observation[Any]] | None = None, - _live_sub: Any | None = None, + _live_sub: Disposable | None = None, ) -> None: self._source = source self._query = query @@ -84,8 +84,9 @@ def _build_iter(self) -> Iterator[Observation[T]]: def _iter_transform(self) -> Iterator[Observation[T]]: """Iterate a transform source, applying query filters in Python.""" - upstream_stream, xf = self._source # type: ignore[misc] - it: Iterator[Observation[Any]] = xf(iter(upstream_stream)) + assert isinstance(self._source, tuple) + upstream_stream, xf = self._source + it: Iterator[Observation[T]] = xf(iter(upstream_stream)) # Apply filters as Python predicates filters = self._query.filters @@ -109,7 +110,7 @@ def _iter_transform(self) -> Iterator[Observation[T]]: if self._query.limit_val is not None: it = islice(it, self._query.limit_val) - return it # type: ignore[return-value] + return it def _iter_with_live(self, backfill: Iterator[Observation[T]]) -> Iterator[Observation[T]]: """Yield backfill, then switch to live tail.""" @@ -130,7 +131,7 @@ def _iter_with_live(self, backfill: Iterator[Observation[T]]) -> Iterator[Observ last_id = obs.id if filters and not all(f.matches(obs) for f in filters): continue - yield obs # type: ignore[misc] + yield obs except (ClosedError, StopIteration): return diff --git a/dimos/memory2/type.py b/dimos/memory2/type.py index 14eb26ba34..22a974c11f 100644 --- a/dimos/memory2/type.py +++ b/dimos/memory2/type.py @@ -15,7 +15,10 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Generic, Protocol, TypeVar, runtime_checkable +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, runtime_checkable + +if TYPE_CHECKING: + from collections.abc import Callable T = TypeVar("T") R = TypeVar("R") @@ -58,16 +61,19 @@ class Observation(Generic[T]): pose: Any | None = None tags: dict[str, Any] = field(default_factory=dict) _data: T | _Unloaded = field(default=_UNLOADED, repr=False) - _loader: Any | None = field(default=None, repr=False) # Callable[[], T] + _loader: Callable[[], T] | None = field(default=None, repr=False) @property def data(self) -> T: - if isinstance(self._data, _Unloaded): + val = self._data + if isinstance(val, _Unloaded): if self._loader is None: raise LookupError("No data and no loader set on this observation") - self._data = self._loader() + loaded = self._loader() + self._data = loaded self._loader = None # release closure - return self._data # type: ignore[return-value] + return loaded + return val def derive(self, *, data: Any, **overrides: Any) -> Observation[Any]: """Create a new observation preserving ts/pose/tags, replacing data.""" @@ -160,10 +166,10 @@ def matches(self, obs: Observation[Any]) -> bool: class PredicateFilter: """Wraps an arbitrary predicate function for use with .filter().""" - fn: Any # Callable[[Observation], bool] — Any to keep frozen hashable + fn: Callable[[Observation[Any]], bool] def matches(self, obs: Observation[Any]) -> bool: - return self.fn(obs) + return bool(self.fn(obs)) # ── StreamQuery ───────────────────────────────────────────────────── From 90a636a72e77561ff8d4fead37164b02db3c2241 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Mon, 9 Mar 2026 18:57:01 +0800 Subject: [PATCH 063/118] =?UTF-8?q?memory2:=20fix=20.live()=20on=20transfo?= =?UTF-8?q?rm=20streams=20=E2=80=94=20reject=20with=20clear=20error?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Live items from the backend buffer were bypassing the transform chain entirely. The fix: .live() is only valid on backend-backed streams; transforms downstream just see an infinite iterator. --- dimos/memory2/stream.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/dimos/memory2/stream.py b/dimos/memory2/stream.py index 278197e4b9..dbe23dd62d 100644 --- a/dimos/memory2/stream.py +++ b/dimos/memory2/stream.py @@ -213,24 +213,22 @@ def transform(self, xf: Transformer[Any, Any]) -> Stream[Any]: def live(self, buffer: BackpressureBuffer[Observation[Any]] | None = None) -> Stream[T]: """Return a stream that yields backfill then live data (infinite iterator). - Default buffer: KeepLast(). Subscribes to the root backend BEFORE - backfill starts and deduplicates by observation id. + Only valid on backend-backed streams. Transforms downstream of a live + stream just see an infinite iterator — they don't need to know about + liveness. Call .live() before .transform(), not after. + + Default buffer: KeepLast(). Subscribes to the backend BEFORE backfill + starts and deduplicates by observation id. """ + if isinstance(self._source, tuple): + raise TypeError( + "Cannot call .live() on a transform stream. " + "Call .live() on the source stream, then .transform()." + ) buf = buffer if buffer is not None else KeepLast() - backend = self._find_backend() - sub = backend.subscribe(buf) + sub = self._source.subscribe(buf) return Stream(self._source, query=self._query, _live_buf=buf, _live_sub=sub) - def _find_backend(self) -> Backend[Any]: - """Walk up the source chain to find the root Backend.""" - source = self._source - while isinstance(source, tuple): - upstream_stream, _ = source - source = upstream_stream._source - if not isinstance(source, Backend): - raise TypeError("Cannot find a backend in this stream chain") - return source - # ── Terminals ─────────────────────────────────────────────────── def fetch(self) -> list[Observation[T]]: From 2f029e41221576b1848c89a9d7f4dcf9859f4174 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Mon, 9 Mar 2026 20:12:36 +0800 Subject: [PATCH 064/118] memory2: replace custom Disposable with rxpy DisposableBase Use reactivex.abc.DisposableBase in protocols and reactivex.disposable.Disposable in implementations, consistent with dimos's existing Resource pattern. --- dimos/memory2/__init__.py | 3 +-- dimos/memory2/backend.py | 24 ++++++++---------------- dimos/memory2/stream.py | 8 +++++--- 3 files changed, 14 insertions(+), 21 deletions(-) diff --git a/dimos/memory2/__init__.py b/dimos/memory2/__init__.py index eeda6fd788..5b33f35505 100644 --- a/dimos/memory2/__init__.py +++ b/dimos/memory2/__init__.py @@ -1,4 +1,4 @@ -from dimos.memory2.backend import Backend, Disposable, ListBackend +from dimos.memory2.backend import Backend, ListBackend from dimos.memory2.buffer import ( BackpressureBuffer, Bounded, @@ -31,7 +31,6 @@ "BeforeFilter", "Bounded", "ClosedError", - "Disposable", "DropNew", "Filter", "FnTransformer", diff --git a/dimos/memory2/backend.py b/dimos/memory2/backend.py index f95cb0c43e..b26ac41d39 100644 --- a/dimos/memory2/backend.py +++ b/dimos/memory2/backend.py @@ -18,28 +18,20 @@ import time from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, runtime_checkable +from reactivex.disposable import Disposable + from dimos.memory2.type import Observation, StreamQuery if TYPE_CHECKING: - from collections.abc import Callable, Iterator + from collections.abc import Iterator + + from reactivex.abc import DisposableBase from dimos.memory2.buffer import BackpressureBuffer T = TypeVar("T") -class Disposable: - """Simple disposable that calls a function on dispose().""" - - def __init__(self, fn: Callable[[], None]) -> None: - self._fn: Callable[[], None] | None = fn - - def dispose(self) -> None: - if self._fn is not None: - self._fn() - self._fn = None - - @runtime_checkable class Backend(Protocol[T]): """Data source protocol for stored observations. @@ -64,7 +56,7 @@ def append( def count(self, query: StreamQuery) -> int: ... - def subscribe(self, buf: BackpressureBuffer[Observation[T]]) -> Disposable: ... + def subscribe(self, buf: BackpressureBuffer[Observation[T]]) -> DisposableBase: ... class ListBackend(Generic[T]): @@ -137,7 +129,7 @@ def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: def count(self, query: StreamQuery) -> int: return sum(1 for _ in self.iterate(query)) - def subscribe(self, buf: BackpressureBuffer[Observation[T]]) -> Disposable: + def subscribe(self, buf: BackpressureBuffer[Observation[T]]) -> DisposableBase: with self._lock: self._subscribers.append(buf) @@ -148,4 +140,4 @@ def _unsubscribe() -> None: except ValueError: pass - return Disposable(_unsubscribe) + return Disposable(action=_unsubscribe) diff --git a/dimos/memory2/stream.py b/dimos/memory2/stream.py index dbe23dd62d..04568894b6 100644 --- a/dimos/memory2/stream.py +++ b/dimos/memory2/stream.py @@ -17,7 +17,7 @@ from itertools import islice from typing import TYPE_CHECKING, Any, Generic, TypeVar -from dimos.memory2.backend import Backend, Disposable +from dimos.memory2.backend import Backend from dimos.memory2.buffer import BackpressureBuffer, ClosedError, KeepLast from dimos.memory2.transform import FnTransformer, Transformer from dimos.memory2.type import ( @@ -36,6 +36,8 @@ if TYPE_CHECKING: from collections.abc import Callable, Iterable, Iterator + from reactivex.abc import DisposableBase + T = TypeVar("T") R = TypeVar("R") @@ -57,12 +59,12 @@ def __init__( *, query: StreamQuery = StreamQuery(), _live_buf: BackpressureBuffer[Observation[Any]] | None = None, - _live_sub: Disposable | None = None, + _live_sub: DisposableBase | None = None, ) -> None: self._source = source self._query = query self._live_buf = _live_buf - self._live_sub = _live_sub # Disposable, kept alive for lifetime of stream + self._live_sub = _live_sub # kept alive for lifetime of stream # ── Iteration ─────────────────────────────────────────────────── From 4061b8ffab6586e72cfcbb6eb011a81240c6df7b Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Mon, 9 Mar 2026 20:22:49 +0800 Subject: [PATCH 065/118] memory2: extract filters and StreamQuery from type.py into filter.py type.py now only contains Observation and its helpers. --- dimos/memory2/__init__.py | 10 +-- dimos/memory2/backend.py | 3 +- dimos/memory2/filter.py | 131 ++++++++++++++++++++++++++++++++++++++ dimos/memory2/stream.py | 7 +- dimos/memory2/type.py | 111 +------------------------------- 5 files changed, 143 insertions(+), 119 deletions(-) create mode 100644 dimos/memory2/filter.py diff --git a/dimos/memory2/__init__.py b/dimos/memory2/__init__.py index 5b33f35505..2df421954e 100644 --- a/dimos/memory2/__init__.py +++ b/dimos/memory2/__init__.py @@ -7,21 +7,21 @@ KeepLast, Unbounded, ) -from dimos.memory2.store import ListStore, Session, Store, StreamNamespace -from dimos.memory2.stream import Stream -from dimos.memory2.transform import FnTransformer, QualityWindow, Transformer -from dimos.memory2.type import ( +from dimos.memory2.filter import ( AfterFilter, AtFilter, BeforeFilter, Filter, NearFilter, - Observation, PredicateFilter, StreamQuery, TagsFilter, TimeRangeFilter, ) +from dimos.memory2.store import ListStore, Session, Store, StreamNamespace +from dimos.memory2.stream import Stream +from dimos.memory2.transform import FnTransformer, QualityWindow, Transformer +from dimos.memory2.type import Observation __all__ = [ "AfterFilter", diff --git a/dimos/memory2/backend.py b/dimos/memory2/backend.py index b26ac41d39..23e7cd0726 100644 --- a/dimos/memory2/backend.py +++ b/dimos/memory2/backend.py @@ -20,7 +20,7 @@ from reactivex.disposable import Disposable -from dimos.memory2.type import Observation, StreamQuery +from dimos.memory2.type import Observation if TYPE_CHECKING: from collections.abc import Iterator @@ -28,6 +28,7 @@ from reactivex.abc import DisposableBase from dimos.memory2.buffer import BackpressureBuffer + from dimos.memory2.filter import StreamQuery T = TypeVar("T") diff --git a/dimos/memory2/filter.py b/dimos/memory2/filter.py new file mode 100644 index 0000000000..2901ebf04a --- /dev/null +++ b/dimos/memory2/filter.py @@ -0,0 +1,131 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +if TYPE_CHECKING: + from collections.abc import Callable + + from dimos.memory2.type import Observation + + +# ── Filter protocol ───────────────────────────────────────────────── + + +@runtime_checkable +class Filter(Protocol): + """Any object with a .matches(obs) -> bool method can be a filter.""" + + def matches(self, obs: Observation[Any]) -> bool: ... + + +# ── Concrete filters ──────────────────────────────────────────────── + + +@dataclass(frozen=True) +class AfterFilter: + t: float + + def matches(self, obs: Observation[Any]) -> bool: + return obs.ts > self.t + + +@dataclass(frozen=True) +class BeforeFilter: + t: float + + def matches(self, obs: Observation[Any]) -> bool: + return obs.ts < self.t + + +@dataclass(frozen=True) +class TimeRangeFilter: + t1: float + t2: float + + def matches(self, obs: Observation[Any]) -> bool: + return self.t1 <= obs.ts <= self.t2 + + +@dataclass(frozen=True) +class AtFilter: + t: float + tolerance: float = 1.0 + + def matches(self, obs: Observation[Any]) -> bool: + return abs(obs.ts - self.t) <= self.tolerance + + +@dataclass(frozen=True) +class NearFilter: + pose: Any + radius: float + + def matches(self, obs: Observation[Any]) -> bool: + if obs.pose is None or self.pose is None: + return False + p1 = self.pose + p2 = obs.pose + # Support both raw (x,y,z) tuples and PoseStamped objects + if hasattr(p1, "position"): + p1 = p1.position + if hasattr(p2, "position"): + p2 = p2.position + x1, y1, z1 = _xyz(p1) + x2, y2, z2 = _xyz(p2) + dist_sq = (x1 - x2) ** 2 + (y1 - y2) ** 2 + (z1 - z2) ** 2 + return dist_sq <= self.radius**2 + + +def _xyz(p: Any) -> tuple[float, float, float]: + """Extract (x, y, z) from various pose representations.""" + if isinstance(p, (list, tuple)): + return (float(p[0]), float(p[1]), float(p[2]) if len(p) > 2 else 0.0) + return (float(p.x), float(p.y), float(getattr(p, "z", 0.0))) + + +@dataclass(frozen=True) +class TagsFilter: + tags: dict[str, Any] + + def matches(self, obs: Observation[Any]) -> bool: + for k, v in self.tags.items(): + if obs.tags.get(k) != v: + return False + return True + + +@dataclass(frozen=True) +class PredicateFilter: + """Wraps an arbitrary predicate function for use with .filter().""" + + fn: Callable[[Observation[Any]], bool] + + def matches(self, obs: Observation[Any]) -> bool: + return bool(self.fn(obs)) + + +# ── StreamQuery ───────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class StreamQuery: + filters: tuple[Filter, ...] = () + order_field: str | None = None + order_desc: bool = False + limit_val: int | None = None + offset_val: int | None = None diff --git a/dimos/memory2/stream.py b/dimos/memory2/stream.py index 04568894b6..dc1d05da0c 100644 --- a/dimos/memory2/stream.py +++ b/dimos/memory2/stream.py @@ -19,25 +19,26 @@ from dimos.memory2.backend import Backend from dimos.memory2.buffer import BackpressureBuffer, ClosedError, KeepLast -from dimos.memory2.transform import FnTransformer, Transformer -from dimos.memory2.type import ( +from dimos.memory2.filter import ( AfterFilter, AtFilter, BeforeFilter, Filter, NearFilter, - Observation, PredicateFilter, StreamQuery, TagsFilter, TimeRangeFilter, ) +from dimos.memory2.transform import FnTransformer, Transformer if TYPE_CHECKING: from collections.abc import Callable, Iterable, Iterator from reactivex.abc import DisposableBase + from dimos.memory2.type import Observation + T = TypeVar("T") R = TypeVar("R") diff --git a/dimos/memory2/type.py b/dimos/memory2/type.py index 22a974c11f..59ec300685 100644 --- a/dimos/memory2/type.py +++ b/dimos/memory2/type.py @@ -15,23 +15,12 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, runtime_checkable +from typing import TYPE_CHECKING, Any, Generic, TypeVar if TYPE_CHECKING: from collections.abc import Callable T = TypeVar("T") -R = TypeVar("R") - - -# ── Filter protocol ───────────────────────────────────────────────── - - -@runtime_checkable -class Filter(Protocol): - """Any object with a .matches(obs) -> bool method can be a filter.""" - - def matches(self, obs: Observation[Any]) -> bool: ... # ── Lazy data sentinel ────────────────────────────────────────────── @@ -84,101 +73,3 @@ def derive(self, *, data: Any, **overrides: Any) -> Observation[Any]: tags=overrides.get("tags", self.tags), _data=data, ) - - -# ── Filters ───────────────────────────────────────────────────────── - - -@dataclass(frozen=True) -class AfterFilter: - t: float - - def matches(self, obs: Observation[Any]) -> bool: - return obs.ts > self.t - - -@dataclass(frozen=True) -class BeforeFilter: - t: float - - def matches(self, obs: Observation[Any]) -> bool: - return obs.ts < self.t - - -@dataclass(frozen=True) -class TimeRangeFilter: - t1: float - t2: float - - def matches(self, obs: Observation[Any]) -> bool: - return self.t1 <= obs.ts <= self.t2 - - -@dataclass(frozen=True) -class AtFilter: - t: float - tolerance: float = 1.0 - - def matches(self, obs: Observation[Any]) -> bool: - return abs(obs.ts - self.t) <= self.tolerance - - -@dataclass(frozen=True) -class NearFilter: - pose: Any - radius: float - - def matches(self, obs: Observation[Any]) -> bool: - if obs.pose is None or self.pose is None: - return False - p1 = self.pose - p2 = obs.pose - # Support both raw (x,y,z) tuples and PoseStamped objects - if hasattr(p1, "position"): - p1 = p1.position - if hasattr(p2, "position"): - p2 = p2.position - x1, y1, z1 = _xyz(p1) - x2, y2, z2 = _xyz(p2) - dist_sq = (x1 - x2) ** 2 + (y1 - y2) ** 2 + (z1 - z2) ** 2 - return dist_sq <= self.radius**2 - - -def _xyz(p: Any) -> tuple[float, float, float]: - """Extract (x, y, z) from various pose representations.""" - if isinstance(p, (list, tuple)): - return (float(p[0]), float(p[1]), float(p[2]) if len(p) > 2 else 0.0) - return (float(p.x), float(p.y), float(getattr(p, "z", 0.0))) - - -@dataclass(frozen=True) -class TagsFilter: - tags: dict[str, Any] - - def matches(self, obs: Observation[Any]) -> bool: - for k, v in self.tags.items(): - if obs.tags.get(k) != v: - return False - return True - - -@dataclass(frozen=True) -class PredicateFilter: - """Wraps an arbitrary predicate function for use with .filter().""" - - fn: Callable[[Observation[Any]], bool] - - def matches(self, obs: Observation[Any]) -> bool: - return bool(self.fn(obs)) - - -# ── StreamQuery ───────────────────────────────────────────────────── - - -@dataclass(frozen=True) -class StreamQuery: - filters: tuple[Filter, ...] = () - order_field: str | None = None - order_desc: bool = False - limit_val: int | None = None - offset_val: int | None = None From a44d8705de635e3f962d2d6a781db92c7235a198 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Mon, 9 Mar 2026 21:18:56 +0800 Subject: [PATCH 066/118] memory2: store transform on Stream node, not as source tuple MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Stream._source is now `Backend | Stream` instead of `Backend | tuple[Stream, Transformer]`. The transformer lives on the stream that owns it (`_xf` field), not bundled into the source pointer. Fix .map() tests to pass Observation→Observation lambdas. Remove live mode tests (blocked by nvidia driver D-state in root conftest autoconf). --- dimos/memory2/stream.py | 49 ++++++------- dimos/memory2/test_stream.py | 136 +---------------------------------- dimos/memory2/transform.py | 2 +- 3 files changed, 24 insertions(+), 163 deletions(-) diff --git a/dimos/memory2/stream.py b/dimos/memory2/stream.py index dc1d05da0c..b2f44c14ca 100644 --- a/dimos/memory2/stream.py +++ b/dimos/memory2/stream.py @@ -33,7 +33,7 @@ from dimos.memory2.transform import FnTransformer, Transformer if TYPE_CHECKING: - from collections.abc import Callable, Iterable, Iterator + from collections.abc import Callable, Iterator from reactivex.abc import DisposableBase @@ -42,9 +42,6 @@ T = TypeVar("T") R = TypeVar("R") -# Source is either a Backend or a (upstream_stream, transformer) pair. -_Source = Backend[Any] | tuple["Stream[Any]", Transformer[Any, Any]] - class Stream(Generic[T]): """Lazy, pull-based stream over observations. @@ -56,13 +53,15 @@ class Stream(Generic[T]): def __init__( self, - source: _Source, + source: Backend[T] | Stream[Any], *, + xf: Transformer[Any, T] | None = None, query: StreamQuery = StreamQuery(), _live_buf: BackpressureBuffer[Observation[Any]] | None = None, _live_sub: DisposableBase | None = None, ) -> None: self._source = source + self._xf = xf self._query = query self._live_buf = _live_buf self._live_sub = _live_sub # kept alive for lifetime of stream @@ -73,7 +72,7 @@ def __iter__(self) -> Iterator[Observation[T]]: return self._build_iter() def _build_iter(self) -> Iterator[Observation[T]]: - if isinstance(self._source, tuple): + if isinstance(self._source, Stream): it = self._iter_transform() else: # Backend handles all query application @@ -87,9 +86,8 @@ def _build_iter(self) -> Iterator[Observation[T]]: def _iter_transform(self) -> Iterator[Observation[T]]: """Iterate a transform source, applying query filters in Python.""" - assert isinstance(self._source, tuple) - upstream_stream, xf = self._source - it: Iterator[Observation[T]] = xf(iter(upstream_stream)) + assert isinstance(self._source, Stream) and self._xf is not None + it: Iterator[Observation[T]] = self._xf(iter(self._source)) # Apply filters as Python predicates filters = self._query.filters @@ -149,7 +147,13 @@ def _replace_query(self, **overrides: Any) -> Stream[T]: limit_val=overrides.get("limit_val", q.limit_val), offset_val=overrides.get("offset_val", q.offset_val), ) - return Stream(self._source, query=new_q, _live_buf=self._live_buf, _live_sub=self._live_sub) + return Stream( + self._source, + xf=self._xf, + query=new_q, + _live_buf=self._live_buf, + _live_sub=self._live_sub, + ) def _with_filter(self, f: Filter) -> Stream[T]: return self._replace_query(filters=(*self._query.filters, f)) @@ -187,29 +191,18 @@ def filter(self, pred: Callable[[Observation[T]], bool]) -> Stream[T]: """Filter by arbitrary predicate on the full Observation.""" return self._with_filter(PredicateFilter(pred)) - def map(self, fn: Callable[[Observation[T]], Any]) -> Stream[Any]: + def map(self, fn: Callable[[Observation[T]], Observation[R]]) -> Stream[Any]: """Transform each observation's data via callable.""" - return self.transform(FnTransformer(lambda obs: obs.derive(data=fn(obs)))) - - def flat_map(self, fn: Callable[[Observation[T]], Iterable[Any]]) -> Stream[Any]: - """Map that fans out — fn returns iterable of data values per observation.""" - - class _FlatMapXf(Transformer[T, Any]): - def __call__(self, upstream: Iterator[Observation[T]]) -> Iterator[Observation[Any]]: - for obs in upstream: - for item in fn(obs): - yield obs.derive(data=item) - - return self.transform(_FlatMapXf()) + return self.transform(FnTransformer(lambda obs: fn(obs))) # ── Transform ─────────────────────────────────────────────────── - def transform(self, xf: Transformer[Any, Any]) -> Stream[Any]: + def transform(self, xf: Transformer[T, R]) -> Stream[R]: """Wrap this stream with a transformer. Returns a new lazy Stream. When iterated, calls xf(iter(self)) — pulls lazily through the chain. """ - return Stream(source=(self, xf), query=StreamQuery()) + return Stream(source=self, xf=xf, query=StreamQuery()) # ── Live mode ─────────────────────────────────────────────────── @@ -223,7 +216,7 @@ def live(self, buffer: BackpressureBuffer[Observation[Any]] | None = None) -> St Default buffer: KeepLast(). Subscribes to the backend BEFORE backfill starts and deduplicates by observation id. """ - if isinstance(self._source, tuple): + if isinstance(self._source, Stream): raise TypeError( "Cannot call .live() on a transform stream. " "Call .live() on the source stream, then .transform()." @@ -252,7 +245,7 @@ def last(self) -> Observation[T]: def count(self) -> int: """Count matching observations.""" - if isinstance(self._source, Backend) and not isinstance(self._source, tuple): + if isinstance(self._source, Backend): return self._source.count(self._query) return sum(1 for _ in self) @@ -271,6 +264,6 @@ def append( tags: dict[str, Any] | None = None, ) -> Observation[T]: """Append to the backing store. Only works if source is a Backend.""" - if isinstance(self._source, tuple): + if isinstance(self._source, Stream): raise TypeError("Cannot append to a transform stream. Append to the source stream.") return self._source.append(payload, ts=ts, pose=pose, tags=tags) diff --git a/dimos/memory2/test_stream.py b/dimos/memory2/test_stream.py index db3c214394..4340c7c2bb 100644 --- a/dimos/memory2/test_stream.py +++ b/dimos/memory2/test_stream.py @@ -231,19 +231,14 @@ def test_filter_on_metadata(self): def test_map(self): """.map() transforms each observation's data.""" - result = make_stream(3).map(lambda obs: obs.data * 2).fetch() + result = make_stream(3).map(lambda obs: obs.derive(data=obs.data * 2)).fetch() assert [o.data for o in result] == [0, 20, 40] def test_map_preserves_ts(self): - result = make_stream(3).map(lambda obs: str(obs.data)).fetch() + result = make_stream(3).map(lambda obs: obs.derive(data=str(obs.data))).fetch() assert [o.ts for o in result] == [0.0, 1.0, 2.0] assert [o.data for o in result] == ["0", "10", "20"] - def test_flat_map(self): - """.flat_map() fans out — fn returns iterable of values per obs.""" - result = make_stream(3).flat_map(lambda obs: [obs.data, obs.data + 1]).fetch() - assert [o.data for o in result] == [0, 1, 10, 11, 20, 21] - # ═══════════════════════════════════════════════════════════════════ # 7. Transform chaining @@ -484,130 +479,3 @@ def producer(): result.append(buf.take(timeout=2.0)) t.join() assert result == [42] - - -# ═══════════════════════════════════════════════════════════════════ -# 11. Live mode -# ═══════════════════════════════════════════════════════════════════ - - -class TestLiveMode: - """Live streams yield backfill then block for new observations.""" - - def test_live_sees_backfill_then_new(self): - """Backfill first, then live appends come through.""" - backend = ListBackend[str]("live") - backend.append("old", ts=0.0) - stream = Stream(source=backend) - live = stream.live(buffer=Unbounded()) - - # Start consuming in a thread - results: list[str] = [] - consumed = threading.Event() - - def consumer(): - for obs in live: - results.append(obs.data) - if len(results) >= 3: - consumed.set() - return - - t = threading.Thread(target=consumer) - t.start() - - time.sleep(0.05) - backend.append("new1", ts=1.0) - backend.append("new2", ts=2.0) - - consumed.wait(timeout=2.0) - t.join(timeout=2.0) - assert results == ["old", "new1", "new2"] - - def test_live_with_filter(self): - """Filters apply to live data — non-matching obs are dropped silently.""" - backend = ListBackend[int]("live_filter") - stream = Stream(source=backend) - live = stream.after(5.0).live(buffer=Unbounded()) - - results: list[int] = [] - consumed = threading.Event() - - def consumer(): - for obs in live: - results.append(obs.data) - if len(results) >= 2: - consumed.set() - return - - t = threading.Thread(target=consumer) - t.start() - - time.sleep(0.05) - backend.append(1, ts=1.0) # filtered out (ts <= 5.0) - backend.append(2, ts=6.0) # passes - backend.append(3, ts=3.0) # filtered out - backend.append(4, ts=10.0) # passes - - consumed.wait(timeout=2.0) - t.join(timeout=2.0) - assert results == [2, 4] - - def test_live_deduplicates_backfill_overlap(self): - """Observations seen in backfill are not re-yielded from the live buffer.""" - backend = ListBackend[str]("dedup") - backend.append("backfill", ts=0.0) - stream = Stream(source=backend) - live = stream.live(buffer=Unbounded()) - - results: list[str] = [] - consumed = threading.Event() - - def consumer(): - for obs in live: - results.append(obs.data) - if len(results) >= 2: - consumed.set() - return - - t = threading.Thread(target=consumer) - t.start() - - time.sleep(0.05) - backend.append("live1", ts=1.0) - - consumed.wait(timeout=2.0) - t.join(timeout=2.0) - assert results == ["backfill", "live1"] - - def test_live_with_keep_last_backpressure(self): - """KeepLast drops intermediate values when consumer is slow.""" - backend = ListBackend[int]("bp") - stream = Stream(source=backend) - live = stream.live(buffer=KeepLast()) - - results: list[int] = [] - consumed = threading.Event() - - def consumer(): - for obs in live: - results.append(obs.data) - if obs.data >= 90: - consumed.set() - return - time.sleep(0.1) # slow consumer - - t = threading.Thread(target=consumer) - t.start() - - time.sleep(0.05) - # Rapid producer — KeepLast will drop most of these - for i in range(100): - backend.append(i, ts=float(i)) - time.sleep(0.001) - - consumed.wait(timeout=5.0) - t.join(timeout=2.0) - # Should have far fewer than 100 results due to KeepLast - assert len(results) < 50 - # Last result should be near the end - assert results[-1] >= 90 diff --git a/dimos/memory2/transform.py b/dimos/memory2/transform.py index 7fca2ab3e2..a39fb3c3b3 100644 --- a/dimos/memory2/transform.py +++ b/dimos/memory2/transform.py @@ -30,7 +30,7 @@ class Transformer(ABC, Generic[T, R]): """Transforms a stream of observations lazily via iterator -> iterator. Pull from upstream, yield transformed observations. Naturally supports - batching, windowing, fan-out. No flush() needed — the generator cleans + batching, windowing, fan-out. The generator cleans up when upstream exhausts. """ From 09ada62df7f4ec0727e9ed24147aa55c77666539 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Mon, 9 Mar 2026 21:29:22 +0800 Subject: [PATCH 067/118] memory2: move live logic from Stream into Backend via StreamQuery MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Live is now just a query parameter (live_buffer on StreamQuery). Stream.live() is a one-liner query modifier — the backend handles subscription, dedup, and backpressure internally. Stream has zero live implementation. --- dimos/memory2/backend.py | 42 ++++++- dimos/memory2/filter.py | 2 + dimos/memory2/stream.py | 67 ++--------- dimos/memory2/test_stream.py | 223 +++++++++++++++++++++++++++++++++++ 4 files changed, 278 insertions(+), 56 deletions(-) diff --git a/dimos/memory2/backend.py b/dimos/memory2/backend.py index 23e7cd0726..220f86347d 100644 --- a/dimos/memory2/backend.py +++ b/dimos/memory2/backend.py @@ -101,7 +101,19 @@ def append( return obs def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: - """Snapshot + apply all filters/ordering/offset/limit in Python.""" + """Snapshot + apply all filters/ordering/offset/limit in Python. + + If query.live_buffer is set, subscribes before backfill, then + switches to a live tail that blocks for new observations. + """ + buf = query.live_buffer + if buf is not None: + # Subscribe BEFORE backfill to avoid missing items + sub = self.subscribe(buf) + return self._iterate_live(query, buf, sub) + return self._iterate_snapshot(query) + + def _iterate_snapshot(self, query: StreamQuery) -> Iterator[Observation[T]]: with self._lock: snapshot = list(self._observations) @@ -127,6 +139,34 @@ def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: yield from snapshot + def _iterate_live( + self, + query: StreamQuery, + buf: BackpressureBuffer[Observation[T]], + sub: DisposableBase, + ) -> Iterator[Observation[T]]: + from dimos.memory2.buffer import ClosedError + + # Backfill phase — use snapshot query (without live) for the backfill + last_id = -1 + for obs in self._iterate_snapshot(query): + last_id = max(last_id, obs.id) + yield obs + + # Live tail + filters = query.filters + try: + while True: + obs = buf.take() + if obs.id <= last_id: + continue + last_id = obs.id + if filters and not all(f.matches(obs) for f in filters): + continue + yield obs + except (ClosedError, StopIteration): + sub.dispose() + def count(self, query: StreamQuery) -> int: return sum(1 for _ in self.iterate(query)) diff --git a/dimos/memory2/filter.py b/dimos/memory2/filter.py index 2901ebf04a..942b3f52d7 100644 --- a/dimos/memory2/filter.py +++ b/dimos/memory2/filter.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from dimos.memory2.buffer import BackpressureBuffer from dimos.memory2.type import Observation @@ -129,3 +130,4 @@ class StreamQuery: order_desc: bool = False limit_val: int | None = None offset_val: int | None = None + live_buffer: BackpressureBuffer[Any] | None = None diff --git a/dimos/memory2/stream.py b/dimos/memory2/stream.py index b2f44c14ca..ab327726dc 100644 --- a/dimos/memory2/stream.py +++ b/dimos/memory2/stream.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar from dimos.memory2.backend import Backend -from dimos.memory2.buffer import BackpressureBuffer, ClosedError, KeepLast +from dimos.memory2.buffer import BackpressureBuffer, KeepLast from dimos.memory2.filter import ( AfterFilter, AtFilter, @@ -35,8 +35,6 @@ if TYPE_CHECKING: from collections.abc import Callable, Iterator - from reactivex.abc import DisposableBase - from dimos.memory2.type import Observation T = TypeVar("T") @@ -57,14 +55,10 @@ def __init__( *, xf: Transformer[Any, T] | None = None, query: StreamQuery = StreamQuery(), - _live_buf: BackpressureBuffer[Observation[Any]] | None = None, - _live_sub: DisposableBase | None = None, ) -> None: self._source = source self._xf = xf self._query = query - self._live_buf = _live_buf - self._live_sub = _live_sub # kept alive for lifetime of stream # ── Iteration ─────────────────────────────────────────────────── @@ -73,16 +67,9 @@ def __iter__(self) -> Iterator[Observation[T]]: def _build_iter(self) -> Iterator[Observation[T]]: if isinstance(self._source, Stream): - it = self._iter_transform() - else: - # Backend handles all query application - it = self._source.iterate(self._query) - - # Live tail: after backfill exhausts, yield from live buffer - if self._live_buf is not None: - it = self._iter_with_live(it) - - return it + return self._iter_transform() + # Backend handles all query application (including live if requested) + return self._source.iterate(self._query) def _iter_transform(self) -> Iterator[Observation[T]]: """Iterate a transform source, applying query filters in Python.""" @@ -113,29 +100,6 @@ def _iter_transform(self) -> Iterator[Observation[T]]: return it - def _iter_with_live(self, backfill: Iterator[Observation[T]]) -> Iterator[Observation[T]]: - """Yield backfill, then switch to live tail.""" - last_id = -1 - for obs in backfill: - last_id = max(last_id, obs.id) - yield obs - - # Live phase - buf = self._live_buf - assert buf is not None - filters = self._query.filters - try: - while True: - obs = buf.take() - if obs.id <= last_id: - continue - last_id = obs.id - if filters and not all(f.matches(obs) for f in filters): - continue - yield obs - except (ClosedError, StopIteration): - return - # ── Query builders ────────────────────────────────────────────── def _replace_query(self, **overrides: Any) -> Stream[T]: @@ -146,14 +110,9 @@ def _replace_query(self, **overrides: Any) -> Stream[T]: order_desc=overrides.get("order_desc", q.order_desc), limit_val=overrides.get("limit_val", q.limit_val), offset_val=overrides.get("offset_val", q.offset_val), + live_buffer=overrides.get("live_buffer", q.live_buffer), ) - return Stream( - self._source, - xf=self._xf, - query=new_q, - _live_buf=self._live_buf, - _live_sub=self._live_sub, - ) + return Stream(self._source, xf=self._xf, query=new_q) def _with_filter(self, f: Filter) -> Stream[T]: return self._replace_query(filters=(*self._query.filters, f)) @@ -207,14 +166,13 @@ def transform(self, xf: Transformer[T, R]) -> Stream[R]: # ── Live mode ─────────────────────────────────────────────────── def live(self, buffer: BackpressureBuffer[Observation[Any]] | None = None) -> Stream[T]: - """Return a stream that yields backfill then live data (infinite iterator). + """Return a stream whose iteration never ends — backfill then live tail. - Only valid on backend-backed streams. Transforms downstream of a live - stream just see an infinite iterator — they don't need to know about - liveness. Call .live() before .transform(), not after. + Only valid on backend-backed streams. Transforms downstream just see + an infinite iterator. Call .live() before .transform(), not after. - Default buffer: KeepLast(). Subscribes to the backend BEFORE backfill - starts and deduplicates by observation id. + Default buffer: KeepLast(). The backend handles subscription, dedup, + and backpressure — how it does so is its business. """ if isinstance(self._source, Stream): raise TypeError( @@ -222,8 +180,7 @@ def live(self, buffer: BackpressureBuffer[Observation[Any]] | None = None) -> St "Call .live() on the source stream, then .transform()." ) buf = buffer if buffer is not None else KeepLast() - sub = self._source.subscribe(buf) - return Stream(self._source, query=self._query, _live_buf=buf, _live_sub=sub) + return self._replace_query(live_buffer=buf) # ── Terminals ─────────────────────────────────────────────────── diff --git a/dimos/memory2/test_stream.py b/dimos/memory2/test_stream.py index 4340c7c2bb..70f2500039 100644 --- a/dimos/memory2/test_stream.py +++ b/dimos/memory2/test_stream.py @@ -479,3 +479,226 @@ def producer(): result.append(buf.take(timeout=2.0)) t.join() assert result == [42] + + +# ═══════════════════════════════════════════════════════════════════ +# 11. Live mode +# ═══════════════════════════════════════════════════════════════════ + + +class TestLiveMode: + """Live streams yield backfill then block for new observations.""" + + def test_live_sees_backfill_then_new(self): + """Backfill first, then live appends come through.""" + backend = ListBackend[str]("live") + backend.append("old", ts=0.0) + stream = Stream(source=backend) + live = stream.live(buffer=Unbounded()) + + results: list[str] = [] + consumed = threading.Event() + + def consumer(): + for obs in live: + results.append(obs.data) + if len(results) >= 3: + consumed.set() + return + + t = threading.Thread(target=consumer) + t.start() + + time.sleep(0.05) + backend.append("new1", ts=1.0) + backend.append("new2", ts=2.0) + + consumed.wait(timeout=2.0) + t.join(timeout=2.0) + assert results == ["old", "new1", "new2"] + + def test_live_with_filter(self): + """Filters apply to live data — non-matching obs are dropped silently.""" + backend = ListBackend[int]("live_filter") + stream = Stream(source=backend) + live = stream.after(5.0).live(buffer=Unbounded()) + + results: list[int] = [] + consumed = threading.Event() + + def consumer(): + for obs in live: + results.append(obs.data) + if len(results) >= 2: + consumed.set() + return + + t = threading.Thread(target=consumer) + t.start() + + time.sleep(0.05) + backend.append(1, ts=1.0) # filtered out (ts <= 5.0) + backend.append(2, ts=6.0) # passes + backend.append(3, ts=3.0) # filtered out + backend.append(4, ts=10.0) # passes + + consumed.wait(timeout=2.0) + t.join(timeout=2.0) + assert results == [2, 4] + + def test_live_deduplicates_backfill_overlap(self): + """Observations seen in backfill are not re-yielded from the live buffer.""" + backend = ListBackend[str]("dedup") + backend.append("backfill", ts=0.0) + stream = Stream(source=backend) + live = stream.live(buffer=Unbounded()) + + results: list[str] = [] + consumed = threading.Event() + + def consumer(): + for obs in live: + results.append(obs.data) + if len(results) >= 2: + consumed.set() + return + + t = threading.Thread(target=consumer) + t.start() + + time.sleep(0.05) + backend.append("live1", ts=1.0) + + consumed.wait(timeout=2.0) + t.join(timeout=2.0) + assert results == ["backfill", "live1"] + + def test_live_with_keep_last_backpressure(self): + """KeepLast drops intermediate values when consumer is slow.""" + backend = ListBackend[int]("bp") + stream = Stream(source=backend) + live = stream.live(buffer=KeepLast()) + + results: list[int] = [] + consumed = threading.Event() + + def consumer(): + for obs in live: + results.append(obs.data) + if obs.data >= 90: + consumed.set() + return + time.sleep(0.1) # slow consumer + + t = threading.Thread(target=consumer) + t.start() + + time.sleep(0.05) + # Rapid producer — KeepLast will drop most of these + for i in range(100): + backend.append(i, ts=float(i)) + time.sleep(0.001) + + consumed.wait(timeout=5.0) + t.join(timeout=2.0) + # KeepLast means many values were dropped — far fewer than 100 + assert len(results) < 50 + assert results[-1] >= 90 + + def test_live_transform_receives_live_items(self): + """Transforms downstream of .live() see both backfill and live items.""" + backend = ListBackend[int]("live_xf") + backend.append(1, ts=0.0) + stream = Stream(source=backend) + double = FnTransformer(lambda obs: obs.derive(data=obs.data * 2)) + live = stream.live(buffer=Unbounded()).transform(double) + + results: list[int] = [] + consumed = threading.Event() + + def consumer(): + for obs in live: + results.append(obs.data) + if len(results) >= 3: + consumed.set() + return + + t = threading.Thread(target=consumer) + t.start() + + time.sleep(0.05) + backend.append(10, ts=1.0) + backend.append(100, ts=2.0) + + consumed.wait(timeout=2.0) + t.join(timeout=2.0) + # All items went through the double transform + assert results == [2, 20, 200] + + def test_live_on_transform_raises(self): + """Calling .live() on a transform stream raises TypeError.""" + stream = make_stream(3) + xf = FnTransformer(lambda obs: obs) + with pytest.raises(TypeError, match="Cannot call .live"): + stream.transform(xf).live() + + def test_live_chained_transforms(self): + """stream.live().transform(A).transform(B) — both transforms applied to live items.""" + backend = ListBackend[int]("live_chain") + backend.append(1, ts=0.0) + stream = Stream(source=backend) + add_one = FnTransformer(lambda obs: obs.derive(data=obs.data + 1)) + double = FnTransformer(lambda obs: obs.derive(data=obs.data * 2)) + live = stream.live(buffer=Unbounded()).transform(add_one).transform(double) + + results: list[int] = [] + consumed = threading.Event() + + def consumer(): + for obs in live: + results.append(obs.data) + if len(results) >= 3: + consumed.set() + return + + t = threading.Thread(target=consumer) + t.start() + + time.sleep(0.05) + backend.append(10, ts=1.0) + backend.append(100, ts=2.0) + + consumed.wait(timeout=2.0) + t.join(timeout=2.0) + # (1+1)*2=4, (10+1)*2=22, (100+1)*2=202 + assert results == [4, 22, 202] + + def test_live_filter_before_live(self): + """Filters applied before .live() work on both backfill and live items.""" + backend = ListBackend[str]("live_pre_filter") + backend.append("a", ts=1.0) + backend.append("b", ts=10.0) + stream = Stream(source=backend) + live = stream.after(5.0).live(buffer=Unbounded()) + + results: list[str] = [] + consumed = threading.Event() + + def consumer(): + for obs in live: + results.append(obs.data) + if len(results) >= 2: + consumed.set() + return + + t = threading.Thread(target=consumer) + t.start() + + time.sleep(0.05) + backend.append("c", ts=3.0) # filtered + backend.append("d", ts=20.0) # passes + + consumed.wait(timeout=2.0) + t.join(timeout=2.0) + # "a" filtered in backfill, "c" filtered in live + assert results == ["b", "d"] From 9ef10abbe6e5079be6a8fa98aac56fac7783c875 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Tue, 10 Mar 2026 13:47:31 +0800 Subject: [PATCH 068/118] memory2: extract impl/ layer with MemoryStore and SqliteStore scaffold Move ListBackend from backend.py into impl/memory.py alongside new MemorySession and MemoryStore. Add SqliteStore/SqliteSession/SqliteBackend skeleton in impl/sqlite.py. Refactor Store and Session to abstract base classes with _create_backend() hook. backend.py now only contains the Backend and LiveBackend protocols. Also fix doclinks: disambiguate memory.py reference in transports docs, and include source .md file path in all doclinks error messages. --- dimos/memory2/__init__.py | 13 ++- dimos/memory2/backend.py | 136 ++--------------------- dimos/memory2/impl/__init__.py | 13 +++ dimos/memory2/impl/memory.py | 173 +++++++++++++++++++++++++++++ dimos/memory2/impl/sqlite.py | 90 ++++++++++++++++ dimos/memory2/store.py | 62 +++++------ dimos/memory2/stream.py | 22 +++- dimos/memory2/test_save.py | 191 +++++++++++++++++++++++++++++++++ dimos/memory2/test_stream.py | 13 ++- dimos/utils/docs/doclinks.py | 23 ++-- docs/usage/transports/index.md | 2 +- 11 files changed, 552 insertions(+), 186 deletions(-) create mode 100644 dimos/memory2/impl/__init__.py create mode 100644 dimos/memory2/impl/memory.py create mode 100644 dimos/memory2/impl/sqlite.py create mode 100644 dimos/memory2/test_save.py diff --git a/dimos/memory2/__init__.py b/dimos/memory2/__init__.py index 2df421954e..0fbf95a5ff 100644 --- a/dimos/memory2/__init__.py +++ b/dimos/memory2/__init__.py @@ -1,4 +1,4 @@ -from dimos.memory2.backend import Backend, ListBackend +from dimos.memory2.backend import Backend, LiveBackend from dimos.memory2.buffer import ( BackpressureBuffer, Bounded, @@ -18,7 +18,9 @@ TagsFilter, TimeRangeFilter, ) -from dimos.memory2.store import ListStore, Session, Store, StreamNamespace +from dimos.memory2.impl.memory import ListBackend, MemorySession, MemoryStore +from dimos.memory2.impl.sqlite import SqliteBackend, SqliteSession, SqliteStore +from dimos.memory2.store import Session, Store, StreamNamespace from dimos.memory2.stream import Stream from dimos.memory2.transform import FnTransformer, QualityWindow, Transformer from dimos.memory2.type import Observation @@ -36,12 +38,17 @@ "FnTransformer", "KeepLast", "ListBackend", - "ListStore", + "LiveBackend", + "MemorySession", + "MemoryStore", "NearFilter", "Observation", "PredicateFilter", "QualityWindow", "Session", + "SqliteBackend", + "SqliteSession", + "SqliteStore", "Store", "Stream", "StreamNamespace", diff --git a/dimos/memory2/backend.py b/dimos/memory2/backend.py index 220f86347d..a63fae1f73 100644 --- a/dimos/memory2/backend.py +++ b/dimos/memory2/backend.py @@ -14,13 +14,7 @@ from __future__ import annotations -import threading -import time -from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, runtime_checkable - -from reactivex.disposable import Disposable - -from dimos.memory2.type import Observation +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, runtime_checkable if TYPE_CHECKING: from collections.abc import Iterator @@ -29,6 +23,7 @@ from dimos.memory2.buffer import BackpressureBuffer from dimos.memory2.filter import StreamQuery + from dimos.memory2.type import Observation T = TypeVar("T") @@ -57,128 +52,9 @@ def append( def count(self, query: StreamQuery) -> int: ... - def subscribe(self, buf: BackpressureBuffer[Observation[T]]) -> DisposableBase: ... - - -class ListBackend(Generic[T]): - """In-memory backend for experimentation. Thread-safe.""" - def __init__(self, name: str = "") -> None: - self._name = name - self._observations: list[Observation[T]] = [] - self._next_id = 0 - self._lock = threading.Lock() - self._subscribers: list[BackpressureBuffer[Observation[T]]] = [] - - @property - def name(self) -> str: - return self._name +@runtime_checkable +class LiveBackend(Backend[T], Protocol[T]): + """Backend that also supports live subscriptions.""" - def append( - self, - payload: T, - *, - ts: float | None = None, - pose: Any | None = None, - tags: dict[str, Any] | None = None, - ) -> Observation[T]: - with self._lock: - obs: Observation[T] = Observation( - id=self._next_id, - ts=ts if ts is not None else time.time(), - pose=pose, - tags=tags or {}, - _data=payload, - ) - self._next_id += 1 - self._observations.append(obs) - subs = list(self._subscribers) - - # Notify outside lock to avoid deadlocks - for buf in subs: - buf.put(obs) - - return obs - - def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: - """Snapshot + apply all filters/ordering/offset/limit in Python. - - If query.live_buffer is set, subscribes before backfill, then - switches to a live tail that blocks for new observations. - """ - buf = query.live_buffer - if buf is not None: - # Subscribe BEFORE backfill to avoid missing items - sub = self.subscribe(buf) - return self._iterate_live(query, buf, sub) - return self._iterate_snapshot(query) - - def _iterate_snapshot(self, query: StreamQuery) -> Iterator[Observation[T]]: - with self._lock: - snapshot = list(self._observations) - - # Apply filters - for f in query.filters: - snapshot = [obs for obs in snapshot if f.matches(obs)] - - # Ordering - if query.order_field: - key = query.order_field - snapshot.sort( - key=lambda obs: getattr(obs, key) if getattr(obs, key, None) is not None else 0, - reverse=query.order_desc, - ) - - # Offset - if query.offset_val: - snapshot = snapshot[query.offset_val :] - - # Limit - if query.limit_val is not None: - snapshot = snapshot[: query.limit_val] - - yield from snapshot - - def _iterate_live( - self, - query: StreamQuery, - buf: BackpressureBuffer[Observation[T]], - sub: DisposableBase, - ) -> Iterator[Observation[T]]: - from dimos.memory2.buffer import ClosedError - - # Backfill phase — use snapshot query (without live) for the backfill - last_id = -1 - for obs in self._iterate_snapshot(query): - last_id = max(last_id, obs.id) - yield obs - - # Live tail - filters = query.filters - try: - while True: - obs = buf.take() - if obs.id <= last_id: - continue - last_id = obs.id - if filters and not all(f.matches(obs) for f in filters): - continue - yield obs - except (ClosedError, StopIteration): - sub.dispose() - - def count(self, query: StreamQuery) -> int: - return sum(1 for _ in self.iterate(query)) - - def subscribe(self, buf: BackpressureBuffer[Observation[T]]) -> DisposableBase: - with self._lock: - self._subscribers.append(buf) - - def _unsubscribe() -> None: - with self._lock: - try: - self._subscribers.remove(buf) - except ValueError: - pass - - return Disposable(action=_unsubscribe) + def subscribe(self, buf: BackpressureBuffer[Observation[T]]) -> DisposableBase: ... diff --git a/dimos/memory2/impl/__init__.py b/dimos/memory2/impl/__init__.py new file mode 100644 index 0000000000..1ed1bd093e --- /dev/null +++ b/dimos/memory2/impl/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/dimos/memory2/impl/memory.py b/dimos/memory2/impl/memory.py new file mode 100644 index 0000000000..644821d4d1 --- /dev/null +++ b/dimos/memory2/impl/memory.py @@ -0,0 +1,173 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import threading +import time +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +from reactivex.disposable import Disposable + +from dimos.memory2.store import Session, Store +from dimos.memory2.type import Observation + +if TYPE_CHECKING: + from collections.abc import Iterator + + from reactivex.abc import DisposableBase + + from dimos.memory2.backend import Backend + from dimos.memory2.buffer import BackpressureBuffer + from dimos.memory2.filter import StreamQuery + +T = TypeVar("T") + + +class ListBackend(Generic[T]): + """In-memory backend for experimentation. Thread-safe.""" + + def __init__(self, name: str = "") -> None: + self._name = name + self._observations: list[Observation[T]] = [] + self._next_id = 0 + self._lock = threading.Lock() + self._subscribers: list[BackpressureBuffer[Observation[T]]] = [] + + @property + def name(self) -> str: + return self._name + + def append( + self, + payload: T, + *, + ts: float | None = None, + pose: Any | None = None, + tags: dict[str, Any] | None = None, + ) -> Observation[T]: + with self._lock: + obs: Observation[T] = Observation( + id=self._next_id, + ts=ts if ts is not None else time.time(), + pose=pose, + tags=tags or {}, + _data=payload, + ) + self._next_id += 1 + self._observations.append(obs) + subs = list(self._subscribers) + + # Notify outside lock to avoid deadlocks + for buf in subs: + buf.put(obs) + + return obs + + def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: + """Snapshot + apply all filters/ordering/offset/limit in Python. + + If query.live_buffer is set, subscribes before backfill, then + switches to a live tail that blocks for new observations. + """ + buf = query.live_buffer + if buf is not None: + # Subscribe BEFORE backfill to avoid missing items + sub = self.subscribe(buf) + return self._iterate_live(query, buf, sub) + return self._iterate_snapshot(query) + + def _iterate_snapshot(self, query: StreamQuery) -> Iterator[Observation[T]]: + with self._lock: + snapshot = list(self._observations) + + # Apply filters + for f in query.filters: + snapshot = [obs for obs in snapshot if f.matches(obs)] + + # Ordering + if query.order_field: + key = query.order_field + snapshot.sort( + key=lambda obs: getattr(obs, key) if getattr(obs, key, None) is not None else 0, + reverse=query.order_desc, + ) + + # Offset + if query.offset_val: + snapshot = snapshot[query.offset_val :] + + # Limit + if query.limit_val is not None: + snapshot = snapshot[: query.limit_val] + + yield from snapshot + + def _iterate_live( + self, + query: StreamQuery, + buf: BackpressureBuffer[Observation[T]], + sub: DisposableBase, + ) -> Iterator[Observation[T]]: + from dimos.memory2.buffer import ClosedError + + # Backfill phase — use snapshot query (without live) for the backfill + last_id = -1 + for obs in self._iterate_snapshot(query): + last_id = max(last_id, obs.id) + yield obs + + # Live tail + filters = query.filters + try: + while True: + obs = buf.take() + if obs.id <= last_id: + continue + last_id = obs.id + if filters and not all(f.matches(obs) for f in filters): + continue + yield obs + except (ClosedError, StopIteration): + sub.dispose() + + def count(self, query: StreamQuery) -> int: + return sum(1 for _ in self.iterate(query)) + + def subscribe(self, buf: BackpressureBuffer[Observation[T]]) -> DisposableBase: + with self._lock: + self._subscribers.append(buf) + + def _unsubscribe() -> None: + with self._lock: + try: + self._subscribers.remove(buf) + except ValueError: + pass + + return Disposable(action=_unsubscribe) + + +class MemorySession(Session): + """In-memory session. Each stream is backed by a ListBackend.""" + + def _create_backend(self, name: str) -> Backend[Any]: + return ListBackend(name) + + +class MemoryStore(Store): + """In-memory store for experimentation.""" + + def session(self) -> MemorySession: + return MemorySession() diff --git a/dimos/memory2/impl/sqlite.py b/dimos/memory2/impl/sqlite.py new file mode 100644 index 0000000000..a481d671a5 --- /dev/null +++ b/dimos/memory2/impl/sqlite.py @@ -0,0 +1,90 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import sqlite3 +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +from dimos.memory2.store import Session, Store + +if TYPE_CHECKING: + from collections.abc import Iterator + + from reactivex.abc import DisposableBase + + from dimos.memory2.backend import Backend + from dimos.memory2.buffer import BackpressureBuffer + from dimos.memory2.filter import StreamQuery + from dimos.memory2.type import Observation + +T = TypeVar("T") + + +class SqliteBackend(Generic[T]): + """SQLite-backed observation storage for a single stream (table).""" + + def __init__(self, conn: sqlite3.Connection, name: str) -> None: + self._conn = conn + self._name = name + + @property + def name(self) -> str: + return self._name + + def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: + raise NotImplementedError + + def append( + self, + payload: T, + *, + ts: float | None = None, + pose: Any | None = None, + tags: dict[str, Any] | None = None, + ) -> Observation[T]: + raise NotImplementedError + + def count(self, query: StreamQuery) -> int: + raise NotImplementedError + + def subscribe(self, buf: BackpressureBuffer[Observation[T]]) -> DisposableBase: + raise NotImplementedError + + +class SqliteSession(Session): + """Session owning a single SQLite connection.""" + + def __init__(self, conn: sqlite3.Connection) -> None: + super().__init__() + self._conn = conn + + def _create_backend(self, name: str) -> Backend[Any]: + return SqliteBackend(self._conn, name) + + def close(self) -> None: + super().close() + self._conn.close() + + +class SqliteStore(Store): + """Store backed by a SQLite database file.""" + + def __init__(self, path: str) -> None: + self._path = path + + def session(self) -> SqliteSession: + conn = sqlite3.connect(self._path, check_same_thread=False) + conn.execute("PRAGMA journal_mode=WAL") + return SqliteSession(conn) diff --git a/dimos/memory2/store.py b/dimos/memory2/store.py index e8e5340a14..071d7a46c3 100644 --- a/dimos/memory2/store.py +++ b/dimos/memory2/store.py @@ -14,14 +14,15 @@ from __future__ import annotations +from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, TypeVar, cast -from dimos.core.resource import Resource -from dimos.memory2.backend import Backend, ListBackend from dimos.memory2.stream import Stream if TYPE_CHECKING: - from collections.abc import Callable, Iterator + from collections.abc import Iterator + + from dimos.memory2.backend import Backend T = TypeVar("T") @@ -68,24 +69,32 @@ def __repr__(self) -> str: return f"StreamNamespace({list(self._session._streams.keys())})" -class Session(Resource): - """A session against a store. Creates and manages named streams.""" +class Session(ABC): + """A session against a store. Manages named streams over a shared connection. + + Subclasses implement ``_create_backend`` to provide storage-specific backends. + """ - def __init__(self, backend_factory: Callable[[str], Backend[Any]]) -> None: - self._backend_factory = backend_factory + def __init__(self) -> None: self._streams: dict[str, Stream[Any]] = {} self._backends: dict[str, Backend[Any]] = {} + @abstractmethod + def _create_backend(self, name: str) -> Backend[Any]: + """Create a backend for the named stream. Called once per stream name.""" + ... + def stream(self, name: str, payload_type: type[T] | None = None) -> Stream[T]: """Get or create a named stream. Returns the same Stream on repeated calls.""" if name not in self._streams: - backend = self._backend_factory(name) + backend = self._create_backend(name) self._backends[name] = backend self._streams[name] = Stream(source=backend) return cast("Stream[T]", self._streams[name]) - def list_streams(self) -> list[Stream[Any]]: - return list(self._streams.values()) + def list_streams(self) -> list[str]: + """Return names of all streams in this session.""" + return list(self._streams.keys()) def delete_stream(self, name: str) -> None: self._streams.pop(name, None) @@ -95,40 +104,27 @@ def delete_stream(self, name: str) -> None: def streams(self) -> StreamNamespace: return StreamNamespace(self) - def start(self) -> None: - pass - - def stop(self) -> None: - pass + def close(self) -> None: # noqa: B027 + """Release resources. Override in subclasses for cleanup.""" def __enter__(self) -> Session: return self def __exit__(self, *args: object) -> None: - self.stop() + self.close() -class Store(Resource): - """Top-level entry point — wraps a storage location.""" +class Store(ABC): + """Top-level entry point — wraps a storage location (file, URL, etc.).""" - def session(self) -> Session: - raise NotImplementedError + @abstractmethod + def session(self) -> Session: ... - def start(self) -> None: - pass - - def stop(self) -> None: - pass + def close(self) -> None: # noqa: B027 + """Release resources. Override in subclasses for cleanup.""" def __enter__(self) -> Store: return self def __exit__(self, *args: object) -> None: - self.stop() - - -class ListStore(Store): - """In-memory store for experimentation.""" - - def session(self) -> Session: - return Session(backend_factory=lambda name: ListBackend(name)) + self.close() diff --git a/dimos/memory2/stream.py b/dimos/memory2/stream.py index ab327726dc..69ba220575 100644 --- a/dimos/memory2/stream.py +++ b/dimos/memory2/stream.py @@ -17,7 +17,7 @@ from itertools import islice from typing import TYPE_CHECKING, Any, Generic, TypeVar -from dimos.memory2.backend import Backend +from dimos.memory2.backend import Backend, LiveBackend from dimos.memory2.buffer import BackpressureBuffer, KeepLast from dimos.memory2.filter import ( AfterFilter, @@ -168,8 +168,8 @@ def transform(self, xf: Transformer[T, R]) -> Stream[R]: def live(self, buffer: BackpressureBuffer[Observation[Any]] | None = None) -> Stream[T]: """Return a stream whose iteration never ends — backfill then live tail. - Only valid on backend-backed streams. Transforms downstream just see - an infinite iterator. Call .live() before .transform(), not after. + Only valid on backend-backed streams whose backend implements + LiveBackend. Call .live() before .transform(), not after. Default buffer: KeepLast(). The backend handles subscription, dedup, and backpressure — how it does so is its business. @@ -179,9 +179,25 @@ def live(self, buffer: BackpressureBuffer[Observation[Any]] | None = None) -> St "Cannot call .live() on a transform stream. " "Call .live() on the source stream, then .transform()." ) + if not isinstance(self._source, LiveBackend): + raise TypeError(f"Backend {self._source.name!r} does not support live mode.") buf = buffer if buffer is not None else KeepLast() return self._replace_query(live_buffer=buf) + # ── Save ───────────────────────────────────────────────────────── + + def save(self, target: Stream[T]) -> Stream[T]: + """Sync terminal: iterate self, append each obs to target's backend. + + Returns the target stream for continued querying. + """ + if isinstance(target._source, Stream): + raise TypeError("Cannot save to a transform stream. Target must be backend-backed.") + backend = target._source + for obs in self: + backend.append(obs.data, ts=obs.ts, pose=obs.pose, tags=obs.tags) + return target + # ── Terminals ─────────────────────────────────────────────────── def fetch(self) -> list[Observation[T]]: diff --git a/dimos/memory2/test_save.py b/dimos/memory2/test_save.py new file mode 100644 index 0000000000..a009850334 --- /dev/null +++ b/dimos/memory2/test_save.py @@ -0,0 +1,191 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Stream.save() and LiveBackend protocol split.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pytest + +from dimos.memory2.backend import Backend, LiveBackend +from dimos.memory2.impl.memory import ListBackend +from dimos.memory2.stream import Stream +from dimos.memory2.transform import FnTransformer +from dimos.memory2.type import Observation + +if TYPE_CHECKING: + from collections.abc import Iterator + + from dimos.memory2.filter import StreamQuery + +# ── Helpers ────────────────────────────────────────────────────────── + + +def make_stream(n: int = 5, start_ts: float = 0.0) -> Stream[int]: + backend = ListBackend[int]("test") + for i in range(n): + backend.append(i * 10, ts=start_ts + i) + return Stream(source=backend) + + +class ReadOnlyBackend: + """A Backend that does NOT support live mode (no subscribe).""" + + def __init__(self, name: str = "") -> None: + self._name = name + self._obs: list[Observation[int]] = [] + self._next_id = 0 + + @property + def name(self) -> str: + return self._name + + def iterate(self, query: StreamQuery) -> Iterator[Observation[int]]: + yield from self._obs + + def append( + self, + payload: int, + *, + ts: float | None = None, + pose: Any | None = None, + tags: dict[str, Any] | None = None, + ) -> Observation[int]: + obs: Observation[int] = Observation( + id=self._next_id, ts=ts or 0.0, pose=pose, tags=tags or {}, _data=payload + ) + self._next_id += 1 + self._obs.append(obs) + return obs + + def count(self, query: StreamQuery) -> int: + return len(self._obs) + + +# ═══════════════════════════════════════════════════════════════════ +# Protocol checks +# ═══════════════════════════════════════════════════════════════════ + + +class TestProtocolSplit: + def test_list_backend_is_live(self) -> None: + b = ListBackend[int]("x") + assert isinstance(b, LiveBackend) + + def test_list_backend_is_backend(self) -> None: + b = ListBackend[int]("x") + assert isinstance(b, Backend) + + def test_readonly_is_backend(self) -> None: + b = ReadOnlyBackend() + assert isinstance(b, Backend) + + def test_readonly_is_not_live(self) -> None: + b = ReadOnlyBackend() + assert not isinstance(b, LiveBackend) + + +# ═══════════════════════════════════════════════════════════════════ +# .live() rejects non-LiveBackend +# ═══════════════════════════════════════════════════════════════════ + + +class TestLiveRejectsNonLive: + def test_live_rejects_non_live_backend(self) -> None: + b = ReadOnlyBackend("ro") + s = Stream(source=b) + with pytest.raises(TypeError, match="does not support live mode"): + s.live() + + +# ═══════════════════════════════════════════════════════════════════ +# .save() +# ═══════════════════════════════════════════════════════════════════ + + +class TestSave: + def test_save_populates_target(self) -> None: + source = make_stream(3) + target_backend = ListBackend[int]("target") + target = Stream(source=target_backend) + + source.save(target) + + results = target.fetch() + assert len(results) == 3 + assert [o.data for o in results] == [0, 10, 20] + + def test_save_returns_target_stream(self) -> None: + source = make_stream(2) + target_backend = ListBackend[int]("target") + target = Stream(source=target_backend) + + result = source.save(target) + + assert result is target + + def test_save_preserves_data(self) -> None: + backend = ListBackend[int]("src") + backend.append(42, ts=1.0, pose=(1, 2, 3), tags={"label": "cat"}) + source = Stream(source=backend) + + target_backend = ListBackend[int]("dst") + target = Stream(source=target_backend) + source.save(target) + + obs = target.first() + assert obs.data == 42 + assert obs.ts == 1.0 + assert obs.pose == (1, 2, 3) + assert obs.tags == {"label": "cat"} + + def test_save_with_transform(self) -> None: + source = make_stream(3) # data: 0, 10, 20 + doubled = source.transform(FnTransformer(lambda obs: obs.derive(data=obs.data * 2))) + + target_backend = ListBackend[int]("target") + target = Stream(source=target_backend) + doubled.save(target) + + assert [o.data for o in target.fetch()] == [0, 20, 40] + + def test_save_rejects_transform_target(self) -> None: + source = make_stream(2) + base = make_stream(2) + transform_stream = base.transform(FnTransformer(lambda obs: obs.derive(obs.data))) + + with pytest.raises(TypeError, match="Cannot save to a transform stream"): + source.save(transform_stream) + + def test_save_target_queryable(self) -> None: + source = make_stream(5, start_ts=0.0) # ts: 0,1,2,3,4 + + target_backend = ListBackend[int]("target") + target = Stream(source=target_backend) + result = source.save(target) + + after_2 = result.after(2.0).fetch() + assert [o.data for o in after_2] == [30, 40] + + def test_save_empty_source(self) -> None: + source = make_stream(0) + target_backend = ListBackend[int]("target") + target = Stream(source=target_backend) + + result = source.save(target) + + assert result.count() == 0 + assert result.fetch() == [] diff --git a/dimos/memory2/test_stream.py b/dimos/memory2/test_stream.py index 70f2500039..74d009ed6f 100644 --- a/dimos/memory2/test_stream.py +++ b/dimos/memory2/test_stream.py @@ -24,9 +24,8 @@ import pytest -from dimos.memory2.backend import ListBackend from dimos.memory2.buffer import Bounded, ClosedError, DropNew, KeepLast, Unbounded -from dimos.memory2.store import ListStore +from dimos.memory2.impl.memory import ListBackend, MemoryStore from dimos.memory2.stream import Stream from dimos.memory2.transform import FnTransformer, QualityWindow, Transformer from dimos.memory2.type import Observation @@ -331,7 +330,7 @@ class TestStoreSession: """Store -> Session -> Stream hierarchy for named streams.""" def test_basic_session(self): - store = ListStore() + store = MemoryStore() with store.session() as session: images = session.stream("images") images.append("frame1", ts=0.0) @@ -339,14 +338,14 @@ def test_basic_session(self): assert images.count() == 2 def test_same_stream_on_repeated_calls(self): - store = ListStore() + store = MemoryStore() with store.session() as session: s1 = session.stream("images") s2 = session.stream("images") assert s1 is s2 def test_stream_namespace(self): - store = ListStore() + store = MemoryStore() with store.session() as session: session.stream("images") session.stream("lidar") @@ -356,13 +355,13 @@ def test_stream_namespace(self): assert session.streams["lidar"] is session.stream("lidar") def test_namespace_missing_raises(self): - store = ListStore() + store = MemoryStore() with store.session() as session: with pytest.raises(AttributeError, match="No stream named"): _ = session.streams.nonexistent def test_delete_stream(self): - store = ListStore() + store = MemoryStore() with store.session() as session: session.stream("temp") session.delete_stream("temp") diff --git a/dimos/utils/docs/doclinks.py b/dimos/utils/docs/doclinks.py index 2cf5d1702f..4d2fb6dc1c 100644 --- a/dimos/utils/docs/doclinks.py +++ b/dimos/utils/docs/doclinks.py @@ -360,12 +360,13 @@ def replace_code_match(match: re.Match[str]) -> str: resolved_path = resolve_candidates(candidates, file_ref) if resolved_path is None: + doc_rel = doc_path.relative_to(root) if doc_path.is_relative_to(root) else doc_path if len(candidates) > 1: errors.append( - f"'{file_ref}' matches multiple files: {[str(c) for c in candidates]}" + f"'{file_ref}' in {doc_rel} matches multiple files: {[str(c) for c in candidates]}" ) else: - errors.append(f"No file matching '{file_ref}' found in codebase") + errors.append(f"No file matching '{file_ref}' found in codebase (in {doc_rel})") return full_match # Determine line fragment @@ -438,12 +439,13 @@ def replace_link_match(match: re.Match[str]) -> str: if result != full_match: changes.append(f" {link_text}: .md -> {new_link}") return result + doc_rel = doc_path.relative_to(root) if doc_path.is_relative_to(root) else doc_path if len(candidates) > 1: errors.append( - f"'{link_text}' matches multiple docs: {[str(c) for c in candidates]}" + f"'{link_text}' in {doc_rel} matches multiple docs: {[str(c) for c in candidates]}" ) else: - errors.append(f"No doc matching '{link_text}' found") + errors.append(f"No doc matching '{link_text}' found (in {doc_rel})") return full_match # Absolute path @@ -460,12 +462,13 @@ def replace_link_match(match: re.Match[str]) -> str: ) changes.append(f" {link_text}: {raw_link} -> {new_link} (fixed broken link)") return f"[{link_text}]({new_link})" + doc_rel = doc_path.relative_to(root) if doc_path.is_relative_to(root) else doc_path if len(candidates) > 1: errors.append( - f"Broken link '{raw_link}': ambiguous, matches {[str(c) for c in candidates]}" + f"Broken link '{raw_link}' in {doc_rel}: ambiguous, matches {[str(c) for c in candidates]}" ) else: - errors.append(f"Broken link: '{raw_link}' does not exist") + errors.append(f"Broken link '{raw_link}' in {doc_rel}: does not exist") return full_match # Relative path — resolve from doc file's directory @@ -475,7 +478,8 @@ def replace_link_match(match: re.Match[str]) -> str: try: rel_to_root = resolved_abs.relative_to(root) except ValueError: - errors.append(f"Link '{raw_link}' resolves outside repo root") + doc_rel = doc_path.relative_to(root) if doc_path.is_relative_to(root) else doc_path + errors.append(f"Link '{raw_link}' in {doc_rel} resolves outside repo root") return full_match if resolved_abs.exists(): @@ -496,12 +500,13 @@ def replace_link_match(match: re.Match[str]) -> str: ) changes.append(f" {link_text}: {raw_link} -> {new_link} (found by search)") return f"[{link_text}]({new_link})" + doc_rel = doc_path.relative_to(root) if doc_path.is_relative_to(root) else doc_path if len(candidates) > 1: errors.append( - f"Broken link '{raw_link}': ambiguous, matches {[str(c) for c in candidates]}" + f"Broken link '{raw_link}' in {doc_rel}: ambiguous, matches {[str(c) for c in candidates]}" ) else: - errors.append(f"Broken link '{raw_link}': target not found") + errors.append(f"Broken link '{raw_link}' in {doc_rel}: target not found") return full_match # Split by ignore regions and only process non-ignored parts diff --git a/docs/usage/transports/index.md b/docs/usage/transports/index.md index 5cfe9caaa8..02bb8a43ab 100644 --- a/docs/usage/transports/index.md +++ b/docs/usage/transports/index.md @@ -357,7 +357,7 @@ Received 2 messages: {'temperature': 23.0} ``` -See [`memory.py`](/dimos/protocol/pubsub/impl/memory.py) for the complete source. +See [`pubsub/impl/memory.py`](/dimos/protocol/pubsub/impl/memory.py) for the complete source. --- From 87b94adc61782274e1e8452493aa10f876473ef2 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Wed, 11 Mar 2026 13:50:32 +0800 Subject: [PATCH 069/118] memory2: add buffer.py docstring and extract buffer tests to test_buffer.py --- dimos/memory2/buffer.py | 20 +++++++++ dimos/memory2/test_buffer.py | 86 ++++++++++++++++++++++++++++++++++++ dimos/memory2/test_stream.py | 71 +---------------------------- 3 files changed, 108 insertions(+), 69 deletions(-) create mode 100644 dimos/memory2/test_buffer.py diff --git a/dimos/memory2/buffer.py b/dimos/memory2/buffer.py index de122f3d68..49814eb6dc 100644 --- a/dimos/memory2/buffer.py +++ b/dimos/memory2/buffer.py @@ -12,6 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Backpressure buffers — the bridge between push and pull. + +Real-world data sources (cameras, LiDAR, ROS topics) and ReactiveX pipelines +are *push-based*: they emit items whenever they please. Databases, analysis +systems, and our memory store are *pull-based*: consumers iterate at their own +pace. A BackpressureBuffer sits between the two, absorbing push bursts so +that the pull side can drain items on its own schedule. + +The choice of strategy controls what happens under load: + +- **KeepLast** — single-slot, always overwrites; best for real-time sensor + data where only the latest reading matters. +- **Bounded** — FIFO with a cap; drops the oldest item on overflow. +- **DropNew** — FIFO with a cap; rejects new items on overflow. +- **Unbounded** — unlimited FIFO; guarantees delivery at the cost of memory. + +All four share the same ABC interface and are interchangeable wherever a +buffer is accepted (e.g. ``Stream.live(buffer=...)``). +""" + from __future__ import annotations from abc import ABC, abstractmethod diff --git a/dimos/memory2/test_buffer.py b/dimos/memory2/test_buffer.py new file mode 100644 index 0000000000..f851a6fcee --- /dev/null +++ b/dimos/memory2/test_buffer.py @@ -0,0 +1,86 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for backpressure buffers.""" + +from __future__ import annotations + +import threading +import time + +import pytest + +from dimos.memory2.buffer import Bounded, ClosedError, DropNew, KeepLast, Unbounded + + +class TestBackpressureBuffers: + """Thread-safe buffers bridging push sources to pull consumers.""" + + def test_keep_last_overwrites(self): + buf = KeepLast[int]() + buf.put(1) + buf.put(2) + buf.put(3) + assert buf.take() == 3 + assert len(buf) == 0 + + def test_bounded_drops_oldest(self): + buf = Bounded[int](maxlen=2) + buf.put(1) + buf.put(2) + buf.put(3) # drops 1 + assert buf.take() == 2 + assert buf.take() == 3 + + def test_drop_new_rejects(self): + buf = DropNew[int](maxlen=2) + assert buf.put(1) is True + assert buf.put(2) is True + assert buf.put(3) is False # rejected + assert buf.take() == 1 + assert buf.take() == 2 + + def test_unbounded_keeps_all(self): + buf = Unbounded[int]() + for i in range(100): + buf.put(i) + assert len(buf) == 100 + + def test_close_signals_end(self): + buf = KeepLast[int]() + buf.close() + with pytest.raises(ClosedError): + buf.take() + + def test_buffer_is_iterable(self): + """Iterating a buffer yields items until closed.""" + buf = Unbounded[int]() + buf.put(1) + buf.put(2) + buf.close() + assert list(buf) == [1, 2] + + def test_take_blocks_until_put(self): + buf = KeepLast[int]() + result = [] + + def producer(): + time.sleep(0.05) + buf.put(42) + + t = threading.Thread(target=producer) + t.start() + result.append(buf.take(timeout=2.0)) + t.join() + assert result == [42] diff --git a/dimos/memory2/test_stream.py b/dimos/memory2/test_stream.py index 74d009ed6f..764f7532bd 100644 --- a/dimos/memory2/test_stream.py +++ b/dimos/memory2/test_stream.py @@ -24,7 +24,7 @@ import pytest -from dimos.memory2.buffer import Bounded, ClosedError, DropNew, KeepLast, Unbounded +from dimos.memory2.buffer import KeepLast, Unbounded from dimos.memory2.impl.memory import ListBackend, MemoryStore from dimos.memory2.stream import Stream from dimos.memory2.transform import FnTransformer, QualityWindow, Transformer @@ -414,74 +414,7 @@ def test_derive_preserves_metadata(self): # ═══════════════════════════════════════════════════════════════════ -# 10. Backpressure buffers -# ═══════════════════════════════════════════════════════════════════ - - -class TestBackpressureBuffers: - """Thread-safe buffers bridging push sources to pull consumers.""" - - def test_keep_last_overwrites(self): - buf = KeepLast[int]() - buf.put(1) - buf.put(2) - buf.put(3) - assert buf.take() == 3 - assert len(buf) == 0 - - def test_bounded_drops_oldest(self): - buf = Bounded[int](maxlen=2) - buf.put(1) - buf.put(2) - buf.put(3) # drops 1 - assert buf.take() == 2 - assert buf.take() == 3 - - def test_drop_new_rejects(self): - buf = DropNew[int](maxlen=2) - assert buf.put(1) is True - assert buf.put(2) is True - assert buf.put(3) is False # rejected - assert buf.take() == 1 - assert buf.take() == 2 - - def test_unbounded_keeps_all(self): - buf = Unbounded[int]() - for i in range(100): - buf.put(i) - assert len(buf) == 100 - - def test_close_signals_end(self): - buf = KeepLast[int]() - buf.close() - with pytest.raises(ClosedError): - buf.take() - - def test_buffer_is_iterable(self): - """Iterating a buffer yields items until closed.""" - buf = Unbounded[int]() - buf.put(1) - buf.put(2) - buf.close() - assert list(buf) == [1, 2] - - def test_take_blocks_until_put(self): - buf = KeepLast[int]() - result = [] - - def producer(): - time.sleep(0.05) - buf.put(42) - - t = threading.Thread(target=producer) - t.start() - result.append(buf.take(timeout=2.0)) - t.join() - assert result == [42] - - -# ═══════════════════════════════════════════════════════════════════ -# 11. Live mode +# 10. Live mode # ═══════════════════════════════════════════════════════════════════ From 8070379e44b92ef8edd23ff448fdb2f41252de81 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Wed, 11 Mar 2026 14:17:39 +0800 Subject: [PATCH 070/118] memory2: add Codec protocol and grid test for store implementations Introduce codecs/ package with the Codec[T] protocol (encode/decode). Thread payload_type through Session._create_backend() so backends can select the right codec. Add test_impl.py grid test that runs the same 15 basic tests against every store backend (memory passes, sqlite xfail until implemented). --- dimos/memory2/codecs/__init__.py | 26 ++++ dimos/memory2/impl/memory.py | 2 +- dimos/memory2/impl/sqlite.py | 2 +- dimos/memory2/store.py | 4 +- dimos/memory2/test_impl.py | 236 +++++++++++++++++++++++++++++++ 5 files changed, 266 insertions(+), 4 deletions(-) create mode 100644 dimos/memory2/codecs/__init__.py create mode 100644 dimos/memory2/test_impl.py diff --git a/dimos/memory2/codecs/__init__.py b/dimos/memory2/codecs/__init__.py new file mode 100644 index 0000000000..42627ce1ae --- /dev/null +++ b/dimos/memory2/codecs/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Protocol, TypeVar + +T = TypeVar("T") + + +class Codec(Protocol[T]): + """Encode/decode payloads for storage.""" + + def encode(self, value: T) -> bytes: ... + def decode(self, data: bytes) -> T: ... diff --git a/dimos/memory2/impl/memory.py b/dimos/memory2/impl/memory.py index 644821d4d1..870ee044b9 100644 --- a/dimos/memory2/impl/memory.py +++ b/dimos/memory2/impl/memory.py @@ -162,7 +162,7 @@ def _unsubscribe() -> None: class MemorySession(Session): """In-memory session. Each stream is backed by a ListBackend.""" - def _create_backend(self, name: str) -> Backend[Any]: + def _create_backend(self, name: str, payload_type: type[Any] | None = None) -> Backend[Any]: return ListBackend(name) diff --git a/dimos/memory2/impl/sqlite.py b/dimos/memory2/impl/sqlite.py index a481d671a5..61556b4998 100644 --- a/dimos/memory2/impl/sqlite.py +++ b/dimos/memory2/impl/sqlite.py @@ -70,7 +70,7 @@ def __init__(self, conn: sqlite3.Connection) -> None: super().__init__() self._conn = conn - def _create_backend(self, name: str) -> Backend[Any]: + def _create_backend(self, name: str, payload_type: type[Any] | None = None) -> Backend[Any]: return SqliteBackend(self._conn, name) def close(self) -> None: diff --git a/dimos/memory2/store.py b/dimos/memory2/store.py index 071d7a46c3..4dfca74cc0 100644 --- a/dimos/memory2/store.py +++ b/dimos/memory2/store.py @@ -80,14 +80,14 @@ def __init__(self) -> None: self._backends: dict[str, Backend[Any]] = {} @abstractmethod - def _create_backend(self, name: str) -> Backend[Any]: + def _create_backend(self, name: str, payload_type: type[Any] | None = None) -> Backend[Any]: """Create a backend for the named stream. Called once per stream name.""" ... def stream(self, name: str, payload_type: type[T] | None = None) -> Stream[T]: """Get or create a named stream. Returns the same Stream on repeated calls.""" if name not in self._streams: - backend = self._create_backend(name) + backend = self._create_backend(name, payload_type) self._backends[name] = backend self._streams[name] = Stream(source=backend) return cast("Stream[T]", self._streams[name]) diff --git a/dimos/memory2/test_impl.py b/dimos/memory2/test_impl.py new file mode 100644 index 0000000000..699cde40fc --- /dev/null +++ b/dimos/memory2/test_impl.py @@ -0,0 +1,236 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Grid tests for Store implementations. + +Runs the same test logic against every Store backend (MemoryStore, SqliteStore, …). +""" + +from __future__ import annotations + +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +import pytest + +if TYPE_CHECKING: + from collections.abc import Callable, Generator + + from dimos.memory2.store import Session + +# ── Case definition ──────────────────────────────────────────────── + + +@dataclass +class Case: + name: str + session_factory: Callable[[], Generator[Session, None, None]] + tags: set[str] = field(default_factory=set) + + +# ── Context managers ─────────────────────────────────────────────── + + +@contextmanager +def memory_session() -> Generator[Session, None, None]: + from dimos.memory2.impl.memory import MemoryStore + + store = MemoryStore() + with store.session() as session: + yield session + + +@contextmanager +def sqlite_session() -> Generator[Session, None, None]: + import tempfile + + from dimos.memory2.impl.sqlite import SqliteStore + + with tempfile.NamedTemporaryFile(suffix=".db") as f: + store = SqliteStore(f.name) + with store.session() as session: + yield session + + +# ── Test cases ───────────────────────────────────────────────────── + +testcases = [ + Case(name="memory", session_factory=memory_session, tags={"basic", "live"}), + Case( + name="sqlite", + session_factory=sqlite_session, + tags={"basic"}, + ), +] + +basic_cases = [c for c in testcases if "basic" in c.tags] + +# Mark sqlite xfail until backend methods are implemented +_xfail_if_stub = { + "sqlite": pytest.mark.xfail( + reason="SqliteBackend not yet implemented", raises=NotImplementedError, strict=False + ), +} + + +def _apply_marks(cases: list[Case]) -> list[Any]: + return [ + pytest.param(c, marks=_xfail_if_stub[c.name]) if c.name in _xfail_if_stub else c + for c in cases + ] + + +# ── Tests ────────────────────────────────────────────────────────── + + +@pytest.mark.parametrize("case", _apply_marks(basic_cases), ids=lambda c: c.name) +class TestStoreBasic: + """Core store operations that every backend must support.""" + + def test_create_stream_and_append(self, case: Case) -> None: + with case.session_factory() as session: + s = session.stream("images", bytes) + obs = s.append(b"frame1", tags={"camera": "front"}) + + assert obs.data == b"frame1" + assert obs.tags["camera"] == "front" + assert obs.ts > 0 + + def test_append_multiple_and_fetch(self, case: Case) -> None: + with case.session_factory() as session: + s = session.stream("sensor", float) + s.append(1.0, ts=100.0) + s.append(2.0, ts=200.0) + s.append(3.0, ts=300.0) + + results = s.fetch() + assert len(results) == 3 + assert [o.data for o in results] == [1.0, 2.0, 3.0] + + def test_iterate_stream(self, case: Case) -> None: + with case.session_factory() as session: + s = session.stream("log", str) + s.append("a", ts=1.0) + s.append("b", ts=2.0) + + collected = [obs.data for obs in s] + assert collected == ["a", "b"] + + def test_count(self, case: Case) -> None: + with case.session_factory() as session: + s = session.stream("events", str) + assert s.count() == 0 + s.append("x") + s.append("y") + assert s.count() == 2 + + def test_first_and_last(self, case: Case) -> None: + with case.session_factory() as session: + s = session.stream("data", int) + s.append(10, ts=1.0) + s.append(20, ts=2.0) + s.append(30, ts=3.0) + + assert s.first().data == 10 + assert s.last().data == 30 + + def test_first_empty_raises(self, case: Case) -> None: + with case.session_factory() as session: + s = session.stream("empty", int) + with pytest.raises(LookupError): + s.first() + + def test_exists(self, case: Case) -> None: + with case.session_factory() as session: + s = session.stream("check", str) + assert not s.exists() + s.append("hi") + assert s.exists() + + def test_filter_after(self, case: Case) -> None: + with case.session_factory() as session: + s = session.stream("ts_data", int) + s.append(1, ts=10.0) + s.append(2, ts=20.0) + s.append(3, ts=30.0) + + results = s.after(15.0).fetch() + assert [o.data for o in results] == [2, 3] + + def test_filter_before(self, case: Case) -> None: + with case.session_factory() as session: + s = session.stream("ts_data", int) + s.append(1, ts=10.0) + s.append(2, ts=20.0) + s.append(3, ts=30.0) + + results = s.before(25.0).fetch() + assert [o.data for o in results] == [1, 2] + + def test_filter_time_range(self, case: Case) -> None: + with case.session_factory() as session: + s = session.stream("ts_data", int) + s.append(1, ts=10.0) + s.append(2, ts=20.0) + s.append(3, ts=30.0) + + results = s.time_range(15.0, 25.0).fetch() + assert [o.data for o in results] == [2] + + def test_filter_tags(self, case: Case) -> None: + with case.session_factory() as session: + s = session.stream("tagged", str) + s.append("a", tags={"kind": "info"}) + s.append("b", tags={"kind": "error"}) + s.append("c", tags={"kind": "info"}) + + results = s.filter_tags(kind="info").fetch() + assert [o.data for o in results] == ["a", "c"] + + def test_limit_and_offset(self, case: Case) -> None: + with case.session_factory() as session: + s = session.stream("paged", int) + for i in range(5): + s.append(i, ts=float(i)) + + page = s.offset(1).limit(2).fetch() + assert [o.data for o in page] == [1, 2] + + def test_order_by_desc(self, case: Case) -> None: + with case.session_factory() as session: + s = session.stream("ordered", int) + s.append(1, ts=10.0) + s.append(2, ts=20.0) + s.append(3, ts=30.0) + + results = s.order_by("ts", desc=True).fetch() + assert [o.data for o in results] == [3, 2, 1] + + def test_separate_streams_isolated(self, case: Case) -> None: + with case.session_factory() as session: + a = session.stream("stream_a", str) + b = session.stream("stream_b", str) + + a.append("in_a") + b.append("in_b") + + assert [o.data for o in a] == ["in_a"] + assert [o.data for o in b] == ["in_b"] + + def test_same_stream_on_repeated_calls(self, case: Case) -> None: + with case.session_factory() as session: + s1 = session.stream("reuse", str) + s2 = session.stream("reuse", str) + assert s1 is s2 From dde8017dd39be4471a041d8c8ac1da394abb23de Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Wed, 11 Mar 2026 14:40:55 +0800 Subject: [PATCH 071/118] memory2: add codec implementations (pickle, lcm, jpeg) with grid tests PickleCodec for arbitrary objects, LcmCodec for DimosMsg types, JpegCodec for Image types with TurboJPEG. codec_for() auto-selects based on payload type. Grid test verifies roundtrip preservation across all three codecs using real PoseStamped and camera frame data. --- dimos/memory2/codecs/__init__.py | 14 +-- dimos/memory2/codecs/base.py | 44 ++++++++ dimos/memory2/codecs/jpeg.py | 63 ++++++++++++ dimos/memory2/codecs/lcm.py | 33 ++++++ dimos/memory2/codecs/pickle.py | 28 +++++ dimos/memory2/codecs/test_codecs.py | 154 ++++++++++++++++++++++++++++ 6 files changed, 325 insertions(+), 11 deletions(-) create mode 100644 dimos/memory2/codecs/base.py create mode 100644 dimos/memory2/codecs/jpeg.py create mode 100644 dimos/memory2/codecs/lcm.py create mode 100644 dimos/memory2/codecs/pickle.py create mode 100644 dimos/memory2/codecs/test_codecs.py diff --git a/dimos/memory2/codecs/__init__.py b/dimos/memory2/codecs/__init__.py index 42627ce1ae..a7feb3bce3 100644 --- a/dimos/memory2/codecs/__init__.py +++ b/dimos/memory2/codecs/__init__.py @@ -12,15 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations +from dimos.memory2.codecs.base import Codec, codec_for +from dimos.memory2.codecs.pickle import PickleCodec -from typing import Protocol, TypeVar - -T = TypeVar("T") - - -class Codec(Protocol[T]): - """Encode/decode payloads for storage.""" - - def encode(self, value: T) -> bytes: ... - def decode(self, data: bytes) -> T: ... +__all__ = ["Codec", "PickleCodec", "codec_for"] diff --git a/dimos/memory2/codecs/base.py b/dimos/memory2/codecs/base.py new file mode 100644 index 0000000000..4c2b3865f5 --- /dev/null +++ b/dimos/memory2/codecs/base.py @@ -0,0 +1,44 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any, Protocol, TypeVar + +T = TypeVar("T") + + +class Codec(Protocol[T]): + """Encode/decode payloads for storage.""" + + def encode(self, value: T) -> bytes: ... + def decode(self, data: bytes) -> T: ... + + +def codec_for(payload_type: type[Any] | None = None) -> Codec[Any]: + """Auto-select codec based on payload type.""" + from dimos.memory2.codecs.pickle import PickleCodec + + if payload_type is not None: + from dimos.msgs.sensor_msgs.Image import Image + + if issubclass(payload_type, Image): + from dimos.memory2.codecs.jpeg import JpegCodec + + return JpegCodec() + if hasattr(payload_type, "lcm_encode") and hasattr(payload_type, "lcm_decode"): + from dimos.memory2.codecs.lcm import LcmCodec + + return LcmCodec(payload_type) + return PickleCodec() diff --git a/dimos/memory2/codecs/jpeg.py b/dimos/memory2/codecs/jpeg.py new file mode 100644 index 0000000000..3ef605c0db --- /dev/null +++ b/dimos/memory2/codecs/jpeg.py @@ -0,0 +1,63 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import struct +from typing import Any + + +class JpegCodec: + """Codec for Image types — stores as JPEG bytes (lossy, ~10-20x smaller). + + Uses TurboJPEG (libjpeg-turbo) for 2-5x faster encode/decode vs OpenCV. + Preserves ``frame_id`` as a short header: ````. + Pixel data is lossy-compressed; ``ts`` is NOT preserved (stored separately). + """ + + def __init__(self, quality: int = 50) -> None: + self._quality = quality + from turbojpeg import TurboJPEG # type: ignore[import-untyped] + + self._tj = TurboJPEG() + + _TJPF_MAP: dict[str, int] | None = None + + @staticmethod + def _get_tjpf_map() -> dict[str, int]: + if JpegCodec._TJPF_MAP is None: + from turbojpeg import TJPF_BGR, TJPF_GRAY, TJPF_RGB # type: ignore[import-untyped] + + JpegCodec._TJPF_MAP = {"BGR": TJPF_BGR, "RGB": TJPF_RGB, "GRAY": TJPF_GRAY} + return JpegCodec._TJPF_MAP + + def encode(self, value: Any) -> bytes: + from turbojpeg import TJPF_BGR # type: ignore[import-untyped] + + pf = self._get_tjpf_map().get(value.format.value, TJPF_BGR) + jpeg_data: bytes = self._tj.encode(value.data, quality=self._quality, pixel_format=pf) + frame_id = (value.frame_id or "").encode("utf-8") + header = struct.pack(" Any: + from dimos.msgs.sensor_msgs.Image import Image, ImageFormat + + fid_len = struct.unpack(" None: + self._msg_type = msg_type + + def encode(self, value: DimosMsg) -> bytes: + return value.lcm_encode() + + def decode(self, data: bytes) -> DimosMsg: + return self._msg_type.lcm_decode(data) diff --git a/dimos/memory2/codecs/pickle.py b/dimos/memory2/codecs/pickle.py new file mode 100644 index 0000000000..7200e1da50 --- /dev/null +++ b/dimos/memory2/codecs/pickle.py @@ -0,0 +1,28 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import pickle +from typing import Any + + +class PickleCodec: + """Fallback codec for arbitrary Python objects.""" + + def encode(self, value: Any) -> bytes: + return pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL) + + def decode(self, data: bytes) -> Any: + return pickle.loads(data) diff --git a/dimos/memory2/codecs/test_codecs.py b/dimos/memory2/codecs/test_codecs.py new file mode 100644 index 0000000000..8f3eb17c10 --- /dev/null +++ b/dimos/memory2/codecs/test_codecs.py @@ -0,0 +1,154 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Grid tests for Codec implementations. + +Runs roundtrip encode→decode tests across every codec, verifying data preservation. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +import pytest + +from dimos.memory2.codecs.base import Codec, codec_for + +if TYPE_CHECKING: + from collections.abc import Callable + +# ── Case definition ──────────────────────────────────────────────── + + +@dataclass +class Case: + name: str + codec: Codec[Any] + values: list[Any] + eq: Callable[[Any, Any], bool] | None = None # custom equality: (original, decoded) -> bool + + +# ── Test cases ───────────────────────────────────────────────────── + + +def _pickle_case() -> Case: + from dimos.memory2.codecs.pickle import PickleCodec + + return Case( + name="pickle", + codec=PickleCodec(), + values=[42, "hello", b"raw bytes", {"key": "value"}], + ) + + +def _lcm_case() -> Case: + from dimos.memory2.codecs.lcm import LcmCodec + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + from dimos.msgs.geometry_msgs.Quaternion import Quaternion + from dimos.msgs.geometry_msgs.Vector3 import Vector3 + + return Case( + name="lcm", + codec=LcmCodec(PoseStamped), + values=[ + PoseStamped( + ts=1.0, + frame_id="map", + position=Vector3(1.0, 2.0, 3.0), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), + ), + PoseStamped(ts=0.5, frame_id="odom"), + ], + ) + + +def _jpeg_eq(original: Any, decoded: Any) -> bool: + """JPEG is lossy — check shape, frame_id, and pixel closeness.""" + import numpy as np + + if decoded.data.shape != original.data.shape: + return False + if decoded.frame_id != original.frame_id: + return False + return bool(np.mean(np.abs(decoded.data.astype(float) - original.data.astype(float))) < 5) + + +def _jpeg_case() -> Case: + from dimos.memory2.codecs.jpeg import JpegCodec + from dimos.utils.testing import TimedSensorReplay + + replay = TimedSensorReplay("unitree_go2_bigoffice/video") + frames = [replay.find_closest_seek(float(i)) for i in range(1, 4)] + + return Case( + name="jpeg", + codec=JpegCodec(quality=95), + values=frames, + eq=_jpeg_eq, + ) + + +testcases = [_pickle_case(), _lcm_case(), _jpeg_case()] + + +# ── Tests ────────────────────────────────────────────────────────── + + +@pytest.mark.parametrize("case", testcases, ids=lambda c: c.name) +class TestCodecRoundtrip: + """Every codec must perfectly roundtrip its values.""" + + def test_roundtrip_preserves_value(self, case: Case) -> None: + eq = case.eq or (lambda a, b: a == b) + for value in case.values: + encoded = case.codec.encode(value) + assert isinstance(encoded, bytes) + decoded = case.codec.decode(encoded) + assert eq(value, decoded), f"Roundtrip failed for {value!r}: got {decoded!r}" + + def test_encode_returns_nonempty_bytes(self, case: Case) -> None: + for value in case.values: + encoded = case.codec.encode(value) + assert len(encoded) > 0, f"Empty encoding for {value!r}" + + def test_different_values_produce_different_bytes(self, case: Case) -> None: + encodings = [case.codec.encode(v) for v in case.values] + assert len(set(encodings)) > 1, "All values encoded to identical bytes" + + +class TestCodecFor: + """codec_for() auto-selects the right codec.""" + + def test_none_returns_pickle(self) -> None: + from dimos.memory2.codecs.pickle import PickleCodec + + assert isinstance(codec_for(None), PickleCodec) + + def test_unknown_type_returns_pickle(self) -> None: + from dimos.memory2.codecs.pickle import PickleCodec + + assert isinstance(codec_for(dict), PickleCodec) + + def test_lcm_type_returns_lcm(self) -> None: + from dimos.memory2.codecs.lcm import LcmCodec + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + + assert isinstance(codec_for(PoseStamped), LcmCodec) + + def test_image_type_returns_jpeg(self) -> None: + from dimos.memory2.codecs.jpeg import JpegCodec + from dimos.msgs.sensor_msgs.Image import Image + + assert isinstance(codec_for(Image), JpegCodec) From 7ce2364b056ee6e1e5952f5eec88bfdbe9281679 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Wed, 11 Mar 2026 15:11:27 +0800 Subject: [PATCH 072/118] resource: add context manager to Resource; make Store/Session Resources Resource.__enter__/__exit__ calls start()/stop(), giving every Resource context-manager support. memory2 Store and Session now extend Resource instead of bare ABC, replacing close() with the standard start()/stop() lifecycle. --- dimos/core/resource.py | 9 +++++++++ dimos/memory2/impl/sqlite.py | 4 ++-- dimos/memory2/store.py | 29 ++++++++++++----------------- 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/dimos/core/resource.py b/dimos/core/resource.py index df1ca568bc..7d76ec6281 100644 --- a/dimos/core/resource.py +++ b/dimos/core/resource.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from abc import abstractmethod from reactivex.abc import DisposableBase @@ -45,3 +47,10 @@ def dispose(self) -> None: """ self.stop() + + def __enter__(self) -> Resource: + self.start() + return self + + def __exit__(self, *args: object) -> None: + self.stop() diff --git a/dimos/memory2/impl/sqlite.py b/dimos/memory2/impl/sqlite.py index 61556b4998..3f6d752da7 100644 --- a/dimos/memory2/impl/sqlite.py +++ b/dimos/memory2/impl/sqlite.py @@ -73,8 +73,8 @@ def __init__(self, conn: sqlite3.Connection) -> None: def _create_backend(self, name: str, payload_type: type[Any] | None = None) -> Backend[Any]: return SqliteBackend(self._conn, name) - def close(self) -> None: - super().close() + def stop(self) -> None: + super().stop() self._conn.close() diff --git a/dimos/memory2/store.py b/dimos/memory2/store.py index 4dfca74cc0..2fb3de6646 100644 --- a/dimos/memory2/store.py +++ b/dimos/memory2/store.py @@ -14,9 +14,10 @@ from __future__ import annotations -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import TYPE_CHECKING, Any, TypeVar, cast +from dimos.core.resource import Resource from dimos.memory2.stream import Stream if TYPE_CHECKING: @@ -69,7 +70,7 @@ def __repr__(self) -> str: return f"StreamNamespace({list(self._session._streams.keys())})" -class Session(ABC): +class Session(Resource): """A session against a store. Manages named streams over a shared connection. Subclasses implement ``_create_backend`` to provide storage-specific backends. @@ -104,27 +105,21 @@ def delete_stream(self, name: str) -> None: def streams(self) -> StreamNamespace: return StreamNamespace(self) - def close(self) -> None: # noqa: B027 - """Release resources. Override in subclasses for cleanup.""" + def start(self) -> None: + pass - def __enter__(self) -> Session: - return self + def stop(self) -> None: + pass - def __exit__(self, *args: object) -> None: - self.close() - -class Store(ABC): +class Store(Resource): """Top-level entry point — wraps a storage location (file, URL, etc.).""" @abstractmethod def session(self) -> Session: ... - def close(self) -> None: # noqa: B027 - """Release resources. Override in subclasses for cleanup.""" - - def __enter__(self) -> Store: - return self + def start(self) -> None: + pass - def __exit__(self, *args: object) -> None: - self.close() + def stop(self) -> None: + pass From d5dde8125e0c390c3c5d18b062bbc141cfeccf3b Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Wed, 11 Mar 2026 15:27:01 +0800 Subject: [PATCH 073/118] resource: add CompositeResource with owned disposables MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CompositeResource extends Resource with a _disposables list and own() method. stop() disposes all children — gives tree-structured resources automatic cleanup. Session and Store now extend CompositeResource. --- dimos/core/resource.py | 21 +++++++++++++++++++++ dimos/memory2/store.py | 19 ++++--------------- 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/dimos/core/resource.py b/dimos/core/resource.py index 7d76ec6281..63ba31f210 100644 --- a/dimos/core/resource.py +++ b/dimos/core/resource.py @@ -17,6 +17,7 @@ from abc import abstractmethod from reactivex.abc import DisposableBase +from reactivex.disposable import CompositeDisposable class Resource(DisposableBase): @@ -54,3 +55,23 @@ def __enter__(self) -> Resource: def __exit__(self, *args: object) -> None: self.stop() + + +class CompositeResource(Resource): + """Resource that owns child disposables, disposed on stop().""" + + _disposables: CompositeDisposable + + def __init__(self) -> None: + self._disposables = CompositeDisposable() + + def own(self, *disposables: DisposableBase) -> None: + """Register child disposables to be disposed when this resource stops.""" + for d in disposables: + self._disposables.add(d) + + def start(self) -> None: + pass + + def stop(self) -> None: + self._disposables.dispose() diff --git a/dimos/memory2/store.py b/dimos/memory2/store.py index 2fb3de6646..d4d926230c 100644 --- a/dimos/memory2/store.py +++ b/dimos/memory2/store.py @@ -17,7 +17,7 @@ from abc import abstractmethod from typing import TYPE_CHECKING, Any, TypeVar, cast -from dimos.core.resource import Resource +from dimos.core.resource import CompositeResource from dimos.memory2.stream import Stream if TYPE_CHECKING: @@ -70,13 +70,14 @@ def __repr__(self) -> str: return f"StreamNamespace({list(self._session._streams.keys())})" -class Session(Resource): +class Session(CompositeResource): """A session against a store. Manages named streams over a shared connection. Subclasses implement ``_create_backend`` to provide storage-specific backends. """ def __init__(self) -> None: + super().__init__() self._streams: dict[str, Stream[Any]] = {} self._backends: dict[str, Backend[Any]] = {} @@ -105,21 +106,9 @@ def delete_stream(self, name: str) -> None: def streams(self) -> StreamNamespace: return StreamNamespace(self) - def start(self) -> None: - pass - def stop(self) -> None: - pass - - -class Store(Resource): +class Store(CompositeResource): """Top-level entry point — wraps a storage location (file, URL, etc.).""" @abstractmethod def session(self) -> Session: ... - - def start(self) -> None: - pass - - def stop(self) -> None: - pass From 9d37f1d957c80f22fdab030207276da69fc1d60c Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Wed, 11 Mar 2026 15:33:26 +0800 Subject: [PATCH 074/118] memory2: add BlobStore ABC with File and SQLite implementations BlobStore separates payload blob storage from metadata indexing. FileBlobStore stores on disk ({root}/{stream}/{key}.bin), SqliteBlobStore uses per-stream tables. Grid tests cover both. --- dimos/memory2/backend.py | 32 ++++++ dimos/memory2/blobstore/__init__.py | 19 ++++ dimos/memory2/blobstore/file.py | 64 ++++++++++++ dimos/memory2/blobstore/sqlite.py | 80 +++++++++++++++ dimos/memory2/blobstore/test_blobstore.py | 114 ++++++++++++++++++++++ dimos/memory2/impl/sqlite.py | 3 +- 6 files changed, 311 insertions(+), 1 deletion(-) create mode 100644 dimos/memory2/blobstore/__init__.py create mode 100644 dimos/memory2/blobstore/file.py create mode 100644 dimos/memory2/blobstore/sqlite.py create mode 100644 dimos/memory2/blobstore/test_blobstore.py diff --git a/dimos/memory2/backend.py b/dimos/memory2/backend.py index a63fae1f73..ef000ec38b 100644 --- a/dimos/memory2/backend.py +++ b/dimos/memory2/backend.py @@ -14,8 +14,11 @@ from __future__ import annotations +from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Protocol, TypeVar, runtime_checkable +from dimos.core.resource import Resource + if TYPE_CHECKING: from collections.abc import Iterator @@ -58,3 +61,32 @@ class LiveBackend(Backend[T], Protocol[T]): """Backend that also supports live subscriptions.""" def subscribe(self, buf: BackpressureBuffer[Observation[T]]) -> DisposableBase: ... + + +# ── Blob storage ────────────────────────────────────────────────── + + +class BlobStore(Resource, ABC): + """Persistent storage for encoded payload blobs. + + Separates payload data from metadata indexing so that large blobs + (images, point clouds) don't penalize metadata queries. + + Extends Resource (start/stop) but does NOT manage its dependencies' + lifecycle — the caller owns the session / connection. + """ + + @abstractmethod + def put(self, stream: str, key: int, data: bytes) -> None: + """Store a blob for the given stream and observation id.""" + ... + + @abstractmethod + def get(self, stream: str, key: int) -> bytes: + """Retrieve a blob by stream name and observation id.""" + ... + + @abstractmethod + def delete(self, stream: str, key: int) -> None: + """Delete a blob by stream name and observation id.""" + ... diff --git a/dimos/memory2/blobstore/__init__.py b/dimos/memory2/blobstore/__init__.py new file mode 100644 index 0000000000..8f78d7c439 --- /dev/null +++ b/dimos/memory2/blobstore/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.memory2.backend import BlobStore +from dimos.memory2.blobstore.file import FileBlobStore +from dimos.memory2.blobstore.sqlite import SqliteBlobStore + +__all__ = ["BlobStore", "FileBlobStore", "SqliteBlobStore"] diff --git a/dimos/memory2/blobstore/file.py b/dimos/memory2/blobstore/file.py new file mode 100644 index 0000000000..54ec80e284 --- /dev/null +++ b/dimos/memory2/blobstore/file.py @@ -0,0 +1,64 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +from dimos.memory2.backend import BlobStore + +if TYPE_CHECKING: + import os + + +class FileBlobStore(BlobStore): + """Stores blobs as files on disk, one directory per stream. + + Layout:: + + {root}/{stream}/{key}.bin + """ + + def __init__(self, root: str | os.PathLike[str]) -> None: + self._root = Path(root) + + def _path(self, stream: str, key: int) -> Path: + return self._root / stream / f"{key}.bin" + + # ── Resource lifecycle ──────────────────────────────────────── + + def start(self) -> None: + self._root.mkdir(parents=True, exist_ok=True) + + def stop(self) -> None: + pass + + # ── BlobStore interface ─────────────────────────────────────── + + def put(self, stream: str, key: int, data: bytes) -> None: + p = self._path(stream, key) + p.parent.mkdir(parents=True, exist_ok=True) + p.write_bytes(data) + + def get(self, stream: str, key: int) -> bytes: + p = self._path(stream, key) + try: + return p.read_bytes() + except FileNotFoundError: + raise KeyError(f"No blob for stream={stream!r}, key={key}") from None + + def delete(self, stream: str, key: int) -> None: + p = self._path(stream, key) + p.unlink(missing_ok=True) diff --git a/dimos/memory2/blobstore/sqlite.py b/dimos/memory2/blobstore/sqlite.py new file mode 100644 index 0000000000..0fd144c532 --- /dev/null +++ b/dimos/memory2/blobstore/sqlite.py @@ -0,0 +1,80 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from dimos.memory2.backend import BlobStore + +if TYPE_CHECKING: + import sqlite3 + + +class SqliteBlobStore(BlobStore): + """Stores blobs in a separate SQLite table per stream. + + Table layout per stream:: + + CREATE TABLE "{stream}_blob" ( + id INTEGER PRIMARY KEY, + data BLOB NOT NULL + ); + + Does NOT own the connection — lifecycle managed externally. + """ + + def __init__(self, conn: sqlite3.Connection) -> None: + self._conn = conn + self._tables: set[str] = set() + + def _ensure_table(self, stream: str) -> None: + if stream in self._tables: + return + self._conn.execute( + f'CREATE TABLE IF NOT EXISTS "{stream}_blob" ' + "(id INTEGER PRIMARY KEY, data BLOB NOT NULL)" + ) + self._tables.add(stream) + + # ── Resource lifecycle ──────────────────────────────────────── + + def start(self) -> None: + pass + + def stop(self) -> None: + pass + + # ── BlobStore interface ─────────────────────────────────────── + + def put(self, stream: str, key: int, data: bytes) -> None: + self._ensure_table(stream) + self._conn.execute( + f'INSERT OR REPLACE INTO "{stream}_blob" (id, data) VALUES (?, ?)', + (key, data), + ) + + def get(self, stream: str, key: int) -> bytes: + self._ensure_table(stream) + row = self._conn.execute( + f'SELECT data FROM "{stream}_blob" WHERE id = ?', (key,) + ).fetchone() + if row is None: + raise KeyError(f"No blob for stream={stream!r}, key={key}") + result: bytes = row[0] + return result + + def delete(self, stream: str, key: int) -> None: + self._ensure_table(stream) + self._conn.execute(f'DELETE FROM "{stream}_blob" WHERE id = ?', (key,)) diff --git a/dimos/memory2/blobstore/test_blobstore.py b/dimos/memory2/blobstore/test_blobstore.py new file mode 100644 index 0000000000..fe05cfa84f --- /dev/null +++ b/dimos/memory2/blobstore/test_blobstore.py @@ -0,0 +1,114 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Grid tests for BlobStore implementations.""" + +from __future__ import annotations + +from dataclasses import dataclass +import sqlite3 +from typing import TYPE_CHECKING + +import pytest + +from dimos.memory2.blobstore.file import FileBlobStore +from dimos.memory2.blobstore.sqlite import SqliteBlobStore + +if TYPE_CHECKING: + from collections.abc import Callable, Generator + from pathlib import Path + + from dimos.memory2.backend import BlobStore + + +# ── Case definition ──────────────────────────────────────────────── + + +@dataclass +class Case: + name: str + factory: Callable[..., Generator[BlobStore, None, None]] + + +# ── Factories ────────────────────────────────────────────────────── + + +@pytest.fixture() +def file_store(tmp_path: Path) -> Generator[FileBlobStore, None, None]: + store = FileBlobStore(tmp_path / "blobs") + store.start() + yield store + store.stop() + + +@pytest.fixture() +def sqlite_store() -> Generator[SqliteBlobStore, None, None]: + conn = sqlite3.connect(":memory:") + store = SqliteBlobStore(conn) + store.start() + yield store + store.stop() + conn.close() + + +@pytest.fixture(params=["file", "sqlite"]) +def blob_store( + request: pytest.FixtureRequest, + file_store: FileBlobStore, + sqlite_store: SqliteBlobStore, +) -> BlobStore: + if request.param == "file": + return file_store + return sqlite_store + + +# ── Tests ────────────────────────────────────────────────────────── + + +class TestBlobStore: + """Every BlobStore must satisfy these contracts.""" + + def test_put_get_roundtrip(self, blob_store: BlobStore) -> None: + data = b"hello world" + blob_store.put("stream_a", 1, data) + assert blob_store.get("stream_a", 1) == data + + def test_get_missing_raises(self, blob_store: BlobStore) -> None: + with pytest.raises(KeyError): + blob_store.get("nonexistent", 999) + + def test_put_overwrite(self, blob_store: BlobStore) -> None: + blob_store.put("s", 1, b"first") + blob_store.put("s", 1, b"second") + assert blob_store.get("s", 1) == b"second" + + def test_delete(self, blob_store: BlobStore) -> None: + blob_store.put("s", 1, b"data") + blob_store.delete("s", 1) + with pytest.raises(KeyError): + blob_store.get("s", 1) + + def test_delete_missing_is_silent(self, blob_store: BlobStore) -> None: + blob_store.delete("s", 999) # should not raise + + def test_stream_isolation(self, blob_store: BlobStore) -> None: + blob_store.put("a", 1, b"alpha") + blob_store.put("b", 1, b"beta") + assert blob_store.get("a", 1) == b"alpha" + assert blob_store.get("b", 1) == b"beta" + + def test_large_blob(self, blob_store: BlobStore) -> None: + data = bytes(range(256)) * 1000 # 256 KB + blob_store.put("big", 0, data) + assert blob_store.get("big", 0) == data diff --git a/dimos/memory2/impl/sqlite.py b/dimos/memory2/impl/sqlite.py index 3f6d752da7..ef82d62ac5 100644 --- a/dimos/memory2/impl/sqlite.py +++ b/dimos/memory2/impl/sqlite.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from collections.abc import Iterator + import os from reactivex.abc import DisposableBase @@ -81,7 +82,7 @@ def stop(self) -> None: class SqliteStore(Store): """Store backed by a SQLite database file.""" - def __init__(self, path: str) -> None: + def __init__(self, path: str | os.PathLike[str]) -> None: self._path = path def session(self) -> SqliteSession: From a83d7a297867184a004a809f2e86a6bb8b8671cc Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Wed, 11 Mar 2026 15:46:08 +0800 Subject: [PATCH 075/118] memory2: move blobstore.md into blobstore/ as module readme --- dimos/memory2/blobstore/blobstore.md | 86 ++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 dimos/memory2/blobstore/blobstore.md diff --git a/dimos/memory2/blobstore/blobstore.md b/dimos/memory2/blobstore/blobstore.md new file mode 100644 index 0000000000..79a36d52ae --- /dev/null +++ b/dimos/memory2/blobstore/blobstore.md @@ -0,0 +1,86 @@ +# blobstore/ + +Separates payload blob storage from metadata indexing. Observation payloads vary hugely in size — a `Vector3` is 24 bytes, a camera frame is megabytes. Storing everything inline penalizes metadata queries. BlobStore lets large payloads live elsewhere. + +## ABC (`backend.py`) + +```python +class BlobStore(Resource, ABC): + def put(self, stream: str, key: int, data: bytes) -> None: ... + def get(self, stream: str, key: int) -> bytes: ... # raises KeyError if missing + def delete(self, stream: str, key: int) -> None: ... # silent if missing +``` + +- `stream` — stream name (used to organize storage: directories, tables) +- `key` — observation id +- `data` — encoded payload bytes (codec handles serialization, blob store handles persistence) +- Extends `Resource` (start/stop) but does NOT own its dependencies' lifecycle + +## Implementations + +### `file.py` — FileBlobStore + +Stores blobs as files on disk, one directory per stream. + +``` +{root}/{stream}/{key}.bin +``` + +`__init__(root: str | os.PathLike[str])` — `start()` creates the root directory. + +### `sqlite.py` — SqliteBlobStore + +Stores blobs in a separate SQLite table per stream. + +```sql +CREATE TABLE "{stream}_blob" (id INTEGER PRIMARY KEY, data BLOB NOT NULL) +``` + +`__init__(conn: sqlite3.Connection)` — does NOT own the connection. + +**Internal use** (same db as metadata): `SqliteStore.session()` creates one connection, passes it to both the metadata backend and the blob store. + +**External use** (separate db): user creates a separate connection and passes it. User manages that connection's lifecycle. + +**JOIN optimization** (future): when `lazy=False` and the blob store shares the same connection as the metadata backend, `SqliteBackend` can optimize with a JOIN instead of separate queries: + +```sql +SELECT m.id, m.ts, m.pose, m.tags, b.data +FROM "images" m JOIN "images_blob" b ON m.id = b.id +WHERE m.ts > ? +``` + +## Lazy loading + +`lazy` is a stream-level flag, orthogonal to blob store choice. It controls WHEN data is loaded: + +- `lazy=False` → backend loads payload during iteration (eager) +- `lazy=True` → backend sets `Observation._loader`, payload loaded on `.data` access + +| lazy | blob store | loading strategy | +|------|-----------|-----------------| +| False | SqliteBlobStore (same conn) | JOIN — one round trip | +| False | any other | iterate meta, `blob_store.get()` per row | +| True | any | iterate meta only, `_loader = lambda: codec.decode(blob_store.get(...))` | + +## Usage + +```python +# Per-stream blob store choice +with store.session() as session: + poses = session.stream("poses", PoseStamped) # default, eager + images = session.stream("images", Image, lazy=True) # default, lazy + images = session.stream("images", Image, blob_store=file_blobs) # override +``` + +## Files + +``` +backend.py BlobStore ABC (alongside Backend, LiveBackend) +blobstore/ + blobstore.md this file + __init__.py re-exports BlobStore, FileBlobStore, SqliteBlobStore + file.py FileBlobStore + sqlite.py SqliteBlobStore + test_blobstore.py grid tests across implementations +``` From b6c9543635f81d106a12be3c4e3a47293d338245 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Wed, 11 Mar 2026 18:21:35 +0800 Subject: [PATCH 076/118] memory2: add embedding layer, vector/text search, live safety guards - EmbeddedObservation with derive() promotion semantics - EmbedImages/EmbedText transformers using EmbeddingModel ABC - .search(vec, k) and .search_text() on Stream with Embedding type - VectorStore ABC for pluggable vector backends - Backend.append() takes Observation directly (not kwargs) - is_live() walks source chain; search/order_by/fetch/count guard against live streams with TypeError instead of silent hang - .drain() terminal for constant-memory side-effect pipelines - Rewrite test_stream.py to use Stream layer (no manual backends) --- dimos/memory2/__init__.py | 9 +- dimos/memory2/backend.py | 48 ++++- dimos/memory2/embed.py | 79 +++++++ dimos/memory2/filter.py | 8 +- dimos/memory2/impl/memory.py | 41 ++-- dimos/memory2/impl/sqlite.py | 11 +- dimos/memory2/stream.py | 97 ++++++++- dimos/memory2/test_embedding.py | 372 ++++++++++++++++++++++++++++++++ dimos/memory2/test_impl.py | 43 ++++ dimos/memory2/test_save.py | 19 +- dimos/memory2/test_stream.py | 257 ++++++++++++++++------ dimos/memory2/type.py | 41 +++- 12 files changed, 897 insertions(+), 128 deletions(-) create mode 100644 dimos/memory2/embed.py create mode 100644 dimos/memory2/test_embedding.py diff --git a/dimos/memory2/__init__.py b/dimos/memory2/__init__.py index 0fbf95a5ff..779a38b041 100644 --- a/dimos/memory2/__init__.py +++ b/dimos/memory2/__init__.py @@ -1,4 +1,4 @@ -from dimos.memory2.backend import Backend, LiveBackend +from dimos.memory2.backend import Backend, LiveBackend, VectorStore from dimos.memory2.buffer import ( BackpressureBuffer, Bounded, @@ -7,6 +7,7 @@ KeepLast, Unbounded, ) +from dimos.memory2.embed import EmbedImages, EmbedText from dimos.memory2.filter import ( AfterFilter, AtFilter, @@ -23,7 +24,7 @@ from dimos.memory2.store import Session, Store, StreamNamespace from dimos.memory2.stream import Stream from dimos.memory2.transform import FnTransformer, QualityWindow, Transformer -from dimos.memory2.type import Observation +from dimos.memory2.type import EmbeddedObservation, Observation __all__ = [ "AfterFilter", @@ -34,6 +35,9 @@ "Bounded", "ClosedError", "DropNew", + "EmbedImages", + "EmbedText", + "EmbeddedObservation", "Filter", "FnTransformer", "KeepLast", @@ -57,4 +61,5 @@ "TimeRangeFilter", "Transformer", "Unbounded", + "VectorStore", ] diff --git a/dimos/memory2/backend.py b/dimos/memory2/backend.py index ef000ec38b..e1f87287bf 100644 --- a/dimos/memory2/backend.py +++ b/dimos/memory2/backend.py @@ -15,7 +15,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Protocol, TypeVar, runtime_checkable +from typing import TYPE_CHECKING, Protocol, TypeVar, runtime_checkable from dimos.core.resource import Resource @@ -27,6 +27,7 @@ from dimos.memory2.buffer import BackpressureBuffer from dimos.memory2.filter import StreamQuery from dimos.memory2.type import Observation + from dimos.models.embedding.base import Embedding T = TypeVar("T") @@ -44,14 +45,13 @@ def name(self) -> str: ... def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: ... - def append( - self, - payload: T, - *, - ts: float | None = None, - pose: Any | None = None, - tags: dict[str, Any] | None = None, - ) -> Observation[T]: ... + def append(self, obs: Observation[T]) -> Observation[T]: + """Store an observation, assigning it a backend-managed id. + + The caller builds the Observation (or EmbeddedObservation); + the backend assigns the canonical ``id`` and persists it. + """ + ... def count(self, query: StreamQuery) -> int: ... @@ -90,3 +90,33 @@ def get(self, stream: str, key: int) -> bytes: def delete(self, stream: str, key: int) -> None: """Delete a blob by stream name and observation id.""" ... + + +# ── Vector storage ─────────────────────────────────────────────── + + +class VectorStore(Resource, ABC): + """Pluggable storage and ANN index for embedding vectors. + + Separates vector indexing from metadata so backends can swap + search strategies (brute-force, vec0, FAISS, Qdrant) independently. + + Same shape as BlobStore: ``put`` / ``search`` / ``delete``, keyed + by ``(stream, observation_id)``. Index creation is lazy — the + first ``put`` for a stream determines dimensionality. + """ + + @abstractmethod + def put(self, stream: str, key: int, embedding: Embedding) -> None: + """Store an embedding vector for the given stream and observation id.""" + ... + + @abstractmethod + def search(self, stream: str, query: Embedding, k: int) -> list[tuple[int, float]]: + """Return top-k (observation_id, similarity) pairs, descending.""" + ... + + @abstractmethod + def delete(self, stream: str, key: int) -> None: + """Remove a vector. Silent if missing.""" + ... diff --git a/dimos/memory2/embed.py b/dimos/memory2/embed.py new file mode 100644 index 0000000000..e3b34bb0ae --- /dev/null +++ b/dimos/memory2/embed.py @@ -0,0 +1,79 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from itertools import islice +from typing import TYPE_CHECKING, Any, TypeVar + +from dimos.memory2.transform import Transformer + +if TYPE_CHECKING: + from collections.abc import Iterator + + from dimos.memory2.type import EmbeddedObservation, Observation + from dimos.models.embedding.base import EmbeddingModel + +T = TypeVar("T") + + +def _batched(it: Iterator[T], n: int) -> Iterator[list[T]]: + """Yield successive n-sized chunks from an iterator.""" + while True: + batch = list(islice(it, n)) + if not batch: + return + yield batch + + +class EmbedImages(Transformer[Any, Any]): + """Embed images using ``model.embed()``. + + Data type stays the same — observations are enriched with an + ``.embedding`` field, yielding :class:`EmbeddedObservation` instances. + """ + + def __init__(self, model: EmbeddingModel, batch_size: int = 32) -> None: + self.model = model + self.batch_size = batch_size + + def __call__(self, upstream: Iterator[Observation[Any]]) -> Iterator[EmbeddedObservation[Any]]: + for batch in _batched(upstream, self.batch_size): + images = [obs.data for obs in batch] + embeddings = self.model.embed(*images) + if not isinstance(embeddings, list): + embeddings = [embeddings] + for obs, emb in zip(batch, embeddings, strict=False): + yield obs.derive(data=obs.data, embedding=emb) + + +class EmbedText(Transformer[Any, Any]): + """Embed text using ``model.embed_text()``. + + Data type stays the same — observations are enriched with an + ``.embedding`` field, yielding :class:`EmbeddedObservation` instances. + """ + + def __init__(self, model: EmbeddingModel, batch_size: int = 32) -> None: + self.model = model + self.batch_size = batch_size + + def __call__(self, upstream: Iterator[Observation[Any]]) -> Iterator[EmbeddedObservation[Any]]: + for batch in _batched(upstream, self.batch_size): + texts = [str(obs.data) for obs in batch] + embeddings = self.model.embed_text(*texts) + if not isinstance(embeddings, list): + embeddings = [embeddings] + for obs, emb in zip(batch, embeddings, strict=False): + yield obs.derive(data=obs.data, embedding=emb) diff --git a/dimos/memory2/filter.py b/dimos/memory2/filter.py index 942b3f52d7..243200f68a 100644 --- a/dimos/memory2/filter.py +++ b/dimos/memory2/filter.py @@ -14,7 +14,7 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable if TYPE_CHECKING: @@ -22,6 +22,7 @@ from dimos.memory2.buffer import BackpressureBuffer from dimos.memory2.type import Observation + from dimos.models.embedding.base import Embedding # ── Filter protocol ───────────────────────────────────────────────── @@ -131,3 +132,8 @@ class StreamQuery: limit_val: int | None = None offset_val: int | None = None live_buffer: BackpressureBuffer[Any] | None = None + # Vector search (embedding similarity) + search_vec: Embedding | None = field(default=None, hash=False, compare=False) + search_k: int | None = None + # Full-text search (substring / FTS5) + search_text: str | None = None diff --git a/dimos/memory2/impl/memory.py b/dimos/memory2/impl/memory.py index 870ee044b9..3e8f89c63c 100644 --- a/dimos/memory2/impl/memory.py +++ b/dimos/memory2/impl/memory.py @@ -15,13 +15,11 @@ from __future__ import annotations import threading -import time from typing import TYPE_CHECKING, Any, Generic, TypeVar from reactivex.disposable import Disposable from dimos.memory2.store import Session, Store -from dimos.memory2.type import Observation if TYPE_CHECKING: from collections.abc import Iterator @@ -31,6 +29,7 @@ from dimos.memory2.backend import Backend from dimos.memory2.buffer import BackpressureBuffer from dimos.memory2.filter import StreamQuery + from dimos.memory2.type import Observation T = TypeVar("T") @@ -49,22 +48,9 @@ def __init__(self, name: str = "") -> None: def name(self) -> str: return self._name - def append( - self, - payload: T, - *, - ts: float | None = None, - pose: Any | None = None, - tags: dict[str, Any] | None = None, - ) -> Observation[T]: + def append(self, obs: Observation[T]) -> Observation[T]: with self._lock: - obs: Observation[T] = Observation( - id=self._next_id, - ts=ts if ts is not None else time.time(), - pose=pose, - tags=tags or {}, - _data=payload, - ) + obs.id = self._next_id self._next_id += 1 self._observations.append(obs) subs = list(self._subscribers) @@ -81,6 +67,8 @@ def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: If query.live_buffer is set, subscribes before backfill, then switches to a live tail that blocks for new observations. """ + if query.search_vec is not None and query.live_buffer is not None: + raise TypeError("Cannot combine .search() with .live() — search is a batch operation.") buf = query.live_buffer if buf is not None: # Subscribe BEFORE backfill to avoid missing items @@ -96,6 +84,25 @@ def _iterate_snapshot(self, query: StreamQuery) -> Iterator[Observation[T]]: for f in query.filters: snapshot = [obs for obs in snapshot if f.matches(obs)] + # Text search — substring match (SqliteBackend will use FTS5) + if query.search_text is not None: + needle = query.search_text.lower() + snapshot = [obs for obs in snapshot if needle in str(obs.data).lower()] + + # Vector search — brute-force cosine via Embedding.__matmul__ + if query.search_vec is not None: + query_emb = query.search_vec + scored: list[Observation[T]] = [] + for obs in snapshot: + emb = getattr(obs, "embedding", None) + if emb is not None: + sim = float(emb @ query_emb) + scored.append(obs.derive(data=obs.data, similarity=sim)) + scored.sort(key=lambda o: getattr(o, "similarity", 0.0) or 0.0, reverse=True) + if query.search_k is not None: + scored = scored[: query.search_k] + snapshot = scored + # Ordering if query.order_field: key = query.order_field diff --git a/dimos/memory2/impl/sqlite.py b/dimos/memory2/impl/sqlite.py index ef82d62ac5..c87162acf2 100644 --- a/dimos/memory2/impl/sqlite.py +++ b/dimos/memory2/impl/sqlite.py @@ -15,7 +15,7 @@ from __future__ import annotations import sqlite3 -from typing import TYPE_CHECKING, Any, Generic, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar from dimos.memory2.store import Session, Store @@ -47,14 +47,7 @@ def name(self) -> str: def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: raise NotImplementedError - def append( - self, - payload: T, - *, - ts: float | None = None, - pose: Any | None = None, - tags: dict[str, Any] | None = None, - ) -> Observation[T]: + def append(self, obs: Observation[T]) -> Observation[T]: raise NotImplementedError def count(self, query: StreamQuery) -> int: diff --git a/dimos/memory2/stream.py b/dimos/memory2/stream.py index 69ba220575..a5c00d8af1 100644 --- a/dimos/memory2/stream.py +++ b/dimos/memory2/stream.py @@ -15,6 +15,7 @@ from __future__ import annotations from itertools import islice +import time from typing import TYPE_CHECKING, Any, Generic, TypeVar from dimos.memory2.backend import Backend, LiveBackend @@ -31,11 +32,12 @@ TimeRangeFilter, ) from dimos.memory2.transform import FnTransformer, Transformer +from dimos.memory2.type import EmbeddedObservation, Observation if TYPE_CHECKING: from collections.abc import Callable, Iterator - from dimos.memory2.type import Observation + from dimos.models.embedding.base import Embedding T = TypeVar("T") R = TypeVar("R") @@ -60,6 +62,14 @@ def __init__( self._xf = xf self._query = query + def is_live(self) -> bool: + """True if this stream (or any ancestor in the chain) is in live mode.""" + if self._query.live_buffer is not None: + return True + if isinstance(self._source, Stream): + return self._source.is_live() + return False + # ── Iteration ─────────────────────────────────────────────────── def __iter__(self) -> Iterator[Observation[T]]: @@ -81,8 +91,36 @@ def _iter_transform(self) -> Iterator[Observation[T]]: if filters: it = (obs for obs in it if all(f.matches(obs) for f in filters)) - # Sort if needed (materializes — only for finite streams) + # Text search — substring match + if self._query.search_text is not None: + needle = self._query.search_text.lower() + it = (obs for obs in it if needle in str(obs.data).lower()) + + # Vector search — brute-force cosine (materializes — rejects live) + if self._query.search_vec is not None: + if self.is_live(): + raise TypeError( + ".search() requires finite data — cannot rank an infinite live stream." + ) + query_emb = self._query.search_vec + scored = [] + for obs in it: + emb = getattr(obs, "embedding", None) + if emb is not None: + sim = float(emb @ query_emb) + scored.append(obs.derive(data=obs.data, similarity=sim)) + scored.sort(key=lambda o: getattr(o, "similarity", 0.0) or 0.0, reverse=True) + k = self._query.search_k + if k is not None: + scored = scored[:k] + it = iter(scored) + + # Sort if needed (materializes — rejects live) if self._query.order_field: + if self.is_live(): + raise TypeError( + ".order_by() requires finite data — cannot sort an infinite live stream." + ) key = self._query.order_field desc = self._query.order_desc items = sorted( @@ -111,6 +149,9 @@ def _replace_query(self, **overrides: Any) -> Stream[T]: limit_val=overrides.get("limit_val", q.limit_val), offset_val=overrides.get("offset_val", q.offset_val), live_buffer=overrides.get("live_buffer", q.live_buffer), + search_vec=overrides.get("search_vec", q.search_vec), + search_k=overrides.get("search_k", q.search_k), + search_text=overrides.get("search_text", q.search_text), ) return Stream(self._source, xf=self._xf, query=new_q) @@ -144,6 +185,22 @@ def limit(self, k: int) -> Stream[T]: def offset(self, n: int) -> Stream[T]: return self._replace_query(offset_val=n) + def search(self, query: Embedding, k: int) -> Stream[T]: + """Return top-k observations by cosine similarity to *query*. + + The backend handles the actual computation. ListBackend does + brute-force cosine; SqliteBackend (future) pushes down to vec0. + """ + return self._replace_query(search_vec=query, search_k=k) + + def search_text(self, text: str) -> Stream[T]: + """Filter observations whose data contains *text*. + + ListBackend does case-insensitive substring match; + SqliteBackend (future) pushes down to FTS5. + """ + return self._replace_query(search_text=text) + # ── Functional API ────────────────────────────────────────────── def filter(self, pred: Callable[[Observation[T]], bool]) -> Stream[T]: @@ -195,13 +252,18 @@ def save(self, target: Stream[T]) -> Stream[T]: raise TypeError("Cannot save to a transform stream. Target must be backend-backed.") backend = target._source for obs in self: - backend.append(obs.data, ts=obs.ts, pose=obs.pose, tags=obs.tags) + backend.append(obs) return target # ── Terminals ─────────────────────────────────────────────────── def fetch(self) -> list[Observation[T]]: """Materialize all observations into a list.""" + if self.is_live() and self._query.limit_val is None: + raise TypeError( + ".fetch() on a live stream without .limit() would collect forever. " + "Use .limit(n).fetch(), .drain(), or .save(target) instead." + ) return list(self) def first(self) -> Observation[T]: @@ -220,12 +282,25 @@ def count(self) -> int: """Count matching observations.""" if isinstance(self._source, Backend): return self._source.count(self._query) + if self.is_live(): + raise TypeError(".count() on a live transform stream would block forever.") return sum(1 for _ in self) def exists(self) -> bool: """Check if any matching observation exists.""" return next(iter(self.limit(1)), None) is not None + def drain(self) -> int: + """Consume all observations, discarding results. Returns count consumed. + + Use for side-effect pipelines (e.g. live embed-and-store) where you + don't need to collect results in memory. + """ + n = 0 + for _ in self: + n += 1 + return n + # ── Write ─────────────────────────────────────────────────────── def append( @@ -235,8 +310,22 @@ def append( ts: float | None = None, pose: Any | None = None, tags: dict[str, Any] | None = None, + embedding: Embedding | None = None, ) -> Observation[T]: """Append to the backing store. Only works if source is a Backend.""" if isinstance(self._source, Stream): raise TypeError("Cannot append to a transform stream. Append to the source stream.") - return self._source.append(payload, ts=ts, pose=pose, tags=tags) + _ts = ts if ts is not None else time.time() + _tags = tags or {} + if embedding is not None: + obs: Observation[T] = EmbeddedObservation( + id=-1, + ts=_ts, + pose=pose, + tags=_tags, + _data=payload, + embedding=embedding, + ) + else: + obs = Observation(id=-1, ts=_ts, pose=pose, tags=_tags, _data=payload) + return self._source.append(obs) diff --git a/dimos/memory2/test_embedding.py b/dimos/memory2/test_embedding.py new file mode 100644 index 0000000000..40f80346c0 --- /dev/null +++ b/dimos/memory2/test_embedding.py @@ -0,0 +1,372 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for embedding layer: EmbeddedObservation, vector search, text search, transformers.""" + +from __future__ import annotations + +import numpy as np +import pytest + +from dimos.memory2.impl.memory import MemoryStore +from dimos.memory2.type import EmbeddedObservation, Observation +from dimos.models.embedding.base import Embedding + +# ── Helpers ─────────────────────────────────────────────────────── + + +def _emb(vec: list[float]) -> Embedding: + """Return a unit-normalized Embedding.""" + v = np.array(vec, dtype=np.float32) + v /= np.linalg.norm(v) + 1e-10 + return Embedding(vector=v) + + +# ── EmbeddedObservation ────────────────────────────────────────── + + +class TestEmbeddedObservation: + def test_construction(self) -> None: + emb = _emb([1, 0, 0]) + obs = EmbeddedObservation(id=0, ts=1.0, _data="hello", embedding=emb) + assert obs.data == "hello" + assert obs.embedding is emb + assert obs.similarity is None + + def test_is_observation(self) -> None: + obs = EmbeddedObservation(id=0, ts=1.0, _data="x", embedding=_emb([1, 0])) + assert isinstance(obs, Observation) + + def test_derive_preserves_embedding(self) -> None: + emb = _emb([1, 0, 0]) + obs = EmbeddedObservation(id=0, ts=1.0, _data="a", embedding=emb) + derived = obs.derive(data="b") + assert isinstance(derived, EmbeddedObservation) + assert derived.embedding is emb + assert derived.data == "b" + + def test_derive_replaces_embedding(self) -> None: + old = _emb([1, 0, 0]) + new = _emb([0, 1, 0]) + obs = EmbeddedObservation(id=0, ts=1.0, _data="a", embedding=old) + derived = obs.derive(data="a", embedding=new) + assert derived.embedding is new + + def test_derive_preserves_similarity(self) -> None: + obs = EmbeddedObservation(id=0, ts=1.0, _data="a", embedding=_emb([1, 0]), similarity=0.95) + derived = obs.derive(data="b") + assert derived.similarity == 0.95 + + def test_observation_derive_promotes_to_embedded(self) -> None: + obs = Observation(id=0, ts=1.0, _data="plain") + emb = _emb([1, 0, 0]) + derived = obs.derive(data="plain", embedding=emb) + assert isinstance(derived, EmbeddedObservation) + assert derived.embedding is emb + + def test_observation_derive_without_embedding_stays_observation(self) -> None: + obs = Observation(id=0, ts=1.0, _data="plain") + derived = obs.derive(data="still plain") + assert type(derived) is Observation + + +# ── ListBackend embedding support ──────────────────────────────── + + +class TestListBackendEmbedding: + def test_append_with_embedding(self) -> None: + store = MemoryStore() + with store.session() as session: + s = session.stream("vecs", str) + emb = _emb([1, 0, 0]) + obs = s.append("hello", embedding=emb) + assert isinstance(obs, EmbeddedObservation) + assert obs.embedding is emb + + def test_append_without_embedding(self) -> None: + store = MemoryStore() + with store.session() as session: + s = session.stream("plain", str) + obs = s.append("hello") + assert type(obs) is Observation + + def test_search_returns_top_k(self) -> None: + store = MemoryStore() + with store.session() as session: + s = session.stream("vecs", str) + s.append("north", embedding=_emb([0, 1, 0])) + s.append("east", embedding=_emb([1, 0, 0])) + s.append("south", embedding=_emb([0, -1, 0])) + s.append("west", embedding=_emb([-1, 0, 0])) + + results = s.search(_emb([0, 1, 0]), k=2).fetch() + assert len(results) == 2 + assert results[0].data == "north" + assert results[0].similarity is not None + assert results[0].similarity > 0.99 + + def test_search_sorted_by_similarity(self) -> None: + store = MemoryStore() + with store.session() as session: + s = session.stream("vecs", str) + s.append("far", embedding=_emb([0, -1, 0])) + s.append("close", embedding=_emb([0.9, 0.1, 0])) + s.append("exact", embedding=_emb([1, 0, 0])) + + results = s.search(_emb([1, 0, 0]), k=3).fetch() + assert results[0].data == "exact" + assert results[1].data == "close" + assert results[2].data == "far" + # Descending similarity + assert results[0].similarity >= results[1].similarity >= results[2].similarity + + def test_search_skips_non_embedded(self) -> None: + store = MemoryStore() + with store.session() as session: + s = session.stream("mixed", str) + s.append("plain") # no embedding + s.append("embedded", embedding=_emb([1, 0, 0])) + + results = s.search(_emb([1, 0, 0]), k=10).fetch() + assert len(results) == 1 + assert results[0].data == "embedded" + + def test_search_with_filters(self) -> None: + store = MemoryStore() + with store.session() as session: + s = session.stream("vecs", str) + s.append("early", ts=10.0, embedding=_emb([1, 0, 0])) + s.append("late", ts=20.0, embedding=_emb([1, 0, 0])) + + # Only the late one should pass the after filter + results = s.after(15.0).search(_emb([1, 0, 0]), k=10).fetch() + assert len(results) == 1 + assert results[0].data == "late" + + def test_search_with_limit(self) -> None: + store = MemoryStore() + with store.session() as session: + s = session.stream("vecs", str) + for i in range(10): + s.append(f"item{i}", embedding=_emb([1, 0, 0])) + + # search k=5 then limit 2 + results = s.search(_emb([1, 0, 0]), k=5).limit(2).fetch() + assert len(results) == 2 + + def test_search_with_live_raises(self) -> None: + store = MemoryStore() + with store.session() as session: + s = session.stream("vecs", str) + s.append("x", embedding=_emb([1, 0, 0])) + with pytest.raises(TypeError, match="Cannot combine"): + list(s.live().search(_emb([1, 0, 0]), k=5)) + + +# ── Text search ────────────────────────────────────────────────── + + +class TestTextSearch: + def test_search_text_substring(self) -> None: + store = MemoryStore() + with store.session() as session: + s = session.stream("logs", str) + s.append("motor fault detected") + s.append("temperature normal") + s.append("motor overheating") + + results = s.search_text("motor").fetch() + assert len(results) == 2 + assert {r.data for r in results} == {"motor fault detected", "motor overheating"} + + def test_search_text_case_insensitive(self) -> None: + store = MemoryStore() + with store.session() as session: + s = session.stream("logs", str) + s.append("Motor Fault") + s.append("other event") + + results = s.search_text("motor fault").fetch() + assert len(results) == 1 + + def test_search_text_with_filters(self) -> None: + store = MemoryStore() + with store.session() as session: + s = session.stream("logs", str) + s.append("motor fault", ts=10.0) + s.append("motor warning", ts=20.0) + s.append("motor fault", ts=30.0) + + results = s.after(15.0).search_text("fault").fetch() + assert len(results) == 1 + assert results[0].ts == 30.0 + + def test_search_text_no_match(self) -> None: + store = MemoryStore() + with store.session() as session: + s = session.stream("logs", str) + s.append("all clear") + + results = s.search_text("motor").fetch() + assert len(results) == 0 + + +# ── Save preserves embeddings ──────────────────────────────────── + + +class TestSaveEmbeddings: + def test_save_preserves_embeddings(self) -> None: + store = MemoryStore() + with store.session() as session: + src = session.stream("source", str) + dst = session.stream("dest", str) + + emb = _emb([1, 0, 0]) + src.append("item", embedding=emb) + src.save(dst) + + results = dst.fetch() + assert len(results) == 1 + assert isinstance(results[0], EmbeddedObservation) + # Same vector content (different Embedding instance after re-append) + np.testing.assert_array_almost_equal(results[0].embedding.to_numpy(), emb.to_numpy()) + + def test_save_mixed_embedded_and_plain(self) -> None: + store = MemoryStore() + with store.session() as session: + src = session.stream("source", str) + dst = session.stream("dest", str) + + src.append("plain") + src.append("embedded", embedding=_emb([0, 1, 0])) + src.save(dst) + + results = dst.fetch() + assert len(results) == 2 + assert type(results[0]) is Observation + assert isinstance(results[1], EmbeddedObservation) + + +# ── Embed transformers (mock model) ───────────────────────────── + + +class _MockEmbeddingModel: + """Fake EmbeddingModel that returns deterministic unit vectors.""" + + device = "cpu" + + def embed(self, *images): + vecs = [] + for img in images: + rng = np.random.default_rng(hash(str(img)) % 2**32) + v = rng.standard_normal(8).astype(np.float32) + v /= np.linalg.norm(v) + vecs.append(Embedding(vector=v)) + return vecs if len(vecs) > 1 else vecs[0] + + def embed_text(self, *texts): + vecs = [] + for text in texts: + rng = np.random.default_rng(hash(text) % 2**32) + v = rng.standard_normal(8).astype(np.float32) + v /= np.linalg.norm(v) + vecs.append(Embedding(vector=v)) + return vecs if len(vecs) > 1 else vecs[0] + + +class TestEmbedTransformers: + def test_embed_images_produces_embedded_observations(self) -> None: + from dimos.memory2.embed import EmbedImages + + model = _MockEmbeddingModel() + store = MemoryStore() + with store.session() as session: + s = session.stream("imgs", str) + s.append("img1", ts=1.0) + s.append("img2", ts=2.0) + + results = s.transform(EmbedImages(model)).fetch() + assert len(results) == 2 + for obs in results: + assert isinstance(obs, EmbeddedObservation) + assert isinstance(obs.embedding, Embedding) + assert obs.embedding.to_numpy().shape == (8,) + + def test_embed_text_produces_embedded_observations(self) -> None: + from dimos.memory2.embed import EmbedText + + model = _MockEmbeddingModel() + store = MemoryStore() + with store.session() as session: + s = session.stream("logs", str) + s.append("motor fault", ts=1.0) + s.append("all clear", ts=2.0) + + results = s.transform(EmbedText(model)).fetch() + assert len(results) == 2 + for obs in results: + assert isinstance(obs, EmbeddedObservation) + assert isinstance(obs.embedding, Embedding) + + def test_embed_preserves_data(self) -> None: + from dimos.memory2.embed import EmbedText + + model = _MockEmbeddingModel() + store = MemoryStore() + with store.session() as session: + s = session.stream("logs", str) + s.append("hello", ts=1.0) + + result = s.transform(EmbedText(model)).first() + assert result.data == "hello" + + def test_embed_then_search(self) -> None: + from dimos.memory2.embed import EmbedText + + model = _MockEmbeddingModel() + store = MemoryStore() + with store.session() as session: + s = session.stream("logs", str) + for i in range(10): + s.append(f"log entry {i}", ts=float(i)) + + embedded = s.transform(EmbedText(model)) + # Get the embedding for the first item, then search for similar + first_emb = embedded.first().embedding + results = embedded.search(first_emb, k=3).fetch() + assert len(results) == 3 + # First result should be the exact match + assert results[0].similarity is not None + assert results[0].similarity > 0.99 + + def test_embed_batching(self) -> None: + from dimos.memory2.embed import EmbedText + + call_sizes: list[int] = [] + + class _TrackingModel(_MockEmbeddingModel): + def embed_text(self, *texts): + call_sizes.append(len(texts)) + return super().embed_text(*texts) + + model = _TrackingModel() + store = MemoryStore() + with store.session() as session: + s = session.stream("logs", str) + for i in range(5): + s.append(f"entry {i}") + + list(s.transform(EmbedText(model, batch_size=2))) + # 5 items with batch_size=2 → 3 calls (2, 2, 1) + assert call_sizes == [2, 2, 1] diff --git a/dimos/memory2/test_impl.py b/dimos/memory2/test_impl.py index 699cde40fc..e338966e3d 100644 --- a/dimos/memory2/test_impl.py +++ b/dimos/memory2/test_impl.py @@ -234,3 +234,46 @@ def test_same_stream_on_repeated_calls(self, case: Case) -> None: s1 = session.stream("reuse", str) s2 = session.stream("reuse", str) assert s1 is s2 + + def test_append_with_embedding(self, case: Case) -> None: + import numpy as np + + from dimos.memory2.type import EmbeddedObservation + from dimos.models.embedding.base import Embedding + + with case.session_factory() as session: + s = session.stream("vectors", str) + emb = Embedding(vector=np.array([1.0, 0.0, 0.0], dtype=np.float32)) + obs = s.append("hello", embedding=emb) + assert isinstance(obs, EmbeddedObservation) + assert obs.embedding is emb + + def test_search_top_k(self, case: Case) -> None: + import numpy as np + + from dimos.models.embedding.base import Embedding + + def _emb(v: list[float]) -> Embedding: + a = np.array(v, dtype=np.float32) + return Embedding(vector=a / (np.linalg.norm(a) + 1e-10)) + + with case.session_factory() as session: + s = session.stream("searchable", str) + s.append("north", embedding=_emb([0, 1, 0])) + s.append("east", embedding=_emb([1, 0, 0])) + s.append("south", embedding=_emb([0, -1, 0])) + + results = s.search(_emb([0, 1, 0]), k=2).fetch() + assert len(results) == 2 + assert results[0].data == "north" + assert results[0].similarity > 0.99 + + def test_search_text(self, case: Case) -> None: + with case.session_factory() as session: + s = session.stream("logs", str) + s.append("motor fault") + s.append("temperature ok") + + results = s.search_text("motor").fetch() + assert len(results) == 1 + assert results[0].data == "motor fault" diff --git a/dimos/memory2/test_save.py b/dimos/memory2/test_save.py index a009850334..d15b8e63f8 100644 --- a/dimos/memory2/test_save.py +++ b/dimos/memory2/test_save.py @@ -16,7 +16,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import pytest @@ -37,7 +37,7 @@ def make_stream(n: int = 5, start_ts: float = 0.0) -> Stream[int]: backend = ListBackend[int]("test") for i in range(n): - backend.append(i * 10, ts=start_ts + i) + backend.append(Observation(id=-1, ts=start_ts + i, _data=i * 10)) return Stream(source=backend) @@ -56,17 +56,8 @@ def name(self) -> str: def iterate(self, query: StreamQuery) -> Iterator[Observation[int]]: yield from self._obs - def append( - self, - payload: int, - *, - ts: float | None = None, - pose: Any | None = None, - tags: dict[str, Any] | None = None, - ) -> Observation[int]: - obs: Observation[int] = Observation( - id=self._next_id, ts=ts or 0.0, pose=pose, tags=tags or {}, _data=payload - ) + def append(self, obs: Observation[int]) -> Observation[int]: + obs.id = self._next_id self._next_id += 1 self._obs.append(obs) return obs @@ -139,7 +130,7 @@ def test_save_returns_target_stream(self) -> None: def test_save_preserves_data(self) -> None: backend = ListBackend[int]("src") - backend.append(42, ts=1.0, pose=(1, 2, 3), tags={"label": "cat"}) + backend.append(Observation(id=-1, ts=1.0, pose=(1, 2, 3), tags={"label": "cat"}, _data=42)) source = Stream(source=backend) target_backend = ListBackend[int]("dst") diff --git a/dimos/memory2/test_stream.py b/dimos/memory2/test_stream.py index 764f7532bd..eaca07eca8 100644 --- a/dimos/memory2/test_stream.py +++ b/dimos/memory2/test_stream.py @@ -21,24 +21,29 @@ import threading import time +from typing import TYPE_CHECKING import pytest from dimos.memory2.buffer import KeepLast, Unbounded -from dimos.memory2.impl.memory import ListBackend, MemoryStore -from dimos.memory2.stream import Stream +from dimos.memory2.impl.memory import MemoryStore from dimos.memory2.transform import FnTransformer, QualityWindow, Transformer from dimos.memory2.type import Observation +if TYPE_CHECKING: + from dimos.memory2.stream import Stream + # ── Helpers ────────────────────────────────────────────────────────── def make_stream(n: int = 5, start_ts: float = 0.0) -> Stream[int]: - """Create a ListBackend stream with n integer observations at 1-second intervals.""" - backend = ListBackend[int]("test") + """Create a MemoryStore stream with n integer observations at 1-second intervals.""" + store = MemoryStore() + session = store.session() + stream = session.stream("test") for i in range(n): - backend.append(i * 10, ts=start_ts + i) - return Stream(source=backend) + stream.append(i * 10, ts=start_ts + i) + return stream # ═══════════════════════════════════════════════════════════════════ @@ -119,20 +124,22 @@ class TestSpatialFilter: """.near(pose, radius) filters by Euclidean distance.""" def test_near_with_tuples(self): - backend = ListBackend[str]("spatial") - backend.append("origin", ts=0.0, pose=(0, 0, 0)) - backend.append("close", ts=1.0, pose=(1, 1, 0)) - backend.append("far", ts=2.0, pose=(10, 10, 10)) - stream = Stream(source=backend) + store = MemoryStore() + session = store.session() + stream = session.stream("spatial") + stream.append("origin", ts=0.0, pose=(0, 0, 0)) + stream.append("close", ts=1.0, pose=(1, 1, 0)) + stream.append("far", ts=2.0, pose=(10, 10, 10)) result = stream.near((0, 0, 0), radius=2.0).fetch() assert [o.data for o in result] == ["origin", "close"] def test_near_excludes_no_pose(self): - backend = ListBackend[str]("spatial") - backend.append("no_pose", ts=0.0) - backend.append("has_pose", ts=1.0, pose=(0, 0, 0)) - stream = Stream(source=backend) + store = MemoryStore() + session = store.session() + stream = session.stream("spatial") + stream.append("no_pose", ts=0.0) + stream.append("has_pose", ts=1.0, pose=(0, 0, 0)) result = stream.near((0, 0, 0), radius=10.0).fetch() assert [o.data for o in result] == ["has_pose"] @@ -147,20 +154,22 @@ class TestTagsFilter: """.filter_tags() matches on observation metadata.""" def test_filter_by_tag(self): - backend = ListBackend[str]("tagged") - backend.append("cat", ts=0.0, tags={"type": "animal", "legs": 4}) - backend.append("car", ts=1.0, tags={"type": "vehicle", "wheels": 4}) - backend.append("dog", ts=2.0, tags={"type": "animal", "legs": 4}) - stream = Stream(source=backend) + store = MemoryStore() + session = store.session() + stream = session.stream("tagged") + stream.append("cat", ts=0.0, tags={"type": "animal", "legs": 4}) + stream.append("car", ts=1.0, tags={"type": "vehicle", "wheels": 4}) + stream.append("dog", ts=2.0, tags={"type": "animal", "legs": 4}) result = stream.filter_tags(type="animal").fetch() assert [o.data for o in result] == ["cat", "dog"] def test_filter_multiple_tags(self): - backend = ListBackend[str]("tagged") - backend.append("a", ts=0.0, tags={"x": 1, "y": 2}) - backend.append("b", ts=1.0, tags={"x": 1, "y": 3}) - stream = Stream(source=backend) + store = MemoryStore() + session = store.session() + stream = session.stream("tagged") + stream.append("a", ts=0.0, tags={"x": 1, "y": 2}) + stream.append("b", ts=1.0, tags={"x": 1, "y": 3}) result = stream.filter_tags(x=1, y=2).fetch() assert [o.data for o in result] == ["a"] @@ -209,6 +218,11 @@ def test_exists(self): assert not make_stream(0).exists() assert not make_stream(5).after(100.0).exists() + def test_drain(self): + assert make_stream(5).drain() == 5 + assert make_stream(5).after(2.0).drain() == 2 + assert make_stream(0).drain() == 0 + # ═══════════════════════════════════════════════════════════════════ # 6. Functional API: .filter(), .map() @@ -269,11 +283,12 @@ def test_transform_can_skip(self): def test_transform_filter_transform(self): """stream.transform(A).near(pose).transform(B) — filter between transforms.""" - backend = ListBackend[int]("tfft") - backend.append(1, ts=0.0, pose=(0, 0, 0)) - backend.append(2, ts=1.0, pose=(100, 100, 100)) - backend.append(3, ts=2.0, pose=(1, 0, 0)) - stream = Stream(source=backend) + store = MemoryStore() + session = store.session() + stream = session.stream("tfft") + stream.append(1, ts=0.0, pose=(0, 0, 0)) + stream.append(2, ts=1.0, pose=(100, 100, 100)) + stream.append(3, ts=2.0, pose=(1, 0, 0)) add_ten = FnTransformer(lambda obs: obs.derive(data=obs.data + 10)) double = FnTransformer(lambda obs: obs.derive(data=obs.data * 2)) @@ -288,17 +303,18 @@ def test_transform_filter_transform(self): def test_quality_window(self): """QualityWindow keeps the best item per time window.""" - backend = ListBackend[float]("qw") + store = MemoryStore() + session = store.session() + stream = session.stream("qw") # Window 1: ts 0.0-0.9 → best quality - backend.append(0.3, ts=0.0) - backend.append(0.9, ts=0.3) # best in window - backend.append(0.1, ts=0.7) + stream.append(0.3, ts=0.0) + stream.append(0.9, ts=0.3) # best in window + stream.append(0.1, ts=0.7) # Window 2: ts 1.0-1.9 - backend.append(0.5, ts=1.0) - backend.append(0.8, ts=1.5) # best in window + stream.append(0.5, ts=1.0) + stream.append(0.8, ts=1.5) # best in window # Window 3: ts 2.0+ (emitted at end via flush) - backend.append(0.6, ts=2.2) - stream = Stream(source=backend) + stream.append(0.6, ts=2.2) xf = QualityWindow(quality_fn=lambda v: v, window=1.0) result = stream.transform(xf).fetch() @@ -423,9 +439,10 @@ class TestLiveMode: def test_live_sees_backfill_then_new(self): """Backfill first, then live appends come through.""" - backend = ListBackend[str]("live") - backend.append("old", ts=0.0) - stream = Stream(source=backend) + store = MemoryStore() + session = store.session() + stream = session.stream("live") + stream.append("old", ts=0.0) live = stream.live(buffer=Unbounded()) results: list[str] = [] @@ -442,8 +459,8 @@ def consumer(): t.start() time.sleep(0.05) - backend.append("new1", ts=1.0) - backend.append("new2", ts=2.0) + stream.append("new1", ts=1.0) + stream.append("new2", ts=2.0) consumed.wait(timeout=2.0) t.join(timeout=2.0) @@ -451,8 +468,9 @@ def consumer(): def test_live_with_filter(self): """Filters apply to live data — non-matching obs are dropped silently.""" - backend = ListBackend[int]("live_filter") - stream = Stream(source=backend) + store = MemoryStore() + session = store.session() + stream = session.stream("live_filter") live = stream.after(5.0).live(buffer=Unbounded()) results: list[int] = [] @@ -469,10 +487,10 @@ def consumer(): t.start() time.sleep(0.05) - backend.append(1, ts=1.0) # filtered out (ts <= 5.0) - backend.append(2, ts=6.0) # passes - backend.append(3, ts=3.0) # filtered out - backend.append(4, ts=10.0) # passes + stream.append(1, ts=1.0) # filtered out (ts <= 5.0) + stream.append(2, ts=6.0) # passes + stream.append(3, ts=3.0) # filtered out + stream.append(4, ts=10.0) # passes consumed.wait(timeout=2.0) t.join(timeout=2.0) @@ -480,9 +498,10 @@ def consumer(): def test_live_deduplicates_backfill_overlap(self): """Observations seen in backfill are not re-yielded from the live buffer.""" - backend = ListBackend[str]("dedup") - backend.append("backfill", ts=0.0) - stream = Stream(source=backend) + store = MemoryStore() + session = store.session() + stream = session.stream("dedup") + stream.append("backfill", ts=0.0) live = stream.live(buffer=Unbounded()) results: list[str] = [] @@ -499,7 +518,7 @@ def consumer(): t.start() time.sleep(0.05) - backend.append("live1", ts=1.0) + stream.append("live1", ts=1.0) consumed.wait(timeout=2.0) t.join(timeout=2.0) @@ -507,8 +526,9 @@ def consumer(): def test_live_with_keep_last_backpressure(self): """KeepLast drops intermediate values when consumer is slow.""" - backend = ListBackend[int]("bp") - stream = Stream(source=backend) + store = MemoryStore() + session = store.session() + stream = session.stream("bp") live = stream.live(buffer=KeepLast()) results: list[int] = [] @@ -528,7 +548,7 @@ def consumer(): time.sleep(0.05) # Rapid producer — KeepLast will drop most of these for i in range(100): - backend.append(i, ts=float(i)) + stream.append(i, ts=float(i)) time.sleep(0.001) consumed.wait(timeout=5.0) @@ -539,9 +559,10 @@ def consumer(): def test_live_transform_receives_live_items(self): """Transforms downstream of .live() see both backfill and live items.""" - backend = ListBackend[int]("live_xf") - backend.append(1, ts=0.0) - stream = Stream(source=backend) + store = MemoryStore() + session = store.session() + stream = session.stream("live_xf") + stream.append(1, ts=0.0) double = FnTransformer(lambda obs: obs.derive(data=obs.data * 2)) live = stream.live(buffer=Unbounded()).transform(double) @@ -559,8 +580,8 @@ def consumer(): t.start() time.sleep(0.05) - backend.append(10, ts=1.0) - backend.append(100, ts=2.0) + stream.append(10, ts=1.0) + stream.append(100, ts=2.0) consumed.wait(timeout=2.0) t.join(timeout=2.0) @@ -574,11 +595,104 @@ def test_live_on_transform_raises(self): with pytest.raises(TypeError, match="Cannot call .live"): stream.transform(xf).live() + def test_is_live(self): + """is_live() walks the source chain to detect live mode.""" + store = MemoryStore() + session = store.session() + stream = session.stream("is_live") + assert not stream.is_live() + + live = stream.live(buffer=Unbounded()) + assert live.is_live() + + xf = FnTransformer(lambda obs: obs) + transformed = live.transform(xf) + assert transformed.is_live() + + # Two levels deep + double_xf = transformed.transform(xf) + assert double_xf.is_live() + + # Non-live transform is not live + assert not stream.transform(xf).is_live() + + def test_search_on_live_transform_raises(self): + """search() on a transform with live upstream raises immediately.""" + store = MemoryStore() + session = store.session() + stream = session.stream("live_search") + xf = FnTransformer(lambda obs: obs) + live_xf = stream.live(buffer=Unbounded()).transform(xf) + + import numpy as np + + from dimos.models.embedding.base import Embedding + + vec = Embedding(vector=np.array([1.0, 0.0, 0.0])) + with pytest.raises(TypeError, match="requires finite data"): + # Use list() to trigger iteration — fetch() would hit its own guard first + list(live_xf.search(vec, k=5)) + + def test_order_by_on_live_transform_raises(self): + """order_by() on a transform with live upstream raises immediately.""" + store = MemoryStore() + session = store.session() + stream = session.stream("live_order") + xf = FnTransformer(lambda obs: obs) + live_xf = stream.live(buffer=Unbounded()).transform(xf) + + with pytest.raises(TypeError, match="requires finite data"): + list(live_xf.order_by("ts", desc=True)) + + def test_fetch_on_live_without_limit_raises(self): + """fetch() on a live stream without limit() raises TypeError.""" + store = MemoryStore() + session = store.session() + stream = session.stream("live_fetch") + live = stream.live(buffer=Unbounded()) + + with pytest.raises(TypeError, match="collect forever"): + live.fetch() + + def test_fetch_on_live_transform_without_limit_raises(self): + """fetch() on a live transform without limit() raises TypeError.""" + store = MemoryStore() + session = store.session() + stream = session.stream("live_fetch_xf") + xf = FnTransformer(lambda obs: obs) + live_xf = stream.live(buffer=Unbounded()).transform(xf) + + with pytest.raises(TypeError, match="collect forever"): + live_xf.fetch() + + def test_count_on_live_transform_raises(self): + """count() on a live transform stream raises TypeError.""" + store = MemoryStore() + session = store.session() + stream = session.stream("live_count") + xf = FnTransformer(lambda obs: obs) + live_xf = stream.live(buffer=Unbounded()).transform(xf) + + with pytest.raises(TypeError, match="block forever"): + live_xf.count() + + def test_last_on_live_transform_raises(self): + """last() on a live transform raises TypeError (via order_by guard).""" + store = MemoryStore() + session = store.session() + stream = session.stream("live_last") + xf = FnTransformer(lambda obs: obs) + live_xf = stream.live(buffer=Unbounded()).transform(xf) + + with pytest.raises(TypeError, match="requires finite data"): + live_xf.last() + def test_live_chained_transforms(self): """stream.live().transform(A).transform(B) — both transforms applied to live items.""" - backend = ListBackend[int]("live_chain") - backend.append(1, ts=0.0) - stream = Stream(source=backend) + store = MemoryStore() + session = store.session() + stream = session.stream("live_chain") + stream.append(1, ts=0.0) add_one = FnTransformer(lambda obs: obs.derive(data=obs.data + 1)) double = FnTransformer(lambda obs: obs.derive(data=obs.data * 2)) live = stream.live(buffer=Unbounded()).transform(add_one).transform(double) @@ -597,8 +711,8 @@ def consumer(): t.start() time.sleep(0.05) - backend.append(10, ts=1.0) - backend.append(100, ts=2.0) + stream.append(10, ts=1.0) + stream.append(100, ts=2.0) consumed.wait(timeout=2.0) t.join(timeout=2.0) @@ -607,10 +721,11 @@ def consumer(): def test_live_filter_before_live(self): """Filters applied before .live() work on both backfill and live items.""" - backend = ListBackend[str]("live_pre_filter") - backend.append("a", ts=1.0) - backend.append("b", ts=10.0) - stream = Stream(source=backend) + store = MemoryStore() + session = store.session() + stream = session.stream("live_pre_filter") + stream.append("a", ts=1.0) + stream.append("b", ts=10.0) live = stream.after(5.0).live(buffer=Unbounded()) results: list[str] = [] @@ -627,8 +742,8 @@ def consumer(): t.start() time.sleep(0.05) - backend.append("c", ts=3.0) # filtered - backend.append("d", ts=20.0) # passes + stream.append("c", ts=3.0) # filtered + stream.append("d", ts=20.0) # passes consumed.wait(timeout=2.0) t.join(timeout=2.0) diff --git a/dimos/memory2/type.py b/dimos/memory2/type.py index 59ec300685..85cfab9640 100644 --- a/dimos/memory2/type.py +++ b/dimos/memory2/type.py @@ -20,6 +20,8 @@ if TYPE_CHECKING: from collections.abc import Callable + from dimos.models.embedding.base import Embedding + T = TypeVar("T") @@ -65,7 +67,21 @@ def data(self) -> T: return val def derive(self, *, data: Any, **overrides: Any) -> Observation[Any]: - """Create a new observation preserving ts/pose/tags, replacing data.""" + """Create a new observation preserving ts/pose/tags, replacing data. + + If ``embedding`` is passed, promotes the result to + :class:`EmbeddedObservation`. + """ + if "embedding" in overrides: + return EmbeddedObservation( + id=self.id, + ts=overrides.get("ts", self.ts), + pose=overrides.get("pose", self.pose), + tags=overrides.get("tags", self.tags), + _data=data, + embedding=overrides["embedding"], + similarity=overrides.get("similarity"), + ) return Observation( id=self.id, ts=overrides.get("ts", self.ts), @@ -73,3 +89,26 @@ def derive(self, *, data: Any, **overrides: Any) -> Observation[Any]: tags=overrides.get("tags", self.tags), _data=data, ) + + +# ── EmbeddedObservation ────────────────────────────────────────── + + +@dataclass +class EmbeddedObservation(Observation[T]): + """Observation enriched with a vector embedding and optional similarity score.""" + + embedding: Embedding | None = None + similarity: float | None = None + + def derive(self, *, data: Any, **overrides: Any) -> EmbeddedObservation[Any]: + """Preserve embedding unless explicitly replaced.""" + return EmbeddedObservation( + id=self.id, + ts=overrides.get("ts", self.ts), + pose=overrides.get("pose", self.pose), + tags=overrides.get("tags", self.tags), + _data=data, + embedding=overrides.get("embedding", self.embedding), + similarity=overrides.get("similarity", self.similarity), + ) From 94aa659d5b76897100b11f1e53ed3d5e3c7d6df9 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Wed, 11 Mar 2026 18:21:55 +0800 Subject: [PATCH 077/118] memory2: add documentation for streaming model, codecs, and backends - README.md: architecture overview, module index, quick start - streaming.md: lazy vs materializing vs terminal evaluation model - codecs/README.md: codec protocol, built-in codecs, writing new ones - impl/README.md: backend guide with query contract and grid test setup --- dimos/memory2/README.md | 80 ++++++++++++++++++++++++ dimos/memory2/codecs/README.md | 57 +++++++++++++++++ dimos/memory2/impl/README.md | 107 ++++++++++++++++++++++++++++++++ dimos/memory2/streaming.md | 109 +++++++++++++++++++++++++++++++++ 4 files changed, 353 insertions(+) create mode 100644 dimos/memory2/README.md create mode 100644 dimos/memory2/codecs/README.md create mode 100644 dimos/memory2/impl/README.md create mode 100644 dimos/memory2/streaming.md diff --git a/dimos/memory2/README.md b/dimos/memory2/README.md new file mode 100644 index 0000000000..595453b165 --- /dev/null +++ b/dimos/memory2/README.md @@ -0,0 +1,80 @@ +# memory2 + +Observation storage and streaming layer for DimOS. Pull-based, lazy, composable. + +## Architecture + +``` + Live Sensor Data + ↓ +Store → Session → Stream → [filters / transforms / terminals] → Stream → [filters / transforms / terminals] → Stream → Live hooks + ↓ ↓ ↓ + Backend (ListBackend, SqliteBackend) Backend In Memory +``` + +**Store** owns a storage location (file, in-memory). **Session** manages named streams over a shared connection. **Stream** is the query/iteration surface — lazy until a terminal is called. + +## Modules + +| Module | What | +|----------------|-------------------------------------------------------------------| +| `stream.py` | Stream node — filters, transforms, terminals | +| `backend.py` | Backend / LiveBackend protocols, VectorStore / BlobStore ABCs | +| `filter.py` | StreamQuery dataclass, filter types | +| `transform.py` | Transformer protocol, FnTransformer, QualityWindow | +| `buffer.py` | Backpressure buffers for live mode (KeepLast, Bounded, Unbounded) | +| `store.py` | Store / Session base classes, StreamNamespace | +| `type.py` | Observation, EmbeddedObservation dataclasses | +| `embed.py` | EmbedImages / EmbedText transformers | + +## Subpackages + +| Package | What | Docs | +|--------------|------------------------------------------------------|--------------------------------------------------| +| `impl/` | Backend implementations (ListBackend, SqliteBackend) | [impl/README.md](impl/README.md) | +| `blobstore/` | Pluggable blob storage (file, sqlite) | [blobstore/blobstore.md](blobstore/blobstore.md) | +| `codecs/` | Encode/decode for storage (pickle, JPEG, LCM) | [codecs/README.md](codecs/README.md) | + +## Docs + +| Doc | What | +|-----|------| +| [streaming.md](streaming.md) | Lazy vs materializing vs terminal — evaluation model, live safety | +| [embeddings.md](embeddings.md) | Embedding layer design — EmbeddedObservation, vector search, EmbedImages/EmbedText | +| [blobstore/blobstore.md](blobstore/blobstore.md) | BlobStore architecture — separate payload storage from metadata | + +## Quick start + +```python +from dimos.memory2 import MemoryStore + +store = MemoryStore() +with store.session() as session: + images = session.stream("images") + + # Write + images.append(frame, ts=time.time(), pose=(x, y, z), tags={"camera": "front"}) + + # Query + recent = images.after(t).limit(10).fetch() + nearest = images.near(pose, radius=2.0).fetch() + latest = images.last() + + # Transform + edges = images.transform(Canny()).save(session.stream("edges")) + + # Live + for obs in images.live().transform(process): + handle(obs) + + # Embed + search + images.transform(EmbedImages(clip)).save(session.stream("embedded")) + results = session.stream("embedded").search(query_vec, k=5).fetch() +``` + +## Implementations + +| Backend | Status | Storage | +|-----------------|----------|----------------------------------------| +| `ListBackend` | Complete | In-memory (lists + brute-force search) | +| `SqliteBackend` | Stub | SQLite (WAL, FTS5, vec0) | diff --git a/dimos/memory2/codecs/README.md b/dimos/memory2/codecs/README.md new file mode 100644 index 0000000000..ff6b701054 --- /dev/null +++ b/dimos/memory2/codecs/README.md @@ -0,0 +1,57 @@ +# codecs + +Encode/decode payloads for persistent storage. Codecs convert typed Python objects to `bytes` and back, used by backends that store observation data as blobs. + +## Protocol + +```python +class Codec(Protocol[T]): + def encode(self, value: T) -> bytes: ... + def decode(self, data: bytes) -> T: ... +``` + +## Built-in codecs + +| Codec | Type | Notes | +|-------|------|-------| +| `PickleCodec` | Any Python object | Fallback. Uses `HIGHEST_PROTOCOL`. | +| `JpegCodec` | `Image` | Lossy compression via TurboJPEG. ~10-20x smaller. Preserves `frame_id` in header. | +| `LcmCodec` | `DimosMsg` subclasses | Uses `lcm_encode()`/`lcm_decode()`. Zero-copy for LCM message types. | + +## Auto-selection + +`codec_for(payload_type)` picks the right codec: + +```python +from dimos.memory2.codecs import codec_for + +codec_for(Image) # → JpegCodec(quality=50) +codec_for(SomeLcmMsg) # → LcmCodec(SomeLcmMsg) (if has lcm_encode/lcm_decode) +codec_for(dict) # → PickleCodec() (fallback) +codec_for(None) # → PickleCodec() +``` + +## Writing a new codec + +1. Create `dimos/memory2/codecs/mycodec.py`: + +```python +class MyCodec: + def encode(self, value: MyType) -> bytes: + ... + + def decode(self, data: bytes) -> MyType: + ... +``` + +2. Add a branch in `codec_for()` in `base.py` to auto-select it for the relevant type. + +3. Add a test case to `test_codecs.py` — the grid fixture makes this easy: + +```python +@pytest.fixture(params=[..., ("mycodec", MyCodec(), sample_value)]) +def codec_case(request): + ... +``` + +No base class needed — `Codec` is a protocol. Just implement `encode` and `decode`. diff --git a/dimos/memory2/impl/README.md b/dimos/memory2/impl/README.md new file mode 100644 index 0000000000..8bdbf11f8f --- /dev/null +++ b/dimos/memory2/impl/README.md @@ -0,0 +1,107 @@ +# impl — Backend implementations + +Storage backends for memory2. Each backend implements the `Backend` protocol (and optionally `LiveBackend`) to provide observation storage with query support. + +## Existing backends + +| Backend | File | Status | Storage | +|-----------------|-------------|----------|-------------------------------------| +| `ListBackend` | `memory.py` | Complete | In-memory lists, brute-force search | +| `SqliteBackend` | `sqlite.py` | Stub | SQLite (WAL, FTS5, vec0) | + +## Writing a new backend + +### 1. Implement the Backend protocol + +```python +from dimos.memory2.backend import Backend +from dimos.memory2.filter import StreamQuery +from dimos.memory2.type import Observation + +class MyBackend(Generic[T]): + @property + def name(self) -> str: + return self._name + + def append(self, obs: Observation[T]) -> Observation[T]: + """Assign an id and store. Return the stored observation.""" + obs.id = self._next_id + self._next_id += 1 + # ... persist obs ... + return obs + + def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: + """Yield observations matching the query.""" + # The backend is responsible for applying ALL query fields: + # query.filters — list of Filter objects (each has .matches(obs)) + # query.order_field — sort field name (e.g. "ts") + # query.order_desc — sort direction + # query.limit_val — max results + # query.offset_val — skip first N + # query.search_vec — Embedding for vector search + # query.search_k — top-k for vector search + # query.search_text — substring text search + # query.live_buffer — if set, switch to live mode (see LiveBackend) + ... + + def count(self, query: StreamQuery) -> int: + """Count matching observations.""" + ... +``` + +`Backend` is a `@runtime_checkable` Protocol — no base class needed, just implement the methods. + +### 2. Add LiveBackend support (optional) + +If your backend supports live subscriptions (push notifications on new observations): + +```python +from dimos.memory2.backend import LiveBackend + +class MyBackend(Generic[T]): + # ... Backend methods ... + + def subscribe(self, buf: BackpressureBuffer[Observation[T]]) -> DisposableBase: + """Register a buffer for push notifications. Return a disposable to unsubscribe.""" + ... +``` + +The `iterate()` method should check `query.live_buffer`: +- If `None`: return a snapshot iterator +- If set: subscribe before backfill, then yield a live tail that deduplicates by `obs.id` + +See `ListBackend._iterate_live()` for the reference implementation. + +### 3. Add Store and Session + +```python +from dimos.memory2.store import Session, Store + +class MySession(Session): + def _create_backend(self, name: str, payload_type: type | None = None) -> Backend: + return MyBackend(self._conn, name) + +class MyStore(Store): + def session(self) -> MySession: + return MySession(...) +``` + +### 4. Add to the grid test + +In `test_impl.py`, add your store to the fixture so all standard tests run against it: + +```python +@pytest.fixture(params=["memory", "sqlite", "mybackend"]) +def store(request, tmp_path): + if request.param == "mybackend": + return MyStore(...) + ... +``` + +Use `pytest.mark.xfail` for features not yet implemented — the grid test covers: append, fetch, iterate, count, first/last, exists, all filters, ordering, limit/offset, embeddings, text search. + +### Query contract + +The backend must handle the full `StreamQuery`. The Stream layer does NOT apply filters to backend results — it trusts the backend to do so. This allows backends to push queries down to their native query engine (SQL WHERE, FTS5 MATCH, vec0 knn). + +For filters, each `Filter` object has a `.matches(obs) -> bool` method that backends can use directly if they don't have a native equivalent. diff --git a/dimos/memory2/streaming.md b/dimos/memory2/streaming.md new file mode 100644 index 0000000000..c1c4a3c36c --- /dev/null +++ b/dimos/memory2/streaming.md @@ -0,0 +1,109 @@ +# Stream evaluation model + +Stream methods fall into three categories: **lazy**, **materializing**, and **terminal**. The distinction matters for live (infinite) streams. + +`is_live()` walks the source chain to detect live mode — any stream whose ancestor called `.live()` returns `True`. +All materializing operations and unsafe terminals check this and raise `TypeError` immediately rather than silently hanging. + +## Lazy (streaming) + +These return generators — each observation flows through one at a time. Safe with live/infinite streams. No internal buffering between stages. + +| Method | How | +|---------------------------------------------------------------------------|-------------------------------------------------| +| `.after()` `.before()` `.time_range()` `.at()` `.near()` `.filter_tags()` | Filter predicates — skip non-matching obs | +| `.filter(pred)` | Same, user-defined predicate | +| `.transform(xf)` / `.map(fn)` | Generator — yields transformed obs one by one | +| `.search_text(text)` | Generator — substring match filter | +| `.limit(k)` | `islice` — stops after k | +| `.offset(n)` | `islice` — skips first n | +| `.live()` | Enables live tail (backfill then block for new) | + +These compose freely. A chain like `.after(t).filter(pred).transform(xf).limit(10)` pulls lazily — the source only produces what the consumer asks for. + +## Materializing (collect-then-process) + +These **must consume the entire upstream** before producing output. On a live stream, they raise `TypeError` immediately. + +| Method | Why | Live behaviour | +|--------------------|----------------------------------------------|----------------| +| `.search(vec, k)` | Cosine-ranks all observations, returns top-k | TypeError | +| `.order_by(field)` | `sorted(list(it))` — needs all items to sort | TypeError | + +On a backend-backed stream (not a transform), both are pushed down to the backend which handles them on its own data structure (snapshot). The guard only fires when these appear on a **transform stream** whose upstream is live — detected via `is_live()`. + +### Rejected patterns (raise TypeError) + +```python +# TypeError: search requires finite data +stream.live().transform(Embed(model)).search(vec, k=5) + +# TypeError: order_by requires finite data +stream.live().transform(xf).order_by("ts", desc=True) + +# TypeError (via order_by): last() calls order_by internally +stream.live().transform(xf).last() +``` + +### Safe equivalents + +```python +# Search the stored data, not the live tail +results = stream.search(vec, k=5).fetch() + +# First works fine (uses limit(1), no materialization) +obs = stream.live().transform(xf).first() +``` + +## Terminal (consume the iterator) + +Terminals trigger iteration and return a value. They're the "go" button — nothing executes until a terminal is called. + +| Method | Returns | Memory | Live behaviour | +|-----------------|---------------------|--------------------|-----------------------------------------| +| `.fetch()` | `list[Observation]` | Grows with results | TypeError without `.limit()` first | +| `.drain()` | `int` (count) | Constant | Blocks forever, memory stays flat | +| `.save(target)` | target `Stream` | Constant | Blocks forever, appends each to store | +| `.first()` | `Observation` | Constant | Returns first item, then stops | +| `.exists()` | `bool` | Constant | Returns after one item check | +| `.last()` | `Observation` | Materializes | TypeError (uses order_by internally) | +| `.count()` | `int` | Constant | TypeError on transform streams | + +### Choosing the right terminal + +**Batch query** — collect results into memory: +```python +results = stream.after(t).search(vec, k=10).fetch() +``` + +**Live ingestion** — process forever, constant memory: +```python +# Embed and store continuously +stream.live().transform(EmbedImages(clip)).save(target) + +# Side-effect pipeline (no storage) +stream.live().transform(process).drain() +``` + +**One-shot** — get a single observation: +```python +obs = stream.live().transform(xf).first() # blocks until one arrives +has_data = stream.exists() # quick check +``` + +**Bounded live** — collect a fixed number from a live stream: +```python +batch = stream.live().limit(100).fetch() # OK — limit makes it finite +``` + +### Error summary + +All operations that would silently hang on live streams raise `TypeError` instead: + +| Pattern | Error | +|-------------------------------------|-----------------------------------------------| +| `live.transform(xf).search(vec, k)` | `.search() requires finite data` | +| `live.transform(xf).order_by("ts")` | `.order_by() requires finite data` | +| `live.fetch()` (without `.limit()`) | `.fetch() would collect forever` | +| `live.transform(xf).count()` | `.count() would block forever` | +| `live.transform(xf).last()` | `.order_by() requires finite data` (via last) | From 1dc68b75f821a8f3ad301d999cb16afe5258174c Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Wed, 11 Mar 2026 18:48:30 +0800 Subject: [PATCH 078/118] query application refactor --- dimos/memory2/README.md | 31 ++++++++++++++++-- dimos/memory2/filter.py | 62 +++++++++++++++++++++++++++++++++++- dimos/memory2/impl/README.md | 26 ++++++++++++++- dimos/memory2/impl/memory.py | 42 +----------------------- dimos/memory2/stream.py | 54 +------------------------------ 5 files changed, 117 insertions(+), 98 deletions(-) diff --git a/dimos/memory2/README.md b/dimos/memory2/README.md index 595453b165..6a5f5781ad 100644 --- a/dimos/memory2/README.md +++ b/dimos/memory2/README.md @@ -14,13 +14,22 @@ Store → Session → Stream → [filters / transforms / terminals] → Stream **Store** owns a storage location (file, in-memory). **Session** manages named streams over a shared connection. **Stream** is the query/iteration surface — lazy until a terminal is called. + +Supporting Systems: + +- BlobStore — separates large payloads from metadata. FileBlobStore (files on disk) and SqliteBlobStore (blob table per stream). Supports lazy loading. +- Codecs — codec_for() auto-selects: JpegCodec for images (TurboJPEG, ~10-20x compression), LcmCodec for DimOS messages, PickleCodec fallback. +- Transformers — Transformer[T,R] ABC wrapping iterator-to-iterator. EmbedImages/EmbedText enrich observations with embeddings. QualityWindow keeps best per time window. +- Backpressure Buffers — KeepLast, Bounded, DropNew, Unbounded — bridge push/pull for live mode. + + ## Modules | Module | What | |----------------|-------------------------------------------------------------------| | `stream.py` | Stream node — filters, transforms, terminals | | `backend.py` | Backend / LiveBackend protocols, VectorStore / BlobStore ABCs | -| `filter.py` | StreamQuery dataclass, filter types | +| `filter.py` | StreamQuery dataclass, filter types, Python query execution | | `transform.py` | Transformer protocol, FnTransformer, QualityWindow | | `buffer.py` | Backpressure buffers for live mode (KeepLast, Bounded, Unbounded) | | `store.py` | Store / Session base classes, StreamNamespace | @@ -32,7 +41,7 @@ Store → Session → Stream → [filters / transforms / terminals] → Stream | Package | What | Docs | |--------------|------------------------------------------------------|--------------------------------------------------| | `impl/` | Backend implementations (ListBackend, SqliteBackend) | [impl/README.md](impl/README.md) | -| `blobstore/` | Pluggable blob storage (file, sqlite) | [blobstore/blobstore.md](blobstore/blobstore.md) | +| `blobstore/` | Pluggable blob storage (file, sqlite) | e[blobstore/blobstore.md](blobstore/blobstore.md) | | `codecs/` | Encode/decode for storage (pickle, JPEG, LCM) | [codecs/README.md](codecs/README.md) | ## Docs @@ -43,6 +52,24 @@ Store → Session → Stream → [filters / transforms / terminals] → Stream | [embeddings.md](embeddings.md) | Embedding layer design — EmbeddedObservation, vector search, EmbedImages/EmbedText | | [blobstore/blobstore.md](blobstore/blobstore.md) | BlobStore architecture — separate payload storage from metadata | +## Query execution + +`StreamQuery` holds the full query spec (filters, text search, vector search, ordering, offset/limit). It also provides `apply(iterator)` — a Python-side execution path that runs all operations as in-memory predicates, brute-force cosine, and list sorts. + +This is the **default fallback**. Backends are free to push down operations using store-specific strategies instead: + +| Operation | Python fallback (`StreamQuery.apply`) | Store push-down (example) | +|----------------|---------------------------------------|----------------------------------| +| Filters | `filter.matches()` predicates | SQL WHERE clauses | +| Text search | Case-insensitive substring | FTS5 full-text index | +| Vector search | Brute-force cosine similarity | vec0 / FAISS ANN index | +| Ordering | `sorted()` materialization | SQL ORDER BY | +| Offset / limit | `islice()` | SQL OFFSET / LIMIT | + +`ListBackend` delegates entirely to `StreamQuery.apply()`. `SqliteBackend` translates the query into SQL and only falls back to Python for operations it can't express natively. + +Transform-sourced streams (post `.transform()`) always use `StreamQuery.apply()` since there's no backend to push down to. + ## Quick start ```python diff --git a/dimos/memory2/filter.py b/dimos/memory2/filter.py index 243200f68a..8b0d9ec9bd 100644 --- a/dimos/memory2/filter.py +++ b/dimos/memory2/filter.py @@ -15,10 +15,11 @@ from __future__ import annotations from dataclasses import dataclass, field +from itertools import islice from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Callable, Iterator from dimos.memory2.buffer import BackpressureBuffer from dimos.memory2.type import Observation @@ -137,3 +138,62 @@ class StreamQuery: search_k: int | None = None # Full-text search (substring / FTS5) search_text: str | None = None + + def apply( + self, it: Iterator[Observation[Any]], *, live: bool = False + ) -> Iterator[Observation[Any]]: + """Apply all query operations to an iterator in Python. + + Used as the fallback execution path for transform-sourced streams + and in-memory backends. Backends with native query support (SQL, + ANN indexes) should push down operations instead. + """ + # Filters + if self.filters: + it = (obs for obs in it if all(f.matches(obs) for f in self.filters)) + + # Text search — substring match + if self.search_text is not None: + needle = self.search_text.lower() + it = (obs for obs in it if needle in str(obs.data).lower()) + + # Vector search — brute-force cosine (materializes) + if self.search_vec is not None: + if live: + raise TypeError( + ".search() requires finite data — cannot rank an infinite live stream." + ) + query_emb = self.search_vec + scored = [] + for obs in it: + emb = getattr(obs, "embedding", None) + if emb is not None: + sim = float(emb @ query_emb) + scored.append(obs.derive(data=obs.data, similarity=sim)) + scored.sort(key=lambda o: getattr(o, "similarity", 0.0) or 0.0, reverse=True) + if self.search_k is not None: + scored = scored[: self.search_k] + it = iter(scored) + + # Sort (materializes) + if self.order_field: + if live: + raise TypeError( + ".order_by() requires finite data — cannot sort an infinite live stream." + ) + key = self.order_field + desc = self.order_desc + items = sorted( + list(it), + key=lambda obs: getattr(obs, key) if getattr(obs, key, None) is not None else 0, + reverse=desc, + ) + it = iter(items) + + # Offset + limit + if self.offset_val: + it = islice(it, self.offset_val, None) + if self.limit_val is not None: + it = islice(it, self.limit_val) + + return it diff --git a/dimos/memory2/impl/README.md b/dimos/memory2/impl/README.md index 8bdbf11f8f..19efc17bed 100644 --- a/dimos/memory2/impl/README.md +++ b/dimos/memory2/impl/README.md @@ -102,6 +102,30 @@ Use `pytest.mark.xfail` for features not yet implemented — the grid test cover ### Query contract -The backend must handle the full `StreamQuery`. The Stream layer does NOT apply filters to backend results — it trusts the backend to do so. This allows backends to push queries down to their native query engine (SQL WHERE, FTS5 MATCH, vec0 knn). +The backend must handle the full `StreamQuery`. The Stream layer does NOT apply filters to backend results — it trusts the backend to do so. + +`StreamQuery.apply(iterator)` provides a complete Python-side execution path — filters, text search, vector search, ordering, offset/limit — all as in-memory operations. Backends can use it in three ways: + +**Full delegation** — simplest, good enough for in-memory backends: +```python +def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: + return query.apply(iter(self._data)) +``` + +**Partial push-down** — handle some operations natively, delegate the rest: +```python +def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: + # Handle filters and ordering in SQL + rows = self._sql_query(query.filters, query.order_field, query.order_desc) + # Delegate remaining operations (vector search, text search, offset/limit) to Python + remaining = StreamQuery( + search_vec=query.search_vec, search_k=query.search_k, + search_text=query.search_text, + offset_val=query.offset_val, limit_val=query.limit_val, + ) + return remaining.apply(iter(rows)) +``` + +**Full push-down** — translate everything to native queries (SQL WHERE, FTS5 MATCH, vec0 knn) without calling `apply()` at all. For filters, each `Filter` object has a `.matches(obs) -> bool` method that backends can use directly if they don't have a native equivalent. diff --git a/dimos/memory2/impl/memory.py b/dimos/memory2/impl/memory.py index 3e8f89c63c..d26d731208 100644 --- a/dimos/memory2/impl/memory.py +++ b/dimos/memory2/impl/memory.py @@ -79,47 +79,7 @@ def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: def _iterate_snapshot(self, query: StreamQuery) -> Iterator[Observation[T]]: with self._lock: snapshot = list(self._observations) - - # Apply filters - for f in query.filters: - snapshot = [obs for obs in snapshot if f.matches(obs)] - - # Text search — substring match (SqliteBackend will use FTS5) - if query.search_text is not None: - needle = query.search_text.lower() - snapshot = [obs for obs in snapshot if needle in str(obs.data).lower()] - - # Vector search — brute-force cosine via Embedding.__matmul__ - if query.search_vec is not None: - query_emb = query.search_vec - scored: list[Observation[T]] = [] - for obs in snapshot: - emb = getattr(obs, "embedding", None) - if emb is not None: - sim = float(emb @ query_emb) - scored.append(obs.derive(data=obs.data, similarity=sim)) - scored.sort(key=lambda o: getattr(o, "similarity", 0.0) or 0.0, reverse=True) - if query.search_k is not None: - scored = scored[: query.search_k] - snapshot = scored - - # Ordering - if query.order_field: - key = query.order_field - snapshot.sort( - key=lambda obs: getattr(obs, key) if getattr(obs, key, None) is not None else 0, - reverse=query.order_desc, - ) - - # Offset - if query.offset_val: - snapshot = snapshot[query.offset_val :] - - # Limit - if query.limit_val is not None: - snapshot = snapshot[: query.limit_val] - - yield from snapshot + yield from query.apply(iter(snapshot)) def _iterate_live( self, diff --git a/dimos/memory2/stream.py b/dimos/memory2/stream.py index a5c00d8af1..af61541460 100644 --- a/dimos/memory2/stream.py +++ b/dimos/memory2/stream.py @@ -14,7 +14,6 @@ from __future__ import annotations -from itertools import islice import time from typing import TYPE_CHECKING, Any, Generic, TypeVar @@ -85,58 +84,7 @@ def _iter_transform(self) -> Iterator[Observation[T]]: """Iterate a transform source, applying query filters in Python.""" assert isinstance(self._source, Stream) and self._xf is not None it: Iterator[Observation[T]] = self._xf(iter(self._source)) - - # Apply filters as Python predicates - filters = self._query.filters - if filters: - it = (obs for obs in it if all(f.matches(obs) for f in filters)) - - # Text search — substring match - if self._query.search_text is not None: - needle = self._query.search_text.lower() - it = (obs for obs in it if needle in str(obs.data).lower()) - - # Vector search — brute-force cosine (materializes — rejects live) - if self._query.search_vec is not None: - if self.is_live(): - raise TypeError( - ".search() requires finite data — cannot rank an infinite live stream." - ) - query_emb = self._query.search_vec - scored = [] - for obs in it: - emb = getattr(obs, "embedding", None) - if emb is not None: - sim = float(emb @ query_emb) - scored.append(obs.derive(data=obs.data, similarity=sim)) - scored.sort(key=lambda o: getattr(o, "similarity", 0.0) or 0.0, reverse=True) - k = self._query.search_k - if k is not None: - scored = scored[:k] - it = iter(scored) - - # Sort if needed (materializes — rejects live) - if self._query.order_field: - if self.is_live(): - raise TypeError( - ".order_by() requires finite data — cannot sort an infinite live stream." - ) - key = self._query.order_field - desc = self._query.order_desc - items = sorted( - list(it), - key=lambda obs: getattr(obs, key) if getattr(obs, key, None) is not None else 0, - reverse=desc, - ) - it = iter(items) - - # Offset + limit - if self._query.offset_val: - it = islice(it, self._query.offset_val, None) - if self._query.limit_val is not None: - it = islice(it, self._query.limit_val) - - return it + return self._query.apply(it, live=self.is_live()) # ── Query builders ────────────────────────────────────────────── From 4d31779f245e777e9390ab513973275a00e5d9e7 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Wed, 11 Mar 2026 19:53:19 +0800 Subject: [PATCH 079/118] memory2: replace LiveBackend with pluggable LiveChannel, add Configurable pattern MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace LiveBackend protocol with LiveChannel ABC (SubjectChannel for in-memory fan-out, extensible to Redis/Postgres for cross-process) - Add livechannel/ subpackage with SubjectChannel implementation - Make Store and Session extend Configurable[ConfigT] with StoreConfig and SessionConfig dataclasses - Remove redundant Session._backends dict (Backend lives in Stream._source) - Make list_streams() and delete_stream() abstract on Session so implementations can query persisted streams - StreamNamespace delegates to list_streams()/stream() instead of accessing _streams directly - Remove LiveBackend isinstance guard from stream.py — all backends now have a built-in LiveChannel --- dimos/memory2/README.md | 11 ++-- dimos/memory2/__init__.py | 10 ++- dimos/memory2/backend.py | 32 ++++++++-- dimos/memory2/impl/memory.py | 42 ++++++------- dimos/memory2/impl/sqlite.py | 51 ++++++++++------ dimos/memory2/livechannel/__init__.py | 4 ++ dimos/memory2/livechannel/subject.py | 63 +++++++++++++++++++ dimos/memory2/store.py | 87 +++++++++++++++++++-------- dimos/memory2/stream.py | 8 +-- dimos/memory2/test_impl.py | 2 +- dimos/memory2/test_save.py | 65 ++------------------ 11 files changed, 231 insertions(+), 144 deletions(-) create mode 100644 dimos/memory2/livechannel/__init__.py create mode 100644 dimos/memory2/livechannel/subject.py diff --git a/dimos/memory2/README.md b/dimos/memory2/README.md index 6a5f5781ad..adc478f1e4 100644 --- a/dimos/memory2/README.md +++ b/dimos/memory2/README.md @@ -28,11 +28,11 @@ Supporting Systems: | Module | What | |----------------|-------------------------------------------------------------------| | `stream.py` | Stream node — filters, transforms, terminals | -| `backend.py` | Backend / LiveBackend protocols, VectorStore / BlobStore ABCs | +| `backend.py` | Backend protocol, LiveChannel / VectorStore / BlobStore ABCs | | `filter.py` | StreamQuery dataclass, filter types, Python query execution | | `transform.py` | Transformer protocol, FnTransformer, QualityWindow | | `buffer.py` | Backpressure buffers for live mode (KeepLast, Bounded, Unbounded) | -| `store.py` | Store / Session base classes, StreamNamespace | +| `store.py` | Store / Session (Configurable), StoreConfig / SessionConfig | | `type.py` | Observation, EmbeddedObservation dataclasses | | `embed.py` | EmbedImages / EmbedText transformers | @@ -40,9 +40,10 @@ Supporting Systems: | Package | What | Docs | |--------------|------------------------------------------------------|--------------------------------------------------| -| `impl/` | Backend implementations (ListBackend, SqliteBackend) | [impl/README.md](impl/README.md) | -| `blobstore/` | Pluggable blob storage (file, sqlite) | e[blobstore/blobstore.md](blobstore/blobstore.md) | -| `codecs/` | Encode/decode for storage (pickle, JPEG, LCM) | [codecs/README.md](codecs/README.md) | +| `impl/` | Backend implementations (ListBackend, SqliteBackend) | [impl/README.md](impl/README.md) | +| `livechannel/` | Live notification channels (SubjectChannel) | | +| `blobstore/` | Pluggable blob storage (file, sqlite) | [blobstore/blobstore.md](blobstore/blobstore.md) | +| `codecs/` | Encode/decode for storage (pickle, JPEG, LCM) | [codecs/README.md](codecs/README.md) | ## Docs diff --git a/dimos/memory2/__init__.py b/dimos/memory2/__init__.py index 779a38b041..9be60f6d9f 100644 --- a/dimos/memory2/__init__.py +++ b/dimos/memory2/__init__.py @@ -1,4 +1,4 @@ -from dimos.memory2.backend import Backend, LiveBackend, VectorStore +from dimos.memory2.backend import Backend, LiveChannel, VectorStore from dimos.memory2.buffer import ( BackpressureBuffer, Bounded, @@ -21,7 +21,8 @@ ) from dimos.memory2.impl.memory import ListBackend, MemorySession, MemoryStore from dimos.memory2.impl.sqlite import SqliteBackend, SqliteSession, SqliteStore -from dimos.memory2.store import Session, Store, StreamNamespace +from dimos.memory2.livechannel import SubjectChannel +from dimos.memory2.store import Session, SessionConfig, Store, StoreConfig, StreamNamespace from dimos.memory2.stream import Stream from dimos.memory2.transform import FnTransformer, QualityWindow, Transformer from dimos.memory2.type import EmbeddedObservation, Observation @@ -42,7 +43,7 @@ "FnTransformer", "KeepLast", "ListBackend", - "LiveBackend", + "LiveChannel", "MemorySession", "MemoryStore", "NearFilter", @@ -50,13 +51,16 @@ "PredicateFilter", "QualityWindow", "Session", + "SessionConfig", "SqliteBackend", "SqliteSession", "SqliteStore", "Store", + "StoreConfig", "Stream", "StreamNamespace", "StreamQuery", + "SubjectChannel", "TagsFilter", "TimeRangeFilter", "Transformer", diff --git a/dimos/memory2/backend.py b/dimos/memory2/backend.py index e1f87287bf..3838c6694d 100644 --- a/dimos/memory2/backend.py +++ b/dimos/memory2/backend.py @@ -15,7 +15,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Protocol, TypeVar, runtime_checkable +from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, runtime_checkable from dimos.core.resource import Resource @@ -38,11 +38,16 @@ class Backend(Protocol[T]): The backend is fully responsible for applying query filters. How it does so (SQL, R-tree, Python predicates) is its business. + + Every backend supports live mode via a built-in ``LiveChannel``. """ @property def name(self) -> str: ... + @property + def live_channel(self) -> LiveChannel[T]: ... + def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: ... def append(self, obs: Observation[T]) -> Observation[T]: @@ -56,11 +61,28 @@ def append(self, obs: Observation[T]) -> Observation[T]: def count(self, query: StreamQuery) -> int: ... -@runtime_checkable -class LiveBackend(Backend[T], Protocol[T]): - """Backend that also supports live subscriptions.""" +# ── Live notification channel ──────────────────────────────────── - def subscribe(self, buf: BackpressureBuffer[Observation[T]]) -> DisposableBase: ... + +class LiveChannel(ABC, Generic[T]): + """Push-notification channel for live observation delivery. + + Decouples the notification mechanism from storage. The built-in + ``SubjectChannel`` handles same-session fan-out (thread-safe, zero + config). External implementations (Redis pub/sub, Postgres + LISTEN/NOTIFY, inotify) can be injected for cross-process use. + """ + + @abstractmethod + def subscribe(self, buf: BackpressureBuffer[Observation[T]]) -> DisposableBase: + """Register *buf* to receive new observations. Returns a + disposable that unsubscribes when disposed.""" + ... + + @abstractmethod + def notify(self, obs: Observation[T]) -> None: + """Fan out *obs* to all current subscribers.""" + ... # ── Blob storage ────────────────────────────────────────────────── diff --git a/dimos/memory2/impl/memory.py b/dimos/memory2/impl/memory.py index d26d731208..3178abb07a 100644 --- a/dimos/memory2/impl/memory.py +++ b/dimos/memory2/impl/memory.py @@ -17,8 +17,7 @@ import threading from typing import TYPE_CHECKING, Any, Generic, TypeVar -from reactivex.disposable import Disposable - +from dimos.memory2.livechannel.subject import SubjectChannel from dimos.memory2.store import Session, Store if TYPE_CHECKING: @@ -26,7 +25,7 @@ from reactivex.abc import DisposableBase - from dimos.memory2.backend import Backend + from dimos.memory2.backend import Backend, LiveChannel from dimos.memory2.buffer import BackpressureBuffer from dimos.memory2.filter import StreamQuery from dimos.memory2.type import Observation @@ -42,23 +41,23 @@ def __init__(self, name: str = "") -> None: self._observations: list[Observation[T]] = [] self._next_id = 0 self._lock = threading.Lock() - self._subscribers: list[BackpressureBuffer[Observation[T]]] = [] + self._channel: SubjectChannel[T] = SubjectChannel() @property def name(self) -> str: return self._name + @property + def live_channel(self) -> LiveChannel[T]: + return self._channel + def append(self, obs: Observation[T]) -> Observation[T]: with self._lock: obs.id = self._next_id self._next_id += 1 self._observations.append(obs) - subs = list(self._subscribers) - - # Notify outside lock to avoid deadlocks - for buf in subs: - buf.put(obs) + self._channel.notify(obs) return obs def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: @@ -72,7 +71,7 @@ def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: buf = query.live_buffer if buf is not None: # Subscribe BEFORE backfill to avoid missing items - sub = self.subscribe(buf) + sub = self._channel.subscribe(buf) return self._iterate_live(query, buf, sub) return self._iterate_snapshot(query) @@ -112,19 +111,6 @@ def _iterate_live( def count(self, query: StreamQuery) -> int: return sum(1 for _ in self.iterate(query)) - def subscribe(self, buf: BackpressureBuffer[Observation[T]]) -> DisposableBase: - with self._lock: - self._subscribers.append(buf) - - def _unsubscribe() -> None: - with self._lock: - try: - self._subscribers.remove(buf) - except ValueError: - pass - - return Disposable(action=_unsubscribe) - class MemorySession(Session): """In-memory session. Each stream is backed by a ListBackend.""" @@ -132,9 +118,15 @@ class MemorySession(Session): def _create_backend(self, name: str, payload_type: type[Any] | None = None) -> Backend[Any]: return ListBackend(name) + def list_streams(self) -> list[str]: + return list(self._streams.keys()) + + def delete_stream(self, name: str) -> None: + self._streams.pop(name, None) + class MemoryStore(Store): """In-memory store for experimentation.""" - def session(self) -> MemorySession: - return MemorySession() + def session(self, **kwargs: Any) -> MemorySession: + return MemorySession(**kwargs) diff --git a/dimos/memory2/impl/sqlite.py b/dimos/memory2/impl/sqlite.py index c87162acf2..6037297e14 100644 --- a/dimos/memory2/impl/sqlite.py +++ b/dimos/memory2/impl/sqlite.py @@ -14,19 +14,17 @@ from __future__ import annotations +from dataclasses import dataclass import sqlite3 -from typing import TYPE_CHECKING, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar -from dimos.memory2.store import Session, Store +from dimos.memory2.livechannel.subject import SubjectChannel +from dimos.memory2.store import Session, Store, StoreConfig if TYPE_CHECKING: from collections.abc import Iterator - import os - from reactivex.abc import DisposableBase - - from dimos.memory2.backend import Backend - from dimos.memory2.buffer import BackpressureBuffer + from dimos.memory2.backend import Backend, LiveChannel from dimos.memory2.filter import StreamQuery from dimos.memory2.type import Observation @@ -39,11 +37,16 @@ class SqliteBackend(Generic[T]): def __init__(self, conn: sqlite3.Connection, name: str) -> None: self._conn = conn self._name = name + self._channel: SubjectChannel[T] = SubjectChannel() @property def name(self) -> str: return self._name + @property + def live_channel(self) -> LiveChannel[T]: + return self._channel + def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: raise NotImplementedError @@ -53,32 +56,46 @@ def append(self, obs: Observation[T]) -> Observation[T]: def count(self, query: StreamQuery) -> int: raise NotImplementedError - def subscribe(self, buf: BackpressureBuffer[Observation[T]]) -> DisposableBase: - raise NotImplementedError - class SqliteSession(Session): """Session owning a single SQLite connection.""" - def __init__(self, conn: sqlite3.Connection) -> None: - super().__init__() + def __init__(self, conn: sqlite3.Connection, **kwargs: Any) -> None: + super().__init__(**kwargs) self._conn = conn def _create_backend(self, name: str, payload_type: type[Any] | None = None) -> Backend[Any]: return SqliteBackend(self._conn, name) + def list_streams(self) -> list[str]: + # TODO: also query DB for persisted streams not yet opened + return list(self._streams.keys()) + + def delete_stream(self, name: str) -> None: + self._streams.pop(name, None) + # TODO: drop underlying table/rows from SQLite + def stop(self) -> None: super().stop() self._conn.close() +@dataclass +class SqliteStoreConfig(StoreConfig): + """Config for SQLite-backed store.""" + + path: str = "memory.db" + + class SqliteStore(Store): """Store backed by a SQLite database file.""" - def __init__(self, path: str | os.PathLike[str]) -> None: - self._path = path + default_config: type[SqliteStoreConfig] = SqliteStoreConfig + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) - def session(self) -> SqliteSession: - conn = sqlite3.connect(self._path, check_same_thread=False) + def session(self, **kwargs: Any) -> SqliteSession: + conn = sqlite3.connect(self.config.path, check_same_thread=False) conn.execute("PRAGMA journal_mode=WAL") - return SqliteSession(conn) + return SqliteSession(conn, **kwargs) diff --git a/dimos/memory2/livechannel/__init__.py b/dimos/memory2/livechannel/__init__.py new file mode 100644 index 0000000000..4fba822bab --- /dev/null +++ b/dimos/memory2/livechannel/__init__.py @@ -0,0 +1,4 @@ +from dimos.memory2.backend import LiveChannel +from dimos.memory2.livechannel.subject import SubjectChannel + +__all__ = ["LiveChannel", "SubjectChannel"] diff --git a/dimos/memory2/livechannel/subject.py b/dimos/memory2/livechannel/subject.py new file mode 100644 index 0000000000..2d2b848f9f --- /dev/null +++ b/dimos/memory2/livechannel/subject.py @@ -0,0 +1,63 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""In-memory fan-out live channel (same-session, thread-safe).""" + +from __future__ import annotations + +import threading +from typing import TYPE_CHECKING, Generic, TypeVar + +from reactivex.disposable import Disposable + +from dimos.memory2.backend import LiveChannel + +if TYPE_CHECKING: + from reactivex.abc import DisposableBase + + from dimos.memory2.buffer import BackpressureBuffer + from dimos.memory2.type import Observation + +T = TypeVar("T") + + +class SubjectChannel(LiveChannel[T], Generic[T]): + """In-memory fan-out channel for same-session live notification. + + Thread-safe. ``notify()`` copies the subscriber list under the lock, + then iterates outside the lock to avoid deadlocks with slow consumers. + """ + + def __init__(self) -> None: + self._subscribers: list[BackpressureBuffer[Observation[T]]] = [] + self._lock = threading.Lock() + + def subscribe(self, buf: BackpressureBuffer[Observation[T]]) -> DisposableBase: + with self._lock: + self._subscribers.append(buf) + + def _unsubscribe() -> None: + with self._lock: + try: + self._subscribers.remove(buf) + except ValueError: + pass + + return Disposable(action=_unsubscribe) + + def notify(self, obs: Observation[T]) -> None: + with self._lock: + subs = list(self._subscribers) + for buf in subs: + buf.put(obs) diff --git a/dimos/memory2/store.py b/dimos/memory2/store.py index d4d926230c..aeb926531a 100644 --- a/dimos/memory2/store.py +++ b/dimos/memory2/store.py @@ -15,19 +15,45 @@ from __future__ import annotations from abc import abstractmethod +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, TypeVar, cast from dimos.core.resource import CompositeResource from dimos.memory2.stream import Stream +from dimos.protocol.service.spec import Configurable if TYPE_CHECKING: from collections.abc import Iterator - from dimos.memory2.backend import Backend + from dimos.memory2.backend import Backend, BlobStore, LiveChannel, VectorStore T = TypeVar("T") +# ── Configuration ───────────────────────────────────────────────── + + +@dataclass +class StoreConfig: + """Base config for Store. Subclasses extend with store-specific fields.""" + + +@dataclass +class SessionConfig: + """Session-level defaults for stream capabilities. + + These are inherited by all streams in the session unless overridden + per-stream in ``session.stream(..., **overrides)``. + """ + + live_channel: LiveChannel[Any] | None = None + blob_store: BlobStore | None = None + vector_store: VectorStore | None = None + + +# ── Stream namespace ────────────────────────────────────────────── + + class StreamNamespace: """Attribute-access proxy for session streams. @@ -45,41 +71,45 @@ def __init__(self, session: Session) -> None: def __getattr__(self, name: str) -> Stream[Any]: if name.startswith("_"): raise AttributeError(name) - try: - return self._session._streams[name] - except KeyError: - available = ", ".join(self._session._streams) or "(none)" - raise AttributeError(f"No stream named {name!r}. Available: {available}") from None + if name not in self._session.list_streams(): + available = ", ".join(self._session.list_streams()) or "(none)" + raise AttributeError(f"No stream named {name!r}. Available: {available}") + return self._session.stream(name) def __getitem__(self, name: str) -> Stream[Any]: - try: - return self._session._streams[name] - except KeyError: - raise KeyError(name) from None + if name not in self._session.list_streams(): + raise KeyError(name) + return self._session.stream(name) def __iter__(self) -> Iterator[Stream[Any]]: - return iter(self._session._streams.values()) + for name in self._session.list_streams(): + yield self._session.stream(name) def __len__(self) -> int: - return len(self._session._streams) + return len(self._session.list_streams()) def __contains__(self, name: str) -> bool: - return name in self._session._streams + return name in self._session.list_streams() def __repr__(self) -> str: - return f"StreamNamespace({list(self._session._streams.keys())})" + return f"StreamNamespace({self._session.list_streams()})" -class Session(CompositeResource): +# ── Session & Store ─────────────────────────────────────────────── + + +class Session(Configurable[SessionConfig], CompositeResource): """A session against a store. Manages named streams over a shared connection. Subclasses implement ``_create_backend`` to provide storage-specific backends. """ - def __init__(self) -> None: - super().__init__() + default_config: type[SessionConfig] = SessionConfig + + def __init__(self, **kwargs: Any) -> None: + Configurable.__init__(self, **kwargs) + CompositeResource.__init__(self) self._streams: dict[str, Stream[Any]] = {} - self._backends: dict[str, Backend[Any]] = {} @abstractmethod def _create_backend(self, name: str, payload_type: type[Any] | None = None) -> Backend[Any]: @@ -90,25 +120,34 @@ def stream(self, name: str, payload_type: type[T] | None = None) -> Stream[T]: """Get or create a named stream. Returns the same Stream on repeated calls.""" if name not in self._streams: backend = self._create_backend(name, payload_type) - self._backends[name] = backend self._streams[name] = Stream(source=backend) return cast("Stream[T]", self._streams[name]) + @abstractmethod def list_streams(self) -> list[str]: """Return names of all streams in this session.""" - return list(self._streams.keys()) + ... + @abstractmethod def delete_stream(self, name: str) -> None: - self._streams.pop(name, None) - self._backends.pop(name, None) + """Delete a stream by name (from cache and underlying storage).""" + ... @property def streams(self) -> StreamNamespace: return StreamNamespace(self) -class Store(CompositeResource): +class Store(Configurable[StoreConfig], CompositeResource): """Top-level entry point — wraps a storage location (file, URL, etc.).""" + default_config: type[StoreConfig] = StoreConfig + + def __init__(self, **kwargs: Any) -> None: + Configurable.__init__(self, **kwargs) + CompositeResource.__init__(self) + @abstractmethod - def session(self) -> Session: ... + def session(self, **kwargs: Any) -> Session: + """Create a session. kwargs are forwarded to SessionConfig.""" + ... diff --git a/dimos/memory2/stream.py b/dimos/memory2/stream.py index af61541460..d2b212855c 100644 --- a/dimos/memory2/stream.py +++ b/dimos/memory2/stream.py @@ -17,7 +17,7 @@ import time from typing import TYPE_CHECKING, Any, Generic, TypeVar -from dimos.memory2.backend import Backend, LiveBackend +from dimos.memory2.backend import Backend from dimos.memory2.buffer import BackpressureBuffer, KeepLast from dimos.memory2.filter import ( AfterFilter, @@ -173,8 +173,8 @@ def transform(self, xf: Transformer[T, R]) -> Stream[R]: def live(self, buffer: BackpressureBuffer[Observation[Any]] | None = None) -> Stream[T]: """Return a stream whose iteration never ends — backfill then live tail. - Only valid on backend-backed streams whose backend implements - LiveBackend. Call .live() before .transform(), not after. + All backends support live mode via their built-in ``LiveChannel``. + Call .live() before .transform(), not after. Default buffer: KeepLast(). The backend handles subscription, dedup, and backpressure — how it does so is its business. @@ -184,8 +184,6 @@ def live(self, buffer: BackpressureBuffer[Observation[Any]] | None = None) -> St "Cannot call .live() on a transform stream. " "Call .live() on the source stream, then .transform()." ) - if not isinstance(self._source, LiveBackend): - raise TypeError(f"Backend {self._source.name!r} does not support live mode.") buf = buffer if buffer is not None else KeepLast() return self._replace_query(live_buffer=buf) diff --git a/dimos/memory2/test_impl.py b/dimos/memory2/test_impl.py index e338966e3d..7ed342ffdb 100644 --- a/dimos/memory2/test_impl.py +++ b/dimos/memory2/test_impl.py @@ -59,7 +59,7 @@ def sqlite_session() -> Generator[Session, None, None]: from dimos.memory2.impl.sqlite import SqliteStore with tempfile.NamedTemporaryFile(suffix=".db") as f: - store = SqliteStore(f.name) + store = SqliteStore(path=f.name) with store.session() as session: yield session diff --git a/dimos/memory2/test_save.py b/dimos/memory2/test_save.py index d15b8e63f8..74c1be89f0 100644 --- a/dimos/memory2/test_save.py +++ b/dimos/memory2/test_save.py @@ -12,25 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for Stream.save() and LiveBackend protocol split.""" +"""Tests for Stream.save() and LiveChannel integration.""" from __future__ import annotations -from typing import TYPE_CHECKING - import pytest -from dimos.memory2.backend import Backend, LiveBackend +from dimos.memory2.backend import Backend, LiveChannel from dimos.memory2.impl.memory import ListBackend from dimos.memory2.stream import Stream from dimos.memory2.transform import FnTransformer from dimos.memory2.type import Observation -if TYPE_CHECKING: - from collections.abc import Iterator - - from dimos.memory2.filter import StreamQuery - # ── Helpers ────────────────────────────────────────────────────────── @@ -41,65 +34,19 @@ def make_stream(n: int = 5, start_ts: float = 0.0) -> Stream[int]: return Stream(source=backend) -class ReadOnlyBackend: - """A Backend that does NOT support live mode (no subscribe).""" - - def __init__(self, name: str = "") -> None: - self._name = name - self._obs: list[Observation[int]] = [] - self._next_id = 0 - - @property - def name(self) -> str: - return self._name - - def iterate(self, query: StreamQuery) -> Iterator[Observation[int]]: - yield from self._obs - - def append(self, obs: Observation[int]) -> Observation[int]: - obs.id = self._next_id - self._next_id += 1 - self._obs.append(obs) - return obs - - def count(self, query: StreamQuery) -> int: - return len(self._obs) - - # ═══════════════════════════════════════════════════════════════════ # Protocol checks # ═══════════════════════════════════════════════════════════════════ -class TestProtocolSplit: - def test_list_backend_is_live(self) -> None: - b = ListBackend[int]("x") - assert isinstance(b, LiveBackend) - +class TestProtocol: def test_list_backend_is_backend(self) -> None: b = ListBackend[int]("x") assert isinstance(b, Backend) - def test_readonly_is_backend(self) -> None: - b = ReadOnlyBackend() - assert isinstance(b, Backend) - - def test_readonly_is_not_live(self) -> None: - b = ReadOnlyBackend() - assert not isinstance(b, LiveBackend) - - -# ═══════════════════════════════════════════════════════════════════ -# .live() rejects non-LiveBackend -# ═══════════════════════════════════════════════════════════════════ - - -class TestLiveRejectsNonLive: - def test_live_rejects_non_live_backend(self) -> None: - b = ReadOnlyBackend("ro") - s = Stream(source=b) - with pytest.raises(TypeError, match="does not support live mode"): - s.live() + def test_list_backend_has_live_channel(self) -> None: + b = ListBackend[int]("x") + assert isinstance(b.live_channel, LiveChannel) # ═══════════════════════════════════════════════════════════════════ From 690c5ecc546fb9df861324479d362713f0885ecf Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Wed, 11 Mar 2026 20:04:32 +0800 Subject: [PATCH 080/118] =?UTF-8?q?memory2:=20make=20backends=20Configurab?= =?UTF-8?q?le,=20add=20session=E2=86=92stream=20config=20propagation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Session.stream() now merges session-level defaults with per-stream overrides and forwards them to _create_backend(). Backends (ListBackend, SqliteBackend) extend Configurable[BackendConfig] so they receive live_channel, blob_store, and vector_store through the standard config pattern instead of explicit constructor params. --- dimos/memory2/backend.py | 19 ++++++++++++++++++- dimos/memory2/impl/memory.py | 17 ++++++++++++----- dimos/memory2/impl/sqlite.py | 17 ++++++++++++----- dimos/memory2/store.py | 16 ++++++++++++---- 4 files changed, 54 insertions(+), 15 deletions(-) diff --git a/dimos/memory2/backend.py b/dimos/memory2/backend.py index 3838c6694d..5c048936cf 100644 --- a/dimos/memory2/backend.py +++ b/dimos/memory2/backend.py @@ -15,7 +15,8 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, runtime_checkable +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, runtime_checkable from dimos.core.resource import Resource @@ -32,6 +33,22 @@ T = TypeVar("T") +# ── Backend configuration ─────────────────────────────────────── + + +@dataclass +class BackendConfig: + """Configuration for backend capabilities. + + Session-level defaults are merged with per-stream overrides and + forwarded here by ``Session.stream()``. + """ + + live_channel: LiveChannel[Any] | None = None + blob_store: BlobStore | None = None + vector_store: VectorStore | None = None + + @runtime_checkable class Backend(Protocol[T]): """Data source protocol for stored observations. diff --git a/dimos/memory2/impl/memory.py b/dimos/memory2/impl/memory.py index 3178abb07a..ae228127ba 100644 --- a/dimos/memory2/impl/memory.py +++ b/dimos/memory2/impl/memory.py @@ -17,8 +17,10 @@ import threading from typing import TYPE_CHECKING, Any, Generic, TypeVar +from dimos.memory2.backend import BackendConfig from dimos.memory2.livechannel.subject import SubjectChannel from dimos.memory2.store import Session, Store +from dimos.protocol.service.spec import Configurable if TYPE_CHECKING: from collections.abc import Iterator @@ -33,15 +35,18 @@ T = TypeVar("T") -class ListBackend(Generic[T]): +class ListBackend(Configurable[BackendConfig], Generic[T]): """In-memory backend for experimentation. Thread-safe.""" - def __init__(self, name: str = "") -> None: + default_config: type[BackendConfig] = BackendConfig + + def __init__(self, name: str = "", **kwargs: Any) -> None: + super().__init__(**kwargs) self._name = name self._observations: list[Observation[T]] = [] self._next_id = 0 self._lock = threading.Lock() - self._channel: SubjectChannel[T] = SubjectChannel() + self._channel: LiveChannel[T] = self.config.live_channel or SubjectChannel() @property def name(self) -> str: @@ -115,8 +120,10 @@ def count(self, query: StreamQuery) -> int: class MemorySession(Session): """In-memory session. Each stream is backed by a ListBackend.""" - def _create_backend(self, name: str, payload_type: type[Any] | None = None) -> Backend[Any]: - return ListBackend(name) + def _create_backend( + self, name: str, payload_type: type[Any] | None = None, **config: Any + ) -> Backend[Any]: + return ListBackend(name, **config) def list_streams(self) -> list[str]: return list(self._streams.keys()) diff --git a/dimos/memory2/impl/sqlite.py b/dimos/memory2/impl/sqlite.py index 6037297e14..31bb38855a 100644 --- a/dimos/memory2/impl/sqlite.py +++ b/dimos/memory2/impl/sqlite.py @@ -18,8 +18,10 @@ import sqlite3 from typing import TYPE_CHECKING, Any, Generic, TypeVar +from dimos.memory2.backend import BackendConfig from dimos.memory2.livechannel.subject import SubjectChannel from dimos.memory2.store import Session, Store, StoreConfig +from dimos.protocol.service.spec import Configurable if TYPE_CHECKING: from collections.abc import Iterator @@ -31,13 +33,16 @@ T = TypeVar("T") -class SqliteBackend(Generic[T]): +class SqliteBackend(Configurable[BackendConfig], Generic[T]): """SQLite-backed observation storage for a single stream (table).""" - def __init__(self, conn: sqlite3.Connection, name: str) -> None: + default_config: type[BackendConfig] = BackendConfig + + def __init__(self, conn: sqlite3.Connection, name: str, **kwargs: Any) -> None: + super().__init__(**kwargs) self._conn = conn self._name = name - self._channel: SubjectChannel[T] = SubjectChannel() + self._channel: LiveChannel[T] = self.config.live_channel or SubjectChannel() @property def name(self) -> str: @@ -64,8 +69,10 @@ def __init__(self, conn: sqlite3.Connection, **kwargs: Any) -> None: super().__init__(**kwargs) self._conn = conn - def _create_backend(self, name: str, payload_type: type[Any] | None = None) -> Backend[Any]: - return SqliteBackend(self._conn, name) + def _create_backend( + self, name: str, payload_type: type[Any] | None = None, **config: Any + ) -> Backend[Any]: + return SqliteBackend(self._conn, name, **config) def list_streams(self) -> list[str]: # TODO: also query DB for persisted streams not yet opened diff --git a/dimos/memory2/store.py b/dimos/memory2/store.py index aeb926531a..5037c4b074 100644 --- a/dimos/memory2/store.py +++ b/dimos/memory2/store.py @@ -112,14 +112,22 @@ def __init__(self, **kwargs: Any) -> None: self._streams: dict[str, Stream[Any]] = {} @abstractmethod - def _create_backend(self, name: str, payload_type: type[Any] | None = None) -> Backend[Any]: + def _create_backend( + self, name: str, payload_type: type[Any] | None = None, **config: Any + ) -> Backend[Any]: """Create a backend for the named stream. Called once per stream name.""" ... - def stream(self, name: str, payload_type: type[T] | None = None) -> Stream[T]: - """Get or create a named stream. Returns the same Stream on repeated calls.""" + def stream(self, name: str, payload_type: type[T] | None = None, **overrides: Any) -> Stream[T]: + """Get or create a named stream. Returns the same Stream on repeated calls. + + Per-stream ``overrides`` (e.g. ``live_channel=``) are merged on top of + the session-level defaults from :class:`SessionConfig`. + """ if name not in self._streams: - backend = self._create_backend(name, payload_type) + resolved = {k: v for k, v in vars(self.config).items() if v is not None} + resolved.update({k: v for k, v in overrides.items() if v is not None}) + backend = self._create_backend(name, payload_type, **resolved) self._streams[name] = Stream(source=backend) return cast("Stream[T]", self._streams[name]) From f73d8d4fef26ee93e46f4ab73f408035ba5be7f0 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Wed, 11 Mar 2026 20:10:06 +0800 Subject: [PATCH 081/118] memory2: wire VectorStore into ListBackend, add MemoryVectorStore ListBackend.append() now delegates embedding storage to the pluggable VectorStore when configured. _iterate_snapshot() uses VectorStore.search() for ANN ranking when available, falling back to brute-force in StreamQuery.apply(). Adds MemoryVectorStore (in-memory brute-force impl) and tests verifying end-to-end config propagation including per-stream vector_store overrides. --- dimos/memory2/impl/memory.py | 35 ++++++++++- dimos/memory2/test_embedding.py | 83 +++++++++++++++++++++++++++ dimos/memory2/vectorstore/__init__.py | 17 ++++++ dimos/memory2/vectorstore/memory.py | 58 +++++++++++++++++++ 4 files changed, 192 insertions(+), 1 deletion(-) create mode 100644 dimos/memory2/vectorstore/__init__.py create mode 100644 dimos/memory2/vectorstore/memory.py diff --git a/dimos/memory2/impl/memory.py b/dimos/memory2/impl/memory.py index ae228127ba..6e584c89d2 100644 --- a/dimos/memory2/impl/memory.py +++ b/dimos/memory2/impl/memory.py @@ -14,6 +14,7 @@ from __future__ import annotations +from dataclasses import replace import threading from typing import TYPE_CHECKING, Any, Generic, TypeVar @@ -62,6 +63,13 @@ def append(self, obs: Observation[T]) -> Observation[T]: self._next_id += 1 self._observations.append(obs) + # Delegate embedding to pluggable vector store + vs = self.config.vector_store + if vs is not None: + emb = getattr(obs, "embedding", None) + if emb is not None: + vs.put(self._name, obs.id, emb) + self._channel.notify(obs) return obs @@ -83,7 +91,32 @@ def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: def _iterate_snapshot(self, query: StreamQuery) -> Iterator[Observation[T]]: with self._lock: snapshot = list(self._observations) - yield from query.apply(iter(snapshot)) + + if query.search_vec is not None and self.config.vector_store is not None: + yield from self._vector_search(snapshot, query) + else: + yield from query.apply(iter(snapshot)) + + def _vector_search( + self, snapshot: list[Observation[T]], query: StreamQuery + ) -> Iterator[Observation[T]]: + """Use pluggable VectorStore for ANN search, then apply remaining query ops.""" + vs = self.config.vector_store + assert vs is not None # caller checks + + hits = vs.search(self._name, query.search_vec, query.search_k or len(snapshot)) + + # Build results with similarity attached, preserving VectorStore ranking + ranked: list[Observation[T]] = [] + obs_by_id = {obs.id: obs for obs in snapshot} + for obs_id, sim in hits: + obs = obs_by_id.get(obs_id) + if obs is not None: + ranked.append(obs.derive(data=obs.data, similarity=sim)) + + # Apply remaining query ops (filters, ordering, offset, limit) — skip vector search + rest = replace(query, search_vec=None, search_k=None) + yield from rest.apply(iter(ranked)) def _iterate_live( self, diff --git a/dimos/memory2/test_embedding.py b/dimos/memory2/test_embedding.py index 40f80346c0..f1d22addf2 100644 --- a/dimos/memory2/test_embedding.py +++ b/dimos/memory2/test_embedding.py @@ -370,3 +370,86 @@ def embed_text(self, *texts): list(s.transform(EmbedText(model, batch_size=2))) # 5 items with batch_size=2 → 3 calls (2, 2, 1) assert call_sizes == [2, 2, 1] + + +# ── Pluggable VectorStore ──────────────────────────────────────── + + +class TestPluggableVectorStore: + """Verify that injecting a VectorStore via session config actually delegates search.""" + + def test_append_stores_in_vector_store(self) -> None: + from dimos.memory2.vectorstore import MemoryVectorStore + + vs = MemoryVectorStore() + store = MemoryStore() + with store.session(vector_store=vs) as session: + s = session.stream("vecs", str) + s.append("hello", embedding=_emb([1, 0, 0])) + s.append("world", embedding=_emb([0, 1, 0])) + + assert len(vs._vectors["vecs"]) == 2 + + def test_append_without_embedding_skips_vector_store(self) -> None: + from dimos.memory2.vectorstore import MemoryVectorStore + + vs = MemoryVectorStore() + store = MemoryStore() + with store.session(vector_store=vs) as session: + s = session.stream("plain", str) + s.append("no embedding") + + assert "plain" not in vs._vectors + + def test_search_uses_vector_store(self) -> None: + from dimos.memory2.vectorstore import MemoryVectorStore + + vs = MemoryVectorStore() + store = MemoryStore() + with store.session(vector_store=vs) as session: + s = session.stream("vecs", str) + s.append("north", embedding=_emb([0, 1, 0])) + s.append("east", embedding=_emb([1, 0, 0])) + s.append("south", embedding=_emb([0, -1, 0])) + s.append("west", embedding=_emb([-1, 0, 0])) + + results = s.search(_emb([0, 1, 0]), k=2).fetch() + assert len(results) == 2 + assert results[0].data == "north" + assert results[0].similarity is not None + assert results[0].similarity > 0.99 + + def test_search_with_filters_via_vector_store(self) -> None: + from dimos.memory2.vectorstore import MemoryVectorStore + + vs = MemoryVectorStore() + store = MemoryStore() + with store.session(vector_store=vs) as session: + s = session.stream("vecs", str) + s.append("early", ts=10.0, embedding=_emb([1, 0, 0])) + s.append("late", ts=20.0, embedding=_emb([1, 0, 0])) + + # Filter + search: only "late" passes the after filter + results = s.after(15.0).search(_emb([1, 0, 0]), k=10).fetch() + assert len(results) == 1 + assert results[0].data == "late" + + def test_per_stream_vector_store_override(self) -> None: + from dimos.memory2.vectorstore import MemoryVectorStore + + vs_default = MemoryVectorStore() + vs_override = MemoryVectorStore() + store = MemoryStore() + with store.session(vector_store=vs_default) as session: + # Stream with default vector store + s1 = session.stream("s1", str) + s1.append("a", embedding=_emb([1, 0, 0])) + + # Stream with overridden vector store + s2 = session.stream("s2", str, vector_store=vs_override) + s2.append("b", embedding=_emb([0, 1, 0])) + + assert "s1" in vs_default._vectors + assert "s1" not in vs_override._vectors + assert "s2" in vs_override._vectors + assert "s2" not in vs_default._vectors diff --git a/dimos/memory2/vectorstore/__init__.py b/dimos/memory2/vectorstore/__init__.py new file mode 100644 index 0000000000..fbdd9d3666 --- /dev/null +++ b/dimos/memory2/vectorstore/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.memory2.vectorstore.memory import MemoryVectorStore + +__all__ = ["MemoryVectorStore"] diff --git a/dimos/memory2/vectorstore/memory.py b/dimos/memory2/vectorstore/memory.py new file mode 100644 index 0000000000..22532c6ad1 --- /dev/null +++ b/dimos/memory2/vectorstore/memory.py @@ -0,0 +1,58 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from dimos.memory2.backend import VectorStore + +if TYPE_CHECKING: + from dimos.models.embedding.base import Embedding + + +class MemoryVectorStore(VectorStore): + """In-memory brute-force vector store for testing. + + Stores embeddings in a dict keyed by ``(stream, observation_id)``. + Search computes cosine similarity against all vectors in the stream. + """ + + def __init__(self) -> None: + self._vectors: dict[str, dict[int, Embedding]] = {} + + # ── Resource lifecycle ──────────────────────────────────────── + + def start(self) -> None: + pass + + def stop(self) -> None: + pass + + # ── VectorStore interface ──────────────────────────────────── + + def put(self, stream: str, key: int, embedding: Embedding) -> None: + self._vectors.setdefault(stream, {})[key] = embedding + + def search(self, stream: str, query: Embedding, k: int) -> list[tuple[int, float]]: + vectors = self._vectors.get(stream, {}) + if not vectors: + return [] + scored = [(key, float(emb @ query)) for key, emb in vectors.items()] + scored.sort(key=lambda x: x[1], reverse=True) + return scored[:k] + + def delete(self, stream: str, key: int) -> None: + vectors = self._vectors.get(stream, {}) + vectors.pop(key, None) From c6557392209503c7823bb35a721078ff37689209 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Wed, 11 Mar 2026 20:41:57 +0800 Subject: [PATCH 082/118] memory2: wire BlobStore into ListBackend with lazy/eager blob loading Payloads are encoded via auto-selected codec and externalized to the pluggable BlobStore on append. Observations become lightweight metadata with lazy loaders that fetch+decode on first .data access. Per-stream eager_blobs toggle pre-loads data during iteration. --- dimos/memory2/backend.py | 3 + dimos/memory2/impl/memory.py | 41 ++++++- dimos/memory2/store.py | 3 + dimos/memory2/test_blobstore.py | 185 ++++++++++++++++++++++++++++++++ 4 files changed, 228 insertions(+), 4 deletions(-) create mode 100644 dimos/memory2/test_blobstore.py diff --git a/dimos/memory2/backend.py b/dimos/memory2/backend.py index 5c048936cf..f5e74cf6ad 100644 --- a/dimos/memory2/backend.py +++ b/dimos/memory2/backend.py @@ -26,6 +26,7 @@ from reactivex.abc import DisposableBase from dimos.memory2.buffer import BackpressureBuffer + from dimos.memory2.codecs.base import Codec from dimos.memory2.filter import StreamQuery from dimos.memory2.type import Observation from dimos.models.embedding.base import Embedding @@ -47,6 +48,8 @@ class BackendConfig: live_channel: LiveChannel[Any] | None = None blob_store: BlobStore | None = None vector_store: VectorStore | None = None + eager_blobs: bool = False + codec: Codec[Any] | None = None @runtime_checkable diff --git a/dimos/memory2/impl/memory.py b/dimos/memory2/impl/memory.py index 6e584c89d2..f53a3d2af2 100644 --- a/dimos/memory2/impl/memory.py +++ b/dimos/memory2/impl/memory.py @@ -19,8 +19,10 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar from dimos.memory2.backend import BackendConfig +from dimos.memory2.codecs.base import Codec, codec_for from dimos.memory2.livechannel.subject import SubjectChannel from dimos.memory2.store import Session, Store +from dimos.memory2.type import _UNLOADED from dimos.protocol.service.spec import Configurable if TYPE_CHECKING: @@ -41,13 +43,19 @@ class ListBackend(Configurable[BackendConfig], Generic[T]): default_config: type[BackendConfig] = BackendConfig - def __init__(self, name: str = "", **kwargs: Any) -> None: + def __init__( + self, name: str = "", payload_type: type[Any] | None = None, **kwargs: Any + ) -> None: super().__init__(**kwargs) self._name = name self._observations: list[Observation[T]] = [] self._next_id = 0 self._lock = threading.Lock() self._channel: LiveChannel[T] = self.config.live_channel or SubjectChannel() + # Resolve codec for blob store + self._codec: Codec[Any] | None = None + if self.config.blob_store is not None: + self._codec = self.config.codec or codec_for(payload_type) @property def name(self) -> str: @@ -58,9 +66,23 @@ def live_channel(self) -> LiveChannel[T]: return self._channel def append(self, obs: Observation[T]) -> Observation[T]: + # Encode BEFORE lock (avoids holding lock during IO) + bs = self.config.blob_store + encoded: bytes | None = None + if bs is not None and self._codec is not None: + encoded = self._codec.encode(obs._data) + with self._lock: obs.id = self._next_id self._next_id += 1 + if encoded is not None: + assert bs is not None + bs.put(self._name, obs.id, encoded) + # Replace inline data with lazy loader + stream_name, key, codec = self._name, obs.id, self._codec + assert codec is not None + obs._data = _UNLOADED # type: ignore[assignment] + obs._loader = lambda: codec.decode(bs.get(stream_name, key)) self._observations.append(obs) # Delegate embedding to pluggable vector store @@ -93,9 +115,16 @@ def _iterate_snapshot(self, query: StreamQuery) -> Iterator[Observation[T]]: snapshot = list(self._observations) if query.search_vec is not None and self.config.vector_store is not None: - yield from self._vector_search(snapshot, query) + it = self._vector_search(snapshot, query) else: - yield from query.apply(iter(snapshot)) + it = query.apply(iter(snapshot)) + + if self.config.eager_blobs and self.config.blob_store is not None: + for obs in it: + _ = obs.data # trigger lazy loader + yield obs + else: + yield from it def _vector_search( self, snapshot: list[Observation[T]], query: StreamQuery @@ -126,6 +155,8 @@ def _iterate_live( ) -> Iterator[Observation[T]]: from dimos.memory2.buffer import ClosedError + eager = self.config.eager_blobs and self.config.blob_store is not None + # Backfill phase — use snapshot query (without live) for the backfill last_id = -1 for obs in self._iterate_snapshot(query): @@ -142,6 +173,8 @@ def _iterate_live( last_id = obs.id if filters and not all(f.matches(obs) for f in filters): continue + if eager: + _ = obs.data # trigger lazy loader yield obs except (ClosedError, StopIteration): sub.dispose() @@ -156,7 +189,7 @@ class MemorySession(Session): def _create_backend( self, name: str, payload_type: type[Any] | None = None, **config: Any ) -> Backend[Any]: - return ListBackend(name, **config) + return ListBackend(name, payload_type=payload_type, **config) def list_streams(self) -> list[str]: return list(self._streams.keys()) diff --git a/dimos/memory2/store.py b/dimos/memory2/store.py index 5037c4b074..e9c1ec4e51 100644 --- a/dimos/memory2/store.py +++ b/dimos/memory2/store.py @@ -26,6 +26,7 @@ from collections.abc import Iterator from dimos.memory2.backend import Backend, BlobStore, LiveChannel, VectorStore + from dimos.memory2.codecs.base import Codec T = TypeVar("T") @@ -49,6 +50,8 @@ class SessionConfig: live_channel: LiveChannel[Any] | None = None blob_store: BlobStore | None = None vector_store: VectorStore | None = None + eager_blobs: bool = False + codec: Codec[Any] | None = None # ── Stream namespace ────────────────────────────────────────────── diff --git a/dimos/memory2/test_blobstore.py b/dimos/memory2/test_blobstore.py new file mode 100644 index 0000000000..b8e8668ff8 --- /dev/null +++ b/dimos/memory2/test_blobstore.py @@ -0,0 +1,185 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for BlobStore integration with ListBackend.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from dimos.memory2.blobstore.file import FileBlobStore +from dimos.memory2.impl.memory import MemoryStore +from dimos.memory2.type import _UNLOADED +from dimos.models.embedding.base import Embedding + +if TYPE_CHECKING: + from pathlib import Path + +# ── Helpers ─────────────────────────────────────────────────────── + + +def _emb(vec: list[float]) -> Embedding: + v = np.array(vec, dtype=np.float32) + v /= np.linalg.norm(v) + 1e-10 + return Embedding(vector=v) + + +# ── Tests ───────────────────────────────────────────────────────── + + +class TestBlobStoreIntegration: + def test_append_stores_in_blobstore(self, tmp_path: Path) -> None: + bs = FileBlobStore(tmp_path / "blobs") + bs.start() + store = MemoryStore() + with store.session(blob_store=bs) as session: + s = session.stream("data", bytes) + s.append(b"hello", ts=1.0) + + # Blob was written to the file store + raw = bs.get("data", 0) + assert len(raw) > 0 + + def test_lazy_data_not_loaded_until_access(self, tmp_path: Path) -> None: + bs = FileBlobStore(tmp_path / "blobs") + bs.start() + store = MemoryStore() + with store.session(blob_store=bs) as session: + s = session.stream("data", str) + obs = s.append("payload", ts=1.0) + + # Data replaced with sentinel after append + assert isinstance(obs._data, type(_UNLOADED)) + assert obs._loader is not None + + def test_lazy_data_loads_correctly(self, tmp_path: Path) -> None: + bs = FileBlobStore(tmp_path / "blobs") + bs.start() + store = MemoryStore() + with store.session(blob_store=bs) as session: + s = session.stream("data", str) + s.append("payload", ts=1.0) + + result = s.first() + assert result.data == "payload" + + def test_eager_preloads_data(self, tmp_path: Path) -> None: + bs = FileBlobStore(tmp_path / "blobs") + bs.start() + store = MemoryStore() + with store.session(blob_store=bs, eager_blobs=True) as session: + s = session.stream("data", str) + s.append("payload", ts=1.0) + + # Iterating with eager_blobs triggers load + results = s.fetch() + assert len(results) == 1 + # Data should be loaded (not _UNLOADED) + assert not isinstance(results[0]._data, type(_UNLOADED)) + assert results[0].data == "payload" + + def test_per_stream_eager_override(self, tmp_path: Path) -> None: + bs = FileBlobStore(tmp_path / "blobs") + bs.start() + store = MemoryStore() + with store.session(blob_store=bs) as session: + # Default: lazy + lazy_stream = session.stream("lazy", str) + lazy_stream.append("lazy-val", ts=1.0) + + # Override: eager + eager_stream = session.stream("eager", str, eager_blobs=True) + eager_stream.append("eager-val", ts=1.0) + + lazy_results = lazy_stream.fetch() + eager_results = eager_stream.fetch() + + # Lazy: data stays unloaded until accessed + assert lazy_results[0].data == "lazy-val" + + # Eager: data pre-loaded during iteration + assert not isinstance(eager_results[0]._data, type(_UNLOADED)) + assert eager_results[0].data == "eager-val" + + def test_no_blobstore_unchanged(self) -> None: + store = MemoryStore() + with store.session() as session: + s = session.stream("data", str) + obs = s.append("inline", ts=1.0) + + # Without blob store, data stays inline + assert obs._data == "inline" + assert obs._loader is None + assert obs.data == "inline" + + def test_blobstore_with_vector_search(self, tmp_path: Path) -> None: + from dimos.memory2.vectorstore import MemoryVectorStore + + bs = FileBlobStore(tmp_path / "blobs") + bs.start() + vs = MemoryVectorStore() + store = MemoryStore() + with store.session(blob_store=bs, vector_store=vs) as session: + s = session.stream("vecs", str) + s.append("north", ts=1.0, embedding=_emb([0, 1, 0])) + s.append("east", ts=2.0, embedding=_emb([1, 0, 0])) + s.append("south", ts=3.0, embedding=_emb([0, -1, 0])) + + # Vector search triggers lazy load via obs.derive(data=obs.data, ...) + results = s.search(_emb([0, 1, 0]), k=2).fetch() + assert len(results) == 2 + assert results[0].data == "north" + assert results[0].similarity > 0.99 + + def test_blobstore_with_text_search(self, tmp_path: Path) -> None: + bs = FileBlobStore(tmp_path / "blobs") + bs.start() + store = MemoryStore() + with store.session(blob_store=bs) as session: + s = session.stream("logs", str) + s.append("motor fault", ts=1.0) + s.append("temperature ok", ts=2.0) + + # Text search triggers lazy load via str(obs.data) + results = s.search_text("motor").fetch() + assert len(results) == 1 + assert results[0].data == "motor fault" + + def test_multiple_appends_get_unique_blobs(self, tmp_path: Path) -> None: + bs = FileBlobStore(tmp_path / "blobs") + bs.start() + store = MemoryStore() + with store.session(blob_store=bs) as session: + s = session.stream("multi", str) + s.append("first", ts=1.0) + s.append("second", ts=2.0) + s.append("third", ts=3.0) + + results = s.fetch() + assert [r.data for r in results] == ["first", "second", "third"] + + def test_fetch_preserves_metadata(self, tmp_path: Path) -> None: + bs = FileBlobStore(tmp_path / "blobs") + bs.start() + store = MemoryStore() + with store.session(blob_store=bs) as session: + s = session.stream("meta", str) + s.append("val", ts=42.0, tags={"kind": "info"}) + + result = s.first() + assert result.ts == 42.0 + assert result.tags == {"kind": "info"} + assert result.data == "val" From 5b565db3a825badfa59d03265a5821398d6a9dbf Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Wed, 11 Mar 2026 21:06:39 +0800 Subject: [PATCH 083/118] memory2: allow bare generator functions as stream transforms MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit stream.transform() now accepts Iterator→Iterator callables in addition to Transformer subclasses, for quick stateful pipelines. --- dimos/memory2/stream.py | 18 +++++++++++++++--- dimos/memory2/test_stream.py | 23 +++++++++++++++++++++++ dimos/memory2/transform.py | 10 ++++++++++ 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/dimos/memory2/stream.py b/dimos/memory2/stream.py index d2b212855c..a1cabc7f90 100644 --- a/dimos/memory2/stream.py +++ b/dimos/memory2/stream.py @@ -30,7 +30,7 @@ TagsFilter, TimeRangeFilter, ) -from dimos.memory2.transform import FnTransformer, Transformer +from dimos.memory2.transform import FnIterTransformer, FnTransformer, Transformer from dimos.memory2.type import EmbeddedObservation, Observation if TYPE_CHECKING: @@ -161,11 +161,23 @@ def map(self, fn: Callable[[Observation[T]], Observation[R]]) -> Stream[Any]: # ── Transform ─────────────────────────────────────────────────── - def transform(self, xf: Transformer[T, R]) -> Stream[R]: + def transform( + self, + xf: Transformer[T, R] | Callable[[Iterator[Observation[T]]], Iterator[Observation[R]]], + ) -> Stream[R]: """Wrap this stream with a transformer. Returns a new lazy Stream. - When iterated, calls xf(iter(self)) — pulls lazily through the chain. + Accepts a ``Transformer`` subclass or a bare callable / generator + function with the same ``Iterator[Obs] → Iterator[Obs]`` signature:: + + def detect(upstream): + for obs in upstream: + yield obs.derive(data=run_detector(obs.data)) + + images.transform(detect).save(detections) """ + if not isinstance(xf, Transformer): + xf = FnIterTransformer(xf) return Stream(source=self, xf=xf, query=StreamQuery()) # ── Live mode ─────────────────────────────────────────────────── diff --git a/dimos/memory2/test_stream.py b/dimos/memory2/test_stream.py index eaca07eca8..ab8f4da7e6 100644 --- a/dimos/memory2/test_stream.py +++ b/dimos/memory2/test_stream.py @@ -301,6 +301,29 @@ def test_transform_filter_transform(self): ) assert [o.data for o in result] == [22, 26] + def test_generator_function_transform(self): + """A bare generator function works as a transform.""" + + def double_all(upstream): + for obs in upstream: + yield obs.derive(data=obs.data * 2) + + result = make_stream(3).transform(double_all).fetch() + assert [o.data for o in result] == [0, 20, 40] + + def test_generator_function_stateful(self): + """Generator transforms can accumulate state and yield at their own pace.""" + + def running_sum(upstream): + total = 0 + for obs in upstream: + total += obs.data + yield obs.derive(data=total) + + result = make_stream(3).transform(running_sum).fetch() + # 0, 0+10=10, 10+20=30 + assert [o.data for o in result] == [0, 10, 30] + def test_quality_window(self): """QualityWindow keeps the best item per time window.""" store = MemoryStore() diff --git a/dimos/memory2/transform.py b/dimos/memory2/transform.py index a39fb3c3b3..05a809f8cb 100644 --- a/dimos/memory2/transform.py +++ b/dimos/memory2/transform.py @@ -52,6 +52,16 @@ def __call__(self, upstream: Iterator[Observation[T]]) -> Iterator[Observation[R yield result +class FnIterTransformer(Transformer[T, R]): + """Wraps a bare ``Iterator → Iterator`` callable (e.g. a generator function).""" + + def __init__(self, fn: Callable[[Iterator[Observation[T]]], Iterator[Observation[R]]]) -> None: + self._fn = fn + + def __call__(self, upstream: Iterator[Observation[T]]) -> Iterator[Observation[R]]: + return self._fn(upstream) + + class QualityWindow(Transformer[T, T]): """Keeps the highest-quality item per time window. From da676f6025e1e454b25ab142d5b09d5d10be0430 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Wed, 11 Mar 2026 21:14:42 +0800 Subject: [PATCH 084/118] memory2: update docs to reflect current API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - impl/README: LiveBackend → LiveChannel, add Configurable pattern, update _create_backend and Store/Session signatures - embeddings.md: fix Observation fields (_source → _loader), embedding type (np.ndarray → Embedding), remove unimplemented source chain, use temporal join for lineage - streaming.md: note .transform() accepts bare callables - README: add FnIterTransformer, generator function example --- dimos/memory2/README.md | 11 ++- dimos/memory2/embeddings.md | 148 +++++++++++++++++++++++++++++++++++ dimos/memory2/impl/README.md | 49 ++++++------ dimos/memory2/streaming.md | 2 +- 4 files changed, 185 insertions(+), 25 deletions(-) create mode 100644 dimos/memory2/embeddings.md diff --git a/dimos/memory2/README.md b/dimos/memory2/README.md index adc478f1e4..25a37a22f5 100644 --- a/dimos/memory2/README.md +++ b/dimos/memory2/README.md @@ -30,7 +30,7 @@ Supporting Systems: | `stream.py` | Stream node — filters, transforms, terminals | | `backend.py` | Backend protocol, LiveChannel / VectorStore / BlobStore ABCs | | `filter.py` | StreamQuery dataclass, filter types, Python query execution | -| `transform.py` | Transformer protocol, FnTransformer, QualityWindow | +| `transform.py` | Transformer ABC, FnTransformer, FnIterTransformer, QualityWindow | | `buffer.py` | Backpressure buffers for live mode (KeepLast, Bounded, Unbounded) | | `store.py` | Store / Session (Configurable), StoreConfig / SessionConfig | | `type.py` | Observation, EmbeddedObservation dataclasses | @@ -88,9 +88,16 @@ with store.session() as session: nearest = images.near(pose, radius=2.0).fetch() latest = images.last() - # Transform + # Transform (class or bare generator function) edges = images.transform(Canny()).save(session.stream("edges")) + def running_avg(upstream): + total, n = 0.0, 0 + for obs in upstream: + total += obs.data; n += 1 + yield obs.derive(data=total / n) + avgs = stream.transform(running_avg).fetch() + # Live for obs in images.live().transform(process): handle(obs) diff --git a/dimos/memory2/embeddings.md b/dimos/memory2/embeddings.md new file mode 100644 index 0000000000..de27cd18c9 --- /dev/null +++ b/dimos/memory2/embeddings.md @@ -0,0 +1,148 @@ +# memory2 Embedding Design + +## Core Principle: Enrichment, Not Replacement + +The embedding annotates the observation — it doesn't replace `.data`. +In memory1, `.data` IS the embedding and you need `parent_id` + `project_to()` to get back to the source image. We avoid this entirely. + +## Observation Types + +```python +@dataclass +class Observation(Generic[T]): + id: int + ts: float + pose: Any | None = None + tags: dict[str, Any] = field(default_factory=dict) + _data: T | _Unloaded = ... + _loader: Callable[[], T] | None = None # lazy loading via blob store + +@dataclass +class EmbeddedObservation(Observation[T]): + embedding: Embedding | None = None # populated by Embed transformer + similarity: float | None = None # populated by .search() +``` + +`EmbeddedObservation` is a subclass — passes anywhere `Observation` is accepted (LSP). +Users who don't care about types just use `Observation`. Users who want precision annotate with `EmbeddedObservation`. + +`derive()` on `Observation` promotes to `EmbeddedObservation` if `embedding=` is passed. +`derive()` on `EmbeddedObservation` returns `EmbeddedObservation`, preserving the embedding unless explicitly replaced. + +## Embed Transformer + +`Embed` is `Transformer[T, T]` — same data type in and out. It populates `.embedding` on each observation: + +```python +class Embed(Transformer[T, T]): + def __init__(self, model: EmbeddingModel): + self.model = model + + def __call__(self, upstream): + for batch in batched(upstream, 32): + vecs = self.model.embed_batch([obs.data for obs in batch]) + for obs, vec in zip(batch, vecs): + yield obs.derive(data=obs.data, embedding=vec) +``` + +`Stream[Image]` stays `Stream[Image]` after embedding — `T` is about `.data`, not the observation subclass. + +## Search + +`.search(query_vec, k)` lives on `Stream` itself. Returns a new Stream filtered to top-k by cosine similarity: + +```python +query_vec = clip.embed_text("a cat in the kitchen") + +results = images.transform(Embed(clip)).search(query_vec, k=20).fetch() +# results[0].data → Image +# results[0].embedding → np.ndarray +# results[0].similarity → 0.93 + +# Chainable with other filters +results = images.transform(Embed(clip)) \ + .search(query_vec, k=50) \ + .after(one_hour_ago) \ + .near(kitchen_pose, 5.0) \ + .fetch() +``` + +## Backend Handles Storage Strategy + +The Backend protocol decides how to store embeddings based on what it sees: + +- `append(image, ts=now, embedding=vec)` → backend routes: blob table for Image, vec0 table for vector +- `append(image, ts=now)` → blob table only (no embedding) +- `ListBackend`: stores embeddings in-memory, brute-force cosine on search +- `SqliteBackend`: vec0 side table for fast ANN search +- Future backends (Postgres/pgvector, Qdrant, etc.) do their thing + +Search is pushed down to the backend. Stream just passes `.search()` calls through. + +## Projection / Lineage + +**Usually not needed.** Since `.data` IS the original data, search results give you the image directly. + +When a downstream transform replaces `.data` (e.g., Image → Detection), use temporal join to get back to the source: + +```python +detection = detections.first() +detection.data # → Detection +detection.ts # → timestamp preserved by derive() + +# Get the source image via temporal join +source_image = images.at(detection.ts).first() +``` + +## Multi-Modal + +**Same embedding space = same stream.** CLIP maps images and text to the same 512-d space: + +```python +unified = session.stream("clip_unified") + +for obs in images.transform(Embed(clip.vision)): + unified.append(obs.data, ts=obs.ts, + tags={"modality": "image"}, embedding=obs.embedding) + +for obs in logs.transform(Embed(clip.text)): + unified.append(obs.data, ts=obs.ts, + tags={"modality": "text"}, embedding=obs.embedding) + +results = unified.search(query_vec, k=20).fetch() +# results[i].tags["modality"] tells you what it is +``` + +**Different embedding spaces = different streams.** Can't mix CLIP and sentence-transformer vectors. + +## Chaining — Embedding as Cheap Pre-Filter + +```python +smoke_query = clip.embed_text("smoke or fire") + +detections = images.transform(Embed(clip)) \ + .search(smoke_query, k=100) \ + .transform(ExpensiveVLMDetector()) +# VLM only runs on 100 most promising frames + +# Smart transformer can use embedding directly +class SmartDetector(Transformer[Image, Detection]): + def __call__(self, upstream: Iterator[EmbeddedObservation[Image]]) -> ...: + for obs in upstream: + if obs.embedding @ self.query > 0.3: + yield obs.derive(data=self.detect(obs.data)) +``` + +## Text Search (FTS) — Separate Concern + +FTS is keyword-based, not embedding-based. Complementary, not competing: + +```python +# Keyword search via FTS5 +logs = session.text_stream("logs") +logs.search_text("motor fault").fetch() + +# Semantic search via embeddings +log_idx = logs.transform(Embed(sentence_model)).store("log_emb") +log_idx.search(model.embed("motor problems"), k=10).fetch() +``` diff --git a/dimos/memory2/impl/README.md b/dimos/memory2/impl/README.md index 19efc17bed..bc95405ee2 100644 --- a/dimos/memory2/impl/README.md +++ b/dimos/memory2/impl/README.md @@ -1,6 +1,6 @@ # impl — Backend implementations -Storage backends for memory2. Each backend implements the `Backend` protocol (and optionally `LiveBackend`) to provide observation storage with query support. +Storage backends for memory2. Each backend implements the `Backend` protocol to provide observation storage with query support. All backends support live mode via a pluggable `LiveChannel`. ## Existing backends @@ -14,20 +14,34 @@ Storage backends for memory2. Each backend implements the `Backend` protocol (an ### 1. Implement the Backend protocol ```python -from dimos.memory2.backend import Backend +from dimos.memory2.backend import Backend, BackendConfig, LiveChannel from dimos.memory2.filter import StreamQuery +from dimos.memory2.livechannel.subject import SubjectChannel from dimos.memory2.type import Observation +from dimos.protocol.service.spec import Configurable + +class MyBackend(Configurable[BackendConfig], Generic[T]): + default_config: type[BackendConfig] = BackendConfig + + def __init__(self, name: str, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._name = name + self._channel: LiveChannel[T] = self.config.live_channel or SubjectChannel() -class MyBackend(Generic[T]): @property def name(self) -> str: return self._name + @property + def live_channel(self) -> LiveChannel[T]: + return self._channel + def append(self, obs: Observation[T]) -> Observation[T]: """Assign an id and store. Return the stored observation.""" obs.id = self._next_id self._next_id += 1 # ... persist obs ... + self._channel.notify(obs) return obs def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: @@ -41,7 +55,7 @@ class MyBackend(Generic[T]): # query.search_vec — Embedding for vector search # query.search_k — top-k for vector search # query.search_text — substring text search - # query.live_buffer — if set, switch to live mode (see LiveBackend) + # query.live_buffer — if set, switch to live mode ... def count(self, query: StreamQuery) -> int: @@ -51,24 +65,13 @@ class MyBackend(Generic[T]): `Backend` is a `@runtime_checkable` Protocol — no base class needed, just implement the methods. -### 2. Add LiveBackend support (optional) - -If your backend supports live subscriptions (push notifications on new observations): +### 2. Live mode via LiveChannel -```python -from dimos.memory2.backend import LiveBackend - -class MyBackend(Generic[T]): - # ... Backend methods ... - - def subscribe(self, buf: BackpressureBuffer[Observation[T]]) -> DisposableBase: - """Register a buffer for push notifications. Return a disposable to unsubscribe.""" - ... -``` +Every backend exposes a `live_channel` property. The default `SubjectChannel` handles same-process fan-out. Inject a custom `LiveChannel` (Redis pub/sub, Postgres LISTEN/NOTIFY, etc.) via `BackendConfig.live_channel` for cross-process use. The `iterate()` method should check `query.live_buffer`: - If `None`: return a snapshot iterator -- If set: subscribe before backfill, then yield a live tail that deduplicates by `obs.id` +- If set: subscribe via `self._channel.subscribe(buf)` before backfill, then yield a live tail that deduplicates by `obs.id` See `ListBackend._iterate_live()` for the reference implementation. @@ -78,12 +81,14 @@ See `ListBackend._iterate_live()` for the reference implementation. from dimos.memory2.store import Session, Store class MySession(Session): - def _create_backend(self, name: str, payload_type: type | None = None) -> Backend: - return MyBackend(self._conn, name) + def _create_backend( + self, name: str, payload_type: type | None = None, **config: Any + ) -> Backend: + return MyBackend(name, **config) class MyStore(Store): - def session(self) -> MySession: - return MySession(...) + def session(self, **kwargs: Any) -> MySession: + return MySession(**kwargs) ``` ### 4. Add to the grid test diff --git a/dimos/memory2/streaming.md b/dimos/memory2/streaming.md index c1c4a3c36c..fd7f5519a1 100644 --- a/dimos/memory2/streaming.md +++ b/dimos/memory2/streaming.md @@ -13,7 +13,7 @@ These return generators — each observation flows through one at a time. Safe w |---------------------------------------------------------------------------|-------------------------------------------------| | `.after()` `.before()` `.time_range()` `.at()` `.near()` `.filter_tags()` | Filter predicates — skip non-matching obs | | `.filter(pred)` | Same, user-defined predicate | -| `.transform(xf)` / `.map(fn)` | Generator — yields transformed obs one by one | +| `.transform(xf_or_fn)` / `.map(fn)` | Generator — yields transformed obs one by one | | `.search_text(text)` | Generator — substring match filter | | `.limit(k)` | `islice` — stops after k | | `.offset(n)` | `islice` — skips first n | From a0c9c70e3bf5038cf2bac1c2756e29a60530e893 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Wed, 11 Mar 2026 21:21:43 +0800 Subject: [PATCH 085/118] memory2: implement full SqliteBackend with vec0 vector search, JSONB tags, and SQL filter pushdown MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add SqliteVectorStore using sqlite-vec vec0 virtual tables with cosine distance - Implement SqliteBackend: append, iterate (snapshot/live/vector), count with SQL pushdown - Add SQL filter compilation for time, tags, and range filters; Python fallback for NearFilter/PredicateFilter - Wire SqliteSession with _streams registry table, codec persistence, shared store auto-wiring - Support eager blob loading via co-located JOIN optimization - Load sqlite-vec extension in SqliteStore with graceful fallback - Remove xfail markers from test_impl.py — all 36 grid tests pass --- dimos/memory2/__init__.py | 3 +- dimos/memory2/impl/sqlite.py | 515 +++++++++++++++++++++++++- dimos/memory2/test_impl.py | 18 +- dimos/memory2/vectorstore/__init__.py | 3 +- dimos/memory2/vectorstore/sqlite.py | 82 ++++ 5 files changed, 591 insertions(+), 30 deletions(-) create mode 100644 dimos/memory2/vectorstore/sqlite.py diff --git a/dimos/memory2/__init__.py b/dimos/memory2/__init__.py index 9be60f6d9f..0b358fe438 100644 --- a/dimos/memory2/__init__.py +++ b/dimos/memory2/__init__.py @@ -20,7 +20,7 @@ TimeRangeFilter, ) from dimos.memory2.impl.memory import ListBackend, MemorySession, MemoryStore -from dimos.memory2.impl.sqlite import SqliteBackend, SqliteSession, SqliteStore +from dimos.memory2.impl.sqlite import SqliteBackend, SqliteSession, SqliteStore, SqliteStoreConfig from dimos.memory2.livechannel import SubjectChannel from dimos.memory2.store import Session, SessionConfig, Store, StoreConfig, StreamNamespace from dimos.memory2.stream import Stream @@ -55,6 +55,7 @@ "SqliteBackend", "SqliteSession", "SqliteStore", + "SqliteStoreConfig", "Store", "StoreConfig", "Stream", diff --git a/dimos/memory2/impl/sqlite.py b/dimos/memory2/impl/sqlite.py index 31bb38855a..1ea9a82fd8 100644 --- a/dimos/memory2/impl/sqlite.py +++ b/dimos/memory2/impl/sqlite.py @@ -14,24 +14,189 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, replace +from itertools import islice +import json +import re import sqlite3 +import threading from typing import TYPE_CHECKING, Any, Generic, TypeVar from dimos.memory2.backend import BackendConfig +from dimos.memory2.blobstore.sqlite import SqliteBlobStore +from dimos.memory2.codecs.base import Codec, codec_for +from dimos.memory2.filter import ( + AfterFilter, + AtFilter, + BeforeFilter, + TagsFilter, + TimeRangeFilter, +) from dimos.memory2.livechannel.subject import SubjectChannel from dimos.memory2.store import Session, Store, StoreConfig +from dimos.memory2.type import _UNLOADED, Observation from dimos.protocol.service.spec import Configurable if TYPE_CHECKING: from collections.abc import Iterator + from reactivex.abc import DisposableBase + from dimos.memory2.backend import Backend, LiveChannel - from dimos.memory2.filter import StreamQuery - from dimos.memory2.type import Observation + from dimos.memory2.buffer import BackpressureBuffer + from dimos.memory2.filter import Filter, StreamQuery T = TypeVar("T") +_IDENT_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + + +# ── Helpers ────────────────────────────────────────────────────── + + +def _validate_identifier(name: str) -> None: + if not _IDENT_RE.match(name): + raise ValueError(f"Invalid stream name: {name!r}") + + +def _decompose_pose(pose: Any) -> tuple[float, ...] | None: + if pose is None: + return None + if hasattr(pose, "position"): + pos = pose.position + orient = getattr(pose, "orientation", None) + x, y, z = float(pos.x), float(pos.y), float(getattr(pos, "z", 0.0)) + if orient is not None: + return (x, y, z, float(orient.x), float(orient.y), float(orient.z), float(orient.w)) + return (x, y, z, 0.0, 0.0, 0.0, 1.0) + if isinstance(pose, (list, tuple)): + vals = [float(v) for v in pose] + while len(vals) < 7: + vals.append(0.0 if len(vals) < 6 else 1.0) + return tuple(vals[:7]) + return None + + +def _reconstruct_pose( + x: float | None, + y: float | None, + z: float | None, + qx: float | None, + qy: float | None, + qz: float | None, + qw: float | None, +) -> tuple[float, ...] | None: + if x is None: + return None + return (x, y or 0.0, z or 0.0, qx or 0.0, qy or 0.0, qz or 0.0, qw or 1.0) + + +def _compile_filter(f: Filter) -> tuple[str, list[Any]] | None: + """Compile a filter to SQL WHERE clause. Returns None for non-pushable filters.""" + if isinstance(f, AfterFilter): + return ("ts > ?", [f.t]) + if isinstance(f, BeforeFilter): + return ("ts < ?", [f.t]) + if isinstance(f, TimeRangeFilter): + return ("ts >= ? AND ts <= ?", [f.t1, f.t2]) + if isinstance(f, AtFilter): + return ("ABS(ts - ?) <= ?", [f.t, f.tolerance]) + if isinstance(f, TagsFilter): + clauses = [] + params: list[Any] = [] + for k, v in f.tags.items(): + clauses.append(f"json_extract(tags, '$.{k}') = ?") + params.append(v) + return (" AND ".join(clauses), params) + # NearFilter, PredicateFilter — not pushable + return None + + +def _compile_query( + query: StreamQuery, + table: str, + *, + join_blob: bool = False, +) -> tuple[str, list[Any], list[Filter]]: + """Compile a StreamQuery to SQL. + + Returns (sql, params, python_filters) where python_filters must be + applied as post-filters in Python. + """ + if join_blob: + select = f'SELECT meta.id, meta.ts, meta.pose_x, meta.pose_y, meta.pose_z, meta.pose_qx, meta.pose_qy, meta.pose_qz, meta.pose_qw, json(meta.tags), blob.data FROM "{table}" AS meta JOIN "{table}_blob" AS blob ON blob.id = meta.id' + else: + select = f'SELECT id, ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, json(tags) FROM "{table}"' + + where_parts: list[str] = [] + params: list[Any] = [] + python_filters: list[Filter] = [] + + for f in query.filters: + compiled = _compile_filter(f) + if compiled is not None: + sql_part, sql_params = compiled + if join_blob: + # Qualify column references for JOIN + sql_part = sql_part.replace("ts ", "meta.ts ").replace("tags", "meta.tags") + where_parts.append(sql_part) + params.extend(sql_params) + else: + python_filters.append(f) + + sql = select + if where_parts: + sql += " WHERE " + " AND ".join(where_parts) + + # ORDER BY + if query.order_field: + col = "meta." + query.order_field if join_blob else query.order_field + direction = "DESC" if query.order_desc else "ASC" + sql += f" ORDER BY {col} {direction}" + else: + col = "meta.id" if join_blob else "id" + sql += f" ORDER BY {col} ASC" + + # Only push LIMIT/OFFSET to SQL when there are no Python post-filters + if not python_filters and not query.search_text: + if query.limit_val is not None: + if query.offset_val: + sql += f" LIMIT {query.limit_val} OFFSET {query.offset_val}" + else: + sql += f" LIMIT {query.limit_val}" + elif query.offset_val: + sql += f" LIMIT -1 OFFSET {query.offset_val}" + + return (sql, params, python_filters) + + +def _compile_count( + query: StreamQuery, + table: str, +) -> tuple[str, list[Any], list[Filter]]: + """Compile a StreamQuery to a COUNT SQL query.""" + where_parts: list[str] = [] + params: list[Any] = [] + python_filters: list[Filter] = [] + + for f in query.filters: + compiled = _compile_filter(f) + if compiled is not None: + sql_part, sql_params = compiled + where_parts.append(sql_part) + params.extend(sql_params) + else: + python_filters.append(f) + + sql = f'SELECT COUNT(*) FROM "{table}"' + if where_parts: + sql += " WHERE " + " AND ".join(where_parts) + + return (sql, params, python_filters) + + +# ── SqliteBackend ──────────────────────────────────────────────── + class SqliteBackend(Configurable[BackendConfig], Generic[T]): """SQLite-backed observation storage for a single stream (table).""" @@ -42,7 +207,9 @@ def __init__(self, conn: sqlite3.Connection, name: str, **kwargs: Any) -> None: super().__init__(**kwargs) self._conn = conn self._name = name + self._codec: Codec[Any] = self.config.codec # type: ignore[assignment] self._channel: LiveChannel[T] = self.config.live_channel or SubjectChannel() + self._lock = threading.Lock() @property def name(self) -> str: @@ -52,41 +219,352 @@ def name(self) -> str: def live_channel(self) -> LiveChannel[T]: return self._channel - def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: - raise NotImplementedError + @property + def _join_blobs(self) -> bool: + if not self.config.eager_blobs: + return False + bs = self.config.blob_store + return isinstance(bs, SqliteBlobStore) and bs._conn is self._conn + + def _make_loader(self, row_id: int) -> Any: + bs = self.config.blob_store + assert bs is not None + name, codec = self._name, self._codec + owner_tid = threading.get_ident() + + def loader() -> Any: + assert threading.get_ident() == owner_tid + raw = bs.get(name, row_id) + return codec.decode(raw) + + return loader + + def _row_to_obs(self, row: tuple[Any, ...], *, has_blob: bool = False) -> Observation[T]: + if has_blob: + row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json, blob_data = row + else: + row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json = row + blob_data = None + + pose = _reconstruct_pose(px, py, pz, qx, qy, qz, qw) + tags = json.loads(tags_json) if tags_json else {} + + if has_blob and blob_data is not None: + data = self._codec.decode(blob_data) + return Observation(id=row_id, ts=ts, pose=pose, tags=tags, _data=data) + + return Observation( + id=row_id, + ts=ts, + pose=pose, + tags=tags, + _data=_UNLOADED, + _loader=self._make_loader(row_id), # type: ignore[arg-type] + ) + + # ── Write ──────────────────────────────────────────────────── def append(self, obs: Observation[T]) -> Observation[T]: - raise NotImplementedError + encoded = self._codec.encode(obs._data) + pose = _decompose_pose(obs.pose) + tags_json = json.dumps(obs.tags) if obs.tags else "{}" + + with self._lock: + if pose: + px, py, pz, qx, qy, qz, qw = pose + else: + px = py = pz = qx = qy = qz = qw = None + + cur = self._conn.execute( + f'INSERT INTO "{self._name}" (ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, tags) ' + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, jsonb(?))", + (obs.ts, px, py, pz, qx, qy, qz, qw, tags_json), + ) + row_id = cur.lastrowid + assert row_id is not None + + bs = self.config.blob_store + assert bs is not None + bs.put(self._name, row_id, encoded) + + vs = self.config.vector_store + if vs is not None: + emb = getattr(obs, "embedding", None) + if emb is not None: + vs.put(self._name, row_id, emb) + + self._conn.commit() + + obs.id = row_id + self._channel.notify(obs) + return obs + + # ── Read ───────────────────────────────────────────────────── + + def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: + if query.search_vec is not None and query.live_buffer is not None: + raise TypeError("Cannot combine .search() with .live() — search is a batch operation.") + buf = query.live_buffer + if buf is not None: + sub = self._channel.subscribe(buf) + return self._iterate_live(query, buf, sub) + return self._iterate_snapshot(query) + + def _iterate_snapshot(self, query: StreamQuery) -> Iterator[Observation[T]]: + if query.search_vec is not None and self.config.vector_store is not None: + yield from self._vector_search(query) + return + + join = self._join_blobs + sql, params, python_filters = _compile_query(query, self._name, join_blob=join) + + rows = self._conn.execute(sql, params).fetchall() + it: Iterator[Observation[T]] = (self._row_to_obs(r, has_blob=join) for r in rows) + + # Text search — requires loading data + if query.search_text is not None: + needle = query.search_text.lower() + it = (obs for obs in it if needle in str(obs.data).lower()) + + # Apply Python post-filters + if python_filters: + it = (obs for obs in it if all(f.matches(obs) for f in python_filters)) + + # Apply LIMIT/OFFSET in Python when we couldn't push to SQL + if python_filters or query.search_text: + if query.offset_val: + it = islice(it, query.offset_val, None) + if query.limit_val is not None: + it = islice(it, query.limit_val) + + yield from it + + def _vector_search(self, query: StreamQuery) -> Iterator[Observation[T]]: + vs = self.config.vector_store + assert vs is not None and query.search_vec is not None + + hits = vs.search(self._name, query.search_vec, query.search_k or 10) + if not hits: + return + + ids = [h[0] for h in hits] + dict(hits) + + # Batch-fetch metadata + join = self._join_blobs + placeholders = ",".join("?" * len(ids)) + if join: + sql = ( + f"SELECT meta.id, meta.ts, meta.pose_x, meta.pose_y, meta.pose_z, " + f"meta.pose_qx, meta.pose_qy, meta.pose_qz, meta.pose_qw, json(meta.tags), blob.data " + f'FROM "{self._name}" AS meta ' + f'JOIN "{self._name}_blob" AS blob ON blob.id = meta.id ' + f"WHERE meta.id IN ({placeholders})" + ) + else: + sql = ( + f"SELECT id, ts, pose_x, pose_y, pose_z, " + f"pose_qx, pose_qy, pose_qz, pose_qw, json(tags) " + f'FROM "{self._name}" WHERE id IN ({placeholders})' + ) + + rows = self._conn.execute(sql, ids).fetchall() + obs_by_id: dict[int, Observation[T]] = {} + for r in rows: + obs = self._row_to_obs(r, has_blob=join) + obs_by_id[obs.id] = obs + + # Preserve VectorStore ranking order, promoting to EmbeddedObservation + ranked: list[Observation[T]] = [] + for obs_id, sim in hits: + obs = obs_by_id.get(obs_id) + if obs is not None: + ranked.append(obs.derive(data=obs.data, embedding=query.search_vec, similarity=sim)) + + # Apply remaining query ops (skip vector search) + rest = replace(query, search_vec=None, search_k=None) + yield from rest.apply(iter(ranked)) + + def _iterate_live( + self, + query: StreamQuery, + buf: BackpressureBuffer[Observation[T]], + sub: DisposableBase, + ) -> Iterator[Observation[T]]: + from dimos.memory2.buffer import ClosedError + + # Backfill phase + last_id = -1 + for obs in self._iterate_snapshot(query): + last_id = max(last_id, obs.id) + yield obs + + # Live tail + filters = query.filters + try: + while True: + obs = buf.take() + if obs.id <= last_id: + continue + last_id = obs.id + if filters and not all(f.matches(obs) for f in filters): + continue + yield obs + except (ClosedError, StopIteration): + sub.dispose() def count(self, query: StreamQuery) -> int: - raise NotImplementedError + if query.search_vec or query.search_text: + return sum(1 for _ in self.iterate(query)) + + sql, params, python_filters = _compile_count(query, self._name) + if python_filters: + return sum(1 for _ in self.iterate(query)) + + row = self._conn.execute(sql, params).fetchone() + return int(row[0]) if row else 0 + + +# ── SqliteSession ──────────────────────────────────────────────── class SqliteSession(Session): """Session owning a single SQLite connection.""" - def __init__(self, conn: sqlite3.Connection, **kwargs: Any) -> None: + def __init__( + self, conn: sqlite3.Connection, *, vec_available: bool = False, **kwargs: Any + ) -> None: super().__init__(**kwargs) self._conn = conn + self._vec_available = vec_available + self._blob_store: SqliteBlobStore | None = None + self._vector_store: Any | None = None + + # Create stream registry + self._conn.execute( + "CREATE TABLE IF NOT EXISTS _streams (" + " name TEXT PRIMARY KEY," + " payload_module TEXT NOT NULL," + " codec_id TEXT NOT NULL" + ")" + ) + self._conn.commit() + + def _ensure_shared_stores(self) -> None: + """Lazily create shared stores on first stream creation.""" + if self._blob_store is None: + self._blob_store = SqliteBlobStore(self._conn) + if self._vector_store is None and self._vec_available: + from dimos.memory2.vectorstore.sqlite import SqliteVectorStore + + self._vector_store = SqliteVectorStore(self._conn) + + @staticmethod + def _codec_id(codec: Codec[Any]) -> str: + from dimos.memory2.codecs.jpeg import JpegCodec + from dimos.memory2.codecs.lcm import LcmCodec + + if isinstance(codec, JpegCodec): + return "jpeg" + if isinstance(codec, LcmCodec): + return "lcm" + return "pickle" + + @staticmethod + def _codec_from_id(codec_id: str, payload_module: str) -> Codec[Any]: + from dimos.memory2.codecs.pickle import PickleCodec + + if codec_id == "jpeg": + from dimos.memory2.codecs.jpeg import JpegCodec + + return JpegCodec() + if codec_id == "lcm": + from dimos.memory2.codecs.lcm import LcmCodec + + # Resolve the payload type from module path + parts = payload_module.rsplit(".", 1) + if len(parts) == 2: + import importlib + + mod = importlib.import_module(parts[0]) + cls = getattr(mod, parts[1]) + return LcmCodec(cls) + return PickleCodec() + return PickleCodec() def _create_backend( self, name: str, payload_type: type[Any] | None = None, **config: Any ) -> Backend[Any]: + _validate_identifier(name) + self._ensure_shared_stores() + + # Look up existing stream in registry + row = self._conn.execute( + "SELECT payload_module, codec_id FROM _streams WHERE name = ?", (name,) + ).fetchone() + + if row is not None: + stored_module, stored_codec_id = row + if payload_type is not None: + actual_module = f"{payload_type.__module__}.{payload_type.__qualname__}" + if actual_module != stored_module: + raise ValueError( + f"Stream {name!r} was created with type {stored_module}, " + f"but opened with {actual_module}" + ) + codec = config.get("codec") or self._codec_from_id(stored_codec_id, stored_module) + else: + if payload_type is None: + raise TypeError(f"Stream {name!r} does not exist yet — payload_type is required") + codec = config.get("codec") or codec_for(payload_type) + payload_module = f"{payload_type.__module__}.{payload_type.__qualname__}" + self._conn.execute( + "INSERT INTO _streams (name, payload_module, codec_id) VALUES (?, ?, ?)", + (name, payload_module, self._codec_id(codec)), + ) + self._conn.commit() + + # Create metadata table + self._conn.execute( + f'CREATE TABLE IF NOT EXISTS "{name}" (' + " id INTEGER PRIMARY KEY AUTOINCREMENT," + " ts REAL NOT NULL UNIQUE," + " pose_x REAL, pose_y REAL, pose_z REAL," + " pose_qx REAL, pose_qy REAL, pose_qz REAL, pose_qw REAL," + " tags BLOB DEFAULT (jsonb('{}'))" + ")" + ) + self._conn.commit() + + # Merge shared stores as defaults + if "blob_store" not in config or config["blob_store"] is None: + config["blob_store"] = self._blob_store + if "vector_store" not in config or config["vector_store"] is None: + config["vector_store"] = self._vector_store + config["codec"] = codec + return SqliteBackend(self._conn, name, **config) def list_streams(self) -> list[str]: - # TODO: also query DB for persisted streams not yet opened - return list(self._streams.keys()) + db_names = {row[0] for row in self._conn.execute("SELECT name FROM _streams").fetchall()} + return sorted(db_names | set(self._streams.keys())) def delete_stream(self, name: str) -> None: self._streams.pop(name, None) - # TODO: drop underlying table/rows from SQLite + self._conn.execute(f'DROP TABLE IF EXISTS "{name}"') + self._conn.execute(f'DROP TABLE IF EXISTS "{name}_blob"') + self._conn.execute(f'DROP TABLE IF EXISTS "{name}_vec"') + self._conn.execute("DELETE FROM _streams WHERE name = ?", (name,)) + self._conn.commit() def stop(self) -> None: super().stop() self._conn.close() +# ── SqliteStore ────────────────────────────────────────────────── + + @dataclass class SqliteStoreConfig(StoreConfig): """Config for SQLite-backed store.""" @@ -105,4 +583,17 @@ def __init__(self, **kwargs: Any) -> None: def session(self, **kwargs: Any) -> SqliteSession: conn = sqlite3.connect(self.config.path, check_same_thread=False) conn.execute("PRAGMA journal_mode=WAL") - return SqliteSession(conn, **kwargs) + conn.execute("PRAGMA synchronous=NORMAL") + + vec_available = False + try: + import sqlite_vec + + conn.enable_load_extension(True) + sqlite_vec.load(conn) + conn.enable_load_extension(False) + vec_available = True + except (ImportError, Exception): + pass + + return SqliteSession(conn, vec_available=vec_available, **kwargs) diff --git a/dimos/memory2/test_impl.py b/dimos/memory2/test_impl.py index 7ed342ffdb..f8890c38c6 100644 --- a/dimos/memory2/test_impl.py +++ b/dimos/memory2/test_impl.py @@ -21,7 +21,7 @@ from contextlib import contextmanager from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import pytest @@ -77,25 +77,11 @@ def sqlite_session() -> Generator[Session, None, None]: basic_cases = [c for c in testcases if "basic" in c.tags] -# Mark sqlite xfail until backend methods are implemented -_xfail_if_stub = { - "sqlite": pytest.mark.xfail( - reason="SqliteBackend not yet implemented", raises=NotImplementedError, strict=False - ), -} - - -def _apply_marks(cases: list[Case]) -> list[Any]: - return [ - pytest.param(c, marks=_xfail_if_stub[c.name]) if c.name in _xfail_if_stub else c - for c in cases - ] - # ── Tests ────────────────────────────────────────────────────────── -@pytest.mark.parametrize("case", _apply_marks(basic_cases), ids=lambda c: c.name) +@pytest.mark.parametrize("case", basic_cases, ids=lambda c: c.name) class TestStoreBasic: """Core store operations that every backend must support.""" diff --git a/dimos/memory2/vectorstore/__init__.py b/dimos/memory2/vectorstore/__init__.py index fbdd9d3666..d8f3395cb8 100644 --- a/dimos/memory2/vectorstore/__init__.py +++ b/dimos/memory2/vectorstore/__init__.py @@ -13,5 +13,6 @@ # limitations under the License. from dimos.memory2.vectorstore.memory import MemoryVectorStore +from dimos.memory2.vectorstore.sqlite import SqliteVectorStore -__all__ = ["MemoryVectorStore"] +__all__ = ["MemoryVectorStore", "SqliteVectorStore"] diff --git a/dimos/memory2/vectorstore/sqlite.py b/dimos/memory2/vectorstore/sqlite.py new file mode 100644 index 0000000000..736cc16e27 --- /dev/null +++ b/dimos/memory2/vectorstore/sqlite.py @@ -0,0 +1,82 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING + +from dimos.memory2.backend import VectorStore + +if TYPE_CHECKING: + import sqlite3 + + from dimos.models.embedding.base import Embedding + + +class SqliteVectorStore(VectorStore): + """Vector store backed by sqlite-vec's vec0 virtual tables. + + Creates one virtual table per stream: ``"{stream}_vec"``. + Dimensionality is determined lazily on the first ``put()``. + + Does NOT own the connection — lifecycle managed externally. + """ + + def __init__(self, conn: sqlite3.Connection) -> None: + self._conn = conn + self._tables: dict[str, int] = {} # stream -> dimensionality + + def _ensure_table(self, stream: str, dim: int) -> None: + if stream in self._tables: + return + self._conn.execute( + f'CREATE VIRTUAL TABLE IF NOT EXISTS "{stream}_vec" ' + f"USING vec0(embedding float[{dim}] distance_metric=cosine)" + ) + self._tables[stream] = dim + + # ── Resource lifecycle ──────────────────────────────────────── + + def start(self) -> None: + pass + + def stop(self) -> None: + pass + + # ── VectorStore interface ──────────────────────────────────── + + def put(self, stream: str, key: int, embedding: Embedding) -> None: + vec = embedding.to_numpy().tolist() + self._ensure_table(stream, len(vec)) + self._conn.execute( + f'INSERT OR REPLACE INTO "{stream}_vec" (rowid, embedding) VALUES (?, ?)', + (key, json.dumps(vec)), + ) + + def search(self, stream: str, query: Embedding, k: int) -> list[tuple[int, float]]: + if stream not in self._tables: + return [] + vec = query.to_numpy().tolist() + rows = self._conn.execute( + f'SELECT rowid, distance FROM "{stream}_vec" WHERE embedding MATCH ? AND k = ?', + (json.dumps(vec), k), + ).fetchall() + # vec0 cosine distance = 1 - cosine_similarity + return [(int(row[0]), max(0.0, 1.0 - row[1])) for row in rows] + + def delete(self, stream: str, key: int) -> None: + if stream not in self._tables: + return + self._conn.execute(f'DELETE FROM "{stream}_vec" WHERE rowid = ?', (key,)) From 0b094047ac8fde24bb9749c0fbd8c4850c51bf64 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Wed, 11 Mar 2026 21:26:53 +0800 Subject: [PATCH 086/118] memory2: stream rows via cursor pagination instead of fetchall() Add configurable page_size (default 256) to BackendConfig. SqliteBackend now iterates the cursor with arraysize set to page_size for memory-efficient streaming of large result sets. --- dimos/memory2/backend.py | 1 + dimos/memory2/impl/sqlite.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/dimos/memory2/backend.py b/dimos/memory2/backend.py index f5e74cf6ad..928b74e229 100644 --- a/dimos/memory2/backend.py +++ b/dimos/memory2/backend.py @@ -50,6 +50,7 @@ class BackendConfig: vector_store: VectorStore | None = None eager_blobs: bool = False codec: Codec[Any] | None = None + page_size: int = 256 @runtime_checkable diff --git a/dimos/memory2/impl/sqlite.py b/dimos/memory2/impl/sqlite.py index 1ea9a82fd8..0476024cf2 100644 --- a/dimos/memory2/impl/sqlite.py +++ b/dimos/memory2/impl/sqlite.py @@ -318,8 +318,9 @@ def _iterate_snapshot(self, query: StreamQuery) -> Iterator[Observation[T]]: join = self._join_blobs sql, params, python_filters = _compile_query(query, self._name, join_blob=join) - rows = self._conn.execute(sql, params).fetchall() - it: Iterator[Observation[T]] = (self._row_to_obs(r, has_blob=join) for r in rows) + cur = self._conn.execute(sql, params) + cur.arraysize = self.config.page_size + it: Iterator[Observation[T]] = (self._row_to_obs(r, has_blob=join) for r in cur) # Text search — requires loading data if query.search_text is not None: From df076ce0fa657220b9c3c64e10db2f0cd1d32a84 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Wed, 11 Mar 2026 21:30:30 +0800 Subject: [PATCH 087/118] memory2: add lazy/eager blob tests and spy store delegation grid tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - TestBlobLoading: verify lazy (_UNLOADED sentinel + loader) vs eager (JOIN inline) paths for SqliteBackend, plus value equivalence between both modes - TestStoreDelegation: grid tests with SpyBlobStore/SpyVectorStore injected into both memory and sqlite backends — verify append→put, iterate→get, and search delegation through the pluggable store ABCs --- dimos/memory2/test_impl.py | 259 +++++++++++++++++++++++++++++++++++++ 1 file changed, 259 insertions(+) diff --git a/dimos/memory2/test_impl.py b/dimos/memory2/test_impl.py index f8890c38c6..6663cf5a04 100644 --- a/dimos/memory2/test_impl.py +++ b/dimos/memory2/test_impl.py @@ -263,3 +263,262 @@ def test_search_text(self, case: Case) -> None: results = s.search_text("motor").fetch() assert len(results) == 1 assert results[0].data == "motor fault" + + +# ── Lazy / eager blob loading tests ────────────────────────────── + + +class TestBlobLoading: + """Verify lazy and eager blob loading paths.""" + + def test_sqlite_lazy_by_default(self) -> None: + """Default sqlite iteration uses lazy loaders — data is _UNLOADED until accessed.""" + import tempfile + + from dimos.memory2.impl.sqlite import SqliteStore + from dimos.memory2.type import _Unloaded + + with tempfile.NamedTemporaryFile(suffix=".db") as f: + store = SqliteStore(path=f.name) + with store.session() as session: + s = session.stream("lazy_test", str) + s.append("hello", ts=1.0) + s.append("world", ts=2.0) + + for obs in s: + # Before accessing .data, _data should be the unloaded sentinel + assert isinstance(obs._data, _Unloaded) + assert obs._loader is not None + # Accessing .data triggers the loader + val = obs.data + assert isinstance(val, str) + # After loading, _loader is cleared + assert obs._loader is None + + def test_sqlite_eager_loads_inline(self) -> None: + """With eager_blobs=True, data is loaded via JOIN — no lazy loader.""" + import tempfile + + from dimos.memory2.impl.sqlite import SqliteStore + from dimos.memory2.type import _Unloaded + + with tempfile.NamedTemporaryFile(suffix=".db") as f: + store = SqliteStore(path=f.name) + with store.session() as session: + s = session.stream("eager_test", str, eager_blobs=True) + s.append("hello", ts=1.0) + s.append("world", ts=2.0) + + for obs in s: + # Data should already be loaded — no lazy sentinel + assert not isinstance(obs._data, _Unloaded) + assert obs._loader is None + assert isinstance(obs.data, str) + + def test_sqlite_lazy_and_eager_same_values(self) -> None: + """Both paths must return identical data.""" + import tempfile + + from dimos.memory2.impl.sqlite import SqliteStore + + with tempfile.NamedTemporaryFile(suffix=".db") as f: + store = SqliteStore(path=f.name) + with store.session() as session: + lazy_s = session.stream("vals", str) + lazy_s.append("alpha", ts=1.0, tags={"k": "v"}) + lazy_s.append("beta", ts=2.0, tags={"k": "w"}) + + # Lazy read + lazy_results = lazy_s.fetch() + + # Eager read — new stream handle with eager_blobs on same backend + eager_s = session.stream("vals", str, eager_blobs=True) + eager_results = eager_s.fetch() + + assert [o.data for o in lazy_results] == [o.data for o in eager_results] + assert [o.tags for o in lazy_results] == [o.tags for o in eager_results] + assert [o.ts for o in lazy_results] == [o.ts for o in eager_results] + + def test_memory_lazy_with_blobstore(self) -> None: + """MemoryStore with a BlobStore uses lazy loaders.""" + from dimos.memory2.blobstore.file import FileBlobStore + from dimos.memory2.impl.memory import MemoryStore + from dimos.memory2.type import _Unloaded + + store = MemoryStore() + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + bs = FileBlobStore(root=tmpdir) + bs.start() + with store.session(blob_store=bs) as session: + s = session.stream("mem_lazy", str) + s.append("data1", ts=1.0) + + obs = s.first() + # ListBackend replaces _data with _UNLOADED when blob_store is set + assert isinstance(obs._data, _Unloaded) + assert obs.data == "data1" + bs.stop() + + +# ── Spy stores ─────────────────────────────────────────────────── + + +class SpyBlobStore: + """BlobStore that records all calls for verification.""" + + def __init__(self) -> None: + self.puts: list[tuple[str, int, bytes]] = [] + self.gets: list[tuple[str, int]] = [] + self.store: dict[tuple[str, int], bytes] = {} + + def start(self) -> None: + pass + + def stop(self) -> None: + pass + + def put(self, stream: str, key: int, data: bytes) -> None: + self.puts.append((stream, key, data)) + self.store[(stream, key)] = data + + def get(self, stream: str, key: int) -> bytes: + self.gets.append((stream, key)) + return self.store[(stream, key)] + + def delete(self, stream: str, key: int) -> None: + self.store.pop((stream, key), None) + + +class SpyVectorStore: + """VectorStore that records all calls for verification.""" + + def __init__(self) -> None: + self.puts: list[tuple[str, int]] = [] + self.searches: list[tuple[str, int]] = [] + self.vectors: dict[str, dict[int, Any]] = {} + + def start(self) -> None: + pass + + def stop(self) -> None: + pass + + def put(self, stream: str, key: int, embedding: Any) -> None: + self.puts.append((stream, key)) + self.vectors.setdefault(stream, {})[key] = embedding + + def search(self, stream: str, query: Any, k: int) -> list[tuple[int, float]]: + self.searches.append((stream, k)) + vectors = self.vectors.get(stream, {}) + if not vectors: + return [] + scored = [(key, float(emb @ query)) for key, emb in vectors.items()] + scored.sort(key=lambda x: x[1], reverse=True) + return scored[:k] + + def delete(self, stream: str, key: int) -> None: + self.vectors.get(stream, {}).pop(key, None) + + +# ── Spy grid: session factories that inject spy stores ─────────── + + +@dataclass +class SpyCase: + name: str + session_factory: Callable[ + [], Generator[tuple[Session, SpyBlobStore, SpyVectorStore], None, None] + ] + + +@contextmanager +def memory_spy_session() -> Generator[tuple[Session, SpyBlobStore, SpyVectorStore], None, None]: + from dimos.memory2.impl.memory import MemoryStore + + blob_spy = SpyBlobStore() + vec_spy = SpyVectorStore() + store = MemoryStore() + with store.session(blob_store=blob_spy, vector_store=vec_spy) as session: + yield session, blob_spy, vec_spy + + +@contextmanager +def sqlite_spy_session() -> Generator[tuple[Session, SpyBlobStore, SpyVectorStore], None, None]: + import tempfile + + from dimos.memory2.impl.sqlite import SqliteStore + + blob_spy = SpyBlobStore() + vec_spy = SpyVectorStore() + with tempfile.NamedTemporaryFile(suffix=".db") as f: + store = SqliteStore(path=f.name) + with store.session(blob_store=blob_spy, vector_store=vec_spy) as session: + yield session, blob_spy, vec_spy + + +spy_cases = [ + SpyCase(name="memory", session_factory=memory_spy_session), + SpyCase(name="sqlite", session_factory=sqlite_spy_session), +] + + +@pytest.mark.parametrize("case", spy_cases, ids=lambda c: c.name) +class TestStoreDelegation: + """Verify all backends delegate to pluggable BlobStore and VectorStore.""" + + def test_append_calls_blob_put(self, case: SpyCase) -> None: + with case.session_factory() as (session, blob_spy, _vec_spy): + s = session.stream("blobs", str) + s.append("first", ts=1.0) + s.append("second", ts=2.0) + + assert len(blob_spy.puts) == 2 + assert all(stream == "blobs" for stream, _k, _d in blob_spy.puts) + + def test_iterate_calls_blob_get(self, case: SpyCase) -> None: + with case.session_factory() as (session, blob_spy, _vec_spy): + s = session.stream("blobs", str) + s.append("a", ts=1.0) + s.append("b", ts=2.0) + + blob_spy.gets.clear() + for obs in s: + _ = obs.data + assert len(blob_spy.gets) == 2 + + def test_append_embedding_calls_vector_put(self, case: SpyCase) -> None: + import numpy as np + + from dimos.models.embedding.base import Embedding + + def _emb(v: list[float]) -> Embedding: + a = np.array(v, dtype=np.float32) + return Embedding(vector=a / (np.linalg.norm(a) + 1e-10)) + + with case.session_factory() as (session, _blob_spy, vec_spy): + s = session.stream("vecs", str) + s.append("a", ts=1.0, embedding=_emb([1, 0, 0])) + s.append("b", ts=2.0, embedding=_emb([0, 1, 0])) + s.append("c", ts=3.0) # no embedding + + assert len(vec_spy.puts) == 2 + + def test_search_calls_vector_search(self, case: SpyCase) -> None: + import numpy as np + + from dimos.models.embedding.base import Embedding + + def _emb(v: list[float]) -> Embedding: + a = np.array(v, dtype=np.float32) + return Embedding(vector=a / (np.linalg.norm(a) + 1e-10)) + + with case.session_factory() as (session, _blob_spy, vec_spy): + s = session.stream("vecs", str) + s.append("north", ts=1.0, embedding=_emb([0, 1, 0])) + s.append("east", ts=2.0, embedding=_emb([1, 0, 0])) + + results = s.search(_emb([0, 1, 0]), k=2).fetch() + assert len(vec_spy.searches) == 1 + assert results[0].data == "north" From bcb98bd447d0d9a8c5ecf62e1c9b8a50853c4e81 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Wed, 11 Mar 2026 21:59:01 +0800 Subject: [PATCH 088/118] memory2: add R*Tree spatial index for NearFilter SQL pushdown, add e2e tests R*Tree virtual tables enable O(log n) pose-based proximity queries instead of full-table Python scans. E2E tests verify import pipeline and read-only queries against real robot sensor data (video + lidar). --- data/.lfs/go2_bigoffice_v2.db.tar.gz | 3 + dimos/memory2/impl/sqlite.py | 92 ++++++++++++++--- dimos/memory2/test_e2e_import.py | 148 +++++++++++++++++++++++++++ dimos/memory2/test_e2e_query.py | 148 +++++++++++++++++++++++++++ 4 files changed, 374 insertions(+), 17 deletions(-) create mode 100644 data/.lfs/go2_bigoffice_v2.db.tar.gz create mode 100644 dimos/memory2/test_e2e_import.py create mode 100644 dimos/memory2/test_e2e_query.py diff --git a/data/.lfs/go2_bigoffice_v2.db.tar.gz b/data/.lfs/go2_bigoffice_v2.db.tar.gz new file mode 100644 index 0000000000..f091edf861 --- /dev/null +++ b/data/.lfs/go2_bigoffice_v2.db.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7f7a94b136be71044b8c4f645eaa9fbc672df5d241ae9ceb2f5de5f85ffb3668 +size 254406884 diff --git a/dimos/memory2/impl/sqlite.py b/dimos/memory2/impl/sqlite.py index 0476024cf2..fdad3318f6 100644 --- a/dimos/memory2/impl/sqlite.py +++ b/dimos/memory2/impl/sqlite.py @@ -29,8 +29,10 @@ AfterFilter, AtFilter, BeforeFilter, + NearFilter, TagsFilter, TimeRangeFilter, + _xyz, ) from dimos.memory2.livechannel.subject import SubjectChannel from dimos.memory2.store import Session, Store, StoreConfig @@ -91,24 +93,66 @@ def _reconstruct_pose( return (x, y or 0.0, z or 0.0, qx or 0.0, qy or 0.0, qz or 0.0, qw or 1.0) -def _compile_filter(f: Filter) -> tuple[str, list[Any]] | None: - """Compile a filter to SQL WHERE clause. Returns None for non-pushable filters.""" +def _compile_filter(f: Filter, stream: str, prefix: str = "") -> tuple[str, list[Any]] | None: + """Compile a filter to SQL WHERE clause. Returns None for non-pushable filters. + + ``stream`` is the raw stream name (for R*Tree table references). + ``prefix`` is a column qualifier (e.g. ``"meta."`` for JOIN queries). + """ if isinstance(f, AfterFilter): - return ("ts > ?", [f.t]) + return (f"{prefix}ts > ?", [f.t]) if isinstance(f, BeforeFilter): - return ("ts < ?", [f.t]) + return (f"{prefix}ts < ?", [f.t]) if isinstance(f, TimeRangeFilter): - return ("ts >= ? AND ts <= ?", [f.t1, f.t2]) + return (f"{prefix}ts >= ? AND {prefix}ts <= ?", [f.t1, f.t2]) if isinstance(f, AtFilter): - return ("ABS(ts - ?) <= ?", [f.t, f.tolerance]) + return (f"ABS({prefix}ts - ?) <= ?", [f.t, f.tolerance]) if isinstance(f, TagsFilter): clauses = [] params: list[Any] = [] for k, v in f.tags.items(): - clauses.append(f"json_extract(tags, '$.{k}') = ?") + clauses.append(f"json_extract({prefix}tags, '$.{k}') = ?") params.append(v) return (" AND ".join(clauses), params) - # NearFilter, PredicateFilter — not pushable + if isinstance(f, NearFilter): + pose = f.pose + if pose is None: + return None + if hasattr(pose, "position"): + pose = pose.position + cx, cy, cz = _xyz(pose) + r = f.radius + # R*Tree bounding-box pre-filter + exact squared-distance check + rtree_sql = ( + f'{prefix}id IN (SELECT id FROM "{stream}_rtree" ' + f"WHERE x_min >= ? AND x_max <= ? " + f"AND y_min >= ? AND y_max <= ? " + f"AND z_min >= ? AND z_max <= ?)" + ) + dist_sql = ( + f"(({prefix}pose_x - ?) * ({prefix}pose_x - ?) + " + f"({prefix}pose_y - ?) * ({prefix}pose_y - ?) + " + f"({prefix}pose_z - ?) * ({prefix}pose_z - ?) <= ?)" + ) + return ( + f"{rtree_sql} AND {dist_sql}", + [ + cx - r, + cx + r, + cy - r, + cy + r, + cz - r, + cz + r, # R*Tree bbox + cx, + cx, + cy, + cy, + cz, + cz, + r * r, # squared distance + ], + ) + # PredicateFilter — not pushable return None @@ -123,6 +167,7 @@ def _compile_query( Returns (sql, params, python_filters) where python_filters must be applied as post-filters in Python. """ + prefix = "meta." if join_blob else "" if join_blob: select = f'SELECT meta.id, meta.ts, meta.pose_x, meta.pose_y, meta.pose_z, meta.pose_qx, meta.pose_qy, meta.pose_qz, meta.pose_qw, json(meta.tags), blob.data FROM "{table}" AS meta JOIN "{table}_blob" AS blob ON blob.id = meta.id' else: @@ -133,12 +178,9 @@ def _compile_query( python_filters: list[Filter] = [] for f in query.filters: - compiled = _compile_filter(f) + compiled = _compile_filter(f, table, prefix) if compiled is not None: sql_part, sql_params = compiled - if join_blob: - # Qualify column references for JOIN - sql_part = sql_part.replace("ts ", "meta.ts ").replace("tags", "meta.tags") where_parts.append(sql_part) params.extend(sql_params) else: @@ -150,12 +192,10 @@ def _compile_query( # ORDER BY if query.order_field: - col = "meta." + query.order_field if join_blob else query.order_field direction = "DESC" if query.order_desc else "ASC" - sql += f" ORDER BY {col} {direction}" + sql += f" ORDER BY {prefix}{query.order_field} {direction}" else: - col = "meta.id" if join_blob else "id" - sql += f" ORDER BY {col} ASC" + sql += f" ORDER BY {prefix}id ASC" # Only push LIMIT/OFFSET to SQL when there are no Python post-filters if not python_filters and not query.search_text: @@ -180,7 +220,7 @@ def _compile_count( python_filters: list[Filter] = [] for f in query.filters: - compiled = _compile_filter(f) + compiled = _compile_filter(f, table) if compiled is not None: sql_part, sql_params = compiled where_parts.append(sql_part) @@ -287,6 +327,14 @@ def append(self, obs: Observation[T]) -> Observation[T]: assert bs is not None bs.put(self._name, row_id, encoded) + # R*Tree spatial index + if pose: + self._conn.execute( + f'INSERT INTO "{self._name}_rtree" (id, x_min, x_max, y_min, y_max, z_min, z_max) ' + "VALUES (?, ?, ?, ?, ?, ?, ?)", + (row_id, px, px, py, py, pz, pz), + ) + vs = self.config.vector_store if vs is not None: emb = getattr(obs, "embedding", None) @@ -535,6 +583,15 @@ def _create_backend( " tags BLOB DEFAULT (jsonb('{}'))" ")" ) + # R*Tree spatial index for pose queries + self._conn.execute( + f'CREATE VIRTUAL TABLE IF NOT EXISTS "{name}_rtree" USING rtree(' + " id," + " x_min, x_max," + " y_min, y_max," + " z_min, z_max" + ")" + ) self._conn.commit() # Merge shared stores as defaults @@ -555,6 +612,7 @@ def delete_stream(self, name: str) -> None: self._conn.execute(f'DROP TABLE IF EXISTS "{name}"') self._conn.execute(f'DROP TABLE IF EXISTS "{name}_blob"') self._conn.execute(f'DROP TABLE IF EXISTS "{name}_vec"') + self._conn.execute(f'DROP TABLE IF EXISTS "{name}_rtree"') self._conn.execute("DELETE FROM _streams WHERE name = ?", (name,)) self._conn.commit() diff --git a/dimos/memory2/test_e2e_import.py b/dimos/memory2/test_e2e_import.py new file mode 100644 index 0000000000..b134c306d0 --- /dev/null +++ b/dimos/memory2/test_e2e_import.py @@ -0,0 +1,148 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""E2E test: import legacy pickle replays into memory2 SqliteStore.""" + +from __future__ import annotations + +import bisect +from typing import TYPE_CHECKING, Any + +import pytest + +from dimos.memory2.impl.sqlite import SqliteStore +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.utils.data import get_data_dir +from dimos.utils.testing import TimedSensorReplay + +if TYPE_CHECKING: + from collections.abc import Generator + + from dimos.memory2.impl.sqlite import SqliteSession + +DB_PATH = get_data_dir("go2_bigoffice_v2.db") + + +class PoseIndex: + """Preloaded odom data with O(log n) closest-timestamp lookup.""" + + def __init__(self, replay: TimedSensorReplay) -> None: # type: ignore[type-arg] + self._timestamps: list[float] = [] + self._data: list[Any] = [] + for ts, data in replay.iterate_ts(): + self._timestamps.append(ts) + self._data.append(data) + + def find_closest(self, ts: float) -> Any | None: + if not self._timestamps: + return None + idx = bisect.bisect_left(self._timestamps, ts) + # Compare the two candidates around the insertion point + if idx == 0: + return self._data[0] + if idx >= len(self._timestamps): + return self._data[-1] + if ts - self._timestamps[idx - 1] <= self._timestamps[idx] - ts: + return self._data[idx - 1] + return self._data[idx] + + +@pytest.fixture(scope="module") +def store() -> Generator[SqliteStore, None, None]: + s = SqliteStore(path=str(DB_PATH)) + yield s + + +@pytest.fixture(scope="module") +def session(store: SqliteStore) -> Generator[SqliteSession, None, None]: + with store.session() as session: + yield session + + +@pytest.fixture(scope="module") +def video_replay() -> TimedSensorReplay: # type: ignore[type-arg] + return TimedSensorReplay("unitree_go2_bigoffice/video") + + +@pytest.fixture(scope="module") +def odom_index() -> PoseIndex: + return PoseIndex(TimedSensorReplay("unitree_go2_bigoffice/odom")) + + +@pytest.fixture(scope="module") +def lidar_replay() -> TimedSensorReplay: # type: ignore[type-arg] + return TimedSensorReplay("unitree_go2_bigoffice/lidar") + + +@pytest.mark.tool +class TestImportReplay: + """Import legacy pickle replay data into a memory2 SqliteStore.""" + + def test_import_video( + self, + session: SqliteSession, + video_replay: TimedSensorReplay, # type: ignore[type-arg] + odom_index: PoseIndex, + ) -> None: + video = session.stream("color_image", Image) + + count = 0 + for ts, frame in video_replay.iterate_ts(): + pose = odom_index.find_closest(ts) + print(frame) + video.append(frame, ts=ts, pose=pose) + count += 1 + + assert count > 0 + assert video.count() == count + print(f"Imported {count} video frames") + + def test_import_lidar( + self, + session: SqliteSession, + lidar_replay: TimedSensorReplay, # type: ignore[type-arg] + odom_index: PoseIndex, + ) -> None: + lidar = session.stream("lidar", PointCloud2) + + count = 0 + for ts, frame in lidar_replay.iterate_ts(): + pose = odom_index.find_closest(ts) + print(frame) + lidar.append(frame, ts=ts, pose=pose) + count += 1 + + assert count > 0 + assert lidar.count() == count + print(f"Imported {count} lidar frames") + + def test_query_imported_data(self, session: SqliteSession) -> None: + video = session.stream("color_image", Image) + lidar = session.stream("lidar", PointCloud2) + + assert video.exists() + assert lidar.exists() + + first_frame = video.first() + last_frame = video.last() + assert first_frame.ts < last_frame.ts + + mid_ts = (first_frame.ts + last_frame.ts) / 2 + subset = video.time_range(first_frame.ts, mid_ts).fetch() + assert 0 < len(subset) < video.count() + + streams = session.list_streams() + assert "color_image" in streams + assert "lidar" in streams diff --git a/dimos/memory2/test_e2e_query.py b/dimos/memory2/test_e2e_query.py new file mode 100644 index 0000000000..2415bf0a47 --- /dev/null +++ b/dimos/memory2/test_e2e_query.py @@ -0,0 +1,148 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""E2E query tests against pre-built go2_bigoffice_v2.db. + +Read-only — no writes, just verifies query paths against real robot data. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from dimos.memory2.impl.sqlite import SqliteStore +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.utils.data import get_data + +if TYPE_CHECKING: + from collections.abc import Generator + + from dimos.memory2.impl.sqlite import SqliteSession + + +@pytest.fixture(scope="module") +def session() -> Generator[SqliteSession, None, None]: + db_path = get_data("go2_bigoffice_v2.db") + store = SqliteStore(path=str(db_path)) + with store.session() as s: + yield s + + +@pytest.mark.tool +class TestE2EQuery: + """Query operations against real robot replay data.""" + + def test_list_streams(self, session: SqliteSession) -> None: + streams = session.list_streams() + assert "color_image" in streams + assert "lidar" in streams + + def test_video_count(self, session: SqliteSession) -> None: + video = session.stream("color_image", Image) + assert video.count() > 1000 + + def test_lidar_count(self, session: SqliteSession) -> None: + lidar = session.stream("lidar", PointCloud2) + assert lidar.count() > 1000 + + def test_first_last_timestamps(self, session: SqliteSession) -> None: + video = session.stream("color_image", Image) + first = video.first() + last = video.last() + assert first.ts < last.ts + duration = last.ts - first.ts + assert duration > 10.0 # at least 10s of data + + def test_time_range_filter(self, session: SqliteSession) -> None: + video = session.stream("color_image", Image) + first = video.first() + + # Grab first 5 seconds + window = video.time_range(first.ts, first.ts + 5.0).fetch() + assert len(window) > 0 + assert len(window) < video.count() + assert all(first.ts <= obs.ts <= first.ts + 5.0 for obs in window) + + def test_limit_offset_pagination(self, session: SqliteSession) -> None: + video = session.stream("color_image", Image) + page1 = video.limit(10).fetch() + page2 = video.offset(10).limit(10).fetch() + + assert len(page1) == 10 + assert len(page2) == 10 + assert page1[-1].ts < page2[0].ts # no overlap + + def test_order_by_desc(self, session: SqliteSession) -> None: + video = session.stream("color_image", Image) + last_10 = video.order_by("ts", desc=True).limit(10).fetch() + + assert len(last_10) == 10 + assert all(last_10[i].ts >= last_10[i + 1].ts for i in range(9)) + + def test_lazy_data_loads_correctly(self, session: SqliteSession) -> None: + """Verify lazy blob loading returns valid Image data.""" + from dimos.memory2.type import _Unloaded + + video = session.stream("color_image", Image) + obs = next(iter(video.limit(1))) + + # Should start lazy + assert isinstance(obs._data, _Unloaded) + + # Trigger load + frame = obs.data + assert isinstance(frame, Image) + assert frame.width > 0 + assert frame.height > 0 + + def test_iterate_window_decodes_all(self, session: SqliteSession) -> None: + """Iterate a time window and verify every frame decodes.""" + video = session.stream("color_image", Image) + first_ts = video.first().ts + + window = video.time_range(first_ts, first_ts + 2.0) + count = 0 + for obs in window: + frame = obs.data + assert isinstance(frame, Image) + count += 1 + assert count > 0 + + def test_lidar_data_loads(self, session: SqliteSession) -> None: + """Verify lidar blobs decode to PointCloud2.""" + lidar = session.stream("lidar", PointCloud2) + frame = lidar.first().data + assert isinstance(frame, PointCloud2) + + def test_poses_present(self, session: SqliteSession) -> None: + """Verify poses were stored during import.""" + video = session.stream("color_image", Image) + obs = video.first() + assert obs.pose is not None + + def test_cross_stream_time_alignment(self, session: SqliteSession) -> None: + """Video and lidar should overlap in time.""" + video = session.stream("color_image", Image) + lidar = session.stream("lidar", PointCloud2) + + v_first, v_last = video.first().ts, video.last().ts + l_first, l_last = lidar.first().ts, lidar.last().ts + + # Overlap: max of starts < min of ends + overlap_start = max(v_first, l_first) + overlap_end = min(v_last, l_last) + assert overlap_start < overlap_end, "Video and lidar should overlap in time" From 3c01a6e316e4f2053a74ae0e2e604ee33d65b90c Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Wed, 11 Mar 2026 23:01:57 +0800 Subject: [PATCH 089/118] auto index tags --- dimos/memory2/impl/sqlite.py | 13 +++++++++++++ dimos/memory2/test_e2e_query.py | 2 ++ 2 files changed, 15 insertions(+) diff --git a/dimos/memory2/impl/sqlite.py b/dimos/memory2/impl/sqlite.py index fdad3318f6..b4ad7bd520 100644 --- a/dimos/memory2/impl/sqlite.py +++ b/dimos/memory2/impl/sqlite.py @@ -250,6 +250,7 @@ def __init__(self, conn: sqlite3.Connection, name: str, **kwargs: Any) -> None: self._codec: Codec[Any] = self.config.codec # type: ignore[assignment] self._channel: LiveChannel[T] = self.config.live_channel or SubjectChannel() self._lock = threading.Lock() + self._tag_indexes: set[str] = set() @property def name(self) -> str: @@ -304,12 +305,24 @@ def _row_to_obs(self, row: tuple[Any, ...], *, has_blob: bool = False) -> Observ # ── Write ──────────────────────────────────────────────────── + def _ensure_tag_indexes(self, tags: dict[str, Any]) -> None: + """Auto-create expression indexes for any new tag keys.""" + for key in tags: + if key not in self._tag_indexes and _IDENT_RE.match(key): + self._conn.execute( + f'CREATE INDEX IF NOT EXISTS "{self._name}_tag_{key}" ' + f"ON \"{self._name}\"(json_extract(tags, '$.{key}'))" + ) + self._tag_indexes.add(key) + def append(self, obs: Observation[T]) -> Observation[T]: encoded = self._codec.encode(obs._data) pose = _decompose_pose(obs.pose) tags_json = json.dumps(obs.tags) if obs.tags else "{}" with self._lock: + if obs.tags: + self._ensure_tag_indexes(obs.tags) if pose: px, py, pz, qx, qy, qz, qw = pose else: diff --git a/dimos/memory2/test_e2e_query.py b/dimos/memory2/test_e2e_query.py index 2415bf0a47..ac26e865ff 100644 --- a/dimos/memory2/test_e2e_query.py +++ b/dimos/memory2/test_e2e_query.py @@ -48,6 +48,8 @@ class TestE2EQuery: def test_list_streams(self, session: SqliteSession) -> None: streams = session.list_streams() + print(streams) + assert "color_image" in streams assert "lidar" in streams From f3682977b824f7eef29f9ba6c2a921731efaa67e Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 12 Mar 2026 18:39:34 +0800 Subject: [PATCH 090/118] memory/stream str, and observables --- dimos/memory2/filter.py | 30 +++++++++++++-- dimos/memory2/stream.py | 85 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 107 insertions(+), 8 deletions(-) diff --git a/dimos/memory2/filter.py b/dimos/memory2/filter.py index 8b0d9ec9bd..32330192ba 100644 --- a/dimos/memory2/filter.py +++ b/dimos/memory2/filter.py @@ -14,9 +14,10 @@ from __future__ import annotations -from dataclasses import dataclass, field +from abc import ABC, abstractmethod +from dataclasses import dataclass, field, fields from itertools import islice -from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from collections.abc import Callable, Iterator @@ -29,12 +30,17 @@ # ── Filter protocol ───────────────────────────────────────────────── -@runtime_checkable -class Filter(Protocol): +@dataclass(frozen=True) +class Filter(ABC): """Any object with a .matches(obs) -> bool method can be a filter.""" + @abstractmethod def matches(self, obs: Observation[Any]) -> bool: ... + def __str__(self) -> str: + args = ", ".join(f"{f.name}={getattr(self, f.name)!r}" for f in fields(self)) + return f"{self.__class__.__name__}({args})" + # ── Concrete filters ──────────────────────────────────────────────── @@ -139,6 +145,22 @@ class StreamQuery: # Full-text search (substring / FTS5) search_text: str | None = None + def __str__(self) -> str: + parts: list[str] = [str(f) for f in self.filters] + if self.search_text is not None: + parts.append(f"search({self.search_text!r})") + if self.search_vec is not None: + k = f", k={self.search_k}" if self.search_k is not None else "" + parts.append(f"vector_search({k.lstrip(', ')})" if k else "vector_search()") + if self.order_field: + direction = " DESC" if self.order_desc else "" + parts.append(f"order_by({self.order_field}{direction})") + if self.offset_val: + parts.append(f"offset({self.offset_val})") + if self.limit_val is not None: + parts.append(f"limit({self.limit_val})") + return " | ".join(parts) + def apply( self, it: Iterator[Observation[Any]], *, live: bool = False ) -> Iterator[Observation[Any]]: diff --git a/dimos/memory2/stream.py b/dimos/memory2/stream.py index a1cabc7f90..55ade446bf 100644 --- a/dimos/memory2/stream.py +++ b/dimos/memory2/stream.py @@ -36,6 +36,9 @@ if TYPE_CHECKING: from collections.abc import Callable, Iterator + import reactivex + from reactivex.abc import DisposableBase, ObserverBase + from dimos.models.embedding.base import Embedding T = TypeVar("T") @@ -61,6 +64,28 @@ def __init__( self._xf = xf self._query = query + def __str__(self) -> str: + # Walk the source chain to collect (xf, query) pairs + chain: list[tuple[Any, StreamQuery]] = [] + current: Any = self + while isinstance(current, Stream): + chain.append((current._xf, current._query)) + current = current._source + chain.reverse() # innermost first + + # current is the Backend + name = getattr(current, "name", "?") + result = f'Stream("{name}")' + + for xf, query in chain: + if xf is not None: + result += f" -> {xf}" + q_str = str(query) + if q_str: + result += f" | {q_str}" + + return result + def is_live(self) -> bool: """True if this stream (or any ancestor in the chain) is in live mode.""" if self._query.live_buffer is not None: @@ -121,7 +146,7 @@ def at(self, t: float, tolerance: float = 1.0) -> Stream[T]: def near(self, pose: Any, radius: float) -> Stream[T]: return self._with_filter(NearFilter(pose, radius)) - def filter_tags(self, **tags: Any) -> Stream[T]: + def tags(self, **tags: Any) -> Stream[T]: return self._with_filter(TagsFilter(tags)) def order_by(self, field: str, desc: bool = False) -> Stream[T]: @@ -217,10 +242,10 @@ def save(self, target: Stream[T]) -> Stream[T]: def fetch(self) -> list[Observation[T]]: """Materialize all observations into a list.""" - if self.is_live() and self._query.limit_val is None: + if self.is_live(): raise TypeError( - ".fetch() on a live stream without .limit() would collect forever. " - "Use .limit(n).fetch(), .drain(), or .save(target) instead." + ".fetch() on a live stream would block forever. " + "Use .drain() or .save(target) instead." ) return list(self) @@ -248,6 +273,28 @@ def exists(self) -> bool: """Check if any matching observation exists.""" return next(iter(self.limit(1)), None) is not None + def get_time_range(self) -> tuple[float, float]: + """Return (min_ts, max_ts) for matching observations.""" + first = self.first() + last = self.last() + return (first.ts, last.ts) + + def summary(self) -> str: + """Return a short human-readable summary: count, time range, duration.""" + from datetime import datetime, timezone + + n = self.count() + if n == 0: + return f"{self}: empty" + + (t0, t1) = self.get_time_range() + + fmt = "%Y-%m-%d %H:%M:%S" + dt0 = datetime.fromtimestamp(t0, tz=timezone.utc).strftime(fmt) + dt1 = datetime.fromtimestamp(t1, tz=timezone.utc).strftime(fmt) + dur = t1 - t0 + return f"{self}: {n} items, {dt0} — {dt1} ({dur:.1f}s)" + def drain(self) -> int: """Consume all observations, discarding results. Returns count consumed. @@ -259,6 +306,36 @@ def drain(self) -> int: n += 1 return n + # ── Reactive ───────────────────────────────────────────────────── + + def observable(self) -> reactivex.Observable[Observation[T]]: + """Convert this stream to an RxPY Observable. + + Iteration is scheduled on the dimos thread pool so subscribe() never + blocks the calling thread. + """ + import reactivex + import reactivex.operators as ops + + from dimos.utils.threadpool import get_scheduler + + return reactivex.from_iterable(self).pipe( + ops.subscribe_on(get_scheduler()), + ) + + def subscribe( + self, + on_next: Callable[[Observation[T]], None] | ObserverBase[Observation[T]] | None = None, + on_error: Callable[[Exception], None] | None = None, + on_completed: Callable[[], None] | None = None, + ) -> DisposableBase: + """Subscribe to this stream as an RxPY Observable.""" + return self.observable().subscribe( # type: ignore[call-overload] + on_next=on_next, + on_error=on_error, + on_completed=on_completed, + ) + # ── Write ─────────────────────────────────────────────────────── def append( From f89ad3fd4c50fac93aba3eb45a9d8bfc5b512016 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 12 Mar 2026 18:43:08 +0800 Subject: [PATCH 091/118] live stream is a resource --- dimos/memory2/stream.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/dimos/memory2/stream.py b/dimos/memory2/stream.py index 55ade446bf..df6dc4636a 100644 --- a/dimos/memory2/stream.py +++ b/dimos/memory2/stream.py @@ -17,6 +17,7 @@ import time from typing import TYPE_CHECKING, Any, Generic, TypeVar +from dimos.core.resource import Resource from dimos.memory2.backend import Backend from dimos.memory2.buffer import BackpressureBuffer, KeepLast from dimos.memory2.filter import ( @@ -45,12 +46,15 @@ R = TypeVar("R") -class Stream(Generic[T]): +class Stream(Resource, Generic[T]): """Lazy, pull-based stream over observations. Every filter/transform method returns a new Stream — no computation happens until iteration. Backends handle query application for stored data; transform sources apply filters as Python predicates. + + Implements Resource so live streams can be cleanly stopped via + ``stop()`` or used as a context manager. """ def __init__( @@ -64,6 +68,17 @@ def __init__( self._xf = xf self._query = query + def start(self) -> None: + pass + + def stop(self) -> None: + """Close the live buffer (if any), unblocking iteration.""" + buf = self._query.live_buffer + if buf is not None: + buf.close() + if isinstance(self._source, Stream): + self._source.stop() + def __str__(self) -> str: # Walk the source chain to collect (xf, query) pairs chain: list[tuple[Any, StreamQuery]] = [] From a32b44da2093aee4ad53fd982861eb8cd75c180b Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 12 Mar 2026 19:08:12 +0800 Subject: [PATCH 092/118] readme work --- dimos/memory2/intro.md | 134 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 dimos/memory2/intro.md diff --git a/dimos/memory2/intro.md b/dimos/memory2/intro.md new file mode 100644 index 0000000000..ed048deadf --- /dev/null +++ b/dimos/memory2/intro.md @@ -0,0 +1,134 @@ +# Memory Intro + +## Quick start + +```python session=memory ansi=false no-result +from dimos.memory2.impl.sqlite import SqliteStore + +store = SqliteStore(path="/tmp/memory_readme.db") +session = store.session() +``` + + +```python session=memory ansi=false +logs = session.stream("logs", str) +print(logs) +``` + + +``` +Stream("logs") +``` + +Append observations and query them: + +```python session=memory ansi=false +logs.append("Motor started", ts=1.0, tags={"level": "info"}) +logs.append("Joint 3 fault", ts=2.0, tags={"level": "error"}) +logs.append("Motor stopped", ts=3.0, tags={"level": "info"}) + +print(logs.summary()) +``` + + +``` +Stream("logs"): 3 items, 1970-01-01 00:00:01 — 1970-01-01 00:00:03 (2.0s) +``` + +### Filters + +Queries are lazy — chaining filters builds a pipeline without fetching: + +```python session=memory ansi=false +print(logs.at(1.0).before(5.0).tags(level="error")) +``` + + +``` +Stream("logs") | AtFilter(t=1.0, tolerance=1.0) | BeforeFilter(t=5.0) | TagsFilter(tags={'level': 'error'}) +``` + +Available filters: `.after(t)`, `.before(t)`, `.at(t)`, `.near(pose, radius)`, `.tags(**kv)`, `.filter(predicate)`, `.search(embedding, k)`, `.order_by(field)`, `.limit(k)`, `.offset(n)`. + +### Terminals + +Terminals materialize or consume the stream: + +```python session=memory ansi=false +print(logs.before(5.0).tags(level="error").fetch()) +``` + + +``` +[Observation(id=2, ts=2.0, pose=None, tags={'level': 'error'})] +``` + +Available terminals: `.fetch()`, `.first()`, `.last()`, `.count()`, `.exists()`, `.summary()`, `.get_time_range()`, `.drain()`, `.save(target)`. + +### Transforms + +`.map(fn)` transforms each observation, returning a new stream: + +```python session=memory ansi=false +print(logs.map(lambda obs: obs.data.upper()).first()) +``` + + +``` +MOTOR STARTED +``` + +### Live queries + +Live queries backfill existing matches, then emit new ones as they arrive: + +```python session=memory ansi=false +import time + +def emit_some_logs(): + last_ts = logs.last().ts + logs.append("Heartbeat ok", ts=last_ts + 1, pose=(3.0, 1.5, 0.0), tags={"level": "info"}) + time.sleep(0.1) + logs.append("Sensor fault", ts=last_ts + 2, pose=(4.1, 2.0, 0.0), tags={"level": "error"}) + time.sleep(0.1) + logs.append("Battery charge 30%", ts=last_ts + 3, pose=(5.3, 2.5, 0.0), tags={"level": "info"}) + time.sleep(0.1) + logs.append("Overtemp", ts=last_ts + 4, pose=(6.0, 3.0, 0.0), tags={"level": "error"}) + time.sleep(0.1) + + +with logs.tags(level="error").live() as errors: + sub = errors.subscribe(lambda obs: print(f"{obs.ts} - {obs.data}")) + emit_some_logs() + sub.dispose() + +``` + + +``` +2.0 - Joint 3 fault +5.0 - Sensor fault +7.0 - Overtemp +``` + +## Spatial + live + +Filters compose freely. Here `.near()` + `.live()` + `.map()` watches for logs near a physical location — backfilling past matches and tailing new ones: + +```python session=memory ansi=false + +with logs.near((5.0, 2.0), radius=2.0).live().map(lambda obs: f"log entry around our point of interest - {obs.data}") as logs_near: + # subscription is also contextmanager + with logs_near.subscribe(print): + emit_some_logs() +``` + + +``` +log entry around our point of interest - Sensor fault +log entry around our point of interest - Battery charge 30% +log entry around our point of interest - Overtemp +log entry around our point of interest - Sensor fault +log entry around our point of interest - Battery charge 30% +log entry around our point of interest - Overtemp +``` From db2327543f73f942c47bf34d4c51255e64020a11 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 12 Mar 2026 19:29:28 +0800 Subject: [PATCH 093/118] streams and intro --- dimos/memory2/intro.md | 82 +++++++++++++++++++++++++++++------- dimos/memory2/test_stream.py | 8 ++-- dimos/memory2/transform.py | 21 ++++++++- 3 files changed, 91 insertions(+), 20 deletions(-) diff --git a/dimos/memory2/intro.md b/dimos/memory2/intro.md index ed048deadf..341d89608c 100644 --- a/dimos/memory2/intro.md +++ b/dimos/memory2/intro.md @@ -20,7 +20,7 @@ print(logs) Stream("logs") ``` -Append observations and query them: +Append observations: ```python session=memory ansi=false logs.append("Motor started", ts=1.0, tags={"level": "info"}) @@ -35,7 +35,7 @@ print(logs.summary()) Stream("logs"): 3 items, 1970-01-01 00:00:01 — 1970-01-01 00:00:03 (2.0s) ``` -### Filters +## Filters Queries are lazy — chaining filters builds a pipeline without fetching: @@ -50,7 +50,7 @@ Stream("logs") | AtFilter(t=1.0, tolerance=1.0) | BeforeFilter(t=5.0) | TagsFilt Available filters: `.after(t)`, `.before(t)`, `.at(t)`, `.near(pose, radius)`, `.tags(**kv)`, `.filter(predicate)`, `.search(embedding, k)`, `.order_by(field)`, `.limit(k)`, `.offset(n)`. -### Terminals +## Terminals Terminals materialize or consume the stream: @@ -65,7 +65,7 @@ print(logs.before(5.0).tags(level="error").fetch()) Available terminals: `.fetch()`, `.first()`, `.last()`, `.count()`, `.exists()`, `.summary()`, `.get_time_range()`, `.drain()`, `.save(target)`. -### Transforms +## Transforms `.map(fn)` transforms each observation, returning a new stream: @@ -78,7 +78,7 @@ print(logs.map(lambda obs: obs.data.upper()).first()) MOTOR STARTED ``` -### Live queries +## Live queries Live queries backfill existing matches, then emit new ones as they arrive: @@ -91,7 +91,7 @@ def emit_some_logs(): time.sleep(0.1) logs.append("Sensor fault", ts=last_ts + 2, pose=(4.1, 2.0, 0.0), tags={"level": "error"}) time.sleep(0.1) - logs.append("Battery charge 30%", ts=last_ts + 3, pose=(5.3, 2.5, 0.0), tags={"level": "info"}) + logs.append("Battery low: 30%", ts=last_ts + 3, pose=(5.3, 2.5, 0.0), tags={"level": "info"}) time.sleep(0.1) logs.append("Overtemp", ts=last_ts + 4, pose=(6.0, 3.0, 0.0), tags={"level": "error"}) time.sleep(0.1) @@ -116,19 +116,71 @@ with logs.tags(level="error").live() as errors: Filters compose freely. Here `.near()` + `.live()` + `.map()` watches for logs near a physical location — backfilling past matches and tailing new ones: ```python session=memory ansi=false - -with logs.near((5.0, 2.0), radius=2.0).live().map(lambda obs: f"log entry around our point of interest - {obs.data}") as logs_near: - # subscription is also contextmanager +near_query = logs.near((5.0, 2.0), radius=2.0).live() +with near_query.map(lambda obs: f"near POI - {obs.data}") as logs_near: with logs_near.subscribe(print): emit_some_logs() ``` ``` -log entry around our point of interest - Sensor fault -log entry around our point of interest - Battery charge 30% -log entry around our point of interest - Overtemp -log entry around our point of interest - Sensor fault -log entry around our point of interest - Battery charge 30% -log entry around our point of interest - Overtemp +near POI - Sensor fault +near POI - Battery low: 30% +near POI - Overtemp +near POI - Sensor fault +near POI - Battery low: 30% +near POI - Overtemp +``` + +## Embeddings + +Use `EmbedText` transformer with CLIP to enrich observations with embeddings, then search by similarity: + +`.search(embedding, k)` returns the top-k most similar observations by cosine similarity: + +```python session=memory ansi=false +from dimos.models.embedding.clip import CLIPModel +from dimos.memory2.embed import EmbedText + +clip = CLIPModel() + +for obs in logs.transform(EmbedText(clip)).search(clip.embed_text("hardware problem"), k=3).fetch(): + print(f"{obs.similarity:.3f} {obs.data}") +``` + + +``` +0.897 Sensor fault +0.897 Sensor fault +0.887 Battery low: 30% +``` + +The embedded stream above was ephemeral — built on the fly for one query. To persist embeddings automatically as logs arrive, pipe a live stream through the transform into a stored stream: + +```python skip +import threading + +embedded_logs = session.stream("embedded_logs", str) +threading.Thread( + target=lambda: logs.live().transform(EmbedText(clip)).save(embedded_logs), + daemon=True, +).start() + +# every new log is now automatically embedded and stored +# embedded_logs.search(query, k=5).fetch() to query at any time +``` + +## Full text search + +`.search_text(text)` does efficient substring matching: + +```python session=memory ansi=false +for obs in logs.search_text("motor").fetch(): + print(f"{obs.data}") +``` + + +``` +Motor started +Motor stopped ``` diff --git a/dimos/memory2/test_stream.py b/dimos/memory2/test_stream.py index ab8f4da7e6..46eef32e4f 100644 --- a/dimos/memory2/test_stream.py +++ b/dimos/memory2/test_stream.py @@ -161,7 +161,7 @@ def test_filter_by_tag(self): stream.append("car", ts=1.0, tags={"type": "vehicle", "wheels": 4}) stream.append("dog", ts=2.0, tags={"type": "animal", "legs": 4}) - result = stream.filter_tags(type="animal").fetch() + result = stream.tags(type="animal").fetch() assert [o.data for o in result] == ["cat", "dog"] def test_filter_multiple_tags(self): @@ -171,7 +171,7 @@ def test_filter_multiple_tags(self): stream.append("a", ts=0.0, tags={"x": 1, "y": 2}) stream.append("b", ts=1.0, tags={"x": 1, "y": 3}) - result = stream.filter_tags(x=1, y=2).fetch() + result = stream.tags(x=1, y=2).fetch() assert [o.data for o in result] == ["a"] @@ -674,7 +674,7 @@ def test_fetch_on_live_without_limit_raises(self): stream = session.stream("live_fetch") live = stream.live(buffer=Unbounded()) - with pytest.raises(TypeError, match="collect forever"): + with pytest.raises(TypeError, match="block forever"): live.fetch() def test_fetch_on_live_transform_without_limit_raises(self): @@ -685,7 +685,7 @@ def test_fetch_on_live_transform_without_limit_raises(self): xf = FnTransformer(lambda obs: obs) live_xf = stream.live(buffer=Unbounded()).transform(xf) - with pytest.raises(TypeError, match="collect forever"): + with pytest.raises(TypeError, match="block forever"): live_xf.fetch() def test_count_on_live_transform_raises(self): diff --git a/dimos/memory2/transform.py b/dimos/memory2/transform.py index 05a809f8cb..d68e25344a 100644 --- a/dimos/memory2/transform.py +++ b/dimos/memory2/transform.py @@ -15,8 +15,11 @@ from __future__ import annotations from abc import ABC, abstractmethod +import inspect from typing import TYPE_CHECKING, Any, Generic, TypeVar +from dimos.memory2.formatting import FilterRepr + if TYPE_CHECKING: from collections.abc import Callable, Iterator @@ -25,8 +28,10 @@ T = TypeVar("T") R = TypeVar("R") +_MISSING = object() + -class Transformer(ABC, Generic[T, R]): +class Transformer(FilterRepr, ABC, Generic[T, R]): """Transforms a stream of observations lazily via iterator -> iterator. Pull from upstream, yield transformed observations. Naturally supports @@ -37,6 +42,20 @@ class Transformer(ABC, Generic[T, R]): @abstractmethod def __call__(self, upstream: Iterator[Observation[T]]) -> Iterator[Observation[R]]: ... + def __str__(self) -> str: + parts: list[str] = [] + for name in inspect.signature(self.__init__).parameters: # type: ignore[misc] + val = getattr(self, name, _MISSING) + if val is _MISSING: + val = getattr(self, f"_{name}", _MISSING) + if val is _MISSING: + continue + if callable(val): + parts.append(f"{name}={getattr(val, '__name__', '...')}") + else: + parts.append(f"{name}={val!r}") + return f"{self.__class__.__name__}({', '.join(parts)})" + class FnTransformer(Transformer[T, R]): """Wraps a callable that receives an Observation and returns a new one (or None to skip).""" From 9b148947367d2e85dbe18285f15c9192bba24303 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 12 Mar 2026 19:30:46 +0800 Subject: [PATCH 094/118] renamed readme to arch --- dimos/memory2/{README.md => architecture.md} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename dimos/memory2/{README.md => architecture.md} (100%) diff --git a/dimos/memory2/README.md b/dimos/memory2/architecture.md similarity index 100% rename from dimos/memory2/README.md rename to dimos/memory2/architecture.md From 67a6a830529579eb6edd3cd719416a9690f9c8fd Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 12 Mar 2026 19:40:38 +0800 Subject: [PATCH 095/118] =?UTF-8?q?Rename=20memory2=20=E2=86=92=20memory,?= =?UTF-8?q?=20fix=20all=20imports=20and=20type=20errors?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace all dimos.memory2 imports with dimos.memory - Make concrete filter classes inherit from Filter ABC - Fix mypy errors: type narrowing, Optional guards, annotation mismatches - Fix test_impl.py: filter_tags() → tags() - Remove intro.py (superseded by intro.md) - Delete old dimos/memory2/ directory --- dimos/memory/__init__.py | 84 +- dimos/{memory2 => memory}/architecture.md | 4 +- dimos/{memory2 => memory}/backend.py | 8 +- .../{memory2 => memory}/blobstore/__init__.py | 6 +- .../blobstore/blobstore.md | 0 dimos/{memory2 => memory}/blobstore/file.py | 2 +- dimos/{memory2 => memory}/blobstore/sqlite.py | 2 +- .../blobstore/test_blobstore.py | 6 +- dimos/{memory2 => memory}/buffer.py | 0 dimos/memory/codec.py | 135 -- dimos/{memory2 => memory}/codecs/README.md | 4 +- dimos/{memory2 => memory}/codecs/__init__.py | 4 +- dimos/{memory2 => memory}/codecs/base.py | 6 +- dimos/{memory2 => memory}/codecs/jpeg.py | 0 dimos/{memory2 => memory}/codecs/lcm.py | 0 dimos/{memory2 => memory}/codecs/pickle.py | 0 .../{memory2 => memory}/codecs/test_codecs.py | 16 +- dimos/memory/docs/api.md | 679 --------- dimos/memory/docs/query_objects.md | 155 -- dimos/memory/docs/questions.md | 56 - dimos/memory/docs/sqlite.md | 621 -------- dimos/memory/docs/transform.md | 180 --- dimos/{memory2 => memory}/embed.py | 8 +- dimos/{memory2 => memory}/embeddings.md | 2 +- dimos/{memory2 => memory}/filter.py | 18 +- dimos/memory/formatting.py | 235 +-- dimos/{memory2 => memory}/impl/README.md | 12 +- dimos/memory/impl/__init__.py | 13 + dimos/{memory2 => memory}/impl/memory.py | 21 +- dimos/memory/impl/sqlite.py | 1334 ++++++----------- dimos/memory/impl/test_e2e_export.py | 177 --- dimos/memory/impl/test_sqlite.py | 1039 ------------- dimos/memory/impl/test_sqlite_e2e.py | 113 -- dimos/memory/ingest.py | 51 - dimos/{memory2 => memory}/intro.md | 4 +- dimos/memory/livechannel/__init__.py | 4 + .../livechannel/subject.py | 6 +- dimos/memory/module.py | 100 -- dimos/memory/readme.md | 455 ------ dimos/memory/rerun.py | 81 - dimos/memory/store.py | 204 ++- dimos/memory/stream.py | 912 ++++------- dimos/{memory2 => memory}/streaming.md | 0 dimos/{memory2 => memory}/test_blobstore.py | 8 +- dimos/{memory2 => memory}/test_buffer.py | 2 +- dimos/{memory2 => memory}/test_e2e_import.py | 8 +- .../test_e2e_processing.py} | 3 + dimos/{memory2 => memory}/test_e2e_query.py | 6 +- dimos/{memory2 => memory}/test_embedding.py | 24 +- dimos/{memory2 => memory}/test_impl.py | 30 +- dimos/memory/test_memory.py | 72 - dimos/memory/test_projection.py | 203 --- dimos/{memory2 => memory}/test_save.py | 10 +- dimos/{memory2 => memory}/test_stream.py | 12 +- dimos/memory/test_stream_repr.py | 194 --- dimos/memory/test_transformer.py | 252 ---- dimos/memory/tests/__init__.py | 0 dimos/{memory2 => memory}/transform.py | 4 +- dimos/memory/transformer.py | 327 ---- dimos/memory/type.py | 271 +--- .../vectorstore/__init__.py | 4 +- .../{memory2 => memory}/vectorstore/memory.py | 2 +- .../{memory2 => memory}/vectorstore/sqlite.py | 2 +- dimos/memory2/__init__.py | 70 - dimos/memory2/impl/sqlite.py | 671 --------- dimos/memory2/livechannel/__init__.py | 4 - dimos/memory2/store.py | 164 -- dimos/memory2/stream.py | 381 ----- dimos/memory2/type.py | 114 -- 69 files changed, 1124 insertions(+), 8471 deletions(-) rename dimos/{memory2 => memory}/architecture.md (99%) rename dimos/{memory2 => memory}/backend.py (96%) rename dimos/{memory2 => memory}/blobstore/__init__.py (80%) rename dimos/{memory2 => memory}/blobstore/blobstore.md (100%) rename dimos/{memory2 => memory}/blobstore/file.py (97%) rename dimos/{memory2 => memory}/blobstore/sqlite.py (98%) rename dimos/{memory2 => memory}/blobstore/test_blobstore.py (95%) rename dimos/{memory2 => memory}/buffer.py (100%) delete mode 100644 dimos/memory/codec.py rename dimos/{memory2 => memory}/codecs/README.md (94%) rename dimos/{memory2 => memory}/codecs/__init__.py (85%) rename dimos/{memory2 => memory}/codecs/base.py (88%) rename dimos/{memory2 => memory}/codecs/jpeg.py (100%) rename dimos/{memory2 => memory}/codecs/lcm.py (100%) rename dimos/{memory2 => memory}/codecs/pickle.py (100%) rename dimos/{memory2 => memory}/codecs/test_codecs.py (91%) delete mode 100644 dimos/memory/docs/api.md delete mode 100644 dimos/memory/docs/query_objects.md delete mode 100644 dimos/memory/docs/questions.md delete mode 100644 dimos/memory/docs/sqlite.md delete mode 100644 dimos/memory/docs/transform.md rename dimos/{memory2 => memory}/embed.py (93%) rename dimos/{memory2 => memory}/embeddings.md (99%) rename dimos/{memory2 => memory}/filter.py (96%) rename dimos/{memory2 => memory}/impl/README.md (91%) rename dimos/{memory2 => memory}/impl/memory.py (92%) delete mode 100644 dimos/memory/impl/test_e2e_export.py delete mode 100644 dimos/memory/impl/test_sqlite.py delete mode 100644 dimos/memory/impl/test_sqlite_e2e.py delete mode 100644 dimos/memory/ingest.py rename dimos/{memory2 => memory}/intro.md (98%) create mode 100644 dimos/memory/livechannel/__init__.py rename dimos/{memory2 => memory}/livechannel/subject.py (92%) delete mode 100644 dimos/memory/module.py delete mode 100644 dimos/memory/readme.md delete mode 100644 dimos/memory/rerun.py rename dimos/{memory2 => memory}/streaming.md (100%) rename dimos/{memory2 => memory}/test_blobstore.py (97%) rename dimos/{memory2 => memory}/test_buffer.py (96%) rename dimos/{memory2 => memory}/test_e2e_import.py (94%) rename dimos/{memory2/impl/__init__.py => memory/test_e2e_processing.py} (95%) rename dimos/{memory2 => memory}/test_e2e_query.py (97%) rename dimos/{memory2 => memory}/test_embedding.py (96%) rename dimos/{memory2 => memory}/test_impl.py (95%) delete mode 100644 dimos/memory/test_memory.py delete mode 100644 dimos/memory/test_projection.py rename dimos/{memory2 => memory}/test_save.py (95%) rename dimos/{memory2 => memory}/test_stream.py (98%) delete mode 100644 dimos/memory/test_stream_repr.py delete mode 100644 dimos/memory/test_transformer.py delete mode 100644 dimos/memory/tests/__init__.py rename dimos/{memory2 => memory}/transform.py (97%) delete mode 100644 dimos/memory/transformer.py rename dimos/{memory2 => memory}/vectorstore/__init__.py (83%) rename dimos/{memory2 => memory}/vectorstore/memory.py (97%) rename dimos/{memory2 => memory}/vectorstore/sqlite.py (98%) delete mode 100644 dimos/memory2/__init__.py delete mode 100644 dimos/memory2/impl/sqlite.py delete mode 100644 dimos/memory2/livechannel/__init__.py delete mode 100644 dimos/memory2/store.py delete mode 100644 dimos/memory2/stream.py delete mode 100644 dimos/memory2/type.py diff --git a/dimos/memory/__init__.py b/dimos/memory/__init__.py index 489e3f568d..b98418e415 100644 --- a/dimos/memory/__init__.py +++ b/dimos/memory/__init__.py @@ -1,36 +1,70 @@ -from dimos.memory.codec import Codec, JpegCodec, LcmCodec, PickleCodec, codec_for_type -from dimos.memory.store import Session, Store, StreamNamespace -from dimos.memory.stream import EmbeddingStream, ObservationSet, Stream, TextStream -from dimos.memory.transformer import ( - CaptionTransformer, - EmbeddingTransformer, - PerItemTransformer, - TextEmbeddingTransformer, - Transformer, +from dimos.memory.backend import Backend, LiveChannel, VectorStore +from dimos.memory.buffer import ( + BackpressureBuffer, + Bounded, + ClosedError, + DropNew, + KeepLast, + Unbounded, ) -from dimos.memory.type import ( - EmbeddingObservation, - Observation, +from dimos.memory.embed import EmbedImages, EmbedText +from dimos.memory.filter import ( + AfterFilter, + AtFilter, + BeforeFilter, + Filter, + NearFilter, + PredicateFilter, + StreamQuery, + TagsFilter, + TimeRangeFilter, ) +from dimos.memory.impl.memory import ListBackend, MemorySession, MemoryStore +from dimos.memory.impl.sqlite import SqliteBackend, SqliteSession, SqliteStore, SqliteStoreConfig +from dimos.memory.livechannel import SubjectChannel +from dimos.memory.store import Session, SessionConfig, Store, StoreConfig, StreamNamespace +from dimos.memory.stream import Stream +from dimos.memory.transform import FnTransformer, QualityWindow, Transformer +from dimos.memory.type import EmbeddedObservation, Observation __all__ = [ - "CaptionTransformer", - "Codec", - "EmbeddingObservation", - "EmbeddingStream", - "EmbeddingTransformer", - "JpegCodec", - "LcmCodec", + "AfterFilter", + "AtFilter", + "Backend", + "BackpressureBuffer", + "BeforeFilter", + "Bounded", + "ClosedError", + "DropNew", + "EmbedImages", + "EmbedText", + "EmbeddedObservation", + "Filter", + "FnTransformer", + "KeepLast", + "ListBackend", + "LiveChannel", + "MemorySession", + "MemoryStore", + "NearFilter", "Observation", - "ObservationSet", - "PerItemTransformer", - "PickleCodec", + "PredicateFilter", + "QualityWindow", "Session", + "SessionConfig", + "SqliteBackend", + "SqliteSession", + "SqliteStore", + "SqliteStoreConfig", "Store", + "StoreConfig", "Stream", "StreamNamespace", - "TextEmbeddingTransformer", - "TextStream", + "StreamQuery", + "SubjectChannel", + "TagsFilter", + "TimeRangeFilter", "Transformer", - "codec_for_type", + "Unbounded", + "VectorStore", ] diff --git a/dimos/memory2/architecture.md b/dimos/memory/architecture.md similarity index 99% rename from dimos/memory2/architecture.md rename to dimos/memory/architecture.md index 25a37a22f5..7fba703f4c 100644 --- a/dimos/memory2/architecture.md +++ b/dimos/memory/architecture.md @@ -1,4 +1,4 @@ -# memory2 +# memory Observation storage and streaming layer for DimOS. Pull-based, lazy, composable. @@ -74,7 +74,7 @@ Transform-sourced streams (post `.transform()`) always use `StreamQuery.apply()` ## Quick start ```python -from dimos.memory2 import MemoryStore +from dimos.memory import MemoryStore store = MemoryStore() with store.session() as session: diff --git a/dimos/memory2/backend.py b/dimos/memory/backend.py similarity index 96% rename from dimos/memory2/backend.py rename to dimos/memory/backend.py index 928b74e229..cc36f79239 100644 --- a/dimos/memory2/backend.py +++ b/dimos/memory/backend.py @@ -25,10 +25,10 @@ from reactivex.abc import DisposableBase - from dimos.memory2.buffer import BackpressureBuffer - from dimos.memory2.codecs.base import Codec - from dimos.memory2.filter import StreamQuery - from dimos.memory2.type import Observation + from dimos.memory.buffer import BackpressureBuffer + from dimos.memory.codecs.base import Codec + from dimos.memory.filter import StreamQuery + from dimos.memory.type import Observation from dimos.models.embedding.base import Embedding T = TypeVar("T") diff --git a/dimos/memory2/blobstore/__init__.py b/dimos/memory/blobstore/__init__.py similarity index 80% rename from dimos/memory2/blobstore/__init__.py rename to dimos/memory/blobstore/__init__.py index 8f78d7c439..f0b3fe76f5 100644 --- a/dimos/memory2/blobstore/__init__.py +++ b/dimos/memory/blobstore/__init__.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.memory2.backend import BlobStore -from dimos.memory2.blobstore.file import FileBlobStore -from dimos.memory2.blobstore.sqlite import SqliteBlobStore +from dimos.memory.backend import BlobStore +from dimos.memory.blobstore.file import FileBlobStore +from dimos.memory.blobstore.sqlite import SqliteBlobStore __all__ = ["BlobStore", "FileBlobStore", "SqliteBlobStore"] diff --git a/dimos/memory2/blobstore/blobstore.md b/dimos/memory/blobstore/blobstore.md similarity index 100% rename from dimos/memory2/blobstore/blobstore.md rename to dimos/memory/blobstore/blobstore.md diff --git a/dimos/memory2/blobstore/file.py b/dimos/memory/blobstore/file.py similarity index 97% rename from dimos/memory2/blobstore/file.py rename to dimos/memory/blobstore/file.py index 54ec80e284..de8c6e8bc2 100644 --- a/dimos/memory2/blobstore/file.py +++ b/dimos/memory/blobstore/file.py @@ -17,7 +17,7 @@ from pathlib import Path from typing import TYPE_CHECKING -from dimos.memory2.backend import BlobStore +from dimos.memory.backend import BlobStore if TYPE_CHECKING: import os diff --git a/dimos/memory2/blobstore/sqlite.py b/dimos/memory/blobstore/sqlite.py similarity index 98% rename from dimos/memory2/blobstore/sqlite.py rename to dimos/memory/blobstore/sqlite.py index 0fd144c532..235152d796 100644 --- a/dimos/memory2/blobstore/sqlite.py +++ b/dimos/memory/blobstore/sqlite.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING -from dimos.memory2.backend import BlobStore +from dimos.memory.backend import BlobStore if TYPE_CHECKING: import sqlite3 diff --git a/dimos/memory2/blobstore/test_blobstore.py b/dimos/memory/blobstore/test_blobstore.py similarity index 95% rename from dimos/memory2/blobstore/test_blobstore.py rename to dimos/memory/blobstore/test_blobstore.py index fe05cfa84f..83f76fa2ec 100644 --- a/dimos/memory2/blobstore/test_blobstore.py +++ b/dimos/memory/blobstore/test_blobstore.py @@ -22,14 +22,14 @@ import pytest -from dimos.memory2.blobstore.file import FileBlobStore -from dimos.memory2.blobstore.sqlite import SqliteBlobStore +from dimos.memory.blobstore.file import FileBlobStore +from dimos.memory.blobstore.sqlite import SqliteBlobStore if TYPE_CHECKING: from collections.abc import Callable, Generator from pathlib import Path - from dimos.memory2.backend import BlobStore + from dimos.memory.backend import BlobStore # ── Case definition ──────────────────────────────────────────────── diff --git a/dimos/memory2/buffer.py b/dimos/memory/buffer.py similarity index 100% rename from dimos/memory2/buffer.py rename to dimos/memory/buffer.py diff --git a/dimos/memory/codec.py b/dimos/memory/codec.py deleted file mode 100644 index 9351b3bf84..0000000000 --- a/dimos/memory/codec.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import importlib -import pickle -from typing import TYPE_CHECKING, Any, Protocol, TypeVar - -from dimos.msgs.sensor_msgs.Image import Image - -if TYPE_CHECKING: - from dimos.msgs.protocol import DimosMsg - -T = TypeVar("T") - - -class Codec(Protocol[T]): - """Encodes/decodes payloads for storage.""" - - def encode(self, value: T) -> bytes: ... - def decode(self, data: bytes) -> T: ... - - -class LcmCodec: - """Codec for DimosMsg types — uses lcm_encode/lcm_decode.""" - - def __init__(self, msg_type: type[DimosMsg]) -> None: - self._msg_type = msg_type - - def encode(self, value: DimosMsg) -> bytes: - return value.lcm_encode() - - def decode(self, data: bytes) -> DimosMsg: - return self._msg_type.lcm_decode(data) - - -class JpegCodec: - """Codec for Image types — stores as JPEG bytes (lossy, ~10-20x smaller). - - Uses TurboJPEG (libjpeg-turbo) for 2-5x faster encode/decode vs OpenCV. - Preserves ``frame_id`` as a short header: ````. - Pixel data is lossy-compressed; ``ts`` is NOT preserved (stored in the meta table). - """ - - def __init__(self, quality: int = 50) -> None: - self._quality = quality - from turbojpeg import TurboJPEG # type: ignore[import-untyped] - - self._tj = TurboJPEG() - - _TJPF_MAP: dict[str, int] | None = None - - @staticmethod - def _get_tjpf_map() -> dict[str, int]: - if JpegCodec._TJPF_MAP is None: - from turbojpeg import TJPF_BGR, TJPF_GRAY, TJPF_RGB # type: ignore[import-untyped] - - JpegCodec._TJPF_MAP = {"BGR": TJPF_BGR, "RGB": TJPF_RGB, "GRAY": TJPF_GRAY} - return JpegCodec._TJPF_MAP - - def encode(self, value: Any) -> bytes: - import struct - - from turbojpeg import TJPF_BGR # type: ignore[import-untyped] - - pf = self._get_tjpf_map().get(value.format.value, TJPF_BGR) - jpeg_data: bytes = self._tj.encode(value.data, quality=self._quality, pixel_format=pf) - frame_id = (value.frame_id or "").encode("utf-8") - header = struct.pack(" Any: - import struct - - from dimos.msgs.sensor_msgs.Image import Image, ImageFormat - - fid_len = struct.unpack(" bytes: - return pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL) - - def decode(self, data: bytes) -> Any: - return pickle.loads(data) - - -def codec_for_type(payload_type: type | None) -> Codec[Any]: - """Auto-select codec based on payload type.""" - if payload_type is not None: - # Image → JPEG by default (much smaller than LCM raw pixels) - - if issubclass(payload_type, Image): - return JpegCodec() - if hasattr(payload_type, "lcm_encode") and hasattr(payload_type, "lcm_decode"): - return LcmCodec(payload_type) # type: ignore[arg-type] - return PickleCodec() - - -def type_to_module_path(t: type) -> str: - """Return fully qualified module path for a type, e.g. 'dimos.msgs.sensor_msgs.Image.Image'.""" - return f"{t.__module__}.{t.__qualname__}" - - -def module_path_to_type(path: str) -> type | None: - """Resolve a fully qualified module path back to a type. Returns None on failure.""" - parts = path.rsplit(".", 1) - if len(parts) != 2: - return None - module_path, class_name = parts - try: - mod = importlib.import_module(module_path) - return getattr(mod, class_name, None) # type: ignore[no-any-return] - except (ImportError, AttributeError): - return None diff --git a/dimos/memory2/codecs/README.md b/dimos/memory/codecs/README.md similarity index 94% rename from dimos/memory2/codecs/README.md rename to dimos/memory/codecs/README.md index ff6b701054..719369f29a 100644 --- a/dimos/memory2/codecs/README.md +++ b/dimos/memory/codecs/README.md @@ -23,7 +23,7 @@ class Codec(Protocol[T]): `codec_for(payload_type)` picks the right codec: ```python -from dimos.memory2.codecs import codec_for +from dimos.memory.codecs import codec_for codec_for(Image) # → JpegCodec(quality=50) codec_for(SomeLcmMsg) # → LcmCodec(SomeLcmMsg) (if has lcm_encode/lcm_decode) @@ -33,7 +33,7 @@ codec_for(None) # → PickleCodec() ## Writing a new codec -1. Create `dimos/memory2/codecs/mycodec.py`: +1. Create `dimos/memory/codecs/mycodec.py`: ```python class MyCodec: diff --git a/dimos/memory2/codecs/__init__.py b/dimos/memory/codecs/__init__.py similarity index 85% rename from dimos/memory2/codecs/__init__.py rename to dimos/memory/codecs/__init__.py index a7feb3bce3..fe4b870250 100644 --- a/dimos/memory2/codecs/__init__.py +++ b/dimos/memory/codecs/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.memory2.codecs.base import Codec, codec_for -from dimos.memory2.codecs.pickle import PickleCodec +from dimos.memory.codecs.base import Codec, codec_for +from dimos.memory.codecs.pickle import PickleCodec __all__ = ["Codec", "PickleCodec", "codec_for"] diff --git a/dimos/memory2/codecs/base.py b/dimos/memory/codecs/base.py similarity index 88% rename from dimos/memory2/codecs/base.py rename to dimos/memory/codecs/base.py index 4c2b3865f5..12ea658906 100644 --- a/dimos/memory2/codecs/base.py +++ b/dimos/memory/codecs/base.py @@ -28,17 +28,17 @@ def decode(self, data: bytes) -> T: ... def codec_for(payload_type: type[Any] | None = None) -> Codec[Any]: """Auto-select codec based on payload type.""" - from dimos.memory2.codecs.pickle import PickleCodec + from dimos.memory.codecs.pickle import PickleCodec if payload_type is not None: from dimos.msgs.sensor_msgs.Image import Image if issubclass(payload_type, Image): - from dimos.memory2.codecs.jpeg import JpegCodec + from dimos.memory.codecs.jpeg import JpegCodec return JpegCodec() if hasattr(payload_type, "lcm_encode") and hasattr(payload_type, "lcm_decode"): - from dimos.memory2.codecs.lcm import LcmCodec + from dimos.memory.codecs.lcm import LcmCodec return LcmCodec(payload_type) return PickleCodec() diff --git a/dimos/memory2/codecs/jpeg.py b/dimos/memory/codecs/jpeg.py similarity index 100% rename from dimos/memory2/codecs/jpeg.py rename to dimos/memory/codecs/jpeg.py diff --git a/dimos/memory2/codecs/lcm.py b/dimos/memory/codecs/lcm.py similarity index 100% rename from dimos/memory2/codecs/lcm.py rename to dimos/memory/codecs/lcm.py diff --git a/dimos/memory2/codecs/pickle.py b/dimos/memory/codecs/pickle.py similarity index 100% rename from dimos/memory2/codecs/pickle.py rename to dimos/memory/codecs/pickle.py diff --git a/dimos/memory2/codecs/test_codecs.py b/dimos/memory/codecs/test_codecs.py similarity index 91% rename from dimos/memory2/codecs/test_codecs.py rename to dimos/memory/codecs/test_codecs.py index 8f3eb17c10..7d5057d589 100644 --- a/dimos/memory2/codecs/test_codecs.py +++ b/dimos/memory/codecs/test_codecs.py @@ -24,7 +24,7 @@ import pytest -from dimos.memory2.codecs.base import Codec, codec_for +from dimos.memory.codecs.base import Codec, codec_for if TYPE_CHECKING: from collections.abc import Callable @@ -44,7 +44,7 @@ class Case: def _pickle_case() -> Case: - from dimos.memory2.codecs.pickle import PickleCodec + from dimos.memory.codecs.pickle import PickleCodec return Case( name="pickle", @@ -54,7 +54,7 @@ def _pickle_case() -> Case: def _lcm_case() -> Case: - from dimos.memory2.codecs.lcm import LcmCodec + from dimos.memory.codecs.lcm import LcmCodec from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.geometry_msgs.Vector3 import Vector3 @@ -86,7 +86,7 @@ def _jpeg_eq(original: Any, decoded: Any) -> bool: def _jpeg_case() -> Case: - from dimos.memory2.codecs.jpeg import JpegCodec + from dimos.memory.codecs.jpeg import JpegCodec from dimos.utils.testing import TimedSensorReplay replay = TimedSensorReplay("unitree_go2_bigoffice/video") @@ -132,23 +132,23 @@ class TestCodecFor: """codec_for() auto-selects the right codec.""" def test_none_returns_pickle(self) -> None: - from dimos.memory2.codecs.pickle import PickleCodec + from dimos.memory.codecs.pickle import PickleCodec assert isinstance(codec_for(None), PickleCodec) def test_unknown_type_returns_pickle(self) -> None: - from dimos.memory2.codecs.pickle import PickleCodec + from dimos.memory.codecs.pickle import PickleCodec assert isinstance(codec_for(dict), PickleCodec) def test_lcm_type_returns_lcm(self) -> None: - from dimos.memory2.codecs.lcm import LcmCodec + from dimos.memory.codecs.lcm import LcmCodec from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped assert isinstance(codec_for(PoseStamped), LcmCodec) def test_image_type_returns_jpeg(self) -> None: - from dimos.memory2.codecs.jpeg import JpegCodec + from dimos.memory.codecs.jpeg import JpegCodec from dimos.msgs.sensor_msgs.Image import Image assert isinstance(codec_for(Image), JpegCodec) diff --git a/dimos/memory/docs/api.md b/dimos/memory/docs/api.md deleted file mode 100644 index cde3a818ff..0000000000 --- a/dimos/memory/docs/api.md +++ /dev/null @@ -1,679 +0,0 @@ -# Memory2 API — Unified Stream - -## Core Idea - -One type: `Stream[T]`. Everything is a stream — stored, filtered, transformed. The user never thinks about Query vs ObservationSet vs Stream. They just chain operations. - -## Creating Streams - -```python -store = SqliteStore("/data/robot.db") -session = store.session() - -# Root stored stream — backed by DB -images = session.stream("images", Image, - pose_provider=lambda: tf.get_pose("world", "base_link")) - -logs = session.text_stream("logs", str, - pose_provider=lambda: tf.get_pose("world", "base_link")) -``` - -## Writing - -```python -images.append(frame) # ts + pose auto-filled -logs.append("Motor fault on joint 3") # ts + pose auto-filled -images.append(frame, pose=explicit_pose, tags={"cam": "front"}) -``` - -Only meaningful on stored (DB-backed) streams. - -### Batch ingest - -The `ingest()` helper accepts any iterable of `(ts, payload)` — e.g. from a replay: - -```python -from dimos.memory.ingest import ingest - -replay = TimedSensorReplay("unitree_go2_bigoffice/video") -odom = TimedSensorReplay("unitree_go2_bigoffice/odom") - -raw = session.stream("raw_video", Image) -n = ingest(raw, replay.iterate_ts(seek=5.0, duration=300.0), pose_source=odom) -# pose_source.find_closest(ts) is called per frame to attach odom poses -``` - -## Filtering - -Every filter returns a new `Stream[T]`. Lazy — nothing executes until a terminal. - -```python -recent = images.after(one_hour_ago) -kitchen = recent.near(kitchen_pose, 5.0) -tagged = kitchen.filter_tags(cam="front") - -# Or chained -images.after(one_hour_ago).near(kitchen_pose, 5.0).filter_tags(cam="front") -``` - -### Filter methods - -```python -class Stream(Generic[T]): - # Temporal - def after(self, t: float) -> Stream[T]: ... - def before(self, t: float) -> Stream[T]: ... - def time_range(self, t1: float, t2: float) -> Stream[T]: ... - def at(self, t: float, *, tolerance: float = 1.0) -> Stream[T]: ... - - # Spatial - def near(self, pose: PoseLike, radius: float) -> Stream[T]: ... - - # Tags - def filter_tags(self, **tags: Any) -> Stream[T]: ... - -class EmbeddingStream(Stream[T]): - def search_embedding(self, query: Embedding | list[float] | str | Any, - *, k: int, raw: bool = False) -> Stream[Any]: ... - -class TextStream(Stream[T]): - def search_text(self, text: str, *, k: int | None = None) -> TextStream[T]: ... -``` - -## Terminals & Iteration - -`Stream` is directly iterable — pages internally, never loads everything at once. - -```python -# Direct iteration (lazy, memory-efficient — uses fetch_pages internally) -for row in images.after(t).near(kitchen_pose, 5.0): - print(row.data) - -# Explicit fetch when you want the full list in memory -all_rows = images.after(t).fetch() # returns ObservationSet - -# Other terminals -row = images.after(t).one() # single best match -row = images.last() # most recent -n = images.after(t).count() # count without fetching - -# Pagination -page = images.order_by("ts").limit(50).offset(100).fetch() -``` - -### Terminal methods - -```python -class Stream(Generic[T]): - def __iter__(self) -> Iterator[Observation]: ... # lazy, pages internally - def fetch(self) -> ObservationSet[T]: ... # all results, list-like + stream-like - def fetch_pages(self, batch_size: int = 128) -> Iterator[list[Observation]]: ... - def one(self) -> Observation: ... - def last(self) -> Observation: ... - def count(self) -> int: ... - def order_by(self, field: str, *, desc: bool = False) -> Stream[T]: ... - def limit(self, k: int) -> Stream[T]: ... - def offset(self, n: int) -> Stream[T]: ... -``` - -### ObservationSet - -`fetch()` returns an `ObservationSet` — a list-like object that also supports stream chaining: - -```python -results = embeddings.search_embedding("a hallway", k=50).fetch() - -len(results) # list-like -results[0] # indexing -for r in results: # iteration - print(r.data) - -# Stream-like — further filter/transform the materialized results -results.after(t).fetch() -results.transform(caption_xf).fetch() -``` - -## Observation - -```python -@dataclass -class Observation: - id: int - ts: float | None = None - pose: PoseStamped | None = None - tags: dict[str, Any] = field(default_factory=dict) - parent_id: int | None = None # lineage: source observation id - - @property - def data(self) -> Any: - """Lazy payload. Pre-populated from append/transform, fetched on demand from query.""" - ... - -@dataclass -class EmbeddingObservation(Observation): - """Returned by EmbeddingStream terminals. Auto-projects .data to source stream.""" - - similarity: float | None = None # 0..1, populated by search_embedding (vec0 cosine) - - @property - def data(self) -> Any: - """Lazily loads from the source stream (e.g., Image), not the embedding.""" - ... - - @property - def embedding(self) -> Embedding: - """The Embedding object (has .vector, supports @ for cosine similarity).""" - ... -``` - -## Transformer - -A `Transformer` receives the full source stream and decides what to do — which items to process, how to batch, whether to use embeddings as a cheap proxy, etc. - -```python -class Transformer(ABC, Generic[T, R]): - """Transforms a source stream into results on a target stream.""" - - def process(self, source: Stream[T], target: Stream[R]) -> None: - """Batch/historical processing. Has full access to source — can query, - filter, use embeddings, batch, skip frames, etc.""" - ... - - def on_append(self, obs: Observation, target: Stream[R]) -> None: - """Reactive processing. Called per new item. Default: process([obs]).""" - ... - - supports_backfill: bool = True - supports_live: bool = True - output_type: type | None = None # determines target stream kind -``` - -### Simple lambdas (sugar) - -`Callable[[T], R | list[R] | None]` is auto-wrapped into a naive per-item Transformer: - -```python -# These are equivalent: -images.transform(lambda img: vlm.detect(img, "cigarettes")) -images.transform(PerItemTransformer(lambda img: vlm.detect(img, "cigarettes"))) -``` - -- `R` → single result -- `list[R]` → multiple results (e.g., multiple detections per frame) -- `None` → skip (no result for this input) - -### EmbeddingTransformer - -`EmbeddingTransformer` wraps an `EmbeddingModel` as a `Transformer[T, Embedding]`. When the output type is `Embedding`, `.store()` creates an `EmbeddingStream` (vec0 index, `search_embedding`, `EmbeddingObservation`). - -```python -# EmbeddingTransformer wraps the model -img_emb = images.transform(EmbeddingTransformer(CLIPModel())).store("img_emb") - -# Now img_emb is an EmbeddingStream -results = img_emb.search_embedding(query_emb, k=20).fetch() -# results[0].data → Image (auto-projected from source) -# results[0].embedding → Embedding (supports @ for cosine similarity) -``` - -### Chaining transforms - -```python -# Filter → transform → store -images.after(one_hour_ago) \ - .near(kitchen_pose, 5.0) \ - .transform(EmbeddingTransformer(CLIPModel())) \ - .store("kitchen_embeddings") - -# Filter → transform → fetch (in-memory, not persisted) -results = images.after(one_hour_ago) \ - .near(kitchen_pose, 5.0) \ - .transform(EmbeddingTransformer(CLIPModel())) \ - .fetch() - -# Filter → embed → detect → store (chained: detector gets EmbeddingObservation) -images.near(kitchen_pose, 5.0) \ - .transform(EmbeddingTransformer(CLIPModel())) \ - .transform(CigaretteDetector(vlm, clip)) \ - .store("kitchen_cigarette_detections") -``` - -### Backfill / Live modes - -```python -# Both (default): backfill existing + subscribe to new -images.transform(detector).store("detections") - -# Live only: skip backfill, only process new items -images.transform(detector, live=True).store("detections") - -# Backfill only: process existing, don't subscribe -images.transform(detector, backfill=True).store("detections") - -# Backfill only: process existing, and subscribe -images.transform(detector, backfill=True, live=True).store("detections") - -# Incremental: re-running a stored transform resumes from last processed item -# (uses lineage parent_id to skip already-processed source rows) -``` - -## Storing - -`.store(name)` materializes a stream to DB. After storing, results are queryable and persistent. - -```python -# In-memory transform result — not persisted -detections = images.transform(detect_fn) - -# Persist it -detections.store("detections") - -# Now it's a DB-backed stream, queryable -stored = session.stream("detections") -rows = stored.after(t).fetch() -``` - -`.store()` also sets up lineage — every stored row gets `parent_id` pointing back to its source. - -Stream type is determined by what the Transformer produces: -- `Embedding` output → `EmbeddingStream` (vec0 index) -- `str` output from `CaptionTransformer` → `TextStream` (FTS index) -- Everything else → `Stream` (blob) - -## Reactive - -```python -# .appended emits Observation with .data pre-populated -images.appended.subscribe(lambda row: print(f"New image at {row.pose}")) - -# Stored transforms propagate reactively by default -detections = images.transform(detect_fn).store("detections") -# Now every images.append(frame) → detect_fn runs → result stored in "detections" - -# Filtered appended — only kitchen images -images.near(kitchen_pose, 5.0).appended.subscribe(...) -``` - -## Cross-stream lineage (project_to) - -`project_to()` follows `parent_id` chains to project observations onto another stream: - -```python -# Get embeddings matching a query, then project to source images -emb_results = img_emb.search_embedding("red shoes", k=20, raw=True).fetch() -# emb_results are EmbeddingObservations with .similarity, .pose, .ts - -# Or project to get the source images directly -image_results = img_emb.search_embedding("red shoes", k=20, raw=True) \ - .project_to(images).fetch() -``` - -`search_embedding` auto-projects by default — `raw=True` skips this to get -`EmbeddingObservation` results with `.similarity` scores. - -Multi-hop lineage works too: -```python -# images → sharp_frames → clip_embeddings (2 hops) -# search_embedding auto-resolves the chain -results = clip_embeddings.search_embedding("a door", k=10).fetch() -# results[0].data → Image (from raw_video, traversing through sharp_frames) -``` - -## Visualization - -`dimos.memory.rerun` sends stream contents to Rerun: - -```python -from dimos.memory.rerun import to_rerun - -# Send any stream to Rerun — auto-derives entity path from stream name, -# logs .data via to_rerun() and poses as arrows -to_rerun(images) -to_rerun(embeddings.search_embedding("a hallway", k=50)) -``` - -## Full Example: Cigarette Detection Pipeline - -```python -session = SqliteStore("/data/robot.db").session() - -# Root stream -images = session.stream("images", Image, - pose_provider=lambda: tf.get_pose("world", "base_link")) - -# Embedding index — EmbeddingModel is a Transformer -img_emb = images.transform(EmbeddingTransformer(CLIPModel())).store("img_emb") - -# VLM detection pipeline (live-only, no backfill) -images.transform( - lambda img: vlm.detect(img, "people with cigarettes"), - live=True, -).store("cigarette_detections") - -# Smart detection — reuse existing embeddings, detector gets EmbeddingObservation -img_emb.near(kitchen_pose, 10.0) \ - .transform(CigaretteDetector(vlm, clip)) \ - .store("kitchen_cigarette_detections") - -# --- Later, querying --- - -# "Where did we see people with cigarettes in the kitchen?" -for row in session.stream("cigarette_detections") \ - .after(one_hour_ago).near(kitchen_pose, 10.0): - print(f"t={row.ts} pose={row.pose}: {row.data}") - -# "Show me the source images alongside detections" -for det, img in session.stream("cigarette_detections") \ - .after(one_hour_ago).join(images): - print(f"Detection: {det.data}, Source image at {img.pose}") - -# "Find images similar to 'red shoes'" -similar = img_emb.search_embedding("red shoes", k=20).fetch() -# similar[0].data → Image (auto-projected from source) -# similar[0].embedding → Embedding (supports @ for cosine similarity) -``` - -## Full API - -```python -from dimos.models.embedding.base import Embedding, EmbeddingModel - -# --- Data types --- - -@dataclass -class Observation: - id: int - ts: float | None = None - pose: PoseStamped | None = None - tags: dict[str, Any] = field(default_factory=dict) - parent_id: int | None = None - - @property - def data(self) -> Any: - """Lazy payload. Pre-populated from append, fetched on demand from query.""" - ... - -@dataclass -class EmbeddingObservation(Observation): - """Returned by EmbeddingStream terminals. Auto-projects .data to source stream.""" - - similarity: float | None = None # 0..1, populated by search_embedding - - @property - def data(self) -> Any: - """Lazily loads from the source stream (e.g., Image), not the embedding.""" - ... - - @property - def embedding(self) -> Embedding: - """The Embedding object (has .vector, supports @ for cosine similarity).""" - ... - -# --- Transformer --- - -class Transformer(ABC, Generic[T, R]): - """Transforms a source stream into results on a target stream.""" - - def process(self, source: Stream[T], target: Stream[R]) -> None: - """Batch/historical processing. Full access to source stream.""" - ... - - def on_append(self, obs: Observation, target: Stream[R]) -> None: - """Reactive processing. Called per new item.""" - ... - - supports_backfill: bool = True - supports_live: bool = True - output_type: type | None = None - -# --- Streams --- - -class Stream(Generic[T]): - # Write (DB-backed only) - def append(self, payload: T, *, - ts: float | None = None, - pose: PoseLike | None = None, - tags: dict[str, Any] | None = None, - parent_id: int | None = None, - ) -> Observation: ... - - # Filter (returns new Stream, lazy) - def after(self, t: float) -> Stream[T]: ... - def before(self, t: float) -> Stream[T]: ... - def time_range(self, t1: float, t2: float) -> Stream[T]: ... - def at(self, t: float, *, tolerance: float = 1.0) -> Stream[T]: ... - def near(self, pose: PoseLike, radius: float) -> Stream[T]: ... - def filter_tags(self, **tags: Any) -> Stream[T]: ... - - # Order / paginate - def order_by(self, field: str, *, desc: bool = False) -> Stream[T]: ... - def limit(self, k: int) -> Stream[T]: ... - def offset(self, n: int) -> Stream[T]: ... - - # Transform - def transform(self, - xf: Transformer[T, R] | Callable[[T], R | list[R] | None], - *, live: bool = False, - backfill_only: bool = False, - ) -> Stream[R]: ... - - # Materialize (on TransformStream, accepts optional session= fallback) - def store(self, name: str | None = None, session: Session | None = None) -> Stream[T]: ... - - # Cross-stream lineage - def project_to(self, target: Stream[R]) -> Stream[R]: ... - - # Iteration & Terminals - def __iter__(self) -> Iterator[Observation]: ... # lazy, pages internally - def fetch(self) -> ObservationSet[T]: ... # list-like + stream-like result set - def fetch_pages(self, batch_size: int = 128) -> Iterator[list[Observation]]: ... - def one(self) -> Observation: ... - def last(self) -> Observation: ... - def count(self) -> int: ... - - # Reactive - @property - def appended(self) -> Observable[Observation]: ... - -class EmbeddingStream(Stream[T]): - """Created automatically when a Transformer produces Embedding output. - Terminals return EmbeddingObservation (auto-projects .data to source stream).""" - def search_embedding(self, query: Embedding | list[float] | str | Any, - *, k: int, raw: bool = False) -> Stream[Any]: ... - -class TextStream(Stream[T]): - """Stream with FTS index.""" - def search_text(self, text: str, *, k: int | None = None) -> TextStream[T]: ... - -class ObservationSet(Stream[T]): - """Materialized result set from fetch(). List-like + stream-like.""" - def __len__(self) -> int: ... - def __getitem__(self, index: int) -> Observation: ... - def __iter__(self) -> Iterator[Observation]: ... - def __bool__(self) -> bool: ... - -# --- Helpers --- - -def ingest(stream: Stream, source: Iterable[tuple[float, Any]], *, - pose_source: Any | None = None) -> int: - """Ingest (ts, payload) pairs into a stream. Returns count.""" - ... - -# --- Session / Store --- - -PoseProvider = Callable[[], PoseLike | None] - -class Session: - def stream(self, name: str, payload_type: type | None = None, *, - pose_provider: PoseProvider | None = None) -> Stream: ... - def text_stream(self, name: str, payload_type: type | None = None, *, - tokenizer: str = "unicode61", - pose_provider: PoseProvider | None = None) -> TextStream: ... - def embedding_stream(self, name: str, payload_type: type | None = None, *, - vec_dimensions: int | None = None, - pose_provider: PoseProvider | None = None, - parent_table: str | None = None, - embedding_model: EmbeddingModel | None = None) -> EmbeddingStream: ... - def materialize_transform(self, name: str, source: Stream, - transformer: Transformer, - *, payload_type: type | None = None, - live: bool = False, - backfill_only: bool = False) -> Stream: ... - def list_streams(self) -> list[StreamInfo]: ... - def resolve_parent_stream(self, name: str) -> str | None: ... - def resolve_lineage_chain(self, source: str, target: str) -> tuple[str, ...]: ... - def close(self) -> None: ... - -class Store: - def session(self) -> Session: ... - def close(self) -> None: ... -``` - -## Internal Backing (impl detail) - -A `Stream` can be backed by different things — the user never sees this: - -- **DB tables** — from `session.stream()`. Metadata + payload + indexes. -- **Predicate** — from `.after()`, `.near()`, etc. Lazy SQL WHERE. -- **Transform** — from `.transform(t)`. Source stream + Transformer. -- **ListBackend** — from `ObservationSet`. In-memory Python-side filtering. - -The impl decides how to execute based on the backing chain. - -## SQLite Schema - -Each stream `{name}` creates these tables: - -```sql --- Metadata table (compact rows, fast scans) -CREATE TABLE {name} ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - ts REAL, - pose_x REAL, -- position - pose_y REAL, - pose_z REAL, - pose_qx REAL, -- orientation quaternion (stored, not indexed) - pose_qy REAL, - pose_qz REAL, - pose_qw REAL, - tags TEXT DEFAULT '{}', - parent_id INTEGER -- lineage: source observation id -); -CREATE INDEX idx_{name}_ts ON {name}(ts); - --- Payload table (blobs, loaded on demand) -CREATE TABLE {name}_payload ( - id INTEGER PRIMARY KEY, - data BLOB -); - --- R*Tree spatial index (position only) -CREATE VIRTUAL TABLE {name}_rtree USING rtree( - id, - min_x, max_x, - min_y, max_y, - min_z, max_z -); -``` - -**Optional per stream kind:** - -```sql --- TextStream: FTS5 full-text index -CREATE VIRTUAL TABLE {name}_fts USING fts5(content, tokenize='unicode61'); - --- EmbeddingStream: vec0 vector index (cosine distance) -CREATE VIRTUAL TABLE {name}_vec USING vec0( - embedding float[{dim}] distance_metric=cosine -); -``` - -### Key design decisions - -- **Separate payload table** — metadata queries (`fetch`, `count`, `near`, filters) never touch blob data. Payload is loaded lazily via `obs.data`. -- **Decomposed pose columns** — enables R*Tree spatial index for `.near()` queries. Orientation stored for reconstruction but not spatially indexed. -- **R*Tree for spatial queries** — `.near(pose, radius)` compiles to an R*Tree range query (bounding box at +/-radius), with post-filter for exact Euclidean distance. -- **Cosine distance metric** — vec0 uses `distance_metric=cosine` (0=identical, 2=opposite). Similarity = `1.0 - distance`, clamped to [0, 1]. - -### Lazy payload loading - -`fetch()` returns `Observation` with lazy `.data`: -- Metadata query: `SELECT id, ts, pose_x, ..., tags, parent_id FROM {name} WHERE ...` -- `_data` stays `_UNSET`, `_data_loader` is set to: `SELECT data FROM {name}_payload WHERE id = ?` -- Only `obs.data` access triggers the blob read + codec decode - -This means iterating metadata (`obs.ts`, `obs.pose`, `obs.tags`) is cheap. - -### NearFilter SQL compilation - -```python -# .near(pose, 5.0) compiles to: -# JOIN {name}_rtree AS r ON r.id = {name}.id -# WHERE r.min_x >= pose.position.x - 5.0 AND r.max_x <= pose.position.x + 5.0 -# AND r.min_y >= pose.position.y - 5.0 AND r.max_y <= pose.position.y + 5.0 -# AND r.min_z >= pose.position.z - 5.0 AND r.max_z <= pose.position.z + 5.0 -``` - -For exact distance (not just bounding box), a post-filter computes Euclidean distance on the R*Tree candidates. - -## Serialization (Codec) - -Each stream has a `Codec[T]` that handles payload encode/decode. Auto-selected from `payload_type`. - -```python -class Codec(Protocol[T]): - def encode(self, value: T) -> bytes: ... - def decode(self, data: bytes) -> T: ... - -class LcmCodec(Codec[DimosMsg]): - """For DimosMsg types — uses lcm_encode/lcm_decode.""" - -class JpegCodec(Codec[Image]): - """For Image types — uses JPEG compression.""" - -class PickleCodec(Codec[Any]): - """Fallback for arbitrary Python objects.""" - -def codec_for_type(payload_type: type[T] | None) -> Codec[T]: - """Auto-select codec based on payload type.""" - ... -``` - -Lives in `dimos.memory.codec`. - -Transparent to the user — just pass `payload_type` to `session.stream()`: -```python -images = session.stream("images", Image) # auto LCM codec -numbers = session.stream("numbers", int) # auto pickle codec -``` - -Tags are JSON. Poses are decomposed into columns (not serialized). - -### Stream metadata (`_streams` table) - -``` -name TEXT PRIMARY KEY -payload_module TEXT -- fully qualified, e.g. "dimos.msgs.sensor_msgs.Image.Image" -stream_kind TEXT -- "stream" | "text" | "embedding" -parent_stream TEXT -- parent stream name (lineage for project_to/join) -embedding_dim INTEGER -- vec0 dimension (embedding streams only) -``` - -On restart, `session.stream("images")` (no `payload_type`) resolves the class from `payload_module` via `importlib`, then selects the codec automatically. `embedding_dim` allows recreating the vec0 table without needing to see the first embedding again. - -## Resolved Questions - -1. **`.append()` on non-stored streams?** → `TypeError` (requires backend). -2. **Multiple `.store()` calls?** → Idempotent — returns existing stream if already stored. -3. ~~**Memory pressure from in-memory transforms?**~~ → Solved via `fetch_pages`. -4. **Pose storage** → Decomposed columns + R*Tree index (not binary blob). -5. **Payload loading** → Lazy via separate `{name}_payload` table. -6. **`__iter__`** → `for page in self.fetch_pages(): yield from page` — lazy, memory-efficient iteration. -7. **`project_to` / lineage** → Implemented via `parent_id` column + `_streams.parent_stream`. Multi-hop chains supported. -8. **`fetch()` return type** → `ObservationSet` (list-like + stream-like). -9. **Similarity scores** → `EmbeddingObservation.similarity` populated from vec0 cosine distance. - -## Open Questions - -1. **Incremental transforms** — re-running a stored transform should resume from last processed item. -2. **4D indexing** — should R*Tree include time as a 4th dimension? diff --git a/dimos/memory/docs/query_objects.md b/dimos/memory/docs/query_objects.md deleted file mode 100644 index bf86d39675..0000000000 --- a/dimos/memory/docs/query_objects.md +++ /dev/null @@ -1,155 +0,0 @@ -# Query Objects — 4D Region + Soft Scoring System - -## Problem - -We need to query observations across 4 dimensions (x, y, z, t) plus embedding space. Current API has flat `filter_*` methods — works for simple cases but doesn't compose. We need: - -1. **Regions** — composable hard boundaries (include/exclude) -2. **Fields** — soft scoring that biases toward a point/time/embedding without hard cutoffs -3. A way to combine both in a single query - -## Key Insight - -Hard filters and soft biases are the same thing at different extremes: -- Hard filter = step function (1 inside, 0 outside) -- Soft bias = smooth decay (gaussian, linear, etc.) - -A unified **Criterion** type handles both. Each criterion maps an observation to a score in `[0, 1]`. Hard filters are just criteria with score `{0, 1}`. - -## Primitives - -### Temporal - -```python -# Hard boundaries -TimeRange(t1, t2) # 1 inside, 0 outside -Before(t) # sugar for TimeRange(-inf, t) -After(t) # sugar for TimeRange(t, inf) - -# Soft — score decays with distance from target -TimeProximity(target, sigma=60.0) # gaussian: exp(-dt²/2σ²) -``` - -### Spatial - -```python -# Hard boundaries -Sphere(center: PoseLike, radius: float) # 1 inside, 0 outside -Box(min: PoseLike, max: PoseLike) # axis-aligned bounding box -HeightRange(z_min, z_max) # horizontal slice - -# Soft -SpatialProximity(point: PoseLike, sigma=5.0) # gaussian in 3D -``` - -### Embedding - -```python -# Soft only (no hard boundary in embedding space makes sense) -EmbeddingSimilarity(vector, candidate_k=100) # cosine similarity, top-k pre-filter -``` - -### Tags - -```python -TagMatch(robot_id="robot1") # hard: exact match on tag values -``` - -## Composition - -Criteria compose via set operators: - -```python -# Intersection — all criteria must score > 0 -region = TimeRange(t1, t2) & Sphere(point, 5.0) - -# Union — any criterion scoring > 0 passes -region = Sphere(p1, 3.0) | Sphere(p2, 3.0) - -# Complement -region = ~TimeRange(t1, t2) # everything outside this window -``` - -For soft criteria, composition combines scores: -- `a & b` → `min(a.score, b.score)` (conservative) -- `a | b` → `max(a.score, b.score)` (permissive) - -## Weighted Scoring - -The interesting problem: "I care about embedding similarity, temporal proximity, AND spatial proximity" — but as soft preferences, not hard cutoffs. - -```python -Score( - time=TimeProximity(target_t, sigma=60), - space=SpatialProximity(point, sigma=5.0), - embedding=EmbeddingSimilarity(vector, candidate_k=200), - weights={"time": 0.3, "space": 0.3, "embedding": 0.4} -) -``` - -Each dimension produces a `[0, 1]` score. Final score = weighted sum. This replaces the vague `rank(**weights)` in the current API. - -## Integration with Query - -```python -# Current flat API (still works, sugar for simple cases) -q.after(t).near(point, 5.0).search_embedding(vec, candidate_k=100) - -# Region object approach -region = After(t) & Sphere(point, 5.0) -q.where(region).search_embedding(vec, candidate_k=100) - -# Full soft scoring — no hard boundaries, just preferences -q.score( - time=TimeProximity(target_t, sigma=120), - space=SpatialProximity(point, sigma=10.0), - embedding=EmbeddingSimilarity(vec, candidate_k=500), -).limit(20) - -# Mixed — hard boundary + soft ranking within -q.where(TimeRange(t1, t2)).score( - space=SpatialProximity(point, sigma=5.0), - embedding=EmbeddingSimilarity(vec, candidate_k=200), -).limit(10) -``` - -## SQL Mapping (SQLite impl) - -How each primitive maps to SQL: - -| Criterion | SQL Strategy | -|--------------------------|-------------------------------------------------------| -| `TimeRange(t1, t2)` | `WHERE ts BETWEEN ? AND ?` (B-tree) | -| `Before(t)` / `After(t)` | `WHERE ts < ?` / `WHERE ts > ?` | -| `Sphere(p, r)` | R*Tree range query on `_rtree` | -| `HeightRange(z1, z2)` | `WHERE pose_z BETWEEN ? AND ?` | -| `Box(min, max)` | R*Tree range query | -| `TimeProximity(t, σ)` | `ORDER BY ABS(ts - ?) ASC` or compute score in SELECT | -| `SpatialProximity(p, σ)` | R*Tree range (pre-filter at ~3σ) + score in SELECT | -| `EmbeddingSimilarity` | sqlite-vec `MATCH` → temp table | -| `TagMatch` | `WHERE json_extract(tags, ?) = ?` | - -Soft scoring strategy: **generous hard pre-filter in SQL, then score in Python**. -- Each soft criterion auto-generates a hard pre-filter at ~3σ (captures 99.7% of relevant results) -- `TimeProximity(t, σ=60)` → SQL: `WHERE ts BETWEEN t-180 AND t+180` (B-tree) -- `SpatialProximity(p, σ=5)` → SQL: R*Tree range query with 15m box -- `EmbeddingSimilarity` → sqlite-vec `MATCH` top-k (already a pre-filter) -- Python computes `[0, 1]` scores on the pre-filtered set, applies weights, sorts - -This keeps SQL simple (range queries on indexes) and Python handles the math. - -## Open Questions - -2. **How does `Score` interact with `search_embedding`?** Embedding search already returns ranked results from vec0. Should `Score.embedding` just re-weight those scores, or does it need a separate search pass? - -3. **Region objects as first-class types?** Do we store/serialize regions (e.g., "the kitchen region" as a reusable spatial boundary)? Or are they always constructed in code? - -4. **Do we need `NOT` regions for exclusion zones?** E.g., "everywhere except within 2m of the charging station." `~Sphere(charger, 2.0)` — complement on spatial regions requires scanning all of `_meta`, can't use R*Tree efficiently. - -5. **Gradient fields?** "Prefer observations taken at higher elevation" — not proximity to a point but a directional preference. `HeightGradient(ascending=True)` as a scorer? - -## Priority - -- **Phase 1**: Keep the flat `filter_*` / `rank()` API. Implement primitives internally. -- **Phase 2**: Expose `Criterion` objects + `where()` + `score()` as the composable API. -- **Phase 3**: Region persistence, named regions, gradient fields. diff --git a/dimos/memory/docs/questions.md b/dimos/memory/docs/questions.md deleted file mode 100644 index bc91b9f306..0000000000 --- a/dimos/memory/docs/questions.md +++ /dev/null @@ -1,56 +0,0 @@ -# Questions - -1. "where was I when this log line was added?" -- pose lookup, corelating to log lines found -- assume log line has a pose associated -- assume there are multiple log lines matching a search - -2. "how long have I been observing the red socks currently in view?" -- how many times did I see them before? -- temporal duration tracking + observation frequency - -3. "how many people did I see during last week?" -- assume we are generating a facial recognition db — is this matching a face detection stream, then embeddings? then we are searching over that stream? - -4. "where did you see red socks during last week?" -- we query for red socks embedding similarity, then feed this data into a VLM that further filters for socks -- is this data output into some table? is it like an ObservationSet again? -- then we can create a map (costmap) of red socks? - -5. "did anyone ever open this door? at what times did I see this door open? who opened it?" -- event detection + temporal querying of state changes - -6. "I have a transcription log (STT) and voice embeddings, how do I figure out who is saying what?" -- cross-stream correlation: audio → identity - -7. "I have parallel voice and facial recognition streams, how do I correlate voice to people?" -- I don't see all people speaking at all times -- multi-modal fusion with incomplete overlap - -8. "what's different in this room compared to yesterday?" -- comparing scene snapshots across time, diffing object sets -- requires baseline modeling / temporal comparison - -9. "show me everywhere the cat went today" -- continuous spatial tracking over time, not point queries -- dense pose-stream retrieval + path aggregation - -10. "what happened in the 30 seconds before the vase fell?" -- event-anchored temporal window across all streams -- multi-stream temporal slicing relative to a detected event - -11. "when was the last time I did NOT see the cat in the apartment?" -- negation query — finding gaps in an observation stream -- architecturally different from presence queries - -12. "what time does the mailman usually come?" -- aggregation across days, extracting temporal regularity from sparse events -- cross-session pattern extraction - -13. "what did robot-2 observe in the warehouse that I missed?" -- cross-agent memory diff -- session/robot-scoped queries and set difference across streams - -14. "how far did I travel while carrying an object?" -- filtered pose integration — only accumulate distance when a parallel detection stream has a positive signal -- cross-stream conditional joins diff --git a/dimos/memory/docs/sqlite.md b/dimos/memory/docs/sqlite.md deleted file mode 100644 index 173bedb5b6..0000000000 --- a/dimos/memory/docs/sqlite.md +++ /dev/null @@ -1,621 +0,0 @@ -# SQLite Implementation - -Implementation spec for the SQLite backend. A coding agent should be able to implement the full backend from this document + `api.md`. - -## File Structure - -``` -dimos/memory/ - __init__.py # public exports - types.py # Observation, EmbeddingObservation, StreamInfo, Filter types - stream.py # Stream, EmbeddingStream, TextStream, ObservationSet, ListBackend - transformer.py # Transformer ABC, PerItemTransformer, EmbeddingTransformer, etc. - store.py # Session ABC, Store ABC - codec.py # LcmCodec, JpegCodec, PickleCodec, codec_for_type() - ingest.py # ingest() helper for batch ingestion - viz.py # similarity_heatmap(), similarity_poses(), log_similarity_timeline() - - impl/ - sqlite.py # SqliteStore, SqliteSession, Sqlite*Backend (single file) - test_sqlite.py # tests -``` - -## Dependencies - -- `sqlite3` (stdlib) -- `sqlite-vec` — vector similarity search via vec0 virtual table. Loaded via `sqlite_vec.load(conn)`. -- FTS5 — built into SQLite by default on most platforms. -- R*Tree — built into SQLite by default. -- `reactivex` — for `.appended` observable (already a DimOS dependency). - -## Connection Management - -### SqliteStore - -```python -class SqliteStore(Store): - def __init__(self, path: str): - self._path = path - self._conn = sqlite3.connect(path) - self._conn.execute("PRAGMA journal_mode=WAL") - self._conn.execute("PRAGMA synchronous=NORMAL") - self._load_extensions() - - def session(self) -> SqliteSession: - return SqliteSession(self._conn) - - def _load_extensions(self) -> None: - try: - import sqlite_vec - self._conn.enable_load_extension(True) - sqlite_vec.load(self._conn) - self._conn.enable_load_extension(False) - except ImportError: - pass # vec0 unavailable — search_embedding will raise - - def close(self) -> None: - self._conn.close() -``` - -### SqliteSession - -```python -class SqliteSession(Session): - def __init__(self, conn: sqlite3.Connection): - self._conn = conn - self._streams: dict[str, Stream] = {} # cache by name - self._ensure_meta_table() - - def _ensure_meta_table(self): - """Create _streams registry table if not exists.""" - self._conn.execute(""" - CREATE TABLE IF NOT EXISTS _streams ( - name TEXT PRIMARY KEY, - payload_module TEXT, - stream_kind TEXT DEFAULT 'stream', - parent_stream TEXT, - embedding_dim INTEGER - ) - """) - - def stream(self, name, payload_type=None, *, pose_provider=None) -> Stream: - # Returns cached or creates new. payload_type required for new streams. - ... - - def text_stream(self, name, payload_type=None, *, tokenizer="unicode61", - pose_provider=None) -> TextStream: - ... - - def embedding_stream(self, name, payload_type=None, *, vec_dimensions=None, - pose_provider=None, parent_table=None, - embedding_model=None) -> EmbeddingStream: - ... - - def list_streams(self) -> list[StreamInfo]: ... - def resolve_parent_stream(self, name: str) -> str | None: ... - def resolve_lineage_chain(self, source: str, target: str) -> tuple[str, ...]: ... - def close(self) -> None: ... -``` - -## Schema - -All table names are prefixed with the stream name. Stream names are validated: `[a-zA-Z_][a-zA-Z0-9_]*`. - -### `_streams` — Global registry - -```sql -CREATE TABLE _streams ( - name TEXT PRIMARY KEY, - payload_module TEXT, -- e.g. 'dimos.msgs.sensor_msgs.Image.Image' - stream_kind TEXT DEFAULT 'stream', -- 'stream', 'embedding', 'text' - parent_stream TEXT, -- parent stream name (lineage) - embedding_dim INTEGER -- only for kind='embedding' -); -``` - -### `{name}` — Observation metadata (all stream types) - -```sql -CREATE TABLE {name} ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - ts REAL, - pose_x REAL, pose_y REAL, pose_z REAL, - pose_qx REAL, pose_qy REAL, pose_qz REAL, pose_qw REAL, - tags TEXT DEFAULT '{}', -- JSON dict - parent_id INTEGER -- lineage: id in parent stream -); -CREATE INDEX idx_{name}_ts ON {name}(ts); -``` - -### `{name}_payload` — Blob/Text payload - -```sql -CREATE TABLE {name}_payload ( - id INTEGER PRIMARY KEY, -- matches {name}.id - data BLOB NOT NULL -); -``` - -Separated from metadata so metadata queries never page in multi-MB blobs. - -### `{name}_rtree` — Spatial index (all stream types) - -```sql -CREATE VIRTUAL TABLE {name}_rtree USING rtree( - id, -- matches {name}.id - min_x, max_x, - min_y, max_y, - min_z, max_z -); -``` - -Only rows with pose are inserted into R*Tree. Rows without pose are excluded from `.near()` results. - -### `{name}_fts` — Full-text search (TextStream only) - -```sql -CREATE VIRTUAL TABLE {name}_fts USING fts5( - content, - tokenize='{tokenizer}' -); -``` - -Standalone FTS table (not content-synced). Rowids match `{name}.id`. - -### `{name}_vec` — Vector index (EmbeddingStream only) - -```sql -CREATE VIRTUAL TABLE {name}_vec USING vec0( - embedding float[{dim}] distance_metric=cosine -); -``` - -Cosine distance: 0 = identical, 2 = opposite. Similarity = `max(0, min(1, 1.0 - distance))`. - -Rowids match `{name}.id`. Dimension inferred from first embedding inserted, or from `vec_dimensions` parameter. - -## Stream Implementation - -### Backend classes - -The stream/backend split separates query logic from stream API: - -```python -class SqliteStreamBackend: - """Base backend for blob streams.""" - def do_append(self, payload, ts, pose, tags, parent_id=None) -> Observation: ... - def execute_fetch(self, query: StreamQuery) -> list[Observation]: ... - def execute_count(self, query: StreamQuery) -> int: ... - -class SqliteEmbeddingBackend(SqliteStreamBackend): - """Adds vec0 index. Overrides execute_fetch for vector search.""" - ... - -class SqliteTextBackend(SqliteStreamBackend): - """Adds FTS5 index. Overrides execute_fetch for text search.""" - ... -``` - -### append() - -```python -def do_append(self, payload, ts, pose, tags, parent_id=None): - ts = ts or time.time() - if pose is None and self._pose_provider: - pose = self._pose_provider() - - pose_cols = _decompose_pose(pose) - tags_json = _serialize_tags(tags) - - # 1. Insert into meta table - cur = self._conn.execute( - f"INSERT INTO {name} " - "(ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, tags, parent_id) " - "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - (ts, *pose_cols, tags_json, parent_id), - ) - row_id = cur.lastrowid - - # 2. Insert into _payload - blob = self._codec.encode(payload) - self._conn.execute( - f"INSERT INTO {name}_payload(id, data) VALUES (?, ?)", - (row_id, blob) - ) - - # 3. Insert into _rtree (if pose) - if pose_cols: - x, y, z = pose_cols[0], pose_cols[1], pose_cols[2] - self._conn.execute( - f"INSERT INTO {name}_rtree(id, min_x, max_x, min_y, max_y, min_z, max_z) " - "VALUES (?, ?, ?, ?, ?, ?, ?)", - (row_id, x, x, y, y, z, z) - ) - - self._conn.commit() - - # 4. Build Observation and emit - obs = Observation(id=row_id, ts=ts, pose=pose, tags=tags or {}, _data=payload) - self._subject.on_next(obs) - return obs -``` - -### EmbeddingBackend.append() - -Same as above, plus inserts into `_vec`: - -```python -if isinstance(payload, Embedding): - vec = payload.to_numpy().tolist() - self._conn.execute( - f"INSERT INTO {name}_vec(rowid, embedding) VALUES (?, ?)", - (row_id, json.dumps(vec)) - ) -``` - -### TextBackend.append() - -Same as base, plus inserts into `_fts`: - -```python -text = str(payload) -self._conn.execute( - f"INSERT INTO {name}_fts(rowid, content) VALUES (?, ?)", - (row_id, text) -) -``` - -## Filter → SQL Generation - -Each filter method returns a new stream with an added filter. At terminal time, the filter chain is compiled to SQL. - -### Filter types - -```python -AfterFilter(t) # → WHERE ts > ? -BeforeFilter(t) # → WHERE ts < ? -TimeRangeFilter(t1, t2) # → WHERE ts >= ? AND ts <= ? -AtFilter(t, tolerance) # → WHERE ABS(ts - ?) <= ? -NearFilter(pose, radius) # → JOIN _rtree bounding box query -TagsFilter(tags) # → WHERE json_extract(tags, '$.key') = ? -EmbeddingSearchFilter(vec, k) # → query _vec, then filter by rowids -TextSearchFilter(text, k) # → query _fts MATCH, then filter by rowids -LineageFilter(source_table, source_query, hops) # → nested IN subquery -``` - -### SQL compilation - -Walk the filter list, generate SQL: - -```python -def _compile_query(query, table) -> tuple[str, list[Any]]: - # Base SELECT - sql = f"SELECT {table}.id, {table}.ts, ... FROM {table}" - - # NearFilter → JOIN _rtree - # Other filters → WHERE clauses - # EmbeddingSearch/TextSearch → handled separately (two-step query) - # LineageFilter → nested IN subquery via _compile_ids() - - return sql, params -``` - -### search_embedding (vec0) - -Two-step process: - -```sql --- 1. Top-k vector search (cosine distance) -SELECT rowid, distance -FROM {name}_vec -WHERE embedding MATCH ? -ORDER BY distance -LIMIT ? -``` - -```python -# 2. Build dist_map, fetch metadata for those rowids, populate similarity -dist_map = {rowid: distance for rowid, distance in vec_rows} -# ... fetch metadata WHERE id IN (rowids) ... -for obs in observations: - obs.similarity = max(0.0, min(1.0, 1.0 - dist_map[obs.id])) -# Re-sort by distance rank (IN clause doesn't preserve vec0 ordering) -``` - -### search_text (FTS5) - -```sql -SELECT rowid, rank -FROM {name}_fts -WHERE content MATCH ? -ORDER BY rank -``` - -Same two-step: get rowids from FTS5, then fetch metadata. - -### LineageFilter compilation - -LineageFilter compiles to a nested SQL subquery walking the `parent_id` chain: - -```python -# Single hop: embeddings → images -f"SELECT parent_id FROM {source_table} WHERE id IN ({source_ids_sql})" - -# Multi-hop: embeddings → sharp_frames → images -# Wraps each hop as a nested IN subquery -``` - -## Terminal Execution - -### __iter__() — lazy iteration - -`Stream` is directly iterable via `fetch_pages`: - -```python -def __iter__(self): - for page in self.fetch_pages(): - yield from page -``` - -### fetch() - -Returns `ObservationSet` (list-like + stream-like): - -```python -def fetch(self) -> ObservationSet: - results = self._backend.execute_fetch(self._query) - return ObservationSet(results, session=self._session) -``` - -### count() - -```python -def count(self) -> int: - sql, params = _compile_count(query, table) - # → SELECT COUNT(*) FROM {table} WHERE ... - return self._conn.execute(sql, params).fetchone()[0] -``` - -### one() / last() - -- `one()` → `self.limit(1).fetch()[0]` -- `last()` → `self.order_by("ts", desc=True).limit(1).fetch()[0]` - -## Lazy Data Loading - -`Observation.data` uses lazy loading: - -```python -@dataclass -class Observation: - _data: Any = field(default=_UNSET, repr=False) - _data_loader: Callable[[], Any] | None = field(default=None, repr=False) - - @property - def data(self) -> Any: - if self._data is not _UNSET: - return self._data - if self._data_loader is not None: - self._data = self._data_loader() - return self._data - raise LookupError("No data available") -``` - -When building observations from query results: - -```python -def _row_to_obs(self, row) -> Observation: - row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json, pid = row - pose = _reconstruct_pose(px, py, pz, qx, qy, qz, qw) - - def loader(): - r = conn.execute(f"SELECT data FROM {table}_payload WHERE id = ?", (row_id,)).fetchone() - return codec.decode(r[0]) - - return Observation(id=row_id, ts=ts, pose=pose, tags=..., _data_loader=loader) -``` - -### EmbeddingObservation - -For `EmbeddingBackend`, `_row_to_obs` returns `EmbeddingObservation` with two lazy loaders: - -```python -def _row_to_obs(self, row) -> EmbeddingObservation: - # ... same metadata extraction ... - - # _data_loader: loads raw embedding payload - # _source_data_loader: loads from PARENT stream (auto-projection) - # - Resolves parent codec from _streams.payload_module - # - Uses parent_id to look up the source payload - - return EmbeddingObservation( - id=row_id, ts=ts, pose=pose, tags=..., - parent_id=pid, - _data_loader=loader, - _source_data_loader=source_loader, # None if no parent - ) -``` - -## Lineage - -### Storing lineage - -When a Transformer appends to a target stream, `parent_id` links back to the source: - -```python -target.append(result, ts=source_obs.ts, pose=source_obs.pose, - parent_id=source_obs.id) -``` - -The `_streams` registry tracks stream-level lineage: -```python -# After materialize_transform creates the target -UPDATE _streams SET parent_stream = ? WHERE name = ? -``` - -### resolve_lineage_chain() - -Walks `_streams.parent_stream` from source toward target: - -```python -def resolve_lineage_chain(self, source: str, target: str) -> tuple[str, ...]: - # Single hop (source → target): returns () - # Two hops (source → mid → target): returns ("mid",) - # Raises ValueError if no path exists -``` - -### project_to() - -Uses `LineageFilter` to compile a nested SQL subquery: - -```python -def project_to(self, target: Stream) -> Stream: - hops = session.resolve_lineage_chain(source_table, target_table) - return target._with_filter(LineageFilter(source_table, self._query, hops)) -``` - -## Pose Helpers - -PoseStamped in dimos extends Pose directly (no wrapper). Access position/orientation directly: - -```python -def _decompose_pose(pose) -> tuple[float, ...] | None: - if pose is None: - return None - p = pose.position # NOT pose.pose.position - q = pose.orientation - return (p.x, p.y, p.z, q.x, q.y, q.z, q.w) - -def _reconstruct_pose(x, y, z, qx, qy, qz, qw) -> PoseStamped | None: - if x is None: - return None - return PoseStamped( - position=[x, y or 0.0, z or 0.0], # list args (plum dispatch) - orientation=[qx or 0.0, qy or 0.0, qz or 0.0, qw or 1.0], - ) -``` - -NearFilter SQL compilation also accesses `f.pose.position` directly. - -## Transform Execution - -### .transform() — returns lazy stream - -`.transform(xf)` doesn't execute immediately. It returns a `TransformStream`. Execution happens at terminal time or `.store()`. - -### .store() — materializes - -When `.store(name)` is called on a `TransformStream`: - -1. Register target stream in `_streams` (with `parent_stream` set) -2. Create target tables -3. Auto-detect target stream type from transformer: - - `EmbeddingTransformer` → `EmbeddingStream` (with parent_table) - - `CaptionTransformer` → `TextStream` (FTS) - - Other → `Stream` (blob) -4. If not `live` mode: run `xf.process(source, target)` (backfill) -5. If not `backfill_only`: subscribe to source's `.appended`, call `xf.on_append()` -6. Return the stored stream - -### .fetch() on TransformStream (no .store()) - -Executes the transform in-memory using `_CollectorStream`: - -```python -def fetch(self) -> ObservationSet: - collector = _CollectorStream() - self._transformer.process(self._source, collector) - return ObservationSet(collector.results) -``` - -## Reactive (.appended) - -Each stored stream backend has a `Subject` from reactivex: - -```python -class SqliteStreamBackend: - def __init__(self, ...): - self._subject: Subject[Observation] = Subject() - - @property - def appended_subject(self): - return self._subject -``` - -`do_append()` emits to the subject after the DB write succeeds. - -For filtered streams, the observable filters events through the filter chain in Python: - -```python -@property -def appended(self): - raw = self._backend.appended_subject - active = [f for f in self._query.filters - if not isinstance(f, (EmbeddingSearchFilter, LineageFilter))] - return raw.pipe(ops.filter(lambda obs: all(f.matches(obs) for f in active))) -``` - -## Serialization - -### Codec system - -```python -class LcmCodec: # for DimosMsg types (lcm_encode/lcm_decode) -class JpegCodec: # for Image types (JPEG compression) -class PickleCodec: # fallback for arbitrary Python objects - -def codec_for_type(payload_type: type | None) -> Codec: - """Auto-select codec based on payload type.""" - ... -``` - -Lives in `dimos.memory.codec`. - -### Tag serialization - -Tags are stored as JSON text. Empty dict → `"{}"`. - -## SQL Safety - -- **Identifier validation**: stream names must match `^[a-zA-Z_][a-zA-Z0-9_]*$`. -- **Parameterized queries**: all user values go through `?` params, never string interpolation. -- **Table names**: constructed from validated stream names, safe for SQL interpolation. -- **Order fields**: validated against allowlist `{"id", "ts"}`. - -## Thread Safety - -- Each `Session` owns one `sqlite3.Connection` — not shared across threads. -- Multiple sessions can exist on the same file (WAL mode allows concurrent reads + one writer). -- The `appended` subject emits on the thread that called `append()`. - -## Error Handling - -- `append()` on non-stored stream → `TypeError` -- `search_embedding()` on non-embedding stream → `TypeError` -- `search_text()` on non-text stream → `TypeError` -- `search_embedding()` when sqlite-vec not loaded → `RuntimeError` -- Invalid stream name → `ValueError` -- `one()` with no results → `LookupError` -- `stream()` without `payload_type` on new stream → `TypeError` - -## Testing - -Tests in `dimos/memory/impl/test_sqlite.py`. Use `:memory:` store for speed. - -Key test scenarios: -1. Create stream, append, fetch — verify data round-trips -2. Temporal filters (after, before, time_range, at) -3. Spatial filter (near) — with and without pose -4. Tag filtering -5. EmbeddingStream — store embeddings, search_embedding, verify auto-projection -6. TextStream — store text, search_text -7. Transform with lambda — verify lineage -8. Transform with Transformer class — verify process() called -9. Chained filters — verify SQL composition -10. project_to — verify cross-stream lineage (single and multi-hop) -11. fetch_pages — verify pagination -12. Lazy data loading — verify .data only hits DB on access -13. .appended observable — verify reactive emission -14. Similarity scores — verify EmbeddingObservation.similarity populated after search -15. raw=True — verify EmbeddingObservation with similarity + auto-projected data -16. ObservationSet — verify list-like + stream-like behavior diff --git a/dimos/memory/docs/transform.md b/dimos/memory/docs/transform.md deleted file mode 100644 index 409fd8fc6b..0000000000 --- a/dimos/memory/docs/transform.md +++ /dev/null @@ -1,180 +0,0 @@ -# Transform — Unified Derived Stream API - -## Concept - -`.transform()` is a single method on `StreamBase` that handles both historical (batch) and live (reactive) processing. It takes data from a source, applies a function, and stores results into the target stream with lineage. - -## API - -```python -class StreamBase(ABC, Generic[T]): - def transform(self, - source: StreamBase | ObservationSet, - fn: Callable[[Any], T | list[T] | None] | None = None, - *, - live: bool = False, - ) -> Self: - """ - Process source data, store results in this stream. - - Args: - source: where to read from - fn: transform function. Returns T, list[T], or None (skip). - None allowed for EmbeddingStream (uses model.embed implicitly). - live: if True, only subscribe to new appends (no backfill) - - Behavior by source type: - StreamBase → backfill existing + subscribe to live (default) - live=True → skip backfill, only subscribe - ObservationSet → batch process snapshot (live ignored) - - Returns self for chaining. - """ -``` - -## Source type determines mode - -| Source | `live=False` (default) | `live=True` | -|------------------|--------------------------------------------------|-------------------------------| -| `StreamBase` | backfill all existing + subscribe to `.appended` | subscribe to `.appended` only | -| `ObservationSet` | batch process the set | N/A (ignored) | - -## Transform function contract - -```python -fn: Callable[[Any], T | list[T] | None] -``` - -- Returns `T` → single result stored -- Returns `list[T]` → multiple results stored (e.g., multiple detections per frame) -- Returns `None` or `[]` → nothing stored for this input (e.g., no detections) -- `parent_id` set automatically from source row - -## Examples - -### VLM detections on images - -```python -images = session.stream("images", Image, - pose_provider=lambda: tf.get_pose("world", "base_link")) - -detections = session.stream("cigarette_detections", VLMDetection) - -# Backfill + live -detections.transform(images, fn=lambda img: vlm.detect(img, "people with cigarettes")) - -# After this, every new image.append() triggers detection automatically -# All results are queryable -rows = detections.query().filter_after(one_hour_ago).fetch() -``` - -### Live-only (skip backfill) - -```python -detections.transform(images, fn=detect_fn, live=True) -# Only processes images appended from now on -``` - -### Historical batch on query results - -```python -# Only process images from the kitchen in the last hour -kitchen_images = images.query().filter_near(kitchen_pose, 5.0).filter_after(one_hour_ago).fetch_set() - -detections.transform(kitchen_images, fn=lambda img: vlm.detect(img, "cigarettes")) -# Batch processes the set, no live subscription -``` - -### Embedding stream (specialized) - -```python -img_emb = session.embedding_stream("img_emb", model=CLIPModel()) - -# fn is implicit — uses model.embed() -img_emb.transform(images, live=True) - -# Equivalent to: -img_emb.transform(images, fn=lambda img: clip.embed(img), live=True) -``` - -### Chaining transforms - -```python -images = session.stream("images", Image, pose_provider=pose_fn) - -# Embeddings from images -img_emb = session.embedding_stream("img_emb", model=CLIPModel()) -img_emb.transform(images, live=True) - -# Detections from images -detections = session.stream("detections", VLMDetection) -detections.transform(images, fn=detect_fn, live=True) - -# Text descriptions from detections (second-level derived) -descriptions = session.text_stream("descriptions", str) -descriptions.transform(detections, fn=lambda det: det.describe(), live=True) -``` - -## Internals - -### Backfill (batch) - -```python -for page in source.iter_meta(page_size=128): - for row in page: - payload = source.load(row) # or row.data - results = fn(payload) - if results is None: - continue - if not isinstance(results, list): - results = [results] - for r in results: - self.append(r, ts=row.ts, pose=row.pose, parent_id=row.id) -``` - -### Live (reactive) - -```python -source.appended.pipe( - ops.map(lambda row: (row, fn(row.data))), - ops.filter(lambda pair: pair[1] is not None), - ops.flat_map(lambda pair: [ - (pair[0], r) for r in (pair[1] if isinstance(pair[1], list) else [pair[1]]) - ]), -).subscribe(lambda pair: self.append(pair[1], ts=pair[0].ts, pose=pair[0].pose, - parent_id=pair[0].id)) -``` - -### EmbeddingStream override - -```python -class EmbeddingStream(StreamBase[T]): - model: EmbeddingModel - - def transform(self, source, fn=None, *, live=False): - if fn is None: - fn = self.model.embed - return super().transform(source, fn, live=live) -``` - -## Lineage - -`transform()` sets `parent_id` on every appended row, linking back to the source row. This enables `project_to()`: - -```python -# Find source images for cigarette detections -with detections.query().fetch_set() as det_set: - source_images = det_set.project_to(images) - for row in source_images.rows(limit=5): - img = images.load(row) -``` - -## Open questions - -1. **Async transforms?** VLM inference is slow. Should `fn` support async/await or rx scheduling (e.g., `observe_on(io_scheduler)`)? - -2. **Error handling?** If `fn` raises on one row, skip it? Log and continue? Configurable? - -3. **Backfill progress?** For large backfills, should `transform()` return a progress observable or run in background? - -4. **Multiple parents?** Current design is single-parent lineage. If a stream derives from two streams (e.g., fusing image + audio), we'd need multi-parent support. Phase 3. diff --git a/dimos/memory2/embed.py b/dimos/memory/embed.py similarity index 93% rename from dimos/memory2/embed.py rename to dimos/memory/embed.py index e3b34bb0ae..04e68dd540 100644 --- a/dimos/memory2/embed.py +++ b/dimos/memory/embed.py @@ -17,12 +17,12 @@ from itertools import islice from typing import TYPE_CHECKING, Any, TypeVar -from dimos.memory2.transform import Transformer +from dimos.memory.transform import Transformer if TYPE_CHECKING: from collections.abc import Iterator - from dimos.memory2.type import EmbeddedObservation, Observation + from dimos.memory.type import Observation from dimos.models.embedding.base import EmbeddingModel T = TypeVar("T") @@ -48,7 +48,7 @@ def __init__(self, model: EmbeddingModel, batch_size: int = 32) -> None: self.model = model self.batch_size = batch_size - def __call__(self, upstream: Iterator[Observation[Any]]) -> Iterator[EmbeddedObservation[Any]]: + def __call__(self, upstream: Iterator[Observation[Any]]) -> Iterator[Observation[Any]]: for batch in _batched(upstream, self.batch_size): images = [obs.data for obs in batch] embeddings = self.model.embed(*images) @@ -69,7 +69,7 @@ def __init__(self, model: EmbeddingModel, batch_size: int = 32) -> None: self.model = model self.batch_size = batch_size - def __call__(self, upstream: Iterator[Observation[Any]]) -> Iterator[EmbeddedObservation[Any]]: + def __call__(self, upstream: Iterator[Observation[Any]]) -> Iterator[Observation[Any]]: for batch in _batched(upstream, self.batch_size): texts = [str(obs.data) for obs in batch] embeddings = self.model.embed_text(*texts) diff --git a/dimos/memory2/embeddings.md b/dimos/memory/embeddings.md similarity index 99% rename from dimos/memory2/embeddings.md rename to dimos/memory/embeddings.md index de27cd18c9..a48dbb6439 100644 --- a/dimos/memory2/embeddings.md +++ b/dimos/memory/embeddings.md @@ -1,4 +1,4 @@ -# memory2 Embedding Design +# memory Embedding Design ## Core Principle: Enrichment, Not Replacement diff --git a/dimos/memory2/filter.py b/dimos/memory/filter.py similarity index 96% rename from dimos/memory2/filter.py rename to dimos/memory/filter.py index 32330192ba..2d9f98c1d4 100644 --- a/dimos/memory2/filter.py +++ b/dimos/memory/filter.py @@ -22,8 +22,8 @@ if TYPE_CHECKING: from collections.abc import Callable, Iterator - from dimos.memory2.buffer import BackpressureBuffer - from dimos.memory2.type import Observation + from dimos.memory.buffer import BackpressureBuffer + from dimos.memory.type import Observation from dimos.models.embedding.base import Embedding @@ -46,7 +46,7 @@ def __str__(self) -> str: @dataclass(frozen=True) -class AfterFilter: +class AfterFilter(Filter): t: float def matches(self, obs: Observation[Any]) -> bool: @@ -54,7 +54,7 @@ def matches(self, obs: Observation[Any]) -> bool: @dataclass(frozen=True) -class BeforeFilter: +class BeforeFilter(Filter): t: float def matches(self, obs: Observation[Any]) -> bool: @@ -62,7 +62,7 @@ def matches(self, obs: Observation[Any]) -> bool: @dataclass(frozen=True) -class TimeRangeFilter: +class TimeRangeFilter(Filter): t1: float t2: float @@ -71,7 +71,7 @@ def matches(self, obs: Observation[Any]) -> bool: @dataclass(frozen=True) -class AtFilter: +class AtFilter(Filter): t: float tolerance: float = 1.0 @@ -80,7 +80,7 @@ def matches(self, obs: Observation[Any]) -> bool: @dataclass(frozen=True) -class NearFilter: +class NearFilter(Filter): pose: Any radius: float @@ -108,7 +108,7 @@ def _xyz(p: Any) -> tuple[float, float, float]: @dataclass(frozen=True) -class TagsFilter: +class TagsFilter(Filter): tags: dict[str, Any] def matches(self, obs: Observation[Any]) -> bool: @@ -119,7 +119,7 @@ def matches(self, obs: Observation[Any]) -> bool: @dataclass(frozen=True) -class PredicateFilter: +class PredicateFilter(Filter): """Wraps an arbitrary predicate function for use with .filter().""" fn: Callable[[Observation[Any]], bool] diff --git a/dimos/memory/formatting.py b/dimos/memory/formatting.py index 81145a07ba..ee13fb3f36 100644 --- a/dimos/memory/formatting.py +++ b/dimos/memory/formatting.py @@ -12,18 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Rich text rendering for memory types and streams.""" +"""Rich rendering helpers for memory types. -from __future__ import annotations +All rich/ANSI logic lives here. Other modules import the mixin and +``render_text`` — nothing else needs to touch ``rich`` directly. +""" -from typing import TYPE_CHECKING, Any +from __future__ import annotations from rich.console import Console from rich.text import Text -if TYPE_CHECKING: - from collections.abc import Callable - _console = Console(force_terminal=True, highlight=False) @@ -34,220 +33,26 @@ def render_text(text: Text) -> str: return cap.get() -# ── Filter rendering ──────────────────────────────────────────────── - - -def _after_rich(f: Any) -> Text: - t = Text() - t.append("after", style="cyan") - t.append(f"(t={f.t})") - return t - - -def _before_rich(f: Any) -> Text: - t = Text() - t.append("before", style="cyan") - t.append(f"(t={f.t})") - return t - - -def _time_range_rich(f: Any) -> Text: - t = Text() - t.append("time_range", style="cyan") - t.append(f"({f.t1}, {f.t2})") - return t - - -def _at_rich(f: Any) -> Text: - t = Text() - t.append("at", style="cyan") - t.append(f"(t={f.t}, tol={f.tolerance})") - return t - - -def _near_rich(f: Any) -> Text: - t = Text() - t.append("near", style="cyan") - t.append("(") - if f.pose is not None and hasattr(f.pose, "position"): - p = f.pose.position - t.append(f"[{p.x:.1f}, {p.y:.1f}, {p.z:.1f}]", style="green") - t.append(f", radius={f.radius:.2f}") - else: - t.append(f"radius={f.radius}") - t.append(")") - return t - - -def _tags_rich(f: Any) -> Text: - t = Text() - t.append("tags", style="cyan") - pairs = ", ".join(f"{k}={v!r}" for k, v in f.tags) - t.append(f"({pairs})") - return t - - -def _embedding_search_rich(f: Any) -> Text: - t = Text() - t.append("search_embedding", style="cyan") - t.append("(") - if f.label: - t.append(repr(f.label), style="green") - t.append(", ") - t.append(f"k={f.k}") - t.append(")") - return t - - -def _text_search_rich(f: Any) -> Text: - t = Text() - t.append("text", style="cyan") - t.append(f"({f.text!r})") - return t - - -def _lineage_rich(f: Any) -> Text: - t = Text() - t.append("lineage", style="cyan") - hops = " -> ".join(f.hops) if f.hops else "direct" - t.append(f"({f.source_table} -> {hops})") - return t - - -_FILTER_DISPATCH: dict[type, Callable[..., Text]] | None = None - - -def _get_dispatch() -> dict[type, Callable[..., Text]]: - global _FILTER_DISPATCH - if _FILTER_DISPATCH is not None: - return _FILTER_DISPATCH - from dimos.memory.type import ( - AfterFilter, - AtFilter, - BeforeFilter, - EmbeddingSearchFilter, - LineageFilter, - NearFilter, - TagsFilter, - TextSearchFilter, - TimeRangeFilter, - ) - - _FILTER_DISPATCH = { - AfterFilter: _after_rich, - BeforeFilter: _before_rich, - TimeRangeFilter: _time_range_rich, - AtFilter: _at_rich, - NearFilter: _near_rich, - TagsFilter: _tags_rich, - EmbeddingSearchFilter: _embedding_search_rich, - TextSearchFilter: _text_search_rich, - LineageFilter: _lineage_rich, - } - return _FILTER_DISPATCH - - -def filter_rich(f: Any) -> Text: - """Render a Filter to rich Text.""" - dispatch = _get_dispatch() - renderer = dispatch.get(type(f)) - if renderer is None: - return Text(str(f)) - return renderer(f) - - -def query_rich(q: Any) -> Text: - """Render a StreamQuery to rich Text.""" +def _colorize(plain: str) -> Text: + """Turn ``'name(args)'``, ``'a | b'``, or ``'a -> b'`` into rich Text with cyan names.""" t = Text() pipe = Text(" | ", style="dim") - parts: list[Text] = [filter_rich(f) for f in q.filters] - if q.order_field: - p = Text() - p.append("order", style="cyan") - direction = "desc" if q.order_desc else "asc" - p.append(f"({q.order_field}, {direction})") - parts.append(p) - if q.limit_val is not None: - p = Text() - p.append("limit", style="cyan") - p.append(f"({q.limit_val})") - parts.append(p) - if q.offset_val is not None: - p = Text() - p.append("offset", style="cyan") - p.append(f"({q.offset_val})") - parts.append(p) - for i, part in enumerate(parts): + arrow = Text(" -> ", style="dim") + for i, seg in enumerate(plain.split(" | ")): if i > 0: t.append_text(pipe) - t.append_text(part) + for j, part in enumerate(seg.split(" -> ")): + if j > 0: + t.append_text(arrow) + name, _, rest = part.partition("(") + t.append(name, style="cyan") + if rest: + t.append(f"({rest}") return t -# ── Stream rendering ──────────────────────────────────────────────── - +class FilterRepr: + """Mixin for filters: subclass defines ``__str__``, gets colored ``__repr__`` free.""" -def rich_text(obj: Any) -> Text: - """Render a Stream, TransformStream, ObservationSet, or StreamQuery to rich Text. - - Uses duck-typing on attributes — no dispatch table needed. - """ - # TransformStream: has _source and _transformer - if hasattr(obj, "_transformer"): - xf = obj._transformer - t = Text() - t.append("TransformStream", style="bold cyan") - t.append("[", style="dim") - t.append(xf.output_type.__name__ if xf.output_type else "?", style="yellow") - t.append("]", style="dim") - t.append("(", style="dim") - t.append_text(rich_text(obj._source)) - t.append(" -> ", style="dim") - t.append(repr(xf), style="magenta") - if obj._live: - t.append(", ", style="dim") - t.append("live=True", style="yellow") - if obj._backfill_only: - t.append(", ", style="dim") - t.append("backfill_only=True", style="yellow") - t.append(")", style="dim") - qt = query_rich(obj._query) - if qt.plain: - t.append(" | ", style="dim") - t.append_text(qt) - return t - - # ObservationSet: has _observations list - if hasattr(obj, "_observations"): - type_name = obj._payload_type.__name__ if obj._payload_type else "?" - t = Text() - t.append("ObservationSet", style="bold cyan") - t.append("[", style="dim") - t.append(type_name, style="yellow") - t.append("]", style="dim") - t.append("(", style="dim") - t.append(f"{len(obj._observations)} items", style="green") - t.append(")", style="dim") - return t - - # StreamQuery - if hasattr(obj, "filters"): - return query_rich(obj) - - # Stream (and subclasses like EmbeddingStream, TextStream) - cls_name = type(obj).__name__ - type_name = obj._payload_type.__name__ if obj._payload_type else "?" - name = obj._backend.stream_name if obj._backend else "unbound" - t = Text() - t.append(cls_name, style="bold cyan") - t.append("[", style="dim") - t.append(type_name, style="yellow") - t.append("]", style="dim") - t.append("(", style="dim") - t.append(f'"{name}"', style="green") - t.append(")", style="dim") - qt = query_rich(obj._query) - if qt.plain: - t.append(" | ", style="dim") - t.append_text(qt) - return t + def __repr__(self) -> str: + return render_text(_colorize(str(self))) diff --git a/dimos/memory2/impl/README.md b/dimos/memory/impl/README.md similarity index 91% rename from dimos/memory2/impl/README.md rename to dimos/memory/impl/README.md index bc95405ee2..f2475c7eec 100644 --- a/dimos/memory2/impl/README.md +++ b/dimos/memory/impl/README.md @@ -1,6 +1,6 @@ # impl — Backend implementations -Storage backends for memory2. Each backend implements the `Backend` protocol to provide observation storage with query support. All backends support live mode via a pluggable `LiveChannel`. +Storage backends for memory. Each backend implements the `Backend` protocol to provide observation storage with query support. All backends support live mode via a pluggable `LiveChannel`. ## Existing backends @@ -14,10 +14,10 @@ Storage backends for memory2. Each backend implements the `Backend` protocol to ### 1. Implement the Backend protocol ```python -from dimos.memory2.backend import Backend, BackendConfig, LiveChannel -from dimos.memory2.filter import StreamQuery -from dimos.memory2.livechannel.subject import SubjectChannel -from dimos.memory2.type import Observation +from dimos.memory.backend import Backend, BackendConfig, LiveChannel +from dimos.memory.filter import StreamQuery +from dimos.memory.livechannel.subject import SubjectChannel +from dimos.memory.type import Observation from dimos.protocol.service.spec import Configurable class MyBackend(Configurable[BackendConfig], Generic[T]): @@ -78,7 +78,7 @@ See `ListBackend._iterate_live()` for the reference implementation. ### 3. Add Store and Session ```python -from dimos.memory2.store import Session, Store +from dimos.memory.store import Session, Store class MySession(Session): def _create_backend( diff --git a/dimos/memory/impl/__init__.py b/dimos/memory/impl/__init__.py index e69de29bb2..1ed1bd093e 100644 --- a/dimos/memory/impl/__init__.py +++ b/dimos/memory/impl/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/dimos/memory2/impl/memory.py b/dimos/memory/impl/memory.py similarity index 92% rename from dimos/memory2/impl/memory.py rename to dimos/memory/impl/memory.py index f53a3d2af2..0956fa83e3 100644 --- a/dimos/memory2/impl/memory.py +++ b/dimos/memory/impl/memory.py @@ -18,11 +18,11 @@ import threading from typing import TYPE_CHECKING, Any, Generic, TypeVar -from dimos.memory2.backend import BackendConfig -from dimos.memory2.codecs.base import Codec, codec_for -from dimos.memory2.livechannel.subject import SubjectChannel -from dimos.memory2.store import Session, Store -from dimos.memory2.type import _UNLOADED +from dimos.memory.backend import BackendConfig +from dimos.memory.codecs.base import Codec, codec_for +from dimos.memory.livechannel.subject import SubjectChannel +from dimos.memory.store import Session, Store +from dimos.memory.type import _UNLOADED from dimos.protocol.service.spec import Configurable if TYPE_CHECKING: @@ -30,10 +30,10 @@ from reactivex.abc import DisposableBase - from dimos.memory2.backend import Backend, LiveChannel - from dimos.memory2.buffer import BackpressureBuffer - from dimos.memory2.filter import StreamQuery - from dimos.memory2.type import Observation + from dimos.memory.backend import Backend, LiveChannel + from dimos.memory.buffer import BackpressureBuffer + from dimos.memory.filter import StreamQuery + from dimos.memory.type import Observation T = TypeVar("T") @@ -132,6 +132,7 @@ def _vector_search( """Use pluggable VectorStore for ANN search, then apply remaining query ops.""" vs = self.config.vector_store assert vs is not None # caller checks + assert query.search_vec is not None # caller checks hits = vs.search(self._name, query.search_vec, query.search_k or len(snapshot)) @@ -153,7 +154,7 @@ def _iterate_live( buf: BackpressureBuffer[Observation[T]], sub: DisposableBase, ) -> Iterator[Observation[T]]: - from dimos.memory2.buffer import ClosedError + from dimos.memory.buffer import ClosedError eager = self.config.eager_blobs and self.config.blob_store is not None diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index 9ee8219627..e511608387 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -12,90 +12,71 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""SQLite-backed memory store implementation. - -Schema per stream ``{name}``: - - {name} — id, ts, pose columns (x/y/z + quaternion), tags, parent_id - {name}_payload — id, data BLOB (loaded lazily) - {name}_rtree — R*Tree spatial index on position - {name}_fts — FTS5 full-text index (TextStream only) - {name}_vec — vec0 vector index (EmbeddingStream only) - -Payloads use Codec (LCM for DimosMsg types, pickle otherwise). -Poses are decomposed into columns. Tags are JSON. -""" - from __future__ import annotations +from dataclasses import dataclass, replace +from itertools import islice import json import re import sqlite3 import threading -import time -from typing import TYPE_CHECKING, Any - -from reactivex.subject import Subject +from typing import TYPE_CHECKING, Any, Generic, TypeVar -from dimos.memory.codec import ( - Codec, - PickleCodec, - codec_for_type, - module_path_to_type, - type_to_module_path, -) -from dimos.memory.store import Session, Store -from dimos.memory.stream import EmbeddingStream, Stream, TextStream -from dimos.memory.transformer import ( - CaptionTransformer, - EmbeddingTransformer, - TextEmbeddingTransformer, - Transformer, -) -from dimos.memory.type import ( +from dimos.memory.backend import BackendConfig +from dimos.memory.blobstore.sqlite import SqliteBlobStore +from dimos.memory.codecs.base import Codec, codec_for +from dimos.memory.filter import ( AfterFilter, AtFilter, BeforeFilter, - EmbeddingObservation, - EmbeddingSearchFilter, - Filter, - LineageFilter, NearFilter, - Observation, - StreamQuery, TagsFilter, - TextSearchFilter, TimeRangeFilter, + _xyz, ) +from dimos.memory.livechannel.subject import SubjectChannel +from dimos.memory.store import Session, Store, StoreConfig +from dimos.memory.type import _UNLOADED, Observation +from dimos.protocol.service.spec import Configurable if TYPE_CHECKING: - from collections.abc import Callable - import os + from collections.abc import Iterator + + from reactivex.abc import DisposableBase + + from dimos.memory.backend import Backend, LiveChannel + from dimos.memory.buffer import BackpressureBuffer + from dimos.memory.filter import Filter, StreamQuery - from dimos.memory.type import PoseProvider - from dimos.models.embedding.base import EmbeddingModel +T = TypeVar("T") -_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") -_ALLOWED_ORDER_FIELDS = frozenset({"id", "ts"}) +_IDENT_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") -def _validate_identifier(name: str) -> str: - """Validate that *name* is a safe SQL identifier (alphanumeric + underscore).""" - if not _IDENTIFIER_RE.match(name): - raise ValueError(f"Invalid identifier: {name!r}") - return name +# ── Helpers ────────────────────────────────────────────────────── -# ── Pose helpers (column-based) ────────────────────────────────────── +def _validate_identifier(name: str) -> None: + if not _IDENT_RE.match(name): + raise ValueError(f"Invalid stream name: {name!r}") -def _decompose_pose(pose: Any) -> tuple[float, float, float, float, float, float, float] | None: - """Extract (x, y, z, qx, qy, qz, qw) from a PoseStamped.""" +def _decompose_pose(pose: Any) -> tuple[float, ...] | None: if pose is None: return None - p = pose.position - q = pose.orientation - return (p.x, p.y, p.z, q.x, q.y, q.z, q.w) + if hasattr(pose, "position"): + pos = pose.position + orient = getattr(pose, "orientation", None) + x, y, z = float(pos.x), float(pos.y), float(getattr(pos, "z", 0.0)) + if orient is not None: + return (x, y, z, float(orient.x), float(orient.y), float(orient.z), float(orient.w)) + return (x, y, z, 0.0, 0.0, 0.0, 1.0) + if isinstance(pose, (list, tuple)): + vals = [float(v) for v in pose] + while len(vals) < 7: + vals.append(0.0 if len(vals) < 6 else 1.0) + return tuple(vals[:7]) + return None def _reconstruct_pose( @@ -106,927 +87,588 @@ def _reconstruct_pose( qy: float | None, qz: float | None, qw: float | None, -) -> Any | None: - """Rebuild a PoseStamped from column values.""" +) -> tuple[float, ...] | None: if x is None: return None - from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped - - return PoseStamped( - position=[x, y if y is not None else 0.0, z if z is not None else 0.0], - orientation=[ - qx if qx is not None else 0.0, - qy if qy is not None else 0.0, - qz if qz is not None else 0.0, - qw if qw is not None else 1.0, - ], - ) - + return (x, y or 0.0, z or 0.0, qx or 0.0, qy or 0.0, qz or 0.0, qw or 1.0) -def _serialize_tags(tags: dict[str, Any] | None) -> str: - if not tags: - return "{}" - return json.dumps(tags, separators=(",", ":")) +def _compile_filter(f: Filter, stream: str, prefix: str = "") -> tuple[str, list[Any]] | None: + """Compile a filter to SQL WHERE clause. Returns None for non-pushable filters. -def _deserialize_tags(text: str) -> dict[str, Any]: - if not text: - return {} - return json.loads(text) # type: ignore[no-any-return] - - -# ── SQL building ────────────────────────────────────────────────────── - -# Columns selected from the meta table (no payload). -_META_COLS = "id, ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, tags, parent_id" - - -def _compile_filter(f: Filter, table: str) -> tuple[str, list[Any]]: - """Compile a single filter to (SQL fragment, params).""" + ``stream`` is the raw stream name (for R*Tree table references). + ``prefix`` is a column qualifier (e.g. ``"meta."`` for JOIN queries). + """ if isinstance(f, AfterFilter): - return f"{table}.ts > ?", [f.t] + return (f"{prefix}ts > ?", [f.t]) if isinstance(f, BeforeFilter): - return f"{table}.ts < ?", [f.t] + return (f"{prefix}ts < ?", [f.t]) if isinstance(f, TimeRangeFilter): - return f"{table}.ts >= ? AND {table}.ts <= ?", [f.t1, f.t2] + return (f"{prefix}ts >= ? AND {prefix}ts <= ?", [f.t1, f.t2]) if isinstance(f, AtFilter): - return f"ABS({table}.ts - ?) <= ?", [f.t, f.tolerance] + return (f"ABS({prefix}ts - ?) <= ?", [f.t, f.tolerance]) if isinstance(f, TagsFilter): - clauses: list[str] = [] + clauses = [] params: list[Any] = [] - for key, val in f.tags: - _validate_identifier(key) - clauses.append(f"json_extract({table}.tags, '$.{key}') = ?") - params.append(val) - return " AND ".join(clauses), params + for k, v in f.tags.items(): + clauses.append(f"json_extract({prefix}tags, '$.{k}') = ?") + params.append(v) + return (" AND ".join(clauses), params) if isinstance(f, NearFilter): - # Handled via R*Tree JOIN — see _compile_query - return "1=1", [] - if isinstance(f, EmbeddingSearchFilter): - return "1=1", [] - if isinstance(f, TextSearchFilter): - return "1=1", [] - if isinstance(f, LineageFilter): - inner_sql, params = _compile_ids(f.source_query, f.source_table, select_col="parent_id") - for hop in f.hops: - inner_sql = f"SELECT parent_id FROM {hop} WHERE id IN ({inner_sql})" - return f"{table}.id IN ({inner_sql})", params - raise TypeError(f"Unknown filter type: {type(f)}") - - -def _compile_ids( - query: StreamQuery, table: str, *, select_col: str = "id" -) -> tuple[str, list[Any]]: - """Compile a StreamQuery to ``SELECT {col} FROM {table} WHERE ...``. - - Unlike ``_compile_query``, this handles *all* filter types as SQL — including - EmbeddingSearchFilter and TextSearchFilter as inline subqueries — so that the - result can be nested inside another query (used by LineageFilter). - """ - where_parts: list[str] = [] - params: list[Any] = [] - joins: list[str] = [] - - for f in query.filters: - if isinstance(f, EmbeddingSearchFilter): - where_parts.append( - f"{table}.id IN (SELECT rowid FROM {table}_vec WHERE embedding MATCH ? AND k = ?)" - ) - params.extend([json.dumps(f.query), f.k]) - elif isinstance(f, TextSearchFilter): - fts_sub = f"SELECT rowid FROM {table}_fts WHERE content MATCH ?" - fts_params: list[Any] = [f.text] - if f.k is not None: - fts_sub += " LIMIT ?" - fts_params.append(f.k) - where_parts.append(f"{table}.id IN ({fts_sub})") - params.extend(fts_params) - elif isinstance(f, NearFilter): - joins.append(f"JOIN {table}_rtree AS r ON r.id = {table}.id") - p = f.pose.position - x, y, z = p.x, p.y, p.z - where_parts.append( - "r.min_x >= ? AND r.max_x <= ? AND " - "r.min_y >= ? AND r.max_y <= ? AND " - "r.min_z >= ? AND r.max_z <= ?" - ) - params.extend( - [x - f.radius, x + f.radius, y - f.radius, y + f.radius, z - f.radius, z + f.radius] - ) - else: - # Simple filters + LineageFilter → delegate to _compile_filter - sql_frag, p = _compile_filter(f, table) - where_parts.append(sql_frag) - params.extend(p) - - where = " AND ".join(where_parts) if where_parts else "1=1" - join_clause = " ".join(joins) - - sql = f"SELECT {table}.{select_col} FROM {table}" - if join_clause: - sql += f" {join_clause}" - sql += f" WHERE {where}" - - if query.order_field: - if query.order_field not in _ALLOWED_ORDER_FIELDS: - raise ValueError(f"Invalid order field: {query.order_field!r}") - sql += f" ORDER BY {query.order_field}" - if query.order_desc: - sql += " DESC" - if query.limit_val is not None: - sql += f" LIMIT {query.limit_val}" - if query.offset_val is not None: - sql += f" OFFSET {query.offset_val}" - - return sql, params + pose = f.pose + if pose is None: + return None + if hasattr(pose, "position"): + pose = pose.position + cx, cy, cz = _xyz(pose) + r = f.radius + # R*Tree bounding-box pre-filter + exact squared-distance check + rtree_sql = ( + f'{prefix}id IN (SELECT id FROM "{stream}_rtree" ' + f"WHERE x_min >= ? AND x_max <= ? " + f"AND y_min >= ? AND y_max <= ? " + f"AND z_min >= ? AND z_max <= ?)" + ) + dist_sql = ( + f"(({prefix}pose_x - ?) * ({prefix}pose_x - ?) + " + f"({prefix}pose_y - ?) * ({prefix}pose_y - ?) + " + f"({prefix}pose_z - ?) * ({prefix}pose_z - ?) <= ?)" + ) + return ( + f"{rtree_sql} AND {dist_sql}", + [ + cx - r, + cx + r, + cy - r, + cy + r, + cz - r, + cz + r, # R*Tree bbox + cx, + cx, + cy, + cy, + cz, + cz, + r * r, # squared distance + ], + ) + # PredicateFilter — not pushable + return None -def _has_near_filter(query: StreamQuery) -> NearFilter | None: - for f in query.filters: - if isinstance(f, NearFilter): - return f - return None +def _compile_query( + query: StreamQuery, + table: str, + *, + join_blob: bool = False, +) -> tuple[str, list[Any], list[Filter]]: + """Compile a StreamQuery to SQL. + Returns (sql, params, python_filters) where python_filters must be + applied as post-filters in Python. + """ + prefix = "meta." if join_blob else "" + if join_blob: + select = f'SELECT meta.id, meta.ts, meta.pose_x, meta.pose_y, meta.pose_z, meta.pose_qx, meta.pose_qy, meta.pose_qz, meta.pose_qw, json(meta.tags), blob.data FROM "{table}" AS meta JOIN "{table}_blob" AS blob ON blob.id = meta.id' + else: + select = f'SELECT id, ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, json(tags) FROM "{table}"' -def _compile_query(query: StreamQuery, table: str) -> tuple[str, list[Any]]: - """Compile a StreamQuery to (SQL, params) for a metadata SELECT.""" where_parts: list[str] = [] params: list[Any] = [] - joins: list[str] = [] + python_filters: list[Filter] = [] for f in query.filters: - if isinstance(f, NearFilter): - # R*Tree bounding-box pre-filter + exact Euclidean distance in SQL - joins.append(f"JOIN {table}_rtree AS r ON r.id = {table}.id") - p = f.pose.position - x, y, z = p.x, p.y, p.z - where_parts.append( - "r.min_x >= ? AND r.max_x <= ? AND " - "r.min_y >= ? AND r.max_y <= ? AND " - "r.min_z >= ? AND r.max_z <= ?" - ) - params.extend( - [x - f.radius, x + f.radius, y - f.radius, y + f.radius, z - f.radius, z + f.radius] - ) - # Exact spherical check so LIMIT/OFFSET work correctly - where_parts.append( - f"({table}.pose_x - ?) * ({table}.pose_x - ?) + " - f"({table}.pose_y - ?) * ({table}.pose_y - ?) + " - f"({table}.pose_z - ?) * ({table}.pose_z - ?) <= ? * ?" - ) - params.extend([x, x, y, y, z, z, f.radius, f.radius]) + compiled = _compile_filter(f, table, prefix) + if compiled is not None: + sql_part, sql_params = compiled + where_parts.append(sql_part) + params.extend(sql_params) else: - sql_frag, p = _compile_filter(f, table) - where_parts.append(sql_frag) - params.extend(p) + python_filters.append(f) - where = " AND ".join(where_parts) if where_parts else "1=1" - join_clause = " ".join(joins) + sql = select + if where_parts: + sql += " WHERE " + " AND ".join(where_parts) + # ORDER BY if query.order_field: - if query.order_field not in _ALLOWED_ORDER_FIELDS: - raise ValueError(f"Invalid order field: {query.order_field!r}") - order = f"ORDER BY {table}.{query.order_field}" - if query.order_desc: - order += " DESC" + direction = "DESC" if query.order_desc else "ASC" + sql += f" ORDER BY {prefix}{query.order_field} {direction}" else: - order = f"ORDER BY {table}.id" + sql += f" ORDER BY {prefix}id ASC" + + # Only push LIMIT/OFFSET to SQL when there are no Python post-filters + if not python_filters and not query.search_text: + if query.limit_val is not None: + if query.offset_val: + sql += f" LIMIT {query.limit_val} OFFSET {query.offset_val}" + else: + sql += f" LIMIT {query.limit_val}" + elif query.offset_val: + sql += f" LIMIT -1 OFFSET {query.offset_val}" - sql = f"SELECT {table}.{_META_COLS.replace(', ', f', {table}.')} FROM {table}" - if join_clause: - sql += f" {join_clause}" - sql += f" WHERE {where} {order}" - if query.limit_val is not None: - sql += f" LIMIT {query.limit_val}" - if query.offset_val is not None: - sql += f" OFFSET {query.offset_val}" - return sql, params + return (sql, params, python_filters) -def _compile_count(query: StreamQuery, table: str) -> tuple[str, list[Any]]: +def _compile_count( + query: StreamQuery, + table: str, +) -> tuple[str, list[Any], list[Filter]]: + """Compile a StreamQuery to a COUNT SQL query.""" where_parts: list[str] = [] params: list[Any] = [] - joins: list[str] = [] + python_filters: list[Filter] = [] for f in query.filters: - if isinstance(f, NearFilter): - joins.append(f"JOIN {table}_rtree AS r ON r.id = {table}.id") - p = f.pose.position - x, y, z = p.x, p.y, p.z - where_parts.append( - "r.min_x >= ? AND r.max_x <= ? AND " - "r.min_y >= ? AND r.max_y <= ? AND " - "r.min_z >= ? AND r.max_z <= ?" - ) - params.extend( - [ - x - f.radius, - x + f.radius, - y - f.radius, - y + f.radius, - z - f.radius, - z + f.radius, - ] - ) + compiled = _compile_filter(f, table) + if compiled is not None: + sql_part, sql_params = compiled + where_parts.append(sql_part) + params.extend(sql_params) else: - sql_frag, p = _compile_filter(f, table) - where_parts.append(sql_frag) - params.extend(p) - - where = " AND ".join(where_parts) if where_parts else "1=1" - join_clause = " ".join(joins) - sql = f"SELECT COUNT(*) FROM {table}" - if join_clause: - sql += f" {join_clause}" - sql += f" WHERE {where}" - return sql, params + python_filters.append(f) + sql = f'SELECT COUNT(*) FROM "{table}"' + if where_parts: + sql += " WHERE " + " AND ".join(where_parts) -# ── Near-filter post-processing (exact distance after R*Tree bbox) ─── + return (sql, params, python_filters) -def _apply_near_post_filter( - rows: list[Observation[Any]], near: NearFilter -) -> list[Observation[Any]]: - """Post-filter R*Tree candidates by exact Euclidean distance.""" - tp = near.pose.position - result: list[Observation[Any]] = [] - for obs in rows: - if obs.pose is None: - continue - op = obs.pose.position - dist = ((op.x - tp.x) ** 2 + (op.y - tp.y) ** 2 + (op.z - tp.z) ** 2) ** 0.5 - if dist <= near.radius: - result.append(obs) - return result +# ── SqliteBackend ──────────────────────────────────────────────── -# ── Backend ─────────────────────────────────────────────────────────── +class SqliteBackend(Configurable[BackendConfig], Generic[T]): + """SQLite-backed observation storage for a single stream (table).""" + default_config: type[BackendConfig] = BackendConfig -class SqliteStreamBackend: - """StreamBackend implementation for a single SQLite-backed stream.""" - - def __init__( - self, - conn: sqlite3.Connection, - table: str, - *, - pose_provider: PoseProvider | None = None, - codec: Codec[Any] | None = None, - ) -> None: - _validate_identifier(table) + def __init__(self, conn: sqlite3.Connection, name: str, **kwargs: Any) -> None: + super().__init__(**kwargs) self._conn = conn - self._table = table - self._pose_provider = pose_provider - self._codec = codec or PickleCodec() - self._subject: Subject[Observation[Any]] = Subject() # type: ignore[type-arg] + self._name = name + self._codec: Codec[Any] = self.config.codec # type: ignore[assignment] + self._channel: LiveChannel[T] = self.config.live_channel or SubjectChannel() + self._lock = threading.Lock() + self._tag_indexes: set[str] = set() @property - def appended_subject(self) -> Subject[Observation[Any]]: # type: ignore[type-arg] - return self._subject + def name(self) -> str: + return self._name @property - def stream_name(self) -> str: - return self._table + def live_channel(self) -> LiveChannel[T]: + return self._channel - def _post_insert(self, row_id: int, payload: Any) -> None: - """Hook for subclasses to add extra inserts inside the transaction.""" + @property + def _join_blobs(self) -> bool: + if not self.config.eager_blobs: + return False + bs = self.config.blob_store + return isinstance(bs, SqliteBlobStore) and bs._conn is self._conn + + def _make_loader(self, row_id: int) -> Any: + bs = self.config.blob_store + assert bs is not None + name, codec = self._name, self._codec + owner_tid = threading.get_ident() - def do_append( - self, - payload: Any, - ts: float | None, - pose: Any | None, - tags: dict[str, Any] | None, - parent_id: int | None = None, - ) -> Observation[Any]: - if ts is None: - ts = time.time() - if pose is None and self._pose_provider is not None: - pose = self._pose_provider(ts) - - pose_cols = _decompose_pose(pose) - tags_json = _serialize_tags(tags) - - # Encode payload before touching the DB so a codec error can't leave - # a metadata row without a matching payload row. - payload_blob = self._codec.encode(payload) - - # 1. Insert into meta table - if pose_cols is not None: - cur = self._conn.execute( - f"INSERT INTO {self._table} " - "(ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, tags, parent_id) " - "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - (ts, *pose_cols, tags_json, parent_id), - ) - else: - cur = self._conn.execute( - f"INSERT INTO {self._table} (ts, tags, parent_id) VALUES (?, ?, ?)", - (ts, tags_json, parent_id), - ) - row_id = cur.lastrowid - assert row_id is not None + def loader() -> Any: + assert threading.get_ident() == owner_tid + raw = bs.get(name, row_id) + return codec.decode(raw) - # 2. Insert into payload table - self._conn.execute( - f"INSERT INTO {self._table}_payload (id, data) VALUES (?, ?)", - (row_id, payload_blob), - ) + return loader - # 3. Insert into R*Tree (if pose) - if pose_cols is not None: - x, y, z = pose_cols[0], pose_cols[1], pose_cols[2] - self._conn.execute( - f"INSERT INTO {self._table}_rtree (id, min_x, max_x, min_y, max_y, min_z, max_z) " - "VALUES (?, ?, ?, ?, ?, ?, ?)", - (row_id, x, x, y, y, z, z), - ) + def _row_to_obs(self, row: tuple[Any, ...], *, has_blob: bool = False) -> Observation[T]: + if has_blob: + row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json, blob_data = row + else: + row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json = row + blob_data = None - # 4. Subclass hook (vec0, FTS, etc.) - self._post_insert(row_id, payload) + pose = _reconstruct_pose(px, py, pz, qx, qy, qz, qw) + tags = json.loads(tags_json) if tags_json else {} - self._conn.commit() + if has_blob and blob_data is not None: + data = self._codec.decode(blob_data) + return Observation(id=row_id, ts=ts, pose=pose, tags=tags, _data=data) - obs = Observation( + return Observation( id=row_id, ts=ts, pose=pose, - tags=tags or {}, - parent_id=parent_id, - _data=payload, + tags=tags, + _data=_UNLOADED, + _loader=self._make_loader(row_id), # type: ignore[arg-type] ) - self._subject.on_next(obs) - return obs - - def execute_fetch(self, query: StreamQuery) -> list[Observation[Any]]: - sql, params = _compile_query(query, self._table) - rows = self._conn.execute(sql, params).fetchall() - return [self._row_to_obs(r) for r in rows] - def execute_count(self, query: StreamQuery) -> int: - sql, params = _compile_count(query, self._table) - result = self._conn.execute(sql, params).fetchone() - return result[0] if result else 0 # type: ignore[no-any-return] + # ── Write ──────────────────────────────────────────────────── - def load_data(self, row_id: int) -> Any: - """Load payload by row ID from the database.""" - r = self._conn.execute( - f"SELECT data FROM {self._table}_payload WHERE id = ?", (row_id,) - ).fetchone() - if r is None: - raise LookupError(f"No payload for id={row_id}") - return self._codec.decode(r[0]) - - def _make_loader(self, row_id: int) -> Callable[[], Any]: - owner_tid = threading.get_ident() + def _ensure_tag_indexes(self, tags: dict[str, Any]) -> None: + """Auto-create expression indexes for any new tag keys.""" + for key in tags: + if key not in self._tag_indexes and _IDENT_RE.match(key): + self._conn.execute( + f'CREATE INDEX IF NOT EXISTS "{self._name}_tag_{key}" ' + f"ON \"{self._name}\"(json_extract(tags, '$.{key}'))" + ) + self._tag_indexes.add(key) + + def append(self, obs: Observation[T]) -> Observation[T]: + encoded = self._codec.encode(obs._data) + pose = _decompose_pose(obs.pose) + tags_json = json.dumps(obs.tags) if obs.tags else "{}" + + with self._lock: + if obs.tags: + self._ensure_tag_indexes(obs.tags) + if pose: + px, py, pz, qx, qy, qz, qw = pose + else: + px = py = pz = qx = qy = qz = qw = None # type: ignore[assignment] - def loader() -> Any: - if threading.get_ident() != owner_tid: - raise RuntimeError( - "Observation.data accessed from a different thread than the one that " - "fetched it. Access .data on the original thread first to cache it, " - "or use obs.load() before passing across threads." + cur = self._conn.execute( + f'INSERT INTO "{self._name}" (ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, tags) ' + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, jsonb(?))", + (obs.ts, px, py, pz, qx, qy, qz, qw, tags_json), + ) + row_id = cur.lastrowid + assert row_id is not None + + bs = self.config.blob_store + assert bs is not None + bs.put(self._name, row_id, encoded) + + # R*Tree spatial index + if pose: + self._conn.execute( + f'INSERT INTO "{self._name}_rtree" (id, x_min, x_max, y_min, y_max, z_min, z_max) ' + "VALUES (?, ?, ?, ?, ?, ?, ?)", + (row_id, px, px, py, py, pz, pz), ) - return self.load_data(row_id) - return loader + vs = self.config.vector_store + if vs is not None: + emb = getattr(obs, "embedding", None) + if emb is not None: + vs.put(self._name, row_id, emb) - def _row_to_obs(self, row: Any) -> Observation[Any]: - row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json, pid = row - return Observation( - id=row_id, - ts=ts, - pose=_reconstruct_pose(px, py, pz, qx, qy, qz, qw), - tags=_deserialize_tags(tags_json), - parent_id=pid, - _data_loader=self._make_loader(row_id), - ) + self._conn.commit() + obs.id = row_id + self._channel.notify(obs) + return obs -class SqliteEmbeddingBackend(SqliteStreamBackend): - """Backend for EmbeddingStream — stores vectors in a vec0 virtual table.""" + # ── Read ───────────────────────────────────────────────────── - def __init__( - self, - conn: sqlite3.Connection, - table: str, - *, - vec_dimensions: int | None = None, - pose_provider: PoseProvider | None = None, - parent_table: str | None = None, - codec: Codec[Any] | None = None, - ) -> None: - super().__init__(conn, table, pose_provider=pose_provider, codec=codec) - self._vec_dimensions = vec_dimensions - self._parent_table = parent_table - - def _post_insert(self, row_id: int, payload: Any) -> None: - from dimos.models.embedding.base import Embedding - - if isinstance(payload, Embedding): - vec = payload.to_numpy().tolist() - if self._vec_dimensions is None: - self._vec_dimensions = len(vec) - self._ensure_vec_table() - self._conn.execute( - f"INSERT INTO {self._table}_vec (rowid, embedding) VALUES (?, ?)", - (row_id, json.dumps(vec)), - ) + def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: + if query.search_vec is not None and query.live_buffer is not None: + raise TypeError("Cannot combine .search() with .live() — search is a batch operation.") + buf = query.live_buffer + if buf is not None: + sub = self._channel.subscribe(buf) + return self._iterate_live(query, buf, sub) + return self._iterate_snapshot(query) - def _ensure_vec_table(self) -> None: - if self._vec_dimensions is None: + def _iterate_snapshot(self, query: StreamQuery) -> Iterator[Observation[T]]: + if query.search_vec is not None and self.config.vector_store is not None: + yield from self._vector_search(query) return - self._conn.execute( - f"CREATE VIRTUAL TABLE IF NOT EXISTS {self._table}_vec " - f"USING vec0(embedding float[{self._vec_dimensions}] distance_metric=cosine)" - ) - self._conn.commit() - - def execute_fetch(self, query: StreamQuery) -> list[Observation[Any]]: - emb_filter = None - for f in query.filters: - if isinstance(f, EmbeddingSearchFilter): - emb_filter = f - break - - if emb_filter is not None: - return self._fetch_by_vector(query, emb_filter) - - return super().execute_fetch(query) - - def execute_count(self, query: StreamQuery) -> int: - if any(isinstance(f, EmbeddingSearchFilter) for f in query.filters): - return len(self.execute_fetch(query)) - return super().execute_count(query) - - def _fetch_by_vector( - self, query: StreamQuery, emb_filter: EmbeddingSearchFilter - ) -> list[Observation[Any]]: - """Fetch using vec0 similarity search, then apply remaining filters.""" - vec_sql = ( - f"SELECT rowid, distance FROM {self._table}_vec " - f"WHERE embedding MATCH ? ORDER BY distance LIMIT ?" - ) - vec_rows = self._conn.execute( - vec_sql, (json.dumps(emb_filter.query), emb_filter.k) - ).fetchall() - - if not vec_rows: - return [] - - dist_map = {r[0]: r[1] for r in vec_rows} - rowids = list(dist_map.keys()) - placeholders = ",".join("?" * len(rowids)) - - where_parts: list[str] = [f"{self._table}.id IN ({placeholders})"] - params: list[Any] = list(rowids) - - for f in query.filters: - if isinstance(f, EmbeddingSearchFilter): - continue - sql_frag, p = _compile_filter(f, self._table) - where_parts.append(sql_frag) - params.extend(p) - - where = " AND ".join(where_parts) - sql = ( - f"SELECT {self._table}.{_META_COLS.replace(', ', f', {self._table}.')} " - f"FROM {self._table} WHERE {where}" - ) - rows = self._conn.execute(sql, params).fetchall() - observations: list[Observation[Any]] = [self._row_to_obs(r) for r in rows] + join = self._join_blobs + sql, params, python_filters = _compile_query(query, self._name, join_blob=join) - # Populate similarity scores from vec0 cosine distance (0=identical, 2=opposite) - for obs in observations: - if isinstance(obs, EmbeddingObservation): - obs.similarity = max(0.0, min(1.0, 1.0 - dist_map.get(obs.id, 0.0))) + cur = self._conn.execute(sql, params) + cur.arraysize = self.config.page_size + it: Iterator[Observation[T]] = (self._row_to_obs(r, has_blob=join) for r in cur) - # Re-sort by distance rank (IN clause doesn't preserve vec0 ordering) - rank = {rid: i for i, rid in enumerate(rowids)} - observations.sort(key=lambda o: rank.get(o.id, len(rank))) + # Text search — requires loading data + if query.search_text is not None: + needle = query.search_text.lower() + it = (obs for obs in it if needle in str(obs.data).lower()) - near = _has_near_filter(query) - if near is not None: - observations = _apply_near_post_filter(observations, near) + # Apply Python post-filters + if python_filters: + it = (obs for obs in it if all(f.matches(obs) for f in python_filters)) - return observations + # Apply LIMIT/OFFSET in Python when we couldn't push to SQL + if python_filters or query.search_text: + if query.offset_val: + it = islice(it, query.offset_val, None) + if query.limit_val is not None: + it = islice(it, query.limit_val) - def _row_to_obs(self, row: Any) -> EmbeddingObservation: - row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json, pid = row - return EmbeddingObservation( - id=row_id, - ts=ts, - pose=_reconstruct_pose(px, py, pz, qx, qy, qz, qw), - tags=_deserialize_tags(tags_json), - parent_id=pid, - _data_loader=self._make_loader(row_id), - ) + yield from it + def _vector_search(self, query: StreamQuery) -> Iterator[Observation[T]]: + vs = self.config.vector_store + assert vs is not None and query.search_vec is not None -class SqliteTextBackend(SqliteStreamBackend): - """Backend for TextStream — maintains an FTS5 index.""" + hits = vs.search(self._name, query.search_vec, query.search_k or 10) + if not hits: + return - def __init__( - self, - conn: sqlite3.Connection, - table: str, - *, - tokenizer: str = "unicode61", - pose_provider: PoseProvider | None = None, - codec: Codec[Any] | None = None, - ) -> None: - super().__init__(conn, table, pose_provider=pose_provider, codec=codec) - self._tokenizer = tokenizer + ids = [h[0] for h in hits] + dict(hits) + + # Batch-fetch metadata + join = self._join_blobs + placeholders = ",".join("?" * len(ids)) + if join: + sql = ( + f"SELECT meta.id, meta.ts, meta.pose_x, meta.pose_y, meta.pose_z, " + f"meta.pose_qx, meta.pose_qy, meta.pose_qz, meta.pose_qw, json(meta.tags), blob.data " + f'FROM "{self._name}" AS meta ' + f'JOIN "{self._name}_blob" AS blob ON blob.id = meta.id ' + f"WHERE meta.id IN ({placeholders})" + ) + else: + sql = ( + f"SELECT id, ts, pose_x, pose_y, pose_z, " + f"pose_qx, pose_qy, pose_qz, pose_qw, json(tags) " + f'FROM "{self._name}" WHERE id IN ({placeholders})' + ) - def _post_insert(self, row_id: int, payload: Any) -> None: - text = str(payload) if payload is not None else "" - self._conn.execute( - f"INSERT INTO {self._table}_fts (rowid, content) VALUES (?, ?)", - (row_id, text), - ) + rows = self._conn.execute(sql, ids).fetchall() + obs_by_id: dict[int, Observation[T]] = {} + for r in rows: + obs = self._row_to_obs(r, has_blob=join) + obs_by_id[obs.id] = obs + + # Preserve VectorStore ranking order, promoting to EmbeddedObservation + ranked: list[Observation[T]] = [] + for obs_id, sim in hits: + match = obs_by_id.get(obs_id) + if match is not None: + ranked.append( + match.derive(data=match.data, embedding=query.search_vec, similarity=sim) + ) - def execute_fetch(self, query: StreamQuery) -> list[Observation[Any]]: - text_filter = None - for f in query.filters: - if isinstance(f, TextSearchFilter): - text_filter = f - break - - if text_filter is not None: - return self._fetch_by_text(query, text_filter) - - return super().execute_fetch(query) - - def _fetch_by_text( - self, query: StreamQuery, text_filter: TextSearchFilter - ) -> list[Observation[Any]]: - fts_sql = f"SELECT rowid, rank FROM {self._table}_fts WHERE content MATCH ? ORDER BY rank" - fts_params: list[Any] = [text_filter.text] - if text_filter.k is not None: - fts_sql += " LIMIT ?" - fts_params.append(text_filter.k) - - fts_rows = self._conn.execute(fts_sql, fts_params).fetchall() - if not fts_rows: - return [] - - rowids = [r[0] for r in fts_rows] - placeholders = ",".join("?" * len(rowids)) - - where_parts: list[str] = [f"{self._table}.id IN ({placeholders})"] - params: list[Any] = list(rowids) - - for f in query.filters: - if isinstance(f, TextSearchFilter): - continue - sql_frag, p = _compile_filter(f, self._table) - where_parts.append(sql_frag) - params.extend(p) - - where = " AND ".join(where_parts) - sql = ( - f"SELECT {self._table}.{_META_COLS.replace(', ', f', {self._table}.')} " - f"FROM {self._table} WHERE {where}" - ) - rows = self._conn.execute(sql, params).fetchall() + # Apply remaining query ops (skip vector search) + rest = replace(query, search_vec=None, search_k=None) + yield from rest.apply(iter(ranked)) - observations = [self._row_to_obs(r) for r in rows] + def _iterate_live( + self, + query: StreamQuery, + buf: BackpressureBuffer[Observation[T]], + sub: DisposableBase, + ) -> Iterator[Observation[T]]: + from dimos.memory.buffer import ClosedError + + # Backfill phase + last_id = -1 + for obs in self._iterate_snapshot(query): + last_id = max(last_id, obs.id) + yield obs + + # Live tail + filters = query.filters + try: + while True: + obs = buf.take() + if obs.id <= last_id: + continue + last_id = obs.id + if filters and not all(f.matches(obs) for f in filters): + continue + yield obs + except (ClosedError, StopIteration): + sub.dispose() - # Re-sort by FTS rank (IN clause doesn't preserve FTS5 ordering) - rank = {rid: i for i, rid in enumerate(rowids)} - observations.sort(key=lambda o: rank.get(o.id, len(rank))) + def count(self, query: StreamQuery) -> int: + if query.search_vec or query.search_text: + return sum(1 for _ in self.iterate(query)) - near = _has_near_filter(query) - if near is not None: - observations = _apply_near_post_filter(observations, near) + sql, params, python_filters = _compile_count(query, self._name) + if python_filters: + return sum(1 for _ in self.iterate(query)) - return observations + row = self._conn.execute(sql, params).fetchone() + return int(row[0]) if row else 0 -# ── Session ─────────────────────────────────────────────────────────── +# ── SqliteSession ──────────────────────────────────────────────── class SqliteSession(Session): - """Session against a SQLite database.""" + """Session owning a single SQLite connection.""" - def __init__(self, conn: sqlite3.Connection) -> None: + def __init__( + self, conn: sqlite3.Connection, *, vec_available: bool = False, **kwargs: Any + ) -> None: + super().__init__(**kwargs) self._conn = conn - self._streams: dict[str, Stream[Any]] = {} - self._ensure_meta_table() - - def resolve_parent_stream(self, name: str) -> str | None: - row = self._conn.execute( - "SELECT parent_stream FROM _streams WHERE name = ?", (name,) - ).fetchone() - return row[0] if row and row[0] else None - - def resolve_lineage_chain(self, source: str, target: str) -> tuple[str, ...]: - """Walk ``_streams.parent_stream`` from *source* toward *target*. - - Returns intermediate table names (empty tuple for direct parent). - """ - current = source - intermediates: list[str] = [] - visited = {source} + self._vec_available = vec_available + self._blob_store: SqliteBlobStore | None = None + self._vector_store: Any | None = None - while True: - row = self._conn.execute( - "SELECT parent_stream FROM _streams WHERE name = ?", (current,) - ).fetchone() - if not row or not row[0]: - raise ValueError(f"No lineage path from {source!r} to {target!r}") - - parent_name: str = row[0] - if parent_name == target: - return tuple(intermediates) - - if parent_name in visited: - raise ValueError(f"Cycle detected in lineage chain at {parent_name!r}") - - visited.add(parent_name) - intermediates.append(parent_name) - current = parent_name - - def _ensure_meta_table(self) -> None: + # Create stream registry self._conn.execute( "CREATE TABLE IF NOT EXISTS _streams (" - " name TEXT PRIMARY KEY," - " payload_module TEXT," - " stream_kind TEXT DEFAULT 'stream'," - " parent_stream TEXT," - " embedding_dim INTEGER" + " name TEXT PRIMARY KEY," + " payload_module TEXT NOT NULL," + " codec_id TEXT NOT NULL" ")" ) self._conn.commit() - def stream( - self, - name: str, - payload_type: type, - *, - pose_provider: PoseProvider | None = None, - ) -> Stream[Any]: + def _ensure_shared_stores(self) -> None: + """Lazily create shared stores on first stream creation.""" + if self._blob_store is None: + self._blob_store = SqliteBlobStore(self._conn) + if self._vector_store is None and self._vec_available: + from dimos.memory.vectorstore.sqlite import SqliteVectorStore + + self._vector_store = SqliteVectorStore(self._conn) + + @staticmethod + def _codec_id(codec: Codec[Any]) -> str: + from dimos.memory.codecs.jpeg import JpegCodec + from dimos.memory.codecs.lcm import LcmCodec + + if isinstance(codec, JpegCodec): + return "jpeg" + if isinstance(codec, LcmCodec): + return "lcm" + return "pickle" + + @staticmethod + def _codec_from_id(codec_id: str, payload_module: str) -> Codec[Any]: + from dimos.memory.codecs.pickle import PickleCodec + + if codec_id == "jpeg": + from dimos.memory.codecs.jpeg import JpegCodec + + return JpegCodec() + if codec_id == "lcm": + from dimos.memory.codecs.lcm import LcmCodec + + # Resolve the payload type from module path + parts = payload_module.rsplit(".", 1) + if len(parts) == 2: + import importlib + + mod = importlib.import_module(parts[0]) + cls = getattr(mod, parts[1]) + return LcmCodec(cls) + return PickleCodec() + return PickleCodec() + + def _create_backend( + self, name: str, payload_type: type[Any] | None = None, **config: Any + ) -> Backend[Any]: _validate_identifier(name) - if name in self._streams: - return self._streams[name] - - self._ensure_stream_tables(name) - self._register_stream(name, payload_type, "stream") + self._ensure_shared_stores() - codec = codec_for_type(payload_type) - backend = SqliteStreamBackend(self._conn, name, pose_provider=pose_provider, codec=codec) - s: Stream[Any] = Stream(backend=backend, session=self, payload_type=payload_type) - self._streams[name] = s - return s - - def text_stream( - self, - name: str, - *, - tokenizer: str = "unicode61", - pose_provider: PoseProvider | None = None, - ) -> TextStream[Any]: - _validate_identifier(name) - if name in self._streams: - return self._streams[name] # type: ignore[return-value] - - self._ensure_stream_tables(name) - self._ensure_fts_table(name, tokenizer) - self._register_stream(name, str, "text") - - codec = codec_for_type(str) - backend = SqliteTextBackend( - self._conn, name, tokenizer=tokenizer, pose_provider=pose_provider, codec=codec - ) - ts: TextStream[Any] = TextStream(backend=backend, session=self, payload_type=str) - self._streams[name] = ts - return ts - - def embedding_stream( - self, - name: str, - *, - vec_dimensions: int | None = None, - pose_provider: PoseProvider | None = None, - parent_table: str | None = None, - embedding_model: EmbeddingModel | None = None, - ) -> EmbeddingStream[Any]: - from dimos.models.embedding.base import Embedding - - _validate_identifier(name) - if name in self._streams: - existing = self._streams[name] - if embedding_model is not None and isinstance(existing, EmbeddingStream): - existing._embedding_model = embedding_model - return existing # type: ignore[return-value] - - self._ensure_stream_tables(name) - self._register_stream(name, Embedding, "embedding", embedding_dim=vec_dimensions) - - codec = codec_for_type(Embedding) - backend = SqliteEmbeddingBackend( - self._conn, - name, - vec_dimensions=vec_dimensions, - pose_provider=pose_provider, - parent_table=parent_table, - codec=codec, - ) - if vec_dimensions is not None: - backend._ensure_vec_table() - - es: EmbeddingStream[Any] = EmbeddingStream( - backend=backend, - session=self, - embedding_model=embedding_model, - payload_type=Embedding, - ) - self._streams[name] = es - return es - - def list_streams(self) -> list[Stream[Any]]: - rows = self._conn.execute( - "SELECT name, payload_module, stream_kind FROM _streams" - ).fetchall() - result: list[Stream[Any]] = [] - for name, pmodule, kind in rows: - _validate_identifier(name) - payload_type = module_path_to_type(pmodule) if pmodule else None - kind = kind or "stream" - if kind == "embedding": - result.append(self.embedding_stream(name)) - elif kind == "text": - result.append(self.text_stream(name)) - else: - result.append(self.stream(name, payload_type or object)) - return result - - def delete_stream(self, name: str) -> None: - _validate_identifier(name) - for suffix in ("_vec", "_fts", "_rtree", "_payload", ""): - table = f"{name}{suffix}" - # Virtual tables (rtree, fts, vec) need DROP TABLE, not DROP TABLE IF EXISTS - # on some builds, but IF EXISTS is safe for all. - self._conn.execute(f"DROP TABLE IF EXISTS {table}") - self._conn.execute("DELETE FROM _streams WHERE name = ?", (name,)) - self._conn.commit() - self._streams.pop(name, None) + # Look up existing stream in registry + row = self._conn.execute( + "SELECT payload_module, codec_id FROM _streams WHERE name = ?", (name,) + ).fetchone() - def materialize_transform( - self, - name: str, - source: Stream[Any], - transformer: Transformer[Any, Any], - *, - payload_type: type | None = None, - live: bool = False, - backfill_only: bool = False, - ) -> Stream[Any]: - # Resolve source table name for parent lineage - source_table = None - if source._backend is not None: - source_table = source._backend.stream_name - - target: Stream[Any] - if isinstance(transformer, (EmbeddingTransformer, TextEmbeddingTransformer)): - target = self.embedding_stream(name, parent_table=source_table) - target._embedding_model = transformer.model - elif isinstance(transformer, CaptionTransformer): - target = self.text_stream(name) + if row is not None: + stored_module, stored_codec_id = row + if payload_type is not None: + actual_module = f"{payload_type.__module__}.{payload_type.__qualname__}" + if actual_module != stored_module: + raise ValueError( + f"Stream {name!r} was created with type {stored_module}, " + f"but opened with {actual_module}" + ) + codec = config.get("codec") or self._codec_from_id(stored_codec_id, stored_module) else: if payload_type is None: - raise TypeError("materialize_transform requires payload_type for plain streams") - target = self.stream(name, payload_type) - - # Record parent lineage in _streams registry - if source_table is not None: + raise TypeError(f"Stream {name!r} does not exist yet — payload_type is required") + codec = config.get("codec") or codec_for(payload_type) + payload_module = f"{payload_type.__module__}.{payload_type.__qualname__}" self._conn.execute( - "UPDATE _streams SET parent_stream = ? WHERE name = ?", - (source_table, name), + "INSERT INTO _streams (name, payload_module, codec_id) VALUES (?, ?, ?)", + (name, payload_module, self._codec_id(codec)), ) self._conn.commit() - # Backfill existing data - if transformer.supports_backfill and not live: - transformer.process(source, target) - - # Subscribe to live updates - if transformer.supports_live and not backfill_only: - source.subscribe(lambda obs: transformer.on_append(obs, target)) - - return target - - def stop(self) -> None: - for s in self._streams.values(): - if s._backend is not None: - s._backend.appended_subject.on_completed() - self._streams.clear() - self._conn.close() - - # ── Internal helpers ────────────────────────────────────────────── - - def _ensure_stream_tables(self, name: str) -> None: - """Create the meta table, payload table, and R*Tree for a stream.""" + # Create metadata table self._conn.execute( - f"CREATE TABLE IF NOT EXISTS {name} (" - " id INTEGER PRIMARY KEY AUTOINCREMENT," - " ts REAL UNIQUE NOT NULL," - " pose_x REAL," - " pose_y REAL," - " pose_z REAL," - " pose_qx REAL," - " pose_qy REAL," - " pose_qz REAL," - " pose_qw REAL," - " tags TEXT DEFAULT '{}'," - " parent_id INTEGER" + f'CREATE TABLE IF NOT EXISTS "{name}" (' + " id INTEGER PRIMARY KEY AUTOINCREMENT," + " ts REAL NOT NULL UNIQUE," + " pose_x REAL, pose_y REAL, pose_z REAL," + " pose_qx REAL, pose_qy REAL, pose_qz REAL, pose_qw REAL," + " tags BLOB DEFAULT (jsonb('{}'))" ")" ) - + # R*Tree spatial index for pose queries self._conn.execute( - f"CREATE TABLE IF NOT EXISTS {name}_payload ( id INTEGER PRIMARY KEY, data BLOB)" - ) - self._conn.execute( - f"CREATE VIRTUAL TABLE IF NOT EXISTS {name}_rtree USING rtree(" - " id," - " min_x, max_x," - " min_y, max_y," - " min_z, max_z" + f'CREATE VIRTUAL TABLE IF NOT EXISTS "{name}_rtree" USING rtree(' + " id," + " x_min, x_max," + " y_min, y_max," + " z_min, z_max" ")" ) self._conn.commit() - def _ensure_fts_table(self, name: str, tokenizer: str) -> None: - self._conn.execute( - f"CREATE VIRTUAL TABLE IF NOT EXISTS {name}_fts " - f"USING fts5(content, tokenize='{tokenizer}')" - ) - self._conn.commit() + # Merge shared stores as defaults + if "blob_store" not in config or config["blob_store"] is None: + config["blob_store"] = self._blob_store + if "vector_store" not in config or config["vector_store"] is None: + config["vector_store"] = self._vector_store + config["codec"] = codec - def _register_stream( - self, - name: str, - payload_type: type | None, - kind: str, - *, - embedding_dim: int | None = None, - ) -> None: - module_path = type_to_module_path(payload_type) if payload_type else None - self._conn.execute( - "INSERT OR IGNORE INTO _streams (name, payload_module, stream_kind, embedding_dim) " - "VALUES (?, ?, ?, ?)", - (name, module_path, kind, embedding_dim), - ) + return SqliteBackend(self._conn, name, **config) + + def list_streams(self) -> list[str]: + db_names = {row[0] for row in self._conn.execute("SELECT name FROM _streams").fetchall()} + return sorted(db_names | set(self._streams.keys())) + + def delete_stream(self, name: str) -> None: + self._streams.pop(name, None) + self._conn.execute(f'DROP TABLE IF EXISTS "{name}"') + self._conn.execute(f'DROP TABLE IF EXISTS "{name}_blob"') + self._conn.execute(f'DROP TABLE IF EXISTS "{name}_vec"') + self._conn.execute(f'DROP TABLE IF EXISTS "{name}_rtree"') + self._conn.execute("DELETE FROM _streams WHERE name = ?", (name,)) self._conn.commit() - def _resolve_payload_type(self, name: str) -> type | None: - """Look up payload type from _streams metadata (for restart case).""" - row = self._conn.execute( - "SELECT payload_module FROM _streams WHERE name = ?", (name,) - ).fetchone() - if row is None or row[0] is None: - return None - return module_path_to_type(row[0]) + def stop(self) -> None: + super().stop() + self._conn.close() + +# ── SqliteStore ────────────────────────────────────────────────── -# ── Store ───────────────────────────────────────────────────────────── + +@dataclass +class SqliteStoreConfig(StoreConfig): + """Config for SQLite-backed store.""" + + path: str = "memory.db" class SqliteStore(Store): - """SQLite-backed memory store (lightweight factory). + """Store backed by a SQLite database file.""" - Each :meth:`session` call opens a new ``sqlite3.Connection`` with WAL mode - and extensions loaded. Sessions are safe to use from different threads. - """ + default_config: type[SqliteStoreConfig] = SqliteStoreConfig + config: SqliteStoreConfig - def __init__(self, path: str | os.PathLike[str]) -> None: - self._path = str(path) - self._closed = False + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) - def _connect(self) -> sqlite3.Connection: - conn = sqlite3.connect(self._path, check_same_thread=False) + def session(self, **kwargs: Any) -> SqliteSession: + conn = sqlite3.connect(self.config.path, check_same_thread=False) conn.execute("PRAGMA journal_mode=WAL") conn.execute("PRAGMA synchronous=NORMAL") - self._load_extensions(conn) - return conn - def session(self) -> SqliteSession: - if self._closed: - raise RuntimeError("Store is closed") - return SqliteSession(self._connect()) - - def _load_extensions(self, conn: sqlite3.Connection) -> None: + vec_available = False try: import sqlite_vec conn.enable_load_extension(True) sqlite_vec.load(conn) conn.enable_load_extension(False) - except ImportError: + vec_available = True + except (ImportError, Exception): pass - def stop(self) -> None: - self._closed = True + return SqliteSession(conn, vec_available=vec_available, **kwargs) diff --git a/dimos/memory/impl/test_e2e_export.py b/dimos/memory/impl/test_e2e_export.py deleted file mode 100644 index 3804d3d127..0000000000 --- a/dimos/memory/impl/test_e2e_export.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""E2E tests: ingest 5min robot video → sharpness filter → CLIP embed → search. - -The DB is built once and cached on disk so subsequent runs skip ingestion. -Run with: pytest dimos/memory/impl/run_e2e_export.py -s -""" - -from __future__ import annotations - -from pathlib import Path -from typing import TYPE_CHECKING, Any - -import pytest - -from dimos.memory.impl.sqlite import SqliteStore -from dimos.memory.ingest import ingest -from dimos.memory.transformer import ( - CaptionTransformer, - EmbeddingTransformer, - QualityWindowTransformer, -) -from dimos.models.embedding.clip import CLIPModel -from dimos.msgs.sensor_msgs.Image import Image -from dimos.utils.testing import TimedSensorReplay - -if TYPE_CHECKING: - from collections.abc import Generator - - from dimos.memory.stream import EmbeddingStream - -DB_DIR = Path(__file__).parent / "e2e_matches" -DB_DIR.mkdir(exist_ok=True) -DB_PATH = DB_DIR / "e2e.db" - - -@pytest.fixture(scope="module") -def clip() -> CLIPModel: - model = CLIPModel() - model.start() - return model - - -@pytest.fixture(scope="module") -def e2e_db(clip: CLIPModel) -> Generator[tuple[SqliteStore, Any], None, None]: - """Build (or reuse cached) e2e DB with video → sharpness → CLIP embeddings.""" - store = SqliteStore(str(DB_PATH)) - session = store.session() - - existing = {s.name for s in session.list_streams()} - if "clip_embeddings" not in existing: - replay = TimedSensorReplay("unitree_go2_bigoffice/video") - odom = TimedSensorReplay("unitree_go2_bigoffice/odom") - - raw = session.stream("raw_video", Image) - n = ingest(raw, replay.iterate_ts(seek=5.0, duration=300.0), pose_source=odom) - print(f" {n} frames ingested") - - sharp = raw.transform( - QualityWindowTransformer(lambda img: img.sharpness, window=0.5) - ).store("sharp_frames", Image) - print(f" {sharp.count()} sharp frames (from {n}, {sharp.count() / n:.0%} kept)") - - embeddings: EmbeddingStream[Any] = sharp.transform(EmbeddingTransformer(clip)).store( - "clip_embeddings" - ) # type: ignore[assignment] - print(f" {embeddings.count()} embeddings stored") - else: - print(f"Using cached DB ({DB_PATH})") - - yield store, session # type: ignore[misc] - session.stop() - store.stop() - - -@pytest.fixture(scope="module") -def embeddings(e2e_db: tuple[SqliteStore, Any], clip: CLIPModel) -> EmbeddingStream[Any]: - _, session = e2e_db - stream: EmbeddingStream[Any] = session.embedding_stream("clip_embeddings", embedding_model=clip) # type: ignore[return-value] - return stream - - -@pytest.fixture(scope="module") -def sharp_frames(e2e_db: tuple[SqliteStore, Any]) -> Any: - _, session = e2e_db - return session.stream("sharp_frames", Image) - - -class TestEmbeddingSearch: - """Search the cached CLIP embedding DB and export top matches.""" - - QUERIES = [ - "a hallway in an office", - "a person standing", - "a door", - "a desk", - "supermarket", - "large room", - ] - - @pytest.mark.parametrize("query", QUERIES) - def test_search_returns_results(self, embeddings: EmbeddingStream[Any], query: str) -> None: - from dimos.memory.type import EmbeddingObservation - - results = embeddings.search_embedding(query, k=5).fetch() - assert len(results) > 0 - for obs in results: - assert obs.ts is not None - assert isinstance(obs, EmbeddingObservation) - - @pytest.mark.parametrize("query", QUERIES) - def test_search_exports_images( - self, embeddings: EmbeddingStream[Any], sharp_frames: Any, query: str - ) -> None: - slug = query.replace(" ", "_")[:30] - results = embeddings.search_embedding(query, k=5).project_to(sharp_frames).fetch() - - for rank, img in enumerate(results): - fname = DB_DIR / f"{slug}_{rank + 1}_id{img.id}_ts{img.ts:.0f}.jpg" - img.data.save(str(fname)) - print(f" [{rank + 1}] id={img.id} ts={img.ts:.2f}") - - def test_search_has_similarity(self, embeddings: EmbeddingStream[Any]) -> None: - from dimos.memory.type import EmbeddingObservation - - results = embeddings.search_embedding("a hallway", k=10).fetch() - assert len(results) > 0 - for obs in results: - assert isinstance(obs, EmbeddingObservation) - assert obs.similarity is not None - assert 0.0 <= obs.similarity <= 1.0 - - def test_caption_search_results( - self, embeddings: EmbeddingStream[Any], sharp_frames: Any - ) -> None: - from dimos.models.vl.florence import Florence2Model - - captioner = Florence2Model() - captioner.start() - caption_xf = CaptionTransformer(captioner) - - results = embeddings.search_embedding("a door", k=3).project_to(sharp_frames).fetch() - captions = results.transform(caption_xf).fetch() - - assert len(captions) == len(results) - for cap in captions: - assert isinstance(cap.data, str) - assert len(cap.data) > 0 - print(f" Caption: {cap.data}") - - -class TestRerunStream: - """Send a full image stream to Rerun.""" - - def test_stream_to_rerun(self, e2e_db: tuple[SqliteStore, Any]) -> None: - import rerun as rr - - from dimos.memory.rerun import to_rerun - - rr.init("memory_e2e_test", spawn=True) - - _, session = e2e_db - n = to_rerun(session.stream("sharp_frames", Image)) - assert n > 0 - print(f" Logged {n} images to Rerun") diff --git a/dimos/memory/impl/test_sqlite.py b/dimos/memory/impl/test_sqlite.py deleted file mode 100644 index 1df2d55d7f..0000000000 --- a/dimos/memory/impl/test_sqlite.py +++ /dev/null @@ -1,1039 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for SQLite-backed memory store.""" - -from __future__ import annotations - -import numpy as np -import pytest - -from dimos.memory.impl.sqlite import SqliteSession, SqliteStore -from dimos.memory.transformer import EmbeddingTransformer -from dimos.memory.type import EmbeddingObservation, Observation, _Unset -from dimos.models.embedding.base import Embedding, EmbeddingModel -from dimos.msgs.sensor_msgs.Image import Image -from dimos.utils.testing import TimedSensorReplay - - -def _img_close(a: Image, b: Image, max_diff: float = 5.0) -> bool: - """Approximate Image equality (JPEG is lossy).""" - if a.data.shape != b.data.shape: - return False - if a.frame_id != b.frame_id: - return False - return float(np.abs(a.data.astype(np.float32) - b.data.astype(np.float32)).mean()) < max_diff - - -@pytest.fixture(scope="module") -def replay() -> TimedSensorReplay: # type: ignore[type-arg] - return TimedSensorReplay("unitree_go2_bigoffice/video") - - -@pytest.fixture(scope="module") -def images(replay: TimedSensorReplay) -> list[Image]: # type: ignore[type-arg] - """Load 5 images from replay at 1s intervals.""" - imgs = [replay.find_closest_seek(float(i)) for i in range(1, 6)] - assert all(isinstance(im, Image) for im in imgs) - return imgs # type: ignore[return-value] - - -@pytest.fixture -def store(tmp_path: object) -> SqliteStore: - from pathlib import Path - - assert isinstance(tmp_path, Path) - return SqliteStore(str(tmp_path / "test.db")) - - -@pytest.fixture -def session(store: SqliteStore) -> SqliteSession: - return store.session() - - -class TestStreamBasics: - def test_create_stream(self, session: SqliteSession) -> None: - s = session.stream("images", Image) - assert s is not None - - def test_append_and_fetch(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("images", Image) - obs = s.append(images[0]) - assert obs.id == 1 - assert obs.data == images[0] # append returns original, not decoded - assert obs.ts is not None - - rows = s.fetch() - assert len(rows) == 1 - assert _img_close(rows[0].data, images[0]) - assert rows[0].id == 1 - - def test_append_multiple(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("images", Image) - for img in images[:3]: - s.append(img) - - assert s.count() == 3 - rows = s.fetch() - assert all(_img_close(r.data, img) for r, img in zip(rows, images[:3], strict=True)) - - def test_append_with_tags(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("images", Image) - s.append(images[0], tags={"cam": "front", "quality": "high"}) - - rows = s.fetch() - assert rows[0].tags == {"cam": "front", "quality": "high"} - - def test_last(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("images", Image) - s.append(images[0], ts=1.0) - s.append(images[1], ts=2.0) - s.append(images[2], ts=3.0) - - obs = s.last() - assert _img_close(obs.data, images[2]) - assert obs.ts == 3.0 - - def test_one(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("images", Image) - s.append(images[0]) - - obs = s.first() - assert _img_close(obs.data, images[0]) - - def test_one_empty_raises(self, session: SqliteSession) -> None: - s = session.stream("images", Image) - with pytest.raises(LookupError): - s.first() - - -class TestFilters: - def test_after(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("data", Image) - s.append(images[0], ts=1.0) - s.append(images[1], ts=10.0) - - rows = s.after(5.0).fetch() - assert len(rows) == 1 - assert _img_close(rows[0].data, images[1]) - - def test_before(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("data", Image) - s.append(images[0], ts=1.0) - s.append(images[1], ts=10.0) - - rows = s.before(5.0).fetch() - assert len(rows) == 1 - assert _img_close(rows[0].data, images[0]) - - def test_time_range(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("data", Image) - s.append(images[0], ts=1.0) - s.append(images[1], ts=5.0) - s.append(images[2], ts=10.0) - - rows = s.time_range(3.0, 7.0).fetch() - assert len(rows) == 1 - assert _img_close(rows[0].data, images[1]) - - def test_at(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("data", Image) - s.append(images[0], ts=1.0) - s.append(images[1], ts=5.0) - s.append(images[2], ts=10.0) - - rows = s.at(5.5, tolerance=1.0).fetch() - assert len(rows) == 1 - assert _img_close(rows[0].data, images[1]) - - def test_filter_tags(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("data", Image) - s.append(images[0], tags={"cam": "front"}) - s.append(images[1], tags={"cam": "rear"}) - - rows = s.filter_tags(cam="front").fetch() - assert len(rows) == 1 - assert _img_close(rows[0].data, images[0]) - - def test_chained_filters(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("chained_filter_data", Image) - s.append(images[0], ts=1.0, tags={"cam": "front"}) - s.append(images[1], ts=5.0, tags={"cam": "front"}) - s.append(images[2], ts=6.0, tags={"cam": "rear"}) - - rows = s.after(3.0).filter_tags(cam="front").fetch() - assert len(rows) == 1 - assert _img_close(rows[0].data, images[1]) - - -class TestOrdering: - def test_order_by_ts(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("data", Image) - s.append(images[1], ts=2.0) - s.append(images[0], ts=1.0) - s.append(images[2], ts=3.0) - - rows = s.order_by("ts").fetch() - assert all( - _img_close(r.data, img) - for r, img in zip(rows, [images[0], images[1], images[2]], strict=True) - ) - - def test_order_by_desc(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("data", Image) - s.append(images[0], ts=1.0) - s.append(images[1], ts=2.0) - s.append(images[2], ts=3.0) - - rows = s.order_by("ts", desc=True).fetch() - assert all( - _img_close(r.data, img) - for r, img in zip(rows, [images[2], images[1], images[0]], strict=True) - ) - - def test_limit_offset(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("data", Image) - for i, img in enumerate(images): - s.append(img, ts=float(i)) - - rows = s.order_by("ts").limit(2).offset(1).fetch() - assert len(rows) == 2 - assert all( - _img_close(r.data, img) for r, img in zip(rows, [images[1], images[2]], strict=True) - ) - - -class TestFetchPages: - def test_basic_pagination(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("data", Image) - for i, img in enumerate(images): - s.append(img, ts=float(i)) - - pages = list(s.fetch_pages(batch_size=2)) - assert len(pages) == 3 # 2+2+1 - assert len(pages[0]) == 2 - assert len(pages[-1]) == 1 - - all_items = [obs.data for page in pages for obs in page] - assert all(_img_close(a, b) for a, b in zip(all_items, images, strict=True)) - - -class TestTextStream: - def test_create_and_append(self, session: SqliteSession) -> None: - s = session.text_stream("logs") - s.append("Motor fault on joint 3") - s.append("Battery low warning") - - assert s.count() == 2 - - def test_text_search(self, session: SqliteSession) -> None: - s = session.text_stream("logs") - s.append("Motor fault on joint 3") - s.append("Battery low warning") - s.append("Motor overheating on joint 5") - - rows = s.search_text("motor", k=10).fetch() - assert len(rows) == 2 - assert all("Motor" in r.data for r in rows) - - -class TestTextStorage: - """Test storing plain text (str) in streams.""" - - def test_store_and_fetch_str(self, session: SqliteSession) -> None: - s = session.stream("raw_logs", str) - s.append("Robot started navigation to kitchen", ts=1.0) - s.append("Obstacle detected at waypoint 3", ts=2.0) - s.append("Navigation complete", ts=3.0) - - assert s.count() == 3 - rows = s.fetch() - assert rows[0].data == "Robot started navigation to kitchen" - assert rows[2].data == "Navigation complete" - - def test_str_with_tags_and_filters(self, session: SqliteSession) -> None: - s = session.stream("tagged_logs", str) - s.append("Motor fault on joint 3", ts=1.0, tags={"level": "error"}) - s.append("Battery at 80%", ts=2.0, tags={"level": "info"}) - s.append("Motor overheating", ts=3.0, tags={"level": "error"}) - - errors = s.filter_tags(level="error").fetch() - assert len(errors) == 2 - assert all("Motor" in e.data for e in errors) - - def test_str_persists_reopen(self, tmp_path: object) -> None: - from pathlib import Path - - assert isinstance(tmp_path, Path) - db_path = str(tmp_path / "logs.db") - - store1 = SqliteStore(db_path) - s1 = store1.session() - s1.stream("logs", str).append("hello world", ts=1.0) - s1.stop() - store1.stop() - - store2 = SqliteStore(db_path) - s2 = store2.session() - rows = s2.stream("logs", str).fetch() - assert len(rows) == 1 - assert rows[0].data == "hello world" - s2.stop() - store2.stop() - - -class TestEmbeddingStream: - def test_create_and_append(self, session: SqliteSession) -> None: - es = session.embedding_stream("emb", vec_dimensions=4) - e1 = Embedding(np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)) - e2 = Embedding(np.array([0.0, 1.0, 0.0, 0.0], dtype=np.float32)) - - es.append(e1, ts=1.0) - es.append(e2, ts=2.0) - - assert es.count() == 2 - - def test_search_embedding(self, session: SqliteSession) -> None: - es = session.embedding_stream("emb_search", vec_dimensions=4) - vecs = [ - [1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0], - [0.9, 0.1, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0], - ] - for i, v in enumerate(vecs): - es.append(Embedding(np.array(v, dtype=np.float32)), ts=float(i)) - - # Search for vector closest to [1, 0, 0, 0] — should get id=1 and id=3 - results = es.search_embedding([1.0, 0.0, 0.0, 0.0], k=2).fetch() - assert len(results) == 2 - result_ids = {r.id for r in results} - assert 1 in result_ids # exact match - assert 3 in result_ids # [0.9, 0.1, 0, 0] is close - - def test_search_returns_embedding_observation(self, session: SqliteSession) -> None: - es = session.embedding_stream("emb_obs", vec_dimensions=3) - es.append(Embedding(np.array([1.0, 0.0, 0.0], dtype=np.float32)), ts=1.0) - - results = es.search_embedding([1.0, 0.0, 0.0], k=1).fetch() - assert len(results) == 1 - assert isinstance(results[0], EmbeddingObservation) - - def test_search_with_time_filter(self, session: SqliteSession) -> None: - es = session.embedding_stream("emb_time", vec_dimensions=3) - es.append(Embedding(np.array([1.0, 0.0, 0.0], dtype=np.float32)), ts=1.0) - es.append(Embedding(np.array([1.0, 0.0, 0.0], dtype=np.float32)), ts=10.0) - - # Both match the vector, but only one is after t=5 - results = es.search_embedding([1.0, 0.0, 0.0], k=10).after(5.0).fetch() - assert len(results) == 1 - assert results[0].ts == 10.0 - - def test_embedding_transformer_store(self, session: SqliteSession, images: list[Image]) -> None: - """Test the full pipeline: images → EmbeddingTransformer → EmbeddingStream.""" - - class FakeEmbedder(EmbeddingModel): - device = "cpu" - - def embed(self, *imgs: Image) -> Embedding | list[Embedding]: # type: ignore[override] - results = [] - for img in imgs: - val = float(img.data.mean()) / 255.0 - results.append( - Embedding(np.array([val, 1.0 - val, 0.0, 0.0], dtype=np.float32)) - ) - return results if len(results) > 1 else results[0] - - def embed_text(self, *texts: str) -> Embedding | list[Embedding]: - raise NotImplementedError - - s = session.stream("cam_emb", Image) - s.append(images[0], ts=1.0) - s.append(images[1], ts=2.0) - - emb_stream = s.transform(EmbeddingTransformer(FakeEmbedder())).store("cam_embeddings") - assert emb_stream.count() == 2 - - # Search returns EmbeddingObservation; project_to to get source images - results = emb_stream.search_embedding([0.5, 0.5, 0.0, 0.0], k=1).project_to(s).fetch() - assert len(results) == 1 - assert _img_close(results[0].data, images[0]) or _img_close(results[0].data, images[1]) - - -class TestListStreams: - def test_list_empty(self, session: SqliteSession) -> None: - assert session.list_streams() == [] - - def test_list_after_create(self, session: SqliteSession) -> None: - session.stream("images", Image) - session.text_stream("logs") - - infos = session.list_streams() - names = {i.name for i in infos} - assert names == {"images", "logs"} - - -class TestReactive: - def test_appended_observable(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("images", Image) - received: list[Observation] = [] - s.subscribe(received.append) - - s.append(images[0]) - s.append(images[1]) - - assert len(received) == 2 - assert received[0].data is images[0] # appended obs holds original - assert received[1].data is images[1] - - -class TestTransformInMemory: - def test_lambda_transform(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("data", Image) - s.append(images[0], ts=1.0) - s.append(images[1], ts=2.0) - - shapes = s.transform(lambda im: f"{im.width}x{im.height}") - results = shapes.fetch() - assert len(results) == 2 - assert results[0].data == f"{images[0].width}x{images[0].height}" - - def test_lambda_filter_none(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("data", Image) - s.append(images[0], ts=1.0) - s.append(images[1], ts=2.0) - s.append(images[2], ts=3.0) - - # Only keep images wider than 0 (all pass), filter second by index trick - idx = iter(range(3)) - big = s.transform(lambda im: im if next(idx) % 2 == 0 else None) - results = big.fetch() - assert len(results) == 2 # indices 0 and 2 - - def test_lambda_expand_list(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("data", Image) - s.append(images[0], ts=1.0) - - # Extract format and frame_id as two separate results - results = s.transform(lambda im: [im.format.value, im.frame_id]).fetch() - assert len(results) == 2 - assert results[0].data == images[0].format.value - - -class TestTransformStore: - def test_transform_store_backfill(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("data", Image) - s.append(images[0], ts=1.0) - s.append(images[1], ts=2.0) - - stored = s.transform(lambda im: f"{im.width}x{im.height}").store("shapes", str) - rows = stored.fetch() - assert len(rows) == 2 - expected = f"{images[0].width}x{images[0].height}" - assert rows[0].data == expected - - reloaded = session.stream("shapes", str) - assert reloaded.count() == 2 - - def test_transform_store_live(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("data", Image) - s.append(images[0], ts=1.0) - - stored = s.transform(lambda im: im.height, live=True).store("heights", int) - assert stored.count() == 0 # no backfill - - s.append(images[1], ts=2.0) - assert stored.count() == 1 - assert stored.last().data == images[1].height - - def test_transform_store_backfill_only( - self, session: SqliteSession, images: list[Image] - ) -> None: - s = session.stream("data", Image) - s.append(images[0], ts=1.0) - - stored = s.transform(lambda im: im.height, backfill_only=True).store("heights_bo", int) - assert stored.count() == 1 - assert stored.first().data == images[0].height - - s.append(images[1], ts=2.0) - assert stored.count() == 1 # still 1 - - -class TestLazyData: - def test_data_lazy_loaded(self, session: SqliteSession, images: list[Image]) -> None: - """Fetched observations should not eagerly load payload.""" - s = session.stream("data", Image) - s.append(images[0], ts=1.0) - - rows = s.fetch() - obs = rows[0] - assert isinstance(obs._data, _Unset) - assert obs._data_loader is not None - loaded = obs.data - assert _img_close(loaded, images[0]) - assert obs._data is loaded # cached after first access - - def test_metadata_without_payload(self, session: SqliteSession, images: list[Image]) -> None: - """Metadata (ts, tags) should be available without loading payload.""" - s = session.stream("data", Image) - s.append(images[0], ts=1.0, tags={"key": "val"}) - - rows = s.fetch() - obs = rows[0] - assert obs.ts == 1.0 - assert obs.tags == {"key": "val"} - assert obs.id == 1 - - -class TestIteration: - def test_iter(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("data", Image) - for i, img in enumerate(images[:3]): - s.append(img, ts=float(i)) - - items = [obs.data for obs in s] - assert all(_img_close(a, b) for a, b in zip(items, images[:3], strict=True)) - - -class TestProjectTo: - def test_search_returns_embedding_obs( - self, session: SqliteSession, images: list[Image] - ) -> None: - """search_embedding returns EmbeddingObservation; .data provides source data via lineage.""" - - class FakeEmbedder(EmbeddingModel): - device = "cpu" - - def embed(self, *imgs: Image) -> Embedding | list[Embedding]: # type: ignore[override] - results = [] - for img in imgs: - val = float(img.data.mean()) / 255.0 - results.append(Embedding(np.array([val, 1.0 - val, 0.0], dtype=np.float32))) - return results if len(results) > 1 else results[0] - - def embed_text(self, *texts: str) -> Embedding | list[Embedding]: - raise NotImplementedError - - imgs = session.stream("pt_images", Image) - imgs.append(images[0], ts=1.0) - imgs.append(images[1], ts=2.0) - imgs.append(images[2], ts=3.0) - - embs = imgs.transform(EmbeddingTransformer(FakeEmbedder())).store("pt_embs") - assert embs.count() == 3 - - # search_embedding returns EmbeddingObservation with Embedding data - results = embs.search_embedding([0.5, 0.5, 0.0], k=2).fetch() - assert len(results) == 2 - for obs in results: - assert isinstance(obs, EmbeddingObservation) - assert isinstance(obs.data, Embedding) - - # project_to to get source images - projected = embs.search_embedding([0.5, 0.5, 0.0], k=2).project_to(imgs).fetch() - assert len(projected) == 2 - for obs in projected: - assert ( - _img_close(obs.data, images[0]) - or _img_close(obs.data, images[1]) - or _img_close(obs.data, images[2]) - ) - - def test_search_chainable(self, session: SqliteSession, images: list[Image]) -> None: - """Search results support further filter chaining.""" - - class FakeEmbedder(EmbeddingModel): - device = "cpu" - - def embed(self, *imgs: Image) -> Embedding | list[Embedding]: # type: ignore[override] - results = [] - for img in imgs: - val = float(img.data.mean()) / 255.0 - results.append(Embedding(np.array([val, 1.0 - val, 0.0], dtype=np.float32))) - return results if len(results) > 1 else results[0] - - def embed_text(self, *texts: str) -> Embedding | list[Embedding]: - raise NotImplementedError - - imgs = session.stream("ptc_images", Image) - imgs.append(images[0], ts=1.0) - imgs.append(images[1], ts=5.0) - imgs.append(images[2], ts=10.0) - - embs = imgs.transform(EmbeddingTransformer(FakeEmbedder())).store("ptc_embs") - - # Chain time filter after search - results = embs.search_embedding([0.5, 0.5, 0.0], k=10).after(3.0).fetch() - assert all(r.ts is not None and r.ts > 3.0 for r in results) - - def test_explicit_project_to(self, session: SqliteSession, images: list[Image]) -> None: - """Explicit project_to works for non-search cases.""" - - class FakeEmbedder(EmbeddingModel): - device = "cpu" - - def embed(self, *imgs: Image) -> Embedding | list[Embedding]: # type: ignore[override] - results = [] - for img in imgs: - val = float(img.data.mean()) / 255.0 - results.append(Embedding(np.array([val, 1.0 - val, 0.0], dtype=np.float32))) - return results if len(results) > 1 else results[0] - - def embed_text(self, *texts: str) -> Embedding | list[Embedding]: - raise NotImplementedError - - imgs = session.stream("pte_images", Image) - imgs.append(images[0], ts=1.0) - imgs.append(images[1], ts=5.0) - imgs.append(images[2], ts=10.0) - - embs = imgs.transform(EmbeddingTransformer(FakeEmbedder())).store("pte_embs") - - # Explicit project_to without search — project all embeddings to images - projected = embs.project_to(imgs).after(3.0) - results = projected.fetch() - assert all(r.ts is not None and r.ts > 3.0 for r in results) - - def test_two_hop(self, session: SqliteSession, images: list[Image]) -> None: - """project_to handles multi-hop lineage (embs → mid → raw).""" - - class FakeEmbedder(EmbeddingModel): - device = "cpu" - - def embed(self, *imgs: Image) -> Embedding | list[Embedding]: # type: ignore[override] - results = [] - for img in imgs: - val = float(img.data.mean()) / 255.0 - results.append(Embedding(np.array([val, 1.0 - val, 0.0], dtype=np.float32))) - return results if len(results) > 1 else results[0] - - def embed_text(self, *texts: str) -> Embedding | list[Embedding]: - raise NotImplementedError - - raw = session.stream("th_raw", Image) - raw.append(images[0], ts=1.0) - raw.append(images[1], ts=2.0) - raw.append(images[2], ts=3.0) - - mid = raw.transform(lambda img: img).store("th_mid", Image) - assert mid.count() == 3 - - embs = mid.transform(EmbeddingTransformer(FakeEmbedder())).store("th_embs") - assert embs.count() == 3 - - # project_to(raw) walks the full chain: th_embs → th_mid → th_raw - projected = embs.search_embedding([0.5, 0.5, 0.0], k=2).project_to(raw) - results = projected.fetch() - assert len(results) == 2 - - def test_count_on_projected(self, session: SqliteSession, images: list[Image]) -> None: - """count() works on auto-projected search results.""" - - class FakeEmbedder(EmbeddingModel): - device = "cpu" - - def embed(self, *imgs: Image) -> Embedding | list[Embedding]: # type: ignore[override] - results = [] - for img in imgs: - val = float(img.data.mean()) / 255.0 - results.append(Embedding(np.array([val, 1.0 - val, 0.0], dtype=np.float32))) - return results if len(results) > 1 else results[0] - - def embed_text(self, *texts: str) -> Embedding | list[Embedding]: - raise NotImplementedError - - imgs = session.stream("ptcnt_images", Image) - imgs.append(images[0], ts=1.0) - imgs.append(images[1], ts=2.0) - - embs = imgs.transform(EmbeddingTransformer(FakeEmbedder())).store("ptcnt_embs") - assert embs.search_embedding([0.5, 0.5, 0.0], k=1).count() == 1 - - def test_project_to_plain_transform(self, session: SqliteSession, images: list[Image]) -> None: - """project_to on a non-embedding derived stream (e.g., detections → images).""" - imgs = session.stream("ptplain_images", Image) - imgs.append(images[0], ts=1.0) - imgs.append(images[1], ts=2.0) - imgs.append(images[2], ts=3.0) - - # Simulate a detection transform — extracts height as an "int" stream - heights = imgs.transform(lambda im: im.height).store("ptplain_heights", int) - assert heights.count() == 3 - - # Project heights back to source images - projected = heights.after(1.5).project_to(imgs) - results = projected.fetch() - assert len(results) == 2 # ts=2.0 and ts=3.0 - for obs in results: - assert _img_close(obs.data, images[1]) or _img_close(obs.data, images[2]) - - def test_search_by_text(self, session: SqliteSession, images: list[Image]) -> None: - """search_embedding accepts a string and auto-embeds via model.""" - - class FakeEmbedder(EmbeddingModel): - device = "cpu" - - def embed(self, *imgs: Image) -> Embedding | list[Embedding]: # type: ignore[override] - results = [] - for img in imgs: - val = float(img.data.mean()) / 255.0 - results.append(Embedding(np.array([val, 1.0 - val, 0.0], dtype=np.float32))) - return results if len(results) > 1 else results[0] - - def embed_text(self, *texts: str) -> Embedding | list[Embedding]: - results = [] - for _text in texts: - results.append(Embedding(np.array([0.5, 0.5, 0.0], dtype=np.float32))) - return results if len(results) > 1 else results[0] - - imgs = session.stream("pttxt_images", Image) - imgs.append(images[0], ts=1.0) - imgs.append(images[1], ts=2.0) - - embs = imgs.transform(EmbeddingTransformer(FakeEmbedder())).store("pttxt_embs") - - # Search with text string — auto-embeds via embed_text() - results = embs.search_embedding("a hallway", k=2).fetch() - assert len(results) == 2 - - def test_search_by_image(self, session: SqliteSession, images: list[Image]) -> None: - """search_embedding accepts an image and auto-embeds via model.""" - - class FakeEmbedder(EmbeddingModel): - device = "cpu" - - def embed(self, *imgs: Image) -> Embedding | list[Embedding]: # type: ignore[override] - results = [] - for img in imgs: - val = float(img.data.mean()) / 255.0 - results.append(Embedding(np.array([val, 1.0 - val, 0.0], dtype=np.float32))) - return results if len(results) > 1 else results[0] - - def embed_text(self, *texts: str) -> Embedding | list[Embedding]: - raise NotImplementedError - - imgs = session.stream("ptimg_images", Image) - imgs.append(images[0], ts=1.0) - imgs.append(images[1], ts=2.0) - - embs = imgs.transform(EmbeddingTransformer(FakeEmbedder())).store("ptimg_embs") - - # Search with image — auto-embeds via embed() - results = embs.search_embedding(images[0], k=1).fetch() - assert len(results) == 1 - - def test_search_no_model_raises(self, session: SqliteSession) -> None: - """search_embedding with str raises when no model is available.""" - es = session.embedding_stream("pt_nomodel", vec_dimensions=3) - es.append(Embedding(np.array([1.0, 0.0, 0.0], dtype=np.float32)), ts=1.0) - - with pytest.raises(TypeError, match="No embedding model available"): - es.search_embedding("hello", k=1) - - def test_no_lineage_fallback(self, session: SqliteSession) -> None: - """search_embedding without lineage returns EmbeddingStream (no projection).""" - es = session.embedding_stream("pt_standalone", vec_dimensions=3) - es.append(Embedding(np.array([1.0, 0.0, 0.0], dtype=np.float32)), ts=1.0) - - results = es.search_embedding([1.0, 0.0, 0.0], k=1).fetch() - assert len(results) == 1 - assert isinstance(results[0], EmbeddingObservation) - - -class TestSimilarityScores: - def test_search_populates_similarity(self, session: SqliteSession) -> None: - """search_embedding should populate .similarity on EmbeddingObservation.""" - es = session.embedding_stream("sim_test", vec_dimensions=4) - vecs = [ - [1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0], - [0.9, 0.1, 0.0, 0.0], - ] - for i, v in enumerate(vecs): - es.append(Embedding(np.array(v, dtype=np.float32)), ts=float(i)) - - results = es.search_embedding([1.0, 0.0, 0.0, 0.0], k=3).fetch() - assert len(results) == 3 - for obs in results: - assert isinstance(obs, EmbeddingObservation) - assert obs.similarity is not None - assert 0.0 <= obs.similarity <= 1.0 - - # Exact match should have highest similarity - by_sim = sorted(results, key=lambda o: o.similarity, reverse=True) - assert by_sim[0].id == 1 # [1,0,0,0] is exact match - - def test_similarity_none_without_search(self, session: SqliteSession) -> None: - """Plain fetch() should leave similarity as None.""" - es = session.embedding_stream("sim_none", vec_dimensions=3) - es.append(Embedding(np.array([1.0, 0.0, 0.0], dtype=np.float32)), ts=1.0) - - results = es.fetch() - assert len(results) == 1 - assert isinstance(results[0], EmbeddingObservation) - assert results[0].similarity is None - - def test_search_embedding_obs_with_similarity( - self, session: SqliteSession, images: list[Image] - ) -> None: - """search_embedding returns EmbeddingObservation with similarity scores.""" - - class FakeEmbedder(EmbeddingModel): - device = "cpu" - - def embed(self, *imgs: Image) -> Embedding | list[Embedding]: # type: ignore[override] - results = [] - for img in imgs: - val = float(img.data.mean()) / 255.0 - results.append(Embedding(np.array([val, 1.0 - val, 0.0], dtype=np.float32))) - return results if len(results) > 1 else results[0] - - def embed_text(self, *texts: str) -> Embedding | list[Embedding]: - raise NotImplementedError - - imgs = session.stream("sim_proj_imgs", Image) - imgs.append(images[0], ts=1.0) - imgs.append(images[1], ts=2.0) - - embs = imgs.transform(EmbeddingTransformer(FakeEmbedder())).store("sim_proj_embs") - - results = embs.search_embedding([0.5, 0.5, 0.0], k=2).fetch() - assert len(results) == 2 - for obs in results: - assert isinstance(obs, EmbeddingObservation) - assert obs.similarity is not None - assert isinstance(obs.data, Embedding) - - -class TestObservationSet: - def test_fetch_returns_observation_set( - self, session: SqliteSession, images: list[Image] - ) -> None: - from dimos.memory.stream import ObservationSet - - s = session.stream("obs_set", Image) - s.append(images[0], ts=1.0) - s.append(images[1], ts=2.0) - - result = s.fetch() - assert isinstance(result, ObservationSet) - assert len(result) == 2 - - def test_list_like_access(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("obs_list", Image) - s.append(images[0], ts=1.0) - s.append(images[1], ts=2.0) - s.append(images[2], ts=3.0) - - result = s.fetch() - assert result[0].ts == 1.0 - assert result[-1].ts == 3.0 - assert len(result[1:]) == 2 - assert bool(result) is True - - def test_empty_observation_set(self, session: SqliteSession) -> None: - s = session.stream("obs_empty", Image) - result = s.fetch() - assert len(result) == 0 - assert bool(result) is False - - def test_iter(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("obs_iter", Image) - for i, img in enumerate(images[:3]): - s.append(img, ts=float(i)) - - result = s.fetch() - timestamps = [obs.ts for obs in result] - assert timestamps == [0.0, 1.0, 2.0] - - def test_refilter_in_memory(self, session: SqliteSession, images: list[Image]) -> None: - """ObservationSet supports chaining filters that re-evaluate in memory.""" - s = session.stream("obs_refilter", Image) - s.append(images[0], ts=1.0) - s.append(images[1], ts=5.0) - s.append(images[2], ts=10.0) - - result = s.fetch() - assert len(result) == 3 - - # Re-filter in memory - recent = result.after(3.0).fetch() - assert len(recent) == 2 - assert all(r.ts is not None and r.ts > 3.0 for r in recent) - - def test_transform_on_observation_set( - self, session: SqliteSession, images: list[Image] - ) -> None: - """ObservationSet supports .transform() for fork-and-zip.""" - s = session.stream("obs_xf", Image) - s.append(images[0], ts=1.0) - s.append(images[1], ts=2.0) - - result = s.fetch() - shapes = result.transform(lambda im: f"{im.width}x{im.height}").fetch() - assert len(shapes) == 2 - assert shapes[0].data == f"{images[0].width}x{images[0].height}" - - def test_append(self, session: SqliteSession, images: list[Image]) -> None: - from dimos.memory.stream import ObservationSet - - result = ObservationSet([], session=session) - obs = result.append(images[0], ts=1.0) - assert obs.id == 0 - assert obs.ts == 1.0 - assert len(result) == 1 - - def test_ordering_in_memory(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("obs_order", Image) - s.append(images[0], ts=3.0) - s.append(images[1], ts=1.0) - s.append(images[2], ts=2.0) - - result = s.fetch() - ordered = result.order_by("ts").fetch() - assert [o.ts for o in ordered] == [1.0, 2.0, 3.0] - - desc = result.order_by("ts", desc=True).fetch() - assert [o.ts for o in desc] == [3.0, 2.0, 1.0] - - def test_limit_offset_in_memory(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("obs_lim", Image) - for i, img in enumerate(images): - s.append(img, ts=float(i)) - - result = s.fetch() - page = result.order_by("ts").limit(2).offset(1).fetch() - assert len(page) == 2 - assert [o.ts for o in page] == [1.0, 2.0] - - -class TestMatchesFilters: - def test_after_filter(self) -> None: - from dimos.memory.type import AfterFilter - - f = AfterFilter(5.0) - assert f.matches(Observation(id=1, ts=6.0)) is True - assert f.matches(Observation(id=2, ts=5.0)) is False - assert f.matches(Observation(id=3, ts=4.0)) is False - assert f.matches(Observation(id=4, ts=None)) is False - - def test_before_filter(self) -> None: - from dimos.memory.type import BeforeFilter - - f = BeforeFilter(5.0) - assert f.matches(Observation(id=1, ts=4.0)) is True - assert f.matches(Observation(id=2, ts=5.0)) is False - assert f.matches(Observation(id=3, ts=6.0)) is False - - def test_time_range_filter(self) -> None: - from dimos.memory.type import TimeRangeFilter - - f = TimeRangeFilter(2.0, 8.0) - assert f.matches(Observation(id=1, ts=5.0)) is True - assert f.matches(Observation(id=2, ts=2.0)) is True - assert f.matches(Observation(id=3, ts=8.0)) is True - assert f.matches(Observation(id=4, ts=1.0)) is False - assert f.matches(Observation(id=5, ts=9.0)) is False - - def test_at_filter(self) -> None: - from dimos.memory.type import AtFilter - - f = AtFilter(5.0, tolerance=1.0) - assert f.matches(Observation(id=1, ts=5.0)) is True - assert f.matches(Observation(id=2, ts=5.5)) is True - assert f.matches(Observation(id=3, ts=6.0)) is True - assert f.matches(Observation(id=4, ts=6.5)) is False - - def test_tags_filter(self) -> None: - from dimos.memory.type import TagsFilter - - f = TagsFilter((("cam", "front"),)) - assert ( - f.matches(Observation(id=1, ts=0.0, tags={"cam": "front", "quality": "high"})) is True - ) - assert f.matches(Observation(id=2, ts=0.0, tags={"cam": "rear"})) is False - assert f.matches(Observation(id=3, ts=0.0, tags={})) is False - - def test_text_search_filter(self) -> None: - from dimos.memory.type import TextSearchFilter - - f = TextSearchFilter("motor", k=None) - assert f.matches(Observation(id=1, ts=0.0, _data="Motor fault on joint 3")) is True - assert f.matches(Observation(id=2, ts=0.0, _data="Battery low")) is False - - def test_embedding_search_filter_always_true(self) -> None: - from dimos.memory.type import EmbeddingSearchFilter - - f = EmbeddingSearchFilter([1.0, 0.0], k=5) - assert f.matches(Observation(id=1, ts=0.0)) is True - - def test_lineage_filter_raises(self) -> None: - from dimos.memory.type import LineageFilter, StreamQuery - - f = LineageFilter("src", StreamQuery(), ()) - with pytest.raises(NotImplementedError): - f.matches(Observation(id=1, ts=0.0)) - - -class TestFilteredAppended: - def test_unfiltered_appended(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("fa_unfilt", Image) - received: list[Observation] = [] - s.subscribe(received.append) - - s.append(images[0], ts=1.0) - s.append(images[1], ts=5.0) - assert len(received) == 2 - - def test_filtered_appended(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("fa_filt", Image) - received: list[Observation] = [] - s.after(3.0).subscribe(received.append) - - s.append(images[0], ts=1.0) # filtered out - s.append(images[1], ts=5.0) # passes - assert len(received) == 1 - assert received[0].ts == 5.0 - - def test_tag_filtered_appended(self, session: SqliteSession, images: list[Image]) -> None: - s = session.stream("fa_tag", Image) - received: list[Observation] = [] - s.filter_tags(cam="front").subscribe(received.append) - - s.append(images[0], tags={"cam": "front"}) - s.append(images[1], tags={"cam": "rear"}) - assert len(received) == 1 - - -class TestStoreReopen: - def test_data_persists(self, tmp_path: object, images: list[Image]) -> None: - from pathlib import Path - - assert isinstance(tmp_path, Path) - db_path = str(tmp_path / "persist.db") - - store1 = SqliteStore(db_path) - s1 = store1.session() - s1.stream("data", Image).append(images[0], ts=1.0) - s1.stop() - store1.stop() - - store2 = SqliteStore(db_path) - s2 = store2.session() - rows = s2.stream("data", Image).fetch() - assert len(rows) == 1 - assert _img_close(rows[0].data, images[0]) - s2.stop() - store2.stop() diff --git a/dimos/memory/impl/test_sqlite_e2e.py b/dimos/memory/impl/test_sqlite_e2e.py deleted file mode 100644 index d75278c384..0000000000 --- a/dimos/memory/impl/test_sqlite_e2e.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""E2E test: ingest robot video → sharpness filter → CLIP embed → vector search.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -if TYPE_CHECKING: - from pathlib import Path - -from dimos.memory.impl.sqlite import SqliteStore -from dimos.memory.ingest import ingest -from dimos.memory.transformer import EmbeddingTransformer, QualityWindowTransformer -from dimos.models.embedding.clip import CLIPModel -from dimos.msgs.sensor_msgs.Image import Image -from dimos.utils.testing import TimedSensorReplay - - -@pytest.fixture(scope="module") -def replay() -> TimedSensorReplay: # type: ignore[type-arg] - return TimedSensorReplay("unitree_go2_bigoffice/video") - - -@pytest.fixture(scope="module") -def clip() -> CLIPModel: - model = CLIPModel() - model.start() - return model - - -@pytest.mark.slow -@pytest.mark.skipif_in_ci -class TestE2EPipeline: - """Ingest 60s of robot video, filter by sharpness, embed with CLIP, search.""" - - def test_ingest_filter_embed_search( - self, - tmp_path: Path, - replay: TimedSensorReplay, # type: ignore[type-arg] - clip: CLIPModel, - ) -> None: - store = SqliteStore(str(tmp_path / "e2e.db")) - session = store.session() - - # 1. Ingest 60s of video - raw = session.stream("raw_video", Image) - n_ingested = ingest(raw, replay.iterate_ts(seek=5.0, duration=60.0)) - assert n_ingested > 0 - print(f"\nIngested {n_ingested} frames") - - # 2. Sharpness filter: keep best frame per 0.5s window - sharp = raw.transform( - QualityWindowTransformer(lambda img: img.sharpness, window=0.5) - ).store("sharp_frames", Image) - n_sharp = sharp.count() - assert n_sharp > 0 - assert n_sharp < n_ingested # should reduce count - print(f"Sharp frames: {n_sharp} (from {n_ingested}, {n_sharp / n_ingested:.0%} kept)") - - # 3. Embed with real CLIP model - embeddings = sharp.transform(EmbeddingTransformer(clip)).store("clip_embeddings") - n_emb = embeddings.count() - assert n_emb == n_sharp - print(f"Embeddings stored: {n_emb}") - - # 4. Text-to-image search - query_emb = clip.embed_text("a hallway in an office") - results = embeddings.search_embedding(query_emb, k=5).fetch() - assert len(results) > 0 - assert len(results) <= 5 - print(f"Search returned {len(results)} results") - - for r in results: - assert r.ts is not None - assert r.data is not None - print(f" id={r.id} ts={r.ts:.2f}") - - # 5. Search with time filter - mid_ts = (results[0].ts + results[-1].ts) / 2 if len(results) > 1 else results[0].ts - filtered = embeddings.search_embedding(query_emb, k=10).after(mid_ts).fetch() - assert all(r.ts > mid_ts for r in filtered) - print(f"Time-filtered search: {len(filtered)} results after ts={mid_ts:.2f}") - - # 6. Verify persistence — reopen and search again - session.stop() - store.stop() - - store2 = SqliteStore(str(tmp_path / "e2e.db")) - session2 = store2.session() - reloaded = session2.embedding_stream("clip_embeddings", vec_dimensions=512) - assert reloaded.count() == n_emb - - results2 = reloaded.search_embedding(query_emb, k=3).fetch() - assert len(results2) > 0 - print(f"After reopen: {len(results2)} results") - - session2.stop() - store2.stop() diff --git a/dimos/memory/ingest.py b/dimos/memory/ingest.py deleted file mode 100644 index f0fd04263b..0000000000 --- a/dimos/memory/ingest.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Helpers for ingesting timestamped data into memory streams.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from collections.abc import Iterable - - from dimos.memory.stream import Stream - - -def ingest( - stream: Stream[Any], - source: Iterable[tuple[float, Any]], - *, - pose_source: Any | None = None, -) -> int: - """Ingest (timestamp, payload) pairs into a stream. - - Accepts any iterable of ``(ts, data)`` — e.g. ``replay.iterate_ts(seek=5, duration=60)``. - - Args: - pose_source: Optional replay with ``find_closest(ts)`` returning a pose - to attach to each frame (e.g. odom replay). - - Returns: - Number of items ingested. - """ - count = 0 - for ts, payload in source: - pose = None - if pose_source is not None: - pose = pose_source.find_closest(ts) - stream.append(payload, ts=ts, pose=pose) - count += 1 - return count diff --git a/dimos/memory2/intro.md b/dimos/memory/intro.md similarity index 98% rename from dimos/memory2/intro.md rename to dimos/memory/intro.md index 341d89608c..269807cb4f 100644 --- a/dimos/memory2/intro.md +++ b/dimos/memory/intro.md @@ -3,7 +3,7 @@ ## Quick start ```python session=memory ansi=false no-result -from dimos.memory2.impl.sqlite import SqliteStore +from dimos.memory.impl.sqlite import SqliteStore store = SqliteStore(path="/tmp/memory_readme.db") session = store.session() @@ -140,7 +140,7 @@ Use `EmbedText` transformer with CLIP to enrich observations with embeddings, th ```python session=memory ansi=false from dimos.models.embedding.clip import CLIPModel -from dimos.memory2.embed import EmbedText +from dimos.memory.embed import EmbedText clip = CLIPModel() diff --git a/dimos/memory/livechannel/__init__.py b/dimos/memory/livechannel/__init__.py new file mode 100644 index 0000000000..143c8e95bf --- /dev/null +++ b/dimos/memory/livechannel/__init__.py @@ -0,0 +1,4 @@ +from dimos.memory.backend import LiveChannel +from dimos.memory.livechannel.subject import SubjectChannel + +__all__ = ["LiveChannel", "SubjectChannel"] diff --git a/dimos/memory2/livechannel/subject.py b/dimos/memory/livechannel/subject.py similarity index 92% rename from dimos/memory2/livechannel/subject.py rename to dimos/memory/livechannel/subject.py index 2d2b848f9f..8debe229d7 100644 --- a/dimos/memory2/livechannel/subject.py +++ b/dimos/memory/livechannel/subject.py @@ -21,13 +21,13 @@ from reactivex.disposable import Disposable -from dimos.memory2.backend import LiveChannel +from dimos.memory.backend import LiveChannel if TYPE_CHECKING: from reactivex.abc import DisposableBase - from dimos.memory2.buffer import BackpressureBuffer - from dimos.memory2.type import Observation + from dimos.memory.buffer import BackpressureBuffer + from dimos.memory.type import Observation T = TypeVar("T") diff --git a/dimos/memory/module.py b/dimos/memory/module.py deleted file mode 100644 index b651d037be..0000000000 --- a/dimos/memory/module.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Memory module — record input streams into persistent memory.""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, TypeVar - -import cv2 - -from dimos.core.core import rpc -from dimos.core.module import Module, ModuleConfig -from dimos.memory.impl.sqlite import SqliteStore -from dimos.msgs.sensor_msgs.Image import sharpness_barrier -from dimos.utils.logging_config import setup_logger - -cv2.setNumThreads(1) - -if TYPE_CHECKING: - from reactivex.observable import Observable - - from dimos.core.stream import In - from dimos.memory.stream import Stream - from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped - -T = TypeVar("T") - -logger = setup_logger() - - -@dataclass -class MemoryModuleConfig(ModuleConfig): - db_path: str = "memory.db" - world_frame: str = "world" - robot_frame: str = "base_link" - - -class MemoryModule(Module[MemoryModuleConfig]): - default_config: type[MemoryModuleConfig] = MemoryModuleConfig - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self._store: SqliteStore | None = None - - def pose(self, ts: float) -> PoseStamped | None: - return self.tf.get_pose(self.config.world_frame, self.config.robot_frame) # type: ignore[no-any-return] - - @rpc - def start(self) -> None: - super().start() - self._store = SqliteStore(self.config.db_path) - self._disposables.add(self._store) - logger.info("MemoryModule started (db=%s)", self.config.db_path) - - def memory( - self, - input: In[T], - name: str | None = None, # can be infered from input - payload_type: type | None = None, # can be infered from input - fps: float = 0, - ) -> Stream[T]: - assert self._store is not None, "record() called before start()" - - if name is None: - name = input.name - if payload_type is None: - payload_type = input.type - - session = self._store.session() - self._disposables.add(session) - - memory_stream = session.stream(name, payload_type, pose_provider=self.pose) - - obs: Observable[Any] = input.observable() - if fps > 0: - obs = obs.pipe(sharpness_barrier(fps)) - - self._disposables.add(obs.subscribe(on_next=memory_stream.append)) - - return memory_stream - - @rpc - def stop(self) -> None: - super().stop() - - -memory_module = MemoryModule.blueprint diff --git a/dimos/memory/readme.md b/dimos/memory/readme.md deleted file mode 100644 index 3b5e6aca41..0000000000 --- a/dimos/memory/readme.md +++ /dev/null @@ -1,455 +0,0 @@ -# Memory - -Lazy, chainable query system for persistent robot data. Stores timestamped observations in SQLite with vector similarity (sqlite-vec), full-text (FTS5), spatial (R*Tree), and temporal indexes. - -```sh no-result -rm -f /tmp/memory_readme.db -``` - -## Quick start - -```python session=memory ansi=false no-result -from dimos.memory.impl.sqlite import SqliteStore - -store = SqliteStore("/tmp/memory_readme.db") -session = store.session() -``` - -Open a store, get a session, create a stream: - -```python session=memory ansi=false -logs = session.stream("logs", str) -print(logs) -``` - - -``` -Stream[str]("logs") -``` - -Append observations and query them: - -```python session=memory ansi=false -logs.append("Motor started", ts=1.0, tags={"level": "info"}) -logs.append("Joint 3 fault", ts=2.0, tags={"level": "error"}) -logs.append("Motor stopped", ts=3.0, tags={"level": "info"}) - -print(logs.summary()) -``` - - -``` -Stream[str]("logs"): 3 items, 1970-01-01 00:00:01 — 1970-01-01 00:00:03 (2.0s) -``` - -## Observations - -Each observation wraps a payload with metadata: - -```python session=memory ansi=false -obs = logs.first() -print(obs) -print(f"id={obs.id}, ts={obs.ts}, tags={obs.tags}") -``` - - -``` -Observation(id=1, ts=1.0, pose=None, tags={'level': 'info'}) -id=1, ts=1.0, tags={'level': 'info'} -``` - -- `id` — auto-assigned integer -- `ts` — timestamp (float, seconds) -- `pose` — optional 3D position + orientation -- `tags` — key-value metadata dict -- `parent_id` — lineage tracking (set by transforms) -- `data` — the payload (lazily loaded from DB) - -### Lazy payload loading - -Metadata is always in memory. The `.data` property triggers a single-row `SELECT` on first access, then caches: - -```python skip -obs = logs.first() -obs.ts # already in memory -obs.tags # already in memory -obs.data # NOW loads the blob and decodes it -obs.data # cached — second access is free -``` - -Payloads are decoded by the thread that fetched them. To pass observations across threads, call `.load()` first: - -```python skip -obs = logs.first().load() # force-loads .data, safe to pass to another thread -``` - -## Streams are lazy queries - -Filter methods return new `Stream` instances — nothing executes until a terminal is called: - -```python session=memory ansi=false -query = logs.after(1.0).filter_tags(level="info").order_by("ts", desc=True).limit(5) -print(query) -# nothing has hit the database yet -``` - - -``` -Stream[str]("logs") | after(t=1.0) | tags(level='info') | order(ts, desc) | limit(5) -``` - -Each call clones the stream with updated query parameters. The underlying `StreamQuery` compiles to SQL only at terminal time. - -### Filters - -| Method | Description | -|--------|-------------| -| `.after(t)` | `ts > t` | -| `.before(t)` | `ts < t` | -| `.time_range(t1, t2)` | `t1 <= ts <= t2` | -| `.at(t, tolerance=1.0)` | `\|ts - t\| <= tolerance` | -| `.near(pose, radius)` | R*Tree bounding box + exact distance post-filter | -| `.filter_tags(**kv)` | JSON tag field matching | - -### Ordering and pagination - -| Method | Description | -|--------|-------------| -| `.order_by(field, desc=False)` | Sort by `"ts"` or `"id"` | -| `.limit(k)` | Cap results | -| `.offset(n)` | Skip first n results | - -## Terminals execute the query - -| Terminal | Returns | Description | -|----------|---------|-------------| -| `.fetch()` | `ObservationSet` | All matching rows (lazy payloads) | -| `.fetch_pages(batch_size=128)` | `Iterator[list[Observation]]` | Paginated iteration | -| `.count()` | `int` | `SELECT COUNT(*)`, no payload loading | -| `.first()` | `Observation` | First by current ordering; raises `LookupError` if empty | -| `.last()` | `Observation` | Most recent by `ts` | -| `.exists()` | `bool` | `count() > 0` | -| `.summary()` | `str` | Count, time range, duration | -| `.get_time_range()` | `(float, float)` | `(first.ts, last.ts)` | - -List-like access also works: - -```python session=memory ansi=false -print(f"len={len(logs)}, bool={bool(logs)}") -print(f"logs[0] = {logs[0]}") -print(f"logs[-1] = {logs[-1]}") -``` - - -``` -len=3, bool=True -logs[0] = Observation(id=1, ts=1.0, pose=None, tags={'level': 'info'}) -logs[-1] = Observation(id=3, ts=3.0, pose=None, tags={'level': 'info'}) -``` - -Iteration uses paginated fetching under the hood: - -```python session=memory ansi=false -for obs in logs.after(1.5): - print(obs) -``` - - -``` -Observation(id=2, ts=2.0, pose=None, tags={'level': 'error'}) -Observation(id=3, ts=3.0, pose=None, tags={'level': 'info'}) -``` - -## ObservationSet - -`.fetch()` returns an `ObservationSet` — an in-memory result set that is itself a `Stream`. All filters and terminals work on it, re-evaluating in memory without hitting the database: - -```python session=memory ansi=false -results = logs.fetch() -print(results) -print(f"len={len(results)}") - -# re-filter in memory — no DB hit -errors = results.filter_tags(level="error").fetch() -print(errors) -print(errors[0]) -``` - - -``` -ObservationSet[str](3 items) -len=3 -ObservationSet[str](1 items) -Observation(id=2, ts=2.0, pose=None, tags={'level': 'error'}) -``` - -ObservationSet is read-only — `.append()` raises `TypeError`. - -When you chain filters on an ObservationSet, it downgrades to a plain Stream backed by the in-memory list, so it doesn't carry the full result set through the chain. - -## Transforms - -`.transform()` applies a function to each observation's payload. Without `.store()`, it runs entirely in memory: - -```python session=memory ansi=false -upper = logs.transform(lambda s: s.upper()) -print(upper) -print(upper.fetch()) -for obs in upper.fetch(): - print(obs.data) -``` - - -``` -TransformStream[?](Stream[str]("logs") -> PerItemTransformer) -ObservationSet[?](3 items) -MOTOR STARTED -JOINT 3 FAULT -MOTOR STOPPED -``` - -Return `None` to skip an item, return a `list` to fan-out: - -```python skip -# Filter: skip short messages -long = logs.transform(lambda s: s if len(s) > 10 else None) - -# Fan-out: split into words -words = logs.transform(lambda s: s.split()) -``` - -### Storing transforms - -`.store(name)` materializes a transform into a new named stream in the database: - -```python skip -# Default: backfill existing + subscribe to new appends -embeddings = images.transform(EmbeddingTransformer(clip)).store("clip_embeddings") - -# Live only: skip backfill, only process new appends -embeddings = images.transform(EmbeddingTransformer(clip), live=True).store("clip_embeddings") - -# Backfill only: process existing data, don't subscribe -embeddings = images.transform(EmbeddingTransformer(clip), backfill_only=True).store("clip_embeddings") -``` - -| Mode | Processes existing data | Subscribes to new appends | -|------|-------------------------|---------------------------| -| default | yes | yes | -| `live=True` | no | yes | -| `backfill_only=True` | yes | no | - -The output stream kind is auto-detected from the transformer: `EmbeddingTransformer` and `TextEmbeddingTransformer` create an `EmbeddingStream` with vec0 index, `CaptionTransformer` creates a `TextStream` with FTS index. - -Storing also records **parent lineage** — which source stream produced the derived stream. This powers `project_to`. - -### Built-in transformers - -| Transformer | Input | Output | Stored as | -|---|---|---|---| -| `PerItemTransformer(fn)` | any | any | `Stream` | -| `QualityWindowTransformer(quality_fn, window)` | any | same type | `Stream` — keeps best-quality item per time window | -| `CaptionTransformer(model)` | Image | `str` | `TextStream` with FTS index | -| `EmbeddingTransformer(model)` | Image/any | `Embedding` | `EmbeddingStream` with vec0 index | -| `TextEmbeddingTransformer(model)` | `str` | `Embedding` | `EmbeddingStream` with vec0 index | - -`QualityWindowTransformer` buffers observations within a time window and emits only the one with the highest quality score. Useful for sharpness filtering on camera frames: - -```python skip -from dimos.memory.transformer import QualityWindowTransformer - -sharp = images.transform( - QualityWindowTransformer(quality_fn=lambda img: img.sharpness, window=0.5) -).store("sharp_frames") -``` - -## Specialized streams - -### TextStream — full-text search - -```python session=memory ansi=false -text = session.text_stream("events") -text.append("Motor fault on joint 3", ts=1.0) -text.append("Battery low warning", ts=2.0) -text.append("Motor recovered", ts=3.0) - -results = text.search_text("motor").fetch() -print(results) -for obs in results: - print(f" {obs.data}") -``` - - -``` -ObservationSet[str](2 items) - Motor recovered - Motor fault on joint 3 -``` - -Uses SQLite FTS5. Results are ranked by relevance. Optional `k` parameter limits results. - -### EmbeddingStream — vector similarity search - -```python skip -embs = session.embedding_stream("clip_embs", vec_dimensions=512) -embs.append(embedding_vector, ts=1.0) - -results = embs.search_embedding([0.5, 0.3, ...], k=5).fetch() -``` - -Uses sqlite-vec (vec0) for cosine similarity. `search_embedding` accepts: -- Pre-computed `Embedding` or `list[float]` -- A `str` — auto-embedded via the stream's model (`embed_text`) -- An `Image` or other object — auto-embedded via the stream's model (`embed`) - -Results are `EmbeddingObservation` with `.similarity` (0–1 cosine) and `.embedding` (convenience alias for `.data`). - -## Lineage and project_to - -When you store a transform, each derived observation tracks its `parent_id`. Use `.project_to()` to follow the lineage chain back to a source stream: - -```python skip -images = session.stream("images", Image) -embeddings = images.transform(EmbeddingTransformer(clip)).store("clip_embeddings") - -# Search returns EmbeddingObservation — .data is the Embedding, not the source Image -results = embeddings.search_embedding("a hallway", k=5).fetch() -results[0].similarity # cosine similarity (0–1) -results[0].embedding # the Embedding vector -results[0].data # also the Embedding (same as .embedding) - -# To get source images, project back through the lineage chain -image_results = embeddings.search_embedding("office", k=5).project_to(images).fetch() - -# Multi-hop works too: embeddings → sharp_frames → raw_images -image_results = embeddings.search_embedding("office", k=5).project_to(raw_images).fetch() -``` - -## Reactive subscriptions - -Streams emit observations as they're appended: - -```python skip -images.subscribe(lambda obs: print(f"new frame at {obs.ts}")) - -# Filters work on subscriptions too: -images.after(10.0).filter_tags(cam="front").subscribe(handle_front_cam) - -# Or get the raw RxPY Observable: -observable = images.observable() -``` - -Under the hood this is an RxPY Observable on the backend's `Subject`. Embedding and lineage filters are skipped for live filtering (they need DB context); temporal, spatial, and tag filters work. - -## Codecs - -Payloads are BLOB-encoded via auto-selected codecs: - -| Codec | Used for | Strategy | -|-------|----------|----------| -| `JpegCodec` | `Image` | TurboJPEG lossy compression (~10-20x smaller), preserves `frame_id` | -| `LcmCodec` | `DimosMsg` types | LCM binary encoding (lossless) | -| `PickleCodec` | everything else | Python pickle (fallback) | - -`codec_for_type(payload_type)` auto-selects the best codec. This is transparent — you never need to specify a codec manually. - -## Session management - -### Listing streams - -```python session=memory ansi=false -for info in session.list_streams(): - print(f"{info.name}: {info.stream_kind}, {info.count} items") -``` - - -``` -logs: stream, 3 items -events: text, 3 items -``` - -### Context managers - -```python skip -with SqliteStore("memory.db") as store: - session = store.session() - # ... use session ... -# store.stop() called automatically -``` - -### Pose provider - -Auto-attach pose to every appended observation: - -```python skip -images = session.stream("images", Image, pose_provider=robot.get_pose) -images.append(frame) # pose is auto-filled from pose_provider() -``` - -### Persistence - -Data persists across restarts. Reopen the same database and streams pick up where they left off: - -```python skip -store = SqliteStore("memory.db") -session = store.session() -images = session.stream("images", Image) -results = images.after(100.0).fetch() # picks up old data -``` - -The `_streams` meta-table tracks stream names, payload types (as module paths), stream kind, parent lineage, and embedding dimensions. - -## MemoryModule — blueprint integration - -In a robot blueprint, `MemoryModule` wires input streams to memory: - -```python skip -class MyMemory(MemoryModule): - camera: In[Image] - - def start(self): - super().start() - - # Record camera input to a named stream (name/type inferred from input) - self.image_memory = self.memory(self.camera) - - # With quality filtering at 2 fps (keeps sharpest frame per window) - self.image_memory = self.memory(self.camera, fps=2) - - # Build derived streams - self.embeddings = self.image_memory.transform( - EmbeddingTransformer(CLIPModel()), live=True - ).store("clip_embeddings") -``` - -## Utilities - -### Bulk import - -```python skip -from dimos.memory.ingest import ingest - -# Import from any iterable of (timestamp, payload) tuples -count = ingest(images, replay.iterate_ts()) - -# With pose interpolation from an odometry source -count = ingest(images, replay.iterate_ts(), pose_source=odom_replay) -``` - -### Rerun export - -```python skip -from dimos.memory.rerun import to_rerun - -# Log a stream's observations to Rerun timeline -count = to_rerun(images) -count = to_rerun(images, entity_path="memory/camera") -``` - -```python session=memory no-result -store.stop() -``` - -```sh no-result -rm -f /tmp/memory_readme.db -``` diff --git a/dimos/memory/rerun.py b/dimos/memory/rerun.py deleted file mode 100644 index 629fa5b4dd..0000000000 --- a/dimos/memory/rerun.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Send memory stream contents to Rerun. - -Iterates a Stream, calls ``.to_rerun()`` on each observation's data -payload, and logs it at the observation's timestamp. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from dimos.memory.stream import Stream - - -def _infer_entity_path(stream: Any) -> str: - """Derive an entity path from the stream's backend name.""" - backend = getattr(stream, "_backend", None) - if backend is not None: - name = getattr(backend, "stream_name", None) - if name and name != "": - return f"memory/{name}" - raise ValueError( - "Cannot infer entity_path — stream has no named backend " - "(e.g. ObservationSet from .fetch()). Pass entity_path explicitly." - ) - - -def to_rerun( - stream: Stream[Any] | Any, - entity_path: str | None = None, -) -> int: - """Log stream observations to Rerun. - - For each observation whose ``.data`` has a ``to_rerun()`` method, - logs the result at the observation's timestamp on a custom "time" - timeline (no wall-clock contamination). - - Args: - stream: Any Stream or iterable of Observations. - entity_path: Rerun entity path. Auto-derived from stream name if None. - - Returns: - Number of items logged. - """ - import rerun as rr - - if entity_path is None: - entity_path = _infer_entity_path(stream) - - rr.disable_timeline("log_time") - rr.disable_timeline("log_tick") - - count = 0 - for obs in stream: - if obs.ts is not None: - rr.set_time("time", duration=obs.ts) - - data = obs.data - if hasattr(data, "to_rerun"): - rr.log(entity_path, data.to_rerun()) - count += 1 - - if obs.pose is not None and hasattr(obs.pose, "to_rerun_arrow"): - rr.log(f"{entity_path}/pose", obs.pose.to_rerun_arrow()) - - rr.reset_time() - return count diff --git a/dimos/memory/store.py b/dimos/memory/store.py index 0d2f549110..213df34d84 100644 --- a/dimos/memory/store.py +++ b/dimos/memory/store.py @@ -15,166 +15,150 @@ from __future__ import annotations from abc import abstractmethod -from typing import TYPE_CHECKING, Any, TypeVar +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, TypeVar, cast -from dimos.core.resource import Resource +from dimos.core.resource import CompositeResource +from dimos.memory.stream import Stream +from dimos.protocol.service.spec import Configurable if TYPE_CHECKING: - from dimos.memory.stream import EmbeddingStream, Stream, TextStream - from dimos.memory.transformer import Transformer - from dimos.memory.type import PoseProvider - from dimos.models.embedding.base import Embedding, EmbeddingModel + from collections.abc import Iterator + + from dimos.memory.backend import Backend, BlobStore, LiveChannel, VectorStore + from dimos.memory.codecs.base import Codec T = TypeVar("T") +# ── Configuration ───────────────────────────────────────────────── + + +@dataclass +class StoreConfig: + """Base config for Store. Subclasses extend with store-specific fields.""" + + +@dataclass +class SessionConfig: + """Session-level defaults for stream capabilities. + + These are inherited by all streams in the session unless overridden + per-stream in ``session.stream(..., **overrides)``. + """ + + live_channel: LiveChannel[Any] | None = None + blob_store: BlobStore | None = None + vector_store: VectorStore | None = None + eager_blobs: bool = False + codec: Codec[Any] | None = None + + +# ── Stream namespace ────────────────────────────────────────────── + + class StreamNamespace: """Attribute-access proxy for session streams. Usage:: - session.streams.image_stream # same as looking up "image_stream" from list_streams() + session.streams.image_stream session.streams["image_stream"] - list(session.streams) # iterate all streams + list(session.streams) len(session.streams) """ def __init__(self, session: Session) -> None: self._session = session - def _load(self) -> dict[str, Stream[Any]]: - return {s._backend.stream_name: s for s in self._session.list_streams() if s._backend} - def __getattr__(self, name: str) -> Stream[Any]: if name.startswith("_"): raise AttributeError(name) - streams = self._load() - if name in streams: - return streams[name] - raise AttributeError( - f"No stream named {name!r}. Available: {', '.join(streams) or '(none)'}" - ) + if name not in self._session.list_streams(): + available = ", ".join(self._session.list_streams()) or "(none)" + raise AttributeError(f"No stream named {name!r}. Available: {available}") + return self._session.stream(name) def __getitem__(self, name: str) -> Stream[Any]: - streams = self._load() - if name in streams: - return streams[name] - raise KeyError(name) + if name not in self._session.list_streams(): + raise KeyError(name) + return self._session.stream(name) - def __iter__(self): - return iter(self._load().values()) + def __iter__(self) -> Iterator[Stream[Any]]: + for name in self._session.list_streams(): + yield self._session.stream(name) def __len__(self) -> int: - return len(self._load()) + return len(self._session.list_streams()) def __contains__(self, name: str) -> bool: - return name in self._load() + return name in self._session.list_streams() def __repr__(self) -> str: - names = list(self._load().keys()) - return f"StreamNamespace({names})" + return f"StreamNamespace({self._session.list_streams()})" -class Session(Resource): - """A session against a memory store. Creates and manages streams. +# ── Session & Store ─────────────────────────────────────────────── - Inherits DisposableBase so sessions can be added to CompositeDisposable. - """ - @property - def streams(self) -> StreamNamespace: - """Attribute-access namespace for all streams in this session.""" - return StreamNamespace(self) +class Session(Configurable[SessionConfig], CompositeResource): + """A session against a store. Manages named streams over a shared connection. - def start(self) -> None: - pass - - @abstractmethod - def stream( - self, - name: str, - payload_type: type[T], - *, - pose_provider: PoseProvider | None = None, - ) -> Stream[T]: - """Get or create a stored stream backed by the database.""" + Subclasses implement ``_create_backend`` to provide storage-specific backends. + """ - @abstractmethod - def text_stream( - self, - name: str, - *, - tokenizer: str = "unicode61", - pose_provider: PoseProvider | None = None, - ) -> TextStream[str]: - """Get or create a text stream with FTS index.""" + default_config: type[SessionConfig] = SessionConfig - @abstractmethod - def embedding_stream( - self, - name: str, - *, - vec_dimensions: int | None = None, - pose_provider: PoseProvider | None = None, - parent_table: str | None = None, - embedding_model: EmbeddingModel | None = None, - ) -> EmbeddingStream[Embedding]: - """Get or create an embedding stream with vec0 index.""" + def __init__(self, **kwargs: Any) -> None: + Configurable.__init__(self, **kwargs) + CompositeResource.__init__(self) + self._streams: dict[str, Stream[Any]] = {} @abstractmethod - def list_streams(self) -> list[Stream[Any]]: ... + def _create_backend( + self, name: str, payload_type: type[Any] | None = None, **config: Any + ) -> Backend[Any]: + """Create a backend for the named stream. Called once per stream name.""" + ... - @abstractmethod - def delete_stream(self, name: str) -> None: - """Drop a stream and all its associated tables (payload, rtree, etc.).""" + def stream(self, name: str, payload_type: type[T] | None = None, **overrides: Any) -> Stream[T]: + """Get or create a named stream. Returns the same Stream on repeated calls. - @abstractmethod - def materialize_transform( - self, - name: str, - source: Stream[Any], - transformer: Transformer[Any, Any], - *, - payload_type: type | None = None, - live: bool = False, - backfill_only: bool = False, - ) -> Stream[Any]: - """Create a stored stream from a transform pipeline.""" - - @abstractmethod - def resolve_parent_stream(self, name: str) -> str | None: - """Return the direct parent stream name, or None if no lineage exists.""" + Per-stream ``overrides`` (e.g. ``live_channel=``) are merged on top of + the session-level defaults from :class:`SessionConfig`. + """ + if name not in self._streams: + resolved = {k: v for k, v in vars(self.config).items() if v is not None} + resolved.update({k: v for k, v in overrides.items() if v is not None}) + backend = self._create_backend(name, payload_type, **resolved) + self._streams[name] = Stream(source=backend) + return cast("Stream[T]", self._streams[name]) @abstractmethod - def resolve_lineage_chain(self, source: str, target: str) -> tuple[str, ...]: - """Return intermediate tables in the parent_id chain from source to target. - - Single hop (source directly parents target) returns ``()``. - Two hops (source → mid → target) returns ``("mid",)``. - Raises ``ValueError`` if no lineage path exists. - """ + def list_streams(self) -> list[str]: + """Return names of all streams in this session.""" + ... @abstractmethod - def stop(self) -> None: ... + def delete_stream(self, name: str) -> None: + """Delete a stream by name (from cache and underlying storage).""" + ... - def __enter__(self) -> Session: - return self + @property + def streams(self) -> StreamNamespace: + return StreamNamespace(self) -class Store(Resource): - """Top-level entry point — wraps a database file.""" +class Store(Configurable[StoreConfig], CompositeResource): + """Top-level entry point — wraps a storage location (file, URL, etc.).""" - @abstractmethod - def session(self) -> Session: ... + default_config: type[StoreConfig] = StoreConfig - def start(self) -> None: - pass + def __init__(self, **kwargs: Any) -> None: + Configurable.__init__(self, **kwargs) + CompositeResource.__init__(self) @abstractmethod - def stop(self) -> None: ... - - def __enter__(self) -> Store: - return self - - def __exit__(self, *args: object) -> None: - self.stop() + def session(self, **kwargs: Any) -> Session: + """Create a session. kwargs are forwarded to SessionConfig.""" + ... diff --git a/dimos/memory/stream.py b/dimos/memory/stream.py index 4b044028ed..60d8a6ed7c 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory/stream.py @@ -14,156 +14,137 @@ from __future__ import annotations -import copy import time -from typing import ( - TYPE_CHECKING, - Any, - Generic, - Protocol, - Self, - TypeVar, - cast, - overload, -) - -import numpy as np -import reactivex.operators as ops +from typing import TYPE_CHECKING, Any, Generic, TypeVar -from dimos.memory.formatting import render_text, rich_text -from dimos.memory.type import ( +from dimos.core.resource import Resource +from dimos.memory.backend import Backend +from dimos.memory.buffer import BackpressureBuffer, KeepLast +from dimos.memory.filter import ( AfterFilter, AtFilter, BeforeFilter, - EmbeddingObservation, - EmbeddingSearchFilter, Filter, - LineageFilter, NearFilter, - Observation, + PredicateFilter, StreamQuery, TagsFilter, - TextSearchFilter, TimeRangeFilter, ) -from dimos.types.timestamped import Timestamped +from dimos.memory.transform import FnIterTransformer, FnTransformer, Transformer +from dimos.memory.type import EmbeddedObservation, Observation if TYPE_CHECKING: from collections.abc import Callable, Iterator - from reactivex import Observable - from reactivex.abc import DisposableBase as Disposable - from reactivex.subject import Subject + import reactivex + from reactivex.abc import DisposableBase, ObserverBase - from dimos.memory.store import Session - from dimos.memory.transformer import Transformer - from dimos.models.embedding.base import Embedding, EmbeddingModel - from dimos.msgs.geometry_msgs.Pose import PoseLike + from dimos.models.embedding.base import Embedding T = TypeVar("T") R = TypeVar("R") -class StreamBackend(Protocol): - """Backend protocol — implemented by SqliteStreamBackend etc.""" - - def execute_fetch(self, query: StreamQuery) -> list[Observation[Any]]: ... - def execute_count(self, query: StreamQuery) -> int: ... - def do_append( - self, - payload: Any, - ts: float | None, - pose: Any | None, - tags: dict[str, Any] | None, - parent_id: int | None = None, - ) -> Observation[Any]: ... - - def load_data(self, row_id: int) -> Any: ... +class Stream(Resource, Generic[T]): + """Lazy, pull-based stream over observations. - @property - def appended_subject(self) -> Subject[Observation[Any]]: ... # type: ignore[type-arg] - @property - def stream_name(self) -> str: ... + Every filter/transform method returns a new Stream — no computation + happens until iteration. Backends handle query application for stored + data; transform sources apply filters as Python predicates. - -class Stream(Generic[T]): - """Lazy, chainable stream over stored observations. - - Created by Session.stream(). Filter methods return new Stream instances. - Terminals (.fetch(), .count(), etc.) execute the query. + Implements Resource so live streams can be cleanly stopped via + ``stop()`` or used as a context manager. """ def __init__( self, - backend: StreamBackend | None = None, + source: Backend[T] | Stream[Any], *, - query: StreamQuery | None = None, - session: Session | None = None, - payload_type: type | None = None, + xf: Transformer[Any, T] | None = None, + query: StreamQuery = StreamQuery(), ) -> None: - self._backend = backend - self._query = query or StreamQuery() - self._session: Session | None = session - self._payload_type: type | None = payload_type - - @property - def name(self) -> str: - """The stream name in the backing store.""" - return self._require_backend().stream_name - - def _clone(self, **overrides: Any) -> Self: - """Return a shallow copy with updated query fields.""" - q = self._query - clone = copy.copy(self) - clone._query = StreamQuery( - filters=overrides.get("filters", q.filters), - order_field=overrides.get("order_field", q.order_field), - order_desc=overrides.get("order_desc", q.order_desc), - limit_val=overrides.get("limit_val", q.limit_val), - offset_val=overrides.get("offset_val", q.offset_val), - ) - return clone + self._source = source + self._xf = xf + self._query = query - def __repr__(self) -> str: - return rich_text(self).plain + def start(self) -> None: + pass - def __str__(self) -> str: - return render_text(rich_text(self)) + def stop(self) -> None: + """Close the live buffer (if any), unblocking iteration.""" + buf = self._query.live_buffer + if buf is not None: + buf.close() + if isinstance(self._source, Stream): + self._source.stop() - def _with_filter(self, f: Filter) -> Self: - return self._clone(filters=(*self._query.filters, f)) + def __str__(self) -> str: + # Walk the source chain to collect (xf, query) pairs + chain: list[tuple[Any, StreamQuery]] = [] + current: Any = self + while isinstance(current, Stream): + chain.append((current._xf, current._query)) + current = current._source + chain.reverse() # innermost first + + # current is the Backend + name = getattr(current, "name", "?") + result = f'Stream("{name}")' + + for xf, query in chain: + if xf is not None: + result += f" -> {xf}" + q_str = str(query) + if q_str: + result += f" | {q_str}" + + return result + + def is_live(self) -> bool: + """True if this stream (or any ancestor in the chain) is in live mode.""" + if self._query.live_buffer is not None: + return True + if isinstance(self._source, Stream): + return self._source.is_live() + return False + + # ── Iteration ─────────────────────────────────────────────────── - def _require_backend(self) -> StreamBackend: - if self._backend is None: - raise TypeError( - "Operation requires a stored stream. Call .store() first or use session.stream()." - ) - return self._backend + def __iter__(self) -> Iterator[Observation[T]]: + return self._build_iter() - # ── Data loading ────────────────────────────────────────────────── + def _build_iter(self) -> Iterator[Observation[T]]: + if isinstance(self._source, Stream): + return self._iter_transform() + # Backend handles all query application (including live if requested) + return self._source.iterate(self._query) - def load_data(self, obs: Observation[T]) -> T: - """Load payload for an observation. Thread-safe alternative to obs.data.""" - backend = self._require_backend() - return cast("T", backend.load_data(obs.id)) + def _iter_transform(self) -> Iterator[Observation[T]]: + """Iterate a transform source, applying query filters in Python.""" + assert isinstance(self._source, Stream) and self._xf is not None + it: Iterator[Observation[T]] = self._xf(iter(self._source)) + return self._query.apply(it, live=self.is_live()) - # ── Write ───────────────────────────────────────────────────────── + # ── Query builders ────────────────────────────────────────────── - def append( - self, - payload: T, - *, - ts: float | None = None, - pose: PoseLike | None = None, - tags: dict[str, Any] | None = None, - parent_id: int | None = None, - ) -> Observation[T]: - if ts is None and isinstance(payload, Timestamped): - ts = payload.ts - backend = self._require_backend() - return cast("Observation[T]", backend.do_append(payload, ts, pose, tags, parent_id)) + def _replace_query(self, **overrides: Any) -> Stream[T]: + q = self._query + new_q = StreamQuery( + filters=overrides.get("filters", q.filters), + order_field=overrides.get("order_field", q.order_field), + order_desc=overrides.get("order_desc", q.order_desc), + limit_val=overrides.get("limit_val", q.limit_val), + offset_val=overrides.get("offset_val", q.offset_val), + live_buffer=overrides.get("live_buffer", q.live_buffer), + search_vec=overrides.get("search_vec", q.search_vec), + search_k=overrides.get("search_k", q.search_k), + search_text=overrides.get("search_text", q.search_text), + ) + return Stream(self._source, xf=self._xf, query=new_q) - # ── Temporal filters ────────────────────────────────────────────── + def _with_filter(self, f: Filter) -> Stream[T]: + return self._replace_query(filters=(*self._query.filters, f)) def after(self, t: float) -> Stream[T]: return self._with_filter(AfterFilter(t)) @@ -174,618 +155,227 @@ def before(self, t: float) -> Stream[T]: def time_range(self, t1: float, t2: float) -> Stream[T]: return self._with_filter(TimeRangeFilter(t1, t2)) - def at(self, t: float | Stream[Any], *, tolerance: float = 1.0) -> Stream[T]: - if isinstance(t, Stream): - t1, t2 = t.get_time_range() - return self._with_filter(TimeRangeFilter(t1 - tolerance, t2 + tolerance)) + def at(self, t: float, tolerance: float = 1.0) -> Stream[T]: return self._with_filter(AtFilter(t, tolerance)) - # ── Spatial filter ──────────────────────────────────────────────── - - def near(self, pose: PoseLike | Stream[Any], radius: float = 0.0) -> Stream[T]: - if isinstance(pose, Stream): - center, max_dist = pose.bounding_sphere() - return self._with_filter(NearFilter(center, max_dist + radius)) + def near(self, pose: Any, radius: float) -> Stream[T]: return self._with_filter(NearFilter(pose, radius)) - # ── Tag filter ──────────────────────────────────────────────────── + def tags(self, **tags: Any) -> Stream[T]: + return self._with_filter(TagsFilter(tags)) - def filter_tags(self, **tags: Any) -> Stream[T]: - return self._with_filter(TagsFilter(tuple(tags.items()))) - - # ── Ordering / pagination ───────────────────────────────────────── - - def order_by(self, field: str, *, desc: bool = False) -> Stream[T]: - return self._clone(order_field=field, order_desc=desc) + def order_by(self, field: str, desc: bool = False) -> Stream[T]: + return self._replace_query(order_field=field, order_desc=desc) def limit(self, k: int) -> Stream[T]: - return self._clone(limit_val=k) + return self._replace_query(limit_val=k) def offset(self, n: int) -> Stream[T]: - return self._clone(offset_val=n) + return self._replace_query(offset_val=n) - # ── Transform ───────────────────────────────────────────────────── + def search(self, query: Embedding, k: int) -> Stream[T]: + """Return top-k observations by cosine similarity to *query*. - @overload - def transform( - self, - xf: Transformer[T, R], - *, - live: bool = ..., - backfill_only: bool = ..., - ) -> Stream[R]: ... + The backend handles the actual computation. ListBackend does + brute-force cosine; SqliteBackend (future) pushes down to vec0. + """ + return self._replace_query(search_vec=query, search_k=k) - @overload - def transform( - self, - xf: Callable[[T], Any], - *, - live: bool = ..., - backfill_only: bool = ..., - ) -> Stream[Any]: ... + def search_text(self, text: str) -> Stream[T]: + """Filter observations whose data contains *text*. - def transform( - self, - xf: Transformer[Any, Any] | Callable[..., Any], - *, - live: bool = False, - backfill_only: bool = False, - ) -> Stream[Any]: - from dimos.memory.transformer import PerItemTransformer, Transformer as TransformerABC - - transformer: TransformerABC[Any, Any] - if not isinstance(xf, TransformerABC): - transformer = PerItemTransformer(xf) - else: - transformer = xf + ListBackend does case-insensitive substring match; + SqliteBackend (future) pushes down to FTS5. + """ + return self._replace_query(search_text=text) - return TransformStream( - source=self, - transformer=transformer, - live=live, - backfill_only=backfill_only, - ) + # ── Functional API ────────────────────────────────────────────── + + def filter(self, pred: Callable[[Observation[T]], bool]) -> Stream[T]: + """Filter by arbitrary predicate on the full Observation.""" + return self._with_filter(PredicateFilter(pred)) + + def map(self, fn: Callable[[Observation[T]], Observation[R]]) -> Stream[Any]: + """Transform each observation's data via callable.""" + return self.transform(FnTransformer(lambda obs: fn(obs))) - # ── Materialize ─────────────────────────────────────────────────── + # ── Transform ─────────────────────────────────────────────────── - def store( + def transform( self, - name: str | None = None, - payload_type: type | None = None, - session: Session | None = None, - ) -> Stream[T]: - # Already stored streams are a no-op - if self._backend is not None and name is None: - return self - raise TypeError( - "store() requires a session context. This stream is not associated with a session." - ) + xf: Transformer[T, R] | Callable[[Iterator[Observation[T]]], Iterator[Observation[R]]], + ) -> Stream[R]: + """Wrap this stream with a transformer. Returns a new lazy Stream. - # ── Cross-stream lineage ────────────────────────────────────────── + Accepts a ``Transformer`` subclass or a bare callable / generator + function with the same ``Iterator[Obs] → Iterator[Obs]`` signature:: - def project_to(self, target: Stream[R]) -> Stream[R]: - """Follow parent_id lineage to project observations onto the target stream. + def detect(upstream): + for obs in upstream: + yield obs.derive(data=run_detector(obs.data)) - Returns a filtered *target* Stream containing only observations that are - ancestors of the current (source) query results. The result is a normal - Stream — all chaining, pagination, and lazy loading work as usual. + images.transform(detect).save(detections) """ - backend = self._require_backend() - target_backend = target._require_backend() - session = self._session - if session is None: - raise TypeError("project_to requires a session-backed stream") + if not isinstance(xf, Transformer): + xf = FnIterTransformer(xf) + return Stream(source=self, xf=xf, query=StreamQuery()) - source_table = backend.stream_name - target_table = target_backend.stream_name + # ── Live mode ─────────────────────────────────────────────────── - if source_table == target_table: - return self # type: ignore[return-value] + def live(self, buffer: BackpressureBuffer[Observation[Any]] | None = None) -> Stream[T]: + """Return a stream whose iteration never ends — backfill then live tail. - hops = session.resolve_lineage_chain(source_table, target_table) + All backends support live mode via their built-in ``LiveChannel``. + Call .live() before .transform(), not after. - return target._with_filter( - LineageFilter( - source_table=source_table, - source_query=self._query, - hops=hops, + Default buffer: KeepLast(). The backend handles subscription, dedup, + and backpressure — how it does so is its business. + """ + if isinstance(self._source, Stream): + raise TypeError( + "Cannot call .live() on a transform stream. " + "Call .live() on the source stream, then .transform()." ) - ) + buf = buffer if buffer is not None else KeepLast() + return self._replace_query(live_buffer=buf) - # ── List-like interface ──────────────────────────────────────────── + # ── Save ───────────────────────────────────────────────────────── - def __iter__(self) -> Iterator[Observation[T]]: - for page in self.fetch_pages(): - yield from page - - def __len__(self) -> int: - return self.count() - - @overload - def __getitem__(self, index: int) -> Observation[T]: ... - - @overload - def __getitem__(self, index: slice) -> list[Observation[T]]: ... - - def __getitem__(self, index: int | slice) -> Observation[T] | list[Observation[T]]: - if isinstance(index, int): - if index < 0: - # Negative index: need count to resolve - n = self.count() - index = n + index - if index < 0: - raise IndexError("stream index out of range") - results = self.offset(index).limit(1).fetch() - if not results: - raise IndexError("stream index out of range") - return results[0] - # Slice - start, stop, step = index.indices(self.count()) - s = self.offset(start).limit(stop - start) - results = s.fetch() - if step != 1: - return list(results)[::step] - return list(results) - - def __bool__(self) -> bool: - return self.exists() - - # ── Terminals ───────────────────────────────────────────────────── - - def fetch(self) -> ObservationSet[T]: - backend = self._require_backend() - results = backend.execute_fetch(self._query) - return ObservationSet( - cast("list[Observation[T]]", results), - session=self._session, - payload_type=self._payload_type, - ) + def save(self, target: Stream[T]) -> Stream[T]: + """Sync terminal: iterate self, append each obs to target's backend. - def fetch_pages(self, batch_size: int = 128) -> Iterator[list[Observation[T]]]: - offset = self._query.offset_val or 0 - total_limit = self._query.limit_val - emitted = 0 - while True: - page_size = batch_size - if total_limit is not None: - remaining = total_limit - emitted - if remaining <= 0: - break - page_size = min(batch_size, remaining) - q = StreamQuery( - filters=self._query.filters, - order_field=self._query.order_field or "id", - order_desc=self._query.order_desc, - limit_val=page_size, - offset_val=offset, + Returns the target stream for continued querying. + """ + if isinstance(target._source, Stream): + raise TypeError("Cannot save to a transform stream. Target must be backend-backed.") + backend = target._source + for obs in self: + backend.append(obs) + return target + + # ── Terminals ─────────────────────────────────────────────────── + + def fetch(self) -> list[Observation[T]]: + """Materialize all observations into a list.""" + if self.is_live(): + raise TypeError( + ".fetch() on a live stream would block forever. " + "Use .drain() or .save(target) instead." ) - backend = self._require_backend() - page = backend.execute_fetch(q) - if not page: - break - yield cast("list[Observation[T]]", page) - emitted += len(page) - if len(page) < page_size: - break - offset += len(page) + return list(self) def first(self) -> Observation[T]: - results = self.limit(1).fetch() - if not results: - raise LookupError("No matching observation") - return results[0] + """Return the first matching observation.""" + it = iter(self.limit(1)) + try: + return next(it) + except StopIteration: + raise LookupError("No matching observation") from None def last(self) -> Observation[T]: - results = self.order_by("ts", desc=True).limit(1).fetch() - if not results: - raise LookupError("No matching observation") - return results[0] + """Return the last matching observation (by timestamp).""" + return self.order_by("ts", desc=True).first() def count(self) -> int: - backend = self._require_backend() - return backend.execute_count(self._query) + """Count matching observations.""" + if isinstance(self._source, Backend): + return self._source.count(self._query) + if self.is_live(): + raise TypeError(".count() on a live transform stream would block forever.") + return sum(1 for _ in self) def exists(self) -> bool: - return self.count() > 0 - - def delete(self) -> None: - """Drop this stream and all associated storage.""" - if self._session is None: - raise TypeError("Cannot delete: no session available.") - backend = self._require_backend() - self._session.delete_stream(backend.stream_name) + """Check if any matching observation exists.""" + return next(iter(self.limit(1)), None) is not None def get_time_range(self) -> tuple[float, float]: - return (self.first().ts, self.last().ts) - - def bounding_sphere(self) -> tuple[Any, float]: - """Return (centroid_pose, max_distance_from_centroid) for all poses.""" - xs: list[float] = [] - ys: list[float] = [] - zs: list[float] = [] - for obs in self: - if obs.pose is None: - continue - p = obs.pose.position - xs.append(p.x) - ys.append(p.y) - zs.append(p.z) - if not xs: - raise ValueError("No observations with poses in this stream") - cx, cy, cz = sum(xs) / len(xs), sum(ys) / len(ys), sum(zs) / len(zs) - max_dist = max( - ((x - cx) ** 2 + (y - cy) ** 2 + (z - cz) ** 2) ** 0.5 - for x, y, z in zip(xs, ys, zs, strict=True) - ) - from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped - - center = PoseStamped(position=[cx, cy, cz]) - return center, max_dist + """Return (min_ts, max_ts) for matching observations.""" + first = self.first() + last = self.last() + return (first.ts, last.ts) def summary(self) -> str: + """Return a short human-readable summary: count, time range, duration.""" from datetime import datetime, timezone - t = rich_text(self) n = self.count() if n == 0: - t.append(": ", style="dim") - t.append("empty", style="italic dim") - return render_text(t) - t0, t1 = self.get_time_range() + return f"{self}: empty" + + (t0, t1) = self.get_time_range() + fmt = "%Y-%m-%d %H:%M:%S" dt0 = datetime.fromtimestamp(t0, tz=timezone.utc).strftime(fmt) dt1 = datetime.fromtimestamp(t1, tz=timezone.utc).strftime(fmt) dur = t1 - t0 - t.append(": ", style="dim") - t.append(f"{n}", style="bold white") - t.append(" items, ", style="dim") - t.append(dt0, style="bright_blue") - t.append(" — ", style="dim") - t.append(dt1, style="bright_blue") - t.append(f" ({dur:.1f}s)", style="dim yellow") - return render_text(t) - - # ── Reactive ────────────────────────────────────────────────────── - - def observable(self) -> Observable[Observation[T]]: # type: ignore[type-arg] - backend = self._require_backend() - raw: Observable[Observation[T]] = backend.appended_subject # type: ignore[assignment] - if not self._query.filters: - return raw - active = [ - f - for f in self._query.filters - if not isinstance(f, (EmbeddingSearchFilter, LineageFilter)) - ] - - def _check(o: Observation[T]) -> bool: - return all(f.matches(o) for f in active) - - return raw.pipe(ops.filter(_check)) - - def subscribe(self, on_next: Callable[[Observation[T]], None]) -> Disposable: - return self.observable().subscribe(on_next=on_next) - + return f"{self}: {n} items, {dt0} — {dt1} ({dur:.1f}s)" -class EmbeddingStream(Stream[T]): - """Stream with a vector index. Adds search_embedding().""" - - _embedding_model: EmbeddingModel | None - - def __init__( - self, - backend: StreamBackend | None = None, - *, - query: StreamQuery | None = None, - session: Session | None = None, - embedding_model: EmbeddingModel | None = None, - payload_type: type | None = None, - ) -> None: - super().__init__(backend=backend, query=query, session=session, payload_type=payload_type) - self._embedding_model = embedding_model - - def _require_model(self) -> EmbeddingModel: - if self._embedding_model is None: - raise TypeError( - "This embedding stream has no model reference. " - "Pass a str/image only on streams created via EmbeddingTransformer, " - "or search with a pre-computed Embedding / list[float]." - ) - return self._embedding_model + def drain(self) -> int: + """Consume all observations, discarding results. Returns count consumed. - def search_embedding( - self, - query: Embedding | list[float] | str | Any, - *, - k: int, - model: EmbeddingModel | None = None, - ) -> EmbeddingStream[T]: - """Search by vector similarity. - - Accepts pre-computed embeddings, raw float lists, text strings, or - images/other objects. Text and non-vector inputs are auto-embedded - using the model that created this stream (or the ``model`` override). - - Returns an EmbeddingStream — use ``.project_to(source)`` to get - results in the source stream's type, or ``.fetch()`` for - ``EmbeddingObservation`` with ``.similarity`` scores. + Use for side-effect pipelines (e.g. live embed-and-store) where you + don't need to collect results in memory. """ - from dimos.models.embedding.base import Embedding as EmbeddingCls - - resolve = model or self._embedding_model - label: str | None = None - if isinstance(query, str): - if resolve is None: - raise TypeError( - "No embedding model available. Pass model= or use a " - "pre-computed Embedding / list[float]." - ) - label = query - emb = resolve.embed_text(query) - if isinstance(emb, list): - emb = emb[0] - query = emb - - if isinstance(query, EmbeddingCls): - vec = query.to_numpy().tolist() - elif isinstance(query, list): - vec = list(query) - else: - # Assume embeddable object (Image, etc.) - if resolve is None: - raise TypeError( - "No embedding model available. Pass model= or use a " - "pre-computed Embedding / list[float]." - ) - label = type(query).__name__ - emb = resolve.embed(query) - if isinstance(emb, list): - emb = emb[0] - query = emb - vec = query.to_numpy().tolist() - - return self._with_filter(EmbeddingSearchFilter(vec, k, label=label)) - - def fetch(self) -> ObservationSet[T]: # type: ignore[override] - backend = self._require_backend() - results = backend.execute_fetch(self._query) - return ObservationSet( - cast("list[Observation[T]]", results), - session=self._session, - payload_type=self._payload_type, - ) - - def first(self) -> EmbeddingObservation: # type: ignore[override] - results = self.limit(1).fetch() - if not results: - raise LookupError("No matching observation") - return results[0] # type: ignore[return-value] + n = 0 + for _ in self: + n += 1 + return n - def last(self) -> EmbeddingObservation: # type: ignore[override] - results = self.order_by("ts", desc=True).limit(1).fetch() - if not results: - raise LookupError("No matching observation") - return results[0] # type: ignore[return-value] + # ── Reactive ───────────────────────────────────────────────────── + def observable(self) -> reactivex.Observable[Observation[T]]: + """Convert this stream to an RxPY Observable. -class TextStream(Stream[T]): - """Stream with an FTS5 index. Adds search_text().""" - - def search_text(self, text: str, *, k: int | None = None) -> TextStream[T]: - return self._with_filter(TextSearchFilter(text, k)) - + Iteration is scheduled on the dimos thread pool so subscribe() never + blocks the calling thread. + """ + import reactivex + import reactivex.operators as ops -class TransformStream(Stream[R]): - """In-memory stream produced by .transform(). Backed by ListBackend.""" + from dimos.utils.threadpool import get_scheduler - def __init__( - self, - source: Stream[Any], - transformer: Transformer[Any, R], - *, - live: bool = False, - backfill_only: bool = False, - ) -> None: - backend = ListBackend([], name="") - super().__init__(backend=backend, session=source._session) - self._source = source - self._transformer = transformer - self._live = live - self._backfill_only = backfill_only - self._materialized = False - - def _materialize(self) -> None: - """Run backfill if not yet done.""" - if self._materialized: - return - self._materialized = True - if self._transformer.supports_backfill and not self._live: - self._transformer.process(self._source, self) - - def _require_backend(self) -> StreamBackend: - self._materialize() - return super()._require_backend() - - def __repr__(self) -> str: - return rich_text(self).plain - - def __str__(self) -> str: - return render_text(rich_text(self)) - - def fetch(self) -> ObservationSet[R]: - self._materialize() - backend = cast("ListBackend", self._backend) - return ObservationSet( - cast("list[Observation[R]]", list(backend._observations)), - session=self._session, - payload_type=self._transformer.output_type, + return reactivex.from_iterable(self).pipe( + ops.subscribe_on(get_scheduler()), ) - def store( + def subscribe( self, - name: str | None = None, - payload_type: type | None = None, - session: Session | None = None, - ) -> Stream[R]: - resolved = session or self._source._session - if resolved is None: - raise TypeError( - "Cannot store: no session available. " - "Either use session.stream() to create the source, " - "or pass session= to store()." - ) - if name is None: - raise TypeError("store() requires a name for transform outputs") - resolved_type = payload_type or self._transformer.output_type - return resolved.materialize_transform( - name=name, - source=self._source, - transformer=self._transformer, - payload_type=resolved_type, - live=self._live, - backfill_only=self._backfill_only, + on_next: Callable[[Observation[T]], None] | ObserverBase[Observation[T]] | None = None, + on_error: Callable[[Exception], None] | None = None, + on_completed: Callable[[], None] | None = None, + ) -> DisposableBase: + """Subscribe to this stream as an RxPY Observable.""" + return self.observable().subscribe( # type: ignore[call-overload] + on_next=on_next, + on_error=on_error, + on_completed=on_completed, ) + # ── Write ─────────────────────────────────────────────────────── -class ListBackend: - """In-memory backend that evaluates StreamQuery filters in Python.""" - - def __init__(self, observations: list[Observation[Any]], name: str = "") -> None: - self._observations = observations - self._name = name - self._next_id = max((o.id for o in observations), default=-1) + 1 - from reactivex.subject import Subject - - self._subject: Subject[Observation[Any]] = Subject() # type: ignore[type-arg] - - def execute_fetch(self, query: StreamQuery) -> list[Observation[Any]]: - results = list(self._observations) - - # Apply non-embedding filters - for f in query.filters: - if isinstance(f, (EmbeddingSearchFilter, LineageFilter)): - continue - results = [obs for obs in results if f.matches(obs)] - - # Embedding top-k pass (cosine similarity) - emb_filters = [f for f in query.filters if isinstance(f, EmbeddingSearchFilter)] - if emb_filters: - ef = emb_filters[0] - query_vec = np.array(ef.query, dtype=np.float32) - query_norm = np.linalg.norm(query_vec) - if query_norm > 0: - scored = [] - for obs in results: - if isinstance(obs, EmbeddingObservation): - obs_vec = obs.embedding.to_numpy() - else: - continue - obs_norm = np.linalg.norm(obs_vec) - if obs_norm > 0: - sim = float(np.dot(query_vec, obs_vec) / (query_norm * obs_norm)) - else: - sim = 0.0 - scored.append((sim, obs)) - scored.sort(key=lambda x: x[0], reverse=True) - results = [obs for _, obs in scored[: ef.k]] - - # Ordering - if query.order_field: - key = query.order_field - results.sort( - key=lambda obs: getattr(obs, key) if getattr(obs, key, None) is not None else 0, - reverse=query.order_desc, - ) - - # Offset / limit - if query.offset_val: - results = results[query.offset_val :] - if query.limit_val is not None: - results = results[: query.limit_val] - - return results - - def execute_count(self, query: StreamQuery) -> int: - return len(self.execute_fetch(query)) - - def load_data(self, row_id: int) -> Any: - for obs in self._observations: - if obs.id == row_id: - return obs.data - raise LookupError(f"No observation with id={row_id}") - - def do_append( - self, - payload: Any, - ts: float | None, - pose: Any | None, - tags: dict[str, Any] | None, - parent_id: int | None = None, - ) -> Observation[Any]: - obs: Observation[Any] = Observation( - id=self._next_id, - ts=ts if ts is not None else time.time(), - pose=pose, - tags=tags or {}, - parent_id=parent_id, - _data=payload, - ) - self._next_id += 1 - self._observations.append(obs) - self._subject.on_next(obs) - return obs - - @property - def appended_subject(self) -> Subject[Observation[Any]]: # type: ignore[type-arg] - return self._subject - - @property - def stream_name(self) -> str: - return self._name - - -class ObservationSet(Stream[T]): - """Materialized result set — list-like + stream-like. - - Holds Observation objects with lazy _data_loader closures. - Metadata is in memory, payload BLOBs stay in DB until .data access. - """ - - def __init__( + def append( self, - observations: list[Observation[T]], + payload: T, *, - session: Session | None = None, - payload_type: type | None = None, - ) -> None: - self._observations = observations - backend = ListBackend(cast("list[Observation[Any]]", observations)) - super().__init__(backend=backend, session=session, payload_type=payload_type) - - def _clone(self, **overrides: Any) -> Stream[T]: # type: ignore[override] - """Downgrade to plain Stream — don't carry _observations through chaining.""" - base: Stream[T] = Stream( - backend=self._backend, session=self._session, payload_type=self._payload_type - ) - base._query = self._query - return base._clone(**overrides) - - # ── List-like interface ────────────────────────────────────────── - - def __len__(self) -> int: - return len(self._observations) - - @overload - def __getitem__(self, index: int) -> Observation[T]: ... - - @overload - def __getitem__(self, index: slice) -> list[Observation[T]]: ... - - def __getitem__(self, index: int | slice) -> Observation[T] | list[Observation[T]]: - return self._observations[index] - - def __iter__(self) -> Iterator[Observation[T]]: - return iter(self._observations) - - def __bool__(self) -> bool: - return len(self._observations) > 0 + ts: float | None = None, + pose: Any | None = None, + tags: dict[str, Any] | None = None, + embedding: Embedding | None = None, + ) -> Observation[T]: + """Append to the backing store. Only works if source is a Backend.""" + if isinstance(self._source, Stream): + raise TypeError("Cannot append to a transform stream. Append to the source stream.") + _ts = ts if ts is not None else time.time() + _tags = tags or {} + if embedding is not None: + obs: Observation[T] = EmbeddedObservation( + id=-1, + ts=_ts, + pose=pose, + tags=_tags, + _data=payload, + embedding=embedding, + ) + else: + obs = Observation(id=-1, ts=_ts, pose=pose, tags=_tags, _data=payload) + return self._source.append(obs) diff --git a/dimos/memory2/streaming.md b/dimos/memory/streaming.md similarity index 100% rename from dimos/memory2/streaming.md rename to dimos/memory/streaming.md diff --git a/dimos/memory2/test_blobstore.py b/dimos/memory/test_blobstore.py similarity index 97% rename from dimos/memory2/test_blobstore.py rename to dimos/memory/test_blobstore.py index b8e8668ff8..8e5ab37744 100644 --- a/dimos/memory2/test_blobstore.py +++ b/dimos/memory/test_blobstore.py @@ -20,9 +20,9 @@ import numpy as np -from dimos.memory2.blobstore.file import FileBlobStore -from dimos.memory2.impl.memory import MemoryStore -from dimos.memory2.type import _UNLOADED +from dimos.memory.blobstore.file import FileBlobStore +from dimos.memory.impl.memory import MemoryStore +from dimos.memory.type import _UNLOADED from dimos.models.embedding.base import Embedding if TYPE_CHECKING: @@ -126,7 +126,7 @@ def test_no_blobstore_unchanged(self) -> None: assert obs.data == "inline" def test_blobstore_with_vector_search(self, tmp_path: Path) -> None: - from dimos.memory2.vectorstore import MemoryVectorStore + from dimos.memory.vectorstore import MemoryVectorStore bs = FileBlobStore(tmp_path / "blobs") bs.start() diff --git a/dimos/memory2/test_buffer.py b/dimos/memory/test_buffer.py similarity index 96% rename from dimos/memory2/test_buffer.py rename to dimos/memory/test_buffer.py index f851a6fcee..33235890e1 100644 --- a/dimos/memory2/test_buffer.py +++ b/dimos/memory/test_buffer.py @@ -21,7 +21,7 @@ import pytest -from dimos.memory2.buffer import Bounded, ClosedError, DropNew, KeepLast, Unbounded +from dimos.memory.buffer import Bounded, ClosedError, DropNew, KeepLast, Unbounded class TestBackpressureBuffers: diff --git a/dimos/memory2/test_e2e_import.py b/dimos/memory/test_e2e_import.py similarity index 94% rename from dimos/memory2/test_e2e_import.py rename to dimos/memory/test_e2e_import.py index b134c306d0..1f8f863d21 100644 --- a/dimos/memory2/test_e2e_import.py +++ b/dimos/memory/test_e2e_import.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""E2E test: import legacy pickle replays into memory2 SqliteStore.""" +"""E2E test: import legacy pickle replays into memory SqliteStore.""" from __future__ import annotations @@ -21,7 +21,7 @@ import pytest -from dimos.memory2.impl.sqlite import SqliteStore +from dimos.memory.impl.sqlite import SqliteStore from dimos.msgs.sensor_msgs.Image import Image from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.utils.data import get_data_dir @@ -30,7 +30,7 @@ if TYPE_CHECKING: from collections.abc import Generator - from dimos.memory2.impl.sqlite import SqliteSession + from dimos.memory.impl.sqlite import SqliteSession DB_PATH = get_data_dir("go2_bigoffice_v2.db") @@ -88,7 +88,7 @@ def lidar_replay() -> TimedSensorReplay: # type: ignore[type-arg] @pytest.mark.tool class TestImportReplay: - """Import legacy pickle replay data into a memory2 SqliteStore.""" + """Import legacy pickle replay data into a memory SqliteStore.""" def test_import_video( self, diff --git a/dimos/memory2/impl/__init__.py b/dimos/memory/test_e2e_processing.py similarity index 95% rename from dimos/memory2/impl/__init__.py rename to dimos/memory/test_e2e_processing.py index 1ed1bd093e..81eba5c2a8 100644 --- a/dimos/memory2/impl/__init__.py +++ b/dimos/memory/test_e2e_processing.py @@ -1,3 +1,6 @@ +#!/usr/bin/env python3 + + # Copyright 2026 Dimensional Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/dimos/memory2/test_e2e_query.py b/dimos/memory/test_e2e_query.py similarity index 97% rename from dimos/memory2/test_e2e_query.py rename to dimos/memory/test_e2e_query.py index ac26e865ff..6c9faed17b 100644 --- a/dimos/memory2/test_e2e_query.py +++ b/dimos/memory/test_e2e_query.py @@ -23,7 +23,7 @@ import pytest -from dimos.memory2.impl.sqlite import SqliteStore +from dimos.memory.impl.sqlite import SqliteStore from dimos.msgs.sensor_msgs.Image import Image from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.utils.data import get_data @@ -31,7 +31,7 @@ if TYPE_CHECKING: from collections.abc import Generator - from dimos.memory2.impl.sqlite import SqliteSession + from dimos.memory.impl.sqlite import SqliteSession @pytest.fixture(scope="module") @@ -97,7 +97,7 @@ def test_order_by_desc(self, session: SqliteSession) -> None: def test_lazy_data_loads_correctly(self, session: SqliteSession) -> None: """Verify lazy blob loading returns valid Image data.""" - from dimos.memory2.type import _Unloaded + from dimos.memory.type import _Unloaded video = session.stream("color_image", Image) obs = next(iter(video.limit(1))) diff --git a/dimos/memory2/test_embedding.py b/dimos/memory/test_embedding.py similarity index 96% rename from dimos/memory2/test_embedding.py rename to dimos/memory/test_embedding.py index f1d22addf2..b05ae619a0 100644 --- a/dimos/memory2/test_embedding.py +++ b/dimos/memory/test_embedding.py @@ -19,8 +19,8 @@ import numpy as np import pytest -from dimos.memory2.impl.memory import MemoryStore -from dimos.memory2.type import EmbeddedObservation, Observation +from dimos.memory.impl.memory import MemoryStore +from dimos.memory.type import EmbeddedObservation, Observation from dimos.models.embedding.base import Embedding # ── Helpers ─────────────────────────────────────────────────────── @@ -287,7 +287,7 @@ def embed_text(self, *texts): class TestEmbedTransformers: def test_embed_images_produces_embedded_observations(self) -> None: - from dimos.memory2.embed import EmbedImages + from dimos.memory.embed import EmbedImages model = _MockEmbeddingModel() store = MemoryStore() @@ -304,7 +304,7 @@ def test_embed_images_produces_embedded_observations(self) -> None: assert obs.embedding.to_numpy().shape == (8,) def test_embed_text_produces_embedded_observations(self) -> None: - from dimos.memory2.embed import EmbedText + from dimos.memory.embed import EmbedText model = _MockEmbeddingModel() store = MemoryStore() @@ -320,7 +320,7 @@ def test_embed_text_produces_embedded_observations(self) -> None: assert isinstance(obs.embedding, Embedding) def test_embed_preserves_data(self) -> None: - from dimos.memory2.embed import EmbedText + from dimos.memory.embed import EmbedText model = _MockEmbeddingModel() store = MemoryStore() @@ -332,7 +332,7 @@ def test_embed_preserves_data(self) -> None: assert result.data == "hello" def test_embed_then_search(self) -> None: - from dimos.memory2.embed import EmbedText + from dimos.memory.embed import EmbedText model = _MockEmbeddingModel() store = MemoryStore() @@ -351,7 +351,7 @@ def test_embed_then_search(self) -> None: assert results[0].similarity > 0.99 def test_embed_batching(self) -> None: - from dimos.memory2.embed import EmbedText + from dimos.memory.embed import EmbedText call_sizes: list[int] = [] @@ -379,7 +379,7 @@ class TestPluggableVectorStore: """Verify that injecting a VectorStore via session config actually delegates search.""" def test_append_stores_in_vector_store(self) -> None: - from dimos.memory2.vectorstore import MemoryVectorStore + from dimos.memory.vectorstore import MemoryVectorStore vs = MemoryVectorStore() store = MemoryStore() @@ -391,7 +391,7 @@ def test_append_stores_in_vector_store(self) -> None: assert len(vs._vectors["vecs"]) == 2 def test_append_without_embedding_skips_vector_store(self) -> None: - from dimos.memory2.vectorstore import MemoryVectorStore + from dimos.memory.vectorstore import MemoryVectorStore vs = MemoryVectorStore() store = MemoryStore() @@ -402,7 +402,7 @@ def test_append_without_embedding_skips_vector_store(self) -> None: assert "plain" not in vs._vectors def test_search_uses_vector_store(self) -> None: - from dimos.memory2.vectorstore import MemoryVectorStore + from dimos.memory.vectorstore import MemoryVectorStore vs = MemoryVectorStore() store = MemoryStore() @@ -420,7 +420,7 @@ def test_search_uses_vector_store(self) -> None: assert results[0].similarity > 0.99 def test_search_with_filters_via_vector_store(self) -> None: - from dimos.memory2.vectorstore import MemoryVectorStore + from dimos.memory.vectorstore import MemoryVectorStore vs = MemoryVectorStore() store = MemoryStore() @@ -435,7 +435,7 @@ def test_search_with_filters_via_vector_store(self) -> None: assert results[0].data == "late" def test_per_stream_vector_store_override(self) -> None: - from dimos.memory2.vectorstore import MemoryVectorStore + from dimos.memory.vectorstore import MemoryVectorStore vs_default = MemoryVectorStore() vs_override = MemoryVectorStore() diff --git a/dimos/memory2/test_impl.py b/dimos/memory/test_impl.py similarity index 95% rename from dimos/memory2/test_impl.py rename to dimos/memory/test_impl.py index 6663cf5a04..990818a8d6 100644 --- a/dimos/memory2/test_impl.py +++ b/dimos/memory/test_impl.py @@ -28,7 +28,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Generator - from dimos.memory2.store import Session + from dimos.memory.store import Session # ── Case definition ──────────────────────────────────────────────── @@ -45,7 +45,7 @@ class Case: @contextmanager def memory_session() -> Generator[Session, None, None]: - from dimos.memory2.impl.memory import MemoryStore + from dimos.memory.impl.memory import MemoryStore store = MemoryStore() with store.session() as session: @@ -56,7 +56,7 @@ def memory_session() -> Generator[Session, None, None]: def sqlite_session() -> Generator[Session, None, None]: import tempfile - from dimos.memory2.impl.sqlite import SqliteStore + from dimos.memory.impl.sqlite import SqliteStore with tempfile.NamedTemporaryFile(suffix=".db") as f: store = SqliteStore(path=f.name) @@ -182,7 +182,7 @@ def test_filter_tags(self, case: Case) -> None: s.append("b", tags={"kind": "error"}) s.append("c", tags={"kind": "info"}) - results = s.filter_tags(kind="info").fetch() + results = s.tags(kind="info").fetch() assert [o.data for o in results] == ["a", "c"] def test_limit_and_offset(self, case: Case) -> None: @@ -224,7 +224,7 @@ def test_same_stream_on_repeated_calls(self, case: Case) -> None: def test_append_with_embedding(self, case: Case) -> None: import numpy as np - from dimos.memory2.type import EmbeddedObservation + from dimos.memory.type import EmbeddedObservation from dimos.models.embedding.base import Embedding with case.session_factory() as session: @@ -275,8 +275,8 @@ def test_sqlite_lazy_by_default(self) -> None: """Default sqlite iteration uses lazy loaders — data is _UNLOADED until accessed.""" import tempfile - from dimos.memory2.impl.sqlite import SqliteStore - from dimos.memory2.type import _Unloaded + from dimos.memory.impl.sqlite import SqliteStore + from dimos.memory.type import _Unloaded with tempfile.NamedTemporaryFile(suffix=".db") as f: store = SqliteStore(path=f.name) @@ -299,8 +299,8 @@ def test_sqlite_eager_loads_inline(self) -> None: """With eager_blobs=True, data is loaded via JOIN — no lazy loader.""" import tempfile - from dimos.memory2.impl.sqlite import SqliteStore - from dimos.memory2.type import _Unloaded + from dimos.memory.impl.sqlite import SqliteStore + from dimos.memory.type import _Unloaded with tempfile.NamedTemporaryFile(suffix=".db") as f: store = SqliteStore(path=f.name) @@ -319,7 +319,7 @@ def test_sqlite_lazy_and_eager_same_values(self) -> None: """Both paths must return identical data.""" import tempfile - from dimos.memory2.impl.sqlite import SqliteStore + from dimos.memory.impl.sqlite import SqliteStore with tempfile.NamedTemporaryFile(suffix=".db") as f: store = SqliteStore(path=f.name) @@ -341,9 +341,9 @@ def test_sqlite_lazy_and_eager_same_values(self) -> None: def test_memory_lazy_with_blobstore(self) -> None: """MemoryStore with a BlobStore uses lazy loaders.""" - from dimos.memory2.blobstore.file import FileBlobStore - from dimos.memory2.impl.memory import MemoryStore - from dimos.memory2.type import _Unloaded + from dimos.memory.blobstore.file import FileBlobStore + from dimos.memory.impl.memory import MemoryStore + from dimos.memory.type import _Unloaded store = MemoryStore() import tempfile @@ -435,7 +435,7 @@ class SpyCase: @contextmanager def memory_spy_session() -> Generator[tuple[Session, SpyBlobStore, SpyVectorStore], None, None]: - from dimos.memory2.impl.memory import MemoryStore + from dimos.memory.impl.memory import MemoryStore blob_spy = SpyBlobStore() vec_spy = SpyVectorStore() @@ -448,7 +448,7 @@ def memory_spy_session() -> Generator[tuple[Session, SpyBlobStore, SpyVectorStor def sqlite_spy_session() -> Generator[tuple[Session, SpyBlobStore, SpyVectorStore], None, None]: import tempfile - from dimos.memory2.impl.sqlite import SqliteStore + from dimos.memory.impl.sqlite import SqliteStore blob_spy = SpyBlobStore() vec_spy = SpyVectorStore() diff --git a/dimos/memory/test_memory.py b/dimos/memory/test_memory.py deleted file mode 100644 index cfcb5230ce..0000000000 --- a/dimos/memory/test_memory.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections.abc import Generator - -import pytest - -from dimos.memory.impl.sqlite import SqliteSession, SqliteStore -from dimos.memory.transformer import ( - CaptionTransformer, - QualityWindowTransformer, - TextEmbeddingTransformer, -) -from dimos.models.embedding.clip import CLIPModel -from dimos.models.vl.florence import CaptionDetail, Florence2Model -from dimos.msgs.sensor_msgs.Image import Image -from dimos.utils.data import get_data - - -@pytest.fixture(scope="module") -def store() -> Generator[SqliteStore, None, None]: - with SqliteStore(get_data("go2_bigoffice.db")) as store: - yield store - - -@pytest.fixture(scope="module") -def session(store: SqliteStore) -> Generator[SqliteSession, None, None]: - with store.session() as session: - yield session - - -@pytest.fixture(scope="module") -def image_stream(session): - return session.stream("color_image", Image) - - -@pytest.fixture(scope="module") -def clip() -> CLIPModel: - model = CLIPModel() - model.start() - return model - - -def test_make_caption(session, clip): - print("") - - florence = Florence2Model(detail=CaptionDetail.NORMAL) - florence.start() - - caption_embeddings = ( - session.streams.sharp_images.transform( - QualityWindowTransformer(lambda img: img.sharpness, window=3.0), - ) - .transform(CaptionTransformer(florence)) - .transform(TextEmbeddingTransformer(clip)) - ) - - florence.stop() - - print(caption_embeddings) - print(caption_embeddings.fetch().summary()) diff --git a/dimos/memory/test_projection.py b/dimos/memory/test_projection.py deleted file mode 100644 index 88989a376d..0000000000 --- a/dimos/memory/test_projection.py +++ /dev/null @@ -1,203 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections.abc import Generator - -import pytest - -from dimos.memory.impl.sqlite import SqliteSession, SqliteStore -from dimos.memory.transformer import ( - CaptionTransformer, - DetectionTransformer, - EmbeddingTransformer, - QualityWindowTransformer, - TextEmbeddingTransformer, -) -from dimos.models.embedding.base import Embedding -from dimos.models.embedding.clip import CLIPModel -from dimos.models.vl.florence import CaptionDetail, Florence2Model -from dimos.msgs.sensor_msgs.Image import Image -from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 -from dimos.utils.data import get_data - - -@pytest.fixture(scope="module") -def store() -> Generator[SqliteStore, None, None]: - with SqliteStore(get_data("go2_bigoffice.db")) as store: - yield store - - -@pytest.fixture(scope="module") -def session(store: SqliteStore) -> Generator[SqliteSession, None, None]: - with store.session() as session: - yield session - - -@pytest.fixture(scope="module") -def image_stream(session): - return session.stream("color_image", Image) - - -@pytest.fixture(scope="module") -def lidar_stream(session): - return session.stream("lidar", PointCloud2) - - -@pytest.fixture(scope="module") -def clip() -> CLIPModel: - model = CLIPModel() - model.start() - return model - - -def test_list_streams(session): - print("") - for stream in session.list_streams(): - print(stream.summary()) - - -@pytest.mark.tool -def test_make_embedding(session, lidar_stream, image_stream, clip): - embeddings = ( - image_stream.transform( - QualityWindowTransformer(lambda img: img.sharpness, window=1.0), - live=False, - backfill_only=True, - ) - .store("sharp_images", Image) - .transform(EmbeddingTransformer(clip), live=False, backfill_only=True) - .store("clip_embeddings", Embedding) - ) - print(embeddings) - print(f"Stored {embeddings.count()} embeddings") - - -@pytest.mark.tool -def test_make_caption(session, clip): - print("") - - session.streams.captions.delete() - session.streams.super_sharp_images.delete() - session.streams.caption_embeddings.delete() - - florence = Florence2Model(detail=CaptionDetail.NORMAL) - florence.start() - - super_sharp_images = session.streams.sharp_images.transform( - QualityWindowTransformer(lambda img: img.sharpness, window=3.0), - backfill_only=True, - ).store("super_sharp_images", Image) - - print(super_sharp_images.summary()) - - captions = super_sharp_images.transform(CaptionTransformer(florence), backfill_only=True).store( - "captions", str - ) - - print(captions.summary()) - - florence.stop() - - caption_embeddings = captions.transform( - TextEmbeddingTransformer(clip), backfill_only=True - ).store("caption_embeddings", Embedding) - - print(caption_embeddings.summary()) - print(f"Stored {caption_embeddings.count()} caption embeddings") - - -@pytest.mark.tool -def test_query_embeddings(session, clip): - print("\n") - - embeddings = session.streams.clip_embeddings.search_embedding("supermarket", k=5, model=clip) - - # we are precomputing and throwing away this stream - captions = ( - session.streams.sharp_images.near(embeddings) # spatially near the embedding matches - .limit(5) - .transform(CaptionTransformer(florence)) - # adding live=True here makes it run the caption transformer live on each new matching embedding - ) - - for obs in captions.fetch(): - print(obs.id, obs.data) - - # we can also find all images ever captured spatially near these embeddings (600+ frames) - images = session.streams.color_image.near(embeddings).fetch() - - print(images) - - # we can also find all sharp images near these embeddings, then transform to detect bottles - # sharp images can be loaded from db or computed on demand, here we load from db - bottles = session.streams.sharp_images.near(embeddings, radius=1.0).transform( - DetectionTransformer(moondream, query="bottle") - ) - - # if we want to save this we'd do - # bottles.save("bottle_detections", Detection2D) - - print(bottles) - - for bottle in bottles.fetch(): - print(bottle.data) - - moondream.stop() - - -def test_count_comparison(session, clip): - """Compare fetch-then-transform vs transform-then-fetch counts.""" - print("\n") - embeddings = session.streams.clip_embeddings.search_embedding("supermarket", k=5, model=clip) - - # Count from near() directly - near_stream = session.streams.color_image.near(embeddings, radius=1.0) - fetched = near_stream.fetch() - print(f"near().fetch() count: {len(fetched)}") - - # Approach 1: fetch first, then transform with identity lambda - result1 = fetched.transform(lambda x: x).fetch() - print(f"fetch().transform(id).fetch() count: {len(result1)}") - - # Approach 2: transform on lazy stream, then fetch - near_stream2 = session.streams.color_image.near(embeddings, radius=1.0) - result2 = near_stream2.transform(lambda x: x).fetch() - print(f"near().transform(id).fetch() count: {len(result2)}") - - assert len(fetched) == len(result1), ( - f"fetch-then-transform mismatch: {len(fetched)} vs {len(result1)}" - ) - assert len(fetched) == len(result2), ( - f"transform-then-fetch mismatch: {len(fetched)} vs {len(result2)}" - ) - - -@pytest.mark.tool -def test_print_captions(session, clip): - for caption in session.streams.captions: - print(caption.data) - - -def test_search_embeddings(session, clip): - print("") - embedding_stream = session.embedding_stream("clip_embeddings", embedding_model=clip) - - search = embedding_stream.search_embedding("supermarket", k=5) - print(search) - - project = search.project_to(session.streams.color_image) - print(project) - - results = project.fetch() - print(results) diff --git a/dimos/memory2/test_save.py b/dimos/memory/test_save.py similarity index 95% rename from dimos/memory2/test_save.py rename to dimos/memory/test_save.py index 74c1be89f0..ba672f76fd 100644 --- a/dimos/memory2/test_save.py +++ b/dimos/memory/test_save.py @@ -18,11 +18,11 @@ import pytest -from dimos.memory2.backend import Backend, LiveChannel -from dimos.memory2.impl.memory import ListBackend -from dimos.memory2.stream import Stream -from dimos.memory2.transform import FnTransformer -from dimos.memory2.type import Observation +from dimos.memory.backend import Backend, LiveChannel +from dimos.memory.impl.memory import ListBackend +from dimos.memory.stream import Stream +from dimos.memory.transform import FnTransformer +from dimos.memory.type import Observation # ── Helpers ────────────────────────────────────────────────────────── diff --git a/dimos/memory2/test_stream.py b/dimos/memory/test_stream.py similarity index 98% rename from dimos/memory2/test_stream.py rename to dimos/memory/test_stream.py index 46eef32e4f..1fa4bdbbb2 100644 --- a/dimos/memory2/test_stream.py +++ b/dimos/memory/test_stream.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""memory2 stream tests — serves as living documentation of the lazy stream API. +"""memory stream tests — serves as living documentation of the lazy stream API. Each test demonstrates a specific capability with clear setup, action, and assertion. """ @@ -25,13 +25,13 @@ import pytest -from dimos.memory2.buffer import KeepLast, Unbounded -from dimos.memory2.impl.memory import MemoryStore -from dimos.memory2.transform import FnTransformer, QualityWindow, Transformer -from dimos.memory2.type import Observation +from dimos.memory.buffer import KeepLast, Unbounded +from dimos.memory.impl.memory import MemoryStore +from dimos.memory.transform import FnTransformer, QualityWindow, Transformer +from dimos.memory.type import Observation if TYPE_CHECKING: - from dimos.memory2.stream import Stream + from dimos.memory.stream import Stream # ── Helpers ────────────────────────────────────────────────────────── diff --git a/dimos/memory/test_stream_repr.py b/dimos/memory/test_stream_repr.py deleted file mode 100644 index 1fa7239578..0000000000 --- a/dimos/memory/test_stream_repr.py +++ /dev/null @@ -1,194 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for Stream repr/str and Filter.__str__.""" - -from __future__ import annotations - -import pytest - -from dimos.memory.impl.sqlite import SqliteStore -from dimos.memory.stream import Stream -from dimos.memory.transformer import PerItemTransformer -from dimos.memory.type import ( - AfterFilter, - AtFilter, - BeforeFilter, - EmbeddingSearchFilter, - LineageFilter, - NearFilter, - StreamQuery, - TagsFilter, - TextSearchFilter, - TimeRangeFilter, -) - -# ── Filter __str__ ──────────────────────────────────────────────────── - - -class TestFilterStr: - def test_after(self) -> None: - assert str(AfterFilter(3.0)) == "after(t=3.0)" - - def test_before(self) -> None: - assert str(BeforeFilter(10.5)) == "before(t=10.5)" - - def test_time_range(self) -> None: - assert str(TimeRangeFilter(3.0, 10.0)) == "time_range(3.0, 10.0)" - - def test_at(self) -> None: - assert str(AtFilter(5.0, 1.0)) == "at(t=5.0, tol=1.0)" - - def test_near(self) -> None: - assert str(NearFilter(pose=None, radius=5.0)) == "near(radius=5.0)" - - def test_tags_single(self) -> None: - assert str(TagsFilter((("cam", "front"),))) == "tags(cam='front')" - - def test_tags_multiple(self) -> None: - f = TagsFilter((("cam", "front"), ("quality", 1))) - assert str(f) == "tags(cam='front', quality=1)" - - def test_embedding_search(self) -> None: - assert str(EmbeddingSearchFilter([0.1, 0.2], k=5)) == "search_embedding(k=5)" - - def test_text_search(self) -> None: - assert str(TextSearchFilter("error", k=None)) == "text('error')" - - def test_lineage(self) -> None: - f = LineageFilter("embeddings", StreamQuery(), hops=("filtered",)) - assert str(f) == "lineage(embeddings -> filtered)" - - def test_lineage_direct(self) -> None: - f = LineageFilter("embeddings", StreamQuery(), hops=()) - assert str(f) == "lineage(embeddings -> direct)" - - -# ── Stream __str__ ──────────────────────────────────────────────────── - - -@pytest.fixture() -def session(): - store = SqliteStore(":memory:") - store.start() - s = store.session() - yield s - s.stop() - store.stop() - - -class TestStreamRepr: - def test_basic_stream(self, session) -> None: - s = session.stream("images", int) - print(s) - assert repr(s) == 'Stream[int]("images")' - - def test_chain(self, session) -> None: - s = session.stream("images", int).after(3.0).filter_tags(cam="front").limit(10) - print(s) - assert repr(s) == "Stream[int](\"images\") | after(t=3.0) | tags(cam='front') | limit(10)" - - def test_order_and_offset(self, session) -> None: - s = session.stream("images", int).order_by("ts", desc=True).offset(5).limit(10) - print(s) - assert repr(s) == 'Stream[int]("images") | order(ts, desc) | limit(10) | offset(5)' - - def test_text_stream(self, session) -> None: - ts = session.text_stream("logs") - print(ts) - assert repr(ts) == 'TextStream[str]("logs")' - - def test_text_search(self, session) -> None: - ts = session.text_stream("logs").search_text("error") - print(ts) - assert repr(ts) == "TextStream[str](\"logs\") | text('error')" - - def test_embedding_stream(self, session) -> None: - es = session.embedding_stream("clip", vec_dimensions=512) - print(es) - assert repr(es) == 'EmbeddingStream[Embedding]("clip")' - - def test_transform_stream(self, session) -> None: - s = session.stream("images", int) - xf = PerItemTransformer(lambda x: x) - t = s.transform(xf, live=True) - print(t) - assert ( - repr(t) == 'TransformStream[?](Stream[int]("images") -> PerItemTransformer, live=True)' - ) - - def test_transform_backfill_only(self, session) -> None: - s = session.stream("images", int) - xf = PerItemTransformer(lambda x: x) - t = s.transform(xf, backfill_only=True) - print(t) - assert ( - repr(t) - == 'TransformStream[?](Stream[int]("images") -> PerItemTransformer, backfill_only=True)' - ) - - def test_unbound_stream(self) -> None: - s = Stream(payload_type=int) - print(s) - assert repr(s) == 'Stream[int]("unbound")' - - def test_no_payload_type(self) -> None: - s = Stream() - print(s) - assert repr(s) == 'Stream[?]("unbound")' - - def test_materialized_transform(self, session) -> None: - s = session.stream("images", int) - s.append(1, ts=1.0) - xf = PerItemTransformer(lambda x: x * 2) - derived = s.transform(xf).store("doubled", int) - print(derived) - assert repr(derived) == 'Stream[int]("doubled")' - - def test_transform_with_typed_transformer(self, session) -> None: - from unittest.mock import MagicMock - - from dimos.memory.transformer import EmbeddingTransformer - - s = session.stream("images", int) - model = MagicMock() - xf = EmbeddingTransformer(model) - t = s.transform(xf, live=True) - print(t) - assert ( - repr(t) - == 'TransformStream[Embedding](Stream[int]("images") -> EmbeddingTransformer(MagicMock), live=True)' - ) - - def test_embedding_stream_from_source(self, session) -> None: - session.stream("images", int) - es = ( - session.embedding_stream("clip", vec_dimensions=512, parent_table="images") - .after(5.0) - .limit(3) - ) - print(es) - assert repr(es) == 'EmbeddingStream[Embedding]("clip") | after(t=5.0) | limit(3)' - - def test_ivan(self, session) -> None: - from unittest.mock import MagicMock - - from dimos.memory.transformer import EmbeddingTransformer - from dimos.msgs.sensor_msgs.Image import Image - - s = session.stream("images", Image).after(5.0).limit(3) - print("\n") - print(s) - model = MagicMock() - print(s.transform(EmbeddingTransformer(model)).limit(3)) diff --git a/dimos/memory/test_transformer.py b/dimos/memory/test_transformer.py deleted file mode 100644 index bb526a302a..0000000000 --- a/dimos/memory/test_transformer.py +++ /dev/null @@ -1,252 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for memory transformers.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -import numpy as np -import pytest - -from dimos.memory.impl.sqlite import SqliteSession, SqliteStore -from dimos.memory.transformer import DetectionTransformer, TextEmbeddingTransformer -from dimos.models.embedding.base import Embedding, EmbeddingModel -from dimos.msgs.sensor_msgs import Image -from dimos.perception.detection.type import Detection2DBBox, ImageDetections2D - -if TYPE_CHECKING: - from collections.abc import Iterator - - -class FakeTextEmbedder(EmbeddingModel): - device = "cpu" - - def embed(self, *imgs: Image) -> Embedding | list[Embedding]: # type: ignore[override] - raise NotImplementedError - - def embed_text(self, *texts: str) -> Embedding | list[Embedding]: - results = [] - for text in texts: - h = hash(text) % 1000 / 1000.0 - results.append(Embedding(np.array([h, 1.0 - h, 0.0, 0.0], dtype=np.float32))) - return results if len(results) > 1 else results[0] - - -class SemanticFakeEmbedder(EmbeddingModel): - """Embeds 'kitchen' texts to one region, everything else to another.""" - - device = "cpu" - - def embed(self, *imgs: Image) -> Embedding | list[Embedding]: # type: ignore[override] - raise NotImplementedError - - def embed_text(self, *texts: str) -> Embedding | list[Embedding]: - results = [] - for text in texts: - if "kitchen" in text.lower(): - results.append(Embedding(np.array([1.0, 0.0, 0.0], dtype=np.float32))) - else: - results.append(Embedding(np.array([0.0, 1.0, 0.0], dtype=np.float32))) - return results if len(results) > 1 else results[0] - - -@pytest.fixture -def session(tmp_path: object) -> Iterator[SqliteSession]: - from pathlib import Path - - assert isinstance(tmp_path, Path) - store = SqliteStore(str(tmp_path / "test.db")) - sess = store.session() - yield sess - sess.stop() - store.stop() - - -class TestTextEmbeddingTransformer: - """Test text -> embedding -> semantic search pipeline.""" - - def test_text_to_embedding_backfill(self, session: SqliteSession) -> None: - """Backfill: store text, transform to embeddings, search by text.""" - logs = session.stream("te_logs", str) - logs.append("Robot navigated to kitchen", ts=1.0) - logs.append("Battery low warning", ts=2.0) - logs.append("Robot navigated to bedroom", ts=3.0) - - emb_stream = logs.transform(TextEmbeddingTransformer(FakeTextEmbedder())).store( - "te_log_embeddings" - ) - - assert emb_stream.count() == 3 - - results = emb_stream.search_embedding("Robot navigated to kitchen", k=1).fetch() - assert len(results) == 1 - assert isinstance(results[0].data, Embedding) - - # project_to to get source text - projected = ( - emb_stream.search_embedding("Robot navigated to kitchen", k=1).project_to(logs).fetch() - ) - assert len(projected) == 1 - assert isinstance(projected[0].data, str) - - def test_text_embedding_live(self, session: SqliteSession) -> None: - """Live mode: new text is embedded automatically.""" - logs = session.stream("te_live_logs", str) - emb_stream = logs.transform(TextEmbeddingTransformer(FakeTextEmbedder()), live=True).store( - "te_live_embs" - ) - - assert emb_stream.count() == 0 # no backfill - - logs.append("New log entry", ts=1.0) - assert emb_stream.count() == 1 - - logs.append("Another log entry", ts=2.0) - assert emb_stream.count() == 2 - - def test_text_embedding_search_and_project(self, session: SqliteSession) -> None: - """search_embedding + project_to retrieves source text.""" - logs = session.stream("te_proj_logs", str) - logs.append("Robot entered kitchen", ts=1.0) - logs.append("Battery warning", ts=2.0) - logs.append("Cleaning kitchen floor", ts=3.0) - - emb_stream = logs.transform(TextEmbeddingTransformer(SemanticFakeEmbedder())).store( - "te_proj_embs" - ) - - results = emb_stream.search_embedding("kitchen", k=2).project_to(logs).fetch() - assert len(results) == 2 - assert all("kitchen" in r.data.lower() for r in results) - - -def _make_image(ts: float) -> Image: - return Image(data=np.zeros((64, 64, 3), dtype=np.uint8), ts=ts) - - -class FakeVlModel: - """Minimal VlModel stub for detection tests.""" - - def __init__( - self, - detections_per_image: int = 2, - *, - raise_on_call: bool = False, - ) -> None: - self.detections_per_image = detections_per_image - self.raise_on_call = raise_on_call - - def query_detections(self, image: Image, query: str, **kwargs: object) -> ImageDetections2D: - if self.raise_on_call: - raise RuntimeError("model error") - dets = [ - Detection2DBBox( - bbox=(10.0 * i, 10.0 * i, 20.0 * i + 20, 20.0 * i + 20), - track_id=i, - class_id=-1, - confidence=0.9, - name=query, - ts=image.ts, - image=image, - ) - for i in range(self.detections_per_image) - ] - return ImageDetections2D(image=image, detections=dets) - - -class TestDetectionTransformer: - """Test VLM detection transformer.""" - - def test_detection_backfill(self, session: SqliteSession) -> None: - """Backfill: 3 images → transform → 3 detection observations.""" - imgs = session.stream("det_imgs", Image) - for i in range(3): - imgs.append(_make_image(float(i + 1)), ts=float(i + 1)) - - det_stream = imgs.transform(DetectionTransformer(FakeVlModel(2), "cup")).store("det_cups") - - assert det_stream.count() == 3 - results = det_stream.fetch() - for obs in results: - assert obs.data.image is None, "image should be stripped" - for det in obs.data.detections: - assert det.image is None, "detection image should be stripped" - assert obs.tags["query"] == "cup" - assert obs.tags["count"] == 2 - - def test_detection_skip_empty(self, session: SqliteSession) -> None: - """skip_empty=True (default): 0 detections → observation skipped.""" - imgs = session.stream("det_skip_imgs", Image) - imgs.append(_make_image(1.0), ts=1.0) - - det_stream = imgs.transform(DetectionTransformer(FakeVlModel(0), "nothing")).store( - "det_skip" - ) - - assert det_stream.count() == 0 - - def test_detection_keep_empty(self, session: SqliteSession) -> None: - """skip_empty=False: 0 detections → observation stored with count=0.""" - imgs = session.stream("det_keep_imgs", Image) - imgs.append(_make_image(1.0), ts=1.0) - - det_stream = imgs.transform( - DetectionTransformer(FakeVlModel(0), "nothing", skip_empty=False) - ).store("det_keep") - - assert det_stream.count() == 1 - obs = det_stream.fetch()[0] - assert obs.tags["count"] == 0 - assert len(obs.data.detections) == 0 - - def test_detection_model_error(self, session: SqliteSession) -> None: - """Model raises → observation skipped, no crash.""" - imgs = session.stream("det_err_imgs", Image) - imgs.append(_make_image(1.0), ts=1.0) - - det_stream = imgs.transform( - DetectionTransformer(FakeVlModel(raise_on_call=True), "cup") - ).store("det_err") - - assert det_stream.count() == 0 - - def test_detection_lineage(self, session: SqliteSession) -> None: - """project_to(image_stream) recovers source images.""" - imgs = session.stream("det_lin_imgs", Image) - imgs.append(_make_image(1.0), ts=1.0) - imgs.append(_make_image(2.0), ts=2.0) - - det_stream = imgs.transform(DetectionTransformer(FakeVlModel(1), "obj")).store("det_lin") - - projected = det_stream.project_to(imgs).fetch() - assert len(projected) == 2 - for obs in projected: - assert isinstance(obs.data, Image) - - def test_detection_live(self, session: SqliteSession) -> None: - """Live mode: append images after transform, verify reactive detection.""" - imgs = session.stream("det_live_imgs", Image) - det_stream = imgs.transform(DetectionTransformer(FakeVlModel(1), "cup"), live=True).store( - "det_live" - ) - - assert det_stream.count() == 0 - - imgs.append(_make_image(1.0), ts=1.0) - assert det_stream.count() == 1 - - imgs.append(_make_image(2.0), ts=2.0) - assert det_stream.count() == 2 diff --git a/dimos/memory/tests/__init__.py b/dimos/memory/tests/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/memory2/transform.py b/dimos/memory/transform.py similarity index 97% rename from dimos/memory2/transform.py rename to dimos/memory/transform.py index d68e25344a..ebdb6416cf 100644 --- a/dimos/memory2/transform.py +++ b/dimos/memory/transform.py @@ -18,12 +18,12 @@ import inspect from typing import TYPE_CHECKING, Any, Generic, TypeVar -from dimos.memory2.formatting import FilterRepr +from dimos.memory.formatting import FilterRepr if TYPE_CHECKING: from collections.abc import Callable, Iterator - from dimos.memory2.type import Observation + from dimos.memory.type import Observation T = TypeVar("T") R = TypeVar("R") diff --git a/dimos/memory/transformer.py b/dimos/memory/transformer.py deleted file mode 100644 index 3d0e73f55f..0000000000 --- a/dimos/memory/transformer.py +++ /dev/null @@ -1,327 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from abc import ABC, abstractmethod -import logging -from typing import TYPE_CHECKING, Any, Generic, TypeVar - -if TYPE_CHECKING: - from collections.abc import Callable - - from dimos.memory.stream import Stream - from dimos.memory.type import Observation - from dimos.models.embedding.base import Embedding, EmbeddingModel - from dimos.models.vl.base import Captioner, VlModel - from dimos.perception.detection.type import ImageDetections2D - -logger = logging.getLogger(__name__) - -T = TypeVar("T") -R = TypeVar("R") - - -class Transformer(ABC, Generic[T, R]): - """Transforms a source stream into results on a target stream.""" - - supports_backfill: bool = True - supports_live: bool = True - output_type: type | None = None - - def __repr__(self) -> str: - return type(self).__name__ - - @abstractmethod - def process(self, source: Stream[T], target: Stream[R]) -> None: - """Batch/historical processing. - - Has full access to the source stream — can query, filter, batch, skip, etc. - """ - - def on_append(self, obs: Observation[Any], target: Stream[R]) -> None: - """Reactive per-item processing. Called for each new item.""" - - -class PerItemTransformer(Transformer[T, R]): - """Wraps a simple callable as a per-item Transformer.""" - - def __init__(self, fn: Callable[[T], R | list[R] | None]) -> None: - self._fn = fn - - def process(self, source: Stream[T], target: Stream[R]) -> None: - for page in source.fetch_pages(): - for obs in page: - self._apply(obs, target) - - def on_append(self, obs: Observation[Any], target: Stream[R]) -> None: - self._apply(obs, target) - - def _apply(self, obs: Observation[Any], target: Stream[R]) -> None: - result = self._fn(obs.data) - if result is None: - return - if isinstance(result, list): - for item in result: - target.append(item, ts=obs.ts, pose=obs.pose, tags=obs.tags, parent_id=obs.id) - else: - target.append(result, ts=obs.ts, pose=obs.pose, tags=obs.tags, parent_id=obs.id) - - -class QualityWindowTransformer(Transformer[T, T]): - """Keeps the highest-quality item per time window. - - Like ``sharpness_barrier`` but operates on stored data (no wall-clock dependency). - In live mode, buffers the current window and emits the best item when a new - observation falls outside the window. - """ - - supports_backfill: bool = True - supports_live: bool = True - - def __init__(self, quality_fn: Callable[[T], float], window: float = 0.5) -> None: - self._quality_fn = quality_fn - self._window = window - - def __repr__(self) -> str: - fn_name = getattr(self._quality_fn, "__name__", None) or repr(self._quality_fn) - return f"QualityWindowTransformer({fn_name}, window={self._window})" - # Live state - self._window_start: float | None = None - self._best_obs: Observation[T] | None = None - self._best_score: float = -1.0 - - def process(self, source: Stream[T], target: Stream[T]) -> None: - window_start: float | None = None - best_obs: Observation[T] | None = None - best_score: float = -1.0 - - for obs in source: - ts = obs.ts or 0.0 - if window_start is None: - window_start = ts - - if (ts - window_start) >= self._window: - if best_obs is not None: - target.append( - best_obs.data, - ts=best_obs.ts, - pose=best_obs.pose, - tags=best_obs.tags, - parent_id=best_obs.id, - ) - window_start = ts - best_score = -1.0 - best_obs = None - - score = self._quality_fn(obs.data) - if score > best_score: - best_score = score - best_obs = obs - - if best_obs is not None: - target.append( - best_obs.data, - ts=best_obs.ts, - pose=best_obs.pose, - tags=best_obs.tags, - parent_id=best_obs.id, - ) - - def on_append(self, obs: Observation[T], target: Stream[T]) -> None: # type: ignore[override] - ts = obs.ts or 0.0 - - if self._window_start is None: - self._window_start = ts - - if (ts - self._window_start) >= self._window: - if self._best_obs is not None: - target.append( - self._best_obs.data, - ts=self._best_obs.ts, - pose=self._best_obs.pose, - tags=self._best_obs.tags, - parent_id=self._best_obs.id, - ) - self._window_start = ts - self._best_score = -1.0 - self._best_obs = None - - score = self._quality_fn(obs.data) - if score > self._best_score: - self._best_score = score - self._best_obs = obs - - -class CaptionTransformer(Transformer[Any, str]): - """Wraps a Captioner (or VlModel) to produce text captions from images. - - When stored, the output stream becomes a TextStream with FTS index. - Uses caption_batch() during backfill for efficiency. - """ - - supports_backfill: bool = True - supports_live: bool = True - - def __init__(self, model: Captioner, *, batch_size: int = 16) -> None: - self.model = model - self.batch_size = batch_size - self.output_type: type | None = str - - def __repr__(self) -> str: - model_name = type(self.model).__name__ - parts = [model_name] - if self.batch_size != 16: - parts.append(f"batch_size={self.batch_size}") - return f"CaptionTransformer({', '.join(parts)})" - - def process(self, source: Stream[Any], target: Stream[str]) -> None: - for page in source.fetch_pages(batch_size=self.batch_size): - images = [obs.data for obs in page] - if not images: - continue - captions = self.model.caption_batch(*images) - for obs, cap in zip(page, captions, strict=True): - target.append(cap, ts=obs.ts, pose=obs.pose, tags=obs.tags, parent_id=obs.id) - - def on_append(self, obs: Observation[Any], target: Stream[str]) -> None: - caption = self.model.caption(obs.data) - target.append(caption, ts=obs.ts, pose=obs.pose, tags=obs.tags, parent_id=obs.id) - - -class TextEmbeddingTransformer(Transformer[Any, "Embedding"]): - """Wraps an EmbeddingModel to embed text payloads (strings) into vectors. - - Use this for semantic search over logs, captions, or any text data. - When stored, the output stream becomes an EmbeddingStream with vector index. - """ - - supports_backfill: bool = True - supports_live: bool = True - - def __init__(self, model: EmbeddingModel) -> None: - from dimos.models.embedding.base import Embedding as EmbeddingCls - - self.model = model - self.output_type: type | None = EmbeddingCls - - def __repr__(self) -> str: - return f"TextEmbeddingTransformer({type(self.model).__name__})" - - def process(self, source: Stream[Any], target: Stream[Embedding]) -> None: - for page in source.fetch_pages(): - texts = [str(obs.data) for obs in page] - if not texts: - continue - embeddings = self.model.embed_text(*texts) - if not isinstance(embeddings, list): - embeddings = [embeddings] - for obs, emb in zip(page, embeddings, strict=True): - target.append(emb, ts=obs.ts, pose=obs.pose, tags=obs.tags, parent_id=obs.id) - - def on_append(self, obs: Observation[Any], target: Stream[Embedding]) -> None: - emb = self.model.embed_text(str(obs.data)) - if isinstance(emb, list): - emb = emb[0] - target.append(emb, ts=obs.ts, pose=obs.pose, tags=obs.tags, parent_id=obs.id) - - -class EmbeddingTransformer(Transformer[Any, "Embedding"]): - """Wraps an EmbeddingModel as a Transformer that produces Embedding output. - - When stored, the output stream becomes an EmbeddingStream with vector index. - """ - - supports_backfill: bool = True - supports_live: bool = True - - def __init__(self, model: EmbeddingModel) -> None: - from dimos.models.embedding.base import Embedding as EmbeddingCls - - self.model = model - self.output_type: type | None = EmbeddingCls - - def __repr__(self) -> str: - return f"EmbeddingTransformer({type(self.model).__name__})" - - def process(self, source: Stream[Any], target: Stream[Embedding]) -> None: - for page in source.fetch_pages(): - images = [obs.data for obs in page] - if not images: - continue - embeddings = self.model.embed(*images) - if not isinstance(embeddings, list): - embeddings = [embeddings] - for obs, emb in zip(page, embeddings, strict=True): - target.append(emb, ts=obs.ts, pose=obs.pose, tags=obs.tags, parent_id=obs.id) - - def on_append(self, obs: Observation[Any], target: Stream[Embedding]) -> None: - emb = self.model.embed(obs.data) - if isinstance(emb, list): - emb = emb[0] - target.append(emb, ts=obs.ts, pose=obs.pose, tags=obs.tags, parent_id=obs.id) - - -class DetectionTransformer(Transformer[Any, "ImageDetections2D"]): - """Runs VLM object detection on images, producing ImageDetections2D. - - Strips image references from detections before storage to avoid - duplicating image data. Use project_to(image_stream) to recover - source images via lineage. - """ - - supports_backfill = True - supports_live = True - - def __init__(self, model: VlModel, query: str, *, skip_empty: bool = True) -> None: - from dimos.perception.detection.type import ImageDetections2D as IDet2D - - self.model = model - self.query = query - self.skip_empty = skip_empty - self.output_type: type | None = IDet2D - - def __repr__(self) -> str: - model_name = type(self.model).__name__ - parts = [f"{model_name}, {self.query!r}"] - if not self.skip_empty: - parts.append("skip_empty=False") - return f"DetectionTransformer({', '.join(parts)})" - - def process(self, source: Stream[Any], target: Stream[ImageDetections2D]) -> None: - for page in source.fetch_pages(): - for obs in page: - self._detect_and_append(obs, target) - - def on_append(self, obs: Observation[Any], target: Stream[ImageDetections2D]) -> None: - self._detect_and_append(obs, target) - - def _detect_and_append(self, obs: Observation[Any], target: Stream[ImageDetections2D]) -> None: - try: - detections = self.model.query_detections(obs.data, self.query) - except Exception: - logger.warning("Detection failed for obs %s, skipping", obs.id, exc_info=True) - return - - count = len(detections) - if count == 0 and self.skip_empty: - return - - # Strip image refs to avoid duplicating image data in storage - detections.image = None - for det in detections.detections: - det.image = None - - tags = {**(obs.tags or {}), "query": self.query, "count": count} - target.append(detections, ts=obs.ts, pose=obs.pose, tags=tags, parent_id=obs.id) diff --git a/dimos/memory/type.py b/dimos/memory/type.py index 35982818f6..85cfab9640 100644 --- a/dimos/memory/type.py +++ b/dimos/memory/type.py @@ -14,230 +14,101 @@ from __future__ import annotations -from collections.abc import Callable from dataclasses import dataclass, field -import math -from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeVar - -from dimos.models.embedding.base import Embedding +from typing import TYPE_CHECKING, Any, Generic, TypeVar if TYPE_CHECKING: - from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + from collections.abc import Callable -PoseProvider: TypeAlias = Callable[[float], Any] # (ts) -> PoseLike | None + from dimos.models.embedding.base import Embedding T = TypeVar("T") -class _Unset: - """Sentinel indicating no data has been loaded yet.""" +# ── Lazy data sentinel ────────────────────────────────────────────── + + +class _Unloaded: + """Sentinel indicating data has not been loaded yet.""" __slots__ = () + def __repr__(self) -> str: + return "" + -_UNSET = _Unset() +_UNLOADED = _Unloaded() + + +# ── Observation ───────────────────────────────────────────────────── @dataclass class Observation(Generic[T]): + """A single timestamped observation with optional spatial pose and metadata.""" + id: int ts: float - pose: PoseStamped | None = None + pose: Any | None = None tags: dict[str, Any] = field(default_factory=dict) - parent_id: int | None = field(default=None, repr=False) - _data: T | _Unset = field(default_factory=lambda: _UNSET, repr=False) - _data_loader: Callable[[], T] | None = field(default=None, repr=False, compare=False) + _data: T | _Unloaded = field(default=_UNLOADED, repr=False) + _loader: Callable[[], T] | None = field(default=None, repr=False) @property def data(self) -> T: - if not isinstance(self._data, _Unset): - return self._data - if self._data_loader is not None: - loaded = self._data_loader() + val = self._data + if isinstance(val, _Unloaded): + if self._loader is None: + raise LookupError("No data and no loader set on this observation") + loaded = self._loader() self._data = loaded + self._loader = None # release closure return loaded - raise LookupError("No data available; observation was not fetched with payload") - - def load(self) -> Observation[T]: - """Force-load .data and return self. Safe to pass across threads after this.""" - self.data # noqa: B018 - return self + return val + + def derive(self, *, data: Any, **overrides: Any) -> Observation[Any]: + """Create a new observation preserving ts/pose/tags, replacing data. + + If ``embedding`` is passed, promotes the result to + :class:`EmbeddedObservation`. + """ + if "embedding" in overrides: + return EmbeddedObservation( + id=self.id, + ts=overrides.get("ts", self.ts), + pose=overrides.get("pose", self.pose), + tags=overrides.get("tags", self.tags), + _data=data, + embedding=overrides["embedding"], + similarity=overrides.get("similarity"), + ) + return Observation( + id=self.id, + ts=overrides.get("ts", self.ts), + pose=overrides.get("pose", self.pose), + tags=overrides.get("tags", self.tags), + _data=data, + ) + + +# ── EmbeddedObservation ────────────────────────────────────────── @dataclass -class EmbeddingObservation(Observation[Embedding]): - """Returned by EmbeddingStream terminals. - - .data returns the Embedding stored in this stream. - .embedding is a convenience alias for .data (typed as Embedding). - .similarity is populated (0..1) when fetched via search_embedding (vec0 cosine). - - To get source data (e.g. the original Image), use .project_to(source_stream). - """ - - similarity: float | None = field(default=None, repr=True) - - @property - def embedding(self) -> Embedding: - return self.data - - -# ── Filter types ────────────────────────────────────────────────────── - - -@dataclass(frozen=True) -class AfterFilter: - t: float - - def matches(self, obs: Observation[Any]) -> bool: - return obs.ts is not None and obs.ts > self.t - - def __str__(self) -> str: - return f"after(t={self.t})" - - -@dataclass(frozen=True) -class BeforeFilter: - t: float - - def matches(self, obs: Observation[Any]) -> bool: - return obs.ts is not None and obs.ts < self.t - - def __str__(self) -> str: - return f"before(t={self.t})" - - -@dataclass(frozen=True) -class TimeRangeFilter: - t1: float - t2: float - - def matches(self, obs: Observation[Any]) -> bool: - return obs.ts is not None and self.t1 <= obs.ts <= self.t2 - - def __str__(self) -> str: - return f"time_range({self.t1}, {self.t2})" - - -@dataclass(frozen=True) -class AtFilter: - t: float - tolerance: float - - def matches(self, obs: Observation[Any]) -> bool: - return obs.ts is not None and abs(obs.ts - self.t) <= self.tolerance - - def __str__(self) -> str: - return f"at(t={self.t}, tol={self.tolerance})" - - -@dataclass(frozen=True) -class NearFilter: - pose: Any # PoseLike - radius: float - - def matches(self, obs: Observation[Any]) -> bool: - if obs.pose is None: - return False - p1 = obs.pose.position - p2 = self.pose.position - dist = math.sqrt((p1.x - p2.x) ** 2 + (p1.y - p2.y) ** 2 + (p1.z - p2.z) ** 2) - return dist <= self.radius - - def __str__(self) -> str: - return f"near(radius={self.radius})" - - -@dataclass(frozen=True) -class TagsFilter: - tags: tuple[tuple[str, Any], ...] - - def matches(self, obs: Observation[Any]) -> bool: - return all(obs.tags.get(k) == v for k, v in self.tags) - - def __str__(self) -> str: - pairs = ", ".join(f"{k}={v!r}" for k, v in self.tags) - return f"tags({pairs})" - - -@dataclass(frozen=True) -class EmbeddingSearchFilter: - query: list[float] - k: int - label: str | None = None - - def matches(self, obs: Observation[Any]) -> bool: - return True # top-k handled as special pass in ListBackend - - def __str__(self) -> str: - parts = [f"k={self.k}"] - if self.label: - parts.insert(0, repr(self.label)) - return f"search_embedding({', '.join(parts)})" - - -@dataclass(frozen=True) -class TextSearchFilter: - text: str - k: int | None - - def matches(self, obs: Observation[Any]) -> bool: - return self.text.lower() in str(obs.data).lower() - - def __str__(self) -> str: - return f"text({self.text!r})" - - -@dataclass(frozen=True) -class LineageFilter: - """Filter to rows that are ancestors of observations in another stream. - - Used by ``project_to`` — compiles to a nested SQL subquery that walks the - ``parent_id`` chain from *source_table* through *hops* to the target. - """ - - source_table: str - source_query: StreamQuery - hops: tuple[str, ...] # intermediate tables between source and target - - def matches(self, obs: Observation[Any]) -> bool: - raise NotImplementedError("LineageFilter requires a database backend") - - def __str__(self) -> str: - hops = " -> ".join(self.hops) if self.hops else "direct" - return f"lineage({self.source_table} -> {hops})" - - -Filter: TypeAlias = ( - AfterFilter - | BeforeFilter - | TimeRangeFilter - | AtFilter - | NearFilter - | TagsFilter - | EmbeddingSearchFilter - | TextSearchFilter - | LineageFilter -) - - -@dataclass(frozen=True) -class StreamQuery: - """Immutable bundle of query parameters passed to backends.""" - - filters: tuple[Filter, ...] = () - order_field: str | None = None - order_desc: bool = False - limit_val: int | None = None - offset_val: int | None = None - - def __str__(self) -> str: - parts: list[str] = [str(f) for f in self.filters] - if self.order_field: - direction = "desc" if self.order_desc else "asc" - parts.append(f"order({self.order_field}, {direction})") - if self.limit_val is not None: - parts.append(f"limit({self.limit_val})") - if self.offset_val is not None: - parts.append(f"offset({self.offset_val})") - return " | ".join(parts) +class EmbeddedObservation(Observation[T]): + """Observation enriched with a vector embedding and optional similarity score.""" + + embedding: Embedding | None = None + similarity: float | None = None + + def derive(self, *, data: Any, **overrides: Any) -> EmbeddedObservation[Any]: + """Preserve embedding unless explicitly replaced.""" + return EmbeddedObservation( + id=self.id, + ts=overrides.get("ts", self.ts), + pose=overrides.get("pose", self.pose), + tags=overrides.get("tags", self.tags), + _data=data, + embedding=overrides.get("embedding", self.embedding), + similarity=overrides.get("similarity", self.similarity), + ) diff --git a/dimos/memory2/vectorstore/__init__.py b/dimos/memory/vectorstore/__init__.py similarity index 83% rename from dimos/memory2/vectorstore/__init__.py rename to dimos/memory/vectorstore/__init__.py index d8f3395cb8..fa9ff33c8a 100644 --- a/dimos/memory2/vectorstore/__init__.py +++ b/dimos/memory/vectorstore/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.memory2.vectorstore.memory import MemoryVectorStore -from dimos.memory2.vectorstore.sqlite import SqliteVectorStore +from dimos.memory.vectorstore.memory import MemoryVectorStore +from dimos.memory.vectorstore.sqlite import SqliteVectorStore __all__ = ["MemoryVectorStore", "SqliteVectorStore"] diff --git a/dimos/memory2/vectorstore/memory.py b/dimos/memory/vectorstore/memory.py similarity index 97% rename from dimos/memory2/vectorstore/memory.py rename to dimos/memory/vectorstore/memory.py index 22532c6ad1..3fcad4e02a 100644 --- a/dimos/memory2/vectorstore/memory.py +++ b/dimos/memory/vectorstore/memory.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING -from dimos.memory2.backend import VectorStore +from dimos.memory.backend import VectorStore if TYPE_CHECKING: from dimos.models.embedding.base import Embedding diff --git a/dimos/memory2/vectorstore/sqlite.py b/dimos/memory/vectorstore/sqlite.py similarity index 98% rename from dimos/memory2/vectorstore/sqlite.py rename to dimos/memory/vectorstore/sqlite.py index 736cc16e27..5ff49a2255 100644 --- a/dimos/memory2/vectorstore/sqlite.py +++ b/dimos/memory/vectorstore/sqlite.py @@ -17,7 +17,7 @@ import json from typing import TYPE_CHECKING -from dimos.memory2.backend import VectorStore +from dimos.memory.backend import VectorStore if TYPE_CHECKING: import sqlite3 diff --git a/dimos/memory2/__init__.py b/dimos/memory2/__init__.py deleted file mode 100644 index 0b358fe438..0000000000 --- a/dimos/memory2/__init__.py +++ /dev/null @@ -1,70 +0,0 @@ -from dimos.memory2.backend import Backend, LiveChannel, VectorStore -from dimos.memory2.buffer import ( - BackpressureBuffer, - Bounded, - ClosedError, - DropNew, - KeepLast, - Unbounded, -) -from dimos.memory2.embed import EmbedImages, EmbedText -from dimos.memory2.filter import ( - AfterFilter, - AtFilter, - BeforeFilter, - Filter, - NearFilter, - PredicateFilter, - StreamQuery, - TagsFilter, - TimeRangeFilter, -) -from dimos.memory2.impl.memory import ListBackend, MemorySession, MemoryStore -from dimos.memory2.impl.sqlite import SqliteBackend, SqliteSession, SqliteStore, SqliteStoreConfig -from dimos.memory2.livechannel import SubjectChannel -from dimos.memory2.store import Session, SessionConfig, Store, StoreConfig, StreamNamespace -from dimos.memory2.stream import Stream -from dimos.memory2.transform import FnTransformer, QualityWindow, Transformer -from dimos.memory2.type import EmbeddedObservation, Observation - -__all__ = [ - "AfterFilter", - "AtFilter", - "Backend", - "BackpressureBuffer", - "BeforeFilter", - "Bounded", - "ClosedError", - "DropNew", - "EmbedImages", - "EmbedText", - "EmbeddedObservation", - "Filter", - "FnTransformer", - "KeepLast", - "ListBackend", - "LiveChannel", - "MemorySession", - "MemoryStore", - "NearFilter", - "Observation", - "PredicateFilter", - "QualityWindow", - "Session", - "SessionConfig", - "SqliteBackend", - "SqliteSession", - "SqliteStore", - "SqliteStoreConfig", - "Store", - "StoreConfig", - "Stream", - "StreamNamespace", - "StreamQuery", - "SubjectChannel", - "TagsFilter", - "TimeRangeFilter", - "Transformer", - "Unbounded", - "VectorStore", -] diff --git a/dimos/memory2/impl/sqlite.py b/dimos/memory2/impl/sqlite.py deleted file mode 100644 index b4ad7bd520..0000000000 --- a/dimos/memory2/impl/sqlite.py +++ /dev/null @@ -1,671 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from dataclasses import dataclass, replace -from itertools import islice -import json -import re -import sqlite3 -import threading -from typing import TYPE_CHECKING, Any, Generic, TypeVar - -from dimos.memory2.backend import BackendConfig -from dimos.memory2.blobstore.sqlite import SqliteBlobStore -from dimos.memory2.codecs.base import Codec, codec_for -from dimos.memory2.filter import ( - AfterFilter, - AtFilter, - BeforeFilter, - NearFilter, - TagsFilter, - TimeRangeFilter, - _xyz, -) -from dimos.memory2.livechannel.subject import SubjectChannel -from dimos.memory2.store import Session, Store, StoreConfig -from dimos.memory2.type import _UNLOADED, Observation -from dimos.protocol.service.spec import Configurable - -if TYPE_CHECKING: - from collections.abc import Iterator - - from reactivex.abc import DisposableBase - - from dimos.memory2.backend import Backend, LiveChannel - from dimos.memory2.buffer import BackpressureBuffer - from dimos.memory2.filter import Filter, StreamQuery - -T = TypeVar("T") - -_IDENT_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") - - -# ── Helpers ────────────────────────────────────────────────────── - - -def _validate_identifier(name: str) -> None: - if not _IDENT_RE.match(name): - raise ValueError(f"Invalid stream name: {name!r}") - - -def _decompose_pose(pose: Any) -> tuple[float, ...] | None: - if pose is None: - return None - if hasattr(pose, "position"): - pos = pose.position - orient = getattr(pose, "orientation", None) - x, y, z = float(pos.x), float(pos.y), float(getattr(pos, "z", 0.0)) - if orient is not None: - return (x, y, z, float(orient.x), float(orient.y), float(orient.z), float(orient.w)) - return (x, y, z, 0.0, 0.0, 0.0, 1.0) - if isinstance(pose, (list, tuple)): - vals = [float(v) for v in pose] - while len(vals) < 7: - vals.append(0.0 if len(vals) < 6 else 1.0) - return tuple(vals[:7]) - return None - - -def _reconstruct_pose( - x: float | None, - y: float | None, - z: float | None, - qx: float | None, - qy: float | None, - qz: float | None, - qw: float | None, -) -> tuple[float, ...] | None: - if x is None: - return None - return (x, y or 0.0, z or 0.0, qx or 0.0, qy or 0.0, qz or 0.0, qw or 1.0) - - -def _compile_filter(f: Filter, stream: str, prefix: str = "") -> tuple[str, list[Any]] | None: - """Compile a filter to SQL WHERE clause. Returns None for non-pushable filters. - - ``stream`` is the raw stream name (for R*Tree table references). - ``prefix`` is a column qualifier (e.g. ``"meta."`` for JOIN queries). - """ - if isinstance(f, AfterFilter): - return (f"{prefix}ts > ?", [f.t]) - if isinstance(f, BeforeFilter): - return (f"{prefix}ts < ?", [f.t]) - if isinstance(f, TimeRangeFilter): - return (f"{prefix}ts >= ? AND {prefix}ts <= ?", [f.t1, f.t2]) - if isinstance(f, AtFilter): - return (f"ABS({prefix}ts - ?) <= ?", [f.t, f.tolerance]) - if isinstance(f, TagsFilter): - clauses = [] - params: list[Any] = [] - for k, v in f.tags.items(): - clauses.append(f"json_extract({prefix}tags, '$.{k}') = ?") - params.append(v) - return (" AND ".join(clauses), params) - if isinstance(f, NearFilter): - pose = f.pose - if pose is None: - return None - if hasattr(pose, "position"): - pose = pose.position - cx, cy, cz = _xyz(pose) - r = f.radius - # R*Tree bounding-box pre-filter + exact squared-distance check - rtree_sql = ( - f'{prefix}id IN (SELECT id FROM "{stream}_rtree" ' - f"WHERE x_min >= ? AND x_max <= ? " - f"AND y_min >= ? AND y_max <= ? " - f"AND z_min >= ? AND z_max <= ?)" - ) - dist_sql = ( - f"(({prefix}pose_x - ?) * ({prefix}pose_x - ?) + " - f"({prefix}pose_y - ?) * ({prefix}pose_y - ?) + " - f"({prefix}pose_z - ?) * ({prefix}pose_z - ?) <= ?)" - ) - return ( - f"{rtree_sql} AND {dist_sql}", - [ - cx - r, - cx + r, - cy - r, - cy + r, - cz - r, - cz + r, # R*Tree bbox - cx, - cx, - cy, - cy, - cz, - cz, - r * r, # squared distance - ], - ) - # PredicateFilter — not pushable - return None - - -def _compile_query( - query: StreamQuery, - table: str, - *, - join_blob: bool = False, -) -> tuple[str, list[Any], list[Filter]]: - """Compile a StreamQuery to SQL. - - Returns (sql, params, python_filters) where python_filters must be - applied as post-filters in Python. - """ - prefix = "meta." if join_blob else "" - if join_blob: - select = f'SELECT meta.id, meta.ts, meta.pose_x, meta.pose_y, meta.pose_z, meta.pose_qx, meta.pose_qy, meta.pose_qz, meta.pose_qw, json(meta.tags), blob.data FROM "{table}" AS meta JOIN "{table}_blob" AS blob ON blob.id = meta.id' - else: - select = f'SELECT id, ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, json(tags) FROM "{table}"' - - where_parts: list[str] = [] - params: list[Any] = [] - python_filters: list[Filter] = [] - - for f in query.filters: - compiled = _compile_filter(f, table, prefix) - if compiled is not None: - sql_part, sql_params = compiled - where_parts.append(sql_part) - params.extend(sql_params) - else: - python_filters.append(f) - - sql = select - if where_parts: - sql += " WHERE " + " AND ".join(where_parts) - - # ORDER BY - if query.order_field: - direction = "DESC" if query.order_desc else "ASC" - sql += f" ORDER BY {prefix}{query.order_field} {direction}" - else: - sql += f" ORDER BY {prefix}id ASC" - - # Only push LIMIT/OFFSET to SQL when there are no Python post-filters - if not python_filters and not query.search_text: - if query.limit_val is not None: - if query.offset_val: - sql += f" LIMIT {query.limit_val} OFFSET {query.offset_val}" - else: - sql += f" LIMIT {query.limit_val}" - elif query.offset_val: - sql += f" LIMIT -1 OFFSET {query.offset_val}" - - return (sql, params, python_filters) - - -def _compile_count( - query: StreamQuery, - table: str, -) -> tuple[str, list[Any], list[Filter]]: - """Compile a StreamQuery to a COUNT SQL query.""" - where_parts: list[str] = [] - params: list[Any] = [] - python_filters: list[Filter] = [] - - for f in query.filters: - compiled = _compile_filter(f, table) - if compiled is not None: - sql_part, sql_params = compiled - where_parts.append(sql_part) - params.extend(sql_params) - else: - python_filters.append(f) - - sql = f'SELECT COUNT(*) FROM "{table}"' - if where_parts: - sql += " WHERE " + " AND ".join(where_parts) - - return (sql, params, python_filters) - - -# ── SqliteBackend ──────────────────────────────────────────────── - - -class SqliteBackend(Configurable[BackendConfig], Generic[T]): - """SQLite-backed observation storage for a single stream (table).""" - - default_config: type[BackendConfig] = BackendConfig - - def __init__(self, conn: sqlite3.Connection, name: str, **kwargs: Any) -> None: - super().__init__(**kwargs) - self._conn = conn - self._name = name - self._codec: Codec[Any] = self.config.codec # type: ignore[assignment] - self._channel: LiveChannel[T] = self.config.live_channel or SubjectChannel() - self._lock = threading.Lock() - self._tag_indexes: set[str] = set() - - @property - def name(self) -> str: - return self._name - - @property - def live_channel(self) -> LiveChannel[T]: - return self._channel - - @property - def _join_blobs(self) -> bool: - if not self.config.eager_blobs: - return False - bs = self.config.blob_store - return isinstance(bs, SqliteBlobStore) and bs._conn is self._conn - - def _make_loader(self, row_id: int) -> Any: - bs = self.config.blob_store - assert bs is not None - name, codec = self._name, self._codec - owner_tid = threading.get_ident() - - def loader() -> Any: - assert threading.get_ident() == owner_tid - raw = bs.get(name, row_id) - return codec.decode(raw) - - return loader - - def _row_to_obs(self, row: tuple[Any, ...], *, has_blob: bool = False) -> Observation[T]: - if has_blob: - row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json, blob_data = row - else: - row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json = row - blob_data = None - - pose = _reconstruct_pose(px, py, pz, qx, qy, qz, qw) - tags = json.loads(tags_json) if tags_json else {} - - if has_blob and blob_data is not None: - data = self._codec.decode(blob_data) - return Observation(id=row_id, ts=ts, pose=pose, tags=tags, _data=data) - - return Observation( - id=row_id, - ts=ts, - pose=pose, - tags=tags, - _data=_UNLOADED, - _loader=self._make_loader(row_id), # type: ignore[arg-type] - ) - - # ── Write ──────────────────────────────────────────────────── - - def _ensure_tag_indexes(self, tags: dict[str, Any]) -> None: - """Auto-create expression indexes for any new tag keys.""" - for key in tags: - if key not in self._tag_indexes and _IDENT_RE.match(key): - self._conn.execute( - f'CREATE INDEX IF NOT EXISTS "{self._name}_tag_{key}" ' - f"ON \"{self._name}\"(json_extract(tags, '$.{key}'))" - ) - self._tag_indexes.add(key) - - def append(self, obs: Observation[T]) -> Observation[T]: - encoded = self._codec.encode(obs._data) - pose = _decompose_pose(obs.pose) - tags_json = json.dumps(obs.tags) if obs.tags else "{}" - - with self._lock: - if obs.tags: - self._ensure_tag_indexes(obs.tags) - if pose: - px, py, pz, qx, qy, qz, qw = pose - else: - px = py = pz = qx = qy = qz = qw = None - - cur = self._conn.execute( - f'INSERT INTO "{self._name}" (ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, tags) ' - "VALUES (?, ?, ?, ?, ?, ?, ?, ?, jsonb(?))", - (obs.ts, px, py, pz, qx, qy, qz, qw, tags_json), - ) - row_id = cur.lastrowid - assert row_id is not None - - bs = self.config.blob_store - assert bs is not None - bs.put(self._name, row_id, encoded) - - # R*Tree spatial index - if pose: - self._conn.execute( - f'INSERT INTO "{self._name}_rtree" (id, x_min, x_max, y_min, y_max, z_min, z_max) ' - "VALUES (?, ?, ?, ?, ?, ?, ?)", - (row_id, px, px, py, py, pz, pz), - ) - - vs = self.config.vector_store - if vs is not None: - emb = getattr(obs, "embedding", None) - if emb is not None: - vs.put(self._name, row_id, emb) - - self._conn.commit() - - obs.id = row_id - self._channel.notify(obs) - return obs - - # ── Read ───────────────────────────────────────────────────── - - def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: - if query.search_vec is not None and query.live_buffer is not None: - raise TypeError("Cannot combine .search() with .live() — search is a batch operation.") - buf = query.live_buffer - if buf is not None: - sub = self._channel.subscribe(buf) - return self._iterate_live(query, buf, sub) - return self._iterate_snapshot(query) - - def _iterate_snapshot(self, query: StreamQuery) -> Iterator[Observation[T]]: - if query.search_vec is not None and self.config.vector_store is not None: - yield from self._vector_search(query) - return - - join = self._join_blobs - sql, params, python_filters = _compile_query(query, self._name, join_blob=join) - - cur = self._conn.execute(sql, params) - cur.arraysize = self.config.page_size - it: Iterator[Observation[T]] = (self._row_to_obs(r, has_blob=join) for r in cur) - - # Text search — requires loading data - if query.search_text is not None: - needle = query.search_text.lower() - it = (obs for obs in it if needle in str(obs.data).lower()) - - # Apply Python post-filters - if python_filters: - it = (obs for obs in it if all(f.matches(obs) for f in python_filters)) - - # Apply LIMIT/OFFSET in Python when we couldn't push to SQL - if python_filters or query.search_text: - if query.offset_val: - it = islice(it, query.offset_val, None) - if query.limit_val is not None: - it = islice(it, query.limit_val) - - yield from it - - def _vector_search(self, query: StreamQuery) -> Iterator[Observation[T]]: - vs = self.config.vector_store - assert vs is not None and query.search_vec is not None - - hits = vs.search(self._name, query.search_vec, query.search_k or 10) - if not hits: - return - - ids = [h[0] for h in hits] - dict(hits) - - # Batch-fetch metadata - join = self._join_blobs - placeholders = ",".join("?" * len(ids)) - if join: - sql = ( - f"SELECT meta.id, meta.ts, meta.pose_x, meta.pose_y, meta.pose_z, " - f"meta.pose_qx, meta.pose_qy, meta.pose_qz, meta.pose_qw, json(meta.tags), blob.data " - f'FROM "{self._name}" AS meta ' - f'JOIN "{self._name}_blob" AS blob ON blob.id = meta.id ' - f"WHERE meta.id IN ({placeholders})" - ) - else: - sql = ( - f"SELECT id, ts, pose_x, pose_y, pose_z, " - f"pose_qx, pose_qy, pose_qz, pose_qw, json(tags) " - f'FROM "{self._name}" WHERE id IN ({placeholders})' - ) - - rows = self._conn.execute(sql, ids).fetchall() - obs_by_id: dict[int, Observation[T]] = {} - for r in rows: - obs = self._row_to_obs(r, has_blob=join) - obs_by_id[obs.id] = obs - - # Preserve VectorStore ranking order, promoting to EmbeddedObservation - ranked: list[Observation[T]] = [] - for obs_id, sim in hits: - obs = obs_by_id.get(obs_id) - if obs is not None: - ranked.append(obs.derive(data=obs.data, embedding=query.search_vec, similarity=sim)) - - # Apply remaining query ops (skip vector search) - rest = replace(query, search_vec=None, search_k=None) - yield from rest.apply(iter(ranked)) - - def _iterate_live( - self, - query: StreamQuery, - buf: BackpressureBuffer[Observation[T]], - sub: DisposableBase, - ) -> Iterator[Observation[T]]: - from dimos.memory2.buffer import ClosedError - - # Backfill phase - last_id = -1 - for obs in self._iterate_snapshot(query): - last_id = max(last_id, obs.id) - yield obs - - # Live tail - filters = query.filters - try: - while True: - obs = buf.take() - if obs.id <= last_id: - continue - last_id = obs.id - if filters and not all(f.matches(obs) for f in filters): - continue - yield obs - except (ClosedError, StopIteration): - sub.dispose() - - def count(self, query: StreamQuery) -> int: - if query.search_vec or query.search_text: - return sum(1 for _ in self.iterate(query)) - - sql, params, python_filters = _compile_count(query, self._name) - if python_filters: - return sum(1 for _ in self.iterate(query)) - - row = self._conn.execute(sql, params).fetchone() - return int(row[0]) if row else 0 - - -# ── SqliteSession ──────────────────────────────────────────────── - - -class SqliteSession(Session): - """Session owning a single SQLite connection.""" - - def __init__( - self, conn: sqlite3.Connection, *, vec_available: bool = False, **kwargs: Any - ) -> None: - super().__init__(**kwargs) - self._conn = conn - self._vec_available = vec_available - self._blob_store: SqliteBlobStore | None = None - self._vector_store: Any | None = None - - # Create stream registry - self._conn.execute( - "CREATE TABLE IF NOT EXISTS _streams (" - " name TEXT PRIMARY KEY," - " payload_module TEXT NOT NULL," - " codec_id TEXT NOT NULL" - ")" - ) - self._conn.commit() - - def _ensure_shared_stores(self) -> None: - """Lazily create shared stores on first stream creation.""" - if self._blob_store is None: - self._blob_store = SqliteBlobStore(self._conn) - if self._vector_store is None and self._vec_available: - from dimos.memory2.vectorstore.sqlite import SqliteVectorStore - - self._vector_store = SqliteVectorStore(self._conn) - - @staticmethod - def _codec_id(codec: Codec[Any]) -> str: - from dimos.memory2.codecs.jpeg import JpegCodec - from dimos.memory2.codecs.lcm import LcmCodec - - if isinstance(codec, JpegCodec): - return "jpeg" - if isinstance(codec, LcmCodec): - return "lcm" - return "pickle" - - @staticmethod - def _codec_from_id(codec_id: str, payload_module: str) -> Codec[Any]: - from dimos.memory2.codecs.pickle import PickleCodec - - if codec_id == "jpeg": - from dimos.memory2.codecs.jpeg import JpegCodec - - return JpegCodec() - if codec_id == "lcm": - from dimos.memory2.codecs.lcm import LcmCodec - - # Resolve the payload type from module path - parts = payload_module.rsplit(".", 1) - if len(parts) == 2: - import importlib - - mod = importlib.import_module(parts[0]) - cls = getattr(mod, parts[1]) - return LcmCodec(cls) - return PickleCodec() - return PickleCodec() - - def _create_backend( - self, name: str, payload_type: type[Any] | None = None, **config: Any - ) -> Backend[Any]: - _validate_identifier(name) - self._ensure_shared_stores() - - # Look up existing stream in registry - row = self._conn.execute( - "SELECT payload_module, codec_id FROM _streams WHERE name = ?", (name,) - ).fetchone() - - if row is not None: - stored_module, stored_codec_id = row - if payload_type is not None: - actual_module = f"{payload_type.__module__}.{payload_type.__qualname__}" - if actual_module != stored_module: - raise ValueError( - f"Stream {name!r} was created with type {stored_module}, " - f"but opened with {actual_module}" - ) - codec = config.get("codec") or self._codec_from_id(stored_codec_id, stored_module) - else: - if payload_type is None: - raise TypeError(f"Stream {name!r} does not exist yet — payload_type is required") - codec = config.get("codec") or codec_for(payload_type) - payload_module = f"{payload_type.__module__}.{payload_type.__qualname__}" - self._conn.execute( - "INSERT INTO _streams (name, payload_module, codec_id) VALUES (?, ?, ?)", - (name, payload_module, self._codec_id(codec)), - ) - self._conn.commit() - - # Create metadata table - self._conn.execute( - f'CREATE TABLE IF NOT EXISTS "{name}" (' - " id INTEGER PRIMARY KEY AUTOINCREMENT," - " ts REAL NOT NULL UNIQUE," - " pose_x REAL, pose_y REAL, pose_z REAL," - " pose_qx REAL, pose_qy REAL, pose_qz REAL, pose_qw REAL," - " tags BLOB DEFAULT (jsonb('{}'))" - ")" - ) - # R*Tree spatial index for pose queries - self._conn.execute( - f'CREATE VIRTUAL TABLE IF NOT EXISTS "{name}_rtree" USING rtree(' - " id," - " x_min, x_max," - " y_min, y_max," - " z_min, z_max" - ")" - ) - self._conn.commit() - - # Merge shared stores as defaults - if "blob_store" not in config or config["blob_store"] is None: - config["blob_store"] = self._blob_store - if "vector_store" not in config or config["vector_store"] is None: - config["vector_store"] = self._vector_store - config["codec"] = codec - - return SqliteBackend(self._conn, name, **config) - - def list_streams(self) -> list[str]: - db_names = {row[0] for row in self._conn.execute("SELECT name FROM _streams").fetchall()} - return sorted(db_names | set(self._streams.keys())) - - def delete_stream(self, name: str) -> None: - self._streams.pop(name, None) - self._conn.execute(f'DROP TABLE IF EXISTS "{name}"') - self._conn.execute(f'DROP TABLE IF EXISTS "{name}_blob"') - self._conn.execute(f'DROP TABLE IF EXISTS "{name}_vec"') - self._conn.execute(f'DROP TABLE IF EXISTS "{name}_rtree"') - self._conn.execute("DELETE FROM _streams WHERE name = ?", (name,)) - self._conn.commit() - - def stop(self) -> None: - super().stop() - self._conn.close() - - -# ── SqliteStore ────────────────────────────────────────────────── - - -@dataclass -class SqliteStoreConfig(StoreConfig): - """Config for SQLite-backed store.""" - - path: str = "memory.db" - - -class SqliteStore(Store): - """Store backed by a SQLite database file.""" - - default_config: type[SqliteStoreConfig] = SqliteStoreConfig - - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - - def session(self, **kwargs: Any) -> SqliteSession: - conn = sqlite3.connect(self.config.path, check_same_thread=False) - conn.execute("PRAGMA journal_mode=WAL") - conn.execute("PRAGMA synchronous=NORMAL") - - vec_available = False - try: - import sqlite_vec - - conn.enable_load_extension(True) - sqlite_vec.load(conn) - conn.enable_load_extension(False) - vec_available = True - except (ImportError, Exception): - pass - - return SqliteSession(conn, vec_available=vec_available, **kwargs) diff --git a/dimos/memory2/livechannel/__init__.py b/dimos/memory2/livechannel/__init__.py deleted file mode 100644 index 4fba822bab..0000000000 --- a/dimos/memory2/livechannel/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from dimos.memory2.backend import LiveChannel -from dimos.memory2.livechannel.subject import SubjectChannel - -__all__ = ["LiveChannel", "SubjectChannel"] diff --git a/dimos/memory2/store.py b/dimos/memory2/store.py deleted file mode 100644 index e9c1ec4e51..0000000000 --- a/dimos/memory2/store.py +++ /dev/null @@ -1,164 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from abc import abstractmethod -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, TypeVar, cast - -from dimos.core.resource import CompositeResource -from dimos.memory2.stream import Stream -from dimos.protocol.service.spec import Configurable - -if TYPE_CHECKING: - from collections.abc import Iterator - - from dimos.memory2.backend import Backend, BlobStore, LiveChannel, VectorStore - from dimos.memory2.codecs.base import Codec - -T = TypeVar("T") - - -# ── Configuration ───────────────────────────────────────────────── - - -@dataclass -class StoreConfig: - """Base config for Store. Subclasses extend with store-specific fields.""" - - -@dataclass -class SessionConfig: - """Session-level defaults for stream capabilities. - - These are inherited by all streams in the session unless overridden - per-stream in ``session.stream(..., **overrides)``. - """ - - live_channel: LiveChannel[Any] | None = None - blob_store: BlobStore | None = None - vector_store: VectorStore | None = None - eager_blobs: bool = False - codec: Codec[Any] | None = None - - -# ── Stream namespace ────────────────────────────────────────────── - - -class StreamNamespace: - """Attribute-access proxy for session streams. - - Usage:: - - session.streams.image_stream - session.streams["image_stream"] - list(session.streams) - len(session.streams) - """ - - def __init__(self, session: Session) -> None: - self._session = session - - def __getattr__(self, name: str) -> Stream[Any]: - if name.startswith("_"): - raise AttributeError(name) - if name not in self._session.list_streams(): - available = ", ".join(self._session.list_streams()) or "(none)" - raise AttributeError(f"No stream named {name!r}. Available: {available}") - return self._session.stream(name) - - def __getitem__(self, name: str) -> Stream[Any]: - if name not in self._session.list_streams(): - raise KeyError(name) - return self._session.stream(name) - - def __iter__(self) -> Iterator[Stream[Any]]: - for name in self._session.list_streams(): - yield self._session.stream(name) - - def __len__(self) -> int: - return len(self._session.list_streams()) - - def __contains__(self, name: str) -> bool: - return name in self._session.list_streams() - - def __repr__(self) -> str: - return f"StreamNamespace({self._session.list_streams()})" - - -# ── Session & Store ─────────────────────────────────────────────── - - -class Session(Configurable[SessionConfig], CompositeResource): - """A session against a store. Manages named streams over a shared connection. - - Subclasses implement ``_create_backend`` to provide storage-specific backends. - """ - - default_config: type[SessionConfig] = SessionConfig - - def __init__(self, **kwargs: Any) -> None: - Configurable.__init__(self, **kwargs) - CompositeResource.__init__(self) - self._streams: dict[str, Stream[Any]] = {} - - @abstractmethod - def _create_backend( - self, name: str, payload_type: type[Any] | None = None, **config: Any - ) -> Backend[Any]: - """Create a backend for the named stream. Called once per stream name.""" - ... - - def stream(self, name: str, payload_type: type[T] | None = None, **overrides: Any) -> Stream[T]: - """Get or create a named stream. Returns the same Stream on repeated calls. - - Per-stream ``overrides`` (e.g. ``live_channel=``) are merged on top of - the session-level defaults from :class:`SessionConfig`. - """ - if name not in self._streams: - resolved = {k: v for k, v in vars(self.config).items() if v is not None} - resolved.update({k: v for k, v in overrides.items() if v is not None}) - backend = self._create_backend(name, payload_type, **resolved) - self._streams[name] = Stream(source=backend) - return cast("Stream[T]", self._streams[name]) - - @abstractmethod - def list_streams(self) -> list[str]: - """Return names of all streams in this session.""" - ... - - @abstractmethod - def delete_stream(self, name: str) -> None: - """Delete a stream by name (from cache and underlying storage).""" - ... - - @property - def streams(self) -> StreamNamespace: - return StreamNamespace(self) - - -class Store(Configurable[StoreConfig], CompositeResource): - """Top-level entry point — wraps a storage location (file, URL, etc.).""" - - default_config: type[StoreConfig] = StoreConfig - - def __init__(self, **kwargs: Any) -> None: - Configurable.__init__(self, **kwargs) - CompositeResource.__init__(self) - - @abstractmethod - def session(self, **kwargs: Any) -> Session: - """Create a session. kwargs are forwarded to SessionConfig.""" - ... diff --git a/dimos/memory2/stream.py b/dimos/memory2/stream.py deleted file mode 100644 index df6dc4636a..0000000000 --- a/dimos/memory2/stream.py +++ /dev/null @@ -1,381 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import time -from typing import TYPE_CHECKING, Any, Generic, TypeVar - -from dimos.core.resource import Resource -from dimos.memory2.backend import Backend -from dimos.memory2.buffer import BackpressureBuffer, KeepLast -from dimos.memory2.filter import ( - AfterFilter, - AtFilter, - BeforeFilter, - Filter, - NearFilter, - PredicateFilter, - StreamQuery, - TagsFilter, - TimeRangeFilter, -) -from dimos.memory2.transform import FnIterTransformer, FnTransformer, Transformer -from dimos.memory2.type import EmbeddedObservation, Observation - -if TYPE_CHECKING: - from collections.abc import Callable, Iterator - - import reactivex - from reactivex.abc import DisposableBase, ObserverBase - - from dimos.models.embedding.base import Embedding - -T = TypeVar("T") -R = TypeVar("R") - - -class Stream(Resource, Generic[T]): - """Lazy, pull-based stream over observations. - - Every filter/transform method returns a new Stream — no computation - happens until iteration. Backends handle query application for stored - data; transform sources apply filters as Python predicates. - - Implements Resource so live streams can be cleanly stopped via - ``stop()`` or used as a context manager. - """ - - def __init__( - self, - source: Backend[T] | Stream[Any], - *, - xf: Transformer[Any, T] | None = None, - query: StreamQuery = StreamQuery(), - ) -> None: - self._source = source - self._xf = xf - self._query = query - - def start(self) -> None: - pass - - def stop(self) -> None: - """Close the live buffer (if any), unblocking iteration.""" - buf = self._query.live_buffer - if buf is not None: - buf.close() - if isinstance(self._source, Stream): - self._source.stop() - - def __str__(self) -> str: - # Walk the source chain to collect (xf, query) pairs - chain: list[tuple[Any, StreamQuery]] = [] - current: Any = self - while isinstance(current, Stream): - chain.append((current._xf, current._query)) - current = current._source - chain.reverse() # innermost first - - # current is the Backend - name = getattr(current, "name", "?") - result = f'Stream("{name}")' - - for xf, query in chain: - if xf is not None: - result += f" -> {xf}" - q_str = str(query) - if q_str: - result += f" | {q_str}" - - return result - - def is_live(self) -> bool: - """True if this stream (or any ancestor in the chain) is in live mode.""" - if self._query.live_buffer is not None: - return True - if isinstance(self._source, Stream): - return self._source.is_live() - return False - - # ── Iteration ─────────────────────────────────────────────────── - - def __iter__(self) -> Iterator[Observation[T]]: - return self._build_iter() - - def _build_iter(self) -> Iterator[Observation[T]]: - if isinstance(self._source, Stream): - return self._iter_transform() - # Backend handles all query application (including live if requested) - return self._source.iterate(self._query) - - def _iter_transform(self) -> Iterator[Observation[T]]: - """Iterate a transform source, applying query filters in Python.""" - assert isinstance(self._source, Stream) and self._xf is not None - it: Iterator[Observation[T]] = self._xf(iter(self._source)) - return self._query.apply(it, live=self.is_live()) - - # ── Query builders ────────────────────────────────────────────── - - def _replace_query(self, **overrides: Any) -> Stream[T]: - q = self._query - new_q = StreamQuery( - filters=overrides.get("filters", q.filters), - order_field=overrides.get("order_field", q.order_field), - order_desc=overrides.get("order_desc", q.order_desc), - limit_val=overrides.get("limit_val", q.limit_val), - offset_val=overrides.get("offset_val", q.offset_val), - live_buffer=overrides.get("live_buffer", q.live_buffer), - search_vec=overrides.get("search_vec", q.search_vec), - search_k=overrides.get("search_k", q.search_k), - search_text=overrides.get("search_text", q.search_text), - ) - return Stream(self._source, xf=self._xf, query=new_q) - - def _with_filter(self, f: Filter) -> Stream[T]: - return self._replace_query(filters=(*self._query.filters, f)) - - def after(self, t: float) -> Stream[T]: - return self._with_filter(AfterFilter(t)) - - def before(self, t: float) -> Stream[T]: - return self._with_filter(BeforeFilter(t)) - - def time_range(self, t1: float, t2: float) -> Stream[T]: - return self._with_filter(TimeRangeFilter(t1, t2)) - - def at(self, t: float, tolerance: float = 1.0) -> Stream[T]: - return self._with_filter(AtFilter(t, tolerance)) - - def near(self, pose: Any, radius: float) -> Stream[T]: - return self._with_filter(NearFilter(pose, radius)) - - def tags(self, **tags: Any) -> Stream[T]: - return self._with_filter(TagsFilter(tags)) - - def order_by(self, field: str, desc: bool = False) -> Stream[T]: - return self._replace_query(order_field=field, order_desc=desc) - - def limit(self, k: int) -> Stream[T]: - return self._replace_query(limit_val=k) - - def offset(self, n: int) -> Stream[T]: - return self._replace_query(offset_val=n) - - def search(self, query: Embedding, k: int) -> Stream[T]: - """Return top-k observations by cosine similarity to *query*. - - The backend handles the actual computation. ListBackend does - brute-force cosine; SqliteBackend (future) pushes down to vec0. - """ - return self._replace_query(search_vec=query, search_k=k) - - def search_text(self, text: str) -> Stream[T]: - """Filter observations whose data contains *text*. - - ListBackend does case-insensitive substring match; - SqliteBackend (future) pushes down to FTS5. - """ - return self._replace_query(search_text=text) - - # ── Functional API ────────────────────────────────────────────── - - def filter(self, pred: Callable[[Observation[T]], bool]) -> Stream[T]: - """Filter by arbitrary predicate on the full Observation.""" - return self._with_filter(PredicateFilter(pred)) - - def map(self, fn: Callable[[Observation[T]], Observation[R]]) -> Stream[Any]: - """Transform each observation's data via callable.""" - return self.transform(FnTransformer(lambda obs: fn(obs))) - - # ── Transform ─────────────────────────────────────────────────── - - def transform( - self, - xf: Transformer[T, R] | Callable[[Iterator[Observation[T]]], Iterator[Observation[R]]], - ) -> Stream[R]: - """Wrap this stream with a transformer. Returns a new lazy Stream. - - Accepts a ``Transformer`` subclass or a bare callable / generator - function with the same ``Iterator[Obs] → Iterator[Obs]`` signature:: - - def detect(upstream): - for obs in upstream: - yield obs.derive(data=run_detector(obs.data)) - - images.transform(detect).save(detections) - """ - if not isinstance(xf, Transformer): - xf = FnIterTransformer(xf) - return Stream(source=self, xf=xf, query=StreamQuery()) - - # ── Live mode ─────────────────────────────────────────────────── - - def live(self, buffer: BackpressureBuffer[Observation[Any]] | None = None) -> Stream[T]: - """Return a stream whose iteration never ends — backfill then live tail. - - All backends support live mode via their built-in ``LiveChannel``. - Call .live() before .transform(), not after. - - Default buffer: KeepLast(). The backend handles subscription, dedup, - and backpressure — how it does so is its business. - """ - if isinstance(self._source, Stream): - raise TypeError( - "Cannot call .live() on a transform stream. " - "Call .live() on the source stream, then .transform()." - ) - buf = buffer if buffer is not None else KeepLast() - return self._replace_query(live_buffer=buf) - - # ── Save ───────────────────────────────────────────────────────── - - def save(self, target: Stream[T]) -> Stream[T]: - """Sync terminal: iterate self, append each obs to target's backend. - - Returns the target stream for continued querying. - """ - if isinstance(target._source, Stream): - raise TypeError("Cannot save to a transform stream. Target must be backend-backed.") - backend = target._source - for obs in self: - backend.append(obs) - return target - - # ── Terminals ─────────────────────────────────────────────────── - - def fetch(self) -> list[Observation[T]]: - """Materialize all observations into a list.""" - if self.is_live(): - raise TypeError( - ".fetch() on a live stream would block forever. " - "Use .drain() or .save(target) instead." - ) - return list(self) - - def first(self) -> Observation[T]: - """Return the first matching observation.""" - it = iter(self.limit(1)) - try: - return next(it) - except StopIteration: - raise LookupError("No matching observation") from None - - def last(self) -> Observation[T]: - """Return the last matching observation (by timestamp).""" - return self.order_by("ts", desc=True).first() - - def count(self) -> int: - """Count matching observations.""" - if isinstance(self._source, Backend): - return self._source.count(self._query) - if self.is_live(): - raise TypeError(".count() on a live transform stream would block forever.") - return sum(1 for _ in self) - - def exists(self) -> bool: - """Check if any matching observation exists.""" - return next(iter(self.limit(1)), None) is not None - - def get_time_range(self) -> tuple[float, float]: - """Return (min_ts, max_ts) for matching observations.""" - first = self.first() - last = self.last() - return (first.ts, last.ts) - - def summary(self) -> str: - """Return a short human-readable summary: count, time range, duration.""" - from datetime import datetime, timezone - - n = self.count() - if n == 0: - return f"{self}: empty" - - (t0, t1) = self.get_time_range() - - fmt = "%Y-%m-%d %H:%M:%S" - dt0 = datetime.fromtimestamp(t0, tz=timezone.utc).strftime(fmt) - dt1 = datetime.fromtimestamp(t1, tz=timezone.utc).strftime(fmt) - dur = t1 - t0 - return f"{self}: {n} items, {dt0} — {dt1} ({dur:.1f}s)" - - def drain(self) -> int: - """Consume all observations, discarding results. Returns count consumed. - - Use for side-effect pipelines (e.g. live embed-and-store) where you - don't need to collect results in memory. - """ - n = 0 - for _ in self: - n += 1 - return n - - # ── Reactive ───────────────────────────────────────────────────── - - def observable(self) -> reactivex.Observable[Observation[T]]: - """Convert this stream to an RxPY Observable. - - Iteration is scheduled on the dimos thread pool so subscribe() never - blocks the calling thread. - """ - import reactivex - import reactivex.operators as ops - - from dimos.utils.threadpool import get_scheduler - - return reactivex.from_iterable(self).pipe( - ops.subscribe_on(get_scheduler()), - ) - - def subscribe( - self, - on_next: Callable[[Observation[T]], None] | ObserverBase[Observation[T]] | None = None, - on_error: Callable[[Exception], None] | None = None, - on_completed: Callable[[], None] | None = None, - ) -> DisposableBase: - """Subscribe to this stream as an RxPY Observable.""" - return self.observable().subscribe( # type: ignore[call-overload] - on_next=on_next, - on_error=on_error, - on_completed=on_completed, - ) - - # ── Write ─────────────────────────────────────────────────────── - - def append( - self, - payload: T, - *, - ts: float | None = None, - pose: Any | None = None, - tags: dict[str, Any] | None = None, - embedding: Embedding | None = None, - ) -> Observation[T]: - """Append to the backing store. Only works if source is a Backend.""" - if isinstance(self._source, Stream): - raise TypeError("Cannot append to a transform stream. Append to the source stream.") - _ts = ts if ts is not None else time.time() - _tags = tags or {} - if embedding is not None: - obs: Observation[T] = EmbeddedObservation( - id=-1, - ts=_ts, - pose=pose, - tags=_tags, - _data=payload, - embedding=embedding, - ) - else: - obs = Observation(id=-1, ts=_ts, pose=pose, tags=_tags, _data=payload) - return self._source.append(obs) diff --git a/dimos/memory2/type.py b/dimos/memory2/type.py deleted file mode 100644 index 85cfab9640..0000000000 --- a/dimos/memory2/type.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Generic, TypeVar - -if TYPE_CHECKING: - from collections.abc import Callable - - from dimos.models.embedding.base import Embedding - -T = TypeVar("T") - - -# ── Lazy data sentinel ────────────────────────────────────────────── - - -class _Unloaded: - """Sentinel indicating data has not been loaded yet.""" - - __slots__ = () - - def __repr__(self) -> str: - return "" - - -_UNLOADED = _Unloaded() - - -# ── Observation ───────────────────────────────────────────────────── - - -@dataclass -class Observation(Generic[T]): - """A single timestamped observation with optional spatial pose and metadata.""" - - id: int - ts: float - pose: Any | None = None - tags: dict[str, Any] = field(default_factory=dict) - _data: T | _Unloaded = field(default=_UNLOADED, repr=False) - _loader: Callable[[], T] | None = field(default=None, repr=False) - - @property - def data(self) -> T: - val = self._data - if isinstance(val, _Unloaded): - if self._loader is None: - raise LookupError("No data and no loader set on this observation") - loaded = self._loader() - self._data = loaded - self._loader = None # release closure - return loaded - return val - - def derive(self, *, data: Any, **overrides: Any) -> Observation[Any]: - """Create a new observation preserving ts/pose/tags, replacing data. - - If ``embedding`` is passed, promotes the result to - :class:`EmbeddedObservation`. - """ - if "embedding" in overrides: - return EmbeddedObservation( - id=self.id, - ts=overrides.get("ts", self.ts), - pose=overrides.get("pose", self.pose), - tags=overrides.get("tags", self.tags), - _data=data, - embedding=overrides["embedding"], - similarity=overrides.get("similarity"), - ) - return Observation( - id=self.id, - ts=overrides.get("ts", self.ts), - pose=overrides.get("pose", self.pose), - tags=overrides.get("tags", self.tags), - _data=data, - ) - - -# ── EmbeddedObservation ────────────────────────────────────────── - - -@dataclass -class EmbeddedObservation(Observation[T]): - """Observation enriched with a vector embedding and optional similarity score.""" - - embedding: Embedding | None = None - similarity: float | None = None - - def derive(self, *, data: Any, **overrides: Any) -> EmbeddedObservation[Any]: - """Preserve embedding unless explicitly replaced.""" - return EmbeddedObservation( - id=self.id, - ts=overrides.get("ts", self.ts), - pose=overrides.get("pose", self.pose), - tags=overrides.get("tags", self.tags), - _data=data, - embedding=overrides.get("embedding", self.embedding), - similarity=overrides.get("similarity", self.similarity), - ) From 1a6c8a10b1e3cdb0d0a7db3031e26c45dcb5a3d6 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 12 Mar 2026 20:19:48 +0800 Subject: [PATCH 096/118] Revert memory rename: restore memory/ from dev, new code lives in memory2/ MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Restore dimos/memory/ (old timeseries memory) to match dev - Move new memory system back to dimos/memory2/ with corrected imports - Delete dimos/memory_old/ (no longer needed) - Fix memory_old imports in tf.py, timestamped.py, replay.py → dimos.memory - Remove dps CLI util and pyproject entry - Remove unitree_go2_memory blueprint (depends on deleted modules) --- dimos/{memory_old => memory}/embedding.py | 21 +- dimos/memory/impl/sqlite.py | 662 +---------------- dimos/memory/livechannel/__init__.py | 4 - dimos/memory/test_embedding.py | 476 +------------ .../timeseries/__init__.py | 14 +- .../{memory_old => memory}/timeseries/base.py | 0 .../timeseries/inmemory.py | 2 +- .../timeseries/legacy.py | 2 +- .../timeseries/pickledir.py | 2 +- .../timeseries/postgres.py | 2 +- .../timeseries/sqlite.py | 12 +- .../timeseries/test_base.py | 16 +- .../timeseries/test_legacy.py | 2 +- dimos/{memory => memory2}/__init__.py | 22 +- dimos/{memory => memory2}/architecture.md | 2 +- dimos/{memory => memory2}/backend.py | 8 +- .../{memory => memory2}/blobstore/__init__.py | 6 +- .../blobstore/blobstore.md | 0 dimos/{memory => memory2}/blobstore/file.py | 2 +- dimos/{memory => memory2}/blobstore/sqlite.py | 2 +- .../blobstore/test_blobstore.py | 6 +- dimos/{memory => memory2}/buffer.py | 0 dimos/{memory => memory2}/codecs/README.md | 2 +- dimos/{memory => memory2}/codecs/__init__.py | 4 +- dimos/{memory => memory2}/codecs/base.py | 6 +- dimos/{memory => memory2}/codecs/jpeg.py | 0 dimos/{memory => memory2}/codecs/lcm.py | 0 dimos/{memory => memory2}/codecs/pickle.py | 0 .../{memory => memory2}/codecs/test_codecs.py | 16 +- dimos/{memory => memory2}/embed.py | 4 +- dimos/{memory => memory2}/embeddings.md | 0 dimos/{memory => memory2}/filter.py | 4 +- dimos/{memory => memory2}/formatting.py | 0 dimos/{memory => memory2}/impl/README.md | 10 +- dimos/{memory => memory2}/impl/__init__.py | 0 dimos/{memory => memory2}/impl/memory.py | 20 +- dimos/memory2/impl/sqlite.py | 674 ++++++++++++++++++ dimos/{memory => memory2}/intro.md | 4 +- dimos/memory2/livechannel/__init__.py | 4 + .../livechannel/subject.py | 6 +- dimos/{memory => memory2}/store.py | 6 +- dimos/{memory => memory2}/stream.py | 10 +- dimos/{memory => memory2}/streaming.md | 0 dimos/{memory => memory2}/test_blobstore.py | 8 +- dimos/{memory => memory2}/test_buffer.py | 2 +- dimos/{memory => memory2}/test_e2e_import.py | 4 +- .../test_e2e_processing.py | 0 dimos/{memory => memory2}/test_e2e_query.py | 6 +- dimos/memory2/test_embedding.py | 455 ++++++++++++ dimos/{memory => memory2}/test_impl.py | 28 +- dimos/{memory => memory2}/test_save.py | 10 +- dimos/{memory => memory2}/test_stream.py | 10 +- dimos/{memory => memory2}/transform.py | 4 +- dimos/{memory => memory2}/type.py | 0 .../vectorstore/__init__.py | 4 +- .../{memory => memory2}/vectorstore/memory.py | 2 +- .../{memory => memory2}/vectorstore/sqlite.py | 2 +- dimos/memory_old/impl/sqlite.py | 14 - dimos/memory_old/test_embedding.py | 53 -- dimos/protocol/tf/tf.py | 2 +- dimos/robot/all_blueprints.py | 1 - dimos/robot/cli/dimos.py | 9 - .../blueprints/smart/unitree_go2_memory.py | 67 -- dimos/types/test_timestamped.py | 2 +- dimos/types/timestamped.py | 2 +- dimos/utils/cli/dps.py | 139 ---- dimos/utils/testing/replay.py | 2 +- pyproject.toml | 1 - 68 files changed, 1319 insertions(+), 1541 deletions(-) rename dimos/{memory_old => memory}/embedding.py (92%) delete mode 100644 dimos/memory/livechannel/__init__.py rename dimos/{memory_old => memory}/timeseries/__init__.py (70%) rename dimos/{memory_old => memory}/timeseries/base.py (100%) rename dimos/{memory_old => memory}/timeseries/inmemory.py (98%) rename dimos/{memory_old => memory}/timeseries/legacy.py (99%) rename dimos/{memory_old => memory}/timeseries/pickledir.py (99%) rename dimos/{memory_old => memory}/timeseries/postgres.py (99%) rename dimos/{memory_old => memory}/timeseries/sqlite.py (96%) rename dimos/{memory_old => memory}/timeseries/test_base.py (96%) rename dimos/{memory_old => memory}/timeseries/test_legacy.py (96%) rename dimos/{memory => memory2}/__init__.py (59%) rename dimos/{memory => memory2}/architecture.md (99%) rename dimos/{memory => memory2}/backend.py (96%) rename dimos/{memory => memory2}/blobstore/__init__.py (80%) rename dimos/{memory => memory2}/blobstore/blobstore.md (100%) rename dimos/{memory => memory2}/blobstore/file.py (97%) rename dimos/{memory => memory2}/blobstore/sqlite.py (98%) rename dimos/{memory => memory2}/blobstore/test_blobstore.py (95%) rename dimos/{memory => memory2}/buffer.py (100%) rename dimos/{memory => memory2}/codecs/README.md (97%) rename dimos/{memory => memory2}/codecs/__init__.py (85%) rename dimos/{memory => memory2}/codecs/base.py (88%) rename dimos/{memory => memory2}/codecs/jpeg.py (100%) rename dimos/{memory => memory2}/codecs/lcm.py (100%) rename dimos/{memory => memory2}/codecs/pickle.py (100%) rename dimos/{memory => memory2}/codecs/test_codecs.py (91%) rename dimos/{memory => memory2}/embed.py (96%) rename dimos/{memory => memory2}/embeddings.md (100%) rename dimos/{memory => memory2}/filter.py (98%) rename dimos/{memory => memory2}/formatting.py (100%) rename dimos/{memory => memory2}/impl/README.md (95%) rename dimos/{memory => memory2}/impl/__init__.py (100%) rename dimos/{memory => memory2}/impl/memory.py (93%) create mode 100644 dimos/memory2/impl/sqlite.py rename dimos/{memory => memory2}/intro.md (98%) create mode 100644 dimos/memory2/livechannel/__init__.py rename dimos/{memory => memory2}/livechannel/subject.py (92%) rename dimos/{memory => memory2}/store.py (97%) rename dimos/{memory => memory2}/stream.py (98%) rename dimos/{memory => memory2}/streaming.md (100%) rename dimos/{memory => memory2}/test_blobstore.py (97%) rename dimos/{memory => memory2}/test_buffer.py (96%) rename dimos/{memory => memory2}/test_e2e_import.py (97%) rename dimos/{memory => memory2}/test_e2e_processing.py (100%) rename dimos/{memory => memory2}/test_e2e_query.py (97%) create mode 100644 dimos/memory2/test_embedding.py rename dimos/{memory => memory2}/test_impl.py (95%) rename dimos/{memory => memory2}/test_save.py (95%) rename dimos/{memory => memory2}/test_stream.py (99%) rename dimos/{memory => memory2}/transform.py (97%) rename dimos/{memory => memory2}/type.py (100%) rename dimos/{memory => memory2}/vectorstore/__init__.py (83%) rename dimos/{memory => memory2}/vectorstore/memory.py (97%) rename dimos/{memory => memory2}/vectorstore/sqlite.py (98%) delete mode 100644 dimos/memory_old/impl/sqlite.py delete mode 100644 dimos/memory_old/test_embedding.py delete mode 100644 dimos/robot/unitree/go2/blueprints/smart/unitree_go2_memory.py delete mode 100644 dimos/utils/cli/dps.py diff --git a/dimos/memory_old/embedding.py b/dimos/memory/embedding.py similarity index 92% rename from dimos/memory_old/embedding.py rename to dimos/memory/embedding.py index 6fa3445208..4627ecfc35 100644 --- a/dimos/memory_old/embedding.py +++ b/dimos/memory/embedding.py @@ -29,30 +29,25 @@ from dimos.msgs.nav_msgs import OccupancyGrid from dimos.msgs.sensor_msgs import Image from dimos.msgs.sensor_msgs.Image import Image, sharpness_barrier -from dimos.types.timestamped import Timestamped from dimos.utils.reactive import getter_hot @dataclass -class SpatialEntry(Timestamped): - pose: PoseStamped +class Config(ModuleConfig): + embedding_model: EmbeddingModel = field(default_factory=CLIPModel) @dataclass -class SpatialImage(SpatialEntry): +class SpatialEntry: image: Image + pose: PoseStamped @dataclass -class SpatialEmbedding(SpatialImage): +class SpatialEmbedding(SpatialEntry): embedding: Embedding -@dataclass -class Config(ModuleConfig): - embedding_model: EmbeddingModel = field(default_factory=CLIPModel) - - class EmbeddingMemory(Module[Config]): default_config = Config config: Config @@ -88,13 +83,13 @@ def start(self) -> None: ops.map(self._store_spatial_entry), ).subscribe(print) - def _try_create_spatial_entry(self, img: Image) -> Observable[SpatialImage]: + def _try_create_spatial_entry(self, img: Image) -> Observable[SpatialEntry]: pose = self.tf.get_pose("world", "base_link") if not pose: return rx.empty() - return rx.of(SpatialImage(image=img, pose=pose)) + return rx.of(SpatialEntry(image=img, pose=pose)) - def _embed_spatial_entry(self, spatial_entry: SpatialImage) -> SpatialEmbedding: + def _embed_spatial_entry(self, spatial_entry: SpatialEntry) -> SpatialEmbedding: embedding = cast("Embedding", self.config.embedding_model.embed(spatial_entry.image)) return SpatialEmbedding( image=spatial_entry.image, diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py index e511608387..20caceb8a7 100644 --- a/dimos/memory/impl/sqlite.py +++ b/dimos/memory/impl/sqlite.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python3 # Copyright 2026 Dimensional Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -11,664 +12,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from __future__ import annotations - -from dataclasses import dataclass, replace -from itertools import islice -import json -import re -import sqlite3 -import threading -from typing import TYPE_CHECKING, Any, Generic, TypeVar - -from dimos.memory.backend import BackendConfig -from dimos.memory.blobstore.sqlite import SqliteBlobStore -from dimos.memory.codecs.base import Codec, codec_for -from dimos.memory.filter import ( - AfterFilter, - AtFilter, - BeforeFilter, - NearFilter, - TagsFilter, - TimeRangeFilter, - _xyz, -) -from dimos.memory.livechannel.subject import SubjectChannel -from dimos.memory.store import Session, Store, StoreConfig -from dimos.memory.type import _UNLOADED, Observation -from dimos.protocol.service.spec import Configurable - -if TYPE_CHECKING: - from collections.abc import Iterator - - from reactivex.abc import DisposableBase - - from dimos.memory.backend import Backend, LiveChannel - from dimos.memory.buffer import BackpressureBuffer - from dimos.memory.filter import Filter, StreamQuery - -T = TypeVar("T") - -_IDENT_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") - - -# ── Helpers ────────────────────────────────────────────────────── - - -def _validate_identifier(name: str) -> None: - if not _IDENT_RE.match(name): - raise ValueError(f"Invalid stream name: {name!r}") - - -def _decompose_pose(pose: Any) -> tuple[float, ...] | None: - if pose is None: - return None - if hasattr(pose, "position"): - pos = pose.position - orient = getattr(pose, "orientation", None) - x, y, z = float(pos.x), float(pos.y), float(getattr(pos, "z", 0.0)) - if orient is not None: - return (x, y, z, float(orient.x), float(orient.y), float(orient.z), float(orient.w)) - return (x, y, z, 0.0, 0.0, 0.0, 1.0) - if isinstance(pose, (list, tuple)): - vals = [float(v) for v in pose] - while len(vals) < 7: - vals.append(0.0 if len(vals) < 6 else 1.0) - return tuple(vals[:7]) - return None - - -def _reconstruct_pose( - x: float | None, - y: float | None, - z: float | None, - qx: float | None, - qy: float | None, - qz: float | None, - qw: float | None, -) -> tuple[float, ...] | None: - if x is None: - return None - return (x, y or 0.0, z or 0.0, qx or 0.0, qy or 0.0, qz or 0.0, qw or 1.0) - - -def _compile_filter(f: Filter, stream: str, prefix: str = "") -> tuple[str, list[Any]] | None: - """Compile a filter to SQL WHERE clause. Returns None for non-pushable filters. - - ``stream`` is the raw stream name (for R*Tree table references). - ``prefix`` is a column qualifier (e.g. ``"meta."`` for JOIN queries). - """ - if isinstance(f, AfterFilter): - return (f"{prefix}ts > ?", [f.t]) - if isinstance(f, BeforeFilter): - return (f"{prefix}ts < ?", [f.t]) - if isinstance(f, TimeRangeFilter): - return (f"{prefix}ts >= ? AND {prefix}ts <= ?", [f.t1, f.t2]) - if isinstance(f, AtFilter): - return (f"ABS({prefix}ts - ?) <= ?", [f.t, f.tolerance]) - if isinstance(f, TagsFilter): - clauses = [] - params: list[Any] = [] - for k, v in f.tags.items(): - clauses.append(f"json_extract({prefix}tags, '$.{k}') = ?") - params.append(v) - return (" AND ".join(clauses), params) - if isinstance(f, NearFilter): - pose = f.pose - if pose is None: - return None - if hasattr(pose, "position"): - pose = pose.position - cx, cy, cz = _xyz(pose) - r = f.radius - # R*Tree bounding-box pre-filter + exact squared-distance check - rtree_sql = ( - f'{prefix}id IN (SELECT id FROM "{stream}_rtree" ' - f"WHERE x_min >= ? AND x_max <= ? " - f"AND y_min >= ? AND y_max <= ? " - f"AND z_min >= ? AND z_max <= ?)" - ) - dist_sql = ( - f"(({prefix}pose_x - ?) * ({prefix}pose_x - ?) + " - f"({prefix}pose_y - ?) * ({prefix}pose_y - ?) + " - f"({prefix}pose_z - ?) * ({prefix}pose_z - ?) <= ?)" - ) - return ( - f"{rtree_sql} AND {dist_sql}", - [ - cx - r, - cx + r, - cy - r, - cy + r, - cz - r, - cz + r, # R*Tree bbox - cx, - cx, - cy, - cy, - cz, - cz, - r * r, # squared distance - ], - ) - # PredicateFilter — not pushable - return None - - -def _compile_query( - query: StreamQuery, - table: str, - *, - join_blob: bool = False, -) -> tuple[str, list[Any], list[Filter]]: - """Compile a StreamQuery to SQL. - - Returns (sql, params, python_filters) where python_filters must be - applied as post-filters in Python. - """ - prefix = "meta." if join_blob else "" - if join_blob: - select = f'SELECT meta.id, meta.ts, meta.pose_x, meta.pose_y, meta.pose_z, meta.pose_qx, meta.pose_qy, meta.pose_qz, meta.pose_qw, json(meta.tags), blob.data FROM "{table}" AS meta JOIN "{table}_blob" AS blob ON blob.id = meta.id' - else: - select = f'SELECT id, ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, json(tags) FROM "{table}"' - - where_parts: list[str] = [] - params: list[Any] = [] - python_filters: list[Filter] = [] - - for f in query.filters: - compiled = _compile_filter(f, table, prefix) - if compiled is not None: - sql_part, sql_params = compiled - where_parts.append(sql_part) - params.extend(sql_params) - else: - python_filters.append(f) - - sql = select - if where_parts: - sql += " WHERE " + " AND ".join(where_parts) - - # ORDER BY - if query.order_field: - direction = "DESC" if query.order_desc else "ASC" - sql += f" ORDER BY {prefix}{query.order_field} {direction}" - else: - sql += f" ORDER BY {prefix}id ASC" - - # Only push LIMIT/OFFSET to SQL when there are no Python post-filters - if not python_filters and not query.search_text: - if query.limit_val is not None: - if query.offset_val: - sql += f" LIMIT {query.limit_val} OFFSET {query.offset_val}" - else: - sql += f" LIMIT {query.limit_val}" - elif query.offset_val: - sql += f" LIMIT -1 OFFSET {query.offset_val}" - - return (sql, params, python_filters) - - -def _compile_count( - query: StreamQuery, - table: str, -) -> tuple[str, list[Any], list[Filter]]: - """Compile a StreamQuery to a COUNT SQL query.""" - where_parts: list[str] = [] - params: list[Any] = [] - python_filters: list[Filter] = [] - - for f in query.filters: - compiled = _compile_filter(f, table) - if compiled is not None: - sql_part, sql_params = compiled - where_parts.append(sql_part) - params.extend(sql_params) - else: - python_filters.append(f) - - sql = f'SELECT COUNT(*) FROM "{table}"' - if where_parts: - sql += " WHERE " + " AND ".join(where_parts) - - return (sql, params, python_filters) - - -# ── SqliteBackend ──────────────────────────────────────────────── - - -class SqliteBackend(Configurable[BackendConfig], Generic[T]): - """SQLite-backed observation storage for a single stream (table).""" - - default_config: type[BackendConfig] = BackendConfig - - def __init__(self, conn: sqlite3.Connection, name: str, **kwargs: Any) -> None: - super().__init__(**kwargs) - self._conn = conn - self._name = name - self._codec: Codec[Any] = self.config.codec # type: ignore[assignment] - self._channel: LiveChannel[T] = self.config.live_channel or SubjectChannel() - self._lock = threading.Lock() - self._tag_indexes: set[str] = set() - - @property - def name(self) -> str: - return self._name - - @property - def live_channel(self) -> LiveChannel[T]: - return self._channel - - @property - def _join_blobs(self) -> bool: - if not self.config.eager_blobs: - return False - bs = self.config.blob_store - return isinstance(bs, SqliteBlobStore) and bs._conn is self._conn - - def _make_loader(self, row_id: int) -> Any: - bs = self.config.blob_store - assert bs is not None - name, codec = self._name, self._codec - owner_tid = threading.get_ident() - - def loader() -> Any: - assert threading.get_ident() == owner_tid - raw = bs.get(name, row_id) - return codec.decode(raw) - - return loader - - def _row_to_obs(self, row: tuple[Any, ...], *, has_blob: bool = False) -> Observation[T]: - if has_blob: - row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json, blob_data = row - else: - row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json = row - blob_data = None - - pose = _reconstruct_pose(px, py, pz, qx, qy, qz, qw) - tags = json.loads(tags_json) if tags_json else {} - - if has_blob and blob_data is not None: - data = self._codec.decode(blob_data) - return Observation(id=row_id, ts=ts, pose=pose, tags=tags, _data=data) - - return Observation( - id=row_id, - ts=ts, - pose=pose, - tags=tags, - _data=_UNLOADED, - _loader=self._make_loader(row_id), # type: ignore[arg-type] - ) - - # ── Write ──────────────────────────────────────────────────── - - def _ensure_tag_indexes(self, tags: dict[str, Any]) -> None: - """Auto-create expression indexes for any new tag keys.""" - for key in tags: - if key not in self._tag_indexes and _IDENT_RE.match(key): - self._conn.execute( - f'CREATE INDEX IF NOT EXISTS "{self._name}_tag_{key}" ' - f"ON \"{self._name}\"(json_extract(tags, '$.{key}'))" - ) - self._tag_indexes.add(key) - - def append(self, obs: Observation[T]) -> Observation[T]: - encoded = self._codec.encode(obs._data) - pose = _decompose_pose(obs.pose) - tags_json = json.dumps(obs.tags) if obs.tags else "{}" - - with self._lock: - if obs.tags: - self._ensure_tag_indexes(obs.tags) - if pose: - px, py, pz, qx, qy, qz, qw = pose - else: - px = py = pz = qx = qy = qz = qw = None # type: ignore[assignment] - - cur = self._conn.execute( - f'INSERT INTO "{self._name}" (ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, tags) ' - "VALUES (?, ?, ?, ?, ?, ?, ?, ?, jsonb(?))", - (obs.ts, px, py, pz, qx, qy, qz, qw, tags_json), - ) - row_id = cur.lastrowid - assert row_id is not None - - bs = self.config.blob_store - assert bs is not None - bs.put(self._name, row_id, encoded) - - # R*Tree spatial index - if pose: - self._conn.execute( - f'INSERT INTO "{self._name}_rtree" (id, x_min, x_max, y_min, y_max, z_min, z_max) ' - "VALUES (?, ?, ?, ?, ?, ?, ?)", - (row_id, px, px, py, py, pz, pz), - ) - - vs = self.config.vector_store - if vs is not None: - emb = getattr(obs, "embedding", None) - if emb is not None: - vs.put(self._name, row_id, emb) - - self._conn.commit() - - obs.id = row_id - self._channel.notify(obs) - return obs - - # ── Read ───────────────────────────────────────────────────── - - def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: - if query.search_vec is not None and query.live_buffer is not None: - raise TypeError("Cannot combine .search() with .live() — search is a batch operation.") - buf = query.live_buffer - if buf is not None: - sub = self._channel.subscribe(buf) - return self._iterate_live(query, buf, sub) - return self._iterate_snapshot(query) - - def _iterate_snapshot(self, query: StreamQuery) -> Iterator[Observation[T]]: - if query.search_vec is not None and self.config.vector_store is not None: - yield from self._vector_search(query) - return - - join = self._join_blobs - sql, params, python_filters = _compile_query(query, self._name, join_blob=join) - - cur = self._conn.execute(sql, params) - cur.arraysize = self.config.page_size - it: Iterator[Observation[T]] = (self._row_to_obs(r, has_blob=join) for r in cur) - - # Text search — requires loading data - if query.search_text is not None: - needle = query.search_text.lower() - it = (obs for obs in it if needle in str(obs.data).lower()) - - # Apply Python post-filters - if python_filters: - it = (obs for obs in it if all(f.matches(obs) for f in python_filters)) - - # Apply LIMIT/OFFSET in Python when we couldn't push to SQL - if python_filters or query.search_text: - if query.offset_val: - it = islice(it, query.offset_val, None) - if query.limit_val is not None: - it = islice(it, query.limit_val) - - yield from it - - def _vector_search(self, query: StreamQuery) -> Iterator[Observation[T]]: - vs = self.config.vector_store - assert vs is not None and query.search_vec is not None - - hits = vs.search(self._name, query.search_vec, query.search_k or 10) - if not hits: - return - - ids = [h[0] for h in hits] - dict(hits) - - # Batch-fetch metadata - join = self._join_blobs - placeholders = ",".join("?" * len(ids)) - if join: - sql = ( - f"SELECT meta.id, meta.ts, meta.pose_x, meta.pose_y, meta.pose_z, " - f"meta.pose_qx, meta.pose_qy, meta.pose_qz, meta.pose_qw, json(meta.tags), blob.data " - f'FROM "{self._name}" AS meta ' - f'JOIN "{self._name}_blob" AS blob ON blob.id = meta.id ' - f"WHERE meta.id IN ({placeholders})" - ) - else: - sql = ( - f"SELECT id, ts, pose_x, pose_y, pose_z, " - f"pose_qx, pose_qy, pose_qz, pose_qw, json(tags) " - f'FROM "{self._name}" WHERE id IN ({placeholders})' - ) - - rows = self._conn.execute(sql, ids).fetchall() - obs_by_id: dict[int, Observation[T]] = {} - for r in rows: - obs = self._row_to_obs(r, has_blob=join) - obs_by_id[obs.id] = obs - - # Preserve VectorStore ranking order, promoting to EmbeddedObservation - ranked: list[Observation[T]] = [] - for obs_id, sim in hits: - match = obs_by_id.get(obs_id) - if match is not None: - ranked.append( - match.derive(data=match.data, embedding=query.search_vec, similarity=sim) - ) - - # Apply remaining query ops (skip vector search) - rest = replace(query, search_vec=None, search_k=None) - yield from rest.apply(iter(ranked)) - - def _iterate_live( - self, - query: StreamQuery, - buf: BackpressureBuffer[Observation[T]], - sub: DisposableBase, - ) -> Iterator[Observation[T]]: - from dimos.memory.buffer import ClosedError - - # Backfill phase - last_id = -1 - for obs in self._iterate_snapshot(query): - last_id = max(last_id, obs.id) - yield obs - - # Live tail - filters = query.filters - try: - while True: - obs = buf.take() - if obs.id <= last_id: - continue - last_id = obs.id - if filters and not all(f.matches(obs) for f in filters): - continue - yield obs - except (ClosedError, StopIteration): - sub.dispose() - - def count(self, query: StreamQuery) -> int: - if query.search_vec or query.search_text: - return sum(1 for _ in self.iterate(query)) - - sql, params, python_filters = _compile_count(query, self._name) - if python_filters: - return sum(1 for _ in self.iterate(query)) - - row = self._conn.execute(sql, params).fetchone() - return int(row[0]) if row else 0 - - -# ── SqliteSession ──────────────────────────────────────────────── - - -class SqliteSession(Session): - """Session owning a single SQLite connection.""" - - def __init__( - self, conn: sqlite3.Connection, *, vec_available: bool = False, **kwargs: Any - ) -> None: - super().__init__(**kwargs) - self._conn = conn - self._vec_available = vec_available - self._blob_store: SqliteBlobStore | None = None - self._vector_store: Any | None = None - - # Create stream registry - self._conn.execute( - "CREATE TABLE IF NOT EXISTS _streams (" - " name TEXT PRIMARY KEY," - " payload_module TEXT NOT NULL," - " codec_id TEXT NOT NULL" - ")" - ) - self._conn.commit() - - def _ensure_shared_stores(self) -> None: - """Lazily create shared stores on first stream creation.""" - if self._blob_store is None: - self._blob_store = SqliteBlobStore(self._conn) - if self._vector_store is None and self._vec_available: - from dimos.memory.vectorstore.sqlite import SqliteVectorStore - - self._vector_store = SqliteVectorStore(self._conn) - - @staticmethod - def _codec_id(codec: Codec[Any]) -> str: - from dimos.memory.codecs.jpeg import JpegCodec - from dimos.memory.codecs.lcm import LcmCodec - - if isinstance(codec, JpegCodec): - return "jpeg" - if isinstance(codec, LcmCodec): - return "lcm" - return "pickle" - - @staticmethod - def _codec_from_id(codec_id: str, payload_module: str) -> Codec[Any]: - from dimos.memory.codecs.pickle import PickleCodec - - if codec_id == "jpeg": - from dimos.memory.codecs.jpeg import JpegCodec - - return JpegCodec() - if codec_id == "lcm": - from dimos.memory.codecs.lcm import LcmCodec - - # Resolve the payload type from module path - parts = payload_module.rsplit(".", 1) - if len(parts) == 2: - import importlib - - mod = importlib.import_module(parts[0]) - cls = getattr(mod, parts[1]) - return LcmCodec(cls) - return PickleCodec() - return PickleCodec() - - def _create_backend( - self, name: str, payload_type: type[Any] | None = None, **config: Any - ) -> Backend[Any]: - _validate_identifier(name) - self._ensure_shared_stores() - - # Look up existing stream in registry - row = self._conn.execute( - "SELECT payload_module, codec_id FROM _streams WHERE name = ?", (name,) - ).fetchone() - - if row is not None: - stored_module, stored_codec_id = row - if payload_type is not None: - actual_module = f"{payload_type.__module__}.{payload_type.__qualname__}" - if actual_module != stored_module: - raise ValueError( - f"Stream {name!r} was created with type {stored_module}, " - f"but opened with {actual_module}" - ) - codec = config.get("codec") or self._codec_from_id(stored_codec_id, stored_module) - else: - if payload_type is None: - raise TypeError(f"Stream {name!r} does not exist yet — payload_type is required") - codec = config.get("codec") or codec_for(payload_type) - payload_module = f"{payload_type.__module__}.{payload_type.__qualname__}" - self._conn.execute( - "INSERT INTO _streams (name, payload_module, codec_id) VALUES (?, ?, ?)", - (name, payload_module, self._codec_id(codec)), - ) - self._conn.commit() - - # Create metadata table - self._conn.execute( - f'CREATE TABLE IF NOT EXISTS "{name}" (' - " id INTEGER PRIMARY KEY AUTOINCREMENT," - " ts REAL NOT NULL UNIQUE," - " pose_x REAL, pose_y REAL, pose_z REAL," - " pose_qx REAL, pose_qy REAL, pose_qz REAL, pose_qw REAL," - " tags BLOB DEFAULT (jsonb('{}'))" - ")" - ) - # R*Tree spatial index for pose queries - self._conn.execute( - f'CREATE VIRTUAL TABLE IF NOT EXISTS "{name}_rtree" USING rtree(' - " id," - " x_min, x_max," - " y_min, y_max," - " z_min, z_max" - ")" - ) - self._conn.commit() - - # Merge shared stores as defaults - if "blob_store" not in config or config["blob_store"] is None: - config["blob_store"] = self._blob_store - if "vector_store" not in config or config["vector_store"] is None: - config["vector_store"] = self._vector_store - config["codec"] = codec - - return SqliteBackend(self._conn, name, **config) - - def list_streams(self) -> list[str]: - db_names = {row[0] for row in self._conn.execute("SELECT name FROM _streams").fetchall()} - return sorted(db_names | set(self._streams.keys())) - - def delete_stream(self, name: str) -> None: - self._streams.pop(name, None) - self._conn.execute(f'DROP TABLE IF EXISTS "{name}"') - self._conn.execute(f'DROP TABLE IF EXISTS "{name}_blob"') - self._conn.execute(f'DROP TABLE IF EXISTS "{name}_vec"') - self._conn.execute(f'DROP TABLE IF EXISTS "{name}_rtree"') - self._conn.execute("DELETE FROM _streams WHERE name = ?", (name,)) - self._conn.commit() - - def stop(self) -> None: - super().stop() - self._conn.close() - - -# ── SqliteStore ────────────────────────────────────────────────── - - -@dataclass -class SqliteStoreConfig(StoreConfig): - """Config for SQLite-backed store.""" - - path: str = "memory.db" - - -class SqliteStore(Store): - """Store backed by a SQLite database file.""" - - default_config: type[SqliteStoreConfig] = SqliteStoreConfig - config: SqliteStoreConfig - - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - - def session(self, **kwargs: Any) -> SqliteSession: - conn = sqlite3.connect(self.config.path, check_same_thread=False) - conn.execute("PRAGMA journal_mode=WAL") - conn.execute("PRAGMA synchronous=NORMAL") - - vec_available = False - try: - import sqlite_vec - - conn.enable_load_extension(True) - sqlite_vec.load(conn) - conn.enable_load_extension(False) - vec_available = True - except (ImportError, Exception): - pass - - return SqliteSession(conn, vec_available=vec_available, **kwargs) diff --git a/dimos/memory/livechannel/__init__.py b/dimos/memory/livechannel/__init__.py deleted file mode 100644 index 143c8e95bf..0000000000 --- a/dimos/memory/livechannel/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from dimos.memory.backend import LiveChannel -from dimos.memory.livechannel.subject import SubjectChannel - -__all__ = ["LiveChannel", "SubjectChannel"] diff --git a/dimos/memory/test_embedding.py b/dimos/memory/test_embedding.py index b05ae619a0..b7e7fbb294 100644 --- a/dimos/memory/test_embedding.py +++ b/dimos/memory/test_embedding.py @@ -12,444 +12,42 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for embedding layer: EmbeddedObservation, vector search, text search, transformers.""" - -from __future__ import annotations - -import numpy as np import pytest -from dimos.memory.impl.memory import MemoryStore -from dimos.memory.type import EmbeddedObservation, Observation -from dimos.models.embedding.base import Embedding - -# ── Helpers ─────────────────────────────────────────────────────── - - -def _emb(vec: list[float]) -> Embedding: - """Return a unit-normalized Embedding.""" - v = np.array(vec, dtype=np.float32) - v /= np.linalg.norm(v) + 1e-10 - return Embedding(vector=v) - - -# ── EmbeddedObservation ────────────────────────────────────────── - - -class TestEmbeddedObservation: - def test_construction(self) -> None: - emb = _emb([1, 0, 0]) - obs = EmbeddedObservation(id=0, ts=1.0, _data="hello", embedding=emb) - assert obs.data == "hello" - assert obs.embedding is emb - assert obs.similarity is None - - def test_is_observation(self) -> None: - obs = EmbeddedObservation(id=0, ts=1.0, _data="x", embedding=_emb([1, 0])) - assert isinstance(obs, Observation) - - def test_derive_preserves_embedding(self) -> None: - emb = _emb([1, 0, 0]) - obs = EmbeddedObservation(id=0, ts=1.0, _data="a", embedding=emb) - derived = obs.derive(data="b") - assert isinstance(derived, EmbeddedObservation) - assert derived.embedding is emb - assert derived.data == "b" - - def test_derive_replaces_embedding(self) -> None: - old = _emb([1, 0, 0]) - new = _emb([0, 1, 0]) - obs = EmbeddedObservation(id=0, ts=1.0, _data="a", embedding=old) - derived = obs.derive(data="a", embedding=new) - assert derived.embedding is new - - def test_derive_preserves_similarity(self) -> None: - obs = EmbeddedObservation(id=0, ts=1.0, _data="a", embedding=_emb([1, 0]), similarity=0.95) - derived = obs.derive(data="b") - assert derived.similarity == 0.95 - - def test_observation_derive_promotes_to_embedded(self) -> None: - obs = Observation(id=0, ts=1.0, _data="plain") - emb = _emb([1, 0, 0]) - derived = obs.derive(data="plain", embedding=emb) - assert isinstance(derived, EmbeddedObservation) - assert derived.embedding is emb - - def test_observation_derive_without_embedding_stays_observation(self) -> None: - obs = Observation(id=0, ts=1.0, _data="plain") - derived = obs.derive(data="still plain") - assert type(derived) is Observation - - -# ── ListBackend embedding support ──────────────────────────────── - - -class TestListBackendEmbedding: - def test_append_with_embedding(self) -> None: - store = MemoryStore() - with store.session() as session: - s = session.stream("vecs", str) - emb = _emb([1, 0, 0]) - obs = s.append("hello", embedding=emb) - assert isinstance(obs, EmbeddedObservation) - assert obs.embedding is emb - - def test_append_without_embedding(self) -> None: - store = MemoryStore() - with store.session() as session: - s = session.stream("plain", str) - obs = s.append("hello") - assert type(obs) is Observation - - def test_search_returns_top_k(self) -> None: - store = MemoryStore() - with store.session() as session: - s = session.stream("vecs", str) - s.append("north", embedding=_emb([0, 1, 0])) - s.append("east", embedding=_emb([1, 0, 0])) - s.append("south", embedding=_emb([0, -1, 0])) - s.append("west", embedding=_emb([-1, 0, 0])) - - results = s.search(_emb([0, 1, 0]), k=2).fetch() - assert len(results) == 2 - assert results[0].data == "north" - assert results[0].similarity is not None - assert results[0].similarity > 0.99 - - def test_search_sorted_by_similarity(self) -> None: - store = MemoryStore() - with store.session() as session: - s = session.stream("vecs", str) - s.append("far", embedding=_emb([0, -1, 0])) - s.append("close", embedding=_emb([0.9, 0.1, 0])) - s.append("exact", embedding=_emb([1, 0, 0])) - - results = s.search(_emb([1, 0, 0]), k=3).fetch() - assert results[0].data == "exact" - assert results[1].data == "close" - assert results[2].data == "far" - # Descending similarity - assert results[0].similarity >= results[1].similarity >= results[2].similarity - - def test_search_skips_non_embedded(self) -> None: - store = MemoryStore() - with store.session() as session: - s = session.stream("mixed", str) - s.append("plain") # no embedding - s.append("embedded", embedding=_emb([1, 0, 0])) - - results = s.search(_emb([1, 0, 0]), k=10).fetch() - assert len(results) == 1 - assert results[0].data == "embedded" - - def test_search_with_filters(self) -> None: - store = MemoryStore() - with store.session() as session: - s = session.stream("vecs", str) - s.append("early", ts=10.0, embedding=_emb([1, 0, 0])) - s.append("late", ts=20.0, embedding=_emb([1, 0, 0])) - - # Only the late one should pass the after filter - results = s.after(15.0).search(_emb([1, 0, 0]), k=10).fetch() - assert len(results) == 1 - assert results[0].data == "late" - - def test_search_with_limit(self) -> None: - store = MemoryStore() - with store.session() as session: - s = session.stream("vecs", str) - for i in range(10): - s.append(f"item{i}", embedding=_emb([1, 0, 0])) - - # search k=5 then limit 2 - results = s.search(_emb([1, 0, 0]), k=5).limit(2).fetch() - assert len(results) == 2 - - def test_search_with_live_raises(self) -> None: - store = MemoryStore() - with store.session() as session: - s = session.stream("vecs", str) - s.append("x", embedding=_emb([1, 0, 0])) - with pytest.raises(TypeError, match="Cannot combine"): - list(s.live().search(_emb([1, 0, 0]), k=5)) - - -# ── Text search ────────────────────────────────────────────────── - - -class TestTextSearch: - def test_search_text_substring(self) -> None: - store = MemoryStore() - with store.session() as session: - s = session.stream("logs", str) - s.append("motor fault detected") - s.append("temperature normal") - s.append("motor overheating") - - results = s.search_text("motor").fetch() - assert len(results) == 2 - assert {r.data for r in results} == {"motor fault detected", "motor overheating"} - - def test_search_text_case_insensitive(self) -> None: - store = MemoryStore() - with store.session() as session: - s = session.stream("logs", str) - s.append("Motor Fault") - s.append("other event") - - results = s.search_text("motor fault").fetch() - assert len(results) == 1 - - def test_search_text_with_filters(self) -> None: - store = MemoryStore() - with store.session() as session: - s = session.stream("logs", str) - s.append("motor fault", ts=10.0) - s.append("motor warning", ts=20.0) - s.append("motor fault", ts=30.0) - - results = s.after(15.0).search_text("fault").fetch() - assert len(results) == 1 - assert results[0].ts == 30.0 - - def test_search_text_no_match(self) -> None: - store = MemoryStore() - with store.session() as session: - s = session.stream("logs", str) - s.append("all clear") - - results = s.search_text("motor").fetch() - assert len(results) == 0 - - -# ── Save preserves embeddings ──────────────────────────────────── - - -class TestSaveEmbeddings: - def test_save_preserves_embeddings(self) -> None: - store = MemoryStore() - with store.session() as session: - src = session.stream("source", str) - dst = session.stream("dest", str) - - emb = _emb([1, 0, 0]) - src.append("item", embedding=emb) - src.save(dst) - - results = dst.fetch() - assert len(results) == 1 - assert isinstance(results[0], EmbeddedObservation) - # Same vector content (different Embedding instance after re-append) - np.testing.assert_array_almost_equal(results[0].embedding.to_numpy(), emb.to_numpy()) - - def test_save_mixed_embedded_and_plain(self) -> None: - store = MemoryStore() - with store.session() as session: - src = session.stream("source", str) - dst = session.stream("dest", str) - - src.append("plain") - src.append("embedded", embedding=_emb([0, 1, 0])) - src.save(dst) - - results = dst.fetch() - assert len(results) == 2 - assert type(results[0]) is Observation - assert isinstance(results[1], EmbeddedObservation) - - -# ── Embed transformers (mock model) ───────────────────────────── - - -class _MockEmbeddingModel: - """Fake EmbeddingModel that returns deterministic unit vectors.""" - - device = "cpu" - - def embed(self, *images): - vecs = [] - for img in images: - rng = np.random.default_rng(hash(str(img)) % 2**32) - v = rng.standard_normal(8).astype(np.float32) - v /= np.linalg.norm(v) - vecs.append(Embedding(vector=v)) - return vecs if len(vecs) > 1 else vecs[0] - - def embed_text(self, *texts): - vecs = [] - for text in texts: - rng = np.random.default_rng(hash(text) % 2**32) - v = rng.standard_normal(8).astype(np.float32) - v /= np.linalg.norm(v) - vecs.append(Embedding(vector=v)) - return vecs if len(vecs) > 1 else vecs[0] - - -class TestEmbedTransformers: - def test_embed_images_produces_embedded_observations(self) -> None: - from dimos.memory.embed import EmbedImages - - model = _MockEmbeddingModel() - store = MemoryStore() - with store.session() as session: - s = session.stream("imgs", str) - s.append("img1", ts=1.0) - s.append("img2", ts=2.0) - - results = s.transform(EmbedImages(model)).fetch() - assert len(results) == 2 - for obs in results: - assert isinstance(obs, EmbeddedObservation) - assert isinstance(obs.embedding, Embedding) - assert obs.embedding.to_numpy().shape == (8,) - - def test_embed_text_produces_embedded_observations(self) -> None: - from dimos.memory.embed import EmbedText - - model = _MockEmbeddingModel() - store = MemoryStore() - with store.session() as session: - s = session.stream("logs", str) - s.append("motor fault", ts=1.0) - s.append("all clear", ts=2.0) - - results = s.transform(EmbedText(model)).fetch() - assert len(results) == 2 - for obs in results: - assert isinstance(obs, EmbeddedObservation) - assert isinstance(obs.embedding, Embedding) - - def test_embed_preserves_data(self) -> None: - from dimos.memory.embed import EmbedText - - model = _MockEmbeddingModel() - store = MemoryStore() - with store.session() as session: - s = session.stream("logs", str) - s.append("hello", ts=1.0) - - result = s.transform(EmbedText(model)).first() - assert result.data == "hello" - - def test_embed_then_search(self) -> None: - from dimos.memory.embed import EmbedText - - model = _MockEmbeddingModel() - store = MemoryStore() - with store.session() as session: - s = session.stream("logs", str) - for i in range(10): - s.append(f"log entry {i}", ts=float(i)) - - embedded = s.transform(EmbedText(model)) - # Get the embedding for the first item, then search for similar - first_emb = embedded.first().embedding - results = embedded.search(first_emb, k=3).fetch() - assert len(results) == 3 - # First result should be the exact match - assert results[0].similarity is not None - assert results[0].similarity > 0.99 - - def test_embed_batching(self) -> None: - from dimos.memory.embed import EmbedText - - call_sizes: list[int] = [] - - class _TrackingModel(_MockEmbeddingModel): - def embed_text(self, *texts): - call_sizes.append(len(texts)) - return super().embed_text(*texts) - - model = _TrackingModel() - store = MemoryStore() - with store.session() as session: - s = session.stream("logs", str) - for i in range(5): - s.append(f"entry {i}") - - list(s.transform(EmbedText(model, batch_size=2))) - # 5 items with batch_size=2 → 3 calls (2, 2, 1) - assert call_sizes == [2, 2, 1] - - -# ── Pluggable VectorStore ──────────────────────────────────────── - - -class TestPluggableVectorStore: - """Verify that injecting a VectorStore via session config actually delegates search.""" - - def test_append_stores_in_vector_store(self) -> None: - from dimos.memory.vectorstore import MemoryVectorStore - - vs = MemoryVectorStore() - store = MemoryStore() - with store.session(vector_store=vs) as session: - s = session.stream("vecs", str) - s.append("hello", embedding=_emb([1, 0, 0])) - s.append("world", embedding=_emb([0, 1, 0])) - - assert len(vs._vectors["vecs"]) == 2 - - def test_append_without_embedding_skips_vector_store(self) -> None: - from dimos.memory.vectorstore import MemoryVectorStore - - vs = MemoryVectorStore() - store = MemoryStore() - with store.session(vector_store=vs) as session: - s = session.stream("plain", str) - s.append("no embedding") - - assert "plain" not in vs._vectors - - def test_search_uses_vector_store(self) -> None: - from dimos.memory.vectorstore import MemoryVectorStore - - vs = MemoryVectorStore() - store = MemoryStore() - with store.session(vector_store=vs) as session: - s = session.stream("vecs", str) - s.append("north", embedding=_emb([0, 1, 0])) - s.append("east", embedding=_emb([1, 0, 0])) - s.append("south", embedding=_emb([0, -1, 0])) - s.append("west", embedding=_emb([-1, 0, 0])) - - results = s.search(_emb([0, 1, 0]), k=2).fetch() - assert len(results) == 2 - assert results[0].data == "north" - assert results[0].similarity is not None - assert results[0].similarity > 0.99 - - def test_search_with_filters_via_vector_store(self) -> None: - from dimos.memory.vectorstore import MemoryVectorStore - - vs = MemoryVectorStore() - store = MemoryStore() - with store.session(vector_store=vs) as session: - s = session.stream("vecs", str) - s.append("early", ts=10.0, embedding=_emb([1, 0, 0])) - s.append("late", ts=20.0, embedding=_emb([1, 0, 0])) - - # Filter + search: only "late" passes the after filter - results = s.after(15.0).search(_emb([1, 0, 0]), k=10).fetch() - assert len(results) == 1 - assert results[0].data == "late" - - def test_per_stream_vector_store_override(self) -> None: - from dimos.memory.vectorstore import MemoryVectorStore - - vs_default = MemoryVectorStore() - vs_override = MemoryVectorStore() - store = MemoryStore() - with store.session(vector_store=vs_default) as session: - # Stream with default vector store - s1 = session.stream("s1", str) - s1.append("a", embedding=_emb([1, 0, 0])) - - # Stream with overridden vector store - s2 = session.stream("s2", str, vector_store=vs_override) - s2.append("b", embedding=_emb([0, 1, 0])) - - assert "s1" in vs_default._vectors - assert "s1" not in vs_override._vectors - assert "s2" in vs_override._vectors - assert "s2" not in vs_default._vectors +from dimos.memory.embedding import EmbeddingMemory, SpatialEntry +from dimos.msgs.geometry_msgs import PoseStamped +from dimos.utils.data import get_data +from dimos.utils.testing import TimedSensorReplay + +dir_name = "unitree_go2_bigoffice" + + +@pytest.mark.skip +def test_embed_frame() -> None: + """Test embedding a single frame.""" + # Load a frame from recorded data + video = TimedSensorReplay(get_data(dir_name) / "video") + frame = video.find_closest_seek(10) + + # Create memory and embed + memory = EmbeddingMemory() + + try: + # Create a spatial entry with dummy pose (no TF needed for this test) + dummy_pose = PoseStamped( + position=[0, 0, 0], + orientation=[0, 0, 0, 1], # identity quaternion + ) + spatial_entry = SpatialEntry(image=frame, pose=dummy_pose) + + # Embed the frame + result = memory._embed_spatial_entry(spatial_entry) + + # Verify + assert result is not None + assert result.embedding is not None + assert result.embedding.vector is not None + print(f"Embedding shape: {result.embedding.vector.shape}") + print(f"Embedding vector (first 5): {result.embedding.vector[:5]}") + finally: + memory.stop() diff --git a/dimos/memory_old/timeseries/__init__.py b/dimos/memory/timeseries/__init__.py similarity index 70% rename from dimos/memory_old/timeseries/__init__.py rename to dimos/memory/timeseries/__init__.py index 51130005b3..debc14ab3a 100644 --- a/dimos/memory_old/timeseries/__init__.py +++ b/dimos/memory/timeseries/__init__.py @@ -13,19 +13,19 @@ # limitations under the License. """Time series storage and replay.""" -from dimos.memory_old.timeseries.base import TimeSeriesStore -from dimos.memory_old.timeseries.inmemory import InMemoryStore -from dimos.memory_old.timeseries.pickledir import PickleDirStore -from dimos.memory_old.timeseries.sqlite import SqliteTSStore +from dimos.memory.timeseries.base import TimeSeriesStore +from dimos.memory.timeseries.inmemory import InMemoryStore +from dimos.memory.timeseries.pickledir import PickleDirStore +from dimos.memory.timeseries.sqlite import SqliteStore def __getattr__(name: str): # type: ignore[no-untyped-def] if name == "PostgresStore": - from dimos.memory_old.timeseries.postgres import PostgresStore + from dimos.memory.timeseries.postgres import PostgresStore return PostgresStore if name == "reset_db": - from dimos.memory_old.timeseries.postgres import reset_db + from dimos.memory.timeseries.postgres import reset_db return reset_db raise AttributeError(f"module {__name__!r} has no attribute {name!r}") @@ -35,7 +35,7 @@ def __getattr__(name: str): # type: ignore[no-untyped-def] "InMemoryStore", "PickleDirStore", "PostgresStore", - "SqliteTSStore", + "SqliteStore", "TimeSeriesStore", "reset_db", ] diff --git a/dimos/memory_old/timeseries/base.py b/dimos/memory/timeseries/base.py similarity index 100% rename from dimos/memory_old/timeseries/base.py rename to dimos/memory/timeseries/base.py diff --git a/dimos/memory_old/timeseries/inmemory.py b/dimos/memory/timeseries/inmemory.py similarity index 98% rename from dimos/memory_old/timeseries/inmemory.py rename to dimos/memory/timeseries/inmemory.py index 608235c11d..b67faca644 100644 --- a/dimos/memory_old/timeseries/inmemory.py +++ b/dimos/memory/timeseries/inmemory.py @@ -17,7 +17,7 @@ from sortedcontainers import SortedKeyList # type: ignore[import-untyped] -from dimos.memory_old.timeseries.base import T, TimeSeriesStore +from dimos.memory.timeseries.base import T, TimeSeriesStore class InMemoryStore(TimeSeriesStore[T]): diff --git a/dimos/memory_old/timeseries/legacy.py b/dimos/memory/timeseries/legacy.py similarity index 99% rename from dimos/memory_old/timeseries/legacy.py rename to dimos/memory/timeseries/legacy.py index abc0adff3a..15a4ff90fa 100644 --- a/dimos/memory_old/timeseries/legacy.py +++ b/dimos/memory/timeseries/legacy.py @@ -30,7 +30,7 @@ from reactivex.observable import Observable from reactivex.scheduler import TimeoutScheduler -from dimos.memory_old.timeseries.base import T, TimeSeriesStore +from dimos.memory.timeseries.base import T, TimeSeriesStore from dimos.utils.data import get_data, get_data_dir diff --git a/dimos/memory_old/timeseries/pickledir.py b/dimos/memory/timeseries/pickledir.py similarity index 99% rename from dimos/memory_old/timeseries/pickledir.py rename to dimos/memory/timeseries/pickledir.py index 719c9f8a94..9e8cd5a249 100644 --- a/dimos/memory_old/timeseries/pickledir.py +++ b/dimos/memory/timeseries/pickledir.py @@ -20,7 +20,7 @@ from pathlib import Path import pickle -from dimos.memory_old.timeseries.base import T, TimeSeriesStore +from dimos.memory.timeseries.base import T, TimeSeriesStore from dimos.utils.data import get_data, get_data_dir diff --git a/dimos/memory_old/timeseries/postgres.py b/dimos/memory/timeseries/postgres.py similarity index 99% rename from dimos/memory_old/timeseries/postgres.py rename to dimos/memory/timeseries/postgres.py index c6774d3920..0daae44adb 100644 --- a/dimos/memory_old/timeseries/postgres.py +++ b/dimos/memory/timeseries/postgres.py @@ -21,7 +21,7 @@ import psycopg2.extensions # type: ignore[import-untyped] from dimos.core.resource import Resource -from dimos.memory_old.timeseries.base import T, TimeSeriesStore +from dimos.memory.timeseries.base import T, TimeSeriesStore # Valid SQL identifier: alphanumeric and underscores, not starting with digit _VALID_IDENTIFIER = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") diff --git a/dimos/memory_old/timeseries/sqlite.py b/dimos/memory/timeseries/sqlite.py similarity index 96% rename from dimos/memory_old/timeseries/sqlite.py rename to dimos/memory/timeseries/sqlite.py index a7d3fcbb35..6e2ac7a7f5 100644 --- a/dimos/memory_old/timeseries/sqlite.py +++ b/dimos/memory/timeseries/sqlite.py @@ -19,7 +19,7 @@ import re import sqlite3 -from dimos.memory_old.timeseries.base import T, TimeSeriesStore +from dimos.memory.timeseries.base import T, TimeSeriesStore from dimos.utils.data import get_data, get_data_dir # Valid SQL identifier: alphanumeric and underscores, not starting with digit @@ -37,24 +37,24 @@ def _validate_identifier(name: str) -> str: return name -class SqliteTSStore(TimeSeriesStore[T]): +class SqliteStore(TimeSeriesStore[T]): """SQLite backend for sensor data. Good for indexed queries and single-file storage. Data is stored as pickled BLOBs with timestamp as indexed column. Usage: # Named store (uses data/ directory, auto-downloads from LFS if needed) - store = SqliteTSStore("recordings/lidar") # -> data/recordings/lidar.db + store = SqliteStore("recordings/lidar") # -> data/recordings/lidar.db store.save(data) # saves using data.ts # Absolute path - store = SqliteTSStore("/path/to/sensors.db") + store = SqliteStore("/path/to/sensors.db") # In-memory (for testing) - store = SqliteTSStore(":memory:") + store = SqliteStore(":memory:") # Multiple tables in one DB - store = SqliteTSStore("recordings/sensors", table="lidar") + store = SqliteStore("recordings/sensors", table="lidar") """ def __init__(self, name: str | Path, table: str = "sensor_data") -> None: diff --git a/dimos/memory_old/timeseries/test_base.py b/dimos/memory/timeseries/test_base.py similarity index 96% rename from dimos/memory_old/timeseries/test_base.py rename to dimos/memory/timeseries/test_base.py index 491f0ed534..9491d2c93c 100644 --- a/dimos/memory_old/timeseries/test_base.py +++ b/dimos/memory/timeseries/test_base.py @@ -20,11 +20,11 @@ import pytest -from dimos.memory_old.timeseries.base import TimeSeriesStore -from dimos.memory_old.timeseries.inmemory import InMemoryStore -from dimos.memory_old.timeseries.legacy import LegacyPickleStore -from dimos.memory_old.timeseries.pickledir import PickleDirStore -from dimos.memory_old.timeseries.sqlite import SqliteTSStore +from dimos.memory.timeseries.base import TimeSeriesStore +from dimos.memory.timeseries.inmemory import InMemoryStore +from dimos.memory.timeseries.legacy import LegacyPickleStore +from dimos.memory.timeseries.pickledir import PickleDirStore +from dimos.memory.timeseries.sqlite import SqliteStore from dimos.types.timestamped import Timestamped @@ -60,7 +60,7 @@ def make_pickle_dir_store(tmpdir: str) -> TimeSeriesStore[SampleData]: def make_sqlite_store(tmpdir: str) -> TimeSeriesStore[SampleData]: - return SqliteTSStore[SampleData](Path(tmpdir) / "test.db") + return SqliteStore[SampleData](Path(tmpdir) / "test.db") def make_legacy_pickle_store(tmpdir: str) -> TimeSeriesStore[SampleData]: @@ -71,7 +71,7 @@ def make_legacy_pickle_store(tmpdir: str) -> TimeSeriesStore[SampleData]: testdata: list[tuple[object, str]] = [ (lambda _: make_in_memory_store(), "InMemoryStore"), (lambda tmpdir: make_pickle_dir_store(tmpdir), "PickleDirStore"), - (lambda tmpdir: make_sqlite_store(tmpdir), "SqliteTSStore"), + (lambda tmpdir: make_sqlite_store(tmpdir), "SqliteStore"), (lambda tmpdir: make_legacy_pickle_store(tmpdir), "LegacyPickleStore"), ] @@ -81,7 +81,7 @@ def make_legacy_pickle_store(tmpdir: str) -> TimeSeriesStore[SampleData]: try: import psycopg2 - from dimos.memory_old.timeseries.postgres import PostgresStore + from dimos.memory.timeseries.postgres import PostgresStore # Test connection _test_conn = psycopg2.connect(dbname="dimensional") diff --git a/dimos/memory_old/timeseries/test_legacy.py b/dimos/memory/timeseries/test_legacy.py similarity index 96% rename from dimos/memory_old/timeseries/test_legacy.py rename to dimos/memory/timeseries/test_legacy.py index 145af0d1f4..c77ec64a76 100644 --- a/dimos/memory_old/timeseries/test_legacy.py +++ b/dimos/memory/timeseries/test_legacy.py @@ -15,7 +15,7 @@ import pytest -from dimos.memory_old.timeseries.legacy import LegacyPickleStore +from dimos.memory.timeseries.legacy import LegacyPickleStore class TestLegacyPickleStoreRealData: diff --git a/dimos/memory/__init__.py b/dimos/memory2/__init__.py similarity index 59% rename from dimos/memory/__init__.py rename to dimos/memory2/__init__.py index b98418e415..0b358fe438 100644 --- a/dimos/memory/__init__.py +++ b/dimos/memory2/__init__.py @@ -1,5 +1,5 @@ -from dimos.memory.backend import Backend, LiveChannel, VectorStore -from dimos.memory.buffer import ( +from dimos.memory2.backend import Backend, LiveChannel, VectorStore +from dimos.memory2.buffer import ( BackpressureBuffer, Bounded, ClosedError, @@ -7,8 +7,8 @@ KeepLast, Unbounded, ) -from dimos.memory.embed import EmbedImages, EmbedText -from dimos.memory.filter import ( +from dimos.memory2.embed import EmbedImages, EmbedText +from dimos.memory2.filter import ( AfterFilter, AtFilter, BeforeFilter, @@ -19,13 +19,13 @@ TagsFilter, TimeRangeFilter, ) -from dimos.memory.impl.memory import ListBackend, MemorySession, MemoryStore -from dimos.memory.impl.sqlite import SqliteBackend, SqliteSession, SqliteStore, SqliteStoreConfig -from dimos.memory.livechannel import SubjectChannel -from dimos.memory.store import Session, SessionConfig, Store, StoreConfig, StreamNamespace -from dimos.memory.stream import Stream -from dimos.memory.transform import FnTransformer, QualityWindow, Transformer -from dimos.memory.type import EmbeddedObservation, Observation +from dimos.memory2.impl.memory import ListBackend, MemorySession, MemoryStore +from dimos.memory2.impl.sqlite import SqliteBackend, SqliteSession, SqliteStore, SqliteStoreConfig +from dimos.memory2.livechannel import SubjectChannel +from dimos.memory2.store import Session, SessionConfig, Store, StoreConfig, StreamNamespace +from dimos.memory2.stream import Stream +from dimos.memory2.transform import FnTransformer, QualityWindow, Transformer +from dimos.memory2.type import EmbeddedObservation, Observation __all__ = [ "AfterFilter", diff --git a/dimos/memory/architecture.md b/dimos/memory2/architecture.md similarity index 99% rename from dimos/memory/architecture.md rename to dimos/memory2/architecture.md index 7fba703f4c..4acaf4a7f9 100644 --- a/dimos/memory/architecture.md +++ b/dimos/memory2/architecture.md @@ -74,7 +74,7 @@ Transform-sourced streams (post `.transform()`) always use `StreamQuery.apply()` ## Quick start ```python -from dimos.memory import MemoryStore +from dimos.memory2 import MemoryStore store = MemoryStore() with store.session() as session: diff --git a/dimos/memory/backend.py b/dimos/memory2/backend.py similarity index 96% rename from dimos/memory/backend.py rename to dimos/memory2/backend.py index cc36f79239..928b74e229 100644 --- a/dimos/memory/backend.py +++ b/dimos/memory2/backend.py @@ -25,10 +25,10 @@ from reactivex.abc import DisposableBase - from dimos.memory.buffer import BackpressureBuffer - from dimos.memory.codecs.base import Codec - from dimos.memory.filter import StreamQuery - from dimos.memory.type import Observation + from dimos.memory2.buffer import BackpressureBuffer + from dimos.memory2.codecs.base import Codec + from dimos.memory2.filter import StreamQuery + from dimos.memory2.type import Observation from dimos.models.embedding.base import Embedding T = TypeVar("T") diff --git a/dimos/memory/blobstore/__init__.py b/dimos/memory2/blobstore/__init__.py similarity index 80% rename from dimos/memory/blobstore/__init__.py rename to dimos/memory2/blobstore/__init__.py index f0b3fe76f5..8f78d7c439 100644 --- a/dimos/memory/blobstore/__init__.py +++ b/dimos/memory2/blobstore/__init__.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.memory.backend import BlobStore -from dimos.memory.blobstore.file import FileBlobStore -from dimos.memory.blobstore.sqlite import SqliteBlobStore +from dimos.memory2.backend import BlobStore +from dimos.memory2.blobstore.file import FileBlobStore +from dimos.memory2.blobstore.sqlite import SqliteBlobStore __all__ = ["BlobStore", "FileBlobStore", "SqliteBlobStore"] diff --git a/dimos/memory/blobstore/blobstore.md b/dimos/memory2/blobstore/blobstore.md similarity index 100% rename from dimos/memory/blobstore/blobstore.md rename to dimos/memory2/blobstore/blobstore.md diff --git a/dimos/memory/blobstore/file.py b/dimos/memory2/blobstore/file.py similarity index 97% rename from dimos/memory/blobstore/file.py rename to dimos/memory2/blobstore/file.py index de8c6e8bc2..54ec80e284 100644 --- a/dimos/memory/blobstore/file.py +++ b/dimos/memory2/blobstore/file.py @@ -17,7 +17,7 @@ from pathlib import Path from typing import TYPE_CHECKING -from dimos.memory.backend import BlobStore +from dimos.memory2.backend import BlobStore if TYPE_CHECKING: import os diff --git a/dimos/memory/blobstore/sqlite.py b/dimos/memory2/blobstore/sqlite.py similarity index 98% rename from dimos/memory/blobstore/sqlite.py rename to dimos/memory2/blobstore/sqlite.py index 235152d796..0fd144c532 100644 --- a/dimos/memory/blobstore/sqlite.py +++ b/dimos/memory2/blobstore/sqlite.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING -from dimos.memory.backend import BlobStore +from dimos.memory2.backend import BlobStore if TYPE_CHECKING: import sqlite3 diff --git a/dimos/memory/blobstore/test_blobstore.py b/dimos/memory2/blobstore/test_blobstore.py similarity index 95% rename from dimos/memory/blobstore/test_blobstore.py rename to dimos/memory2/blobstore/test_blobstore.py index 83f76fa2ec..fe05cfa84f 100644 --- a/dimos/memory/blobstore/test_blobstore.py +++ b/dimos/memory2/blobstore/test_blobstore.py @@ -22,14 +22,14 @@ import pytest -from dimos.memory.blobstore.file import FileBlobStore -from dimos.memory.blobstore.sqlite import SqliteBlobStore +from dimos.memory2.blobstore.file import FileBlobStore +from dimos.memory2.blobstore.sqlite import SqliteBlobStore if TYPE_CHECKING: from collections.abc import Callable, Generator from pathlib import Path - from dimos.memory.backend import BlobStore + from dimos.memory2.backend import BlobStore # ── Case definition ──────────────────────────────────────────────── diff --git a/dimos/memory/buffer.py b/dimos/memory2/buffer.py similarity index 100% rename from dimos/memory/buffer.py rename to dimos/memory2/buffer.py diff --git a/dimos/memory/codecs/README.md b/dimos/memory2/codecs/README.md similarity index 97% rename from dimos/memory/codecs/README.md rename to dimos/memory2/codecs/README.md index 719369f29a..8ad40e95fd 100644 --- a/dimos/memory/codecs/README.md +++ b/dimos/memory2/codecs/README.md @@ -23,7 +23,7 @@ class Codec(Protocol[T]): `codec_for(payload_type)` picks the right codec: ```python -from dimos.memory.codecs import codec_for +from dimos.memory2.codecs import codec_for codec_for(Image) # → JpegCodec(quality=50) codec_for(SomeLcmMsg) # → LcmCodec(SomeLcmMsg) (if has lcm_encode/lcm_decode) diff --git a/dimos/memory/codecs/__init__.py b/dimos/memory2/codecs/__init__.py similarity index 85% rename from dimos/memory/codecs/__init__.py rename to dimos/memory2/codecs/__init__.py index fe4b870250..a7feb3bce3 100644 --- a/dimos/memory/codecs/__init__.py +++ b/dimos/memory2/codecs/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.memory.codecs.base import Codec, codec_for -from dimos.memory.codecs.pickle import PickleCodec +from dimos.memory2.codecs.base import Codec, codec_for +from dimos.memory2.codecs.pickle import PickleCodec __all__ = ["Codec", "PickleCodec", "codec_for"] diff --git a/dimos/memory/codecs/base.py b/dimos/memory2/codecs/base.py similarity index 88% rename from dimos/memory/codecs/base.py rename to dimos/memory2/codecs/base.py index 12ea658906..4c2b3865f5 100644 --- a/dimos/memory/codecs/base.py +++ b/dimos/memory2/codecs/base.py @@ -28,17 +28,17 @@ def decode(self, data: bytes) -> T: ... def codec_for(payload_type: type[Any] | None = None) -> Codec[Any]: """Auto-select codec based on payload type.""" - from dimos.memory.codecs.pickle import PickleCodec + from dimos.memory2.codecs.pickle import PickleCodec if payload_type is not None: from dimos.msgs.sensor_msgs.Image import Image if issubclass(payload_type, Image): - from dimos.memory.codecs.jpeg import JpegCodec + from dimos.memory2.codecs.jpeg import JpegCodec return JpegCodec() if hasattr(payload_type, "lcm_encode") and hasattr(payload_type, "lcm_decode"): - from dimos.memory.codecs.lcm import LcmCodec + from dimos.memory2.codecs.lcm import LcmCodec return LcmCodec(payload_type) return PickleCodec() diff --git a/dimos/memory/codecs/jpeg.py b/dimos/memory2/codecs/jpeg.py similarity index 100% rename from dimos/memory/codecs/jpeg.py rename to dimos/memory2/codecs/jpeg.py diff --git a/dimos/memory/codecs/lcm.py b/dimos/memory2/codecs/lcm.py similarity index 100% rename from dimos/memory/codecs/lcm.py rename to dimos/memory2/codecs/lcm.py diff --git a/dimos/memory/codecs/pickle.py b/dimos/memory2/codecs/pickle.py similarity index 100% rename from dimos/memory/codecs/pickle.py rename to dimos/memory2/codecs/pickle.py diff --git a/dimos/memory/codecs/test_codecs.py b/dimos/memory2/codecs/test_codecs.py similarity index 91% rename from dimos/memory/codecs/test_codecs.py rename to dimos/memory2/codecs/test_codecs.py index 7d5057d589..8f3eb17c10 100644 --- a/dimos/memory/codecs/test_codecs.py +++ b/dimos/memory2/codecs/test_codecs.py @@ -24,7 +24,7 @@ import pytest -from dimos.memory.codecs.base import Codec, codec_for +from dimos.memory2.codecs.base import Codec, codec_for if TYPE_CHECKING: from collections.abc import Callable @@ -44,7 +44,7 @@ class Case: def _pickle_case() -> Case: - from dimos.memory.codecs.pickle import PickleCodec + from dimos.memory2.codecs.pickle import PickleCodec return Case( name="pickle", @@ -54,7 +54,7 @@ def _pickle_case() -> Case: def _lcm_case() -> Case: - from dimos.memory.codecs.lcm import LcmCodec + from dimos.memory2.codecs.lcm import LcmCodec from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.geometry_msgs.Vector3 import Vector3 @@ -86,7 +86,7 @@ def _jpeg_eq(original: Any, decoded: Any) -> bool: def _jpeg_case() -> Case: - from dimos.memory.codecs.jpeg import JpegCodec + from dimos.memory2.codecs.jpeg import JpegCodec from dimos.utils.testing import TimedSensorReplay replay = TimedSensorReplay("unitree_go2_bigoffice/video") @@ -132,23 +132,23 @@ class TestCodecFor: """codec_for() auto-selects the right codec.""" def test_none_returns_pickle(self) -> None: - from dimos.memory.codecs.pickle import PickleCodec + from dimos.memory2.codecs.pickle import PickleCodec assert isinstance(codec_for(None), PickleCodec) def test_unknown_type_returns_pickle(self) -> None: - from dimos.memory.codecs.pickle import PickleCodec + from dimos.memory2.codecs.pickle import PickleCodec assert isinstance(codec_for(dict), PickleCodec) def test_lcm_type_returns_lcm(self) -> None: - from dimos.memory.codecs.lcm import LcmCodec + from dimos.memory2.codecs.lcm import LcmCodec from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped assert isinstance(codec_for(PoseStamped), LcmCodec) def test_image_type_returns_jpeg(self) -> None: - from dimos.memory.codecs.jpeg import JpegCodec + from dimos.memory2.codecs.jpeg import JpegCodec from dimos.msgs.sensor_msgs.Image import Image assert isinstance(codec_for(Image), JpegCodec) diff --git a/dimos/memory/embed.py b/dimos/memory2/embed.py similarity index 96% rename from dimos/memory/embed.py rename to dimos/memory2/embed.py index 04e68dd540..981bd83b73 100644 --- a/dimos/memory/embed.py +++ b/dimos/memory2/embed.py @@ -17,12 +17,12 @@ from itertools import islice from typing import TYPE_CHECKING, Any, TypeVar -from dimos.memory.transform import Transformer +from dimos.memory2.transform import Transformer if TYPE_CHECKING: from collections.abc import Iterator - from dimos.memory.type import Observation + from dimos.memory2.type import Observation from dimos.models.embedding.base import EmbeddingModel T = TypeVar("T") diff --git a/dimos/memory/embeddings.md b/dimos/memory2/embeddings.md similarity index 100% rename from dimos/memory/embeddings.md rename to dimos/memory2/embeddings.md diff --git a/dimos/memory/filter.py b/dimos/memory2/filter.py similarity index 98% rename from dimos/memory/filter.py rename to dimos/memory2/filter.py index 2d9f98c1d4..8c80546ef4 100644 --- a/dimos/memory/filter.py +++ b/dimos/memory2/filter.py @@ -22,8 +22,8 @@ if TYPE_CHECKING: from collections.abc import Callable, Iterator - from dimos.memory.buffer import BackpressureBuffer - from dimos.memory.type import Observation + from dimos.memory2.buffer import BackpressureBuffer + from dimos.memory2.type import Observation from dimos.models.embedding.base import Embedding diff --git a/dimos/memory/formatting.py b/dimos/memory2/formatting.py similarity index 100% rename from dimos/memory/formatting.py rename to dimos/memory2/formatting.py diff --git a/dimos/memory/impl/README.md b/dimos/memory2/impl/README.md similarity index 95% rename from dimos/memory/impl/README.md rename to dimos/memory2/impl/README.md index f2475c7eec..c10c5b235c 100644 --- a/dimos/memory/impl/README.md +++ b/dimos/memory2/impl/README.md @@ -14,10 +14,10 @@ Storage backends for memory. Each backend implements the `Backend` protocol to p ### 1. Implement the Backend protocol ```python -from dimos.memory.backend import Backend, BackendConfig, LiveChannel -from dimos.memory.filter import StreamQuery -from dimos.memory.livechannel.subject import SubjectChannel -from dimos.memory.type import Observation +from dimos.memory2.backend import Backend, BackendConfig, LiveChannel +from dimos.memory2.filter import StreamQuery +from dimos.memory2.livechannel.subject import SubjectChannel +from dimos.memory2.type import Observation from dimos.protocol.service.spec import Configurable class MyBackend(Configurable[BackendConfig], Generic[T]): @@ -78,7 +78,7 @@ See `ListBackend._iterate_live()` for the reference implementation. ### 3. Add Store and Session ```python -from dimos.memory.store import Session, Store +from dimos.memory2.store import Session, Store class MySession(Session): def _create_backend( diff --git a/dimos/memory/impl/__init__.py b/dimos/memory2/impl/__init__.py similarity index 100% rename from dimos/memory/impl/__init__.py rename to dimos/memory2/impl/__init__.py diff --git a/dimos/memory/impl/memory.py b/dimos/memory2/impl/memory.py similarity index 93% rename from dimos/memory/impl/memory.py rename to dimos/memory2/impl/memory.py index 0956fa83e3..0525326fcc 100644 --- a/dimos/memory/impl/memory.py +++ b/dimos/memory2/impl/memory.py @@ -18,11 +18,11 @@ import threading from typing import TYPE_CHECKING, Any, Generic, TypeVar -from dimos.memory.backend import BackendConfig -from dimos.memory.codecs.base import Codec, codec_for -from dimos.memory.livechannel.subject import SubjectChannel -from dimos.memory.store import Session, Store -from dimos.memory.type import _UNLOADED +from dimos.memory2.backend import BackendConfig +from dimos.memory2.codecs.base import Codec, codec_for +from dimos.memory2.livechannel.subject import SubjectChannel +from dimos.memory2.store import Session, Store +from dimos.memory2.type import _UNLOADED from dimos.protocol.service.spec import Configurable if TYPE_CHECKING: @@ -30,10 +30,10 @@ from reactivex.abc import DisposableBase - from dimos.memory.backend import Backend, LiveChannel - from dimos.memory.buffer import BackpressureBuffer - from dimos.memory.filter import StreamQuery - from dimos.memory.type import Observation + from dimos.memory2.backend import Backend, LiveChannel + from dimos.memory2.buffer import BackpressureBuffer + from dimos.memory2.filter import StreamQuery + from dimos.memory2.type import Observation T = TypeVar("T") @@ -154,7 +154,7 @@ def _iterate_live( buf: BackpressureBuffer[Observation[T]], sub: DisposableBase, ) -> Iterator[Observation[T]]: - from dimos.memory.buffer import ClosedError + from dimos.memory2.buffer import ClosedError eager = self.config.eager_blobs and self.config.blob_store is not None diff --git a/dimos/memory2/impl/sqlite.py b/dimos/memory2/impl/sqlite.py new file mode 100644 index 0000000000..03857d9bd8 --- /dev/null +++ b/dimos/memory2/impl/sqlite.py @@ -0,0 +1,674 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass, replace +from itertools import islice +import json +import re +import sqlite3 +import threading +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +from dimos.memory2.backend import BackendConfig +from dimos.memory2.blobstore.sqlite import SqliteBlobStore +from dimos.memory2.codecs.base import Codec, codec_for +from dimos.memory2.filter import ( + AfterFilter, + AtFilter, + BeforeFilter, + NearFilter, + TagsFilter, + TimeRangeFilter, + _xyz, +) +from dimos.memory2.livechannel.subject import SubjectChannel +from dimos.memory2.store import Session, Store, StoreConfig +from dimos.memory2.type import _UNLOADED, Observation +from dimos.protocol.service.spec import Configurable + +if TYPE_CHECKING: + from collections.abc import Iterator + + from reactivex.abc import DisposableBase + + from dimos.memory2.backend import Backend, LiveChannel + from dimos.memory2.buffer import BackpressureBuffer + from dimos.memory2.filter import Filter, StreamQuery + +T = TypeVar("T") + +_IDENT_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + + +# ── Helpers ────────────────────────────────────────────────────── + + +def _validate_identifier(name: str) -> None: + if not _IDENT_RE.match(name): + raise ValueError(f"Invalid stream name: {name!r}") + + +def _decompose_pose(pose: Any) -> tuple[float, ...] | None: + if pose is None: + return None + if hasattr(pose, "position"): + pos = pose.position + orient = getattr(pose, "orientation", None) + x, y, z = float(pos.x), float(pos.y), float(getattr(pos, "z", 0.0)) + if orient is not None: + return (x, y, z, float(orient.x), float(orient.y), float(orient.z), float(orient.w)) + return (x, y, z, 0.0, 0.0, 0.0, 1.0) + if isinstance(pose, (list, tuple)): + vals = [float(v) for v in pose] + while len(vals) < 7: + vals.append(0.0 if len(vals) < 6 else 1.0) + return tuple(vals[:7]) + return None + + +def _reconstruct_pose( + x: float | None, + y: float | None, + z: float | None, + qx: float | None, + qy: float | None, + qz: float | None, + qw: float | None, +) -> tuple[float, ...] | None: + if x is None: + return None + return (x, y or 0.0, z or 0.0, qx or 0.0, qy or 0.0, qz or 0.0, qw or 1.0) + + +def _compile_filter(f: Filter, stream: str, prefix: str = "") -> tuple[str, list[Any]] | None: + """Compile a filter to SQL WHERE clause. Returns None for non-pushable filters. + + ``stream`` is the raw stream name (for R*Tree table references). + ``prefix`` is a column qualifier (e.g. ``"meta."`` for JOIN queries). + """ + if isinstance(f, AfterFilter): + return (f"{prefix}ts > ?", [f.t]) + if isinstance(f, BeforeFilter): + return (f"{prefix}ts < ?", [f.t]) + if isinstance(f, TimeRangeFilter): + return (f"{prefix}ts >= ? AND {prefix}ts <= ?", [f.t1, f.t2]) + if isinstance(f, AtFilter): + return (f"ABS({prefix}ts - ?) <= ?", [f.t, f.tolerance]) + if isinstance(f, TagsFilter): + clauses = [] + params: list[Any] = [] + for k, v in f.tags.items(): + clauses.append(f"json_extract({prefix}tags, '$.{k}') = ?") + params.append(v) + return (" AND ".join(clauses), params) + if isinstance(f, NearFilter): + pose = f.pose + if pose is None: + return None + if hasattr(pose, "position"): + pose = pose.position + cx, cy, cz = _xyz(pose) + r = f.radius + # R*Tree bounding-box pre-filter + exact squared-distance check + rtree_sql = ( + f'{prefix}id IN (SELECT id FROM "{stream}_rtree" ' + f"WHERE x_min >= ? AND x_max <= ? " + f"AND y_min >= ? AND y_max <= ? " + f"AND z_min >= ? AND z_max <= ?)" + ) + dist_sql = ( + f"(({prefix}pose_x - ?) * ({prefix}pose_x - ?) + " + f"({prefix}pose_y - ?) * ({prefix}pose_y - ?) + " + f"({prefix}pose_z - ?) * ({prefix}pose_z - ?) <= ?)" + ) + return ( + f"{rtree_sql} AND {dist_sql}", + [ + cx - r, + cx + r, + cy - r, + cy + r, + cz - r, + cz + r, # R*Tree bbox + cx, + cx, + cy, + cy, + cz, + cz, + r * r, # squared distance + ], + ) + # PredicateFilter — not pushable + return None + + +def _compile_query( + query: StreamQuery, + table: str, + *, + join_blob: bool = False, +) -> tuple[str, list[Any], list[Filter]]: + """Compile a StreamQuery to SQL. + + Returns (sql, params, python_filters) where python_filters must be + applied as post-filters in Python. + """ + prefix = "meta." if join_blob else "" + if join_blob: + select = f'SELECT meta.id, meta.ts, meta.pose_x, meta.pose_y, meta.pose_z, meta.pose_qx, meta.pose_qy, meta.pose_qz, meta.pose_qw, json(meta.tags), blob.data FROM "{table}" AS meta JOIN "{table}_blob" AS blob ON blob.id = meta.id' + else: + select = f'SELECT id, ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, json(tags) FROM "{table}"' + + where_parts: list[str] = [] + params: list[Any] = [] + python_filters: list[Filter] = [] + + for f in query.filters: + compiled = _compile_filter(f, table, prefix) + if compiled is not None: + sql_part, sql_params = compiled + where_parts.append(sql_part) + params.extend(sql_params) + else: + python_filters.append(f) + + sql = select + if where_parts: + sql += " WHERE " + " AND ".join(where_parts) + + # ORDER BY + if query.order_field: + direction = "DESC" if query.order_desc else "ASC" + sql += f" ORDER BY {prefix}{query.order_field} {direction}" + else: + sql += f" ORDER BY {prefix}id ASC" + + # Only push LIMIT/OFFSET to SQL when there are no Python post-filters + if not python_filters and not query.search_text: + if query.limit_val is not None: + if query.offset_val: + sql += f" LIMIT {query.limit_val} OFFSET {query.offset_val}" + else: + sql += f" LIMIT {query.limit_val}" + elif query.offset_val: + sql += f" LIMIT -1 OFFSET {query.offset_val}" + + return (sql, params, python_filters) + + +def _compile_count( + query: StreamQuery, + table: str, +) -> tuple[str, list[Any], list[Filter]]: + """Compile a StreamQuery to a COUNT SQL query.""" + where_parts: list[str] = [] + params: list[Any] = [] + python_filters: list[Filter] = [] + + for f in query.filters: + compiled = _compile_filter(f, table) + if compiled is not None: + sql_part, sql_params = compiled + where_parts.append(sql_part) + params.extend(sql_params) + else: + python_filters.append(f) + + sql = f'SELECT COUNT(*) FROM "{table}"' + if where_parts: + sql += " WHERE " + " AND ".join(where_parts) + + return (sql, params, python_filters) + + +# ── SqliteBackend ──────────────────────────────────────────────── + + +class SqliteBackend(Configurable[BackendConfig], Generic[T]): + """SQLite-backed observation storage for a single stream (table).""" + + default_config: type[BackendConfig] = BackendConfig + + def __init__(self, conn: sqlite3.Connection, name: str, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._conn = conn + self._name = name + self._codec: Codec[Any] = self.config.codec # type: ignore[assignment] + self._channel: LiveChannel[T] = self.config.live_channel or SubjectChannel() + self._lock = threading.Lock() + self._tag_indexes: set[str] = set() + + @property + def name(self) -> str: + return self._name + + @property + def live_channel(self) -> LiveChannel[T]: + return self._channel + + @property + def _join_blobs(self) -> bool: + if not self.config.eager_blobs: + return False + bs = self.config.blob_store + return isinstance(bs, SqliteBlobStore) and bs._conn is self._conn + + def _make_loader(self, row_id: int) -> Any: + bs = self.config.blob_store + assert bs is not None + name, codec = self._name, self._codec + owner_tid = threading.get_ident() + + def loader() -> Any: + assert threading.get_ident() == owner_tid + raw = bs.get(name, row_id) + return codec.decode(raw) + + return loader + + def _row_to_obs(self, row: tuple[Any, ...], *, has_blob: bool = False) -> Observation[T]: + if has_blob: + row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json, blob_data = row + else: + row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json = row + blob_data = None + + pose = _reconstruct_pose(px, py, pz, qx, qy, qz, qw) + tags = json.loads(tags_json) if tags_json else {} + + if has_blob and blob_data is not None: + data = self._codec.decode(blob_data) + return Observation(id=row_id, ts=ts, pose=pose, tags=tags, _data=data) + + return Observation( + id=row_id, + ts=ts, + pose=pose, + tags=tags, + _data=_UNLOADED, + _loader=self._make_loader(row_id), # type: ignore[arg-type] + ) + + # ── Write ──────────────────────────────────────────────────── + + def _ensure_tag_indexes(self, tags: dict[str, Any]) -> None: + """Auto-create expression indexes for any new tag keys.""" + for key in tags: + if key not in self._tag_indexes and _IDENT_RE.match(key): + self._conn.execute( + f'CREATE INDEX IF NOT EXISTS "{self._name}_tag_{key}" ' + f"ON \"{self._name}\"(json_extract(tags, '$.{key}'))" + ) + self._tag_indexes.add(key) + + def append(self, obs: Observation[T]) -> Observation[T]: + encoded = self._codec.encode(obs._data) + pose = _decompose_pose(obs.pose) + tags_json = json.dumps(obs.tags) if obs.tags else "{}" + + with self._lock: + if obs.tags: + self._ensure_tag_indexes(obs.tags) + if pose: + px, py, pz, qx, qy, qz, qw = pose + else: + px = py = pz = qx = qy = qz = qw = None # type: ignore[assignment] + + cur = self._conn.execute( + f'INSERT INTO "{self._name}" (ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, tags) ' + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, jsonb(?))", + (obs.ts, px, py, pz, qx, qy, qz, qw, tags_json), + ) + row_id = cur.lastrowid + assert row_id is not None + + bs = self.config.blob_store + assert bs is not None + bs.put(self._name, row_id, encoded) + + # R*Tree spatial index + if pose: + self._conn.execute( + f'INSERT INTO "{self._name}_rtree" (id, x_min, x_max, y_min, y_max, z_min, z_max) ' + "VALUES (?, ?, ?, ?, ?, ?, ?)", + (row_id, px, px, py, py, pz, pz), + ) + + vs = self.config.vector_store + if vs is not None: + emb = getattr(obs, "embedding", None) + if emb is not None: + vs.put(self._name, row_id, emb) + + self._conn.commit() + + obs.id = row_id + self._channel.notify(obs) + return obs + + # ── Read ───────────────────────────────────────────────────── + + def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: + if query.search_vec is not None and query.live_buffer is not None: + raise TypeError("Cannot combine .search() with .live() — search is a batch operation.") + buf = query.live_buffer + if buf is not None: + sub = self._channel.subscribe(buf) + return self._iterate_live(query, buf, sub) + return self._iterate_snapshot(query) + + def _iterate_snapshot(self, query: StreamQuery) -> Iterator[Observation[T]]: + if query.search_vec is not None and self.config.vector_store is not None: + yield from self._vector_search(query) + return + + join = self._join_blobs + sql, params, python_filters = _compile_query(query, self._name, join_blob=join) + + cur = self._conn.execute(sql, params) + cur.arraysize = self.config.page_size + it: Iterator[Observation[T]] = (self._row_to_obs(r, has_blob=join) for r in cur) + + # Text search — requires loading data + if query.search_text is not None: + needle = query.search_text.lower() + it = (obs for obs in it if needle in str(obs.data).lower()) + + # Apply Python post-filters + if python_filters: + it = (obs for obs in it if all(f.matches(obs) for f in python_filters)) + + # Apply LIMIT/OFFSET in Python when we couldn't push to SQL + if python_filters or query.search_text: + if query.offset_val: + it = islice(it, query.offset_val, None) + if query.limit_val is not None: + it = islice(it, query.limit_val) + + yield from it + + def _vector_search(self, query: StreamQuery) -> Iterator[Observation[T]]: + vs = self.config.vector_store + assert vs is not None and query.search_vec is not None + + hits = vs.search(self._name, query.search_vec, query.search_k or 10) + if not hits: + return + + ids = [h[0] for h in hits] + dict(hits) + + # Batch-fetch metadata + join = self._join_blobs + placeholders = ",".join("?" * len(ids)) + if join: + sql = ( + f"SELECT meta.id, meta.ts, meta.pose_x, meta.pose_y, meta.pose_z, " + f"meta.pose_qx, meta.pose_qy, meta.pose_qz, meta.pose_qw, json(meta.tags), blob.data " + f'FROM "{self._name}" AS meta ' + f'JOIN "{self._name}_blob" AS blob ON blob.id = meta.id ' + f"WHERE meta.id IN ({placeholders})" + ) + else: + sql = ( + f"SELECT id, ts, pose_x, pose_y, pose_z, " + f"pose_qx, pose_qy, pose_qz, pose_qw, json(tags) " + f'FROM "{self._name}" WHERE id IN ({placeholders})' + ) + + rows = self._conn.execute(sql, ids).fetchall() + obs_by_id: dict[int, Observation[T]] = {} + for r in rows: + obs = self._row_to_obs(r, has_blob=join) + obs_by_id[obs.id] = obs + + # Preserve VectorStore ranking order, promoting to EmbeddedObservation + ranked: list[Observation[T]] = [] + for obs_id, sim in hits: + match = obs_by_id.get(obs_id) + if match is not None: + ranked.append( + match.derive(data=match.data, embedding=query.search_vec, similarity=sim) + ) + + # Apply remaining query ops (skip vector search) + rest = replace(query, search_vec=None, search_k=None) + yield from rest.apply(iter(ranked)) + + def _iterate_live( + self, + query: StreamQuery, + buf: BackpressureBuffer[Observation[T]], + sub: DisposableBase, + ) -> Iterator[Observation[T]]: + from dimos.memory2.buffer import ClosedError + + # Backfill phase + last_id = -1 + for obs in self._iterate_snapshot(query): + last_id = max(last_id, obs.id) + yield obs + + # Live tail + filters = query.filters + try: + while True: + obs = buf.take() + if obs.id <= last_id: + continue + last_id = obs.id + if filters and not all(f.matches(obs) for f in filters): + continue + yield obs + except (ClosedError, StopIteration): + sub.dispose() + + def count(self, query: StreamQuery) -> int: + if query.search_vec or query.search_text: + return sum(1 for _ in self.iterate(query)) + + sql, params, python_filters = _compile_count(query, self._name) + if python_filters: + return sum(1 for _ in self.iterate(query)) + + row = self._conn.execute(sql, params).fetchone() + return int(row[0]) if row else 0 + + +# ── SqliteSession ──────────────────────────────────────────────── + + +class SqliteSession(Session): + """Session owning a single SQLite connection.""" + + def __init__( + self, conn: sqlite3.Connection, *, vec_available: bool = False, **kwargs: Any + ) -> None: + super().__init__(**kwargs) + self._conn = conn + self._vec_available = vec_available + self._blob_store: SqliteBlobStore | None = None + self._vector_store: Any | None = None + + # Create stream registry + self._conn.execute( + "CREATE TABLE IF NOT EXISTS _streams (" + " name TEXT PRIMARY KEY," + " payload_module TEXT NOT NULL," + " codec_id TEXT NOT NULL" + ")" + ) + self._conn.commit() + + def _ensure_shared_stores(self) -> None: + """Lazily create shared stores on first stream creation.""" + if self._blob_store is None: + self._blob_store = SqliteBlobStore(self._conn) + if self._vector_store is None and self._vec_available: + from dimos.memory2.vectorstore.sqlite import SqliteVectorStore + + self._vector_store = SqliteVectorStore(self._conn) + + @staticmethod + def _codec_id(codec: Codec[Any]) -> str: + from dimos.memory2.codecs.jpeg import JpegCodec + from dimos.memory2.codecs.lcm import LcmCodec + + if isinstance(codec, JpegCodec): + return "jpeg" + if isinstance(codec, LcmCodec): + return "lcm" + return "pickle" + + @staticmethod + def _codec_from_id(codec_id: str, payload_module: str) -> Codec[Any]: + from dimos.memory2.codecs.pickle import PickleCodec + + if codec_id == "jpeg": + from dimos.memory2.codecs.jpeg import JpegCodec + + return JpegCodec() + if codec_id == "lcm": + from dimos.memory2.codecs.lcm import LcmCodec + + # Resolve the payload type from module path + parts = payload_module.rsplit(".", 1) + if len(parts) == 2: + import importlib + + mod = importlib.import_module(parts[0]) + cls = getattr(mod, parts[1]) + return LcmCodec(cls) + return PickleCodec() + return PickleCodec() + + def _create_backend( + self, name: str, payload_type: type[Any] | None = None, **config: Any + ) -> Backend[Any]: + _validate_identifier(name) + self._ensure_shared_stores() + + # Look up existing stream in registry + row = self._conn.execute( + "SELECT payload_module, codec_id FROM _streams WHERE name = ?", (name,) + ).fetchone() + + if row is not None: + stored_module, stored_codec_id = row + if payload_type is not None: + actual_module = f"{payload_type.__module__}.{payload_type.__qualname__}" + if actual_module != stored_module: + raise ValueError( + f"Stream {name!r} was created with type {stored_module}, " + f"but opened with {actual_module}" + ) + codec = config.get("codec") or self._codec_from_id(stored_codec_id, stored_module) + else: + if payload_type is None: + raise TypeError(f"Stream {name!r} does not exist yet — payload_type is required") + codec = config.get("codec") or codec_for(payload_type) + payload_module = f"{payload_type.__module__}.{payload_type.__qualname__}" + self._conn.execute( + "INSERT INTO _streams (name, payload_module, codec_id) VALUES (?, ?, ?)", + (name, payload_module, self._codec_id(codec)), + ) + self._conn.commit() + + # Create metadata table + self._conn.execute( + f'CREATE TABLE IF NOT EXISTS "{name}" (' + " id INTEGER PRIMARY KEY AUTOINCREMENT," + " ts REAL NOT NULL UNIQUE," + " pose_x REAL, pose_y REAL, pose_z REAL," + " pose_qx REAL, pose_qy REAL, pose_qz REAL, pose_qw REAL," + " tags BLOB DEFAULT (jsonb('{}'))" + ")" + ) + # R*Tree spatial index for pose queries + self._conn.execute( + f'CREATE VIRTUAL TABLE IF NOT EXISTS "{name}_rtree" USING rtree(' + " id," + " x_min, x_max," + " y_min, y_max," + " z_min, z_max" + ")" + ) + self._conn.commit() + + # Merge shared stores as defaults + if "blob_store" not in config or config["blob_store"] is None: + config["blob_store"] = self._blob_store + if "vector_store" not in config or config["vector_store"] is None: + config["vector_store"] = self._vector_store + config["codec"] = codec + + return SqliteBackend(self._conn, name, **config) + + def list_streams(self) -> list[str]: + db_names = {row[0] for row in self._conn.execute("SELECT name FROM _streams").fetchall()} + return sorted(db_names | set(self._streams.keys())) + + def delete_stream(self, name: str) -> None: + self._streams.pop(name, None) + self._conn.execute(f'DROP TABLE IF EXISTS "{name}"') + self._conn.execute(f'DROP TABLE IF EXISTS "{name}_blob"') + self._conn.execute(f'DROP TABLE IF EXISTS "{name}_vec"') + self._conn.execute(f'DROP TABLE IF EXISTS "{name}_rtree"') + self._conn.execute("DELETE FROM _streams WHERE name = ?", (name,)) + self._conn.commit() + + def stop(self) -> None: + super().stop() + self._conn.close() + + +# ── SqliteStore ────────────────────────────────────────────────── + + +@dataclass +class SqliteStoreConfig(StoreConfig): + """Config for SQLite-backed store.""" + + path: str = "memory.db" + + +class SqliteStore(Store): + """Store backed by a SQLite database file.""" + + default_config: type[SqliteStoreConfig] = SqliteStoreConfig + config: SqliteStoreConfig + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + + def session(self, **kwargs: Any) -> SqliteSession: + conn = sqlite3.connect(self.config.path, check_same_thread=False) + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA synchronous=NORMAL") + + vec_available = False + try: + import sqlite_vec + + conn.enable_load_extension(True) + sqlite_vec.load(conn) + conn.enable_load_extension(False) + vec_available = True + except (ImportError, Exception): + pass + + return SqliteSession(conn, vec_available=vec_available, **kwargs) diff --git a/dimos/memory/intro.md b/dimos/memory2/intro.md similarity index 98% rename from dimos/memory/intro.md rename to dimos/memory2/intro.md index 269807cb4f..341d89608c 100644 --- a/dimos/memory/intro.md +++ b/dimos/memory2/intro.md @@ -3,7 +3,7 @@ ## Quick start ```python session=memory ansi=false no-result -from dimos.memory.impl.sqlite import SqliteStore +from dimos.memory2.impl.sqlite import SqliteStore store = SqliteStore(path="/tmp/memory_readme.db") session = store.session() @@ -140,7 +140,7 @@ Use `EmbedText` transformer with CLIP to enrich observations with embeddings, th ```python session=memory ansi=false from dimos.models.embedding.clip import CLIPModel -from dimos.memory.embed import EmbedText +from dimos.memory2.embed import EmbedText clip = CLIPModel() diff --git a/dimos/memory2/livechannel/__init__.py b/dimos/memory2/livechannel/__init__.py new file mode 100644 index 0000000000..4fba822bab --- /dev/null +++ b/dimos/memory2/livechannel/__init__.py @@ -0,0 +1,4 @@ +from dimos.memory2.backend import LiveChannel +from dimos.memory2.livechannel.subject import SubjectChannel + +__all__ = ["LiveChannel", "SubjectChannel"] diff --git a/dimos/memory/livechannel/subject.py b/dimos/memory2/livechannel/subject.py similarity index 92% rename from dimos/memory/livechannel/subject.py rename to dimos/memory2/livechannel/subject.py index 8debe229d7..2d2b848f9f 100644 --- a/dimos/memory/livechannel/subject.py +++ b/dimos/memory2/livechannel/subject.py @@ -21,13 +21,13 @@ from reactivex.disposable import Disposable -from dimos.memory.backend import LiveChannel +from dimos.memory2.backend import LiveChannel if TYPE_CHECKING: from reactivex.abc import DisposableBase - from dimos.memory.buffer import BackpressureBuffer - from dimos.memory.type import Observation + from dimos.memory2.buffer import BackpressureBuffer + from dimos.memory2.type import Observation T = TypeVar("T") diff --git a/dimos/memory/store.py b/dimos/memory2/store.py similarity index 97% rename from dimos/memory/store.py rename to dimos/memory2/store.py index 213df34d84..e9c1ec4e51 100644 --- a/dimos/memory/store.py +++ b/dimos/memory2/store.py @@ -19,14 +19,14 @@ from typing import TYPE_CHECKING, Any, TypeVar, cast from dimos.core.resource import CompositeResource -from dimos.memory.stream import Stream +from dimos.memory2.stream import Stream from dimos.protocol.service.spec import Configurable if TYPE_CHECKING: from collections.abc import Iterator - from dimos.memory.backend import Backend, BlobStore, LiveChannel, VectorStore - from dimos.memory.codecs.base import Codec + from dimos.memory2.backend import Backend, BlobStore, LiveChannel, VectorStore + from dimos.memory2.codecs.base import Codec T = TypeVar("T") diff --git a/dimos/memory/stream.py b/dimos/memory2/stream.py similarity index 98% rename from dimos/memory/stream.py rename to dimos/memory2/stream.py index 60d8a6ed7c..df6dc4636a 100644 --- a/dimos/memory/stream.py +++ b/dimos/memory2/stream.py @@ -18,9 +18,9 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar from dimos.core.resource import Resource -from dimos.memory.backend import Backend -from dimos.memory.buffer import BackpressureBuffer, KeepLast -from dimos.memory.filter import ( +from dimos.memory2.backend import Backend +from dimos.memory2.buffer import BackpressureBuffer, KeepLast +from dimos.memory2.filter import ( AfterFilter, AtFilter, BeforeFilter, @@ -31,8 +31,8 @@ TagsFilter, TimeRangeFilter, ) -from dimos.memory.transform import FnIterTransformer, FnTransformer, Transformer -from dimos.memory.type import EmbeddedObservation, Observation +from dimos.memory2.transform import FnIterTransformer, FnTransformer, Transformer +from dimos.memory2.type import EmbeddedObservation, Observation if TYPE_CHECKING: from collections.abc import Callable, Iterator diff --git a/dimos/memory/streaming.md b/dimos/memory2/streaming.md similarity index 100% rename from dimos/memory/streaming.md rename to dimos/memory2/streaming.md diff --git a/dimos/memory/test_blobstore.py b/dimos/memory2/test_blobstore.py similarity index 97% rename from dimos/memory/test_blobstore.py rename to dimos/memory2/test_blobstore.py index 8e5ab37744..b8e8668ff8 100644 --- a/dimos/memory/test_blobstore.py +++ b/dimos/memory2/test_blobstore.py @@ -20,9 +20,9 @@ import numpy as np -from dimos.memory.blobstore.file import FileBlobStore -from dimos.memory.impl.memory import MemoryStore -from dimos.memory.type import _UNLOADED +from dimos.memory2.blobstore.file import FileBlobStore +from dimos.memory2.impl.memory import MemoryStore +from dimos.memory2.type import _UNLOADED from dimos.models.embedding.base import Embedding if TYPE_CHECKING: @@ -126,7 +126,7 @@ def test_no_blobstore_unchanged(self) -> None: assert obs.data == "inline" def test_blobstore_with_vector_search(self, tmp_path: Path) -> None: - from dimos.memory.vectorstore import MemoryVectorStore + from dimos.memory2.vectorstore import MemoryVectorStore bs = FileBlobStore(tmp_path / "blobs") bs.start() diff --git a/dimos/memory/test_buffer.py b/dimos/memory2/test_buffer.py similarity index 96% rename from dimos/memory/test_buffer.py rename to dimos/memory2/test_buffer.py index 33235890e1..f851a6fcee 100644 --- a/dimos/memory/test_buffer.py +++ b/dimos/memory2/test_buffer.py @@ -21,7 +21,7 @@ import pytest -from dimos.memory.buffer import Bounded, ClosedError, DropNew, KeepLast, Unbounded +from dimos.memory2.buffer import Bounded, ClosedError, DropNew, KeepLast, Unbounded class TestBackpressureBuffers: diff --git a/dimos/memory/test_e2e_import.py b/dimos/memory2/test_e2e_import.py similarity index 97% rename from dimos/memory/test_e2e_import.py rename to dimos/memory2/test_e2e_import.py index 1f8f863d21..0fd44a3329 100644 --- a/dimos/memory/test_e2e_import.py +++ b/dimos/memory2/test_e2e_import.py @@ -21,7 +21,7 @@ import pytest -from dimos.memory.impl.sqlite import SqliteStore +from dimos.memory2.impl.sqlite import SqliteStore from dimos.msgs.sensor_msgs.Image import Image from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.utils.data import get_data_dir @@ -30,7 +30,7 @@ if TYPE_CHECKING: from collections.abc import Generator - from dimos.memory.impl.sqlite import SqliteSession + from dimos.memory2.impl.sqlite import SqliteSession DB_PATH = get_data_dir("go2_bigoffice_v2.db") diff --git a/dimos/memory/test_e2e_processing.py b/dimos/memory2/test_e2e_processing.py similarity index 100% rename from dimos/memory/test_e2e_processing.py rename to dimos/memory2/test_e2e_processing.py diff --git a/dimos/memory/test_e2e_query.py b/dimos/memory2/test_e2e_query.py similarity index 97% rename from dimos/memory/test_e2e_query.py rename to dimos/memory2/test_e2e_query.py index 6c9faed17b..ac26e865ff 100644 --- a/dimos/memory/test_e2e_query.py +++ b/dimos/memory2/test_e2e_query.py @@ -23,7 +23,7 @@ import pytest -from dimos.memory.impl.sqlite import SqliteStore +from dimos.memory2.impl.sqlite import SqliteStore from dimos.msgs.sensor_msgs.Image import Image from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.utils.data import get_data @@ -31,7 +31,7 @@ if TYPE_CHECKING: from collections.abc import Generator - from dimos.memory.impl.sqlite import SqliteSession + from dimos.memory2.impl.sqlite import SqliteSession @pytest.fixture(scope="module") @@ -97,7 +97,7 @@ def test_order_by_desc(self, session: SqliteSession) -> None: def test_lazy_data_loads_correctly(self, session: SqliteSession) -> None: """Verify lazy blob loading returns valid Image data.""" - from dimos.memory.type import _Unloaded + from dimos.memory2.type import _Unloaded video = session.stream("color_image", Image) obs = next(iter(video.limit(1))) diff --git a/dimos/memory2/test_embedding.py b/dimos/memory2/test_embedding.py new file mode 100644 index 0000000000..f1d22addf2 --- /dev/null +++ b/dimos/memory2/test_embedding.py @@ -0,0 +1,455 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for embedding layer: EmbeddedObservation, vector search, text search, transformers.""" + +from __future__ import annotations + +import numpy as np +import pytest + +from dimos.memory2.impl.memory import MemoryStore +from dimos.memory2.type import EmbeddedObservation, Observation +from dimos.models.embedding.base import Embedding + +# ── Helpers ─────────────────────────────────────────────────────── + + +def _emb(vec: list[float]) -> Embedding: + """Return a unit-normalized Embedding.""" + v = np.array(vec, dtype=np.float32) + v /= np.linalg.norm(v) + 1e-10 + return Embedding(vector=v) + + +# ── EmbeddedObservation ────────────────────────────────────────── + + +class TestEmbeddedObservation: + def test_construction(self) -> None: + emb = _emb([1, 0, 0]) + obs = EmbeddedObservation(id=0, ts=1.0, _data="hello", embedding=emb) + assert obs.data == "hello" + assert obs.embedding is emb + assert obs.similarity is None + + def test_is_observation(self) -> None: + obs = EmbeddedObservation(id=0, ts=1.0, _data="x", embedding=_emb([1, 0])) + assert isinstance(obs, Observation) + + def test_derive_preserves_embedding(self) -> None: + emb = _emb([1, 0, 0]) + obs = EmbeddedObservation(id=0, ts=1.0, _data="a", embedding=emb) + derived = obs.derive(data="b") + assert isinstance(derived, EmbeddedObservation) + assert derived.embedding is emb + assert derived.data == "b" + + def test_derive_replaces_embedding(self) -> None: + old = _emb([1, 0, 0]) + new = _emb([0, 1, 0]) + obs = EmbeddedObservation(id=0, ts=1.0, _data="a", embedding=old) + derived = obs.derive(data="a", embedding=new) + assert derived.embedding is new + + def test_derive_preserves_similarity(self) -> None: + obs = EmbeddedObservation(id=0, ts=1.0, _data="a", embedding=_emb([1, 0]), similarity=0.95) + derived = obs.derive(data="b") + assert derived.similarity == 0.95 + + def test_observation_derive_promotes_to_embedded(self) -> None: + obs = Observation(id=0, ts=1.0, _data="plain") + emb = _emb([1, 0, 0]) + derived = obs.derive(data="plain", embedding=emb) + assert isinstance(derived, EmbeddedObservation) + assert derived.embedding is emb + + def test_observation_derive_without_embedding_stays_observation(self) -> None: + obs = Observation(id=0, ts=1.0, _data="plain") + derived = obs.derive(data="still plain") + assert type(derived) is Observation + + +# ── ListBackend embedding support ──────────────────────────────── + + +class TestListBackendEmbedding: + def test_append_with_embedding(self) -> None: + store = MemoryStore() + with store.session() as session: + s = session.stream("vecs", str) + emb = _emb([1, 0, 0]) + obs = s.append("hello", embedding=emb) + assert isinstance(obs, EmbeddedObservation) + assert obs.embedding is emb + + def test_append_without_embedding(self) -> None: + store = MemoryStore() + with store.session() as session: + s = session.stream("plain", str) + obs = s.append("hello") + assert type(obs) is Observation + + def test_search_returns_top_k(self) -> None: + store = MemoryStore() + with store.session() as session: + s = session.stream("vecs", str) + s.append("north", embedding=_emb([0, 1, 0])) + s.append("east", embedding=_emb([1, 0, 0])) + s.append("south", embedding=_emb([0, -1, 0])) + s.append("west", embedding=_emb([-1, 0, 0])) + + results = s.search(_emb([0, 1, 0]), k=2).fetch() + assert len(results) == 2 + assert results[0].data == "north" + assert results[0].similarity is not None + assert results[0].similarity > 0.99 + + def test_search_sorted_by_similarity(self) -> None: + store = MemoryStore() + with store.session() as session: + s = session.stream("vecs", str) + s.append("far", embedding=_emb([0, -1, 0])) + s.append("close", embedding=_emb([0.9, 0.1, 0])) + s.append("exact", embedding=_emb([1, 0, 0])) + + results = s.search(_emb([1, 0, 0]), k=3).fetch() + assert results[0].data == "exact" + assert results[1].data == "close" + assert results[2].data == "far" + # Descending similarity + assert results[0].similarity >= results[1].similarity >= results[2].similarity + + def test_search_skips_non_embedded(self) -> None: + store = MemoryStore() + with store.session() as session: + s = session.stream("mixed", str) + s.append("plain") # no embedding + s.append("embedded", embedding=_emb([1, 0, 0])) + + results = s.search(_emb([1, 0, 0]), k=10).fetch() + assert len(results) == 1 + assert results[0].data == "embedded" + + def test_search_with_filters(self) -> None: + store = MemoryStore() + with store.session() as session: + s = session.stream("vecs", str) + s.append("early", ts=10.0, embedding=_emb([1, 0, 0])) + s.append("late", ts=20.0, embedding=_emb([1, 0, 0])) + + # Only the late one should pass the after filter + results = s.after(15.0).search(_emb([1, 0, 0]), k=10).fetch() + assert len(results) == 1 + assert results[0].data == "late" + + def test_search_with_limit(self) -> None: + store = MemoryStore() + with store.session() as session: + s = session.stream("vecs", str) + for i in range(10): + s.append(f"item{i}", embedding=_emb([1, 0, 0])) + + # search k=5 then limit 2 + results = s.search(_emb([1, 0, 0]), k=5).limit(2).fetch() + assert len(results) == 2 + + def test_search_with_live_raises(self) -> None: + store = MemoryStore() + with store.session() as session: + s = session.stream("vecs", str) + s.append("x", embedding=_emb([1, 0, 0])) + with pytest.raises(TypeError, match="Cannot combine"): + list(s.live().search(_emb([1, 0, 0]), k=5)) + + +# ── Text search ────────────────────────────────────────────────── + + +class TestTextSearch: + def test_search_text_substring(self) -> None: + store = MemoryStore() + with store.session() as session: + s = session.stream("logs", str) + s.append("motor fault detected") + s.append("temperature normal") + s.append("motor overheating") + + results = s.search_text("motor").fetch() + assert len(results) == 2 + assert {r.data for r in results} == {"motor fault detected", "motor overheating"} + + def test_search_text_case_insensitive(self) -> None: + store = MemoryStore() + with store.session() as session: + s = session.stream("logs", str) + s.append("Motor Fault") + s.append("other event") + + results = s.search_text("motor fault").fetch() + assert len(results) == 1 + + def test_search_text_with_filters(self) -> None: + store = MemoryStore() + with store.session() as session: + s = session.stream("logs", str) + s.append("motor fault", ts=10.0) + s.append("motor warning", ts=20.0) + s.append("motor fault", ts=30.0) + + results = s.after(15.0).search_text("fault").fetch() + assert len(results) == 1 + assert results[0].ts == 30.0 + + def test_search_text_no_match(self) -> None: + store = MemoryStore() + with store.session() as session: + s = session.stream("logs", str) + s.append("all clear") + + results = s.search_text("motor").fetch() + assert len(results) == 0 + + +# ── Save preserves embeddings ──────────────────────────────────── + + +class TestSaveEmbeddings: + def test_save_preserves_embeddings(self) -> None: + store = MemoryStore() + with store.session() as session: + src = session.stream("source", str) + dst = session.stream("dest", str) + + emb = _emb([1, 0, 0]) + src.append("item", embedding=emb) + src.save(dst) + + results = dst.fetch() + assert len(results) == 1 + assert isinstance(results[0], EmbeddedObservation) + # Same vector content (different Embedding instance after re-append) + np.testing.assert_array_almost_equal(results[0].embedding.to_numpy(), emb.to_numpy()) + + def test_save_mixed_embedded_and_plain(self) -> None: + store = MemoryStore() + with store.session() as session: + src = session.stream("source", str) + dst = session.stream("dest", str) + + src.append("plain") + src.append("embedded", embedding=_emb([0, 1, 0])) + src.save(dst) + + results = dst.fetch() + assert len(results) == 2 + assert type(results[0]) is Observation + assert isinstance(results[1], EmbeddedObservation) + + +# ── Embed transformers (mock model) ───────────────────────────── + + +class _MockEmbeddingModel: + """Fake EmbeddingModel that returns deterministic unit vectors.""" + + device = "cpu" + + def embed(self, *images): + vecs = [] + for img in images: + rng = np.random.default_rng(hash(str(img)) % 2**32) + v = rng.standard_normal(8).astype(np.float32) + v /= np.linalg.norm(v) + vecs.append(Embedding(vector=v)) + return vecs if len(vecs) > 1 else vecs[0] + + def embed_text(self, *texts): + vecs = [] + for text in texts: + rng = np.random.default_rng(hash(text) % 2**32) + v = rng.standard_normal(8).astype(np.float32) + v /= np.linalg.norm(v) + vecs.append(Embedding(vector=v)) + return vecs if len(vecs) > 1 else vecs[0] + + +class TestEmbedTransformers: + def test_embed_images_produces_embedded_observations(self) -> None: + from dimos.memory2.embed import EmbedImages + + model = _MockEmbeddingModel() + store = MemoryStore() + with store.session() as session: + s = session.stream("imgs", str) + s.append("img1", ts=1.0) + s.append("img2", ts=2.0) + + results = s.transform(EmbedImages(model)).fetch() + assert len(results) == 2 + for obs in results: + assert isinstance(obs, EmbeddedObservation) + assert isinstance(obs.embedding, Embedding) + assert obs.embedding.to_numpy().shape == (8,) + + def test_embed_text_produces_embedded_observations(self) -> None: + from dimos.memory2.embed import EmbedText + + model = _MockEmbeddingModel() + store = MemoryStore() + with store.session() as session: + s = session.stream("logs", str) + s.append("motor fault", ts=1.0) + s.append("all clear", ts=2.0) + + results = s.transform(EmbedText(model)).fetch() + assert len(results) == 2 + for obs in results: + assert isinstance(obs, EmbeddedObservation) + assert isinstance(obs.embedding, Embedding) + + def test_embed_preserves_data(self) -> None: + from dimos.memory2.embed import EmbedText + + model = _MockEmbeddingModel() + store = MemoryStore() + with store.session() as session: + s = session.stream("logs", str) + s.append("hello", ts=1.0) + + result = s.transform(EmbedText(model)).first() + assert result.data == "hello" + + def test_embed_then_search(self) -> None: + from dimos.memory2.embed import EmbedText + + model = _MockEmbeddingModel() + store = MemoryStore() + with store.session() as session: + s = session.stream("logs", str) + for i in range(10): + s.append(f"log entry {i}", ts=float(i)) + + embedded = s.transform(EmbedText(model)) + # Get the embedding for the first item, then search for similar + first_emb = embedded.first().embedding + results = embedded.search(first_emb, k=3).fetch() + assert len(results) == 3 + # First result should be the exact match + assert results[0].similarity is not None + assert results[0].similarity > 0.99 + + def test_embed_batching(self) -> None: + from dimos.memory2.embed import EmbedText + + call_sizes: list[int] = [] + + class _TrackingModel(_MockEmbeddingModel): + def embed_text(self, *texts): + call_sizes.append(len(texts)) + return super().embed_text(*texts) + + model = _TrackingModel() + store = MemoryStore() + with store.session() as session: + s = session.stream("logs", str) + for i in range(5): + s.append(f"entry {i}") + + list(s.transform(EmbedText(model, batch_size=2))) + # 5 items with batch_size=2 → 3 calls (2, 2, 1) + assert call_sizes == [2, 2, 1] + + +# ── Pluggable VectorStore ──────────────────────────────────────── + + +class TestPluggableVectorStore: + """Verify that injecting a VectorStore via session config actually delegates search.""" + + def test_append_stores_in_vector_store(self) -> None: + from dimos.memory2.vectorstore import MemoryVectorStore + + vs = MemoryVectorStore() + store = MemoryStore() + with store.session(vector_store=vs) as session: + s = session.stream("vecs", str) + s.append("hello", embedding=_emb([1, 0, 0])) + s.append("world", embedding=_emb([0, 1, 0])) + + assert len(vs._vectors["vecs"]) == 2 + + def test_append_without_embedding_skips_vector_store(self) -> None: + from dimos.memory2.vectorstore import MemoryVectorStore + + vs = MemoryVectorStore() + store = MemoryStore() + with store.session(vector_store=vs) as session: + s = session.stream("plain", str) + s.append("no embedding") + + assert "plain" not in vs._vectors + + def test_search_uses_vector_store(self) -> None: + from dimos.memory2.vectorstore import MemoryVectorStore + + vs = MemoryVectorStore() + store = MemoryStore() + with store.session(vector_store=vs) as session: + s = session.stream("vecs", str) + s.append("north", embedding=_emb([0, 1, 0])) + s.append("east", embedding=_emb([1, 0, 0])) + s.append("south", embedding=_emb([0, -1, 0])) + s.append("west", embedding=_emb([-1, 0, 0])) + + results = s.search(_emb([0, 1, 0]), k=2).fetch() + assert len(results) == 2 + assert results[0].data == "north" + assert results[0].similarity is not None + assert results[0].similarity > 0.99 + + def test_search_with_filters_via_vector_store(self) -> None: + from dimos.memory2.vectorstore import MemoryVectorStore + + vs = MemoryVectorStore() + store = MemoryStore() + with store.session(vector_store=vs) as session: + s = session.stream("vecs", str) + s.append("early", ts=10.0, embedding=_emb([1, 0, 0])) + s.append("late", ts=20.0, embedding=_emb([1, 0, 0])) + + # Filter + search: only "late" passes the after filter + results = s.after(15.0).search(_emb([1, 0, 0]), k=10).fetch() + assert len(results) == 1 + assert results[0].data == "late" + + def test_per_stream_vector_store_override(self) -> None: + from dimos.memory2.vectorstore import MemoryVectorStore + + vs_default = MemoryVectorStore() + vs_override = MemoryVectorStore() + store = MemoryStore() + with store.session(vector_store=vs_default) as session: + # Stream with default vector store + s1 = session.stream("s1", str) + s1.append("a", embedding=_emb([1, 0, 0])) + + # Stream with overridden vector store + s2 = session.stream("s2", str, vector_store=vs_override) + s2.append("b", embedding=_emb([0, 1, 0])) + + assert "s1" in vs_default._vectors + assert "s1" not in vs_override._vectors + assert "s2" in vs_override._vectors + assert "s2" not in vs_default._vectors diff --git a/dimos/memory/test_impl.py b/dimos/memory2/test_impl.py similarity index 95% rename from dimos/memory/test_impl.py rename to dimos/memory2/test_impl.py index 990818a8d6..3846654455 100644 --- a/dimos/memory/test_impl.py +++ b/dimos/memory2/test_impl.py @@ -28,7 +28,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Generator - from dimos.memory.store import Session + from dimos.memory2.store import Session # ── Case definition ──────────────────────────────────────────────── @@ -45,7 +45,7 @@ class Case: @contextmanager def memory_session() -> Generator[Session, None, None]: - from dimos.memory.impl.memory import MemoryStore + from dimos.memory2.impl.memory import MemoryStore store = MemoryStore() with store.session() as session: @@ -56,7 +56,7 @@ def memory_session() -> Generator[Session, None, None]: def sqlite_session() -> Generator[Session, None, None]: import tempfile - from dimos.memory.impl.sqlite import SqliteStore + from dimos.memory2.impl.sqlite import SqliteStore with tempfile.NamedTemporaryFile(suffix=".db") as f: store = SqliteStore(path=f.name) @@ -224,7 +224,7 @@ def test_same_stream_on_repeated_calls(self, case: Case) -> None: def test_append_with_embedding(self, case: Case) -> None: import numpy as np - from dimos.memory.type import EmbeddedObservation + from dimos.memory2.type import EmbeddedObservation from dimos.models.embedding.base import Embedding with case.session_factory() as session: @@ -275,8 +275,8 @@ def test_sqlite_lazy_by_default(self) -> None: """Default sqlite iteration uses lazy loaders — data is _UNLOADED until accessed.""" import tempfile - from dimos.memory.impl.sqlite import SqliteStore - from dimos.memory.type import _Unloaded + from dimos.memory2.impl.sqlite import SqliteStore + from dimos.memory2.type import _Unloaded with tempfile.NamedTemporaryFile(suffix=".db") as f: store = SqliteStore(path=f.name) @@ -299,8 +299,8 @@ def test_sqlite_eager_loads_inline(self) -> None: """With eager_blobs=True, data is loaded via JOIN — no lazy loader.""" import tempfile - from dimos.memory.impl.sqlite import SqliteStore - from dimos.memory.type import _Unloaded + from dimos.memory2.impl.sqlite import SqliteStore + from dimos.memory2.type import _Unloaded with tempfile.NamedTemporaryFile(suffix=".db") as f: store = SqliteStore(path=f.name) @@ -319,7 +319,7 @@ def test_sqlite_lazy_and_eager_same_values(self) -> None: """Both paths must return identical data.""" import tempfile - from dimos.memory.impl.sqlite import SqliteStore + from dimos.memory2.impl.sqlite import SqliteStore with tempfile.NamedTemporaryFile(suffix=".db") as f: store = SqliteStore(path=f.name) @@ -341,9 +341,9 @@ def test_sqlite_lazy_and_eager_same_values(self) -> None: def test_memory_lazy_with_blobstore(self) -> None: """MemoryStore with a BlobStore uses lazy loaders.""" - from dimos.memory.blobstore.file import FileBlobStore - from dimos.memory.impl.memory import MemoryStore - from dimos.memory.type import _Unloaded + from dimos.memory2.blobstore.file import FileBlobStore + from dimos.memory2.impl.memory import MemoryStore + from dimos.memory2.type import _Unloaded store = MemoryStore() import tempfile @@ -435,7 +435,7 @@ class SpyCase: @contextmanager def memory_spy_session() -> Generator[tuple[Session, SpyBlobStore, SpyVectorStore], None, None]: - from dimos.memory.impl.memory import MemoryStore + from dimos.memory2.impl.memory import MemoryStore blob_spy = SpyBlobStore() vec_spy = SpyVectorStore() @@ -448,7 +448,7 @@ def memory_spy_session() -> Generator[tuple[Session, SpyBlobStore, SpyVectorStor def sqlite_spy_session() -> Generator[tuple[Session, SpyBlobStore, SpyVectorStore], None, None]: import tempfile - from dimos.memory.impl.sqlite import SqliteStore + from dimos.memory2.impl.sqlite import SqliteStore blob_spy = SpyBlobStore() vec_spy = SpyVectorStore() diff --git a/dimos/memory/test_save.py b/dimos/memory2/test_save.py similarity index 95% rename from dimos/memory/test_save.py rename to dimos/memory2/test_save.py index ba672f76fd..74c1be89f0 100644 --- a/dimos/memory/test_save.py +++ b/dimos/memory2/test_save.py @@ -18,11 +18,11 @@ import pytest -from dimos.memory.backend import Backend, LiveChannel -from dimos.memory.impl.memory import ListBackend -from dimos.memory.stream import Stream -from dimos.memory.transform import FnTransformer -from dimos.memory.type import Observation +from dimos.memory2.backend import Backend, LiveChannel +from dimos.memory2.impl.memory import ListBackend +from dimos.memory2.stream import Stream +from dimos.memory2.transform import FnTransformer +from dimos.memory2.type import Observation # ── Helpers ────────────────────────────────────────────────────────── diff --git a/dimos/memory/test_stream.py b/dimos/memory2/test_stream.py similarity index 99% rename from dimos/memory/test_stream.py rename to dimos/memory2/test_stream.py index 1fa4bdbbb2..a2261036b3 100644 --- a/dimos/memory/test_stream.py +++ b/dimos/memory2/test_stream.py @@ -25,13 +25,13 @@ import pytest -from dimos.memory.buffer import KeepLast, Unbounded -from dimos.memory.impl.memory import MemoryStore -from dimos.memory.transform import FnTransformer, QualityWindow, Transformer -from dimos.memory.type import Observation +from dimos.memory2.buffer import KeepLast, Unbounded +from dimos.memory2.impl.memory import MemoryStore +from dimos.memory2.transform import FnTransformer, QualityWindow, Transformer +from dimos.memory2.type import Observation if TYPE_CHECKING: - from dimos.memory.stream import Stream + from dimos.memory2.stream import Stream # ── Helpers ────────────────────────────────────────────────────────── diff --git a/dimos/memory/transform.py b/dimos/memory2/transform.py similarity index 97% rename from dimos/memory/transform.py rename to dimos/memory2/transform.py index ebdb6416cf..d68e25344a 100644 --- a/dimos/memory/transform.py +++ b/dimos/memory2/transform.py @@ -18,12 +18,12 @@ import inspect from typing import TYPE_CHECKING, Any, Generic, TypeVar -from dimos.memory.formatting import FilterRepr +from dimos.memory2.formatting import FilterRepr if TYPE_CHECKING: from collections.abc import Callable, Iterator - from dimos.memory.type import Observation + from dimos.memory2.type import Observation T = TypeVar("T") R = TypeVar("R") diff --git a/dimos/memory/type.py b/dimos/memory2/type.py similarity index 100% rename from dimos/memory/type.py rename to dimos/memory2/type.py diff --git a/dimos/memory/vectorstore/__init__.py b/dimos/memory2/vectorstore/__init__.py similarity index 83% rename from dimos/memory/vectorstore/__init__.py rename to dimos/memory2/vectorstore/__init__.py index fa9ff33c8a..d8f3395cb8 100644 --- a/dimos/memory/vectorstore/__init__.py +++ b/dimos/memory2/vectorstore/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.memory.vectorstore.memory import MemoryVectorStore -from dimos.memory.vectorstore.sqlite import SqliteVectorStore +from dimos.memory2.vectorstore.memory import MemoryVectorStore +from dimos.memory2.vectorstore.sqlite import SqliteVectorStore __all__ = ["MemoryVectorStore", "SqliteVectorStore"] diff --git a/dimos/memory/vectorstore/memory.py b/dimos/memory2/vectorstore/memory.py similarity index 97% rename from dimos/memory/vectorstore/memory.py rename to dimos/memory2/vectorstore/memory.py index 3fcad4e02a..22532c6ad1 100644 --- a/dimos/memory/vectorstore/memory.py +++ b/dimos/memory2/vectorstore/memory.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING -from dimos.memory.backend import VectorStore +from dimos.memory2.backend import VectorStore if TYPE_CHECKING: from dimos.models.embedding.base import Embedding diff --git a/dimos/memory/vectorstore/sqlite.py b/dimos/memory2/vectorstore/sqlite.py similarity index 98% rename from dimos/memory/vectorstore/sqlite.py rename to dimos/memory2/vectorstore/sqlite.py index 5ff49a2255..736cc16e27 100644 --- a/dimos/memory/vectorstore/sqlite.py +++ b/dimos/memory2/vectorstore/sqlite.py @@ -17,7 +17,7 @@ import json from typing import TYPE_CHECKING -from dimos.memory.backend import VectorStore +from dimos.memory2.backend import VectorStore if TYPE_CHECKING: import sqlite3 diff --git a/dimos/memory_old/impl/sqlite.py b/dimos/memory_old/impl/sqlite.py deleted file mode 100644 index 20caceb8a7..0000000000 --- a/dimos/memory_old/impl/sqlite.py +++ /dev/null @@ -1,14 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/dimos/memory_old/test_embedding.py b/dimos/memory_old/test_embedding.py deleted file mode 100644 index 5e8de6b3bf..0000000000 --- a/dimos/memory_old/test_embedding.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest - -from dimos.memory_old.embedding import EmbeddingMemory, SpatialEntry -from dimos.msgs.geometry_msgs import PoseStamped -from dimos.utils.data import get_data -from dimos.utils.testing import TimedSensorReplay - -dir_name = "unitree_go2_bigoffice" - - -@pytest.mark.skip -def test_embed_frame() -> None: - """Test embedding a single frame.""" - # Load a frame from recorded data - video = TimedSensorReplay(get_data(dir_name) / "video") - frame = video.find_closest_seek(10) - - # Create memory and embed - memory = EmbeddingMemory() - - try: - # Create a spatial entry with dummy pose (no TF needed for this test) - dummy_pose = PoseStamped( - position=[0, 0, 0], - orientation=[0, 0, 0, 1], # identity quaternion - ) - spatial_entry = SpatialEntry(image=frame, pose=dummy_pose) - - # Embed the frame - result = memory._embed_spatial_entry(spatial_entry) - - # Verify - assert result is not None - assert result.embedding is not None - assert result.embedding.vector is not None - print(f"Embedding shape: {result.embedding.vector.shape}") - print(f"Embedding vector (first 5): {result.embedding.vector[:5]}") - finally: - memory.stop() diff --git a/dimos/protocol/tf/tf.py b/dimos/protocol/tf/tf.py index 6e75813d8b..825e89fc8c 100644 --- a/dimos/protocol/tf/tf.py +++ b/dimos/protocol/tf/tf.py @@ -20,7 +20,7 @@ from functools import reduce from typing import TypeVar -from dimos.memory_old.timeseries.inmemory import InMemoryStore +from dimos.memory.timeseries.inmemory import InMemoryStore from dimos.msgs.geometry_msgs import PoseStamped, Transform from dimos.msgs.tf2_msgs import TFMessage from dimos.protocol.pubsub.impl.lcmpubsub import LCM, Topic diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 3db0944800..e4e7705925 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -82,7 +82,6 @@ "unitree-go2-basic": "dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic:unitree_go2_basic", "unitree-go2-detection": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_detection:unitree_go2_detection", "unitree-go2-fleet": "dimos.robot.unitree.go2.blueprints.basic.unitree_go2_fleet:unitree_go2_fleet", - "unitree-go2-memory": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_memory:unitree_go2_memory", "unitree-go2-ros": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_ros:unitree_go2_ros", "unitree-go2-spatial": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_spatial:unitree_go2_spatial", "unitree-go2-temporal-memory": "dimos.robot.unitree.go2.blueprints.agentic.unitree_go2_temporal_memory:unitree_go2_temporal_memory", diff --git a/dimos/robot/cli/dimos.py b/dimos/robot/cli/dimos.py index 7c024469c1..1137a612f3 100644 --- a/dimos/robot/cli/dimos.py +++ b/dimos/robot/cli/dimos.py @@ -523,15 +523,6 @@ def top(ctx: typer.Context) -> None: dtop_main() -@main.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) -def ps(ctx: typer.Context) -> None: - """List running worker processes (non-interactive).""" - from dimos.utils.cli.dps import main as dps_main - - sys.argv = ["dps", *ctx.args] - dps_main() - - topic_app = typer.Typer(help="Topic commands for pub/sub") main.add_typer(topic_app, name="topic") diff --git a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_memory.py b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_memory.py deleted file mode 100644 index 4891d307e2..0000000000 --- a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_memory.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from dataclasses import dataclass -from typing import TYPE_CHECKING - -from dimos.core.blueprints import autoconnect -from dimos.core.core import rpc -from dimos.memory.module import MemoryModule, MemoryModuleConfig -from dimos.memory.transformer import EmbeddingTransformer -from dimos.models.embedding.clip import CLIPModel -from dimos.robot.unitree.go2.blueprints.smart.unitree_go2 import unitree_go2 - -if TYPE_CHECKING: - from dimos.core.stream import In - from dimos.memory.stream import Stream - from dimos.models.embedding.base import Embedding - from dimos.msgs.sensor_msgs import Image, PointCloud2 - - -@dataclass -class UnitreeGo2MemoryConfig(MemoryModuleConfig): - image_fps: float = 5.0 - - -class UnitreeGo2Memory(MemoryModule): - color_image: In[Image] - lidar: In[PointCloud2] - config: UnitreeGo2MemoryConfig # type: ignore[assignment] - default_config: type[UnitreeGo2MemoryConfig] = UnitreeGo2MemoryConfig - - @rpc - def start(self) -> None: - super().start() - - self.image_memory: Stream[Image] = self.memory( - self.color_image, - ) - - self.pointcloud_memory: Stream[PointCloud2] = self.memory(self.lidar) - - clip = CLIPModel() - clip.start() - self._disposables.add(clip) - - self.image_embeddings: Stream[Embedding] = self.image_memory.transform( - EmbeddingTransformer(clip), live=True - ).store("clip_embeddings") - - -unitree_go2_memory = autoconnect( - unitree_go2, - UnitreeGo2Memory.blueprint(), -).global_config(n_workers=8) diff --git a/dimos/types/test_timestamped.py b/dimos/types/test_timestamped.py index eaad794384..7de82e8f9a 100644 --- a/dimos/types/test_timestamped.py +++ b/dimos/types/test_timestamped.py @@ -19,7 +19,7 @@ from reactivex import operators as ops from reactivex.scheduler import ThreadPoolScheduler -from dimos.memory_old.timeseries.inmemory import InMemoryStore +from dimos.memory.timeseries.inmemory import InMemoryStore from dimos.msgs.sensor_msgs import Image from dimos.types.timestamped import ( Timestamped, diff --git a/dimos/types/timestamped.py b/dimos/types/timestamped.py index a02cd392e1..b229a2478e 100644 --- a/dimos/types/timestamped.py +++ b/dimos/types/timestamped.py @@ -22,7 +22,7 @@ # from dimos_lcm.std_msgs import Time as ROSTime from reactivex.observable import Observable -from dimos.memory_old.timeseries.inmemory import InMemoryStore +from dimos.memory.timeseries.inmemory import InMemoryStore from dimos.types.weaklist import WeakList from dimos.utils.logging_config import setup_logger diff --git a/dimos/utils/cli/dps.py b/dimos/utils/cli/dps.py deleted file mode 100644 index 0ab36d5a71..0000000000 --- a/dimos/utils/cli/dps.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""dps — Non-interactive process list over LCM (like `docker ps`). - -Waits for one dtop resource_stats message and prints a table. - -Usage: - dps [--topic /dimos/resource_stats] [--timeout 5] -""" - -from __future__ import annotations - -import sys -import threading -from typing import Any - -from rich.console import Console -from rich.table import Table - -from dimos.protocol.pubsub.impl.lcmpubsub import PickleLCM, Topic - - -def _fmt_pct(v: float) -> str: - return f"{v:.0f}%" - - -def _fmt_mem(v: float) -> str: - mb = v / 1048576 - if mb >= 1024: - return f"{mb / 1024:.1f}G" - return f"{mb:.0f}M" - - -def _fmt_secs(v: float) -> str: - if v >= 3600: - return f"{v / 3600:.1f}h" - if v >= 60: - return f"{v / 60:.1f}m" - return f"{v:.1f}s" - - -def ps(topic: str = "/dimos/resource_stats", timeout: float = 5.0) -> None: - """Wait for one LCM message and print a process table.""" - lcm = PickleLCM(autoconf=True) - result: dict[str, Any] = {} - event = threading.Event() - - def on_msg(msg: dict[str, Any], _topic: str) -> None: - nonlocal result - result = msg - event.set() - - lcm.subscribe(Topic(topic), on_msg) - lcm.start() - - if not event.wait(timeout): - lcm.stop() - Console(stderr=True).print( - f"[red]No dtop message within {timeout:.0f}s. Is --dtop enabled?[/red]" - ) - sys.exit(1) - - lcm.stop() - - table = Table(show_header=True, header_style="bold", padding=(0, 1)) - table.add_column("PID", style="dim") - table.add_column("Role") - table.add_column("Modules") - table.add_column("CPU", justify="right") - table.add_column("Mem", justify="right") - table.add_column("Threads", justify="right") - table.add_column("FDs", justify="right") - table.add_column("User", justify="right") - table.add_column("Sys", justify="right") - - coord = result.get("coordinator", {}) - table.add_row( - str(coord.get("pid", "")), - "[cyan]coordinator[/cyan]", - "", - _fmt_pct(coord.get("cpu_percent", 0)), - _fmt_mem(coord.get("pss", 0)), - str(int(coord.get("num_threads", 0))), - str(int(coord.get("num_fds", 0))), - _fmt_secs(coord.get("cpu_time_user", 0)), - _fmt_secs(coord.get("cpu_time_system", 0)), - ) - - for w in result.get("workers", []): - alive = w.get("alive", False) - wid = w.get("worker_id", "?") - role_style = "green" if alive else "red" - modules = ", ".join(w.get("modules", [])) - table.add_row( - str(w.get("pid", "")), - f"[{role_style}]worker {wid}[/{role_style}]", - modules, - _fmt_pct(w.get("cpu_percent", 0)), - _fmt_mem(w.get("pss", 0)), - str(int(w.get("num_threads", 0))), - str(int(w.get("num_fds", 0))), - _fmt_secs(w.get("cpu_time_user", 0)), - _fmt_secs(w.get("cpu_time_system", 0)), - ) - - Console().print(table) - - -def main() -> None: - topic = "/dimos/resource_stats" - timeout = 5.0 - args = sys.argv[1:] - i = 0 - while i < len(args): - if args[i] == "--topic" and i + 1 < len(args): - topic = args[i + 1] - i += 2 - elif args[i] == "--timeout" and i + 1 < len(args): - timeout = float(args[i + 1]) - i += 2 - else: - i += 1 - ps(topic=topic, timeout=timeout) - - -if __name__ == "__main__": - main() diff --git a/dimos/utils/testing/replay.py b/dimos/utils/testing/replay.py index 68d3ca8fe8..588b63e099 100644 --- a/dimos/utils/testing/replay.py +++ b/dimos/utils/testing/replay.py @@ -14,7 +14,7 @@ """Shim for TimedSensorReplay/TimedSensorStorage.""" -from dimos.memory_old.timeseries.legacy import LegacyPickleStore +from dimos.memory.timeseries.legacy import LegacyPickleStore SensorReplay = LegacyPickleStore SensorStorage = LegacyPickleStore diff --git a/pyproject.toml b/pyproject.toml index 45d396c358..10a6c40ec8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,6 @@ dimos = "dimos.robot.cli.dimos:main" rerun-bridge = "dimos.visualization.rerun.bridge:app" doclinks = "dimos.utils.docs.doclinks:main" dtop = "dimos.utils.cli.dtop:main" -dps = "dimos.utils.cli.dps:main" [project.urls] Homepage = "https://dimensionalos.com" From 2076ba4b372edbb6665ca63169dcfc521c198dd3 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 12 Mar 2026 20:29:29 +0800 Subject: [PATCH 097/118] Remove stray old memory module references - Delete empty dimos/memory/impl/sqlite.py - Remove nonexistent memory-module entry from all_blueprints - Restore codeblocks.md from dev --- dimos/memory/impl/sqlite.py | 14 -------------- dimos/robot/all_blueprints.py | 1 - docs/agents/docs/codeblocks.md | 14 +++++++------- 3 files changed, 7 insertions(+), 22 deletions(-) delete mode 100644 dimos/memory/impl/sqlite.py diff --git a/dimos/memory/impl/sqlite.py b/dimos/memory/impl/sqlite.py deleted file mode 100644 index 20caceb8a7..0000000000 --- a/dimos/memory/impl/sqlite.py +++ /dev/null @@ -1,14 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index e4e7705925..e82cb656ce 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -122,7 +122,6 @@ "manipulation-module": "dimos.manipulation.manipulation_module", "mapper": "dimos.robot.unitree.type.map", "mcp-client": "dimos.agents.mcp.mcp_client", - "memory-module": "dimos.memory.module", "mid360-module": "dimos.hardware.sensors.lidar.livox.module", "navigation-skill": "dimos.agents.skills.navigation", "object-scene-registration-module": "dimos.perception.object_scene_registration", diff --git a/docs/agents/docs/codeblocks.md b/docs/agents/docs/codeblocks.md index d56ee97015..323f1c0c50 100644 --- a/docs/agents/docs/codeblocks.md +++ b/docs/agents/docs/codeblocks.md @@ -22,13 +22,13 @@ Python, Shell (sh), Node.js, plus visualization: Matplotlib, Graphviz, Pikchr, A Add flags after the language identifier: -| Flag | Effect | -|-------------------|---------------------------------------------------| -| `session=NAME` | Share state between blocks with same session name | -| `output=path.png` | Write output to file instead of inline | -| `no-result` | Execute but don't insert result | -| `skip` | Don't execute this block | -| `expected-error` | Block is expected to fail | +| Flag | Effect | +|------|--------| +| `session=NAME` | Share state between blocks with same session name | +| `output=path.png` | Write output to file instead of inline | +| `no-result` | Execute but don't insert result | +| `skip` | Don't execute this block | +| `expected-error` | Block is expected to fail | ## Examples From 05c091d488d9135727c0663a15d6e9d20c82f241 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 12 Mar 2026 20:44:43 +0800 Subject: [PATCH 098/118] Remove LFS test databases from PR These were added during development but shouldn't be in the PR. --- data/.lfs/go2_bigoffice.db.tar.gz | 3 --- data/.lfs/go2_bigoffice_v2.db.tar.gz | 3 --- 2 files changed, 6 deletions(-) delete mode 100644 data/.lfs/go2_bigoffice.db.tar.gz delete mode 100644 data/.lfs/go2_bigoffice_v2.db.tar.gz diff --git a/data/.lfs/go2_bigoffice.db.tar.gz b/data/.lfs/go2_bigoffice.db.tar.gz deleted file mode 100644 index 843a97b9b1..0000000000 --- a/data/.lfs/go2_bigoffice.db.tar.gz +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:6a4ce6670e3a48fdf378188ababe8dc607ed83b5160802ff7f309aa43f8e72ce -size 406735715 diff --git a/data/.lfs/go2_bigoffice_v2.db.tar.gz b/data/.lfs/go2_bigoffice_v2.db.tar.gz deleted file mode 100644 index f091edf861..0000000000 --- a/data/.lfs/go2_bigoffice_v2.db.tar.gz +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:7f7a94b136be71044b8c4f645eaa9fbc672df5d241ae9ceb2f5de5f85ffb3668 -size 254406884 From 0570bc3616ff15e3ae5b0da7a89da97d2e0df986 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 12 Mar 2026 20:53:16 +0800 Subject: [PATCH 099/118] Address review findings: SQL injection guards, type fixes, cleanup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove dead dict(hits) and thread-affinity assertion in SqliteBackend - Validate order_field and tag keys against _IDENT_RE to prevent SQL injection - Replace assert bs is not None with RuntimeError for -O safety - Add hash=False to NearFilter.pose, TagsFilter.tags, PredicateFilter.fn - Collapse CaptionDetail enum to 3 distinct levels (BRIEF/NORMAL/DETAILED) - Fix Stream.map() return type: Stream[Any] → Stream[R] - Update architecture.md: SqliteBackend status Stub → Complete - Document SqliteBlobStore commit responsibility - Guard ImageDetections.ts against image=None --- dimos/memory2/architecture.md | 2 +- dimos/memory2/blobstore/sqlite.py | 1 + dimos/memory2/filter.py | 8 ++++---- dimos/memory2/impl/sqlite.py | 13 ++++++++----- dimos/memory2/stream.py | 2 +- dimos/models/vl/florence.py | 5 ++--- dimos/perception/detection/type/imageDetections.py | 2 ++ 7 files changed, 19 insertions(+), 14 deletions(-) diff --git a/dimos/memory2/architecture.md b/dimos/memory2/architecture.md index 4acaf4a7f9..f1026040d3 100644 --- a/dimos/memory2/architecture.md +++ b/dimos/memory2/architecture.md @@ -112,4 +112,4 @@ with store.session() as session: | Backend | Status | Storage | |-----------------|----------|----------------------------------------| | `ListBackend` | Complete | In-memory (lists + brute-force search) | -| `SqliteBackend` | Stub | SQLite (WAL, FTS5, vec0) | +| `SqliteBackend` | Complete | SQLite (WAL, FTS5, vec0) | diff --git a/dimos/memory2/blobstore/sqlite.py b/dimos/memory2/blobstore/sqlite.py index 0fd144c532..fac00bef7b 100644 --- a/dimos/memory2/blobstore/sqlite.py +++ b/dimos/memory2/blobstore/sqlite.py @@ -33,6 +33,7 @@ class SqliteBlobStore(BlobStore): ); Does NOT own the connection — lifecycle managed externally. + Does NOT commit; the caller (typically SqliteBackend) is responsible for commits. """ def __init__(self, conn: sqlite3.Connection) -> None: diff --git a/dimos/memory2/filter.py b/dimos/memory2/filter.py index 8c80546ef4..4a1846d3e1 100644 --- a/dimos/memory2/filter.py +++ b/dimos/memory2/filter.py @@ -81,8 +81,8 @@ def matches(self, obs: Observation[Any]) -> bool: @dataclass(frozen=True) class NearFilter(Filter): - pose: Any - radius: float + pose: Any = field(hash=False) + radius: float = 0.0 def matches(self, obs: Observation[Any]) -> bool: if obs.pose is None or self.pose is None: @@ -109,7 +109,7 @@ def _xyz(p: Any) -> tuple[float, float, float]: @dataclass(frozen=True) class TagsFilter(Filter): - tags: dict[str, Any] + tags: dict[str, Any] = field(default_factory=dict, hash=False) def matches(self, obs: Observation[Any]) -> bool: for k, v in self.tags.items(): @@ -122,7 +122,7 @@ def matches(self, obs: Observation[Any]) -> bool: class PredicateFilter(Filter): """Wraps an arbitrary predicate function for use with .filter().""" - fn: Callable[[Observation[Any]], bool] + fn: Callable[[Observation[Any]], bool] = field(hash=False) def matches(self, obs: Observation[Any]) -> bool: return bool(self.fn(obs)) diff --git a/dimos/memory2/impl/sqlite.py b/dimos/memory2/impl/sqlite.py index 03857d9bd8..a07711bc63 100644 --- a/dimos/memory2/impl/sqlite.py +++ b/dimos/memory2/impl/sqlite.py @@ -111,6 +111,8 @@ def _compile_filter(f: Filter, stream: str, prefix: str = "") -> tuple[str, list clauses = [] params: list[Any] = [] for k, v in f.tags.items(): + if not _IDENT_RE.match(k): + raise ValueError(f"Invalid tag key: {k!r}") clauses.append(f"json_extract({prefix}tags, '$.{k}') = ?") params.append(v) return (" AND ".join(clauses), params) @@ -192,6 +194,8 @@ def _compile_query( # ORDER BY if query.order_field: + if not _IDENT_RE.match(query.order_field): + raise ValueError(f"Invalid order_field: {query.order_field!r}") direction = "DESC" if query.order_desc else "ASC" sql += f" ORDER BY {prefix}{query.order_field} {direction}" else: @@ -269,12 +273,11 @@ def _join_blobs(self) -> bool: def _make_loader(self, row_id: int) -> Any: bs = self.config.blob_store - assert bs is not None + if bs is None: + raise RuntimeError("BlobStore required but not configured") name, codec = self._name, self._codec - owner_tid = threading.get_ident() def loader() -> Any: - assert threading.get_ident() == owner_tid raw = bs.get(name, row_id) return codec.decode(raw) @@ -337,7 +340,8 @@ def append(self, obs: Observation[T]) -> Observation[T]: assert row_id is not None bs = self.config.blob_store - assert bs is not None + if bs is None: + raise RuntimeError("BlobStore required but not configured") bs.put(self._name, row_id, encoded) # R*Tree spatial index @@ -410,7 +414,6 @@ def _vector_search(self, query: StreamQuery) -> Iterator[Observation[T]]: return ids = [h[0] for h in hits] - dict(hits) # Batch-fetch metadata join = self._join_blobs diff --git a/dimos/memory2/stream.py b/dimos/memory2/stream.py index df6dc4636a..14eb7cd5ee 100644 --- a/dimos/memory2/stream.py +++ b/dimos/memory2/stream.py @@ -195,7 +195,7 @@ def filter(self, pred: Callable[[Observation[T]], bool]) -> Stream[T]: """Filter by arbitrary predicate on the full Observation.""" return self._with_filter(PredicateFilter(pred)) - def map(self, fn: Callable[[Observation[T]], Observation[R]]) -> Stream[Any]: + def map(self, fn: Callable[[Observation[T]], Observation[R]]) -> Stream[R]: """Transform each observation's data via callable.""" return self.transform(FnTransformer(lambda obs: fn(obs))) diff --git a/dimos/models/vl/florence.py b/dimos/models/vl/florence.py index 19d99ccdfd..a42b06b6a7 100644 --- a/dimos/models/vl/florence.py +++ b/dimos/models/vl/florence.py @@ -28,9 +28,8 @@ class CaptionDetail(Enum): """Florence-2 caption detail level.""" BRIEF = "" - NORMAL = "" - DETAILED = "" - MORE_DETAILED = "" + NORMAL = "" + DETAILED = "" class Florence2Model(HuggingFaceModel, Captioner): diff --git a/dimos/perception/detection/type/imageDetections.py b/dimos/perception/detection/type/imageDetections.py index a3d8acebd1..9ef31a34dd 100644 --- a/dimos/perception/detection/type/imageDetections.py +++ b/dimos/perception/detection/type/imageDetections.py @@ -43,6 +43,8 @@ class ImageDetections(Generic[T], TableStr): @property def ts(self) -> float: + if self.image is None: + return 0.0 return self.image.ts def __init__(self, image: Image | None = None, detections: list[T] | None = None) -> None: From 2dcfcd93064e9bbbdeb3f8d143d725580a574e9d Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 12 Mar 2026 20:59:32 +0800 Subject: [PATCH 100/118] Revert detection type changes: keep image as required field MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Restores detection2d/bbox.py, imageDetections.py, and utils.py to dev state — the image-optional decoupling is not needed for memory2. --- dimos/perception/detection/type/detection2d/bbox.py | 2 +- dimos/perception/detection/type/imageDetections.py | 11 ++++------- dimos/perception/detection/type/utils.py | 8 ++++---- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/dimos/perception/detection/type/detection2d/bbox.py b/dimos/perception/detection/type/detection2d/bbox.py index 6022e010cb..45dc848e9d 100644 --- a/dimos/perception/detection/type/detection2d/bbox.py +++ b/dimos/perception/detection/type/detection2d/bbox.py @@ -85,7 +85,7 @@ class Detection2DBBox(Detection2D): confidence: float name: str ts: float - image: Image | None + image: Image def to_repr_dict(self) -> dict[str, Any]: """Return a dictionary representation of the detection for display purposes.""" diff --git a/dimos/perception/detection/type/imageDetections.py b/dimos/perception/detection/type/imageDetections.py index 9ef31a34dd..12a1f4efb9 100644 --- a/dimos/perception/detection/type/imageDetections.py +++ b/dimos/perception/detection/type/imageDetections.py @@ -43,17 +43,14 @@ class ImageDetections(Generic[T], TableStr): @property def ts(self) -> float: - if self.image is None: - return 0.0 return self.image.ts - def __init__(self, image: Image | None = None, detections: list[T] | None = None) -> None: + def __init__(self, image: Image, detections: list[T] | None = None) -> None: self.image = image self.detections = detections or [] - if image is not None: - for det in self.detections: - if not det.ts: - det.ts = image.ts + for det in self.detections: + if not det.ts: + det.ts = image.ts def __len__(self) -> int: return len(self.detections) diff --git a/dimos/perception/detection/type/utils.py b/dimos/perception/detection/type/utils.py index 35c3909698..eb924cbd1a 100644 --- a/dimos/perception/detection/type/utils.py +++ b/dimos/perception/detection/type/utils.py @@ -53,18 +53,18 @@ class TableStr: def __str__(self) -> str: console = Console(force_terminal=True, legacy_windows=False) - ts_str = f"{to_timestamp(self.image.ts):.3f}" if self.image is not None else "?" # type: ignore[attr-defined] - # Create a table for detections table = Table( - title=f"{self.__class__.__name__} [{len(self.detections)} detections @ {ts_str}]", # type: ignore[attr-defined] + title=f"{self.__class__.__name__} [{len(self.detections)} detections @ {to_timestamp(self.image.ts):.3f}]", # type: ignore[attr-defined] show_header=True, show_edge=True, ) # Dynamically build columns based on the first detection's dict keys if not self.detections: # type: ignore[attr-defined] - return f" {self.__class__.__name__} [0 detections @ {ts_str}]" # type: ignore[attr-defined] + return ( + f" {self.__class__.__name__} [0 detections @ {to_timestamp(self.image.ts):.3f}]" # type: ignore[attr-defined] + ) # Cache all repr_dicts to avoid double computation detection_dicts = [det.to_repr_dict() for det in self] # type: ignore[attr-defined] From e88e0e5c762dc911f10ab2bd010b7c0a3309203c Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 12 Mar 2026 21:49:25 +0800 Subject: [PATCH 101/118] add libturbojpeg to docker image --- docker/python/Dockerfile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docker/python/Dockerfile b/docker/python/Dockerfile index 16b4db1807..d14b281603 100644 --- a/docker/python/Dockerfile +++ b/docker/python/Dockerfile @@ -31,7 +31,8 @@ RUN apt-get update && apt-get install -y \ qtbase5-dev-tools \ supervisor \ iproute2 # for LCM networking system config \ - liblcm-dev + liblcm-dev \ + libturbojpeg0-dev # Fix distutils-installed packages that block pip upgrades RUN apt-get purge -y python3-blinker python3-sympy python3-oauthlib || true From f29f766512054ad22d8e8761e76b039f53bdff0c Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 12 Mar 2026 21:57:07 +0800 Subject: [PATCH 102/118] Make turbojpeg import lazy so tests skip gracefully in CI Move top-level turbojpeg import in Image.py to the two methods that use it, and guard jpeg codec tests behind ImportError / importorskip so the test suite passes when libturbojpeg is not installed. --- dimos/memory2/codecs/test_codecs.py | 19 ++++++++++++------- dimos/msgs/sensor_msgs/Image.py | 5 ++++- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/dimos/memory2/codecs/test_codecs.py b/dimos/memory2/codecs/test_codecs.py index 8f3eb17c10..03226b491b 100644 --- a/dimos/memory2/codecs/test_codecs.py +++ b/dimos/memory2/codecs/test_codecs.py @@ -85,22 +85,26 @@ def _jpeg_eq(original: Any, decoded: Any) -> bool: return bool(np.mean(np.abs(decoded.data.astype(float) - original.data.astype(float))) < 5) -def _jpeg_case() -> Case: - from dimos.memory2.codecs.jpeg import JpegCodec - from dimos.utils.testing import TimedSensorReplay +def _jpeg_case() -> Case | None: + try: + from dimos.memory2.codecs.jpeg import JpegCodec + from dimos.utils.testing import TimedSensorReplay - replay = TimedSensorReplay("unitree_go2_bigoffice/video") - frames = [replay.find_closest_seek(float(i)) for i in range(1, 4)] + replay = TimedSensorReplay("unitree_go2_bigoffice/video") + frames = [replay.find_closest_seek(float(i)) for i in range(1, 4)] + codec = JpegCodec(quality=95) + except ImportError: + return None return Case( name="jpeg", - codec=JpegCodec(quality=95), + codec=codec, values=frames, eq=_jpeg_eq, ) -testcases = [_pickle_case(), _lcm_case(), _jpeg_case()] +testcases = [c for c in [_pickle_case(), _lcm_case(), _jpeg_case()] if c is not None] # ── Tests ────────────────────────────────────────────────────────── @@ -148,6 +152,7 @@ def test_lcm_type_returns_lcm(self) -> None: assert isinstance(codec_for(PoseStamped), LcmCodec) def test_image_type_returns_jpeg(self) -> None: + pytest.importorskip("turbojpeg") from dimos.memory2.codecs.jpeg import JpegCodec from dimos.msgs.sensor_msgs.Image import Image diff --git a/dimos/msgs/sensor_msgs/Image.py b/dimos/msgs/sensor_msgs/Image.py index 3f2e049920..8aee99435d 100644 --- a/dimos/msgs/sensor_msgs/Image.py +++ b/dimos/msgs/sensor_msgs/Image.py @@ -27,7 +27,6 @@ import reactivex as rx from reactivex import operators as ops import rerun as rr -from turbojpeg import TurboJPEG # type: ignore[import-untyped] from dimos.types.timestamped import Timestamped, TimestampedBufferCollection, to_human_readable from dimos.utils.reactive import quality_barrier @@ -510,6 +509,8 @@ def lcm_jpeg_encode(self, quality: int = 75, frame_id: str | None = None) -> byt Returns: LCM-encoded bytes with JPEG-compressed image data """ + from turbojpeg import TurboJPEG # type: ignore[import-untyped] + jpeg = TurboJPEG() msg = LCMImage() @@ -555,6 +556,8 @@ def lcm_jpeg_decode(cls, data: bytes, **kwargs: Any) -> Image: Returns: Image instance """ + from turbojpeg import TurboJPEG # type: ignore[import-untyped] + jpeg = TurboJPEG() msg = LCMImage.lcm_decode(data) From c56e2834997f217252307df9047e2030ac4c4c91 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 12 Mar 2026 22:02:39 +0800 Subject: [PATCH 103/118] Give each SqliteBackend its own connection for WAL-mode concurrency MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously all backends shared a single sqlite3.Connection — concurrent writes from different streams could interleave commits/rollbacks. Now SqliteSession opens a dedicated connection per backend, with per-backend blob/vector stores wrapping the same connection for atomicity. A separate registry connection handles the _streams table. Also makes SqliteBackend a CompositeResource so session.own(backend) properly closes connections on stop, and fixes live iterator cleanup in both backends (backfill phase now inside try/finally). --- dimos/memory2/impl/memory.py | 18 +++-- dimos/memory2/impl/sqlite.py | 139 +++++++++++++++++++---------------- 2 files changed, 84 insertions(+), 73 deletions(-) diff --git a/dimos/memory2/impl/memory.py b/dimos/memory2/impl/memory.py index 0525326fcc..1b4fe91b8c 100644 --- a/dimos/memory2/impl/memory.py +++ b/dimos/memory2/impl/memory.py @@ -158,15 +158,15 @@ def _iterate_live( eager = self.config.eager_blobs and self.config.blob_store is not None - # Backfill phase — use snapshot query (without live) for the backfill - last_id = -1 - for obs in self._iterate_snapshot(query): - last_id = max(last_id, obs.id) - yield obs - - # Live tail - filters = query.filters try: + # Backfill phase — use snapshot query (without live) for the backfill + last_id = -1 + for obs in self._iterate_snapshot(query): + last_id = max(last_id, obs.id) + yield obs + + # Live tail + filters = query.filters while True: obs = buf.take() if obs.id <= last_id: @@ -178,6 +178,8 @@ def _iterate_live( _ = obs.data # trigger lazy loader yield obs except (ClosedError, StopIteration): + pass + finally: sub.dispose() def count(self, query: StreamQuery) -> int: diff --git a/dimos/memory2/impl/sqlite.py b/dimos/memory2/impl/sqlite.py index a07711bc63..a370e340a1 100644 --- a/dimos/memory2/impl/sqlite.py +++ b/dimos/memory2/impl/sqlite.py @@ -22,6 +22,7 @@ import threading from typing import TYPE_CHECKING, Any, Generic, TypeVar +from dimos.core.resource import CompositeResource from dimos.memory2.backend import BackendConfig from dimos.memory2.blobstore.sqlite import SqliteBlobStore from dimos.memory2.codecs.base import Codec, codec_for @@ -242,13 +243,18 @@ def _compile_count( # ── SqliteBackend ──────────────────────────────────────────────── -class SqliteBackend(Configurable[BackendConfig], Generic[T]): - """SQLite-backed observation storage for a single stream (table).""" +class SqliteBackend(Configurable[BackendConfig], CompositeResource, Generic[T]): + """SQLite-backed observation storage for a single stream (table). + + Owns its ``sqlite3.Connection``. When disposed (via + ``CompositeResource.stop()``), the connection is closed. + """ default_config: type[BackendConfig] = BackendConfig def __init__(self, conn: sqlite3.Connection, name: str, **kwargs: Any) -> None: - super().__init__(**kwargs) + Configurable.__init__(self, **kwargs) + CompositeResource.__init__(self) self._conn = conn self._name = name self._codec: Codec[Any] = self.config.codec # type: ignore[assignment] @@ -460,15 +466,15 @@ def _iterate_live( ) -> Iterator[Observation[T]]: from dimos.memory2.buffer import ClosedError - # Backfill phase - last_id = -1 - for obs in self._iterate_snapshot(query): - last_id = max(last_id, obs.id) - yield obs - - # Live tail - filters = query.filters try: + # Backfill phase + last_id = -1 + for obs in self._iterate_snapshot(query): + last_id = max(last_id, obs.id) + yield obs + + # Live tail + filters = query.filters while True: obs = buf.take() if obs.id <= last_id: @@ -478,6 +484,8 @@ def _iterate_live( continue yield obs except (ClosedError, StopIteration): + pass + finally: sub.dispose() def count(self, query: StreamQuery) -> int: @@ -491,40 +499,48 @@ def count(self, query: StreamQuery) -> int: row = self._conn.execute(sql, params).fetchone() return int(row[0]) if row else 0 + def stop(self) -> None: + super().stop() + self._conn.close() + # ── SqliteSession ──────────────────────────────────────────────── class SqliteSession(Session): - """Session owning a single SQLite connection.""" + """Session owning a SQLite database. - def __init__( - self, conn: sqlite3.Connection, *, vec_available: bool = False, **kwargs: Any - ) -> None: + Each backend gets its own ``sqlite3.Connection`` so SQLite WAL mode + handles cross-stream concurrency natively. A separate + ``_registry_conn`` is used only for the ``_streams`` registry table. + """ + + def __init__(self, db_path: str, **kwargs: Any) -> None: super().__init__(**kwargs) - self._conn = conn - self._vec_available = vec_available - self._blob_store: SqliteBlobStore | None = None - self._vector_store: Any | None = None + self._db_path = db_path - # Create stream registry - self._conn.execute( + # Dedicated connection for the stream registry table + self._registry_conn = self._open_connection() + self._registry_conn.execute( "CREATE TABLE IF NOT EXISTS _streams (" " name TEXT PRIMARY KEY," " payload_module TEXT NOT NULL," " codec_id TEXT NOT NULL" ")" ) - self._conn.commit() + self._registry_conn.commit() - def _ensure_shared_stores(self) -> None: - """Lazily create shared stores on first stream creation.""" - if self._blob_store is None: - self._blob_store = SqliteBlobStore(self._conn) - if self._vector_store is None and self._vec_available: - from dimos.memory2.vectorstore.sqlite import SqliteVectorStore + def _open_connection(self) -> sqlite3.Connection: + """Open a new WAL-mode connection with sqlite-vec loaded.""" + import sqlite_vec - self._vector_store = SqliteVectorStore(self._conn) + conn = sqlite3.connect(self._db_path, check_same_thread=False) + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA synchronous=NORMAL") + conn.enable_load_extension(True) + sqlite_vec.load(conn) + conn.enable_load_extension(False) + return conn @staticmethod def _codec_id(codec: Codec[Any]) -> str: @@ -563,10 +579,9 @@ def _create_backend( self, name: str, payload_type: type[Any] | None = None, **config: Any ) -> Backend[Any]: _validate_identifier(name) - self._ensure_shared_stores() # Look up existing stream in registry - row = self._conn.execute( + row = self._registry_conn.execute( "SELECT payload_module, codec_id FROM _streams WHERE name = ?", (name,) ).fetchone() @@ -585,14 +600,17 @@ def _create_backend( raise TypeError(f"Stream {name!r} does not exist yet — payload_type is required") codec = config.get("codec") or codec_for(payload_type) payload_module = f"{payload_type.__module__}.{payload_type.__qualname__}" - self._conn.execute( + self._registry_conn.execute( "INSERT INTO _streams (name, payload_module, codec_id) VALUES (?, ?, ?)", (name, payload_module, self._codec_id(codec)), ) - self._conn.commit() + self._registry_conn.commit() + + # Each backend gets its own connection for WAL-mode concurrency + backend_conn = self._open_connection() # Create metadata table - self._conn.execute( + backend_conn.execute( f'CREATE TABLE IF NOT EXISTS "{name}" (' " id INTEGER PRIMARY KEY AUTOINCREMENT," " ts REAL NOT NULL UNIQUE," @@ -602,7 +620,7 @@ def _create_backend( ")" ) # R*Tree spatial index for pose queries - self._conn.execute( + backend_conn.execute( f'CREATE VIRTUAL TABLE IF NOT EXISTS "{name}_rtree" USING rtree(' " id," " x_min, x_max," @@ -610,33 +628,39 @@ def _create_backend( " z_min, z_max" ")" ) - self._conn.commit() + backend_conn.commit() - # Merge shared stores as defaults + # Create per-backend stores wrapping the backend's own connection if "blob_store" not in config or config["blob_store"] is None: - config["blob_store"] = self._blob_store + config["blob_store"] = SqliteBlobStore(backend_conn) if "vector_store" not in config or config["vector_store"] is None: - config["vector_store"] = self._vector_store + from dimos.memory2.vectorstore.sqlite import SqliteVectorStore + + config["vector_store"] = SqliteVectorStore(backend_conn) config["codec"] = codec - return SqliteBackend(self._conn, name, **config) + backend: SqliteBackend[Any] = SqliteBackend(backend_conn, name, **config) + self.own(backend) + return backend def list_streams(self) -> list[str]: - db_names = {row[0] for row in self._conn.execute("SELECT name FROM _streams").fetchall()} + db_names = { + row[0] for row in self._registry_conn.execute("SELECT name FROM _streams").fetchall() + } return sorted(db_names | set(self._streams.keys())) def delete_stream(self, name: str) -> None: self._streams.pop(name, None) - self._conn.execute(f'DROP TABLE IF EXISTS "{name}"') - self._conn.execute(f'DROP TABLE IF EXISTS "{name}_blob"') - self._conn.execute(f'DROP TABLE IF EXISTS "{name}_vec"') - self._conn.execute(f'DROP TABLE IF EXISTS "{name}_rtree"') - self._conn.execute("DELETE FROM _streams WHERE name = ?", (name,)) - self._conn.commit() + self._registry_conn.execute(f'DROP TABLE IF EXISTS "{name}"') + self._registry_conn.execute(f'DROP TABLE IF EXISTS "{name}_blob"') + self._registry_conn.execute(f'DROP TABLE IF EXISTS "{name}_vec"') + self._registry_conn.execute(f'DROP TABLE IF EXISTS "{name}_rtree"') + self._registry_conn.execute("DELETE FROM _streams WHERE name = ?", (name,)) + self._registry_conn.commit() def stop(self) -> None: - super().stop() - self._conn.close() + super().stop() # disposes owned backends (closes their connections) + self._registry_conn.close() # ── SqliteStore ────────────────────────────────────────────────── @@ -659,19 +683,4 @@ def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) def session(self, **kwargs: Any) -> SqliteSession: - conn = sqlite3.connect(self.config.path, check_same_thread=False) - conn.execute("PRAGMA journal_mode=WAL") - conn.execute("PRAGMA synchronous=NORMAL") - - vec_available = False - try: - import sqlite_vec - - conn.enable_load_extension(True) - sqlite_vec.load(conn) - conn.enable_load_extension(False) - vec_available = True - except (ImportError, Exception): - pass - - return SqliteSession(conn, vec_available=vec_available, **kwargs) + return SqliteSession(self.config.path, **kwargs) From 93d6afeae1ab7631c3733f39fcf4fe2d045379f0 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 12 Mar 2026 22:29:46 +0800 Subject: [PATCH 104/118] Block search_text on SqliteBackend to prevent full table scans MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit search_text previously loaded every blob from the DB and did Python substring matching — a silent full table scan. Raise NotImplementedError instead until proper SQL pushdown is implemented. --- dimos/memory2/impl/sqlite.py | 15 ++++++--------- dimos/memory2/test_impl.py | 10 +++++++--- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/dimos/memory2/impl/sqlite.py b/dimos/memory2/impl/sqlite.py index a370e340a1..f2dc514d1f 100644 --- a/dimos/memory2/impl/sqlite.py +++ b/dimos/memory2/impl/sqlite.py @@ -203,7 +203,7 @@ def _compile_query( sql += f" ORDER BY {prefix}id ASC" # Only push LIMIT/OFFSET to SQL when there are no Python post-filters - if not python_filters and not query.search_text: + if not python_filters: if query.limit_val is not None: if query.offset_val: sql += f" LIMIT {query.limit_val} OFFSET {query.offset_val}" @@ -382,6 +382,9 @@ def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: return self._iterate_snapshot(query) def _iterate_snapshot(self, query: StreamQuery) -> Iterator[Observation[T]]: + if query.search_text is not None: + raise NotImplementedError("search_text is not supported by SqliteBackend") + if query.search_vec is not None and self.config.vector_store is not None: yield from self._vector_search(query) return @@ -393,17 +396,11 @@ def _iterate_snapshot(self, query: StreamQuery) -> Iterator[Observation[T]]: cur.arraysize = self.config.page_size it: Iterator[Observation[T]] = (self._row_to_obs(r, has_blob=join) for r in cur) - # Text search — requires loading data - if query.search_text is not None: - needle = query.search_text.lower() - it = (obs for obs in it if needle in str(obs.data).lower()) - # Apply Python post-filters if python_filters: it = (obs for obs in it if all(f.matches(obs) for f in python_filters)) - # Apply LIMIT/OFFSET in Python when we couldn't push to SQL - if python_filters or query.search_text: + # Apply LIMIT/OFFSET in Python when we couldn't push to SQL if query.offset_val: it = islice(it, query.offset_val, None) if query.limit_val is not None: @@ -489,7 +486,7 @@ def _iterate_live( sub.dispose() def count(self, query: StreamQuery) -> int: - if query.search_vec or query.search_text: + if query.search_vec: return sum(1 for _ in self.iterate(query)) sql, params, python_filters = _compile_count(query, self._name) diff --git a/dimos/memory2/test_impl.py b/dimos/memory2/test_impl.py index 3846654455..bcbd06ffd1 100644 --- a/dimos/memory2/test_impl.py +++ b/dimos/memory2/test_impl.py @@ -260,9 +260,13 @@ def test_search_text(self, case: Case) -> None: s.append("motor fault") s.append("temperature ok") - results = s.search_text("motor").fetch() - assert len(results) == 1 - assert results[0].data == "motor fault" + if case.name == "sqlite": + with pytest.raises(NotImplementedError): + s.search_text("motor").fetch() + else: + results = s.search_text("motor").fetch() + assert len(results) == 1 + assert results[0].data == "motor fault" # ── Lazy / eager blob loading tests ────────────────────────────── From 317562c773ceb33e8441fe7d39c7596420d36431 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 12 Mar 2026 22:58:17 +0800 Subject: [PATCH 105/118] Catch RuntimeError from missing turbojpeg native library in codec tests TurboJPEG import succeeds but instantiation raises RuntimeError when the native library isn't installed. Skip the test case gracefully. --- dimos/memory2/codecs/test_codecs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dimos/memory2/codecs/test_codecs.py b/dimos/memory2/codecs/test_codecs.py index 03226b491b..5c17281e48 100644 --- a/dimos/memory2/codecs/test_codecs.py +++ b/dimos/memory2/codecs/test_codecs.py @@ -93,7 +93,7 @@ def _jpeg_case() -> Case | None: replay = TimedSensorReplay("unitree_go2_bigoffice/video") frames = [replay.find_closest_seek(float(i)) for i in range(1, 4)] codec = JpegCodec(quality=95) - except ImportError: + except (ImportError, RuntimeError): return None return Case( From 5a418c6d1c1ee9d8d549192668d8aae21fabd00c Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Thu, 12 Mar 2026 23:12:54 +0800 Subject: [PATCH 106/118] pr comments --- dimos/memory2/impl/sqlite.py | 54 +++++++++++++++++++----------------- dimos/memory2/intro.md | 15 ---------- dimos/memory2/type.py | 19 +++++++++---- 3 files changed, 42 insertions(+), 46 deletions(-) diff --git a/dimos/memory2/impl/sqlite.py b/dimos/memory2/impl/sqlite.py index f2dc514d1f..8672aa2eb9 100644 --- a/dimos/memory2/impl/sqlite.py +++ b/dimos/memory2/impl/sqlite.py @@ -337,34 +337,38 @@ def append(self, obs: Observation[T]) -> Observation[T]: else: px = py = pz = qx = qy = qz = qw = None # type: ignore[assignment] - cur = self._conn.execute( - f'INSERT INTO "{self._name}" (ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, tags) ' - "VALUES (?, ?, ?, ?, ?, ?, ?, ?, jsonb(?))", - (obs.ts, px, py, pz, qx, qy, qz, qw, tags_json), - ) - row_id = cur.lastrowid - assert row_id is not None - - bs = self.config.blob_store - if bs is None: - raise RuntimeError("BlobStore required but not configured") - bs.put(self._name, row_id, encoded) - - # R*Tree spatial index - if pose: - self._conn.execute( - f'INSERT INTO "{self._name}_rtree" (id, x_min, x_max, y_min, y_max, z_min, z_max) ' - "VALUES (?, ?, ?, ?, ?, ?, ?)", - (row_id, px, px, py, py, pz, pz), + try: + cur = self._conn.execute( + f'INSERT INTO "{self._name}" (ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, tags) ' + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, jsonb(?))", + (obs.ts, px, py, pz, qx, qy, qz, qw, tags_json), ) + row_id = cur.lastrowid + assert row_id is not None + + bs = self.config.blob_store + if bs is None: + raise RuntimeError("BlobStore required but not configured") + bs.put(self._name, row_id, encoded) + + # R*Tree spatial index + if pose: + self._conn.execute( + f'INSERT INTO "{self._name}_rtree" (id, x_min, x_max, y_min, y_max, z_min, z_max) ' + "VALUES (?, ?, ?, ?, ?, ?, ?)", + (row_id, px, px, py, py, pz, pz), + ) - vs = self.config.vector_store - if vs is not None: - emb = getattr(obs, "embedding", None) - if emb is not None: - vs.put(self._name, row_id, emb) + vs = self.config.vector_store + if vs is not None: + emb = getattr(obs, "embedding", None) + if emb is not None: + vs.put(self._name, row_id, emb) - self._conn.commit() + self._conn.commit() + except BaseException: + self._conn.rollback() + raise obs.id = row_id self._channel.notify(obs) diff --git a/dimos/memory2/intro.md b/dimos/memory2/intro.md index 341d89608c..1a9b224204 100644 --- a/dimos/memory2/intro.md +++ b/dimos/memory2/intro.md @@ -169,18 +169,3 @@ threading.Thread( # every new log is now automatically embedded and stored # embedded_logs.search(query, k=5).fetch() to query at any time ``` - -## Full text search - -`.search_text(text)` does efficient substring matching: - -```python session=memory ansi=false -for obs in logs.search_text("motor").fetch(): - print(f"{obs.data}") -``` - - -``` -Motor started -Motor stopped -``` diff --git a/dimos/memory2/type.py b/dimos/memory2/type.py index 85cfab9640..1b0427eb3d 100644 --- a/dimos/memory2/type.py +++ b/dimos/memory2/type.py @@ -15,6 +15,7 @@ from __future__ import annotations from dataclasses import dataclass, field +import threading from typing import TYPE_CHECKING, Any, Generic, TypeVar if TYPE_CHECKING: @@ -53,17 +54,23 @@ class Observation(Generic[T]): tags: dict[str, Any] = field(default_factory=dict) _data: T | _Unloaded = field(default=_UNLOADED, repr=False) _loader: Callable[[], T] | None = field(default=None, repr=False) + _data_lock: threading.Lock = field(default_factory=threading.Lock, repr=False) @property def data(self) -> T: val = self._data if isinstance(val, _Unloaded): - if self._loader is None: - raise LookupError("No data and no loader set on this observation") - loaded = self._loader() - self._data = loaded - self._loader = None # release closure - return loaded + with self._data_lock: + # Re-check after acquiring lock (double-checked locking) + val = self._data + if isinstance(val, _Unloaded): + if self._loader is None: + raise LookupError("No data and no loader set on this observation") + loaded = self._loader() + self._data = loaded + self._loader = None # release closure + return loaded + return val # type: ignore[return-value] return val def derive(self, *, data: Any, **overrides: Any) -> Observation[Any]: From 99c3f3eb29c0fcdd91dd11364d1d2d5e56feca11 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 13 Mar 2026 11:15:03 +0800 Subject: [PATCH 107/118] occupancy change undo --- dimos/mapping/occupancy/gradient.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dimos/mapping/occupancy/gradient.py b/dimos/mapping/occupancy/gradient.py index 51a0c013ad..880f2692da 100644 --- a/dimos/mapping/occupancy/gradient.py +++ b/dimos/mapping/occupancy/gradient.py @@ -50,7 +50,7 @@ def gradient( # Compute distance transform (distance to nearest obstacle in cells) # Unknown cells are treated as if they don't exist for distance calculation - distance_cells: np.ndarray = ndimage.distance_transform_edt(1 - obstacle_map) # type: ignore[assignment] + distance_cells = ndimage.distance_transform_edt(1 - obstacle_map) # Convert to meters and clip to max distance distance_meters = np.clip(distance_cells * occupancy_grid.resolution, 0, max_distance) From 1103e3d682c4b0c834a7b4dffe29e4849c48f7ec Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 13 Mar 2026 15:15:25 +0800 Subject: [PATCH 108/118] tests cleanup --- dimos/memory2/blobstore/test_blobstore.py | 55 +- dimos/memory2/conftest.py | 96 ++++ dimos/memory2/test_blobstore.py | 185 ------ dimos/memory2/test_blobstore_integration.py | 178 ++++++ dimos/memory2/test_e2e_import.py | 148 ----- dimos/memory2/test_e2e_query.py | 150 ----- dimos/memory2/test_embedding.py | 392 ++++++------- dimos/memory2/test_impl.py | 586 ++++++++------------ dimos/memory2/test_stream.py | 238 ++++---- 9 files changed, 792 insertions(+), 1236 deletions(-) create mode 100644 dimos/memory2/conftest.py delete mode 100644 dimos/memory2/test_blobstore.py create mode 100644 dimos/memory2/test_blobstore_integration.py delete mode 100644 dimos/memory2/test_e2e_import.py delete mode 100644 dimos/memory2/test_e2e_query.py diff --git a/dimos/memory2/blobstore/test_blobstore.py b/dimos/memory2/blobstore/test_blobstore.py index fe05cfa84f..ebe051a17f 100644 --- a/dimos/memory2/blobstore/test_blobstore.py +++ b/dimos/memory2/blobstore/test_blobstore.py @@ -16,69 +16,15 @@ from __future__ import annotations -from dataclasses import dataclass -import sqlite3 from typing import TYPE_CHECKING import pytest -from dimos.memory2.blobstore.file import FileBlobStore -from dimos.memory2.blobstore.sqlite import SqliteBlobStore - if TYPE_CHECKING: - from collections.abc import Callable, Generator - from pathlib import Path - from dimos.memory2.backend import BlobStore -# ── Case definition ──────────────────────────────────────────────── - - -@dataclass -class Case: - name: str - factory: Callable[..., Generator[BlobStore, None, None]] - - -# ── Factories ────────────────────────────────────────────────────── - - -@pytest.fixture() -def file_store(tmp_path: Path) -> Generator[FileBlobStore, None, None]: - store = FileBlobStore(tmp_path / "blobs") - store.start() - yield store - store.stop() - - -@pytest.fixture() -def sqlite_store() -> Generator[SqliteBlobStore, None, None]: - conn = sqlite3.connect(":memory:") - store = SqliteBlobStore(conn) - store.start() - yield store - store.stop() - conn.close() - - -@pytest.fixture(params=["file", "sqlite"]) -def blob_store( - request: pytest.FixtureRequest, - file_store: FileBlobStore, - sqlite_store: SqliteBlobStore, -) -> BlobStore: - if request.param == "file": - return file_store - return sqlite_store - - -# ── Tests ────────────────────────────────────────────────────────── - - class TestBlobStore: - """Every BlobStore must satisfy these contracts.""" - def test_put_get_roundtrip(self, blob_store: BlobStore) -> None: data = b"hello world" blob_store.put("stream_a", 1, data) @@ -112,3 +58,4 @@ def test_large_blob(self, blob_store: BlobStore) -> None: data = bytes(range(256)) * 1000 # 256 KB blob_store.put("big", 0, data) assert blob_store.get("big", 0) == data + assert blob_store.get("big", 0) == data diff --git a/dimos/memory2/conftest.py b/dimos/memory2/conftest.py new file mode 100644 index 0000000000..de73c249bc --- /dev/null +++ b/dimos/memory2/conftest.py @@ -0,0 +1,96 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared fixtures for memory2 tests.""" + +from __future__ import annotations + +import sqlite3 +import tempfile +from typing import TYPE_CHECKING + +import pytest + +from dimos.memory2.blobstore.file import FileBlobStore +from dimos.memory2.blobstore.sqlite import SqliteBlobStore +from dimos.memory2.impl.memory import MemoryStore +from dimos.memory2.impl.sqlite import SqliteStore + +if TYPE_CHECKING: + from collections.abc import Generator + from pathlib import Path + + from dimos.memory2.backend import BlobStore + from dimos.memory2.impl.memory import MemorySession + from dimos.memory2.store import Session + + +# ── Stores ──────────────────────────────────────────────────────── + + +@pytest.fixture +def memory_store() -> Generator[MemoryStore, None, None]: + with MemoryStore() as store: + yield store + + +@pytest.fixture +def memory_session(memory_store: MemoryStore) -> Generator[MemorySession, None, None]: + with memory_store.session() as session: + yield session + + +@pytest.fixture +def sqlite_store() -> Generator[SqliteStore, None, None]: + with tempfile.NamedTemporaryFile(suffix=".db") as f: + store = SqliteStore(path=f.name) + with store: + yield store + + +@pytest.fixture +def sqlite_session(sqlite_store: SqliteStore) -> Generator[Session, None, None]: + with sqlite_store.session() as session: + yield session + + +@pytest.fixture(params=["memory_session", "sqlite_session"]) +def session(request: pytest.FixtureRequest) -> Session: + return request.getfixturevalue(request.param) + + +# ── Blob Stores ─────────────────────────────────────────────────── + + +@pytest.fixture +def file_blob_store(tmp_path: Path) -> Generator[FileBlobStore, None, None]: + store = FileBlobStore(tmp_path / "blobs") + store.start() + yield store + store.stop() + + +@pytest.fixture +def sqlite_blob_store() -> Generator[SqliteBlobStore, None, None]: + conn = sqlite3.connect(":memory:") + store = SqliteBlobStore(conn) + store.start() + yield store + store.stop() + conn.close() + + +@pytest.fixture(params=["file_blob_store", "sqlite_blob_store"]) +def blob_store(request: pytest.FixtureRequest) -> BlobStore: + return request.getfixturevalue(request.param) diff --git a/dimos/memory2/test_blobstore.py b/dimos/memory2/test_blobstore.py deleted file mode 100644 index b8e8668ff8..0000000000 --- a/dimos/memory2/test_blobstore.py +++ /dev/null @@ -1,185 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for BlobStore integration with ListBackend.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -import numpy as np - -from dimos.memory2.blobstore.file import FileBlobStore -from dimos.memory2.impl.memory import MemoryStore -from dimos.memory2.type import _UNLOADED -from dimos.models.embedding.base import Embedding - -if TYPE_CHECKING: - from pathlib import Path - -# ── Helpers ─────────────────────────────────────────────────────── - - -def _emb(vec: list[float]) -> Embedding: - v = np.array(vec, dtype=np.float32) - v /= np.linalg.norm(v) + 1e-10 - return Embedding(vector=v) - - -# ── Tests ───────────────────────────────────────────────────────── - - -class TestBlobStoreIntegration: - def test_append_stores_in_blobstore(self, tmp_path: Path) -> None: - bs = FileBlobStore(tmp_path / "blobs") - bs.start() - store = MemoryStore() - with store.session(blob_store=bs) as session: - s = session.stream("data", bytes) - s.append(b"hello", ts=1.0) - - # Blob was written to the file store - raw = bs.get("data", 0) - assert len(raw) > 0 - - def test_lazy_data_not_loaded_until_access(self, tmp_path: Path) -> None: - bs = FileBlobStore(tmp_path / "blobs") - bs.start() - store = MemoryStore() - with store.session(blob_store=bs) as session: - s = session.stream("data", str) - obs = s.append("payload", ts=1.0) - - # Data replaced with sentinel after append - assert isinstance(obs._data, type(_UNLOADED)) - assert obs._loader is not None - - def test_lazy_data_loads_correctly(self, tmp_path: Path) -> None: - bs = FileBlobStore(tmp_path / "blobs") - bs.start() - store = MemoryStore() - with store.session(blob_store=bs) as session: - s = session.stream("data", str) - s.append("payload", ts=1.0) - - result = s.first() - assert result.data == "payload" - - def test_eager_preloads_data(self, tmp_path: Path) -> None: - bs = FileBlobStore(tmp_path / "blobs") - bs.start() - store = MemoryStore() - with store.session(blob_store=bs, eager_blobs=True) as session: - s = session.stream("data", str) - s.append("payload", ts=1.0) - - # Iterating with eager_blobs triggers load - results = s.fetch() - assert len(results) == 1 - # Data should be loaded (not _UNLOADED) - assert not isinstance(results[0]._data, type(_UNLOADED)) - assert results[0].data == "payload" - - def test_per_stream_eager_override(self, tmp_path: Path) -> None: - bs = FileBlobStore(tmp_path / "blobs") - bs.start() - store = MemoryStore() - with store.session(blob_store=bs) as session: - # Default: lazy - lazy_stream = session.stream("lazy", str) - lazy_stream.append("lazy-val", ts=1.0) - - # Override: eager - eager_stream = session.stream("eager", str, eager_blobs=True) - eager_stream.append("eager-val", ts=1.0) - - lazy_results = lazy_stream.fetch() - eager_results = eager_stream.fetch() - - # Lazy: data stays unloaded until accessed - assert lazy_results[0].data == "lazy-val" - - # Eager: data pre-loaded during iteration - assert not isinstance(eager_results[0]._data, type(_UNLOADED)) - assert eager_results[0].data == "eager-val" - - def test_no_blobstore_unchanged(self) -> None: - store = MemoryStore() - with store.session() as session: - s = session.stream("data", str) - obs = s.append("inline", ts=1.0) - - # Without blob store, data stays inline - assert obs._data == "inline" - assert obs._loader is None - assert obs.data == "inline" - - def test_blobstore_with_vector_search(self, tmp_path: Path) -> None: - from dimos.memory2.vectorstore import MemoryVectorStore - - bs = FileBlobStore(tmp_path / "blobs") - bs.start() - vs = MemoryVectorStore() - store = MemoryStore() - with store.session(blob_store=bs, vector_store=vs) as session: - s = session.stream("vecs", str) - s.append("north", ts=1.0, embedding=_emb([0, 1, 0])) - s.append("east", ts=2.0, embedding=_emb([1, 0, 0])) - s.append("south", ts=3.0, embedding=_emb([0, -1, 0])) - - # Vector search triggers lazy load via obs.derive(data=obs.data, ...) - results = s.search(_emb([0, 1, 0]), k=2).fetch() - assert len(results) == 2 - assert results[0].data == "north" - assert results[0].similarity > 0.99 - - def test_blobstore_with_text_search(self, tmp_path: Path) -> None: - bs = FileBlobStore(tmp_path / "blobs") - bs.start() - store = MemoryStore() - with store.session(blob_store=bs) as session: - s = session.stream("logs", str) - s.append("motor fault", ts=1.0) - s.append("temperature ok", ts=2.0) - - # Text search triggers lazy load via str(obs.data) - results = s.search_text("motor").fetch() - assert len(results) == 1 - assert results[0].data == "motor fault" - - def test_multiple_appends_get_unique_blobs(self, tmp_path: Path) -> None: - bs = FileBlobStore(tmp_path / "blobs") - bs.start() - store = MemoryStore() - with store.session(blob_store=bs) as session: - s = session.stream("multi", str) - s.append("first", ts=1.0) - s.append("second", ts=2.0) - s.append("third", ts=3.0) - - results = s.fetch() - assert [r.data for r in results] == ["first", "second", "third"] - - def test_fetch_preserves_metadata(self, tmp_path: Path) -> None: - bs = FileBlobStore(tmp_path / "blobs") - bs.start() - store = MemoryStore() - with store.session(blob_store=bs) as session: - s = session.stream("meta", str) - s.append("val", ts=42.0, tags={"kind": "info"}) - - result = s.first() - assert result.ts == 42.0 - assert result.tags == {"kind": "info"} - assert result.data == "val" diff --git a/dimos/memory2/test_blobstore_integration.py b/dimos/memory2/test_blobstore_integration.py new file mode 100644 index 0000000000..c961d1fe31 --- /dev/null +++ b/dimos/memory2/test_blobstore_integration.py @@ -0,0 +1,178 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for BlobStore integration with ListBackend.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import pytest + +from dimos.memory2.blobstore.file import FileBlobStore +from dimos.memory2.impl.memory import MemoryStore +from dimos.memory2.type import _UNLOADED +from dimos.models.embedding.base import Embedding + +if TYPE_CHECKING: + from collections.abc import Generator + from pathlib import Path + + from dimos.memory2.impl.memory import MemorySession + +# ── Helpers ─────────────────────────────────────────────────────── + + +def _emb(vec: list[float]) -> Embedding: + v = np.array(vec, dtype=np.float32) + v /= np.linalg.norm(v) + 1e-10 + return Embedding(vector=v) + + +# ── Fixtures ───────────────────────────────────────────────────── + + +@pytest.fixture +def bs(tmp_path: Path) -> Generator[FileBlobStore, None, None]: + blob_store = FileBlobStore(tmp_path / "blobs") + blob_store.start() + yield blob_store + blob_store.stop() + + +@pytest.fixture +def store() -> Generator[MemoryStore, None, None]: + s = MemoryStore() + yield s + s.stop() + + +@pytest.fixture +def session(store: MemoryStore, bs: FileBlobStore) -> Generator[MemorySession, None, None]: + with store.session(blob_store=bs) as sess: + yield sess + + +# ── Tests ───────────────────────────────────────────────────────── + + +class TestBlobStoreIntegration: + def test_append_stores_in_blobstore(self, bs: FileBlobStore, session: MemorySession) -> None: + s = session.stream("data", bytes) + s.append(b"hello", ts=1.0) + + # Blob was written to the file store + raw = bs.get("data", 0) + assert len(raw) > 0 + + def test_lazy_data_not_loaded_until_access(self, session: MemorySession) -> None: + s = session.stream("data", str) + obs = s.append("payload", ts=1.0) + + # Data replaced with sentinel after append + assert isinstance(obs._data, type(_UNLOADED)) + assert obs._loader is not None + + def test_lazy_data_loads_correctly(self, session: MemorySession) -> None: + s = session.stream("data", str) + s.append("payload", ts=1.0) + + result = s.first() + assert result.data == "payload" + + def test_eager_preloads_data(self, store: MemoryStore, bs: FileBlobStore) -> None: + with store.session(blob_store=bs, eager_blobs=True) as session: + s = session.stream("data", str) + s.append("payload", ts=1.0) + + # Iterating with eager_blobs triggers load + results = s.fetch() + assert len(results) == 1 + # Data should be loaded (not _UNLOADED) + assert not isinstance(results[0]._data, type(_UNLOADED)) + assert results[0].data == "payload" + + def test_per_stream_eager_override(self, session: MemorySession) -> None: + # Default: lazy + lazy_stream = session.stream("lazy", str) + lazy_stream.append("lazy-val", ts=1.0) + + # Override: eager + eager_stream = session.stream("eager", str, eager_blobs=True) + eager_stream.append("eager-val", ts=1.0) + + lazy_results = lazy_stream.fetch() + eager_results = eager_stream.fetch() + + # Lazy: data stays unloaded until accessed + assert lazy_results[0].data == "lazy-val" + + # Eager: data pre-loaded during iteration + assert not isinstance(eager_results[0]._data, type(_UNLOADED)) + assert eager_results[0].data == "eager-val" + + def test_no_blobstore_unchanged(self, store: MemoryStore) -> None: + with store.session() as session: + s = session.stream("data", str) + obs = s.append("inline", ts=1.0) + + # Without blob store, data stays inline + assert obs._data == "inline" + assert obs._loader is None + assert obs.data == "inline" + + def test_blobstore_with_vector_search(self, store: MemoryStore, bs: FileBlobStore) -> None: + from dimos.memory2.vectorstore import MemoryVectorStore + + vs = MemoryVectorStore() + with store.session(blob_store=bs, vector_store=vs) as session: + s = session.stream("vecs", str) + s.append("north", ts=1.0, embedding=_emb([0, 1, 0])) + s.append("east", ts=2.0, embedding=_emb([1, 0, 0])) + s.append("south", ts=3.0, embedding=_emb([0, -1, 0])) + + # Vector search triggers lazy load via obs.derive(data=obs.data, ...) + results = s.search(_emb([0, 1, 0]), k=2).fetch() + assert len(results) == 2 + assert results[0].data == "north" + assert results[0].similarity > 0.99 + + def test_blobstore_with_text_search(self, session: MemorySession) -> None: + s = session.stream("logs", str) + s.append("motor fault", ts=1.0) + s.append("temperature ok", ts=2.0) + + # Text search triggers lazy load via str(obs.data) + results = s.search_text("motor").fetch() + assert len(results) == 1 + assert results[0].data == "motor fault" + + def test_multiple_appends_get_unique_blobs(self, session: MemorySession) -> None: + s = session.stream("multi", str) + s.append("first", ts=1.0) + s.append("second", ts=2.0) + s.append("third", ts=3.0) + + results = s.fetch() + assert [r.data for r in results] == ["first", "second", "third"] + + def test_fetch_preserves_metadata(self, session: MemorySession) -> None: + s = session.stream("meta", str) + s.append("val", ts=42.0, tags={"kind": "info"}) + + result = s.first() + assert result.ts == 42.0 + assert result.tags == {"kind": "info"} + assert result.data == "val" diff --git a/dimos/memory2/test_e2e_import.py b/dimos/memory2/test_e2e_import.py deleted file mode 100644 index 0fd44a3329..0000000000 --- a/dimos/memory2/test_e2e_import.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""E2E test: import legacy pickle replays into memory SqliteStore.""" - -from __future__ import annotations - -import bisect -from typing import TYPE_CHECKING, Any - -import pytest - -from dimos.memory2.impl.sqlite import SqliteStore -from dimos.msgs.sensor_msgs.Image import Image -from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 -from dimos.utils.data import get_data_dir -from dimos.utils.testing import TimedSensorReplay - -if TYPE_CHECKING: - from collections.abc import Generator - - from dimos.memory2.impl.sqlite import SqliteSession - -DB_PATH = get_data_dir("go2_bigoffice_v2.db") - - -class PoseIndex: - """Preloaded odom data with O(log n) closest-timestamp lookup.""" - - def __init__(self, replay: TimedSensorReplay) -> None: # type: ignore[type-arg] - self._timestamps: list[float] = [] - self._data: list[Any] = [] - for ts, data in replay.iterate_ts(): - self._timestamps.append(ts) - self._data.append(data) - - def find_closest(self, ts: float) -> Any | None: - if not self._timestamps: - return None - idx = bisect.bisect_left(self._timestamps, ts) - # Compare the two candidates around the insertion point - if idx == 0: - return self._data[0] - if idx >= len(self._timestamps): - return self._data[-1] - if ts - self._timestamps[idx - 1] <= self._timestamps[idx] - ts: - return self._data[idx - 1] - return self._data[idx] - - -@pytest.fixture(scope="module") -def store() -> Generator[SqliteStore, None, None]: - s = SqliteStore(path=str(DB_PATH)) - yield s - - -@pytest.fixture(scope="module") -def session(store: SqliteStore) -> Generator[SqliteSession, None, None]: - with store.session() as session: - yield session - - -@pytest.fixture(scope="module") -def video_replay() -> TimedSensorReplay: # type: ignore[type-arg] - return TimedSensorReplay("unitree_go2_bigoffice/video") - - -@pytest.fixture(scope="module") -def odom_index() -> PoseIndex: - return PoseIndex(TimedSensorReplay("unitree_go2_bigoffice/odom")) - - -@pytest.fixture(scope="module") -def lidar_replay() -> TimedSensorReplay: # type: ignore[type-arg] - return TimedSensorReplay("unitree_go2_bigoffice/lidar") - - -@pytest.mark.tool -class TestImportReplay: - """Import legacy pickle replay data into a memory SqliteStore.""" - - def test_import_video( - self, - session: SqliteSession, - video_replay: TimedSensorReplay, # type: ignore[type-arg] - odom_index: PoseIndex, - ) -> None: - video = session.stream("color_image", Image) - - count = 0 - for ts, frame in video_replay.iterate_ts(): - pose = odom_index.find_closest(ts) - print(frame) - video.append(frame, ts=ts, pose=pose) - count += 1 - - assert count > 0 - assert video.count() == count - print(f"Imported {count} video frames") - - def test_import_lidar( - self, - session: SqliteSession, - lidar_replay: TimedSensorReplay, # type: ignore[type-arg] - odom_index: PoseIndex, - ) -> None: - lidar = session.stream("lidar", PointCloud2) - - count = 0 - for ts, frame in lidar_replay.iterate_ts(): - pose = odom_index.find_closest(ts) - print(frame) - lidar.append(frame, ts=ts, pose=pose) - count += 1 - - assert count > 0 - assert lidar.count() == count - print(f"Imported {count} lidar frames") - - def test_query_imported_data(self, session: SqliteSession) -> None: - video = session.stream("color_image", Image) - lidar = session.stream("lidar", PointCloud2) - - assert video.exists() - assert lidar.exists() - - first_frame = video.first() - last_frame = video.last() - assert first_frame.ts < last_frame.ts - - mid_ts = (first_frame.ts + last_frame.ts) / 2 - subset = video.time_range(first_frame.ts, mid_ts).fetch() - assert 0 < len(subset) < video.count() - - streams = session.list_streams() - assert "color_image" in streams - assert "lidar" in streams diff --git a/dimos/memory2/test_e2e_query.py b/dimos/memory2/test_e2e_query.py deleted file mode 100644 index ac26e865ff..0000000000 --- a/dimos/memory2/test_e2e_query.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""E2E query tests against pre-built go2_bigoffice_v2.db. - -Read-only — no writes, just verifies query paths against real robot data. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -from dimos.memory2.impl.sqlite import SqliteStore -from dimos.msgs.sensor_msgs.Image import Image -from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 -from dimos.utils.data import get_data - -if TYPE_CHECKING: - from collections.abc import Generator - - from dimos.memory2.impl.sqlite import SqliteSession - - -@pytest.fixture(scope="module") -def session() -> Generator[SqliteSession, None, None]: - db_path = get_data("go2_bigoffice_v2.db") - store = SqliteStore(path=str(db_path)) - with store.session() as s: - yield s - - -@pytest.mark.tool -class TestE2EQuery: - """Query operations against real robot replay data.""" - - def test_list_streams(self, session: SqliteSession) -> None: - streams = session.list_streams() - print(streams) - - assert "color_image" in streams - assert "lidar" in streams - - def test_video_count(self, session: SqliteSession) -> None: - video = session.stream("color_image", Image) - assert video.count() > 1000 - - def test_lidar_count(self, session: SqliteSession) -> None: - lidar = session.stream("lidar", PointCloud2) - assert lidar.count() > 1000 - - def test_first_last_timestamps(self, session: SqliteSession) -> None: - video = session.stream("color_image", Image) - first = video.first() - last = video.last() - assert first.ts < last.ts - duration = last.ts - first.ts - assert duration > 10.0 # at least 10s of data - - def test_time_range_filter(self, session: SqliteSession) -> None: - video = session.stream("color_image", Image) - first = video.first() - - # Grab first 5 seconds - window = video.time_range(first.ts, first.ts + 5.0).fetch() - assert len(window) > 0 - assert len(window) < video.count() - assert all(first.ts <= obs.ts <= first.ts + 5.0 for obs in window) - - def test_limit_offset_pagination(self, session: SqliteSession) -> None: - video = session.stream("color_image", Image) - page1 = video.limit(10).fetch() - page2 = video.offset(10).limit(10).fetch() - - assert len(page1) == 10 - assert len(page2) == 10 - assert page1[-1].ts < page2[0].ts # no overlap - - def test_order_by_desc(self, session: SqliteSession) -> None: - video = session.stream("color_image", Image) - last_10 = video.order_by("ts", desc=True).limit(10).fetch() - - assert len(last_10) == 10 - assert all(last_10[i].ts >= last_10[i + 1].ts for i in range(9)) - - def test_lazy_data_loads_correctly(self, session: SqliteSession) -> None: - """Verify lazy blob loading returns valid Image data.""" - from dimos.memory2.type import _Unloaded - - video = session.stream("color_image", Image) - obs = next(iter(video.limit(1))) - - # Should start lazy - assert isinstance(obs._data, _Unloaded) - - # Trigger load - frame = obs.data - assert isinstance(frame, Image) - assert frame.width > 0 - assert frame.height > 0 - - def test_iterate_window_decodes_all(self, session: SqliteSession) -> None: - """Iterate a time window and verify every frame decodes.""" - video = session.stream("color_image", Image) - first_ts = video.first().ts - - window = video.time_range(first_ts, first_ts + 2.0) - count = 0 - for obs in window: - frame = obs.data - assert isinstance(frame, Image) - count += 1 - assert count > 0 - - def test_lidar_data_loads(self, session: SqliteSession) -> None: - """Verify lidar blobs decode to PointCloud2.""" - lidar = session.stream("lidar", PointCloud2) - frame = lidar.first().data - assert isinstance(frame, PointCloud2) - - def test_poses_present(self, session: SqliteSession) -> None: - """Verify poses were stored during import.""" - video = session.stream("color_image", Image) - obs = video.first() - assert obs.pose is not None - - def test_cross_stream_time_alignment(self, session: SqliteSession) -> None: - """Video and lidar should overlap in time.""" - video = session.stream("color_image", Image) - lidar = session.stream("lidar", PointCloud2) - - v_first, v_last = video.first().ts, video.last().ts - l_first, l_last = lidar.first().ts, lidar.last().ts - - # Overlap: max of starts < min of ends - overlap_start = max(v_first, l_first) - overlap_end = min(v_last, l_last) - assert overlap_start < overlap_end, "Video and lidar should overlap in time" diff --git a/dimos/memory2/test_embedding.py b/dimos/memory2/test_embedding.py index f1d22addf2..d2b37bf210 100644 --- a/dimos/memory2/test_embedding.py +++ b/dimos/memory2/test_embedding.py @@ -19,7 +19,6 @@ import numpy as np import pytest -from dimos.memory2.impl.memory import MemoryStore from dimos.memory2.type import EmbeddedObservation, Observation from dimos.models.embedding.base import Embedding @@ -85,177 +84,149 @@ def test_observation_derive_without_embedding_stays_observation(self) -> None: class TestListBackendEmbedding: - def test_append_with_embedding(self) -> None: - store = MemoryStore() - with store.session() as session: - s = session.stream("vecs", str) - emb = _emb([1, 0, 0]) - obs = s.append("hello", embedding=emb) - assert isinstance(obs, EmbeddedObservation) - assert obs.embedding is emb - - def test_append_without_embedding(self) -> None: - store = MemoryStore() - with store.session() as session: - s = session.stream("plain", str) - obs = s.append("hello") - assert type(obs) is Observation - - def test_search_returns_top_k(self) -> None: - store = MemoryStore() - with store.session() as session: - s = session.stream("vecs", str) - s.append("north", embedding=_emb([0, 1, 0])) - s.append("east", embedding=_emb([1, 0, 0])) - s.append("south", embedding=_emb([0, -1, 0])) - s.append("west", embedding=_emb([-1, 0, 0])) - - results = s.search(_emb([0, 1, 0]), k=2).fetch() - assert len(results) == 2 - assert results[0].data == "north" - assert results[0].similarity is not None - assert results[0].similarity > 0.99 - - def test_search_sorted_by_similarity(self) -> None: - store = MemoryStore() - with store.session() as session: - s = session.stream("vecs", str) - s.append("far", embedding=_emb([0, -1, 0])) - s.append("close", embedding=_emb([0.9, 0.1, 0])) - s.append("exact", embedding=_emb([1, 0, 0])) - - results = s.search(_emb([1, 0, 0]), k=3).fetch() - assert results[0].data == "exact" - assert results[1].data == "close" - assert results[2].data == "far" - # Descending similarity - assert results[0].similarity >= results[1].similarity >= results[2].similarity - - def test_search_skips_non_embedded(self) -> None: - store = MemoryStore() - with store.session() as session: - s = session.stream("mixed", str) - s.append("plain") # no embedding - s.append("embedded", embedding=_emb([1, 0, 0])) - - results = s.search(_emb([1, 0, 0]), k=10).fetch() - assert len(results) == 1 - assert results[0].data == "embedded" - - def test_search_with_filters(self) -> None: - store = MemoryStore() - with store.session() as session: - s = session.stream("vecs", str) - s.append("early", ts=10.0, embedding=_emb([1, 0, 0])) - s.append("late", ts=20.0, embedding=_emb([1, 0, 0])) - - # Only the late one should pass the after filter - results = s.after(15.0).search(_emb([1, 0, 0]), k=10).fetch() - assert len(results) == 1 - assert results[0].data == "late" - - def test_search_with_limit(self) -> None: - store = MemoryStore() - with store.session() as session: - s = session.stream("vecs", str) - for i in range(10): - s.append(f"item{i}", embedding=_emb([1, 0, 0])) - - # search k=5 then limit 2 - results = s.search(_emb([1, 0, 0]), k=5).limit(2).fetch() - assert len(results) == 2 + def test_append_with_embedding(self, memory_session) -> None: + s = memory_session.stream("vecs", str) + emb = _emb([1, 0, 0]) + obs = s.append("hello", embedding=emb) + assert isinstance(obs, EmbeddedObservation) + assert obs.embedding is emb - def test_search_with_live_raises(self) -> None: - store = MemoryStore() - with store.session() as session: - s = session.stream("vecs", str) - s.append("x", embedding=_emb([1, 0, 0])) - with pytest.raises(TypeError, match="Cannot combine"): - list(s.live().search(_emb([1, 0, 0]), k=5)) + def test_append_without_embedding(self, memory_session) -> None: + s = memory_session.stream("plain", str) + obs = s.append("hello") + assert type(obs) is Observation + + def test_search_returns_top_k(self, memory_session) -> None: + s = memory_session.stream("vecs", str) + s.append("north", embedding=_emb([0, 1, 0])) + s.append("east", embedding=_emb([1, 0, 0])) + s.append("south", embedding=_emb([0, -1, 0])) + s.append("west", embedding=_emb([-1, 0, 0])) + + results = s.search(_emb([0, 1, 0]), k=2).fetch() + assert len(results) == 2 + assert results[0].data == "north" + assert results[0].similarity is not None + assert results[0].similarity > 0.99 + + def test_search_sorted_by_similarity(self, memory_session) -> None: + s = memory_session.stream("vecs", str) + s.append("far", embedding=_emb([0, -1, 0])) + s.append("close", embedding=_emb([0.9, 0.1, 0])) + s.append("exact", embedding=_emb([1, 0, 0])) + + results = s.search(_emb([1, 0, 0]), k=3).fetch() + assert results[0].data == "exact" + assert results[1].data == "close" + assert results[2].data == "far" + # Descending similarity + assert results[0].similarity >= results[1].similarity >= results[2].similarity + + def test_search_skips_non_embedded(self, memory_session) -> None: + s = memory_session.stream("mixed", str) + s.append("plain") # no embedding + s.append("embedded", embedding=_emb([1, 0, 0])) + + results = s.search(_emb([1, 0, 0]), k=10).fetch() + assert len(results) == 1 + assert results[0].data == "embedded" + + def test_search_with_filters(self, memory_session) -> None: + s = memory_session.stream("vecs", str) + s.append("early", ts=10.0, embedding=_emb([1, 0, 0])) + s.append("late", ts=20.0, embedding=_emb([1, 0, 0])) + + # Only the late one should pass the after filter + results = s.after(15.0).search(_emb([1, 0, 0]), k=10).fetch() + assert len(results) == 1 + assert results[0].data == "late" + + def test_search_with_limit(self, memory_session) -> None: + s = memory_session.stream("vecs", str) + for i in range(10): + s.append(f"item{i}", embedding=_emb([1, 0, 0])) + + # search k=5 then limit 2 + results = s.search(_emb([1, 0, 0]), k=5).limit(2).fetch() + assert len(results) == 2 + + def test_search_with_live_raises(self, memory_session) -> None: + s = memory_session.stream("vecs", str) + s.append("x", embedding=_emb([1, 0, 0])) + with pytest.raises(TypeError, match="Cannot combine"): + list(s.live().search(_emb([1, 0, 0]), k=5)) # ── Text search ────────────────────────────────────────────────── class TestTextSearch: - def test_search_text_substring(self) -> None: - store = MemoryStore() - with store.session() as session: - s = session.stream("logs", str) - s.append("motor fault detected") - s.append("temperature normal") - s.append("motor overheating") - - results = s.search_text("motor").fetch() - assert len(results) == 2 - assert {r.data for r in results} == {"motor fault detected", "motor overheating"} + def test_search_text_substring(self, memory_session) -> None: + s = memory_session.stream("logs", str) + s.append("motor fault detected") + s.append("temperature normal") + s.append("motor overheating") - def test_search_text_case_insensitive(self) -> None: - store = MemoryStore() - with store.session() as session: - s = session.stream("logs", str) - s.append("Motor Fault") - s.append("other event") + results = s.search_text("motor").fetch() + assert len(results) == 2 + assert {r.data for r in results} == {"motor fault detected", "motor overheating"} - results = s.search_text("motor fault").fetch() - assert len(results) == 1 + def test_search_text_case_insensitive(self, memory_session) -> None: + s = memory_session.stream("logs", str) + s.append("Motor Fault") + s.append("other event") - def test_search_text_with_filters(self) -> None: - store = MemoryStore() - with store.session() as session: - s = session.stream("logs", str) - s.append("motor fault", ts=10.0) - s.append("motor warning", ts=20.0) - s.append("motor fault", ts=30.0) + results = s.search_text("motor fault").fetch() + assert len(results) == 1 - results = s.after(15.0).search_text("fault").fetch() - assert len(results) == 1 - assert results[0].ts == 30.0 + def test_search_text_with_filters(self, memory_session) -> None: + s = memory_session.stream("logs", str) + s.append("motor fault", ts=10.0) + s.append("motor warning", ts=20.0) + s.append("motor fault", ts=30.0) + + results = s.after(15.0).search_text("fault").fetch() + assert len(results) == 1 + assert results[0].ts == 30.0 - def test_search_text_no_match(self) -> None: - store = MemoryStore() - with store.session() as session: - s = session.stream("logs", str) - s.append("all clear") + def test_search_text_no_match(self, memory_session) -> None: + s = memory_session.stream("logs", str) + s.append("all clear") - results = s.search_text("motor").fetch() - assert len(results) == 0 + results = s.search_text("motor").fetch() + assert len(results) == 0 # ── Save preserves embeddings ──────────────────────────────────── class TestSaveEmbeddings: - def test_save_preserves_embeddings(self) -> None: - store = MemoryStore() - with store.session() as session: - src = session.stream("source", str) - dst = session.stream("dest", str) + def test_save_preserves_embeddings(self, memory_session) -> None: + src = memory_session.stream("source", str) + dst = memory_session.stream("dest", str) - emb = _emb([1, 0, 0]) - src.append("item", embedding=emb) - src.save(dst) + emb = _emb([1, 0, 0]) + src.append("item", embedding=emb) + src.save(dst) - results = dst.fetch() - assert len(results) == 1 - assert isinstance(results[0], EmbeddedObservation) - # Same vector content (different Embedding instance after re-append) - np.testing.assert_array_almost_equal(results[0].embedding.to_numpy(), emb.to_numpy()) + results = dst.fetch() + assert len(results) == 1 + assert isinstance(results[0], EmbeddedObservation) + # Same vector content (different Embedding instance after re-append) + np.testing.assert_array_almost_equal(results[0].embedding.to_numpy(), emb.to_numpy()) - def test_save_mixed_embedded_and_plain(self) -> None: - store = MemoryStore() - with store.session() as session: - src = session.stream("source", str) - dst = session.stream("dest", str) + def test_save_mixed_embedded_and_plain(self, memory_session) -> None: + src = memory_session.stream("source", str) + dst = memory_session.stream("dest", str) - src.append("plain") - src.append("embedded", embedding=_emb([0, 1, 0])) - src.save(dst) + src.append("plain") + src.append("embedded", embedding=_emb([0, 1, 0])) + src.save(dst) - results = dst.fetch() - assert len(results) == 2 - assert type(results[0]) is Observation - assert isinstance(results[1], EmbeddedObservation) + results = dst.fetch() + assert len(results) == 2 + assert type(results[0]) is Observation + assert isinstance(results[1], EmbeddedObservation) # ── Embed transformers (mock model) ───────────────────────────── @@ -286,71 +257,63 @@ def embed_text(self, *texts): class TestEmbedTransformers: - def test_embed_images_produces_embedded_observations(self) -> None: + def test_embed_images_produces_embedded_observations(self, memory_session) -> None: from dimos.memory2.embed import EmbedImages model = _MockEmbeddingModel() - store = MemoryStore() - with store.session() as session: - s = session.stream("imgs", str) - s.append("img1", ts=1.0) - s.append("img2", ts=2.0) + s = memory_session.stream("imgs", str) + s.append("img1", ts=1.0) + s.append("img2", ts=2.0) - results = s.transform(EmbedImages(model)).fetch() - assert len(results) == 2 - for obs in results: - assert isinstance(obs, EmbeddedObservation) - assert isinstance(obs.embedding, Embedding) - assert obs.embedding.to_numpy().shape == (8,) + results = s.transform(EmbedImages(model)).fetch() + assert len(results) == 2 + for obs in results: + assert isinstance(obs, EmbeddedObservation) + assert isinstance(obs.embedding, Embedding) + assert obs.embedding.to_numpy().shape == (8,) - def test_embed_text_produces_embedded_observations(self) -> None: + def test_embed_text_produces_embedded_observations(self, memory_session) -> None: from dimos.memory2.embed import EmbedText model = _MockEmbeddingModel() - store = MemoryStore() - with store.session() as session: - s = session.stream("logs", str) - s.append("motor fault", ts=1.0) - s.append("all clear", ts=2.0) + s = memory_session.stream("logs", str) + s.append("motor fault", ts=1.0) + s.append("all clear", ts=2.0) - results = s.transform(EmbedText(model)).fetch() - assert len(results) == 2 - for obs in results: - assert isinstance(obs, EmbeddedObservation) - assert isinstance(obs.embedding, Embedding) + results = s.transform(EmbedText(model)).fetch() + assert len(results) == 2 + for obs in results: + assert isinstance(obs, EmbeddedObservation) + assert isinstance(obs.embedding, Embedding) - def test_embed_preserves_data(self) -> None: + def test_embed_preserves_data(self, memory_session) -> None: from dimos.memory2.embed import EmbedText model = _MockEmbeddingModel() - store = MemoryStore() - with store.session() as session: - s = session.stream("logs", str) - s.append("hello", ts=1.0) + s = memory_session.stream("logs", str) + s.append("hello", ts=1.0) - result = s.transform(EmbedText(model)).first() - assert result.data == "hello" + result = s.transform(EmbedText(model)).first() + assert result.data == "hello" - def test_embed_then_search(self) -> None: + def test_embed_then_search(self, memory_session) -> None: from dimos.memory2.embed import EmbedText model = _MockEmbeddingModel() - store = MemoryStore() - with store.session() as session: - s = session.stream("logs", str) - for i in range(10): - s.append(f"log entry {i}", ts=float(i)) - - embedded = s.transform(EmbedText(model)) - # Get the embedding for the first item, then search for similar - first_emb = embedded.first().embedding - results = embedded.search(first_emb, k=3).fetch() - assert len(results) == 3 - # First result should be the exact match - assert results[0].similarity is not None - assert results[0].similarity > 0.99 - - def test_embed_batching(self) -> None: + s = memory_session.stream("logs", str) + for i in range(10): + s.append(f"log entry {i}", ts=float(i)) + + embedded = s.transform(EmbedText(model)) + # Get the embedding for the first item, then search for similar + first_emb = embedded.first().embedding + results = embedded.search(first_emb, k=3).fetch() + assert len(results) == 3 + # First result should be the exact match + assert results[0].similarity is not None + assert results[0].similarity > 0.99 + + def test_embed_batching(self, memory_session) -> None: from dimos.memory2.embed import EmbedText call_sizes: list[int] = [] @@ -361,15 +324,13 @@ def embed_text(self, *texts): return super().embed_text(*texts) model = _TrackingModel() - store = MemoryStore() - with store.session() as session: - s = session.stream("logs", str) - for i in range(5): - s.append(f"entry {i}") + s = memory_session.stream("logs", str) + for i in range(5): + s.append(f"entry {i}") - list(s.transform(EmbedText(model, batch_size=2))) - # 5 items with batch_size=2 → 3 calls (2, 2, 1) - assert call_sizes == [2, 2, 1] + list(s.transform(EmbedText(model, batch_size=2))) + # 5 items with batch_size=2 → 3 calls (2, 2, 1) + assert call_sizes == [2, 2, 1] # ── Pluggable VectorStore ──────────────────────────────────────── @@ -378,35 +339,32 @@ def embed_text(self, *texts): class TestPluggableVectorStore: """Verify that injecting a VectorStore via session config actually delegates search.""" - def test_append_stores_in_vector_store(self) -> None: + def test_append_stores_in_vector_store(self, memory_store) -> None: from dimos.memory2.vectorstore import MemoryVectorStore vs = MemoryVectorStore() - store = MemoryStore() - with store.session(vector_store=vs) as session: + with memory_store.session(vector_store=vs) as session: s = session.stream("vecs", str) s.append("hello", embedding=_emb([1, 0, 0])) s.append("world", embedding=_emb([0, 1, 0])) assert len(vs._vectors["vecs"]) == 2 - def test_append_without_embedding_skips_vector_store(self) -> None: + def test_append_without_embedding_skips_vector_store(self, memory_store) -> None: from dimos.memory2.vectorstore import MemoryVectorStore vs = MemoryVectorStore() - store = MemoryStore() - with store.session(vector_store=vs) as session: + with memory_store.session(vector_store=vs) as session: s = session.stream("plain", str) s.append("no embedding") assert "plain" not in vs._vectors - def test_search_uses_vector_store(self) -> None: + def test_search_uses_vector_store(self, memory_store) -> None: from dimos.memory2.vectorstore import MemoryVectorStore vs = MemoryVectorStore() - store = MemoryStore() - with store.session(vector_store=vs) as session: + with memory_store.session(vector_store=vs) as session: s = session.stream("vecs", str) s.append("north", embedding=_emb([0, 1, 0])) s.append("east", embedding=_emb([1, 0, 0])) @@ -419,12 +377,11 @@ def test_search_uses_vector_store(self) -> None: assert results[0].similarity is not None assert results[0].similarity > 0.99 - def test_search_with_filters_via_vector_store(self) -> None: + def test_search_with_filters_via_vector_store(self, memory_store) -> None: from dimos.memory2.vectorstore import MemoryVectorStore vs = MemoryVectorStore() - store = MemoryStore() - with store.session(vector_store=vs) as session: + with memory_store.session(vector_store=vs) as session: s = session.stream("vecs", str) s.append("early", ts=10.0, embedding=_emb([1, 0, 0])) s.append("late", ts=20.0, embedding=_emb([1, 0, 0])) @@ -434,13 +391,12 @@ def test_search_with_filters_via_vector_store(self) -> None: assert len(results) == 1 assert results[0].data == "late" - def test_per_stream_vector_store_override(self) -> None: + def test_per_stream_vector_store_override(self, memory_store) -> None: from dimos.memory2.vectorstore import MemoryVectorStore vs_default = MemoryVectorStore() vs_override = MemoryVectorStore() - store = MemoryStore() - with store.session(vector_store=vs_default) as session: + with memory_store.session(vector_store=vs_default) as session: # Stream with default vector store s1 = session.stream("s1", str) s1.append("a", embedding=_emb([1, 0, 0])) diff --git a/dimos/memory2/test_impl.py b/dimos/memory2/test_impl.py index bcbd06ffd1..faf5dc6258 100644 --- a/dimos/memory2/test_impl.py +++ b/dimos/memory2/test_impl.py @@ -15,226 +15,159 @@ """Grid tests for Store implementations. Runs the same test logic against every Store backend (MemoryStore, SqliteStore, …). +The parametrized ``session`` fixture from conftest runs each test against both backends. """ from __future__ import annotations -from contextlib import contextmanager -from dataclasses import dataclass, field -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import pytest if TYPE_CHECKING: - from collections.abc import Callable, Generator - from dimos.memory2.store import Session -# ── Case definition ──────────────────────────────────────────────── - - -@dataclass -class Case: - name: str - session_factory: Callable[[], Generator[Session, None, None]] - tags: set[str] = field(default_factory=set) - - -# ── Context managers ─────────────────────────────────────────────── - - -@contextmanager -def memory_session() -> Generator[Session, None, None]: - from dimos.memory2.impl.memory import MemoryStore - - store = MemoryStore() - with store.session() as session: - yield session - - -@contextmanager -def sqlite_session() -> Generator[Session, None, None]: - import tempfile - - from dimos.memory2.impl.sqlite import SqliteStore - - with tempfile.NamedTemporaryFile(suffix=".db") as f: - store = SqliteStore(path=f.name) - with store.session() as session: - yield session +# ── Tests ───────────────────────────────────────────────────────── -# ── Test cases ───────────────────────────────────────────────────── -testcases = [ - Case(name="memory", session_factory=memory_session, tags={"basic", "live"}), - Case( - name="sqlite", - session_factory=sqlite_session, - tags={"basic"}, - ), -] - -basic_cases = [c for c in testcases if "basic" in c.tags] - - -# ── Tests ────────────────────────────────────────────────────────── - - -@pytest.mark.parametrize("case", basic_cases, ids=lambda c: c.name) class TestStoreBasic: """Core store operations that every backend must support.""" - def test_create_stream_and_append(self, case: Case) -> None: - with case.session_factory() as session: - s = session.stream("images", bytes) - obs = s.append(b"frame1", tags={"camera": "front"}) - - assert obs.data == b"frame1" - assert obs.tags["camera"] == "front" - assert obs.ts > 0 - - def test_append_multiple_and_fetch(self, case: Case) -> None: - with case.session_factory() as session: - s = session.stream("sensor", float) - s.append(1.0, ts=100.0) - s.append(2.0, ts=200.0) - s.append(3.0, ts=300.0) - - results = s.fetch() - assert len(results) == 3 - assert [o.data for o in results] == [1.0, 2.0, 3.0] - - def test_iterate_stream(self, case: Case) -> None: - with case.session_factory() as session: - s = session.stream("log", str) - s.append("a", ts=1.0) - s.append("b", ts=2.0) - - collected = [obs.data for obs in s] - assert collected == ["a", "b"] - - def test_count(self, case: Case) -> None: - with case.session_factory() as session: - s = session.stream("events", str) - assert s.count() == 0 - s.append("x") - s.append("y") - assert s.count() == 2 - - def test_first_and_last(self, case: Case) -> None: - with case.session_factory() as session: - s = session.stream("data", int) - s.append(10, ts=1.0) - s.append(20, ts=2.0) - s.append(30, ts=3.0) - - assert s.first().data == 10 - assert s.last().data == 30 - - def test_first_empty_raises(self, case: Case) -> None: - with case.session_factory() as session: - s = session.stream("empty", int) - with pytest.raises(LookupError): - s.first() - - def test_exists(self, case: Case) -> None: - with case.session_factory() as session: - s = session.stream("check", str) - assert not s.exists() - s.append("hi") - assert s.exists() - - def test_filter_after(self, case: Case) -> None: - with case.session_factory() as session: - s = session.stream("ts_data", int) - s.append(1, ts=10.0) - s.append(2, ts=20.0) - s.append(3, ts=30.0) - - results = s.after(15.0).fetch() - assert [o.data for o in results] == [2, 3] - - def test_filter_before(self, case: Case) -> None: - with case.session_factory() as session: - s = session.stream("ts_data", int) - s.append(1, ts=10.0) - s.append(2, ts=20.0) - s.append(3, ts=30.0) - - results = s.before(25.0).fetch() - assert [o.data for o in results] == [1, 2] - - def test_filter_time_range(self, case: Case) -> None: - with case.session_factory() as session: - s = session.stream("ts_data", int) - s.append(1, ts=10.0) - s.append(2, ts=20.0) - s.append(3, ts=30.0) - - results = s.time_range(15.0, 25.0).fetch() - assert [o.data for o in results] == [2] - - def test_filter_tags(self, case: Case) -> None: - with case.session_factory() as session: - s = session.stream("tagged", str) - s.append("a", tags={"kind": "info"}) - s.append("b", tags={"kind": "error"}) - s.append("c", tags={"kind": "info"}) - - results = s.tags(kind="info").fetch() - assert [o.data for o in results] == ["a", "c"] - - def test_limit_and_offset(self, case: Case) -> None: - with case.session_factory() as session: - s = session.stream("paged", int) - for i in range(5): - s.append(i, ts=float(i)) - - page = s.offset(1).limit(2).fetch() - assert [o.data for o in page] == [1, 2] - - def test_order_by_desc(self, case: Case) -> None: - with case.session_factory() as session: - s = session.stream("ordered", int) - s.append(1, ts=10.0) - s.append(2, ts=20.0) - s.append(3, ts=30.0) - - results = s.order_by("ts", desc=True).fetch() - assert [o.data for o in results] == [3, 2, 1] - - def test_separate_streams_isolated(self, case: Case) -> None: - with case.session_factory() as session: - a = session.stream("stream_a", str) - b = session.stream("stream_b", str) - - a.append("in_a") - b.append("in_b") - - assert [o.data for o in a] == ["in_a"] - assert [o.data for o in b] == ["in_b"] - - def test_same_stream_on_repeated_calls(self, case: Case) -> None: - with case.session_factory() as session: - s1 = session.stream("reuse", str) - s2 = session.stream("reuse", str) - assert s1 is s2 - - def test_append_with_embedding(self, case: Case) -> None: + def test_create_stream_and_append(self, session: Session) -> None: + s = session.stream("images", bytes) + obs = s.append(b"frame1", tags={"camera": "front"}) + + assert obs.data == b"frame1" + assert obs.tags["camera"] == "front" + assert obs.ts > 0 + + def test_append_multiple_and_fetch(self, session: Session) -> None: + s = session.stream("sensor", float) + s.append(1.0, ts=100.0) + s.append(2.0, ts=200.0) + s.append(3.0, ts=300.0) + + results = s.fetch() + assert len(results) == 3 + assert [o.data for o in results] == [1.0, 2.0, 3.0] + + def test_iterate_stream(self, session: Session) -> None: + s = session.stream("log", str) + s.append("a", ts=1.0) + s.append("b", ts=2.0) + + collected = [obs.data for obs in s] + assert collected == ["a", "b"] + + def test_count(self, session: Session) -> None: + s = session.stream("events", str) + assert s.count() == 0 + s.append("x") + s.append("y") + assert s.count() == 2 + + def test_first_and_last(self, session: Session) -> None: + s = session.stream("data", int) + s.append(10, ts=1.0) + s.append(20, ts=2.0) + s.append(30, ts=3.0) + + assert s.first().data == 10 + assert s.last().data == 30 + + def test_first_empty_raises(self, session: Session) -> None: + s = session.stream("empty", int) + with pytest.raises(LookupError): + s.first() + + def test_exists(self, session: Session) -> None: + s = session.stream("check", str) + assert not s.exists() + s.append("hi") + assert s.exists() + + def test_filter_after(self, session: Session) -> None: + s = session.stream("ts_data", int) + s.append(1, ts=10.0) + s.append(2, ts=20.0) + s.append(3, ts=30.0) + + results = s.after(15.0).fetch() + assert [o.data for o in results] == [2, 3] + + def test_filter_before(self, session: Session) -> None: + s = session.stream("ts_data", int) + s.append(1, ts=10.0) + s.append(2, ts=20.0) + s.append(3, ts=30.0) + + results = s.before(25.0).fetch() + assert [o.data for o in results] == [1, 2] + + def test_filter_time_range(self, session: Session) -> None: + s = session.stream("ts_data", int) + s.append(1, ts=10.0) + s.append(2, ts=20.0) + s.append(3, ts=30.0) + + results = s.time_range(15.0, 25.0).fetch() + assert [o.data for o in results] == [2] + + def test_filter_tags(self, session: Session) -> None: + s = session.stream("tagged", str) + s.append("a", tags={"kind": "info"}) + s.append("b", tags={"kind": "error"}) + s.append("c", tags={"kind": "info"}) + + results = s.tags(kind="info").fetch() + assert [o.data for o in results] == ["a", "c"] + + def test_limit_and_offset(self, session: Session) -> None: + s = session.stream("paged", int) + for i in range(5): + s.append(i, ts=float(i)) + + page = s.offset(1).limit(2).fetch() + assert [o.data for o in page] == [1, 2] + + def test_order_by_desc(self, session: Session) -> None: + s = session.stream("ordered", int) + s.append(1, ts=10.0) + s.append(2, ts=20.0) + s.append(3, ts=30.0) + + results = s.order_by("ts", desc=True).fetch() + assert [o.data for o in results] == [3, 2, 1] + + def test_separate_streams_isolated(self, session: Session) -> None: + a = session.stream("stream_a", str) + b = session.stream("stream_b", str) + + a.append("in_a") + b.append("in_b") + + assert [o.data for o in a] == ["in_a"] + assert [o.data for o in b] == ["in_b"] + + def test_same_stream_on_repeated_calls(self, session: Session) -> None: + s1 = session.stream("reuse", str) + s2 = session.stream("reuse", str) + assert s1 is s2 + + def test_append_with_embedding(self, session: Session) -> None: import numpy as np from dimos.memory2.type import EmbeddedObservation from dimos.models.embedding.base import Embedding - with case.session_factory() as session: - s = session.stream("vectors", str) - emb = Embedding(vector=np.array([1.0, 0.0, 0.0], dtype=np.float32)) - obs = s.append("hello", embedding=emb) - assert isinstance(obs, EmbeddedObservation) - assert obs.embedding is emb + s = session.stream("vectors", str) + emb = Embedding(vector=np.array([1.0, 0.0, 0.0], dtype=np.float32)) + obs = s.append("hello", embedding=emb) + assert isinstance(obs, EmbeddedObservation) + assert obs.embedding is emb - def test_search_top_k(self, case: Case) -> None: + def test_search_top_k(self, session: Session) -> None: import numpy as np from dimos.models.embedding.base import Embedding @@ -243,30 +176,28 @@ def _emb(v: list[float]) -> Embedding: a = np.array(v, dtype=np.float32) return Embedding(vector=a / (np.linalg.norm(a) + 1e-10)) - with case.session_factory() as session: - s = session.stream("searchable", str) - s.append("north", embedding=_emb([0, 1, 0])) - s.append("east", embedding=_emb([1, 0, 0])) - s.append("south", embedding=_emb([0, -1, 0])) + s = session.stream("searchable", str) + s.append("north", embedding=_emb([0, 1, 0])) + s.append("east", embedding=_emb([1, 0, 0])) + s.append("south", embedding=_emb([0, -1, 0])) - results = s.search(_emb([0, 1, 0]), k=2).fetch() - assert len(results) == 2 - assert results[0].data == "north" - assert results[0].similarity > 0.99 + results = s.search(_emb([0, 1, 0]), k=2).fetch() + assert len(results) == 2 + assert results[0].data == "north" + assert results[0].similarity > 0.99 - def test_search_text(self, case: Case) -> None: - with case.session_factory() as session: - s = session.stream("logs", str) - s.append("motor fault") - s.append("temperature ok") + def test_search_text(self, session: Session) -> None: + s = session.stream("logs", str) + s.append("motor fault") + s.append("temperature ok") - if case.name == "sqlite": - with pytest.raises(NotImplementedError): - s.search_text("motor").fetch() - else: - results = s.search_text("motor").fetch() - assert len(results) == 1 - assert results[0].data == "motor fault" + # SqliteBackend blocks search_text to prevent full table scans + try: + results = s.search_text("motor").fetch() + except NotImplementedError: + pytest.skip("search_text not supported on this backend") + assert len(results) == 1 + assert results[0].data == "motor fault" # ── Lazy / eager blob loading tests ────────────────────────────── @@ -275,95 +206,71 @@ def test_search_text(self, case: Case) -> None: class TestBlobLoading: """Verify lazy and eager blob loading paths.""" - def test_sqlite_lazy_by_default(self) -> None: + def test_sqlite_lazy_by_default(self, sqlite_session: Session) -> None: """Default sqlite iteration uses lazy loaders — data is _UNLOADED until accessed.""" - import tempfile - - from dimos.memory2.impl.sqlite import SqliteStore from dimos.memory2.type import _Unloaded - with tempfile.NamedTemporaryFile(suffix=".db") as f: - store = SqliteStore(path=f.name) - with store.session() as session: - s = session.stream("lazy_test", str) - s.append("hello", ts=1.0) - s.append("world", ts=2.0) - - for obs in s: - # Before accessing .data, _data should be the unloaded sentinel - assert isinstance(obs._data, _Unloaded) - assert obs._loader is not None - # Accessing .data triggers the loader - val = obs.data - assert isinstance(val, str) - # After loading, _loader is cleared - assert obs._loader is None - - def test_sqlite_eager_loads_inline(self) -> None: + s = sqlite_session.stream("lazy_test", str) + s.append("hello", ts=1.0) + s.append("world", ts=2.0) + + for obs in s: + # Before accessing .data, _data should be the unloaded sentinel + assert isinstance(obs._data, _Unloaded) + assert obs._loader is not None + # Accessing .data triggers the loader + val = obs.data + assert isinstance(val, str) + # After loading, _loader is cleared + assert obs._loader is None + + def test_sqlite_eager_loads_inline(self, sqlite_session: Session) -> None: """With eager_blobs=True, data is loaded via JOIN — no lazy loader.""" - import tempfile - - from dimos.memory2.impl.sqlite import SqliteStore from dimos.memory2.type import _Unloaded - with tempfile.NamedTemporaryFile(suffix=".db") as f: - store = SqliteStore(path=f.name) - with store.session() as session: - s = session.stream("eager_test", str, eager_blobs=True) - s.append("hello", ts=1.0) - s.append("world", ts=2.0) + s = sqlite_session.stream("eager_test", str, eager_blobs=True) + s.append("hello", ts=1.0) + s.append("world", ts=2.0) - for obs in s: - # Data should already be loaded — no lazy sentinel - assert not isinstance(obs._data, _Unloaded) - assert obs._loader is None - assert isinstance(obs.data, str) + for obs in s: + # Data should already be loaded — no lazy sentinel + assert not isinstance(obs._data, _Unloaded) + assert obs._loader is None + assert isinstance(obs.data, str) - def test_sqlite_lazy_and_eager_same_values(self) -> None: + def test_sqlite_lazy_and_eager_same_values(self, sqlite_session: Session) -> None: """Both paths must return identical data.""" - import tempfile - - from dimos.memory2.impl.sqlite import SqliteStore + lazy_s = sqlite_session.stream("vals", str) + lazy_s.append("alpha", ts=1.0, tags={"k": "v"}) + lazy_s.append("beta", ts=2.0, tags={"k": "w"}) - with tempfile.NamedTemporaryFile(suffix=".db") as f: - store = SqliteStore(path=f.name) - with store.session() as session: - lazy_s = session.stream("vals", str) - lazy_s.append("alpha", ts=1.0, tags={"k": "v"}) - lazy_s.append("beta", ts=2.0, tags={"k": "w"}) + # Lazy read + lazy_results = lazy_s.fetch() - # Lazy read - lazy_results = lazy_s.fetch() + # Eager read — new stream handle with eager_blobs on same backend + eager_s = sqlite_session.stream("vals", str, eager_blobs=True) + eager_results = eager_s.fetch() - # Eager read — new stream handle with eager_blobs on same backend - eager_s = session.stream("vals", str, eager_blobs=True) - eager_results = eager_s.fetch() + assert [o.data for o in lazy_results] == [o.data for o in eager_results] + assert [o.tags for o in lazy_results] == [o.tags for o in eager_results] + assert [o.ts for o in lazy_results] == [o.ts for o in eager_results] - assert [o.data for o in lazy_results] == [o.data for o in eager_results] - assert [o.tags for o in lazy_results] == [o.tags for o in eager_results] - assert [o.ts for o in lazy_results] == [o.ts for o in eager_results] - - def test_memory_lazy_with_blobstore(self) -> None: + def test_memory_lazy_with_blobstore(self, memory_store, tmp_path) -> None: """MemoryStore with a BlobStore uses lazy loaders.""" from dimos.memory2.blobstore.file import FileBlobStore - from dimos.memory2.impl.memory import MemoryStore from dimos.memory2.type import _Unloaded - store = MemoryStore() - import tempfile - - with tempfile.TemporaryDirectory() as tmpdir: - bs = FileBlobStore(root=tmpdir) - bs.start() - with store.session(blob_store=bs) as session: - s = session.stream("mem_lazy", str) - s.append("data1", ts=1.0) + bs = FileBlobStore(root=tmp_path / "blobs") + bs.start() + with memory_store.session(blob_store=bs) as session: + s = session.stream("mem_lazy", str) + s.append("data1", ts=1.0) - obs = s.first() - # ListBackend replaces _data with _UNLOADED when blob_store is set - assert isinstance(obs._data, _Unloaded) - assert obs.data == "data1" - bs.stop() + obs = s.first() + # ListBackend replaces _data with _UNLOADED when blob_store is set + assert isinstance(obs._data, _Unloaded) + assert obs.data == "data1" + bs.stop() # ── Spy stores ─────────────────────────────────────────────────── @@ -426,19 +333,11 @@ def delete(self, stream: str, key: int) -> None: self.vectors.get(stream, {}).pop(key, None) -# ── Spy grid: session factories that inject spy stores ─────────── - - -@dataclass -class SpyCase: - name: str - session_factory: Callable[ - [], Generator[tuple[Session, SpyBlobStore, SpyVectorStore], None, None] - ] +# ── Spy delegation tests ───────────────────────────────────────── -@contextmanager -def memory_spy_session() -> Generator[tuple[Session, SpyBlobStore, SpyVectorStore], None, None]: +@pytest.fixture +def memory_spy_session(): from dimos.memory2.impl.memory import MemoryStore blob_spy = SpyBlobStore() @@ -446,53 +345,50 @@ def memory_spy_session() -> Generator[tuple[Session, SpyBlobStore, SpyVectorStor store = MemoryStore() with store.session(blob_store=blob_spy, vector_store=vec_spy) as session: yield session, blob_spy, vec_spy + store.stop() -@contextmanager -def sqlite_spy_session() -> Generator[tuple[Session, SpyBlobStore, SpyVectorStore], None, None]: - import tempfile - +@pytest.fixture +def sqlite_spy_session(tmp_path): from dimos.memory2.impl.sqlite import SqliteStore blob_spy = SpyBlobStore() vec_spy = SpyVectorStore() - with tempfile.NamedTemporaryFile(suffix=".db") as f: - store = SqliteStore(path=f.name) - with store.session(blob_store=blob_spy, vector_store=vec_spy) as session: - yield session, blob_spy, vec_spy + store = SqliteStore(path=str(tmp_path / "spy.db")) + with store.session(blob_store=blob_spy, vector_store=vec_spy) as session: + yield session, blob_spy, vec_spy + store.stop() -spy_cases = [ - SpyCase(name="memory", session_factory=memory_spy_session), - SpyCase(name="sqlite", session_factory=sqlite_spy_session), -] +@pytest.fixture(params=["memory_spy_session", "sqlite_spy_session"]) +def spy_session(request: pytest.FixtureRequest): + return request.getfixturevalue(request.param) -@pytest.mark.parametrize("case", spy_cases, ids=lambda c: c.name) class TestStoreDelegation: """Verify all backends delegate to pluggable BlobStore and VectorStore.""" - def test_append_calls_blob_put(self, case: SpyCase) -> None: - with case.session_factory() as (session, blob_spy, _vec_spy): - s = session.stream("blobs", str) - s.append("first", ts=1.0) - s.append("second", ts=2.0) + def test_append_calls_blob_put(self, spy_session) -> None: + session, blob_spy, _vec_spy = spy_session + s = session.stream("blobs", str) + s.append("first", ts=1.0) + s.append("second", ts=2.0) - assert len(blob_spy.puts) == 2 - assert all(stream == "blobs" for stream, _k, _d in blob_spy.puts) + assert len(blob_spy.puts) == 2 + assert all(stream == "blobs" for stream, _k, _d in blob_spy.puts) - def test_iterate_calls_blob_get(self, case: SpyCase) -> None: - with case.session_factory() as (session, blob_spy, _vec_spy): - s = session.stream("blobs", str) - s.append("a", ts=1.0) - s.append("b", ts=2.0) + def test_iterate_calls_blob_get(self, spy_session) -> None: + session, blob_spy, _vec_spy = spy_session + s = session.stream("blobs", str) + s.append("a", ts=1.0) + s.append("b", ts=2.0) - blob_spy.gets.clear() - for obs in s: - _ = obs.data - assert len(blob_spy.gets) == 2 + blob_spy.gets.clear() + for obs in s: + _ = obs.data + assert len(blob_spy.gets) == 2 - def test_append_embedding_calls_vector_put(self, case: SpyCase) -> None: + def test_append_embedding_calls_vector_put(self, spy_session) -> None: import numpy as np from dimos.models.embedding.base import Embedding @@ -501,15 +397,15 @@ def _emb(v: list[float]) -> Embedding: a = np.array(v, dtype=np.float32) return Embedding(vector=a / (np.linalg.norm(a) + 1e-10)) - with case.session_factory() as (session, _blob_spy, vec_spy): - s = session.stream("vecs", str) - s.append("a", ts=1.0, embedding=_emb([1, 0, 0])) - s.append("b", ts=2.0, embedding=_emb([0, 1, 0])) - s.append("c", ts=3.0) # no embedding + session, _blob_spy, vec_spy = spy_session + s = session.stream("vecs", str) + s.append("a", ts=1.0, embedding=_emb([1, 0, 0])) + s.append("b", ts=2.0, embedding=_emb([0, 1, 0])) + s.append("c", ts=3.0) # no embedding - assert len(vec_spy.puts) == 2 + assert len(vec_spy.puts) == 2 - def test_search_calls_vector_search(self, case: SpyCase) -> None: + def test_search_calls_vector_search(self, spy_session) -> None: import numpy as np from dimos.models.embedding.base import Embedding @@ -518,11 +414,11 @@ def _emb(v: list[float]) -> Embedding: a = np.array(v, dtype=np.float32) return Embedding(vector=a / (np.linalg.norm(a) + 1e-10)) - with case.session_factory() as (session, _blob_spy, vec_spy): - s = session.stream("vecs", str) - s.append("north", ts=1.0, embedding=_emb([0, 1, 0])) - s.append("east", ts=2.0, embedding=_emb([1, 0, 0])) + session, _blob_spy, vec_spy = spy_session + s = session.stream("vecs", str) + s.append("north", ts=1.0, embedding=_emb([0, 1, 0])) + s.append("east", ts=2.0, embedding=_emb([1, 0, 0])) - results = s.search(_emb([0, 1, 0]), k=2).fetch() - assert len(vec_spy.searches) == 1 - assert results[0].data == "north" + results = s.search(_emb([0, 1, 0]), k=2).fetch() + assert len(vec_spy.searches) == 1 + assert results[0].data == "north" diff --git a/dimos/memory2/test_stream.py b/dimos/memory2/test_stream.py index a2261036b3..8442527c62 100644 --- a/dimos/memory2/test_stream.py +++ b/dimos/memory2/test_stream.py @@ -26,24 +26,31 @@ import pytest from dimos.memory2.buffer import KeepLast, Unbounded -from dimos.memory2.impl.memory import MemoryStore from dimos.memory2.transform import FnTransformer, QualityWindow, Transformer from dimos.memory2.type import Observation if TYPE_CHECKING: + from collections.abc import Callable, Generator + from dimos.memory2.stream import Stream -# ── Helpers ────────────────────────────────────────────────────────── +# ── Fixtures ───────────────────────────────────────────────────────── + + +@pytest.fixture +def make_stream(session) -> Generator[Callable[..., Stream[int]], None, None]: + stream_index = 0 + + def f(n: int = 5, start_ts: float = 0.0): + nonlocal stream_index + stream_index += 1 + stream = session.stream(f"test{stream_index}", int) + for i in range(n): + stream.append(i * 10, ts=start_ts + i) + return stream -def make_stream(n: int = 5, start_ts: float = 0.0) -> Stream[int]: - """Create a MemoryStore stream with n integer observations at 1-second intervals.""" - store = MemoryStore() - session = store.session() - stream = session.stream("test") - for i in range(n): - stream.append(i * 10, ts=start_ts + i) - return stream + return f # ═══════════════════════════════════════════════════════════════════ @@ -54,26 +61,26 @@ def make_stream(n: int = 5, start_ts: float = 0.0) -> Stream[int]: class TestBasicIteration: """Streams are lazy iterables — nothing runs until you iterate.""" - def test_iterate_yields_all_observations(self): + def test_iterate_yields_all_observations(self, make_stream): stream = make_stream(5) obs = list(stream) assert len(obs) == 5 assert [o.data for o in obs] == [0, 10, 20, 30, 40] - def test_iterate_preserves_timestamps(self): + def test_iterate_preserves_timestamps(self, make_stream): stream = make_stream(3, start_ts=100.0) assert [o.ts for o in stream] == [100.0, 101.0, 102.0] - def test_empty_stream(self): + def test_empty_stream(self, make_stream): stream = make_stream(0) assert list(stream) == [] - def test_fetch_materializes_to_list(self): + def test_fetch_materializes_to_list(self, make_stream): result = make_stream(3).fetch() assert isinstance(result, list) assert len(result) == 3 - def test_stream_is_reiterable(self): + def test_stream_is_reiterable(self, make_stream): """Same stream can be iterated multiple times — each time re-queries.""" stream = make_stream(3) first = [o.data for o in stream] @@ -89,27 +96,27 @@ def test_stream_is_reiterable(self): class TestTemporalFilters: """Temporal filters constrain observations by timestamp.""" - def test_after(self): + def test_after(self, make_stream): """.after(t) keeps observations with ts > t.""" result = make_stream(5).after(2.0).fetch() assert [o.ts for o in result] == [3.0, 4.0] - def test_before(self): + def test_before(self, make_stream): """.before(t) keeps observations with ts < t.""" result = make_stream(5).before(2.0).fetch() assert [o.ts for o in result] == [0.0, 1.0] - def test_time_range(self): + def test_time_range(self, make_stream): """.time_range(t1, t2) keeps t1 <= ts <= t2.""" result = make_stream(5).time_range(1.0, 3.0).fetch() assert [o.ts for o in result] == [1.0, 2.0, 3.0] - def test_at_with_tolerance(self): + def test_at_with_tolerance(self, make_stream): """.at(t, tolerance) keeps observations within tolerance of t.""" result = make_stream(5).at(2.0, tolerance=0.5).fetch() assert [o.ts for o in result] == [2.0] - def test_chained_temporal_filters(self): + def test_chained_temporal_filters(self, make_stream): """Filters compose — each narrows the result.""" result = make_stream(10).after(2.0).before(7.0).fetch() assert [o.ts for o in result] == [3.0, 4.0, 5.0, 6.0] @@ -123,10 +130,8 @@ def test_chained_temporal_filters(self): class TestSpatialFilter: """.near(pose, radius) filters by Euclidean distance.""" - def test_near_with_tuples(self): - store = MemoryStore() - session = store.session() - stream = session.stream("spatial") + def test_near_with_tuples(self, memory_session): + stream = memory_session.stream("spatial") stream.append("origin", ts=0.0, pose=(0, 0, 0)) stream.append("close", ts=1.0, pose=(1, 1, 0)) stream.append("far", ts=2.0, pose=(10, 10, 10)) @@ -134,10 +139,8 @@ def test_near_with_tuples(self): result = stream.near((0, 0, 0), radius=2.0).fetch() assert [o.data for o in result] == ["origin", "close"] - def test_near_excludes_no_pose(self): - store = MemoryStore() - session = store.session() - stream = session.stream("spatial") + def test_near_excludes_no_pose(self, memory_session): + stream = memory_session.stream("spatial") stream.append("no_pose", ts=0.0) stream.append("has_pose", ts=1.0, pose=(0, 0, 0)) @@ -153,10 +156,8 @@ def test_near_excludes_no_pose(self): class TestTagsFilter: """.filter_tags() matches on observation metadata.""" - def test_filter_by_tag(self): - store = MemoryStore() - session = store.session() - stream = session.stream("tagged") + def test_filter_by_tag(self, memory_session): + stream = memory_session.stream("tagged") stream.append("cat", ts=0.0, tags={"type": "animal", "legs": 4}) stream.append("car", ts=1.0, tags={"type": "vehicle", "wheels": 4}) stream.append("dog", ts=2.0, tags={"type": "animal", "legs": 4}) @@ -164,10 +165,8 @@ def test_filter_by_tag(self): result = stream.tags(type="animal").fetch() assert [o.data for o in result] == ["cat", "dog"] - def test_filter_multiple_tags(self): - store = MemoryStore() - session = store.session() - stream = session.stream("tagged") + def test_filter_multiple_tags(self, memory_session): + stream = memory_session.stream("tagged") stream.append("a", ts=0.0, tags={"x": 1, "y": 2}) stream.append("b", ts=1.0, tags={"x": 1, "y": 3}) @@ -181,44 +180,44 @@ def test_filter_multiple_tags(self): class TestOrderLimitOffset: - def test_limit(self): + def test_limit(self, make_stream): result = make_stream(10).limit(3).fetch() assert len(result) == 3 - def test_offset(self): + def test_offset(self, make_stream): result = make_stream(5).offset(2).fetch() assert [o.data for o in result] == [20, 30, 40] - def test_limit_and_offset(self): + def test_limit_and_offset(self, make_stream): result = make_stream(10).offset(2).limit(3).fetch() assert [o.data for o in result] == [20, 30, 40] - def test_order_by_ts_desc(self): + def test_order_by_ts_desc(self, make_stream): result = make_stream(5).order_by("ts", desc=True).fetch() assert [o.ts for o in result] == [4.0, 3.0, 2.0, 1.0, 0.0] - def test_first(self): + def test_first(self, make_stream): obs = make_stream(5).first() assert obs.data == 0 - def test_last(self): + def test_last(self, make_stream): obs = make_stream(5).last() assert obs.data == 40 - def test_first_empty_raises(self): + def test_first_empty_raises(self, make_stream): with pytest.raises(LookupError): make_stream(0).first() - def test_count(self): + def test_count(self, make_stream): assert make_stream(5).count() == 5 assert make_stream(5).after(2.0).count() == 2 - def test_exists(self): + def test_exists(self, make_stream): assert make_stream(5).exists() assert not make_stream(0).exists() assert not make_stream(5).after(100.0).exists() - def test_drain(self): + def test_drain(self, make_stream): assert make_stream(5).drain() == 5 assert make_stream(5).after(2.0).drain() == 2 assert make_stream(0).drain() == 0 @@ -232,22 +231,22 @@ def test_drain(self): class TestFunctionalAPI: """Functional combinators receive the full Observation.""" - def test_filter_with_predicate(self): + def test_filter_with_predicate(self, make_stream): """.filter() takes a predicate on the full Observation.""" result = make_stream(5).filter(lambda obs: obs.data > 20).fetch() assert [o.data for o in result] == [30, 40] - def test_filter_on_metadata(self): + def test_filter_on_metadata(self, make_stream): """Predicates can access ts, tags, pose — not just data.""" result = make_stream(5).filter(lambda obs: obs.ts % 2 == 0).fetch() assert [o.ts for o in result] == [0.0, 2.0, 4.0] - def test_map(self): + def test_map(self, make_stream): """.map() transforms each observation's data.""" result = make_stream(3).map(lambda obs: obs.derive(data=obs.data * 2)).fetch() assert [o.data for o in result] == [0, 20, 40] - def test_map_preserves_ts(self): + def test_map_preserves_ts(self, make_stream): result = make_stream(3).map(lambda obs: obs.derive(data=str(obs.data))).fetch() assert [o.ts for o in result] == [0.0, 1.0, 2.0] assert [o.data for o in result] == ["0", "10", "20"] @@ -261,12 +260,12 @@ def test_map_preserves_ts(self): class TestTransformChaining: """Transforms chain lazily — each obs flows through the full pipeline.""" - def test_single_transform(self): + def test_single_transform(self, make_stream): xf = FnTransformer(lambda obs: obs.derive(data=obs.data + 1)) result = make_stream(3).transform(xf).fetch() assert [o.data for o in result] == [1, 11, 21] - def test_chained_transforms(self): + def test_chained_transforms(self, make_stream): """stream.transform(A).transform(B) — B pulls from A which pulls from source.""" add_one = FnTransformer(lambda obs: obs.derive(data=obs.data + 1)) double = FnTransformer(lambda obs: obs.derive(data=obs.data * 2)) @@ -275,17 +274,15 @@ def test_chained_transforms(self): # (0+1)*2=2, (10+1)*2=22, (20+1)*2=42 assert [o.data for o in result] == [2, 22, 42] - def test_transform_can_skip(self): + def test_transform_can_skip(self, make_stream): """Returning None from a transformer skips that observation.""" keep_even = FnTransformer(lambda obs: obs if obs.data % 20 == 0 else None) result = make_stream(5).transform(keep_even).fetch() assert [o.data for o in result] == [0, 20, 40] - def test_transform_filter_transform(self): + def test_transform_filter_transform(self, memory_session): """stream.transform(A).near(pose).transform(B) — filter between transforms.""" - store = MemoryStore() - session = store.session() - stream = session.stream("tfft") + stream = memory_session.stream("tfft") stream.append(1, ts=0.0, pose=(0, 0, 0)) stream.append(2, ts=1.0, pose=(100, 100, 100)) stream.append(3, ts=2.0, pose=(1, 0, 0)) @@ -301,7 +298,7 @@ def test_transform_filter_transform(self): ) assert [o.data for o in result] == [22, 26] - def test_generator_function_transform(self): + def test_generator_function_transform(self, make_stream): """A bare generator function works as a transform.""" def double_all(upstream): @@ -311,7 +308,7 @@ def double_all(upstream): result = make_stream(3).transform(double_all).fetch() assert [o.data for o in result] == [0, 20, 40] - def test_generator_function_stateful(self): + def test_generator_function_stateful(self, make_stream): """Generator transforms can accumulate state and yield at their own pace.""" def running_sum(upstream): @@ -324,11 +321,9 @@ def running_sum(upstream): # 0, 0+10=10, 10+20=30 assert [o.data for o in result] == [0, 10, 30] - def test_quality_window(self): + def test_quality_window(self, memory_session): """QualityWindow keeps the best item per time window.""" - store = MemoryStore() - session = store.session() - stream = session.stream("qw") + stream = memory_session.stream("qw") # Window 1: ts 0.0-0.9 → best quality stream.append(0.3, ts=0.0) stream.append(0.9, ts=0.3) # best in window @@ -343,7 +338,7 @@ def test_quality_window(self): result = stream.transform(xf).fetch() assert [o.data for o in result] == [0.9, 0.8, 0.6] - def test_streaming_not_buffering(self): + def test_streaming_not_buffering(self, make_stream): """Transforms process lazily — early limit stops pulling from source.""" calls = [] @@ -368,24 +363,21 @@ def __call__(self, upstream): class TestStoreSession: """Store -> Session -> Stream hierarchy for named streams.""" - def test_basic_session(self): - store = MemoryStore() - with store.session() as session: + def test_basic_session(self, memory_store): + with memory_store.session() as session: images = session.stream("images") images.append("frame1", ts=0.0) images.append("frame2", ts=1.0) assert images.count() == 2 - def test_same_stream_on_repeated_calls(self): - store = MemoryStore() - with store.session() as session: + def test_same_stream_on_repeated_calls(self, memory_store): + with memory_store.session() as session: s1 = session.stream("images") s2 = session.stream("images") assert s1 is s2 - def test_stream_namespace(self): - store = MemoryStore() - with store.session() as session: + def test_stream_namespace(self, memory_store): + with memory_store.session() as session: session.stream("images") session.stream("lidar") assert "images" in session.streams @@ -393,15 +385,13 @@ def test_stream_namespace(self): assert session.streams.images is session.stream("images") assert session.streams["lidar"] is session.stream("lidar") - def test_namespace_missing_raises(self): - store = MemoryStore() - with store.session() as session: + def test_namespace_missing_raises(self, memory_store): + with memory_store.session() as session: with pytest.raises(AttributeError, match="No stream named"): _ = session.streams.nonexistent - def test_delete_stream(self): - store = MemoryStore() - with store.session() as session: + def test_delete_stream(self, memory_store): + with memory_store.session() as session: session.stream("temp") session.delete_stream("temp") assert "temp" not in session.streams @@ -460,11 +450,9 @@ def test_derive_preserves_metadata(self): class TestLiveMode: """Live streams yield backfill then block for new observations.""" - def test_live_sees_backfill_then_new(self): + def test_live_sees_backfill_then_new(self, memory_session): """Backfill first, then live appends come through.""" - store = MemoryStore() - session = store.session() - stream = session.stream("live") + stream = memory_session.stream("live") stream.append("old", ts=0.0) live = stream.live(buffer=Unbounded()) @@ -489,11 +477,9 @@ def consumer(): t.join(timeout=2.0) assert results == ["old", "new1", "new2"] - def test_live_with_filter(self): + def test_live_with_filter(self, memory_session): """Filters apply to live data — non-matching obs are dropped silently.""" - store = MemoryStore() - session = store.session() - stream = session.stream("live_filter") + stream = memory_session.stream("live_filter") live = stream.after(5.0).live(buffer=Unbounded()) results: list[int] = [] @@ -519,11 +505,9 @@ def consumer(): t.join(timeout=2.0) assert results == [2, 4] - def test_live_deduplicates_backfill_overlap(self): + def test_live_deduplicates_backfill_overlap(self, memory_session): """Observations seen in backfill are not re-yielded from the live buffer.""" - store = MemoryStore() - session = store.session() - stream = session.stream("dedup") + stream = memory_session.stream("dedup") stream.append("backfill", ts=0.0) live = stream.live(buffer=Unbounded()) @@ -547,11 +531,9 @@ def consumer(): t.join(timeout=2.0) assert results == ["backfill", "live1"] - def test_live_with_keep_last_backpressure(self): + def test_live_with_keep_last_backpressure(self, memory_session): """KeepLast drops intermediate values when consumer is slow.""" - store = MemoryStore() - session = store.session() - stream = session.stream("bp") + stream = memory_session.stream("bp") live = stream.live(buffer=KeepLast()) results: list[int] = [] @@ -580,11 +562,9 @@ def consumer(): assert len(results) < 50 assert results[-1] >= 90 - def test_live_transform_receives_live_items(self): + def test_live_transform_receives_live_items(self, memory_session): """Transforms downstream of .live() see both backfill and live items.""" - store = MemoryStore() - session = store.session() - stream = session.stream("live_xf") + stream = memory_session.stream("live_xf") stream.append(1, ts=0.0) double = FnTransformer(lambda obs: obs.derive(data=obs.data * 2)) live = stream.live(buffer=Unbounded()).transform(double) @@ -611,18 +591,16 @@ def consumer(): # All items went through the double transform assert results == [2, 20, 200] - def test_live_on_transform_raises(self): + def test_live_on_transform_raises(self, make_stream): """Calling .live() on a transform stream raises TypeError.""" stream = make_stream(3) xf = FnTransformer(lambda obs: obs) with pytest.raises(TypeError, match="Cannot call .live"): stream.transform(xf).live() - def test_is_live(self): + def test_is_live(self, memory_session): """is_live() walks the source chain to detect live mode.""" - store = MemoryStore() - session = store.session() - stream = session.stream("is_live") + stream = memory_session.stream("is_live") assert not stream.is_live() live = stream.live(buffer=Unbounded()) @@ -639,11 +617,9 @@ def test_is_live(self): # Non-live transform is not live assert not stream.transform(xf).is_live() - def test_search_on_live_transform_raises(self): + def test_search_on_live_transform_raises(self, memory_session): """search() on a transform with live upstream raises immediately.""" - store = MemoryStore() - session = store.session() - stream = session.stream("live_search") + stream = memory_session.stream("live_search") xf = FnTransformer(lambda obs: obs) live_xf = stream.live(buffer=Unbounded()).transform(xf) @@ -656,65 +632,53 @@ def test_search_on_live_transform_raises(self): # Use list() to trigger iteration — fetch() would hit its own guard first list(live_xf.search(vec, k=5)) - def test_order_by_on_live_transform_raises(self): + def test_order_by_on_live_transform_raises(self, memory_session): """order_by() on a transform with live upstream raises immediately.""" - store = MemoryStore() - session = store.session() - stream = session.stream("live_order") + stream = memory_session.stream("live_order") xf = FnTransformer(lambda obs: obs) live_xf = stream.live(buffer=Unbounded()).transform(xf) with pytest.raises(TypeError, match="requires finite data"): list(live_xf.order_by("ts", desc=True)) - def test_fetch_on_live_without_limit_raises(self): + def test_fetch_on_live_without_limit_raises(self, memory_session): """fetch() on a live stream without limit() raises TypeError.""" - store = MemoryStore() - session = store.session() - stream = session.stream("live_fetch") + stream = memory_session.stream("live_fetch") live = stream.live(buffer=Unbounded()) with pytest.raises(TypeError, match="block forever"): live.fetch() - def test_fetch_on_live_transform_without_limit_raises(self): + def test_fetch_on_live_transform_without_limit_raises(self, memory_session): """fetch() on a live transform without limit() raises TypeError.""" - store = MemoryStore() - session = store.session() - stream = session.stream("live_fetch_xf") + stream = memory_session.stream("live_fetch_xf") xf = FnTransformer(lambda obs: obs) live_xf = stream.live(buffer=Unbounded()).transform(xf) with pytest.raises(TypeError, match="block forever"): live_xf.fetch() - def test_count_on_live_transform_raises(self): + def test_count_on_live_transform_raises(self, memory_session): """count() on a live transform stream raises TypeError.""" - store = MemoryStore() - session = store.session() - stream = session.stream("live_count") + stream = memory_session.stream("live_count") xf = FnTransformer(lambda obs: obs) live_xf = stream.live(buffer=Unbounded()).transform(xf) with pytest.raises(TypeError, match="block forever"): live_xf.count() - def test_last_on_live_transform_raises(self): + def test_last_on_live_transform_raises(self, memory_session): """last() on a live transform raises TypeError (via order_by guard).""" - store = MemoryStore() - session = store.session() - stream = session.stream("live_last") + stream = memory_session.stream("live_last") xf = FnTransformer(lambda obs: obs) live_xf = stream.live(buffer=Unbounded()).transform(xf) with pytest.raises(TypeError, match="requires finite data"): live_xf.last() - def test_live_chained_transforms(self): + def test_live_chained_transforms(self, memory_session): """stream.live().transform(A).transform(B) — both transforms applied to live items.""" - store = MemoryStore() - session = store.session() - stream = session.stream("live_chain") + stream = memory_session.stream("live_chain") stream.append(1, ts=0.0) add_one = FnTransformer(lambda obs: obs.derive(data=obs.data + 1)) double = FnTransformer(lambda obs: obs.derive(data=obs.data * 2)) @@ -742,11 +706,9 @@ def consumer(): # (1+1)*2=4, (10+1)*2=22, (100+1)*2=202 assert results == [4, 22, 202] - def test_live_filter_before_live(self): + def test_live_filter_before_live(self, memory_session): """Filters applied before .live() work on both backfill and live items.""" - store = MemoryStore() - session = store.session() - stream = session.stream("live_pre_filter") + stream = memory_session.stream("live_pre_filter") stream.append("a", ts=1.0) stream.append("b", ts=10.0) live = stream.after(5.0).live(buffer=Unbounded()) @@ -772,3 +734,7 @@ def consumer(): t.join(timeout=2.0) # "a" filtered in backfill, "c" filtered in live assert results == ["b", "d"] + # "a" filtered in backfill, "c" filtered in live + assert results == ["b", "d"] + assert results == ["b", "d"] + assert results == ["b", "d"] From 32d75d8aafa30a45db3b3ae7579d524fa4135ec3 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 13 Mar 2026 15:38:35 +0800 Subject: [PATCH 109/118] compression codec added, new bigoffice db uploaded --- data/.lfs/go2_bigoffice.db.tar.gz | 3 ++ dimos/mapping/occupancy/gradient.py | 6 ++- dimos/memory2/codecs/__init__.py | 5 ++- dimos/memory2/codecs/base.py | 70 +++++++++++++++++++++++++++++ dimos/memory2/codecs/test_codecs.py | 62 ++++++++++++++++++++----- dimos/memory2/impl/sqlite.py | 51 ++++++++------------- dimos/memory2/store.py | 2 +- pyproject.toml | 1 + uv.lock | 2 + 9 files changed, 154 insertions(+), 48 deletions(-) create mode 100644 data/.lfs/go2_bigoffice.db.tar.gz diff --git a/data/.lfs/go2_bigoffice.db.tar.gz b/data/.lfs/go2_bigoffice.db.tar.gz new file mode 100644 index 0000000000..cd4882b832 --- /dev/null +++ b/data/.lfs/go2_bigoffice.db.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:af0058cbaa02198e2709dd93bbac288421a5922c132470bcfba724d1c9524ec2 +size 183793527 diff --git a/dimos/mapping/occupancy/gradient.py b/dimos/mapping/occupancy/gradient.py index 880f2692da..66e8aaa734 100644 --- a/dimos/mapping/occupancy/gradient.py +++ b/dimos/mapping/occupancy/gradient.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + import numpy as np from scipy import ndimage # type: ignore[import-untyped] @@ -50,7 +52,9 @@ def gradient( # Compute distance transform (distance to nearest obstacle in cells) # Unknown cells are treated as if they don't exist for distance calculation - distance_cells = ndimage.distance_transform_edt(1 - obstacle_map) + distance_cells: np.ndarray[Any, np.dtype[np.float64]] = ndimage.distance_transform_edt( + 1 - obstacle_map + ) # type: ignore[assignment] # Convert to meters and clip to max distance distance_meters = np.clip(distance_cells * occupancy_grid.resolution, 0, max_distance) diff --git a/dimos/memory2/codecs/__init__.py b/dimos/memory2/codecs/__init__.py index a7feb3bce3..07187fa9af 100644 --- a/dimos/memory2/codecs/__init__.py +++ b/dimos/memory2/codecs/__init__.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.memory2.codecs.base import Codec, codec_for +from dimos.memory2.codecs.base import Codec, codec_for, codec_from_id, codec_id +from dimos.memory2.codecs.lz4 import Lz4Codec from dimos.memory2.codecs.pickle import PickleCodec -__all__ = ["Codec", "PickleCodec", "codec_for"] +__all__ = ["Codec", "Lz4Codec", "PickleCodec", "codec_for", "codec_from_id", "codec_id"] diff --git a/dimos/memory2/codecs/base.py b/dimos/memory2/codecs/base.py index 4c2b3865f5..4d082e6eb2 100644 --- a/dimos/memory2/codecs/base.py +++ b/dimos/memory2/codecs/base.py @@ -14,6 +14,7 @@ from __future__ import annotations +import importlib from typing import Any, Protocol, TypeVar T = TypeVar("T") @@ -42,3 +43,72 @@ def codec_for(payload_type: type[Any] | None = None) -> Codec[Any]: return LcmCodec(payload_type) return PickleCodec() + + +# ── Codec ID serialization ─────────────────────────────────────── + + +def codec_id(codec: Codec[Any]) -> str: + """Derive a string ID from a codec instance, e.g. ``'lz4+lcm'``. + + Walks the ``_inner`` chain for wrapper codecs, joining with ``+``. + Uses the naming convention ``FooCodec`` → ``'foo'``. + """ + parts: list[str] = [] + c: Any = codec + while hasattr(c, "_inner"): + parts.append(_class_to_id(c)) + c = c._inner + parts.append(_class_to_id(c)) + return "+".join(parts) + + +def codec_from_id(codec_id_str: str, payload_module: str) -> Codec[Any]: + """Reconstruct a codec chain from its string ID (e.g. ``'lz4+lcm'``). + + Builds inside-out: the rightmost segment is the innermost (base) codec. + """ + parts = codec_id_str.split("+") + # Innermost first + result = _make_one(parts[-1], payload_module) + for name in reversed(parts[:-1]): + result = _make_one(name, payload_module, inner=result) + return result + + +def _class_to_id(codec: Any) -> str: + name = type(codec).__name__ + if name.endswith("Codec"): + return name[:-5].lower() + return name.lower() + + +def _resolve_payload_type(payload_module: str) -> type[Any]: + parts = payload_module.rsplit(".", 1) + if len(parts) != 2: + raise ValueError(f"Cannot resolve payload type from {payload_module!r}") + mod = importlib.import_module(parts[0]) + return getattr(mod, parts[1]) # type: ignore[no-any-return] + + +def _make_one(name: str, payload_module: str, inner: Codec[Any] | None = None) -> Codec[Any]: + """Instantiate a single codec by its short name.""" + if name == "lz4": + from dimos.memory2.codecs.lz4 import Lz4Codec + + if inner is None: + raise ValueError("lz4 is a wrapper codec — must have an inner codec") + return Lz4Codec(inner) + if name == "jpeg": + from dimos.memory2.codecs.jpeg import JpegCodec + + return JpegCodec() + if name == "lcm": + from dimos.memory2.codecs.lcm import LcmCodec + + return LcmCodec(_resolve_payload_type(payload_module)) + if name == "pickle": + from dimos.memory2.codecs.pickle import PickleCodec + + return PickleCodec() + raise ValueError(f"Unknown codec: {name!r}") diff --git a/dimos/memory2/codecs/test_codecs.py b/dimos/memory2/codecs/test_codecs.py index 5c17281e48..3c0055bed0 100644 --- a/dimos/memory2/codecs/test_codecs.py +++ b/dimos/memory2/codecs/test_codecs.py @@ -29,6 +29,8 @@ if TYPE_CHECKING: from collections.abc import Callable + from dimos.msgs.protocol import DimosMsg + # ── Case definition ──────────────────────────────────────────────── @@ -43,6 +45,22 @@ class Case: # ── Test cases ───────────────────────────────────────────────────── +def _lcm_values() -> list[DimosMsg]: + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + from dimos.msgs.geometry_msgs.Quaternion import Quaternion + from dimos.msgs.geometry_msgs.Vector3 import Vector3 + + return [ + PoseStamped( + ts=1.0, + frame_id="map", + position=Vector3(1.0, 2.0, 3.0), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), + ), + PoseStamped(ts=0.5, frame_id="odom"), + ] + + def _pickle_case() -> Case: from dimos.memory2.codecs.pickle import PickleCodec @@ -56,21 +74,34 @@ def _pickle_case() -> Case: def _lcm_case() -> Case: from dimos.memory2.codecs.lcm import LcmCodec from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped - from dimos.msgs.geometry_msgs.Quaternion import Quaternion - from dimos.msgs.geometry_msgs.Vector3 import Vector3 return Case( name="lcm", codec=LcmCodec(PoseStamped), - values=[ - PoseStamped( - ts=1.0, - frame_id="map", - position=Vector3(1.0, 2.0, 3.0), - orientation=Quaternion(0.0, 0.0, 0.0, 1.0), - ), - PoseStamped(ts=0.5, frame_id="odom"), - ], + values=_lcm_values(), + ) + + +def _lz4_pickle_case() -> Case: + from dimos.memory2.codecs.lz4 import Lz4Codec + from dimos.memory2.codecs.pickle import PickleCodec + + return Case( + name="lz4+pickle", + codec=Lz4Codec(PickleCodec()), + values=[42, "hello", b"raw bytes", {"key": "value"}, list(range(1000))], + ) + + +def _lz4_lcm_case() -> Case: + from dimos.memory2.codecs.lcm import LcmCodec + from dimos.memory2.codecs.lz4 import Lz4Codec + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + + return Case( + name="lz4+lcm", + codec=Lz4Codec(LcmCodec(PoseStamped)), + values=_lcm_values(), ) @@ -104,7 +135,11 @@ def _jpeg_case() -> Case | None: ) -testcases = [c for c in [_pickle_case(), _lcm_case(), _jpeg_case()] if c is not None] +testcases = [ + c + for c in [_pickle_case(), _lcm_case(), _lz4_pickle_case(), _lz4_lcm_case(), _jpeg_case()] + if c is not None +] # ── Tests ────────────────────────────────────────────────────────── @@ -157,3 +192,6 @@ def test_image_type_returns_jpeg(self) -> None: from dimos.msgs.sensor_msgs.Image import Image assert isinstance(codec_for(Image), JpegCodec) + assert isinstance(codec_for(Image), JpegCodec) + assert isinstance(codec_for(Image), JpegCodec) + assert isinstance(codec_for(Image), JpegCodec) diff --git a/dimos/memory2/impl/sqlite.py b/dimos/memory2/impl/sqlite.py index 8672aa2eb9..72eabb7264 100644 --- a/dimos/memory2/impl/sqlite.py +++ b/dimos/memory2/impl/sqlite.py @@ -25,7 +25,7 @@ from dimos.core.resource import CompositeResource from dimos.memory2.backend import BackendConfig from dimos.memory2.blobstore.sqlite import SqliteBlobStore -from dimos.memory2.codecs.base import Codec, codec_for +from dimos.memory2.codecs.base import Codec, codec_for, codec_from_id, codec_id from dimos.memory2.filter import ( AfterFilter, AtFilter, @@ -544,37 +544,12 @@ def _open_connection(self) -> sqlite3.Connection: return conn @staticmethod - def _codec_id(codec: Codec[Any]) -> str: - from dimos.memory2.codecs.jpeg import JpegCodec - from dimos.memory2.codecs.lcm import LcmCodec - - if isinstance(codec, JpegCodec): - return "jpeg" - if isinstance(codec, LcmCodec): - return "lcm" - return "pickle" + def _codec_id(c: Codec[Any]) -> str: + return codec_id(c) @staticmethod - def _codec_from_id(codec_id: str, payload_module: str) -> Codec[Any]: - from dimos.memory2.codecs.pickle import PickleCodec - - if codec_id == "jpeg": - from dimos.memory2.codecs.jpeg import JpegCodec - - return JpegCodec() - if codec_id == "lcm": - from dimos.memory2.codecs.lcm import LcmCodec - - # Resolve the payload type from module path - parts = payload_module.rsplit(".", 1) - if len(parts) == 2: - import importlib - - mod = importlib.import_module(parts[0]) - cls = getattr(mod, parts[1]) - return LcmCodec(cls) - return PickleCodec() - return PickleCodec() + def _codec_from_id(codec_id_str: str, payload_module: str) -> Codec[Any]: + return codec_from_id(codec_id_str, payload_module) def _create_backend( self, name: str, payload_type: type[Any] | None = None, **config: Any @@ -595,12 +570,24 @@ def _create_backend( f"Stream {name!r} was created with type {stored_module}, " f"but opened with {actual_module}" ) - codec = config.get("codec") or self._codec_from_id(stored_codec_id, stored_module) + raw_codec = config.get("codec") + if isinstance(raw_codec, str): + codec = codec_from_id(raw_codec, stored_module) + elif raw_codec is not None: + codec = raw_codec + else: + codec = self._codec_from_id(stored_codec_id, stored_module) else: if payload_type is None: raise TypeError(f"Stream {name!r} does not exist yet — payload_type is required") - codec = config.get("codec") or codec_for(payload_type) payload_module = f"{payload_type.__module__}.{payload_type.__qualname__}" + raw_codec = config.get("codec") + if isinstance(raw_codec, str): + codec = codec_from_id(raw_codec, payload_module) + elif raw_codec is not None: + codec = raw_codec + else: + codec = codec_for(payload_type) self._registry_conn.execute( "INSERT INTO _streams (name, payload_module, codec_id) VALUES (?, ?, ?)", (name, payload_module, self._codec_id(codec)), diff --git a/dimos/memory2/store.py b/dimos/memory2/store.py index e9c1ec4e51..c46b11e3e4 100644 --- a/dimos/memory2/store.py +++ b/dimos/memory2/store.py @@ -51,7 +51,7 @@ class SessionConfig: blob_store: BlobStore | None = None vector_store: VectorStore | None = None eager_blobs: bool = False - codec: Codec[Any] | None = None + codec: Codec[Any] | str | None = None # ── Stream namespace ────────────────────────────────────────────── diff --git a/pyproject.toml b/pyproject.toml index 10a6c40ec8..5fa0543e56 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,6 +86,7 @@ dependencies = [ "protobuf>=6.33.5,<7", "psutil>=7.0.0", "sqlite-vec>=0.1.6", + "lz4>=4.4.5", ] diff --git a/uv.lock b/uv.lock index a3abe85559..6e6ebc9810 100644 --- a/uv.lock +++ b/uv.lock @@ -1686,6 +1686,7 @@ dependencies = [ { name = "dimos-viewer" }, { name = "lazy-loader" }, { name = "llvmlite" }, + { name = "lz4" }, { name = "numba" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, @@ -2026,6 +2027,7 @@ requires-dist = [ { name = "lcm", marker = "extra == 'docker'" }, { name = "llvmlite", specifier = ">=0.42.0" }, { name = "lxml-stubs", marker = "extra == 'dev'", specifier = ">=0.5.1,<1" }, + { name = "lz4", specifier = ">=4.4.5" }, { name = "matplotlib", marker = "extra == 'manipulation'", specifier = ">=3.7.1" }, { name = "md-babel-py", marker = "extra == 'dev'", specifier = "==1.1.1" }, { name = "moondream", marker = "extra == 'perception'" }, From b7e25a9ae61127b3c423b65bdf71d5f02f9568d5 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 13 Mar 2026 15:48:58 +0800 Subject: [PATCH 110/118] correct jpeg codec --- data/.lfs/go2_bigoffice.db.tar.gz | 4 +-- dimos/memory2/codecs/jpeg.py | 41 ++++++------------------------- 2 files changed, 9 insertions(+), 36 deletions(-) diff --git a/data/.lfs/go2_bigoffice.db.tar.gz b/data/.lfs/go2_bigoffice.db.tar.gz index cd4882b832..93a0af8d25 100644 --- a/data/.lfs/go2_bigoffice.db.tar.gz +++ b/data/.lfs/go2_bigoffice.db.tar.gz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:af0058cbaa02198e2709dd93bbac288421a5922c132470bcfba724d1c9524ec2 -size 183793527 +oid sha256:3a6e1bfa79c81cb40425c3dcce5b51f721f5435e83bd7c3f502b7014c156b20f +size 183890080 diff --git a/dimos/memory2/codecs/jpeg.py b/dimos/memory2/codecs/jpeg.py index 3ef605c0db..cd1cc677c8 100644 --- a/dimos/memory2/codecs/jpeg.py +++ b/dimos/memory2/codecs/jpeg.py @@ -14,50 +14,23 @@ from __future__ import annotations -import struct from typing import Any class JpegCodec: - """Codec for Image types — stores as JPEG bytes (lossy, ~10-20x smaller). + """Codec for Image types — JPEG-compressed inside an LCM Image envelope. - Uses TurboJPEG (libjpeg-turbo) for 2-5x faster encode/decode vs OpenCV. - Preserves ``frame_id`` as a short header: ````. - Pixel data is lossy-compressed; ``ts`` is NOT preserved (stored separately). + Uses ``Image.lcm_jpeg_encode/decode`` which preserves ``ts``, ``frame_id``, + and all LCM header fields. Pixel data is lossy-compressed via TurboJPEG. """ def __init__(self, quality: int = 50) -> None: self._quality = quality - from turbojpeg import TurboJPEG # type: ignore[import-untyped] - - self._tj = TurboJPEG() - - _TJPF_MAP: dict[str, int] | None = None - - @staticmethod - def _get_tjpf_map() -> dict[str, int]: - if JpegCodec._TJPF_MAP is None: - from turbojpeg import TJPF_BGR, TJPF_GRAY, TJPF_RGB # type: ignore[import-untyped] - - JpegCodec._TJPF_MAP = {"BGR": TJPF_BGR, "RGB": TJPF_RGB, "GRAY": TJPF_GRAY} - return JpegCodec._TJPF_MAP def encode(self, value: Any) -> bytes: - from turbojpeg import TJPF_BGR # type: ignore[import-untyped] - - pf = self._get_tjpf_map().get(value.format.value, TJPF_BGR) - jpeg_data: bytes = self._tj.encode(value.data, quality=self._quality, pixel_format=pf) - frame_id = (value.frame_id or "").encode("utf-8") - header = struct.pack(" Any: - from dimos.msgs.sensor_msgs.Image import Image, ImageFormat - - fid_len = struct.unpack(" Date: Fri, 13 Mar 2026 16:07:57 +0800 Subject: [PATCH 111/118] PR comments cleanup --- dimos/core/resource.py | 2 +- dimos/memory2/backend.py | 7 ++----- dimos/memory2/blobstore/sqlite.py | 16 ++++++++++------ dimos/memory2/impl/sqlite.py | 2 +- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/dimos/core/resource.py b/dimos/core/resource.py index 63ba31f210..25d590706f 100644 --- a/dimos/core/resource.py +++ b/dimos/core/resource.py @@ -65,7 +65,7 @@ class CompositeResource(Resource): def __init__(self) -> None: self._disposables = CompositeDisposable() - def own(self, *disposables: DisposableBase) -> None: + def register_disposables(self, *disposables: DisposableBase) -> None: """Register child disposables to be disposed when this resource stops.""" for d in disposables: self._disposables.add(d) diff --git a/dimos/memory2/backend.py b/dimos/memory2/backend.py index 928b74e229..9125a63f86 100644 --- a/dimos/memory2/backend.py +++ b/dimos/memory2/backend.py @@ -109,14 +109,11 @@ def notify(self, obs: Observation[T]) -> None: # ── Blob storage ────────────────────────────────────────────────── -class BlobStore(Resource, ABC): +class BlobStore(Resource): """Persistent storage for encoded payload blobs. Separates payload data from metadata indexing so that large blobs (images, point clouds) don't penalize metadata queries. - - Extends Resource (start/stop) but does NOT manage its dependencies' - lifecycle — the caller owns the session / connection. """ @abstractmethod @@ -138,7 +135,7 @@ def delete(self, stream: str, key: int) -> None: # ── Vector storage ─────────────────────────────────────────────── -class VectorStore(Resource, ABC): +class VectorStore(Resource): """Pluggable storage and ANN index for embedding vectors. Separates vector indexing from metadata so backends can swap diff --git a/dimos/memory2/blobstore/sqlite.py b/dimos/memory2/blobstore/sqlite.py index fac00bef7b..3019b6f77e 100644 --- a/dimos/memory2/blobstore/sqlite.py +++ b/dimos/memory2/blobstore/sqlite.py @@ -67,15 +67,19 @@ def put(self, stream: str, key: int, data: bytes) -> None: ) def get(self, stream: str, key: int) -> bytes: - self._ensure_table(stream) - row = self._conn.execute( - f'SELECT data FROM "{stream}_blob" WHERE id = ?', (key,) - ).fetchone() + try: + row = self._conn.execute( + f'SELECT data FROM "{stream}_blob" WHERE id = ?', (key,) + ).fetchone() + except Exception: + raise KeyError(f"No blob for stream={stream!r}, key={key}") if row is None: raise KeyError(f"No blob for stream={stream!r}, key={key}") result: bytes = row[0] return result def delete(self, stream: str, key: int) -> None: - self._ensure_table(stream) - self._conn.execute(f'DELETE FROM "{stream}_blob" WHERE id = ?', (key,)) + try: + self._conn.execute(f'DELETE FROM "{stream}_blob" WHERE id = ?', (key,)) + except Exception: + pass diff --git a/dimos/memory2/impl/sqlite.py b/dimos/memory2/impl/sqlite.py index 72eabb7264..5d7ed80cec 100644 --- a/dimos/memory2/impl/sqlite.py +++ b/dimos/memory2/impl/sqlite.py @@ -628,7 +628,7 @@ def _create_backend( config["codec"] = codec backend: SqliteBackend[Any] = SqliteBackend(backend_conn, name, **config) - self.own(backend) + self.register_disposables(backend) return backend def list_streams(self) -> list[str]: From 8be106a7d5ee36206f60abae86209db1c41321a1 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 13 Mar 2026 16:12:01 +0800 Subject: [PATCH 112/118] blobstore stream -> stream_name --- dimos/memory2/backend.py | 6 +++--- dimos/memory2/blobstore/blobstore.md | 8 ++++---- dimos/memory2/blobstore/file.py | 18 +++++++++--------- dimos/memory2/blobstore/sqlite.py | 26 +++++++++++++------------- 4 files changed, 29 insertions(+), 29 deletions(-) diff --git a/dimos/memory2/backend.py b/dimos/memory2/backend.py index 9125a63f86..d43d614820 100644 --- a/dimos/memory2/backend.py +++ b/dimos/memory2/backend.py @@ -117,17 +117,17 @@ class BlobStore(Resource): """ @abstractmethod - def put(self, stream: str, key: int, data: bytes) -> None: + def put(self, stream_name: str, key: int, data: bytes) -> None: """Store a blob for the given stream and observation id.""" ... @abstractmethod - def get(self, stream: str, key: int) -> bytes: + def get(self, stream_name: str, key: int) -> bytes: """Retrieve a blob by stream name and observation id.""" ... @abstractmethod - def delete(self, stream: str, key: int) -> None: + def delete(self, stream_name: str, key: int) -> None: """Delete a blob by stream name and observation id.""" ... diff --git a/dimos/memory2/blobstore/blobstore.md b/dimos/memory2/blobstore/blobstore.md index 79a36d52ae..edc4f8a217 100644 --- a/dimos/memory2/blobstore/blobstore.md +++ b/dimos/memory2/blobstore/blobstore.md @@ -6,12 +6,12 @@ Separates payload blob storage from metadata indexing. Observation payloads vary ```python class BlobStore(Resource, ABC): - def put(self, stream: str, key: int, data: bytes) -> None: ... - def get(self, stream: str, key: int) -> bytes: ... # raises KeyError if missing - def delete(self, stream: str, key: int) -> None: ... # silent if missing + def put(self, stream_name: str, key: int, data: bytes) -> None: ... + def get(self, stream_name: str, key: int) -> bytes: ... # raises KeyError if missing + def delete(self, stream_name: str, key: int) -> None: ... # silent if missing ``` -- `stream` — stream name (used to organize storage: directories, tables) +- `stream_name` — stream name (used to organize storage: directories, tables) - `key` — observation id - `data` — encoded payload bytes (codec handles serialization, blob store handles persistence) - Extends `Resource` (start/stop) but does NOT own its dependencies' lifecycle diff --git a/dimos/memory2/blobstore/file.py b/dimos/memory2/blobstore/file.py index 54ec80e284..fda557ae01 100644 --- a/dimos/memory2/blobstore/file.py +++ b/dimos/memory2/blobstore/file.py @@ -34,8 +34,8 @@ class FileBlobStore(BlobStore): def __init__(self, root: str | os.PathLike[str]) -> None: self._root = Path(root) - def _path(self, stream: str, key: int) -> Path: - return self._root / stream / f"{key}.bin" + def _path(self, stream_name: str, key: int) -> Path: + return self._root / stream_name / f"{key}.bin" # ── Resource lifecycle ──────────────────────────────────────── @@ -47,18 +47,18 @@ def stop(self) -> None: # ── BlobStore interface ─────────────────────────────────────── - def put(self, stream: str, key: int, data: bytes) -> None: - p = self._path(stream, key) + def put(self, stream_name: str, key: int, data: bytes) -> None: + p = self._path(stream_name, key) p.parent.mkdir(parents=True, exist_ok=True) p.write_bytes(data) - def get(self, stream: str, key: int) -> bytes: - p = self._path(stream, key) + def get(self, stream_name: str, key: int) -> bytes: + p = self._path(stream_name, key) try: return p.read_bytes() except FileNotFoundError: - raise KeyError(f"No blob for stream={stream!r}, key={key}") from None + raise KeyError(f"No blob for stream={stream_name!r}, key={key}") from None - def delete(self, stream: str, key: int) -> None: - p = self._path(stream, key) + def delete(self, stream_name: str, key: int) -> None: + p = self._path(stream_name, key) p.unlink(missing_ok=True) diff --git a/dimos/memory2/blobstore/sqlite.py b/dimos/memory2/blobstore/sqlite.py index 3019b6f77e..e9cef65159 100644 --- a/dimos/memory2/blobstore/sqlite.py +++ b/dimos/memory2/blobstore/sqlite.py @@ -40,14 +40,14 @@ def __init__(self, conn: sqlite3.Connection) -> None: self._conn = conn self._tables: set[str] = set() - def _ensure_table(self, stream: str) -> None: - if stream in self._tables: + def _ensure_table(self, stream_name: str) -> None: + if stream_name in self._tables: return self._conn.execute( - f'CREATE TABLE IF NOT EXISTS "{stream}_blob" ' + f'CREATE TABLE IF NOT EXISTS "{stream_name}_blob" ' "(id INTEGER PRIMARY KEY, data BLOB NOT NULL)" ) - self._tables.add(stream) + self._tables.add(stream_name) # ── Resource lifecycle ──────────────────────────────────────── @@ -59,27 +59,27 @@ def stop(self) -> None: # ── BlobStore interface ─────────────────────────────────────── - def put(self, stream: str, key: int, data: bytes) -> None: - self._ensure_table(stream) + def put(self, stream_name: str, key: int, data: bytes) -> None: + self._ensure_table(stream_name) self._conn.execute( - f'INSERT OR REPLACE INTO "{stream}_blob" (id, data) VALUES (?, ?)', + f'INSERT OR REPLACE INTO "{stream_name}_blob" (id, data) VALUES (?, ?)', (key, data), ) - def get(self, stream: str, key: int) -> bytes: + def get(self, stream_name: str, key: int) -> bytes: try: row = self._conn.execute( - f'SELECT data FROM "{stream}_blob" WHERE id = ?', (key,) + f'SELECT data FROM "{stream_name}_blob" WHERE id = ?', (key,) ).fetchone() except Exception: - raise KeyError(f"No blob for stream={stream!r}, key={key}") + raise KeyError(f"No blob for stream={stream_name!r}, key={key}") if row is None: - raise KeyError(f"No blob for stream={stream!r}, key={key}") + raise KeyError(f"No blob for stream={stream_name!r}, key={key}") result: bytes = row[0] return result - def delete(self, stream: str, key: int) -> None: + def delete(self, stream_name: str, key: int) -> None: try: - self._conn.execute(f'DELETE FROM "{stream}_blob" WHERE id = ?', (key,)) + self._conn.execute(f'DELETE FROM "{stream_name}_blob" WHERE id = ?', (key,)) except Exception: pass From 1e28b509f165926d0a399f9deb2ccdc70014702e Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 13 Mar 2026 16:15:30 +0800 Subject: [PATCH 113/118] vectorstore stream -> stream_name --- dimos/memory2/backend.py | 6 +++--- dimos/memory2/vectorstore/memory.py | 12 ++++++------ dimos/memory2/vectorstore/sqlite.py | 28 ++++++++++++++-------------- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/dimos/memory2/backend.py b/dimos/memory2/backend.py index d43d614820..cd8062522a 100644 --- a/dimos/memory2/backend.py +++ b/dimos/memory2/backend.py @@ -147,16 +147,16 @@ class VectorStore(Resource): """ @abstractmethod - def put(self, stream: str, key: int, embedding: Embedding) -> None: + def put(self, stream_name: str, key: int, embedding: Embedding) -> None: """Store an embedding vector for the given stream and observation id.""" ... @abstractmethod - def search(self, stream: str, query: Embedding, k: int) -> list[tuple[int, float]]: + def search(self, stream_name: str, query: Embedding, k: int) -> list[tuple[int, float]]: """Return top-k (observation_id, similarity) pairs, descending.""" ... @abstractmethod - def delete(self, stream: str, key: int) -> None: + def delete(self, stream_name: str, key: int) -> None: """Remove a vector. Silent if missing.""" ... diff --git a/dimos/memory2/vectorstore/memory.py b/dimos/memory2/vectorstore/memory.py index 22532c6ad1..fd01514348 100644 --- a/dimos/memory2/vectorstore/memory.py +++ b/dimos/memory2/vectorstore/memory.py @@ -42,17 +42,17 @@ def stop(self) -> None: # ── VectorStore interface ──────────────────────────────────── - def put(self, stream: str, key: int, embedding: Embedding) -> None: - self._vectors.setdefault(stream, {})[key] = embedding + def put(self, stream_name: str, key: int, embedding: Embedding) -> None: + self._vectors.setdefault(stream_name, {})[key] = embedding - def search(self, stream: str, query: Embedding, k: int) -> list[tuple[int, float]]: - vectors = self._vectors.get(stream, {}) + def search(self, stream_name: str, query: Embedding, k: int) -> list[tuple[int, float]]: + vectors = self._vectors.get(stream_name, {}) if not vectors: return [] scored = [(key, float(emb @ query)) for key, emb in vectors.items()] scored.sort(key=lambda x: x[1], reverse=True) return scored[:k] - def delete(self, stream: str, key: int) -> None: - vectors = self._vectors.get(stream, {}) + def delete(self, stream_name: str, key: int) -> None: + vectors = self._vectors.get(stream_name, {}) vectors.pop(key, None) diff --git a/dimos/memory2/vectorstore/sqlite.py b/dimos/memory2/vectorstore/sqlite.py index 736cc16e27..e6c5ee19bc 100644 --- a/dimos/memory2/vectorstore/sqlite.py +++ b/dimos/memory2/vectorstore/sqlite.py @@ -36,16 +36,16 @@ class SqliteVectorStore(VectorStore): def __init__(self, conn: sqlite3.Connection) -> None: self._conn = conn - self._tables: dict[str, int] = {} # stream -> dimensionality + self._tables: dict[str, int] = {} # stream_name -> dimensionality - def _ensure_table(self, stream: str, dim: int) -> None: - if stream in self._tables: + def _ensure_table(self, stream_name: str, dim: int) -> None: + if stream_name in self._tables: return self._conn.execute( - f'CREATE VIRTUAL TABLE IF NOT EXISTS "{stream}_vec" ' + f'CREATE VIRTUAL TABLE IF NOT EXISTS "{stream_name}_vec" ' f"USING vec0(embedding float[{dim}] distance_metric=cosine)" ) - self._tables[stream] = dim + self._tables[stream_name] = dim # ── Resource lifecycle ──────────────────────────────────────── @@ -57,26 +57,26 @@ def stop(self) -> None: # ── VectorStore interface ──────────────────────────────────── - def put(self, stream: str, key: int, embedding: Embedding) -> None: + def put(self, stream_name: str, key: int, embedding: Embedding) -> None: vec = embedding.to_numpy().tolist() - self._ensure_table(stream, len(vec)) + self._ensure_table(stream_name, len(vec)) self._conn.execute( - f'INSERT OR REPLACE INTO "{stream}_vec" (rowid, embedding) VALUES (?, ?)', + f'INSERT OR REPLACE INTO "{stream_name}_vec" (rowid, embedding) VALUES (?, ?)', (key, json.dumps(vec)), ) - def search(self, stream: str, query: Embedding, k: int) -> list[tuple[int, float]]: - if stream not in self._tables: + def search(self, stream_name: str, query: Embedding, k: int) -> list[tuple[int, float]]: + if stream_name not in self._tables: return [] vec = query.to_numpy().tolist() rows = self._conn.execute( - f'SELECT rowid, distance FROM "{stream}_vec" WHERE embedding MATCH ? AND k = ?', + f'SELECT rowid, distance FROM "{stream_name}_vec" WHERE embedding MATCH ? AND k = ?', (json.dumps(vec), k), ).fetchall() # vec0 cosine distance = 1 - cosine_similarity return [(int(row[0]), max(0.0, 1.0 - row[1])) for row in rows] - def delete(self, stream: str, key: int) -> None: - if stream not in self._tables: + def delete(self, stream_name: str, key: int) -> None: + if stream_name not in self._tables: return - self._conn.execute(f'DELETE FROM "{stream}_vec" WHERE rowid = ?', (key,)) + self._conn.execute(f'DELETE FROM "{stream_name}_vec" WHERE rowid = ?', (key,)) From 6f3ef511f49bcf353200b66992a42b3f1ccf79fe Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 13 Mar 2026 16:50:23 +0800 Subject: [PATCH 114/118] resource typing fixes --- dimos/core/resource.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/dimos/core/resource.py b/dimos/core/resource.py index 25d590706f..63b1eec4f0 100644 --- a/dimos/core/resource.py +++ b/dimos/core/resource.py @@ -15,6 +15,10 @@ from __future__ import annotations from abc import abstractmethod +from typing import TYPE_CHECKING, Self + +if TYPE_CHECKING: + from types import TracebackType from reactivex.abc import DisposableBase from reactivex.disposable import CompositeDisposable @@ -49,11 +53,16 @@ def dispose(self) -> None: """ self.stop() - def __enter__(self) -> Resource: + def __enter__(self) -> Self: self.start() return self - def __exit__(self, *args: object) -> None: + def __exit__( + self, + exctype: type[BaseException] | None, + excinst: BaseException | None, + exctb: TracebackType | None, + ) -> None: self.stop() From 30959af9af9afbe00d71465733af87f233674f13 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 13 Mar 2026 17:14:38 +0800 Subject: [PATCH 115/118] move type definitions into dimos/memory2/type/ subpackage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Separate pure-definition files (protocols, ABCs, dataclasses) from implementation files by moving them into a type/ subpackage: - backend.py → type/backend.py - type.py → type/observation.py - filter.py → type/filter.py Added type/__init__.py with re-exports for convenience imports. Updated all 24 importing files across the module. --- dimos/memory2/__init__.py | 18 +++--- dimos/memory2/blobstore/__init__.py | 2 +- dimos/memory2/blobstore/file.py | 4 +- dimos/memory2/blobstore/sqlite.py | 4 +- dimos/memory2/blobstore/test_blobstore.py | 2 +- dimos/memory2/conftest.py | 2 +- dimos/memory2/embed.py | 2 +- dimos/memory2/impl/README.md | 6 +- dimos/memory2/impl/memory.py | 10 ++-- dimos/memory2/impl/sqlite.py | 22 +++---- dimos/memory2/livechannel/__init__.py | 2 +- dimos/memory2/livechannel/subject.py | 4 +- dimos/memory2/store.py | 2 +- dimos/memory2/stream.py | 8 +-- dimos/memory2/test_blobstore_integration.py | 2 +- dimos/memory2/test_embedding.py | 2 +- dimos/memory2/test_impl.py | 8 +-- dimos/memory2/test_save.py | 4 +- dimos/memory2/test_stream.py | 2 +- dimos/memory2/transform.py | 2 +- dimos/memory2/type/__init__.py | 59 +++++++++++++++++++ dimos/memory2/{ => type}/backend.py | 4 +- dimos/memory2/{ => type}/filter.py | 2 +- .../memory2/{type.py => type/observation.py} | 0 dimos/memory2/vectorstore/memory.py | 2 +- dimos/memory2/vectorstore/sqlite.py | 4 +- 26 files changed, 120 insertions(+), 59 deletions(-) create mode 100644 dimos/memory2/type/__init__.py rename dimos/memory2/{ => type}/backend.py (97%) rename dimos/memory2/{ => type}/filter.py (99%) rename dimos/memory2/{type.py => type/observation.py} (100%) diff --git a/dimos/memory2/__init__.py b/dimos/memory2/__init__.py index 0b358fe438..36dfa8ad60 100644 --- a/dimos/memory2/__init__.py +++ b/dimos/memory2/__init__.py @@ -1,4 +1,3 @@ -from dimos.memory2.backend import Backend, LiveChannel, VectorStore from dimos.memory2.buffer import ( BackpressureBuffer, Bounded, @@ -8,7 +7,14 @@ Unbounded, ) from dimos.memory2.embed import EmbedImages, EmbedText -from dimos.memory2.filter import ( +from dimos.memory2.impl.memory import ListBackend, MemorySession, MemoryStore +from dimos.memory2.impl.sqlite import SqliteBackend, SqliteSession, SqliteStore, SqliteStoreConfig +from dimos.memory2.livechannel import SubjectChannel +from dimos.memory2.store import Session, SessionConfig, Store, StoreConfig, StreamNamespace +from dimos.memory2.stream import Stream +from dimos.memory2.transform import FnTransformer, QualityWindow, Transformer +from dimos.memory2.type.backend import Backend, LiveChannel, VectorStore +from dimos.memory2.type.filter import ( AfterFilter, AtFilter, BeforeFilter, @@ -19,13 +25,7 @@ TagsFilter, TimeRangeFilter, ) -from dimos.memory2.impl.memory import ListBackend, MemorySession, MemoryStore -from dimos.memory2.impl.sqlite import SqliteBackend, SqliteSession, SqliteStore, SqliteStoreConfig -from dimos.memory2.livechannel import SubjectChannel -from dimos.memory2.store import Session, SessionConfig, Store, StoreConfig, StreamNamespace -from dimos.memory2.stream import Stream -from dimos.memory2.transform import FnTransformer, QualityWindow, Transformer -from dimos.memory2.type import EmbeddedObservation, Observation +from dimos.memory2.type.observation import EmbeddedObservation, Observation __all__ = [ "AfterFilter", diff --git a/dimos/memory2/blobstore/__init__.py b/dimos/memory2/blobstore/__init__.py index 8f78d7c439..bdc0adc034 100644 --- a/dimos/memory2/blobstore/__init__.py +++ b/dimos/memory2/blobstore/__init__.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.memory2.backend import BlobStore from dimos.memory2.blobstore.file import FileBlobStore from dimos.memory2.blobstore.sqlite import SqliteBlobStore +from dimos.memory2.type.backend import BlobStore __all__ = ["BlobStore", "FileBlobStore", "SqliteBlobStore"] diff --git a/dimos/memory2/blobstore/file.py b/dimos/memory2/blobstore/file.py index fda557ae01..e2d7d492b8 100644 --- a/dimos/memory2/blobstore/file.py +++ b/dimos/memory2/blobstore/file.py @@ -17,7 +17,8 @@ from pathlib import Path from typing import TYPE_CHECKING -from dimos.memory2.backend import BlobStore +from dimos.memory2.type.backend import BlobStore +from dimos.memory2.utils import validate_identifier if TYPE_CHECKING: import os @@ -35,6 +36,7 @@ def __init__(self, root: str | os.PathLike[str]) -> None: self._root = Path(root) def _path(self, stream_name: str, key: int) -> Path: + validate_identifier(stream_name) return self._root / stream_name / f"{key}.bin" # ── Resource lifecycle ──────────────────────────────────────── diff --git a/dimos/memory2/blobstore/sqlite.py b/dimos/memory2/blobstore/sqlite.py index e9cef65159..426b6ab503 100644 --- a/dimos/memory2/blobstore/sqlite.py +++ b/dimos/memory2/blobstore/sqlite.py @@ -16,7 +16,8 @@ from typing import TYPE_CHECKING -from dimos.memory2.backend import BlobStore +from dimos.memory2.type.backend import BlobStore +from dimos.memory2.utils import validate_identifier if TYPE_CHECKING: import sqlite3 @@ -43,6 +44,7 @@ def __init__(self, conn: sqlite3.Connection) -> None: def _ensure_table(self, stream_name: str) -> None: if stream_name in self._tables: return + validate_identifier(stream_name) self._conn.execute( f'CREATE TABLE IF NOT EXISTS "{stream_name}_blob" ' "(id INTEGER PRIMARY KEY, data BLOB NOT NULL)" diff --git a/dimos/memory2/blobstore/test_blobstore.py b/dimos/memory2/blobstore/test_blobstore.py index ebe051a17f..70f9915722 100644 --- a/dimos/memory2/blobstore/test_blobstore.py +++ b/dimos/memory2/blobstore/test_blobstore.py @@ -21,7 +21,7 @@ import pytest if TYPE_CHECKING: - from dimos.memory2.backend import BlobStore + from dimos.memory2.type.backend import BlobStore class TestBlobStore: diff --git a/dimos/memory2/conftest.py b/dimos/memory2/conftest.py index de73c249bc..297380213b 100644 --- a/dimos/memory2/conftest.py +++ b/dimos/memory2/conftest.py @@ -31,9 +31,9 @@ from collections.abc import Generator from pathlib import Path - from dimos.memory2.backend import BlobStore from dimos.memory2.impl.memory import MemorySession from dimos.memory2.store import Session + from dimos.memory2.type.backend import BlobStore # ── Stores ──────────────────────────────────────────────────────── diff --git a/dimos/memory2/embed.py b/dimos/memory2/embed.py index 981bd83b73..17b5b98a31 100644 --- a/dimos/memory2/embed.py +++ b/dimos/memory2/embed.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: from collections.abc import Iterator - from dimos.memory2.type import Observation + from dimos.memory2.type.observation import Observation from dimos.models.embedding.base import EmbeddingModel T = TypeVar("T") diff --git a/dimos/memory2/impl/README.md b/dimos/memory2/impl/README.md index c10c5b235c..b30626fe5f 100644 --- a/dimos/memory2/impl/README.md +++ b/dimos/memory2/impl/README.md @@ -14,10 +14,10 @@ Storage backends for memory. Each backend implements the `Backend` protocol to p ### 1. Implement the Backend protocol ```python -from dimos.memory2.backend import Backend, BackendConfig, LiveChannel -from dimos.memory2.filter import StreamQuery +from dimos.memory2.type.backend import Backend, BackendConfig, LiveChannel +from dimos.memory2.type.filter import StreamQuery from dimos.memory2.livechannel.subject import SubjectChannel -from dimos.memory2.type import Observation +from dimos.memory2.type.observation import Observation from dimos.protocol.service.spec import Configurable class MyBackend(Configurable[BackendConfig], Generic[T]): diff --git a/dimos/memory2/impl/memory.py b/dimos/memory2/impl/memory.py index 1b4fe91b8c..1c1148bf42 100644 --- a/dimos/memory2/impl/memory.py +++ b/dimos/memory2/impl/memory.py @@ -18,11 +18,11 @@ import threading from typing import TYPE_CHECKING, Any, Generic, TypeVar -from dimos.memory2.backend import BackendConfig from dimos.memory2.codecs.base import Codec, codec_for from dimos.memory2.livechannel.subject import SubjectChannel from dimos.memory2.store import Session, Store -from dimos.memory2.type import _UNLOADED +from dimos.memory2.type.backend import BackendConfig +from dimos.memory2.type.observation import _UNLOADED from dimos.protocol.service.spec import Configurable if TYPE_CHECKING: @@ -30,10 +30,10 @@ from reactivex.abc import DisposableBase - from dimos.memory2.backend import Backend, LiveChannel from dimos.memory2.buffer import BackpressureBuffer - from dimos.memory2.filter import StreamQuery - from dimos.memory2.type import Observation + from dimos.memory2.type.backend import Backend, LiveChannel + from dimos.memory2.type.filter import StreamQuery + from dimos.memory2.type.observation import Observation T = TypeVar("T") diff --git a/dimos/memory2/impl/sqlite.py b/dimos/memory2/impl/sqlite.py index 5d7ed80cec..028e7c1591 100644 --- a/dimos/memory2/impl/sqlite.py +++ b/dimos/memory2/impl/sqlite.py @@ -23,10 +23,12 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar from dimos.core.resource import CompositeResource -from dimos.memory2.backend import BackendConfig from dimos.memory2.blobstore.sqlite import SqliteBlobStore from dimos.memory2.codecs.base import Codec, codec_for, codec_from_id, codec_id -from dimos.memory2.filter import ( +from dimos.memory2.livechannel.subject import SubjectChannel +from dimos.memory2.store import Session, Store, StoreConfig +from dimos.memory2.type.backend import BackendConfig +from dimos.memory2.type.filter import ( AfterFilter, AtFilter, BeforeFilter, @@ -35,9 +37,8 @@ TimeRangeFilter, _xyz, ) -from dimos.memory2.livechannel.subject import SubjectChannel -from dimos.memory2.store import Session, Store, StoreConfig -from dimos.memory2.type import _UNLOADED, Observation +from dimos.memory2.type.observation import _UNLOADED, Observation +from dimos.memory2.utils import validate_identifier from dimos.protocol.service.spec import Configurable if TYPE_CHECKING: @@ -45,9 +46,9 @@ from reactivex.abc import DisposableBase - from dimos.memory2.backend import Backend, LiveChannel from dimos.memory2.buffer import BackpressureBuffer - from dimos.memory2.filter import Filter, StreamQuery + from dimos.memory2.type.backend import Backend, LiveChannel + from dimos.memory2.type.filter import Filter, StreamQuery T = TypeVar("T") @@ -57,11 +58,6 @@ # ── Helpers ────────────────────────────────────────────────────── -def _validate_identifier(name: str) -> None: - if not _IDENT_RE.match(name): - raise ValueError(f"Invalid stream name: {name!r}") - - def _decompose_pose(pose: Any) -> tuple[float, ...] | None: if pose is None: return None @@ -554,7 +550,7 @@ def _codec_from_id(codec_id_str: str, payload_module: str) -> Codec[Any]: def _create_backend( self, name: str, payload_type: type[Any] | None = None, **config: Any ) -> Backend[Any]: - _validate_identifier(name) + validate_identifier(name) # Look up existing stream in registry row = self._registry_conn.execute( diff --git a/dimos/memory2/livechannel/__init__.py b/dimos/memory2/livechannel/__init__.py index 4fba822bab..fdd7c37aa9 100644 --- a/dimos/memory2/livechannel/__init__.py +++ b/dimos/memory2/livechannel/__init__.py @@ -1,4 +1,4 @@ -from dimos.memory2.backend import LiveChannel from dimos.memory2.livechannel.subject import SubjectChannel +from dimos.memory2.type.backend import LiveChannel __all__ = ["LiveChannel", "SubjectChannel"] diff --git a/dimos/memory2/livechannel/subject.py b/dimos/memory2/livechannel/subject.py index 2d2b848f9f..b1e0c40581 100644 --- a/dimos/memory2/livechannel/subject.py +++ b/dimos/memory2/livechannel/subject.py @@ -21,13 +21,13 @@ from reactivex.disposable import Disposable -from dimos.memory2.backend import LiveChannel +from dimos.memory2.type.backend import LiveChannel if TYPE_CHECKING: from reactivex.abc import DisposableBase from dimos.memory2.buffer import BackpressureBuffer - from dimos.memory2.type import Observation + from dimos.memory2.type.observation import Observation T = TypeVar("T") diff --git a/dimos/memory2/store.py b/dimos/memory2/store.py index c46b11e3e4..3fc1b682ad 100644 --- a/dimos/memory2/store.py +++ b/dimos/memory2/store.py @@ -25,8 +25,8 @@ if TYPE_CHECKING: from collections.abc import Iterator - from dimos.memory2.backend import Backend, BlobStore, LiveChannel, VectorStore from dimos.memory2.codecs.base import Codec + from dimos.memory2.type.backend import Backend, BlobStore, LiveChannel, VectorStore T = TypeVar("T") diff --git a/dimos/memory2/stream.py b/dimos/memory2/stream.py index 14eb7cd5ee..52ac516c88 100644 --- a/dimos/memory2/stream.py +++ b/dimos/memory2/stream.py @@ -18,9 +18,10 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar from dimos.core.resource import Resource -from dimos.memory2.backend import Backend from dimos.memory2.buffer import BackpressureBuffer, KeepLast -from dimos.memory2.filter import ( +from dimos.memory2.transform import FnIterTransformer, FnTransformer, Transformer +from dimos.memory2.type.backend import Backend +from dimos.memory2.type.filter import ( AfterFilter, AtFilter, BeforeFilter, @@ -31,8 +32,7 @@ TagsFilter, TimeRangeFilter, ) -from dimos.memory2.transform import FnIterTransformer, FnTransformer, Transformer -from dimos.memory2.type import EmbeddedObservation, Observation +from dimos.memory2.type.observation import EmbeddedObservation, Observation if TYPE_CHECKING: from collections.abc import Callable, Iterator diff --git a/dimos/memory2/test_blobstore_integration.py b/dimos/memory2/test_blobstore_integration.py index c961d1fe31..120f0438f8 100644 --- a/dimos/memory2/test_blobstore_integration.py +++ b/dimos/memory2/test_blobstore_integration.py @@ -23,7 +23,7 @@ from dimos.memory2.blobstore.file import FileBlobStore from dimos.memory2.impl.memory import MemoryStore -from dimos.memory2.type import _UNLOADED +from dimos.memory2.type.observation import _UNLOADED from dimos.models.embedding.base import Embedding if TYPE_CHECKING: diff --git a/dimos/memory2/test_embedding.py b/dimos/memory2/test_embedding.py index d2b37bf210..2a30a761b7 100644 --- a/dimos/memory2/test_embedding.py +++ b/dimos/memory2/test_embedding.py @@ -19,7 +19,7 @@ import numpy as np import pytest -from dimos.memory2.type import EmbeddedObservation, Observation +from dimos.memory2.type.observation import EmbeddedObservation, Observation from dimos.models.embedding.base import Embedding # ── Helpers ─────────────────────────────────────────────────────── diff --git a/dimos/memory2/test_impl.py b/dimos/memory2/test_impl.py index faf5dc6258..69e6150c16 100644 --- a/dimos/memory2/test_impl.py +++ b/dimos/memory2/test_impl.py @@ -158,7 +158,7 @@ def test_same_stream_on_repeated_calls(self, session: Session) -> None: def test_append_with_embedding(self, session: Session) -> None: import numpy as np - from dimos.memory2.type import EmbeddedObservation + from dimos.memory2.type.observation import EmbeddedObservation from dimos.models.embedding.base import Embedding s = session.stream("vectors", str) @@ -208,7 +208,7 @@ class TestBlobLoading: def test_sqlite_lazy_by_default(self, sqlite_session: Session) -> None: """Default sqlite iteration uses lazy loaders — data is _UNLOADED until accessed.""" - from dimos.memory2.type import _Unloaded + from dimos.memory2.type.observation import _Unloaded s = sqlite_session.stream("lazy_test", str) s.append("hello", ts=1.0) @@ -226,7 +226,7 @@ def test_sqlite_lazy_by_default(self, sqlite_session: Session) -> None: def test_sqlite_eager_loads_inline(self, sqlite_session: Session) -> None: """With eager_blobs=True, data is loaded via JOIN — no lazy loader.""" - from dimos.memory2.type import _Unloaded + from dimos.memory2.type.observation import _Unloaded s = sqlite_session.stream("eager_test", str, eager_blobs=True) s.append("hello", ts=1.0) @@ -258,7 +258,7 @@ def test_sqlite_lazy_and_eager_same_values(self, sqlite_session: Session) -> Non def test_memory_lazy_with_blobstore(self, memory_store, tmp_path) -> None: """MemoryStore with a BlobStore uses lazy loaders.""" from dimos.memory2.blobstore.file import FileBlobStore - from dimos.memory2.type import _Unloaded + from dimos.memory2.type.observation import _Unloaded bs = FileBlobStore(root=tmp_path / "blobs") bs.start() diff --git a/dimos/memory2/test_save.py b/dimos/memory2/test_save.py index 74c1be89f0..eaea775a6a 100644 --- a/dimos/memory2/test_save.py +++ b/dimos/memory2/test_save.py @@ -18,11 +18,11 @@ import pytest -from dimos.memory2.backend import Backend, LiveChannel from dimos.memory2.impl.memory import ListBackend from dimos.memory2.stream import Stream from dimos.memory2.transform import FnTransformer -from dimos.memory2.type import Observation +from dimos.memory2.type.backend import Backend, LiveChannel +from dimos.memory2.type.observation import Observation # ── Helpers ────────────────────────────────────────────────────────── diff --git a/dimos/memory2/test_stream.py b/dimos/memory2/test_stream.py index 8442527c62..6e64248607 100644 --- a/dimos/memory2/test_stream.py +++ b/dimos/memory2/test_stream.py @@ -27,7 +27,7 @@ from dimos.memory2.buffer import KeepLast, Unbounded from dimos.memory2.transform import FnTransformer, QualityWindow, Transformer -from dimos.memory2.type import Observation +from dimos.memory2.type.observation import Observation if TYPE_CHECKING: from collections.abc import Callable, Generator diff --git a/dimos/memory2/transform.py b/dimos/memory2/transform.py index d68e25344a..a05b75e5c4 100644 --- a/dimos/memory2/transform.py +++ b/dimos/memory2/transform.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Iterator - from dimos.memory2.type import Observation + from dimos.memory2.type.observation import Observation T = TypeVar("T") R = TypeVar("R") diff --git a/dimos/memory2/type/__init__.py b/dimos/memory2/type/__init__.py new file mode 100644 index 0000000000..e3655af16c --- /dev/null +++ b/dimos/memory2/type/__init__.py @@ -0,0 +1,59 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.memory2.type.backend import ( + Backend, + BackendConfig, + BlobStore, + LiveChannel, + VectorStore, +) +from dimos.memory2.type.filter import ( + AfterFilter, + AtFilter, + BeforeFilter, + Filter, + NearFilter, + PredicateFilter, + StreamQuery, + TagsFilter, + TimeRangeFilter, +) +from dimos.memory2.type.observation import ( + _UNLOADED, + EmbeddedObservation, + Observation, + _Unloaded, +) + +__all__ = [ + "_UNLOADED", + "AfterFilter", + "AtFilter", + "Backend", + "BackendConfig", + "BeforeFilter", + "BlobStore", + "EmbeddedObservation", + "Filter", + "LiveChannel", + "NearFilter", + "Observation", + "PredicateFilter", + "StreamQuery", + "TagsFilter", + "TimeRangeFilter", + "VectorStore", + "_Unloaded", +] diff --git a/dimos/memory2/backend.py b/dimos/memory2/type/backend.py similarity index 97% rename from dimos/memory2/backend.py rename to dimos/memory2/type/backend.py index cd8062522a..b5c96eeb98 100644 --- a/dimos/memory2/backend.py +++ b/dimos/memory2/type/backend.py @@ -27,8 +27,8 @@ from dimos.memory2.buffer import BackpressureBuffer from dimos.memory2.codecs.base import Codec - from dimos.memory2.filter import StreamQuery - from dimos.memory2.type import Observation + from dimos.memory2.type.filter import StreamQuery + from dimos.memory2.type.observation import Observation from dimos.models.embedding.base import Embedding T = TypeVar("T") diff --git a/dimos/memory2/filter.py b/dimos/memory2/type/filter.py similarity index 99% rename from dimos/memory2/filter.py rename to dimos/memory2/type/filter.py index 4a1846d3e1..dd6ac0123b 100644 --- a/dimos/memory2/filter.py +++ b/dimos/memory2/type/filter.py @@ -23,7 +23,7 @@ from collections.abc import Callable, Iterator from dimos.memory2.buffer import BackpressureBuffer - from dimos.memory2.type import Observation + from dimos.memory2.type.observation import Observation from dimos.models.embedding.base import Embedding diff --git a/dimos/memory2/type.py b/dimos/memory2/type/observation.py similarity index 100% rename from dimos/memory2/type.py rename to dimos/memory2/type/observation.py diff --git a/dimos/memory2/vectorstore/memory.py b/dimos/memory2/vectorstore/memory.py index fd01514348..f4a8bb9baf 100644 --- a/dimos/memory2/vectorstore/memory.py +++ b/dimos/memory2/vectorstore/memory.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING -from dimos.memory2.backend import VectorStore +from dimos.memory2.type.backend import VectorStore if TYPE_CHECKING: from dimos.models.embedding.base import Embedding diff --git a/dimos/memory2/vectorstore/sqlite.py b/dimos/memory2/vectorstore/sqlite.py index e6c5ee19bc..5af04c022d 100644 --- a/dimos/memory2/vectorstore/sqlite.py +++ b/dimos/memory2/vectorstore/sqlite.py @@ -17,7 +17,8 @@ import json from typing import TYPE_CHECKING -from dimos.memory2.backend import VectorStore +from dimos.memory2.type.backend import VectorStore +from dimos.memory2.utils import validate_identifier if TYPE_CHECKING: import sqlite3 @@ -41,6 +42,7 @@ def __init__(self, conn: sqlite3.Connection) -> None: def _ensure_table(self, stream_name: str, dim: int) -> None: if stream_name in self._tables: return + validate_identifier(stream_name) self._conn.execute( f'CREATE VIRTUAL TABLE IF NOT EXISTS "{stream_name}_vec" ' f"USING vec0(embedding float[{dim}] distance_metric=cosine)" From 367fa4eb2740e101a5ead4fec29af0915d1d3c97 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 13 Mar 2026 17:22:54 +0800 Subject: [PATCH 116/118] lz4 codec included, utils/ cleanup --- dimos/memory2/codecs/jpeg.py | 9 +++-- dimos/memory2/codecs/lz4.py | 44 +++++++++++++++++++++++++ dimos/memory2/transform.py | 2 +- dimos/memory2/utils/__init__.py | 4 +++ dimos/memory2/{ => utils}/formatting.py | 0 dimos/memory2/utils/validation.py | 25 ++++++++++++++ 6 files changed, 80 insertions(+), 4 deletions(-) create mode 100644 dimos/memory2/codecs/lz4.py create mode 100644 dimos/memory2/utils/__init__.py rename dimos/memory2/{ => utils}/formatting.py (100%) create mode 100644 dimos/memory2/utils/validation.py diff --git a/dimos/memory2/codecs/jpeg.py b/dimos/memory2/codecs/jpeg.py index cd1cc677c8..3d854400b1 100644 --- a/dimos/memory2/codecs/jpeg.py +++ b/dimos/memory2/codecs/jpeg.py @@ -14,7 +14,10 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dimos.msgs.sensor_msgs.Image import Image class JpegCodec: @@ -27,10 +30,10 @@ class JpegCodec: def __init__(self, quality: int = 50) -> None: self._quality = quality - def encode(self, value: Any) -> bytes: + def encode(self, value: Image) -> bytes: return value.lcm_jpeg_encode(quality=self._quality) - def decode(self, data: bytes) -> Any: + def decode(self, data: bytes) -> Image: from dimos.msgs.sensor_msgs.Image import Image return Image.lcm_jpeg_decode(data) diff --git a/dimos/memory2/codecs/lz4.py b/dimos/memory2/codecs/lz4.py new file mode 100644 index 0000000000..68aabb339e --- /dev/null +++ b/dimos/memory2/codecs/lz4.py @@ -0,0 +1,44 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, TypeVar + +import lz4.frame # type: ignore[import-untyped] + +if TYPE_CHECKING: + from dimos.memory2.codecs.base import Codec + +T = TypeVar("T") + + +class Lz4Codec: + """Wraps another codec and applies LZ4 frame compression to the output. + + Works with any inner codec — compresses the bytes produced by + ``inner.encode()`` and decompresses before ``inner.decode()``. + """ + + def __init__(self, inner: Codec[Any], compression_level: int = 0) -> None: + self._inner = inner + self._compression_level = compression_level + + def encode(self, value: Any) -> bytes: + raw = self._inner.encode(value) + return bytes(lz4.frame.compress(raw, compression_level=self._compression_level)) + + def decode(self, data: bytes) -> Any: + raw: bytes = lz4.frame.decompress(data) + return self._inner.decode(raw) diff --git a/dimos/memory2/transform.py b/dimos/memory2/transform.py index a05b75e5c4..5ffb02aa46 100644 --- a/dimos/memory2/transform.py +++ b/dimos/memory2/transform.py @@ -18,7 +18,7 @@ import inspect from typing import TYPE_CHECKING, Any, Generic, TypeVar -from dimos.memory2.formatting import FilterRepr +from dimos.memory2.utils.formatting import FilterRepr if TYPE_CHECKING: from collections.abc import Callable, Iterator diff --git a/dimos/memory2/utils/__init__.py b/dimos/memory2/utils/__init__.py new file mode 100644 index 0000000000..d65be4ae67 --- /dev/null +++ b/dimos/memory2/utils/__init__.py @@ -0,0 +1,4 @@ +from dimos.memory2.utils.formatting import FilterRepr +from dimos.memory2.utils.validation import validate_identifier + +__all__ = ["FilterRepr", "validate_identifier"] diff --git a/dimos/memory2/formatting.py b/dimos/memory2/utils/formatting.py similarity index 100% rename from dimos/memory2/formatting.py rename to dimos/memory2/utils/formatting.py diff --git a/dimos/memory2/utils/validation.py b/dimos/memory2/utils/validation.py new file mode 100644 index 0000000000..636ff59327 --- /dev/null +++ b/dimos/memory2/utils/validation.py @@ -0,0 +1,25 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import re + +_IDENT_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + + +def validate_identifier(name: str) -> None: + """Reject stream names that aren't safe SQL identifiers.""" + if not _IDENT_RE.match(name): + raise ValueError(f"Invalid stream name: {name!r}") From 02a233253c7d3d52120f0f3509bd4624c9b5df9d Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 13 Mar 2026 17:45:48 +0800 Subject: [PATCH 117/118] migrated stores to a new config system --- dimos/memory2/impl/sqlite.py | 3 +-- dimos/memory2/store.py | 9 +++------ dimos/memory2/type/backend.py | 5 ++--- 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/dimos/memory2/impl/sqlite.py b/dimos/memory2/impl/sqlite.py index 028e7c1591..5473e5ea70 100644 --- a/dimos/memory2/impl/sqlite.py +++ b/dimos/memory2/impl/sqlite.py @@ -14,7 +14,7 @@ from __future__ import annotations -from dataclasses import dataclass, replace +from dataclasses import replace from itertools import islice import json import re @@ -650,7 +650,6 @@ def stop(self) -> None: # ── SqliteStore ────────────────────────────────────────────────── -@dataclass class SqliteStoreConfig(StoreConfig): """Config for SQLite-backed store.""" diff --git a/dimos/memory2/store.py b/dimos/memory2/store.py index 3fc1b682ad..2ad7516c6e 100644 --- a/dimos/memory2/store.py +++ b/dimos/memory2/store.py @@ -15,12 +15,11 @@ from __future__ import annotations from abc import abstractmethod -from dataclasses import dataclass from typing import TYPE_CHECKING, Any, TypeVar, cast from dimos.core.resource import CompositeResource from dimos.memory2.stream import Stream -from dimos.protocol.service.spec import Configurable +from dimos.protocol.service.spec import BaseConfig, Configurable if TYPE_CHECKING: from collections.abc import Iterator @@ -34,13 +33,11 @@ # ── Configuration ───────────────────────────────────────────────── -@dataclass -class StoreConfig: +class StoreConfig(BaseConfig): """Base config for Store. Subclasses extend with store-specific fields.""" -@dataclass -class SessionConfig: +class SessionConfig(BaseConfig): """Session-level defaults for stream capabilities. These are inherited by all streams in the session unless overridden diff --git a/dimos/memory2/type/backend.py b/dimos/memory2/type/backend.py index b5c96eeb98..0ab9a207bf 100644 --- a/dimos/memory2/type/backend.py +++ b/dimos/memory2/type/backend.py @@ -15,10 +15,10 @@ from __future__ import annotations from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, runtime_checkable from dimos.core.resource import Resource +from dimos.protocol.service.spec import BaseConfig if TYPE_CHECKING: from collections.abc import Iterator @@ -37,8 +37,7 @@ # ── Backend configuration ─────────────────────────────────────── -@dataclass -class BackendConfig: +class BackendConfig(BaseConfig): """Configuration for backend capabilities. Session-level defaults are merged with per-stream overrides and From b3e72364abbe5f459acfb0bb8bac243183407704 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 13 Mar 2026 18:54:13 +0800 Subject: [PATCH 118/118] config fix --- dimos/memory2/codecs/base.py | 3 ++- dimos/memory2/store.py | 9 +++++---- dimos/memory2/test_impl.py | 6 ++++-- dimos/memory2/type/backend.py | 8 ++++---- 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/dimos/memory2/codecs/base.py b/dimos/memory2/codecs/base.py index 4d082e6eb2..ee7bce9ca4 100644 --- a/dimos/memory2/codecs/base.py +++ b/dimos/memory2/codecs/base.py @@ -15,11 +15,12 @@ from __future__ import annotations import importlib -from typing import Any, Protocol, TypeVar +from typing import Any, Protocol, TypeVar, runtime_checkable T = TypeVar("T") +@runtime_checkable class Codec(Protocol[T]): """Encode/decode payloads for storage.""" diff --git a/dimos/memory2/store.py b/dimos/memory2/store.py index 2ad7516c6e..034335f8b0 100644 --- a/dimos/memory2/store.py +++ b/dimos/memory2/store.py @@ -18,14 +18,15 @@ from typing import TYPE_CHECKING, Any, TypeVar, cast from dimos.core.resource import CompositeResource +from dimos.memory2.codecs.base import Codec from dimos.memory2.stream import Stream +from dimos.memory2.type.backend import BlobStore, LiveChannel, VectorStore from dimos.protocol.service.spec import BaseConfig, Configurable if TYPE_CHECKING: from collections.abc import Iterator - from dimos.memory2.codecs.base import Codec - from dimos.memory2.type.backend import Backend, BlobStore, LiveChannel, VectorStore + from dimos.memory2.type.backend import Backend T = TypeVar("T") @@ -44,11 +45,11 @@ class SessionConfig(BaseConfig): per-stream in ``session.stream(..., **overrides)``. """ - live_channel: LiveChannel[Any] | None = None + live_channel: LiveChannel | None = None blob_store: BlobStore | None = None vector_store: VectorStore | None = None eager_blobs: bool = False - codec: Codec[Any] | str | None = None + codec: Codec | str | None = None # ── Stream namespace ────────────────────────────────────────────── diff --git a/dimos/memory2/test_impl.py b/dimos/memory2/test_impl.py index 69e6150c16..0f31695612 100644 --- a/dimos/memory2/test_impl.py +++ b/dimos/memory2/test_impl.py @@ -24,6 +24,8 @@ import pytest +from dimos.memory2.type.backend import BlobStore, VectorStore + if TYPE_CHECKING: from dimos.memory2.store import Session @@ -276,7 +278,7 @@ def test_memory_lazy_with_blobstore(self, memory_store, tmp_path) -> None: # ── Spy stores ─────────────────────────────────────────────────── -class SpyBlobStore: +class SpyBlobStore(BlobStore): """BlobStore that records all calls for verification.""" def __init__(self) -> None: @@ -302,7 +304,7 @@ def delete(self, stream: str, key: int) -> None: self.store.pop((stream, key), None) -class SpyVectorStore: +class SpyVectorStore(VectorStore): """VectorStore that records all calls for verification.""" def __init__(self) -> None: diff --git a/dimos/memory2/type/backend.py b/dimos/memory2/type/backend.py index 0ab9a207bf..3230fae05f 100644 --- a/dimos/memory2/type/backend.py +++ b/dimos/memory2/type/backend.py @@ -15,9 +15,10 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, runtime_checkable +from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, runtime_checkable from dimos.core.resource import Resource +from dimos.memory2.codecs.base import Codec from dimos.protocol.service.spec import BaseConfig if TYPE_CHECKING: @@ -26,7 +27,6 @@ from reactivex.abc import DisposableBase from dimos.memory2.buffer import BackpressureBuffer - from dimos.memory2.codecs.base import Codec from dimos.memory2.type.filter import StreamQuery from dimos.memory2.type.observation import Observation from dimos.models.embedding.base import Embedding @@ -44,11 +44,11 @@ class BackendConfig(BaseConfig): forwarded here by ``Session.stream()``. """ - live_channel: LiveChannel[Any] | None = None + live_channel: LiveChannel | None = None blob_store: BlobStore | None = None vector_store: VectorStore | None = None eager_blobs: bool = False - codec: Codec[Any] | None = None + codec: Codec | None = None page_size: int = 256