diff --git a/milli/src/lib.rs b/milli/src/lib.rs index 021880a50..5bebdbda5 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -285,6 +285,18 @@ pub fn normalize_facet(original: &str) -> String { CompatibilityDecompositionNormalizer.normalize_str(original.trim()).to_lowercase() } +/// Normalize a vector by dividing the dimensions by the lenght of it. +pub fn normalize_vector(mut vector: Vec) -> Vec { + let squared: f32 = vector.iter().map(|x| x * x).sum(); + let length = squared.sqrt(); + if length <= f32::EPSILON { + vector + } else { + vector.iter_mut().for_each(|x| *x = *x / length); + vector + } +} + #[cfg(test)] mod tests { use serde_json::json; diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index 246a89045..51917b772 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -49,7 +49,8 @@ use self::interner::Interned; use crate::score_details::{ScoreDetails, ScoringStrategy}; use crate::search::new::distinct::apply_distinct_rule; use crate::{ - AscDesc, DocumentId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError, BEU32, + normalize_vector, AscDesc, DocumentId, Filter, Index, Member, Result, TermsMatchingStrategy, + UserError, BEU32, }; /// A structure used throughout the execution of a search query. @@ -454,7 +455,8 @@ pub fn execute_search( let hnsw = ctx.index.vector_hnsw(ctx.txn)?.unwrap_or_default(); let ef = hnsw.len().min(100); let mut dest = vec![Neighbor { index: 0, distance: 0 }; ef]; - let neighbors = hnsw.nearest(vector, ef, &mut searcher, &mut dest[..]); + let vector = normalize_vector(vector.clone()); + let neighbors = hnsw.nearest(&vector, ef, &mut searcher, &mut dest[..]); let mut docids = Vec::new(); for Neighbor { index, distance: _ } in neighbors.iter() { diff --git a/milli/src/update/index_documents/extract/mod.rs b/milli/src/update/index_documents/extract/mod.rs index 325d52279..14f08b106 100644 --- a/milli/src/update/index_documents/extract/mod.rs +++ b/milli/src/update/index_documents/extract/mod.rs @@ -283,7 +283,7 @@ fn send_and_extract_flattened_documents_data( faceted_fields: &HashSet, primary_key_id: FieldId, geo_fields_ids: Option<(FieldId, FieldId)>, - vector_field_id: Option, + vectors_field_id: Option, stop_words: &Option>, max_positions_per_attributes: Option, ) -> Result<( @@ -312,11 +312,11 @@ fn send_and_extract_flattened_documents_data( }); } - if let Some(vector_field_id) = vector_field_id { + if let Some(vectors_field_id) = vectors_field_id { let documents_chunk_cloned = flattened_documents_chunk.clone(); let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); rayon::spawn(move || { - let result = extract_vector_points(documents_chunk_cloned, indexer, vector_field_id); + let result = extract_vector_points(documents_chunk_cloned, indexer, vectors_field_id); let _ = match result { Ok(vector_points) => { lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints(vector_points))) diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index 7d23ef320..a63aacf83 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -19,6 +19,7 @@ use super::helpers::{ use super::{ClonableMmap, MergeFn}; use crate::error::UserError; use crate::facet::FacetType; +use crate::normalize_vector; use crate::update::facet::FacetsUpdate; use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at}; use crate::{lat_lng_to_xyz, CboRoaringBitmapCodec, DocumentId, GeoPoint, Index, Result, BEU32}; @@ -253,6 +254,7 @@ pub(crate) fn write_typed_chunk_into_index( return Err(UserError::InvalidVectorDimensions { expected, found })?; } + let vector = normalize_vector(vector); let vector_id = hnsw.insert(vector, &mut searcher) as u32; index.vector_id_docid.put(wtxn, &BEU32::new(vector_id), &BEU32::new(docid))?; }