diff --git a/meilisearch/src/routes/indexes/mod.rs b/meilisearch/src/routes/indexes/mod.rs index 59fa02dff..943f532e3 100644 --- a/meilisearch/src/routes/indexes/mod.rs +++ b/meilisearch/src/routes/indexes/mod.rs @@ -27,6 +27,7 @@ use crate::Opt; pub mod documents; pub mod facet_search; +pub mod recommend; pub mod search; pub mod settings; @@ -48,6 +49,7 @@ pub fn configure(cfg: &mut web::ServiceConfig) { .service(web::scope("/documents").configure(documents::configure)) .service(web::scope("/search").configure(search::configure)) .service(web::scope("/facet-search").configure(facet_search::configure)) + .service(web::scope("/recommend").configure(recommend::configure)) .service(web::scope("/settings").configure(settings::configure)), ); } diff --git a/meilisearch/src/routes/indexes/recommend.rs b/meilisearch/src/routes/indexes/recommend.rs new file mode 100644 index 000000000..0a127faad --- /dev/null +++ b/meilisearch/src/routes/indexes/recommend.rs @@ -0,0 +1,53 @@ +use actix_web::web::{self, Data}; +use actix_web::{HttpRequest, HttpResponse}; +use deserr::actix_web::AwebJson; +use index_scheduler::IndexScheduler; +use meilisearch_types::deserr::DeserrJsonError; +use meilisearch_types::error::ResponseError; +use meilisearch_types::index_uid::IndexUid; +use meilisearch_types::keys::actions; +use tracing::debug; + +use super::ActionPolicy; +use crate::analytics::Analytics; +use crate::extractors::authentication::GuardedData; +use crate::extractors::sequential_extractor::SeqHandler; +use crate::search::{perform_recommend, RecommendQuery, SearchKind}; + +pub fn configure(cfg: &mut web::ServiceConfig) { + cfg.service(web::resource("").route(web::post().to(SeqHandler(recommend)))); +} + +pub async fn recommend( + index_scheduler: GuardedData, Data>, + index_uid: web::Path, + params: AwebJson, + _req: HttpRequest, + _analytics: web::Data, +) -> Result { + let index_uid = IndexUid::try_from(index_uid.into_inner())?; + + // TODO analytics + + let query = params.into_inner(); + debug!(parameters = ?query, "Recommend post"); + + let index = index_scheduler.index(&index_uid)?; + + let features = index_scheduler.features(); + + features.check_vector("Using the recommend API.")?; + + let (embedder_name, embedder) = + SearchKind::embedder(&index_scheduler, &index, query.embedder.as_deref(), None)?; + + let recommendations = tokio::task::spawn_blocking(move || { + perform_recommend(&index, query, embedder_name, embedder) + }) + .await?; + + let recommendations = recommendations?; + + debug!(returns = ?recommendations, "Recommend post"); + Ok(HttpResponse::Ok().json(recommendations)) +} diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index 5bbcf1577..3fbc20757 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -312,6 +312,27 @@ impl SearchQueryWithIndex { } } +#[derive(Debug, Clone, Default, PartialEq, Deserr)] +#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] +pub struct RecommendQuery { + #[deserr(default, error = DeserrJsonError)] + pub id: String, + #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError)] + pub offset: usize, + #[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError)] + pub limit: usize, + #[deserr(default, error = DeserrJsonError)] + pub filter: Option, + #[deserr(default, error = DeserrJsonError, default)] + pub embedder: Option, + #[deserr(default, error = DeserrJsonError)] + pub attributes_to_retrieve: Option>, + #[deserr(default, error = DeserrJsonError, default)] + pub show_ranking_score: bool, + #[deserr(default, error = DeserrJsonError, default)] + pub show_ranking_score_details: bool, +} + #[derive(Debug, Copy, Clone, PartialEq, Eq, Deserr)] #[deserr(rename_all = camelCase)] pub enum MatchingStrategy { @@ -393,6 +414,16 @@ pub struct SearchResult { pub used_negative_operator: bool, } +#[derive(Serialize, Debug, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +pub struct RecommendResult { + pub hits: Vec, + pub id: String, + pub processing_time_ms: u128, + #[serde(flatten)] + pub hits_info: HitsInfo, +} + #[derive(Serialize, Debug, Clone, PartialEq)] #[serde(rename_all = "camelCase")] pub struct SearchResultWithIndex { @@ -796,6 +827,131 @@ pub fn perform_facet_search( }) } +pub fn perform_recommend( + index: &Index, + query: RecommendQuery, + embedder_name: String, + embedder: Arc, +) -> Result { + let before_search = Instant::now(); + let rtxn = index.read_txn()?; + + let internal_id = index + .external_documents_ids() + .get(&rtxn, &query.id)? + .ok_or_else(|| MeilisearchHttpError::DocumentNotFound(query.id.clone()))?; + + let mut recommend = milli::Recommend::new( + internal_id, + query.offset, + query.limit, + index, + &rtxn, + embedder_name, + embedder, + ); + + if let Some(ref filter) = query.filter { + if let Some(facets) = parse_filter(filter)? { + recommend.filter(facets); + } + } + + let milli::SearchResult { + documents_ids, + matching_words: _, + candidates, + document_scores, + degraded: _, + used_negative_operator: _, + } = recommend.execute()?; + + let fields_ids_map = index.fields_ids_map(&rtxn).unwrap(); + + let displayed_ids = index + .displayed_fields_ids(&rtxn)? + .map(|fields| fields.into_iter().collect::>()) + .unwrap_or_else(|| fields_ids_map.iter().map(|(id, _)| id).collect()); + + let fids = |attrs: &BTreeSet| { + let mut ids = BTreeSet::new(); + for attr in attrs { + if attr == "*" { + ids = displayed_ids.clone(); + break; + } + + if let Some(id) = fields_ids_map.id(attr) { + ids.insert(id); + } + } + ids + }; + + // The attributes to retrieve are the ones explicitly marked as to retrieve (all by default), + // but these attributes must be also be present + // - in the fields_ids_map + // - in the displayed attributes + let to_retrieve_ids: BTreeSet<_> = query + .attributes_to_retrieve + .as_ref() + .map(fids) + .unwrap_or_else(|| displayed_ids.clone()) + .intersection(&displayed_ids) + .cloned() + .collect(); + + let mut documents = Vec::new(); + let documents_iter = index.documents(&rtxn, documents_ids)?; + + for ((_id, obkv), score) in documents_iter.into_iter().zip(document_scores.into_iter()) { + // First generate a document with all the displayed fields + let displayed_document = make_document(&displayed_ids, &fields_ids_map, obkv)?; + + // select the attributes to retrieve + let attributes_to_retrieve = to_retrieve_ids + .iter() + .map(|&fid| fields_ids_map.name(fid).expect("Missing field name")); + let document = + permissive_json_pointer::select_values(&displayed_document, attributes_to_retrieve); + + let ranking_score = + query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter())); + let ranking_score_details = + query.show_ranking_score_details.then(|| ScoreDetails::to_json_map(score.iter())); + + let hit = SearchHit { + document, + formatted: Default::default(), + matches_position: None, + ranking_score_details, + ranking_score, + }; + documents.push(hit); + } + + let max_total_hits = index + .pagination_max_total_hits(&rtxn) + .map_err(milli::Error::from)? + .map(|x| x as usize) + .unwrap_or(DEFAULT_PAGINATION_MAX_TOTAL_HITS); + + let number_of_hits = min(candidates.len() as usize, max_total_hits); + let hits_info = HitsInfo::OffsetLimit { + limit: query.limit, + offset: query.offset, + estimated_total_hits: number_of_hits, + }; + + let result = RecommendResult { + hits: documents, + hits_info, + id: query.id, + processing_time_ms: before_search.elapsed().as_millis(), + }; + Ok(result) +} + fn insert_geo_distance(sorts: &[String], document: &mut Document) { lazy_static::lazy_static! { static ref GEO_REGEX: Regex = diff --git a/milli/src/lib.rs b/milli/src/lib.rs index 22816787b..712674fb9 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -59,6 +59,7 @@ pub use self::heed_codec::{ }; pub use self::index::Index; pub use self::search::facet::{FacetValueHit, SearchForFacetValues}; +pub use self::search::recommend::Recommend; pub use self::search::{ FacetDistribution, Filter, FormatOptions, MatchBounds, MatcherBuilder, MatchingWords, OrderBy, Search, SearchResult, SemanticSearch, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET, diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index 3470caa23..4e49d3c07 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -24,6 +24,7 @@ pub mod facet; mod fst_utils; pub mod hybrid; pub mod new; +pub mod recommend; #[derive(Debug, Clone)] pub struct SemanticSearch { diff --git a/milli/src/search/recommend.rs b/milli/src/search/recommend.rs new file mode 100644 index 000000000..269d65c49 --- /dev/null +++ b/milli/src/search/recommend.rs @@ -0,0 +1,108 @@ +use std::sync::Arc; + +use ordered_float::OrderedFloat; + +use crate::score_details::{self, ScoreDetails}; +use crate::vector::Embedder; +use crate::{filtered_universe, DocumentId, Filter, Index, Result, SearchResult}; + +pub struct Recommend<'a> { + id: DocumentId, + // this should be linked to the String in the query + filter: Option>, + offset: usize, + limit: usize, + rtxn: &'a heed::RoTxn<'a>, + index: &'a Index, + embedder_name: String, + embedder: Arc, +} + +impl<'a> Recommend<'a> { + pub fn new( + id: DocumentId, + offset: usize, + limit: usize, + index: &'a Index, + rtxn: &'a heed::RoTxn<'a>, + embedder_name: String, + embedder: Arc, + ) -> Self { + Self { id, filter: None, offset, limit, rtxn, index, embedder_name, embedder } + } + + pub fn filter(&mut self, filter: Filter<'a>) -> &mut Self { + self.filter = Some(filter); + self + } + + pub fn execute(&self) -> Result { + let universe = filtered_universe(self.index, self.rtxn, &self.filter)?; + + let embedder_index = + self.index + .embedder_category_id + .get(self.rtxn, &self.embedder_name)? + .ok_or_else(|| crate::UserError::InvalidEmbedder(self.embedder_name.to_owned()))?; + + let writer_index = (embedder_index as u16) << 8; + let readers: std::result::Result, _> = (0..=u8::MAX) + .map_while(|k| { + arroy::Reader::open(self.rtxn, writer_index | (k as u16), self.index.vector_arroy) + .map(Some) + .or_else(|e| match e { + arroy::Error::MissingMetadata => Ok(None), + e => Err(e), + }) + .transpose() + }) + .collect(); + + let readers = readers?; + + let mut results = Vec::new(); + + for reader in readers.iter() { + let nns_by_item = reader.nns_by_item( + self.rtxn, + self.id, + self.limit + self.offset + 1, + None, + Some(&universe), + )?; + if let Some(mut nns_by_item) = nns_by_item { + results.append(&mut nns_by_item); + } + } + + results.sort_unstable_by_key(|(_, distance)| OrderedFloat(*distance)); + + let mut documents_ids = Vec::with_capacity(self.limit); + let mut document_scores = Vec::with_capacity(self.limit); + + // skip offset +1 to skip the target document that is normally returned + for (docid, distance) in results.into_iter().skip(self.offset + 1) { + documents_ids.push(docid); + + let score = 1.0 - distance; + let score = self + .embedder + .distribution() + .map(|distribution| distribution.shift(score)) + .unwrap_or(score); + + let score = ScoreDetails::Vector(score_details::Vector { similarity: Some(score) }); + + document_scores.push(vec![score]); + } + + Ok(SearchResult { + matching_words: Default::default(), + candidates: universe, + documents_ids, + document_scores, + degraded: false, + used_negative_operator: false, + }) + } +}