Various changes

- DistributionShift in Search object (to be set from model in embed?)
- Fix issue where embedder index wasn't computed at search time
- Accept as default embedder either the "default" one, or the only embedder when there is only one
This commit is contained in:
Louis Dureuil 2023-12-13 15:38:44 +01:00
parent 12940d79a9
commit e0cc775dc4
No known key found for this signature in database
12 changed files with 141 additions and 33 deletions

View file

@ -9,10 +9,9 @@ mod extract_word_docids;
mod extract_word_pair_proximity_docids;
mod extract_word_position_docids;
use std::collections::{HashMap, HashSet};
use std::collections::HashSet;
use std::fs::File;
use std::io::BufReader;
use std::sync::Arc;
use crossbeam_channel::Sender;
use log::debug;
@ -35,9 +34,8 @@ use super::helpers::{
MergeFn, MergeableReader,
};
use super::{helpers, TypedChunk};
use crate::prompt::Prompt;
use crate::proximity::ProximityPrecision;
use crate::vector::Embedder;
use crate::vector::EmbeddingConfigs;
use crate::{FieldId, FieldsIdsMap, Result};
/// Extract data for each databases from obkv documents in parallel.
@ -59,7 +57,7 @@ pub(crate) fn data_from_obkv_documents(
max_positions_per_attributes: Option<u32>,
exact_attributes: HashSet<FieldId>,
proximity_precision: ProximityPrecision,
embedders: HashMap<String, (Arc<Embedder>, Arc<Prompt>)>,
embedders: EmbeddingConfigs,
) -> Result<()> {
puffin::profile_function!();
@ -284,7 +282,7 @@ fn send_original_documents_data(
indexer: GrenadParameters,
lmdb_writer_sx: Sender<Result<TypedChunk>>,
field_id_map: FieldsIdsMap,
embedders: HashMap<String, (Arc<Embedder>, Arc<Prompt>)>,
embedders: EmbeddingConfigs,
) -> Result<()> {
let original_documents_chunk =
original_documents_chunk.and_then(|c| unsafe { as_cloneable_grenad(&c) })?;

View file

@ -9,7 +9,6 @@ use std::io::{Cursor, Read, Seek};
use std::iter::FromIterator;
use std::num::NonZeroU32;
use std::result::Result as StdResult;
use std::sync::Arc;
use crossbeam_channel::{Receiver, Sender};
use heed::types::Str;
@ -34,12 +33,11 @@ use self::helpers::{grenad_obkv_into_chunks, GrenadParameters};
pub use self::transform::{Transform, TransformOutput};
use crate::documents::{obkv_to_object, DocumentsBatchReader};
use crate::error::{Error, InternalError, UserError};
use crate::prompt::Prompt;
pub use crate::update::index_documents::helpers::CursorClonableMmap;
use crate::update::{
IndexerConfig, UpdateIndexingStep, WordPrefixDocids, WordPrefixIntegerDocids, WordsPrefixesFst,
};
use crate::vector::Embedder;
use crate::vector::EmbeddingConfigs;
use crate::{CboRoaringBitmapCodec, Index, Result};
static MERGED_DATABASE_COUNT: usize = 7;
@ -82,7 +80,7 @@ pub struct IndexDocuments<'t, 'i, 'a, FP, FA> {
should_abort: FA,
added_documents: u64,
deleted_documents: u64,
embedders: HashMap<String, (Arc<Embedder>, Arc<Prompt>)>,
embedders: EmbeddingConfigs,
}
#[derive(Default, Debug, Clone)]
@ -173,10 +171,7 @@ where
Ok((self, Ok(indexed_documents)))
}
pub fn with_embedders(
mut self,
embedders: HashMap<String, (Arc<Embedder>, Arc<Prompt>)>,
) -> Self {
pub fn with_embedders(mut self, embedders: EmbeddingConfigs) -> Self {
self.embedders = embedders;
self
}

View file

@ -14,12 +14,11 @@ use super::IndexerConfig;
use crate::criterion::Criterion;
use crate::error::UserError;
use crate::index::{DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS};
use crate::prompt::Prompt;
use crate::proximity::ProximityPrecision;
use crate::update::index_documents::IndexDocumentsMethod;
use crate::update::{IndexDocuments, UpdateIndexingStep};
use crate::vector::settings::{EmbeddingSettings, PromptSettings};
use crate::vector::{Embedder, EmbeddingConfig};
use crate::vector::{Embedder, EmbeddingConfig, EmbeddingConfigs};
use crate::{FieldsIdsMap, Index, OrderBy, Result};
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
@ -422,7 +421,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
fn embedders(
&self,
embedding_configs: Vec<(String, EmbeddingConfig)>,
) -> Result<HashMap<String, (Arc<Embedder>, Arc<Prompt>)>> {
) -> Result<EmbeddingConfigs> {
let res: Result<_> = embedding_configs
.into_iter()
.map(|(name, EmbeddingConfig { embedder_options, prompt })| {
@ -436,7 +435,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
Ok((name, (embedder, prompt)))
})
.collect();
res
res.map(EmbeddingConfigs::new)
}
fn update_displayed(&mut self) -> Result<bool> {