diff --git a/README.md b/README.md index 9775e8fa..b57a9a2b 100644 --- a/README.md +++ b/README.md @@ -657,6 +657,13 @@ utilized: 1. If the `shield_id` starts with `inout_`, it will be used both for input and output. 1. Otherwise, it will be used for input only. +Additionally, an optional list parameter `shield_ids` can be specified in `/query` and `/streaming_query` endpoints to override which shields are applied. You can use this config to disable shield overrides: + +```yaml +customization: + disable_shield_ids_override: true +``` + ## Authentication See [authentication and authorization](docs/auth.md). diff --git a/src/app/endpoints/a2a.py b/src/app/endpoints/a2a.py index 3a36f8b8..a3c75fdd 100644 --- a/src/app/endpoints/a2a.py +++ b/src/app/endpoints/a2a.py @@ -312,6 +312,7 @@ async def _process_task_streaming( # pylint: disable=too-many-locals generate_topic_summary=True, media_type=None, vector_store_ids=vector_store_ids, + shield_ids=None, ) # Get LLM client and select model diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 9cd7cf35..ac1e421e 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -4,7 +4,7 @@ import logging from datetime import UTC, datetime -from typing import Annotated, Any, cast +from typing import Annotated, Any, Optional, cast from fastapi import APIRouter, Depends, HTTPException, Request from llama_stack_api.openai_responses import OpenAIResponseObject @@ -61,6 +61,7 @@ from utils.shields import ( append_turn_to_conversation, run_shield_moderation, + validate_shield_ids_override, ) from utils.suid import normalize_conversation_id from utils.types import ResponsesApiParams, TurnSummary @@ -125,6 +126,9 @@ async def query_endpoint_handler( # Enforce RBAC: optionally disallow overriding model/provider in requests validate_model_provider_override(query_request, request.state.authorized_actions) + # Validate shield_ids override if provided + validate_shield_ids_override(query_request, configuration) + # Validate attachments if provided if query_request.attachments: validate_attachments_metadata(query_request.attachments) @@ -166,7 +170,9 @@ async def query_endpoint_handler( client = await update_azure_token(client) # Retrieve response using Responses API - turn_summary = await retrieve_response(client, responses_params) + turn_summary = await retrieve_response( + client, responses_params, query_request.shield_ids + ) # Get topic summary for new conversation if not user_conversation and query_request.generate_topic_summary: @@ -225,6 +231,7 @@ async def query_endpoint_handler( async def retrieve_response( # pylint: disable=too-many-locals client: AsyncLlamaStackClient, responses_params: ResponsesApiParams, + shield_ids: Optional[list[str]] = None, ) -> TurnSummary: """ Retrieve response from LLMs and agents. @@ -235,6 +242,7 @@ async def retrieve_response( # pylint: disable=too-many-locals Parameters: client: The AsyncLlamaStackClient to use for the request. responses_params: The Responses API parameters. + shield_ids: Optional list of shield IDs for moderation. Returns: TurnSummary: Summary of the LLM response content @@ -242,7 +250,9 @@ async def retrieve_response( # pylint: disable=too-many-locals summary = TurnSummary() try: - moderation_result = await run_shield_moderation(client, responses_params.input) + moderation_result = await run_shield_moderation( + client, responses_params.input, shield_ids + ) if moderation_result.blocked: # Handle shield moderation blocking violation_message = moderation_result.message or "" diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 00cbe132..1753f406 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -80,6 +80,7 @@ from utils.shields import ( append_turn_to_conversation, run_shield_moderation, + validate_shield_ids_override, ) from utils.suid import normalize_conversation_id from utils.token_counter import TokenCounter @@ -151,6 +152,9 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals # Enforce RBAC: optionally disallow overriding model/provider in requests validate_model_provider_override(query_request, request.state.authorized_actions) + # Validate shield_ids override if provided + validate_shield_ids_override(query_request, configuration) + # Validate attachments if provided if query_request.attachments: validate_attachments_metadata(query_request.attachments) @@ -246,7 +250,7 @@ async def retrieve_response_generator( turn_summary = TurnSummary() try: moderation_result = await run_shield_moderation( - context.client, responses_params.input + context.client, responses_params.input, context.query_request.shield_ids ) if moderation_result.blocked: violation_message = moderation_result.message or "" diff --git a/src/models/config.py b/src/models/config.py index 68ad168e..ba4d7606 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -1274,6 +1274,7 @@ class Customization(ConfigurationBase): profile_path: Optional[str] = None disable_query_system_prompt: bool = False + disable_shield_ids_override: bool = False system_prompt_path: Optional[FilePath] = None system_prompt: Optional[str] = None agent_card_path: Optional[FilePath] = None diff --git a/src/models/requests.py b/src/models/requests.py index 38bdc9ac..1924ed97 100644 --- a/src/models/requests.py +++ b/src/models/requests.py @@ -83,6 +83,7 @@ class QueryRequest(BaseModel): generate_topic_summary: Whether to generate topic summary for new conversations. media_type: The optional media type for response format (application/json or text/plain). vector_store_ids: The optional list of specific vector store IDs to query for RAG. + shield_ids: The optional list of safety shield IDs to apply. Example: ```python @@ -166,6 +167,14 @@ class QueryRequest(BaseModel): examples=["ocp_docs", "knowledge_base", "vector_db_1"], ) + shield_ids: Optional[list[str]] = Field( + None, + description="Optional list of safety shield IDs to apply. " + "If None, all configured shields are used. " + "If empty list, all shields are skipped.", + examples=["llama-guard", "custom-shield"], + ) + # provides examples for /docs endpoint model_config = { "extra": "forbid", diff --git a/src/utils/shields.py b/src/utils/shields.py index ecfa80f5..69d4d6e8 100644 --- a/src/utils/shields.py +++ b/src/utils/shields.py @@ -1,15 +1,18 @@ """Utility functions for working with Llama Stack shields.""" import logging -from typing import Any, cast +from typing import Any, Optional, cast from fastapi import HTTPException from llama_stack_client import AsyncLlamaStackClient from llama_stack_client.types import CreateResponse import metrics +from configuration import AppConfig +from models.requests import QueryRequest from models.responses import ( NotFoundResponse, + UnprocessableEntityResponse, ) from utils.types import ShieldModerationResult @@ -62,27 +65,93 @@ def detect_shield_violations(output_items: list[Any]) -> bool: return False +def validate_shield_ids_override( + query_request: QueryRequest, config: AppConfig +) -> None: + """ + Validate that shield_ids override is allowed by configuration. + + If configuration disables shield_ids override + (config.customization.disable_shield_ids_override) and the incoming + query_request contains shield_ids, an HTTP 422 Unprocessable Entity + is raised instructing the client to remove the field. + + Parameters: + query_request: The incoming query payload; may contain shield_ids. + config: Application configuration which may include customization flags. + + Raises: + HTTPException: If shield_ids override is disabled but shield_ids is provided. + """ + shield_ids_override_disabled = ( + config.customization is not None + and config.customization.disable_shield_ids_override + ) + if shield_ids_override_disabled and query_request.shield_ids is not None: + response = UnprocessableEntityResponse( + response="Shield IDs customization is disabled", + cause=( + "This instance does not support customizing shield IDs in the " + "query request (disable_shield_ids_override is set). Please remove the " + "shield_ids field from your request." + ), + ) + raise HTTPException(**response.model_dump()) + + async def run_shield_moderation( client: AsyncLlamaStackClient, input_text: str, + shield_ids: Optional[list[str]] = None, ) -> ShieldModerationResult: """ Run shield moderation on input text. - Iterates through all configured shields and runs moderation checks. + Iterates through configured shields and runs moderation checks. Raises HTTPException if shield model is not found. Parameters: client: The Llama Stack client. input_text: The text to moderate. + shield_ids: Optional list of shield IDs to use. If None, uses all shields. + If empty list, skips all shields. Returns: ShieldModerationResult: Result indicating if content was blocked and the message. + + Raises: + HTTPException: If shield's provider_resource_id is not configured or model not found. """ + all_shields = await client.shields.list() + + # Filter shields based on shield_ids parameter + if shield_ids is not None: + if len(shield_ids) == 0: + logger.info("shield_ids=[] provided, skipping all shields") + return ShieldModerationResult(blocked=False) + + shields_to_run = [s for s in all_shields if s.identifier in shield_ids] + + # Log warning if requested shield not found + requested = set(shield_ids) + available = {s.identifier for s in shields_to_run} + missing = requested - available + if missing: + logger.warning("Requested shields not found: %s", missing) + + # Reject if no requested shields were found (prevents accidental bypass) + if not shields_to_run: + response = UnprocessableEntityResponse( + response="Invalid shield configuration", + cause=f"Requested shield_ids not found: {sorted(missing)}", + ) + raise HTTPException(**response.model_dump()) + else: + shields_to_run = list(all_shields) + available_models = {model.id for model in await client.models.list()} - shields = await client.shields.list() - for shield in shields: + for shield in shields_to_run: if ( not shield.provider_resource_id or shield.provider_resource_id not in available_models diff --git a/tests/unit/utils/test_shields.py b/tests/unit/utils/test_shields.py index 5c352758..16bf0128 100644 --- a/tests/unit/utils/test_shields.py +++ b/tests/unit/utils/test_shields.py @@ -10,6 +10,7 @@ detect_shield_violations, get_available_shields, run_shield_moderation, + validate_shield_ids_override, ) @@ -305,6 +306,75 @@ async def test_returns_blocked_on_bad_request_error( assert result.shield_model == "moderation-model" mock_metric.inc.assert_called_once() + @pytest.mark.asyncio + async def test_shield_ids_empty_list_skips_all_shields( + self, mocker: MockerFixture + ) -> None: + """Test that shield_ids=[] explicitly skips all shields (intentional bypass).""" + mock_client = mocker.Mock() + shield = mocker.Mock() + shield.identifier = "shield-1" + mock_client.shields.list = mocker.AsyncMock(return_value=[shield]) + + result = await run_shield_moderation(mock_client, "test input", shield_ids=[]) + + assert result.blocked is False + mock_client.shields.list.assert_called_once() + + @pytest.mark.asyncio + async def test_shield_ids_raises_exception_when_no_shields_found( + self, mocker: MockerFixture + ) -> None: + """Test shield_ids raises HTTPException when no requested shields exist.""" + mock_client = mocker.Mock() + shield = mocker.Mock() + shield.identifier = "shield-1" + mock_client.shields.list = mocker.AsyncMock(return_value=[shield]) + + with pytest.raises(HTTPException) as exc_info: + await run_shield_moderation( + mock_client, "test input", shield_ids=["typo-shield"] + ) + + assert exc_info.value.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + assert "Invalid shield configuration" in exc_info.value.detail["response"] # type: ignore + assert "typo-shield" in exc_info.value.detail["cause"] # type: ignore + + @pytest.mark.asyncio + async def test_shield_ids_filters_to_specific_shield( + self, mocker: MockerFixture + ) -> None: + """Test that shield_ids filters to only specified shields.""" + mock_client = mocker.Mock() + + shield1 = mocker.Mock() + shield1.identifier = "shield-1" + shield1.provider_resource_id = "model-1" + shield2 = mocker.Mock() + shield2.identifier = "shield-2" + shield2.provider_resource_id = "model-2" + mock_client.shields.list = mocker.AsyncMock(return_value=[shield1, shield2]) + + model1 = mocker.Mock() + model1.id = "model-1" + mock_client.models.list = mocker.AsyncMock(return_value=[model1]) + + moderation_result = mocker.Mock() + moderation_result.results = [mocker.Mock(flagged=False)] + mock_client.moderations.create = mocker.AsyncMock( + return_value=moderation_result + ) + + result = await run_shield_moderation( + mock_client, "test input", shield_ids=["shield-1"] + ) + + assert result.blocked is False + assert mock_client.moderations.create.call_count == 1 + mock_client.moderations.create.assert_called_with( + input="test input", model="model-1" + ) + class TestAppendTurnToConversation: # pylint: disable=too-few-public-methods """Tests for append_turn_to_conversation function.""" @@ -335,3 +405,83 @@ async def test_appends_user_and_assistant_messages( }, ], ) + + +class TestValidateShieldIdsOverride: + """Tests for validate_shield_ids_override function.""" + + def test_allows_shield_ids_when_override_enabled( + self, mocker: MockerFixture + ) -> None: + """Test that shield_ids is allowed when override is not disabled.""" + mock_config = mocker.Mock() + mock_config.customization = None + + query_request = mocker.Mock() + query_request.shield_ids = ["shield-1"] + + # Should not raise exception + validate_shield_ids_override(query_request, mock_config) + + def test_allows_shield_ids_when_customization_exists_but_override_not_disabled( + self, mocker: MockerFixture + ) -> None: + """Test shield_ids allowed when customization exists but override not disabled.""" + mock_config = mocker.Mock() + mock_config.customization = mocker.Mock() + mock_config.customization.disable_shield_ids_override = False + + query_request = mocker.Mock() + query_request.shield_ids = ["shield-1"] + + # Should not raise exception + validate_shield_ids_override(query_request, mock_config) + + def test_allows_none_shield_ids_when_override_disabled( + self, mocker: MockerFixture + ) -> None: + """Test that None shield_ids is allowed even when override is disabled.""" + mock_config = mocker.Mock() + mock_config.customization = mocker.Mock() + mock_config.customization.disable_shield_ids_override = True + + query_request = mocker.Mock() + query_request.shield_ids = None + + # Should not raise exception + validate_shield_ids_override(query_request, mock_config) + + def test_raises_422_when_shield_ids_provided_and_override_disabled( + self, mocker: MockerFixture + ) -> None: + """Test HTTPException 422 raised when shield_ids provided but override disabled.""" + mock_config = mocker.Mock() + mock_config.customization = mocker.Mock() + mock_config.customization.disable_shield_ids_override = True + + query_request = mocker.Mock() + query_request.shield_ids = ["shield-1"] + + with pytest.raises(HTTPException) as exc_info: + validate_shield_ids_override(query_request, mock_config) + + assert exc_info.value.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + # pylint: disable=line-too-long + assert "Shield IDs customization is disabled" in exc_info.value.detail["response"] # type: ignore + assert "disable_shield_ids_override" in exc_info.value.detail["cause"] # type: ignore + + def test_raises_422_when_empty_list_shield_ids_and_override_disabled( + self, mocker: MockerFixture + ) -> None: + """Test that HTTPException 422 is raised when shield_ids=[] and override disabled.""" + mock_config = mocker.Mock() + mock_config.customization = mocker.Mock() + mock_config.customization.disable_shield_ids_override = True + + query_request = mocker.Mock() + query_request.shield_ids = [] + + with pytest.raises(HTTPException) as exc_info: + validate_shield_ids_override(query_request, mock_config) + + assert exc_info.value.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY