Move crates under a sub folder to clean up the code

This commit is contained in:
Clément Renault 2024-10-21 08:18:43 +02:00
parent 30f3c30389
commit 9c1e54a2c8
No known key found for this signature in database
GPG key ID: F250A4C4E3AE5F5F
1062 changed files with 19 additions and 20 deletions

View file

@ -0,0 +1,435 @@
use std::collections::BTreeMap;
use std::path::PathBuf;
use hf_hub::api::sync::ApiError;
use super::parsed_vectors::ParsedVectorsDiff;
use super::rest::ConfigurationSource;
use crate::error::FaultSource;
use crate::{FieldDistribution, PanicCatched};
#[derive(Debug, thiserror::Error)]
#[error("Error while generating embeddings: {inner}")]
pub struct Error {
pub inner: Box<ErrorKind>,
}
impl<I: Into<ErrorKind>> From<I> for Error {
fn from(value: I) -> Self {
Self { inner: Box::new(value.into()) }
}
}
impl Error {
pub fn fault(&self) -> FaultSource {
match &*self.inner {
ErrorKind::NewEmbedderError(inner) => inner.fault,
ErrorKind::EmbedError(inner) => inner.fault,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum ErrorKind {
#[error(transparent)]
NewEmbedderError(#[from] NewEmbedderError),
#[error(transparent)]
EmbedError(#[from] EmbedError),
}
#[derive(Debug, thiserror::Error)]
#[error("{fault}: {kind}")]
pub struct EmbedError {
pub kind: EmbedErrorKind,
pub fault: FaultSource,
}
#[derive(Debug, thiserror::Error)]
pub enum EmbedErrorKind {
#[error("could not tokenize:\n - {0}")]
Tokenize(Box<dyn std::error::Error + Send + Sync>),
#[error("unexpected tensor shape:\n - {0}")]
TensorShape(candle_core::Error),
#[error("unexpected tensor value:\n - {0}")]
TensorValue(candle_core::Error),
#[error("could not run model:\n - {0}")]
ModelForward(candle_core::Error),
#[error("attempt to embed the following text in a configuration where embeddings must be user provided:\n - `{0}`")]
ManualEmbed(String),
#[error("model not found. Meilisearch will not automatically download models from the Ollama library, please pull the model manually{}", option_info(.0.as_deref(), "server replied with "))]
OllamaModelNotFoundError(Option<String>),
#[error("error deserialization the response body as JSON:\n - {0}")]
RestResponseDeserialization(std::io::Error),
#[error("expected a response containing {0} embeddings, got only {1}")]
RestResponseEmbeddingCount(usize, usize),
#[error("could not authenticate against {embedding} server{server_reply}{hint}", embedding=match *.1 {
ConfigurationSource::User => "embedding",
ConfigurationSource::OpenAi => "OpenAI",
ConfigurationSource::Ollama => "ollama"
},
server_reply=option_info(.0.as_deref(), "server replied with "),
hint=match *.1 {
ConfigurationSource::User => "\n - Hint: Check the `apiKey` parameter in the embedder configuration",
ConfigurationSource::OpenAi => "\n - Hint: Check the `apiKey` parameter in the embedder configuration, and the `MEILI_OPENAI_API_KEY` and `OPENAI_API_KEY` environment variables",
ConfigurationSource::Ollama => "\n - Hint: Check the `apiKey` parameter in the embedder configuration"
})]
RestUnauthorized(Option<String>, ConfigurationSource),
#[error("sent too many requests to embedding server{}", option_info(.0.as_deref(), "server replied with "))]
RestTooManyRequests(Option<String>),
#[error("sent a bad request to embedding server{}{}",
if ConfigurationSource::User == *.1 {
"\n - Hint: check that the `request` in the embedder configuration matches the remote server's API"
} else {
""
},
option_info(.0.as_deref(), "server replied with "))]
RestBadRequest(Option<String>, ConfigurationSource),
#[error("received internal error HTTP {0} from embedding server{}", option_info(.1.as_deref(), "server replied with "))]
RestInternalServerError(u16, Option<String>),
#[error("received unexpected HTTP {0} from embedding server{}", option_info(.1.as_deref(), "server replied with "))]
RestOtherStatusCode(u16, Option<String>),
#[error("could not reach embedding server:\n - {0}")]
RestNetwork(ureq::Transport),
#[error("error extracting embeddings from the response:\n - {0}")]
RestExtractionError(String),
#[error("was expecting embeddings of dimension `{0}`, got embeddings of dimensions `{1}`")]
UnexpectedDimension(usize, usize),
#[error("no embedding was produced")]
MissingEmbedding,
#[error(transparent)]
PanicInThreadPool(#[from] PanicCatched),
}
fn option_info(info: Option<&str>, prefix: &str) -> String {
match info {
Some(info) => format!("\n - {prefix}`{info}`"),
None => String::new(),
}
}
impl EmbedError {
pub fn tokenize(inner: Box<dyn std::error::Error + Send + Sync>) -> Self {
Self { kind: EmbedErrorKind::Tokenize(inner), fault: FaultSource::Runtime }
}
pub fn tensor_shape(inner: candle_core::Error) -> Self {
Self { kind: EmbedErrorKind::TensorShape(inner), fault: FaultSource::Bug }
}
pub fn tensor_value(inner: candle_core::Error) -> Self {
Self { kind: EmbedErrorKind::TensorValue(inner), fault: FaultSource::Bug }
}
pub fn model_forward(inner: candle_core::Error) -> Self {
Self { kind: EmbedErrorKind::ModelForward(inner), fault: FaultSource::Runtime }
}
pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError {
Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User }
}
pub(crate) fn ollama_model_not_found(inner: Option<String>) -> EmbedError {
Self { kind: EmbedErrorKind::OllamaModelNotFoundError(inner), fault: FaultSource::User }
}
pub(crate) fn rest_response_deserialization(error: std::io::Error) -> EmbedError {
Self {
kind: EmbedErrorKind::RestResponseDeserialization(error),
fault: FaultSource::Runtime,
}
}
pub(crate) fn rest_response_embedding_count(expected: usize, got: usize) -> EmbedError {
Self {
kind: EmbedErrorKind::RestResponseEmbeddingCount(expected, got),
fault: FaultSource::Runtime,
}
}
pub(crate) fn rest_unauthorized(
error_response: Option<String>,
configuration_source: ConfigurationSource,
) -> EmbedError {
Self {
kind: EmbedErrorKind::RestUnauthorized(error_response, configuration_source),
fault: FaultSource::User,
}
}
pub(crate) fn rest_too_many_requests(error_response: Option<String>) -> EmbedError {
Self {
kind: EmbedErrorKind::RestTooManyRequests(error_response),
fault: FaultSource::Runtime,
}
}
pub(crate) fn rest_bad_request(
error_response: Option<String>,
configuration_source: ConfigurationSource,
) -> EmbedError {
Self {
kind: EmbedErrorKind::RestBadRequest(error_response, configuration_source),
fault: FaultSource::User,
}
}
pub(crate) fn rest_internal_server_error(
code: u16,
error_response: Option<String>,
) -> EmbedError {
Self {
kind: EmbedErrorKind::RestInternalServerError(code, error_response),
fault: FaultSource::Runtime,
}
}
pub(crate) fn rest_other_status_code(code: u16, error_response: Option<String>) -> EmbedError {
Self {
kind: EmbedErrorKind::RestOtherStatusCode(code, error_response),
fault: FaultSource::Undecided,
}
}
pub(crate) fn rest_network(transport: ureq::Transport) -> EmbedError {
Self { kind: EmbedErrorKind::RestNetwork(transport), fault: FaultSource::Runtime }
}
pub(crate) fn rest_unexpected_dimension(expected: usize, got: usize) -> EmbedError {
Self {
kind: EmbedErrorKind::UnexpectedDimension(expected, got),
fault: FaultSource::Runtime,
}
}
pub(crate) fn missing_embedding() -> EmbedError {
Self { kind: EmbedErrorKind::MissingEmbedding, fault: FaultSource::Undecided }
}
pub(crate) fn rest_extraction_error(error: String) -> EmbedError {
Self { kind: EmbedErrorKind::RestExtractionError(error), fault: FaultSource::Runtime }
}
}
#[derive(Debug, thiserror::Error)]
#[error("{fault}: {kind}")]
pub struct NewEmbedderError {
pub kind: NewEmbedderErrorKind,
pub fault: FaultSource,
}
impl NewEmbedderError {
pub fn open_config(config_filename: PathBuf, inner: std::io::Error) -> NewEmbedderError {
let open_config = OpenConfig { filename: config_filename, inner };
Self { kind: NewEmbedderErrorKind::OpenConfig(open_config), fault: FaultSource::Runtime }
}
pub fn deserialize_config(
model_name: String,
config: String,
config_filename: PathBuf,
inner: serde_json::Error,
) -> NewEmbedderError {
match serde_json::from_str(&config) {
Ok(value) => {
let value: serde_json::Value = value;
let architectures = match value.get("architectures") {
Some(serde_json::Value::Array(architectures)) => architectures
.iter()
.filter_map(|value| match value {
serde_json::Value::String(s) => Some(s.to_owned()),
_ => None,
})
.collect(),
_ => vec![],
};
let unsupported_model = UnsupportedModel { model_name, inner, architectures };
Self {
kind: NewEmbedderErrorKind::UnsupportedModel(unsupported_model),
fault: FaultSource::User,
}
}
Err(error) => {
let deserialize_config =
DeserializeConfig { model_name, filename: config_filename, inner: error };
Self {
kind: NewEmbedderErrorKind::DeserializeConfig(deserialize_config),
fault: FaultSource::Runtime,
}
}
}
}
pub fn open_tokenizer(
tokenizer_filename: PathBuf,
inner: Box<dyn std::error::Error + Send + Sync>,
) -> NewEmbedderError {
let open_tokenizer = OpenTokenizer { filename: tokenizer_filename, inner };
Self {
kind: NewEmbedderErrorKind::OpenTokenizer(open_tokenizer),
fault: FaultSource::Runtime,
}
}
pub fn new_api_fail(inner: ApiError) -> Self {
Self { kind: NewEmbedderErrorKind::NewApiFail(inner), fault: FaultSource::Bug }
}
pub fn api_get(inner: ApiError) -> Self {
Self { kind: NewEmbedderErrorKind::ApiGet(inner), fault: FaultSource::Undecided }
}
pub fn pytorch_weight(inner: candle_core::Error) -> Self {
Self { kind: NewEmbedderErrorKind::PytorchWeight(inner), fault: FaultSource::Runtime }
}
pub fn safetensor_weight(inner: candle_core::Error) -> Self {
Self { kind: NewEmbedderErrorKind::SafetensorWeight(inner), fault: FaultSource::Runtime }
}
pub fn load_model(inner: candle_core::Error) -> Self {
Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime }
}
pub fn could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError {
Self {
kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner),
fault: FaultSource::Runtime,
}
}
pub(crate) fn rest_could_not_parse_template(message: String) -> NewEmbedderError {
Self {
kind: NewEmbedderErrorKind::CouldNotParseTemplate(message),
fault: FaultSource::User,
}
}
}
#[derive(Debug, thiserror::Error)]
#[error("could not open config at {filename}: {inner}")]
pub struct OpenConfig {
pub filename: PathBuf,
pub inner: std::io::Error,
}
#[derive(Debug, thiserror::Error)]
#[error("for model '{model_name}', could not deserialize config at {filename} as JSON: {inner}")]
pub struct DeserializeConfig {
pub model_name: String,
pub filename: PathBuf,
pub inner: serde_json::Error,
}
#[derive(Debug, thiserror::Error)]
#[error("model `{model_name}` appears to be unsupported{}\n - inner error: {inner}",
if architectures.is_empty() {
"\n - Note: only models with architecture \"BertModel\" are supported.".to_string()
} else {
format!("\n - Note: model has declared architectures `{architectures:?}`, only models with architecture `\"BertModel\"` are supported.")
})]
pub struct UnsupportedModel {
pub model_name: String,
pub inner: serde_json::Error,
pub architectures: Vec<String>,
}
#[derive(Debug, thiserror::Error)]
#[error("could not open tokenizer at {filename}: {inner}")]
pub struct OpenTokenizer {
pub filename: PathBuf,
#[source]
pub inner: Box<dyn std::error::Error + Send + Sync>,
}
#[derive(Debug, thiserror::Error)]
pub enum NewEmbedderErrorKind {
// hf
#[error(transparent)]
OpenConfig(OpenConfig),
#[error(transparent)]
DeserializeConfig(DeserializeConfig),
#[error(transparent)]
UnsupportedModel(UnsupportedModel),
#[error(transparent)]
OpenTokenizer(OpenTokenizer),
#[error("could not build weights from Pytorch weights:\n - {0}")]
PytorchWeight(candle_core::Error),
#[error("could not build weights from Safetensor weights:\n - {0}")]
SafetensorWeight(candle_core::Error),
#[error("could not spawn HG_HUB API client:\n - {0}")]
NewApiFail(ApiError),
#[error("fetching file from HG_HUB failed:\n - {0}")]
ApiGet(ApiError),
#[error("could not determine model dimensions:\n - test embedding failed with {0}")]
CouldNotDetermineDimension(EmbedError),
#[error("loading model failed:\n - {0}")]
LoadModel(candle_core::Error),
#[error("{0}")]
CouldNotParseTemplate(String),
}
pub struct PossibleEmbeddingMistakes {
vectors_mistakes: BTreeMap<String, u64>,
}
impl PossibleEmbeddingMistakes {
pub fn new(field_distribution: &FieldDistribution) -> Self {
let mut vectors_mistakes = BTreeMap::new();
let builder = levenshtein_automata::LevenshteinAutomatonBuilder::new(2, true);
let automata = builder.build_dfa("_vectors");
for (field, count) in field_distribution {
if *count == 0 {
continue;
}
if field.contains('.') {
continue;
}
match automata.eval(field) {
levenshtein_automata::Distance::Exact(0) => continue,
levenshtein_automata::Distance::Exact(_) => {
vectors_mistakes.insert(field.to_string(), *count);
}
levenshtein_automata::Distance::AtLeast(_) => continue,
}
}
Self { vectors_mistakes }
}
pub fn vector_mistakes(&self) -> impl Iterator<Item = (&str, u64)> {
self.vectors_mistakes.iter().map(|(misspelling, count)| (misspelling.as_str(), *count))
}
pub fn embedder_mistakes<'a>(
&'a self,
embedder_name: &'a str,
unused_vectors_distributions: &'a UnusedVectorsDistribution,
) -> impl Iterator<Item = (&'a str, u64)> + 'a {
let builder = levenshtein_automata::LevenshteinAutomatonBuilder::new(2, true);
let automata = builder.build_dfa(embedder_name);
unused_vectors_distributions.0.iter().filter_map(move |(field, count)| {
match automata.eval(field) {
levenshtein_automata::Distance::Exact(0) => None,
levenshtein_automata::Distance::Exact(_) => Some((field.as_str(), *count)),
levenshtein_automata::Distance::AtLeast(_) => None,
}
})
}
}
#[derive(Default)]
pub struct UnusedVectorsDistribution(BTreeMap<String, u64>);
impl UnusedVectorsDistribution {
pub fn new() -> Self {
Self::default()
}
pub fn append(&mut self, parsed_vectors_diff: ParsedVectorsDiff) {
for name in parsed_vectors_diff.into_new_vectors_keys_iter() {
*self.0.entry(name).or_default() += 1;
}
}
}

View file

@ -0,0 +1,214 @@
use candle_core::Tensor;
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
// FIXME: currently we'll be using the hub to retrieve model, in the future we might want to embed it into Meilisearch itself
use hf_hub::api::sync::Api;
use hf_hub::{Repo, RepoType};
use tokenizers::{PaddingParams, Tokenizer};
pub use super::error::{EmbedError, Error, NewEmbedderError};
use super::{DistributionShift, Embedding, Embeddings};
#[derive(
Debug,
Clone,
Copy,
Default,
Hash,
PartialEq,
Eq,
serde::Deserialize,
serde::Serialize,
deserr::Deserr,
)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)]
enum WeightSource {
#[default]
Safetensors,
Pytorch,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct EmbedderOptions {
pub model: String,
pub revision: Option<String>,
pub distribution: Option<DistributionShift>,
}
impl EmbedderOptions {
pub fn new() -> Self {
Self {
model: "BAAI/bge-base-en-v1.5".to_string(),
revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()),
distribution: None,
}
}
}
impl Default for EmbedderOptions {
fn default() -> Self {
Self::new()
}
}
/// Perform embedding of documents and queries
pub struct Embedder {
model: BertModel,
tokenizer: Tokenizer,
options: EmbedderOptions,
dimensions: usize,
}
impl std::fmt::Debug for Embedder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Embedder")
.field("model", &self.options.model)
.field("tokenizer", &self.tokenizer)
.field("options", &self.options)
.finish()
}
}
impl Embedder {
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
let device = match candle_core::Device::cuda_if_available(0) {
Ok(device) => device,
Err(error) => {
tracing::warn!("could not initialize CUDA device for Hugging Face embedder, defaulting to CPU: {}", error);
candle_core::Device::Cpu
}
};
let repo = match options.revision.clone() {
Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision),
None => Repo::model(options.model.clone()),
};
let (config_filename, tokenizer_filename, weights_filename, weight_source) = {
let api = Api::new().map_err(NewEmbedderError::new_api_fail)?;
let api = api.repo(repo);
let config = api.get("config.json").map_err(NewEmbedderError::api_get)?;
let tokenizer = api.get("tokenizer.json").map_err(NewEmbedderError::api_get)?;
let (weights, source) = {
api.get("model.safetensors")
.map(|filename| (filename, WeightSource::Safetensors))
.or_else(|_| {
api.get("pytorch_model.bin")
.map(|filename| (filename, WeightSource::Pytorch))
})
.map_err(NewEmbedderError::api_get)?
};
(config, tokenizer, weights, source)
};
let config = std::fs::read_to_string(&config_filename)
.map_err(|inner| NewEmbedderError::open_config(config_filename.clone(), inner))?;
let config: Config = serde_json::from_str(&config).map_err(|inner| {
NewEmbedderError::deserialize_config(
options.model.clone(),
config,
config_filename,
inner,
)
})?;
let mut tokenizer = Tokenizer::from_file(&tokenizer_filename)
.map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?;
let vb = match weight_source {
WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device)
.map_err(NewEmbedderError::pytorch_weight)?,
WeightSource::Safetensors => unsafe {
VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)
.map_err(NewEmbedderError::safetensor_weight)?
},
};
let model = BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?;
if let Some(pp) = tokenizer.get_padding_mut() {
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
} else {
let pp = PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
..Default::default()
};
tokenizer.with_padding(Some(pp));
}
let mut this = Self { model, tokenizer, options, dimensions: 0 };
let embeddings = this
.embed(vec!["test".into()])
.map_err(NewEmbedderError::could_not_determine_dimension)?;
this.dimensions = embeddings.first().unwrap().dimension();
Ok(this)
}
pub fn embed(
&self,
mut texts: Vec<String>,
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
let tokens = match texts.len() {
1 => vec![self
.tokenizer
.encode(texts.pop().unwrap(), true)
.map_err(EmbedError::tokenize)?],
_ => self.tokenizer.encode_batch(texts, true).map_err(EmbedError::tokenize)?,
};
let token_ids = tokens
.iter()
.map(|tokens| {
let mut tokens = tokens.get_ids().to_vec();
tokens.truncate(512);
Tensor::new(tokens.as_slice(), &self.model.device).map_err(EmbedError::tensor_shape)
})
.collect::<Result<Vec<_>, EmbedError>>()?;
let token_ids = Tensor::stack(&token_ids, 0).map_err(EmbedError::tensor_shape)?;
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
let embeddings =
self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?;
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) =
embeddings.dims3().map_err(EmbedError::tensor_shape)?;
let embeddings = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
.map_err(EmbedError::tensor_shape)?;
let embeddings: Vec<Embedding> = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?;
Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect())
}
pub fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect()
}
pub fn chunk_count_hint(&self) -> usize {
1
}
pub fn prompt_count_in_chunk_hint(&self) -> usize {
std::thread::available_parallelism().map(|x| x.get()).unwrap_or(8)
}
pub fn dimensions(&self) -> usize {
self.dimensions
}
pub fn distribution(&self) -> Option<DistributionShift> {
self.options.distribution.or_else(|| {
if self.options.model == "BAAI/bge-base-en-v1.5" {
Some(DistributionShift {
current_mean: ordered_float::OrderedFloat(0.85),
current_sigma: ordered_float::OrderedFloat(0.1),
})
} else {
None
}
})
}
}

View file

@ -0,0 +1,970 @@
//! Module to manipulate JSON templates.
//!
//! This module allows two main operations:
//! 1. Render JSON values from a template and a context value.
//! 2. Retrieve data from a template and JSON values.
#![warn(rustdoc::broken_intra_doc_links)]
#![warn(missing_docs)]
use serde::Deserialize;
use serde_json::{Map, Value};
type ValuePath = Vec<PathComponent>;
/// Encapsulates a JSON template and allows injecting and extracting values from it.
#[derive(Debug)]
pub struct ValueTemplate {
template: Value,
value_kind: ValueKind,
}
#[derive(Debug)]
enum ValueKind {
Single(ValuePath),
Array(ArrayPath),
}
#[derive(Debug)]
struct ArrayPath {
repeated_value: Value,
path_to_array: ValuePath,
value_path_in_array: ValuePath,
}
/// Component of a path to a Value
#[derive(Debug, Clone)]
pub enum PathComponent {
/// A key inside of an object
MapKey(String),
/// An index inside of an array
ArrayIndex(usize),
}
impl PartialEq for PathComponent {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::MapKey(l0), Self::MapKey(r0)) => l0 == r0,
(Self::ArrayIndex(l0), Self::ArrayIndex(r0)) => l0 == r0,
_ => false,
}
}
}
impl Eq for PathComponent {}
/// Error that occurs when no few value was provided to a template for injection.
#[derive(Debug)]
pub struct MissingValue;
/// Error that occurs when trying to parse a template in [`ValueTemplate::new`]
#[derive(Debug)]
pub enum TemplateParsingError {
/// A repeat string appears inside a repeated value
NestedRepeatString(ValuePath),
/// A repeat string appears outside of an array
RepeatStringNotInArray(ValuePath),
/// A repeat string appears in an array, but not in the second position
BadIndexForRepeatString(ValuePath, usize),
/// A repeated value lacks a placeholder
MissingPlaceholderInRepeatedValue(ValuePath),
/// Multiple repeat string appear in the template
MultipleRepeatString(ValuePath, ValuePath),
/// Multiple placeholder strings appear in the template
MultiplePlaceholderString(ValuePath, ValuePath),
/// No placeholder string appear in the template
MissingPlaceholderString,
/// A placeholder appears both inside a repeated value and outside of it
BothArrayAndSingle {
/// Path to the single value
single_path: ValuePath,
/// Path to the array of repeated values
path_to_array: ValuePath,
/// Path to placeholder inside each repeated value, starting from the array
array_to_placeholder: ValuePath,
},
}
impl TemplateParsingError {
/// Produce an error message from the error kind, the name of the root object, the placeholder string and the repeat string
pub fn error_message(&self, root: &str, placeholder: &str, repeat: &str) -> String {
match self {
TemplateParsingError::NestedRepeatString(path) => {
format!(
r#"in {}: "{repeat}" appears nested inside of a value that is itself repeated"#,
path_with_root(root, path)
)
}
TemplateParsingError::RepeatStringNotInArray(path) => format!(
r#"in {}: "{repeat}" appears outside of an array"#,
path_with_root(root, path)
),
TemplateParsingError::BadIndexForRepeatString(path, index) => format!(
r#"in {}: "{repeat}" expected at position #1, but found at position #{index}"#,
path_with_root(root, path)
),
TemplateParsingError::MissingPlaceholderInRepeatedValue(path) => format!(
r#"in {}: Expected "{placeholder}" inside of the repeated value"#,
path_with_root(root, path)
),
TemplateParsingError::MultipleRepeatString(current, previous) => format!(
r#"in {}: Found "{repeat}", but it was already present in {}"#,
path_with_root(root, current),
path_with_root(root, previous)
),
TemplateParsingError::MultiplePlaceholderString(current, previous) => format!(
r#"in {}: Found "{placeholder}", but it was already present in {}"#,
path_with_root(root, current),
path_with_root(root, previous)
),
TemplateParsingError::MissingPlaceholderString => {
format!(r#"in `{root}`: "{placeholder}" not found"#)
}
TemplateParsingError::BothArrayAndSingle {
single_path,
path_to_array,
array_to_placeholder,
} => {
let path_to_first_repeated = path_to_array
.iter()
.chain(std::iter::once(&PathComponent::ArrayIndex(0)))
.chain(array_to_placeholder.iter());
format!(
r#"in {}: Found "{placeholder}", but it was already present in {} (repeated)"#,
path_with_root(root, single_path),
path_with_root(root, path_to_first_repeated)
)
}
}
}
fn prepend_path(self, mut prepended_path: ValuePath) -> Self {
match self {
TemplateParsingError::NestedRepeatString(mut path) => {
prepended_path.append(&mut path);
TemplateParsingError::NestedRepeatString(prepended_path)
}
TemplateParsingError::RepeatStringNotInArray(mut path) => {
prepended_path.append(&mut path);
TemplateParsingError::RepeatStringNotInArray(prepended_path)
}
TemplateParsingError::BadIndexForRepeatString(mut path, index) => {
prepended_path.append(&mut path);
TemplateParsingError::BadIndexForRepeatString(prepended_path, index)
}
TemplateParsingError::MissingPlaceholderInRepeatedValue(mut path) => {
prepended_path.append(&mut path);
TemplateParsingError::MissingPlaceholderInRepeatedValue(prepended_path)
}
TemplateParsingError::MultipleRepeatString(mut path, older_path) => {
let older_prepended_path =
prepended_path.iter().cloned().chain(older_path).collect();
prepended_path.append(&mut path);
TemplateParsingError::MultipleRepeatString(prepended_path, older_prepended_path)
}
TemplateParsingError::MultiplePlaceholderString(mut path, older_path) => {
let older_prepended_path =
prepended_path.iter().cloned().chain(older_path).collect();
prepended_path.append(&mut path);
TemplateParsingError::MultiplePlaceholderString(
prepended_path,
older_prepended_path,
)
}
TemplateParsingError::MissingPlaceholderString => {
TemplateParsingError::MissingPlaceholderString
}
TemplateParsingError::BothArrayAndSingle {
single_path,
mut path_to_array,
array_to_placeholder,
} => {
// note, this case is not super logical, but is also likely to be dead code
let single_prepended_path =
prepended_path.iter().cloned().chain(single_path).collect();
prepended_path.append(&mut path_to_array);
// we don't prepend the array_to_placeholder path as it is the array path that is prepended
TemplateParsingError::BothArrayAndSingle {
single_path: single_prepended_path,
path_to_array: prepended_path,
array_to_placeholder,
}
}
}
}
}
/// Error that occurs when [`ValueTemplate::extract`] fails.
#[derive(Debug)]
pub struct ExtractionError {
/// The cause of the failure
pub kind: ExtractionErrorKind,
/// The context where the failure happened: the operation that failed
pub context: ExtractionErrorContext,
}
impl ExtractionError {
/// Produce an error message from the error, the name of the root object, the placeholder string and the expected value type
pub fn error_message(
&self,
root: &str,
placeholder: &str,
expected_value_type: &str,
) -> String {
let context = match &self.context {
ExtractionErrorContext::ExtractingSingleValue => {
format!(r#"extracting a single "{placeholder}""#)
}
ExtractionErrorContext::FindingPathToArray => {
format!(r#"extracting the array of "{placeholder}"s"#)
}
ExtractionErrorContext::ExtractingArrayItem(index) => {
format!(r#"extracting item #{index} from the array of "{placeholder}"s"#)
}
};
match &self.kind {
ExtractionErrorKind::MissingPathComponent { missing_index, path, key_suggestion } => {
let last_named_object = last_named_object(root, path.iter().take(*missing_index));
format!(
"in {}, while {context}, configuration expects {}, which is missing in response{}",
path_with_root(root, path.iter().take(*missing_index)),
missing_component(path.get(*missing_index)),
match key_suggestion {
Some(key_suggestion) => format!("\n - Hint: {last_named_object} has key `{key_suggestion}`, did you mean {} in embedder configuration?",
path_with_root(root, path.iter().take(*missing_index).chain(std::iter::once(&PathComponent::MapKey(key_suggestion.to_owned()))))),
None => "".to_owned(),
}
)
}
ExtractionErrorKind::WrongPathComponent { wrong_component, index, path } => {
let last_named_object = last_named_object(root, path.iter().take(*index));
format!(
"in {}, while {context}, configuration expects {last_named_object} to be {} but server sent {wrong_component}",
path_with_root(root, path.iter().take(*index)),
expected_component(path.get(*index))
)
}
ExtractionErrorKind::DeserializationError { error, path } => {
let last_named_object = last_named_object(root, path);
format!(
"in {}, while {context}, expected {last_named_object} to be {expected_value_type}, but failed to parse server response:\n - {error}",
path_with_root(root, path)
)
}
}
}
}
fn missing_component(component: Option<&PathComponent>) -> String {
match component {
Some(PathComponent::ArrayIndex(index)) => {
format!(r#"item #{index}"#)
}
Some(PathComponent::MapKey(key)) => {
format!(r#"key "{key}""#)
}
None => "unknown".to_string(),
}
}
fn expected_component(component: Option<&PathComponent>) -> String {
match component {
Some(PathComponent::ArrayIndex(index)) => {
format!(r#"an array with at least {} item(s)"#, index.saturating_add(1))
}
Some(PathComponent::MapKey(key)) => {
format!("an object with key `{}`", key)
}
None => "unknown".to_string(),
}
}
fn last_named_object<'a>(
root: &'a str,
path: impl IntoIterator<Item = &'a PathComponent> + 'a,
) -> LastNamedObject<'a> {
let mut last_named_object = LastNamedObject::Object { name: root };
for component in path.into_iter() {
last_named_object = match (component, last_named_object) {
(PathComponent::MapKey(name), _) => LastNamedObject::Object { name },
(PathComponent::ArrayIndex(index), LastNamedObject::Object { name }) => {
LastNamedObject::ArrayInsideObject { object_name: name, index: *index }
}
(
PathComponent::ArrayIndex(index),
LastNamedObject::ArrayInsideObject { object_name, index: _ },
) => LastNamedObject::NestedArrayInsideObject {
object_name,
index: *index,
nesting_level: 0,
},
(
PathComponent::ArrayIndex(index),
LastNamedObject::NestedArrayInsideObject { object_name, index: _, nesting_level },
) => LastNamedObject::NestedArrayInsideObject {
object_name,
index: *index,
nesting_level: nesting_level.saturating_add(1),
},
}
}
last_named_object
}
impl<'a> std::fmt::Display for LastNamedObject<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LastNamedObject::Object { name } => write!(f, "`{name}`"),
LastNamedObject::ArrayInsideObject { object_name, index } => {
write!(f, "item #{index} inside `{object_name}`")
}
LastNamedObject::NestedArrayInsideObject { object_name, index, nesting_level } => {
if *nesting_level == 0 {
write!(f, "item #{index} inside nested array in `{object_name}`")
} else {
write!(f, "item #{index} inside nested array ({} levels of nesting) in `{object_name}`", nesting_level + 1)
}
}
}
}
}
#[derive(Debug, Clone, Copy)]
enum LastNamedObject<'a> {
Object { name: &'a str },
ArrayInsideObject { object_name: &'a str, index: usize },
NestedArrayInsideObject { object_name: &'a str, index: usize, nesting_level: usize },
}
/// Builds a string representation of a path, preprending the name of the root value.
pub fn path_with_root<'a>(
root: &str,
path: impl IntoIterator<Item = &'a PathComponent> + 'a,
) -> String {
use std::fmt::Write as _;
let mut res = format!("`{root}");
for component in path.into_iter() {
match component {
PathComponent::MapKey(key) => {
let _ = write!(&mut res, ".{key}");
}
PathComponent::ArrayIndex(index) => {
let _ = write!(&mut res, "[{index}]");
}
}
}
res.push('`');
res
}
/// Context where an extraction failure happened
///
/// The operation that failed
#[derive(Debug, Clone, Copy)]
pub enum ExtractionErrorContext {
/// Failure happened while extracting a value at a single location
ExtractingSingleValue,
/// Failure happened while extracting an array of values
FindingPathToArray,
/// Failure happened while extracting a value inside of an array
ExtractingArrayItem(usize),
}
/// Kind of errors that can happen during extraction
#[derive(Debug)]
pub enum ExtractionErrorKind {
/// An expected path component is missing
MissingPathComponent {
/// Index of the missing component in the path
missing_index: usize,
/// Path where a component is missing
path: ValuePath,
/// Possible matching key in object
key_suggestion: Option<String>,
},
/// An expected path component cannot be found because its container is the wrong type
WrongPathComponent {
/// String representation of the wrong component
wrong_component: String,
/// Index of the wrong component in the path
index: usize,
/// Path where a component has the wrong type
path: ValuePath,
},
/// Could not deserialize an extracted value to its requested type
DeserializationError {
/// inner deserialization error
error: serde_json::Error,
/// path to extracted value
path: ValuePath,
},
}
enum ArrayParsingContext<'a> {
Nested,
NotNested(&'a mut Option<ArrayPath>),
}
impl ValueTemplate {
/// Prepare a template for injection or extraction.
///
/// # Parameters
///
/// - `template`: JSON value that acts a template. Its placeholder values will be replaced by actual values during injection,
/// and actual values will be recovered from their location during extraction.
/// - `placeholder_string`: Value that a JSON string should assume to act as a placeholder value that can be injected into or
/// extracted from.
/// - `repeat_string`: Sentinel value that can be placed as the second value in an array to indicate that the first value can be repeated
/// any number of times. The first value should contain exactly one placeholder string.
///
/// # Errors
///
/// - [`TemplateParsingError`]: refer to the documentation of this type
pub fn new(
template: Value,
placeholder_string: &str,
repeat_string: &str,
) -> Result<Self, TemplateParsingError> {
let mut value_path = None;
let mut array_path = None;
let mut current_path = Vec::new();
Self::parse_value(
&template,
placeholder_string,
repeat_string,
&mut value_path,
&mut ArrayParsingContext::NotNested(&mut array_path),
&mut current_path,
)?;
let value_kind = match (array_path, value_path) {
(None, None) => return Err(TemplateParsingError::MissingPlaceholderString),
(None, Some(value_path)) => ValueKind::Single(value_path),
(Some(array_path), None) => ValueKind::Array(array_path),
(Some(array_path), Some(value_path)) => {
return Err(TemplateParsingError::BothArrayAndSingle {
single_path: value_path,
path_to_array: array_path.path_to_array,
array_to_placeholder: array_path.value_path_in_array,
})
}
};
Ok(Self { template, value_kind })
}
/// Whether there is a placeholder that can be repeated.
///
/// - During injection, all values are injected in the array placeholder,
/// - During extraction, all repeatable placeholders are extracted from the array.
pub fn has_array_value(&self) -> bool {
matches!(self.value_kind, ValueKind::Array(_))
}
/// Render a value from the template and context values.
///
/// # Error
///
/// - [`MissingValue`]: if the number of injected values is 0.
pub fn inject(&self, values: impl IntoIterator<Item = Value>) -> Result<Value, MissingValue> {
let mut rendered = self.template.clone();
let mut values = values.into_iter();
match &self.value_kind {
ValueKind::Single(injection_path) => {
let Some(injected_value) = values.next() else { return Err(MissingValue) };
inject_value(&mut rendered, injection_path, injected_value);
}
ValueKind::Array(ArrayPath { repeated_value, path_to_array, value_path_in_array }) => {
// 1. build the array of repeated values
let mut array = Vec::new();
for injected_value in values {
let mut repeated_value = repeated_value.clone();
inject_value(&mut repeated_value, value_path_in_array, injected_value);
array.push(repeated_value);
}
if array.is_empty() {
return Err(MissingValue);
}
// 2. inject at the injection point in the rendered value
inject_value(&mut rendered, path_to_array, Value::Array(array));
}
}
Ok(rendered)
}
/// Extract sub values from the template and a value.
///
/// # Errors
///
/// - if a single placeholder is missing.
/// - if there is no value corresponding to an array placeholder
/// - if the value corresponding to an array placeholder is not an array
pub fn extract<T>(&self, mut value: Value) -> Result<Vec<T>, ExtractionError>
where
T: for<'de> Deserialize<'de>,
{
Ok(match &self.value_kind {
ValueKind::Single(extraction_path) => {
let extracted_value =
extract_value(extraction_path, &mut value).with_context(|kind| {
ExtractionError {
kind,
context: ExtractionErrorContext::ExtractingSingleValue,
}
})?;
vec![extracted_value]
}
ValueKind::Array(ArrayPath {
repeated_value: _,
path_to_array,
value_path_in_array,
}) => {
// get the array
let array = extract_value(path_to_array, &mut value).with_context(|kind| {
ExtractionError { kind, context: ExtractionErrorContext::FindingPathToArray }
})?;
let array = match array {
Value::Array(array) => array,
not_array => {
let mut path = path_to_array.clone();
path.push(PathComponent::ArrayIndex(0));
return Err(ExtractionError {
kind: ExtractionErrorKind::WrongPathComponent {
wrong_component: format_value(&not_array),
index: path_to_array.len(),
path,
},
context: ExtractionErrorContext::FindingPathToArray,
});
}
};
let mut extracted_values = Vec::with_capacity(array.len());
for (index, mut item) in array.into_iter().enumerate() {
let extracted_value = extract_value(value_path_in_array, &mut item)
.with_context(|kind| ExtractionError {
kind,
context: ExtractionErrorContext::ExtractingArrayItem(index),
})?;
extracted_values.push(extracted_value);
}
extracted_values
}
})
}
fn parse_array(
array: &[Value],
placeholder_string: &str,
repeat_string: &str,
value_path: &mut Option<ValuePath>,
mut array_path: &mut ArrayParsingContext,
current_path: &mut ValuePath,
) -> Result<(), TemplateParsingError> {
// two modes for parsing array.
match array {
// 1. array contains a repeat string in second position
[first, second, rest @ ..] if second == repeat_string => {
let ArrayParsingContext::NotNested(array_path) = &mut array_path else {
return Err(TemplateParsingError::NestedRepeatString(current_path.clone()));
};
if let Some(array_path) = array_path {
return Err(TemplateParsingError::MultipleRepeatString(
current_path.clone(),
array_path.path_to_array.clone(),
));
}
if first == repeat_string {
return Err(TemplateParsingError::BadIndexForRepeatString(
current_path.clone(),
0,
));
}
if let Some(position) = rest.iter().position(|value| value == repeat_string) {
let position = position + 2;
return Err(TemplateParsingError::BadIndexForRepeatString(
current_path.clone(),
position,
));
}
let value_path_in_array = {
let mut value_path = None;
let mut current_path_in_array = Vec::new();
Self::parse_value(
first,
placeholder_string,
repeat_string,
&mut value_path,
&mut ArrayParsingContext::Nested,
&mut current_path_in_array,
)
.map_err(|error| error.prepend_path(current_path.to_vec()))?;
value_path.ok_or_else(|| {
let mut repeated_value_path = current_path.clone();
repeated_value_path.push(PathComponent::ArrayIndex(0));
TemplateParsingError::MissingPlaceholderInRepeatedValue(repeated_value_path)
})?
};
**array_path = Some(ArrayPath {
repeated_value: first.to_owned(),
path_to_array: current_path.clone(),
value_path_in_array,
});
}
// 2. array does not contain a repeat string
array => {
if let Some(position) = array.iter().position(|value| value == repeat_string) {
return Err(TemplateParsingError::BadIndexForRepeatString(
current_path.clone(),
position,
));
}
for (index, value) in array.iter().enumerate() {
current_path.push(PathComponent::ArrayIndex(index));
Self::parse_value(
value,
placeholder_string,
repeat_string,
value_path,
array_path,
current_path,
)?;
current_path.pop();
}
}
}
Ok(())
}
fn parse_object(
object: &Map<String, Value>,
placeholder_string: &str,
repeat_string: &str,
value_path: &mut Option<ValuePath>,
array_path: &mut ArrayParsingContext,
current_path: &mut ValuePath,
) -> Result<(), TemplateParsingError> {
for (key, value) in object.iter() {
current_path.push(PathComponent::MapKey(key.to_owned()));
Self::parse_value(
value,
placeholder_string,
repeat_string,
value_path,
array_path,
current_path,
)?;
current_path.pop();
}
Ok(())
}
fn parse_value(
value: &Value,
placeholder_string: &str,
repeat_string: &str,
value_path: &mut Option<ValuePath>,
array_path: &mut ArrayParsingContext,
current_path: &mut ValuePath,
) -> Result<(), TemplateParsingError> {
match value {
Value::String(str) => {
if placeholder_string == str {
if let Some(value_path) = value_path {
return Err(TemplateParsingError::MultiplePlaceholderString(
current_path.clone(),
value_path.clone(),
));
}
*value_path = Some(current_path.clone());
}
if repeat_string == str {
return Err(TemplateParsingError::RepeatStringNotInArray(current_path.clone()));
}
}
Value::Null | Value::Bool(_) | Value::Number(_) => {}
Value::Array(array) => Self::parse_array(
array,
placeholder_string,
repeat_string,
value_path,
array_path,
current_path,
)?,
Value::Object(object) => Self::parse_object(
object,
placeholder_string,
repeat_string,
value_path,
array_path,
current_path,
)?,
}
Ok(())
}
}
fn inject_value(rendered: &mut Value, injection_path: &Vec<PathComponent>, injected_value: Value) {
let mut current_value = rendered;
for injection_component in injection_path {
current_value = match injection_component {
PathComponent::MapKey(key) => current_value.get_mut(key).unwrap(),
PathComponent::ArrayIndex(index) => current_value.get_mut(index).unwrap(),
}
}
*current_value = injected_value;
}
fn format_value(value: &Value) -> String {
match value {
Value::Array(array) => format!("an array of size {}", array.len()),
Value::Object(object) => {
format!("an object with {} field(s)", object.len())
}
value => value.to_string(),
}
}
fn extract_value<T>(
extraction_path: &[PathComponent],
initial_value: &mut Value,
) -> Result<T, ExtractionErrorKind>
where
T: for<'de> Deserialize<'de>,
{
let mut current_value = initial_value;
for (path_index, extraction_component) in extraction_path.iter().enumerate() {
current_value = {
match extraction_component {
PathComponent::MapKey(key) => {
if !current_value.is_object() {
return Err(ExtractionErrorKind::WrongPathComponent {
wrong_component: format_value(current_value),
index: path_index,
path: extraction_path.to_vec(),
});
}
if let Some(object) = current_value.as_object_mut() {
if !object.contains_key(key) {
let typos =
levenshtein_automata::LevenshteinAutomatonBuilder::new(2, true)
.build_dfa(key);
let mut key_suggestion = None;
'check_typos: for (key, _) in object.iter() {
match typos.eval(key) {
levenshtein_automata::Distance::Exact(0) => { /* ??? */ }
levenshtein_automata::Distance::Exact(_) => {
key_suggestion = Some(key.to_owned());
break 'check_typos;
}
levenshtein_automata::Distance::AtLeast(_) => continue,
}
}
return Err(ExtractionErrorKind::MissingPathComponent {
missing_index: path_index,
path: extraction_path.to_vec(),
key_suggestion,
});
}
if let Some(value) = object.get_mut(key) {
value
} else {
// borrow checking limit: the borrow checker cannot be convinced that `object` is no longer mutably borrowed on the
// `else` branch of the `if let`, so we cannot return MissingPathComponent here.
// As a workaround, we checked that the object does not contain the key above, making this `else` unreachable.
unreachable!()
}
} else {
// borrow checking limit: the borrow checker cannot be convinced that `current_value` is no longer mutably borrowed
// on the `else` branch of the `if let`, so we cannot return WrongPathComponent here.
// As a workaround, we checked that the value was not a map above, making this `else` unreachable.
unreachable!()
}
}
PathComponent::ArrayIndex(index) => {
if !current_value.is_array() {
return Err(ExtractionErrorKind::WrongPathComponent {
wrong_component: format_value(current_value),
index: path_index,
path: extraction_path.to_vec(),
});
}
match current_value.get_mut(index) {
Some(value) => value,
None => {
return Err(ExtractionErrorKind::MissingPathComponent {
missing_index: path_index,
path: extraction_path.to_vec(),
key_suggestion: None,
});
}
}
}
}
};
}
serde_json::from_value(current_value.take()).map_err(|error| {
ExtractionErrorKind::DeserializationError { error, path: extraction_path.to_vec() }
})
}
trait ExtractionResultErrorContext<T> {
fn with_context<F>(self, f: F) -> Result<T, ExtractionError>
where
F: FnOnce(ExtractionErrorKind) -> ExtractionError;
}
impl<T> ExtractionResultErrorContext<T> for Result<T, ExtractionErrorKind> {
fn with_context<F>(self, f: F) -> Result<T, ExtractionError>
where
F: FnOnce(ExtractionErrorKind) -> ExtractionError,
{
match self {
Ok(t) => Ok(t),
Err(kind) => Err(f(kind)),
}
}
}
#[cfg(test)]
mod test {
use serde_json::{json, Value};
use super::{PathComponent, TemplateParsingError, ValueTemplate};
fn new_template(template: Value) -> Result<ValueTemplate, TemplateParsingError> {
ValueTemplate::new(template, "{{text}}", "{{..}}")
}
#[test]
fn empty_template() {
let template = json!({
"toto": "no template at all",
"titi": ["this", "will", "not", "work"],
"tutu": null
});
let error = new_template(template.clone()).unwrap_err();
assert!(matches!(error, TemplateParsingError::MissingPlaceholderString))
}
#[test]
fn single_template() {
let template = json!({
"toto": "text",
"titi": ["this", "will", "still", "{{text}}"],
"tutu": null
});
let basic = new_template(template.clone()).unwrap();
assert!(!basic.has_array_value());
assert_eq!(
basic.inject(vec!["work".into(), Value::Null, "test".into()]).unwrap(),
json!({
"toto": "text",
"titi": ["this", "will", "still", "work"],
"tutu": null
})
);
}
#[test]
fn too_many_placeholders() {
let template = json!({
"toto": "{{text}}",
"titi": ["this", "will", "still", "{{text}}"],
"tutu": "text"
});
match new_template(template.clone()) {
Err(TemplateParsingError::MultiplePlaceholderString(left, right)) => {
assert_eq!(
left,
vec![PathComponent::MapKey("titi".into()), PathComponent::ArrayIndex(3)]
);
assert_eq!(right, vec![PathComponent::MapKey("toto".into())])
}
_ => panic!("should error"),
}
}
#[test]
fn dynamic_template() {
let template = json!({
"toto": "text",
"titi": [{
"type": "text",
"data": "{{text}}"
}, "{{..}}"],
"tutu": null
});
let basic = new_template(template.clone()).unwrap();
assert!(basic.has_array_value());
let injected_values = vec![
"work".into(),
Value::Null,
42.into(),
"test".into(),
"tata".into(),
"titi".into(),
"tutu".into(),
];
let rendered = basic.inject(injected_values.clone()).unwrap();
assert_eq!(
rendered,
json!({
"toto": "text",
"titi": [
{
"type": "text",
"data": "work"
},
{
"type": "text",
"data": Value::Null
},
{
"type": "text",
"data": 42
},
{
"type": "text",
"data": "test"
},
{
"type": "text",
"data": "tata"
},
{
"type": "text",
"data": "titi"
},
{
"type": "text",
"data": "tutu"
}
],
"tutu": null
})
);
let extracted_values: Vec<Value> = basic.extract(rendered).unwrap();
assert_eq!(extracted_values, injected_values);
}
}

View file

@ -0,0 +1,40 @@
use super::error::EmbedError;
use super::{DistributionShift, Embeddings};
#[derive(Debug, Clone, Copy)]
pub struct Embedder {
dimensions: usize,
distribution: Option<DistributionShift>,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct EmbedderOptions {
pub dimensions: usize,
pub distribution: Option<DistributionShift>,
}
impl Embedder {
pub fn new(options: EmbedderOptions) -> Self {
Self { dimensions: options.dimensions, distribution: options.distribution }
}
pub fn embed(&self, mut texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
let Some(text) = texts.pop() else { return Ok(Default::default()) };
Err(EmbedError::embed_on_manual_embedder(text.chars().take(250).collect()))
}
pub fn dimensions(&self) -> usize {
self.dimensions
}
pub fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect()
}
pub fn distribution(&self) -> Option<DistributionShift> {
self.distribution
}
}

View file

@ -0,0 +1,606 @@
use std::collections::HashMap;
use std::sync::Arc;
use arroy::distances::{Angular, BinaryQuantizedAngular};
use arroy::ItemId;
use deserr::{DeserializeError, Deserr};
use heed::{RoTxn, RwTxn, Unspecified};
use ordered_float::OrderedFloat;
use roaring::RoaringBitmap;
use serde::{Deserialize, Serialize};
use self::error::{EmbedError, NewEmbedderError};
use crate::prompt::{Prompt, PromptData};
use crate::ThreadPoolNoAbort;
pub mod error;
pub mod hf;
pub mod json_template;
pub mod manual;
pub mod openai;
pub mod parsed_vectors;
pub mod settings;
pub mod ollama;
pub mod rest;
pub use self::error::Error;
pub type Embedding = Vec<f32>;
pub const REQUEST_PARALLELISM: usize = 40;
pub struct ArroyWrapper {
quantized: bool,
index: u16,
database: arroy::Database<Unspecified>,
}
impl ArroyWrapper {
pub fn new(database: arroy::Database<Unspecified>, index: u16, quantized: bool) -> Self {
Self { database, index, quantized }
}
pub fn index(&self) -> u16 {
self.index
}
pub fn dimensions(&self, rtxn: &RoTxn) -> Result<usize, arroy::Error> {
if self.quantized {
Ok(arroy::Reader::open(rtxn, self.index, self.quantized_db())?.dimensions())
} else {
Ok(arroy::Reader::open(rtxn, self.index, self.angular_db())?.dimensions())
}
}
pub fn quantize(
&mut self,
wtxn: &mut RwTxn,
index: u16,
dimension: usize,
) -> Result<(), arroy::Error> {
if !self.quantized {
let writer = arroy::Writer::new(self.angular_db(), index, dimension);
writer.prepare_changing_distance::<BinaryQuantizedAngular>(wtxn)?;
self.quantized = true;
}
Ok(())
}
pub fn need_build(&self, rtxn: &RoTxn, dimension: usize) -> Result<bool, arroy::Error> {
if self.quantized {
arroy::Writer::new(self.quantized_db(), self.index, dimension).need_build(rtxn)
} else {
arroy::Writer::new(self.angular_db(), self.index, dimension).need_build(rtxn)
}
}
pub fn build<R: rand::Rng + rand::SeedableRng>(
&self,
wtxn: &mut RwTxn,
rng: &mut R,
dimension: usize,
) -> Result<(), arroy::Error> {
if self.quantized {
arroy::Writer::new(self.quantized_db(), self.index, dimension).build(wtxn, rng, None)
} else {
arroy::Writer::new(self.angular_db(), self.index, dimension).build(wtxn, rng, None)
}
}
pub fn add_item(
&self,
wtxn: &mut RwTxn,
dimension: usize,
item_id: arroy::ItemId,
vector: &[f32],
) -> Result<(), arroy::Error> {
if self.quantized {
arroy::Writer::new(self.quantized_db(), self.index, dimension)
.add_item(wtxn, item_id, vector)
} else {
arroy::Writer::new(self.angular_db(), self.index, dimension)
.add_item(wtxn, item_id, vector)
}
}
pub fn del_item(
&self,
wtxn: &mut RwTxn,
dimension: usize,
item_id: arroy::ItemId,
) -> Result<bool, arroy::Error> {
if self.quantized {
arroy::Writer::new(self.quantized_db(), self.index, dimension).del_item(wtxn, item_id)
} else {
arroy::Writer::new(self.angular_db(), self.index, dimension).del_item(wtxn, item_id)
}
}
pub fn clear(&self, wtxn: &mut RwTxn, dimension: usize) -> Result<(), arroy::Error> {
if self.quantized {
arroy::Writer::new(self.quantized_db(), self.index, dimension).clear(wtxn)
} else {
arroy::Writer::new(self.angular_db(), self.index, dimension).clear(wtxn)
}
}
pub fn is_empty(&self, rtxn: &RoTxn, dimension: usize) -> Result<bool, arroy::Error> {
if self.quantized {
arroy::Writer::new(self.quantized_db(), self.index, dimension).is_empty(rtxn)
} else {
arroy::Writer::new(self.angular_db(), self.index, dimension).is_empty(rtxn)
}
}
pub fn contains_item(
&self,
rtxn: &RoTxn,
dimension: usize,
item: arroy::ItemId,
) -> Result<bool, arroy::Error> {
if self.quantized {
arroy::Writer::new(self.quantized_db(), self.index, dimension).contains_item(rtxn, item)
} else {
arroy::Writer::new(self.angular_db(), self.index, dimension).contains_item(rtxn, item)
}
}
pub fn nns_by_item(
&self,
rtxn: &RoTxn,
item: ItemId,
limit: usize,
filter: Option<&RoaringBitmap>,
) -> Result<Option<Vec<(ItemId, f32)>>, arroy::Error> {
if self.quantized {
arroy::Reader::open(rtxn, self.index, self.quantized_db())?
.nns_by_item(rtxn, item, limit, None, None, filter)
} else {
arroy::Reader::open(rtxn, self.index, self.angular_db())?
.nns_by_item(rtxn, item, limit, None, None, filter)
}
}
pub fn nns_by_vector(
&self,
txn: &RoTxn,
item: &[f32],
limit: usize,
filter: Option<&RoaringBitmap>,
) -> Result<Vec<(ItemId, f32)>, arroy::Error> {
if self.quantized {
arroy::Reader::open(txn, self.index, self.quantized_db())?
.nns_by_vector(txn, item, limit, None, None, filter)
} else {
arroy::Reader::open(txn, self.index, self.angular_db())?
.nns_by_vector(txn, item, limit, None, None, filter)
}
}
pub fn item_vector(&self, rtxn: &RoTxn, docid: u32) -> Result<Option<Vec<f32>>, arroy::Error> {
if self.quantized {
arroy::Reader::open(rtxn, self.index, self.quantized_db())?.item_vector(rtxn, docid)
} else {
arroy::Reader::open(rtxn, self.index, self.angular_db())?.item_vector(rtxn, docid)
}
}
fn angular_db(&self) -> arroy::Database<Angular> {
self.database.remap_data_type()
}
fn quantized_db(&self) -> arroy::Database<BinaryQuantizedAngular> {
self.database.remap_data_type()
}
}
/// One or multiple embeddings stored consecutively in a flat vector.
pub struct Embeddings<F> {
data: Vec<F>,
dimension: usize,
}
impl<F> Embeddings<F> {
/// Declares an empty vector of embeddings of the specified dimensions.
pub fn new(dimension: usize) -> Self {
Self { data: Default::default(), dimension }
}
/// Declares a vector of embeddings containing a single element.
///
/// The dimension is inferred from the length of the passed embedding.
pub fn from_single_embedding(embedding: Vec<F>) -> Self {
Self { dimension: embedding.len(), data: embedding }
}
/// Declares a vector of embeddings from its components.
///
/// `data.len()` must be a multiple of `dimension`, otherwise an error is returned.
pub fn from_inner(data: Vec<F>, dimension: usize) -> Result<Self, Vec<F>> {
let mut this = Self::new(dimension);
this.append(data)?;
Ok(this)
}
/// Returns the number of embeddings in this vector of embeddings.
pub fn embedding_count(&self) -> usize {
self.data.len() / self.dimension
}
/// Dimension of a single embedding.
pub fn dimension(&self) -> usize {
self.dimension
}
/// Deconstructs self into the inner flat vector.
pub fn into_inner(self) -> Vec<F> {
self.data
}
/// A reference to the inner flat vector.
pub fn as_inner(&self) -> &[F] {
&self.data
}
/// Iterates over the embeddings contained in the flat vector.
pub fn iter(&self) -> impl Iterator<Item = &'_ [F]> + '_ {
self.data.as_slice().chunks_exact(self.dimension)
}
/// Push an embedding at the end of the embeddings.
///
/// If `embedding.len() != self.dimension`, then the push operation fails.
pub fn push(&mut self, mut embedding: Vec<F>) -> Result<(), Vec<F>> {
if embedding.len() != self.dimension {
return Err(embedding);
}
self.data.append(&mut embedding);
Ok(())
}
/// Append a flat vector of embeddings a the end of the embeddings.
///
/// If `embeddings.len() % self.dimension != 0`, then the append operation fails.
pub fn append(&mut self, mut embeddings: Vec<F>) -> Result<(), Vec<F>> {
if embeddings.len() % self.dimension != 0 {
return Err(embeddings);
}
self.data.append(&mut embeddings);
Ok(())
}
}
/// An embedder can be used to transform text into embeddings.
#[derive(Debug)]
pub enum Embedder {
/// An embedder based on running local models, fetched from the Hugging Face Hub.
HuggingFace(hf::Embedder),
/// An embedder based on making embedding queries against the OpenAI API.
OpenAi(openai::Embedder),
/// An embedder based on the user providing the embeddings in the documents and queries.
UserProvided(manual::Embedder),
/// An embedder based on making embedding queries against an <https://ollama.com> embedding server.
Ollama(ollama::Embedder),
/// An embedder based on making embedding queries against a generic JSON/REST embedding server.
Rest(rest::Embedder),
}
/// Configuration for an embedder.
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
pub struct EmbeddingConfig {
/// Options of the embedder, specific to each kind of embedder
pub embedder_options: EmbedderOptions,
/// Document template
pub prompt: PromptData,
/// If this embedder is binary quantized
pub quantized: Option<bool>,
// TODO: add metrics and anything needed
}
impl EmbeddingConfig {
pub fn quantized(&self) -> bool {
self.quantized.unwrap_or_default()
}
}
/// Map of embedder configurations.
///
/// Each configuration is mapped to a name.
#[derive(Clone, Default)]
pub struct EmbeddingConfigs(HashMap<String, (Arc<Embedder>, Arc<Prompt>, bool)>);
impl EmbeddingConfigs {
/// Create the map from its internal component.s
pub fn new(data: HashMap<String, (Arc<Embedder>, Arc<Prompt>, bool)>) -> Self {
Self(data)
}
/// Get an embedder configuration and template from its name.
pub fn get(&self, name: &str) -> Option<(Arc<Embedder>, Arc<Prompt>, bool)> {
self.0.get(name).cloned()
}
pub fn inner_as_ref(&self) -> &HashMap<String, (Arc<Embedder>, Arc<Prompt>, bool)> {
&self.0
}
pub fn into_inner(self) -> HashMap<String, (Arc<Embedder>, Arc<Prompt>, bool)> {
self.0
}
}
impl IntoIterator for EmbeddingConfigs {
type Item = (String, (Arc<Embedder>, Arc<Prompt>, bool));
type IntoIter =
std::collections::hash_map::IntoIter<String, (Arc<Embedder>, Arc<Prompt>, bool)>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
/// Options of an embedder, specific to each kind of embedder.
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub enum EmbedderOptions {
HuggingFace(hf::EmbedderOptions),
OpenAi(openai::EmbedderOptions),
Ollama(ollama::EmbedderOptions),
UserProvided(manual::EmbedderOptions),
Rest(rest::EmbedderOptions),
}
impl Default for EmbedderOptions {
fn default() -> Self {
Self::HuggingFace(Default::default())
}
}
impl Embedder {
/// Spawns a new embedder built from its options.
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
Ok(match options {
EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?),
EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?),
EmbedderOptions::Ollama(options) => Self::Ollama(ollama::Embedder::new(options)?),
EmbedderOptions::UserProvided(options) => {
Self::UserProvided(manual::Embedder::new(options))
}
EmbedderOptions::Rest(options) => {
Self::Rest(rest::Embedder::new(options, rest::ConfigurationSource::User)?)
}
})
}
/// Embed one or multiple texts.
///
/// Each text can be embedded as one or multiple embeddings.
pub fn embed(
&self,
texts: Vec<String>,
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
match self {
Embedder::HuggingFace(embedder) => embedder.embed(texts),
Embedder::OpenAi(embedder) => embedder.embed(texts),
Embedder::Ollama(embedder) => embedder.embed(texts),
Embedder::UserProvided(embedder) => embedder.embed(texts),
Embedder::Rest(embedder) => embedder.embed(texts),
}
}
pub fn embed_one(&self, text: String) -> std::result::Result<Embedding, EmbedError> {
let mut embeddings = self.embed(vec![text])?;
let embeddings = embeddings.pop().ok_or_else(EmbedError::missing_embedding)?;
Ok(if embeddings.iter().nth(1).is_some() {
tracing::warn!("Ignoring embeddings past the first one in long search query");
embeddings.iter().next().unwrap().to_vec()
} else {
embeddings.into_inner()
})
}
/// Embed multiple chunks of texts.
///
/// Each chunk is composed of one or multiple texts.
pub fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
threads: &ThreadPoolNoAbort,
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
match self {
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks, threads),
Embedder::Ollama(embedder) => embedder.embed_chunks(text_chunks, threads),
Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks),
Embedder::Rest(embedder) => embedder.embed_chunks(text_chunks, threads),
}
}
/// Indicates the preferred number of chunks to pass to [`Self::embed_chunks`]
pub fn chunk_count_hint(&self) -> usize {
match self {
Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(),
Embedder::OpenAi(embedder) => embedder.chunk_count_hint(),
Embedder::Ollama(embedder) => embedder.chunk_count_hint(),
Embedder::UserProvided(_) => 1,
Embedder::Rest(embedder) => embedder.chunk_count_hint(),
}
}
/// Indicates the preferred number of texts in a single chunk passed to [`Self::embed`]
pub fn prompt_count_in_chunk_hint(&self) -> usize {
match self {
Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::Ollama(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::UserProvided(_) => 1,
Embedder::Rest(embedder) => embedder.prompt_count_in_chunk_hint(),
}
}
/// Indicates the dimensions of a single embedding produced by the embedder.
pub fn dimensions(&self) -> usize {
match self {
Embedder::HuggingFace(embedder) => embedder.dimensions(),
Embedder::OpenAi(embedder) => embedder.dimensions(),
Embedder::Ollama(embedder) => embedder.dimensions(),
Embedder::UserProvided(embedder) => embedder.dimensions(),
Embedder::Rest(embedder) => embedder.dimensions(),
}
}
/// An optional distribution used to apply an affine transformation to the similarity score of a document.
pub fn distribution(&self) -> Option<DistributionShift> {
match self {
Embedder::HuggingFace(embedder) => embedder.distribution(),
Embedder::OpenAi(embedder) => embedder.distribution(),
Embedder::Ollama(embedder) => embedder.distribution(),
Embedder::UserProvided(embedder) => embedder.distribution(),
Embedder::Rest(embedder) => embedder.distribution(),
}
}
pub fn uses_document_template(&self) -> bool {
match self {
Embedder::HuggingFace(_)
| Embedder::OpenAi(_)
| Embedder::Ollama(_)
| Embedder::Rest(_) => true,
Embedder::UserProvided(_) => false,
}
}
}
/// Describes the mean and sigma of distribution of embedding similarity in the embedding space.
///
/// The intended use is to make the similarity score more comparable to the regular ranking score.
/// This allows to correct effects where results are too "packed" around a certain value.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize)]
#[serde(from = "DistributionShiftSerializable")]
#[serde(into = "DistributionShiftSerializable")]
pub struct DistributionShift {
/// Value where the results are "packed".
///
/// Similarity scores are translated so that they are packed around 0.5 instead
pub current_mean: OrderedFloat<f32>,
/// standard deviation of a similarity score.
///
/// Set below 0.4 to make the results less packed around the mean, and above 0.4 to make them more packed.
pub current_sigma: OrderedFloat<f32>,
}
impl<E> Deserr<E> for DistributionShift
where
E: DeserializeError,
{
fn deserialize_from_value<V: deserr::IntoValue>(
value: deserr::Value<V>,
location: deserr::ValuePointerRef<'_>,
) -> Result<Self, E> {
let value = DistributionShiftSerializable::deserialize_from_value(value, location)?;
if value.mean < 0. || value.mean > 1. {
return Err(deserr::take_cf_content(E::error::<std::convert::Infallible>(
None,
deserr::ErrorKind::Unexpected {
msg: format!(
"the distribution mean must be in the range [0, 1], got {}",
value.mean
),
},
location,
)));
}
if value.sigma <= 0. || value.sigma > 1. {
return Err(deserr::take_cf_content(E::error::<std::convert::Infallible>(
None,
deserr::ErrorKind::Unexpected {
msg: format!(
"the distribution sigma must be in the range ]0, 1], got {}",
value.sigma
),
},
location,
)));
}
Ok(value.into())
}
}
#[derive(Serialize, Deserialize, Deserr)]
#[serde(deny_unknown_fields)]
#[deserr(deny_unknown_fields)]
struct DistributionShiftSerializable {
mean: f32,
sigma: f32,
}
impl From<DistributionShift> for DistributionShiftSerializable {
fn from(
DistributionShift {
current_mean: OrderedFloat(current_mean),
current_sigma: OrderedFloat(current_sigma),
}: DistributionShift,
) -> Self {
Self { mean: current_mean, sigma: current_sigma }
}
}
impl From<DistributionShiftSerializable> for DistributionShift {
fn from(DistributionShiftSerializable { mean, sigma }: DistributionShiftSerializable) -> Self {
Self { current_mean: OrderedFloat(mean), current_sigma: OrderedFloat(sigma) }
}
}
impl DistributionShift {
/// `None` if sigma <= 0.
pub fn new(mean: f32, sigma: f32) -> Option<Self> {
if sigma <= 0.0 {
None
} else {
Some(Self { current_mean: OrderedFloat(mean), current_sigma: OrderedFloat(sigma) })
}
}
pub fn shift(&self, score: f32) -> f32 {
let current_mean = self.current_mean.0;
let current_sigma = self.current_sigma.0;
// <https://math.stackexchange.com/a/2894689>
// We're somewhat abusively mapping the distribution of distances to a gaussian.
// The parameters we're given is the mean and sigma of the native result distribution.
// We're using them to retarget the distribution to a gaussian centered on 0.5 with a sigma of 0.4.
let target_mean = 0.5;
let target_sigma = 0.4;
// a^2 sig1^2 = sig2^2 => a^2 = sig2^2 / sig1^2 => a = sig2 / sig1, assuming a, sig1, and sig2 positive.
let factor = target_sigma / current_sigma;
// a*mu1 + b = mu2 => b = mu2 - a*mu1
let offset = target_mean - (factor * current_mean);
let mut score = factor * score + offset;
// clamp the final score in the ]0, 1] interval.
if score <= 0.0 {
score = f32::EPSILON;
}
if score > 1.0 {
score = 1.0;
}
score
}
}
/// Whether CUDA is supported in this version of Meilisearch.
pub const fn is_cuda_enabled() -> bool {
cfg!(feature = "cuda")
}
pub fn arroy_db_range_for_embedder(embedder_id: u8) -> impl Iterator<Item = u16> {
let embedder_id = (embedder_id as u16) << 8;
(0..=u8::MAX).map(move |k| embedder_id | (k as u16))
}

View file

@ -0,0 +1,123 @@
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind};
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
use super::{DistributionShift, Embeddings};
use crate::error::FaultSource;
use crate::ThreadPoolNoAbort;
#[derive(Debug)]
pub struct Embedder {
rest_embedder: RestEmbedder,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct EmbedderOptions {
pub embedding_model: String,
pub url: Option<String>,
pub api_key: Option<String>,
pub distribution: Option<DistributionShift>,
pub dimensions: Option<usize>,
}
impl EmbedderOptions {
pub fn with_default_model(
api_key: Option<String>,
url: Option<String>,
dimensions: Option<usize>,
) -> Self {
Self {
embedding_model: "nomic-embed-text".into(),
api_key,
url,
distribution: None,
dimensions,
}
}
}
impl Embedder {
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
let model = options.embedding_model.as_str();
let rest_embedder = match RestEmbedder::new(
RestEmbedderOptions {
api_key: options.api_key,
dimensions: options.dimensions,
distribution: options.distribution,
url: options.url.unwrap_or_else(get_ollama_path),
request: serde_json::json!({
"model": model,
"prompt": super::rest::REQUEST_PLACEHOLDER,
}),
response: serde_json::json!({
"embedding": super::rest::RESPONSE_PLACEHOLDER,
}),
headers: Default::default(),
},
super::rest::ConfigurationSource::Ollama,
) {
Ok(embedder) => embedder,
Err(NewEmbedderError {
kind:
NewEmbedderErrorKind::CouldNotDetermineDimension(EmbedError {
kind: super::error::EmbedErrorKind::RestOtherStatusCode(404, error),
fault: _,
}),
fault: _,
}) => {
return Err(NewEmbedderError::could_not_determine_dimension(
EmbedError::ollama_model_not_found(error),
))
}
Err(error) => return Err(error),
};
Ok(Self { rest_embedder })
}
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
match self.rest_embedder.embed(texts) {
Ok(embeddings) => Ok(embeddings),
Err(EmbedError { kind: EmbedErrorKind::RestOtherStatusCode(404, error), fault: _ }) => {
Err(EmbedError::ollama_model_not_found(error))
}
Err(error) => Err(error),
}
}
pub fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
threads: &ThreadPoolNoAbort,
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
threads
.install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
})
.map_err(|error| EmbedError {
kind: EmbedErrorKind::PanicInThreadPool(error),
fault: FaultSource::Bug,
})?
}
pub fn chunk_count_hint(&self) -> usize {
self.rest_embedder.chunk_count_hint()
}
pub fn prompt_count_in_chunk_hint(&self) -> usize {
self.rest_embedder.prompt_count_in_chunk_hint()
}
pub fn dimensions(&self) -> usize {
self.rest_embedder.dimensions()
}
pub fn distribution(&self) -> Option<DistributionShift> {
self.rest_embedder.distribution()
}
}
fn get_ollama_path() -> String {
// Important: Hostname not enough, has to be entire path to embeddings endpoint
std::env::var("MEILI_OLLAMA_URL").unwrap_or("http://localhost:11434/api/embeddings".to_string())
}

View file

@ -0,0 +1,274 @@
use ordered_float::OrderedFloat;
use rayon::iter::{IntoParallelIterator, ParallelIterator as _};
use super::error::{EmbedError, NewEmbedderError};
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
use super::{DistributionShift, Embeddings};
use crate::error::FaultSource;
use crate::vector::error::EmbedErrorKind;
use crate::ThreadPoolNoAbort;
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct EmbedderOptions {
pub url: Option<String>,
pub api_key: Option<String>,
pub embedding_model: EmbeddingModel,
pub dimensions: Option<usize>,
pub distribution: Option<DistributionShift>,
}
impl EmbedderOptions {
pub fn dimensions(&self) -> usize {
if self.embedding_model.supports_overriding_dimensions() {
self.dimensions.unwrap_or(self.embedding_model.default_dimensions())
} else {
self.embedding_model.default_dimensions()
}
}
pub fn request(&self) -> serde_json::Value {
let model = self.embedding_model.name();
let mut request = serde_json::json!({
"model": model,
"input": [super::rest::REQUEST_PLACEHOLDER, super::rest::REPEAT_PLACEHOLDER]
});
if self.embedding_model.supports_overriding_dimensions() {
if let Some(dimensions) = self.dimensions {
request["dimensions"] = dimensions.into();
}
}
request
}
pub fn distribution(&self) -> Option<DistributionShift> {
self.distribution.or(self.embedding_model.distribution())
}
}
#[derive(
Debug,
Clone,
Copy,
Default,
Hash,
PartialEq,
Eq,
serde::Serialize,
serde::Deserialize,
deserr::Deserr,
)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)]
pub enum EmbeddingModel {
// # WARNING
//
// If ever adding a model, make sure to add it to the list of supported models below.
#[serde(rename = "text-embedding-ada-002")]
#[deserr(rename = "text-embedding-ada-002")]
TextEmbeddingAda002,
#[default]
#[serde(rename = "text-embedding-3-small")]
#[deserr(rename = "text-embedding-3-small")]
TextEmbedding3Small,
#[serde(rename = "text-embedding-3-large")]
#[deserr(rename = "text-embedding-3-large")]
TextEmbedding3Large,
}
impl EmbeddingModel {
pub fn supported_models() -> &'static [&'static str] {
&["text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large"]
}
pub fn max_token(&self) -> usize {
match self {
EmbeddingModel::TextEmbeddingAda002 => 8191,
EmbeddingModel::TextEmbedding3Large => 8191,
EmbeddingModel::TextEmbedding3Small => 8191,
}
}
pub fn default_dimensions(&self) -> usize {
match self {
EmbeddingModel::TextEmbeddingAda002 => 1536,
EmbeddingModel::TextEmbedding3Large => 3072,
EmbeddingModel::TextEmbedding3Small => 1536,
}
}
pub fn name(&self) -> &'static str {
match self {
EmbeddingModel::TextEmbeddingAda002 => "text-embedding-ada-002",
EmbeddingModel::TextEmbedding3Large => "text-embedding-3-large",
EmbeddingModel::TextEmbedding3Small => "text-embedding-3-small",
}
}
pub fn from_name(name: &str) -> Option<Self> {
match name {
"text-embedding-ada-002" => Some(EmbeddingModel::TextEmbeddingAda002),
"text-embedding-3-large" => Some(EmbeddingModel::TextEmbedding3Large),
"text-embedding-3-small" => Some(EmbeddingModel::TextEmbedding3Small),
_ => None,
}
}
fn distribution(&self) -> Option<DistributionShift> {
match self {
EmbeddingModel::TextEmbeddingAda002 => Some(DistributionShift {
current_mean: OrderedFloat(0.90),
current_sigma: OrderedFloat(0.08),
}),
EmbeddingModel::TextEmbedding3Large => Some(DistributionShift {
current_mean: OrderedFloat(0.70),
current_sigma: OrderedFloat(0.1),
}),
EmbeddingModel::TextEmbedding3Small => Some(DistributionShift {
current_mean: OrderedFloat(0.75),
current_sigma: OrderedFloat(0.1),
}),
}
}
pub fn supports_overriding_dimensions(&self) -> bool {
match self {
EmbeddingModel::TextEmbeddingAda002 => false,
EmbeddingModel::TextEmbedding3Large => true,
EmbeddingModel::TextEmbedding3Small => true,
}
}
}
pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
impl EmbedderOptions {
pub fn with_default_model(api_key: Option<String>) -> Self {
Self {
api_key,
embedding_model: Default::default(),
dimensions: None,
distribution: None,
url: None,
}
}
}
fn infer_api_key() -> String {
std::env::var("MEILI_OPENAI_API_KEY")
.or_else(|_| std::env::var("OPENAI_API_KEY"))
.unwrap_or_default()
}
#[derive(Debug)]
pub struct Embedder {
tokenizer: tiktoken_rs::CoreBPE,
rest_embedder: RestEmbedder,
options: EmbedderOptions,
}
impl Embedder {
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
let mut inferred_api_key = Default::default();
let api_key = options.api_key.as_ref().unwrap_or_else(|| {
inferred_api_key = infer_api_key();
&inferred_api_key
});
let url = options.url.as_deref().unwrap_or(OPENAI_EMBEDDINGS_URL).to_owned();
let rest_embedder = RestEmbedder::new(
RestEmbedderOptions {
api_key: (!api_key.is_empty()).then(|| api_key.clone()),
distribution: None,
dimensions: Some(options.dimensions()),
url,
request: options.request(),
response: serde_json::json!({
"data": [{
"embedding": super::rest::RESPONSE_PLACEHOLDER
},
super::rest::REPEAT_PLACEHOLDER
]
}),
headers: Default::default(),
},
super::rest::ConfigurationSource::OpenAi,
)?;
// looking at the code it is very unclear that this can actually fail.
let tokenizer = tiktoken_rs::cl100k_base().unwrap();
Ok(Self { options, rest_embedder, tokenizer })
}
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
match self.rest_embedder.embed_ref(&texts) {
Ok(embeddings) => Ok(embeddings),
Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error, _), fault: _ }) => {
tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template.");
self.try_embed_tokenized(&texts)
}
Err(error) => Err(error),
}
}
fn try_embed_tokenized(&self, text: &[String]) -> Result<Vec<Embeddings<f32>>, EmbedError> {
let mut all_embeddings = Vec::with_capacity(text.len());
for text in text {
let max_token_count = self.options.embedding_model.max_token();
let encoded = self.tokenizer.encode_ordinary(text.as_str());
let len = encoded.len();
if len < max_token_count {
all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text])?);
continue;
}
let tokens = &encoded.as_slice()[0..max_token_count];
let mut embeddings_for_prompt = Embeddings::new(self.dimensions());
let embedding = self.rest_embedder.embed_tokens(tokens)?;
embeddings_for_prompt.append(embedding.into_inner()).map_err(|got| {
EmbedError::rest_unexpected_dimension(self.dimensions(), got.len())
})?;
all_embeddings.push(embeddings_for_prompt);
}
Ok(all_embeddings)
}
pub fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
threads: &ThreadPoolNoAbort,
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
threads
.install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
})
.map_err(|error| EmbedError {
kind: EmbedErrorKind::PanicInThreadPool(error),
fault: FaultSource::Bug,
})?
}
pub fn chunk_count_hint(&self) -> usize {
self.rest_embedder.chunk_count_hint()
}
pub fn prompt_count_in_chunk_hint(&self) -> usize {
self.rest_embedder.prompt_count_in_chunk_hint()
}
pub fn dimensions(&self) -> usize {
self.options.dimensions()
}
pub fn distribution(&self) -> Option<DistributionShift> {
self.options.distribution()
}
}

View file

@ -0,0 +1,417 @@
use std::collections::{BTreeMap, BTreeSet};
use deserr::{take_cf_content, DeserializeError, Deserr, Sequence};
use obkv::KvReader;
use serde_json::{from_slice, Value};
use super::Embedding;
use crate::index::IndexEmbeddingConfig;
use crate::update::del_add::{DelAdd, KvReaderDelAdd};
use crate::{DocumentId, FieldId, InternalError, UserError};
pub const RESERVED_VECTORS_FIELD_NAME: &str = "_vectors";
#[derive(serde::Serialize, Debug)]
#[serde(untagged)]
pub enum Vectors {
ImplicitlyUserProvided(VectorOrArrayOfVectors),
Explicit(ExplicitVectors),
}
impl<E: DeserializeError> Deserr<E> for Vectors {
fn deserialize_from_value<V: deserr::IntoValue>(
value: deserr::Value<V>,
location: deserr::ValuePointerRef<'_>,
) -> Result<Self, E> {
match value {
deserr::Value::Sequence(_) | deserr::Value::Null => {
Ok(Vectors::ImplicitlyUserProvided(VectorOrArrayOfVectors::deserialize_from_value(
value, location,
)?))
}
deserr::Value::Map(_) => {
Ok(Vectors::Explicit(ExplicitVectors::deserialize_from_value(value, location)?))
}
value => Err(take_cf_content(E::error(
None,
deserr::ErrorKind::IncorrectValueKind {
actual: value,
accepted: &[
deserr::ValueKind::Sequence,
deserr::ValueKind::Map,
deserr::ValueKind::Null,
],
},
location,
))),
}
}
}
impl Vectors {
pub fn must_regenerate(&self) -> bool {
match self {
Vectors::ImplicitlyUserProvided(_) => false,
Vectors::Explicit(ExplicitVectors { regenerate, .. }) => *regenerate,
}
}
pub fn into_array_of_vectors(self) -> Option<Vec<Embedding>> {
match self {
Vectors::ImplicitlyUserProvided(embeddings) => {
Some(embeddings.into_array_of_vectors().unwrap_or_default())
}
Vectors::Explicit(ExplicitVectors { embeddings, regenerate: _ }) => {
embeddings.map(|embeddings| embeddings.into_array_of_vectors().unwrap_or_default())
}
}
}
}
#[derive(serde::Serialize, Deserr, Debug)]
#[serde(rename_all = "camelCase")]
pub struct ExplicitVectors {
#[serde(default)]
#[deserr(default)]
pub embeddings: Option<VectorOrArrayOfVectors>,
pub regenerate: bool,
}
pub enum VectorState {
Inline(Vectors),
Manual,
Generated,
}
impl VectorState {
pub fn must_regenerate(&self) -> bool {
match self {
VectorState::Inline(vectors) => vectors.must_regenerate(),
VectorState::Manual => false,
VectorState::Generated => true,
}
}
}
pub enum VectorsState {
NoVectorsFid,
NoVectorsFieldInDocument,
Vectors(BTreeMap<String, Vectors>),
}
pub struct ParsedVectorsDiff {
old: BTreeMap<String, VectorState>,
new: VectorsState,
}
impl ParsedVectorsDiff {
pub fn new(
docid: DocumentId,
embedders_configs: &[IndexEmbeddingConfig],
documents_diff: KvReader<'_, FieldId>,
old_vectors_fid: Option<FieldId>,
new_vectors_fid: Option<FieldId>,
) -> Result<Self, Error> {
let mut old = match old_vectors_fid
.and_then(|vectors_fid| documents_diff.get(vectors_fid))
.map(KvReaderDelAdd::new)
.map(|obkv| to_vector_map(obkv, DelAdd::Deletion))
.transpose()
{
Ok(del) => del,
// ignore wrong shape for old version of documents, use an empty map in this case
Err(Error::InvalidMap(value)) => {
tracing::warn!(%value, "Previous version of the `_vectors` field had a wrong shape");
Default::default()
}
Err(error) => {
return Err(error);
}
}
.flatten().map_or(BTreeMap::default(), |del| del.into_iter().map(|(name, vec)| (name, VectorState::Inline(vec))).collect());
for embedding_config in embedders_configs {
if embedding_config.user_provided.contains(docid) {
old.entry(embedding_config.name.to_string()).or_insert(VectorState::Manual);
}
}
let new = 'new: {
let Some(new_vectors_fid) = new_vectors_fid else {
break 'new VectorsState::NoVectorsFid;
};
let Some(bytes) = documents_diff.get(new_vectors_fid) else {
break 'new VectorsState::NoVectorsFieldInDocument;
};
let obkv = KvReaderDelAdd::new(bytes);
match to_vector_map(obkv, DelAdd::Addition)? {
Some(new) => VectorsState::Vectors(new),
None => VectorsState::NoVectorsFieldInDocument,
}
};
Ok(Self { old, new })
}
pub fn remove(&mut self, embedder_name: &str) -> (VectorState, VectorState) {
let old = self.old.remove(embedder_name).unwrap_or(VectorState::Generated);
let state_from_old = match old {
// assume a userProvided is still userProvided
VectorState::Manual => VectorState::Manual,
// generated is still generated
VectorState::Generated => VectorState::Generated,
// weird case that shouldn't happen were the previous docs version is inline,
// but it was removed in the new version
// Since it is not in the new version, we switch to generated
VectorState::Inline(_) => VectorState::Generated,
};
let new = match &mut self.new {
VectorsState::Vectors(new) => {
new.remove(embedder_name).map(VectorState::Inline).unwrap_or(state_from_old)
}
_ =>
// if no `_vectors` field is present in the new document,
// the state depends on the previous version of the document
{
state_from_old
}
};
(old, new)
}
pub fn into_new_vectors_keys_iter(self) -> impl Iterator<Item = String> {
let maybe_it = match self.new {
VectorsState::NoVectorsFid => None,
VectorsState::NoVectorsFieldInDocument => None,
VectorsState::Vectors(vectors) => Some(vectors.into_keys()),
};
maybe_it.into_iter().flatten()
}
}
pub struct ParsedVectors(pub BTreeMap<String, Vectors>);
impl<E: DeserializeError> Deserr<E> for ParsedVectors {
fn deserialize_from_value<V: deserr::IntoValue>(
value: deserr::Value<V>,
location: deserr::ValuePointerRef<'_>,
) -> Result<Self, E> {
let value = <BTreeMap<String, Vectors>>::deserialize_from_value(value, location)?;
Ok(ParsedVectors(value))
}
}
impl ParsedVectors {
pub fn from_bytes(value: &[u8]) -> Result<Self, Error> {
let value: serde_json::Value = from_slice(value).map_err(Error::InternalSerdeJson)?;
deserr::deserialize(value).map_err(|error| Error::InvalidEmbedderConf { error })
}
pub fn retain_not_embedded_vectors(&mut self, embedders: &BTreeSet<String>) {
self.0.retain(|k, _v| !embedders.contains(k))
}
}
pub enum Error {
InvalidMap(Value),
InvalidEmbedderConf { error: deserr::errors::JsonError },
InternalSerdeJson(serde_json::Error),
}
impl Error {
pub fn to_crate_error(self, document_id: String) -> crate::Error {
match self {
Error::InvalidMap(value) => {
crate::Error::UserError(UserError::InvalidVectorsMapType { document_id, value })
}
Error::InvalidEmbedderConf { error } => {
crate::Error::UserError(UserError::InvalidVectorsEmbedderConf {
document_id,
error,
})
}
Error::InternalSerdeJson(error) => {
crate::Error::InternalError(InternalError::SerdeJson(error))
}
}
}
}
fn to_vector_map(
obkv: KvReaderDelAdd<'_>,
side: DelAdd,
) -> Result<Option<BTreeMap<String, Vectors>>, Error> {
Ok(if let Some(value) = obkv.get(side) {
let ParsedVectors(parsed_vectors) = ParsedVectors::from_bytes(value)?;
Some(parsed_vectors)
} else {
None
})
}
/// Represents either a vector or an array of multiple vectors.
#[derive(serde::Serialize, Debug)]
#[serde(transparent)]
pub struct VectorOrArrayOfVectors {
#[serde(with = "either::serde_untagged_optional")]
inner: Option<either::Either<Vec<Embedding>, Embedding>>,
}
impl<E: DeserializeError> Deserr<E> for VectorOrArrayOfVectors {
fn deserialize_from_value<V: deserr::IntoValue>(
value: deserr::Value<V>,
location: deserr::ValuePointerRef<'_>,
) -> Result<Self, E> {
match value {
deserr::Value::Null => Ok(VectorOrArrayOfVectors { inner: None }),
deserr::Value::Sequence(seq) => {
let mut iter = seq.into_iter();
match iter.next().map(|v| v.into_value()) {
None => {
// With the strange way serde serialize the `Either`, we must send the left part
// otherwise it'll consider we returned [[]]
Ok(VectorOrArrayOfVectors { inner: Some(either::Either::Left(Vec::new())) })
}
Some(val @ deserr::Value::Sequence(_)) => {
let first = Embedding::deserialize_from_value(val, location.push_index(0))?;
let mut collect = vec![first];
let mut tail = iter
.enumerate()
.map(|(i, v)| {
Embedding::deserialize_from_value(
v.into_value(),
location.push_index(i + 1),
)
})
.collect::<Result<Vec<_>, _>>()?;
collect.append(&mut tail);
Ok(VectorOrArrayOfVectors { inner: Some(either::Either::Left(collect)) })
}
Some(
val @ deserr::Value::Integer(_)
| val @ deserr::Value::NegativeInteger(_)
| val @ deserr::Value::Float(_),
) => {
let first = <f32>::deserialize_from_value(val, location.push_index(0))?;
let mut embedding = iter
.enumerate()
.map(|(i, v)| {
<f32>::deserialize_from_value(
v.into_value(),
location.push_index(i + 1),
)
})
.collect::<Result<Vec<_>, _>>()?;
embedding.insert(0, first);
Ok(VectorOrArrayOfVectors { inner: Some(either::Either::Right(embedding)) })
}
Some(value) => Err(take_cf_content(E::error(
None,
deserr::ErrorKind::IncorrectValueKind {
actual: value,
accepted: &[deserr::ValueKind::Sequence, deserr::ValueKind::Float],
},
location.push_index(0),
))),
}
}
value => Err(take_cf_content(E::error(
None,
deserr::ErrorKind::IncorrectValueKind {
actual: value,
accepted: &[deserr::ValueKind::Sequence, deserr::ValueKind::Null],
},
location,
))),
}
}
}
impl VectorOrArrayOfVectors {
pub fn into_array_of_vectors(self) -> Option<Vec<Embedding>> {
match self.inner? {
either::Either::Left(vectors) => Some(vectors),
either::Either::Right(vector) => Some(vec![vector]),
}
}
pub fn from_array_of_vectors(array_of_vec: Vec<Embedding>) -> Self {
Self { inner: Some(either::Either::Left(array_of_vec)) }
}
pub fn from_vector(vec: Embedding) -> Self {
Self { inner: Some(either::Either::Right(vec)) }
}
}
impl From<Embedding> for VectorOrArrayOfVectors {
fn from(vec: Embedding) -> Self {
Self::from_vector(vec)
}
}
impl From<Vec<Embedding>> for VectorOrArrayOfVectors {
fn from(vec: Vec<Embedding>) -> Self {
Self::from_array_of_vectors(vec)
}
}
#[cfg(test)]
mod test {
use super::VectorOrArrayOfVectors;
fn embedding_from_str(s: &str) -> Result<VectorOrArrayOfVectors, deserr::errors::JsonError> {
let value: serde_json::Value = serde_json::from_str(s).unwrap();
deserr::deserialize(value)
}
#[test]
fn array_of_vectors() {
let null = embedding_from_str("null").unwrap();
let empty = embedding_from_str("[]").unwrap();
let one = embedding_from_str("[0.1]").unwrap();
let two = embedding_from_str("[0.1, 0.2]").unwrap();
let one_vec = embedding_from_str("[[0.1, 0.2]]").unwrap();
let two_vecs = embedding_from_str("[[0.1, 0.2], [0.3, 0.4]]").unwrap();
insta::assert_json_snapshot!(null.into_array_of_vectors(), @"null");
insta::assert_json_snapshot!(empty.into_array_of_vectors(), @"[]");
insta::assert_json_snapshot!(one.into_array_of_vectors(), @r###"
[
[
0.1
]
]
"###);
insta::assert_json_snapshot!(two.into_array_of_vectors(), @r###"
[
[
0.1,
0.2
]
]
"###);
insta::assert_json_snapshot!(one_vec.into_array_of_vectors(), @r###"
[
[
0.1,
0.2
]
]
"###);
insta::assert_json_snapshot!(two_vecs.into_array_of_vectors(), @r###"
[
[
0.1,
0.2
],
[
0.3,
0.4
]
]
"###);
}
}

View file

@ -0,0 +1,411 @@
use std::collections::BTreeMap;
use deserr::Deserr;
use rand::Rng;
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
use serde::{Deserialize, Serialize};
use super::error::EmbedErrorKind;
use super::json_template::ValueTemplate;
use super::{
DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM,
};
use crate::error::FaultSource;
use crate::ThreadPoolNoAbort;
// retrying in case of failure
pub struct Retry {
pub error: EmbedError,
strategy: RetryStrategy,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConfigurationSource {
OpenAi,
Ollama,
User,
}
pub enum RetryStrategy {
GiveUp,
Retry,
RetryTokenized,
RetryAfterRateLimit,
}
impl Retry {
pub fn give_up(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::GiveUp }
}
pub fn retry_later(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::Retry }
}
pub fn retry_tokenized(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::RetryTokenized }
}
pub fn rate_limited(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::RetryAfterRateLimit }
}
pub fn into_duration(self, attempt: u32) -> Result<std::time::Duration, EmbedError> {
match self.strategy {
RetryStrategy::GiveUp => Err(self.error),
RetryStrategy::Retry => Ok(std::time::Duration::from_millis((10u64).pow(attempt))),
RetryStrategy::RetryTokenized => Ok(std::time::Duration::from_millis(1)),
RetryStrategy::RetryAfterRateLimit => {
Ok(std::time::Duration::from_millis(100 + 10u64.pow(attempt)))
}
}
}
pub fn must_tokenize(&self) -> bool {
matches!(self.strategy, RetryStrategy::RetryTokenized)
}
pub fn into_error(self) -> EmbedError {
self.error
}
}
#[derive(Debug)]
pub struct Embedder {
data: EmbedderData,
dimensions: usize,
distribution: Option<DistributionShift>,
}
/// All data needed to perform requests and parse responses
#[derive(Debug)]
struct EmbedderData {
client: ureq::Agent,
bearer: Option<String>,
headers: BTreeMap<String, String>,
url: String,
request: Request,
response: Response,
configuration_source: ConfigurationSource,
}
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
pub struct EmbedderOptions {
pub api_key: Option<String>,
pub distribution: Option<DistributionShift>,
pub dimensions: Option<usize>,
pub url: String,
pub request: serde_json::Value,
pub response: serde_json::Value,
pub headers: BTreeMap<String, String>,
}
impl std::hash::Hash for EmbedderOptions {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.api_key.hash(state);
self.distribution.hash(state);
self.dimensions.hash(state);
self.url.hash(state);
// skip hashing the request and response
// collisions in regular usage should be minimal,
// and the list is limited to 256 values anyway
}
}
#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Hash, Deserr)]
#[serde(rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)]
enum InputType {
Text,
TextArray,
}
impl Embedder {
pub fn new(
options: EmbedderOptions,
configuration_source: ConfigurationSource,
) -> Result<Self, NewEmbedderError> {
let bearer = options.api_key.as_deref().map(|api_key| format!("Bearer {api_key}"));
let client = ureq::AgentBuilder::new()
.max_idle_connections(REQUEST_PARALLELISM * 2)
.max_idle_connections_per_host(REQUEST_PARALLELISM * 2)
.build();
let request = Request::new(options.request)?;
let response = Response::new(options.response, &request)?;
let data = EmbedderData {
client,
bearer,
url: options.url,
request,
response,
configuration_source,
headers: options.headers,
};
let dimensions = if let Some(dimensions) = options.dimensions {
dimensions
} else {
infer_dimensions(&data)?
};
Ok(Self { data, dimensions, distribution: options.distribution })
}
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
embed(&self.data, texts.as_slice(), texts.len(), Some(self.dimensions))
}
pub fn embed_ref<S>(&self, texts: &[S]) -> Result<Vec<Embeddings<f32>>, EmbedError>
where
S: AsRef<str> + Serialize,
{
embed(&self.data, texts, texts.len(), Some(self.dimensions))
}
pub fn embed_tokens(&self, tokens: &[usize]) -> Result<Embeddings<f32>, EmbedError> {
let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions))?;
// unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error
Ok(embeddings.pop().unwrap())
}
pub fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
threads: &ThreadPoolNoAbort,
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
threads
.install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
})
.map_err(|error| EmbedError {
kind: EmbedErrorKind::PanicInThreadPool(error),
fault: FaultSource::Bug,
})?
}
pub fn chunk_count_hint(&self) -> usize {
super::REQUEST_PARALLELISM
}
pub fn prompt_count_in_chunk_hint(&self) -> usize {
match self.data.request.input_type() {
InputType::Text => 1,
InputType::TextArray => 10,
}
}
pub fn dimensions(&self) -> usize {
self.dimensions
}
pub fn distribution(&self) -> Option<DistributionShift> {
self.distribution
}
}
fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> {
let v = embed(data, ["test"].as_slice(), 1, None)
.map_err(NewEmbedderError::could_not_determine_dimension)?;
// unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error
Ok(v.first().unwrap().dimension())
}
fn embed<S>(
data: &EmbedderData,
inputs: &[S],
expected_count: usize,
expected_dimension: Option<usize>,
) -> Result<Vec<Embeddings<f32>>, EmbedError>
where
S: Serialize,
{
let request = data.client.post(&data.url);
let request = if let Some(bearer) = &data.bearer {
request.set("Authorization", bearer)
} else {
request
};
let mut request = request.set("Content-Type", "application/json");
for (header, value) in &data.headers {
request = request.set(header.as_str(), value.as_str());
}
let body = data.request.inject_texts(inputs);
for attempt in 0..10 {
let response = request.clone().send_json(&body);
let result = check_response(response, data.configuration_source);
let retry_duration = match result {
Ok(response) => {
return response_to_embedding(response, data, expected_count, expected_dimension)
}
Err(retry) => {
tracing::warn!("Failed: {}", retry.error);
retry.into_duration(attempt)
}
}?;
let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute
// randomly up to double the retry duration
let retry_duration = retry_duration
+ rand::thread_rng().gen_range(std::time::Duration::ZERO..retry_duration);
tracing::warn!("Attempt #{}, retrying after {}ms.", attempt, retry_duration.as_millis());
std::thread::sleep(retry_duration);
}
let response = request.send_json(&body);
let result = check_response(response, data.configuration_source);
result.map_err(Retry::into_error).and_then(|response| {
response_to_embedding(response, data, expected_count, expected_dimension)
})
}
fn check_response(
response: Result<ureq::Response, ureq::Error>,
configuration_source: ConfigurationSource,
) -> Result<ureq::Response, Retry> {
match response {
Ok(response) => Ok(response),
Err(ureq::Error::Status(code, response)) => {
let error_response: Option<String> = response.into_string().ok();
Err(match code {
401 => Retry::give_up(EmbedError::rest_unauthorized(
error_response,
configuration_source,
)),
429 => Retry::rate_limited(EmbedError::rest_too_many_requests(error_response)),
400 => Retry::give_up(EmbedError::rest_bad_request(
error_response,
configuration_source,
)),
500..=599 => {
Retry::retry_later(EmbedError::rest_internal_server_error(code, error_response))
}
402..=499 => {
Retry::give_up(EmbedError::rest_other_status_code(code, error_response))
}
_ => Retry::retry_later(EmbedError::rest_other_status_code(code, error_response)),
})
}
Err(ureq::Error::Transport(transport)) => {
Err(Retry::retry_later(EmbedError::rest_network(transport)))
}
}
}
fn response_to_embedding(
response: ureq::Response,
data: &EmbedderData,
expected_count: usize,
expected_dimensions: Option<usize>,
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
let response: serde_json::Value =
response.into_json().map_err(EmbedError::rest_response_deserialization)?;
let embeddings = data.response.extract_embeddings(response)?;
if embeddings.len() != expected_count {
return Err(EmbedError::rest_response_embedding_count(expected_count, embeddings.len()));
}
if let Some(dimensions) = expected_dimensions {
for embedding in &embeddings {
if embedding.dimension() != dimensions {
return Err(EmbedError::rest_unexpected_dimension(
dimensions,
embedding.dimension(),
));
}
}
}
Ok(embeddings)
}
pub(super) const REQUEST_PLACEHOLDER: &str = "{{text}}";
pub(super) const RESPONSE_PLACEHOLDER: &str = "{{embedding}}";
pub(super) const REPEAT_PLACEHOLDER: &str = "{{..}}";
#[derive(Debug)]
pub struct Request {
template: ValueTemplate,
}
impl Request {
pub fn new(template: serde_json::Value) -> Result<Self, NewEmbedderError> {
let template = match ValueTemplate::new(template, REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER) {
Ok(template) => template,
Err(error) => {
let message =
error.error_message("request", REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER);
return Err(NewEmbedderError::rest_could_not_parse_template(message));
}
};
Ok(Self { template })
}
fn input_type(&self) -> InputType {
if self.template.has_array_value() {
InputType::TextArray
} else {
InputType::Text
}
}
pub fn inject_texts<S: Serialize>(
&self,
texts: impl IntoIterator<Item = S>,
) -> serde_json::Value {
self.template.inject(texts.into_iter().map(|s| serde_json::json!(s))).unwrap()
}
}
#[derive(Debug)]
pub struct Response {
template: ValueTemplate,
}
impl Response {
pub fn new(template: serde_json::Value, request: &Request) -> Result<Self, NewEmbedderError> {
let template = match ValueTemplate::new(template, RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER)
{
Ok(template) => template,
Err(error) => {
let message =
error.error_message("response", RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER);
return Err(NewEmbedderError::rest_could_not_parse_template(message));
}
};
match (template.has_array_value(), request.template.has_array_value()) {
(true, true) | (false, false) => Ok(Self {template}),
(true, false) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has multiple embeddings, but `request` has only one text to embed".to_string())),
(false, true) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has a single embedding, but `request` has multiple texts to embed".to_string())),
}
}
pub fn extract_embeddings(
&self,
response: serde_json::Value,
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
let extracted_values: Vec<Embedding> = match self.template.extract(response) {
Ok(extracted_values) => extracted_values,
Err(error) => {
let error_message =
error.error_message("response", "{{embedding}}", "an array of numbers");
return Err(EmbedError::rest_extraction_error(error_message));
}
};
let embeddings: Vec<Embeddings<f32>> =
extracted_values.into_iter().map(Embeddings::from_single_embedding).collect();
Ok(embeddings)
}
}

View file

@ -0,0 +1,770 @@
use std::collections::BTreeMap;
use std::num::NonZeroUsize;
use deserr::Deserr;
use roaring::RoaringBitmap;
use serde::{Deserialize, Serialize};
use super::{ollama, openai, DistributionShift};
use crate::prompt::{default_max_bytes, PromptData};
use crate::update::Setting;
use crate::vector::EmbeddingConfig;
use crate::UserError;
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)]
pub struct EmbeddingSettings {
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub source: Setting<EmbedderSource>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub model: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub revision: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub api_key: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub dimensions: Setting<usize>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub binary_quantized: Setting<bool>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub document_template: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub document_template_max_bytes: Setting<usize>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub url: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub request: Setting<serde_json::Value>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub response: Setting<serde_json::Value>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub headers: Setting<BTreeMap<String, String>>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub distribution: Setting<DistributionShift>,
}
pub fn check_unset<T>(
key: &Setting<T>,
field: &'static str,
source: EmbedderSource,
embedder_name: &str,
) -> Result<(), UserError> {
if matches!(key, Setting::NotSet) {
Ok(())
} else {
Err(UserError::InvalidFieldForSource {
embedder_name: embedder_name.to_owned(),
source_: source,
field,
allowed_fields_for_source: EmbeddingSettings::allowed_fields_for_source(source),
allowed_sources_for_field: EmbeddingSettings::allowed_sources_for_field(field),
})
}
}
/// Indicates what action should take place during a reindexing operation for an embedder
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ReindexAction {
/// An indexing operation should take place for this embedder, keeping existing vectors
/// and checking whether the document template changed or not
RegeneratePrompts,
/// An indexing operation should take place for all documents for this embedder, removing existing vectors
/// (except userProvided ones)
FullReindex,
}
pub enum SettingsDiff {
Remove,
Reindex { action: ReindexAction, updated_settings: EmbeddingSettings, quantize: bool },
UpdateWithoutReindex { updated_settings: EmbeddingSettings, quantize: bool },
}
#[derive(Default, Debug)]
pub struct EmbedderAction {
pub was_quantized: bool,
pub is_being_quantized: bool,
pub write_back: Option<WriteBackToDocuments>,
pub reindex: Option<ReindexAction>,
}
impl EmbedderAction {
pub fn is_being_quantized(&self) -> bool {
self.is_being_quantized
}
pub fn write_back(&self) -> Option<&WriteBackToDocuments> {
self.write_back.as_ref()
}
pub fn reindex(&self) -> Option<&ReindexAction> {
self.reindex.as_ref()
}
pub fn with_is_being_quantized(mut self, quantize: bool) -> Self {
self.is_being_quantized = quantize;
self
}
pub fn with_write_back(write_back: WriteBackToDocuments, was_quantized: bool) -> Self {
Self {
was_quantized,
is_being_quantized: false,
write_back: Some(write_back),
reindex: None,
}
}
pub fn with_reindex(reindex: ReindexAction, was_quantized: bool) -> Self {
Self { was_quantized, is_being_quantized: false, write_back: None, reindex: Some(reindex) }
}
}
#[derive(Debug)]
pub struct WriteBackToDocuments {
pub embedder_id: u8,
pub user_provided: RoaringBitmap,
}
impl SettingsDiff {
pub fn from_settings(
embedder_name: &str,
old: EmbeddingSettings,
new: Setting<EmbeddingSettings>,
) -> Result<Self, UserError> {
let ret = match new {
Setting::Set(new) => {
let EmbeddingSettings {
mut source,
mut model,
mut revision,
mut api_key,
mut dimensions,
mut document_template,
mut url,
mut request,
mut response,
mut distribution,
mut headers,
mut document_template_max_bytes,
binary_quantized: mut binary_quantize,
} = old;
let EmbeddingSettings {
source: new_source,
model: new_model,
revision: new_revision,
api_key: new_api_key,
dimensions: new_dimensions,
document_template: new_document_template,
url: new_url,
request: new_request,
response: new_response,
distribution: new_distribution,
headers: new_headers,
document_template_max_bytes: new_document_template_max_bytes,
binary_quantized: new_binary_quantize,
} = new;
if matches!(binary_quantize, Setting::Set(true))
&& matches!(new_binary_quantize, Setting::Set(false))
{
return Err(UserError::InvalidDisableBinaryQuantization {
embedder_name: embedder_name.to_string(),
});
}
let mut reindex_action = None;
// **Warning**: do not use short-circuiting || here, we want all these operations applied
if source.apply(new_source) {
ReindexAction::push_action(&mut reindex_action, ReindexAction::FullReindex);
// when the source changes, we need to reapply the default settings for the new source
apply_default_for_source(
&source,
&mut model,
&mut revision,
&mut dimensions,
&mut url,
&mut request,
&mut response,
&mut document_template,
&mut document_template_max_bytes,
&mut headers,
)
}
if model.apply(new_model) {
ReindexAction::push_action(&mut reindex_action, ReindexAction::FullReindex);
}
if revision.apply(new_revision) {
ReindexAction::push_action(&mut reindex_action, ReindexAction::FullReindex);
}
if dimensions.apply(new_dimensions) {
match source {
// regenerate on dimensions change in OpenAI since truncation is supported
Setting::Set(EmbedderSource::OpenAi) | Setting::Reset => {
ReindexAction::push_action(
&mut reindex_action,
ReindexAction::FullReindex,
);
}
// for all other embedders, the parameter is a hint that should not be able to change the result
// and so won't cause a reindex by itself.
_ => {}
}
}
let binary_quantize_changed = binary_quantize.apply(new_binary_quantize);
if url.apply(new_url) {
match source {
// do not regenerate on an url change in OpenAI
Setting::Set(EmbedderSource::OpenAi) | Setting::Reset => {}
_ => {
ReindexAction::push_action(
&mut reindex_action,
ReindexAction::FullReindex,
);
}
}
}
if request.apply(new_request) {
ReindexAction::push_action(&mut reindex_action, ReindexAction::FullReindex);
}
if response.apply(new_response) {
ReindexAction::push_action(&mut reindex_action, ReindexAction::FullReindex);
}
if document_template.apply(new_document_template) {
ReindexAction::push_action(
&mut reindex_action,
ReindexAction::RegeneratePrompts,
);
}
if document_template_max_bytes.apply(new_document_template_max_bytes) {
let previous_document_template_max_bytes =
document_template_max_bytes.set().unwrap_or(default_max_bytes().get());
let new_document_template_max_bytes =
new_document_template_max_bytes.set().unwrap_or(default_max_bytes().get());
// only reindex if the size increased. Reasoning:
// - size decrease is a performance optimization, so we don't reindex and we keep the more accurate vectors
// - size increase is an accuracy optimization, so we want to reindex
if new_document_template_max_bytes > previous_document_template_max_bytes {
ReindexAction::push_action(
&mut reindex_action,
ReindexAction::RegeneratePrompts,
)
}
}
distribution.apply(new_distribution);
api_key.apply(new_api_key);
headers.apply(new_headers);
let updated_settings = EmbeddingSettings {
source,
model,
revision,
api_key,
dimensions,
document_template,
url,
request,
response,
distribution,
headers,
document_template_max_bytes,
binary_quantized: binary_quantize,
};
match reindex_action {
Some(action) => Self::Reindex {
action,
updated_settings,
quantize: binary_quantize_changed,
},
None => Self::UpdateWithoutReindex {
updated_settings,
quantize: binary_quantize_changed,
},
}
}
Setting::Reset => Self::Remove,
Setting::NotSet => {
Self::UpdateWithoutReindex { updated_settings: old, quantize: false }
}
};
Ok(ret)
}
}
impl ReindexAction {
fn push_action(this: &mut Option<Self>, other: Self) {
*this = match (*this, other) {
(_, ReindexAction::FullReindex) => Some(ReindexAction::FullReindex),
(Some(ReindexAction::FullReindex), _) => Some(ReindexAction::FullReindex),
(_, ReindexAction::RegeneratePrompts) => Some(ReindexAction::RegeneratePrompts),
}
}
}
#[allow(clippy::too_many_arguments)] // private function
fn apply_default_for_source(
source: &Setting<EmbedderSource>,
model: &mut Setting<String>,
revision: &mut Setting<String>,
dimensions: &mut Setting<usize>,
url: &mut Setting<String>,
request: &mut Setting<serde_json::Value>,
response: &mut Setting<serde_json::Value>,
document_template: &mut Setting<String>,
document_template_max_bytes: &mut Setting<usize>,
headers: &mut Setting<BTreeMap<String, String>>,
) {
match source {
Setting::Set(EmbedderSource::HuggingFace) => {
*model = Setting::Reset;
*revision = Setting::Reset;
*dimensions = Setting::NotSet;
*url = Setting::NotSet;
*request = Setting::NotSet;
*response = Setting::NotSet;
*headers = Setting::NotSet;
}
Setting::Set(EmbedderSource::Ollama) => {
*model = Setting::Reset;
*revision = Setting::NotSet;
*dimensions = Setting::Reset;
*url = Setting::NotSet;
*request = Setting::NotSet;
*response = Setting::NotSet;
*headers = Setting::NotSet;
}
Setting::Set(EmbedderSource::OpenAi) | Setting::Reset => {
*model = Setting::Reset;
*revision = Setting::NotSet;
*dimensions = Setting::NotSet;
*url = Setting::Reset;
*request = Setting::NotSet;
*response = Setting::NotSet;
*headers = Setting::NotSet;
}
Setting::Set(EmbedderSource::Rest) => {
*model = Setting::NotSet;
*revision = Setting::NotSet;
*dimensions = Setting::Reset;
*url = Setting::Reset;
*request = Setting::Reset;
*response = Setting::Reset;
*headers = Setting::Reset;
}
Setting::Set(EmbedderSource::UserProvided) => {
*model = Setting::NotSet;
*revision = Setting::NotSet;
*dimensions = Setting::Reset;
*url = Setting::NotSet;
*request = Setting::NotSet;
*response = Setting::NotSet;
*document_template = Setting::NotSet;
*document_template_max_bytes = Setting::NotSet;
*headers = Setting::NotSet;
}
Setting::NotSet => {}
}
}
pub fn check_set<T>(
key: &Setting<T>,
field: &'static str,
source: EmbedderSource,
embedder_name: &str,
) -> Result<(), UserError> {
if matches!(key, Setting::Set(_)) {
Ok(())
} else {
Err(UserError::MissingFieldForSource {
field,
source_: source,
embedder_name: embedder_name.to_owned(),
})
}
}
impl EmbeddingSettings {
pub const SOURCE: &'static str = "source";
pub const MODEL: &'static str = "model";
pub const REVISION: &'static str = "revision";
pub const API_KEY: &'static str = "apiKey";
pub const DIMENSIONS: &'static str = "dimensions";
pub const DOCUMENT_TEMPLATE: &'static str = "documentTemplate";
pub const DOCUMENT_TEMPLATE_MAX_BYTES: &'static str = "documentTemplateMaxBytes";
pub const URL: &'static str = "url";
pub const REQUEST: &'static str = "request";
pub const RESPONSE: &'static str = "response";
pub const HEADERS: &'static str = "headers";
pub const DISTRIBUTION: &'static str = "distribution";
pub fn allowed_sources_for_field(field: &'static str) -> &'static [EmbedderSource] {
match field {
Self::SOURCE => &[
EmbedderSource::HuggingFace,
EmbedderSource::OpenAi,
EmbedderSource::UserProvided,
EmbedderSource::Rest,
EmbedderSource::Ollama,
],
Self::MODEL => {
&[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::Ollama]
}
Self::REVISION => &[EmbedderSource::HuggingFace],
Self::API_KEY => {
&[EmbedderSource::OpenAi, EmbedderSource::Ollama, EmbedderSource::Rest]
}
Self::DIMENSIONS => &[
EmbedderSource::OpenAi,
EmbedderSource::UserProvided,
EmbedderSource::Ollama,
EmbedderSource::Rest,
],
Self::DOCUMENT_TEMPLATE => &[
EmbedderSource::HuggingFace,
EmbedderSource::OpenAi,
EmbedderSource::Ollama,
EmbedderSource::Rest,
],
Self::URL => &[EmbedderSource::Ollama, EmbedderSource::Rest, EmbedderSource::OpenAi],
Self::REQUEST => &[EmbedderSource::Rest],
Self::RESPONSE => &[EmbedderSource::Rest],
Self::HEADERS => &[EmbedderSource::Rest],
Self::DISTRIBUTION => &[
EmbedderSource::HuggingFace,
EmbedderSource::Ollama,
EmbedderSource::OpenAi,
EmbedderSource::Rest,
EmbedderSource::UserProvided,
],
_other => unreachable!("unknown field"),
}
}
pub fn allowed_fields_for_source(source: EmbedderSource) -> &'static [&'static str] {
match source {
EmbedderSource::OpenAi => &[
Self::SOURCE,
Self::MODEL,
Self::API_KEY,
Self::DOCUMENT_TEMPLATE,
Self::DIMENSIONS,
Self::DISTRIBUTION,
Self::URL,
],
EmbedderSource::HuggingFace => &[
Self::SOURCE,
Self::MODEL,
Self::REVISION,
Self::DOCUMENT_TEMPLATE,
Self::DISTRIBUTION,
],
EmbedderSource::Ollama => &[
Self::SOURCE,
Self::MODEL,
Self::DOCUMENT_TEMPLATE,
Self::URL,
Self::API_KEY,
Self::DIMENSIONS,
Self::DISTRIBUTION,
],
EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS, Self::DISTRIBUTION],
EmbedderSource::Rest => &[
Self::SOURCE,
Self::API_KEY,
Self::DIMENSIONS,
Self::DOCUMENT_TEMPLATE,
Self::URL,
Self::REQUEST,
Self::RESPONSE,
Self::HEADERS,
Self::DISTRIBUTION,
],
}
}
pub(crate) fn apply_default_source(setting: &mut Setting<EmbeddingSettings>) {
if let Setting::Set(EmbeddingSettings {
source: source @ (Setting::NotSet | Setting::Reset),
..
}) = setting
{
*source = Setting::Set(EmbedderSource::default())
}
}
pub(crate) fn apply_default_openai_model(setting: &mut Setting<EmbeddingSettings>) {
if let Setting::Set(EmbeddingSettings {
source: Setting::Set(EmbedderSource::OpenAi),
model: model @ (Setting::NotSet | Setting::Reset),
..
}) = setting
{
*model = Setting::Set(openai::EmbeddingModel::default().name().to_owned())
}
}
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)]
pub enum EmbedderSource {
#[default]
OpenAi,
HuggingFace,
Ollama,
UserProvided,
Rest,
}
impl std::fmt::Display for EmbedderSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
EmbedderSource::OpenAi => "openAi",
EmbedderSource::HuggingFace => "huggingFace",
EmbedderSource::UserProvided => "userProvided",
EmbedderSource::Ollama => "ollama",
EmbedderSource::Rest => "rest",
};
f.write_str(s)
}
}
impl From<EmbeddingConfig> for EmbeddingSettings {
fn from(value: EmbeddingConfig) -> Self {
let EmbeddingConfig { embedder_options, prompt, quantized } = value;
let document_template_max_bytes =
Setting::Set(prompt.max_bytes.unwrap_or(default_max_bytes()).get());
match embedder_options {
super::EmbedderOptions::HuggingFace(super::hf::EmbedderOptions {
model,
revision,
distribution,
}) => Self {
source: Setting::Set(EmbedderSource::HuggingFace),
model: Setting::Set(model),
revision: Setting::some_or_not_set(revision),
api_key: Setting::NotSet,
dimensions: Setting::NotSet,
document_template: Setting::Set(prompt.template),
document_template_max_bytes,
url: Setting::NotSet,
request: Setting::NotSet,
response: Setting::NotSet,
headers: Setting::NotSet,
distribution: Setting::some_or_not_set(distribution),
binary_quantized: Setting::some_or_not_set(quantized),
},
super::EmbedderOptions::OpenAi(super::openai::EmbedderOptions {
url,
api_key,
embedding_model,
dimensions,
distribution,
}) => Self {
source: Setting::Set(EmbedderSource::OpenAi),
model: Setting::Set(embedding_model.name().to_owned()),
revision: Setting::NotSet,
api_key: Setting::some_or_not_set(api_key),
dimensions: Setting::some_or_not_set(dimensions),
document_template: Setting::Set(prompt.template),
document_template_max_bytes,
url: Setting::some_or_not_set(url),
request: Setting::NotSet,
response: Setting::NotSet,
headers: Setting::NotSet,
distribution: Setting::some_or_not_set(distribution),
binary_quantized: Setting::some_or_not_set(quantized),
},
super::EmbedderOptions::Ollama(super::ollama::EmbedderOptions {
embedding_model,
url,
api_key,
distribution,
dimensions,
}) => Self {
source: Setting::Set(EmbedderSource::Ollama),
model: Setting::Set(embedding_model),
revision: Setting::NotSet,
api_key: Setting::some_or_not_set(api_key),
dimensions: Setting::some_or_not_set(dimensions),
document_template: Setting::Set(prompt.template),
document_template_max_bytes,
url: Setting::some_or_not_set(url),
request: Setting::NotSet,
response: Setting::NotSet,
headers: Setting::NotSet,
distribution: Setting::some_or_not_set(distribution),
binary_quantized: Setting::some_or_not_set(quantized),
},
super::EmbedderOptions::UserProvided(super::manual::EmbedderOptions {
dimensions,
distribution,
}) => Self {
source: Setting::Set(EmbedderSource::UserProvided),
model: Setting::NotSet,
revision: Setting::NotSet,
api_key: Setting::NotSet,
dimensions: Setting::Set(dimensions),
document_template: Setting::NotSet,
document_template_max_bytes: Setting::NotSet,
url: Setting::NotSet,
request: Setting::NotSet,
response: Setting::NotSet,
headers: Setting::NotSet,
distribution: Setting::some_or_not_set(distribution),
binary_quantized: Setting::some_or_not_set(quantized),
},
super::EmbedderOptions::Rest(super::rest::EmbedderOptions {
api_key,
dimensions,
url,
request,
response,
distribution,
headers,
}) => Self {
source: Setting::Set(EmbedderSource::Rest),
model: Setting::NotSet,
revision: Setting::NotSet,
api_key: Setting::some_or_not_set(api_key),
dimensions: Setting::some_or_not_set(dimensions),
document_template: Setting::Set(prompt.template),
document_template_max_bytes,
url: Setting::Set(url),
request: Setting::Set(request),
response: Setting::Set(response),
distribution: Setting::some_or_not_set(distribution),
headers: Setting::Set(headers),
binary_quantized: Setting::some_or_not_set(quantized),
},
}
}
}
impl From<EmbeddingSettings> for EmbeddingConfig {
fn from(value: EmbeddingSettings) -> Self {
let mut this = Self::default();
let EmbeddingSettings {
source,
model,
revision,
api_key,
dimensions,
document_template,
document_template_max_bytes,
url,
request,
response,
distribution,
headers,
binary_quantized,
} = value;
this.quantized = binary_quantized.set();
if let Some(source) = source.set() {
match source {
EmbedderSource::OpenAi => {
let mut options = super::openai::EmbedderOptions::with_default_model(None);
if let Some(model) = model.set() {
if let Some(model) = super::openai::EmbeddingModel::from_name(&model) {
options.embedding_model = model;
}
}
if let Some(url) = url.set() {
options.url = Some(url);
}
if let Some(api_key) = api_key.set() {
options.api_key = Some(api_key);
}
if let Some(dimensions) = dimensions.set() {
options.dimensions = Some(dimensions);
}
options.distribution = distribution.set();
this.embedder_options = super::EmbedderOptions::OpenAi(options);
}
EmbedderSource::Ollama => {
let mut options: ollama::EmbedderOptions =
super::ollama::EmbedderOptions::with_default_model(
api_key.set(),
url.set(),
dimensions.set(),
);
if let Some(model) = model.set() {
options.embedding_model = model;
}
options.distribution = distribution.set();
this.embedder_options = super::EmbedderOptions::Ollama(options);
}
EmbedderSource::HuggingFace => {
let mut options = super::hf::EmbedderOptions::default();
if let Some(model) = model.set() {
options.model = model;
// Reset the revision if we are setting the model.
// This allows the following:
// "huggingFace": {} -> default model with default revision
// "huggingFace": { "model": "name-of-the-default-model" } -> default model without a revision
// "huggingFace": { "model": "some-other-model" } -> most importantly, other model without a revision
options.revision = None;
}
if let Some(revision) = revision.set() {
options.revision = Some(revision);
}
options.distribution = distribution.set();
this.embedder_options = super::EmbedderOptions::HuggingFace(options);
}
EmbedderSource::UserProvided => {
this.embedder_options =
super::EmbedderOptions::UserProvided(super::manual::EmbedderOptions {
dimensions: dimensions.set().unwrap(),
distribution: distribution.set(),
});
}
EmbedderSource::Rest => {
this.embedder_options =
super::EmbedderOptions::Rest(super::rest::EmbedderOptions {
api_key: api_key.set(),
dimensions: dimensions.set(),
url: url.set().unwrap(),
request: request.set().unwrap(),
response: response.set().unwrap(),
distribution: distribution.set(),
headers: headers.set().unwrap_or_default(),
})
}
}
}
if let Setting::Set(template) = document_template {
let max_bytes = document_template_max_bytes
.set()
.and_then(NonZeroUsize::new)
.unwrap_or(default_max_bytes());
this.prompt = PromptData { template, max_bytes: Some(max_bytes) }
}
this
}
}