Better CSV support

This commit is contained in:
Clément Renault 2024-09-11 10:02:00 +02:00
parent 8287c2644f
commit b4de06259e
No known key found for this signature in database
GPG Key ID: F250A4C4E3AE5F5F

View File

@ -1,9 +1,8 @@
use std::fmt::{self, Debug, Display};
use std::fs::File;
use std::io::{self, BufReader, BufWriter, Seek, Write};
use std::io::{self, BufWriter};
use std::marker::PhantomData;
use csv::StringRecord;
use memmap2::Mmap;
use milli::documents::Error;
use milli::update::new::TopLevelMap;
@ -11,13 +10,13 @@ use milli::Object;
use serde::de::{SeqAccess, Visitor};
use serde::{Deserialize, Deserializer};
use serde_json::error::Category;
use serde_json::{Map, Value};
use crate::error::deserr_codes::MalformedPayload;
use crate::error::{Code, ErrorCode};
type Result<T> = std::result::Result<T, DocumentFormatError>;
#[derive(Debug)]
#[derive(Debug, Clone, Copy)]
pub enum PayloadType {
Ndjson,
Json,
@ -101,6 +100,16 @@ impl From<(PayloadType, serde_json::Error)> for DocumentFormatError {
}
}
impl From<(PayloadType, csv::Error)> for DocumentFormatError {
fn from((ty, error): (PayloadType, csv::Error)) -> Self {
if error.is_io_error() {
Self::Io(error.into())
} else {
Self::MalformedPayload(Error::Csv(error), ty)
}
}
}
impl From<io::Error> for DocumentFormatError {
fn from(error: io::Error) -> Self {
Self::Io(error)
@ -140,78 +149,63 @@ fn parse_csv_header(header: &str) -> (&str, AllowedType) {
/// Reads CSV from input and write an obkv batch to writer.
pub fn read_csv(input: &File, output: impl io::Write, delimiter: u8) -> Result<u64> {
use serde_json::{Map, Value};
let ptype = PayloadType::Csv { delimiter };
let mut output = BufWriter::new(output);
let mut reader = csv::ReaderBuilder::new().delimiter(delimiter).from_reader(input);
// TODO manage error correctly
// Make sure that we insert the fields ids in order as the obkv writer has this requirement.
let mut typed_fields: Vec<_> = reader
.headers()
.unwrap()
.into_iter()
.map(parse_csv_header)
.map(|(f, t)| (f.to_string(), t))
.collect();
let headers = reader.headers().map_err(|e| DocumentFormatError::from((ptype, e)))?.clone();
let typed_fields: Vec<_> = headers.iter().map(parse_csv_header).collect();
let mut object: Map<_, _> = headers.iter().map(|k| (k.to_string(), Value::Null)).collect();
let mut object: Map<_, _> =
reader.headers().unwrap().iter().map(|k| (k.to_string(), Value::Null)).collect();
let mut line: usize = 0;
let mut line = 0;
let mut record = csv::StringRecord::new();
while reader.read_record(&mut record).unwrap() {
// We increment here and not at the end of the while loop to take
// the header offset into account.
while reader.read_record(&mut record).map_err(|e| DocumentFormatError::from((ptype, e)))? {
// We increment here and not at the end of the loop
// to take the header offset into account.
line += 1;
// Reset the document to write
// Reset the document values
object.iter_mut().for_each(|(_, v)| *v = Value::Null);
for (i, (name, type_)) in typed_fields.iter().enumerate() {
for (i, (name, atype)) in typed_fields.iter().enumerate() {
let value = &record[i];
let trimmed_value = value.trim();
let value = match type_ {
let value = match atype {
AllowedType::Number if trimmed_value.is_empty() => Value::Null,
AllowedType::Number => match trimmed_value.parse::<i64>() {
Ok(integer) => Value::from(integer),
Err(_) => {
match trimmed_value.parse::<f64>() {
Ok(float) => Value::from(float),
Err(error) => {
panic!("bad float")
// return Err(Error::ParseFloat {
// error,
// line,
// value: value.to_string(),
// });
}
Err(_) => match trimmed_value.parse::<f64>() {
Ok(float) => Value::from(float),
Err(error) => {
return Err(DocumentFormatError::MalformedPayload(
Error::ParseFloat { error, line, value: value.to_string() },
ptype,
))
}
}
},
},
AllowedType::Boolean if trimmed_value.is_empty() => Value::Null,
AllowedType::Boolean => match trimmed_value.parse::<bool>() {
Ok(bool) => Value::from(bool),
Err(error) => {
panic!("bad bool")
// return Err(Error::ParseBool {
// error,
// line,
// value: value.to_string(),
// });
return Err(DocumentFormatError::MalformedPayload(
Error::ParseBool { error, line, value: value.to_string() },
ptype,
))
}
},
AllowedType::String if value.is_empty() => Value::Null,
AllowedType::String => Value::from(value),
};
*object.get_mut(name).unwrap() = value;
*object.get_mut(*name).expect("encountered an unknown field") = value;
}
serde_json::to_writer(&mut output, &object).unwrap();
serde_json::to_writer(&mut output, &object)
.map_err(|e| DocumentFormatError::from((ptype, e)))?;
}
Ok(line.saturating_sub(1) as u64)
Ok(line as u64)
}
/// Reads JSON from temporary file and write an obkv batch to writer.