diff --git a/lightspeed-stack.yaml b/lightspeed-stack.yaml index ba29f85f..356394c7 100644 --- a/lightspeed-stack.yaml +++ b/lightspeed-stack.yaml @@ -10,12 +10,12 @@ service: llama_stack: # Uses a remote llama-stack service # The instance would have already been started with a llama-stack-run.yaml file - use_as_library_client: false + # use_as_library_client: false # Alternative for "as library use" - # use_as_library_client: true - # library_client_config_path: - url: http://llama-stack:8321 - api_key: xyzzy + use_as_library_client: true + library_client_config_path: run.yaml + # url: http://llama-stack:8321 + # api_key: xyzzy user_data_collection: feedback_enabled: true feedback_storage: "/tmp/data/feedback" diff --git a/run.yaml b/run.yaml index 3680f2b3..5ed401dc 100644 --- a/run.yaml +++ b/run.yaml @@ -15,124 +15,159 @@ apis: benchmarks: [] datasets: [] image_name: starter -# external_providers_dir: /opt/app-root/src/.llama/providers.d +external_providers_dir: ${env.EXTERNAL_PROVIDERS_DIR} -providers: - inference: - - provider_id: openai # This ID is a reference to 'providers.inference' - provider_type: remote::openai - config: - api_key: ${env.OPENAI_API_KEY} - allowed_models: ["${env.E2E_OPENAI_MODEL:=gpt-4o-mini}"] - - config: {} - provider_id: sentence-transformers - provider_type: inline::sentence-transformers - files: - - config: - metadata_store: - table_name: files_metadata - backend: sql_default - storage_dir: ~/.llama/storage/files - provider_id: meta-reference-files - provider_type: inline::localfs - safety: - - config: - excluded_categories: [] - provider_id: llama-guard - provider_type: inline::llama-guard - scoring: - - provider_id: basic - provider_type: inline::basic - config: {} - - provider_id: llm-as-judge - provider_type: inline::llm-as-judge - config: {} - - provider_id: braintrust - provider_type: inline::braintrust - config: - openai_api_key: '********' - tool_runtime: - - config: {} # Enable the RAG tool - provider_id: rag-runtime - provider_type: inline::rag-runtime - vector_io: - - config: # Define the storage backend for RAG - persistence: - namespace: vector_io::faiss - backend: kv_default - provider_id: faiss - provider_type: inline::faiss - agents: - - config: - persistence: - agent_state: - namespace: agents_state - backend: kv_default - responses: - table_name: agents_responses - backend: sql_default - provider_id: meta-reference - provider_type: inline::meta-reference - batches: - - config: - kvstore: - namespace: batches_store - backend: kv_default - provider_id: reference - provider_type: inline::reference - datasetio: - - config: - kvstore: - namespace: huggingface_datasetio - backend: kv_default - provider_id: huggingface - provider_type: remote::huggingface - - config: - kvstore: - namespace: localfs_datasetio - backend: kv_default - provider_id: localfs - provider_type: inline::localfs - eval: - - config: - kvstore: - namespace: eval_store - backend: kv_default - provider_id: meta-reference - provider_type: inline::meta-reference -scoring_fns: [] -server: - port: 8321 storage: backends: - kv_default: # Define the storage backend type for RAG, in this case registry and RAG are unified i.e. information on registered resources (e.g. models, vector_stores) are saved together with the RAG chunks + kv_default: type: kv_sqlite db_path: ${env.KV_STORE_PATH:=~/.llama/storage/rag/kv_store.db} sql_default: type: sql_sqlite db_path: ${env.SQL_STORE_PATH:=~/.llama/storage/sql_store.db} + stores: metadata: namespace: registry backend: kv_default + inference: table_name: inference_store backend: sql_default max_write_queue_size: 10000 num_writers: 4 + conversations: table_name: openai_conversations backend: sql_default + prompts: namespace: prompts backend: kv_default + +metadata_store: + type: sqlite + db_path: ~/.llama/storage/registry.db + +inference_store: + type: sqlite + db_path: ~/.llama/storage/inference-store.db + +conversations_store: + type: sqlite + db_path: ~/.llama/storage/conversations.db + +providers: + inference: + - provider_id: openai + provider_type: remote::openai + config: + api_key: ${env.OPENAI_API_KEY} + allowed_models: + - gpt-4o-mini + + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers + config: + allowed_models: + - ${env.EMBEDDING_MODEL_DIR} + + files: + - provider_id: meta-reference-files + provider_type: inline::localfs + config: + storage_dir: ~/.llama/storage/files + metadata_store: + table_name: files_metadata + backend: sql_default + + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: + excluded_categories: [] + + scoring: + - provider_id: basic + provider_type: inline::basic + config: {} + + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + config: {} + + tool_runtime: + - provider_id: rag-runtime + provider_type: inline::rag-runtime + config: {} + + vector_io: + - provider_id: solr-vector + provider_type: remote::solr_vector_io + config: + solr_url: http://localhost:8983/solr + collection_name: portal-rag + vector_field: chunk_vector + content_field: chunk + embedding_dimension: 384 + embedding_model: ${env.EMBEDDING_MODEL_DIR} + persistence: + namespace: portal-rag + backend: kv_default + + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence: + agent_state: + namespace: agents_state + backend: kv_default + responses: + table_name: agents_responses + backend: sql_default + + batches: + - provider_id: reference + provider_type: inline::reference + config: + kvstore: + namespace: batches_store + backend: kv_default + + datasetio: + - provider_id: huggingface + provider_type: remote::huggingface + config: + kvstore: + namespace: huggingface_datasetio + backend: kv_default + + - provider_id: localfs + provider_type: inline::localfs + config: + kvstore: + namespace: localfs_datasetio + backend: kv_default + registered_resources: - models: [] + models: + - model_id: granite-embedding-30m + model_type: embedding + provider_id: sentence-transformers + provider_model_id: ${env.EMBEDDING_MODEL_DIR} + metadata: + embedding_dimension: 384 + shields: - shield_id: llama-guard provider_id: llama-guard provider_shield_id: openai/gpt-4o-mini - vector_stores: [] + vector_stores: + - vector_store_id: portal-rag + provider_id: solr-vector + embedding_model: sentence-transformers/${env.EMBEDDING_MODEL_DIR} + embedding_dimension: 384 datasets: [] scoring_fns: [] benchmarks: [] @@ -140,9 +175,9 @@ registered_resources: - toolgroup_id: builtin::rag # Register the RAG tool provider_id: rag-runtime vector_stores: - default_provider_id: faiss - default_embedding_model: # Define the default embedding model for RAG - provider_id: sentence-transformers - model_id: nomic-ai/nomic-embed-text-v1.5 + vector_store_id: portal-rag + provider_id: solr-vector + embedding_model: sentence-transformers/${env.EMBEDDING_MODEL_DIR} + embedding_dimension: 384 safety: default_shield_id: llama-guard diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index eddf30f8..eb28a001 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -55,6 +55,11 @@ router = APIRouter(tags=["query"]) +# When OFFLINE is False, use reference_url for chunk source +# When OFFLINE is True, use parent_id for chunk source +# TODO: move this setting to a higher level configuration +OFFLINE = True + query_response: dict[int | str, dict[str, Any]] = { 200: QueryResponse.openapi_response(), 401: UnauthorizedResponse.openapi_response( @@ -386,9 +391,9 @@ async def query_endpoint_handler_base( # pylint: disable=R0914 response = QueryResponse( conversation_id=conversation_id, response=summary.llm_response, - tool_calls=summary.tool_calls, - tool_results=summary.tool_results, - rag_chunks=summary.rag_chunks, + rag_chunks=rag_chunks_dict, + tool_calls=summary.tool_calls if summary.tool_calls else [], + tool_results=summary.tool_results if summary.tool_results else [], referenced_documents=referenced_documents, truncated=False, # TODO: implement truncation detection input_tokens=token_usage.input_tokens, diff --git a/src/app/endpoints/query_v2.py b/src/app/endpoints/query_v2.py index ecc39b07..3e34b07d 100644 --- a/src/app/endpoints/query_v2.py +++ b/src/app/endpoints/query_v2.py @@ -4,9 +4,12 @@ import json import logging +import traceback from typing import Annotated, Any, Optional, cast +from urllib.parse import urljoin from fastapi import APIRouter, Depends, Request +from llama_stack_client import APIConnectionError, APIStatusError from llama_stack_api.openai_responses import ( OpenAIResponseMCPApprovalRequest, OpenAIResponseMCPApprovalResponse, @@ -30,7 +33,7 @@ from authentication.interface import AuthTuple from authorization.middleware import authorize from configuration import AppConfig, configuration -from constants import DEFAULT_RAG_TOOL +from constants import DEFAULT_RAG_TOOL, MIMIR_DOC_URL from models.config import Action, ModelContextProtocolServer from models.requests import QueryRequest from models.responses import ( @@ -364,9 +367,14 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche if query_request.attachments: validate_attachments_metadata(query_request.attachments) - # Prepare tools for responses API + # Prepare tools for responses API - skip RAG tools since we're doing direct vector query toolgroups = await prepare_tools_for_responses_api( - client, query_request, token, configuration, mcp_headers + client, + query_request, + token, + configuration, + mcp_headers=mcp_headers, + skip_rag_tools=True, ) # Prepare input for Responses API @@ -420,6 +428,174 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche TokenCounter(), ) + # Extract RAG chunks from vector DB query response BEFORE calling responses API + rag_chunks = [] + doc_ids_from_chunks = [] + retrieved_chunks = [] + retrieved_scores = [] + + # When offline is False, use reference_url for chunk source + # When offline is True, use parent_id for chunk source + # TODO: move this setting to a higher level configuration + offline = True + + try: + # Get vector stores for direct querying + if query_request.vector_store_ids: + vector_store_ids = query_request.vector_store_ids + logger.info( + "Using specified vector_store_ids for direct query: %s", + vector_store_ids, + ) + else: + vector_store_ids = [ + vector_store.id + for vector_store in (await client.vector_stores.list()).data + ] + logger.info( + "Using all available vector_store_ids for direct query: %s", + vector_store_ids, + ) + + if vector_store_ids: + vector_store_id = vector_store_ids[0] # Use first available vector store + + params = {"k": 5, "score_threshold": 0.0, "mode": "hybrid"} + logger.info("Initial params: %s", params) + logger.info("query_request.solr: %s", query_request.solr) + if query_request.solr: + # Pass the entire solr dict under the 'solr' key + params["solr"] = query_request.solr + logger.info("Final params with solr filters: %s", params) + else: + logger.info("No solr filters provided") + logger.info("Final params being sent to vector_io.query: %s", params) + + query_response = await client.vector_io.query( + vector_store_id=vector_store_id, + query=query_request.query, + params=params, + ) + + logger.info("The query response total payload: %s", query_response) + + if query_response.chunks: + retrieved_chunks = query_response.chunks + retrieved_scores = ( + query_response.scores if hasattr(query_response, "scores") else [] + ) + + # Extract doc_ids from chunks for referenced_documents + metadata_doc_ids = set() + + for chunk in query_response.chunks: + logger.info("Extract doc ids from chunk: %s", chunk) + + # 1) dict metadata (what your code expects today) + md = getattr(chunk, "metadata", None) or {} + doc_id = md.get("doc_id") or md.get("document_id") + title = md.get("title") + + # 2) typed chunk_metadata (what your provider/logs are actually populating) + if not doc_id: + cm = getattr(chunk, "chunk_metadata", None) + if cm is not None: + # cm might be a pydantic model or a dict depending on caller + if isinstance(cm, dict): + doc_id = cm.get("doc_id") or cm.get("document_id") + title = title or cm.get("title") + reference_url = cm.get("reference_url") + else: + doc_id = getattr(cm, "doc_id", None) or getattr( + cm, "document_id", None + ) + title = title or getattr(cm, "title", None) + reference_url = getattr(cm, "reference_url", None) + else: + reference_url = None + else: + reference_url = md.get("reference_url") + + if not doc_id and not reference_url: + continue + + # Build URL based on offline flag + if offline: + # Use parent/doc path + reference_doc = doc_id + doc_url = MIMIR_DOC_URL + reference_doc + else: + # Use reference_url if online + reference_doc = reference_url or doc_id + doc_url = ( + reference_doc + if reference_doc.startswith("http") + else (MIMIR_DOC_URL + reference_doc) + ) + + if reference_doc and reference_doc not in metadata_doc_ids: + metadata_doc_ids.add(reference_doc) + doc_ids_from_chunks.append( + ReferencedDocument( + doc_title=title, + doc_url=doc_url, + ) + ) + + logger.info( + "Extracted %d unique document IDs from chunks", len(doc_ids_from_chunks) + ) + + except ( + APIConnectionError, + APIStatusError, + AttributeError, + KeyError, + ValueError, + ) as e: + logger.warning("Failed to query vector database for chunks: %s", e) + logger.debug("Vector DB query error details: %s", traceback.format_exc()) + # Continue without RAG chunks + + # Convert retrieved chunks to RAGChunk format + for i, chunk in enumerate(retrieved_chunks): + # Extract source from chunk metadata based on offline flag + source = None + if chunk.metadata: + if offline: + parent_id = chunk.metadata.get("parent_id") + if parent_id: + source = urljoin(MIMIR_DOC_URL, parent_id) + else: + source = chunk.metadata.get("reference_url") + + # Get score from retrieved_scores list if available + score = retrieved_scores[i] if i < len(retrieved_scores) else None + + rag_chunks.append( + RAGChunk( + content=chunk.content, + source=source, + score=score, + ) + ) + + logger.info("Retrieved %d chunks from vector DB", len(rag_chunks)) + + # Format RAG context for injection into user message + rag_context = "" + if rag_chunks: + context_chunks = [] + for chunk in rag_chunks[:5]: # Limit to top 5 chunks + chunk_text = f"Source: {chunk.source or 'Unknown'}\n{chunk.content}" + context_chunks.append(chunk_text) + rag_context = "\n\nRelevant documentation:\n" + "\n\n".join(context_chunks) + logger.info("Injecting %d RAG chunks into user message", len(context_chunks)) + + # Inject RAG context into input text + if rag_context: + input_text = input_text + rag_context + # Create OpenAI response using responses API create_kwargs: dict[str, Any] = { "input": input_text, @@ -444,18 +620,29 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche llm_response = "" tool_calls: list[ToolCallSummary] = [] tool_results: list[ToolResultSummary] = [] - rag_chunks: list[RAGChunk] = [] + response_api_rag_chunks: list[RAGChunk] = [] for output_item in response.output: message_text = extract_text_from_response_output_item(output_item) if message_text: llm_response += message_text - tool_call, tool_result = _build_tool_call_summary(output_item, rag_chunks) + tool_call, tool_result = _build_tool_call_summary( + output_item, response_api_rag_chunks + ) if tool_call: tool_calls.append(tool_call) if tool_result: tool_results.append(tool_result) + # Merge RAG chunks from direct vector query with those from responses API + all_rag_chunks = rag_chunks + response_api_rag_chunks + logger.info( + "Combined RAG chunks: %d from direct query + %d from responses API = %d total", + len(rag_chunks), + len(response_api_rag_chunks), + len(all_rag_chunks), + ) + logger.info( "Response processing complete - Tool calls: %d, Response length: %d chars", len(tool_calls), @@ -466,11 +653,21 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche llm_response=llm_response, tool_calls=tool_calls, tool_results=tool_results, - rag_chunks=rag_chunks, + rag_chunks=all_rag_chunks, ) # Extract referenced documents and token usage from Responses API response - referenced_documents = parse_referenced_documents_from_responses_api(response) + # Merge with documents from direct vector query + response_referenced_documents = parse_referenced_documents_from_responses_api( + response + ) + all_referenced_documents = doc_ids_from_chunks + response_referenced_documents + logger.info( + "Combined referenced documents: %d from direct query + %d from responses API = %d total", + len(doc_ids_from_chunks), + len(response_referenced_documents), + len(all_referenced_documents), + ) model_label = model_id.split("/", 1)[1] if "/" in model_id else model_id token_usage = extract_token_usage_from_responses_api( response, model_label, provider_id, system_prompt @@ -485,7 +682,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche return ( summary, normalize_conversation_id(conversation_id), - referenced_documents, + all_referenced_documents, token_usage, ) @@ -687,12 +884,15 @@ def _increment_llm_call_metric(provider: str, model: str) -> None: logger.warning("Failed to update LLM call metric: %s", e) -def get_rag_tools(vector_store_ids: list[str]) -> Optional[list[dict[str, Any]]]: +def get_rag_tools( + vector_store_ids: list[str], solr_params: Optional[dict[str, Any]] = None +) -> Optional[list[dict[str, Any]]]: """ Convert vector store IDs to tools format for Responses API. Args: vector_store_ids: List of vector store identifiers + solr_params: Optional Solr filtering parameters Returns: Optional[list[dict[str, Any]]]: List containing file_search tool configuration, @@ -701,13 +901,16 @@ def get_rag_tools(vector_store_ids: list[str]) -> Optional[list[dict[str, Any]]] if not vector_store_ids: return None - return [ - { - "type": "file_search", - "vector_store_ids": vector_store_ids, - "max_num_results": 10, - } - ] + tool_config = { + "type": "file_search", + "vector_store_ids": vector_store_ids, + "max_num_results": 10, + } + + if solr_params: + tool_config["solr"] = solr_params + + return [tool_config] def get_mcp_tools( @@ -808,7 +1011,9 @@ async def prepare_tools_for_responses_api( query_request: QueryRequest, token: str, config: AppConfig, + *, mcp_headers: Optional[dict[str, dict[str, str]]] = None, + skip_rag_tools: bool = False, ) -> Optional[list[dict[str, Any]]]: """ Prepare tools for Responses API including RAG and MCP tools. @@ -822,6 +1027,7 @@ async def prepare_tools_for_responses_api( token: Authentication token for MCP tools config: Configuration object containing MCP server settings mcp_headers: Per-request headers for MCP servers + skip_rag_tools: If True, skip adding RAG tools (used when doing direct vector querying) Returns: Optional[list[dict[str, Any]]]: List of tool configurations for the @@ -831,18 +1037,32 @@ async def prepare_tools_for_responses_api( return None toolgroups = [] - # Get vector stores for RAG tools - use specified ones or fetch all - if query_request.vector_store_ids: - vector_store_ids = query_request.vector_store_ids - else: - vector_store_ids = [ - vector_store.id for vector_store in (await client.vector_stores.list()).data - ] - # Add RAG tools if vector stores are available - rag_tools = get_rag_tools(vector_store_ids) - if rag_tools: - toolgroups.extend(rag_tools) + # Add RAG tools if not skipped + if not skip_rag_tools: + # Get vector stores for RAG tools - use specified ones or fetch all + if query_request.vector_store_ids: + vector_store_ids = query_request.vector_store_ids + logger.info("Using specified vector_store_ids: %s", vector_store_ids) + else: + vector_store_ids = [ + vector_store.id + for vector_store in (await client.vector_stores.list()).data + ] + logger.info("Using all available vector_store_ids: %s", vector_store_ids) + + # Add RAG tools if vector stores are available + if vector_store_ids: + rag_tools = get_rag_tools(vector_store_ids) + if rag_tools: + logger.info("rag_tool are: %s", rag_tools) + toolgroups.extend(rag_tools) + else: + logger.info("No RAG tools configured") + else: + logger.info("No vector stores available for RAG tools") + else: + logger.info("Skipping RAG tools - using direct vector querying instead") # Add MCP server tools mcp_tools = get_mcp_tools(config.mcp_servers, token, mcp_headers) diff --git a/src/app/endpoints/shields.py b/src/app/endpoints/shields.py index 5dd8b8b6..790c2d0b 100644 --- a/src/app/endpoints/shields.py +++ b/src/app/endpoints/shields.py @@ -70,6 +70,8 @@ async def shields_endpoint_handler( try: # try to get Llama Stack client client = AsyncLlamaStackClientHolder().get_client() + # await client.shields.delete(identifier="llama-guard-shielf") + # exit(1) # retrieve shields shields = await client.shields.list() s = [dict(s) for s in shields] diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index afd7293a..2b12f14c 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -59,7 +59,6 @@ logger = logging.getLogger("app.endpoints.handlers") router = APIRouter(tags=["streaming_query"]) - streaming_query_responses: dict[int | str, dict[str, Any]] = { 200: StreamingQueryResponse.openapi_response(), 401: UnauthorizedResponse.openapi_response( @@ -662,6 +661,7 @@ async def streaming_query_endpoint_handler_base( # pylint: disable=too-many-loc token, mcp_headers=mcp_headers, ) + metadata_map: dict[str, dict[str, Any]] = {} # Create context object for response generator diff --git a/src/app/endpoints/streaming_query_v2.py b/src/app/endpoints/streaming_query_v2.py index e1c02ca4..1655d954 100644 --- a/src/app/endpoints/streaming_query_v2.py +++ b/src/app/endpoints/streaming_query_v2.py @@ -1,7 +1,9 @@ """Streaming query handler using Responses API (v2).""" import logging +import traceback from typing import Annotated, Any, AsyncIterator, Optional, cast +from urllib.parse import urljoin from fastapi import APIRouter, Depends, Request from fastapi.responses import StreamingResponse @@ -14,7 +16,7 @@ OpenAIResponseObjectStreamResponseOutputTextDelta, OpenAIResponseObjectStreamResponseOutputTextDone, ) -from llama_stack_client import AsyncLlamaStackClient +from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient from app.endpoints.query import ( is_transcripts_enabled, @@ -42,7 +44,7 @@ from authentication.interface import AuthTuple from authorization.middleware import authorize from configuration import configuration -from constants import MEDIA_TYPE_JSON +from constants import MEDIA_TYPE_JSON, MIMIR_DOC_URL from models.config import Action from models.context import ResponseGeneratorContext from models.requests import QueryRequest @@ -51,6 +53,7 @@ InternalServerErrorResponse, NotFoundResponse, QuotaExceededResponse, + ReferencedDocument, ServiceUnavailableResponse, StreamingQueryResponse, UnauthorizedResponse, @@ -97,6 +100,7 @@ def create_responses_response_generator( # pylint: disable=too-many-locals,too-many-statements context: ResponseGeneratorContext, + doc_ids_from_chunks: Optional[list[ReferencedDocument]] = None, ) -> Any: """ Create a response generator function for Responses API streaming. @@ -106,6 +110,7 @@ def create_responses_response_generator( # pylint: disable=too-many-locals,too- Args: context: Context object containing all necessary parameters for response generation + doc_ids_from_chunks: Referenced documents extracted from vector DB chunks Returns: An async generator function that yields SSE-formatted strings @@ -294,9 +299,13 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat model_id=context.model_id, provider_id=context.provider_id, ) - referenced_documents = parse_referenced_documents_from_responses_api( + response_referenced_documents = parse_referenced_documents_from_responses_api( cast(OpenAIResponseObject, latest_response_object) ) + # Combine doc_ids_from_chunks with response_referenced_documents + all_referenced_documents = ( + doc_ids_from_chunks or [] + ) + response_referenced_documents available_quotas = get_available_quotas( configuration.quota_limiters, context.user_id ) @@ -304,7 +313,7 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat context.metadata_map, token_usage, available_quotas, - referenced_documents, + all_referenced_documents, media_type, ) @@ -382,7 +391,7 @@ async def retrieve_response( # pylint: disable=too-many-locals query_request: QueryRequest, token: str, mcp_headers: Optional[dict[str, dict[str, str]]] = None, -) -> tuple[AsyncIterator[OpenAIResponseObjectStream], str]: +) -> tuple[AsyncIterator[OpenAIResponseObjectStream], str, list[ReferencedDocument]]: """ Retrieve response from LLMs and agents. @@ -403,8 +412,8 @@ async def retrieve_response( # pylint: disable=too-many-locals Multi-cluster proxy headers for tool integrations. Returns: - tuple: A tuple containing the streaming response object - and the conversation ID. + tuple: A tuple containing the streaming response object, + the conversation ID, and the list of referenced documents from vector DB chunks. """ # use system prompt from request or default one system_prompt = get_system_prompt(query_request, configuration) @@ -415,11 +424,180 @@ async def retrieve_response( # pylint: disable=too-many-locals if query_request.attachments: validate_attachments_metadata(query_request.attachments) - # Prepare tools for responses API + # Prepare tools for responses API - skip RAG tools since we're doing direct vector query toolgroups = await prepare_tools_for_responses_api( - client, query_request, token, configuration, mcp_headers + client, + query_request, + token, + configuration, + mcp_headers=mcp_headers, + skip_rag_tools=True, ) + # Extract RAG chunks from vector DB query response BEFORE calling responses API + rag_chunks = [] + doc_ids_from_chunks = [] + retrieved_chunks = [] + retrieved_scores = [] + + # When offline is False, use reference_url for chunk source + # When offline is True, use parent_id for chunk source + # TODO: move this setting to a higher level configuration + offline = True + + try: + # Get vector stores for direct querying + if query_request.vector_store_ids: + vector_store_ids = query_request.vector_store_ids + logger.info( + "Using specified vector_store_ids for direct query: %s", + vector_store_ids, + ) + else: + vector_store_ids = [ + vector_store.id + for vector_store in (await client.vector_stores.list()).data + ] + logger.info( + "Using all available vector_store_ids for direct query: %s", + vector_store_ids, + ) + + if vector_store_ids: + vector_store_id = vector_store_ids[0] # Use first available vector store + + params = {"k": 5, "score_threshold": 0.0, "mode": "hybrid"} + logger.info("Initial params: %s", params) + logger.info("query_request.solr: %s", query_request.solr) + if query_request.solr: + # Pass the entire solr dict under the 'solr' key + params["solr"] = query_request.solr + logger.info("Final params with solr filters: %s", params) + else: + logger.info("No solr filters provided") + logger.info("Final params being sent to vector_io.query: %s", params) + + query_response = await client.vector_io.query( + vector_store_id=vector_store_id, + query=query_request.query, + params=params, + ) + + logger.info("The query response total payload: %s", query_response) + + if query_response.chunks: + retrieved_chunks = query_response.chunks + retrieved_scores = ( + query_response.scores if hasattr(query_response, "scores") else [] + ) + + # Extract doc_ids from chunks for referenced_documents + metadata_doc_ids = set() + + for chunk in query_response.chunks: + logger.info("Extract doc ids from chunk: %s", chunk) + + # 1) dict metadata + md = getattr(chunk, "metadata", None) or {} + doc_id = md.get("doc_id") or md.get("document_id") + title = md.get("title") + + # 2) typed chunk_metadata + if not doc_id: + cm = getattr(chunk, "chunk_metadata", None) + if cm is not None: + # cm might be a pydantic model or a dict depending on caller + if isinstance(cm, dict): + doc_id = cm.get("doc_id") or cm.get("document_id") + title = title or cm.get("title") + reference_url = cm.get("reference_url") + else: + doc_id = getattr(cm, "doc_id", None) or getattr( + cm, "document_id", None + ) + title = title or getattr(cm, "title", None) + reference_url = getattr(cm, "reference_url", None) + else: + reference_url = None + else: + reference_url = md.get("reference_url") + + if not doc_id and not reference_url: + continue + + # Build URL based on offline flag + if offline: + # Use parent/doc path + reference_doc = doc_id + doc_url = MIMIR_DOC_URL + reference_doc + else: + # Use reference_url if online + reference_doc = reference_url or doc_id + doc_url = ( + reference_doc + if reference_doc.startswith("http") + else (MIMIR_DOC_URL + reference_doc) + ) + + if reference_doc and reference_doc not in metadata_doc_ids: + metadata_doc_ids.add(reference_doc) + doc_ids_from_chunks.append( + ReferencedDocument( + doc_title=title, + doc_url=doc_url, + ) + ) + + logger.info( + "Extracted %d unique document IDs from chunks", len(doc_ids_from_chunks) + ) + + except ( + APIConnectionError, + APIStatusError, + AttributeError, + KeyError, + ValueError, + ) as e: + logger.warning("Failed to query vector database for chunks: %s", e) + logger.debug("Vector DB query error details: %s", traceback.format_exc()) + # Continue without RAG chunks + + # Convert retrieved chunks to RAGChunk format + for i, chunk in enumerate(retrieved_chunks): + # Extract source from chunk metadata based on offline flag + source = None + if chunk.metadata: + if offline: + parent_id = chunk.metadata.get("parent_id") + if parent_id: + source = urljoin(MIMIR_DOC_URL, parent_id) + else: + source = chunk.metadata.get("reference_url") + + # Get score from retrieved_scores list if available + score = retrieved_scores[i] if i < len(retrieved_scores) else None + + rag_chunks.append( + RAGChunk( + content=chunk.content, + source=source, + score=score, + ) + ) + + logger.info("Retrieved %d chunks from vector DB", len(rag_chunks)) + + # Format RAG context for injection into user message + rag_context = "" + if rag_chunks: + context_chunks = [] + for chunk in rag_chunks[:5]: # Limit to top 5 chunks + chunk_text = f"Source: {chunk.source or 'Unknown'}\n{chunk.content}" + context_chunks.append(chunk_text) + rag_context = "\n\nRelevant documentation:\n" + "\n\n".join(context_chunks) + logger.info("Injecting %d RAG chunks into user message", len(context_chunks)) + # Prepare input for Responses API # Convert attachments to text and concatenate with query input_text = query_request.query @@ -430,6 +608,9 @@ async def retrieve_response( # pylint: disable=too-many-locals f"{attachment.content}" ) + # Add RAG context to input text + input_text += rag_context + # Handle conversation ID for Responses API # Create conversation upfront if not provided conversation_id = query_request.conversation_id @@ -475,4 +656,8 @@ async def retrieve_response( # pylint: disable=too-many-locals response = await client.responses.create(**create_params) response_stream = cast(AsyncIterator[OpenAIResponseObjectStream], response) - return response_stream, normalize_conversation_id(conversation_id) + return ( + response_stream, + normalize_conversation_id(conversation_id), + doc_ids_from_chunks, + ) diff --git a/src/app/main.py b/src/app/main.py index 74a6b86a..f011ee22 100644 --- a/src/app/main.py +++ b/src/app/main.py @@ -22,6 +22,11 @@ from utils.common import register_mcp_servers_async from utils.llama_stack_version import check_llama_stack_version +import faulthandler +import signal + +faulthandler.register(signal.SIGUSR1) + logger = get_logger(__name__) logger.info("Initializing app") @@ -55,6 +60,12 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]: # check if the Llama Stack version is supported by the service await check_llama_stack_version(client) + # try: + # await client.vector_stores.delete(vector_store_id="portal-rag") + # logger.info("Successfully deregistered vector store: portal-rag") + # except Exception as e: + # logger.warning("Failed to deregister vector store 'portal-rag': %s", e) + logger.info("Registering MCP servers") await register_mcp_servers_async(logger, configuration.configuration) get_logger("app.endpoints.handlers") diff --git a/src/constants.py b/src/constants.py index 681759cd..e7ce5d27 100644 --- a/src/constants.py +++ b/src/constants.py @@ -127,7 +127,7 @@ MCP_AUTH_CLIENT = "client" # default RAG tool value -DEFAULT_RAG_TOOL = "knowledge_search" +DEFAULT_RAG_TOOL = "file_search" # Media type constants for streaming responses MEDIA_TYPE_JSON = "application/json" @@ -161,3 +161,6 @@ # quota limiters constants USER_QUOTA_LIMITER = "user_limiter" CLUSTER_QUOTA_LIMITER = "cluster_limiter" + +# SOLR OKP RAG +MIMIR_DOC_URL = "https://mimir.corp.redhat.com" diff --git a/src/models/requests.py b/src/models/requests.py index 18e5b4b6..3ac4ede6 100644 --- a/src/models/requests.py +++ b/src/models/requests.py @@ -1,7 +1,7 @@ """Models for REST API requests.""" +from typing import Optional, Self, Any from enum import Enum -from typing import Optional, Self from pydantic import BaseModel, Field, field_validator, model_validator @@ -166,6 +166,13 @@ class QueryRequest(BaseModel): examples=["ocp_docs", "knowledge_base", "vector_db_1"], ) + solr: Optional[dict[str, Any]] = Field( + None, + description="Solr-specific query parameters including filter queries", + examples=[ + {"fq": ["product:*openshift*", "product_version:*4.16*"]}, + ], + ) # provides examples for /docs endpoint model_config = { "extra": "forbid", diff --git a/src/utils/endpoints.py b/src/utils/endpoints.py index b0b49917..016cf95f 100644 --- a/src/utils/endpoints.py +++ b/src/utils/endpoints.py @@ -592,7 +592,7 @@ def _process_rag_chunks_for_documents( for chunk in rag_chunks: src = chunk.source - if not src or src == constants.DEFAULT_RAG_TOOL: + if not src or src == constants.DEFAULT_RAG_TOOL or src.endswith("_search"): continue if src.startswith("http"): diff --git a/tests/unit/app/endpoints/test_query_v2.py b/tests/unit/app/endpoints/test_query_v2.py index 37468ad9..6fd6f372 100644 --- a/tests/unit/app/endpoints/test_query_v2.py +++ b/tests/unit/app/endpoints/test_query_v2.py @@ -53,6 +53,16 @@ def test_get_rag_tools() -> None: assert tools[0]["type"] == "file_search" assert tools[0]["vector_store_ids"] == ["db1", "db2"] assert tools[0]["max_num_results"] == 10 + assert "solr" not in tools[0] + + # Test with Solr parameters + solr_params = {"fq": ["product:*openshift*", "product_version:*4.16*"]} + tools_with_solr = get_rag_tools(["db1", "db2"], solr_params) + assert isinstance(tools_with_solr, list) + assert tools_with_solr[0]["type"] == "file_search" + assert tools_with_solr[0]["vector_store_ids"] == ["db1", "db2"] + assert tools_with_solr[0]["max_num_results"] == 10 + assert tools_with_solr[0]["solr"] == solr_params def test_get_mcp_tools_with_and_without_token() -> None: @@ -280,7 +290,22 @@ async def test_retrieve_response_builds_rag_and_mcp_tools( # pylint: disable=to mock_client.shields.list = mocker.AsyncMock(return_value=[]) mock_client.models.list = mocker.AsyncMock(return_value=[]) + # Mock vector_io.query for direct vector querying + mock_query_response = mocker.Mock() + mock_query_response.chunks = [] + mock_query_response.scores = [] + mock_client.vector_io.query = mocker.AsyncMock(return_value=mock_query_response) + mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") + + # Mock shield moderation + mock_moderation_result = mocker.Mock() + mock_moderation_result.blocked = False + mocker.patch( + "app.endpoints.query_v2.run_shield_moderation", + return_value=mock_moderation_result, + ) + mock_cfg = mocker.Mock() mock_cfg.mcp_servers = [ ModelContextProtocolServer( @@ -304,11 +329,9 @@ async def test_retrieve_response_builds_rag_and_mcp_tools( # pylint: disable=to kwargs = mock_client.responses.create.call_args.kwargs tools = kwargs["tools"] assert isinstance(tools, list) - # Expect one file_search and one mcp tool + # Expect only MCP tools since RAG tools are skipped when doing direct vector querying tool_types = {t.get("type") for t in tools} - assert tool_types == {"file_search", "mcp"} - file_search = next(t for t in tools if t["type"] == "file_search") - assert file_search["vector_store_ids"] == ["dbA"] + assert tool_types == {"mcp"} mcp_tool = next(t for t in tools if t["type"] == "mcp") assert mcp_tool["server_label"] == "fs" assert mcp_tool["headers"] == {"Authorization": "Bearer mytoken"} diff --git a/tests/unit/app/endpoints/test_streaming_query_v2.py b/tests/unit/app/endpoints/test_streaming_query_v2.py index d4740786..69cde6e9 100644 --- a/tests/unit/app/endpoints/test_streaming_query_v2.py +++ b/tests/unit/app/endpoints/test_streaming_query_v2.py @@ -53,6 +53,10 @@ async def test_retrieve_response_builds_rag_and_mcp_tools( # Mock shields.list and models.list for run_shield_moderation mock_client.shields.list = mocker.AsyncMock(return_value=[]) mock_client.models.list = mocker.AsyncMock(return_value=[]) + # Mock vector_io.query for direct vector querying + mock_query_response = mocker.Mock() + mock_query_response.chunks = [] + mock_client.vector_io.query = mocker.AsyncMock(return_value=mock_query_response) mocker.patch( "app.endpoints.streaming_query_v2.get_system_prompt", return_value="PROMPT" @@ -77,7 +81,9 @@ async def test_retrieve_response_builds_rag_and_mcp_tools( tools = kwargs["tools"] assert isinstance(tools, list) types = {t.get("type") for t in tools} - assert types == {"file_search", "mcp"} + # Since we're now skipping RAG tools and doing direct vector querying, + # we should only see MCP tools, not file_search tools + assert types == {"mcp"} @pytest.mark.asyncio @@ -95,6 +101,10 @@ async def test_retrieve_response_no_tools_passes_none(mocker: MockerFixture) -> # Mock shields.list and models.list for run_shield_moderation mock_client.shields.list = mocker.AsyncMock(return_value=[]) mock_client.models.list = mocker.AsyncMock(return_value=[]) + # Mock vector_io.query for direct vector querying + mock_query_response = mocker.Mock() + mock_query_response.chunks = [] + mock_client.vector_io.query = mocker.AsyncMock(return_value=mock_query_response) mocker.patch( "app.endpoints.streaming_query_v2.get_system_prompt", return_value="PROMPT"