Take sort criteria from the request

This commit is contained in:
Mubelotix 2025-06-25 16:41:08 +02:00
parent 6e0526090a
commit b05cb80803
No known key found for this signature in database
GPG key ID: 0406DF6C3A69B942
3 changed files with 79 additions and 29 deletions

View file

@ -1,8 +1,8 @@
use roaring::RoaringBitmap;
use heed::Database;
use crate::{facet::{ascending_facet_sort, descending_facet_sort}, heed_codec::{facet::{FacetGroupKeyCodec, FacetGroupValueCodec}, BytesRefCodec}};
use crate::{heed_codec::{facet::{FacetGroupKeyCodec, FacetGroupValueCodec}, BytesRefCodec}, search::{facet::{ascending_facet_sort, descending_facet_sort}, new::check_sort_criteria}, AscDesc, Member};
pub fn recursive_facet_sort<'t>(
fn recursive_facet_sort_inner<'t>(
rtxn: &'t heed::RoTxn<'t>,
number_db: Database<FacetGroupKeyCodec<BytesRefCodec>, FacetGroupValueCodec>,
string_db: Database<FacetGroupKeyCodec<BytesRefCodec>, FacetGroupValueCodec>,
@ -53,7 +53,7 @@ pub fn recursive_facet_sort<'t>(
if inner_candidates.len() <= 1 || fields.len() <= 1 {
result |= inner_candidates;
} else {
let inner_candidates = recursive_facet_sort(
let inner_candidates = recursive_facet_sort_inner(
rtxn,
number_db,
string_db,
@ -66,3 +66,36 @@ pub fn recursive_facet_sort<'t>(
Ok(result)
}
pub fn recursive_facet_sort<'t>(
index: &crate::Index,
rtxn: &'t heed::RoTxn<'t>,
sort: &[AscDesc],
candidates: RoaringBitmap,
) -> crate::Result<RoaringBitmap> {
check_sort_criteria(index, rtxn, Some(sort))?;
let mut fields = Vec::new();
let fields_ids_map = index.fields_ids_map(rtxn)?;
for sort in sort {
let (field_id, ascending) = match sort {
AscDesc::Asc(Member::Field(field)) => (fields_ids_map.id(field), true),
AscDesc::Desc(Member::Field(field)) => (fields_ids_map.id(field), false),
AscDesc::Asc(Member::Geo(_)) => todo!(),
AscDesc::Desc(Member::Geo(_)) => todo!(),
};
if let Some(field_id) = field_id {
fields.push((field_id, ascending)); // FIXME: Should this return an error if the field is not found?
}
}
let number_db = index
.facet_id_f64_docids
.remap_key_type::<FacetGroupKeyCodec<BytesRefCodec>>();
let string_db = index
.facet_id_string_docids
.remap_key_type::<FacetGroupKeyCodec<BytesRefCodec>>();
let candidates = recursive_facet_sort_inner(rtxn, number_db, string_db, &fields, candidates)?;
Ok(candidates)
}

View file

@ -638,7 +638,7 @@ pub fn execute_vector_search(
time_budget: TimeBudget,
ranking_score_threshold: Option<f64>,
) -> Result<PartialSearchResult> {
check_sort_criteria(ctx, sort_criteria.as_ref())?;
check_sort_criteria(ctx.index, ctx.txn, sort_criteria.as_deref())?;
// FIXME: input universe = universe & documents_with_vectors
// for now if we're computing embeddings for ALL documents, we can assume that this is just universe
@ -702,7 +702,7 @@ pub fn execute_search(
ranking_score_threshold: Option<f64>,
locales: Option<&Vec<Language>>,
) -> Result<PartialSearchResult> {
check_sort_criteria(ctx, sort_criteria.as_ref())?;
check_sort_criteria(ctx.index, ctx.txn, sort_criteria.as_deref())?;
let mut used_negative_operator = false;
let mut located_query_terms = None;
@ -872,9 +872,10 @@ pub fn execute_search(
})
}
fn check_sort_criteria(
ctx: &SearchContext<'_>,
sort_criteria: Option<&Vec<AscDesc>>,
pub(crate) fn check_sort_criteria(
index: &Index,
rtxn: &RoTxn<'_>,
sort_criteria: Option<&[AscDesc]>,
) -> Result<()> {
let sort_criteria = if let Some(sort_criteria) = sort_criteria {
sort_criteria
@ -888,19 +889,19 @@ fn check_sort_criteria(
// We check that the sort ranking rule exists and throw an
// error if we try to use it and that it doesn't.
let sort_ranking_rule_missing = !ctx.index.criteria(ctx.txn)?.contains(&crate::Criterion::Sort);
let sort_ranking_rule_missing = !index.criteria(rtxn)?.contains(&crate::Criterion::Sort);
if sort_ranking_rule_missing {
return Err(UserError::SortRankingRuleMissing.into());
}
// We check that we are allowed to use the sort criteria, we check
// that they are declared in the sortable fields.
let sortable_fields = ctx.index.sortable_fields(ctx.txn)?;
let sortable_fields = index.sortable_fields(rtxn)?;
for asc_desc in sort_criteria {
match asc_desc.member() {
Member::Field(ref field) if !crate::is_faceted(field, &sortable_fields) => {
let (valid_fields, hidden_fields) =
ctx.index.remove_hidden_fields(ctx.txn, sortable_fields)?;
index.remove_hidden_fields(rtxn, sortable_fields)?;
return Err(UserError::InvalidSortableAttribute {
field: field.to_string(),
@ -911,7 +912,7 @@ fn check_sort_criteria(
}
Member::Geo(_) if !sortable_fields.contains(RESERVED_GEO_FIELD_NAME) => {
let (valid_fields, hidden_fields) =
ctx.index.remove_hidden_fields(ctx.txn, sortable_fields)?;
index.remove_hidden_fields(rtxn, sortable_fields)?;
return Err(UserError::InvalidSortableAttribute {
field: RESERVED_GEO_FIELD_NAME.to_string(),