diff --git a/crates/meilisearch/src/routes/chat.rs b/crates/meilisearch/src/routes/chat.rs index 70d565b99..207feb256 100644 --- a/crates/meilisearch/src/routes/chat.rs +++ b/crates/meilisearch/src/routes/chat.rs @@ -6,14 +6,16 @@ use actix_web::{Either, HttpResponse, Responder}; use actix_web_lab::sse::{self, Event}; use async_openai::config::OpenAIConfig; use async_openai::types::{ - ChatCompletionMessageToolCallChunk, ChatCompletionRequestAssistantMessageArgs, - ChatCompletionRequestMessage, ChatCompletionRequestToolMessage, - ChatCompletionRequestToolMessageContent, ChatCompletionStreamResponseDelta, - ChatCompletionToolArgs, ChatCompletionToolType, CreateChatCompletionRequest, FinishReason, - FunctionCallStream, FunctionObjectArgs, + ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk, + ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, + ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent, + ChatCompletionStreamResponseDelta, ChatCompletionToolArgs, ChatCompletionToolType, + CreateChatCompletionRequest, FinishReason, FunctionCall, FunctionCallStream, + FunctionObjectArgs, }; use async_openai::Client; use futures::StreamExt; +use futures_util::stream; use index_scheduler::IndexScheduler; use meilisearch_types::error::ResponseError; use meilisearch_types::keys::actions; @@ -23,6 +25,7 @@ use meilisearch_types::milli::vector::EmbeddingConfig; use meilisearch_types::{Document, Index}; use serde::{Deserialize, Serialize}; use serde_json::json; +use tokio::runtime::Handle; use crate::extractors::authentication::policies::ActionPolicy; use crate::extractors::authentication::GuardedData; @@ -297,26 +300,25 @@ async fn streamed_chat( let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base let client = Client::with_config(config); - let response = client.chat().create_stream(chat_completion).await.unwrap(); + let response = client.chat().create_stream(chat_completion.clone()).await.unwrap(); let mut global_tool_calls = HashMap::::new(); - actix_web_lab::sse::Sse::from_stream(response.map(move |response| { - response.map(|mut r| { - let delta = &r.choices[0].delta; + actix_web_lab::sse::Sse::from_stream(response.flat_map(move |response| match response { + Ok(resp) => { + let delta = &resp.choices[0].delta; let ChatCompletionStreamResponseDelta { - ref content, - ref function_call, + content: _, + function_call: _, ref tool_calls, - ref role, - ref refusal, + role: _, + refusal: _, } = delta; match tool_calls { Some(tool_calls) => { for chunk in tool_calls { - let ChatCompletionMessageToolCallChunk { index, id, r#type, function } = + let ChatCompletionMessageToolCallChunk { index, id, r#type: _, function } = chunk; - let FunctionCallStream { ref name, ref arguments } = - function.as_ref().unwrap(); + let FunctionCallStream { name, arguments } = function.as_ref().unwrap(); global_tool_calls .entry(*index) .or_insert_with(|| Call { @@ -326,15 +328,120 @@ async fn streamed_chat( }) .append(arguments.as_ref().unwrap()); } + stream::iter(vec![Ok(Event::Data(sse::Data::new_json(resp).unwrap()))]) } None if !global_tool_calls.is_empty() => { dbg!(&global_tool_calls); - } - None => (), - } - Event::Data(sse::Data::new_json(r).unwrap()) - }) + let (meili_calls, other_calls): (Vec<_>, Vec<_>) = + mem::take(&mut global_tool_calls) + .into_iter() + .map(|(_, call)| ChatCompletionMessageToolCall { + id: call.id, + r#type: ChatCompletionToolType::Function, + function: FunctionCall { + name: call.function_name, + arguments: call.arguments, + }, + }) + .partition(|call| call.function.name == "searchInIndex"); + + chat_completion.messages.push( + ChatCompletionRequestAssistantMessageArgs::default() + .tool_calls(meili_calls.clone()) + .build() + .unwrap() + .into(), + ); + + for call in meili_calls { + let SearchInIndexParameters { index_uid, q } = + serde_json::from_str(&call.function.arguments).unwrap(); + + let mut query = SearchQuery { + q, + hybrid: Some(HybridQuery { + semantic_ratio: SemanticRatio::default(), + embedder: EMBEDDER_NAME.to_string(), + }), + limit: 20, + ..Default::default() + }; + + // Tenant token search_rules. + if let Some(search_rules) = + index_scheduler.filters().get_index_search_rules(&index_uid) + { + add_search_rules(&mut query.filter, search_rules); + } + + // TBD + // let mut aggregate = SearchAggregator::::from_query(&query); + + let index = index_scheduler.index(&index_uid).unwrap(); + let search_kind = search_kind( + &query, + index_scheduler.get_ref(), + index_uid.to_string(), + &index, + ) + .unwrap(); + + // let permit = search_queue.try_get_search_permit().await?; + let features = index_scheduler.features(); + let index_cloned = index.clone(); + // let search_result = tokio::task::spawn_blocking(move || { + let search_result = perform_search( + index_uid.to_string(), + &index_cloned, + query, + search_kind, + RetrieveVectors::new(false), + features, + ); + // }) + // .await; + // permit.drop().await; + + // let search_result = search_result.unwrap(); + if let Ok(ref search_result) = search_result { + // aggregate.succeed(search_result); + if search_result.degraded { + MEILISEARCH_DEGRADED_SEARCH_REQUESTS.inc(); + } + } + // analytics.publish(aggregate, &req); + + let search_result = search_result.unwrap(); + let formatted = format_documents( + &index, + search_result.hits.into_iter().map(|doc| doc.document), + ); + let text = formatted.join("\n"); + chat_completion.messages.push(ChatCompletionRequestMessage::Tool( + ChatCompletionRequestToolMessage { + tool_call_id: call.id, + content: ChatCompletionRequestToolMessageContent::Text(text), + }, + )); + } + + let response = Handle::current().block_on(async { + client.chat().create_stream(chat_completion.clone()).await.unwrap() + }); + + // stream::iter(vec![ + // Ok(Event::Data(sse::Data::new_json(json!({ "text": "Hello" })).unwrap())), + // Ok(Event::Data(sse::Data::new_json(json!({ "text": " world" })).unwrap())), + // Ok(Event::Data(sse::Data::new_json(json!({ "text": " !" })).unwrap())), + // ]) + + response + } + None => stream::iter(vec![Ok(Event::Data(sse::Data::new_json(resp).unwrap()))]), + } + } + Err(err) => stream::iter(vec![Err(err)]), })) }