use std::collections::HashMap; use std::collections::HashSet; use std::time::Duration; use meilidb_core::Index; use rayon::iter::{IntoParallelIterator, ParallelIterator}; use serde::{Deserialize, Serialize}; use tide::querystring::ContextExt as QSContextExt; use tide::{Context, Response}; use crate::error::{ResponseError, SResult}; use crate::helpers::meilidb::{Error, IndexSearchExt, SearchHit}; use crate::helpers::tide::ContextExt; use crate::Data; #[derive(Deserialize)] #[serde(rename_all = "camelCase", deny_unknown_fields)] struct SearchQuery { q: String, offset: Option, limit: Option, attributes_to_retrieve: Option, attributes_to_search_in: Option, attributes_to_crop: Option, crop_length: Option, attributes_to_highlight: Option, filters: Option, timeout_ms: Option, matches: Option, } pub async fn search_with_url_query(ctx: Context) -> SResult { // ctx.is_allowed(DocumentsRead)?; let index = ctx.index()?; let env = &ctx.state().db.env; let reader = env.read_txn().map_err(ResponseError::internal)?; let query: SearchQuery = ctx .url_query() .map_err(|_| ResponseError::bad_request("invalid query parameter"))?; let mut search_builder = index.new_search(query.q.clone()); if let Some(offset) = query.offset { search_builder.offset(offset); } if let Some(limit) = query.limit { search_builder.limit(limit); } if let Some(attributes_to_retrieve) = query.attributes_to_retrieve { for attr in attributes_to_retrieve.split(',') { search_builder.add_retrievable_field(attr.to_string()); } } if let Some(attributes_to_search_in) = query.attributes_to_search_in { for attr in attributes_to_search_in.split(',') { search_builder.add_retrievable_field(attr.to_string()); } } if let Some(attributes_to_crop) = query.attributes_to_crop { let crop_length = query.crop_length.unwrap_or(200); let attributes_to_crop = attributes_to_crop .split(',') .map(|r| (r.to_string(), crop_length)) .collect(); search_builder.attributes_to_crop(attributes_to_crop); } if let Some(attributes_to_highlight) = query.attributes_to_highlight { let attributes_to_highlight = attributes_to_highlight .split(',') .map(ToString::to_string) .collect(); search_builder.attributes_to_highlight(attributes_to_highlight); } if let Some(filters) = query.filters { search_builder.filters(filters); } if let Some(timeout_ms) = query.timeout_ms { search_builder.timeout(Duration::from_millis(timeout_ms)); } if let Some(matches) = query.matches { if matches { search_builder.get_matches(); } } let response = match search_builder.search(&reader) { Ok(response) => response, Err(Error::Internal(message)) => return Err(ResponseError::Internal(message)), Err(others) => return Err(ResponseError::bad_request(others)), }; Ok(tide::response::json(response)) } #[derive(Clone, Deserialize)] #[serde(rename_all = "camelCase", deny_unknown_fields)] struct SearchMultiBody { indexes: HashSet, query: String, offset: Option, limit: Option, attributes_to_retrieve: Option>, attributes_to_search_in: Option>, attributes_to_crop: Option>, attributes_to_highlight: Option>, filters: Option, timeout_ms: Option, matches: Option, } #[derive(Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] struct SearchMultiBodyResponse { hits: HashMap>, offset: usize, hits_per_page: usize, processing_time_ms: usize, query: String, } pub async fn search_multi_index(mut ctx: Context) -> SResult { // ctx.is_allowed(DocumentsRead)?; let body = ctx .body_json::() .await .map_err(ResponseError::bad_request)?; let mut index_list = body.clone().indexes; for index in index_list.clone() { if index == "*" { index_list = ctx .state() .db .indexes_names() .map_err(ResponseError::internal)? .into_iter() .collect(); } } let mut offset = 0; let mut count = 20; if let Some(body_offset) = body.offset { if let Some(limit) = body.limit { offset = body_offset; count = limit; } } let offset = offset; let count = count; let db = &ctx.state().db; let par_body = body.clone(); let responses_per_index: Vec> = index_list .into_par_iter() .map(move |index_name| { let index: Index = db .open_index(&index_name) .ok_or(ResponseError::index_not_found(&index_name))?; let mut search_builder = index.new_search(par_body.query.clone()); search_builder.offset(offset); search_builder.limit(count); if let Some(attributes_to_retrieve) = par_body.attributes_to_retrieve.clone() { search_builder.attributes_to_retrieve(attributes_to_retrieve); } if let Some(attributes_to_search_in) = par_body.attributes_to_search_in.clone() { search_builder.attributes_to_search_in(attributes_to_search_in); } if let Some(attributes_to_crop) = par_body.attributes_to_crop.clone() { search_builder.attributes_to_crop(attributes_to_crop); } if let Some(attributes_to_highlight) = par_body.attributes_to_highlight.clone() { search_builder.attributes_to_highlight(attributes_to_highlight); } if let Some(filters) = par_body.filters.clone() { search_builder.filters(filters); } if let Some(timeout_ms) = par_body.timeout_ms { search_builder.timeout(Duration::from_secs(timeout_ms)); } if let Some(matches) = par_body.matches { if matches { search_builder.get_matches(); } } let env = &db.env; let reader = env.read_txn().map_err(ResponseError::internal)?; let response = search_builder .search(&reader) .map_err(ResponseError::internal)?; Ok((index_name, response)) }) .collect(); let mut hits_map = HashMap::new(); let mut max_query_time = 0; for response in responses_per_index { if let Ok((index_name, response)) = response { if response.processing_time_ms > max_query_time { max_query_time = response.processing_time_ms; } hits_map.insert(index_name, response.hits); } } let response = SearchMultiBodyResponse { hits: hits_map, offset, hits_per_page: count, processing_time_ms: max_query_time, query: body.query, }; Ok(tide::response::json(response)) }