This commit is contained in:
Mubelotix 2025-06-24 12:20:22 +02:00
parent 4a179fb3c0
commit d7721fe607
No known key found for this signature in database
GPG key ID: 89F391DBCC8CE7F0
18 changed files with 124 additions and 63 deletions

View file

@ -25,7 +25,7 @@ pub struct Progress {
#[derive(Default)]
pub struct EmbedderStats {
pub errors: Arc<RwLock<(Option<String>, u32)>>,
pub total_count: AtomicUsize
pub total_count: AtomicUsize,
}
impl std::fmt::Debug for EmbedderStats {

View file

@ -95,7 +95,7 @@ pub fn setup_search_index_with_criteria(criteria: &[Criterion]) -> Index {
embedders,
&|| false,
&Progress::default(),
Default::default(),
Default::default(),
)
.unwrap();

View file

@ -1,7 +1,7 @@
use std::collections::BTreeMap;
use std::sync::atomic::AtomicBool;
use std::sync::OnceLock;
use std::sync::Arc;
use std::sync::OnceLock;
use bumpalo::Bump;
use roaring::RoaringBitmap;

View file

@ -1,7 +1,7 @@
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use std::sync::{Once, RwLock};
use std::thread::{self, Builder};
use std::sync::Arc;
use big_s::S;
use document_changes::{DocumentChanges, IndexingContext};

View file

@ -1358,7 +1358,12 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
}
}
pub fn execute<FP, FA>(mut self, progress_callback: FP, should_abort: FA, embedder_stats: Arc<EmbedderStats>) -> Result<()>
pub fn execute<FP, FA>(
mut self,
progress_callback: FP,
should_abort: FA,
embedder_stats: Arc<EmbedderStats>,
) -> Result<()>
where
FP: Fn(UpdateIndexingStep) + Sync,
FA: Fn() -> bool + Sync,

View file

@ -173,12 +173,14 @@ impl SubEmbedder {
) -> std::result::Result<Embedding, EmbedError> {
match self {
SubEmbedder::HuggingFace(embedder) => embedder.embed_one(text),
SubEmbedder::OpenAi(embedder) => {
embedder.embed(&[text], deadline, embedder_stats)?.pop().ok_or_else(EmbedError::missing_embedding)
}
SubEmbedder::Ollama(embedder) => {
embedder.embed(&[text], deadline, embedder_stats)?.pop().ok_or_else(EmbedError::missing_embedding)
}
SubEmbedder::OpenAi(embedder) => embedder
.embed(&[text], deadline, embedder_stats)?
.pop()
.ok_or_else(EmbedError::missing_embedding),
SubEmbedder::Ollama(embedder) => embedder
.embed(&[text], deadline, embedder_stats)?
.pop()
.ok_or_else(EmbedError::missing_embedding),
SubEmbedder::UserProvided(embedder) => embedder.embed_one(text),
SubEmbedder::Rest(embedder) => embedder
.embed_ref(&[text], deadline, embedder_stats)?
@ -198,10 +200,16 @@ impl SubEmbedder {
) -> std::result::Result<Vec<Vec<Embedding>>, EmbedError> {
match self {
SubEmbedder::HuggingFace(embedder) => embedder.embed_index(text_chunks),
SubEmbedder::OpenAi(embedder) => embedder.embed_index(text_chunks, threads, embedder_stats),
SubEmbedder::Ollama(embedder) => embedder.embed_index(text_chunks, threads, embedder_stats),
SubEmbedder::OpenAi(embedder) => {
embedder.embed_index(text_chunks, threads, embedder_stats)
}
SubEmbedder::Ollama(embedder) => {
embedder.embed_index(text_chunks, threads, embedder_stats)
}
SubEmbedder::UserProvided(embedder) => embedder.embed_index(text_chunks),
SubEmbedder::Rest(embedder) => embedder.embed_index(text_chunks, threads, embedder_stats),
SubEmbedder::Rest(embedder) => {
embedder.embed_index(text_chunks, threads, embedder_stats)
}
}
}
@ -214,8 +222,12 @@ impl SubEmbedder {
) -> std::result::Result<Vec<Embedding>, EmbedError> {
match self {
SubEmbedder::HuggingFace(embedder) => embedder.embed_index_ref(texts),
SubEmbedder::OpenAi(embedder) => embedder.embed_index_ref(texts, threads, embedder_stats),
SubEmbedder::Ollama(embedder) => embedder.embed_index_ref(texts, threads, embedder_stats),
SubEmbedder::OpenAi(embedder) => {
embedder.embed_index_ref(texts, threads, embedder_stats)
}
SubEmbedder::Ollama(embedder) => {
embedder.embed_index_ref(texts, threads, embedder_stats)
}
SubEmbedder::UserProvided(embedder) => embedder.embed_index_ref(texts),
SubEmbedder::Rest(embedder) => embedder.embed_index_ref(texts, threads, embedder_stats),
}

View file

@ -719,12 +719,14 @@ impl Embedder {
}
let embedding = match self {
Embedder::HuggingFace(embedder) => embedder.embed_one(text),
Embedder::OpenAi(embedder) => {
embedder.embed(&[text], deadline, None)?.pop().ok_or_else(EmbedError::missing_embedding)
}
Embedder::Ollama(embedder) => {
embedder.embed(&[text], deadline, None)?.pop().ok_or_else(EmbedError::missing_embedding)
}
Embedder::OpenAi(embedder) => embedder
.embed(&[text], deadline, None)?
.pop()
.ok_or_else(EmbedError::missing_embedding),
Embedder::Ollama(embedder) => embedder
.embed(&[text], deadline, None)?
.pop()
.ok_or_else(EmbedError::missing_embedding),
Embedder::UserProvided(embedder) => embedder.embed_one(text),
Embedder::Rest(embedder) => embedder
.embed_ref(&[text], deadline, None)?
@ -751,11 +753,17 @@ impl Embedder {
) -> std::result::Result<Vec<Vec<Embedding>>, EmbedError> {
match self {
Embedder::HuggingFace(embedder) => embedder.embed_index(text_chunks),
Embedder::OpenAi(embedder) => embedder.embed_index(text_chunks, threads, embedder_stats),
Embedder::Ollama(embedder) => embedder.embed_index(text_chunks, threads, embedder_stats),
Embedder::OpenAi(embedder) => {
embedder.embed_index(text_chunks, threads, embedder_stats)
}
Embedder::Ollama(embedder) => {
embedder.embed_index(text_chunks, threads, embedder_stats)
}
Embedder::UserProvided(embedder) => embedder.embed_index(text_chunks),
Embedder::Rest(embedder) => embedder.embed_index(text_chunks, threads, embedder_stats),
Embedder::Composite(embedder) => embedder.index.embed_index(text_chunks, threads, embedder_stats),
Embedder::Composite(embedder) => {
embedder.index.embed_index(text_chunks, threads, embedder_stats)
}
}
}
@ -772,7 +780,9 @@ impl Embedder {
Embedder::Ollama(embedder) => embedder.embed_index_ref(texts, threads, embedder_stats),
Embedder::UserProvided(embedder) => embedder.embed_index_ref(texts),
Embedder::Rest(embedder) => embedder.embed_index_ref(texts, threads, embedder_stats),
Embedder::Composite(embedder) => embedder.index.embed_index_ref(texts, threads, embedder_stats),
Embedder::Composite(embedder) => {
embedder.index.embed_index_ref(texts, threads, embedder_stats)
}
}
}

View file

@ -106,7 +106,7 @@ impl Embedder {
&self,
texts: &[S],
deadline: Option<Instant>,
embedder_stats: Option<Arc<EmbedderStats>>
embedder_stats: Option<Arc<EmbedderStats>>,
) -> Result<Vec<Embedding>, EmbedError> {
match self.rest_embedder.embed_ref(texts, deadline, embedder_stats) {
Ok(embeddings) => Ok(embeddings),
@ -126,11 +126,17 @@ impl Embedder {
// This condition helps reduce the number of active rayon jobs
// so that we avoid consuming all the LMDB rtxns and avoid stack overflows.
if threads.active_operations() >= REQUEST_PARALLELISM {
text_chunks.into_iter().map(move |chunk| self.embed(&chunk, None, embedder_stats.clone())).collect()
text_chunks
.into_iter()
.map(move |chunk| self.embed(&chunk, None, embedder_stats.clone()))
.collect()
} else {
threads
.install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk, None, embedder_stats.clone())).collect()
text_chunks
.into_par_iter()
.map(move |chunk| self.embed(&chunk, None, embedder_stats.clone()))
.collect()
})
.map_err(|error| EmbedError {
kind: EmbedErrorKind::PanicInThreadPool(error),
@ -143,7 +149,7 @@ impl Embedder {
&self,
texts: &[&str],
threads: &ThreadPoolNoAbort,
embedder_stats: Option<Arc<EmbedderStats>>
embedder_stats: Option<Arc<EmbedderStats>>,
) -> Result<Vec<Vec<f32>>, EmbedError> {
// This condition helps reduce the number of active rayon jobs
// so that we avoid consuming all the LMDB rtxns and avoid stack overflows.

View file

@ -241,7 +241,11 @@ impl Embedder {
let encoded = self.tokenizer.encode_ordinary(text);
let len = encoded.len();
if len < max_token_count {
all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text], deadline, None)?);
all_embeddings.append(&mut self.rest_embedder.embed_ref(
&[text],
deadline,
None,
)?);
continue;
}
@ -263,11 +267,17 @@ impl Embedder {
// This condition helps reduce the number of active rayon jobs
// so that we avoid consuming all the LMDB rtxns and avoid stack overflows.
if threads.active_operations() >= REQUEST_PARALLELISM {
text_chunks.into_iter().map(move |chunk| self.embed(&chunk, None, embedder_stats.clone())).collect()
text_chunks
.into_iter()
.map(move |chunk| self.embed(&chunk, None, embedder_stats.clone()))
.collect()
} else {
threads
.install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk, None, embedder_stats.clone())).collect()
text_chunks
.into_par_iter()
.map(move |chunk| self.embed(&chunk, None, embedder_stats.clone()))
.collect()
})
.map_err(|error| EmbedError {
kind: EmbedErrorKind::PanicInThreadPool(error),

View file

@ -14,8 +14,8 @@ use super::{
DistributionShift, EmbedError, Embedding, EmbeddingCache, NewEmbedderError, REQUEST_PARALLELISM,
};
use crate::error::FaultSource;
use crate::ThreadPoolNoAbort;
use crate::progress::EmbedderStats;
use crate::ThreadPoolNoAbort;
// retrying in case of failure
pub struct Retry {
@ -172,7 +172,14 @@ impl Embedder {
deadline: Option<Instant>,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> Result<Vec<Embedding>, EmbedError> {
embed(&self.data, texts.as_slice(), texts.len(), Some(self.dimensions), deadline, embedder_stats)
embed(
&self.data,
texts.as_slice(),
texts.len(),
Some(self.dimensions),
deadline,
embedder_stats,
)
}
pub fn embed_ref<S>(
@ -206,11 +213,17 @@ impl Embedder {
// This condition helps reduce the number of active rayon jobs
// so that we avoid consuming all the LMDB rtxns and avoid stack overflows.
if threads.active_operations() >= REQUEST_PARALLELISM {
text_chunks.into_iter().map(move |chunk| self.embed(chunk, None, embedder_stats.clone())).collect()
text_chunks
.into_iter()
.map(move |chunk| self.embed(chunk, None, embedder_stats.clone()))
.collect()
} else {
threads
.install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk, None, embedder_stats.clone())).collect()
text_chunks
.into_par_iter()
.map(move |chunk| self.embed(chunk, None, embedder_stats.clone()))
.collect()
})
.map_err(|error| EmbedError {
kind: EmbedErrorKind::PanicInThreadPool(error),
@ -223,7 +236,7 @@ impl Embedder {
&self,
texts: &[&str],
threads: &ThreadPoolNoAbort,
embedder_stats: Option<Arc<EmbedderStats>>
embedder_stats: Option<Arc<EmbedderStats>>,
) -> Result<Vec<Embedding>, EmbedError> {
// This condition helps reduce the number of active rayon jobs
// so that we avoid consuming all the LMDB rtxns and avoid stack overflows.