Normalize the vectors during indexation and search

This commit is contained in:
Kerollmops 2023-06-20 11:45:29 +02:00 committed by Clément Renault
parent 321ec5f3fa
commit ab9f2269aa
No known key found for this signature in database
GPG Key ID: 92ADA4E935E71FA4
4 changed files with 21 additions and 5 deletions

View File

@ -285,6 +285,18 @@ pub fn normalize_facet(original: &str) -> String {
CompatibilityDecompositionNormalizer.normalize_str(original.trim()).to_lowercase() CompatibilityDecompositionNormalizer.normalize_str(original.trim()).to_lowercase()
} }
/// Normalize a vector by dividing the dimensions by the lenght of it.
pub fn normalize_vector(mut vector: Vec<f32>) -> Vec<f32> {
let squared: f32 = vector.iter().map(|x| x * x).sum();
let length = squared.sqrt();
if length <= f32::EPSILON {
vector
} else {
vector.iter_mut().for_each(|x| *x = *x / length);
vector
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use serde_json::json; use serde_json::json;

View File

@ -49,7 +49,8 @@ use self::interner::Interned;
use crate::score_details::{ScoreDetails, ScoringStrategy}; use crate::score_details::{ScoreDetails, ScoringStrategy};
use crate::search::new::distinct::apply_distinct_rule; use crate::search::new::distinct::apply_distinct_rule;
use crate::{ use crate::{
AscDesc, DocumentId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError, BEU32, normalize_vector, AscDesc, DocumentId, Filter, Index, Member, Result, TermsMatchingStrategy,
UserError, BEU32,
}; };
/// A structure used throughout the execution of a search query. /// A structure used throughout the execution of a search query.
@ -454,7 +455,8 @@ pub fn execute_search(
let hnsw = ctx.index.vector_hnsw(ctx.txn)?.unwrap_or_default(); let hnsw = ctx.index.vector_hnsw(ctx.txn)?.unwrap_or_default();
let ef = hnsw.len().min(100); let ef = hnsw.len().min(100);
let mut dest = vec![Neighbor { index: 0, distance: 0 }; ef]; let mut dest = vec![Neighbor { index: 0, distance: 0 }; ef];
let neighbors = hnsw.nearest(vector, ef, &mut searcher, &mut dest[..]); let vector = normalize_vector(vector.clone());
let neighbors = hnsw.nearest(&vector, ef, &mut searcher, &mut dest[..]);
let mut docids = Vec::new(); let mut docids = Vec::new();
for Neighbor { index, distance: _ } in neighbors.iter() { for Neighbor { index, distance: _ } in neighbors.iter() {

View File

@ -283,7 +283,7 @@ fn send_and_extract_flattened_documents_data(
faceted_fields: &HashSet<FieldId>, faceted_fields: &HashSet<FieldId>,
primary_key_id: FieldId, primary_key_id: FieldId,
geo_fields_ids: Option<(FieldId, FieldId)>, geo_fields_ids: Option<(FieldId, FieldId)>,
vector_field_id: Option<FieldId>, vectors_field_id: Option<FieldId>,
stop_words: &Option<fst::Set<&[u8]>>, stop_words: &Option<fst::Set<&[u8]>>,
max_positions_per_attributes: Option<u32>, max_positions_per_attributes: Option<u32>,
) -> Result<( ) -> Result<(
@ -312,11 +312,11 @@ fn send_and_extract_flattened_documents_data(
}); });
} }
if let Some(vector_field_id) = vector_field_id { if let Some(vectors_field_id) = vectors_field_id {
let documents_chunk_cloned = flattened_documents_chunk.clone(); let documents_chunk_cloned = flattened_documents_chunk.clone();
let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); let lmdb_writer_sx_cloned = lmdb_writer_sx.clone();
rayon::spawn(move || { rayon::spawn(move || {
let result = extract_vector_points(documents_chunk_cloned, indexer, vector_field_id); let result = extract_vector_points(documents_chunk_cloned, indexer, vectors_field_id);
let _ = match result { let _ = match result {
Ok(vector_points) => { Ok(vector_points) => {
lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints(vector_points))) lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints(vector_points)))

View File

@ -19,6 +19,7 @@ use super::helpers::{
use super::{ClonableMmap, MergeFn}; use super::{ClonableMmap, MergeFn};
use crate::error::UserError; use crate::error::UserError;
use crate::facet::FacetType; use crate::facet::FacetType;
use crate::normalize_vector;
use crate::update::facet::FacetsUpdate; use crate::update::facet::FacetsUpdate;
use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at}; use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at};
use crate::{lat_lng_to_xyz, CboRoaringBitmapCodec, DocumentId, GeoPoint, Index, Result, BEU32}; use crate::{lat_lng_to_xyz, CboRoaringBitmapCodec, DocumentId, GeoPoint, Index, Result, BEU32};
@ -253,6 +254,7 @@ pub(crate) fn write_typed_chunk_into_index(
return Err(UserError::InvalidVectorDimensions { expected, found })?; return Err(UserError::InvalidVectorDimensions { expected, found })?;
} }
let vector = normalize_vector(vector);
let vector_id = hnsw.insert(vector, &mut searcher) as u32; let vector_id = hnsw.insert(vector, &mut searcher) as u32;
index.vector_id_docid.put(wtxn, &BEU32::new(vector_id), &BEU32::new(docid))?; index.vector_id_docid.put(wtxn, &BEU32::new(vector_id), &BEU32::new(docid))?;
} }