mirror of
https://github.com/meilisearch/MeiliSearch
synced 2024-11-26 23:04:26 +01:00
Fix vector parsing
This commit is contained in:
parent
d97af4d8e6
commit
4706a0eb49
@ -167,7 +167,7 @@ fn entry_from_raw_value(
|
|||||||
value: &RawValue,
|
value: &RawValue,
|
||||||
has_configured_embedder: bool,
|
has_configured_embedder: bool,
|
||||||
) -> std::result::Result<VectorEntry<'_>, serde_json::Error> {
|
) -> std::result::Result<VectorEntry<'_>, serde_json::Error> {
|
||||||
let value: RawVectors = serde_json::from_str(value.get())?;
|
let value: RawVectors = RawVectors::from_raw_value(value)?;
|
||||||
|
|
||||||
Ok(match value {
|
Ok(match value {
|
||||||
RawVectors::Explicit(raw_explicit_vectors) => VectorEntry {
|
RawVectors::Explicit(raw_explicit_vectors) => VectorEntry {
|
||||||
@ -177,7 +177,7 @@ fn entry_from_raw_value(
|
|||||||
},
|
},
|
||||||
RawVectors::ImplicitlyUserProvided(value) => VectorEntry {
|
RawVectors::ImplicitlyUserProvided(value) => VectorEntry {
|
||||||
has_configured_embedder,
|
has_configured_embedder,
|
||||||
embeddings: Some(Embeddings::FromJsonImplicityUserProvided(value)),
|
embeddings: value.map(Embeddings::FromJsonImplicityUserProvided),
|
||||||
regenerate: false,
|
regenerate: false,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
@ -12,11 +12,96 @@ use crate::{DocumentId, FieldId, InternalError, UserError};
|
|||||||
|
|
||||||
pub const RESERVED_VECTORS_FIELD_NAME: &str = "_vectors";
|
pub const RESERVED_VECTORS_FIELD_NAME: &str = "_vectors";
|
||||||
|
|
||||||
#[derive(serde::Serialize, serde::Deserialize, Debug)]
|
#[derive(serde::Serialize, Debug)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
pub enum RawVectors<'doc> {
|
pub enum RawVectors<'doc> {
|
||||||
Explicit(#[serde(borrow)] RawExplicitVectors<'doc>),
|
Explicit(#[serde(borrow)] RawExplicitVectors<'doc>),
|
||||||
ImplicitlyUserProvided(#[serde(borrow)] &'doc RawValue),
|
ImplicitlyUserProvided(#[serde(borrow)] Option<&'doc RawValue>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'doc> RawVectors<'doc> {
|
||||||
|
pub fn from_raw_value(raw: &'doc RawValue) -> Result<Self, serde_json::Error> {
|
||||||
|
use serde::de::Deserializer as _;
|
||||||
|
Ok(match raw.deserialize_any(RawVectorsVisitor)? {
|
||||||
|
RawVectorsVisitorValue::ImplicitNone => RawVectors::ImplicitlyUserProvided(None),
|
||||||
|
RawVectorsVisitorValue::Implicit => RawVectors::ImplicitlyUserProvided(Some(raw)),
|
||||||
|
RawVectorsVisitorValue::Explicit { regenerate, embeddings } => {
|
||||||
|
RawVectors::Explicit(RawExplicitVectors { embeddings, regenerate })
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct RawVectorsVisitor;
|
||||||
|
|
||||||
|
enum RawVectorsVisitorValue<'doc> {
|
||||||
|
ImplicitNone,
|
||||||
|
Implicit,
|
||||||
|
Explicit { regenerate: bool, embeddings: Option<&'doc RawValue> },
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'doc> serde::de::Visitor<'doc> for RawVectorsVisitor {
|
||||||
|
type Value = RawVectorsVisitorValue<'doc>;
|
||||||
|
|
||||||
|
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
|
write!(formatter, "a map containing at least `regenerate`, or an array of floats`")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_none<E>(self) -> Result<Self::Value, E>
|
||||||
|
where
|
||||||
|
E: serde::de::Error,
|
||||||
|
{
|
||||||
|
Ok(RawVectorsVisitorValue::ImplicitNone)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
|
||||||
|
where
|
||||||
|
D: serde::Deserializer<'doc>,
|
||||||
|
{
|
||||||
|
deserializer.deserialize_any(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_unit<E>(self) -> Result<Self::Value, E>
|
||||||
|
where
|
||||||
|
E: serde::de::Error,
|
||||||
|
{
|
||||||
|
Ok(RawVectorsVisitorValue::ImplicitNone)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
|
||||||
|
where
|
||||||
|
A: serde::de::SeqAccess<'doc>,
|
||||||
|
{
|
||||||
|
// must consume all elements or parsing fails
|
||||||
|
while let Some(_) = seq.next_element::<&RawValue>()? {}
|
||||||
|
Ok(RawVectorsVisitorValue::Implicit)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
|
||||||
|
where
|
||||||
|
A: serde::de::MapAccess<'doc>,
|
||||||
|
{
|
||||||
|
use serde::de::Error as _;
|
||||||
|
let mut regenerate = None;
|
||||||
|
let mut embeddings = None;
|
||||||
|
while let Some(s) = map.next_key()? {
|
||||||
|
match s {
|
||||||
|
"regenerate" => {
|
||||||
|
let value: bool = map.next_value()?;
|
||||||
|
regenerate = Some(value);
|
||||||
|
}
|
||||||
|
"embeddings" => {
|
||||||
|
let value: &RawValue = map.next_value()?;
|
||||||
|
embeddings = Some(value);
|
||||||
|
}
|
||||||
|
other => return Err(A::Error::unknown_field(other, &["regenerate", "embeddings"])),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let Some(regenerate) = regenerate else {
|
||||||
|
return Err(A::Error::missing_field("regenerate"));
|
||||||
|
};
|
||||||
|
Ok(RawVectorsVisitorValue::Explicit { regenerate, embeddings })
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(serde::Serialize, Debug)]
|
#[derive(serde::Serialize, Debug)]
|
||||||
@ -86,7 +171,7 @@ impl<'doc> RawVectors<'doc> {
|
|||||||
}
|
}
|
||||||
pub fn embeddings(&self) -> Option<&'doc RawValue> {
|
pub fn embeddings(&self) -> Option<&'doc RawValue> {
|
||||||
match self {
|
match self {
|
||||||
RawVectors::ImplicitlyUserProvided(embeddings) => Some(embeddings),
|
RawVectors::ImplicitlyUserProvided(embeddings) => *embeddings,
|
||||||
RawVectors::Explicit(RawExplicitVectors { embeddings, regenerate: _ }) => *embeddings,
|
RawVectors::Explicit(RawExplicitVectors { embeddings, regenerate: _ }) => *embeddings,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user