Make the basic ranking rule boosting work

This commit is contained in:
Kerollmops 2023-11-08 14:47:35 +01:00
parent 67dc0268c5
commit 4cfb48fbb6
No known key found for this signature in database
GPG Key ID: F250A4C4E3AE5F5F
7 changed files with 323 additions and 46 deletions

View File

@ -583,6 +583,8 @@ pub enum RankingRuleView {
/// Sorted by decreasing number of matched query terms.
/// Query words at the front of an attribute is considered better than if it was at the back.
Words,
/// Sorted by documents matching the given filter and then documents not matching it.
Boost(String),
/// Sorted by increasing number of typos.
Typo,
/// Sorted by increasing distance between matched query terms.
@ -648,6 +650,7 @@ impl From<RankingRule> for RankingRuleView {
fn from(value: RankingRule) -> Self {
match value {
RankingRule::Words => RankingRuleView::Words,
RankingRule::Boost(filter) => RankingRuleView::Boost(filter),
RankingRule::Typo => RankingRuleView::Typo,
RankingRule::Proximity => RankingRuleView::Proximity,
RankingRule::Attribute => RankingRuleView::Attribute,
@ -662,6 +665,7 @@ impl From<RankingRuleView> for RankingRule {
fn from(value: RankingRuleView) -> Self {
match value {
RankingRuleView::Words => RankingRule::Words,
RankingRuleView::Boost(filter) => RankingRule::Boost(filter),
RankingRuleView::Typo => RankingRule::Typo,
RankingRuleView::Proximity => RankingRule::Proximity,
RankingRuleView::Attribute => RankingRule::Attribute,

144
milli/src/boost.rs Normal file
View File

@ -0,0 +1,144 @@
//! This module provides the `Boost` type and defines all the errors related to this type.
use std::str::FromStr;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::RankingRuleError;
/// This error type is never supposed to be shown to the end user.
/// You must always cast it to a sort error or a criterion error.
#[derive(Error, Debug)]
pub enum BoostError {
#[error("Invalid syntax for the boost parameter: expected expression ending by `boost:`, found `{name}`.")]
InvalidSyntax { name: String },
}
impl From<BoostError> for RankingRuleError {
fn from(error: BoostError) -> Self {
match error {
BoostError::InvalidSyntax { name } => RankingRuleError::InvalidName { name },
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct Boost(pub String);
impl Boost {
pub fn filter(&self) -> &str {
&self.0
}
}
impl FromStr for Boost {
type Err = BoostError;
fn from_str(text: &str) -> Result<Boost, Self::Err> {
match text.split_once(':') {
Some(("boost", right)) => Ok(Boost(right.to_string())), // TODO check filter validity
_ => Err(BoostError::InvalidSyntax { name: text.to_string() }),
}
}
}
#[cfg(test)]
mod tests {
use big_s::S;
use BoostError::*;
use super::*;
#[test]
fn parse_asc_desc() {
let valid_req = [
("truc:asc", Asc(Field(S("truc")))),
("bidule:desc", Desc(Field(S("bidule")))),
("a-b:desc", Desc(Field(S("a-b")))),
("a:b:desc", Desc(Field(S("a:b")))),
("a12:asc", Asc(Field(S("a12")))),
("42:asc", Asc(Field(S("42")))),
("_geoPoint(42, 59):asc", Asc(Geo([42., 59.]))),
("_geoPoint(42.459, 59):desc", Desc(Geo([42.459, 59.]))),
("_geoPoint(42, 59.895):desc", Desc(Geo([42., 59.895]))),
("_geoPoint(42, 59.895):desc", Desc(Geo([42., 59.895]))),
("_geoPoint(90.000000000, 180):desc", Desc(Geo([90., 180.]))),
("_geoPoint(-90, -180.0000000000):asc", Asc(Geo([-90., -180.]))),
("_geoPoint(42.0002, 59.895):desc", Desc(Geo([42.0002, 59.895]))),
("_geoPoint(42., 59.):desc", Desc(Geo([42., 59.]))),
("truc(12, 13):desc", Desc(Field(S("truc(12, 13)")))),
];
for (req, expected) in valid_req {
let res = req.parse::<Boost>();
assert!(
res.is_ok(),
"Failed to parse `{}`, was expecting `{:?}` but instead got `{:?}`",
req,
expected,
res
);
assert_eq!(res.unwrap(), expected);
}
let invalid_req = [
("truc:machin", InvalidSyntax { name: S("truc:machin") }),
("truc:deesc", InvalidSyntax { name: S("truc:deesc") }),
("truc:asc:deesc", InvalidSyntax { name: S("truc:asc:deesc") }),
("42desc", InvalidSyntax { name: S("42desc") }),
("_geoPoint:asc", ReservedKeyword { name: S("_geoPoint") }),
("_geoDistance:asc", ReservedKeyword { name: S("_geoDistance") }),
("_geoPoint(42.12 , 59.598)", InvalidSyntax { name: S("_geoPoint(42.12 , 59.598)") }),
(
"_geoPoint(42.12 , 59.598):deesc",
InvalidSyntax { name: S("_geoPoint(42.12 , 59.598):deesc") },
),
(
"_geoPoint(42.12 , 59.598):machin",
InvalidSyntax { name: S("_geoPoint(42.12 , 59.598):machin") },
),
(
"_geoPoint(42.12 , 59.598):asc:aasc",
InvalidSyntax { name: S("_geoPoint(42.12 , 59.598):asc:aasc") },
),
(
"_geoPoint(42,12 , 59,598):desc",
ReservedKeyword { name: S("_geoPoint(42,12 , 59,598)") },
),
("_geoPoint(35, 85, 75):asc", ReservedKeyword { name: S("_geoPoint(35, 85, 75)") }),
("_geoPoint(18):asc", ReservedKeyword { name: S("_geoPoint(18)") }),
("_geoPoint(200, 200):asc", GeoError(BadGeoError::Lat(200.))),
("_geoPoint(90.000001, 0):asc", GeoError(BadGeoError::Lat(90.000001))),
("_geoPoint(0, -180.000001):desc", GeoError(BadGeoError::Lng(-180.000001))),
("_geoPoint(159.256, 130):asc", GeoError(BadGeoError::Lat(159.256))),
("_geoPoint(12, -2021):desc", GeoError(BadGeoError::Lng(-2021.))),
("_geo(12, -2021):asc", ReservedKeyword { name: S("_geo(12, -2021)") }),
("_geo(12, -2021):desc", ReservedKeyword { name: S("_geo(12, -2021)") }),
("_geoDistance(12, -2021):asc", ReservedKeyword { name: S("_geoDistance(12, -2021)") }),
(
"_geoDistance(12, -2021):desc",
ReservedKeyword { name: S("_geoDistance(12, -2021)") },
),
];
for (req, expected_error) in invalid_req {
let res = req.parse::<Boost>();
assert!(
res.is_err(),
"Should no be able to parse `{}`, was expecting an error but instead got: `{:?}`",
req,
res,
);
let res = res.unwrap_err();
assert_eq!(
res.to_string(),
expected_error.to_string(),
"Bad error for input {}: got `{:?}` instead of `{:?}`",
req,
res,
expected_error
);
}
}
}

View File

@ -9,6 +9,7 @@ pub static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc;
pub mod documents;
mod asc_desc;
mod boost;
pub mod distance;
mod error;
mod external_documents_ids;

View File

@ -4,10 +4,11 @@ use std::str::FromStr;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::{AscDesc, Member};
use crate::boost::{Boost, BoostError};
use crate::{AscDesc, AscDescError, Member};
#[derive(Error, Debug)]
pub enum CriterionError {
pub enum RankingRuleError {
#[error("`{name}` ranking rule is invalid. Valid ranking rules are words, typo, sort, proximity, attribute, exactness and custom ranking rules.")]
InvalidName { name: String },
#[error("`{name}` is a reserved keyword and thus can't be used as a ranking rule")]
@ -25,10 +26,12 @@ pub enum CriterionError {
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
pub enum Criterion {
pub enum RankingRule {
/// Sorted by decreasing number of matched query terms.
/// Query words at the front of an attribute is considered better than if it was at the back.
Words,
/// Sorted by documents matching the given filter and then documents not matching it.
Boost(String),
/// Sorted by increasing number of typos.
Typo,
/// Sorted by increasing distance between matched query terms.
@ -47,62 +50,76 @@ pub enum Criterion {
Desc(String),
}
impl Criterion {
impl RankingRule {
/// Returns the field name parameter of this criterion.
pub fn field_name(&self) -> Option<&str> {
match self {
Criterion::Asc(name) | Criterion::Desc(name) => Some(name),
RankingRule::Asc(name) | RankingRule::Desc(name) => Some(name),
_otherwise => None,
}
}
}
impl FromStr for Criterion {
type Err = CriterionError;
impl FromStr for RankingRule {
type Err = RankingRuleError;
fn from_str(text: &str) -> Result<Criterion, Self::Err> {
fn from_str(text: &str) -> Result<RankingRule, Self::Err> {
match text {
"words" => Ok(Criterion::Words),
"typo" => Ok(Criterion::Typo),
"proximity" => Ok(Criterion::Proximity),
"attribute" => Ok(Criterion::Attribute),
"sort" => Ok(Criterion::Sort),
"exactness" => Ok(Criterion::Exactness),
text => match AscDesc::from_str(text)? {
AscDesc::Asc(Member::Field(field)) => Ok(Criterion::Asc(field)),
AscDesc::Desc(Member::Field(field)) => Ok(Criterion::Desc(field)),
AscDesc::Asc(Member::Geo(_)) | AscDesc::Desc(Member::Geo(_)) => {
Err(CriterionError::ReservedNameForSort { name: "_geoPoint".to_string() })?
}
"words" => Ok(RankingRule::Words),
"typo" => Ok(RankingRule::Typo),
"proximity" => Ok(RankingRule::Proximity),
"attribute" => Ok(RankingRule::Attribute),
"sort" => Ok(RankingRule::Sort),
"exactness" => Ok(RankingRule::Exactness),
text => match (AscDesc::from_str(text), Boost::from_str(text)) {
(Ok(asc_desc), _) => match asc_desc {
AscDesc::Asc(Member::Field(field)) => Ok(RankingRule::Asc(field)),
AscDesc::Desc(Member::Field(field)) => Ok(RankingRule::Desc(field)),
AscDesc::Asc(Member::Geo(_)) | AscDesc::Desc(Member::Geo(_)) => {
Err(RankingRuleError::ReservedNameForSort {
name: "_geoPoint".to_string(),
})?
}
},
(_, Ok(Boost(filter))) => Ok(RankingRule::Boost(filter)),
(
Err(AscDescError::InvalidSyntax { name: asc_desc_name }),
Err(BoostError::InvalidSyntax { name: boost_name }),
) => Err(RankingRuleError::InvalidName {
// TODO improve the error message quality
name: format!("{asc_desc_name} {boost_name}"),
}),
(Err(asc_desc_error), _) => Err(asc_desc_error.into()),
},
}
}
}
pub fn default_criteria() -> Vec<Criterion> {
pub fn default_criteria() -> Vec<RankingRule> {
vec![
Criterion::Words,
Criterion::Typo,
Criterion::Proximity,
Criterion::Attribute,
Criterion::Sort,
Criterion::Exactness,
RankingRule::Words,
RankingRule::Typo,
RankingRule::Proximity,
RankingRule::Attribute,
RankingRule::Sort,
RankingRule::Exactness,
]
}
impl fmt::Display for Criterion {
impl fmt::Display for RankingRule {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use Criterion::*;
use RankingRule::*;
match self {
Words => f.write_str("words"),
Boost(filter) => write!(f, "boost:{filter}"),
Typo => f.write_str("typo"),
Proximity => f.write_str("proximity"),
Attribute => f.write_str("attribute"),
Sort => f.write_str("sort"),
Exactness => f.write_str("exactness"),
Asc(attr) => write!(f, "{}:asc", attr),
Desc(attr) => write!(f, "{}:desc", attr),
Asc(attr) => write!(f, "{attr}:asc"),
Desc(attr) => write!(f, "{attr}:desc"),
}
}
}
@ -110,29 +127,29 @@ impl fmt::Display for Criterion {
#[cfg(test)]
mod tests {
use big_s::S;
use CriterionError::*;
use RankingRuleError::*;
use super::*;
#[test]
fn parse_criterion() {
let valid_criteria = [
("words", Criterion::Words),
("typo", Criterion::Typo),
("proximity", Criterion::Proximity),
("attribute", Criterion::Attribute),
("sort", Criterion::Sort),
("exactness", Criterion::Exactness),
("price:asc", Criterion::Asc(S("price"))),
("price:desc", Criterion::Desc(S("price"))),
("price:asc:desc", Criterion::Desc(S("price:asc"))),
("truc:machin:desc", Criterion::Desc(S("truc:machin"))),
("hello-world!:desc", Criterion::Desc(S("hello-world!"))),
("it's spacy over there:asc", Criterion::Asc(S("it's spacy over there"))),
("words", RankingRule::Words),
("typo", RankingRule::Typo),
("proximity", RankingRule::Proximity),
("attribute", RankingRule::Attribute),
("sort", RankingRule::Sort),
("exactness", RankingRule::Exactness),
("price:asc", RankingRule::Asc(S("price"))),
("price:desc", RankingRule::Desc(S("price"))),
("price:asc:desc", RankingRule::Desc(S("price:asc"))),
("truc:machin:desc", RankingRule::Desc(S("truc:machin"))),
("hello-world!:desc", RankingRule::Desc(S("hello-world!"))),
("it's spacy over there:asc", RankingRule::Asc(S("it's spacy over there"))),
];
for (input, expected) in valid_criteria {
let res = input.parse::<Criterion>();
let res = input.parse::<RankingRule>();
assert!(
res.is_ok(),
"Failed to parse `{}`, was expecting `{:?}` but instead got `{:?}`",
@ -167,7 +184,7 @@ mod tests {
];
for (input, expected) in invalid_criteria {
let res = input.parse::<Criterion>();
let res = input.parse::<RankingRule>();
assert!(
res.is_err(),
"Should no be able to parse `{}`, was expecting an error but instead got: `{:?}`",

View File

@ -5,6 +5,7 @@ use crate::distance_between_two_points;
#[derive(Debug, Clone, PartialEq)]
pub enum ScoreDetails {
Words(Words),
Boost(Boost),
Typo(Typo),
Proximity(Rank),
Fid(Rank),
@ -23,6 +24,7 @@ impl ScoreDetails {
pub fn rank(&self) -> Option<Rank> {
match self {
ScoreDetails::Words(details) => Some(details.rank()),
ScoreDetails::Boost(_) => None,
ScoreDetails::Typo(details) => Some(details.rank()),
ScoreDetails::Proximity(details) => Some(*details),
ScoreDetails::Fid(details) => Some(*details),
@ -60,6 +62,14 @@ impl ScoreDetails {
details_map.insert("words".into(), words_details);
order += 1;
}
ScoreDetails::Boost(Boost { filter, matching }) => {
let sort = format!("boost:{}", filter);
let sort_details = serde_json::json!({
"value": matching,
});
details_map.insert(sort, sort_details);
order += 1;
}
ScoreDetails::Typo(typo) => {
let typo_details = serde_json::json!({
"order": order,
@ -221,6 +231,12 @@ impl Words {
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Boost {
pub filter: String,
pub matching: bool,
}
/// Structure that is super similar to [`Words`], but whose semantics is a bit distinct.
///
/// In exactness, the number of matching words can actually be 0 with a non-zero score,

View File

@ -0,0 +1,88 @@
use roaring::RoaringBitmap;
use super::logger::SearchLogger;
use super::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait, SearchContext};
use crate::score_details::{self, ScoreDetails};
use crate::{Filter, Result};
pub struct Boost<Query> {
original_expression: String,
original_query: Option<Query>,
matching: Option<RankingRuleOutput<Query>>,
non_matching: Option<RankingRuleOutput<Query>>,
}
impl<Query> Boost<Query> {
pub fn new(expression: String) -> Result<Self> {
Ok(Self {
original_expression: expression,
original_query: None,
matching: None,
non_matching: None,
})
}
}
impl<'ctx, Query: RankingRuleQueryTrait> RankingRule<'ctx, Query> for Boost<Query> {
fn id(&self) -> String {
// TODO improve this
let Self { original_expression, .. } = self;
format!("boost:{original_expression}")
}
fn start_iteration(
&mut self,
ctx: &mut SearchContext<'ctx>,
_logger: &mut dyn SearchLogger<Query>,
parent_candidates: &RoaringBitmap,
parent_query: &Query,
) -> Result<()> {
let universe_matching = match Filter::from_str(&self.original_expression)? {
Some(filter) => filter.evaluate(ctx.txn, ctx.index)?,
None => RoaringBitmap::default(),
};
let matching = parent_candidates & universe_matching;
let non_matching = parent_candidates - &matching;
self.original_query = Some(parent_query.clone());
self.matching = Some(RankingRuleOutput {
query: parent_query.clone(),
candidates: matching,
score: ScoreDetails::Boost(score_details::Boost {
filter: self.original_expression.clone(),
matching: true,
}),
});
self.non_matching = Some(RankingRuleOutput {
query: parent_query.clone(),
candidates: non_matching,
score: ScoreDetails::Boost(score_details::Boost {
filter: self.original_expression.clone(),
matching: false,
}),
});
Ok(())
}
fn next_bucket(
&mut self,
_ctx: &mut SearchContext<'ctx>,
_logger: &mut dyn SearchLogger<Query>,
_universe: &RoaringBitmap,
) -> Result<Option<RankingRuleOutput<Query>>> {
Ok(self.matching.take().or_else(|| self.non_matching.take()))
}
fn end_iteration(
&mut self,
_ctx: &mut SearchContext<'ctx>,
_logger: &mut dyn SearchLogger<Query>,
) {
self.original_query = None;
self.matching = None;
self.non_matching = None;
}
}

View File

@ -14,6 +14,7 @@ mod ranking_rules;
mod resolve_query_graph;
mod small_bitmap;
mod boost;
mod exact_attribute;
mod sort;
@ -22,6 +23,7 @@ mod tests;
use std::collections::HashSet;
use boost::Boost;
use bucket_sort::{bucket_sort, BucketSortOutput};
use charabia::TokenizerBuilder;
use db_cache::DatabaseCache;
@ -208,6 +210,7 @@ fn get_ranking_rules_for_placeholder_search<'ctx>(
| crate::RankingRule::Attribute
| crate::RankingRule::Proximity
| crate::RankingRule::Exactness => continue,
crate::RankingRule::Boost(filter) => ranking_rules.push(Box::new(Boost::new(filter)?)),
crate::RankingRule::Sort => {
if sort {
continue;
@ -287,6 +290,9 @@ fn get_ranking_rules_for_query_graph_search<'ctx>(
ranking_rules.push(Box::new(Words::new(terms_matching_strategy)));
words = true;
}
crate::RankingRule::Boost(filter) => {
ranking_rules.push(Box::new(Boost::new(filter)?));
}
crate::RankingRule::Typo => {
if typo {
continue;
@ -332,6 +338,7 @@ fn get_ranking_rules_for_query_graph_search<'ctx>(
exactness = true;
}
crate::RankingRule::Asc(field_name) => {
// TODO Question: Why would it be invalid to sort price:asc, typo, price:desc?
if sorted_fields.contains(&field_name) {
continue;
}