Add "position" part of the attribute ranking rule

This commit is contained in:
Loïc Lecrenier 2023-04-13 10:46:09 +02:00
parent 8edad8291b
commit bd9aba4d77
11 changed files with 314 additions and 31 deletions

View File

@ -34,6 +34,10 @@ pub struct DatabaseCache<'ctx> {
pub words_fst: Option<fst::Set<Cow<'ctx, [u8]>>>,
pub word_position_docids: FxHashMap<(Interned<String>, u16), Option<&'ctx [u8]>>,
pub word_prefix_position_docids: FxHashMap<(Interned<String>, u16), Option<&'ctx [u8]>>,
pub word_positions: FxHashMap<Interned<String>, Vec<u16>>,
pub word_prefix_positions: FxHashMap<Interned<String>, Vec<u16>>,
pub word_fid_docids: FxHashMap<(Interned<String>, u16), Option<&'ctx [u8]>>,
pub word_prefix_fid_docids: FxHashMap<(Interned<String>, u16), Option<&'ctx [u8]>>,
pub word_fids: FxHashMap<Interned<String>, Vec<u16>>,
@ -356,4 +360,77 @@ impl<'ctx> SearchContext<'ctx> {
};
Ok(fids)
}
pub fn get_db_word_prefix_position_docids(
&mut self,
word_prefix: Interned<String>,
position: u16,
) -> Result<Option<RoaringBitmap>> {
DatabaseCache::get_value(
self.txn,
(word_prefix, position),
&(self.word_interner.get(word_prefix).as_str(), position),
&mut self.db_cache.word_prefix_position_docids,
self.index.word_prefix_position_docids.remap_data_type::<ByteSlice>(),
)?
.map(|bytes| CboRoaringBitmapCodec::bytes_decode(bytes).ok_or(heed::Error::Decoding.into()))
.transpose()
}
pub fn get_db_word_positions(&mut self, word: Interned<String>) -> Result<Vec<u16>> {
let positions = match self.db_cache.word_positions.entry(word) {
Entry::Occupied(positions) => positions.get().clone(),
Entry::Vacant(entry) => {
let mut key = self.word_interner.get(word).as_bytes().to_owned();
key.push(0);
let mut positions = vec![];
let remap_key_type = self
.index
.word_position_docids
.remap_types::<ByteSlice, ByteSlice>()
.prefix_iter(self.txn, &key)?
.remap_key_type::<StrBEU16Codec>();
for result in remap_key_type {
let ((_, position), value) = result?;
// filling other caches to avoid searching for them again
self.db_cache.word_position_docids.insert((word, position), Some(value));
positions.push(position);
}
entry.insert(positions.clone());
positions
}
};
Ok(positions)
}
pub fn get_db_word_prefix_positions(
&mut self,
word_prefix: Interned<String>,
) -> Result<Vec<u16>> {
let positions = match self.db_cache.word_prefix_positions.entry(word_prefix) {
Entry::Occupied(positions) => positions.get().clone(),
Entry::Vacant(entry) => {
let mut key = self.word_interner.get(word_prefix).as_bytes().to_owned();
key.push(0);
let mut positions = vec![];
let remap_key_type = self
.index
.word_prefix_position_docids
.remap_types::<ByteSlice, ByteSlice>()
.prefix_iter(self.txn, &key)?
.remap_key_type::<StrBEU16Codec>();
for result in remap_key_type {
let ((_, position), value) = result?;
// filling other caches to avoid searching for them again
self.db_cache
.word_prefix_position_docids
.insert((word_prefix, position), Some(value));
positions.push(position);
}
entry.insert(positions.clone());
positions
}
};
Ok(positions)
}
}

View File

@ -44,7 +44,7 @@ use super::interner::{Interned, MappedInterner};
use super::logger::SearchLogger;
use super::query_graph::QueryNode;
use super::ranking_rule_graph::{
AttributeGraph, ConditionDocIdsCache, DeadEndsCache, ExactnessGraph, ProximityGraph,
ConditionDocIdsCache, DeadEndsCache, ExactnessGraph, FidGraph, PositionGraph, ProximityGraph,
RankingRuleGraph, RankingRuleGraphTrait, TypoGraph,
};
use super::small_bitmap::SmallBitmap;
@ -59,10 +59,16 @@ impl GraphBasedRankingRule<ProximityGraph> {
Self::new_with_id("proximity".to_owned(), terms_matching_strategy)
}
}
pub type Attribute = GraphBasedRankingRule<AttributeGraph>;
impl GraphBasedRankingRule<AttributeGraph> {
pub type Fid = GraphBasedRankingRule<FidGraph>;
impl GraphBasedRankingRule<FidGraph> {
pub fn new(terms_matching_strategy: Option<TermsMatchingStrategy>) -> Self {
Self::new_with_id("attribute".to_owned(), terms_matching_strategy)
Self::new_with_id("fid".to_owned(), terms_matching_strategy)
}
}
pub type Position = GraphBasedRankingRule<PositionGraph>;
impl GraphBasedRankingRule<PositionGraph> {
pub fn new(terms_matching_strategy: Option<TermsMatchingStrategy>) -> Self {
Self::new_with_id("position".to_owned(), terms_matching_strategy)
}
}
pub type Typo = GraphBasedRankingRule<TypoGraph>;

View File

@ -11,8 +11,8 @@ use crate::search::new::interner::Interned;
use crate::search::new::query_graph::QueryNodeData;
use crate::search::new::query_term::LocatedQueryTermSubset;
use crate::search::new::ranking_rule_graph::{
AttributeCondition, AttributeGraph, Edge, ProximityCondition, ProximityGraph, RankingRuleGraph,
RankingRuleGraphTrait, TypoCondition, TypoGraph,
Edge, FidCondition, FidGraph, PositionCondition, PositionGraph, ProximityCondition,
ProximityGraph, RankingRuleGraph, RankingRuleGraphTrait, TypoCondition, TypoGraph,
};
use crate::search::new::ranking_rules::BoxRankingRule;
use crate::search::new::{QueryGraph, QueryNode, RankingRule, SearchContext, SearchLogger};
@ -29,15 +29,18 @@ pub enum SearchEvents {
ProximityPaths { paths: Vec<Vec<Interned<ProximityCondition>>> },
TypoGraph { graph: RankingRuleGraph<TypoGraph> },
TypoPaths { paths: Vec<Vec<Interned<TypoCondition>>> },
AttributeGraph { graph: RankingRuleGraph<AttributeGraph> },
AttributePaths { paths: Vec<Vec<Interned<AttributeCondition>>> },
FidGraph { graph: RankingRuleGraph<FidGraph> },
FidPaths { paths: Vec<Vec<Interned<FidCondition>>> },
PositionGraph { graph: RankingRuleGraph<PositionGraph> },
PositionPaths { paths: Vec<Vec<Interned<PositionCondition>>> },
}
enum Location {
Words,
Typo,
Proximity,
Attribute,
Fid,
Position,
Other,
}
@ -84,7 +87,8 @@ impl SearchLogger<QueryGraph> for VisualSearchLogger {
"words" => Location::Words,
"typo" => Location::Typo,
"proximity" => Location::Proximity,
"attribute" => Location::Attribute,
"fid" => Location::Fid,
"position" => Location::Position,
_ => Location::Other,
});
}
@ -156,13 +160,20 @@ impl SearchLogger<QueryGraph> for VisualSearchLogger {
self.events.push(SearchEvents::ProximityPaths { paths: paths.clone() });
}
}
Location::Attribute => {
if let Some(graph) = state.downcast_ref::<RankingRuleGraph<AttributeGraph>>() {
self.events.push(SearchEvents::AttributeGraph { graph: graph.clone() });
Location::Fid => {
if let Some(graph) = state.downcast_ref::<RankingRuleGraph<FidGraph>>() {
self.events.push(SearchEvents::FidGraph { graph: graph.clone() });
}
if let Some(paths) = state.downcast_ref::<Vec<Vec<Interned<AttributeCondition>>>>()
{
self.events.push(SearchEvents::AttributePaths { paths: paths.clone() });
if let Some(paths) = state.downcast_ref::<Vec<Vec<Interned<FidCondition>>>>() {
self.events.push(SearchEvents::FidPaths { paths: paths.clone() });
}
}
Location::Position => {
if let Some(graph) = state.downcast_ref::<RankingRuleGraph<PositionGraph>>() {
self.events.push(SearchEvents::PositionGraph { graph: graph.clone() });
}
if let Some(paths) = state.downcast_ref::<Vec<Vec<Interned<PositionCondition>>>>() {
self.events.push(SearchEvents::PositionPaths { paths: paths.clone() });
}
}
Location::Other => {}
@ -327,9 +338,13 @@ impl<'ctx> DetailedLoggerFinish<'ctx> {
SearchEvents::TypoPaths { paths } => {
self.write_rr_graph_paths::<TypoGraph>(paths)?;
}
SearchEvents::AttributeGraph { graph } => self.write_rr_graph(&graph)?,
SearchEvents::AttributePaths { paths } => {
self.write_rr_graph_paths::<AttributeGraph>(paths)?;
SearchEvents::FidGraph { graph } => self.write_rr_graph(&graph)?,
SearchEvents::FidPaths { paths } => {
self.write_rr_graph_paths::<FidGraph>(paths)?;
}
SearchEvents::PositionGraph { graph } => self.write_rr_graph(&graph)?,
SearchEvents::PositionPaths { paths } => {
self.write_rr_graph_paths::<PositionGraph>(paths)?;
}
}
Ok(())

View File

@ -28,7 +28,7 @@ use std::collections::HashSet;
use bucket_sort::bucket_sort;
use charabia::TokenizerBuilder;
use db_cache::DatabaseCache;
use graph_based_ranking_rule::{Attribute, Proximity, Typo};
use graph_based_ranking_rule::{Fid, Position, Proximity, Typo};
use heed::RoTxn;
use interner::DedupInterner;
pub use logger::visual::VisualSearchLogger;
@ -223,7 +223,8 @@ fn get_ranking_rules_for_query_graph_search<'ctx>(
continue;
}
attribute = true;
ranking_rules.push(Box::new(Attribute::new(None)));
ranking_rules.push(Box::new(Fid::new(None)));
ranking_rules.push(Box::new(Position::new(None)));
}
crate::Criterion::Sort => {
if sort {

View File

@ -9,22 +9,22 @@ use crate::search::new::SearchContext;
use crate::Result;
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct AttributeCondition {
pub struct FidCondition {
term: LocatedQueryTermSubset,
fid: u16,
}
pub enum AttributeGraph {}
pub enum FidGraph {}
impl RankingRuleGraphTrait for AttributeGraph {
type Condition = AttributeCondition;
impl RankingRuleGraphTrait for FidGraph {
type Condition = FidCondition;
fn resolve_condition(
ctx: &mut SearchContext,
condition: &Self::Condition,
universe: &RoaringBitmap,
) -> Result<ComputedCondition> {
let AttributeCondition { term, .. } = condition;
let FidCondition { term, .. } = condition;
// maybe compute_query_term_subset_docids_within_field_id should accept a universe as argument
let mut docids = compute_query_term_subset_docids_within_field_id(
ctx,
@ -73,7 +73,7 @@ impl RankingRuleGraphTrait for AttributeGraph {
// the term subsets associated to each field ids fetched.
edges.push((
fid as u32 * term.term_ids.len() as u32, // TODO improve the fid score i.e. fid^10.
conditions_interner.insert(AttributeCondition {
conditions_interner.insert(FidCondition {
term: term.clone(), // TODO remove this ugly clone
fid,
}),

View File

@ -11,9 +11,11 @@ mod condition_docids_cache;
mod dead_ends_cache;
/// Implementation of the `attribute` ranking rule
mod attribute;
mod fid;
/// Implementation of the `exactness` ranking rule
mod exactness;
/// Implementation of the `position` ranking rule
mod position;
/// Implementation of the `proximity` ranking rule
mod proximity;
/// Implementation of the `typo` ranking rule
@ -21,11 +23,12 @@ mod typo;
use std::hash::Hash;
pub use attribute::{AttributeCondition, AttributeGraph};
pub use fid::{FidCondition, FidGraph};
pub use cheapest_paths::PathVisitor;
pub use condition_docids_cache::ConditionDocIdsCache;
pub use dead_ends_cache::DeadEndsCache;
pub use exactness::{ExactnessCondition, ExactnessGraph};
pub use position::{PositionCondition, PositionGraph};
pub use proximity::{ProximityCondition, ProximityGraph};
use roaring::RoaringBitmap;
pub use typo::{TypoCondition, TypoGraph};

View File

@ -0,0 +1,93 @@
use fxhash::FxHashSet;
use roaring::RoaringBitmap;
use super::{ComputedCondition, RankingRuleGraphTrait};
use crate::search::new::interner::{DedupInterner, Interned};
use crate::search::new::query_term::LocatedQueryTermSubset;
use crate::search::new::resolve_query_graph::compute_query_term_subset_docids_within_position;
use crate::search::new::SearchContext;
use crate::Result;
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct PositionCondition {
term: LocatedQueryTermSubset,
position: u16,
}
pub enum PositionGraph {}
impl RankingRuleGraphTrait for PositionGraph {
type Condition = PositionCondition;
fn resolve_condition(
ctx: &mut SearchContext,
condition: &Self::Condition,
universe: &RoaringBitmap,
) -> Result<ComputedCondition> {
let PositionCondition { term, .. } = condition;
// maybe compute_query_term_subset_docids_within_position_id should accept a universe as argument
let mut docids = compute_query_term_subset_docids_within_position(
ctx,
&term.term_subset,
condition.position,
)?;
docids &= universe;
Ok(ComputedCondition {
docids,
universe_len: universe.len(),
start_term_subset: None,
end_term_subset: term.clone(),
})
}
fn build_edges(
ctx: &mut SearchContext,
conditions_interner: &mut DedupInterner<Self::Condition>,
_from: Option<&LocatedQueryTermSubset>,
to_term: &LocatedQueryTermSubset,
) -> Result<Vec<(u32, Interned<Self::Condition>)>> {
let term = to_term;
let mut all_positions = FxHashSet::default();
for word in term.term_subset.all_single_words_except_prefix_db(ctx)? {
let positions = ctx.get_db_word_positions(word.interned())?;
all_positions.extend(positions);
}
for phrase in term.term_subset.all_phrases(ctx)? {
for &word in phrase.words(ctx).iter().flatten() {
let positions = ctx.get_db_word_positions(word)?;
all_positions.extend(positions);
}
}
if let Some(word_prefix) = term.term_subset.use_prefix_db(ctx) {
let positions = ctx.get_db_word_prefix_positions(word_prefix.interned())?;
all_positions.extend(positions);
}
let mut edges = vec![];
for position in all_positions {
let cost = {
let mut cost = 0;
for i in 0..term.term_ids.len() {
cost += position as u32 + i as u32;
}
cost
};
// TODO: We can improve performances and relevancy by storing
// the term subsets associated to each position fetched.
edges.push((
cost,
conditions_interner.insert(PositionCondition {
term: term.clone(), // TODO remove this ugly clone
position,
}),
));
}
Ok(edges)
}
}

View File

@ -87,6 +87,41 @@ pub fn compute_query_term_subset_docids_within_field_id(
Ok(docids)
}
pub fn compute_query_term_subset_docids_within_position(
ctx: &mut SearchContext,
term: &QueryTermSubset,
position: u16,
) -> Result<RoaringBitmap> {
// TODO Use the roaring::MultiOps trait
let mut docids = RoaringBitmap::new();
for word in term.all_single_words_except_prefix_db(ctx)? {
if let Some(word_position_docids) =
ctx.get_db_word_position_docids(word.interned(), position)?
{
docids |= word_position_docids;
}
}
for phrase in term.all_phrases(ctx)? {
for &word in phrase.words(ctx).iter().flatten() {
if let Some(word_position_docids) = ctx.get_db_word_position_docids(word, position)? {
docids |= word_position_docids;
}
}
}
if let Some(word_prefix) = term.use_prefix_db(ctx) {
if let Some(word_position_docids) =
ctx.get_db_word_prefix_position_docids(word_prefix.interned(), position)?
{
docids |= word_position_docids;
}
}
Ok(docids)
}
pub fn compute_query_graph_docids(
ctx: &mut SearchContext,
q: &QueryGraph,

View File

@ -95,7 +95,7 @@ fn create_index() -> TempIndex {
}
#[test]
fn test_attributes_simple() {
fn test_attribute_fid_simple() {
let index = create_index();
let txn = index.read_txn().unwrap();

View File

@ -0,0 +1,52 @@
use crate::{index::tests::TempIndex, Criterion, Search, SearchResult, TermsMatchingStrategy};
fn create_index() -> TempIndex {
let index = TempIndex::new();
index
.update_settings(|s| {
s.set_primary_key("id".to_owned());
s.set_searchable_fields(vec!["text".to_owned()]);
s.set_criteria(vec![Criterion::Attribute]);
})
.unwrap();
index
.add_documents(documents!([
{
"id": 0,
"text": "do you know about the quick and talented brown fox",
},
{
"id": 1,
"text": "do you know about the quick brown fox",
},
{
"id": 2,
"text": "the quick and talented brown fox",
},
{
"id": 3,
"text": "fox brown quick the",
},
{
"id": 4,
"text": "the quick brown fox",
},
]))
.unwrap();
index
}
#[test]
fn test_attribute_fid_simple() {
let index = create_index();
let txn = index.read_txn().unwrap();
let mut s = Search::new(&txn, &index);
s.terms_matching_strategy(TermsMatchingStrategy::All);
s.query("the quick brown fox");
let SearchResult { documents_ids, .. } = s.execute().unwrap();
insta::assert_snapshot!(format!("{documents_ids:?}"), @"[3, 4, 2, 1, 0]");
}

View File

@ -1,4 +1,5 @@
pub mod attribute;
pub mod attribute_fid;
pub mod attribute_position;
pub mod distinct;
#[cfg(feature = "default")]
pub mod language;