diff --git a/crates/milli/src/vector/hf.rs b/crates/milli/src/vector/hf.rs index 6e73c8247..ce7429d36 100644 --- a/crates/milli/src/vector/hf.rs +++ b/crates/milli/src/vector/hf.rs @@ -252,7 +252,7 @@ impl Embedder { options, dimensions: 0, pooling, - cache: EmbeddingCache::new(super::CAP_PER_THREAD), + cache: EmbeddingCache::new(super::CACHE_CAP), }; let embeddings = this diff --git a/crates/milli/src/vector/mod.rs b/crates/milli/src/vector/mod.rs index 55b865b4a..476ba28c9 100644 --- a/crates/milli/src/vector/mod.rs +++ b/crates/milli/src/vector/mod.rs @@ -1,7 +1,6 @@ -use std::cell::RefCell; use std::collections::HashMap; -use std::num::{NonZeroUsize, TryFromIntError}; -use std::sync::Arc; +use std::num::NonZeroUsize; +use std::sync::{Arc, Mutex}; use std::time::Instant; use arroy::distances::{BinaryQuantizedCosine, Cosine}; @@ -555,48 +554,43 @@ pub enum Embedder { #[derive(Debug)] struct EmbeddingCache { - data: thread_local::ThreadLocal>>, - cap_per_thread: u16, + data: Option>>, } impl EmbeddingCache { - pub fn new(cap_per_thread: u16) -> Self { - Self { cap_per_thread, data: thread_local::ThreadLocal::new() } + const MAX_TEXT_LEN: usize = 2000; + + pub fn new(cap: u16) -> Self { + let data = NonZeroUsize::new(cap.into()).map(lru::LruCache::new).map(Mutex::new); + Self { data } } /// Get the embedding corresponding to `text`, if any is present in the cache. pub fn get(&self, text: &str) -> Option { - let mut cache = self - .data - .get_or_try(|| -> Result>>, TryFromIntError> { - Ok(RefCell::new(lru::LruCache::new(NonZeroUsize::try_from( - self.cap_per_thread as usize, - )?))) - }) - .ok()? - .borrow_mut(); + let data = self.data.as_ref()?; + if text.len() > Self::MAX_TEXT_LEN { + return None; + } + let mut cache = data.lock().unwrap(); cache.get(text).cloned() } /// Puts a new embedding for the specified `text` pub fn put(&self, text: String, embedding: Embedding) { - let Ok(cache) = self.data.get_or_try( - || -> Result>>, TryFromIntError> { - Ok(RefCell::new(lru::LruCache::new(NonZeroUsize::try_from( - self.cap_per_thread as usize, - )?))) - }, - ) else { + let Some(data) = self.data.as_ref() else { return; }; - let mut cache = cache.borrow_mut(); + if text.len() > Self::MAX_TEXT_LEN { + return; + } + let mut cache = data.lock().unwrap(); cache.put(text, embedding); } } -pub const CAP_PER_THREAD: u16 = 20; +pub const CACHE_CAP: u16 = 150; /// Configuration for an embedder. #[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)] diff --git a/crates/milli/src/vector/rest.rs b/crates/milli/src/vector/rest.rs index b9b8b0fb3..9761c753e 100644 --- a/crates/milli/src/vector/rest.rs +++ b/crates/milli/src/vector/rest.rs @@ -10,7 +10,7 @@ use serde::{Deserialize, Serialize}; use super::error::EmbedErrorKind; use super::json_template::ValueTemplate; use super::{ - DistributionShift, EmbedError, Embedding, EmbeddingCache, NewEmbedderError, CAP_PER_THREAD, + DistributionShift, EmbedError, Embedding, EmbeddingCache, NewEmbedderError, CACHE_CAP, REQUEST_PARALLELISM, }; use crate::error::FaultSource; @@ -160,7 +160,7 @@ impl Embedder { data, dimensions, distribution: options.distribution, - cache: EmbeddingCache::new(CAP_PER_THREAD), + cache: EmbeddingCache::new(CACHE_CAP), }) }