Add embedder stats in batches

This commit is contained in:
Mubelotix 2025-06-20 12:42:22 +02:00
parent fc6cc80705
commit 4cadc8113b
No known key found for this signature in database
GPG Key ID: 89F391DBCC8CE7F0
26 changed files with 188 additions and 73 deletions

View File

@ -65,7 +65,7 @@ fn setup_settings<'t>(
let sortable_fields = sortable_fields.iter().map(|s| s.to_string()).collect(); let sortable_fields = sortable_fields.iter().map(|s| s.to_string()).collect();
builder.set_sortable_fields(sortable_fields); builder.set_sortable_fields(sortable_fields);
builder.execute(|_| (), || false).unwrap(); builder.execute(|_| (), || false, None).unwrap();
} }
fn setup_index_with_settings( fn setup_index_with_settings(

View File

@ -90,7 +90,7 @@ pub fn base_setup(conf: &Conf) -> Index {
(conf.configure)(&mut builder); (conf.configure)(&mut builder);
builder.execute(|_| (), || false).unwrap(); builder.execute(|_| (), || false, None).unwrap();
wtxn.commit().unwrap(); wtxn.commit().unwrap();
let config = IndexerConfig::default(); let config = IndexerConfig::default();

View File

@ -328,6 +328,7 @@ pub(crate) mod test {
progress_trace: Default::default(), progress_trace: Default::default(),
write_channel_congestion: None, write_channel_congestion: None,
internal_database_sizes: Default::default(), internal_database_sizes: Default::default(),
embeddings: Default::default(),
}, },
enqueued_at: Some(BatchEnqueuedAt { enqueued_at: Some(BatchEnqueuedAt {
earliest: datetime!(2022-11-11 0:00 UTC), earliest: datetime!(2022-11-11 0:00 UTC),

View File

@ -242,6 +242,7 @@ impl IndexScheduler {
.execute( .execute(
|indexing_step| tracing::debug!(update = ?indexing_step), |indexing_step| tracing::debug!(update = ?indexing_step),
|| must_stop_processing.get(), || must_stop_processing.get(),
Some(progress.embedder_stats),
) )
.map_err(|e| Error::from_milli(e, Some(index_uid.to_string())))?; .map_err(|e| Error::from_milli(e, Some(index_uid.to_string())))?;
index_wtxn.commit()?; index_wtxn.commit()?;

View File

@ -1,3 +1,5 @@
use std::sync::Arc;
use bumpalo::collections::CollectIn; use bumpalo::collections::CollectIn;
use bumpalo::Bump; use bumpalo::Bump;
use meilisearch_types::heed::RwTxn; use meilisearch_types::heed::RwTxn;
@ -472,6 +474,7 @@ impl IndexScheduler {
.execute( .execute(
|indexing_step| tracing::debug!(update = ?indexing_step), |indexing_step| tracing::debug!(update = ?indexing_step),
|| must_stop_processing.get(), || must_stop_processing.get(),
Some(Arc::clone(&progress.embedder_stats))
) )
.map_err(|err| Error::from_milli(err, Some(index_uid.clone())))?; .map_err(|err| Error::from_milli(err, Some(index_uid.clone())))?;

View File

@ -82,4 +82,14 @@ pub struct BatchStats {
pub write_channel_congestion: Option<serde_json::Map<String, serde_json::Value>>, pub write_channel_congestion: Option<serde_json::Map<String, serde_json::Value>>,
#[serde(default, skip_serializing_if = "serde_json::Map::is_empty")] #[serde(default, skip_serializing_if = "serde_json::Map::is_empty")]
pub internal_database_sizes: serde_json::Map<String, serde_json::Value>, pub internal_database_sizes: serde_json::Map<String, serde_json::Value>,
pub embeddings: BatchEmbeddingStats
}
#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize, ToSchema)]
#[serde(rename_all = "camelCase")]
#[schema(rename_all = "camelCase")]
pub struct BatchEmbeddingStats {
pub total_count: usize,
pub error_count: usize,
pub last_error: Option<String>,
} }

View File

@ -543,7 +543,7 @@ fn import_dump(
let settings = index_reader.settings()?; let settings = index_reader.settings()?;
apply_settings_to_builder(&settings, &mut builder); apply_settings_to_builder(&settings, &mut builder);
builder builder
.execute(|indexing_step| tracing::debug!("update: {:?}", indexing_step), || false)?; .execute(|indexing_step| tracing::debug!("update: {:?}", indexing_step), || false, None)?;
// 4.3 Import the documents. // 4.3 Import the documents.
// 4.3.1 We need to recreate the grenad+obkv format accepted by the index. // 4.3.1 We need to recreate the grenad+obkv format accepted by the index.
@ -574,6 +574,7 @@ fn import_dump(
}, },
|indexing_step| tracing::trace!("update: {:?}", indexing_step), |indexing_step| tracing::trace!("update: {:?}", indexing_step),
|| false, || false,
None,
)?; )?;
let builder = builder.with_embedders(embedders); let builder = builder.with_embedders(embedders);

View File

@ -1,7 +1,7 @@
use std::any::TypeId; use std::any::TypeId;
use std::borrow::Cow; use std::borrow::Cow;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::atomic::{AtomicU32, AtomicUsize, Ordering};
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
@ -20,6 +20,13 @@ pub trait Step: 'static + Send + Sync {
#[derive(Clone, Default)] #[derive(Clone, Default)]
pub struct Progress { pub struct Progress {
steps: Arc<RwLock<InnerProgress>>, steps: Arc<RwLock<InnerProgress>>,
pub embedder_stats: Arc<EmbedderStats>,
}
#[derive(Default)]
pub struct EmbedderStats {
pub errors: Arc<RwLock<(Option<String>, u32)>>,
pub total_requests: AtomicUsize
} }
#[derive(Default)] #[derive(Default)]
@ -65,7 +72,19 @@ impl Progress {
}); });
} }
ProgressView { steps: step_view, percentage: percentage * 100.0 } let embedder_view = {
let (last_error, error_count) = match self.embedder_stats.errors.read() {
Ok(guard) => (guard.0.clone(), guard.1),
Err(_) => (None, 0),
};
EmbedderStatsView {
last_error,
request_count: self.embedder_stats.total_requests.load(Ordering::Relaxed) as u32,
error_count,
}
};
ProgressView { steps: step_view, percentage: percentage * 100.0, embedder: embedder_view }
} }
pub fn accumulated_durations(&self) -> IndexMap<String, String> { pub fn accumulated_durations(&self) -> IndexMap<String, String> {
@ -209,6 +228,7 @@ make_enum_progress! {
pub struct ProgressView { pub struct ProgressView {
pub steps: Vec<ProgressStepView>, pub steps: Vec<ProgressStepView>,
pub percentage: f32, pub percentage: f32,
pub embedder: EmbedderStatsView,
} }
#[derive(Debug, Serialize, Clone, ToSchema)] #[derive(Debug, Serialize, Clone, ToSchema)]
@ -220,6 +240,16 @@ pub struct ProgressStepView {
pub total: u32, pub total: u32,
} }
#[derive(Debug, Serialize, Clone, ToSchema)]
#[serde(rename_all = "camelCase")]
#[schema(rename_all = "camelCase")]
pub struct EmbedderStatsView {
#[serde(skip_serializing_if = "Option::is_none")]
pub last_error: Option<String>,
pub request_count: u32,
pub error_count: u32,
}
/// Used when the name can change but it's still the same step. /// Used when the name can change but it's still the same step.
/// To avoid conflicts on the `TypeId`, create a unique type every time you use this step: /// To avoid conflicts on the `TypeId`, create a unique type every time you use this step:
/// ```text /// ```text

View File

@ -44,7 +44,7 @@ pub fn setup_search_index_with_criteria(criteria: &[Criterion]) -> Index {
S("america") => vec![S("the united states")], S("america") => vec![S("the united states")],
}); });
builder.set_searchable_fields(vec![S("title"), S("description")]); builder.set_searchable_fields(vec![S("title"), S("description")]);
builder.execute(|_| (), || false).unwrap(); builder.execute(|_| (), || false, None).unwrap();
wtxn.commit().unwrap(); wtxn.commit().unwrap();
// index documents // index documents

View File

@ -134,7 +134,7 @@ impl TempIndex {
) -> Result<(), crate::error::Error> { ) -> Result<(), crate::error::Error> {
let mut builder = update::Settings::new(wtxn, &self.inner, &self.indexer_config); let mut builder = update::Settings::new(wtxn, &self.inner, &self.indexer_config);
update(&mut builder); update(&mut builder);
builder.execute(drop, || false)?; builder.execute(drop, || false, None)?;
Ok(()) Ok(())
} }

View File

@ -17,6 +17,7 @@ use crate::constants::RESERVED_VECTORS_FIELD_NAME;
use crate::error::FaultSource; use crate::error::FaultSource;
use crate::fields_ids_map::metadata::FieldIdMapWithMetadata; use crate::fields_ids_map::metadata::FieldIdMapWithMetadata;
use crate::index::IndexEmbeddingConfig; use crate::index::IndexEmbeddingConfig;
use crate::progress::EmbedderStats;
use crate::prompt::Prompt; use crate::prompt::Prompt;
use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd}; use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd};
use crate::update::settings::InnerIndexSettingsDiff; use crate::update::settings::InnerIndexSettingsDiff;
@ -682,6 +683,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
embedder: Arc<Embedder>, embedder: Arc<Embedder>,
embedder_name: &str, embedder_name: &str,
possible_embedding_mistakes: &PossibleEmbeddingMistakes, possible_embedding_mistakes: &PossibleEmbeddingMistakes,
embedder_stats: Option<Arc<EmbedderStats>>,
unused_vectors_distribution: &UnusedVectorsDistribution, unused_vectors_distribution: &UnusedVectorsDistribution,
request_threads: &ThreadPoolNoAbort, request_threads: &ThreadPoolNoAbort,
) -> Result<grenad::Reader<BufReader<File>>> { ) -> Result<grenad::Reader<BufReader<File>>> {
@ -724,6 +726,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks)), std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks)),
embedder_name, embedder_name,
possible_embedding_mistakes, possible_embedding_mistakes,
embedder_stats.clone(),
unused_vectors_distribution, unused_vectors_distribution,
request_threads, request_threads,
)?; )?;
@ -746,6 +749,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
std::mem::take(&mut chunks), std::mem::take(&mut chunks),
embedder_name, embedder_name,
possible_embedding_mistakes, possible_embedding_mistakes,
embedder_stats.clone(),
unused_vectors_distribution, unused_vectors_distribution,
request_threads, request_threads,
)?; )?;
@ -764,6 +768,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
vec![std::mem::take(&mut current_chunk)], vec![std::mem::take(&mut current_chunk)],
embedder_name, embedder_name,
possible_embedding_mistakes, possible_embedding_mistakes,
embedder_stats,
unused_vectors_distribution, unused_vectors_distribution,
request_threads, request_threads,
)?; )?;
@ -783,10 +788,11 @@ fn embed_chunks(
text_chunks: Vec<Vec<String>>, text_chunks: Vec<Vec<String>>,
embedder_name: &str, embedder_name: &str,
possible_embedding_mistakes: &PossibleEmbeddingMistakes, possible_embedding_mistakes: &PossibleEmbeddingMistakes,
embedder_stats: Option<Arc<EmbedderStats>>,
unused_vectors_distribution: &UnusedVectorsDistribution, unused_vectors_distribution: &UnusedVectorsDistribution,
request_threads: &ThreadPoolNoAbort, request_threads: &ThreadPoolNoAbort,
) -> Result<Vec<Vec<Embedding>>> { ) -> Result<Vec<Vec<Embedding>>> {
match embedder.embed_index(text_chunks, request_threads) { match embedder.embed_index(text_chunks, request_threads, embedder_stats) {
Ok(chunks) => Ok(chunks), Ok(chunks) => Ok(chunks),
Err(error) => { Err(error) => {
if let FaultSource::Bug = error.fault { if let FaultSource::Bug = error.fault {

View File

@ -31,6 +31,7 @@ use self::extract_word_position_docids::extract_word_position_docids;
use super::helpers::{as_cloneable_grenad, CursorClonableMmap, GrenadParameters}; use super::helpers::{as_cloneable_grenad, CursorClonableMmap, GrenadParameters};
use super::{helpers, TypedChunk}; use super::{helpers, TypedChunk};
use crate::index::IndexEmbeddingConfig; use crate::index::IndexEmbeddingConfig;
use crate::progress::EmbedderStats;
use crate::update::settings::InnerIndexSettingsDiff; use crate::update::settings::InnerIndexSettingsDiff;
use crate::vector::error::PossibleEmbeddingMistakes; use crate::vector::error::PossibleEmbeddingMistakes;
use crate::{FieldId, Result, ThreadPoolNoAbort, ThreadPoolNoAbortBuilder}; use crate::{FieldId, Result, ThreadPoolNoAbort, ThreadPoolNoAbortBuilder};
@ -49,6 +50,7 @@ pub(crate) fn data_from_obkv_documents(
settings_diff: Arc<InnerIndexSettingsDiff>, settings_diff: Arc<InnerIndexSettingsDiff>,
max_positions_per_attributes: Option<u32>, max_positions_per_attributes: Option<u32>,
possible_embedding_mistakes: Arc<PossibleEmbeddingMistakes>, possible_embedding_mistakes: Arc<PossibleEmbeddingMistakes>,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> Result<()> { ) -> Result<()> {
let (original_pipeline_result, flattened_pipeline_result): (Result<_>, Result<_>) = rayon::join( let (original_pipeline_result, flattened_pipeline_result): (Result<_>, Result<_>) = rayon::join(
|| { || {
@ -62,6 +64,7 @@ pub(crate) fn data_from_obkv_documents(
embedders_configs.clone(), embedders_configs.clone(),
settings_diff.clone(), settings_diff.clone(),
possible_embedding_mistakes.clone(), possible_embedding_mistakes.clone(),
embedder_stats.clone(),
) )
}) })
.collect::<Result<()>>() .collect::<Result<()>>()
@ -231,6 +234,7 @@ fn send_original_documents_data(
embedders_configs: Arc<Vec<IndexEmbeddingConfig>>, embedders_configs: Arc<Vec<IndexEmbeddingConfig>>,
settings_diff: Arc<InnerIndexSettingsDiff>, settings_diff: Arc<InnerIndexSettingsDiff>,
possible_embedding_mistakes: Arc<PossibleEmbeddingMistakes>, possible_embedding_mistakes: Arc<PossibleEmbeddingMistakes>,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> Result<()> { ) -> Result<()> {
let original_documents_chunk = let original_documents_chunk =
original_documents_chunk.and_then(|c| unsafe { as_cloneable_grenad(&c) })?; original_documents_chunk.and_then(|c| unsafe { as_cloneable_grenad(&c) })?;
@ -270,6 +274,7 @@ fn send_original_documents_data(
embedder.clone(), embedder.clone(),
&embedder_name, &embedder_name,
&possible_embedding_mistakes, &possible_embedding_mistakes,
embedder_stats.clone(),
&unused_vectors_distribution, &unused_vectors_distribution,
request_threads(), request_threads(),
) { ) {

View File

@ -32,7 +32,7 @@ use crate::database_stats::DatabaseStats;
use crate::documents::{obkv_to_object, DocumentsBatchReader}; use crate::documents::{obkv_to_object, DocumentsBatchReader};
use crate::error::{Error, InternalError}; use crate::error::{Error, InternalError};
use crate::index::{PrefixSearch, PrefixSettings}; use crate::index::{PrefixSearch, PrefixSettings};
use crate::progress::Progress; use crate::progress::{EmbedderStats, Progress};
pub use crate::update::index_documents::helpers::CursorClonableMmap; pub use crate::update::index_documents::helpers::CursorClonableMmap;
use crate::update::{ use crate::update::{
IndexerConfig, UpdateIndexingStep, WordPrefixDocids, WordPrefixIntegerDocids, WordsPrefixesFst, IndexerConfig, UpdateIndexingStep, WordPrefixDocids, WordPrefixIntegerDocids, WordsPrefixesFst,
@ -81,6 +81,7 @@ pub struct IndexDocuments<'t, 'i, 'a, FP, FA> {
added_documents: u64, added_documents: u64,
deleted_documents: u64, deleted_documents: u64,
embedders: EmbeddingConfigs, embedders: EmbeddingConfigs,
embedder_stats: Option<Arc<EmbedderStats>>,
} }
#[derive(Default, Debug, Clone)] #[derive(Default, Debug, Clone)]
@ -103,6 +104,7 @@ where
config: IndexDocumentsConfig, config: IndexDocumentsConfig,
progress: FP, progress: FP,
should_abort: FA, should_abort: FA,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> Result<IndexDocuments<'t, 'i, 'a, FP, FA>> { ) -> Result<IndexDocuments<'t, 'i, 'a, FP, FA>> {
let transform = Some(Transform::new( let transform = Some(Transform::new(
wtxn, wtxn,
@ -123,6 +125,7 @@ where
added_documents: 0, added_documents: 0,
deleted_documents: 0, deleted_documents: 0,
embedders: Default::default(), embedders: Default::default(),
embedder_stats,
}) })
} }
@ -292,6 +295,7 @@ where
// Run extraction pipeline in parallel. // Run extraction pipeline in parallel.
let mut modified_docids = RoaringBitmap::new(); let mut modified_docids = RoaringBitmap::new();
let embedder_stats = self.embedder_stats.clone();
pool.install(|| { pool.install(|| {
let settings_diff_cloned = settings_diff.clone(); let settings_diff_cloned = settings_diff.clone();
rayon::spawn(move || { rayon::spawn(move || {
@ -326,7 +330,8 @@ where
embedders_configs.clone(), embedders_configs.clone(),
settings_diff_cloned, settings_diff_cloned,
max_positions_per_attributes, max_positions_per_attributes,
Arc::new(possible_embedding_mistakes) Arc::new(possible_embedding_mistakes),
embedder_stats.clone()
) )
}); });

View File

@ -450,7 +450,7 @@ impl<'a, 'b, 'extractor> Chunks<'a, 'b, 'extractor> {
return Err(crate::Error::UserError(crate::UserError::DocumentEmbeddingError(msg))); return Err(crate::Error::UserError(crate::UserError::DocumentEmbeddingError(msg)));
} }
let res = match embedder.embed_index_ref(texts.as_slice(), threads) { let res = match embedder.embed_index_ref(texts.as_slice(), threads, None) {
Ok(embeddings) => { Ok(embeddings) => {
for (docid, embedding) in ids.into_iter().zip(embeddings) { for (docid, embedding) in ids.into_iter().zip(embeddings) {
sender.set_vector(*docid, embedder_id, embedding).unwrap(); sender.set_vector(*docid, embedder_id, embedding).unwrap();

View File

@ -27,6 +27,7 @@ use crate::index::{
DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS, DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS,
}; };
use crate::order_by_map::OrderByMap; use crate::order_by_map::OrderByMap;
use crate::progress::EmbedderStats;
use crate::prompt::{default_max_bytes, default_template_text, PromptData}; use crate::prompt::{default_max_bytes, default_template_text, PromptData};
use crate::proximity::ProximityPrecision; use crate::proximity::ProximityPrecision;
use crate::update::index_documents::IndexDocumentsMethod; use crate::update::index_documents::IndexDocumentsMethod;
@ -466,7 +467,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
#[tracing::instrument( #[tracing::instrument(
level = "trace" level = "trace"
skip(self, progress_callback, should_abort, settings_diff), skip(self, progress_callback, should_abort, settings_diff, embedder_stats),
target = "indexing::documents" target = "indexing::documents"
)] )]
fn reindex<FP, FA>( fn reindex<FP, FA>(
@ -474,6 +475,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
progress_callback: &FP, progress_callback: &FP,
should_abort: &FA, should_abort: &FA,
settings_diff: InnerIndexSettingsDiff, settings_diff: InnerIndexSettingsDiff,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> Result<()> ) -> Result<()>
where where
FP: Fn(UpdateIndexingStep) + Sync, FP: Fn(UpdateIndexingStep) + Sync,
@ -505,6 +507,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
IndexDocumentsConfig::default(), IndexDocumentsConfig::default(),
&progress_callback, &progress_callback,
&should_abort, &should_abort,
embedder_stats,
)?; )?;
indexing_builder.execute_raw(output)?; indexing_builder.execute_raw(output)?;
@ -1355,7 +1358,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
} }
} }
pub fn execute<FP, FA>(mut self, progress_callback: FP, should_abort: FA) -> Result<()> pub fn execute<FP, FA>(mut self, progress_callback: FP, should_abort: FA, embedder_stats: Option<Arc<EmbedderStats>>) -> Result<()>
where where
FP: Fn(UpdateIndexingStep) + Sync, FP: Fn(UpdateIndexingStep) + Sync,
FA: Fn() -> bool + Sync, FA: Fn() -> bool + Sync,
@ -1413,7 +1416,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
); );
if inner_settings_diff.any_reindexing_needed() { if inner_settings_diff.any_reindexing_needed() {
self.reindex(&progress_callback, &should_abort, inner_settings_diff)?; self.reindex(&progress_callback, &should_abort, inner_settings_diff, embedder_stats)?;
} }
Ok(()) Ok(())

View File

@ -1,3 +1,4 @@
use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
use arroy::Distance; use arroy::Distance;
@ -7,6 +8,7 @@ use super::{
hf, manual, ollama, openai, rest, DistributionShift, EmbedError, Embedding, EmbeddingCache, hf, manual, ollama, openai, rest, DistributionShift, EmbedError, Embedding, EmbeddingCache,
NewEmbedderError, NewEmbedderError,
}; };
use crate::progress::EmbedderStats;
use crate::ThreadPoolNoAbort; use crate::ThreadPoolNoAbort;
#[derive(Debug)] #[derive(Debug)]
@ -81,6 +83,7 @@ impl Embedder {
"This is a sample text. It is meant to compare similarity.".into(), "This is a sample text. It is meant to compare similarity.".into(),
], ],
None, None,
None,
) )
.map_err(|error| NewEmbedderError::composite_test_embedding_failed(error, "search"))?; .map_err(|error| NewEmbedderError::composite_test_embedding_failed(error, "search"))?;
@ -92,6 +95,7 @@ impl Embedder {
"This is a sample text. It is meant to compare similarity.".into(), "This is a sample text. It is meant to compare similarity.".into(),
], ],
None, None,
None,
) )
.map_err(|error| { .map_err(|error| {
NewEmbedderError::composite_test_embedding_failed(error, "indexing") NewEmbedderError::composite_test_embedding_failed(error, "indexing")
@ -150,13 +154,14 @@ impl SubEmbedder {
&self, &self,
texts: Vec<String>, texts: Vec<String>,
deadline: Option<Instant>, deadline: Option<Instant>,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> std::result::Result<Vec<Embedding>, EmbedError> { ) -> std::result::Result<Vec<Embedding>, EmbedError> {
match self { match self {
SubEmbedder::HuggingFace(embedder) => embedder.embed(texts), SubEmbedder::HuggingFace(embedder) => embedder.embed(texts),
SubEmbedder::OpenAi(embedder) => embedder.embed(&texts, deadline), SubEmbedder::OpenAi(embedder) => embedder.embed(&texts, deadline, embedder_stats),
SubEmbedder::Ollama(embedder) => embedder.embed(&texts, deadline), SubEmbedder::Ollama(embedder) => embedder.embed(&texts, deadline, embedder_stats),
SubEmbedder::UserProvided(embedder) => embedder.embed(&texts), SubEmbedder::UserProvided(embedder) => embedder.embed(&texts),
SubEmbedder::Rest(embedder) => embedder.embed(texts, deadline), SubEmbedder::Rest(embedder) => embedder.embed(texts, deadline, embedder_stats),
} }
} }
@ -164,18 +169,19 @@ impl SubEmbedder {
&self, &self,
text: &str, text: &str,
deadline: Option<Instant>, deadline: Option<Instant>,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> std::result::Result<Embedding, EmbedError> { ) -> std::result::Result<Embedding, EmbedError> {
match self { match self {
SubEmbedder::HuggingFace(embedder) => embedder.embed_one(text), SubEmbedder::HuggingFace(embedder) => embedder.embed_one(text),
SubEmbedder::OpenAi(embedder) => { SubEmbedder::OpenAi(embedder) => {
embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding) embedder.embed(&[text], deadline, embedder_stats)?.pop().ok_or_else(EmbedError::missing_embedding)
} }
SubEmbedder::Ollama(embedder) => { SubEmbedder::Ollama(embedder) => {
embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding) embedder.embed(&[text], deadline, embedder_stats)?.pop().ok_or_else(EmbedError::missing_embedding)
} }
SubEmbedder::UserProvided(embedder) => embedder.embed_one(text), SubEmbedder::UserProvided(embedder) => embedder.embed_one(text),
SubEmbedder::Rest(embedder) => embedder SubEmbedder::Rest(embedder) => embedder
.embed_ref(&[text], deadline)? .embed_ref(&[text], deadline, embedder_stats)?
.pop() .pop()
.ok_or_else(EmbedError::missing_embedding), .ok_or_else(EmbedError::missing_embedding),
} }
@ -188,13 +194,14 @@ impl SubEmbedder {
&self, &self,
text_chunks: Vec<Vec<String>>, text_chunks: Vec<Vec<String>>,
threads: &ThreadPoolNoAbort, threads: &ThreadPoolNoAbort,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> std::result::Result<Vec<Vec<Embedding>>, EmbedError> { ) -> std::result::Result<Vec<Vec<Embedding>>, EmbedError> {
match self { match self {
SubEmbedder::HuggingFace(embedder) => embedder.embed_index(text_chunks), SubEmbedder::HuggingFace(embedder) => embedder.embed_index(text_chunks),
SubEmbedder::OpenAi(embedder) => embedder.embed_index(text_chunks, threads), SubEmbedder::OpenAi(embedder) => embedder.embed_index(text_chunks, threads, embedder_stats),
SubEmbedder::Ollama(embedder) => embedder.embed_index(text_chunks, threads), SubEmbedder::Ollama(embedder) => embedder.embed_index(text_chunks, threads, embedder_stats),
SubEmbedder::UserProvided(embedder) => embedder.embed_index(text_chunks), SubEmbedder::UserProvided(embedder) => embedder.embed_index(text_chunks),
SubEmbedder::Rest(embedder) => embedder.embed_index(text_chunks, threads), SubEmbedder::Rest(embedder) => embedder.embed_index(text_chunks, threads, embedder_stats),
} }
} }
@ -203,13 +210,14 @@ impl SubEmbedder {
&self, &self,
texts: &[&str], texts: &[&str],
threads: &ThreadPoolNoAbort, threads: &ThreadPoolNoAbort,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> std::result::Result<Vec<Embedding>, EmbedError> { ) -> std::result::Result<Vec<Embedding>, EmbedError> {
match self { match self {
SubEmbedder::HuggingFace(embedder) => embedder.embed_index_ref(texts), SubEmbedder::HuggingFace(embedder) => embedder.embed_index_ref(texts),
SubEmbedder::OpenAi(embedder) => embedder.embed_index_ref(texts, threads), SubEmbedder::OpenAi(embedder) => embedder.embed_index_ref(texts, threads, embedder_stats),
SubEmbedder::Ollama(embedder) => embedder.embed_index_ref(texts, threads), SubEmbedder::Ollama(embedder) => embedder.embed_index_ref(texts, threads, embedder_stats),
SubEmbedder::UserProvided(embedder) => embedder.embed_index_ref(texts), SubEmbedder::UserProvided(embedder) => embedder.embed_index_ref(texts),
SubEmbedder::Rest(embedder) => embedder.embed_index_ref(texts, threads), SubEmbedder::Rest(embedder) => embedder.embed_index_ref(texts, threads, embedder_stats),
} }
} }

View File

@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize};
use utoipa::ToSchema; use utoipa::ToSchema;
use self::error::{EmbedError, NewEmbedderError}; use self::error::{EmbedError, NewEmbedderError};
use crate::progress::Progress; use crate::progress::{EmbedderStats, Progress};
use crate::prompt::{Prompt, PromptData}; use crate::prompt::{Prompt, PromptData};
use crate::ThreadPoolNoAbort; use crate::ThreadPoolNoAbort;
@ -720,17 +720,17 @@ impl Embedder {
let embedding = match self { let embedding = match self {
Embedder::HuggingFace(embedder) => embedder.embed_one(text), Embedder::HuggingFace(embedder) => embedder.embed_one(text),
Embedder::OpenAi(embedder) => { Embedder::OpenAi(embedder) => {
embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding) embedder.embed(&[text], deadline, None)?.pop().ok_or_else(EmbedError::missing_embedding)
} }
Embedder::Ollama(embedder) => { Embedder::Ollama(embedder) => {
embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding) embedder.embed(&[text], deadline, None)?.pop().ok_or_else(EmbedError::missing_embedding)
} }
Embedder::UserProvided(embedder) => embedder.embed_one(text), Embedder::UserProvided(embedder) => embedder.embed_one(text),
Embedder::Rest(embedder) => embedder Embedder::Rest(embedder) => embedder
.embed_ref(&[text], deadline)? .embed_ref(&[text], deadline, None)?
.pop() .pop()
.ok_or_else(EmbedError::missing_embedding), .ok_or_else(EmbedError::missing_embedding),
Embedder::Composite(embedder) => embedder.search.embed_one(text, deadline), Embedder::Composite(embedder) => embedder.search.embed_one(text, deadline, None),
}?; }?;
if let Some(cache) = self.cache() { if let Some(cache) = self.cache() {
@ -747,14 +747,15 @@ impl Embedder {
&self, &self,
text_chunks: Vec<Vec<String>>, text_chunks: Vec<Vec<String>>,
threads: &ThreadPoolNoAbort, threads: &ThreadPoolNoAbort,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> std::result::Result<Vec<Vec<Embedding>>, EmbedError> { ) -> std::result::Result<Vec<Vec<Embedding>>, EmbedError> {
match self { match self {
Embedder::HuggingFace(embedder) => embedder.embed_index(text_chunks), Embedder::HuggingFace(embedder) => embedder.embed_index(text_chunks),
Embedder::OpenAi(embedder) => embedder.embed_index(text_chunks, threads), Embedder::OpenAi(embedder) => embedder.embed_index(text_chunks, threads, embedder_stats),
Embedder::Ollama(embedder) => embedder.embed_index(text_chunks, threads), Embedder::Ollama(embedder) => embedder.embed_index(text_chunks, threads, embedder_stats),
Embedder::UserProvided(embedder) => embedder.embed_index(text_chunks), Embedder::UserProvided(embedder) => embedder.embed_index(text_chunks),
Embedder::Rest(embedder) => embedder.embed_index(text_chunks, threads), Embedder::Rest(embedder) => embedder.embed_index(text_chunks, threads, embedder_stats),
Embedder::Composite(embedder) => embedder.index.embed_index(text_chunks, threads), Embedder::Composite(embedder) => embedder.index.embed_index(text_chunks, threads, embedder_stats),
} }
} }
@ -763,14 +764,15 @@ impl Embedder {
&self, &self,
texts: &[&str], texts: &[&str],
threads: &ThreadPoolNoAbort, threads: &ThreadPoolNoAbort,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> std::result::Result<Vec<Embedding>, EmbedError> { ) -> std::result::Result<Vec<Embedding>, EmbedError> {
match self { match self {
Embedder::HuggingFace(embedder) => embedder.embed_index_ref(texts), Embedder::HuggingFace(embedder) => embedder.embed_index_ref(texts),
Embedder::OpenAi(embedder) => embedder.embed_index_ref(texts, threads), Embedder::OpenAi(embedder) => embedder.embed_index_ref(texts, threads, embedder_stats),
Embedder::Ollama(embedder) => embedder.embed_index_ref(texts, threads), Embedder::Ollama(embedder) => embedder.embed_index_ref(texts, threads, embedder_stats),
Embedder::UserProvided(embedder) => embedder.embed_index_ref(texts), Embedder::UserProvided(embedder) => embedder.embed_index_ref(texts),
Embedder::Rest(embedder) => embedder.embed_index_ref(texts, threads), Embedder::Rest(embedder) => embedder.embed_index_ref(texts, threads, embedder_stats),
Embedder::Composite(embedder) => embedder.index.embed_index_ref(texts, threads), Embedder::Composite(embedder) => embedder.index.embed_index_ref(texts, threads, embedder_stats),
} }
} }

View File

@ -1,3 +1,4 @@
use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
@ -7,6 +8,7 @@ use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErro
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
use super::{DistributionShift, EmbeddingCache, REQUEST_PARALLELISM}; use super::{DistributionShift, EmbeddingCache, REQUEST_PARALLELISM};
use crate::error::FaultSource; use crate::error::FaultSource;
use crate::progress::EmbedderStats;
use crate::vector::Embedding; use crate::vector::Embedding;
use crate::ThreadPoolNoAbort; use crate::ThreadPoolNoAbort;
@ -104,8 +106,9 @@ impl Embedder {
&self, &self,
texts: &[S], texts: &[S],
deadline: Option<Instant>, deadline: Option<Instant>,
embedder_stats: Option<Arc<EmbedderStats>>
) -> Result<Vec<Embedding>, EmbedError> { ) -> Result<Vec<Embedding>, EmbedError> {
match self.rest_embedder.embed_ref(texts, deadline) { match self.rest_embedder.embed_ref(texts, deadline, embedder_stats) {
Ok(embeddings) => Ok(embeddings), Ok(embeddings) => Ok(embeddings),
Err(EmbedError { kind: EmbedErrorKind::RestOtherStatusCode(404, error), fault: _ }) => { Err(EmbedError { kind: EmbedErrorKind::RestOtherStatusCode(404, error), fault: _ }) => {
Err(EmbedError::ollama_model_not_found(error)) Err(EmbedError::ollama_model_not_found(error))
@ -118,15 +121,16 @@ impl Embedder {
&self, &self,
text_chunks: Vec<Vec<String>>, text_chunks: Vec<Vec<String>>,
threads: &ThreadPoolNoAbort, threads: &ThreadPoolNoAbort,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> Result<Vec<Vec<Embedding>>, EmbedError> { ) -> Result<Vec<Vec<Embedding>>, EmbedError> {
// This condition helps reduce the number of active rayon jobs // This condition helps reduce the number of active rayon jobs
// so that we avoid consuming all the LMDB rtxns and avoid stack overflows. // so that we avoid consuming all the LMDB rtxns and avoid stack overflows.
if threads.active_operations() >= REQUEST_PARALLELISM { if threads.active_operations() >= REQUEST_PARALLELISM {
text_chunks.into_iter().map(move |chunk| self.embed(&chunk, None)).collect() text_chunks.into_iter().map(move |chunk| self.embed(&chunk, None, embedder_stats.clone())).collect()
} else { } else {
threads threads
.install(move || { .install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk, None)).collect() text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk, None, embedder_stats.clone())).collect()
}) })
.map_err(|error| EmbedError { .map_err(|error| EmbedError {
kind: EmbedErrorKind::PanicInThreadPool(error), kind: EmbedErrorKind::PanicInThreadPool(error),
@ -139,13 +143,14 @@ impl Embedder {
&self, &self,
texts: &[&str], texts: &[&str],
threads: &ThreadPoolNoAbort, threads: &ThreadPoolNoAbort,
embedder_stats: Option<Arc<EmbedderStats>>
) -> Result<Vec<Vec<f32>>, EmbedError> { ) -> Result<Vec<Vec<f32>>, EmbedError> {
// This condition helps reduce the number of active rayon jobs // This condition helps reduce the number of active rayon jobs
// so that we avoid consuming all the LMDB rtxns and avoid stack overflows. // so that we avoid consuming all the LMDB rtxns and avoid stack overflows.
if threads.active_operations() >= REQUEST_PARALLELISM { if threads.active_operations() >= REQUEST_PARALLELISM {
let embeddings: Result<Vec<Vec<Embedding>>, _> = texts let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
.chunks(self.prompt_count_in_chunk_hint()) .chunks(self.prompt_count_in_chunk_hint())
.map(move |chunk| self.embed(chunk, None)) .map(move |chunk| self.embed(chunk, None, embedder_stats.clone()))
.collect(); .collect();
let embeddings = embeddings?; let embeddings = embeddings?;
@ -155,7 +160,7 @@ impl Embedder {
.install(move || { .install(move || {
let embeddings: Result<Vec<Vec<Embedding>>, _> = texts let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
.par_chunks(self.prompt_count_in_chunk_hint()) .par_chunks(self.prompt_count_in_chunk_hint())
.map(move |chunk| self.embed(chunk, None)) .map(move |chunk| self.embed(chunk, None, embedder_stats.clone()))
.collect(); .collect();
let embeddings = embeddings?; let embeddings = embeddings?;

View File

@ -1,4 +1,5 @@
use std::fmt; use std::fmt;
use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
use ordered_float::OrderedFloat; use ordered_float::OrderedFloat;
@ -9,6 +10,7 @@ use super::error::{EmbedError, NewEmbedderError};
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
use super::{DistributionShift, EmbeddingCache, REQUEST_PARALLELISM}; use super::{DistributionShift, EmbeddingCache, REQUEST_PARALLELISM};
use crate::error::FaultSource; use crate::error::FaultSource;
use crate::progress::EmbedderStats;
use crate::vector::error::EmbedErrorKind; use crate::vector::error::EmbedErrorKind;
use crate::vector::Embedding; use crate::vector::Embedding;
use crate::ThreadPoolNoAbort; use crate::ThreadPoolNoAbort;
@ -215,8 +217,9 @@ impl Embedder {
&self, &self,
texts: &[S], texts: &[S],
deadline: Option<Instant>, deadline: Option<Instant>,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> Result<Vec<Embedding>, EmbedError> { ) -> Result<Vec<Embedding>, EmbedError> {
match self.rest_embedder.embed_ref(texts, deadline) { match self.rest_embedder.embed_ref(texts, deadline, embedder_stats) {
Ok(embeddings) => Ok(embeddings), Ok(embeddings) => Ok(embeddings),
Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error, _), fault: _ }) => { Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error, _), fault: _ }) => {
tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template."); tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template.");
@ -238,7 +241,7 @@ impl Embedder {
let encoded = self.tokenizer.encode_ordinary(text); let encoded = self.tokenizer.encode_ordinary(text);
let len = encoded.len(); let len = encoded.len();
if len < max_token_count { if len < max_token_count {
all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text], deadline)?); all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text], deadline, None)?);
continue; continue;
} }
@ -255,15 +258,16 @@ impl Embedder {
&self, &self,
text_chunks: Vec<Vec<String>>, text_chunks: Vec<Vec<String>>,
threads: &ThreadPoolNoAbort, threads: &ThreadPoolNoAbort,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> Result<Vec<Vec<Embedding>>, EmbedError> { ) -> Result<Vec<Vec<Embedding>>, EmbedError> {
// This condition helps reduce the number of active rayon jobs // This condition helps reduce the number of active rayon jobs
// so that we avoid consuming all the LMDB rtxns and avoid stack overflows. // so that we avoid consuming all the LMDB rtxns and avoid stack overflows.
if threads.active_operations() >= REQUEST_PARALLELISM { if threads.active_operations() >= REQUEST_PARALLELISM {
text_chunks.into_iter().map(move |chunk| self.embed(&chunk, None)).collect() text_chunks.into_iter().map(move |chunk| self.embed(&chunk, None, embedder_stats.clone())).collect()
} else { } else {
threads threads
.install(move || { .install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk, None)).collect() text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk, None, embedder_stats.clone())).collect()
}) })
.map_err(|error| EmbedError { .map_err(|error| EmbedError {
kind: EmbedErrorKind::PanicInThreadPool(error), kind: EmbedErrorKind::PanicInThreadPool(error),
@ -276,13 +280,14 @@ impl Embedder {
&self, &self,
texts: &[&str], texts: &[&str],
threads: &ThreadPoolNoAbort, threads: &ThreadPoolNoAbort,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> Result<Vec<Vec<f32>>, EmbedError> { ) -> Result<Vec<Vec<f32>>, EmbedError> {
// This condition helps reduce the number of active rayon jobs // This condition helps reduce the number of active rayon jobs
// so that we avoid consuming all the LMDB rtxns and avoid stack overflows. // so that we avoid consuming all the LMDB rtxns and avoid stack overflows.
if threads.active_operations() >= REQUEST_PARALLELISM { if threads.active_operations() >= REQUEST_PARALLELISM {
let embeddings: Result<Vec<Vec<Embedding>>, _> = texts let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
.chunks(self.prompt_count_in_chunk_hint()) .chunks(self.prompt_count_in_chunk_hint())
.map(move |chunk| self.embed(chunk, None)) .map(move |chunk| self.embed(chunk, None, embedder_stats.clone()))
.collect(); .collect();
let embeddings = embeddings?; let embeddings = embeddings?;
Ok(embeddings.into_iter().flatten().collect()) Ok(embeddings.into_iter().flatten().collect())
@ -291,7 +296,7 @@ impl Embedder {
.install(move || { .install(move || {
let embeddings: Result<Vec<Vec<Embedding>>, _> = texts let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
.par_chunks(self.prompt_count_in_chunk_hint()) .par_chunks(self.prompt_count_in_chunk_hint())
.map(move |chunk| self.embed(chunk, None)) .map(move |chunk| self.embed(chunk, None, embedder_stats.clone()))
.collect(); .collect();
let embeddings = embeddings?; let embeddings = embeddings?;

View File

@ -1,4 +1,5 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
use deserr::Deserr; use deserr::Deserr;
@ -14,6 +15,7 @@ use super::{
}; };
use crate::error::FaultSource; use crate::error::FaultSource;
use crate::ThreadPoolNoAbort; use crate::ThreadPoolNoAbort;
use crate::progress::EmbedderStats;
// retrying in case of failure // retrying in case of failure
pub struct Retry { pub struct Retry {
@ -168,19 +170,21 @@ impl Embedder {
&self, &self,
texts: Vec<String>, texts: Vec<String>,
deadline: Option<Instant>, deadline: Option<Instant>,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> Result<Vec<Embedding>, EmbedError> { ) -> Result<Vec<Embedding>, EmbedError> {
embed(&self.data, texts.as_slice(), texts.len(), Some(self.dimensions), deadline) embed(&self.data, texts.as_slice(), texts.len(), Some(self.dimensions), deadline, embedder_stats)
} }
pub fn embed_ref<S>( pub fn embed_ref<S>(
&self, &self,
texts: &[S], texts: &[S],
deadline: Option<Instant>, deadline: Option<Instant>,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> Result<Vec<Embedding>, EmbedError> ) -> Result<Vec<Embedding>, EmbedError>
where where
S: AsRef<str> + Serialize, S: AsRef<str> + Serialize,
{ {
embed(&self.data, texts, texts.len(), Some(self.dimensions), deadline) embed(&self.data, texts, texts.len(), Some(self.dimensions), deadline, embedder_stats)
} }
pub fn embed_tokens( pub fn embed_tokens(
@ -188,7 +192,7 @@ impl Embedder {
tokens: &[u32], tokens: &[u32],
deadline: Option<Instant>, deadline: Option<Instant>,
) -> Result<Embedding, EmbedError> { ) -> Result<Embedding, EmbedError> {
let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions), deadline)?; let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions), deadline, None)?;
// unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error // unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error
Ok(embeddings.pop().unwrap()) Ok(embeddings.pop().unwrap())
} }
@ -197,15 +201,16 @@ impl Embedder {
&self, &self,
text_chunks: Vec<Vec<String>>, text_chunks: Vec<Vec<String>>,
threads: &ThreadPoolNoAbort, threads: &ThreadPoolNoAbort,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> Result<Vec<Vec<Embedding>>, EmbedError> { ) -> Result<Vec<Vec<Embedding>>, EmbedError> {
// This condition helps reduce the number of active rayon jobs // This condition helps reduce the number of active rayon jobs
// so that we avoid consuming all the LMDB rtxns and avoid stack overflows. // so that we avoid consuming all the LMDB rtxns and avoid stack overflows.
if threads.active_operations() >= REQUEST_PARALLELISM { if threads.active_operations() >= REQUEST_PARALLELISM {
text_chunks.into_iter().map(move |chunk| self.embed(chunk, None)).collect() text_chunks.into_iter().map(move |chunk| self.embed(chunk, None, embedder_stats.clone())).collect()
} else { } else {
threads threads
.install(move || { .install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk, None)).collect() text_chunks.into_par_iter().map(move |chunk| self.embed(chunk, None, embedder_stats.clone())).collect()
}) })
.map_err(|error| EmbedError { .map_err(|error| EmbedError {
kind: EmbedErrorKind::PanicInThreadPool(error), kind: EmbedErrorKind::PanicInThreadPool(error),
@ -218,13 +223,14 @@ impl Embedder {
&self, &self,
texts: &[&str], texts: &[&str],
threads: &ThreadPoolNoAbort, threads: &ThreadPoolNoAbort,
embedder_stats: Option<Arc<EmbedderStats>>
) -> Result<Vec<Embedding>, EmbedError> { ) -> Result<Vec<Embedding>, EmbedError> {
// This condition helps reduce the number of active rayon jobs // This condition helps reduce the number of active rayon jobs
// so that we avoid consuming all the LMDB rtxns and avoid stack overflows. // so that we avoid consuming all the LMDB rtxns and avoid stack overflows.
if threads.active_operations() >= REQUEST_PARALLELISM { if threads.active_operations() >= REQUEST_PARALLELISM {
let embeddings: Result<Vec<Vec<Embedding>>, _> = texts let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
.chunks(self.prompt_count_in_chunk_hint()) .chunks(self.prompt_count_in_chunk_hint())
.map(move |chunk| self.embed_ref(chunk, None)) .map(move |chunk| self.embed_ref(chunk, None, embedder_stats.clone()))
.collect(); .collect();
let embeddings = embeddings?; let embeddings = embeddings?;
@ -234,7 +240,7 @@ impl Embedder {
.install(move || { .install(move || {
let embeddings: Result<Vec<Vec<Embedding>>, _> = texts let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
.par_chunks(self.prompt_count_in_chunk_hint()) .par_chunks(self.prompt_count_in_chunk_hint())
.map(move |chunk| self.embed_ref(chunk, None)) .map(move |chunk| self.embed_ref(chunk, None, embedder_stats.clone()))
.collect(); .collect();
let embeddings = embeddings?; let embeddings = embeddings?;
@ -272,7 +278,7 @@ impl Embedder {
} }
fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> { fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> {
let v = embed(data, ["test"].as_slice(), 1, None, None) let v = embed(data, ["test"].as_slice(), 1, None, None, None)
.map_err(NewEmbedderError::could_not_determine_dimension)?; .map_err(NewEmbedderError::could_not_determine_dimension)?;
// unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error // unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error
Ok(v.first().unwrap().len()) Ok(v.first().unwrap().len())
@ -284,6 +290,7 @@ fn embed<S>(
expected_count: usize, expected_count: usize,
expected_dimension: Option<usize>, expected_dimension: Option<usize>,
deadline: Option<Instant>, deadline: Option<Instant>,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> Result<Vec<Embedding>, EmbedError> ) -> Result<Vec<Embedding>, EmbedError>
where where
S: Serialize, S: Serialize,
@ -302,6 +309,9 @@ where
let body = data.request.inject_texts(inputs); let body = data.request.inject_texts(inputs);
for attempt in 0..10 { for attempt in 0..10 {
if let Some(embedder_stats) = &embedder_stats {
embedder_stats.as_ref().total_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
let response = request.clone().send_json(&body); let response = request.clone().send_json(&body);
let result = check_response(response, data.configuration_source).and_then(|response| { let result = check_response(response, data.configuration_source).and_then(|response| {
response_to_embedding(response, data, expected_count, expected_dimension) response_to_embedding(response, data, expected_count, expected_dimension)
@ -311,6 +321,12 @@ where
Ok(response) => return Ok(response), Ok(response) => return Ok(response),
Err(retry) => { Err(retry) => {
tracing::warn!("Failed: {}", retry.error); tracing::warn!("Failed: {}", retry.error);
if let Some(embedder_stats) = &embedder_stats {
if let Ok(mut errors) = embedder_stats.errors.write() {
errors.0 = Some(retry.error.to_string());
errors.1 += 1;
}
}
if let Some(deadline) = deadline { if let Some(deadline) = deadline {
let now = std::time::Instant::now(); let now = std::time::Instant::now();
if now > deadline { if now > deadline {
@ -336,12 +352,26 @@ where
std::thread::sleep(retry_duration); std::thread::sleep(retry_duration);
} }
if let Some(embedder_stats) = &embedder_stats {
embedder_stats.as_ref().total_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
let response = request.send_json(&body); let response = request.send_json(&body);
let result = check_response(response, data.configuration_source); let result = check_response(response, data.configuration_source).and_then(|response| {
result.map_err(Retry::into_error).and_then(|response| {
response_to_embedding(response, data, expected_count, expected_dimension) response_to_embedding(response, data, expected_count, expected_dimension)
.map_err(Retry::into_error) });
})
match result {
Ok(response) => Ok(response),
Err(retry) => {
if let Some(embedder_stats) = &embedder_stats {
if let Ok(mut errors) = embedder_stats.errors.write() {
errors.0 = Some(retry.error.to_string());
errors.1 += 1;
}
}
Err(retry.into_error())
}
}
} }
fn check_response( fn check_response(

View File

@ -19,7 +19,7 @@ macro_rules! test_distinct {
let config = milli::update::IndexerConfig::default(); let config = milli::update::IndexerConfig::default();
let mut builder = Settings::new(&mut wtxn, &index, &config); let mut builder = Settings::new(&mut wtxn, &index, &config);
builder.set_distinct_field(S(stringify!($distinct))); builder.set_distinct_field(S(stringify!($distinct)));
builder.execute(|_| (), || false).unwrap(); builder.execute(|_| (), || false, None).unwrap();
wtxn.commit().unwrap(); wtxn.commit().unwrap();
let rtxn = index.read_txn().unwrap(); let rtxn = index.read_txn().unwrap();

View File

@ -25,7 +25,7 @@ fn test_facet_distribution_with_no_facet_values() {
FilterableAttributesRule::Field(S("genres")), FilterableAttributesRule::Field(S("genres")),
FilterableAttributesRule::Field(S("tags")), FilterableAttributesRule::Field(S("tags")),
]); ]);
builder.execute(|_| (), || false).unwrap(); builder.execute(|_| (), || false, None).unwrap();
wtxn.commit().unwrap(); wtxn.commit().unwrap();
// index documents // index documents

View File

@ -63,7 +63,7 @@ pub fn setup_search_index_with_criteria(criteria: &[Criterion]) -> Index {
S("america") => vec![S("the united states")], S("america") => vec![S("the united states")],
}); });
builder.set_searchable_fields(vec![S("title"), S("description")]); builder.set_searchable_fields(vec![S("title"), S("description")]);
builder.execute(|_| (), || false).unwrap(); builder.execute(|_| (), || false, None).unwrap();
wtxn.commit().unwrap(); wtxn.commit().unwrap();
// index documents // index documents

View File

@ -10,7 +10,7 @@ fn set_stop_words(index: &Index, stop_words: &[&str]) {
let mut builder = Settings::new(&mut wtxn, index, &config); let mut builder = Settings::new(&mut wtxn, index, &config);
let stop_words = stop_words.iter().map(|s| s.to_string()).collect(); let stop_words = stop_words.iter().map(|s| s.to_string()).collect();
builder.set_stop_words(stop_words); builder.set_stop_words(stop_words);
builder.execute(|_| (), || false).unwrap(); builder.execute(|_| (), || false, None).unwrap();
wtxn.commit().unwrap(); wtxn.commit().unwrap();
} }

View File

@ -236,7 +236,7 @@ fn criteria_mixup() {
let mut wtxn = index.write_txn().unwrap(); let mut wtxn = index.write_txn().unwrap();
let mut builder = Settings::new(&mut wtxn, &index, &config); let mut builder = Settings::new(&mut wtxn, &index, &config);
builder.set_criteria(criteria.clone()); builder.set_criteria(criteria.clone());
builder.execute(|_| (), || false).unwrap(); builder.execute(|_| (), || false, None).unwrap();
wtxn.commit().unwrap(); wtxn.commit().unwrap();
let rtxn = index.read_txn().unwrap(); let rtxn = index.read_txn().unwrap();
@ -276,7 +276,7 @@ fn criteria_ascdesc() {
S("name"), S("name"),
S("age"), S("age"),
}); });
builder.execute(|_| (), || false).unwrap(); builder.execute(|_| (), || false, None).unwrap();
wtxn.commit().unwrap(); wtxn.commit().unwrap();
let mut wtxn = index.write_txn().unwrap(); let mut wtxn = index.write_txn().unwrap();
@ -358,7 +358,7 @@ fn criteria_ascdesc() {
let mut wtxn = index.write_txn().unwrap(); let mut wtxn = index.write_txn().unwrap();
let mut builder = Settings::new(&mut wtxn, &index, &config); let mut builder = Settings::new(&mut wtxn, &index, &config);
builder.set_criteria(vec![criterion.clone()]); builder.set_criteria(vec![criterion.clone()]);
builder.execute(|_| (), || false).unwrap(); builder.execute(|_| (), || false, None).unwrap();
wtxn.commit().unwrap(); wtxn.commit().unwrap();
let rtxn = index.read_txn().unwrap(); let rtxn = index.read_txn().unwrap();

View File

@ -46,7 +46,7 @@ fn test_typo_tolerance_one_typo() {
let config = IndexerConfig::default(); let config = IndexerConfig::default();
let mut builder = Settings::new(&mut txn, &index, &config); let mut builder = Settings::new(&mut txn, &index, &config);
builder.set_min_word_len_one_typo(4); builder.set_min_word_len_one_typo(4);
builder.execute(|_| (), || false).unwrap(); builder.execute(|_| (), || false, None).unwrap();
// typo is now supported for 4 letters words // typo is now supported for 4 letters words
let mut search = Search::new(&txn, &index); let mut search = Search::new(&txn, &index);
@ -92,7 +92,7 @@ fn test_typo_tolerance_two_typo() {
let config = IndexerConfig::default(); let config = IndexerConfig::default();
let mut builder = Settings::new(&mut txn, &index, &config); let mut builder = Settings::new(&mut txn, &index, &config);
builder.set_min_word_len_two_typos(7); builder.set_min_word_len_two_typos(7);
builder.execute(|_| (), || false).unwrap(); builder.execute(|_| (), || false, None).unwrap();
// typo is now supported for 4 letters words // typo is now supported for 4 letters words
let mut search = Search::new(&txn, &index); let mut search = Search::new(&txn, &index);
@ -180,7 +180,7 @@ fn test_typo_disabled_on_word() {
// `zealand` doesn't allow typos anymore // `zealand` doesn't allow typos anymore
exact_words.insert("zealand".to_string()); exact_words.insert("zealand".to_string());
builder.set_exact_words(exact_words); builder.set_exact_words(exact_words);
builder.execute(|_| (), || false).unwrap(); builder.execute(|_| (), || false, None).unwrap();
let mut search = Search::new(&txn, &index); let mut search = Search::new(&txn, &index);
search.query("zealand"); search.query("zealand");
@ -218,7 +218,7 @@ fn test_disable_typo_on_attribute() {
let mut builder = Settings::new(&mut txn, &index, &config); let mut builder = Settings::new(&mut txn, &index, &config);
// disable typos on `description` // disable typos on `description`
builder.set_exact_attributes(vec!["description".to_string()].into_iter().collect()); builder.set_exact_attributes(vec!["description".to_string()].into_iter().collect());
builder.execute(|_| (), || false).unwrap(); builder.execute(|_| (), || false, None).unwrap();
let mut search = Search::new(&txn, &index); let mut search = Search::new(&txn, &index);
search.query("antebelum"); search.query("antebelum");