Add embedding cache

This commit is contained in:
Louis Dureuil 2025-03-13 11:13:14 +01:00
parent d9111fe8ce
commit b08544e86d
No known key found for this signature in database
8 changed files with 159 additions and 19 deletions

View file

@ -7,7 +7,7 @@ use hf_hub::{Repo, RepoType};
use tokenizers::{PaddingParams, Tokenizer};
pub use super::error::{EmbedError, Error, NewEmbedderError};
use super::{DistributionShift, Embedding};
use super::{DistributionShift, Embedding, EmbeddingCache};
#[derive(
Debug,
@ -84,6 +84,7 @@ pub struct Embedder {
options: EmbedderOptions,
dimensions: usize,
pooling: Pooling,
cache: EmbeddingCache,
}
impl std::fmt::Debug for Embedder {
@ -245,7 +246,14 @@ impl Embedder {
tokenizer.with_padding(Some(pp));
}
let mut this = Self { model, tokenizer, options, dimensions: 0, pooling };
let mut this = Self {
model,
tokenizer,
options,
dimensions: 0,
pooling,
cache: EmbeddingCache::new(super::CAP_PER_THREAD),
};
let embeddings = this
.embed(vec!["test".into()])
@ -355,4 +363,8 @@ impl Embedder {
pub(crate) fn embed_index_ref(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbedError> {
texts.iter().map(|text| self.embed_one(text)).collect()
}
pub(super) fn cache(&self) -> &EmbeddingCache {
&self.cache
}
}