Take sort criteria from the request

This commit is contained in:
Mubelotix 2025-06-25 16:41:08 +02:00
parent 6e0526090a
commit b05cb80803
No known key found for this signature in database
GPG key ID: 0406DF6C3A69B942
3 changed files with 79 additions and 29 deletions

View file

@ -1,6 +1,7 @@
use std::collections::HashSet;
use std::io::{ErrorKind, Seek as _};
use std::marker::PhantomData;
use std::str::FromStr;
use actix_web::http::header::CONTENT_TYPE;
use actix_web::web::Data;
@ -18,12 +19,9 @@ use meilisearch_types::error::{Code, ResponseError};
use meilisearch_types::heed::RoTxn;
use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::milli::facet::facet_sort_recursive::recursive_facet_sort;
use meilisearch_types::milli::facet::{ascending_facet_sort, descending_facet_sort};
use meilisearch_types::milli::heed_codec::facet::FacetGroupKeyCodec;
use meilisearch_types::milli::heed_codec::BytesRefCodec;
use meilisearch_types::milli::update::IndexDocumentsMethod;
use meilisearch_types::milli::vector::parsed_vectors::ExplicitVectors;
use meilisearch_types::milli::DocumentId;
use meilisearch_types::milli::{AscDesc, DocumentId};
use meilisearch_types::serde_cs::vec::CS;
use meilisearch_types::star_or::OptionStarOrList;
use meilisearch_types::tasks::KindWithContent;
@ -46,6 +44,7 @@ use crate::extractors::authentication::policies::*;
use crate::extractors::authentication::GuardedData;
use crate::extractors::payload::Payload;
use crate::extractors::sequential_extractor::SeqHandler;
use crate::routes::indexes::search::fix_sort_query_parameters;
use crate::routes::{
get_task_id, is_dry_run, PaginationView, SummarizedTaskView, PAGINATION_DEFAULT_LIMIT,
};
@ -410,6 +409,8 @@ pub struct BrowseQueryGet {
#[param(default, value_type = Option<String>, example = "popularity > 1000")]
#[deserr(default, error = DeserrQueryParamError<InvalidDocumentFilter>)]
filter: Option<String>,
#[deserr(default, error = DeserrQueryParamError<InvalidSearchSort>)]
sort: Option<String>, // TODO: change deser error
}
#[derive(Debug, Deserr, ToSchema)]
@ -434,6 +435,9 @@ pub struct BrowseQuery {
#[schema(default, value_type = Option<Value>, example = "popularity > 1000")]
#[deserr(default, error = DeserrJsonError<InvalidDocumentFilter>)]
filter: Option<Value>,
#[schema(default, value_type = Option<Vec<String>>, example = json!(["title:asc", "rating:desc"]))]
#[deserr(default, error = DeserrJsonError<InvalidSearchSort>)] // TODO: Change error
pub sort: Option<Vec<String>>,
}
/// Get documents with POST
@ -575,7 +579,7 @@ pub async fn get_documents(
) -> Result<HttpResponse, ResponseError> {
debug!(parameters = ?params, "Get documents GET");
let BrowseQueryGet { limit, offset, fields, retrieve_vectors, filter, ids } =
let BrowseQueryGet { limit, offset, fields, retrieve_vectors, filter, ids, sort } =
params.into_inner();
let filter = match filter {
@ -586,15 +590,14 @@ pub async fn get_documents(
None => None,
};
let ids = ids.map(|ids| ids.into_iter().map(Into::into).collect());
let query = BrowseQuery {
offset: offset.0,
limit: limit.0,
fields: fields.merge_star_and_none(),
retrieve_vectors: retrieve_vectors.0,
filter,
ids,
ids: ids.map(|ids| ids.into_iter().map(Into::into).collect()),
sort: sort.map(|attr| fix_sort_query_parameters(&attr)),
};
analytics.publish(
@ -619,7 +622,7 @@ fn documents_by_query(
query: BrowseQuery,
) -> Result<HttpResponse, ResponseError> {
let index_uid = IndexUid::try_from(index_uid.into_inner())?;
let BrowseQuery { offset, limit, fields, retrieve_vectors, filter, ids } = query;
let BrowseQuery { offset, limit, fields, retrieve_vectors, filter, ids, sort } = query;
let retrieve_vectors = RetrieveVectors::new(retrieve_vectors);
@ -637,6 +640,22 @@ fn documents_by_query(
None
};
let sort_criteria = if let Some(sort) = &sort {
let sorts: Vec<_> =
match sort.iter().map(|s| milli::AscDesc::from_str(s)).collect() {
Ok(sorts) => sorts,
Err(asc_desc_error) => {
return Err(milli::Error::from(milli::SortError::from(
asc_desc_error,
))
.into())
}
};
Some(sorts)
} else {
None
};
let index = index_scheduler.index(&index_uid)?;
let (total, documents) = retrieve_documents(
&index,
@ -647,6 +666,7 @@ fn documents_by_query(
fields,
retrieve_vectors,
index_scheduler.features(),
sort_criteria,
)?;
let ret = PaginationView::new(offset, limit, total as usize, documents);
@ -1505,6 +1525,7 @@ fn retrieve_documents<S: AsRef<str>>(
attributes_to_retrieve: Option<Vec<S>>,
retrieve_vectors: RetrieveVectors,
features: RoFeatures,
sort_criteria: Option<Vec<AscDesc>>,
) -> Result<(u64, Vec<Document>), ResponseError> {
let rtxn = index.read_txn()?;
let filter = &filter;
@ -1537,14 +1558,9 @@ fn retrieve_documents<S: AsRef<str>>(
})?
}
let fields = vec![(0, true)];
let number_db = index
.facet_id_f64_docids
.remap_key_type::<FacetGroupKeyCodec<BytesRefCodec>>();
let string_db = index
.facet_id_string_docids
.remap_key_type::<FacetGroupKeyCodec<BytesRefCodec>>();
candidates = recursive_facet_sort(&rtxn, number_db, string_db, &fields, candidates)?;
if let Some(sort) = sort_criteria {
candidates = recursive_facet_sort(index, &rtxn, &sort, candidates)?;
}
let (it, number_of_documents) = {
let number_of_documents = candidates.len();