Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lightspeed-stack.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,7 @@ conversation_cache:

authentication:
module: "noop"


solr:
offline: True
26 changes: 23 additions & 3 deletions run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ apis:
benchmarks: []
datasets: []
image_name: starter
# external_providers_dir: /opt/app-root/src/.llama/providers.d
external_providers_dir: ${env.EXTERNAL_PROVIDERS_DIR}

providers:
inference:
Expand All @@ -24,7 +24,9 @@ providers:
config:
api_key: ${env.OPENAI_API_KEY}
allowed_models: ["${env.E2E_OPENAI_MODEL:=gpt-4o-mini}"]
- config: {}
- config:
allowed_models:
- ${env.EMBEDDING_MODEL_DIR}
provider_id: sentence-transformers
provider_type: inline::sentence-transformers
files:
Expand Down Expand Up @@ -56,6 +58,18 @@ providers:
provider_id: rag-runtime
provider_type: inline::rag-runtime
vector_io:
- provider_id: solr-vector
provider_type: remote::solr_vector_io
config:
solr_url: http://localhost:8983/solr
collection_name: portal-rag
vector_field: chunk_vector
content_field: chunk
embedding_dimension: 384
embedding_model: ${env.EMBEDDING_MODEL_DIR}
persistence:
namespace: portal-rag
backend: kv_default
- config: # Define the storage backend for RAG
persistence:
namespace: vector_io::faiss
Expand Down Expand Up @@ -127,7 +141,13 @@ storage:
namespace: prompts
backend: kv_default
registered_resources:
models: []
models:
- model_id: granite-embedding-30m
model_type: embedding
provider_id: sentence-transformers
provider_model_id: ${env.EMBEDDING_MODEL_DIR}
metadata:
embedding_dimension: 384
shields:
- shield_id: llama-guard
provider_id: llama-guard
Expand Down
52 changes: 47 additions & 5 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

"""Handler for REST API call to provide answer to query using Response API."""

import datetime
import logging
from datetime import UTC, datetime
from typing import Annotated, Any, cast

from fastapi import APIRouter, Depends, HTTPException, Request
Expand All @@ -25,13 +25,15 @@
from configuration import configuration
from models.config import Action
from models.requests import QueryRequest

from models.responses import (
ForbiddenResponse,
InternalServerErrorResponse,
NotFoundResponse,
PromptTooLongResponse,
QueryResponse,
QuotaExceededResponse,
ReferencedDocument,
ServiceUnavailableResponse,
UnauthorizedResponse,
UnprocessableEntityResponse,
Expand Down Expand Up @@ -63,7 +65,11 @@
run_shield_moderation,
)
from utils.suid import normalize_conversation_id
from utils.types import ResponsesApiParams, TurnSummary
from utils.types import (
ResponsesApiParams,
TurnSummary,
)
from utils.vector_search import perform_vector_search, format_rag_context_for_injection

logger = logging.getLogger("app.endpoints.handlers")
router = APIRouter(tags=["query"])
Expand All @@ -77,9 +83,9 @@
examples=["endpoint", "conversation read", "model override"]
),
404: NotFoundResponse.openapi_response(
examples=["model", "conversation", "provider"]
examples=["conversation", "model", "provider"]
),
413: PromptTooLongResponse.openapi_response(),
# 413: PromptTooLongResponse.openapi_response(),
422: UnprocessableEntityResponse.openapi_response(),
429: QuotaExceededResponse.openapi_response(),
500: InternalServerErrorResponse.openapi_response(examples=["configuration"]),
Expand Down Expand Up @@ -145,6 +151,19 @@ async def query_endpoint_handler(

client = AsyncLlamaStackClientHolder().get_client()

doc_ids_from_chunks: list[ReferencedDocument] = []
pre_rag_chunks: list[Any] = [] # use your RAGChunk type (or the upstream one)

_, _, doc_ids_from_chunks, pre_rag_chunks = await perform_vector_search(
client, query_request, configuration
)

rag_context = format_rag_context_for_injection(pre_rag_chunks)
if rag_context:
# safest: mutate a local copy so we don't surprise other logic
query_request = query_request.model_copy(deep=True) # pydantic v2
query_request.query = query_request.query + rag_context

# Prepare API request parameters
responses_params = await prepare_responses_params(
client,
Expand All @@ -168,6 +187,14 @@ async def query_endpoint_handler(
# Retrieve response using Responses API
turn_summary = await retrieve_response(client, responses_params)

if pre_rag_chunks:
turn_summary.rag_chunks = pre_rag_chunks + (turn_summary.rag_chunks or [])

if doc_ids_from_chunks:
turn_summary.referenced_documents = parse_referenced_docs(
doc_ids_from_chunks + (turn_summary.referenced_documents or [])
)

# Get topic summary for new conversation
if not user_conversation and query_request.generate_topic_summary:
logger.debug("Generating topic summary for new conversation")
Expand All @@ -190,7 +217,9 @@ async def query_endpoint_handler(
quota_limiters=configuration.quota_limiters, user_id=user_id
)

completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
completed_at = datetime.datetime.now(datetime.timezone.utc).strftime(
"%Y-%m-%dT%H:%M:%SZ"
)
conversation_id = normalize_conversation_id(responses_params.conversation)

logger.info("Storing query results")
Expand Down Expand Up @@ -222,6 +251,19 @@ async def query_endpoint_handler(
)


def parse_referenced_docs(
docs: list[ReferencedDocument],
) -> list[ReferencedDocument]:
seen: set[tuple[str | None, str | None]] = set()
out: list[ReferencedDocument] = []
for d in docs:
key = (d.doc_url, d.doc_title)
if key in seen:
continue
seen.add(key)
out.append(d)
return out

async def retrieve_response( # pylint: disable=too-many-locals
client: AsyncLlamaStackClient,
responses_params: ResponsesApiParams,
Expand Down
61 changes: 45 additions & 16 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
LLM_TOOL_CALL_EVENT,
LLM_TOOL_RESULT_EVENT,
LLM_TURN_COMPLETE_EVENT,
MEDIA_TYPE_EVENT_STREAM,
MEDIA_TYPE_JSON,
MEDIA_TYPE_TEXT,
)
Expand All @@ -53,7 +52,7 @@
UnauthorizedResponse,
UnprocessableEntityResponse,
)
from utils.types import ReferencedDocument
from utils.types import RAGChunk, ReferencedDocument
from utils.endpoints import (
check_configuration_loaded,
validate_and_retrieve_conversation,
Expand Down Expand Up @@ -85,6 +84,7 @@
from utils.suid import normalize_conversation_id
from utils.token_counter import TokenCounter
from utils.types import ResponsesApiParams, TurnSummary
from utils.vector_search import format_rag_context_for_injection, perform_vector_search

logger = logging.getLogger(__name__)
router = APIRouter(tags=["streaming_query"])
Expand All @@ -100,7 +100,7 @@
404: NotFoundResponse.openapi_response(
examples=["conversation", "model", "provider"]
),
413: PromptTooLongResponse.openapi_response(),
# 413: PromptTooLongResponse.openapi_response(),
422: UnprocessableEntityResponse.openapi_response(),
429: QuotaExceededResponse.openapi_response(),
500: InternalServerErrorResponse.openapi_response(examples=["configuration"]),
Expand Down Expand Up @@ -172,6 +172,17 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals

client = AsyncLlamaStackClientHolder().get_client()

pre_rag_chunks: list[RAGChunk] = []
doc_ids_from_chunks: list[ReferencedDocument] = []

_, _, doc_ids_from_chunks, pre_rag_chunks = await perform_vector_search(
client, query_request, configuration
)
rag_context = format_rag_context_for_injection(pre_rag_chunks)
if rag_context:
query_request = query_request.model_copy(deep=True)
query_request.query = query_request.query + rag_context

# Prepare API request parameters
responses_params = await prepare_responses_params(
client=client,
Expand Down Expand Up @@ -212,6 +223,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
generator, turn_summary = await retrieve_response_generator(
responses_params=responses_params,
context=context,
doc_ids_from_chunks=doc_ids_from_chunks,
)

response_media_type = (
Expand All @@ -227,13 +239,14 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
responses_params=responses_params,
turn_summary=turn_summary,
),
media_type=response_media_type,
media_type=query_request.media_type or MEDIA_TYPE_TEXT,
)


async def retrieve_response_generator(
responses_params: ResponsesApiParams,
context: ResponseGeneratorContext,
doc_ids_from_chunks: list[ReferencedDocument],
) -> tuple[AsyncIterator[str], TurnSummary]:
"""
Retrieve the appropriate response generator.
Expand Down Expand Up @@ -273,6 +286,8 @@ async def retrieve_response_generator(
response = await context.client.responses.create(
**responses_params.model_dump()
)
# Store pre-RAG documents for later merging
turn_summary.pre_rag_documents = doc_ids_from_chunks
return response_generator(response, context, turn_summary), turn_summary

# Handle know LLS client errors only at stream creation time and shield execution
Expand Down Expand Up @@ -571,9 +586,21 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
turn_summary.token_usage = extract_token_usage(
latest_response_object, context.model_id
)
turn_summary.referenced_documents = parse_referenced_documents(
latest_response_object
)
tool_based_documents = parse_referenced_documents(latest_response_object)

# Merge pre-RAG documents with tool-based documents (similar to query.py)
if turn_summary.pre_rag_documents:
all_documents = turn_summary.pre_rag_documents + tool_based_documents
seen = set()
deduplicated_documents = []
for doc in all_documents:
key = (doc.doc_url, doc.doc_title)
if key not in seen:
seen.add(key)
deduplicated_documents.append(doc)
turn_summary.referenced_documents = deduplicated_documents
else:
turn_summary.referenced_documents = tool_based_documents


def stream_http_error_event(
Expand Down Expand Up @@ -608,7 +635,7 @@ def stream_http_error_event(

def format_stream_data(d: dict) -> str:
"""
Format a dictionary as a Server-Sent Events (SSE) data string.
Create a response generator function for Responses API streaming.

Parameters:
d (dict): The data to be formatted as an SSE event.
Expand Down Expand Up @@ -694,22 +721,24 @@ def stream_event(data: dict, event_type: str, media_type: str) -> str:
"""Build an item to yield based on media type.

Args:
data: The data to yield.
event_type: The type of event (e.g. token, tool request, tool execution).
media_type: Media type of the response (e.g. text or JSON).
data: Dictionary containing the event data
event_type: Type of event (token, tool call, etc.)
media_type: The media type for the response format

Returns:
str: The formatted string or JSON to yield.
SSE-formatted string representing the event
"""
if media_type == MEDIA_TYPE_TEXT:
if event_type == LLM_TOKEN_EVENT:
return data["token"]
return data.get("token", "")
if event_type == LLM_TOOL_CALL_EVENT:
return f"\nTool call: {json.dumps(data)}\n"
return f"[Tool Call: {data.get('function_name', 'unknown')}]\n"
if event_type == LLM_TOOL_RESULT_EVENT:
return f"\nTool result: {json.dumps(data)}\n"
logger.error("Unknown event type: %s", event_type)
return "[Tool Result]\n"
if event_type == LLM_TURN_COMPLETE_EVENT:
return ""
return ""

return format_stream_data(
{
"event": event_type,
Expand Down
5 changes: 5 additions & 0 deletions src/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
from utils.common import register_mcp_servers_async
from utils.llama_stack_version import check_llama_stack_version

import faulthandler
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably leftover too

import signal

faulthandler.register(signal.SIGUSR1)

logger = get_logger(__name__)

logger.info("Initializing app")
Expand Down
8 changes: 8 additions & 0 deletions src/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
DatabaseConfiguration,
ConversationHistoryConfiguration,
QuotaHandlersConfiguration,
SolrConfiguration,
SplunkConfiguration,
)

Expand Down Expand Up @@ -363,5 +364,12 @@ def deployment_environment(self) -> str:
raise LogicError("logic error: configuration is not loaded")
return self._configuration.deployment_environment

@property
def solr(self) -> Optional[SolrConfiguration]:
"""Return Solr configuration, or None if not provided."""
if self._configuration is None:
raise LogicError("logic error: configuration is not loaded")
return self._configuration.solr


configuration: AppConfig = AppConfig()
10 changes: 9 additions & 1 deletion src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
MCP_AUTH_CLIENT = "client"

# default RAG tool value
DEFAULT_RAG_TOOL = "knowledge_search"
DEFAULT_RAG_TOOL = "file_search"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This causes error in integration tests, please rename it also there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean rename?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you look at the integration tests job in the original branch you see something like:

=========================== short test summary info ============================
FAILED tests/integration/endpoints/test_query_v2_integration.py::test_query_v2_endpoint_with_tool_calls - AssertionError: assert 'file_search' == 'knowledge_search'
  
  - knowledge_search
  + file_search

That means you renamed the constant in code but not in integration tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh got it. Thank you!


# Media type constants for streaming responses
MEDIA_TYPE_JSON = "application/json"
Expand Down Expand Up @@ -168,3 +168,11 @@
# quota limiters constants
USER_QUOTA_LIMITER = "user_limiter"
CLUSTER_QUOTA_LIMITER = "cluster_limiter"

# Vector search constants
VECTOR_SEARCH_DEFAULT_K = 5
VECTOR_SEARCH_DEFAULT_SCORE_THRESHOLD = 0.0
VECTOR_SEARCH_DEFAULT_MODE = "hybrid"

# SOLR OKP RAG
MIMIR_DOC_URL = "https://mimir.corp.redhat.com"
Loading
Loading