mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-07-03 11:57:07 +02:00
Revert "Revert "Merge remote-tracking branch 'origin/main' into release-v1.7.1""
This commit is contained in:
parent
c495c8eb33
commit
c5322df519
34 changed files with 1784 additions and 610 deletions
|
@ -20,13 +20,13 @@ use crate::heed_codec::facet::{
|
|||
use crate::heed_codec::{
|
||||
BEU16StrCodec, FstSetCodec, ScriptLanguageCodec, StrBEU16Codec, StrRefCodec,
|
||||
};
|
||||
use crate::order_by_map::OrderByMap;
|
||||
use crate::proximity::ProximityPrecision;
|
||||
use crate::vector::EmbeddingConfig;
|
||||
use crate::{
|
||||
default_criteria, CboRoaringBitmapCodec, Criterion, DocumentId, ExternalDocumentsIds,
|
||||
FacetDistribution, FieldDistribution, FieldId, FieldIdWordCountCodec, GeoPoint, ObkvCodec,
|
||||
OrderBy, Result, RoaringBitmapCodec, RoaringBitmapLenCodec, Search, U8StrStrCodec, BEU16,
|
||||
BEU32, BEU64,
|
||||
Result, RoaringBitmapCodec, RoaringBitmapLenCodec, Search, U8StrStrCodec, BEU16, BEU32, BEU64,
|
||||
};
|
||||
|
||||
pub const DEFAULT_MIN_WORD_LEN_ONE_TYPO: u8 = 5;
|
||||
|
@ -1373,21 +1373,19 @@ impl Index {
|
|||
self.main.remap_key_type::<Str>().delete(txn, main_key::MAX_VALUES_PER_FACET)
|
||||
}
|
||||
|
||||
pub fn sort_facet_values_by(&self, txn: &RoTxn) -> heed::Result<HashMap<String, OrderBy>> {
|
||||
let mut orders = self
|
||||
pub fn sort_facet_values_by(&self, txn: &RoTxn) -> heed::Result<OrderByMap> {
|
||||
let orders = self
|
||||
.main
|
||||
.remap_types::<Str, SerdeJson<HashMap<String, OrderBy>>>()
|
||||
.remap_types::<Str, SerdeJson<OrderByMap>>()
|
||||
.get(txn, main_key::SORT_FACET_VALUES_BY)?
|
||||
.unwrap_or_default();
|
||||
// Insert the default ordering if it is not already overwritten by the user.
|
||||
orders.entry("*".to_string()).or_insert(OrderBy::Lexicographic);
|
||||
Ok(orders)
|
||||
}
|
||||
|
||||
pub(crate) fn put_sort_facet_values_by(
|
||||
&self,
|
||||
txn: &mut RwTxn,
|
||||
val: &HashMap<String, OrderBy>,
|
||||
val: &OrderByMap,
|
||||
) -> heed::Result<()> {
|
||||
self.main.remap_types::<Str, SerdeJson<_>>().put(txn, main_key::SORT_FACET_VALUES_BY, &val)
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@ pub mod facet;
|
|||
mod fields_ids_map;
|
||||
pub mod heed_codec;
|
||||
pub mod index;
|
||||
pub mod order_by_map;
|
||||
pub mod prompt;
|
||||
pub mod proximity;
|
||||
pub mod score_details;
|
||||
|
@ -56,10 +57,10 @@ pub use self::heed_codec::{
|
|||
UncheckedU8StrStrCodec,
|
||||
};
|
||||
pub use self::index::Index;
|
||||
pub use self::search::facet::{FacetValueHit, SearchForFacetValues};
|
||||
pub use self::search::{
|
||||
FacetDistribution, FacetValueHit, Filter, FormatOptions, MatchBounds, MatcherBuilder,
|
||||
MatchingWords, OrderBy, Search, SearchForFacetValues, SearchResult, TermsMatchingStrategy,
|
||||
DEFAULT_VALUES_PER_FACET,
|
||||
FacetDistribution, Filter, FormatOptions, MatchBounds, MatcherBuilder, MatchingWords, OrderBy,
|
||||
Search, SearchResult, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET,
|
||||
};
|
||||
|
||||
pub type Result<T> = std::result::Result<T, error::Error>;
|
||||
|
|
57
milli/src/order_by_map.rs
Normal file
57
milli/src/order_by_map.rs
Normal file
|
@ -0,0 +1,57 @@
|
|||
use std::collections::{hash_map, HashMap};
|
||||
use std::iter::FromIterator;
|
||||
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
|
||||
use crate::OrderBy;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct OrderByMap(HashMap<String, OrderBy>);
|
||||
|
||||
impl OrderByMap {
|
||||
pub fn get(&self, key: impl AsRef<str>) -> OrderBy {
|
||||
self.0
|
||||
.get(key.as_ref())
|
||||
.copied()
|
||||
.unwrap_or_else(|| self.0.get("*").copied().unwrap_or_default())
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, key: String, value: OrderBy) -> Option<OrderBy> {
|
||||
self.0.insert(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for OrderByMap {
|
||||
fn default() -> Self {
|
||||
let mut map = HashMap::new();
|
||||
map.insert("*".to_string(), OrderBy::Lexicographic);
|
||||
OrderByMap(map)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromIterator<(String, OrderBy)> for OrderByMap {
|
||||
fn from_iter<T: IntoIterator<Item = (String, OrderBy)>>(iter: T) -> Self {
|
||||
OrderByMap(iter.into_iter().collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoIterator for OrderByMap {
|
||||
type Item = (String, OrderBy);
|
||||
type IntoIter = hash_map::IntoIter<String, OrderBy>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.0.into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for OrderByMap {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let mut map = Deserialize::deserialize(deserializer).map(OrderByMap)?;
|
||||
// Insert the default ordering if it is not already overwritten by the user.
|
||||
map.0.entry("*".to_string()).or_insert(OrderBy::default());
|
||||
Ok(map)
|
||||
}
|
||||
}
|
|
@ -168,7 +168,7 @@ impl<'t, 'b, 'bitmap> FacetRangeSearch<'t, 'b, 'bitmap> {
|
|||
}
|
||||
|
||||
// should we stop?
|
||||
// We should if the the search range doesn't include any
|
||||
// We should if the search range doesn't include any
|
||||
// element from the previous key or its successors
|
||||
let should_stop = {
|
||||
match self.right {
|
||||
|
@ -232,7 +232,7 @@ impl<'t, 'b, 'bitmap> FacetRangeSearch<'t, 'b, 'bitmap> {
|
|||
}
|
||||
|
||||
// should we stop?
|
||||
// We should if the the search range doesn't include any
|
||||
// We should if the search range doesn't include any
|
||||
// element from the previous key or its successors
|
||||
let should_stop = {
|
||||
match self.right {
|
||||
|
|
|
@ -6,15 +6,18 @@ use roaring::RoaringBitmap;
|
|||
|
||||
pub use self::facet_distribution::{FacetDistribution, OrderBy, DEFAULT_VALUES_PER_FACET};
|
||||
pub use self::filter::{BadGeoError, Filter};
|
||||
pub use self::search::{FacetValueHit, SearchForFacetValues};
|
||||
use crate::heed_codec::facet::{FacetGroupKeyCodec, FacetGroupValueCodec, OrderedF64Codec};
|
||||
use crate::heed_codec::BytesRefCodec;
|
||||
use crate::{Index, Result};
|
||||
|
||||
mod facet_distribution;
|
||||
mod facet_distribution_iter;
|
||||
mod facet_range_search;
|
||||
mod facet_sort_ascending;
|
||||
mod facet_sort_descending;
|
||||
mod filter;
|
||||
mod search;
|
||||
|
||||
fn facet_extreme_value<'t>(
|
||||
mut extreme_it: impl Iterator<Item = heed::Result<(RoaringBitmap, &'t [u8])>> + 't,
|
||||
|
|
326
milli/src/search/facet/search.rs
Normal file
326
milli/src/search/facet/search.rs
Normal file
|
@ -0,0 +1,326 @@
|
|||
use std::cmp::{Ordering, Reverse};
|
||||
use std::collections::BinaryHeap;
|
||||
use std::ops::ControlFlow;
|
||||
|
||||
use charabia::normalizer::NormalizerOption;
|
||||
use charabia::Normalize;
|
||||
use fst::automaton::{Automaton, Str};
|
||||
use fst::{IntoStreamer, Streamer};
|
||||
use roaring::RoaringBitmap;
|
||||
use tracing::error;
|
||||
|
||||
use crate::error::UserError;
|
||||
use crate::heed_codec::facet::{FacetGroupKey, FacetGroupValue};
|
||||
use crate::search::build_dfa;
|
||||
use crate::{DocumentId, FieldId, OrderBy, Result, Search};
|
||||
|
||||
/// The maximum number of values per facet returned by the facet search route.
|
||||
const DEFAULT_MAX_NUMBER_OF_VALUES_PER_FACET: usize = 100;
|
||||
|
||||
pub struct SearchForFacetValues<'a> {
|
||||
query: Option<String>,
|
||||
facet: String,
|
||||
search_query: Search<'a>,
|
||||
max_values: usize,
|
||||
is_hybrid: bool,
|
||||
}
|
||||
|
||||
impl<'a> SearchForFacetValues<'a> {
|
||||
pub fn new(
|
||||
facet: String,
|
||||
search_query: Search<'a>,
|
||||
is_hybrid: bool,
|
||||
) -> SearchForFacetValues<'a> {
|
||||
SearchForFacetValues {
|
||||
query: None,
|
||||
facet,
|
||||
search_query,
|
||||
max_values: DEFAULT_MAX_NUMBER_OF_VALUES_PER_FACET,
|
||||
is_hybrid,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn query(&mut self, query: impl Into<String>) -> &mut Self {
|
||||
self.query = Some(query.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn max_values(&mut self, max: usize) -> &mut Self {
|
||||
self.max_values = max;
|
||||
self
|
||||
}
|
||||
|
||||
fn one_original_value_of(
|
||||
&self,
|
||||
field_id: FieldId,
|
||||
facet_str: &str,
|
||||
any_docid: DocumentId,
|
||||
) -> Result<Option<String>> {
|
||||
let index = self.search_query.index;
|
||||
let rtxn = self.search_query.rtxn;
|
||||
let key: (FieldId, _, &str) = (field_id, any_docid, facet_str);
|
||||
Ok(index.field_id_docid_facet_strings.get(rtxn, &key)?.map(|v| v.to_owned()))
|
||||
}
|
||||
|
||||
pub fn execute(&self) -> Result<Vec<FacetValueHit>> {
|
||||
let index = self.search_query.index;
|
||||
let rtxn = self.search_query.rtxn;
|
||||
|
||||
let filterable_fields = index.filterable_fields(rtxn)?;
|
||||
if !filterable_fields.contains(&self.facet) {
|
||||
let (valid_fields, hidden_fields) =
|
||||
index.remove_hidden_fields(rtxn, filterable_fields)?;
|
||||
|
||||
return Err(UserError::InvalidFacetSearchFacetName {
|
||||
field: self.facet.clone(),
|
||||
valid_fields,
|
||||
hidden_fields,
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
||||
let fields_ids_map = index.fields_ids_map(rtxn)?;
|
||||
let fid = match fields_ids_map.id(&self.facet) {
|
||||
Some(fid) => fid,
|
||||
// we return an empty list of results when the attribute has been
|
||||
// set as filterable but no document contains this field (yet).
|
||||
None => return Ok(Vec::new()),
|
||||
};
|
||||
|
||||
let fst = match self.search_query.index.facet_id_string_fst.get(rtxn, &fid)? {
|
||||
Some(fst) => fst,
|
||||
None => return Ok(Vec::new()),
|
||||
};
|
||||
|
||||
let search_candidates = self
|
||||
.search_query
|
||||
.execute_for_candidates(self.is_hybrid || self.search_query.vector.is_some())?;
|
||||
|
||||
let mut results = match index.sort_facet_values_by(rtxn)?.get(&self.facet) {
|
||||
OrderBy::Lexicographic => ValuesCollection::by_lexicographic(self.max_values),
|
||||
OrderBy::Count => ValuesCollection::by_count(self.max_values),
|
||||
};
|
||||
|
||||
match self.query.as_ref() {
|
||||
Some(query) => {
|
||||
let options = NormalizerOption { lossy: true, ..Default::default() };
|
||||
let query = query.normalize(&options);
|
||||
let query = query.as_ref();
|
||||
|
||||
let authorize_typos = self.search_query.index.authorize_typos(rtxn)?;
|
||||
let field_authorizes_typos =
|
||||
!self.search_query.index.exact_attributes_ids(rtxn)?.contains(&fid);
|
||||
|
||||
if authorize_typos && field_authorizes_typos {
|
||||
let exact_words_fst = self.search_query.index.exact_words(rtxn)?;
|
||||
if exact_words_fst.map_or(false, |fst| fst.contains(query)) {
|
||||
if fst.contains(query) {
|
||||
self.fetch_original_facets_using_normalized(
|
||||
fid,
|
||||
query,
|
||||
query,
|
||||
&search_candidates,
|
||||
&mut results,
|
||||
)?;
|
||||
}
|
||||
} else {
|
||||
let one_typo = self.search_query.index.min_word_len_one_typo(rtxn)?;
|
||||
let two_typos = self.search_query.index.min_word_len_two_typos(rtxn)?;
|
||||
|
||||
let is_prefix = true;
|
||||
let automaton = if query.len() < one_typo as usize {
|
||||
build_dfa(query, 0, is_prefix)
|
||||
} else if query.len() < two_typos as usize {
|
||||
build_dfa(query, 1, is_prefix)
|
||||
} else {
|
||||
build_dfa(query, 2, is_prefix)
|
||||
};
|
||||
|
||||
let mut stream = fst.search(automaton).into_stream();
|
||||
while let Some(facet_value) = stream.next() {
|
||||
let value = std::str::from_utf8(facet_value)?;
|
||||
if self
|
||||
.fetch_original_facets_using_normalized(
|
||||
fid,
|
||||
value,
|
||||
query,
|
||||
&search_candidates,
|
||||
&mut results,
|
||||
)?
|
||||
.is_break()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let automaton = Str::new(query).starts_with();
|
||||
let mut stream = fst.search(automaton).into_stream();
|
||||
while let Some(facet_value) = stream.next() {
|
||||
let value = std::str::from_utf8(facet_value)?;
|
||||
if self
|
||||
.fetch_original_facets_using_normalized(
|
||||
fid,
|
||||
value,
|
||||
query,
|
||||
&search_candidates,
|
||||
&mut results,
|
||||
)?
|
||||
.is_break()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let prefix = FacetGroupKey { field_id: fid, level: 0, left_bound: "" };
|
||||
for result in index.facet_id_string_docids.prefix_iter(rtxn, &prefix)? {
|
||||
let (FacetGroupKey { left_bound, .. }, FacetGroupValue { bitmap, .. }) =
|
||||
result?;
|
||||
let count = search_candidates.intersection_len(&bitmap);
|
||||
if count != 0 {
|
||||
let value = self
|
||||
.one_original_value_of(fid, left_bound, bitmap.min().unwrap())?
|
||||
.unwrap_or_else(|| left_bound.to_string());
|
||||
if results.insert(FacetValueHit { value, count }).is_break() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results.into_sorted_vec())
|
||||
}
|
||||
|
||||
fn fetch_original_facets_using_normalized(
|
||||
&self,
|
||||
fid: FieldId,
|
||||
value: &str,
|
||||
query: &str,
|
||||
search_candidates: &RoaringBitmap,
|
||||
results: &mut ValuesCollection,
|
||||
) -> Result<ControlFlow<()>> {
|
||||
let index = self.search_query.index;
|
||||
let rtxn = self.search_query.rtxn;
|
||||
|
||||
let database = index.facet_id_normalized_string_strings;
|
||||
let key = (fid, value);
|
||||
let original_strings = match database.get(rtxn, &key)? {
|
||||
Some(original_strings) => original_strings,
|
||||
None => {
|
||||
error!("the facet value is missing from the facet database: {key:?}");
|
||||
return Ok(ControlFlow::Continue(()));
|
||||
}
|
||||
};
|
||||
for original in original_strings {
|
||||
let key = FacetGroupKey { field_id: fid, level: 0, left_bound: original.as_str() };
|
||||
let docids = match index.facet_id_string_docids.get(rtxn, &key)? {
|
||||
Some(FacetGroupValue { bitmap, .. }) => bitmap,
|
||||
None => {
|
||||
error!("the facet value is missing from the facet database: {key:?}");
|
||||
return Ok(ControlFlow::Continue(()));
|
||||
}
|
||||
};
|
||||
let count = search_candidates.intersection_len(&docids);
|
||||
if count != 0 {
|
||||
let value = self
|
||||
.one_original_value_of(fid, &original, docids.min().unwrap())?
|
||||
.unwrap_or_else(|| query.to_string());
|
||||
if results.insert(FacetValueHit { value, count }).is_break() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ControlFlow::Continue(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, PartialEq)]
|
||||
pub struct FacetValueHit {
|
||||
/// The original facet value
|
||||
pub value: String,
|
||||
/// The number of documents associated to this facet
|
||||
pub count: u64,
|
||||
}
|
||||
|
||||
impl PartialOrd for FacetValueHit {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for FacetValueHit {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
self.count.cmp(&other.count).then_with(|| self.value.cmp(&other.value))
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for FacetValueHit {}
|
||||
|
||||
/// A wrapper type that collects the best facet values by
|
||||
/// lexicographic or number of associated values.
|
||||
enum ValuesCollection {
|
||||
/// Keeps the top values according to the lexicographic order.
|
||||
Lexicographic { max: usize, content: Vec<FacetValueHit> },
|
||||
/// Keeps the top values according to the number of values associated to them.
|
||||
///
|
||||
/// Note that it is a max heap and we need to move the smallest counts
|
||||
/// at the top to be able to pop them when we reach the max_values limit.
|
||||
Count { max: usize, content: BinaryHeap<Reverse<FacetValueHit>> },
|
||||
}
|
||||
|
||||
impl ValuesCollection {
|
||||
pub fn by_lexicographic(max: usize) -> Self {
|
||||
ValuesCollection::Lexicographic { max, content: Vec::new() }
|
||||
}
|
||||
|
||||
pub fn by_count(max: usize) -> Self {
|
||||
ValuesCollection::Count { max, content: BinaryHeap::new() }
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, value: FacetValueHit) -> ControlFlow<()> {
|
||||
match self {
|
||||
ValuesCollection::Lexicographic { max, content } => {
|
||||
if content.len() < *max {
|
||||
content.push(value);
|
||||
if content.len() < *max {
|
||||
return ControlFlow::Continue(());
|
||||
}
|
||||
}
|
||||
ControlFlow::Break(())
|
||||
}
|
||||
ValuesCollection::Count { max, content } => {
|
||||
if content.len() == *max {
|
||||
// Peeking gives us the worst value in the list as
|
||||
// this is a max-heap and we reversed it.
|
||||
let Some(mut peek) = content.peek_mut() else { return ControlFlow::Break(()) };
|
||||
if peek.0.count <= value.count {
|
||||
// Replace the current worst value in the heap
|
||||
// with the new one we received that is better.
|
||||
*peek = Reverse(value);
|
||||
}
|
||||
} else {
|
||||
content.push(Reverse(value));
|
||||
}
|
||||
ControlFlow::Continue(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the list of facet values in descending order of, either,
|
||||
/// count or lexicographic order of the value depending on the type.
|
||||
pub fn into_sorted_vec(self) -> Vec<FacetValueHit> {
|
||||
match self {
|
||||
ValuesCollection::Lexicographic { content, .. } => content.into_iter().collect(),
|
||||
ValuesCollection::Count { content, .. } => {
|
||||
// Convert the heap into a vec of hits by removing the Reverse wrapper.
|
||||
// Hits are already in the right order as they were reversed and there
|
||||
// are output in ascending order.
|
||||
content.into_sorted_vec().into_iter().map(|Reverse(hit)| hit).collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,25 +1,17 @@
|
|||
use std::fmt;
|
||||
use std::ops::ControlFlow;
|
||||
|
||||
use charabia::normalizer::NormalizerOption;
|
||||
use charabia::Normalize;
|
||||
use fst::automaton::{Automaton, Str};
|
||||
use fst::{IntoStreamer, Streamer};
|
||||
use levenshtein_automata::{LevenshteinAutomatonBuilder as LevBuilder, DFA};
|
||||
use once_cell::sync::Lazy;
|
||||
use roaring::bitmap::RoaringBitmap;
|
||||
use tracing::error;
|
||||
|
||||
pub use self::facet::{FacetDistribution, Filter, OrderBy, DEFAULT_VALUES_PER_FACET};
|
||||
pub use self::new::matches::{FormatOptions, MatchBounds, MatcherBuilder, MatchingWords};
|
||||
use self::new::{execute_vector_search, PartialSearchResult};
|
||||
use crate::error::UserError;
|
||||
use crate::heed_codec::facet::{FacetGroupKey, FacetGroupValue};
|
||||
use crate::score_details::{ScoreDetails, ScoringStrategy};
|
||||
use crate::vector::DistributionShift;
|
||||
use crate::{
|
||||
execute_search, filtered_universe, AscDesc, DefaultSearchLogger, DocumentId, FieldId, Index,
|
||||
Result, SearchContext,
|
||||
execute_search, filtered_universe, AscDesc, DefaultSearchLogger, DocumentId, Index, Result,
|
||||
SearchContext,
|
||||
};
|
||||
|
||||
// Building these factories is not free.
|
||||
|
@ -27,9 +19,6 @@ static LEVDIST0: Lazy<LevBuilder> = Lazy::new(|| LevBuilder::new(0, true));
|
|||
static LEVDIST1: Lazy<LevBuilder> = Lazy::new(|| LevBuilder::new(1, true));
|
||||
static LEVDIST2: Lazy<LevBuilder> = Lazy::new(|| LevBuilder::new(2, true));
|
||||
|
||||
/// The maximum number of values per facet returned by the facet search route.
|
||||
const DEFAULT_MAX_NUMBER_OF_VALUES_PER_FACET: usize = 100;
|
||||
|
||||
pub mod facet;
|
||||
mod fst_utils;
|
||||
pub mod hybrid;
|
||||
|
@ -302,240 +291,6 @@ pub fn build_dfa(word: &str, typos: u8, is_prefix: bool) -> DFA {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct SearchForFacetValues<'a> {
|
||||
query: Option<String>,
|
||||
facet: String,
|
||||
search_query: Search<'a>,
|
||||
max_values: usize,
|
||||
is_hybrid: bool,
|
||||
}
|
||||
|
||||
impl<'a> SearchForFacetValues<'a> {
|
||||
pub fn new(
|
||||
facet: String,
|
||||
search_query: Search<'a>,
|
||||
is_hybrid: bool,
|
||||
) -> SearchForFacetValues<'a> {
|
||||
SearchForFacetValues {
|
||||
query: None,
|
||||
facet,
|
||||
search_query,
|
||||
max_values: DEFAULT_MAX_NUMBER_OF_VALUES_PER_FACET,
|
||||
is_hybrid,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn query(&mut self, query: impl Into<String>) -> &mut Self {
|
||||
self.query = Some(query.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn max_values(&mut self, max: usize) -> &mut Self {
|
||||
self.max_values = max;
|
||||
self
|
||||
}
|
||||
|
||||
fn one_original_value_of(
|
||||
&self,
|
||||
field_id: FieldId,
|
||||
facet_str: &str,
|
||||
any_docid: DocumentId,
|
||||
) -> Result<Option<String>> {
|
||||
let index = self.search_query.index;
|
||||
let rtxn = self.search_query.rtxn;
|
||||
let key: (FieldId, _, &str) = (field_id, any_docid, facet_str);
|
||||
Ok(index.field_id_docid_facet_strings.get(rtxn, &key)?.map(|v| v.to_owned()))
|
||||
}
|
||||
|
||||
pub fn execute(&self) -> Result<Vec<FacetValueHit>> {
|
||||
let index = self.search_query.index;
|
||||
let rtxn = self.search_query.rtxn;
|
||||
|
||||
let filterable_fields = index.filterable_fields(rtxn)?;
|
||||
if !filterable_fields.contains(&self.facet) {
|
||||
let (valid_fields, hidden_fields) =
|
||||
index.remove_hidden_fields(rtxn, filterable_fields)?;
|
||||
|
||||
return Err(UserError::InvalidFacetSearchFacetName {
|
||||
field: self.facet.clone(),
|
||||
valid_fields,
|
||||
hidden_fields,
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
||||
let fields_ids_map = index.fields_ids_map(rtxn)?;
|
||||
let fid = match fields_ids_map.id(&self.facet) {
|
||||
Some(fid) => fid,
|
||||
// we return an empty list of results when the attribute has been
|
||||
// set as filterable but no document contains this field (yet).
|
||||
None => return Ok(Vec::new()),
|
||||
};
|
||||
|
||||
let fst = match self.search_query.index.facet_id_string_fst.get(rtxn, &fid)? {
|
||||
Some(fst) => fst,
|
||||
None => return Ok(vec![]),
|
||||
};
|
||||
|
||||
let search_candidates = self
|
||||
.search_query
|
||||
.execute_for_candidates(self.is_hybrid || self.search_query.vector.is_some())?;
|
||||
|
||||
match self.query.as_ref() {
|
||||
Some(query) => {
|
||||
let options = NormalizerOption { lossy: true, ..Default::default() };
|
||||
let query = query.normalize(&options);
|
||||
let query = query.as_ref();
|
||||
|
||||
let authorize_typos = self.search_query.index.authorize_typos(rtxn)?;
|
||||
let field_authorizes_typos =
|
||||
!self.search_query.index.exact_attributes_ids(rtxn)?.contains(&fid);
|
||||
|
||||
if authorize_typos && field_authorizes_typos {
|
||||
let exact_words_fst = self.search_query.index.exact_words(rtxn)?;
|
||||
if exact_words_fst.map_or(false, |fst| fst.contains(query)) {
|
||||
let mut results = vec![];
|
||||
if fst.contains(query) {
|
||||
self.fetch_original_facets_using_normalized(
|
||||
fid,
|
||||
query,
|
||||
query,
|
||||
&search_candidates,
|
||||
&mut results,
|
||||
)?;
|
||||
}
|
||||
Ok(results)
|
||||
} else {
|
||||
let one_typo = self.search_query.index.min_word_len_one_typo(rtxn)?;
|
||||
let two_typos = self.search_query.index.min_word_len_two_typos(rtxn)?;
|
||||
|
||||
let is_prefix = true;
|
||||
let automaton = if query.len() < one_typo as usize {
|
||||
build_dfa(query, 0, is_prefix)
|
||||
} else if query.len() < two_typos as usize {
|
||||
build_dfa(query, 1, is_prefix)
|
||||
} else {
|
||||
build_dfa(query, 2, is_prefix)
|
||||
};
|
||||
|
||||
let mut stream = fst.search(automaton).into_stream();
|
||||
let mut results = vec![];
|
||||
while let Some(facet_value) = stream.next() {
|
||||
let value = std::str::from_utf8(facet_value)?;
|
||||
if self
|
||||
.fetch_original_facets_using_normalized(
|
||||
fid,
|
||||
value,
|
||||
query,
|
||||
&search_candidates,
|
||||
&mut results,
|
||||
)?
|
||||
.is_break()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
} else {
|
||||
let automaton = Str::new(query).starts_with();
|
||||
let mut stream = fst.search(automaton).into_stream();
|
||||
let mut results = vec![];
|
||||
while let Some(facet_value) = stream.next() {
|
||||
let value = std::str::from_utf8(facet_value)?;
|
||||
if self
|
||||
.fetch_original_facets_using_normalized(
|
||||
fid,
|
||||
value,
|
||||
query,
|
||||
&search_candidates,
|
||||
&mut results,
|
||||
)?
|
||||
.is_break()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let mut results = vec![];
|
||||
let prefix = FacetGroupKey { field_id: fid, level: 0, left_bound: "" };
|
||||
for result in index.facet_id_string_docids.prefix_iter(rtxn, &prefix)? {
|
||||
let (FacetGroupKey { left_bound, .. }, FacetGroupValue { bitmap, .. }) =
|
||||
result?;
|
||||
let count = search_candidates.intersection_len(&bitmap);
|
||||
if count != 0 {
|
||||
let value = self
|
||||
.one_original_value_of(fid, left_bound, bitmap.min().unwrap())?
|
||||
.unwrap_or_else(|| left_bound.to_string());
|
||||
results.push(FacetValueHit { value, count });
|
||||
}
|
||||
if results.len() >= self.max_values {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn fetch_original_facets_using_normalized(
|
||||
&self,
|
||||
fid: FieldId,
|
||||
value: &str,
|
||||
query: &str,
|
||||
search_candidates: &RoaringBitmap,
|
||||
results: &mut Vec<FacetValueHit>,
|
||||
) -> Result<ControlFlow<()>> {
|
||||
let index = self.search_query.index;
|
||||
let rtxn = self.search_query.rtxn;
|
||||
|
||||
let database = index.facet_id_normalized_string_strings;
|
||||
let key = (fid, value);
|
||||
let original_strings = match database.get(rtxn, &key)? {
|
||||
Some(original_strings) => original_strings,
|
||||
None => {
|
||||
error!("the facet value is missing from the facet database: {key:?}");
|
||||
return Ok(ControlFlow::Continue(()));
|
||||
}
|
||||
};
|
||||
for original in original_strings {
|
||||
let key = FacetGroupKey { field_id: fid, level: 0, left_bound: original.as_str() };
|
||||
let docids = match index.facet_id_string_docids.get(rtxn, &key)? {
|
||||
Some(FacetGroupValue { bitmap, .. }) => bitmap,
|
||||
None => {
|
||||
error!("the facet value is missing from the facet database: {key:?}");
|
||||
return Ok(ControlFlow::Continue(()));
|
||||
}
|
||||
};
|
||||
let count = search_candidates.intersection_len(&docids);
|
||||
if count != 0 {
|
||||
let value = self
|
||||
.one_original_value_of(fid, &original, docids.min().unwrap())?
|
||||
.unwrap_or_else(|| query.to_string());
|
||||
results.push(FacetValueHit { value, count });
|
||||
}
|
||||
if results.len() >= self.max_values {
|
||||
return Ok(ControlFlow::Break(()));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ControlFlow::Continue(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, PartialEq)]
|
||||
pub struct FacetValueHit {
|
||||
/// The original facet value
|
||||
pub value: String,
|
||||
/// The number of documents associated to this facet
|
||||
pub count: u64,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
#[allow(unused_imports)]
|
||||
|
|
|
@ -5,7 +5,7 @@ The typo ranking rule should transform the query graph such that it only contain
|
|||
the combinations of word derivations that it used to compute its bucket.
|
||||
|
||||
The proximity ranking rule should then look for proximities only between those specific derivations.
|
||||
For example, given the the search query `beautiful summer` and the dataset:
|
||||
For example, given the search query `beautiful summer` and the dataset:
|
||||
```text
|
||||
{ "id": 0, "text": "beautigul summer...... beautiful day in the summer" }
|
||||
{ "id": 1, "text": "beautiful summer" }
|
||||
|
|
|
@ -14,12 +14,13 @@ use super::IndexerConfig;
|
|||
use crate::criterion::Criterion;
|
||||
use crate::error::UserError;
|
||||
use crate::index::{DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS};
|
||||
use crate::order_by_map::OrderByMap;
|
||||
use crate::proximity::ProximityPrecision;
|
||||
use crate::update::index_documents::IndexDocumentsMethod;
|
||||
use crate::update::{IndexDocuments, UpdateIndexingStep};
|
||||
use crate::vector::settings::{check_set, check_unset, EmbedderSource, EmbeddingSettings};
|
||||
use crate::vector::{Embedder, EmbeddingConfig, EmbeddingConfigs};
|
||||
use crate::{FieldsIdsMap, Index, OrderBy, Result};
|
||||
use crate::{FieldsIdsMap, Index, Result};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
|
||||
pub enum Setting<T> {
|
||||
|
@ -145,7 +146,7 @@ pub struct Settings<'a, 't, 'i> {
|
|||
/// Attributes on which typo tolerance is disabled.
|
||||
exact_attributes: Setting<HashSet<String>>,
|
||||
max_values_per_facet: Setting<usize>,
|
||||
sort_facet_values_by: Setting<HashMap<String, OrderBy>>,
|
||||
sort_facet_values_by: Setting<OrderByMap>,
|
||||
pagination_max_total_hits: Setting<usize>,
|
||||
proximity_precision: Setting<ProximityPrecision>,
|
||||
embedder_settings: Setting<BTreeMap<String, Setting<EmbeddingSettings>>>,
|
||||
|
@ -340,7 +341,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
|
|||
self.max_values_per_facet = Setting::Reset;
|
||||
}
|
||||
|
||||
pub fn set_sort_facet_values_by(&mut self, value: HashMap<String, OrderBy>) {
|
||||
pub fn set_sort_facet_values_by(&mut self, value: OrderByMap) {
|
||||
self.sort_facet_values_by = Setting::Set(value);
|
||||
}
|
||||
|
||||
|
@ -1186,6 +1187,13 @@ pub fn validate_embedding_settings(
|
|||
}
|
||||
}
|
||||
}
|
||||
EmbedderSource::Ollama => {
|
||||
// Dimensions get inferred, only model name is required
|
||||
check_unset(&dimensions, "dimensions", inferred_source, name)?;
|
||||
check_set(&model, "model", inferred_source, name)?;
|
||||
check_unset(&api_key, "apiKey", inferred_source, name)?;
|
||||
check_unset(&revision, "revision", inferred_source, name)?;
|
||||
}
|
||||
EmbedderSource::HuggingFace => {
|
||||
check_unset(&api_key, "apiKey", inferred_source, name)?;
|
||||
check_unset(&dimensions, "dimensions", inferred_source, name)?;
|
||||
|
|
|
@ -2,6 +2,7 @@ use std::path::PathBuf;
|
|||
|
||||
use hf_hub::api::sync::ApiError;
|
||||
|
||||
use super::ollama::OllamaError;
|
||||
use crate::error::FaultSource;
|
||||
use crate::vector::openai::OpenAiError;
|
||||
|
||||
|
@ -71,6 +72,17 @@ pub enum EmbedErrorKind {
|
|||
OpenAiRuntimeInit(std::io::Error),
|
||||
#[error("initializing web client for sending embedding requests failed: {0}")]
|
||||
InitWebClient(reqwest::Error),
|
||||
// Dedicated Ollama error kinds, might have to merge them into one cohesive error type for all backends.
|
||||
#[error("unexpected response from Ollama: {0}")]
|
||||
OllamaUnexpected(reqwest::Error),
|
||||
#[error("sent too many requests to Ollama: {0}")]
|
||||
OllamaTooManyRequests(OllamaError),
|
||||
#[error("received internal error from Ollama: {0}")]
|
||||
OllamaInternalServerError(OllamaError),
|
||||
#[error("model not found. Meilisearch will not automatically download models from the Ollama library, please pull the model manually: {0}")]
|
||||
OllamaModelNotFoundError(OllamaError),
|
||||
#[error("received unhandled HTTP status code {0} from Ollama")]
|
||||
OllamaUnhandledStatusCode(u16),
|
||||
}
|
||||
|
||||
impl EmbedError {
|
||||
|
@ -129,6 +141,26 @@ impl EmbedError {
|
|||
pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self {
|
||||
Self { kind: EmbedErrorKind::InitWebClient(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub(crate) fn ollama_unexpected(inner: reqwest::Error) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::OllamaUnexpected(inner), fault: FaultSource::Bug }
|
||||
}
|
||||
|
||||
pub(crate) fn ollama_model_not_found(inner: OllamaError) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::OllamaModelNotFoundError(inner), fault: FaultSource::User }
|
||||
}
|
||||
|
||||
pub(crate) fn ollama_too_many_requests(inner: OllamaError) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::OllamaTooManyRequests(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub(crate) fn ollama_internal_server_error(inner: OllamaError) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::OllamaInternalServerError(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub(crate) fn ollama_unhandled_status_code(code: u16) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::OllamaUnhandledStatusCode(code), fault: FaultSource::Bug }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
|
@ -195,6 +227,13 @@ impl NewEmbedderError {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn ollama_could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError {
|
||||
Self {
|
||||
kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner),
|
||||
fault: FaultSource::User,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn openai_invalid_api_key_format(inner: reqwest::header::InvalidHeaderValue) -> Self {
|
||||
Self { kind: NewEmbedderErrorKind::InvalidApiKeyFormat(inner), fault: FaultSource::User }
|
||||
}
|
||||
|
|
|
@ -10,6 +10,8 @@ pub mod manual;
|
|||
pub mod openai;
|
||||
pub mod settings;
|
||||
|
||||
pub mod ollama;
|
||||
|
||||
pub use self::error::Error;
|
||||
|
||||
pub type Embedding = Vec<f32>;
|
||||
|
@ -76,6 +78,7 @@ pub enum Embedder {
|
|||
HuggingFace(hf::Embedder),
|
||||
OpenAi(openai::Embedder),
|
||||
UserProvided(manual::Embedder),
|
||||
Ollama(ollama::Embedder),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
|
||||
|
@ -127,6 +130,7 @@ impl IntoIterator for EmbeddingConfigs {
|
|||
pub enum EmbedderOptions {
|
||||
HuggingFace(hf::EmbedderOptions),
|
||||
OpenAi(openai::EmbedderOptions),
|
||||
Ollama(ollama::EmbedderOptions),
|
||||
UserProvided(manual::EmbedderOptions),
|
||||
}
|
||||
|
||||
|
@ -144,6 +148,10 @@ impl EmbedderOptions {
|
|||
pub fn openai(api_key: Option<String>) -> Self {
|
||||
Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key))
|
||||
}
|
||||
|
||||
pub fn ollama() -> Self {
|
||||
Self::Ollama(ollama::EmbedderOptions::with_default_model())
|
||||
}
|
||||
}
|
||||
|
||||
impl Embedder {
|
||||
|
@ -151,6 +159,7 @@ impl Embedder {
|
|||
Ok(match options {
|
||||
EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?),
|
||||
EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?),
|
||||
EmbedderOptions::Ollama(options) => Self::Ollama(ollama::Embedder::new(options)?),
|
||||
EmbedderOptions::UserProvided(options) => {
|
||||
Self::UserProvided(manual::Embedder::new(options))
|
||||
}
|
||||
|
@ -167,6 +176,10 @@ impl Embedder {
|
|||
let client = embedder.new_client()?;
|
||||
embedder.embed(texts, &client).await
|
||||
}
|
||||
Embedder::Ollama(embedder) => {
|
||||
let client = embedder.new_client()?;
|
||||
embedder.embed(texts, &client).await
|
||||
}
|
||||
Embedder::UserProvided(embedder) => embedder.embed(texts),
|
||||
}
|
||||
}
|
||||
|
@ -181,6 +194,7 @@ impl Embedder {
|
|||
match self {
|
||||
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),
|
||||
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks),
|
||||
Embedder::Ollama(embedder) => embedder.embed_chunks(text_chunks),
|
||||
Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks),
|
||||
}
|
||||
}
|
||||
|
@ -189,6 +203,7 @@ impl Embedder {
|
|||
match self {
|
||||
Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(),
|
||||
Embedder::OpenAi(embedder) => embedder.chunk_count_hint(),
|
||||
Embedder::Ollama(embedder) => embedder.chunk_count_hint(),
|
||||
Embedder::UserProvided(_) => 1,
|
||||
}
|
||||
}
|
||||
|
@ -197,6 +212,7 @@ impl Embedder {
|
|||
match self {
|
||||
Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(),
|
||||
Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(),
|
||||
Embedder::Ollama(embedder) => embedder.prompt_count_in_chunk_hint(),
|
||||
Embedder::UserProvided(_) => 1,
|
||||
}
|
||||
}
|
||||
|
@ -205,6 +221,7 @@ impl Embedder {
|
|||
match self {
|
||||
Embedder::HuggingFace(embedder) => embedder.dimensions(),
|
||||
Embedder::OpenAi(embedder) => embedder.dimensions(),
|
||||
Embedder::Ollama(embedder) => embedder.dimensions(),
|
||||
Embedder::UserProvided(embedder) => embedder.dimensions(),
|
||||
}
|
||||
}
|
||||
|
@ -213,6 +230,7 @@ impl Embedder {
|
|||
match self {
|
||||
Embedder::HuggingFace(embedder) => embedder.distribution(),
|
||||
Embedder::OpenAi(embedder) => embedder.distribution(),
|
||||
Embedder::Ollama(embedder) => embedder.distribution(),
|
||||
Embedder::UserProvided(_embedder) => None,
|
||||
}
|
||||
}
|
||||
|
|
307
milli/src/vector/ollama.rs
Normal file
307
milli/src/vector/ollama.rs
Normal file
|
@ -0,0 +1,307 @@
|
|||
// Copied from "openai.rs" with the sections I actually understand changed for Ollama.
|
||||
// The common components of the Ollama and OpenAI interfaces might need to be extracted.
|
||||
|
||||
use std::fmt::Display;
|
||||
|
||||
use reqwest::StatusCode;
|
||||
|
||||
use super::error::{EmbedError, NewEmbedderError};
|
||||
use super::openai::Retry;
|
||||
use super::{DistributionShift, Embedding, Embeddings};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Embedder {
|
||||
headers: reqwest::header::HeaderMap,
|
||||
options: EmbedderOptions,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
||||
pub struct EmbedderOptions {
|
||||
pub embedding_model: EmbeddingModel,
|
||||
}
|
||||
|
||||
#[derive(
|
||||
Debug, Clone, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize, deserr::Deserr,
|
||||
)]
|
||||
#[deserr(deny_unknown_fields)]
|
||||
pub struct EmbeddingModel {
|
||||
name: String,
|
||||
dimensions: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
struct OllamaRequest<'a> {
|
||||
model: &'a str,
|
||||
prompt: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct OllamaResponse {
|
||||
embedding: Embedding,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
pub struct OllamaError {
|
||||
error: String,
|
||||
}
|
||||
|
||||
impl EmbeddingModel {
|
||||
pub fn max_token(&self) -> usize {
|
||||
// this might not be the same for all models
|
||||
8192
|
||||
}
|
||||
|
||||
pub fn default_dimensions(&self) -> usize {
|
||||
// Dimensions for nomic-embed-text
|
||||
768
|
||||
}
|
||||
|
||||
pub fn name(&self) -> String {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
pub fn from_name(name: &str) -> Self {
|
||||
Self { name: name.to_string(), dimensions: 0 }
|
||||
}
|
||||
|
||||
pub fn supports_overriding_dimensions(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for EmbeddingModel {
|
||||
fn default() -> Self {
|
||||
Self { name: "nomic-embed-text".to_string(), dimensions: 0 }
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbedderOptions {
|
||||
pub fn with_default_model() -> Self {
|
||||
Self { embedding_model: Default::default() }
|
||||
}
|
||||
|
||||
pub fn with_embedding_model(embedding_model: EmbeddingModel) -> Self {
|
||||
Self { embedding_model }
|
||||
}
|
||||
}
|
||||
|
||||
impl Embedder {
|
||||
pub fn new_client(&self) -> Result<reqwest::Client, EmbedError> {
|
||||
reqwest::ClientBuilder::new()
|
||||
.default_headers(self.headers.clone())
|
||||
.build()
|
||||
.map_err(EmbedError::openai_initialize_web_client)
|
||||
}
|
||||
|
||||
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
reqwest::header::CONTENT_TYPE,
|
||||
reqwest::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
let mut embedder = Self { options, headers };
|
||||
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_io()
|
||||
.enable_time()
|
||||
.build()
|
||||
.map_err(EmbedError::openai_runtime_init)
|
||||
.map_err(NewEmbedderError::ollama_could_not_determine_dimension)?;
|
||||
|
||||
// Get dimensions from Ollama
|
||||
let request =
|
||||
OllamaRequest { model: &embedder.options.embedding_model.name(), prompt: "test" };
|
||||
// TODO: Refactor into shared error type
|
||||
let client = embedder
|
||||
.new_client()
|
||||
.map_err(NewEmbedderError::ollama_could_not_determine_dimension)?;
|
||||
|
||||
rt.block_on(async move {
|
||||
let response = client
|
||||
.post(get_ollama_path())
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(EmbedError::ollama_unexpected)
|
||||
.map_err(NewEmbedderError::ollama_could_not_determine_dimension)?;
|
||||
|
||||
// Process error in case model not found
|
||||
let response = Self::check_response(response).await.map_err(|_err| {
|
||||
let e = EmbedError::ollama_model_not_found(OllamaError {
|
||||
error: format!("model: {}", embedder.options.embedding_model.name()),
|
||||
});
|
||||
NewEmbedderError::ollama_could_not_determine_dimension(e)
|
||||
})?;
|
||||
|
||||
let response: OllamaResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(EmbedError::ollama_unexpected)
|
||||
.map_err(NewEmbedderError::ollama_could_not_determine_dimension)?;
|
||||
|
||||
let embedding = Embeddings::from_single_embedding(response.embedding);
|
||||
|
||||
embedder.options.embedding_model.dimensions = embedding.dimension();
|
||||
|
||||
tracing::info!(
|
||||
"ollama model {} with dimensionality {} added",
|
||||
embedder.options.embedding_model.name(),
|
||||
embedding.dimension()
|
||||
);
|
||||
|
||||
Ok(embedder)
|
||||
})
|
||||
}
|
||||
|
||||
async fn check_response(response: reqwest::Response) -> Result<reqwest::Response, Retry> {
|
||||
if !response.status().is_success() {
|
||||
// Not the same number of possible error cases covered as with OpenAI.
|
||||
match response.status() {
|
||||
StatusCode::TOO_MANY_REQUESTS => {
|
||||
let error_response: OllamaError = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(EmbedError::ollama_unexpected)
|
||||
.map_err(Retry::retry_later)?;
|
||||
|
||||
return Err(Retry::rate_limited(EmbedError::ollama_too_many_requests(
|
||||
OllamaError { error: error_response.error },
|
||||
)));
|
||||
}
|
||||
StatusCode::SERVICE_UNAVAILABLE => {
|
||||
let error_response: OllamaError = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(EmbedError::ollama_unexpected)
|
||||
.map_err(Retry::retry_later)?;
|
||||
return Err(Retry::retry_later(EmbedError::ollama_internal_server_error(
|
||||
OllamaError { error: error_response.error },
|
||||
)));
|
||||
}
|
||||
StatusCode::NOT_FOUND => {
|
||||
let error_response: OllamaError = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(EmbedError::ollama_unexpected)
|
||||
.map_err(Retry::give_up)?;
|
||||
|
||||
return Err(Retry::give_up(EmbedError::ollama_model_not_found(OllamaError {
|
||||
error: error_response.error,
|
||||
})));
|
||||
}
|
||||
code => {
|
||||
return Err(Retry::give_up(EmbedError::ollama_unhandled_status_code(
|
||||
code.as_u16(),
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub async fn embed(
|
||||
&self,
|
||||
texts: Vec<String>,
|
||||
client: &reqwest::Client,
|
||||
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
// Ollama only embedds one document at a time.
|
||||
let mut results = Vec::with_capacity(texts.len());
|
||||
|
||||
// The retry loop is inside the texts loop, might have to switch that around
|
||||
for text in texts {
|
||||
// Retries copied from openai.rs
|
||||
for attempt in 0..7 {
|
||||
let retry_duration = match self.try_embed(&text, client).await {
|
||||
Ok(result) => {
|
||||
results.push(result);
|
||||
break;
|
||||
}
|
||||
Err(retry) => {
|
||||
tracing::warn!("Failed: {}", retry.error);
|
||||
retry.into_duration(attempt)
|
||||
}
|
||||
}?;
|
||||
tracing::warn!(
|
||||
"Attempt #{}, retrying after {}ms.",
|
||||
attempt,
|
||||
retry_duration.as_millis()
|
||||
);
|
||||
tokio::time::sleep(retry_duration).await;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
async fn try_embed(
|
||||
&self,
|
||||
text: &str,
|
||||
client: &reqwest::Client,
|
||||
) -> Result<Embeddings<f32>, Retry> {
|
||||
let request = OllamaRequest { model: &self.options.embedding_model.name(), prompt: text };
|
||||
let response = client
|
||||
.post(get_ollama_path())
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(EmbedError::openai_network)
|
||||
.map_err(Retry::retry_later)?;
|
||||
|
||||
let response = Self::check_response(response).await?;
|
||||
|
||||
let response: OllamaResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(EmbedError::openai_unexpected)
|
||||
.map_err(Retry::retry_later)?;
|
||||
|
||||
tracing::trace!("response: {:?}", response.embedding);
|
||||
|
||||
let embedding = Embeddings::from_single_embedding(response.embedding);
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
pub fn embed_chunks(
|
||||
&self,
|
||||
text_chunks: Vec<Vec<String>>,
|
||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_io()
|
||||
.enable_time()
|
||||
.build()
|
||||
.map_err(EmbedError::openai_runtime_init)?;
|
||||
let client = self.new_client()?;
|
||||
rt.block_on(futures::future::try_join_all(
|
||||
text_chunks.into_iter().map(|prompts| self.embed(prompts, &client)),
|
||||
))
|
||||
}
|
||||
|
||||
// Defaults copied from openai.rs
|
||||
pub fn chunk_count_hint(&self) -> usize {
|
||||
10
|
||||
}
|
||||
|
||||
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
||||
10
|
||||
}
|
||||
|
||||
pub fn dimensions(&self) -> usize {
|
||||
self.options.embedding_model.dimensions
|
||||
}
|
||||
|
||||
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for OllamaError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.error)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_ollama_path() -> String {
|
||||
// Important: Hostname not enough, has to be entire path to embeddings endpoint
|
||||
std::env::var("MEILI_OLLAMA_URL").unwrap_or("http://localhost:11434/api/embeddings".to_string())
|
||||
}
|
|
@ -419,12 +419,12 @@ impl Embedder {
|
|||
|
||||
// retrying in case of failure
|
||||
|
||||
struct Retry {
|
||||
error: EmbedError,
|
||||
pub struct Retry {
|
||||
pub error: EmbedError,
|
||||
strategy: RetryStrategy,
|
||||
}
|
||||
|
||||
enum RetryStrategy {
|
||||
pub enum RetryStrategy {
|
||||
GiveUp,
|
||||
Retry,
|
||||
RetryTokenized,
|
||||
|
@ -432,23 +432,23 @@ enum RetryStrategy {
|
|||
}
|
||||
|
||||
impl Retry {
|
||||
fn give_up(error: EmbedError) -> Self {
|
||||
pub fn give_up(error: EmbedError) -> Self {
|
||||
Self { error, strategy: RetryStrategy::GiveUp }
|
||||
}
|
||||
|
||||
fn retry_later(error: EmbedError) -> Self {
|
||||
pub fn retry_later(error: EmbedError) -> Self {
|
||||
Self { error, strategy: RetryStrategy::Retry }
|
||||
}
|
||||
|
||||
fn retry_tokenized(error: EmbedError) -> Self {
|
||||
pub fn retry_tokenized(error: EmbedError) -> Self {
|
||||
Self { error, strategy: RetryStrategy::RetryTokenized }
|
||||
}
|
||||
|
||||
fn rate_limited(error: EmbedError) -> Self {
|
||||
pub fn rate_limited(error: EmbedError) -> Self {
|
||||
Self { error, strategy: RetryStrategy::RetryAfterRateLimit }
|
||||
}
|
||||
|
||||
fn into_duration(self, attempt: u32) -> Result<tokio::time::Duration, EmbedError> {
|
||||
pub fn into_duration(self, attempt: u32) -> Result<tokio::time::Duration, EmbedError> {
|
||||
match self.strategy {
|
||||
RetryStrategy::GiveUp => Err(self.error),
|
||||
RetryStrategy::Retry => Ok(tokio::time::Duration::from_millis((10u64).pow(attempt))),
|
||||
|
@ -459,11 +459,11 @@ impl Retry {
|
|||
}
|
||||
}
|
||||
|
||||
fn must_tokenize(&self) -> bool {
|
||||
pub fn must_tokenize(&self) -> bool {
|
||||
matches!(self.strategy, RetryStrategy::RetryTokenized)
|
||||
}
|
||||
|
||||
fn into_error(self) -> EmbedError {
|
||||
pub fn into_error(self) -> EmbedError {
|
||||
self.error
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use deserr::Deserr;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::openai;
|
||||
use super::{ollama, openai};
|
||||
use crate::prompt::PromptData;
|
||||
use crate::update::Setting;
|
||||
use crate::vector::EmbeddingConfig;
|
||||
|
@ -80,11 +80,15 @@ impl EmbeddingSettings {
|
|||
Self::SOURCE => {
|
||||
&[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::UserProvided]
|
||||
}
|
||||
Self::MODEL => &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi],
|
||||
Self::MODEL => {
|
||||
&[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::Ollama]
|
||||
}
|
||||
Self::REVISION => &[EmbedderSource::HuggingFace],
|
||||
Self::API_KEY => &[EmbedderSource::OpenAi],
|
||||
Self::DIMENSIONS => &[EmbedderSource::OpenAi, EmbedderSource::UserProvided],
|
||||
Self::DOCUMENT_TEMPLATE => &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi],
|
||||
Self::DOCUMENT_TEMPLATE => {
|
||||
&[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::Ollama]
|
||||
}
|
||||
_other => unreachable!("unknown field"),
|
||||
}
|
||||
}
|
||||
|
@ -101,6 +105,7 @@ impl EmbeddingSettings {
|
|||
EmbedderSource::HuggingFace => {
|
||||
&[Self::SOURCE, Self::MODEL, Self::REVISION, Self::DOCUMENT_TEMPLATE]
|
||||
}
|
||||
EmbedderSource::Ollama => &[Self::SOURCE, Self::MODEL, Self::DOCUMENT_TEMPLATE],
|
||||
EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS],
|
||||
}
|
||||
}
|
||||
|
@ -134,6 +139,7 @@ pub enum EmbedderSource {
|
|||
#[default]
|
||||
OpenAi,
|
||||
HuggingFace,
|
||||
Ollama,
|
||||
UserProvided,
|
||||
}
|
||||
|
||||
|
@ -143,6 +149,7 @@ impl std::fmt::Display for EmbedderSource {
|
|||
EmbedderSource::OpenAi => "openAi",
|
||||
EmbedderSource::HuggingFace => "huggingFace",
|
||||
EmbedderSource::UserProvided => "userProvided",
|
||||
EmbedderSource::Ollama => "ollama",
|
||||
};
|
||||
f.write_str(s)
|
||||
}
|
||||
|
@ -195,6 +202,14 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
|
|||
dimensions: options.dimensions.map(Setting::Set).unwrap_or_default(),
|
||||
document_template: Setting::Set(prompt.template),
|
||||
},
|
||||
super::EmbedderOptions::Ollama(options) => Self {
|
||||
source: Setting::Set(EmbedderSource::Ollama),
|
||||
model: Setting::Set(options.embedding_model.name().to_owned()),
|
||||
revision: Setting::NotSet,
|
||||
api_key: Setting::NotSet,
|
||||
dimensions: Setting::NotSet,
|
||||
document_template: Setting::Set(prompt.template),
|
||||
},
|
||||
super::EmbedderOptions::UserProvided(options) => Self {
|
||||
source: Setting::Set(EmbedderSource::UserProvided),
|
||||
model: Setting::NotSet,
|
||||
|
@ -229,6 +244,14 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
|
|||
}
|
||||
this.embedder_options = super::EmbedderOptions::OpenAi(options);
|
||||
}
|
||||
EmbedderSource::Ollama => {
|
||||
let mut options: ollama::EmbedderOptions =
|
||||
super::ollama::EmbedderOptions::with_default_model();
|
||||
if let Some(model) = model.set() {
|
||||
options.embedding_model = super::ollama::EmbeddingModel::from_name(&model);
|
||||
}
|
||||
this.embedder_options = super::EmbedderOptions::Ollama(options);
|
||||
}
|
||||
EmbedderSource::HuggingFace => {
|
||||
let mut options = super::hf::EmbedderOptions::default();
|
||||
if let Some(model) = model.set() {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue