mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-07-04 04:17:10 +02:00
Remove stuff, add distribution shift (WIP)
This commit is contained in:
parent
e56f160032
commit
65e49b7092
10 changed files with 126 additions and 278 deletions
|
@ -1,41 +0,0 @@
|
|||
use std::ops;
|
||||
|
||||
use instant_distance::Point;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::normalize_vector;
|
||||
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
|
||||
pub struct NDotProductPoint(Vec<f32>);
|
||||
|
||||
impl NDotProductPoint {
|
||||
pub fn new(point: Vec<f32>) -> Self {
|
||||
NDotProductPoint(normalize_vector(point))
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> Vec<f32> {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl ops::Deref for NDotProductPoint {
|
||||
type Target = [f32];
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.0.as_slice()
|
||||
}
|
||||
}
|
||||
|
||||
impl Point for NDotProductPoint {
|
||||
fn distance(&self, other: &Self) -> f32 {
|
||||
let dist = 1.0 - dot_product_similarity(&self.0, &other.0);
|
||||
debug_assert!(!dist.is_nan());
|
||||
dist
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the dot product similarity score that will between 0.0 and 1.0
|
||||
/// if both vectors are normalized. The higher the more similar the vectors are.
|
||||
pub fn dot_product_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter().zip(b).map(|(a, b)| a * b).sum()
|
||||
}
|
|
@ -10,7 +10,6 @@ use roaring::RoaringBitmap;
|
|||
use rstar::RTree;
|
||||
use time::OffsetDateTime;
|
||||
|
||||
use crate::distance::NDotProductPoint;
|
||||
use crate::documents::PrimaryKey;
|
||||
use crate::error::{InternalError, UserError};
|
||||
use crate::fields_ids_map::FieldsIdsMap;
|
||||
|
@ -30,9 +29,6 @@ use crate::{
|
|||
BEU32, BEU64,
|
||||
};
|
||||
|
||||
/// The HNSW data-structure that we serialize, fill and search in.
|
||||
pub type Hnsw = instant_distance::Hnsw<NDotProductPoint>;
|
||||
|
||||
pub const DEFAULT_MIN_WORD_LEN_ONE_TYPO: u8 = 5;
|
||||
pub const DEFAULT_MIN_WORD_LEN_TWO_TYPOS: u8 = 9;
|
||||
|
||||
|
|
|
@ -10,7 +10,6 @@ pub mod documents;
|
|||
|
||||
mod asc_desc;
|
||||
mod criterion;
|
||||
pub mod distance;
|
||||
mod error;
|
||||
mod external_documents_ids;
|
||||
pub mod facet;
|
||||
|
@ -33,7 +32,6 @@ use std::convert::{TryFrom, TryInto};
|
|||
use std::hash::BuildHasherDefault;
|
||||
|
||||
use charabia::normalizer::{CharNormalizer, CompatibilityDecompositionNormalizer};
|
||||
pub use distance::dot_product_similarity;
|
||||
pub use filter_parser::{Condition, FilterCondition, Span, Token};
|
||||
use fxhash::{FxHasher32, FxHasher64};
|
||||
pub use grenad::CompressionType;
|
||||
|
|
|
@ -50,6 +50,7 @@ use self::vector_sort::VectorSort;
|
|||
use crate::error::FieldIdMapMissingEntry;
|
||||
use crate::score_details::{ScoreDetails, ScoringStrategy};
|
||||
use crate::search::new::distinct::apply_distinct_rule;
|
||||
use crate::vector::DistributionShift;
|
||||
use crate::{
|
||||
AscDesc, DocumentId, FieldId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError,
|
||||
};
|
||||
|
@ -264,6 +265,7 @@ fn get_ranking_rules_for_vector<'ctx>(
|
|||
geo_strategy: geo_sort::Strategy,
|
||||
limit_plus_offset: usize,
|
||||
target: &[f32],
|
||||
distribution_shift: Option<DistributionShift>,
|
||||
) -> Result<Vec<BoxRankingRule<'ctx, PlaceholderQuery>>> {
|
||||
// query graph search
|
||||
|
||||
|
@ -289,6 +291,7 @@ fn get_ranking_rules_for_vector<'ctx>(
|
|||
target.to_vec(),
|
||||
vector_candidates,
|
||||
limit_plus_offset,
|
||||
distribution_shift,
|
||||
)?;
|
||||
ranking_rules.push(Box::new(vector_sort));
|
||||
vector = true;
|
||||
|
@ -515,8 +518,14 @@ pub fn execute_vector_search(
|
|||
|
||||
/// FIXME: input universe = universe & documents_with_vectors
|
||||
// for now if we're computing embeddings for ALL documents, we can assume that this is just universe
|
||||
let ranking_rules =
|
||||
get_ranking_rules_for_vector(ctx, sort_criteria, geo_strategy, from + length, vector)?;
|
||||
let ranking_rules = get_ranking_rules_for_vector(
|
||||
ctx,
|
||||
sort_criteria,
|
||||
geo_strategy,
|
||||
from + length,
|
||||
vector,
|
||||
None,
|
||||
)?;
|
||||
|
||||
let mut placeholder_search_logger = logger::DefaultSearchLogger;
|
||||
let placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery> =
|
||||
|
|
|
@ -5,6 +5,7 @@ use roaring::RoaringBitmap;
|
|||
|
||||
use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait};
|
||||
use crate::score_details::{self, ScoreDetails};
|
||||
use crate::vector::DistributionShift;
|
||||
use crate::{DocumentId, Result, SearchContext, SearchLogger};
|
||||
|
||||
pub struct VectorSort<Q: RankingRuleQueryTrait> {
|
||||
|
@ -13,6 +14,7 @@ pub struct VectorSort<Q: RankingRuleQueryTrait> {
|
|||
vector_candidates: RoaringBitmap,
|
||||
cached_sorted_docids: std::vec::IntoIter<(DocumentId, f32, Vec<f32>)>,
|
||||
limit: usize,
|
||||
distribution_shift: Option<DistributionShift>,
|
||||
}
|
||||
|
||||
impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
|
||||
|
@ -21,6 +23,7 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
|
|||
target: Vec<f32>,
|
||||
vector_candidates: RoaringBitmap,
|
||||
limit: usize,
|
||||
distribution_shift: Option<DistributionShift>,
|
||||
) -> Result<Self> {
|
||||
Ok(Self {
|
||||
query: None,
|
||||
|
@ -28,6 +31,7 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
|
|||
vector_candidates,
|
||||
cached_sorted_docids: Default::default(),
|
||||
limit,
|
||||
distribution_shift,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -52,7 +56,7 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
|
|||
for reader in readers.iter() {
|
||||
let nns_by_vector = reader.nns_by_vector(
|
||||
ctx.txn,
|
||||
&target,
|
||||
target,
|
||||
self.limit,
|
||||
None,
|
||||
Some(&self.vector_candidates),
|
||||
|
@ -66,6 +70,7 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
|
|||
}
|
||||
results.sort_unstable_by_key(|(_, distance, _)| OrderedFloat(*distance));
|
||||
self.cached_sorted_docids = results.into_iter();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -111,14 +116,19 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort<Q> {
|
|||
}));
|
||||
}
|
||||
|
||||
while let Some((docid, distance, vector)) = self.cached_sorted_docids.next() {
|
||||
for (docid, distance, vector) in self.cached_sorted_docids.by_ref() {
|
||||
if self.vector_candidates.contains(docid) {
|
||||
let score = 1.0 - distance;
|
||||
let score = self
|
||||
.distribution_shift
|
||||
.map(|distribution| distribution.shift(score))
|
||||
.unwrap_or(score);
|
||||
return Ok(Some(RankingRuleOutput {
|
||||
query,
|
||||
candidates: RoaringBitmap::from_iter([docid]),
|
||||
score: ScoreDetails::Vector(score_details::Vector {
|
||||
target_vector: self.target.clone(),
|
||||
value_similarity: Some((vector, 1.0 - distance)),
|
||||
value_similarity: Some((vector, score)),
|
||||
}),
|
||||
}));
|
||||
}
|
||||
|
|
|
@ -415,7 +415,7 @@ pub(crate) fn write_typed_chunk_into_index(
|
|||
|
||||
let mut deleted_index = None;
|
||||
for (index, writer) in writers.iter().enumerate() {
|
||||
let Some(candidate) = writer.item_vector(&wtxn, docid)? else {
|
||||
let Some(candidate) = writer.item_vector(wtxn, docid)? else {
|
||||
// uses invariant: vectors are packed in the first writers.
|
||||
break;
|
||||
};
|
||||
|
@ -429,7 +429,7 @@ pub(crate) fn write_typed_chunk_into_index(
|
|||
if let Some(deleted_index) = deleted_index {
|
||||
let mut last_index_with_a_vector = None;
|
||||
for (index, writer) in writers.iter().enumerate().skip(deleted_index) {
|
||||
let Some(candidate) = writer.item_vector(&wtxn, docid)? else {
|
||||
let Some(candidate) = writer.item_vector(wtxn, docid)? else {
|
||||
break;
|
||||
};
|
||||
last_index_with_a_vector = Some((index, candidate));
|
||||
|
|
|
@ -140,3 +140,47 @@ impl Embedder {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct DistributionShift {
|
||||
pub current_mean: f32,
|
||||
pub current_sigma: f32,
|
||||
}
|
||||
|
||||
impl DistributionShift {
|
||||
/// `None` if sigma <= 0.
|
||||
pub fn new(mean: f32, sigma: f32) -> Option<Self> {
|
||||
if sigma <= 0.0 {
|
||||
None
|
||||
} else {
|
||||
Some(Self { current_mean: mean, current_sigma: sigma })
|
||||
}
|
||||
}
|
||||
|
||||
pub fn shift(&self, score: f32) -> f32 {
|
||||
// <https://math.stackexchange.com/a/2894689>
|
||||
// We're somewhat abusively mapping the distribution of distances to a gaussian.
|
||||
// The parameters we're given is the mean and sigma of the native result distribution.
|
||||
// We're using them to retarget the distribution to a gaussian centered on 0.5 with a sigma of 0.4.
|
||||
|
||||
let target_mean = 0.5;
|
||||
let target_sigma = 0.4;
|
||||
|
||||
// a^2 sig1^2 = sig2^2 => a^2 = sig2^2 / sig1^2 => a = sig2 / sig1, assuming a, sig1, and sig2 positive.
|
||||
let factor = target_sigma / self.current_sigma;
|
||||
// a*mu1 + b = mu2 => b = mu2 - a*mu1
|
||||
let offset = target_mean - (factor * self.current_mean);
|
||||
|
||||
let mut score = factor * score + offset;
|
||||
|
||||
// clamp the final score in the ]0, 1] interval.
|
||||
if score <= 0.0 {
|
||||
score = f32::EPSILON;
|
||||
}
|
||||
if score > 1.0 {
|
||||
score = 1.0;
|
||||
}
|
||||
|
||||
score
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue