Add embedder stats in batches

This commit is contained in:
Mubelotix 2025-06-20 12:42:22 +02:00
parent fc6cc80705
commit 4cadc8113b
No known key found for this signature in database
GPG key ID: 89F391DBCC8CE7F0
26 changed files with 188 additions and 73 deletions

View file

@ -1,3 +1,4 @@
use std::sync::Arc;
use std::time::Instant;
use arroy::Distance;
@ -7,6 +8,7 @@ use super::{
hf, manual, ollama, openai, rest, DistributionShift, EmbedError, Embedding, EmbeddingCache,
NewEmbedderError,
};
use crate::progress::EmbedderStats;
use crate::ThreadPoolNoAbort;
#[derive(Debug)]
@ -81,6 +83,7 @@ impl Embedder {
"This is a sample text. It is meant to compare similarity.".into(),
],
None,
None,
)
.map_err(|error| NewEmbedderError::composite_test_embedding_failed(error, "search"))?;
@ -92,6 +95,7 @@ impl Embedder {
"This is a sample text. It is meant to compare similarity.".into(),
],
None,
None,
)
.map_err(|error| {
NewEmbedderError::composite_test_embedding_failed(error, "indexing")
@ -150,13 +154,14 @@ impl SubEmbedder {
&self,
texts: Vec<String>,
deadline: Option<Instant>,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> std::result::Result<Vec<Embedding>, EmbedError> {
match self {
SubEmbedder::HuggingFace(embedder) => embedder.embed(texts),
SubEmbedder::OpenAi(embedder) => embedder.embed(&texts, deadline),
SubEmbedder::Ollama(embedder) => embedder.embed(&texts, deadline),
SubEmbedder::OpenAi(embedder) => embedder.embed(&texts, deadline, embedder_stats),
SubEmbedder::Ollama(embedder) => embedder.embed(&texts, deadline, embedder_stats),
SubEmbedder::UserProvided(embedder) => embedder.embed(&texts),
SubEmbedder::Rest(embedder) => embedder.embed(texts, deadline),
SubEmbedder::Rest(embedder) => embedder.embed(texts, deadline, embedder_stats),
}
}
@ -164,18 +169,19 @@ impl SubEmbedder {
&self,
text: &str,
deadline: Option<Instant>,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> std::result::Result<Embedding, EmbedError> {
match self {
SubEmbedder::HuggingFace(embedder) => embedder.embed_one(text),
SubEmbedder::OpenAi(embedder) => {
embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding)
embedder.embed(&[text], deadline, embedder_stats)?.pop().ok_or_else(EmbedError::missing_embedding)
}
SubEmbedder::Ollama(embedder) => {
embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding)
embedder.embed(&[text], deadline, embedder_stats)?.pop().ok_or_else(EmbedError::missing_embedding)
}
SubEmbedder::UserProvided(embedder) => embedder.embed_one(text),
SubEmbedder::Rest(embedder) => embedder
.embed_ref(&[text], deadline)?
.embed_ref(&[text], deadline, embedder_stats)?
.pop()
.ok_or_else(EmbedError::missing_embedding),
}
@ -188,13 +194,14 @@ impl SubEmbedder {
&self,
text_chunks: Vec<Vec<String>>,
threads: &ThreadPoolNoAbort,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> std::result::Result<Vec<Vec<Embedding>>, EmbedError> {
match self {
SubEmbedder::HuggingFace(embedder) => embedder.embed_index(text_chunks),
SubEmbedder::OpenAi(embedder) => embedder.embed_index(text_chunks, threads),
SubEmbedder::Ollama(embedder) => embedder.embed_index(text_chunks, threads),
SubEmbedder::OpenAi(embedder) => embedder.embed_index(text_chunks, threads, embedder_stats),
SubEmbedder::Ollama(embedder) => embedder.embed_index(text_chunks, threads, embedder_stats),
SubEmbedder::UserProvided(embedder) => embedder.embed_index(text_chunks),
SubEmbedder::Rest(embedder) => embedder.embed_index(text_chunks, threads),
SubEmbedder::Rest(embedder) => embedder.embed_index(text_chunks, threads, embedder_stats),
}
}
@ -203,13 +210,14 @@ impl SubEmbedder {
&self,
texts: &[&str],
threads: &ThreadPoolNoAbort,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> std::result::Result<Vec<Embedding>, EmbedError> {
match self {
SubEmbedder::HuggingFace(embedder) => embedder.embed_index_ref(texts),
SubEmbedder::OpenAi(embedder) => embedder.embed_index_ref(texts, threads),
SubEmbedder::Ollama(embedder) => embedder.embed_index_ref(texts, threads),
SubEmbedder::OpenAi(embedder) => embedder.embed_index_ref(texts, threads, embedder_stats),
SubEmbedder::Ollama(embedder) => embedder.embed_index_ref(texts, threads, embedder_stats),
SubEmbedder::UserProvided(embedder) => embedder.embed_index_ref(texts),
SubEmbedder::Rest(embedder) => embedder.embed_index_ref(texts, threads),
SubEmbedder::Rest(embedder) => embedder.embed_index_ref(texts, threads, embedder_stats),
}
}

View file

@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use self::error::{EmbedError, NewEmbedderError};
use crate::progress::Progress;
use crate::progress::{EmbedderStats, Progress};
use crate::prompt::{Prompt, PromptData};
use crate::ThreadPoolNoAbort;
@ -720,17 +720,17 @@ impl Embedder {
let embedding = match self {
Embedder::HuggingFace(embedder) => embedder.embed_one(text),
Embedder::OpenAi(embedder) => {
embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding)
embedder.embed(&[text], deadline, None)?.pop().ok_or_else(EmbedError::missing_embedding)
}
Embedder::Ollama(embedder) => {
embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding)
embedder.embed(&[text], deadline, None)?.pop().ok_or_else(EmbedError::missing_embedding)
}
Embedder::UserProvided(embedder) => embedder.embed_one(text),
Embedder::Rest(embedder) => embedder
.embed_ref(&[text], deadline)?
.embed_ref(&[text], deadline, None)?
.pop()
.ok_or_else(EmbedError::missing_embedding),
Embedder::Composite(embedder) => embedder.search.embed_one(text, deadline),
Embedder::Composite(embedder) => embedder.search.embed_one(text, deadline, None),
}?;
if let Some(cache) = self.cache() {
@ -747,14 +747,15 @@ impl Embedder {
&self,
text_chunks: Vec<Vec<String>>,
threads: &ThreadPoolNoAbort,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> std::result::Result<Vec<Vec<Embedding>>, EmbedError> {
match self {
Embedder::HuggingFace(embedder) => embedder.embed_index(text_chunks),
Embedder::OpenAi(embedder) => embedder.embed_index(text_chunks, threads),
Embedder::Ollama(embedder) => embedder.embed_index(text_chunks, threads),
Embedder::OpenAi(embedder) => embedder.embed_index(text_chunks, threads, embedder_stats),
Embedder::Ollama(embedder) => embedder.embed_index(text_chunks, threads, embedder_stats),
Embedder::UserProvided(embedder) => embedder.embed_index(text_chunks),
Embedder::Rest(embedder) => embedder.embed_index(text_chunks, threads),
Embedder::Composite(embedder) => embedder.index.embed_index(text_chunks, threads),
Embedder::Rest(embedder) => embedder.embed_index(text_chunks, threads, embedder_stats),
Embedder::Composite(embedder) => embedder.index.embed_index(text_chunks, threads, embedder_stats),
}
}
@ -763,14 +764,15 @@ impl Embedder {
&self,
texts: &[&str],
threads: &ThreadPoolNoAbort,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> std::result::Result<Vec<Embedding>, EmbedError> {
match self {
Embedder::HuggingFace(embedder) => embedder.embed_index_ref(texts),
Embedder::OpenAi(embedder) => embedder.embed_index_ref(texts, threads),
Embedder::Ollama(embedder) => embedder.embed_index_ref(texts, threads),
Embedder::OpenAi(embedder) => embedder.embed_index_ref(texts, threads, embedder_stats),
Embedder::Ollama(embedder) => embedder.embed_index_ref(texts, threads, embedder_stats),
Embedder::UserProvided(embedder) => embedder.embed_index_ref(texts),
Embedder::Rest(embedder) => embedder.embed_index_ref(texts, threads),
Embedder::Composite(embedder) => embedder.index.embed_index_ref(texts, threads),
Embedder::Rest(embedder) => embedder.embed_index_ref(texts, threads, embedder_stats),
Embedder::Composite(embedder) => embedder.index.embed_index_ref(texts, threads, embedder_stats),
}
}

View file

@ -1,3 +1,4 @@
use std::sync::Arc;
use std::time::Instant;
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
@ -7,6 +8,7 @@ use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErro
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
use super::{DistributionShift, EmbeddingCache, REQUEST_PARALLELISM};
use crate::error::FaultSource;
use crate::progress::EmbedderStats;
use crate::vector::Embedding;
use crate::ThreadPoolNoAbort;
@ -104,8 +106,9 @@ impl Embedder {
&self,
texts: &[S],
deadline: Option<Instant>,
embedder_stats: Option<Arc<EmbedderStats>>
) -> Result<Vec<Embedding>, EmbedError> {
match self.rest_embedder.embed_ref(texts, deadline) {
match self.rest_embedder.embed_ref(texts, deadline, embedder_stats) {
Ok(embeddings) => Ok(embeddings),
Err(EmbedError { kind: EmbedErrorKind::RestOtherStatusCode(404, error), fault: _ }) => {
Err(EmbedError::ollama_model_not_found(error))
@ -118,15 +121,16 @@ impl Embedder {
&self,
text_chunks: Vec<Vec<String>>,
threads: &ThreadPoolNoAbort,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> Result<Vec<Vec<Embedding>>, EmbedError> {
// This condition helps reduce the number of active rayon jobs
// so that we avoid consuming all the LMDB rtxns and avoid stack overflows.
if threads.active_operations() >= REQUEST_PARALLELISM {
text_chunks.into_iter().map(move |chunk| self.embed(&chunk, None)).collect()
text_chunks.into_iter().map(move |chunk| self.embed(&chunk, None, embedder_stats.clone())).collect()
} else {
threads
.install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk, None)).collect()
text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk, None, embedder_stats.clone())).collect()
})
.map_err(|error| EmbedError {
kind: EmbedErrorKind::PanicInThreadPool(error),
@ -139,13 +143,14 @@ impl Embedder {
&self,
texts: &[&str],
threads: &ThreadPoolNoAbort,
embedder_stats: Option<Arc<EmbedderStats>>
) -> Result<Vec<Vec<f32>>, EmbedError> {
// This condition helps reduce the number of active rayon jobs
// so that we avoid consuming all the LMDB rtxns and avoid stack overflows.
if threads.active_operations() >= REQUEST_PARALLELISM {
let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
.chunks(self.prompt_count_in_chunk_hint())
.map(move |chunk| self.embed(chunk, None))
.map(move |chunk| self.embed(chunk, None, embedder_stats.clone()))
.collect();
let embeddings = embeddings?;
@ -155,7 +160,7 @@ impl Embedder {
.install(move || {
let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
.par_chunks(self.prompt_count_in_chunk_hint())
.map(move |chunk| self.embed(chunk, None))
.map(move |chunk| self.embed(chunk, None, embedder_stats.clone()))
.collect();
let embeddings = embeddings?;

View file

@ -1,4 +1,5 @@
use std::fmt;
use std::sync::Arc;
use std::time::Instant;
use ordered_float::OrderedFloat;
@ -9,6 +10,7 @@ use super::error::{EmbedError, NewEmbedderError};
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
use super::{DistributionShift, EmbeddingCache, REQUEST_PARALLELISM};
use crate::error::FaultSource;
use crate::progress::EmbedderStats;
use crate::vector::error::EmbedErrorKind;
use crate::vector::Embedding;
use crate::ThreadPoolNoAbort;
@ -215,8 +217,9 @@ impl Embedder {
&self,
texts: &[S],
deadline: Option<Instant>,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> Result<Vec<Embedding>, EmbedError> {
match self.rest_embedder.embed_ref(texts, deadline) {
match self.rest_embedder.embed_ref(texts, deadline, embedder_stats) {
Ok(embeddings) => Ok(embeddings),
Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error, _), fault: _ }) => {
tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template.");
@ -238,7 +241,7 @@ impl Embedder {
let encoded = self.tokenizer.encode_ordinary(text);
let len = encoded.len();
if len < max_token_count {
all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text], deadline)?);
all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text], deadline, None)?);
continue;
}
@ -255,15 +258,16 @@ impl Embedder {
&self,
text_chunks: Vec<Vec<String>>,
threads: &ThreadPoolNoAbort,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> Result<Vec<Vec<Embedding>>, EmbedError> {
// This condition helps reduce the number of active rayon jobs
// so that we avoid consuming all the LMDB rtxns and avoid stack overflows.
if threads.active_operations() >= REQUEST_PARALLELISM {
text_chunks.into_iter().map(move |chunk| self.embed(&chunk, None)).collect()
text_chunks.into_iter().map(move |chunk| self.embed(&chunk, None, embedder_stats.clone())).collect()
} else {
threads
.install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk, None)).collect()
text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk, None, embedder_stats.clone())).collect()
})
.map_err(|error| EmbedError {
kind: EmbedErrorKind::PanicInThreadPool(error),
@ -276,13 +280,14 @@ impl Embedder {
&self,
texts: &[&str],
threads: &ThreadPoolNoAbort,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> Result<Vec<Vec<f32>>, EmbedError> {
// This condition helps reduce the number of active rayon jobs
// so that we avoid consuming all the LMDB rtxns and avoid stack overflows.
if threads.active_operations() >= REQUEST_PARALLELISM {
let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
.chunks(self.prompt_count_in_chunk_hint())
.map(move |chunk| self.embed(chunk, None))
.map(move |chunk| self.embed(chunk, None, embedder_stats.clone()))
.collect();
let embeddings = embeddings?;
Ok(embeddings.into_iter().flatten().collect())
@ -291,7 +296,7 @@ impl Embedder {
.install(move || {
let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
.par_chunks(self.prompt_count_in_chunk_hint())
.map(move |chunk| self.embed(chunk, None))
.map(move |chunk| self.embed(chunk, None, embedder_stats.clone()))
.collect();
let embeddings = embeddings?;

View file

@ -1,4 +1,5 @@
use std::collections::BTreeMap;
use std::sync::Arc;
use std::time::Instant;
use deserr::Deserr;
@ -14,6 +15,7 @@ use super::{
};
use crate::error::FaultSource;
use crate::ThreadPoolNoAbort;
use crate::progress::EmbedderStats;
// retrying in case of failure
pub struct Retry {
@ -168,19 +170,21 @@ impl Embedder {
&self,
texts: Vec<String>,
deadline: Option<Instant>,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> Result<Vec<Embedding>, EmbedError> {
embed(&self.data, texts.as_slice(), texts.len(), Some(self.dimensions), deadline)
embed(&self.data, texts.as_slice(), texts.len(), Some(self.dimensions), deadline, embedder_stats)
}
pub fn embed_ref<S>(
&self,
texts: &[S],
deadline: Option<Instant>,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> Result<Vec<Embedding>, EmbedError>
where
S: AsRef<str> + Serialize,
{
embed(&self.data, texts, texts.len(), Some(self.dimensions), deadline)
embed(&self.data, texts, texts.len(), Some(self.dimensions), deadline, embedder_stats)
}
pub fn embed_tokens(
@ -188,7 +192,7 @@ impl Embedder {
tokens: &[u32],
deadline: Option<Instant>,
) -> Result<Embedding, EmbedError> {
let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions), deadline)?;
let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions), deadline, None)?;
// unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error
Ok(embeddings.pop().unwrap())
}
@ -197,15 +201,16 @@ impl Embedder {
&self,
text_chunks: Vec<Vec<String>>,
threads: &ThreadPoolNoAbort,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> Result<Vec<Vec<Embedding>>, EmbedError> {
// This condition helps reduce the number of active rayon jobs
// so that we avoid consuming all the LMDB rtxns and avoid stack overflows.
if threads.active_operations() >= REQUEST_PARALLELISM {
text_chunks.into_iter().map(move |chunk| self.embed(chunk, None)).collect()
text_chunks.into_iter().map(move |chunk| self.embed(chunk, None, embedder_stats.clone())).collect()
} else {
threads
.install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk, None)).collect()
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk, None, embedder_stats.clone())).collect()
})
.map_err(|error| EmbedError {
kind: EmbedErrorKind::PanicInThreadPool(error),
@ -218,13 +223,14 @@ impl Embedder {
&self,
texts: &[&str],
threads: &ThreadPoolNoAbort,
embedder_stats: Option<Arc<EmbedderStats>>
) -> Result<Vec<Embedding>, EmbedError> {
// This condition helps reduce the number of active rayon jobs
// so that we avoid consuming all the LMDB rtxns and avoid stack overflows.
if threads.active_operations() >= REQUEST_PARALLELISM {
let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
.chunks(self.prompt_count_in_chunk_hint())
.map(move |chunk| self.embed_ref(chunk, None))
.map(move |chunk| self.embed_ref(chunk, None, embedder_stats.clone()))
.collect();
let embeddings = embeddings?;
@ -234,7 +240,7 @@ impl Embedder {
.install(move || {
let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
.par_chunks(self.prompt_count_in_chunk_hint())
.map(move |chunk| self.embed_ref(chunk, None))
.map(move |chunk| self.embed_ref(chunk, None, embedder_stats.clone()))
.collect();
let embeddings = embeddings?;
@ -272,7 +278,7 @@ impl Embedder {
}
fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> {
let v = embed(data, ["test"].as_slice(), 1, None, None)
let v = embed(data, ["test"].as_slice(), 1, None, None, None)
.map_err(NewEmbedderError::could_not_determine_dimension)?;
// unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error
Ok(v.first().unwrap().len())
@ -284,6 +290,7 @@ fn embed<S>(
expected_count: usize,
expected_dimension: Option<usize>,
deadline: Option<Instant>,
embedder_stats: Option<Arc<EmbedderStats>>,
) -> Result<Vec<Embedding>, EmbedError>
where
S: Serialize,
@ -302,6 +309,9 @@ where
let body = data.request.inject_texts(inputs);
for attempt in 0..10 {
if let Some(embedder_stats) = &embedder_stats {
embedder_stats.as_ref().total_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
let response = request.clone().send_json(&body);
let result = check_response(response, data.configuration_source).and_then(|response| {
response_to_embedding(response, data, expected_count, expected_dimension)
@ -311,6 +321,12 @@ where
Ok(response) => return Ok(response),
Err(retry) => {
tracing::warn!("Failed: {}", retry.error);
if let Some(embedder_stats) = &embedder_stats {
if let Ok(mut errors) = embedder_stats.errors.write() {
errors.0 = Some(retry.error.to_string());
errors.1 += 1;
}
}
if let Some(deadline) = deadline {
let now = std::time::Instant::now();
if now > deadline {
@ -336,12 +352,26 @@ where
std::thread::sleep(retry_duration);
}
if let Some(embedder_stats) = &embedder_stats {
embedder_stats.as_ref().total_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
let response = request.send_json(&body);
let result = check_response(response, data.configuration_source);
result.map_err(Retry::into_error).and_then(|response| {
let result = check_response(response, data.configuration_source).and_then(|response| {
response_to_embedding(response, data, expected_count, expected_dimension)
.map_err(Retry::into_error)
})
});
match result {
Ok(response) => Ok(response),
Err(retry) => {
if let Some(embedder_stats) = &embedder_stats {
if let Ok(mut errors) = embedder_stats.errors.write() {
errors.0 = Some(retry.error.to_string());
errors.1 += 1;
}
}
Err(retry.into_error())
}
}
}
fn check_response(