From 097d03957ed7bd56769f4f99a2e4886e75670692 Mon Sep 17 00:00:00 2001 From: Lucas Date: Tue, 3 Feb 2026 11:42:26 -0500 Subject: [PATCH 1/3] Safety shield config Signed-off-by: Lucas --- src/app/endpoints/a2a.py | 1 + src/app/endpoints/query_v2.py | 4 +++- src/app/endpoints/streaming_query_v2.py | 4 +++- src/models/requests.py | 9 ++++++++ src/utils/shields.py | 28 ++++++++++++++++++++++--- 5 files changed, 41 insertions(+), 5 deletions(-) diff --git a/src/app/endpoints/a2a.py b/src/app/endpoints/a2a.py index 7e3fc0152..b0fc9d58a 100644 --- a/src/app/endpoints/a2a.py +++ b/src/app/endpoints/a2a.py @@ -312,6 +312,7 @@ async def _process_task_streaming( # pylint: disable=too-many-locals generate_topic_summary=True, media_type=None, vector_store_ids=vector_store_ids, + shield_ids=None, ) # Get LLM client and select model diff --git a/src/app/endpoints/query_v2.py b/src/app/endpoints/query_v2.py index ecc39b071..dfd5f7668 100644 --- a/src/app/endpoints/query_v2.py +++ b/src/app/endpoints/query_v2.py @@ -401,7 +401,9 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche ) # Run shield moderation before calling LLM - moderation_result = await run_shield_moderation(client, input_text) + moderation_result = await run_shield_moderation( + client, input_text, query_request.shield_ids + ) if moderation_result.blocked: violation_message = moderation_result.message or "" await append_turn_to_conversation( diff --git a/src/app/endpoints/streaming_query_v2.py b/src/app/endpoints/streaming_query_v2.py index e1c02ca4a..787b4565a 100644 --- a/src/app/endpoints/streaming_query_v2.py +++ b/src/app/endpoints/streaming_query_v2.py @@ -451,7 +451,9 @@ async def retrieve_response( # pylint: disable=too-many-locals ) # Run shield moderation before calling LLM - moderation_result = await run_shield_moderation(client, input_text) + moderation_result = await run_shield_moderation( + client, input_text, query_request.shield_ids + ) if moderation_result.blocked: violation_message = moderation_result.message or "" await append_turn_to_conversation( diff --git a/src/models/requests.py b/src/models/requests.py index 18e5b4b61..ccef6a741 100644 --- a/src/models/requests.py +++ b/src/models/requests.py @@ -83,6 +83,7 @@ class QueryRequest(BaseModel): generate_topic_summary: Whether to generate topic summary for new conversations. media_type: The optional media type for response format (application/json or text/plain). vector_store_ids: The optional list of specific vector store IDs to query for RAG. + shield_ids: The optional list of safety shield IDs to apply. Example: ```python @@ -166,6 +167,14 @@ class QueryRequest(BaseModel): examples=["ocp_docs", "knowledge_base", "vector_db_1"], ) + shield_ids: Optional[list[str]] = Field( + None, + description="Optional list of safety shield IDs to apply. " + "If None, all configured shields are used. " + "If empty list, all shields are skipped.", + examples=["llama-guard", "custom-shield"], + ) + # provides examples for /docs endpoint model_config = { "extra": "forbid", diff --git a/src/utils/shields.py b/src/utils/shields.py index 065cc96e4..dc676e11f 100644 --- a/src/utils/shields.py +++ b/src/utils/shields.py @@ -1,7 +1,7 @@ """Utility functions for working with Llama Stack shields.""" import logging -from typing import Any, cast +from typing import Any, Optional, cast from fastapi import HTTPException from llama_stack_client import AsyncLlamaStackClient, BadRequestError @@ -63,16 +63,19 @@ def detect_shield_violations(output_items: list[Any]) -> bool: async def run_shield_moderation( client: AsyncLlamaStackClient, input_text: str, + shield_ids: Optional[list[str]] = None, ) -> ShieldModerationResult: """ Run shield moderation on input text. - Iterates through all configured shields and runs moderation checks. + Iterates through configured shields and runs moderation checks. Raises HTTPException if shield model is not found. Parameters: client: The Llama Stack client. input_text: The text to moderate. + shield_ids: Optional list of shield IDs to use. If None, uses all shields. + If empty list, skips all shields. Returns: ShieldModerationResult: Result indicating if content was blocked and the message. @@ -80,9 +83,28 @@ async def run_shield_moderation( Raises: HTTPException: If shield's provider_resource_id is not configured or model not found. """ + all_shields = await client.shields.list() + + # Filter shields based on shield_ids parameter + if shield_ids is not None: + if len(shield_ids) == 0: + logger.info("shield_ids=[] provided, skipping all shields") + return ShieldModerationResult(blocked=False) + + shields_to_run = [s for s in all_shields if s.identifier in shield_ids] + + # Log warning if requested shield not found + requested = set(shield_ids) + available = {s.identifier for s in shields_to_run} + missing = requested - available + if missing: + logger.warning("Requested shields not found: %s", missing) + else: + shields_to_run = list(all_shields) + available_models = {model.id for model in await client.models.list()} - for shield in await client.shields.list(): + for shield in shields_to_run: if ( not shield.provider_resource_id or shield.provider_resource_id not in available_models From 71b5bad9690202da85dbcd35bc32b830eebb11b2 Mon Sep 17 00:00:00 2001 From: Lucas Date: Tue, 3 Feb 2026 13:14:55 -0500 Subject: [PATCH 2/3] Typos for shield fail gracefully. Added test suite Signed-off-by: Lucas --- src/utils/shields.py | 10 ++++- tests/unit/utils/test_shields.py | 69 ++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 1 deletion(-) diff --git a/src/utils/shields.py b/src/utils/shields.py index dc676e11f..9e4a929c7 100644 --- a/src/utils/shields.py +++ b/src/utils/shields.py @@ -8,7 +8,7 @@ from llama_stack_client.types import CreateResponse import metrics -from models.responses import NotFoundResponse +from models.responses import NotFoundResponse, UnprocessableEntityResponse from utils.types import ShieldModerationResult logger = logging.getLogger(__name__) @@ -99,6 +99,14 @@ async def run_shield_moderation( missing = requested - available if missing: logger.warning("Requested shields not found: %s", missing) + + # Reject if no requested shields were found (prevents accidental bypass) + if not shields_to_run: + response = UnprocessableEntityResponse( + response="Invalid shield configuration", + cause=f"Requested shield_ids not found: {sorted(missing)}", + ) + raise HTTPException(**response.model_dump()) else: shields_to_run = list(all_shields) diff --git a/tests/unit/utils/test_shields.py b/tests/unit/utils/test_shields.py index adf3fe8b1..b33a4289a 100644 --- a/tests/unit/utils/test_shields.py +++ b/tests/unit/utils/test_shields.py @@ -312,6 +312,75 @@ async def test_returns_blocked_on_bad_request_error( assert result.shield_model == "moderation-model" mock_metric.inc.assert_called_once() + @pytest.mark.asyncio + async def test_shield_ids_empty_list_skips_all_shields( + self, mocker: MockerFixture + ) -> None: + """Test that shield_ids=[] explicitly skips all shields (intentional bypass).""" + mock_client = mocker.Mock() + shield = mocker.Mock() + shield.identifier = "shield-1" + mock_client.shields.list = mocker.AsyncMock(return_value=[shield]) + + result = await run_shield_moderation(mock_client, "test input", shield_ids=[]) + + assert result.blocked is False + mock_client.shields.list.assert_called_once() + + @pytest.mark.asyncio + async def test_shield_ids_raises_exception_when_no_shields_found( + self, mocker: MockerFixture + ) -> None: + """Test shield_ids raises HTTPException when no requested shields exist.""" + mock_client = mocker.Mock() + shield = mocker.Mock() + shield.identifier = "shield-1" + mock_client.shields.list = mocker.AsyncMock(return_value=[shield]) + + with pytest.raises(HTTPException) as exc_info: + await run_shield_moderation( + mock_client, "test input", shield_ids=["typo-shield"] + ) + + assert exc_info.value.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + assert "Invalid shield configuration" in exc_info.value.detail["response"] # type: ignore + assert "typo-shield" in exc_info.value.detail["cause"] # type: ignore + + @pytest.mark.asyncio + async def test_shield_ids_filters_to_specific_shield( + self, mocker: MockerFixture + ) -> None: + """Test that shield_ids filters to only specified shields.""" + mock_client = mocker.Mock() + + shield1 = mocker.Mock() + shield1.identifier = "shield-1" + shield1.provider_resource_id = "model-1" + shield2 = mocker.Mock() + shield2.identifier = "shield-2" + shield2.provider_resource_id = "model-2" + mock_client.shields.list = mocker.AsyncMock(return_value=[shield1, shield2]) + + model1 = mocker.Mock() + model1.id = "model-1" + mock_client.models.list = mocker.AsyncMock(return_value=[model1]) + + moderation_result = mocker.Mock() + moderation_result.results = [mocker.Mock(flagged=False)] + mock_client.moderations.create = mocker.AsyncMock( + return_value=moderation_result + ) + + result = await run_shield_moderation( + mock_client, "test input", shield_ids=["shield-1"] + ) + + assert result.blocked is False + assert mock_client.moderations.create.call_count == 1 + mock_client.moderations.create.assert_called_with( + input="test input", model="model-1" + ) + class TestAppendTurnToConversation: # pylint: disable=too-few-public-methods """Tests for append_turn_to_conversation function.""" From bdf945f5c1f531ea2bc8fe642426729df28d5432 Mon Sep 17 00:00:00 2001 From: Lucas Date: Fri, 13 Feb 2026 15:18:38 -0500 Subject: [PATCH 3/3] new configuration parameter that will enable or disable safety shield config Signed-off-by: Lucas --- README.md | 7 + src/app/endpoints/query.py | 16 +- src/app/endpoints/query_v2.py | 862 ------------------------ src/app/endpoints/streaming_query.py | 6 +- src/app/endpoints/streaming_query_v2.py | 480 ------------- src/models/config.py | 1 + src/utils/shields.py | 38 +- tests/unit/utils/test_shields.py | 81 +++ 8 files changed, 144 insertions(+), 1347 deletions(-) delete mode 100644 src/app/endpoints/query_v2.py delete mode 100644 src/app/endpoints/streaming_query_v2.py diff --git a/README.md b/README.md index 9775e8fa7..b57a9a2b4 100644 --- a/README.md +++ b/README.md @@ -657,6 +657,13 @@ utilized: 1. If the `shield_id` starts with `inout_`, it will be used both for input and output. 1. Otherwise, it will be used for input only. +Additionally, an optional list parameter `shield_ids` can be specified in `/query` and `/streaming_query` endpoints to override which shields are applied. You can use this config to disable shield overrides: + +```yaml +customization: + disable_shield_ids_override: true +``` + ## Authentication See [authentication and authorization](docs/auth.md). diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 9cd7cf353..ac1e421e1 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -4,7 +4,7 @@ import logging from datetime import UTC, datetime -from typing import Annotated, Any, cast +from typing import Annotated, Any, Optional, cast from fastapi import APIRouter, Depends, HTTPException, Request from llama_stack_api.openai_responses import OpenAIResponseObject @@ -61,6 +61,7 @@ from utils.shields import ( append_turn_to_conversation, run_shield_moderation, + validate_shield_ids_override, ) from utils.suid import normalize_conversation_id from utils.types import ResponsesApiParams, TurnSummary @@ -125,6 +126,9 @@ async def query_endpoint_handler( # Enforce RBAC: optionally disallow overriding model/provider in requests validate_model_provider_override(query_request, request.state.authorized_actions) + # Validate shield_ids override if provided + validate_shield_ids_override(query_request, configuration) + # Validate attachments if provided if query_request.attachments: validate_attachments_metadata(query_request.attachments) @@ -166,7 +170,9 @@ async def query_endpoint_handler( client = await update_azure_token(client) # Retrieve response using Responses API - turn_summary = await retrieve_response(client, responses_params) + turn_summary = await retrieve_response( + client, responses_params, query_request.shield_ids + ) # Get topic summary for new conversation if not user_conversation and query_request.generate_topic_summary: @@ -225,6 +231,7 @@ async def query_endpoint_handler( async def retrieve_response( # pylint: disable=too-many-locals client: AsyncLlamaStackClient, responses_params: ResponsesApiParams, + shield_ids: Optional[list[str]] = None, ) -> TurnSummary: """ Retrieve response from LLMs and agents. @@ -235,6 +242,7 @@ async def retrieve_response( # pylint: disable=too-many-locals Parameters: client: The AsyncLlamaStackClient to use for the request. responses_params: The Responses API parameters. + shield_ids: Optional list of shield IDs for moderation. Returns: TurnSummary: Summary of the LLM response content @@ -242,7 +250,9 @@ async def retrieve_response( # pylint: disable=too-many-locals summary = TurnSummary() try: - moderation_result = await run_shield_moderation(client, responses_params.input) + moderation_result = await run_shield_moderation( + client, responses_params.input, shield_ids + ) if moderation_result.blocked: # Handle shield moderation blocking violation_message = moderation_result.message or "" diff --git a/src/app/endpoints/query_v2.py b/src/app/endpoints/query_v2.py deleted file mode 100644 index dfd5f7668..000000000 --- a/src/app/endpoints/query_v2.py +++ /dev/null @@ -1,862 +0,0 @@ -# pylint: disable=too-many-locals,too-many-branches,too-many-nested-blocks - -"""Handler for REST API call to provide answer to query using Response API.""" - -import json -import logging -from typing import Annotated, Any, Optional, cast - -from fastapi import APIRouter, Depends, Request -from llama_stack_api.openai_responses import ( - OpenAIResponseMCPApprovalRequest, - OpenAIResponseMCPApprovalResponse, - OpenAIResponseObject, - OpenAIResponseOutput, - OpenAIResponseOutputMessageFileSearchToolCall, - OpenAIResponseOutputMessageFunctionToolCall, - OpenAIResponseOutputMessageMCPCall, - OpenAIResponseOutputMessageMCPListTools, - OpenAIResponseOutputMessageWebSearchToolCall, -) -from llama_stack_client import AsyncLlamaStackClient - -import constants -import metrics -from app.endpoints.query import ( - query_endpoint_handler_base, - validate_attachments_metadata, -) -from authentication import get_auth_dependency -from authentication.interface import AuthTuple -from authorization.middleware import authorize -from configuration import AppConfig, configuration -from constants import DEFAULT_RAG_TOOL -from models.config import Action, ModelContextProtocolServer -from models.requests import QueryRequest -from models.responses import ( - ForbiddenResponse, - InternalServerErrorResponse, - NotFoundResponse, - QueryResponse, - QuotaExceededResponse, - ReferencedDocument, - ServiceUnavailableResponse, - UnauthorizedResponse, - UnprocessableEntityResponse, -) -from utils.endpoints import ( - check_configuration_loaded, - get_system_prompt, - get_topic_summary_system_prompt, -) -from utils.mcp_headers import mcp_headers_dependency -from utils.query import parse_arguments_string -from utils.responses import extract_text_from_response_output_item -from utils.shields import ( - append_turn_to_conversation, - run_shield_moderation, -) -from utils.suid import normalize_conversation_id, to_llama_stack_conversation_id -from utils.token_counter import TokenCounter -from utils.types import RAGChunk, ToolCallSummary, ToolResultSummary, TurnSummary - -logger = logging.getLogger("app.endpoints.handlers") -router = APIRouter(tags=["query_v1"]) - -query_v2_response: dict[int | str, dict[str, Any]] = { - 200: QueryResponse.openapi_response(), - 401: UnauthorizedResponse.openapi_response( - examples=["missing header", "missing token"] - ), - 403: ForbiddenResponse.openapi_response( - examples=["endpoint", "conversation read", "model override"] - ), - 404: NotFoundResponse.openapi_response( - examples=["conversation", "model", "provider"] - ), - # 413: PromptTooLongResponse.openapi_response(), - 422: UnprocessableEntityResponse.openapi_response(), - 429: QuotaExceededResponse.openapi_response(), - 500: InternalServerErrorResponse.openapi_response(examples=["configuration"]), - 503: ServiceUnavailableResponse.openapi_response(), -} - - -def _build_tool_call_summary( # pylint: disable=too-many-return-statements,too-many-branches - output_item: OpenAIResponseOutput, - rag_chunks: list[RAGChunk], -) -> tuple[Optional[ToolCallSummary], Optional[ToolResultSummary]]: - """Translate Responses API tool outputs into ToolCallSummary and ToolResultSummary records. - - Processes OpenAI response output items and extracts tool call and result information. - Also parses RAG chunks from file_search_call items and appends them to the provided list. - - Args: - output_item: An OpenAIResponseOutput item from the response.output array - rag_chunks: List to append extracted RAG chunks to (from file_search_call items) - Returns: - A tuple of (ToolCallSummary, ToolResultSummary) one of them possibly None - if current llama stack Responses API does not provide the information. - - Supported tool types: - - function_call: Function tool calls with parsed arguments (no result) - - file_search_call: File search operations with results (also extracts RAG chunks) - - web_search_call: Web search operations (incomplete) - - mcp_call: MCP calls with server labels - - mcp_list_tools: MCP server tool listings - - mcp_approval_request: MCP approval requests (no result) - - mcp_approval_response: MCP approval responses (no call) - """ - item_type = getattr(output_item, "type", None) - - if item_type == "function_call": - item = cast(OpenAIResponseOutputMessageFunctionToolCall, output_item) - return ( - ToolCallSummary( - id=item.call_id, - name=item.name, - args=parse_arguments_string(item.arguments), - type="function_call", - ), - None, # not supported by Responses API at all - ) - - if item_type == "file_search_call": - file_search_item = cast( - OpenAIResponseOutputMessageFileSearchToolCall, output_item - ) - extract_rag_chunks_from_file_search_item(file_search_item, rag_chunks) - response_payload: Optional[dict[str, Any]] = None - if file_search_item.results is not None: - response_payload = { - "results": [result.model_dump() for result in file_search_item.results] - } - return ToolCallSummary( - id=file_search_item.id, - name=DEFAULT_RAG_TOOL, - args={"queries": file_search_item.queries}, - type="file_search_call", - ), ToolResultSummary( - id=file_search_item.id, - status=file_search_item.status, - content=json.dumps(response_payload) if response_payload else "", - type="file_search_call", - round=1, - ) - - # Incomplete OpenAI Responses API definition in LLS: action attribute not supported yet - if item_type == "web_search_call": - web_search_item = cast( - OpenAIResponseOutputMessageWebSearchToolCall, output_item - ) - return ( - ToolCallSummary( - id=web_search_item.id, - name="web_search", - args={}, - type="web_search_call", - ), - ToolResultSummary( - id=web_search_item.id, - status=web_search_item.status, - content="", - type="web_search_call", - round=1, - ), - ) - - if item_type == "mcp_call": - mcp_call_item = cast(OpenAIResponseOutputMessageMCPCall, output_item) - args = parse_arguments_string(mcp_call_item.arguments) - if mcp_call_item.server_label: - args["server_label"] = mcp_call_item.server_label - content = ( - mcp_call_item.error - if mcp_call_item.error - else (mcp_call_item.output if mcp_call_item.output else "") - ) - - return ToolCallSummary( - id=mcp_call_item.id, - name=mcp_call_item.name, - args=args, - type="mcp_call", - ), ToolResultSummary( - id=mcp_call_item.id, - status="success" if mcp_call_item.error is None else "failure", - content=content, - type="mcp_call", - round=1, - ) - - if item_type == "mcp_list_tools": - mcp_list_tools_item = cast(OpenAIResponseOutputMessageMCPListTools, output_item) - tools_info = [ - { - "name": tool.name, - "description": tool.description, - "input_schema": tool.input_schema, - } - for tool in mcp_list_tools_item.tools - ] - content_dict = { - "server_label": mcp_list_tools_item.server_label, - "tools": tools_info, - } - return ( - ToolCallSummary( - id=mcp_list_tools_item.id, - name="mcp_list_tools", - args={"server_label": mcp_list_tools_item.server_label}, - type="mcp_list_tools", - ), - ToolResultSummary( - id=mcp_list_tools_item.id, - status="success", - content=json.dumps(content_dict), - type="mcp_list_tools", - round=1, - ), - ) - - if item_type == "mcp_approval_request": - approval_request_item = cast(OpenAIResponseMCPApprovalRequest, output_item) - args = parse_arguments_string(approval_request_item.arguments) - return ( - ToolCallSummary( - id=approval_request_item.id, - name=approval_request_item.name, - args=args, - type="tool_call", - ), - None, - ) - - if item_type == "mcp_approval_response": - approval_response_item = cast(OpenAIResponseMCPApprovalResponse, output_item) - content_dict = {} - if approval_response_item.reason: - content_dict["reason"] = approval_response_item.reason - return ( - None, - ToolResultSummary( - id=approval_response_item.approval_request_id, - status="success" if approval_response_item.approve else "denied", - content=json.dumps(content_dict), - type="mcp_approval_response", - round=1, - ), - ) - - return None, None - - -async def get_topic_summary( # pylint: disable=too-many-nested-blocks - question: str, client: AsyncLlamaStackClient, model_id: str -) -> str: - """ - Get a topic summary for a question using Responses API. - - This is the Responses API version of get_topic_summary, which uses - client.responses.create() instead of the Agent API. - - Args: - question: The question to generate a topic summary for - client: The AsyncLlamaStackClient to use for the request - model_id: The llama stack model ID (full format: provider/model) - - Returns: - str: The topic summary for the question - """ - topic_summary_system_prompt = get_topic_summary_system_prompt(configuration) - - # Use Responses API to generate topic summary - response = cast( - OpenAIResponseObject, - await client.responses.create( - input=question, - model=model_id, - instructions=topic_summary_system_prompt, - stream=False, - store=False, # Don't store topic summary requests - ), - ) - - # Extract text from response output - summary_text = "".join( - extract_text_from_response_output_item(output_item) - for output_item in response.output - ) - - return summary_text.strip() if summary_text else "" - - -@router.post("/query", responses=query_v2_response, summary="Query Endpoint Handler V1") -@authorize(Action.QUERY) -async def query_endpoint_handler_v2( - request: Request, - query_request: QueryRequest, - auth: Annotated[AuthTuple, Depends(get_auth_dependency())], - mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency), -) -> QueryResponse: - """ - Handle request to the /query endpoint using Responses API. - - This is a wrapper around query_endpoint_handler_base that provides - the Responses API specific retrieve_response and get_topic_summary functions. - - Returns: - QueryResponse: Contains the conversation ID and the LLM-generated response. - """ - check_configuration_loaded(configuration) - return await query_endpoint_handler_base( - request=request, - query_request=query_request, - auth=auth, - mcp_headers=mcp_headers, - retrieve_response_func=retrieve_response, - get_topic_summary_func=get_topic_summary, - ) - - -async def retrieve_response( # pylint: disable=too-many-locals,too-many-branches,too-many-arguments,too-many-statements - client: AsyncLlamaStackClient, - model_id: str, - query_request: QueryRequest, - token: str, - mcp_headers: Optional[dict[str, dict[str, str]]] = None, - *, - provider_id: str = "", -) -> tuple[TurnSummary, str, list[ReferencedDocument], TokenCounter]: - """ - Retrieve response from LLMs and agents. - - Retrieves a response from the Llama Stack LLM or agent for a - given query, handling shield configuration, tool usage, and - attachment validation. - - This function configures system prompts, shields, and toolgroups - (including RAG and MCP integration) as needed based on - the query request and system configuration. It - validates attachments, manages conversation and session - context, and processes MCP headers for multi-component - processing. Corresponding metrics are updated. - - Parameters: - client (AsyncLlamaStackClient): The AsyncLlamaStackClient to use for the request. - model_id (str): The identifier of the LLM model to use. - query_request (QueryRequest): The user's query and associated metadata. - token (str): The authentication token for authorization. - mcp_headers (dict[str, dict[str, str]], optional): Headers for multi-component processing. - provider_id (str): The identifier of the LLM provider to use. - - Returns: - tuple[TurnSummary, str]: A tuple containing a summary of the LLM or agent's response content - and the conversation ID, the list of parsed referenced documents, - and token usage information. - """ - # use system prompt from request or default one - system_prompt = get_system_prompt(query_request, configuration) - logger.debug("Using system prompt: %s", system_prompt) - - # TODO(lucasagomes): redact attachments content before sending to LLM - # if attachments are provided, validate them - if query_request.attachments: - validate_attachments_metadata(query_request.attachments) - - # Prepare tools for responses API - toolgroups = await prepare_tools_for_responses_api( - client, query_request, token, configuration, mcp_headers - ) - - # Prepare input for Responses API - # Convert attachments to text and concatenate with query - input_text = query_request.query - if query_request.attachments: - for attachment in query_request.attachments: - # Append attachment content with type label - input_text += ( - f"\n\n[Attachment: {attachment.attachment_type}]\n{attachment.content}" - ) - - # Handle conversation ID for Responses API - # Create conversation upfront if not provided - conversation_id = query_request.conversation_id - if conversation_id: - # Conversation ID was provided - convert to llama-stack format - logger.debug("Using existing conversation ID: %s", conversation_id) - llama_stack_conv_id = to_llama_stack_conversation_id(conversation_id) - else: - # No conversation_id provided - create a new conversation first - logger.debug("No conversation_id provided, creating new conversation") - - conversation = await client.conversations.create(metadata={}) - llama_stack_conv_id = conversation.id - # Store the normalized version for later use - conversation_id = normalize_conversation_id(llama_stack_conv_id) - logger.info( - "Created new conversation with ID: %s (normalized: %s)", - llama_stack_conv_id, - conversation_id, - ) - - # Run shield moderation before calling LLM - moderation_result = await run_shield_moderation( - client, input_text, query_request.shield_ids - ) - if moderation_result.blocked: - violation_message = moderation_result.message or "" - await append_turn_to_conversation( - client, llama_stack_conv_id, input_text, violation_message - ) - summary = TurnSummary( - llm_response=violation_message, - tool_calls=[], - tool_results=[], - rag_chunks=[], - ) - return ( - summary, - normalize_conversation_id(conversation_id), - [], - TokenCounter(), - ) - - # Create OpenAI response using responses API - create_kwargs: dict[str, Any] = { - "input": input_text, - "model": model_id, - "instructions": system_prompt, - "tools": cast(Any, toolgroups), - "stream": False, - "store": True, - "conversation": llama_stack_conv_id, - } - - response = await client.responses.create(**create_kwargs) - response = cast(OpenAIResponseObject, response) - logger.debug( - "Received response with ID: %s, conversation ID: %s, output items: %d", - response.id, - conversation_id, - len(response.output), - ) - - # Process OpenAI response format - llm_response = "" - tool_calls: list[ToolCallSummary] = [] - tool_results: list[ToolResultSummary] = [] - rag_chunks: list[RAGChunk] = [] - for output_item in response.output: - message_text = extract_text_from_response_output_item(output_item) - if message_text: - llm_response += message_text - - tool_call, tool_result = _build_tool_call_summary(output_item, rag_chunks) - if tool_call: - tool_calls.append(tool_call) - if tool_result: - tool_results.append(tool_result) - - logger.info( - "Response processing complete - Tool calls: %d, Response length: %d chars", - len(tool_calls), - len(llm_response), - ) - - summary = TurnSummary( - llm_response=llm_response, - tool_calls=tool_calls, - tool_results=tool_results, - rag_chunks=rag_chunks, - ) - - # Extract referenced documents and token usage from Responses API response - referenced_documents = parse_referenced_documents_from_responses_api(response) - model_label = model_id.split("/", 1)[1] if "/" in model_id else model_id - token_usage = extract_token_usage_from_responses_api( - response, model_label, provider_id, system_prompt - ) - - if not summary.llm_response: - logger.warning( - "Response lacks content (conversation_id=%s)", - conversation_id, - ) - - return ( - summary, - normalize_conversation_id(conversation_id), - referenced_documents, - token_usage, - ) - - -def extract_rag_chunks_from_file_search_item( - item: OpenAIResponseOutputMessageFileSearchToolCall, - rag_chunks: list[RAGChunk], -) -> None: - """Extract RAG chunks from a file search tool call item and append to rag_chunks. - - Args: - item: The file search tool call item. - rag_chunks: List to append extracted RAG chunks to. - """ - if item.results is not None: - for result in item.results: - rag_chunk = RAGChunk( - content=result.text, source=result.filename, score=result.score - ) - rag_chunks.append(rag_chunk) - - -def parse_rag_chunks_from_responses_api( - response_obj: OpenAIResponseObject, -) -> list[RAGChunk]: - """ - Extract rag_chunks from the llama-stack OpenAI response. - - Args: - response_obj: The ResponseObject from OpenAI compatible response API in llama-stack. - - Returns: - List of RAGChunk with content, source, score - """ - rag_chunks: list[RAGChunk] = [] - - for output_item in response_obj.output: - item_type = getattr(output_item, "type", None) - if item_type == "file_search_call": - item = cast(OpenAIResponseOutputMessageFileSearchToolCall, output_item) - extract_rag_chunks_from_file_search_item(item, rag_chunks) - - return rag_chunks - - -def parse_referenced_documents_from_responses_api( - response: OpenAIResponseObject, # pylint: disable=unused-argument -) -> list[ReferencedDocument]: - """ - Parse referenced documents from OpenAI Responses API response. - - Args: - response: The OpenAI Response API response object - - Returns: - list[ReferencedDocument]: List of referenced documents with doc_url and doc_title - """ - documents: list[ReferencedDocument] = [] - # Use a set to track unique documents by (doc_url, doc_title) tuple - seen_docs: set[tuple[Optional[str], Optional[str]]] = set() - - # Handle None response (e.g., when agent fails) - if response is None or not response.output: - return documents - - for output_item in response.output: - item_type = getattr(output_item, "type", None) - - # 1. Parse from file_search_call results - if item_type == "file_search_call": - results = getattr(output_item, "results", []) or [] - for result in results: - # Handle both object and dict access - if isinstance(result, dict): - attributes = result.get("attributes", {}) - else: - attributes = getattr(result, "attributes", {}) - - # Try to get URL from attributes - # Look for common URL fields in attributes - doc_url = ( - attributes.get("doc_url") - or attributes.get("docs_url") - or attributes.get("url") - or attributes.get("link") - ) - doc_title = attributes.get("title") - - if doc_title or doc_url: - # Treat empty string as None for URL to satisfy Optional[AnyUrl] - final_url = doc_url if doc_url else None - if (final_url, doc_title) not in seen_docs: - documents.append( - ReferencedDocument(doc_url=final_url, doc_title=doc_title) - ) - seen_docs.add((final_url, doc_title)) - - return documents - - -def extract_token_usage_from_responses_api( - response: OpenAIResponseObject, - model: str, - provider: str, - system_prompt: str = "", # pylint: disable=unused-argument -) -> TokenCounter: - """ - Extract token usage from OpenAI Responses API response and update metrics. - - This function extracts token usage information from the Responses API response - object and updates Prometheus metrics. If usage information is not available, - it returns zero values without estimation. - - Note: When llama stack internally uses chat_completions, the usage field may be - empty or a dict. This is expected and will be populated in future llama stack versions. - - Args: - response: The OpenAI Response API response object - model: The model identifier for metrics labeling - provider: The provider identifier for metrics labeling - system_prompt: The system prompt used (unused, kept for compatibility) - - Returns: - TokenCounter: Token usage information with input_tokens and output_tokens - """ - token_counter = TokenCounter() - token_counter.llm_calls = 1 - - # Extract usage from the response if available - # Note: usage attribute exists at runtime but may not be in type definitions - usage = getattr(response, "usage", None) - if usage: - try: - # Handle both dict and object cases due to llama_stack inconsistency: - # - When llama_stack converts to chat_completions internally, usage is a dict - # - When using proper Responses API, usage should be an object - # TODO: Remove dict handling once llama_stack standardizes on object type # pylint: disable=fixme - if isinstance(usage, dict): - input_tokens = usage.get("input_tokens", 0) - output_tokens = usage.get("output_tokens", 0) - else: - # Object with attributes (expected final behavior) - input_tokens = getattr(usage, "input_tokens", 0) - output_tokens = getattr(usage, "output_tokens", 0) - # Only set if we got valid values - if input_tokens or output_tokens: - token_counter.input_tokens = input_tokens or 0 - token_counter.output_tokens = output_tokens or 0 - - logger.debug( - "Extracted token usage from Responses API: input=%d, output=%d", - token_counter.input_tokens, - token_counter.output_tokens, - ) - - # Update Prometheus metrics only when we have actual usage data - try: - metrics.llm_token_sent_total.labels(provider, model).inc( - token_counter.input_tokens - ) - metrics.llm_token_received_total.labels(provider, model).inc( - token_counter.output_tokens - ) - except (AttributeError, TypeError, ValueError) as e: - logger.warning("Failed to update token metrics: %s", e) - _increment_llm_call_metric(provider, model) - else: - logger.debug( - "Usage object exists but tokens are 0 or None, treating as no usage info" - ) - # Still increment the call counter - _increment_llm_call_metric(provider, model) - except (AttributeError, KeyError, TypeError) as e: - logger.warning( - "Failed to extract token usage from response.usage: %s. Usage value: %s", - e, - usage, - ) - # Still increment the call counter - _increment_llm_call_metric(provider, model) - else: - # No usage information available - this is expected when llama stack - # internally converts to chat_completions - logger.debug( - "No usage information in Responses API response, token counts will be 0" - ) - # token_counter already initialized with 0 values - # Still increment the call counter - _increment_llm_call_metric(provider, model) - - return token_counter - - -def _increment_llm_call_metric(provider: str, model: str) -> None: - """Safely increment LLM call metric.""" - try: - metrics.llm_calls_total.labels(provider, model).inc() - except (AttributeError, TypeError, ValueError) as e: - logger.warning("Failed to update LLM call metric: %s", e) - - -def get_rag_tools(vector_store_ids: list[str]) -> Optional[list[dict[str, Any]]]: - """ - Convert vector store IDs to tools format for Responses API. - - Args: - vector_store_ids: List of vector store identifiers - - Returns: - Optional[list[dict[str, Any]]]: List containing file_search tool configuration, - or None if no vector stores provided - """ - if not vector_store_ids: - return None - - return [ - { - "type": "file_search", - "vector_store_ids": vector_store_ids, - "max_num_results": 10, - } - ] - - -def get_mcp_tools( - mcp_servers: list[ModelContextProtocolServer], - token: str | None = None, - mcp_headers: dict[str, dict[str, str]] | None = None, -) -> list[dict[str, Any]]: - """ - Convert MCP servers to tools format for Responses API. - - Args: - mcp_servers: List of MCP server configurations - token: Optional authentication token for MCP server authorization - mcp_headers: Optional per-request headers for MCP servers, keyed by server URL - - Returns: - list[dict[str, Any]]: List of MCP tool definitions with server - details and optional auth headers - - The way it works is we go through all the defined mcp servers and - create a tool definitions for each of them. If MCP server definition - has a non-empty resolved_authorization_headers we create invocation - headers, following the algorithm: - 1. If the header value is 'kubernetes' the header value is a k8s token - 2. If the header value is 'client': - find the value for a given MCP server/header in mcp_headers. - if the value is not found omit this header, otherwise use found value - 3. otherwise use the value from resolved_authorization_headers directly - - This algorithm allows to: - 1. Use static global header values, provided by configuration - 2. Use user specific k8s token, which will work for the majority of kubernetes - based MCP servers - 3. Use user specific tokens (passed by the client) for user specific MCP headers - """ - - def _get_token_value(original: str, header: str) -> str | None: - """Convert to header value.""" - match original: - case constants.MCP_AUTH_KUBERNETES: - # use k8s token - if token is None or token == "": - return None - return f"Bearer {token}" - case constants.MCP_AUTH_CLIENT: - # use client provided token - if mcp_headers is None: - return None - c_headers = mcp_headers.get(mcp_server.name, None) - if c_headers is None: - return None - return c_headers.get(header, None) - case _: - # use provided - return original - - tools = [] - for mcp_server in mcp_servers: - # Base tool definition - tool_def = { - "type": "mcp", - "server_label": mcp_server.name, - "server_url": mcp_server.url, - "require_approval": "never", - } - - # Build headers - headers = {} - for name, value in mcp_server.resolved_authorization_headers.items(): - # for each defined header - h_value = _get_token_value(value, name) - # only add the header if we got value - if h_value is not None: - headers[name] = h_value - - # Skip server if auth headers were configured but not all could be resolved - if mcp_server.authorization_headers and len(headers) != len( - mcp_server.authorization_headers - ): - logger.warning( - "Skipping MCP server %s: required %d auth headers but only resolved %d", - mcp_server.name, - len(mcp_server.authorization_headers), - len(headers), - ) - continue - - if len(headers) > 0: - # add headers to tool definition - tool_def["headers"] = headers # type: ignore[index] - # collect tools info - tools.append(tool_def) - return tools - - -async def prepare_tools_for_responses_api( - client: AsyncLlamaStackClient, - query_request: QueryRequest, - token: str, - config: AppConfig, - mcp_headers: Optional[dict[str, dict[str, str]]] = None, -) -> Optional[list[dict[str, Any]]]: - """ - Prepare tools for Responses API including RAG and MCP tools. - - This function retrieves vector stores and combines them with MCP - server tools to create a unified toolgroups list for the Responses API. - - Args: - client: The Llama Stack client instance - query_request: The user's query request - token: Authentication token for MCP tools - config: Configuration object containing MCP server settings - mcp_headers: Per-request headers for MCP servers - - Returns: - Optional[list[dict[str, Any]]]: List of tool configurations for the - Responses API, or None if no_tools is True or no tools are available - """ - if query_request.no_tools: - return None - - toolgroups = [] - # Get vector stores for RAG tools - use specified ones or fetch all - if query_request.vector_store_ids: - vector_store_ids = query_request.vector_store_ids - else: - vector_store_ids = [ - vector_store.id for vector_store in (await client.vector_stores.list()).data - ] - - # Add RAG tools if vector stores are available - rag_tools = get_rag_tools(vector_store_ids) - if rag_tools: - toolgroups.extend(rag_tools) - - # Add MCP server tools - mcp_tools = get_mcp_tools(config.mcp_servers, token, mcp_headers) - if mcp_tools: - toolgroups.extend(mcp_tools) - logger.debug( - "Configured %d MCP tools: %s", - len(mcp_tools), - [tool.get("server_label", "unknown") for tool in mcp_tools], - ) - # Convert empty list to None for consistency with existing behavior - if not toolgroups: - return None - - return toolgroups diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 00cbe132b..1753f406d 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -80,6 +80,7 @@ from utils.shields import ( append_turn_to_conversation, run_shield_moderation, + validate_shield_ids_override, ) from utils.suid import normalize_conversation_id from utils.token_counter import TokenCounter @@ -151,6 +152,9 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals # Enforce RBAC: optionally disallow overriding model/provider in requests validate_model_provider_override(query_request, request.state.authorized_actions) + # Validate shield_ids override if provided + validate_shield_ids_override(query_request, configuration) + # Validate attachments if provided if query_request.attachments: validate_attachments_metadata(query_request.attachments) @@ -246,7 +250,7 @@ async def retrieve_response_generator( turn_summary = TurnSummary() try: moderation_result = await run_shield_moderation( - context.client, responses_params.input + context.client, responses_params.input, context.query_request.shield_ids ) if moderation_result.blocked: violation_message = moderation_result.message or "" diff --git a/src/app/endpoints/streaming_query_v2.py b/src/app/endpoints/streaming_query_v2.py deleted file mode 100644 index 787b4565a..000000000 --- a/src/app/endpoints/streaming_query_v2.py +++ /dev/null @@ -1,480 +0,0 @@ -"""Streaming query handler using Responses API (v2).""" - -import logging -from typing import Annotated, Any, AsyncIterator, Optional, cast - -from fastapi import APIRouter, Depends, Request -from fastapi.responses import StreamingResponse -from llama_stack_api.openai_responses import ( - OpenAIResponseObject, - OpenAIResponseObjectStream, - OpenAIResponseObjectStreamResponseCompleted, - OpenAIResponseObjectStreamResponseFailed, - OpenAIResponseObjectStreamResponseOutputItemDone, - OpenAIResponseObjectStreamResponseOutputTextDelta, - OpenAIResponseObjectStreamResponseOutputTextDone, -) -from llama_stack_client import AsyncLlamaStackClient - -from app.endpoints.query import ( - is_transcripts_enabled, - persist_user_conversation_details, - validate_attachments_metadata, -) -from app.endpoints.query_v2 import ( - _build_tool_call_summary, - extract_token_usage_from_responses_api, - get_topic_summary, - parse_referenced_documents_from_responses_api, - prepare_tools_for_responses_api, -) -from app.endpoints.streaming_query import ( - LLM_TOKEN_EVENT, - LLM_TOOL_CALL_EVENT, - LLM_TOOL_RESULT_EVENT, - format_stream_data, - stream_end_event, - stream_event, - stream_start_event, - streaming_query_endpoint_handler_base, -) -from authentication import get_auth_dependency -from authentication.interface import AuthTuple -from authorization.middleware import authorize -from configuration import configuration -from constants import MEDIA_TYPE_JSON -from models.config import Action -from models.context import ResponseGeneratorContext -from models.requests import QueryRequest -from models.responses import ( - ForbiddenResponse, - InternalServerErrorResponse, - NotFoundResponse, - QuotaExceededResponse, - ServiceUnavailableResponse, - StreamingQueryResponse, - UnauthorizedResponse, - UnprocessableEntityResponse, -) -from utils.endpoints import ( - cleanup_after_streaming, - get_system_prompt, -) -from utils.query import create_violation_stream -from utils.quota import consume_tokens, get_available_quotas -from utils.suid import normalize_conversation_id, to_llama_stack_conversation_id -from utils.mcp_headers import mcp_headers_dependency -from utils.shields import ( - append_turn_to_conversation, - run_shield_moderation, -) -from utils.token_counter import TokenCounter -from utils.transcripts import store_transcript -from utils.types import RAGChunk, TurnSummary - -logger = logging.getLogger("app.endpoints.handlers") -router = APIRouter(tags=["streaming_query_v1"]) -auth_dependency = get_auth_dependency() - -streaming_query_v2_responses: dict[int | str, dict[str, Any]] = { - 200: StreamingQueryResponse.openapi_response(), - 401: UnauthorizedResponse.openapi_response( - examples=["missing header", "missing token"] - ), - 403: ForbiddenResponse.openapi_response( - examples=["conversation read", "endpoint", "model override"] - ), - 404: NotFoundResponse.openapi_response( - examples=["conversation", "model", "provider"] - ), - # 413: PromptTooLongResponse.openapi_response(), - 422: UnprocessableEntityResponse.openapi_response(), - 429: QuotaExceededResponse.openapi_response(), - 500: InternalServerErrorResponse.openapi_response(examples=["configuration"]), - 503: ServiceUnavailableResponse.openapi_response(), -} - - -def create_responses_response_generator( # pylint: disable=too-many-locals,too-many-statements - context: ResponseGeneratorContext, -) -> Any: - """ - Create a response generator function for Responses API streaming. - - This factory function returns an async generator that processes streaming - responses from the Responses API and yields Server-Sent Events (SSE). - - Args: - context: Context object containing all necessary parameters for response generation - - Returns: - An async generator function that yields SSE-formatted strings - """ - - async def response_generator( # pylint: disable=too-many-branches,too-many-statements - turn_response: AsyncIterator[OpenAIResponseObjectStream], - ) -> AsyncIterator[str]: - """ - Generate SSE formatted streaming response. - - Asynchronously generates a stream of Server-Sent Events - (SSE) representing incremental responses from a - language model turn. - - Yields start, token, tool call, turn completion, and - end events as SSE-formatted strings. Collects the - complete response for transcript storage if enabled. - """ - chunk_id = 0 - summary = TurnSummary( - llm_response="", tool_calls=[], tool_results=[], rag_chunks=[] - ) - - # Determine media type for response formatting - media_type = context.query_request.media_type or MEDIA_TYPE_JSON - - # Accumulators for Responses API - text_parts: list[str] = [] - emitted_turn_complete = False - - # Use the conversation_id from context (either provided or newly created) - conv_id = context.conversation_id - - # Track the latest response object from response.completed event - latest_response_object: Optional[Any] = None - - # RAG chunks - rag_chunks: list[RAGChunk] = [] - - logger.debug("Starting streaming response (Responses API) processing") - - async for chunk in turn_response: - event_type = getattr(chunk, "type", None) - logger.debug("Processing chunk %d, type: %s", chunk_id, event_type) - - # Emit start event when response is created - if event_type == "response.created": - yield stream_start_event(conv_id) - - # Text streaming - if event_type == "response.output_text.delta": - delta_chunk = cast( - OpenAIResponseObjectStreamResponseOutputTextDelta, chunk - ) - if delta_chunk.delta: - text_parts.append(delta_chunk.delta) - yield stream_event( - { - "id": chunk_id, - "token": delta_chunk.delta, - }, - LLM_TOKEN_EVENT, - media_type, - ) - chunk_id += 1 - - # Final text of the output (capture, but emit at response.completed) - elif event_type == "response.output_text.done": - text_done_chunk = cast( - OpenAIResponseObjectStreamResponseOutputTextDone, chunk - ) - if text_done_chunk.text: - summary.llm_response = text_done_chunk.text - - # Content part started - emit an empty token to kick off UI streaming - elif event_type == "response.content_part.added": - yield stream_event( - { - "id": chunk_id, - "token": "", - }, - LLM_TOKEN_EVENT, - media_type, - ) - chunk_id += 1 - - # Process tool calls and results are emitted together when output items are done - # TODO(asimurka): support emitting tool calls and results separately when ready - elif event_type == "response.output_item.done": - output_item_done_chunk = cast( - OpenAIResponseObjectStreamResponseOutputItemDone, chunk - ) - if output_item_done_chunk.item.type == "message": - continue - tool_call, tool_result = _build_tool_call_summary( - output_item_done_chunk.item, rag_chunks - ) - if tool_call: - summary.tool_calls.append(tool_call) - yield stream_event( - tool_call.model_dump(), - LLM_TOOL_CALL_EVENT, - media_type, - ) - if tool_result: - summary.tool_results.append(tool_result) - yield stream_event( - tool_result.model_dump(), - LLM_TOOL_RESULT_EVENT, - media_type, - ) - - # Completed response - capture final text and response object - elif event_type == "response.completed": - # Capture the response object for token usage extraction - completed_chunk = cast( - OpenAIResponseObjectStreamResponseCompleted, chunk - ) - latest_response_object = completed_chunk.response - - if not emitted_turn_complete: - final_message = summary.llm_response or "".join(text_parts) - if not final_message: - final_message = "No response from the model" - summary.llm_response = final_message - yield stream_event( - { - "id": chunk_id, - "token": final_message, - }, - "turn_complete", - media_type, - ) - chunk_id += 1 - emitted_turn_complete = True - - # Incomplete response - emit error because LLS does not - # support incomplete responses "incomplete_detail" attribute yet - elif event_type == "response.incomplete": - error_response = InternalServerErrorResponse.query_failed( - "An unexpected error occurred while processing the request." - ) - logger.error("Error while obtaining answer for user question") - yield format_stream_data( - {"event": "error", "data": {**error_response.detail.model_dump()}} - ) - return - - # Failed response - emit error with custom cause from error message - elif event_type == "response.failed": - failed_chunk = cast(OpenAIResponseObjectStreamResponseFailed, chunk) - latest_response_object = failed_chunk.response - error_message = ( - failed_chunk.response.error.message - if failed_chunk.response.error - else "An unexpected error occurred while processing the request." - ) - error_response = InternalServerErrorResponse.query_failed(error_message) - logger.error("Error while obtaining answer for user question") - yield format_stream_data( - {"event": "error", "data": {**error_response.detail.model_dump()}} - ) - return - - logger.debug( - "Streaming complete - Tool calls: %d, Response chars: %d", - len(summary.tool_calls), - len(summary.llm_response), - ) - - # Extract token usage from the response object - token_usage = ( - extract_token_usage_from_responses_api( - latest_response_object, context.model_id, context.provider_id - ) - if latest_response_object is not None - else TokenCounter() - ) - consume_tokens( - configuration.quota_limiters, - configuration.token_usage_history, - context.user_id, - input_tokens=token_usage.input_tokens, - output_tokens=token_usage.output_tokens, - model_id=context.model_id, - provider_id=context.provider_id, - ) - referenced_documents = parse_referenced_documents_from_responses_api( - cast(OpenAIResponseObject, latest_response_object) - ) - available_quotas = get_available_quotas( - configuration.quota_limiters, context.user_id - ) - yield stream_end_event( - context.metadata_map, - token_usage, - available_quotas, - referenced_documents, - media_type, - ) - - # Perform cleanup tasks (database and cache operations)) - await cleanup_after_streaming( - user_id=context.user_id, - conversation_id=conv_id, - model_id=context.model_id, - provider_id=context.provider_id, - llama_stack_model_id=context.llama_stack_model_id, - query_request=context.query_request, - summary=summary, - metadata_map=context.metadata_map, - started_at=context.started_at, - client=context.client, - config=configuration, - skip_userid_check=context.skip_userid_check, - get_topic_summary_func=get_topic_summary, - is_transcripts_enabled_func=is_transcripts_enabled, - store_transcript_func=store_transcript, - persist_user_conversation_details_func=persist_user_conversation_details, - rag_chunks=[rag_chunk.model_dump() for rag_chunk in rag_chunks], - ) - - return response_generator - - -@router.post( - "/streaming_query", - response_class=StreamingResponse, - responses=streaming_query_v2_responses, - summary="Streaming Query Endpoint Handler V1", -) -@authorize(Action.STREAMING_QUERY) -async def streaming_query_endpoint_handler_v2( # pylint: disable=too-many-locals - request: Request, - query_request: QueryRequest, - auth: Annotated[AuthTuple, Depends(auth_dependency)], - mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency), -) -> StreamingResponse: - """ - Handle request to the /streaming_query endpoint using Responses API. - - Returns a streaming response using Server-Sent Events (SSE) format with - content type text/event-stream. - - Returns: - StreamingResponse: An HTTP streaming response yielding - SSE-formatted events for the query lifecycle with content type - text/event-stream. - - Raises: - HTTPException: - - 401: Unauthorized - Missing or invalid credentials - - 403: Forbidden - Insufficient permissions or model override not allowed - - 404: Not Found - Conversation, model, or provider not found - - 422: Unprocessable Entity - Request validation failed - - 429: Too Many Requests - Quota limit exceeded - - 500: Internal Server Error - Configuration not loaded or other server errors - - 503: Service Unavailable - Unable to connect to Llama Stack backend - """ - return await streaming_query_endpoint_handler_base( - request=request, - query_request=query_request, - auth=auth, - mcp_headers=mcp_headers, - retrieve_response_func=retrieve_response, - create_response_generator_func=create_responses_response_generator, - ) - - -async def retrieve_response( # pylint: disable=too-many-locals - client: AsyncLlamaStackClient, - model_id: str, - query_request: QueryRequest, - token: str, - mcp_headers: Optional[dict[str, dict[str, str]]] = None, -) -> tuple[AsyncIterator[OpenAIResponseObjectStream], str]: - """ - Retrieve response from LLMs and agents. - - Asynchronously retrieves a streaming response and conversation - ID from the Llama Stack agent for a given user query. - - This function configures shields, system prompt, and tool usage - based on the request and environment. It prepares the agent with - appropriate headers and toolgroups, validates attachments if - present, and initiates a streaming turn with the user's query - and any provided documents. - - Parameters: - model_id (str): Identifier of the model to use for the query. - query_request (QueryRequest): The user's query and associated metadata. - token (str): Authentication token for downstream services. - mcp_headers (dict[str, dict[str, str]], optional): - Multi-cluster proxy headers for tool integrations. - - Returns: - tuple: A tuple containing the streaming response object - and the conversation ID. - """ - # use system prompt from request or default one - system_prompt = get_system_prompt(query_request, configuration) - logger.debug("Using system prompt: %s", system_prompt) - - # TODO(lucasagomes): redact attachments content before sending to LLM - # if attachments are provided, validate them - if query_request.attachments: - validate_attachments_metadata(query_request.attachments) - - # Prepare tools for responses API - toolgroups = await prepare_tools_for_responses_api( - client, query_request, token, configuration, mcp_headers - ) - - # Prepare input for Responses API - # Convert attachments to text and concatenate with query - input_text = query_request.query - if query_request.attachments: - for attachment in query_request.attachments: - input_text += ( - f"\n\n[Attachment: {attachment.attachment_type}]\n" - f"{attachment.content}" - ) - - # Handle conversation ID for Responses API - # Create conversation upfront if not provided - conversation_id = query_request.conversation_id - if conversation_id: - # Conversation ID was provided - convert to llama-stack format - logger.debug("Using existing conversation ID: %s", conversation_id) - llama_stack_conv_id = to_llama_stack_conversation_id(conversation_id) - else: - # No conversation_id provided - create a new conversation first - logger.debug("No conversation_id provided, creating new conversation") - conversation = await client.conversations.create(metadata={}) - llama_stack_conv_id = conversation.id - # Store the normalized version for later use - conversation_id = normalize_conversation_id(llama_stack_conv_id) - logger.info( - "Created new conversation with ID: %s (normalized: %s)", - llama_stack_conv_id, - conversation_id, - ) - - # Run shield moderation before calling LLM - moderation_result = await run_shield_moderation( - client, input_text, query_request.shield_ids - ) - if moderation_result.blocked: - violation_message = moderation_result.message or "" - await append_turn_to_conversation( - client, llama_stack_conv_id, input_text, violation_message - ) - return ( - create_violation_stream(violation_message, moderation_result.shield_model), - normalize_conversation_id(conversation_id), - ) - - create_params: dict[str, Any] = { - "input": input_text, - "model": model_id, - "instructions": system_prompt, - "stream": True, - "store": True, - "tools": toolgroups, - "conversation": llama_stack_conv_id, - } - - response = await client.responses.create(**create_params) - response_stream = cast(AsyncIterator[OpenAIResponseObjectStream], response) - - return response_stream, normalize_conversation_id(conversation_id) diff --git a/src/models/config.py b/src/models/config.py index 68ad168e4..ba4d76062 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -1274,6 +1274,7 @@ class Customization(ConfigurationBase): profile_path: Optional[str] = None disable_query_system_prompt: bool = False + disable_shield_ids_override: bool = False system_prompt_path: Optional[FilePath] = None system_prompt: Optional[str] = None agent_card_path: Optional[FilePath] = None diff --git a/src/utils/shields.py b/src/utils/shields.py index 38aa4b19f..69d4d6e8f 100644 --- a/src/utils/shields.py +++ b/src/utils/shields.py @@ -8,8 +8,10 @@ from llama_stack_client.types import CreateResponse import metrics +from configuration import AppConfig +from models.requests import QueryRequest from models.responses import ( - NotFoundResponse, + NotFoundResponse, UnprocessableEntityResponse, ) from utils.types import ShieldModerationResult @@ -63,6 +65,40 @@ def detect_shield_violations(output_items: list[Any]) -> bool: return False +def validate_shield_ids_override( + query_request: QueryRequest, config: AppConfig +) -> None: + """ + Validate that shield_ids override is allowed by configuration. + + If configuration disables shield_ids override + (config.customization.disable_shield_ids_override) and the incoming + query_request contains shield_ids, an HTTP 422 Unprocessable Entity + is raised instructing the client to remove the field. + + Parameters: + query_request: The incoming query payload; may contain shield_ids. + config: Application configuration which may include customization flags. + + Raises: + HTTPException: If shield_ids override is disabled but shield_ids is provided. + """ + shield_ids_override_disabled = ( + config.customization is not None + and config.customization.disable_shield_ids_override + ) + if shield_ids_override_disabled and query_request.shield_ids is not None: + response = UnprocessableEntityResponse( + response="Shield IDs customization is disabled", + cause=( + "This instance does not support customizing shield IDs in the " + "query request (disable_shield_ids_override is set). Please remove the " + "shield_ids field from your request." + ), + ) + raise HTTPException(**response.model_dump()) + + async def run_shield_moderation( client: AsyncLlamaStackClient, input_text: str, diff --git a/tests/unit/utils/test_shields.py b/tests/unit/utils/test_shields.py index 7702d70a8..16bf01287 100644 --- a/tests/unit/utils/test_shields.py +++ b/tests/unit/utils/test_shields.py @@ -10,6 +10,7 @@ detect_shield_violations, get_available_shields, run_shield_moderation, + validate_shield_ids_override, ) @@ -404,3 +405,83 @@ async def test_appends_user_and_assistant_messages( }, ], ) + + +class TestValidateShieldIdsOverride: + """Tests for validate_shield_ids_override function.""" + + def test_allows_shield_ids_when_override_enabled( + self, mocker: MockerFixture + ) -> None: + """Test that shield_ids is allowed when override is not disabled.""" + mock_config = mocker.Mock() + mock_config.customization = None + + query_request = mocker.Mock() + query_request.shield_ids = ["shield-1"] + + # Should not raise exception + validate_shield_ids_override(query_request, mock_config) + + def test_allows_shield_ids_when_customization_exists_but_override_not_disabled( + self, mocker: MockerFixture + ) -> None: + """Test shield_ids allowed when customization exists but override not disabled.""" + mock_config = mocker.Mock() + mock_config.customization = mocker.Mock() + mock_config.customization.disable_shield_ids_override = False + + query_request = mocker.Mock() + query_request.shield_ids = ["shield-1"] + + # Should not raise exception + validate_shield_ids_override(query_request, mock_config) + + def test_allows_none_shield_ids_when_override_disabled( + self, mocker: MockerFixture + ) -> None: + """Test that None shield_ids is allowed even when override is disabled.""" + mock_config = mocker.Mock() + mock_config.customization = mocker.Mock() + mock_config.customization.disable_shield_ids_override = True + + query_request = mocker.Mock() + query_request.shield_ids = None + + # Should not raise exception + validate_shield_ids_override(query_request, mock_config) + + def test_raises_422_when_shield_ids_provided_and_override_disabled( + self, mocker: MockerFixture + ) -> None: + """Test HTTPException 422 raised when shield_ids provided but override disabled.""" + mock_config = mocker.Mock() + mock_config.customization = mocker.Mock() + mock_config.customization.disable_shield_ids_override = True + + query_request = mocker.Mock() + query_request.shield_ids = ["shield-1"] + + with pytest.raises(HTTPException) as exc_info: + validate_shield_ids_override(query_request, mock_config) + + assert exc_info.value.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + # pylint: disable=line-too-long + assert "Shield IDs customization is disabled" in exc_info.value.detail["response"] # type: ignore + assert "disable_shield_ids_override" in exc_info.value.detail["cause"] # type: ignore + + def test_raises_422_when_empty_list_shield_ids_and_override_disabled( + self, mocker: MockerFixture + ) -> None: + """Test that HTTPException 422 is raised when shield_ids=[] and override disabled.""" + mock_config = mocker.Mock() + mock_config.customization = mocker.Mock() + mock_config.customization.disable_shield_ids_override = True + + query_request = mocker.Mock() + query_request.shield_ids = [] + + with pytest.raises(HTTPException) as exc_info: + validate_shield_ids_override(query_request, mock_config) + + assert exc_info.value.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY