Actually pass embedders on reindex

This commit is contained in:
Louis Dureuil 2023-12-07 23:05:26 +01:00
parent 687d92f217
commit e56f160032
No known key found for this signature in database

View File

@ -1,5 +1,7 @@
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::convert::TryInto;
use std::result::Result as StdResult;
use std::sync::Arc;
use charabia::{Normalize, Tokenizer, TokenizerBuilder};
use deserr::{DeserializeError, Deserr};
@ -12,11 +14,12 @@ use super::IndexerConfig;
use crate::criterion::Criterion;
use crate::error::UserError;
use crate::index::{DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS};
use crate::prompt::Prompt;
use crate::proximity::ProximityPrecision;
use crate::update::index_documents::IndexDocumentsMethod;
use crate::update::{IndexDocuments, UpdateIndexingStep};
use crate::vector::settings::{EmbeddingSettings, PromptSettings};
use crate::vector::EmbeddingConfig;
use crate::vector::{Embedder, EmbeddingConfig};
use crate::{FieldsIdsMap, Index, OrderBy, Result};
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
@ -396,6 +399,9 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
fields_ids_map,
)?;
let embedder_configs = self.index.embedding_configs(self.wtxn)?;
let embedders = self.embedders(embedder_configs)?;
// We index the generated `TransformOutput` which must contain
// all the documents with fields in the newly defined searchable order.
let indexing_builder = IndexDocuments::new(
@ -406,11 +412,34 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
&progress_callback,
&should_abort,
)?;
let indexing_builder = indexing_builder.with_embedders(embedders);
indexing_builder.execute_raw(output)?;
Ok(())
}
fn embedders(
&self,
embedding_configs: Vec<(String, EmbeddingConfig)>,
) -> Result<HashMap<String, (Arc<Embedder>, Arc<Prompt>)>> {
let res: Result<_> = embedding_configs
.into_iter()
.map(|(name, EmbeddingConfig { embedder_options, prompt })| {
let prompt = Arc::new(prompt.try_into().map_err(crate::Error::from)?);
let embedder = Arc::new(
Embedder::new(embedder_options.clone())
.map_err(crate::vector::Error::from)
.map_err(crate::UserError::from)
.map_err(crate::Error::from)?,
);
Ok((name, (embedder, prompt)))
})
.collect();
res
}
fn update_displayed(&mut self) -> Result<bool> {
match self.displayed_fields {
Setting::Set(ref fields) => {