Correctly support document templates on the chat API

This commit is contained in:
Clément Renault 2025-05-21 15:32:34 +02:00
parent c6930c8819
commit 75c3f33478
No known key found for this signature in database
GPG key ID: F250A4C4E3AE5F5F
8 changed files with 72 additions and 52 deletions

View file

@ -1,5 +1,7 @@
use std::cell::RefCell;
use std::collections::HashMap;
use std::mem;
use std::sync::RwLock;
use std::time::Duration;
use actix_web::web::{self, Data};
@ -16,27 +18,33 @@ use async_openai::types::{
FunctionObjectArgs,
};
use async_openai::Client;
use bumpalo::Bump;
use futures::StreamExt;
use index_scheduler::IndexScheduler;
use meilisearch_auth::AuthController;
use meilisearch_types::error::ResponseError;
use meilisearch_types::heed::RoTxn;
use meilisearch_types::keys::actions;
use meilisearch_types::milli::index::IndexEmbeddingConfig;
use meilisearch_types::milli::prompt::PromptData;
use meilisearch_types::milli::vector::EmbeddingConfig;
use meilisearch_types::{Document, Index};
use serde::{Deserialize, Serialize};
use meilisearch_types::milli::index::ChatConfig;
use meilisearch_types::milli::prompt::{Prompt, PromptData};
use meilisearch_types::milli::update::new::document::DocumentFromDb;
use meilisearch_types::milli::{
DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap, MetadataBuilder, TimeBudget,
};
use meilisearch_types::Index;
use serde::Deserialize;
use serde_json::json;
use tokio::runtime::Handle;
use tokio::sync::mpsc::error::SendError;
use super::settings::chat::{ChatPrompts, GlobalChatSettings};
use crate::error::MeilisearchHttpError;
use crate::extractors::authentication::policies::ActionPolicy;
use crate::extractors::authentication::{extract_token_from_request, GuardedData, Policy as _};
use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS;
use crate::routes::indexes::search::search_kind;
use crate::search::{
add_search_rules, perform_search, HybridQuery, RetrieveVectors, SearchQuery, SemanticRatio,
add_search_rules, prepare_search, search_from_kind, HybridQuery, SearchQuery, SemanticRatio,
};
use crate::search_queue::SearchQueue;
@ -175,15 +183,22 @@ async fn process_search_request(
let permit = search_queue.try_get_search_permit().await?;
let features = index_scheduler.features();
let index_cloned = index.clone();
let search_result = tokio::task::spawn_blocking(move || {
perform_search(
index_uid.to_string(),
&index_cloned,
query,
search_kind,
RetrieveVectors::new(false),
features,
)
let search_result = tokio::task::spawn_blocking(move || -> Result<_, ResponseError> {
let rtxn = index_cloned.read_txn()?;
let time_budget = match index_cloned
.search_cutoff(&rtxn)
.map_err(|e| MeilisearchHttpError::from_milli(e, Some(index_uid.clone())))?
{
Some(cutoff) => TimeBudget::new(Duration::from_millis(cutoff)),
None => TimeBudget::default(),
};
let (search, _is_finite_pagination, _max_total_hits, _offset) =
prepare_search(&index_cloned, &rtxn, &query, &search_kind, time_budget, features)?;
search_from_kind(index_uid, search_kind, search)
.map(|(search_results, _)| search_results)
.map_err(ResponseError::from)
})
.await;
permit.drop().await;
@ -198,9 +213,11 @@ async fn process_search_request(
// analytics.publish(aggregate, &req);
let search_result = search_result?;
let formatted =
format_documents(&index, search_result.hits.into_iter().map(|doc| doc.document));
let rtxn = index.read_txn()?;
let render_alloc = Bump::new();
let formatted = format_documents(&rtxn, &index, &render_alloc, search_result.documents_ids)?;
let text = formatted.join("\n");
drop(rtxn);
Ok((index, text))
}
@ -506,31 +523,36 @@ struct SearchInIndexParameters {
q: Option<String>,
}
fn format_documents(index: &Index, documents: impl Iterator<Item = Document>) -> Vec<String> {
let rtxn = index.read_txn().unwrap();
let IndexEmbeddingConfig { name: _, config, user_provided: _ } = index
.embedding_configs(&rtxn)
.unwrap()
fn format_documents<'t, 'doc>(
rtxn: &RoTxn<'t>,
index: &Index,
doc_alloc: &'doc Bump,
internal_docids: Vec<DocumentId>,
) -> Result<Vec<&'doc str>, ResponseError> {
let ChatConfig { prompt: PromptData { template, max_bytes }, .. } = index.chat_config(rtxn)?;
let prompt = Prompt::new(template, max_bytes).unwrap();
let fid_map = index.fields_ids_map(rtxn)?;
let metadata_builder = MetadataBuilder::from_index(index, rtxn)?;
let fid_map_with_meta = FieldIdMapWithMetadata::new(fid_map.clone(), metadata_builder);
let global = RwLock::new(fid_map_with_meta);
let gfid_map = RefCell::new(GlobalFieldsIdsMap::new(&global));
let external_ids: Vec<String> = index
.external_id_of(rtxn, internal_docids.iter().copied())?
.into_iter()
.find(|conf| conf.name == EMBEDDER_NAME)
.unwrap();
.collect::<Result<_, _>>()?;
let EmbeddingConfig {
embedder_options: _,
prompt: PromptData { template, max_bytes: _ },
quantized: _,
} = config;
let mut renders = Vec::new();
for (docid, external_docid) in internal_docids.into_iter().zip(external_ids) {
let document = match DocumentFromDb::new(docid, rtxn, index, &fid_map)? {
Some(doc) => doc,
None => continue,
};
#[derive(Serialize)]
struct Doc<T: Serialize> {
doc: T,
let text = prompt.render_document(&external_docid, document, &gfid_map, doc_alloc).unwrap();
renders.push(text);
}
let template = liquid::ParserBuilder::with_stdlib().build().unwrap().parse(&template).unwrap();
documents
.map(|doc| {
let object = liquid::to_object(&Doc { doc }).unwrap();
template.render(&object).unwrap()
})
.collect()
Ok(renders)
}

View file

@ -1,5 +1,3 @@
use std::collections::BTreeMap;
use actix_web::web::{self, Data};
use actix_web::HttpResponse;
use index_scheduler::IndexScheduler;
@ -51,7 +49,6 @@ pub struct GlobalChatSettings {
pub base_api: Option<String>,
pub api_key: Option<String>,
pub prompts: ChatPrompts,
pub indexes: BTreeMap<String, ChatIndexSettings>,
}
#[derive(Debug, Serialize, Deserialize)]
@ -105,7 +102,6 @@ impl Default for GlobalChatSettings {
.to_string(),
pre_query: "".to_string(),
},
indexes: BTreeMap::new(),
}
}
}

View file

@ -882,7 +882,7 @@ pub fn add_search_rules(filter: &mut Option<Value>, rules: IndexSearchRules) {
}
}
fn prepare_search<'t>(
pub fn prepare_search<'t>(
index: &'t Index,
rtxn: &'t RoTxn,
query: &'t SearchQuery,