Mutex-based implementation

This commit is contained in:
Louis Dureuil 2025-03-13 12:00:11 +01:00
parent d0b0b90d17
commit 1876132172
No known key found for this signature in database
3 changed files with 22 additions and 28 deletions

View File

@ -252,7 +252,7 @@ impl Embedder {
options, options,
dimensions: 0, dimensions: 0,
pooling, pooling,
cache: EmbeddingCache::new(super::CAP_PER_THREAD), cache: EmbeddingCache::new(super::CACHE_CAP),
}; };
let embeddings = this let embeddings = this

View File

@ -1,7 +1,6 @@
use std::cell::RefCell;
use std::collections::HashMap; use std::collections::HashMap;
use std::num::{NonZeroUsize, TryFromIntError}; use std::num::NonZeroUsize;
use std::sync::Arc; use std::sync::{Arc, Mutex};
use std::time::Instant; use std::time::Instant;
use arroy::distances::{BinaryQuantizedCosine, Cosine}; use arroy::distances::{BinaryQuantizedCosine, Cosine};
@ -555,48 +554,43 @@ pub enum Embedder {
#[derive(Debug)] #[derive(Debug)]
struct EmbeddingCache { struct EmbeddingCache {
data: thread_local::ThreadLocal<RefCell<lru::LruCache<String, Embedding>>>, data: Option<Mutex<lru::LruCache<String, Embedding>>>,
cap_per_thread: u16,
} }
impl EmbeddingCache { impl EmbeddingCache {
pub fn new(cap_per_thread: u16) -> Self { const MAX_TEXT_LEN: usize = 2000;
Self { cap_per_thread, data: thread_local::ThreadLocal::new() }
pub fn new(cap: u16) -> Self {
let data = NonZeroUsize::new(cap.into()).map(lru::LruCache::new).map(Mutex::new);
Self { data }
} }
/// Get the embedding corresponding to `text`, if any is present in the cache. /// Get the embedding corresponding to `text`, if any is present in the cache.
pub fn get(&self, text: &str) -> Option<Embedding> { pub fn get(&self, text: &str) -> Option<Embedding> {
let mut cache = self let data = self.data.as_ref()?;
.data if text.len() > Self::MAX_TEXT_LEN {
.get_or_try(|| -> Result<RefCell<lru::LruCache<String, Vec<f32>>>, TryFromIntError> { return None;
Ok(RefCell::new(lru::LruCache::new(NonZeroUsize::try_from( }
self.cap_per_thread as usize, let mut cache = data.lock().unwrap();
)?)))
})
.ok()?
.borrow_mut();
cache.get(text).cloned() cache.get(text).cloned()
} }
/// Puts a new embedding for the specified `text` /// Puts a new embedding for the specified `text`
pub fn put(&self, text: String, embedding: Embedding) { pub fn put(&self, text: String, embedding: Embedding) {
let Ok(cache) = self.data.get_or_try( let Some(data) = self.data.as_ref() else {
|| -> Result<RefCell<lru::LruCache<String, Vec<f32>>>, TryFromIntError> {
Ok(RefCell::new(lru::LruCache::new(NonZeroUsize::try_from(
self.cap_per_thread as usize,
)?)))
},
) else {
return; return;
}; };
let mut cache = cache.borrow_mut(); if text.len() > Self::MAX_TEXT_LEN {
return;
}
let mut cache = data.lock().unwrap();
cache.put(text, embedding); cache.put(text, embedding);
} }
} }
pub const CAP_PER_THREAD: u16 = 20; pub const CACHE_CAP: u16 = 150;
/// Configuration for an embedder. /// Configuration for an embedder.
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)] #[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]

View File

@ -10,7 +10,7 @@ use serde::{Deserialize, Serialize};
use super::error::EmbedErrorKind; use super::error::EmbedErrorKind;
use super::json_template::ValueTemplate; use super::json_template::ValueTemplate;
use super::{ use super::{
DistributionShift, EmbedError, Embedding, EmbeddingCache, NewEmbedderError, CAP_PER_THREAD, DistributionShift, EmbedError, Embedding, EmbeddingCache, NewEmbedderError, CACHE_CAP,
REQUEST_PARALLELISM, REQUEST_PARALLELISM,
}; };
use crate::error::FaultSource; use crate::error::FaultSource;
@ -160,7 +160,7 @@ impl Embedder {
data, data,
dimensions, dimensions,
distribution: options.distribution, distribution: options.distribution,
cache: EmbeddingCache::new(CAP_PER_THREAD), cache: EmbeddingCache::new(CACHE_CAP),
}) })
} }