diff --git a/src/intelstream/bot.py b/src/intelstream/bot.py index b51b874..c02c743 100644 --- a/src/intelstream/bot.py +++ b/src/intelstream/bot.py @@ -183,6 +183,7 @@ async def _setup_search(self) -> None: self.vector_store = VectorStore( data_dir=self.settings.zvec_data_dir, dimensions=self.settings.embedding_dimensions, + model_name=self.settings.embedding_model, ) await self.vector_store.initialize() diff --git a/src/intelstream/database/repository.py b/src/intelstream/database/repository.py index 8e07744..482454e 100644 --- a/src/intelstream/database/repository.py +++ b/src/intelstream/database/repository.py @@ -321,6 +321,16 @@ async def get_summarized_content_items( ) return list(result.scalars().all()) + async def count_summarized_content_items(self) -> int: + async with self.session() as session: + result = await session.execute( + select(func.count()) + .select_from(ContentItem) + .where(ContentItem.summary.isnot(None)) + .where(ContentItem.summary != "") + ) + return int(result.scalar_one()) + async def content_item_exists(self, external_id: str) -> bool: async with self.session() as session: result = await session.execute( diff --git a/src/intelstream/database/vector_store.py b/src/intelstream/database/vector_store.py index 4bbbc96..e9d23d4 100644 --- a/src/intelstream/database/vector_store.py +++ b/src/intelstream/database/vector_store.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import json import os import shutil from dataclasses import dataclass @@ -14,6 +15,9 @@ logger = structlog.get_logger(__name__) +_VECTOR_FIELD_NAME = "embedding" +_METADATA_FILENAME = "intelstream-index.json" + @dataclass class SearchResult: @@ -31,20 +35,32 @@ class VectorStore: _ARTICLES_COLLECTION = "articles" _MESSAGE_CHUNKS_COLLECTION = "message_chunks" - def __init__(self, data_dir: str, dimensions: int = 384) -> None: + def __init__( + self, + data_dir: str, + dimensions: int = 384, + model_name: str | None = None, + ) -> None: self._data_dir = data_dir self._dimensions = dimensions + self._model_name = model_name self._articles: zvec.Collection | None = None self._message_chunks: dict[str, zvec.Collection] = {} async def initialize(self) -> None: await asyncio.to_thread(os.makedirs, self._data_dir, exist_ok=True) - self._articles = await self._open_or_create_collection(self._ARTICLES_COLLECTION) + self._articles = await self._open_or_create_collection( + self._ARTICLES_COLLECTION, + validate_metadata=True, + ) await asyncio.to_thread(self._warn_if_legacy_message_chunk_collection_present) def _collection_path(self, collection_name: str) -> str: return str(Path(self._data_dir) / collection_name) + def _collection_metadata_path(self, collection_name: str, path: str | None = None) -> Path: + return Path(path or self._collection_path(collection_name)) / _METADATA_FILENAME + def _collection_attr_name(self, collection_name: str) -> str: if collection_name == self._ARTICLES_COLLECTION: return "_articles" @@ -74,53 +90,173 @@ def _build_schema(self, collection_name: str) -> zvec.CollectionSchema: return zvec.CollectionSchema( name=collection_name, - vectors=zvec.VectorSchema("embedding", zvec.DataType.VECTOR_FP32, self._dimensions), + vectors=zvec.VectorSchema( + _VECTOR_FIELD_NAME, + zvec.DataType.VECTOR_FP32, + self._dimensions, + ), ) - async def _open_or_create_collection( + def _expected_collection_metadata(self, collection_name: str) -> dict[str, Any]: + return { + "collection": collection_name, + "dimensions": self._dimensions, + "model_name": self._model_name, + } + + def _read_collection_metadata( self, collection_name: str, path: str | None = None + ) -> dict[str, Any] | None: + path_obj = self._collection_metadata_path(collection_name, path) + if not path_obj.exists(): + return None + try: + data = json.loads(path_obj.read_text()) + if isinstance(data, dict): + return data + logger.warning( + "Vector collection metadata file is not a JSON object", + collection=collection_name, + path=str(path_obj), + ) + return None + except (json.JSONDecodeError, OSError): + logger.warning( + "Failed to read vector collection metadata", + collection=collection_name, + path=str(path_obj), + ) + return None + + def _write_collection_metadata(self, collection_name: str, path: str | None = None) -> None: + path_obj = self._collection_metadata_path(collection_name, path) + path_obj.write_text( + json.dumps( + self._expected_collection_metadata(collection_name), indent=2, sort_keys=True + ) + ) + + def _collection_dimension(self, collection: zvec.Collection) -> int: + schema = json.loads(str(collection.schema)) + return int(schema["vectors"][_VECTOR_FIELD_NAME]["dimension"]) + + async def _collection_needs_recreate( + self, + collection_name: str, + collection: zvec.Collection, + path: str | None = None, + ) -> str | None: + actual_dimension = await asyncio.to_thread(self._collection_dimension, collection) + if actual_dimension != self._dimensions: + return f"dimension mismatch ({actual_dimension} != {self._dimensions})" + + metadata = await asyncio.to_thread(self._read_collection_metadata, collection_name, path) + if metadata is None: + return None + + stored_dimensions = metadata.get("dimensions") + if isinstance(stored_dimensions, int) and stored_dimensions != self._dimensions: + return ( + f"stored metadata dimensions mismatch ({stored_dimensions} != {self._dimensions})" + ) + + stored_model_name = metadata.get("model_name") + if self._model_name and stored_model_name and stored_model_name != self._model_name: + return f"model mismatch ({stored_model_name} != {self._model_name})" + + return None + + async def _destroy_collection_at_path( + self, + collection_name: str, + collection: zvec.Collection | None, + path: str | None = None, + ) -> None: + collection_path = path or self._collection_path(collection_name) + if collection is not None: + try: + await asyncio.to_thread(collection.destroy) + except Exception: + logger.warning( + "Failed to destroy vector collection cleanly, removing files manually", + collection=collection_name, + path=collection_path, + ) + + if await asyncio.to_thread(os.path.exists, collection_path): + await asyncio.to_thread(shutil.rmtree, collection_path, True) + + async def _open_or_create_collection( + self, + collection_name: str, + path: str | None = None, + *, + validate_metadata: bool = False, ) -> zvec.Collection: import zvec - path = path or self._collection_path(collection_name) + collection_path = path or self._collection_path(collection_name) try: collection = await asyncio.to_thread( zvec.create_and_open, - path=path, + path=collection_path, schema=self._build_schema(collection_name), ) logger.info("Created new vector collection", collection=collection_name) + if validate_metadata: + await asyncio.to_thread( + self._write_collection_metadata, + collection_name, + collection_path, + ) return collection except Exception: collection = await asyncio.to_thread( zvec.open, - path=path, + path=collection_path, option=zvec.CollectionOption(), ) logger.info("Opened existing vector collection", collection=collection_name) + if validate_metadata: + recreate_reason = await self._collection_needs_recreate( + collection_name, + collection, + collection_path, + ) + if recreate_reason is not None: + logger.warning( + "Recreating incompatible vector collection", + collection=collection_name, + reason=recreate_reason, + ) + await self._destroy_collection_at_path( + collection_name, + collection, + collection_path, + ) + collection = await asyncio.to_thread( + zvec.create_and_open, + path=collection_path, + schema=self._build_schema(collection_name), + ) + logger.info("Recreated vector collection", collection=collection_name) + await asyncio.to_thread( + self._write_collection_metadata, + collection_name, + collection_path, + ) return collection async def _recreate_collection(self, collection_name: str) -> zvec.Collection: attr_name = self._collection_attr_name(collection_name) collection = getattr(self, attr_name) - path = self._collection_path(collection_name) + await self._destroy_collection_at_path(collection_name, collection) + setattr(self, attr_name, None) - if collection is not None: - try: - await asyncio.to_thread(collection.destroy) - except Exception: - logger.warning( - "Failed to destroy vector collection cleanly, removing files manually", - collection=collection_name, - path=path, - ) - finally: - setattr(self, attr_name, None) - - if await asyncio.to_thread(os.path.exists, path): - await asyncio.to_thread(shutil.rmtree, path, True) - - recreated = await self._open_or_create_collection(collection_name) + recreated = await self._open_or_create_collection( + collection_name, + validate_metadata=True, + ) setattr(self, attr_name, recreated) return recreated @@ -146,18 +282,11 @@ async def _recreate_message_chunk_collection(self, guild_id: str) -> zvec.Collec collection = self._message_chunks.pop(guild_id, None) path = self._message_chunk_collection_path(guild_id) - if collection is not None: - try: - await asyncio.to_thread(collection.destroy) - except Exception: - logger.warning( - "Failed to destroy vector collection cleanly, removing files manually", - collection=self._message_chunk_collection_name(guild_id), - path=path, - ) - - if await asyncio.to_thread(os.path.exists, path): - await asyncio.to_thread(shutil.rmtree, path, True) + await self._destroy_collection_at_path( + self._message_chunk_collection_name(guild_id), + collection, + path, + ) recreated = await self._open_or_create_collection( self._message_chunk_collection_name(guild_id), @@ -178,7 +307,7 @@ async def upsert_article(self, content_item_id: str, embedding: list[float]) -> raise RuntimeError("VectorStore not initialized") doc = zvec.Doc( id=content_item_id, - vectors={"embedding": embedding}, + vectors={_VECTOR_FIELD_NAME: embedding}, ) await asyncio.to_thread(self._articles.upsert, [doc]) @@ -189,7 +318,7 @@ async def upsert_articles_batch(self, items: list[tuple[str, list[float]]]) -> N raise RuntimeError("VectorStore not initialized") if not items: return - docs = [zvec.Doc(id=item_id, vectors={"embedding": emb}) for item_id, emb in items] + docs = [zvec.Doc(id=item_id, vectors={_VECTOR_FIELD_NAME: emb}) for item_id, emb in items] await asyncio.to_thread(self._articles.upsert, docs) async def search_articles( @@ -201,7 +330,7 @@ async def search_articles( raise RuntimeError("VectorStore not initialized") results: Any = await asyncio.to_thread( self._articles.query, - zvec.VectorQuery("embedding", vector=query_embedding), + zvec.VectorQuery(_VECTOR_FIELD_NAME, vector=query_embedding), topk=topk, ) return [SearchResult(content_item_id=r.id, score=r.score) for r in results] @@ -214,6 +343,9 @@ async def delete_article(self, content_item_id: str) -> None: async def article_doc_count(self) -> int: return await self._doc_count(self._articles) + async def recreate_articles_collection(self) -> None: + await self._recreate_collection(self._ARTICLES_COLLECTION) + async def upsert_message_chunk( self, guild_id: str, chunk_id: str, embedding: list[float] ) -> None: @@ -224,7 +356,7 @@ async def upsert_message_chunk( raise RuntimeError("VectorStore not initialized") doc = zvec.Doc( id=chunk_id, - vectors={"embedding": embedding}, + vectors={_VECTOR_FIELD_NAME: embedding}, ) await asyncio.to_thread(collection.upsert, [doc]) @@ -238,7 +370,7 @@ async def upsert_message_chunks_batch( raise RuntimeError("VectorStore not initialized") if not items: return - docs = [zvec.Doc(id=cid, vectors={"embedding": emb}) for cid, emb in items] + docs = [zvec.Doc(id=cid, vectors={_VECTOR_FIELD_NAME: emb}) for cid, emb in items] await asyncio.to_thread(collection.upsert, docs) async def search_message_chunks( @@ -251,7 +383,7 @@ async def search_message_chunks( return [] results: Any = await asyncio.to_thread( collection.query, - zvec.VectorQuery("embedding", vector=query_embedding), + zvec.VectorQuery(_VECTOR_FIELD_NAME, vector=query_embedding), topk=topk, ) return [ChunkSearchResult(chunk_id=r.id, score=r.score) for r in results] diff --git a/src/intelstream/discord/cogs/search.py b/src/intelstream/discord/cogs/search.py index 342b48a..dbc3604 100644 --- a/src/intelstream/discord/cogs/search.py +++ b/src/intelstream/discord/cogs/search.py @@ -1,5 +1,7 @@ from __future__ import annotations +import asyncio +from contextlib import suppress from datetime import UTC, datetime from typing import TYPE_CHECKING @@ -16,6 +18,8 @@ logger = structlog.get_logger(__name__) MAX_SUMMARY_PREVIEW = 200 +INDEX_BATCH_SIZE = 50 +HEALTH_CHECK_TOPK = 10 class Search(commands.Cog): @@ -28,6 +32,22 @@ def __init__( self.bot = bot self._embedding_service = embedding_service self._vector_store = vector_store + self._index_rebuild_task: asyncio.Task[None] | None = None + self._index_rebuild_error: str | None = None + + async def cog_load(self) -> None: + self._index_rebuild_task = asyncio.create_task( + self._ensure_article_index(), + name="article-index-rebuild", + ) + logger.info("Search cog loaded") + + async def cog_unload(self) -> None: + if self._index_rebuild_task is not None: + self._index_rebuild_task.cancel() + with suppress(asyncio.CancelledError): + await self._index_rebuild_task + logger.info("Search cog unloaded") @app_commands.command( name="search", @@ -36,6 +56,20 @@ def __init__( @app_commands.describe(query="Natural language search query") @app_commands.checks.cooldown(rate=5, per=60.0) async def search(self, interaction: discord.Interaction, query: str) -> None: + if self._index_rebuild_task is not None and not self._index_rebuild_task.done(): + await interaction.response.send_message( + "Search is temporarily unavailable while the article index is being rebuilt.", + ephemeral=True, + ) + return + + if self._index_rebuild_error is not None: + await interaction.response.send_message( + "Search is temporarily unavailable because the article index needs recovery.", + ephemeral=True, + ) + return + await interaction.response.defer() logger.info( @@ -64,6 +98,7 @@ async def search(self, interaction: discord.Interaction, query: str) -> None: color=discord.Color.blue(), timestamp=datetime.now(UTC), ) + rendered_results = 0 for result in results: item = items_by_id.get(result.content_item_id) @@ -85,8 +120,16 @@ async def search(self, interaction: discord.Interaction, query: str) -> None: value="\n".join(value_parts), inline=False, ) + rendered_results += 1 - embed.set_footer(text=f"{len(results)} results") + if rendered_results == 0: + await interaction.followup.send( + "No results found. The search index may need rebuilding.", + ephemeral=True, + ) + return + + embed.set_footer(text=f"{rendered_results} results") await interaction.followup.send(embed=embed) @app_commands.command( @@ -99,27 +142,14 @@ async def index(self, interaction: discord.Interaction) -> None: logger.info("index command invoked", user_id=interaction.user.id) - total_indexed = 0 - offset = 0 - batch_size = 50 - - while True: - items = await self.bot.repository.get_summarized_content_items( - offset=offset, limit=batch_size + if self._index_rebuild_task is not None and not self._index_rebuild_task.done(): + await interaction.followup.send( + "The article index is already being rebuilt.", ephemeral=True ) - if not items: - break - - texts = [f"{item.title} {item.summary}" for item in items] - embeddings = await self._embedding_service.embed_batch(texts) - - batch = [(item.id, emb) for item, emb in zip(items, embeddings, strict=True)] - await self._vector_store.upsert_articles_batch(batch) - - total_indexed += len(items) - offset += batch_size + return - logger.info("Indexing progress", indexed=total_indexed) + total_indexed = await self._rebuild_article_index() + self._index_rebuild_error = None await interaction.followup.send( f"Indexed {total_indexed} articles for search.", ephemeral=True @@ -149,6 +179,97 @@ async def index_error( else: raise error + async def _ensure_article_index(self) -> None: + try: + expected_count = await self.bot.repository.count_summarized_content_items() + if expected_count == 0: + logger.info("No summarized content found; skipping article index rebuild") + return + + if await self._article_index_is_healthy(expected_count): + logger.info("Article search index is healthy", items=expected_count) + return + + logger.warning( + "Article search index is unhealthy; rebuilding from summarized content", + expected_items=expected_count, + ) + rebuilt = await self._rebuild_article_index() + logger.info("Article search index rebuilt", indexed=rebuilt) + except asyncio.CancelledError: + raise + except Exception as exc: + self._index_rebuild_error = str(exc) + logger.exception("Failed to rebuild article search index", error=str(exc)) + + async def _article_index_is_healthy(self, expected_count: int) -> bool: + indexed_count = await self._vector_store.article_doc_count() + if indexed_count != expected_count: + logger.warning( + "Article search index count mismatch", + expected=expected_count, + indexed=indexed_count, + ) + return False + + sample_batch = await self.bot.repository.get_summarized_content_items(limit=1) + if not sample_batch: + return True + + sample = sample_batch[0] + query_embedding = await self._embedding_service.embed_text( + f"{sample.title} {sample.summary}" + ) + results = await self._vector_store.search_articles( + query_embedding, + topk=HEALTH_CHECK_TOPK, + ) + if any(result.content_item_id == sample.id for result in results): + return True + + logger.warning( + "Article search index probe failed", + sample_item_id=sample.id, + result_ids=[result.content_item_id for result in results], + ) + return False + + async def _rebuild_article_index(self, batch_size: int = INDEX_BATCH_SIZE) -> int: + total_items = await self.bot.repository.count_summarized_content_items() + await self._vector_store.recreate_articles_collection() + + if total_items == 0: + logger.info("No summarized content to index") + return 0 + + indexed = 0 + offset = 0 + + while True: + items = await self.bot.repository.get_summarized_content_items( + offset=offset, + limit=batch_size, + ) + if not items: + break + + texts = [f"{item.title} {item.summary}" for item in items] + embeddings = await self._embedding_service.embed_batch(texts) + batch = [(item.id, emb) for item, emb in zip(items, embeddings, strict=True)] + await self._vector_store.upsert_articles_batch(batch) + + indexed += len(items) + offset += len(items) + + if indexed == total_items or indexed % (batch_size * 10) == 0: + logger.info( + "Article index rebuild progress", + indexed=indexed, + total=total_items, + ) + + return indexed + def _truncate(text: str, max_len: int) -> str: if len(text) <= max_len: diff --git a/tests/test_discord/test_search.py b/tests/test_discord/test_search.py index fff4169..a6881b1 100644 --- a/tests/test_discord/test_search.py +++ b/tests/test_discord/test_search.py @@ -1,3 +1,4 @@ +import asyncio from unittest.mock import AsyncMock, MagicMock import discord @@ -13,6 +14,8 @@ def mock_bot(): bot.settings = MagicMock() bot.settings.search_result_limit = 5 bot.repository = AsyncMock() + bot.repository.count_summarized_content_items = AsyncMock(return_value=0) + bot.repository.get_summarized_content_items = AsyncMock(return_value=[]) return bot @@ -29,6 +32,8 @@ def mock_vector_store(): store = AsyncMock() store.search_articles = AsyncMock(return_value=[]) store.upsert_articles_batch = AsyncMock() + store.article_doc_count = AsyncMock(return_value=0) + store.recreate_articles_collection = AsyncMock() return store @@ -42,6 +47,7 @@ def mock_interaction(): interaction = MagicMock(spec=discord.Interaction) interaction.response = MagicMock() interaction.response.defer = AsyncMock() + interaction.response.send_message = AsyncMock() interaction.followup = MagicMock() interaction.followup.send = AsyncMock() interaction.user = MagicMock() @@ -97,6 +103,20 @@ async def test_search_embeds_query(self, search_cog, mock_interaction, mock_embe await search_cog.search.callback(search_cog, mock_interaction, "test query") mock_embedding_service.embed_text.assert_called_once_with("test query") + async def test_search_mentions_rebuild_in_progress(self, search_cog, mock_interaction): + search_cog._index_rebuild_task = asyncio.create_task(asyncio.sleep(0.1)) + + try: + await search_cog.search.callback(search_cog, mock_interaction, "test query") + finally: + search_cog._index_rebuild_task.cancel() + with pytest.raises(asyncio.CancelledError): + await search_cog._index_rebuild_task + + mock_interaction.response.send_message.assert_called_once() + msg = mock_interaction.response.send_message.call_args[0][0].lower() + assert "rebuilt" in msg or "rebuild" in msg + class TestIndex: async def test_index_empty(self, search_cog, mock_interaction, mock_bot): @@ -122,15 +142,29 @@ async def test_index_processes_items( [item1, item2], [], ] + mock_bot.repository.count_summarized_content_items.return_value = 2 await search_cog.index.callback(search_cog, mock_interaction) + mock_bot.repository.count_summarized_content_items.assert_called_once() + mock_vector_store.recreate_articles_collection.assert_called_once() mock_embedding_service.embed_batch.assert_called_once_with( ["Title 1 Summary 1", "Title 2 Summary 2"] ) mock_vector_store.upsert_articles_batch.assert_called_once() assert "2" in mock_interaction.followup.send.call_args.args[0] + async def test_ensure_article_index_rebuilds_unhealthy_index( + self, search_cog, mock_bot, mock_vector_store + ): + search_cog._rebuild_article_index = AsyncMock(return_value=3) + mock_bot.repository.count_summarized_content_items.return_value = 3 + mock_vector_store.article_doc_count.return_value = 0 + + await search_cog._ensure_article_index() + + search_cog._rebuild_article_index.assert_awaited_once() + class TestTruncate: def test_short_text(self): diff --git a/tests/test_vector_store.py b/tests/test_vector_store.py index 339147c..da05ab4 100644 --- a/tests/test_vector_store.py +++ b/tests/test_vector_store.py @@ -20,7 +20,7 @@ def _can_import_zvec() -> bool: @pytest.fixture async def vector_store(tmp_path): - store = VectorStore(data_dir=str(tmp_path / "vectors"), dimensions=4) + store = VectorStore(data_dir=str(tmp_path / "vectors"), dimensions=4, model_name="model-a") await store.initialize() yield store await store.close() @@ -28,25 +28,55 @@ async def vector_store(tmp_path): class TestInitialize: async def test_creates_directory(self, tmp_path): - store = VectorStore(data_dir=str(tmp_path / "new_dir"), dimensions=4) + store = VectorStore(data_dir=str(tmp_path / "new_dir"), dimensions=4, model_name="model-a") await store.initialize() assert (tmp_path / "new_dir").exists() await store.close() async def test_reopens_existing_collection(self, tmp_path): data_dir = str(tmp_path / "reopen_test") - store1 = VectorStore(data_dir=data_dir, dimensions=4) + store1 = VectorStore(data_dir=data_dir, dimensions=4, model_name="model-a") await store1.initialize() await store1.upsert_article("doc1", [0.1, 0.2, 0.3, 0.4]) await store1.close() - store2 = VectorStore(data_dir=data_dir, dimensions=4) + store2 = VectorStore(data_dir=data_dir, dimensions=4, model_name="model-a") await store2.initialize() results = await store2.search_articles([0.1, 0.2, 0.3, 0.4], topk=1) assert len(results) == 1 assert results[0].content_item_id == "doc1" await store2.close() + async def test_recreates_articles_collection_on_dimension_mismatch(self, tmp_path): + data_dir = str(tmp_path / "dimension_mismatch") + store1 = VectorStore(data_dir=data_dir, dimensions=4, model_name="model-a") + await store1.initialize() + await store1.upsert_article("doc1", [0.1, 0.2, 0.3, 0.4]) + await store1.close() + + store2 = VectorStore(data_dir=data_dir, dimensions=3, model_name="model-a") + await store2.initialize() + + assert await store2.article_doc_count() == 0 + results = await store2.search_articles([0.1, 0.2, 0.3], topk=1) + assert results == [] + await store2.close() + + async def test_recreates_articles_collection_on_model_mismatch(self, tmp_path): + data_dir = str(tmp_path / "model_mismatch") + store1 = VectorStore(data_dir=data_dir, dimensions=4, model_name="model-a") + await store1.initialize() + await store1.upsert_article("doc1", [0.1, 0.2, 0.3, 0.4]) + await store1.close() + + store2 = VectorStore(data_dir=data_dir, dimensions=4, model_name="model-b") + await store2.initialize() + + assert await store2.article_doc_count() == 0 + results = await store2.search_articles([0.1, 0.2, 0.3, 0.4], topk=1) + assert results == [] + await store2.close() + class TestUpsertAndSearch: async def test_upsert_and_search(self, vector_store): @@ -125,6 +155,16 @@ async def test_delete_article(self, vector_store): class TestRecreateCollections: + async def test_recreate_articles_collection(self, vector_store): + await vector_store.upsert_article("item-1", [1.0, 0.0, 0.0, 0.0]) + assert await vector_store.article_doc_count() == 1 + + await vector_store.recreate_articles_collection() + + assert await vector_store.article_doc_count() == 0 + results = await vector_store.search_articles([1.0, 0.0, 0.0, 0.0], topk=1) + assert results == [] + async def test_recreate_message_chunks_collection(self, vector_store): await vector_store.upsert_message_chunk("guild-1", "chunk-1", [1.0, 0.0, 0.0, 0.0]) assert await vector_store.message_chunk_doc_count("guild-1") == 1 @@ -138,16 +178,16 @@ async def test_recreate_message_chunks_collection(self, vector_store): class TestNotInitialized: async def test_upsert_raises(self): - store = VectorStore(data_dir="/tmp/noinit", dimensions=4) + store = VectorStore(data_dir="/tmp/noinit", dimensions=4, model_name="model-a") with pytest.raises(RuntimeError, match="not initialized"): await store.upsert_article("x", [1.0, 0.0, 0.0, 0.0]) async def test_search_raises(self): - store = VectorStore(data_dir="/tmp/noinit", dimensions=4) + store = VectorStore(data_dir="/tmp/noinit", dimensions=4, model_name="model-a") with pytest.raises(RuntimeError, match="not initialized"): await store.search_articles([1.0, 0.0, 0.0, 0.0]) async def test_delete_raises(self): - store = VectorStore(data_dir="/tmp/noinit", dimensions=4) + store = VectorStore(data_dir="/tmp/noinit", dimensions=4, model_name="model-a") with pytest.raises(RuntimeError, match="not initialized"): await store.delete_article("x")