Remove stuff, add distribution shift (WIP)

This commit is contained in:
Louis Dureuil 2023-12-12 10:05:06 +01:00
parent e56f160032
commit 65e49b7092
No known key found for this signature in database
10 changed files with 126 additions and 278 deletions

View file

@ -1,41 +0,0 @@
use std::ops;
use instant_distance::Point;
use serde::{Deserialize, Serialize};
use crate::normalize_vector;
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct NDotProductPoint(Vec<f32>);
impl NDotProductPoint {
pub fn new(point: Vec<f32>) -> Self {
NDotProductPoint(normalize_vector(point))
}
pub fn into_inner(self) -> Vec<f32> {
self.0
}
}
impl ops::Deref for NDotProductPoint {
type Target = [f32];
fn deref(&self) -> &Self::Target {
self.0.as_slice()
}
}
impl Point for NDotProductPoint {
fn distance(&self, other: &Self) -> f32 {
let dist = 1.0 - dot_product_similarity(&self.0, &other.0);
debug_assert!(!dist.is_nan());
dist
}
}
/// Returns the dot product similarity score that will between 0.0 and 1.0
/// if both vectors are normalized. The higher the more similar the vectors are.
pub fn dot_product_similarity(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b).map(|(a, b)| a * b).sum()
}

View file

@ -10,7 +10,6 @@ use roaring::RoaringBitmap;
use rstar::RTree;
use time::OffsetDateTime;
use crate::distance::NDotProductPoint;
use crate::documents::PrimaryKey;
use crate::error::{InternalError, UserError};
use crate::fields_ids_map::FieldsIdsMap;
@ -30,9 +29,6 @@ use crate::{
BEU32, BEU64,
};
/// The HNSW data-structure that we serialize, fill and search in.
pub type Hnsw = instant_distance::Hnsw<NDotProductPoint>;
pub const DEFAULT_MIN_WORD_LEN_ONE_TYPO: u8 = 5;
pub const DEFAULT_MIN_WORD_LEN_TWO_TYPOS: u8 = 9;

View file

@ -10,7 +10,6 @@ pub mod documents;
mod asc_desc;
mod criterion;
pub mod distance;
mod error;
mod external_documents_ids;
pub mod facet;
@ -33,7 +32,6 @@ use std::convert::{TryFrom, TryInto};
use std::hash::BuildHasherDefault;
use charabia::normalizer::{CharNormalizer, CompatibilityDecompositionNormalizer};
pub use distance::dot_product_similarity;
pub use filter_parser::{Condition, FilterCondition, Span, Token};
use fxhash::{FxHasher32, FxHasher64};
pub use grenad::CompressionType;

View file

@ -50,6 +50,7 @@ use self::vector_sort::VectorSort;
use crate::error::FieldIdMapMissingEntry;
use crate::score_details::{ScoreDetails, ScoringStrategy};
use crate::search::new::distinct::apply_distinct_rule;
use crate::vector::DistributionShift;
use crate::{
AscDesc, DocumentId, FieldId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError,
};
@ -264,6 +265,7 @@ fn get_ranking_rules_for_vector<'ctx>(
geo_strategy: geo_sort::Strategy,
limit_plus_offset: usize,
target: &[f32],
distribution_shift: Option<DistributionShift>,
) -> Result<Vec<BoxRankingRule<'ctx, PlaceholderQuery>>> {
// query graph search
@ -289,6 +291,7 @@ fn get_ranking_rules_for_vector<'ctx>(
target.to_vec(),
vector_candidates,
limit_plus_offset,
distribution_shift,
)?;
ranking_rules.push(Box::new(vector_sort));
vector = true;
@ -515,8 +518,14 @@ pub fn execute_vector_search(
/// FIXME: input universe = universe & documents_with_vectors
// for now if we're computing embeddings for ALL documents, we can assume that this is just universe
let ranking_rules =
get_ranking_rules_for_vector(ctx, sort_criteria, geo_strategy, from + length, vector)?;
let ranking_rules = get_ranking_rules_for_vector(
ctx,
sort_criteria,
geo_strategy,
from + length,
vector,
None,
)?;
let mut placeholder_search_logger = logger::DefaultSearchLogger;
let placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery> =

View file

@ -5,6 +5,7 @@ use roaring::RoaringBitmap;
use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait};
use crate::score_details::{self, ScoreDetails};
use crate::vector::DistributionShift;
use crate::{DocumentId, Result, SearchContext, SearchLogger};
pub struct VectorSort<Q: RankingRuleQueryTrait> {
@ -13,6 +14,7 @@ pub struct VectorSort<Q: RankingRuleQueryTrait> {
vector_candidates: RoaringBitmap,
cached_sorted_docids: std::vec::IntoIter<(DocumentId, f32, Vec<f32>)>,
limit: usize,
distribution_shift: Option<DistributionShift>,
}
impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
@ -21,6 +23,7 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
target: Vec<f32>,
vector_candidates: RoaringBitmap,
limit: usize,
distribution_shift: Option<DistributionShift>,
) -> Result<Self> {
Ok(Self {
query: None,
@ -28,6 +31,7 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
vector_candidates,
cached_sorted_docids: Default::default(),
limit,
distribution_shift,
})
}
@ -52,7 +56,7 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
for reader in readers.iter() {
let nns_by_vector = reader.nns_by_vector(
ctx.txn,
&target,
target,
self.limit,
None,
Some(&self.vector_candidates),
@ -66,6 +70,7 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
}
results.sort_unstable_by_key(|(_, distance, _)| OrderedFloat(*distance));
self.cached_sorted_docids = results.into_iter();
Ok(())
}
}
@ -111,14 +116,19 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort<Q> {
}));
}
while let Some((docid, distance, vector)) = self.cached_sorted_docids.next() {
for (docid, distance, vector) in self.cached_sorted_docids.by_ref() {
if self.vector_candidates.contains(docid) {
let score = 1.0 - distance;
let score = self
.distribution_shift
.map(|distribution| distribution.shift(score))
.unwrap_or(score);
return Ok(Some(RankingRuleOutput {
query,
candidates: RoaringBitmap::from_iter([docid]),
score: ScoreDetails::Vector(score_details::Vector {
target_vector: self.target.clone(),
value_similarity: Some((vector, 1.0 - distance)),
value_similarity: Some((vector, score)),
}),
}));
}

View file

@ -415,7 +415,7 @@ pub(crate) fn write_typed_chunk_into_index(
let mut deleted_index = None;
for (index, writer) in writers.iter().enumerate() {
let Some(candidate) = writer.item_vector(&wtxn, docid)? else {
let Some(candidate) = writer.item_vector(wtxn, docid)? else {
// uses invariant: vectors are packed in the first writers.
break;
};
@ -429,7 +429,7 @@ pub(crate) fn write_typed_chunk_into_index(
if let Some(deleted_index) = deleted_index {
let mut last_index_with_a_vector = None;
for (index, writer) in writers.iter().enumerate().skip(deleted_index) {
let Some(candidate) = writer.item_vector(&wtxn, docid)? else {
let Some(candidate) = writer.item_vector(wtxn, docid)? else {
break;
};
last_index_with_a_vector = Some((index, candidate));

View file

@ -140,3 +140,47 @@ impl Embedder {
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct DistributionShift {
pub current_mean: f32,
pub current_sigma: f32,
}
impl DistributionShift {
/// `None` if sigma <= 0.
pub fn new(mean: f32, sigma: f32) -> Option<Self> {
if sigma <= 0.0 {
None
} else {
Some(Self { current_mean: mean, current_sigma: sigma })
}
}
pub fn shift(&self, score: f32) -> f32 {
// <https://math.stackexchange.com/a/2894689>
// We're somewhat abusively mapping the distribution of distances to a gaussian.
// The parameters we're given is the mean and sigma of the native result distribution.
// We're using them to retarget the distribution to a gaussian centered on 0.5 with a sigma of 0.4.
let target_mean = 0.5;
let target_sigma = 0.4;
// a^2 sig1^2 = sig2^2 => a^2 = sig2^2 / sig1^2 => a = sig2 / sig1, assuming a, sig1, and sig2 positive.
let factor = target_sigma / self.current_sigma;
// a*mu1 + b = mu2 => b = mu2 - a*mu1
let offset = target_mean - (factor * self.current_mean);
let mut score = factor * score + offset;
// clamp the final score in the ]0, 1] interval.
if score <= 0.0 {
score = f32::EPSILON;
}
if score > 1.0 {
score = 1.0;
}
score
}
}