diff --git a/src/database/mod.rs b/src/database/mod.rs index ddc7e128d..3e11b1b81 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -23,6 +23,8 @@ pub use self::serde::SerializerError; pub use self::schema::Schema; pub use self::index::Index; +pub type RankedMap = HashMap<(DocumentId, SchemaAttr), i64>; + const DATA_INDEX: &[u8] = b"data-index"; const DATA_RANKED_MAP: &[u8] = b"data-ranked-map"; const DATA_SCHEMA: &[u8] = b"data-schema"; @@ -65,9 +67,8 @@ where D: Deref Ok(index) } -fn retrieve_data_ranked_map(snapshot: &Snapshot) --> Result, Box> -where D: Deref +fn retrieve_data_ranked_map(snapshot: &Snapshot) -> Result> +where D: Deref, { match snapshot.get(DATA_RANKED_MAP)? { Some(vector) => Ok(bincode::deserialize(&*vector)?), @@ -94,9 +95,9 @@ fn merge_indexes(existing: Option<&[u8]>, operands: &mut MergeOperands) -> Vec, operands: &mut MergeOperands) -> Vec { - let mut ranked_map: Option> = None; + let mut ranked_map: Option = None; for bytes in existing.into_iter().chain(operands) { - let operand: HashMap<(DocumentId, SchemaAttr), i64> = bincode::deserialize(bytes).unwrap(); + let operand: RankedMap = bincode::deserialize(bytes).unwrap(); match ranked_map { Some(ref mut ranked_map) => ranked_map.extend(operand), None => { ranked_map.replace(operand); }, @@ -174,7 +175,6 @@ impl DatabaseIndex { let snapshot = Snapshot::new(db.clone()); let view = ArcCell::new(Arc::new(DatabaseView::new(snapshot)?)); - Ok(DatabaseIndex { db: db, view: view, diff --git a/src/database/update.rs b/src/database/update.rs index 0c165550d..5961b2ec8 100644 --- a/src/database/update.rs +++ b/src/database/update.rs @@ -16,8 +16,9 @@ use crate::tokenizer::TokenizerBuilder; use crate::data::{DocIds, DocIndexes}; use crate::database::schema::Schema; use crate::database::index::Index; -use crate::{DocumentId, DocIndex}; +use crate::database::RankedMap; use crate::database::{DATA_INDEX, DATA_RANKED_MAP}; +use crate::{DocumentId, DocIndex}; pub type Token = Vec; // TODO could be replaced by a SmallVec @@ -78,7 +79,7 @@ use UpdateType::{Updated, Deleted}; pub struct RawUpdateBuilder { documents_update: HashMap, - documents_ranked_fields: HashMap<(DocumentId, SchemaAttr), i64>, + documents_ranked_fields: RankedMap, indexed_words: BTreeMap>, batch: WriteBatch, } diff --git a/src/database/view.rs b/src/database/view.rs index 6f04ac4b1..e757b6021 100644 --- a/src/database/view.rs +++ b/src/database/view.rs @@ -1,4 +1,3 @@ -use hashbrown::HashMap; use std::error::Error; use std::path::Path; use std::ops::Deref; @@ -15,6 +14,7 @@ use crate::rank::{QueryBuilder, FilterFunc}; use crate::database::schema::SchemaAttr; use crate::database::schema::Schema; use crate::database::index::Index; +use crate::database::RankedMap; use crate::DocumentId; pub struct DatabaseView @@ -22,7 +22,7 @@ where D: Deref { snapshot: Snapshot, index: Index, - ranked_map: HashMap<(DocumentId, SchemaAttr), i64>, + ranked_map: RankedMap, schema: Schema, } @@ -44,7 +44,7 @@ where D: Deref &self.index } - pub fn ranked_map(&self) -> &HashMap<(DocumentId, SchemaAttr), i64> { + pub fn ranked_map(&self) -> &RankedMap { &self.ranked_map } diff --git a/src/rank/criterion/mod.rs b/src/rank/criterion/mod.rs index 07c6a37e1..fb2314b70 100644 --- a/src/rank/criterion/mod.rs +++ b/src/rank/criterion/mod.rs @@ -4,7 +4,7 @@ mod words_proximity; mod sum_of_words_attribute; mod sum_of_words_position; mod exact; -mod sort_by; +mod sort_by_attr; mod document_id; use std::cmp::Ordering; @@ -17,7 +17,7 @@ pub use self::{ sum_of_words_attribute::SumOfWordsAttribute, sum_of_words_position::SumOfWordsPosition, exact::Exact, - sort_by::SortBy, + sort_by_attr::SortByAttr, document_id::DocumentId, }; diff --git a/src/rank/criterion/sort_by.rs b/src/rank/criterion/sort_by.rs deleted file mode 100644 index d1c7abf8c..000000000 --- a/src/rank/criterion/sort_by.rs +++ /dev/null @@ -1,83 +0,0 @@ -use std::cmp::Ordering; -use std::ops::Deref; -use std::marker; - -use rocksdb::DB; -use serde::de::DeserializeOwned; - -use crate::rank::criterion::Criterion; -use crate::database::DatabaseView; -use crate::rank::RawDocument; - -/// An helper struct that permit to sort documents by -/// some of their stored attributes. -/// -/// # Note -/// -/// If a document cannot be deserialized it will be considered [`None`][]. -/// -/// Deserialized documents are compared like `Some(doc0).cmp(&Some(doc1))`, -/// so you must check the [`Ord`] of `Option` implementation. -/// -/// [`None`]: https://doc.rust-lang.org/std/option/enum.Option.html#variant.None -/// [`Ord`]: https://doc.rust-lang.org/std/option/enum.Option.html#impl-Ord -/// -/// # Example -/// -/// ```ignore -/// use serde_derive::Deserialize; -/// use meilidb::rank::criterion::*; -/// -/// #[derive(Deserialize, PartialOrd, Ord, PartialEq, Eq)] -/// struct TimeOnly { -/// time: String, -/// } -/// -/// let builder = CriteriaBuilder::with_capacity(8) -/// .add(SumOfTypos) -/// .add(NumberOfWords) -/// .add(WordsProximity) -/// .add(SumOfWordsAttribute) -/// .add(SumOfWordsPosition) -/// .add(Exact) -/// .add(SortBy::::new(&view)) -/// .add(DocumentId); -/// -/// let criterion = builder.build(); -/// -/// ``` -pub struct SortBy<'a, T, D> -where D: Deref + Send + Sync, - T: Send + Sync -{ - view: &'a DatabaseView, - _phantom: marker::PhantomData, -} - -impl<'a, T, D> SortBy<'a, T, D> -where D: Deref + Send + Sync, - T: Send + Sync -{ - pub fn new(view: &'a DatabaseView) -> Self { - SortBy { view, _phantom: marker::PhantomData } - } -} - -impl<'a, T, D> Criterion for SortBy<'a, T, D> -where D: Deref + Send + Sync, - T: DeserializeOwned + Ord + Send + Sync, -{ - fn evaluate(&self, lhs: &RawDocument, rhs: &RawDocument) -> Ordering { - let lhs = match self.view.document_by_id::(lhs.id) { - Ok(doc) => Some(doc), - Err(e) => { eprintln!("{}", e); None }, - }; - - let rhs = match self.view.document_by_id::(rhs.id) { - Ok(doc) => Some(doc), - Err(e) => { eprintln!("{}", e); None }, - }; - - lhs.cmp(&rhs) - } -} diff --git a/src/rank/criterion/sort_by_attr.rs b/src/rank/criterion/sort_by_attr.rs new file mode 100644 index 000000000..15bd09ee5 --- /dev/null +++ b/src/rank/criterion/sort_by_attr.rs @@ -0,0 +1,122 @@ +use std::cmp::Ordering; +use std::error::Error; +use std::fmt; + +use crate::database::schema::{Schema, SchemaAttr}; +use crate::rank::criterion::Criterion; +use crate::database::RankedMap; +use crate::rank::RawDocument; + +/// An helper struct that permit to sort documents by +/// some of their stored attributes. +/// +/// # Note +/// +/// If a document cannot be deserialized it will be considered [`None`][]. +/// +/// Deserialized documents are compared like `Some(doc0).cmp(&Some(doc1))`, +/// so you must check the [`Ord`] of `Option` implementation. +/// +/// [`None`]: https://doc.rust-lang.org/std/option/enum.Option.html#variant.None +/// [`Ord`]: https://doc.rust-lang.org/std/option/enum.Option.html#impl-Ord +/// +/// # Example +/// +/// ```ignore +/// use serde_derive::Deserialize; +/// use meilidb::rank::criterion::*; +/// +/// let custom_ranking = SortByAttr::lower_is_better(&ranked_map, &schema, "published_at")?; +/// +/// let builder = CriteriaBuilder::with_capacity(8) +/// .add(SumOfTypos) +/// .add(NumberOfWords) +/// .add(WordsProximity) +/// .add(SumOfWordsAttribute) +/// .add(SumOfWordsPosition) +/// .add(Exact) +/// .add(custom_ranking) +/// .add(DocumentId); +/// +/// let criterion = builder.build(); +/// +/// ``` +pub struct SortByAttr<'a> { + ranked_map: &'a RankedMap, + attr: SchemaAttr, + reversed: bool, +} + +impl<'a> SortByAttr<'a> { + pub fn lower_is_better( + ranked_map: &'a RankedMap, + schema: &Schema, + attr_name: &str, + ) -> Result, SortByAttrError> + { + SortByAttr::new(ranked_map, schema, attr_name, false) + } + + pub fn higher_is_better( + ranked_map: &'a RankedMap, + schema: &Schema, + attr_name: &str, + ) -> Result, SortByAttrError> + { + SortByAttr::new(ranked_map, schema, attr_name, true) + } + + fn new( + ranked_map: &'a RankedMap, + schema: &Schema, + attr_name: &str, + reversed: bool, + ) -> Result, SortByAttrError> + { + let attr = match schema.attribute(attr_name) { + Some(attr) => attr, + None => return Err(SortByAttrError::AttributeNotFound), + }; + + if schema.props(attr).is_ranked() { + return Err(SortByAttrError::AttributeNotRegisteredForRanking); + } + + Ok(SortByAttr { ranked_map, attr, reversed }) + } +} + +impl<'a> Criterion for SortByAttr<'a> { + fn evaluate(&self, lhs: &RawDocument, rhs: &RawDocument) -> Ordering { + let lhs = self.ranked_map.get(&(lhs.id, self.attr)); + let rhs = self.ranked_map.get(&(rhs.id, self.attr)); + + match (lhs, rhs) { + (Some(lhs), Some(rhs)) => { + let order = lhs.cmp(&rhs); + if self.reversed { order.reverse() } else { order } + }, + (None, Some(_)) => Ordering::Greater, + (Some(_), None) => Ordering::Less, + (None, None) => Ordering::Equal, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum SortByAttrError { + AttributeNotFound, + AttributeNotRegisteredForRanking, +} + +impl fmt::Display for SortByAttrError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use SortByAttrError::*; + match self { + AttributeNotFound => f.write_str("attribute not found in the schema"), + AttributeNotRegisteredForRanking => f.write_str("attribute not registered for ranking"), + } + } +} + +impl Error for SortByAttrError { }