Move embedder stats out of progress

This commit is contained in:
Mubelotix 2025-06-23 15:24:14 +02:00
parent 4cadc8113b
commit 4925b30196
No known key found for this signature in database
GPG key ID: 89F391DBCC8CE7F0
30 changed files with 255 additions and 69 deletions

View file

@ -1,7 +1,7 @@
use std::collections::BTreeSet;
use std::fmt::Write;
use meilisearch_types::batches::{Batch, BatchEnqueuedAt, BatchStats};
use meilisearch_types::batches::{Batch, BatchEmbeddingStats, BatchEnqueuedAt, BatchStats};
use meilisearch_types::heed::types::{SerdeBincode, SerdeJson, Str};
use meilisearch_types::heed::{Database, RoTxn};
use meilisearch_types::milli::{CboRoaringBitmapCodec, RoaringBitmapCodec, BEU32};
@ -343,6 +343,7 @@ pub fn snapshot_batch(batch: &Batch) -> String {
uid,
details,
stats,
embedder_stats,
started_at,
finished_at,
progress: _,
@ -366,6 +367,7 @@ pub fn snapshot_batch(batch: &Batch) -> String {
snap.push_str(&format!("uid: {uid}, "));
snap.push_str(&format!("details: {}, ", serde_json::to_string(details).unwrap()));
snap.push_str(&format!("stats: {}, ", serde_json::to_string(&stats).unwrap()));
snap.push_str(&format!("embedder_stats: {}, ", serde_json::to_string(&embedder_stats).unwrap()));
snap.push_str(&format!("stop reason: {}, ", serde_json::to_string(&stop_reason).unwrap()));
snap.push('}');
snap

View file

@ -1,7 +1,7 @@
use std::collections::HashSet;
use std::ops::{Bound, RangeBounds};
use meilisearch_types::batches::{Batch, BatchId};
use meilisearch_types::batches::{Batch, BatchEmbeddingStats, BatchId};
use meilisearch_types::heed::types::{DecodeIgnore, SerdeBincode, SerdeJson, Str};
use meilisearch_types::heed::{Database, Env, RoTxn, RwTxn, WithoutTls};
use meilisearch_types::milli::{CboRoaringBitmapCodec, RoaringBitmapCodec, BEU32};
@ -92,7 +92,10 @@ impl BatchQueue {
}
pub(crate) fn get_batch(&self, rtxn: &RoTxn, batch_id: BatchId) -> Result<Option<Batch>> {
Ok(self.all_batches.get(rtxn, &batch_id)?)
println!("Got batch from db {batch_id:?}");
let r = Ok(self.all_batches.get(rtxn, &batch_id)?);
println!("Got batch from db => {:?}", r);
r
}
/// Returns the whole set of batches that belongs to this index.
@ -171,6 +174,8 @@ impl BatchQueue {
pub(crate) fn write_batch(&self, wtxn: &mut RwTxn, batch: ProcessingBatch) -> Result<()> {
let old_batch = self.all_batches.get(wtxn, &batch.uid)?;
println!("Saving batch: {}", batch.embedder_stats.is_some());
self.all_batches.put(
wtxn,
&batch.uid,
@ -179,6 +184,7 @@ impl BatchQueue {
progress: None,
details: batch.details,
stats: batch.stats,
embedder_stats: batch.embedder_stats.as_ref().map(|s| BatchEmbeddingStats::from(s.as_ref())),
started_at: batch.started_at,
finished_at: batch.finished_at,
enqueued_at: batch.enqueued_at,

View file

@ -1,10 +1,11 @@
use std::collections::{BTreeSet, HashMap, HashSet};
use std::panic::{catch_unwind, AssertUnwindSafe};
use std::sync::atomic::Ordering;
use std::sync::Arc;
use meilisearch_types::batches::{BatchEnqueuedAt, BatchId};
use meilisearch_types::heed::{RoTxn, RwTxn};
use meilisearch_types::milli::progress::{Progress, VariableNameStep};
use meilisearch_types::milli::progress::{EmbedderStats, Progress, VariableNameStep};
use meilisearch_types::milli::{self, ChannelCongestion};
use meilisearch_types::tasks::{Details, IndexSwap, Kind, KindWithContent, Status, Task};
use meilisearch_types::versioning::{VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH};
@ -163,7 +164,7 @@ impl IndexScheduler {
let pre_commit_dabases_sizes = index.database_sizes(&index_wtxn)?;
let (tasks, congestion) =
self.apply_index_operation(&mut index_wtxn, &index, op, &progress)?;
self.apply_index_operation(&mut index_wtxn, &index, op, &progress, current_batch.clone_embedder_stats())?;
{
progress.update_progress(FinalizingIndexStep::Committing);
@ -238,11 +239,21 @@ impl IndexScheduler {
);
builder.set_primary_key(primary_key);
let must_stop_processing = self.scheduler.must_stop_processing.clone();
let embedder_stats = match current_batch.embedder_stats {
Some(ref stats) => stats.clone(),
None => {
let embedder_stats: Arc<EmbedderStats> = Default::default();
current_batch.embedder_stats = Some(embedder_stats.clone());
embedder_stats
},
};
builder
.execute(
|indexing_step| tracing::debug!(update = ?indexing_step),
|| must_stop_processing.get(),
Some(progress.embedder_stats),
embedder_stats,
)
.map_err(|e| Error::from_milli(e, Some(index_uid.to_string())))?;
index_wtxn.commit()?;

View file

@ -4,7 +4,7 @@ use bumpalo::collections::CollectIn;
use bumpalo::Bump;
use meilisearch_types::heed::RwTxn;
use meilisearch_types::milli::documents::PrimaryKey;
use meilisearch_types::milli::progress::Progress;
use meilisearch_types::milli::progress::{EmbedderStats, Progress};
use meilisearch_types::milli::update::new::indexer::{self, UpdateByFunction};
use meilisearch_types::milli::update::DocumentAdditionResult;
use meilisearch_types::milli::{self, ChannelCongestion, Filter};
@ -26,7 +26,7 @@ impl IndexScheduler {
/// The list of processed tasks.
#[tracing::instrument(
level = "trace",
skip(self, index_wtxn, index, progress),
skip(self, index_wtxn, index, progress, embedder_stats),
target = "indexing::scheduler"
)]
pub(crate) fn apply_index_operation<'i>(
@ -35,6 +35,7 @@ impl IndexScheduler {
index: &'i Index,
operation: IndexOperation,
progress: &Progress,
embedder_stats: Arc<EmbedderStats>,
) -> Result<(Vec<Task>, Option<ChannelCongestion>)> {
let indexer_alloc = Bump::new();
let started_processing_at = std::time::Instant::now();
@ -179,6 +180,7 @@ impl IndexScheduler {
embedders,
&|| must_stop_processing.get(),
progress,
embedder_stats,
)
.map_err(|e| Error::from_milli(e, Some(index_uid.clone())))?,
);
@ -290,6 +292,7 @@ impl IndexScheduler {
embedders,
&|| must_stop_processing.get(),
progress,
embedder_stats,
)
.map_err(|err| Error::from_milli(err, Some(index_uid.clone())))?,
);
@ -438,6 +441,7 @@ impl IndexScheduler {
embedders,
&|| must_stop_processing.get(),
progress,
embedder_stats,
)
.map_err(|err| Error::from_milli(err, Some(index_uid.clone())))?,
);
@ -474,7 +478,7 @@ impl IndexScheduler {
.execute(
|indexing_step| tracing::debug!(update = ?indexing_step),
|| must_stop_processing.get(),
Some(Arc::clone(&progress.embedder_stats))
embedder_stats,
)
.map_err(|err| Error::from_milli(err, Some(index_uid.clone())))?;
@ -494,6 +498,7 @@ impl IndexScheduler {
tasks: cleared_tasks,
},
progress,
embedder_stats.clone(),
)?;
let (settings_tasks, _congestion) = self.apply_index_operation(
@ -501,6 +506,7 @@ impl IndexScheduler {
index,
IndexOperation::Settings { index_uid, settings, tasks: settings_tasks },
progress,
embedder_stats,
)?;
let mut tasks = settings_tasks;

View file

@ -2,8 +2,10 @@
use std::collections::{BTreeSet, HashSet};
use std::ops::Bound;
use std::sync::Arc;
use crate::milli::progress::EmbedderStats;
use meilisearch_types::batches::{Batch, BatchEnqueuedAt, BatchId, BatchStats};
use meilisearch_types::batches::{Batch, BatchEmbeddingStats, BatchEnqueuedAt, BatchId, BatchStats};
use meilisearch_types::heed::{Database, RoTxn, RwTxn};
use meilisearch_types::milli::CboRoaringBitmapCodec;
use meilisearch_types::task_view::DetailsView;
@ -27,6 +29,7 @@ pub struct ProcessingBatch {
pub uid: BatchId,
pub details: DetailsView,
pub stats: BatchStats,
pub embedder_stats: Option<Arc<EmbedderStats>>,
pub statuses: HashSet<Status>,
pub kinds: HashSet<Kind>,
@ -48,6 +51,7 @@ impl ProcessingBatch {
uid,
details: DetailsView::default(),
stats: BatchStats::default(),
embedder_stats: None,
statuses,
kinds: HashSet::default(),
@ -60,6 +64,17 @@ impl ProcessingBatch {
}
}
pub fn clone_embedder_stats(&mut self) -> Arc<EmbedderStats> {
match self.embedder_stats {
Some(ref stats) => stats.clone(),
None => {
let embedder_stats: Arc<EmbedderStats> = Default::default();
self.embedder_stats = Some(embedder_stats.clone());
embedder_stats
},
}
}
/// Update itself with the content of the task and update the batch id in the task.
pub fn processing<'a>(&mut self, tasks: impl IntoIterator<Item = &'a mut Task>) {
for task in tasks.into_iter() {
@ -141,11 +156,13 @@ impl ProcessingBatch {
}
pub fn to_batch(&self) -> Batch {
println!("Converting to batch: {:?}", self.embedder_stats);
Batch {
uid: self.uid,
progress: None,
details: self.details.clone(),
stats: self.stats.clone(),
embedder_stats: self.embedder_stats.as_ref().map(|s| BatchEmbeddingStats::from(s.as_ref())),
started_at: self.started_at,
finished_at: self.finished_at,
enqueued_at: self.enqueued_at,