Expose the scores and detailed scores in the API

This commit is contained in:
Louis Dureuil 2023-06-06 18:26:33 +02:00
parent 701d44bd91
commit da833eb095
No known key found for this signature in database
8 changed files with 74 additions and 12 deletions

View File

@ -240,6 +240,8 @@ InvalidSearchOffset , InvalidRequest , BAD_REQUEST ;
InvalidSearchPage , InvalidRequest , BAD_REQUEST ; InvalidSearchPage , InvalidRequest , BAD_REQUEST ;
InvalidSearchQ , InvalidRequest , BAD_REQUEST ; InvalidSearchQ , InvalidRequest , BAD_REQUEST ;
InvalidSearchShowMatchesPosition , InvalidRequest , BAD_REQUEST ; InvalidSearchShowMatchesPosition , InvalidRequest , BAD_REQUEST ;
InvalidSearchShowRankingScore , InvalidRequest , BAD_REQUEST ;
InvalidSearchShowRankingScoreDetails , InvalidRequest , BAD_REQUEST ;
InvalidSearchSort , InvalidRequest , BAD_REQUEST ; InvalidSearchSort , InvalidRequest , BAD_REQUEST ;
InvalidSettingsDisplayedAttributes , InvalidRequest , BAD_REQUEST ; InvalidSettingsDisplayedAttributes , InvalidRequest , BAD_REQUEST ;
InvalidSettingsDistinctAttribute , InvalidRequest , BAD_REQUEST ; InvalidSettingsDistinctAttribute , InvalidRequest , BAD_REQUEST ;

View File

@ -56,6 +56,10 @@ pub struct SearchQueryGet {
sort: Option<String>, sort: Option<String>,
#[deserr(default, error = DeserrQueryParamError<InvalidSearchShowMatchesPosition>)] #[deserr(default, error = DeserrQueryParamError<InvalidSearchShowMatchesPosition>)]
show_matches_position: Param<bool>, show_matches_position: Param<bool>,
#[deserr(default, error = DeserrQueryParamError<InvalidSearchShowRankingScore>)]
show_ranking_score: Param<bool>,
#[deserr(default, error = DeserrQueryParamError<InvalidSearchShowRankingScoreDetails>)]
show_ranking_score_details: Param<bool>,
#[deserr(default, error = DeserrQueryParamError<InvalidSearchFacets>)] #[deserr(default, error = DeserrQueryParamError<InvalidSearchFacets>)]
facets: Option<CS<String>>, facets: Option<CS<String>>,
#[deserr( default = DEFAULT_HIGHLIGHT_PRE_TAG(), error = DeserrQueryParamError<InvalidSearchHighlightPreTag>)] #[deserr( default = DEFAULT_HIGHLIGHT_PRE_TAG(), error = DeserrQueryParamError<InvalidSearchHighlightPreTag>)]
@ -91,6 +95,8 @@ impl From<SearchQueryGet> for SearchQuery {
filter, filter,
sort: other.sort.map(|attr| fix_sort_query_parameters(&attr)), sort: other.sort.map(|attr| fix_sort_query_parameters(&attr)),
show_matches_position: other.show_matches_position.0, show_matches_position: other.show_matches_position.0,
show_ranking_score: other.show_ranking_score.0,
show_ranking_score_details: other.show_ranking_score_details.0,
facets: other.facets.map(|o| o.into_iter().collect()), facets: other.facets.map(|o| o.into_iter().collect()),
highlight_pre_tag: other.highlight_pre_tag, highlight_pre_tag: other.highlight_pre_tag,
highlight_post_tag: other.highlight_post_tag, highlight_post_tag: other.highlight_post_tag,

View File

@ -9,6 +9,7 @@ use meilisearch_auth::IndexSearchRules;
use meilisearch_types::deserr::DeserrJsonError; use meilisearch_types::deserr::DeserrJsonError;
use meilisearch_types::error::deserr_codes::*; use meilisearch_types::error::deserr_codes::*;
use meilisearch_types::index_uid::IndexUid; use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy};
use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS; use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS;
use meilisearch_types::{milli, Document}; use meilisearch_types::{milli, Document};
use milli::tokenizer::TokenizerBuilder; use milli::tokenizer::TokenizerBuilder;
@ -54,6 +55,10 @@ pub struct SearchQuery {
pub attributes_to_highlight: Option<HashSet<String>>, pub attributes_to_highlight: Option<HashSet<String>>,
#[deserr(default, error = DeserrJsonError<InvalidSearchShowMatchesPosition>, default)] #[deserr(default, error = DeserrJsonError<InvalidSearchShowMatchesPosition>, default)]
pub show_matches_position: bool, pub show_matches_position: bool,
#[deserr(default, error = DeserrJsonError<InvalidSearchShowRankingScore>, default)]
pub show_ranking_score: bool,
#[deserr(default, error = DeserrJsonError<InvalidSearchShowRankingScoreDetails>, default)]
pub show_ranking_score_details: bool,
#[deserr(default, error = DeserrJsonError<InvalidSearchFilter>)] #[deserr(default, error = DeserrJsonError<InvalidSearchFilter>)]
pub filter: Option<Value>, pub filter: Option<Value>,
#[deserr(default, error = DeserrJsonError<InvalidSearchSort>)] #[deserr(default, error = DeserrJsonError<InvalidSearchSort>)]
@ -103,6 +108,10 @@ pub struct SearchQueryWithIndex {
pub crop_length: usize, pub crop_length: usize,
#[deserr(default, error = DeserrJsonError<InvalidSearchAttributesToHighlight>)] #[deserr(default, error = DeserrJsonError<InvalidSearchAttributesToHighlight>)]
pub attributes_to_highlight: Option<HashSet<String>>, pub attributes_to_highlight: Option<HashSet<String>>,
#[deserr(default, error = DeserrJsonError<InvalidSearchShowRankingScore>, default)]
pub show_ranking_score: bool,
#[deserr(default, error = DeserrJsonError<InvalidSearchShowRankingScoreDetails>, default)]
pub show_ranking_score_details: bool,
#[deserr(default, error = DeserrJsonError<InvalidSearchShowMatchesPosition>, default)] #[deserr(default, error = DeserrJsonError<InvalidSearchShowMatchesPosition>, default)]
pub show_matches_position: bool, pub show_matches_position: bool,
#[deserr(default, error = DeserrJsonError<InvalidSearchFilter>)] #[deserr(default, error = DeserrJsonError<InvalidSearchFilter>)]
@ -134,6 +143,8 @@ impl SearchQueryWithIndex {
attributes_to_crop, attributes_to_crop,
crop_length, crop_length,
attributes_to_highlight, attributes_to_highlight,
show_ranking_score,
show_ranking_score_details,
show_matches_position, show_matches_position,
filter, filter,
sort, sort,
@ -155,6 +166,8 @@ impl SearchQueryWithIndex {
attributes_to_crop, attributes_to_crop,
crop_length, crop_length,
attributes_to_highlight, attributes_to_highlight,
show_ranking_score,
show_ranking_score_details,
show_matches_position, show_matches_position,
filter, filter,
sort, sort,
@ -194,7 +207,7 @@ impl From<MatchingStrategy> for TermsMatchingStrategy {
} }
} }
#[derive(Debug, Clone, Serialize, PartialEq, Eq)] #[derive(Debug, Clone, Serialize, PartialEq)]
pub struct SearchHit { pub struct SearchHit {
#[serde(flatten)] #[serde(flatten)]
pub document: Document, pub document: Document,
@ -202,6 +215,10 @@ pub struct SearchHit {
pub formatted: Document, pub formatted: Document,
#[serde(rename = "_matchesPosition", skip_serializing_if = "Option::is_none")] #[serde(rename = "_matchesPosition", skip_serializing_if = "Option::is_none")]
pub matches_position: Option<MatchesPosition>, pub matches_position: Option<MatchesPosition>,
#[serde(rename = "_rankingScore", skip_serializing_if = "Option::is_none")]
pub ranking_score: Option<f64>,
#[serde(rename = "_rankingScoreDetails", skip_serializing_if = "Option::is_none")]
pub ranking_score_details: Option<serde_json::Map<String, serde_json::Value>>,
} }
#[derive(Serialize, Debug, Clone, PartialEq)] #[derive(Serialize, Debug, Clone, PartialEq)]
@ -283,6 +300,11 @@ pub fn perform_search(
.unwrap_or(DEFAULT_PAGINATION_MAX_TOTAL_HITS); .unwrap_or(DEFAULT_PAGINATION_MAX_TOTAL_HITS);
search.exhaustive_number_hits(is_finite_pagination); search.exhaustive_number_hits(is_finite_pagination);
search.scoring_strategy(if query.show_ranking_score || query.show_ranking_score_details {
ScoringStrategy::Detailed
} else {
ScoringStrategy::Skip
});
// compute the offset on the limit depending on the pagination mode. // compute the offset on the limit depending on the pagination mode.
let (offset, limit) = if is_finite_pagination { let (offset, limit) = if is_finite_pagination {
@ -320,7 +342,8 @@ pub fn perform_search(
search.sort_criteria(sort); search.sort_criteria(sort);
} }
let milli::SearchResult { documents_ids, matching_words, candidates, .. } = search.execute()?; let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } =
search.execute()?;
let fields_ids_map = index.fields_ids_map(&rtxn).unwrap(); let fields_ids_map = index.fields_ids_map(&rtxn).unwrap();
@ -392,7 +415,7 @@ pub fn perform_search(
let documents_iter = index.documents(&rtxn, documents_ids)?; let documents_iter = index.documents(&rtxn, documents_ids)?;
for (_id, obkv) in documents_iter { for ((_id, obkv), score) in documents_iter.into_iter().zip(document_scores.into_iter()) {
// First generate a document with all the displayed fields // First generate a document with all the displayed fields
let displayed_document = make_document(&displayed_ids, &fields_ids_map, obkv)?; let displayed_document = make_document(&displayed_ids, &fields_ids_map, obkv)?;
@ -416,7 +439,18 @@ pub fn perform_search(
insert_geo_distance(sort, &mut document); insert_geo_distance(sort, &mut document);
} }
let hit = SearchHit { document, formatted, matches_position }; let ranking_score =
query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter()));
let ranking_score_details =
query.show_ranking_score_details.then(|| ScoreDetails::to_json_map(score.iter()));
let hit = SearchHit {
document,
formatted,
matches_position,
ranking_score_details,
ranking_score,
};
documents.push(hit); documents.push(hit);
} }

View File

@ -53,6 +53,7 @@ fn main() -> Result<(), Box<dyn Error>> {
&mut ctx, &mut ctx,
&(!query.trim().is_empty()).then(|| query.trim().to_owned()), &(!query.trim().is_empty()).then(|| query.trim().to_owned()),
TermsMatchingStrategy::Last, TermsMatchingStrategy::Last,
milli::score_details::ScoringStrategy::Skip,
false, false,
&None, &None,
&None, &None,

View File

@ -2488,8 +2488,12 @@ pub(crate) mod tests {
let rtxn = index.read_txn().unwrap(); let rtxn = index.read_txn().unwrap();
let search = Search::new(&rtxn, &index); let search = Search::new(&rtxn, &index);
let SearchResult { matching_words: _, candidates: _, mut documents_ids } = let SearchResult {
search.execute().unwrap(); matching_words: _,
candidates: _,
document_scores: _,
mut documents_ids,
} = search.execute().unwrap();
let primary_key_id = index.fields_ids_map(&rtxn).unwrap().id("primary_key").unwrap(); let primary_key_id = index.fields_ids_map(&rtxn).unwrap().id("primary_key").unwrap();
documents_ids.sort_unstable(); documents_ids.sort_unstable();
let docs = index.documents(&rtxn, documents_ids).unwrap(); let docs = index.documents(&rtxn, documents_ids).unwrap();

View File

@ -7,6 +7,7 @@ use roaring::bitmap::RoaringBitmap;
pub use self::facet::{FacetDistribution, Filter, DEFAULT_VALUES_PER_FACET}; pub use self::facet::{FacetDistribution, Filter, DEFAULT_VALUES_PER_FACET};
pub use self::new::matches::{FormatOptions, MatchBounds, Matcher, MatcherBuilder, MatchingWords}; pub use self::new::matches::{FormatOptions, MatchBounds, Matcher, MatcherBuilder, MatchingWords};
use self::new::PartialSearchResult; use self::new::PartialSearchResult;
use crate::score_details::{ScoreDetails, ScoringStrategy};
use crate::{ use crate::{
execute_search, AscDesc, DefaultSearchLogger, DocumentId, Index, Result, SearchContext, execute_search, AscDesc, DefaultSearchLogger, DocumentId, Index, Result, SearchContext,
}; };
@ -29,6 +30,7 @@ pub struct Search<'a> {
sort_criteria: Option<Vec<AscDesc>>, sort_criteria: Option<Vec<AscDesc>>,
geo_strategy: new::GeoSortStrategy, geo_strategy: new::GeoSortStrategy,
terms_matching_strategy: TermsMatchingStrategy, terms_matching_strategy: TermsMatchingStrategy,
scoring_strategy: ScoringStrategy,
words_limit: usize, words_limit: usize,
exhaustive_number_hits: bool, exhaustive_number_hits: bool,
rtxn: &'a heed::RoTxn<'a>, rtxn: &'a heed::RoTxn<'a>,
@ -45,6 +47,7 @@ impl<'a> Search<'a> {
sort_criteria: None, sort_criteria: None,
geo_strategy: new::GeoSortStrategy::default(), geo_strategy: new::GeoSortStrategy::default(),
terms_matching_strategy: TermsMatchingStrategy::default(), terms_matching_strategy: TermsMatchingStrategy::default(),
scoring_strategy: Default::default(),
exhaustive_number_hits: false, exhaustive_number_hits: false,
words_limit: 10, words_limit: 10,
rtxn, rtxn,
@ -77,6 +80,11 @@ impl<'a> Search<'a> {
self self
} }
pub fn scoring_strategy(&mut self, value: ScoringStrategy) -> &mut Search<'a> {
self.scoring_strategy = value;
self
}
pub fn words_limit(&mut self, value: usize) -> &mut Search<'a> { pub fn words_limit(&mut self, value: usize) -> &mut Search<'a> {
self.words_limit = value; self.words_limit = value;
self self
@ -93,7 +101,7 @@ impl<'a> Search<'a> {
self self
} }
/// Force the search to exhastivelly compute the number of candidates, /// Forces the search to exhaustively compute the number of candidates,
/// this will increase the search time but allows finite pagination. /// this will increase the search time but allows finite pagination.
pub fn exhaustive_number_hits(&mut self, exhaustive_number_hits: bool) -> &mut Search<'a> { pub fn exhaustive_number_hits(&mut self, exhaustive_number_hits: bool) -> &mut Search<'a> {
self.exhaustive_number_hits = exhaustive_number_hits; self.exhaustive_number_hits = exhaustive_number_hits;
@ -102,11 +110,12 @@ impl<'a> Search<'a> {
pub fn execute(&self) -> Result<SearchResult> { pub fn execute(&self) -> Result<SearchResult> {
let mut ctx = SearchContext::new(self.index, self.rtxn); let mut ctx = SearchContext::new(self.index, self.rtxn);
let PartialSearchResult { located_query_terms, candidates, documents_ids } = let PartialSearchResult { located_query_terms, candidates, documents_ids, document_scores } =
execute_search( execute_search(
&mut ctx, &mut ctx,
&self.query, &self.query,
self.terms_matching_strategy, self.terms_matching_strategy,
self.scoring_strategy,
self.exhaustive_number_hits, self.exhaustive_number_hits,
&self.filter, &self.filter,
&self.sort_criteria, &self.sort_criteria,
@ -124,7 +133,7 @@ impl<'a> Search<'a> {
None => MatchingWords::default(), None => MatchingWords::default(),
}; };
Ok(SearchResult { matching_words, candidates, documents_ids }) Ok(SearchResult { matching_words, candidates, document_scores, documents_ids })
} }
} }
@ -138,6 +147,7 @@ impl fmt::Debug for Search<'_> {
sort_criteria, sort_criteria,
geo_strategy: _, geo_strategy: _,
terms_matching_strategy, terms_matching_strategy,
scoring_strategy,
words_limit, words_limit,
exhaustive_number_hits, exhaustive_number_hits,
rtxn: _, rtxn: _,
@ -150,6 +160,7 @@ impl fmt::Debug for Search<'_> {
.field("limit", limit) .field("limit", limit)
.field("sort_criteria", sort_criteria) .field("sort_criteria", sort_criteria)
.field("terms_matching_strategy", terms_matching_strategy) .field("terms_matching_strategy", terms_matching_strategy)
.field("scoring_strategy", scoring_strategy)
.field("exhaustive_number_hits", exhaustive_number_hits) .field("exhaustive_number_hits", exhaustive_number_hits)
.field("words_limit", words_limit) .field("words_limit", words_limit)
.finish() .finish()
@ -160,8 +171,8 @@ impl fmt::Debug for Search<'_> {
pub struct SearchResult { pub struct SearchResult {
pub matching_words: MatchingWords, pub matching_words: MatchingWords,
pub candidates: RoaringBitmap, pub candidates: RoaringBitmap,
// TODO those documents ids should be associated with their criteria scores.
pub documents_ids: Vec<DocumentId>, pub documents_ids: Vec<DocumentId>,
pub document_scores: Vec<Vec<ScoreDetails>>,
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]

View File

@ -510,6 +510,7 @@ mod tests {
&mut ctx, &mut ctx,
&Some(query.to_string()), &Some(query.to_string()),
crate::TermsMatchingStrategy::default(), crate::TermsMatchingStrategy::default(),
crate::score_details::ScoringStrategy::Skip,
false, false,
&None, &None,
&None, &None,

View File

@ -351,6 +351,7 @@ pub fn execute_search(
ctx: &mut SearchContext, ctx: &mut SearchContext,
query: &Option<String>, query: &Option<String>,
terms_matching_strategy: TermsMatchingStrategy, terms_matching_strategy: TermsMatchingStrategy,
scoring_strategy: ScoringStrategy,
exhaustive_number_hits: bool, exhaustive_number_hits: bool,
filters: &Option<Filter>, filters: &Option<Filter>,
sort_criteria: &Option<Vec<AscDesc>>, sort_criteria: &Option<Vec<AscDesc>>,
@ -419,7 +420,7 @@ pub fn execute_search(
&universe, &universe,
from, from,
length, length,
ScoringStrategy::Skip, scoring_strategy,
query_graph_logger, query_graph_logger,
)? )?
} else { } else {
@ -432,7 +433,7 @@ pub fn execute_search(
&universe, &universe,
from, from,
length, length,
ScoringStrategy::Skip, scoring_strategy,
placeholder_search_logger, placeholder_search_logger,
)? )?
}; };
@ -453,6 +454,7 @@ pub fn execute_search(
Ok(PartialSearchResult { Ok(PartialSearchResult {
candidates: all_candidates, candidates: all_candidates,
document_scores: scores,
documents_ids: docids, documents_ids: docids,
located_query_terms, located_query_terms,
}) })
@ -504,4 +506,5 @@ pub struct PartialSearchResult {
pub located_query_terms: Option<Vec<LocatedQueryTerm>>, pub located_query_terms: Option<Vec<LocatedQueryTerm>>,
pub candidates: RoaringBitmap, pub candidates: RoaringBitmap,
pub documents_ids: Vec<DocumentId>, pub documents_ids: Vec<DocumentId>,
pub document_scores: Vec<Vec<ScoreDetails>>,
} }