diff --git a/data/.lfs/go2_bigoffice.db.tar.gz b/data/.lfs/go2_bigoffice.db.tar.gz new file mode 100644 index 0000000000..93a0af8d25 --- /dev/null +++ b/data/.lfs/go2_bigoffice.db.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3a6e1bfa79c81cb40425c3dcce5b51f721f5435e83bd7c3f502b7014c156b20f +size 183890080 diff --git a/dimos/core/library_config.py b/dimos/core/library_config.py new file mode 100644 index 0000000000..813fb642f6 --- /dev/null +++ b/dimos/core/library_config.py @@ -0,0 +1,27 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Process-wide library defaults. +# Modules that need different settings can override in their own start(). + + +def apply_library_config() -> None: + """Apply process-wide library defaults. Call once per process.""" + # Limit OpenCV internal threads to avoid idle thread contention. + try: + import cv2 + + cv2.setNumThreads(2) + except ImportError: + pass diff --git a/dimos/core/resource.py b/dimos/core/resource.py index ce3f735329..63b1eec4f0 100644 --- a/dimos/core/resource.py +++ b/dimos/core/resource.py @@ -12,10 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC, abstractmethod +from __future__ import annotations +from abc import abstractmethod +from typing import TYPE_CHECKING, Self -class Resource(ABC): +if TYPE_CHECKING: + from types import TracebackType + +from reactivex.abc import DisposableBase +from reactivex.disposable import CompositeDisposable + + +class Resource(DisposableBase): @abstractmethod def start(self) -> None: ... @@ -43,3 +52,35 @@ def dispose(self) -> None: """ self.stop() + + def __enter__(self) -> Self: + self.start() + return self + + def __exit__( + self, + exctype: type[BaseException] | None, + excinst: BaseException | None, + exctb: TracebackType | None, + ) -> None: + self.stop() + + +class CompositeResource(Resource): + """Resource that owns child disposables, disposed on stop().""" + + _disposables: CompositeDisposable + + def __init__(self) -> None: + self._disposables = CompositeDisposable() + + def register_disposables(self, *disposables: DisposableBase) -> None: + """Register child disposables to be disposed when this resource stops.""" + for d in disposables: + self._disposables.add(d) + + def start(self) -> None: + pass + + def stop(self) -> None: + self._disposables.dispose() diff --git a/dimos/core/worker.py b/dimos/core/worker.py index dca561f16c..8f3beee7ec 100644 --- a/dimos/core/worker.py +++ b/dimos/core/worker.py @@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any from dimos.core.global_config import GlobalConfig, global_config +from dimos.core.library_config import apply_library_config from dimos.utils.logging_config import setup_logger from dimos.utils.sequential_ids import SequentialIds @@ -292,6 +293,7 @@ def _suppress_console_output() -> None: def _worker_entrypoint(conn: Connection, worker_id: int) -> None: + apply_library_config() instances: dict[int, Any] = {} try: diff --git a/dimos/mapping/occupancy/gradient.py b/dimos/mapping/occupancy/gradient.py index 880f2692da..c9db43088e 100644 --- a/dimos/mapping/occupancy/gradient.py +++ b/dimos/mapping/occupancy/gradient.py @@ -53,7 +53,7 @@ def gradient( distance_cells = ndimage.distance_transform_edt(1 - obstacle_map) # Convert to meters and clip to max distance - distance_meters = np.clip(distance_cells * occupancy_grid.resolution, 0, max_distance) + distance_meters = np.clip(distance_cells * occupancy_grid.resolution, 0, max_distance) # type: ignore[operator] # Invert and scale to 0-100 range # Far from obstacles (max_distance) -> 0 diff --git a/dimos/memory2/architecture.md b/dimos/memory2/architecture.md new file mode 100644 index 0000000000..9dc805577f --- /dev/null +++ b/dimos/memory2/architecture.md @@ -0,0 +1,114 @@ +# memory + +Observation storage and streaming layer for DimOS. Pull-based, lazy, composable. + +## Architecture + +``` + Live Sensor Data + ↓ +Store → Stream → [filters / transforms / terminals] → Stream → [filters / transforms / terminals] → Stream → Live hooks + ↓ ↓ ↓ + Backend (ObservationStore + BlobStore + VectorStore + Notifier) Backend In Memory +``` + +**Store** owns a storage location (file, in-memory) and directly manages named streams. **Stream** is the query/iteration surface — lazy until a terminal is called. **Backend** is a concrete composite that orchestrates ObservationStore + BlobStore + VectorStore + Notifier for each stream. + +Supporting Systems: + +- BlobStore — separates large payloads from metadata. FileBlobStore (files on disk) and SqliteBlobStore (blob table per stream). Supports lazy loading. +- Codecs — codec_for() auto-selects: JpegCodec for images (TurboJPEG, ~10-20x compression), LcmCodec for DimOS messages, PickleCodec fallback. +- Transformers — Transformer[T,R] ABC wrapping iterator-to-iterator. EmbedImages/EmbedText enrich observations with embeddings. QualityWindow keeps best per time window. +- Backpressure Buffers — KeepLast, Bounded, DropNew, Unbounded — bridge push/pull for live mode. + + +## Modules + +| Module | What | +|----------------|-------------------------------------------------------------------| +| `stream.py` | Stream node — filters, transforms, terminals | +| `backend.py` | Concrete Backend composite (ObservationStore + Blob + Vector + Live) | +| `store.py` | Store, StoreConfig | +| `transform.py` | Transformer ABC, FnTransformer, FnIterTransformer, QualityWindow | +| `buffer.py` | Backpressure buffers for live mode (KeepLast, Bounded, Unbounded) | +| `embed.py` | EmbedImages / EmbedText transformers | + +## Subpackages + +| Package | What | Docs | +|-----------------|------------------------------------------------------|--------------------------------------------------| +| `type/` | Observation, EmbeddedObservation, Filter/StreamQuery | | +| `store/` | Store ABC + implementations (MemoryStore, SqliteStore) | [store/README.md](store/README.md) | +| `notifier/` | Notifier ABC + SubjectNotifier | | +| `blobstore/` | BlobStore ABC + implementations (file, sqlite) | [blobstore/blobstore.md](blobstore/blobstore.md) | +| `codecs/` | Encode/decode for storage (pickle, JPEG, LCM) | [codecs/README.md](codecs/README.md) | +| `vectorstore/` | VectorStore ABC + implementations (memory, sqlite) | | +| `observationstore/` | ObservationStore Protocol + implementations | | + +## Docs + +| Doc | What | +|-----|------| +| [streaming.md](streaming.md) | Lazy vs materializing vs terminal — evaluation model, live safety | +| [embeddings.md](embeddings.md) | Embedding layer design — EmbeddedObservation, vector search, EmbedImages/EmbedText | +| [blobstore/blobstore.md](blobstore/blobstore.md) | BlobStore architecture — separate payload storage from metadata | + +## Query execution + +`StreamQuery` holds the full query spec (filters, text search, vector search, ordering, offset/limit). It also provides `apply(iterator)` — a Python-side execution path that runs all operations as in-memory predicates, brute-force cosine, and list sorts. + +This is the **default fallback**. ObservationStore implementations are free to push down operations using store-specific strategies instead: + +| Operation | Python fallback (`StreamQuery.apply`) | Store push-down (example) | +|----------------|---------------------------------------|----------------------------------| +| Filters | `filter.matches()` predicates | SQL WHERE clauses | +| Text search | Case-insensitive substring | FTS5 full-text index | +| Vector search | Brute-force cosine similarity | vec0 / FAISS ANN index | +| Ordering | `sorted()` materialization | SQL ORDER BY | +| Offset / limit | `islice()` | SQL OFFSET / LIMIT | + +`ListObservationStore` delegates entirely to `StreamQuery.apply()`. `SqliteObservationStore` translates the query into SQL and only falls back to Python for operations it can't express natively. + +Transform-sourced streams (post `.transform()`) always use `StreamQuery.apply()` since there's no index to push down to. + +## Quick start + +```python +from dimos.memory2 import MemoryStore + +store = MemoryStore() +images = store.stream("images") + +# Write +images.append(frame, ts=time.time(), pose=(x, y, z), tags={"camera": "front"}) + +# Query +recent = images.after(t).limit(10).fetch() +nearest = images.near(pose, radius=2.0).fetch() +latest = images.last() + +# Transform (class or bare generator function) +edges = images.transform(Canny()).save(store.stream("edges")) + +def running_avg(upstream): + total, n = 0.0, 0 + for obs in upstream: + total += obs.data; n += 1 + yield obs.derive(data=total / n) +avgs = stream.transform(running_avg).fetch() + +# Live +for obs in images.live().transform(process): + handle(obs) + +# Embed + search +images.transform(EmbedImages(clip)).save(store.stream("embedded")) +results = store.stream("embedded").search(query_vec, k=5).fetch() +``` + +## Implementations + +| ObservationStore | Status | Storage | +|-----------------|----------|----------------------------------------| +| `ListObservationStore` | Complete | In-memory (lists + brute-force search) | +| `SqliteObservationStore` | Complete | SQLite (WAL, FTS5, vec0) | diff --git a/dimos/memory2/backend.py b/dimos/memory2/backend.py new file mode 100644 index 0000000000..c861993de9 --- /dev/null +++ b/dimos/memory2/backend.py @@ -0,0 +1,244 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Concrete composite Backend that orchestrates ObservationStore + BlobStore + VectorStore + Notifier.""" + +from __future__ import annotations + +from dataclasses import replace +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +from dimos.memory2.codecs.base import Codec, codec_id +from dimos.memory2.notifier.subject import SubjectNotifier +from dimos.memory2.type.observation import _UNLOADED + +if TYPE_CHECKING: + from collections.abc import Iterator + + from reactivex.abc import DisposableBase + + from dimos.memory2.blobstore.base import BlobStore + from dimos.memory2.buffer import BackpressureBuffer + from dimos.memory2.notifier.base import Notifier + from dimos.memory2.observationstore.base import ObservationStore + from dimos.memory2.type.filter import StreamQuery + from dimos.memory2.type.observation import Observation + from dimos.memory2.vectorstore.base import VectorStore + +T = TypeVar("T") + + +class Backend(Generic[T]): + """Orchestrates metadata, blob, vector, and live stores for one stream. + + This is a concrete class — NOT a protocol. All shared orchestration logic + (encode → insert → store blob → index vector → notify) lives here, + eliminating duplication between ListObservationStore and SqliteObservationStore. + """ + + def __init__( + self, + *, + metadata_store: ObservationStore[T], + codec: Codec[Any], + blob_store: BlobStore | None = None, + vector_store: VectorStore | None = None, + notifier: Notifier[T] | None = None, + eager_blobs: bool = False, + ) -> None: + self.metadata_store = metadata_store + self.codec = codec + self.blob_store = blob_store + self.vector_store = vector_store + self.notifier: Notifier[T] = notifier or SubjectNotifier() + self.eager_blobs = eager_blobs + + @property + def name(self) -> str: + return self.metadata_store.name + + def _make_loader(self, row_id: int) -> Any: + bs = self.blob_store + if bs is None: + raise RuntimeError("BlobStore required but not configured") + name, codec = self.name, self.codec + + def loader() -> Any: + raw = bs.get(name, row_id) + return codec.decode(raw) + + return loader + + def append(self, obs: Observation[T]) -> Observation[T]: + # Encode payload before any locking (avoids holding locks during IO) + encoded: bytes | None = None + if self.blob_store is not None: + encoded = self.codec.encode(obs._data) + + try: + # Insert metadata, get assigned id + row_id = self.metadata_store.insert(obs) + obs.id = row_id + + # Store blob + if encoded is not None: + assert self.blob_store is not None + self.blob_store.put(self.name, row_id, encoded) + # Replace inline data with lazy loader + obs._data = _UNLOADED # type: ignore[assignment] + obs._loader = self._make_loader(row_id) + + # Store embedding vector + if self.vector_store is not None: + emb = getattr(obs, "embedding", None) + if emb is not None: + self.vector_store.put(self.name, row_id, emb) + + # Commit if the metadata store supports it (e.g. SqliteObservationStore) + if hasattr(self.metadata_store, "commit"): + self.metadata_store.commit() + except BaseException: + if hasattr(self.metadata_store, "rollback"): + self.metadata_store.rollback() + raise + + self.notifier.notify(obs) + return obs + + def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: + if query.search_vec is not None and query.live_buffer is not None: + raise TypeError("Cannot combine .search() with .live() — search is a batch operation.") + buf = query.live_buffer + if buf is not None: + sub = self.notifier.subscribe(buf) + return self._iterate_live(query, buf, sub) + return self._iterate_snapshot(query) + + def _attach_loaders(self, it: Iterator[Observation[T]]) -> Iterator[Observation[T]]: + """Attach lazy blob loaders to observations from the metadata store.""" + if self.blob_store is None: + yield from it + return + for obs in it: + if obs._loader is None and isinstance(obs._data, type(_UNLOADED)): + obs._loader = self._make_loader(obs.id) + yield obs + + def _iterate_snapshot(self, query: StreamQuery) -> Iterator[Observation[T]]: + if query.search_vec is not None and self.vector_store is not None: + yield from self._vector_search(query) + return + + it: Iterator[Observation[T]] = self._attach_loaders(self.metadata_store.query(query)) + + # Apply python post-filters after loaders are attached (so obs.data works) + python_filters = getattr(self.metadata_store, "_pending_python_filters", None) + pending_query = getattr(self.metadata_store, "_pending_query", None) + if python_filters: + from itertools import islice as _islice + + it = (obs for obs in it if all(f.matches(obs) for f in python_filters)) + if pending_query and pending_query.offset_val: + it = _islice(it, pending_query.offset_val, None) + if pending_query and pending_query.limit_val is not None: + it = _islice(it, pending_query.limit_val) + + if self.eager_blobs and self.blob_store is not None: + for obs in it: + _ = obs.data # trigger lazy loader + yield obs + else: + yield from it + + def _vector_search(self, query: StreamQuery) -> Iterator[Observation[T]]: + vs = self.vector_store + assert vs is not None and query.search_vec is not None + + hits = vs.search(self.name, query.search_vec, query.search_k or 10) + if not hits: + return + + ids = [h[0] for h in hits] + obs_list = list(self._attach_loaders(iter(self.metadata_store.fetch_by_ids(ids)))) + obs_by_id = {obs.id: obs for obs in obs_list} + + # Preserve VectorStore ranking order + ranked: list[Observation[T]] = [] + for obs_id, sim in hits: + match = obs_by_id.get(obs_id) + if match is not None: + ranked.append( + match.derive(data=match.data, embedding=query.search_vec, similarity=sim) + ) + + # Apply remaining query ops (skip vector search) + rest = replace(query, search_vec=None, search_k=None) + yield from rest.apply(iter(ranked)) + + def _iterate_live( + self, + query: StreamQuery, + buf: BackpressureBuffer[Observation[T]], + sub: DisposableBase, + ) -> Iterator[Observation[T]]: + from dimos.memory2.buffer import ClosedError + + eager = self.eager_blobs and self.blob_store is not None + + try: + # Backfill phase + last_id = -1 + for obs in self._iterate_snapshot(query): + last_id = max(last_id, obs.id) + yield obs + + # Live tail + filters = query.filters + while True: + obs = buf.take() + if obs.id <= last_id: + continue + last_id = obs.id + if filters and not all(f.matches(obs) for f in filters): + continue + if eager: + _ = obs.data # trigger lazy loader + yield obs + except (ClosedError, StopIteration): + pass + finally: + sub.dispose() + + def count(self, query: StreamQuery) -> int: + if query.search_vec: + return sum(1 for _ in self.iterate(query)) + return self.metadata_store.count(query) + + def serialize(self) -> dict[str, Any]: + """Serialize the fully-resolved backend config to a dict.""" + return { + "codec_id": codec_id(self.codec), + "eager_blobs": self.eager_blobs, + "metadata_store": self.metadata_store.serialize() + if hasattr(self.metadata_store, "serialize") + else None, + "blob_store": self.blob_store.serialize() if self.blob_store else None, + "vector_store": self.vector_store.serialize() if self.vector_store else None, + "notifier": self.notifier.serialize(), + } + + def stop(self) -> None: + """Stop the metadata store (closes per-stream connections if any).""" + if hasattr(self.metadata_store, "stop"): + self.metadata_store.stop() diff --git a/dimos/memory2/blobstore/base.py b/dimos/memory2/blobstore/base.py new file mode 100644 index 0000000000..b146d2028e --- /dev/null +++ b/dimos/memory2/blobstore/base.py @@ -0,0 +1,58 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import abstractmethod +from typing import Any + +from dimos.core.resource import CompositeResource +from dimos.memory2.registry import qual +from dimos.protocol.service.spec import BaseConfig, Configurable + + +class BlobStoreConfig(BaseConfig): + pass + + +class BlobStore(Configurable[BlobStoreConfig], CompositeResource): + """Persistent storage for encoded payload blobs. + + Separates payload data from metadata indexing so that large blobs + (images, point clouds) don't penalize metadata queries. + """ + + default_config: type[BlobStoreConfig] = BlobStoreConfig + + def __init__(self, **kwargs: Any) -> None: + Configurable.__init__(self, **kwargs) + CompositeResource.__init__(self) + + @abstractmethod + def put(self, stream_name: str, key: int, data: bytes) -> None: + """Store a blob for the given stream and observation id.""" + ... + + @abstractmethod + def get(self, stream_name: str, key: int) -> bytes: + """Retrieve a blob by stream name and observation id.""" + ... + + @abstractmethod + def delete(self, stream_name: str, key: int) -> None: + """Delete a blob by stream name and observation id.""" + ... + + def serialize(self) -> dict[str, Any]: + return {"class": qual(type(self)), "config": self.config.model_dump()} diff --git a/dimos/memory2/blobstore/blobstore.md b/dimos/memory2/blobstore/blobstore.md new file mode 100644 index 0000000000..00006cf468 --- /dev/null +++ b/dimos/memory2/blobstore/blobstore.md @@ -0,0 +1,84 @@ +# blobstore/ + +Separates payload blob storage from metadata indexing. Observation payloads vary hugely in size — a `Vector3` is 24 bytes, a camera frame is megabytes. Storing everything inline penalizes metadata queries. BlobStore lets large payloads live elsewhere. + +## ABC (`blobstore/base.py`) + +```python +class BlobStore(Resource): + def put(self, stream_name: str, key: int, data: bytes) -> None: ... + def get(self, stream_name: str, key: int) -> bytes: ... # raises KeyError if missing + def delete(self, stream_name: str, key: int) -> None: ... # silent if missing +``` + +- `stream_name` — stream name (used to organize storage: directories, tables) +- `key` — observation id +- `data` — encoded payload bytes (codec handles serialization, blob store handles persistence) +- Extends `Resource` (start/stop) but does NOT own its dependencies' lifecycle + +## Implementations + +### `file.py` — FileBlobStore + +Stores blobs as files on disk, one directory per stream. + +``` +{root}/{stream}/{key}.bin +``` + +`__init__(root: str | os.PathLike[str])` — `start()` creates the root directory. + +### `sqlite.py` — SqliteBlobStore + +Stores blobs in a separate SQLite table per stream. + +```sql +CREATE TABLE "{stream}_blob" (id INTEGER PRIMARY KEY, data BLOB NOT NULL) +``` + +`__init__(conn: sqlite3.Connection)` — does NOT own the connection. + +**Internal use** (same db as metadata): `SqliteStore._create_backend()` creates one connection per stream, passes it to both the index and the blob store. + +**External use** (separate db): user creates a separate connection and passes it. User manages that connection's lifecycle. + +**JOIN optimization**: when `eager_blobs=True` and the blob store shares the same connection as the index, `SqliteObservationStore` can optimize with a JOIN instead of separate queries: + +```sql +SELECT m.id, m.ts, m.pose, m.tags, b.data +FROM "images" m JOIN "images_blob" b ON m.id = b.id +WHERE m.ts > ? +``` + +## Lazy loading + +`eager_blobs` is a store/stream-level flag, orthogonal to blob store choice. It controls WHEN data is loaded: + +- `eager_blobs=False` (default) → backend sets `Observation._loader`, payload loaded on `.data` access +- `eager_blobs=True` → backend triggers `.data` access during iteration (eager) + +| eager_blobs | blob store | loading strategy | +|-------------|-----------|-----------------| +| True | SqliteBlobStore (same conn) | JOIN — one round trip | +| True | any other | iterate meta, `blob_store.get()` per row | +| False | any | iterate meta only, `_loader = lambda: codec.decode(blob_store.get(...))` | + +## Usage + +```python +# Per-stream blob store choice +poses = store.stream("poses", PoseStamped) # default, lazy +images = store.stream("images", Image, eager_blobs=True) # eager +images = store.stream("images", Image, blob_store=file_blobs) # override +``` + +## Files + +``` +blobstore/ + base.py BlobStore ABC + blobstore.md this file + __init__.py re-exports BlobStore, FileBlobStore, SqliteBlobStore + file.py FileBlobStore + sqlite.py SqliteBlobStore +``` diff --git a/dimos/memory2/blobstore/file.py b/dimos/memory2/blobstore/file.py new file mode 100644 index 0000000000..e0ae80b61a --- /dev/null +++ b/dimos/memory2/blobstore/file.py @@ -0,0 +1,70 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from dimos.memory2.blobstore.base import BlobStore, BlobStoreConfig +from dimos.memory2.utils.validation import validate_identifier + + +class FileBlobStoreConfig(BlobStoreConfig): + root: str + + +class FileBlobStore(BlobStore): + """Stores blobs as files on disk, one directory per stream. + + Layout:: + + {root}/{stream}/{key}.bin + """ + + default_config = FileBlobStoreConfig + config: FileBlobStoreConfig + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._root = Path(self.config.root) + + def _path(self, stream_name: str, key: int) -> Path: + validate_identifier(stream_name) + return self._root / stream_name / f"{key}.bin" + + def start(self) -> None: + self._root.mkdir(parents=True, exist_ok=True) + + def stop(self) -> None: + pass + + def put(self, stream_name: str, key: int, data: bytes) -> None: + p = self._path(stream_name, key) + p.parent.mkdir(parents=True, exist_ok=True) + p.write_bytes(data) + + def get(self, stream_name: str, key: int) -> bytes: + p = self._path(stream_name, key) + try: + return p.read_bytes() + except FileNotFoundError: + raise KeyError(f"No blob for stream={stream_name!r}, key={key}") from None + + def delete(self, stream_name: str, key: int) -> None: + p = self._path(stream_name, key) + try: + p.unlink() + except FileNotFoundError: + raise KeyError(f"No blob for stream={stream_name!r}, key={key}") from None diff --git a/dimos/memory2/blobstore/sqlite.py b/dimos/memory2/blobstore/sqlite.py new file mode 100644 index 0000000000..1cb5f1aa38 --- /dev/null +++ b/dimos/memory2/blobstore/sqlite.py @@ -0,0 +1,108 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import sqlite3 +from typing import Any + +from pydantic import Field, model_validator + +from dimos.memory2.blobstore.base import BlobStore, BlobStoreConfig +from dimos.memory2.utils.sqlite import open_disposable_sqlite_connection +from dimos.memory2.utils.validation import validate_identifier + + +class SqliteBlobStoreConfig(BlobStoreConfig): + conn: sqlite3.Connection | None = Field(default=None, exclude=True) + path: str | None = None + + @model_validator(mode="after") + def _conn_xor_path(self) -> SqliteBlobStoreConfig: + if self.conn is not None and self.path is not None: + raise ValueError("Specify either conn or path, not both") + if self.conn is None and self.path is None: + raise ValueError("Specify either conn or path") + return self + + +class SqliteBlobStore(BlobStore): + """Stores blobs in a separate SQLite table per stream. + + Table layout per stream:: + + CREATE TABLE "{stream}_blob" ( + id INTEGER PRIMARY KEY, + data BLOB NOT NULL + ); + + Supports two construction modes: + + - ``SqliteBlobStore(conn=conn)`` — borrows an externally-managed connection. + - ``SqliteBlobStore(path="file.db")`` — opens and owns its own connection. + + Does NOT commit; the caller (typically Backend) is responsible for commits. + """ + + default_config = SqliteBlobStoreConfig + config: SqliteBlobStoreConfig + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._conn: sqlite3.Connection = self.config.conn # type: ignore[assignment] # set in start() if None + self._path = self.config.path + self._tables: set[str] = set() + + def _ensure_table(self, stream_name: str) -> None: + if stream_name in self._tables: + return + validate_identifier(stream_name) + self._conn.execute( + f'CREATE TABLE IF NOT EXISTS "{stream_name}_blob" ' + "(id INTEGER PRIMARY KEY, data BLOB NOT NULL)" + ) + self._tables.add(stream_name) + + def start(self) -> None: + if self._conn is None: + assert self._path is not None + disposable, self._conn = open_disposable_sqlite_connection(self._path) + self.register_disposables(disposable) + + def put(self, stream_name: str, key: int, data: bytes) -> None: + self._ensure_table(stream_name) + self._conn.execute( + f'INSERT OR REPLACE INTO "{stream_name}_blob" (id, data) VALUES (?, ?)', + (key, data), + ) + + def get(self, stream_name: str, key: int) -> bytes: + try: + row = self._conn.execute( + f'SELECT data FROM "{stream_name}_blob" WHERE id = ?', (key,) + ).fetchone() + except Exception: + raise KeyError(f"No blob for stream={stream_name!r}, key={key}") + if row is None: + raise KeyError(f"No blob for stream={stream_name!r}, key={key}") + result: bytes = row[0] + return result + + def delete(self, stream_name: str, key: int) -> None: + try: + cur = self._conn.execute(f'DELETE FROM "{stream_name}_blob" WHERE id = ?', (key,)) + except Exception: + raise KeyError(f"No blob for stream={stream_name!r}, key={key}") from None + if cur.rowcount == 0: + raise KeyError(f"No blob for stream={stream_name!r}, key={key}") diff --git a/dimos/memory2/blobstore/test_blobstore.py b/dimos/memory2/blobstore/test_blobstore.py new file mode 100644 index 0000000000..ade6aa4cc6 --- /dev/null +++ b/dimos/memory2/blobstore/test_blobstore.py @@ -0,0 +1,62 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Grid tests for BlobStore implementations.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +if TYPE_CHECKING: + from dimos.memory2.blobstore.base import BlobStore + + +class TestBlobStore: + def test_put_get_roundtrip(self, blob_store: BlobStore) -> None: + data = b"hello world" + blob_store.put("stream_a", 1, data) + assert blob_store.get("stream_a", 1) == data + + def test_get_missing_raises(self, blob_store: BlobStore) -> None: + with pytest.raises(KeyError): + blob_store.get("nonexistent", 999) + + def test_put_overwrite(self, blob_store: BlobStore) -> None: + blob_store.put("s", 1, b"first") + blob_store.put("s", 1, b"second") + assert blob_store.get("s", 1) == b"second" + + def test_delete(self, blob_store: BlobStore) -> None: + blob_store.put("s", 1, b"data") + blob_store.delete("s", 1) + with pytest.raises(KeyError): + blob_store.get("s", 1) + + def test_delete_missing_raises(self, blob_store: BlobStore) -> None: + with pytest.raises(KeyError): + blob_store.delete("s", 999) + + def test_stream_isolation(self, blob_store: BlobStore) -> None: + blob_store.put("a", 1, b"alpha") + blob_store.put("b", 1, b"beta") + assert blob_store.get("a", 1) == b"alpha" + assert blob_store.get("b", 1) == b"beta" + + def test_large_blob(self, blob_store: BlobStore) -> None: + data = bytes(range(256)) * 1000 # 256 KB + blob_store.put("big", 0, data) + assert blob_store.get("big", 0) == data + assert blob_store.get("big", 0) == data diff --git a/dimos/memory2/buffer.py b/dimos/memory2/buffer.py new file mode 100644 index 0000000000..49814eb6dc --- /dev/null +++ b/dimos/memory2/buffer.py @@ -0,0 +1,248 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Backpressure buffers — the bridge between push and pull. + +Real-world data sources (cameras, LiDAR, ROS topics) and ReactiveX pipelines +are *push-based*: they emit items whenever they please. Databases, analysis +systems, and our memory store are *pull-based*: consumers iterate at their own +pace. A BackpressureBuffer sits between the two, absorbing push bursts so +that the pull side can drain items on its own schedule. + +The choice of strategy controls what happens under load: + +- **KeepLast** — single-slot, always overwrites; best for real-time sensor + data where only the latest reading matters. +- **Bounded** — FIFO with a cap; drops the oldest item on overflow. +- **DropNew** — FIFO with a cap; rejects new items on overflow. +- **Unbounded** — unlimited FIFO; guarantees delivery at the cost of memory. + +All four share the same ABC interface and are interchangeable wherever a +buffer is accepted (e.g. ``Stream.live(buffer=...)``). +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections import deque +import threading +from typing import TYPE_CHECKING, Generic, TypeVar + +if TYPE_CHECKING: + from collections.abc import Iterator + +T = TypeVar("T") + + +class ClosedError(Exception): + """Raised when take() is called on a closed buffer.""" + + +class BackpressureBuffer(ABC, Generic[T]): + """Thread-safe buffer between push producers and pull consumers.""" + + @abstractmethod + def put(self, item: T) -> bool: + """Push an item. Returns False if the item was dropped.""" + + @abstractmethod + def take(self, timeout: float | None = None) -> T: + """Block until an item is available. Raises ClosedError if the buffer is closed.""" + + @abstractmethod + def try_take(self) -> T | None: + """Non-blocking take. Returns None if empty.""" + + @abstractmethod + def close(self) -> None: + """Signal no more items. Subsequent take() raises ClosedError.""" + + @abstractmethod + def __len__(self) -> int: ... + + def __iter__(self) -> Iterator[T]: + """Yield items until the buffer is closed.""" + while True: + try: + yield self.take() + except ClosedError: + return + + +class KeepLast(BackpressureBuffer[T]): + """Single-slot buffer. put() always overwrites. Default for live mode.""" + + def __init__(self) -> None: + self._item: T | None = None + self._has_item = False + self._closed = False + self._cond = threading.Condition() + + def put(self, item: T) -> bool: + with self._cond: + if self._closed: + return False + self._item = item + self._has_item = True + self._cond.notify() + return True + + def take(self, timeout: float | None = None) -> T: + with self._cond: + while not self._has_item: + if self._closed: + raise ClosedError("Buffer is closed") + if not self._cond.wait(timeout): + raise TimeoutError("take() timed out") + item = self._item + assert item is not None + self._item = None + self._has_item = False + return item + + def try_take(self) -> T | None: + with self._cond: + if not self._has_item: + return None + item = self._item + self._item = None + self._has_item = False + return item + + def close(self) -> None: + with self._cond: + self._closed = True + self._cond.notify_all() + + def __len__(self) -> int: + with self._cond: + return 1 if self._has_item else 0 + + +class Bounded(BackpressureBuffer[T]): + """FIFO queue with max size. Drops oldest when full.""" + + def __init__(self, maxlen: int) -> None: + self._buf: deque[T] = deque(maxlen=maxlen) + self._closed = False + self._cond = threading.Condition() + + def put(self, item: T) -> bool: + with self._cond: + if self._closed: + return False + self._buf.append(item) # deque(maxlen) drops oldest automatically + self._cond.notify() + return True + + def take(self, timeout: float | None = None) -> T: + with self._cond: + while not self._buf: + if self._closed: + raise ClosedError("Buffer is closed") + if not self._cond.wait(timeout): + raise TimeoutError("take() timed out") + return self._buf.popleft() + + def try_take(self) -> T | None: + with self._cond: + return self._buf.popleft() if self._buf else None + + def close(self) -> None: + with self._cond: + self._closed = True + self._cond.notify_all() + + def __len__(self) -> int: + with self._cond: + return len(self._buf) + + +class DropNew(BackpressureBuffer[T]): + """FIFO queue. Rejects new items when full (put returns False).""" + + def __init__(self, maxlen: int) -> None: + self._buf: deque[T] = deque() + self._maxlen = maxlen + self._closed = False + self._cond = threading.Condition() + + def put(self, item: T) -> bool: + with self._cond: + if self._closed or len(self._buf) >= self._maxlen: + return False + self._buf.append(item) + self._cond.notify() + return True + + def take(self, timeout: float | None = None) -> T: + with self._cond: + while not self._buf: + if self._closed: + raise ClosedError("Buffer is closed") + if not self._cond.wait(timeout): + raise TimeoutError("take() timed out") + return self._buf.popleft() + + def try_take(self) -> T | None: + with self._cond: + return self._buf.popleft() if self._buf else None + + def close(self) -> None: + with self._cond: + self._closed = True + self._cond.notify_all() + + def __len__(self) -> int: + with self._cond: + return len(self._buf) + + +class Unbounded(BackpressureBuffer[T]): + """Unbounded FIFO queue. Use carefully — can grow without limit.""" + + def __init__(self) -> None: + self._buf: deque[T] = deque() + self._closed = False + self._cond = threading.Condition() + + def put(self, item: T) -> bool: + with self._cond: + if self._closed: + return False + self._buf.append(item) + self._cond.notify() + return True + + def take(self, timeout: float | None = None) -> T: + with self._cond: + while not self._buf: + if self._closed: + raise ClosedError("Buffer is closed") + if not self._cond.wait(timeout): + raise TimeoutError("take() timed out") + return self._buf.popleft() + + def try_take(self) -> T | None: + with self._cond: + return self._buf.popleft() if self._buf else None + + def close(self) -> None: + with self._cond: + self._closed = True + self._cond.notify_all() + + def __len__(self) -> int: + with self._cond: + return len(self._buf) diff --git a/dimos/memory2/codecs/README.md b/dimos/memory2/codecs/README.md new file mode 100644 index 0000000000..8ad40e95fd --- /dev/null +++ b/dimos/memory2/codecs/README.md @@ -0,0 +1,57 @@ +# codecs + +Encode/decode payloads for persistent storage. Codecs convert typed Python objects to `bytes` and back, used by backends that store observation data as blobs. + +## Protocol + +```python +class Codec(Protocol[T]): + def encode(self, value: T) -> bytes: ... + def decode(self, data: bytes) -> T: ... +``` + +## Built-in codecs + +| Codec | Type | Notes | +|-------|------|-------| +| `PickleCodec` | Any Python object | Fallback. Uses `HIGHEST_PROTOCOL`. | +| `JpegCodec` | `Image` | Lossy compression via TurboJPEG. ~10-20x smaller. Preserves `frame_id` in header. | +| `LcmCodec` | `DimosMsg` subclasses | Uses `lcm_encode()`/`lcm_decode()`. Zero-copy for LCM message types. | + +## Auto-selection + +`codec_for(payload_type)` picks the right codec: + +```python +from dimos.memory2.codecs import codec_for + +codec_for(Image) # → JpegCodec(quality=50) +codec_for(SomeLcmMsg) # → LcmCodec(SomeLcmMsg) (if has lcm_encode/lcm_decode) +codec_for(dict) # → PickleCodec() (fallback) +codec_for(None) # → PickleCodec() +``` + +## Writing a new codec + +1. Create `dimos/memory/codecs/mycodec.py`: + +```python +class MyCodec: + def encode(self, value: MyType) -> bytes: + ... + + def decode(self, data: bytes) -> MyType: + ... +``` + +2. Add a branch in `codec_for()` in `base.py` to auto-select it for the relevant type. + +3. Add a test case to `test_codecs.py` — the grid fixture makes this easy: + +```python +@pytest.fixture(params=[..., ("mycodec", MyCodec(), sample_value)]) +def codec_case(request): + ... +``` + +No base class needed — `Codec` is a protocol. Just implement `encode` and `decode`. diff --git a/dimos/memory2/codecs/base.py b/dimos/memory2/codecs/base.py new file mode 100644 index 0000000000..821b36b60f --- /dev/null +++ b/dimos/memory2/codecs/base.py @@ -0,0 +1,112 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import importlib +from typing import Any, Protocol, TypeVar, runtime_checkable + +T = TypeVar("T") + + +@runtime_checkable +class Codec(Protocol[T]): + """Encode/decode payloads for storage.""" + + def encode(self, value: T) -> bytes: ... + def decode(self, data: bytes) -> T: ... + + +def codec_for(payload_type: type[Any] | None = None) -> Codec[Any]: + """Auto-select codec based on payload type.""" + from dimos.memory2.codecs.pickle import PickleCodec + + if payload_type is not None: + from dimos.msgs.sensor_msgs.Image import Image + + if issubclass(payload_type, Image): + from dimos.memory2.codecs.jpeg import JpegCodec + + return JpegCodec() + if hasattr(payload_type, "lcm_encode") and hasattr(payload_type, "lcm_decode"): + from dimos.memory2.codecs.lcm import LcmCodec + + return LcmCodec(payload_type) + return PickleCodec() + + +def codec_id(codec: Codec[Any]) -> str: + """Derive a string ID from a codec instance, e.g. ``'lz4+lcm'``. + + Walks the ``_inner`` chain for wrapper codecs, joining with ``+``. + Uses the naming convention ``FooCodec`` → ``'foo'``. + """ + parts: list[str] = [] + c: Any = codec + while hasattr(c, "_inner"): + parts.append(_class_to_id(c)) + c = c._inner + parts.append(_class_to_id(c)) + return "+".join(parts) + + +def codec_from_id(codec_id_str: str, payload_module: str) -> Codec[Any]: + """Reconstruct a codec chain from its string ID (e.g. ``'lz4+lcm'``). + + Builds inside-out: the rightmost segment is the innermost (base) codec. + """ + parts = codec_id_str.split("+") + # Innermost first + result = _make_one(parts[-1], payload_module) + for name in reversed(parts[:-1]): + result = _make_one(name, payload_module, inner=result) + return result + + +def _class_to_id(codec: Any) -> str: + name = type(codec).__name__ + if name.endswith("Codec"): + return name[:-5].lower() + return name.lower() + + +def _resolve_payload_type(payload_module: str) -> type[Any]: + parts = payload_module.rsplit(".", 1) + if len(parts) != 2: + raise ValueError(f"Cannot resolve payload type from {payload_module!r}") + mod = importlib.import_module(parts[0]) + return getattr(mod, parts[1]) # type: ignore[no-any-return] + + +def _make_one(name: str, payload_module: str, inner: Codec[Any] | None = None) -> Codec[Any]: + """Instantiate a single codec by its short name.""" + if name == "lz4": + from dimos.memory2.codecs.lz4 import Lz4Codec + + if inner is None: + raise ValueError("lz4 is a wrapper codec — must have an inner codec") + return Lz4Codec(inner) + if name == "jpeg": + from dimos.memory2.codecs.jpeg import JpegCodec + + return JpegCodec() + if name == "lcm": + from dimos.memory2.codecs.lcm import LcmCodec + + return LcmCodec(_resolve_payload_type(payload_module)) + if name == "pickle": + from dimos.memory2.codecs.pickle import PickleCodec + + return PickleCodec() + raise ValueError(f"Unknown codec: {name!r}") diff --git a/dimos/memory2/codecs/jpeg.py b/dimos/memory2/codecs/jpeg.py new file mode 100644 index 0000000000..3d854400b1 --- /dev/null +++ b/dimos/memory2/codecs/jpeg.py @@ -0,0 +1,39 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dimos.msgs.sensor_msgs.Image import Image + + +class JpegCodec: + """Codec for Image types — JPEG-compressed inside an LCM Image envelope. + + Uses ``Image.lcm_jpeg_encode/decode`` which preserves ``ts``, ``frame_id``, + and all LCM header fields. Pixel data is lossy-compressed via TurboJPEG. + """ + + def __init__(self, quality: int = 50) -> None: + self._quality = quality + + def encode(self, value: Image) -> bytes: + return value.lcm_jpeg_encode(quality=self._quality) + + def decode(self, data: bytes) -> Image: + from dimos.msgs.sensor_msgs.Image import Image + + return Image.lcm_jpeg_decode(data) diff --git a/dimos/memory2/codecs/lcm.py b/dimos/memory2/codecs/lcm.py new file mode 100644 index 0000000000..fe7055d9c8 --- /dev/null +++ b/dimos/memory2/codecs/lcm.py @@ -0,0 +1,33 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dimos.msgs.protocol import DimosMsg + + +class LcmCodec: + """Codec for DimosMsg types — uses lcm_encode/lcm_decode.""" + + def __init__(self, msg_type: type[DimosMsg]) -> None: + self._msg_type = msg_type + + def encode(self, value: DimosMsg) -> bytes: + return value.lcm_encode() + + def decode(self, data: bytes) -> DimosMsg: + return self._msg_type.lcm_decode(data) diff --git a/dimos/memory2/codecs/lz4.py b/dimos/memory2/codecs/lz4.py new file mode 100644 index 0000000000..15cbad56e4 --- /dev/null +++ b/dimos/memory2/codecs/lz4.py @@ -0,0 +1,42 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import lz4.frame # type: ignore[import-untyped] + +if TYPE_CHECKING: + from dimos.memory2.codecs.base import Codec + + +class Lz4Codec: + """Wraps another codec and applies LZ4 frame compression to the output. + + Works with any inner codec — compresses the bytes produced by + ``inner.encode()`` and decompresses before ``inner.decode()``. + """ + + def __init__(self, inner: Codec[Any], compression_level: int = 0) -> None: + self._inner = inner + self._compression_level = compression_level + + def encode(self, value: Any) -> bytes: + raw = self._inner.encode(value) + return bytes(lz4.frame.compress(raw, compression_level=self._compression_level)) + + def decode(self, data: bytes) -> Any: + raw: bytes = lz4.frame.decompress(data) + return self._inner.decode(raw) diff --git a/dimos/memory2/codecs/pickle.py b/dimos/memory2/codecs/pickle.py new file mode 100644 index 0000000000..7200e1da50 --- /dev/null +++ b/dimos/memory2/codecs/pickle.py @@ -0,0 +1,28 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import pickle +from typing import Any + + +class PickleCodec: + """Fallback codec for arbitrary Python objects.""" + + def encode(self, value: Any) -> bytes: + return pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL) + + def decode(self, data: bytes) -> Any: + return pickle.loads(data) diff --git a/dimos/memory2/codecs/test_codecs.py b/dimos/memory2/codecs/test_codecs.py new file mode 100644 index 0000000000..5c1c23f7c5 --- /dev/null +++ b/dimos/memory2/codecs/test_codecs.py @@ -0,0 +1,181 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Grid tests for Codec implementations. + +Runs roundtrip encode→decode tests across every codec, verifying data preservation. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +import pytest + +from dimos.memory2.codecs.base import Codec, codec_for +from dimos.memory2.codecs.jpeg import JpegCodec +from dimos.memory2.codecs.lcm import LcmCodec +from dimos.memory2.codecs.pickle import PickleCodec +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.sensor_msgs.Image import Image + +if TYPE_CHECKING: + from collections.abc import Callable + + from dimos.msgs.protocol import DimosMsg + + +@dataclass +class Case: + name: str + codec: Codec[Any] + values: list[Any] + eq: Callable[[Any, Any], bool] | None = None # custom equality: (original, decoded) -> bool + + +def _lcm_values() -> list[DimosMsg]: + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + from dimos.msgs.geometry_msgs.Quaternion import Quaternion + from dimos.msgs.geometry_msgs.Vector3 import Vector3 + + return [ + PoseStamped( + ts=1.0, + frame_id="map", + position=Vector3(1.0, 2.0, 3.0), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), + ), + PoseStamped(ts=0.5, frame_id="odom"), + ] + + +def _pickle_case() -> Case: + from dimos.memory2.codecs.pickle import PickleCodec + + return Case( + name="pickle", + codec=PickleCodec(), + values=[42, "hello", b"raw bytes", {"key": "value"}], + ) + + +def _lcm_case() -> Case: + from dimos.memory2.codecs.lcm import LcmCodec + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + + return Case( + name="lcm", + codec=LcmCodec(PoseStamped), + values=_lcm_values(), + ) + + +def _lz4_pickle_case() -> Case: + from dimos.memory2.codecs.lz4 import Lz4Codec + from dimos.memory2.codecs.pickle import PickleCodec + + return Case( + name="lz4+pickle", + codec=Lz4Codec(PickleCodec()), + values=[42, "hello", b"raw bytes", {"key": "value"}, list(range(1000))], + ) + + +def _lz4_lcm_case() -> Case: + from dimos.memory2.codecs.lcm import LcmCodec + from dimos.memory2.codecs.lz4 import Lz4Codec + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + + return Case( + name="lz4+lcm", + codec=Lz4Codec(LcmCodec(PoseStamped)), + values=_lcm_values(), + ) + + +def _jpeg_eq(original: Any, decoded: Any) -> bool: + """JPEG is lossy — check shape, frame_id, and pixel closeness.""" + import numpy as np + + if decoded.data.shape != original.data.shape: + return False + if decoded.frame_id != original.frame_id: + return False + return bool(np.mean(np.abs(decoded.data.astype(float) - original.data.astype(float))) < 5) + + +def _jpeg_case() -> Case | None: + try: + from dimos.memory2.codecs.jpeg import JpegCodec + from dimos.utils.testing.replay import TimedSensorReplay + + replay = TimedSensorReplay("unitree_go2_bigoffice/video") + frames = [replay.find_closest_seek(float(i)) for i in range(1, 4)] + codec = JpegCodec(quality=95) + except (ImportError, RuntimeError): + return None + + return Case( + name="jpeg", + codec=codec, + values=frames, + eq=_jpeg_eq, + ) + + +testcases = [ + c + for c in [_pickle_case(), _lcm_case(), _lz4_pickle_case(), _lz4_lcm_case(), _jpeg_case()] + if c is not None +] + + +@pytest.mark.parametrize("case", testcases, ids=lambda c: c.name) +class TestCodecRoundtrip: + """Every codec must perfectly roundtrip its values.""" + + def test_roundtrip_preserves_value(self, case: Case) -> None: + eq = case.eq or (lambda a, b: a == b) + for value in case.values: + encoded = case.codec.encode(value) + assert isinstance(encoded, bytes) + decoded = case.codec.decode(encoded) + assert eq(value, decoded), f"Roundtrip failed for {value!r}: got {decoded!r}" + + def test_encode_returns_nonempty_bytes(self, case: Case) -> None: + for value in case.values: + encoded = case.codec.encode(value) + assert len(encoded) > 0, f"Empty encoding for {value!r}" + + def test_different_values_produce_different_bytes(self, case: Case) -> None: + encodings = [case.codec.encode(v) for v in case.values] + assert len(set(encodings)) > 1, "All values encoded to identical bytes" + + +class TestCodecFor: + """codec_for() auto-selects the right codec.""" + + def test_none_returns_pickle(self) -> None: + assert isinstance(codec_for(None), PickleCodec) + + def test_unknown_type_returns_pickle(self) -> None: + assert isinstance(codec_for(dict), PickleCodec) + + def test_lcm_type_returns_lcm(self) -> None: + assert isinstance(codec_for(PoseStamped), LcmCodec) + + def test_image_type_returns_jpeg(self) -> None: + pytest.importorskip("turbojpeg") + assert isinstance(codec_for(Image), JpegCodec) diff --git a/dimos/memory2/conftest.py b/dimos/memory2/conftest.py new file mode 100644 index 0000000000..f4ec83884c --- /dev/null +++ b/dimos/memory2/conftest.py @@ -0,0 +1,94 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared fixtures for memory2 tests.""" + +from __future__ import annotations + +import sqlite3 +import tempfile +from typing import TYPE_CHECKING + +import pytest + +from dimos.memory2.blobstore.file import FileBlobStore +from dimos.memory2.blobstore.sqlite import SqliteBlobStore +from dimos.memory2.store.memory import MemoryStore +from dimos.memory2.store.sqlite import SqliteStore + +if TYPE_CHECKING: + from collections.abc import Generator + from pathlib import Path + + from dimos.memory2.blobstore.base import BlobStore + from dimos.memory2.store.base import Store + + +@pytest.fixture +def memory_store() -> Generator[MemoryStore, None, None]: + with MemoryStore() as store: + yield store + + +@pytest.fixture +def memory_session(memory_store: MemoryStore) -> Generator[MemoryStore, None, None]: + """Alias: in the new architecture, the store IS the session.""" + yield memory_store + + +@pytest.fixture +def sqlite_store() -> Generator[SqliteStore, None, None]: + with tempfile.NamedTemporaryFile(suffix=".db") as f: + store = SqliteStore(path=f.name) + with store: + yield store + + +@pytest.fixture +def sqlite_session(sqlite_store: SqliteStore) -> Generator[SqliteStore, None, None]: + """Alias: in the new architecture, the store IS the session.""" + yield sqlite_store + + +@pytest.fixture(params=["memory_store", "sqlite_store"]) +def session(request: pytest.FixtureRequest) -> Store: + """Parametrized fixture that runs tests against both backends. + + Named 'session' to minimize test changes — tests use session.stream() which + now goes directly to Store.stream(). + """ + return request.getfixturevalue(request.param) + + +@pytest.fixture +def file_blob_store(tmp_path: Path) -> Generator[FileBlobStore, None, None]: + store = FileBlobStore(root=str(tmp_path / "blobs")) + store.start() + yield store + store.stop() + + +@pytest.fixture +def sqlite_blob_store() -> Generator[SqliteBlobStore, None, None]: + conn = sqlite3.connect(":memory:") + store = SqliteBlobStore(conn=conn) + store.start() + yield store + store.stop() + conn.close() + + +@pytest.fixture(params=["file_blob_store", "sqlite_blob_store"]) +def blob_store(request: pytest.FixtureRequest) -> BlobStore: + return request.getfixturevalue(request.param) diff --git a/dimos/memory2/embed.py b/dimos/memory2/embed.py new file mode 100644 index 0000000000..17b5b98a31 --- /dev/null +++ b/dimos/memory2/embed.py @@ -0,0 +1,79 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from itertools import islice +from typing import TYPE_CHECKING, Any, TypeVar + +from dimos.memory2.transform import Transformer + +if TYPE_CHECKING: + from collections.abc import Iterator + + from dimos.memory2.type.observation import Observation + from dimos.models.embedding.base import EmbeddingModel + +T = TypeVar("T") + + +def _batched(it: Iterator[T], n: int) -> Iterator[list[T]]: + """Yield successive n-sized chunks from an iterator.""" + while True: + batch = list(islice(it, n)) + if not batch: + return + yield batch + + +class EmbedImages(Transformer[Any, Any]): + """Embed images using ``model.embed()``. + + Data type stays the same — observations are enriched with an + ``.embedding`` field, yielding :class:`EmbeddedObservation` instances. + """ + + def __init__(self, model: EmbeddingModel, batch_size: int = 32) -> None: + self.model = model + self.batch_size = batch_size + + def __call__(self, upstream: Iterator[Observation[Any]]) -> Iterator[Observation[Any]]: + for batch in _batched(upstream, self.batch_size): + images = [obs.data for obs in batch] + embeddings = self.model.embed(*images) + if not isinstance(embeddings, list): + embeddings = [embeddings] + for obs, emb in zip(batch, embeddings, strict=False): + yield obs.derive(data=obs.data, embedding=emb) + + +class EmbedText(Transformer[Any, Any]): + """Embed text using ``model.embed_text()``. + + Data type stays the same — observations are enriched with an + ``.embedding`` field, yielding :class:`EmbeddedObservation` instances. + """ + + def __init__(self, model: EmbeddingModel, batch_size: int = 32) -> None: + self.model = model + self.batch_size = batch_size + + def __call__(self, upstream: Iterator[Observation[Any]]) -> Iterator[Observation[Any]]: + for batch in _batched(upstream, self.batch_size): + texts = [str(obs.data) for obs in batch] + embeddings = self.model.embed_text(*texts) + if not isinstance(embeddings, list): + embeddings = [embeddings] + for obs, emb in zip(batch, embeddings, strict=False): + yield obs.derive(data=obs.data, embedding=emb) diff --git a/dimos/memory2/embeddings.md b/dimos/memory2/embeddings.md new file mode 100644 index 0000000000..9028c29f9d --- /dev/null +++ b/dimos/memory2/embeddings.md @@ -0,0 +1,148 @@ +# memory Embedding Design + +## Core Principle: Enrichment, Not Replacement + +The embedding annotates the observation — it doesn't replace `.data`. +In memory1, `.data` IS the embedding and you need `parent_id` + `project_to()` to get back to the source image. We avoid this entirely. + +## Observation Types + +```python +@dataclass +class Observation(Generic[T]): + id: int + ts: float + pose: Any | None = None + tags: dict[str, Any] = field(default_factory=dict) + _data: T | _Unloaded = ... + _loader: Callable[[], T] | None = None # lazy loading via blob store + +@dataclass +class EmbeddedObservation(Observation[T]): + embedding: Embedding | None = None # populated by Embed transformer + similarity: float | None = None # populated by .search() +``` + +`EmbeddedObservation` is a subclass — passes anywhere `Observation` is accepted (LSP). +Users who don't care about types just use `Observation`. Users who want precision annotate with `EmbeddedObservation`. + +`derive()` on `Observation` promotes to `EmbeddedObservation` if `embedding=` is passed. +`derive()` on `EmbeddedObservation` returns `EmbeddedObservation`, preserving the embedding unless explicitly replaced. + +## Embed Transformer + +`Embed` is `Transformer[T, T]` — same data type in and out. It populates `.embedding` on each observation: + +```python +class Embed(Transformer[T, T]): + def __init__(self, model: EmbeddingModel): + self.model = model + + def __call__(self, upstream): + for batch in batched(upstream, 32): + vecs = self.model.embed_batch([obs.data for obs in batch]) + for obs, vec in zip(batch, vecs): + yield obs.derive(data=obs.data, embedding=vec) +``` + +`Stream[Image]` stays `Stream[Image]` after embedding — `T` is about `.data`, not the observation subclass. + +## Search + +`.search(query_vec, k)` lives on `Stream` itself. Returns a new Stream filtered to top-k by cosine similarity: + +```python +query_vec = clip.embed_text("a cat in the kitchen") + +results = images.transform(Embed(clip)).search(query_vec, k=20).fetch() +# results[0].data → Image +# results[0].embedding → np.ndarray +# results[0].similarity → 0.93 + +# Chainable with other filters +results = images.transform(Embed(clip)) \ + .search(query_vec, k=50) \ + .after(one_hour_ago) \ + .near(kitchen_pose, 5.0) \ + .fetch() +``` + +## Backend Handles Storage Strategy + +The Backend composite decides how to route storage based on what it sees: + +- `append(image, ts=now, embedding=vec)` → backend routes: blob via BlobStore, vector via VectorStore, metadata via ObservationStore +- `append(image, ts=now)` → blob + metadata only (no embedding) +- `ListObservationStore`: stores metadata in-memory, brute-force cosine via MemoryVectorStore +- `SqliteObservationStore`: metadata in SQLite, vec0 side table for fast ANN search via SqliteVectorStore +- Future backends (Postgres/pgvector, Qdrant, etc.) do their thing + +Search is pushed down to the VectorStore. Stream just passes `.search()` calls through. + +## Projection / Lineage + +**Usually not needed.** Since `.data` IS the original data, search results give you the image directly. + +When a downstream transform replaces `.data` (e.g., Image → Detection), use temporal join to get back to the source: + +```python +detection = detections.first() +detection.data # → Detection +detection.ts # → timestamp preserved by derive() + +# Get the source image via temporal join +source_image = images.at(detection.ts).first() +``` + +## Multi-Modal + +**Same embedding space = same stream.** CLIP maps images and text to the same 512-d space: + +```python +unified = store.stream("clip_unified") + +for obs in images.transform(Embed(clip.vision)): + unified.append(obs.data, ts=obs.ts, + tags={"modality": "image"}, embedding=obs.embedding) + +for obs in logs.transform(Embed(clip.text)): + unified.append(obs.data, ts=obs.ts, + tags={"modality": "text"}, embedding=obs.embedding) + +results = unified.search(query_vec, k=20).fetch() +# results[i].tags["modality"] tells you what it is +``` + +**Different embedding spaces = different streams.** Can't mix CLIP and sentence-transformer vectors. + +## Chaining — Embedding as Cheap Pre-Filter + +```python +smoke_query = clip.embed_text("smoke or fire") + +detections = images.transform(Embed(clip)) \ + .search(smoke_query, k=100) \ + .transform(ExpensiveVLMDetector()) +# VLM only runs on 100 most promising frames + +# Smart transformer can use embedding directly +class SmartDetector(Transformer[Image, Detection]): + def __call__(self, upstream: Iterator[EmbeddedObservation[Image]]) -> ...: + for obs in upstream: + if obs.embedding @ self.query > 0.3: + yield obs.derive(data=self.detect(obs.data)) +``` + +## Text Search (FTS) — Separate Concern + +FTS is keyword-based, not embedding-based. Complementary, not competing: + +```python +# Keyword search via FTS5 +logs = store.stream("logs") +logs.search_text("motor fault").fetch() + +# Semantic search via embeddings +log_idx = logs.transform(Embed(sentence_model)).store("log_emb") +log_idx.search(model.embed("motor problems"), k=10).fetch() +``` diff --git a/dimos/memory2/intro.md b/dimos/memory2/intro.md new file mode 100644 index 0000000000..e88561c283 --- /dev/null +++ b/dimos/memory2/intro.md @@ -0,0 +1,170 @@ +# Memory Intro + +## Quick start + +```python session=memory ansi=false no-result +from dimos.memory2.store.sqlite import SqliteStore + +store = SqliteStore(path="/tmp/memory_readme.db") +``` + + +```python session=memory ansi=false +logs = store.stream("logs", str) +print(logs) +``` + + +``` +Stream("logs") +``` + +Append observations: + +```python session=memory ansi=false +logs.append("Motor started", ts=1.0, tags={"level": "info"}) +logs.append("Joint 3 fault", ts=2.0, tags={"level": "error"}) +logs.append("Motor stopped", ts=3.0, tags={"level": "info"}) + +print(logs.summary()) +``` + + +``` +Stream("logs"): 3 items, 1970-01-01 00:00:01 — 1970-01-01 00:00:03 (2.0s) +``` + +## Filters + +Queries are lazy — chaining filters builds a pipeline without fetching: + +```python session=memory ansi=false +print(logs.at(1.0).before(5.0).tags(level="error")) +``` + + +``` +Stream("logs") | AtFilter(t=1.0, tolerance=1.0) | BeforeFilter(t=5.0) | TagsFilter(tags={'level': 'error'}) +``` + +Available filters: `.after(t)`, `.before(t)`, `.at(t)`, `.near(pose, radius)`, `.tags(**kv)`, `.filter(predicate)`, `.search(embedding, k)`, `.order_by(field)`, `.limit(k)`, `.offset(n)`. + +## Terminals + +Terminals materialize or consume the stream: + +```python session=memory ansi=false +print(logs.before(5.0).tags(level="error").fetch()) +``` + + +``` +[Observation(id=2, ts=2.0, pose=None, tags={'level': 'error'})] +``` + +Available terminals: `.fetch()`, `.first()`, `.last()`, `.count()`, `.exists()`, `.summary()`, `.get_time_range()`, `.drain()`, `.save(target)`. + +## Transforms + +`.map(fn)` transforms each observation, returning a new stream: + +```python session=memory ansi=false +print(logs.map(lambda obs: obs.data.upper()).first()) +``` + + +``` +MOTOR STARTED +``` + +## Live queries + +Live queries backfill existing matches, then emit new ones as they arrive: + +```python session=memory ansi=false +import time + +def emit_some_logs(): + last_ts = logs.last().ts + logs.append("Heartbeat ok", ts=last_ts + 1, pose=(3.0, 1.5, 0.0), tags={"level": "info"}) + time.sleep(0.1) + logs.append("Sensor fault", ts=last_ts + 2, pose=(4.1, 2.0, 0.0), tags={"level": "error"}) + time.sleep(0.1) + logs.append("Battery low: 30%", ts=last_ts + 3, pose=(5.3, 2.5, 0.0), tags={"level": "info"}) + time.sleep(0.1) + logs.append("Overtemp", ts=last_ts + 4, pose=(6.0, 3.0, 0.0), tags={"level": "error"}) + time.sleep(0.1) + + +with logs.tags(level="error").live() as errors: + sub = errors.subscribe(lambda obs: print(f"{obs.ts} - {obs.data}")) + emit_some_logs() + sub.dispose() + +``` + + +``` +2.0 - Joint 3 fault +5.0 - Sensor fault +7.0 - Overtemp +``` + +## Spatial + live + +Filters compose freely. Here `.near()` + `.live()` + `.map()` watches for logs near a physical location — backfilling past matches and tailing new ones: + +```python session=memory ansi=false +near_query = logs.near((5.0, 2.0), radius=2.0).live() +with near_query.map(lambda obs: f"near POI - {obs.data}") as logs_near: + with logs_near.subscribe(print): + emit_some_logs() +``` + + +``` +near POI - Sensor fault +near POI - Battery low: 30% +near POI - Overtemp +near POI - Sensor fault +near POI - Battery low: 30% +near POI - Overtemp +``` + +## Embeddings + +Use `EmbedText` transformer with CLIP to enrich observations with embeddings, then search by similarity: + +`.search(embedding, k)` returns the top-k most similar observations by cosine similarity: + +```python session=memory ansi=false +from dimos.models.embedding.clip import CLIPModel +from dimos.memory2.embed import EmbedText + +clip = CLIPModel() + +for obs in logs.transform(EmbedText(clip)).search(clip.embed_text("hardware problem"), k=3).fetch(): + print(f"{obs.similarity:.3f} {obs.data}") +``` + + +``` +0.897 Sensor fault +0.897 Sensor fault +0.887 Battery low: 30% +``` + +The embedded stream above was ephemeral — built on the fly for one query. To persist embeddings automatically as logs arrive, pipe a live stream through the transform into a stored stream: + +```python skip +import threading + +embedded_logs = store.stream("embedded_logs", str) +threading.Thread( + target=lambda: logs.live().transform(EmbedText(clip)).save(embedded_logs), + daemon=True, +).start() + +# every new log is now automatically embedded and stored +# embedded_logs.search(query, k=5).fetch() to query at any time +``` diff --git a/dimos/memory2/notes.md b/dimos/memory2/notes.md new file mode 100644 index 0000000000..8a9a05c30c --- /dev/null +++ b/dimos/memory2/notes.md @@ -0,0 +1,10 @@ + +```python +with db() as db: + with db.stream as image: + image.put(...) +``` + +DB specifies some general configuration for all sessions/streams. + +`db.stream` initializes these sessions? diff --git a/dimos/memory2/notifier/base.py b/dimos/memory2/notifier/base.py new file mode 100644 index 0000000000..022d26d4e0 --- /dev/null +++ b/dimos/memory2/notifier/base.py @@ -0,0 +1,62 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +from dimos.memory2.registry import qual +from dimos.protocol.service.spec import BaseConfig, Configurable + +if TYPE_CHECKING: + from reactivex.abc import DisposableBase + + from dimos.memory2.buffer import BackpressureBuffer + from dimos.memory2.type.observation import Observation + +T = TypeVar("T") + + +class NotifierConfig(BaseConfig): + pass + + +class Notifier(Configurable[NotifierConfig], Generic[T]): + """Push-notification for live observation delivery. + + Decouples the notification mechanism from storage. The built-in + ``SubjectNotifier`` handles same-process fan-out (thread-safe, zero + config). External implementations (Redis pub/sub, Postgres + LISTEN/NOTIFY, inotify) can be injected for cross-process use. + """ + + default_config: type[NotifierConfig] = NotifierConfig + + def __init__(self, **kwargs: Any) -> None: + Configurable.__init__(self, **kwargs) + + @abstractmethod + def subscribe(self, buf: BackpressureBuffer[Observation[T]]) -> DisposableBase: + """Register *buf* to receive new observations. Returns a + disposable that unsubscribes when disposed.""" + ... + + @abstractmethod + def notify(self, obs: Observation[T]) -> None: + """Fan out *obs* to all current subscribers.""" + ... + + def serialize(self) -> dict[str, Any]: + return {"class": qual(type(self)), "config": self.config.model_dump()} diff --git a/dimos/memory2/notifier/subject.py b/dimos/memory2/notifier/subject.py new file mode 100644 index 0000000000..d804b03b66 --- /dev/null +++ b/dimos/memory2/notifier/subject.py @@ -0,0 +1,70 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""In-memory fan-out notifier (same-process, thread-safe).""" + +from __future__ import annotations + +import threading +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +from reactivex.disposable import Disposable + +from dimos.memory2.notifier.base import Notifier, NotifierConfig + +if TYPE_CHECKING: + from reactivex.abc import DisposableBase + + from dimos.memory2.buffer import BackpressureBuffer + from dimos.memory2.type.observation import Observation + +T = TypeVar("T") + + +class SubjectNotifierConfig(NotifierConfig): + pass + + +class SubjectNotifier(Notifier[T], Generic[T]): + """In-memory fan-out notifier for same-process live notification. + + Thread-safe. ``notify()`` copies the subscriber list under the lock, + then iterates outside the lock to avoid deadlocks with slow consumers. + """ + + default_config = SubjectNotifierConfig + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._subscribers: list[BackpressureBuffer[Observation[T]]] = [] + self._lock = threading.Lock() + + def subscribe(self, buf: BackpressureBuffer[Observation[T]]) -> DisposableBase: + with self._lock: + self._subscribers.append(buf) + + def _unsubscribe() -> None: + with self._lock: + try: + self._subscribers.remove(buf) + except ValueError: + pass + + return Disposable(action=_unsubscribe) + + def notify(self, obs: Observation[T]) -> None: + with self._lock: + subs = list(self._subscribers) + for buf in subs: + buf.put(obs) diff --git a/dimos/memory2/observationstore/base.py b/dimos/memory2/observationstore/base.py new file mode 100644 index 0000000000..4d94889fb0 --- /dev/null +++ b/dimos/memory2/observationstore/base.py @@ -0,0 +1,73 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +from dimos.core.resource import CompositeResource +from dimos.memory2.registry import qual +from dimos.protocol.service.spec import BaseConfig, Configurable + +if TYPE_CHECKING: + from collections.abc import Iterator + + from dimos.memory2.type.filter import StreamQuery + from dimos.memory2.type.observation import Observation + +T = TypeVar("T") + + +class ObservationStoreConfig(BaseConfig): + pass + + +class ObservationStore(Configurable[ObservationStoreConfig], CompositeResource, Generic[T]): + """Core metadata storage and query engine for observations. + + Handles only observation metadata storage, query pushdown, and count. + Blob/vector/live orchestration is handled by the concrete Backend class. + """ + + default_config: type[ObservationStoreConfig] = ObservationStoreConfig + + def __init__(self, **kwargs: Any) -> None: + Configurable.__init__(self, **kwargs) + CompositeResource.__init__(self) + + @property + @abstractmethod + def name(self) -> str: ... + + @abstractmethod + def insert(self, obs: Observation[T]) -> int: + """Insert observation metadata, return assigned id.""" + ... + + @abstractmethod + def query(self, q: StreamQuery) -> Iterator[Observation[T]]: + """Execute query against metadata. Blobs are NOT loaded here.""" + ... + + @abstractmethod + def count(self, q: StreamQuery) -> int: ... + + @abstractmethod + def fetch_by_ids(self, ids: list[int]) -> list[Observation[T]]: + """Batch fetch by id (for vector search results).""" + ... + + def serialize(self) -> dict[str, Any]: + return {"class": qual(type(self)), "config": self.config.model_dump()} diff --git a/dimos/memory2/observationstore/memory.py b/dimos/memory2/observationstore/memory.py new file mode 100644 index 0000000000..529cd06394 --- /dev/null +++ b/dimos/memory2/observationstore/memory.py @@ -0,0 +1,80 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import threading +from typing import TYPE_CHECKING, Any, TypeVar + +from dimos.memory2.observationstore.base import ObservationStore, ObservationStoreConfig + +if TYPE_CHECKING: + from collections.abc import Iterator + + from dimos.memory2.type.filter import StreamQuery + from dimos.memory2.type.observation import Observation + +T = TypeVar("T") + + +class ListObservationStoreConfig(ObservationStoreConfig): + name: str = "" + + +class ListObservationStore(ObservationStore[T]): + """In-memory metadata store for experimentation. Thread-safe.""" + + default_config = ListObservationStoreConfig + config: ListObservationStoreConfig + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._name = self.config.name + self._observations: list[Observation[T]] = [] + self._next_id = 0 + self._lock = threading.Lock() + + @property + def name(self) -> str: + return self._name + + def insert(self, obs: Observation[T]) -> int: + with self._lock: + obs.id = self._next_id + row_id = self._next_id + self._next_id += 1 + self._observations.append(obs) + return row_id + + def query(self, q: StreamQuery) -> Iterator[Observation[T]]: + with self._lock: + snapshot = list(self._observations) + + # Text search — substring match + if q.search_text is not None: + needle = q.search_text.lower() + it: Iterator[Observation[T]] = ( + obs for obs in snapshot if needle in str(obs.data).lower() + ) + return q.apply(it) + + return q.apply(iter(snapshot)) + + def count(self, q: StreamQuery) -> int: + return sum(1 for _ in self.query(q)) + + def fetch_by_ids(self, ids: list[int]) -> list[Observation[T]]: + id_set = set(ids) + with self._lock: + return [obs for obs in self._observations if obs.id in id_set] diff --git a/dimos/memory2/observationstore/sqlite.py b/dimos/memory2/observationstore/sqlite.py new file mode 100644 index 0000000000..5d680c540a --- /dev/null +++ b/dimos/memory2/observationstore/sqlite.py @@ -0,0 +1,444 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import re +import sqlite3 +import threading +from typing import TYPE_CHECKING, Any, TypeVar + +from pydantic import Field, model_validator + +from dimos.memory2.codecs.base import Codec +from dimos.memory2.observationstore.base import ObservationStore, ObservationStoreConfig +from dimos.memory2.type.filter import ( + AfterFilter, + AtFilter, + BeforeFilter, + NearFilter, + TagsFilter, + TimeRangeFilter, + _xyz, +) +from dimos.memory2.type.observation import _UNLOADED, Observation +from dimos.memory2.utils.sqlite import open_disposable_sqlite_connection + +if TYPE_CHECKING: + from collections.abc import Iterator + + from dimos.memory2.type.filter import Filter, StreamQuery + +T = TypeVar("T") + +_IDENT_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + + +def _decompose_pose(pose: Any) -> tuple[float, ...] | None: + if pose is None: + return None + if hasattr(pose, "position"): + pos = pose.position + orient = getattr(pose, "orientation", None) + x, y, z = float(pos.x), float(pos.y), float(getattr(pos, "z", 0.0)) + if orient is not None: + return (x, y, z, float(orient.x), float(orient.y), float(orient.z), float(orient.w)) + return (x, y, z, 0.0, 0.0, 0.0, 1.0) + if isinstance(pose, (list, tuple)): + vals = [float(v) for v in pose] + while len(vals) < 7: + vals.append(0.0 if len(vals) < 6 else 1.0) + return tuple(vals[:7]) + return None + + +def _reconstruct_pose( + x: float | None, + y: float | None, + z: float | None, + qx: float | None, + qy: float | None, + qz: float | None, + qw: float | None, +) -> tuple[float, ...] | None: + if x is None: + return None + return (x, y or 0.0, z or 0.0, qx or 0.0, qy or 0.0, qz or 0.0, qw or 1.0) + + +def _compile_filter(f: Filter, stream: str, prefix: str = "") -> tuple[str, list[Any]] | None: + """Compile a filter to SQL WHERE clause. Returns None for non-pushable filters. + + ``stream`` is the raw stream name (for R*Tree table references). + ``prefix`` is a column qualifier (e.g. ``"meta."`` for JOIN queries). + """ + if isinstance(f, AfterFilter): + return (f"{prefix}ts > ?", [f.t]) + if isinstance(f, BeforeFilter): + return (f"{prefix}ts < ?", [f.t]) + if isinstance(f, TimeRangeFilter): + return (f"{prefix}ts >= ? AND {prefix}ts <= ?", [f.t1, f.t2]) + if isinstance(f, AtFilter): + return (f"ABS({prefix}ts - ?) <= ?", [f.t, f.tolerance]) + if isinstance(f, TagsFilter): + clauses = [] + params: list[Any] = [] + for k, v in f.tags.items(): + if not _IDENT_RE.match(k): + raise ValueError(f"Invalid tag key: {k!r}") + clauses.append(f"json_extract({prefix}tags, '$.{k}') = ?") + params.append(v) + return (" AND ".join(clauses), params) + if isinstance(f, NearFilter): + pose = f.pose + if pose is None: + return None + if hasattr(pose, "position"): + pose = pose.position + cx, cy, cz = _xyz(pose) + r = f.radius + # R*Tree bounding-box pre-filter + exact squared-distance check + rtree_sql = ( + f'{prefix}id IN (SELECT id FROM "{stream}_rtree" ' + f"WHERE x_min >= ? AND x_max <= ? " + f"AND y_min >= ? AND y_max <= ? " + f"AND z_min >= ? AND z_max <= ?)" + ) + dist_sql = ( + f"(({prefix}pose_x - ?) * ({prefix}pose_x - ?) + " + f"({prefix}pose_y - ?) * ({prefix}pose_y - ?) + " + f"({prefix}pose_z - ?) * ({prefix}pose_z - ?) <= ?)" + ) + return ( + f"{rtree_sql} AND {dist_sql}", + [ + cx - r, + cx + r, + cy - r, + cy + r, + cz - r, + cz + r, # R*Tree bbox + cx, + cx, + cy, + cy, + cz, + cz, + r * r, # squared distance + ], + ) + # PredicateFilter — not pushable + return None + + +def _compile_query( + query: StreamQuery, + table: str, + *, + join_blob: bool = False, +) -> tuple[str, list[Any], list[Filter]]: + """Compile a StreamQuery to SQL. + + Returns (sql, params, python_filters) where python_filters must be + applied as post-filters in Python. + """ + prefix = "meta." if join_blob else "" + if join_blob: + select = f'SELECT meta.id, meta.ts, meta.pose_x, meta.pose_y, meta.pose_z, meta.pose_qx, meta.pose_qy, meta.pose_qz, meta.pose_qw, json(meta.tags), blob.data FROM "{table}" AS meta JOIN "{table}_blob" AS blob ON blob.id = meta.id' + else: + select = f'SELECT id, ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, json(tags) FROM "{table}"' + + where_parts: list[str] = [] + params: list[Any] = [] + python_filters: list[Filter] = [] + + for f in query.filters: + compiled = _compile_filter(f, table, prefix) + if compiled is not None: + sql_part, sql_params = compiled + where_parts.append(sql_part) + params.extend(sql_params) + else: + python_filters.append(f) + + sql = select + if where_parts: + sql += " WHERE " + " AND ".join(where_parts) + + # ORDER BY + if query.order_field: + if not _IDENT_RE.match(query.order_field): + raise ValueError(f"Invalid order_field: {query.order_field!r}") + direction = "DESC" if query.order_desc else "ASC" + sql += f" ORDER BY {prefix}{query.order_field} {direction}" + else: + sql += f" ORDER BY {prefix}id ASC" + + # Only push LIMIT/OFFSET to SQL when there are no Python post-filters + if not python_filters: + if query.limit_val is not None: + if query.offset_val: + sql += f" LIMIT {query.limit_val} OFFSET {query.offset_val}" + else: + sql += f" LIMIT {query.limit_val}" + elif query.offset_val: + sql += f" LIMIT -1 OFFSET {query.offset_val}" + + return (sql, params, python_filters) + + +def _compile_count( + query: StreamQuery, + table: str, +) -> tuple[str, list[Any], list[Filter]]: + """Compile a StreamQuery to a COUNT SQL query.""" + where_parts: list[str] = [] + params: list[Any] = [] + python_filters: list[Filter] = [] + + for f in query.filters: + compiled = _compile_filter(f, table) + if compiled is not None: + sql_part, sql_params = compiled + where_parts.append(sql_part) + params.extend(sql_params) + else: + python_filters.append(f) + + sql = f'SELECT COUNT(*) FROM "{table}"' + if where_parts: + sql += " WHERE " + " AND ".join(where_parts) + + return (sql, params, python_filters) + + +class SqliteObservationStoreConfig(ObservationStoreConfig): + conn: sqlite3.Connection | None = Field(default=None, exclude=True) + name: str = "" + codec: Codec[Any] | None = Field(default=None, exclude=True) + blob_store_conn_match: bool = Field(default=False, exclude=True) + page_size: int = 256 + path: str | None = None + + @model_validator(mode="after") + def _conn_xor_path(self) -> SqliteObservationStoreConfig: + if self.conn is not None and self.path is not None: + raise ValueError("Specify either conn or path, not both") + if self.conn is None and self.path is None: + raise ValueError("Specify either conn or path") + return self + + +class SqliteObservationStore(ObservationStore[T]): + """SQLite-backed metadata store for a single stream (table). + + Handles only metadata storage and query pushdown. + Blob/vector/live orchestration is handled by Backend. + + Supports two construction modes: + + - ``SqliteObservationStore(conn=conn, name="x", codec=...)`` — borrows an externally-managed connection. + - ``SqliteObservationStore(path="file.db", name="x", codec=...)`` — opens and owns its own connection. + """ + + default_config = SqliteObservationStoreConfig + config: SqliteObservationStoreConfig + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._conn: sqlite3.Connection = self.config.conn # type: ignore[assignment] # set in start() if None + self._path = self.config.path + self._name = self.config.name + self._codec = self.config.codec + self._blob_store_conn_match = self.config.blob_store_conn_match + self._page_size = self.config.page_size + self._lock = threading.Lock() + self._tag_indexes: set[str] = set() + self._pending_python_filters: list[Any] = [] + self._pending_query: StreamQuery | None = None + + def start(self) -> None: + if self._conn is None: + assert self._path is not None + disposable, self._conn = open_disposable_sqlite_connection(self._path) + self.register_disposables(disposable) + self._ensure_tables() + + def _ensure_tables(self) -> None: + """Create the metadata table and R*Tree index if they don't exist.""" + self._conn.execute( + f'CREATE TABLE IF NOT EXISTS "{self._name}" (' + " id INTEGER PRIMARY KEY AUTOINCREMENT," + " ts REAL NOT NULL UNIQUE," + " pose_x REAL, pose_y REAL, pose_z REAL," + " pose_qx REAL, pose_qy REAL, pose_qz REAL, pose_qw REAL," + " tags BLOB DEFAULT (jsonb('{}'))" + ")" + ) + self._conn.execute( + f'CREATE VIRTUAL TABLE IF NOT EXISTS "{self._name}_rtree" USING rtree(' + " id," + " x_min, x_max," + " y_min, y_max," + " z_min, z_max" + ")" + ) + self._conn.commit() + + @property + def name(self) -> str: + return self._name + + @property + def _join_blobs(self) -> bool: + return self._blob_store_conn_match + + def _make_loader(self, row_id: int, blob_store: Any) -> Any: + name = self._name + codec = self._codec + assert codec is not None, "codec is required for data loading" + + def loader() -> Any: + raw = blob_store.get(name, row_id) + return codec.decode(raw) + + return loader + + def _row_to_obs(self, row: tuple[Any, ...], *, has_blob: bool = False) -> Observation[T]: + if has_blob: + row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json, blob_data = row + else: + row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json = row + blob_data = None + + pose = _reconstruct_pose(px, py, pz, qx, qy, qz, qw) + tags = json.loads(tags_json) if tags_json else {} + + if has_blob and blob_data is not None: + assert self._codec is not None, "codec is required for data loading" + data = self._codec.decode(blob_data) + return Observation(id=row_id, ts=ts, pose=pose, tags=tags, _data=data) + + return Observation( + id=row_id, + ts=ts, + pose=pose, + tags=tags, + _data=_UNLOADED, + ) + + def _ensure_tag_indexes(self, tags: dict[str, Any]) -> None: + for key in tags: + if key not in self._tag_indexes and _IDENT_RE.match(key): + self._conn.execute( + f'CREATE INDEX IF NOT EXISTS "{self._name}_tag_{key}" ' + f"ON \"{self._name}\"(json_extract(tags, '$.{key}'))" + ) + self._tag_indexes.add(key) + + def insert(self, obs: Observation[T]) -> int: + pose = _decompose_pose(obs.pose) + tags_json = json.dumps(obs.tags) if obs.tags else "{}" + + with self._lock: + if obs.tags: + self._ensure_tag_indexes(obs.tags) + if pose: + px, py, pz, qx, qy, qz, qw = pose + else: + px = py = pz = qx = qy = qz = qw = None # type: ignore[assignment] + + cur = self._conn.execute( + f'INSERT INTO "{self._name}" (ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, tags) ' + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, jsonb(?))", + (obs.ts, px, py, pz, qx, qy, qz, qw, tags_json), + ) + row_id = cur.lastrowid + assert row_id is not None + + # R*Tree spatial index + if pose: + self._conn.execute( + f'INSERT INTO "{self._name}_rtree" (id, x_min, x_max, y_min, y_max, z_min, z_max) ' + "VALUES (?, ?, ?, ?, ?, ?, ?)", + (row_id, px, px, py, py, pz, pz), + ) + + # Do NOT commit here — Backend calls commit() after blob/vector writes + + return row_id + + def commit(self) -> None: + self._conn.commit() + + def rollback(self) -> None: + self._conn.rollback() + + def query(self, q: StreamQuery) -> Iterator[Observation[T]]: + if q.search_text is not None: + raise NotImplementedError("search_text is not supported by SqliteObservationStore") + + join = self._join_blobs + sql, params, python_filters = _compile_query(q, self._name, join_blob=join) + + cur = self._conn.execute(sql, params) + cur.arraysize = self._page_size + it: Iterator[Observation[T]] = (self._row_to_obs(r, has_blob=join) for r in cur) + + # Don't apply python post-filters here — Backend._attach_loaders must + # run first so that obs.data works for PredicateFilter etc. + # Store them so Backend can retrieve and apply after attaching loaders. + self._pending_python_filters = python_filters + self._pending_query = q + + return it + + def count(self, q: StreamQuery) -> int: + if q.search_vec: + # Delegate to Backend for vector-aware counting + raise NotImplementedError("count with search_vec must go through Backend") + + sql, params, python_filters = _compile_count(q, self._name) + if python_filters: + return sum(1 for _ in self.query(q)) + + row = self._conn.execute(sql, params).fetchone() + return int(row[0]) if row else 0 + + def fetch_by_ids(self, ids: list[int]) -> list[Observation[T]]: + if not ids: + return [] + join = self._join_blobs + placeholders = ",".join("?" * len(ids)) + if join: + sql = ( + f"SELECT meta.id, meta.ts, meta.pose_x, meta.pose_y, meta.pose_z, " + f"meta.pose_qx, meta.pose_qy, meta.pose_qz, meta.pose_qw, json(meta.tags), blob.data " + f'FROM "{self._name}" AS meta ' + f'JOIN "{self._name}_blob" AS blob ON blob.id = meta.id ' + f"WHERE meta.id IN ({placeholders})" + ) + else: + sql = ( + f"SELECT id, ts, pose_x, pose_y, pose_z, " + f"pose_qx, pose_qy, pose_qz, pose_qw, json(tags) " + f'FROM "{self._name}" WHERE id IN ({placeholders})' + ) + + rows = self._conn.execute(sql, ids).fetchall() + return [self._row_to_obs(r, has_blob=join) for r in rows] + + def stop(self) -> None: + super().stop() diff --git a/dimos/memory2/registry.py b/dimos/memory2/registry.py new file mode 100644 index 0000000000..bf9bd5de55 --- /dev/null +++ b/dimos/memory2/registry.py @@ -0,0 +1,82 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Stream registry: persists fully-resolved backend config per stream.""" + +from __future__ import annotations + +import importlib +import json +import sqlite3 +from typing import Any + +from pydantic import Field + +from dimos.protocol.service.spec import BaseConfig, Configurable + + +def qual(cls: type) -> str: + """Fully qualified class name, e.g. 'dimos.memory2.blobstore.sqlite.SqliteBlobStore'.""" + return f"{cls.__module__}.{cls.__qualname__}" + + +def deserialize_component(data: dict[str, Any]) -> Any: + """Instantiate a component from its ``{"class": ..., "config": ...}`` dict.""" + module_path, _, cls_name = data["class"].rpartition(".") + mod = importlib.import_module(module_path) + cls = getattr(mod, cls_name) + return cls(**data["config"]) + + +class RegistryStoreConfig(BaseConfig): + conn: sqlite3.Connection | None = Field(default=None, exclude=True) + + +class RegistryStore(Configurable[RegistryStoreConfig]): + """SQLite persistence for stream name -> config JSON.""" + + default_config: type[RegistryStoreConfig] = RegistryStoreConfig + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + assert self.config.conn is not None, "conn is required" + self._conn: sqlite3.Connection = self.config.conn + self._conn.execute( + "CREATE TABLE IF NOT EXISTS _streams (" + " name TEXT PRIMARY KEY," + " config TEXT NOT NULL" + ")" + ) + self._conn.commit() + + def get(self, name: str) -> dict[str, Any] | None: + row = self._conn.execute("SELECT config FROM _streams WHERE name = ?", (name,)).fetchone() + if row is None: + return None + return json.loads(row[0]) # type: ignore[no-any-return] + + def put(self, name: str, config: dict[str, Any]) -> None: + self._conn.execute( + "INSERT OR REPLACE INTO _streams (name, config) VALUES (?, ?)", + (name, json.dumps(config)), + ) + self._conn.commit() + + def delete(self, name: str) -> None: + self._conn.execute("DELETE FROM _streams WHERE name = ?", (name,)) + self._conn.commit() + + def list_streams(self) -> list[str]: + rows = self._conn.execute("SELECT name FROM _streams").fetchall() + return [r[0] for r in rows] diff --git a/dimos/memory2/store/README.md b/dimos/memory2/store/README.md new file mode 100644 index 0000000000..ff18640c0b --- /dev/null +++ b/dimos/memory2/store/README.md @@ -0,0 +1,130 @@ +# store — Store implementations + +Metadata index backends for memory. Each index implements the `ObservationStore` protocol to provide observation metadata storage with query support. The concrete `Backend` class handles orchestration (blob, vector, live) on top of any index. + +## Existing implementations + +| ObservationStore | File | Status | Storage | +|-----------------|-------------|----------|-------------------------------------| +| `ListObservationStore` | `memory.py` | Complete | In-memory lists, brute-force search | +| `SqliteObservationStore` | `sqlite.py` | Complete | SQLite (WAL, R*Tree, vec0) | + +## Writing a new index + +### 1. Implement the ObservationStore protocol + +```python +from dimos.memory2.observationstore.base import ObservationStore +from dimos.memory2.type.filter import StreamQuery +from dimos.memory2.type.observation import Observation + +class MyObservationStore(Generic[T]): + def __init__(self, name: str) -> None: + self._name = name + + @property + def name(self) -> str: + return self._name + + def insert(self, obs: Observation[T]) -> int: + """Insert observation metadata, return assigned id.""" + row_id = self._next_id + self._next_id += 1 + # ... persist metadata ... + return row_id + + def query(self, q: StreamQuery) -> Iterator[Observation[T]]: + """Yield observations matching the query.""" + # The index handles metadata query fields: + # q.filters — list of Filter objects (each has .matches(obs)) + # q.order_field — sort field name (e.g. "ts") + # q.order_desc — sort direction + # q.limit_val — max results + # q.offset_val — skip first N + # q.search_text — substring text search + ... + + def count(self, q: StreamQuery) -> int: + """Count matching observations.""" + ... + + def fetch_by_ids(self, ids: list[int]) -> list[Observation[T]]: + """Batch fetch by id (for vector search results).""" + ... +``` + +`ObservationStore` is a `@runtime_checkable` Protocol — no base class needed, just implement the methods. + +### 2. Create a Store subclass + +```python +from dimos.memory2.backend import Backend +from dimos.memory2.codecs.base import codec_for +from dimos.memory2.store.base import Store + +class MyStore(Store): + def _create_backend( + self, name: str, payload_type: type | None = None, **config: Any + ) -> Backend: + index = MyObservationStore(name) + codec = codec_for(payload_type) + return Backend( + index=index, + codec=codec, + blob_store=config.get("blob_store"), + vector_store=config.get("vector_store"), + notifier=config.get("notifier"), + eager_blobs=config.get("eager_blobs", False), + ) + + def list_streams(self) -> list[str]: + return list(self._streams.keys()) + + def delete_stream(self, name: str) -> None: + self._streams.pop(name, None) +``` + +The Store creates a `Backend` composite for each stream. The `Backend` handles all orchestration (encode → insert → store blob → index vector → notify) so your index only needs to handle metadata. + +### 3. Add to the grid test + +In `test_impl.py`, add your store to the fixture so all standard tests run against it: + +```python +@pytest.fixture(params=["memory", "sqlite", "myindex"]) +def store(request, tmp_path): + if request.param == "myindex": + return MyStore(...) + ... +``` + +Use `pytest.mark.xfail` for features not yet implemented — the grid test covers: append, fetch, iterate, count, first/last, exists, all filters, ordering, limit/offset, embeddings, text search. + +### Query contract + +The index must handle the `StreamQuery` metadata fields. Vector search and blob loading are handled by the `Backend` composite — the index never needs to deal with them. + +`StreamQuery.apply(iterator)` provides a complete Python-side execution path — filters, text search, vector search, ordering, offset/limit — all as in-memory operations. ObservationStorees can use it in three ways: + +**Full delegation** — simplest, good enough for in-memory indexes: +```python +def query(self, q: StreamQuery) -> Iterator[Observation[T]]: + return q.apply(iter(self._data)) +``` + +**Partial push-down** — handle some operations natively, delegate the rest: +```python +def query(self, q: StreamQuery) -> Iterator[Observation[T]]: + # Handle filters and ordering in SQL + rows = self._sql_query(q.filters, q.order_field, q.order_desc) + # Delegate remaining operations to Python + remaining = StreamQuery( + search_text=q.search_text, + offset_val=q.offset_val, limit_val=q.limit_val, + ) + return remaining.apply(iter(rows)) +``` + +**Full push-down** — translate everything to native queries (SQL WHERE, FTS5 MATCH) without calling `apply()` at all. + +For filters, each `Filter` object has a `.matches(obs) -> bool` method that indexes can use directly if they don't have a native equivalent. diff --git a/dimos/memory2/store/base.py b/dimos/memory2/store/base.py new file mode 100644 index 0000000000..cf571f23b0 --- /dev/null +++ b/dimos/memory2/store/base.py @@ -0,0 +1,166 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any, TypeVar, cast + +from dimos.core.resource import CompositeResource +from dimos.memory2.backend import Backend +from dimos.memory2.blobstore.base import BlobStore +from dimos.memory2.codecs.base import Codec, codec_for, codec_from_id +from dimos.memory2.notifier.base import Notifier +from dimos.memory2.notifier.subject import SubjectNotifier +from dimos.memory2.observationstore.base import ObservationStore +from dimos.memory2.observationstore.memory import ListObservationStore +from dimos.memory2.stream import Stream +from dimos.memory2.vectorstore.base import VectorStore +from dimos.protocol.service.spec import BaseConfig, Configurable + +T = TypeVar("T") + + +class StreamAccessor: + """Attribute-style access: ``store.streams.name`` -> ``store.stream(name)``.""" + + __slots__ = ("_store",) + + def __init__(self, store: Store) -> None: + object.__setattr__(self, "_store", store) + + def __getattr__(self, name: str) -> Stream[Any]: + if name.startswith("_"): + raise AttributeError(name) + store: Store = object.__getattribute__(self, "_store") + if name not in store.list_streams(): + raise AttributeError(f"No stream {name!r}. Available: {store.list_streams()}") + return store.stream(name) + + def __getitem__(self, name: str) -> Stream[Any]: + store: Store = object.__getattribute__(self, "_store") + if name not in store.list_streams(): + raise KeyError(name) + return store.stream(name) + + def __dir__(self) -> list[str]: + store: Store = object.__getattribute__(self, "_store") + return store.list_streams() + + def __repr__(self) -> str: + names = object.__getattribute__(self, "_store").list_streams() + return f"StreamAccessor({names})" + + +class StoreConfig(BaseConfig): + """Store-level config. These are defaults inherited by all streams. + + Component fields accept either a class (instantiated per-stream) or + a live instance (used directly). Classes are the default; instances + are for overrides (e.g. spy stores in tests, shared external stores). + """ + + observation_store: type[ObservationStore] | ObservationStore | None = None # type: ignore[type-arg] + blob_store: type[BlobStore] | BlobStore | None = None + vector_store: type[VectorStore] | VectorStore | None = None + notifier: type[Notifier] | Notifier | None = None # type: ignore[type-arg] + eager_blobs: bool = False + + +class Store(Configurable[StoreConfig], CompositeResource): + """Top-level entry point — wraps a storage location (file, URL, etc.). + + Store directly manages streams. No Session layer. + """ + + default_config: type[StoreConfig] = StoreConfig + + def __init__(self, **kwargs: Any) -> None: + Configurable.__init__(self, **kwargs) + CompositeResource.__init__(self) + self._streams: dict[str, Stream[Any]] = {} + + @property + def streams(self) -> StreamAccessor: + """Attribute-style access to streams: ``store.streams.name``.""" + return StreamAccessor(self) + + @staticmethod + def _resolve_codec( + payload_type: type[Any] | None, raw_codec: Codec[Any] | str | None + ) -> Codec[Any]: + if isinstance(raw_codec, Codec): + return raw_codec + if isinstance(raw_codec, str): + module = ( + f"{payload_type.__module__}.{payload_type.__qualname__}" + if payload_type + else "builtins.object" + ) + return codec_from_id(raw_codec, module) + return codec_for(payload_type) + + def _create_backend( + self, name: str, payload_type: type[Any] | None = None, **config: Any + ) -> Backend[Any]: + """Create a Backend for the named stream. Called once per stream name.""" + codec = self._resolve_codec(payload_type, config.pop("codec", None)) + + # Instantiate or use provided instances + obs = config.pop("observation_store", self.config.observation_store) + if obs is None or isinstance(obs, type): + obs = (obs or ListObservationStore)(name=name) + obs.start() + + bs = config.pop("blob_store", self.config.blob_store) + if isinstance(bs, type): + bs = bs() + bs.start() + + vs = config.pop("vector_store", self.config.vector_store) + if isinstance(vs, type): + vs = vs() + vs.start() + + notifier = config.pop("notifier", self.config.notifier) + if notifier is None or isinstance(notifier, type): + notifier = (notifier or SubjectNotifier)() + + return Backend( + metadata_store=obs, + codec=codec, + blob_store=bs, + vector_store=vs, + notifier=notifier, + eager_blobs=config.get("eager_blobs", False), + ) + + def stream(self, name: str, payload_type: type[T] | None = None, **overrides: Any) -> Stream[T]: + """Get or create a named stream. Returns the same Stream on repeated calls. + + Per-stream ``overrides`` (e.g. ``blob_store=``, ``codec=``) are merged + on top of the store-level defaults from :class:`StoreConfig`. + """ + if name not in self._streams: + resolved = {**self.config.model_dump(exclude_none=True), **overrides} + backend = self._create_backend(name, payload_type, **resolved) + self._streams[name] = Stream(source=backend) + return cast("Stream[T]", self._streams[name]) + + def list_streams(self) -> list[str]: + """Return names of all streams in this store.""" + return list(self._streams.keys()) + + def delete_stream(self, name: str) -> None: + """Delete a stream by name (from cache and underlying storage).""" + self._streams.pop(name, None) diff --git a/dimos/memory2/store/memory.py b/dimos/memory2/store/memory.py new file mode 100644 index 0000000000..6aecde29dd --- /dev/null +++ b/dimos/memory2/store/memory.py @@ -0,0 +1,21 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.memory2.store.base import Store + + +class MemoryStore(Store): + """In-memory store for experimentation.""" + + pass diff --git a/dimos/memory2/store/sqlite.py b/dimos/memory2/store/sqlite.py new file mode 100644 index 0000000000..b655e0a8bc --- /dev/null +++ b/dimos/memory2/store/sqlite.py @@ -0,0 +1,217 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import sqlite3 +from typing import Any + +from dimos.memory2.backend import Backend +from dimos.memory2.blobstore.base import BlobStore +from dimos.memory2.blobstore.sqlite import SqliteBlobStore +from dimos.memory2.codecs.base import codec_id +from dimos.memory2.observationstore.sqlite import SqliteObservationStore +from dimos.memory2.registry import RegistryStore, deserialize_component, qual +from dimos.memory2.store.base import Store, StoreConfig +from dimos.memory2.utils.sqlite import open_disposable_sqlite_connection +from dimos.memory2.utils.validation import validate_identifier +from dimos.memory2.vectorstore.base import VectorStore +from dimos.memory2.vectorstore.sqlite import SqliteVectorStore + + +class SqliteStoreConfig(StoreConfig): + """Config for SQLite-backed store.""" + + path: str = "memory.db" + page_size: int = 256 + + +class SqliteStore(Store): + """Store backed by a SQLite database file.""" + + default_config = SqliteStoreConfig + config: SqliteStoreConfig + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._registry_conn = self._open_connection() + self._registry = RegistryStore(conn=self._registry_conn) + + def _open_connection(self) -> sqlite3.Connection: + """Open a new WAL-mode connection with sqlite-vec loaded.""" + disposable, connection = open_disposable_sqlite_connection(self.config.path) + self.register_disposables(disposable) + return connection + + def _assemble_backend(self, name: str, stored: dict[str, Any]) -> Backend[Any]: + """Reconstruct a Backend from a stored config dict.""" + from dimos.memory2.codecs.base import codec_from_id + + payload_module = stored["payload_module"] + codec = codec_from_id(stored["codec_id"], payload_module) + eager_blobs = stored.get("eager_blobs", False) + page_size = stored.get("page_size", self.config.page_size) + + backend_conn = self._open_connection() + + # Reconstruct components from serialized config + bs_data = stored.get("blob_store") + if bs_data is not None: + bs_cfg = bs_data.get("config", {}) + if bs_cfg.get("path") is None and bs_data["class"] == qual(SqliteBlobStore): + bs: Any = SqliteBlobStore(conn=backend_conn) + else: + bs = deserialize_component(bs_data) + else: + bs = SqliteBlobStore(conn=backend_conn) + bs.start() + + vs_data = stored.get("vector_store") + if vs_data is not None: + vs_cfg = vs_data.get("config", {}) + if vs_cfg.get("path") is None and vs_data["class"] == qual(SqliteVectorStore): + vs: Any = SqliteVectorStore(conn=backend_conn) + else: + vs = deserialize_component(vs_data) + else: + vs = SqliteVectorStore(conn=backend_conn) + vs.start() + + notifier_data = stored.get("notifier") + if notifier_data is not None: + notifier = deserialize_component(notifier_data) + else: + from dimos.memory2.notifier.subject import SubjectNotifier + + notifier = SubjectNotifier() + + blob_store_conn_match = isinstance(bs, SqliteBlobStore) and bs._conn is backend_conn + + metadata_store: SqliteObservationStore[Any] = SqliteObservationStore( + conn=backend_conn, + name=name, + codec=codec, + blob_store_conn_match=blob_store_conn_match and eager_blobs, + page_size=page_size, + ) + metadata_store.start() + + backend: Backend[Any] = Backend( + metadata_store=metadata_store, + codec=codec, + blob_store=bs, + vector_store=vs, + notifier=notifier, + eager_blobs=eager_blobs, + ) + return backend + + @staticmethod + def _serialize_backend( + backend: Backend[Any], payload_module: str, page_size: int + ) -> dict[str, Any]: + """Serialize a backend's config for registry storage.""" + cfg: dict[str, Any] = { + "payload_module": payload_module, + "codec_id": codec_id(backend.codec), + "eager_blobs": backend.eager_blobs, + "page_size": page_size, + } + if backend.blob_store is not None: + cfg["blob_store"] = backend.blob_store.serialize() + if backend.vector_store is not None: + cfg["vector_store"] = backend.vector_store.serialize() + cfg["notifier"] = backend.notifier.serialize() + return cfg + + def _create_backend( + self, name: str, payload_type: type[Any] | None = None, **config: Any + ) -> Backend[Any]: + validate_identifier(name) + + stored = self._registry.get(name) + + if stored is not None: + # Load path: validate type, assemble from stored config + if payload_type is not None: + actual_module = f"{payload_type.__module__}.{payload_type.__qualname__}" + if actual_module != stored["payload_module"]: + raise ValueError( + f"Stream {name!r} was created with type {stored['payload_module']}, " + f"but opened with {actual_module}" + ) + return self._assemble_backend(name, stored) + + # Create path: inject conn-shared defaults, then delegate to base + if payload_type is None: + raise TypeError(f"Stream {name!r} does not exist yet — payload_type is required") + + backend_conn = self._open_connection() + + # Inject conn-shared instances unless user provided overrides + if not isinstance(config.get("blob_store"), BlobStore): + bs = SqliteBlobStore(conn=backend_conn) + bs.start() + config["blob_store"] = bs + if not isinstance(config.get("vector_store"), VectorStore): + vs = SqliteVectorStore(conn=backend_conn) + vs.start() + config["vector_store"] = vs + + # Resolve codec early — needed for SqliteObservationStore + codec = self._resolve_codec(payload_type, config.get("codec")) + config["codec"] = codec + + # Create SqliteObservationStore with conn-sharing + bs = config["blob_store"] + blob_conn_match = isinstance(bs, SqliteBlobStore) and bs._conn is backend_conn + eager_blobs = config.get("eager_blobs", False) + obs_store: SqliteObservationStore[Any] = SqliteObservationStore( + conn=backend_conn, + name=name, + codec=codec, + blob_store_conn_match=blob_conn_match and eager_blobs, + page_size=config.pop("page_size", self.config.page_size), + ) + obs_store.start() + config["observation_store"] = obs_store + + backend = super()._create_backend(name, payload_type, **config) + + # Persist to registry + payload_module = f"{payload_type.__module__}.{payload_type.__qualname__}" + self._registry.put( + name, + self._serialize_backend( + backend, payload_module, config["observation_store"].config.page_size + ), + ) + + return backend + + def list_streams(self) -> list[str]: + db_names = set(self._registry.list_streams()) + return sorted(db_names | set(self._streams.keys())) + + def delete_stream(self, name: str) -> None: + super().delete_stream(name) + self._registry_conn.execute(f'DROP TABLE IF EXISTS "{name}"') + self._registry_conn.execute(f'DROP TABLE IF EXISTS "{name}_blob"') + self._registry_conn.execute(f'DROP TABLE IF EXISTS "{name}_vec"') + self._registry_conn.execute(f'DROP TABLE IF EXISTS "{name}_rtree"') + self._registry.delete(name) + + def stop(self) -> None: + super().stop() + self._registry_conn.close() diff --git a/dimos/memory2/stream.py b/dimos/memory2/stream.py new file mode 100644 index 0000000000..545d387c32 --- /dev/null +++ b/dimos/memory2/stream.py @@ -0,0 +1,363 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +from dimos.core.resource import Resource +from dimos.memory2.buffer import BackpressureBuffer, KeepLast +from dimos.memory2.transform import FnIterTransformer, FnTransformer, Transformer +from dimos.memory2.type.filter import ( + AfterFilter, + AtFilter, + BeforeFilter, + Filter, + NearFilter, + PredicateFilter, + StreamQuery, + TagsFilter, + TimeRangeFilter, +) +from dimos.memory2.type.observation import EmbeddedObservation, Observation + +if TYPE_CHECKING: + from collections.abc import Callable, Iterator + + import reactivex + from reactivex.abc import DisposableBase, ObserverBase + + from dimos.memory2.backend import Backend + from dimos.models.embedding.base import Embedding + +T = TypeVar("T") +R = TypeVar("R") + + +class Stream(Resource, Generic[T]): + """Lazy, pull-based stream over observations. + + Every filter/transform method returns a new Stream — no computation + happens until iteration. Backends handle query application for stored + data; transform sources apply filters as Python predicates. + + Implements Resource so live streams can be cleanly stopped via + ``stop()`` or used as a context manager. + """ + + def __init__( + self, + source: Backend[T] | Stream[Any], + *, + xf: Transformer[Any, T] | None = None, + query: StreamQuery = StreamQuery(), + ) -> None: + self._source = source + self._xf = xf + self._query = query + + def start(self) -> None: + pass + + def stop(self) -> None: + """Close the live buffer (if any), unblocking iteration.""" + buf = self._query.live_buffer + if buf is not None: + buf.close() + if isinstance(self._source, Stream): + self._source.stop() + + def __str__(self) -> str: + # Walk the source chain to collect (xf, query) pairs + chain: list[tuple[Any, StreamQuery]] = [] + current: Any = self + while isinstance(current, Stream): + chain.append((current._xf, current._query)) + current = current._source + chain.reverse() # innermost first + + # current is the Backend + name = getattr(current, "name", "?") + result = f'Stream("{name}")' + + for xf, query in chain: + if xf is not None: + result += f" -> {xf}" + q_str = str(query) + if q_str: + result += f" | {q_str}" + + return result + + def is_live(self) -> bool: + """True if this stream (or any ancestor in the chain) is in live mode.""" + if self._query.live_buffer is not None: + return True + if isinstance(self._source, Stream): + return self._source.is_live() + return False + + def __iter__(self) -> Iterator[Observation[T]]: + return self._build_iter() + + def _build_iter(self) -> Iterator[Observation[T]]: + if isinstance(self._source, Stream): + return self._iter_transform() + # Backend handles all query application (including live if requested) + return self._source.iterate(self._query) + + def _iter_transform(self) -> Iterator[Observation[T]]: + """Iterate a transform source, applying query filters in Python.""" + assert isinstance(self._source, Stream) and self._xf is not None + it: Iterator[Observation[T]] = self._xf(iter(self._source)) + return self._query.apply(it, live=self.is_live()) + + def _replace_query(self, **overrides: Any) -> Stream[T]: + q = self._query + new_q = StreamQuery( + filters=overrides.get("filters", q.filters), + order_field=overrides.get("order_field", q.order_field), + order_desc=overrides.get("order_desc", q.order_desc), + limit_val=overrides.get("limit_val", q.limit_val), + offset_val=overrides.get("offset_val", q.offset_val), + live_buffer=overrides.get("live_buffer", q.live_buffer), + search_vec=overrides.get("search_vec", q.search_vec), + search_k=overrides.get("search_k", q.search_k), + search_text=overrides.get("search_text", q.search_text), + ) + return Stream(self._source, xf=self._xf, query=new_q) + + def _with_filter(self, f: Filter) -> Stream[T]: + return self._replace_query(filters=(*self._query.filters, f)) + + def after(self, t: float) -> Stream[T]: + return self._with_filter(AfterFilter(t)) + + def before(self, t: float) -> Stream[T]: + return self._with_filter(BeforeFilter(t)) + + def time_range(self, t1: float, t2: float) -> Stream[T]: + return self._with_filter(TimeRangeFilter(t1, t2)) + + def at(self, t: float, tolerance: float = 1.0) -> Stream[T]: + return self._with_filter(AtFilter(t, tolerance)) + + def near(self, pose: Any, radius: float) -> Stream[T]: + return self._with_filter(NearFilter(pose, radius)) + + def tags(self, **tags: Any) -> Stream[T]: + return self._with_filter(TagsFilter(tags)) + + def order_by(self, field: str, desc: bool = False) -> Stream[T]: + return self._replace_query(order_field=field, order_desc=desc) + + def limit(self, k: int) -> Stream[T]: + return self._replace_query(limit_val=k) + + def offset(self, n: int) -> Stream[T]: + return self._replace_query(offset_val=n) + + def search(self, query: Embedding, k: int) -> Stream[T]: + """Return top-k observations by cosine similarity to *query*. + + The backend handles the actual computation. ListObservationStore does + brute-force cosine; SqliteObservationStore pushes down to vec0. + """ + return self._replace_query(search_vec=query, search_k=k) + + def search_text(self, text: str) -> Stream[T]: + """Filter observations whose data contains *text*. + + ListObservationStore does case-insensitive substring match; + SqliteObservationStore (future) pushes down to FTS5. + """ + return self._replace_query(search_text=text) + + def filter(self, pred: Callable[[Observation[T]], bool]) -> Stream[T]: + """Filter by arbitrary predicate on the full Observation.""" + return self._with_filter(PredicateFilter(pred)) + + def map(self, fn: Callable[[Observation[T]], Observation[R]]) -> Stream[R]: + """Transform each observation's data via callable.""" + return self.transform(FnTransformer(lambda obs: fn(obs))) + + def transform( + self, + xf: Transformer[T, R] | Callable[[Iterator[Observation[T]]], Iterator[Observation[R]]], + ) -> Stream[R]: + """Wrap this stream with a transformer. Returns a new lazy Stream. + + Accepts a ``Transformer`` subclass or a bare callable / generator + function with the same ``Iterator[Obs] → Iterator[Obs]`` signature:: + + def detect(upstream): + for obs in upstream: + yield obs.derive(data=run_detector(obs.data)) + + images.transform(detect).save(detections) + """ + if not isinstance(xf, Transformer): + xf = FnIterTransformer(xf) + return Stream(source=self, xf=xf, query=StreamQuery()) + + def live(self, buffer: BackpressureBuffer[Observation[Any]] | None = None) -> Stream[T]: + """Return a stream whose iteration never ends — backfill then live tail. + + All backends support live mode via their built-in ``Notifier``. + Call .live() before .transform(), not after. + + Default buffer: KeepLast(). The backend handles subscription, dedup, + and backpressure — how it does so is its business. + """ + if isinstance(self._source, Stream): + raise TypeError( + "Cannot call .live() on a transform stream. " + "Call .live() on the source stream, then .transform()." + ) + buf = buffer if buffer is not None else KeepLast() + return self._replace_query(live_buffer=buf) + + def save(self, target: Stream[T]) -> Stream[T]: + """Sync terminal: iterate self, append each obs to target's backend. + + Returns the target stream for continued querying. + """ + if isinstance(target._source, Stream): + raise TypeError("Cannot save to a transform stream. Target must be backend-backed.") + backend = target._source + for obs in self: + backend.append(obs) + return target + + def fetch(self) -> list[Observation[T]]: + """Materialize all observations into a list.""" + if self.is_live(): + raise TypeError( + ".fetch() on a live stream would block forever. " + "Use .drain() or .save(target) instead." + ) + return list(self) + + def first(self) -> Observation[T]: + """Return the first matching observation.""" + it = iter(self.limit(1)) + try: + return next(it) + except StopIteration: + raise LookupError("No matching observation") from None + + def last(self) -> Observation[T]: + """Return the last matching observation (by timestamp).""" + return self.order_by("ts", desc=True).first() + + def count(self) -> int: + """Count matching observations.""" + if not isinstance(self._source, Stream): + return self._source.count(self._query) + if self.is_live(): + raise TypeError(".count() on a live transform stream would block forever.") + return sum(1 for _ in self) + + def exists(self) -> bool: + """Check if any matching observation exists.""" + return next(iter(self.limit(1)), None) is not None + + def get_time_range(self) -> tuple[float, float]: + """Return (min_ts, max_ts) for matching observations.""" + first = self.first() + last = self.last() + return (first.ts, last.ts) + + def summary(self) -> str: + """Return a short human-readable summary: count, time range, duration.""" + from datetime import datetime, timezone + + n = self.count() + if n == 0: + return f"{self}: empty" + + (t0, t1) = self.get_time_range() + + fmt = "%Y-%m-%d %H:%M:%S" + dt0 = datetime.fromtimestamp(t0, tz=timezone.utc).strftime(fmt) + dt1 = datetime.fromtimestamp(t1, tz=timezone.utc).strftime(fmt) + dur = t1 - t0 + return f"{self}: {n} items, {dt0} — {dt1} ({dur:.1f}s)" + + def drain(self) -> int: + """Consume all observations, discarding results. Returns count consumed. + + Use for side-effect pipelines (e.g. live embed-and-store) where you + don't need to collect results in memory. + """ + n = 0 + for _ in self: + n += 1 + return n + + def observable(self) -> reactivex.Observable[Observation[T]]: + """Convert this stream to an RxPY Observable. + + Iteration is scheduled on the dimos thread pool so subscribe() never + blocks the calling thread. + """ + import reactivex + import reactivex.operators as ops + + from dimos.utils.threadpool import get_scheduler + + return reactivex.from_iterable(self).pipe( + ops.subscribe_on(get_scheduler()), + ) + + def subscribe( + self, + on_next: Callable[[Observation[T]], None] | ObserverBase[Observation[T]] | None = None, + on_error: Callable[[Exception], None] | None = None, + on_completed: Callable[[], None] | None = None, + ) -> DisposableBase: + """Subscribe to this stream as an RxPY Observable.""" + return self.observable().subscribe( # type: ignore[call-overload] + on_next=on_next, + on_error=on_error, + on_completed=on_completed, + ) + + def append( + self, + payload: T, + *, + ts: float | None = None, + pose: Any | None = None, + tags: dict[str, Any] | None = None, + embedding: Embedding | None = None, + ) -> Observation[T]: + """Append to the backing store. Only works if source is a Backend.""" + if isinstance(self._source, Stream): + raise TypeError("Cannot append to a transform stream. Append to the source stream.") + _ts = ts if ts is not None else time.time() + _tags = tags or {} + if embedding is not None: + obs: Observation[T] = EmbeddedObservation( + id=-1, + ts=_ts, + pose=pose, + tags=_tags, + _data=payload, + embedding=embedding, + ) + else: + obs = Observation(id=-1, ts=_ts, pose=pose, tags=_tags, _data=payload) + return self._source.append(obs) diff --git a/dimos/memory2/streaming.md b/dimos/memory2/streaming.md new file mode 100644 index 0000000000..fd7f5519a1 --- /dev/null +++ b/dimos/memory2/streaming.md @@ -0,0 +1,109 @@ +# Stream evaluation model + +Stream methods fall into three categories: **lazy**, **materializing**, and **terminal**. The distinction matters for live (infinite) streams. + +`is_live()` walks the source chain to detect live mode — any stream whose ancestor called `.live()` returns `True`. +All materializing operations and unsafe terminals check this and raise `TypeError` immediately rather than silently hanging. + +## Lazy (streaming) + +These return generators — each observation flows through one at a time. Safe with live/infinite streams. No internal buffering between stages. + +| Method | How | +|---------------------------------------------------------------------------|-------------------------------------------------| +| `.after()` `.before()` `.time_range()` `.at()` `.near()` `.filter_tags()` | Filter predicates — skip non-matching obs | +| `.filter(pred)` | Same, user-defined predicate | +| `.transform(xf_or_fn)` / `.map(fn)` | Generator — yields transformed obs one by one | +| `.search_text(text)` | Generator — substring match filter | +| `.limit(k)` | `islice` — stops after k | +| `.offset(n)` | `islice` — skips first n | +| `.live()` | Enables live tail (backfill then block for new) | + +These compose freely. A chain like `.after(t).filter(pred).transform(xf).limit(10)` pulls lazily — the source only produces what the consumer asks for. + +## Materializing (collect-then-process) + +These **must consume the entire upstream** before producing output. On a live stream, they raise `TypeError` immediately. + +| Method | Why | Live behaviour | +|--------------------|----------------------------------------------|----------------| +| `.search(vec, k)` | Cosine-ranks all observations, returns top-k | TypeError | +| `.order_by(field)` | `sorted(list(it))` — needs all items to sort | TypeError | + +On a backend-backed stream (not a transform), both are pushed down to the backend which handles them on its own data structure (snapshot). The guard only fires when these appear on a **transform stream** whose upstream is live — detected via `is_live()`. + +### Rejected patterns (raise TypeError) + +```python +# TypeError: search requires finite data +stream.live().transform(Embed(model)).search(vec, k=5) + +# TypeError: order_by requires finite data +stream.live().transform(xf).order_by("ts", desc=True) + +# TypeError (via order_by): last() calls order_by internally +stream.live().transform(xf).last() +``` + +### Safe equivalents + +```python +# Search the stored data, not the live tail +results = stream.search(vec, k=5).fetch() + +# First works fine (uses limit(1), no materialization) +obs = stream.live().transform(xf).first() +``` + +## Terminal (consume the iterator) + +Terminals trigger iteration and return a value. They're the "go" button — nothing executes until a terminal is called. + +| Method | Returns | Memory | Live behaviour | +|-----------------|---------------------|--------------------|-----------------------------------------| +| `.fetch()` | `list[Observation]` | Grows with results | TypeError without `.limit()` first | +| `.drain()` | `int` (count) | Constant | Blocks forever, memory stays flat | +| `.save(target)` | target `Stream` | Constant | Blocks forever, appends each to store | +| `.first()` | `Observation` | Constant | Returns first item, then stops | +| `.exists()` | `bool` | Constant | Returns after one item check | +| `.last()` | `Observation` | Materializes | TypeError (uses order_by internally) | +| `.count()` | `int` | Constant | TypeError on transform streams | + +### Choosing the right terminal + +**Batch query** — collect results into memory: +```python +results = stream.after(t).search(vec, k=10).fetch() +``` + +**Live ingestion** — process forever, constant memory: +```python +# Embed and store continuously +stream.live().transform(EmbedImages(clip)).save(target) + +# Side-effect pipeline (no storage) +stream.live().transform(process).drain() +``` + +**One-shot** — get a single observation: +```python +obs = stream.live().transform(xf).first() # blocks until one arrives +has_data = stream.exists() # quick check +``` + +**Bounded live** — collect a fixed number from a live stream: +```python +batch = stream.live().limit(100).fetch() # OK — limit makes it finite +``` + +### Error summary + +All operations that would silently hang on live streams raise `TypeError` instead: + +| Pattern | Error | +|-------------------------------------|-----------------------------------------------| +| `live.transform(xf).search(vec, k)` | `.search() requires finite data` | +| `live.transform(xf).order_by("ts")` | `.order_by() requires finite data` | +| `live.fetch()` (without `.limit()`) | `.fetch() would collect forever` | +| `live.transform(xf).count()` | `.count() would block forever` | +| `live.transform(xf).last()` | `.order_by() requires finite data` (via last) | diff --git a/dimos/memory2/test_blobstore_integration.py b/dimos/memory2/test_blobstore_integration.py new file mode 100644 index 0000000000..51710005c4 --- /dev/null +++ b/dimos/memory2/test_blobstore_integration.py @@ -0,0 +1,161 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for BlobStore integration with MemoryStore/Backend.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import pytest + +from dimos.memory2.blobstore.file import FileBlobStore +from dimos.memory2.store.memory import MemoryStore +from dimos.memory2.type.observation import _UNLOADED +from dimos.models.embedding.base import Embedding + +if TYPE_CHECKING: + from collections.abc import Generator + from pathlib import Path + + +def _emb(vec: list[float]) -> Embedding: + v = np.array(vec, dtype=np.float32) + v /= np.linalg.norm(v) + 1e-10 + return Embedding(vector=v) + + +@pytest.fixture +def bs(tmp_path: Path) -> Generator[FileBlobStore, None, None]: + blob_store = FileBlobStore(root=str(tmp_path / "blobs")) + blob_store.start() + yield blob_store + blob_store.stop() + + +@pytest.fixture +def store(bs: FileBlobStore) -> Generator[MemoryStore, None, None]: + with MemoryStore(blob_store=bs) as s: + yield s + + +class TestBlobStoreIntegration: + def test_append_stores_in_blobstore(self, bs: FileBlobStore, store: MemoryStore) -> None: + s = store.stream("data", bytes) + s.append(b"hello", ts=1.0) + + # Blob was written to the file store + raw = bs.get("data", 0) + assert len(raw) > 0 + + def test_lazy_data_not_loaded_until_access(self, store: MemoryStore) -> None: + s = store.stream("data", str) + obs = s.append("payload", ts=1.0) + + # Data replaced with sentinel after append + assert isinstance(obs._data, type(_UNLOADED)) + assert obs._loader is not None + + def test_lazy_data_loads_correctly(self, store: MemoryStore) -> None: + s = store.stream("data", str) + s.append("payload", ts=1.0) + + result = s.first() + assert result.data == "payload" + + def test_eager_preloads_data(self, bs: FileBlobStore) -> None: + with MemoryStore(blob_store=bs, eager_blobs=True) as store: + s = store.stream("data", str) + s.append("payload", ts=1.0) + + # Iterating with eager_blobs triggers load + results = s.fetch() + assert len(results) == 1 + # Data should be loaded (not _UNLOADED) + assert not isinstance(results[0]._data, type(_UNLOADED)) + assert results[0].data == "payload" + + def test_per_stream_eager_override(self, store: MemoryStore) -> None: + # Default: lazy + lazy_stream = store.stream("lazy", str) + lazy_stream.append("lazy-val", ts=1.0) + + # Override: eager + eager_stream = store.stream("eager", str, eager_blobs=True) + eager_stream.append("eager-val", ts=1.0) + + lazy_results = lazy_stream.fetch() + eager_results = eager_stream.fetch() + + # Lazy: data stays unloaded until accessed + assert lazy_results[0].data == "lazy-val" + + # Eager: data pre-loaded during iteration + assert not isinstance(eager_results[0]._data, type(_UNLOADED)) + assert eager_results[0].data == "eager-val" + + def test_no_blobstore_unchanged(self) -> None: + with MemoryStore() as store: + s = store.stream("data", str) + obs = s.append("inline", ts=1.0) + + # Without blob store, data stays inline + assert obs._data == "inline" + assert obs._loader is None + assert obs.data == "inline" + + def test_blobstore_with_vector_search(self, bs: FileBlobStore) -> None: + from dimos.memory2.vectorstore.memory import MemoryVectorStore + + vs = MemoryVectorStore() + with MemoryStore(blob_store=bs, vector_store=vs) as store: + s = store.stream("vecs", str) + s.append("north", ts=1.0, embedding=_emb([0, 1, 0])) + s.append("east", ts=2.0, embedding=_emb([1, 0, 0])) + s.append("south", ts=3.0, embedding=_emb([0, -1, 0])) + + # Vector search triggers lazy load via obs.derive(data=obs.data, ...) + results = s.search(_emb([0, 1, 0]), k=2).fetch() + assert len(results) == 2 + assert results[0].data == "north" + assert results[0].similarity > 0.99 + + def test_blobstore_with_text_search(self, store: MemoryStore) -> None: + s = store.stream("logs", str) + s.append("motor fault", ts=1.0) + s.append("temperature ok", ts=2.0) + + # Text search triggers lazy load via str(obs.data) + results = s.search_text("motor").fetch() + assert len(results) == 1 + assert results[0].data == "motor fault" + + def test_multiple_appends_get_unique_blobs(self, store: MemoryStore) -> None: + s = store.stream("multi", str) + s.append("first", ts=1.0) + s.append("second", ts=2.0) + s.append("third", ts=3.0) + + results = s.fetch() + assert [r.data for r in results] == ["first", "second", "third"] + + def test_fetch_preserves_metadata(self, store: MemoryStore) -> None: + s = store.stream("meta", str) + s.append("val", ts=42.0, tags={"kind": "info"}) + + result = s.first() + assert result.ts == 42.0 + assert result.tags == {"kind": "info"} + assert result.data == "val" diff --git a/dimos/memory2/test_buffer.py b/dimos/memory2/test_buffer.py new file mode 100644 index 0000000000..f851a6fcee --- /dev/null +++ b/dimos/memory2/test_buffer.py @@ -0,0 +1,86 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for backpressure buffers.""" + +from __future__ import annotations + +import threading +import time + +import pytest + +from dimos.memory2.buffer import Bounded, ClosedError, DropNew, KeepLast, Unbounded + + +class TestBackpressureBuffers: + """Thread-safe buffers bridging push sources to pull consumers.""" + + def test_keep_last_overwrites(self): + buf = KeepLast[int]() + buf.put(1) + buf.put(2) + buf.put(3) + assert buf.take() == 3 + assert len(buf) == 0 + + def test_bounded_drops_oldest(self): + buf = Bounded[int](maxlen=2) + buf.put(1) + buf.put(2) + buf.put(3) # drops 1 + assert buf.take() == 2 + assert buf.take() == 3 + + def test_drop_new_rejects(self): + buf = DropNew[int](maxlen=2) + assert buf.put(1) is True + assert buf.put(2) is True + assert buf.put(3) is False # rejected + assert buf.take() == 1 + assert buf.take() == 2 + + def test_unbounded_keeps_all(self): + buf = Unbounded[int]() + for i in range(100): + buf.put(i) + assert len(buf) == 100 + + def test_close_signals_end(self): + buf = KeepLast[int]() + buf.close() + with pytest.raises(ClosedError): + buf.take() + + def test_buffer_is_iterable(self): + """Iterating a buffer yields items until closed.""" + buf = Unbounded[int]() + buf.put(1) + buf.put(2) + buf.close() + assert list(buf) == [1, 2] + + def test_take_blocks_until_put(self): + buf = KeepLast[int]() + result = [] + + def producer(): + time.sleep(0.05) + buf.put(42) + + t = threading.Thread(target=producer) + t.start() + result.append(buf.take(timeout=2.0)) + t.join() + assert result == [42] diff --git a/dimos/memory2/test_e2e.py b/dimos/memory2/test_e2e.py new file mode 100644 index 0000000000..48f5e680fd --- /dev/null +++ b/dimos/memory2/test_e2e.py @@ -0,0 +1,255 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""E2E test: import legacy pickle replays into memory2 SqliteStore.""" + +from __future__ import annotations + +import bisect +from typing import TYPE_CHECKING, Any + +import pytest + +from dimos.memory2.store.sqlite import SqliteStore +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.utils.data import get_data_dir +from dimos.utils.testing.replay import TimedSensorReplay + +if TYPE_CHECKING: + from collections.abc import Generator + +DB_PATH = get_data_dir() / "go2_bigoffice.db" + + +@pytest.fixture(scope="module") +def session() -> Generator[SqliteStore, None, None]: + store = SqliteStore(path=str(DB_PATH)) + with store: + yield store + store.stop() + + +class PoseIndex: + """Preloaded odom data with O(log n) closest-timestamp lookup.""" + + def __init__(self, replay: TimedSensorReplay) -> None: # type: ignore[type-arg] + self._timestamps: list[float] = [] + self._data: list[Any] = [] + for ts, data in replay.iterate_ts(): + self._timestamps.append(ts) + self._data.append(data) + + def find_closest(self, ts: float) -> Any | None: + if not self._timestamps: + return None + idx = bisect.bisect_left(self._timestamps, ts) + # Compare the two candidates around the insertion point + if idx == 0: + return self._data[0] + if idx >= len(self._timestamps): + return self._data[-1] + if ts - self._timestamps[idx - 1] <= self._timestamps[idx] - ts: + return self._data[idx - 1] + return self._data[idx] + + +@pytest.fixture(scope="module") +def video_replay() -> TimedSensorReplay: + return TimedSensorReplay("unitree_go2_bigoffice/video") + + +@pytest.fixture(scope="module") +def odom_index() -> PoseIndex: + return PoseIndex(TimedSensorReplay("unitree_go2_bigoffice/odom")) + + +@pytest.fixture(scope="module") +def lidar_replay() -> TimedSensorReplay: + return TimedSensorReplay("unitree_go2_bigoffice/lidar") + + +@pytest.mark.tool +class TestImportReplay: + """Import legacy pickle replay data into a memory2 SqliteStore.""" + + def test_import_video( + self, + session: SqliteStore, + video_replay: TimedSensorReplay, # type: ignore[type-arg] + odom_index: PoseIndex, + ) -> None: + with session.stream("color_image", Image) as video: + count = 0 + for ts, frame in video_replay.iterate_ts(): + pose = odom_index.find_closest(ts) + print("import", frame) + video.append(frame, ts=ts, pose=pose) + count += 1 + + assert count > 0 + assert video.count() == count + print(f"Imported {count} video frames") + + def test_import_lidar( + self, + session: SqliteStore, + lidar_replay: TimedSensorReplay, # type: ignore[type-arg] + odom_index: PoseIndex, + ) -> None: + # can also be explicit here + # lidar = session.stream("lidar", PointCloud2, codec=Lz4Codec(LcmCodec(PointCloud2))) + lidar = session.stream("lidar", PointCloud2, codec="lz4+lcm") + + count = 0 + for ts, frame in lidar_replay.iterate_ts(): + pose = odom_index.find_closest(ts) + print("import", frame) + lidar.append(frame, ts=ts, pose=pose) + count += 1 + + assert count > 0 + assert lidar.count() == count + print(f"Imported {count} lidar frames") + + def test_query_imported_data(self, session: SqliteStore) -> None: + video = session.stream("color_image", Image) + lidar = session.stream("lidar", PointCloud2) + + assert video.exists() + assert lidar.exists() + + first_frame = video.first() + last_frame = video.last() + assert first_frame.ts < last_frame.ts + + mid_ts = (first_frame.ts + last_frame.ts) / 2 + subset = video.time_range(first_frame.ts, mid_ts).fetch() + assert 0 < len(subset) < video.count() + + streams = session.list_streams() + assert "color_image" in streams + assert "lidar" in streams + + +class TestE2EQuery: + """Query operations against real robot replay data.""" + + def test_list_streams(self, session: SqliteStore) -> None: + streams = session.list_streams() + print(streams) + + assert "color_image" in streams + assert "lidar" in streams + assert session.streams.color_image + assert session.streams.lidar + + print(session.streams.lidar) + + def test_video_count(self, session: SqliteStore) -> None: + video = session.stream("color_image", Image) + assert video.count() > 1000 + + def test_lidar_count(self, session: SqliteStore) -> None: + lidar = session.stream("lidar", PointCloud2) + assert lidar.count() > 1000 + + def test_first_last_timestamps(self, session: SqliteStore) -> None: + video = session.stream("color_image", Image) + first = video.first() + last = video.last() + assert first.ts < last.ts + duration = last.ts - first.ts + assert duration > 10.0 # at least 10s of data + + def test_time_range_filter(self, session: SqliteStore) -> None: + video = session.stream("color_image", Image) + first = video.first() + + # Grab first 5 seconds + window = video.time_range(first.ts, first.ts + 5.0).fetch() + assert len(window) > 0 + assert len(window) < video.count() + assert all(first.ts <= obs.ts <= first.ts + 5.0 for obs in window) + + def test_limit_offset_pagination(self, session: SqliteStore) -> None: + video = session.stream("color_image", Image) + page1 = video.limit(10).fetch() + page2 = video.offset(10).limit(10).fetch() + + assert len(page1) == 10 + assert len(page2) == 10 + assert page1[-1].ts < page2[0].ts # no overlap + + def test_order_by_desc(self, session: SqliteStore) -> None: + video = session.stream("color_image", Image) + last_10 = video.order_by("ts", desc=True).limit(10).fetch() + + assert len(last_10) == 10 + assert all(last_10[i].ts >= last_10[i + 1].ts for i in range(9)) + + def test_lazy_data_loads_correctly(self, session: SqliteStore) -> None: + """Verify lazy blob loading returns valid Image data.""" + from dimos.memory2.type.observation import _Unloaded + + video = session.stream("color_image", Image) + obs = next(iter(video.limit(1))) + + # Should start lazy + assert isinstance(obs._data, _Unloaded) + + # Trigger load + frame = obs.data + assert isinstance(frame, Image) + assert frame.width > 0 + assert frame.height > 0 + + def test_iterate_window_decodes_all(self, session: SqliteStore) -> None: + """Iterate a time window and verify every frame decodes.""" + video = session.stream("color_image", Image) + first_ts = video.first().ts + + window = video.time_range(first_ts, first_ts + 2.0) + count = 0 + for obs in window: + frame = obs.data + assert isinstance(frame, Image) + count += 1 + assert count > 0 + + def test_lidar_data_loads(self, session: SqliteStore) -> None: + """Verify lidar blobs decode to PointCloud2.""" + lidar = session.stream("lidar", PointCloud2) + frame = lidar.first().data + assert isinstance(frame, PointCloud2) + + def test_poses_present(self, session: SqliteStore) -> None: + """Verify poses were stored during import.""" + video = session.stream("color_image", Image) + obs = video.first() + assert obs.pose is not None + + def test_cross_stream_time_alignment(self, session: SqliteStore) -> None: + """Video and lidar should overlap in time.""" + video = session.stream("color_image", Image) + lidar = session.stream("lidar", PointCloud2) + + v_first, v_last = video.first().ts, video.last().ts + l_first, l_last = lidar.first().ts, lidar.last().ts + + # Overlap: max of starts < min of ends + overlap_start = max(v_first, l_first) + overlap_end = min(v_last, l_last) + assert overlap_start < overlap_end, "Video and lidar should overlap in time" + assert overlap_start < overlap_end, "Video and lidar should overlap in time" diff --git a/dimos/memory2/test_e2e_processing.py b/dimos/memory2/test_e2e_processing.py new file mode 100644 index 0000000000..81eba5c2a8 --- /dev/null +++ b/dimos/memory2/test_e2e_processing.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 + + +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/dimos/memory2/test_embedding.py b/dimos/memory2/test_embedding.py new file mode 100644 index 0000000000..57d66da278 --- /dev/null +++ b/dimos/memory2/test_embedding.py @@ -0,0 +1,396 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for embedding layer: EmbeddedObservation, vector search, text search, transformers.""" + +from __future__ import annotations + +import numpy as np +import pytest + +from dimos.memory2.type.observation import EmbeddedObservation, Observation +from dimos.models.embedding.base import Embedding + + +def _emb(vec: list[float]) -> Embedding: + """Return a unit-normalized Embedding.""" + v = np.array(vec, dtype=np.float32) + v /= np.linalg.norm(v) + 1e-10 + return Embedding(vector=v) + + +class TestEmbeddedObservation: + def test_construction(self) -> None: + emb = _emb([1, 0, 0]) + obs = EmbeddedObservation(id=0, ts=1.0, _data="hello", embedding=emb) + assert obs.data == "hello" + assert obs.embedding is emb + assert obs.similarity is None + + def test_is_observation(self) -> None: + obs = EmbeddedObservation(id=0, ts=1.0, _data="x", embedding=_emb([1, 0])) + assert isinstance(obs, Observation) + + def test_derive_preserves_embedding(self) -> None: + emb = _emb([1, 0, 0]) + obs = EmbeddedObservation(id=0, ts=1.0, _data="a", embedding=emb) + derived = obs.derive(data="b") + assert isinstance(derived, EmbeddedObservation) + assert derived.embedding is emb + assert derived.data == "b" + + def test_derive_replaces_embedding(self) -> None: + old = _emb([1, 0, 0]) + new = _emb([0, 1, 0]) + obs = EmbeddedObservation(id=0, ts=1.0, _data="a", embedding=old) + derived = obs.derive(data="a", embedding=new) + assert derived.embedding is new + + def test_derive_preserves_similarity(self) -> None: + obs = EmbeddedObservation(id=0, ts=1.0, _data="a", embedding=_emb([1, 0]), similarity=0.95) + derived = obs.derive(data="b") + assert derived.similarity == 0.95 + + def test_observation_derive_promotes_to_embedded(self) -> None: + obs = Observation(id=0, ts=1.0, _data="plain") + emb = _emb([1, 0, 0]) + derived = obs.derive(data="plain", embedding=emb) + assert isinstance(derived, EmbeddedObservation) + assert derived.embedding is emb + + def test_observation_derive_without_embedding_stays_observation(self) -> None: + obs = Observation(id=0, ts=1.0, _data="plain") + derived = obs.derive(data="still plain") + assert type(derived) is Observation + + +class TestListBackendEmbedding: + def test_append_with_embedding(self, memory_store) -> None: + s = memory_store.stream("vecs", str) + emb = _emb([1, 0, 0]) + obs = s.append("hello", embedding=emb) + assert isinstance(obs, EmbeddedObservation) + assert obs.embedding is emb + + def test_append_without_embedding(self, memory_store) -> None: + s = memory_store.stream("plain", str) + obs = s.append("hello") + assert type(obs) is Observation + + def test_search_returns_top_k(self, memory_store) -> None: + s = memory_store.stream("vecs", str) + s.append("north", embedding=_emb([0, 1, 0])) + s.append("east", embedding=_emb([1, 0, 0])) + s.append("south", embedding=_emb([0, -1, 0])) + s.append("west", embedding=_emb([-1, 0, 0])) + + results = s.search(_emb([0, 1, 0]), k=2).fetch() + assert len(results) == 2 + assert results[0].data == "north" + assert results[0].similarity is not None + assert results[0].similarity > 0.99 + + def test_search_sorted_by_similarity(self, memory_store) -> None: + s = memory_store.stream("vecs", str) + s.append("far", embedding=_emb([0, -1, 0])) + s.append("close", embedding=_emb([0.9, 0.1, 0])) + s.append("exact", embedding=_emb([1, 0, 0])) + + results = s.search(_emb([1, 0, 0]), k=3).fetch() + assert results[0].data == "exact" + assert results[1].data == "close" + assert results[2].data == "far" + # Descending similarity + assert results[0].similarity >= results[1].similarity >= results[2].similarity + + def test_search_skips_non_embedded(self, memory_store) -> None: + s = memory_store.stream("mixed", str) + s.append("plain") # no embedding + s.append("embedded", embedding=_emb([1, 0, 0])) + + results = s.search(_emb([1, 0, 0]), k=10).fetch() + assert len(results) == 1 + assert results[0].data == "embedded" + + def test_search_with_filters(self, memory_store) -> None: + s = memory_store.stream("vecs", str) + s.append("early", ts=10.0, embedding=_emb([1, 0, 0])) + s.append("late", ts=20.0, embedding=_emb([1, 0, 0])) + + # Only the late one should pass the after filter + results = s.after(15.0).search(_emb([1, 0, 0]), k=10).fetch() + assert len(results) == 1 + assert results[0].data == "late" + + def test_search_with_limit(self, memory_store) -> None: + s = memory_store.stream("vecs", str) + for i in range(10): + s.append(f"item{i}", embedding=_emb([1, 0, 0])) + + # search k=5 then limit 2 + results = s.search(_emb([1, 0, 0]), k=5).limit(2).fetch() + assert len(results) == 2 + + def test_search_with_live_raises(self, memory_store) -> None: + s = memory_store.stream("vecs", str) + s.append("x", embedding=_emb([1, 0, 0])) + with pytest.raises(TypeError, match="Cannot combine"): + list(s.live().search(_emb([1, 0, 0]), k=5)) + + +class TestTextSearch: + def test_search_text_substring(self, memory_store) -> None: + s = memory_store.stream("logs", str) + s.append("motor fault detected") + s.append("temperature normal") + s.append("motor overheating") + + results = s.search_text("motor").fetch() + assert len(results) == 2 + assert {r.data for r in results} == {"motor fault detected", "motor overheating"} + + def test_search_text_case_insensitive(self, memory_store) -> None: + s = memory_store.stream("logs", str) + s.append("Motor Fault") + s.append("other event") + + results = s.search_text("motor fault").fetch() + assert len(results) == 1 + + def test_search_text_with_filters(self, memory_store) -> None: + s = memory_store.stream("logs", str) + s.append("motor fault", ts=10.0) + s.append("motor warning", ts=20.0) + s.append("motor fault", ts=30.0) + + results = s.after(15.0).search_text("fault").fetch() + assert len(results) == 1 + assert results[0].ts == 30.0 + + def test_search_text_no_match(self, memory_store) -> None: + s = memory_store.stream("logs", str) + s.append("all clear") + + results = s.search_text("motor").fetch() + assert len(results) == 0 + + +class TestSaveEmbeddings: + def test_save_preserves_embeddings(self, memory_store) -> None: + src = memory_store.stream("source", str) + dst = memory_store.stream("dest", str) + + emb = _emb([1, 0, 0]) + src.append("item", embedding=emb) + src.save(dst) + + results = dst.fetch() + assert len(results) == 1 + assert isinstance(results[0], EmbeddedObservation) + # Same vector content (different Embedding instance after re-append) + np.testing.assert_array_almost_equal(results[0].embedding.to_numpy(), emb.to_numpy()) + + def test_save_mixed_embedded_and_plain(self, memory_store) -> None: + src = memory_store.stream("source", str) + dst = memory_store.stream("dest", str) + + src.append("plain") + src.append("embedded", embedding=_emb([0, 1, 0])) + src.save(dst) + + results = dst.fetch() + assert len(results) == 2 + assert type(results[0]) is Observation + assert isinstance(results[1], EmbeddedObservation) + + +class _MockEmbeddingModel: + """Fake EmbeddingModel that returns deterministic unit vectors.""" + + device = "cpu" + + def embed(self, *images): + vecs = [] + for img in images: + rng = np.random.default_rng(hash(str(img)) % 2**32) + v = rng.standard_normal(8).astype(np.float32) + v /= np.linalg.norm(v) + vecs.append(Embedding(vector=v)) + return vecs if len(vecs) > 1 else vecs[0] + + def embed_text(self, *texts): + vecs = [] + for text in texts: + rng = np.random.default_rng(hash(text) % 2**32) + v = rng.standard_normal(8).astype(np.float32) + v /= np.linalg.norm(v) + vecs.append(Embedding(vector=v)) + return vecs if len(vecs) > 1 else vecs[0] + + +class TestEmbedTransformers: + def test_embed_images_produces_embedded_observations(self, memory_store) -> None: + from dimos.memory2.embed import EmbedImages + + model = _MockEmbeddingModel() + s = memory_store.stream("imgs", str) + s.append("img1", ts=1.0) + s.append("img2", ts=2.0) + + results = s.transform(EmbedImages(model)).fetch() + assert len(results) == 2 + for obs in results: + assert isinstance(obs, EmbeddedObservation) + assert isinstance(obs.embedding, Embedding) + assert obs.embedding.to_numpy().shape == (8,) + + def test_embed_text_produces_embedded_observations(self, memory_store) -> None: + from dimos.memory2.embed import EmbedText + + model = _MockEmbeddingModel() + s = memory_store.stream("logs", str) + s.append("motor fault", ts=1.0) + s.append("all clear", ts=2.0) + + results = s.transform(EmbedText(model)).fetch() + assert len(results) == 2 + for obs in results: + assert isinstance(obs, EmbeddedObservation) + assert isinstance(obs.embedding, Embedding) + + def test_embed_preserves_data(self, memory_store) -> None: + from dimos.memory2.embed import EmbedText + + model = _MockEmbeddingModel() + s = memory_store.stream("logs", str) + s.append("hello", ts=1.0) + + result = s.transform(EmbedText(model)).first() + assert result.data == "hello" + + def test_embed_then_search(self, memory_store) -> None: + from dimos.memory2.embed import EmbedText + + model = _MockEmbeddingModel() + s = memory_store.stream("logs", str) + for i in range(10): + s.append(f"log entry {i}", ts=float(i)) + + embedded = s.transform(EmbedText(model)) + # Get the embedding for the first item, then search for similar + first_emb = embedded.first().embedding + results = embedded.search(first_emb, k=3).fetch() + assert len(results) == 3 + # First result should be the exact match + assert results[0].similarity is not None + assert results[0].similarity > 0.99 + + def test_embed_batching(self, memory_store) -> None: + from dimos.memory2.embed import EmbedText + + call_sizes: list[int] = [] + + class _TrackingModel(_MockEmbeddingModel): + def embed_text(self, *texts): + call_sizes.append(len(texts)) + return super().embed_text(*texts) + + model = _TrackingModel() + s = memory_store.stream("logs", str) + for i in range(5): + s.append(f"entry {i}") + + list(s.transform(EmbedText(model, batch_size=2))) + # 5 items with batch_size=2 → 3 calls (2, 2, 1) + assert call_sizes == [2, 2, 1] + + +class TestPluggableVectorStore: + """Verify that injecting a VectorStore via store config actually delegates search.""" + + def test_append_stores_in_vector_store(self) -> None: + from dimos.memory2.store.memory import MemoryStore + from dimos.memory2.vectorstore.memory import MemoryVectorStore + + vs = MemoryVectorStore() + with MemoryStore(vector_store=vs) as store: + s = store.stream("vecs", str) + s.append("hello", embedding=_emb([1, 0, 0])) + s.append("world", embedding=_emb([0, 1, 0])) + + assert len(vs._vectors["vecs"]) == 2 + + def test_append_without_embedding_skips_vector_store(self) -> None: + from dimos.memory2.store.memory import MemoryStore + from dimos.memory2.vectorstore.memory import MemoryVectorStore + + vs = MemoryVectorStore() + with MemoryStore(vector_store=vs) as store: + s = store.stream("plain", str) + s.append("no embedding") + + assert "plain" not in vs._vectors + + def test_search_uses_vector_store(self) -> None: + from dimos.memory2.store.memory import MemoryStore + from dimos.memory2.vectorstore.memory import MemoryVectorStore + + vs = MemoryVectorStore() + with MemoryStore(vector_store=vs) as store: + s = store.stream("vecs", str) + s.append("north", embedding=_emb([0, 1, 0])) + s.append("east", embedding=_emb([1, 0, 0])) + s.append("south", embedding=_emb([0, -1, 0])) + s.append("west", embedding=_emb([-1, 0, 0])) + + results = s.search(_emb([0, 1, 0]), k=2).fetch() + assert len(results) == 2 + assert results[0].data == "north" + assert results[0].similarity is not None + assert results[0].similarity > 0.99 + + def test_search_with_filters_via_vector_store(self) -> None: + from dimos.memory2.store.memory import MemoryStore + from dimos.memory2.vectorstore.memory import MemoryVectorStore + + vs = MemoryVectorStore() + with MemoryStore(vector_store=vs) as store: + s = store.stream("vecs", str) + s.append("early", ts=10.0, embedding=_emb([1, 0, 0])) + s.append("late", ts=20.0, embedding=_emb([1, 0, 0])) + + # Filter + search: only "late" passes the after filter + results = s.after(15.0).search(_emb([1, 0, 0]), k=10).fetch() + assert len(results) == 1 + assert results[0].data == "late" + + def test_per_stream_vector_store_override(self) -> None: + from dimos.memory2.store.memory import MemoryStore + from dimos.memory2.vectorstore.memory import MemoryVectorStore + + vs_default = MemoryVectorStore() + vs_override = MemoryVectorStore() + with MemoryStore(vector_store=vs_default) as store: + # Stream with default vector store + s1 = store.stream("s1", str) + s1.append("a", embedding=_emb([1, 0, 0])) + + # Stream with overridden vector store + s2 = store.stream("s2", str, vector_store=vs_override) + s2.append("b", embedding=_emb([0, 1, 0])) + + assert "s1" in vs_default._vectors + assert "s1" not in vs_override._vectors + assert "s2" in vs_override._vectors + assert "s2" not in vs_default._vectors diff --git a/dimos/memory2/test_registry.py b/dimos/memory2/test_registry.py new file mode 100644 index 0000000000..d611073075 --- /dev/null +++ b/dimos/memory2/test_registry.py @@ -0,0 +1,263 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RegistryStore and serialization round-trips.""" + +from __future__ import annotations + +import pytest + +from dimos.memory2.blobstore.file import FileBlobStore +from dimos.memory2.blobstore.sqlite import SqliteBlobStore, SqliteBlobStoreConfig +from dimos.memory2.notifier.subject import SubjectNotifier +from dimos.memory2.observationstore.sqlite import SqliteObservationStoreConfig +from dimos.memory2.registry import RegistryStore, deserialize_component, qual +from dimos.memory2.store.sqlite import SqliteStore +from dimos.memory2.vectorstore.sqlite import SqliteVectorStore, SqliteVectorStoreConfig + + +class TestQual: + def test_qual_blob_store(self) -> None: + assert qual(SqliteBlobStore) == "dimos.memory2.blobstore.sqlite.SqliteBlobStore" + + def test_qual_file_blob_store(self) -> None: + assert qual(FileBlobStore) == "dimos.memory2.blobstore.file.FileBlobStore" + + def test_qual_vector_store(self) -> None: + assert qual(SqliteVectorStore) == "dimos.memory2.vectorstore.sqlite.SqliteVectorStore" + + def test_qual_notifier(self) -> None: + assert qual(SubjectNotifier) == "dimos.memory2.notifier.subject.SubjectNotifier" + + +class TestRegistryStore: + def test_put_get_round_trip(self, tmp_path) -> None: + from dimos.memory2.utils.sqlite import open_sqlite_connection + + conn = open_sqlite_connection(str(tmp_path / "reg.db")) + reg = RegistryStore(conn=conn) + + config = {"payload_module": "builtins.str", "codec_id": "pickle"} + reg.put("my_stream", config) + result = reg.get("my_stream") + assert result == config + conn.close() + + def test_get_missing(self, tmp_path) -> None: + from dimos.memory2.utils.sqlite import open_sqlite_connection + + conn = open_sqlite_connection(str(tmp_path / "reg.db")) + reg = RegistryStore(conn=conn) + assert reg.get("nonexistent") is None + conn.close() + + def test_list_streams(self, tmp_path) -> None: + from dimos.memory2.utils.sqlite import open_sqlite_connection + + conn = open_sqlite_connection(str(tmp_path / "reg.db")) + reg = RegistryStore(conn=conn) + reg.put("a", {"x": 1}) + reg.put("b", {"x": 2}) + assert sorted(reg.list_streams()) == ["a", "b"] + conn.close() + + def test_delete(self, tmp_path) -> None: + from dimos.memory2.utils.sqlite import open_sqlite_connection + + conn = open_sqlite_connection(str(tmp_path / "reg.db")) + reg = RegistryStore(conn=conn) + reg.put("x", {"y": 1}) + reg.delete("x") + assert reg.get("x") is None + conn.close() + + def test_upsert(self, tmp_path) -> None: + from dimos.memory2.utils.sqlite import open_sqlite_connection + + conn = open_sqlite_connection(str(tmp_path / "reg.db")) + reg = RegistryStore(conn=conn) + reg.put("x", {"v": 1}) + reg.put("x", {"v": 2}) + assert reg.get("x") == {"v": 2} + conn.close() + + +class TestComponentSerialization: + def test_sqlite_observation_store_config(self) -> None: + cfg = SqliteObservationStoreConfig(page_size=512, path="test.db") + dumped = cfg.model_dump() + restored = SqliteObservationStoreConfig(**dumped) + assert restored.page_size == 512 + + def test_sqlite_blob_store_config(self) -> None: + cfg = SqliteBlobStoreConfig(path="/tmp/test.db") + dumped = cfg.model_dump() + restored = SqliteBlobStoreConfig(**dumped) + assert restored.path == "/tmp/test.db" + + def test_sqlite_blob_store_roundtrip(self, tmp_path) -> None: + store = SqliteBlobStore(path=str(tmp_path / "blob.db")) + data = store.serialize() + assert data["class"] == qual(SqliteBlobStore) + restored = deserialize_component(data) + assert isinstance(restored, SqliteBlobStore) + + def test_file_blob_store_roundtrip(self, tmp_path) -> None: + store = FileBlobStore(root=str(tmp_path / "blobs")) + data = store.serialize() + assert data["class"] == qual(FileBlobStore) + restored = deserialize_component(data) + assert isinstance(restored, FileBlobStore) + assert str(restored._root) == str(tmp_path / "blobs") + + def test_sqlite_vector_store_config(self) -> None: + cfg = SqliteVectorStoreConfig(path="/tmp/vec.db") + dumped = cfg.model_dump() + restored = SqliteVectorStoreConfig(**dumped) + assert restored.path == "/tmp/vec.db" + + def test_sqlite_vector_store_roundtrip(self, tmp_path) -> None: + store = SqliteVectorStore(path=str(tmp_path / "vec.db")) + data = store.serialize() + assert data["class"] == qual(SqliteVectorStore) + restored = deserialize_component(data) + assert isinstance(restored, SqliteVectorStore) + + def test_subject_notifier_roundtrip(self) -> None: + notifier = SubjectNotifier() + data = notifier.serialize() + assert data["class"] == qual(SubjectNotifier) + restored = deserialize_component(data) + assert isinstance(restored, SubjectNotifier) + + def test_deserialize_component(self, tmp_path) -> None: + store = FileBlobStore(root=str(tmp_path / "blobs")) + data = store.serialize() + restored = deserialize_component(data) + assert isinstance(restored, FileBlobStore) + + +class TestBackendSerialization: + def test_backend_serialize(self, tmp_path) -> None: + from dimos.memory2.backend import Backend + from dimos.memory2.codecs.pickle import PickleCodec + from dimos.memory2.observationstore.memory import ListObservationStore + + backend = Backend( + metadata_store=ListObservationStore(name="test"), + codec=PickleCodec(), + blob_store=FileBlobStore(root=str(tmp_path / "blobs")), + notifier=SubjectNotifier(), + ) + data = backend.serialize() + assert data["codec_id"] == "pickle" + assert data["blob_store"]["class"] == qual(FileBlobStore) + assert data["notifier"]["class"] == qual(SubjectNotifier) + + +class TestStoreReopen: + def test_reopen_preserves_data(self, tmp_path) -> None: + """Create a store, write data, close, reopen, read back.""" + db = str(tmp_path / "test.db") + with SqliteStore(path=db) as store: + s = store.stream("nums", int) + s.append(42, ts=1.0) + s.append(99, ts=2.0) + + with SqliteStore(path=db) as store2: + s2 = store2.stream("nums", int) + assert s2.count() == 2 + obs = s2.fetch() + assert [o.data for o in obs] == [42, 99] + + def test_reopen_preserves_codec(self, tmp_path) -> None: + """Codec ID is stored and restored on reopen.""" + db = str(tmp_path / "codec.db") + with SqliteStore(path=db) as store: + s = store.stream("data", str, codec="pickle") + s.append("hello", ts=1.0) + + with SqliteStore(path=db) as store2: + s2 = store2.stream("data", str) + assert s2.first().data == "hello" + + def test_reopen_preserves_eager_blobs(self, tmp_path) -> None: + """eager_blobs override is stored in registry and restored on reopen.""" + db = str(tmp_path / "eager.db") + with SqliteStore(path=db) as store: + s = store.stream("data", str, eager_blobs=True) + s.append("test", ts=1.0) + + with SqliteStore(path=db) as store2: + stored = store2._registry.get("data") + assert stored is not None + assert stored["eager_blobs"] is True + + def test_reopen_preserves_file_blob_store(self, tmp_path) -> None: + """FileBlobStore override is stored and restored on reopen.""" + db = str(tmp_path / "file_blob.db") + blob_dir = str(tmp_path / "blobs") + with SqliteStore(path=db) as store: + fbs = FileBlobStore(root=blob_dir) + fbs.start() + s = store.stream("imgs", str, blob_store=fbs) + s.append("image_data", ts=1.0) + + with SqliteStore(path=db) as store2: + stored = store2._registry.get("imgs") + assert stored is not None + assert stored["blob_store"]["class"] == qual(FileBlobStore) + assert stored["blob_store"]["config"]["root"] == blob_dir + + def test_reopen_type_mismatch_raises(self, tmp_path) -> None: + """Opening a stream with a different payload type raises ValueError.""" + db = str(tmp_path / "mismatch.db") + with SqliteStore(path=db) as store: + store.stream("nums", int) + + with SqliteStore(path=db) as store2: + with pytest.raises(ValueError, match="was created with type"): + store2.stream("nums", str) + + def test_reopen_list_streams(self, tmp_path) -> None: + """list_streams includes streams from registry on reopen.""" + db = str(tmp_path / "list.db") + with SqliteStore(path=db) as store: + store.stream("a", int) + store.stream("b", str) + + with SqliteStore(path=db) as store2: + assert sorted(store2.list_streams()) == ["a", "b"] + + def test_reopen_without_payload_type(self, tmp_path) -> None: + """Reopening a known stream without payload_type works.""" + db = str(tmp_path / "no_type.db") + with SqliteStore(path=db) as store: + s = store.stream("data", str) + s.append("hello", ts=1.0) + + with SqliteStore(path=db) as store2: + s2 = store2.stream("data") + assert s2.first().data == "hello" + + def test_reopen_preserves_page_size(self, tmp_path) -> None: + """page_size is stored in registry and restored on reopen.""" + db = str(tmp_path / "page.db") + with SqliteStore(path=db, page_size=512) as store: + store.stream("data", str) + + with SqliteStore(path=db) as store2: + stored = store2._registry.get("data") + assert stored is not None + assert stored["page_size"] == 512 diff --git a/dimos/memory2/test_save.py b/dimos/memory2/test_save.py new file mode 100644 index 0000000000..13ee73d46a --- /dev/null +++ b/dimos/memory2/test_save.py @@ -0,0 +1,123 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Stream.save() and Notifier integration.""" + +from __future__ import annotations + +import pytest + +from dimos.memory2.backend import Backend +from dimos.memory2.codecs.pickle import PickleCodec +from dimos.memory2.notifier.base import Notifier +from dimos.memory2.observationstore.memory import ListObservationStore +from dimos.memory2.stream import Stream +from dimos.memory2.transform import FnTransformer +from dimos.memory2.type.observation import Observation + + +def _make_backend(name: str = "test") -> Backend[int]: + return Backend(metadata_store=ListObservationStore[int](name=name), codec=PickleCodec()) + + +def make_stream(n: int = 5, start_ts: float = 0.0) -> Stream[int]: + backend = _make_backend() + for i in range(n): + backend.append(Observation(id=-1, ts=start_ts + i, _data=i * 10)) + return Stream(source=backend) + + +# ═══════════════════════════════════════════════════════════════════ +# Protocol checks +# ═══════════════════════════════════════════════════════════════════ + + +class TestProtocol: + def test_backend_has_notifier(self) -> None: + b = _make_backend("x") + assert isinstance(b.notifier, Notifier) + + +# ═══════════════════════════════════════════════════════════════════ +# .save() +# ═══════════════════════════════════════════════════════════════════ + + +class TestSave: + def test_save_populates_target(self) -> None: + source = make_stream(3) + target = Stream(source=_make_backend("target")) + + source.save(target) + + results = target.fetch() + assert len(results) == 3 + assert [o.data for o in results] == [0, 10, 20] + + def test_save_returns_target_stream(self) -> None: + source = make_stream(2) + target = Stream(source=_make_backend("target")) + + result = source.save(target) + + assert result is target + + def test_save_preserves_data(self) -> None: + backend = _make_backend("src") + backend.append(Observation(id=-1, ts=1.0, pose=(1, 2, 3), tags={"label": "cat"}, _data=42)) + source = Stream(source=backend) + + target = Stream(source=_make_backend("dst")) + source.save(target) + + obs = target.first() + assert obs.data == 42 + assert obs.ts == 1.0 + assert obs.pose == (1, 2, 3) + assert obs.tags == {"label": "cat"} + + def test_save_with_transform(self) -> None: + source = make_stream(3) # data: 0, 10, 20 + doubled = source.transform(FnTransformer(lambda obs: obs.derive(data=obs.data * 2))) + + target = Stream(source=_make_backend("target")) + doubled.save(target) + + assert [o.data for o in target.fetch()] == [0, 20, 40] + + def test_save_rejects_transform_target(self) -> None: + source = make_stream(2) + base = make_stream(2) + transform_stream = base.transform(FnTransformer(lambda obs: obs.derive(obs.data))) + + with pytest.raises(TypeError, match="Cannot save to a transform stream"): + source.save(transform_stream) + + def test_save_target_queryable(self) -> None: + source = make_stream(5, start_ts=0.0) # ts: 0,1,2,3,4 + + target = Stream(source=_make_backend("target")) + result = source.save(target) + + after_2 = result.after(2.0).fetch() + assert [o.data for o in after_2] == [30, 40] + + def test_save_empty_source(self) -> None: + source = make_stream(0) + target = Stream(source=_make_backend("target")) + + result = source.save(target) + + assert result.count() == 0 + assert result.fetch() == [] diff --git a/dimos/memory2/test_store.py b/dimos/memory2/test_store.py new file mode 100644 index 0000000000..dfba6d6d2b --- /dev/null +++ b/dimos/memory2/test_store.py @@ -0,0 +1,527 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Grid tests for Store implementations. + +Runs the same test logic against every Store backend (MemoryStore, SqliteStore, ...). +The parametrized ``session`` fixture from conftest runs each test against both backends. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pytest + +from dimos.memory2.blobstore.base import BlobStore +from dimos.memory2.vectorstore.base import VectorStore + +if TYPE_CHECKING: + from dimos.memory2.store.base import Store + + +class TestStoreBasic: + """Core store operations that every backend must support.""" + + def test_create_stream_and_append(self, session: Store) -> None: + s = session.stream("images", bytes) + obs = s.append(b"frame1", tags={"camera": "front"}) + + assert obs.data == b"frame1" + assert obs.tags["camera"] == "front" + assert obs.ts > 0 + + def test_append_multiple_and_fetch(self, session: Store) -> None: + s = session.stream("sensor", float) + s.append(1.0, ts=100.0) + s.append(2.0, ts=200.0) + s.append(3.0, ts=300.0) + + results = s.fetch() + assert len(results) == 3 + assert [o.data for o in results] == [1.0, 2.0, 3.0] + + def test_iterate_stream(self, session: Store) -> None: + s = session.stream("log", str) + s.append("a", ts=1.0) + s.append("b", ts=2.0) + + collected = [obs.data for obs in s] + assert collected == ["a", "b"] + + def test_count(self, session: Store) -> None: + s = session.stream("events", str) + assert s.count() == 0 + s.append("x") + s.append("y") + assert s.count() == 2 + + def test_first_and_last(self, session: Store) -> None: + s = session.stream("data", int) + s.append(10, ts=1.0) + s.append(20, ts=2.0) + s.append(30, ts=3.0) + + assert s.first().data == 10 + assert s.last().data == 30 + + def test_first_empty_raises(self, session: Store) -> None: + s = session.stream("empty", int) + with pytest.raises(LookupError): + s.first() + + def test_exists(self, session: Store) -> None: + s = session.stream("check", str) + assert not s.exists() + s.append("hi") + assert s.exists() + + def test_filter_after(self, session: Store) -> None: + s = session.stream("ts_data", int) + s.append(1, ts=10.0) + s.append(2, ts=20.0) + s.append(3, ts=30.0) + + results = s.after(15.0).fetch() + assert [o.data for o in results] == [2, 3] + + def test_filter_before(self, session: Store) -> None: + s = session.stream("ts_data", int) + s.append(1, ts=10.0) + s.append(2, ts=20.0) + s.append(3, ts=30.0) + + results = s.before(25.0).fetch() + assert [o.data for o in results] == [1, 2] + + def test_filter_time_range(self, session: Store) -> None: + s = session.stream("ts_data", int) + s.append(1, ts=10.0) + s.append(2, ts=20.0) + s.append(3, ts=30.0) + + results = s.time_range(15.0, 25.0).fetch() + assert [o.data for o in results] == [2] + + def test_filter_tags(self, session: Store) -> None: + s = session.stream("tagged", str) + s.append("a", tags={"kind": "info"}) + s.append("b", tags={"kind": "error"}) + s.append("c", tags={"kind": "info"}) + + results = s.tags(kind="info").fetch() + assert [o.data for o in results] == ["a", "c"] + + def test_limit_and_offset(self, session: Store) -> None: + s = session.stream("paged", int) + for i in range(5): + s.append(i, ts=float(i)) + + page = s.offset(1).limit(2).fetch() + assert [o.data for o in page] == [1, 2] + + def test_order_by_desc(self, session: Store) -> None: + s = session.stream("ordered", int) + s.append(1, ts=10.0) + s.append(2, ts=20.0) + s.append(3, ts=30.0) + + results = s.order_by("ts", desc=True).fetch() + assert [o.data for o in results] == [3, 2, 1] + + def test_separate_streams_isolated(self, session: Store) -> None: + a = session.stream("stream_a", str) + b = session.stream("stream_b", str) + + a.append("in_a") + b.append("in_b") + + assert [o.data for o in a] == ["in_a"] + assert [o.data for o in b] == ["in_b"] + + def test_same_stream_on_repeated_calls(self, session: Store) -> None: + s1 = session.stream("reuse", str) + s2 = session.stream("reuse", str) + assert s1 is s2 + + def test_append_with_embedding(self, session: Store) -> None: + import numpy as np + + from dimos.memory2.type.observation import EmbeddedObservation + from dimos.models.embedding.base import Embedding + + s = session.stream("vectors", str) + emb = Embedding(vector=np.array([1.0, 0.0, 0.0], dtype=np.float32)) + obs = s.append("hello", embedding=emb) + assert isinstance(obs, EmbeddedObservation) + assert obs.embedding is emb + + def test_search_top_k(self, session: Store) -> None: + import numpy as np + + from dimos.models.embedding.base import Embedding + + def _emb(v: list[float]) -> Embedding: + a = np.array(v, dtype=np.float32) + return Embedding(vector=a / (np.linalg.norm(a) + 1e-10)) + + s = session.stream("searchable", str) + s.append("north", embedding=_emb([0, 1, 0])) + s.append("east", embedding=_emb([1, 0, 0])) + s.append("south", embedding=_emb([0, -1, 0])) + + results = s.search(_emb([0, 1, 0]), k=2).fetch() + assert len(results) == 2 + assert results[0].data == "north" + assert results[0].similarity > 0.99 + + def test_search_text(self, session: Store) -> None: + s = session.stream("logs", str) + s.append("motor fault") + s.append("temperature ok") + + # SqliteObservationStore blocks search_text to prevent full table scans + try: + results = s.search_text("motor").fetch() + except NotImplementedError: + pytest.skip("search_text not supported on this backend") + assert len(results) == 1 + assert results[0].data == "motor fault" + + +class TestBlobLoading: + """Verify lazy and eager blob loading paths.""" + + def test_sqlite_lazy_by_default(self, sqlite_store: Store) -> None: + """Default sqlite iteration uses lazy loaders — data is _UNLOADED until accessed.""" + from dimos.memory2.type.observation import _Unloaded + + s = sqlite_store.stream("lazy_test", str) + s.append("hello", ts=1.0) + s.append("world", ts=2.0) + + for obs in s: + # Before accessing .data, _data should be the unloaded sentinel + assert isinstance(obs._data, _Unloaded) + assert obs._loader is not None + # Accessing .data triggers the loader + val = obs.data + assert isinstance(val, str) + # After loading, _loader is cleared + assert obs._loader is None + + def test_sqlite_eager_loads_inline(self, sqlite_store: Store) -> None: + """With eager_blobs=True, data is loaded via JOIN — no lazy loader.""" + from dimos.memory2.type.observation import _Unloaded + + s = sqlite_store.stream("eager_test", str, eager_blobs=True) + s.append("hello", ts=1.0) + s.append("world", ts=2.0) + + for obs in s: + # Data should already be loaded — no lazy sentinel + assert not isinstance(obs._data, _Unloaded) + assert obs._loader is None + assert isinstance(obs.data, str) + + def test_sqlite_lazy_and_eager_same_values(self, sqlite_store: Store) -> None: + """Both paths must return identical data.""" + lazy_s = sqlite_store.stream("vals", str) + lazy_s.append("alpha", ts=1.0, tags={"k": "v"}) + lazy_s.append("beta", ts=2.0, tags={"k": "w"}) + + # Lazy read + lazy_results = lazy_s.fetch() + + # Eager read — new stream handle with eager_blobs on same backend + eager_s = sqlite_store.stream("vals", str, eager_blobs=True) + eager_results = eager_s.fetch() + + assert [o.data for o in lazy_results] == [o.data for o in eager_results] + assert [o.tags for o in lazy_results] == [o.tags for o in eager_results] + assert [o.ts for o in lazy_results] == [o.ts for o in eager_results] + + def test_memory_lazy_with_blobstore(self, tmp_path) -> None: + """MemoryStore with a BlobStore uses lazy loaders.""" + from dimos.memory2.blobstore.file import FileBlobStore + from dimos.memory2.store.memory import MemoryStore + from dimos.memory2.type.observation import _Unloaded + + bs = FileBlobStore(root=str(tmp_path / "blobs")) + bs.start() + with MemoryStore(blob_store=bs) as store: + s = store.stream("mem_lazy", str) + s.append("data1", ts=1.0) + + obs = s.first() + # Backend replaces _data with _UNLOADED when blob_store is set + assert isinstance(obs._data, _Unloaded) + assert obs.data == "data1" + bs.stop() + + +class SpyBlobStore(BlobStore): + """BlobStore that records all calls for verification.""" + + def __init__(self) -> None: + super().__init__() + self.puts: list[tuple[str, int, bytes]] = [] + self.gets: list[tuple[str, int]] = [] + self.store: dict[tuple[str, int], bytes] = {} + + def start(self) -> None: + pass + + def stop(self) -> None: + pass + + def put(self, stream: str, key: int, data: bytes) -> None: + self.puts.append((stream, key, data)) + self.store[(stream, key)] = data + + def get(self, stream: str, key: int) -> bytes: + self.gets.append((stream, key)) + return self.store[(stream, key)] + + def delete(self, stream: str, key: int) -> None: + self.store.pop((stream, key), None) + + +class SpyVectorStore(VectorStore): + """VectorStore that records all calls for verification.""" + + def __init__(self) -> None: + super().__init__() + self.puts: list[tuple[str, int]] = [] + self.searches: list[tuple[str, int]] = [] + self.vectors: dict[str, dict[int, Any]] = {} + + def start(self) -> None: + pass + + def stop(self) -> None: + pass + + def put(self, stream: str, key: int, embedding: Any) -> None: + self.puts.append((stream, key)) + self.vectors.setdefault(stream, {})[key] = embedding + + def search(self, stream: str, query: Any, k: int) -> list[tuple[int, float]]: + self.searches.append((stream, k)) + vectors = self.vectors.get(stream, {}) + if not vectors: + return [] + scored = [(key, float(emb @ query)) for key, emb in vectors.items()] + scored.sort(key=lambda x: x[1], reverse=True) + return scored[:k] + + def delete(self, stream: str, key: int) -> None: + self.vectors.get(stream, {}).pop(key, None) + + +@pytest.fixture +def memory_spy_session(): + from dimos.memory2.store.memory import MemoryStore + + blob_spy = SpyBlobStore() + vec_spy = SpyVectorStore() + with MemoryStore(blob_store=blob_spy, vector_store=vec_spy) as store: + yield store, blob_spy, vec_spy + + +@pytest.fixture +def sqlite_spy_session(tmp_path): + from dimos.memory2.store.sqlite import SqliteStore + + blob_spy = SpyBlobStore() + vec_spy = SpyVectorStore() + with SqliteStore( + path=str(tmp_path / "spy.db"), blob_store=blob_spy, vector_store=vec_spy + ) as store: + yield store, blob_spy, vec_spy + + +@pytest.fixture(params=["memory_spy_session", "sqlite_spy_session"]) +def spy_session(request: pytest.FixtureRequest): + return request.getfixturevalue(request.param) + + +class TestStoreDelegation: + """Verify all backends delegate to pluggable BlobStore and VectorStore.""" + + def test_append_calls_blob_put(self, spy_session) -> None: + store, blob_spy, _vec_spy = spy_session + s = store.stream("blobs", str) + s.append("first", ts=1.0) + s.append("second", ts=2.0) + + assert len(blob_spy.puts) == 2 + assert all(stream == "blobs" for stream, _k, _d in blob_spy.puts) + + def test_iterate_calls_blob_get(self, spy_session) -> None: + store, blob_spy, _vec_spy = spy_session + s = store.stream("blobs", str) + s.append("a", ts=1.0) + s.append("b", ts=2.0) + + blob_spy.gets.clear() + for obs in s: + _ = obs.data + assert len(blob_spy.gets) == 2 + + def test_append_embedding_calls_vector_put(self, spy_session) -> None: + import numpy as np + + from dimos.models.embedding.base import Embedding + + def _emb(v: list[float]) -> Embedding: + a = np.array(v, dtype=np.float32) + return Embedding(vector=a / (np.linalg.norm(a) + 1e-10)) + + store, _blob_spy, vec_spy = spy_session + s = store.stream("vecs", str) + s.append("a", ts=1.0, embedding=_emb([1, 0, 0])) + s.append("b", ts=2.0, embedding=_emb([0, 1, 0])) + s.append("c", ts=3.0) # no embedding + + assert len(vec_spy.puts) == 2 + + def test_search_calls_vector_search(self, spy_session) -> None: + import numpy as np + + from dimos.models.embedding.base import Embedding + + def _emb(v: list[float]) -> Embedding: + a = np.array(v, dtype=np.float32) + return Embedding(vector=a / (np.linalg.norm(a) + 1e-10)) + + store, _blob_spy, vec_spy = spy_session + s = store.stream("vecs", str) + s.append("north", ts=1.0, embedding=_emb([0, 1, 0])) + s.append("east", ts=2.0, embedding=_emb([1, 0, 0])) + + results = s.search(_emb([0, 1, 0]), k=2).fetch() + assert len(vec_spy.searches) == 1 + assert results[0].data == "north" + + +class TestStandaloneComponents: + """Verify each SQLite component works standalone with path= (no Store needed).""" + + def test_observation_store_standalone(self, tmp_path) -> None: + from dimos.memory2.codecs.base import codec_for + from dimos.memory2.observationstore.sqlite import SqliteObservationStore + from dimos.memory2.type.filter import StreamQuery + from dimos.memory2.type.observation import Observation + + db = str(tmp_path / "obs.db") + codec = codec_for(str) + with SqliteObservationStore(path=db, name="events", codec=codec) as store: + obs = Observation(id=0, ts=1.0, _data="hello") + row_id = store.insert(obs) + store.commit() + assert row_id == 1 + + results = list(store.query(StreamQuery())) + assert len(results) == 1 + assert results[0].ts == 1.0 + + def test_blob_store_standalone(self, tmp_path) -> None: + from dimos.memory2.blobstore.sqlite import SqliteBlobStore + + db = str(tmp_path / "blob.db") + with SqliteBlobStore(path=db) as store: + store.put("stream1", 1, b"data1") + store.put("stream1", 2, b"data2") + assert store.get("stream1", 1) == b"data1" + assert store.get("stream1", 2) == b"data2" + + def test_vector_store_standalone(self, tmp_path) -> None: + import numpy as np + + from dimos.memory2.vectorstore.sqlite import SqliteVectorStore + from dimos.models.embedding.base import Embedding + + db = str(tmp_path / "vec.db") + with SqliteVectorStore(path=db) as store: + emb1 = Embedding(vector=np.array([1, 0, 0], dtype=np.float32)) + emb2 = Embedding(vector=np.array([0, 1, 0], dtype=np.float32)) + store.put("vecs", 1, emb1) + store.put("vecs", 2, emb2) + + results = store.search("vecs", emb1, k=2) + assert len(results) == 2 + assert results[0][0] == 1 # closest to emb1 is itself + + def test_conn_and_path_mutually_exclusive(self, tmp_path) -> None: + import sqlite3 + + from dimos.memory2.blobstore.sqlite import SqliteBlobStore + from dimos.memory2.observationstore.sqlite import SqliteObservationStore + from dimos.memory2.vectorstore.sqlite import SqliteVectorStore + + conn = sqlite3.connect(":memory:") + db = str(tmp_path / "test.db") + + with pytest.raises(ValueError, match="either conn or path"): + SqliteBlobStore(conn=conn, path=db) + with pytest.raises(ValueError, match="either conn or path"): + SqliteVectorStore(conn=conn, path=db) + with pytest.raises(ValueError, match="either conn or path"): + SqliteObservationStore(conn=conn, name="x", path=db) + with pytest.raises(ValueError, match="either conn or path"): + SqliteBlobStore() + with pytest.raises(ValueError, match="either conn or path"): + SqliteVectorStore() + with pytest.raises(ValueError, match="either conn or path"): + SqliteObservationStore(name="x") + conn.close() + + +class TestStreamAccessor: + """Test attribute-style stream access via store.streams.""" + + def test_accessor_returns_same_stream(self, session: Store) -> None: + s = session.stream("images", bytes) + assert session.streams.images is s + + def test_accessor_dir_lists_streams(self, session: Store) -> None: + session.stream("alpha", str) + session.stream("beta", int) + names = dir(session.streams) + assert "alpha" in names + assert "beta" in names + + def test_accessor_missing_raises(self, session: Store) -> None: + with pytest.raises(AttributeError, match="nonexistent"): + _ = session.streams.nonexistent + + def test_accessor_getitem(self, session: Store) -> None: + s = session.stream("data", float) + assert session.streams["data"] is s + + def test_accessor_getitem_missing_raises(self, session: Store) -> None: + with pytest.raises(KeyError): + session.streams["nope"] + + def test_accessor_repr(self, session: Store) -> None: + session.stream("x", str) + r = repr(session.streams) + assert "x" in r + assert "StreamAccessor" in r + + def test_accessor_dynamic(self, session: Store) -> None: + assert "late" not in dir(session.streams) + session.stream("late", str) + assert "late" in dir(session.streams) diff --git a/dimos/memory2/test_stream.py b/dimos/memory2/test_stream.py new file mode 100644 index 0000000000..adfa15ac14 --- /dev/null +++ b/dimos/memory2/test_stream.py @@ -0,0 +1,728 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""memory stream tests — serves as living documentation of the lazy stream API. + +Each test demonstrates a specific capability with clear setup, action, and assertion. +""" + +from __future__ import annotations + +import threading +import time +from typing import TYPE_CHECKING + +import pytest + +from dimos.memory2.buffer import KeepLast, Unbounded +from dimos.memory2.transform import FnTransformer, QualityWindow, Transformer +from dimos.memory2.type.observation import Observation + +if TYPE_CHECKING: + from collections.abc import Callable, Generator + + from dimos.memory2.stream import Stream + + +@pytest.fixture +def make_stream(session) -> Generator[Callable[..., Stream[int]], None, None]: + stream_index = 0 + + def f(n: int = 5, start_ts: float = 0.0): + nonlocal stream_index + stream_index += 1 + stream = session.stream(f"test{stream_index}", int) + for i in range(n): + stream.append(i * 10, ts=start_ts + i) + return stream + + return f + + +# ═══════════════════════════════════════════════════════════════════ +# 1. Basic iteration +# ═══════════════════════════════════════════════════════════════════ + + +class TestBasicIteration: + """Streams are lazy iterables — nothing runs until you iterate.""" + + def test_iterate_yields_all_observations(self, make_stream): + stream = make_stream(5) + obs = list(stream) + assert len(obs) == 5 + assert [o.data for o in obs] == [0, 10, 20, 30, 40] + + def test_iterate_preserves_timestamps(self, make_stream): + stream = make_stream(3, start_ts=100.0) + assert [o.ts for o in stream] == [100.0, 101.0, 102.0] + + def test_empty_stream(self, make_stream): + stream = make_stream(0) + assert list(stream) == [] + + def test_fetch_materializes_to_list(self, make_stream): + result = make_stream(3).fetch() + assert isinstance(result, list) + assert len(result) == 3 + + def test_stream_is_reiterable(self, make_stream): + """Same stream can be iterated multiple times — each time re-queries.""" + stream = make_stream(3) + first = [o.data for o in stream] + second = [o.data for o in stream] + assert first == second == [0, 10, 20] + + +# ═══════════════════════════════════════════════════════════════════ +# 2. Temporal filters +# ═══════════════════════════════════════════════════════════════════ + + +class TestTemporalFilters: + """Temporal filters constrain observations by timestamp.""" + + def test_after(self, make_stream): + """.after(t) keeps observations with ts > t.""" + result = make_stream(5).after(2.0).fetch() + assert [o.ts for o in result] == [3.0, 4.0] + + def test_before(self, make_stream): + """.before(t) keeps observations with ts < t.""" + result = make_stream(5).before(2.0).fetch() + assert [o.ts for o in result] == [0.0, 1.0] + + def test_time_range(self, make_stream): + """.time_range(t1, t2) keeps t1 <= ts <= t2.""" + result = make_stream(5).time_range(1.0, 3.0).fetch() + assert [o.ts for o in result] == [1.0, 2.0, 3.0] + + def test_at_with_tolerance(self, make_stream): + """.at(t, tolerance) keeps observations within tolerance of t.""" + result = make_stream(5).at(2.0, tolerance=0.5).fetch() + assert [o.ts for o in result] == [2.0] + + def test_chained_temporal_filters(self, make_stream): + """Filters compose — each narrows the result.""" + result = make_stream(10).after(2.0).before(7.0).fetch() + assert [o.ts for o in result] == [3.0, 4.0, 5.0, 6.0] + + +# ═══════════════════════════════════════════════════════════════════ +# 3. Spatial filter +# ═══════════════════════════════════════════════════════════════════ + + +class TestSpatialFilter: + """.near(pose, radius) filters by Euclidean distance.""" + + def test_near_with_tuples(self, memory_session): + stream = memory_session.stream("spatial") + stream.append("origin", ts=0.0, pose=(0, 0, 0)) + stream.append("close", ts=1.0, pose=(1, 1, 0)) + stream.append("far", ts=2.0, pose=(10, 10, 10)) + + result = stream.near((0, 0, 0), radius=2.0).fetch() + assert [o.data for o in result] == ["origin", "close"] + + def test_near_excludes_no_pose(self, memory_session): + stream = memory_session.stream("spatial") + stream.append("no_pose", ts=0.0) + stream.append("has_pose", ts=1.0, pose=(0, 0, 0)) + + result = stream.near((0, 0, 0), radius=10.0).fetch() + assert [o.data for o in result] == ["has_pose"] + + +# ═══════════════════════════════════════════════════════════════════ +# 4. Tags filter +# ═══════════════════════════════════════════════════════════════════ + + +class TestTagsFilter: + """.filter_tags() matches on observation metadata.""" + + def test_filter_by_tag(self, memory_session): + stream = memory_session.stream("tagged") + stream.append("cat", ts=0.0, tags={"type": "animal", "legs": 4}) + stream.append("car", ts=1.0, tags={"type": "vehicle", "wheels": 4}) + stream.append("dog", ts=2.0, tags={"type": "animal", "legs": 4}) + + result = stream.tags(type="animal").fetch() + assert [o.data for o in result] == ["cat", "dog"] + + def test_filter_multiple_tags(self, memory_session): + stream = memory_session.stream("tagged") + stream.append("a", ts=0.0, tags={"x": 1, "y": 2}) + stream.append("b", ts=1.0, tags={"x": 1, "y": 3}) + + result = stream.tags(x=1, y=2).fetch() + assert [o.data for o in result] == ["a"] + + +# ═══════════════════════════════════════════════════════════════════ +# 5. Ordering, limit, offset +# ═══════════════════════════════════════════════════════════════════ + + +class TestOrderLimitOffset: + def test_limit(self, make_stream): + result = make_stream(10).limit(3).fetch() + assert len(result) == 3 + + def test_offset(self, make_stream): + result = make_stream(5).offset(2).fetch() + assert [o.data for o in result] == [20, 30, 40] + + def test_limit_and_offset(self, make_stream): + result = make_stream(10).offset(2).limit(3).fetch() + assert [o.data for o in result] == [20, 30, 40] + + def test_order_by_ts_desc(self, make_stream): + result = make_stream(5).order_by("ts", desc=True).fetch() + assert [o.ts for o in result] == [4.0, 3.0, 2.0, 1.0, 0.0] + + def test_first(self, make_stream): + obs = make_stream(5).first() + assert obs.data == 0 + + def test_last(self, make_stream): + obs = make_stream(5).last() + assert obs.data == 40 + + def test_first_empty_raises(self, make_stream): + with pytest.raises(LookupError): + make_stream(0).first() + + def test_count(self, make_stream): + assert make_stream(5).count() == 5 + assert make_stream(5).after(2.0).count() == 2 + + def test_exists(self, make_stream): + assert make_stream(5).exists() + assert not make_stream(0).exists() + assert not make_stream(5).after(100.0).exists() + + def test_drain(self, make_stream): + assert make_stream(5).drain() == 5 + assert make_stream(5).after(2.0).drain() == 2 + assert make_stream(0).drain() == 0 + + +# ═══════════════════════════════════════════════════════════════════ +# 6. Functional API: .filter(), .map() +# ═══════════════════════════════════════════════════════════════════ + + +class TestFunctionalAPI: + """Functional combinators receive the full Observation.""" + + def test_filter_with_predicate(self, make_stream): + """.filter() takes a predicate on the full Observation.""" + result = make_stream(5).filter(lambda obs: obs.data > 20).fetch() + assert [o.data for o in result] == [30, 40] + + def test_filter_on_metadata(self, make_stream): + """Predicates can access ts, tags, pose — not just data.""" + result = make_stream(5).filter(lambda obs: obs.ts % 2 == 0).fetch() + assert [o.ts for o in result] == [0.0, 2.0, 4.0] + + def test_map(self, make_stream): + """.map() transforms each observation's data.""" + result = make_stream(3).map(lambda obs: obs.derive(data=obs.data * 2)).fetch() + assert [o.data for o in result] == [0, 20, 40] + + def test_map_preserves_ts(self, make_stream): + result = make_stream(3).map(lambda obs: obs.derive(data=str(obs.data))).fetch() + assert [o.ts for o in result] == [0.0, 1.0, 2.0] + assert [o.data for o in result] == ["0", "10", "20"] + + +# ═══════════════════════════════════════════════════════════════════ +# 7. Transform chaining +# ═══════════════════════════════════════════════════════════════════ + + +class TestTransformChaining: + """Transforms chain lazily — each obs flows through the full pipeline.""" + + def test_single_transform(self, make_stream): + xf = FnTransformer(lambda obs: obs.derive(data=obs.data + 1)) + result = make_stream(3).transform(xf).fetch() + assert [o.data for o in result] == [1, 11, 21] + + def test_chained_transforms(self, make_stream): + """stream.transform(A).transform(B) — B pulls from A which pulls from source.""" + add_one = FnTransformer(lambda obs: obs.derive(data=obs.data + 1)) + double = FnTransformer(lambda obs: obs.derive(data=obs.data * 2)) + + result = make_stream(3).transform(add_one).transform(double).fetch() + # (0+1)*2=2, (10+1)*2=22, (20+1)*2=42 + assert [o.data for o in result] == [2, 22, 42] + + def test_transform_can_skip(self, make_stream): + """Returning None from a transformer skips that observation.""" + keep_even = FnTransformer(lambda obs: obs if obs.data % 20 == 0 else None) + result = make_stream(5).transform(keep_even).fetch() + assert [o.data for o in result] == [0, 20, 40] + + def test_transform_filter_transform(self, memory_session): + """stream.transform(A).near(pose).transform(B) — filter between transforms.""" + stream = memory_session.stream("tfft") + stream.append(1, ts=0.0, pose=(0, 0, 0)) + stream.append(2, ts=1.0, pose=(100, 100, 100)) + stream.append(3, ts=2.0, pose=(1, 0, 0)) + + add_ten = FnTransformer(lambda obs: obs.derive(data=obs.data + 10)) + double = FnTransformer(lambda obs: obs.derive(data=obs.data * 2)) + + result = ( + stream.transform(add_ten) # 11, 12, 13 + .near((0, 0, 0), 5.0) # keeps pose at (0,0,0) and (1,0,0) + .transform(double) # 22, 26 + .fetch() + ) + assert [o.data for o in result] == [22, 26] + + def test_generator_function_transform(self, make_stream): + """A bare generator function works as a transform.""" + + def double_all(upstream): + for obs in upstream: + yield obs.derive(data=obs.data * 2) + + result = make_stream(3).transform(double_all).fetch() + assert [o.data for o in result] == [0, 20, 40] + + def test_generator_function_stateful(self, make_stream): + """Generator transforms can accumulate state and yield at their own pace.""" + + def running_sum(upstream): + total = 0 + for obs in upstream: + total += obs.data + yield obs.derive(data=total) + + result = make_stream(3).transform(running_sum).fetch() + # 0, 0+10=10, 10+20=30 + assert [o.data for o in result] == [0, 10, 30] + + def test_quality_window(self, memory_session): + """QualityWindow keeps the best item per time window.""" + stream = memory_session.stream("qw") + # Window 1: ts 0.0-0.9 → best quality + stream.append(0.3, ts=0.0) + stream.append(0.9, ts=0.3) # best in window + stream.append(0.1, ts=0.7) + # Window 2: ts 1.0-1.9 + stream.append(0.5, ts=1.0) + stream.append(0.8, ts=1.5) # best in window + # Window 3: ts 2.0+ (emitted at end via flush) + stream.append(0.6, ts=2.2) + + xf = QualityWindow(quality_fn=lambda v: v, window=1.0) + result = stream.transform(xf).fetch() + assert [o.data for o in result] == [0.9, 0.8, 0.6] + + def test_streaming_not_buffering(self, make_stream): + """Transforms process lazily — early limit stops pulling from source.""" + calls = [] + + class CountingXf(Transformer[int, int]): + def __call__(self, upstream): + for obs in upstream: + calls.append(obs.data) + yield obs + + result = make_stream(100).transform(CountingXf()).limit(3).fetch() + assert len(result) == 3 + # The transformer should have processed at most a few more than 3 + # (not all 100) due to lazy evaluation + assert len(calls) == 3 + + +# ═══════════════════════════════════════════════════════════════════ +# 8. Store +# ═══════════════════════════════════════════════════════════════════ + + +class TestStore: + """Store -> Stream hierarchy for named streams.""" + + def test_basic_store(self, memory_store): + images = memory_store.stream("images") + images.append("frame1", ts=0.0) + images.append("frame2", ts=1.0) + assert images.count() == 2 + + def test_same_stream_on_repeated_calls(self, memory_store): + s1 = memory_store.stream("images") + s2 = memory_store.stream("images") + assert s1 is s2 + + def test_list_streams(self, memory_store): + memory_store.stream("images") + memory_store.stream("lidar") + names = memory_store.list_streams() + assert "images" in names + assert "lidar" in names + assert len(names) == 2 + + def test_delete_stream(self, memory_store): + memory_store.stream("temp") + memory_store.delete_stream("temp") + assert "temp" not in memory_store.list_streams() + + +# ═══════════════════════════════════════════════════════════════════ +# 9. Lazy data loading +# ═══════════════════════════════════════════════════════════════════ + + +class TestLazyData: + """Observation.data supports lazy loading with cleanup.""" + + def test_eager_data(self): + """In-memory observations have data set directly — zero-cost access.""" + obs = Observation(id=0, ts=0.0, _data="hello") + assert obs.data == "hello" + + def test_lazy_loading(self): + """Data loaded on first access, loader released after.""" + load_count = 0 + + def loader(): + nonlocal load_count + load_count += 1 + return "loaded" + + obs = Observation(id=0, ts=0.0, _loader=loader) + assert load_count == 0 + assert obs.data == "loaded" + assert load_count == 1 + assert obs._loader is None # released + assert obs.data == "loaded" # cached, no second load + assert load_count == 1 + + def test_no_data_no_loader_raises(self): + obs = Observation(id=0, ts=0.0) + with pytest.raises(LookupError): + _ = obs.data + + def test_derive_preserves_metadata(self): + obs = Observation(id=42, ts=1.5, pose=(1, 2, 3), tags={"k": "v"}, _data="original") + derived = obs.derive(data="transformed") + assert derived.id == 42 + assert derived.ts == 1.5 + assert derived.pose == (1, 2, 3) + assert derived.tags == {"k": "v"} + assert derived.data == "transformed" + + +# ═══════════════════════════════════════════════════════════════════ +# 10. Live mode +# ═══════════════════════════════════════════════════════════════════ + + +class TestLiveMode: + """Live streams yield backfill then block for new observations.""" + + def test_live_sees_backfill_then_new(self, memory_session): + """Backfill first, then live appends come through.""" + stream = memory_session.stream("live") + stream.append("old", ts=0.0) + live = stream.live(buffer=Unbounded()) + + results: list[str] = [] + consumed = threading.Event() + + def consumer(): + for obs in live: + results.append(obs.data) + if len(results) >= 3: + consumed.set() + return + + t = threading.Thread(target=consumer) + t.start() + + time.sleep(0.05) + stream.append("new1", ts=1.0) + stream.append("new2", ts=2.0) + + consumed.wait(timeout=2.0) + t.join(timeout=2.0) + assert results == ["old", "new1", "new2"] + + def test_live_with_filter(self, memory_session): + """Filters apply to live data — non-matching obs are dropped silently.""" + stream = memory_session.stream("live_filter") + live = stream.after(5.0).live(buffer=Unbounded()) + + results: list[int] = [] + consumed = threading.Event() + + def consumer(): + for obs in live: + results.append(obs.data) + if len(results) >= 2: + consumed.set() + return + + t = threading.Thread(target=consumer) + t.start() + + time.sleep(0.05) + stream.append(1, ts=1.0) # filtered out (ts <= 5.0) + stream.append(2, ts=6.0) # passes + stream.append(3, ts=3.0) # filtered out + stream.append(4, ts=10.0) # passes + + consumed.wait(timeout=2.0) + t.join(timeout=2.0) + assert results == [2, 4] + + def test_live_deduplicates_backfill_overlap(self, memory_session): + """Observations seen in backfill are not re-yielded from the live buffer.""" + stream = memory_session.stream("dedup") + stream.append("backfill", ts=0.0) + live = stream.live(buffer=Unbounded()) + + results: list[str] = [] + consumed = threading.Event() + + def consumer(): + for obs in live: + results.append(obs.data) + if len(results) >= 2: + consumed.set() + return + + t = threading.Thread(target=consumer) + t.start() + + time.sleep(0.05) + stream.append("live1", ts=1.0) + + consumed.wait(timeout=2.0) + t.join(timeout=2.0) + assert results == ["backfill", "live1"] + + def test_live_with_keep_last_backpressure(self, memory_session): + """KeepLast drops intermediate values when consumer is slow.""" + stream = memory_session.stream("bp") + live = stream.live(buffer=KeepLast()) + + results: list[int] = [] + consumed = threading.Event() + + def consumer(): + for obs in live: + results.append(obs.data) + if obs.data >= 90: + consumed.set() + return + time.sleep(0.1) # slow consumer + + t = threading.Thread(target=consumer) + t.start() + + time.sleep(0.05) + # Rapid producer — KeepLast will drop most of these + for i in range(100): + stream.append(i, ts=float(i)) + time.sleep(0.001) + + consumed.wait(timeout=5.0) + t.join(timeout=2.0) + # KeepLast means many values were dropped — far fewer than 100 + assert len(results) < 50 + assert results[-1] >= 90 + + def test_live_transform_receives_live_items(self, memory_session): + """Transforms downstream of .live() see both backfill and live items.""" + stream = memory_session.stream("live_xf") + stream.append(1, ts=0.0) + double = FnTransformer(lambda obs: obs.derive(data=obs.data * 2)) + live = stream.live(buffer=Unbounded()).transform(double) + + results: list[int] = [] + consumed = threading.Event() + + def consumer(): + for obs in live: + results.append(obs.data) + if len(results) >= 3: + consumed.set() + return + + t = threading.Thread(target=consumer) + t.start() + + time.sleep(0.05) + stream.append(10, ts=1.0) + stream.append(100, ts=2.0) + + consumed.wait(timeout=2.0) + t.join(timeout=2.0) + # All items went through the double transform + assert results == [2, 20, 200] + + def test_live_on_transform_raises(self, make_stream): + """Calling .live() on a transform stream raises TypeError.""" + stream = make_stream(3) + xf = FnTransformer(lambda obs: obs) + with pytest.raises(TypeError, match="Cannot call .live"): + stream.transform(xf).live() + + def test_is_live(self, memory_session): + """is_live() walks the source chain to detect live mode.""" + stream = memory_session.stream("is_live") + assert not stream.is_live() + + live = stream.live(buffer=Unbounded()) + assert live.is_live() + + xf = FnTransformer(lambda obs: obs) + transformed = live.transform(xf) + assert transformed.is_live() + + # Two levels deep + double_xf = transformed.transform(xf) + assert double_xf.is_live() + + # Non-live transform is not live + assert not stream.transform(xf).is_live() + + def test_search_on_live_transform_raises(self, memory_session): + """search() on a transform with live upstream raises immediately.""" + stream = memory_session.stream("live_search") + xf = FnTransformer(lambda obs: obs) + live_xf = stream.live(buffer=Unbounded()).transform(xf) + + import numpy as np + + from dimos.models.embedding.base import Embedding + + vec = Embedding(vector=np.array([1.0, 0.0, 0.0])) + with pytest.raises(TypeError, match="requires finite data"): + # Use list() to trigger iteration — fetch() would hit its own guard first + list(live_xf.search(vec, k=5)) + + def test_order_by_on_live_transform_raises(self, memory_session): + """order_by() on a transform with live upstream raises immediately.""" + stream = memory_session.stream("live_order") + xf = FnTransformer(lambda obs: obs) + live_xf = stream.live(buffer=Unbounded()).transform(xf) + + with pytest.raises(TypeError, match="requires finite data"): + list(live_xf.order_by("ts", desc=True)) + + def test_fetch_on_live_without_limit_raises(self, memory_session): + """fetch() on a live stream without limit() raises TypeError.""" + stream = memory_session.stream("live_fetch") + live = stream.live(buffer=Unbounded()) + + with pytest.raises(TypeError, match="block forever"): + live.fetch() + + def test_fetch_on_live_transform_without_limit_raises(self, memory_session): + """fetch() on a live transform without limit() raises TypeError.""" + stream = memory_session.stream("live_fetch_xf") + xf = FnTransformer(lambda obs: obs) + live_xf = stream.live(buffer=Unbounded()).transform(xf) + + with pytest.raises(TypeError, match="block forever"): + live_xf.fetch() + + def test_count_on_live_transform_raises(self, memory_session): + """count() on a live transform stream raises TypeError.""" + stream = memory_session.stream("live_count") + xf = FnTransformer(lambda obs: obs) + live_xf = stream.live(buffer=Unbounded()).transform(xf) + + with pytest.raises(TypeError, match="block forever"): + live_xf.count() + + def test_last_on_live_transform_raises(self, memory_session): + """last() on a live transform raises TypeError (via order_by guard).""" + stream = memory_session.stream("live_last") + xf = FnTransformer(lambda obs: obs) + live_xf = stream.live(buffer=Unbounded()).transform(xf) + + with pytest.raises(TypeError, match="requires finite data"): + live_xf.last() + + def test_live_chained_transforms(self, memory_session): + """stream.live().transform(A).transform(B) — both transforms applied to live items.""" + stream = memory_session.stream("live_chain") + stream.append(1, ts=0.0) + add_one = FnTransformer(lambda obs: obs.derive(data=obs.data + 1)) + double = FnTransformer(lambda obs: obs.derive(data=obs.data * 2)) + live = stream.live(buffer=Unbounded()).transform(add_one).transform(double) + + results: list[int] = [] + consumed = threading.Event() + + def consumer(): + for obs in live: + results.append(obs.data) + if len(results) >= 3: + consumed.set() + return + + t = threading.Thread(target=consumer) + t.start() + + time.sleep(0.05) + stream.append(10, ts=1.0) + stream.append(100, ts=2.0) + + consumed.wait(timeout=2.0) + t.join(timeout=2.0) + # (1+1)*2=4, (10+1)*2=22, (100+1)*2=202 + assert results == [4, 22, 202] + + def test_live_filter_before_live(self, memory_session): + """Filters applied before .live() work on both backfill and live items.""" + stream = memory_session.stream("live_pre_filter") + stream.append("a", ts=1.0) + stream.append("b", ts=10.0) + live = stream.after(5.0).live(buffer=Unbounded()) + + results: list[str] = [] + consumed = threading.Event() + + def consumer(): + for obs in live: + results.append(obs.data) + if len(results) >= 2: + consumed.set() + return + + t = threading.Thread(target=consumer) + t.start() + + time.sleep(0.05) + stream.append("c", ts=3.0) # filtered + stream.append("d", ts=20.0) # passes + + consumed.wait(timeout=2.0) + t.join(timeout=2.0) + # "a" filtered in backfill, "c" filtered in live + assert results == ["b", "d"] + # "a" filtered in backfill, "c" filtered in live + assert results == ["b", "d"] + assert results == ["b", "d"] + assert results == ["b", "d"] diff --git a/dimos/memory2/transform.py b/dimos/memory2/transform.py new file mode 100644 index 0000000000..1e5dc35c2c --- /dev/null +++ b/dimos/memory2/transform.py @@ -0,0 +1,115 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +import inspect +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +from dimos.memory2.utils.formatting import FilterRepr + +if TYPE_CHECKING: + from collections.abc import Callable, Iterator + + from dimos.memory2.type.observation import Observation + +T = TypeVar("T") +R = TypeVar("R") + + +class Transformer(FilterRepr, ABC, Generic[T, R]): + """Transforms a stream of observations lazily via iterator -> iterator. + + Pull from upstream, yield transformed observations. Naturally supports + batching, windowing, fan-out. The generator cleans + up when upstream exhausts. + """ + + @abstractmethod + def __call__(self, upstream: Iterator[Observation[T]]) -> Iterator[Observation[R]]: ... + + def __str__(self) -> str: + parts: list[str] = [] + for name in inspect.signature(self.__init__).parameters: # type: ignore[misc] + for attr in (name, f"_{name}"): + if hasattr(self, attr): + val = getattr(self, attr) + if callable(val): + parts.append(f"{name}={getattr(val, '__name__', '...')}") + else: + parts.append(f"{name}={val!r}") + break + return f"{self.__class__.__name__}({', '.join(parts)})" + + +class FnTransformer(Transformer[T, R]): + """Wraps a callable that receives an Observation and returns a new one (or None to skip).""" + + def __init__(self, fn: Callable[[Observation[T]], Observation[R] | None]) -> None: + self._fn = fn + + def __call__(self, upstream: Iterator[Observation[T]]) -> Iterator[Observation[R]]: + fn = self._fn + for obs in upstream: + result = fn(obs) + if result is not None: + yield result + + +class FnIterTransformer(Transformer[T, R]): + """Wraps a bare ``Iterator → Iterator`` callable (e.g. a generator function).""" + + def __init__(self, fn: Callable[[Iterator[Observation[T]]], Iterator[Observation[R]]]) -> None: + self._fn = fn + + def __call__(self, upstream: Iterator[Observation[T]]) -> Iterator[Observation[R]]: + return self._fn(upstream) + + +class QualityWindow(Transformer[T, T]): + """Keeps the highest-quality item per time window. + + Emits the best observation when the window advances. The last window + is emitted when the upstream iterator exhausts — no flush needed. + """ + + def __init__(self, quality_fn: Callable[[Any], float], window: float) -> None: + self._quality_fn = quality_fn + self._window = window + + def __call__(self, upstream: Iterator[Observation[T]]) -> Iterator[Observation[T]]: + quality_fn = self._quality_fn + window = self._window + best: Observation[T] | None = None + best_score: float = -1.0 + window_start: float | None = None + + for obs in upstream: + if window_start is not None and (obs.ts - window_start) >= window: + if best is not None: + yield best + best = None + best_score = -1.0 + window_start = obs.ts + + score = quality_fn(obs.data) + if score > best_score: + best = obs + best_score = score + if window_start is None: + window_start = obs.ts + + if best is not None: + yield best diff --git a/dimos/memory2/type/filter.py b/dimos/memory2/type/filter.py new file mode 100644 index 0000000000..af453498fd --- /dev/null +++ b/dimos/memory2/type/filter.py @@ -0,0 +1,212 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field, fields +from itertools import islice +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Callable, Iterator + + from dimos.memory2.buffer import BackpressureBuffer + from dimos.memory2.type.observation import Observation + from dimos.models.embedding.base import Embedding + + +@dataclass(frozen=True) +class Filter(ABC): + """Any object with a .matches(obs) -> bool method can be a filter.""" + + @abstractmethod + def matches(self, obs: Observation[Any]) -> bool: ... + + def __str__(self) -> str: + args = ", ".join(f"{f.name}={getattr(self, f.name)!r}" for f in fields(self)) + return f"{self.__class__.__name__}({args})" + + +@dataclass(frozen=True) +class AfterFilter(Filter): + t: float + + def matches(self, obs: Observation[Any]) -> bool: + return obs.ts > self.t + + +@dataclass(frozen=True) +class BeforeFilter(Filter): + t: float + + def matches(self, obs: Observation[Any]) -> bool: + return obs.ts < self.t + + +@dataclass(frozen=True) +class TimeRangeFilter(Filter): + t1: float + t2: float + + def matches(self, obs: Observation[Any]) -> bool: + return self.t1 <= obs.ts <= self.t2 + + +@dataclass(frozen=True) +class AtFilter(Filter): + t: float + tolerance: float = 1.0 + + def matches(self, obs: Observation[Any]) -> bool: + return abs(obs.ts - self.t) <= self.tolerance + + +@dataclass(frozen=True) +class NearFilter(Filter): + pose: Any = field(hash=False) + radius: float = 0.0 + + def matches(self, obs: Observation[Any]) -> bool: + if obs.pose is None or self.pose is None: + return False + p1 = self.pose + p2 = obs.pose + # Support both raw (x,y,z) tuples and PoseStamped objects + if hasattr(p1, "position"): + p1 = p1.position + if hasattr(p2, "position"): + p2 = p2.position + x1, y1, z1 = _xyz(p1) + x2, y2, z2 = _xyz(p2) + dist_sq = (x1 - x2) ** 2 + (y1 - y2) ** 2 + (z1 - z2) ** 2 + return dist_sq <= self.radius**2 + + +def _xyz(p: Any) -> tuple[float, float, float]: + """Extract (x, y, z) from various pose representations.""" + if isinstance(p, (list, tuple)): + return (float(p[0]), float(p[1]), float(p[2]) if len(p) > 2 else 0.0) + return (float(p.x), float(p.y), float(getattr(p, "z", 0.0))) + + +@dataclass(frozen=True) +class TagsFilter(Filter): + tags: dict[str, Any] = field(default_factory=dict, hash=False) + + def matches(self, obs: Observation[Any]) -> bool: + for k, v in self.tags.items(): + if obs.tags.get(k) != v: + return False + return True + + +@dataclass(frozen=True) +class PredicateFilter(Filter): + """Wraps an arbitrary predicate function for use with .filter().""" + + fn: Callable[[Observation[Any]], bool] = field(hash=False) + + def matches(self, obs: Observation[Any]) -> bool: + return bool(self.fn(obs)) + + +@dataclass(frozen=True) +class StreamQuery: + filters: tuple[Filter, ...] = () + order_field: str | None = None + order_desc: bool = False + limit_val: int | None = None + offset_val: int | None = None + live_buffer: BackpressureBuffer[Any] | None = None + # Vector search (embedding similarity) + search_vec: Embedding | None = field(default=None, hash=False, compare=False) + search_k: int | None = None + # Full-text search (substring / FTS5) + search_text: str | None = None + + def __str__(self) -> str: + parts: list[str] = [str(f) for f in self.filters] + if self.search_text is not None: + parts.append(f"search({self.search_text!r})") + if self.search_vec is not None: + k = f", k={self.search_k}" if self.search_k is not None else "" + parts.append(f"vector_search({k.lstrip(', ')})" if k else "vector_search()") + if self.order_field: + direction = " DESC" if self.order_desc else "" + parts.append(f"order_by({self.order_field}{direction})") + if self.offset_val: + parts.append(f"offset({self.offset_val})") + if self.limit_val is not None: + parts.append(f"limit({self.limit_val})") + return " | ".join(parts) + + def apply( + self, it: Iterator[Observation[Any]], *, live: bool = False + ) -> Iterator[Observation[Any]]: + """Apply all query operations to an iterator in Python. + + Used as the fallback execution path for transform-sourced streams + and in-memory backends. Backends with native query support (SQL, + ANN indexes) should push down operations instead. + """ + # Filters + if self.filters: + it = (obs for obs in it if all(f.matches(obs) for f in self.filters)) + + # Text search — substring match + if self.search_text is not None: + needle = self.search_text.lower() + it = (obs for obs in it if needle in str(obs.data).lower()) + + # Vector search — brute-force cosine (materializes) + if self.search_vec is not None: + if live: + raise TypeError( + ".search() requires finite data — cannot rank an infinite live stream." + ) + query_emb = self.search_vec + scored = [] + for obs in it: + emb = getattr(obs, "embedding", None) + if emb is not None: + sim = float(emb @ query_emb) + scored.append(obs.derive(data=obs.data, similarity=sim)) + scored.sort(key=lambda o: getattr(o, "similarity", 0.0) or 0.0, reverse=True) + if self.search_k is not None: + scored = scored[: self.search_k] + it = iter(scored) + + # Sort (materializes) + if self.order_field: + if live: + raise TypeError( + ".order_by() requires finite data — cannot sort an infinite live stream." + ) + key = self.order_field + desc = self.order_desc + items = sorted( + list(it), + key=lambda obs: getattr(obs, key) if getattr(obs, key, None) is not None else 0, + reverse=desc, + ) + it = iter(items) + + # Offset + limit + if self.offset_val: + it = islice(it, self.offset_val, None) + if self.limit_val is not None: + it = islice(it, self.limit_val) + + return it diff --git a/dimos/memory2/type/observation.py b/dimos/memory2/type/observation.py new file mode 100644 index 0000000000..0a6dd16ea5 --- /dev/null +++ b/dimos/memory2/type/observation.py @@ -0,0 +1,112 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass, field +import threading +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +if TYPE_CHECKING: + from collections.abc import Callable + + from dimos.models.embedding.base import Embedding + +T = TypeVar("T") + + +class _Unloaded: + """Sentinel indicating data has not been loaded yet.""" + + __slots__ = () + + def __repr__(self) -> str: + return "" + + +_UNLOADED = _Unloaded() + + +@dataclass +class Observation(Generic[T]): + """A single timestamped observation with optional spatial pose and metadata.""" + + id: int + ts: float + pose: Any | None = None + tags: dict[str, Any] = field(default_factory=dict) + _data: T | _Unloaded = field(default=_UNLOADED, repr=False) + _loader: Callable[[], T] | None = field(default=None, repr=False) + _data_lock: threading.Lock = field(default_factory=threading.Lock, repr=False) + + @property + def data(self) -> T: + val = self._data + if isinstance(val, _Unloaded): + with self._data_lock: + # Re-check after acquiring lock (double-checked locking) + val = self._data + if isinstance(val, _Unloaded): + if self._loader is None: + raise LookupError("No data and no loader set on this observation") + loaded = self._loader() + self._data = loaded + self._loader = None # release closure + return loaded + return val # type: ignore[return-value] + return val + + def derive(self, *, data: Any, **overrides: Any) -> Observation[Any]: + """Create a new observation preserving ts/pose/tags, replacing data. + + If ``embedding`` is passed, promotes the result to + :class:`EmbeddedObservation`. + """ + if "embedding" in overrides: + return EmbeddedObservation( + id=self.id, + ts=overrides.get("ts", self.ts), + pose=overrides.get("pose", self.pose), + tags=overrides.get("tags", self.tags), + _data=data, + embedding=overrides["embedding"], + similarity=overrides.get("similarity"), + ) + return Observation( + id=self.id, + ts=overrides.get("ts", self.ts), + pose=overrides.get("pose", self.pose), + tags=overrides.get("tags", self.tags), + _data=data, + ) + + +@dataclass +class EmbeddedObservation(Observation[T]): + """Observation enriched with a vector embedding and optional similarity score.""" + + embedding: Embedding | None = None + similarity: float | None = None + + def derive(self, *, data: Any, **overrides: Any) -> EmbeddedObservation[Any]: + """Preserve embedding unless explicitly replaced.""" + return EmbeddedObservation( + id=self.id, + ts=overrides.get("ts", self.ts), + pose=overrides.get("pose", self.pose), + tags=overrides.get("tags", self.tags), + _data=data, + embedding=overrides.get("embedding", self.embedding), + similarity=overrides.get("similarity", self.similarity), + ) diff --git a/dimos/memory2/utils/formatting.py b/dimos/memory2/utils/formatting.py new file mode 100644 index 0000000000..ee13fb3f36 --- /dev/null +++ b/dimos/memory2/utils/formatting.py @@ -0,0 +1,58 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Rich rendering helpers for memory types. + +All rich/ANSI logic lives here. Other modules import the mixin and +``render_text`` — nothing else needs to touch ``rich`` directly. +""" + +from __future__ import annotations + +from rich.console import Console +from rich.text import Text + +_console = Console(force_terminal=True, highlight=False) + + +def render_text(text: Text) -> str: + """Render rich Text to a terminal string with ANSI codes.""" + with _console.capture() as cap: + _console.print(text, end="", soft_wrap=True) + return cap.get() + + +def _colorize(plain: str) -> Text: + """Turn ``'name(args)'``, ``'a | b'``, or ``'a -> b'`` into rich Text with cyan names.""" + t = Text() + pipe = Text(" | ", style="dim") + arrow = Text(" -> ", style="dim") + for i, seg in enumerate(plain.split(" | ")): + if i > 0: + t.append_text(pipe) + for j, part in enumerate(seg.split(" -> ")): + if j > 0: + t.append_text(arrow) + name, _, rest = part.partition("(") + t.append(name, style="cyan") + if rest: + t.append(f"({rest}") + return t + + +class FilterRepr: + """Mixin for filters: subclass defines ``__str__``, gets colored ``__repr__`` free.""" + + def __repr__(self) -> str: + return render_text(_colorize(str(self))) diff --git a/dimos/memory2/utils/sqlite.py b/dimos/memory2/utils/sqlite.py new file mode 100644 index 0000000000..e242a6e1f5 --- /dev/null +++ b/dimos/memory2/utils/sqlite.py @@ -0,0 +1,43 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import sqlite3 + +from reactivex.disposable import Disposable + + +def open_sqlite_connection(path: str) -> sqlite3.Connection: + """Open a WAL-mode SQLite connection with sqlite-vec loaded.""" + import sqlite_vec + + conn = sqlite3.connect(path, check_same_thread=False) + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA synchronous=NORMAL") + conn.enable_load_extension(True) + sqlite_vec.load(conn) + conn.enable_load_extension(False) + return conn + + +def open_disposable_sqlite_connection( + path: str, +) -> tuple[Disposable, sqlite3.Connection]: + """Open a WAL-mode SQLite connection and return (disposable, connection). + + The disposable closes the connection when disposed. + """ + conn = open_sqlite_connection(path) + return Disposable(lambda: conn.close()), conn diff --git a/dimos/memory2/utils/validation.py b/dimos/memory2/utils/validation.py new file mode 100644 index 0000000000..636ff59327 --- /dev/null +++ b/dimos/memory2/utils/validation.py @@ -0,0 +1,25 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import re + +_IDENT_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + + +def validate_identifier(name: str) -> None: + """Reject stream names that aren't safe SQL identifiers.""" + if not _IDENT_RE.match(name): + raise ValueError(f"Invalid stream name: {name!r}") diff --git a/dimos/memory2/vectorstore/base.py b/dimos/memory2/vectorstore/base.py new file mode 100644 index 0000000000..2b26520fd6 --- /dev/null +++ b/dimos/memory2/vectorstore/base.py @@ -0,0 +1,65 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import abstractmethod +from typing import TYPE_CHECKING, Any + +from dimos.core.resource import CompositeResource +from dimos.memory2.registry import qual +from dimos.protocol.service.spec import BaseConfig, Configurable + +if TYPE_CHECKING: + from dimos.models.embedding.base import Embedding + + +class VectorStoreConfig(BaseConfig): + pass + + +class VectorStore(Configurable[VectorStoreConfig], CompositeResource): + """Pluggable storage and ANN index for embedding vectors. + + Separates vector indexing from metadata so backends can swap + search strategies (brute-force, vec0, FAISS, Qdrant) independently. + + Same shape as BlobStore: ``put`` / ``search`` / ``delete``, keyed + by ``(stream, observation_id)``. Vector index creation is lazy — the + first ``put`` for a stream determines dimensionality. + """ + + default_config: type[VectorStoreConfig] = VectorStoreConfig + + def __init__(self, **kwargs: Any) -> None: + Configurable.__init__(self, **kwargs) + CompositeResource.__init__(self) + + @abstractmethod + def put(self, stream_name: str, key: int, embedding: Embedding) -> None: + """Store an embedding vector for the given stream and observation id.""" + ... + + @abstractmethod + def search(self, stream_name: str, query: Embedding, k: int) -> list[tuple[int, float]]: + """Return top-k (observation_id, similarity) pairs, descending.""" + ... + + @abstractmethod + def delete(self, stream_name: str, key: int) -> None: + """Remove a vector. Silent if missing.""" + ... + + def serialize(self) -> dict[str, Any]: + return {"class": qual(type(self)), "config": self.config.model_dump()} diff --git a/dimos/memory2/vectorstore/memory.py b/dimos/memory2/vectorstore/memory.py new file mode 100644 index 0000000000..a34ce29108 --- /dev/null +++ b/dimos/memory2/vectorstore/memory.py @@ -0,0 +1,61 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from dimos.memory2.vectorstore.base import VectorStore, VectorStoreConfig + +if TYPE_CHECKING: + from dimos.models.embedding.base import Embedding + + +class MemoryVectorStoreConfig(VectorStoreConfig): + pass + + +class MemoryVectorStore(VectorStore): + """In-memory brute-force vector store for testing. + + Stores embeddings in a dict keyed by ``(stream, observation_id)``. + Search computes cosine similarity against all vectors in the stream. + """ + + default_config: type[MemoryVectorStoreConfig] = MemoryVectorStoreConfig + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._vectors: dict[str, dict[int, Embedding]] = {} + + def start(self) -> None: + pass + + def stop(self) -> None: + pass + + def put(self, stream_name: str, key: int, embedding: Embedding) -> None: + self._vectors.setdefault(stream_name, {})[key] = embedding + + def search(self, stream_name: str, query: Embedding, k: int) -> list[tuple[int, float]]: + vectors = self._vectors.get(stream_name, {}) + if not vectors: + return [] + scored = [(key, float(emb @ query)) for key, emb in vectors.items()] + scored.sort(key=lambda x: x[1], reverse=True) + return scored[:k] + + def delete(self, stream_name: str, key: int) -> None: + vectors = self._vectors.get(stream_name, {}) + vectors.pop(key, None) diff --git a/dimos/memory2/vectorstore/sqlite.py b/dimos/memory2/vectorstore/sqlite.py new file mode 100644 index 0000000000..fb4613825b --- /dev/null +++ b/dimos/memory2/vectorstore/sqlite.py @@ -0,0 +1,103 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import sqlite3 +from typing import TYPE_CHECKING, Any + +from pydantic import Field, model_validator + +from dimos.memory2.utils.sqlite import open_disposable_sqlite_connection +from dimos.memory2.utils.validation import validate_identifier +from dimos.memory2.vectorstore.base import VectorStore, VectorStoreConfig + +if TYPE_CHECKING: + from dimos.models.embedding.base import Embedding + + +class SqliteVectorStoreConfig(VectorStoreConfig): + conn: sqlite3.Connection | None = Field(default=None, exclude=True) + path: str | None = None + + @model_validator(mode="after") + def _conn_xor_path(self) -> SqliteVectorStoreConfig: + if self.conn is not None and self.path is not None: + raise ValueError("Specify either conn or path, not both") + if self.conn is None and self.path is None: + raise ValueError("Specify either conn or path") + return self + + +class SqliteVectorStore(VectorStore): + """Vector store backed by sqlite-vec's vec0 virtual tables. + + Creates one virtual table per stream: ``"{stream}_vec"``. + Dimensionality is determined lazily on the first ``put()``. + + Supports two construction modes: + + - ``SqliteVectorStore(conn=conn)`` — borrows an externally-managed connection. + - ``SqliteVectorStore(path="file.db")`` — opens and owns its own connection. + """ + + default_config = SqliteVectorStoreConfig + config: SqliteVectorStoreConfig + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._conn: sqlite3.Connection = self.config.conn # type: ignore[assignment] # set in start() if None + self._path = self.config.path + self._tables: dict[str, int] = {} # stream_name -> dimensionality + + def _ensure_table(self, stream_name: str, dim: int) -> None: + if stream_name in self._tables: + return + validate_identifier(stream_name) + self._conn.execute( + f'CREATE VIRTUAL TABLE IF NOT EXISTS "{stream_name}_vec" ' + f"USING vec0(embedding float[{dim}] distance_metric=cosine)" + ) + self._tables[stream_name] = dim + + def start(self) -> None: + if self._conn is None: + assert self._path is not None + disposable, self._conn = open_disposable_sqlite_connection(self._path) + self.register_disposables(disposable) + + def put(self, stream_name: str, key: int, embedding: Embedding) -> None: + vec = embedding.to_numpy().tolist() + self._ensure_table(stream_name, len(vec)) + self._conn.execute( + f'INSERT OR REPLACE INTO "{stream_name}_vec" (rowid, embedding) VALUES (?, ?)', + (key, json.dumps(vec)), + ) + + def search(self, stream_name: str, query: Embedding, k: int) -> list[tuple[int, float]]: + if stream_name not in self._tables: + return [] + vec = query.to_numpy().tolist() + rows = self._conn.execute( + f'SELECT rowid, distance FROM "{stream_name}_vec" WHERE embedding MATCH ? AND k = ?', + (json.dumps(vec), k), + ).fetchall() + # vec0 cosine distance = 1 - cosine_similarity + return [(int(row[0]), max(0.0, 1.0 - row[1])) for row in rows] + + def delete(self, stream_name: str, key: int) -> None: + if stream_name not in self._tables: + return + self._conn.execute(f'DELETE FROM "{stream_name}_vec" WHERE rowid = ?', (key,)) diff --git a/dimos/models/embedding/clip.py b/dimos/models/embedding/clip.py index 6fb42b7ccf..10e44f1cc5 100644 --- a/dimos/models/embedding/clip.py +++ b/dimos/models/embedding/clip.py @@ -75,9 +75,9 @@ def embed_text(self, *texts: str) -> Embedding | list[Embedding]: Returns embeddings as torch.Tensor on device for efficient GPU comparisons. """ with torch.inference_mode(): - inputs = self._processor(text=list(texts), return_tensors="pt", padding=True).to( - self.config.device - ) + inputs = self._processor( + text=list(texts), return_tensors="pt", padding=True, truncation=True + ).to(self.config.device) text_features = self._model.get_text_features(**inputs) if self.config.normalize: diff --git a/dimos/models/vl/florence.py b/dimos/models/vl/florence.py index b68441328a..6fa7ba3d12 100644 --- a/dimos/models/vl/florence.py +++ b/dimos/models/vl/florence.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum from functools import cached_property from PIL import Image as PILImage @@ -23,6 +24,14 @@ from dimos.msgs.sensor_msgs.Image import Image +class CaptionDetail(Enum): + """Florence-2 caption detail level.""" + + BRIEF = "" + NORMAL = "" + DETAILED = "" + + class Florence2Model(HuggingFaceModel, Captioner): """Florence-2 captioning model from Microsoft. @@ -35,6 +44,7 @@ class Florence2Model(HuggingFaceModel, Captioner): def __init__( self, model_name: str = "microsoft/Florence-2-base", + detail: CaptionDetail = CaptionDetail.NORMAL, **kwargs: object, ) -> None: """Initialize Florence-2 model. @@ -43,9 +53,11 @@ def __init__( model_name: HuggingFace model name. Options: - "microsoft/Florence-2-base" (~0.2B, fastest) - "microsoft/Florence-2-large" (~0.8B, better quality) + detail: Caption detail level **kwargs: Additional config options (device, dtype, warmup, etc.) """ super().__init__(model_name=model_name, **kwargs) + self._task_prompt = detail.value @cached_property def _processor(self) -> AutoProcessor: @@ -53,27 +65,22 @@ def _processor(self) -> AutoProcessor: self.config.model_name, trust_remote_code=self.config.trust_remote_code ) - def caption(self, image: Image, detail: str = "normal") -> str: - """Generate a caption for the image. + _STRIP_PREFIXES = ("The image shows ", "The image is a ", "A ") - Args: - image: Input image to caption - detail: Level of detail for caption: - - "brief": Short, concise caption - - "normal": Standard caption (default) - - "detailed": More detailed description + @staticmethod + def _clean_caption(text: str) -> str: + for prefix in Florence2Model._STRIP_PREFIXES: + if text.startswith(prefix): + return text[len(prefix) :] + return text + + def caption(self, image: Image) -> str: + """Generate a caption for the image. Returns: Text description of the image """ - # Map detail level to Florence-2 task prompts - task_prompts = { - "brief": "", - "normal": "", - "detailed": "", - "more_detailed": "", - } - task_prompt = task_prompts.get(detail, "") + task_prompt = self._task_prompt # Convert to PIL pil_image = PILImage.fromarray(image.to_rgb().data) @@ -101,21 +108,18 @@ def caption(self, image: Image, detail: str = "normal") -> str: # Extract caption from parsed output caption: str = parsed.get(task_prompt, generated_text) - return caption.strip() + return self._clean_caption(caption.strip()) def caption_batch(self, *images: Image) -> list[str]: """Generate captions for multiple images efficiently. - Args: - images: Input images to caption - Returns: List of text descriptions """ if not images: return [] - task_prompt = "" + task_prompt = self._task_prompt # Convert all to PIL pil_images = [PILImage.fromarray(img.to_rgb().data) for img in images] @@ -136,7 +140,7 @@ def caption_batch(self, *images: Image) -> list[str]: ) # Decode all - generated_texts = self._processor.batch_decode(generated_ids, skip_special_tokens=False) + generated_texts = self._processor.batch_decode(generated_ids, skip_special_tokens=True) # Parse outputs captions = [] @@ -144,7 +148,7 @@ def caption_batch(self, *images: Image) -> list[str]: parsed = self._processor.post_process_generation( text, task=task_prompt, image_size=pil_img.size ) - captions.append(parsed.get(task_prompt, text).strip()) + captions.append(self._clean_caption(parsed.get(task_prompt, text).strip())) return captions diff --git a/dimos/msgs/sensor_msgs/Image.py b/dimos/msgs/sensor_msgs/Image.py index 66c2876b62..8aee99435d 100644 --- a/dimos/msgs/sensor_msgs/Image.py +++ b/dimos/msgs/sensor_msgs/Image.py @@ -27,7 +27,6 @@ import reactivex as rx from reactivex import operators as ops import rerun as rr -from turbojpeg import TurboJPEG # type: ignore[import-untyped] from dimos.types.timestamped import Timestamped, TimestampedBufferCollection, to_human_readable from dimos.utils.reactive import quality_barrier @@ -377,15 +376,21 @@ def crop(self, x: int, y: int, width: int, height: int) -> Image: @property def sharpness(self) -> float: - """Return sharpness score.""" - gray = self.to_grayscale() - sx = cv2.Sobel(gray.data, cv2.CV_32F, 1, 0, ksize=5) - sy = cv2.Sobel(gray.data, cv2.CV_32F, 0, 1, ksize=5) - magnitude = cv2.magnitude(sx, sy) - mean_mag = float(magnitude.mean()) - if mean_mag <= 0: + """Return sharpness score. + + Downsamples to ~160px wide before computing Laplacian variance + for fast evaluation (~10-20x cheaper than full-res Sobel). + """ + gray = self.to_grayscale().data + # Downsample to ~160px wide for cheap evaluation + h, w = gray.shape[:2] + if w > 160: + scale = 160.0 / w + gray = cv2.resize(gray, (160, int(h * scale)), interpolation=cv2.INTER_AREA) + lap_var = cv2.Laplacian(gray, cv2.CV_32F).var() + if lap_var <= 0: return 0.0 - return float(np.clip((np.log10(mean_mag + 1) - 1.7) / 2.0, 0.0, 1.0)) + return float(np.clip((np.log10(lap_var + 1) - 1.0) / 3.0, 0.0, 1.0)) def save(self, filepath: str) -> bool: arr = self.to_opencv() @@ -504,6 +509,8 @@ def lcm_jpeg_encode(self, quality: int = 75, frame_id: str | None = None) -> byt Returns: LCM-encoded bytes with JPEG-compressed image data """ + from turbojpeg import TurboJPEG # type: ignore[import-untyped] + jpeg = TurboJPEG() msg = LCMImage() @@ -549,6 +556,8 @@ def lcm_jpeg_decode(cls, data: bytes, **kwargs: Any) -> Image: Returns: Image instance """ + from turbojpeg import TurboJPEG # type: ignore[import-untyped] + jpeg = TurboJPEG() msg = LCMImage.lcm_decode(data) diff --git a/dimos/perception/detection/type/detection3d/test_pointcloud.py b/dimos/perception/detection/type/detection3d/test_pointcloud.py index ad1c5cdf1b..2a6d7578e0 100644 --- a/dimos/perception/detection/type/detection3d/test_pointcloud.py +++ b/dimos/perception/detection/type/detection3d/test_pointcloud.py @@ -46,14 +46,14 @@ def test_detection3dpc(detection3dpc) -> None: assert aabb is not None, "Axis-aligned bounding box should not be None" # Verify AABB min values - assert aabb.min_bound[0] == pytest.approx(-3.575, abs=0.1) - assert aabb.min_bound[1] == pytest.approx(-0.375, abs=0.1) - assert aabb.min_bound[2] == pytest.approx(-0.075, abs=0.1) + assert aabb.min_bound[0] == pytest.approx(-3.575, abs=0.2) + assert aabb.min_bound[1] == pytest.approx(-0.375, abs=0.2) + assert aabb.min_bound[2] == pytest.approx(-0.075, abs=0.2) # Verify AABB max values - assert aabb.max_bound[0] == pytest.approx(-3.075, abs=0.1) - assert aabb.max_bound[1] == pytest.approx(-0.125, abs=0.1) - assert aabb.max_bound[2] == pytest.approx(0.475, abs=0.1) + assert aabb.max_bound[0] == pytest.approx(-3.075, abs=0.2) + assert aabb.max_bound[1] == pytest.approx(-0.125, abs=0.2) + assert aabb.max_bound[2] == pytest.approx(0.475, abs=0.2) # def test_point_cloud_properties(detection3dpc): """Test point cloud data and boundaries.""" @@ -68,13 +68,13 @@ def test_detection3dpc(detection3dpc) -> None: center = np.mean(points, axis=0) # Verify point cloud boundaries - assert min_pt[0] == pytest.approx(-3.575, abs=0.1) - assert min_pt[1] == pytest.approx(-0.375, abs=0.1) - assert min_pt[2] == pytest.approx(-0.075, abs=0.1) + assert min_pt[0] == pytest.approx(-3.575, abs=0.2) + assert min_pt[1] == pytest.approx(-0.375, abs=0.2) + assert min_pt[2] == pytest.approx(-0.075, abs=0.2) - assert max_pt[0] == pytest.approx(-3.075, abs=0.1) - assert max_pt[1] == pytest.approx(-0.125, abs=0.1) - assert max_pt[2] == pytest.approx(0.475, abs=0.1) + assert max_pt[0] == pytest.approx(-3.075, abs=0.2) + assert max_pt[1] == pytest.approx(-0.125, abs=0.2) + assert max_pt[2] == pytest.approx(0.475, abs=0.2) assert center[0] == pytest.approx(-3.326, abs=0.1) assert center[1] == pytest.approx(-0.202, abs=0.1) diff --git a/dimos/utils/docs/doclinks.py b/dimos/utils/docs/doclinks.py index 2cf5d1702f..4d2fb6dc1c 100644 --- a/dimos/utils/docs/doclinks.py +++ b/dimos/utils/docs/doclinks.py @@ -360,12 +360,13 @@ def replace_code_match(match: re.Match[str]) -> str: resolved_path = resolve_candidates(candidates, file_ref) if resolved_path is None: + doc_rel = doc_path.relative_to(root) if doc_path.is_relative_to(root) else doc_path if len(candidates) > 1: errors.append( - f"'{file_ref}' matches multiple files: {[str(c) for c in candidates]}" + f"'{file_ref}' in {doc_rel} matches multiple files: {[str(c) for c in candidates]}" ) else: - errors.append(f"No file matching '{file_ref}' found in codebase") + errors.append(f"No file matching '{file_ref}' found in codebase (in {doc_rel})") return full_match # Determine line fragment @@ -438,12 +439,13 @@ def replace_link_match(match: re.Match[str]) -> str: if result != full_match: changes.append(f" {link_text}: .md -> {new_link}") return result + doc_rel = doc_path.relative_to(root) if doc_path.is_relative_to(root) else doc_path if len(candidates) > 1: errors.append( - f"'{link_text}' matches multiple docs: {[str(c) for c in candidates]}" + f"'{link_text}' in {doc_rel} matches multiple docs: {[str(c) for c in candidates]}" ) else: - errors.append(f"No doc matching '{link_text}' found") + errors.append(f"No doc matching '{link_text}' found (in {doc_rel})") return full_match # Absolute path @@ -460,12 +462,13 @@ def replace_link_match(match: re.Match[str]) -> str: ) changes.append(f" {link_text}: {raw_link} -> {new_link} (fixed broken link)") return f"[{link_text}]({new_link})" + doc_rel = doc_path.relative_to(root) if doc_path.is_relative_to(root) else doc_path if len(candidates) > 1: errors.append( - f"Broken link '{raw_link}': ambiguous, matches {[str(c) for c in candidates]}" + f"Broken link '{raw_link}' in {doc_rel}: ambiguous, matches {[str(c) for c in candidates]}" ) else: - errors.append(f"Broken link: '{raw_link}' does not exist") + errors.append(f"Broken link '{raw_link}' in {doc_rel}: does not exist") return full_match # Relative path — resolve from doc file's directory @@ -475,7 +478,8 @@ def replace_link_match(match: re.Match[str]) -> str: try: rel_to_root = resolved_abs.relative_to(root) except ValueError: - errors.append(f"Link '{raw_link}' resolves outside repo root") + doc_rel = doc_path.relative_to(root) if doc_path.is_relative_to(root) else doc_path + errors.append(f"Link '{raw_link}' in {doc_rel} resolves outside repo root") return full_match if resolved_abs.exists(): @@ -496,12 +500,13 @@ def replace_link_match(match: re.Match[str]) -> str: ) changes.append(f" {link_text}: {raw_link} -> {new_link} (found by search)") return f"[{link_text}]({new_link})" + doc_rel = doc_path.relative_to(root) if doc_path.is_relative_to(root) else doc_path if len(candidates) > 1: errors.append( - f"Broken link '{raw_link}': ambiguous, matches {[str(c) for c in candidates]}" + f"Broken link '{raw_link}' in {doc_rel}: ambiguous, matches {[str(c) for c in candidates]}" ) else: - errors.append(f"Broken link '{raw_link}': target not found") + errors.append(f"Broken link '{raw_link}' in {doc_rel}: target not found") return full_match # Split by ignore regions and only process non-ignored parts diff --git a/dimos/utils/threadpool.py b/dimos/utils/threadpool.py index a2adc90725..f2fd577d40 100644 --- a/dimos/utils/threadpool.py +++ b/dimos/utils/threadpool.py @@ -36,7 +36,7 @@ def get_max_workers() -> int: environment variable, defaulting to 4 times the CPU count. """ env_value = os.getenv("DIMOS_MAX_WORKERS", "") - return int(env_value) if env_value.strip() else multiprocessing.cpu_count() + return int(env_value) if env_value.strip() else min(8, multiprocessing.cpu_count()) # Create a ThreadPoolScheduler with a configurable number of workers. diff --git a/docker/python/Dockerfile b/docker/python/Dockerfile index 16b4db1807..d14b281603 100644 --- a/docker/python/Dockerfile +++ b/docker/python/Dockerfile @@ -31,7 +31,8 @@ RUN apt-get update && apt-get install -y \ qtbase5-dev-tools \ supervisor \ iproute2 # for LCM networking system config \ - liblcm-dev + liblcm-dev \ + libturbojpeg0-dev # Fix distutils-installed packages that block pip upgrades RUN apt-get purge -y python3-blinker python3-sympy python3-oauthlib || true diff --git a/docs/usage/transports/index.md b/docs/usage/transports/index.md index db931872bd..09ccb484ed 100644 --- a/docs/usage/transports/index.md +++ b/docs/usage/transports/index.md @@ -357,7 +357,7 @@ Received 2 messages: {'temperature': 23.0} ``` -See [`memory.py`](/dimos/protocol/pubsub/impl/memory.py) for the complete source. +See [`pubsub/impl/memory.py`](/dimos/protocol/pubsub/impl/memory.py) for the complete source. --- diff --git a/pyproject.toml b/pyproject.toml index 722e3b0485..52ed08505d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,6 @@ dependencies = [ "annotation-protocol>=1.4.0", "lazy_loader", "plum-dispatch==2.5.7", - # Logging "structlog>=25.5.0,<26", "colorlog==6.9.0", @@ -86,6 +85,8 @@ dependencies = [ "toolz>=1.1.0", "protobuf>=6.33.5,<7", "psutil>=7.0.0", + "sqlite-vec>=0.1.6", + "lz4>=4.4.5", ] @@ -271,6 +272,7 @@ dev = [ "types-tensorflow>=2.18.0.20251008,<3", "types-tqdm>=4.67.0.20250809,<5", "types-psycopg2>=2.9.21.20251012", + "scipy-stubs>=1.15.0", "types-psutil>=7.2.2.20260130,<8", # Tools @@ -407,6 +409,7 @@ module = [ "rclpy.*", "sam2.*", "sensor_msgs.*", + "sqlite_vec", "std_msgs.*", "tf2_msgs.*", "torchreid", diff --git a/uv.lock b/uv.lock index 5ec39fff59..6e6ebc9810 100644 --- a/uv.lock +++ b/uv.lock @@ -1686,6 +1686,7 @@ dependencies = [ { name = "dimos-viewer" }, { name = "lazy-loader" }, { name = "llvmlite" }, + { name = "lz4" }, { name = "numba" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, @@ -1706,6 +1707,7 @@ dependencies = [ { name = "scipy", version = "1.15.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "scipy", version = "1.17.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "sortedcontainers" }, + { name = "sqlite-vec" }, { name = "structlog" }, { name = "terminaltexteffects" }, { name = "textual" }, @@ -1791,6 +1793,8 @@ dds = [ { name = "python-lsp-server", extra = ["all"] }, { name = "requests-mock" }, { name = "ruff" }, + { name = "scipy-stubs", version = "1.15.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "scipy-stubs", version = "1.17.1.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "terminaltexteffects" }, { name = "types-colorama" }, { name = "types-defusedxml" }, @@ -1828,6 +1832,8 @@ dev = [ { name = "python-lsp-server", extra = ["all"] }, { name = "requests-mock" }, { name = "ruff" }, + { name = "scipy-stubs", version = "1.15.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "scipy-stubs", version = "1.17.1.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "terminaltexteffects" }, { name = "types-colorama" }, { name = "types-defusedxml" }, @@ -2021,6 +2027,7 @@ requires-dist = [ { name = "lcm", marker = "extra == 'docker'" }, { name = "llvmlite", specifier = ">=0.42.0" }, { name = "lxml-stubs", marker = "extra == 'dev'", specifier = ">=0.5.1,<1" }, + { name = "lz4", specifier = ">=4.4.5" }, { name = "matplotlib", marker = "extra == 'manipulation'", specifier = ">=3.7.1" }, { name = "md-babel-py", marker = "extra == 'dev'", specifier = "==1.1.1" }, { name = "moondream", marker = "extra == 'perception'" }, @@ -2090,11 +2097,13 @@ requires-dist = [ { name = "scikit-learn", marker = "extra == 'misc'" }, { name = "scipy", specifier = ">=1.15.1" }, { name = "scipy", marker = "extra == 'docker'", specifier = ">=1.15.1" }, + { name = "scipy-stubs", marker = "extra == 'dev'", specifier = ">=1.15.0" }, { name = "sentence-transformers", marker = "extra == 'misc'" }, { name = "sortedcontainers", specifier = "==2.4.0" }, { name = "sortedcontainers", marker = "extra == 'docker'" }, { name = "sounddevice", marker = "extra == 'agents'" }, { name = "soundfile", marker = "extra == 'web'" }, + { name = "sqlite-vec", specifier = ">=0.1.6" }, { name = "sse-starlette", marker = "extra == 'web'", specifier = ">=2.2.1" }, { name = "structlog", specifier = ">=25.5.0,<26" }, { name = "structlog", marker = "extra == 'docker'", specifier = ">=25.5.0,<26" }, @@ -5556,6 +5565,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2d/ee/346fa473e666fe14c52fcdd19ec2424157290a032d4c41f98127bfb31ac7/numpy-2.3.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:f16417ec91f12f814b10bafe79ef77e70113a2f5f7018640e7425ff979253425", size = 12967213, upload-time = "2025-11-16T22:52:39.38Z" }, ] +[[package]] +name = "numpy-typing-compat" +version = "20251206.2.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/77/83/dd90774d6685664cbe5525645a50c4e6c7454207aee552918790e879137f/numpy_typing_compat-20251206.2.3.tar.gz", hash = "sha256:18e00e0f4f2040fe98574890248848c7c6831a975562794da186cf4f3c90b935", size = 5009, upload-time = "2025-12-06T20:02:04.177Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/48/6f/dde8e2a79a3b6cbc31bc1037c1a1dbc07c90d52d946851bd7cba67e730a8/numpy_typing_compat-20251206.2.3-py3-none-any.whl", hash = "sha256:bfa2e4c4945413e84552cbd34a6d368c88a06a54a896e77ced760521b08f0f61", size = 6300, upload-time = "2025-12-06T20:01:56.664Z" }, +] + [[package]] name = "nvidia-cublas-cu12" version = "12.8.4.1" @@ -6219,6 +6240,60 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b8/ec/19c6cc6064c7fc8f0cd6d5b37c4747849e66040c6ca98f86565efc2c227c/optax-0.2.6-py3-none-any.whl", hash = "sha256:f875251a5ab20f179d4be57478354e8e21963373b10f9c3b762b94dcb8c36d91", size = 367782, upload-time = "2025-09-15T22:41:22.825Z" }, ] +[[package]] +name = "optype" +version = "0.9.3" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.11' and sys_platform == 'darwin'", + "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.11' and sys_platform == 'win32'", + "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", +] +dependencies = [ + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/88/3c/9d59b0167458b839273ad0c4fc5f62f787058d8f5aed7f71294963a99471/optype-0.9.3.tar.gz", hash = "sha256:5f09d74127d316053b26971ce441a4df01f3a01943601d3712dd6f34cdfbaf48", size = 96143, upload-time = "2025-03-31T17:00:08.392Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/73/d8/ac50e2982bdc2d3595dc2bfe3c7e5a0574b5e407ad82d70b5f3707009671/optype-0.9.3-py3-none-any.whl", hash = "sha256:2935c033265938d66cc4198b0aca865572e635094e60e6e79522852f029d9e8d", size = 84357, upload-time = "2025-03-31T17:00:06.464Z" }, +] + +[[package]] +name = "optype" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.14' and sys_platform == 'darwin'", + "python_full_version == '3.13.*' and sys_platform == 'darwin'", + "python_full_version == '3.12.*' and sys_platform == 'darwin'", + "python_full_version >= '3.14' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version >= '3.14' and sys_platform == 'win32'", + "python_full_version == '3.13.*' and sys_platform == 'win32'", + "(python_full_version >= '3.14' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.14' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "(python_full_version == '3.13.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.13.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version == '3.12.*' and sys_platform == 'win32'", + "(python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version == '3.11.*' and sys_platform == 'darwin'", + "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.11.*' and sys_platform == 'win32'", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", +] +dependencies = [ + { name = "typing-extensions", marker = "python_full_version >= '3.11' and python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/66/d3/c88bb4bd90867356275ca839499313851af4b36fce6919ebc5e1de26e7ca/optype-0.16.0.tar.gz", hash = "sha256:fa682fd629ef6b70ba656ebc9fdd6614ba06ce13f52e0416dd8014c7e691a2d1", size = 53498, upload-time = "2026-02-19T23:37:09.495Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/91/a8/fe26515203cff140f1afc31236fb7f703d4bb4bd5679d28afcb3661c8d9f/optype-0.16.0-py3-none-any.whl", hash = "sha256:c28905713f55630b4bb8948f38e027ad13a541499ebcf957501f486da54b74d2", size = 65893, upload-time = "2026-02-19T23:37:08.217Z" }, +] + +[package.optional-dependencies] +numpy = [ + { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "numpy-typing-compat", marker = "python_full_version >= '3.11'" }, +] + [[package]] name = "orbax-checkpoint" version = "0.11.32" @@ -8920,6 +8995,54 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/56/a5/df8f46ef7da168f1bc52cd86e09a9de5c6f19cc1da04454d51b7d4f43408/scipy-1.17.0-cp314-cp314t-win_arm64.whl", hash = "sha256:031121914e295d9791319a1875444d55079885bbae5bdc9c5e0f2ee5f09d34ff", size = 25246266, upload-time = "2026-01-10T21:30:45.923Z" }, ] +[[package]] +name = "scipy-stubs" +version = "1.15.3.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.11' and sys_platform == 'darwin'", + "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.11' and sys_platform == 'win32'", + "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", +] +dependencies = [ + { name = "optype", version = "0.9.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0b/5f/35c43bd7d412add4adcd68475702571b2489b50c40b6564f808b2355e452/scipy_stubs-1.15.3.0.tar.gz", hash = "sha256:e8f76c9887461cf9424c1e2ad78ea5dac71dd4cbb383dc85f91adfe8f74d1e17", size = 275699, upload-time = "2025-05-08T16:58:35.139Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6c/42/cd8dc81f8060de1f14960885ad5b2d2651f41de8b93d09f3f919d6567a5a/scipy_stubs-1.15.3.0-py3-none-any.whl", hash = "sha256:a251254cf4fd6e7fb87c55c1feee92d32ddbc1f542ecdf6a0159cdb81c2fb62d", size = 459062, upload-time = "2025-05-08T16:58:33.356Z" }, +] + +[[package]] +name = "scipy-stubs" +version = "1.17.1.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.14' and sys_platform == 'darwin'", + "python_full_version == '3.13.*' and sys_platform == 'darwin'", + "python_full_version == '3.12.*' and sys_platform == 'darwin'", + "python_full_version >= '3.14' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version >= '3.14' and sys_platform == 'win32'", + "python_full_version == '3.13.*' and sys_platform == 'win32'", + "(python_full_version >= '3.14' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.14' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "(python_full_version == '3.13.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.13.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version == '3.12.*' and sys_platform == 'win32'", + "(python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version == '3.11.*' and sys_platform == 'darwin'", + "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.11.*' and sys_platform == 'win32'", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", +] +dependencies = [ + { name = "optype", version = "0.16.0", source = { registry = "https://pypi.org/simple" }, extra = ["numpy"], marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d9/ad/413b0d18efca7bb48574d28e91253409d91ee6121e7937022d0d380dfc6a/scipy_stubs-1.17.1.0.tar.gz", hash = "sha256:5dc51c21765b145c2d132b96b63ff4f835dd5fb768006876d1554e7a59c61571", size = 381420, upload-time = "2026-02-23T10:33:04.742Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6c/ee/c6811e04ff9d5dd1d92236e8df7ebc4db6aa65c70b9938cec293348b8ec4/scipy_stubs-1.17.1.0-py3-none-any.whl", hash = "sha256:5c9c84993d36b104acb2d187b05985eb79f73491c60d83292dd738093d53d96a", size = 587059, upload-time = "2026-02-23T10:33:02.845Z" }, +] + [[package]] name = "sentence-transformers" version = "5.2.2" @@ -9114,6 +9237,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/46/2c/1462b1d0a634697ae9e55b3cecdcb64788e8b7d63f54d923fcd0bb140aed/soupsieve-2.8.3-py3-none-any.whl", hash = "sha256:ed64f2ba4eebeab06cc4962affce381647455978ffc1e36bb79a545b91f45a95", size = 37016, upload-time = "2026-01-20T04:27:01.012Z" }, ] +[[package]] +name = "sqlite-vec" +version = "0.1.6" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/ed/aabc328f29ee6814033d008ec43e44f2c595447d9cccd5f2aabe60df2933/sqlite_vec-0.1.6-py3-none-macosx_10_6_x86_64.whl", hash = "sha256:77491bcaa6d496f2acb5cc0d0ff0b8964434f141523c121e313f9a7d8088dee3", size = 164075, upload-time = "2024-11-20T16:40:29.847Z" }, + { url = "https://files.pythonhosted.org/packages/a7/57/05604e509a129b22e303758bfa062c19afb020557d5e19b008c64016704e/sqlite_vec-0.1.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fdca35f7ee3243668a055255d4dee4dea7eed5a06da8cad409f89facf4595361", size = 165242, upload-time = "2024-11-20T16:40:31.206Z" }, + { url = "https://files.pythonhosted.org/packages/f2/48/dbb2cc4e5bad88c89c7bb296e2d0a8df58aab9edc75853728c361eefc24f/sqlite_vec-0.1.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b0519d9cd96164cd2e08e8eed225197f9cd2f0be82cb04567692a0a4be02da3", size = 103704, upload-time = "2024-11-20T16:40:33.729Z" }, + { url = "https://files.pythonhosted.org/packages/80/76/97f33b1a2446f6ae55e59b33869bed4eafaf59b7f4c662c8d9491b6a714a/sqlite_vec-0.1.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux1_x86_64.whl", hash = "sha256:823b0493add80d7fe82ab0fe25df7c0703f4752941aee1c7b2b02cec9656cb24", size = 151556, upload-time = "2024-11-20T16:40:35.387Z" }, + { url = "https://files.pythonhosted.org/packages/6a/98/e8bc58b178266eae2fcf4c9c7a8303a8d41164d781b32d71097924a6bebe/sqlite_vec-0.1.6-py3-none-win_amd64.whl", hash = "sha256:c65bcfd90fa2f41f9000052bcb8bb75d38240b2dae49225389eca6c3136d3f0c", size = 281540, upload-time = "2024-11-20T16:40:37.296Z" }, +] + [[package]] name = "sse-starlette" version = "3.2.0"