diff --git a/benchmarks/benches/utils.rs b/benchmarks/benches/utils.rs index 24f5d5343..dbe8fffad 100644 --- a/benchmarks/benches/utils.rs +++ b/benchmarks/benches/utils.rs @@ -1,7 +1,7 @@ #![allow(dead_code)] use std::fs::{create_dir_all, remove_dir_all, File}; -use std::io::{self, Cursor, Read, Seek}; +use std::io::{self, BufRead, BufReader, Cursor, Read, Seek}; use std::num::ParseFloatError; use std::path::Path; @@ -146,44 +146,34 @@ pub fn documents_from(filename: &str, filetype: &str) -> DocumentBatchReader anyhow::Result> { +fn documents_from_jsonl(reader: impl Read) -> anyhow::Result> { let mut writer = Cursor::new(Vec::new()); let mut documents = milli::documents::DocumentBatchBuilder::new(&mut writer)?; - let values = serde_json::Deserializer::from_reader(reader) - .into_iter::>(); - for document in values { - let document = document?; - documents.add_documents(document)?; + let mut buf = String::new(); + let mut reader = BufReader::new(reader); + + while reader.read_line(&mut buf)? > 0 { + documents.extend_from_json(&mut buf.as_bytes())?; } documents.finish()?; Ok(writer.into_inner()) } -fn documents_from_json(reader: impl io::Read) -> anyhow::Result> { +fn documents_from_json(reader: impl Read) -> anyhow::Result> { let mut writer = Cursor::new(Vec::new()); let mut documents = milli::documents::DocumentBatchBuilder::new(&mut writer)?; - let json: serde_json::Value = serde_json::from_reader(reader)?; - documents.add_documents(json)?; + documents.extend_from_json(reader)?; documents.finish()?; Ok(writer.into_inner()) } -fn documents_from_csv(reader: impl io::Read) -> anyhow::Result> { +fn documents_from_csv(reader: impl Read) -> anyhow::Result> { let mut writer = Cursor::new(Vec::new()); - let mut documents = milli::documents::DocumentBatchBuilder::new(&mut writer)?; - - let iter = CSVDocumentDeserializer::from_reader(reader)?; - - for doc in iter { - let doc = doc?; - documents.add_documents(doc)?; - } - - documents.finish()?; + milli::documents::DocumentBatchBuilder::from_csv(reader, &mut writer)?.finish()?; Ok(writer.into_inner()) } diff --git a/cli/src/main.rs b/cli/src/main.rs index b84ff3243..8e28d4a25 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -1,5 +1,5 @@ use std::fs::File; -use std::io::{stdin, Cursor, Read}; +use std::io::{stdin, BufRead, BufReader, Cursor, Read}; use std::path::PathBuf; use std::str::FromStr; @@ -9,7 +9,6 @@ use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; use milli::update::UpdateIndexingStep::{ ComputeIdsAndMergeDocuments, IndexDocuments, MergeDataIntoFinalDatabase, RemapDocumentAddition, }; -use serde_json::{Map, Value}; use structopt::StructOpt; #[cfg(target_os = "linux")] @@ -202,11 +201,11 @@ fn documents_from_jsonl(reader: impl Read) -> Result> { let mut writer = Cursor::new(Vec::new()); let mut documents = milli::documents::DocumentBatchBuilder::new(&mut writer)?; - let values = serde_json::Deserializer::from_reader(reader) - .into_iter::>(); - for document in values { - let document = document?; - documents.add_documents(document)?; + let mut buf = String::new(); + let mut reader = BufReader::new(reader); + + while reader.read_line(&mut buf)? > 0 { + documents.extend_from_json(&mut buf.as_bytes())?; } documents.finish()?; @@ -217,8 +216,7 @@ fn documents_from_json(reader: impl Read) -> Result> { let mut writer = Cursor::new(Vec::new()); let mut documents = milli::documents::DocumentBatchBuilder::new(&mut writer)?; - let json: serde_json::Value = serde_json::from_reader(reader)?; - documents.add_documents(json)?; + documents.extend_from_json(reader)?; documents.finish()?; Ok(writer.into_inner()) @@ -226,17 +224,7 @@ fn documents_from_json(reader: impl Read) -> Result> { fn documents_from_csv(reader: impl Read) -> Result> { let mut writer = Cursor::new(Vec::new()); - let mut documents = milli::documents::DocumentBatchBuilder::new(&mut writer)?; - - let mut records = csv::Reader::from_reader(reader); - let iter = records.deserialize::>(); - - for doc in iter { - let doc = doc?; - documents.add_documents(doc)?; - } - - documents.finish()?; + milli::documents::DocumentBatchBuilder::from_csv(reader, &mut writer)?.finish()?; Ok(writer.into_inner()) } diff --git a/http-ui/src/documents_from_csv.rs b/http-ui/src/documents_from_csv.rs deleted file mode 100644 index 2b62f23c2..000000000 --- a/http-ui/src/documents_from_csv.rs +++ /dev/null @@ -1,285 +0,0 @@ -use std::io::{Read, Result as IoResult}; -use std::num::ParseFloatError; - -use serde_json::{Map, Value}; - -enum AllowedType { - String, - Number, -} - -fn parse_csv_header(header: &str) -> (String, AllowedType) { - // if there are several separators we only split on the last one. - match header.rsplit_once(':') { - Some((field_name, field_type)) => match field_type { - "string" => (field_name.to_string(), AllowedType::String), - "number" => (field_name.to_string(), AllowedType::Number), - // we may return an error in this case. - _otherwise => (header.to_string(), AllowedType::String), - }, - None => (header.to_string(), AllowedType::String), - } -} - -pub struct CSVDocumentDeserializer -where - R: Read, -{ - documents: csv::StringRecordsIntoIter, - headers: Vec<(String, AllowedType)>, -} - -impl CSVDocumentDeserializer { - pub fn from_reader(reader: R) -> IoResult { - let mut records = csv::Reader::from_reader(reader); - - let headers = records.headers()?.into_iter().map(parse_csv_header).collect(); - - Ok(Self { documents: records.into_records(), headers }) - } -} - -impl Iterator for CSVDocumentDeserializer { - type Item = anyhow::Result>; - - fn next(&mut self) -> Option { - let csv_document = self.documents.next()?; - - match csv_document { - Ok(csv_document) => { - let mut document = Map::new(); - - for ((field_name, field_type), value) in - self.headers.iter().zip(csv_document.into_iter()) - { - let parsed_value: Result = match field_type { - AllowedType::Number => { - value.parse::().map(Value::from).map_err(Into::into) - } - AllowedType::String => Ok(Value::String(value.to_string())), - }; - - match parsed_value { - Ok(value) => drop(document.insert(field_name.to_string(), value)), - Err(_e) => { - return Some(Err(anyhow::anyhow!( - "Value '{}' is not a valid number", - value - ))) - } - } - } - - Some(Ok(document)) - } - Err(e) => Some(Err(anyhow::anyhow!("Error parsing csv document: {}", e))), - } - } -} - -#[cfg(test)] -mod test { - use serde_json::json; - - use super::*; - - #[test] - fn simple_csv_document() { - let documents = r#"city,country,pop -"Boston","United States","4628910""#; - - let mut csv_iter = CSVDocumentDeserializer::from_reader(documents.as_bytes()).unwrap(); - - assert_eq!( - Value::Object(csv_iter.next().unwrap().unwrap()), - json!({ - "city": "Boston", - "country": "United States", - "pop": "4628910", - }) - ); - } - - #[test] - fn coma_in_field() { - let documents = r#"city,country,pop -"Boston","United, States","4628910""#; - - let mut csv_iter = CSVDocumentDeserializer::from_reader(documents.as_bytes()).unwrap(); - - assert_eq!( - Value::Object(csv_iter.next().unwrap().unwrap()), - json!({ - "city": "Boston", - "country": "United, States", - "pop": "4628910", - }) - ); - } - - #[test] - fn quote_in_field() { - let documents = r#"city,country,pop -"Boston","United"" States","4628910""#; - - let mut csv_iter = CSVDocumentDeserializer::from_reader(documents.as_bytes()).unwrap(); - - assert_eq!( - Value::Object(csv_iter.next().unwrap().unwrap()), - json!({ - "city": "Boston", - "country": "United\" States", - "pop": "4628910", - }) - ); - } - - #[test] - fn integer_in_field() { - let documents = r#"city,country,pop:number -"Boston","United States","4628910""#; - - let mut csv_iter = CSVDocumentDeserializer::from_reader(documents.as_bytes()).unwrap(); - - assert_eq!( - Value::Object(csv_iter.next().unwrap().unwrap()), - json!({ - "city": "Boston", - "country": "United States", - "pop": 4628910.0, - }) - ); - } - - #[test] - fn float_in_field() { - let documents = r#"city,country,pop:number -"Boston","United States","4628910.01""#; - - let mut csv_iter = CSVDocumentDeserializer::from_reader(documents.as_bytes()).unwrap(); - - assert_eq!( - Value::Object(csv_iter.next().unwrap().unwrap()), - json!({ - "city": "Boston", - "country": "United States", - "pop": 4628910.01, - }) - ); - } - - #[test] - fn several_double_dot_in_header() { - let documents = r#"city:love:string,country:state,pop -"Boston","United States","4628910""#; - - let mut csv_iter = CSVDocumentDeserializer::from_reader(documents.as_bytes()).unwrap(); - - assert_eq!( - Value::Object(csv_iter.next().unwrap().unwrap()), - json!({ - "city:love": "Boston", - "country:state": "United States", - "pop": "4628910", - }) - ); - } - - #[test] - fn ending_by_double_dot_in_header() { - let documents = r#"city:,country,pop -"Boston","United States","4628910""#; - - let mut csv_iter = CSVDocumentDeserializer::from_reader(documents.as_bytes()).unwrap(); - - assert_eq!( - Value::Object(csv_iter.next().unwrap().unwrap()), - json!({ - "city:": "Boston", - "country": "United States", - "pop": "4628910", - }) - ); - } - - #[test] - fn starting_by_double_dot_in_header() { - let documents = r#":city,country,pop -"Boston","United States","4628910""#; - - let mut csv_iter = CSVDocumentDeserializer::from_reader(documents.as_bytes()).unwrap(); - - assert_eq!( - Value::Object(csv_iter.next().unwrap().unwrap()), - json!({ - ":city": "Boston", - "country": "United States", - "pop": "4628910", - }) - ); - } - - #[test] - fn starting_by_double_dot_in_header2() { - let documents = r#":string,country,pop -"Boston","United States","4628910""#; - - let mut csv_iter = CSVDocumentDeserializer::from_reader(documents.as_bytes()).unwrap(); - - assert_eq!( - Value::Object(csv_iter.next().unwrap().unwrap()), - json!({ - "": "Boston", - "country": "United States", - "pop": "4628910", - }) - ); - } - - #[test] - fn double_double_dot_in_header() { - let documents = r#"city::string,country,pop -"Boston","United States","4628910""#; - - let mut csv_iter = CSVDocumentDeserializer::from_reader(documents.as_bytes()).unwrap(); - - assert_eq!( - Value::Object(csv_iter.next().unwrap().unwrap()), - json!({ - "city:": "Boston", - "country": "United States", - "pop": "4628910", - }) - ); - } - - #[test] - fn bad_type_in_header() { - let documents = r#"city,country:number,pop -"Boston","United States","4628910""#; - - let mut csv_iter = CSVDocumentDeserializer::from_reader(documents.as_bytes()).unwrap(); - - assert!(csv_iter.next().unwrap().is_err()); - } - - #[test] - fn bad_column_count1() { - let documents = r#"city,country,pop -"Boston","United States","4628910", "too much""#; - - let mut csv_iter = CSVDocumentDeserializer::from_reader(documents.as_bytes()).unwrap(); - - assert!(csv_iter.next().unwrap().is_err()); - } - - #[test] - fn bad_column_count2() { - let documents = r#"city,country,pop -"Boston","United States""#; - - let mut csv_iter = CSVDocumentDeserializer::from_reader(documents.as_bytes()).unwrap(); - - assert!(csv_iter.next().unwrap().is_err()); - } -} diff --git a/http-ui/src/main.rs b/http-ui/src/main.rs index d27c6d5bb..9e9fe4a2b 100644 --- a/http-ui/src/main.rs +++ b/http-ui/src/main.rs @@ -1,10 +1,9 @@ -mod documents_from_csv; mod update_store; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use std::fmt::Display; use std::fs::{create_dir_all, File}; -use std::io::Cursor; +use std::io::{BufRead, BufReader, Cursor}; use std::net::SocketAddr; use std::num::{NonZeroU32, NonZeroUsize}; use std::path::PathBuf; @@ -39,7 +38,6 @@ use warp::http::Response; use warp::Filter; use self::update_store::UpdateStore; -use crate::documents_from_csv::CSVDocumentDeserializer; #[cfg(target_os = "linux")] #[global_allocator] @@ -1041,11 +1039,11 @@ fn documents_from_jsonl(reader: impl io::Read) -> anyhow::Result> { let mut writer = Cursor::new(Vec::new()); let mut documents = milli::documents::DocumentBatchBuilder::new(&mut writer)?; - let values = serde_json::Deserializer::from_reader(reader) - .into_iter::>(); - for document in values { - let document = document?; - documents.add_documents(document)?; + let mut buf = String::new(); + let mut reader = BufReader::new(reader); + + while reader.read_line(&mut buf)? > 0 { + documents.extend_from_json(&mut buf.as_bytes())?; } documents.finish()?; @@ -1056,8 +1054,7 @@ fn documents_from_json(reader: impl io::Read) -> anyhow::Result> { let mut writer = Cursor::new(Vec::new()); let mut documents = milli::documents::DocumentBatchBuilder::new(&mut writer)?; - let json: serde_json::Value = serde_json::from_reader(reader)?; - documents.add_documents(json)?; + documents.extend_from_json(reader)?; documents.finish()?; Ok(writer.into_inner()) @@ -1065,16 +1062,7 @@ fn documents_from_json(reader: impl io::Read) -> anyhow::Result> { fn documents_from_csv(reader: impl io::Read) -> anyhow::Result> { let mut writer = Cursor::new(Vec::new()); - let mut documents = milli::documents::DocumentBatchBuilder::new(&mut writer)?; - - let iter = CSVDocumentDeserializer::from_reader(reader)?; - - for doc in iter { - let doc = doc?; - documents.add_documents(doc)?; - } - - documents.finish()?; + milli::documents::DocumentBatchBuilder::from_csv(reader, &mut writer)?.finish()?; Ok(writer.into_inner()) } diff --git a/milli/Cargo.toml b/milli/Cargo.toml index bda83c2df..209c8b1f7 100644 --- a/milli/Cargo.toml +++ b/milli/Cargo.toml @@ -48,6 +48,7 @@ itertools = "0.10.0" # logging log = "0.4.14" logging_timer = "1.0.0" +csv = "1.1.6" [dev-dependencies] big_s = "1.0.2" diff --git a/milli/src/documents/builder.rs b/milli/src/documents/builder.rs index ba1319eff..6ba890b79 100644 --- a/milli/src/documents/builder.rs +++ b/milli/src/documents/builder.rs @@ -1,16 +1,20 @@ +use std::collections::BTreeMap; use std::io; +use std::io::{Cursor, Write}; use byteorder::{BigEndian, WriteBytesExt}; -use serde::ser::Serialize; +use serde::Deserializer; +use serde_json::Value; -use super::serde::DocumentSerializer; +use super::serde::DocumentVisitor; use super::{ByteCounter, DocumentsBatchIndex, DocumentsMetadata, Error}; +use crate::FieldId; /// The `DocumentsBatchBuilder` provides a way to build a documents batch in the intermediary /// format used by milli. /// /// The writer used by the DocumentBatchBuilder can be read using a `DocumentBatchReader` to -/// iterate other the documents. +/// iterate over the documents. /// /// ## example: /// ``` @@ -18,43 +22,48 @@ use super::{ByteCounter, DocumentsBatchIndex, DocumentsMetadata, Error}; /// use serde_json::json; /// use std::io::Cursor; /// +/// let json = r##"{"id": 1, "name": "foo"}"##; /// let mut writer = Cursor::new(Vec::new()); /// let mut builder = DocumentBatchBuilder::new(&mut writer).unwrap(); -/// builder.add_documents(json!({"id": 1, "name": "foo"})).unwrap(); +/// builder.extend_from_json(&mut json.as_bytes()).unwrap(); /// builder.finish().unwrap(); /// ``` pub struct DocumentBatchBuilder { - serializer: DocumentSerializer, + inner: ByteCounter, + index: DocumentsBatchIndex, + obkv_buffer: Vec, + value_buffer: Vec, + values: BTreeMap, + count: usize, } impl DocumentBatchBuilder { pub fn new(writer: W) -> Result { - let index = DocumentsBatchIndex::new(); + let index = DocumentsBatchIndex::default(); let mut writer = ByteCounter::new(writer); // add space to write the offset of the metadata at the end of the writer writer.write_u64::(0)?; - let serializer = - DocumentSerializer { writer, buffer: Vec::new(), index, count: 0, allow_seq: true }; - - Ok(Self { serializer }) + Ok(Self { + inner: writer, + index, + obkv_buffer: Vec::new(), + value_buffer: Vec::new(), + values: BTreeMap::new(), + count: 0, + }) } /// Returns the number of documents that have been written to the builder. pub fn len(&self) -> usize { - self.serializer.count + self.count } /// This method must be called after the document addition is terminated. It will put the /// metadata at the end of the file, and write the metadata offset at the beginning on the /// file. pub fn finish(self) -> Result<(), Error> { - let DocumentSerializer { - writer: ByteCounter { mut writer, count: offset }, - index, - count, - .. - } = self.serializer; + let Self { inner: ByteCounter { mut writer, count: offset }, index, count, .. } = self; let meta = DocumentsMetadata { count, index }; @@ -68,13 +77,478 @@ impl DocumentBatchBuilder { Ok(()) } - /// Adds documents to the builder. + /// Extends the builder with json documents from a reader. + pub fn extend_from_json(&mut self, reader: R) -> Result<(), Error> { + let mut de = serde_json::Deserializer::from_reader(reader); + + let mut visitor = DocumentVisitor { + inner: &mut self.inner, + index: &mut self.index, + obkv_buffer: &mut self.obkv_buffer, + value_buffer: &mut self.value_buffer, + values: &mut self.values, + count: &mut self.count, + }; + + de.deserialize_any(&mut visitor).map_err(Error::JsonError)? + } + + /// Creates a builder from a reader of CSV documents. /// - /// The internal index is updated with the fields found - /// in the documents. Document must either be a map or a sequences of map, anything else will - /// fail. - pub fn add_documents(&mut self, document: T) -> Result<(), Error> { - document.serialize(&mut self.serializer)?; - Ok(()) + /// Since all fields in a csv documents are guaranteed to be ordered, we are able to perform + /// optimisations, and extending from another CSV is not allowed. + pub fn from_csv(reader: R, writer: W) -> Result { + let mut this = Self::new(writer)?; + // Ensure that this is the first and only addition made with this builder + debug_assert!(this.index.is_empty()); + + let mut records = csv::Reader::from_reader(reader); + + let headers = records + .headers()? + .into_iter() + .map(parse_csv_header) + .map(|(k, t)| (this.index.insert(&k), t)) + .collect::>(); + + for (i, record) in records.into_records().enumerate() { + let record = record?; + this.obkv_buffer.clear(); + let mut writer = obkv::KvWriter::new(&mut this.obkv_buffer); + for (value, (fid, ty)) in record.into_iter().zip(headers.iter()) { + let value = match ty { + AllowedType::Number => { + value.parse::().map(Value::from).map_err(|error| { + Error::ParseFloat { + error, + // +1 for the header offset. + line: i + 1, + value: value.to_string(), + } + })? + } + AllowedType::String => Value::String(value.to_string()), + }; + + this.value_buffer.clear(); + serde_json::to_writer(Cursor::new(&mut this.value_buffer), &value)?; + writer.insert(*fid, &this.value_buffer)?; + } + + this.inner.write_u32::(this.obkv_buffer.len() as u32)?; + this.inner.write_all(&this.obkv_buffer)?; + + this.count += 1; + } + + Ok(this) + } +} + +#[derive(Debug)] +enum AllowedType { + String, + Number, +} + +fn parse_csv_header(header: &str) -> (String, AllowedType) { + // if there are several separators we only split on the last one. + match header.rsplit_once(':') { + Some((field_name, field_type)) => match field_type { + "string" => (field_name.to_string(), AllowedType::String), + "number" => (field_name.to_string(), AllowedType::Number), + // if the pattern isn't reconized, we keep the whole field. + _otherwise => (header.to_string(), AllowedType::String), + }, + None => (header.to_string(), AllowedType::String), + } +} + +#[cfg(test)] +mod test { + use std::io::Cursor; + + use serde_json::{json, Map}; + + use super::*; + use crate::documents::DocumentBatchReader; + + fn obkv_to_value(obkv: &obkv::KvReader, index: &DocumentsBatchIndex) -> Value { + let mut map = Map::new(); + + for (fid, value) in obkv.iter() { + let field_name = index.name(fid).unwrap().clone(); + let value: Value = serde_json::from_slice(value).unwrap(); + + map.insert(field_name, value); + } + + Value::Object(map) + } + + #[test] + fn add_single_documents_json() { + let mut cursor = Cursor::new(Vec::new()); + let mut builder = DocumentBatchBuilder::new(&mut cursor).unwrap(); + + let json = serde_json::json!({ + "id": 1, + "field": "hello!", + }); + + builder.extend_from_json(Cursor::new(serde_json::to_vec(&json).unwrap())).unwrap(); + + let json = serde_json::json!({ + "blabla": false, + "field": "hello!", + "id": 1, + }); + + builder.extend_from_json(Cursor::new(serde_json::to_vec(&json).unwrap())).unwrap(); + + assert_eq!(builder.len(), 2); + + builder.finish().unwrap(); + + cursor.set_position(0); + + let mut reader = DocumentBatchReader::from_reader(cursor).unwrap(); + + let (index, document) = reader.next_document_with_index().unwrap().unwrap(); + assert_eq!(index.len(), 3); + assert_eq!(document.iter().count(), 2); + + let (index, document) = reader.next_document_with_index().unwrap().unwrap(); + assert_eq!(index.len(), 3); + assert_eq!(document.iter().count(), 3); + + assert!(reader.next_document_with_index().unwrap().is_none()); + } + + #[test] + fn add_documents_seq_json() { + let mut cursor = Cursor::new(Vec::new()); + let mut builder = DocumentBatchBuilder::new(&mut cursor).unwrap(); + + let json = serde_json::json!([{ + "id": 1, + "field": "hello!", + },{ + "blabla": false, + "field": "hello!", + "id": 1, + } + ]); + + builder.extend_from_json(Cursor::new(serde_json::to_vec(&json).unwrap())).unwrap(); + + assert_eq!(builder.len(), 2); + + builder.finish().unwrap(); + + cursor.set_position(0); + + let mut reader = DocumentBatchReader::from_reader(cursor).unwrap(); + + let (index, document) = reader.next_document_with_index().unwrap().unwrap(); + assert_eq!(index.len(), 3); + assert_eq!(document.iter().count(), 2); + + let (index, document) = reader.next_document_with_index().unwrap().unwrap(); + assert_eq!(index.len(), 3); + assert_eq!(document.iter().count(), 3); + + assert!(reader.next_document_with_index().unwrap().is_none()); + } + + #[test] + fn add_documents_csv() { + let mut cursor = Cursor::new(Vec::new()); + + let csv = "id:number,field:string\n1,hello!\n2,blabla"; + + let builder = + DocumentBatchBuilder::from_csv(Cursor::new(csv.as_bytes()), &mut cursor).unwrap(); + builder.finish().unwrap(); + + cursor.set_position(0); + + let mut reader = DocumentBatchReader::from_reader(cursor).unwrap(); + + let (index, document) = reader.next_document_with_index().unwrap().unwrap(); + assert_eq!(index.len(), 2); + assert_eq!(document.iter().count(), 2); + + let (_index, document) = reader.next_document_with_index().unwrap().unwrap(); + assert_eq!(document.iter().count(), 2); + + assert!(reader.next_document_with_index().unwrap().is_none()); + } + + #[test] + fn simple_csv_document() { + let documents = r#"city,country,pop +"Boston","United States","4628910""#; + + let mut buf = Vec::new(); + DocumentBatchBuilder::from_csv(documents.as_bytes(), Cursor::new(&mut buf)) + .unwrap() + .finish() + .unwrap(); + let mut reader = DocumentBatchReader::from_reader(Cursor::new(buf)).unwrap(); + let (index, doc) = reader.next_document_with_index().unwrap().unwrap(); + let val = obkv_to_value(&doc, index); + + assert_eq!( + val, + json!({ + "city": "Boston", + "country": "United States", + "pop": "4628910", + }) + ); + + assert!(reader.next_document_with_index().unwrap().is_none()); + } + + #[test] + fn coma_in_field() { + let documents = r#"city,country,pop +"Boston","United, States","4628910""#; + + let mut buf = Vec::new(); + DocumentBatchBuilder::from_csv(documents.as_bytes(), Cursor::new(&mut buf)) + .unwrap() + .finish() + .unwrap(); + let mut reader = DocumentBatchReader::from_reader(Cursor::new(buf)).unwrap(); + let (index, doc) = reader.next_document_with_index().unwrap().unwrap(); + let val = obkv_to_value(&doc, index); + + assert_eq!( + val, + json!({ + "city": "Boston", + "country": "United, States", + "pop": "4628910", + }) + ); + } + + #[test] + fn quote_in_field() { + let documents = r#"city,country,pop +"Boston","United"" States","4628910""#; + + let mut buf = Vec::new(); + DocumentBatchBuilder::from_csv(documents.as_bytes(), Cursor::new(&mut buf)) + .unwrap() + .finish() + .unwrap(); + let mut reader = DocumentBatchReader::from_reader(Cursor::new(buf)).unwrap(); + let (index, doc) = reader.next_document_with_index().unwrap().unwrap(); + let val = obkv_to_value(&doc, index); + + assert_eq!( + val, + json!({ + "city": "Boston", + "country": "United\" States", + "pop": "4628910", + }) + ); + } + + #[test] + fn integer_in_field() { + let documents = r#"city,country,pop:number +"Boston","United States","4628910""#; + + let mut buf = Vec::new(); + DocumentBatchBuilder::from_csv(documents.as_bytes(), Cursor::new(&mut buf)) + .unwrap() + .finish() + .unwrap(); + let mut reader = DocumentBatchReader::from_reader(Cursor::new(buf)).unwrap(); + let (index, doc) = reader.next_document_with_index().unwrap().unwrap(); + let val = obkv_to_value(&doc, index); + + assert_eq!( + val, + json!({ + "city": "Boston", + "country": "United States", + "pop": 4628910.0, + }) + ); + } + + #[test] + fn float_in_field() { + let documents = r#"city,country,pop:number +"Boston","United States","4628910.01""#; + + let mut buf = Vec::new(); + DocumentBatchBuilder::from_csv(documents.as_bytes(), Cursor::new(&mut buf)) + .unwrap() + .finish() + .unwrap(); + let mut reader = DocumentBatchReader::from_reader(Cursor::new(buf)).unwrap(); + let (index, doc) = reader.next_document_with_index().unwrap().unwrap(); + let val = obkv_to_value(&doc, index); + + assert_eq!( + val, + json!({ + "city": "Boston", + "country": "United States", + "pop": 4628910.01, + }) + ); + } + + #[test] + fn several_colon_in_header() { + let documents = r#"city:love:string,country:state,pop +"Boston","United States","4628910""#; + + let mut buf = Vec::new(); + DocumentBatchBuilder::from_csv(documents.as_bytes(), Cursor::new(&mut buf)) + .unwrap() + .finish() + .unwrap(); + let mut reader = DocumentBatchReader::from_reader(Cursor::new(buf)).unwrap(); + let (index, doc) = reader.next_document_with_index().unwrap().unwrap(); + let val = obkv_to_value(&doc, index); + + assert_eq!( + val, + json!({ + "city:love": "Boston", + "country:state": "United States", + "pop": "4628910", + }) + ); + } + + #[test] + fn ending_by_colon_in_header() { + let documents = r#"city:,country,pop +"Boston","United States","4628910""#; + + let mut buf = Vec::new(); + DocumentBatchBuilder::from_csv(documents.as_bytes(), Cursor::new(&mut buf)) + .unwrap() + .finish() + .unwrap(); + let mut reader = DocumentBatchReader::from_reader(Cursor::new(buf)).unwrap(); + let (index, doc) = reader.next_document_with_index().unwrap().unwrap(); + let val = obkv_to_value(&doc, index); + + assert_eq!( + val, + json!({ + "city:": "Boston", + "country": "United States", + "pop": "4628910", + }) + ); + } + + #[test] + fn starting_by_colon_in_header() { + let documents = r#":city,country,pop +"Boston","United States","4628910""#; + + let mut buf = Vec::new(); + DocumentBatchBuilder::from_csv(documents.as_bytes(), Cursor::new(&mut buf)) + .unwrap() + .finish() + .unwrap(); + let mut reader = DocumentBatchReader::from_reader(Cursor::new(buf)).unwrap(); + let (index, doc) = reader.next_document_with_index().unwrap().unwrap(); + let val = obkv_to_value(&doc, index); + + assert_eq!( + val, + json!({ + ":city": "Boston", + "country": "United States", + "pop": "4628910", + }) + ); + } + + #[ignore] + #[test] + fn starting_by_colon_in_header2() { + let documents = r#":string,country,pop +"Boston","United States","4628910""#; + + let mut buf = Vec::new(); + DocumentBatchBuilder::from_csv(documents.as_bytes(), Cursor::new(&mut buf)) + .unwrap() + .finish() + .unwrap(); + let mut reader = DocumentBatchReader::from_reader(Cursor::new(buf)).unwrap(); + + assert!(reader.next_document_with_index().is_err()); + } + + #[test] + fn double_colon_in_header() { + let documents = r#"city::string,country,pop +"Boston","United States","4628910""#; + + let mut buf = Vec::new(); + DocumentBatchBuilder::from_csv(documents.as_bytes(), Cursor::new(&mut buf)) + .unwrap() + .finish() + .unwrap(); + let mut reader = DocumentBatchReader::from_reader(Cursor::new(buf)).unwrap(); + let (index, doc) = reader.next_document_with_index().unwrap().unwrap(); + let val = obkv_to_value(&doc, index); + + assert_eq!( + val, + json!({ + "city:": "Boston", + "country": "United States", + "pop": "4628910", + }) + ); + } + + #[test] + fn bad_type_in_header() { + let documents = r#"city,country:number,pop +"Boston","United States","4628910""#; + + let mut buf = Vec::new(); + assert!( + DocumentBatchBuilder::from_csv(documents.as_bytes(), Cursor::new(&mut buf)).is_err() + ); + } + + #[test] + fn bad_column_count1() { + let documents = r#"city,country,pop +"Boston","United States","4628910", "too much""#; + + let mut buf = Vec::new(); + assert!( + DocumentBatchBuilder::from_csv(documents.as_bytes(), Cursor::new(&mut buf)).is_err() + ); + } + + #[test] + fn bad_column_count2() { + let documents = r#"city,country,pop +"Boston","United States""#; + + let mut buf = Vec::new(); + assert!( + DocumentBatchBuilder::from_csv(documents.as_bytes(), Cursor::new(&mut buf)).is_err() + ); } } diff --git a/milli/src/documents/mod.rs b/milli/src/documents/mod.rs index f79c210fe..14d97ee7d 100644 --- a/milli/src/documents/mod.rs +++ b/milli/src/documents/mod.rs @@ -7,7 +7,8 @@ mod builder; mod reader; mod serde; -use std::{fmt, io}; +use std::fmt::{self, Debug}; +use std::io; use ::serde::{Deserialize, Serialize}; use bimap::BiHashMap; @@ -17,7 +18,38 @@ pub use reader::DocumentBatchReader; use crate::FieldId; /// A bidirectional map that links field ids to their name in a document batch. -pub type DocumentsBatchIndex = BiHashMap; +#[derive(Default, Debug, Serialize, Deserialize)] +pub struct DocumentsBatchIndex(pub BiHashMap); + +impl DocumentsBatchIndex { + /// Insert the field in the map, or return it's field id if it doesn't already exists. + pub fn insert(&mut self, field: &str) -> FieldId { + match self.0.get_by_right(field) { + Some(field_id) => *field_id, + None => { + let field_id = self.0.len() as FieldId; + self.0.insert(field_id, field.to_string()); + field_id + } + } + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + pub fn len(&self) -> usize { + self.0.len() + } + + pub fn iter(&self) -> bimap::hash::Iter { + self.0.iter() + } + + pub fn name(&self, id: FieldId) -> Option<&String> { + self.0.get_by_left(&id) + } +} #[derive(Debug, Serialize, Deserialize)] struct DocumentsMetadata { @@ -50,14 +82,22 @@ impl io::Write for ByteCounter { #[derive(Debug)] pub enum Error { + ParseFloat { error: std::num::ParseFloatError, line: usize, value: String }, InvalidDocumentFormat, Custom(String), JsonError(serde_json::Error), + CsvError(csv::Error), Serialize(bincode::Error), Io(io::Error), DocumentTooLarge, } +impl From for Error { + fn from(e: csv::Error) -> Self { + Self::CsvError(e) + } +} + impl From for Error { fn from(other: io::Error) -> Self { Self::Io(other) @@ -70,15 +110,25 @@ impl From for Error { } } +impl From for Error { + fn from(other: serde_json::Error) -> Self { + Self::JsonError(other) + } +} + impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { + Error::ParseFloat { error, line, value } => { + write!(f, "Error parsing number {:?} at line {}: {}", value, line, error) + } Error::Custom(s) => write!(f, "Unexpected serialization error: {}", s), Error::InvalidDocumentFormat => f.write_str("Invalid document addition format."), Error::JsonError(err) => write!(f, "Couldn't serialize document value: {}", err), - Error::Io(e) => e.fmt(f), + Error::Io(e) => write!(f, "{}", e), Error::DocumentTooLarge => f.write_str("Provided document is too large (>2Gib)"), - Error::Serialize(e) => e.fmt(f), + Error::Serialize(e) => write!(f, "{}", e), + Error::CsvError(e) => write!(f, "{}", e), } } } @@ -92,7 +142,8 @@ macro_rules! documents { let documents = serde_json::json!($data); let mut writer = std::io::Cursor::new(Vec::new()); let mut builder = crate::documents::DocumentBatchBuilder::new(&mut writer).unwrap(); - builder.add_documents(documents).unwrap(); + let documents = serde_json::to_vec(&documents).unwrap(); + builder.extend_from_json(std::io::Cursor::new(documents)).unwrap(); builder.finish().unwrap(); writer.set_position(0); @@ -103,6 +154,8 @@ macro_rules! documents { #[cfg(test)] mod test { + use std::io::Cursor; + use serde_json::{json, Value}; use super::*; @@ -119,12 +172,14 @@ mod test { "bool": true }); + let json = serde_json::to_vec(&json).unwrap(); + let mut v = Vec::new(); let mut cursor = io::Cursor::new(&mut v); let mut builder = DocumentBatchBuilder::new(&mut cursor).unwrap(); - builder.add_documents(json).unwrap(); + builder.extend_from_json(Cursor::new(json)).unwrap(); builder.finish().unwrap(); @@ -148,13 +203,16 @@ mod test { "toto": false, }); + let doc1 = serde_json::to_vec(&doc1).unwrap(); + let doc2 = serde_json::to_vec(&doc2).unwrap(); + let mut v = Vec::new(); let mut cursor = io::Cursor::new(&mut v); let mut builder = DocumentBatchBuilder::new(&mut cursor).unwrap(); - builder.add_documents(doc1).unwrap(); - builder.add_documents(doc2).unwrap(); + builder.extend_from_json(Cursor::new(doc1)).unwrap(); + builder.extend_from_json(Cursor::new(doc2)).unwrap(); builder.finish().unwrap(); @@ -177,12 +235,14 @@ mod test { { "tata": "hello" }, ]); + let docs = serde_json::to_vec(&docs).unwrap(); + let mut v = Vec::new(); let mut cursor = io::Cursor::new(&mut v); let mut builder = DocumentBatchBuilder::new(&mut cursor).unwrap(); - builder.add_documents(docs).unwrap(); + builder.extend_from_json(Cursor::new(docs)).unwrap(); builder.finish().unwrap(); @@ -210,11 +270,13 @@ mod test { { "tata": "hello" }, ]]); - assert!(builder.add_documents(docs).is_err()); + let docs = serde_json::to_vec(&docs).unwrap(); + assert!(builder.extend_from_json(Cursor::new(docs)).is_err()); let docs = json!("hello"); + let docs = serde_json::to_vec(&docs).unwrap(); - assert!(builder.add_documents(docs).is_err()); + assert!(builder.extend_from_json(Cursor::new(docs)).is_err()); } #[test] diff --git a/milli/src/documents/serde.rs b/milli/src/documents/serde.rs index 036ec246a..d57bf1ffb 100644 --- a/milli/src/documents/serde.rs +++ b/milli/src/documents/serde.rs @@ -1,474 +1,134 @@ use std::collections::BTreeMap; -use std::convert::TryInto; -use std::io::Cursor; -use std::{fmt, io}; +use std::fmt; +use std::io::{Cursor, Write}; -use byteorder::{BigEndian, WriteBytesExt}; -use obkv::KvWriter; -use serde::ser::{Impossible, Serialize, SerializeMap, SerializeSeq, Serializer}; +use byteorder::WriteBytesExt; +use serde::de::{DeserializeSeed, MapAccess, SeqAccess, Visitor}; +use serde::Deserialize; use serde_json::Value; use super::{ByteCounter, DocumentsBatchIndex, Error}; use crate::FieldId; -pub struct DocumentSerializer { - pub writer: ByteCounter, - pub buffer: Vec, - pub index: DocumentsBatchIndex, - pub count: usize, - pub allow_seq: bool, -} - -impl<'a, W: io::Write> Serializer for &'a mut DocumentSerializer { - type Ok = (); - - type Error = Error; - - type SerializeSeq = SeqSerializer<'a, W>; - type SerializeTuple = Impossible<(), Self::Error>; - type SerializeTupleStruct = Impossible<(), Self::Error>; - type SerializeTupleVariant = Impossible<(), Self::Error>; - type SerializeMap = MapSerializer<'a, &'a mut ByteCounter>; - type SerializeStruct = Impossible<(), Self::Error>; - type SerializeStructVariant = Impossible<(), Self::Error>; - fn serialize_map(self, _len: Option) -> Result { - self.buffer.clear(); - let cursor = io::Cursor::new(&mut self.buffer); - self.count += 1; - let map_serializer = MapSerializer { - map: KvWriter::new(cursor), - index: &mut self.index, - writer: &mut self.writer, - mapped_documents: BTreeMap::new(), - }; - - Ok(map_serializer) - } - - fn serialize_seq(self, _len: Option) -> Result { - if self.allow_seq { - // Only allow sequence of documents of depth 1. - self.allow_seq = false; - Ok(SeqSerializer { serializer: self }) - } else { - Err(Error::InvalidDocumentFormat) +macro_rules! tri { + ($e:expr) => { + match $e { + Ok(r) => r, + Err(e) => return Ok(Err(e.into())), } - } + }; +} - fn serialize_bool(self, _v: bool) -> Result { - Err(Error::InvalidDocumentFormat) - } +struct FieldIdResolver<'a>(&'a mut DocumentsBatchIndex); - fn serialize_i8(self, _v: i8) -> Result { - Err(Error::InvalidDocumentFormat) - } +impl<'a, 'de> DeserializeSeed<'de> for FieldIdResolver<'a> { + type Value = FieldId; - fn serialize_i16(self, _v: i16) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_i32(self, _v: i32) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_i64(self, _v: i64) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_u8(self, _v: u8) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_u16(self, _v: u16) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_u32(self, _v: u32) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_u64(self, _v: u64) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_f32(self, _v: f32) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_f64(self, _v: f64) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_char(self, _v: char) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_str(self, _v: &str) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_bytes(self, _v: &[u8]) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_none(self) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_some(self, _value: &T) -> Result + fn deserialize(self, deserializer: D) -> Result where - T: Serialize, + D: serde::Deserializer<'de>, { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_unit(self) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_unit_struct(self, _name: &'static str) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_unit_variant( - self, - _name: &'static str, - _variant_index: u32, - _variant: &'static str, - ) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_newtype_struct( - self, - _name: &'static str, - _value: &T, - ) -> Result - where - T: Serialize, - { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_newtype_variant( - self, - _name: &'static str, - _variant_index: u32, - _variant: &'static str, - _value: &T, - ) -> Result - where - T: Serialize, - { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_tuple(self, _len: usize) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_tuple_struct( - self, - _name: &'static str, - _len: usize, - ) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_tuple_variant( - self, - _name: &'static str, - _variant_index: u32, - _variant: &'static str, - _len: usize, - ) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_struct( - self, - _name: &'static str, - _len: usize, - ) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_struct_variant( - self, - _name: &'static str, - _variant_index: u32, - _variant: &'static str, - _len: usize, - ) -> Result { - Err(Error::InvalidDocumentFormat) + deserializer.deserialize_str(self) } } -pub struct SeqSerializer<'a, W> { - serializer: &'a mut DocumentSerializer, -} +impl<'a, 'de> Visitor<'de> for FieldIdResolver<'a> { + type Value = FieldId; -impl<'a, W: io::Write> SerializeSeq for SeqSerializer<'a, W> { - type Ok = (); - type Error = Error; - - fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + fn visit_str(self, v: &str) -> Result where - T: Serialize, + E: serde::de::Error, { - value.serialize(&mut *self.serializer)?; - Ok(()) + Ok(self.0.insert(v)) } - fn end(self) -> Result { - Ok(()) + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "a string") } } -pub struct MapSerializer<'a, W> { - map: KvWriter>, FieldId>, - index: &'a mut DocumentsBatchIndex, - writer: W, - mapped_documents: BTreeMap, +struct ValueDeserializer; + +impl<'de> DeserializeSeed<'de> for ValueDeserializer { + type Value = serde_json::Value; + + fn deserialize(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + serde_json::Value::deserialize(deserializer) + } } -/// This implementation of SerializeMap uses serilialize_entry instead of seriliaze_key and -/// serialize_value, therefore these to methods remain unimplemented. -impl<'a, W: io::Write> SerializeMap for MapSerializer<'a, W> { - type Ok = (); - type Error = Error; +pub struct DocumentVisitor<'a, W> { + pub inner: &'a mut ByteCounter, + pub index: &'a mut DocumentsBatchIndex, + pub obkv_buffer: &'a mut Vec, + pub value_buffer: &'a mut Vec, + pub values: &'a mut BTreeMap, + pub count: &'a mut usize, +} - fn serialize_key(&mut self, _key: &T) -> Result<(), Self::Error> { - unreachable!() - } +impl<'a, 'de, W: Write> Visitor<'de> for &mut DocumentVisitor<'a, W> { + /// This Visitor value is nothing, since it write the value to a file. + type Value = Result<(), Error>; - fn serialize_value(&mut self, _value: &T) -> Result<(), Self::Error> { - unreachable!() - } - - fn end(mut self) -> Result { - let mut buf = Vec::new(); - for (key, value) in self.mapped_documents { - buf.clear(); - let mut cursor = Cursor::new(&mut buf); - serde_json::to_writer(&mut cursor, &value).map_err(Error::JsonError)?; - self.map.insert(key, cursor.into_inner()).map_err(Error::Io)?; + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + while let Some(v) = seq.next_element_seed(&mut *self)? { + tri!(v) } - let data = self.map.into_inner().map_err(Error::Io)?.into_inner(); - let data_len: u32 = data.len().try_into().map_err(|_| Error::DocumentTooLarge)?; - - self.writer.write_u32::(data_len).map_err(Error::Io)?; - self.writer.write_all(&data).map_err(Error::Io)?; - - Ok(()) + Ok(Ok(())) } - fn serialize_entry( - &mut self, - key: &K, - value: &V, - ) -> Result<(), Self::Error> + fn visit_map(self, mut map: A) -> Result where - K: Serialize, - V: Serialize, + A: MapAccess<'de>, { - let field_serializer = FieldSerializer { index: &mut self.index }; - let field_id: FieldId = key.serialize(field_serializer)?; + while let Some((key, value)) = + map.next_entry_seed(FieldIdResolver(&mut *self.index), ValueDeserializer)? + { + self.values.insert(key, value); + } - let value = serde_json::to_value(value).map_err(Error::JsonError)?; + self.obkv_buffer.clear(); + let mut obkv = obkv::KvWriter::new(Cursor::new(&mut *self.obkv_buffer)); + for (key, value) in self.values.iter() { + self.value_buffer.clear(); + // This is guaranteed to work + tri!(serde_json::to_writer(Cursor::new(&mut *self.value_buffer), value)); + tri!(obkv.insert(*key, &self.value_buffer)); + } - self.mapped_documents.insert(field_id, value); + let reader = tri!(obkv.into_inner()).into_inner(); - Ok(()) + tri!(self.inner.write_u32::(reader.len() as u32)); + tri!(self.inner.write_all(reader)); + + *self.count += 1; + self.values.clear(); + + Ok(Ok(())) + } + + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "a documents, or a sequence of documents.") } } -struct FieldSerializer<'a> { - index: &'a mut DocumentsBatchIndex, -} +impl<'a, 'de, W> DeserializeSeed<'de> for &mut DocumentVisitor<'a, W> +where + W: Write, +{ + type Value = Result<(), Error>; -impl<'a> serde::Serializer for FieldSerializer<'a> { - type Ok = FieldId; - - type Error = Error; - - type SerializeSeq = Impossible; - type SerializeTuple = Impossible; - type SerializeTupleStruct = Impossible; - type SerializeTupleVariant = Impossible; - type SerializeMap = Impossible; - type SerializeStruct = Impossible; - type SerializeStructVariant = Impossible; - - fn serialize_str(self, ws: &str) -> Result { - let field_id = match self.index.get_by_right(ws) { - Some(field_id) => *field_id, - None => { - let field_id = self.index.len() as FieldId; - self.index.insert(field_id, ws.to_string()); - field_id - } - }; - - Ok(field_id) - } - - fn serialize_bool(self, _v: bool) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_i8(self, _v: i8) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_i16(self, _v: i16) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_i32(self, _v: i32) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_i64(self, _v: i64) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_u8(self, _v: u8) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_u16(self, _v: u16) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_u32(self, _v: u32) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_u64(self, _v: u64) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_f32(self, _v: f32) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_f64(self, _v: f64) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_char(self, _v: char) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_bytes(self, _v: &[u8]) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_none(self) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_some(self, _value: &T) -> Result + fn deserialize(self, deserializer: D) -> Result where - T: Serialize, + D: serde::Deserializer<'de>, { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_unit(self) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_unit_struct(self, _name: &'static str) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_unit_variant( - self, - _name: &'static str, - _variant_index: u32, - _variant: &'static str, - ) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_newtype_struct( - self, - _name: &'static str, - _value: &T, - ) -> Result - where - T: Serialize, - { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_newtype_variant( - self, - _name: &'static str, - _variant_index: u32, - _variant: &'static str, - _value: &T, - ) -> Result - where - T: Serialize, - { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_seq(self, _len: Option) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_tuple(self, _len: usize) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_tuple_struct( - self, - _name: &'static str, - _len: usize, - ) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_tuple_variant( - self, - _name: &'static str, - _variant_index: u32, - _variant: &'static str, - _len: usize, - ) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_map(self, _len: Option) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_struct( - self, - _name: &'static str, - _len: usize, - ) -> Result { - Err(Error::InvalidDocumentFormat) - } - - fn serialize_struct_variant( - self, - _name: &'static str, - _variant_index: u32, - _variant: &'static str, - _len: usize, - ) -> Result { - Err(Error::InvalidDocumentFormat) - } -} - -impl serde::ser::Error for Error { - fn custom(msg: T) -> Self { - Error::Custom(msg.to_string()) + deserializer.deserialize_map(self) } } diff --git a/milli/src/index.rs b/milli/src/index.rs index 6ce693fbe..fe89fe734 100644 --- a/milli/src/index.rs +++ b/milli/src/index.rs @@ -953,6 +953,7 @@ pub(crate) mod tests { { "id": 1, "name": "kevin", "has_dog": true }, { "id": 2, "name": "bob" } ]); + let mut wtxn = index.write_txn().unwrap(); let builder = IndexDocuments::new(&mut wtxn, &index, 0); builder.execute(content, |_, _| ()).unwrap(); diff --git a/milli/src/search/distinct/mod.rs b/milli/src/search/distinct/mod.rs index deb51a053..11f6379e3 100644 --- a/milli/src/search/distinct/mod.rs +++ b/milli/src/search/distinct/mod.rs @@ -68,7 +68,9 @@ mod test { "txts": sample_txts[..(rng.gen_range(0..3))], "cat-ints": sample_ints[..(rng.gen_range(0..3))], }); - builder.add_documents(doc).unwrap(); + + let doc = Cursor::new(serde_json::to_vec(&doc).unwrap()); + builder.extend_from_json(doc).unwrap(); } builder.finish().unwrap(); diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index 92bcab0e9..440546b10 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -877,7 +877,8 @@ mod tests { let mut cursor = Cursor::new(Vec::new()); let mut builder = DocumentBatchBuilder::new(&mut cursor).unwrap(); - builder.add_documents(big_object).unwrap(); + let big_object = Cursor::new(serde_json::to_vec(&big_object).unwrap()); + builder.extend_from_json(big_object).unwrap(); builder.finish().unwrap(); cursor.set_position(0); let content = DocumentBatchReader::from_reader(cursor).unwrap(); @@ -905,8 +906,9 @@ mod tests { let mut cursor = Cursor::new(Vec::new()); + let big_object = serde_json::to_string(&big_object).unwrap(); let mut builder = DocumentBatchBuilder::new(&mut cursor).unwrap(); - builder.add_documents(big_object).unwrap(); + builder.extend_from_json(&mut big_object.as_bytes()).unwrap(); builder.finish().unwrap(); cursor.set_position(0); let content = DocumentBatchReader::from_reader(cursor).unwrap(); diff --git a/milli/src/update/index_documents/transform.rs b/milli/src/update/index_documents/transform.rs index c0c88abed..08aa72d35 100644 --- a/milli/src/update/index_documents/transform.rs +++ b/milli/src/update/index_documents/transform.rs @@ -75,7 +75,7 @@ fn create_fields_mapping( .collect() } -fn find_primary_key(index: &bimap::BiHashMap) -> Option<&str> { +fn find_primary_key(index: &DocumentsBatchIndex) -> Option<&str> { index .iter() .sorted_by_key(|(k, _)| *k) @@ -179,7 +179,7 @@ impl Transform<'_, '_> { if !self.autogenerate_docids { let mut json = Map::new(); for (key, value) in document.iter() { - let key = addition_index.get_by_left(&key).cloned(); + let key = addition_index.name(key).cloned(); let value = serde_json::from_slice::(&value).ok(); if let Some((k, v)) = key.zip(value) { @@ -544,6 +544,7 @@ mod test { mod primary_key_inference { use bimap::BiHashMap; + use crate::documents::DocumentsBatchIndex; use crate::update::index_documents::transform::find_primary_key; #[test] @@ -557,7 +558,7 @@ mod test { map.insert(4, "fakeId".to_string()); map.insert(0, "realId".to_string()); - assert_eq!(find_primary_key(&map), Some("realId")); + assert_eq!(find_primary_key(&DocumentsBatchIndex(map)), Some("realId")); } } } diff --git a/milli/tests/search/mod.rs b/milli/tests/search/mod.rs index cda0da617..e8fb3fdfa 100644 --- a/milli/tests/search/mod.rs +++ b/milli/tests/search/mod.rs @@ -61,9 +61,12 @@ pub fn setup_search_index_with_criteria(criteria: &[Criterion]) -> Index { let mut cursor = Cursor::new(Vec::new()); let mut documents_builder = DocumentBatchBuilder::new(&mut cursor).unwrap(); let reader = Cursor::new(CONTENT.as_bytes()); + for doc in serde_json::Deserializer::from_reader(reader).into_iter::() { - documents_builder.add_documents(doc.unwrap()).unwrap(); + let doc = Cursor::new(serde_json::to_vec(&doc.unwrap()).unwrap()); + documents_builder.extend_from_json(doc).unwrap(); } + documents_builder.finish().unwrap(); cursor.set_position(0); diff --git a/milli/tests/search/query_criteria.rs b/milli/tests/search/query_criteria.rs index f3b04c4fa..e5dde049c 100644 --- a/milli/tests/search/query_criteria.rs +++ b/milli/tests/search/query_criteria.rs @@ -409,7 +409,8 @@ fn criteria_ascdesc() { "age": age, }); - batch_builder.add_documents(json).unwrap(); + let json = Cursor::new(serde_json::to_vec(&json).unwrap()); + batch_builder.extend_from_json(json).unwrap(); }); batch_builder.finish().unwrap();