Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,13 @@ utilized:
1. If the `shield_id` starts with `inout_`, it will be used both for input and output.
1. Otherwise, it will be used for input only.

Additionally, an optional list parameter `shield_ids` can be specified in `/query` and `/streaming_query` endpoints to override which shields are applied. You can use this config to disable shield overrides:

```yaml
customization:
disable_shield_ids_override: true
```

## Authentication

See [authentication and authorization](docs/auth.md).
Expand Down
1 change: 1 addition & 0 deletions src/app/endpoints/a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ async def _process_task_streaming( # pylint: disable=too-many-locals
generate_topic_summary=True,
media_type=None,
vector_store_ids=vector_store_ids,
shield_ids=None,
)

# Get LLM client and select model
Expand Down
16 changes: 13 additions & 3 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import logging
from datetime import UTC, datetime
from typing import Annotated, Any, cast
from typing import Annotated, Any, Optional, cast

from fastapi import APIRouter, Depends, HTTPException, Request
from llama_stack_api.openai_responses import OpenAIResponseObject
Expand Down Expand Up @@ -61,6 +61,7 @@
from utils.shields import (
append_turn_to_conversation,
run_shield_moderation,
validate_shield_ids_override,
)
from utils.suid import normalize_conversation_id
from utils.types import ResponsesApiParams, TurnSummary
Expand Down Expand Up @@ -125,6 +126,9 @@ async def query_endpoint_handler(
# Enforce RBAC: optionally disallow overriding model/provider in requests
validate_model_provider_override(query_request, request.state.authorized_actions)

# Validate shield_ids override if provided
validate_shield_ids_override(query_request, configuration)

# Validate attachments if provided
if query_request.attachments:
validate_attachments_metadata(query_request.attachments)
Expand Down Expand Up @@ -166,7 +170,9 @@ async def query_endpoint_handler(
client = await update_azure_token(client)

# Retrieve response using Responses API
turn_summary = await retrieve_response(client, responses_params)
turn_summary = await retrieve_response(
client, responses_params, query_request.shield_ids
)

# Get topic summary for new conversation
if not user_conversation and query_request.generate_topic_summary:
Expand Down Expand Up @@ -225,6 +231,7 @@ async def query_endpoint_handler(
async def retrieve_response( # pylint: disable=too-many-locals
client: AsyncLlamaStackClient,
responses_params: ResponsesApiParams,
shield_ids: Optional[list[str]] = None,
) -> TurnSummary:
"""
Retrieve response from LLMs and agents.
Expand All @@ -235,14 +242,17 @@ async def retrieve_response( # pylint: disable=too-many-locals
Parameters:
client: The AsyncLlamaStackClient to use for the request.
responses_params: The Responses API parameters.
shield_ids: Optional list of shield IDs for moderation.

Returns:
TurnSummary: Summary of the LLM response content
"""
summary = TurnSummary()

try:
moderation_result = await run_shield_moderation(client, responses_params.input)
moderation_result = await run_shield_moderation(
client, responses_params.input, shield_ids
)
if moderation_result.blocked:
# Handle shield moderation blocking
violation_message = moderation_result.message or ""
Expand Down
6 changes: 5 additions & 1 deletion src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
from utils.shields import (
append_turn_to_conversation,
run_shield_moderation,
validate_shield_ids_override,
)
from utils.suid import normalize_conversation_id
from utils.token_counter import TokenCounter
Expand Down Expand Up @@ -151,6 +152,9 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
# Enforce RBAC: optionally disallow overriding model/provider in requests
validate_model_provider_override(query_request, request.state.authorized_actions)

# Validate shield_ids override if provided
validate_shield_ids_override(query_request, configuration)

# Validate attachments if provided
if query_request.attachments:
validate_attachments_metadata(query_request.attachments)
Expand Down Expand Up @@ -246,7 +250,7 @@ async def retrieve_response_generator(
turn_summary = TurnSummary()
try:
moderation_result = await run_shield_moderation(
context.client, responses_params.input
context.client, responses_params.input, context.query_request.shield_ids
)
if moderation_result.blocked:
violation_message = moderation_result.message or ""
Expand Down
1 change: 1 addition & 0 deletions src/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,7 @@ class Customization(ConfigurationBase):

profile_path: Optional[str] = None
disable_query_system_prompt: bool = False
disable_shield_ids_override: bool = False
system_prompt_path: Optional[FilePath] = None
system_prompt: Optional[str] = None
agent_card_path: Optional[FilePath] = None
Expand Down
9 changes: 9 additions & 0 deletions src/models/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class QueryRequest(BaseModel):
generate_topic_summary: Whether to generate topic summary for new conversations.
media_type: The optional media type for response format (application/json or text/plain).
vector_store_ids: The optional list of specific vector store IDs to query for RAG.
shield_ids: The optional list of safety shield IDs to apply.

Example:
```python
Expand Down Expand Up @@ -166,6 +167,14 @@ class QueryRequest(BaseModel):
examples=["ocp_docs", "knowledge_base", "vector_db_1"],
)

shield_ids: Optional[list[str]] = Field(
None,
description="Optional list of safety shield IDs to apply. "
"If None, all configured shields are used. "
"If empty list, all shields are skipped.",
examples=["llama-guard", "custom-shield"],
)

# provides examples for /docs endpoint
model_config = {
"extra": "forbid",
Expand Down
77 changes: 73 additions & 4 deletions src/utils/shields.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
"""Utility functions for working with Llama Stack shields."""

import logging
from typing import Any, cast
from typing import Any, Optional, cast

from fastapi import HTTPException
from llama_stack_client import AsyncLlamaStackClient
from llama_stack_client.types import CreateResponse

import metrics
from configuration import AppConfig
from models.requests import QueryRequest
from models.responses import (
NotFoundResponse,
UnprocessableEntityResponse,
)
from utils.types import ShieldModerationResult

Expand Down Expand Up @@ -62,27 +65,93 @@ def detect_shield_violations(output_items: list[Any]) -> bool:
return False


def validate_shield_ids_override(
query_request: QueryRequest, config: AppConfig
) -> None:
"""
Validate that shield_ids override is allowed by configuration.

If configuration disables shield_ids override
(config.customization.disable_shield_ids_override) and the incoming
query_request contains shield_ids, an HTTP 422 Unprocessable Entity
is raised instructing the client to remove the field.

Parameters:
query_request: The incoming query payload; may contain shield_ids.
config: Application configuration which may include customization flags.

Raises:
HTTPException: If shield_ids override is disabled but shield_ids is provided.
"""
shield_ids_override_disabled = (
config.customization is not None
and config.customization.disable_shield_ids_override
)
if shield_ids_override_disabled and query_request.shield_ids is not None:
response = UnprocessableEntityResponse(
response="Shield IDs customization is disabled",
cause=(
"This instance does not support customizing shield IDs in the "
"query request (disable_shield_ids_override is set). Please remove the "
"shield_ids field from your request."
),
)
raise HTTPException(**response.model_dump())


async def run_shield_moderation(
client: AsyncLlamaStackClient,
input_text: str,
shield_ids: Optional[list[str]] = None,
) -> ShieldModerationResult:
"""
Run shield moderation on input text.

Iterates through all configured shields and runs moderation checks.
Iterates through configured shields and runs moderation checks.
Raises HTTPException if shield model is not found.

Parameters:
client: The Llama Stack client.
input_text: The text to moderate.
shield_ids: Optional list of shield IDs to use. If None, uses all shields.
If empty list, skips all shields.

Returns:
ShieldModerationResult: Result indicating if content was blocked and the message.

Raises:
HTTPException: If shield's provider_resource_id is not configured or model not found.
"""
all_shields = await client.shields.list()

# Filter shields based on shield_ids parameter
if shield_ids is not None:
if len(shield_ids) == 0:
logger.info("shield_ids=[] provided, skipping all shields")
return ShieldModerationResult(blocked=False)

shields_to_run = [s for s in all_shields if s.identifier in shield_ids]

# Log warning if requested shield not found
requested = set(shield_ids)
available = {s.identifier for s in shields_to_run}
missing = requested - available
if missing:
logger.warning("Requested shields not found: %s", missing)

# Reject if no requested shields were found (prevents accidental bypass)
if not shields_to_run:
response = UnprocessableEntityResponse(
response="Invalid shield configuration",
cause=f"Requested shield_ids not found: {sorted(missing)}",
)
raise HTTPException(**response.model_dump())
else:
shields_to_run = list(all_shields)

available_models = {model.id for model in await client.models.list()}

shields = await client.shields.list()
for shield in shields:
for shield in shields_to_run:
if (
not shield.provider_resource_id
or shield.provider_resource_id not in available_models
Expand Down
Loading
Loading