Mutex-based implementation

This commit is contained in:
Louis Dureuil 2025-03-13 12:00:11 +01:00
parent d0b0b90d17
commit 1876132172
No known key found for this signature in database
3 changed files with 22 additions and 28 deletions

View File

@ -252,7 +252,7 @@ impl Embedder {
options,
dimensions: 0,
pooling,
cache: EmbeddingCache::new(super::CAP_PER_THREAD),
cache: EmbeddingCache::new(super::CACHE_CAP),
};
let embeddings = this

View File

@ -1,7 +1,6 @@
use std::cell::RefCell;
use std::collections::HashMap;
use std::num::{NonZeroUsize, TryFromIntError};
use std::sync::Arc;
use std::num::NonZeroUsize;
use std::sync::{Arc, Mutex};
use std::time::Instant;
use arroy::distances::{BinaryQuantizedCosine, Cosine};
@ -555,48 +554,43 @@ pub enum Embedder {
#[derive(Debug)]
struct EmbeddingCache {
data: thread_local::ThreadLocal<RefCell<lru::LruCache<String, Embedding>>>,
cap_per_thread: u16,
data: Option<Mutex<lru::LruCache<String, Embedding>>>,
}
impl EmbeddingCache {
pub fn new(cap_per_thread: u16) -> Self {
Self { cap_per_thread, data: thread_local::ThreadLocal::new() }
const MAX_TEXT_LEN: usize = 2000;
pub fn new(cap: u16) -> Self {
let data = NonZeroUsize::new(cap.into()).map(lru::LruCache::new).map(Mutex::new);
Self { data }
}
/// Get the embedding corresponding to `text`, if any is present in the cache.
pub fn get(&self, text: &str) -> Option<Embedding> {
let mut cache = self
.data
.get_or_try(|| -> Result<RefCell<lru::LruCache<String, Vec<f32>>>, TryFromIntError> {
Ok(RefCell::new(lru::LruCache::new(NonZeroUsize::try_from(
self.cap_per_thread as usize,
)?)))
})
.ok()?
.borrow_mut();
let data = self.data.as_ref()?;
if text.len() > Self::MAX_TEXT_LEN {
return None;
}
let mut cache = data.lock().unwrap();
cache.get(text).cloned()
}
/// Puts a new embedding for the specified `text`
pub fn put(&self, text: String, embedding: Embedding) {
let Ok(cache) = self.data.get_or_try(
|| -> Result<RefCell<lru::LruCache<String, Vec<f32>>>, TryFromIntError> {
Ok(RefCell::new(lru::LruCache::new(NonZeroUsize::try_from(
self.cap_per_thread as usize,
)?)))
},
) else {
let Some(data) = self.data.as_ref() else {
return;
};
let mut cache = cache.borrow_mut();
if text.len() > Self::MAX_TEXT_LEN {
return;
}
let mut cache = data.lock().unwrap();
cache.put(text, embedding);
}
}
pub const CAP_PER_THREAD: u16 = 20;
pub const CACHE_CAP: u16 = 150;
/// Configuration for an embedder.
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]

View File

@ -10,7 +10,7 @@ use serde::{Deserialize, Serialize};
use super::error::EmbedErrorKind;
use super::json_template::ValueTemplate;
use super::{
DistributionShift, EmbedError, Embedding, EmbeddingCache, NewEmbedderError, CAP_PER_THREAD,
DistributionShift, EmbedError, Embedding, EmbeddingCache, NewEmbedderError, CACHE_CAP,
REQUEST_PARALLELISM,
};
use crate::error::FaultSource;
@ -160,7 +160,7 @@ impl Embedder {
data,
dimensions,
distribution: options.distribution,
cache: EmbeddingCache::new(CAP_PER_THREAD),
cache: EmbeddingCache::new(CACHE_CAP),
})
}