2023-11-15 15:46:37 +01:00
|
|
|
use std::iter::FromIterator;
|
|
|
|
|
2023-12-07 17:03:10 +01:00
|
|
|
use ordered_float::OrderedFloat;
|
2023-11-15 15:46:37 +01:00
|
|
|
use roaring::RoaringBitmap;
|
|
|
|
|
|
|
|
use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait};
|
|
|
|
use crate::score_details::{self, ScoreDetails};
|
2023-12-07 17:03:10 +01:00
|
|
|
use crate::{DocumentId, Result, SearchContext, SearchLogger};
|
2023-11-15 15:46:37 +01:00
|
|
|
|
2023-12-07 17:03:10 +01:00
|
|
|
pub struct VectorSort<Q: RankingRuleQueryTrait> {
|
2023-11-15 15:46:37 +01:00
|
|
|
query: Option<Q>,
|
|
|
|
target: Vec<f32>,
|
|
|
|
vector_candidates: RoaringBitmap,
|
2023-12-07 17:03:10 +01:00
|
|
|
cached_sorted_docids: std::vec::IntoIter<(DocumentId, f32, Vec<f32>)>,
|
2023-12-07 13:33:15 +01:00
|
|
|
limit: usize,
|
2023-11-15 15:46:37 +01:00
|
|
|
}
|
|
|
|
|
2023-12-07 17:03:10 +01:00
|
|
|
impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
|
2023-11-15 15:46:37 +01:00
|
|
|
pub fn new(
|
2023-12-07 17:03:10 +01:00
|
|
|
_ctx: &SearchContext,
|
2023-11-15 15:46:37 +01:00
|
|
|
target: Vec<f32>,
|
|
|
|
vector_candidates: RoaringBitmap,
|
2023-12-07 13:33:15 +01:00
|
|
|
limit: usize,
|
2023-11-15 15:46:37 +01:00
|
|
|
) -> Result<Self> {
|
2023-12-07 17:03:10 +01:00
|
|
|
Ok(Self {
|
|
|
|
query: None,
|
|
|
|
target,
|
|
|
|
vector_candidates,
|
|
|
|
cached_sorted_docids: Default::default(),
|
|
|
|
limit,
|
|
|
|
})
|
|
|
|
}
|
2023-11-15 15:46:37 +01:00
|
|
|
|
2023-12-07 17:03:10 +01:00
|
|
|
fn fill_buffer(&mut self, ctx: &mut SearchContext<'_>) -> Result<()> {
|
|
|
|
let readers: std::result::Result<Vec<_>, _> = (0..=u8::MAX)
|
|
|
|
.map_while(|k| {
|
|
|
|
arroy::Reader::open(ctx.txn, k.into(), ctx.index.vector_arroy)
|
|
|
|
.map(Some)
|
|
|
|
.or_else(|e| match e {
|
|
|
|
arroy::Error::MissingMetadata => Ok(None),
|
|
|
|
e => Err(e),
|
|
|
|
})
|
|
|
|
.transpose()
|
|
|
|
})
|
|
|
|
.collect();
|
|
|
|
|
|
|
|
let readers = readers?;
|
2023-11-15 15:46:37 +01:00
|
|
|
|
2023-12-07 17:03:10 +01:00
|
|
|
let target = &self.target;
|
|
|
|
let mut results = Vec::new();
|
|
|
|
|
|
|
|
for reader in readers.iter() {
|
|
|
|
let nns_by_vector = reader.nns_by_vector(
|
|
|
|
ctx.txn,
|
|
|
|
&target,
|
|
|
|
self.limit,
|
|
|
|
None,
|
|
|
|
Some(&self.vector_candidates),
|
|
|
|
)?;
|
|
|
|
let vectors: std::result::Result<Vec<_>, _> = nns_by_vector
|
|
|
|
.iter()
|
|
|
|
.map(|(docid, _)| reader.item_vector(ctx.txn, *docid).transpose().unwrap())
|
|
|
|
.collect();
|
|
|
|
let vectors = vectors?;
|
|
|
|
results.extend(nns_by_vector.into_iter().zip(vectors).map(|((x, y), z)| (x, y, z)));
|
|
|
|
}
|
|
|
|
results.sort_unstable_by_key(|(_, distance, _)| OrderedFloat(*distance));
|
|
|
|
self.cached_sorted_docids = results.into_iter();
|
|
|
|
Ok(())
|
2023-11-15 15:46:37 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-12-07 17:03:10 +01:00
|
|
|
impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort<Q> {
|
2023-11-15 15:46:37 +01:00
|
|
|
fn id(&self) -> String {
|
|
|
|
"vector_sort".to_owned()
|
|
|
|
}
|
|
|
|
|
|
|
|
fn start_iteration(
|
|
|
|
&mut self,
|
2023-12-07 17:03:10 +01:00
|
|
|
ctx: &mut SearchContext<'ctx>,
|
2023-11-15 15:46:37 +01:00
|
|
|
_logger: &mut dyn SearchLogger<Q>,
|
|
|
|
universe: &RoaringBitmap,
|
|
|
|
query: &Q,
|
|
|
|
) -> Result<()> {
|
|
|
|
assert!(self.query.is_none());
|
|
|
|
|
|
|
|
self.query = Some(query.clone());
|
|
|
|
self.vector_candidates &= universe;
|
2023-12-07 17:03:10 +01:00
|
|
|
self.fill_buffer(ctx)?;
|
2023-11-15 15:46:37 +01:00
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
|
|
|
|
#[allow(clippy::only_used_in_recursion)]
|
|
|
|
fn next_bucket(
|
|
|
|
&mut self,
|
|
|
|
ctx: &mut SearchContext<'ctx>,
|
|
|
|
_logger: &mut dyn SearchLogger<Q>,
|
|
|
|
universe: &RoaringBitmap,
|
|
|
|
) -> Result<Option<RankingRuleOutput<Q>>> {
|
|
|
|
let query = self.query.as_ref().unwrap().clone();
|
|
|
|
self.vector_candidates &= universe;
|
|
|
|
|
|
|
|
if self.vector_candidates.is_empty() {
|
|
|
|
return Ok(Some(RankingRuleOutput {
|
|
|
|
query,
|
|
|
|
candidates: universe.clone(),
|
|
|
|
score: ScoreDetails::Vector(score_details::Vector {
|
|
|
|
target_vector: self.target.clone(),
|
|
|
|
value_similarity: None,
|
|
|
|
}),
|
|
|
|
}));
|
|
|
|
}
|
2023-12-07 17:03:10 +01:00
|
|
|
|
|
|
|
while let Some((docid, distance, vector)) = self.cached_sorted_docids.next() {
|
|
|
|
if self.vector_candidates.contains(docid) {
|
|
|
|
return Ok(Some(RankingRuleOutput {
|
|
|
|
query,
|
|
|
|
candidates: RoaringBitmap::from_iter([docid]),
|
|
|
|
score: ScoreDetails::Vector(score_details::Vector {
|
|
|
|
target_vector: self.target.clone(),
|
|
|
|
value_similarity: Some((vector, 1.0 - distance)),
|
|
|
|
}),
|
|
|
|
}));
|
2023-11-15 15:46:37 +01:00
|
|
|
}
|
2023-12-07 17:03:10 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
// if we got out of this loop it means we've exhausted our cache.
|
|
|
|
// we need to refill it and run the function again.
|
|
|
|
self.fill_buffer(ctx)?;
|
|
|
|
self.next_bucket(ctx, _logger, universe)
|
2023-11-15 15:46:37 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
fn end_iteration(&mut self, _ctx: &mut SearchContext<'ctx>, _logger: &mut dyn SearchLogger<Q>) {
|
|
|
|
self.query = None;
|
|
|
|
}
|
|
|
|
}
|