Add embedding cache

This commit is contained in:
Louis Dureuil 2025-03-13 11:13:14 +01:00
parent d9111fe8ce
commit b08544e86d
No known key found for this signature in database
8 changed files with 159 additions and 19 deletions

View File

@ -916,7 +916,7 @@ fn prepare_search<'t>(
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(10); let deadline = std::time::Instant::now() + std::time::Duration::from_secs(10);
embedder embedder
.embed_search(query.q.clone().unwrap(), Some(deadline)) .embed_search(query.q.as_ref().unwrap(), Some(deadline))
.map_err(milli::vector::Error::from) .map_err(milli::vector::Error::from)
.map_err(milli::Error::from)? .map_err(milli::Error::from)?
} }

View File

@ -203,7 +203,7 @@ impl<'a> Search<'a> {
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(3); let deadline = std::time::Instant::now() + std::time::Duration::from_secs(3);
match embedder.embed_search(query, Some(deadline)) { match embedder.embed_search(&query, Some(deadline)) {
Ok(embedding) => embedding, Ok(embedding) => embedding,
Err(error) => { Err(error) => {
tracing::error!(error=%error, "Embedding failed"); tracing::error!(error=%error, "Embedding failed");

View File

@ -4,7 +4,8 @@ use arroy::Distance;
use super::error::CompositeEmbedderContainsHuggingFace; use super::error::CompositeEmbedderContainsHuggingFace;
use super::{ use super::{
hf, manual, ollama, openai, rest, DistributionShift, EmbedError, Embedding, NewEmbedderError, hf, manual, ollama, openai, rest, DistributionShift, EmbedError, Embedding, EmbeddingCache,
NewEmbedderError,
}; };
use crate::ThreadPoolNoAbort; use crate::ThreadPoolNoAbort;
@ -148,6 +149,27 @@ impl SubEmbedder {
} }
} }
pub fn embed_one(
&self,
text: &str,
deadline: Option<Instant>,
) -> std::result::Result<Embedding, EmbedError> {
match self {
SubEmbedder::HuggingFace(embedder) => embedder.embed_one(text),
SubEmbedder::OpenAi(embedder) => {
embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding)
}
SubEmbedder::Ollama(embedder) => {
embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding)
}
SubEmbedder::UserProvided(embedder) => embedder.embed_one(text),
SubEmbedder::Rest(embedder) => embedder
.embed_ref(&[text], deadline)?
.pop()
.ok_or_else(EmbedError::missing_embedding),
}
}
/// Embed multiple chunks of texts. /// Embed multiple chunks of texts.
/// ///
/// Each chunk is composed of one or multiple texts. /// Each chunk is composed of one or multiple texts.
@ -233,6 +255,16 @@ impl SubEmbedder {
SubEmbedder::Rest(embedder) => embedder.distribution(), SubEmbedder::Rest(embedder) => embedder.distribution(),
} }
} }
pub(super) fn cache(&self) -> Option<&EmbeddingCache> {
match self {
SubEmbedder::HuggingFace(embedder) => Some(embedder.cache()),
SubEmbedder::OpenAi(embedder) => Some(embedder.cache()),
SubEmbedder::UserProvided(_) => None,
SubEmbedder::Ollama(embedder) => Some(embedder.cache()),
SubEmbedder::Rest(embedder) => Some(embedder.cache()),
}
}
} }
fn check_similarity( fn check_similarity(

View File

@ -7,7 +7,7 @@ use hf_hub::{Repo, RepoType};
use tokenizers::{PaddingParams, Tokenizer}; use tokenizers::{PaddingParams, Tokenizer};
pub use super::error::{EmbedError, Error, NewEmbedderError}; pub use super::error::{EmbedError, Error, NewEmbedderError};
use super::{DistributionShift, Embedding}; use super::{DistributionShift, Embedding, EmbeddingCache};
#[derive( #[derive(
Debug, Debug,
@ -84,6 +84,7 @@ pub struct Embedder {
options: EmbedderOptions, options: EmbedderOptions,
dimensions: usize, dimensions: usize,
pooling: Pooling, pooling: Pooling,
cache: EmbeddingCache,
} }
impl std::fmt::Debug for Embedder { impl std::fmt::Debug for Embedder {
@ -245,7 +246,14 @@ impl Embedder {
tokenizer.with_padding(Some(pp)); tokenizer.with_padding(Some(pp));
} }
let mut this = Self { model, tokenizer, options, dimensions: 0, pooling }; let mut this = Self {
model,
tokenizer,
options,
dimensions: 0,
pooling,
cache: EmbeddingCache::new(super::CAP_PER_THREAD),
};
let embeddings = this let embeddings = this
.embed(vec!["test".into()]) .embed(vec!["test".into()])
@ -355,4 +363,8 @@ impl Embedder {
pub(crate) fn embed_index_ref(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbedError> { pub(crate) fn embed_index_ref(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbedError> {
texts.iter().map(|text| self.embed_one(text)).collect() texts.iter().map(|text| self.embed_one(text)).collect()
} }
pub(super) fn cache(&self) -> &EmbeddingCache {
&self.cache
}
} }

View File

@ -1,4 +1,6 @@
use std::cell::RefCell;
use std::collections::HashMap; use std::collections::HashMap;
use std::num::{NonZeroUsize, TryFromIntError};
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
@ -551,6 +553,51 @@ pub enum Embedder {
Composite(composite::Embedder), Composite(composite::Embedder),
} }
#[derive(Debug)]
struct EmbeddingCache {
data: thread_local::ThreadLocal<RefCell<lru::LruCache<String, Embedding>>>,
cap_per_thread: u16,
}
impl EmbeddingCache {
pub fn new(cap_per_thread: u16) -> Self {
Self { cap_per_thread, data: thread_local::ThreadLocal::new() }
}
/// Get the embedding corresponding to `text`, if any is present in the cache.
pub fn get(&self, text: &str) -> Option<Embedding> {
let mut cache = self
.data
.get_or_try(|| -> Result<RefCell<lru::LruCache<String, Vec<f32>>>, TryFromIntError> {
Ok(RefCell::new(lru::LruCache::new(NonZeroUsize::try_from(
self.cap_per_thread as usize,
)?)))
})
.ok()?
.borrow_mut();
cache.get(text).cloned()
}
/// Puts a new embedding for the specified `text`
pub fn put(&self, text: String, embedding: Embedding) {
let Ok(cache) = self.data.get_or_try(
|| -> Result<RefCell<lru::LruCache<String, Vec<f32>>>, TryFromIntError> {
Ok(RefCell::new(lru::LruCache::new(NonZeroUsize::try_from(
self.cap_per_thread as usize,
)?)))
},
) else {
return;
};
let mut cache = cache.borrow_mut();
cache.put(text, embedding);
}
}
pub const CAP_PER_THREAD: u16 = 20;
/// Configuration for an embedder. /// Configuration for an embedder.
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)] #[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
pub struct EmbeddingConfig { pub struct EmbeddingConfig {
@ -651,19 +698,36 @@ impl Embedder {
#[tracing::instrument(level = "debug", skip_all, target = "search")] #[tracing::instrument(level = "debug", skip_all, target = "search")]
pub fn embed_search( pub fn embed_search(
&self, &self,
text: String, text: &str,
deadline: Option<Instant>, deadline: Option<Instant>,
) -> std::result::Result<Embedding, EmbedError> { ) -> std::result::Result<Embedding, EmbedError> {
let texts = vec![text]; if let Some(cache) = self.cache() {
let mut embedding = match self { if let Some(embedding) = cache.get(text) {
Embedder::HuggingFace(embedder) => embedder.embed(texts), tracing::trace!(text, "embedding found in cache");
Embedder::OpenAi(embedder) => embedder.embed(&texts, deadline), return Ok(embedding);
Embedder::Ollama(embedder) => embedder.embed(&texts, deadline), }
Embedder::UserProvided(embedder) => embedder.embed(&texts), }
Embedder::Rest(embedder) => embedder.embed(texts, deadline), let embedding = match self {
Embedder::Composite(embedder) => embedder.search.embed(texts, deadline), Embedder::HuggingFace(embedder) => embedder.embed_one(text),
Embedder::OpenAi(embedder) => {
embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding)
}
Embedder::Ollama(embedder) => {
embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding)
}
Embedder::UserProvided(embedder) => embedder.embed_one(text),
Embedder::Rest(embedder) => embedder
.embed_ref(&[text], deadline)?
.pop()
.ok_or_else(EmbedError::missing_embedding),
Embedder::Composite(embedder) => embedder.search.embed_one(text, deadline),
}?; }?;
let embedding = embedding.pop().ok_or_else(EmbedError::missing_embedding)?;
if let Some(cache) = self.cache() {
tracing::trace!(text, "embedding added to cache");
cache.put(text.to_owned(), embedding.clone());
}
Ok(embedding) Ok(embedding)
} }
@ -759,6 +823,17 @@ impl Embedder {
Embedder::Composite(embedder) => embedder.index.uses_document_template(), Embedder::Composite(embedder) => embedder.index.uses_document_template(),
} }
} }
fn cache(&self) -> Option<&EmbeddingCache> {
match self {
Embedder::HuggingFace(embedder) => Some(embedder.cache()),
Embedder::OpenAi(embedder) => Some(embedder.cache()),
Embedder::UserProvided(_) => None,
Embedder::Ollama(embedder) => Some(embedder.cache()),
Embedder::Rest(embedder) => Some(embedder.cache()),
Embedder::Composite(embedder) => embedder.search.cache(),
}
}
} }
/// Describes the mean and sigma of distribution of embedding similarity in the embedding space. /// Describes the mean and sigma of distribution of embedding similarity in the embedding space.

View File

@ -5,7 +5,7 @@ use rayon::slice::ParallelSlice as _;
use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind}; use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind};
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
use super::{DistributionShift, REQUEST_PARALLELISM}; use super::{DistributionShift, EmbeddingCache, REQUEST_PARALLELISM};
use crate::error::FaultSource; use crate::error::FaultSource;
use crate::vector::Embedding; use crate::vector::Embedding;
use crate::ThreadPoolNoAbort; use crate::ThreadPoolNoAbort;
@ -182,6 +182,10 @@ impl Embedder {
pub fn distribution(&self) -> Option<DistributionShift> { pub fn distribution(&self) -> Option<DistributionShift> {
self.rest_embedder.distribution() self.rest_embedder.distribution()
} }
pub(super) fn cache(&self) -> &EmbeddingCache {
self.rest_embedder.cache()
}
} }
fn get_ollama_path() -> String { fn get_ollama_path() -> String {

View File

@ -7,7 +7,7 @@ use rayon::slice::ParallelSlice as _;
use super::error::{EmbedError, NewEmbedderError}; use super::error::{EmbedError, NewEmbedderError};
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
use super::{DistributionShift, REQUEST_PARALLELISM}; use super::{DistributionShift, EmbeddingCache, REQUEST_PARALLELISM};
use crate::error::FaultSource; use crate::error::FaultSource;
use crate::vector::error::EmbedErrorKind; use crate::vector::error::EmbedErrorKind;
use crate::vector::Embedding; use crate::vector::Embedding;
@ -318,6 +318,10 @@ impl Embedder {
pub fn distribution(&self) -> Option<DistributionShift> { pub fn distribution(&self) -> Option<DistributionShift> {
self.options.distribution() self.options.distribution()
} }
pub(super) fn cache(&self) -> &EmbeddingCache {
self.rest_embedder.cache()
}
} }
impl fmt::Debug for Embedder { impl fmt::Debug for Embedder {

View File

@ -9,7 +9,10 @@ use serde::{Deserialize, Serialize};
use super::error::EmbedErrorKind; use super::error::EmbedErrorKind;
use super::json_template::ValueTemplate; use super::json_template::ValueTemplate;
use super::{DistributionShift, EmbedError, Embedding, NewEmbedderError, REQUEST_PARALLELISM}; use super::{
DistributionShift, EmbedError, Embedding, EmbeddingCache, NewEmbedderError, CAP_PER_THREAD,
REQUEST_PARALLELISM,
};
use crate::error::FaultSource; use crate::error::FaultSource;
use crate::ThreadPoolNoAbort; use crate::ThreadPoolNoAbort;
@ -75,6 +78,7 @@ pub struct Embedder {
data: EmbedderData, data: EmbedderData,
dimensions: usize, dimensions: usize,
distribution: Option<DistributionShift>, distribution: Option<DistributionShift>,
cache: EmbeddingCache,
} }
/// All data needed to perform requests and parse responses /// All data needed to perform requests and parse responses
@ -152,7 +156,12 @@ impl Embedder {
infer_dimensions(&data)? infer_dimensions(&data)?
}; };
Ok(Self { data, dimensions, distribution: options.distribution }) Ok(Self {
data,
dimensions,
distribution: options.distribution,
cache: EmbeddingCache::new(CAP_PER_THREAD),
})
} }
pub fn embed( pub fn embed(
@ -256,6 +265,10 @@ impl Embedder {
pub fn distribution(&self) -> Option<DistributionShift> { pub fn distribution(&self) -> Option<DistributionShift> {
self.distribution self.distribution
} }
pub(super) fn cache(&self) -> &EmbeddingCache {
&self.cache
}
} }
fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> { fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> {