mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-07-04 12:27:13 +02:00
Add embedder stats in batches
This commit is contained in:
parent
fc6cc80705
commit
4cadc8113b
26 changed files with 188 additions and 73 deletions
|
@ -1,4 +1,5 @@
|
|||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use ordered_float::OrderedFloat;
|
||||
|
@ -9,6 +10,7 @@ use super::error::{EmbedError, NewEmbedderError};
|
|||
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
|
||||
use super::{DistributionShift, EmbeddingCache, REQUEST_PARALLELISM};
|
||||
use crate::error::FaultSource;
|
||||
use crate::progress::EmbedderStats;
|
||||
use crate::vector::error::EmbedErrorKind;
|
||||
use crate::vector::Embedding;
|
||||
use crate::ThreadPoolNoAbort;
|
||||
|
@ -215,8 +217,9 @@ impl Embedder {
|
|||
&self,
|
||||
texts: &[S],
|
||||
deadline: Option<Instant>,
|
||||
embedder_stats: Option<Arc<EmbedderStats>>,
|
||||
) -> Result<Vec<Embedding>, EmbedError> {
|
||||
match self.rest_embedder.embed_ref(texts, deadline) {
|
||||
match self.rest_embedder.embed_ref(texts, deadline, embedder_stats) {
|
||||
Ok(embeddings) => Ok(embeddings),
|
||||
Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error, _), fault: _ }) => {
|
||||
tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template.");
|
||||
|
@ -238,7 +241,7 @@ impl Embedder {
|
|||
let encoded = self.tokenizer.encode_ordinary(text);
|
||||
let len = encoded.len();
|
||||
if len < max_token_count {
|
||||
all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text], deadline)?);
|
||||
all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text], deadline, None)?);
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -255,15 +258,16 @@ impl Embedder {
|
|||
&self,
|
||||
text_chunks: Vec<Vec<String>>,
|
||||
threads: &ThreadPoolNoAbort,
|
||||
embedder_stats: Option<Arc<EmbedderStats>>,
|
||||
) -> Result<Vec<Vec<Embedding>>, EmbedError> {
|
||||
// This condition helps reduce the number of active rayon jobs
|
||||
// so that we avoid consuming all the LMDB rtxns and avoid stack overflows.
|
||||
if threads.active_operations() >= REQUEST_PARALLELISM {
|
||||
text_chunks.into_iter().map(move |chunk| self.embed(&chunk, None)).collect()
|
||||
text_chunks.into_iter().map(move |chunk| self.embed(&chunk, None, embedder_stats.clone())).collect()
|
||||
} else {
|
||||
threads
|
||||
.install(move || {
|
||||
text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk, None)).collect()
|
||||
text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk, None, embedder_stats.clone())).collect()
|
||||
})
|
||||
.map_err(|error| EmbedError {
|
||||
kind: EmbedErrorKind::PanicInThreadPool(error),
|
||||
|
@ -276,13 +280,14 @@ impl Embedder {
|
|||
&self,
|
||||
texts: &[&str],
|
||||
threads: &ThreadPoolNoAbort,
|
||||
embedder_stats: Option<Arc<EmbedderStats>>,
|
||||
) -> Result<Vec<Vec<f32>>, EmbedError> {
|
||||
// This condition helps reduce the number of active rayon jobs
|
||||
// so that we avoid consuming all the LMDB rtxns and avoid stack overflows.
|
||||
if threads.active_operations() >= REQUEST_PARALLELISM {
|
||||
let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
|
||||
.chunks(self.prompt_count_in_chunk_hint())
|
||||
.map(move |chunk| self.embed(chunk, None))
|
||||
.map(move |chunk| self.embed(chunk, None, embedder_stats.clone()))
|
||||
.collect();
|
||||
let embeddings = embeddings?;
|
||||
Ok(embeddings.into_iter().flatten().collect())
|
||||
|
@ -291,7 +296,7 @@ impl Embedder {
|
|||
.install(move || {
|
||||
let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
|
||||
.par_chunks(self.prompt_count_in_chunk_hint())
|
||||
.map(move |chunk| self.embed(chunk, None))
|
||||
.map(move |chunk| self.embed(chunk, None, embedder_stats.clone()))
|
||||
.collect();
|
||||
|
||||
let embeddings = embeddings?;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue