Deserr for distribution

This commit is contained in:
Louis Dureuil 2024-03-27 11:50:33 +01:00
parent afd1da5642
commit 168ded3b9d
No known key found for this signature in database

View File

@ -318,10 +318,50 @@ pub struct DistributionShift {
pub current_sigma: OrderedFloat<f32>,
}
#[derive(Serialize, Deserialize)]
impl<E> Deserr<E> for DistributionShift
where
E: DeserializeError,
{
fn deserialize_from_value<V: deserr::IntoValue>(
value: deserr::Value<V>,
location: deserr::ValuePointerRef,
) -> Result<Self, E> {
let value = DistributionShiftSerializable::deserialize_from_value(value, location)?;
if value.mean < 0. || value.mean > 1. {
return Err(deserr::take_cf_content(E::error::<std::convert::Infallible>(
None,
deserr::ErrorKind::Unexpected {
msg: format!(
"the distribution mean must be in the range [0, 1], got {}",
value.mean
),
},
location,
)));
}
if value.sigma <= 0. || value.sigma > 1. {
return Err(deserr::take_cf_content(E::error::<std::convert::Infallible>(
None,
deserr::ErrorKind::Unexpected {
msg: format!(
"the distribution sigma must be in the range ]0, 1], got {}",
value.sigma
),
},
location,
)));
}
Ok(value.into())
}
}
#[derive(Serialize, Deserialize, Deserr)]
#[serde(deny_unknown_fields)]
#[deserr(deny_unknown_fields)]
struct DistributionShiftSerializable {
current_mean: f32,
current_sigma: f32,
mean: f32,
sigma: f32,
}
impl From<DistributionShift> for DistributionShiftSerializable {
@ -331,18 +371,13 @@ impl From<DistributionShift> for DistributionShiftSerializable {
current_sigma: OrderedFloat(current_sigma),
}: DistributionShift,
) -> Self {
Self { current_mean, current_sigma }
Self { mean: current_mean, sigma: current_sigma }
}
}
impl From<DistributionShiftSerializable> for DistributionShift {
fn from(
DistributionShiftSerializable { current_mean, current_sigma }: DistributionShiftSerializable,
) -> Self {
Self {
current_mean: OrderedFloat(current_mean),
current_sigma: OrderedFloat(current_sigma),
}
fn from(DistributionShiftSerializable { mean, sigma }: DistributionShiftSerializable) -> Self {
Self { current_mean: OrderedFloat(mean), current_sigma: OrderedFloat(sigma) }
}
}