Support filtering

This commit is contained in:
Clément Renault 2025-07-01 16:24:02 +02:00
parent 5d2535032d
commit 08e096c985
No known key found for this signature in database
GPG key ID: F250A4C4E3AE5F5F
3 changed files with 58 additions and 5 deletions

View file

@ -27,9 +27,10 @@ use meilisearch_types::features::{
ChatCompletionPrompts as DbChatCompletionPrompts,
ChatCompletionSource as DbChatCompletionSource, SystemRole,
};
use meilisearch_types::heed::RoTxn;
use meilisearch_types::keys::actions;
use meilisearch_types::milli::index::ChatConfig;
use meilisearch_types::milli::{all_obkv_to_json, obkv_to_json, TimeBudget};
use meilisearch_types::milli::{all_obkv_to_json, obkv_to_json, OrderBy, TimeBudget};
use meilisearch_types::{Document, Index};
use serde::Deserialize;
use serde_json::json;
@ -160,6 +161,7 @@ fn setup_search_tool(
let mut index_uids = Vec::new();
let mut function_description = prompts.search_description.clone();
let mut filter_description = prompts.search_filter_param.clone();
index_scheduler.try_for_each_index::<_, ()>(|name, index| {
// Make sure to skip unauthorized indexes
if !filters.is_index_authorized(name) {
@ -171,16 +173,22 @@ fn setup_search_tool(
let index_description = chat_config.description;
let _ = writeln!(&mut function_description, "\n\n - {name}: {index_description}\n");
index_uids.push(name.to_string());
let facet_distributions = format_facet_distributions(&index, &rtxn, 10).unwrap(); // TODO do not unwrap
let _ = writeln!(&mut filter_description, "\n## Facet distributions of the {name} index");
let _ = writeln!(&mut filter_description, "{facet_distributions}");
Ok(())
})?;
tracing::debug!("LLM function description: {function_description}");
tracing::debug!("LLM filter description: {filter_description}");
let tool = ChatCompletionToolArgs::default()
.r#type(ChatCompletionToolType::Function)
.function(
FunctionObjectArgs::default()
.name(MEILI_SEARCH_IN_INDEX_FUNCTION_NAME)
.description(&function_description)
.description(function_description)
.parameters(json!({
"type": "object",
"properties": {
@ -197,7 +205,7 @@ fn setup_search_tool(
},
"filter": {
"type": "string",
"description": prompts.search_filter_param,
"description": filter_description,
}
},
"required": ["index_uid", "q", "filter"],
@ -252,6 +260,9 @@ async fn process_search_request(
filter: filter.map(serde_json::Value::from),
..SearchQuery::from(search_parameters)
};
tracing::debug!("LLM query: {:?}", query);
let auth_filter = ActionPolicy::<{ actions::SEARCH }>::authenticate(
auth_ctrl,
auth_token,
@ -766,3 +777,39 @@ struct SearchInIndexParameters {
/// The filter parameter to use.
filter: Option<String>,
}
fn format_facet_distributions(
index: &Index,
rtxn: &RoTxn,
max_values_per_facet: usize,
) -> meilisearch_types::milli::Result<String> {
let universe = index.documents_ids(&rtxn)?;
let rules = index.filterable_attributes_rules(&rtxn)?;
let fields_ids_map = index.fields_ids_map(&rtxn)?;
let filterable_attributes = fields_ids_map
.names()
.filter(|name| rules.iter().any(|rule| rule.match_str(name).matches()))
.map(|name| (name, OrderBy::Count));
let facets_distribution = index
.facets_distribution(&rtxn)
.max_values_per_facet(max_values_per_facet)
.candidates(universe)
.facets(filterable_attributes)
.execute()?;
let mut output = String::new();
for (facet_name, entries) in facets_distribution {
let _ = write!(&mut output, "{}: ", facet_name);
let total_entries = entries.len();
for (i, (value, count)) in entries.into_iter().enumerate() {
let _ = if total_entries.saturating_sub(1) == i {
write!(&mut output, "{} ({}).", value, count)
} else {
write!(&mut output, "{} ({}), ", value, count)
};
}
let _ = writeln!(&mut output);
}
Ok(output)
}