diff --git a/Cargo.lock b/Cargo.lock index 966394cf6..e31943cf3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -500,7 +500,7 @@ checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" [[package]] name = "benchmarks" -version = "1.8.1" +version = "1.9.0" dependencies = [ "anyhow", "bytes", @@ -645,7 +645,7 @@ dependencies = [ [[package]] name = "build-info" -version = "1.8.1" +version = "1.9.0" dependencies = [ "anyhow", "time", @@ -1545,7 +1545,7 @@ dependencies = [ [[package]] name = "dump" -version = "1.8.1" +version = "1.9.0" dependencies = [ "anyhow", "big_s", @@ -1793,7 +1793,7 @@ dependencies = [ [[package]] name = "file-store" -version = "1.8.1" +version = "1.9.0" dependencies = [ "faux", "tempfile", @@ -1816,7 +1816,7 @@ dependencies = [ [[package]] name = "filter-parser" -version = "1.8.1" +version = "1.9.0" dependencies = [ "insta", "nom", @@ -1836,7 +1836,7 @@ dependencies = [ [[package]] name = "flatten-serde-json" -version = "1.8.1" +version = "1.9.0" dependencies = [ "criterion", "serde_json", @@ -1954,7 +1954,7 @@ dependencies = [ [[package]] name = "fuzzers" -version = "1.8.1" +version = "1.9.0" dependencies = [ "arbitrary", "clap", @@ -2447,7 +2447,7 @@ checksum = "206ca75c9c03ba3d4ace2460e57b189f39f43de612c2f85836e65c929701bb2d" [[package]] name = "index-scheduler" -version = "1.8.1" +version = "1.9.0" dependencies = [ "anyhow", "big_s", @@ -2642,7 +2642,7 @@ dependencies = [ [[package]] name = "json-depth-checker" -version = "1.8.1" +version = "1.9.0" dependencies = [ "criterion", "serde_json", @@ -3272,7 +3272,7 @@ checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" [[package]] name = "meili-snap" -version = "1.8.1" +version = "1.9.0" dependencies = [ "insta", "md5", @@ -3281,7 +3281,7 @@ dependencies = [ [[package]] name = "meilisearch" -version = "1.8.1" +version = "1.9.0" dependencies = [ "actix-cors", "actix-http", @@ -3373,7 +3373,7 @@ dependencies = [ [[package]] name = "meilisearch-auth" -version = "1.8.1" +version = "1.9.0" dependencies = [ "base64 0.21.7", "enum-iterator", @@ -3392,7 +3392,7 @@ dependencies = [ [[package]] name = "meilisearch-types" -version = "1.8.1" +version = "1.9.0" dependencies = [ "actix-web", "anyhow", @@ -3422,7 +3422,7 @@ dependencies = [ [[package]] name = "meilitool" -version = "1.8.1" +version = "1.9.0" dependencies = [ "anyhow", "clap", @@ -3461,7 +3461,7 @@ dependencies = [ [[package]] name = "milli" -version = "1.8.1" +version = "1.9.0" dependencies = [ "arroy", "big_s", @@ -3901,7 +3901,7 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "permissive-json-pointer" -version = "1.8.1" +version = "1.9.0" dependencies = [ "big_s", "serde_json", @@ -6052,7 +6052,7 @@ dependencies = [ [[package]] name = "xtask" -version = "1.8.1" +version = "1.9.0" dependencies = [ "anyhow", "build-info", diff --git a/Cargo.toml b/Cargo.toml index eadef3a1b..5c6c8b376 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,7 @@ members = [ ] [workspace.package] -version = "1.8.1" +version = "1.9.0" authors = [ "Quentin de Quelen ", "Clément Renault ", diff --git a/meilisearch-types/src/deserr/mod.rs b/meilisearch-types/src/deserr/mod.rs index bf1aa1da5..c593c50fb 100644 --- a/meilisearch-types/src/deserr/mod.rs +++ b/meilisearch-types/src/deserr/mod.rs @@ -189,3 +189,4 @@ merge_with_error_impl_take_error_message!(ParseTaskKindError); merge_with_error_impl_take_error_message!(ParseTaskStatusError); merge_with_error_impl_take_error_message!(IndexUidFormatError); merge_with_error_impl_take_error_message!(InvalidSearchSemanticRatio); +merge_with_error_impl_take_error_message!(InvalidSimilarId); diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index 85a2cd767..d2218807f 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -239,18 +239,23 @@ InvalidIndexUid , InvalidRequest , BAD_REQUEST ; InvalidSearchAttributesToSearchOn , InvalidRequest , BAD_REQUEST ; InvalidSearchAttributesToCrop , InvalidRequest , BAD_REQUEST ; InvalidSearchAttributesToHighlight , InvalidRequest , BAD_REQUEST ; +InvalidSimilarAttributesToRetrieve , InvalidRequest , BAD_REQUEST ; InvalidSearchAttributesToRetrieve , InvalidRequest , BAD_REQUEST ; InvalidSearchCropLength , InvalidRequest , BAD_REQUEST ; InvalidSearchCropMarker , InvalidRequest , BAD_REQUEST ; InvalidSearchFacets , InvalidRequest , BAD_REQUEST ; InvalidSearchSemanticRatio , InvalidRequest , BAD_REQUEST ; InvalidFacetSearchFacetName , InvalidRequest , BAD_REQUEST ; +InvalidSimilarId , InvalidRequest , BAD_REQUEST ; InvalidSearchFilter , InvalidRequest , BAD_REQUEST ; +InvalidSimilarFilter , InvalidRequest , BAD_REQUEST ; InvalidSearchHighlightPostTag , InvalidRequest , BAD_REQUEST ; InvalidSearchHighlightPreTag , InvalidRequest , BAD_REQUEST ; InvalidSearchHitsPerPage , InvalidRequest , BAD_REQUEST ; +InvalidSimilarLimit , InvalidRequest , BAD_REQUEST ; InvalidSearchLimit , InvalidRequest , BAD_REQUEST ; InvalidSearchMatchingStrategy , InvalidRequest , BAD_REQUEST ; +InvalidSimilarOffset , InvalidRequest , BAD_REQUEST ; InvalidSearchOffset , InvalidRequest , BAD_REQUEST ; InvalidSearchPage , InvalidRequest , BAD_REQUEST ; InvalidSearchQ , InvalidRequest , BAD_REQUEST ; @@ -259,7 +264,9 @@ InvalidFacetSearchName , InvalidRequest , BAD_REQUEST ; InvalidSearchVector , InvalidRequest , BAD_REQUEST ; InvalidSearchShowMatchesPosition , InvalidRequest , BAD_REQUEST ; InvalidSearchShowRankingScore , InvalidRequest , BAD_REQUEST ; +InvalidSimilarShowRankingScore , InvalidRequest , BAD_REQUEST ; InvalidSearchShowRankingScoreDetails , InvalidRequest , BAD_REQUEST ; +InvalidSimilarShowRankingScoreDetails , InvalidRequest , BAD_REQUEST ; InvalidSearchSort , InvalidRequest , BAD_REQUEST ; InvalidSettingsDisplayedAttributes , InvalidRequest , BAD_REQUEST ; InvalidSettingsDistinctAttribute , InvalidRequest , BAD_REQUEST ; @@ -322,7 +329,8 @@ UnretrievableErrorCode , InvalidRequest , BAD_REQUEST ; UnsupportedMediaType , InvalidRequest , UNSUPPORTED_MEDIA_TYPE ; // Experimental features -VectorEmbeddingError , InvalidRequest , BAD_REQUEST +VectorEmbeddingError , InvalidRequest , BAD_REQUEST ; +NotFoundSimilarId , InvalidRequest , BAD_REQUEST } impl ErrorCode for JoinError { @@ -486,6 +494,17 @@ impl fmt::Display for deserr_codes::InvalidSearchSemanticRatio { } } +impl fmt::Display for deserr_codes::InvalidSimilarId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "the value of `id` is invalid. \ + A document identifier can be of type integer or string, \ + only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and underscores (_)." + ) + } +} + #[macro_export] macro_rules! internal_error { ($target:ty : $($other:path), *) => { diff --git a/meilisearch/src/analytics/mock_analytics.rs b/meilisearch/src/analytics/mock_analytics.rs index 1687e9e19..8f2fe0333 100644 --- a/meilisearch/src/analytics/mock_analytics.rs +++ b/meilisearch/src/analytics/mock_analytics.rs @@ -25,6 +25,18 @@ impl SearchAggregator { pub fn succeed(&mut self, _: &dyn Any) {} } +#[derive(Default)] +pub struct SimilarAggregator; + +#[allow(dead_code)] +impl SimilarAggregator { + pub fn from_query(_: &dyn Any, _: &dyn Any) -> Self { + Self + } + + pub fn succeed(&mut self, _: &dyn Any) {} +} + #[derive(Default)] pub struct MultiSearchAggregator; @@ -66,6 +78,8 @@ impl Analytics for MockAnalytics { fn publish(&self, _event_name: String, _send: Value, _request: Option<&HttpRequest>) {} fn get_search(&self, _aggregate: super::SearchAggregator) {} fn post_search(&self, _aggregate: super::SearchAggregator) {} + fn get_similar(&self, _aggregate: super::SimilarAggregator) {} + fn post_similar(&self, _aggregate: super::SimilarAggregator) {} fn post_multi_search(&self, _aggregate: super::MultiSearchAggregator) {} fn post_facet_search(&self, _aggregate: super::FacetSearchAggregator) {} fn add_documents( diff --git a/meilisearch/src/analytics/mod.rs b/meilisearch/src/analytics/mod.rs index 09c0a05df..3468ad2c7 100644 --- a/meilisearch/src/analytics/mod.rs +++ b/meilisearch/src/analytics/mod.rs @@ -22,6 +22,8 @@ pub type SegmentAnalytics = mock_analytics::MockAnalytics; #[cfg(not(feature = "analytics"))] pub type SearchAggregator = mock_analytics::SearchAggregator; #[cfg(not(feature = "analytics"))] +pub type SimilarAggregator = mock_analytics::SimilarAggregator; +#[cfg(not(feature = "analytics"))] pub type MultiSearchAggregator = mock_analytics::MultiSearchAggregator; #[cfg(not(feature = "analytics"))] pub type FacetSearchAggregator = mock_analytics::FacetSearchAggregator; @@ -32,6 +34,8 @@ pub type SegmentAnalytics = segment_analytics::SegmentAnalytics; #[cfg(feature = "analytics")] pub type SearchAggregator = segment_analytics::SearchAggregator; #[cfg(feature = "analytics")] +pub type SimilarAggregator = segment_analytics::SimilarAggregator; +#[cfg(feature = "analytics")] pub type MultiSearchAggregator = segment_analytics::MultiSearchAggregator; #[cfg(feature = "analytics")] pub type FacetSearchAggregator = segment_analytics::FacetSearchAggregator; @@ -86,6 +90,12 @@ pub trait Analytics: Sync + Send { /// This method should be called to aggregate a post search fn post_search(&self, aggregate: SearchAggregator); + /// This method should be called to aggregate a get similar request + fn get_similar(&self, aggregate: SimilarAggregator); + + /// This method should be called to aggregate a post similar request + fn post_similar(&self, aggregate: SimilarAggregator); + /// This method should be called to aggregate a post array of searches fn post_multi_search(&self, aggregate: MultiSearchAggregator); diff --git a/meilisearch/src/analytics/segment_analytics.rs b/meilisearch/src/analytics/segment_analytics.rs index 8c20c82c2..add430893 100644 --- a/meilisearch/src/analytics/segment_analytics.rs +++ b/meilisearch/src/analytics/segment_analytics.rs @@ -36,8 +36,9 @@ use crate::routes::indexes::facet_search::FacetSearchQuery; use crate::routes::{create_all_stats, Stats}; use crate::search::{ FacetSearchResult, MatchingStrategy, SearchQuery, SearchQueryWithIndex, SearchResult, - DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, - DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEMANTIC_RATIO, + SimilarQuery, SimilarResult, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, + DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, + DEFAULT_SEMANTIC_RATIO, }; use crate::Opt; @@ -73,6 +74,8 @@ pub enum AnalyticsMsg { BatchMessage(Track), AggregateGetSearch(SearchAggregator), AggregatePostSearch(SearchAggregator), + AggregateGetSimilar(SimilarAggregator), + AggregatePostSimilar(SimilarAggregator), AggregatePostMultiSearch(MultiSearchAggregator), AggregatePostFacetSearch(FacetSearchAggregator), AggregateAddDocuments(DocumentsAggregator), @@ -149,6 +152,8 @@ impl SegmentAnalytics { update_documents_aggregator: DocumentsAggregator::default(), get_fetch_documents_aggregator: DocumentsFetchAggregator::default(), post_fetch_documents_aggregator: DocumentsFetchAggregator::default(), + get_similar_aggregator: SimilarAggregator::default(), + post_similar_aggregator: SimilarAggregator::default(), }); tokio::spawn(segment.run(index_scheduler.clone(), auth_controller.clone())); @@ -184,6 +189,14 @@ impl super::Analytics for SegmentAnalytics { let _ = self.sender.try_send(AnalyticsMsg::AggregatePostSearch(aggregate)); } + fn get_similar(&self, aggregate: SimilarAggregator) { + let _ = self.sender.try_send(AnalyticsMsg::AggregateGetSimilar(aggregate)); + } + + fn post_similar(&self, aggregate: SimilarAggregator) { + let _ = self.sender.try_send(AnalyticsMsg::AggregatePostSimilar(aggregate)); + } + fn post_facet_search(&self, aggregate: FacetSearchAggregator) { let _ = self.sender.try_send(AnalyticsMsg::AggregatePostFacetSearch(aggregate)); } @@ -379,6 +392,8 @@ pub struct Segment { update_documents_aggregator: DocumentsAggregator, get_fetch_documents_aggregator: DocumentsFetchAggregator, post_fetch_documents_aggregator: DocumentsFetchAggregator, + get_similar_aggregator: SimilarAggregator, + post_similar_aggregator: SimilarAggregator, } impl Segment { @@ -441,6 +456,8 @@ impl Segment { Some(AnalyticsMsg::AggregateUpdateDocuments(agreg)) => self.update_documents_aggregator.aggregate(agreg), Some(AnalyticsMsg::AggregateGetFetchDocuments(agreg)) => self.get_fetch_documents_aggregator.aggregate(agreg), Some(AnalyticsMsg::AggregatePostFetchDocuments(agreg)) => self.post_fetch_documents_aggregator.aggregate(agreg), + Some(AnalyticsMsg::AggregateGetSimilar(agreg)) => self.get_similar_aggregator.aggregate(agreg), + Some(AnalyticsMsg::AggregatePostSimilar(agreg)) => self.post_similar_aggregator.aggregate(agreg), None => (), } } @@ -494,6 +511,8 @@ impl Segment { update_documents_aggregator, get_fetch_documents_aggregator, post_fetch_documents_aggregator, + get_similar_aggregator, + post_similar_aggregator, } = self; if let Some(get_search) = @@ -541,6 +560,18 @@ impl Segment { { let _ = self.batcher.push(post_fetch_documents).await; } + + if let Some(get_similar_documents) = + take(get_similar_aggregator).into_event(user, "Similar GET") + { + let _ = self.batcher.push(get_similar_documents).await; + } + + if let Some(post_similar_documents) = + take(post_similar_aggregator).into_event(user, "Similar POST") + { + let _ = self.batcher.push(post_similar_documents).await; + } let _ = self.batcher.flush().await; } } @@ -1558,3 +1589,235 @@ impl DocumentsFetchAggregator { }) } } + +#[derive(Default)] +pub struct SimilarAggregator { + timestamp: Option, + + // context + user_agents: HashSet, + + // requests + total_received: usize, + total_succeeded: usize, + time_spent: BinaryHeap, + + // filter + filter_with_geo_radius: bool, + filter_with_geo_bounding_box: bool, + // every time a request has a filter, this field must be incremented by the number of terms it contains + filter_sum_of_criteria_terms: usize, + // every time a request has a filter, this field must be incremented by one + filter_total_number_of_criteria: usize, + used_syntax: HashMap, + + // Whether a non-default embedder was specified + embedder: bool, + + // pagination + max_limit: usize, + max_offset: usize, + + // formatting + max_attributes_to_retrieve: usize, + + // scoring + show_ranking_score: bool, + show_ranking_score_details: bool, +} + +impl SimilarAggregator { + #[allow(clippy::field_reassign_with_default)] + pub fn from_query(query: &SimilarQuery, request: &HttpRequest) -> Self { + let SimilarQuery { + id: _, + embedder, + offset, + limit, + attributes_to_retrieve: _, + show_ranking_score, + show_ranking_score_details, + filter, + } = query; + + let mut ret = Self::default(); + ret.timestamp = Some(OffsetDateTime::now_utc()); + + ret.total_received = 1; + ret.user_agents = extract_user_agents(request).into_iter().collect(); + + if let Some(ref filter) = filter { + static RE: Lazy = Lazy::new(|| Regex::new("AND | OR").unwrap()); + ret.filter_total_number_of_criteria = 1; + + let syntax = match filter { + Value::String(_) => "string".to_string(), + Value::Array(values) => { + if values.iter().map(|v| v.to_string()).any(|s| RE.is_match(&s)) { + "mixed".to_string() + } else { + "array".to_string() + } + } + _ => "none".to_string(), + }; + // convert the string to a HashMap + ret.used_syntax.insert(syntax, 1); + + let stringified_filters = filter.to_string(); + ret.filter_with_geo_radius = stringified_filters.contains("_geoRadius("); + ret.filter_with_geo_bounding_box = stringified_filters.contains("_geoBoundingBox("); + ret.filter_sum_of_criteria_terms = RE.split(&stringified_filters).count(); + } + + ret.max_limit = *limit; + ret.max_offset = *offset; + + ret.show_ranking_score = *show_ranking_score; + ret.show_ranking_score_details = *show_ranking_score_details; + + ret.embedder = embedder.is_some(); + + ret + } + + pub fn succeed(&mut self, result: &SimilarResult) { + let SimilarResult { id: _, hits: _, processing_time_ms, hits_info: _ } = result; + + self.total_succeeded = self.total_succeeded.saturating_add(1); + + self.time_spent.push(*processing_time_ms as usize); + } + + /// Aggregate one [SimilarAggregator] into another. + pub fn aggregate(&mut self, mut other: Self) { + let Self { + timestamp, + user_agents, + total_received, + total_succeeded, + ref mut time_spent, + filter_with_geo_radius, + filter_with_geo_bounding_box, + filter_sum_of_criteria_terms, + filter_total_number_of_criteria, + used_syntax, + max_limit, + max_offset, + max_attributes_to_retrieve, + show_ranking_score, + show_ranking_score_details, + embedder, + } = other; + + if self.timestamp.is_none() { + self.timestamp = timestamp; + } + + // context + for user_agent in user_agents.into_iter() { + self.user_agents.insert(user_agent); + } + + // request + self.total_received = self.total_received.saturating_add(total_received); + self.total_succeeded = self.total_succeeded.saturating_add(total_succeeded); + self.time_spent.append(time_spent); + + // filter + self.filter_with_geo_radius |= filter_with_geo_radius; + self.filter_with_geo_bounding_box |= filter_with_geo_bounding_box; + self.filter_sum_of_criteria_terms = + self.filter_sum_of_criteria_terms.saturating_add(filter_sum_of_criteria_terms); + self.filter_total_number_of_criteria = + self.filter_total_number_of_criteria.saturating_add(filter_total_number_of_criteria); + for (key, value) in used_syntax.into_iter() { + let used_syntax = self.used_syntax.entry(key).or_insert(0); + *used_syntax = used_syntax.saturating_add(value); + } + + self.embedder |= embedder; + + // pagination + self.max_limit = self.max_limit.max(max_limit); + self.max_offset = self.max_offset.max(max_offset); + + // formatting + self.max_attributes_to_retrieve = + self.max_attributes_to_retrieve.max(max_attributes_to_retrieve); + + // scoring + self.show_ranking_score |= show_ranking_score; + self.show_ranking_score_details |= show_ranking_score_details; + } + + pub fn into_event(self, user: &User, event_name: &str) -> Option { + let Self { + timestamp, + user_agents, + total_received, + total_succeeded, + time_spent, + filter_with_geo_radius, + filter_with_geo_bounding_box, + filter_sum_of_criteria_terms, + filter_total_number_of_criteria, + used_syntax, + max_limit, + max_offset, + max_attributes_to_retrieve, + show_ranking_score, + show_ranking_score_details, + embedder, + } = self; + + if total_received == 0 { + None + } else { + // we get all the values in a sorted manner + let time_spent = time_spent.into_sorted_vec(); + // the index of the 99th percentage of value + let percentile_99th = time_spent.len() * 99 / 100; + // We are only interested by the slowest value of the 99th fastest results + let time_spent = time_spent.get(percentile_99th); + + let properties = json!({ + "user-agent": user_agents, + "requests": { + "99th_response_time": time_spent.map(|t| format!("{:.2}", t)), + "total_succeeded": total_succeeded, + "total_failed": total_received.saturating_sub(total_succeeded), // just to be sure we never panics + "total_received": total_received, + }, + "filter": { + "with_geoRadius": filter_with_geo_radius, + "with_geoBoundingBox": filter_with_geo_bounding_box, + "avg_criteria_number": format!("{:.2}", filter_sum_of_criteria_terms as f64 / filter_total_number_of_criteria as f64), + "most_used_syntax": used_syntax.iter().max_by_key(|(_, v)| *v).map(|(k, _)| json!(k)).unwrap_or_else(|| json!(null)), + }, + "hybrid": { + "embedder": embedder, + }, + "pagination": { + "max_limit": max_limit, + "max_offset": max_offset, + }, + "formatting": { + "max_attributes_to_retrieve": max_attributes_to_retrieve, + }, + "scoring": { + "show_ranking_score": show_ranking_score, + "show_ranking_score_details": show_ranking_score_details, + }, + }); + + Some(Track { + timestamp, + user: user.clone(), + event: event_name.to_string(), + properties, + ..Default::default() + }) + } + } +} diff --git a/meilisearch/src/routes/indexes/facet_search.rs b/meilisearch/src/routes/indexes/facet_search.rs index 4d6950988..3f05fa846 100644 --- a/meilisearch/src/routes/indexes/facet_search.rs +++ b/meilisearch/src/routes/indexes/facet_search.rs @@ -69,7 +69,7 @@ pub async fn search( // Tenant token search_rules. if let Some(search_rules) = index_scheduler.filters().get_index_search_rules(&index_uid) { - add_search_rules(&mut search_query, search_rules); + add_search_rules(&mut search_query.filter, search_rules); } let index = index_scheduler.index(&index_uid)?; diff --git a/meilisearch/src/routes/indexes/mod.rs b/meilisearch/src/routes/indexes/mod.rs index 651977723..35b747ccf 100644 --- a/meilisearch/src/routes/indexes/mod.rs +++ b/meilisearch/src/routes/indexes/mod.rs @@ -29,6 +29,7 @@ pub mod documents; pub mod facet_search; pub mod search; pub mod settings; +pub mod similar; pub fn configure(cfg: &mut web::ServiceConfig) { cfg.service( @@ -48,6 +49,7 @@ pub fn configure(cfg: &mut web::ServiceConfig) { .service(web::scope("/documents").configure(documents::configure)) .service(web::scope("/search").configure(search::configure)) .service(web::scope("/facet-search").configure(facet_search::configure)) + .service(web::scope("/similar").configure(similar::configure)) .service(web::scope("/settings").configure(settings::configure)), ); } diff --git a/meilisearch/src/routes/indexes/search.rs b/meilisearch/src/routes/indexes/search.rs index 5581e6a68..8628da6d9 100644 --- a/meilisearch/src/routes/indexes/search.rs +++ b/meilisearch/src/routes/indexes/search.rs @@ -196,7 +196,7 @@ pub async fn search_with_url_query( // Tenant token search_rules. if let Some(search_rules) = index_scheduler.filters().get_index_search_rules(&index_uid) { - add_search_rules(&mut query, search_rules); + add_search_rules(&mut query.filter, search_rules); } let mut aggregate = SearchAggregator::from_query(&query, &req); @@ -235,7 +235,7 @@ pub async fn search_with_post( // Tenant token search_rules. if let Some(search_rules) = index_scheduler.filters().get_index_search_rules(&index_uid) { - add_search_rules(&mut query, search_rules); + add_search_rules(&mut query.filter, search_rules); } let mut aggregate = SearchAggregator::from_query(&query, &req); diff --git a/meilisearch/src/routes/indexes/similar.rs b/meilisearch/src/routes/indexes/similar.rs new file mode 100644 index 000000000..da73dd63b --- /dev/null +++ b/meilisearch/src/routes/indexes/similar.rs @@ -0,0 +1,171 @@ +use actix_web::web::{self, Data}; +use actix_web::{HttpRequest, HttpResponse}; +use deserr::actix_web::{AwebJson, AwebQueryParameter}; +use index_scheduler::IndexScheduler; +use meilisearch_types::deserr::query_params::Param; +use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError}; +use meilisearch_types::error::deserr_codes::{ + InvalidEmbedder, InvalidSimilarAttributesToRetrieve, InvalidSimilarFilter, InvalidSimilarId, + InvalidSimilarLimit, InvalidSimilarOffset, InvalidSimilarShowRankingScore, + InvalidSimilarShowRankingScoreDetails, +}; +use meilisearch_types::error::{ErrorCode as _, ResponseError}; +use meilisearch_types::index_uid::IndexUid; +use meilisearch_types::keys::actions; +use meilisearch_types::serde_cs::vec::CS; +use serde_json::Value; +use tracing::debug; + +use super::ActionPolicy; +use crate::analytics::{Analytics, SimilarAggregator}; +use crate::extractors::authentication::GuardedData; +use crate::extractors::sequential_extractor::SeqHandler; +use crate::search::{ + add_search_rules, perform_similar, SearchKind, SimilarQuery, SimilarResult, + DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, +}; + +pub fn configure(cfg: &mut web::ServiceConfig) { + cfg.service( + web::resource("") + .route(web::get().to(SeqHandler(similar_get))) + .route(web::post().to(SeqHandler(similar_post))), + ); +} + +pub async fn similar_get( + index_scheduler: GuardedData, Data>, + index_uid: web::Path, + params: AwebQueryParameter, + req: HttpRequest, + analytics: web::Data, +) -> Result { + let index_uid = IndexUid::try_from(index_uid.into_inner())?; + + let query = params.0.try_into().map_err(|code: InvalidSimilarId| { + ResponseError::from_msg(code.to_string(), code.error_code()) + })?; + + let mut aggregate = SimilarAggregator::from_query(&query, &req); + + debug!(parameters = ?query, "Similar get"); + + let similar = similar(index_scheduler, index_uid, query).await; + + if let Ok(similar) = &similar { + aggregate.succeed(similar); + } + analytics.get_similar(aggregate); + + let similar = similar?; + + debug!(returns = ?similar, "Similar get"); + Ok(HttpResponse::Ok().json(similar)) +} + +pub async fn similar_post( + index_scheduler: GuardedData, Data>, + index_uid: web::Path, + params: AwebJson, + req: HttpRequest, + analytics: web::Data, +) -> Result { + let index_uid = IndexUid::try_from(index_uid.into_inner())?; + + let query = params.into_inner(); + debug!(parameters = ?query, "Similar post"); + + let mut aggregate = SimilarAggregator::from_query(&query, &req); + + let similar = similar(index_scheduler, index_uid, query).await; + + if let Ok(similar) = &similar { + aggregate.succeed(similar); + } + analytics.post_similar(aggregate); + + let similar = similar?; + + debug!(returns = ?similar, "Similar post"); + Ok(HttpResponse::Ok().json(similar)) +} + +async fn similar( + index_scheduler: GuardedData, Data>, + index_uid: IndexUid, + mut query: SimilarQuery, +) -> Result { + let features = index_scheduler.features(); + + features.check_vector("Using the similar API")?; + + // Tenant token search_rules. + if let Some(search_rules) = index_scheduler.filters().get_index_search_rules(&index_uid) { + add_search_rules(&mut query.filter, search_rules); + } + + let index = index_scheduler.index(&index_uid)?; + + let (embedder_name, embedder) = + SearchKind::embedder(&index_scheduler, &index, query.embedder.as_deref(), None)?; + + tokio::task::spawn_blocking(move || perform_similar(&index, query, embedder_name, embedder)) + .await? +} + +#[derive(Debug, deserr::Deserr)] +#[deserr(error = DeserrQueryParamError, rename_all = camelCase, deny_unknown_fields)] +pub struct SimilarQueryGet { + #[deserr(error = DeserrQueryParamError)] + id: Param, + #[deserr(default = Param(DEFAULT_SEARCH_OFFSET()), error = DeserrQueryParamError)] + offset: Param, + #[deserr(default = Param(DEFAULT_SEARCH_LIMIT()), error = DeserrQueryParamError)] + limit: Param, + #[deserr(default, error = DeserrQueryParamError)] + attributes_to_retrieve: Option>, + #[deserr(default, error = DeserrQueryParamError)] + filter: Option, + #[deserr(default, error = DeserrQueryParamError)] + show_ranking_score: Param, + #[deserr(default, error = DeserrQueryParamError)] + show_ranking_score_details: Param, + #[deserr(default, error = DeserrQueryParamError)] + pub embedder: Option, +} + +impl TryFrom for SimilarQuery { + type Error = InvalidSimilarId; + + fn try_from( + SimilarQueryGet { + id, + offset, + limit, + attributes_to_retrieve, + filter, + show_ranking_score, + show_ranking_score_details, + embedder, + }: SimilarQueryGet, + ) -> Result { + let filter = match filter { + Some(f) => match serde_json::from_str(&f) { + Ok(v) => Some(v), + _ => Some(Value::String(f)), + }, + None => None, + }; + + Ok(SimilarQuery { + id: id.0.try_into()?, + offset: offset.0, + limit: limit.0, + filter, + embedder, + attributes_to_retrieve: attributes_to_retrieve.map(|o| o.into_iter().collect()), + show_ranking_score: show_ranking_score.0, + show_ranking_score_details: show_ranking_score_details.0, + }) + } +} diff --git a/meilisearch/src/routes/multi_search.rs b/meilisearch/src/routes/multi_search.rs index 7b7cbd265..a83dc4bc0 100644 --- a/meilisearch/src/routes/multi_search.rs +++ b/meilisearch/src/routes/multi_search.rs @@ -67,7 +67,7 @@ pub async fn multi_search_with_post( // Apply search rules from tenant token if let Some(search_rules) = index_scheduler.filters().get_index_search_rules(&index_uid) { - add_search_rules(&mut query, search_rules); + add_search_rules(&mut query.filter, search_rules); } let index = index_scheduler diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index 34ebe463d..c6c4e88ca 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -11,7 +11,7 @@ use indexmap::IndexMap; use meilisearch_auth::IndexSearchRules; use meilisearch_types::deserr::DeserrJsonError; use meilisearch_types::error::deserr_codes::*; -use meilisearch_types::error::ResponseError; +use meilisearch_types::error::{Code, ResponseError}; use meilisearch_types::heed::RoTxn; use meilisearch_types::index_uid::IndexUid; use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy}; @@ -231,7 +231,7 @@ impl SearchKind { Ok(Self::Hybrid { embedder_name, embedder, semantic_ratio }) } - fn embedder( + pub(crate) fn embedder( index_scheduler: &index_scheduler::IndexScheduler, index: &Index, embedder_name: Option<&str>, @@ -417,6 +417,59 @@ impl SearchQueryWithIndex { } } +#[derive(Debug, Clone, PartialEq, Deserr)] +#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] +pub struct SimilarQuery { + #[deserr(error = DeserrJsonError)] + pub id: ExternalDocumentId, + #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError)] + pub offset: usize, + #[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError)] + pub limit: usize, + #[deserr(default, error = DeserrJsonError)] + pub filter: Option, + #[deserr(default, error = DeserrJsonError, default)] + pub embedder: Option, + #[deserr(default, error = DeserrJsonError)] + pub attributes_to_retrieve: Option>, + #[deserr(default, error = DeserrJsonError, default)] + pub show_ranking_score: bool, + #[deserr(default, error = DeserrJsonError, default)] + pub show_ranking_score_details: bool, +} + +#[derive(Debug, Clone, PartialEq, Deserr)] +#[deserr(try_from(Value) = TryFrom::try_from -> InvalidSimilarId)] +pub struct ExternalDocumentId(String); + +impl AsRef for ExternalDocumentId { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl ExternalDocumentId { + pub fn into_inner(self) -> String { + self.0 + } +} + +impl TryFrom for ExternalDocumentId { + type Error = InvalidSimilarId; + + fn try_from(value: String) -> Result { + serde_json::Value::String(value).try_into() + } +} + +impl TryFrom for ExternalDocumentId { + type Error = InvalidSimilarId; + + fn try_from(value: Value) -> Result { + Ok(Self(milli::documents::validate_document_id_value(value).map_err(|_| InvalidSimilarId)?)) + } +} + #[derive(Debug, Copy, Clone, PartialEq, Eq, Deserr)] #[deserr(rename_all = camelCase)] pub enum MatchingStrategy { @@ -538,6 +591,16 @@ impl fmt::Debug for SearchResult { } } +#[derive(Serialize, Debug, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +pub struct SimilarResult { + pub hits: Vec, + pub id: String, + pub processing_time_ms: u128, + #[serde(flatten)] + pub hits_info: HitsInfo, +} + #[derive(Serialize, Debug, Clone, PartialEq)] #[serde(rename_all = "camelCase")] pub struct SearchResultWithIndex { @@ -570,8 +633,8 @@ pub struct FacetSearchResult { } /// Incorporate search rules in search query -pub fn add_search_rules(query: &mut SearchQuery, rules: IndexSearchRules) { - query.filter = match (query.filter.take(), rules.filter) { +pub fn add_search_rules(filter: &mut Option, rules: IndexSearchRules) { + *filter = match (filter.take(), rules.filter) { (None, rules_filter) => rules_filter, (filter, None) => filter, (Some(filter), Some(rules_filter)) => { @@ -719,131 +782,52 @@ pub fn perform_search( SearchKind::Hybrid { semantic_ratio, .. } => search.execute_hybrid(*semantic_ratio)?, }; - let fields_ids_map = index.fields_ids_map(&rtxn).unwrap(); + let SearchQuery { + q, + vector: _, + hybrid: _, + // already computed from prepare_search + offset: _, + limit, + page, + hits_per_page, + attributes_to_retrieve, + attributes_to_crop, + crop_length, + attributes_to_highlight, + show_matches_position, + show_ranking_score, + show_ranking_score_details, + filter: _, + sort, + facets, + highlight_pre_tag, + highlight_post_tag, + crop_marker, + matching_strategy: _, + attributes_to_search_on: _, + } = query; - let displayed_ids = index - .displayed_fields_ids(&rtxn)? - .map(|fields| fields.into_iter().collect::>()) - .unwrap_or_else(|| fields_ids_map.iter().map(|(id, _)| id).collect()); - - let fids = |attrs: &BTreeSet| { - let mut ids = BTreeSet::new(); - for attr in attrs { - if attr == "*" { - ids.clone_from(&displayed_ids); - break; - } - - if let Some(id) = fields_ids_map.id(attr) { - ids.insert(id); - } - } - ids + let format = AttributesFormat { + attributes_to_retrieve, + attributes_to_highlight, + attributes_to_crop, + crop_length, + crop_marker, + highlight_pre_tag, + highlight_post_tag, + show_matches_position, + sort, + show_ranking_score, + show_ranking_score_details, }; - // The attributes to retrieve are the ones explicitly marked as to retrieve (all by default), - // but these attributes must be also be present - // - in the fields_ids_map - // - in the displayed attributes - let to_retrieve_ids: BTreeSet<_> = query - .attributes_to_retrieve - .as_ref() - .map(fids) - .unwrap_or_else(|| displayed_ids.clone()) - .intersection(&displayed_ids) - .cloned() - .collect(); - - let attr_to_highlight = query.attributes_to_highlight.unwrap_or_default(); - - let attr_to_crop = query.attributes_to_crop.unwrap_or_default(); - - // Attributes in `formatted_options` correspond to the attributes that will be in `_formatted` - // These attributes are: - // - the attributes asked to be highlighted or cropped (with `attributesToCrop` or `attributesToHighlight`) - // - the attributes asked to be retrieved: these attributes will not be highlighted/cropped - // But these attributes must be also present in displayed attributes - let formatted_options = compute_formatted_options( - &attr_to_highlight, - &attr_to_crop, - query.crop_length, - &to_retrieve_ids, - &fields_ids_map, - &displayed_ids, - ); - - let mut tokenizer_builder = TokenizerBuilder::default(); - tokenizer_builder.create_char_map(true); - - let script_lang_map = index.script_language(&rtxn)?; - if !script_lang_map.is_empty() { - tokenizer_builder.allow_list(&script_lang_map); - } - - let separators = index.allowed_separators(&rtxn)?; - let separators: Option> = - separators.as_ref().map(|x| x.iter().map(String::as_str).collect()); - if let Some(ref separators) = separators { - tokenizer_builder.separators(separators); - } - - let dictionary = index.dictionary(&rtxn)?; - let dictionary: Option> = - dictionary.as_ref().map(|x| x.iter().map(String::as_str).collect()); - if let Some(ref dictionary) = dictionary { - tokenizer_builder.words_dict(dictionary); - } - - let mut formatter_builder = MatcherBuilder::new(matching_words, tokenizer_builder.build()); - formatter_builder.crop_marker(query.crop_marker); - formatter_builder.highlight_prefix(query.highlight_pre_tag); - formatter_builder.highlight_suffix(query.highlight_post_tag); - - let mut documents = Vec::new(); - let documents_iter = index.documents(&rtxn, documents_ids)?; - - for ((_id, obkv), score) in documents_iter.into_iter().zip(document_scores.into_iter()) { - // First generate a document with all the displayed fields - let displayed_document = make_document(&displayed_ids, &fields_ids_map, obkv)?; - - // select the attributes to retrieve - let attributes_to_retrieve = to_retrieve_ids - .iter() - .map(|&fid| fields_ids_map.name(fid).expect("Missing field name")); - let mut document = - permissive_json_pointer::select_values(&displayed_document, attributes_to_retrieve); - - let (matches_position, formatted) = format_fields( - &displayed_document, - &fields_ids_map, - &formatter_builder, - &formatted_options, - query.show_matches_position, - &displayed_ids, - )?; - - if let Some(sort) = query.sort.as_ref() { - insert_geo_distance(sort, &mut document); - } - - let ranking_score = - query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter())); - let ranking_score_details = - query.show_ranking_score_details.then(|| ScoreDetails::to_json_map(score.iter())); - - let hit = SearchHit { - document, - formatted, - matches_position, - ranking_score_details, - ranking_score, - }; - documents.push(hit); - } + let documents = + make_hits(index, &rtxn, format, matching_words, documents_ids, document_scores)?; let number_of_hits = min(candidates.len() as usize, max_total_hits); let hits_info = if is_finite_pagination { - let hits_per_page = query.hits_per_page.unwrap_or_else(DEFAULT_SEARCH_LIMIT); + let hits_per_page = hits_per_page.unwrap_or_else(DEFAULT_SEARCH_LIMIT); // If hit_per_page is 0, then pages can't be computed and so we respond 0. let total_pages = (number_of_hits + hits_per_page.saturating_sub(1)) .checked_div(hits_per_page) @@ -851,15 +835,15 @@ pub fn perform_search( HitsInfo::Pagination { hits_per_page, - page: query.page.unwrap_or(1), + page: page.unwrap_or(1), total_pages, total_hits: number_of_hits, } } else { - HitsInfo::OffsetLimit { limit: query.limit, offset, estimated_total_hits: number_of_hits } + HitsInfo::OffsetLimit { limit, offset, estimated_total_hits: number_of_hits } }; - let (facet_distribution, facet_stats) = match query.facets { + let (facet_distribution, facet_stats) = match facets { Some(ref fields) => { let mut facet_distribution = index.facets_distribution(&rtxn); @@ -896,7 +880,7 @@ pub fn perform_search( let result = SearchResult { hits: documents, hits_info, - query: query.q.unwrap_or_default(), + query: q.unwrap_or_default(), processing_time_ms: before_search.elapsed().as_millis(), facet_distribution, facet_stats, @@ -907,6 +891,130 @@ pub fn perform_search( Ok(result) } +struct AttributesFormat { + attributes_to_retrieve: Option>, + attributes_to_highlight: Option>, + attributes_to_crop: Option>, + crop_length: usize, + crop_marker: String, + highlight_pre_tag: String, + highlight_post_tag: String, + show_matches_position: bool, + sort: Option>, + show_ranking_score: bool, + show_ranking_score_details: bool, +} + +fn make_hits( + index: &Index, + rtxn: &RoTxn<'_>, + format: AttributesFormat, + matching_words: milli::MatchingWords, + documents_ids: Vec, + document_scores: Vec>, +) -> Result, MeilisearchHttpError> { + let fields_ids_map = index.fields_ids_map(rtxn).unwrap(); + let displayed_ids = index + .displayed_fields_ids(rtxn)? + .map(|fields| fields.into_iter().collect::>()) + .unwrap_or_else(|| fields_ids_map.iter().map(|(id, _)| id).collect()); + let fids = |attrs: &BTreeSet| { + let mut ids = BTreeSet::new(); + for attr in attrs { + if attr == "*" { + ids.clone_from(&displayed_ids); + break; + } + + if let Some(id) = fields_ids_map.id(attr) { + ids.insert(id); + } + } + ids + }; + let to_retrieve_ids: BTreeSet<_> = format + .attributes_to_retrieve + .as_ref() + .map(fids) + .unwrap_or_else(|| displayed_ids.clone()) + .intersection(&displayed_ids) + .cloned() + .collect(); + let attr_to_highlight = format.attributes_to_highlight.unwrap_or_default(); + let attr_to_crop = format.attributes_to_crop.unwrap_or_default(); + let formatted_options = compute_formatted_options( + &attr_to_highlight, + &attr_to_crop, + format.crop_length, + &to_retrieve_ids, + &fields_ids_map, + &displayed_ids, + ); + let mut tokenizer_builder = TokenizerBuilder::default(); + tokenizer_builder.create_char_map(true); + let script_lang_map = index.script_language(rtxn)?; + if !script_lang_map.is_empty() { + tokenizer_builder.allow_list(&script_lang_map); + } + let separators = index.allowed_separators(rtxn)?; + let separators: Option> = + separators.as_ref().map(|x| x.iter().map(String::as_str).collect()); + if let Some(ref separators) = separators { + tokenizer_builder.separators(separators); + } + let dictionary = index.dictionary(rtxn)?; + let dictionary: Option> = + dictionary.as_ref().map(|x| x.iter().map(String::as_str).collect()); + if let Some(ref dictionary) = dictionary { + tokenizer_builder.words_dict(dictionary); + } + let mut formatter_builder = MatcherBuilder::new(matching_words, tokenizer_builder.build()); + formatter_builder.crop_marker(format.crop_marker); + formatter_builder.highlight_prefix(format.highlight_pre_tag); + formatter_builder.highlight_suffix(format.highlight_post_tag); + let mut documents = Vec::new(); + let documents_iter = index.documents(rtxn, documents_ids)?; + for ((_id, obkv), score) in documents_iter.into_iter().zip(document_scores.into_iter()) { + // First generate a document with all the displayed fields + let displayed_document = make_document(&displayed_ids, &fields_ids_map, obkv)?; + + // select the attributes to retrieve + let attributes_to_retrieve = to_retrieve_ids + .iter() + .map(|&fid| fields_ids_map.name(fid).expect("Missing field name")); + let mut document = + permissive_json_pointer::select_values(&displayed_document, attributes_to_retrieve); + + let (matches_position, formatted) = format_fields( + &displayed_document, + &fields_ids_map, + &formatter_builder, + &formatted_options, + format.show_matches_position, + &displayed_ids, + )?; + + if let Some(sort) = format.sort.as_ref() { + insert_geo_distance(sort, &mut document); + } + + let ranking_score = + format.show_ranking_score.then(|| ScoreDetails::global_score(score.iter())); + let ranking_score_details = + format.show_ranking_score_details.then(|| ScoreDetails::to_json_map(score.iter())); + + let hit = SearchHit { + document, + formatted, + matches_position, + ranking_score_details, + ranking_score, + }; + documents.push(hit); + } + Ok(documents) +} + pub fn perform_facet_search( index: &Index, search_query: SearchQuery, @@ -941,6 +1049,95 @@ pub fn perform_facet_search( }) } +pub fn perform_similar( + index: &Index, + query: SimilarQuery, + embedder_name: String, + embedder: Arc, +) -> Result { + let before_search = Instant::now(); + let rtxn = index.read_txn()?; + + let SimilarQuery { + id, + offset, + limit, + filter: _, + embedder: _, + attributes_to_retrieve, + show_ranking_score, + show_ranking_score_details, + } = query; + + // using let-else rather than `?` so that the borrow checker identifies we're always returning here, + // preventing a use-after-move + let Some(internal_id) = index.external_documents_ids().get(&rtxn, &id)? else { + return Err(ResponseError::from_msg( + MeilisearchHttpError::DocumentNotFound(id.into_inner()).to_string(), + Code::NotFoundSimilarId, + )); + }; + + let mut similar = + milli::Similar::new(internal_id, offset, limit, index, &rtxn, embedder_name, embedder); + + if let Some(ref filter) = query.filter { + if let Some(facets) = parse_filter(filter) + // inject InvalidSimilarFilter code + .map_err(|e| ResponseError::from_msg(e.to_string(), Code::InvalidSimilarFilter))? + { + similar.filter(facets); + } + } + + let milli::SearchResult { + documents_ids, + matching_words: _, + candidates, + document_scores, + degraded: _, + used_negative_operator: _, + } = similar.execute().map_err(|err| match err { + milli::Error::UserError(milli::UserError::InvalidFilter(_)) => { + ResponseError::from_msg(err.to_string(), Code::InvalidSimilarFilter) + } + err => err.into(), + })?; + + let format = AttributesFormat { + attributes_to_retrieve, + attributes_to_highlight: None, + attributes_to_crop: None, + crop_length: DEFAULT_CROP_LENGTH(), + crop_marker: DEFAULT_CROP_MARKER(), + highlight_pre_tag: DEFAULT_HIGHLIGHT_PRE_TAG(), + highlight_post_tag: DEFAULT_HIGHLIGHT_POST_TAG(), + show_matches_position: false, + sort: None, + show_ranking_score, + show_ranking_score_details, + }; + + let hits = make_hits(index, &rtxn, format, Default::default(), documents_ids, document_scores)?; + + let max_total_hits = index + .pagination_max_total_hits(&rtxn) + .map_err(milli::Error::from)? + .map(|x| x as usize) + .unwrap_or(DEFAULT_PAGINATION_MAX_TOTAL_HITS); + + let number_of_hits = min(candidates.len() as usize, max_total_hits); + let hits_info = HitsInfo::OffsetLimit { limit, offset, estimated_total_hits: number_of_hits }; + + let result = SimilarResult { + hits, + hits_info, + id: id.into_inner(), + processing_time_ms: before_search.elapsed().as_millis(), + }; + Ok(result) +} + fn insert_geo_distance(sorts: &[String], document: &mut Document) { lazy_static::lazy_static! { static ref GEO_REGEX: Regex = diff --git a/meilisearch/tests/common/index.rs b/meilisearch/tests/common/index.rs index 9ed6a6077..3ac33b4e9 100644 --- a/meilisearch/tests/common/index.rs +++ b/meilisearch/tests/common/index.rs @@ -380,6 +380,43 @@ impl Index<'_> { self.service.get(url).await } + /// Performs both GET and POST similar queries + pub async fn similar( + &self, + query: Value, + test: impl Fn(Value, StatusCode) + UnwindSafe + Clone, + ) { + let post = self.similar_post(query.clone()).await; + + let query = yaup::to_string(&query).unwrap(); + let get = self.similar_get(&query).await; + + insta::allow_duplicates! { + let (response, code) = post; + let t = test.clone(); + if let Err(e) = catch_unwind(move || t(response, code)) { + eprintln!("Error with post search"); + resume_unwind(e); + } + + let (response, code) = get; + if let Err(e) = catch_unwind(move || test(response, code)) { + eprintln!("Error with get search"); + resume_unwind(e); + } + } + } + + pub async fn similar_post(&self, query: Value) -> (Value, StatusCode) { + let url = format!("/indexes/{}/similar", urlencode(self.uid.as_ref())); + self.service.post_encoded(url, query, self.encoder).await + } + + pub async fn similar_get(&self, query: &str) -> (Value, StatusCode) { + let url = format!("/indexes/{}/similar?{}", urlencode(self.uid.as_ref()), query); + self.service.get(url).await + } + pub async fn facet_search(&self, query: Value) -> (Value, StatusCode) { let url = format!("/indexes/{}/facet-search", urlencode(self.uid.as_ref())); self.service.post_encoded(url, query, self.encoder).await diff --git a/meilisearch/tests/integration.rs b/meilisearch/tests/integration.rs index 943af802a..bb77ecc63 100644 --- a/meilisearch/tests/integration.rs +++ b/meilisearch/tests/integration.rs @@ -8,6 +8,7 @@ mod index; mod logs; mod search; mod settings; +mod similar; mod snapshot; mod stats; mod swap_indexes; diff --git a/meilisearch/tests/similar/errors.rs b/meilisearch/tests/similar/errors.rs new file mode 100644 index 000000000..64386a7bf --- /dev/null +++ b/meilisearch/tests/similar/errors.rs @@ -0,0 +1,696 @@ +use meili_snap::*; + +use super::DOCUMENTS; +use crate::common::Server; +use crate::json; + +#[actix_rt::test] +async fn similar_unexisting_index() { + let server = Server::new().await; + let index = server.index("test"); + server.set_features(json!({"vectorStore": true})).await; + + let expected_response = json!({ + "message": "Index `test` not found.", + "code": "index_not_found", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#index_not_found" + }); + + index + .similar(json!({"id": 287947}), |response, code| { + assert_eq!(code, 404); + assert_eq!(response, expected_response); + }) + .await; +} + +#[actix_rt::test] +async fn similar_unexisting_parameter() { + let server = Server::new().await; + let index = server.index("test"); + server.set_features(json!({"vectorStore": true})).await; + + index + .similar(json!({"id": 287947, "marin": "hello"}), |response, code| { + assert_eq!(code, 400, "{}", response); + assert_eq!(response["code"], "bad_request"); + }) + .await; +} + +#[actix_rt::test] +async fn similar_feature_not_enabled() { + let server = Server::new().await; + let index = server.index("test"); + + let (response, code) = index.similar_post(json!({"id": 287947})).await; + snapshot!(code, @"400 Bad Request"); + snapshot!(json_string!(response), @r###" + { + "message": "Using the similar API requires enabling the `vector store` experimental feature. See https://github.com/meilisearch/product/discussions/677", + "code": "feature_not_enabled", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#feature_not_enabled" + } + "###); +} + +#[actix_rt::test] +async fn similar_bad_id() { + let server = Server::new().await; + let index = server.index("test"); + server.set_features(json!({"vectorStore": true})).await; + + let (response, code) = index + .update_settings(json!({ + "embedders": { + "manual": { + "source": "userProvided", + "dimensions": 3, + } + }, + "filterableAttributes": ["title"]})) + .await; + snapshot!(code, @"202 Accepted"); + server.wait_task(response.uid()).await; + + let (response, code) = index.similar_post(json!({"id": ["doggo"]})).await; + snapshot!(code, @"400 Bad Request"); + snapshot!(json_string!(response), @r###" + { + "message": "Invalid value at `.id`: the value of `id` is invalid. A document identifier can be of type integer or string, only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and underscores (_).", + "code": "invalid_similar_id", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_similar_id" + } + "###); +} + +#[actix_rt::test] +async fn similar_invalid_id() { + let server = Server::new().await; + let index = server.index("test"); + server.set_features(json!({"vectorStore": true})).await; + + let (response, code) = index + .update_settings(json!({ + "embedders": { + "manual": { + "source": "userProvided", + "dimensions": 3, + } + }, + "filterableAttributes": ["title"]})) + .await; + snapshot!(code, @"202 Accepted"); + server.wait_task(response.uid()).await; + + let (response, code) = index.similar_post(json!({"id": "http://invalid-docid/"})).await; + snapshot!(code, @"400 Bad Request"); + snapshot!(json_string!(response), @r###" + { + "message": "Invalid value at `.id`: the value of `id` is invalid. A document identifier can be of type integer or string, only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and underscores (_).", + "code": "invalid_similar_id", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_similar_id" + } + "###); +} + +#[actix_rt::test] +async fn similar_not_found_id() { + let server = Server::new().await; + let index = server.index("test"); + server.set_features(json!({"vectorStore": true})).await; + + let (response, code) = index + .update_settings(json!({ + "embedders": { + "manual": { + "source": "userProvided", + "dimensions": 3, + } + }, + "filterableAttributes": ["title"]})) + .await; + snapshot!(code, @"202 Accepted"); + server.wait_task(response.uid()).await; + + let (response, code) = index.similar_post(json!({"id": "definitely-doesnt-exist"})).await; + snapshot!(code, @"400 Bad Request"); + snapshot!(json_string!(response), @r###" + { + "message": "Document `definitely-doesnt-exist` not found.", + "code": "not_found_similar_id", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#not_found_similar_id" + } + "###); +} + +#[actix_rt::test] +async fn similar_bad_offset() { + let server = Server::new().await; + let index = server.index("test"); + server.set_features(json!({"vectorStore": true})).await; + + let (response, code) = index + .update_settings(json!({ + "embedders": { + "manual": { + "source": "userProvided", + "dimensions": 3, + } + }, + "filterableAttributes": ["title"]})) + .await; + snapshot!(code, @"202 Accepted"); + server.wait_task(response.uid()).await; + + let (response, code) = index.similar_post(json!({"id": 287947, "offset": "doggo"})).await; + snapshot!(code, @"400 Bad Request"); + snapshot!(json_string!(response), @r###" + { + "message": "Invalid value type at `.offset`: expected a positive integer, but found a string: `\"doggo\"`", + "code": "invalid_similar_offset", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_similar_offset" + } + "###); + + let (response, code) = index.similar_get("id=287947&offset=doggo").await; + snapshot!(code, @"400 Bad Request"); + snapshot!(json_string!(response), @r###" + { + "message": "Invalid value in parameter `offset`: could not parse `doggo` as a positive integer", + "code": "invalid_similar_offset", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_similar_offset" + } + "###); +} + +#[actix_rt::test] +async fn similar_bad_limit() { + let server = Server::new().await; + let index = server.index("test"); + server.set_features(json!({"vectorStore": true})).await; + + let (response, code) = index + .update_settings(json!({ + "embedders": { + "manual": { + "source": "userProvided", + "dimensions": 3, + } + }, + "filterableAttributes": ["title"]})) + .await; + snapshot!(code, @"202 Accepted"); + server.wait_task(response.uid()).await; + + let (response, code) = index.similar_post(json!({"id": 287947, "limit": "doggo"})).await; + snapshot!(code, @"400 Bad Request"); + snapshot!(json_string!(response), @r###" + { + "message": "Invalid value type at `.limit`: expected a positive integer, but found a string: `\"doggo\"`", + "code": "invalid_similar_limit", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_similar_limit" + } + "###); + + let (response, code) = index.similar_get("id=287946&limit=doggo").await; + snapshot!(code, @"400 Bad Request"); + snapshot!(json_string!(response), @r###" + { + "message": "Invalid value in parameter `limit`: could not parse `doggo` as a positive integer", + "code": "invalid_similar_limit", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_similar_limit" + } + "###); +} + +#[actix_rt::test] +async fn similar_bad_filter() { + // Since a filter is deserialized as a json Value it will never fail to deserialize. + // Thus the error message is not generated by deserr but written by us. + let server = Server::new().await; + let index = server.index("test"); + server.set_features(json!({"vectorStore": true})).await; + + let (response, code) = index + .update_settings(json!({ + "embedders": { + "manual": { + "source": "userProvided", + "dimensions": 3, + } + }, + "filterableAttributes": ["title"]})) + .await; + snapshot!(code, @"202 Accepted"); + server.wait_task(response.uid()).await; + + snapshot!(code, @"202 Accepted"); + + let documents = DOCUMENTS.clone(); + let (value, code) = index.add_documents(documents, None).await; + snapshot!(code, @"202 Accepted"); + index.wait_task(value.uid()).await; + + let (response, code) = index.similar_post(json!({ "id": 287947, "filter": true })).await; + snapshot!(code, @"400 Bad Request"); + snapshot!(json_string!(response), @r###" + { + "message": "Invalid syntax for the filter parameter: `expected String, Array, found: true`.", + "code": "invalid_similar_filter", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_similar_filter" + } + "###); + // Can't make the `filter` fail with a get search since it'll accept anything as a strings. +} + +#[actix_rt::test] +async fn filter_invalid_syntax_object() { + let server = Server::new().await; + let index = server.index("test"); + server.set_features(json!({"vectorStore": true})).await; + + let (response, code) = index + .update_settings(json!({ + "embedders": { + "manual": { + "source": "userProvided", + "dimensions": 3, + } + }, + "filterableAttributes": ["title"]})) + .await; + snapshot!(code, @"202 Accepted"); + server.wait_task(response.uid()).await; + + let documents = DOCUMENTS.clone(); + let (value, code) = index.add_documents(documents, None).await; + snapshot!(code, @"202 Accepted"); + index.wait_task(value.uid()).await; + + let expected_response = json!({ + "message": "Was expecting an operation `=`, `!=`, `>=`, `>`, `<=`, `<`, `IN`, `NOT IN`, `TO`, `EXISTS`, `NOT EXISTS`, `IS NULL`, `IS NOT NULL`, `IS EMPTY`, `IS NOT EMPTY`, `_geoRadius`, or `_geoBoundingBox` at `title & Glass`.\n1:14 title & Glass", + "code": "invalid_similar_filter", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_similar_filter" + }); + index + .similar(json!({"id": 287947, "filter": "title & Glass"}), |response, code| { + assert_eq!(response, expected_response); + assert_eq!(code, 400); + }) + .await; +} + +#[actix_rt::test] +async fn filter_invalid_syntax_array() { + let server = Server::new().await; + let index = server.index("test"); + server.set_features(json!({"vectorStore": true})).await; + + let (response, code) = index + .update_settings(json!({ + "embedders": { + "manual": { + "source": "userProvided", + "dimensions": 3, + } + }, + "filterableAttributes": ["title"]})) + .await; + snapshot!(code, @"202 Accepted"); + server.wait_task(response.uid()).await; + + let documents = DOCUMENTS.clone(); + let (value, code) = index.add_documents(documents, None).await; + snapshot!(code, @"202 Accepted"); + index.wait_task(value.uid()).await; + + let expected_response = json!({ + "message": "Was expecting an operation `=`, `!=`, `>=`, `>`, `<=`, `<`, `IN`, `NOT IN`, `TO`, `EXISTS`, `NOT EXISTS`, `IS NULL`, `IS NOT NULL`, `IS EMPTY`, `IS NOT EMPTY`, `_geoRadius`, or `_geoBoundingBox` at `title & Glass`.\n1:14 title & Glass", + "code": "invalid_similar_filter", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_similar_filter" + }); + index + .similar(json!({"id": 287947, "filter": ["title & Glass"]}), |response, code| { + assert_eq!(response, expected_response); + assert_eq!(code, 400); + }) + .await; +} + +#[actix_rt::test] +async fn filter_invalid_syntax_string() { + let server = Server::new().await; + let index = server.index("test"); + server.set_features(json!({"vectorStore": true})).await; + + let (response, code) = index + .update_settings(json!({ + "embedders": { + "manual": { + "source": "userProvided", + "dimensions": 3, + } + }, + "filterableAttributes": ["title"]})) + .await; + snapshot!(code, @"202 Accepted"); + server.wait_task(response.uid()).await; + + let documents = DOCUMENTS.clone(); + let (value, code) = index.add_documents(documents, None).await; + snapshot!(code, @"202 Accepted"); + index.wait_task(value.uid()).await; + + let expected_response = json!({ + "message": "Found unexpected characters at the end of the filter: `XOR title = Glass`. You probably forgot an `OR` or an `AND` rule.\n15:32 title = Glass XOR title = Glass", + "code": "invalid_similar_filter", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_similar_filter" + }); + index + .similar( + json!({"id": 287947, "filter": "title = Glass XOR title = Glass"}), + |response, code| { + assert_eq!(response, expected_response); + assert_eq!(code, 400); + }, + ) + .await; +} + +#[actix_rt::test] +async fn filter_invalid_attribute_array() { + let server = Server::new().await; + let index = server.index("test"); + server.set_features(json!({"vectorStore": true})).await; + + let (response, code) = index + .update_settings(json!({ + "embedders": { + "manual": { + "source": "userProvided", + "dimensions": 3, + } + }, + "filterableAttributes": ["title"]})) + .await; + snapshot!(code, @"202 Accepted"); + server.wait_task(response.uid()).await; + + let documents = DOCUMENTS.clone(); + let (value, code) = index.add_documents(documents, None).await; + snapshot!(code, @"202 Accepted"); + index.wait_task(value.uid()).await; + + let expected_response = json!({ + "message": "Attribute `many` is not filterable. Available filterable attributes are: `title`.\n1:5 many = Glass", + "code": "invalid_similar_filter", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_similar_filter" + }); + index + .similar(json!({"id": 287947, "filter": ["many = Glass"]}), |response, code| { + assert_eq!(response, expected_response); + assert_eq!(code, 400); + }) + .await; +} + +#[actix_rt::test] +async fn filter_invalid_attribute_string() { + let server = Server::new().await; + let index = server.index("test"); + server.set_features(json!({"vectorStore": true})).await; + + let (response, code) = index + .update_settings(json!({ + "embedders": { + "manual": { + "source": "userProvided", + "dimensions": 3, + } + }, + "filterableAttributes": ["title"]})) + .await; + snapshot!(code, @"202 Accepted"); + server.wait_task(response.uid()).await; + + let documents = DOCUMENTS.clone(); + let (value, code) = index.add_documents(documents, None).await; + snapshot!(code, @"202 Accepted"); + index.wait_task(value.uid()).await; + + let expected_response = json!({ + "message": "Attribute `many` is not filterable. Available filterable attributes are: `title`.\n1:5 many = Glass", + "code": "invalid_similar_filter", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_similar_filter" + }); + index + .similar(json!({"id": 287947, "filter": "many = Glass"}), |response, code| { + assert_eq!(response, expected_response); + assert_eq!(code, 400); + }) + .await; +} + +#[actix_rt::test] +async fn filter_reserved_geo_attribute_array() { + let server = Server::new().await; + let index = server.index("test"); + server.set_features(json!({"vectorStore": true})).await; + + let (response, code) = index + .update_settings(json!({ + "embedders": { + "manual": { + "source": "userProvided", + "dimensions": 3, + } + }, + "filterableAttributes": ["title"]})) + .await; + snapshot!(code, @"202 Accepted"); + server.wait_task(response.uid()).await; + + let documents = DOCUMENTS.clone(); + let (value, code) = index.add_documents(documents, None).await; + snapshot!(code, @"202 Accepted"); + index.wait_task(value.uid()).await; + + let expected_response = json!({ + "message": "`_geo` is a reserved keyword and thus can't be used as a filter expression. Use the `_geoRadius(latitude, longitude, distance)` or `_geoBoundingBox([latitude, longitude], [latitude, longitude])` built-in rules to filter on `_geo` coordinates.\n1:13 _geo = Glass", + "code": "invalid_similar_filter", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_similar_filter" + }); + index + .similar(json!({"id": 287947, "filter": ["_geo = Glass"]}), |response, code| { + assert_eq!(response, expected_response); + assert_eq!(code, 400); + }) + .await; +} + +#[actix_rt::test] +async fn filter_reserved_geo_attribute_string() { + let server = Server::new().await; + let index = server.index("test"); + server.set_features(json!({"vectorStore": true})).await; + + let (response, code) = index + .update_settings(json!({ + "embedders": { + "manual": { + "source": "userProvided", + "dimensions": 3, + } + }, + "filterableAttributes": ["title"]})) + .await; + snapshot!(code, @"202 Accepted"); + server.wait_task(response.uid()).await; + + let documents = DOCUMENTS.clone(); + let (value, code) = index.add_documents(documents, None).await; + snapshot!(code, @"202 Accepted"); + index.wait_task(value.uid()).await; + + let expected_response = json!({ + "message": "`_geo` is a reserved keyword and thus can't be used as a filter expression. Use the `_geoRadius(latitude, longitude, distance)` or `_geoBoundingBox([latitude, longitude], [latitude, longitude])` built-in rules to filter on `_geo` coordinates.\n1:13 _geo = Glass", + "code": "invalid_similar_filter", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_similar_filter" + }); + index + .similar(json!({"id": 287947, "filter": "_geo = Glass"}), |response, code| { + assert_eq!(response, expected_response); + assert_eq!(code, 400); + }) + .await; +} + +#[actix_rt::test] +async fn filter_reserved_attribute_array() { + let server = Server::new().await; + let index = server.index("test"); + server.set_features(json!({"vectorStore": true})).await; + + let (response, code) = index + .update_settings(json!({ + "embedders": { + "manual": { + "source": "userProvided", + "dimensions": 3, + } + }, + "filterableAttributes": ["title"]})) + .await; + snapshot!(code, @"202 Accepted"); + server.wait_task(response.uid()).await; + + let documents = DOCUMENTS.clone(); + let (value, code) = index.add_documents(documents, None).await; + snapshot!(code, @"202 Accepted"); + index.wait_task(value.uid()).await; + + let expected_response = json!({ + "message": "`_geoDistance` is a reserved keyword and thus can't be used as a filter expression. Use the `_geoRadius(latitude, longitude, distance)` or `_geoBoundingBox([latitude, longitude], [latitude, longitude])` built-in rules to filter on `_geo` coordinates.\n1:21 _geoDistance = Glass", + "code": "invalid_similar_filter", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_similar_filter" + }); + index + .similar(json!({"id": 287947, "filter": ["_geoDistance = Glass"]}), |response, code| { + assert_eq!(response, expected_response); + assert_eq!(code, 400); + }) + .await; +} + +#[actix_rt::test] +async fn filter_reserved_attribute_string() { + let server = Server::new().await; + let index = server.index("test"); + server.set_features(json!({"vectorStore": true})).await; + + let (response, code) = index + .update_settings(json!({ + "embedders": { + "manual": { + "source": "userProvided", + "dimensions": 3, + } + }, + "filterableAttributes": ["title"]})) + .await; + snapshot!(code, @"202 Accepted"); + server.wait_task(response.uid()).await; + + let documents = DOCUMENTS.clone(); + let (value, code) = index.add_documents(documents, None).await; + snapshot!(code, @"202 Accepted"); + index.wait_task(value.uid()).await; + + let expected_response = json!({ + "message": "`_geoDistance` is a reserved keyword and thus can't be used as a filter expression. Use the `_geoRadius(latitude, longitude, distance)` or `_geoBoundingBox([latitude, longitude], [latitude, longitude])` built-in rules to filter on `_geo` coordinates.\n1:21 _geoDistance = Glass", + "code": "invalid_similar_filter", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_similar_filter" + }); + index + .similar(json!({"id": 287947, "filter": "_geoDistance = Glass"}), |response, code| { + assert_eq!(response, expected_response); + assert_eq!(code, 400); + }) + .await; +} + +#[actix_rt::test] +async fn filter_reserved_geo_point_array() { + let server = Server::new().await; + let index = server.index("test"); + server.set_features(json!({"vectorStore": true})).await; + + let (response, code) = index + .update_settings(json!({ + "embedders": { + "manual": { + "source": "userProvided", + "dimensions": 3, + } + }, + "filterableAttributes": ["title"]})) + .await; + snapshot!(code, @"202 Accepted"); + server.wait_task(response.uid()).await; + + let documents = DOCUMENTS.clone(); + let (value, code) = index.add_documents(documents, None).await; + snapshot!(code, @"202 Accepted"); + index.wait_task(value.uid()).await; + + let expected_response = json!({ + "message": "`_geoPoint` is a reserved keyword and thus can't be used as a filter expression. Use the `_geoRadius(latitude, longitude, distance)` or `_geoBoundingBox([latitude, longitude], [latitude, longitude])` built-in rules to filter on `_geo` coordinates.\n1:18 _geoPoint = Glass", + "code": "invalid_similar_filter", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_similar_filter" + }); + index + .similar(json!({"id": 287947, "filter": ["_geoPoint = Glass"]}), |response, code| { + assert_eq!(response, expected_response); + assert_eq!(code, 400); + }) + .await; +} + +#[actix_rt::test] +async fn filter_reserved_geo_point_string() { + let server = Server::new().await; + let index = server.index("test"); + server.set_features(json!({"vectorStore": true})).await; + + let (response, code) = index + .update_settings(json!({ + "embedders": { + "manual": { + "source": "userProvided", + "dimensions": 3, + } + }, + "filterableAttributes": ["title"]})) + .await; + snapshot!(code, @"202 Accepted"); + server.wait_task(response.uid()).await; + + let documents = DOCUMENTS.clone(); + let (value, code) = index.add_documents(documents, None).await; + snapshot!(code, @"202 Accepted"); + index.wait_task(value.uid()).await; + + let expected_response = json!({ + "message": "`_geoPoint` is a reserved keyword and thus can't be used as a filter expression. Use the `_geoRadius(latitude, longitude, distance)` or `_geoBoundingBox([latitude, longitude], [latitude, longitude])` built-in rules to filter on `_geo` coordinates.\n1:18 _geoPoint = Glass", + "code": "invalid_similar_filter", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_similar_filter" + }); + index + .similar(json!({"id": 287947, "filter": "_geoPoint = Glass"}), |response, code| { + assert_eq!(response, expected_response); + assert_eq!(code, 400); + }) + .await; +} diff --git a/meilisearch/tests/similar/mod.rs b/meilisearch/tests/similar/mod.rs new file mode 100644 index 000000000..ee78917cb --- /dev/null +++ b/meilisearch/tests/similar/mod.rs @@ -0,0 +1,373 @@ +mod errors; + +use meili_snap::{json_string, snapshot}; +use once_cell::sync::Lazy; + +use crate::common::{Server, Value}; +use crate::json; + +static DOCUMENTS: Lazy = Lazy::new(|| { + json!([ + { + "title": "Shazam!", + "release_year": 2019, + "id": "287947", + // Three semantic properties: + // 1. magic, anything that reminds you of magic + // 2. authority, anything that inspires command + // 3. horror, anything that inspires fear or dread + "_vectors": { "manual": [0.8, 0.4, -0.5]}, + }, + { + "title": "Captain Marvel", + "release_year": 2019, + "id": "299537", + "_vectors": { "manual": [0.6, 0.8, -0.2] }, + }, + { + "title": "Escape Room", + "release_year": 2019, + "id": "522681", + "_vectors": { "manual": [0.1, 0.6, 0.8] }, + }, + { + "title": "How to Train Your Dragon: The Hidden World", + "release_year": 2019, + "id": "166428", + "_vectors": { "manual": [0.7, 0.7, -0.4] }, + }, + { + "title": "All Quiet on the Western Front", + "release_year": 1930, + "id": "143", + "_vectors": { "manual": [-0.5, 0.3, 0.85] }, + } + ]) +}); + +#[actix_rt::test] +async fn basic() { + let server = Server::new().await; + let index = server.index("test"); + let (value, code) = server.set_features(json!({"vectorStore": true})).await; + snapshot!(code, @"200 OK"); + snapshot!(value, @r###" + { + "vectorStore": true, + "metrics": false, + "logsRoute": false + } + "###); + + let (response, code) = index + .update_settings(json!({ + "embedders": { + "manual": { + "source": "userProvided", + "dimensions": 3, + } + }, + "filterableAttributes": ["title"]})) + .await; + snapshot!(code, @"202 Accepted"); + server.wait_task(response.uid()).await; + + let documents = DOCUMENTS.clone(); + let (value, code) = index.add_documents(documents, None).await; + snapshot!(code, @"202 Accepted"); + index.wait_task(value.uid()).await; + + index + .similar(json!({"id": 143}), |response, code| { + snapshot!(code, @"200 OK"); + snapshot!(json_string!(response["hits"]), @r###" + [ + { + "title": "Escape Room", + "release_year": 2019, + "id": "522681", + "_vectors": { + "manual": [ + 0.1, + 0.6, + 0.8 + ] + } + }, + { + "title": "Captain Marvel", + "release_year": 2019, + "id": "299537", + "_vectors": { + "manual": [ + 0.6, + 0.8, + -0.2 + ] + } + }, + { + "title": "How to Train Your Dragon: The Hidden World", + "release_year": 2019, + "id": "166428", + "_vectors": { + "manual": [ + 0.7, + 0.7, + -0.4 + ] + } + }, + { + "title": "Shazam!", + "release_year": 2019, + "id": "287947", + "_vectors": { + "manual": [ + 0.8, + 0.4, + -0.5 + ] + } + } + ] + "###); + }) + .await; + + index + .similar(json!({"id": "299537"}), |response, code| { + snapshot!(code, @"200 OK"); + snapshot!(json_string!(response["hits"]), @r###" + [ + { + "title": "How to Train Your Dragon: The Hidden World", + "release_year": 2019, + "id": "166428", + "_vectors": { + "manual": [ + 0.7, + 0.7, + -0.4 + ] + } + }, + { + "title": "Shazam!", + "release_year": 2019, + "id": "287947", + "_vectors": { + "manual": [ + 0.8, + 0.4, + -0.5 + ] + } + }, + { + "title": "Escape Room", + "release_year": 2019, + "id": "522681", + "_vectors": { + "manual": [ + 0.1, + 0.6, + 0.8 + ] + } + }, + { + "title": "All Quiet on the Western Front", + "release_year": 1930, + "id": "143", + "_vectors": { + "manual": [ + -0.5, + 0.3, + 0.85 + ] + } + } + ] + "###); + }) + .await; +} + +#[actix_rt::test] +async fn filter() { + let server = Server::new().await; + let index = server.index("test"); + let (value, code) = server.set_features(json!({"vectorStore": true})).await; + snapshot!(code, @"200 OK"); + snapshot!(value, @r###" + { + "vectorStore": true, + "metrics": false, + "logsRoute": false + } + "###); + + let (response, code) = index + .update_settings(json!({ + "embedders": { + "manual": { + "source": "userProvided", + "dimensions": 3, + } + }, + "filterableAttributes": ["title", "release_year"]})) + .await; + snapshot!(code, @"202 Accepted"); + server.wait_task(response.uid()).await; + + let documents = DOCUMENTS.clone(); + let (value, code) = index.add_documents(documents, None).await; + snapshot!(code, @"202 Accepted"); + index.wait_task(value.uid()).await; + + index + .similar(json!({"id": 522681, "filter": "release_year = 2019"}), |response, code| { + snapshot!(code, @"200 OK"); + snapshot!(json_string!(response["hits"]), @r###" + [ + { + "title": "Captain Marvel", + "release_year": 2019, + "id": "299537", + "_vectors": { + "manual": [ + 0.6, + 0.8, + -0.2 + ] + } + }, + { + "title": "How to Train Your Dragon: The Hidden World", + "release_year": 2019, + "id": "166428", + "_vectors": { + "manual": [ + 0.7, + 0.7, + -0.4 + ] + } + }, + { + "title": "Shazam!", + "release_year": 2019, + "id": "287947", + "_vectors": { + "manual": [ + 0.8, + 0.4, + -0.5 + ] + } + } + ] + "###); + }) + .await; + + index + .similar(json!({"id": 522681, "filter": "release_year < 2000"}), |response, code| { + snapshot!(code, @"200 OK"); + snapshot!(json_string!(response["hits"]), @r###" + [ + { + "title": "All Quiet on the Western Front", + "release_year": 1930, + "id": "143", + "_vectors": { + "manual": [ + -0.5, + 0.3, + 0.85 + ] + } + } + ] + "###); + }) + .await; +} + +#[actix_rt::test] +async fn limit_and_offset() { + let server = Server::new().await; + let index = server.index("test"); + let (value, code) = server.set_features(json!({"vectorStore": true})).await; + snapshot!(code, @"200 OK"); + snapshot!(value, @r###" + { + "vectorStore": true, + "metrics": false, + "logsRoute": false + } + "###); + + let (response, code) = index + .update_settings(json!({ + "embedders": { + "manual": { + "source": "userProvided", + "dimensions": 3, + } + }, + "filterableAttributes": ["title"]})) + .await; + snapshot!(code, @"202 Accepted"); + server.wait_task(response.uid()).await; + + let documents = DOCUMENTS.clone(); + let (value, code) = index.add_documents(documents, None).await; + snapshot!(code, @"202 Accepted"); + index.wait_task(value.uid()).await; + + index + .similar(json!({"id": 143, "limit": 1}), |response, code| { + snapshot!(code, @"200 OK"); + snapshot!(json_string!(response["hits"]), @r###" + [ + { + "title": "Escape Room", + "release_year": 2019, + "id": "522681", + "_vectors": { + "manual": [ + 0.1, + 0.6, + 0.8 + ] + } + } + ] + "###); + }) + .await; + + index + .similar(json!({"id": 143, "limit": 1, "offset": 1}), |response, code| { + snapshot!(code, @"200 OK"); + snapshot!(json_string!(response["hits"]), @r###" + [ + { + "title": "Captain Marvel", + "release_year": 2019, + "id": "299537", + "_vectors": { + "manual": [ + 0.6, + 0.8, + -0.2 + ] + } + } + ] + "###); + }) + .await; +} diff --git a/meilisearch/tests/snapshot/mod.rs b/meilisearch/tests/snapshot/mod.rs index 67e80f45b..0008993fe 100644 --- a/meilisearch/tests/snapshot/mod.rs +++ b/meilisearch/tests/snapshot/mod.rs @@ -31,6 +31,7 @@ macro_rules! verify_snapshot { } #[actix_rt::test] +#[cfg_attr(target_os = "windows", ignore)] async fn perform_snapshot() { let temp = tempfile::tempdir().unwrap(); let snapshot_dir = tempfile::tempdir().unwrap(); diff --git a/milli/examples/search.rs b/milli/examples/search.rs index 3d10ec599..2779f5b15 100644 --- a/milli/examples/search.rs +++ b/milli/examples/search.rs @@ -49,7 +49,7 @@ fn main() -> Result<(), Box> { let start = Instant::now(); let mut ctx = SearchContext::new(&index, &txn)?; - let universe = filtered_universe(&ctx, &None)?; + let universe = filtered_universe(ctx.index, ctx.txn, &None)?; let docs = execute_search( &mut ctx, diff --git a/milli/src/documents/mod.rs b/milli/src/documents/mod.rs index a874ac17e..76be61275 100644 --- a/milli/src/documents/mod.rs +++ b/milli/src/documents/mod.rs @@ -12,7 +12,10 @@ use bimap::BiHashMap; pub use builder::DocumentsBatchBuilder; pub use enriched::{EnrichedDocument, EnrichedDocumentsBatchCursor, EnrichedDocumentsBatchReader}; use obkv::KvReader; -pub use primary_key::{DocumentIdExtractionError, FieldIdMapper, PrimaryKey, DEFAULT_PRIMARY_KEY}; +pub use primary_key::{ + validate_document_id_value, DocumentIdExtractionError, FieldIdMapper, PrimaryKey, + DEFAULT_PRIMARY_KEY, +}; pub use reader::{DocumentsBatchCursor, DocumentsBatchCursorError, DocumentsBatchReader}; use serde::{Deserialize, Serialize}; diff --git a/milli/src/documents/primary_key.rs b/milli/src/documents/primary_key.rs index 16a95c21f..29f95beaf 100644 --- a/milli/src/documents/primary_key.rs +++ b/milli/src/documents/primary_key.rs @@ -60,7 +60,7 @@ impl<'a> PrimaryKey<'a> { Some(document_id_bytes) => { let document_id = serde_json::from_slice(document_id_bytes) .map_err(InternalError::SerdeJson)?; - match validate_document_id_value(document_id)? { + match validate_document_id_value(document_id) { Ok(document_id) => Ok(Ok(document_id)), Err(user_error) => { Ok(Err(DocumentIdExtractionError::InvalidDocumentId(user_error))) @@ -88,7 +88,7 @@ impl<'a> PrimaryKey<'a> { } match matching_documents_ids.pop() { - Some(document_id) => match validate_document_id_value(document_id)? { + Some(document_id) => match validate_document_id_value(document_id) { Ok(document_id) => Ok(Ok(document_id)), Err(user_error) => { Ok(Err(DocumentIdExtractionError::InvalidDocumentId(user_error))) @@ -159,14 +159,14 @@ fn validate_document_id(document_id: &str) -> Option<&str> { } } -pub fn validate_document_id_value(document_id: Value) -> Result> { +pub fn validate_document_id_value(document_id: Value) -> StdResult { match document_id { Value::String(string) => match validate_document_id(&string) { - Some(s) if s.len() == string.len() => Ok(Ok(string)), - Some(s) => Ok(Ok(s.to_string())), - None => Ok(Err(UserError::InvalidDocumentId { document_id: Value::String(string) })), + Some(s) if s.len() == string.len() => Ok(string), + Some(s) => Ok(s.to_string()), + None => Err(UserError::InvalidDocumentId { document_id: Value::String(string) }), }, - Value::Number(number) if number.is_i64() => Ok(Ok(number.to_string())), - content => Ok(Err(UserError::InvalidDocumentId { document_id: content })), + Value::Number(number) if number.is_i64() => Ok(number.to_string()), + content => Err(UserError::InvalidDocumentId { document_id: content }), } } diff --git a/milli/src/index.rs b/milli/src/index.rs index 982be0139..3c502d541 100644 --- a/milli/src/index.rs +++ b/milli/src/index.rs @@ -1595,6 +1595,22 @@ impl Index { .unwrap_or_default()) } + pub fn arroy_readers<'a>( + &'a self, + rtxn: &'a RoTxn<'a>, + embedder_id: u8, + ) -> impl Iterator>> + 'a { + crate::vector::arroy_db_range_for_embedder(embedder_id).map_while(move |k| { + arroy::Reader::open(rtxn, k, self.vector_arroy) + .map(Some) + .or_else(|e| match e { + arroy::Error::MissingMetadata => Ok(None), + e => Err(e.into()), + }) + .transpose() + }) + } + pub(crate) fn put_search_cutoff(&self, wtxn: &mut RwTxn<'_>, cutoff: u64) -> heed::Result<()> { self.main.remap_types::().put(wtxn, main_key::SEARCH_CUTOFF, &cutoff) } diff --git a/milli/src/lib.rs b/milli/src/lib.rs index 4d4cdaf9b..095fe1b94 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -63,6 +63,7 @@ pub use self::heed_codec::{ }; pub use self::index::Index; pub use self::search::facet::{FacetValueHit, SearchForFacetValues}; +pub use self::search::similar::Similar; pub use self::search::{ FacetDistribution, Filter, FormatOptions, MatchBounds, MatcherBuilder, MatchingWords, OrderBy, Search, SearchResult, SemanticSearch, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET, diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index ca0eda49e..76068b1f2 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -24,6 +24,7 @@ pub mod facet; mod fst_utils; pub mod hybrid; pub mod new; +pub mod similar; #[derive(Debug, Clone)] pub struct SemanticSearch { @@ -148,7 +149,7 @@ impl<'a> Search<'a> { pub fn execute_for_candidates(&self, has_vector_search: bool) -> Result { if has_vector_search { let ctx = SearchContext::new(self.index, self.rtxn)?; - filtered_universe(&ctx, &self.filter) + filtered_universe(ctx.index, ctx.txn, &self.filter) } else { Ok(self.execute()?.candidates) } @@ -161,7 +162,7 @@ impl<'a> Search<'a> { ctx.attributes_to_search_on(searchable_attributes)?; } - let universe = filtered_universe(&ctx, &self.filter)?; + let universe = filtered_universe(ctx.index, ctx.txn, &self.filter)?; let PartialSearchResult { located_query_terms, candidates, diff --git a/milli/src/search/new/matches/mod.rs b/milli/src/search/new/matches/mod.rs index 40e6f8dc8..f121971b8 100644 --- a/milli/src/search/new/matches/mod.rs +++ b/milli/src/search/new/matches/mod.rs @@ -507,7 +507,7 @@ mod tests { impl<'a> MatcherBuilder<'a> { fn new_test(rtxn: &'a heed::RoTxn, index: &'a TempIndex, query: &str) -> Self { let mut ctx = SearchContext::new(index, rtxn).unwrap(); - let universe = filtered_universe(&ctx, &None).unwrap(); + let universe = filtered_universe(ctx.index, ctx.txn, &None).unwrap(); let crate::search::PartialSearchResult { located_query_terms, .. } = execute_search( &mut ctx, Some(query), diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index 5e4c2f829..e152dd233 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -543,11 +543,15 @@ fn resolve_sort_criteria<'ctx, Query: RankingRuleQueryTrait>( Ok(()) } -pub fn filtered_universe(ctx: &SearchContext, filters: &Option) -> Result { +pub fn filtered_universe( + index: &Index, + txn: &RoTxn<'_>, + filters: &Option, +) -> Result { Ok(if let Some(filters) = filters { - filters.evaluate(ctx.txn, ctx.index)? + filters.evaluate(txn, index)? } else { - ctx.index.documents_ids(ctx.txn)? + index.documents_ids(txn)? }) } diff --git a/milli/src/search/new/vector_sort.rs b/milli/src/search/new/vector_sort.rs index de272ed47..cd69b6c47 100644 --- a/milli/src/search/new/vector_sort.rs +++ b/milli/src/search/new/vector_sort.rs @@ -49,19 +49,8 @@ impl VectorSort { ctx: &mut SearchContext<'_>, vector_candidates: &RoaringBitmap, ) -> Result<()> { - let writer_index = (self.embedder_index as u16) << 8; - let readers: std::result::Result, _> = (0..=u8::MAX) - .map_while(|k| { - arroy::Reader::open(ctx.txn, writer_index | (k as u16), ctx.index.vector_arroy) - .map(Some) - .or_else(|e| match e { - arroy::Error::MissingMetadata => Ok(None), - e => Err(e), - }) - .transpose() - }) - .collect(); - + let readers: std::result::Result, _> = + ctx.index.arroy_readers(ctx.txn, self.embedder_index).collect(); let readers = readers?; let target = &self.target; diff --git a/milli/src/search/similar.rs b/milli/src/search/similar.rs new file mode 100644 index 000000000..49b7c876f --- /dev/null +++ b/milli/src/search/similar.rs @@ -0,0 +1,111 @@ +use std::sync::Arc; + +use ordered_float::OrderedFloat; +use roaring::RoaringBitmap; + +use crate::score_details::{self, ScoreDetails}; +use crate::vector::Embedder; +use crate::{filtered_universe, DocumentId, Filter, Index, Result, SearchResult}; + +pub struct Similar<'a> { + id: DocumentId, + // this should be linked to the String in the query + filter: Option>, + offset: usize, + limit: usize, + rtxn: &'a heed::RoTxn<'a>, + index: &'a Index, + embedder_name: String, + embedder: Arc, +} + +impl<'a> Similar<'a> { + pub fn new( + id: DocumentId, + offset: usize, + limit: usize, + index: &'a Index, + rtxn: &'a heed::RoTxn<'a>, + embedder_name: String, + embedder: Arc, + ) -> Self { + Self { id, filter: None, offset, limit, rtxn, index, embedder_name, embedder } + } + + pub fn filter(&mut self, filter: Filter<'a>) -> &mut Self { + self.filter = Some(filter); + self + } + + pub fn execute(&self) -> Result { + let universe = filtered_universe(self.index, self.rtxn, &self.filter)?; + + let embedder_index = + self.index + .embedder_category_id + .get(self.rtxn, &self.embedder_name)? + .ok_or_else(|| crate::UserError::InvalidEmbedder(self.embedder_name.to_owned()))?; + + let readers: std::result::Result, _> = + self.index.arroy_readers(self.rtxn, embedder_index).collect(); + + let readers = readers?; + + let mut results = Vec::new(); + + for reader in readers.iter() { + let nns_by_item = reader.nns_by_item( + self.rtxn, + self.id, + self.limit + self.offset + 1, + None, + Some(&universe), + )?; + if let Some(mut nns_by_item) = nns_by_item { + results.append(&mut nns_by_item); + } else { + break; + } + } + + results.sort_unstable_by_key(|(_, distance)| OrderedFloat(*distance)); + + let mut documents_ids = Vec::with_capacity(self.limit); + let mut document_scores = Vec::with_capacity(self.limit); + // list of documents we've already seen, so that we don't return the same document multiple times. + // initialized to the target document, that we never want to return. + let mut documents_seen = RoaringBitmap::new(); + documents_seen.insert(self.id); + + for (docid, distance) in results + .into_iter() + // skip documents we've already seen & mark that we saw the current document + .filter(|(docid, _)| documents_seen.insert(*docid)) + .skip(self.offset) + // take **after** filter and skip so that we get exactly limit elements if available + .take(self.limit) + { + documents_ids.push(docid); + + let score = 1.0 - distance; + let score = self + .embedder + .distribution() + .map(|distribution| distribution.shift(score)) + .unwrap_or(score); + + let score = ScoreDetails::Vector(score_details::Vector { similarity: Some(score) }); + + document_scores.push(vec![score]); + } + + Ok(SearchResult { + matching_words: Default::default(), + candidates: universe, + documents_ids, + document_scores, + degraded: false, + used_negative_operator: false, + }) + } +} diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index 046498a8b..afae8973a 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -538,10 +538,8 @@ where )?; pool.install(|| { - let writer_index = (embedder_index as u16) << 8; - for k in 0..=u8::MAX { - let writer = - arroy::Writer::new(vector_arroy, writer_index | (k as u16), dimension); + for k in crate::vector::arroy_db_range_for_embedder(embedder_index) { + let writer = arroy::Writer::new(vector_arroy, k, dimension); if writer.is_empty(wtxn)? { break; } diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index 27f760c2a..2ef7a8990 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -634,16 +634,9 @@ pub(crate) fn write_typed_chunk_into_index( let embedder_index = index.embedder_category_id.get(wtxn, &embedder_name)?.ok_or( InternalError::DatabaseMissingEntry { db_name: "embedder_category_id", key: None }, )?; - let writer_index = (embedder_index as u16) << 8; // FIXME: allow customizing distance - let writers: Vec<_> = (0..=u8::MAX) - .map(|k| { - arroy::Writer::new( - index.vector_arroy, - writer_index | (k as u16), - expected_dimension, - ) - }) + let writers: Vec<_> = crate::vector::arroy_db_range_for_embedder(embedder_index) + .map(|k| arroy::Writer::new(index.vector_arroy, k, expected_dimension)) .collect(); // remove vectors for docids we want them removed diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index 1922bb389..553c8c3c1 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -442,3 +442,9 @@ impl DistributionShift { pub const fn is_cuda_enabled() -> bool { cfg!(feature = "cuda") } + +pub fn arroy_db_range_for_embedder(embedder_id: u8) -> impl Iterator { + let embedder_id = (embedder_id as u16) << 8; + + (0..=u8::MAX).map(move |k| embedder_id | (k as u16)) +}