diff --git a/examples/serve-console.rs b/examples/serve-console.rs index 21def0676..9eff1fbac 100644 --- a/examples/serve-console.rs +++ b/examples/serve-console.rs @@ -3,11 +3,10 @@ use std::io::{self, Write}; use structopt::StructOpt; use std::path::PathBuf; -use fst::Streamer; use elapsed::measure_time; use rocksdb::{DB, DBOptions, IngestExternalFileOptions}; use raptor::{automaton, Metadata, CommonWords}; -use raptor::rank; +use raptor::rank::{criterion, RankedStreamBuilder}; #[derive(Debug, StructOpt)] pub struct CommandConsole { @@ -70,15 +69,13 @@ fn search(metadata: &Metadata, database: &DB, common_words: &CommonWords, query: automatons.push(lev); } - let config = rank::Config { - criteria: rank::criterion::default(), - metadata: &metadata, - automatons: automatons, - limit: 20, - }; + let mut builder = RankedStreamBuilder::new(metadata, automatons); + builder.criteria(criterion::default()); - let mut stream = rank::RankedStream::new(config); - while let Some(document) = stream.next() { + let mut stream = builder.build(); + let documents = stream.retrieve_documents(20); + + for document in documents { let id_key = format!("{}-id", document.id); let id = database.get(id_key.as_bytes()).unwrap().unwrap(); let id = unsafe { from_utf8_unchecked(&id) }; diff --git a/examples/serve-http.rs b/examples/serve-http.rs index 4581f512b..f059973f3 100644 --- a/examples/serve-http.rs +++ b/examples/serve-http.rs @@ -7,10 +7,9 @@ use std::path::PathBuf; use std::error::Error; use std::sync::Arc; -use raptor::rank; +use raptor::rank::{criterion, RankedStreamBuilder}; use raptor::{automaton, Metadata, CommonWords}; use rocksdb::{DB, DBOptions, IngestExternalFileOptions}; -use fst::Streamer; use warp::Filter; use structopt::StructOpt; @@ -100,19 +99,17 @@ where M: AsRef, automatons.push(lev); } - let config = rank::Config { - criteria: rank::criterion::default(), - metadata: metadata.as_ref(), - automatons: automatons, - limit: 20, - }; + let mut builder = RankedStreamBuilder::new(metadata.as_ref(), automatons); + builder.criteria(criterion::default()); + + let mut stream = builder.build(); + let documents = stream.retrieve_documents(20); - let mut stream = rank::RankedStream::new(config); let mut body = Vec::new(); write!(&mut body, "[")?; let mut first = true; - while let Some(document) = stream.next() { + for document in documents { let title_key = format!("{}-title", document.id); let title = database.as_ref().get(title_key.as_bytes()).unwrap().unwrap(); let title = unsafe { from_utf8_unchecked(&title) }; diff --git a/src/rank/criterion/exact.rs b/src/rank/criterion/exact.rs index 7c3189fbc..2cdb9c0dd 100644 --- a/src/rank/criterion/exact.rs +++ b/src/rank/criterion/exact.rs @@ -2,6 +2,7 @@ use std::cmp::Ordering; use group_by::GroupBy; use crate::Match; use crate::rank::{match_query_index, Document}; +use crate::rank::criterion::Criterion; #[inline] fn contains_exact(matches: &[Match]) -> bool { @@ -13,10 +14,14 @@ fn number_exact_matches(matches: &[Match]) -> usize { GroupBy::new(matches, match_query_index).map(contains_exact).count() } -#[inline] -pub fn exact(lhs: &Document, rhs: &Document) -> Ordering { - let lhs = number_exact_matches(&lhs.matches); - let rhs = number_exact_matches(&rhs.matches); +#[derive(Debug, Clone, Copy)] +pub struct Exact; - lhs.cmp(&rhs).reverse() +impl Criterion for Exact { + fn evaluate(&self, lhs: &Document, rhs: &Document) -> Ordering { + let lhs = number_exact_matches(&lhs.matches); + let rhs = number_exact_matches(&rhs.matches); + + lhs.cmp(&rhs).reverse() + } } diff --git a/src/rank/criterion/mod.rs b/src/rank/criterion/mod.rs index 31b188d63..907b5cf48 100644 --- a/src/rank/criterion/mod.rs +++ b/src/rank/criterion/mod.rs @@ -7,65 +7,64 @@ mod exact; use std::vec; use std::cmp::Ordering; +use std::ops::Deref; use crate::rank::Document; pub use self::{ - sum_of_typos::sum_of_typos, - number_of_words::number_of_words, - words_proximity::words_proximity, - sum_of_words_attribute::sum_of_words_attribute, - sum_of_words_position::sum_of_words_position, - exact::exact, + sum_of_typos::SumOfTypos, + number_of_words::NumberOfWords, + words_proximity::WordsProximity, + sum_of_words_attribute::SumOfWordsAttribute, + sum_of_words_position::SumOfWordsPosition, + exact::Exact, }; -#[inline] -pub fn document_id(lhs: &Document, rhs: &Document) -> Ordering { - lhs.id.cmp(&rhs.id) -} +pub trait Criterion { + #[inline] + fn evaluate(&self, lhs: &Document, rhs: &Document) -> Ordering; -#[derive(Debug)] -pub struct Criteria(Vec); - -impl Criteria { - pub fn new() -> Self { - Criteria(Vec::new()) - } - - pub fn with_capacity(cap: usize) -> Self { - Criteria(Vec::with_capacity(cap)) - } - - pub fn push(&mut self, criterion: F) { - self.0.push(criterion) - } - - pub fn add(mut self, criterion: F) -> Self { - self.push(criterion); - self + #[inline] + fn eq(&self, lhs: &Document, rhs: &Document) -> bool { + self.evaluate(lhs, rhs) == Ordering::Equal } } -impl IntoIterator for Criteria { - type Item = F; - type IntoIter = vec::IntoIter; +impl<'a, T: Criterion + ?Sized> Criterion for &'a T { + fn evaluate(&self, lhs: &Document, rhs: &Document) -> Ordering { + self.deref().evaluate(lhs, rhs) + } - fn into_iter(self) -> Self::IntoIter { - self.0.into_iter() + fn eq(&self, lhs: &Document, rhs: &Document) -> bool { + self.deref().eq(lhs, rhs) } } -pub fn default() -> Criteria Ordering + Copy> { - let functions = &[ - sum_of_typos, - number_of_words, - words_proximity, - sum_of_words_attribute, - sum_of_words_position, - exact, - document_id, - ]; +impl Criterion for Box { + fn evaluate(&self, lhs: &Document, rhs: &Document) -> Ordering { + self.deref().evaluate(lhs, rhs) + } - let mut criteria = Criteria::with_capacity(functions.len()); - for f in functions { criteria.push(f) } - criteria + fn eq(&self, lhs: &Document, rhs: &Document) -> bool { + self.deref().eq(lhs, rhs) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct DocumentId; + +impl Criterion for DocumentId { + fn evaluate(&self, lhs: &Document, rhs: &Document) -> Ordering { + lhs.id.cmp(&rhs.id) + } +} + +pub fn default() -> Vec> { + vec![ + Box::new(SumOfTypos), + Box::new(NumberOfWords), + Box::new(WordsProximity), + Box::new(SumOfWordsAttribute), + Box::new(SumOfWordsPosition), + Box::new(Exact), + ] } diff --git a/src/rank/criterion/number_of_words.rs b/src/rank/criterion/number_of_words.rs index e64b3f5e0..902e49fc0 100644 --- a/src/rank/criterion/number_of_words.rs +++ b/src/rank/criterion/number_of_words.rs @@ -2,16 +2,21 @@ use std::cmp::Ordering; use group_by::GroupBy; use crate::Match; use crate::rank::{match_query_index, Document}; +use crate::rank::criterion::Criterion; #[inline] fn number_of_query_words(matches: &[Match]) -> usize { GroupBy::new(matches, match_query_index).count() } -#[inline] -pub fn number_of_words(lhs: &Document, rhs: &Document) -> Ordering { - let lhs = number_of_query_words(&lhs.matches); - let rhs = number_of_query_words(&rhs.matches); +#[derive(Debug, Clone, Copy)] +pub struct NumberOfWords; - lhs.cmp(&rhs).reverse() +impl Criterion for NumberOfWords { + fn evaluate(&self, lhs: &Document, rhs: &Document) -> Ordering { + let lhs = number_of_query_words(&lhs.matches); + let rhs = number_of_query_words(&rhs.matches); + + lhs.cmp(&rhs).reverse() + } } diff --git a/src/rank/criterion/sum_of_typos.rs b/src/rank/criterion/sum_of_typos.rs index ab90595c4..911a6c67f 100644 --- a/src/rank/criterion/sum_of_typos.rs +++ b/src/rank/criterion/sum_of_typos.rs @@ -2,6 +2,7 @@ use std::cmp::Ordering; use group_by::GroupBy; use crate::Match; use crate::rank::{match_query_index, Document}; +use crate::rank::criterion::Criterion; #[inline] fn sum_matches_typos(matches: &[Match]) -> i8 { @@ -18,14 +19,19 @@ fn sum_matches_typos(matches: &[Match]) -> i8 { sum_typos - number_words } -#[inline] -pub fn sum_of_typos(lhs: &Document, rhs: &Document) -> Ordering { - let lhs = sum_matches_typos(&lhs.matches); - let rhs = sum_matches_typos(&rhs.matches); +#[derive(Debug, Clone, Copy)] +pub struct SumOfTypos; - lhs.cmp(&rhs) +impl Criterion for SumOfTypos { + fn evaluate(&self, lhs: &Document, rhs: &Document) -> Ordering { + let lhs = sum_matches_typos(&lhs.matches); + let rhs = sum_matches_typos(&rhs.matches); + + lhs.cmp(&rhs) + } } + #[cfg(test)] mod tests { use super::*; @@ -42,7 +48,7 @@ mod tests { Match { query_index: 1, distance: 0, attribute: 0, attribute_index: 2, is_exact: false }, ]; Document { - document_id: 0, + id: 0, matches: matches, } }; @@ -53,12 +59,12 @@ mod tests { Match { query_index: 1, distance: 0, attribute: 0, attribute_index: 2, is_exact: false }, ]; Document { - document_id: 1, + id: 1, matches: matches, } }; - assert_eq!(sum_of_typos(&doc0, &doc1), Ordering::Less); + assert_eq!(SumOfTypos.evaluate(&doc0, &doc1), Ordering::Less); } // typing: "bouton manchette" @@ -73,7 +79,7 @@ mod tests { Match { query_index: 1, distance: 0, attribute: 0, attribute_index: 1, is_exact: false }, ]; Document { - document_id: 0, + id: 0, matches: matches, } }; @@ -83,12 +89,12 @@ mod tests { Match { query_index: 0, distance: 0, attribute: 0, attribute_index: 0, is_exact: false }, ]; Document { - document_id: 1, + id: 1, matches: matches, } }; - assert_eq!(sum_of_typos(&doc0, &doc1), Ordering::Less); + assert_eq!(SumOfTypos.evaluate(&doc0, &doc1), Ordering::Less); } // typing: "bouton manchztte" @@ -103,7 +109,7 @@ mod tests { Match { query_index: 1, distance: 1, attribute: 0, attribute_index: 1, is_exact: false }, ]; Document { - document_id: 0, + id: 0, matches: matches, } }; @@ -113,11 +119,11 @@ mod tests { Match { query_index: 0, distance: 0, attribute: 0, attribute_index: 0, is_exact: false }, ]; Document { - document_id: 1, + id: 1, matches: matches, } }; - assert_eq!(sum_of_typos(&doc0, &doc1), Ordering::Equal); + assert_eq!(SumOfTypos.evaluate(&doc0, &doc1), Ordering::Equal); } } diff --git a/src/rank/criterion/sum_of_words_attribute.rs b/src/rank/criterion/sum_of_words_attribute.rs index 3666df3f2..95629e2b5 100644 --- a/src/rank/criterion/sum_of_words_attribute.rs +++ b/src/rank/criterion/sum_of_words_attribute.rs @@ -2,6 +2,7 @@ use std::cmp::Ordering; use group_by::GroupBy; use crate::Match; use crate::rank::{match_query_index, Document}; +use crate::rank::criterion::Criterion; #[inline] fn sum_matches_attributes(matches: &[Match]) -> u8 { @@ -12,10 +13,14 @@ fn sum_matches_attributes(matches: &[Match]) -> u8 { }).sum() } -#[inline] -pub fn sum_of_words_attribute(lhs: &Document, rhs: &Document) -> Ordering { - let lhs = sum_matches_attributes(&lhs.matches); - let rhs = sum_matches_attributes(&rhs.matches); +#[derive(Debug, Clone, Copy)] +pub struct SumOfWordsAttribute; - lhs.cmp(&rhs) +impl Criterion for SumOfWordsAttribute { + fn evaluate(&self, lhs: &Document, rhs: &Document) -> Ordering { + let lhs = sum_matches_attributes(&lhs.matches); + let rhs = sum_matches_attributes(&rhs.matches); + + lhs.cmp(&rhs) + } } diff --git a/src/rank/criterion/sum_of_words_position.rs b/src/rank/criterion/sum_of_words_position.rs index ccf075b8a..5a230fed2 100644 --- a/src/rank/criterion/sum_of_words_position.rs +++ b/src/rank/criterion/sum_of_words_position.rs @@ -2,6 +2,7 @@ use std::cmp::Ordering; use group_by::GroupBy; use crate::Match; use crate::rank::{match_query_index, Document}; +use crate::rank::criterion::Criterion; #[inline] fn sum_matches_attribute_index(matches: &[Match]) -> u32 { @@ -12,10 +13,14 @@ fn sum_matches_attribute_index(matches: &[Match]) -> u32 { }).sum() } -#[inline] -pub fn sum_of_words_position(lhs: &Document, rhs: &Document) -> Ordering { - let lhs = sum_matches_attribute_index(&lhs.matches); - let rhs = sum_matches_attribute_index(&rhs.matches); +#[derive(Debug, Clone, Copy)] +pub struct SumOfWordsPosition; - lhs.cmp(&rhs) +impl Criterion for SumOfWordsPosition { + fn evaluate(&self, lhs: &Document, rhs: &Document) -> Ordering { + let lhs = sum_matches_attribute_index(&lhs.matches); + let rhs = sum_matches_attribute_index(&rhs.matches); + + lhs.cmp(&rhs) + } } diff --git a/src/rank/criterion/words_proximity.rs b/src/rank/criterion/words_proximity.rs index 601897c3d..8c2344595 100644 --- a/src/rank/criterion/words_proximity.rs +++ b/src/rank/criterion/words_proximity.rs @@ -2,6 +2,7 @@ use std::cmp::{self, Ordering}; use group_by::GroupBy; use crate::Match; use crate::rank::{match_query_index, Document}; +use crate::rank::criterion::Criterion; const MAX_DISTANCE: u32 = 8; @@ -42,10 +43,19 @@ fn matches_proximity(matches: &[Match]) -> u32 { proximity } -pub fn words_proximity(lhs: &Document, rhs: &Document) -> Ordering { - matches_proximity(&lhs.matches).cmp(&matches_proximity(&rhs.matches)) +#[derive(Debug, Clone, Copy)] +pub struct WordsProximity; + +impl Criterion for WordsProximity { + fn evaluate(&self, lhs: &Document, rhs: &Document) -> Ordering { + let lhs = matches_proximity(&lhs.matches); + let rhs = matches_proximity(&rhs.matches); + + lhs.cmp(&rhs) + } } + #[cfg(test)] mod tests { use super::*; diff --git a/src/rank/mod.rs b/src/rank/mod.rs index 0fe544ae1..c846964d1 100644 --- a/src/rank/mod.rs +++ b/src/rank/mod.rs @@ -3,7 +3,7 @@ mod ranked_stream; use crate::{Match, DocumentId}; -pub use self::ranked_stream::{RankedStream, Config}; +pub use self::ranked_stream::{RankedStreamBuilder, RankedStream}; #[inline] fn match_query_index(a: &Match, b: &Match) -> bool { @@ -18,10 +18,10 @@ pub struct Document { impl Document { pub fn new(doc: DocumentId, match_: Match) -> Self { - Self::from_sorted_matches(doc, vec![match_]) + unsafe { Self::from_sorted_matches(doc, vec![match_]) } } - pub fn from_sorted_matches(id: DocumentId, matches: Vec) -> Self { + pub unsafe fn from_sorted_matches(id: DocumentId, matches: Vec) -> Self { Self { id, matches } } } diff --git a/src/rank/ranked_stream.rs b/src/rank/ranked_stream.rs index e395eb4a9..3a658a6fc 100644 --- a/src/rank/ranked_stream.rs +++ b/src/rank/ranked_stream.rs @@ -1,4 +1,3 @@ -use std::cmp::Ordering; use std::rc::Rc; use std::{mem, vec}; @@ -8,134 +7,97 @@ use group_by::GroupByMut; use crate::automaton::{DfaExt, AutomatonExt}; use crate::metadata::Metadata; -use crate::metadata::ops::{OpBuilder, Union}; -use crate::rank::criterion::Criteria; +use crate::metadata::ops::OpBuilder; +use crate::rank::criterion::Criterion; use crate::rank::Document; -use crate::{Match, DocumentId}; +use crate::Match; -pub struct Config<'m, F> { - pub criteria: Criteria, - pub metadata: &'m Metadata, - pub automatons: Vec, - pub limit: usize, +#[derive(Clone)] +pub struct RankedStreamBuilder<'m, C> { + metadata: &'m Metadata, + automatons: Vec>, + criteria: Vec, } -pub struct RankedStream<'m, F>(RankedStreamInner<'m, F>); - -impl<'m, F> RankedStream<'m, F> { - pub fn new(config: Config<'m, F>) -> Self { - let automatons: Vec<_> = config.automatons.into_iter().map(Rc::new).collect(); - let mut builder = OpBuilder::with_automatons(automatons.clone()); - builder.push(config.metadata); - - let inner = RankedStreamInner::Fed { - inner: builder.union(), - automatons: automatons, - criteria: config.criteria, - limit: config.limit, - matches: FnvHashMap::default(), - }; - - RankedStream(inner) +impl<'m, C> RankedStreamBuilder<'m, C> { + pub fn new(metadata: &'m Metadata, automatons: Vec) -> Self { + RankedStreamBuilder { + metadata: metadata, + automatons: automatons.into_iter().map(Rc::new).collect(), + criteria: Vec::new(), // hummm... prefer the criterion::default() ones ! + } } -} -impl<'m, 'a, F> fst::Streamer<'a> for RankedStream<'m, F> -where F: Fn(&Document, &Document) -> Ordering + Copy, -{ - type Item = Document; - - fn next(&'a mut self) -> Option { - self.0.next() + pub fn criteria(&mut self, criteria: Vec) { + self.criteria = criteria; } -} -enum RankedStreamInner<'m, F> { - Fed { - inner: Union<'m>, - automatons: Vec>, - criteria: Criteria, - limit: usize, - matches: FnvHashMap>, - }, - Pours { - inner: vec::IntoIter, - }, -} + pub fn build(&self) -> RankedStream { + let mut builder = OpBuilder::with_automatons(self.automatons.clone()); + builder.push(self.metadata); -impl<'m, 'a, F> fst::Streamer<'a> for RankedStreamInner<'m, F> -where F: Fn(&Document, &Document) -> Ordering + Copy, -{ - type Item = Document; - - fn next(&'a mut self) -> Option { - loop { - match self { - RankedStreamInner::Fed { inner, automatons, criteria, limit, matches } => { - match inner.next() { - Some((string, indexed_values)) => { - for iv in indexed_values { - let automaton = &automatons[iv.index]; - let distance = automaton.eval(string).to_u8(); - let same_length = string.len() == automaton.query_len(); - - for di in iv.doc_indexes.as_slice() { - let match_ = Match { - query_index: iv.index as u32, - distance: distance, - attribute: di.attribute, - attribute_index: di.attribute_index, - is_exact: distance == 0 && same_length, - }; - matches.entry(di.document) - .or_insert_with(Vec::new) - .push(match_); - } - } - }, - None => { - let matches = mem::replace(matches, FnvHashMap::default()); - let criteria = mem::replace(criteria, Criteria::new()); - *self = RankedStreamInner::Pours { - inner: matches_into_iter(matches, criteria, *limit).into_iter() - }; - }, - } - }, - RankedStreamInner::Pours { inner } => { - return inner.next() - }, - } + RankedStream { + stream: builder.union(), + automatons: &self.automatons, + criteria: &self.criteria, } } } -fn matches_into_iter(matches: FnvHashMap>, - criteria: Criteria, - limit: usize) -> vec::IntoIter -where F: Fn(&Document, &Document) -> Ordering + Copy, -{ - let mut documents: Vec<_> = matches.into_iter().map(|(id, mut matches)| { - matches.sort_unstable(); - Document::from_sorted_matches(id, matches) - }).collect(); +pub struct RankedStream<'a, 'm, C> { + stream: crate::metadata::ops::Union<'m>, + automatons: &'a [Rc], + criteria: &'a [C], +} - let mut groups = vec![documents.as_mut_slice()]; +impl<'a, 'm, C> RankedStream<'a, 'm, C> { + pub fn retrieve_documents(&mut self, limit: usize) -> Vec + where C: Criterion + { + let mut matches = FnvHashMap::default(); - for sort in criteria { - let temp = mem::replace(&mut groups, Vec::new()); - let mut computed = 0; + while let Some((string, indexed_values)) = self.stream.next() { + for iv in indexed_values { + let automaton = &self.automatons[iv.index]; + let distance = automaton.eval(string).to_u8(); + let is_exact = distance == 0 && string.len() == automaton.query_len(); - 'grp: for group in temp { - group.sort_unstable_by(sort); - for group in GroupByMut::new(group, |a, b| sort(a, b) == Ordering::Equal) { - computed += group.len(); - groups.push(group); - if computed >= limit { break 'grp } + for di in iv.doc_indexes.as_slice() { + let match_ = Match { + query_index: iv.index as u32, + distance: distance, + attribute: di.attribute, + attribute_index: di.attribute_index, + is_exact: is_exact, + }; + matches.entry(di.document).or_insert_with(Vec::new).push(match_); + } } } - } - documents.truncate(limit); - documents.into_iter() + // collect matches from an HashMap into a Vec + let mut documents: Vec<_> = matches.into_iter().map(|(id, mut matches)| { + matches.sort_unstable(); + unsafe { Document::from_sorted_matches(id, matches) } + }).collect(); + + let mut groups = vec![documents.as_mut_slice()]; + + for criterion in self.criteria { + let temp = mem::replace(&mut groups, Vec::new()); + let mut computed = 0; + + 'grp: for group in temp { + group.sort_unstable_by(|a, b| criterion.evaluate(a, b)); + for group in GroupByMut::new(group, |a, b| criterion.eq(a, b)) { + computed += group.len(); + groups.push(group); + if computed >= limit { break 'grp } + } + } + } + + documents.truncate(limit); + documents + } }