diff --git a/Cargo.lock b/Cargo.lock index 116ffd8cc..ad726632b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -471,6 +471,7 @@ dependencies = [ "lazy_static", "memchr", "regex-automata", + "serde", ] [[package]] @@ -772,6 +773,28 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "csv" +version = "1.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22813a6dc45b335f9bade10bf7271dc477e81113e89eb251a0bc2a8a81c536e1" +dependencies = [ + "bstr", + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b2466559f260f48ad25fe6317b3c8dac77b5bdb5763ac7d9d6103530663bc90" +dependencies = [ + "memchr", +] + [[package]] name = "derivative" version = "2.2.0" @@ -1668,6 +1691,7 @@ dependencies = [ "bytes", "chrono", "crossbeam-channel", + "csv", "derivative", "either", "flate2", diff --git a/meilisearch-lib/Cargo.toml b/meilisearch-lib/Cargo.toml index df8b1e45d..43db857d2 100644 --- a/meilisearch-lib/Cargo.toml +++ b/meilisearch-lib/Cargo.toml @@ -15,6 +15,7 @@ arc-swap = "1.3.2" byte-unit = { version = "4.0.12", default-features = false, features = ["std"] } bytes = "1.1.0" chrono = { version = "0.4.19", features = ["serde"] } +csv = "1.1.6" crossbeam-channel = "0.5.1" either = "1.6.1" flate2 = "1.0.21" diff --git a/meilisearch-lib/src/index_controller/updates/csv_documents_iter.rs b/meilisearch-lib/src/index_controller/updates/csv_documents_iter.rs new file mode 100644 index 000000000..837240ceb --- /dev/null +++ b/meilisearch-lib/src/index_controller/updates/csv_documents_iter.rs @@ -0,0 +1,282 @@ +use super::error::{Result, UpdateLoopError}; +use std::io::{Read, Result as IoResult}; + +use csv::{Reader as CsvReader, StringRecordsIntoIter}; +use serde_json::{Map, Value}; + +enum AllowedType { + String, + Number, +} + +fn parse_csv_header(header: &str) -> (String, AllowedType) { + // if there are several separators we only split on the last one. + match header.rsplit_once(':') { + Some((field_name, field_type)) => match field_type { + "string" => (field_name.to_string(), AllowedType::String), + "number" => (field_name.to_string(), AllowedType::Number), + // if the pattern isn't reconized, we keep the whole field. + _otherwise => (header.to_string(), AllowedType::String), + }, + None => (header.to_string(), AllowedType::String), + } +} + +pub struct CsvDocumentIter +where + R: Read, +{ + documents: StringRecordsIntoIter, + headers: Vec<(String, AllowedType)>, +} + +impl CsvDocumentIter { + pub fn from_reader(reader: R) -> IoResult { + let mut records = CsvReader::from_reader(reader); + + let headers = records + .headers()? + .into_iter() + .map(parse_csv_header) + .collect(); + + Ok(Self { + documents: records.into_records(), + headers, + }) + } +} + +impl Iterator for CsvDocumentIter { + type Item = Result>; + + fn next(&mut self) -> Option { + let csv_document = self.documents.next()?; + + match csv_document { + Ok(csv_document) => { + let mut document = Map::new(); + + for ((field_name, field_type), value) in + self.headers.iter().zip(csv_document.into_iter()) + { + let parsed_value = (|| match field_type { + AllowedType::Number => value + .parse::() + .map(Value::from) + .map_err(|e| UpdateLoopError::MalformedPayload(Box::new(e))), + AllowedType::String => Ok(Value::String(value.to_string())), + })(); + + match parsed_value { + Ok(value) => drop(document.insert(field_name.to_string(), value)), + Err(e) => return Some(Err(e)), + } + } + + Some(Ok(document)) + } + Err(e) => Some(Err(UpdateLoopError::MalformedPayload(Box::new(e)))), + } + } +} + +#[cfg(test)] +mod test { + use serde_json::json; + + use super::*; + + #[test] + fn simple_csv_document() { + let documents = r#"city,country,pop +"Boston","United States","4628910""#; + + let mut csv_iter = CsvDocumentIter::from_reader(documents.as_bytes()).unwrap(); + + assert_eq!( + Value::Object(csv_iter.next().unwrap().unwrap()), + json!({ + "city": "Boston", + "country": "United States", + "pop": "4628910", + }) + ); + } + + #[test] + fn coma_in_field() { + let documents = r#"city,country,pop +"Boston","United, States","4628910""#; + + let mut csv_iter = CsvDocumentIter::from_reader(documents.as_bytes()).unwrap(); + + assert_eq!( + Value::Object(csv_iter.next().unwrap().unwrap()), + json!({ + "city": "Boston", + "country": "United, States", + "pop": "4628910", + }) + ); + } + + #[test] + fn quote_in_field() { + let documents = r#"city,country,pop +"Boston","United"" States","4628910""#; + + let mut csv_iter = CsvDocumentIter::from_reader(documents.as_bytes()).unwrap(); + + assert_eq!( + Value::Object(csv_iter.next().unwrap().unwrap()), + json!({ + "city": "Boston", + "country": "United\" States", + "pop": "4628910", + }) + ); + } + + #[test] + fn integer_in_field() { + let documents = r#"city,country,pop:number +"Boston","United States","4628910""#; + + let mut csv_iter = CsvDocumentIter::from_reader(documents.as_bytes()).unwrap(); + + assert_eq!( + Value::Object(csv_iter.next().unwrap().unwrap()), + json!({ + "city": "Boston", + "country": "United States", + "pop": 4628910.0, + }) + ); + } + + #[test] + fn float_in_field() { + let documents = r#"city,country,pop:number +"Boston","United States","4628910.01""#; + + let mut csv_iter = CsvDocumentIter::from_reader(documents.as_bytes()).unwrap(); + + assert_eq!( + Value::Object(csv_iter.next().unwrap().unwrap()), + json!({ + "city": "Boston", + "country": "United States", + "pop": 4628910.01, + }) + ); + } + + #[test] + fn several_double_dot_in_header() { + let documents = r#"city:love:string,country:state,pop +"Boston","United States","4628910""#; + + let mut csv_iter = CsvDocumentIter::from_reader(documents.as_bytes()).unwrap(); + + assert_eq!( + Value::Object(csv_iter.next().unwrap().unwrap()), + json!({ + "city:love": "Boston", + "country:state": "United States", + "pop": "4628910", + }) + ); + } + + #[test] + fn ending_by_double_dot_in_header() { + let documents = r#"city:,country,pop +"Boston","United States","4628910""#; + + let mut csv_iter = CsvDocumentIter::from_reader(documents.as_bytes()).unwrap(); + + assert_eq!( + Value::Object(csv_iter.next().unwrap().unwrap()), + json!({ + "city:": "Boston", + "country": "United States", + "pop": "4628910", + }) + ); + } + + #[test] + fn starting_by_double_dot_in_header() { + let documents = r#":city,country,pop +"Boston","United States","4628910""#; + + let mut csv_iter = CsvDocumentIter::from_reader(documents.as_bytes()).unwrap(); + + assert_eq!( + Value::Object(csv_iter.next().unwrap().unwrap()), + json!({ + ":city": "Boston", + "country": "United States", + "pop": "4628910", + }) + ); + } + + #[test] + fn starting_by_double_dot_in_header2() { + let documents = r#":string,country,pop +"Boston","United States","4628910""#; + + let mut csv_iter = CsvDocumentIter::from_reader(documents.as_bytes()).unwrap(); + + assert!(csv_iter.next().unwrap().is_err()); + } + + #[test] + fn double_double_dot_in_header() { + let documents = r#"city::string,country,pop +"Boston","United States","4628910""#; + + let mut csv_iter = CsvDocumentIter::from_reader(documents.as_bytes()).unwrap(); + + assert_eq!( + Value::Object(csv_iter.next().unwrap().unwrap()), + json!({ + "city:": "Boston", + "country": "United States", + "pop": "4628910", + }) + ); + } + + #[test] + fn bad_type_in_header() { + let documents = r#"city,country:number,pop +"Boston","United States","4628910""#; + + let mut csv_iter = CsvDocumentIter::from_reader(documents.as_bytes()).unwrap(); + + assert!(csv_iter.next().unwrap().is_err()); + } + + #[test] + fn bad_column_count1() { + let documents = r#"city,country,pop +"Boston","United States","4628910", "too much""#; + + let mut csv_iter = CsvDocumentIter::from_reader(documents.as_bytes()).unwrap(); + + assert!(csv_iter.next().unwrap().is_err()); + } + + #[test] + fn bad_column_count2() { + let documents = r#"city,country,pop +"Boston","United States""#; + + let mut csv_iter = CsvDocumentIter::from_reader(documents.as_bytes()).unwrap(); + + assert!(csv_iter.next().unwrap().is_err()); + } +} diff --git a/meilisearch-lib/src/index_controller/updates/error.rs b/meilisearch-lib/src/index_controller/updates/error.rs index 8cbcf211a..217567569 100644 --- a/meilisearch-lib/src/index_controller/updates/error.rs +++ b/meilisearch-lib/src/index_controller/updates/error.rs @@ -25,6 +25,8 @@ pub enum UpdateLoopError { FatalUpdateStoreError, #[error("{0}")] InvalidPayload(#[from] DocumentFormatError), + #[error("{0}")] + MalformedPayload(Box), // TODO: The reference to actix has to go. #[error("{0}")] PayloadError(#[from] actix_web::error::PayloadError), @@ -56,12 +58,13 @@ internal_error!( impl ErrorCode for UpdateLoopError { fn error_code(&self) -> Code { match self { - UpdateLoopError::UnexistingUpdate(_) => Code::NotFound, - UpdateLoopError::Internal(_) => Code::Internal, - //UpdateLoopError::IndexActor(e) => e.error_code(), - UpdateLoopError::FatalUpdateStoreError => Code::Internal, - UpdateLoopError::InvalidPayload(_) => Code::BadRequest, - UpdateLoopError::PayloadError(error) => match error { + Self::UnexistingUpdate(_) => Code::NotFound, + Self::Internal(_) => Code::Internal, + //Self::IndexActor(e) => e.error_code(), + Self::FatalUpdateStoreError => Code::Internal, + Self::InvalidPayload(_) => Code::BadRequest, + Self::MalformedPayload(_) => Code::BadRequest, + Self::PayloadError(error) => match error { actix_web::error::PayloadError::Overflow => Code::PayloadTooLarge, _ => Code::Internal, }, diff --git a/meilisearch-lib/src/index_controller/updates/mod.rs b/meilisearch-lib/src/index_controller/updates/mod.rs index fad337553..14f0a7c69 100644 --- a/meilisearch-lib/src/index_controller/updates/mod.rs +++ b/meilisearch-lib/src/index_controller/updates/mod.rs @@ -1,8 +1,10 @@ +mod csv_documents_iter; pub mod error; mod message; pub mod status; pub mod store; +use crate::index_controller::updates::csv_documents_iter::CsvDocumentIter; use std::io; use std::path::{Path, PathBuf}; use std::sync::atomic::AtomicBool; @@ -13,6 +15,7 @@ use async_stream::stream; use bytes::Bytes; use futures::{Stream, StreamExt}; use log::trace; +use milli::documents::DocumentBatchBuilder; use milli::update::IndexDocumentsMethod; use serde::{Deserialize, Serialize}; use tokio::sync::mpsc; @@ -27,7 +30,7 @@ use crate::index_controller::update_file_store::UpdateFileStore; use status::UpdateStatus; use super::index_resolver::HardStateIndexResolver; -use super::{DocumentAdditionFormat, Update}; +use super::{DocumentAdditionFormat, Payload, Update}; pub type UpdateSender = mpsc::Sender; @@ -222,6 +225,26 @@ impl UpdateLoop { Ok(status.into()) } + async fn documents_from_csv(&self, payload: Payload) -> Result { + let file_store = self.update_file_store.clone(); + tokio::task::spawn_blocking(move || { + let (uuid, mut file) = file_store.new_update().unwrap(); + let mut builder = DocumentBatchBuilder::new(&mut *file).unwrap(); + + let iter = CsvDocumentIter::from_reader(StreamReader::new(payload))?; + for doc in iter { + let doc = doc?; + builder.add_documents(doc).unwrap(); + } + builder.finish().unwrap(); + + file.persist(); + + Ok(uuid) + }) + .await? + } + async fn handle_list_updates(&self, uuid: Uuid) -> Result> { let update_store = self.store.clone(); tokio::task::spawn_blocking(move || {