mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-05-15 08:43:56 +02:00
Streaming supports tool calling
This commit is contained in:
parent
da7d651f4b
commit
6e8b371111
@ -1,9 +1,10 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::mem;
|
use std::mem;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
use actix_web::web::{self, Data};
|
use actix_web::web::{self, Data};
|
||||||
use actix_web::{Either, HttpResponse, Responder};
|
use actix_web::{Either, HttpResponse, Responder};
|
||||||
use actix_web_lab::sse::{self, Event};
|
use actix_web_lab::sse::{self, Event, Sse};
|
||||||
use async_openai::config::OpenAIConfig;
|
use async_openai::config::OpenAIConfig;
|
||||||
use async_openai::types::{
|
use async_openai::types::{
|
||||||
ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
|
ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
|
||||||
@ -15,7 +16,6 @@ use async_openai::types::{
|
|||||||
};
|
};
|
||||||
use async_openai::Client;
|
use async_openai::Client;
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use futures_util::stream;
|
|
||||||
use index_scheduler::IndexScheduler;
|
use index_scheduler::IndexScheduler;
|
||||||
use meilisearch_types::error::ResponseError;
|
use meilisearch_types::error::ResponseError;
|
||||||
use meilisearch_types::keys::actions;
|
use meilisearch_types::keys::actions;
|
||||||
@ -59,7 +59,7 @@ pub fn configure(cfg: &mut web::ServiceConfig) {
|
|||||||
async fn chat(
|
async fn chat(
|
||||||
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT_GET }>, Data<IndexScheduler>>,
|
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT_GET }>, Data<IndexScheduler>>,
|
||||||
search_queue: web::Data<SearchQueue>,
|
search_queue: web::Data<SearchQueue>,
|
||||||
web::Json(mut chat_completion): web::Json<CreateChatCompletionRequest>,
|
web::Json(chat_completion): web::Json<CreateChatCompletionRequest>,
|
||||||
) -> impl Responder {
|
) -> impl Responder {
|
||||||
// To enable later on, when the feature will be experimental
|
// To enable later on, when the feature will be experimental
|
||||||
// index_scheduler.features().check_chat("Using the /chat route")?;
|
// index_scheduler.features().check_chat("Using the /chat route")?;
|
||||||
@ -298,151 +298,174 @@ async fn streamed_chat(
|
|||||||
.unwrap(),
|
.unwrap(),
|
||||||
);
|
);
|
||||||
|
|
||||||
let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base
|
let (tx, rx) = tokio::sync::mpsc::channel(10);
|
||||||
let client = Client::with_config(config);
|
let _join_handle = Handle::current().spawn(async move {
|
||||||
let response = client.chat().create_stream(chat_completion.clone()).await.unwrap();
|
let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base
|
||||||
let mut global_tool_calls = HashMap::<u32, Call>::new();
|
let client = Client::with_config(config);
|
||||||
actix_web_lab::sse::Sse::from_stream(response.flat_map(move |response| match response {
|
let mut global_tool_calls = HashMap::<u32, Call>::new();
|
||||||
Ok(resp) => {
|
|
||||||
let delta = &resp.choices[0].delta;
|
|
||||||
let ChatCompletionStreamResponseDelta {
|
|
||||||
content: _,
|
|
||||||
function_call: _,
|
|
||||||
ref tool_calls,
|
|
||||||
role: _,
|
|
||||||
refusal: _,
|
|
||||||
} = delta;
|
|
||||||
|
|
||||||
match tool_calls {
|
'main: loop {
|
||||||
Some(tool_calls) => {
|
let mut response = client.chat().create_stream(chat_completion.clone()).await.unwrap();
|
||||||
for chunk in tool_calls {
|
|
||||||
let ChatCompletionMessageToolCallChunk { index, id, r#type: _, function } =
|
|
||||||
chunk;
|
|
||||||
let FunctionCallStream { name, arguments } = function.as_ref().unwrap();
|
|
||||||
global_tool_calls
|
|
||||||
.entry(*index)
|
|
||||||
.or_insert_with(|| Call {
|
|
||||||
id: id.as_ref().unwrap().clone(),
|
|
||||||
function_name: name.as_ref().unwrap().clone(),
|
|
||||||
arguments: arguments.as_ref().unwrap().clone(),
|
|
||||||
})
|
|
||||||
.append(arguments.as_ref().unwrap());
|
|
||||||
}
|
|
||||||
stream::iter(vec![Ok(Event::Data(sse::Data::new_json(resp).unwrap()))])
|
|
||||||
}
|
|
||||||
None if !global_tool_calls.is_empty() => {
|
|
||||||
dbg!(&global_tool_calls);
|
|
||||||
|
|
||||||
let (meili_calls, other_calls): (Vec<_>, Vec<_>) =
|
while let Some(result) = response.next().await {
|
||||||
mem::take(&mut global_tool_calls)
|
match result {
|
||||||
.into_iter()
|
Ok(resp) => {
|
||||||
.map(|(_, call)| ChatCompletionMessageToolCall {
|
let delta = &resp.choices[0].delta;
|
||||||
id: call.id,
|
let ChatCompletionStreamResponseDelta {
|
||||||
r#type: ChatCompletionToolType::Function,
|
content,
|
||||||
function: FunctionCall {
|
function_call: _,
|
||||||
name: call.function_name,
|
ref tool_calls,
|
||||||
arguments: call.arguments,
|
role: _,
|
||||||
},
|
refusal: _,
|
||||||
})
|
} = delta;
|
||||||
.partition(|call| call.function.name == "searchInIndex");
|
|
||||||
|
|
||||||
chat_completion.messages.push(
|
if content.is_none() && tool_calls.is_none() && global_tool_calls.is_empty()
|
||||||
ChatCompletionRequestAssistantMessageArgs::default()
|
|
||||||
.tool_calls(meili_calls.clone())
|
|
||||||
.build()
|
|
||||||
.unwrap()
|
|
||||||
.into(),
|
|
||||||
);
|
|
||||||
|
|
||||||
for call in meili_calls {
|
|
||||||
let SearchInIndexParameters { index_uid, q } =
|
|
||||||
serde_json::from_str(&call.function.arguments).unwrap();
|
|
||||||
|
|
||||||
let mut query = SearchQuery {
|
|
||||||
q,
|
|
||||||
hybrid: Some(HybridQuery {
|
|
||||||
semantic_ratio: SemanticRatio::default(),
|
|
||||||
embedder: EMBEDDER_NAME.to_string(),
|
|
||||||
}),
|
|
||||||
limit: 20,
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Tenant token search_rules.
|
|
||||||
if let Some(search_rules) =
|
|
||||||
index_scheduler.filters().get_index_search_rules(&index_uid)
|
|
||||||
{
|
{
|
||||||
add_search_rules(&mut query.filter, search_rules);
|
break 'main;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TBD
|
if let Some(text) = content {
|
||||||
// let mut aggregate = SearchAggregator::<SearchPOST>::from_query(&query);
|
tx.send(Event::Data(sse::Data::new_json(&resp).unwrap())).await.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
let index = index_scheduler.index(&index_uid).unwrap();
|
match tool_calls {
|
||||||
let search_kind = search_kind(
|
Some(tool_calls) => {
|
||||||
&query,
|
for chunk in tool_calls {
|
||||||
index_scheduler.get_ref(),
|
let ChatCompletionMessageToolCallChunk {
|
||||||
index_uid.to_string(),
|
index,
|
||||||
&index,
|
id,
|
||||||
)
|
r#type: _,
|
||||||
.unwrap();
|
function,
|
||||||
|
} = chunk;
|
||||||
// let permit = search_queue.try_get_search_permit().await?;
|
let FunctionCallStream { name, arguments } =
|
||||||
let features = index_scheduler.features();
|
function.as_ref().unwrap();
|
||||||
let index_cloned = index.clone();
|
global_tool_calls
|
||||||
// let search_result = tokio::task::spawn_blocking(move || {
|
.entry(*index)
|
||||||
let search_result = perform_search(
|
.or_insert_with(|| Call {
|
||||||
index_uid.to_string(),
|
id: id.as_ref().unwrap().clone(),
|
||||||
&index_cloned,
|
function_name: name.as_ref().unwrap().clone(),
|
||||||
query,
|
arguments: arguments.as_ref().unwrap().clone(),
|
||||||
search_kind,
|
})
|
||||||
RetrieveVectors::new(false),
|
.append(arguments.as_ref().unwrap());
|
||||||
features,
|
}
|
||||||
);
|
tx.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
|
||||||
// })
|
.await
|
||||||
// .await;
|
.unwrap()
|
||||||
// permit.drop().await;
|
|
||||||
|
|
||||||
// let search_result = search_result.unwrap();
|
|
||||||
if let Ok(ref search_result) = search_result {
|
|
||||||
// aggregate.succeed(search_result);
|
|
||||||
if search_result.degraded {
|
|
||||||
MEILISEARCH_DEGRADED_SEARCH_REQUESTS.inc();
|
|
||||||
}
|
}
|
||||||
|
None if !global_tool_calls.is_empty() => {
|
||||||
|
// dbg!(&global_tool_calls);
|
||||||
|
|
||||||
|
let (meili_calls, other_calls): (Vec<_>, Vec<_>) =
|
||||||
|
mem::take(&mut global_tool_calls)
|
||||||
|
.into_iter()
|
||||||
|
.map(|(_, call)| ChatCompletionMessageToolCall {
|
||||||
|
id: call.id,
|
||||||
|
r#type: ChatCompletionToolType::Function,
|
||||||
|
function: FunctionCall {
|
||||||
|
name: call.function_name,
|
||||||
|
arguments: call.arguments,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.partition(|call| call.function.name == "searchInIndex");
|
||||||
|
|
||||||
|
chat_completion.messages.push(
|
||||||
|
ChatCompletionRequestAssistantMessageArgs::default()
|
||||||
|
.tool_calls(meili_calls.clone())
|
||||||
|
.build()
|
||||||
|
.unwrap()
|
||||||
|
.into(),
|
||||||
|
);
|
||||||
|
|
||||||
|
for call in meili_calls {
|
||||||
|
let SearchInIndexParameters { index_uid, q } =
|
||||||
|
serde_json::from_str(&call.function.arguments).unwrap();
|
||||||
|
|
||||||
|
let mut query = SearchQuery {
|
||||||
|
q,
|
||||||
|
hybrid: Some(HybridQuery {
|
||||||
|
semantic_ratio: SemanticRatio::default(),
|
||||||
|
embedder: EMBEDDER_NAME.to_string(),
|
||||||
|
}),
|
||||||
|
limit: 20,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Tenant token search_rules.
|
||||||
|
if let Some(search_rules) =
|
||||||
|
index_scheduler.filters().get_index_search_rules(&index_uid)
|
||||||
|
{
|
||||||
|
add_search_rules(&mut query.filter, search_rules);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TBD
|
||||||
|
// let mut aggregate = SearchAggregator::<SearchPOST>::from_query(&query);
|
||||||
|
|
||||||
|
let index = index_scheduler.index(&index_uid).unwrap();
|
||||||
|
let search_kind = search_kind(
|
||||||
|
&query,
|
||||||
|
index_scheduler.get_ref(),
|
||||||
|
index_uid.to_string(),
|
||||||
|
&index,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let permit =
|
||||||
|
search_queue.try_get_search_permit().await.unwrap();
|
||||||
|
let features = index_scheduler.features();
|
||||||
|
let index_cloned = index.clone();
|
||||||
|
let search_result = tokio::task::spawn_blocking(move || {
|
||||||
|
perform_search(
|
||||||
|
index_uid.to_string(),
|
||||||
|
&index_cloned,
|
||||||
|
query,
|
||||||
|
search_kind,
|
||||||
|
RetrieveVectors::new(false),
|
||||||
|
features,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
permit.drop().await;
|
||||||
|
|
||||||
|
let search_result = search_result.unwrap();
|
||||||
|
if let Ok(ref search_result) = search_result {
|
||||||
|
// aggregate.succeed(search_result);
|
||||||
|
if search_result.degraded {
|
||||||
|
MEILISEARCH_DEGRADED_SEARCH_REQUESTS.inc();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// analytics.publish(aggregate, &req);
|
||||||
|
|
||||||
|
let search_result = search_result.unwrap();
|
||||||
|
let formatted = format_documents(
|
||||||
|
&index,
|
||||||
|
search_result.hits.into_iter().map(|doc| doc.document),
|
||||||
|
);
|
||||||
|
let text = formatted.join("\n");
|
||||||
|
chat_completion.messages.push(
|
||||||
|
ChatCompletionRequestMessage::Tool(
|
||||||
|
ChatCompletionRequestToolMessage {
|
||||||
|
tool_call_id: call.id,
|
||||||
|
content:
|
||||||
|
ChatCompletionRequestToolMessageContent::Text(
|
||||||
|
text,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => (),
|
||||||
}
|
}
|
||||||
// analytics.publish(aggregate, &req);
|
|
||||||
|
|
||||||
let search_result = search_result.unwrap();
|
|
||||||
let formatted = format_documents(
|
|
||||||
&index,
|
|
||||||
search_result.hits.into_iter().map(|doc| doc.document),
|
|
||||||
);
|
|
||||||
let text = formatted.join("\n");
|
|
||||||
chat_completion.messages.push(ChatCompletionRequestMessage::Tool(
|
|
||||||
ChatCompletionRequestToolMessage {
|
|
||||||
tool_call_id: call.id,
|
|
||||||
content: ChatCompletionRequestToolMessageContent::Text(text),
|
|
||||||
},
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
Err(_err) => {
|
||||||
let response = Handle::current().block_on(async {
|
// writeln!(lock, "error: {err}").unwrap();
|
||||||
client.chat().create_stream(chat_completion.clone()).await.unwrap()
|
}
|
||||||
});
|
|
||||||
|
|
||||||
// stream::iter(vec![
|
|
||||||
// Ok(Event::Data(sse::Data::new_json(json!({ "text": "Hello" })).unwrap())),
|
|
||||||
// Ok(Event::Data(sse::Data::new_json(json!({ "text": " world" })).unwrap())),
|
|
||||||
// Ok(Event::Data(sse::Data::new_json(json!({ "text": " !" })).unwrap())),
|
|
||||||
// ])
|
|
||||||
|
|
||||||
response
|
|
||||||
}
|
}
|
||||||
None => stream::iter(vec![Ok(Event::Data(sse::Data::new_json(resp).unwrap()))]),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(err) => stream::iter(vec![Err(err)]),
|
});
|
||||||
}))
|
|
||||||
|
Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The structure used to aggregate the function calls to make.
|
/// The structure used to aggregate the function calls to make.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user