Small commit to add hybrid search and autoembedding

This commit is contained in:
Louis Dureuil 2023-11-15 15:46:37 +01:00
parent 21bcf32109
commit 13c2c6c16b
No known key found for this signature in database
42 changed files with 4045 additions and 246 deletions

View file

@ -52,6 +52,7 @@ use meilisearch_types::heed::types::{SerdeBincode, SerdeJson, Str, I128};
use meilisearch_types::heed::{self, Database, Env, PutFlags, RoTxn, RwTxn};
use meilisearch_types::milli::documents::DocumentsBatchBuilder;
use meilisearch_types::milli::update::IndexerConfig;
use meilisearch_types::milli::vector::{Embedder, EmbedderOptions};
use meilisearch_types::milli::{self, CboRoaringBitmapCodec, Index, RoaringBitmapCodec, BEU32};
use meilisearch_types::tasks::{Kind, KindWithContent, Status, Task};
use puffin::FrameView;
@ -341,6 +342,8 @@ pub struct IndexScheduler {
/// so that a handle to the index is available from other threads (search) in an optimized manner.
currently_updating_index: Arc<RwLock<Option<(String, Index)>>>,
embedders: Arc<RwLock<HashMap<EmbedderOptions, std::sync::Arc<Embedder>>>>,
// ================= test
// The next entry is dedicated to the tests.
/// Provide a way to set a breakpoint in multiple part of the scheduler.
@ -386,6 +389,7 @@ impl IndexScheduler {
auth_path: self.auth_path.clone(),
version_file_path: self.version_file_path.clone(),
currently_updating_index: self.currently_updating_index.clone(),
embedders: self.embedders.clone(),
#[cfg(test)]
test_breakpoint_sdr: self.test_breakpoint_sdr.clone(),
#[cfg(test)]
@ -484,6 +488,7 @@ impl IndexScheduler {
auth_path: options.auth_path,
version_file_path: options.version_file_path,
currently_updating_index: Arc::new(RwLock::new(None)),
embedders: Default::default(),
#[cfg(test)]
test_breakpoint_sdr,
@ -1333,6 +1338,42 @@ impl IndexScheduler {
}
}
// TODO: consider using a type alias or a struct embedder/template
#[allow(clippy::type_complexity)]
pub fn embedders(
&self,
embedding_configs: Vec<(String, milli::vector::EmbeddingConfig)>,
) -> Result<HashMap<String, (Arc<milli::vector::Embedder>, Arc<milli::prompt::Prompt>)>> {
let res: Result<_> = embedding_configs
.into_iter()
.map(|(name, milli::vector::EmbeddingConfig { embedder_options, prompt })| {
let prompt =
Arc::new(prompt.try_into().map_err(meilisearch_types::milli::Error::from)?);
// optimistically return existing embedder
{
let embedders = self.embedders.read().unwrap();
if let Some(embedder) = embedders.get(&embedder_options) {
return Ok((name, (embedder.clone(), prompt)));
}
}
// add missing embedder
let embedder = Arc::new(
Embedder::new(embedder_options.clone())
.map_err(meilisearch_types::milli::vector::Error::from)
.map_err(meilisearch_types::milli::UserError::from)
.map_err(meilisearch_types::milli::Error::from)?,
);
{
let mut embedders = self.embedders.write().unwrap();
embedders.insert(embedder_options, embedder.clone());
}
Ok((name, (embedder, prompt)))
})
.collect();
res
}
/// Blocks the thread until the test handle asks to progress to/through this breakpoint.
///
/// Two messages are sent through the channel for each breakpoint.