From c16c60b5998e88a835d20505799b8f0c779d1922 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Sun, 29 Jun 2025 23:48:53 +0200 Subject: [PATCH] Add `vector::extractor` module --- crates/milli/src/vector/extractor.rs | 214 +++++++++++++++++++++++++++ crates/milli/src/vector/mod.rs | 1 + 2 files changed, 215 insertions(+) create mode 100644 crates/milli/src/vector/extractor.rs diff --git a/crates/milli/src/vector/extractor.rs b/crates/milli/src/vector/extractor.rs new file mode 100644 index 000000000..cbfc62ee1 --- /dev/null +++ b/crates/milli/src/vector/extractor.rs @@ -0,0 +1,214 @@ +use std::cell::RefCell; +use std::collections::BTreeMap; +use std::fmt::Debug; + +use bumpalo::Bump; +use serde_json::Value; + +use super::json_template::{self, JsonTemplate}; +use crate::prompt::error::RenderPromptError; +use crate::prompt::Prompt; +use crate::update::new::document::Document; +use crate::vector::RuntimeFragment; +use crate::GlobalFieldsIdsMap; + +pub trait Extractor<'doc> { + type DocumentMetadata; + type Input: PartialEq; + type Error; + + fn extract<'a, D: Document<'a> + Debug>( + &self, + doc: D, + meta: &Self::DocumentMetadata, + ) -> Result, Self::Error>; + + fn extractor_id(&self) -> u8; + + fn diff_documents<'a, OD: Document<'a> + Debug, ND: Document<'a> + Debug>( + &self, + old: OD, + new: ND, + meta: &Self::DocumentMetadata, + ) -> Result, Self::Error> + where + 'doc: 'a, + { + let old_input = self.extract(old, meta); + let new_input = self.extract(new, meta); + to_diff(old_input, new_input) + } + + fn diff_settings<'a, D: Document<'a> + Debug>( + &self, + doc: D, + meta: &Self::DocumentMetadata, + old: Option<&Self>, + ) -> Result, Self::Error> { + let old_input = if let Some(old) = old { old.extract(&doc, meta) } else { Ok(None) }; + let new_input = self.extract(&doc, meta); + + to_diff(old_input, new_input) + } + + fn ignore_errors(self) -> IgnoreErrorExtractor + where + Self: Sized, + { + IgnoreErrorExtractor(self) + } +} + +fn to_diff( + old_input: Result, E>, + new_input: Result, E>, +) -> Result, E> { + let old_input = old_input.ok().unwrap_or(None); + let new_input = new_input?; + Ok(match (old_input, new_input) { + (Some(old), Some(new)) if old == new => ExtractorDiff::Unchanged, + (None, None) => ExtractorDiff::Unchanged, + (None, Some(input)) => ExtractorDiff::Added(input), + (Some(_), None) => ExtractorDiff::Removed, + (Some(_), Some(input)) => ExtractorDiff::Updated(input), + }) +} + +pub enum ExtractorDiff { + Removed, + Added(Input), + Updated(Input), + Unchanged, +} + +impl ExtractorDiff { + pub fn into_input(self) -> Option { + match self { + ExtractorDiff::Removed => None, + ExtractorDiff::Added(input) => Some(input), + ExtractorDiff::Updated(input) => Some(input), + ExtractorDiff::Unchanged => None, + } + } + + pub fn needs_change(&self) -> bool { + match self { + ExtractorDiff::Removed => true, + ExtractorDiff::Added(_) => true, + ExtractorDiff::Updated(_) => true, + ExtractorDiff::Unchanged => false, + } + } + + pub fn into_list_of_changes( + named_diffs: impl IntoIterator, + ) -> BTreeMap> { + named_diffs + .into_iter() + .filter(|(_, diff)| diff.needs_change()) + .map(|(name, diff)| (name, diff.into_input())) + .collect() + } +} + +pub struct DocumentTemplateExtractor<'a, 'b, 'c> { + doc_alloc: &'a Bump, + field_id_map: &'a RefCell>, + template: &'c Prompt, +} + +impl<'a, 'b, 'c> DocumentTemplateExtractor<'a, 'b, 'c> { + pub fn new( + template: &'c Prompt, + doc_alloc: &'a Bump, + field_id_map: &'a RefCell>, + ) -> Self { + Self { template, doc_alloc, field_id_map } + } +} + +impl<'doc> Extractor<'doc> for DocumentTemplateExtractor<'doc, '_, '_> { + type DocumentMetadata = &'doc str; + type Input = &'doc str; + type Error = RenderPromptError; + + fn extractor_id(&self) -> u8 { + 0 + } + + fn extract<'a, D: Document<'a> + Debug>( + &self, + doc: D, + external_docid: &Self::DocumentMetadata, + ) -> Result, Self::Error> { + Ok(Some(self.template.render_document( + external_docid, + doc, + self.field_id_map, + self.doc_alloc, + )?)) + } +} + +pub struct RequestFragmentExtractor<'a> { + fragment: &'a JsonTemplate, + extractor_id: u8, + doc_alloc: &'a Bump, +} + +impl<'a> RequestFragmentExtractor<'a> { + pub fn new(fragment: &'a RuntimeFragment, doc_alloc: &'a Bump) -> Self { + Self { fragment: &fragment.template, extractor_id: fragment.id, doc_alloc } + } +} + +impl<'doc> Extractor<'doc> for RequestFragmentExtractor<'doc> { + type DocumentMetadata = (); + type Input = Value; + type Error = json_template::Error; + + fn extractor_id(&self) -> u8 { + self.extractor_id + } + + fn extract<'a, D: Document<'a> + Debug>( + &self, + doc: D, + _meta: &Self::DocumentMetadata, + ) -> Result, Self::Error> { + Ok(Some(self.fragment.render_document(doc, self.doc_alloc)?)) + } +} + +pub struct IgnoreErrorExtractor(E); + +impl<'doc, E> Extractor<'doc> for IgnoreErrorExtractor +where + E: Extractor<'doc>, +{ + type DocumentMetadata = E::DocumentMetadata; + type Input = E::Input; + + type Error = Infallible; + + fn extractor_id(&self) -> u8 { + self.0.extractor_id() + } + + fn extract<'a, D: Document<'a> + Debug>( + &self, + doc: D, + meta: &Self::DocumentMetadata, + ) -> Result, Self::Error> { + Ok(self.0.extract(doc, meta).ok().flatten()) + } +} + +#[derive(Debug)] +pub enum Infallible {} + +impl From for crate::Error { + fn from(_: Infallible) -> Self { + unreachable!("Infallible values cannot be built") + } +} diff --git a/crates/milli/src/vector/mod.rs b/crates/milli/src/vector/mod.rs index ec4ee2ccd..246f824e1 100644 --- a/crates/milli/src/vector/mod.rs +++ b/crates/milli/src/vector/mod.rs @@ -20,6 +20,7 @@ use crate::ThreadPoolNoAbort; pub mod composite; pub mod db; pub mod error; +pub mod extractor; pub mod hf; pub mod json_template; pub mod manual;