Better CSV support

This commit is contained in:
Clément Renault 2024-09-11 10:02:00 +02:00
parent 8287c2644f
commit b4de06259e
No known key found for this signature in database
GPG Key ID: F250A4C4E3AE5F5F

View File

@ -1,9 +1,8 @@
use std::fmt::{self, Debug, Display}; use std::fmt::{self, Debug, Display};
use std::fs::File; use std::fs::File;
use std::io::{self, BufReader, BufWriter, Seek, Write}; use std::io::{self, BufWriter};
use std::marker::PhantomData; use std::marker::PhantomData;
use csv::StringRecord;
use memmap2::Mmap; use memmap2::Mmap;
use milli::documents::Error; use milli::documents::Error;
use milli::update::new::TopLevelMap; use milli::update::new::TopLevelMap;
@ -11,13 +10,13 @@ use milli::Object;
use serde::de::{SeqAccess, Visitor}; use serde::de::{SeqAccess, Visitor};
use serde::{Deserialize, Deserializer}; use serde::{Deserialize, Deserializer};
use serde_json::error::Category; use serde_json::error::Category;
use serde_json::{Map, Value};
use crate::error::deserr_codes::MalformedPayload;
use crate::error::{Code, ErrorCode}; use crate::error::{Code, ErrorCode};
type Result<T> = std::result::Result<T, DocumentFormatError>; type Result<T> = std::result::Result<T, DocumentFormatError>;
#[derive(Debug)] #[derive(Debug, Clone, Copy)]
pub enum PayloadType { pub enum PayloadType {
Ndjson, Ndjson,
Json, Json,
@ -101,6 +100,16 @@ impl From<(PayloadType, serde_json::Error)> for DocumentFormatError {
} }
} }
impl From<(PayloadType, csv::Error)> for DocumentFormatError {
fn from((ty, error): (PayloadType, csv::Error)) -> Self {
if error.is_io_error() {
Self::Io(error.into())
} else {
Self::MalformedPayload(Error::Csv(error), ty)
}
}
}
impl From<io::Error> for DocumentFormatError { impl From<io::Error> for DocumentFormatError {
fn from(error: io::Error) -> Self { fn from(error: io::Error) -> Self {
Self::Io(error) Self::Io(error)
@ -140,78 +149,63 @@ fn parse_csv_header(header: &str) -> (&str, AllowedType) {
/// Reads CSV from input and write an obkv batch to writer. /// Reads CSV from input and write an obkv batch to writer.
pub fn read_csv(input: &File, output: impl io::Write, delimiter: u8) -> Result<u64> { pub fn read_csv(input: &File, output: impl io::Write, delimiter: u8) -> Result<u64> {
use serde_json::{Map, Value}; let ptype = PayloadType::Csv { delimiter };
let mut output = BufWriter::new(output); let mut output = BufWriter::new(output);
let mut reader = csv::ReaderBuilder::new().delimiter(delimiter).from_reader(input); let mut reader = csv::ReaderBuilder::new().delimiter(delimiter).from_reader(input);
// TODO manage error correctly let headers = reader.headers().map_err(|e| DocumentFormatError::from((ptype, e)))?.clone();
// Make sure that we insert the fields ids in order as the obkv writer has this requirement. let typed_fields: Vec<_> = headers.iter().map(parse_csv_header).collect();
let mut typed_fields: Vec<_> = reader let mut object: Map<_, _> = headers.iter().map(|k| (k.to_string(), Value::Null)).collect();
.headers()
.unwrap()
.into_iter()
.map(parse_csv_header)
.map(|(f, t)| (f.to_string(), t))
.collect();
let mut object: Map<_, _> = let mut line = 0;
reader.headers().unwrap().iter().map(|k| (k.to_string(), Value::Null)).collect();
let mut line: usize = 0;
let mut record = csv::StringRecord::new(); let mut record = csv::StringRecord::new();
while reader.read_record(&mut record).unwrap() { while reader.read_record(&mut record).map_err(|e| DocumentFormatError::from((ptype, e)))? {
// We increment here and not at the end of the while loop to take // We increment here and not at the end of the loop
// the header offset into account. // to take the header offset into account.
line += 1; line += 1;
// Reset the document to write // Reset the document values
object.iter_mut().for_each(|(_, v)| *v = Value::Null); object.iter_mut().for_each(|(_, v)| *v = Value::Null);
for (i, (name, type_)) in typed_fields.iter().enumerate() { for (i, (name, atype)) in typed_fields.iter().enumerate() {
let value = &record[i]; let value = &record[i];
let trimmed_value = value.trim(); let trimmed_value = value.trim();
let value = match type_ { let value = match atype {
AllowedType::Number if trimmed_value.is_empty() => Value::Null, AllowedType::Number if trimmed_value.is_empty() => Value::Null,
AllowedType::Number => match trimmed_value.parse::<i64>() { AllowedType::Number => match trimmed_value.parse::<i64>() {
Ok(integer) => Value::from(integer), Ok(integer) => Value::from(integer),
Err(_) => { Err(_) => match trimmed_value.parse::<f64>() {
match trimmed_value.parse::<f64>() { Ok(float) => Value::from(float),
Ok(float) => Value::from(float), Err(error) => {
Err(error) => { return Err(DocumentFormatError::MalformedPayload(
panic!("bad float") Error::ParseFloat { error, line, value: value.to_string() },
// return Err(Error::ParseFloat { ptype,
// error, ))
// line,
// value: value.to_string(),
// });
}
} }
} },
}, },
AllowedType::Boolean if trimmed_value.is_empty() => Value::Null, AllowedType::Boolean if trimmed_value.is_empty() => Value::Null,
AllowedType::Boolean => match trimmed_value.parse::<bool>() { AllowedType::Boolean => match trimmed_value.parse::<bool>() {
Ok(bool) => Value::from(bool), Ok(bool) => Value::from(bool),
Err(error) => { Err(error) => {
panic!("bad bool") return Err(DocumentFormatError::MalformedPayload(
// return Err(Error::ParseBool { Error::ParseBool { error, line, value: value.to_string() },
// error, ptype,
// line, ))
// value: value.to_string(),
// });
} }
}, },
AllowedType::String if value.is_empty() => Value::Null, AllowedType::String if value.is_empty() => Value::Null,
AllowedType::String => Value::from(value), AllowedType::String => Value::from(value),
}; };
*object.get_mut(name).unwrap() = value; *object.get_mut(*name).expect("encountered an unknown field") = value;
} }
serde_json::to_writer(&mut output, &object).unwrap(); serde_json::to_writer(&mut output, &object)
.map_err(|e| DocumentFormatError::from((ptype, e)))?;
} }
Ok(line.saturating_sub(1) as u64) Ok(line as u64)
} }
/// Reads JSON from temporary file and write an obkv batch to writer. /// Reads JSON from temporary file and write an obkv batch to writer.