add the retrieveVectors parameter to the get and fetch documents route

This commit is contained in:
Tamo 2024-06-05 23:40:29 +02:00
parent ea61e5cbec
commit 6607875f49
6 changed files with 325 additions and 87 deletions

View file

@ -16,6 +16,7 @@ use meilisearch_types::error::{Code, ResponseError};
use meilisearch_types::heed::RoTxn;
use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::milli::update::IndexDocumentsMethod;
use meilisearch_types::milli::vector::parsed_vectors::ExplicitVectors;
use meilisearch_types::milli::DocumentId;
use meilisearch_types::star_or::OptionStarOrList;
use meilisearch_types::tasks::KindWithContent;
@ -94,6 +95,8 @@ pub fn configure(cfg: &mut web::ServiceConfig) {
pub struct GetDocument {
#[deserr(default, error = DeserrQueryParamError<InvalidDocumentFields>)]
fields: OptionStarOrList<String>,
#[deserr(default, error = DeserrQueryParamError<InvalidDocumentRetrieveVectors>)]
retrieve_vectors: Param<bool>,
}
pub async fn get_document(
@ -109,11 +112,12 @@ pub async fn get_document(
analytics.get_fetch_documents(&DocumentFetchKind::PerDocumentId, &req);
let GetDocument { fields } = params.into_inner();
let GetDocument { fields, retrieve_vectors } = params.into_inner();
let attributes_to_retrieve = fields.merge_star_and_none();
let index = index_scheduler.index(&index_uid)?;
let document = retrieve_document(&index, &document_id, attributes_to_retrieve)?;
let document =
retrieve_document(&index, &document_id, attributes_to_retrieve, retrieve_vectors.0)?;
debug!(returns = ?document, "Get document");
Ok(HttpResponse::Ok().json(document))
}
@ -153,6 +157,8 @@ pub struct BrowseQueryGet {
limit: Param<usize>,
#[deserr(default, error = DeserrQueryParamError<InvalidDocumentFields>)]
fields: OptionStarOrList<String>,
#[deserr(default, error = DeserrQueryParamError<InvalidDocumentRetrieveVectors>)]
retrieve_vectors: Param<bool>,
#[deserr(default, error = DeserrQueryParamError<InvalidDocumentFilter>)]
filter: Option<String>,
}
@ -166,6 +172,8 @@ pub struct BrowseQuery {
limit: usize,
#[deserr(default, error = DeserrJsonError<InvalidDocumentFields>)]
fields: Option<Vec<String>>,
#[deserr(default, error = DeserrJsonError<InvalidDocumentRetrieveVectors>)]
retrieve_vectors: bool,
#[deserr(default, error = DeserrJsonError<InvalidDocumentFilter>)]
filter: Option<Value>,
}
@ -201,7 +209,7 @@ pub async fn get_documents(
) -> Result<HttpResponse, ResponseError> {
debug!(parameters = ?params, "Get documents GET");
let BrowseQueryGet { limit, offset, fields, filter } = params.into_inner();
let BrowseQueryGet { limit, offset, fields, retrieve_vectors, filter } = params.into_inner();
let filter = match filter {
Some(f) => match serde_json::from_str(&f) {
@ -215,6 +223,7 @@ pub async fn get_documents(
offset: offset.0,
limit: limit.0,
fields: fields.merge_star_and_none(),
retrieve_vectors: retrieve_vectors.0,
filter,
};
@ -236,10 +245,11 @@ fn documents_by_query(
query: BrowseQuery,
) -> Result<HttpResponse, ResponseError> {
let index_uid = IndexUid::try_from(index_uid.into_inner())?;
let BrowseQuery { offset, limit, fields, filter } = query;
let BrowseQuery { offset, limit, fields, retrieve_vectors, filter } = query;
let index = index_scheduler.index(&index_uid)?;
let (total, documents) = retrieve_documents(&index, offset, limit, filter, fields)?;
let (total, documents) =
retrieve_documents(&index, offset, limit, filter, fields, retrieve_vectors)?;
let ret = PaginationView::new(offset, limit, total as usize, documents);
@ -579,13 +589,33 @@ fn some_documents<'a, 't: 'a>(
index: &'a Index,
rtxn: &'t RoTxn,
doc_ids: impl IntoIterator<Item = DocumentId> + 'a,
retrieve_vectors: bool,
) -> Result<impl Iterator<Item = Result<Document, ResponseError>> + 'a, ResponseError> {
let fields_ids_map = index.fields_ids_map(rtxn)?;
let all_fields: Vec<_> = fields_ids_map.iter().map(|(id, _)| id).collect();
let embedding_configs = index.embedding_configs(rtxn)?;
Ok(index.iter_documents(rtxn, doc_ids)?.map(move |ret| {
ret.map_err(ResponseError::from).and_then(|(_key, document)| -> Result<_, ResponseError> {
Ok(milli::obkv_to_json(&all_fields, &fields_ids_map, document)?)
ret.map_err(ResponseError::from).and_then(|(key, document)| -> Result<_, ResponseError> {
let mut document = milli::obkv_to_json(&all_fields, &fields_ids_map, document)?;
if retrieve_vectors {
let mut vectors = serde_json::Map::new();
for (name, vector) in index.embeddings(rtxn, key)? {
let user_provided = embedding_configs
.iter()
.find(|conf| conf.name == name)
.is_some_and(|conf| conf.user_provided.contains(key));
let embeddings = ExplicitVectors { embeddings: vector.into(), user_provided };
vectors.insert(
name,
serde_json::to_value(embeddings).map_err(MeilisearchHttpError::from)?,
);
}
document.insert("_vectors".into(), vectors.into());
}
Ok(document)
})
}))
}
@ -596,6 +626,7 @@ fn retrieve_documents<S: AsRef<str>>(
limit: usize,
filter: Option<Value>,
attributes_to_retrieve: Option<Vec<S>>,
retrieve_vectors: bool,
) -> Result<(u64, Vec<Document>), ResponseError> {
let rtxn = index.read_txn()?;
let filter = &filter;
@ -620,53 +651,58 @@ fn retrieve_documents<S: AsRef<str>>(
let (it, number_of_documents) = {
let number_of_documents = candidates.len();
(
some_documents(index, &rtxn, candidates.into_iter().skip(offset).take(limit))?,
some_documents(
index,
&rtxn,
candidates.into_iter().skip(offset).take(limit),
retrieve_vectors,
)?,
number_of_documents,
)
};
let documents: Result<Vec<_>, ResponseError> = it
let documents: Vec<_> = it
.map(|document| {
Ok(match &attributes_to_retrieve {
Some(attributes_to_retrieve) => permissive_json_pointer::select_values(
&document?,
attributes_to_retrieve.iter().map(|s| s.as_ref()),
attributes_to_retrieve
.iter()
.map(|s| s.as_ref())
.chain(retrieve_vectors.then_some("_vectors")),
),
None => document?,
})
})
.collect();
.collect::<Result<_, ResponseError>>()?;
Ok((number_of_documents, documents?))
Ok((number_of_documents, documents))
}
fn retrieve_document<S: AsRef<str>>(
index: &Index,
doc_id: &str,
attributes_to_retrieve: Option<Vec<S>>,
retrieve_vectors: bool,
) -> Result<Document, ResponseError> {
let txn = index.read_txn()?;
let fields_ids_map = index.fields_ids_map(&txn)?;
let all_fields: Vec<_> = fields_ids_map.iter().map(|(id, _)| id).collect();
let internal_id = index
.external_documents_ids()
.get(&txn, doc_id)?
.ok_or_else(|| MeilisearchHttpError::DocumentNotFound(doc_id.to_string()))?;
let document = index
.documents(&txn, std::iter::once(internal_id))?
.into_iter()
let document = some_documents(index, &txn, Some(internal_id), retrieve_vectors)?
.next()
.map(|(_, d)| d)
.ok_or_else(|| MeilisearchHttpError::DocumentNotFound(doc_id.to_string()))?;
.ok_or_else(|| MeilisearchHttpError::DocumentNotFound(doc_id.to_string()))??;
let document = meilisearch_types::milli::obkv_to_json(&all_fields, &fields_ids_map, document)?;
let document = match &attributes_to_retrieve {
Some(attributes_to_retrieve) => permissive_json_pointer::select_values(
&document,
attributes_to_retrieve.iter().map(|s| s.as_ref()),
attributes_to_retrieve
.iter()
.map(|s| s.as_ref())
.chain(retrieve_vectors.then_some("_vectors")),
),
None => document,
};