diff --git a/crates/meilisearch/src/search/mod.rs b/crates/meilisearch/src/search/mod.rs index 69b306abc..35bb883ad 100644 --- a/crates/meilisearch/src/search/mod.rs +++ b/crates/meilisearch/src/search/mod.rs @@ -916,7 +916,7 @@ fn prepare_search<'t>( let deadline = std::time::Instant::now() + std::time::Duration::from_secs(10); embedder - .embed_search(query.q.clone().unwrap(), Some(deadline)) + .embed_search(query.q.as_ref().unwrap(), Some(deadline)) .map_err(milli::vector::Error::from) .map_err(milli::Error::from)? } diff --git a/crates/milli/src/search/hybrid.rs b/crates/milli/src/search/hybrid.rs index a1c8b71da..298248c8b 100644 --- a/crates/milli/src/search/hybrid.rs +++ b/crates/milli/src/search/hybrid.rs @@ -203,7 +203,7 @@ impl<'a> Search<'a> { let deadline = std::time::Instant::now() + std::time::Duration::from_secs(3); - match embedder.embed_search(query, Some(deadline)) { + match embedder.embed_search(&query, Some(deadline)) { Ok(embedding) => embedding, Err(error) => { tracing::error!(error=%error, "Embedding failed"); diff --git a/crates/milli/src/vector/composite.rs b/crates/milli/src/vector/composite.rs index d174232bf..368fb7f18 100644 --- a/crates/milli/src/vector/composite.rs +++ b/crates/milli/src/vector/composite.rs @@ -4,7 +4,8 @@ use arroy::Distance; use super::error::CompositeEmbedderContainsHuggingFace; use super::{ - hf, manual, ollama, openai, rest, DistributionShift, EmbedError, Embedding, NewEmbedderError, + hf, manual, ollama, openai, rest, DistributionShift, EmbedError, Embedding, EmbeddingCache, + NewEmbedderError, }; use crate::ThreadPoolNoAbort; @@ -148,6 +149,27 @@ impl SubEmbedder { } } + pub fn embed_one( + &self, + text: &str, + deadline: Option, + ) -> std::result::Result { + match self { + SubEmbedder::HuggingFace(embedder) => embedder.embed_one(text), + SubEmbedder::OpenAi(embedder) => { + embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding) + } + SubEmbedder::Ollama(embedder) => { + embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding) + } + SubEmbedder::UserProvided(embedder) => embedder.embed_one(text), + SubEmbedder::Rest(embedder) => embedder + .embed_ref(&[text], deadline)? + .pop() + .ok_or_else(EmbedError::missing_embedding), + } + } + /// Embed multiple chunks of texts. /// /// Each chunk is composed of one or multiple texts. @@ -233,6 +255,16 @@ impl SubEmbedder { SubEmbedder::Rest(embedder) => embedder.distribution(), } } + + pub(super) fn cache(&self) -> Option<&EmbeddingCache> { + match self { + SubEmbedder::HuggingFace(embedder) => Some(embedder.cache()), + SubEmbedder::OpenAi(embedder) => Some(embedder.cache()), + SubEmbedder::UserProvided(_) => None, + SubEmbedder::Ollama(embedder) => Some(embedder.cache()), + SubEmbedder::Rest(embedder) => Some(embedder.cache()), + } + } } fn check_similarity( diff --git a/crates/milli/src/vector/hf.rs b/crates/milli/src/vector/hf.rs index 60e40e367..6e73c8247 100644 --- a/crates/milli/src/vector/hf.rs +++ b/crates/milli/src/vector/hf.rs @@ -7,7 +7,7 @@ use hf_hub::{Repo, RepoType}; use tokenizers::{PaddingParams, Tokenizer}; pub use super::error::{EmbedError, Error, NewEmbedderError}; -use super::{DistributionShift, Embedding}; +use super::{DistributionShift, Embedding, EmbeddingCache}; #[derive( Debug, @@ -84,6 +84,7 @@ pub struct Embedder { options: EmbedderOptions, dimensions: usize, pooling: Pooling, + cache: EmbeddingCache, } impl std::fmt::Debug for Embedder { @@ -245,7 +246,14 @@ impl Embedder { tokenizer.with_padding(Some(pp)); } - let mut this = Self { model, tokenizer, options, dimensions: 0, pooling }; + let mut this = Self { + model, + tokenizer, + options, + dimensions: 0, + pooling, + cache: EmbeddingCache::new(super::CAP_PER_THREAD), + }; let embeddings = this .embed(vec!["test".into()]) @@ -355,4 +363,8 @@ impl Embedder { pub(crate) fn embed_index_ref(&self, texts: &[&str]) -> Result, EmbedError> { texts.iter().map(|text| self.embed_one(text)).collect() } + + pub(super) fn cache(&self) -> &EmbeddingCache { + &self.cache + } } diff --git a/crates/milli/src/vector/mod.rs b/crates/milli/src/vector/mod.rs index 80efc210d..55b865b4a 100644 --- a/crates/milli/src/vector/mod.rs +++ b/crates/milli/src/vector/mod.rs @@ -1,4 +1,6 @@ +use std::cell::RefCell; use std::collections::HashMap; +use std::num::{NonZeroUsize, TryFromIntError}; use std::sync::Arc; use std::time::Instant; @@ -551,6 +553,51 @@ pub enum Embedder { Composite(composite::Embedder), } +#[derive(Debug)] +struct EmbeddingCache { + data: thread_local::ThreadLocal>>, + cap_per_thread: u16, +} + +impl EmbeddingCache { + pub fn new(cap_per_thread: u16) -> Self { + Self { cap_per_thread, data: thread_local::ThreadLocal::new() } + } + + /// Get the embedding corresponding to `text`, if any is present in the cache. + pub fn get(&self, text: &str) -> Option { + let mut cache = self + .data + .get_or_try(|| -> Result>>, TryFromIntError> { + Ok(RefCell::new(lru::LruCache::new(NonZeroUsize::try_from( + self.cap_per_thread as usize, + )?))) + }) + .ok()? + .borrow_mut(); + + cache.get(text).cloned() + } + + /// Puts a new embedding for the specified `text` + pub fn put(&self, text: String, embedding: Embedding) { + let Ok(cache) = self.data.get_or_try( + || -> Result>>, TryFromIntError> { + Ok(RefCell::new(lru::LruCache::new(NonZeroUsize::try_from( + self.cap_per_thread as usize, + )?))) + }, + ) else { + return; + }; + let mut cache = cache.borrow_mut(); + + cache.put(text, embedding); + } +} + +pub const CAP_PER_THREAD: u16 = 20; + /// Configuration for an embedder. #[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)] pub struct EmbeddingConfig { @@ -651,19 +698,36 @@ impl Embedder { #[tracing::instrument(level = "debug", skip_all, target = "search")] pub fn embed_search( &self, - text: String, + text: &str, deadline: Option, ) -> std::result::Result { - let texts = vec![text]; - let mut embedding = match self { - Embedder::HuggingFace(embedder) => embedder.embed(texts), - Embedder::OpenAi(embedder) => embedder.embed(&texts, deadline), - Embedder::Ollama(embedder) => embedder.embed(&texts, deadline), - Embedder::UserProvided(embedder) => embedder.embed(&texts), - Embedder::Rest(embedder) => embedder.embed(texts, deadline), - Embedder::Composite(embedder) => embedder.search.embed(texts, deadline), + if let Some(cache) = self.cache() { + if let Some(embedding) = cache.get(text) { + tracing::trace!(text, "embedding found in cache"); + return Ok(embedding); + } + } + let embedding = match self { + Embedder::HuggingFace(embedder) => embedder.embed_one(text), + Embedder::OpenAi(embedder) => { + embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding) + } + Embedder::Ollama(embedder) => { + embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding) + } + Embedder::UserProvided(embedder) => embedder.embed_one(text), + Embedder::Rest(embedder) => embedder + .embed_ref(&[text], deadline)? + .pop() + .ok_or_else(EmbedError::missing_embedding), + Embedder::Composite(embedder) => embedder.search.embed_one(text, deadline), }?; - let embedding = embedding.pop().ok_or_else(EmbedError::missing_embedding)?; + + if let Some(cache) = self.cache() { + tracing::trace!(text, "embedding added to cache"); + cache.put(text.to_owned(), embedding.clone()); + } + Ok(embedding) } @@ -759,6 +823,17 @@ impl Embedder { Embedder::Composite(embedder) => embedder.index.uses_document_template(), } } + + fn cache(&self) -> Option<&EmbeddingCache> { + match self { + Embedder::HuggingFace(embedder) => Some(embedder.cache()), + Embedder::OpenAi(embedder) => Some(embedder.cache()), + Embedder::UserProvided(_) => None, + Embedder::Ollama(embedder) => Some(embedder.cache()), + Embedder::Rest(embedder) => Some(embedder.cache()), + Embedder::Composite(embedder) => embedder.search.cache(), + } + } } /// Describes the mean and sigma of distribution of embedding similarity in the embedding space. diff --git a/crates/milli/src/vector/ollama.rs b/crates/milli/src/vector/ollama.rs index 130e90cee..57c71538e 100644 --- a/crates/milli/src/vector/ollama.rs +++ b/crates/milli/src/vector/ollama.rs @@ -5,7 +5,7 @@ use rayon::slice::ParallelSlice as _; use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind}; use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; -use super::{DistributionShift, REQUEST_PARALLELISM}; +use super::{DistributionShift, EmbeddingCache, REQUEST_PARALLELISM}; use crate::error::FaultSource; use crate::vector::Embedding; use crate::ThreadPoolNoAbort; @@ -182,6 +182,10 @@ impl Embedder { pub fn distribution(&self) -> Option { self.rest_embedder.distribution() } + + pub(super) fn cache(&self) -> &EmbeddingCache { + self.rest_embedder.cache() + } } fn get_ollama_path() -> String { diff --git a/crates/milli/src/vector/openai.rs b/crates/milli/src/vector/openai.rs index 8a5e6266a..66680adb0 100644 --- a/crates/milli/src/vector/openai.rs +++ b/crates/milli/src/vector/openai.rs @@ -7,7 +7,7 @@ use rayon::slice::ParallelSlice as _; use super::error::{EmbedError, NewEmbedderError}; use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; -use super::{DistributionShift, REQUEST_PARALLELISM}; +use super::{DistributionShift, EmbeddingCache, REQUEST_PARALLELISM}; use crate::error::FaultSource; use crate::vector::error::EmbedErrorKind; use crate::vector::Embedding; @@ -318,6 +318,10 @@ impl Embedder { pub fn distribution(&self) -> Option { self.options.distribution() } + + pub(super) fn cache(&self) -> &EmbeddingCache { + self.rest_embedder.cache() + } } impl fmt::Debug for Embedder { diff --git a/crates/milli/src/vector/rest.rs b/crates/milli/src/vector/rest.rs index a31bc5d2f..b9b8b0fb3 100644 --- a/crates/milli/src/vector/rest.rs +++ b/crates/milli/src/vector/rest.rs @@ -9,7 +9,10 @@ use serde::{Deserialize, Serialize}; use super::error::EmbedErrorKind; use super::json_template::ValueTemplate; -use super::{DistributionShift, EmbedError, Embedding, NewEmbedderError, REQUEST_PARALLELISM}; +use super::{ + DistributionShift, EmbedError, Embedding, EmbeddingCache, NewEmbedderError, CAP_PER_THREAD, + REQUEST_PARALLELISM, +}; use crate::error::FaultSource; use crate::ThreadPoolNoAbort; @@ -75,6 +78,7 @@ pub struct Embedder { data: EmbedderData, dimensions: usize, distribution: Option, + cache: EmbeddingCache, } /// All data needed to perform requests and parse responses @@ -152,7 +156,12 @@ impl Embedder { infer_dimensions(&data)? }; - Ok(Self { data, dimensions, distribution: options.distribution }) + Ok(Self { + data, + dimensions, + distribution: options.distribution, + cache: EmbeddingCache::new(CAP_PER_THREAD), + }) } pub fn embed( @@ -256,6 +265,10 @@ impl Embedder { pub fn distribution(&self) -> Option { self.distribution } + + pub(super) fn cache(&self) -> &EmbeddingCache { + &self.cache + } } fn infer_dimensions(data: &EmbedderData) -> Result {