diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index c0d707657..f7cfe99f9 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -17,7 +17,7 @@ use meilisearch_types::{milli, Document}; use milli::tokenizer::TokenizerBuilder; use milli::{ AscDesc, FieldId, FieldsIdsMap, Filter, FormatOptions, Index, MatchBounds, MatcherBuilder, - SortError, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET, + SortError, TermsMatchingStrategy, VectorOrArrayOfVectors, DEFAULT_VALUES_PER_FACET, }; use ordered_float::OrderedFloat; use regex::Regex; @@ -432,7 +432,6 @@ pub fn perform_search( formatter_builder.highlight_suffix(query.highlight_post_tag); let mut documents = Vec::new(); - let documents_iter = index.documents(&rtxn, documents_ids)?; for ((_id, obkv), score) in documents_iter.into_iter().zip(document_scores.into_iter()) { @@ -460,7 +459,9 @@ pub fn perform_search( } if let Some(vector) = query.vector.as_ref() { - insert_semantic_similarity(&vector, &mut document); + if let Some(vectors) = extract_field("_vectors", &fields_ids_map, obkv)? { + insert_semantic_similarity(vector, vectors, &mut document); + } } let ranking_score = @@ -548,20 +549,18 @@ fn insert_geo_distance(sorts: &[String], document: &mut Document) { } } -fn insert_semantic_similarity(query: &[f32], document: &mut Document) { - if let Some(value) = document.get("_vectors") { - let vectors: Vec> = match serde_json::from_value(value.clone()) { - Ok(Either::Left(vector)) => vec![vector], - Ok(Either::Right(vectors)) => vectors, +fn insert_semantic_similarity(query: &[f32], vectors: Value, document: &mut Document) { + let vectors = + match serde_json::from_value(vectors).map(VectorOrArrayOfVectors::into_array_of_vectors) { + Ok(vectors) => vectors, Err(_) => return, }; - let similarity = vectors - .into_iter() - .map(|v| OrderedFloat(dot_product_similarity(query, &v))) - .max() - .map(OrderedFloat::into_inner); - document.insert("_semanticSimilarity".to_string(), json!(similarity)); - } + let similarity = vectors + .into_iter() + .map(|v| OrderedFloat(dot_product_similarity(query, &v))) + .max() + .map(OrderedFloat::into_inner); + document.insert("_semanticSimilarity".to_string(), json!(similarity)); } fn compute_formatted_options( @@ -691,6 +690,22 @@ fn make_document( Ok(document) } +/// Extract the JSON value under the field name specified +/// but doesn't support nested objects. +fn extract_field( + field_name: &str, + field_ids_map: &FieldsIdsMap, + obkv: obkv::KvReaderU16, +) -> Result, MeilisearchHttpError> { + match field_ids_map.id(field_name) { + Some(fid) => match obkv.get(fid) { + Some(value) => Ok(serde_json::from_slice(value).map(Some)?), + None => Ok(None), + }, + None => Ok(None), + } +} + fn format_fields>( document: &Document, field_ids_map: &FieldsIdsMap, diff --git a/milli/src/lib.rs b/milli/src/lib.rs index c93bf88ff..63cf6f397 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -286,6 +286,23 @@ pub fn normalize_facet(original: &str) -> String { CompatibilityDecompositionNormalizer.normalize_str(original.trim()).to_lowercase() } +/// Represents either a vector or an array of multiple vectors. +#[derive(serde::Serialize, serde::Deserialize, Debug)] +#[serde(transparent)] +pub struct VectorOrArrayOfVectors { + #[serde(with = "either::serde_untagged")] + inner: either::Either, Vec>>, +} + +impl VectorOrArrayOfVectors { + pub fn into_array_of_vectors(self) -> Vec> { + match self.inner { + either::Either::Left(vector) => vec![vector], + either::Either::Right(vectors) => vectors, + } + } +} + /// Normalize a vector by dividing the dimensions by the lenght of it. pub fn normalize_vector(mut vector: Vec) -> Vec { let squared: f32 = vector.iter().map(|x| x * x).sum(); diff --git a/milli/src/update/index_documents/extract/extract_vector_points.rs b/milli/src/update/index_documents/extract/extract_vector_points.rs index 7e2bd25c5..c2a08b320 100644 --- a/milli/src/update/index_documents/extract/extract_vector_points.rs +++ b/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -3,11 +3,10 @@ use std::fs::File; use std::io; use bytemuck::cast_slice; -use either::Either; use serde_json::from_slice; use super::helpers::{create_writer, writer_into_reader, GrenadParameters}; -use crate::{FieldId, InternalError, Result}; +use crate::{FieldId, InternalError, Result, VectorOrArrayOfVectors}; /// Extracts the embedding vector contained in each document under the `_vectors` field. /// @@ -31,9 +30,11 @@ pub fn extract_vector_points( // first we retrieve the _vectors field if let Some(vectors) = obkv.get(vectors_fid) { // extract the vectors - let vectors: Either>, Vec> = - from_slice(vectors).map_err(InternalError::SerdeJson).unwrap(); - let vectors = vectors.map_right(|v| vec![v]).into_inner(); + // TODO return a user error before unwrapping + let vectors = from_slice(vectors) + .map_err(InternalError::SerdeJson) + .map(VectorOrArrayOfVectors::into_array_of_vectors) + .unwrap(); for (i, vector) in vectors.into_iter().enumerate() { match u16::try_from(i) {