Implement useful conversion strategies and clean up the code

This commit is contained in:
Kerollmops 2025-05-30 10:54:32 +02:00 committed by Clément Renault
parent 2821163b95
commit 50fafbbc8b
No known key found for this signature in database
GPG key ID: F250A4C4E3AE5F5F
8 changed files with 224 additions and 203 deletions

View file

@ -1,14 +1,18 @@
use std::borrow::Cow;
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::error::Error;
use std::fmt;
use std::fs::File;
use std::path::Path;
use deserr::Deserr;
use heed::types::*;
use heed::{CompactionOption, Database, DatabaseStat, RoTxn, RwTxn, Unspecified, WithoutTls};
use indexmap::IndexMap;
use roaring::RoaringBitmap;
use rstar::RTree;
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use crate::constants::{self, RESERVED_GEO_FIELD_NAME, RESERVED_VECTORS_FIELD_NAME};
use crate::database_stats::DatabaseStats;
@ -25,6 +29,7 @@ use crate::heed_codec::{BEU16StrCodec, FstSetCodec, StrBEU16Codec, StrRefCodec};
use crate::order_by_map::OrderByMap;
use crate::prompt::PromptData;
use crate::proximity::ProximityPrecision;
use crate::update::new::StdResult;
use crate::vector::{ArroyStats, ArroyWrapper, Embedding, EmbeddingConfig};
use crate::{
default_criteria, CboRoaringBitmapCodec, Criterion, DocumentId, ExternalDocumentsIds,
@ -1962,10 +1967,46 @@ pub struct SearchParameters {
#[serde(skip_serializing_if = "Option::is_none")]
pub attributes_to_search_on: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ranking_score_threshold: Option<f64>,
pub ranking_score_threshold: Option<RankingScoreThreshold>,
}
#[derive(Debug, Default, Deserialize, Serialize)]
#[derive(Debug, Clone, Copy, Default, Deserialize, Serialize, PartialEq, Deserr, ToSchema)]
#[deserr(try_from(f64) = TryFrom::try_from -> InvalidSearchRankingScoreThreshold)]
pub struct RankingScoreThreshold(f64);
impl RankingScoreThreshold {
pub fn as_f64(&self) -> f64 {
self.0
}
}
impl TryFrom<f64> for RankingScoreThreshold {
type Error = InvalidSearchRankingScoreThreshold;
fn try_from(value: f64) -> StdResult<Self, Self::Error> {
if value < 0.0 || value > 1.0 {
Err(InvalidSearchRankingScoreThreshold)
} else {
Ok(RankingScoreThreshold(value))
}
}
}
#[derive(Debug)]
pub struct InvalidSearchRankingScoreThreshold;
impl Error for InvalidSearchRankingScoreThreshold {}
impl fmt::Display for InvalidSearchRankingScoreThreshold {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"the value of `rankingScoreThreshold` is invalid, expected a float between `0.0` and `1.0`."
)
}
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct HybridQuery {
pub semantic_ratio: f32,
@ -1980,10 +2021,12 @@ pub struct PrefixSettings {
pub compute_prefixes: PrefixSearch,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Deserr, ToSchema, Serialize, Deserialize)]
#[deserr(rename_all = camelCase)]
#[serde(rename_all = "camelCase")]
pub enum MatchingStrategy {
/// Remove query words from last to first
#[default]
Last,
/// All query words are mandatory
All,

View file

@ -64,7 +64,7 @@ fn default_template() -> liquid::Template {
new_template(default_template_text()).unwrap()
}
fn default_template_text() -> &'static str {
pub fn default_template_text() -> &'static str {
"{% for field in fields %}\
{% if field.is_searchable and field.value != nil %}\
{{ field.name }}: {{ field.value }}\n\

View file

@ -10,6 +10,7 @@ pub use self::facet::{FacetDistribution, Filter, OrderBy, DEFAULT_VALUES_PER_FAC
pub use self::new::matches::{FormatOptions, MatchBounds, MatcherBuilder, MatchingWords};
use self::new::{execute_vector_search, PartialSearchResult, VectorStoreStats};
use crate::filterable_attributes_rules::{filtered_matching_patterns, matching_features};
use crate::index::MatchingStrategy;
use crate::score_details::{ScoreDetails, ScoringStrategy};
use crate::vector::Embedder;
use crate::{
@ -364,6 +365,16 @@ impl Default for TermsMatchingStrategy {
}
}
impl From<MatchingStrategy> for TermsMatchingStrategy {
fn from(other: MatchingStrategy) -> Self {
match other {
MatchingStrategy::Last => Self::Last,
MatchingStrategy::All => Self::All,
MatchingStrategy::Frequency => Self::Frequency,
}
}
}
fn get_first(s: &str) -> &str {
match s.chars().next() {
Some(c) => &s[..c.len_utf8()],

View file

@ -6,10 +6,9 @@ use deserr::Deserr;
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use crate::index::{self, ChatConfig, SearchParameters};
use crate::index::{self, ChatConfig, MatchingStrategy, RankingScoreThreshold, SearchParameters};
use crate::prompt::{default_max_bytes, PromptData};
use crate::update::Setting;
use crate::TermsMatchingStrategy;
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Deserr, ToSchema)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
@ -70,13 +69,10 @@ impl From<ChatConfig> for ChatSettings {
HybridQuery { semantic_ratio: SemanticRatio(semantic_ratio), embedder }
});
let matching_strategy = matching_strategy.map(|ms| match ms {
index::MatchingStrategy::Last => MatchingStrategy::Last,
index::MatchingStrategy::All => MatchingStrategy::All,
index::MatchingStrategy::Frequency => MatchingStrategy::Frequency,
});
let matching_strategy = matching_strategy.map(MatchingStrategy::from);
let ranking_score_threshold = ranking_score_threshold.map(RankingScoreThreshold);
let ranking_score_threshold =
ranking_score_threshold.map(RankingScoreThreshold::from);
ChatSearchParams {
hybrid: Setting::some_or_not_set(hybrid),
@ -197,63 +193,3 @@ impl std::ops::Deref for SemanticRatio {
&self.0
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Deserr, ToSchema, Serialize, Deserialize)]
#[deserr(rename_all = camelCase)]
#[serde(rename_all = "camelCase")]
pub enum MatchingStrategy {
/// Remove query words from last to first
Last,
/// All query words are mandatory
All,
/// Remove query words from the most frequent to the least
Frequency,
}
impl Default for MatchingStrategy {
fn default() -> Self {
Self::Last
}
}
impl From<MatchingStrategy> for TermsMatchingStrategy {
fn from(other: MatchingStrategy) -> Self {
match other {
MatchingStrategy::Last => Self::Last,
MatchingStrategy::All => Self::All,
MatchingStrategy::Frequency => Self::Frequency,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Deserr, ToSchema, Serialize, Deserialize)]
#[deserr(try_from(f64) = TryFrom::try_from -> InvalidSearchRankingScoreThreshold)]
pub struct RankingScoreThreshold(pub f64);
impl std::convert::TryFrom<f64> for RankingScoreThreshold {
type Error = InvalidSearchRankingScoreThreshold;
fn try_from(f: f64) -> Result<Self, Self::Error> {
// the suggested "fix" is: `!(0.0..=1.0).contains(&f)`` which is allegedly less readable
#[allow(clippy::manual_range_contains)]
if f > 1.0 || f < 0.0 {
Err(InvalidSearchRankingScoreThreshold)
} else {
Ok(RankingScoreThreshold(f))
}
}
}
#[derive(Debug)]
pub struct InvalidSearchRankingScoreThreshold;
impl Error for InvalidSearchRankingScoreThreshold {}
impl fmt::Display for InvalidSearchRankingScoreThreshold {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"the value of `rankingScoreThreshold` is invalid, expected a float between `0.0` and `1.0`."
)
}
}

View file

@ -11,7 +11,7 @@ use roaring::RoaringBitmap;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use time::OffsetDateTime;
use super::chat::{ChatSearchParams, RankingScoreThreshold};
use super::chat::ChatSearchParams;
use super::del_add::{DelAdd, DelAddOperation};
use super::index_documents::{IndexDocumentsConfig, Transform};
use super::{ChatSettings, IndexerConfig};
@ -23,11 +23,11 @@ use crate::error::UserError;
use crate::fields_ids_map::metadata::{FieldIdMapWithMetadata, MetadataBuilder};
use crate::filterable_attributes_rules::match_faceted_field;
use crate::index::{
ChatConfig, IndexEmbeddingConfig, MatchingStrategy, PrefixSearch,
DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS,
ChatConfig, IndexEmbeddingConfig, MatchingStrategy, PrefixSearch, RankingScoreThreshold,
SearchParameters, DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS,
};
use crate::order_by_map::OrderByMap;
use crate::prompt::{default_max_bytes, PromptData};
use crate::prompt::{default_max_bytes, default_template_text, PromptData};
use crate::proximity::ProximityPrecision;
use crate::update::index_documents::IndexDocumentsMethod;
use crate::update::{IndexDocuments, UpdateIndexingStep};
@ -1266,32 +1266,29 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
document_template_max_bytes: new_document_template_max_bytes,
search_parameters: new_search_parameters,
}) => {
let mut old = self.index.chat_config(self.wtxn)?;
let ChatConfig {
ref mut description,
prompt: PromptData { ref mut template, ref mut max_bytes },
ref mut search_parameters,
} = old;
let ChatConfig { description, prompt, search_parameters } =
self.index.chat_config(self.wtxn)?;
match new_description {
Setting::Set(d) => *description = d.clone(),
Setting::Reset => *description = Default::default(),
Setting::NotSet => (),
}
let description = match new_description {
Setting::Set(new) => new.clone(),
Setting::Reset => Default::default(),
Setting::NotSet => description,
};
match new_document_template {
Setting::Set(dt) => *template = dt.clone(),
Setting::Reset => *template = Default::default(),
Setting::NotSet => (),
}
let prompt = PromptData {
template: match new_document_template {
Setting::Set(new) => new.clone(),
Setting::Reset => default_template_text().to_string(),
Setting::NotSet => prompt.template.clone(),
},
max_bytes: match new_document_template_max_bytes {
Setting::Set(m) => NonZeroUsize::new(*m),
Setting::Reset => Some(default_max_bytes()),
Setting::NotSet => prompt.max_bytes,
},
};
match new_document_template_max_bytes {
Setting::Set(m) => *max_bytes = NonZeroUsize::new(*m),
Setting::Reset => *max_bytes = Some(default_max_bytes()),
Setting::NotSet => (),
}
match new_search_parameters {
let search_parameters = match new_search_parameters {
Setting::Set(sp) => {
let ChatSearchParams {
hybrid,
@ -1303,74 +1300,62 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
ranking_score_threshold,
} = sp;
match hybrid {
Setting::Set(hybrid) => {
search_parameters.hybrid = Some(crate::index::HybridQuery {
SearchParameters {
hybrid: match hybrid {
Setting::Set(hybrid) => Some(crate::index::HybridQuery {
semantic_ratio: *hybrid.semantic_ratio,
embedder: hybrid.embedder.clone(),
})
}
Setting::Reset => search_parameters.hybrid = None,
Setting::NotSet => (),
}
match limit {
Setting::Set(limit) => search_parameters.limit = Some(*limit),
Setting::Reset => search_parameters.limit = None,
Setting::NotSet => (),
}
match sort {
Setting::Set(sort) => search_parameters.sort = Some(sort.clone()),
Setting::Reset => search_parameters.sort = None,
Setting::NotSet => (),
}
match distinct {
Setting::Set(distinct) => {
search_parameters.distinct = Some(distinct.clone())
}
Setting::Reset => search_parameters.distinct = None,
Setting::NotSet => (),
}
match matching_strategy {
Setting::Set(matching_strategy) => {
let strategy = match matching_strategy {
super::chat::MatchingStrategy::Last => MatchingStrategy::Last,
super::chat::MatchingStrategy::All => MatchingStrategy::All,
super::chat::MatchingStrategy::Frequency => {
MatchingStrategy::Frequency
}
};
search_parameters.matching_strategy = Some(strategy)
}
Setting::Reset => search_parameters.matching_strategy = None,
Setting::NotSet => (),
}
match attributes_to_search_on {
Setting::Set(attributes_to_search_on) => {
search_parameters.attributes_to_search_on =
}),
Setting::Reset => None,
Setting::NotSet => search_parameters.hybrid.clone(),
},
limit: match limit {
Setting::Set(limit) => Some(*limit),
Setting::Reset => None,
Setting::NotSet => search_parameters.limit,
},
sort: match sort {
Setting::Set(sort) => Some(sort.clone()),
Setting::Reset => None,
Setting::NotSet => search_parameters.sort.clone(),
},
distinct: match distinct {
Setting::Set(distinct) => Some(distinct.clone()),
Setting::Reset => None,
Setting::NotSet => search_parameters.distinct.clone(),
},
matching_strategy: match matching_strategy {
Setting::Set(matching_strategy) => {
Some(MatchingStrategy::from(*matching_strategy))
}
Setting::Reset => None,
Setting::NotSet => search_parameters.matching_strategy,
},
attributes_to_search_on: match attributes_to_search_on {
Setting::Set(attributes_to_search_on) => {
Some(attributes_to_search_on.clone())
}
Setting::Reset => search_parameters.attributes_to_search_on = None,
Setting::NotSet => (),
}
match ranking_score_threshold {
Setting::Set(RankingScoreThreshold(score)) => {
search_parameters.ranking_score_threshold = Some(*score)
}
Setting::Reset => search_parameters.ranking_score_threshold = None,
Setting::NotSet => (),
}
Setting::Reset => None,
Setting::NotSet => {
search_parameters.attributes_to_search_on.clone()
}
},
ranking_score_threshold: match ranking_score_threshold {
Setting::Set(rst) => Some(RankingScoreThreshold::from(*rst)),
Setting::Reset => None,
Setting::NotSet => search_parameters.ranking_score_threshold,
},
}
}
Setting::Reset => *search_parameters = Default::default(),
Setting::NotSet => (),
}
Setting::Reset => Default::default(),
Setting::NotSet => search_parameters,
};
self.index.put_chat_config(
self.wtxn,
&ChatConfig { description, prompt, search_parameters },
)?;
self.index.put_chat_config(self.wtxn, &old)?;
Ok(true)
}
Setting::Reset => self.index.delete_chat_config(self.wtxn),

View file

@ -926,6 +926,7 @@ fn test_correct_settings_init() {
assert!(matches!(prefix_search, Setting::NotSet));
assert!(matches!(facet_search, Setting::NotSet));
assert!(matches!(disable_on_numbers, Setting::NotSet));
assert!(matches!(chat, Setting::NotSet));
})
.unwrap();
}