mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-01-09 21:14:30 +01:00
WIP
- manual embedder - multi embedders OK - clippy + tests OK
This commit is contained in:
parent
922a640188
commit
12940d79a9
@ -305,6 +305,7 @@ NoSpaceLeftOnDevice , System , UNPROCESSABLE_ENT
|
||||
PayloadTooLarge , InvalidRequest , PAYLOAD_TOO_LARGE ;
|
||||
TaskNotFound , InvalidRequest , NOT_FOUND ;
|
||||
TooManyOpenFiles , System , UNPROCESSABLE_ENTITY ;
|
||||
TooManyVectors , InvalidRequest , BAD_REQUEST ;
|
||||
UnretrievableDocument , Internal , BAD_REQUEST ;
|
||||
UnretrievableErrorCode , InvalidRequest , BAD_REQUEST ;
|
||||
UnsupportedMediaType , InvalidRequest , UNSUPPORTED_MEDIA_TYPE ;
|
||||
@ -362,7 +363,9 @@ impl ErrorCode for milli::Error {
|
||||
UserError::CriterionError(_) => Code::InvalidSettingsRankingRules,
|
||||
UserError::InvalidGeoField { .. } => Code::InvalidDocumentGeoField,
|
||||
UserError::InvalidVectorDimensions { .. } => Code::InvalidVectorDimensions,
|
||||
UserError::InvalidVectorsMapType { .. } => Code::InvalidVectorsType,
|
||||
UserError::InvalidVectorsType { .. } => Code::InvalidVectorsType,
|
||||
UserError::TooManyVectors(_, _) => Code::TooManyVectors,
|
||||
UserError::SortError(_) => Code::InvalidSearchSort,
|
||||
UserError::InvalidMinTypoWordLenSetting(_, _) => {
|
||||
Code::InvalidSettingsTypoTolerance
|
||||
|
@ -235,14 +235,14 @@ pub async fn embed(
|
||||
index_scheduler: &IndexScheduler,
|
||||
index: &milli::Index,
|
||||
) -> Result<(), ResponseError> {
|
||||
if let Some(VectorQuery::String(prompt)) = query.vector.take() {
|
||||
match query.vector.take() {
|
||||
Some(VectorQuery::String(prompt)) => {
|
||||
let embedder_configs = index.embedding_configs(&index.read_txn()?)?;
|
||||
let embedder = index_scheduler.embedders(embedder_configs)?;
|
||||
|
||||
let embedder_name = if let Some(HybridQuery {
|
||||
semantic_ratio: _,
|
||||
embedder: Some(embedder),
|
||||
}) = &query.hybrid
|
||||
let embedder_name =
|
||||
if let Some(HybridQuery { semantic_ratio: _, embedder: Some(embedder) }) =
|
||||
&query.hybrid
|
||||
{
|
||||
embedder
|
||||
} else {
|
||||
@ -263,10 +263,14 @@ pub async fn embed(
|
||||
|
||||
if embeddings.iter().nth(1).is_some() {
|
||||
warn!("Ignoring embeddings past the first one in long search query");
|
||||
query.vector = Some(VectorQuery::Vector(embeddings.iter().next().unwrap().to_vec()));
|
||||
query.vector =
|
||||
Some(VectorQuery::Vector(embeddings.iter().next().unwrap().to_vec()));
|
||||
} else {
|
||||
query.vector = Some(VectorQuery::Vector(embeddings.into_inner()));
|
||||
}
|
||||
}
|
||||
Some(vector) => query.vector = Some(vector),
|
||||
None => {}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
@ -13,7 +13,7 @@ use meilisearch_types::deserr::DeserrJsonError;
|
||||
use meilisearch_types::error::deserr_codes::*;
|
||||
use meilisearch_types::heed::RoTxn;
|
||||
use meilisearch_types::index_uid::IndexUid;
|
||||
use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy};
|
||||
use meilisearch_types::milli::score_details::{self, ScoreDetails, ScoringStrategy};
|
||||
use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues, VectorQuery};
|
||||
use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS;
|
||||
use meilisearch_types::{milli, Document};
|
||||
@ -562,8 +562,17 @@ pub fn perform_search(
|
||||
insert_geo_distance(sort, &mut document);
|
||||
}
|
||||
|
||||
/// FIXME: remove this or set to value from the score details
|
||||
let semantic_score = None;
|
||||
let mut semantic_score = None;
|
||||
for details in &score {
|
||||
if let ScoreDetails::Vector(score_details::Vector {
|
||||
target_vector: _,
|
||||
value_similarity: Some((_matching_vector, similarity)),
|
||||
}) = details
|
||||
{
|
||||
semantic_score = Some(*similarity);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let ranking_score =
|
||||
query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter()));
|
||||
@ -648,8 +657,10 @@ pub fn perform_search(
|
||||
hits: documents,
|
||||
hits_info,
|
||||
query: query.q.unwrap_or_default(),
|
||||
// FIXME: display input vector
|
||||
vector: None,
|
||||
vector: match query.vector {
|
||||
Some(VectorQuery::Vector(vector)) => Some(vector),
|
||||
_ => None,
|
||||
},
|
||||
processing_time_ms: before_search.elapsed().as_millis(),
|
||||
facet_distribution,
|
||||
facet_stats,
|
||||
|
@ -77,7 +77,8 @@ async fn import_dump_v1_movie_raw() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
}
|
||||
},
|
||||
"embedders": {}
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -238,7 +239,8 @@ async fn import_dump_v1_movie_with_settings() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
}
|
||||
},
|
||||
"embedders": {}
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -385,7 +387,8 @@ async fn import_dump_v1_rubygems_with_settings() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
}
|
||||
},
|
||||
"embedders": {}
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -518,7 +521,8 @@ async fn import_dump_v2_movie_raw() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
}
|
||||
},
|
||||
"embedders": {}
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -663,7 +667,8 @@ async fn import_dump_v2_movie_with_settings() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
}
|
||||
},
|
||||
"embedders": {}
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -807,7 +812,8 @@ async fn import_dump_v2_rubygems_with_settings() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
}
|
||||
},
|
||||
"embedders": {}
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -940,7 +946,8 @@ async fn import_dump_v3_movie_raw() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
}
|
||||
},
|
||||
"embedders": {}
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -1085,7 +1092,8 @@ async fn import_dump_v3_movie_with_settings() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
}
|
||||
},
|
||||
"embedders": {}
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -1229,7 +1237,8 @@ async fn import_dump_v3_rubygems_with_settings() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
}
|
||||
},
|
||||
"embedders": {}
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -1362,7 +1371,8 @@ async fn import_dump_v4_movie_raw() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
}
|
||||
},
|
||||
"embedders": {}
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -1507,7 +1517,8 @@ async fn import_dump_v4_movie_with_settings() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
}
|
||||
},
|
||||
"embedders": {}
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -1651,7 +1662,8 @@ async fn import_dump_v4_rubygems_with_settings() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
}
|
||||
},
|
||||
"embedders": {}
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -1896,7 +1908,8 @@ async fn import_dump_v6_containing_experimental_features() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
}
|
||||
},
|
||||
"embedders": {}
|
||||
}
|
||||
"###);
|
||||
|
||||
|
@ -876,7 +876,31 @@ async fn experimental_feature_vector_store() {
|
||||
}))
|
||||
.await;
|
||||
meili_snap::snapshot!(code, @"200 OK");
|
||||
meili_snap::snapshot!(meili_snap::json_string!(response["hits"]), @"[]");
|
||||
// vector search returns all documents that don't have vectors in the last bucket, like all sorts
|
||||
meili_snap::snapshot!(meili_snap::json_string!(response["hits"]), @r###"
|
||||
[
|
||||
{
|
||||
"title": "Shazam!",
|
||||
"id": "287947"
|
||||
},
|
||||
{
|
||||
"title": "Captain Marvel",
|
||||
"id": "299537"
|
||||
},
|
||||
{
|
||||
"title": "Escape Room",
|
||||
"id": "522681"
|
||||
},
|
||||
{
|
||||
"title": "How to Train Your Dragon: The Hidden World",
|
||||
"id": "166428"
|
||||
},
|
||||
{
|
||||
"title": "Gläss",
|
||||
"id": "450465"
|
||||
}
|
||||
]
|
||||
"###);
|
||||
}
|
||||
|
||||
#[cfg(feature = "default")]
|
||||
|
@ -54,7 +54,7 @@ async fn get_settings() {
|
||||
let (response, code) = index.settings().await;
|
||||
assert_eq!(code, 200);
|
||||
let settings = response.as_object().unwrap();
|
||||
assert_eq!(settings.keys().len(), 15);
|
||||
assert_eq!(settings.keys().len(), 16);
|
||||
assert_eq!(settings["displayedAttributes"], json!(["*"]));
|
||||
assert_eq!(settings["searchableAttributes"], json!(["*"]));
|
||||
assert_eq!(settings["filterableAttributes"], json!([]));
|
||||
@ -83,6 +83,7 @@ async fn get_settings() {
|
||||
"maxTotalHits": 1000,
|
||||
})
|
||||
);
|
||||
assert_eq!(settings["embedders"], json!({}));
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
|
@ -114,8 +114,10 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco
|
||||
InvalidGeoField(#[from] GeoError),
|
||||
#[error("Invalid vector dimensions: expected: `{}`, found: `{}`.", .expected, .found)]
|
||||
InvalidVectorDimensions { expected: usize, found: usize },
|
||||
#[error("The `_vectors` field in the document with the id: `{document_id}` is not an array. Was expecting an array of floats or an array of arrays of floats but instead got `{value}`.")]
|
||||
InvalidVectorsType { document_id: Value, value: Value },
|
||||
#[error("The `_vectors.{subfield}` field in the document with id: `{document_id}` is not an array. Was expecting an array of floats or an array of arrays of floats but instead got `{value}`.")]
|
||||
InvalidVectorsType { document_id: Value, value: Value, subfield: String },
|
||||
#[error("The `_vectors` field in the document with id: `{document_id}` is not an object. Was expecting an object with a key for each embedder with manually provided vectors, but instead got `{value}`")]
|
||||
InvalidVectorsMapType { document_id: Value, value: Value },
|
||||
#[error("{0}")]
|
||||
InvalidFilter(String),
|
||||
#[error("Invalid type for filter subexpression: expected: {}, found: {1}.", .0.join(", "))]
|
||||
@ -196,6 +198,8 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco
|
||||
TooManyEmbedders(usize),
|
||||
#[error("Cannot find embedder with name {0}.")]
|
||||
InvalidEmbedder(String),
|
||||
#[error("Too many vectors for document with id {0}: found {1}, but limited to 256.")]
|
||||
TooManyVectors(String, usize),
|
||||
}
|
||||
|
||||
impl From<crate::vector::Error> for Error {
|
||||
|
@ -73,6 +73,7 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
|
||||
indexer: GrenadParameters,
|
||||
field_id_map: &FieldsIdsMap,
|
||||
prompt: &Prompt,
|
||||
embedder_name: &str,
|
||||
) -> Result<ExtractedVectorPoints> {
|
||||
puffin::profile_function!();
|
||||
|
||||
@ -115,24 +116,33 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
|
||||
// lazily get it when needed
|
||||
let document_id = || -> Value { from_utf8(external_id_bytes).unwrap().into() };
|
||||
|
||||
let delta = if let Some(value) = vectors_fid.and_then(|vectors_fid| obkv.get(vectors_fid)) {
|
||||
let vectors_obkv = KvReaderDelAdd::new(value);
|
||||
match (vectors_obkv.get(DelAdd::Deletion), vectors_obkv.get(DelAdd::Addition)) {
|
||||
let vectors_field = vectors_fid
|
||||
.and_then(|vectors_fid| obkv.get(vectors_fid))
|
||||
.map(KvReaderDelAdd::new)
|
||||
.map(|obkv| to_vector_maps(obkv, document_id))
|
||||
.transpose()?;
|
||||
|
||||
let (del_map, add_map) = vectors_field.unzip();
|
||||
let del_map = del_map.flatten();
|
||||
let add_map = add_map.flatten();
|
||||
|
||||
let del_value = del_map.and_then(|mut map| map.remove(embedder_name));
|
||||
let add_value = add_map.and_then(|mut map| map.remove(embedder_name));
|
||||
|
||||
let delta = match (del_value, add_value) {
|
||||
(Some(old), Some(new)) => {
|
||||
// no autogeneration
|
||||
let del_vectors = extract_vectors(old, document_id)?;
|
||||
let add_vectors = extract_vectors(new, document_id)?;
|
||||
let del_vectors = extract_vectors(old, document_id, embedder_name)?;
|
||||
let add_vectors = extract_vectors(new, document_id, embedder_name)?;
|
||||
|
||||
VectorStateDelta::ManualDelta(
|
||||
del_vectors.unwrap_or_default(),
|
||||
add_vectors.unwrap_or_default(),
|
||||
)
|
||||
if add_vectors.len() > u8::MAX.into() {
|
||||
return Err(crate::Error::UserError(crate::UserError::TooManyVectors(
|
||||
document_id().to_string(),
|
||||
add_vectors.len(),
|
||||
)));
|
||||
}
|
||||
(None, Some(new)) => {
|
||||
// was possibly autogenerated, remove all vectors for that document
|
||||
let add_vectors = extract_vectors(new, document_id)?;
|
||||
|
||||
VectorStateDelta::WasGeneratedNowManual(add_vectors.unwrap_or_default())
|
||||
VectorStateDelta::ManualDelta(del_vectors, add_vectors)
|
||||
}
|
||||
(Some(_old), None) => {
|
||||
// Do we keep this document?
|
||||
@ -151,6 +161,18 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
|
||||
VectorStateDelta::NowRemoved
|
||||
}
|
||||
}
|
||||
(None, Some(new)) => {
|
||||
// was possibly autogenerated, remove all vectors for that document
|
||||
let add_vectors = extract_vectors(new, document_id, embedder_name)?;
|
||||
if add_vectors.len() > u8::MAX.into() {
|
||||
return Err(crate::Error::UserError(crate::UserError::TooManyVectors(
|
||||
document_id().to_string(),
|
||||
add_vectors.len(),
|
||||
)));
|
||||
}
|
||||
|
||||
VectorStateDelta::WasGeneratedNowManual(add_vectors)
|
||||
}
|
||||
(None, None) => {
|
||||
// Do we keep this document?
|
||||
let document_is_kept = obkv
|
||||
@ -176,29 +198,6 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
|
||||
VectorStateDelta::NowRemoved
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Do we keep this document?
|
||||
let document_is_kept = obkv
|
||||
.iter()
|
||||
.map(|(_, deladd)| KvReaderDelAdd::new(deladd))
|
||||
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
|
||||
|
||||
if document_is_kept {
|
||||
// Don't give up if the old prompt was failing
|
||||
let old_prompt =
|
||||
prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default();
|
||||
let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?;
|
||||
if old_prompt != new_prompt {
|
||||
log::trace!("🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}");
|
||||
VectorStateDelta::NowGenerated(new_prompt)
|
||||
} else {
|
||||
log::trace!("⏭️ Prompt unmodified, skipping");
|
||||
VectorStateDelta::NoChange
|
||||
}
|
||||
} else {
|
||||
VectorStateDelta::NowRemoved
|
||||
}
|
||||
};
|
||||
|
||||
// and we finally push the unique vectors into the writer
|
||||
@ -221,6 +220,34 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
|
||||
})
|
||||
}
|
||||
|
||||
fn to_vector_maps(
|
||||
obkv: KvReaderDelAdd,
|
||||
document_id: impl Fn() -> Value,
|
||||
) -> Result<(Option<serde_json::Map<String, Value>>, Option<serde_json::Map<String, Value>>)> {
|
||||
let del = to_vector_map(obkv, DelAdd::Deletion, &document_id)?;
|
||||
let add = to_vector_map(obkv, DelAdd::Addition, &document_id)?;
|
||||
Ok((del, add))
|
||||
}
|
||||
|
||||
fn to_vector_map(
|
||||
obkv: KvReaderDelAdd,
|
||||
side: DelAdd,
|
||||
document_id: &impl Fn() -> Value,
|
||||
) -> Result<Option<serde_json::Map<String, Value>>> {
|
||||
Ok(if let Some(value) = obkv.get(side) {
|
||||
let Ok(value) = from_slice(value) else {
|
||||
let value = from_slice(value).map_err(InternalError::SerdeJson)?;
|
||||
return Err(crate::Error::UserError(UserError::InvalidVectorsMapType {
|
||||
document_id: document_id(),
|
||||
value,
|
||||
}));
|
||||
};
|
||||
Some(value)
|
||||
} else {
|
||||
None
|
||||
})
|
||||
}
|
||||
|
||||
/// Computes the diff between both Del and Add numbers and
|
||||
/// only inserts the parts that differ in the sorter.
|
||||
fn push_vectors_diff(
|
||||
@ -286,12 +313,20 @@ fn compare_vectors(a: &[f32], b: &[f32]) -> Ordering {
|
||||
}
|
||||
|
||||
/// Extracts the vectors from a JSON value.
|
||||
fn extract_vectors(value: &[u8], document_id: impl Fn() -> Value) -> Result<Option<Vec<Vec<f32>>>> {
|
||||
match from_slice(value) {
|
||||
Ok(vectors) => Ok(VectorOrArrayOfVectors::into_array_of_vectors(vectors)),
|
||||
fn extract_vectors(
|
||||
value: Value,
|
||||
document_id: impl Fn() -> Value,
|
||||
name: &str,
|
||||
) -> Result<Vec<Vec<f32>>> {
|
||||
// FIXME: ugly clone of the vectors here
|
||||
match serde_json::from_value(value.clone()) {
|
||||
Ok(vectors) => {
|
||||
Ok(VectorOrArrayOfVectors::into_array_of_vectors(vectors).unwrap_or_default())
|
||||
}
|
||||
Err(_) => Err(UserError::InvalidVectorsType {
|
||||
document_id: document_id(),
|
||||
value: from_slice(value).map_err(InternalError::SerdeJson)?,
|
||||
value,
|
||||
subfield: name.to_owned(),
|
||||
}
|
||||
.into()),
|
||||
}
|
||||
|
@ -298,6 +298,7 @@ fn send_original_documents_data(
|
||||
indexer,
|
||||
&field_id_map,
|
||||
&prompt,
|
||||
&name,
|
||||
);
|
||||
match result {
|
||||
Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => {
|
||||
|
@ -514,16 +514,18 @@ where
|
||||
// We write the primary key field id into the main database
|
||||
self.index.put_primary_key(self.wtxn, &primary_key)?;
|
||||
let number_of_documents = self.index.number_of_documents(self.wtxn)?;
|
||||
let mut rng = rand::rngs::StdRng::from_entropy();
|
||||
|
||||
for (embedder_name, dimension) in dimension {
|
||||
let wtxn = &mut *self.wtxn;
|
||||
let vector_arroy = self.index.vector_arroy;
|
||||
/// FIXME: unwrap
|
||||
let embedder_index =
|
||||
self.index.embedder_category_id.get(wtxn, &embedder_name)?.unwrap();
|
||||
|
||||
let embedder_index = self.index.embedder_category_id.get(wtxn, &embedder_name)?.ok_or(
|
||||
InternalError::DatabaseMissingEntry { db_name: "embedder_category_id", key: None },
|
||||
)?;
|
||||
|
||||
pool.install(|| {
|
||||
let writer_index = (embedder_index as u16) << 8;
|
||||
let mut rng = rand::rngs::StdRng::from_entropy();
|
||||
for k in 0..=u8::MAX {
|
||||
let writer = arroy::Writer::prepare(
|
||||
wtxn,
|
||||
|
@ -22,7 +22,9 @@ use crate::index::db_name::DOCUMENTS;
|
||||
use crate::update::del_add::{deladd_serialize_add_side, DelAdd, KvReaderDelAdd};
|
||||
use crate::update::facet::FacetsUpdate;
|
||||
use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at};
|
||||
use crate::{lat_lng_to_xyz, DocumentId, FieldId, GeoPoint, Index, Result, SerializationError};
|
||||
use crate::{
|
||||
lat_lng_to_xyz, DocumentId, FieldId, GeoPoint, Index, InternalError, Result, SerializationError,
|
||||
};
|
||||
|
||||
pub(crate) enum TypedChunk {
|
||||
FieldIdDocidFacetStrings(grenad::Reader<CursorClonableMmap>),
|
||||
@ -363,8 +365,9 @@ pub(crate) fn write_typed_chunk_into_index(
|
||||
expected_dimension,
|
||||
embedder_name,
|
||||
} => {
|
||||
/// FIXME: unwrap
|
||||
let embedder_index = index.embedder_category_id.get(wtxn, &embedder_name)?.unwrap();
|
||||
let embedder_index = index.embedder_category_id.get(wtxn, &embedder_name)?.ok_or(
|
||||
InternalError::DatabaseMissingEntry { db_name: "embedder_category_id", key: None },
|
||||
)?;
|
||||
let writer_index = (embedder_index as u16) << 8;
|
||||
// FIXME: allow customizing distance
|
||||
let writers: std::result::Result<Vec<_>, _> = (0..=u8::MAX)
|
||||
@ -404,7 +407,20 @@ pub(crate) fn write_typed_chunk_into_index(
|
||||
// code error if we somehow got the wrong dimension
|
||||
.unwrap();
|
||||
|
||||
/// FIXME: detect overflow
|
||||
if embeddings.embedding_count() > u8::MAX.into() {
|
||||
let external_docid = if let Ok(Some(Ok(index))) = index
|
||||
.external_id_of(wtxn, std::iter::once(docid))
|
||||
.map(|it| it.into_iter().next())
|
||||
{
|
||||
index
|
||||
} else {
|
||||
format!("internal docid={docid}")
|
||||
};
|
||||
return Err(crate::Error::UserError(crate::UserError::TooManyVectors(
|
||||
external_docid,
|
||||
embeddings.embedding_count(),
|
||||
)));
|
||||
}
|
||||
for (embedding, writer) in embeddings.iter().zip(&writers) {
|
||||
writer.add_item(wtxn, docid, embedding)?;
|
||||
}
|
||||
@ -455,7 +471,7 @@ pub(crate) fn write_typed_chunk_into_index(
|
||||
if let Some(value) = vector_deladd_obkv.get(DelAdd::Addition) {
|
||||
let vector = pod_collect_to_vec(value);
|
||||
|
||||
/// FIXME: detect overflow
|
||||
// overflow was detected during vector extraction.
|
||||
for writer in &writers {
|
||||
if !writer.contains_item(wtxn, docid)? {
|
||||
writer.add_item(wtxn, docid, &vector)?;
|
||||
|
34
milli/src/vector/manual.rs
Normal file
34
milli/src/vector/manual.rs
Normal file
@ -0,0 +1,34 @@
|
||||
use super::error::EmbedError;
|
||||
use super::Embeddings;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct Embedder {
|
||||
dimensions: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
||||
pub struct EmbedderOptions {
|
||||
pub dimensions: usize,
|
||||
}
|
||||
|
||||
impl Embedder {
|
||||
pub fn new(options: EmbedderOptions) -> Self {
|
||||
Self { dimensions: options.dimensions }
|
||||
}
|
||||
|
||||
pub fn embed(&self, mut texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
let Some(text) = texts.pop() else { return Ok(Default::default()) };
|
||||
Err(EmbedError::embed_on_manual_embedder(text))
|
||||
}
|
||||
|
||||
pub fn dimensions(&self) -> usize {
|
||||
self.dimensions
|
||||
}
|
||||
|
||||
pub fn embed_chunks(
|
||||
&self,
|
||||
text_chunks: Vec<Vec<String>>,
|
||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||
text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect()
|
||||
}
|
||||
}
|
@ -31,6 +31,10 @@ impl<F> Embeddings<F> {
|
||||
Ok(this)
|
||||
}
|
||||
|
||||
pub fn embedding_count(&self) -> usize {
|
||||
self.data.len() / self.dimension
|
||||
}
|
||||
|
||||
pub fn dimension(&self) -> usize {
|
||||
self.dimension
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user