387 lines
14 KiB
Rust
Raw Permalink Normal View History

2021-08-23 11:37:18 +02:00
use std::cmp::Reverse;
2021-06-17 13:56:09 +02:00
use std::collections::HashSet;
2024-11-18 17:39:55 +01:00
use std::io::Write;
2021-06-17 13:56:09 +02:00
2021-06-03 14:44:53 +02:00
use big_s::S;
2024-11-18 17:39:55 +01:00
use bumpalo::Bump;
2021-06-17 13:56:09 +02:00
use either::{Either, Left, Right};
2021-06-03 14:44:53 +02:00
use heed::EnvOpenOptions;
use maplit::{btreemap, hashset};
2024-12-10 16:30:48 +01:00
use milli::progress::Progress;
2024-11-18 17:39:55 +01:00
use milli::update::new::indexer;
use milli::update::{IndexDocumentsMethod, IndexerConfig, Settings};
use milli::vector::EmbeddingConfigs;
use milli::{AscDesc, Criterion, DocumentId, Index, Member, TermsMatchingStrategy};
use serde::{Deserialize, Deserializer};
2021-06-03 14:44:53 +02:00
use slice_group_by::GroupBy;
2021-06-17 14:24:59 +02:00
mod distinct;
mod facet_distribution;
2021-06-17 13:56:09 +02:00
mod filters;
2022-10-13 23:34:17 +05:30
mod phrase_search;
2021-06-03 14:44:53 +02:00
mod query_criteria;
mod sort;
2022-04-01 10:50:01 +02:00
mod typo_tolerance;
2021-06-03 14:44:53 +02:00
2022-10-10 22:28:03 +09:00
pub const TEST_QUERY: &str = "hello world america";
2021-06-03 14:44:53 +02:00
2021-06-16 18:33:33 +02:00
pub const EXTERNAL_DOCUMENTS_IDS: &[&str; 17] =
&["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q"];
2021-06-03 14:44:53 +02:00
pub const CONTENT: &str = include_str!("../assets/test_set.ndjson");
pub fn setup_search_index_with_criteria(criteria: &[Criterion]) -> Index {
let path = tempfile::tempdir().unwrap();
let mut options = EnvOpenOptions::new();
options.map_size(10 * 1024 * 1024); // 10 MB
let index = Index::new(options, &path).unwrap();
let mut wtxn = index.write_txn().unwrap();
let config = IndexerConfig::default();
2021-06-03 14:44:53 +02:00
let mut builder = Settings::new(&mut wtxn, &index, &config);
2021-06-03 14:44:53 +02:00
2023-01-11 12:14:17 +01:00
builder.set_criteria(criteria.to_vec());
2021-06-16 18:33:33 +02:00
builder.set_filterable_fields(hashset! {
2021-06-03 14:44:53 +02:00
S("tag"),
S("asc_desc_rank"),
2021-09-08 13:08:48 +02:00
S("_geo"),
S("opt1"),
S("opt1.opt2"),
S("tag_in")
2021-06-03 14:44:53 +02:00
});
2021-08-23 11:37:18 +02:00
builder.set_sortable_fields(hashset! {
S("tag"),
S("asc_desc_rank"),
});
builder.set_synonyms(btreemap! {
2021-06-03 14:44:53 +02:00
S("hello") => vec![S("good morning")],
S("world") => vec![S("earth")],
S("america") => vec![S("the united states")],
});
2021-06-16 18:33:33 +02:00
builder.set_searchable_fields(vec![S("title"), S("description")]);
builder.execute(|_| (), || false).unwrap();
2024-11-19 16:49:00 +01:00
wtxn.commit().unwrap();
2021-06-03 14:44:53 +02:00
// index documents
let config = IndexerConfig { max_memory: Some(10 * 1024 * 1024), ..Default::default() };
2024-11-19 16:49:00 +01:00
let rtxn = index.read_txn().unwrap();
let mut wtxn = index.write_txn().unwrap();
2024-11-18 17:39:55 +01:00
let db_fields_ids_map = index.fields_ids_map(&rtxn).unwrap();
let mut new_fields_ids_map = db_fields_ids_map.clone();
2024-11-18 17:39:55 +01:00
let embedders = EmbeddingConfigs::default();
let mut indexer = indexer::DocumentOperation::new(IndexDocumentsMethod::ReplaceDocuments);
2021-10-24 14:41:36 +02:00
2024-11-18 17:39:55 +01:00
let mut file = tempfile::tempfile().unwrap();
file.write_all(CONTENT.as_bytes()).unwrap();
file.sync_all().unwrap();
let payload = unsafe { memmap2::Mmap::map(&file).unwrap() };
// index documents
2024-11-18 17:39:55 +01:00
indexer.add_documents(&payload).unwrap();
let indexer_alloc = Bump::new();
2024-11-20 14:58:25 +01:00
let (document_changes, operation_stats, primary_key) = indexer
2024-11-20 15:10:09 +01:00
.into_changes(
&indexer_alloc,
&index,
&rtxn,
None,
&mut new_fields_ids_map,
&|| false,
2024-12-10 16:30:48 +01:00
Progress::default(),
2024-11-20 15:10:09 +01:00
)
2024-11-20 14:58:25 +01:00
.unwrap();
2024-11-18 17:39:55 +01:00
2024-11-19 16:49:00 +01:00
if let Some(error) = operation_stats.into_iter().find_map(|stat| stat.error) {
panic!("{error}");
}
2024-11-18 17:39:55 +01:00
indexer::index(
&mut wtxn,
&index,
&milli::ThreadPoolNoAbortBuilder::new().build().unwrap(),
2024-11-18 17:39:55 +01:00
config.grenad_parameters(),
&db_fields_ids_map,
new_fields_ids_map,
primary_key,
&document_changes,
embedders,
&|| false,
2024-12-10 16:30:48 +01:00
&Progress::default(),
2024-11-18 17:39:55 +01:00
)
.unwrap();
2021-06-03 14:44:53 +02:00
wtxn.commit().unwrap();
2024-11-18 17:39:55 +01:00
drop(rtxn);
2021-06-03 14:44:53 +02:00
index
}
pub fn internal_to_external_ids(index: &Index, internal_ids: &[DocumentId]) -> Vec<String> {
2023-01-17 18:01:26 +01:00
let rtxn = index.read_txn().unwrap();
let docid_map = index.external_documents_ids();
let docid_map: std::collections::HashMap<_, _> = EXTERNAL_DOCUMENTS_IDS
.iter()
.map(|id| (docid_map.get(&rtxn, id).unwrap().unwrap(), id))
.collect();
2021-06-03 14:44:53 +02:00
internal_ids.iter().map(|id| docid_map.get(id).unwrap().to_string()).collect()
}
2021-06-16 18:33:33 +02:00
pub fn expected_order(
criteria: &[Criterion],
2022-08-18 17:36:08 +02:00
optional_words: TermsMatchingStrategy,
2021-08-23 11:37:18 +02:00
sort_by: &[AscDesc],
2021-06-16 18:33:33 +02:00
) -> Vec<TestDocument> {
let dataset =
serde_json::Deserializer::from_str(CONTENT).into_iter().map(|r| r.unwrap()).collect();
2021-06-03 14:44:53 +02:00
let mut groups: Vec<Vec<TestDocument>> = vec![dataset];
for criterion in criteria {
let mut new_groups = Vec::new();
for group in groups.iter_mut() {
match criterion {
Criterion::Attribute => {
group.sort_by_key(|d| d.attribute_rank);
2021-06-16 18:33:33 +02:00
new_groups
.extend(group.linear_group_by_key(|d| d.attribute_rank).map(Vec::from));
}
2021-06-03 14:44:53 +02:00
Criterion::Exactness => {
group.sort_by_key(|d| d.exact_rank);
new_groups.extend(group.linear_group_by_key(|d| d.exact_rank).map(Vec::from));
2021-06-16 18:33:33 +02:00
}
2021-06-03 14:44:53 +02:00
Criterion::Proximity => {
group.sort_by_key(|d| d.proximity_rank);
2021-06-16 18:33:33 +02:00
new_groups
.extend(group.linear_group_by_key(|d| d.proximity_rank).map(Vec::from));
}
2021-09-01 17:43:18 +02:00
Criterion::Sort if sort_by == [AscDesc::Asc(Member::Field(S("tag")))] => {
2021-08-23 11:37:18 +02:00
group.sort_by_key(|d| d.sort_by_rank);
new_groups.extend(group.linear_group_by_key(|d| d.sort_by_rank).map(Vec::from));
}
2021-09-01 17:43:18 +02:00
Criterion::Sort if sort_by == [AscDesc::Desc(Member::Field(S("tag")))] => {
2021-08-23 11:37:18 +02:00
group.sort_by_key(|d| Reverse(d.sort_by_rank));
new_groups.extend(group.linear_group_by_key(|d| d.sort_by_rank).map(Vec::from));
}
2021-06-03 14:44:53 +02:00
Criterion::Typo => {
group.sort_by_key(|d| d.typo_rank);
new_groups.extend(group.linear_group_by_key(|d| d.typo_rank).map(Vec::from));
2021-06-16 18:33:33 +02:00
}
2021-06-03 14:44:53 +02:00
Criterion::Words => {
group.sort_by_key(|d| d.word_rank);
new_groups.extend(group.linear_group_by_key(|d| d.word_rank).map(Vec::from));
2021-06-16 18:33:33 +02:00
}
2021-06-08 12:33:02 +02:00
Criterion::Asc(field_name) if field_name == "asc_desc_rank" => {
2021-06-03 14:44:53 +02:00
group.sort_by_key(|d| d.asc_desc_rank);
2021-06-16 18:33:33 +02:00
new_groups
.extend(group.linear_group_by_key(|d| d.asc_desc_rank).map(Vec::from));
}
Criterion::Desc(field_name) if field_name == "asc_desc_rank" => {
2021-08-23 11:37:18 +02:00
group.sort_by_key(|d| Reverse(d.asc_desc_rank));
2021-06-16 18:33:33 +02:00
new_groups
.extend(group.linear_group_by_key(|d| d.asc_desc_rank).map(Vec::from));
}
2021-08-23 11:37:18 +02:00
Criterion::Asc(_) | Criterion::Desc(_) | Criterion::Sort => {
new_groups.push(group.clone())
}
2021-06-03 14:44:53 +02:00
}
}
groups = std::mem::take(&mut new_groups);
}
match optional_words {
TermsMatchingStrategy::Last => groups.into_iter().flatten().collect(),
2024-05-29 11:06:39 +02:00
TermsMatchingStrategy::Frequency => groups.into_iter().flatten().collect(),
TermsMatchingStrategy::All => {
groups.into_iter().flatten().filter(|d| d.word_rank == 0).collect()
}
2021-06-03 14:44:53 +02:00
}
}
2021-06-17 13:56:09 +02:00
fn execute_filter(filter: &str, document: &TestDocument) -> Option<String> {
let mut id = None;
if let Some((field, filter)) = filter.split_once("!=") {
2023-01-17 18:01:26 +01:00
if field == "tag" && document.tag != filter
|| (field == "asc_desc_rank"
&& Ok(&document.asc_desc_rank) != filter.parse::<u32>().as_ref())
{
id = Some(document.id.clone())
}
2022-10-10 22:28:03 +09:00
} else if let Some((field, filter)) = filter.split_once('=') {
2023-01-17 18:01:26 +01:00
if field == "tag" && document.tag == filter
|| (field == "asc_desc_rank"
&& document.asc_desc_rank == filter.parse::<u32>().unwrap())
2021-06-17 13:56:09 +02:00
{
id = Some(document.id.clone())
}
2022-10-10 22:28:03 +09:00
} else if let Some(("asc_desc_rank", filter)) = filter.split_once('<') {
2021-06-17 13:56:09 +02:00
if document.asc_desc_rank < filter.parse().unwrap() {
id = Some(document.id.clone())
}
2022-10-10 22:28:03 +09:00
} else if let Some(("asc_desc_rank", filter)) = filter.split_once('>') {
2021-06-17 13:56:09 +02:00
if document.asc_desc_rank > filter.parse().unwrap() {
id = Some(document.id.clone())
}
2021-09-08 13:08:48 +02:00
} else if filter.starts_with("_geoRadius") {
id = (document.geo_rank < 100000).then(|| document.id.clone());
} else if filter.starts_with("NOT _geoRadius") {
id = (document.geo_rank > 1000000).then(|| document.id.clone());
} else if matches!(filter, "opt1 EXISTS" | "NOT opt1 NOT EXISTS") {
id = document.opt1.is_some().then(|| document.id.clone());
} else if matches!(filter, "NOT opt1 EXISTS" | "opt1 NOT EXISTS") {
id = document.opt1.is_none().then(|| document.id.clone());
} else if matches!(filter, "opt1.opt2 EXISTS") {
if document.opt1opt2.is_some() {
id = Some(document.id.clone());
} else if let Some(opt1) = &document.opt1 {
id = contains_key_rec(opt1, "opt2").then(|| document.id.clone());
}
} else if matches!(filter, "opt1 IS NULL" | "NOT opt1 IS NOT NULL") {
2023-03-09 10:04:27 +01:00
id = document.opt1.as_ref().map_or(false, |v| v.is_null()).then(|| document.id.clone());
} else if matches!(filter, "NOT opt1 IS NULL" | "opt1 IS NOT NULL") {
2023-03-09 10:04:27 +01:00
id = document.opt1.as_ref().map_or(true, |v| !v.is_null()).then(|| document.id.clone());
} else if matches!(filter, "opt1.opt2 IS NULL") {
2023-03-09 10:04:27 +01:00
if document.opt1opt2.as_ref().map_or(false, |v| v.is_null()) {
id = Some(document.id.clone());
} else if let Some(opt1) = &document.opt1 {
if !opt1.is_null() {
id = contains_null_rec(opt1, "opt2").then(|| document.id.clone());
}
2023-03-15 14:57:17 +01:00
}
} else if matches!(filter, "opt1 IS EMPTY" | "NOT opt1 IS NOT EMPTY") {
id = document.opt1.as_ref().map_or(false, is_empty_value).then(|| document.id.clone());
2023-03-15 14:57:17 +01:00
} else if matches!(filter, "NOT opt1 IS EMPTY" | "opt1 IS NOT EMPTY") {
id = document
.opt1
.as_ref()
.map_or(true, |v| !is_empty_value(v))
.then(|| document.id.clone());
} else if matches!(filter, "opt1.opt2 IS EMPTY") {
if document.opt1opt2.as_ref().map_or(false, is_empty_value) {
2023-03-15 14:57:17 +01:00
id = Some(document.id.clone());
2023-03-09 10:04:27 +01:00
}
} else if matches!(
filter,
"tag_in IN[1, 2, 3, four, five]" | "NOT tag_in NOT IN[1, 2, 3, four, five]"
) {
id = matches!(document.id.as_str(), "A" | "B" | "C" | "D" | "E")
.then(|| document.id.clone());
} else if matches!(filter, "tag_in NOT IN[1, 2, 3, four, five]") {
id = (!matches!(document.id.as_str(), "A" | "B" | "C" | "D" | "E"))
.then(|| document.id.clone());
2021-06-17 13:56:09 +02:00
}
id
}
2023-03-15 14:57:17 +01:00
pub fn is_empty_value(v: &serde_json::Value) -> bool {
match v {
serde_json::Value::String(s) => s.is_empty(),
serde_json::Value::Array(a) => a.is_empty(),
serde_json::Value::Object(o) => o.is_empty(),
_ => false,
}
}
pub fn contains_key_rec(v: &serde_json::Value, key: &str) -> bool {
match v {
serde_json::Value::Array(v) => {
for v in v.iter() {
if contains_key_rec(v, key) {
return true;
}
}
false
}
serde_json::Value::Object(v) => {
for (k, v) in v.iter() {
if k == key || contains_key_rec(v, key) {
return true;
}
}
false
}
_ => false,
}
}
pub fn contains_null_rec(v: &serde_json::Value, key: &str) -> bool {
match v {
serde_json::Value::Object(v) => {
for (k, v) in v.iter() {
if k == key && v.is_null() || contains_null_rec(v, key) {
return true;
}
}
false
}
serde_json::Value::Array(v) => {
for v in v.iter() {
if contains_null_rec(v, key) {
return true;
}
}
false
}
_ => false,
}
}
2021-06-17 13:56:09 +02:00
pub fn expected_filtered_ids(filters: Vec<Either<Vec<&str>, &str>>) -> HashSet<String> {
let dataset: Vec<TestDocument> =
2021-06-17 13:56:09 +02:00
serde_json::Deserializer::from_str(CONTENT).into_iter().map(|r| r.unwrap()).collect();
let mut filtered_ids: HashSet<_> = dataset.iter().map(|d| d.id.clone()).collect();
for either in filters {
let ids = match either {
Left(array) => array
.into_iter()
.map(|f| {
let ids: HashSet<String> =
dataset.iter().filter_map(|d| execute_filter(f, d)).collect();
ids
})
.reduce(|a, b| a.union(&b).cloned().collect())
.unwrap(),
Right(filter) => {
let ids: HashSet<String> =
dataset.iter().filter_map(|d| execute_filter(filter, d)).collect();
ids
}
};
filtered_ids = filtered_ids.intersection(&ids).cloned().collect();
}
filtered_ids
}
#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
2021-06-03 14:44:53 +02:00
pub struct TestDocument {
pub id: String,
pub word_rank: u32,
pub typo_rank: u32,
pub proximity_rank: u32,
pub attribute_rank: u32,
pub exact_rank: u32,
pub asc_desc_rank: u32,
2021-08-23 11:37:18 +02:00
pub sort_by_rank: u32,
2021-09-08 13:08:48 +02:00
pub geo_rank: u32,
2021-06-03 14:44:53 +02:00
pub title: String,
pub description: String,
pub tag: String,
#[serde(default, deserialize_with = "some_option")]
pub opt1: Option<serde_json::Value>,
#[serde(default, deserialize_with = "some_option", rename = "opt1.opt2")]
pub opt1opt2: Option<serde_json::Value>,
}
fn some_option<'de, D>(deserializer: D) -> Result<Option<serde_json::Value>, D::Error>
where
D: Deserializer<'de>,
{
let result = serde_json::Value::deserialize(deserializer)?;
Ok(Some(result))
2021-06-03 14:44:53 +02:00
}