Add media to search

This commit is contained in:
Louis Dureuil 2025-06-30 00:10:46 +02:00
parent 46bceb91f1
commit d14184f4da
No known key found for this signature in database
5 changed files with 86 additions and 33 deletions

View file

@ -56,6 +56,8 @@ pub struct FacetSearchQuery {
pub q: Option<String>, pub q: Option<String>,
#[deserr(default, error = DeserrJsonError<InvalidSearchVector>)] #[deserr(default, error = DeserrJsonError<InvalidSearchVector>)]
pub vector: Option<Vec<f32>>, pub vector: Option<Vec<f32>>,
#[deserr(default, error = DeserrJsonError<InvalidSearchMedia>)]
pub media: Option<Value>,
#[deserr(default, error = DeserrJsonError<InvalidSearchHybridQuery>)] #[deserr(default, error = DeserrJsonError<InvalidSearchHybridQuery>)]
pub hybrid: Option<HybridQuery>, pub hybrid: Option<HybridQuery>,
#[deserr(default, error = DeserrJsonError<InvalidSearchFilter>)] #[deserr(default, error = DeserrJsonError<InvalidSearchFilter>)]
@ -94,6 +96,7 @@ impl FacetSearchAggregator {
facet_name, facet_name,
vector, vector,
q, q,
media,
filter, filter,
matching_strategy, matching_strategy,
attributes_to_search_on, attributes_to_search_on,
@ -108,6 +111,7 @@ impl FacetSearchAggregator {
facet_names: Some(facet_name.clone()).into_iter().collect(), facet_names: Some(facet_name.clone()).into_iter().collect(),
additional_search_parameters_provided: q.is_some() additional_search_parameters_provided: q.is_some()
|| vector.is_some() || vector.is_some()
|| media.is_some()
|| filter.is_some() || filter.is_some()
|| *matching_strategy != MatchingStrategy::default() || *matching_strategy != MatchingStrategy::default()
|| attributes_to_search_on.is_some() || attributes_to_search_on.is_some()
@ -291,6 +295,7 @@ impl From<FacetSearchQuery> for SearchQuery {
facet_name: _, facet_name: _,
q, q,
vector, vector,
media,
filter, filter,
matching_strategy, matching_strategy,
attributes_to_search_on, attributes_to_search_on,
@ -312,6 +317,7 @@ impl From<FacetSearchQuery> for SearchQuery {
SearchQuery { SearchQuery {
q, q,
media,
offset: DEFAULT_SEARCH_OFFSET(), offset: DEFAULT_SEARCH_OFFSET(),
limit: DEFAULT_SEARCH_LIMIT(), limit: DEFAULT_SEARCH_LIMIT(),
page, page,

View file

@ -205,6 +205,8 @@ impl TryFrom<SearchQueryGet> for SearchQuery {
Ok(Self { Ok(Self {
q: other.q, q: other.q,
// `media` not supported for `GET`
media: None,
vector: other.vector.map(CS::into_inner), vector: other.vector.map(CS::into_inner),
offset: other.offset.0, offset: other.offset.0,
limit: other.limit.0, limit: other.limit.0,

View file

@ -64,6 +64,8 @@ pub struct SearchQuery {
pub q: Option<String>, pub q: Option<String>,
#[deserr(default, error = DeserrJsonError<InvalidSearchVector>)] #[deserr(default, error = DeserrJsonError<InvalidSearchVector>)]
pub vector: Option<Vec<f32>>, pub vector: Option<Vec<f32>>,
#[deserr(default, error = DeserrJsonError<InvalidSearchMedia>)]
pub media: Option<serde_json::Value>,
#[deserr(default, error = DeserrJsonError<InvalidSearchHybridQuery>)] #[deserr(default, error = DeserrJsonError<InvalidSearchHybridQuery>)]
pub hybrid: Option<HybridQuery>, pub hybrid: Option<HybridQuery>,
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)] #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
@ -147,6 +149,7 @@ impl From<SearchParameters> for SearchQuery {
ranking_score_threshold: ranking_score_threshold.map(RankingScoreThreshold::from), ranking_score_threshold: ranking_score_threshold.map(RankingScoreThreshold::from),
q: None, q: None,
vector: None, vector: None,
media: None,
offset: DEFAULT_SEARCH_OFFSET(), offset: DEFAULT_SEARCH_OFFSET(),
page: None, page: None,
hits_per_page: None, hits_per_page: None,
@ -220,6 +223,7 @@ impl fmt::Debug for SearchQuery {
let Self { let Self {
q, q,
vector, vector,
media,
hybrid, hybrid,
offset, offset,
limit, limit,
@ -274,6 +278,9 @@ impl fmt::Debug for SearchQuery {
); );
} }
} }
if let Some(media) = media {
debug.field("media", media);
}
if let Some(hybrid) = hybrid { if let Some(hybrid) = hybrid {
debug.field("hybrid", &hybrid); debug.field("hybrid", &hybrid);
} }
@ -482,8 +489,10 @@ pub struct SearchQueryWithIndex {
pub index_uid: IndexUid, pub index_uid: IndexUid,
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)] #[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
pub q: Option<String>, pub q: Option<String>,
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)] #[deserr(default, error = DeserrJsonError<InvalidSearchVector>)]
pub vector: Option<Vec<f32>>, pub vector: Option<Vec<f32>>,
#[deserr(default, error = DeserrJsonError<InvalidSearchMedia>)]
pub media: Option<serde_json::Value>,
#[deserr(default, error = DeserrJsonError<InvalidSearchHybridQuery>)] #[deserr(default, error = DeserrJsonError<InvalidSearchHybridQuery>)]
pub hybrid: Option<HybridQuery>, pub hybrid: Option<HybridQuery>,
#[deserr(default, error = DeserrJsonError<InvalidSearchOffset>)] #[deserr(default, error = DeserrJsonError<InvalidSearchOffset>)]
@ -564,6 +573,7 @@ impl SearchQueryWithIndex {
let SearchQuery { let SearchQuery {
q, q,
vector, vector,
media,
hybrid, hybrid,
offset, offset,
limit, limit,
@ -594,6 +604,7 @@ impl SearchQueryWithIndex {
index_uid, index_uid,
q, q,
vector, vector,
media,
hybrid, hybrid,
offset: if offset == DEFAULT_SEARCH_OFFSET() { None } else { Some(offset) }, offset: if offset == DEFAULT_SEARCH_OFFSET() { None } else { Some(offset) },
limit: if limit == DEFAULT_SEARCH_LIMIT() { None } else { Some(limit) }, limit: if limit == DEFAULT_SEARCH_LIMIT() { None } else { Some(limit) },
@ -628,6 +639,7 @@ impl SearchQueryWithIndex {
federation_options, federation_options,
q, q,
vector, vector,
media,
offset, offset,
limit, limit,
page, page,
@ -658,6 +670,7 @@ impl SearchQueryWithIndex {
SearchQuery { SearchQuery {
q, q,
vector, vector,
media,
offset: offset.unwrap_or(DEFAULT_SEARCH_OFFSET()), offset: offset.unwrap_or(DEFAULT_SEARCH_OFFSET()),
limit: limit.unwrap_or(DEFAULT_SEARCH_LIMIT()), limit: limit.unwrap_or(DEFAULT_SEARCH_LIMIT()),
page, page,
@ -984,14 +997,27 @@ pub fn prepare_search<'t>(
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(10); let deadline = std::time::Instant::now() + std::time::Duration::from_secs(10);
let q = query.q.as_deref();
let media = query.media.as_ref();
let search_query = match (q, media) {
(Some(text), None) => milli::vector::SearchQuery::Text(text),
(q, media) => milli::vector::SearchQuery::Media { q, media },
};
embedder embedder
.embed_search(query.q.as_ref().unwrap(), Some(deadline)) .embed_search(search_query, Some(deadline))
.map_err(milli::vector::Error::from) .map_err(milli::vector::Error::from)
.map_err(milli::Error::from)? .map_err(milli::Error::from)?
} }
}; };
search.semantic(
search.semantic(embedder_name.clone(), embedder.clone(), *quantized, Some(vector)); embedder_name.clone(),
embedder.clone(),
*quantized,
Some(vector),
query.media.clone(),
);
} }
SearchKind::Hybrid { embedder_name, embedder, quantized, semantic_ratio: _ } => { SearchKind::Hybrid { embedder_name, embedder, quantized, semantic_ratio: _ } => {
if let Some(q) = &query.q { if let Some(q) = &query.q {
@ -1003,6 +1029,7 @@ pub fn prepare_search<'t>(
embedder.clone(), embedder.clone(),
*quantized, *quantized,
query.vector.clone(), query.vector.clone(),
query.media.clone(),
); );
} }
} }
@ -1127,6 +1154,7 @@ pub fn perform_search(
locales, locales,
// already used in prepare_search // already used in prepare_search
vector: _, vector: _,
media: _,
hybrid: _, hybrid: _,
offset: _, offset: _,
ranking_score_threshold: _, ranking_score_threshold: _,

View file

@ -7,6 +7,7 @@ use roaring::RoaringBitmap;
use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy}; use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy};
use crate::search::new::{distinct_fid, distinct_single_docid}; use crate::search::new::{distinct_fid, distinct_single_docid};
use crate::search::SemanticSearch; use crate::search::SemanticSearch;
use crate::vector::SearchQuery;
use crate::{Index, MatchingWords, Result, Search, SearchResult}; use crate::{Index, MatchingWords, Result, Search, SearchResult};
struct ScoreWithRatioResult { struct ScoreWithRatioResult {
@ -225,12 +226,9 @@ impl Search<'_> {
return Ok(return_keyword_results(self.limit, self.offset, keyword_results)); return Ok(return_keyword_results(self.limit, self.offset, keyword_results));
} }
// no vector search against placeholder search
let Some(query) = search.query.take() else {
return Ok(return_keyword_results(self.limit, self.offset, keyword_results));
};
// no embedder, no semantic search // no embedder, no semantic search
let Some(SemanticSearch { vector, embedder_name, embedder, quantized }) = semantic else { let Some(SemanticSearch { vector, embedder_name, embedder, quantized, media }) = semantic
else {
return Ok(return_keyword_results(self.limit, self.offset, keyword_results)); return Ok(return_keyword_results(self.limit, self.offset, keyword_results));
}; };
@ -241,9 +239,17 @@ impl Search<'_> {
let span = tracing::trace_span!(target: "search::hybrid", "embed_one"); let span = tracing::trace_span!(target: "search::hybrid", "embed_one");
let _entered = span.enter(); let _entered = span.enter();
let q = search.query.as_deref();
let media = media.as_ref();
let query = match (q, media) {
(Some(text), None) => SearchQuery::Text(text),
(q, media) => SearchQuery::Media { q, media },
};
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(3); let deadline = std::time::Instant::now() + std::time::Duration::from_secs(3);
match embedder.embed_search(&query, Some(deadline)) { match embedder.embed_search(query, Some(deadline)) {
Ok(embedding) => embedding, Ok(embedding) => embedding,
Err(error) => { Err(error) => {
tracing::error!(error=%error, "Embedding failed"); tracing::error!(error=%error, "Embedding failed");
@ -257,8 +263,13 @@ impl Search<'_> {
} }
}; };
search.semantic = search.semantic = Some(SemanticSearch {
Some(SemanticSearch { vector: Some(vector_query), embedder_name, embedder, quantized }); vector: Some(vector_query),
embedder_name,
embedder,
quantized,
media,
});
// TODO: would be better to have two distinct functions at this point // TODO: would be better to have two distinct functions at this point
let vector_results = search.execute()?; let vector_results = search.execute()?;

View file

@ -12,7 +12,7 @@ use self::new::{execute_vector_search, PartialSearchResult, VectorStoreStats};
use crate::filterable_attributes_rules::{filtered_matching_patterns, matching_features}; use crate::filterable_attributes_rules::{filtered_matching_patterns, matching_features};
use crate::index::MatchingStrategy; use crate::index::MatchingStrategy;
use crate::score_details::{ScoreDetails, ScoringStrategy}; use crate::score_details::{ScoreDetails, ScoringStrategy};
use crate::vector::Embedder; use crate::vector::{Embedder, Embedding};
use crate::{ use crate::{
execute_search, filtered_universe, AscDesc, DefaultSearchLogger, DocumentId, Error, Index, execute_search, filtered_universe, AscDesc, DefaultSearchLogger, DocumentId, Error, Index,
Result, SearchContext, TimeBudget, UserError, Result, SearchContext, TimeBudget, UserError,
@ -32,6 +32,7 @@ pub mod similar;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct SemanticSearch { pub struct SemanticSearch {
vector: Option<Vec<f32>>, vector: Option<Vec<f32>>,
media: Option<serde_json::Value>,
embedder_name: String, embedder_name: String,
embedder: Arc<Embedder>, embedder: Arc<Embedder>,
quantized: bool, quantized: bool,
@ -93,9 +94,10 @@ impl<'a> Search<'a> {
embedder_name: String, embedder_name: String,
embedder: Arc<Embedder>, embedder: Arc<Embedder>,
quantized: bool, quantized: bool,
vector: Option<Vec<f32>>, vector: Option<Embedding>,
media: Option<serde_json::Value>,
) -> &mut Search<'a> { ) -> &mut Search<'a> {
self.semantic = Some(SemanticSearch { embedder_name, embedder, quantized, vector }); self.semantic = Some(SemanticSearch { embedder_name, embedder, quantized, vector, media });
self self
} }
@ -231,8 +233,13 @@ impl<'a> Search<'a> {
degraded, degraded,
used_negative_operator, used_negative_operator,
} = match self.semantic.as_ref() { } = match self.semantic.as_ref() {
Some(SemanticSearch { vector: Some(vector), embedder_name, embedder, quantized }) => { Some(SemanticSearch {
execute_vector_search( vector: Some(vector),
embedder_name,
embedder,
quantized,
media: _,
}) => execute_vector_search(
&mut ctx, &mut ctx,
vector, vector,
self.scoring_strategy, self.scoring_strategy,
@ -247,8 +254,7 @@ impl<'a> Search<'a> {
*quantized, *quantized,
self.time_budget.clone(), self.time_budget.clone(),
self.ranking_score_threshold, self.ranking_score_threshold,
)? )?,
}
_ => execute_search( _ => execute_search(
&mut ctx, &mut ctx,
self.query.as_deref(), self.query.as_deref(),