diff --git a/src/intelstream/adapters/strategies/llm_extraction.py b/src/intelstream/adapters/strategies/llm_extraction.py index 1901b73..4fecc25 100644 --- a/src/intelstream/adapters/strategies/llm_extraction.py +++ b/src/intelstream/adapters/strategies/llm_extraction.py @@ -80,14 +80,8 @@ async def discover( if isinstance(posts_data, list): posts = [] for p in posts_data: - if ( - isinstance(p, dict) - and isinstance(p.get("url"), str) - and p.get("url") - ): - posts.append( - DiscoveredPost(url=p["url"], title=p.get("title", "")) - ) + if isinstance(p, dict) and isinstance(p.get("url"), str) and p.get("url"): + posts.append(DiscoveredPost(url=p["url"], title=p.get("title", ""))) if posts: logger.debug( "Using cached LLM extraction", @@ -131,12 +125,7 @@ def _get_content_hash(self, html: str) -> str: ): tag.decompose() - main = ( - soup.find("main") - or soup.find("article") - or soup.find(id="content") - or soup.body - ) + main = soup.find("main") or soup.find("article") or soup.find(id="content") or soup.body if main: text = " ".join(main.get_text().split()) @@ -150,16 +139,10 @@ async def _fetch_html(self, url: str) -> str | None: } try: if self._http_client: - response = await self._http_client.get( - url, headers=headers, follow_redirects=True - ) + response = await self._http_client.get(url, headers=headers, follow_redirects=True) else: - async with httpx.AsyncClient( - timeout=get_settings().http_timeout_seconds - ) as client: - response = await client.get( - url, headers=headers, follow_redirects=True - ) + async with httpx.AsyncClient(timeout=get_settings().http_timeout_seconds) as client: + response = await client.get(url, headers=headers, follow_redirects=True) response.raise_for_status() return response.text except httpx.HTTPError as e: @@ -169,9 +152,7 @@ async def _fetch_html(self, url: str) -> str | None: def _clean_html(self, html: str) -> str: soup = BeautifulSoup(html, "lxml") - for tag in soup.find_all( - ["script", "style", "noscript", "svg", "path", "iframe"] - ): + for tag in soup.find_all(["script", "style", "noscript", "svg", "path", "iframe"]): tag.decompose() for tag in soup.find_all(True): @@ -286,7 +267,5 @@ def parse_and_validate(data: str) -> list[dict[str, str]] | None: if result is not None: return result - logger.warning( - "Failed to extract JSON from LLM response", response_preview=text[:200] - ) + logger.warning("Failed to extract JSON from LLM response", response_preview=text[:200]) return [] diff --git a/src/intelstream/config.py b/src/intelstream/config.py index 49f3d8a..79f19d9 100644 --- a/src/intelstream/config.py +++ b/src/intelstream/config.py @@ -34,20 +34,12 @@ class Settings(BaseSettings): description="LLM provider for summarization: anthropic, openai, gemini, or kimi", ) - anthropic_api_key: str | None = Field( - default=None, description="Anthropic API key for Claude" - ) + anthropic_api_key: str | None = Field(default=None, description="Anthropic API key for Claude") openai_api_key: str | None = Field(default=None, description="OpenAI API key") - gemini_api_key: str | None = Field( - default=None, description="Google Gemini API key" - ) - kimi_api_key: str | None = Field( - default=None, description="Kimi (Moonshot AI) API key" - ) + gemini_api_key: str | None = Field(default=None, description="Google Gemini API key") + kimi_api_key: str | None = Field(default=None, description="Kimi (Moonshot AI) API key") - youtube_api_key: str | None = Field( - default=None, description="YouTube Data API key (optional)" - ) + youtube_api_key: str | None = Field(default=None, description="YouTube Data API key (optional)") twitter_bearer_token: str | None = Field( default=None, diff --git a/src/intelstream/database/repository.py b/src/intelstream/database/repository.py index a27c07d..8e07744 100644 --- a/src/intelstream/database/repository.py +++ b/src/intelstream/database/repository.py @@ -938,23 +938,40 @@ async def add_message_chunk_metas_batch(self, chunks: list[MessageChunkMeta]) -> session.add_all(chunks) await session.commit() - async def count_message_chunk_metas(self) -> int: + async def count_message_chunk_metas(self, guild_id: str | None = None) -> int: async with self.session() as session: - result = await session.execute(select(func.count()).select_from(MessageChunkMeta)) + query = select(func.count()).select_from(MessageChunkMeta) + if guild_id is not None: + query = query.where(MessageChunkMeta.guild_id == guild_id) + result = await session.execute(query) return int(result.scalar_one()) async def get_message_chunk_metas_batch( - self, offset: int = 0, limit: int = 100 + self, + offset: int = 0, + limit: int = 100, + guild_id: str | None = None, ) -> list[MessageChunkMeta]: async with self.session() as session: + query = select(MessageChunkMeta) + if guild_id is not None: + query = query.where(MessageChunkMeta.guild_id == guild_id) result = await session.execute( - select(MessageChunkMeta) - .order_by(MessageChunkMeta.start_timestamp.asc(), MessageChunkMeta.id.asc()) + query.order_by(MessageChunkMeta.start_timestamp.asc(), MessageChunkMeta.id.asc()) .offset(offset) .limit(limit) ) return list(result.scalars().all()) + async def get_message_chunk_guild_ids(self) -> list[str]: + async with self.session() as session: + result = await session.execute( + select(MessageChunkMeta.guild_id) + .distinct() + .order_by(MessageChunkMeta.guild_id.asc()) + ) + return [str(guild_id) for guild_id in result.scalars().all()] + async def get_message_chunk_metas_by_ids(self, chunk_ids: list[str]) -> list[MessageChunkMeta]: if not chunk_ids: return [] diff --git a/src/intelstream/database/vector_store.py b/src/intelstream/database/vector_store.py index aa51986..4bbbc96 100644 --- a/src/intelstream/database/vector_store.py +++ b/src/intelstream/database/vector_store.py @@ -35,14 +35,12 @@ def __init__(self, data_dir: str, dimensions: int = 384) -> None: self._data_dir = data_dir self._dimensions = dimensions self._articles: zvec.Collection | None = None - self._message_chunks: zvec.Collection | None = None + self._message_chunks: dict[str, zvec.Collection] = {} async def initialize(self) -> None: await asyncio.to_thread(os.makedirs, self._data_dir, exist_ok=True) self._articles = await self._open_or_create_collection(self._ARTICLES_COLLECTION) - self._message_chunks = await self._open_or_create_collection( - self._MESSAGE_CHUNKS_COLLECTION - ) + await asyncio.to_thread(self._warn_if_legacy_message_chunk_collection_present) def _collection_path(self, collection_name: str) -> str: return str(Path(self._data_dir) / collection_name) @@ -50,10 +48,27 @@ def _collection_path(self, collection_name: str) -> str: def _collection_attr_name(self, collection_name: str) -> str: if collection_name == self._ARTICLES_COLLECTION: return "_articles" - if collection_name == self._MESSAGE_CHUNKS_COLLECTION: - return "_message_chunks" raise ValueError(f"Unknown collection name: {collection_name}") + def _message_chunk_collection_name(self, guild_id: str) -> str: + return f"{self._MESSAGE_CHUNKS_COLLECTION}_{guild_id}" + + def _message_chunk_collection_path(self, guild_id: str) -> str: + return str(Path(self._data_dir) / self._MESSAGE_CHUNKS_COLLECTION / guild_id) + + def _warn_if_legacy_message_chunk_collection_present(self) -> None: + legacy_root = Path(self._collection_path(self._MESSAGE_CHUNKS_COLLECTION)) + if not legacy_root.exists(): + return + + legacy_files = [entry.name for entry in legacy_root.iterdir() if entry.is_file()] + if legacy_files: + logger.warning( + "Detected legacy global lore vector collection files; they are no longer used", + path=str(legacy_root), + files=sorted(legacy_files), + ) + def _build_schema(self, collection_name: str) -> zvec.CollectionSchema: import zvec @@ -62,10 +77,12 @@ def _build_schema(self, collection_name: str) -> zvec.CollectionSchema: vectors=zvec.VectorSchema("embedding", zvec.DataType.VECTOR_FP32, self._dimensions), ) - async def _open_or_create_collection(self, collection_name: str) -> zvec.Collection: + async def _open_or_create_collection( + self, collection_name: str, path: str | None = None + ) -> zvec.Collection: import zvec - path = self._collection_path(collection_name) + path = path or self._collection_path(collection_name) try: collection = await asyncio.to_thread( zvec.create_and_open, @@ -107,6 +124,48 @@ async def _recreate_collection(self, collection_name: str) -> zvec.Collection: setattr(self, attr_name, recreated) return recreated + async def _message_chunk_collection( + self, guild_id: str, *, create: bool + ) -> zvec.Collection | None: + collection = self._message_chunks.get(guild_id) + if collection is not None: + return collection + + path = self._message_chunk_collection_path(guild_id) + if not create and not await asyncio.to_thread(os.path.exists, path): + return None + + collection = await self._open_or_create_collection( + self._message_chunk_collection_name(guild_id), + path=path, + ) + self._message_chunks[guild_id] = collection + return collection + + async def _recreate_message_chunk_collection(self, guild_id: str) -> zvec.Collection: + collection = self._message_chunks.pop(guild_id, None) + path = self._message_chunk_collection_path(guild_id) + + if collection is not None: + try: + await asyncio.to_thread(collection.destroy) + except Exception: + logger.warning( + "Failed to destroy vector collection cleanly, removing files manually", + collection=self._message_chunk_collection_name(guild_id), + path=path, + ) + + if await asyncio.to_thread(os.path.exists, path): + await asyncio.to_thread(shutil.rmtree, path, True) + + recreated = await self._open_or_create_collection( + self._message_chunk_collection_name(guild_id), + path=path, + ) + self._message_chunks[guild_id] = recreated + return recreated + async def _doc_count(self, collection: zvec.Collection | None) -> int: if collection is None: raise RuntimeError("VectorStore not initialized") @@ -155,57 +214,68 @@ async def delete_article(self, content_item_id: str) -> None: async def article_doc_count(self) -> int: return await self._doc_count(self._articles) - async def upsert_message_chunk(self, chunk_id: str, embedding: list[float]) -> None: + async def upsert_message_chunk( + self, guild_id: str, chunk_id: str, embedding: list[float] + ) -> None: import zvec - if self._message_chunks is None: + collection = await self._message_chunk_collection(guild_id, create=True) + if collection is None: raise RuntimeError("VectorStore not initialized") doc = zvec.Doc( id=chunk_id, vectors={"embedding": embedding}, ) - await asyncio.to_thread(self._message_chunks.upsert, [doc]) + await asyncio.to_thread(collection.upsert, [doc]) - async def upsert_message_chunks_batch(self, items: list[tuple[str, list[float]]]) -> None: + async def upsert_message_chunks_batch( + self, guild_id: str, items: list[tuple[str, list[float]]] + ) -> None: import zvec - if self._message_chunks is None: + collection = await self._message_chunk_collection(guild_id, create=True) + if collection is None: raise RuntimeError("VectorStore not initialized") if not items: return docs = [zvec.Doc(id=cid, vectors={"embedding": emb}) for cid, emb in items] - await asyncio.to_thread(self._message_chunks.upsert, docs) + await asyncio.to_thread(collection.upsert, docs) async def search_message_chunks( - self, query_embedding: list[float], topk: int = 30 + self, guild_id: str, query_embedding: list[float], topk: int = 30 ) -> list[ChunkSearchResult]: import zvec - if self._message_chunks is None: - raise RuntimeError("VectorStore not initialized") + collection = await self._message_chunk_collection(guild_id, create=False) + if collection is None: + return [] results: Any = await asyncio.to_thread( - self._message_chunks.query, + collection.query, zvec.VectorQuery("embedding", vector=query_embedding), topk=topk, ) return [ChunkSearchResult(chunk_id=r.id, score=r.score) for r in results] - async def delete_message_chunks_by_ids(self, chunk_ids: list[str]) -> None: - if self._message_chunks is None: - raise RuntimeError("VectorStore not initialized") + async def delete_message_chunks_by_ids(self, guild_id: str, chunk_ids: list[str]) -> None: + collection = await self._message_chunk_collection(guild_id, create=False) + if collection is None: + return for chunk_id in chunk_ids: - await asyncio.to_thread(self._message_chunks.delete, chunk_id) + await asyncio.to_thread(collection.delete, chunk_id) - async def message_chunk_doc_count(self) -> int: - return await self._doc_count(self._message_chunks) + async def message_chunk_doc_count(self, guild_id: str) -> int: + collection = await self._message_chunk_collection(guild_id, create=False) + if collection is None: + return 0 + return await self._doc_count(collection) - async def recreate_message_chunks_collection(self) -> None: - await self._recreate_collection(self._MESSAGE_CHUNKS_COLLECTION) + async def recreate_message_chunks_collection(self, guild_id: str) -> None: + await self._recreate_message_chunk_collection(guild_id) async def close(self) -> None: if self._articles is not None: await asyncio.to_thread(self._articles.flush) self._articles = None - if self._message_chunks is not None: - await asyncio.to_thread(self._message_chunks.flush) - self._message_chunks = None + for collection in self._message_chunks.values(): + await asyncio.to_thread(collection.flush) + self._message_chunks.clear() diff --git a/src/intelstream/discord/cogs/lore.py b/src/intelstream/discord/cogs/lore.py index 675c672..1a952ba 100644 --- a/src/intelstream/discord/cogs/lore.py +++ b/src/intelstream/discord/cogs/lore.py @@ -173,44 +173,64 @@ async def _ensure_message_chunk_index(self) -> None: return try: - expected_count = await self.bot.repository.count_message_chunk_metas() - if expected_count == 0: + guild_ids = await self.bot.repository.get_message_chunk_guild_ids() + if not guild_ids: logger.info("No stored lore chunks found; skipping vector index rebuild") return - if await self._message_index_is_healthy(expected_count): - logger.info("Lore message index is healthy", chunks=expected_count) - return - - logger.warning( - "Lore message index is unhealthy; rebuilding from stored chunks", - expected_chunks=expected_count, - ) - rebuilt = await self._ingestion_service.rebuild_vector_index() - logger.info("Lore message index rebuilt", indexed=rebuilt) + for guild_id in guild_ids: + expected_count = await self.bot.repository.count_message_chunk_metas( + guild_id=guild_id + ) + if expected_count == 0: + continue + + if await self._message_index_is_healthy(guild_id, expected_count): + logger.info( + "Lore message index is healthy", + guild_id=guild_id, + chunks=expected_count, + ) + continue + + logger.warning( + "Lore message index is unhealthy; rebuilding from stored chunks", + guild_id=guild_id, + expected_chunks=expected_count, + ) + rebuilt = await self._ingestion_service.rebuild_vector_index(guild_id) + logger.info( + "Lore message index rebuilt", + guild_id=guild_id, + indexed=rebuilt, + ) except asyncio.CancelledError: raise except Exception as exc: self._index_rebuild_error = str(exc) logger.exception("Failed to rebuild lore message index", error=str(exc)) - async def _message_index_is_healthy(self, expected_count: int) -> bool: - indexed_count = await self._vector_store.message_chunk_doc_count() + async def _message_index_is_healthy(self, guild_id: str, expected_count: int) -> bool: + indexed_count = await self._vector_store.message_chunk_doc_count(guild_id) if indexed_count != expected_count: logger.warning( "Lore message index count mismatch", + guild_id=guild_id, expected=expected_count, indexed=indexed_count, ) return False - sample_batch = await self.bot.repository.get_message_chunk_metas_batch(limit=1) + sample_batch = await self.bot.repository.get_message_chunk_metas_batch( + limit=1, guild_id=guild_id + ) if not sample_batch: return True sample = sample_batch[0] query_embedding = await self._embedding_service.embed_text(sample.text) results = await self._vector_store.search_message_chunks( + guild_id, query_embedding, topk=HEALTH_CHECK_TOPK, ) @@ -219,6 +239,7 @@ async def _message_index_is_healthy(self, expected_count: int) -> bool: logger.warning( "Lore message index probe failed", + guild_id=guild_id, sample_chunk_id=sample.id, result_ids=[result.chunk_id for result in results], ) diff --git a/src/intelstream/services/message_ingestion.py b/src/intelstream/services/message_ingestion.py index 4640ca3..f29e276 100644 --- a/src/intelstream/services/message_ingestion.py +++ b/src/intelstream/services/message_ingestion.py @@ -206,7 +206,7 @@ async def store_chunks(self, chunks: list[Chunk]) -> int: embeddings = await self._embedding_service.embed_batch(texts) metas: list[MessageChunkMeta] = [] - vector_items: list[tuple[str, list[float]]] = [] + vector_items_by_guild: dict[str, list[tuple[str, list[float]]]] = {} for chunk, embedding in zip(chunks, embeddings, strict=True): meta = MessageChunkMeta( @@ -223,10 +223,11 @@ async def store_chunks(self, chunks: list[Chunk]) -> int: text=chunk.text, ) metas.append(meta) - vector_items.append((meta.id, embedding)) + vector_items_by_guild.setdefault(meta.guild_id, []).append((meta.id, embedding)) await self._repository.add_message_chunk_metas_batch(metas) - await self._vector_store.upsert_message_chunks_batch(vector_items) + for guild_id, vector_items in vector_items_by_guild.items(): + await self._vector_store.upsert_message_chunks_batch(guild_id, vector_items) total_messages = sum(len(c.messages) for c in chunks) logger.info( @@ -237,12 +238,12 @@ async def store_chunks(self, chunks: list[Chunk]) -> int: return len(metas) - async def rebuild_vector_index(self, batch_size: int = EMBED_BATCH_SIZE) -> int: - total_chunks = await self._repository.count_message_chunk_metas() - await self._vector_store.recreate_message_chunks_collection() + async def rebuild_vector_index(self, guild_id: str, batch_size: int = EMBED_BATCH_SIZE) -> int: + total_chunks = await self._repository.count_message_chunk_metas(guild_id=guild_id) + await self._vector_store.recreate_message_chunks_collection(guild_id) if total_chunks == 0: - logger.info("No stored message chunks to reindex") + logger.info("No stored message chunks to reindex", guild_id=guild_id) return 0 indexed = 0 @@ -252,6 +253,7 @@ async def rebuild_vector_index(self, batch_size: int = EMBED_BATCH_SIZE) -> int: metas = await self._repository.get_message_chunk_metas_batch( offset=offset, limit=batch_size, + guild_id=guild_id, ) if not metas: break @@ -260,7 +262,7 @@ async def rebuild_vector_index(self, batch_size: int = EMBED_BATCH_SIZE) -> int: vector_items = [ (meta.id, embedding) for meta, embedding in zip(metas, embeddings, strict=True) ] - await self._vector_store.upsert_message_chunks_batch(vector_items) + await self._vector_store.upsert_message_chunks_batch(guild_id, vector_items) indexed += len(metas) offset += len(metas) @@ -268,11 +270,17 @@ async def rebuild_vector_index(self, batch_size: int = EMBED_BATCH_SIZE) -> int: if indexed == total_chunks or indexed % (batch_size * 10) == 0: logger.info( "Lore vector index rebuild progress", + guild_id=guild_id, indexed=indexed, total=total_chunks, ) - logger.info("Lore vector index rebuild complete", indexed=indexed, total=total_chunks) + logger.info( + "Lore vector index rebuild complete", + guild_id=guild_id, + indexed=indexed, + total=total_chunks, + ) return indexed async def ingest_channel( diff --git a/src/intelstream/services/page_analyzer.py b/src/intelstream/services/page_analyzer.py index 2fc5b8c..cd36349 100644 --- a/src/intelstream/services/page_analyzer.py +++ b/src/intelstream/services/page_analyzer.py @@ -121,9 +121,7 @@ async def analyze(self, url: str) -> ExtractionProfile: validation_result = self._validate_profile(html, profile) if not validation_result["valid"]: - raise PageAnalysisError( - f"Profile validation failed: {validation_result['reason']}" - ) + raise PageAnalysisError(f"Profile validation failed: {validation_result['reason']}") logger.info( "Page analysis complete", @@ -141,24 +139,16 @@ async def _fetch_html(self, url: str) -> str: try: if self._http_client: - response = await self._http_client.get( - url, headers=headers, follow_redirects=True - ) + response = await self._http_client.get(url, headers=headers, follow_redirects=True) else: - async with httpx.AsyncClient( - timeout=get_settings().http_timeout_seconds - ) as client: - response = await client.get( - url, headers=headers, follow_redirects=True - ) + async with httpx.AsyncClient(timeout=get_settings().http_timeout_seconds) as client: + response = await client.get(url, headers=headers, follow_redirects=True) response.raise_for_status() return response.text except httpx.HTTPStatusError as e: - raise PageAnalysisError( - f"Failed to fetch page: HTTP {e.response.status_code}" - ) from e + raise PageAnalysisError(f"Failed to fetch page: HTTP {e.response.status_code}") from e except httpx.RequestError as e: raise PageAnalysisError(f"Failed to fetch page: {e}") from e @@ -255,9 +245,7 @@ async def _extract_profile_with_llm(self, url: str, html: str) -> dict[str, Any] logger.error("Anthropic API error during page analysis", error=str(e)) raise PageAnalysisError(f"LLM API error: {e}") from e - def _validate_profile( - self, html: str, profile: ExtractionProfile - ) -> dict[str, Any]: + def _validate_profile(self, html: str, profile: ExtractionProfile) -> dict[str, Any]: soup = BeautifulSoup(html, "lxml") try: diff --git a/tests/test_config.py b/tests/test_config.py index 045af15..38c61ef 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -28,9 +28,7 @@ def test_settings_from_env(self, monkeypatch: pytest.MonkeyPatch) -> None: assert settings.default_poll_interval_minutes == 5 assert settings.log_level == "INFO" - def test_settings_with_optional_youtube( - self, monkeypatch: pytest.MonkeyPatch - ) -> None: + def test_settings_with_optional_youtube(self, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("DISCORD_BOT_TOKEN", "test_token") monkeypatch.setenv("DISCORD_GUILD_ID", "123456789") monkeypatch.setenv("DISCORD_CHANNEL_ID", "987654321") @@ -42,9 +40,7 @@ def test_settings_with_optional_youtube( assert settings.youtube_api_key == "yt-api-key" - def test_settings_poll_interval_bounds( - self, monkeypatch: pytest.MonkeyPatch - ) -> None: + def test_settings_poll_interval_bounds(self, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("DISCORD_BOT_TOKEN", "test_token") monkeypatch.setenv("DISCORD_GUILD_ID", "123456789") monkeypatch.setenv("DISCORD_CHANNEL_ID", "987654321") @@ -103,9 +99,7 @@ def test_repr_handles_none_keys(self, monkeypatch: pytest.MonkeyPatch) -> None: assert "youtube_api_key=None" in repr_str assert "openai_api_key=None" in repr_str - def test_empty_discord_bot_token_rejected( - self, monkeypatch: pytest.MonkeyPatch - ) -> None: + def test_empty_discord_bot_token_rejected(self, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("DISCORD_BOT_TOKEN", "") monkeypatch.setenv("DISCORD_GUILD_ID", "123456789") monkeypatch.setenv("DISCORD_OWNER_ID", "111222333") @@ -137,9 +131,7 @@ def test_llm_api_key_returns_correct_provider_key( settings = Settings(_env_file=None) assert settings.llm_api_key == "sk-openai-test" - def test_llm_api_key_raises_when_key_missing( - self, monkeypatch: pytest.MonkeyPatch - ) -> None: + def test_llm_api_key_raises_when_key_missing(self, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("DISCORD_BOT_TOKEN", "test_token") monkeypatch.setenv("DISCORD_GUILD_ID", "123456789") monkeypatch.setenv("DISCORD_OWNER_ID", "111222333") @@ -149,9 +141,7 @@ def test_llm_api_key_raises_when_key_missing( with pytest.raises(ValidationError, match="No API key configured"): Settings(_env_file=None) - def test_invalid_llm_provider_rejected( - self, monkeypatch: pytest.MonkeyPatch - ) -> None: + def test_invalid_llm_provider_rejected(self, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("DISCORD_BOT_TOKEN", "test_token") monkeypatch.setenv("DISCORD_GUILD_ID", "123456789") monkeypatch.setenv("DISCORD_OWNER_ID", "111222333") @@ -161,9 +151,7 @@ def test_invalid_llm_provider_rejected( with pytest.raises(ValidationError): Settings(_env_file=None) - def test_valid_llm_providers_accepted( - self, monkeypatch: pytest.MonkeyPatch - ) -> None: + def test_valid_llm_providers_accepted(self, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("DISCORD_BOT_TOKEN", "test_token") monkeypatch.setenv("DISCORD_GUILD_ID", "123456789") monkeypatch.setenv("DISCORD_OWNER_ID", "111222333") @@ -181,9 +169,7 @@ def test_valid_llm_providers_accepted( assert settings.llm_api_key == key_val monkeypatch.delenv(key_env, raising=False) - def test_missing_api_key_fails_at_construction( - self, monkeypatch: pytest.MonkeyPatch - ) -> None: + def test_missing_api_key_fails_at_construction(self, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("DISCORD_BOT_TOKEN", "test_token") monkeypatch.setenv("DISCORD_GUILD_ID", "123456789") monkeypatch.setenv("DISCORD_OWNER_ID", "111222333") @@ -261,9 +247,7 @@ def test_explicit_model_overrides_provider_default( assert settings.summary_model == "my-custom-model" assert settings.summary_model_interactive == "my-custom-interactive" - def test_partial_override_uses_default_for_unset( - self, monkeypatch: pytest.MonkeyPatch - ) -> None: + def test_partial_override_uses_default_for_unset(self, monkeypatch: pytest.MonkeyPatch) -> None: self._base_env(monkeypatch) monkeypatch.setenv("LLM_PROVIDER", "openai") monkeypatch.setenv("OPENAI_API_KEY", "sk-openai-test") @@ -301,9 +285,7 @@ def test_falls_back_to_default(self, monkeypatch: pytest.MonkeyPatch) -> None: assert settings.get_poll_interval(SourceType.YOUTUBE) == 10 assert settings.get_poll_interval(SourceType.RSS) == 10 - def test_type_specific_overrides_default( - self, monkeypatch: pytest.MonkeyPatch - ) -> None: + def test_type_specific_overrides_default(self, monkeypatch: pytest.MonkeyPatch) -> None: self._base_env(monkeypatch) monkeypatch.setenv("DEFAULT_POLL_INTERVAL_MINUTES", "5") monkeypatch.setenv("TWITTER_POLL_INTERVAL_MINUTES", "20") @@ -329,9 +311,7 @@ def test_returns_parent_directory_for_sqlite_file(self) -> None: assert result == Path("./data") def test_returns_parent_for_absolute_path(self) -> None: - result = get_database_directory( - "sqlite+aiosqlite:////home/user/data/intelstream.db" - ) + result = get_database_directory("sqlite+aiosqlite:////home/user/data/intelstream.db") assert result == Path("/home/user/data") def test_returns_none_for_memory_database(self) -> None: diff --git a/tests/test_discord/test_channel_summary.py b/tests/test_discord/test_channel_summary.py index 481351f..82453f0 100644 --- a/tests/test_discord/test_channel_summary.py +++ b/tests/test_discord/test_channel_summary.py @@ -170,9 +170,7 @@ async def test_summary_with_different_channel(self, cog, mock_interaction): messages = [_make_message(f"msg {i}", f"user{i}") for i in range(6)] target_channel.history = MagicMock(return_value=_async_iter(messages)) - await cog.summary.callback( - cog, mock_interaction, count=200, channel=target_channel - ) + await cog.summary.callback(cog, mock_interaction, count=200, channel=target_channel) target_channel.history.assert_called_once() sent_text = mock_interaction.followup.send.call_args.args[0] @@ -183,19 +181,14 @@ async def test_summary_handles_summarization_failure(self, cog, mock_interaction channel = mock_interaction.channel channel.history = MagicMock(return_value=_async_iter(messages)) - cog._summarizer.summarize_chat = AsyncMock( - side_effect=SummarizationError("API error") - ) + cog._summarizer.summarize_chat = AsyncMock(side_effect=SummarizationError("API error")) await cog.summary.callback(cog, mock_interaction, count=200, channel=None) mock_interaction.followup.send.assert_called_once() sent_kwargs = mock_interaction.followup.send.call_args.kwargs assert sent_kwargs.get("ephemeral") is True - assert ( - "Failed to generate summary" - in mock_interaction.followup.send.call_args.args[0] - ) + assert "Failed to generate summary" in mock_interaction.followup.send.call_args.args[0] async def test_summary_filters_empty_messages(self, cog, mock_interaction): messages = [ diff --git a/tests/test_discord/test_content_posting.py b/tests/test_discord/test_content_posting.py index 2bdc89c..f2eb8fd 100644 --- a/tests/test_discord/test_content_posting.py +++ b/tests/test_discord/test_content_posting.py @@ -159,9 +159,7 @@ async def test_cog_unload_closes_summarizer(self, _patch_cog_deps, mock_bot): class TestContentLoop: - async def test_content_loop_skips_when_not_initialized( - self, _patch_cog_deps, mock_bot - ): + async def test_content_loop_skips_when_not_initialized(self, _patch_cog_deps, mock_bot): cog = ContentPosting(mock_bot) cog._initialized = False @@ -202,9 +200,7 @@ async def test_content_loop_posts_to_all_guilds(self, _patch_cog_deps, mock_bot) deps["poster"].post_unposted_items.assert_any_call(111) deps["poster"].post_unposted_items.assert_any_call(222) - async def test_content_loop_notifies_owner_on_error( - self, _patch_cog_deps, mock_bot - ): + async def test_content_loop_notifies_owner_on_error(self, _patch_cog_deps, mock_bot): deps = _patch_cog_deps deps["pipeline"].run_cycle = AsyncMock(side_effect=Exception("Test error")) @@ -217,14 +213,10 @@ async def test_content_loop_notifies_owner_on_error( call_args = mock_bot.notify_owner.call_args[0][0] assert "Test error" in call_args - async def test_content_loop_continues_on_guild_error( - self, _patch_cog_deps, mock_bot - ): + async def test_content_loop_continues_on_guild_error(self, _patch_cog_deps, mock_bot): deps = _patch_cog_deps deps["pipeline"].run_cycle = AsyncMock(return_value=(5, 3)) - deps["poster"].post_unposted_items = AsyncMock( - side_effect=[Exception("Guild 1 error"), 2] - ) + deps["poster"].post_unposted_items = AsyncMock(side_effect=[Exception("Guild 1 error"), 2]) guild1 = MagicMock(spec=discord.Guild) guild1.id = 111 @@ -245,9 +237,7 @@ async def test_content_loop_continues_on_guild_error( class TestContentLoopErrorHandler: - async def test_error_handler_notifies_owner_on_first_error( - self, _patch_cog_deps, mock_bot - ): + async def test_error_handler_notifies_owner_on_first_error(self, _patch_cog_deps, mock_bot): cog = ContentPosting(mock_bot) await cog.cog_load() @@ -272,9 +262,7 @@ async def test_error_handler_does_not_notify_owner_on_subsequent_errors( class TestContentLoopBackoff: - async def test_backoff_increments_consecutive_failures( - self, _patch_cog_deps, mock_bot - ): + async def test_backoff_increments_consecutive_failures(self, _patch_cog_deps, mock_bot): deps = _patch_cog_deps deps["pipeline"].run_cycle = AsyncMock(side_effect=Exception("Test error")) @@ -299,9 +287,7 @@ async def test_backoff_resets_on_success(self, _patch_cog_deps, mock_bot): assert cog._consecutive_failures == 0 - async def test_circuit_breaker_notifies_and_retries_hourly( - self, _patch_cog_deps, mock_bot - ): + async def test_circuit_breaker_notifies_and_retries_hourly(self, _patch_cog_deps, mock_bot): deps = _patch_cog_deps deps["pipeline"].run_cycle = AsyncMock(side_effect=Exception("Still failing")) @@ -330,9 +316,7 @@ async def test_circuit_breaker_recovers_on_success(self, _patch_cog_deps, mock_b assert cog._consecutive_failures == 0 assert cog.content_loop.minutes == cog._base_interval - async def test_apply_backoff_keeps_base_on_first_failure( - self, _patch_cog_deps, mock_bot - ): + async def test_apply_backoff_keeps_base_on_first_failure(self, _patch_cog_deps, mock_bot): cog = ContentPosting(mock_bot) await cog.cog_load() cog._consecutive_failures = 1 @@ -341,9 +325,7 @@ async def test_apply_backoff_keeps_base_on_first_failure( assert cog.content_loop.minutes == cog._base_interval - async def test_apply_backoff_doubles_on_second_failure( - self, _patch_cog_deps, mock_bot - ): + async def test_apply_backoff_doubles_on_second_failure(self, _patch_cog_deps, mock_bot): cog = ContentPosting(mock_bot) await cog.cog_load() cog._consecutive_failures = 2 @@ -352,9 +334,7 @@ async def test_apply_backoff_doubles_on_second_failure( assert cog.content_loop.minutes == cog._base_interval * 2 - async def test_apply_backoff_caps_at_max_multiplier( - self, _patch_cog_deps, mock_bot - ): + async def test_apply_backoff_caps_at_max_multiplier(self, _patch_cog_deps, mock_bot): cog = ContentPosting(mock_bot) await cog.cog_load() cog._consecutive_failures = 4 @@ -364,9 +344,7 @@ async def test_apply_backoff_caps_at_max_multiplier( max_interval = cog._base_interval * ContentPosting.MAX_BACKOFF_MULTIPLIER assert cog.content_loop.minutes == max_interval - async def test_reset_backoff_restores_base_interval( - self, _patch_cog_deps, mock_bot - ): + async def test_reset_backoff_restores_base_interval(self, _patch_cog_deps, mock_bot): cog = ContentPosting(mock_bot) await cog.cog_load() cog._consecutive_failures = 3 @@ -377,9 +355,7 @@ async def test_reset_backoff_restores_base_interval( assert cog._consecutive_failures == 0 assert cog.content_loop.minutes == cog._base_interval - async def test_only_notifies_owner_on_first_failure( - self, _patch_cog_deps, mock_bot - ): + async def test_only_notifies_owner_on_first_failure(self, _patch_cog_deps, mock_bot): deps = _patch_cog_deps deps["pipeline"].run_cycle = AsyncMock(side_effect=Exception("Test error")) diff --git a/tests/test_discord/test_lore.py b/tests/test_discord/test_lore.py index b8bce88..62a166d 100644 --- a/tests/test_discord/test_lore.py +++ b/tests/test_discord/test_lore.py @@ -21,6 +21,7 @@ def mock_bot(): bot.settings.summary_model_interactive = "claude-test" bot.repository = AsyncMock() bot.repository.count_message_chunk_metas = AsyncMock(return_value=0) + bot.repository.get_message_chunk_guild_ids = AsyncMock(return_value=[]) bot.repository.get_message_chunk_metas_batch = AsyncMock(return_value=[]) bot.get_guild = MagicMock(return_value=None) bot.guilds = [] @@ -173,14 +174,14 @@ async def test_message_index_healthy(self, lore_cog, mock_bot, mock_vector_store ChunkSearchResult(chunk_id="chunk-1", score=1.0) ] - result = await lore_cog._message_index_is_healthy(expected_count=1) + result = await lore_cog._message_index_is_healthy("guild-1", expected_count=1) assert result is True async def test_message_index_unhealthy_on_count_mismatch(self, lore_cog, mock_vector_store): mock_vector_store.message_chunk_doc_count.return_value = 0 - result = await lore_cog._message_index_is_healthy(expected_count=2) + result = await lore_cog._message_index_is_healthy("guild-1", expected_count=2) assert result is False mock_vector_store.search_message_chunks.assert_not_called() @@ -190,12 +191,13 @@ async def test_ensure_message_chunk_index_rebuilds_unhealthy_index( ): lore_cog._ingestion_service = MagicMock() lore_cog._ingestion_service.rebuild_vector_index = AsyncMock(return_value=3) + mock_bot.repository.get_message_chunk_guild_ids.return_value = ["guild-1"] mock_bot.repository.count_message_chunk_metas.return_value = 3 mock_vector_store.message_chunk_doc_count.return_value = 0 await lore_cog._ensure_message_chunk_index() - lore_cog._ingestion_service.rebuild_vector_index.assert_awaited_once() + lore_cog._ingestion_service.rebuild_vector_index.assert_awaited_once_with("guild-1") async def test_command_mentions_rebuild_in_progress(self, lore_cog, mock_interaction): lore_cog._index_rebuild_task = asyncio.create_task(asyncio.sleep(0.1)) @@ -253,14 +255,17 @@ async def test_skips_if_already_running(self, lore_cog): lore_cog._ingestion_service.start_backfill.assert_not_called() async def test_auto_start_uses_first_guild(self, lore_cog, mock_bot): - guild = MagicMock(spec=discord.Guild) - guild.id = 111 - guild.name = "Test Server" - mock_bot.guilds = [guild] + guild1 = MagicMock(spec=discord.Guild) + guild1.id = 111 + guild1.name = "Test Server 1" + guild2 = MagicMock(spec=discord.Guild) + guild2.id = 222 + guild2.name = "Test Server 2" + mock_bot.guilds = [guild1, guild2] mock_bot.repository.get_ingestion_progress_for_guild.return_value = [] await lore_cog.auto_start_ingestion() - lore_cog._ingestion_service.start_backfill.assert_called_once_with(guild) + lore_cog._ingestion_service.start_backfill.assert_called_once_with(guild1) async def test_auto_start_no_guilds(self, lore_cog, mock_bot): mock_bot.guilds = [] diff --git a/tests/test_discord/test_summarize.py b/tests/test_discord/test_summarize.py index 2b19d70..213c62a 100644 --- a/tests/test_discord/test_summarize.py +++ b/tests/test_discord/test_summarize.py @@ -45,42 +45,23 @@ def mock_interaction(): class TestDetectUrlType: def test_detect_youtube_com(self, summarize_cog): - assert ( - summarize_cog.detect_url_type("https://www.youtube.com/watch?v=abc123") - == "youtube" - ) - assert ( - summarize_cog.detect_url_type("https://youtube.com/watch?v=abc123") - == "youtube" - ) + assert summarize_cog.detect_url_type("https://www.youtube.com/watch?v=abc123") == "youtube" + assert summarize_cog.detect_url_type("https://youtube.com/watch?v=abc123") == "youtube" def test_detect_youtu_be(self, summarize_cog): assert summarize_cog.detect_url_type("https://youtu.be/abc123") == "youtube" def test_detect_substack(self, summarize_cog): - assert ( - summarize_cog.detect_url_type("https://example.substack.com/p/article") - == "substack" - ) - assert ( - summarize_cog.detect_url_type("https://newsletter.substack.com/p/post") - == "substack" - ) + assert summarize_cog.detect_url_type("https://example.substack.com/p/article") == "substack" + assert summarize_cog.detect_url_type("https://newsletter.substack.com/p/post") == "substack" def test_detect_twitter(self, summarize_cog): - assert ( - summarize_cog.detect_url_type("https://twitter.com/user/status/123") - == "twitter" - ) - assert ( - summarize_cog.detect_url_type("https://x.com/user/status/123") == "twitter" - ) + assert summarize_cog.detect_url_type("https://twitter.com/user/status/123") == "twitter" + assert summarize_cog.detect_url_type("https://x.com/user/status/123") == "twitter" def test_detect_generic_web(self, summarize_cog): assert summarize_cog.detect_url_type("https://example.com/article") == "web" - assert ( - summarize_cog.detect_url_type("https://nytimes.com/2024/article") == "web" - ) + assert summarize_cog.detect_url_type("https://nytimes.com/2024/article") == "web" assert summarize_cog.detect_url_type("https://blog.example.org/post") == "web" @@ -222,9 +203,7 @@ def test_sets_image_when_thumbnail_provided(self, summarize_cog): class TestSummarizeCommand: async def test_rejects_invalid_url(self, summarize_cog, mock_interaction): - await summarize_cog.summarize.callback( - summarize_cog, mock_interaction, "not-a-url" - ) + await summarize_cog.summarize.callback(summarize_cog, mock_interaction, "not-a-url") mock_interaction.followup.send.assert_called_once() call_args = mock_interaction.followup.send.call_args @@ -372,9 +351,7 @@ async def test_handles_summarization_error(self, summarize_cog, mock_interaction content="This is enough content for summarization. " * 10, ) - summarize_cog._summarizer.summarize = AsyncMock( - side_effect=Exception("API Error") - ) + summarize_cog._summarizer.summarize = AsyncMock(side_effect=Exception("API Error")) with patch.object( summarize_cog, "_fetch_web_content", AsyncMock(return_value=mock_content) @@ -414,9 +391,7 @@ async def test_cog_unload_closes_http_client(self, mock_bot): class TestSummarizeCooldown: - async def test_cooldown_error_sends_retry_message( - self, summarize_cog, mock_interaction - ): + async def test_cooldown_error_sends_retry_message(self, summarize_cog, mock_interaction): from discord import app_commands mock_interaction.response.send_message = AsyncMock() @@ -448,9 +423,7 @@ async def test_cooldown_error_shows_seconds_only_for_short_wait( assert "45s" in call_args[0][0] assert "m " not in call_args[0][0] - async def test_non_cooldown_error_is_reraised( - self, summarize_cog, mock_interaction - ): + async def test_non_cooldown_error_is_reraised(self, summarize_cog, mock_interaction): from discord import app_commands error = app_commands.MissingPermissions(["manage_guild"]) diff --git a/tests/test_services/test_message_ingestion.py b/tests/test_services/test_message_ingestion.py index 102fc57..a29595f 100644 --- a/tests/test_services/test_message_ingestion.py +++ b/tests/test_services/test_message_ingestion.py @@ -295,6 +295,7 @@ async def test_store_chunks_single(self, service, mock_deps): embedding_service.embed_batch.assert_called_once() repository.add_message_chunk_metas_batch.assert_called_once() vector_store.upsert_message_chunks_batch.assert_called_once() + assert vector_store.upsert_message_chunks_batch.call_args.args[0] == "111" metas = repository.add_message_chunk_metas_batch.call_args[0][0] assert len(metas) == 1 @@ -323,23 +324,32 @@ async def test_rebuild_vector_index(self, service, mock_deps): ] ) - result = await service.rebuild_vector_index(batch_size=2) + result = await service.rebuild_vector_index("guild-1", batch_size=2) assert result == 3 - vector_store.recreate_message_chunks_collection.assert_called_once() + vector_store.recreate_message_chunks_collection.assert_called_once_with("guild-1") assert vector_store.upsert_message_chunks_batch.await_count == 2 - repository.get_message_chunk_metas_batch.assert_any_call(offset=0, limit=2) - repository.get_message_chunk_metas_batch.assert_any_call(offset=2, limit=2) + repository.count_message_chunk_metas.assert_called_once_with(guild_id="guild-1") + repository.get_message_chunk_metas_batch.assert_any_call( + offset=0, + limit=2, + guild_id="guild-1", + ) + repository.get_message_chunk_metas_batch.assert_any_call( + offset=2, + limit=2, + guild_id="guild-1", + ) async def test_rebuild_vector_index_empty(self, service, mock_deps): repository, embedding_service, vector_store = mock_deps repository.count_message_chunk_metas = AsyncMock(return_value=0) repository.get_message_chunk_metas_batch = AsyncMock(return_value=[]) - result = await service.rebuild_vector_index() + result = await service.rebuild_vector_index("guild-1") assert result == 0 - vector_store.recreate_message_chunks_collection.assert_called_once() + vector_store.recreate_message_chunks_collection.assert_called_once_with("guild-1") embedding_service.embed_batch.assert_not_called() vector_store.upsert_message_chunks_batch.assert_not_called() diff --git a/tests/test_vector_store.py b/tests/test_vector_store.py index 77fb310..339147c 100644 --- a/tests/test_vector_store.py +++ b/tests/test_vector_store.py @@ -72,12 +72,30 @@ async def test_search_empty_collection(self, vector_store): assert results == [] async def test_message_chunk_doc_count(self, vector_store): - assert await vector_store.message_chunk_doc_count() == 0 + assert await vector_store.message_chunk_doc_count("guild-1") == 0 - await vector_store.upsert_message_chunk("chunk-1", [1.0, 0.0, 0.0, 0.0]) - await vector_store.upsert_message_chunk("chunk-2", [0.0, 1.0, 0.0, 0.0]) + await vector_store.upsert_message_chunk("guild-1", "chunk-1", [1.0, 0.0, 0.0, 0.0]) + await vector_store.upsert_message_chunk("guild-1", "chunk-2", [0.0, 1.0, 0.0, 0.0]) - assert await vector_store.message_chunk_doc_count() == 2 + assert await vector_store.message_chunk_doc_count("guild-1") == 2 + + async def test_message_chunk_collections_are_guild_scoped(self, vector_store): + await vector_store.upsert_message_chunk("guild-1", "chunk-1", [1.0, 0.0, 0.0, 0.0]) + await vector_store.upsert_message_chunk("guild-2", "chunk-2", [0.0, 1.0, 0.0, 0.0]) + + guild_1_results = await vector_store.search_message_chunks( + "guild-1", + [1.0, 0.0, 0.0, 0.0], + topk=5, + ) + guild_2_results = await vector_store.search_message_chunks( + "guild-2", + [1.0, 0.0, 0.0, 0.0], + topk=5, + ) + + assert [result.chunk_id for result in guild_1_results] == ["chunk-1"] + assert [result.chunk_id for result in guild_2_results] == ["chunk-2"] class TestUpsertBatch: @@ -108,13 +126,13 @@ async def test_delete_article(self, vector_store): class TestRecreateCollections: async def test_recreate_message_chunks_collection(self, vector_store): - await vector_store.upsert_message_chunk("chunk-1", [1.0, 0.0, 0.0, 0.0]) - assert await vector_store.message_chunk_doc_count() == 1 + await vector_store.upsert_message_chunk("guild-1", "chunk-1", [1.0, 0.0, 0.0, 0.0]) + assert await vector_store.message_chunk_doc_count("guild-1") == 1 - await vector_store.recreate_message_chunks_collection() + await vector_store.recreate_message_chunks_collection("guild-1") - assert await vector_store.message_chunk_doc_count() == 0 - results = await vector_store.search_message_chunks([1.0, 0.0, 0.0, 0.0], topk=1) + assert await vector_store.message_chunk_doc_count("guild-1") == 0 + results = await vector_store.search_message_chunks("guild-1", [1.0, 0.0, 0.0, 0.0], topk=1) assert results == []