diff --git a/meilidb-core/src/criterion/mod.rs b/meilidb-core/src/criterion/mod.rs index 6ce42007c..53a781e26 100644 --- a/meilidb-core/src/criterion/mod.rs +++ b/meilidb-core/src/criterion/mod.rs @@ -4,6 +4,7 @@ mod words_proximity; mod sum_of_words_attribute; mod sum_of_words_position; mod exact; +mod sort_by_attr; mod document_id; use std::cmp::Ordering; @@ -16,6 +17,7 @@ pub use self::{ sum_of_words_attribute::SumOfWordsAttribute, sum_of_words_position::SumOfWordsPosition, exact::Exact, + sort_by_attr::SortByAttr, document_id::DocumentId, }; diff --git a/meilidb-core/src/criterion/sort_by_attr.rs b/meilidb-core/src/criterion/sort_by_attr.rs new file mode 100644 index 000000000..68a5e5b69 --- /dev/null +++ b/meilidb-core/src/criterion/sort_by_attr.rs @@ -0,0 +1,125 @@ +use std::cmp::Ordering; +use std::error::Error; +use std::fmt; + +use meilidb_schema::{Schema, SchemaAttr}; +use crate::criterion::Criterion; +use crate::{RawDocument, RankedMap}; + +/// An helper struct that permit to sort documents by +/// some of their stored attributes. +/// +/// # Note +/// +/// If a document cannot be deserialized it will be considered [`None`][]. +/// +/// Deserialized documents are compared like `Some(doc0).cmp(&Some(doc1))`, +/// so you must check the [`Ord`] of `Option` implementation. +/// +/// [`None`]: https://doc.rust-lang.org/std/option/enum.Option.html#variant.None +/// [`Ord`]: https://doc.rust-lang.org/std/option/enum.Option.html#impl-Ord +/// +/// # Example +/// +/// ```ignore +/// use serde_derive::Deserialize; +/// use meilidb::rank::criterion::*; +/// +/// let custom_ranking = SortByAttr::lower_is_better(&ranked_map, &schema, "published_at")?; +/// +/// let builder = CriteriaBuilder::with_capacity(8) +/// .add(SumOfTypos) +/// .add(NumberOfWords) +/// .add(WordsProximity) +/// .add(SumOfWordsAttribute) +/// .add(SumOfWordsPosition) +/// .add(Exact) +/// .add(custom_ranking) +/// .add(DocumentId); +/// +/// let criterion = builder.build(); +/// +/// ``` +pub struct SortByAttr<'a> { + ranked_map: &'a RankedMap, + attr: SchemaAttr, + reversed: bool, +} + +impl<'a> SortByAttr<'a> { + pub fn lower_is_better( + ranked_map: &'a RankedMap, + schema: &Schema, + attr_name: &str, + ) -> Result, SortByAttrError> + { + SortByAttr::new(ranked_map, schema, attr_name, false) + } + + pub fn higher_is_better( + ranked_map: &'a RankedMap, + schema: &Schema, + attr_name: &str, + ) -> Result, SortByAttrError> + { + SortByAttr::new(ranked_map, schema, attr_name, true) + } + + fn new( + ranked_map: &'a RankedMap, + schema: &Schema, + attr_name: &str, + reversed: bool, + ) -> Result, SortByAttrError> + { + let attr = match schema.attribute(attr_name) { + Some(attr) => attr, + None => return Err(SortByAttrError::AttributeNotFound), + }; + + if !schema.props(attr).is_ranked() { + return Err(SortByAttrError::AttributeNotRegisteredForRanking); + } + + Ok(SortByAttr { ranked_map, attr, reversed }) + } +} + +impl<'a> Criterion for SortByAttr<'a> { + fn evaluate(&self, lhs: &RawDocument, rhs: &RawDocument) -> Ordering { + let lhs = self.ranked_map.get(lhs.id, self.attr); + let rhs = self.ranked_map.get(rhs.id, self.attr); + + match (lhs, rhs) { + (Some(lhs), Some(rhs)) => { + let order = lhs.cmp(&rhs); + if self.reversed { order.reverse() } else { order } + }, + (None, Some(_)) => Ordering::Greater, + (Some(_), None) => Ordering::Less, + (None, None) => Ordering::Equal, + } + } + + fn name(&self) -> &'static str { + "SortByAttr" + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum SortByAttrError { + AttributeNotFound, + AttributeNotRegisteredForRanking, +} + +impl fmt::Display for SortByAttrError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use SortByAttrError::*; + match self { + AttributeNotFound => f.write_str("attribute not found in the schema"), + AttributeNotRegisteredForRanking => f.write_str("attribute not registered for ranking"), + } + } +} + +impl Error for SortByAttrError { }