mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-05-25 09:03:59 +02:00
Nearly support tools on the streaming route
This commit is contained in:
parent
24050f06e4
commit
da7d651f4b
@ -6,14 +6,16 @@ use actix_web::{Either, HttpResponse, Responder};
|
||||
use actix_web_lab::sse::{self, Event};
|
||||
use async_openai::config::OpenAIConfig;
|
||||
use async_openai::types::{
|
||||
ChatCompletionMessageToolCallChunk, ChatCompletionRequestAssistantMessageArgs,
|
||||
ChatCompletionRequestMessage, ChatCompletionRequestToolMessage,
|
||||
ChatCompletionRequestToolMessageContent, ChatCompletionStreamResponseDelta,
|
||||
ChatCompletionToolArgs, ChatCompletionToolType, CreateChatCompletionRequest, FinishReason,
|
||||
FunctionCallStream, FunctionObjectArgs,
|
||||
ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
|
||||
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
|
||||
ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent,
|
||||
ChatCompletionStreamResponseDelta, ChatCompletionToolArgs, ChatCompletionToolType,
|
||||
CreateChatCompletionRequest, FinishReason, FunctionCall, FunctionCallStream,
|
||||
FunctionObjectArgs,
|
||||
};
|
||||
use async_openai::Client;
|
||||
use futures::StreamExt;
|
||||
use futures_util::stream;
|
||||
use index_scheduler::IndexScheduler;
|
||||
use meilisearch_types::error::ResponseError;
|
||||
use meilisearch_types::keys::actions;
|
||||
@ -23,6 +25,7 @@ use meilisearch_types::milli::vector::EmbeddingConfig;
|
||||
use meilisearch_types::{Document, Index};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use tokio::runtime::Handle;
|
||||
|
||||
use crate::extractors::authentication::policies::ActionPolicy;
|
||||
use crate::extractors::authentication::GuardedData;
|
||||
@ -297,26 +300,25 @@ async fn streamed_chat(
|
||||
|
||||
let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base
|
||||
let client = Client::with_config(config);
|
||||
let response = client.chat().create_stream(chat_completion).await.unwrap();
|
||||
let response = client.chat().create_stream(chat_completion.clone()).await.unwrap();
|
||||
let mut global_tool_calls = HashMap::<u32, Call>::new();
|
||||
actix_web_lab::sse::Sse::from_stream(response.map(move |response| {
|
||||
response.map(|mut r| {
|
||||
let delta = &r.choices[0].delta;
|
||||
actix_web_lab::sse::Sse::from_stream(response.flat_map(move |response| match response {
|
||||
Ok(resp) => {
|
||||
let delta = &resp.choices[0].delta;
|
||||
let ChatCompletionStreamResponseDelta {
|
||||
ref content,
|
||||
ref function_call,
|
||||
content: _,
|
||||
function_call: _,
|
||||
ref tool_calls,
|
||||
ref role,
|
||||
ref refusal,
|
||||
role: _,
|
||||
refusal: _,
|
||||
} = delta;
|
||||
|
||||
match tool_calls {
|
||||
Some(tool_calls) => {
|
||||
for chunk in tool_calls {
|
||||
let ChatCompletionMessageToolCallChunk { index, id, r#type, function } =
|
||||
let ChatCompletionMessageToolCallChunk { index, id, r#type: _, function } =
|
||||
chunk;
|
||||
let FunctionCallStream { ref name, ref arguments } =
|
||||
function.as_ref().unwrap();
|
||||
let FunctionCallStream { name, arguments } = function.as_ref().unwrap();
|
||||
global_tool_calls
|
||||
.entry(*index)
|
||||
.or_insert_with(|| Call {
|
||||
@ -326,15 +328,120 @@ async fn streamed_chat(
|
||||
})
|
||||
.append(arguments.as_ref().unwrap());
|
||||
}
|
||||
stream::iter(vec![Ok(Event::Data(sse::Data::new_json(resp).unwrap()))])
|
||||
}
|
||||
None if !global_tool_calls.is_empty() => {
|
||||
dbg!(&global_tool_calls);
|
||||
}
|
||||
None => (),
|
||||
}
|
||||
|
||||
Event::Data(sse::Data::new_json(r).unwrap())
|
||||
})
|
||||
let (meili_calls, other_calls): (Vec<_>, Vec<_>) =
|
||||
mem::take(&mut global_tool_calls)
|
||||
.into_iter()
|
||||
.map(|(_, call)| ChatCompletionMessageToolCall {
|
||||
id: call.id,
|
||||
r#type: ChatCompletionToolType::Function,
|
||||
function: FunctionCall {
|
||||
name: call.function_name,
|
||||
arguments: call.arguments,
|
||||
},
|
||||
})
|
||||
.partition(|call| call.function.name == "searchInIndex");
|
||||
|
||||
chat_completion.messages.push(
|
||||
ChatCompletionRequestAssistantMessageArgs::default()
|
||||
.tool_calls(meili_calls.clone())
|
||||
.build()
|
||||
.unwrap()
|
||||
.into(),
|
||||
);
|
||||
|
||||
for call in meili_calls {
|
||||
let SearchInIndexParameters { index_uid, q } =
|
||||
serde_json::from_str(&call.function.arguments).unwrap();
|
||||
|
||||
let mut query = SearchQuery {
|
||||
q,
|
||||
hybrid: Some(HybridQuery {
|
||||
semantic_ratio: SemanticRatio::default(),
|
||||
embedder: EMBEDDER_NAME.to_string(),
|
||||
}),
|
||||
limit: 20,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Tenant token search_rules.
|
||||
if let Some(search_rules) =
|
||||
index_scheduler.filters().get_index_search_rules(&index_uid)
|
||||
{
|
||||
add_search_rules(&mut query.filter, search_rules);
|
||||
}
|
||||
|
||||
// TBD
|
||||
// let mut aggregate = SearchAggregator::<SearchPOST>::from_query(&query);
|
||||
|
||||
let index = index_scheduler.index(&index_uid).unwrap();
|
||||
let search_kind = search_kind(
|
||||
&query,
|
||||
index_scheduler.get_ref(),
|
||||
index_uid.to_string(),
|
||||
&index,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// let permit = search_queue.try_get_search_permit().await?;
|
||||
let features = index_scheduler.features();
|
||||
let index_cloned = index.clone();
|
||||
// let search_result = tokio::task::spawn_blocking(move || {
|
||||
let search_result = perform_search(
|
||||
index_uid.to_string(),
|
||||
&index_cloned,
|
||||
query,
|
||||
search_kind,
|
||||
RetrieveVectors::new(false),
|
||||
features,
|
||||
);
|
||||
// })
|
||||
// .await;
|
||||
// permit.drop().await;
|
||||
|
||||
// let search_result = search_result.unwrap();
|
||||
if let Ok(ref search_result) = search_result {
|
||||
// aggregate.succeed(search_result);
|
||||
if search_result.degraded {
|
||||
MEILISEARCH_DEGRADED_SEARCH_REQUESTS.inc();
|
||||
}
|
||||
}
|
||||
// analytics.publish(aggregate, &req);
|
||||
|
||||
let search_result = search_result.unwrap();
|
||||
let formatted = format_documents(
|
||||
&index,
|
||||
search_result.hits.into_iter().map(|doc| doc.document),
|
||||
);
|
||||
let text = formatted.join("\n");
|
||||
chat_completion.messages.push(ChatCompletionRequestMessage::Tool(
|
||||
ChatCompletionRequestToolMessage {
|
||||
tool_call_id: call.id,
|
||||
content: ChatCompletionRequestToolMessageContent::Text(text),
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
let response = Handle::current().block_on(async {
|
||||
client.chat().create_stream(chat_completion.clone()).await.unwrap()
|
||||
});
|
||||
|
||||
// stream::iter(vec![
|
||||
// Ok(Event::Data(sse::Data::new_json(json!({ "text": "Hello" })).unwrap())),
|
||||
// Ok(Event::Data(sse::Data::new_json(json!({ "text": " world" })).unwrap())),
|
||||
// Ok(Event::Data(sse::Data::new_json(json!({ "text": " !" })).unwrap())),
|
||||
// ])
|
||||
|
||||
response
|
||||
}
|
||||
None => stream::iter(vec![Ok(Event::Data(sse::Data::new_json(resp).unwrap()))]),
|
||||
}
|
||||
}
|
||||
Err(err) => stream::iter(vec![Err(err)]),
|
||||
}))
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user