diff --git a/crates/index-scheduler/src/lib.rs b/crates/index-scheduler/src/lib.rs index f551652c1..b2f27d66b 100644 --- a/crates/index-scheduler/src/lib.rs +++ b/crates/index-scheduler/src/lib.rs @@ -882,12 +882,12 @@ impl IndexScheduler { { let embedders = self.embedders.read().unwrap(); if let Some(embedder) = embedders.get(&embedder_options) { - let runtime = Arc::new(RuntimeEmbedder { - embedder: embedder.clone(), + let runtime = Arc::new(RuntimeEmbedder::new( + embedder.clone(), document_template, fragments, - is_quantized: quantized.unwrap_or_default(), - }); + quantized.unwrap_or_default(), + )); return Ok((name, runtime)); } @@ -906,12 +906,12 @@ impl IndexScheduler { embedders.insert(embedder_options, embedder.clone()); } - let runtime = Arc::new(RuntimeEmbedder { - embedder: embedder.clone(), + let runtime = Arc::new(RuntimeEmbedder::new( + embedder.clone(), document_template, fragments, - is_quantized: quantized.unwrap_or_default(), - }); + quantized.unwrap_or_default(), + )); Ok((name, runtime)) }, diff --git a/crates/milli/src/test_index.rs b/crates/milli/src/test_index.rs index cfd8c8492..6bb6b1345 100644 --- a/crates/milli/src/test_index.rs +++ b/crates/milli/src/test_index.rs @@ -66,7 +66,7 @@ impl TempIndex { let db_fields_ids_map = self.inner.fields_ids_map(&rtxn)?; let mut new_fields_ids_map = db_fields_ids_map.clone(); - let embedders = InnerIndexSettings::from_index(&self.inner, &rtxn, None)?.embedding_configs; + let embedders = InnerIndexSettings::from_index(&self.inner, &rtxn, None)?.runtime_embedders; let mut indexer = indexer::DocumentOperation::new(); match self.index_documents_config.update_method { IndexDocumentsMethod::ReplaceDocuments => { @@ -151,7 +151,7 @@ impl TempIndex { let db_fields_ids_map = self.inner.fields_ids_map(&rtxn)?; let mut new_fields_ids_map = db_fields_ids_map.clone(); - let embedders = InnerIndexSettings::from_index(&self.inner, &rtxn, None)?.embedding_configs; + let embedders = InnerIndexSettings::from_index(&self.inner, &rtxn, None)?.runtime_embedders; let mut indexer = indexer::DocumentOperation::new(); let external_document_ids: Vec<_> = diff --git a/crates/milli/src/update/index_documents/extract/extract_vector_points.rs b/crates/milli/src/update/index_documents/extract/extract_vector_points.rs index 0a179cfa5..d40e82b92 100644 --- a/crates/milli/src/update/index_documents/extract/extract_vector_points.rs +++ b/crates/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -236,8 +236,8 @@ pub fn extract_vector_points( let mut extractors = Vec::new(); - let mut configs = settings_diff.new.embedding_configs.clone().into_inner(); - let old_configs = &settings_diff.old.embedding_configs; + let mut configs = settings_diff.new.runtime_embedders.clone().into_inner(); + let old_configs = &settings_diff.old.runtime_embedders; if reindex_vectors { for (name, action) in settings_diff.embedding_config_updates.iter() { if let Some(action) = action.reindex() { @@ -284,16 +284,16 @@ pub fn extract_vector_points( continue; }; - let fragments = regenerate_fragments + let fragment_diffs = regenerate_fragments .iter() .filter_map(|(name, fragment)| match fragment { crate::vector::settings::RegenerateFragment::Update => { let old_value = old_runtime - .fragments + .fragments() .binary_search_by_key(&name, |fragment| &fragment.name) .ok(); let Ok(new_value) = runtime - .fragments + .fragments() .binary_search_by_key(&name, |fragment| &fragment.name) else { return None; @@ -304,7 +304,7 @@ pub fn extract_vector_points( crate::vector::settings::RegenerateFragment::Remove => None, crate::vector::settings::RegenerateFragment::Add => { let Ok(new_value) = runtime - .fragments + .fragments() .binary_search_by_key(&name, |fragment| &fragment.name) else { return None; @@ -314,8 +314,8 @@ pub fn extract_vector_points( }) .collect(); ExtractionAction::SettingsRegenerateFragments { - old_runtime, - must_regenerate_fragments: fragments, + old_runtime: old_runtime.clone(), + must_regenerate_fragments: fragment_diffs, } } @@ -325,7 +325,9 @@ pub fn extract_vector_points( continue; }; - ExtractionAction::SettingsRegeneratePrompts { old_runtime } + ExtractionAction::SettingsRegeneratePrompts { + old_runtime: old_runtime.clone(), + } } }; @@ -473,11 +475,11 @@ pub fn extract_vector_points( ); continue; } - let has_fragments = !runtime.fragments.is_empty(); + let has_fragments = !runtime.fragments().is_empty(); if has_fragments { regenerate_all_fragments( - &runtime.fragments, + runtime.fragments(), &doc_alloc, new_fields_ids_map, obkv, @@ -492,14 +494,14 @@ pub fn extract_vector_points( old_runtime, } => { if old.must_regenerate() { - let has_fragments = !runtime.fragments.is_empty(); - let old_has_fragments = !old_runtime.fragments.is_empty(); + let has_fragments = !runtime.fragments().is_empty(); + let old_has_fragments = !old_runtime.fragments().is_empty(); let is_adding_fragments = has_fragments && !old_has_fragments; if is_adding_fragments { regenerate_all_fragments( - &runtime.fragments, + runtime.fragments(), &doc_alloc, new_fields_ids_map, obkv, @@ -517,14 +519,16 @@ pub fn extract_vector_points( new_fields_ids_map, ); for (name, (old_index, new_index)) in must_regenerate_fragments { - let Some(new) = runtime.fragments.get(*new_index) else { continue }; + let Some(new) = runtime.fragments().get(*new_index) else { + continue; + }; let new = RequestFragmentExtractor::new(new, &doc_alloc).ignore_errors(); let diff = { let old = old_index.as_ref().and_then(|old| { - let old = old_runtime.fragments.get(*old)?; + let old = old_runtime.fragments().get(*old)?; Some( RequestFragmentExtractor::new(old, &doc_alloc) .ignore_errors(), @@ -555,11 +559,11 @@ pub fn extract_vector_points( ); continue; } - let has_fragments = !runtime.fragments.is_empty(); + let has_fragments = !runtime.fragments().is_empty(); if has_fragments { regenerate_all_fragments( - &runtime.fragments, + runtime.fragments(), &doc_alloc, new_fields_ids_map, obkv, @@ -607,7 +611,7 @@ pub fn extract_vector_points( manual_vectors_writer, &mut key_buffer, delta, - &runtime.fragments, + runtime.fragments(), )?; } @@ -720,7 +724,7 @@ fn extract_vector_document_diff( ManualEmbedderErrors::push_error(manual_errors, embedder_name, document_id); return Ok(VectorStateDelta::NoChange); } - let has_fragments = !runtime.fragments.is_empty(); + let has_fragments = !runtime.fragments().is_empty(); if has_fragments { let prompt = &runtime.document_template; // Don't give up if the old prompt was failing @@ -753,7 +757,7 @@ fn extract_vector_document_diff( new_fields_ids_map, ); - for new in &runtime.fragments { + for new in runtime.fragments() { let name = &new.name; let fragment = RequestFragmentExtractor::new(new, doc_alloc).ignore_errors(); @@ -791,11 +795,11 @@ fn extract_vector_document_diff( return Ok(VectorStateDelta::NoChange); } - let has_fragments = !runtime.fragments.is_empty(); + let has_fragments = !runtime.fragments().is_empty(); if has_fragments { regenerate_all_fragments( - &runtime.fragments, + runtime.fragments(), doc_alloc, new_fields_ids_map, obkv, diff --git a/crates/milli/src/update/index_documents/extract/mod.rs b/crates/milli/src/update/index_documents/extract/mod.rs index cbf4ceba2..b41fd59e1 100644 --- a/crates/milli/src/update/index_documents/extract/mod.rs +++ b/crates/milli/src/update/index_documents/extract/mod.rs @@ -242,7 +242,7 @@ fn send_original_documents_data( let index_vectors = (settings_diff.reindex_vectors() || !settings_diff.settings_update_only()) // no point in indexing vectors without embedders - && (!settings_diff.new.embedding_configs.inner_as_ref().is_empty()); + && (!settings_diff.new.runtime_embedders.inner_as_ref().is_empty()); if index_vectors { let settings_diff = settings_diff.clone(); diff --git a/crates/milli/src/update/index_documents/mod.rs b/crates/milli/src/update/index_documents/mod.rs index 055b8bbad..658ff1923 100644 --- a/crates/milli/src/update/index_documents/mod.rs +++ b/crates/milli/src/update/index_documents/mod.rs @@ -517,7 +517,7 @@ where let embedder_config = settings_diff.embedding_config_updates.get(&embedder_name); let was_quantized = settings_diff .old - .embedding_configs + .runtime_embedders .get(&embedder_name) .is_some_and(|conf| conf.is_quantized); let is_quantizing = embedder_config.is_some_and(|action| action.is_being_quantized); diff --git a/crates/milli/src/update/index_documents/typed_chunk.rs b/crates/milli/src/update/index_documents/typed_chunk.rs index 370579a6c..c93e3e0f7 100644 --- a/crates/milli/src/update/index_documents/typed_chunk.rs +++ b/crates/milli/src/update/index_documents/typed_chunk.rs @@ -673,7 +673,7 @@ pub(crate) fn write_typed_chunk_into_index( let binary_quantized = settings_diff .old - .embedding_configs + .runtime_embedders .get(&embedder_name) .is_some_and(|conf| conf.is_quantized); // FIXME: allow customizing distance diff --git a/crates/milli/src/update/settings.rs b/crates/milli/src/update/settings.rs index 03d44d785..c9ab427ea 100644 --- a/crates/milli/src/update/settings.rs +++ b/crates/milli/src/update/settings.rs @@ -1647,9 +1647,9 @@ impl InnerIndexSettingsDiff { // if the user-defined searchables changed, then we need to reindex prompts. if cache_user_defined_searchables { - for (embedder_name, runtime) in new_settings.embedding_configs.inner_as_ref() { + for (embedder_name, runtime) in new_settings.runtime_embedders.inner_as_ref() { let was_quantized = old_settings - .embedding_configs + .runtime_embedders .get(embedder_name) .is_some_and(|conf| conf.is_quantized); // skip embedders that don't use document templates @@ -1893,7 +1893,7 @@ pub(crate) struct InnerIndexSettings { pub exact_attributes: HashSet, pub disabled_typos_terms: DisabledTyposTerms, pub proximity_precision: ProximityPrecision, - pub embedding_configs: RuntimeEmbedders, + pub runtime_embedders: RuntimeEmbedders, pub embedder_category_id: HashMap, pub geo_fields_ids: Option<(FieldId, FieldId)>, pub prefix_search: PrefixSearch, @@ -1904,7 +1904,7 @@ impl InnerIndexSettings { pub fn from_index( index: &Index, rtxn: &heed::RoTxn<'_>, - embedding_configs: Option, + runtime_embedders: Option, ) -> Result { let stop_words = index.stop_words(rtxn)?; let stop_words = stop_words.map(|sw| sw.map_data(Vec::from).unwrap()); @@ -1913,13 +1913,13 @@ impl InnerIndexSettings { let mut fields_ids_map = index.fields_ids_map(rtxn)?; let exact_attributes = index.exact_attributes_ids(rtxn)?; let proximity_precision = index.proximity_precision(rtxn)?.unwrap_or_default(); - let embedding_configs = match embedding_configs { + let runtime_embedders = match runtime_embedders { Some(embedding_configs) => embedding_configs, None => embedders(index.embedding_configs().embedding_configs(rtxn)?)?, }; let embedder_category_id = index - .embedder_category_id - .iter(rtxn)? + .embedding_configs() + .iter_embedder_id(rtxn)? .map(|r| r.map(|(k, v)| (k.to_string(), v))) .collect::>()?; let prefix_search = index.prefix_search(rtxn)?.unwrap_or_default(); @@ -1960,7 +1960,7 @@ impl InnerIndexSettings { sortable_fields, exact_attributes, proximity_precision, - embedding_configs, + runtime_embedders, embedder_category_id, geo_fields_ids, prefix_search, @@ -2035,12 +2035,12 @@ fn embedders(embedding_configs: Vec) -> Result &EmbeddingConfigs; - fn old_embedders(&self) -> &EmbeddingConfigs; + fn new_embedders(&self) -> &RuntimeEmbedders; + fn old_embedders(&self) -> &RuntimeEmbedders; fn new_embedder_category_id(&self) -> &HashMap; fn embedder_actions(&self) -> &BTreeMap; fn try_for_each_fragment_diff( @@ -2407,12 +2407,12 @@ pub struct FragmentDiff<'a> { } impl SettingsDelta for InnerIndexSettingsDiff { - fn new_embedders(&self) -> &EmbeddingConfigs { - &self.new.embedding_configs + fn new_embedders(&self) -> &RuntimeEmbedders { + &self.new.runtime_embedders } - fn old_embedders(&self) -> &EmbeddingConfigs { - &self.old.embedding_configs + fn old_embedders(&self) -> &RuntimeEmbedders { + &self.old.runtime_embedders } fn new_embedder_category_id(&self) -> &HashMap { diff --git a/crates/milli/src/vector/mod.rs b/crates/milli/src/vector/mod.rs index 87ecd2414..f64223e41 100644 --- a/crates/milli/src/vector/mod.rs +++ b/crates/milli/src/vector/mod.rs @@ -742,10 +742,27 @@ pub struct RuntimeEmbedders(HashMap>); pub struct RuntimeEmbedder { pub embedder: Arc, pub document_template: Prompt, - pub fragments: Vec, + fragments: Vec, pub is_quantized: bool, } +impl RuntimeEmbedder { + pub fn new( + embedder: Arc, + document_template: Prompt, + mut fragments: Vec, + is_quantized: bool, + ) -> Self { + fragments.sort_unstable_by(|left, right| left.name.cmp(&right.name)); + Self { embedder, document_template, fragments, is_quantized } + } + + /// The runtime fragments sorted by name. + pub fn fragments(&self) -> &[RuntimeFragment] { + self.fragments.as_slice() + } +} + pub struct RuntimeFragment { pub name: String, pub id: u8, @@ -763,8 +780,8 @@ impl RuntimeEmbedders { } /// Get an embedder configuration and template from its name. - pub fn get(&self, name: &str) -> Option> { - self.0.get(name).cloned() + pub fn get(&self, name: &str) -> Option<&Arc> { + self.0.get(name) } pub fn inner_as_ref(&self) -> &HashMap> { @@ -774,6 +791,14 @@ impl RuntimeEmbedders { pub fn into_inner(self) -> HashMap> { self.0 } + + pub fn len(&self) -> usize { + self.0.len() + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } } impl IntoIterator for RuntimeEmbedders {