mirror of
https://github.com/meilisearch/MeiliSearch
synced 2024-11-22 21:04:27 +01:00
WIP multi embedders
fixed template bugs
This commit is contained in:
parent
abbe131084
commit
922a640188
@ -1361,7 +1361,6 @@ impl IndexScheduler {
|
|||||||
let embedder = Arc::new(
|
let embedder = Arc::new(
|
||||||
Embedder::new(embedder_options.clone())
|
Embedder::new(embedder_options.clone())
|
||||||
.map_err(meilisearch_types::milli::vector::Error::from)
|
.map_err(meilisearch_types::milli::vector::Error::from)
|
||||||
.map_err(meilisearch_types::milli::UserError::from)
|
|
||||||
.map_err(meilisearch_types::milli::Error::from)?,
|
.map_err(meilisearch_types::milli::Error::from)?,
|
||||||
);
|
);
|
||||||
{
|
{
|
||||||
|
@ -222,6 +222,8 @@ InvalidVectorsType , InvalidRequest , BAD_REQUEST ;
|
|||||||
InvalidDocumentId , InvalidRequest , BAD_REQUEST ;
|
InvalidDocumentId , InvalidRequest , BAD_REQUEST ;
|
||||||
InvalidDocumentLimit , InvalidRequest , BAD_REQUEST ;
|
InvalidDocumentLimit , InvalidRequest , BAD_REQUEST ;
|
||||||
InvalidDocumentOffset , InvalidRequest , BAD_REQUEST ;
|
InvalidDocumentOffset , InvalidRequest , BAD_REQUEST ;
|
||||||
|
InvalidEmbedder , InvalidRequest , BAD_REQUEST ;
|
||||||
|
InvalidHybridQuery , InvalidRequest , BAD_REQUEST ;
|
||||||
InvalidIndexLimit , InvalidRequest , BAD_REQUEST ;
|
InvalidIndexLimit , InvalidRequest , BAD_REQUEST ;
|
||||||
InvalidIndexOffset , InvalidRequest , BAD_REQUEST ;
|
InvalidIndexOffset , InvalidRequest , BAD_REQUEST ;
|
||||||
InvalidIndexPrimaryKey , InvalidRequest , BAD_REQUEST ;
|
InvalidIndexPrimaryKey , InvalidRequest , BAD_REQUEST ;
|
||||||
@ -233,6 +235,7 @@ InvalidSearchAttributesToRetrieve , InvalidRequest , BAD_REQUEST ;
|
|||||||
InvalidSearchCropLength , InvalidRequest , BAD_REQUEST ;
|
InvalidSearchCropLength , InvalidRequest , BAD_REQUEST ;
|
||||||
InvalidSearchCropMarker , InvalidRequest , BAD_REQUEST ;
|
InvalidSearchCropMarker , InvalidRequest , BAD_REQUEST ;
|
||||||
InvalidSearchFacets , InvalidRequest , BAD_REQUEST ;
|
InvalidSearchFacets , InvalidRequest , BAD_REQUEST ;
|
||||||
|
InvalidSemanticRatio , InvalidRequest , BAD_REQUEST ;
|
||||||
InvalidFacetSearchFacetName , InvalidRequest , BAD_REQUEST ;
|
InvalidFacetSearchFacetName , InvalidRequest , BAD_REQUEST ;
|
||||||
InvalidSearchFilter , InvalidRequest , BAD_REQUEST ;
|
InvalidSearchFilter , InvalidRequest , BAD_REQUEST ;
|
||||||
InvalidSearchHighlightPostTag , InvalidRequest , BAD_REQUEST ;
|
InvalidSearchHighlightPostTag , InvalidRequest , BAD_REQUEST ;
|
||||||
@ -340,6 +343,7 @@ impl ErrorCode for milli::Error {
|
|||||||
}
|
}
|
||||||
UserError::MissingDocumentField(_) => Code::InvalidDocumentFields,
|
UserError::MissingDocumentField(_) => Code::InvalidDocumentFields,
|
||||||
UserError::InvalidPrompt(_) => Code::InvalidSettingsEmbedders,
|
UserError::InvalidPrompt(_) => Code::InvalidSettingsEmbedders,
|
||||||
|
UserError::TooManyEmbedders(_) => Code::InvalidSettingsEmbedders,
|
||||||
UserError::InvalidPromptForEmbeddings(..) => Code::InvalidSettingsEmbedders,
|
UserError::InvalidPromptForEmbeddings(..) => Code::InvalidSettingsEmbedders,
|
||||||
UserError::NoPrimaryKeyCandidateFound => Code::IndexPrimaryKeyNoCandidateFound,
|
UserError::NoPrimaryKeyCandidateFound => Code::IndexPrimaryKeyNoCandidateFound,
|
||||||
UserError::MultiplePrimaryKeyCandidatesFound { .. } => {
|
UserError::MultiplePrimaryKeyCandidatesFound { .. } => {
|
||||||
@ -363,6 +367,7 @@ impl ErrorCode for milli::Error {
|
|||||||
UserError::InvalidMinTypoWordLenSetting(_, _) => {
|
UserError::InvalidMinTypoWordLenSetting(_, _) => {
|
||||||
Code::InvalidSettingsTypoTolerance
|
Code::InvalidSettingsTypoTolerance
|
||||||
}
|
}
|
||||||
|
UserError::InvalidEmbedder(_) => Code::InvalidEmbedder,
|
||||||
UserError::VectorEmbeddingError(_) => Code::VectorEmbeddingError,
|
UserError::VectorEmbeddingError(_) => Code::VectorEmbeddingError,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -36,7 +36,7 @@ use crate::routes::{create_all_stats, Stats};
|
|||||||
use crate::search::{
|
use crate::search::{
|
||||||
FacetSearchResult, MatchingStrategy, SearchQuery, SearchQueryWithIndex, SearchResult,
|
FacetSearchResult, MatchingStrategy, SearchQuery, SearchQueryWithIndex, SearchResult,
|
||||||
DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG,
|
DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG,
|
||||||
DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT,
|
DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEMANTIC_RATIO,
|
||||||
};
|
};
|
||||||
use crate::Opt;
|
use crate::Opt;
|
||||||
|
|
||||||
@ -586,6 +586,11 @@ pub struct SearchAggregator {
|
|||||||
// vector
|
// vector
|
||||||
// The maximum number of floats in a vector request
|
// The maximum number of floats in a vector request
|
||||||
max_vector_size: usize,
|
max_vector_size: usize,
|
||||||
|
// Whether the semantic ratio passed to a hybrid search equals the default ratio.
|
||||||
|
semantic_ratio: bool,
|
||||||
|
// Whether a non-default embedder was specified
|
||||||
|
embedder: bool,
|
||||||
|
hybrid: bool,
|
||||||
|
|
||||||
// every time a search is done, we increment the counter linked to the used settings
|
// every time a search is done, we increment the counter linked to the used settings
|
||||||
matching_strategy: HashMap<String, usize>,
|
matching_strategy: HashMap<String, usize>,
|
||||||
@ -639,6 +644,7 @@ impl SearchAggregator {
|
|||||||
crop_marker,
|
crop_marker,
|
||||||
matching_strategy,
|
matching_strategy,
|
||||||
attributes_to_search_on,
|
attributes_to_search_on,
|
||||||
|
hybrid,
|
||||||
} = query;
|
} = query;
|
||||||
|
|
||||||
let mut ret = Self::default();
|
let mut ret = Self::default();
|
||||||
@ -712,6 +718,12 @@ impl SearchAggregator {
|
|||||||
ret.show_ranking_score = *show_ranking_score;
|
ret.show_ranking_score = *show_ranking_score;
|
||||||
ret.show_ranking_score_details = *show_ranking_score_details;
|
ret.show_ranking_score_details = *show_ranking_score_details;
|
||||||
|
|
||||||
|
if let Some(hybrid) = hybrid {
|
||||||
|
ret.semantic_ratio = hybrid.semantic_ratio != DEFAULT_SEMANTIC_RATIO();
|
||||||
|
ret.embedder = hybrid.embedder.is_some();
|
||||||
|
ret.hybrid = true;
|
||||||
|
}
|
||||||
|
|
||||||
ret
|
ret
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -765,6 +777,9 @@ impl SearchAggregator {
|
|||||||
facets_total_number_of_facets,
|
facets_total_number_of_facets,
|
||||||
show_ranking_score,
|
show_ranking_score,
|
||||||
show_ranking_score_details,
|
show_ranking_score_details,
|
||||||
|
semantic_ratio,
|
||||||
|
embedder,
|
||||||
|
hybrid,
|
||||||
} = other;
|
} = other;
|
||||||
|
|
||||||
if self.timestamp.is_none() {
|
if self.timestamp.is_none() {
|
||||||
@ -810,6 +825,9 @@ impl SearchAggregator {
|
|||||||
|
|
||||||
// vector
|
// vector
|
||||||
self.max_vector_size = self.max_vector_size.max(max_vector_size);
|
self.max_vector_size = self.max_vector_size.max(max_vector_size);
|
||||||
|
self.semantic_ratio |= semantic_ratio;
|
||||||
|
self.hybrid |= hybrid;
|
||||||
|
self.embedder |= embedder;
|
||||||
|
|
||||||
// pagination
|
// pagination
|
||||||
self.max_limit = self.max_limit.max(max_limit);
|
self.max_limit = self.max_limit.max(max_limit);
|
||||||
@ -878,6 +896,9 @@ impl SearchAggregator {
|
|||||||
facets_total_number_of_facets,
|
facets_total_number_of_facets,
|
||||||
show_ranking_score,
|
show_ranking_score,
|
||||||
show_ranking_score_details,
|
show_ranking_score_details,
|
||||||
|
semantic_ratio,
|
||||||
|
embedder,
|
||||||
|
hybrid,
|
||||||
} = self;
|
} = self;
|
||||||
|
|
||||||
if total_received == 0 {
|
if total_received == 0 {
|
||||||
@ -917,6 +938,11 @@ impl SearchAggregator {
|
|||||||
"vector": {
|
"vector": {
|
||||||
"max_vector_size": max_vector_size,
|
"max_vector_size": max_vector_size,
|
||||||
},
|
},
|
||||||
|
"hybrid": {
|
||||||
|
"enabled": hybrid,
|
||||||
|
"semantic_ratio": semantic_ratio,
|
||||||
|
"embedder": embedder,
|
||||||
|
},
|
||||||
"pagination": {
|
"pagination": {
|
||||||
"max_limit": max_limit,
|
"max_limit": max_limit,
|
||||||
"max_offset": max_offset,
|
"max_offset": max_offset,
|
||||||
@ -1012,6 +1038,7 @@ impl MultiSearchAggregator {
|
|||||||
crop_marker: _,
|
crop_marker: _,
|
||||||
matching_strategy: _,
|
matching_strategy: _,
|
||||||
attributes_to_search_on: _,
|
attributes_to_search_on: _,
|
||||||
|
hybrid: _,
|
||||||
} = query;
|
} = query;
|
||||||
|
|
||||||
index_uid.as_str()
|
index_uid.as_str()
|
||||||
@ -1158,6 +1185,7 @@ impl FacetSearchAggregator {
|
|||||||
filter,
|
filter,
|
||||||
matching_strategy,
|
matching_strategy,
|
||||||
attributes_to_search_on,
|
attributes_to_search_on,
|
||||||
|
hybrid,
|
||||||
} = query;
|
} = query;
|
||||||
|
|
||||||
let mut ret = Self::default();
|
let mut ret = Self::default();
|
||||||
@ -1171,7 +1199,8 @@ impl FacetSearchAggregator {
|
|||||||
|| vector.is_some()
|
|| vector.is_some()
|
||||||
|| filter.is_some()
|
|| filter.is_some()
|
||||||
|| *matching_strategy != MatchingStrategy::default()
|
|| *matching_strategy != MatchingStrategy::default()
|
||||||
|| attributes_to_search_on.is_some();
|
|| attributes_to_search_on.is_some()
|
||||||
|
|| hybrid.is_some();
|
||||||
|
|
||||||
ret
|
ret
|
||||||
}
|
}
|
||||||
|
@ -14,9 +14,9 @@ use crate::analytics::{Analytics, FacetSearchAggregator};
|
|||||||
use crate::extractors::authentication::policies::*;
|
use crate::extractors::authentication::policies::*;
|
||||||
use crate::extractors::authentication::GuardedData;
|
use crate::extractors::authentication::GuardedData;
|
||||||
use crate::search::{
|
use crate::search::{
|
||||||
add_search_rules, perform_facet_search, MatchingStrategy, SearchQuery, DEFAULT_CROP_LENGTH,
|
add_search_rules, perform_facet_search, HybridQuery, MatchingStrategy, SearchQuery,
|
||||||
DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_HIGHLIGHT_PRE_TAG,
|
DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG,
|
||||||
DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET,
|
DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn configure(cfg: &mut web::ServiceConfig) {
|
pub fn configure(cfg: &mut web::ServiceConfig) {
|
||||||
@ -37,6 +37,8 @@ pub struct FacetSearchQuery {
|
|||||||
pub q: Option<String>,
|
pub q: Option<String>,
|
||||||
#[deserr(default, error = DeserrJsonError<InvalidSearchVector>)]
|
#[deserr(default, error = DeserrJsonError<InvalidSearchVector>)]
|
||||||
pub vector: Option<Vec<f32>>,
|
pub vector: Option<Vec<f32>>,
|
||||||
|
#[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)]
|
||||||
|
pub hybrid: Option<HybridQuery>,
|
||||||
#[deserr(default, error = DeserrJsonError<InvalidSearchFilter>)]
|
#[deserr(default, error = DeserrJsonError<InvalidSearchFilter>)]
|
||||||
pub filter: Option<Value>,
|
pub filter: Option<Value>,
|
||||||
#[deserr(default, error = DeserrJsonError<InvalidSearchMatchingStrategy>, default)]
|
#[deserr(default, error = DeserrJsonError<InvalidSearchMatchingStrategy>, default)]
|
||||||
@ -96,6 +98,7 @@ impl From<FacetSearchQuery> for SearchQuery {
|
|||||||
filter,
|
filter,
|
||||||
matching_strategy,
|
matching_strategy,
|
||||||
attributes_to_search_on,
|
attributes_to_search_on,
|
||||||
|
hybrid,
|
||||||
} = value;
|
} = value;
|
||||||
|
|
||||||
SearchQuery {
|
SearchQuery {
|
||||||
@ -120,6 +123,7 @@ impl From<FacetSearchQuery> for SearchQuery {
|
|||||||
matching_strategy,
|
matching_strategy,
|
||||||
vector: vector.map(VectorQuery::Vector),
|
vector: vector.map(VectorQuery::Vector),
|
||||||
attributes_to_search_on,
|
attributes_to_search_on,
|
||||||
|
hybrid,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -8,7 +8,7 @@ use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError};
|
|||||||
use meilisearch_types::error::deserr_codes::*;
|
use meilisearch_types::error::deserr_codes::*;
|
||||||
use meilisearch_types::error::ResponseError;
|
use meilisearch_types::error::ResponseError;
|
||||||
use meilisearch_types::index_uid::IndexUid;
|
use meilisearch_types::index_uid::IndexUid;
|
||||||
use meilisearch_types::milli::VectorQuery;
|
use meilisearch_types::milli::{self, VectorQuery};
|
||||||
use meilisearch_types::serde_cs::vec::CS;
|
use meilisearch_types::serde_cs::vec::CS;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
@ -17,9 +17,9 @@ use crate::extractors::authentication::policies::*;
|
|||||||
use crate::extractors::authentication::GuardedData;
|
use crate::extractors::authentication::GuardedData;
|
||||||
use crate::extractors::sequential_extractor::SeqHandler;
|
use crate::extractors::sequential_extractor::SeqHandler;
|
||||||
use crate::search::{
|
use crate::search::{
|
||||||
add_search_rules, perform_search, MatchingStrategy, SearchQuery, DEFAULT_CROP_LENGTH,
|
add_search_rules, perform_search, HybridQuery, MatchingStrategy, SearchQuery,
|
||||||
DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_HIGHLIGHT_PRE_TAG,
|
DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG,
|
||||||
DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET,
|
DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, DEFAULT_SEMANTIC_RATIO,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn configure(cfg: &mut web::ServiceConfig) {
|
pub fn configure(cfg: &mut web::ServiceConfig) {
|
||||||
@ -75,6 +75,10 @@ pub struct SearchQueryGet {
|
|||||||
matching_strategy: MatchingStrategy,
|
matching_strategy: MatchingStrategy,
|
||||||
#[deserr(default, error = DeserrQueryParamError<InvalidSearchAttributesToSearchOn>)]
|
#[deserr(default, error = DeserrQueryParamError<InvalidSearchAttributesToSearchOn>)]
|
||||||
pub attributes_to_search_on: Option<CS<String>>,
|
pub attributes_to_search_on: Option<CS<String>>,
|
||||||
|
#[deserr(default, error = DeserrQueryParamError<InvalidHybridQuery>)]
|
||||||
|
pub hybrid_embedder: Option<String>,
|
||||||
|
#[deserr(default, error = DeserrQueryParamError<InvalidHybridQuery>)]
|
||||||
|
pub hybrid_semantic_ratio: Option<f32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<SearchQueryGet> for SearchQuery {
|
impl From<SearchQueryGet> for SearchQuery {
|
||||||
@ -87,6 +91,18 @@ impl From<SearchQueryGet> for SearchQuery {
|
|||||||
None => None,
|
None => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let hybrid = match (other.hybrid_embedder, other.hybrid_semantic_ratio) {
|
||||||
|
(None, None) => None,
|
||||||
|
(None, Some(semantic_ratio)) => Some(HybridQuery { semantic_ratio, embedder: None }),
|
||||||
|
(Some(embedder), None) => Some(HybridQuery {
|
||||||
|
semantic_ratio: DEFAULT_SEMANTIC_RATIO(),
|
||||||
|
embedder: Some(embedder),
|
||||||
|
}),
|
||||||
|
(Some(embedder), Some(semantic_ratio)) => {
|
||||||
|
Some(HybridQuery { semantic_ratio, embedder: Some(embedder) })
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
q: other.q,
|
q: other.q,
|
||||||
vector: other.vector.map(CS::into_inner).map(VectorQuery::Vector),
|
vector: other.vector.map(CS::into_inner).map(VectorQuery::Vector),
|
||||||
@ -109,6 +125,7 @@ impl From<SearchQueryGet> for SearchQuery {
|
|||||||
crop_marker: other.crop_marker,
|
crop_marker: other.crop_marker,
|
||||||
matching_strategy: other.matching_strategy,
|
matching_strategy: other.matching_strategy,
|
||||||
attributes_to_search_on: other.attributes_to_search_on.map(|o| o.into_iter().collect()),
|
attributes_to_search_on: other.attributes_to_search_on.map(|o| o.into_iter().collect()),
|
||||||
|
hybrid,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -159,6 +176,9 @@ pub async fn search_with_url_query(
|
|||||||
|
|
||||||
let index = index_scheduler.index(&index_uid)?;
|
let index = index_scheduler.index(&index_uid)?;
|
||||||
let features = index_scheduler.features();
|
let features = index_scheduler.features();
|
||||||
|
|
||||||
|
embed(&mut query, index_scheduler.get_ref(), &index).await?;
|
||||||
|
|
||||||
let search_result =
|
let search_result =
|
||||||
tokio::task::spawn_blocking(move || perform_search(&index, query, features)).await?;
|
tokio::task::spawn_blocking(move || perform_search(&index, query, features)).await?;
|
||||||
if let Ok(ref search_result) = search_result {
|
if let Ok(ref search_result) = search_result {
|
||||||
@ -213,22 +233,31 @@ pub async fn search_with_post(
|
|||||||
pub async fn embed(
|
pub async fn embed(
|
||||||
query: &mut SearchQuery,
|
query: &mut SearchQuery,
|
||||||
index_scheduler: &IndexScheduler,
|
index_scheduler: &IndexScheduler,
|
||||||
index: &meilisearch_types::milli::Index,
|
index: &milli::Index,
|
||||||
) -> Result<(), ResponseError> {
|
) -> Result<(), ResponseError> {
|
||||||
if let Some(VectorQuery::String(prompt)) = query.vector.take() {
|
if let Some(VectorQuery::String(prompt)) = query.vector.take() {
|
||||||
let embedder_configs = index.embedding_configs(&index.read_txn()?)?;
|
let embedder_configs = index.embedding_configs(&index.read_txn()?)?;
|
||||||
let embedder = index_scheduler.embedders(embedder_configs)?;
|
let embedder = index_scheduler.embedders(embedder_configs)?;
|
||||||
|
|
||||||
/// FIXME: add error if no embedder, remove unwrap, support multiple embedders
|
let embedder_name = if let Some(HybridQuery {
|
||||||
|
semantic_ratio: _,
|
||||||
|
embedder: Some(embedder),
|
||||||
|
}) = &query.hybrid
|
||||||
|
{
|
||||||
|
embedder
|
||||||
|
} else {
|
||||||
|
"default"
|
||||||
|
};
|
||||||
|
|
||||||
let embeddings = embedder
|
let embeddings = embedder
|
||||||
.get("default")
|
.get(embedder_name)
|
||||||
.unwrap()
|
.ok_or(milli::UserError::InvalidEmbedder(embedder_name.to_owned()))
|
||||||
|
.map_err(milli::Error::from)?
|
||||||
.0
|
.0
|
||||||
.embed(vec![prompt])
|
.embed(vec![prompt])
|
||||||
.await
|
.await
|
||||||
.map_err(meilisearch_types::milli::vector::Error::from)
|
.map_err(milli::vector::Error::from)
|
||||||
.map_err(meilisearch_types::milli::UserError::from)
|
.map_err(milli::Error::from)?
|
||||||
.map_err(meilisearch_types::milli::Error::from)?
|
|
||||||
.pop()
|
.pop()
|
||||||
.expect("No vector returned from embedding");
|
.expect("No vector returned from embedding");
|
||||||
|
|
||||||
|
@ -36,6 +36,7 @@ pub const DEFAULT_CROP_LENGTH: fn() -> usize = || 10;
|
|||||||
pub const DEFAULT_CROP_MARKER: fn() -> String = || "…".to_string();
|
pub const DEFAULT_CROP_MARKER: fn() -> String = || "…".to_string();
|
||||||
pub const DEFAULT_HIGHLIGHT_PRE_TAG: fn() -> String = || "<em>".to_string();
|
pub const DEFAULT_HIGHLIGHT_PRE_TAG: fn() -> String = || "<em>".to_string();
|
||||||
pub const DEFAULT_HIGHLIGHT_POST_TAG: fn() -> String = || "</em>".to_string();
|
pub const DEFAULT_HIGHLIGHT_POST_TAG: fn() -> String = || "</em>".to_string();
|
||||||
|
pub const DEFAULT_SEMANTIC_RATIO: fn() -> f32 = || 0.5;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default, PartialEq, Deserr)]
|
#[derive(Debug, Clone, Default, PartialEq, Deserr)]
|
||||||
#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)]
|
#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)]
|
||||||
@ -44,6 +45,8 @@ pub struct SearchQuery {
|
|||||||
pub q: Option<String>,
|
pub q: Option<String>,
|
||||||
#[deserr(default, error = DeserrJsonError<InvalidSearchVector>)]
|
#[deserr(default, error = DeserrJsonError<InvalidSearchVector>)]
|
||||||
pub vector: Option<milli::VectorQuery>,
|
pub vector: Option<milli::VectorQuery>,
|
||||||
|
#[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)]
|
||||||
|
pub hybrid: Option<HybridQuery>,
|
||||||
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
|
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
|
||||||
pub offset: usize,
|
pub offset: usize,
|
||||||
#[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)]
|
#[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)]
|
||||||
@ -84,6 +87,15 @@ pub struct SearchQuery {
|
|||||||
pub attributes_to_search_on: Option<Vec<String>>,
|
pub attributes_to_search_on: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Default, PartialEq, Deserr)]
|
||||||
|
#[deserr(error = DeserrJsonError<InvalidHybridQuery>, rename_all = camelCase, deny_unknown_fields)]
|
||||||
|
pub struct HybridQuery {
|
||||||
|
#[deserr(default, error = DeserrJsonError<InvalidSemanticRatio>, default = DEFAULT_SEMANTIC_RATIO())]
|
||||||
|
pub semantic_ratio: f32,
|
||||||
|
#[deserr(default, error = DeserrJsonError<InvalidEmbedder>, default)]
|
||||||
|
pub embedder: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
impl SearchQuery {
|
impl SearchQuery {
|
||||||
pub fn is_finite_pagination(&self) -> bool {
|
pub fn is_finite_pagination(&self) -> bool {
|
||||||
self.page.or(self.hits_per_page).is_some()
|
self.page.or(self.hits_per_page).is_some()
|
||||||
@ -103,6 +115,8 @@ pub struct SearchQueryWithIndex {
|
|||||||
pub q: Option<String>,
|
pub q: Option<String>,
|
||||||
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
|
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
|
||||||
pub vector: Option<VectorQuery>,
|
pub vector: Option<VectorQuery>,
|
||||||
|
#[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)]
|
||||||
|
pub hybrid: Option<HybridQuery>,
|
||||||
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
|
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
|
||||||
pub offset: usize,
|
pub offset: usize,
|
||||||
#[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)]
|
#[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)]
|
||||||
@ -168,6 +182,7 @@ impl SearchQueryWithIndex {
|
|||||||
crop_marker,
|
crop_marker,
|
||||||
matching_strategy,
|
matching_strategy,
|
||||||
attributes_to_search_on,
|
attributes_to_search_on,
|
||||||
|
hybrid,
|
||||||
} = self;
|
} = self;
|
||||||
(
|
(
|
||||||
index_uid,
|
index_uid,
|
||||||
@ -193,6 +208,7 @@ impl SearchQueryWithIndex {
|
|||||||
crop_marker,
|
crop_marker,
|
||||||
matching_strategy,
|
matching_strategy,
|
||||||
attributes_to_search_on,
|
attributes_to_search_on,
|
||||||
|
hybrid,
|
||||||
// do not use ..Default::default() here,
|
// do not use ..Default::default() here,
|
||||||
// rather add any missing field from `SearchQuery` to `SearchQueryWithIndex`
|
// rather add any missing field from `SearchQuery` to `SearchQueryWithIndex`
|
||||||
},
|
},
|
||||||
|
@ -63,6 +63,8 @@ pub enum InternalError {
|
|||||||
InvalidMatchingWords,
|
InvalidMatchingWords,
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
ArroyError(#[from] arroy::Error),
|
ArroyError(#[from] arroy::Error),
|
||||||
|
#[error(transparent)]
|
||||||
|
VectorEmbeddingError(#[from] crate::vector::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
@ -188,8 +190,23 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco
|
|||||||
MissingDocumentField(#[from] crate::prompt::error::RenderPromptError),
|
MissingDocumentField(#[from] crate::prompt::error::RenderPromptError),
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
InvalidPrompt(#[from] crate::prompt::error::NewPromptError),
|
InvalidPrompt(#[from] crate::prompt::error::NewPromptError),
|
||||||
#[error("Invalid prompt in for embeddings with name '{0}': {1}")]
|
#[error("Invalid prompt in for embeddings with name '{0}': {1}.")]
|
||||||
InvalidPromptForEmbeddings(String, crate::prompt::error::NewPromptError),
|
InvalidPromptForEmbeddings(String, crate::prompt::error::NewPromptError),
|
||||||
|
#[error("Too many embedders in the configuration. Found {0}, but limited to 256.")]
|
||||||
|
TooManyEmbedders(usize),
|
||||||
|
#[error("Cannot find embedder with name {0}.")]
|
||||||
|
InvalidEmbedder(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<crate::vector::Error> for Error {
|
||||||
|
fn from(value: crate::vector::Error) -> Self {
|
||||||
|
match value.fault() {
|
||||||
|
FaultSource::User => Error::UserError(value.into()),
|
||||||
|
FaultSource::Runtime => Error::InternalError(value.into()),
|
||||||
|
FaultSource::Bug => Error::InternalError(value.into()),
|
||||||
|
FaultSource::Undecided => Error::InternalError(value.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<arroy::Error> for Error {
|
impl From<arroy::Error> for Error {
|
||||||
|
@ -110,7 +110,6 @@ impl Prompt {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// render template with special object that's OK with `doc.*` and `fields.*`
|
// render template with special object that's OK with `doc.*` and `fields.*`
|
||||||
/// FIXME: doesn't work for nested objects e.g. `doc.a.b`
|
|
||||||
this.template
|
this.template
|
||||||
.render(&template_checker::TemplateChecker)
|
.render(&template_checker::TemplateChecker)
|
||||||
.map_err(NewPromptError::invalid_fields_in_template)?;
|
.map_err(NewPromptError::invalid_fields_in_template)?;
|
||||||
@ -142,3 +141,80 @@ pub enum PromptFallbackStrategy {
|
|||||||
#[default]
|
#[default]
|
||||||
Error,
|
Error,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use super::Prompt;
|
||||||
|
use crate::error::FaultSource;
|
||||||
|
use crate::prompt::error::{NewPromptError, NewPromptErrorKind};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn default_template() {
|
||||||
|
// does not panic
|
||||||
|
Prompt::default();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn empty_template() {
|
||||||
|
Prompt::new("".into(), None, None).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn template_ok() {
|
||||||
|
Prompt::new("{{doc.title}}: {{doc.overview}}".into(), None, None).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn template_syntax() {
|
||||||
|
assert!(matches!(
|
||||||
|
Prompt::new("{{doc.title: {{doc.overview}}".into(), None, None),
|
||||||
|
Err(NewPromptError {
|
||||||
|
kind: NewPromptErrorKind::CannotParseTemplate(_),
|
||||||
|
fault: FaultSource::User
|
||||||
|
})
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn template_missing_doc() {
|
||||||
|
assert!(matches!(
|
||||||
|
Prompt::new("{{title}}: {{overview}}".into(), None, None),
|
||||||
|
Err(NewPromptError {
|
||||||
|
kind: NewPromptErrorKind::InvalidFieldsInTemplate(_),
|
||||||
|
fault: FaultSource::User
|
||||||
|
})
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn template_nested_doc() {
|
||||||
|
Prompt::new("{{doc.actor.firstName}}: {{doc.actor.lastName}}".into(), None, None).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn template_fields() {
|
||||||
|
Prompt::new("{% for field in fields %}{{field}}{% endfor %}".into(), None, None).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn template_fields_ok() {
|
||||||
|
Prompt::new(
|
||||||
|
"{% for field in fields %}{{field.name}}: {{field.value}}{% endfor %}".into(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn template_fields_invalid() {
|
||||||
|
assert!(matches!(
|
||||||
|
// intentionally garbled field
|
||||||
|
Prompt::new("{% for field in fields %}{{field.vaelu}} {% endfor %}".into(), None, None),
|
||||||
|
Err(NewPromptError {
|
||||||
|
kind: NewPromptErrorKind::InvalidFieldsInTemplate(_),
|
||||||
|
fault: FaultSource::User
|
||||||
|
})
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use liquid::model::{
|
use liquid::model::{
|
||||||
ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue,
|
ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue,
|
||||||
};
|
};
|
||||||
use liquid::{ObjectView, ValueView};
|
use liquid::{Object, ObjectView, ValueView};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct TemplateChecker;
|
pub struct TemplateChecker;
|
||||||
@ -31,11 +31,11 @@ impl ObjectView for DummyField {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
|
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
|
||||||
Box::new(std::iter::empty())
|
Box::new(vec![DUMMY_VALUE.as_view(), DUMMY_VALUE.as_view()].into_iter())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> {
|
fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> {
|
||||||
Box::new(std::iter::empty())
|
Box::new(self.keys().zip(self.values()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn contains_key(&self, index: &str) -> bool {
|
fn contains_key(&self, index: &str) -> bool {
|
||||||
@ -69,7 +69,12 @@ impl ValueView for DummyField {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn query_state(&self, state: State) -> bool {
|
fn query_state(&self, state: State) -> bool {
|
||||||
DUMMY_VALUE.query_state(state)
|
match state {
|
||||||
|
State::Truthy => true,
|
||||||
|
State::DefaultValue => false,
|
||||||
|
State::Empty => false,
|
||||||
|
State::Blank => false,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn to_kstr(&self) -> KStringCow<'_> {
|
fn to_kstr(&self) -> KStringCow<'_> {
|
||||||
@ -77,7 +82,10 @@ impl ValueView for DummyField {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn to_value(&self) -> LiquidValue {
|
fn to_value(&self) -> LiquidValue {
|
||||||
LiquidValue::Nil
|
let mut this = Object::new();
|
||||||
|
this.insert("name".into(), LiquidValue::Nil);
|
||||||
|
this.insert("value".into(), LiquidValue::Nil);
|
||||||
|
LiquidValue::Object(this)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn as_object(&self) -> Option<&dyn ObjectView> {
|
fn as_object(&self) -> Option<&dyn ObjectView> {
|
||||||
@ -103,7 +111,12 @@ impl ValueView for DummyFields {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn query_state(&self, state: State) -> bool {
|
fn query_state(&self, state: State) -> bool {
|
||||||
DUMMY_VALUE.query_state(state)
|
match state {
|
||||||
|
State::Truthy => true,
|
||||||
|
State::DefaultValue => false,
|
||||||
|
State::Empty => false,
|
||||||
|
State::Blank => false,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn to_kstr(&self) -> KStringCow<'_> {
|
fn to_kstr(&self) -> KStringCow<'_> {
|
||||||
@ -111,7 +124,7 @@ impl ValueView for DummyFields {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn to_value(&self) -> LiquidValue {
|
fn to_value(&self) -> LiquidValue {
|
||||||
LiquidValue::Nil
|
LiquidValue::Array(vec![DummyField.to_value()])
|
||||||
}
|
}
|
||||||
|
|
||||||
fn as_array(&self) -> Option<&dyn ArrayView> {
|
fn as_array(&self) -> Option<&dyn ArrayView> {
|
||||||
@ -125,15 +138,15 @@ impl ArrayView for DummyFields {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn size(&self) -> i64 {
|
fn size(&self) -> i64 {
|
||||||
i64::MAX
|
u16::MAX as i64
|
||||||
}
|
}
|
||||||
|
|
||||||
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
|
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
|
||||||
Box::new(std::iter::empty())
|
Box::new(std::iter::once(DummyField.as_value()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn contains_key(&self, _index: i64) -> bool {
|
fn contains_key(&self, index: i64) -> bool {
|
||||||
true
|
index < self.size()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get(&self, _index: i64) -> Option<&dyn ValueView> {
|
fn get(&self, _index: i64) -> Option<&dyn ValueView> {
|
||||||
@ -167,7 +180,8 @@ impl ObjectView for DummyDoc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn get<'s>(&'s self, _index: &str) -> Option<&'s dyn ValueView> {
|
fn get<'s>(&'s self, _index: &str) -> Option<&'s dyn ValueView> {
|
||||||
Some(DUMMY_VALUE.as_view())
|
// Recursively sends itself
|
||||||
|
Some(self)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -189,7 +203,12 @@ impl ValueView for DummyDoc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn query_state(&self, state: State) -> bool {
|
fn query_state(&self, state: State) -> bool {
|
||||||
DUMMY_VALUE.query_state(state)
|
match state {
|
||||||
|
State::Truthy => true,
|
||||||
|
State::DefaultValue => false,
|
||||||
|
State::Empty => false,
|
||||||
|
State::Blank => false,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn to_kstr(&self) -> KStringCow<'_> {
|
fn to_kstr(&self) -> KStringCow<'_> {
|
||||||
|
@ -516,7 +516,7 @@ pub fn execute_vector_search(
|
|||||||
) -> Result<PartialSearchResult> {
|
) -> Result<PartialSearchResult> {
|
||||||
check_sort_criteria(ctx, sort_criteria.as_ref())?;
|
check_sort_criteria(ctx, sort_criteria.as_ref())?;
|
||||||
|
|
||||||
/// FIXME: input universe = universe & documents_with_vectors
|
// FIXME: input universe = universe & documents_with_vectors
|
||||||
// for now if we're computing embeddings for ALL documents, we can assume that this is just universe
|
// for now if we're computing embeddings for ALL documents, we can assume that this is just universe
|
||||||
let ranking_rules = get_ranking_rules_for_vector(
|
let ranking_rules = get_ranking_rules_for_vector(
|
||||||
ctx,
|
ctx,
|
||||||
|
@ -71,8 +71,8 @@ impl VectorStateDelta {
|
|||||||
pub fn extract_vector_points<R: io::Read + io::Seek>(
|
pub fn extract_vector_points<R: io::Read + io::Seek>(
|
||||||
obkv_documents: grenad::Reader<R>,
|
obkv_documents: grenad::Reader<R>,
|
||||||
indexer: GrenadParameters,
|
indexer: GrenadParameters,
|
||||||
field_id_map: FieldsIdsMap,
|
field_id_map: &FieldsIdsMap,
|
||||||
prompt: Option<&Prompt>,
|
prompt: &Prompt,
|
||||||
) -> Result<ExtractedVectorPoints> {
|
) -> Result<ExtractedVectorPoints> {
|
||||||
puffin::profile_function!();
|
puffin::profile_function!();
|
||||||
|
|
||||||
@ -142,14 +142,11 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
|
|||||||
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
|
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
|
||||||
if document_is_kept {
|
if document_is_kept {
|
||||||
// becomes autogenerated
|
// becomes autogenerated
|
||||||
match prompt {
|
VectorStateDelta::NowGenerated(prompt.render(
|
||||||
Some(prompt) => VectorStateDelta::NowGenerated(prompt.render(
|
obkv,
|
||||||
obkv,
|
DelAdd::Addition,
|
||||||
DelAdd::Addition,
|
field_id_map,
|
||||||
&field_id_map,
|
)?)
|
||||||
)?),
|
|
||||||
None => VectorStateDelta::NowRemoved,
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
VectorStateDelta::NowRemoved
|
VectorStateDelta::NowRemoved
|
||||||
}
|
}
|
||||||
@ -162,26 +159,18 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
|
|||||||
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
|
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
|
||||||
|
|
||||||
if document_is_kept {
|
if document_is_kept {
|
||||||
match prompt {
|
// Don't give up if the old prompt was failing
|
||||||
Some(prompt) => {
|
let old_prompt =
|
||||||
// Don't give up if the old prompt was failing
|
prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default();
|
||||||
let old_prompt = prompt
|
let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?;
|
||||||
.render(obkv, DelAdd::Deletion, &field_id_map)
|
if old_prompt != new_prompt {
|
||||||
.unwrap_or_default();
|
log::trace!(
|
||||||
let new_prompt =
|
"🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}"
|
||||||
prompt.render(obkv, DelAdd::Addition, &field_id_map)?;
|
);
|
||||||
if old_prompt != new_prompt {
|
VectorStateDelta::NowGenerated(new_prompt)
|
||||||
log::trace!(
|
} else {
|
||||||
"🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}"
|
log::trace!("⏭️ Prompt unmodified, skipping");
|
||||||
);
|
VectorStateDelta::NoChange
|
||||||
VectorStateDelta::NowGenerated(new_prompt)
|
|
||||||
} else {
|
|
||||||
log::trace!("⏭️ Prompt unmodified, skipping");
|
|
||||||
VectorStateDelta::NoChange
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// We no longer have a prompt, so we need to remove any existing vector
|
|
||||||
None => VectorStateDelta::NowRemoved,
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
VectorStateDelta::NowRemoved
|
VectorStateDelta::NowRemoved
|
||||||
@ -196,24 +185,16 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
|
|||||||
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
|
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
|
||||||
|
|
||||||
if document_is_kept {
|
if document_is_kept {
|
||||||
match prompt {
|
// Don't give up if the old prompt was failing
|
||||||
Some(prompt) => {
|
let old_prompt =
|
||||||
// Don't give up if the old prompt was failing
|
prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default();
|
||||||
let old_prompt = prompt
|
let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?;
|
||||||
.render(obkv, DelAdd::Deletion, &field_id_map)
|
if old_prompt != new_prompt {
|
||||||
.unwrap_or_default();
|
log::trace!("🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}");
|
||||||
let new_prompt = prompt.render(obkv, DelAdd::Addition, &field_id_map)?;
|
VectorStateDelta::NowGenerated(new_prompt)
|
||||||
if old_prompt != new_prompt {
|
} else {
|
||||||
log::trace!(
|
log::trace!("⏭️ Prompt unmodified, skipping");
|
||||||
"🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}"
|
VectorStateDelta::NoChange
|
||||||
);
|
|
||||||
VectorStateDelta::NowGenerated(new_prompt)
|
|
||||||
} else {
|
|
||||||
log::trace!("⏭️ Prompt unmodified, skipping");
|
|
||||||
VectorStateDelta::NoChange
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None => VectorStateDelta::NowRemoved,
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
VectorStateDelta::NowRemoved
|
VectorStateDelta::NowRemoved
|
||||||
@ -322,7 +303,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
|||||||
prompt_reader: grenad::Reader<R>,
|
prompt_reader: grenad::Reader<R>,
|
||||||
indexer: GrenadParameters,
|
indexer: GrenadParameters,
|
||||||
embedder: Arc<Embedder>,
|
embedder: Arc<Embedder>,
|
||||||
) -> Result<(grenad::Reader<BufReader<File>>, Option<usize>)> {
|
) -> Result<grenad::Reader<BufReader<File>>> {
|
||||||
let rt = tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build()?;
|
let rt = tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build()?;
|
||||||
|
|
||||||
let n_chunks = embedder.chunk_count_hint(); // chunk level parellelism
|
let n_chunks = embedder.chunk_count_hint(); // chunk level parellelism
|
||||||
@ -341,8 +322,6 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
|||||||
let mut chunks_ids = Vec::with_capacity(n_chunks);
|
let mut chunks_ids = Vec::with_capacity(n_chunks);
|
||||||
let mut cursor = prompt_reader.into_cursor()?;
|
let mut cursor = prompt_reader.into_cursor()?;
|
||||||
|
|
||||||
let mut expected_dimension = None;
|
|
||||||
|
|
||||||
while let Some((key, value)) = cursor.move_on_next()? {
|
while let Some((key, value)) = cursor.move_on_next()? {
|
||||||
let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap();
|
let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap();
|
||||||
// SAFETY: precondition, the grenad value was saved from a string
|
// SAFETY: precondition, the grenad value was saved from a string
|
||||||
@ -367,7 +346,6 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
|||||||
.embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))),
|
.embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))),
|
||||||
)
|
)
|
||||||
.map_err(crate::vector::Error::from)
|
.map_err(crate::vector::Error::from)
|
||||||
.map_err(crate::UserError::from)
|
|
||||||
.map_err(crate::Error::from)?;
|
.map_err(crate::Error::from)?;
|
||||||
|
|
||||||
for (docid, embeddings) in chunks_ids
|
for (docid, embeddings) in chunks_ids
|
||||||
@ -376,7 +354,6 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
|||||||
.zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter()))
|
.zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter()))
|
||||||
{
|
{
|
||||||
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
|
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
|
||||||
expected_dimension = Some(embeddings.dimension());
|
|
||||||
}
|
}
|
||||||
chunks_ids.clear();
|
chunks_ids.clear();
|
||||||
}
|
}
|
||||||
@ -387,7 +364,6 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
|||||||
let chunked_embeds = rt
|
let chunked_embeds = rt
|
||||||
.block_on(embedder.embed_chunks(std::mem::take(&mut chunks)))
|
.block_on(embedder.embed_chunks(std::mem::take(&mut chunks)))
|
||||||
.map_err(crate::vector::Error::from)
|
.map_err(crate::vector::Error::from)
|
||||||
.map_err(crate::UserError::from)
|
|
||||||
.map_err(crate::Error::from)?;
|
.map_err(crate::Error::from)?;
|
||||||
for (docid, embeddings) in chunks_ids
|
for (docid, embeddings) in chunks_ids
|
||||||
.iter()
|
.iter()
|
||||||
@ -395,7 +371,6 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
|||||||
.zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter()))
|
.zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter()))
|
||||||
{
|
{
|
||||||
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
|
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
|
||||||
expected_dimension = Some(embeddings.dimension());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -403,14 +378,12 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
|||||||
let embeds = rt
|
let embeds = rt
|
||||||
.block_on(embedder.embed(std::mem::take(&mut current_chunk)))
|
.block_on(embedder.embed(std::mem::take(&mut current_chunk)))
|
||||||
.map_err(crate::vector::Error::from)
|
.map_err(crate::vector::Error::from)
|
||||||
.map_err(crate::UserError::from)
|
|
||||||
.map_err(crate::Error::from)?;
|
.map_err(crate::Error::from)?;
|
||||||
|
|
||||||
for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) {
|
for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) {
|
||||||
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
|
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
|
||||||
expected_dimension = Some(embeddings.dimension());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok((writer_into_reader(state_writer)?, expected_dimension))
|
writer_into_reader(state_writer)
|
||||||
}
|
}
|
||||||
|
@ -292,43 +292,42 @@ fn send_original_documents_data(
|
|||||||
let documents_chunk_cloned = original_documents_chunk.clone();
|
let documents_chunk_cloned = original_documents_chunk.clone();
|
||||||
let lmdb_writer_sx_cloned = lmdb_writer_sx.clone();
|
let lmdb_writer_sx_cloned = lmdb_writer_sx.clone();
|
||||||
rayon::spawn(move || {
|
rayon::spawn(move || {
|
||||||
let (embedder, prompt) = embedders.get("default").cloned().unzip();
|
for (name, (embedder, prompt)) in embedders {
|
||||||
let result =
|
let result = extract_vector_points(
|
||||||
extract_vector_points(documents_chunk_cloned, indexer, field_id_map, prompt.as_deref());
|
documents_chunk_cloned.clone(),
|
||||||
match result {
|
indexer,
|
||||||
Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => {
|
&field_id_map,
|
||||||
/// FIXME: support multiple embedders
|
&prompt,
|
||||||
let results = embedder.and_then(|embedder| {
|
);
|
||||||
match extract_embeddings(prompts, indexer, embedder.clone()) {
|
match result {
|
||||||
|
Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => {
|
||||||
|
let embeddings = match extract_embeddings(prompts, indexer, embedder.clone()) {
|
||||||
Ok(results) => Some(results),
|
Ok(results) => Some(results),
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
let _ = lmdb_writer_sx_cloned.send(Err(error));
|
let _ = lmdb_writer_sx_cloned.send(Err(error));
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
});
|
|
||||||
let (embeddings, expected_dimension) = results.unzip();
|
if !(remove_vectors.is_empty()
|
||||||
let expected_dimension = expected_dimension.flatten();
|
&& manual_vectors.is_empty()
|
||||||
if !(remove_vectors.is_empty()
|
&& embeddings.as_ref().map_or(true, |e| e.is_empty()))
|
||||||
&& manual_vectors.is_empty()
|
{
|
||||||
&& embeddings.as_ref().map_or(true, |e| e.is_empty()))
|
|
||||||
{
|
|
||||||
/// FIXME FIXME FIXME
|
|
||||||
if expected_dimension.is_some() {
|
|
||||||
let _ = lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints {
|
let _ = lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints {
|
||||||
remove_vectors,
|
remove_vectors,
|
||||||
embeddings,
|
embeddings,
|
||||||
/// FIXME: compute an expected dimension from the manual vectors if any
|
expected_dimension: embedder.dimensions(),
|
||||||
expected_dimension: expected_dimension.unwrap(),
|
|
||||||
manual_vectors,
|
manual_vectors,
|
||||||
|
embedder_name: name,
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Err(error) => {
|
||||||
|
let _ = lmdb_writer_sx_cloned.send(Err(error));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Err(error) => {
|
}
|
||||||
let _ = lmdb_writer_sx_cloned.send(Err(error));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// TODO: create a custom internal error
|
// TODO: create a custom internal error
|
||||||
|
@ -435,7 +435,7 @@ where
|
|||||||
let mut word_docids = None;
|
let mut word_docids = None;
|
||||||
let mut exact_word_docids = None;
|
let mut exact_word_docids = None;
|
||||||
|
|
||||||
let mut dimension = None;
|
let mut dimension = HashMap::new();
|
||||||
|
|
||||||
for result in lmdb_writer_rx {
|
for result in lmdb_writer_rx {
|
||||||
if (self.should_abort)() {
|
if (self.should_abort)() {
|
||||||
@ -471,13 +471,15 @@ where
|
|||||||
remove_vectors,
|
remove_vectors,
|
||||||
embeddings,
|
embeddings,
|
||||||
manual_vectors,
|
manual_vectors,
|
||||||
|
embedder_name,
|
||||||
} => {
|
} => {
|
||||||
dimension = Some(expected_dimension);
|
dimension.insert(embedder_name.clone(), expected_dimension);
|
||||||
TypedChunk::VectorPoints {
|
TypedChunk::VectorPoints {
|
||||||
remove_vectors,
|
remove_vectors,
|
||||||
embeddings,
|
embeddings,
|
||||||
expected_dimension,
|
expected_dimension,
|
||||||
manual_vectors,
|
manual_vectors,
|
||||||
|
embedder_name,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
otherwise => otherwise,
|
otherwise => otherwise,
|
||||||
@ -513,14 +515,22 @@ where
|
|||||||
self.index.put_primary_key(self.wtxn, &primary_key)?;
|
self.index.put_primary_key(self.wtxn, &primary_key)?;
|
||||||
let number_of_documents = self.index.number_of_documents(self.wtxn)?;
|
let number_of_documents = self.index.number_of_documents(self.wtxn)?;
|
||||||
|
|
||||||
if let Some(dimension) = dimension {
|
for (embedder_name, dimension) in dimension {
|
||||||
let wtxn = &mut *self.wtxn;
|
let wtxn = &mut *self.wtxn;
|
||||||
let vector_arroy = self.index.vector_arroy;
|
let vector_arroy = self.index.vector_arroy;
|
||||||
|
/// FIXME: unwrap
|
||||||
|
let embedder_index =
|
||||||
|
self.index.embedder_category_id.get(wtxn, &embedder_name)?.unwrap();
|
||||||
pool.install(|| {
|
pool.install(|| {
|
||||||
/// FIXME: do for each embedder
|
let writer_index = (embedder_index as u16) << 8;
|
||||||
let mut rng = rand::rngs::StdRng::from_entropy();
|
let mut rng = rand::rngs::StdRng::from_entropy();
|
||||||
for k in 0..=u8::MAX {
|
for k in 0..=u8::MAX {
|
||||||
let writer = arroy::Writer::prepare(wtxn, vector_arroy, k.into(), dimension)?;
|
let writer = arroy::Writer::prepare(
|
||||||
|
wtxn,
|
||||||
|
vector_arroy,
|
||||||
|
writer_index | (k as u16),
|
||||||
|
dimension,
|
||||||
|
)?;
|
||||||
if writer.is_empty(wtxn)? {
|
if writer.is_empty(wtxn)? {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -47,6 +47,7 @@ pub(crate) enum TypedChunk {
|
|||||||
embeddings: Option<grenad::Reader<BufReader<File>>>,
|
embeddings: Option<grenad::Reader<BufReader<File>>>,
|
||||||
expected_dimension: usize,
|
expected_dimension: usize,
|
||||||
manual_vectors: grenad::Reader<BufReader<File>>,
|
manual_vectors: grenad::Reader<BufReader<File>>,
|
||||||
|
embedder_name: String,
|
||||||
},
|
},
|
||||||
ScriptLanguageDocids(HashMap<(Script, Language), (RoaringBitmap, RoaringBitmap)>),
|
ScriptLanguageDocids(HashMap<(Script, Language), (RoaringBitmap, RoaringBitmap)>),
|
||||||
}
|
}
|
||||||
@ -100,8 +101,8 @@ impl TypedChunk {
|
|||||||
TypedChunk::GeoPoints(grenad) => {
|
TypedChunk::GeoPoints(grenad) => {
|
||||||
format!("GeoPoints {{ number_of_entries: {} }}", grenad.len())
|
format!("GeoPoints {{ number_of_entries: {} }}", grenad.len())
|
||||||
}
|
}
|
||||||
TypedChunk::VectorPoints{ remove_vectors, manual_vectors, embeddings, expected_dimension } => {
|
TypedChunk::VectorPoints{ remove_vectors, manual_vectors, embeddings, expected_dimension, embedder_name } => {
|
||||||
format!("VectorPoints {{ remove_vectors: {}, manual_vectors: {}, embeddings: {}, dimension: {} }}", remove_vectors.len(), manual_vectors.len(), embeddings.as_ref().map(|e| e.len()).unwrap_or_default(), expected_dimension)
|
format!("VectorPoints {{ remove_vectors: {}, manual_vectors: {}, embeddings: {}, dimension: {}, embedder_name: {} }}", remove_vectors.len(), manual_vectors.len(), embeddings.as_ref().map(|e| e.len()).unwrap_or_default(), expected_dimension, embedder_name)
|
||||||
}
|
}
|
||||||
TypedChunk::ScriptLanguageDocids(sl_map) => {
|
TypedChunk::ScriptLanguageDocids(sl_map) => {
|
||||||
format!("ScriptLanguageDocids {{ number_of_entries: {} }}", sl_map.len())
|
format!("ScriptLanguageDocids {{ number_of_entries: {} }}", sl_map.len())
|
||||||
@ -360,12 +361,20 @@ pub(crate) fn write_typed_chunk_into_index(
|
|||||||
manual_vectors,
|
manual_vectors,
|
||||||
embeddings,
|
embeddings,
|
||||||
expected_dimension,
|
expected_dimension,
|
||||||
|
embedder_name,
|
||||||
} => {
|
} => {
|
||||||
/// FIXME: allow customizing distance
|
/// FIXME: unwrap
|
||||||
|
let embedder_index = index.embedder_category_id.get(wtxn, &embedder_name)?.unwrap();
|
||||||
|
let writer_index = (embedder_index as u16) << 8;
|
||||||
|
// FIXME: allow customizing distance
|
||||||
let writers: std::result::Result<Vec<_>, _> = (0..=u8::MAX)
|
let writers: std::result::Result<Vec<_>, _> = (0..=u8::MAX)
|
||||||
.map(|k| {
|
.map(|k| {
|
||||||
/// FIXME: allow customizing index and then do index << 8 + k
|
arroy::Writer::prepare(
|
||||||
arroy::Writer::prepare(wtxn, index.vector_arroy, k.into(), expected_dimension)
|
wtxn,
|
||||||
|
index.vector_arroy,
|
||||||
|
writer_index | (k as u16),
|
||||||
|
expected_dimension,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
let writers = writers?;
|
let writers = writers?;
|
||||||
@ -456,7 +465,7 @@ pub(crate) fn write_typed_chunk_into_index(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log::debug!("There are 🤷♀️ entries in the arroy so far");
|
log::debug!("Finished vector chunk for {}", embedder_name);
|
||||||
}
|
}
|
||||||
TypedChunk::ScriptLanguageDocids(sl_map) => {
|
TypedChunk::ScriptLanguageDocids(sl_map) => {
|
||||||
for (key, (deletion, addition)) in sl_map {
|
for (key, (deletion, addition)) in sl_map {
|
||||||
|
@ -431,7 +431,6 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
|
|||||||
let embedder = Arc::new(
|
let embedder = Arc::new(
|
||||||
Embedder::new(embedder_options.clone())
|
Embedder::new(embedder_options.clone())
|
||||||
.map_err(crate::vector::Error::from)
|
.map_err(crate::vector::Error::from)
|
||||||
.map_err(crate::UserError::from)
|
|
||||||
.map_err(crate::Error::from)?,
|
.map_err(crate::Error::from)?,
|
||||||
);
|
);
|
||||||
Ok((name, (embedder, prompt)))
|
Ok((name, (embedder, prompt)))
|
||||||
@ -976,6 +975,19 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
|
|||||||
Setting::NotSet => Some((name, EmbeddingSettings::default().into())),
|
Setting::NotSet => Some((name, EmbeddingSettings::default().into())),
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
self.index.embedder_category_id.clear(self.wtxn)?;
|
||||||
|
for (index, (embedder_name, _)) in new_configs.iter().enumerate() {
|
||||||
|
self.index.embedder_category_id.put_with_flags(
|
||||||
|
self.wtxn,
|
||||||
|
heed::PutFlags::APPEND,
|
||||||
|
embedder_name,
|
||||||
|
&index
|
||||||
|
.try_into()
|
||||||
|
.map_err(|_| UserError::TooManyEmbedders(new_configs.len()))?,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
|
||||||
if new_configs.is_empty() {
|
if new_configs.is_empty() {
|
||||||
self.index.delete_embedding_configs(self.wtxn)?;
|
self.index.delete_embedding_configs(self.wtxn)?;
|
||||||
} else {
|
} else {
|
||||||
@ -1062,7 +1074,7 @@ fn validate_prompt(
|
|||||||
match new {
|
match new {
|
||||||
Setting::Set(EmbeddingSettings {
|
Setting::Set(EmbeddingSettings {
|
||||||
embedder_options,
|
embedder_options,
|
||||||
prompt:
|
document_template:
|
||||||
Setting::Set(PromptSettings { template: Setting::Set(template), strategy, fallback }),
|
Setting::Set(PromptSettings { template: Setting::Set(template), strategy, fallback }),
|
||||||
}) => {
|
}) => {
|
||||||
// validate
|
// validate
|
||||||
@ -1072,7 +1084,7 @@ fn validate_prompt(
|
|||||||
|
|
||||||
Ok(Setting::Set(EmbeddingSettings {
|
Ok(Setting::Set(EmbeddingSettings {
|
||||||
embedder_options,
|
embedder_options,
|
||||||
prompt: Setting::Set(PromptSettings {
|
document_template: Setting::Set(PromptSettings {
|
||||||
template: Setting::Set(template),
|
template: Setting::Set(template),
|
||||||
strategy,
|
strategy,
|
||||||
fallback,
|
fallback,
|
||||||
|
@ -65,6 +65,8 @@ pub enum EmbedErrorKind {
|
|||||||
OpenAiTooManyTokens(OpenAiError),
|
OpenAiTooManyTokens(OpenAiError),
|
||||||
#[error("received unhandled HTTP status code {0} from OpenAI")]
|
#[error("received unhandled HTTP status code {0} from OpenAI")]
|
||||||
OpenAiUnhandledStatusCode(u16),
|
OpenAiUnhandledStatusCode(u16),
|
||||||
|
#[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")]
|
||||||
|
ManualEmbed(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EmbedError {
|
impl EmbedError {
|
||||||
@ -111,6 +113,10 @@ impl EmbedError {
|
|||||||
pub(crate) fn openai_unhandled_status_code(code: u16) -> EmbedError {
|
pub(crate) fn openai_unhandled_status_code(code: u16) -> EmbedError {
|
||||||
Self { kind: EmbedErrorKind::OpenAiUnhandledStatusCode(code), fault: FaultSource::Bug }
|
Self { kind: EmbedErrorKind::OpenAiUnhandledStatusCode(code), fault: FaultSource::Bug }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError {
|
||||||
|
Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User }
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
@ -170,6 +176,13 @@ impl NewEmbedderError {
|
|||||||
Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime }
|
Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn hf_could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError {
|
||||||
|
Self {
|
||||||
|
kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner),
|
||||||
|
fault: FaultSource::Runtime,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self {
|
pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self {
|
||||||
Self { kind: NewEmbedderErrorKind::InitWebClient(inner), fault: FaultSource::Runtime }
|
Self { kind: NewEmbedderErrorKind::InitWebClient(inner), fault: FaultSource::Runtime }
|
||||||
}
|
}
|
||||||
@ -219,6 +232,8 @@ pub enum NewEmbedderErrorKind {
|
|||||||
NewApiFail(ApiError),
|
NewApiFail(ApiError),
|
||||||
#[error("fetching file from HG_HUB failed: {0}")]
|
#[error("fetching file from HG_HUB failed: {0}")]
|
||||||
ApiGet(ApiError),
|
ApiGet(ApiError),
|
||||||
|
#[error("could not determine model dimensions: test embedding failed with {0}")]
|
||||||
|
CouldNotDetermineDimension(EmbedError),
|
||||||
#[error("loading model failed: {0}")]
|
#[error("loading model failed: {0}")]
|
||||||
LoadModel(candle_core::Error),
|
LoadModel(candle_core::Error),
|
||||||
// openai
|
// openai
|
||||||
|
@ -62,6 +62,7 @@ pub struct Embedder {
|
|||||||
model: BertModel,
|
model: BertModel,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
options: EmbedderOptions,
|
options: EmbedderOptions,
|
||||||
|
dimensions: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for Embedder {
|
impl std::fmt::Debug for Embedder {
|
||||||
@ -126,10 +127,17 @@ impl Embedder {
|
|||||||
tokenizer.with_padding(Some(pp));
|
tokenizer.with_padding(Some(pp));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Self { model, tokenizer, options })
|
let mut this = Self { model, tokenizer, options, dimensions: 0 };
|
||||||
|
|
||||||
|
let embeddings = this
|
||||||
|
.embed(vec!["test".into()])
|
||||||
|
.map_err(NewEmbedderError::hf_could_not_determine_dimension)?;
|
||||||
|
this.dimensions = embeddings.first().unwrap().dimension();
|
||||||
|
|
||||||
|
Ok(this)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn embed(
|
pub fn embed(
|
||||||
&self,
|
&self,
|
||||||
mut texts: Vec<String>,
|
mut texts: Vec<String>,
|
||||||
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
|
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||||
@ -170,12 +178,11 @@ impl Embedder {
|
|||||||
Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect())
|
Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn embed_chunks(
|
pub fn embed_chunks(
|
||||||
&self,
|
&self,
|
||||||
text_chunks: Vec<Vec<String>>,
|
text_chunks: Vec<Vec<String>>,
|
||||||
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||||
futures::future::try_join_all(text_chunks.into_iter().map(|prompts| self.embed(prompts)))
|
text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect()
|
||||||
.await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn chunk_count_hint(&self) -> usize {
|
pub fn chunk_count_hint(&self) -> usize {
|
||||||
@ -185,6 +192,10 @@ impl Embedder {
|
|||||||
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
||||||
std::thread::available_parallelism().map(|x| x.get()).unwrap_or(8)
|
std::thread::available_parallelism().map(|x| x.get()).unwrap_or(8)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn dimensions(&self) -> usize {
|
||||||
|
self.dimensions
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn normalize_l2(v: &Tensor) -> Result<Tensor, candle_core::Error> {
|
fn normalize_l2(v: &Tensor) -> Result<Tensor, candle_core::Error> {
|
||||||
|
@ -3,6 +3,7 @@ use crate::prompt::PromptData;
|
|||||||
|
|
||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod hf;
|
pub mod hf;
|
||||||
|
pub mod manual;
|
||||||
pub mod openai;
|
pub mod openai;
|
||||||
pub mod settings;
|
pub mod settings;
|
||||||
|
|
||||||
@ -67,6 +68,7 @@ impl<F> Embeddings<F> {
|
|||||||
pub enum Embedder {
|
pub enum Embedder {
|
||||||
HuggingFace(hf::Embedder),
|
HuggingFace(hf::Embedder),
|
||||||
OpenAi(openai::Embedder),
|
OpenAi(openai::Embedder),
|
||||||
|
UserProvided(manual::Embedder),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
|
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
|
||||||
@ -80,6 +82,7 @@ pub struct EmbeddingConfig {
|
|||||||
pub enum EmbedderOptions {
|
pub enum EmbedderOptions {
|
||||||
HuggingFace(hf::EmbedderOptions),
|
HuggingFace(hf::EmbedderOptions),
|
||||||
OpenAi(openai::EmbedderOptions),
|
OpenAi(openai::EmbedderOptions),
|
||||||
|
UserProvided(manual::EmbedderOptions),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for EmbedderOptions {
|
impl Default for EmbedderOptions {
|
||||||
@ -93,7 +96,7 @@ impl EmbedderOptions {
|
|||||||
Self::HuggingFace(hf::EmbedderOptions::new())
|
Self::HuggingFace(hf::EmbedderOptions::new())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn openai(api_key: String) -> Self {
|
pub fn openai(api_key: Option<String>) -> Self {
|
||||||
Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key))
|
Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -103,6 +106,9 @@ impl Embedder {
|
|||||||
Ok(match options {
|
Ok(match options {
|
||||||
EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?),
|
EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?),
|
||||||
EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?),
|
EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?),
|
||||||
|
EmbedderOptions::UserProvided(options) => {
|
||||||
|
Self::UserProvided(manual::Embedder::new(options))
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -111,8 +117,9 @@ impl Embedder {
|
|||||||
texts: Vec<String>,
|
texts: Vec<String>,
|
||||||
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
|
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||||
match self {
|
match self {
|
||||||
Embedder::HuggingFace(embedder) => embedder.embed(texts).await,
|
Embedder::HuggingFace(embedder) => embedder.embed(texts),
|
||||||
Embedder::OpenAi(embedder) => embedder.embed(texts).await,
|
Embedder::OpenAi(embedder) => embedder.embed(texts).await,
|
||||||
|
Embedder::UserProvided(embedder) => embedder.embed(texts),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -121,8 +128,9 @@ impl Embedder {
|
|||||||
text_chunks: Vec<Vec<String>>,
|
text_chunks: Vec<Vec<String>>,
|
||||||
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||||
match self {
|
match self {
|
||||||
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks).await,
|
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),
|
||||||
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks).await,
|
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks).await,
|
||||||
|
Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -130,6 +138,7 @@ impl Embedder {
|
|||||||
match self {
|
match self {
|
||||||
Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(),
|
Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(),
|
||||||
Embedder::OpenAi(embedder) => embedder.chunk_count_hint(),
|
Embedder::OpenAi(embedder) => embedder.chunk_count_hint(),
|
||||||
|
Embedder::UserProvided(_) => 1,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -137,6 +146,15 @@ impl Embedder {
|
|||||||
match self {
|
match self {
|
||||||
Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(),
|
Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(),
|
||||||
Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(),
|
Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(),
|
||||||
|
Embedder::UserProvided(_) => 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dimensions(&self) -> usize {
|
||||||
|
match self {
|
||||||
|
Embedder::HuggingFace(embedder) => embedder.dimensions(),
|
||||||
|
Embedder::OpenAi(embedder) => embedder.dimensions(),
|
||||||
|
Embedder::UserProvided(embedder) => embedder.dimensions(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -15,7 +15,7 @@ pub struct Embedder {
|
|||||||
|
|
||||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
||||||
pub struct EmbedderOptions {
|
pub struct EmbedderOptions {
|
||||||
pub api_key: String,
|
pub api_key: Option<String>,
|
||||||
pub embedding_model: EmbeddingModel,
|
pub embedding_model: EmbeddingModel,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -68,11 +68,11 @@ impl EmbeddingModel {
|
|||||||
pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
|
pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
|
||||||
|
|
||||||
impl EmbedderOptions {
|
impl EmbedderOptions {
|
||||||
pub fn with_default_model(api_key: String) -> Self {
|
pub fn with_default_model(api_key: Option<String>) -> Self {
|
||||||
Self { api_key, embedding_model: Default::default() }
|
Self { api_key, embedding_model: Default::default() }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_embedding_model(api_key: String, embedding_model: EmbeddingModel) -> Self {
|
pub fn with_embedding_model(api_key: Option<String>, embedding_model: EmbeddingModel) -> Self {
|
||||||
Self { api_key, embedding_model }
|
Self { api_key, embedding_model }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -80,9 +80,14 @@ impl EmbedderOptions {
|
|||||||
impl Embedder {
|
impl Embedder {
|
||||||
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
|
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
|
||||||
let mut headers = reqwest::header::HeaderMap::new();
|
let mut headers = reqwest::header::HeaderMap::new();
|
||||||
|
let mut inferred_api_key = Default::default();
|
||||||
|
let api_key = options.api_key.as_ref().unwrap_or_else(|| {
|
||||||
|
inferred_api_key = infer_api_key();
|
||||||
|
&inferred_api_key
|
||||||
|
});
|
||||||
headers.insert(
|
headers.insert(
|
||||||
reqwest::header::AUTHORIZATION,
|
reqwest::header::AUTHORIZATION,
|
||||||
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", &options.api_key))
|
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", api_key))
|
||||||
.map_err(NewEmbedderError::openai_invalid_api_key_format)?,
|
.map_err(NewEmbedderError::openai_invalid_api_key_format)?,
|
||||||
);
|
);
|
||||||
headers.insert(
|
headers.insert(
|
||||||
@ -315,6 +320,10 @@ impl Embedder {
|
|||||||
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
||||||
10
|
10
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn dimensions(&self) -> usize {
|
||||||
|
self.options.embedding_model.dimensions()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// retrying in case of failure
|
// retrying in case of failure
|
||||||
@ -414,3 +423,9 @@ struct OpenAiEmbedding {
|
|||||||
// object: String,
|
// object: String,
|
||||||
// index: usize,
|
// index: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn infer_api_key() -> String {
|
||||||
|
std::env::var("MEILI_OPENAI_API_KEY")
|
||||||
|
.or_else(|_| std::env::var("OPENAI_API_KEY"))
|
||||||
|
.unwrap_or_default()
|
||||||
|
}
|
||||||
|
@ -15,14 +15,14 @@ pub struct EmbeddingSettings {
|
|||||||
pub embedder_options: Setting<EmbedderSettings>,
|
pub embedder_options: Setting<EmbedderSettings>,
|
||||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
#[deserr(default)]
|
#[deserr(default)]
|
||||||
pub prompt: Setting<PromptSettings>,
|
pub document_template: Setting<PromptSettings>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EmbeddingSettings {
|
impl EmbeddingSettings {
|
||||||
pub fn apply(&mut self, new: Self) {
|
pub fn apply(&mut self, new: Self) {
|
||||||
let EmbeddingSettings { embedder_options, prompt } = new;
|
let EmbeddingSettings { embedder_options, document_template: prompt } = new;
|
||||||
self.embedder_options.apply(embedder_options);
|
self.embedder_options.apply(embedder_options);
|
||||||
self.prompt.apply(prompt);
|
self.document_template.apply(prompt);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -30,7 +30,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
|
|||||||
fn from(value: EmbeddingConfig) -> Self {
|
fn from(value: EmbeddingConfig) -> Self {
|
||||||
Self {
|
Self {
|
||||||
embedder_options: Setting::Set(value.embedder_options.into()),
|
embedder_options: Setting::Set(value.embedder_options.into()),
|
||||||
prompt: Setting::Set(value.prompt.into()),
|
document_template: Setting::Set(value.prompt.into()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -38,7 +38,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
|
|||||||
impl From<EmbeddingSettings> for EmbeddingConfig {
|
impl From<EmbeddingSettings> for EmbeddingConfig {
|
||||||
fn from(value: EmbeddingSettings) -> Self {
|
fn from(value: EmbeddingSettings) -> Self {
|
||||||
let mut this = Self::default();
|
let mut this = Self::default();
|
||||||
let EmbeddingSettings { embedder_options, prompt } = value;
|
let EmbeddingSettings { embedder_options, document_template: prompt } = value;
|
||||||
if let Some(embedder_options) = embedder_options.set() {
|
if let Some(embedder_options) = embedder_options.set() {
|
||||||
this.embedder_options = embedder_options.into();
|
this.embedder_options = embedder_options.into();
|
||||||
}
|
}
|
||||||
@ -105,6 +105,7 @@ impl From<PromptSettings> for PromptData {
|
|||||||
pub enum EmbedderSettings {
|
pub enum EmbedderSettings {
|
||||||
HuggingFace(Setting<HfEmbedderSettings>),
|
HuggingFace(Setting<HfEmbedderSettings>),
|
||||||
OpenAi(Setting<OpenAiEmbedderSettings>),
|
OpenAi(Setting<OpenAiEmbedderSettings>),
|
||||||
|
UserProvided(UserProvidedSettings),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<E> Deserr<E> for EmbedderSettings
|
impl<E> Deserr<E> for EmbedderSettings
|
||||||
@ -145,11 +146,17 @@ where
|
|||||||
location.push_key(&k),
|
location.push_key(&k),
|
||||||
)?,
|
)?,
|
||||||
))),
|
))),
|
||||||
|
"userProvided" => Ok(EmbedderSettings::UserProvided(
|
||||||
|
UserProvidedSettings::deserialize_from_value(
|
||||||
|
v.into_value(),
|
||||||
|
location.push_key(&k),
|
||||||
|
)?,
|
||||||
|
)),
|
||||||
other => Err(deserr::take_cf_content(E::error::<V>(
|
other => Err(deserr::take_cf_content(E::error::<V>(
|
||||||
None,
|
None,
|
||||||
deserr::ErrorKind::UnknownKey {
|
deserr::ErrorKind::UnknownKey {
|
||||||
key: other,
|
key: other,
|
||||||
accepted: &["huggingFace", "openAi"],
|
accepted: &["huggingFace", "openAi", "userProvided"],
|
||||||
},
|
},
|
||||||
location,
|
location,
|
||||||
))),
|
))),
|
||||||
@ -182,6 +189,9 @@ impl From<crate::vector::EmbedderOptions> for EmbedderSettings {
|
|||||||
crate::vector::EmbedderOptions::OpenAi(openai) => {
|
crate::vector::EmbedderOptions::OpenAi(openai) => {
|
||||||
Self::OpenAi(Setting::Set(openai.into()))
|
Self::OpenAi(Setting::Set(openai.into()))
|
||||||
}
|
}
|
||||||
|
crate::vector::EmbedderOptions::UserProvided(user_provided) => {
|
||||||
|
Self::UserProvided(user_provided.into())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -192,9 +202,12 @@ impl From<EmbedderSettings> for crate::vector::EmbedderOptions {
|
|||||||
EmbedderSettings::HuggingFace(Setting::Set(hf)) => Self::HuggingFace(hf.into()),
|
EmbedderSettings::HuggingFace(Setting::Set(hf)) => Self::HuggingFace(hf.into()),
|
||||||
EmbedderSettings::HuggingFace(_setting) => Self::HuggingFace(Default::default()),
|
EmbedderSettings::HuggingFace(_setting) => Self::HuggingFace(Default::default()),
|
||||||
EmbedderSettings::OpenAi(Setting::Set(ai)) => Self::OpenAi(ai.into()),
|
EmbedderSettings::OpenAi(Setting::Set(ai)) => Self::OpenAi(ai.into()),
|
||||||
EmbedderSettings::OpenAi(_setting) => Self::OpenAi(
|
EmbedderSettings::OpenAi(_setting) => {
|
||||||
crate::vector::openai::EmbedderOptions::with_default_model(infer_api_key()),
|
Self::OpenAi(crate::vector::openai::EmbedderOptions::with_default_model(None))
|
||||||
),
|
}
|
||||||
|
EmbedderSettings::UserProvided(user_provided) => {
|
||||||
|
Self::UserProvided(user_provided.into())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -286,7 +299,7 @@ impl OpenAiEmbedderSettings {
|
|||||||
impl From<crate::vector::openai::EmbedderOptions> for OpenAiEmbedderSettings {
|
impl From<crate::vector::openai::EmbedderOptions> for OpenAiEmbedderSettings {
|
||||||
fn from(value: crate::vector::openai::EmbedderOptions) -> Self {
|
fn from(value: crate::vector::openai::EmbedderOptions) -> Self {
|
||||||
Self {
|
Self {
|
||||||
api_key: Setting::Set(value.api_key),
|
api_key: value.api_key.map(Setting::Set).unwrap_or(Setting::Reset),
|
||||||
embedding_model: Setting::Set(value.embedding_model),
|
embedding_model: Setting::Set(value.embedding_model),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -295,14 +308,25 @@ impl From<crate::vector::openai::EmbedderOptions> for OpenAiEmbedderSettings {
|
|||||||
impl From<OpenAiEmbedderSettings> for crate::vector::openai::EmbedderOptions {
|
impl From<OpenAiEmbedderSettings> for crate::vector::openai::EmbedderOptions {
|
||||||
fn from(value: OpenAiEmbedderSettings) -> Self {
|
fn from(value: OpenAiEmbedderSettings) -> Self {
|
||||||
let OpenAiEmbedderSettings { api_key, embedding_model } = value;
|
let OpenAiEmbedderSettings { api_key, embedding_model } = value;
|
||||||
Self {
|
Self { api_key: api_key.set(), embedding_model: embedding_model.set().unwrap_or_default() }
|
||||||
api_key: api_key.set().unwrap_or_else(infer_api_key),
|
|
||||||
embedding_model: embedding_model.set().unwrap_or_default(),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn infer_api_key() -> String {
|
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
|
||||||
/// FIXME: get key from instance options?
|
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||||
std::env::var("MEILI_OPENAI_API_KEY").unwrap_or_default()
|
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
||||||
|
pub struct UserProvidedSettings {
|
||||||
|
pub dimensions: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<UserProvidedSettings> for crate::vector::manual::EmbedderOptions {
|
||||||
|
fn from(value: UserProvidedSettings) -> Self {
|
||||||
|
Self { dimensions: value.dimensions }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<crate::vector::manual::EmbedderOptions> for UserProvidedSettings {
|
||||||
|
fn from(value: crate::vector::manual::EmbedderOptions) -> Self {
|
||||||
|
Self { dimensions: value.dimensions }
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user