mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-03-20 06:30:38 +01:00
Mutex-based implementation
This commit is contained in:
parent
d0b0b90d17
commit
1876132172
@ -252,7 +252,7 @@ impl Embedder {
|
|||||||
options,
|
options,
|
||||||
dimensions: 0,
|
dimensions: 0,
|
||||||
pooling,
|
pooling,
|
||||||
cache: EmbeddingCache::new(super::CAP_PER_THREAD),
|
cache: EmbeddingCache::new(super::CACHE_CAP),
|
||||||
};
|
};
|
||||||
|
|
||||||
let embeddings = this
|
let embeddings = this
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
use std::cell::RefCell;
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::num::{NonZeroUsize, TryFromIntError};
|
use std::num::NonZeroUsize;
|
||||||
use std::sync::Arc;
|
use std::sync::{Arc, Mutex};
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|
||||||
use arroy::distances::{BinaryQuantizedCosine, Cosine};
|
use arroy::distances::{BinaryQuantizedCosine, Cosine};
|
||||||
@ -555,48 +554,43 @@ pub enum Embedder {
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct EmbeddingCache {
|
struct EmbeddingCache {
|
||||||
data: thread_local::ThreadLocal<RefCell<lru::LruCache<String, Embedding>>>,
|
data: Option<Mutex<lru::LruCache<String, Embedding>>>,
|
||||||
cap_per_thread: u16,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EmbeddingCache {
|
impl EmbeddingCache {
|
||||||
pub fn new(cap_per_thread: u16) -> Self {
|
const MAX_TEXT_LEN: usize = 2000;
|
||||||
Self { cap_per_thread, data: thread_local::ThreadLocal::new() }
|
|
||||||
|
pub fn new(cap: u16) -> Self {
|
||||||
|
let data = NonZeroUsize::new(cap.into()).map(lru::LruCache::new).map(Mutex::new);
|
||||||
|
Self { data }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the embedding corresponding to `text`, if any is present in the cache.
|
/// Get the embedding corresponding to `text`, if any is present in the cache.
|
||||||
pub fn get(&self, text: &str) -> Option<Embedding> {
|
pub fn get(&self, text: &str) -> Option<Embedding> {
|
||||||
let mut cache = self
|
let data = self.data.as_ref()?;
|
||||||
.data
|
if text.len() > Self::MAX_TEXT_LEN {
|
||||||
.get_or_try(|| -> Result<RefCell<lru::LruCache<String, Vec<f32>>>, TryFromIntError> {
|
return None;
|
||||||
Ok(RefCell::new(lru::LruCache::new(NonZeroUsize::try_from(
|
}
|
||||||
self.cap_per_thread as usize,
|
let mut cache = data.lock().unwrap();
|
||||||
)?)))
|
|
||||||
})
|
|
||||||
.ok()?
|
|
||||||
.borrow_mut();
|
|
||||||
|
|
||||||
cache.get(text).cloned()
|
cache.get(text).cloned()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Puts a new embedding for the specified `text`
|
/// Puts a new embedding for the specified `text`
|
||||||
pub fn put(&self, text: String, embedding: Embedding) {
|
pub fn put(&self, text: String, embedding: Embedding) {
|
||||||
let Ok(cache) = self.data.get_or_try(
|
let Some(data) = self.data.as_ref() else {
|
||||||
|| -> Result<RefCell<lru::LruCache<String, Vec<f32>>>, TryFromIntError> {
|
|
||||||
Ok(RefCell::new(lru::LruCache::new(NonZeroUsize::try_from(
|
|
||||||
self.cap_per_thread as usize,
|
|
||||||
)?)))
|
|
||||||
},
|
|
||||||
) else {
|
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
let mut cache = cache.borrow_mut();
|
if text.len() > Self::MAX_TEXT_LEN {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let mut cache = data.lock().unwrap();
|
||||||
|
|
||||||
cache.put(text, embedding);
|
cache.put(text, embedding);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const CAP_PER_THREAD: u16 = 20;
|
pub const CACHE_CAP: u16 = 150;
|
||||||
|
|
||||||
/// Configuration for an embedder.
|
/// Configuration for an embedder.
|
||||||
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
|
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
|
||||||
|
@ -10,7 +10,7 @@ use serde::{Deserialize, Serialize};
|
|||||||
use super::error::EmbedErrorKind;
|
use super::error::EmbedErrorKind;
|
||||||
use super::json_template::ValueTemplate;
|
use super::json_template::ValueTemplate;
|
||||||
use super::{
|
use super::{
|
||||||
DistributionShift, EmbedError, Embedding, EmbeddingCache, NewEmbedderError, CAP_PER_THREAD,
|
DistributionShift, EmbedError, Embedding, EmbeddingCache, NewEmbedderError, CACHE_CAP,
|
||||||
REQUEST_PARALLELISM,
|
REQUEST_PARALLELISM,
|
||||||
};
|
};
|
||||||
use crate::error::FaultSource;
|
use crate::error::FaultSource;
|
||||||
@ -160,7 +160,7 @@ impl Embedder {
|
|||||||
data,
|
data,
|
||||||
dimensions,
|
dimensions,
|
||||||
distribution: options.distribution,
|
distribution: options.distribution,
|
||||||
cache: EmbeddingCache::new(CAP_PER_THREAD),
|
cache: EmbeddingCache::new(CACHE_CAP),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user