wraps the index embedding config in a struct

This commit is contained in:
Tamo 2024-05-30 11:50:30 +02:00
parent 04f6523f3c
commit 9eb6f522ea
7 changed files with 112 additions and 75 deletions

View file

@ -53,6 +53,7 @@ use meilisearch_types::heed::byteorder::BE;
use meilisearch_types::heed::types::{SerdeBincode, SerdeJson, Str, I128};
use meilisearch_types::heed::{self, Database, Env, PutFlags, RoTxn, RwTxn};
use meilisearch_types::milli::documents::DocumentsBatchBuilder;
use meilisearch_types::milli::index::IndexEmbeddingConfig;
use meilisearch_types::milli::update::IndexerConfig;
use meilisearch_types::milli::vector::{Embedder, EmbedderOptions, EmbeddingConfigs};
use meilisearch_types::milli::{self, CboRoaringBitmapCodec, Index, RoaringBitmapCodec, BEU32};
@ -1459,33 +1460,39 @@ impl IndexScheduler {
// TODO: consider using a type alias or a struct embedder/template
pub fn embedders(
&self,
embedding_configs: Vec<(String, milli::vector::EmbeddingConfig, RoaringBitmap)>,
embedding_configs: Vec<IndexEmbeddingConfig>,
) -> Result<EmbeddingConfigs> {
let res: Result<_> = embedding_configs
.into_iter()
.map(|(name, milli::vector::EmbeddingConfig { embedder_options, prompt }, _)| {
let prompt =
Arc::new(prompt.try_into().map_err(meilisearch_types::milli::Error::from)?);
// optimistically return existing embedder
{
let embedders = self.embedders.read().unwrap();
if let Some(embedder) = embedders.get(&embedder_options) {
return Ok((name, (embedder.clone(), prompt)));
.map(
|IndexEmbeddingConfig {
name,
config: milli::vector::EmbeddingConfig { embedder_options, prompt },
..
}| {
let prompt =
Arc::new(prompt.try_into().map_err(meilisearch_types::milli::Error::from)?);
// optimistically return existing embedder
{
let embedders = self.embedders.read().unwrap();
if let Some(embedder) = embedders.get(&embedder_options) {
return Ok((name, (embedder.clone(), prompt)));
}
}
}
// add missing embedder
let embedder = Arc::new(
Embedder::new(embedder_options.clone())
.map_err(meilisearch_types::milli::vector::Error::from)
.map_err(meilisearch_types::milli::Error::from)?,
);
{
let mut embedders = self.embedders.write().unwrap();
embedders.insert(embedder_options, embedder.clone());
}
Ok((name, (embedder, prompt)))
})
// add missing embedder
let embedder = Arc::new(
Embedder::new(embedder_options.clone())
.map_err(meilisearch_types::milli::vector::Error::from)
.map_err(meilisearch_types::milli::Error::from)?,
);
{
let mut embedders = self.embedders.write().unwrap();
embedders.insert(embedder_options, embedder.clone());
}
Ok((name, (embedder, prompt)))
},
)
.collect();
res.map(EmbeddingConfigs::new)
}
@ -3055,10 +3062,10 @@ mod tests {
let rtxn = index.read_txn().unwrap();
let configs = index.embedding_configs(&rtxn).unwrap();
let (name, embedding_config, user_provided) = configs.first().unwrap();
let IndexEmbeddingConfig { name, config, user_defined } = configs.first().unwrap();
insta::assert_snapshot!(name, @"default");
insta::assert_debug_snapshot!(user_provided, @"RoaringBitmap<[]>");
insta::assert_json_snapshot!(embedding_config.embedder_options);
insta::assert_debug_snapshot!(user_defined, @"RoaringBitmap<[]>");
insta::assert_json_snapshot!(config.embedder_options);
}
#[test]
@ -5022,15 +5029,17 @@ mod tests {
let configs = index.embedding_configs(&rtxn).unwrap();
// for consistency with the below
#[allow(clippy::get_first)]
let (name, fakerest_config, user_provided) = configs.get(0).unwrap();
let IndexEmbeddingConfig { name, config: fakerest_config, user_defined } =
configs.get(0).unwrap();
insta::assert_snapshot!(name, @"A_fakerest");
insta::assert_debug_snapshot!(user_provided, @"RoaringBitmap<[]>");
insta::assert_debug_snapshot!(user_defined, @"RoaringBitmap<[]>");
insta::assert_json_snapshot!(fakerest_config.embedder_options);
let fakerest_name = name.clone();
let (name, simple_hf_config, user_provided) = configs.get(1).unwrap();
let IndexEmbeddingConfig { name, config: simple_hf_config, user_defined } =
configs.get(1).unwrap();
insta::assert_snapshot!(name, @"B_small_hf");
insta::assert_debug_snapshot!(user_provided, @"RoaringBitmap<[]>");
insta::assert_debug_snapshot!(user_defined, @"RoaringBitmap<[]>");
insta::assert_json_snapshot!(simple_hf_config.embedder_options);
let simple_hf_name = name.clone();
@ -5102,11 +5111,11 @@ mod tests {
let configs = index.embedding_configs(&rtxn).unwrap();
// for consistency with the below
#[allow(clippy::get_first)]
let (name, _config, user_defined) = configs.get(0).unwrap();
let IndexEmbeddingConfig { name, config: _, user_defined } = configs.get(0).unwrap();
insta::assert_snapshot!(name, @"A_fakerest");
insta::assert_debug_snapshot!(user_defined, @"RoaringBitmap<[0]>");
let (name, _config, user_defined) = configs.get(1).unwrap();
let IndexEmbeddingConfig { name, config: _, user_defined } = configs.get(1).unwrap();
insta::assert_snapshot!(name, @"B_small_hf");
insta::assert_debug_snapshot!(user_defined, @"RoaringBitmap<[]>");
@ -5178,11 +5187,13 @@ mod tests {
let configs = index.embedding_configs(&rtxn).unwrap();
// for consistency with the below
#[allow(clippy::get_first)]
let (name, _config, user_defined) = configs.get(0).unwrap();
let IndexEmbeddingConfig { name, config: _, user_defined } =
configs.get(0).unwrap();
insta::assert_snapshot!(name, @"A_fakerest");
insta::assert_debug_snapshot!(user_defined, @"RoaringBitmap<[0]>");
let (name, _config, user_defined) = configs.get(1).unwrap();
let IndexEmbeddingConfig { name, config: _, user_defined } =
configs.get(1).unwrap();
insta::assert_snapshot!(name, @"B_small_hf");
insta::assert_debug_snapshot!(user_defined, @"RoaringBitmap<[]>");