mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-03-19 06:08:20 +01:00
Add embedding cache
This commit is contained in:
parent
d9111fe8ce
commit
b08544e86d
@ -916,7 +916,7 @@ fn prepare_search<'t>(
|
|||||||
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(10);
|
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(10);
|
||||||
|
|
||||||
embedder
|
embedder
|
||||||
.embed_search(query.q.clone().unwrap(), Some(deadline))
|
.embed_search(query.q.as_ref().unwrap(), Some(deadline))
|
||||||
.map_err(milli::vector::Error::from)
|
.map_err(milli::vector::Error::from)
|
||||||
.map_err(milli::Error::from)?
|
.map_err(milli::Error::from)?
|
||||||
}
|
}
|
||||||
|
@ -203,7 +203,7 @@ impl<'a> Search<'a> {
|
|||||||
|
|
||||||
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(3);
|
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(3);
|
||||||
|
|
||||||
match embedder.embed_search(query, Some(deadline)) {
|
match embedder.embed_search(&query, Some(deadline)) {
|
||||||
Ok(embedding) => embedding,
|
Ok(embedding) => embedding,
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
tracing::error!(error=%error, "Embedding failed");
|
tracing::error!(error=%error, "Embedding failed");
|
||||||
|
@ -4,7 +4,8 @@ use arroy::Distance;
|
|||||||
|
|
||||||
use super::error::CompositeEmbedderContainsHuggingFace;
|
use super::error::CompositeEmbedderContainsHuggingFace;
|
||||||
use super::{
|
use super::{
|
||||||
hf, manual, ollama, openai, rest, DistributionShift, EmbedError, Embedding, NewEmbedderError,
|
hf, manual, ollama, openai, rest, DistributionShift, EmbedError, Embedding, EmbeddingCache,
|
||||||
|
NewEmbedderError,
|
||||||
};
|
};
|
||||||
use crate::ThreadPoolNoAbort;
|
use crate::ThreadPoolNoAbort;
|
||||||
|
|
||||||
@ -148,6 +149,27 @@ impl SubEmbedder {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn embed_one(
|
||||||
|
&self,
|
||||||
|
text: &str,
|
||||||
|
deadline: Option<Instant>,
|
||||||
|
) -> std::result::Result<Embedding, EmbedError> {
|
||||||
|
match self {
|
||||||
|
SubEmbedder::HuggingFace(embedder) => embedder.embed_one(text),
|
||||||
|
SubEmbedder::OpenAi(embedder) => {
|
||||||
|
embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding)
|
||||||
|
}
|
||||||
|
SubEmbedder::Ollama(embedder) => {
|
||||||
|
embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding)
|
||||||
|
}
|
||||||
|
SubEmbedder::UserProvided(embedder) => embedder.embed_one(text),
|
||||||
|
SubEmbedder::Rest(embedder) => embedder
|
||||||
|
.embed_ref(&[text], deadline)?
|
||||||
|
.pop()
|
||||||
|
.ok_or_else(EmbedError::missing_embedding),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Embed multiple chunks of texts.
|
/// Embed multiple chunks of texts.
|
||||||
///
|
///
|
||||||
/// Each chunk is composed of one or multiple texts.
|
/// Each chunk is composed of one or multiple texts.
|
||||||
@ -233,6 +255,16 @@ impl SubEmbedder {
|
|||||||
SubEmbedder::Rest(embedder) => embedder.distribution(),
|
SubEmbedder::Rest(embedder) => embedder.distribution(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(super) fn cache(&self) -> Option<&EmbeddingCache> {
|
||||||
|
match self {
|
||||||
|
SubEmbedder::HuggingFace(embedder) => Some(embedder.cache()),
|
||||||
|
SubEmbedder::OpenAi(embedder) => Some(embedder.cache()),
|
||||||
|
SubEmbedder::UserProvided(_) => None,
|
||||||
|
SubEmbedder::Ollama(embedder) => Some(embedder.cache()),
|
||||||
|
SubEmbedder::Rest(embedder) => Some(embedder.cache()),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_similarity(
|
fn check_similarity(
|
||||||
|
@ -7,7 +7,7 @@ use hf_hub::{Repo, RepoType};
|
|||||||
use tokenizers::{PaddingParams, Tokenizer};
|
use tokenizers::{PaddingParams, Tokenizer};
|
||||||
|
|
||||||
pub use super::error::{EmbedError, Error, NewEmbedderError};
|
pub use super::error::{EmbedError, Error, NewEmbedderError};
|
||||||
use super::{DistributionShift, Embedding};
|
use super::{DistributionShift, Embedding, EmbeddingCache};
|
||||||
|
|
||||||
#[derive(
|
#[derive(
|
||||||
Debug,
|
Debug,
|
||||||
@ -84,6 +84,7 @@ pub struct Embedder {
|
|||||||
options: EmbedderOptions,
|
options: EmbedderOptions,
|
||||||
dimensions: usize,
|
dimensions: usize,
|
||||||
pooling: Pooling,
|
pooling: Pooling,
|
||||||
|
cache: EmbeddingCache,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for Embedder {
|
impl std::fmt::Debug for Embedder {
|
||||||
@ -245,7 +246,14 @@ impl Embedder {
|
|||||||
tokenizer.with_padding(Some(pp));
|
tokenizer.with_padding(Some(pp));
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut this = Self { model, tokenizer, options, dimensions: 0, pooling };
|
let mut this = Self {
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
options,
|
||||||
|
dimensions: 0,
|
||||||
|
pooling,
|
||||||
|
cache: EmbeddingCache::new(super::CAP_PER_THREAD),
|
||||||
|
};
|
||||||
|
|
||||||
let embeddings = this
|
let embeddings = this
|
||||||
.embed(vec!["test".into()])
|
.embed(vec!["test".into()])
|
||||||
@ -355,4 +363,8 @@ impl Embedder {
|
|||||||
pub(crate) fn embed_index_ref(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbedError> {
|
pub(crate) fn embed_index_ref(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbedError> {
|
||||||
texts.iter().map(|text| self.embed_one(text)).collect()
|
texts.iter().map(|text| self.embed_one(text)).collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(super) fn cache(&self) -> &EmbeddingCache {
|
||||||
|
&self.cache
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
|
use std::cell::RefCell;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::num::{NonZeroUsize, TryFromIntError};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|
||||||
@ -551,6 +553,51 @@ pub enum Embedder {
|
|||||||
Composite(composite::Embedder),
|
Composite(composite::Embedder),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct EmbeddingCache {
|
||||||
|
data: thread_local::ThreadLocal<RefCell<lru::LruCache<String, Embedding>>>,
|
||||||
|
cap_per_thread: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EmbeddingCache {
|
||||||
|
pub fn new(cap_per_thread: u16) -> Self {
|
||||||
|
Self { cap_per_thread, data: thread_local::ThreadLocal::new() }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the embedding corresponding to `text`, if any is present in the cache.
|
||||||
|
pub fn get(&self, text: &str) -> Option<Embedding> {
|
||||||
|
let mut cache = self
|
||||||
|
.data
|
||||||
|
.get_or_try(|| -> Result<RefCell<lru::LruCache<String, Vec<f32>>>, TryFromIntError> {
|
||||||
|
Ok(RefCell::new(lru::LruCache::new(NonZeroUsize::try_from(
|
||||||
|
self.cap_per_thread as usize,
|
||||||
|
)?)))
|
||||||
|
})
|
||||||
|
.ok()?
|
||||||
|
.borrow_mut();
|
||||||
|
|
||||||
|
cache.get(text).cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Puts a new embedding for the specified `text`
|
||||||
|
pub fn put(&self, text: String, embedding: Embedding) {
|
||||||
|
let Ok(cache) = self.data.get_or_try(
|
||||||
|
|| -> Result<RefCell<lru::LruCache<String, Vec<f32>>>, TryFromIntError> {
|
||||||
|
Ok(RefCell::new(lru::LruCache::new(NonZeroUsize::try_from(
|
||||||
|
self.cap_per_thread as usize,
|
||||||
|
)?)))
|
||||||
|
},
|
||||||
|
) else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
let mut cache = cache.borrow_mut();
|
||||||
|
|
||||||
|
cache.put(text, embedding);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const CAP_PER_THREAD: u16 = 20;
|
||||||
|
|
||||||
/// Configuration for an embedder.
|
/// Configuration for an embedder.
|
||||||
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
|
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
|
||||||
pub struct EmbeddingConfig {
|
pub struct EmbeddingConfig {
|
||||||
@ -651,19 +698,36 @@ impl Embedder {
|
|||||||
#[tracing::instrument(level = "debug", skip_all, target = "search")]
|
#[tracing::instrument(level = "debug", skip_all, target = "search")]
|
||||||
pub fn embed_search(
|
pub fn embed_search(
|
||||||
&self,
|
&self,
|
||||||
text: String,
|
text: &str,
|
||||||
deadline: Option<Instant>,
|
deadline: Option<Instant>,
|
||||||
) -> std::result::Result<Embedding, EmbedError> {
|
) -> std::result::Result<Embedding, EmbedError> {
|
||||||
let texts = vec![text];
|
if let Some(cache) = self.cache() {
|
||||||
let mut embedding = match self {
|
if let Some(embedding) = cache.get(text) {
|
||||||
Embedder::HuggingFace(embedder) => embedder.embed(texts),
|
tracing::trace!(text, "embedding found in cache");
|
||||||
Embedder::OpenAi(embedder) => embedder.embed(&texts, deadline),
|
return Ok(embedding);
|
||||||
Embedder::Ollama(embedder) => embedder.embed(&texts, deadline),
|
}
|
||||||
Embedder::UserProvided(embedder) => embedder.embed(&texts),
|
}
|
||||||
Embedder::Rest(embedder) => embedder.embed(texts, deadline),
|
let embedding = match self {
|
||||||
Embedder::Composite(embedder) => embedder.search.embed(texts, deadline),
|
Embedder::HuggingFace(embedder) => embedder.embed_one(text),
|
||||||
|
Embedder::OpenAi(embedder) => {
|
||||||
|
embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding)
|
||||||
|
}
|
||||||
|
Embedder::Ollama(embedder) => {
|
||||||
|
embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding)
|
||||||
|
}
|
||||||
|
Embedder::UserProvided(embedder) => embedder.embed_one(text),
|
||||||
|
Embedder::Rest(embedder) => embedder
|
||||||
|
.embed_ref(&[text], deadline)?
|
||||||
|
.pop()
|
||||||
|
.ok_or_else(EmbedError::missing_embedding),
|
||||||
|
Embedder::Composite(embedder) => embedder.search.embed_one(text, deadline),
|
||||||
}?;
|
}?;
|
||||||
let embedding = embedding.pop().ok_or_else(EmbedError::missing_embedding)?;
|
|
||||||
|
if let Some(cache) = self.cache() {
|
||||||
|
tracing::trace!(text, "embedding added to cache");
|
||||||
|
cache.put(text.to_owned(), embedding.clone());
|
||||||
|
}
|
||||||
|
|
||||||
Ok(embedding)
|
Ok(embedding)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -759,6 +823,17 @@ impl Embedder {
|
|||||||
Embedder::Composite(embedder) => embedder.index.uses_document_template(),
|
Embedder::Composite(embedder) => embedder.index.uses_document_template(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn cache(&self) -> Option<&EmbeddingCache> {
|
||||||
|
match self {
|
||||||
|
Embedder::HuggingFace(embedder) => Some(embedder.cache()),
|
||||||
|
Embedder::OpenAi(embedder) => Some(embedder.cache()),
|
||||||
|
Embedder::UserProvided(_) => None,
|
||||||
|
Embedder::Ollama(embedder) => Some(embedder.cache()),
|
||||||
|
Embedder::Rest(embedder) => Some(embedder.cache()),
|
||||||
|
Embedder::Composite(embedder) => embedder.search.cache(),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Describes the mean and sigma of distribution of embedding similarity in the embedding space.
|
/// Describes the mean and sigma of distribution of embedding similarity in the embedding space.
|
||||||
|
@ -5,7 +5,7 @@ use rayon::slice::ParallelSlice as _;
|
|||||||
|
|
||||||
use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind};
|
use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind};
|
||||||
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
|
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
|
||||||
use super::{DistributionShift, REQUEST_PARALLELISM};
|
use super::{DistributionShift, EmbeddingCache, REQUEST_PARALLELISM};
|
||||||
use crate::error::FaultSource;
|
use crate::error::FaultSource;
|
||||||
use crate::vector::Embedding;
|
use crate::vector::Embedding;
|
||||||
use crate::ThreadPoolNoAbort;
|
use crate::ThreadPoolNoAbort;
|
||||||
@ -182,6 +182,10 @@ impl Embedder {
|
|||||||
pub fn distribution(&self) -> Option<DistributionShift> {
|
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||||
self.rest_embedder.distribution()
|
self.rest_embedder.distribution()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(super) fn cache(&self) -> &EmbeddingCache {
|
||||||
|
self.rest_embedder.cache()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_ollama_path() -> String {
|
fn get_ollama_path() -> String {
|
||||||
|
@ -7,7 +7,7 @@ use rayon::slice::ParallelSlice as _;
|
|||||||
|
|
||||||
use super::error::{EmbedError, NewEmbedderError};
|
use super::error::{EmbedError, NewEmbedderError};
|
||||||
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
|
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
|
||||||
use super::{DistributionShift, REQUEST_PARALLELISM};
|
use super::{DistributionShift, EmbeddingCache, REQUEST_PARALLELISM};
|
||||||
use crate::error::FaultSource;
|
use crate::error::FaultSource;
|
||||||
use crate::vector::error::EmbedErrorKind;
|
use crate::vector::error::EmbedErrorKind;
|
||||||
use crate::vector::Embedding;
|
use crate::vector::Embedding;
|
||||||
@ -318,6 +318,10 @@ impl Embedder {
|
|||||||
pub fn distribution(&self) -> Option<DistributionShift> {
|
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||||
self.options.distribution()
|
self.options.distribution()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(super) fn cache(&self) -> &EmbeddingCache {
|
||||||
|
self.rest_embedder.cache()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Debug for Embedder {
|
impl fmt::Debug for Embedder {
|
||||||
|
@ -9,7 +9,10 @@ use serde::{Deserialize, Serialize};
|
|||||||
|
|
||||||
use super::error::EmbedErrorKind;
|
use super::error::EmbedErrorKind;
|
||||||
use super::json_template::ValueTemplate;
|
use super::json_template::ValueTemplate;
|
||||||
use super::{DistributionShift, EmbedError, Embedding, NewEmbedderError, REQUEST_PARALLELISM};
|
use super::{
|
||||||
|
DistributionShift, EmbedError, Embedding, EmbeddingCache, NewEmbedderError, CAP_PER_THREAD,
|
||||||
|
REQUEST_PARALLELISM,
|
||||||
|
};
|
||||||
use crate::error::FaultSource;
|
use crate::error::FaultSource;
|
||||||
use crate::ThreadPoolNoAbort;
|
use crate::ThreadPoolNoAbort;
|
||||||
|
|
||||||
@ -75,6 +78,7 @@ pub struct Embedder {
|
|||||||
data: EmbedderData,
|
data: EmbedderData,
|
||||||
dimensions: usize,
|
dimensions: usize,
|
||||||
distribution: Option<DistributionShift>,
|
distribution: Option<DistributionShift>,
|
||||||
|
cache: EmbeddingCache,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// All data needed to perform requests and parse responses
|
/// All data needed to perform requests and parse responses
|
||||||
@ -152,7 +156,12 @@ impl Embedder {
|
|||||||
infer_dimensions(&data)?
|
infer_dimensions(&data)?
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Self { data, dimensions, distribution: options.distribution })
|
Ok(Self {
|
||||||
|
data,
|
||||||
|
dimensions,
|
||||||
|
distribution: options.distribution,
|
||||||
|
cache: EmbeddingCache::new(CAP_PER_THREAD),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embed(
|
pub fn embed(
|
||||||
@ -256,6 +265,10 @@ impl Embedder {
|
|||||||
pub fn distribution(&self) -> Option<DistributionShift> {
|
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||||
self.distribution
|
self.distribution
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(super) fn cache(&self) -> &EmbeddingCache {
|
||||||
|
&self.cache
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> {
|
fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user