From 6607875f49b3047a3fe6d8771a700d21f36a0b9e Mon Sep 17 00:00:00 2001 From: Tamo Date: Wed, 5 Jun 2024 23:40:29 +0200 Subject: [PATCH] add the retrieveVectors parameter to the get and fetch documents route --- meilisearch-types/src/error.rs | 1 + meilisearch/src/routes/indexes/documents.rs | 80 ++++-- meilisearch/tests/common/index.rs | 37 +-- meilisearch/tests/common/mod.rs | 2 +- meilisearch/tests/documents/errors.rs | 24 ++ meilisearch/tests/documents/get_documents.rs | 268 ++++++++++++++++--- 6 files changed, 325 insertions(+), 87 deletions(-) diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index 63543fb1b..ae2a753db 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -222,6 +222,7 @@ InvalidApiKeyUid , InvalidRequest , BAD_REQUEST ; InvalidContentType , InvalidRequest , UNSUPPORTED_MEDIA_TYPE ; InvalidDocumentCsvDelimiter , InvalidRequest , BAD_REQUEST ; InvalidDocumentFields , InvalidRequest , BAD_REQUEST ; +InvalidDocumentRetrieveVectors , InvalidRequest , BAD_REQUEST ; MissingDocumentFilter , InvalidRequest , BAD_REQUEST ; InvalidDocumentFilter , InvalidRequest , BAD_REQUEST ; InvalidDocumentGeoField , InvalidRequest , BAD_REQUEST ; diff --git a/meilisearch/src/routes/indexes/documents.rs b/meilisearch/src/routes/indexes/documents.rs index 43fab1dae..97ded8069 100644 --- a/meilisearch/src/routes/indexes/documents.rs +++ b/meilisearch/src/routes/indexes/documents.rs @@ -16,6 +16,7 @@ use meilisearch_types::error::{Code, ResponseError}; use meilisearch_types::heed::RoTxn; use meilisearch_types::index_uid::IndexUid; use meilisearch_types::milli::update::IndexDocumentsMethod; +use meilisearch_types::milli::vector::parsed_vectors::ExplicitVectors; use meilisearch_types::milli::DocumentId; use meilisearch_types::star_or::OptionStarOrList; use meilisearch_types::tasks::KindWithContent; @@ -94,6 +95,8 @@ pub fn configure(cfg: &mut web::ServiceConfig) { pub struct GetDocument { #[deserr(default, error = DeserrQueryParamError)] fields: OptionStarOrList, + #[deserr(default, error = DeserrQueryParamError)] + retrieve_vectors: Param, } pub async fn get_document( @@ -109,11 +112,12 @@ pub async fn get_document( analytics.get_fetch_documents(&DocumentFetchKind::PerDocumentId, &req); - let GetDocument { fields } = params.into_inner(); + let GetDocument { fields, retrieve_vectors } = params.into_inner(); let attributes_to_retrieve = fields.merge_star_and_none(); let index = index_scheduler.index(&index_uid)?; - let document = retrieve_document(&index, &document_id, attributes_to_retrieve)?; + let document = + retrieve_document(&index, &document_id, attributes_to_retrieve, retrieve_vectors.0)?; debug!(returns = ?document, "Get document"); Ok(HttpResponse::Ok().json(document)) } @@ -153,6 +157,8 @@ pub struct BrowseQueryGet { limit: Param, #[deserr(default, error = DeserrQueryParamError)] fields: OptionStarOrList, + #[deserr(default, error = DeserrQueryParamError)] + retrieve_vectors: Param, #[deserr(default, error = DeserrQueryParamError)] filter: Option, } @@ -166,6 +172,8 @@ pub struct BrowseQuery { limit: usize, #[deserr(default, error = DeserrJsonError)] fields: Option>, + #[deserr(default, error = DeserrJsonError)] + retrieve_vectors: bool, #[deserr(default, error = DeserrJsonError)] filter: Option, } @@ -201,7 +209,7 @@ pub async fn get_documents( ) -> Result { debug!(parameters = ?params, "Get documents GET"); - let BrowseQueryGet { limit, offset, fields, filter } = params.into_inner(); + let BrowseQueryGet { limit, offset, fields, retrieve_vectors, filter } = params.into_inner(); let filter = match filter { Some(f) => match serde_json::from_str(&f) { @@ -215,6 +223,7 @@ pub async fn get_documents( offset: offset.0, limit: limit.0, fields: fields.merge_star_and_none(), + retrieve_vectors: retrieve_vectors.0, filter, }; @@ -236,10 +245,11 @@ fn documents_by_query( query: BrowseQuery, ) -> Result { let index_uid = IndexUid::try_from(index_uid.into_inner())?; - let BrowseQuery { offset, limit, fields, filter } = query; + let BrowseQuery { offset, limit, fields, retrieve_vectors, filter } = query; let index = index_scheduler.index(&index_uid)?; - let (total, documents) = retrieve_documents(&index, offset, limit, filter, fields)?; + let (total, documents) = + retrieve_documents(&index, offset, limit, filter, fields, retrieve_vectors)?; let ret = PaginationView::new(offset, limit, total as usize, documents); @@ -579,13 +589,33 @@ fn some_documents<'a, 't: 'a>( index: &'a Index, rtxn: &'t RoTxn, doc_ids: impl IntoIterator + 'a, + retrieve_vectors: bool, ) -> Result> + 'a, ResponseError> { let fields_ids_map = index.fields_ids_map(rtxn)?; let all_fields: Vec<_> = fields_ids_map.iter().map(|(id, _)| id).collect(); + let embedding_configs = index.embedding_configs(rtxn)?; Ok(index.iter_documents(rtxn, doc_ids)?.map(move |ret| { - ret.map_err(ResponseError::from).and_then(|(_key, document)| -> Result<_, ResponseError> { - Ok(milli::obkv_to_json(&all_fields, &fields_ids_map, document)?) + ret.map_err(ResponseError::from).and_then(|(key, document)| -> Result<_, ResponseError> { + let mut document = milli::obkv_to_json(&all_fields, &fields_ids_map, document)?; + + if retrieve_vectors { + let mut vectors = serde_json::Map::new(); + for (name, vector) in index.embeddings(rtxn, key)? { + let user_provided = embedding_configs + .iter() + .find(|conf| conf.name == name) + .is_some_and(|conf| conf.user_provided.contains(key)); + let embeddings = ExplicitVectors { embeddings: vector.into(), user_provided }; + vectors.insert( + name, + serde_json::to_value(embeddings).map_err(MeilisearchHttpError::from)?, + ); + } + document.insert("_vectors".into(), vectors.into()); + } + + Ok(document) }) })) } @@ -596,6 +626,7 @@ fn retrieve_documents>( limit: usize, filter: Option, attributes_to_retrieve: Option>, + retrieve_vectors: bool, ) -> Result<(u64, Vec), ResponseError> { let rtxn = index.read_txn()?; let filter = &filter; @@ -620,53 +651,58 @@ fn retrieve_documents>( let (it, number_of_documents) = { let number_of_documents = candidates.len(); ( - some_documents(index, &rtxn, candidates.into_iter().skip(offset).take(limit))?, + some_documents( + index, + &rtxn, + candidates.into_iter().skip(offset).take(limit), + retrieve_vectors, + )?, number_of_documents, ) }; - let documents: Result, ResponseError> = it + let documents: Vec<_> = it .map(|document| { Ok(match &attributes_to_retrieve { Some(attributes_to_retrieve) => permissive_json_pointer::select_values( &document?, - attributes_to_retrieve.iter().map(|s| s.as_ref()), + attributes_to_retrieve + .iter() + .map(|s| s.as_ref()) + .chain(retrieve_vectors.then_some("_vectors")), ), None => document?, }) }) - .collect(); + .collect::>()?; - Ok((number_of_documents, documents?)) + Ok((number_of_documents, documents)) } fn retrieve_document>( index: &Index, doc_id: &str, attributes_to_retrieve: Option>, + retrieve_vectors: bool, ) -> Result { let txn = index.read_txn()?; - let fields_ids_map = index.fields_ids_map(&txn)?; - let all_fields: Vec<_> = fields_ids_map.iter().map(|(id, _)| id).collect(); - let internal_id = index .external_documents_ids() .get(&txn, doc_id)? .ok_or_else(|| MeilisearchHttpError::DocumentNotFound(doc_id.to_string()))?; - let document = index - .documents(&txn, std::iter::once(internal_id))? - .into_iter() + let document = some_documents(index, &txn, Some(internal_id), retrieve_vectors)? .next() - .map(|(_, d)| d) - .ok_or_else(|| MeilisearchHttpError::DocumentNotFound(doc_id.to_string()))?; + .ok_or_else(|| MeilisearchHttpError::DocumentNotFound(doc_id.to_string()))??; - let document = meilisearch_types::milli::obkv_to_json(&all_fields, &fields_ids_map, document)?; let document = match &attributes_to_retrieve { Some(attributes_to_retrieve) => permissive_json_pointer::select_values( &document, - attributes_to_retrieve.iter().map(|s| s.as_ref()), + attributes_to_retrieve + .iter() + .map(|s| s.as_ref()) + .chain(retrieve_vectors.then_some("_vectors")), ), None => document, }; diff --git a/meilisearch/tests/common/index.rs b/meilisearch/tests/common/index.rs index 3ac33b4e9..f81fe8c8a 100644 --- a/meilisearch/tests/common/index.rs +++ b/meilisearch/tests/common/index.rs @@ -182,14 +182,10 @@ impl Index<'_> { self.service.get(url).await } - pub async fn get_document( - &self, - id: u64, - options: Option, - ) -> (Value, StatusCode) { + pub async fn get_document(&self, id: u64, options: Option) -> (Value, StatusCode) { let mut url = format!("/indexes/{}/documents/{}", urlencode(self.uid.as_ref()), id); - if let Some(fields) = options.and_then(|o| o.fields) { - let _ = write!(url, "?fields={}", fields.join(",")); + if let Some(options) = options { + write!(url, "?{}", yaup::to_string(&options).unwrap()).unwrap(); } self.service.get(url).await } @@ -205,18 +201,11 @@ impl Index<'_> { } pub async fn get_all_documents(&self, options: GetAllDocumentsOptions) -> (Value, StatusCode) { - let mut url = format!("/indexes/{}/documents?", urlencode(self.uid.as_ref())); - if let Some(limit) = options.limit { - let _ = write!(url, "limit={}&", limit); - } - - if let Some(offset) = options.offset { - let _ = write!(url, "offset={}&", offset); - } - - if let Some(attributes_to_retrieve) = options.attributes_to_retrieve { - let _ = write!(url, "fields={}&", attributes_to_retrieve.join(",")); - } + let url = format!( + "/indexes/{}/documents?{}", + urlencode(self.uid.as_ref()), + yaup::to_string(&options).unwrap() + ); self.service.get(url).await } @@ -435,13 +424,11 @@ impl Index<'_> { } } -pub struct GetDocumentOptions { - pub fields: Option>, -} - -#[derive(Debug, Default)] +#[derive(Debug, Default, serde::Serialize)] +#[serde(rename_all = "camelCase")] pub struct GetAllDocumentsOptions { pub limit: Option, pub offset: Option, - pub attributes_to_retrieve: Option>, + pub fields: Option>, + pub retrieve_vectors: bool, } diff --git a/meilisearch/tests/common/mod.rs b/meilisearch/tests/common/mod.rs index 3117dd185..317e5e171 100644 --- a/meilisearch/tests/common/mod.rs +++ b/meilisearch/tests/common/mod.rs @@ -6,7 +6,7 @@ pub mod service; use std::fmt::{self, Display}; #[allow(unused)] -pub use index::{GetAllDocumentsOptions, GetDocumentOptions}; +pub use index::GetAllDocumentsOptions; use meili_snap::json_string; use serde::{Deserialize, Serialize}; #[allow(unused)] diff --git a/meilisearch/tests/documents/errors.rs b/meilisearch/tests/documents/errors.rs index cd2d89813..cd1be4dc4 100644 --- a/meilisearch/tests/documents/errors.rs +++ b/meilisearch/tests/documents/errors.rs @@ -795,3 +795,27 @@ async fn fetch_document_by_filter() { } "###); } + +#[actix_rt::test] +async fn retrieve_vectors() { + let server = Server::new().await; + let index = server.index("doggo"); + let (response, _code) = index.get_all_documents_raw("?retrieveVectors=tamo").await; + snapshot!(json_string!(response), @r###" + { + "message": "Invalid value in parameter `retrieveVectors`: could not parse `tamo` as a boolean, expected either `true` or `false`", + "code": "invalid_document_retrieve_vectors", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_document_retrieve_vectors" + } + "###); + let (response, _code) = index.get_document(0, Some(json!({"retrieveVectors": "tamo"}))).await; + snapshot!(json_string!(response), @r###" + { + "message": "Invalid value in parameter `retrieveVectors`: could not parse `tamo` as a boolean, expected either `true` or `false`", + "code": "invalid_document_retrieve_vectors", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_document_retrieve_vectors" + } + "###); +} diff --git a/meilisearch/tests/documents/get_documents.rs b/meilisearch/tests/documents/get_documents.rs index 3b0629fcb..1ade00b06 100644 --- a/meilisearch/tests/documents/get_documents.rs +++ b/meilisearch/tests/documents/get_documents.rs @@ -4,7 +4,7 @@ use meili_snap::*; use urlencoding::encode as urlencode; use crate::common::encoder::Encoder; -use crate::common::{GetAllDocumentsOptions, GetDocumentOptions, Server, Value}; +use crate::common::{GetAllDocumentsOptions, Server, Value}; use crate::json; // TODO: partial test since we are testing error, amd error is not yet fully implemented in @@ -59,8 +59,7 @@ async fn get_document() { }) ); - let (response, code) = - index.get_document(0, Some(GetDocumentOptions { fields: Some(vec!["id"]) })).await; + let (response, code) = index.get_document(0, Some(json!({ "fields": ["id"] }))).await; assert_eq!(code, 200); assert_eq!( response, @@ -69,9 +68,8 @@ async fn get_document() { }) ); - let (response, code) = index - .get_document(0, Some(GetDocumentOptions { fields: Some(vec!["nested.content"]) })) - .await; + let (response, code) = + index.get_document(0, Some(json!({ "fields": ["nested.content"] }))).await; assert_eq!(code, 200); assert_eq!( response, @@ -211,7 +209,7 @@ async fn test_get_all_documents_attributes_to_retrieve() { let (response, code) = index .get_all_documents(GetAllDocumentsOptions { - attributes_to_retrieve: Some(vec!["name"]), + fields: Some(vec!["name"]), ..Default::default() }) .await; @@ -225,9 +223,21 @@ async fn test_get_all_documents_attributes_to_retrieve() { assert_eq!(response["limit"], json!(20)); assert_eq!(response["total"], json!(77)); + let (response, code) = index + .get_all_documents(GetAllDocumentsOptions { fields: Some(vec![]), ..Default::default() }) + .await; + assert_eq!(code, 200); + assert_eq!(response["results"].as_array().unwrap().len(), 20); + for results in response["results"].as_array().unwrap() { + assert_eq!(results.as_object().unwrap().keys().count(), 0); + } + assert_eq!(response["offset"], json!(0)); + assert_eq!(response["limit"], json!(20)); + assert_eq!(response["total"], json!(77)); + let (response, code) = index .get_all_documents(GetAllDocumentsOptions { - attributes_to_retrieve: Some(vec![]), + fields: Some(vec!["wrong"]), ..Default::default() }) .await; @@ -242,22 +252,7 @@ async fn test_get_all_documents_attributes_to_retrieve() { let (response, code) = index .get_all_documents(GetAllDocumentsOptions { - attributes_to_retrieve: Some(vec!["wrong"]), - ..Default::default() - }) - .await; - assert_eq!(code, 200); - assert_eq!(response["results"].as_array().unwrap().len(), 20); - for results in response["results"].as_array().unwrap() { - assert_eq!(results.as_object().unwrap().keys().count(), 0); - } - assert_eq!(response["offset"], json!(0)); - assert_eq!(response["limit"], json!(20)); - assert_eq!(response["total"], json!(77)); - - let (response, code) = index - .get_all_documents(GetAllDocumentsOptions { - attributes_to_retrieve: Some(vec!["name", "tags"]), + fields: Some(vec!["name", "tags"]), ..Default::default() }) .await; @@ -270,10 +265,7 @@ async fn test_get_all_documents_attributes_to_retrieve() { } let (response, code) = index - .get_all_documents(GetAllDocumentsOptions { - attributes_to_retrieve: Some(vec!["*"]), - ..Default::default() - }) + .get_all_documents(GetAllDocumentsOptions { fields: Some(vec!["*"]), ..Default::default() }) .await; assert_eq!(code, 200); assert_eq!(response["results"].as_array().unwrap().len(), 20); @@ -283,7 +275,7 @@ async fn test_get_all_documents_attributes_to_retrieve() { let (response, code) = index .get_all_documents(GetAllDocumentsOptions { - attributes_to_retrieve: Some(vec!["*", "wrong"]), + fields: Some(vec!["*", "wrong"]), ..Default::default() }) .await; @@ -316,12 +308,10 @@ async fn get_document_s_nested_attributes_to_retrieve() { assert_eq!(code, 202); index.wait_task(1).await; - let (response, code) = - index.get_document(0, Some(GetDocumentOptions { fields: Some(vec!["content"]) })).await; + let (response, code) = index.get_document(0, Some(json!({ "fields": ["content"] }))).await; assert_eq!(code, 200); assert_eq!(response, json!({})); - let (response, code) = - index.get_document(1, Some(GetDocumentOptions { fields: Some(vec!["content"]) })).await; + let (response, code) = index.get_document(1, Some(json!({ "fields": ["content"] }))).await; assert_eq!(code, 200); assert_eq!( response, @@ -333,9 +323,7 @@ async fn get_document_s_nested_attributes_to_retrieve() { }) ); - let (response, code) = index - .get_document(0, Some(GetDocumentOptions { fields: Some(vec!["content.truc"]) })) - .await; + let (response, code) = index.get_document(0, Some(json!({ "fields": ["content.truc"] }))).await; assert_eq!(code, 200); assert_eq!( response, @@ -343,9 +331,7 @@ async fn get_document_s_nested_attributes_to_retrieve() { "content.truc": "foobar", }) ); - let (response, code) = index - .get_document(1, Some(GetDocumentOptions { fields: Some(vec!["content.truc"]) })) - .await; + let (response, code) = index.get_document(1, Some(json!({ "fields": ["content.truc"] }))).await; assert_eq!(code, 200); assert_eq!( response, @@ -540,3 +526,207 @@ async fn get_document_by_filter() { } "###); } + +#[actix_rt::test] +async fn get_document_with_vectors() { + let server = Server::new().await; + let index = server.index("doggo"); + let (value, code) = server.set_features(json!({"vectorStore": true})).await; + snapshot!(code, @"200 OK"); + snapshot!(value, @r###" + { + "vectorStore": true, + "metrics": false, + "logsRoute": false + } + "###); + + let (response, code) = index + .update_settings(json!({ + "embedders": { + "manual": { + "source": "userProvided", + "dimensions": 3, + } + }, + })) + .await; + snapshot!(code, @"202 Accepted"); + server.wait_task(response.uid()).await; + + let documents = json!([ + {"id": 0, "name": "kefir", "_vectors": { "manual": [0, 0, 0] }}, + {"id": 1, "name": "echo", "_vectors": { "manual": null }}, + ]); + let (value, code) = index.add_documents(documents, None).await; + snapshot!(code, @"202 Accepted"); + index.wait_task(value.uid()).await; + + // by default you shouldn't see the `_vectors` object + let (documents, _code) = index.get_all_documents(Default::default()).await; + snapshot!(json_string!(documents), @r###" + { + "results": [ + { + "id": 0, + "name": "kefir" + }, + { + "id": 1, + "name": "echo" + } + ], + "offset": 0, + "limit": 20, + "total": 2 + } + "###); + let (documents, _code) = index.get_document(0, None).await; + snapshot!(json_string!(documents), @r###" + { + "id": 0, + "name": "kefir" + } + "###); + + // if we try to retrieve the vectors with the `fields` parameter they + // still shouldn't be displayed + let (documents, _code) = index + .get_all_documents(GetAllDocumentsOptions { + fields: Some(vec!["name", "_vectors"]), + ..Default::default() + }) + .await; + snapshot!(json_string!(documents), @r###" + { + "results": [ + { + "name": "kefir" + }, + { + "name": "echo" + } + ], + "offset": 0, + "limit": 20, + "total": 2 + } + "###); + let (documents, _code) = + index.get_document(0, Some(json!({"fields": ["name", "_vectors"]}))).await; + snapshot!(json_string!(documents), @r###" + { + "name": "kefir" + } + "###); + + // If we specify the retrieve vectors boolean and nothing else we should get the vectors + let (documents, _code) = index + .get_all_documents(GetAllDocumentsOptions { retrieve_vectors: true, ..Default::default() }) + .await; + snapshot!(json_string!(documents), @r###" + { + "results": [ + { + "id": 0, + "name": "kefir", + "_vectors": { + "manual": { + "embeddings": [ + [ + 0.0, + 0.0, + 0.0 + ] + ], + "userProvided": true + } + } + }, + { + "id": 1, + "name": "echo", + "_vectors": {} + } + ], + "offset": 0, + "limit": 20, + "total": 2 + } + "###); + let (documents, _code) = index.get_document(0, Some(json!({"retrieveVectors": true}))).await; + snapshot!(json_string!(documents), @r###" + { + "id": 0, + "name": "kefir", + "_vectors": { + "manual": { + "embeddings": [ + [ + 0.0, + 0.0, + 0.0 + ] + ], + "userProvided": true + } + } + } + "###); + + // If we specify the retrieve vectors boolean and exclude vectors form the `fields` we should still get the vectors + let (documents, _code) = index + .get_all_documents(GetAllDocumentsOptions { + retrieve_vectors: true, + fields: Some(vec!["name"]), + ..Default::default() + }) + .await; + snapshot!(json_string!(documents), @r###" + { + "results": [ + { + "name": "kefir", + "_vectors": { + "manual": { + "embeddings": [ + [ + 0.0, + 0.0, + 0.0 + ] + ], + "userProvided": true + } + } + }, + { + "name": "echo", + "_vectors": {} + } + ], + "offset": 0, + "limit": 20, + "total": 2 + } + "###); + let (documents, _code) = + index.get_document(0, Some(json!({"retrieveVectors": true, "fields": ["name"]}))).await; + snapshot!(json_string!(documents), @r###" + { + "name": "kefir", + "_vectors": { + "manual": { + "embeddings": [ + [ + 0.0, + 0.0, + 0.0 + ] + ], + "userProvided": true + } + } + } + "###); +}