diff --git a/benchmarks/benches/utils.rs b/benchmarks/benches/utils.rs index e5bdbdfaa..24f5d5343 100644 --- a/benchmarks/benches/utils.rs +++ b/benchmarks/benches/utils.rs @@ -2,6 +2,7 @@ use std::fs::{create_dir_all, remove_dir_all, File}; use std::io::{self, Cursor, Read, Seek}; +use std::num::ParseFloatError; use std::path::Path; use criterion::BenchmarkId; @@ -175,8 +176,7 @@ fn documents_from_csv(reader: impl io::Read) -> anyhow::Result> { let mut writer = Cursor::new(Vec::new()); let mut documents = milli::documents::DocumentBatchBuilder::new(&mut writer)?; - let mut records = csv::Reader::from_reader(reader); - let iter = records.deserialize::>(); + let iter = CSVDocumentDeserializer::from_reader(reader)?; for doc in iter { let doc = doc?; @@ -187,3 +187,77 @@ fn documents_from_csv(reader: impl io::Read) -> anyhow::Result> { Ok(writer.into_inner()) } + +enum AllowedType { + String, + Number, +} + +fn parse_csv_header(header: &str) -> (String, AllowedType) { + // if there are several separators we only split on the last one. + match header.rsplit_once(':') { + Some((field_name, field_type)) => match field_type { + "string" => (field_name.to_string(), AllowedType::String), + "number" => (field_name.to_string(), AllowedType::Number), + // we may return an error in this case. + _otherwise => (header.to_string(), AllowedType::String), + }, + None => (header.to_string(), AllowedType::String), + } +} + +struct CSVDocumentDeserializer +where + R: Read, +{ + documents: csv::StringRecordsIntoIter, + headers: Vec<(String, AllowedType)>, +} + +impl CSVDocumentDeserializer { + fn from_reader(reader: R) -> io::Result { + let mut records = csv::Reader::from_reader(reader); + + let headers = records.headers()?.into_iter().map(parse_csv_header).collect(); + + Ok(Self { documents: records.into_records(), headers }) + } +} + +impl Iterator for CSVDocumentDeserializer { + type Item = anyhow::Result>; + + fn next(&mut self) -> Option { + let csv_document = self.documents.next()?; + + match csv_document { + Ok(csv_document) => { + let mut document = Map::new(); + + for ((field_name, field_type), value) in + self.headers.iter().zip(csv_document.into_iter()) + { + let parsed_value: Result = match field_type { + AllowedType::Number => { + value.parse::().map(Value::from).map_err(Into::into) + } + AllowedType::String => Ok(Value::String(value.to_string())), + }; + + match parsed_value { + Ok(value) => drop(document.insert(field_name.to_string(), value)), + Err(_e) => { + return Some(Err(anyhow::anyhow!( + "Value '{}' is not a valid number", + value + ))) + } + } + } + + Some(Ok(document)) + } + Err(e) => Some(Err(anyhow::anyhow!("Error parsing csv document: {}", e))), + } + } +} diff --git a/http-ui/src/documents_from_csv.rs b/http-ui/src/documents_from_csv.rs new file mode 100644 index 000000000..2b62f23c2 --- /dev/null +++ b/http-ui/src/documents_from_csv.rs @@ -0,0 +1,285 @@ +use std::io::{Read, Result as IoResult}; +use std::num::ParseFloatError; + +use serde_json::{Map, Value}; + +enum AllowedType { + String, + Number, +} + +fn parse_csv_header(header: &str) -> (String, AllowedType) { + // if there are several separators we only split on the last one. + match header.rsplit_once(':') { + Some((field_name, field_type)) => match field_type { + "string" => (field_name.to_string(), AllowedType::String), + "number" => (field_name.to_string(), AllowedType::Number), + // we may return an error in this case. + _otherwise => (header.to_string(), AllowedType::String), + }, + None => (header.to_string(), AllowedType::String), + } +} + +pub struct CSVDocumentDeserializer +where + R: Read, +{ + documents: csv::StringRecordsIntoIter, + headers: Vec<(String, AllowedType)>, +} + +impl CSVDocumentDeserializer { + pub fn from_reader(reader: R) -> IoResult { + let mut records = csv::Reader::from_reader(reader); + + let headers = records.headers()?.into_iter().map(parse_csv_header).collect(); + + Ok(Self { documents: records.into_records(), headers }) + } +} + +impl Iterator for CSVDocumentDeserializer { + type Item = anyhow::Result>; + + fn next(&mut self) -> Option { + let csv_document = self.documents.next()?; + + match csv_document { + Ok(csv_document) => { + let mut document = Map::new(); + + for ((field_name, field_type), value) in + self.headers.iter().zip(csv_document.into_iter()) + { + let parsed_value: Result = match field_type { + AllowedType::Number => { + value.parse::().map(Value::from).map_err(Into::into) + } + AllowedType::String => Ok(Value::String(value.to_string())), + }; + + match parsed_value { + Ok(value) => drop(document.insert(field_name.to_string(), value)), + Err(_e) => { + return Some(Err(anyhow::anyhow!( + "Value '{}' is not a valid number", + value + ))) + } + } + } + + Some(Ok(document)) + } + Err(e) => Some(Err(anyhow::anyhow!("Error parsing csv document: {}", e))), + } + } +} + +#[cfg(test)] +mod test { + use serde_json::json; + + use super::*; + + #[test] + fn simple_csv_document() { + let documents = r#"city,country,pop +"Boston","United States","4628910""#; + + let mut csv_iter = CSVDocumentDeserializer::from_reader(documents.as_bytes()).unwrap(); + + assert_eq!( + Value::Object(csv_iter.next().unwrap().unwrap()), + json!({ + "city": "Boston", + "country": "United States", + "pop": "4628910", + }) + ); + } + + #[test] + fn coma_in_field() { + let documents = r#"city,country,pop +"Boston","United, States","4628910""#; + + let mut csv_iter = CSVDocumentDeserializer::from_reader(documents.as_bytes()).unwrap(); + + assert_eq!( + Value::Object(csv_iter.next().unwrap().unwrap()), + json!({ + "city": "Boston", + "country": "United, States", + "pop": "4628910", + }) + ); + } + + #[test] + fn quote_in_field() { + let documents = r#"city,country,pop +"Boston","United"" States","4628910""#; + + let mut csv_iter = CSVDocumentDeserializer::from_reader(documents.as_bytes()).unwrap(); + + assert_eq!( + Value::Object(csv_iter.next().unwrap().unwrap()), + json!({ + "city": "Boston", + "country": "United\" States", + "pop": "4628910", + }) + ); + } + + #[test] + fn integer_in_field() { + let documents = r#"city,country,pop:number +"Boston","United States","4628910""#; + + let mut csv_iter = CSVDocumentDeserializer::from_reader(documents.as_bytes()).unwrap(); + + assert_eq!( + Value::Object(csv_iter.next().unwrap().unwrap()), + json!({ + "city": "Boston", + "country": "United States", + "pop": 4628910.0, + }) + ); + } + + #[test] + fn float_in_field() { + let documents = r#"city,country,pop:number +"Boston","United States","4628910.01""#; + + let mut csv_iter = CSVDocumentDeserializer::from_reader(documents.as_bytes()).unwrap(); + + assert_eq!( + Value::Object(csv_iter.next().unwrap().unwrap()), + json!({ + "city": "Boston", + "country": "United States", + "pop": 4628910.01, + }) + ); + } + + #[test] + fn several_double_dot_in_header() { + let documents = r#"city:love:string,country:state,pop +"Boston","United States","4628910""#; + + let mut csv_iter = CSVDocumentDeserializer::from_reader(documents.as_bytes()).unwrap(); + + assert_eq!( + Value::Object(csv_iter.next().unwrap().unwrap()), + json!({ + "city:love": "Boston", + "country:state": "United States", + "pop": "4628910", + }) + ); + } + + #[test] + fn ending_by_double_dot_in_header() { + let documents = r#"city:,country,pop +"Boston","United States","4628910""#; + + let mut csv_iter = CSVDocumentDeserializer::from_reader(documents.as_bytes()).unwrap(); + + assert_eq!( + Value::Object(csv_iter.next().unwrap().unwrap()), + json!({ + "city:": "Boston", + "country": "United States", + "pop": "4628910", + }) + ); + } + + #[test] + fn starting_by_double_dot_in_header() { + let documents = r#":city,country,pop +"Boston","United States","4628910""#; + + let mut csv_iter = CSVDocumentDeserializer::from_reader(documents.as_bytes()).unwrap(); + + assert_eq!( + Value::Object(csv_iter.next().unwrap().unwrap()), + json!({ + ":city": "Boston", + "country": "United States", + "pop": "4628910", + }) + ); + } + + #[test] + fn starting_by_double_dot_in_header2() { + let documents = r#":string,country,pop +"Boston","United States","4628910""#; + + let mut csv_iter = CSVDocumentDeserializer::from_reader(documents.as_bytes()).unwrap(); + + assert_eq!( + Value::Object(csv_iter.next().unwrap().unwrap()), + json!({ + "": "Boston", + "country": "United States", + "pop": "4628910", + }) + ); + } + + #[test] + fn double_double_dot_in_header() { + let documents = r#"city::string,country,pop +"Boston","United States","4628910""#; + + let mut csv_iter = CSVDocumentDeserializer::from_reader(documents.as_bytes()).unwrap(); + + assert_eq!( + Value::Object(csv_iter.next().unwrap().unwrap()), + json!({ + "city:": "Boston", + "country": "United States", + "pop": "4628910", + }) + ); + } + + #[test] + fn bad_type_in_header() { + let documents = r#"city,country:number,pop +"Boston","United States","4628910""#; + + let mut csv_iter = CSVDocumentDeserializer::from_reader(documents.as_bytes()).unwrap(); + + assert!(csv_iter.next().unwrap().is_err()); + } + + #[test] + fn bad_column_count1() { + let documents = r#"city,country,pop +"Boston","United States","4628910", "too much""#; + + let mut csv_iter = CSVDocumentDeserializer::from_reader(documents.as_bytes()).unwrap(); + + assert!(csv_iter.next().unwrap().is_err()); + } + + #[test] + fn bad_column_count2() { + let documents = r#"city,country,pop +"Boston","United States""#; + + let mut csv_iter = CSVDocumentDeserializer::from_reader(documents.as_bytes()).unwrap(); + + assert!(csv_iter.next().unwrap().is_err()); + } +} diff --git a/http-ui/src/main.rs b/http-ui/src/main.rs index b76547309..9efdd1371 100644 --- a/http-ui/src/main.rs +++ b/http-ui/src/main.rs @@ -1,3 +1,4 @@ +mod documents_from_csv; mod update_store; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; @@ -38,6 +39,7 @@ use warp::http::Response; use warp::Filter; use self::update_store::UpdateStore; +use crate::documents_from_csv::CSVDocumentDeserializer; #[cfg(target_os = "linux")] #[global_allocator] @@ -1056,8 +1058,7 @@ fn documents_from_csv(reader: impl io::Read) -> anyhow::Result> { let mut writer = Cursor::new(Vec::new()); let mut documents = milli::documents::DocumentBatchBuilder::new(&mut writer)?; - let mut records = csv::Reader::from_reader(reader); - let iter = records.deserialize::>(); + let iter = CSVDocumentDeserializer::from_reader(reader)?; for doc in iter { let doc = doc?;