diff --git a/src/intelstream/database/repository.py b/src/intelstream/database/repository.py index 5068a24..a27c07d 100644 --- a/src/intelstream/database/repository.py +++ b/src/intelstream/database/repository.py @@ -938,6 +938,23 @@ async def add_message_chunk_metas_batch(self, chunks: list[MessageChunkMeta]) -> session.add_all(chunks) await session.commit() + async def count_message_chunk_metas(self) -> int: + async with self.session() as session: + result = await session.execute(select(func.count()).select_from(MessageChunkMeta)) + return int(result.scalar_one()) + + async def get_message_chunk_metas_batch( + self, offset: int = 0, limit: int = 100 + ) -> list[MessageChunkMeta]: + async with self.session() as session: + result = await session.execute( + select(MessageChunkMeta) + .order_by(MessageChunkMeta.start_timestamp.asc(), MessageChunkMeta.id.asc()) + .offset(offset) + .limit(limit) + ) + return list(result.scalars().all()) + async def get_message_chunk_metas_by_ids(self, chunk_ids: list[str]) -> list[MessageChunkMeta]: if not chunk_ids: return [] diff --git a/src/intelstream/database/vector_store.py b/src/intelstream/database/vector_store.py index 7a1d4f8..aa51986 100644 --- a/src/intelstream/database/vector_store.py +++ b/src/intelstream/database/vector_store.py @@ -2,7 +2,9 @@ import asyncio import os +import shutil from dataclasses import dataclass +from pathlib import Path from typing import TYPE_CHECKING, Any import structlog @@ -26,6 +28,9 @@ class ChunkSearchResult: class VectorStore: + _ARTICLES_COLLECTION = "articles" + _MESSAGE_CHUNKS_COLLECTION = "message_chunks" + def __init__(self, data_dir: str, dimensions: int = 384) -> None: self._data_dir = data_dir self._dimensions = dimensions @@ -33,40 +38,79 @@ def __init__(self, data_dir: str, dimensions: int = 384) -> None: self._message_chunks: zvec.Collection | None = None async def initialize(self) -> None: + await asyncio.to_thread(os.makedirs, self._data_dir, exist_ok=True) + self._articles = await self._open_or_create_collection(self._ARTICLES_COLLECTION) + self._message_chunks = await self._open_or_create_collection( + self._MESSAGE_CHUNKS_COLLECTION + ) + + def _collection_path(self, collection_name: str) -> str: + return str(Path(self._data_dir) / collection_name) + + def _collection_attr_name(self, collection_name: str) -> str: + if collection_name == self._ARTICLES_COLLECTION: + return "_articles" + if collection_name == self._MESSAGE_CHUNKS_COLLECTION: + return "_message_chunks" + raise ValueError(f"Unknown collection name: {collection_name}") + + def _build_schema(self, collection_name: str) -> zvec.CollectionSchema: import zvec - await asyncio.to_thread(os.makedirs, self._data_dir, exist_ok=True) - articles_path = f"{self._data_dir}/articles" - try: - schema = zvec.CollectionSchema( - name="articles", - vectors=zvec.VectorSchema("embedding", zvec.DataType.VECTOR_FP32, self._dimensions), - ) - self._articles = await asyncio.to_thread( - zvec.create_and_open, path=articles_path, schema=schema - ) - logger.info("Created new articles vector collection") - except Exception: - self._articles = await asyncio.to_thread( - zvec.open, path=articles_path, option=zvec.CollectionOption() - ) - logger.info("Opened existing articles vector collection") + return zvec.CollectionSchema( + name=collection_name, + vectors=zvec.VectorSchema("embedding", zvec.DataType.VECTOR_FP32, self._dimensions), + ) + + async def _open_or_create_collection(self, collection_name: str) -> zvec.Collection: + import zvec - chunks_path = f"{self._data_dir}/message_chunks" + path = self._collection_path(collection_name) try: - schema = zvec.CollectionSchema( - name="message_chunks", - vectors=zvec.VectorSchema("embedding", zvec.DataType.VECTOR_FP32, self._dimensions), + collection = await asyncio.to_thread( + zvec.create_and_open, + path=path, + schema=self._build_schema(collection_name), ) - self._message_chunks = await asyncio.to_thread( - zvec.create_and_open, path=chunks_path, schema=schema - ) - logger.info("Created new message_chunks vector collection") + logger.info("Created new vector collection", collection=collection_name) + return collection except Exception: - self._message_chunks = await asyncio.to_thread( - zvec.open, path=chunks_path, option=zvec.CollectionOption() + collection = await asyncio.to_thread( + zvec.open, + path=path, + option=zvec.CollectionOption(), ) - logger.info("Opened existing message_chunks vector collection") + logger.info("Opened existing vector collection", collection=collection_name) + return collection + + async def _recreate_collection(self, collection_name: str) -> zvec.Collection: + attr_name = self._collection_attr_name(collection_name) + collection = getattr(self, attr_name) + path = self._collection_path(collection_name) + + if collection is not None: + try: + await asyncio.to_thread(collection.destroy) + except Exception: + logger.warning( + "Failed to destroy vector collection cleanly, removing files manually", + collection=collection_name, + path=path, + ) + finally: + setattr(self, attr_name, None) + + if await asyncio.to_thread(os.path.exists, path): + await asyncio.to_thread(shutil.rmtree, path, True) + + recreated = await self._open_or_create_collection(collection_name) + setattr(self, attr_name, recreated) + return recreated + + async def _doc_count(self, collection: zvec.Collection | None) -> int: + if collection is None: + raise RuntimeError("VectorStore not initialized") + return await asyncio.to_thread(lambda: int(collection.stats.doc_count)) async def upsert_article(self, content_item_id: str, embedding: list[float]) -> None: import zvec @@ -108,6 +152,9 @@ async def delete_article(self, content_item_id: str) -> None: raise RuntimeError("VectorStore not initialized") await asyncio.to_thread(self._articles.delete, content_item_id) + async def article_doc_count(self) -> int: + return await self._doc_count(self._articles) + async def upsert_message_chunk(self, chunk_id: str, embedding: list[float]) -> None: import zvec @@ -149,6 +196,12 @@ async def delete_message_chunks_by_ids(self, chunk_ids: list[str]) -> None: for chunk_id in chunk_ids: await asyncio.to_thread(self._message_chunks.delete, chunk_id) + async def message_chunk_doc_count(self) -> int: + return await self._doc_count(self._message_chunks) + + async def recreate_message_chunks_collection(self) -> None: + await self._recreate_collection(self._MESSAGE_CHUNKS_COLLECTION) + async def close(self) -> None: if self._articles is not None: await asyncio.to_thread(self._articles.flush) diff --git a/src/intelstream/discord/cogs/lore.py b/src/intelstream/discord/cogs/lore.py index e1f954e..675c672 100644 --- a/src/intelstream/discord/cogs/lore.py +++ b/src/intelstream/discord/cogs/lore.py @@ -1,6 +1,8 @@ from __future__ import annotations +import asyncio import re +from contextlib import suppress from datetime import UTC, datetime, timedelta from typing import TYPE_CHECKING @@ -33,6 +35,7 @@ BUFFER_FLUSH_MINUTES = 5 MAX_DISCORD_MESSAGE_LENGTH = 2000 +HEALTH_CHECK_TOPK = 10 def _parse_timeframe(timeframe: str) -> tuple[datetime | None, datetime | None]: @@ -89,6 +92,8 @@ def __init__( self._llm_client: LLMClient | None = None self._message_buffers: dict[str, list[RawMessage]] = {} self._chunker: MessageChunker | None = None + self._index_rebuild_task: asyncio.Task[None] | None = None + self._index_rebuild_error: str | None = None async def cog_load(self) -> None: self._ingestion_service = MessageIngestionService( @@ -111,10 +116,18 @@ async def cog_load(self) -> None: max_messages=self.bot.settings.lore_chunk_max_messages, ) self._flush_buffers.start() + self._index_rebuild_task = asyncio.create_task( + self._ensure_message_chunk_index(), + name="lore-index-rebuild", + ) logger.info("Lore cog loaded") async def cog_unload(self) -> None: self._flush_buffers.cancel() + if self._index_rebuild_task is not None: + self._index_rebuild_task.cancel() + with suppress(asyncio.CancelledError): + await self._index_rebuild_task if self._ingestion_service and self._ingestion_service.is_running: self._ingestion_service.stop_backfill() await self._flush_all_buffers() @@ -135,12 +148,82 @@ async def lore( channel: discord.TextChannel | None = None, # noqa: ARG002 timeframe: str | None = None, # noqa: ARG002 ) -> None: + if self._index_rebuild_task is not None and not self._index_rebuild_task.done(): + message = ( + "The /lore command is temporarily disabled while the message index is being " + "rebuilt. Check back soon!" + ) + elif self._index_rebuild_error is not None: + message = ( + "The /lore command is temporarily disabled because the message index needs " + "recovery. Check logs and try again after reindexing completes." + ) + else: + message = ( + "The /lore command is temporarily disabled while the message index is being " + "built. Check back soon!" + ) await interaction.response.send_message( - "The /lore command is temporarily disabled while the message index is being built. " - "Check back soon!", + message, ephemeral=True, ) + async def _ensure_message_chunk_index(self) -> None: + if self._ingestion_service is None: + return + + try: + expected_count = await self.bot.repository.count_message_chunk_metas() + if expected_count == 0: + logger.info("No stored lore chunks found; skipping vector index rebuild") + return + + if await self._message_index_is_healthy(expected_count): + logger.info("Lore message index is healthy", chunks=expected_count) + return + + logger.warning( + "Lore message index is unhealthy; rebuilding from stored chunks", + expected_chunks=expected_count, + ) + rebuilt = await self._ingestion_service.rebuild_vector_index() + logger.info("Lore message index rebuilt", indexed=rebuilt) + except asyncio.CancelledError: + raise + except Exception as exc: + self._index_rebuild_error = str(exc) + logger.exception("Failed to rebuild lore message index", error=str(exc)) + + async def _message_index_is_healthy(self, expected_count: int) -> bool: + indexed_count = await self._vector_store.message_chunk_doc_count() + if indexed_count != expected_count: + logger.warning( + "Lore message index count mismatch", + expected=expected_count, + indexed=indexed_count, + ) + return False + + sample_batch = await self.bot.repository.get_message_chunk_metas_batch(limit=1) + if not sample_batch: + return True + + sample = sample_batch[0] + query_embedding = await self._embedding_service.embed_text(sample.text) + results = await self._vector_store.search_message_chunks( + query_embedding, + topk=HEALTH_CHECK_TOPK, + ) + if any(result.chunk_id == sample.id for result in results): + return True + + logger.warning( + "Lore message index probe failed", + sample_chunk_id=sample.id, + result_ids=[result.chunk_id for result in results], + ) + return False + @commands.Cog.listener("on_message") async def on_message(self, message: discord.Message) -> None: if not message.guild: diff --git a/src/intelstream/services/message_ingestion.py b/src/intelstream/services/message_ingestion.py index 40e9d8e..4640ca3 100644 --- a/src/intelstream/services/message_ingestion.py +++ b/src/intelstream/services/message_ingestion.py @@ -237,6 +237,44 @@ async def store_chunks(self, chunks: list[Chunk]) -> int: return len(metas) + async def rebuild_vector_index(self, batch_size: int = EMBED_BATCH_SIZE) -> int: + total_chunks = await self._repository.count_message_chunk_metas() + await self._vector_store.recreate_message_chunks_collection() + + if total_chunks == 0: + logger.info("No stored message chunks to reindex") + return 0 + + indexed = 0 + offset = 0 + + while True: + metas = await self._repository.get_message_chunk_metas_batch( + offset=offset, + limit=batch_size, + ) + if not metas: + break + + embeddings = await self._embedding_service.embed_batch([meta.text for meta in metas]) + vector_items = [ + (meta.id, embedding) for meta, embedding in zip(metas, embeddings, strict=True) + ] + await self._vector_store.upsert_message_chunks_batch(vector_items) + + indexed += len(metas) + offset += len(metas) + + if indexed == total_chunks or indexed % (batch_size * 10) == 0: + logger.info( + "Lore vector index rebuild progress", + indexed=indexed, + total=total_chunks, + ) + + logger.info("Lore vector index rebuild complete", indexed=indexed, total=total_chunks) + return indexed + async def ingest_channel( self, channel: discord.TextChannel, diff --git a/tests/test_discord/test_lore.py b/tests/test_discord/test_lore.py index 8e39f14..b8bce88 100644 --- a/tests/test_discord/test_lore.py +++ b/tests/test_discord/test_lore.py @@ -1,9 +1,11 @@ +import asyncio from datetime import UTC, datetime from unittest.mock import AsyncMock, MagicMock import discord import pytest +from intelstream.database.vector_store import ChunkSearchResult from intelstream.discord.cogs.lore import Lore, _parse_timeframe, _split_message @@ -18,6 +20,8 @@ def mock_bot(): bot.settings.llm_api_key = "test-key" bot.settings.summary_model_interactive = "claude-test" bot.repository = AsyncMock() + bot.repository.count_message_chunk_metas = AsyncMock(return_value=0) + bot.repository.get_message_chunk_metas_batch = AsyncMock(return_value=[]) bot.get_guild = MagicMock(return_value=None) bot.guilds = [] bot.cogs = {} @@ -34,6 +38,7 @@ def mock_embedding_service(): @pytest.fixture def mock_vector_store(): store = AsyncMock() + store.message_chunk_doc_count = AsyncMock(return_value=0) store.search_message_chunks = AsyncMock(return_value=[]) return store @@ -43,6 +48,7 @@ def lore_cog(mock_bot, mock_embedding_service, mock_vector_store): cog = Lore(mock_bot, mock_embedding_service, mock_vector_store) cog._ingestion_service = MagicMock() cog._ingestion_service.is_running = False + cog._ingestion_service.rebuild_vector_index = AsyncMock(return_value=0) cog._llm_client = AsyncMock() cog._llm_client.complete = AsyncMock(return_value="Here is the lore about that topic.") cog._chunker = MagicMock() @@ -153,6 +159,57 @@ async def test_cog_load_without_api_key( assert cog._llm_client is None assert cog._ingestion_service is not None assert cog._chunker is not None + if cog._index_rebuild_task is not None: + await cog._index_rebuild_task + + +class TestIndexHealth: + async def test_message_index_healthy(self, lore_cog, mock_bot, mock_vector_store): + mock_bot.repository.get_message_chunk_metas_batch.return_value = [ + MagicMock(id="chunk-1", text="sample chunk text") + ] + mock_vector_store.message_chunk_doc_count.return_value = 1 + mock_vector_store.search_message_chunks.return_value = [ + ChunkSearchResult(chunk_id="chunk-1", score=1.0) + ] + + result = await lore_cog._message_index_is_healthy(expected_count=1) + + assert result is True + + async def test_message_index_unhealthy_on_count_mismatch(self, lore_cog, mock_vector_store): + mock_vector_store.message_chunk_doc_count.return_value = 0 + + result = await lore_cog._message_index_is_healthy(expected_count=2) + + assert result is False + mock_vector_store.search_message_chunks.assert_not_called() + + async def test_ensure_message_chunk_index_rebuilds_unhealthy_index( + self, lore_cog, mock_bot, mock_vector_store + ): + lore_cog._ingestion_service = MagicMock() + lore_cog._ingestion_service.rebuild_vector_index = AsyncMock(return_value=3) + mock_bot.repository.count_message_chunk_metas.return_value = 3 + mock_vector_store.message_chunk_doc_count.return_value = 0 + + await lore_cog._ensure_message_chunk_index() + + lore_cog._ingestion_service.rebuild_vector_index.assert_awaited_once() + + async def test_command_mentions_rebuild_in_progress(self, lore_cog, mock_interaction): + lore_cog._index_rebuild_task = asyncio.create_task(asyncio.sleep(0.1)) + + try: + await lore_cog.lore.callback(lore_cog, mock_interaction, "test query") + finally: + lore_cog._index_rebuild_task.cancel() + with pytest.raises(asyncio.CancelledError): + await lore_cog._index_rebuild_task + + mock_interaction.response.send_message.assert_called_once() + msg = mock_interaction.response.send_message.call_args[0][0].lower() + assert "rebuilt" in msg or "rebuild" in msg class TestAutoStartIngestion: diff --git a/tests/test_services/test_message_ingestion.py b/tests/test_services/test_message_ingestion.py index 10fe442..102fc57 100644 --- a/tests/test_services/test_message_ingestion.py +++ b/tests/test_services/test_message_ingestion.py @@ -1,5 +1,5 @@ from datetime import UTC, datetime, timedelta -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, MagicMock import pytest @@ -302,6 +302,47 @@ async def test_store_chunks_single(self, service, mock_deps): assert metas[0].channel_id == "222" assert metas[0].message_count == 3 + async def test_rebuild_vector_index(self, service, mock_deps): + repository, embedding_service, vector_store = mock_deps + repository.count_message_chunk_metas = AsyncMock(return_value=3) + + meta1 = MagicMock(id="chunk-1", text="first chunk text") + meta2 = MagicMock(id="chunk-2", text="second chunk text") + meta3 = MagicMock(id="chunk-3", text="third chunk text") + repository.get_message_chunk_metas_batch = AsyncMock( + side_effect=[ + [meta1, meta2], + [meta3], + [], + ] + ) + embedding_service.embed_batch = AsyncMock( + side_effect=[ + [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + [[0.7, 0.8, 0.9]], + ] + ) + + result = await service.rebuild_vector_index(batch_size=2) + + assert result == 3 + vector_store.recreate_message_chunks_collection.assert_called_once() + assert vector_store.upsert_message_chunks_batch.await_count == 2 + repository.get_message_chunk_metas_batch.assert_any_call(offset=0, limit=2) + repository.get_message_chunk_metas_batch.assert_any_call(offset=2, limit=2) + + async def test_rebuild_vector_index_empty(self, service, mock_deps): + repository, embedding_service, vector_store = mock_deps + repository.count_message_chunk_metas = AsyncMock(return_value=0) + repository.get_message_chunk_metas_batch = AsyncMock(return_value=[]) + + result = await service.rebuild_vector_index() + + assert result == 0 + vector_store.recreate_message_chunks_collection.assert_called_once() + embedding_service.embed_batch.assert_not_called() + vector_store.upsert_message_chunks_batch.assert_not_called() + def test_is_running_no_task(self, service): assert service.is_running is False diff --git a/tests/test_vector_store.py b/tests/test_vector_store.py index 212d138..77fb310 100644 --- a/tests/test_vector_store.py +++ b/tests/test_vector_store.py @@ -71,6 +71,14 @@ async def test_search_empty_collection(self, vector_store): results = await vector_store.search_articles([1.0, 0.0, 0.0, 0.0], topk=5) assert results == [] + async def test_message_chunk_doc_count(self, vector_store): + assert await vector_store.message_chunk_doc_count() == 0 + + await vector_store.upsert_message_chunk("chunk-1", [1.0, 0.0, 0.0, 0.0]) + await vector_store.upsert_message_chunk("chunk-2", [0.0, 1.0, 0.0, 0.0]) + + assert await vector_store.message_chunk_doc_count() == 2 + class TestUpsertBatch: async def test_batch_upsert(self, vector_store): @@ -98,6 +106,18 @@ async def test_delete_article(self, vector_store): assert len(results) == 0 +class TestRecreateCollections: + async def test_recreate_message_chunks_collection(self, vector_store): + await vector_store.upsert_message_chunk("chunk-1", [1.0, 0.0, 0.0, 0.0]) + assert await vector_store.message_chunk_doc_count() == 1 + + await vector_store.recreate_message_chunks_collection() + + assert await vector_store.message_chunk_doc_count() == 0 + results = await vector_store.search_message_chunks([1.0, 0.0, 0.0, 0.0], topk=1) + assert results == [] + + class TestNotInitialized: async def test_upsert_raises(self): store = VectorStore(data_dir="/tmp/noinit", dimensions=4)