From 0f86d6b28f333b2f75439ef3483b0f8f2f59c824 Mon Sep 17 00:00:00 2001 From: marin postma Date: Thu, 21 Oct 2021 11:05:16 +0200 Subject: [PATCH] implement csv serialization --- milli/Cargo.toml | 1 + milli/src/documents/builder.rs | 103 +++++++++++++++++- milli/src/documents/mod.rs | 33 +++++- milli/src/documents/serde.rs | 14 +-- milli/src/update/index_documents/transform.rs | 8 +- 5 files changed, 142 insertions(+), 17 deletions(-) diff --git a/milli/Cargo.toml b/milli/Cargo.toml index 594cc60e0..709f8d865 100644 --- a/milli/Cargo.toml +++ b/milli/Cargo.toml @@ -47,6 +47,7 @@ itertools = "0.10.0" # logging log = "0.4.14" logging_timer = "1.0.0" +csv = "1.1.6" [dev-dependencies] big_s = "1.0.2" diff --git a/milli/src/documents/builder.rs b/milli/src/documents/builder.rs index 98213edd7..719580b4a 100644 --- a/milli/src/documents/builder.rs +++ b/milli/src/documents/builder.rs @@ -1,5 +1,8 @@ use std::collections::BTreeMap; +use std::collections::HashMap; use std::io; +use std::io::Cursor; +use std::io::Write; use byteorder::{BigEndian, WriteBytesExt}; use serde::Deserializer; @@ -38,7 +41,7 @@ pub struct DocumentBatchBuilder { impl DocumentBatchBuilder { pub fn new(writer: W) -> Result { - let index = DocumentsBatchIndex::new(); + let index = DocumentsBatchIndex::default(); let mut writer = ByteCounter::new(writer); // add space to write the offset of the metadata at the end of the writer writer.write_u64::(0)?; @@ -101,6 +104,79 @@ impl DocumentBatchBuilder { Ok(()) } + + /// Extends the builder with json documents from a reader. + /// + /// This method can be only called once and is mutually exclusive with extend from json. This + /// is because the fields in a csv are always guaranteed to come in order, and permits some + /// optimizations. + /// + /// From csv takes care to call finish in the end. + pub fn from_csv(mut self, reader: R) -> Result<(), Error> { + + // Ensure that this is the first and only addition made with this builder + debug_assert!(self.index.is_empty()); + + let mut records = csv::Reader::from_reader(reader); + + let headers = records + .headers() + .unwrap() + .into_iter() + .map(parse_csv_header) + .map(|(k, t)| (self.index.insert(&k), t)) + .collect::>(); + + let records = records.into_records(); + + dbg!(&headers); + for record in records { + match record { + Ok(record) => { + let mut writer = obkv::KvWriter::new(Cursor::new(&mut self.obkv_buffer)); + for (value, (fid, ty)) in record.into_iter().zip(headers.iter()) { + let value = match ty { + AllowedType::Number => value.parse::().map(Value::from).unwrap(), + AllowedType::String => Value::String(value.to_string()), + }; + + serde_json::to_writer(Cursor::new(&mut self.value_buffer), dbg!(&value)).unwrap(); + writer.insert(*fid, &self.value_buffer)?; + self.value_buffer.clear(); + } + + self.inner.write_u32::(self.obkv_buffer.len() as u32)?; + self.inner.write_all(&self.obkv_buffer)?; + + self.obkv_buffer.clear(); + self.count += 1; + }, + Err(_) => panic!(), + } + } + + self.finish()?; + + Ok(()) + } +} + +#[derive(Debug)] +enum AllowedType { + String, + Number, +} + +fn parse_csv_header(header: &str) -> (String, AllowedType) { + // if there are several separators we only split on the last one. + match header.rsplit_once(':') { + Some((field_name, field_type)) => match field_type { + "string" => (field_name.to_string(), AllowedType::String), + "number" => (field_name.to_string(), AllowedType::Number), // if the pattern isn't reconized, we keep the whole field. + _otherwise => (header.to_string(), AllowedType::String), + }, + None => (header.to_string(), AllowedType::String), + } } #[cfg(test)] @@ -185,4 +261,29 @@ mod test { assert!(reader.next_document_with_index().unwrap().is_none()); } + + #[test] + fn add_documents_csv() { + let mut cursor = Cursor::new(Vec::new()); + let builder = DocumentBatchBuilder::new(&mut cursor).unwrap(); + + let csv = "id:number,field:string\n1,hello!\n2,blabla"; + + builder.from_csv(Cursor::new(csv.as_bytes())).unwrap(); + + cursor.set_position(0); + + let mut reader = DocumentBatchReader::from_reader(cursor).unwrap(); + + dbg!(reader.len()); + + let (index, document) = reader.next_document_with_index().unwrap().unwrap(); + assert_eq!(index.len(), 2); + assert_eq!(document.iter().count(), 2); + + let (_index, document) = reader.next_document_with_index().unwrap().unwrap(); + assert_eq!(document.iter().count(), 2); + + assert!(reader.next_document_with_index().unwrap().is_none()); + } } diff --git a/milli/src/documents/mod.rs b/milli/src/documents/mod.rs index ce0539c24..9f6ebd3de 100644 --- a/milli/src/documents/mod.rs +++ b/milli/src/documents/mod.rs @@ -17,7 +17,38 @@ pub use reader::DocumentBatchReader; use crate::FieldId; /// A bidirectional map that links field ids to their name in a document batch. -pub type DocumentsBatchIndex = BiHashMap; +#[derive(Default, Debug, Serialize, Deserialize)] +pub struct DocumentsBatchIndex(pub BiHashMap); + +impl DocumentsBatchIndex { + /// Insert the field in the map, or return it's field id if it doesn't already exists. + pub fn insert(&mut self, field: &str) -> FieldId { + match self.0.get_by_right(field) { + Some(field_id) => *field_id, + None => { + let field_id = self.0.len() as FieldId; + self.0.insert(field_id, field.to_string()); + field_id + } + } + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + pub fn len(&self) -> usize { + self.0.len() + } + + pub fn iter(&self) -> impl Iterator { + self.0.iter() + } + + pub fn get_id(&self, id: FieldId) -> Option<&String> { + self.0.get_by_left(&id) + } +} #[derive(Debug, Serialize, Deserialize)] struct DocumentsMetadata { diff --git a/milli/src/documents/serde.rs b/milli/src/documents/serde.rs index 0d02fff6c..86fb68534 100644 --- a/milli/src/documents/serde.rs +++ b/milli/src/documents/serde.rs @@ -31,17 +31,9 @@ impl<'a, 'de> Visitor<'de> for FieldIdResolver<'a> { fn visit_str(self, v: &str) -> Result where - E: serde::de::Error, { - let field_id = match self.0.get_by_right(v) { - Some(field_id) => *field_id, - None => { - let field_id = self.0.len() as FieldId; - self.0.insert(field_id, v.to_string()); - field_id - } - }; - - Ok(field_id) + E: serde::de::Error, + { + Ok(self.0.insert(v)) } fn expecting(&self, _formatter: &mut fmt::Formatter) -> fmt::Result { diff --git a/milli/src/update/index_documents/transform.rs b/milli/src/update/index_documents/transform.rs index c0c88abed..5af1eda72 100644 --- a/milli/src/update/index_documents/transform.rs +++ b/milli/src/update/index_documents/transform.rs @@ -75,7 +75,7 @@ fn create_fields_mapping( .collect() } -fn find_primary_key(index: &bimap::BiHashMap) -> Option<&str> { +fn find_primary_key(index: &DocumentsBatchIndex) -> Option<&str> { index .iter() .sorted_by_key(|(k, _)| *k) @@ -179,7 +179,7 @@ impl Transform<'_, '_> { if !self.autogenerate_docids { let mut json = Map::new(); for (key, value) in document.iter() { - let key = addition_index.get_by_left(&key).cloned(); + let key = addition_index.get_id(key).cloned(); let value = serde_json::from_slice::(&value).ok(); if let Some((k, v)) = key.zip(value) { @@ -544,7 +544,7 @@ mod test { mod primary_key_inference { use bimap::BiHashMap; - use crate::update::index_documents::transform::find_primary_key; + use crate::{documents::DocumentsBatchIndex, update::index_documents::transform::find_primary_key}; #[test] fn primary_key_infered_on_first_field() { @@ -557,7 +557,7 @@ mod test { map.insert(4, "fakeId".to_string()); map.insert(0, "realId".to_string()); - assert_eq!(find_primary_key(&map), Some("realId")); + assert_eq!(find_primary_key(&DocumentsBatchIndex(map)), Some("realId")); } } }