MeiliSearch/milli/src/localized_attributes_rules.rs

115 lines
3.7 KiB
Rust
Raw Normal View History

use std::collections::HashMap;
use charabia::Language;
use serde::{Deserialize, Serialize};
use crate::fields_ids_map::FieldsIdsMap;
use crate::FieldId;
/// A rule that defines which locales are supported for a given attribute.
///
/// The rule is a list of attribute patterns and a list of locales.
/// The attribute patterns are matched against the attribute name.
/// The pattern `*` matches any attribute name.
/// The pattern `attribute_name*` matches any attribute name that starts with `attribute_name`.
/// The pattern `*attribute_name` matches any attribute name that ends with `attribute_name`.
/// The pattern `*attribute_name*` matches any attribute name that contains `attribute_name`.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct LocalizedAttributesRule {
pub attribute_patterns: Vec<String>,
pub locales: Vec<Language>,
}
impl LocalizedAttributesRule {
pub fn new(attribute_patterns: Vec<String>, locales: Vec<Language>) -> Self {
Self { attribute_patterns, locales }
}
pub fn match_str(&self, str: &str) -> bool {
self.attribute_patterns.iter().any(|pattern| match_pattern(pattern.as_str(), str))
}
pub fn locales(&self) -> &[Language] {
&self.locales
}
}
fn match_pattern(pattern: &str, str: &str) -> bool {
let res = if pattern == "*" {
true
} else if pattern.starts_with('*') && pattern.ends_with('*') {
str.contains(&pattern[1..pattern.len() - 1])
} else if pattern.ends_with('*') {
str.starts_with(&pattern[..pattern.len() - 1])
} else if pattern.starts_with('*') {
str.ends_with(&pattern[1..])
} else {
pattern == str
};
res
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LocalizedFieldIds {
field_id_to_locales: HashMap<FieldId, Vec<Language>>,
}
impl LocalizedFieldIds {
pub fn new<I: Iterator<Item = FieldId>>(
rules: &Option<Vec<LocalizedAttributesRule>>,
fields_ids_map: &FieldsIdsMap,
fields_ids: I,
) -> Self {
let mut field_id_to_locales = HashMap::new();
if let Some(rules) = rules {
let fields = fields_ids.filter_map(|field_id| {
fields_ids_map.name(field_id).map(|field_name| (field_id, field_name))
});
for (field_id, field_name) in fields {
let mut locales = Vec::new();
for rule in rules {
if rule.match_str(field_name) {
locales.extend(rule.locales.iter());
}
}
if !locales.is_empty() {
locales.sort();
locales.dedup();
field_id_to_locales.insert(field_id, locales);
}
}
}
Self { field_id_to_locales }
}
pub fn locales<'a>(&'a self, fields_id: FieldId) -> Option<&'a [Language]> {
self.field_id_to_locales.get(&fields_id).map(Vec::as_slice)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_match_pattern() {
assert_eq!(match_pattern("*", "test"), true);
assert_eq!(match_pattern("test*", "test"), true);
assert_eq!(match_pattern("test*", "testa"), true);
assert_eq!(match_pattern("*test", "test"), true);
assert_eq!(match_pattern("*test", "atest"), true);
assert_eq!(match_pattern("*test*", "test"), true);
assert_eq!(match_pattern("*test*", "atesta"), true);
assert_eq!(match_pattern("*test*", "atest"), true);
assert_eq!(match_pattern("*test*", "testa"), true);
assert_eq!(match_pattern("test*test", "test"), false);
assert_eq!(match_pattern("*test", "testa"), false);
assert_eq!(match_pattern("test*", "atest"), false);
}
}