Merge branch 'main' into tmp-release-v1.7.4

This commit is contained in:
Louis Dureuil 2024-03-28 10:51:49 +01:00 committed by GitHub
commit 796213af9a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
68 changed files with 4625 additions and 1492 deletions

View file

@ -243,6 +243,8 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco
},
#[error("`.embedders.{embedder_name}.dimensions`: `dimensions` cannot be zero")]
InvalidSettingsDimensions { embedder_name: String },
#[error("`.embedders.{embedder_name}.url`: could not parse `{url}`: {inner_error}")]
InvalidUrl { embedder_name: String, inner_error: url::ParseError, url: String },
}
impl From<crate::vector::Error> for Error {

View file

@ -20,13 +20,13 @@ use crate::heed_codec::facet::{
use crate::heed_codec::{
BEU16StrCodec, FstSetCodec, ScriptLanguageCodec, StrBEU16Codec, StrRefCodec,
};
use crate::order_by_map::OrderByMap;
use crate::proximity::ProximityPrecision;
use crate::vector::EmbeddingConfig;
use crate::{
default_criteria, CboRoaringBitmapCodec, Criterion, DocumentId, ExternalDocumentsIds,
FacetDistribution, FieldDistribution, FieldId, FieldIdWordCountCodec, GeoPoint, ObkvCodec,
OrderBy, Result, RoaringBitmapCodec, RoaringBitmapLenCodec, Search, U8StrStrCodec, BEU16,
BEU32, BEU64,
Result, RoaringBitmapCodec, RoaringBitmapLenCodec, Search, U8StrStrCodec, BEU16, BEU32, BEU64,
};
pub const DEFAULT_MIN_WORD_LEN_ONE_TYPO: u8 = 5;
@ -67,6 +67,7 @@ pub mod main_key {
pub const PAGINATION_MAX_TOTAL_HITS: &str = "pagination-max-total-hits";
pub const PROXIMITY_PRECISION: &str = "proximity-precision";
pub const EMBEDDING_CONFIGS: &str = "embedding_configs";
pub const SEARCH_CUTOFF: &str = "search_cutoff";
}
pub mod db_name {
@ -1373,21 +1374,19 @@ impl Index {
self.main.remap_key_type::<Str>().delete(txn, main_key::MAX_VALUES_PER_FACET)
}
pub fn sort_facet_values_by(&self, txn: &RoTxn) -> heed::Result<HashMap<String, OrderBy>> {
let mut orders = self
pub fn sort_facet_values_by(&self, txn: &RoTxn) -> heed::Result<OrderByMap> {
let orders = self
.main
.remap_types::<Str, SerdeJson<HashMap<String, OrderBy>>>()
.remap_types::<Str, SerdeJson<OrderByMap>>()
.get(txn, main_key::SORT_FACET_VALUES_BY)?
.unwrap_or_default();
// Insert the default ordering if it is not already overwritten by the user.
orders.entry("*".to_string()).or_insert(OrderBy::Lexicographic);
Ok(orders)
}
pub(crate) fn put_sort_facet_values_by(
&self,
txn: &mut RwTxn,
val: &HashMap<String, OrderBy>,
val: &OrderByMap,
) -> heed::Result<()> {
self.main.remap_types::<Str, SerdeJson<_>>().put(txn, main_key::SORT_FACET_VALUES_BY, &val)
}
@ -1507,6 +1506,18 @@ impl Index {
_ => "default".to_owned(),
})
}
pub(crate) fn put_search_cutoff(&self, wtxn: &mut RwTxn<'_>, cutoff: u64) -> heed::Result<()> {
self.main.remap_types::<Str, BEU64>().put(wtxn, main_key::SEARCH_CUTOFF, &cutoff)
}
pub fn search_cutoff(&self, rtxn: &RoTxn<'_>) -> Result<Option<u64>> {
Ok(self.main.remap_types::<Str, BEU64>().get(rtxn, main_key::SEARCH_CUTOFF)?)
}
pub(crate) fn delete_search_cutoff(&self, wtxn: &mut RwTxn<'_>) -> heed::Result<bool> {
self.main.remap_key_type::<Str>().delete(wtxn, main_key::SEARCH_CUTOFF)
}
}
#[cfg(test)]
@ -2423,6 +2434,7 @@ pub(crate) mod tests {
candidates: _,
document_scores: _,
mut documents_ids,
degraded: _,
} = search.execute().unwrap();
let primary_key_id = index.fields_ids_map(&rtxn).unwrap().id("primary_key").unwrap();
documents_ids.sort_unstable();

View file

@ -16,6 +16,7 @@ pub mod facet;
mod fields_ids_map;
pub mod heed_codec;
pub mod index;
pub mod order_by_map;
pub mod prompt;
pub mod proximity;
pub mod score_details;
@ -29,6 +30,7 @@ pub mod snapshot_tests;
use std::collections::{BTreeMap, HashMap};
use std::convert::{TryFrom, TryInto};
use std::fmt;
use std::hash::BuildHasherDefault;
use charabia::normalizer::{CharNormalizer, CompatibilityDecompositionNormalizer};
@ -56,10 +58,10 @@ pub use self::heed_codec::{
UncheckedU8StrStrCodec,
};
pub use self::index::Index;
pub use self::search::facet::{FacetValueHit, SearchForFacetValues};
pub use self::search::{
FacetDistribution, FacetValueHit, Filter, FormatOptions, MatchBounds, MatcherBuilder,
MatchingWords, OrderBy, Search, SearchForFacetValues, SearchResult, TermsMatchingStrategy,
DEFAULT_VALUES_PER_FACET,
FacetDistribution, Filter, FormatOptions, MatchBounds, MatcherBuilder, MatchingWords, OrderBy,
Search, SearchResult, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET,
};
pub type Result<T> = std::result::Result<T, error::Error>;
@ -103,6 +105,73 @@ pub const MAX_WORD_LENGTH: usize = MAX_LMDB_KEY_LENGTH / 2;
pub const MAX_POSITION_PER_ATTRIBUTE: u32 = u16::MAX as u32 + 1;
#[derive(Clone)]
pub struct TimeBudget {
started_at: std::time::Instant,
budget: std::time::Duration,
/// When testing the time budget, ensuring we did more than iteration of the bucket sort can be useful.
/// But to avoid being flaky, the only option is to add the ability to stop after a specific number of calls instead of a `Duration`.
#[cfg(test)]
stop_after: Option<(std::sync::Arc<std::sync::atomic::AtomicUsize>, usize)>,
}
impl fmt::Debug for TimeBudget {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TimeBudget")
.field("started_at", &self.started_at)
.field("budget", &self.budget)
.field("left", &(self.budget - self.started_at.elapsed()))
.finish()
}
}
impl Default for TimeBudget {
fn default() -> Self {
Self::new(std::time::Duration::from_millis(150))
}
}
impl TimeBudget {
pub fn new(budget: std::time::Duration) -> Self {
Self {
started_at: std::time::Instant::now(),
budget,
#[cfg(test)]
stop_after: None,
}
}
pub fn max() -> Self {
Self::new(std::time::Duration::from_secs(u64::MAX))
}
#[cfg(test)]
pub fn with_stop_after(mut self, stop_after: usize) -> Self {
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
self.stop_after = Some((Arc::new(AtomicUsize::new(0)), stop_after));
self
}
pub fn exceeded(&self) -> bool {
#[cfg(test)]
if let Some((current, stop_after)) = &self.stop_after {
let current = current.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if current >= *stop_after {
return true;
} else {
// if a number has been specified then we ignore entirely the time budget
return false;
}
}
self.started_at.elapsed() > self.budget
}
}
// Convert an absolute word position into a relative position.
// Return the field id of the attribute related to the absolute position
// and the relative position in the attribute.

57
milli/src/order_by_map.rs Normal file
View file

@ -0,0 +1,57 @@
use std::collections::{hash_map, HashMap};
use std::iter::FromIterator;
use serde::{Deserialize, Deserializer, Serialize};
use crate::OrderBy;
#[derive(Serialize)]
pub struct OrderByMap(HashMap<String, OrderBy>);
impl OrderByMap {
pub fn get(&self, key: impl AsRef<str>) -> OrderBy {
self.0
.get(key.as_ref())
.copied()
.unwrap_or_else(|| self.0.get("*").copied().unwrap_or_default())
}
pub fn insert(&mut self, key: String, value: OrderBy) -> Option<OrderBy> {
self.0.insert(key, value)
}
}
impl Default for OrderByMap {
fn default() -> Self {
let mut map = HashMap::new();
map.insert("*".to_string(), OrderBy::Lexicographic);
OrderByMap(map)
}
}
impl FromIterator<(String, OrderBy)> for OrderByMap {
fn from_iter<T: IntoIterator<Item = (String, OrderBy)>>(iter: T) -> Self {
OrderByMap(iter.into_iter().collect())
}
}
impl IntoIterator for OrderByMap {
type Item = (String, OrderBy);
type IntoIter = hash_map::IntoIter<String, OrderBy>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
impl<'de> Deserialize<'de> for OrderByMap {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let mut map = Deserialize::deserialize(deserializer).map(OrderByMap)?;
// Insert the default ordering if it is not already overwritten by the user.
map.0.entry("*".to_string()).or_insert(OrderBy::default());
Ok(map)
}
}

View file

@ -17,6 +17,9 @@ pub enum ScoreDetails {
Sort(Sort),
Vector(Vector),
GeoSort(GeoSort),
/// Returned when we don't have the time to finish applying all the subsequent ranking-rules
Skipped,
}
#[derive(Clone, Copy)]
@ -50,6 +53,7 @@ impl ScoreDetails {
ScoreDetails::Sort(_) => None,
ScoreDetails::GeoSort(_) => None,
ScoreDetails::Vector(_) => None,
ScoreDetails::Skipped => Some(Rank { rank: 0, max_rank: 1 }),
}
}
@ -97,6 +101,7 @@ impl ScoreDetails {
ScoreDetails::Vector(vector) => RankOrValue::Score(
vector.value_similarity.as_ref().map(|(_, s)| *s as f64).unwrap_or(0.0f64),
),
ScoreDetails::Skipped => RankOrValue::Rank(Rank { rank: 0, max_rank: 1 }),
}
}
@ -256,6 +261,11 @@ impl ScoreDetails {
details_map.insert(vector, details);
order += 1;
}
ScoreDetails::Skipped => {
details_map
.insert("skipped".to_string(), serde_json::json!({ "order": order }));
order += 1;
}
}
}
details_map

View file

@ -168,7 +168,7 @@ impl<'t, 'b, 'bitmap> FacetRangeSearch<'t, 'b, 'bitmap> {
}
// should we stop?
// We should if the the search range doesn't include any
// We should if the search range doesn't include any
// element from the previous key or its successors
let should_stop = {
match self.right {
@ -232,7 +232,7 @@ impl<'t, 'b, 'bitmap> FacetRangeSearch<'t, 'b, 'bitmap> {
}
// should we stop?
// We should if the the search range doesn't include any
// We should if the search range doesn't include any
// element from the previous key or its successors
let should_stop = {
match self.right {

View file

@ -6,15 +6,18 @@ use roaring::RoaringBitmap;
pub use self::facet_distribution::{FacetDistribution, OrderBy, DEFAULT_VALUES_PER_FACET};
pub use self::filter::{BadGeoError, Filter};
pub use self::search::{FacetValueHit, SearchForFacetValues};
use crate::heed_codec::facet::{FacetGroupKeyCodec, FacetGroupValueCodec, OrderedF64Codec};
use crate::heed_codec::BytesRefCodec;
use crate::{Index, Result};
mod facet_distribution;
mod facet_distribution_iter;
mod facet_range_search;
mod facet_sort_ascending;
mod facet_sort_descending;
mod filter;
mod search;
fn facet_extreme_value<'t>(
mut extreme_it: impl Iterator<Item = heed::Result<(RoaringBitmap, &'t [u8])>> + 't,

View file

@ -0,0 +1,326 @@
use std::cmp::{Ordering, Reverse};
use std::collections::BinaryHeap;
use std::ops::ControlFlow;
use charabia::normalizer::NormalizerOption;
use charabia::Normalize;
use fst::automaton::{Automaton, Str};
use fst::{IntoStreamer, Streamer};
use roaring::RoaringBitmap;
use tracing::error;
use crate::error::UserError;
use crate::heed_codec::facet::{FacetGroupKey, FacetGroupValue};
use crate::search::build_dfa;
use crate::{DocumentId, FieldId, OrderBy, Result, Search};
/// The maximum number of values per facet returned by the facet search route.
const DEFAULT_MAX_NUMBER_OF_VALUES_PER_FACET: usize = 100;
pub struct SearchForFacetValues<'a> {
query: Option<String>,
facet: String,
search_query: Search<'a>,
max_values: usize,
is_hybrid: bool,
}
impl<'a> SearchForFacetValues<'a> {
pub fn new(
facet: String,
search_query: Search<'a>,
is_hybrid: bool,
) -> SearchForFacetValues<'a> {
SearchForFacetValues {
query: None,
facet,
search_query,
max_values: DEFAULT_MAX_NUMBER_OF_VALUES_PER_FACET,
is_hybrid,
}
}
pub fn query(&mut self, query: impl Into<String>) -> &mut Self {
self.query = Some(query.into());
self
}
pub fn max_values(&mut self, max: usize) -> &mut Self {
self.max_values = max;
self
}
fn one_original_value_of(
&self,
field_id: FieldId,
facet_str: &str,
any_docid: DocumentId,
) -> Result<Option<String>> {
let index = self.search_query.index;
let rtxn = self.search_query.rtxn;
let key: (FieldId, _, &str) = (field_id, any_docid, facet_str);
Ok(index.field_id_docid_facet_strings.get(rtxn, &key)?.map(|v| v.to_owned()))
}
pub fn execute(&self) -> Result<Vec<FacetValueHit>> {
let index = self.search_query.index;
let rtxn = self.search_query.rtxn;
let filterable_fields = index.filterable_fields(rtxn)?;
if !filterable_fields.contains(&self.facet) {
let (valid_fields, hidden_fields) =
index.remove_hidden_fields(rtxn, filterable_fields)?;
return Err(UserError::InvalidFacetSearchFacetName {
field: self.facet.clone(),
valid_fields,
hidden_fields,
}
.into());
}
let fields_ids_map = index.fields_ids_map(rtxn)?;
let fid = match fields_ids_map.id(&self.facet) {
Some(fid) => fid,
// we return an empty list of results when the attribute has been
// set as filterable but no document contains this field (yet).
None => return Ok(Vec::new()),
};
let fst = match self.search_query.index.facet_id_string_fst.get(rtxn, &fid)? {
Some(fst) => fst,
None => return Ok(Vec::new()),
};
let search_candidates = self
.search_query
.execute_for_candidates(self.is_hybrid || self.search_query.vector.is_some())?;
let mut results = match index.sort_facet_values_by(rtxn)?.get(&self.facet) {
OrderBy::Lexicographic => ValuesCollection::by_lexicographic(self.max_values),
OrderBy::Count => ValuesCollection::by_count(self.max_values),
};
match self.query.as_ref() {
Some(query) => {
let options = NormalizerOption { lossy: true, ..Default::default() };
let query = query.normalize(&options);
let query = query.as_ref();
let authorize_typos = self.search_query.index.authorize_typos(rtxn)?;
let field_authorizes_typos =
!self.search_query.index.exact_attributes_ids(rtxn)?.contains(&fid);
if authorize_typos && field_authorizes_typos {
let exact_words_fst = self.search_query.index.exact_words(rtxn)?;
if exact_words_fst.map_or(false, |fst| fst.contains(query)) {
if fst.contains(query) {
self.fetch_original_facets_using_normalized(
fid,
query,
query,
&search_candidates,
&mut results,
)?;
}
} else {
let one_typo = self.search_query.index.min_word_len_one_typo(rtxn)?;
let two_typos = self.search_query.index.min_word_len_two_typos(rtxn)?;
let is_prefix = true;
let automaton = if query.len() < one_typo as usize {
build_dfa(query, 0, is_prefix)
} else if query.len() < two_typos as usize {
build_dfa(query, 1, is_prefix)
} else {
build_dfa(query, 2, is_prefix)
};
let mut stream = fst.search(automaton).into_stream();
while let Some(facet_value) = stream.next() {
let value = std::str::from_utf8(facet_value)?;
if self
.fetch_original_facets_using_normalized(
fid,
value,
query,
&search_candidates,
&mut results,
)?
.is_break()
{
break;
}
}
}
} else {
let automaton = Str::new(query).starts_with();
let mut stream = fst.search(automaton).into_stream();
while let Some(facet_value) = stream.next() {
let value = std::str::from_utf8(facet_value)?;
if self
.fetch_original_facets_using_normalized(
fid,
value,
query,
&search_candidates,
&mut results,
)?
.is_break()
{
break;
}
}
}
}
None => {
let prefix = FacetGroupKey { field_id: fid, level: 0, left_bound: "" };
for result in index.facet_id_string_docids.prefix_iter(rtxn, &prefix)? {
let (FacetGroupKey { left_bound, .. }, FacetGroupValue { bitmap, .. }) =
result?;
let count = search_candidates.intersection_len(&bitmap);
if count != 0 {
let value = self
.one_original_value_of(fid, left_bound, bitmap.min().unwrap())?
.unwrap_or_else(|| left_bound.to_string());
if results.insert(FacetValueHit { value, count }).is_break() {
break;
}
}
}
}
}
Ok(results.into_sorted_vec())
}
fn fetch_original_facets_using_normalized(
&self,
fid: FieldId,
value: &str,
query: &str,
search_candidates: &RoaringBitmap,
results: &mut ValuesCollection,
) -> Result<ControlFlow<()>> {
let index = self.search_query.index;
let rtxn = self.search_query.rtxn;
let database = index.facet_id_normalized_string_strings;
let key = (fid, value);
let original_strings = match database.get(rtxn, &key)? {
Some(original_strings) => original_strings,
None => {
error!("the facet value is missing from the facet database: {key:?}");
return Ok(ControlFlow::Continue(()));
}
};
for original in original_strings {
let key = FacetGroupKey { field_id: fid, level: 0, left_bound: original.as_str() };
let docids = match index.facet_id_string_docids.get(rtxn, &key)? {
Some(FacetGroupValue { bitmap, .. }) => bitmap,
None => {
error!("the facet value is missing from the facet database: {key:?}");
return Ok(ControlFlow::Continue(()));
}
};
let count = search_candidates.intersection_len(&docids);
if count != 0 {
let value = self
.one_original_value_of(fid, &original, docids.min().unwrap())?
.unwrap_or_else(|| query.to_string());
if results.insert(FacetValueHit { value, count }).is_break() {
break;
}
}
}
Ok(ControlFlow::Continue(()))
}
}
#[derive(Debug, Clone, serde::Serialize, PartialEq)]
pub struct FacetValueHit {
/// The original facet value
pub value: String,
/// The number of documents associated to this facet
pub count: u64,
}
impl PartialOrd for FacetValueHit {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for FacetValueHit {
fn cmp(&self, other: &Self) -> Ordering {
self.count.cmp(&other.count).then_with(|| self.value.cmp(&other.value))
}
}
impl Eq for FacetValueHit {}
/// A wrapper type that collects the best facet values by
/// lexicographic or number of associated values.
enum ValuesCollection {
/// Keeps the top values according to the lexicographic order.
Lexicographic { max: usize, content: Vec<FacetValueHit> },
/// Keeps the top values according to the number of values associated to them.
///
/// Note that it is a max heap and we need to move the smallest counts
/// at the top to be able to pop them when we reach the max_values limit.
Count { max: usize, content: BinaryHeap<Reverse<FacetValueHit>> },
}
impl ValuesCollection {
pub fn by_lexicographic(max: usize) -> Self {
ValuesCollection::Lexicographic { max, content: Vec::new() }
}
pub fn by_count(max: usize) -> Self {
ValuesCollection::Count { max, content: BinaryHeap::new() }
}
pub fn insert(&mut self, value: FacetValueHit) -> ControlFlow<()> {
match self {
ValuesCollection::Lexicographic { max, content } => {
if content.len() < *max {
content.push(value);
if content.len() < *max {
return ControlFlow::Continue(());
}
}
ControlFlow::Break(())
}
ValuesCollection::Count { max, content } => {
if content.len() == *max {
// Peeking gives us the worst value in the list as
// this is a max-heap and we reversed it.
let Some(mut peek) = content.peek_mut() else { return ControlFlow::Break(()) };
if peek.0.count <= value.count {
// Replace the current worst value in the heap
// with the new one we received that is better.
*peek = Reverse(value);
}
} else {
content.push(Reverse(value));
}
ControlFlow::Continue(())
}
}
}
/// Returns the list of facet values in descending order of, either,
/// count or lexicographic order of the value depending on the type.
pub fn into_sorted_vec(self) -> Vec<FacetValueHit> {
match self {
ValuesCollection::Lexicographic { content, .. } => content.into_iter().collect(),
ValuesCollection::Count { content, .. } => {
// Convert the heap into a vec of hits by removing the Reverse wrapper.
// Hits are already in the right order as they were reversed and there
// are output in ascending order.
content.into_sorted_vec().into_iter().map(|Reverse(hit)| hit).collect()
}
}
}
}

View file

@ -10,6 +10,7 @@ struct ScoreWithRatioResult {
matching_words: MatchingWords,
candidates: RoaringBitmap,
document_scores: Vec<(u32, ScoreWithRatio)>,
degraded: bool,
}
type ScoreWithRatio = (Vec<ScoreDetails>, f32);
@ -49,8 +50,12 @@ fn compare_scores(
order => return order,
}
}
(Some(ScoreValue::Score(_)), Some(_)) => return Ordering::Greater,
(Some(_), Some(ScoreValue::Score(_))) => return Ordering::Less,
(Some(ScoreValue::Score(x)), Some(_)) => {
return if x == 0. { Ordering::Less } else { Ordering::Greater }
}
(Some(_), Some(ScoreValue::Score(x))) => {
return if x == 0. { Ordering::Greater } else { Ordering::Less }
}
// if we have this, we're bad
(Some(ScoreValue::GeoSort(_)), Some(ScoreValue::Sort(_)))
| (Some(ScoreValue::Sort(_)), Some(ScoreValue::GeoSort(_))) => {
@ -72,6 +77,7 @@ impl ScoreWithRatioResult {
matching_words: results.matching_words,
candidates: results.candidates,
document_scores,
degraded: results.degraded,
}
}
@ -106,6 +112,7 @@ impl ScoreWithRatioResult {
candidates: left.candidates | right.candidates,
documents_ids,
document_scores,
degraded: left.degraded | right.degraded,
}
}
}
@ -131,6 +138,7 @@ impl<'a> Search<'a> {
index: self.index,
distribution_shift: self.distribution_shift,
embedder_name: self.embedder_name.clone(),
time_budget: self.time_budget.clone(),
};
let vector_query = search.vector.take();

View file

@ -1,25 +1,17 @@
use std::fmt;
use std::ops::ControlFlow;
use charabia::normalizer::NormalizerOption;
use charabia::Normalize;
use fst::automaton::{Automaton, Str};
use fst::{IntoStreamer, Streamer};
use levenshtein_automata::{LevenshteinAutomatonBuilder as LevBuilder, DFA};
use once_cell::sync::Lazy;
use roaring::bitmap::RoaringBitmap;
use tracing::error;
pub use self::facet::{FacetDistribution, Filter, OrderBy, DEFAULT_VALUES_PER_FACET};
pub use self::new::matches::{FormatOptions, MatchBounds, MatcherBuilder, MatchingWords};
use self::new::{execute_vector_search, PartialSearchResult};
use crate::error::UserError;
use crate::heed_codec::facet::{FacetGroupKey, FacetGroupValue};
use crate::score_details::{ScoreDetails, ScoringStrategy};
use crate::vector::DistributionShift;
use crate::{
execute_search, filtered_universe, AscDesc, DefaultSearchLogger, DocumentId, FieldId, Index,
Result, SearchContext,
execute_search, filtered_universe, AscDesc, DefaultSearchLogger, DocumentId, Index, Result,
SearchContext, TimeBudget,
};
// Building these factories is not free.
@ -27,9 +19,6 @@ static LEVDIST0: Lazy<LevBuilder> = Lazy::new(|| LevBuilder::new(0, true));
static LEVDIST1: Lazy<LevBuilder> = Lazy::new(|| LevBuilder::new(1, true));
static LEVDIST2: Lazy<LevBuilder> = Lazy::new(|| LevBuilder::new(2, true));
/// The maximum number of values per facet returned by the facet search route.
const DEFAULT_MAX_NUMBER_OF_VALUES_PER_FACET: usize = 100;
pub mod facet;
mod fst_utils;
pub mod hybrid;
@ -54,6 +43,8 @@ pub struct Search<'a> {
index: &'a Index,
distribution_shift: Option<DistributionShift>,
embedder_name: Option<String>,
time_budget: TimeBudget,
}
impl<'a> Search<'a> {
@ -75,6 +66,7 @@ impl<'a> Search<'a> {
index,
distribution_shift: None,
embedder_name: None,
time_budget: TimeBudget::max(),
}
}
@ -154,6 +146,11 @@ impl<'a> Search<'a> {
self
}
pub fn time_budget(&mut self, time_budget: TimeBudget) -> &mut Search<'a> {
self.time_budget = time_budget;
self
}
pub fn execute_for_candidates(&self, has_vector_search: bool) -> Result<RoaringBitmap> {
if has_vector_search {
let ctx = SearchContext::new(self.index, self.rtxn);
@ -180,36 +177,43 @@ impl<'a> Search<'a> {
}
let universe = filtered_universe(&ctx, &self.filter)?;
let PartialSearchResult { located_query_terms, candidates, documents_ids, document_scores } =
match self.vector.as_ref() {
Some(vector) => execute_vector_search(
&mut ctx,
vector,
self.scoring_strategy,
universe,
&self.sort_criteria,
self.geo_strategy,
self.offset,
self.limit,
self.distribution_shift,
embedder_name,
)?,
None => execute_search(
&mut ctx,
self.query.as_deref(),
self.terms_matching_strategy,
self.scoring_strategy,
self.exhaustive_number_hits,
universe,
&self.sort_criteria,
self.geo_strategy,
self.offset,
self.limit,
Some(self.words_limit),
&mut DefaultSearchLogger,
&mut DefaultSearchLogger,
)?,
};
let PartialSearchResult {
located_query_terms,
candidates,
documents_ids,
document_scores,
degraded,
} = match self.vector.as_ref() {
Some(vector) => execute_vector_search(
&mut ctx,
vector,
self.scoring_strategy,
universe,
&self.sort_criteria,
self.geo_strategy,
self.offset,
self.limit,
self.distribution_shift,
embedder_name,
self.time_budget.clone(),
)?,
None => execute_search(
&mut ctx,
self.query.as_deref(),
self.terms_matching_strategy,
self.scoring_strategy,
self.exhaustive_number_hits,
universe,
&self.sort_criteria,
self.geo_strategy,
self.offset,
self.limit,
Some(self.words_limit),
&mut DefaultSearchLogger,
&mut DefaultSearchLogger,
self.time_budget.clone(),
)?,
};
// consume context and located_query_terms to build MatchingWords.
let matching_words = match located_query_terms {
@ -217,7 +221,7 @@ impl<'a> Search<'a> {
None => MatchingWords::default(),
};
Ok(SearchResult { matching_words, candidates, document_scores, documents_ids })
Ok(SearchResult { matching_words, candidates, document_scores, documents_ids, degraded })
}
}
@ -240,6 +244,7 @@ impl fmt::Debug for Search<'_> {
index: _,
distribution_shift,
embedder_name,
time_budget,
} = self;
f.debug_struct("Search")
.field("query", query)
@ -255,6 +260,7 @@ impl fmt::Debug for Search<'_> {
.field("words_limit", words_limit)
.field("distribution_shift", distribution_shift)
.field("embedder_name", embedder_name)
.field("time_budget", time_budget)
.finish()
}
}
@ -265,6 +271,7 @@ pub struct SearchResult {
pub candidates: RoaringBitmap,
pub documents_ids: Vec<DocumentId>,
pub document_scores: Vec<Vec<ScoreDetails>>,
pub degraded: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -302,240 +309,6 @@ pub fn build_dfa(word: &str, typos: u8, is_prefix: bool) -> DFA {
}
}
pub struct SearchForFacetValues<'a> {
query: Option<String>,
facet: String,
search_query: Search<'a>,
max_values: usize,
is_hybrid: bool,
}
impl<'a> SearchForFacetValues<'a> {
pub fn new(
facet: String,
search_query: Search<'a>,
is_hybrid: bool,
) -> SearchForFacetValues<'a> {
SearchForFacetValues {
query: None,
facet,
search_query,
max_values: DEFAULT_MAX_NUMBER_OF_VALUES_PER_FACET,
is_hybrid,
}
}
pub fn query(&mut self, query: impl Into<String>) -> &mut Self {
self.query = Some(query.into());
self
}
pub fn max_values(&mut self, max: usize) -> &mut Self {
self.max_values = max;
self
}
fn one_original_value_of(
&self,
field_id: FieldId,
facet_str: &str,
any_docid: DocumentId,
) -> Result<Option<String>> {
let index = self.search_query.index;
let rtxn = self.search_query.rtxn;
let key: (FieldId, _, &str) = (field_id, any_docid, facet_str);
Ok(index.field_id_docid_facet_strings.get(rtxn, &key)?.map(|v| v.to_owned()))
}
pub fn execute(&self) -> Result<Vec<FacetValueHit>> {
let index = self.search_query.index;
let rtxn = self.search_query.rtxn;
let filterable_fields = index.filterable_fields(rtxn)?;
if !filterable_fields.contains(&self.facet) {
let (valid_fields, hidden_fields) =
index.remove_hidden_fields(rtxn, filterable_fields)?;
return Err(UserError::InvalidFacetSearchFacetName {
field: self.facet.clone(),
valid_fields,
hidden_fields,
}
.into());
}
let fields_ids_map = index.fields_ids_map(rtxn)?;
let fid = match fields_ids_map.id(&self.facet) {
Some(fid) => fid,
// we return an empty list of results when the attribute has been
// set as filterable but no document contains this field (yet).
None => return Ok(Vec::new()),
};
let fst = match self.search_query.index.facet_id_string_fst.get(rtxn, &fid)? {
Some(fst) => fst,
None => return Ok(vec![]),
};
let search_candidates = self
.search_query
.execute_for_candidates(self.is_hybrid || self.search_query.vector.is_some())?;
match self.query.as_ref() {
Some(query) => {
let options = NormalizerOption { lossy: true, ..Default::default() };
let query = query.normalize(&options);
let query = query.as_ref();
let authorize_typos = self.search_query.index.authorize_typos(rtxn)?;
let field_authorizes_typos =
!self.search_query.index.exact_attributes_ids(rtxn)?.contains(&fid);
if authorize_typos && field_authorizes_typos {
let exact_words_fst = self.search_query.index.exact_words(rtxn)?;
if exact_words_fst.map_or(false, |fst| fst.contains(query)) {
let mut results = vec![];
if fst.contains(query) {
self.fetch_original_facets_using_normalized(
fid,
query,
query,
&search_candidates,
&mut results,
)?;
}
Ok(results)
} else {
let one_typo = self.search_query.index.min_word_len_one_typo(rtxn)?;
let two_typos = self.search_query.index.min_word_len_two_typos(rtxn)?;
let is_prefix = true;
let automaton = if query.len() < one_typo as usize {
build_dfa(query, 0, is_prefix)
} else if query.len() < two_typos as usize {
build_dfa(query, 1, is_prefix)
} else {
build_dfa(query, 2, is_prefix)
};
let mut stream = fst.search(automaton).into_stream();
let mut results = vec![];
while let Some(facet_value) = stream.next() {
let value = std::str::from_utf8(facet_value)?;
if self
.fetch_original_facets_using_normalized(
fid,
value,
query,
&search_candidates,
&mut results,
)?
.is_break()
{
break;
}
}
Ok(results)
}
} else {
let automaton = Str::new(query).starts_with();
let mut stream = fst.search(automaton).into_stream();
let mut results = vec![];
while let Some(facet_value) = stream.next() {
let value = std::str::from_utf8(facet_value)?;
if self
.fetch_original_facets_using_normalized(
fid,
value,
query,
&search_candidates,
&mut results,
)?
.is_break()
{
break;
}
}
Ok(results)
}
}
None => {
let mut results = vec![];
let prefix = FacetGroupKey { field_id: fid, level: 0, left_bound: "" };
for result in index.facet_id_string_docids.prefix_iter(rtxn, &prefix)? {
let (FacetGroupKey { left_bound, .. }, FacetGroupValue { bitmap, .. }) =
result?;
let count = search_candidates.intersection_len(&bitmap);
if count != 0 {
let value = self
.one_original_value_of(fid, left_bound, bitmap.min().unwrap())?
.unwrap_or_else(|| left_bound.to_string());
results.push(FacetValueHit { value, count });
}
if results.len() >= self.max_values {
break;
}
}
Ok(results)
}
}
}
fn fetch_original_facets_using_normalized(
&self,
fid: FieldId,
value: &str,
query: &str,
search_candidates: &RoaringBitmap,
results: &mut Vec<FacetValueHit>,
) -> Result<ControlFlow<()>> {
let index = self.search_query.index;
let rtxn = self.search_query.rtxn;
let database = index.facet_id_normalized_string_strings;
let key = (fid, value);
let original_strings = match database.get(rtxn, &key)? {
Some(original_strings) => original_strings,
None => {
error!("the facet value is missing from the facet database: {key:?}");
return Ok(ControlFlow::Continue(()));
}
};
for original in original_strings {
let key = FacetGroupKey { field_id: fid, level: 0, left_bound: original.as_str() };
let docids = match index.facet_id_string_docids.get(rtxn, &key)? {
Some(FacetGroupValue { bitmap, .. }) => bitmap,
None => {
error!("the facet value is missing from the facet database: {key:?}");
return Ok(ControlFlow::Continue(()));
}
};
let count = search_candidates.intersection_len(&docids);
if count != 0 {
let value = self
.one_original_value_of(fid, &original, docids.min().unwrap())?
.unwrap_or_else(|| query.to_string());
results.push(FacetValueHit { value, count });
}
if results.len() >= self.max_values {
return Ok(ControlFlow::Break(()));
}
}
Ok(ControlFlow::Continue(()))
}
}
#[derive(Debug, Clone, serde::Serialize, PartialEq)]
pub struct FacetValueHit {
/// The original facet value
pub value: String,
/// The number of documents associated to this facet
pub count: u64,
}
#[cfg(test)]
mod test {
#[allow(unused_imports)]

View file

@ -5,12 +5,14 @@ use super::ranking_rules::{BoxRankingRule, RankingRuleQueryTrait};
use super::SearchContext;
use crate::score_details::{ScoreDetails, ScoringStrategy};
use crate::search::new::distinct::{apply_distinct_rule, distinct_single_docid, DistinctOutput};
use crate::Result;
use crate::{Result, TimeBudget};
pub struct BucketSortOutput {
pub docids: Vec<u32>,
pub scores: Vec<Vec<ScoreDetails>>,
pub all_candidates: RoaringBitmap,
pub degraded: bool,
}
// TODO: would probably be good to regroup some of these inside of a struct?
@ -25,6 +27,7 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
length: usize,
scoring_strategy: ScoringStrategy,
logger: &mut dyn SearchLogger<Q>,
time_budget: TimeBudget,
) -> Result<BucketSortOutput> {
logger.initial_query(query);
logger.ranking_rules(&ranking_rules);
@ -41,6 +44,7 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
docids: vec![],
scores: vec![],
all_candidates: universe.clone(),
degraded: false,
});
}
if ranking_rules.is_empty() {
@ -74,6 +78,7 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
scores: vec![Default::default(); results.len()],
docids: results,
all_candidates,
degraded: false,
});
} else {
let docids: Vec<u32> = universe.iter().skip(from).take(length).collect();
@ -81,6 +86,7 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
scores: vec![Default::default(); docids.len()],
docids,
all_candidates: universe.clone(),
degraded: false,
});
};
}
@ -154,6 +160,28 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
}
while valid_docids.len() < length {
if time_budget.exceeded() {
loop {
let bucket = std::mem::take(&mut ranking_rule_universes[cur_ranking_rule_index]);
ranking_rule_scores.push(ScoreDetails::Skipped);
maybe_add_to_results!(bucket);
ranking_rule_scores.pop();
if cur_ranking_rule_index == 0 {
break;
}
back!();
}
return Ok(BucketSortOutput {
scores: valid_scores,
docids: valid_docids,
all_candidates,
degraded: true,
});
}
// The universe for this bucket is zero, so we don't need to sort
// anything, just go back to the parent ranking rule.
if ranking_rule_universes[cur_ranking_rule_index].is_empty()
@ -219,7 +247,12 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
)?;
}
Ok(BucketSortOutput { docids: valid_docids, scores: valid_scores, all_candidates })
Ok(BucketSortOutput {
docids: valid_docids,
scores: valid_scores,
all_candidates,
degraded: false,
})
}
/// Add the candidates to the results. Take `distinct`, `from`, `length`, and `cur_offset`

View file

@ -502,7 +502,7 @@ mod tests {
use super::*;
use crate::index::tests::TempIndex;
use crate::{execute_search, filtered_universe, SearchContext};
use crate::{execute_search, filtered_universe, SearchContext, TimeBudget};
impl<'a> MatcherBuilder<'a> {
fn new_test(rtxn: &'a heed::RoTxn, index: &'a TempIndex, query: &str) -> Self {
@ -522,6 +522,7 @@ mod tests {
Some(10),
&mut crate::DefaultSearchLogger,
&mut crate::DefaultSearchLogger,
TimeBudget::max(),
)
.unwrap();

View file

@ -52,7 +52,8 @@ use crate::score_details::{ScoreDetails, ScoringStrategy};
use crate::search::new::distinct::apply_distinct_rule;
use crate::vector::DistributionShift;
use crate::{
AscDesc, DocumentId, FieldId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError,
AscDesc, DocumentId, FieldId, Filter, Index, Member, Result, TermsMatchingStrategy, TimeBudget,
UserError,
};
/// A structure used throughout the execution of a search query.
@ -518,6 +519,7 @@ pub fn execute_vector_search(
length: usize,
distribution_shift: Option<DistributionShift>,
embedder_name: &str,
time_budget: TimeBudget,
) -> Result<PartialSearchResult> {
check_sort_criteria(ctx, sort_criteria.as_ref())?;
@ -537,7 +539,7 @@ pub fn execute_vector_search(
let placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery> =
&mut placeholder_search_logger;
let BucketSortOutput { docids, scores, all_candidates } = bucket_sort(
let BucketSortOutput { docids, scores, all_candidates, degraded } = bucket_sort(
ctx,
ranking_rules,
&PlaceholderQuery,
@ -546,6 +548,7 @@ pub fn execute_vector_search(
length,
scoring_strategy,
placeholder_search_logger,
time_budget,
)?;
Ok(PartialSearchResult {
@ -553,6 +556,7 @@ pub fn execute_vector_search(
document_scores: scores,
documents_ids: docids,
located_query_terms: None,
degraded,
})
}
@ -572,6 +576,7 @@ pub fn execute_search(
words_limit: Option<usize>,
placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery>,
query_graph_logger: &mut dyn SearchLogger<QueryGraph>,
time_budget: TimeBudget,
) -> Result<PartialSearchResult> {
check_sort_criteria(ctx, sort_criteria.as_ref())?;
@ -648,6 +653,7 @@ pub fn execute_search(
length,
scoring_strategy,
query_graph_logger,
time_budget,
)?
} else {
let ranking_rules =
@ -661,10 +667,11 @@ pub fn execute_search(
length,
scoring_strategy,
placeholder_search_logger,
time_budget,
)?
};
let BucketSortOutput { docids, scores, mut all_candidates } = bucket_sort_output;
let BucketSortOutput { docids, scores, mut all_candidates, degraded } = bucket_sort_output;
let fields_ids_map = ctx.index.fields_ids_map(ctx.txn)?;
// The candidates is the universe unless the exhaustive number of hits
@ -682,6 +689,7 @@ pub fn execute_search(
document_scores: scores,
documents_ids: docids,
located_query_terms,
degraded,
})
}
@ -742,4 +750,6 @@ pub struct PartialSearchResult {
pub candidates: RoaringBitmap,
pub documents_ids: Vec<DocumentId>,
pub document_scores: Vec<Vec<ScoreDetails>>,
pub degraded: bool,
}

View file

@ -0,0 +1,429 @@
//! This module test the search cutoff and ensure a few things:
//! 1. A basic test works and mark the search as degraded
//! 2. A test that ensure the filters are affectively applied even with a cutoff of 0
//! 3. A test that ensure the cutoff works well with the ranking scores
use std::time::Duration;
use big_s::S;
use maplit::hashset;
use meili_snap::snapshot;
use crate::index::tests::TempIndex;
use crate::score_details::{ScoreDetails, ScoringStrategy};
use crate::{Criterion, Filter, Search, TimeBudget};
fn create_index() -> TempIndex {
let index = TempIndex::new();
index
.update_settings(|s| {
s.set_primary_key("id".to_owned());
s.set_searchable_fields(vec!["text".to_owned()]);
s.set_filterable_fields(hashset! { S("id") });
s.set_criteria(vec![Criterion::Words, Criterion::Typo]);
})
.unwrap();
// reverse the ID / insertion order so we see better what was sorted from what got the insertion order ordering
index
.add_documents(documents!([
{
"id": 4,
"text": "hella puppo kefir",
},
{
"id": 3,
"text": "hella puppy kefir",
},
{
"id": 2,
"text": "hello",
},
{
"id": 1,
"text": "hello puppy",
},
{
"id": 0,
"text": "hello puppy kefir",
},
]))
.unwrap();
index
}
#[test]
fn basic_degraded_search() {
let index = create_index();
let rtxn = index.read_txn().unwrap();
let mut search = Search::new(&rtxn, &index);
search.query("hello puppy kefir");
search.limit(3);
search.time_budget(TimeBudget::new(Duration::from_millis(0)));
let result = search.execute().unwrap();
assert!(result.degraded);
}
#[test]
fn degraded_search_cannot_skip_filter() {
let index = create_index();
let rtxn = index.read_txn().unwrap();
let mut search = Search::new(&rtxn, &index);
search.query("hello puppy kefir");
search.limit(100);
search.time_budget(TimeBudget::new(Duration::from_millis(0)));
let filter_condition = Filter::from_str("id > 2").unwrap().unwrap();
search.filter(filter_condition);
let result = search.execute().unwrap();
assert!(result.degraded);
snapshot!(format!("{:?}\n{:?}", result.candidates, result.documents_ids), @r###"
RoaringBitmap<[0, 1]>
[0, 1]
"###);
}
#[test]
#[allow(clippy::format_collect)] // the test is already quite big
fn degraded_search_and_score_details() {
let index = create_index();
let rtxn = index.read_txn().unwrap();
let mut search = Search::new(&rtxn, &index);
search.query("hello puppy kefir");
search.limit(4);
search.scoring_strategy(ScoringStrategy::Detailed);
search.time_budget(TimeBudget::max());
let result = search.execute().unwrap();
snapshot!(format!("IDs: {:?}\nScores: {}\nScore Details:\n{:#?}", result.documents_ids, result.document_scores.iter().map(|scores| format!("{:.4} ", ScoreDetails::global_score(scores.iter()))).collect::<String>(), result.document_scores), @r###"
IDs: [4, 1, 0, 3]
Scores: 1.0000 0.9167 0.8333 0.6667
Score Details:
[
[
Words(
Words {
matching_words: 3,
max_matching_words: 3,
},
),
Typo(
Typo {
typo_count: 0,
max_typo_count: 3,
},
),
],
[
Words(
Words {
matching_words: 3,
max_matching_words: 3,
},
),
Typo(
Typo {
typo_count: 1,
max_typo_count: 3,
},
),
],
[
Words(
Words {
matching_words: 3,
max_matching_words: 3,
},
),
Typo(
Typo {
typo_count: 2,
max_typo_count: 3,
},
),
],
[
Words(
Words {
matching_words: 2,
max_matching_words: 3,
},
),
Typo(
Typo {
typo_count: 0,
max_typo_count: 2,
},
),
],
]
"###);
// Do ONE loop iteration. Not much can be deduced, almost everyone matched the words first bucket.
search.time_budget(TimeBudget::max().with_stop_after(1));
let result = search.execute().unwrap();
snapshot!(format!("IDs: {:?}\nScores: {}\nScore Details:\n{:#?}", result.documents_ids, result.document_scores.iter().map(|scores| format!("{:.4} ", ScoreDetails::global_score(scores.iter()))).collect::<String>(), result.document_scores), @r###"
IDs: [0, 1, 4, 2]
Scores: 0.6667 0.6667 0.6667 0.0000
Score Details:
[
[
Words(
Words {
matching_words: 3,
max_matching_words: 3,
},
),
Skipped,
],
[
Words(
Words {
matching_words: 3,
max_matching_words: 3,
},
),
Skipped,
],
[
Words(
Words {
matching_words: 3,
max_matching_words: 3,
},
),
Skipped,
],
[
Skipped,
],
]
"###);
// Do TWO loop iterations. The first document should be entirely sorted
search.time_budget(TimeBudget::max().with_stop_after(2));
let result = search.execute().unwrap();
snapshot!(format!("IDs: {:?}\nScores: {}\nScore Details:\n{:#?}", result.documents_ids, result.document_scores.iter().map(|scores| format!("{:.4} ", ScoreDetails::global_score(scores.iter()))).collect::<String>(), result.document_scores), @r###"
IDs: [4, 0, 1, 2]
Scores: 1.0000 0.6667 0.6667 0.0000
Score Details:
[
[
Words(
Words {
matching_words: 3,
max_matching_words: 3,
},
),
Typo(
Typo {
typo_count: 0,
max_typo_count: 3,
},
),
],
[
Words(
Words {
matching_words: 3,
max_matching_words: 3,
},
),
Skipped,
],
[
Words(
Words {
matching_words: 3,
max_matching_words: 3,
},
),
Skipped,
],
[
Skipped,
],
]
"###);
// Do THREE loop iterations. The second document should be entirely sorted as well
search.time_budget(TimeBudget::max().with_stop_after(3));
let result = search.execute().unwrap();
snapshot!(format!("IDs: {:?}\nScores: {}\nScore Details:\n{:#?}", result.documents_ids, result.document_scores.iter().map(|scores| format!("{:.4} ", ScoreDetails::global_score(scores.iter()))).collect::<String>(), result.document_scores), @r###"
IDs: [4, 1, 0, 2]
Scores: 1.0000 0.9167 0.6667 0.0000
Score Details:
[
[
Words(
Words {
matching_words: 3,
max_matching_words: 3,
},
),
Typo(
Typo {
typo_count: 0,
max_typo_count: 3,
},
),
],
[
Words(
Words {
matching_words: 3,
max_matching_words: 3,
},
),
Typo(
Typo {
typo_count: 1,
max_typo_count: 3,
},
),
],
[
Words(
Words {
matching_words: 3,
max_matching_words: 3,
},
),
Skipped,
],
[
Skipped,
],
]
"###);
// Do FOUR loop iterations. The third document should be entirely sorted as well
// The words bucket have still not progressed thus the last document doesn't have any info yet.
search.time_budget(TimeBudget::max().with_stop_after(4));
let result = search.execute().unwrap();
snapshot!(format!("IDs: {:?}\nScores: {}\nScore Details:\n{:#?}", result.documents_ids, result.document_scores.iter().map(|scores| format!("{:.4} ", ScoreDetails::global_score(scores.iter()))).collect::<String>(), result.document_scores), @r###"
IDs: [4, 1, 0, 2]
Scores: 1.0000 0.9167 0.8333 0.0000
Score Details:
[
[
Words(
Words {
matching_words: 3,
max_matching_words: 3,
},
),
Typo(
Typo {
typo_count: 0,
max_typo_count: 3,
},
),
],
[
Words(
Words {
matching_words: 3,
max_matching_words: 3,
},
),
Typo(
Typo {
typo_count: 1,
max_typo_count: 3,
},
),
],
[
Words(
Words {
matching_words: 3,
max_matching_words: 3,
},
),
Typo(
Typo {
typo_count: 2,
max_typo_count: 3,
},
),
],
[
Skipped,
],
]
"###);
// After SIX loop iteration. The words ranking rule gave us a new bucket.
// Since we reached the limit we were able to early exit without checking the typo ranking rule.
search.time_budget(TimeBudget::max().with_stop_after(6));
let result = search.execute().unwrap();
snapshot!(format!("IDs: {:?}\nScores: {}\nScore Details:\n{:#?}", result.documents_ids, result.document_scores.iter().map(|scores| format!("{:.4} ", ScoreDetails::global_score(scores.iter()))).collect::<String>(), result.document_scores), @r###"
IDs: [4, 1, 0, 3]
Scores: 1.0000 0.9167 0.8333 0.3333
Score Details:
[
[
Words(
Words {
matching_words: 3,
max_matching_words: 3,
},
),
Typo(
Typo {
typo_count: 0,
max_typo_count: 3,
},
),
],
[
Words(
Words {
matching_words: 3,
max_matching_words: 3,
},
),
Typo(
Typo {
typo_count: 1,
max_typo_count: 3,
},
),
],
[
Words(
Words {
matching_words: 3,
max_matching_words: 3,
},
),
Typo(
Typo {
typo_count: 2,
max_typo_count: 3,
},
),
],
[
Words(
Words {
matching_words: 2,
max_matching_words: 3,
},
),
Skipped,
],
]
"###);
}

View file

@ -1,5 +1,6 @@
pub mod attribute_fid;
pub mod attribute_position;
pub mod cutoff;
pub mod distinct;
pub mod exactness;
pub mod geo_sort;

View file

@ -5,7 +5,7 @@ The typo ranking rule should transform the query graph such that it only contain
the combinations of word derivations that it used to compute its bucket.
The proximity ranking rule should then look for proximities only between those specific derivations.
For example, given the the search query `beautiful summer` and the dataset:
For example, given the search query `beautiful summer` and the dataset:
```text
{ "id": 0, "text": "beautigul summer...... beautiful day in the summer" }
{ "id": 1, "text": "beautiful summer" }

View file

@ -339,6 +339,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
prompt_reader: grenad::Reader<R>,
indexer: GrenadParameters,
embedder: Arc<Embedder>,
request_threads: &rayon::ThreadPool,
) -> Result<grenad::Reader<BufReader<File>>> {
puffin::profile_function!();
let n_chunks = embedder.chunk_count_hint(); // chunk level parallelism
@ -376,7 +377,10 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
if chunks.len() == chunks.capacity() {
let chunked_embeds = embedder
.embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks)))
.embed_chunks(
std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks)),
request_threads,
)
.map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?;
@ -394,7 +398,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
// send last chunk
if !chunks.is_empty() {
let chunked_embeds = embedder
.embed_chunks(std::mem::take(&mut chunks))
.embed_chunks(std::mem::take(&mut chunks), request_threads)
.map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?;
for (docid, embeddings) in chunks_ids
@ -408,7 +412,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
if !current_chunk.is_empty() {
let embeds = embedder
.embed_chunks(vec![std::mem::take(&mut current_chunk)])
.embed_chunks(vec![std::mem::take(&mut current_chunk)], request_threads)
.map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?;

View file

@ -238,6 +238,12 @@ fn send_original_documents_data(
let documents_chunk_cloned = original_documents_chunk.clone();
let lmdb_writer_sx_cloned = lmdb_writer_sx.clone();
let request_threads = rayon::ThreadPoolBuilder::new()
.num_threads(crate::vector::REQUEST_PARALLELISM)
.thread_name(|index| format!("embedding-request-{index}"))
.build()?;
rayon::spawn(move || {
for (name, (embedder, prompt)) in embedders {
let result = extract_vector_points(
@ -249,7 +255,12 @@ fn send_original_documents_data(
);
match result {
Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => {
let embeddings = match extract_embeddings(prompts, indexer, embedder.clone()) {
let embeddings = match extract_embeddings(
prompts,
indexer,
embedder.clone(),
&request_threads,
) {
Ok(results) => Some(results),
Err(error) => {
let _ = lmdb_writer_sx_cloned.send(Err(error));

View file

@ -2646,6 +2646,12 @@ mod tests {
api_key: Setting::NotSet,
dimensions: Setting::Set(3),
document_template: Setting::NotSet,
url: Setting::NotSet,
query: Setting::NotSet,
input_field: Setting::NotSet,
path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet,
input_type: Setting::NotSet,
}),
);
settings.set_embedder_settings(embedders);

View file

@ -14,12 +14,13 @@ use super::IndexerConfig;
use crate::criterion::Criterion;
use crate::error::UserError;
use crate::index::{DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS};
use crate::order_by_map::OrderByMap;
use crate::proximity::ProximityPrecision;
use crate::update::index_documents::IndexDocumentsMethod;
use crate::update::{IndexDocuments, UpdateIndexingStep};
use crate::vector::settings::{check_set, check_unset, EmbedderSource, EmbeddingSettings};
use crate::vector::{Embedder, EmbeddingConfig, EmbeddingConfigs};
use crate::{FieldsIdsMap, Index, OrderBy, Result};
use crate::{FieldsIdsMap, Index, Result};
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
pub enum Setting<T> {
@ -145,10 +146,11 @@ pub struct Settings<'a, 't, 'i> {
/// Attributes on which typo tolerance is disabled.
exact_attributes: Setting<HashSet<String>>,
max_values_per_facet: Setting<usize>,
sort_facet_values_by: Setting<HashMap<String, OrderBy>>,
sort_facet_values_by: Setting<OrderByMap>,
pagination_max_total_hits: Setting<usize>,
proximity_precision: Setting<ProximityPrecision>,
embedder_settings: Setting<BTreeMap<String, Setting<EmbeddingSettings>>>,
search_cutoff: Setting<u64>,
}
impl<'a, 't, 'i> Settings<'a, 't, 'i> {
@ -182,6 +184,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
pagination_max_total_hits: Setting::NotSet,
proximity_precision: Setting::NotSet,
embedder_settings: Setting::NotSet,
search_cutoff: Setting::NotSet,
indexer_config,
}
}
@ -340,7 +343,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
self.max_values_per_facet = Setting::Reset;
}
pub fn set_sort_facet_values_by(&mut self, value: HashMap<String, OrderBy>) {
pub fn set_sort_facet_values_by(&mut self, value: OrderByMap) {
self.sort_facet_values_by = Setting::Set(value);
}
@ -372,6 +375,14 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
self.embedder_settings = Setting::Reset;
}
pub fn set_search_cutoff(&mut self, value: u64) {
self.search_cutoff = Setting::Set(value);
}
pub fn reset_search_cutoff(&mut self) {
self.search_cutoff = Setting::Reset;
}
#[tracing::instrument(
level = "trace"
skip(self, progress_callback, should_abort, old_fields_ids_map),
@ -1025,6 +1036,24 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
Ok(update)
}
fn update_search_cutoff(&mut self) -> Result<bool> {
let changed = match self.search_cutoff {
Setting::Set(new) => {
let old = self.index.search_cutoff(self.wtxn)?;
if old == Some(new) {
false
} else {
self.index.put_search_cutoff(self.wtxn, new)?;
true
}
}
Setting::Reset => self.index.delete_search_cutoff(self.wtxn)?,
Setting::NotSet => false,
};
Ok(changed)
}
pub fn execute<FP, FA>(mut self, progress_callback: FP, should_abort: FA) -> Result<()>
where
FP: Fn(UpdateIndexingStep) + Sync,
@ -1073,6 +1102,9 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
// 3. Keep the old vectors but reattempt indexing on a prompt change: only actually changed prompt will need embedding + storage
let embedding_configs_updated = self.update_embedding_configs()?;
// never trigger re-indexing
self.update_search_cutoff()?;
if stop_words_updated
|| non_separator_tokens_updated
|| separator_tokens_updated
@ -1131,6 +1163,12 @@ fn validate_prompt(
api_key,
dimensions,
document_template: Setting::Set(template),
url,
query,
input_field,
path_to_embeddings,
embedding_object,
input_type,
}) => {
// validate
let template = crate::prompt::Prompt::new(template)
@ -1144,6 +1182,12 @@ fn validate_prompt(
api_key,
dimensions,
document_template: Setting::Set(template),
url,
query,
input_field,
path_to_embeddings,
embedding_object,
input_type,
}))
}
new => Ok(new),
@ -1156,8 +1200,20 @@ pub fn validate_embedding_settings(
) -> Result<Setting<EmbeddingSettings>> {
let settings = validate_prompt(name, settings)?;
let Setting::Set(settings) = settings else { return Ok(settings) };
let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } =
settings;
let EmbeddingSettings {
source,
model,
revision,
api_key,
dimensions,
document_template,
url,
query,
input_field,
path_to_embeddings,
embedding_object,
input_type,
} = settings;
if let Some(0) = dimensions.set() {
return Err(crate::error::UserError::InvalidSettingsDimensions {
@ -1166,6 +1222,14 @@ pub fn validate_embedding_settings(
.into());
}
if let Some(url) = url.as_ref().set() {
url::Url::parse(url).map_err(|error| crate::error::UserError::InvalidUrl {
embedder_name: name.to_owned(),
inner_error: error,
url: url.to_owned(),
})?;
}
let Some(inferred_source) = source.set() else {
return Ok(Setting::Set(EmbeddingSettings {
source,
@ -1174,11 +1238,25 @@ pub fn validate_embedding_settings(
api_key,
dimensions,
document_template,
url,
query,
input_field,
path_to_embeddings,
embedding_object,
input_type,
}));
};
match inferred_source {
EmbedderSource::OpenAi => {
check_unset(&revision, "revision", inferred_source, name)?;
check_unset(&url, "url", inferred_source, name)?;
check_unset(&query, "query", inferred_source, name)?;
check_unset(&input_field, "inputField", inferred_source, name)?;
check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?;
check_unset(&embedding_object, "embeddingObject", inferred_source, name)?;
check_unset(&input_type, "inputType", inferred_source, name)?;
if let Setting::Set(model) = &model {
let model = crate::vector::openai::EmbeddingModel::from_name(model.as_str())
.ok_or(crate::error::UserError::InvalidOpenAiModel {
@ -1209,9 +1287,30 @@ pub fn validate_embedding_settings(
}
}
}
EmbedderSource::Ollama => {
// Dimensions get inferred, only model name is required
check_unset(&dimensions, "dimensions", inferred_source, name)?;
check_set(&model, "model", inferred_source, name)?;
check_unset(&api_key, "apiKey", inferred_source, name)?;
check_unset(&revision, "revision", inferred_source, name)?;
check_unset(&url, "url", inferred_source, name)?;
check_unset(&query, "query", inferred_source, name)?;
check_unset(&input_field, "inputField", inferred_source, name)?;
check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?;
check_unset(&embedding_object, "embeddingObject", inferred_source, name)?;
check_unset(&input_type, "inputType", inferred_source, name)?;
}
EmbedderSource::HuggingFace => {
check_unset(&api_key, "apiKey", inferred_source, name)?;
check_unset(&dimensions, "dimensions", inferred_source, name)?;
check_unset(&url, "url", inferred_source, name)?;
check_unset(&query, "query", inferred_source, name)?;
check_unset(&input_field, "inputField", inferred_source, name)?;
check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?;
check_unset(&embedding_object, "embeddingObject", inferred_source, name)?;
check_unset(&input_type, "inputType", inferred_source, name)?;
}
EmbedderSource::UserProvided => {
check_unset(&model, "model", inferred_source, name)?;
@ -1219,6 +1318,18 @@ pub fn validate_embedding_settings(
check_unset(&api_key, "apiKey", inferred_source, name)?;
check_unset(&document_template, "documentTemplate", inferred_source, name)?;
check_set(&dimensions, "dimensions", inferred_source, name)?;
check_unset(&url, "url", inferred_source, name)?;
check_unset(&query, "query", inferred_source, name)?;
check_unset(&input_field, "inputField", inferred_source, name)?;
check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?;
check_unset(&embedding_object, "embeddingObject", inferred_source, name)?;
check_unset(&input_type, "inputType", inferred_source, name)?;
}
EmbedderSource::Rest => {
check_unset(&model, "model", inferred_source, name)?;
check_unset(&revision, "revision", inferred_source, name)?;
check_set(&url, "url", inferred_source, name)?;
}
}
Ok(Setting::Set(EmbeddingSettings {
@ -1228,6 +1339,12 @@ pub fn validate_embedding_settings(
api_key,
dimensions,
document_template,
url,
query,
input_field,
path_to_embeddings,
embedding_object,
input_type,
}))
}
@ -2050,6 +2167,7 @@ mod tests {
pagination_max_total_hits,
proximity_precision,
embedder_settings,
search_cutoff,
} = settings;
assert!(matches!(searchable_fields, Setting::NotSet));
assert!(matches!(displayed_fields, Setting::NotSet));
@ -2073,6 +2191,7 @@ mod tests {
assert!(matches!(pagination_max_total_hits, Setting::NotSet));
assert!(matches!(proximity_precision, Setting::NotSet));
assert!(matches!(embedder_settings, Setting::NotSet));
assert!(matches!(search_cutoff, Setting::NotSet));
})
.unwrap();
}

View file

@ -3,7 +3,6 @@ use std::path::PathBuf;
use hf_hub::api::sync::ApiError;
use crate::error::FaultSource;
use crate::vector::openai::OpenAiError;
#[derive(Debug, thiserror::Error)]
#[error("Error while generating embeddings: {inner}")]
@ -51,26 +50,34 @@ pub enum EmbedErrorKind {
TensorValue(candle_core::Error),
#[error("could not run model: {0}")]
ModelForward(candle_core::Error),
#[error("could not reach OpenAI: {0}")]
OpenAiNetwork(reqwest::Error),
#[error("unexpected response from OpenAI: {0}")]
OpenAiUnexpected(reqwest::Error),
#[error("could not authenticate against OpenAI: {0}")]
OpenAiAuth(OpenAiError),
#[error("sent too many requests to OpenAI: {0}")]
OpenAiTooManyRequests(OpenAiError),
#[error("received internal error from OpenAI: {0:?}")]
OpenAiInternalServerError(Option<OpenAiError>),
#[error("sent too many tokens in a request to OpenAI: {0}")]
OpenAiTooManyTokens(OpenAiError),
#[error("received unhandled HTTP status code {0} from OpenAI")]
OpenAiUnhandledStatusCode(u16),
#[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")]
ManualEmbed(String),
#[error("could not initialize asynchronous runtime: {0}")]
OpenAiRuntimeInit(std::io::Error),
#[error("initializing web client for sending embedding requests failed: {0}")]
InitWebClient(reqwest::Error),
#[error("model not found. Meilisearch will not automatically download models from the Ollama library, please pull the model manually: {0:?}")]
OllamaModelNotFoundError(Option<String>),
#[error("error deserialization the response body as JSON: {0}")]
RestResponseDeserialization(std::io::Error),
#[error("component `{0}` not found in path `{1}` in response: `{2}`")]
RestResponseMissingEmbeddings(String, String, String),
#[error("expected a response parseable as a vector or an array of vectors: {0}")]
RestResponseFormat(serde_json::Error),
#[error("expected a response containing {0} embeddings, got only {1}")]
RestResponseEmbeddingCount(usize, usize),
#[error("could not authenticate against embedding server: {0:?}")]
RestUnauthorized(Option<String>),
#[error("sent too many requests to embedding server: {0:?}")]
RestTooManyRequests(Option<String>),
#[error("sent a bad request to embedding server: {0:?}")]
RestBadRequest(Option<String>),
#[error("received internal error from embedding server: {0:?}")]
RestInternalServerError(u16, Option<String>),
#[error("received HTTP {0} from embedding server: {0:?}")]
RestOtherStatusCode(u16, Option<String>),
#[error("could not reach embedding server: {0}")]
RestNetwork(ureq::Transport),
#[error("was expected '{}' to be an object in query '{0}'", .1.join("."))]
RestNotAnObject(serde_json::Value, Vec<String>),
#[error("while embedding tokenized, was expecting embeddings of dimension `{0}`, got embeddings of dimensions `{1}`")]
OpenAiUnexpectedDimension(usize, usize),
}
impl EmbedError {
@ -90,44 +97,98 @@ impl EmbedError {
Self { kind: EmbedErrorKind::ModelForward(inner), fault: FaultSource::Runtime }
}
pub fn openai_network(inner: reqwest::Error) -> Self {
Self { kind: EmbedErrorKind::OpenAiNetwork(inner), fault: FaultSource::Runtime }
}
pub fn openai_unexpected(inner: reqwest::Error) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiUnexpected(inner), fault: FaultSource::Bug }
}
pub(crate) fn openai_auth_error(inner: OpenAiError) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiAuth(inner), fault: FaultSource::User }
}
pub(crate) fn openai_too_many_requests(inner: OpenAiError) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiTooManyRequests(inner), fault: FaultSource::Runtime }
}
pub(crate) fn openai_internal_server_error(inner: Option<OpenAiError>) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiInternalServerError(inner), fault: FaultSource::Runtime }
}
pub(crate) fn openai_too_many_tokens(inner: OpenAiError) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiTooManyTokens(inner), fault: FaultSource::Bug }
}
pub(crate) fn openai_unhandled_status_code(code: u16) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiUnhandledStatusCode(code), fault: FaultSource::Bug }
}
pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError {
Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User }
}
pub(crate) fn openai_runtime_init(inner: std::io::Error) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiRuntimeInit(inner), fault: FaultSource::Runtime }
pub(crate) fn ollama_model_not_found(inner: Option<String>) -> EmbedError {
Self { kind: EmbedErrorKind::OllamaModelNotFoundError(inner), fault: FaultSource::User }
}
pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self {
Self { kind: EmbedErrorKind::InitWebClient(inner), fault: FaultSource::Runtime }
pub(crate) fn rest_response_deserialization(error: std::io::Error) -> EmbedError {
Self {
kind: EmbedErrorKind::RestResponseDeserialization(error),
fault: FaultSource::Runtime,
}
}
pub(crate) fn rest_response_missing_embeddings<S: AsRef<str>>(
response: serde_json::Value,
component: &str,
response_field: &[S],
) -> EmbedError {
let response_field: Vec<&str> = response_field.iter().map(AsRef::as_ref).collect();
let response_field = response_field.join(".");
Self {
kind: EmbedErrorKind::RestResponseMissingEmbeddings(
component.to_owned(),
response_field,
serde_json::to_string_pretty(&response).unwrap_or_default(),
),
fault: FaultSource::Undecided,
}
}
pub(crate) fn rest_response_format(error: serde_json::Error) -> EmbedError {
Self { kind: EmbedErrorKind::RestResponseFormat(error), fault: FaultSource::Undecided }
}
pub(crate) fn rest_response_embedding_count(expected: usize, got: usize) -> EmbedError {
Self {
kind: EmbedErrorKind::RestResponseEmbeddingCount(expected, got),
fault: FaultSource::Runtime,
}
}
pub(crate) fn rest_unauthorized(error_response: Option<String>) -> EmbedError {
Self { kind: EmbedErrorKind::RestUnauthorized(error_response), fault: FaultSource::User }
}
pub(crate) fn rest_too_many_requests(error_response: Option<String>) -> EmbedError {
Self {
kind: EmbedErrorKind::RestTooManyRequests(error_response),
fault: FaultSource::Runtime,
}
}
pub(crate) fn rest_bad_request(error_response: Option<String>) -> EmbedError {
Self { kind: EmbedErrorKind::RestBadRequest(error_response), fault: FaultSource::User }
}
pub(crate) fn rest_internal_server_error(
code: u16,
error_response: Option<String>,
) -> EmbedError {
Self {
kind: EmbedErrorKind::RestInternalServerError(code, error_response),
fault: FaultSource::Runtime,
}
}
pub(crate) fn rest_other_status_code(code: u16, error_response: Option<String>) -> EmbedError {
Self {
kind: EmbedErrorKind::RestOtherStatusCode(code, error_response),
fault: FaultSource::Undecided,
}
}
pub(crate) fn rest_network(transport: ureq::Transport) -> EmbedError {
Self { kind: EmbedErrorKind::RestNetwork(transport), fault: FaultSource::Runtime }
}
pub(crate) fn rest_not_an_object(
query: serde_json::Value,
input_path: Vec<String>,
) -> EmbedError {
Self { kind: EmbedErrorKind::RestNotAnObject(query, input_path), fault: FaultSource::User }
}
pub(crate) fn openai_unexpected_dimension(expected: usize, got: usize) -> EmbedError {
Self {
kind: EmbedErrorKind::OpenAiUnexpectedDimension(expected, got),
fault: FaultSource::Runtime,
}
}
}
@ -188,16 +249,12 @@ impl NewEmbedderError {
Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime }
}
pub fn hf_could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError {
pub fn could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError {
Self {
kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner),
fault: FaultSource::Runtime,
}
}
pub fn openai_invalid_api_key_format(inner: reqwest::header::InvalidHeaderValue) -> Self {
Self { kind: NewEmbedderErrorKind::InvalidApiKeyFormat(inner), fault: FaultSource::User }
}
}
#[derive(Debug, thiserror::Error)]
@ -244,7 +301,4 @@ pub enum NewEmbedderErrorKind {
CouldNotDetermineDimension(EmbedError),
#[error("loading model failed: {0}")]
LoadModel(candle_core::Error),
// openai
#[error("The API key passed to Authorization error was in an invalid format: {0}")]
InvalidApiKeyFormat(reqwest::header::InvalidHeaderValue),
}

View file

@ -131,7 +131,7 @@ impl Embedder {
let embeddings = this
.embed(vec!["test".into()])
.map_err(NewEmbedderError::hf_could_not_determine_dimension)?;
.map_err(NewEmbedderError::could_not_determine_dimension)?;
this.dimensions = embeddings.first().unwrap().dimension();
Ok(this)
@ -194,7 +194,10 @@ impl Embedder {
pub fn distribution(&self) -> Option<DistributionShift> {
if self.options.model == "BAAI/bge-base-en-v1.5" {
Some(DistributionShift { current_mean: 0.85, current_sigma: 0.1 })
Some(DistributionShift {
current_mean: ordered_float::OrderedFloat(0.85),
current_sigma: ordered_float::OrderedFloat(0.1),
})
} else {
None
}

View file

@ -1,6 +1,9 @@
use std::collections::HashMap;
use std::sync::Arc;
use ordered_float::OrderedFloat;
use serde::{Deserialize, Serialize};
use self::error::{EmbedError, NewEmbedderError};
use crate::prompt::{Prompt, PromptData};
@ -10,50 +13,71 @@ pub mod manual;
pub mod openai;
pub mod settings;
pub mod ollama;
pub mod rest;
pub use self::error::Error;
pub type Embedding = Vec<f32>;
pub const REQUEST_PARALLELISM: usize = 40;
/// One or multiple embeddings stored consecutively in a flat vector.
pub struct Embeddings<F> {
data: Vec<F>,
dimension: usize,
}
impl<F> Embeddings<F> {
/// Declares an empty vector of embeddings of the specified dimensions.
pub fn new(dimension: usize) -> Self {
Self { data: Default::default(), dimension }
}
/// Declares a vector of embeddings containing a single element.
///
/// The dimension is inferred from the length of the passed embedding.
pub fn from_single_embedding(embedding: Vec<F>) -> Self {
Self { dimension: embedding.len(), data: embedding }
}
/// Declares a vector of embeddings from its components.
///
/// `data.len()` must be a multiple of `dimension`, otherwise an error is returned.
pub fn from_inner(data: Vec<F>, dimension: usize) -> Result<Self, Vec<F>> {
let mut this = Self::new(dimension);
this.append(data)?;
Ok(this)
}
/// Returns the number of embeddings in this vector of embeddings.
pub fn embedding_count(&self) -> usize {
self.data.len() / self.dimension
}
/// Dimension of a single embedding.
pub fn dimension(&self) -> usize {
self.dimension
}
/// Deconstructs self into the inner flat vector.
pub fn into_inner(self) -> Vec<F> {
self.data
}
/// A reference to the inner flat vector.
pub fn as_inner(&self) -> &[F] {
&self.data
}
/// Iterates over the embeddings contained in the flat vector.
pub fn iter(&self) -> impl Iterator<Item = &'_ [F]> + '_ {
self.data.as_slice().chunks_exact(self.dimension)
}
/// Push an embedding at the end of the embeddings.
///
/// If `embedding.len() != self.dimension`, then the push operation fails.
pub fn push(&mut self, mut embedding: Vec<F>) -> Result<(), Vec<F>> {
if embedding.len() != self.dimension {
return Err(embedding);
@ -62,6 +86,9 @@ impl<F> Embeddings<F> {
Ok(())
}
/// Append a flat vector of embeddings a the end of the embeddings.
///
/// If `embeddings.len() % self.dimension != 0`, then the append operation fails.
pub fn append(&mut self, mut embeddings: Vec<F>) -> Result<(), Vec<F>> {
if embeddings.len() % self.dimension != 0 {
return Err(embeddings);
@ -71,36 +98,60 @@ impl<F> Embeddings<F> {
}
}
/// An embedder can be used to transform text into embeddings.
#[derive(Debug)]
pub enum Embedder {
/// An embedder based on running local models, fetched from the Hugging Face Hub.
HuggingFace(hf::Embedder),
/// An embedder based on making embedding queries against the OpenAI API.
OpenAi(openai::Embedder),
/// An embedder based on the user providing the embeddings in the documents and queries.
UserProvided(manual::Embedder),
/// An embedder based on making embedding queries against an <https://ollama.com> embedding server.
Ollama(ollama::Embedder),
/// An embedder based on making embedding queries against a generic JSON/REST embedding server.
Rest(rest::Embedder),
}
/// Configuration for an embedder.
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
pub struct EmbeddingConfig {
/// Options of the embedder, specific to each kind of embedder
pub embedder_options: EmbedderOptions,
/// Document template
pub prompt: PromptData,
// TODO: add metrics and anything needed
}
/// Map of embedder configurations.
///
/// Each configuration is mapped to a name.
#[derive(Clone, Default)]
pub struct EmbeddingConfigs(HashMap<String, (Arc<Embedder>, Arc<Prompt>)>);
impl EmbeddingConfigs {
/// Create the map from its internal component.s
pub fn new(data: HashMap<String, (Arc<Embedder>, Arc<Prompt>)>) -> Self {
Self(data)
}
/// Get an embedder configuration and template from its name.
pub fn get(&self, name: &str) -> Option<(Arc<Embedder>, Arc<Prompt>)> {
self.0.get(name).cloned()
}
/// Get the default embedder configuration, if any.
pub fn get_default(&self) -> Option<(Arc<Embedder>, Arc<Prompt>)> {
self.get_default_embedder_name().and_then(|default| self.get(&default))
}
/// Get the name of the default embedder configuration.
///
/// The default embedder is determined as follows:
///
/// - If there is only one embedder, it is always the default.
/// - If there are multiple embedders and one of them is called `default`, then that one is the default embedder.
/// - In all other cases, there is no default embedder.
pub fn get_default_embedder_name(&self) -> Option<String> {
let mut it = self.0.keys();
let first_name = it.next();
@ -123,11 +174,14 @@ impl IntoIterator for EmbeddingConfigs {
}
}
/// Options of an embedder, specific to each kind of embedder.
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub enum EmbedderOptions {
HuggingFace(hf::EmbedderOptions),
OpenAi(openai::EmbedderOptions),
Ollama(ollama::EmbedderOptions),
UserProvided(manual::EmbedderOptions),
Rest(rest::EmbedderOptions),
}
impl Default for EmbedderOptions {
@ -137,91 +191,158 @@ impl Default for EmbedderOptions {
}
impl EmbedderOptions {
/// Default options for the Hugging Face embedder
pub fn huggingface() -> Self {
Self::HuggingFace(hf::EmbedderOptions::new())
}
/// Default options for the OpenAI embedder
pub fn openai(api_key: Option<String>) -> Self {
Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key))
}
pub fn ollama() -> Self {
Self::Ollama(ollama::EmbedderOptions::with_default_model())
}
}
impl Embedder {
/// Spawns a new embedder built from its options.
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
Ok(match options {
EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?),
EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?),
EmbedderOptions::Ollama(options) => Self::Ollama(ollama::Embedder::new(options)?),
EmbedderOptions::UserProvided(options) => {
Self::UserProvided(manual::Embedder::new(options))
}
EmbedderOptions::Rest(options) => Self::Rest(rest::Embedder::new(options)?),
})
}
pub async fn embed(
/// Embed one or multiple texts.
///
/// Each text can be embedded as one or multiple embeddings.
pub fn embed(
&self,
texts: Vec<String>,
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
match self {
Embedder::HuggingFace(embedder) => embedder.embed(texts),
Embedder::OpenAi(embedder) => {
let client = embedder.new_client()?;
embedder.embed(texts, &client).await
}
Embedder::OpenAi(embedder) => embedder.embed(texts),
Embedder::Ollama(embedder) => embedder.embed(texts),
Embedder::UserProvided(embedder) => embedder.embed(texts),
Embedder::Rest(embedder) => embedder.embed(texts),
}
}
/// # Panics
/// Embed multiple chunks of texts.
///
/// - if called from an asynchronous context
/// Each chunk is composed of one or multiple texts.
pub fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
threads: &rayon::ThreadPool,
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
match self {
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks),
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks, threads),
Embedder::Ollama(embedder) => embedder.embed_chunks(text_chunks, threads),
Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks),
Embedder::Rest(embedder) => embedder.embed_chunks(text_chunks, threads),
}
}
/// Indicates the preferred number of chunks to pass to [`Self::embed_chunks`]
pub fn chunk_count_hint(&self) -> usize {
match self {
Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(),
Embedder::OpenAi(embedder) => embedder.chunk_count_hint(),
Embedder::Ollama(embedder) => embedder.chunk_count_hint(),
Embedder::UserProvided(_) => 1,
Embedder::Rest(embedder) => embedder.chunk_count_hint(),
}
}
/// Indicates the preferred number of texts in a single chunk passed to [`Self::embed`]
pub fn prompt_count_in_chunk_hint(&self) -> usize {
match self {
Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::Ollama(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::UserProvided(_) => 1,
Embedder::Rest(embedder) => embedder.prompt_count_in_chunk_hint(),
}
}
/// Indicates the dimensions of a single embedding produced by the embedder.
pub fn dimensions(&self) -> usize {
match self {
Embedder::HuggingFace(embedder) => embedder.dimensions(),
Embedder::OpenAi(embedder) => embedder.dimensions(),
Embedder::Ollama(embedder) => embedder.dimensions(),
Embedder::UserProvided(embedder) => embedder.dimensions(),
Embedder::Rest(embedder) => embedder.dimensions(),
}
}
/// An optional distribution used to apply an affine transformation to the similarity score of a document.
pub fn distribution(&self) -> Option<DistributionShift> {
match self {
Embedder::HuggingFace(embedder) => embedder.distribution(),
Embedder::OpenAi(embedder) => embedder.distribution(),
Embedder::Ollama(embedder) => embedder.distribution(),
Embedder::UserProvided(_embedder) => None,
Embedder::Rest(embedder) => embedder.distribution(),
}
}
}
#[derive(Debug, Clone, Copy)]
/// Describes the mean and sigma of distribution of embedding similarity in the embedding space.
///
/// The intended use is to make the similarity score more comparable to the regular ranking score.
/// This allows to correct effects where results are too "packed" around a certain value.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize)]
#[serde(from = "DistributionShiftSerializable")]
#[serde(into = "DistributionShiftSerializable")]
pub struct DistributionShift {
pub current_mean: f32,
pub current_sigma: f32,
/// Value where the results are "packed".
///
/// Similarity scores are translated so that they are packed around 0.5 instead
pub current_mean: OrderedFloat<f32>,
/// standard deviation of a similarity score.
///
/// Set below 0.4 to make the results less packed around the mean, and above 0.4 to make them more packed.
pub current_sigma: OrderedFloat<f32>,
}
#[derive(Serialize, Deserialize)]
struct DistributionShiftSerializable {
current_mean: f32,
current_sigma: f32,
}
impl From<DistributionShift> for DistributionShiftSerializable {
fn from(
DistributionShift {
current_mean: OrderedFloat(current_mean),
current_sigma: OrderedFloat(current_sigma),
}: DistributionShift,
) -> Self {
Self { current_mean, current_sigma }
}
}
impl From<DistributionShiftSerializable> for DistributionShift {
fn from(
DistributionShiftSerializable { current_mean, current_sigma }: DistributionShiftSerializable,
) -> Self {
Self {
current_mean: OrderedFloat(current_mean),
current_sigma: OrderedFloat(current_sigma),
}
}
}
impl DistributionShift {
@ -230,11 +351,13 @@ impl DistributionShift {
if sigma <= 0.0 {
None
} else {
Some(Self { current_mean: mean, current_sigma: sigma })
Some(Self { current_mean: OrderedFloat(mean), current_sigma: OrderedFloat(sigma) })
}
}
pub fn shift(&self, score: f32) -> f32 {
let current_mean = self.current_mean.0;
let current_sigma = self.current_sigma.0;
// <https://math.stackexchange.com/a/2894689>
// We're somewhat abusively mapping the distribution of distances to a gaussian.
// The parameters we're given is the mean and sigma of the native result distribution.
@ -244,9 +367,9 @@ impl DistributionShift {
let target_sigma = 0.4;
// a^2 sig1^2 = sig2^2 => a^2 = sig2^2 / sig1^2 => a = sig2 / sig1, assuming a, sig1, and sig2 positive.
let factor = target_sigma / self.current_sigma;
let factor = target_sigma / current_sigma;
// a*mu1 + b = mu2 => b = mu2 - a*mu1
let offset = target_mean - (factor * self.current_mean);
let offset = target_mean - (factor * current_mean);
let mut score = factor * score + offset;
@ -262,6 +385,7 @@ impl DistributionShift {
}
}
/// Whether CUDA is supported in this version of Meilisearch.
pub const fn is_cuda_enabled() -> bool {
cfg!(feature = "cuda")
}

102
milli/src/vector/ollama.rs Normal file
View file

@ -0,0 +1,102 @@
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind};
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
use super::{DistributionShift, Embeddings};
#[derive(Debug)]
pub struct Embedder {
rest_embedder: RestEmbedder,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct EmbedderOptions {
pub embedding_model: String,
}
impl EmbedderOptions {
pub fn with_default_model() -> Self {
Self { embedding_model: "nomic-embed-text".into() }
}
pub fn with_embedding_model(embedding_model: String) -> Self {
Self { embedding_model }
}
}
impl Embedder {
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
let model = options.embedding_model.as_str();
let rest_embedder = match RestEmbedder::new(RestEmbedderOptions {
api_key: None,
distribution: None,
dimensions: None,
url: get_ollama_path(),
query: serde_json::json!({
"model": model,
}),
input_field: vec!["prompt".to_owned()],
path_to_embeddings: Default::default(),
embedding_object: vec!["embedding".to_owned()],
input_type: super::rest::InputType::Text,
}) {
Ok(embedder) => embedder,
Err(NewEmbedderError {
kind:
NewEmbedderErrorKind::CouldNotDetermineDimension(EmbedError {
kind: super::error::EmbedErrorKind::RestOtherStatusCode(404, error),
fault: _,
}),
fault: _,
}) => {
return Err(NewEmbedderError::could_not_determine_dimension(
EmbedError::ollama_model_not_found(error),
))
}
Err(error) => return Err(error),
};
Ok(Self { rest_embedder })
}
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
match self.rest_embedder.embed(texts) {
Ok(embeddings) => Ok(embeddings),
Err(EmbedError { kind: EmbedErrorKind::RestOtherStatusCode(404, error), fault: _ }) => {
Err(EmbedError::ollama_model_not_found(error))
}
Err(error) => Err(error),
}
}
pub fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
threads: &rayon::ThreadPool,
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
threads.install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
})
}
pub fn chunk_count_hint(&self) -> usize {
self.rest_embedder.chunk_count_hint()
}
pub fn prompt_count_in_chunk_hint(&self) -> usize {
self.rest_embedder.prompt_count_in_chunk_hint()
}
pub fn dimensions(&self) -> usize {
self.rest_embedder.dimensions()
}
pub fn distribution(&self) -> Option<DistributionShift> {
None
}
}
fn get_ollama_path() -> String {
// Important: Hostname not enough, has to be entire path to embeddings endpoint
std::env::var("MEILI_OLLAMA_URL").unwrap_or("http://localhost:11434/api/embeddings".to_string())
}

View file

@ -1,17 +1,10 @@
use std::fmt::Display;
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use ordered_float::OrderedFloat;
use rayon::iter::{IntoParallelIterator, ParallelIterator as _};
use super::error::{EmbedError, NewEmbedderError};
use super::{DistributionShift, Embedding, Embeddings};
#[derive(Debug)]
pub struct Embedder {
headers: reqwest::header::HeaderMap,
tokenizer: tiktoken_rs::CoreBPE,
options: EmbedderOptions,
}
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
use super::{DistributionShift, Embeddings};
use crate::vector::error::EmbedErrorKind;
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct EmbedderOptions {
@ -20,6 +13,32 @@ pub struct EmbedderOptions {
pub dimensions: Option<usize>,
}
impl EmbedderOptions {
pub fn dimensions(&self) -> usize {
if self.embedding_model.supports_overriding_dimensions() {
self.dimensions.unwrap_or(self.embedding_model.default_dimensions())
} else {
self.embedding_model.default_dimensions()
}
}
pub fn query(&self) -> serde_json::Value {
let model = self.embedding_model.name();
let mut query = serde_json::json!({
"model": model,
});
if self.embedding_model.supports_overriding_dimensions() {
if let Some(dimensions) = self.dimensions {
query["dimensions"] = dimensions.into();
}
}
query
}
}
#[derive(
Debug,
Clone,
@ -92,15 +111,18 @@ impl EmbeddingModel {
fn distribution(&self) -> Option<DistributionShift> {
match self {
EmbeddingModel::TextEmbeddingAda002 => {
Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 })
}
EmbeddingModel::TextEmbedding3Large => {
Some(DistributionShift { current_mean: 0.70, current_sigma: 0.1 })
}
EmbeddingModel::TextEmbedding3Small => {
Some(DistributionShift { current_mean: 0.75, current_sigma: 0.1 })
}
EmbeddingModel::TextEmbeddingAda002 => Some(DistributionShift {
current_mean: OrderedFloat(0.90),
current_sigma: OrderedFloat(0.08),
}),
EmbeddingModel::TextEmbedding3Large => Some(DistributionShift {
current_mean: OrderedFloat(0.70),
current_sigma: OrderedFloat(0.1),
}),
EmbeddingModel::TextEmbedding3Small => Some(DistributionShift {
current_mean: OrderedFloat(0.75),
current_sigma: OrderedFloat(0.1),
}),
}
}
@ -125,178 +147,57 @@ impl EmbedderOptions {
}
}
impl Embedder {
pub fn new_client(&self) -> Result<reqwest::Client, EmbedError> {
reqwest::ClientBuilder::new()
.default_headers(self.headers.clone())
.build()
.map_err(EmbedError::openai_initialize_web_client)
}
fn infer_api_key() -> String {
std::env::var("MEILI_OPENAI_API_KEY")
.or_else(|_| std::env::var("OPENAI_API_KEY"))
.unwrap_or_default()
}
#[derive(Debug)]
pub struct Embedder {
tokenizer: tiktoken_rs::CoreBPE,
rest_embedder: RestEmbedder,
options: EmbedderOptions,
}
impl Embedder {
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
let mut headers = reqwest::header::HeaderMap::new();
let mut inferred_api_key = Default::default();
let api_key = options.api_key.as_ref().unwrap_or_else(|| {
inferred_api_key = infer_api_key();
&inferred_api_key
});
headers.insert(
reqwest::header::AUTHORIZATION,
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", api_key))
.map_err(NewEmbedderError::openai_invalid_api_key_format)?,
);
headers.insert(
reqwest::header::CONTENT_TYPE,
reqwest::header::HeaderValue::from_static("application/json"),
);
let rest_embedder = RestEmbedder::new(RestEmbedderOptions {
api_key: Some(api_key.clone()),
distribution: options.embedding_model.distribution(),
dimensions: Some(options.dimensions()),
url: OPENAI_EMBEDDINGS_URL.to_owned(),
query: options.query(),
input_field: vec!["input".to_owned()],
input_type: crate::vector::rest::InputType::TextArray,
path_to_embeddings: vec!["data".to_owned()],
embedding_object: vec!["embedding".to_owned()],
})?;
// looking at the code it is very unclear that this can actually fail.
let tokenizer = tiktoken_rs::cl100k_base().unwrap();
Ok(Self { options, headers, tokenizer })
Ok(Self { options, rest_embedder, tokenizer })
}
pub async fn embed(
&self,
texts: Vec<String>,
client: &reqwest::Client,
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
let mut tokenized = false;
for attempt in 0..7 {
let result = if tokenized {
self.try_embed_tokenized(&texts, client).await
} else {
self.try_embed(&texts, client).await
};
let retry_duration = match result {
Ok(embeddings) => return Ok(embeddings),
Err(retry) => {
tracing::warn!("Failed: {}", retry.error);
tokenized |= retry.must_tokenize();
retry.into_duration(attempt)
}
}?;
let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute
tracing::warn!(
"Attempt #{}, retrying after {}ms.",
attempt,
retry_duration.as_millis()
);
tokio::time::sleep(retry_duration).await;
}
let result = if tokenized {
self.try_embed_tokenized(&texts, client).await
} else {
self.try_embed(&texts, client).await
};
result.map_err(Retry::into_error)
}
async fn check_response(response: reqwest::Response) -> Result<reqwest::Response, Retry> {
if !response.status().is_success() {
match response.status() {
StatusCode::UNAUTHORIZED => {
let error_response: OpenAiErrorResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
return Err(Retry::give_up(EmbedError::openai_auth_error(
error_response.error,
)));
}
StatusCode::TOO_MANY_REQUESTS => {
let error_response: OpenAiErrorResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
return Err(Retry::rate_limited(EmbedError::openai_too_many_requests(
error_response.error,
)));
}
StatusCode::INTERNAL_SERVER_ERROR
| StatusCode::BAD_GATEWAY
| StatusCode::SERVICE_UNAVAILABLE => {
let error_response: Result<OpenAiErrorResponse, _> = response.json().await;
return Err(Retry::retry_later(EmbedError::openai_internal_server_error(
error_response.ok().map(|error_response| error_response.error),
)));
}
StatusCode::BAD_REQUEST => {
// Most probably, one text contained too many tokens
let error_response: OpenAiErrorResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
tracing::warn!("OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your prompt.");
return Err(Retry::retry_tokenized(EmbedError::openai_too_many_tokens(
error_response.error,
)));
}
code => {
return Err(Retry::retry_later(EmbedError::openai_unhandled_status_code(
code.as_u16(),
)));
}
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
match self.rest_embedder.embed_ref(&texts) {
Ok(embeddings) => Ok(embeddings),
Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error), fault: _ }) => {
tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template.");
self.try_embed_tokenized(&texts)
}
Err(error) => Err(error),
}
Ok(response)
}
async fn try_embed<S: AsRef<str> + serde::Serialize>(
&self,
texts: &[S],
client: &reqwest::Client,
) -> Result<Vec<Embeddings<f32>>, Retry> {
for text in texts {
tracing::trace!("Received prompt: {}", text.as_ref())
}
let request = OpenAiRequest {
model: self.options.embedding_model.name(),
input: texts,
dimensions: self.overriden_dimensions(),
};
let response = client
.post(OPENAI_EMBEDDINGS_URL)
.json(&request)
.send()
.await
.map_err(EmbedError::openai_network)
.map_err(Retry::retry_later)?;
let response = Self::check_response(response).await?;
let response: OpenAiResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
tracing::trace!("response: {:?}", response.data);
Ok(response
.data
.into_iter()
.map(|data| Embeddings::from_single_embedding(data.embedding))
.collect())
}
async fn try_embed_tokenized(
&self,
text: &[String],
client: &reqwest::Client,
) -> Result<Vec<Embeddings<f32>>, Retry> {
fn try_embed_tokenized(&self, text: &[String]) -> Result<Vec<Embeddings<f32>>, EmbedError> {
pub const OVERLAP_SIZE: usize = 200;
let mut all_embeddings = Vec::with_capacity(text.len());
for text in text {
@ -304,7 +205,7 @@ impl Embedder {
let encoded = self.tokenizer.encode_ordinary(text.as_str());
let len = encoded.len();
if len < max_token_count {
all_embeddings.append(&mut self.try_embed(&[text], client).await?);
all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text])?);
continue;
}
@ -312,215 +213,49 @@ impl Embedder {
let mut embeddings_for_prompt = Embeddings::new(self.dimensions());
while tokens.len() > max_token_count {
let window = &tokens[..max_token_count];
embeddings_for_prompt.push(self.embed_tokens(window, client).await?).unwrap();
let embedding = self.rest_embedder.embed_tokens(window)?;
embeddings_for_prompt.append(embedding.into_inner()).map_err(|got| {
EmbedError::openai_unexpected_dimension(self.dimensions(), got.len())
})?;
tokens = &tokens[max_token_count - OVERLAP_SIZE..];
}
// end of text
embeddings_for_prompt.push(self.embed_tokens(tokens, client).await?).unwrap();
let embedding = self.rest_embedder.embed_tokens(tokens)?;
embeddings_for_prompt.append(embedding.into_inner()).map_err(|got| {
EmbedError::openai_unexpected_dimension(self.dimensions(), got.len())
})?;
all_embeddings.push(embeddings_for_prompt);
}
Ok(all_embeddings)
}
async fn embed_tokens(
&self,
tokens: &[usize],
client: &reqwest::Client,
) -> Result<Embedding, Retry> {
for attempt in 0..9 {
let duration = match self.try_embed_tokens(tokens, client).await {
Ok(embedding) => return Ok(embedding),
Err(retry) => retry.into_duration(attempt),
}
.map_err(Retry::retry_later)?;
tokio::time::sleep(duration).await;
}
self.try_embed_tokens(tokens, client)
.await
.map_err(|retry| Retry::give_up(retry.into_error()))
}
async fn try_embed_tokens(
&self,
tokens: &[usize],
client: &reqwest::Client,
) -> Result<Embedding, Retry> {
let request = OpenAiTokensRequest {
model: self.options.embedding_model.name(),
input: tokens,
dimensions: self.overriden_dimensions(),
};
let response = client
.post(OPENAI_EMBEDDINGS_URL)
.json(&request)
.send()
.await
.map_err(EmbedError::openai_network)
.map_err(Retry::retry_later)?;
let response = Self::check_response(response).await?;
let mut response: OpenAiResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default())
}
pub fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
threads: &rayon::ThreadPool,
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
.build()
.map_err(EmbedError::openai_runtime_init)?;
let client = self.new_client()?;
rt.block_on(futures::future::try_join_all(
text_chunks.into_iter().map(|prompts| self.embed(prompts, &client)),
))
threads.install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
})
}
pub fn chunk_count_hint(&self) -> usize {
10
self.rest_embedder.chunk_count_hint()
}
pub fn prompt_count_in_chunk_hint(&self) -> usize {
10
self.rest_embedder.prompt_count_in_chunk_hint()
}
pub fn dimensions(&self) -> usize {
if self.options.embedding_model.supports_overriding_dimensions() {
self.options.dimensions.unwrap_or(self.options.embedding_model.default_dimensions())
} else {
self.options.embedding_model.default_dimensions()
}
self.options.dimensions()
}
pub fn distribution(&self) -> Option<DistributionShift> {
self.options.embedding_model.distribution()
}
fn overriden_dimensions(&self) -> Option<usize> {
if self.options.embedding_model.supports_overriding_dimensions() {
self.options.dimensions
} else {
None
}
}
}
// retrying in case of failure
struct Retry {
error: EmbedError,
strategy: RetryStrategy,
}
enum RetryStrategy {
GiveUp,
Retry,
RetryTokenized,
RetryAfterRateLimit,
}
impl Retry {
fn give_up(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::GiveUp }
}
fn retry_later(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::Retry }
}
fn retry_tokenized(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::RetryTokenized }
}
fn rate_limited(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::RetryAfterRateLimit }
}
fn into_duration(self, attempt: u32) -> Result<tokio::time::Duration, EmbedError> {
match self.strategy {
RetryStrategy::GiveUp => Err(self.error),
RetryStrategy::Retry => Ok(tokio::time::Duration::from_millis((10u64).pow(attempt))),
RetryStrategy::RetryTokenized => Ok(tokio::time::Duration::from_millis(1)),
RetryStrategy::RetryAfterRateLimit => {
Ok(tokio::time::Duration::from_millis(100 + 10u64.pow(attempt)))
}
}
}
fn must_tokenize(&self) -> bool {
matches!(self.strategy, RetryStrategy::RetryTokenized)
}
fn into_error(self) -> EmbedError {
self.error
}
}
// openai api structs
#[derive(Debug, Serialize)]
struct OpenAiRequest<'a, S: AsRef<str> + serde::Serialize> {
model: &'a str,
input: &'a [S],
#[serde(skip_serializing_if = "Option::is_none")]
dimensions: Option<usize>,
}
#[derive(Debug, Serialize)]
struct OpenAiTokensRequest<'a> {
model: &'a str,
input: &'a [usize],
#[serde(skip_serializing_if = "Option::is_none")]
dimensions: Option<usize>,
}
#[derive(Debug, Deserialize)]
struct OpenAiResponse {
data: Vec<OpenAiEmbedding>,
}
#[derive(Debug, Deserialize)]
struct OpenAiErrorResponse {
error: OpenAiError,
}
#[derive(Debug, Deserialize)]
pub struct OpenAiError {
message: String,
// type: String,
code: Option<String>,
}
impl Display for OpenAiError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.code {
Some(code) => write!(f, "{} ({})", self.message, code),
None => write!(f, "{}", self.message),
}
}
}
#[derive(Debug, Deserialize)]
struct OpenAiEmbedding {
embedding: Embedding,
// object: String,
// index: usize,
}
fn infer_api_key() -> String {
std::env::var("MEILI_OPENAI_API_KEY")
.or_else(|_| std::env::var("OPENAI_API_KEY"))
.unwrap_or_default()
}

373
milli/src/vector/rest.rs Normal file
View file

@ -0,0 +1,373 @@
use deserr::Deserr;
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
use serde::{Deserialize, Serialize};
use super::{
DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM,
};
// retrying in case of failure
pub struct Retry {
pub error: EmbedError,
strategy: RetryStrategy,
}
pub enum RetryStrategy {
GiveUp,
Retry,
RetryTokenized,
RetryAfterRateLimit,
}
impl Retry {
pub fn give_up(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::GiveUp }
}
pub fn retry_later(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::Retry }
}
pub fn retry_tokenized(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::RetryTokenized }
}
pub fn rate_limited(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::RetryAfterRateLimit }
}
pub fn into_duration(self, attempt: u32) -> Result<std::time::Duration, EmbedError> {
match self.strategy {
RetryStrategy::GiveUp => Err(self.error),
RetryStrategy::Retry => Ok(std::time::Duration::from_millis((10u64).pow(attempt))),
RetryStrategy::RetryTokenized => Ok(std::time::Duration::from_millis(1)),
RetryStrategy::RetryAfterRateLimit => {
Ok(std::time::Duration::from_millis(100 + 10u64.pow(attempt)))
}
}
}
pub fn must_tokenize(&self) -> bool {
matches!(self.strategy, RetryStrategy::RetryTokenized)
}
pub fn into_error(self) -> EmbedError {
self.error
}
}
#[derive(Debug)]
pub struct Embedder {
client: ureq::Agent,
options: EmbedderOptions,
bearer: Option<String>,
dimensions: usize,
}
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
pub struct EmbedderOptions {
pub api_key: Option<String>,
pub distribution: Option<DistributionShift>,
pub dimensions: Option<usize>,
pub url: String,
pub query: serde_json::Value,
pub input_field: Vec<String>,
// path to the array of embeddings
pub path_to_embeddings: Vec<String>,
// shape of a single embedding
pub embedding_object: Vec<String>,
pub input_type: InputType,
}
impl Default for EmbedderOptions {
fn default() -> Self {
Self {
url: Default::default(),
query: Default::default(),
input_field: vec!["input".into()],
path_to_embeddings: vec!["data".into()],
embedding_object: vec!["embedding".into()],
input_type: InputType::Text,
api_key: None,
distribution: None,
dimensions: None,
}
}
}
impl std::hash::Hash for EmbedderOptions {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.api_key.hash(state);
self.distribution.hash(state);
self.dimensions.hash(state);
self.url.hash(state);
// skip hashing the query
// collisions in regular usage should be minimal,
// and the list is limited to 256 values anyway
self.input_field.hash(state);
self.path_to_embeddings.hash(state);
self.embedding_object.hash(state);
self.input_type.hash(state);
}
}
#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Hash, Deserr)]
#[serde(rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)]
pub enum InputType {
Text,
TextArray,
}
impl Embedder {
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
let bearer = options.api_key.as_deref().map(|api_key| format!("Bearer {api_key}"));
let client = ureq::AgentBuilder::new()
.max_idle_connections(REQUEST_PARALLELISM * 2)
.max_idle_connections_per_host(REQUEST_PARALLELISM * 2)
.build();
let dimensions = if let Some(dimensions) = options.dimensions {
dimensions
} else {
infer_dimensions(&client, &options, bearer.as_deref())?
};
Ok(Self { client, dimensions, options, bearer })
}
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
embed(&self.client, &self.options, self.bearer.as_deref(), texts.as_slice(), texts.len())
}
pub fn embed_ref<S>(&self, texts: &[S]) -> Result<Vec<Embeddings<f32>>, EmbedError>
where
S: AsRef<str> + Serialize,
{
embed(&self.client, &self.options, self.bearer.as_deref(), texts, texts.len())
}
pub fn embed_tokens(&self, tokens: &[usize]) -> Result<Embeddings<f32>, EmbedError> {
let mut embeddings = embed(&self.client, &self.options, self.bearer.as_deref(), tokens, 1)?;
// unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error
Ok(embeddings.pop().unwrap())
}
pub fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
threads: &rayon::ThreadPool,
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
threads.install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
})
}
pub fn chunk_count_hint(&self) -> usize {
super::REQUEST_PARALLELISM
}
pub fn prompt_count_in_chunk_hint(&self) -> usize {
match self.options.input_type {
InputType::Text => 1,
InputType::TextArray => 10,
}
}
pub fn dimensions(&self) -> usize {
self.dimensions
}
pub fn distribution(&self) -> Option<DistributionShift> {
self.options.distribution
}
}
fn infer_dimensions(
client: &ureq::Agent,
options: &EmbedderOptions,
bearer: Option<&str>,
) -> Result<usize, NewEmbedderError> {
let v = embed(client, options, bearer, ["test"].as_slice(), 1)
.map_err(NewEmbedderError::could_not_determine_dimension)?;
// unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error
Ok(v.first().unwrap().dimension())
}
fn embed<S>(
client: &ureq::Agent,
options: &EmbedderOptions,
bearer: Option<&str>,
inputs: &[S],
expected_count: usize,
) -> Result<Vec<Embeddings<f32>>, EmbedError>
where
S: Serialize,
{
let request = client.post(&options.url);
let request =
if let Some(bearer) = bearer { request.set("Authorization", bearer) } else { request };
let request = request.set("Content-Type", "application/json");
let input_value = match options.input_type {
InputType::Text => serde_json::json!(inputs.first()),
InputType::TextArray => serde_json::json!(inputs),
};
let body = match options.input_field.as_slice() {
[] => {
// inject input in body
input_value
}
[input] => {
let mut body = options.query.clone();
body.as_object_mut()
.ok_or_else(|| {
EmbedError::rest_not_an_object(
options.query.clone(),
options.input_field.clone(),
)
})?
.insert(input.clone(), input_value);
body
}
[path @ .., input] => {
let mut body = options.query.clone();
let mut current_value = &mut body;
for component in path {
current_value = current_value
.as_object_mut()
.ok_or_else(|| {
EmbedError::rest_not_an_object(
options.query.clone(),
options.input_field.clone(),
)
})?
.entry(component.clone())
.or_insert(serde_json::json!({}));
}
current_value.as_object_mut().unwrap().insert(input.clone(), input_value);
body
}
};
for attempt in 0..7 {
let response = request.clone().send_json(&body);
let result = check_response(response);
let retry_duration = match result {
Ok(response) => return response_to_embedding(response, options, expected_count),
Err(retry) => {
tracing::warn!("Failed: {}", retry.error);
retry.into_duration(attempt)
}
}?;
let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute
tracing::warn!("Attempt #{}, retrying after {}ms.", attempt, retry_duration.as_millis());
std::thread::sleep(retry_duration);
}
let response = request.send_json(&body);
let result = check_response(response);
result
.map_err(Retry::into_error)
.and_then(|response| response_to_embedding(response, options, expected_count))
}
fn check_response(response: Result<ureq::Response, ureq::Error>) -> Result<ureq::Response, Retry> {
match response {
Ok(response) => Ok(response),
Err(ureq::Error::Status(code, response)) => {
let error_response: Option<String> = response.into_string().ok();
Err(match code {
401 => Retry::give_up(EmbedError::rest_unauthorized(error_response)),
429 => Retry::rate_limited(EmbedError::rest_too_many_requests(error_response)),
400 => Retry::give_up(EmbedError::rest_bad_request(error_response)),
500..=599 => {
Retry::retry_later(EmbedError::rest_internal_server_error(code, error_response))
}
402..=499 => {
Retry::give_up(EmbedError::rest_other_status_code(code, error_response))
}
_ => Retry::retry_later(EmbedError::rest_other_status_code(code, error_response)),
})
}
Err(ureq::Error::Transport(transport)) => {
Err(Retry::retry_later(EmbedError::rest_network(transport)))
}
}
}
fn response_to_embedding(
response: ureq::Response,
options: &EmbedderOptions,
expected_count: usize,
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
let response: serde_json::Value =
response.into_json().map_err(EmbedError::rest_response_deserialization)?;
let mut current_value = &response;
for component in &options.path_to_embeddings {
let component = component.as_ref();
current_value = current_value.get(component).ok_or_else(|| {
EmbedError::rest_response_missing_embeddings(
response.clone(),
component,
&options.path_to_embeddings,
)
})?;
}
let embeddings = match options.input_type {
InputType::Text => {
for component in &options.embedding_object {
current_value = current_value.get(component).ok_or_else(|| {
EmbedError::rest_response_missing_embeddings(
response.clone(),
component,
&options.embedding_object,
)
})?;
}
let embeddings = current_value.to_owned();
let embeddings: Embedding =
serde_json::from_value(embeddings).map_err(EmbedError::rest_response_format)?;
vec![Embeddings::from_single_embedding(embeddings)]
}
InputType::TextArray => {
let empty = vec![];
let values = current_value.as_array().unwrap_or(&empty);
let mut embeddings: Vec<Embeddings<f32>> = Vec::with_capacity(expected_count);
for value in values {
let mut current_value = value;
for component in &options.embedding_object {
current_value = current_value.get(component).ok_or_else(|| {
EmbedError::rest_response_missing_embeddings(
response.clone(),
component,
&options.embedding_object,
)
})?;
}
let embedding = current_value.to_owned();
let embedding: Embedding =
serde_json::from_value(embedding).map_err(EmbedError::rest_response_format)?;
embeddings.push(Embeddings::from_single_embedding(embedding));
}
embeddings
}
};
if embeddings.len() != expected_count {
return Err(EmbedError::rest_response_embedding_count(expected_count, embeddings.len()));
}
Ok(embeddings)
}

View file

@ -1,7 +1,8 @@
use deserr::Deserr;
use serde::{Deserialize, Serialize};
use super::openai;
use super::rest::InputType;
use super::{ollama, openai};
use crate::prompt::PromptData;
use crate::update::Setting;
use crate::vector::EmbeddingConfig;
@ -29,6 +30,24 @@ pub struct EmbeddingSettings {
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub document_template: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub url: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub query: Setting<serde_json::Value>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub input_field: Setting<Vec<String>>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub path_to_embeddings: Setting<Vec<String>>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub embedding_object: Setting<Vec<String>>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub input_type: Setting<InputType>,
}
pub fn check_unset<T>(
@ -75,16 +94,42 @@ impl EmbeddingSettings {
pub const DIMENSIONS: &'static str = "dimensions";
pub const DOCUMENT_TEMPLATE: &'static str = "documentTemplate";
pub const URL: &'static str = "url";
pub const QUERY: &'static str = "query";
pub const INPUT_FIELD: &'static str = "inputField";
pub const PATH_TO_EMBEDDINGS: &'static str = "pathToEmbeddings";
pub const EMBEDDING_OBJECT: &'static str = "embeddingObject";
pub const INPUT_TYPE: &'static str = "inputType";
pub fn allowed_sources_for_field(field: &'static str) -> &'static [EmbedderSource] {
match field {
Self::SOURCE => {
&[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::UserProvided]
Self::SOURCE => &[
EmbedderSource::HuggingFace,
EmbedderSource::OpenAi,
EmbedderSource::UserProvided,
EmbedderSource::Rest,
EmbedderSource::Ollama,
],
Self::MODEL => {
&[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::Ollama]
}
Self::MODEL => &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi],
Self::REVISION => &[EmbedderSource::HuggingFace],
Self::API_KEY => &[EmbedderSource::OpenAi],
Self::DIMENSIONS => &[EmbedderSource::OpenAi, EmbedderSource::UserProvided],
Self::DOCUMENT_TEMPLATE => &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi],
Self::API_KEY => &[EmbedderSource::OpenAi, EmbedderSource::Rest],
Self::DIMENSIONS => {
&[EmbedderSource::OpenAi, EmbedderSource::UserProvided, EmbedderSource::Rest]
}
Self::DOCUMENT_TEMPLATE => &[
EmbedderSource::HuggingFace,
EmbedderSource::OpenAi,
EmbedderSource::Ollama,
EmbedderSource::Rest,
],
Self::URL => &[EmbedderSource::Rest],
Self::QUERY => &[EmbedderSource::Rest],
Self::INPUT_FIELD => &[EmbedderSource::Rest],
Self::PATH_TO_EMBEDDINGS => &[EmbedderSource::Rest],
Self::EMBEDDING_OBJECT => &[EmbedderSource::Rest],
Self::INPUT_TYPE => &[EmbedderSource::Rest],
_other => unreachable!("unknown field"),
}
}
@ -101,7 +146,20 @@ impl EmbeddingSettings {
EmbedderSource::HuggingFace => {
&[Self::SOURCE, Self::MODEL, Self::REVISION, Self::DOCUMENT_TEMPLATE]
}
EmbedderSource::Ollama => &[Self::SOURCE, Self::MODEL, Self::DOCUMENT_TEMPLATE],
EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS],
EmbedderSource::Rest => &[
Self::SOURCE,
Self::API_KEY,
Self::DIMENSIONS,
Self::DOCUMENT_TEMPLATE,
Self::URL,
Self::QUERY,
Self::INPUT_FIELD,
Self::PATH_TO_EMBEDDINGS,
Self::EMBEDDING_OBJECT,
Self::INPUT_TYPE,
],
}
}
@ -134,7 +192,9 @@ pub enum EmbedderSource {
#[default]
OpenAi,
HuggingFace,
Ollama,
UserProvided,
Rest,
}
impl std::fmt::Display for EmbedderSource {
@ -143,6 +203,8 @@ impl std::fmt::Display for EmbedderSource {
EmbedderSource::OpenAi => "openAi",
EmbedderSource::HuggingFace => "huggingFace",
EmbedderSource::UserProvided => "userProvided",
EmbedderSource::Ollama => "ollama",
EmbedderSource::Rest => "rest",
};
f.write_str(s)
}
@ -150,8 +212,20 @@ impl std::fmt::Display for EmbedderSource {
impl EmbeddingSettings {
pub fn apply(&mut self, new: Self) {
let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } =
new;
let EmbeddingSettings {
source,
model,
revision,
api_key,
dimensions,
document_template,
url,
query,
input_field,
path_to_embeddings,
embedding_object,
input_type,
} = new;
let old_source = self.source;
self.source.apply(source);
// Reinitialize the whole setting object on a source change
@ -163,6 +237,12 @@ impl EmbeddingSettings {
api_key,
dimensions,
document_template,
url,
query,
input_field,
path_to_embeddings,
embedding_object,
input_type,
};
return;
}
@ -172,6 +252,13 @@ impl EmbeddingSettings {
self.api_key.apply(api_key);
self.dimensions.apply(dimensions);
self.document_template.apply(document_template);
self.url.apply(url);
self.query.apply(query);
self.input_field.apply(input_field);
self.path_to_embeddings.apply(path_to_embeddings);
self.embedding_object.apply(embedding_object);
self.input_type.apply(input_type);
}
}
@ -186,6 +273,12 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
api_key: Setting::NotSet,
dimensions: Setting::NotSet,
document_template: Setting::Set(prompt.template),
url: Setting::NotSet,
query: Setting::NotSet,
input_field: Setting::NotSet,
path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet,
input_type: Setting::NotSet,
},
super::EmbedderOptions::OpenAi(options) => Self {
source: Setting::Set(EmbedderSource::OpenAi),
@ -194,6 +287,26 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
api_key: options.api_key.map(Setting::Set).unwrap_or_default(),
dimensions: options.dimensions.map(Setting::Set).unwrap_or_default(),
document_template: Setting::Set(prompt.template),
url: Setting::NotSet,
query: Setting::NotSet,
input_field: Setting::NotSet,
path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet,
input_type: Setting::NotSet,
},
super::EmbedderOptions::Ollama(options) => Self {
source: Setting::Set(EmbedderSource::Ollama),
model: Setting::Set(options.embedding_model.to_owned()),
revision: Setting::NotSet,
api_key: Setting::NotSet,
dimensions: Setting::NotSet,
document_template: Setting::Set(prompt.template),
url: Setting::NotSet,
query: Setting::NotSet,
input_field: Setting::NotSet,
path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet,
input_type: Setting::NotSet,
},
super::EmbedderOptions::UserProvided(options) => Self {
source: Setting::Set(EmbedderSource::UserProvided),
@ -202,6 +315,37 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
api_key: Setting::NotSet,
dimensions: Setting::Set(options.dimensions),
document_template: Setting::NotSet,
url: Setting::NotSet,
query: Setting::NotSet,
input_field: Setting::NotSet,
path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet,
input_type: Setting::NotSet,
},
super::EmbedderOptions::Rest(super::rest::EmbedderOptions {
api_key,
// TODO: support distribution
distribution: _,
dimensions,
url,
query,
input_field,
path_to_embeddings,
embedding_object,
input_type,
}) => Self {
source: Setting::Set(EmbedderSource::Rest),
model: Setting::NotSet,
revision: Setting::NotSet,
api_key: api_key.map(Setting::Set).unwrap_or_default(),
dimensions: dimensions.map(Setting::Set).unwrap_or_default(),
document_template: Setting::Set(prompt.template),
url: Setting::Set(url),
query: Setting::Set(query),
input_field: Setting::Set(input_field),
path_to_embeddings: Setting::Set(path_to_embeddings),
embedding_object: Setting::Set(embedding_object),
input_type: Setting::Set(input_type),
},
}
}
@ -210,8 +354,20 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
impl From<EmbeddingSettings> for EmbeddingConfig {
fn from(value: EmbeddingSettings) -> Self {
let mut this = Self::default();
let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } =
value;
let EmbeddingSettings {
source,
model,
revision,
api_key,
dimensions,
document_template,
url,
query,
input_field,
path_to_embeddings,
embedding_object,
input_type,
} = value;
if let Some(source) = source.set() {
match source {
EmbedderSource::OpenAi => {
@ -229,6 +385,14 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
}
this.embedder_options = super::EmbedderOptions::OpenAi(options);
}
EmbedderSource::Ollama => {
let mut options: ollama::EmbedderOptions =
super::ollama::EmbedderOptions::with_default_model();
if let Some(model) = model.set() {
options.embedding_model = model;
}
this.embedder_options = super::EmbedderOptions::Ollama(options);
}
EmbedderSource::HuggingFace => {
let mut options = super::hf::EmbedderOptions::default();
if let Some(model) = model.set() {
@ -251,6 +415,26 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
dimensions: dimensions.set().unwrap(),
});
}
EmbedderSource::Rest => {
let embedder_options = super::rest::EmbedderOptions::default();
this.embedder_options =
super::EmbedderOptions::Rest(super::rest::EmbedderOptions {
api_key: api_key.set(),
distribution: None,
dimensions: dimensions.set(),
url: url.set().unwrap(),
query: query.set().unwrap_or(embedder_options.query),
input_field: input_field.set().unwrap_or(embedder_options.input_field),
path_to_embeddings: path_to_embeddings
.set()
.unwrap_or(embedder_options.path_to_embeddings),
embedding_object: embedding_object
.set()
.unwrap_or(embedder_options.embedding_object),
input_type: input_type.set().unwrap_or(embedder_options.input_type),
})
}
}
}