From 084c3a95b641418d5d4446f4bb21c65b13a16517 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Fri, 8 Feb 2019 15:17:42 +0100 Subject: [PATCH 1/3] feat: Add a new ranked attribute to the schema --- src/database/mod.rs | 52 +++++++-- src/database/schema.rs | 16 ++- src/database/serde/mod.rs | 1 + src/database/serde/serializer.rs | 11 +- src/database/serde/value_to_i64.rs | 169 +++++++++++++++++++++++++++++ src/database/update.rs | 32 +++++- src/database/view.rs | 16 ++- src/lib.rs | 3 + 8 files changed, 279 insertions(+), 21 deletions(-) create mode 100644 src/database/serde/value_to_i64.rs diff --git a/src/database/mod.rs b/src/database/mod.rs index 701a7e23b..ddc7e128d 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,3 +1,5 @@ +use crate::DocumentId; +use crate::database::schema::SchemaAttr; use std::sync::Arc; use std::error::Error; use std::ffi::OsStr; @@ -12,6 +14,7 @@ use rocksdb::rocksdb::{Writable, Snapshot}; use rocksdb::rocksdb_options::{DBOptions, ColumnFamilyOptions}; use rocksdb::{DB, MergeOperands}; use lockfree::map::Map; +use hashbrown::HashMap; pub use self::document_key::{DocumentKey, DocumentKeyAttr}; pub use self::view::{DatabaseView, DocumentIter}; @@ -20,8 +23,9 @@ pub use self::serde::SerializerError; pub use self::schema::Schema; pub use self::index::Index; -const DATA_INDEX: &[u8] = b"data-index"; -const DATA_SCHEMA: &[u8] = b"data-schema"; +const DATA_INDEX: &[u8] = b"data-index"; +const DATA_RANKED_MAP: &[u8] = b"data-ranked-map"; +const DATA_SCHEMA: &[u8] = b"data-schema"; pub mod schema; pub(crate) mod index; @@ -61,9 +65,17 @@ where D: Deref Ok(index) } -fn merge_indexes(key: &[u8], existing: Option<&[u8]>, operands: &mut MergeOperands) -> Vec { - assert_eq!(key, DATA_INDEX, "The merge operator only supports \"data-index\" merging"); +fn retrieve_data_ranked_map(snapshot: &Snapshot) +-> Result, Box> +where D: Deref +{ + match snapshot.get(DATA_RANKED_MAP)? { + Some(vector) => Ok(bincode::deserialize(&*vector)?), + None => Ok(HashMap::new()), + } +} +fn merge_indexes(existing: Option<&[u8]>, operands: &mut MergeOperands) -> Vec { let mut index: Option = None; for bytes in existing.into_iter().chain(operands) { let operand = Index::from_bytes(bytes.to_vec()).unwrap(); @@ -81,6 +93,28 @@ fn merge_indexes(key: &[u8], existing: Option<&[u8]>, operands: &mut MergeOperan bytes } +fn merge_ranked_maps(existing: Option<&[u8]>, operands: &mut MergeOperands) -> Vec { + let mut ranked_map: Option> = None; + for bytes in existing.into_iter().chain(operands) { + let operand: HashMap<(DocumentId, SchemaAttr), i64> = bincode::deserialize(bytes).unwrap(); + match ranked_map { + Some(ref mut ranked_map) => ranked_map.extend(operand), + None => { ranked_map.replace(operand); }, + }; + } + + let ranked_map = ranked_map.unwrap_or_default(); + bincode::serialize(&ranked_map).unwrap() +} + +fn merge_operator(key: &[u8], existing: Option<&[u8]>, operands: &mut MergeOperands) -> Vec { + match key { + DATA_INDEX => merge_indexes(existing, operands), + DATA_RANKED_MAP => merge_ranked_maps(existing, operands), + key => panic!("The merge operator does not support merging {:?}", key), + } +} + pub struct IndexUpdate { index: String, update: Update, @@ -103,14 +137,14 @@ impl DerefMut for IndexUpdate { struct DatabaseIndex { db: Arc, - // This view is updated each time the DB ingests an update + // This view is updated each time the DB ingests an update. view: ArcCell>>, - // This path is the path to the mdb folder stored on disk + // The path of the mdb folder stored on disk. path: PathBuf, // must_die false by default, must be set as true when the Index is dropped. - // It's used to erase the folder saved on disk when the user request to delete an index + // It is used to erase the folder saved on disk when the user request to delete an index. must_die: AtomicBool, } @@ -128,7 +162,7 @@ impl DatabaseIndex { // opts.error_if_exists(true); // FIXME pull request that let mut cf_opts = ColumnFamilyOptions::new(); - cf_opts.add_merge_operator("data-index merge operator", merge_indexes); + cf_opts.add_merge_operator("data merge operator", merge_operator); let db = DB::open_cf(opts, &path_lossy, vec![("default", cf_opts)])?; @@ -156,7 +190,7 @@ impl DatabaseIndex { opts.create_if_missing(false); let mut cf_opts = ColumnFamilyOptions::new(); - cf_opts.add_merge_operator("data-index merge operator", merge_indexes); + cf_opts.add_merge_operator("data merge operator", merge_operator); let db = DB::open_cf(opts, &path_lossy, vec![("default", cf_opts)])?; diff --git a/src/database/schema.rs b/src/database/schema.rs index 5b4b48731..3a8878ee3 100644 --- a/src/database/schema.rs +++ b/src/database/schema.rs @@ -13,8 +13,9 @@ use crate::database::serde::find_id::FindDocumentIdSerializer; use crate::database::serde::SerializerError; use crate::DocumentId; -pub const STORED: SchemaProps = SchemaProps { stored: true, indexed: false }; -pub const INDEXED: SchemaProps = SchemaProps { stored: false, indexed: true }; +pub const STORED: SchemaProps = SchemaProps { stored: true, indexed: false, ranked: false }; +pub const INDEXED: SchemaProps = SchemaProps { stored: false, indexed: true, ranked: false }; +pub const RANKED: SchemaProps = SchemaProps { stored: false, indexed: false, ranked: true }; #[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct SchemaProps { @@ -23,6 +24,9 @@ pub struct SchemaProps { #[serde(default)] indexed: bool, + + #[serde(default)] + ranked: bool, } impl SchemaProps { @@ -33,6 +37,10 @@ impl SchemaProps { pub fn is_indexed(self) -> bool { self.indexed } + + pub fn is_ranked(self) -> bool { + self.ranked + } } impl BitOr for SchemaProps { @@ -42,6 +50,7 @@ impl BitOr for SchemaProps { SchemaProps { stored: self.stored | other.stored, indexed: self.indexed | other.indexed, + ranked: self.ranked | other.ranked, } } } @@ -185,7 +194,8 @@ impl Schema { } } -#[derive(Debug, Copy, Clone, PartialOrd, Ord, PartialEq, Eq)] +#[derive(Serialize, Deserialize)] +#[derive(Debug, Copy, Clone, PartialOrd, Ord, PartialEq, Eq, Hash)] pub struct SchemaAttr(pub(crate) u16); impl SchemaAttr { diff --git a/src/database/serde/mod.rs b/src/database/serde/mod.rs index 2f9415c25..50a3c619e 100644 --- a/src/database/serde/mod.rs +++ b/src/database/serde/mod.rs @@ -17,6 +17,7 @@ macro_rules! forward_to_unserializable_type { pub mod find_id; pub mod key_to_string; +pub mod value_to_i64; pub mod serializer; pub mod indexer_serializer; pub mod deserializer; diff --git a/src/database/serde/serializer.rs b/src/database/serde/serializer.rs index d516be609..bc8b4d1ab 100644 --- a/src/database/serde/serializer.rs +++ b/src/database/serde/serializer.rs @@ -5,6 +5,7 @@ use serde::ser; use crate::database::serde::indexer_serializer::IndexerSerializer; use crate::database::serde::key_to_string::KeyToStringSerializer; +use crate::database::serde::value_to_i64::ValueToI64Serializer; use crate::database::update::DocumentUpdate; use crate::database::serde::SerializerError; use crate::tokenizer::TokenizerBuilder; @@ -155,8 +156,8 @@ where B: TokenizerBuilder { Ok(StructSerializer { schema: self.schema, - update: self.update, document_id: self.document_id, + update: self.update, tokenizer_builder: self.tokenizer_builder, stop_words: self.stop_words, }) @@ -229,6 +230,10 @@ where B: TokenizerBuilder }; value.serialize(serializer)?; } + if props.is_ranked() { + let integer = value.serialize(ValueToI64Serializer)?; + self.update.register_ranked_attribute(attr, integer)?; + } } Ok(()) @@ -276,6 +281,10 @@ where B: TokenizerBuilder }; value.serialize(serializer)?; } + if props.is_ranked() { + let integer = value.serialize(ValueToI64Serializer)?; + self.update.register_ranked_attribute(attr, integer)?; + } } Ok(()) diff --git a/src/database/serde/value_to_i64.rs b/src/database/serde/value_to_i64.rs new file mode 100644 index 000000000..9c046d391 --- /dev/null +++ b/src/database/serde/value_to_i64.rs @@ -0,0 +1,169 @@ +use serde::Serialize; +use serde::{ser, ser::Error}; + +use crate::database::serde::SerializerError; + +pub struct ValueToI64Serializer; + +impl ser::Serializer for ValueToI64Serializer { + type Ok = i64; + type Error = SerializerError; + type SerializeSeq = ser::Impossible; + type SerializeTuple = ser::Impossible; + type SerializeTupleStruct = ser::Impossible; + type SerializeTupleVariant = ser::Impossible; + type SerializeMap = ser::Impossible; + type SerializeStruct = ser::Impossible; + type SerializeStructVariant = ser::Impossible; + + forward_to_unserializable_type! { + bool => serialize_bool, + char => serialize_char, + + f32 => serialize_f32, + f64 => serialize_f64, + } + + fn serialize_i8(self, value: i8) -> Result { + Ok(i64::from(value)) + } + + fn serialize_i16(self, value: i16) -> Result { + Ok(i64::from(value)) + } + + fn serialize_i32(self, value: i32) -> Result { + Ok(i64::from(value)) + } + + fn serialize_i64(self, value: i64) -> Result { + Ok(i64::from(value)) + } + + fn serialize_u8(self, value: u8) -> Result { + Ok(i64::from(value)) + } + + fn serialize_u16(self, value: u16) -> Result { + Ok(i64::from(value)) + } + + fn serialize_u32(self, value: u32) -> Result { + Ok(i64::from(value)) + } + + fn serialize_u64(self, value: u64) -> Result { + // Ok(i64::from(value)) + unimplemented!() + } + + fn serialize_str(self, value: &str) -> Result { + i64::from_str_radix(value, 10).map_err(SerializerError::custom) + } + + fn serialize_bytes(self, _v: &[u8]) -> Result { + Err(SerializerError::UnserializableType { name: "&[u8]" }) + } + + fn serialize_none(self) -> Result { + Err(SerializerError::UnserializableType { name: "Option" }) + } + + fn serialize_some(self, _value: &T) -> Result + where T: Serialize, + { + Err(SerializerError::UnserializableType { name: "Option" }) + } + + fn serialize_unit(self) -> Result { + Err(SerializerError::UnserializableType { name: "()" }) + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result { + Err(SerializerError::UnserializableType { name: "unit struct" }) + } + + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str + ) -> Result + { + Err(SerializerError::UnserializableType { name: "unit variant" }) + } + + fn serialize_newtype_struct( + self, + _name: &'static str, + value: &T + ) -> Result + where T: Serialize, + { + value.serialize(self) + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T + ) -> Result + where T: Serialize, + { + Err(SerializerError::UnserializableType { name: "newtype variant" }) + } + + fn serialize_seq(self, _len: Option) -> Result { + Err(SerializerError::UnserializableType { name: "sequence" }) + } + + fn serialize_tuple(self, _len: usize) -> Result { + Err(SerializerError::UnserializableType { name: "tuple" }) + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize + ) -> Result + { + Err(SerializerError::UnserializableType { name: "tuple struct" }) + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize + ) -> Result + { + Err(SerializerError::UnserializableType { name: "tuple variant" }) + } + + fn serialize_map(self, _len: Option) -> Result { + Err(SerializerError::UnserializableType { name: "map" }) + } + + fn serialize_struct( + self, + _name: &'static str, + _len: usize + ) -> Result + { + Err(SerializerError::UnserializableType { name: "struct" }) + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize + ) -> Result + { + Err(SerializerError::UnserializableType { name: "struct variant" }) + } +} diff --git a/src/database/update.rs b/src/database/update.rs index 616c070e5..0c165550d 100644 --- a/src/database/update.rs +++ b/src/database/update.rs @@ -17,7 +17,7 @@ use crate::data::{DocIds, DocIndexes}; use crate::database::schema::Schema; use crate::database::index::Index; use crate::{DocumentId, DocIndex}; -use crate::database::DATA_INDEX; +use crate::database::{DATA_INDEX, DATA_RANKED_MAP}; pub type Token = Vec; // TODO could be replaced by a SmallVec @@ -78,6 +78,7 @@ use UpdateType::{Updated, Deleted}; pub struct RawUpdateBuilder { documents_update: HashMap, + documents_ranked_fields: HashMap<(DocumentId, SchemaAttr), i64>, indexed_words: BTreeMap>, batch: WriteBatch, } @@ -86,6 +87,7 @@ impl RawUpdateBuilder { pub fn new() -> RawUpdateBuilder { RawUpdateBuilder { documents_update: HashMap::new(), + documents_ranked_fields: HashMap::new(), indexed_words: BTreeMap::new(), batch: WriteBatch::new(), } @@ -137,9 +139,12 @@ impl RawUpdateBuilder { let index = Index { negative, positive }; // write the data-index - let mut bytes = Vec::new(); - index.write_to_bytes(&mut bytes); - self.batch.merge(DATA_INDEX, &bytes)?; + let mut bytes_index = Vec::new(); + index.write_to_bytes(&mut bytes_index); + self.batch.merge(DATA_INDEX, &bytes_index)?; + + let bytes_ranked_map = bincode::serialize(&self.documents_ranked_fields).unwrap(); + self.batch.merge(DATA_RANKED_MAP, &bytes_ranked_map)?; Ok(self.batch) } @@ -195,4 +200,23 @@ impl<'a> DocumentUpdate<'a> { Ok(()) } + + pub fn register_ranked_attribute( + &mut self, + attr: SchemaAttr, + integer: i64, + ) -> Result<(), SerializerError> + { + use serde::ser::Error; + + if let Deleted = self.inner.documents_update.entry(self.document_id).or_insert(Updated) { + return Err(SerializerError::custom( + "This document has already been deleted, ranked attributes cannot be added in the same update" + )); + } + + self.inner.documents_ranked_fields.insert((self.document_id, attr), integer); + + Ok(()) + } } diff --git a/src/database/view.rs b/src/database/view.rs index b9144a281..6f04ac4b1 100644 --- a/src/database/view.rs +++ b/src/database/view.rs @@ -1,3 +1,4 @@ +use hashbrown::HashMap; use std::error::Error; use std::path::Path; use std::ops::Deref; @@ -7,12 +8,13 @@ use rocksdb::rocksdb_options::{ReadOptions, EnvOptions, ColumnFamilyOptions}; use rocksdb::rocksdb::{DB, DBVector, Snapshot, SeekKey, SstFileWriter}; use serde::de::DeserializeOwned; -use crate::database::{DocumentKey, DocumentKeyAttr}; -use crate::database::{retrieve_data_schema, retrieve_data_index}; +use crate::database::{retrieve_data_schema, retrieve_data_index, retrieve_data_ranked_map}; use crate::database::serde::deserializer::Deserializer; +use crate::database::{DocumentKey, DocumentKeyAttr}; +use crate::rank::{QueryBuilder, FilterFunc}; +use crate::database::schema::SchemaAttr; use crate::database::schema::Schema; use crate::database::index::Index; -use crate::rank::{QueryBuilder, FilterFunc}; use crate::DocumentId; pub struct DatabaseView @@ -20,6 +22,7 @@ where D: Deref { snapshot: Snapshot, index: Index, + ranked_map: HashMap<(DocumentId, SchemaAttr), i64>, schema: Schema, } @@ -29,7 +32,8 @@ where D: Deref pub fn new(snapshot: Snapshot) -> Result, Box> { let schema = retrieve_data_schema(&snapshot)?; let index = retrieve_data_index(&snapshot)?; - Ok(DatabaseView { snapshot, index, schema }) + let ranked_map = retrieve_data_ranked_map(&snapshot)?; + Ok(DatabaseView { snapshot, index, ranked_map, schema }) } pub fn schema(&self) -> &Schema { @@ -40,6 +44,10 @@ where D: Deref &self.index } + pub fn ranked_map(&self) -> &HashMap<(DocumentId, SchemaAttr), i64> { + &self.ranked_map + } + pub fn into_snapshot(self) -> Snapshot { self.snapshot } diff --git a/src/lib.rs b/src/lib.rs index bfa0b3cd9..9c0641090 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,6 +7,8 @@ pub mod rank; pub mod tokenizer; mod common_words; +use serde_derive::{Serialize, Deserialize}; + pub use rocksdb; pub use self::tokenizer::Tokenizer; @@ -16,6 +18,7 @@ pub use self::common_words::CommonWords; /// /// It is used to inform the database the document you want to deserialize. /// Helpful for custom ranking. +#[derive(Serialize, Deserialize)] #[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] pub struct DocumentId(u64); From 83cd071827db449cb1f62e620b030389fdee378f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Sat, 9 Feb 2019 13:49:18 +0100 Subject: [PATCH 2/3] feat: Introduce the SortByAttr custom ranking helper --- src/database/mod.rs | 12 +-- src/database/update.rs | 5 +- src/database/view.rs | 6 +- src/rank/criterion/mod.rs | 4 +- src/rank/criterion/sort_by.rs | 83 -------------------- src/rank/criterion/sort_by_attr.rs | 122 +++++++++++++++++++++++++++++ 6 files changed, 136 insertions(+), 96 deletions(-) delete mode 100644 src/rank/criterion/sort_by.rs create mode 100644 src/rank/criterion/sort_by_attr.rs diff --git a/src/database/mod.rs b/src/database/mod.rs index ddc7e128d..3e11b1b81 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -23,6 +23,8 @@ pub use self::serde::SerializerError; pub use self::schema::Schema; pub use self::index::Index; +pub type RankedMap = HashMap<(DocumentId, SchemaAttr), i64>; + const DATA_INDEX: &[u8] = b"data-index"; const DATA_RANKED_MAP: &[u8] = b"data-ranked-map"; const DATA_SCHEMA: &[u8] = b"data-schema"; @@ -65,9 +67,8 @@ where D: Deref Ok(index) } -fn retrieve_data_ranked_map(snapshot: &Snapshot) --> Result, Box> -where D: Deref +fn retrieve_data_ranked_map(snapshot: &Snapshot) -> Result> +where D: Deref, { match snapshot.get(DATA_RANKED_MAP)? { Some(vector) => Ok(bincode::deserialize(&*vector)?), @@ -94,9 +95,9 @@ fn merge_indexes(existing: Option<&[u8]>, operands: &mut MergeOperands) -> Vec, operands: &mut MergeOperands) -> Vec { - let mut ranked_map: Option> = None; + let mut ranked_map: Option = None; for bytes in existing.into_iter().chain(operands) { - let operand: HashMap<(DocumentId, SchemaAttr), i64> = bincode::deserialize(bytes).unwrap(); + let operand: RankedMap = bincode::deserialize(bytes).unwrap(); match ranked_map { Some(ref mut ranked_map) => ranked_map.extend(operand), None => { ranked_map.replace(operand); }, @@ -174,7 +175,6 @@ impl DatabaseIndex { let snapshot = Snapshot::new(db.clone()); let view = ArcCell::new(Arc::new(DatabaseView::new(snapshot)?)); - Ok(DatabaseIndex { db: db, view: view, diff --git a/src/database/update.rs b/src/database/update.rs index 0c165550d..5961b2ec8 100644 --- a/src/database/update.rs +++ b/src/database/update.rs @@ -16,8 +16,9 @@ use crate::tokenizer::TokenizerBuilder; use crate::data::{DocIds, DocIndexes}; use crate::database::schema::Schema; use crate::database::index::Index; -use crate::{DocumentId, DocIndex}; +use crate::database::RankedMap; use crate::database::{DATA_INDEX, DATA_RANKED_MAP}; +use crate::{DocumentId, DocIndex}; pub type Token = Vec; // TODO could be replaced by a SmallVec @@ -78,7 +79,7 @@ use UpdateType::{Updated, Deleted}; pub struct RawUpdateBuilder { documents_update: HashMap, - documents_ranked_fields: HashMap<(DocumentId, SchemaAttr), i64>, + documents_ranked_fields: RankedMap, indexed_words: BTreeMap>, batch: WriteBatch, } diff --git a/src/database/view.rs b/src/database/view.rs index 6f04ac4b1..e757b6021 100644 --- a/src/database/view.rs +++ b/src/database/view.rs @@ -1,4 +1,3 @@ -use hashbrown::HashMap; use std::error::Error; use std::path::Path; use std::ops::Deref; @@ -15,6 +14,7 @@ use crate::rank::{QueryBuilder, FilterFunc}; use crate::database::schema::SchemaAttr; use crate::database::schema::Schema; use crate::database::index::Index; +use crate::database::RankedMap; use crate::DocumentId; pub struct DatabaseView @@ -22,7 +22,7 @@ where D: Deref { snapshot: Snapshot, index: Index, - ranked_map: HashMap<(DocumentId, SchemaAttr), i64>, + ranked_map: RankedMap, schema: Schema, } @@ -44,7 +44,7 @@ where D: Deref &self.index } - pub fn ranked_map(&self) -> &HashMap<(DocumentId, SchemaAttr), i64> { + pub fn ranked_map(&self) -> &RankedMap { &self.ranked_map } diff --git a/src/rank/criterion/mod.rs b/src/rank/criterion/mod.rs index 07c6a37e1..fb2314b70 100644 --- a/src/rank/criterion/mod.rs +++ b/src/rank/criterion/mod.rs @@ -4,7 +4,7 @@ mod words_proximity; mod sum_of_words_attribute; mod sum_of_words_position; mod exact; -mod sort_by; +mod sort_by_attr; mod document_id; use std::cmp::Ordering; @@ -17,7 +17,7 @@ pub use self::{ sum_of_words_attribute::SumOfWordsAttribute, sum_of_words_position::SumOfWordsPosition, exact::Exact, - sort_by::SortBy, + sort_by_attr::SortByAttr, document_id::DocumentId, }; diff --git a/src/rank/criterion/sort_by.rs b/src/rank/criterion/sort_by.rs deleted file mode 100644 index d1c7abf8c..000000000 --- a/src/rank/criterion/sort_by.rs +++ /dev/null @@ -1,83 +0,0 @@ -use std::cmp::Ordering; -use std::ops::Deref; -use std::marker; - -use rocksdb::DB; -use serde::de::DeserializeOwned; - -use crate::rank::criterion::Criterion; -use crate::database::DatabaseView; -use crate::rank::RawDocument; - -/// An helper struct that permit to sort documents by -/// some of their stored attributes. -/// -/// # Note -/// -/// If a document cannot be deserialized it will be considered [`None`][]. -/// -/// Deserialized documents are compared like `Some(doc0).cmp(&Some(doc1))`, -/// so you must check the [`Ord`] of `Option` implementation. -/// -/// [`None`]: https://doc.rust-lang.org/std/option/enum.Option.html#variant.None -/// [`Ord`]: https://doc.rust-lang.org/std/option/enum.Option.html#impl-Ord -/// -/// # Example -/// -/// ```ignore -/// use serde_derive::Deserialize; -/// use meilidb::rank::criterion::*; -/// -/// #[derive(Deserialize, PartialOrd, Ord, PartialEq, Eq)] -/// struct TimeOnly { -/// time: String, -/// } -/// -/// let builder = CriteriaBuilder::with_capacity(8) -/// .add(SumOfTypos) -/// .add(NumberOfWords) -/// .add(WordsProximity) -/// .add(SumOfWordsAttribute) -/// .add(SumOfWordsPosition) -/// .add(Exact) -/// .add(SortBy::::new(&view)) -/// .add(DocumentId); -/// -/// let criterion = builder.build(); -/// -/// ``` -pub struct SortBy<'a, T, D> -where D: Deref + Send + Sync, - T: Send + Sync -{ - view: &'a DatabaseView, - _phantom: marker::PhantomData, -} - -impl<'a, T, D> SortBy<'a, T, D> -where D: Deref + Send + Sync, - T: Send + Sync -{ - pub fn new(view: &'a DatabaseView) -> Self { - SortBy { view, _phantom: marker::PhantomData } - } -} - -impl<'a, T, D> Criterion for SortBy<'a, T, D> -where D: Deref + Send + Sync, - T: DeserializeOwned + Ord + Send + Sync, -{ - fn evaluate(&self, lhs: &RawDocument, rhs: &RawDocument) -> Ordering { - let lhs = match self.view.document_by_id::(lhs.id) { - Ok(doc) => Some(doc), - Err(e) => { eprintln!("{}", e); None }, - }; - - let rhs = match self.view.document_by_id::(rhs.id) { - Ok(doc) => Some(doc), - Err(e) => { eprintln!("{}", e); None }, - }; - - lhs.cmp(&rhs) - } -} diff --git a/src/rank/criterion/sort_by_attr.rs b/src/rank/criterion/sort_by_attr.rs new file mode 100644 index 000000000..15bd09ee5 --- /dev/null +++ b/src/rank/criterion/sort_by_attr.rs @@ -0,0 +1,122 @@ +use std::cmp::Ordering; +use std::error::Error; +use std::fmt; + +use crate::database::schema::{Schema, SchemaAttr}; +use crate::rank::criterion::Criterion; +use crate::database::RankedMap; +use crate::rank::RawDocument; + +/// An helper struct that permit to sort documents by +/// some of their stored attributes. +/// +/// # Note +/// +/// If a document cannot be deserialized it will be considered [`None`][]. +/// +/// Deserialized documents are compared like `Some(doc0).cmp(&Some(doc1))`, +/// so you must check the [`Ord`] of `Option` implementation. +/// +/// [`None`]: https://doc.rust-lang.org/std/option/enum.Option.html#variant.None +/// [`Ord`]: https://doc.rust-lang.org/std/option/enum.Option.html#impl-Ord +/// +/// # Example +/// +/// ```ignore +/// use serde_derive::Deserialize; +/// use meilidb::rank::criterion::*; +/// +/// let custom_ranking = SortByAttr::lower_is_better(&ranked_map, &schema, "published_at")?; +/// +/// let builder = CriteriaBuilder::with_capacity(8) +/// .add(SumOfTypos) +/// .add(NumberOfWords) +/// .add(WordsProximity) +/// .add(SumOfWordsAttribute) +/// .add(SumOfWordsPosition) +/// .add(Exact) +/// .add(custom_ranking) +/// .add(DocumentId); +/// +/// let criterion = builder.build(); +/// +/// ``` +pub struct SortByAttr<'a> { + ranked_map: &'a RankedMap, + attr: SchemaAttr, + reversed: bool, +} + +impl<'a> SortByAttr<'a> { + pub fn lower_is_better( + ranked_map: &'a RankedMap, + schema: &Schema, + attr_name: &str, + ) -> Result, SortByAttrError> + { + SortByAttr::new(ranked_map, schema, attr_name, false) + } + + pub fn higher_is_better( + ranked_map: &'a RankedMap, + schema: &Schema, + attr_name: &str, + ) -> Result, SortByAttrError> + { + SortByAttr::new(ranked_map, schema, attr_name, true) + } + + fn new( + ranked_map: &'a RankedMap, + schema: &Schema, + attr_name: &str, + reversed: bool, + ) -> Result, SortByAttrError> + { + let attr = match schema.attribute(attr_name) { + Some(attr) => attr, + None => return Err(SortByAttrError::AttributeNotFound), + }; + + if schema.props(attr).is_ranked() { + return Err(SortByAttrError::AttributeNotRegisteredForRanking); + } + + Ok(SortByAttr { ranked_map, attr, reversed }) + } +} + +impl<'a> Criterion for SortByAttr<'a> { + fn evaluate(&self, lhs: &RawDocument, rhs: &RawDocument) -> Ordering { + let lhs = self.ranked_map.get(&(lhs.id, self.attr)); + let rhs = self.ranked_map.get(&(rhs.id, self.attr)); + + match (lhs, rhs) { + (Some(lhs), Some(rhs)) => { + let order = lhs.cmp(&rhs); + if self.reversed { order.reverse() } else { order } + }, + (None, Some(_)) => Ordering::Greater, + (Some(_), None) => Ordering::Less, + (None, None) => Ordering::Equal, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum SortByAttrError { + AttributeNotFound, + AttributeNotRegisteredForRanking, +} + +impl fmt::Display for SortByAttrError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use SortByAttrError::*; + match self { + AttributeNotFound => f.write_str("attribute not found in the schema"), + AttributeNotRegisteredForRanking => f.write_str("attribute not registered for ranking"), + } + } +} + +impl Error for SortByAttrError { } From db6210c7ee7b0dc9adf544ed7c53b41e235d1acf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Mon, 11 Feb 2019 16:58:44 +0100 Subject: [PATCH 3/3] feat: Introduce the Number type --- src/database/mod.rs | 10 +- src/database/number.rs | 98 +++++++++++++++++++ src/database/serde/mod.rs | 2 +- src/database/serde/serializer.rs | 8 +- .../{value_to_i64.rs => value_to_number.rs} | 39 +++++--- src/database/update.rs | 6 +- src/database/view.rs | 1 - 7 files changed, 135 insertions(+), 29 deletions(-) create mode 100644 src/database/number.rs rename src/database/serde/{value_to_i64.rs => value_to_number.rs} (85%) diff --git a/src/database/mod.rs b/src/database/mod.rs index 3e11b1b81..8097dd726 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -8,13 +8,13 @@ use std::path::{Path, PathBuf}; use std::sync::atomic::{AtomicBool, Ordering}; use std::ops::{Deref, DerefMut}; -use crossbeam::atomic::ArcCell; -use log::{info, error, warn}; -use rocksdb::rocksdb::{Writable, Snapshot}; use rocksdb::rocksdb_options::{DBOptions, ColumnFamilyOptions}; +use rocksdb::rocksdb::{Writable, Snapshot}; use rocksdb::{DB, MergeOperands}; +use crossbeam::atomic::ArcCell; use lockfree::map::Map; use hashbrown::HashMap; +use log::{info, error, warn}; pub use self::document_key::{DocumentKey, DocumentKeyAttr}; pub use self::view::{DatabaseView, DocumentIter}; @@ -22,8 +22,9 @@ pub use self::update::Update; pub use self::serde::SerializerError; pub use self::schema::Schema; pub use self::index::Index; +pub use self::number::{Number, ParseNumberError}; -pub type RankedMap = HashMap<(DocumentId, SchemaAttr), i64>; +pub type RankedMap = HashMap<(DocumentId, SchemaAttr), Number>; const DATA_INDEX: &[u8] = b"data-index"; const DATA_RANKED_MAP: &[u8] = b"data-ranked-map"; @@ -31,6 +32,7 @@ const DATA_SCHEMA: &[u8] = b"data-schema"; pub mod schema; pub(crate) mod index; +mod number; mod document_key; mod serde; mod update; diff --git a/src/database/number.rs b/src/database/number.rs new file mode 100644 index 000000000..b2c4c9a88 --- /dev/null +++ b/src/database/number.rs @@ -0,0 +1,98 @@ +use std::cmp::Ordering; +use std::str::FromStr; +use std::fmt; + +use serde_derive::{Serialize, Deserialize}; + +#[derive(Serialize, Deserialize)] +#[derive(Debug, Copy, Clone)] +pub enum Number { + Unsigned(u64), + Signed(i64), + Float(f64), +} + +impl FromStr for Number { + type Err = ParseNumberError; + + fn from_str(s: &str) -> Result { + if let Ok(unsigned) = u64::from_str(s) { + return Ok(Number::Unsigned(unsigned)) + } + + if let Ok(signed) = i64::from_str(s) { + return Ok(Number::Signed(signed)) + } + + if let Ok(float) = f64::from_str(s) { + if float == 0.0 || float.is_normal() { + return Ok(Number::Float(float)) + } + } + + Err(ParseNumberError) + } +} + +impl PartialOrd for Number { + fn partial_cmp(&self, other: &Number) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Number { + fn cmp(&self, other: &Number) -> Ordering { + use Number::*; + match (self, other) { + (Unsigned(s), Unsigned(o)) => s.cmp(o), + (Unsigned(s), Signed(o)) => { + let s = i128::from(*s); + let o = i128::from(*o); + s.cmp(&o) + }, + (Unsigned(s), Float(o)) => { + let s = *s as f64; + s.partial_cmp(&o).unwrap_or(Ordering::Equal) + }, + + (Signed(s), Unsigned(o)) => { + let s = i128::from(*s); + let o = i128::from(*o); + s.cmp(&o) + }, + (Signed(s), Signed(o)) => s.cmp(o), + (Signed(s), Float(o)) => { + let s = *s as f64; + s.partial_cmp(o).unwrap_or(Ordering::Equal) + }, + + (Float(s), Unsigned(o)) => { + let o = *o as f64; + s.partial_cmp(&o).unwrap_or(Ordering::Equal) + }, + (Float(s), Signed(o)) => { + let o = *o as f64; + s.partial_cmp(&o).unwrap_or(Ordering::Equal) + }, + (Float(s), Float(o)) => { + s.partial_cmp(o).unwrap_or(Ordering::Equal) + }, + } + } +} + +impl PartialEq for Number { + fn eq(&self, other: &Number) -> bool { + self.cmp(other) == Ordering::Equal + } +} + +impl Eq for Number { } + +pub struct ParseNumberError; + +impl fmt::Display for ParseNumberError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("can not parse number") + } +} diff --git a/src/database/serde/mod.rs b/src/database/serde/mod.rs index 50a3c619e..493124f7e 100644 --- a/src/database/serde/mod.rs +++ b/src/database/serde/mod.rs @@ -17,7 +17,7 @@ macro_rules! forward_to_unserializable_type { pub mod find_id; pub mod key_to_string; -pub mod value_to_i64; +pub mod value_to_number; pub mod serializer; pub mod indexer_serializer; pub mod deserializer; diff --git a/src/database/serde/serializer.rs b/src/database/serde/serializer.rs index bc8b4d1ab..2f41bb82c 100644 --- a/src/database/serde/serializer.rs +++ b/src/database/serde/serializer.rs @@ -5,7 +5,7 @@ use serde::ser; use crate::database::serde::indexer_serializer::IndexerSerializer; use crate::database::serde::key_to_string::KeyToStringSerializer; -use crate::database::serde::value_to_i64::ValueToI64Serializer; +use crate::database::serde::value_to_number::ValueToNumberSerializer; use crate::database::update::DocumentUpdate; use crate::database::serde::SerializerError; use crate::tokenizer::TokenizerBuilder; @@ -231,8 +231,8 @@ where B: TokenizerBuilder value.serialize(serializer)?; } if props.is_ranked() { - let integer = value.serialize(ValueToI64Serializer)?; - self.update.register_ranked_attribute(attr, integer)?; + let number = value.serialize(ValueToNumberSerializer)?; + self.update.register_ranked_attribute(attr, number)?; } } @@ -282,7 +282,7 @@ where B: TokenizerBuilder value.serialize(serializer)?; } if props.is_ranked() { - let integer = value.serialize(ValueToI64Serializer)?; + let integer = value.serialize(ValueToNumberSerializer)?; self.update.register_ranked_attribute(attr, integer)?; } } diff --git a/src/database/serde/value_to_i64.rs b/src/database/serde/value_to_number.rs similarity index 85% rename from src/database/serde/value_to_i64.rs rename to src/database/serde/value_to_number.rs index 9c046d391..a70b92fc4 100644 --- a/src/database/serde/value_to_i64.rs +++ b/src/database/serde/value_to_number.rs @@ -1,12 +1,15 @@ +use std::str::FromStr; + use serde::Serialize; use serde::{ser, ser::Error}; use crate::database::serde::SerializerError; +use crate::database::Number; -pub struct ValueToI64Serializer; +pub struct ValueToNumberSerializer; -impl ser::Serializer for ValueToI64Serializer { - type Ok = i64; +impl ser::Serializer for ValueToNumberSerializer { + type Ok = Number; type Error = SerializerError; type SerializeSeq = ser::Impossible; type SerializeTuple = ser::Impossible; @@ -19,46 +22,50 @@ impl ser::Serializer for ValueToI64Serializer { forward_to_unserializable_type! { bool => serialize_bool, char => serialize_char, - - f32 => serialize_f32, - f64 => serialize_f64, } fn serialize_i8(self, value: i8) -> Result { - Ok(i64::from(value)) + Ok(Number::Signed(value as i64)) } fn serialize_i16(self, value: i16) -> Result { - Ok(i64::from(value)) + Ok(Number::Signed(value as i64)) } fn serialize_i32(self, value: i32) -> Result { - Ok(i64::from(value)) + Ok(Number::Signed(value as i64)) } fn serialize_i64(self, value: i64) -> Result { - Ok(i64::from(value)) + Ok(Number::Signed(value as i64)) } fn serialize_u8(self, value: u8) -> Result { - Ok(i64::from(value)) + Ok(Number::Unsigned(value as u64)) } fn serialize_u16(self, value: u16) -> Result { - Ok(i64::from(value)) + Ok(Number::Unsigned(value as u64)) } fn serialize_u32(self, value: u32) -> Result { - Ok(i64::from(value)) + Ok(Number::Unsigned(value as u64)) } fn serialize_u64(self, value: u64) -> Result { - // Ok(i64::from(value)) - unimplemented!() + Ok(Number::Unsigned(value as u64)) + } + + fn serialize_f32(self, value: f32) -> Result { + Ok(Number::Float(value as f64)) + } + + fn serialize_f64(self, value: f64) -> Result { + Ok(Number::Float(value)) } fn serialize_str(self, value: &str) -> Result { - i64::from_str_radix(value, 10).map_err(SerializerError::custom) + Number::from_str(value).map_err(SerializerError::custom) } fn serialize_bytes(self, _v: &[u8]) -> Result { diff --git a/src/database/update.rs b/src/database/update.rs index 5961b2ec8..e37576e6d 100644 --- a/src/database/update.rs +++ b/src/database/update.rs @@ -16,8 +16,8 @@ use crate::tokenizer::TokenizerBuilder; use crate::data::{DocIds, DocIndexes}; use crate::database::schema::Schema; use crate::database::index::Index; -use crate::database::RankedMap; use crate::database::{DATA_INDEX, DATA_RANKED_MAP}; +use crate::database::{RankedMap, Number}; use crate::{DocumentId, DocIndex}; pub type Token = Vec; // TODO could be replaced by a SmallVec @@ -205,7 +205,7 @@ impl<'a> DocumentUpdate<'a> { pub fn register_ranked_attribute( &mut self, attr: SchemaAttr, - integer: i64, + number: Number, ) -> Result<(), SerializerError> { use serde::ser::Error; @@ -216,7 +216,7 @@ impl<'a> DocumentUpdate<'a> { )); } - self.inner.documents_ranked_fields.insert((self.document_id, attr), integer); + self.inner.documents_ranked_fields.insert((self.document_id, attr), number); Ok(()) } diff --git a/src/database/view.rs b/src/database/view.rs index e757b6021..74e4ef002 100644 --- a/src/database/view.rs +++ b/src/database/view.rs @@ -11,7 +11,6 @@ use crate::database::{retrieve_data_schema, retrieve_data_index, retrieve_data_r use crate::database::serde::deserializer::Deserializer; use crate::database::{DocumentKey, DocumentKeyAttr}; use crate::rank::{QueryBuilder, FilterFunc}; -use crate::database::schema::SchemaAttr; use crate::database::schema::Schema; use crate::database::index::Index; use crate::database::RankedMap;