mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-07-04 20:37:15 +02:00
Merge with main
This commit is contained in:
commit
90056c64f5
264 changed files with 8879 additions and 5892 deletions
|
@ -197,9 +197,11 @@ struct Infos {
|
|||
experimental_max_number_of_batched_tasks: usize,
|
||||
experimental_limit_batched_tasks_total_size: u64,
|
||||
experimental_network: bool,
|
||||
experimental_chat_completions: bool,
|
||||
experimental_get_task_documents_route: bool,
|
||||
experimental_composite_embedders: bool,
|
||||
experimental_embedding_cache_entries: usize,
|
||||
experimental_no_snapshot_compaction: bool,
|
||||
gpu_enabled: bool,
|
||||
db_path: bool,
|
||||
import_dump: bool,
|
||||
|
@ -248,6 +250,7 @@ impl Infos {
|
|||
experimental_max_number_of_batched_tasks,
|
||||
experimental_limit_batched_tasks_total_size,
|
||||
experimental_embedding_cache_entries,
|
||||
experimental_no_snapshot_compaction,
|
||||
http_addr,
|
||||
master_key: _,
|
||||
env,
|
||||
|
@ -294,6 +297,7 @@ impl Infos {
|
|||
network,
|
||||
get_task_documents_route,
|
||||
composite_embedders,
|
||||
chat_completions,
|
||||
} = features;
|
||||
|
||||
// We're going to override every sensible information.
|
||||
|
@ -312,9 +316,11 @@ impl Infos {
|
|||
experimental_enable_logs_route: experimental_enable_logs_route | logs_route,
|
||||
experimental_reduce_indexing_memory_usage,
|
||||
experimental_network: network,
|
||||
experimental_chat_completions: chat_completions,
|
||||
experimental_get_task_documents_route: get_task_documents_route,
|
||||
experimental_composite_embedders: composite_embedders,
|
||||
experimental_embedding_cache_entries,
|
||||
experimental_no_snapshot_compaction,
|
||||
gpu_enabled: meilisearch_types::milli::vector::is_cuda_enabled(),
|
||||
db_path: db_path != PathBuf::from("./data.ms"),
|
||||
import_dump: import_dump.is_some(),
|
||||
|
|
|
@ -4,6 +4,7 @@ use std::marker::PhantomData;
|
|||
use std::ops::Deref;
|
||||
use std::pin::Pin;
|
||||
|
||||
use actix_web::http::header::AUTHORIZATION;
|
||||
use actix_web::web::Data;
|
||||
use actix_web::FromRequest;
|
||||
pub use error::AuthenticationError;
|
||||
|
@ -94,36 +95,44 @@ impl<P: Policy + 'static, D: 'static + Clone> FromRequest for GuardedData<P, D>
|
|||
_payload: &mut actix_web::dev::Payload,
|
||||
) -> Self::Future {
|
||||
match req.app_data::<Data<AuthController>>().cloned() {
|
||||
Some(auth) => match req
|
||||
.headers()
|
||||
.get("Authorization")
|
||||
.map(|type_token| type_token.to_str().unwrap_or_default().splitn(2, ' '))
|
||||
{
|
||||
Some(mut type_token) => match type_token.next() {
|
||||
Some("Bearer") => {
|
||||
// TODO: find a less hardcoded way?
|
||||
let index = req.match_info().get("index_uid");
|
||||
match type_token.next() {
|
||||
Some(token) => Box::pin(Self::auth_bearer(
|
||||
auth,
|
||||
token.to_string(),
|
||||
index.map(String::from),
|
||||
req.app_data::<D>().cloned(),
|
||||
)),
|
||||
None => Box::pin(err(AuthenticationError::InvalidToken.into())),
|
||||
}
|
||||
}
|
||||
_otherwise => {
|
||||
Box::pin(err(AuthenticationError::MissingAuthorizationHeader.into()))
|
||||
}
|
||||
},
|
||||
None => Box::pin(Self::auth_token(auth, req.app_data::<D>().cloned())),
|
||||
Some(auth) => match extract_token_from_request(req) {
|
||||
Ok(Some(token)) => {
|
||||
// TODO: find a less hardcoded way?
|
||||
let index = req.match_info().get("index_uid");
|
||||
Box::pin(Self::auth_bearer(
|
||||
auth,
|
||||
token.to_string(),
|
||||
index.map(String::from),
|
||||
req.app_data::<D>().cloned(),
|
||||
))
|
||||
}
|
||||
Ok(None) => Box::pin(Self::auth_token(auth, req.app_data::<D>().cloned())),
|
||||
Err(e) => Box::pin(err(e.into())),
|
||||
},
|
||||
None => Box::pin(err(AuthenticationError::IrretrievableState.into())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn extract_token_from_request(
|
||||
req: &actix_web::HttpRequest,
|
||||
) -> Result<Option<&str>, AuthenticationError> {
|
||||
match req
|
||||
.headers()
|
||||
.get(AUTHORIZATION)
|
||||
.map(|type_token| type_token.to_str().unwrap_or_default().splitn(2, ' '))
|
||||
{
|
||||
Some(mut type_token) => match type_token.next() {
|
||||
Some("Bearer") => match type_token.next() {
|
||||
Some(token) => Ok(Some(token)),
|
||||
None => Err(AuthenticationError::InvalidToken),
|
||||
},
|
||||
_otherwise => Err(AuthenticationError::MissingAuthorizationHeader),
|
||||
},
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Policy {
|
||||
fn authenticate(
|
||||
auth: Data<AuthController>,
|
||||
|
@ -299,8 +308,8 @@ pub mod policies {
|
|||
auth: &AuthController,
|
||||
token: &str,
|
||||
) -> Result<TenantTokenOutcome, AuthError> {
|
||||
// Only search action can be accessed by a tenant token.
|
||||
if A != actions::SEARCH {
|
||||
// Only search and chat actions can be accessed by a tenant token.
|
||||
if A != actions::SEARCH && A != actions::CHAT_COMPLETIONS {
|
||||
return Ok(TenantTokenOutcome::NotATenantToken);
|
||||
}
|
||||
|
||||
|
|
|
@ -37,7 +37,9 @@ use index_scheduler::{IndexScheduler, IndexSchedulerOptions};
|
|||
use meilisearch_auth::{open_auth_store_env, AuthController};
|
||||
use meilisearch_types::milli::constants::VERSION_MAJOR;
|
||||
use meilisearch_types::milli::documents::{DocumentsBatchBuilder, DocumentsBatchReader};
|
||||
use meilisearch_types::milli::update::{IndexDocumentsConfig, IndexDocumentsMethod};
|
||||
use meilisearch_types::milli::update::{
|
||||
default_thread_pool_and_threads, IndexDocumentsConfig, IndexDocumentsMethod, IndexerConfig,
|
||||
};
|
||||
use meilisearch_types::settings::apply_settings_to_builder;
|
||||
use meilisearch_types::tasks::KindWithContent;
|
||||
use meilisearch_types::versioning::{
|
||||
|
@ -234,6 +236,7 @@ pub fn setup_meilisearch(opt: &Opt) -> anyhow::Result<(Arc<IndexScheduler>, Arc<
|
|||
instance_features: opt.to_instance_features(),
|
||||
auto_upgrade: opt.experimental_dumpless_upgrade,
|
||||
embedding_cache_cap: opt.experimental_embedding_cache_entries,
|
||||
experimental_no_snapshot_compaction: opt.experimental_no_snapshot_compaction,
|
||||
};
|
||||
let binary_version = (VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH);
|
||||
|
||||
|
@ -500,7 +503,19 @@ fn import_dump(
|
|||
let network = dump_reader.network()?.cloned().unwrap_or_default();
|
||||
index_scheduler.put_network(network)?;
|
||||
|
||||
let indexer_config = index_scheduler.indexer_config();
|
||||
// 3.1 Use all cpus to process dump if `max_indexing_threads` not configured
|
||||
let backup_config;
|
||||
let base_config = index_scheduler.indexer_config();
|
||||
|
||||
let indexer_config = if base_config.max_threads.is_none() {
|
||||
let (thread_pool, _) = default_thread_pool_and_threads();
|
||||
|
||||
let _config = IndexerConfig { thread_pool, ..*base_config };
|
||||
backup_config = _config;
|
||||
&backup_config
|
||||
} else {
|
||||
base_config
|
||||
};
|
||||
|
||||
// /!\ The tasks must be imported AFTER importing the indexes or else the scheduler might
|
||||
// try to process tasks while we're trying to import the indexes.
|
||||
|
|
|
@ -65,6 +65,7 @@ const MEILI_EXPERIMENTAL_LIMIT_BATCHED_TASKS_TOTAL_SIZE: &str =
|
|||
"MEILI_EXPERIMENTAL_LIMIT_BATCHED_TASKS_SIZE";
|
||||
const MEILI_EXPERIMENTAL_EMBEDDING_CACHE_ENTRIES: &str =
|
||||
"MEILI_EXPERIMENTAL_EMBEDDING_CACHE_ENTRIES";
|
||||
const MEILI_EXPERIMENTAL_NO_SNAPSHOT_COMPACTION: &str = "MEILI_EXPERIMENTAL_NO_SNAPSHOT_COMPACTION";
|
||||
const DEFAULT_CONFIG_FILE_PATH: &str = "./config.toml";
|
||||
const DEFAULT_DB_PATH: &str = "./data.ms";
|
||||
const DEFAULT_HTTP_ADDR: &str = "localhost:7700";
|
||||
|
@ -455,6 +456,15 @@ pub struct Opt {
|
|||
#[serde(default = "default_embedding_cache_entries")]
|
||||
pub experimental_embedding_cache_entries: usize,
|
||||
|
||||
/// Experimental no snapshot compaction feature.
|
||||
///
|
||||
/// When enabled, Meilisearch will not compact snapshots during creation.
|
||||
///
|
||||
/// For more information, see <https://github.com/orgs/meilisearch/discussions/833>.
|
||||
#[clap(long, env = MEILI_EXPERIMENTAL_NO_SNAPSHOT_COMPACTION)]
|
||||
#[serde(default)]
|
||||
pub experimental_no_snapshot_compaction: bool,
|
||||
|
||||
#[serde(flatten)]
|
||||
#[clap(flatten)]
|
||||
pub indexer_options: IndexerOpts,
|
||||
|
@ -559,6 +569,7 @@ impl Opt {
|
|||
experimental_max_number_of_batched_tasks,
|
||||
experimental_limit_batched_tasks_total_size,
|
||||
experimental_embedding_cache_entries,
|
||||
experimental_no_snapshot_compaction,
|
||||
} = self;
|
||||
export_to_env_if_not_present(MEILI_DB_PATH, db_path);
|
||||
export_to_env_if_not_present(MEILI_HTTP_ADDR, http_addr);
|
||||
|
@ -655,6 +666,10 @@ impl Opt {
|
|||
MEILI_EXPERIMENTAL_EMBEDDING_CACHE_ENTRIES,
|
||||
experimental_embedding_cache_entries.to_string(),
|
||||
);
|
||||
export_to_env_if_not_present(
|
||||
MEILI_EXPERIMENTAL_NO_SNAPSHOT_COMPACTION,
|
||||
experimental_no_snapshot_compaction.to_string(),
|
||||
);
|
||||
indexer_options.export_to_env();
|
||||
}
|
||||
|
||||
|
@ -746,10 +761,12 @@ impl IndexerOpts {
|
|||
max_indexing_memory.to_string(),
|
||||
);
|
||||
}
|
||||
export_to_env_if_not_present(
|
||||
MEILI_MAX_INDEXING_THREADS,
|
||||
max_indexing_threads.0.to_string(),
|
||||
);
|
||||
if let Some(max_indexing_threads) = max_indexing_threads.0 {
|
||||
export_to_env_if_not_present(
|
||||
MEILI_MAX_INDEXING_THREADS,
|
||||
max_indexing_threads.to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -757,15 +774,15 @@ impl TryFrom<&IndexerOpts> for IndexerConfig {
|
|||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(other: &IndexerOpts) -> Result<Self, Self::Error> {
|
||||
let thread_pool = ThreadPoolNoAbortBuilder::new()
|
||||
.thread_name(|index| format!("indexing-thread:{index}"))
|
||||
.num_threads(*other.max_indexing_threads)
|
||||
let thread_pool = ThreadPoolNoAbortBuilder::new_for_indexing()
|
||||
.num_threads(other.max_indexing_threads.unwrap_or_else(|| num_cpus::get() / 2))
|
||||
.build()?;
|
||||
|
||||
Ok(Self {
|
||||
thread_pool,
|
||||
log_every_n: Some(DEFAULT_LOG_EVERY_N),
|
||||
max_memory: other.max_indexing_memory.map(|b| b.as_u64() as usize),
|
||||
thread_pool: Some(thread_pool),
|
||||
max_threads: *other.max_indexing_threads,
|
||||
max_positions_per_attributes: None,
|
||||
skip_index_budget: other.skip_index_budget,
|
||||
..Default::default()
|
||||
|
@ -828,31 +845,31 @@ fn total_memory_bytes() -> Option<u64> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Deserialize, Serialize)]
|
||||
pub struct MaxThreads(usize);
|
||||
#[derive(Default, Debug, Clone, Copy, Deserialize, Serialize)]
|
||||
pub struct MaxThreads(Option<usize>);
|
||||
|
||||
impl FromStr for MaxThreads {
|
||||
type Err = ParseIntError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
usize::from_str(s).map(Self)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MaxThreads {
|
||||
fn default() -> Self {
|
||||
MaxThreads(num_cpus::get() / 2)
|
||||
fn from_str(s: &str) -> Result<MaxThreads, Self::Err> {
|
||||
if s.is_empty() || s == "unlimited" {
|
||||
return Ok(MaxThreads::default());
|
||||
}
|
||||
usize::from_str(s).map(Some).map(MaxThreads)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for MaxThreads {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
match self.0 {
|
||||
Some(threads) => write!(f, "{}", threads),
|
||||
None => write!(f, "unlimited"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for MaxThreads {
|
||||
type Target = usize;
|
||||
type Target = Option<usize>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
|
|
744
crates/meilisearch/src/routes/chats/chat_completions.rs
Normal file
744
crates/meilisearch/src/routes/chats/chat_completions.rs
Normal file
|
@ -0,0 +1,744 @@
|
|||
use std::collections::HashMap;
|
||||
use std::fmt::Write as _;
|
||||
use std::mem;
|
||||
use std::ops::ControlFlow;
|
||||
use std::time::Duration;
|
||||
|
||||
use actix_web::web::{self, Data};
|
||||
use actix_web::{Either, HttpRequest, HttpResponse, Responder};
|
||||
use actix_web_lab::sse::{Event, Sse};
|
||||
use async_openai::types::{
|
||||
ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
|
||||
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestDeveloperMessage,
|
||||
ChatCompletionRequestDeveloperMessageContent, ChatCompletionRequestMessage,
|
||||
ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageContent,
|
||||
ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent,
|
||||
ChatCompletionStreamResponseDelta, ChatCompletionToolArgs, ChatCompletionToolType,
|
||||
CreateChatCompletionRequest, CreateChatCompletionStreamResponse, FinishReason, FunctionCall,
|
||||
FunctionCallStream, FunctionObjectArgs,
|
||||
};
|
||||
use async_openai::Client;
|
||||
use bumpalo::Bump;
|
||||
use futures::StreamExt;
|
||||
use index_scheduler::IndexScheduler;
|
||||
use meilisearch_auth::AuthController;
|
||||
use meilisearch_types::error::{Code, ResponseError};
|
||||
use meilisearch_types::features::{
|
||||
ChatCompletionPrompts as DbChatCompletionPrompts,
|
||||
ChatCompletionSource as DbChatCompletionSource, SystemRole,
|
||||
};
|
||||
use meilisearch_types::keys::actions;
|
||||
use meilisearch_types::milli::index::ChatConfig;
|
||||
use meilisearch_types::milli::{all_obkv_to_json, obkv_to_json, TimeBudget};
|
||||
use meilisearch_types::{Document, Index};
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use tokio::runtime::Handle;
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
|
||||
use super::config::Config;
|
||||
use super::errors::{MistralError, OpenAiOutsideError, StreamErrorEvent};
|
||||
use super::utils::format_documents;
|
||||
use super::{
|
||||
ChatsParam, MEILI_APPEND_CONVERSATION_MESSAGE_NAME, MEILI_SEARCH_IN_INDEX_FUNCTION_NAME,
|
||||
MEILI_SEARCH_PROGRESS_NAME, MEILI_SEARCH_SOURCES_NAME,
|
||||
};
|
||||
use crate::error::MeilisearchHttpError;
|
||||
use crate::extractors::authentication::policies::ActionPolicy;
|
||||
use crate::extractors::authentication::{extract_token_from_request, GuardedData, Policy as _};
|
||||
use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS;
|
||||
use crate::routes::chats::utils::SseEventSender;
|
||||
use crate::routes::indexes::search::search_kind;
|
||||
use crate::search::{add_search_rules, prepare_search, search_from_kind, SearchQuery};
|
||||
use crate::search_queue::SearchQueue;
|
||||
|
||||
pub fn configure(cfg: &mut web::ServiceConfig) {
|
||||
cfg.service(web::resource("").route(web::post().to(chat)));
|
||||
}
|
||||
|
||||
/// Get a chat completion
|
||||
async fn chat(
|
||||
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT_COMPLETIONS }>, Data<IndexScheduler>>,
|
||||
auth_ctrl: web::Data<AuthController>,
|
||||
chats_param: web::Path<ChatsParam>,
|
||||
req: HttpRequest,
|
||||
search_queue: web::Data<SearchQueue>,
|
||||
web::Json(chat_completion): web::Json<CreateChatCompletionRequest>,
|
||||
) -> impl Responder {
|
||||
let ChatsParam { workspace_uid } = chats_param.into_inner();
|
||||
|
||||
if chat_completion.stream.unwrap_or(false) {
|
||||
Either::Right(
|
||||
streamed_chat(
|
||||
index_scheduler,
|
||||
auth_ctrl,
|
||||
search_queue,
|
||||
&workspace_uid,
|
||||
req,
|
||||
chat_completion,
|
||||
)
|
||||
.await,
|
||||
)
|
||||
} else {
|
||||
Either::Left(
|
||||
non_streamed_chat(
|
||||
index_scheduler,
|
||||
auth_ctrl,
|
||||
search_queue,
|
||||
&workspace_uid,
|
||||
req,
|
||||
chat_completion,
|
||||
)
|
||||
.await,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone, Copy)]
|
||||
pub struct FunctionSupport {
|
||||
/// Defines if we can call the _meiliSearchProgress function
|
||||
/// to inform the front-end about what we are searching for.
|
||||
report_progress: bool,
|
||||
/// Defines if we can call the _meiliSearchSources function
|
||||
/// to inform the front-end about the sources of the search.
|
||||
report_sources: bool,
|
||||
/// Defines if we can call the _meiliAppendConversationMessage
|
||||
/// function to provide the messages to append into the conversation.
|
||||
append_to_conversation: bool,
|
||||
}
|
||||
|
||||
/// Setup search tool in chat completion request
|
||||
fn setup_search_tool(
|
||||
index_scheduler: &Data<IndexScheduler>,
|
||||
filters: &meilisearch_auth::AuthFilter,
|
||||
chat_completion: &mut CreateChatCompletionRequest,
|
||||
prompts: &DbChatCompletionPrompts,
|
||||
system_role: SystemRole,
|
||||
) -> Result<FunctionSupport, ResponseError> {
|
||||
let tools = chat_completion.tools.get_or_insert_default();
|
||||
for tool in &tools[..] {
|
||||
match tool.function.name.as_str() {
|
||||
MEILI_SEARCH_IN_INDEX_FUNCTION_NAME => {
|
||||
return Err(ResponseError::from_msg(
|
||||
format!("{MEILI_SEARCH_IN_INDEX_FUNCTION_NAME} function is already defined."),
|
||||
Code::BadRequest,
|
||||
));
|
||||
}
|
||||
MEILI_SEARCH_PROGRESS_NAME
|
||||
| MEILI_SEARCH_SOURCES_NAME
|
||||
| MEILI_APPEND_CONVERSATION_MESSAGE_NAME => (),
|
||||
external_function_name => {
|
||||
return Err(ResponseError::from_msg(
|
||||
format!("{external_function_name}: External functions are not supported yet."),
|
||||
Code::UnimplementedExternalFunctionCalling,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove internal tools used for front-end notifications as they should be hidden from the LLM.
|
||||
let mut report_progress = false;
|
||||
let mut report_sources = false;
|
||||
let mut append_to_conversation = false;
|
||||
tools.retain(|tool| {
|
||||
match tool.function.name.as_str() {
|
||||
MEILI_SEARCH_PROGRESS_NAME => {
|
||||
report_progress = true;
|
||||
false
|
||||
}
|
||||
MEILI_SEARCH_SOURCES_NAME => {
|
||||
report_sources = true;
|
||||
false
|
||||
}
|
||||
MEILI_APPEND_CONVERSATION_MESSAGE_NAME => {
|
||||
append_to_conversation = true;
|
||||
false
|
||||
}
|
||||
_ => true, // keep other tools
|
||||
}
|
||||
});
|
||||
|
||||
let mut index_uids = Vec::new();
|
||||
let mut function_description = prompts.search_description.clone();
|
||||
index_scheduler.try_for_each_index::<_, ()>(|name, index| {
|
||||
// Make sure to skip unauthorized indexes
|
||||
if !filters.is_index_authorized(name) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let rtxn = index.read_txn()?;
|
||||
let chat_config = index.chat_config(&rtxn)?;
|
||||
let index_description = chat_config.description;
|
||||
let _ = writeln!(&mut function_description, "\n\n - {name}: {index_description}\n");
|
||||
index_uids.push(name.to_string());
|
||||
|
||||
Ok(())
|
||||
})?;
|
||||
|
||||
let tool = ChatCompletionToolArgs::default()
|
||||
.r#type(ChatCompletionToolType::Function)
|
||||
.function(
|
||||
FunctionObjectArgs::default()
|
||||
.name(MEILI_SEARCH_IN_INDEX_FUNCTION_NAME)
|
||||
.description(&function_description)
|
||||
.parameters(json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"index_uid": {
|
||||
"type": "string",
|
||||
"enum": index_uids,
|
||||
"description": prompts.search_index_uid_param,
|
||||
},
|
||||
"q": {
|
||||
// Unfortunately, Mistral does not support an array of types, here.
|
||||
// "type": ["string", "null"],
|
||||
"type": "string",
|
||||
"description": prompts.search_q_param,
|
||||
}
|
||||
},
|
||||
"required": ["index_uid", "q"],
|
||||
"additionalProperties": false,
|
||||
}))
|
||||
.strict(true)
|
||||
.build()
|
||||
.unwrap(),
|
||||
)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
tools.push(tool);
|
||||
|
||||
let system_message = match system_role {
|
||||
SystemRole::System => {
|
||||
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
|
||||
content: ChatCompletionRequestSystemMessageContent::Text(prompts.system.clone()),
|
||||
name: None,
|
||||
})
|
||||
}
|
||||
SystemRole::Developer => {
|
||||
ChatCompletionRequestMessage::Developer(ChatCompletionRequestDeveloperMessage {
|
||||
content: ChatCompletionRequestDeveloperMessageContent::Text(prompts.system.clone()),
|
||||
name: None,
|
||||
})
|
||||
}
|
||||
};
|
||||
chat_completion.messages.insert(0, system_message);
|
||||
|
||||
Ok(FunctionSupport { report_progress, report_sources, append_to_conversation })
|
||||
}
|
||||
|
||||
/// Process search request and return formatted results
|
||||
async fn process_search_request(
|
||||
index_scheduler: &GuardedData<
|
||||
ActionPolicy<{ actions::CHAT_COMPLETIONS }>,
|
||||
Data<IndexScheduler>,
|
||||
>,
|
||||
auth_ctrl: web::Data<AuthController>,
|
||||
search_queue: &web::Data<SearchQueue>,
|
||||
auth_token: &str,
|
||||
index_uid: String,
|
||||
q: Option<String>,
|
||||
) -> Result<(Index, Vec<Document>, String), ResponseError> {
|
||||
let index = index_scheduler.index(&index_uid)?;
|
||||
let rtxn = index.static_read_txn()?;
|
||||
let ChatConfig { description: _, prompt: _, search_parameters } = index.chat_config(&rtxn)?;
|
||||
let mut query = SearchQuery { q, ..SearchQuery::from(search_parameters) };
|
||||
let auth_filter = ActionPolicy::<{ actions::SEARCH }>::authenticate(
|
||||
auth_ctrl,
|
||||
auth_token,
|
||||
Some(index_uid.as_str()),
|
||||
)?;
|
||||
|
||||
// Tenant token search_rules.
|
||||
if let Some(search_rules) = auth_filter.get_index_search_rules(&index_uid) {
|
||||
add_search_rules(&mut query.filter, search_rules);
|
||||
}
|
||||
let search_kind =
|
||||
search_kind(&query, index_scheduler.get_ref(), index_uid.to_string(), &index)?;
|
||||
|
||||
let permit = search_queue.try_get_search_permit().await?;
|
||||
let features = index_scheduler.features();
|
||||
let index_cloned = index.clone();
|
||||
let output = tokio::task::spawn_blocking(move || -> Result<_, ResponseError> {
|
||||
let time_budget = match index_cloned
|
||||
.search_cutoff(&rtxn)
|
||||
.map_err(|e| MeilisearchHttpError::from_milli(e, Some(index_uid.clone())))?
|
||||
{
|
||||
Some(cutoff) => TimeBudget::new(Duration::from_millis(cutoff)),
|
||||
None => TimeBudget::default(),
|
||||
};
|
||||
|
||||
let (search, _is_finite_pagination, _max_total_hits, _offset) =
|
||||
prepare_search(&index_cloned, &rtxn, &query, &search_kind, time_budget, features)?;
|
||||
|
||||
search_from_kind(index_uid, search_kind, search)
|
||||
.map(|(search_results, _)| (rtxn, search_results))
|
||||
.map_err(ResponseError::from)
|
||||
})
|
||||
.await;
|
||||
permit.drop().await;
|
||||
|
||||
let output = output?;
|
||||
let mut documents = Vec::new();
|
||||
if let Ok((ref rtxn, ref search_result)) = output {
|
||||
// aggregate.succeed(search_result);
|
||||
if search_result.degraded {
|
||||
MEILISEARCH_DEGRADED_SEARCH_REQUESTS.inc();
|
||||
}
|
||||
|
||||
let fields_ids_map = index.fields_ids_map(rtxn)?;
|
||||
let displayed_fields = index.displayed_fields_ids(rtxn)?;
|
||||
for &document_id in &search_result.documents_ids {
|
||||
let obkv = index.document(rtxn, document_id)?;
|
||||
let document = match displayed_fields {
|
||||
Some(ref fields) => obkv_to_json(fields, &fields_ids_map, obkv)?,
|
||||
None => all_obkv_to_json(obkv, &fields_ids_map)?,
|
||||
};
|
||||
documents.push(document);
|
||||
}
|
||||
}
|
||||
|
||||
let (rtxn, search_result) = output?;
|
||||
let render_alloc = Bump::new();
|
||||
let formatted = format_documents(&rtxn, &index, &render_alloc, search_result.documents_ids)?;
|
||||
let text = formatted.join("\n");
|
||||
drop(rtxn);
|
||||
|
||||
Ok((index, documents, text))
|
||||
}
|
||||
|
||||
#[allow(unreachable_code, unused_variables)] // will be correctly implemented in the future
|
||||
async fn non_streamed_chat(
|
||||
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT_COMPLETIONS }>, Data<IndexScheduler>>,
|
||||
auth_ctrl: web::Data<AuthController>,
|
||||
search_queue: web::Data<SearchQueue>,
|
||||
workspace_uid: &str,
|
||||
req: HttpRequest,
|
||||
chat_completion: CreateChatCompletionRequest,
|
||||
) -> Result<HttpResponse, ResponseError> {
|
||||
index_scheduler.features().check_chat_completions("using the /chats chat completions route")?;
|
||||
|
||||
if let Some(n) = chat_completion.n.filter(|&n| n != 1) {
|
||||
return Err(ResponseError::from_msg(
|
||||
format!("You tried to specify n = {n} but only single choices are supported (n = 1)."),
|
||||
Code::UnimplementedMultiChoiceChatCompletions,
|
||||
));
|
||||
}
|
||||
|
||||
return Err(ResponseError::from_msg(
|
||||
"Non-streamed chat completions is not implemented".to_string(),
|
||||
Code::UnimplementedNonStreamingChatCompletions,
|
||||
));
|
||||
|
||||
let filters = index_scheduler.filters();
|
||||
let chat_settings = match index_scheduler.chat_settings(workspace_uid).unwrap() {
|
||||
Some(settings) => settings,
|
||||
None => {
|
||||
return Err(ResponseError::from_msg(
|
||||
format!("Chat `{workspace_uid}` not found"),
|
||||
Code::ChatNotFound,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let config = Config::new(&chat_settings);
|
||||
let client = Client::with_config(config);
|
||||
let auth_token = extract_token_from_request(&req)?.unwrap();
|
||||
let system_role = chat_settings.source.system_role(&chat_completion.model);
|
||||
// TODO do function support later
|
||||
let _function_support = setup_search_tool(
|
||||
&index_scheduler,
|
||||
filters,
|
||||
&mut chat_completion,
|
||||
&chat_settings.prompts,
|
||||
system_role,
|
||||
)?;
|
||||
|
||||
let mut response;
|
||||
loop {
|
||||
response = client.chat().create(chat_completion.clone()).await.unwrap();
|
||||
|
||||
let choice = &mut response.choices[0];
|
||||
match choice.finish_reason {
|
||||
Some(FinishReason::ToolCalls) => {
|
||||
let tool_calls = mem::take(&mut choice.message.tool_calls).unwrap_or_default();
|
||||
|
||||
let (meili_calls, other_calls): (Vec<_>, Vec<_>) = tool_calls
|
||||
.into_iter()
|
||||
.partition(|call| call.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME);
|
||||
|
||||
chat_completion.messages.push(
|
||||
ChatCompletionRequestAssistantMessageArgs::default()
|
||||
.tool_calls(meili_calls.clone())
|
||||
.build()
|
||||
.unwrap()
|
||||
.into(),
|
||||
);
|
||||
|
||||
for call in meili_calls {
|
||||
let result = match serde_json::from_str(&call.function.arguments) {
|
||||
Ok(SearchInIndexParameters { index_uid, q }) => process_search_request(
|
||||
&index_scheduler,
|
||||
auth_ctrl.clone(),
|
||||
&search_queue,
|
||||
auth_token,
|
||||
index_uid,
|
||||
q,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| e.to_string()),
|
||||
Err(err) => Err(err.to_string()),
|
||||
};
|
||||
|
||||
// TODO report documents sources later
|
||||
let answer = match result {
|
||||
Ok((_, _documents, text)) => text,
|
||||
Err(err) => err,
|
||||
};
|
||||
|
||||
chat_completion.messages.push(ChatCompletionRequestMessage::Tool(
|
||||
ChatCompletionRequestToolMessage {
|
||||
tool_call_id: call.id.clone(),
|
||||
content: ChatCompletionRequestToolMessageContent::Text(answer),
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
// Let the client call other tools by themselves
|
||||
if !other_calls.is_empty() {
|
||||
response.choices[0].message.tool_calls = Some(other_calls);
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
|
||||
Ok(HttpResponse::Ok().json(response))
|
||||
}
|
||||
|
||||
async fn streamed_chat(
|
||||
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT_COMPLETIONS }>, Data<IndexScheduler>>,
|
||||
auth_ctrl: web::Data<AuthController>,
|
||||
search_queue: web::Data<SearchQueue>,
|
||||
workspace_uid: &str,
|
||||
req: HttpRequest,
|
||||
mut chat_completion: CreateChatCompletionRequest,
|
||||
) -> Result<impl Responder, ResponseError> {
|
||||
index_scheduler.features().check_chat_completions("using the /chats chat completions route")?;
|
||||
let filters = index_scheduler.filters();
|
||||
|
||||
if let Some(n) = chat_completion.n.filter(|&n| n != 1) {
|
||||
return Err(ResponseError::from_msg(
|
||||
format!("You tried to specify n = {n} but only single choices are supported (n = 1)."),
|
||||
Code::UnimplementedMultiChoiceChatCompletions,
|
||||
));
|
||||
}
|
||||
|
||||
let chat_settings = match index_scheduler.chat_settings(workspace_uid)? {
|
||||
Some(settings) => settings,
|
||||
None => {
|
||||
return Err(ResponseError::from_msg(
|
||||
format!("Chat `{workspace_uid}` not found"),
|
||||
Code::ChatNotFound,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let config = Config::new(&chat_settings);
|
||||
let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
|
||||
let system_role = chat_settings.source.system_role(&chat_completion.model);
|
||||
let function_support = setup_search_tool(
|
||||
&index_scheduler,
|
||||
filters,
|
||||
&mut chat_completion,
|
||||
&chat_settings.prompts,
|
||||
system_role,
|
||||
)?;
|
||||
|
||||
tracing::debug!("Conversation function support: {function_support:?}");
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(10);
|
||||
let tx = SseEventSender::new(tx);
|
||||
let _join_handle = Handle::current().spawn(async move {
|
||||
let client = Client::with_config(config.clone());
|
||||
let mut global_tool_calls = HashMap::<u32, Call>::new();
|
||||
|
||||
// Limit the number of internal calls to satisfy the search requests of the LLM
|
||||
for _ in 0..20 {
|
||||
let output = run_conversation(
|
||||
&index_scheduler,
|
||||
&auth_ctrl,
|
||||
&search_queue,
|
||||
&auth_token,
|
||||
&client,
|
||||
chat_settings.source,
|
||||
&mut chat_completion,
|
||||
&tx,
|
||||
&mut global_tool_calls,
|
||||
function_support,
|
||||
);
|
||||
|
||||
match output.await {
|
||||
Ok(ControlFlow::Continue(())) => (),
|
||||
Ok(ControlFlow::Break(_finish_reason)) => break,
|
||||
// If the connection is closed we must stop
|
||||
Err(SendError(_)) => return,
|
||||
}
|
||||
}
|
||||
|
||||
let _ = tx.stop().await;
|
||||
});
|
||||
|
||||
Ok(Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10)))
|
||||
}
|
||||
|
||||
/// Updates the chat completion with the new messages, streams the LLM tokens,
|
||||
/// and report progress and errors.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn run_conversation<C: async_openai::config::Config>(
|
||||
index_scheduler: &GuardedData<
|
||||
ActionPolicy<{ actions::CHAT_COMPLETIONS }>,
|
||||
Data<IndexScheduler>,
|
||||
>,
|
||||
auth_ctrl: &web::Data<AuthController>,
|
||||
search_queue: &web::Data<SearchQueue>,
|
||||
auth_token: &str,
|
||||
client: &Client<C>,
|
||||
source: DbChatCompletionSource,
|
||||
chat_completion: &mut CreateChatCompletionRequest,
|
||||
tx: &SseEventSender,
|
||||
global_tool_calls: &mut HashMap<u32, Call>,
|
||||
function_support: FunctionSupport,
|
||||
) -> Result<ControlFlow<Option<FinishReason>, ()>, SendError<Event>> {
|
||||
let mut finish_reason = None;
|
||||
// safety: unwrap: can only happens if `stream` was set to `false`
|
||||
let mut response = client.chat().create_stream(chat_completion.clone()).await.unwrap();
|
||||
while let Some(result) = response.next().await {
|
||||
match result {
|
||||
Ok(resp) => {
|
||||
let choice = &resp.choices[0];
|
||||
finish_reason = choice.finish_reason;
|
||||
|
||||
let ChatCompletionStreamResponseDelta { ref tool_calls, .. } = &choice.delta;
|
||||
|
||||
match tool_calls {
|
||||
Some(tool_calls) => {
|
||||
for chunk in tool_calls {
|
||||
let ChatCompletionMessageToolCallChunk {
|
||||
index,
|
||||
id,
|
||||
r#type: _,
|
||||
function,
|
||||
} = chunk;
|
||||
let FunctionCallStream { name, arguments } = function.as_ref().unwrap();
|
||||
|
||||
global_tool_calls
|
||||
.entry(*index)
|
||||
.and_modify(|call| {
|
||||
if call.is_internal() {
|
||||
call.append(arguments.as_ref().unwrap())
|
||||
}
|
||||
})
|
||||
.or_insert_with(|| {
|
||||
if name.as_deref() == Some(MEILI_SEARCH_IN_INDEX_FUNCTION_NAME)
|
||||
{
|
||||
Call::Internal {
|
||||
id: id.as_ref().unwrap().clone(),
|
||||
function_name: name.as_ref().unwrap().clone(),
|
||||
arguments: arguments.as_ref().unwrap().clone(),
|
||||
}
|
||||
} else {
|
||||
Call::External
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
None => {
|
||||
if !global_tool_calls.is_empty() {
|
||||
let (meili_calls, _other_calls): (Vec<_>, Vec<_>) =
|
||||
mem::take(global_tool_calls)
|
||||
.into_values()
|
||||
.flat_map(|call| match call {
|
||||
Call::Internal { id, function_name: name, arguments } => {
|
||||
Some(ChatCompletionMessageToolCall {
|
||||
id,
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: FunctionCall { name, arguments },
|
||||
})
|
||||
}
|
||||
Call::External => None,
|
||||
})
|
||||
.partition(|call| {
|
||||
call.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME
|
||||
});
|
||||
|
||||
chat_completion.messages.push(
|
||||
ChatCompletionRequestAssistantMessageArgs::default()
|
||||
.tool_calls(meili_calls.clone())
|
||||
.build()
|
||||
.unwrap()
|
||||
.into(),
|
||||
);
|
||||
|
||||
handle_meili_tools(
|
||||
index_scheduler,
|
||||
auth_ctrl,
|
||||
search_queue,
|
||||
auth_token,
|
||||
tx,
|
||||
meili_calls,
|
||||
chat_completion,
|
||||
&resp,
|
||||
function_support,
|
||||
)
|
||||
.await?;
|
||||
} else {
|
||||
tx.forward_response(&resp).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(error) => {
|
||||
let result = match source {
|
||||
DbChatCompletionSource::Mistral => {
|
||||
StreamErrorEvent::from_openai_error::<MistralError>(error).await
|
||||
}
|
||||
_ => StreamErrorEvent::from_openai_error::<OpenAiOutsideError>(error).await,
|
||||
};
|
||||
let error = result.unwrap_or_else(StreamErrorEvent::from_reqwest_error);
|
||||
tx.send_error(&error).await?;
|
||||
return Ok(ControlFlow::Break(None));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We must stop if the finish reason is not something we can solve with Meilisearch
|
||||
match finish_reason {
|
||||
Some(FinishReason::ToolCalls) => Ok(ControlFlow::Continue(())),
|
||||
otherwise => Ok(ControlFlow::Break(otherwise)),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn handle_meili_tools(
|
||||
index_scheduler: &GuardedData<
|
||||
ActionPolicy<{ actions::CHAT_COMPLETIONS }>,
|
||||
Data<IndexScheduler>,
|
||||
>,
|
||||
auth_ctrl: &web::Data<AuthController>,
|
||||
search_queue: &web::Data<SearchQueue>,
|
||||
auth_token: &str,
|
||||
tx: &SseEventSender,
|
||||
meili_calls: Vec<ChatCompletionMessageToolCall>,
|
||||
chat_completion: &mut CreateChatCompletionRequest,
|
||||
resp: &CreateChatCompletionStreamResponse,
|
||||
FunctionSupport { report_progress, report_sources, append_to_conversation, .. }: FunctionSupport,
|
||||
) -> Result<(), SendError<Event>> {
|
||||
for call in meili_calls {
|
||||
if report_progress {
|
||||
tx.report_search_progress(
|
||||
resp.clone(),
|
||||
&call.id,
|
||||
&call.function.name,
|
||||
&call.function.arguments,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
if append_to_conversation {
|
||||
tx.append_tool_call_conversation_message(
|
||||
resp.clone(),
|
||||
call.id.clone(),
|
||||
call.function.name.clone(),
|
||||
call.function.arguments.clone(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let mut error = None;
|
||||
|
||||
let result = match serde_json::from_str(&call.function.arguments) {
|
||||
Ok(SearchInIndexParameters { index_uid, q }) => match process_search_request(
|
||||
index_scheduler,
|
||||
auth_ctrl.clone(),
|
||||
search_queue,
|
||||
auth_token,
|
||||
index_uid,
|
||||
q,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(output) => Ok(output),
|
||||
Err(err) => {
|
||||
let error_text = format!("the search tool call failed with {err}");
|
||||
error = Some(err);
|
||||
Err(error_text)
|
||||
}
|
||||
},
|
||||
Err(err) => Err(err.to_string()),
|
||||
};
|
||||
|
||||
let answer = match result {
|
||||
Ok((_index, documents, text)) => {
|
||||
if report_sources {
|
||||
tx.report_sources(resp.clone(), &call.id, &documents).await?;
|
||||
}
|
||||
text
|
||||
}
|
||||
Err(err) => err,
|
||||
};
|
||||
|
||||
let tool = ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessage {
|
||||
tool_call_id: call.id.clone(),
|
||||
content: ChatCompletionRequestToolMessageContent::Text(answer),
|
||||
});
|
||||
|
||||
if append_to_conversation {
|
||||
tx.append_conversation_message(resp.clone(), &tool).await?;
|
||||
}
|
||||
|
||||
chat_completion.messages.push(tool);
|
||||
|
||||
if let Some(error) = error {
|
||||
tx.send_error(&StreamErrorEvent::from_response_error(error)).await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// The structure used to aggregate the function calls to make.
|
||||
#[derive(Debug)]
|
||||
enum Call {
|
||||
/// Tool calls to tools that must be managed by Meilisearch internally.
|
||||
/// Typically the search functions.
|
||||
Internal { id: String, function_name: String, arguments: String },
|
||||
/// Tool calls that we track but only to know that its not our functions.
|
||||
/// We return the function calls as-is to the end-user.
|
||||
External,
|
||||
}
|
||||
|
||||
impl Call {
|
||||
fn is_internal(&self) -> bool {
|
||||
matches!(self, Call::Internal { .. })
|
||||
}
|
||||
|
||||
/// # Panics
|
||||
///
|
||||
/// - if called on external calls
|
||||
fn append(&mut self, more: &str) {
|
||||
match self {
|
||||
Call::Internal { arguments, .. } => arguments.push_str(more),
|
||||
Call::External => panic!("Cannot append argument chunks to an external function"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct SearchInIndexParameters {
|
||||
/// The index uid to search in.
|
||||
index_uid: String,
|
||||
/// The query parameter to use.
|
||||
q: Option<String>,
|
||||
}
|
88
crates/meilisearch/src/routes/chats/config.rs
Normal file
88
crates/meilisearch/src/routes/chats/config.rs
Normal file
|
@ -0,0 +1,88 @@
|
|||
use async_openai::config::{AzureConfig, OpenAIConfig};
|
||||
use meilisearch_types::features::ChatCompletionSettings as DbChatSettings;
|
||||
use reqwest::header::HeaderMap;
|
||||
use secrecy::SecretString;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Config {
|
||||
OpenAiCompatible(OpenAIConfig),
|
||||
AzureOpenAiCompatible(AzureConfig),
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn new(chat_settings: &DbChatSettings) -> Self {
|
||||
use meilisearch_types::features::ChatCompletionSource::*;
|
||||
match chat_settings.source {
|
||||
OpenAi | Mistral | Gemini | VLlm => {
|
||||
let mut config = OpenAIConfig::default();
|
||||
if let Some(org_id) = chat_settings.org_id.as_ref() {
|
||||
config = config.with_org_id(org_id);
|
||||
}
|
||||
if let Some(project_id) = chat_settings.project_id.as_ref() {
|
||||
config = config.with_project_id(project_id);
|
||||
}
|
||||
if let Some(api_key) = chat_settings.api_key.as_ref() {
|
||||
config = config.with_api_key(api_key);
|
||||
}
|
||||
let base_url = chat_settings.base_url.as_deref();
|
||||
if let Some(base_url) = chat_settings.source.base_url().or(base_url) {
|
||||
config = config.with_api_base(base_url);
|
||||
}
|
||||
Self::OpenAiCompatible(config)
|
||||
}
|
||||
AzureOpenAi => {
|
||||
let mut config = AzureConfig::default();
|
||||
if let Some(version) = chat_settings.api_version.as_ref() {
|
||||
config = config.with_api_version(version);
|
||||
}
|
||||
if let Some(deployment_id) = chat_settings.deployment_id.as_ref() {
|
||||
config = config.with_deployment_id(deployment_id);
|
||||
}
|
||||
if let Some(api_key) = chat_settings.api_key.as_ref() {
|
||||
config = config.with_api_key(api_key);
|
||||
}
|
||||
if let Some(base_url) = chat_settings.base_url.as_ref() {
|
||||
config = config.with_api_base(base_url);
|
||||
}
|
||||
Self::AzureOpenAiCompatible(config)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl async_openai::config::Config for Config {
|
||||
fn headers(&self) -> HeaderMap {
|
||||
match self {
|
||||
Config::OpenAiCompatible(config) => config.headers(),
|
||||
Config::AzureOpenAiCompatible(config) => config.headers(),
|
||||
}
|
||||
}
|
||||
|
||||
fn url(&self, path: &str) -> String {
|
||||
match self {
|
||||
Config::OpenAiCompatible(config) => config.url(path),
|
||||
Config::AzureOpenAiCompatible(config) => config.url(path),
|
||||
}
|
||||
}
|
||||
|
||||
fn query(&self) -> Vec<(&str, &str)> {
|
||||
match self {
|
||||
Config::OpenAiCompatible(config) => config.query(),
|
||||
Config::AzureOpenAiCompatible(config) => config.query(),
|
||||
}
|
||||
}
|
||||
|
||||
fn api_base(&self) -> &str {
|
||||
match self {
|
||||
Config::OpenAiCompatible(config) => config.api_base(),
|
||||
Config::AzureOpenAiCompatible(config) => config.api_base(),
|
||||
}
|
||||
}
|
||||
|
||||
fn api_key(&self) -> &SecretString {
|
||||
match self {
|
||||
Config::OpenAiCompatible(config) => config.api_key(),
|
||||
Config::AzureOpenAiCompatible(config) => config.api_key(),
|
||||
}
|
||||
}
|
||||
}
|
250
crates/meilisearch/src/routes/chats/errors.rs
Normal file
250
crates/meilisearch/src/routes/chats/errors.rs
Normal file
|
@ -0,0 +1,250 @@
|
|||
use async_openai::error::{ApiError, OpenAIError};
|
||||
use async_openai::reqwest_eventsource::Error as EventSourceError;
|
||||
use meilisearch_types::error::ResponseError;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// The error type which is always `error`.
|
||||
const ERROR_TYPE: &str = "error";
|
||||
|
||||
/// The error struct returned by the Mistral API.
|
||||
///
|
||||
/// ```json
|
||||
/// {
|
||||
/// "object": "error",
|
||||
/// "message": "Service tier capacity exceeded for this model.",
|
||||
/// "type": "invalid_request_error",
|
||||
/// "param": null,
|
||||
/// "code": null
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct MistralError {
|
||||
message: String,
|
||||
r#type: String,
|
||||
param: Option<String>,
|
||||
code: Option<String>,
|
||||
}
|
||||
|
||||
impl From<MistralError> for StreamErrorEvent {
|
||||
fn from(error: MistralError) -> Self {
|
||||
let MistralError { message, r#type, param, code } = error;
|
||||
StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: ERROR_TYPE.to_owned(),
|
||||
error: StreamError { r#type, code, message, param, event_id: None },
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct OpenAiOutsideError {
|
||||
/// Emitted when an error occurs.
|
||||
error: OpenAiInnerError,
|
||||
}
|
||||
|
||||
/// Emitted when an error occurs.
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct OpenAiInnerError {
|
||||
/// The error code.
|
||||
code: Option<String>,
|
||||
/// The error message.
|
||||
message: String,
|
||||
/// The error parameter.
|
||||
param: Option<String>,
|
||||
/// The type of the event. Always `error`.
|
||||
r#type: String,
|
||||
}
|
||||
|
||||
impl From<OpenAiOutsideError> for StreamErrorEvent {
|
||||
fn from(error: OpenAiOutsideError) -> Self {
|
||||
let OpenAiOutsideError { error: OpenAiInnerError { code, message, param, r#type } } = error;
|
||||
StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: ERROR_TYPE.to_string(),
|
||||
error: StreamError { r#type, code, message, param, event_id: None },
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An error that occurs during the streaming process.
|
||||
///
|
||||
/// It directly comes from the OpenAI API and you can
|
||||
/// read more about error events on their website:
|
||||
/// <https://platform.openai.com/docs/api-reference/realtime-server-events/error>
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct StreamErrorEvent {
|
||||
/// The unique ID of the server event.
|
||||
pub event_id: String,
|
||||
/// The event type, must be error.
|
||||
pub r#type: String,
|
||||
/// Details of the error.
|
||||
pub error: StreamError,
|
||||
}
|
||||
|
||||
/// Details of the error.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct StreamError {
|
||||
/// The type of error (e.g., "invalid_request_error", "server_error").
|
||||
pub r#type: String,
|
||||
/// Error code, if any.
|
||||
pub code: Option<String>,
|
||||
/// A human-readable error message.
|
||||
pub message: String,
|
||||
/// Parameter related to the error, if any.
|
||||
pub param: Option<String>,
|
||||
/// The event_id of the client event that caused the error, if applicable.
|
||||
pub event_id: Option<String>,
|
||||
}
|
||||
|
||||
impl StreamErrorEvent {
|
||||
pub async fn from_openai_error<E>(error: OpenAIError) -> Result<Self, reqwest::Error>
|
||||
where
|
||||
E: serde::de::DeserializeOwned,
|
||||
Self: From<E>,
|
||||
{
|
||||
match error {
|
||||
OpenAIError::Reqwest(e) => Ok(StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: ERROR_TYPE.to_string(),
|
||||
error: StreamError {
|
||||
r#type: "internal_reqwest_error".to_string(),
|
||||
code: Some("internal".to_string()),
|
||||
message: e.to_string(),
|
||||
param: None,
|
||||
event_id: None,
|
||||
},
|
||||
}),
|
||||
OpenAIError::ApiError(ApiError { message, r#type, param, code }) => {
|
||||
Ok(StreamErrorEvent {
|
||||
r#type: ERROR_TYPE.to_string(),
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
error: StreamError {
|
||||
r#type: r#type.unwrap_or_else(|| "unknown".to_string()),
|
||||
code,
|
||||
message,
|
||||
param,
|
||||
event_id: None,
|
||||
},
|
||||
})
|
||||
}
|
||||
OpenAIError::JSONDeserialize(error) => Ok(StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: ERROR_TYPE.to_string(),
|
||||
error: StreamError {
|
||||
r#type: "json_deserialize_error".to_string(),
|
||||
code: Some("internal".to_string()),
|
||||
message: error.to_string(),
|
||||
param: None,
|
||||
event_id: None,
|
||||
},
|
||||
}),
|
||||
OpenAIError::FileSaveError(_) | OpenAIError::FileReadError(_) => unreachable!(),
|
||||
OpenAIError::StreamError(error) => match error {
|
||||
EventSourceError::InvalidStatusCode(_status_code, response) => {
|
||||
let error = response.json::<E>().await?;
|
||||
Ok(StreamErrorEvent::from(error))
|
||||
}
|
||||
EventSourceError::InvalidContentType(_header_value, response) => {
|
||||
let error = response.json::<E>().await?;
|
||||
Ok(StreamErrorEvent::from(error))
|
||||
}
|
||||
EventSourceError::Utf8(error) => Ok(StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: ERROR_TYPE.to_string(),
|
||||
error: StreamError {
|
||||
r#type: "invalid_utf8_error".to_string(),
|
||||
code: None,
|
||||
message: error.to_string(),
|
||||
param: None,
|
||||
event_id: None,
|
||||
},
|
||||
}),
|
||||
EventSourceError::Parser(error) => Ok(StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: ERROR_TYPE.to_string(),
|
||||
error: StreamError {
|
||||
r#type: "parser_error".to_string(),
|
||||
code: None,
|
||||
message: error.to_string(),
|
||||
param: None,
|
||||
event_id: None,
|
||||
},
|
||||
}),
|
||||
EventSourceError::Transport(error) => Ok(StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: ERROR_TYPE.to_string(),
|
||||
error: StreamError {
|
||||
r#type: "transport_error".to_string(),
|
||||
code: None,
|
||||
message: error.to_string(),
|
||||
param: None,
|
||||
event_id: None,
|
||||
},
|
||||
}),
|
||||
EventSourceError::InvalidLastEventId(message) => Ok(StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: ERROR_TYPE.to_string(),
|
||||
error: StreamError {
|
||||
r#type: "invalid_last_event_id".to_string(),
|
||||
code: None,
|
||||
message,
|
||||
param: None,
|
||||
event_id: None,
|
||||
},
|
||||
}),
|
||||
EventSourceError::StreamEnded => Ok(StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: ERROR_TYPE.to_string(),
|
||||
error: StreamError {
|
||||
r#type: "stream_ended".to_string(),
|
||||
code: None,
|
||||
message: "Stream ended".to_string(),
|
||||
param: None,
|
||||
event_id: None,
|
||||
},
|
||||
}),
|
||||
},
|
||||
OpenAIError::InvalidArgument(message) => Ok(StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: ERROR_TYPE.to_string(),
|
||||
error: StreamError {
|
||||
r#type: "invalid_argument".to_string(),
|
||||
code: None,
|
||||
message,
|
||||
param: None,
|
||||
event_id: None,
|
||||
},
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_response_error(error: ResponseError) -> Self {
|
||||
let ResponseError { code, message, .. } = error;
|
||||
StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: ERROR_TYPE.to_string(),
|
||||
error: StreamError {
|
||||
r#type: "response_error".to_string(),
|
||||
code: Some(code.as_str().to_string()),
|
||||
message,
|
||||
param: None,
|
||||
event_id: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_reqwest_error(error: reqwest::Error) -> Self {
|
||||
StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: ERROR_TYPE.to_string(),
|
||||
error: StreamError {
|
||||
r#type: "reqwest_error".to_string(),
|
||||
code: None,
|
||||
message: error.to_string(),
|
||||
param: None,
|
||||
event_id: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
134
crates/meilisearch/src/routes/chats/mod.rs
Normal file
134
crates/meilisearch/src/routes/chats/mod.rs
Normal file
|
@ -0,0 +1,134 @@
|
|||
use actix_web::web::{self, Data};
|
||||
use actix_web::HttpResponse;
|
||||
use deserr::actix_web::AwebQueryParameter;
|
||||
use deserr::Deserr;
|
||||
use index_scheduler::IndexScheduler;
|
||||
use meilisearch_types::deserr::query_params::Param;
|
||||
use meilisearch_types::deserr::DeserrQueryParamError;
|
||||
use meilisearch_types::error::deserr_codes::{InvalidIndexLimit, InvalidIndexOffset};
|
||||
use meilisearch_types::error::{Code, ResponseError};
|
||||
use meilisearch_types::index_uid::IndexUid;
|
||||
use meilisearch_types::keys::actions;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use tracing::debug;
|
||||
use utoipa::{IntoParams, ToSchema};
|
||||
|
||||
use super::Pagination;
|
||||
use crate::extractors::authentication::policies::ActionPolicy;
|
||||
use crate::extractors::authentication::GuardedData;
|
||||
use crate::routes::PAGINATION_DEFAULT_LIMIT;
|
||||
|
||||
pub mod chat_completions;
|
||||
mod config;
|
||||
mod errors;
|
||||
pub mod settings;
|
||||
mod utils;
|
||||
|
||||
/// The function name to report search progress.
|
||||
/// This function is used to report on what meilisearch is
|
||||
/// doing which must be used on the frontend to report progress.
|
||||
const MEILI_SEARCH_PROGRESS_NAME: &str = "_meiliSearchProgress";
|
||||
/// The function name to append a conversation message in the user conversation.
|
||||
/// This function is used to append a conversation message in the user conversation.
|
||||
/// This must be used on the frontend to keep context of what happened on the
|
||||
/// Meilisearch-side and keep good context for follow up questions.
|
||||
const MEILI_APPEND_CONVERSATION_MESSAGE_NAME: &str = "_meiliAppendConversationMessage";
|
||||
/// The function name to report sources to the frontend.
|
||||
/// This function is used to report sources to the frontend.
|
||||
/// The call id is associated to the one used by the search progress function.
|
||||
const MEILI_SEARCH_SOURCES_NAME: &str = "_meiliSearchSources";
|
||||
/// The *internal* function name to provide to the LLM to search in indexes.
|
||||
/// This function must not leak to the user as the LLM will call it and the
|
||||
/// main goal of Meilisearch is to provide an answer to these calls.
|
||||
const MEILI_SEARCH_IN_INDEX_FUNCTION_NAME: &str = "_meiliSearchInIndex";
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct ChatsParam {
|
||||
workspace_uid: String,
|
||||
}
|
||||
|
||||
pub fn configure(cfg: &mut web::ServiceConfig) {
|
||||
cfg.service(web::resource("").route(web::get().to(list_workspaces))).service(
|
||||
web::scope("/{workspace_uid}")
|
||||
.service(
|
||||
web::resource("")
|
||||
.route(web::get().to(get_chat))
|
||||
.route(web::delete().to(delete_chat)),
|
||||
)
|
||||
.service(web::scope("/chat/completions").configure(chat_completions::configure))
|
||||
.service(web::scope("/settings").configure(settings::configure)),
|
||||
);
|
||||
}
|
||||
|
||||
pub async fn get_chat(
|
||||
index_scheduler: GuardedData<ActionPolicy<{ actions::CHATS_GET }>, Data<IndexScheduler>>,
|
||||
workspace_uid: web::Path<String>,
|
||||
) -> Result<HttpResponse, ResponseError> {
|
||||
index_scheduler.features().check_chat_completions("displaying a chat")?;
|
||||
|
||||
let workspace_uid = IndexUid::try_from(workspace_uid.into_inner())?;
|
||||
if index_scheduler.chat_workspace_exists(&workspace_uid)? {
|
||||
Ok(HttpResponse::Ok().json(json!({ "uid": workspace_uid })))
|
||||
} else {
|
||||
Err(ResponseError::from_msg(format!("chat {workspace_uid} not found"), Code::ChatNotFound))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete_chat(
|
||||
index_scheduler: GuardedData<ActionPolicy<{ actions::CHATS_DELETE }>, Data<IndexScheduler>>,
|
||||
workspace_uid: web::Path<String>,
|
||||
) -> Result<HttpResponse, ResponseError> {
|
||||
index_scheduler.features().check_chat_completions("deleting a chat")?;
|
||||
|
||||
let workspace_uid = workspace_uid.into_inner();
|
||||
if index_scheduler.delete_chat_settings(&workspace_uid)? {
|
||||
Ok(HttpResponse::NoContent().finish())
|
||||
} else {
|
||||
Err(ResponseError::from_msg(format!("chat {workspace_uid} not found"), Code::ChatNotFound))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserr, Debug, Clone, Copy, IntoParams)]
|
||||
#[deserr(error = DeserrQueryParamError, rename_all = camelCase, deny_unknown_fields)]
|
||||
#[into_params(rename_all = "camelCase", parameter_in = Query)]
|
||||
pub struct ListChats {
|
||||
/// The number of chat workspaces to skip before starting to retrieve anything
|
||||
#[param(value_type = Option<usize>, default, example = 100)]
|
||||
#[deserr(default, error = DeserrQueryParamError<InvalidIndexOffset>)]
|
||||
pub offset: Param<usize>,
|
||||
/// The number of chat workspaces to retrieve
|
||||
#[param(value_type = Option<usize>, default = 20, example = 1)]
|
||||
#[deserr(default = Param(PAGINATION_DEFAULT_LIMIT), error = DeserrQueryParamError<InvalidIndexLimit>)]
|
||||
pub limit: Param<usize>,
|
||||
}
|
||||
|
||||
impl ListChats {
|
||||
fn as_pagination(self) -> Pagination {
|
||||
Pagination { offset: self.offset.0, limit: self.limit.0 }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Clone, ToSchema)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ChatWorkspaceView {
|
||||
/// Unique identifier for the index
|
||||
pub uid: String,
|
||||
}
|
||||
|
||||
pub async fn list_workspaces(
|
||||
index_scheduler: GuardedData<ActionPolicy<{ actions::CHATS_GET }>, Data<IndexScheduler>>,
|
||||
paginate: AwebQueryParameter<ListChats, DeserrQueryParamError>,
|
||||
) -> Result<HttpResponse, ResponseError> {
|
||||
index_scheduler.features().check_chat_completions("listing the chats")?;
|
||||
|
||||
debug!(parameters = ?paginate, "List chat workspaces");
|
||||
let (total, workspaces) =
|
||||
index_scheduler.paginated_chat_workspace_uids(*paginate.offset, *paginate.limit)?;
|
||||
let workspaces =
|
||||
workspaces.into_iter().map(|uid| ChatWorkspaceView { uid }).collect::<Vec<_>>();
|
||||
let ret = paginate.as_pagination().format_with(total, workspaces);
|
||||
|
||||
debug!(returns = ?ret, "List chat workspaces");
|
||||
Ok(HttpResponse::Ok().json(ret))
|
||||
}
|
260
crates/meilisearch/src/routes/chats/settings.rs
Normal file
260
crates/meilisearch/src/routes/chats/settings.rs
Normal file
|
@ -0,0 +1,260 @@
|
|||
use actix_web::web::{self, Data};
|
||||
use actix_web::HttpResponse;
|
||||
use deserr::Deserr;
|
||||
use index_scheduler::IndexScheduler;
|
||||
use meilisearch_types::deserr::DeserrJsonError;
|
||||
use meilisearch_types::error::deserr_codes::*;
|
||||
use meilisearch_types::error::{Code, ResponseError};
|
||||
use meilisearch_types::features::{
|
||||
ChatCompletionPrompts as DbChatCompletionPrompts, ChatCompletionSettings,
|
||||
ChatCompletionSource as DbChatCompletionSource, DEFAULT_CHAT_SEARCH_DESCRIPTION_PROMPT,
|
||||
DEFAULT_CHAT_SEARCH_INDEX_UID_PARAM_PROMPT, DEFAULT_CHAT_SEARCH_Q_PARAM_PROMPT,
|
||||
DEFAULT_CHAT_SYSTEM_PROMPT,
|
||||
};
|
||||
use meilisearch_types::keys::actions;
|
||||
use meilisearch_types::milli::update::Setting;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use super::ChatsParam;
|
||||
use crate::extractors::authentication::policies::ActionPolicy;
|
||||
use crate::extractors::authentication::GuardedData;
|
||||
use crate::extractors::sequential_extractor::SeqHandler;
|
||||
|
||||
pub fn configure(cfg: &mut web::ServiceConfig) {
|
||||
cfg.service(
|
||||
web::resource("")
|
||||
.route(web::get().to(SeqHandler(get_settings)))
|
||||
.route(web::patch().to(SeqHandler(patch_settings)))
|
||||
.route(web::delete().to(SeqHandler(reset_settings))),
|
||||
);
|
||||
}
|
||||
|
||||
async fn get_settings(
|
||||
index_scheduler: GuardedData<
|
||||
ActionPolicy<{ actions::CHATS_SETTINGS_GET }>,
|
||||
Data<IndexScheduler>,
|
||||
>,
|
||||
chats_param: web::Path<ChatsParam>,
|
||||
) -> Result<HttpResponse, ResponseError> {
|
||||
index_scheduler.features().check_chat_completions("using the /chats/settings route")?;
|
||||
|
||||
let ChatsParam { workspace_uid } = chats_param.into_inner();
|
||||
|
||||
let mut settings = match index_scheduler.chat_settings(&workspace_uid)? {
|
||||
Some(settings) => settings,
|
||||
None => {
|
||||
return Err(ResponseError::from_msg(
|
||||
format!("Chat `{workspace_uid}` not found"),
|
||||
Code::ChatNotFound,
|
||||
))
|
||||
}
|
||||
};
|
||||
settings.hide_secrets();
|
||||
Ok(HttpResponse::Ok().json(settings))
|
||||
}
|
||||
|
||||
async fn patch_settings(
|
||||
index_scheduler: GuardedData<
|
||||
ActionPolicy<{ actions::CHATS_SETTINGS_UPDATE }>,
|
||||
Data<IndexScheduler>,
|
||||
>,
|
||||
chats_param: web::Path<ChatsParam>,
|
||||
web::Json(new): web::Json<ChatWorkspaceSettings>,
|
||||
) -> Result<HttpResponse, ResponseError> {
|
||||
index_scheduler.features().check_chat_completions("using the /chats/settings route")?;
|
||||
let ChatsParam { workspace_uid } = chats_param.into_inner();
|
||||
|
||||
let old_settings = index_scheduler.chat_settings(&workspace_uid)?.unwrap_or_default();
|
||||
|
||||
let prompts = match new.prompts {
|
||||
Setting::Set(new_prompts) => DbChatCompletionPrompts {
|
||||
system: match new_prompts.system {
|
||||
Setting::Set(new_system) => new_system,
|
||||
Setting::Reset => DEFAULT_CHAT_SYSTEM_PROMPT.to_string(),
|
||||
Setting::NotSet => old_settings.prompts.system,
|
||||
},
|
||||
search_description: match new_prompts.search_description {
|
||||
Setting::Set(new_description) => new_description,
|
||||
Setting::Reset => DEFAULT_CHAT_SEARCH_DESCRIPTION_PROMPT.to_string(),
|
||||
Setting::NotSet => old_settings.prompts.search_description,
|
||||
},
|
||||
search_q_param: match new_prompts.search_q_param {
|
||||
Setting::Set(new_description) => new_description,
|
||||
Setting::Reset => DEFAULT_CHAT_SEARCH_Q_PARAM_PROMPT.to_string(),
|
||||
Setting::NotSet => old_settings.prompts.search_q_param,
|
||||
},
|
||||
search_index_uid_param: match new_prompts.search_index_uid_param {
|
||||
Setting::Set(new_description) => new_description,
|
||||
Setting::Reset => DEFAULT_CHAT_SEARCH_INDEX_UID_PARAM_PROMPT.to_string(),
|
||||
Setting::NotSet => old_settings.prompts.search_index_uid_param,
|
||||
},
|
||||
},
|
||||
Setting::Reset => DbChatCompletionPrompts::default(),
|
||||
Setting::NotSet => old_settings.prompts,
|
||||
};
|
||||
|
||||
let mut settings = ChatCompletionSettings {
|
||||
source: match new.source {
|
||||
Setting::Set(new_source) => new_source.into(),
|
||||
Setting::Reset => DbChatCompletionSource::default(),
|
||||
Setting::NotSet => old_settings.source,
|
||||
},
|
||||
org_id: match new.org_id {
|
||||
Setting::Set(new_org_id) => Some(new_org_id),
|
||||
Setting::Reset => None,
|
||||
Setting::NotSet => old_settings.org_id,
|
||||
},
|
||||
project_id: match new.project_id {
|
||||
Setting::Set(new_project_id) => Some(new_project_id),
|
||||
Setting::Reset => None,
|
||||
Setting::NotSet => old_settings.project_id,
|
||||
},
|
||||
api_version: match new.api_version {
|
||||
Setting::Set(new_api_version) => Some(new_api_version),
|
||||
Setting::Reset => None,
|
||||
Setting::NotSet => old_settings.api_version,
|
||||
},
|
||||
deployment_id: match new.deployment_id {
|
||||
Setting::Set(new_deployment_id) => Some(new_deployment_id),
|
||||
Setting::Reset => None,
|
||||
Setting::NotSet => old_settings.deployment_id,
|
||||
},
|
||||
base_url: match new.base_url {
|
||||
Setting::Set(new_base_url) => Some(new_base_url),
|
||||
Setting::Reset => None,
|
||||
Setting::NotSet => old_settings.base_url,
|
||||
},
|
||||
api_key: match new.api_key {
|
||||
Setting::Set(new_api_key) => Some(new_api_key),
|
||||
Setting::Reset => None,
|
||||
Setting::NotSet => old_settings.api_key,
|
||||
},
|
||||
prompts,
|
||||
};
|
||||
|
||||
// TODO send analytics
|
||||
// analytics.publish(
|
||||
// PatchNetworkAnalytics {
|
||||
// network_size: merged_remotes.len(),
|
||||
// network_has_self: merged_self.is_some(),
|
||||
// },
|
||||
// &req,
|
||||
// );
|
||||
|
||||
settings.validate()?;
|
||||
index_scheduler.put_chat_settings(&workspace_uid, &settings)?;
|
||||
|
||||
settings.hide_secrets();
|
||||
|
||||
Ok(HttpResponse::Ok().json(settings))
|
||||
}
|
||||
|
||||
async fn reset_settings(
|
||||
index_scheduler: GuardedData<
|
||||
ActionPolicy<{ actions::CHATS_SETTINGS_UPDATE }>,
|
||||
Data<IndexScheduler>,
|
||||
>,
|
||||
chats_param: web::Path<ChatsParam>,
|
||||
) -> Result<HttpResponse, ResponseError> {
|
||||
index_scheduler.features().check_chat_completions("using the /chats/settings route")?;
|
||||
|
||||
let ChatsParam { workspace_uid } = chats_param.into_inner();
|
||||
if index_scheduler.chat_settings(&workspace_uid)?.is_some() {
|
||||
let settings = Default::default();
|
||||
index_scheduler.put_chat_settings(&workspace_uid, &settings)?;
|
||||
Ok(HttpResponse::Ok().json(settings))
|
||||
} else {
|
||||
Err(ResponseError::from_msg(
|
||||
format!("Chat `{workspace_uid}` not found"),
|
||||
Code::ChatNotFound,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Deserr, ToSchema)]
|
||||
#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
#[schema(rename_all = "camelCase")]
|
||||
pub struct ChatWorkspaceSettings {
|
||||
#[serde(default)]
|
||||
#[deserr(default)]
|
||||
#[schema(value_type = Option<ChatCompletionSource>)]
|
||||
pub source: Setting<ChatCompletionSource>,
|
||||
#[serde(default)]
|
||||
#[deserr(default, error = DeserrJsonError<InvalidChatCompletionOrgId>)]
|
||||
#[schema(value_type = Option<String>, example = json!("dcba4321..."))]
|
||||
pub org_id: Setting<String>,
|
||||
#[serde(default)]
|
||||
#[deserr(default, error = DeserrJsonError<InvalidChatCompletionProjectId>)]
|
||||
#[schema(value_type = Option<String>, example = json!("4321dcba..."))]
|
||||
pub project_id: Setting<String>,
|
||||
#[serde(default)]
|
||||
#[deserr(default, error = DeserrJsonError<InvalidChatCompletionApiVersion>)]
|
||||
#[schema(value_type = Option<String>, example = json!("2024-02-01"))]
|
||||
pub api_version: Setting<String>,
|
||||
#[serde(default)]
|
||||
#[deserr(default, error = DeserrJsonError<InvalidChatCompletionDeploymentId>)]
|
||||
#[schema(value_type = Option<String>, example = json!("1234abcd..."))]
|
||||
pub deployment_id: Setting<String>,
|
||||
#[serde(default)]
|
||||
#[deserr(default, error = DeserrJsonError<InvalidChatCompletionBaseApi>)]
|
||||
#[schema(value_type = Option<String>, example = json!("https://api.mistral.ai/v1"))]
|
||||
pub base_url: Setting<String>,
|
||||
#[serde(default)]
|
||||
#[deserr(default, error = DeserrJsonError<InvalidChatCompletionApiKey>)]
|
||||
#[schema(value_type = Option<String>, example = json!("abcd1234..."))]
|
||||
pub api_key: Setting<String>,
|
||||
#[serde(default)]
|
||||
#[deserr(default)]
|
||||
#[schema(inline, value_type = Option<ChatPrompts>)]
|
||||
pub prompts: Setting<ChatPrompts>,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, Deserr, ToSchema)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)]
|
||||
pub enum ChatCompletionSource {
|
||||
#[default]
|
||||
OpenAi,
|
||||
Mistral,
|
||||
Gemini,
|
||||
AzureOpenAi,
|
||||
VLlm,
|
||||
}
|
||||
|
||||
impl From<ChatCompletionSource> for DbChatCompletionSource {
|
||||
fn from(source: ChatCompletionSource) -> Self {
|
||||
use ChatCompletionSource::*;
|
||||
match source {
|
||||
OpenAi => DbChatCompletionSource::OpenAi,
|
||||
Mistral => DbChatCompletionSource::Mistral,
|
||||
Gemini => DbChatCompletionSource::Gemini,
|
||||
AzureOpenAi => DbChatCompletionSource::AzureOpenAi,
|
||||
VLlm => DbChatCompletionSource::VLlm,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Deserr, ToSchema)]
|
||||
#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
#[schema(rename_all = "camelCase")]
|
||||
pub struct ChatPrompts {
|
||||
#[serde(default)]
|
||||
#[deserr(default, error = DeserrJsonError<InvalidChatCompletionSystemPrompt>)]
|
||||
#[schema(value_type = Option<String>, example = json!("You are a helpful assistant..."))]
|
||||
pub system: Setting<String>,
|
||||
#[serde(default)]
|
||||
#[deserr(default, error = DeserrJsonError<InvalidChatCompletionSearchDescriptionPrompt>)]
|
||||
#[schema(value_type = Option<String>, example = json!("This is the search function..."))]
|
||||
pub search_description: Setting<String>,
|
||||
#[serde(default)]
|
||||
#[deserr(default, error = DeserrJsonError<InvalidChatCompletionSearchQueryParamPrompt>)]
|
||||
#[schema(value_type = Option<String>, example = json!("This is query parameter..."))]
|
||||
pub search_q_param: Setting<String>,
|
||||
#[serde(default)]
|
||||
#[deserr(default, error = DeserrJsonError<InvalidChatCompletionSearchIndexUidParamPrompt>)]
|
||||
#[schema(value_type = Option<String>, example = json!("This is index you want to search in..."))]
|
||||
pub search_index_uid_param: Setting<String>,
|
||||
}
|
253
crates/meilisearch/src/routes/chats/utils.rs
Normal file
253
crates/meilisearch/src/routes/chats/utils.rs
Normal file
|
@ -0,0 +1,253 @@
|
|||
use std::cell::RefCell;
|
||||
use std::sync::RwLock;
|
||||
|
||||
use actix_web_lab::sse::{self, Event};
|
||||
use async_openai::types::{
|
||||
ChatChoiceStream, ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
|
||||
ChatCompletionRequestAssistantMessage, ChatCompletionRequestMessage,
|
||||
ChatCompletionStreamResponseDelta, ChatCompletionToolType, CreateChatCompletionStreamResponse,
|
||||
FunctionCall, FunctionCallStream, Role,
|
||||
};
|
||||
use bumpalo::Bump;
|
||||
use meilisearch_types::error::{Code, ResponseError};
|
||||
use meilisearch_types::heed::RoTxn;
|
||||
use meilisearch_types::milli::index::ChatConfig;
|
||||
use meilisearch_types::milli::prompt::{Prompt, PromptData};
|
||||
use meilisearch_types::milli::update::new::document::DocumentFromDb;
|
||||
use meilisearch_types::milli::{
|
||||
DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap, MetadataBuilder,
|
||||
};
|
||||
use meilisearch_types::{Document, Index};
|
||||
use serde::Serialize;
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
|
||||
use super::errors::StreamErrorEvent;
|
||||
use super::MEILI_APPEND_CONVERSATION_MESSAGE_NAME;
|
||||
use crate::routes::chats::{MEILI_SEARCH_PROGRESS_NAME, MEILI_SEARCH_SOURCES_NAME};
|
||||
|
||||
pub struct SseEventSender(Sender<Event>);
|
||||
|
||||
impl SseEventSender {
|
||||
pub fn new(sender: Sender<Event>) -> Self {
|
||||
Self(sender)
|
||||
}
|
||||
|
||||
/// Ask the front-end user to append this tool *call* to the conversation
|
||||
pub async fn append_tool_call_conversation_message(
|
||||
&self,
|
||||
resp: CreateChatCompletionStreamResponse,
|
||||
call_id: String,
|
||||
function_name: String,
|
||||
function_arguments: String,
|
||||
) -> Result<(), SendError<Event>> {
|
||||
#[allow(deprecated)] // function_call
|
||||
let message =
|
||||
ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage {
|
||||
content: None,
|
||||
refusal: None,
|
||||
name: None,
|
||||
audio: None,
|
||||
tool_calls: Some(vec![ChatCompletionMessageToolCall {
|
||||
id: call_id,
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: FunctionCall { name: function_name, arguments: function_arguments },
|
||||
}]),
|
||||
function_call: None,
|
||||
});
|
||||
|
||||
self.append_conversation_message(resp, &message).await
|
||||
}
|
||||
|
||||
/// Ask the front-end user to append this tool to the conversation
|
||||
pub async fn append_conversation_message(
|
||||
&self,
|
||||
mut resp: CreateChatCompletionStreamResponse,
|
||||
message: &ChatCompletionRequestMessage,
|
||||
) -> Result<(), SendError<Event>> {
|
||||
let call_text = serde_json::to_string(message).unwrap();
|
||||
let tool_call = ChatCompletionMessageToolCallChunk {
|
||||
index: 0,
|
||||
id: Some(uuid::Uuid::new_v4().to_string()),
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: Some(FunctionCallStream {
|
||||
name: Some(MEILI_APPEND_CONVERSATION_MESSAGE_NAME.to_string()),
|
||||
arguments: Some(call_text),
|
||||
}),
|
||||
};
|
||||
|
||||
resp.choices[0] = ChatChoiceStream {
|
||||
index: 0,
|
||||
#[allow(deprecated)] // function_call
|
||||
delta: ChatCompletionStreamResponseDelta {
|
||||
content: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![tool_call]),
|
||||
role: Some(Role::Assistant),
|
||||
refusal: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
logprobs: None,
|
||||
};
|
||||
|
||||
self.send_json(&resp).await
|
||||
}
|
||||
|
||||
pub async fn report_search_progress(
|
||||
&self,
|
||||
mut resp: CreateChatCompletionStreamResponse,
|
||||
call_id: &str,
|
||||
function_name: &str,
|
||||
function_arguments: &str,
|
||||
) -> Result<(), SendError<Event>> {
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
/// Provides information about the current Meilisearch search operation.
|
||||
struct MeiliSearchProgress<'a> {
|
||||
/// The call ID to track the sources of the search.
|
||||
call_id: &'a str,
|
||||
/// The name of the function we are executing.
|
||||
function_name: &'a str,
|
||||
/// The arguments of the function we are executing, encoded in JSON.
|
||||
function_arguments: &'a str,
|
||||
}
|
||||
|
||||
let progress = MeiliSearchProgress { call_id, function_name, function_arguments };
|
||||
let call_text = serde_json::to_string(&progress).unwrap();
|
||||
let tool_call = ChatCompletionMessageToolCallChunk {
|
||||
index: 0,
|
||||
id: Some(uuid::Uuid::new_v4().to_string()),
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: Some(FunctionCallStream {
|
||||
name: Some(MEILI_SEARCH_PROGRESS_NAME.to_string()),
|
||||
arguments: Some(call_text),
|
||||
}),
|
||||
};
|
||||
|
||||
resp.choices[0] = ChatChoiceStream {
|
||||
index: 0,
|
||||
#[allow(deprecated)] // function_call
|
||||
delta: ChatCompletionStreamResponseDelta {
|
||||
content: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![tool_call]),
|
||||
role: Some(Role::Assistant),
|
||||
refusal: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
logprobs: None,
|
||||
};
|
||||
|
||||
self.send_json(&resp).await
|
||||
}
|
||||
|
||||
pub async fn report_sources(
|
||||
&self,
|
||||
mut resp: CreateChatCompletionStreamResponse,
|
||||
call_id: &str,
|
||||
documents: &[Document],
|
||||
) -> Result<(), SendError<Event>> {
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
/// Provides sources of the search.
|
||||
struct MeiliSearchSources<'a> {
|
||||
/// The call ID to track the original search associated to those sources.
|
||||
call_id: &'a str,
|
||||
/// The documents associated with the search (call_id).
|
||||
/// Only the displayed attributes of the documents are returned.
|
||||
sources: &'a [Document],
|
||||
}
|
||||
|
||||
let sources = MeiliSearchSources { call_id, sources: documents };
|
||||
let call_text = serde_json::to_string(&sources).unwrap();
|
||||
let tool_call = ChatCompletionMessageToolCallChunk {
|
||||
index: 0,
|
||||
id: Some(uuid::Uuid::new_v4().to_string()),
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: Some(FunctionCallStream {
|
||||
name: Some(MEILI_SEARCH_SOURCES_NAME.to_string()),
|
||||
arguments: Some(call_text),
|
||||
}),
|
||||
};
|
||||
|
||||
resp.choices[0] = ChatChoiceStream {
|
||||
index: 0,
|
||||
#[allow(deprecated)] // function_call
|
||||
delta: ChatCompletionStreamResponseDelta {
|
||||
content: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![tool_call]),
|
||||
role: Some(Role::Assistant),
|
||||
refusal: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
logprobs: None,
|
||||
};
|
||||
|
||||
self.send_json(&resp).await
|
||||
}
|
||||
|
||||
pub async fn forward_response(
|
||||
&self,
|
||||
resp: &CreateChatCompletionStreamResponse,
|
||||
) -> Result<(), SendError<Event>> {
|
||||
self.send_json(resp).await
|
||||
}
|
||||
|
||||
pub async fn send_error(&self, error: &StreamErrorEvent) -> Result<(), SendError<Event>> {
|
||||
self.send_json(error).await
|
||||
}
|
||||
|
||||
pub async fn stop(self) -> Result<(), SendError<Event>> {
|
||||
// It is the way OpenAI sends a correct end of stream
|
||||
// <https://platform.openai.com/docs/api-reference/assistants-streaming/events>
|
||||
const DONE_DATA: &str = "[DONE]";
|
||||
self.0.send(Event::Data(sse::Data::new(DONE_DATA))).await
|
||||
}
|
||||
|
||||
async fn send_json<S: Serialize>(&self, data: &S) -> Result<(), SendError<Event>> {
|
||||
self.0.send(Event::Data(sse::Data::new_json(data).unwrap())).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Format documents based on the provided template and maximum bytes.
|
||||
///
|
||||
/// This formatting function is usually used to generate a summary of the documents for LLMs.
|
||||
pub fn format_documents<'doc>(
|
||||
rtxn: &RoTxn<'_>,
|
||||
index: &Index,
|
||||
doc_alloc: &'doc Bump,
|
||||
internal_docids: Vec<DocumentId>,
|
||||
) -> Result<Vec<&'doc str>, ResponseError> {
|
||||
let ChatConfig { prompt: PromptData { template, max_bytes }, .. } = index.chat_config(rtxn)?;
|
||||
|
||||
let prompt = Prompt::new(template, max_bytes).unwrap();
|
||||
let fid_map = index.fields_ids_map(rtxn)?;
|
||||
let metadata_builder = MetadataBuilder::from_index(index, rtxn)?;
|
||||
let fid_map_with_meta = FieldIdMapWithMetadata::new(fid_map.clone(), metadata_builder);
|
||||
let global = RwLock::new(fid_map_with_meta);
|
||||
let gfid_map = RefCell::new(GlobalFieldsIdsMap::new(&global));
|
||||
|
||||
let external_ids: Vec<String> = index
|
||||
.external_id_of(rtxn, internal_docids.iter().copied())?
|
||||
.into_iter()
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
let mut renders = Vec::new();
|
||||
for (docid, external_docid) in internal_docids.into_iter().zip(external_ids) {
|
||||
let document = match DocumentFromDb::new(docid, rtxn, index, &fid_map)? {
|
||||
Some(doc) => doc,
|
||||
None => unreachable!("Document with internal ID {docid} not found"),
|
||||
};
|
||||
let text = match prompt.render_document(&external_docid, document, &gfid_map, doc_alloc) {
|
||||
Ok(text) => text,
|
||||
Err(err) => {
|
||||
return Err(ResponseError::from_msg(
|
||||
err.to_string(),
|
||||
Code::InvalidChatSettingDocumentTemplate,
|
||||
))
|
||||
}
|
||||
};
|
||||
renders.push(text);
|
||||
}
|
||||
|
||||
Ok(renders)
|
||||
}
|
|
@ -53,6 +53,7 @@ pub fn configure(cfg: &mut web::ServiceConfig) {
|
|||
network: Some(false),
|
||||
get_task_documents_route: Some(false),
|
||||
composite_embedders: Some(false),
|
||||
chat_completions: Some(false),
|
||||
})),
|
||||
(status = 401, description = "The authorization header is missing", body = ResponseError, content_type = "application/json", example = json!(
|
||||
{
|
||||
|
@ -97,6 +98,8 @@ pub struct RuntimeTogglableFeatures {
|
|||
pub get_task_documents_route: Option<bool>,
|
||||
#[deserr(default)]
|
||||
pub composite_embedders: Option<bool>,
|
||||
#[deserr(default)]
|
||||
pub chat_completions: Option<bool>,
|
||||
}
|
||||
|
||||
impl From<meilisearch_types::features::RuntimeTogglableFeatures> for RuntimeTogglableFeatures {
|
||||
|
@ -109,6 +112,7 @@ impl From<meilisearch_types::features::RuntimeTogglableFeatures> for RuntimeTogg
|
|||
network,
|
||||
get_task_documents_route,
|
||||
composite_embedders,
|
||||
chat_completions,
|
||||
} = value;
|
||||
|
||||
Self {
|
||||
|
@ -119,6 +123,7 @@ impl From<meilisearch_types::features::RuntimeTogglableFeatures> for RuntimeTogg
|
|||
network: Some(network),
|
||||
get_task_documents_route: Some(get_task_documents_route),
|
||||
composite_embedders: Some(composite_embedders),
|
||||
chat_completions: Some(chat_completions),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -132,6 +137,7 @@ pub struct PatchExperimentalFeatureAnalytics {
|
|||
network: bool,
|
||||
get_task_documents_route: bool,
|
||||
composite_embedders: bool,
|
||||
chat_completions: bool,
|
||||
}
|
||||
|
||||
impl Aggregate for PatchExperimentalFeatureAnalytics {
|
||||
|
@ -148,6 +154,7 @@ impl Aggregate for PatchExperimentalFeatureAnalytics {
|
|||
network: new.network,
|
||||
get_task_documents_route: new.get_task_documents_route,
|
||||
composite_embedders: new.composite_embedders,
|
||||
chat_completions: new.chat_completions,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -173,6 +180,7 @@ impl Aggregate for PatchExperimentalFeatureAnalytics {
|
|||
network: Some(false),
|
||||
get_task_documents_route: Some(false),
|
||||
composite_embedders: Some(false),
|
||||
chat_completions: Some(false),
|
||||
})),
|
||||
(status = 401, description = "The authorization header is missing", body = ResponseError, content_type = "application/json", example = json!(
|
||||
{
|
||||
|
@ -214,6 +222,7 @@ async fn patch_features(
|
|||
.0
|
||||
.composite_embedders
|
||||
.unwrap_or(old_features.composite_embedders),
|
||||
chat_completions: new_features.0.chat_completions.unwrap_or(old_features.chat_completions),
|
||||
};
|
||||
|
||||
// explicitly destructure for analytics rather than using the `Serialize` implementation, because
|
||||
|
@ -227,6 +236,7 @@ async fn patch_features(
|
|||
network,
|
||||
get_task_documents_route,
|
||||
composite_embedders,
|
||||
chat_completions,
|
||||
} = new_features;
|
||||
|
||||
analytics.publish(
|
||||
|
@ -238,6 +248,7 @@ async fn patch_features(
|
|||
network,
|
||||
get_task_documents_route,
|
||||
composite_embedders,
|
||||
chat_completions,
|
||||
},
|
||||
&req,
|
||||
);
|
||||
|
|
|
@ -172,7 +172,7 @@ pub async fn list_indexes(
|
|||
debug!(parameters = ?paginate, "List indexes");
|
||||
let filters = index_scheduler.filters();
|
||||
let (total, indexes) =
|
||||
index_scheduler.get_paginated_indexes_stats(filters, *paginate.offset, *paginate.limit)?;
|
||||
index_scheduler.paginated_indexes_stats(filters, *paginate.offset, *paginate.limit)?;
|
||||
let indexes = indexes
|
||||
.into_iter()
|
||||
.map(|(name, stats)| IndexView {
|
||||
|
|
|
@ -5,8 +5,9 @@ use index_scheduler::IndexScheduler;
|
|||
use meilisearch_types::deserr::DeserrJsonError;
|
||||
use meilisearch_types::error::ResponseError;
|
||||
use meilisearch_types::index_uid::IndexUid;
|
||||
use meilisearch_types::milli::update::Setting;
|
||||
use meilisearch_types::settings::{
|
||||
settings, SecretPolicy, SettingEmbeddingSettings, Settings, Unchecked,
|
||||
settings, ChatSettings, SecretPolicy, SettingEmbeddingSettings, Settings, Unchecked,
|
||||
};
|
||||
use meilisearch_types::tasks::KindWithContent;
|
||||
use tracing::debug;
|
||||
|
@ -508,6 +509,17 @@ make_setting_routes!(
|
|||
camelcase_attr: "prefixSearch",
|
||||
analytics: PrefixSearchAnalytics
|
||||
},
|
||||
{
|
||||
route: "/chat",
|
||||
update_verb: put,
|
||||
value_type: ChatSettings,
|
||||
err_type: meilisearch_types::deserr::DeserrJsonError<
|
||||
meilisearch_types::error::deserr_codes::InvalidSettingsIndexChat,
|
||||
>,
|
||||
attr: chat,
|
||||
camelcase_attr: "chat",
|
||||
analytics: ChatAnalytics
|
||||
},
|
||||
);
|
||||
|
||||
#[utoipa::path(
|
||||
|
@ -597,6 +609,7 @@ pub async fn update_all(
|
|||
),
|
||||
facet_search: FacetSearchAnalytics::new(new_settings.facet_search.as_ref().set()),
|
||||
prefix_search: PrefixSearchAnalytics::new(new_settings.prefix_search.as_ref().set()),
|
||||
chat: ChatAnalytics::new(new_settings.chat.as_ref().set()),
|
||||
},
|
||||
&req,
|
||||
);
|
||||
|
@ -651,7 +664,11 @@ pub async fn get_all(
|
|||
|
||||
let index = index_scheduler.index(&index_uid)?;
|
||||
let rtxn = index.read_txn()?;
|
||||
let new_settings = settings(&index, &rtxn, SecretPolicy::HideSecrets)?;
|
||||
let mut new_settings = settings(&index, &rtxn, SecretPolicy::HideSecrets)?;
|
||||
if index_scheduler.features().check_chat_completions("showing index `chat` settings").is_err() {
|
||||
new_settings.chat = Setting::NotSet;
|
||||
}
|
||||
|
||||
debug!(returns = ?new_settings, "Get all settings");
|
||||
Ok(HttpResponse::Ok().json(new_settings))
|
||||
}
|
||||
|
@ -741,5 +758,9 @@ fn validate_settings(
|
|||
}
|
||||
}
|
||||
|
||||
if let Setting::Set(_chat) = &settings.chat {
|
||||
features.check_chat_completions("setting `chat` in the index settings")?;
|
||||
}
|
||||
|
||||
Ok(settings.validate()?)
|
||||
}
|
||||
|
|
|
@ -10,8 +10,8 @@ use meilisearch_types::locales::{Locale, LocalizedAttributesRuleView};
|
|||
use meilisearch_types::milli::update::Setting;
|
||||
use meilisearch_types::milli::FilterableAttributesRule;
|
||||
use meilisearch_types::settings::{
|
||||
FacetingSettings, PaginationSettings, PrefixSearchSettings, ProximityPrecisionView,
|
||||
RankingRuleView, SettingEmbeddingSettings, TypoSettings,
|
||||
ChatSettings, FacetingSettings, PaginationSettings, PrefixSearchSettings,
|
||||
ProximityPrecisionView, RankingRuleView, SettingEmbeddingSettings, TypoSettings,
|
||||
};
|
||||
use serde::Serialize;
|
||||
|
||||
|
@ -39,6 +39,7 @@ pub struct SettingsAnalytics {
|
|||
pub non_separator_tokens: NonSeparatorTokensAnalytics,
|
||||
pub facet_search: FacetSearchAnalytics,
|
||||
pub prefix_search: PrefixSearchAnalytics,
|
||||
pub chat: ChatAnalytics,
|
||||
}
|
||||
|
||||
impl Aggregate for SettingsAnalytics {
|
||||
|
@ -198,6 +199,7 @@ impl Aggregate for SettingsAnalytics {
|
|||
set: new.prefix_search.set | self.prefix_search.set,
|
||||
value: new.prefix_search.value.or(self.prefix_search.value),
|
||||
},
|
||||
chat: ChatAnalytics { set: new.chat.set | self.chat.set },
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -454,7 +456,9 @@ pub struct PaginationAnalytics {
|
|||
|
||||
impl PaginationAnalytics {
|
||||
pub fn new(setting: Option<&PaginationSettings>) -> Self {
|
||||
Self { max_total_hits: setting.as_ref().and_then(|s| s.max_total_hits.set()) }
|
||||
Self {
|
||||
max_total_hits: setting.as_ref().and_then(|s| s.max_total_hits.set().map(|x| x.into())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_settings(self) -> SettingsAnalytics {
|
||||
|
@ -674,3 +678,18 @@ impl PrefixSearchAnalytics {
|
|||
SettingsAnalytics { prefix_search: self, ..Default::default() }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Default)]
|
||||
pub struct ChatAnalytics {
|
||||
pub set: bool,
|
||||
}
|
||||
|
||||
impl ChatAnalytics {
|
||||
pub fn new(settings: Option<&ChatSettings>) -> Self {
|
||||
Self { set: settings.is_some() }
|
||||
}
|
||||
|
||||
pub fn into_settings(self) -> SettingsAnalytics {
|
||||
SettingsAnalytics { chat: self, ..Default::default() }
|
||||
}
|
||||
}
|
||||
|
|
|
@ -52,6 +52,7 @@ const PAGINATION_DEFAULT_LIMIT_FN: fn() -> usize = || 20;
|
|||
|
||||
mod api_key;
|
||||
pub mod batches;
|
||||
pub mod chats;
|
||||
mod dump;
|
||||
pub mod features;
|
||||
pub mod indexes;
|
||||
|
@ -113,7 +114,8 @@ pub fn configure(cfg: &mut web::ServiceConfig) {
|
|||
.service(web::scope("/swap-indexes").configure(swap_indexes::configure))
|
||||
.service(web::scope("/metrics").configure(metrics::configure))
|
||||
.service(web::scope("/experimental-features").configure(features::configure))
|
||||
.service(web::scope("/network").configure(network::configure));
|
||||
.service(web::scope("/network").configure(network::configure))
|
||||
.service(web::scope("/chats").configure(chats::configure));
|
||||
|
||||
#[cfg(feature = "swagger")]
|
||||
{
|
||||
|
|
|
@ -17,6 +17,7 @@ use meilisearch_types::error::{Code, ResponseError};
|
|||
use meilisearch_types::heed::RoTxn;
|
||||
use meilisearch_types::index_uid::IndexUid;
|
||||
use meilisearch_types::locales::Locale;
|
||||
use meilisearch_types::milli::index::{self, SearchParameters};
|
||||
use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy};
|
||||
use meilisearch_types::milli::vector::parsed_vectors::ExplicitVectors;
|
||||
use meilisearch_types::milli::vector::Embedder;
|
||||
|
@ -122,9 +123,58 @@ pub struct SearchQuery {
|
|||
pub locales: Option<Vec<Locale>>,
|
||||
}
|
||||
|
||||
impl From<SearchParameters> for SearchQuery {
|
||||
fn from(parameters: SearchParameters) -> Self {
|
||||
let SearchParameters {
|
||||
hybrid,
|
||||
limit,
|
||||
sort,
|
||||
distinct,
|
||||
matching_strategy,
|
||||
attributes_to_search_on,
|
||||
ranking_score_threshold,
|
||||
} = parameters;
|
||||
|
||||
SearchQuery {
|
||||
hybrid: hybrid.map(|index::HybridQuery { semantic_ratio, embedder }| HybridQuery {
|
||||
semantic_ratio: SemanticRatio::try_from(semantic_ratio)
|
||||
.ok()
|
||||
.unwrap_or_else(DEFAULT_SEMANTIC_RATIO),
|
||||
embedder,
|
||||
}),
|
||||
limit: limit.unwrap_or_else(DEFAULT_SEARCH_LIMIT),
|
||||
sort,
|
||||
distinct,
|
||||
matching_strategy: matching_strategy.map(MatchingStrategy::from).unwrap_or_default(),
|
||||
attributes_to_search_on,
|
||||
ranking_score_threshold: ranking_score_threshold.map(RankingScoreThreshold::from),
|
||||
q: None,
|
||||
vector: None,
|
||||
offset: DEFAULT_SEARCH_OFFSET(),
|
||||
page: None,
|
||||
hits_per_page: None,
|
||||
attributes_to_retrieve: None,
|
||||
retrieve_vectors: false,
|
||||
attributes_to_crop: None,
|
||||
crop_length: DEFAULT_CROP_LENGTH(),
|
||||
attributes_to_highlight: None,
|
||||
show_matches_position: false,
|
||||
show_ranking_score: false,
|
||||
show_ranking_score_details: false,
|
||||
filter: None,
|
||||
facets: None,
|
||||
highlight_pre_tag: DEFAULT_HIGHLIGHT_PRE_TAG(),
|
||||
highlight_post_tag: DEFAULT_HIGHLIGHT_POST_TAG(),
|
||||
crop_marker: DEFAULT_CROP_MARKER(),
|
||||
locales: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Deserr, ToSchema, Serialize)]
|
||||
#[deserr(try_from(f64) = TryFrom::try_from -> InvalidSearchRankingScoreThreshold)]
|
||||
pub struct RankingScoreThreshold(f64);
|
||||
|
||||
impl std::convert::TryFrom<f64> for RankingScoreThreshold {
|
||||
type Error = InvalidSearchRankingScoreThreshold;
|
||||
|
||||
|
@ -139,6 +189,14 @@ impl std::convert::TryFrom<f64> for RankingScoreThreshold {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<index::RankingScoreThreshold> for RankingScoreThreshold {
|
||||
fn from(threshold: index::RankingScoreThreshold) -> Self {
|
||||
let threshold = threshold.as_f64();
|
||||
assert!((0.0..=1.0).contains(&threshold));
|
||||
RankingScoreThreshold(threshold)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Deserr)]
|
||||
#[deserr(try_from(f64) = TryFrom::try_from -> InvalidSimilarRankingScoreThreshold)]
|
||||
pub struct RankingScoreThresholdSimilar(f64);
|
||||
|
@ -282,8 +340,8 @@ impl fmt::Debug for SearchQuery {
|
|||
#[deserr(error = DeserrJsonError<InvalidSearchHybridQuery>, rename_all = camelCase, deny_unknown_fields)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct HybridQuery {
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSearchSemanticRatio>, default)]
|
||||
#[schema(value_type = f32, default)]
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSearchSemanticRatio>)]
|
||||
#[schema(default, value_type = f32)]
|
||||
#[serde(default)]
|
||||
pub semantic_ratio: SemanticRatio,
|
||||
#[deserr(error = DeserrJsonError<InvalidSearchEmbedder>)]
|
||||
|
@ -720,6 +778,16 @@ impl From<MatchingStrategy> for TermsMatchingStrategy {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<index::MatchingStrategy> for MatchingStrategy {
|
||||
fn from(other: index::MatchingStrategy) -> Self {
|
||||
match other {
|
||||
index::MatchingStrategy::Last => Self::Last,
|
||||
index::MatchingStrategy::All => Self::All,
|
||||
index::MatchingStrategy::Frequency => Self::Frequency,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, PartialEq, Eq, Deserr)]
|
||||
#[deserr(rename_all = camelCase)]
|
||||
pub enum FacetValuesSort {
|
||||
|
@ -883,7 +951,7 @@ pub fn add_search_rules(filter: &mut Option<Value>, rules: IndexSearchRules) {
|
|||
}
|
||||
}
|
||||
|
||||
fn prepare_search<'t>(
|
||||
pub fn prepare_search<'t>(
|
||||
index: &'t Index,
|
||||
rtxn: &'t RoTxn,
|
||||
query: &'t SearchQuery,
|
||||
|
@ -1266,7 +1334,7 @@ struct HitMaker<'a> {
|
|||
vectors_fid: Option<FieldId>,
|
||||
retrieve_vectors: RetrieveVectors,
|
||||
to_retrieve_ids: BTreeSet<FieldId>,
|
||||
embedding_configs: Vec<milli::index::IndexEmbeddingConfig>,
|
||||
embedding_configs: Vec<index::IndexEmbeddingConfig>,
|
||||
matcher_builder: MatcherBuilder<'a>,
|
||||
formatted_options: BTreeMap<FieldId, FormatOptions>,
|
||||
show_ranking_score: bool,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
//! This file implements a queue of searches to process and the ability to control how many searches can be run in parallel.
|
||||
//! We need this because we don't want to process more search requests than we have cores.
|
||||
//! We need this because we don't want to process more search requests than the available CPU cores.
|
||||
//! That slows down everything and consumes RAM for no reason.
|
||||
//! The steps to do a search are to get the `SearchQueue` data structure and try to get a search permit.
|
||||
//! This can fail if the queue is full, and we need to drop your search request to register a new one.
|
||||
|
@ -8,7 +8,7 @@
|
|||
//!
|
||||
//! In order to do a search request you should try to get a search permit.
|
||||
//! Retrieve the `SearchQueue` structure from actix-web (`search_queue: Data<SearchQueue>`)
|
||||
//! and right before processing the search, calls the `SearchQueue::try_get_search_permit` method: `search_queue.try_get_search_permit().await?;`
|
||||
//! and right before processing the search, call the `SearchQueue::try_get_search_permit` method: `search_queue.try_get_search_permit().await?;`
|
||||
//!
|
||||
//! What is going to happen at this point is that you're going to send a oneshot::Sender over an async mpsc channel.
|
||||
//! Then, the queue/scheduler is going to either:
|
||||
|
@ -121,12 +121,12 @@ impl SearchQueue {
|
|||
let mut queue: Vec<oneshot::Sender<Permit>> = Default::default();
|
||||
let mut rng: StdRng = StdRng::from_entropy();
|
||||
let mut searches_running: usize = 0;
|
||||
// By having a capacity of parallelism we ensures that every time a search finish it can release its RAM asap
|
||||
// By having a capacity of parallelism we ensure that every time a search finish it can release its RAM asap
|
||||
let (sender, mut search_finished) = mpsc::channel(parallelism.into());
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// biased select because we wants to free up space before trying to register new tasks
|
||||
// biased select because we want to free up space before trying to register new tasks
|
||||
biased;
|
||||
_ = search_finished.recv() => {
|
||||
searches_running = searches_running.saturating_sub(1);
|
||||
|
@ -148,11 +148,11 @@ impl SearchQueue {
|
|||
|
||||
if searches_running < usize::from(parallelism) && queue.is_empty() {
|
||||
searches_running += 1;
|
||||
// if the search requests die it's not a hard error on our side
|
||||
// if the search requests die, it's not a hard error on our side
|
||||
let _ = search_request.send(Permit { sender: sender.clone() });
|
||||
continue;
|
||||
} else if capacity == 0 {
|
||||
// in the very specific case where we have a capacity of zero
|
||||
// in the very specific case where we have a capacity of zero,
|
||||
// we must refuse the request straight away without going through
|
||||
// the queue stuff.
|
||||
drop(search_request);
|
||||
|
@ -183,7 +183,7 @@ impl SearchQueue {
|
|||
.map_err(|_| MeilisearchHttpError::TooManySearchRequests(self.capacity))?;
|
||||
|
||||
// If we've been for more than one minute to get a search permit, it's better to simply
|
||||
// abort the search request than spending time processing something were the client
|
||||
// abort the search request than spending time processing something where the client
|
||||
// most certainly exited or got a timeout a long time ago.
|
||||
// We may find a better solution in https://github.com/actix/actix-web/issues/3462.
|
||||
if now.elapsed() > self.time_to_abort {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue