Parallelize document upload

This commit is contained in:
Clément Renault 2025-06-16 16:30:35 +02:00
parent 04ee7b0863
commit b21c983b0a
No known key found for this signature in database
GPG Key ID: F250A4C4E3AE5F5F
6 changed files with 133 additions and 82 deletions

View File

@ -7,9 +7,9 @@ use backoff::ExponentialBackoff;
use meilisearch_types::index_uid_pattern::IndexUidPattern;
use meilisearch_types::milli::constants::RESERVED_VECTORS_FIELD_NAME;
use meilisearch_types::milli::progress::{Progress, VariableNameStep};
use meilisearch_types::milli::update::Setting;
use meilisearch_types::milli::update::{request_threads, Setting};
use meilisearch_types::milli::vector::parsed_vectors::{ExplicitVectors, VectorOrArrayOfVectors};
use meilisearch_types::milli::{obkv_to_json, Filter};
use meilisearch_types::milli::{self, obkv_to_json, Filter, InternalError};
use meilisearch_types::settings::{self, SecretPolicy};
use meilisearch_types::tasks::ExportIndexSettings;
use serde::Deserialize;
@ -112,6 +112,10 @@ impl IndexScheduler {
.embedding_configs(&index_rtxn)
.map_err(|e| Error::from_milli(e, Some(uid.to_string())))?;
// We don't need to keep this one alive as we will
// spawn many threads to process the documents
drop(index_rtxn);
let total_documents = universe.len() as u32;
let (step, progress_step) = AtomicDocumentStep::new(total_documents);
progress.update_progress(progress_step);
@ -119,73 +123,107 @@ impl IndexScheduler {
let limit = 50 * 1024 * 1024; // 50 MiB
let documents_url = format!("{base_url}/indexes/{uid}/documents");
let mut buffer = Vec::new();
let mut tmp_buffer = Vec::new();
for (i, docid) in universe.into_iter().enumerate() {
let document = index
.document(&index_rtxn, docid)
.map_err(|e| Error::from_milli(e, Some(uid.to_string())))?;
request_threads()
.broadcast(|ctx| {
let index_rtxn = index
.read_txn()
.map_err(|e| Error::from_milli(e.into(), Some(uid.to_string())))?;
let mut document = obkv_to_json(&all_fields, &fields_ids_map, document)
.map_err(|e| Error::from_milli(e, Some(uid.to_string())))?;
let mut buffer = Vec::new();
let mut tmp_buffer = Vec::new();
for (i, docid) in universe.iter().enumerate() {
if i % ctx.num_threads() != ctx.index() {
continue;
}
// TODO definitely factorize this code
'inject_vectors: {
let embeddings = index
.embeddings(&index_rtxn, docid)
.map_err(|e| Error::from_milli(e, Some(uid.to_string())))?;
let document = index
.document(&index_rtxn, docid)
.map_err(|e| Error::from_milli(e, Some(uid.to_string())))?;
if embeddings.is_empty() {
break 'inject_vectors;
let mut document = obkv_to_json(&all_fields, &fields_ids_map, document)
.map_err(|e| Error::from_milli(e, Some(uid.to_string())))?;
// TODO definitely factorize this code
'inject_vectors: {
let embeddings = index
.embeddings(&index_rtxn, docid)
.map_err(|e| Error::from_milli(e, Some(uid.to_string())))?;
if embeddings.is_empty() {
break 'inject_vectors;
}
let vectors = document
.entry(RESERVED_VECTORS_FIELD_NAME)
.or_insert(serde_json::Value::Object(Default::default()));
let serde_json::Value::Object(vectors) = vectors else {
return Err(Error::from_milli(
milli::Error::UserError(
milli::UserError::InvalidVectorsMapType {
document_id: {
if let Ok(Some(Ok(index))) = index
.external_id_of(
&index_rtxn,
std::iter::once(docid),
)
.map(|it| it.into_iter().next())
{
index
} else {
format!("internal docid={docid}")
}
},
value: vectors.clone(),
},
),
Some(uid.to_string()),
));
};
for (embedder_name, embeddings) in embeddings {
let user_provided = embedding_configs
.iter()
.find(|conf| conf.name == embedder_name)
.is_some_and(|conf| conf.user_provided.contains(docid));
let embeddings = ExplicitVectors {
embeddings: Some(
VectorOrArrayOfVectors::from_array_of_vectors(embeddings),
),
regenerate: !user_provided,
};
vectors.insert(
embedder_name,
serde_json::to_value(embeddings).unwrap(),
);
}
}
tmp_buffer.clear();
serde_json::to_writer(&mut tmp_buffer, &document)
.map_err(milli::InternalError::from)
.map_err(|e| Error::from_milli(e.into(), Some(uid.to_string())))?;
if buffer.len() + tmp_buffer.len() > limit {
retry(&must_stop_processing, || {
let mut request = agent.post(&documents_url);
request = request.set("Content-Type", "application/x-ndjson");
if let Some(api_key) = api_key {
request = request
.set("Authorization", &(format!("Bearer {api_key}")));
}
request.send_bytes(&buffer).map_err(into_backoff_error)
})?;
buffer.clear();
}
buffer.extend_from_slice(&tmp_buffer);
if i % 100 == 0 {
step.fetch_add(100, atomic::Ordering::Relaxed);
}
}
let vectors = document
.entry(RESERVED_VECTORS_FIELD_NAME)
.or_insert(serde_json::Value::Object(Default::default()));
let serde_json::Value::Object(vectors) = vectors else {
return Err(Error::from_milli(
meilisearch_types::milli::Error::UserError(
meilisearch_types::milli::UserError::InvalidVectorsMapType {
document_id: {
if let Ok(Some(Ok(index))) = index
.external_id_of(&index_rtxn, std::iter::once(docid))
.map(|it| it.into_iter().next())
{
index
} else {
format!("internal docid={docid}")
}
},
value: vectors.clone(),
},
),
Some(uid.to_string()),
));
};
for (embedder_name, embeddings) in embeddings {
let user_provided = embedding_configs
.iter()
.find(|conf| conf.name == embedder_name)
.is_some_and(|conf| conf.user_provided.contains(docid));
let embeddings = ExplicitVectors {
embeddings: Some(VectorOrArrayOfVectors::from_array_of_vectors(
embeddings,
)),
regenerate: !user_provided,
};
vectors.insert(embedder_name, serde_json::to_value(embeddings).unwrap());
}
}
tmp_buffer.clear();
serde_json::to_writer(&mut tmp_buffer, &document)
.map_err(meilisearch_types::milli::InternalError::from)
.map_err(|e| Error::from_milli(e.into(), Some(uid.to_string())))?;
if buffer.len() + tmp_buffer.len() > limit {
retry(&must_stop_processing, || {
let mut request = agent.post(&documents_url);
request = request.set("Content-Type", "application/x-ndjson");
@ -194,23 +232,16 @@ impl IndexScheduler {
}
request.send_bytes(&buffer).map_err(into_backoff_error)
})?;
buffer.clear();
}
buffer.extend_from_slice(&tmp_buffer);
if i % 100 == 0 {
step.fetch_add(100, atomic::Ordering::Relaxed);
}
}
Ok(())
})
.map_err(|e| {
Error::from_milli(
milli::Error::InternalError(InternalError::PanicInThreadPool(e)),
Some(uid.to_string()),
)
})?;
retry(&must_stop_processing, || {
let mut request = agent.post(&documents_url);
request = request.set("Content-Type", "application/x-ndjson");
if let Some(api_key) = api_key {
request = request.set("Authorization", &(format!("Bearer {api_key}")));
}
request.send_bytes(&buffer).map_err(into_backoff_error)
})?;
step.store(total_documents, atomic::Ordering::Relaxed);
}

View File

@ -766,6 +766,7 @@ fn basic_get_stats() {
"documentDeletion": 0,
"documentEdition": 0,
"dumpCreation": 0,
"export": 0,
"indexCreation": 3,
"indexDeletion": 0,
"indexSwap": 0,
@ -806,6 +807,7 @@ fn basic_get_stats() {
"documentDeletion": 0,
"documentEdition": 0,
"dumpCreation": 0,
"export": 0,
"indexCreation": 3,
"indexDeletion": 0,
"indexSwap": 0,
@ -847,6 +849,7 @@ fn basic_get_stats() {
"documentDeletion": 0,
"documentEdition": 0,
"dumpCreation": 0,
"export": 0,
"indexCreation": 3,
"indexDeletion": 0,
"indexSwap": 0,

View File

@ -1,7 +1,7 @@
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use rayon::{ThreadPool, ThreadPoolBuilder};
use rayon::{BroadcastContext, ThreadPool, ThreadPoolBuilder};
use thiserror::Error;
/// A rayon ThreadPool wrapper that can catch panics in the pool
@ -32,6 +32,22 @@ impl ThreadPoolNoAbort {
}
}
pub fn broadcast<OP, R>(&self, op: OP) -> Result<Vec<R>, PanicCatched>
where
OP: Fn(BroadcastContext<'_>) -> R + Sync,
R: Send,
{
self.active_operations.fetch_add(1, Ordering::Relaxed);
let output = self.thread_pool.broadcast(op);
self.active_operations.fetch_sub(1, Ordering::Relaxed);
// While reseting the pool panic catcher we return an error if we catched one.
if self.pool_catched_panic.swap(false, Ordering::SeqCst) {
Err(PanicCatched)
} else {
Ok(output)
}
}
pub fn current_num_threads(&self) -> usize {
self.thread_pool.current_num_threads()
}

View File

@ -210,7 +210,7 @@ fn run_extraction_task<FE, FS, M>(
})
}
fn request_threads() -> &'static ThreadPoolNoAbort {
pub fn request_threads() -> &'static ThreadPoolNoAbort {
static REQUEST_THREADS: OnceLock<ThreadPoolNoAbort> = OnceLock::new();
REQUEST_THREADS.get_or_init(|| {

View File

@ -12,6 +12,7 @@ use std::sync::Arc;
use crossbeam_channel::{Receiver, Sender};
use enrich::enrich_documents_batch;
pub use extract::request_threads;
use grenad::{Merger, MergerBuilder};
use hashbrown::HashMap;
use heed::types::Str;

View File

@ -4,7 +4,7 @@ pub use self::clear_documents::ClearDocuments;
pub use self::concurrent_available_ids::ConcurrentAvailableIds;
pub use self::facet::bulk::FacetsUpdateBulk;
pub use self::facet::incremental::FacetsUpdateIncrementalInner;
pub use self::index_documents::*;
pub use self::index_documents::{request_threads, *};
pub use self::indexer_config::{default_thread_pool_and_threads, IndexerConfig};
pub use self::new::ChannelCongestion;
pub use self::settings::{validate_embedding_settings, Setting, Settings};