Remove stuff, add distribution shift (WIP)

This commit is contained in:
Louis Dureuil 2023-12-12 10:05:06 +01:00
parent e56f160032
commit 65e49b7092
No known key found for this signature in database
10 changed files with 126 additions and 278 deletions

219
Cargo.lock generated
View File

@ -56,7 +56,7 @@ dependencies = [
"flate2",
"futures-core",
"h2",
"http",
"http 0.2.9",
"httparse",
"httpdate",
"itoa",
@ -90,7 +90,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d66ff4d247d2b160861fa2866457e85706833527840e4133f8f49aa423a38799"
dependencies = [
"bytestring",
"http",
"http 0.2.9",
"regex",
"serde",
"tracing",
@ -189,7 +189,7 @@ dependencies = [
"encoding_rs",
"futures-core",
"futures-util",
"http",
"http 0.2.9",
"itoa",
"language-tags",
"log",
@ -1407,7 +1407,7 @@ dependencies = [
"anyhow",
"big_s",
"flate2",
"http",
"http 0.2.9",
"log",
"maplit",
"meili-snap",
@ -1702,21 +1702,6 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "foreign-types"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1"
dependencies = [
"foreign-types-shared",
]
[[package]]
name = "foreign-types-shared"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
[[package]]
name = "form_urlencoded"
version = "1.2.0"
@ -2047,7 +2032,7 @@ dependencies = [
"futures-core",
"futures-sink",
"futures-util",
"http",
"http 0.2.9",
"indexmap 1.9.3",
"slab",
"tokio",
@ -2171,13 +2156,12 @@ checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
[[package]]
name = "hf-hub"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732"
source = "git+https://github.com/dureuill/hf-hub.git?branch=rust_tls#88d4f11cb9fa079f2912bacb96f5080b16825ce8"
dependencies = [
"dirs",
"http 1.0.0",
"indicatif",
"log",
"native-tls",
"rand",
"serde",
"serde_json",
@ -2205,6 +2189,17 @@ dependencies = [
"itoa",
]
[[package]]
name = "http"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b32afd38673a8016f7c9ae69e5af41a58f81b1d31689040f2f1959594ce194ea"
dependencies = [
"bytes",
"fnv",
"itoa",
]
[[package]]
name = "http-body"
version = "0.4.5"
@ -2212,7 +2207,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1"
dependencies = [
"bytes",
"http",
"http 0.2.9",
"pin-project-lite",
]
@ -2245,7 +2240,7 @@ dependencies = [
"futures-core",
"futures-util",
"h2",
"http",
"http 0.2.9",
"http-body",
"httparse",
"httpdate",
@ -2265,7 +2260,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d78e1e73ec14cf7375674f74d7dde185c8206fd9dea6fb6295e8a98098aaa97"
dependencies = [
"futures-util",
"http",
"http 0.2.9",
"hyper",
"rustls 0.21.6",
"tokio",
@ -2868,21 +2863,6 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "instant-distance"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c619cdaa30bb84088963968bee12a45ea5fbbf355f2c021bcd15589f5ca494a"
dependencies = [
"num_cpus",
"ordered-float 3.7.0",
"parking_lot",
"rand",
"rayon",
"serde",
"serde-big-array",
]
[[package]]
name = "io-lifetimes"
version = "1.0.11"
@ -3531,7 +3511,7 @@ dependencies = [
"futures",
"futures-util",
"hex",
"http",
"http 0.2.9",
"index-scheduler",
"indexmap 2.0.0",
"insta",
@ -3718,7 +3698,6 @@ dependencies = [
"hf-hub",
"indexmap 2.0.0",
"insta",
"instant-distance",
"itertools 0.11.0",
"json-depth-checker",
"levenshtein_automata",
@ -3730,7 +3709,6 @@ dependencies = [
"meili-snap",
"memmap2 0.7.1",
"mimalloc",
"nolife",
"obkv",
"once_cell",
"ordered-float 3.7.0",
@ -3829,35 +3807,11 @@ dependencies = [
"syn 2.0.28",
]
[[package]]
name = "native-tls"
version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07226173c32f2926027b63cce4bcd8076c3552846cbe7925f3aaffeac0a3b92e"
dependencies = [
"lazy_static",
"libc",
"log",
"openssl",
"openssl-probe",
"openssl-sys",
"schannel",
"security-framework",
"security-framework-sys",
"tempfile",
]
[[package]]
name = "nelson"
version = "0.1.0"
source = "git+https://github.com/meilisearch/nelson.git?rev=675f13885548fb415ead8fbb447e9e6d9314000a#675f13885548fb415ead8fbb447e9e6d9314000a"
[[package]]
name = "nolife"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc52aaf087e8a52e7a2692f83f2dac6ac7ff9d0136bf9c6ac496635cfe3e50dc"
[[package]]
name = "nom"
version = "7.1.3"
@ -3994,50 +3948,6 @@ version = "11.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575"
[[package]]
name = "openssl"
version = "0.10.59"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a257ad03cd8fb16ad4172fedf8094451e1af1c4b70097636ef2eac9a5f0cc33"
dependencies = [
"bitflags 2.4.1",
"cfg-if",
"foreign-types",
"libc",
"once_cell",
"openssl-macros",
"openssl-sys",
]
[[package]]
name = "openssl-macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.28",
]
[[package]]
name = "openssl-probe"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
[[package]]
name = "openssl-sys"
version = "0.9.95"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40a4130519a360279579c2053038317e40eff64d13fd3f004f9e1b72b8a6aaf9"
dependencies = [
"cc",
"libc",
"pkg-config",
"vcpkg",
]
[[package]]
name = "option-ext"
version = "0.2.0"
@ -4655,7 +4565,7 @@ dependencies = [
"futures-core",
"futures-util",
"h2",
"http",
"http 0.2.9",
"http-body",
"hyper",
"hyper-rustls",
@ -4802,7 +4712,7 @@ checksum = "1d1feddffcfcc0b33f5c6ce9a29e341e4cd59c3f78e7ee45f4a40c038b1d6cbb"
dependencies = [
"log",
"ring",
"rustls-webpki 0.101.3",
"rustls-webpki",
"sct",
]
@ -4815,16 +4725,6 @@ dependencies = [
"base64 0.21.5",
]
[[package]]
name = "rustls-webpki"
version = "0.100.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e98ff011474fa39949b7e5c0428f9b4937eda7da7848bbb947786b7be0b27dab"
dependencies = [
"ring",
"untrusted",
]
[[package]]
name = "rustls-webpki"
version = "0.101.3"
@ -4866,15 +4766,6 @@ dependencies = [
"winapi-util",
]
[[package]]
name = "schannel"
version = "0.1.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c3733bf4cf7ea0880754e19cb5a462007c4a8c1914bff372ccc95b464f1df88"
dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "scopeguard"
version = "1.2.0"
@ -4891,29 +4782,6 @@ dependencies = [
"untrusted",
]
[[package]]
name = "security-framework"
version = "2.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05b64fb303737d99b81884b2c63433e9ae28abebe5eb5045dcdd175dc2ecf4de"
dependencies = [
"bitflags 1.3.2",
"core-foundation",
"core-foundation-sys",
"libc",
"security-framework-sys",
]
[[package]]
name = "security-framework-sys"
version = "2.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e932934257d3b408ed8f30db49d85ea163bfe74961f017f405b025af298f0c7a"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "segment"
version = "0.2.2"
@ -4949,15 +4817,6 @@ dependencies = [
"serde_derive",
]
[[package]]
name = "serde-big-array"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11fc7cc2c76d73e0f27ee52abbd64eec84d46f370c88371120433196934e4b7f"
dependencies = [
"serde",
]
[[package]]
name = "serde-cs"
version = "0.2.4"
@ -5151,6 +5010,17 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "socks"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b"
dependencies = [
"byteorder",
"libc",
"winapi",
]
[[package]]
name = "spin"
version = "0.5.2"
@ -5713,21 +5583,21 @@ checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"
[[package]]
name = "ureq"
version = "2.7.1"
version = "2.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b11c96ac7ee530603dcdf68ed1557050f374ce55a5a07193ebf8cbc9f8927e9"
checksum = "f8cdd25c339e200129fe4de81451814e5228c9b771d57378817d6117cc2b3f97"
dependencies = [
"base64 0.21.5",
"flate2",
"log",
"native-tls",
"once_cell",
"rustls 0.21.6",
"rustls-webpki 0.100.2",
"rustls-webpki",
"serde",
"serde_json",
"socks",
"url",
"webpki-roots 0.23.1",
"webpki-roots 0.25.3",
]
[[package]]
@ -5958,15 +5828,6 @@ dependencies = [
"webpki",
]
[[package]]
name = "webpki-roots"
version = "0.23.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b03058f88386e5ff5310d9111d53f48b17d732b401aeb83a8d5190f2ac459338"
dependencies = [
"rustls-webpki 0.100.2",
]
[[package]]
name = "webpki-roots"
version = "0.25.3"

View File

@ -14,18 +14,14 @@ use meilisearch_types::error::deserr_codes::*;
use meilisearch_types::heed::RoTxn;
use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy};
use meilisearch_types::milli::{
dot_product_similarity, FacetValueHit, InternalError, OrderBy, SearchForFacetValues,
VectorQuery,
};
use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues, VectorQuery};
use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS;
use meilisearch_types::{milli, Document};
use milli::tokenizer::TokenizerBuilder;
use milli::{
AscDesc, FieldId, FieldsIdsMap, Filter, FormatOptions, Index, MatchBounds, MatcherBuilder,
SortError, TermsMatchingStrategy, VectorOrArrayOfVectors, DEFAULT_VALUES_PER_FACET,
SortError, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET,
};
use ordered_float::OrderedFloat;
use regex::Regex;
use serde::Serialize;
use serde_json::{json, Value};
@ -550,13 +546,8 @@ pub fn perform_search(
insert_geo_distance(sort, &mut document);
}
let semantic_score = /*match query.vector.as_ref() {
Some(vector) => match extract_field("_vectors", &fields_ids_map, obkv)? {
Some(vectors) => compute_semantic_score(vector, vectors)?,
None => None,
},
None => None,
};*/ None;
/// FIXME: remove this or set to value from the score details
let semantic_score = None;
let ranking_score =
query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter()));
@ -689,18 +680,6 @@ fn insert_geo_distance(sorts: &[String], document: &mut Document) {
}
}
fn compute_semantic_score(query: &[f32], vectors: Value) -> milli::Result<Option<f32>> {
let vectors = serde_json::from_value(vectors)
.map(VectorOrArrayOfVectors::into_array_of_vectors)
.map_err(InternalError::SerdeJson)?;
Ok(vectors
.into_iter()
.flatten()
.map(|v| OrderedFloat(dot_product_similarity(query, &v)))
.max()
.map(OrderedFloat::into_inner))
}
fn compute_formatted_options(
attr_to_highlight: &HashSet<String>,
attr_to_crop: &[String],
@ -828,22 +807,6 @@ fn make_document(
Ok(document)
}
/// Extract the JSON value under the field name specified
/// but doesn't support nested objects.
fn extract_field(
field_name: &str,
field_ids_map: &FieldsIdsMap,
obkv: obkv::KvReaderU16,
) -> Result<Option<serde_json::Value>, MeilisearchHttpError> {
match field_ids_map.id(field_name) {
Some(fid) => match obkv.get(fid) {
Some(value) => Ok(serde_json::from_slice(value).map(Some)?),
None => Ok(None),
},
None => Ok(None),
}
}
fn format_fields<'a>(
document: &Document,
field_ids_map: &FieldsIdsMap,

View File

@ -36,7 +36,6 @@ heed = { version = "0.20.0-alpha.9", default-features = false, features = [
"read-txn-no-tls",
] }
indexmap = { version = "2.0.0", features = ["serde"] }
instant-distance = { version = "0.6.1", features = ["with-serde"] }
json-depth-checker = { path = "../json-depth-checker" }
levenshtein_automata = { version = "0.2.1", features = ["fst_automaton"] }
memmap2 = "0.7.1"
@ -79,10 +78,11 @@ candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.
candle-transformers = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" }
candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" }
tokenizers = { git = "https://github.com/huggingface/tokenizers.git", tag = "v0.14.1", version = "0.14.1" }
hf-hub = "0.3.2"
hf-hub = { git = "https://github.com/dureuill/hf-hub.git", branch = "rust_tls", default_features = false, features = [
"online",
] }
tokio = { version = "1.34.0", features = ["rt"] }
futures = "0.3.29"
nolife = { version = "0.3.1" }
reqwest = { version = "0.11.16", features = [
"rustls-tls",
"json",
@ -102,7 +102,15 @@ meili-snap = { path = "../meili-snap" }
rand = { version = "0.8.5", features = ["small_rng"] }
[features]
all-tokenizations = ["charabia/chinese", "charabia/hebrew", "charabia/japanese", "charabia/thai", "charabia/korean", "charabia/greek", "charabia/khmer"]
all-tokenizations = [
"charabia/chinese",
"charabia/hebrew",
"charabia/japanese",
"charabia/thai",
"charabia/korean",
"charabia/greek",
"charabia/khmer",
]
# Use POSIX semaphores instead of SysV semaphores in LMDB
# For more information on this feature, see heed's Cargo.toml

View File

@ -1,41 +0,0 @@
use std::ops;
use instant_distance::Point;
use serde::{Deserialize, Serialize};
use crate::normalize_vector;
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct NDotProductPoint(Vec<f32>);
impl NDotProductPoint {
pub fn new(point: Vec<f32>) -> Self {
NDotProductPoint(normalize_vector(point))
}
pub fn into_inner(self) -> Vec<f32> {
self.0
}
}
impl ops::Deref for NDotProductPoint {
type Target = [f32];
fn deref(&self) -> &Self::Target {
self.0.as_slice()
}
}
impl Point for NDotProductPoint {
fn distance(&self, other: &Self) -> f32 {
let dist = 1.0 - dot_product_similarity(&self.0, &other.0);
debug_assert!(!dist.is_nan());
dist
}
}
/// Returns the dot product similarity score that will between 0.0 and 1.0
/// if both vectors are normalized. The higher the more similar the vectors are.
pub fn dot_product_similarity(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b).map(|(a, b)| a * b).sum()
}

View File

@ -10,7 +10,6 @@ use roaring::RoaringBitmap;
use rstar::RTree;
use time::OffsetDateTime;
use crate::distance::NDotProductPoint;
use crate::documents::PrimaryKey;
use crate::error::{InternalError, UserError};
use crate::fields_ids_map::FieldsIdsMap;
@ -30,9 +29,6 @@ use crate::{
BEU32, BEU64,
};
/// The HNSW data-structure that we serialize, fill and search in.
pub type Hnsw = instant_distance::Hnsw<NDotProductPoint>;
pub const DEFAULT_MIN_WORD_LEN_ONE_TYPO: u8 = 5;
pub const DEFAULT_MIN_WORD_LEN_TWO_TYPOS: u8 = 9;

View File

@ -10,7 +10,6 @@ pub mod documents;
mod asc_desc;
mod criterion;
pub mod distance;
mod error;
mod external_documents_ids;
pub mod facet;
@ -33,7 +32,6 @@ use std::convert::{TryFrom, TryInto};
use std::hash::BuildHasherDefault;
use charabia::normalizer::{CharNormalizer, CompatibilityDecompositionNormalizer};
pub use distance::dot_product_similarity;
pub use filter_parser::{Condition, FilterCondition, Span, Token};
use fxhash::{FxHasher32, FxHasher64};
pub use grenad::CompressionType;

View File

@ -50,6 +50,7 @@ use self::vector_sort::VectorSort;
use crate::error::FieldIdMapMissingEntry;
use crate::score_details::{ScoreDetails, ScoringStrategy};
use crate::search::new::distinct::apply_distinct_rule;
use crate::vector::DistributionShift;
use crate::{
AscDesc, DocumentId, FieldId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError,
};
@ -264,6 +265,7 @@ fn get_ranking_rules_for_vector<'ctx>(
geo_strategy: geo_sort::Strategy,
limit_plus_offset: usize,
target: &[f32],
distribution_shift: Option<DistributionShift>,
) -> Result<Vec<BoxRankingRule<'ctx, PlaceholderQuery>>> {
// query graph search
@ -289,6 +291,7 @@ fn get_ranking_rules_for_vector<'ctx>(
target.to_vec(),
vector_candidates,
limit_plus_offset,
distribution_shift,
)?;
ranking_rules.push(Box::new(vector_sort));
vector = true;
@ -515,8 +518,14 @@ pub fn execute_vector_search(
/// FIXME: input universe = universe & documents_with_vectors
// for now if we're computing embeddings for ALL documents, we can assume that this is just universe
let ranking_rules =
get_ranking_rules_for_vector(ctx, sort_criteria, geo_strategy, from + length, vector)?;
let ranking_rules = get_ranking_rules_for_vector(
ctx,
sort_criteria,
geo_strategy,
from + length,
vector,
None,
)?;
let mut placeholder_search_logger = logger::DefaultSearchLogger;
let placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery> =

View File

@ -5,6 +5,7 @@ use roaring::RoaringBitmap;
use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait};
use crate::score_details::{self, ScoreDetails};
use crate::vector::DistributionShift;
use crate::{DocumentId, Result, SearchContext, SearchLogger};
pub struct VectorSort<Q: RankingRuleQueryTrait> {
@ -13,6 +14,7 @@ pub struct VectorSort<Q: RankingRuleQueryTrait> {
vector_candidates: RoaringBitmap,
cached_sorted_docids: std::vec::IntoIter<(DocumentId, f32, Vec<f32>)>,
limit: usize,
distribution_shift: Option<DistributionShift>,
}
impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
@ -21,6 +23,7 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
target: Vec<f32>,
vector_candidates: RoaringBitmap,
limit: usize,
distribution_shift: Option<DistributionShift>,
) -> Result<Self> {
Ok(Self {
query: None,
@ -28,6 +31,7 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
vector_candidates,
cached_sorted_docids: Default::default(),
limit,
distribution_shift,
})
}
@ -52,7 +56,7 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
for reader in readers.iter() {
let nns_by_vector = reader.nns_by_vector(
ctx.txn,
&target,
target,
self.limit,
None,
Some(&self.vector_candidates),
@ -66,6 +70,7 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
}
results.sort_unstable_by_key(|(_, distance, _)| OrderedFloat(*distance));
self.cached_sorted_docids = results.into_iter();
Ok(())
}
}
@ -111,14 +116,19 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort<Q> {
}));
}
while let Some((docid, distance, vector)) = self.cached_sorted_docids.next() {
for (docid, distance, vector) in self.cached_sorted_docids.by_ref() {
if self.vector_candidates.contains(docid) {
let score = 1.0 - distance;
let score = self
.distribution_shift
.map(|distribution| distribution.shift(score))
.unwrap_or(score);
return Ok(Some(RankingRuleOutput {
query,
candidates: RoaringBitmap::from_iter([docid]),
score: ScoreDetails::Vector(score_details::Vector {
target_vector: self.target.clone(),
value_similarity: Some((vector, 1.0 - distance)),
value_similarity: Some((vector, score)),
}),
}));
}

View File

@ -415,7 +415,7 @@ pub(crate) fn write_typed_chunk_into_index(
let mut deleted_index = None;
for (index, writer) in writers.iter().enumerate() {
let Some(candidate) = writer.item_vector(&wtxn, docid)? else {
let Some(candidate) = writer.item_vector(wtxn, docid)? else {
// uses invariant: vectors are packed in the first writers.
break;
};
@ -429,7 +429,7 @@ pub(crate) fn write_typed_chunk_into_index(
if let Some(deleted_index) = deleted_index {
let mut last_index_with_a_vector = None;
for (index, writer) in writers.iter().enumerate().skip(deleted_index) {
let Some(candidate) = writer.item_vector(&wtxn, docid)? else {
let Some(candidate) = writer.item_vector(wtxn, docid)? else {
break;
};
last_index_with_a_vector = Some((index, candidate));

View File

@ -140,3 +140,47 @@ impl Embedder {
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct DistributionShift {
pub current_mean: f32,
pub current_sigma: f32,
}
impl DistributionShift {
/// `None` if sigma <= 0.
pub fn new(mean: f32, sigma: f32) -> Option<Self> {
if sigma <= 0.0 {
None
} else {
Some(Self { current_mean: mean, current_sigma: sigma })
}
}
pub fn shift(&self, score: f32) -> f32 {
// <https://math.stackexchange.com/a/2894689>
// We're somewhat abusively mapping the distribution of distances to a gaussian.
// The parameters we're given is the mean and sigma of the native result distribution.
// We're using them to retarget the distribution to a gaussian centered on 0.5 with a sigma of 0.4.
let target_mean = 0.5;
let target_sigma = 0.4;
// a^2 sig1^2 = sig2^2 => a^2 = sig2^2 / sig1^2 => a = sig2 / sig1, assuming a, sig1, and sig2 positive.
let factor = target_sigma / self.current_sigma;
// a*mu1 + b = mu2 => b = mu2 - a*mu1
let offset = target_mean - (factor * self.current_mean);
let mut score = factor * score + offset;
// clamp the final score in the ]0, 1] interval.
if score <= 0.0 {
score = f32::EPSILON;
}
if score > 1.0 {
score = 1.0;
}
score
}
}