mirror of
https://github.com/meilisearch/MeiliSearch
synced 2024-11-26 06:44:27 +01:00
Various changes
- DistributionShift in Search object (to be set from model in embed?) - Fix issue where embedder index wasn't computed at search time - Accept as default embedder either the "default" one, or the only embedder when there is only one
This commit is contained in:
parent
12940d79a9
commit
e0cc775dc4
@ -52,7 +52,7 @@ use meilisearch_types::heed::types::{SerdeBincode, SerdeJson, Str, I128};
|
|||||||
use meilisearch_types::heed::{self, Database, Env, PutFlags, RoTxn, RwTxn};
|
use meilisearch_types::heed::{self, Database, Env, PutFlags, RoTxn, RwTxn};
|
||||||
use meilisearch_types::milli::documents::DocumentsBatchBuilder;
|
use meilisearch_types::milli::documents::DocumentsBatchBuilder;
|
||||||
use meilisearch_types::milli::update::IndexerConfig;
|
use meilisearch_types::milli::update::IndexerConfig;
|
||||||
use meilisearch_types::milli::vector::{Embedder, EmbedderOptions};
|
use meilisearch_types::milli::vector::{Embedder, EmbedderOptions, EmbeddingConfigs};
|
||||||
use meilisearch_types::milli::{self, CboRoaringBitmapCodec, Index, RoaringBitmapCodec, BEU32};
|
use meilisearch_types::milli::{self, CboRoaringBitmapCodec, Index, RoaringBitmapCodec, BEU32};
|
||||||
use meilisearch_types::tasks::{Kind, KindWithContent, Status, Task};
|
use meilisearch_types::tasks::{Kind, KindWithContent, Status, Task};
|
||||||
use puffin::FrameView;
|
use puffin::FrameView;
|
||||||
@ -1339,11 +1339,10 @@ impl IndexScheduler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO: consider using a type alias or a struct embedder/template
|
// TODO: consider using a type alias or a struct embedder/template
|
||||||
#[allow(clippy::type_complexity)]
|
|
||||||
pub fn embedders(
|
pub fn embedders(
|
||||||
&self,
|
&self,
|
||||||
embedding_configs: Vec<(String, milli::vector::EmbeddingConfig)>,
|
embedding_configs: Vec<(String, milli::vector::EmbeddingConfig)>,
|
||||||
) -> Result<HashMap<String, (Arc<milli::vector::Embedder>, Arc<milli::prompt::Prompt>)>> {
|
) -> Result<EmbeddingConfigs> {
|
||||||
let res: Result<_> = embedding_configs
|
let res: Result<_> = embedding_configs
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(name, milli::vector::EmbeddingConfig { embedder_options, prompt })| {
|
.map(|(name, milli::vector::EmbeddingConfig { embedder_options, prompt })| {
|
||||||
@ -1370,7 +1369,7 @@ impl IndexScheduler {
|
|||||||
Ok((name, (embedder, prompt)))
|
Ok((name, (embedder, prompt)))
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
res
|
res.map(EmbeddingConfigs::new)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Blocks the thread until the test handle asks to progress to/through this breakpoint.
|
/// Blocks the thread until the test handle asks to progress to/through this breakpoint.
|
||||||
|
@ -238,22 +238,28 @@ pub async fn embed(
|
|||||||
match query.vector.take() {
|
match query.vector.take() {
|
||||||
Some(VectorQuery::String(prompt)) => {
|
Some(VectorQuery::String(prompt)) => {
|
||||||
let embedder_configs = index.embedding_configs(&index.read_txn()?)?;
|
let embedder_configs = index.embedding_configs(&index.read_txn()?)?;
|
||||||
let embedder = index_scheduler.embedders(embedder_configs)?;
|
let embedders = index_scheduler.embedders(embedder_configs)?;
|
||||||
|
|
||||||
let embedder_name =
|
let embedder_name =
|
||||||
if let Some(HybridQuery { semantic_ratio: _, embedder: Some(embedder) }) =
|
if let Some(HybridQuery { semantic_ratio: _, embedder: Some(embedder) }) =
|
||||||
&query.hybrid
|
&query.hybrid
|
||||||
{
|
{
|
||||||
embedder
|
Some(embedder)
|
||||||
} else {
|
} else {
|
||||||
"default"
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
let embeddings = embedder
|
let embedder = if let Some(embedder_name) = embedder_name {
|
||||||
.get(embedder_name)
|
embedders.get(embedder_name)
|
||||||
.ok_or(milli::UserError::InvalidEmbedder(embedder_name.to_owned()))
|
} else {
|
||||||
|
embedders.get_default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let embedder = embedder
|
||||||
|
.ok_or(milli::UserError::InvalidEmbedder("default".to_owned()))
|
||||||
.map_err(milli::Error::from)?
|
.map_err(milli::Error::from)?
|
||||||
.0
|
.0;
|
||||||
|
let embeddings = embedder
|
||||||
.embed(vec![prompt])
|
.embed(vec![prompt])
|
||||||
.await
|
.await
|
||||||
.map_err(milli::vector::Error::from)
|
.map_err(milli::vector::Error::from)
|
||||||
|
@ -398,6 +398,10 @@ fn prepare_search<'t>(
|
|||||||
features.check_vector("Passing `vector` as a query parameter")?;
|
features.check_vector("Passing `vector` as a query parameter")?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(HybridQuery { embedder: Some(embedder), .. }) = &query.hybrid {
|
||||||
|
search.embedder_name(embedder);
|
||||||
|
}
|
||||||
|
|
||||||
// compute the offset on the limit depending on the pagination mode.
|
// compute the offset on the limit depending on the pagination mode.
|
||||||
let (offset, limit) = if is_finite_pagination {
|
let (offset, limit) = if is_finite_pagination {
|
||||||
let limit = query.hits_per_page.unwrap_or_else(DEFAULT_SEARCH_LIMIT);
|
let limit = query.hits_per_page.unwrap_or_else(DEFAULT_SEARCH_LIMIT);
|
||||||
|
@ -1499,6 +1499,14 @@ impl Index {
|
|||||||
.get(rtxn, main_key::EMBEDDING_CONFIGS)?
|
.get(rtxn, main_key::EMBEDDING_CONFIGS)?
|
||||||
.unwrap_or_default())
|
.unwrap_or_default())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn default_embedding_name(&self, rtxn: &RoTxn<'_>) -> Result<String> {
|
||||||
|
let configs = self.embedding_configs(rtxn)?;
|
||||||
|
Ok(match configs.as_slice() {
|
||||||
|
[(ref first_name, _)] => first_name.clone(),
|
||||||
|
_ => "default".to_owned(),
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -218,6 +218,8 @@ impl<'a> Search<'a> {
|
|||||||
exhaustive_number_hits: self.exhaustive_number_hits,
|
exhaustive_number_hits: self.exhaustive_number_hits,
|
||||||
rtxn: self.rtxn,
|
rtxn: self.rtxn,
|
||||||
index: self.index,
|
index: self.index,
|
||||||
|
distribution_shift: self.distribution_shift,
|
||||||
|
embedder_name: self.embedder_name.clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let vector_query = search.vector.take();
|
let vector_query = search.vector.take();
|
||||||
@ -265,6 +267,15 @@ impl<'a> Search<'a> {
|
|||||||
vector: &[f32],
|
vector: &[f32],
|
||||||
keyword_results: &SearchResult,
|
keyword_results: &SearchResult,
|
||||||
) -> Result<PartialSearchResult> {
|
) -> Result<PartialSearchResult> {
|
||||||
|
let embedder_name;
|
||||||
|
let embedder_name = match &self.embedder_name {
|
||||||
|
Some(embedder_name) => embedder_name,
|
||||||
|
None => {
|
||||||
|
embedder_name = self.index.default_embedding_name(self.rtxn)?;
|
||||||
|
&embedder_name
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let mut ctx = SearchContext::new(self.index, self.rtxn);
|
let mut ctx = SearchContext::new(self.index, self.rtxn);
|
||||||
|
|
||||||
if let Some(searchable_attributes) = self.searchable_attributes {
|
if let Some(searchable_attributes) = self.searchable_attributes {
|
||||||
@ -282,6 +293,8 @@ impl<'a> Search<'a> {
|
|||||||
self.geo_strategy,
|
self.geo_strategy,
|
||||||
0,
|
0,
|
||||||
self.limit + self.offset,
|
self.limit + self.offset,
|
||||||
|
self.distribution_shift,
|
||||||
|
embedder_name,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ use self::new::{execute_vector_search, PartialSearchResult};
|
|||||||
use crate::error::UserError;
|
use crate::error::UserError;
|
||||||
use crate::heed_codec::facet::{FacetGroupKey, FacetGroupValue};
|
use crate::heed_codec::facet::{FacetGroupKey, FacetGroupValue};
|
||||||
use crate::score_details::{ScoreDetails, ScoringStrategy};
|
use crate::score_details::{ScoreDetails, ScoringStrategy};
|
||||||
|
use crate::vector::DistributionShift;
|
||||||
use crate::{
|
use crate::{
|
||||||
execute_search, filtered_universe, AscDesc, DefaultSearchLogger, DocumentId, FieldId, Index,
|
execute_search, filtered_universe, AscDesc, DefaultSearchLogger, DocumentId, FieldId, Index,
|
||||||
Result, SearchContext,
|
Result, SearchContext,
|
||||||
@ -51,6 +52,8 @@ pub struct Search<'a> {
|
|||||||
exhaustive_number_hits: bool,
|
exhaustive_number_hits: bool,
|
||||||
rtxn: &'a heed::RoTxn<'a>,
|
rtxn: &'a heed::RoTxn<'a>,
|
||||||
index: &'a Index,
|
index: &'a Index,
|
||||||
|
distribution_shift: Option<DistributionShift>,
|
||||||
|
embedder_name: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
@ -117,6 +120,8 @@ impl<'a> Search<'a> {
|
|||||||
words_limit: 10,
|
words_limit: 10,
|
||||||
rtxn,
|
rtxn,
|
||||||
index,
|
index,
|
||||||
|
distribution_shift: None,
|
||||||
|
embedder_name: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -183,7 +188,29 @@ impl<'a> Search<'a> {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn distribution_shift(
|
||||||
|
&mut self,
|
||||||
|
distribution_shift: Option<DistributionShift>,
|
||||||
|
) -> &mut Search<'a> {
|
||||||
|
self.distribution_shift = distribution_shift;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn embedder_name(&mut self, embedder_name: impl Into<String>) -> &mut Search<'a> {
|
||||||
|
self.embedder_name = Some(embedder_name.into());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
pub fn execute(&self) -> Result<SearchResult> {
|
pub fn execute(&self) -> Result<SearchResult> {
|
||||||
|
let embedder_name;
|
||||||
|
let embedder_name = match &self.embedder_name {
|
||||||
|
Some(embedder_name) => embedder_name,
|
||||||
|
None => {
|
||||||
|
embedder_name = self.index.default_embedding_name(self.rtxn)?;
|
||||||
|
&embedder_name
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let mut ctx = SearchContext::new(self.index, self.rtxn);
|
let mut ctx = SearchContext::new(self.index, self.rtxn);
|
||||||
|
|
||||||
if let Some(searchable_attributes) = self.searchable_attributes {
|
if let Some(searchable_attributes) = self.searchable_attributes {
|
||||||
@ -202,6 +229,8 @@ impl<'a> Search<'a> {
|
|||||||
self.geo_strategy,
|
self.geo_strategy,
|
||||||
self.offset,
|
self.offset,
|
||||||
self.limit,
|
self.limit,
|
||||||
|
self.distribution_shift,
|
||||||
|
embedder_name,
|
||||||
)?,
|
)?,
|
||||||
None => execute_search(
|
None => execute_search(
|
||||||
&mut ctx,
|
&mut ctx,
|
||||||
@ -247,6 +276,8 @@ impl fmt::Debug for Search<'_> {
|
|||||||
exhaustive_number_hits,
|
exhaustive_number_hits,
|
||||||
rtxn: _,
|
rtxn: _,
|
||||||
index: _,
|
index: _,
|
||||||
|
distribution_shift,
|
||||||
|
embedder_name,
|
||||||
} = self;
|
} = self;
|
||||||
f.debug_struct("Search")
|
f.debug_struct("Search")
|
||||||
.field("query", query)
|
.field("query", query)
|
||||||
@ -260,6 +291,8 @@ impl fmt::Debug for Search<'_> {
|
|||||||
.field("scoring_strategy", scoring_strategy)
|
.field("scoring_strategy", scoring_strategy)
|
||||||
.field("exhaustive_number_hits", exhaustive_number_hits)
|
.field("exhaustive_number_hits", exhaustive_number_hits)
|
||||||
.field("words_limit", words_limit)
|
.field("words_limit", words_limit)
|
||||||
|
.field("distribution_shift", distribution_shift)
|
||||||
|
.field("embedder_name", embedder_name)
|
||||||
.finish()
|
.finish()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -266,6 +266,7 @@ fn get_ranking_rules_for_vector<'ctx>(
|
|||||||
limit_plus_offset: usize,
|
limit_plus_offset: usize,
|
||||||
target: &[f32],
|
target: &[f32],
|
||||||
distribution_shift: Option<DistributionShift>,
|
distribution_shift: Option<DistributionShift>,
|
||||||
|
embedder_name: &str,
|
||||||
) -> Result<Vec<BoxRankingRule<'ctx, PlaceholderQuery>>> {
|
) -> Result<Vec<BoxRankingRule<'ctx, PlaceholderQuery>>> {
|
||||||
// query graph search
|
// query graph search
|
||||||
|
|
||||||
@ -292,6 +293,7 @@ fn get_ranking_rules_for_vector<'ctx>(
|
|||||||
vector_candidates,
|
vector_candidates,
|
||||||
limit_plus_offset,
|
limit_plus_offset,
|
||||||
distribution_shift,
|
distribution_shift,
|
||||||
|
embedder_name,
|
||||||
)?;
|
)?;
|
||||||
ranking_rules.push(Box::new(vector_sort));
|
ranking_rules.push(Box::new(vector_sort));
|
||||||
vector = true;
|
vector = true;
|
||||||
@ -513,6 +515,8 @@ pub fn execute_vector_search(
|
|||||||
geo_strategy: geo_sort::Strategy,
|
geo_strategy: geo_sort::Strategy,
|
||||||
from: usize,
|
from: usize,
|
||||||
length: usize,
|
length: usize,
|
||||||
|
distribution_shift: Option<DistributionShift>,
|
||||||
|
embedder_name: &str,
|
||||||
) -> Result<PartialSearchResult> {
|
) -> Result<PartialSearchResult> {
|
||||||
check_sort_criteria(ctx, sort_criteria.as_ref())?;
|
check_sort_criteria(ctx, sort_criteria.as_ref())?;
|
||||||
|
|
||||||
@ -524,7 +528,8 @@ pub fn execute_vector_search(
|
|||||||
geo_strategy,
|
geo_strategy,
|
||||||
from + length,
|
from + length,
|
||||||
vector,
|
vector,
|
||||||
None,
|
distribution_shift,
|
||||||
|
embedder_name,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let mut placeholder_search_logger = logger::DefaultSearchLogger;
|
let mut placeholder_search_logger = logger::DefaultSearchLogger;
|
||||||
|
@ -15,16 +15,21 @@ pub struct VectorSort<Q: RankingRuleQueryTrait> {
|
|||||||
cached_sorted_docids: std::vec::IntoIter<(DocumentId, f32, Vec<f32>)>,
|
cached_sorted_docids: std::vec::IntoIter<(DocumentId, f32, Vec<f32>)>,
|
||||||
limit: usize,
|
limit: usize,
|
||||||
distribution_shift: Option<DistributionShift>,
|
distribution_shift: Option<DistributionShift>,
|
||||||
|
embedder_index: u8,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
|
impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
_ctx: &SearchContext,
|
ctx: &SearchContext,
|
||||||
target: Vec<f32>,
|
target: Vec<f32>,
|
||||||
vector_candidates: RoaringBitmap,
|
vector_candidates: RoaringBitmap,
|
||||||
limit: usize,
|
limit: usize,
|
||||||
distribution_shift: Option<DistributionShift>,
|
distribution_shift: Option<DistributionShift>,
|
||||||
|
embedder_name: &str,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
|
/// FIXME: unwrap
|
||||||
|
let embedder_index = ctx.index.embedder_category_id.get(ctx.txn, embedder_name)?.unwrap();
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
query: None,
|
query: None,
|
||||||
target,
|
target,
|
||||||
@ -32,6 +37,7 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
|
|||||||
cached_sorted_docids: Default::default(),
|
cached_sorted_docids: Default::default(),
|
||||||
limit,
|
limit,
|
||||||
distribution_shift,
|
distribution_shift,
|
||||||
|
embedder_index,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -40,9 +46,10 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
|
|||||||
ctx: &mut SearchContext<'_>,
|
ctx: &mut SearchContext<'_>,
|
||||||
vector_candidates: &RoaringBitmap,
|
vector_candidates: &RoaringBitmap,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
|
let writer_index = (self.embedder_index as u16) << 8;
|
||||||
let readers: std::result::Result<Vec<_>, _> = (0..=u8::MAX)
|
let readers: std::result::Result<Vec<_>, _> = (0..=u8::MAX)
|
||||||
.map_while(|k| {
|
.map_while(|k| {
|
||||||
arroy::Reader::open(ctx.txn, k.into(), ctx.index.vector_arroy)
|
arroy::Reader::open(ctx.txn, writer_index | (k as u16), ctx.index.vector_arroy)
|
||||||
.map(Some)
|
.map(Some)
|
||||||
.or_else(|e| match e {
|
.or_else(|e| match e {
|
||||||
arroy::Error::MissingMetadata => Ok(None),
|
arroy::Error::MissingMetadata => Ok(None),
|
||||||
|
@ -9,10 +9,9 @@ mod extract_word_docids;
|
|||||||
mod extract_word_pair_proximity_docids;
|
mod extract_word_pair_proximity_docids;
|
||||||
mod extract_word_position_docids;
|
mod extract_word_position_docids;
|
||||||
|
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::HashSet;
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::BufReader;
|
use std::io::BufReader;
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use crossbeam_channel::Sender;
|
use crossbeam_channel::Sender;
|
||||||
use log::debug;
|
use log::debug;
|
||||||
@ -35,9 +34,8 @@ use super::helpers::{
|
|||||||
MergeFn, MergeableReader,
|
MergeFn, MergeableReader,
|
||||||
};
|
};
|
||||||
use super::{helpers, TypedChunk};
|
use super::{helpers, TypedChunk};
|
||||||
use crate::prompt::Prompt;
|
|
||||||
use crate::proximity::ProximityPrecision;
|
use crate::proximity::ProximityPrecision;
|
||||||
use crate::vector::Embedder;
|
use crate::vector::EmbeddingConfigs;
|
||||||
use crate::{FieldId, FieldsIdsMap, Result};
|
use crate::{FieldId, FieldsIdsMap, Result};
|
||||||
|
|
||||||
/// Extract data for each databases from obkv documents in parallel.
|
/// Extract data for each databases from obkv documents in parallel.
|
||||||
@ -59,7 +57,7 @@ pub(crate) fn data_from_obkv_documents(
|
|||||||
max_positions_per_attributes: Option<u32>,
|
max_positions_per_attributes: Option<u32>,
|
||||||
exact_attributes: HashSet<FieldId>,
|
exact_attributes: HashSet<FieldId>,
|
||||||
proximity_precision: ProximityPrecision,
|
proximity_precision: ProximityPrecision,
|
||||||
embedders: HashMap<String, (Arc<Embedder>, Arc<Prompt>)>,
|
embedders: EmbeddingConfigs,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
puffin::profile_function!();
|
puffin::profile_function!();
|
||||||
|
|
||||||
@ -284,7 +282,7 @@ fn send_original_documents_data(
|
|||||||
indexer: GrenadParameters,
|
indexer: GrenadParameters,
|
||||||
lmdb_writer_sx: Sender<Result<TypedChunk>>,
|
lmdb_writer_sx: Sender<Result<TypedChunk>>,
|
||||||
field_id_map: FieldsIdsMap,
|
field_id_map: FieldsIdsMap,
|
||||||
embedders: HashMap<String, (Arc<Embedder>, Arc<Prompt>)>,
|
embedders: EmbeddingConfigs,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let original_documents_chunk =
|
let original_documents_chunk =
|
||||||
original_documents_chunk.and_then(|c| unsafe { as_cloneable_grenad(&c) })?;
|
original_documents_chunk.and_then(|c| unsafe { as_cloneable_grenad(&c) })?;
|
||||||
|
@ -9,7 +9,6 @@ use std::io::{Cursor, Read, Seek};
|
|||||||
use std::iter::FromIterator;
|
use std::iter::FromIterator;
|
||||||
use std::num::NonZeroU32;
|
use std::num::NonZeroU32;
|
||||||
use std::result::Result as StdResult;
|
use std::result::Result as StdResult;
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use crossbeam_channel::{Receiver, Sender};
|
use crossbeam_channel::{Receiver, Sender};
|
||||||
use heed::types::Str;
|
use heed::types::Str;
|
||||||
@ -34,12 +33,11 @@ use self::helpers::{grenad_obkv_into_chunks, GrenadParameters};
|
|||||||
pub use self::transform::{Transform, TransformOutput};
|
pub use self::transform::{Transform, TransformOutput};
|
||||||
use crate::documents::{obkv_to_object, DocumentsBatchReader};
|
use crate::documents::{obkv_to_object, DocumentsBatchReader};
|
||||||
use crate::error::{Error, InternalError, UserError};
|
use crate::error::{Error, InternalError, UserError};
|
||||||
use crate::prompt::Prompt;
|
|
||||||
pub use crate::update::index_documents::helpers::CursorClonableMmap;
|
pub use crate::update::index_documents::helpers::CursorClonableMmap;
|
||||||
use crate::update::{
|
use crate::update::{
|
||||||
IndexerConfig, UpdateIndexingStep, WordPrefixDocids, WordPrefixIntegerDocids, WordsPrefixesFst,
|
IndexerConfig, UpdateIndexingStep, WordPrefixDocids, WordPrefixIntegerDocids, WordsPrefixesFst,
|
||||||
};
|
};
|
||||||
use crate::vector::Embedder;
|
use crate::vector::EmbeddingConfigs;
|
||||||
use crate::{CboRoaringBitmapCodec, Index, Result};
|
use crate::{CboRoaringBitmapCodec, Index, Result};
|
||||||
|
|
||||||
static MERGED_DATABASE_COUNT: usize = 7;
|
static MERGED_DATABASE_COUNT: usize = 7;
|
||||||
@ -82,7 +80,7 @@ pub struct IndexDocuments<'t, 'i, 'a, FP, FA> {
|
|||||||
should_abort: FA,
|
should_abort: FA,
|
||||||
added_documents: u64,
|
added_documents: u64,
|
||||||
deleted_documents: u64,
|
deleted_documents: u64,
|
||||||
embedders: HashMap<String, (Arc<Embedder>, Arc<Prompt>)>,
|
embedders: EmbeddingConfigs,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default, Debug, Clone)]
|
#[derive(Default, Debug, Clone)]
|
||||||
@ -173,10 +171,7 @@ where
|
|||||||
Ok((self, Ok(indexed_documents)))
|
Ok((self, Ok(indexed_documents)))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_embedders(
|
pub fn with_embedders(mut self, embedders: EmbeddingConfigs) -> Self {
|
||||||
mut self,
|
|
||||||
embedders: HashMap<String, (Arc<Embedder>, Arc<Prompt>)>,
|
|
||||||
) -> Self {
|
|
||||||
self.embedders = embedders;
|
self.embedders = embedders;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
@ -14,12 +14,11 @@ use super::IndexerConfig;
|
|||||||
use crate::criterion::Criterion;
|
use crate::criterion::Criterion;
|
||||||
use crate::error::UserError;
|
use crate::error::UserError;
|
||||||
use crate::index::{DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS};
|
use crate::index::{DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS};
|
||||||
use crate::prompt::Prompt;
|
|
||||||
use crate::proximity::ProximityPrecision;
|
use crate::proximity::ProximityPrecision;
|
||||||
use crate::update::index_documents::IndexDocumentsMethod;
|
use crate::update::index_documents::IndexDocumentsMethod;
|
||||||
use crate::update::{IndexDocuments, UpdateIndexingStep};
|
use crate::update::{IndexDocuments, UpdateIndexingStep};
|
||||||
use crate::vector::settings::{EmbeddingSettings, PromptSettings};
|
use crate::vector::settings::{EmbeddingSettings, PromptSettings};
|
||||||
use crate::vector::{Embedder, EmbeddingConfig};
|
use crate::vector::{Embedder, EmbeddingConfig, EmbeddingConfigs};
|
||||||
use crate::{FieldsIdsMap, Index, OrderBy, Result};
|
use crate::{FieldsIdsMap, Index, OrderBy, Result};
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
|
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
|
||||||
@ -422,7 +421,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
|
|||||||
fn embedders(
|
fn embedders(
|
||||||
&self,
|
&self,
|
||||||
embedding_configs: Vec<(String, EmbeddingConfig)>,
|
embedding_configs: Vec<(String, EmbeddingConfig)>,
|
||||||
) -> Result<HashMap<String, (Arc<Embedder>, Arc<Prompt>)>> {
|
) -> Result<EmbeddingConfigs> {
|
||||||
let res: Result<_> = embedding_configs
|
let res: Result<_> = embedding_configs
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(name, EmbeddingConfig { embedder_options, prompt })| {
|
.map(|(name, EmbeddingConfig { embedder_options, prompt })| {
|
||||||
@ -436,7 +435,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
|
|||||||
Ok((name, (embedder, prompt)))
|
Ok((name, (embedder, prompt)))
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
res
|
res.map(EmbeddingConfigs::new)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update_displayed(&mut self) -> Result<bool> {
|
fn update_displayed(&mut self) -> Result<bool> {
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use self::error::{EmbedError, NewEmbedderError};
|
use self::error::{EmbedError, NewEmbedderError};
|
||||||
use crate::prompt::PromptData;
|
use crate::prompt::{Prompt, PromptData};
|
||||||
|
|
||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod hf;
|
pub mod hf;
|
||||||
@ -82,6 +85,44 @@ pub struct EmbeddingConfig {
|
|||||||
// TODO: add metrics and anything needed
|
// TODO: add metrics and anything needed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Default)]
|
||||||
|
pub struct EmbeddingConfigs(HashMap<String, (Arc<Embedder>, Arc<Prompt>)>);
|
||||||
|
|
||||||
|
impl EmbeddingConfigs {
|
||||||
|
pub fn new(data: HashMap<String, (Arc<Embedder>, Arc<Prompt>)>) -> Self {
|
||||||
|
Self(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&self, name: &str) -> Option<(Arc<Embedder>, Arc<Prompt>)> {
|
||||||
|
self.0.get(name).cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_default(&self) -> Option<(Arc<Embedder>, Arc<Prompt>)> {
|
||||||
|
self.get_default_embedder_name().and_then(|default| self.get(&default))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_default_embedder_name(&self) -> Option<String> {
|
||||||
|
let mut it = self.0.keys();
|
||||||
|
let first_name = it.next();
|
||||||
|
let second_name = it.next();
|
||||||
|
match (first_name, second_name) {
|
||||||
|
(None, _) => None,
|
||||||
|
(Some(first), None) => Some(first.to_owned()),
|
||||||
|
(Some(_), Some(_)) => Some("default".to_owned()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoIterator for EmbeddingConfigs {
|
||||||
|
type Item = (String, (Arc<Embedder>, Arc<Prompt>));
|
||||||
|
|
||||||
|
type IntoIter = std::collections::hash_map::IntoIter<String, (Arc<Embedder>, Arc<Prompt>)>;
|
||||||
|
|
||||||
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
|
self.0.into_iter()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
||||||
pub enum EmbedderOptions {
|
pub enum EmbedderOptions {
|
||||||
HuggingFace(hf::EmbedderOptions),
|
HuggingFace(hf::EmbedderOptions),
|
||||||
|
Loading…
Reference in New Issue
Block a user