From 4cfb48fbb6e028cd68ecaa6161b060db817aec85 Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Wed, 8 Nov 2023 14:47:35 +0100 Subject: [PATCH] Make the basic ranking rule boosting work --- meilisearch-types/src/settings.rs | 4 + milli/src/boost.rs | 144 ++++++++++++++++++++++++++++++ milli/src/lib.rs | 1 + milli/src/ranking_rule.rs | 109 ++++++++++++---------- milli/src/score_details.rs | 16 ++++ milli/src/search/new/boost.rs | 88 ++++++++++++++++++ milli/src/search/new/mod.rs | 7 ++ 7 files changed, 323 insertions(+), 46 deletions(-) create mode 100644 milli/src/boost.rs create mode 100644 milli/src/search/new/boost.rs diff --git a/meilisearch-types/src/settings.rs b/meilisearch-types/src/settings.rs index 05f0ce1bc..9c390ba7a 100644 --- a/meilisearch-types/src/settings.rs +++ b/meilisearch-types/src/settings.rs @@ -583,6 +583,8 @@ pub enum RankingRuleView { /// Sorted by decreasing number of matched query terms. /// Query words at the front of an attribute is considered better than if it was at the back. Words, + /// Sorted by documents matching the given filter and then documents not matching it. + Boost(String), /// Sorted by increasing number of typos. Typo, /// Sorted by increasing distance between matched query terms. @@ -648,6 +650,7 @@ impl From for RankingRuleView { fn from(value: RankingRule) -> Self { match value { RankingRule::Words => RankingRuleView::Words, + RankingRule::Boost(filter) => RankingRuleView::Boost(filter), RankingRule::Typo => RankingRuleView::Typo, RankingRule::Proximity => RankingRuleView::Proximity, RankingRule::Attribute => RankingRuleView::Attribute, @@ -662,6 +665,7 @@ impl From for RankingRule { fn from(value: RankingRuleView) -> Self { match value { RankingRuleView::Words => RankingRule::Words, + RankingRuleView::Boost(filter) => RankingRule::Boost(filter), RankingRuleView::Typo => RankingRule::Typo, RankingRuleView::Proximity => RankingRule::Proximity, RankingRuleView::Attribute => RankingRule::Attribute, diff --git a/milli/src/boost.rs b/milli/src/boost.rs new file mode 100644 index 000000000..5571722bf --- /dev/null +++ b/milli/src/boost.rs @@ -0,0 +1,144 @@ +//! This module provides the `Boost` type and defines all the errors related to this type. + +use std::str::FromStr; + +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +use crate::RankingRuleError; + +/// This error type is never supposed to be shown to the end user. +/// You must always cast it to a sort error or a criterion error. +#[derive(Error, Debug)] +pub enum BoostError { + #[error("Invalid syntax for the boost parameter: expected expression ending by `boost:`, found `{name}`.")] + InvalidSyntax { name: String }, +} + +impl From for RankingRuleError { + fn from(error: BoostError) -> Self { + match error { + BoostError::InvalidSyntax { name } => RankingRuleError::InvalidName { name }, + } + } +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct Boost(pub String); + +impl Boost { + pub fn filter(&self) -> &str { + &self.0 + } +} + +impl FromStr for Boost { + type Err = BoostError; + + fn from_str(text: &str) -> Result { + match text.split_once(':') { + Some(("boost", right)) => Ok(Boost(right.to_string())), // TODO check filter validity + _ => Err(BoostError::InvalidSyntax { name: text.to_string() }), + } + } +} + +#[cfg(test)] +mod tests { + use big_s::S; + use BoostError::*; + + use super::*; + + #[test] + fn parse_asc_desc() { + let valid_req = [ + ("truc:asc", Asc(Field(S("truc")))), + ("bidule:desc", Desc(Field(S("bidule")))), + ("a-b:desc", Desc(Field(S("a-b")))), + ("a:b:desc", Desc(Field(S("a:b")))), + ("a12:asc", Asc(Field(S("a12")))), + ("42:asc", Asc(Field(S("42")))), + ("_geoPoint(42, 59):asc", Asc(Geo([42., 59.]))), + ("_geoPoint(42.459, 59):desc", Desc(Geo([42.459, 59.]))), + ("_geoPoint(42, 59.895):desc", Desc(Geo([42., 59.895]))), + ("_geoPoint(42, 59.895):desc", Desc(Geo([42., 59.895]))), + ("_geoPoint(90.000000000, 180):desc", Desc(Geo([90., 180.]))), + ("_geoPoint(-90, -180.0000000000):asc", Asc(Geo([-90., -180.]))), + ("_geoPoint(42.0002, 59.895):desc", Desc(Geo([42.0002, 59.895]))), + ("_geoPoint(42., 59.):desc", Desc(Geo([42., 59.]))), + ("truc(12, 13):desc", Desc(Field(S("truc(12, 13)")))), + ]; + + for (req, expected) in valid_req { + let res = req.parse::(); + assert!( + res.is_ok(), + "Failed to parse `{}`, was expecting `{:?}` but instead got `{:?}`", + req, + expected, + res + ); + assert_eq!(res.unwrap(), expected); + } + + let invalid_req = [ + ("truc:machin", InvalidSyntax { name: S("truc:machin") }), + ("truc:deesc", InvalidSyntax { name: S("truc:deesc") }), + ("truc:asc:deesc", InvalidSyntax { name: S("truc:asc:deesc") }), + ("42desc", InvalidSyntax { name: S("42desc") }), + ("_geoPoint:asc", ReservedKeyword { name: S("_geoPoint") }), + ("_geoDistance:asc", ReservedKeyword { name: S("_geoDistance") }), + ("_geoPoint(42.12 , 59.598)", InvalidSyntax { name: S("_geoPoint(42.12 , 59.598)") }), + ( + "_geoPoint(42.12 , 59.598):deesc", + InvalidSyntax { name: S("_geoPoint(42.12 , 59.598):deesc") }, + ), + ( + "_geoPoint(42.12 , 59.598):machin", + InvalidSyntax { name: S("_geoPoint(42.12 , 59.598):machin") }, + ), + ( + "_geoPoint(42.12 , 59.598):asc:aasc", + InvalidSyntax { name: S("_geoPoint(42.12 , 59.598):asc:aasc") }, + ), + ( + "_geoPoint(42,12 , 59,598):desc", + ReservedKeyword { name: S("_geoPoint(42,12 , 59,598)") }, + ), + ("_geoPoint(35, 85, 75):asc", ReservedKeyword { name: S("_geoPoint(35, 85, 75)") }), + ("_geoPoint(18):asc", ReservedKeyword { name: S("_geoPoint(18)") }), + ("_geoPoint(200, 200):asc", GeoError(BadGeoError::Lat(200.))), + ("_geoPoint(90.000001, 0):asc", GeoError(BadGeoError::Lat(90.000001))), + ("_geoPoint(0, -180.000001):desc", GeoError(BadGeoError::Lng(-180.000001))), + ("_geoPoint(159.256, 130):asc", GeoError(BadGeoError::Lat(159.256))), + ("_geoPoint(12, -2021):desc", GeoError(BadGeoError::Lng(-2021.))), + ("_geo(12, -2021):asc", ReservedKeyword { name: S("_geo(12, -2021)") }), + ("_geo(12, -2021):desc", ReservedKeyword { name: S("_geo(12, -2021)") }), + ("_geoDistance(12, -2021):asc", ReservedKeyword { name: S("_geoDistance(12, -2021)") }), + ( + "_geoDistance(12, -2021):desc", + ReservedKeyword { name: S("_geoDistance(12, -2021)") }, + ), + ]; + + for (req, expected_error) in invalid_req { + let res = req.parse::(); + assert!( + res.is_err(), + "Should no be able to parse `{}`, was expecting an error but instead got: `{:?}`", + req, + res, + ); + let res = res.unwrap_err(); + assert_eq!( + res.to_string(), + expected_error.to_string(), + "Bad error for input {}: got `{:?}` instead of `{:?}`", + req, + res, + expected_error + ); + } + } +} diff --git a/milli/src/lib.rs b/milli/src/lib.rs index e9cef7d48..fb6201f1f 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -9,6 +9,7 @@ pub static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; pub mod documents; mod asc_desc; +mod boost; pub mod distance; mod error; mod external_documents_ids; diff --git a/milli/src/ranking_rule.rs b/milli/src/ranking_rule.rs index 45cbfe63d..18b8fc9a2 100644 --- a/milli/src/ranking_rule.rs +++ b/milli/src/ranking_rule.rs @@ -4,10 +4,11 @@ use std::str::FromStr; use serde::{Deserialize, Serialize}; use thiserror::Error; -use crate::{AscDesc, Member}; +use crate::boost::{Boost, BoostError}; +use crate::{AscDesc, AscDescError, Member}; #[derive(Error, Debug)] -pub enum CriterionError { +pub enum RankingRuleError { #[error("`{name}` ranking rule is invalid. Valid ranking rules are words, typo, sort, proximity, attribute, exactness and custom ranking rules.")] InvalidName { name: String }, #[error("`{name}` is a reserved keyword and thus can't be used as a ranking rule")] @@ -25,10 +26,12 @@ pub enum CriterionError { } #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] -pub enum Criterion { +pub enum RankingRule { /// Sorted by decreasing number of matched query terms. /// Query words at the front of an attribute is considered better than if it was at the back. Words, + /// Sorted by documents matching the given filter and then documents not matching it. + Boost(String), /// Sorted by increasing number of typos. Typo, /// Sorted by increasing distance between matched query terms. @@ -47,62 +50,76 @@ pub enum Criterion { Desc(String), } -impl Criterion { +impl RankingRule { /// Returns the field name parameter of this criterion. pub fn field_name(&self) -> Option<&str> { match self { - Criterion::Asc(name) | Criterion::Desc(name) => Some(name), + RankingRule::Asc(name) | RankingRule::Desc(name) => Some(name), _otherwise => None, } } } -impl FromStr for Criterion { - type Err = CriterionError; +impl FromStr for RankingRule { + type Err = RankingRuleError; - fn from_str(text: &str) -> Result { + fn from_str(text: &str) -> Result { match text { - "words" => Ok(Criterion::Words), - "typo" => Ok(Criterion::Typo), - "proximity" => Ok(Criterion::Proximity), - "attribute" => Ok(Criterion::Attribute), - "sort" => Ok(Criterion::Sort), - "exactness" => Ok(Criterion::Exactness), - text => match AscDesc::from_str(text)? { - AscDesc::Asc(Member::Field(field)) => Ok(Criterion::Asc(field)), - AscDesc::Desc(Member::Field(field)) => Ok(Criterion::Desc(field)), - AscDesc::Asc(Member::Geo(_)) | AscDesc::Desc(Member::Geo(_)) => { - Err(CriterionError::ReservedNameForSort { name: "_geoPoint".to_string() })? - } + "words" => Ok(RankingRule::Words), + "typo" => Ok(RankingRule::Typo), + "proximity" => Ok(RankingRule::Proximity), + "attribute" => Ok(RankingRule::Attribute), + "sort" => Ok(RankingRule::Sort), + "exactness" => Ok(RankingRule::Exactness), + text => match (AscDesc::from_str(text), Boost::from_str(text)) { + (Ok(asc_desc), _) => match asc_desc { + AscDesc::Asc(Member::Field(field)) => Ok(RankingRule::Asc(field)), + AscDesc::Desc(Member::Field(field)) => Ok(RankingRule::Desc(field)), + AscDesc::Asc(Member::Geo(_)) | AscDesc::Desc(Member::Geo(_)) => { + Err(RankingRuleError::ReservedNameForSort { + name: "_geoPoint".to_string(), + })? + } + }, + (_, Ok(Boost(filter))) => Ok(RankingRule::Boost(filter)), + ( + Err(AscDescError::InvalidSyntax { name: asc_desc_name }), + Err(BoostError::InvalidSyntax { name: boost_name }), + ) => Err(RankingRuleError::InvalidName { + // TODO improve the error message quality + name: format!("{asc_desc_name} {boost_name}"), + }), + (Err(asc_desc_error), _) => Err(asc_desc_error.into()), }, } } } -pub fn default_criteria() -> Vec { +pub fn default_criteria() -> Vec { vec![ - Criterion::Words, - Criterion::Typo, - Criterion::Proximity, - Criterion::Attribute, - Criterion::Sort, - Criterion::Exactness, + RankingRule::Words, + RankingRule::Typo, + RankingRule::Proximity, + RankingRule::Attribute, + RankingRule::Sort, + RankingRule::Exactness, ] } -impl fmt::Display for Criterion { +impl fmt::Display for RankingRule { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - use Criterion::*; + use RankingRule::*; match self { Words => f.write_str("words"), + Boost(filter) => write!(f, "boost:{filter}"), Typo => f.write_str("typo"), Proximity => f.write_str("proximity"), Attribute => f.write_str("attribute"), Sort => f.write_str("sort"), Exactness => f.write_str("exactness"), - Asc(attr) => write!(f, "{}:asc", attr), - Desc(attr) => write!(f, "{}:desc", attr), + Asc(attr) => write!(f, "{attr}:asc"), + Desc(attr) => write!(f, "{attr}:desc"), } } } @@ -110,29 +127,29 @@ impl fmt::Display for Criterion { #[cfg(test)] mod tests { use big_s::S; - use CriterionError::*; + use RankingRuleError::*; use super::*; #[test] fn parse_criterion() { let valid_criteria = [ - ("words", Criterion::Words), - ("typo", Criterion::Typo), - ("proximity", Criterion::Proximity), - ("attribute", Criterion::Attribute), - ("sort", Criterion::Sort), - ("exactness", Criterion::Exactness), - ("price:asc", Criterion::Asc(S("price"))), - ("price:desc", Criterion::Desc(S("price"))), - ("price:asc:desc", Criterion::Desc(S("price:asc"))), - ("truc:machin:desc", Criterion::Desc(S("truc:machin"))), - ("hello-world!:desc", Criterion::Desc(S("hello-world!"))), - ("it's spacy over there:asc", Criterion::Asc(S("it's spacy over there"))), + ("words", RankingRule::Words), + ("typo", RankingRule::Typo), + ("proximity", RankingRule::Proximity), + ("attribute", RankingRule::Attribute), + ("sort", RankingRule::Sort), + ("exactness", RankingRule::Exactness), + ("price:asc", RankingRule::Asc(S("price"))), + ("price:desc", RankingRule::Desc(S("price"))), + ("price:asc:desc", RankingRule::Desc(S("price:asc"))), + ("truc:machin:desc", RankingRule::Desc(S("truc:machin"))), + ("hello-world!:desc", RankingRule::Desc(S("hello-world!"))), + ("it's spacy over there:asc", RankingRule::Asc(S("it's spacy over there"))), ]; for (input, expected) in valid_criteria { - let res = input.parse::(); + let res = input.parse::(); assert!( res.is_ok(), "Failed to parse `{}`, was expecting `{:?}` but instead got `{:?}`", @@ -167,7 +184,7 @@ mod tests { ]; for (input, expected) in invalid_criteria { - let res = input.parse::(); + let res = input.parse::(); assert!( res.is_err(), "Should no be able to parse `{}`, was expecting an error but instead got: `{:?}`", diff --git a/milli/src/score_details.rs b/milli/src/score_details.rs index 8fc998ae4..4a264143d 100644 --- a/milli/src/score_details.rs +++ b/milli/src/score_details.rs @@ -5,6 +5,7 @@ use crate::distance_between_two_points; #[derive(Debug, Clone, PartialEq)] pub enum ScoreDetails { Words(Words), + Boost(Boost), Typo(Typo), Proximity(Rank), Fid(Rank), @@ -23,6 +24,7 @@ impl ScoreDetails { pub fn rank(&self) -> Option { match self { ScoreDetails::Words(details) => Some(details.rank()), + ScoreDetails::Boost(_) => None, ScoreDetails::Typo(details) => Some(details.rank()), ScoreDetails::Proximity(details) => Some(*details), ScoreDetails::Fid(details) => Some(*details), @@ -60,6 +62,14 @@ impl ScoreDetails { details_map.insert("words".into(), words_details); order += 1; } + ScoreDetails::Boost(Boost { filter, matching }) => { + let sort = format!("boost:{}", filter); + let sort_details = serde_json::json!({ + "value": matching, + }); + details_map.insert(sort, sort_details); + order += 1; + } ScoreDetails::Typo(typo) => { let typo_details = serde_json::json!({ "order": order, @@ -221,6 +231,12 @@ impl Words { } } +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Boost { + pub filter: String, + pub matching: bool, +} + /// Structure that is super similar to [`Words`], but whose semantics is a bit distinct. /// /// In exactness, the number of matching words can actually be 0 with a non-zero score, diff --git a/milli/src/search/new/boost.rs b/milli/src/search/new/boost.rs new file mode 100644 index 000000000..f1a001a4c --- /dev/null +++ b/milli/src/search/new/boost.rs @@ -0,0 +1,88 @@ +use roaring::RoaringBitmap; + +use super::logger::SearchLogger; +use super::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait, SearchContext}; +use crate::score_details::{self, ScoreDetails}; +use crate::{Filter, Result}; + +pub struct Boost { + original_expression: String, + original_query: Option, + matching: Option>, + non_matching: Option>, +} + +impl Boost { + pub fn new(expression: String) -> Result { + Ok(Self { + original_expression: expression, + original_query: None, + matching: None, + non_matching: None, + }) + } +} + +impl<'ctx, Query: RankingRuleQueryTrait> RankingRule<'ctx, Query> for Boost { + fn id(&self) -> String { + // TODO improve this + let Self { original_expression, .. } = self; + format!("boost:{original_expression}") + } + + fn start_iteration( + &mut self, + ctx: &mut SearchContext<'ctx>, + _logger: &mut dyn SearchLogger, + parent_candidates: &RoaringBitmap, + parent_query: &Query, + ) -> Result<()> { + let universe_matching = match Filter::from_str(&self.original_expression)? { + Some(filter) => filter.evaluate(ctx.txn, ctx.index)?, + None => RoaringBitmap::default(), + }; + let matching = parent_candidates & universe_matching; + let non_matching = parent_candidates - &matching; + + self.original_query = Some(parent_query.clone()); + + self.matching = Some(RankingRuleOutput { + query: parent_query.clone(), + candidates: matching, + score: ScoreDetails::Boost(score_details::Boost { + filter: self.original_expression.clone(), + matching: true, + }), + }); + + self.non_matching = Some(RankingRuleOutput { + query: parent_query.clone(), + candidates: non_matching, + score: ScoreDetails::Boost(score_details::Boost { + filter: self.original_expression.clone(), + matching: false, + }), + }); + + Ok(()) + } + + fn next_bucket( + &mut self, + _ctx: &mut SearchContext<'ctx>, + _logger: &mut dyn SearchLogger, + _universe: &RoaringBitmap, + ) -> Result>> { + Ok(self.matching.take().or_else(|| self.non_matching.take())) + } + + fn end_iteration( + &mut self, + _ctx: &mut SearchContext<'ctx>, + _logger: &mut dyn SearchLogger, + ) { + self.original_query = None; + self.matching = None; + self.non_matching = None; + } +} diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index c93fd6b01..2066e5ad0 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -14,6 +14,7 @@ mod ranking_rules; mod resolve_query_graph; mod small_bitmap; +mod boost; mod exact_attribute; mod sort; @@ -22,6 +23,7 @@ mod tests; use std::collections::HashSet; +use boost::Boost; use bucket_sort::{bucket_sort, BucketSortOutput}; use charabia::TokenizerBuilder; use db_cache::DatabaseCache; @@ -208,6 +210,7 @@ fn get_ranking_rules_for_placeholder_search<'ctx>( | crate::RankingRule::Attribute | crate::RankingRule::Proximity | crate::RankingRule::Exactness => continue, + crate::RankingRule::Boost(filter) => ranking_rules.push(Box::new(Boost::new(filter)?)), crate::RankingRule::Sort => { if sort { continue; @@ -287,6 +290,9 @@ fn get_ranking_rules_for_query_graph_search<'ctx>( ranking_rules.push(Box::new(Words::new(terms_matching_strategy))); words = true; } + crate::RankingRule::Boost(filter) => { + ranking_rules.push(Box::new(Boost::new(filter)?)); + } crate::RankingRule::Typo => { if typo { continue; @@ -332,6 +338,7 @@ fn get_ranking_rules_for_query_graph_search<'ctx>( exactness = true; } crate::RankingRule::Asc(field_name) => { + // TODO Question: Why would it be invalid to sort price:asc, typo, price:desc? if sorted_fields.contains(&field_name) { continue; }