diff --git a/meilidb-data/src/serde/extract_document_id.rs b/meilidb-data/src/serde/extract_document_id.rs new file mode 100644 index 000000000..5310da538 --- /dev/null +++ b/meilidb-data/src/serde/extract_document_id.rs @@ -0,0 +1,259 @@ +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; + +use meilidb_core::DocumentId; +use serde::Serialize; +use serde::ser; + +use super::{SerializerError, ExtractString}; + +pub fn extract_document_id( + identifier: &str, + document: &D, +) -> Result, SerializerError> +where D: serde::Serialize, +{ + let serializer = ExtractDocumentId { identifier }; + document.serialize(serializer) +} + +fn calculate_hash(t: &T) -> u64 { + let mut s = DefaultHasher::new(); + t.hash(&mut s); + s.finish() +} + +struct ExtractDocumentId<'a> { + identifier: &'a str, +} + +impl<'a> ser::Serializer for ExtractDocumentId<'a> { + type Ok = Option; + type Error = SerializerError; + type SerializeSeq = ser::Impossible; + type SerializeTuple = ser::Impossible; + type SerializeTupleStruct = ser::Impossible; + type SerializeTupleVariant = ser::Impossible; + type SerializeMap = ExtractDocumentIdMapSerializer<'a>; + type SerializeStruct = ExtractDocumentIdStructSerializer<'a>; + type SerializeStructVariant = ser::Impossible; + + forward_to_unserializable_type! { + bool => serialize_bool, + char => serialize_char, + + i8 => serialize_i8, + i16 => serialize_i16, + i32 => serialize_i32, + i64 => serialize_i64, + + u8 => serialize_u8, + u16 => serialize_u16, + u32 => serialize_u32, + u64 => serialize_u64, + + f32 => serialize_f32, + f64 => serialize_f64, + } + + fn serialize_str(self, value: &str) -> Result { + Err(SerializerError::UnserializableType { name: "str" }) + } + + fn serialize_bytes(self, _v: &[u8]) -> Result { + Err(SerializerError::UnserializableType { name: "&[u8]" }) + } + + fn serialize_none(self) -> Result { + Err(SerializerError::UnserializableType { name: "Option" }) + } + + fn serialize_some(self, _value: &T) -> Result + where T: Serialize, + { + Err(SerializerError::UnserializableType { name: "Option" }) + } + + fn serialize_unit(self) -> Result { + Err(SerializerError::UnserializableType { name: "()" }) + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result { + Err(SerializerError::UnserializableType { name: "unit struct" }) + } + + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str + ) -> Result + { + Err(SerializerError::UnserializableType { name: "unit variant" }) + } + + fn serialize_newtype_struct( + self, + _name: &'static str, + value: &T + ) -> Result + where T: Serialize, + { + value.serialize(self) + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T + ) -> Result + where T: Serialize, + { + Err(SerializerError::UnserializableType { name: "newtype variant" }) + } + + fn serialize_seq(self, _len: Option) -> Result { + Err(SerializerError::UnserializableType { name: "sequence" }) + } + + fn serialize_tuple(self, _len: usize) -> Result { + Err(SerializerError::UnserializableType { name: "tuple" }) + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize + ) -> Result + { + Err(SerializerError::UnserializableType { name: "tuple struct" }) + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize + ) -> Result + { + Err(SerializerError::UnserializableType { name: "tuple variant" }) + } + + fn serialize_map(self, _len: Option) -> Result { + let serializer = ExtractDocumentIdMapSerializer { + identifier: self.identifier, + document_id: None, + current_key_name: None, + }; + + Ok(serializer) + } + + fn serialize_struct( + self, + _name: &'static str, + _len: usize + ) -> Result + { + let serializer = ExtractDocumentIdStructSerializer { + identifier: self.identifier, + document_id: None, + }; + + Ok(serializer) + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize + ) -> Result + { + Err(SerializerError::UnserializableType { name: "struct variant" }) + } +} + +pub struct ExtractDocumentIdMapSerializer<'a> { + identifier: &'a str, + document_id: Option, + current_key_name: Option, +} + +impl<'a> ser::SerializeMap for ExtractDocumentIdMapSerializer<'a> { + type Ok = Option; + type Error = SerializerError; + + fn serialize_key(&mut self, key: &T) -> Result<(), Self::Error> + where T: Serialize, + { + let key = key.serialize(ExtractString)?; + self.current_key_name = Some(key); + Ok(()) + } + + fn serialize_value(&mut self, value: &T) -> Result<(), Self::Error> + where T: Serialize, + { + let key = self.current_key_name.take().unwrap(); + self.serialize_entry(&key, value) + } + + fn serialize_entry( + &mut self, + key: &K, + value: &V + ) -> Result<(), Self::Error> + where K: Serialize, V: Serialize, + { + let key = key.serialize(ExtractString)?; + + if self.identifier == key { + // TODO is it possible to have multiple ids? + let id = bincode::serialize(value).unwrap(); + let hash = calculate_hash(&id); + self.document_id = Some(DocumentId(hash)); + } + + Ok(()) + } + + fn end(self) -> Result { + Ok(self.document_id) + } +} + +pub struct ExtractDocumentIdStructSerializer<'a> { + identifier: &'a str, + document_id: Option, +} + +impl<'a> ser::SerializeStruct for ExtractDocumentIdStructSerializer<'a> { + type Ok = Option; + type Error = SerializerError; + + fn serialize_field( + &mut self, + key: &'static str, + value: &T + ) -> Result<(), Self::Error> + where T: Serialize, + { + if self.identifier == key { + // TODO can it be possible to have multiple ids? + let id = bincode::serialize(value).unwrap(); + let hash = calculate_hash(&id); + self.document_id = Some(DocumentId(hash)); + } + + Ok(()) + } + + fn end(self) -> Result { + Ok(self.document_id) + } +} diff --git a/meilidb-data/src/serde/mod.rs b/meilidb-data/src/serde/mod.rs index 284c970cf..cf85e60be 100644 --- a/meilidb-data/src/serde/mod.rs +++ b/meilidb-data/src/serde/mod.rs @@ -11,10 +11,12 @@ macro_rules! forward_to_unserializable_type { mod deserializer; mod serializer; mod extract_string; +mod extract_document_id; pub use self::deserializer::Deserializer; pub use self::serializer::Serializer; pub use self::extract_string::ExtractString; +pub use self::extract_document_id::extract_document_id; use std::{fmt, error::Error}; use rmp_serde::encode::Error as RmpError; diff --git a/meilidb-data/src/serde/serializer.rs b/meilidb-data/src/serde/serializer.rs index 7a5808cfd..9be35c2dc 100644 --- a/meilidb-data/src/serde/serializer.rs +++ b/meilidb-data/src/serde/serializer.rs @@ -1,9 +1,5 @@ -use std::collections::{HashSet, HashMap}; -use std::fmt; -use std::error::Error; - use meilidb_core::DocumentId; -use serde::{de, ser}; +use serde::ser; use crate::schema::Schema; use crate::database::RawIndex;