mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-07-04 20:37:15 +02:00
Proxy search requests
This commit is contained in:
parent
c3e5c3ba36
commit
35160788d7
7 changed files with 1787 additions and 923 deletions
10
crates/meilisearch/src/search/federated/mod.rs
Normal file
10
crates/meilisearch/src/search/federated/mod.rs
Normal file
|
@ -0,0 +1,10 @@
|
|||
mod perform;
|
||||
mod proxy;
|
||||
mod types;
|
||||
mod weighted_scores;
|
||||
|
||||
pub use perform::perform_federated_search;
|
||||
pub use proxy::{PROXY_SEARCH_HEADER, PROXY_SEARCH_HEADER_VALUE};
|
||||
pub use types::{
|
||||
FederatedSearch, FederatedSearchResult, Federation, FederationOptions, MergeFacets,
|
||||
};
|
1099
crates/meilisearch/src/search/federated/perform.rs
Normal file
1099
crates/meilisearch/src/search/federated/perform.rs
Normal file
File diff suppressed because it is too large
Load diff
267
crates/meilisearch/src/search/federated/proxy.rs
Normal file
267
crates/meilisearch/src/search/federated/proxy.rs
Normal file
|
@ -0,0 +1,267 @@
|
|||
pub use error::ProxySearchError;
|
||||
use error::ReqwestErrorWithoutUrl;
|
||||
use meilisearch_types::features::Remote;
|
||||
use rand::Rng as _;
|
||||
use reqwest::{Client, Response, StatusCode};
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde_json::Value;
|
||||
|
||||
use super::types::{FederatedSearch, FederatedSearchResult, Federation};
|
||||
use crate::search::SearchQueryWithIndex;
|
||||
|
||||
pub const PROXY_SEARCH_HEADER: &str = "Meili-Proxy-Search";
|
||||
pub const PROXY_SEARCH_HEADER_VALUE: &str = "true";
|
||||
|
||||
mod error {
|
||||
use meilisearch_types::error::ResponseError;
|
||||
use reqwest::StatusCode;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ProxySearchError {
|
||||
#[error("{0}")]
|
||||
CouldNotSendRequest(ReqwestErrorWithoutUrl),
|
||||
#[error("could not authenticate against the remote host\n - hint: check that the remote instance was registered with a valid API key having the `search` action")]
|
||||
AuthenticationError,
|
||||
#[error(
|
||||
"could not parse response from the remote host as a federated search response{}\n - hint: check that the remote instance is a Meilisearch instance running the same version",
|
||||
response_from_remote(response)
|
||||
)]
|
||||
CouldNotParseResponse { response: Result<String, ReqwestErrorWithoutUrl> },
|
||||
#[error("remote host responded with code {}{}\n - hint: check that the remote instance has the correct index configuration for that request\n - hint: check that the `network` experimental feature is enabled on the remote instance", status_code.as_u16(), response_from_remote(response))]
|
||||
BadRequest { status_code: StatusCode, response: Result<String, ReqwestErrorWithoutUrl> },
|
||||
#[error("remote host did not answer before the deadline")]
|
||||
Timeout,
|
||||
#[error("remote hit does not contain `{0}`\n - hint: check that the remote instance is a Meilisearch instance running the same version")]
|
||||
MissingPathInResponse(&'static str),
|
||||
#[error("remote host responded with code {}{}", status_code.as_u16(), response_from_remote(response))]
|
||||
RemoteError { status_code: StatusCode, response: Result<String, ReqwestErrorWithoutUrl> },
|
||||
#[error("remote hit contains an unexpected value at path `{path}`: expected {expected_type}, received `{received_value}`\n - hint: check that the remote instance is a Meilisearch instance running the same version")]
|
||||
UnexpectedValueInPath {
|
||||
path: &'static str,
|
||||
expected_type: &'static str,
|
||||
received_value: String,
|
||||
},
|
||||
#[error("could not parse weighted score values in the remote hit: {0}")]
|
||||
CouldNotParseWeightedScoreValues(serde_json::Error),
|
||||
}
|
||||
|
||||
impl ProxySearchError {
|
||||
pub fn as_response_error(&self) -> ResponseError {
|
||||
use meilisearch_types::error::Code;
|
||||
let message = self.to_string();
|
||||
let code = match self {
|
||||
ProxySearchError::CouldNotSendRequest(_) => Code::RemoteCouldNotSendRequest,
|
||||
ProxySearchError::AuthenticationError => Code::RemoteInvalidApiKey,
|
||||
ProxySearchError::BadRequest { .. } => Code::RemoteBadRequest,
|
||||
ProxySearchError::Timeout => Code::RemoteTimeout,
|
||||
ProxySearchError::RemoteError { .. } => Code::RemoteRemoteError,
|
||||
ProxySearchError::CouldNotParseResponse { .. }
|
||||
| ProxySearchError::MissingPathInResponse(_)
|
||||
| ProxySearchError::UnexpectedValueInPath { .. }
|
||||
| ProxySearchError::CouldNotParseWeightedScoreValues(_) => Code::RemoteBadResponse,
|
||||
};
|
||||
ResponseError::from_msg(message, code)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error(transparent)]
|
||||
pub struct ReqwestErrorWithoutUrl(reqwest::Error);
|
||||
impl ReqwestErrorWithoutUrl {
|
||||
pub fn new(inner: reqwest::Error) -> Self {
|
||||
Self(inner.without_url())
|
||||
}
|
||||
}
|
||||
|
||||
fn response_from_remote(response: &Result<String, ReqwestErrorWithoutUrl>) -> String {
|
||||
match response {
|
||||
Ok(response) => {
|
||||
format!(":\n - response from remote: {}", response)
|
||||
}
|
||||
Err(error) => {
|
||||
format!(":\n - additionally, could not retrieve response from remote: {error}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ProxySearchParams {
|
||||
pub deadline: Option<std::time::Instant>,
|
||||
pub try_count: u32,
|
||||
pub client: reqwest::Client,
|
||||
}
|
||||
|
||||
/// Performs a federated search on a remote host and returns the results
|
||||
pub async fn proxy_search(
|
||||
node: &Remote,
|
||||
queries: Vec<SearchQueryWithIndex>,
|
||||
federation: Federation,
|
||||
params: &ProxySearchParams,
|
||||
) -> Result<FederatedSearchResult, ProxySearchError> {
|
||||
let url = format!("{}/multi-search", node.url);
|
||||
|
||||
let federated = FederatedSearch { queries, federation: Some(federation) };
|
||||
|
||||
let search_api_key = node.search_api_key.as_deref();
|
||||
|
||||
let max_deadline = std::time::Instant::now() + std::time::Duration::from_secs(5);
|
||||
|
||||
let deadline = if let Some(deadline) = params.deadline {
|
||||
std::time::Instant::min(deadline, max_deadline)
|
||||
} else {
|
||||
max_deadline
|
||||
};
|
||||
|
||||
for i in 0..params.try_count {
|
||||
match try_proxy_search(&url, search_api_key, &federated, ¶ms.client, deadline).await {
|
||||
Ok(response) => return Ok(response),
|
||||
Err(retry) => {
|
||||
let duration = retry.into_duration(i)?;
|
||||
tokio::time::sleep(duration).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
try_proxy_search(&url, search_api_key, &federated, ¶ms.client, deadline)
|
||||
.await
|
||||
.map_err(Retry::into_error)
|
||||
}
|
||||
|
||||
async fn try_proxy_search(
|
||||
url: &str,
|
||||
search_api_key: Option<&str>,
|
||||
federated: &FederatedSearch,
|
||||
client: &Client,
|
||||
deadline: std::time::Instant,
|
||||
) -> Result<FederatedSearchResult, Retry> {
|
||||
let timeout = deadline.saturating_duration_since(std::time::Instant::now());
|
||||
|
||||
let request = client.post(url).json(&federated).timeout(timeout);
|
||||
let request = if let Some(search_api_key) = search_api_key {
|
||||
request.bearer_auth(search_api_key)
|
||||
} else {
|
||||
request
|
||||
};
|
||||
let request = request.header(PROXY_SEARCH_HEADER, PROXY_SEARCH_HEADER_VALUE);
|
||||
|
||||
let response = request.send().await;
|
||||
let response = match response {
|
||||
Ok(response) => response,
|
||||
Err(error) if error.is_timeout() => return Err(Retry::give_up(ProxySearchError::Timeout)),
|
||||
Err(error) => {
|
||||
return Err(Retry::retry_later(ProxySearchError::CouldNotSendRequest(
|
||||
ReqwestErrorWithoutUrl::new(error),
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
match response.status() {
|
||||
status_code if status_code.is_success() => (),
|
||||
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
|
||||
return Err(Retry::give_up(ProxySearchError::AuthenticationError))
|
||||
}
|
||||
status_code if status_code.is_client_error() => {
|
||||
let response = parse_error(response).await;
|
||||
return Err(Retry::give_up(ProxySearchError::BadRequest { status_code, response }));
|
||||
}
|
||||
status_code if status_code.is_server_error() => {
|
||||
let response = parse_error(response).await;
|
||||
return Err(Retry::retry_later(ProxySearchError::RemoteError {
|
||||
status_code,
|
||||
response,
|
||||
}));
|
||||
}
|
||||
status_code => {
|
||||
tracing::warn!(
|
||||
status_code = status_code.as_u16(),
|
||||
"remote replied with unexpected status code"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let response = match parse_response(response).await {
|
||||
Ok(response) => response,
|
||||
Err(response) => {
|
||||
return Err(Retry::retry_later(ProxySearchError::CouldNotParseResponse { response }))
|
||||
}
|
||||
};
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Always parse the body of the response of a failed request as JSON.
|
||||
async fn parse_error(response: Response) -> Result<String, ReqwestErrorWithoutUrl> {
|
||||
let bytes = match response.bytes().await {
|
||||
Ok(bytes) => bytes,
|
||||
Err(error) => return Err(ReqwestErrorWithoutUrl::new(error)),
|
||||
};
|
||||
|
||||
Ok(parse_bytes_as_error(&bytes))
|
||||
}
|
||||
|
||||
fn parse_bytes_as_error(bytes: &[u8]) -> String {
|
||||
match serde_json::from_slice::<Value>(bytes) {
|
||||
Ok(value) => value.to_string(),
|
||||
Err(_) => String::from_utf8_lossy(bytes).into_owned(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn parse_response<T: DeserializeOwned>(
|
||||
response: Response,
|
||||
) -> Result<T, Result<String, ReqwestErrorWithoutUrl>> {
|
||||
let bytes = match response.bytes().await {
|
||||
Ok(bytes) => bytes,
|
||||
Err(error) => return Err(Err(ReqwestErrorWithoutUrl::new(error))),
|
||||
};
|
||||
|
||||
match serde_json::from_slice::<T>(&bytes) {
|
||||
Ok(value) => Ok(value),
|
||||
Err(_) => Err(Ok(parse_bytes_as_error(&bytes))),
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Retry {
|
||||
error: ProxySearchError,
|
||||
strategy: RetryStrategy,
|
||||
}
|
||||
|
||||
pub enum RetryStrategy {
|
||||
GiveUp,
|
||||
Retry,
|
||||
}
|
||||
|
||||
impl Retry {
|
||||
pub fn give_up(error: ProxySearchError) -> Self {
|
||||
Self { error, strategy: RetryStrategy::GiveUp }
|
||||
}
|
||||
|
||||
pub fn retry_later(error: ProxySearchError) -> Self {
|
||||
Self { error, strategy: RetryStrategy::Retry }
|
||||
}
|
||||
|
||||
pub fn into_duration(self, attempt: u32) -> Result<std::time::Duration, ProxySearchError> {
|
||||
match self.strategy {
|
||||
RetryStrategy::GiveUp => Err(self.error),
|
||||
RetryStrategy::Retry => {
|
||||
let retry_duration = std::time::Duration::from_nanos((10u64).pow(attempt));
|
||||
let retry_duration = retry_duration.min(std::time::Duration::from_millis(100)); // don't wait more than 100ms
|
||||
|
||||
// randomly up to double the retry duration
|
||||
let retry_duration = retry_duration
|
||||
+ rand::thread_rng().gen_range(std::time::Duration::ZERO..retry_duration);
|
||||
|
||||
tracing::warn!(
|
||||
"Attempt #{}, failed with {}, retrying after {}ms.",
|
||||
attempt,
|
||||
self.error,
|
||||
retry_duration.as_millis()
|
||||
);
|
||||
Ok(retry_duration)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_error(self) -> ProxySearchError {
|
||||
self.error
|
||||
}
|
||||
}
|
322
crates/meilisearch/src/search/federated/types.rs
Normal file
322
crates/meilisearch/src/search/federated/types.rs
Normal file
|
@ -0,0 +1,322 @@
|
|||
use std::collections::btree_map::Entry;
|
||||
use std::collections::BTreeMap;
|
||||
use std::fmt;
|
||||
use std::vec::Vec;
|
||||
|
||||
use indexmap::IndexMap;
|
||||
use meilisearch_types::deserr::DeserrJsonError;
|
||||
use meilisearch_types::error::deserr_codes::{
|
||||
InvalidMultiSearchFacetsByIndex, InvalidMultiSearchMaxValuesPerFacet,
|
||||
InvalidMultiSearchMergeFacets, InvalidMultiSearchQueryPosition, InvalidMultiSearchRemote,
|
||||
InvalidMultiSearchWeight, InvalidSearchLimit, InvalidSearchOffset,
|
||||
};
|
||||
use meilisearch_types::error::ResponseError;
|
||||
use meilisearch_types::index_uid::IndexUid;
|
||||
use meilisearch_types::milli::order_by_map::OrderByMap;
|
||||
use meilisearch_types::milli::OrderBy;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use super::super::{ComputedFacets, FacetStats, HitsInfo, SearchHit, SearchQueryWithIndex};
|
||||
|
||||
pub const DEFAULT_FEDERATED_WEIGHT: f64 = 1.0;
|
||||
|
||||
// fields in the response
|
||||
pub const FEDERATION_HIT: &str = "_federation";
|
||||
pub const INDEX_UID: &str = "indexUid";
|
||||
pub const QUERIES_POSITION: &str = "queriesPosition";
|
||||
pub const WEIGHTED_RANKING_SCORE: &str = "weightedRankingScore";
|
||||
pub const WEIGHTED_SCORE_VALUES: &str = "weightedScoreValues";
|
||||
pub const FEDERATION_REMOTE: &str = "remote";
|
||||
|
||||
#[derive(Debug, Default, Clone, PartialEq, Serialize, deserr::Deserr, ToSchema)]
|
||||
#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
|
||||
pub struct FederationOptions {
|
||||
#[deserr(default, error = DeserrJsonError<InvalidMultiSearchWeight>)]
|
||||
#[schema(value_type = f64)]
|
||||
pub weight: Weight,
|
||||
|
||||
#[deserr(default, error = DeserrJsonError<InvalidMultiSearchRemote>)]
|
||||
pub remote: Option<String>,
|
||||
|
||||
#[deserr(default, error = DeserrJsonError<InvalidMultiSearchQueryPosition>)]
|
||||
pub query_position: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Serialize, deserr::Deserr)]
|
||||
#[deserr(try_from(f64) = TryFrom::try_from -> InvalidMultiSearchWeight)]
|
||||
pub struct Weight(f64);
|
||||
|
||||
impl Default for Weight {
|
||||
fn default() -> Self {
|
||||
Weight(DEFAULT_FEDERATED_WEIGHT)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::convert::TryFrom<f64> for Weight {
|
||||
type Error = InvalidMultiSearchWeight;
|
||||
|
||||
fn try_from(f: f64) -> Result<Self, Self::Error> {
|
||||
if f < 0.0 {
|
||||
Err(InvalidMultiSearchWeight)
|
||||
} else {
|
||||
Ok(Weight(f))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for Weight {
|
||||
type Target = f64;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, deserr::Deserr, Serialize, ToSchema)]
|
||||
#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)]
|
||||
#[schema(rename_all = "camelCase")]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Federation {
|
||||
#[deserr(default = super::super::DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)]
|
||||
pub limit: usize,
|
||||
#[deserr(default = super::super::DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
|
||||
pub offset: usize,
|
||||
#[deserr(default, error = DeserrJsonError<InvalidMultiSearchFacetsByIndex>)]
|
||||
pub facets_by_index: BTreeMap<IndexUid, Option<Vec<String>>>,
|
||||
#[deserr(default, error = DeserrJsonError<InvalidMultiSearchMergeFacets>)]
|
||||
pub merge_facets: Option<MergeFacets>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, deserr::Deserr, Serialize, Default, ToSchema)]
|
||||
#[deserr(error = DeserrJsonError<InvalidMultiSearchMergeFacets>, rename_all = camelCase, deny_unknown_fields)]
|
||||
#[schema(rename_all = "camelCase")]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct MergeFacets {
|
||||
#[deserr(default, error = DeserrJsonError<InvalidMultiSearchMaxValuesPerFacet>)]
|
||||
pub max_values_per_facet: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, deserr::Deserr, Serialize, ToSchema)]
|
||||
#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)]
|
||||
#[schema(rename_all = "camelCase")]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FederatedSearch {
|
||||
pub queries: Vec<SearchQueryWithIndex>,
|
||||
#[deserr(default)]
|
||||
pub federation: Option<Federation>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, ToSchema)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[schema(rename_all = "camelCase")]
|
||||
pub struct FederatedSearchResult {
|
||||
pub hits: Vec<SearchHit>,
|
||||
pub processing_time_ms: u128,
|
||||
#[serde(flatten)]
|
||||
pub hits_info: HitsInfo,
|
||||
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub semantic_hit_count: Option<u32>,
|
||||
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
#[schema(value_type = Option<BTreeMap<String, BTreeMap<String, u64>>>)]
|
||||
pub facet_distribution: Option<BTreeMap<String, IndexMap<String, u64>>>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub facet_stats: Option<BTreeMap<String, FacetStats>>,
|
||||
#[serde(default, skip_serializing_if = "FederatedFacets::is_empty")]
|
||||
pub facets_by_index: FederatedFacets,
|
||||
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub remote_errors: Option<BTreeMap<String, ResponseError>>,
|
||||
|
||||
// These fields are only used for analytics purposes
|
||||
#[serde(skip)]
|
||||
pub degraded: bool,
|
||||
#[serde(skip)]
|
||||
pub used_negative_operator: bool,
|
||||
}
|
||||
|
||||
impl fmt::Debug for FederatedSearchResult {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let FederatedSearchResult {
|
||||
hits,
|
||||
processing_time_ms,
|
||||
hits_info,
|
||||
semantic_hit_count,
|
||||
degraded,
|
||||
used_negative_operator,
|
||||
facet_distribution,
|
||||
facet_stats,
|
||||
facets_by_index,
|
||||
remote_errors,
|
||||
} = self;
|
||||
|
||||
let mut debug = f.debug_struct("SearchResult");
|
||||
// The most important thing when looking at a search result is the time it took to process
|
||||
debug.field("processing_time_ms", &processing_time_ms);
|
||||
debug.field("hits", &format!("[{} hits returned]", hits.len()));
|
||||
debug.field("hits_info", &hits_info);
|
||||
if *used_negative_operator {
|
||||
debug.field("used_negative_operator", used_negative_operator);
|
||||
}
|
||||
if *degraded {
|
||||
debug.field("degraded", degraded);
|
||||
}
|
||||
if let Some(facet_distribution) = facet_distribution {
|
||||
debug.field("facet_distribution", &facet_distribution);
|
||||
}
|
||||
if let Some(facet_stats) = facet_stats {
|
||||
debug.field("facet_stats", &facet_stats);
|
||||
}
|
||||
if let Some(semantic_hit_count) = semantic_hit_count {
|
||||
debug.field("semantic_hit_count", &semantic_hit_count);
|
||||
}
|
||||
if !facets_by_index.is_empty() {
|
||||
debug.field("facets_by_index", &facets_by_index);
|
||||
}
|
||||
if let Some(remote_errors) = remote_errors {
|
||||
debug.field("remote_errors", &remote_errors);
|
||||
}
|
||||
|
||||
debug.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, ToSchema)]
|
||||
pub struct FederatedFacets(pub BTreeMap<String, ComputedFacets>);
|
||||
|
||||
impl FederatedFacets {
|
||||
pub fn insert(&mut self, index: String, facets: Option<ComputedFacets>) {
|
||||
if let Some(facets) = facets {
|
||||
self.0.insert(index, facets);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.0.is_empty()
|
||||
}
|
||||
|
||||
pub fn merge(
|
||||
self,
|
||||
MergeFacets { max_values_per_facet }: MergeFacets,
|
||||
facet_order: BTreeMap<String, (String, OrderBy)>,
|
||||
) -> Option<ComputedFacets> {
|
||||
if self.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut distribution: BTreeMap<String, _> = Default::default();
|
||||
let mut stats: BTreeMap<String, FacetStats> = Default::default();
|
||||
|
||||
for facets_by_index in self.0.into_values() {
|
||||
for (facet, index_distribution) in facets_by_index.distribution {
|
||||
match distribution.entry(facet) {
|
||||
Entry::Vacant(entry) => {
|
||||
entry.insert(index_distribution);
|
||||
}
|
||||
Entry::Occupied(mut entry) => {
|
||||
let distribution = entry.get_mut();
|
||||
|
||||
for (value, index_count) in index_distribution {
|
||||
distribution
|
||||
.entry(value)
|
||||
.and_modify(|count| *count += index_count)
|
||||
.or_insert(index_count);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (facet, index_stats) in facets_by_index.stats {
|
||||
match stats.entry(facet) {
|
||||
Entry::Vacant(entry) => {
|
||||
entry.insert(index_stats);
|
||||
}
|
||||
Entry::Occupied(mut entry) => {
|
||||
let stats = entry.get_mut();
|
||||
|
||||
stats.min = f64::min(stats.min, index_stats.min);
|
||||
stats.max = f64::max(stats.max, index_stats.max);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fixup order
|
||||
for (facet, values) in &mut distribution {
|
||||
let order_by = facet_order.get(facet).map(|(_, order)| *order).unwrap_or_default();
|
||||
|
||||
match order_by {
|
||||
OrderBy::Lexicographic => {
|
||||
values.sort_unstable_by(|left, _, right, _| left.cmp(right))
|
||||
}
|
||||
OrderBy::Count => {
|
||||
values.sort_unstable_by(|_, left, _, right| {
|
||||
left.cmp(right)
|
||||
// biggest first
|
||||
.reverse()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(max_values_per_facet) = max_values_per_facet {
|
||||
values.truncate(max_values_per_facet)
|
||||
};
|
||||
}
|
||||
|
||||
Some(ComputedFacets { distribution, stats })
|
||||
}
|
||||
|
||||
pub(crate) fn append(&mut self, FederatedFacets(remote_facets_by_index): FederatedFacets) {
|
||||
for (index, remote_facets) in remote_facets_by_index {
|
||||
let merged_facets = self.0.entry(index).or_default();
|
||||
|
||||
for (remote_facet, remote_stats) in remote_facets.stats {
|
||||
match merged_facets.stats.entry(remote_facet) {
|
||||
Entry::Vacant(vacant_entry) => {
|
||||
vacant_entry.insert(remote_stats);
|
||||
}
|
||||
Entry::Occupied(mut occupied_entry) => {
|
||||
let stats = occupied_entry.get_mut();
|
||||
stats.min = f64::min(stats.min, remote_stats.min);
|
||||
stats.max = f64::max(stats.max, remote_stats.max);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (remote_facet, remote_values) in remote_facets.distribution {
|
||||
let merged_facet = merged_facets.distribution.entry(remote_facet).or_default();
|
||||
for (remote_value, remote_count) in remote_values {
|
||||
let count = merged_facet.entry(remote_value).or_default();
|
||||
*count += remote_count;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sort_and_truncate(&mut self, facet_order: BTreeMap<String, (OrderByMap, usize)>) {
|
||||
for (index, facets) in &mut self.0 {
|
||||
let Some((order_by, max_values_per_facet)) = facet_order.get(index) else {
|
||||
continue;
|
||||
};
|
||||
for (facet, values) in &mut facets.distribution {
|
||||
match order_by.get(facet) {
|
||||
OrderBy::Lexicographic => {
|
||||
values.sort_unstable_by(|left, _, right, _| left.cmp(right))
|
||||
}
|
||||
OrderBy::Count => {
|
||||
values.sort_unstable_by(|_, left, _, right| {
|
||||
left.cmp(right)
|
||||
// biggest first
|
||||
.reverse()
|
||||
})
|
||||
}
|
||||
}
|
||||
values.truncate(*max_values_per_facet);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
88
crates/meilisearch/src/search/federated/weighted_scores.rs
Normal file
88
crates/meilisearch/src/search/federated/weighted_scores.rs
Normal file
|
@ -0,0 +1,88 @@
|
|||
use std::cmp::Ordering;
|
||||
|
||||
use meilisearch_types::milli::score_details::{self, WeightedScoreValue};
|
||||
|
||||
pub fn compare(
|
||||
mut left_it: impl Iterator<Item = WeightedScoreValue>,
|
||||
left_weighted_global_score: f64,
|
||||
mut right_it: impl Iterator<Item = WeightedScoreValue>,
|
||||
right_weighted_global_score: f64,
|
||||
) -> Ordering {
|
||||
loop {
|
||||
let left = left_it.next();
|
||||
let right = right_it.next();
|
||||
|
||||
match (left, right) {
|
||||
(None, None) => return Ordering::Equal,
|
||||
(None, Some(_)) => return Ordering::Less,
|
||||
(Some(_), None) => return Ordering::Greater,
|
||||
(
|
||||
Some(
|
||||
WeightedScoreValue::WeightedScore(left) | WeightedScoreValue::VectorSort(left),
|
||||
),
|
||||
Some(
|
||||
WeightedScoreValue::WeightedScore(right)
|
||||
| WeightedScoreValue::VectorSort(right),
|
||||
),
|
||||
) => {
|
||||
if (left - right).abs() <= f64::EPSILON {
|
||||
continue;
|
||||
}
|
||||
return left.partial_cmp(&right).unwrap();
|
||||
}
|
||||
(
|
||||
Some(WeightedScoreValue::Sort { asc: left_asc, value: left }),
|
||||
Some(WeightedScoreValue::Sort { asc: right_asc, value: right }),
|
||||
) => {
|
||||
if left_asc != right_asc {
|
||||
return left_weighted_global_score
|
||||
.partial_cmp(&right_weighted_global_score)
|
||||
.unwrap();
|
||||
}
|
||||
match score_details::compare_sort_values(left_asc, &left, &right) {
|
||||
Ordering::Equal => continue,
|
||||
order => return order,
|
||||
}
|
||||
}
|
||||
(
|
||||
Some(WeightedScoreValue::GeoSort { asc: left_asc, distance: left }),
|
||||
Some(WeightedScoreValue::GeoSort { asc: right_asc, distance: right }),
|
||||
) => {
|
||||
if left_asc != right_asc {
|
||||
continue;
|
||||
}
|
||||
match (left, right) {
|
||||
(None, None) => continue,
|
||||
(None, Some(_)) => return Ordering::Less,
|
||||
(Some(_), None) => return Ordering::Greater,
|
||||
(Some(left), Some(right)) => {
|
||||
if (left - right).abs() <= f64::EPSILON {
|
||||
continue;
|
||||
}
|
||||
return left.partial_cmp(&right).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
// not comparable details, use global
|
||||
(Some(WeightedScoreValue::WeightedScore(_)), Some(_))
|
||||
| (Some(_), Some(WeightedScoreValue::WeightedScore(_)))
|
||||
| (Some(WeightedScoreValue::VectorSort(_)), Some(_))
|
||||
| (Some(_), Some(WeightedScoreValue::VectorSort(_)))
|
||||
| (Some(WeightedScoreValue::GeoSort { .. }), Some(WeightedScoreValue::Sort { .. }))
|
||||
| (Some(WeightedScoreValue::Sort { .. }), Some(WeightedScoreValue::GeoSort { .. })) => {
|
||||
let left_count = left_it.count();
|
||||
let right_count = right_it.count();
|
||||
// compare how many remaining groups of rules each side has.
|
||||
// the group with the most remaining groups wins.
|
||||
return left_count
|
||||
.cmp(&right_count)
|
||||
// breaks ties with the global ranking score
|
||||
.then_with(|| {
|
||||
left_weighted_global_score
|
||||
.partial_cmp(&right_weighted_global_score)
|
||||
.unwrap()
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue