diff --git a/lightspeed-stack.yaml b/lightspeed-stack.yaml index ba29f85f..66326b9f 100644 --- a/lightspeed-stack.yaml +++ b/lightspeed-stack.yaml @@ -30,3 +30,7 @@ conversation_cache: authentication: module: "noop" + + +solr: + offline: True \ No newline at end of file diff --git a/run.yaml b/run.yaml index 3680f2b3..e1446f1a 100644 --- a/run.yaml +++ b/run.yaml @@ -15,7 +15,7 @@ apis: benchmarks: [] datasets: [] image_name: starter -# external_providers_dir: /opt/app-root/src/.llama/providers.d +external_providers_dir: ${env.EXTERNAL_PROVIDERS_DIR} providers: inference: @@ -24,7 +24,9 @@ providers: config: api_key: ${env.OPENAI_API_KEY} allowed_models: ["${env.E2E_OPENAI_MODEL:=gpt-4o-mini}"] - - config: {} + - config: + allowed_models: + - ${env.EMBEDDING_MODEL_DIR} provider_id: sentence-transformers provider_type: inline::sentence-transformers files: @@ -56,6 +58,18 @@ providers: provider_id: rag-runtime provider_type: inline::rag-runtime vector_io: + - provider_id: solr-vector + provider_type: remote::solr_vector_io + config: + solr_url: http://localhost:8983/solr + collection_name: portal-rag + vector_field: chunk_vector + content_field: chunk + embedding_dimension: 384 + embedding_model: ${env.EMBEDDING_MODEL_DIR} + persistence: + namespace: portal-rag + backend: kv_default - config: # Define the storage backend for RAG persistence: namespace: vector_io::faiss @@ -127,7 +141,13 @@ storage: namespace: prompts backend: kv_default registered_resources: - models: [] + models: + - model_id: granite-embedding-30m + model_type: embedding + provider_id: sentence-transformers + provider_model_id: ${env.EMBEDDING_MODEL_DIR} + metadata: + embedding_dimension: 384 shields: - shield_id: llama-guard provider_id: llama-guard diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index ae8a9071..a23a93e0 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -2,8 +2,8 @@ """Handler for REST API call to provide answer to query using Response API.""" +import datetime import logging -from datetime import UTC, datetime from typing import Annotated, Any, cast from fastapi import APIRouter, Depends, HTTPException, Request @@ -25,6 +25,7 @@ from configuration import configuration from models.config import Action from models.requests import QueryRequest + from models.responses import ( ForbiddenResponse, InternalServerErrorResponse, @@ -32,6 +33,7 @@ PromptTooLongResponse, QueryResponse, QuotaExceededResponse, + ReferencedDocument, ServiceUnavailableResponse, UnauthorizedResponse, UnprocessableEntityResponse, @@ -63,7 +65,11 @@ run_shield_moderation, ) from utils.suid import normalize_conversation_id -from utils.types import ResponsesApiParams, TurnSummary +from utils.types import ( + ResponsesApiParams, + TurnSummary, +) +from utils.vector_search import perform_vector_search, format_rag_context_for_injection logger = logging.getLogger("app.endpoints.handlers") router = APIRouter(tags=["query"]) @@ -77,9 +83,9 @@ examples=["endpoint", "conversation read", "model override"] ), 404: NotFoundResponse.openapi_response( - examples=["model", "conversation", "provider"] + examples=["conversation", "model", "provider"] ), - 413: PromptTooLongResponse.openapi_response(), + # 413: PromptTooLongResponse.openapi_response(), 422: UnprocessableEntityResponse.openapi_response(), 429: QuotaExceededResponse.openapi_response(), 500: InternalServerErrorResponse.openapi_response(examples=["configuration"]), @@ -145,6 +151,19 @@ async def query_endpoint_handler( client = AsyncLlamaStackClientHolder().get_client() + doc_ids_from_chunks: list[ReferencedDocument] = [] + pre_rag_chunks: list[Any] = [] # use your RAGChunk type (or the upstream one) + + _, _, doc_ids_from_chunks, pre_rag_chunks = await perform_vector_search( + client, query_request, configuration + ) + + rag_context = format_rag_context_for_injection(pre_rag_chunks) + if rag_context: + # safest: mutate a local copy so we don't surprise other logic + query_request = query_request.model_copy(deep=True) # pydantic v2 + query_request.query = query_request.query + rag_context + # Prepare API request parameters responses_params = await prepare_responses_params( client, @@ -168,6 +187,14 @@ async def query_endpoint_handler( # Retrieve response using Responses API turn_summary = await retrieve_response(client, responses_params) + if pre_rag_chunks: + turn_summary.rag_chunks = pre_rag_chunks + (turn_summary.rag_chunks or []) + + if doc_ids_from_chunks: + turn_summary.referenced_documents = parse_referenced_docs( + doc_ids_from_chunks + (turn_summary.referenced_documents or []) + ) + # Get topic summary for new conversation if not user_conversation and query_request.generate_topic_summary: logger.debug("Generating topic summary for new conversation") @@ -190,7 +217,9 @@ async def query_endpoint_handler( quota_limiters=configuration.quota_limiters, user_id=user_id ) - completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ") + completed_at = datetime.datetime.now(datetime.timezone.utc).strftime( + "%Y-%m-%dT%H:%M:%SZ" + ) conversation_id = normalize_conversation_id(responses_params.conversation) logger.info("Storing query results") @@ -222,6 +251,19 @@ async def query_endpoint_handler( ) +def parse_referenced_docs( + docs: list[ReferencedDocument], +) -> list[ReferencedDocument]: + seen: set[tuple[str | None, str | None]] = set() + out: list[ReferencedDocument] = [] + for d in docs: + key = (d.doc_url, d.doc_title) + if key in seen: + continue + seen.add(key) + out.append(d) + return out + async def retrieve_response( # pylint: disable=too-many-locals client: AsyncLlamaStackClient, responses_params: ResponsesApiParams, diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 047278b1..aa1eaf7c 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -34,7 +34,6 @@ LLM_TOOL_CALL_EVENT, LLM_TOOL_RESULT_EVENT, LLM_TURN_COMPLETE_EVENT, - MEDIA_TYPE_EVENT_STREAM, MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT, ) @@ -53,7 +52,7 @@ UnauthorizedResponse, UnprocessableEntityResponse, ) -from utils.types import ReferencedDocument +from utils.types import RAGChunk, ReferencedDocument from utils.endpoints import ( check_configuration_loaded, validate_and_retrieve_conversation, @@ -85,6 +84,7 @@ from utils.suid import normalize_conversation_id from utils.token_counter import TokenCounter from utils.types import ResponsesApiParams, TurnSummary +from utils.vector_search import format_rag_context_for_injection, perform_vector_search logger = logging.getLogger(__name__) router = APIRouter(tags=["streaming_query"]) @@ -100,7 +100,7 @@ 404: NotFoundResponse.openapi_response( examples=["conversation", "model", "provider"] ), - 413: PromptTooLongResponse.openapi_response(), + # 413: PromptTooLongResponse.openapi_response(), 422: UnprocessableEntityResponse.openapi_response(), 429: QuotaExceededResponse.openapi_response(), 500: InternalServerErrorResponse.openapi_response(examples=["configuration"]), @@ -172,6 +172,17 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals client = AsyncLlamaStackClientHolder().get_client() + pre_rag_chunks: list[RAGChunk] = [] + doc_ids_from_chunks: list[ReferencedDocument] = [] + + _, _, doc_ids_from_chunks, pre_rag_chunks = await perform_vector_search( + client, query_request, configuration + ) + rag_context = format_rag_context_for_injection(pre_rag_chunks) + if rag_context: + query_request = query_request.model_copy(deep=True) + query_request.query = query_request.query + rag_context + # Prepare API request parameters responses_params = await prepare_responses_params( client=client, @@ -212,6 +223,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals generator, turn_summary = await retrieve_response_generator( responses_params=responses_params, context=context, + doc_ids_from_chunks=doc_ids_from_chunks, ) response_media_type = ( @@ -227,13 +239,14 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals responses_params=responses_params, turn_summary=turn_summary, ), - media_type=response_media_type, + media_type=query_request.media_type or MEDIA_TYPE_TEXT, ) async def retrieve_response_generator( responses_params: ResponsesApiParams, context: ResponseGeneratorContext, + doc_ids_from_chunks: list[ReferencedDocument], ) -> tuple[AsyncIterator[str], TurnSummary]: """ Retrieve the appropriate response generator. @@ -273,6 +286,8 @@ async def retrieve_response_generator( response = await context.client.responses.create( **responses_params.model_dump() ) + # Store pre-RAG documents for later merging + turn_summary.pre_rag_documents = doc_ids_from_chunks return response_generator(response, context, turn_summary), turn_summary # Handle know LLS client errors only at stream creation time and shield execution @@ -571,9 +586,21 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat turn_summary.token_usage = extract_token_usage( latest_response_object, context.model_id ) - turn_summary.referenced_documents = parse_referenced_documents( - latest_response_object - ) + tool_based_documents = parse_referenced_documents(latest_response_object) + + # Merge pre-RAG documents with tool-based documents (similar to query.py) + if turn_summary.pre_rag_documents: + all_documents = turn_summary.pre_rag_documents + tool_based_documents + seen = set() + deduplicated_documents = [] + for doc in all_documents: + key = (doc.doc_url, doc.doc_title) + if key not in seen: + seen.add(key) + deduplicated_documents.append(doc) + turn_summary.referenced_documents = deduplicated_documents + else: + turn_summary.referenced_documents = tool_based_documents def stream_http_error_event( @@ -608,7 +635,7 @@ def stream_http_error_event( def format_stream_data(d: dict) -> str: """ - Format a dictionary as a Server-Sent Events (SSE) data string. + Create a response generator function for Responses API streaming. Parameters: d (dict): The data to be formatted as an SSE event. @@ -694,22 +721,24 @@ def stream_event(data: dict, event_type: str, media_type: str) -> str: """Build an item to yield based on media type. Args: - data: The data to yield. - event_type: The type of event (e.g. token, tool request, tool execution). - media_type: Media type of the response (e.g. text or JSON). + data: Dictionary containing the event data + event_type: Type of event (token, tool call, etc.) + media_type: The media type for the response format Returns: - str: The formatted string or JSON to yield. + SSE-formatted string representing the event """ if media_type == MEDIA_TYPE_TEXT: if event_type == LLM_TOKEN_EVENT: - return data["token"] + return data.get("token", "") if event_type == LLM_TOOL_CALL_EVENT: - return f"\nTool call: {json.dumps(data)}\n" + return f"[Tool Call: {data.get('function_name', 'unknown')}]\n" if event_type == LLM_TOOL_RESULT_EVENT: - return f"\nTool result: {json.dumps(data)}\n" - logger.error("Unknown event type: %s", event_type) + return "[Tool Result]\n" + if event_type == LLM_TURN_COMPLETE_EVENT: + return "" return "" + return format_stream_data( { "event": event_type, diff --git a/src/app/main.py b/src/app/main.py index e4ee8390..0bd5ff3a 100644 --- a/src/app/main.py +++ b/src/app/main.py @@ -23,6 +23,11 @@ from utils.common import register_mcp_servers_async from utils.llama_stack_version import check_llama_stack_version +import faulthandler +import signal + +faulthandler.register(signal.SIGUSR1) + logger = get_logger(__name__) logger.info("Initializing app") diff --git a/src/configuration.py b/src/configuration.py index 9a253ac7..46e2d76d 100644 --- a/src/configuration.py +++ b/src/configuration.py @@ -23,6 +23,7 @@ DatabaseConfiguration, ConversationHistoryConfiguration, QuotaHandlersConfiguration, + SolrConfiguration, SplunkConfiguration, ) @@ -363,5 +364,12 @@ def deployment_environment(self) -> str: raise LogicError("logic error: configuration is not loaded") return self._configuration.deployment_environment + @property + def solr(self) -> Optional[SolrConfiguration]: + """Return Solr configuration, or None if not provided.""" + if self._configuration is None: + raise LogicError("logic error: configuration is not loaded") + return self._configuration.solr + configuration: AppConfig = AppConfig() diff --git a/src/constants.py b/src/constants.py index 628aeb45..6b43a2ec 100644 --- a/src/constants.py +++ b/src/constants.py @@ -127,7 +127,7 @@ MCP_AUTH_CLIENT = "client" # default RAG tool value -DEFAULT_RAG_TOOL = "knowledge_search" +DEFAULT_RAG_TOOL = "file_search" # Media type constants for streaming responses MEDIA_TYPE_JSON = "application/json" @@ -168,3 +168,11 @@ # quota limiters constants USER_QUOTA_LIMITER = "user_limiter" CLUSTER_QUOTA_LIMITER = "cluster_limiter" + +# Vector search constants +VECTOR_SEARCH_DEFAULT_K = 5 +VECTOR_SEARCH_DEFAULT_SCORE_THRESHOLD = 0.0 +VECTOR_SEARCH_DEFAULT_MODE = "hybrid" + +# SOLR OKP RAG +MIMIR_DOC_URL = "https://mimir.corp.redhat.com" diff --git a/src/models/config.py b/src/models/config.py index 68ad168e..f6a649a2 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -1654,6 +1654,21 @@ class QuotaHandlersConfiguration(ConfigurationBase): ) +class SolrConfiguration(ConfigurationBase): + """Solr configuration for vector search queries. + + Controls whether to use offline or online mode when building document URLs + from vector search results. + """ + + offline: bool = Field( + True, + title="Offline mode", + description="When True, use parent_id for chunk source URLs. " + "When False, use reference_url for chunk source URLs.", + ) + + class AzureEntraIdConfiguration(ConfigurationBase): """Microsoft Entra ID authentication attributes for Azure.""" @@ -1792,6 +1807,12 @@ class Configuration(ConfigurationBase): "Used in telemetry events.", ) + solr: Optional[SolrConfiguration] = Field( + default=None, + title="Solr configuration", + description="Configuration for Solr vector search operations.", + ) + @model_validator(mode="after") def validate_mcp_auth_headers(self) -> Self: """ diff --git a/src/models/requests.py b/src/models/requests.py index 38bdc9ac..4448940a 100644 --- a/src/models/requests.py +++ b/src/models/requests.py @@ -1,7 +1,7 @@ """Models for REST API requests.""" +from typing import Optional, Self, Any from enum import Enum -from typing import Optional, Self from pydantic import BaseModel, Field, field_validator, model_validator @@ -166,6 +166,13 @@ class QueryRequest(BaseModel): examples=["ocp_docs", "knowledge_base", "vector_db_1"], ) + solr: Optional[dict[str, Any]] = Field( + None, + description="Solr-specific query parameters including filter queries", + examples=[ + {"fq": ["product:*openshift*", "product_version:*4.16*"]}, + ], + ) # provides examples for /docs endpoint model_config = { "extra": "forbid", diff --git a/src/utils/endpoints.py b/src/utils/endpoints.py index de800ca0..332002ee 100644 --- a/src/utils/endpoints.py +++ b/src/utils/endpoints.py @@ -349,7 +349,7 @@ def _process_rag_chunks_for_documents( for chunk in rag_chunks: src = chunk.source - if not src or src == constants.DEFAULT_RAG_TOOL: + if not src or src == constants.DEFAULT_RAG_TOOL or src.endswith("_search"): continue if src.startswith("http"): diff --git a/src/utils/types.py b/src/utils/types.py index 202754a1..92421fe2 100644 --- a/src/utils/types.py +++ b/src/utils/types.py @@ -183,4 +183,5 @@ class TurnSummary(BaseModel): tool_results: list[ToolResultSummary] = Field(default_factory=list) rag_chunks: list[RAGChunk] = Field(default_factory=list) referenced_documents: list[ReferencedDocument] = Field(default_factory=list) + pre_rag_documents: list[ReferencedDocument] = Field(default_factory=list) token_usage: TokenCounter = Field(default_factory=TokenCounter) diff --git a/src/utils/vector_search.py b/src/utils/vector_search.py new file mode 100644 index 00000000..682e513b --- /dev/null +++ b/src/utils/vector_search.py @@ -0,0 +1,271 @@ +"""Vector search utilities for query endpoints. + +This module contains common functionality for performing vector searches +and processing RAG chunks that is shared between query_v2.py and streaming_query_v2.py. +""" + +import logging +import traceback +from typing import Any, Optional +from urllib.parse import urljoin + +from llama_stack_client import AsyncLlamaStackClient + +import constants +from configuration import AppConfig +from models.requests import QueryRequest +from models.responses import ReferencedDocument +from utils.types import RAGChunk + +logger = logging.getLogger(__name__) + + +async def perform_vector_search( + client: AsyncLlamaStackClient, + query_request: QueryRequest, + configuration: AppConfig, +) -> tuple[list[Any], list[float], list[ReferencedDocument], list[RAGChunk]]: + """ + Perform vector search and extract RAG chunks and referenced documents. + + Args: + client: The AsyncLlamaStackClient to use for the request + query_request: The user's query request + configuration: Application configuration + + Returns: + Tuple containing: + - retrieved_chunks: Raw chunks from vector store + - retrieved_scores: Scores for each chunk + - doc_ids_from_chunks: Referenced documents extracted from chunks + - rag_chunks: Processed RAG chunks ready for use + """ + retrieved_chunks = [] + retrieved_scores = [] + doc_ids_from_chunks = [] + rag_chunks = [] + + # Get offline setting from configuration + offline = configuration.solr.offline if configuration.solr else True + + try: + # Get vector stores for direct querying + if query_request.vector_store_ids: + vector_store_ids = query_request.vector_store_ids + logger.info( + "Using specified vector_store_ids for direct query: %s", + vector_store_ids, + ) + else: + vector_store_ids = [ + vector_store.id + for vector_store in (await client.vector_stores.list()).data + ] + logger.info( + "Using all available vector_store_ids for direct query: %s", + vector_store_ids, + ) + + if vector_store_ids: + vector_store_id = vector_store_ids[0] # Use first available vector store + + params = { + "k": constants.VECTOR_SEARCH_DEFAULT_K, + "score_threshold": constants.VECTOR_SEARCH_DEFAULT_SCORE_THRESHOLD, + "mode": constants.VECTOR_SEARCH_DEFAULT_MODE, + } + logger.info("Initial params: %s", params) + logger.info("query_request.solr: %s", query_request.solr) + if query_request.solr: + # Pass the entire solr dict under the 'solr' key + params["solr"] = query_request.solr + logger.info("Final params with solr filters: %s", params) + else: + logger.info("No solr filters provided") + logger.info("Final params being sent to vector_io.query: %s", params) + + query_response = await client.vector_io.query( + vector_store_id=vector_store_id, + query=query_request.query, + params=params, + ) + + logger.info("The query response total payload: %s", query_response) + + if query_response.chunks: + retrieved_chunks = query_response.chunks + retrieved_scores = ( + query_response.scores if hasattr(query_response, "scores") else [] + ) + + # Extract doc_ids from chunks for referenced_documents + metadata_doc_ids = set() + + for chunk in query_response.chunks: + logger.info("Extract doc ids from chunk: %s", chunk) + + # 1) dict metadata + metadata = getattr(chunk, "metadata", None) or {} + doc_id = metadata.get("doc_id") or metadata.get("document_id") + title = metadata.get("title") + + # 2) typed chunk_metadata + if not doc_id: + chunk_meta = getattr(chunk, "chunk_metadata", None) + if chunk_meta is not None: + # chunk_meta might be a pydantic model or a dict depending on caller + if isinstance(chunk_meta, dict): + doc_id = chunk_meta.get("doc_id") or chunk_meta.get( + "document_id" + ) + title = title or chunk_meta.get("title") + reference_url = chunk_meta.get("reference_url") + else: + doc_id = getattr(chunk_meta, "doc_id", None) or getattr( + chunk_meta, "document_id", None + ) + title = title or getattr(chunk_meta, "title", None) + reference_url = getattr( + chunk_meta, "reference_url", None + ) + else: + reference_url = None + else: + reference_url = metadata.get("reference_url") + + if not doc_id and not reference_url: + continue + + # Build URL based on offline flag + doc_url, reference_doc = _build_document_url( + offline, doc_id, reference_url + ) + + if reference_doc and reference_doc not in metadata_doc_ids: + metadata_doc_ids.add(reference_doc) + doc_ids_from_chunks.append( + ReferencedDocument( + doc_title=title, + doc_url=doc_url, + ) + ) + + logger.info( + "Extracted %d unique document IDs from chunks", + len(doc_ids_from_chunks), + ) + + # Convert retrieved chunks to RAGChunk format + rag_chunks = _convert_chunks_to_rag_format( + retrieved_chunks, retrieved_scores, offline + ) + logger.info("Retrieved %d chunks from vector DB", len(rag_chunks)) + + except Exception as e: + logger.warning("Failed to query vector database for chunks: %s", e) + logger.debug("Vector DB query error details: %s", traceback.format_exc()) + # Continue without RAG chunks + + return retrieved_chunks, retrieved_scores, doc_ids_from_chunks, rag_chunks + + +def _build_document_url( + offline: bool, doc_id: Optional[str], reference_url: Optional[str] +) -> tuple[str, Optional[str]]: + """ + Build document URL based on offline flag and available metadata. + + Args: + offline: Whether to use offline mode (parent_id) or online mode (reference_url) + doc_id: Document ID from chunk metadata + reference_url: Reference URL from chunk metadata + + Returns: + Tuple of (doc_url, reference_doc) where: + - doc_url: The full URL for the document + - reference_doc: The document reference used for deduplication + """ + if offline: + # Use parent/doc path + reference_doc = doc_id + doc_url = constants.MIMIR_DOC_URL + reference_doc if reference_doc else "" + else: + # Use reference_url if online + reference_doc = reference_url or doc_id + doc_url = ( + reference_doc + if reference_doc and reference_doc.startswith("http") + else (constants.MIMIR_DOC_URL + reference_doc if reference_doc else "") + ) + + return doc_url, reference_doc + + +def _convert_chunks_to_rag_format( + retrieved_chunks: list[Any], + retrieved_scores: list[float], + offline: bool, +) -> list[RAGChunk]: + """ + Convert retrieved chunks to RAGChunk format. + + Args: + retrieved_chunks: Raw chunks from vector store + retrieved_scores: Scores for each chunk + offline: Whether to use offline mode for source URLs + + Returns: + List of RAGChunk objects + """ + rag_chunks = [] + + for i, chunk in enumerate(retrieved_chunks): + # Extract source from chunk metadata based on offline flag + source = None + if chunk.metadata: + if offline: + parent_id = chunk.metadata.get("parent_id") + if parent_id: + source = urljoin(constants.MIMIR_DOC_URL, parent_id) + else: + source = chunk.metadata.get("reference_url") + + # Get score from retrieved_scores list if available + score = retrieved_scores[i] if i < len(retrieved_scores) else None + + rag_chunks.append( + RAGChunk( + content=chunk.content, + source=source, + score=score, + ) + ) + + return rag_chunks + + +def format_rag_context_for_injection( + rag_chunks: list[RAGChunk], max_chunks: int = 5 +) -> str: + """ + Format RAG context for injection into user message. + + Args: + rag_chunks: List of RAG chunks to format + max_chunks: Maximum number of chunks to include (default: 5) + + Returns: + Formatted RAG context string ready for injection + """ + if not rag_chunks: + return "" + + context_chunks = [] + for chunk in rag_chunks[:max_chunks]: # Limit to top chunks + chunk_text = f"Source: {chunk.source or 'Unknown'}\n{chunk.content}" + context_chunks.append(chunk_text) + + rag_context = "\n\nRelevant documentation:\n" + "\n\n".join(context_chunks) + logger.info("Injecting %d RAG chunks into user message", len(context_chunks)) + + return rag_context diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 4fb8cc6e..9ac0d4fe 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -1,9 +1,8 @@ -"""Unit tests for the /query REST API endpoint.""" - -# pylint: disable=redefined-outer-name -# pylint: disable=too-many-lines -# pylint: disable=ungrouped-imports +# pylint: disable=redefined-outer-name, import-error,too-many-locals,too-many-lines +# pyright: reportCallIssue=false +"""Unit tests for the /query (v2) REST API endpoint using Responses API.""" +from pathlib import Path from typing import Any import pytest @@ -14,7 +13,7 @@ from app.endpoints.query import query_endpoint_handler, retrieve_response from configuration import AppConfig -from models.config import Action +from models.config import ModelContextProtocolServer from models.database.conversations import UserConversation from models.requests import Attachment, QueryRequest from models.responses import QueryResponse @@ -34,25 +33,18 @@ def create_dummy_request() -> Request: """Create dummy request fixture for testing. - Create a minimal FastAPI Request with test-ready authorization state. - - The returned Request has a minimal HTTP scope and a - `state.authorized_actions` attribute initialized to a set containing all - members of the `Action` enum, suitable for use in unit tests that require - an authenticated request context. + Create a minimal FastAPI Request object suitable for unit tests. Returns: - req (Request): FastAPI Request with `state.authorized_actions` set to `set(Action)`. + request (fastapi.Request): A Request constructed with a bare HTTP scope + (type "http") for use in tests. """ - req = Request( - scope={ - "type": "http", - } - ) - - req.state.authorized_actions = set(Action) + req = Request(scope={"type": "http"}) return req + # Test with mcp_headers=None (server should be skipped since auth is required but unavailable) + tools_no_headers = get_mcp_tools(servers, token=None, mcp_headers=None) + assert len(tools_no_headers) == 0 # Server skipped due to missing required auth @pytest.fixture(name="setup_configuration") def setup_configuration_fixture() -> AppConfig: diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 0f165c4f..3f527c96 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -1,4 +1,5 @@ -"""Unit tests for the /streaming-query REST API endpoint.""" +# pylint: disable=redefined-outer-name,import-error, too-many-function-args +"""Unit tests for the /streaming_query (v2) endpoint using Responses API.""" # pylint: disable=too-many-lines,too-many-function-args import json