mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-07-04 20:37:15 +02:00
WIP multi embedders
fixed template bugs
This commit is contained in:
parent
abbe131084
commit
922a640188
20 changed files with 438 additions and 158 deletions
|
@ -71,8 +71,8 @@ impl VectorStateDelta {
|
|||
pub fn extract_vector_points<R: io::Read + io::Seek>(
|
||||
obkv_documents: grenad::Reader<R>,
|
||||
indexer: GrenadParameters,
|
||||
field_id_map: FieldsIdsMap,
|
||||
prompt: Option<&Prompt>,
|
||||
field_id_map: &FieldsIdsMap,
|
||||
prompt: &Prompt,
|
||||
) -> Result<ExtractedVectorPoints> {
|
||||
puffin::profile_function!();
|
||||
|
||||
|
@ -142,14 +142,11 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
|
|||
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
|
||||
if document_is_kept {
|
||||
// becomes autogenerated
|
||||
match prompt {
|
||||
Some(prompt) => VectorStateDelta::NowGenerated(prompt.render(
|
||||
obkv,
|
||||
DelAdd::Addition,
|
||||
&field_id_map,
|
||||
)?),
|
||||
None => VectorStateDelta::NowRemoved,
|
||||
}
|
||||
VectorStateDelta::NowGenerated(prompt.render(
|
||||
obkv,
|
||||
DelAdd::Addition,
|
||||
field_id_map,
|
||||
)?)
|
||||
} else {
|
||||
VectorStateDelta::NowRemoved
|
||||
}
|
||||
|
@ -162,26 +159,18 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
|
|||
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
|
||||
|
||||
if document_is_kept {
|
||||
match prompt {
|
||||
Some(prompt) => {
|
||||
// Don't give up if the old prompt was failing
|
||||
let old_prompt = prompt
|
||||
.render(obkv, DelAdd::Deletion, &field_id_map)
|
||||
.unwrap_or_default();
|
||||
let new_prompt =
|
||||
prompt.render(obkv, DelAdd::Addition, &field_id_map)?;
|
||||
if old_prompt != new_prompt {
|
||||
log::trace!(
|
||||
"🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}"
|
||||
);
|
||||
VectorStateDelta::NowGenerated(new_prompt)
|
||||
} else {
|
||||
log::trace!("⏭️ Prompt unmodified, skipping");
|
||||
VectorStateDelta::NoChange
|
||||
}
|
||||
}
|
||||
// We no longer have a prompt, so we need to remove any existing vector
|
||||
None => VectorStateDelta::NowRemoved,
|
||||
// Don't give up if the old prompt was failing
|
||||
let old_prompt =
|
||||
prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default();
|
||||
let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?;
|
||||
if old_prompt != new_prompt {
|
||||
log::trace!(
|
||||
"🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}"
|
||||
);
|
||||
VectorStateDelta::NowGenerated(new_prompt)
|
||||
} else {
|
||||
log::trace!("⏭️ Prompt unmodified, skipping");
|
||||
VectorStateDelta::NoChange
|
||||
}
|
||||
} else {
|
||||
VectorStateDelta::NowRemoved
|
||||
|
@ -196,24 +185,16 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
|
|||
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
|
||||
|
||||
if document_is_kept {
|
||||
match prompt {
|
||||
Some(prompt) => {
|
||||
// Don't give up if the old prompt was failing
|
||||
let old_prompt = prompt
|
||||
.render(obkv, DelAdd::Deletion, &field_id_map)
|
||||
.unwrap_or_default();
|
||||
let new_prompt = prompt.render(obkv, DelAdd::Addition, &field_id_map)?;
|
||||
if old_prompt != new_prompt {
|
||||
log::trace!(
|
||||
"🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}"
|
||||
);
|
||||
VectorStateDelta::NowGenerated(new_prompt)
|
||||
} else {
|
||||
log::trace!("⏭️ Prompt unmodified, skipping");
|
||||
VectorStateDelta::NoChange
|
||||
}
|
||||
}
|
||||
None => VectorStateDelta::NowRemoved,
|
||||
// Don't give up if the old prompt was failing
|
||||
let old_prompt =
|
||||
prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default();
|
||||
let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?;
|
||||
if old_prompt != new_prompt {
|
||||
log::trace!("🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}");
|
||||
VectorStateDelta::NowGenerated(new_prompt)
|
||||
} else {
|
||||
log::trace!("⏭️ Prompt unmodified, skipping");
|
||||
VectorStateDelta::NoChange
|
||||
}
|
||||
} else {
|
||||
VectorStateDelta::NowRemoved
|
||||
|
@ -322,7 +303,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
|||
prompt_reader: grenad::Reader<R>,
|
||||
indexer: GrenadParameters,
|
||||
embedder: Arc<Embedder>,
|
||||
) -> Result<(grenad::Reader<BufReader<File>>, Option<usize>)> {
|
||||
) -> Result<grenad::Reader<BufReader<File>>> {
|
||||
let rt = tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build()?;
|
||||
|
||||
let n_chunks = embedder.chunk_count_hint(); // chunk level parellelism
|
||||
|
@ -341,8 +322,6 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
|||
let mut chunks_ids = Vec::with_capacity(n_chunks);
|
||||
let mut cursor = prompt_reader.into_cursor()?;
|
||||
|
||||
let mut expected_dimension = None;
|
||||
|
||||
while let Some((key, value)) = cursor.move_on_next()? {
|
||||
let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap();
|
||||
// SAFETY: precondition, the grenad value was saved from a string
|
||||
|
@ -367,7 +346,6 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
|||
.embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))),
|
||||
)
|
||||
.map_err(crate::vector::Error::from)
|
||||
.map_err(crate::UserError::from)
|
||||
.map_err(crate::Error::from)?;
|
||||
|
||||
for (docid, embeddings) in chunks_ids
|
||||
|
@ -376,7 +354,6 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
|||
.zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter()))
|
||||
{
|
||||
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
|
||||
expected_dimension = Some(embeddings.dimension());
|
||||
}
|
||||
chunks_ids.clear();
|
||||
}
|
||||
|
@ -387,7 +364,6 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
|||
let chunked_embeds = rt
|
||||
.block_on(embedder.embed_chunks(std::mem::take(&mut chunks)))
|
||||
.map_err(crate::vector::Error::from)
|
||||
.map_err(crate::UserError::from)
|
||||
.map_err(crate::Error::from)?;
|
||||
for (docid, embeddings) in chunks_ids
|
||||
.iter()
|
||||
|
@ -395,7 +371,6 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
|||
.zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter()))
|
||||
{
|
||||
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
|
||||
expected_dimension = Some(embeddings.dimension());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -403,14 +378,12 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
|||
let embeds = rt
|
||||
.block_on(embedder.embed(std::mem::take(&mut current_chunk)))
|
||||
.map_err(crate::vector::Error::from)
|
||||
.map_err(crate::UserError::from)
|
||||
.map_err(crate::Error::from)?;
|
||||
|
||||
for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) {
|
||||
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
|
||||
expected_dimension = Some(embeddings.dimension());
|
||||
}
|
||||
}
|
||||
|
||||
Ok((writer_into_reader(state_writer)?, expected_dimension))
|
||||
writer_into_reader(state_writer)
|
||||
}
|
||||
|
|
|
@ -292,43 +292,42 @@ fn send_original_documents_data(
|
|||
let documents_chunk_cloned = original_documents_chunk.clone();
|
||||
let lmdb_writer_sx_cloned = lmdb_writer_sx.clone();
|
||||
rayon::spawn(move || {
|
||||
let (embedder, prompt) = embedders.get("default").cloned().unzip();
|
||||
let result =
|
||||
extract_vector_points(documents_chunk_cloned, indexer, field_id_map, prompt.as_deref());
|
||||
match result {
|
||||
Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => {
|
||||
/// FIXME: support multiple embedders
|
||||
let results = embedder.and_then(|embedder| {
|
||||
match extract_embeddings(prompts, indexer, embedder.clone()) {
|
||||
for (name, (embedder, prompt)) in embedders {
|
||||
let result = extract_vector_points(
|
||||
documents_chunk_cloned.clone(),
|
||||
indexer,
|
||||
&field_id_map,
|
||||
&prompt,
|
||||
);
|
||||
match result {
|
||||
Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => {
|
||||
let embeddings = match extract_embeddings(prompts, indexer, embedder.clone()) {
|
||||
Ok(results) => Some(results),
|
||||
Err(error) => {
|
||||
let _ = lmdb_writer_sx_cloned.send(Err(error));
|
||||
None
|
||||
}
|
||||
}
|
||||
});
|
||||
let (embeddings, expected_dimension) = results.unzip();
|
||||
let expected_dimension = expected_dimension.flatten();
|
||||
if !(remove_vectors.is_empty()
|
||||
&& manual_vectors.is_empty()
|
||||
&& embeddings.as_ref().map_or(true, |e| e.is_empty()))
|
||||
{
|
||||
/// FIXME FIXME FIXME
|
||||
if expected_dimension.is_some() {
|
||||
};
|
||||
|
||||
if !(remove_vectors.is_empty()
|
||||
&& manual_vectors.is_empty()
|
||||
&& embeddings.as_ref().map_or(true, |e| e.is_empty()))
|
||||
{
|
||||
let _ = lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints {
|
||||
remove_vectors,
|
||||
embeddings,
|
||||
/// FIXME: compute an expected dimension from the manual vectors if any
|
||||
expected_dimension: expected_dimension.unwrap(),
|
||||
expected_dimension: embedder.dimensions(),
|
||||
manual_vectors,
|
||||
embedder_name: name,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
Err(error) => {
|
||||
let _ = lmdb_writer_sx_cloned.send(Err(error));
|
||||
}
|
||||
}
|
||||
Err(error) => {
|
||||
let _ = lmdb_writer_sx_cloned.send(Err(error));
|
||||
}
|
||||
};
|
||||
}
|
||||
});
|
||||
|
||||
// TODO: create a custom internal error
|
||||
|
|
|
@ -435,7 +435,7 @@ where
|
|||
let mut word_docids = None;
|
||||
let mut exact_word_docids = None;
|
||||
|
||||
let mut dimension = None;
|
||||
let mut dimension = HashMap::new();
|
||||
|
||||
for result in lmdb_writer_rx {
|
||||
if (self.should_abort)() {
|
||||
|
@ -471,13 +471,15 @@ where
|
|||
remove_vectors,
|
||||
embeddings,
|
||||
manual_vectors,
|
||||
embedder_name,
|
||||
} => {
|
||||
dimension = Some(expected_dimension);
|
||||
dimension.insert(embedder_name.clone(), expected_dimension);
|
||||
TypedChunk::VectorPoints {
|
||||
remove_vectors,
|
||||
embeddings,
|
||||
expected_dimension,
|
||||
manual_vectors,
|
||||
embedder_name,
|
||||
}
|
||||
}
|
||||
otherwise => otherwise,
|
||||
|
@ -513,14 +515,22 @@ where
|
|||
self.index.put_primary_key(self.wtxn, &primary_key)?;
|
||||
let number_of_documents = self.index.number_of_documents(self.wtxn)?;
|
||||
|
||||
if let Some(dimension) = dimension {
|
||||
for (embedder_name, dimension) in dimension {
|
||||
let wtxn = &mut *self.wtxn;
|
||||
let vector_arroy = self.index.vector_arroy;
|
||||
/// FIXME: unwrap
|
||||
let embedder_index =
|
||||
self.index.embedder_category_id.get(wtxn, &embedder_name)?.unwrap();
|
||||
pool.install(|| {
|
||||
/// FIXME: do for each embedder
|
||||
let writer_index = (embedder_index as u16) << 8;
|
||||
let mut rng = rand::rngs::StdRng::from_entropy();
|
||||
for k in 0..=u8::MAX {
|
||||
let writer = arroy::Writer::prepare(wtxn, vector_arroy, k.into(), dimension)?;
|
||||
let writer = arroy::Writer::prepare(
|
||||
wtxn,
|
||||
vector_arroy,
|
||||
writer_index | (k as u16),
|
||||
dimension,
|
||||
)?;
|
||||
if writer.is_empty(wtxn)? {
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -47,6 +47,7 @@ pub(crate) enum TypedChunk {
|
|||
embeddings: Option<grenad::Reader<BufReader<File>>>,
|
||||
expected_dimension: usize,
|
||||
manual_vectors: grenad::Reader<BufReader<File>>,
|
||||
embedder_name: String,
|
||||
},
|
||||
ScriptLanguageDocids(HashMap<(Script, Language), (RoaringBitmap, RoaringBitmap)>),
|
||||
}
|
||||
|
@ -100,8 +101,8 @@ impl TypedChunk {
|
|||
TypedChunk::GeoPoints(grenad) => {
|
||||
format!("GeoPoints {{ number_of_entries: {} }}", grenad.len())
|
||||
}
|
||||
TypedChunk::VectorPoints{ remove_vectors, manual_vectors, embeddings, expected_dimension } => {
|
||||
format!("VectorPoints {{ remove_vectors: {}, manual_vectors: {}, embeddings: {}, dimension: {} }}", remove_vectors.len(), manual_vectors.len(), embeddings.as_ref().map(|e| e.len()).unwrap_or_default(), expected_dimension)
|
||||
TypedChunk::VectorPoints{ remove_vectors, manual_vectors, embeddings, expected_dimension, embedder_name } => {
|
||||
format!("VectorPoints {{ remove_vectors: {}, manual_vectors: {}, embeddings: {}, dimension: {}, embedder_name: {} }}", remove_vectors.len(), manual_vectors.len(), embeddings.as_ref().map(|e| e.len()).unwrap_or_default(), expected_dimension, embedder_name)
|
||||
}
|
||||
TypedChunk::ScriptLanguageDocids(sl_map) => {
|
||||
format!("ScriptLanguageDocids {{ number_of_entries: {} }}", sl_map.len())
|
||||
|
@ -360,12 +361,20 @@ pub(crate) fn write_typed_chunk_into_index(
|
|||
manual_vectors,
|
||||
embeddings,
|
||||
expected_dimension,
|
||||
embedder_name,
|
||||
} => {
|
||||
/// FIXME: allow customizing distance
|
||||
/// FIXME: unwrap
|
||||
let embedder_index = index.embedder_category_id.get(wtxn, &embedder_name)?.unwrap();
|
||||
let writer_index = (embedder_index as u16) << 8;
|
||||
// FIXME: allow customizing distance
|
||||
let writers: std::result::Result<Vec<_>, _> = (0..=u8::MAX)
|
||||
.map(|k| {
|
||||
/// FIXME: allow customizing index and then do index << 8 + k
|
||||
arroy::Writer::prepare(wtxn, index.vector_arroy, k.into(), expected_dimension)
|
||||
arroy::Writer::prepare(
|
||||
wtxn,
|
||||
index.vector_arroy,
|
||||
writer_index | (k as u16),
|
||||
expected_dimension,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
let writers = writers?;
|
||||
|
@ -456,7 +465,7 @@ pub(crate) fn write_typed_chunk_into_index(
|
|||
}
|
||||
}
|
||||
|
||||
log::debug!("There are 🤷♀️ entries in the arroy so far");
|
||||
log::debug!("Finished vector chunk for {}", embedder_name);
|
||||
}
|
||||
TypedChunk::ScriptLanguageDocids(sl_map) => {
|
||||
for (key, (deletion, addition)) in sl_map {
|
||||
|
|
|
@ -431,7 +431,6 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
|
|||
let embedder = Arc::new(
|
||||
Embedder::new(embedder_options.clone())
|
||||
.map_err(crate::vector::Error::from)
|
||||
.map_err(crate::UserError::from)
|
||||
.map_err(crate::Error::from)?,
|
||||
);
|
||||
Ok((name, (embedder, prompt)))
|
||||
|
@ -976,6 +975,19 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
|
|||
Setting::NotSet => Some((name, EmbeddingSettings::default().into())),
|
||||
})
|
||||
.collect();
|
||||
|
||||
self.index.embedder_category_id.clear(self.wtxn)?;
|
||||
for (index, (embedder_name, _)) in new_configs.iter().enumerate() {
|
||||
self.index.embedder_category_id.put_with_flags(
|
||||
self.wtxn,
|
||||
heed::PutFlags::APPEND,
|
||||
embedder_name,
|
||||
&index
|
||||
.try_into()
|
||||
.map_err(|_| UserError::TooManyEmbedders(new_configs.len()))?,
|
||||
)?;
|
||||
}
|
||||
|
||||
if new_configs.is_empty() {
|
||||
self.index.delete_embedding_configs(self.wtxn)?;
|
||||
} else {
|
||||
|
@ -1062,7 +1074,7 @@ fn validate_prompt(
|
|||
match new {
|
||||
Setting::Set(EmbeddingSettings {
|
||||
embedder_options,
|
||||
prompt:
|
||||
document_template:
|
||||
Setting::Set(PromptSettings { template: Setting::Set(template), strategy, fallback }),
|
||||
}) => {
|
||||
// validate
|
||||
|
@ -1072,7 +1084,7 @@ fn validate_prompt(
|
|||
|
||||
Ok(Setting::Set(EmbeddingSettings {
|
||||
embedder_options,
|
||||
prompt: Setting::Set(PromptSettings {
|
||||
document_template: Setting::Set(PromptSettings {
|
||||
template: Setting::Set(template),
|
||||
strategy,
|
||||
fallback,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue