implement the binary quantization in meilisearch

This commit is contained in:
Tamo 2024-09-18 18:13:37 +02:00
parent 5f474a640d
commit cc45e264ca
20 changed files with 559 additions and 223 deletions

View file

@ -1,8 +1,12 @@
use std::collections::HashMap;
use std::sync::Arc;
use arroy::distances::{Angular, BinaryQuantizedAngular};
use arroy::ItemId;
use deserr::{DeserializeError, Deserr};
use heed::{RoTxn, RwTxn, Unspecified};
use ordered_float::OrderedFloat;
use roaring::RoaringBitmap;
use serde::{Deserialize, Serialize};
use self::error::{EmbedError, NewEmbedderError};
@ -26,6 +30,171 @@ pub type Embedding = Vec<f32>;
pub const REQUEST_PARALLELISM: usize = 40;
pub struct ArroyReader {
quantized: bool,
index: u16,
database: arroy::Database<Unspecified>,
}
impl ArroyReader {
pub fn new(database: arroy::Database<Unspecified>, index: u16, quantized: bool) -> Self {
Self { database, index, quantized }
}
pub fn index(&self) -> u16 {
self.index
}
pub fn dimensions(&self, rtxn: &RoTxn) -> Result<usize, arroy::Error> {
if self.quantized {
Ok(arroy::Reader::open(rtxn, self.index, self.quantized_db())?.dimensions())
} else {
Ok(arroy::Reader::open(rtxn, self.index, self.angular_db())?.dimensions())
}
}
pub fn quantize(
&mut self,
wtxn: &mut RwTxn,
index: u16,
dimension: usize,
) -> Result<(), arroy::Error> {
if !self.quantized {
let writer = arroy::Writer::new(self.angular_db(), index, dimension);
writer.prepare_changing_distance::<BinaryQuantizedAngular>(wtxn)?;
self.quantized = true;
}
Ok(())
}
pub fn need_build(&self, rtxn: &RoTxn, dimension: usize) -> Result<bool, arroy::Error> {
if self.quantized {
arroy::Writer::new(self.quantized_db(), self.index, dimension).need_build(rtxn)
} else {
arroy::Writer::new(self.angular_db(), self.index, dimension).need_build(rtxn)
}
}
pub fn build<R: rand::Rng + rand::SeedableRng>(
&self,
wtxn: &mut RwTxn,
rng: &mut R,
dimension: usize,
) -> Result<(), arroy::Error> {
if self.quantized {
arroy::Writer::new(self.quantized_db(), self.index, dimension).build(wtxn, rng, None)
} else {
arroy::Writer::new(self.angular_db(), self.index, dimension).build(wtxn, rng, None)
}
}
pub fn add_item(
&self,
wtxn: &mut RwTxn,
dimension: usize,
item_id: arroy::ItemId,
vector: &[f32],
) -> Result<(), arroy::Error> {
if self.quantized {
arroy::Writer::new(self.quantized_db(), self.index, dimension)
.add_item(wtxn, item_id, vector)
} else {
arroy::Writer::new(self.angular_db(), self.index, dimension)
.add_item(wtxn, item_id, vector)
}
}
pub fn del_item(
&self,
wtxn: &mut RwTxn,
dimension: usize,
item_id: arroy::ItemId,
) -> Result<bool, arroy::Error> {
if self.quantized {
arroy::Writer::new(self.quantized_db(), self.index, dimension).del_item(wtxn, item_id)
} else {
arroy::Writer::new(self.angular_db(), self.index, dimension).del_item(wtxn, item_id)
}
}
pub fn clear(&self, wtxn: &mut RwTxn, dimension: usize) -> Result<(), arroy::Error> {
if self.quantized {
arroy::Writer::new(self.quantized_db(), self.index, dimension).clear(wtxn)
} else {
arroy::Writer::new(self.angular_db(), self.index, dimension).clear(wtxn)
}
}
pub fn is_empty(&self, rtxn: &RoTxn, dimension: usize) -> Result<bool, arroy::Error> {
if self.quantized {
arroy::Writer::new(self.quantized_db(), self.index, dimension).is_empty(rtxn)
} else {
arroy::Writer::new(self.angular_db(), self.index, dimension).is_empty(rtxn)
}
}
pub fn contains_item(
&self,
rtxn: &RoTxn,
dimension: usize,
item: arroy::ItemId,
) -> Result<bool, arroy::Error> {
if self.quantized {
arroy::Writer::new(self.quantized_db(), self.index, dimension).contains_item(rtxn, item)
} else {
arroy::Writer::new(self.angular_db(), self.index, dimension).contains_item(rtxn, item)
}
}
pub fn nns_by_item(
&self,
rtxn: &RoTxn,
item: ItemId,
limit: usize,
filter: Option<&RoaringBitmap>,
) -> Result<Option<Vec<(ItemId, f32)>>, arroy::Error> {
if self.quantized {
arroy::Reader::open(rtxn, self.index, self.quantized_db())?
.nns_by_item(rtxn, item, limit, None, None, filter)
} else {
arroy::Reader::open(rtxn, self.index, self.angular_db())?
.nns_by_item(rtxn, item, limit, None, None, filter)
}
}
pub fn nns_by_vector(
&self,
txn: &RoTxn,
item: &[f32],
limit: usize,
filter: Option<&RoaringBitmap>,
) -> Result<Vec<(ItemId, f32)>, arroy::Error> {
if self.quantized {
arroy::Reader::open(txn, self.index, self.quantized_db())?
.nns_by_vector(txn, item, limit, None, None, filter)
} else {
arroy::Reader::open(txn, self.index, self.angular_db())?
.nns_by_vector(txn, item, limit, None, None, filter)
}
}
pub fn item_vector(&self, rtxn: &RoTxn, docid: u32) -> Result<Option<Vec<f32>>, arroy::Error> {
if self.quantized {
arroy::Reader::open(rtxn, self.index, self.quantized_db())?.item_vector(rtxn, docid)
} else {
arroy::Reader::open(rtxn, self.index, self.angular_db())?.item_vector(rtxn, docid)
}
}
fn angular_db(&self) -> arroy::Database<Angular> {
self.database.remap_data_type()
}
fn quantized_db(&self) -> arroy::Database<BinaryQuantizedAngular> {
self.database.remap_data_type()
}
}
/// One or multiple embeddings stored consecutively in a flat vector.
pub struct Embeddings<F> {
data: Vec<F>,
@ -124,39 +293,48 @@ pub struct EmbeddingConfig {
pub embedder_options: EmbedderOptions,
/// Document template
pub prompt: PromptData,
/// If this embedder is binary quantized
pub quantized: Option<bool>,
// TODO: add metrics and anything needed
}
impl EmbeddingConfig {
pub fn quantized(&self) -> bool {
self.quantized.unwrap_or_default()
}
}
/// Map of embedder configurations.
///
/// Each configuration is mapped to a name.
#[derive(Clone, Default)]
pub struct EmbeddingConfigs(HashMap<String, (Arc<Embedder>, Arc<Prompt>)>);
pub struct EmbeddingConfigs(HashMap<String, (Arc<Embedder>, Arc<Prompt>, bool)>);
impl EmbeddingConfigs {
/// Create the map from its internal component.s
pub fn new(data: HashMap<String, (Arc<Embedder>, Arc<Prompt>)>) -> Self {
pub fn new(data: HashMap<String, (Arc<Embedder>, Arc<Prompt>, bool)>) -> Self {
Self(data)
}
/// Get an embedder configuration and template from its name.
pub fn get(&self, name: &str) -> Option<(Arc<Embedder>, Arc<Prompt>)> {
pub fn get(&self, name: &str) -> Option<(Arc<Embedder>, Arc<Prompt>, bool)> {
self.0.get(name).cloned()
}
pub fn inner_as_ref(&self) -> &HashMap<String, (Arc<Embedder>, Arc<Prompt>)> {
pub fn inner_as_ref(&self) -> &HashMap<String, (Arc<Embedder>, Arc<Prompt>, bool)> {
&self.0
}
pub fn into_inner(self) -> HashMap<String, (Arc<Embedder>, Arc<Prompt>)> {
pub fn into_inner(self) -> HashMap<String, (Arc<Embedder>, Arc<Prompt>, bool)> {
self.0
}
}
impl IntoIterator for EmbeddingConfigs {
type Item = (String, (Arc<Embedder>, Arc<Prompt>));
type Item = (String, (Arc<Embedder>, Arc<Prompt>, bool));
type IntoIter = std::collections::hash_map::IntoIter<String, (Arc<Embedder>, Arc<Prompt>)>;
type IntoIter =
std::collections::hash_map::IntoIter<String, (Arc<Embedder>, Arc<Prompt>, bool)>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()