From 420c6e1932876ef3c59a2e26bc46766f96600a55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Tue, 27 May 2025 11:48:12 +0200 Subject: [PATCH] Report the sources --- crates/meilisearch/src/routes/chat.rs | 123 ++++++++++++++++++++++---- 1 file changed, 107 insertions(+), 16 deletions(-) diff --git a/crates/meilisearch/src/routes/chat.rs b/crates/meilisearch/src/routes/chat.rs index 3db948eb8..5de5a9367 100644 --- a/crates/meilisearch/src/routes/chat.rs +++ b/crates/meilisearch/src/routes/chat.rs @@ -32,9 +32,10 @@ use meilisearch_types::milli::prompt::{Prompt, PromptData}; use meilisearch_types::milli::update::new::document::DocumentFromDb; use meilisearch_types::milli::update::Setting; use meilisearch_types::milli::{ - DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap, MetadataBuilder, TimeBudget, + all_obkv_to_json, obkv_to_json, DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap, + MetadataBuilder, TimeBudget, }; -use meilisearch_types::Index; +use meilisearch_types::{Document, Index}; use serde::{Deserialize, Serialize}; use serde_json::json; use tokio::runtime::Handle; @@ -55,6 +56,8 @@ use crate::search_queue::SearchQueue; const MEILI_SEARCH_PROGRESS_NAME: &str = "_meiliSearchProgress"; const MEILI_APPEND_CONVERSATION_MESSAGE_NAME: &str = "_meiliAppendConversationMessage"; +const MEILI_SEARCH_SOURCES_NAME: &str = "_meiliSearchSources"; +const MEILI_REPORT_ERRORS_NAME: &str = "_meiliReportErrors"; const MEILI_SEARCH_IN_INDEX_FUNCTION_NAME: &str = "_meiliSearchInIndex"; pub fn configure(cfg: &mut web::ServiceConfig) { @@ -93,10 +96,16 @@ async fn chat( pub struct FunctionSupport { /// Defines if we can call the _meiliSearchProgress function /// to inform the front-end about what we are searching for. - progress: bool, + report_progress: bool, + /// Defines if we can call the _meiliSearchSources function + /// to inform the front-end about the sources of the search. + report_sources: bool, /// Defines if we can call the _meiliAppendConversationMessage /// function to provide the messages to append into the conversation. append_to_conversation: bool, + /// Defines if we can call the _meiliReportErrors function + /// to inform the front-end about potential errors. + report_errors: bool, } /// Setup search tool in chat completion request @@ -112,18 +121,28 @@ fn setup_search_tool( } // Remove internal tools used for front-end notifications as they should be hidden from the LLM. - let mut progress = false; + let mut report_progress = false; + let mut report_sources = false; let mut append_to_conversation = false; + let mut report_errors = false; tools.retain(|tool| { match tool.function.name.as_str() { MEILI_SEARCH_PROGRESS_NAME => { - progress = true; + report_progress = true; + false + } + MEILI_SEARCH_SOURCES_NAME => { + report_sources = true; false } MEILI_APPEND_CONVERSATION_MESSAGE_NAME => { append_to_conversation = true; false } + MEILI_REPORT_ERRORS_NAME => { + report_errors = true; + false + } _ => true, // keep other tools } }); @@ -188,7 +207,7 @@ fn setup_search_tool( }), ); - Ok(FunctionSupport { progress, append_to_conversation }) + Ok(FunctionSupport { report_progress, report_sources, append_to_conversation, report_errors }) } /// Process search request and return formatted results @@ -199,7 +218,7 @@ async fn process_search_request( auth_token: &str, index_uid: String, q: Option, -) -> Result<(Index, String), ResponseError> { +) -> Result<(Index, Vec, String), ResponseError> { // TBD // let mut aggregate = SearchAggregator::::from_query(&query); @@ -276,22 +295,33 @@ async fn process_search_request( permit.drop().await; let output = output?; - if let Ok((_, ref search_result)) = output { + let mut documents = Vec::new(); + if let Ok((ref rtxn, ref search_result)) = output { // aggregate.succeed(search_result); if search_result.degraded { MEILISEARCH_DEGRADED_SEARCH_REQUESTS.inc(); } + + let fields_ids_map = index.fields_ids_map(rtxn)?; + let displayed_fields = index.displayed_fields_ids(rtxn)?; + for &document_id in &search_result.documents_ids { + let obkv = index.document(rtxn, document_id)?; + let document = match displayed_fields { + Some(ref fields) => obkv_to_json(fields, &fields_ids_map, obkv)?, + None => all_obkv_to_json(obkv, &fields_ids_map)?, + }; + documents.push(document); + } } // analytics.publish(aggregate, &req); let (rtxn, search_result) = output?; - // let rtxn = index.read_txn()?; let render_alloc = Bump::new(); let formatted = format_documents(&rtxn, &index, &render_alloc, search_result.documents_ids)?; let text = formatted.join("\n"); drop(rtxn); - Ok((index, text)) + Ok((index, documents, text)) } async fn non_streamed_chat( @@ -319,7 +349,7 @@ async fn non_streamed_chat( let auth_token = extract_token_from_request(&req)?.unwrap(); let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap(); - let FunctionSupport { progress, append_to_conversation } = + let FunctionSupport { report_progress, report_sources, append_to_conversation, report_errors } = setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?; let mut response; @@ -359,7 +389,7 @@ async fn non_streamed_chat( }; let text = match result { - Ok((_, text)) => text, + Ok((_, documents, text)) => text, Err(err) => err, }; @@ -411,7 +441,7 @@ async fn streamed_chat( let auth_token = extract_token_from_request(&req)?.unwrap().to_string(); let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap(); - let FunctionSupport { progress, append_to_conversation } = + let FunctionSupport { report_progress, report_sources, append_to_conversation, report_errors } = setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?; let (tx, rx) = tokio::sync::mpsc::channel(10); @@ -507,8 +537,9 @@ async fn streamed_chat( ); for call in meili_calls { - if progress { + if report_progress { let call = MeiliSearchProgress { + call_id: call.id.to_string(), function_name: call.function.name.clone(), function_arguments: call .function @@ -573,7 +604,24 @@ async fn streamed_chat( }; let text = match result { - Ok((_, text)) => text, + Ok((_index, documents, text)) => { + if report_sources { + let call = MeiliSearchSources { + call_id: call.id.to_string(), + sources: documents, + }; + let resp = call.create_response(resp.clone()); + // Send the event of "we are doing a search" + if let Err(SendError(_)) = tx + .send(Event::Data(sse::Data::new_json(&resp).unwrap())) + .await + { + return; + } + } + + text + }, Err(err) => err, }; @@ -651,8 +699,10 @@ async fn streamed_chat( } #[derive(Debug, Clone, Serialize)] -/// Give context about what Meilisearch is doing. +/// Provides information about the current Meilisearch search operation. struct MeiliSearchProgress { + /// The call ID to track the sources of the search. + pub call_id: String, /// The name of the function we are executing. pub function_name: String, /// The arguments of the function we are executing, encoded in JSON. @@ -690,6 +740,47 @@ impl MeiliSearchProgress { } } +#[derive(Debug, Clone, Serialize)] +/// Provides sources of the search. +struct MeiliSearchSources { + /// The call ID to track the original search associated to those sources. + pub call_id: String, + /// The documents associated with the search (call_id). + /// Only the displayed attributes of the documents are returned. + pub sources: Vec, +} + +impl MeiliSearchSources { + fn create_response( + &self, + mut resp: CreateChatCompletionStreamResponse, + ) -> CreateChatCompletionStreamResponse { + let call_text = serde_json::to_string(self).unwrap(); + let tool_call = ChatCompletionMessageToolCallChunk { + index: 0, + id: Some(uuid::Uuid::new_v4().to_string()), + r#type: Some(ChatCompletionToolType::Function), + function: Some(FunctionCallStream { + name: Some(MEILI_SEARCH_SOURCES_NAME.to_string()), + arguments: Some(call_text), + }), + }; + resp.choices[0] = ChatChoiceStream { + index: 0, + delta: ChatCompletionStreamResponseDelta { + content: None, + function_call: None, + tool_calls: Some(vec![tool_call]), + role: Some(Role::Assistant), + refusal: None, + }, + finish_reason: None, + logprobs: None, + }; + resp + } +} + struct MeiliAppendConversationMessage(pub ChatCompletionRequestMessage); impl MeiliAppendConversationMessage {