diff --git a/Cargo.lock b/Cargo.lock index d08b1a83c..b97e712f8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -300,6 +300,18 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5ab7d9e73059c86c36473f459b52adbd99c3554a4fec492caef460806006f00" +[[package]] +name = "as-slice" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45403b49e3954a4b8428a0ac21a4b7afadccf92bfd96273f1a58cd4812496ae0" +dependencies = [ + "generic-array 0.12.4", + "generic-array 0.13.3", + "generic-array 0.14.4", + "stable_deref_trait", +] + [[package]] name = "assert-json-diff" version = "1.0.1" @@ -1089,6 +1101,15 @@ dependencies = [ "typenum", ] +[[package]] +name = "generic-array" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f797e67af32588215eaaab8327027ee8e71b9dd0b2b26996aedf20c030fce309" +dependencies = [ + "typenum", +] + [[package]] name = "generic-array" version = "0.14.4" @@ -1099,6 +1120,12 @@ dependencies = [ "version_check", ] +[[package]] +name = "geoutils" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e006f616a407d396ace1d2ebb3f43ed73189db8b098079bd129928d7645dd1e" + [[package]] name = "getrandom" version = "0.2.3" @@ -1177,6 +1204,15 @@ dependencies = [ "tracing", ] +[[package]] +name = "hash32" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4041af86e63ac4298ce40e5cca669066e75b6f1aa3390fe2561ffa5e1d9f4cc" +dependencies = [ + "byteorder", +] + [[package]] name = "hashbrown" version = "0.7.2" @@ -1193,6 +1229,18 @@ version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" +[[package]] +name = "heapless" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "634bd4d29cbf24424d0a4bfcbf80c6960129dc24424752a7d1d1390607023422" +dependencies = [ + "as-slice", + "generic-array 0.14.4", + "hash32", + "stable_deref_trait", +] + [[package]] name = "heck" version = "0.3.3" @@ -1697,6 +1745,7 @@ dependencies = [ "http", "indexmap", "itertools", + "lazy_static", "log", "main_error", "meilisearch-error", @@ -1780,8 +1829,8 @@ dependencies = [ [[package]] name = "milli" -version = "0.13.1" -source = "git+https://github.com/meilisearch/milli.git?rev=6de1b41#6de1b41f791e7d117634e63783d78b29b5228a99" +version = "0.16.0" +source = "git+https://github.com/meilisearch/milli.git#0f8320bdc24d76781e596d96d3b2e788a55655c6" dependencies = [ "bimap", "bincode", @@ -1794,6 +1843,7 @@ dependencies = [ "flate2", "fst", "fxhash", + "geoutils", "grenad", "heed", "human_format", @@ -1811,6 +1861,7 @@ dependencies = [ "pest_derive", "rayon", "roaring", + "rstar", "serde", "serde_json", "slice-group-by", @@ -1818,7 +1869,6 @@ dependencies = [ "smallvec", "tempfile", "uuid", - "vec-utils", ] [[package]] @@ -2038,6 +2088,12 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3cacbb3c4ff353b534a67fb8d7524d00229da4cb1dc8c79f4db96e375ab5b619" +[[package]] +name = "pdqselect" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec91767ecc0a0bbe558ce8c9da33c068066c57ecc8bb8477ef8c1ad3ef77c27" + [[package]] name = "percent-encoding" version = "2.1.0" @@ -2489,6 +2545,19 @@ dependencies = [ "retain_mut", ] +[[package]] +name = "rstar" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d535e658ada8c1987a113e5261f8b907f721b2854d666e72820671481b7ee125" +dependencies = [ + "heapless", + "num-traits", + "pdqselect", + "serde", + "smallvec", +] + [[package]] name = "rustc-demangle" version = "0.1.21" @@ -2762,6 +2831,12 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + [[package]] name = "standback" version = "0.2.17" @@ -3265,12 +3340,6 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" -[[package]] -name = "vec-utils" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dac984aa016c26ef4ed7b2c30d6a1bd570fd40a078caccaf6415a2ac5d96161" - [[package]] name = "vec_map" version = "0.8.2" diff --git a/meilisearch-error/src/lib.rs b/meilisearch-error/src/lib.rs index 9d5b79f69..2e4b50ef6 100644 --- a/meilisearch-error/src/lib.rs +++ b/meilisearch-error/src/lib.rs @@ -1,7 +1,7 @@ use std::fmt; use actix_http::http::StatusCode; -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; pub trait ErrorCode: std::error::Error { fn error_code(&self) -> Code; @@ -71,6 +71,8 @@ pub enum Code { BadRequest, DocumentNotFound, Internal, + InvalidGeoField, + InvalidRankingRule, InvalidToken, MissingAuthorizationHeader, NotFound, @@ -108,6 +110,8 @@ impl Code { PrimaryKeyAlreadyPresent => { ErrCode::invalid("primary_key_already_present", StatusCode::BAD_REQUEST) } + // invalid ranking rule + InvalidRankingRule => ErrCode::invalid("invalid_request", StatusCode::BAD_REQUEST), // invalid document MaxFieldsLimitExceeded => { @@ -126,6 +130,9 @@ impl Code { BadRequest => ErrCode::invalid("bad_request", StatusCode::BAD_REQUEST), DocumentNotFound => ErrCode::invalid("document_not_found", StatusCode::NOT_FOUND), Internal => ErrCode::internal("internal", StatusCode::INTERNAL_SERVER_ERROR), + InvalidGeoField => { + ErrCode::authentication("invalid_geo_field", StatusCode::BAD_REQUEST) + } InvalidToken => ErrCode::authentication("invalid_token", StatusCode::FORBIDDEN), MissingAuthorizationHeader => { ErrCode::authentication("missing_authorization_header", StatusCode::UNAUTHORIZED) diff --git a/meilisearch-http/Cargo.toml b/meilisearch-http/Cargo.toml index eb3d550ab..f3c11e29f 100644 --- a/meilisearch-http/Cargo.toml +++ b/meilisearch-http/Cargo.toml @@ -49,7 +49,7 @@ meilisearch-lib = { path = "../meilisearch-lib" } meilisearch-error = { path = "../meilisearch-error" } meilisearch-tokenizer = { git = "https://github.com/meilisearch/tokenizer.git", tag = "v0.2.5" } memmap = "0.7.0" -milli = { git = "https://github.com/meilisearch/milli.git", rev = "6de1b41" } +milli = { git = "https://github.com/meilisearch/milli.git", version = "0.16.0" } mime = "0.3.16" num_cpus = "1.13.0" once_cell = "1.8.0" diff --git a/meilisearch-http/src/error.rs b/meilisearch-http/src/error.rs index c18c32ea5..52538c862 100644 --- a/meilisearch-http/src/error.rs +++ b/meilisearch-http/src/error.rs @@ -74,13 +74,11 @@ impl ErrorCode for MilliError<'_> { milli::Error::UserError(ref error) => { match error { // TODO: wait for spec for new error codes. - | UserError::SerdeJson(_) + UserError::SerdeJson(_) | UserError::MaxDatabaseSizeReached - | UserError::InvalidCriterionName { .. } | UserError::InvalidDocumentId { .. } | UserError::InvalidStoreFile | UserError::NoSpaceLeftOnDevice - | UserError::InvalidAscDescSyntax { .. } | UserError::DocumentLimitReached => Code::Internal, UserError::AttributeLimitReached => Code::MaxFieldsLimitExceeded, UserError::InvalidFilter(_) => Code::Filter, @@ -93,7 +91,10 @@ impl ErrorCode for MilliError<'_> { UserError::SortRankingRuleMissing => Code::Sort, UserError::UnknownInternalDocumentId { .. } => Code::DocumentNotFound, UserError::InvalidFacetsDistribution { .. } => Code::BadRequest, - UserError::InvalidSortableAttribute { .. } => Code::Sort, + UserError::InvalidGeoField { .. } => Code::InvalidGeoField, + UserError::InvalidSortableAttribute { .. } + | UserError::InvalidReservedSortName { .. } => Code::Sort, + UserError::CriterionError(_) => Code::BadRequest, } } } diff --git a/meilisearch-http/src/routes/indexes/search.rs b/meilisearch-http/src/routes/indexes/search.rs index 1ae8eb2f7..c7e987840 100644 --- a/meilisearch-http/src/routes/indexes/search.rs +++ b/meilisearch-http/src/routes/indexes/search.rs @@ -1,7 +1,7 @@ use actix_web::{web, HttpResponse}; use log::debug; +use meilisearch_lib::index::{default_crop_length, SearchQuery, DEFAULT_SEARCH_LIMIT}; use meilisearch_lib::MeiliSearch; -use meilisearch_lib::index::{default_crop_length, SearchQuery, DEFAULT_SEARCH_LIMIT}; use serde::Deserialize; use serde_json::Value; @@ -61,9 +61,7 @@ impl From for SearchQuery { None => None, }; - let sort = other - .sort - .map(|attrs| attrs.split(',').map(String::from).collect()); + let sort = other.sort.map(|attr| fix_sort_query_parameters(&attr)); Self { q: other.q, @@ -81,6 +79,30 @@ impl From for SearchQuery { } } +/// Transform the sort query parameter into something that matches the post expected format. +fn fix_sort_query_parameters(sort_query: &str) -> Vec { + let mut sort_parameters = Vec::new(); + let mut merge = false; + for current_sort in sort_query.trim_matches('"').split(',').map(|s| s.trim()) { + if current_sort.starts_with("_geoPoint(") { + sort_parameters.push(current_sort.to_string()); + merge = true; + } else if merge && !sort_parameters.is_empty() { + sort_parameters + .last_mut() + .unwrap() + .push_str(&format!(",{}", current_sort)); + if current_sort.ends_with("):desc") || current_sort.ends_with("):asc") { + merge = false; + } + } else { + sort_parameters.push(current_sort.to_string()); + merge = false; + } + } + sort_parameters +} + pub async fn search_with_url_query( meilisearch: GuardedData, path: web::Path, @@ -88,7 +110,9 @@ pub async fn search_with_url_query( ) -> Result { debug!("called with params: {:?}", params); let query = params.into_inner().into(); - let search_result = meilisearch.search(path.into_inner().index_uid, query).await?; + let search_result = meilisearch + .search(path.into_inner().index_uid, query) + .await?; // Tests that the nb_hits is always set to false #[cfg(test)] @@ -115,3 +139,42 @@ pub async fn search_with_post( debug!("returns: {:?}", search_result); Ok(HttpResponse::Ok().json(search_result)) } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_fix_sort_query_parameters() { + let sort = fix_sort_query_parameters("_geoPoint(12, 13):asc"); + assert_eq!(sort, vec!["_geoPoint(12,13):asc".to_string()]); + let sort = fix_sort_query_parameters("doggo:asc,_geoPoint(12.45,13.56):desc"); + assert_eq!( + sort, + vec![ + "doggo:asc".to_string(), + "_geoPoint(12.45,13.56):desc".to_string(), + ] + ); + let sort = fix_sort_query_parameters( + "doggo:asc , _geoPoint(12.45, 13.56, 2590352):desc , catto:desc", + ); + assert_eq!( + sort, + vec![ + "doggo:asc".to_string(), + "_geoPoint(12.45,13.56,2590352):desc".to_string(), + "catto:desc".to_string(), + ] + ); + let sort = fix_sort_query_parameters("doggo:asc , _geoPoint(1, 2), catto:desc"); + // This is ugly but eh, I don't want to write a full parser just for this unused route + assert_eq!( + sort, + vec![ + "doggo:asc".to_string(), + "_geoPoint(1,2),catto:desc".to_string(), + ] + ); + } +} diff --git a/meilisearch-lib/Cargo.toml b/meilisearch-lib/Cargo.toml index 0d9f6520b..78cbd1b96 100644 --- a/meilisearch-lib/Cargo.toml +++ b/meilisearch-lib/Cargo.toml @@ -27,12 +27,13 @@ heed = { git = "https://github.com/Kerollmops/heed", tag = "v0.12.1" } http = "0.2.4" indexmap = { version = "1.7.0", features = ["serde-1"] } itertools = "0.10.1" +lazy_static = "1.4.0" log = "0.4.14" main_error = "0.1.1" meilisearch-error = { path = "../meilisearch-error" } meilisearch-tokenizer = { git = "https://github.com/meilisearch/tokenizer.git", tag = "v0.2.5" } memmap = "0.7.0" -milli = { git = "https://github.com/meilisearch/milli.git", rev = "6de1b41" } +milli = { git = "https://github.com/meilisearch/milli.git", version = "0.16.0" } mime = "0.3.16" num_cpus = "1.13.0" once_cell = "1.8.0" diff --git a/meilisearch-lib/src/error.rs b/meilisearch-lib/src/error.rs index 80141dae5..a369381fe 100644 --- a/meilisearch-lib/src/error.rs +++ b/meilisearch-lib/src/error.rs @@ -35,13 +35,11 @@ impl ErrorCode for MilliError<'_> { milli::Error::UserError(ref error) => { match error { // TODO: wait for spec for new error codes. - | UserError::SerdeJson(_) + UserError::SerdeJson(_) | UserError::MaxDatabaseSizeReached - | UserError::InvalidCriterionName { .. } | UserError::InvalidDocumentId { .. } | UserError::InvalidStoreFile | UserError::NoSpaceLeftOnDevice - | UserError::InvalidAscDescSyntax { .. } | UserError::DocumentLimitReached => Code::Internal, UserError::AttributeLimitReached => Code::MaxFieldsLimitExceeded, UserError::InvalidFilter(_) => Code::Filter, @@ -54,7 +52,10 @@ impl ErrorCode for MilliError<'_> { UserError::SortRankingRuleMissing => Code::Sort, UserError::UnknownInternalDocumentId { .. } => Code::DocumentNotFound, UserError::InvalidFacetsDistribution { .. } => Code::BadRequest, - UserError::InvalidSortableAttribute { .. } => Code::Sort, + UserError::InvalidSortableAttribute { .. } + | UserError::InvalidReservedSortName { .. } => Code::Sort, + UserError::CriterionError(_) => Code::InvalidRankingRule, + UserError::InvalidGeoField { .. } => Code::InvalidGeoField, } } } diff --git a/meilisearch-lib/src/index/search.rs b/meilisearch-lib/src/index/search.rs index c7949fea6..37f4e2a33 100644 --- a/meilisearch-lib/src/index/search.rs +++ b/meilisearch-lib/src/index/search.rs @@ -6,9 +6,12 @@ use either::Either; use heed::RoTxn; use indexmap::IndexMap; use meilisearch_tokenizer::{Analyzer, AnalyzerConfig, Token}; -use milli::{AscDesc, FieldId, FieldsIdsMap, FilterCondition, MatchingWords, UserError}; +use milli::{ + AscDesc, AscDescError, FieldId, FieldsIdsMap, FilterCondition, MatchingWords, UserError, +}; +use regex::Regex; use serde::{Deserialize, Serialize}; -use serde_json::Value; +use serde_json::{json, Value}; use crate::index::error::FacetError; use crate::index::IndexError; @@ -110,12 +113,16 @@ impl Index { if let Some(ref sort) = query.sort { let sort = match sort.iter().map(|s| AscDesc::from_str(s)).collect() { Ok(sorts) => sorts, - Err(UserError::InvalidAscDescSyntax { name }) => { + Err(AscDescError::InvalidSyntax { name }) => { return Err(IndexError::Milli( UserError::InvalidSortName { name }.into(), )) } - Err(err) => return Err(IndexError::Milli(err.into())), + Err(AscDescError::ReservedKeyword { name }) => { + return Err(IndexError::Milli( + UserError::InvalidReservedSortName { name }.into(), + )) + } }; search.sort_criteria(sort); @@ -193,7 +200,7 @@ impl Index { let documents_iter = self.documents(&rtxn, documents_ids)?; for (_id, obkv) in documents_iter { - let document = make_document(&to_retrieve_ids, &fields_ids_map, obkv)?; + let mut document = make_document(&to_retrieve_ids, &fields_ids_map, obkv)?; let matches_info = query .matches @@ -207,6 +214,10 @@ impl Index { &formatted_options, )?; + if let Some(sort) = query.sort.as_ref() { + insert_geo_distance(sort, &mut document); + } + let hit = SearchHit { document, formatted, @@ -247,6 +258,25 @@ impl Index { } } +fn insert_geo_distance(sorts: &[String], document: &mut Document) { + lazy_static::lazy_static! { + static ref GEO_REGEX: Regex = + Regex::new(r"_geoPoint\(\s*([[:digit:].\-]+)\s*,\s*([[:digit:].\-]+)\s*\)").unwrap(); + }; + if let Some(capture_group) = sorts.iter().find_map(|sort| GEO_REGEX.captures(sort)) { + // TODO: TAMO: milli encountered an internal error, what do we want to do? + let base = [ + capture_group[1].parse().unwrap(), + capture_group[2].parse().unwrap(), + ]; + let geo_point = &document.get("_geo").unwrap_or(&json!(null)); + if let Some((lat, lng)) = geo_point["lat"].as_f64().zip(geo_point["lng"].as_f64()) { + let distance = milli::distance_between_two_points(&base, &[lat, lng]); + document.insert("_geoDistance".to_string(), json!(distance.round() as usize)); + } + } +} + fn compute_matches>( matcher: &impl Matcher, document: &Document, @@ -1332,4 +1362,65 @@ mod test { r##"{"about": [MatchInfo { start: 0, length: 6 }, MatchInfo { start: 31, length: 7 }, MatchInfo { start: 191, length: 7 }, MatchInfo { start: 225, length: 7 }, MatchInfo { start: 233, length: 6 }], "color": [MatchInfo { start: 0, length: 3 }]}"## ); } + + #[test] + fn test_insert_geo_distance() { + let value: Document = serde_json::from_str( + r#"{ + "_geo": { + "lat": 50.629973371633746, + "lng": 3.0569447399419567 + }, + "city": "Lille", + "id": "1" + }"#, + ) + .unwrap(); + + let sorters = &["_geoPoint(50.629973371633746,3.0569447399419567):desc".to_string()]; + let mut document = value.clone(); + insert_geo_distance(sorters, &mut document); + assert_eq!(document.get("_geoDistance"), Some(&json!(0))); + + let sorters = &["_geoPoint(50.629973371633746, 3.0569447399419567):asc".to_string()]; + let mut document = value.clone(); + insert_geo_distance(sorters, &mut document); + assert_eq!(document.get("_geoDistance"), Some(&json!(0))); + + let sorters = + &["_geoPoint( 50.629973371633746 , 3.0569447399419567 ):desc".to_string()]; + let mut document = value.clone(); + insert_geo_distance(sorters, &mut document); + assert_eq!(document.get("_geoDistance"), Some(&json!(0))); + + let sorters = &[ + "prix:asc", + "villeneuve:desc", + "_geoPoint(50.629973371633746, 3.0569447399419567):asc", + "ubu:asc", + ] + .map(|s| s.to_string()); + let mut document = value.clone(); + insert_geo_distance(sorters, &mut document); + assert_eq!(document.get("_geoDistance"), Some(&json!(0))); + + // only the first geoPoint is used to compute the distance + let sorters = &[ + "chien:desc", + "_geoPoint(50.629973371633746, 3.0569447399419567):asc", + "pangolin:desc", + "_geoPoint(100.0, -80.0):asc", + "chat:asc", + ] + .map(|s| s.to_string()); + let mut document = value.clone(); + insert_geo_distance(sorters, &mut document); + assert_eq!(document.get("_geoDistance"), Some(&json!(0))); + + // there was no _geoPoint so nothing is inserted in the document + let sorters = &["chien:asc".to_string()]; + let mut document = value; + insert_geo_distance(sorters, &mut document); + assert_eq!(document.get("_geoDistance"), None); + } }