diff --git a/Cargo.lock b/Cargo.lock index e82414c80..461b26ab1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3718,6 +3718,7 @@ dependencies = [ "itertools 0.14.0", "jsonwebtoken", "lazy_static", + "liquid", "manifest-dir-macros", "maplit", "meili-snap", diff --git a/crates/meilisearch/Cargo.toml b/crates/meilisearch/Cargo.toml index ff6e5ceeb..a0ce49193 100644 --- a/crates/meilisearch/Cargo.toml +++ b/crates/meilisearch/Cargo.toml @@ -48,6 +48,7 @@ is-terminal = "0.4.13" itertools = "0.14.0" jsonwebtoken = "9.3.0" lazy_static = "1.5.0" +liquid = "0.26.9" meilisearch-auth = { path = "../meilisearch-auth" } meilisearch-types = { path = "../meilisearch-types" } mimalloc = { version = "0.1.43", default-features = false } diff --git a/crates/meilisearch/src/routes/chat.rs b/crates/meilisearch/src/routes/chat.rs index 1cb813acd..335d5bf43 100644 --- a/crates/meilisearch/src/routes/chat.rs +++ b/crates/meilisearch/src/routes/chat.rs @@ -1,14 +1,48 @@ +use std::mem; + use actix_web::web::{self, Data}; use actix_web::HttpResponse; use async_openai::config::OpenAIConfig; -use async_openai::types::CreateChatCompletionRequest; +use async_openai::types::{ + ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, + ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent, + ChatCompletionToolArgs, ChatCompletionToolType, CreateChatCompletionRequest, FinishReason, + FunctionObjectArgs, +}; use async_openai::Client; use index_scheduler::IndexScheduler; use meilisearch_types::error::ResponseError; use meilisearch_types::keys::actions; +use meilisearch_types::milli::index::IndexEmbeddingConfig; +use meilisearch_types::milli::prompt::PromptData; +use meilisearch_types::milli::vector::EmbeddingConfig; +use meilisearch_types::{Document, Index}; +use serde::Deserialize; +use serde_json::json; use crate::extractors::authentication::policies::ActionPolicy; use crate::extractors::authentication::GuardedData; +use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS; +use crate::routes::indexes::search::search_kind; +use crate::search::{ + add_search_rules, perform_search, HybridQuery, RetrieveVectors, SearchQuery, SemanticRatio, +}; +use crate::search_queue::SearchQueue; + +/// The default description of the searchInIndex tool provided to OpenAI. +const DEFAULT_SEARCH_IN_INDEX_TOOL_DESCRIPTION: &str = + "Search the database for relevant JSON documents using an optional query."; +/// The default description of the searchInIndex `q` parameter tool provided to OpenAI. +const DEFAULT_SEARCH_IN_INDEX_Q_PARAMETER_TOOL_DESCRIPTION: &str = + "The search query string used to find relevant documents in the index. \ +This should contain keywords or phrases that best represent what the user is looking for. \ +More specific queries will yield more precise results."; +/// The default description of the searchInIndex `index` parameter tool provided to OpenAI. +const DEFAULT_SEARCH_IN_INDEX_INDEX_PARAMETER_TOOL_DESCRIPTION: &str = +"The name of the index to search within. An index is a collection of documents organized for search. \ +Selecting the right index ensures the most relevant results for the user query"; + +const EMBEDDER_NAME: &str = "openai"; pub fn configure(cfg: &mut web::ServiceConfig) { cfg.service(web::resource("").route(web::post().to(chat))); @@ -16,8 +50,9 @@ pub fn configure(cfg: &mut web::ServiceConfig) { /// Get a chat completion async fn chat( - _index_scheduler: GuardedData, Data>, - web::Json(chat_completion): web::Json, + index_scheduler: GuardedData, Data>, + search_queue: web::Data, + web::Json(mut chat_completion): web::Json, ) -> Result { // To enable later on, when the feature will be experimental // index_scheduler.features().check_chat("Using the /chat route")?; @@ -26,7 +61,173 @@ async fn chat( .expect("cannot find OpenAI API Key (MEILI_OPENAI_API_KEY)"); let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base let client = Client::with_config(config); - let response = client.chat().create(chat_completion).await.unwrap(); + + assert_eq!( + chat_completion.n.unwrap_or(1), + 1, + "Meilisearch /chat only support one completion at a time (n = 1, n = null)" + ); + + let mut response; + loop { + let mut tools = chat_completion.tools.get_or_insert_default(); + tools.push(ChatCompletionToolArgs::default() + .r#type(ChatCompletionToolType::Function) + .function(FunctionObjectArgs::default() + .name("searchInIndex") + .description(DEFAULT_SEARCH_IN_INDEX_TOOL_DESCRIPTION) + .parameters(json!({ + "type": "object", + "properties": { + "index_uid": { + "type": "string", + "enum": ["main"], + "description": DEFAULT_SEARCH_IN_INDEX_INDEX_PARAMETER_TOOL_DESCRIPTION, + }, + "q": { + "type": ["string", "null"], + "description": DEFAULT_SEARCH_IN_INDEX_Q_PARAMETER_TOOL_DESCRIPTION, + } + }, + "required": ["index_uid", "q"], + "additionalProperties": false, + })) + .strict(true) + .build() + .unwrap(), + ) + .build() + .unwrap() + ); + response = dbg!(client.chat().create(chat_completion.clone()).await.unwrap()); + + let choice = &mut response.choices[0]; + match choice.finish_reason { + Some(FinishReason::ToolCalls) => { + let tool_calls = mem::take(&mut choice.message.tool_calls).unwrap_or_default(); + + let (meili_calls, other_calls): (Vec<_>, Vec<_>) = + tool_calls.into_iter().partition(|call| call.function.name == "searchInIndex"); + + chat_completion.messages.push( + ChatCompletionRequestAssistantMessageArgs::default() + .tool_calls(meili_calls.clone()) + .build() + .unwrap() + .into(), + ); + + for call in meili_calls { + let SearchInIndexParameters { index_uid, q } = + serde_json::from_str(dbg!(&call.function.arguments)).unwrap(); + + let mut query = SearchQuery { + q, + hybrid: Some(HybridQuery { + semantic_ratio: SemanticRatio::default(), + embedder: EMBEDDER_NAME.to_string(), + }), + ..Default::default() + }; + + // Tenant token search_rules. + if let Some(search_rules) = + index_scheduler.filters().get_index_search_rules(&index_uid) + { + add_search_rules(&mut query.filter, search_rules); + } + + // TBD + // let mut aggregate = SearchAggregator::::from_query(&query); + + let index = index_scheduler.index(&index_uid)?; + let search_kind = search_kind( + &query, + index_scheduler.get_ref(), + index_uid.to_string(), + &index, + )?; + + let permit = search_queue.try_get_search_permit().await?; + let features = index_scheduler.features(); + let index_cloned = index.clone(); + let search_result = tokio::task::spawn_blocking(move || { + perform_search( + index_uid.to_string(), + &index_cloned, + query, + search_kind, + RetrieveVectors::new(false), + features, + ) + }) + .await; + permit.drop().await; + + let search_result = search_result?; + if let Ok(ref search_result) = search_result { + // aggregate.succeed(search_result); + if search_result.degraded { + MEILISEARCH_DEGRADED_SEARCH_REQUESTS.inc(); + } + } + // analytics.publish(aggregate, &req); + + let search_result = search_result?; + let formatted = format_documents( + &index, + search_result.hits.into_iter().map(|doc| doc.document), + ); + let text = formatted.join("\n"); + chat_completion.messages.push(ChatCompletionRequestMessage::Tool( + ChatCompletionRequestToolMessage { + tool_call_id: call.id, + content: ChatCompletionRequestToolMessageContent::Text(text), + }, + )); + } + + // Let the client call other tools by themselves + if !other_calls.is_empty() { + response.choices[0].message.tool_calls = Some(other_calls); + break; + } + } + _ => break, + } + } Ok(HttpResponse::Ok().json(response)) } + +#[derive(Deserialize)] +struct SearchInIndexParameters { + /// The index uid to search in. + index_uid: String, + /// The query parameter to use. + q: Option, +} + +fn format_documents(index: &Index, documents: impl Iterator) -> Vec { + let rtxn = index.read_txn().unwrap(); + let IndexEmbeddingConfig { name: _, config, user_provided: _ } = index + .embedding_configs(&rtxn) + .unwrap() + .into_iter() + .find(|conf| conf.name == EMBEDDER_NAME) + .unwrap(); + + let EmbeddingConfig { + embedder_options: _, + prompt: PromptData { template, max_bytes }, + quantized: _, + } = config; + + let template = liquid::ParserBuilder::with_stdlib().build().unwrap().parse(&template).unwrap(); + documents + .map(|doc| { + let object = liquid::to_object(&doc).unwrap(); + template.render(&object).unwrap() + }) + .collect() +}