From 168ded3b9d17f290190c326ed9574c2937b216f9 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Wed, 27 Mar 2024 11:50:33 +0100 Subject: [PATCH] Deserr for distribution --- milli/src/vector/mod.rs | 57 +++++++++++++++++++++++++++++++++-------- 1 file changed, 46 insertions(+), 11 deletions(-) diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index 4a3a9920e..1cb0a18f7 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -318,10 +318,50 @@ pub struct DistributionShift { pub current_sigma: OrderedFloat, } -#[derive(Serialize, Deserialize)] +impl Deserr for DistributionShift +where + E: DeserializeError, +{ + fn deserialize_from_value( + value: deserr::Value, + location: deserr::ValuePointerRef, + ) -> Result { + let value = DistributionShiftSerializable::deserialize_from_value(value, location)?; + if value.mean < 0. || value.mean > 1. { + return Err(deserr::take_cf_content(E::error::( + None, + deserr::ErrorKind::Unexpected { + msg: format!( + "the distribution mean must be in the range [0, 1], got {}", + value.mean + ), + }, + location, + ))); + } + if value.sigma <= 0. || value.sigma > 1. { + return Err(deserr::take_cf_content(E::error::( + None, + deserr::ErrorKind::Unexpected { + msg: format!( + "the distribution sigma must be in the range ]0, 1], got {}", + value.sigma + ), + }, + location, + ))); + } + + Ok(value.into()) + } +} + +#[derive(Serialize, Deserialize, Deserr)] +#[serde(deny_unknown_fields)] +#[deserr(deny_unknown_fields)] struct DistributionShiftSerializable { - current_mean: f32, - current_sigma: f32, + mean: f32, + sigma: f32, } impl From for DistributionShiftSerializable { @@ -331,18 +371,13 @@ impl From for DistributionShiftSerializable { current_sigma: OrderedFloat(current_sigma), }: DistributionShift, ) -> Self { - Self { current_mean, current_sigma } + Self { mean: current_mean, sigma: current_sigma } } } impl From for DistributionShift { - fn from( - DistributionShiftSerializable { current_mean, current_sigma }: DistributionShiftSerializable, - ) -> Self { - Self { - current_mean: OrderedFloat(current_mean), - current_sigma: OrderedFloat(current_sigma), - } + fn from(DistributionShiftSerializable { mean, sigma }: DistributionShiftSerializable) -> Self { + Self { current_mean: OrderedFloat(mean), current_sigma: OrderedFloat(sigma) } } }