Expose an _semanticSimilarity as a dot product in the documents

This commit is contained in:
Kerollmops 2023-06-20 14:38:58 +02:00 committed by Clément Renault
parent 3e3c743392
commit 737aec1705
No known key found for this signature in database
GPG Key ID: 92ADA4E935E71FA4
5 changed files with 39 additions and 4 deletions

1
Cargo.lock generated
View File

@ -2595,6 +2595,7 @@ dependencies = [
"num_cpus",
"obkv",
"once_cell",
"ordered-float",
"parking_lot",
"permissive-json-pointer",
"pin-project-lite",

View File

@ -48,6 +48,7 @@ mime = "0.3.17"
num_cpus = "1.15.0"
obkv = "0.2.0"
once_cell = "1.17.1"
ordered-float = "3.7.0"
parking_lot = "0.12.1"
permissive-json-pointer = { path = "../permissive-json-pointer" }
pin-project-lite = "0.2.9"

View File

@ -10,6 +10,7 @@ use meilisearch_auth::IndexSearchRules;
use meilisearch_types::deserr::DeserrJsonError;
use meilisearch_types::error::deserr_codes::*;
use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::milli::dot_product_similarity;
use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy};
use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS;
use meilisearch_types::{milli, Document};
@ -18,6 +19,7 @@ use milli::{
AscDesc, FieldId, FieldsIdsMap, Filter, FormatOptions, Index, MatchBounds, MatcherBuilder,
SortError, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET,
};
use ordered_float::OrderedFloat;
use regex::Regex;
use serde::Serialize;
use serde_json::{json, Value};
@ -457,6 +459,10 @@ pub fn perform_search(
insert_geo_distance(sort, &mut document);
}
if let Some(vector) = query.vector.as_ref() {
insert_semantic_similarity(&vector, &mut document);
}
let ranking_score =
query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter()));
let ranking_score_details =
@ -542,6 +548,22 @@ fn insert_geo_distance(sorts: &[String], document: &mut Document) {
}
}
fn insert_semantic_similarity(query: &[f32], document: &mut Document) {
if let Some(value) = document.get("_vectors") {
let vectors: Vec<Vec<f32>> = match serde_json::from_value(value.clone()) {
Ok(Either::Left(vector)) => vec![vector],
Ok(Either::Right(vectors)) => vectors,
Err(_) => return,
};
let similarity = vectors
.into_iter()
.map(|v| OrderedFloat(dot_product_similarity(query, &v)))
.max()
.map(OrderedFloat::into_inner);
document.insert("_semanticSimilarity".to_string(), json!(similarity));
}
}
fn compute_formatted_options(
attr_to_highlight: &HashSet<String>,
attr_to_crop: &[String],

View File

@ -12,13 +12,18 @@ impl Metric<Vec<f32>> for DotProduct {
//
// Following <https://docs.rs/space/0.17.0/space/trait.Metric.html>.
fn distance(&self, a: &Vec<f32>, b: &Vec<f32>) -> Self::Unit {
let dist: f32 = a.iter().zip(b).map(|(a, b)| a * b).sum();
let dist = 1.0 - dist;
let dist = 1.0 - dot_product_similarity(a, b);
debug_assert!(!dist.is_nan());
dist.to_bits()
}
}
/// Returns the dot product similarity score that will between 0.0 and 1.0
/// if both vectors are normalized. The higher the more similar the vectors are.
pub fn dot_product_similarity(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b).map(|(a, b)| a * b).sum()
}
#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize)]
pub struct Euclidean;
@ -26,9 +31,14 @@ impl Metric<Vec<f32>> for Euclidean {
type Unit = u32;
fn distance(&self, a: &Vec<f32>, b: &Vec<f32>) -> Self::Unit {
let squared: f32 = a.iter().zip(b).map(|(a, b)| (a - b).powi(2)).sum();
let dist = squared.sqrt();
let dist = euclidean_squared_distance(a, b).sqrt();
debug_assert!(!dist.is_nan());
dist.to_bits()
}
}
/// Return the squared euclidean distance between both vectors that will
/// between 0.0 and +inf. The smaller the nearer the vectors are.
pub fn euclidean_squared_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b).map(|(a, b)| (a - b).powi(2)).sum()
}

View File

@ -31,6 +31,7 @@ use std::convert::{TryFrom, TryInto};
use std::hash::BuildHasherDefault;
use charabia::normalizer::{CharNormalizer, CompatibilityDecompositionNormalizer};
pub use distance::{dot_product_similarity, euclidean_squared_distance};
pub use filter_parser::{Condition, FilterCondition, Span, Token};
use fxhash::{FxHasher32, FxHasher64};
pub use grenad::CompressionType;