Small commit to add hybrid search and autoembedding

This commit is contained in:
Louis Dureuil 2023-11-15 15:46:37 +01:00
parent 21bcf32109
commit 13c2c6c16b
No known key found for this signature in database
42 changed files with 4045 additions and 246 deletions

View file

@ -4,11 +4,12 @@ mod helpers;
mod transform;
mod typed_chunk;
use std::collections::HashSet;
use std::collections::{HashMap, HashSet};
use std::io::{Cursor, Read, Seek};
use std::iter::FromIterator;
use std::num::NonZeroU32;
use std::result::Result as StdResult;
use std::sync::Arc;
use crossbeam_channel::{Receiver, Sender};
use heed::types::Str;
@ -32,10 +33,12 @@ use self::helpers::{grenad_obkv_into_chunks, GrenadParameters};
pub use self::transform::{Transform, TransformOutput};
use crate::documents::{obkv_to_object, DocumentsBatchReader};
use crate::error::{Error, InternalError, UserError};
use crate::prompt::Prompt;
pub use crate::update::index_documents::helpers::CursorClonableMmap;
use crate::update::{
IndexerConfig, UpdateIndexingStep, WordPrefixDocids, WordPrefixIntegerDocids, WordsPrefixesFst,
};
use crate::vector::Embedder;
use crate::{CboRoaringBitmapCodec, Index, Result};
static MERGED_DATABASE_COUNT: usize = 7;
@ -78,6 +81,7 @@ pub struct IndexDocuments<'t, 'i, 'a, FP, FA> {
should_abort: FA,
added_documents: u64,
deleted_documents: u64,
embedders: HashMap<String, (Arc<Embedder>, Arc<Prompt>)>,
}
#[derive(Default, Debug, Clone)]
@ -121,6 +125,7 @@ where
index,
added_documents: 0,
deleted_documents: 0,
embedders: Default::default(),
})
}
@ -167,6 +172,14 @@ where
Ok((self, Ok(indexed_documents)))
}
pub fn with_embedders(
mut self,
embedders: HashMap<String, (Arc<Embedder>, Arc<Prompt>)>,
) -> Self {
self.embedders = embedders;
self
}
/// Remove a batch of documents from the current builder.
///
/// Returns the number of documents deleted from the builder.
@ -322,17 +335,18 @@ where
// get filterable fields for facet databases
let faceted_fields = self.index.faceted_fields_ids(self.wtxn)?;
// get the fid of the `_geo.lat` and `_geo.lng` fields.
let geo_fields_ids = match self.index.fields_ids_map(self.wtxn)?.id("_geo") {
let mut field_id_map = self.index.fields_ids_map(self.wtxn)?;
// self.index.fields_ids_map($a)? ==>> field_id_map
let geo_fields_ids = match field_id_map.id("_geo") {
Some(gfid) => {
let is_sortable = self.index.sortable_fields_ids(self.wtxn)?.contains(&gfid);
let is_filterable = self.index.filterable_fields_ids(self.wtxn)?.contains(&gfid);
// if `_geo` is faceted then we get the `lat` and `lng`
if is_sortable || is_filterable {
let field_ids = self
.index
.fields_ids_map(self.wtxn)?
let field_ids = field_id_map
.insert("_geo.lat")
.zip(self.index.fields_ids_map(self.wtxn)?.insert("_geo.lng"))
.zip(field_id_map.insert("_geo.lng"))
.ok_or(UserError::AttributeLimitReached)?;
Some(field_ids)
} else {
@ -341,8 +355,6 @@ where
}
None => None,
};
// get the fid of the `_vectors` field.
let vectors_field_id = self.index.fields_ids_map(self.wtxn)?.id("_vectors");
let stop_words = self.index.stop_words(self.wtxn)?;
let separators = self.index.allowed_separators(self.wtxn)?;
@ -364,6 +376,8 @@ where
self.indexer_config.documents_chunk_size.unwrap_or(1024 * 1024 * 4); // 4MiB
let max_positions_per_attributes = self.indexer_config.max_positions_per_attributes;
let cloned_embedder = self.embedders.clone();
// Run extraction pipeline in parallel.
pool.install(|| {
puffin::profile_scope!("extract_and_send_grenad_chunks");
@ -387,13 +401,14 @@ where
faceted_fields,
primary_key_id,
geo_fields_ids,
vectors_field_id,
field_id_map,
stop_words,
separators.as_deref(),
dictionary.as_deref(),
max_positions_per_attributes,
exact_attributes,
proximity_precision,
cloned_embedder,
)
});
@ -2505,7 +2520,7 @@ mod tests {
.unwrap();
let rtxn = index.read_txn().unwrap();
let res = index.search(&rtxn).vector([0.0, 1.0, 2.0]).execute().unwrap();
let res = index.search(&rtxn).vector([0.0, 1.0, 2.0].to_vec()).execute().unwrap();
assert_eq!(res.documents_ids.len(), 3);
}