mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-07-03 11:57:07 +02:00
wraps the index embedding config in a struct
This commit is contained in:
parent
04f6523f3c
commit
9eb6f522ea
7 changed files with 112 additions and 75 deletions
|
@ -53,6 +53,7 @@ use meilisearch_types::heed::byteorder::BE;
|
|||
use meilisearch_types::heed::types::{SerdeBincode, SerdeJson, Str, I128};
|
||||
use meilisearch_types::heed::{self, Database, Env, PutFlags, RoTxn, RwTxn};
|
||||
use meilisearch_types::milli::documents::DocumentsBatchBuilder;
|
||||
use meilisearch_types::milli::index::IndexEmbeddingConfig;
|
||||
use meilisearch_types::milli::update::IndexerConfig;
|
||||
use meilisearch_types::milli::vector::{Embedder, EmbedderOptions, EmbeddingConfigs};
|
||||
use meilisearch_types::milli::{self, CboRoaringBitmapCodec, Index, RoaringBitmapCodec, BEU32};
|
||||
|
@ -1459,33 +1460,39 @@ impl IndexScheduler {
|
|||
// TODO: consider using a type alias or a struct embedder/template
|
||||
pub fn embedders(
|
||||
&self,
|
||||
embedding_configs: Vec<(String, milli::vector::EmbeddingConfig, RoaringBitmap)>,
|
||||
embedding_configs: Vec<IndexEmbeddingConfig>,
|
||||
) -> Result<EmbeddingConfigs> {
|
||||
let res: Result<_> = embedding_configs
|
||||
.into_iter()
|
||||
.map(|(name, milli::vector::EmbeddingConfig { embedder_options, prompt }, _)| {
|
||||
let prompt =
|
||||
Arc::new(prompt.try_into().map_err(meilisearch_types::milli::Error::from)?);
|
||||
// optimistically return existing embedder
|
||||
{
|
||||
let embedders = self.embedders.read().unwrap();
|
||||
if let Some(embedder) = embedders.get(&embedder_options) {
|
||||
return Ok((name, (embedder.clone(), prompt)));
|
||||
.map(
|
||||
|IndexEmbeddingConfig {
|
||||
name,
|
||||
config: milli::vector::EmbeddingConfig { embedder_options, prompt },
|
||||
..
|
||||
}| {
|
||||
let prompt =
|
||||
Arc::new(prompt.try_into().map_err(meilisearch_types::milli::Error::from)?);
|
||||
// optimistically return existing embedder
|
||||
{
|
||||
let embedders = self.embedders.read().unwrap();
|
||||
if let Some(embedder) = embedders.get(&embedder_options) {
|
||||
return Ok((name, (embedder.clone(), prompt)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// add missing embedder
|
||||
let embedder = Arc::new(
|
||||
Embedder::new(embedder_options.clone())
|
||||
.map_err(meilisearch_types::milli::vector::Error::from)
|
||||
.map_err(meilisearch_types::milli::Error::from)?,
|
||||
);
|
||||
{
|
||||
let mut embedders = self.embedders.write().unwrap();
|
||||
embedders.insert(embedder_options, embedder.clone());
|
||||
}
|
||||
Ok((name, (embedder, prompt)))
|
||||
})
|
||||
// add missing embedder
|
||||
let embedder = Arc::new(
|
||||
Embedder::new(embedder_options.clone())
|
||||
.map_err(meilisearch_types::milli::vector::Error::from)
|
||||
.map_err(meilisearch_types::milli::Error::from)?,
|
||||
);
|
||||
{
|
||||
let mut embedders = self.embedders.write().unwrap();
|
||||
embedders.insert(embedder_options, embedder.clone());
|
||||
}
|
||||
Ok((name, (embedder, prompt)))
|
||||
},
|
||||
)
|
||||
.collect();
|
||||
res.map(EmbeddingConfigs::new)
|
||||
}
|
||||
|
@ -3055,10 +3062,10 @@ mod tests {
|
|||
let rtxn = index.read_txn().unwrap();
|
||||
|
||||
let configs = index.embedding_configs(&rtxn).unwrap();
|
||||
let (name, embedding_config, user_provided) = configs.first().unwrap();
|
||||
let IndexEmbeddingConfig { name, config, user_defined } = configs.first().unwrap();
|
||||
insta::assert_snapshot!(name, @"default");
|
||||
insta::assert_debug_snapshot!(user_provided, @"RoaringBitmap<[]>");
|
||||
insta::assert_json_snapshot!(embedding_config.embedder_options);
|
||||
insta::assert_debug_snapshot!(user_defined, @"RoaringBitmap<[]>");
|
||||
insta::assert_json_snapshot!(config.embedder_options);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -5022,15 +5029,17 @@ mod tests {
|
|||
let configs = index.embedding_configs(&rtxn).unwrap();
|
||||
// for consistency with the below
|
||||
#[allow(clippy::get_first)]
|
||||
let (name, fakerest_config, user_provided) = configs.get(0).unwrap();
|
||||
let IndexEmbeddingConfig { name, config: fakerest_config, user_defined } =
|
||||
configs.get(0).unwrap();
|
||||
insta::assert_snapshot!(name, @"A_fakerest");
|
||||
insta::assert_debug_snapshot!(user_provided, @"RoaringBitmap<[]>");
|
||||
insta::assert_debug_snapshot!(user_defined, @"RoaringBitmap<[]>");
|
||||
insta::assert_json_snapshot!(fakerest_config.embedder_options);
|
||||
let fakerest_name = name.clone();
|
||||
|
||||
let (name, simple_hf_config, user_provided) = configs.get(1).unwrap();
|
||||
let IndexEmbeddingConfig { name, config: simple_hf_config, user_defined } =
|
||||
configs.get(1).unwrap();
|
||||
insta::assert_snapshot!(name, @"B_small_hf");
|
||||
insta::assert_debug_snapshot!(user_provided, @"RoaringBitmap<[]>");
|
||||
insta::assert_debug_snapshot!(user_defined, @"RoaringBitmap<[]>");
|
||||
insta::assert_json_snapshot!(simple_hf_config.embedder_options);
|
||||
let simple_hf_name = name.clone();
|
||||
|
||||
|
@ -5102,11 +5111,11 @@ mod tests {
|
|||
let configs = index.embedding_configs(&rtxn).unwrap();
|
||||
// for consistency with the below
|
||||
#[allow(clippy::get_first)]
|
||||
let (name, _config, user_defined) = configs.get(0).unwrap();
|
||||
let IndexEmbeddingConfig { name, config: _, user_defined } = configs.get(0).unwrap();
|
||||
insta::assert_snapshot!(name, @"A_fakerest");
|
||||
insta::assert_debug_snapshot!(user_defined, @"RoaringBitmap<[0]>");
|
||||
|
||||
let (name, _config, user_defined) = configs.get(1).unwrap();
|
||||
let IndexEmbeddingConfig { name, config: _, user_defined } = configs.get(1).unwrap();
|
||||
insta::assert_snapshot!(name, @"B_small_hf");
|
||||
insta::assert_debug_snapshot!(user_defined, @"RoaringBitmap<[]>");
|
||||
|
||||
|
@ -5178,11 +5187,13 @@ mod tests {
|
|||
let configs = index.embedding_configs(&rtxn).unwrap();
|
||||
// for consistency with the below
|
||||
#[allow(clippy::get_first)]
|
||||
let (name, _config, user_defined) = configs.get(0).unwrap();
|
||||
let IndexEmbeddingConfig { name, config: _, user_defined } =
|
||||
configs.get(0).unwrap();
|
||||
insta::assert_snapshot!(name, @"A_fakerest");
|
||||
insta::assert_debug_snapshot!(user_defined, @"RoaringBitmap<[0]>");
|
||||
|
||||
let (name, _config, user_defined) = configs.get(1).unwrap();
|
||||
let IndexEmbeddingConfig { name, config: _, user_defined } =
|
||||
configs.get(1).unwrap();
|
||||
insta::assert_snapshot!(name, @"B_small_hf");
|
||||
insta::assert_debug_snapshot!(user_defined, @"RoaringBitmap<[]>");
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue