mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-03-18 21:58:20 +01:00
Mutex-based implementation
This commit is contained in:
parent
d0b0b90d17
commit
1876132172
@ -252,7 +252,7 @@ impl Embedder {
|
||||
options,
|
||||
dimensions: 0,
|
||||
pooling,
|
||||
cache: EmbeddingCache::new(super::CAP_PER_THREAD),
|
||||
cache: EmbeddingCache::new(super::CACHE_CAP),
|
||||
};
|
||||
|
||||
let embeddings = this
|
||||
|
@ -1,7 +1,6 @@
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::num::{NonZeroUsize, TryFromIntError};
|
||||
use std::sync::Arc;
|
||||
use std::num::NonZeroUsize;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Instant;
|
||||
|
||||
use arroy::distances::{BinaryQuantizedCosine, Cosine};
|
||||
@ -555,48 +554,43 @@ pub enum Embedder {
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EmbeddingCache {
|
||||
data: thread_local::ThreadLocal<RefCell<lru::LruCache<String, Embedding>>>,
|
||||
cap_per_thread: u16,
|
||||
data: Option<Mutex<lru::LruCache<String, Embedding>>>,
|
||||
}
|
||||
|
||||
impl EmbeddingCache {
|
||||
pub fn new(cap_per_thread: u16) -> Self {
|
||||
Self { cap_per_thread, data: thread_local::ThreadLocal::new() }
|
||||
const MAX_TEXT_LEN: usize = 2000;
|
||||
|
||||
pub fn new(cap: u16) -> Self {
|
||||
let data = NonZeroUsize::new(cap.into()).map(lru::LruCache::new).map(Mutex::new);
|
||||
Self { data }
|
||||
}
|
||||
|
||||
/// Get the embedding corresponding to `text`, if any is present in the cache.
|
||||
pub fn get(&self, text: &str) -> Option<Embedding> {
|
||||
let mut cache = self
|
||||
.data
|
||||
.get_or_try(|| -> Result<RefCell<lru::LruCache<String, Vec<f32>>>, TryFromIntError> {
|
||||
Ok(RefCell::new(lru::LruCache::new(NonZeroUsize::try_from(
|
||||
self.cap_per_thread as usize,
|
||||
)?)))
|
||||
})
|
||||
.ok()?
|
||||
.borrow_mut();
|
||||
let data = self.data.as_ref()?;
|
||||
if text.len() > Self::MAX_TEXT_LEN {
|
||||
return None;
|
||||
}
|
||||
let mut cache = data.lock().unwrap();
|
||||
|
||||
cache.get(text).cloned()
|
||||
}
|
||||
|
||||
/// Puts a new embedding for the specified `text`
|
||||
pub fn put(&self, text: String, embedding: Embedding) {
|
||||
let Ok(cache) = self.data.get_or_try(
|
||||
|| -> Result<RefCell<lru::LruCache<String, Vec<f32>>>, TryFromIntError> {
|
||||
Ok(RefCell::new(lru::LruCache::new(NonZeroUsize::try_from(
|
||||
self.cap_per_thread as usize,
|
||||
)?)))
|
||||
},
|
||||
) else {
|
||||
let Some(data) = self.data.as_ref() else {
|
||||
return;
|
||||
};
|
||||
let mut cache = cache.borrow_mut();
|
||||
if text.len() > Self::MAX_TEXT_LEN {
|
||||
return;
|
||||
}
|
||||
let mut cache = data.lock().unwrap();
|
||||
|
||||
cache.put(text, embedding);
|
||||
}
|
||||
}
|
||||
|
||||
pub const CAP_PER_THREAD: u16 = 20;
|
||||
pub const CACHE_CAP: u16 = 150;
|
||||
|
||||
/// Configuration for an embedder.
|
||||
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
|
||||
|
@ -10,7 +10,7 @@ use serde::{Deserialize, Serialize};
|
||||
use super::error::EmbedErrorKind;
|
||||
use super::json_template::ValueTemplate;
|
||||
use super::{
|
||||
DistributionShift, EmbedError, Embedding, EmbeddingCache, NewEmbedderError, CAP_PER_THREAD,
|
||||
DistributionShift, EmbedError, Embedding, EmbeddingCache, NewEmbedderError, CACHE_CAP,
|
||||
REQUEST_PARALLELISM,
|
||||
};
|
||||
use crate::error::FaultSource;
|
||||
@ -160,7 +160,7 @@ impl Embedder {
|
||||
data,
|
||||
dimensions,
|
||||
distribution: options.distribution,
|
||||
cache: EmbeddingCache::new(CAP_PER_THREAD),
|
||||
cache: EmbeddingCache::new(CACHE_CAP),
|
||||
})
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user