Finialize the GeoExtractor

This commit is contained in:
Clément Renault 2024-11-07 15:05:20 +01:00
parent a01bc7b454
commit b17896d899
No known key found for this signature in database
GPG key ID: F250A4C4E3AE5F5F
19 changed files with 497 additions and 93 deletions

View file

@ -54,7 +54,7 @@ impl<'a, 'extractor> Extractor<'extractor> for DocumentsExtractor<'a> {
DocumentChange::Deletion(deletion) => {
let docid = deletion.docid();
let content = deletion.current(
&context.txn,
&context.rtxn,
context.index,
&context.db_fields_ids_map,
)?;
@ -72,7 +72,7 @@ impl<'a, 'extractor> Extractor<'extractor> for DocumentsExtractor<'a> {
DocumentChange::Update(update) => {
let docid = update.docid();
let content =
update.current(&context.txn, context.index, &context.db_fields_ids_map)?;
update.current(&context.rtxn, context.index, &context.db_fields_ids_map)?;
for res in content.iter_top_level_fields() {
let (f, _) = res?;
let entry = document_extractor_data
@ -92,9 +92,9 @@ impl<'a, 'extractor> Extractor<'extractor> for DocumentsExtractor<'a> {
}
let content =
update.merged(&context.txn, context.index, &context.db_fields_ids_map)?;
update.merged(&context.rtxn, context.index, &context.db_fields_ids_map)?;
let vector_content = update.merged_vectors(
&context.txn,
&context.rtxn,
context.index,
&context.db_fields_ids_map,
&context.doc_alloc,

View file

@ -63,7 +63,7 @@ impl FacetedDocidsExtractor {
document_change: DocumentChange,
) -> Result<()> {
let index = &context.index;
let rtxn = &context.txn;
let rtxn = &context.rtxn;
let mut new_fields_ids_map = context.new_fields_ids_map.borrow_mut_or_yield();
let mut cached_sorter = context.data.borrow_mut_or_yield();
match document_change {

View file

@ -10,7 +10,8 @@ pub fn extract_document_facets<'doc>(
field_id_map: &mut GlobalFieldsIdsMap,
facet_fn: &mut impl FnMut(FieldId, &Value) -> Result<()>,
) -> Result<()> {
for res in document.iter_top_level_fields() {
let geo = document.geo_field().transpose().map(|res| res.map(|rval| ("_geo", rval)));
for res in document.iter_top_level_fields().chain(geo) {
let (field_name, value) = res?;
let mut tokenize_field = |name: &str, value: &Value| match field_id_map.id_or_insert(name) {

View file

@ -0,0 +1,302 @@
use std::cell::RefCell;
use std::fs::File;
use std::io::{self, BufReader, BufWriter, ErrorKind, Read, Write as _};
use std::{iter, mem, result};
use bumpalo::Bump;
use bytemuck::{bytes_of, from_bytes, pod_read_unaligned, Pod, Zeroable};
use heed::RoTxn;
use serde_json::value::RawValue;
use serde_json::Value;
use crate::error::GeoError;
use crate::update::new::document::Document;
use crate::update::new::indexer::document_changes::{DocumentChangeContext, Extractor, MostlySend};
use crate::update::new::ref_cell_ext::RefCellExt as _;
use crate::update::new::DocumentChange;
use crate::update::GrenadParameters;
use crate::{lat_lng_to_xyz, DocumentId, GeoPoint, Index, InternalError, Object, Result};
pub struct GeoExtractor {
grenad_parameters: GrenadParameters,
}
impl GeoExtractor {
pub fn new(
rtxn: &RoTxn,
index: &Index,
grenad_parameters: GrenadParameters,
) -> Result<Option<Self>> {
let is_sortable = index.sortable_fields(rtxn)?.contains("_geo");
let is_filterable = index.filterable_fields(rtxn)?.contains("_geo");
if is_sortable || is_filterable {
Ok(Some(GeoExtractor { grenad_parameters }))
} else {
Ok(None)
}
}
}
#[derive(Pod, Zeroable, Copy, Clone)]
#[repr(C, packed)]
pub struct ExtractedGeoPoint {
pub docid: DocumentId,
pub lat_lng: [f64; 2],
}
impl From<ExtractedGeoPoint> for GeoPoint {
/// Converts the latitude and longitude back to an xyz GeoPoint.
fn from(value: ExtractedGeoPoint) -> Self {
let [lat, lng] = value.lat_lng;
let point = [lat, lng];
let xyz_point = lat_lng_to_xyz(&point);
GeoPoint::new(xyz_point, (value.docid, point))
}
}
pub struct GeoExtractorData<'extractor> {
/// The set of documents ids that were removed. If a document sees its geo
/// point being updated, we first put it in the deleted and then in the inserted.
removed: bumpalo::collections::Vec<'extractor, ExtractedGeoPoint>,
inserted: bumpalo::collections::Vec<'extractor, ExtractedGeoPoint>,
/// TODO Do the doc
spilled_removed: Option<BufWriter<File>>,
/// TODO Do the doc
spilled_inserted: Option<BufWriter<File>>,
}
impl<'extractor> GeoExtractorData<'extractor> {
pub fn freeze(self) -> Result<FrozenGeoExtractorData<'extractor>> {
let GeoExtractorData { removed, inserted, spilled_removed, spilled_inserted } = self;
Ok(FrozenGeoExtractorData {
removed: removed.into_bump_slice(),
inserted: inserted.into_bump_slice(),
spilled_removed: spilled_removed
.map(|bw| bw.into_inner().map(BufReader::new).map_err(|iie| iie.into_error()))
.transpose()?,
spilled_inserted: spilled_inserted
.map(|bw| bw.into_inner().map(BufReader::new).map_err(|iie| iie.into_error()))
.transpose()?,
})
}
}
unsafe impl MostlySend for GeoExtractorData<'_> {}
pub struct FrozenGeoExtractorData<'extractor> {
pub removed: &'extractor [ExtractedGeoPoint],
pub inserted: &'extractor [ExtractedGeoPoint],
pub spilled_removed: Option<BufReader<File>>,
pub spilled_inserted: Option<BufReader<File>>,
}
impl<'extractor> FrozenGeoExtractorData<'extractor> {
pub fn iter_and_clear_removed(
&mut self,
) -> impl IntoIterator<Item = io::Result<ExtractedGeoPoint>> + '_ {
mem::take(&mut self.removed)
.iter()
.copied()
.map(Ok)
.chain(iterator_over_spilled_geopoints(&mut self.spilled_removed))
}
pub fn iter_and_clear_inserted(
&mut self,
) -> impl IntoIterator<Item = io::Result<ExtractedGeoPoint>> + '_ {
mem::take(&mut self.inserted)
.iter()
.copied()
.map(Ok)
.chain(iterator_over_spilled_geopoints(&mut self.spilled_inserted))
}
}
fn iterator_over_spilled_geopoints(
spilled: &mut Option<BufReader<File>>,
) -> impl IntoIterator<Item = io::Result<ExtractedGeoPoint>> + '_ {
let mut spilled = spilled.take();
iter::from_fn(move || match &mut spilled {
Some(file) => {
let geopoint_bytes = &mut [0u8; mem::size_of::<ExtractedGeoPoint>()];
match file.read_exact(geopoint_bytes) {
Ok(()) => Some(Ok(pod_read_unaligned(geopoint_bytes))),
Err(e) if e.kind() == ErrorKind::UnexpectedEof => None,
Err(e) => Some(Err(e)),
}
}
None => None,
})
}
impl<'extractor> Extractor<'extractor> for GeoExtractor {
type Data = RefCell<GeoExtractorData<'extractor>>;
fn init_data<'doc>(&'doc self, extractor_alloc: &'extractor Bump) -> Result<Self::Data> {
Ok(RefCell::new(GeoExtractorData {
removed: bumpalo::collections::Vec::new_in(extractor_alloc),
// inserted: Uell::new_in(extractor_alloc),
inserted: bumpalo::collections::Vec::new_in(extractor_alloc),
spilled_inserted: None,
spilled_removed: None,
}))
}
fn process<'doc>(
&'doc self,
changes: impl Iterator<Item = Result<DocumentChange<'doc>>>,
context: &'doc DocumentChangeContext<Self::Data>,
) -> Result<()> {
let rtxn = &context.rtxn;
let index = context.index;
let max_memory = self.grenad_parameters.max_memory;
let db_fields_ids_map = context.db_fields_ids_map;
let mut data_ref = context.data.borrow_mut_or_yield();
for change in changes {
if max_memory.map_or(false, |mm| context.extractor_alloc.allocated_bytes() >= mm) {
// We must spill as we allocated too much memory
data_ref.spilled_removed = tempfile::tempfile().map(BufWriter::new).map(Some)?;
data_ref.spilled_inserted = tempfile::tempfile().map(BufWriter::new).map(Some)?;
}
match change? {
DocumentChange::Deletion(deletion) => {
let docid = deletion.docid();
let external_id = deletion.external_document_id();
let current = deletion.current(rtxn, index, db_fields_ids_map)?;
let current_geo = current
.geo_field()?
.map(|geo| extract_geo_coordinates(external_id, geo))
.transpose()?;
if let Some(lat_lng) = current_geo.flatten() {
let geopoint = ExtractedGeoPoint { docid, lat_lng };
match &mut data_ref.spilled_removed {
Some(file) => file.write_all(bytes_of(&geopoint))?,
None => data_ref.removed.push(geopoint),
}
}
}
DocumentChange::Update(update) => {
let current = update.current(rtxn, index, db_fields_ids_map)?;
let external_id = update.external_document_id();
let docid = update.docid();
let current_geo = current
.geo_field()?
.map(|geo| extract_geo_coordinates(external_id, geo))
.transpose()?;
let updated_geo = update
.updated()
.geo_field()?
.map(|geo| extract_geo_coordinates(external_id, geo))
.transpose()?;
if current_geo != updated_geo {
// If the current and new geo points are different it means that
// we need to replace the current by the new point and therefore
// delete the current point from the RTree.
if let Some(lat_lng) = current_geo.flatten() {
let geopoint = ExtractedGeoPoint { docid, lat_lng };
match &mut data_ref.spilled_removed {
Some(file) => file.write_all(bytes_of(&geopoint))?,
None => data_ref.removed.push(geopoint),
}
}
if let Some(lat_lng) = updated_geo.flatten() {
let geopoint = ExtractedGeoPoint { docid, lat_lng };
match &mut data_ref.spilled_inserted {
Some(file) => file.write_all(bytes_of(&geopoint))?,
None => data_ref.inserted.push(geopoint),
}
}
}
}
DocumentChange::Insertion(insertion) => {
let external_id = insertion.external_document_id();
let docid = insertion.docid();
let inserted_geo = insertion
.inserted()
.geo_field()?
.map(|geo| extract_geo_coordinates(external_id, geo))
.transpose()?;
if let Some(lat_lng) = inserted_geo.flatten() {
let geopoint = ExtractedGeoPoint { docid, lat_lng };
match &mut data_ref.spilled_inserted {
Some(file) => file.write_all(bytes_of(&geopoint))?,
None => data_ref.inserted.push(geopoint),
}
}
}
}
}
Ok(())
}
}
/// Extracts and validate the latitude and latitude from a document geo field.
///
/// It can be of the form `{ "lat": 0.0, "lng": "1.0" }`.
fn extract_geo_coordinates(external_id: &str, raw_value: &RawValue) -> Result<Option<[f64; 2]>> {
let mut geo = match serde_json::from_str(raw_value.get()).map_err(InternalError::SerdeJson)? {
Value::Null => return Ok(None),
Value::Object(map) => map,
value => {
return Err(
GeoError::NotAnObject { document_id: Value::from(external_id), value }.into()
)
}
};
let [lat, lng] = match (geo.remove("lat"), geo.remove("lng")) {
(Some(lat), Some(lng)) => [lat, lng],
(Some(_), None) => {
return Err(GeoError::MissingLatitude { document_id: Value::from(external_id) }.into())
}
(None, Some(_)) => {
return Err(GeoError::MissingLongitude { document_id: Value::from(external_id) }.into())
}
(None, None) => {
return Err(GeoError::MissingLatitudeAndLongitude {
document_id: Value::from(external_id),
}
.into())
}
};
let lat = extract_finite_float_from_value(lat)
.map_err(|value| GeoError::BadLatitude { document_id: Value::from(external_id), value })?;
let lng = extract_finite_float_from_value(lng)
.map_err(|value| GeoError::BadLongitude { document_id: Value::from(external_id), value })?;
Ok(Some([lat, lng]))
}
/// Extracts and validate that a serde JSON Value is actually a finite f64.
pub fn extract_finite_float_from_value(value: Value) -> result::Result<f64, Value> {
let number = match value {
Value::Number(ref n) => match n.as_f64() {
Some(number) => number,
None => return Err(value),
},
Value::String(ref s) => match s.parse::<f64>() {
Ok(number) => number,
Err(_) => return Err(value),
},
value => return Err(value),
};
if number.is_finite() {
Ok(number)
} else {
Err(value)
}
}

View file

@ -1,6 +1,7 @@
mod cache;
mod documents;
mod faceted;
mod geo;
mod searchable;
mod vectors;
@ -8,6 +9,7 @@ use bumpalo::Bump;
pub use cache::{merge_caches, transpose_and_freeze_caches, BalancedCaches, DelAddRoaringBitmap};
pub use documents::*;
pub use faceted::*;
pub use geo::*;
pub use searchable::*;
pub use vectors::EmbeddingExtractor;

View file

@ -326,7 +326,7 @@ impl WordDocidsExtractors {
document_change: DocumentChange,
) -> Result<()> {
let index = &context.index;
let rtxn = &context.txn;
let rtxn = &context.rtxn;
let mut cached_sorter_ref = context.data.borrow_mut_or_yield();
let cached_sorter = cached_sorter_ref.as_mut().unwrap();
let mut new_fields_ids_map = context.new_fields_ids_map.borrow_mut_or_yield();

View file

@ -39,7 +39,7 @@ impl SearchableExtractor for WordPairProximityDocidsExtractor {
let doc_alloc = &context.doc_alloc;
let index = context.index;
let rtxn = &context.txn;
let rtxn = &context.rtxn;
let mut key_buffer = bumpalo::collections::Vec::new_in(doc_alloc);
let mut del_word_pair_proximity = bumpalo::collections::Vec::new_in(doc_alloc);

View file

@ -2,13 +2,13 @@ use std::cell::RefCell;
use bumpalo::collections::Vec as BVec;
use bumpalo::Bump;
use hashbrown::HashMap;
use hashbrown::{DefaultHashBuilder, HashMap};
use super::cache::DelAddRoaringBitmap;
use crate::error::FaultSource;
use crate::prompt::Prompt;
use crate::update::new::channel::EmbeddingSender;
use crate::update::new::indexer::document_changes::{Extractor, MostlySend};
use crate::update::new::indexer::document_changes::{DocumentChangeContext, Extractor, MostlySend};
use crate::update::new::vector_document::VectorDocument;
use crate::update::new::DocumentChange;
use crate::vector::error::{
@ -37,7 +37,7 @@ impl<'a> EmbeddingExtractor<'a> {
}
pub struct EmbeddingExtractorData<'extractor>(
pub HashMap<String, DelAddRoaringBitmap, hashbrown::DefaultHashBuilder, &'extractor Bump>,
pub HashMap<String, DelAddRoaringBitmap, DefaultHashBuilder, &'extractor Bump>,
);
unsafe impl MostlySend for EmbeddingExtractorData<'_> {}
@ -52,9 +52,7 @@ impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> {
fn process<'doc>(
&'doc self,
changes: impl Iterator<Item = crate::Result<DocumentChange<'doc>>>,
context: &'doc crate::update::new::indexer::document_changes::DocumentChangeContext<
Self::Data,
>,
context: &'doc DocumentChangeContext<Self::Data>,
) -> crate::Result<()> {
let embedders = self.embedders.inner_as_ref();
let mut unused_vectors_distribution =
@ -63,7 +61,7 @@ impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> {
let mut all_chunks = BVec::with_capacity_in(embedders.len(), &context.doc_alloc);
for (embedder_name, (embedder, prompt, _is_quantized)) in embedders {
let embedder_id =
context.index.embedder_category_id.get(&context.txn, embedder_name)?.ok_or_else(
context.index.embedder_category_id.get(&context.rtxn, embedder_name)?.ok_or_else(
|| InternalError::DatabaseMissingEntry {
db_name: "embedder_category_id",
key: None,
@ -95,7 +93,7 @@ impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> {
}
DocumentChange::Update(update) => {
let old_vectors = update.current_vectors(
&context.txn,
&context.rtxn,
context.index,
context.db_fields_ids_map,
&context.doc_alloc,
@ -132,7 +130,7 @@ impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> {
} else if new_vectors.regenerate {
let new_rendered = prompt.render_document(
update.current(
&context.txn,
&context.rtxn,
context.index,
context.db_fields_ids_map,
)?,
@ -141,7 +139,7 @@ impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> {
)?;
let old_rendered = prompt.render_document(
update.merged(
&context.txn,
&context.rtxn,
context.index,
context.db_fields_ids_map,
)?,
@ -160,7 +158,7 @@ impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> {
} else if old_vectors.regenerate {
let old_rendered = prompt.render_document(
update.current(
&context.txn,
&context.rtxn,
context.index,
context.db_fields_ids_map,
)?,
@ -169,7 +167,7 @@ impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> {
)?;
let new_rendered = prompt.render_document(
update.merged(
&context.txn,
&context.rtxn,
context.index,
context.db_fields_ids_map,
)?,